mirror of
https://github.com/lobehub/lobe-chat.git
synced 2026-06-14 03:30:19 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a9e626ea3c |
+5
-1
@@ -71,6 +71,7 @@
|
||||
"prepare": "husky",
|
||||
"prettier": "prettier -c --write \"**/**\"",
|
||||
"pull": "git pull",
|
||||
"qstash": "pnpx @upstash/qstash-cli@latest dev",
|
||||
"reinstall": "rm -rf pnpm-lock.yaml && rm -rf node_modules && pnpm -r exec rm -rf node_modules && pnpm install",
|
||||
"reinstall:desktop": "rm -rf pnpm-lock.yaml && rm -rf node_modules && pnpm -r exec rm -rf node_modules && pnpm install --node-linker=hoisted",
|
||||
"release": "semantic-release",
|
||||
@@ -84,8 +85,9 @@
|
||||
"test:e2e": "pnpm --filter @lobechat/e2e-tests test",
|
||||
"test:e2e:smoke": "pnpm --filter @lobechat/e2e-tests test:smoke",
|
||||
"test:update": "vitest -u",
|
||||
"tunnel:cloudflare": "cloudflared tunnel --url http://localhost:3010",
|
||||
"tunnel:ngrok": "ngrok http http://localhost:3011",
|
||||
"type-check": "tsgo --noEmit",
|
||||
"webhook:ngrok": "ngrok http http://localhost:3011",
|
||||
"workflow:cdn": "tsx ./scripts/cdnWorkflow/index.ts",
|
||||
"workflow:changelog": "tsx ./scripts/changelogWorkflow/index.ts",
|
||||
"workflow:countCharters": "tsx scripts/countEnWord.ts",
|
||||
@@ -186,6 +188,7 @@
|
||||
"@trpc/next": "^11.7.1",
|
||||
"@trpc/react-query": "^11.7.1",
|
||||
"@trpc/server": "^11.7.1",
|
||||
"@upstash/qstash": "^2.8.2",
|
||||
"@vercel/analytics": "^1.5.0",
|
||||
"@vercel/edge-config": "^1.4.3",
|
||||
"@vercel/functions": "^3.3.2",
|
||||
@@ -218,6 +221,7 @@
|
||||
"i18next-browser-languagedetector": "^8.2.0",
|
||||
"i18next-resources-to-backend": "^1.2.1",
|
||||
"immer": "^10.2.0",
|
||||
"ioredis": "^5.7.0",
|
||||
"jose": "^5.10.0",
|
||||
"js-sha256": "^0.11.1",
|
||||
"jsonl-parse-stringify": "^1.0.3",
|
||||
|
||||
@@ -8,6 +8,7 @@ import { LobeAgentTTSConfig } from './tts';
|
||||
|
||||
export interface LobeAgentConfig {
|
||||
chatConfig: LobeAgentChatConfig;
|
||||
enableAgentMode?: boolean;
|
||||
fewShots?: FewShots;
|
||||
files?: FileItem[];
|
||||
id?: string;
|
||||
|
||||
@@ -64,6 +64,8 @@ export const LobeChatPluginApiSchema = z.object({
|
||||
url: z.string().optional(),
|
||||
});
|
||||
|
||||
import type { HumanInterventionPolicy } from './intervention';
|
||||
|
||||
export interface BuiltinToolManifest {
|
||||
api: LobeChatPluginApi[];
|
||||
|
||||
|
||||
@@ -3,4 +3,5 @@ export * from './correctOIDCUrl';
|
||||
export * from './geo';
|
||||
export * from './response';
|
||||
export * from './responsive';
|
||||
export * from './sse';
|
||||
export * from './xor';
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
/**
|
||||
* SSE (Server-Sent Events) utilities for agent streaming
|
||||
*/
|
||||
|
||||
export interface SSEEvent {
|
||||
data: any;
|
||||
event?: string;
|
||||
id?: string;
|
||||
retry?: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Formats data into SSE format with id/event/data structure
|
||||
* @param event - The SSE event configuration
|
||||
* @returns Formatted SSE string
|
||||
*/
|
||||
export function formatSSEEvent({ id, event, data, retry }: SSEEvent): string {
|
||||
const lines: string[] = [];
|
||||
|
||||
if (id !== undefined) {
|
||||
lines.push(`id: ${id}`);
|
||||
}
|
||||
|
||||
if (event !== undefined) {
|
||||
lines.push(`event: ${event}`);
|
||||
}
|
||||
|
||||
if (retry !== undefined) {
|
||||
lines.push(`retry: ${retry}`);
|
||||
}
|
||||
|
||||
// Handle data serialization
|
||||
const dataString = typeof data === 'string' ? data : JSON.stringify(data);
|
||||
|
||||
// Split multi-line data and prefix each line with "data: "
|
||||
const dataLines = dataString.split('\n');
|
||||
dataLines.forEach((line) => {
|
||||
lines.push(`data: ${line}`);
|
||||
});
|
||||
|
||||
// End with double newline
|
||||
lines.push('', '');
|
||||
|
||||
return lines.join('\n');
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a utility for enqueueing SSE events to a ReadableStream controller
|
||||
* @param controller - The ReadableStreamDefaultController
|
||||
* @returns Helper function for sending SSE events
|
||||
*/
|
||||
export function createSSEWriter(controller: ReadableStreamDefaultController<string>) {
|
||||
return {
|
||||
/**
|
||||
* Send a connection event
|
||||
*/
|
||||
writeConnection(sessionId: string, lastEventId: string, timestamp: number = Date.now()) {
|
||||
this.writeEvent({
|
||||
data: {
|
||||
lastEventId,
|
||||
sessionId,
|
||||
timestamp,
|
||||
type: 'connected',
|
||||
},
|
||||
event: 'connected',
|
||||
id: `conn_${timestamp}`,
|
||||
});
|
||||
},
|
||||
|
||||
/**
|
||||
* Send an error event
|
||||
*/
|
||||
writeError(error: any, sessionId: string, phase?: string, timestamp: number = Date.now()) {
|
||||
this.writeEvent({
|
||||
data: {
|
||||
data: {
|
||||
error: error.message || String(error),
|
||||
phase: phase || 'unknown',
|
||||
...(error.stack && { stack: error.stack }),
|
||||
},
|
||||
sessionId,
|
||||
timestamp,
|
||||
type: 'error',
|
||||
},
|
||||
event: 'error',
|
||||
id: `error_${timestamp}`,
|
||||
});
|
||||
},
|
||||
|
||||
/**
|
||||
* Send an SSE event
|
||||
*/
|
||||
writeEvent(event: SSEEvent) {
|
||||
controller.enqueue(formatSSEEvent(event));
|
||||
},
|
||||
|
||||
/**
|
||||
* Send a heartbeat/keep-alive event
|
||||
*/
|
||||
writeHeartbeat(timestamp: number = Date.now()) {
|
||||
this.writeEvent({
|
||||
data: {
|
||||
timestamp,
|
||||
type: 'heartbeat',
|
||||
},
|
||||
event: 'heartbeat',
|
||||
id: `heartbeat_${timestamp}`,
|
||||
});
|
||||
},
|
||||
|
||||
/**
|
||||
* Send a stream event (for historical or real-time events)
|
||||
*/
|
||||
writeStreamEvent(eventData: any, eventId?: string) {
|
||||
this.writeEvent({
|
||||
data: eventData,
|
||||
event: eventData.type || 'stream',
|
||||
id: eventId || `event_${Date.now()}`,
|
||||
});
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Agent stream event types
|
||||
*/
|
||||
export type AgentStreamEventType =
|
||||
| 'connected'
|
||||
| 'stream'
|
||||
| 'error'
|
||||
| 'heartbeat'
|
||||
| 'stream_start'
|
||||
| 'stream_chunk'
|
||||
| 'stream_end'
|
||||
| 'stream_error';
|
||||
|
||||
/**
|
||||
* Creates SSE headers for agent streaming
|
||||
*/
|
||||
// eslint-disable-next-line no-undef
|
||||
export function createSSEHeaders(): HeadersInit {
|
||||
return {
|
||||
'Access-Control-Allow-Headers': 'Cache-Control, Last-Event-ID',
|
||||
'Access-Control-Allow-Methods': 'GET',
|
||||
'Access-Control-Allow-Origin': '*',
|
||||
'Cache-Control': 'no-cache, no-transform',
|
||||
'Connection': 'keep-alive',
|
||||
'Content-Type': 'text/event-stream',
|
||||
'X-Accel-Buffering': 'no',
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
import { serverDBEnv } from '@/config/db';
|
||||
import { isServerMode } from '@/const/version';
|
||||
|
||||
export const isEnableAgent = (): boolean => {
|
||||
if (!isServerMode) return false;
|
||||
|
||||
if (!serverDBEnv.REDIS_URL) return false;
|
||||
|
||||
if (!serverDBEnv.QSTASH_TOKEN) return false;
|
||||
|
||||
// TODO: V2 的 DB 版本默认需要 REDIS 和 QSTASH_TOKEN 了
|
||||
return true;
|
||||
};
|
||||
@@ -0,0 +1,115 @@
|
||||
import debug from 'debug';
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
|
||||
import { getServerDB } from '@/database/core/db-adaptor';
|
||||
import { AgentRuntimeService } from '@/server/services/agentRuntime';
|
||||
|
||||
import { isEnableAgent } from '../isEnableAgent';
|
||||
|
||||
const log = debug('api-route:agent:execute-step');
|
||||
|
||||
export async function POST(request: NextRequest) {
|
||||
if (!isEnableAgent()) {
|
||||
return NextResponse.json({ error: 'Agent features are not enabled' }, { status: 404 });
|
||||
}
|
||||
|
||||
// Initialize service
|
||||
const serverDB = await getServerDB();
|
||||
// TODO: remove userId
|
||||
const agentRuntimeService = new AgentRuntimeService(serverDB, 'github|28616219');
|
||||
|
||||
const startTime = Date.now();
|
||||
|
||||
const body = await request.json();
|
||||
try {
|
||||
const {
|
||||
sessionId,
|
||||
stepIndex = 0,
|
||||
context,
|
||||
humanInput,
|
||||
approvedToolCall,
|
||||
rejectionReason,
|
||||
} = body;
|
||||
|
||||
if (!sessionId) {
|
||||
return NextResponse.json({ error: 'sessionId is required' }, { status: 400 });
|
||||
}
|
||||
|
||||
log(`[${sessionId}] Starting step ${stepIndex}`);
|
||||
|
||||
// 使用 AgentRuntimeService 执行步骤
|
||||
const result = await agentRuntimeService.executeStep({
|
||||
approvedToolCall,
|
||||
context,
|
||||
humanInput,
|
||||
rejectionReason,
|
||||
sessionId,
|
||||
stepIndex,
|
||||
});
|
||||
|
||||
const executionTime = Date.now() - startTime;
|
||||
|
||||
const responseData = {
|
||||
completed: result.state.status === 'done',
|
||||
error: result.state.status === 'error' ? result.state.error : undefined,
|
||||
executionTime,
|
||||
nextStepIndex: result.nextStepScheduled ? stepIndex + 1 : undefined,
|
||||
nextStepScheduled: result.nextStepScheduled,
|
||||
pendingApproval: result.state.pendingToolsCalling,
|
||||
pendingPrompt: result.state.pendingHumanPrompt,
|
||||
pendingSelect: result.state.pendingHumanSelect,
|
||||
sessionId,
|
||||
status: result.state.status,
|
||||
stepIndex,
|
||||
success: result.success,
|
||||
totalCost: result.state.cost?.total || 0,
|
||||
totalSteps: result.state.stepCount,
|
||||
waitingForHuman: result.state.status === 'waiting_for_human',
|
||||
};
|
||||
|
||||
log(
|
||||
`[${sessionId}] Step ${stepIndex} completed (${executionTime}ms, status: ${result.state.status})`,
|
||||
);
|
||||
|
||||
return NextResponse.json(responseData);
|
||||
} catch (error: any) {
|
||||
const executionTime = Date.now() - startTime;
|
||||
console.error('Error in execution: %O', error);
|
||||
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: error.message,
|
||||
executionTime,
|
||||
sessionId: body?.sessionId,
|
||||
stepIndex: body?.stepIndex || 0,
|
||||
},
|
||||
{ status: 500 },
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 健康检查端点
|
||||
*/
|
||||
export async function GET() {
|
||||
if (!isEnableAgent()) {
|
||||
return NextResponse.json({ error: 'Agent features are not enabled' }, { status: 404 });
|
||||
}
|
||||
|
||||
try {
|
||||
return NextResponse.json({
|
||||
healthy: true,
|
||||
message: 'Agent execution service is running',
|
||||
timestamp: new Date().toISOString(),
|
||||
});
|
||||
} catch (error: any) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: error.message,
|
||||
healthy: false,
|
||||
timestamp: new Date().toISOString(),
|
||||
},
|
||||
{ status: 503 },
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,654 @@
|
||||
// @vitest-environment node
|
||||
import { NextRequest } from 'next/server';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { StreamEventManager } from '@/server/modules/AgentRuntime';
|
||||
|
||||
import * as isEnableAgentModule from '../../isEnableAgent';
|
||||
import { GET } from '../route';
|
||||
|
||||
// Mock dependencies first
|
||||
const mockStreamEventManager = {
|
||||
getStreamHistory: vi.fn(),
|
||||
subscribeStreamEvents: vi.fn(),
|
||||
};
|
||||
|
||||
vi.mock('@/server/modules/AgentRuntime', () => ({
|
||||
StreamEventManager: vi.fn(() => mockStreamEventManager),
|
||||
}));
|
||||
|
||||
describe('/api/agent/stream route', () => {
|
||||
const isEnableAgentSpy = vi.spyOn(isEnableAgentModule, 'isEnableAgent');
|
||||
const MOCK_TIMESTAMP = 1758203237000;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks();
|
||||
// Default to enabled for most tests
|
||||
isEnableAgentSpy.mockReturnValue(true);
|
||||
// Mock Date.now to return consistent timestamp
|
||||
vi.spyOn(Date, 'now').mockReturnValue(MOCK_TIMESTAMP);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('GET handler', () => {
|
||||
it('should return 404 when agent features are not enabled', async () => {
|
||||
isEnableAgentSpy.mockReturnValue(false);
|
||||
|
||||
const request = new NextRequest('https://test.com/api/agent/stream?sessionId=test-session');
|
||||
const response = await GET(request);
|
||||
|
||||
expect(response.status).toBe(404);
|
||||
const data = await response.json();
|
||||
expect(data.error).toBe('Agent features are not enabled');
|
||||
});
|
||||
|
||||
it('should return 400 when sessionId parameter is missing', async () => {
|
||||
const request = new NextRequest('https://test.com/api/agent/stream');
|
||||
const response = await GET(request);
|
||||
|
||||
expect(response.status).toBe(400);
|
||||
const data = await response.json();
|
||||
expect(data.error).toBe('sessionId parameter is required');
|
||||
});
|
||||
|
||||
it('should return SSE stream with correct headers when sessionId is provided', async () => {
|
||||
const request = new NextRequest('https://test.com/api/agent/stream?sessionId=test-session');
|
||||
const response = await GET(request);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.headers.get('Content-Type')).toBe('text/event-stream');
|
||||
expect(response.headers.get('Cache-Control')).toBe('no-cache, no-transform');
|
||||
expect(response.headers.get('Connection')).toBe('keep-alive');
|
||||
expect(response.headers.get('Access-Control-Allow-Origin')).toBe('*');
|
||||
expect(response.headers.get('Access-Control-Allow-Methods')).toBe('GET');
|
||||
expect(response.headers.get('Access-Control-Allow-Headers')).toBe(
|
||||
'Cache-Control, Last-Event-ID',
|
||||
);
|
||||
expect(response.headers.get('X-Accel-Buffering')).toBe('no');
|
||||
});
|
||||
});
|
||||
|
||||
describe('Stream functionality with exact data verification', () => {
|
||||
it('should send connection event in exact SSE format', async () => {
|
||||
const request = new NextRequest(
|
||||
'https://test.com/api/agent/stream?sessionId=test-session&lastEventId=123',
|
||||
);
|
||||
|
||||
const response = await GET(request);
|
||||
const decoder = new TextDecoder();
|
||||
const reader = response.body!.getReader();
|
||||
|
||||
// Collect all chunks
|
||||
const chunks = [];
|
||||
let readCount = 0;
|
||||
const maxReads = 1; // Only read connection event
|
||||
|
||||
try {
|
||||
while (readCount < maxReads) {
|
||||
const readPromise = reader.read();
|
||||
const timeoutPromise = new Promise((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Read timeout')), 1000),
|
||||
);
|
||||
|
||||
const result = (await Promise.race([
|
||||
readPromise,
|
||||
timeoutPromise,
|
||||
])) as ReadableStreamReadResult<Uint8Array>;
|
||||
|
||||
if (result.done) break;
|
||||
if (result.value) {
|
||||
const chunk =
|
||||
result.value instanceof Uint8Array
|
||||
? decoder.decode(result.value)
|
||||
: String(result.value);
|
||||
chunks.push(chunk);
|
||||
readCount++;
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
// Timeout or error
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
|
||||
// Verify exact stream format with mocked timestamp (new SSE format)
|
||||
expect(chunks).toEqual([
|
||||
`id: conn_${MOCK_TIMESTAMP}\nevent: connected\ndata: {"lastEventId":"123","sessionId":"test-session","timestamp":${MOCK_TIMESTAMP},"type":"connected"}\n\n`,
|
||||
]);
|
||||
});
|
||||
|
||||
it('should verify getStreamHistory with exact historical events format', async () => {
|
||||
const request = new NextRequest(
|
||||
'https://test.com/api/agent/stream?sessionId=test-session&includeHistory=true&lastEventId=100',
|
||||
);
|
||||
|
||||
// Mock getStreamHistory to return specific events
|
||||
const mockEvents = [
|
||||
{
|
||||
type: 'stream_end',
|
||||
timestamp: 300,
|
||||
sessionId: 'test-session',
|
||||
data: { messageId: 'msg3' },
|
||||
},
|
||||
{
|
||||
type: 'stream_chunk',
|
||||
timestamp: 250,
|
||||
sessionId: 'test-session',
|
||||
data: { content: 'world' },
|
||||
},
|
||||
{
|
||||
type: 'stream_start',
|
||||
timestamp: 150,
|
||||
sessionId: 'test-session',
|
||||
data: { messageId: 'msg1' },
|
||||
},
|
||||
];
|
||||
mockStreamEventManager.getStreamHistory.mockResolvedValue(mockEvents);
|
||||
|
||||
const response = await GET(request);
|
||||
const decoder = new TextDecoder();
|
||||
const reader = response.body!.getReader();
|
||||
|
||||
// Collect all chunks
|
||||
const chunks = [];
|
||||
let readCount = 0;
|
||||
const maxReads = 3; // connection + 2 filtered historical events (timestamp > 100)
|
||||
|
||||
try {
|
||||
while (readCount < maxReads) {
|
||||
const readPromise = reader.read();
|
||||
const timeoutPromise = new Promise((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Read timeout')), 500),
|
||||
);
|
||||
|
||||
const result = (await Promise.race([
|
||||
readPromise,
|
||||
timeoutPromise,
|
||||
])) as ReadableStreamReadResult<Uint8Array>;
|
||||
|
||||
if (result.done) break;
|
||||
if (result.value) {
|
||||
const chunk =
|
||||
result.value instanceof Uint8Array
|
||||
? decoder.decode(result.value)
|
||||
: String(result.value);
|
||||
chunks.push(chunk);
|
||||
readCount++;
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
// Timeout or error
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
|
||||
// Verify exact stream format - connection event + filtered historical events (new SSE format)
|
||||
expect(chunks).toEqual([
|
||||
`id: conn_${MOCK_TIMESTAMP}\nevent: connected\ndata: {"lastEventId":"100","sessionId":"test-session","timestamp":${MOCK_TIMESTAMP},"type":"connected"}\n\n`,
|
||||
`id: test-session\nevent: stream_start\ndata: {"type":"stream_start","timestamp":150,"sessionId":"test-session","data":{"messageId":"msg1"}}\n\n`,
|
||||
`id: test-session\nevent: stream_chunk\ndata: {"type":"stream_chunk","timestamp":250,"sessionId":"test-session","data":{"content":"world"}}\n\n`,
|
||||
]);
|
||||
|
||||
// Verify API calls
|
||||
expect(mockStreamEventManager.getStreamHistory).toHaveBeenCalledWith('test-session', 50);
|
||||
});
|
||||
|
||||
it('should verify event filtering with exact format', async () => {
|
||||
const request = new NextRequest(
|
||||
'https://test.com/api/agent/stream?sessionId=test-session&includeHistory=true&lastEventId=200',
|
||||
);
|
||||
|
||||
// Mock events where some should be filtered out
|
||||
const mockEvents = [
|
||||
{
|
||||
type: 'stream_end',
|
||||
timestamp: 300,
|
||||
sessionId: 'test-session',
|
||||
data: { messageId: 'msg3' },
|
||||
}, // Should be included (300 > 200)
|
||||
{
|
||||
type: 'stream_chunk',
|
||||
timestamp: 250,
|
||||
sessionId: 'test-session',
|
||||
data: { content: 'world' },
|
||||
}, // Should be included (250 > 200)
|
||||
{
|
||||
type: 'stream_chunk',
|
||||
timestamp: 200,
|
||||
sessionId: 'test-session',
|
||||
data: { content: 'hello' },
|
||||
}, // Should be excluded (200 = 200)
|
||||
{
|
||||
type: 'stream_start',
|
||||
timestamp: 150,
|
||||
sessionId: 'test-session',
|
||||
data: { messageId: 'msg1' },
|
||||
}, // Should be excluded (150 < 200)
|
||||
];
|
||||
mockStreamEventManager.getStreamHistory.mockResolvedValue(mockEvents);
|
||||
|
||||
const response = await GET(request);
|
||||
const decoder = new TextDecoder();
|
||||
const reader = response.body!.getReader();
|
||||
|
||||
// Collect all chunks
|
||||
const chunks = [];
|
||||
let readCount = 0;
|
||||
const maxReads = 3; // connection + 2 filtered events
|
||||
|
||||
try {
|
||||
while (readCount < maxReads) {
|
||||
const readPromise = reader.read();
|
||||
const timeoutPromise = new Promise((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Read timeout')), 500),
|
||||
);
|
||||
|
||||
const result = (await Promise.race([
|
||||
readPromise,
|
||||
timeoutPromise,
|
||||
])) as ReadableStreamReadResult<Uint8Array>;
|
||||
|
||||
if (result.done) break;
|
||||
if (result.value) {
|
||||
const chunk =
|
||||
result.value instanceof Uint8Array
|
||||
? decoder.decode(result.value)
|
||||
: String(result.value);
|
||||
chunks.push(chunk);
|
||||
readCount++;
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
// Timeout or error
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
|
||||
// Verify exact stream format - only events with timestamp > 200 are included (new SSE format)
|
||||
// Note: indices are based on original array position, not filtered position
|
||||
expect(chunks).toEqual([
|
||||
`id: conn_${MOCK_TIMESTAMP}
|
||||
event: connected
|
||||
data: {"lastEventId":"200","sessionId":"test-session","timestamp":${MOCK_TIMESTAMP},"type":"connected"}
|
||||
|
||||
`,
|
||||
`id: test-session
|
||||
event: stream_chunk
|
||||
data: {"type":"stream_chunk","timestamp":250,"sessionId":"test-session","data":{"content":"world"}}
|
||||
|
||||
`,
|
||||
`id: test-session
|
||||
event: stream_end
|
||||
data: {"type":"stream_end","timestamp":300,"sessionId":"test-session","data":{"messageId":"msg3"}}
|
||||
\n`,
|
||||
]);
|
||||
|
||||
// Verify API calls
|
||||
expect(mockStreamEventManager.getStreamHistory).toHaveBeenCalledWith('test-session', 50);
|
||||
});
|
||||
|
||||
it('should handle errors with exact error event format', async () => {
|
||||
const request = new NextRequest(
|
||||
'https://test.com/api/agent/stream?sessionId=test-session&includeHistory=true',
|
||||
);
|
||||
|
||||
// Mock getStreamHistory to reject
|
||||
mockStreamEventManager.getStreamHistory.mockRejectedValue(
|
||||
new Error('Redis connection failed'),
|
||||
);
|
||||
|
||||
const response = await GET(request);
|
||||
const decoder = new TextDecoder();
|
||||
const reader = response.body!.getReader();
|
||||
|
||||
// Collect all chunks
|
||||
const chunks = [];
|
||||
let readCount = 0;
|
||||
const maxReads = 2; // connection + error event
|
||||
|
||||
try {
|
||||
while (readCount < maxReads) {
|
||||
const readPromise = reader.read();
|
||||
const timeoutPromise = new Promise((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Read timeout')), 500),
|
||||
);
|
||||
|
||||
const result = (await Promise.race([
|
||||
readPromise,
|
||||
timeoutPromise,
|
||||
])) as ReadableStreamReadResult<Uint8Array>;
|
||||
|
||||
if (result.done) break;
|
||||
if (result.value) {
|
||||
const chunk =
|
||||
result.value instanceof Uint8Array
|
||||
? decoder.decode(result.value)
|
||||
: String(result.value);
|
||||
chunks.push(chunk);
|
||||
readCount++;
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
// Timeout or error
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
|
||||
// Verify exact stream format - connection event + error event (new SSE format)
|
||||
// Parse error event to check format (error includes stack trace dynamically)
|
||||
const errorChunk = chunks[1];
|
||||
expect(errorChunk).toMatch(/^id: error_\d+\nevent: error\ndata: \{.*"type":"error".*\}\n\n$/);
|
||||
expect(errorChunk).toContain('"error":"Redis connection failed"');
|
||||
expect(errorChunk).toContain('"phase":"history_loading"');
|
||||
expect(errorChunk).toContain('"sessionId":"test-session"');
|
||||
expect(errorChunk).toContain(`"timestamp":${MOCK_TIMESTAMP}`);
|
||||
|
||||
// Verify connection event format
|
||||
expect(chunks[0]).toEqual(
|
||||
`id: conn_${MOCK_TIMESTAMP}\nevent: connected\ndata: {"lastEventId":"0","sessionId":"test-session","timestamp":${MOCK_TIMESTAMP},"type":"connected"}\n\n`,
|
||||
);
|
||||
|
||||
// Verify getStreamHistory was called
|
||||
expect(mockStreamEventManager.getStreamHistory).toHaveBeenCalledWith('test-session', 50);
|
||||
});
|
||||
|
||||
it('should verify stream subscription with exact parameters', async () => {
|
||||
const request = new NextRequest(
|
||||
'https://test.com/api/agent/stream?sessionId=test-session&lastEventId=456',
|
||||
);
|
||||
|
||||
mockStreamEventManager.subscribeStreamEvents.mockResolvedValue(undefined);
|
||||
|
||||
const response = await GET(request);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
|
||||
// Verify exact parameter passing
|
||||
expect(mockStreamEventManager.subscribeStreamEvents).toHaveBeenCalledWith(
|
||||
'test-session',
|
||||
'456',
|
||||
expect.any(Function), // callback function
|
||||
expect.any(AbortSignal), // abort signal
|
||||
);
|
||||
|
||||
// Verify the callback function structure
|
||||
const callArgs = mockStreamEventManager.subscribeStreamEvents.mock.calls[0];
|
||||
expect(callArgs).toHaveLength(4);
|
||||
expect(typeof callArgs[2]).toBe('function'); // callback
|
||||
expect(callArgs[3]).toBeInstanceOf(AbortSignal); // signal
|
||||
});
|
||||
|
||||
it('should verify default parameters with exact values', async () => {
|
||||
const request = new NextRequest('https://test.com/api/agent/stream?sessionId=test-session');
|
||||
|
||||
mockStreamEventManager.subscribeStreamEvents.mockResolvedValue(undefined);
|
||||
|
||||
const response = await GET(request);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
|
||||
// Verify default values are used
|
||||
expect(mockStreamEventManager.subscribeStreamEvents).toHaveBeenCalledWith(
|
||||
'test-session',
|
||||
'0', // default lastEventId
|
||||
expect.any(Function),
|
||||
expect.any(AbortSignal),
|
||||
);
|
||||
|
||||
// Verify getStreamHistory is NOT called when includeHistory defaults to false
|
||||
expect(mockStreamEventManager.getStreamHistory).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should verify SSE message structure with exact format specification', async () => {
|
||||
const request = new NextRequest('https://test.com/api/agent/stream?sessionId=test-session');
|
||||
|
||||
const response = await GET(request);
|
||||
const decoder = new TextDecoder();
|
||||
const reader = response.body!.getReader();
|
||||
|
||||
// Collect all chunks
|
||||
const chunks = [];
|
||||
let readCount = 0;
|
||||
const maxReads = 1; // Only read connection event
|
||||
|
||||
try {
|
||||
while (readCount < maxReads) {
|
||||
const readPromise = reader.read();
|
||||
const timeoutPromise = new Promise((_, reject) =>
|
||||
setTimeout(() => reject(new Error('Read timeout')), 1000),
|
||||
);
|
||||
|
||||
const result = (await Promise.race([
|
||||
readPromise,
|
||||
timeoutPromise,
|
||||
])) as ReadableStreamReadResult<Uint8Array>;
|
||||
|
||||
if (result.done) break;
|
||||
if (result.value) {
|
||||
const chunk =
|
||||
result.value instanceof Uint8Array
|
||||
? decoder.decode(result.value)
|
||||
: String(result.value);
|
||||
chunks.push(chunk);
|
||||
readCount++;
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
// Timeout or error
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
|
||||
// Verify exact stream format with default lastEventId (new SSE format)
|
||||
expect(chunks).toEqual([
|
||||
`id: conn_${MOCK_TIMESTAMP}\nevent: connected\ndata: {"lastEventId":"0","sessionId":"test-session","timestamp":${MOCK_TIMESTAMP},"type":"connected"}\n\n`,
|
||||
]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('Agent Runtime Lifecycle', () => {
|
||||
it('should verify agent runtime event handling and connection closure logic', async () => {
|
||||
const request = new NextRequest('https://test.com/api/agent/stream?sessionId=test-session');
|
||||
|
||||
// Capture the event callback so we can test the event processing logic directly
|
||||
let capturedCallback: ((events: any[]) => void) | null = null;
|
||||
let capturedSignal: AbortSignal | null = null;
|
||||
|
||||
mockStreamEventManager.subscribeStreamEvents.mockImplementation(
|
||||
(sessionId, lastEventId, callback, signal) => {
|
||||
capturedCallback = callback;
|
||||
capturedSignal = signal;
|
||||
return Promise.resolve();
|
||||
},
|
||||
);
|
||||
|
||||
const response = await GET(request);
|
||||
|
||||
// Verify the subscription was set up correctly
|
||||
expect(mockStreamEventManager.subscribeStreamEvents).toHaveBeenCalledWith(
|
||||
'test-session',
|
||||
'0',
|
||||
expect.any(Function),
|
||||
expect.any(AbortSignal),
|
||||
);
|
||||
expect(capturedCallback).toBeDefined();
|
||||
expect(capturedSignal).toBeDefined();
|
||||
|
||||
// Verify response headers are correct
|
||||
expect(response.status).toBe(200);
|
||||
expect(response.headers.get('Content-Type')).toBe('text/event-stream');
|
||||
|
||||
// Test that the callback exists and can be called
|
||||
expect(typeof capturedCallback).toBe('function');
|
||||
expect(capturedSignal).toBeInstanceOf(AbortSignal);
|
||||
});
|
||||
|
||||
it('should verify subscribeStreamEvents callback can handle agent_runtime_init events', async () => {
|
||||
const request = new NextRequest('https://test.com/api/agent/stream?sessionId=test-session');
|
||||
|
||||
let capturedCallback: ((events: any[]) => void) | null = null;
|
||||
|
||||
mockStreamEventManager.subscribeStreamEvents.mockImplementation(
|
||||
(sessionId, lastEventId, callback, signal) => {
|
||||
capturedCallback = callback;
|
||||
return Promise.resolve();
|
||||
},
|
||||
);
|
||||
|
||||
const response = await GET(request);
|
||||
|
||||
// Verify we captured the callback
|
||||
expect(capturedCallback).toBeDefined();
|
||||
expect(response.status).toBe(200);
|
||||
|
||||
// Test agent_runtime_init event processing
|
||||
const initEvent = {
|
||||
type: 'agent_runtime_init',
|
||||
timestamp: MOCK_TIMESTAMP + 100,
|
||||
sessionId: 'test-session',
|
||||
data: {
|
||||
userId: 'user-123',
|
||||
modelConfig: { model: 'gpt-4', temperature: 0.7 },
|
||||
agentType: 'assistant',
|
||||
},
|
||||
};
|
||||
|
||||
// The callback should be callable without throwing errors
|
||||
expect(() => capturedCallback!([initEvent])).not.toThrow();
|
||||
});
|
||||
|
||||
it('should verify subscribeStreamEvents callback can handle agent_runtime_end events', async () => {
|
||||
const request = new NextRequest('https://test.com/api/agent/stream?sessionId=test-session');
|
||||
|
||||
let capturedCallback: ((events: any[]) => void) | null = null;
|
||||
|
||||
mockStreamEventManager.subscribeStreamEvents.mockImplementation(
|
||||
(sessionId, lastEventId, callback, signal) => {
|
||||
capturedCallback = callback;
|
||||
return Promise.resolve();
|
||||
},
|
||||
);
|
||||
|
||||
const response = await GET(request);
|
||||
|
||||
// Verify we captured the callback
|
||||
expect(capturedCallback).toBeDefined();
|
||||
expect(response.status).toBe(200);
|
||||
|
||||
// Test agent_runtime_end event processing
|
||||
const endEvent = {
|
||||
type: 'agent_runtime_end',
|
||||
timestamp: MOCK_TIMESTAMP + 600,
|
||||
sessionId: 'test-session',
|
||||
data: {
|
||||
totalSteps: 1,
|
||||
executionTime: 500,
|
||||
status: 'completed',
|
||||
},
|
||||
};
|
||||
|
||||
// The callback should be callable without throwing errors
|
||||
expect(() => capturedCallback!([endEvent])).not.toThrow();
|
||||
});
|
||||
|
||||
it('should verify complete agent runtime lifecycle event types are supported', async () => {
|
||||
const request = new NextRequest('https://test.com/api/agent/stream?sessionId=test-session');
|
||||
|
||||
let capturedCallback: ((events: any[]) => void) | null = null;
|
||||
|
||||
mockStreamEventManager.subscribeStreamEvents.mockImplementation(
|
||||
(sessionId, lastEventId, callback, signal) => {
|
||||
capturedCallback = callback;
|
||||
return Promise.resolve();
|
||||
},
|
||||
);
|
||||
|
||||
const response = await GET(request);
|
||||
|
||||
expect(capturedCallback).toBeDefined();
|
||||
expect(response.status).toBe(200);
|
||||
|
||||
// Test complete lifecycle events can be processed
|
||||
const lifecycleEvents = [
|
||||
{
|
||||
type: 'agent_runtime_init',
|
||||
timestamp: MOCK_TIMESTAMP + 100,
|
||||
sessionId: 'test-session',
|
||||
data: { userId: 'user-123', agentType: 'assistant' },
|
||||
},
|
||||
{
|
||||
type: 'stream_start',
|
||||
timestamp: MOCK_TIMESTAMP + 200,
|
||||
sessionId: 'test-session',
|
||||
data: { messageId: 'msg-001' },
|
||||
},
|
||||
{
|
||||
type: 'stream_chunk',
|
||||
timestamp: MOCK_TIMESTAMP + 300,
|
||||
sessionId: 'test-session',
|
||||
data: { content: 'Hello world' },
|
||||
},
|
||||
{
|
||||
type: 'stream_end',
|
||||
timestamp: MOCK_TIMESTAMP + 400,
|
||||
sessionId: 'test-session',
|
||||
data: { messageId: 'msg-001' },
|
||||
},
|
||||
{
|
||||
type: 'agent_runtime_end',
|
||||
timestamp: MOCK_TIMESTAMP + 500,
|
||||
sessionId: 'test-session',
|
||||
data: { status: 'completed', totalSteps: 1 },
|
||||
},
|
||||
];
|
||||
|
||||
// All lifecycle events should be processable without throwing errors
|
||||
expect(() => capturedCallback!(lifecycleEvents)).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Parameter validation', () => {
|
||||
it('should handle sessionId with special characters', async () => {
|
||||
const sessionId = 'test-session-123_456';
|
||||
const request = new NextRequest(`https://test.com/api/agent/stream?sessionId=${sessionId}`);
|
||||
|
||||
const response = await GET(request);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
});
|
||||
|
||||
it('should handle lastEventId as string number', async () => {
|
||||
const request = new NextRequest(
|
||||
'https://test.com/api/agent/stream?sessionId=test&lastEventId=12345',
|
||||
);
|
||||
|
||||
const response = await GET(request);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
});
|
||||
|
||||
it('should handle includeHistory as string boolean', async () => {
|
||||
const request = new NextRequest(
|
||||
'https://test.com/api/agent/stream?sessionId=test&includeHistory=false',
|
||||
);
|
||||
|
||||
const response = await GET(request);
|
||||
|
||||
expect(response.status).toBe(200);
|
||||
expect(mockStreamEventManager.getStreamHistory).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle invalid URL gracefully', async () => {
|
||||
const request = new NextRequest('https://test.com/api/agent/stream?sessionId=');
|
||||
|
||||
const response = await GET(request);
|
||||
|
||||
expect(response.status).toBe(400);
|
||||
const data = await response.json();
|
||||
expect(data.error).toBe('sessionId parameter is required');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,196 @@
|
||||
import { createSSEHeaders, createSSEWriter } from '@lobechat/utils/server';
|
||||
import debug from 'debug';
|
||||
import { NextRequest, NextResponse } from 'next/server';
|
||||
|
||||
import { StreamEventManager } from '@/server/modules/AgentRuntime';
|
||||
|
||||
import { isEnableAgent } from '../isEnableAgent';
|
||||
|
||||
const log = debug('api-route:agent:stream');
|
||||
|
||||
/**
|
||||
* Server-Sent Events (SSE) endpoint
|
||||
* Provides real-time Agent execution event stream for clients
|
||||
*/
|
||||
export async function GET(request: NextRequest) {
|
||||
if (!isEnableAgent()) {
|
||||
return NextResponse.json({ error: 'Agent features are not enabled' }, { status: 404 });
|
||||
}
|
||||
|
||||
// Initialize stream event manager
|
||||
const streamManager = new StreamEventManager();
|
||||
|
||||
const { searchParams } = new URL(request.url);
|
||||
const sessionId = searchParams.get('sessionId');
|
||||
const lastEventId = searchParams.get('lastEventId') || '0';
|
||||
const includeHistory = searchParams.get('includeHistory') === 'true';
|
||||
|
||||
if (!sessionId) {
|
||||
return NextResponse.json(
|
||||
{
|
||||
error: 'sessionId parameter is required',
|
||||
},
|
||||
{ status: 400 },
|
||||
);
|
||||
}
|
||||
|
||||
log(`Starting SSE connection for session ${sessionId} from eventId ${lastEventId}`);
|
||||
|
||||
// 创建 Server-Sent Events 流
|
||||
const stream = new ReadableStream({
|
||||
cancel(reason) {
|
||||
log(`SSE connection cancelled for session ${sessionId}:`, reason);
|
||||
|
||||
// Call cleanup function
|
||||
if ((this as any)._cleanup) {
|
||||
(this as any)._cleanup();
|
||||
}
|
||||
},
|
||||
|
||||
start(controller) {
|
||||
const writer = createSSEWriter(controller);
|
||||
|
||||
// 发送连接确认事件
|
||||
writer.writeConnection(sessionId, lastEventId);
|
||||
log(`SSE connection established for session ${sessionId}`);
|
||||
|
||||
// 如果需要,先发送历史事件
|
||||
if (includeHistory) {
|
||||
streamManager
|
||||
.getStreamHistory(sessionId, 50)
|
||||
.then((history) => {
|
||||
// 按时间顺序发送历史事件(最早的在前面)
|
||||
const sortedHistory = history.reverse();
|
||||
|
||||
sortedHistory.forEach((event) => {
|
||||
// 只发送比 lastEventId 更新的事件
|
||||
if (!lastEventId || lastEventId === '0' || event.timestamp.toString() > lastEventId) {
|
||||
try {
|
||||
// 添加 SSE 特定的字段,保持与实时事件格式一致
|
||||
const sseEvent = {
|
||||
...event,
|
||||
sessionId,
|
||||
timestamp: event.timestamp || Date.now(),
|
||||
};
|
||||
writer.writeStreamEvent(sseEvent, sessionId);
|
||||
} catch (error) {
|
||||
console.error('[Agent Stream] Error sending history event:', error);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (sortedHistory.length > 0) {
|
||||
log(`Sent ${sortedHistory.length} historical events for session ${sessionId}`);
|
||||
}
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('[Agent Stream] Failed to load history:', error);
|
||||
|
||||
try {
|
||||
writer.writeError(error, sessionId, 'history_loading');
|
||||
} catch (controllerError) {
|
||||
console.error('[Agent Stream] Failed to send error event:', controllerError);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// 创建 AbortController 用于取消订阅
|
||||
const abortController = new AbortController();
|
||||
|
||||
// 订阅新的流式事件
|
||||
const subscribeToEvents = async () => {
|
||||
try {
|
||||
await streamManager.subscribeStreamEvents(
|
||||
sessionId,
|
||||
lastEventId,
|
||||
(events) => {
|
||||
events.forEach((event) => {
|
||||
try {
|
||||
// 添加 SSE 特定的字段
|
||||
const sseEvent = {
|
||||
...event,
|
||||
sessionId,
|
||||
timestamp: event.timestamp || Date.now(),
|
||||
};
|
||||
|
||||
writer.writeStreamEvent(sseEvent, sessionId);
|
||||
|
||||
// 如果收到 agent_runtime_end 事件,停止心跳并准备关闭连接
|
||||
if (event.type === 'agent_runtime_end') {
|
||||
log(
|
||||
`Agent runtime ended for session ${sessionId}, preparing to close connection`,
|
||||
);
|
||||
|
||||
// 延迟关闭连接,确保客户端有时间处理最后的事件
|
||||
setTimeout(() => {
|
||||
try {
|
||||
// eslint-disable-next-line @typescript-eslint/no-use-before-define
|
||||
cleanup();
|
||||
controller.close();
|
||||
log(
|
||||
`SSE connection closed after agent runtime end for session ${sessionId}`,
|
||||
);
|
||||
} catch (closeError) {
|
||||
console.error('[Agent Stream] Error closing connection:', closeError);
|
||||
}
|
||||
}, 1000); // 1秒延迟,给客户端处理时间
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('[Agent Stream] Error sending event:', error);
|
||||
}
|
||||
});
|
||||
},
|
||||
abortController.signal,
|
||||
);
|
||||
} catch (error) {
|
||||
if (!abortController.signal.aborted) {
|
||||
console.error('[Agent Stream] Subscription error:', error);
|
||||
|
||||
try {
|
||||
writer.writeError(error as Error, sessionId, 'stream_subscription');
|
||||
} catch (controllerError) {
|
||||
console.error('[Agent Stream] Failed to send subscription error:', controllerError);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// 开始订阅
|
||||
subscribeToEvents();
|
||||
|
||||
// 定期发送心跳(每 30 秒)
|
||||
const heartbeatInterval = setInterval(() => {
|
||||
try {
|
||||
const heartbeat = {
|
||||
sessionId,
|
||||
timestamp: Date.now(),
|
||||
type: 'heartbeat',
|
||||
};
|
||||
|
||||
controller.enqueue(`data: ${JSON.stringify(heartbeat)}\n\n`);
|
||||
} catch (error) {
|
||||
console.error('[Agent Stream] Heartbeat error:', error);
|
||||
clearInterval(heartbeatInterval);
|
||||
}
|
||||
}, 30_000);
|
||||
|
||||
// Cleanup function
|
||||
const cleanup = () => {
|
||||
abortController.abort();
|
||||
clearInterval(heartbeatInterval);
|
||||
log(`SSE connection closed for session ${sessionId}`);
|
||||
};
|
||||
|
||||
// 监听连接关闭
|
||||
request.signal?.addEventListener('abort', cleanup);
|
||||
|
||||
// 存储清理函数以便在 cancel 时调用
|
||||
(controller as any)._cleanup = cleanup;
|
||||
},
|
||||
});
|
||||
|
||||
// 设置 SSE 响应头
|
||||
return new Response(stream, {
|
||||
headers: createSSEHeaders(),
|
||||
});
|
||||
}
|
||||
+14
-1
@@ -7,6 +7,8 @@ import { Flexbox } from 'react-layout-kit';
|
||||
|
||||
import { type ActionKeys, ChatInputProvider, DesktopChatInput } from '@/features/ChatInput';
|
||||
import WideScreenContainer from '@/features/ChatList/components/WideScreenContainer';
|
||||
import { useAgentStore } from '@/store/agent';
|
||||
import { agentSelectors } from '@/store/agent/slices/chat';
|
||||
import { useChatStore } from '@/store/chat';
|
||||
import { aiChatSelectors } from '@/store/chat/selectors';
|
||||
|
||||
@@ -14,7 +16,17 @@ import { useSend } from '../useSend';
|
||||
import MessageFromUrl from './MessageFromUrl';
|
||||
import { useSendMenuItems } from './useSendMenuItems';
|
||||
|
||||
const leftAgentActions: ActionKeys[] = [
|
||||
'agentMode',
|
||||
'model',
|
||||
'search',
|
||||
'tools',
|
||||
'mainToken',
|
||||
'clear',
|
||||
];
|
||||
|
||||
const leftActions: ActionKeys[] = [
|
||||
'agentMode',
|
||||
'model',
|
||||
'search',
|
||||
'typo',
|
||||
@@ -32,6 +44,7 @@ const ClassicChatInput = memo(() => {
|
||||
const { t } = useTranslation('chat');
|
||||
const { send, generating, disabled, stop } = useSend();
|
||||
|
||||
const enableAgentMode = useAgentStore(agentSelectors.enableAgentMode);
|
||||
const [mainInputSendErrorMsg, clearSendMessageError] = useChatStore((s) => [
|
||||
aiChatSelectors.isCurrentSendMessageError(s),
|
||||
s.clearSendMessageError,
|
||||
@@ -45,7 +58,7 @@ const ClassicChatInput = memo(() => {
|
||||
if (!instance) return;
|
||||
useChatStore.setState({ mainInputEditor: instance });
|
||||
}}
|
||||
leftActions={leftActions}
|
||||
leftActions={enableAgentMode ? leftAgentActions : leftActions}
|
||||
onMarkdownContentChange={(content) => {
|
||||
useChatStore.setState({ inputMessage: content });
|
||||
}}
|
||||
|
||||
+1
@@ -19,6 +19,7 @@ import { sessionSelectors } from '@/store/session/selectors';
|
||||
import { useSend, useSendGroupMessage } from '../useSend';
|
||||
|
||||
const leftActions: ActionKeys[] = [
|
||||
'agentMode',
|
||||
'model',
|
||||
'search',
|
||||
'fileUpload',
|
||||
|
||||
+1
@@ -18,6 +18,7 @@ import SendButton from './Send';
|
||||
import { useSendMessage } from './useSend';
|
||||
|
||||
const defaultLeftActions: ActionKeys[] = [
|
||||
'agentMode',
|
||||
'model',
|
||||
'search',
|
||||
'fileUpload',
|
||||
|
||||
+6
-1
@@ -7,9 +7,11 @@ export const getServerDBConfig = () => {
|
||||
DATABASE_DRIVER: process.env.DATABASE_DRIVER || 'neon',
|
||||
DATABASE_TEST_URL: process.env.DATABASE_TEST_URL,
|
||||
DATABASE_URL: process.env.DATABASE_URL,
|
||||
|
||||
KEY_VAULTS_SECRET: process.env.KEY_VAULTS_SECRET,
|
||||
|
||||
QSTASH_TOKEN: process.env.QSTASH_TOKEN,
|
||||
REDIS_URL: process.env.REDIS_URL,
|
||||
|
||||
REMOVE_GLOBAL_FILE: process.env.DISABLE_REMOVE_GLOBAL_FILE !== '0',
|
||||
},
|
||||
server: {
|
||||
@@ -19,6 +21,9 @@ export const getServerDBConfig = () => {
|
||||
|
||||
KEY_VAULTS_SECRET: z.string().optional(),
|
||||
|
||||
QSTASH_TOKEN: z.string().optional(),
|
||||
REDIS_URL: z.string().optional(),
|
||||
|
||||
REMOVE_GLOBAL_FILE: z.boolean().optional(),
|
||||
},
|
||||
});
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
import { ActionIcon } from '@lobehub/ui';
|
||||
import { Bot } from 'lucide-react';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { useAgentStore } from '@/store/agent';
|
||||
import { agentSelectors } from '@/store/agent/selectors';
|
||||
|
||||
const AgentModeToggle = memo(() => {
|
||||
const { t } = useTranslation('chat');
|
||||
const [enableAgentMode, updateAgentConfig] = useAgentStore((s) => [
|
||||
agentSelectors.enableAgentMode(s),
|
||||
s.updateAgentConfig,
|
||||
]);
|
||||
|
||||
const handleToggle = (checked: boolean) => {
|
||||
updateAgentConfig({ enableAgentMode: checked });
|
||||
};
|
||||
|
||||
return (
|
||||
<ActionIcon
|
||||
icon={Bot}
|
||||
onClick={() => {
|
||||
handleToggle(!enableAgentMode);
|
||||
}}
|
||||
style={{
|
||||
color: enableAgentMode ? 'var(--colorPrimary)' : undefined,
|
||||
}}
|
||||
title={t('agentMode.title', { defaultValue: 'Agent Mode' })}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
AgentModeToggle.displayName = 'AgentModeToggle';
|
||||
|
||||
export default AgentModeToggle;
|
||||
@@ -1,3 +1,4 @@
|
||||
import AgentMode from './AgentMode';
|
||||
import Clear from './Clear';
|
||||
import History from './History';
|
||||
import Knowledge from './Knowledge';
|
||||
@@ -13,6 +14,7 @@ import Typo from './Typo';
|
||||
import Upload from './Upload';
|
||||
|
||||
export const actionMap = {
|
||||
agentMode: AgentMode,
|
||||
clear: Clear,
|
||||
fileUpload: Upload,
|
||||
groupChatToken: GroupChatToken,
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
import debug from 'debug';
|
||||
import Redis from 'ioredis';
|
||||
|
||||
import { getRedisConnectionDescription, getRedisUrl } from './config';
|
||||
|
||||
const log = debug('redis:client');
|
||||
|
||||
/**
|
||||
* 创建 Redis 客户端实例
|
||||
*/
|
||||
export const createRedisClient = (url?: string): Redis | null => {
|
||||
const redisUrl = url || getRedisUrl();
|
||||
|
||||
if (!redisUrl) {
|
||||
console.warn('[Redis Client] No Redis URL available. Redis features are disabled.');
|
||||
return null;
|
||||
}
|
||||
|
||||
const client = new Redis(redisUrl, {
|
||||
maxRetriesPerRequest: 3,
|
||||
});
|
||||
|
||||
client.on('connect', () => {
|
||||
log(`Connected to Redis: ${getRedisConnectionDescription(redisUrl)}`);
|
||||
});
|
||||
|
||||
client.on('error', (error) => {
|
||||
console.error('[Redis Client] Redis connection error:', error);
|
||||
});
|
||||
|
||||
client.on('close', () => {
|
||||
log('Redis connection closed');
|
||||
});
|
||||
|
||||
return client;
|
||||
};
|
||||
|
||||
/**
|
||||
* 全局 Redis 客户端实例(单例模式)
|
||||
*/
|
||||
let globalRedisClient: Redis | null = null;
|
||||
let redisInitialized = false;
|
||||
|
||||
/**
|
||||
* 获取全局 Redis 客户端实例
|
||||
*/
|
||||
export function getRedisClient(): Redis | null {
|
||||
if (!redisInitialized) {
|
||||
globalRedisClient = createRedisClient();
|
||||
redisInitialized = true;
|
||||
}
|
||||
return globalRedisClient;
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭全局 Redis 客户端连接
|
||||
*/
|
||||
export async function closeRedisClient(): Promise<void> {
|
||||
if (globalRedisClient) {
|
||||
await globalRedisClient.quit();
|
||||
globalRedisClient = null;
|
||||
redisInitialized = false;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
/**
|
||||
* Get Redis URL from environment variables
|
||||
*/
|
||||
export const getRedisUrl = (): string | undefined => {
|
||||
const redisUrl = process.env.REDIS_URL;
|
||||
|
||||
if (!redisUrl) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return redisUrl;
|
||||
};
|
||||
|
||||
/**
|
||||
* Validate if Redis URL is valid
|
||||
*/
|
||||
export const validateRedisUrl = (url: string): boolean => {
|
||||
try {
|
||||
new URL(url);
|
||||
return true;
|
||||
} catch {
|
||||
console.error('[Redis Config] Invalid REDIS_URL format:', url);
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Get Redis connection description string for logging (hide password)
|
||||
*/
|
||||
export const getRedisConnectionDescription = (url: string): string => {
|
||||
try {
|
||||
const urlObj = new URL(url);
|
||||
if (urlObj.password) {
|
||||
urlObj.password = '***';
|
||||
}
|
||||
return urlObj.toString();
|
||||
} catch {
|
||||
return 'Invalid URL';
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,2 @@
|
||||
export { closeRedisClient,createRedisClient, getRedisClient } from './client';
|
||||
export { getRedisConnectionDescription,getRedisUrl, validateRedisUrl } from './config';
|
||||
@@ -186,6 +186,7 @@ const defaultMiddleware = (request: NextRequest) => {
|
||||
const isPublicRoute = createRouteMatcher([
|
||||
// backend api
|
||||
'/api/auth(.*)',
|
||||
'/api/agent(.*)',
|
||||
'/api/webhooks(.*)',
|
||||
'/webapi(.*)',
|
||||
'/trpc(.*)',
|
||||
|
||||
@@ -0,0 +1,158 @@
|
||||
import { AgentState } from '@lobechat/agent-runtime';
|
||||
import debug from 'debug';
|
||||
|
||||
import { AgentSessionMetadata, AgentStateManager, StepResult } from './AgentStateManager';
|
||||
import { StreamEventManager } from './StreamEventManager';
|
||||
|
||||
const log = debug('lobe-server:agent-runtime:coordinator');
|
||||
|
||||
/**
|
||||
* Agent Runtime Coordinator
|
||||
* 协调 AgentStateManager 和 StreamEventManager 的操作
|
||||
* 负责在状态变更时发送相应的事件
|
||||
*/
|
||||
export class AgentRuntimeCoordinator {
|
||||
private stateManager: AgentStateManager;
|
||||
private streamEventManager: StreamEventManager;
|
||||
|
||||
constructor() {
|
||||
this.stateManager = new AgentStateManager();
|
||||
this.streamEventManager = new StreamEventManager();
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建新的 Agent 会话并发送初始化事件
|
||||
*/
|
||||
async createAgentSession(
|
||||
sessionId: string,
|
||||
data: {
|
||||
agentConfig?: any;
|
||||
modelRuntimeConfig?: any;
|
||||
userId?: string;
|
||||
},
|
||||
): Promise<void> {
|
||||
try {
|
||||
// 创建会话元数据
|
||||
await this.stateManager.createSessionMetadata(sessionId, data);
|
||||
|
||||
// 获取创建的元数据
|
||||
const metadata = await this.stateManager.getSessionMetadata(sessionId);
|
||||
|
||||
if (metadata) {
|
||||
// 发送 agent runtime init 事件
|
||||
await this.streamEventManager.publishAgentRuntimeInit(sessionId, metadata);
|
||||
log('[%s] Agent session created and initialized', sessionId);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to create agent session:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存 Agent 状态并处理相应事件
|
||||
*/
|
||||
async saveAgentState(sessionId: string, state: AgentState): Promise<void> {
|
||||
try {
|
||||
const previousState = await this.stateManager.loadAgentState(sessionId);
|
||||
|
||||
// 保存状态
|
||||
await this.stateManager.saveAgentState(sessionId, state);
|
||||
|
||||
// 如果状态变为 done,发送 agent runtime end 事件
|
||||
if (state.status === 'done' && previousState?.status !== 'done') {
|
||||
await this.streamEventManager.publishAgentRuntimeEnd(sessionId, state.stepCount, state);
|
||||
log('[%s] Agent runtime completed', sessionId);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to save agent state and handle events:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存步骤结果并处理相应事件
|
||||
*/
|
||||
async saveStepResult(sessionId: string, stepResult: StepResult): Promise<void> {
|
||||
try {
|
||||
// 保存步骤结果
|
||||
await this.stateManager.saveStepResult(sessionId, stepResult);
|
||||
|
||||
// 不在这里发送 agent_runtime_end 事件,让 saveAgentState 统一处理
|
||||
// 这样确保 agent_runtime_end 是最后一个事件
|
||||
} catch (error) {
|
||||
console.error('Failed to save step result and handle events:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 Agent 状态
|
||||
*/
|
||||
async loadAgentState(sessionId: string): Promise<AgentState | null> {
|
||||
return this.stateManager.loadAgentState(sessionId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取会话元数据
|
||||
*/
|
||||
async getSessionMetadata(sessionId: string): Promise<AgentSessionMetadata | null> {
|
||||
return this.stateManager.getSessionMetadata(sessionId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取执行历史
|
||||
*/
|
||||
async getExecutionHistory(sessionId: string, limit?: number): Promise<any[]> {
|
||||
return this.stateManager.getExecutionHistory(sessionId, limit);
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除 Agent 会话
|
||||
*/
|
||||
async deleteAgentSession(sessionId: string): Promise<void> {
|
||||
try {
|
||||
await Promise.all([
|
||||
this.stateManager.deleteAgentSession(sessionId),
|
||||
this.streamEventManager.cleanupSession(sessionId),
|
||||
]);
|
||||
log('Agent session deleted: %s', sessionId);
|
||||
} catch (error) {
|
||||
console.error('Failed to delete agent session:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取活跃会话
|
||||
*/
|
||||
async getActiveSessions(): Promise<string[]> {
|
||||
return this.stateManager.getActiveSessions();
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取统计信息
|
||||
*/
|
||||
async getStats(): Promise<{
|
||||
activeSessions: number;
|
||||
completedSessions: number;
|
||||
errorSessions: number;
|
||||
totalSessions: number;
|
||||
}> {
|
||||
return this.stateManager.getStats();
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理过期会话
|
||||
*/
|
||||
async cleanupExpiredSessions(): Promise<number> {
|
||||
return this.stateManager.cleanupExpiredSessions();
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭连接
|
||||
*/
|
||||
async disconnect(): Promise<void> {
|
||||
await Promise.all([this.stateManager.disconnect(), this.streamEventManager.disconnect()]);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,405 @@
|
||||
import { AgentEvent, AgentRuntimeContext, AgentState } from '@lobechat/agent-runtime';
|
||||
import debug from 'debug';
|
||||
import Redis from 'ioredis';
|
||||
|
||||
import { getRedisClient } from '@/libs/redis';
|
||||
|
||||
const log = debug('lobe-server:agent-runtime:agent-state-manager');
|
||||
|
||||
export interface StepResult {
|
||||
events?: AgentEvent[];
|
||||
executionTime: number;
|
||||
newState: AgentState;
|
||||
nextContext?: AgentRuntimeContext;
|
||||
stepIndex: number;
|
||||
}
|
||||
|
||||
export interface AgentSessionMetadata {
|
||||
agentConfig?: any;
|
||||
createdAt: string;
|
||||
lastActiveAt: string;
|
||||
modelRuntimeConfig?: any;
|
||||
status: AgentState['status'];
|
||||
totalCost: number;
|
||||
totalSteps: number;
|
||||
userId?: string;
|
||||
}
|
||||
|
||||
export class AgentStateManager {
|
||||
private redis: Redis;
|
||||
private readonly STATE_PREFIX = 'agent_runtime_state';
|
||||
private readonly STEPS_PREFIX = 'agent_runtime_steps';
|
||||
private readonly METADATA_PREFIX = 'agent_runtime_meta';
|
||||
private readonly EVENTS_PREFIX = 'agent_runtime_events';
|
||||
private readonly DEFAULT_TTL = 12 * 3600; // 12h
|
||||
|
||||
constructor() {
|
||||
const redisClient = getRedisClient();
|
||||
if (!redisClient) {
|
||||
throw new Error('Redis is not available. Please configure REDIS_URL environment variable.');
|
||||
}
|
||||
this.redis = redisClient;
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存 Agent 状态
|
||||
*/
|
||||
async saveAgentState(sessionId: string, state: AgentState): Promise<void> {
|
||||
const stateKey = `${this.STATE_PREFIX}:${sessionId}`;
|
||||
|
||||
try {
|
||||
const serializedState = JSON.stringify(state);
|
||||
await this.redis.setex(stateKey, this.DEFAULT_TTL, serializedState);
|
||||
|
||||
// 更新元数据
|
||||
await this.updateSessionMetadata(sessionId, {
|
||||
lastActiveAt: new Date().toISOString(),
|
||||
status: state.status,
|
||||
totalCost: state.cost?.total || 0,
|
||||
totalSteps: state.stepCount,
|
||||
});
|
||||
|
||||
// 状态变更事件通过 saveStepResult 中的 events 数组记录
|
||||
|
||||
log('[%s] Saved state for step %d', sessionId, state.stepCount);
|
||||
} catch (error) {
|
||||
console.error('Failed to save agent state:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 加载 Agent 状态
|
||||
*/
|
||||
async loadAgentState(sessionId: string): Promise<AgentState | null> {
|
||||
const stateKey = `${this.STATE_PREFIX}:${sessionId}`;
|
||||
|
||||
try {
|
||||
const serializedState = await this.redis.get(stateKey);
|
||||
|
||||
if (!serializedState) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const state = JSON.parse(serializedState) as AgentState;
|
||||
log('[%s] Loaded state (step %d)', sessionId, state.stepCount);
|
||||
|
||||
return state;
|
||||
} catch (error) {
|
||||
console.error('Failed to load agent state:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存步骤执行结果
|
||||
*/
|
||||
async saveStepResult(sessionId: string, stepResult: StepResult): Promise<void> {
|
||||
const pipeline = this.redis.multi();
|
||||
|
||||
try {
|
||||
// 保存最新状态
|
||||
const stateKey = `${this.STATE_PREFIX}:${sessionId}`;
|
||||
pipeline.setex(stateKey, this.DEFAULT_TTL, JSON.stringify(stepResult.newState));
|
||||
|
||||
// 保存步骤历史
|
||||
const stepsKey = `${this.STEPS_PREFIX}:${sessionId}`;
|
||||
const stepData = {
|
||||
context: stepResult.nextContext,
|
||||
cost: stepResult.newState.cost?.total || 0,
|
||||
executionTime: stepResult.executionTime,
|
||||
status: stepResult.newState.status,
|
||||
stepIndex: stepResult.stepIndex,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
pipeline.lpush(stepsKey, JSON.stringify(stepData));
|
||||
pipeline.ltrim(stepsKey, 0, 199); // 保留最近 200 步
|
||||
pipeline.expire(stepsKey, this.DEFAULT_TTL);
|
||||
|
||||
// 保存步骤的事件序列到 agent_runtime_events
|
||||
if (stepResult.events && stepResult.events.length > 0) {
|
||||
const eventsKey = `${this.EVENTS_PREFIX}:${sessionId}`;
|
||||
|
||||
pipeline.lpush(eventsKey, JSON.stringify(stepResult.events));
|
||||
pipeline.ltrim(eventsKey, 0, 199); // 保留最近 200 步的事件
|
||||
pipeline.expire(eventsKey, this.DEFAULT_TTL);
|
||||
}
|
||||
|
||||
// 更新会话元数据
|
||||
const metaKey = `${this.METADATA_PREFIX}:${sessionId}`;
|
||||
const metadata: Partial<AgentSessionMetadata> = {
|
||||
lastActiveAt: new Date().toISOString(),
|
||||
status: stepResult.newState.status,
|
||||
totalCost: stepResult.newState.cost?.total || 0,
|
||||
totalSteps: stepResult.newState.stepCount,
|
||||
};
|
||||
pipeline.hmset(metaKey, metadata as any);
|
||||
pipeline.expire(metaKey, this.DEFAULT_TTL);
|
||||
|
||||
await pipeline.exec();
|
||||
|
||||
log(
|
||||
'[%s:%d] Saved step result with %d events',
|
||||
sessionId,
|
||||
stepResult.stepIndex,
|
||||
stepResult.events?.length || 0,
|
||||
);
|
||||
} catch (error) {
|
||||
console.error('Failed to save step result:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取执行历史
|
||||
*/
|
||||
async getExecutionHistory(sessionId: string, limit: number = 50): Promise<any[]> {
|
||||
const stepsKey = `${this.STEPS_PREFIX}:${sessionId}`;
|
||||
|
||||
try {
|
||||
const history = await this.redis.lrange(stepsKey, 0, limit - 1);
|
||||
return history.map((item) => JSON.parse(item)).reverse(); // 最早的在前面
|
||||
} catch (error) {
|
||||
console.error('Failed to get execution history:', error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取会话元数据
|
||||
*/
|
||||
async getSessionMetadata(sessionId: string): Promise<AgentSessionMetadata | null> {
|
||||
const metaKey = `${this.METADATA_PREFIX}:${sessionId}`;
|
||||
|
||||
try {
|
||||
const metadata = await this.redis.hgetall(metaKey);
|
||||
|
||||
if (Object.keys(metadata).length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return {
|
||||
agentConfig: metadata.agentConfig ? JSON.parse(metadata.agentConfig) : undefined,
|
||||
createdAt: metadata.createdAt,
|
||||
lastActiveAt: metadata.lastActiveAt,
|
||||
modelRuntimeConfig: metadata.modelRuntimeConfig
|
||||
? JSON.parse(metadata.modelRuntimeConfig)
|
||||
: undefined,
|
||||
status: metadata.status as AgentState['status'],
|
||||
totalCost: parseFloat(metadata.totalCost) || 0,
|
||||
totalSteps: parseInt(metadata.totalSteps) || 0,
|
||||
userId: metadata.userId,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('Failed to get session metadata:', error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建新的会话元数据
|
||||
*/
|
||||
async createSessionMetadata(
|
||||
sessionId: string,
|
||||
data: {
|
||||
agentConfig?: any;
|
||||
modelRuntimeConfig?: any;
|
||||
userId?: string;
|
||||
},
|
||||
): Promise<void> {
|
||||
const metaKey = `${this.METADATA_PREFIX}:${sessionId}`;
|
||||
|
||||
try {
|
||||
const metadata: AgentSessionMetadata = {
|
||||
agentConfig: data.agentConfig,
|
||||
createdAt: new Date().toISOString(),
|
||||
lastActiveAt: new Date().toISOString(),
|
||||
modelRuntimeConfig: data.modelRuntimeConfig,
|
||||
status: 'idle',
|
||||
totalCost: 0,
|
||||
totalSteps: 0,
|
||||
userId: data.userId,
|
||||
};
|
||||
|
||||
// 序列化复杂对象
|
||||
const redisData: Record<string, string> = {
|
||||
createdAt: metadata.createdAt,
|
||||
lastActiveAt: metadata.lastActiveAt,
|
||||
status: metadata.status,
|
||||
totalCost: metadata.totalCost.toString(),
|
||||
totalSteps: metadata.totalSteps.toString(),
|
||||
};
|
||||
|
||||
if (metadata.userId) redisData.userId = metadata.userId;
|
||||
if (metadata.modelRuntimeConfig)
|
||||
redisData.modelRuntimeConfig = JSON.stringify(metadata.modelRuntimeConfig);
|
||||
if (metadata.agentConfig) redisData.agentConfig = JSON.stringify(metadata.agentConfig);
|
||||
|
||||
await this.redis.hmset(metaKey, redisData);
|
||||
await this.redis.expire(metaKey, this.DEFAULT_TTL);
|
||||
|
||||
log('[%s]Created session metadata', sessionId);
|
||||
} catch (error) {
|
||||
console.error('Failed to create session metadata:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 更新会话元数据
|
||||
*/
|
||||
private async updateSessionMetadata(
|
||||
sessionId: string,
|
||||
updates: Partial<AgentSessionMetadata>,
|
||||
): Promise<void> {
|
||||
const metaKey = `${this.METADATA_PREFIX}:${sessionId}`;
|
||||
|
||||
try {
|
||||
const redisUpdates: Record<string, string> = {};
|
||||
|
||||
Object.entries(updates).forEach(([key, value]) => {
|
||||
if (value !== undefined) {
|
||||
if (typeof value === 'object') {
|
||||
redisUpdates[key] = JSON.stringify(value);
|
||||
} else {
|
||||
redisUpdates[key] = value.toString();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (Object.keys(redisUpdates).length > 0) {
|
||||
await this.redis.hmset(metaKey, redisUpdates);
|
||||
await this.redis.expire(metaKey, this.DEFAULT_TTL);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to update session metadata:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 删除 Agent 会话的所有数据
|
||||
*/
|
||||
async deleteAgentSession(sessionId: string): Promise<void> {
|
||||
const keys = [
|
||||
`${this.STATE_PREFIX}:${sessionId}`,
|
||||
`${this.STEPS_PREFIX}:${sessionId}`,
|
||||
`${this.METADATA_PREFIX}:${sessionId}`,
|
||||
`${this.EVENTS_PREFIX}:${sessionId}`,
|
||||
];
|
||||
|
||||
try {
|
||||
await this.redis.del(...keys);
|
||||
log('Deleted session %s', sessionId);
|
||||
} catch (error) {
|
||||
console.error('Failed to delete agent session:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有活跃会话
|
||||
*/
|
||||
async getActiveSessions(): Promise<string[]> {
|
||||
try {
|
||||
const pattern = `${this.STATE_PREFIX}:*`;
|
||||
const keys = await this.redis.keys(pattern);
|
||||
return keys.map((key) => key.replace(`${this.STATE_PREFIX}:`, ''));
|
||||
} catch (error) {
|
||||
console.error('Failed to get active sessions:', error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理过期的会话数据
|
||||
*/
|
||||
async cleanupExpiredSessions(): Promise<number> {
|
||||
try {
|
||||
const activeSessions = await this.getActiveSessions();
|
||||
let cleanedCount = 0;
|
||||
|
||||
for (const sessionId of activeSessions) {
|
||||
const metadata = await this.getSessionMetadata(sessionId);
|
||||
|
||||
if (metadata) {
|
||||
const lastActiveTime = new Date(metadata.lastActiveAt).getTime();
|
||||
const now = Date.now();
|
||||
const daysSinceActive = (now - lastActiveTime) / (1000 * 60 * 60 * 24);
|
||||
|
||||
// 清理超过 7 天未活跃的会话
|
||||
if (daysSinceActive > 7) {
|
||||
await this.deleteAgentSession(sessionId);
|
||||
cleanedCount++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
log('Cleaned up %d expired sessions', cleanedCount);
|
||||
return cleanedCount;
|
||||
} catch (error) {
|
||||
console.error('Failed to cleanup expired sessions:', error);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取统计信息
|
||||
*/
|
||||
async getStats(): Promise<{
|
||||
activeSessions: number;
|
||||
completedSessions: number;
|
||||
errorSessions: number;
|
||||
totalSessions: number;
|
||||
}> {
|
||||
try {
|
||||
const sessions = await this.getActiveSessions();
|
||||
const stats = {
|
||||
activeSessions: 0,
|
||||
completedSessions: 0,
|
||||
errorSessions: 0,
|
||||
totalSessions: sessions.length,
|
||||
};
|
||||
|
||||
for (const sessionId of sessions) {
|
||||
const metadata = await this.getSessionMetadata(sessionId);
|
||||
|
||||
if (metadata) {
|
||||
switch (metadata.status) {
|
||||
case 'running':
|
||||
case 'waiting_for_human': {
|
||||
stats.activeSessions++;
|
||||
break;
|
||||
}
|
||||
case 'done': {
|
||||
stats.completedSessions++;
|
||||
break;
|
||||
}
|
||||
case 'error':
|
||||
case 'interrupted': {
|
||||
stats.errorSessions++;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return stats;
|
||||
} catch (error) {
|
||||
console.error('Failed to get stats:', error);
|
||||
return {
|
||||
activeSessions: 0,
|
||||
completedSessions: 0,
|
||||
errorSessions: 0,
|
||||
totalSessions: 0,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭 Redis 连接
|
||||
*/
|
||||
async disconnect(): Promise<void> {
|
||||
await this.redis.quit();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,226 @@
|
||||
import {
|
||||
AgentInstruction,
|
||||
AgentRuntimeContext,
|
||||
AgentState,
|
||||
InterventionChecker,
|
||||
} from '@lobechat/agent-runtime';
|
||||
import {
|
||||
BuiltinToolManifest,
|
||||
ChatToolPayload,
|
||||
HumanInterventionPolicy,
|
||||
MessageToolCall,
|
||||
} from '@lobechat/types';
|
||||
import debug from 'debug';
|
||||
|
||||
const log = debug('lobe-server:agent-runtime:general-agent');
|
||||
|
||||
export interface ChatAgentConfig {
|
||||
agentConfig?: {
|
||||
[key: string]: any;
|
||||
maxSteps?: number;
|
||||
};
|
||||
modelRuntimeConfig?: {
|
||||
model: string;
|
||||
provider: string;
|
||||
};
|
||||
sessionId: string;
|
||||
userId?: string;
|
||||
}
|
||||
|
||||
export interface GeneralAgentLLMResultPayload {
|
||||
hasToolsCalling: boolean;
|
||||
result: { content: string; tool_calls: MessageToolCall[] };
|
||||
toolsCalling: ChatToolPayload[];
|
||||
}
|
||||
|
||||
export interface GeneralAgentToolResultPayload {
|
||||
data: any;
|
||||
executionTime: number;
|
||||
isSuccess: boolean;
|
||||
toolCall: ChatToolPayload;
|
||||
toolCallId: string;
|
||||
}
|
||||
|
||||
export class GeneralAgent {
|
||||
private config: ChatAgentConfig;
|
||||
|
||||
constructor(config: ChatAgentConfig) {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
async runner(
|
||||
context: AgentRuntimeContext,
|
||||
state: AgentState,
|
||||
): Promise<AgentInstruction | AgentInstruction[]> {
|
||||
log(`[${this.config.sessionId}] Processing phase: %s`, context.phase);
|
||||
|
||||
switch (context.phase) {
|
||||
case 'user_input': {
|
||||
// call LLM
|
||||
return {
|
||||
payload: {
|
||||
messages: state.messages,
|
||||
model: this.config.modelRuntimeConfig?.model,
|
||||
provider: this.config.modelRuntimeConfig?.provider,
|
||||
tools: state.tools,
|
||||
},
|
||||
type: 'call_llm',
|
||||
};
|
||||
}
|
||||
|
||||
case 'llm_result': {
|
||||
// LLM completed, determine next action based on tool calls
|
||||
const payload = context.payload as GeneralAgentLLMResultPayload;
|
||||
|
||||
// Execute tools if present
|
||||
if (payload.hasToolsCalling) {
|
||||
const toolsCalling = payload.toolsCalling;
|
||||
|
||||
// Check intervention for all tool calls
|
||||
const toolsNeedingApproval: ChatToolPayload[] = [];
|
||||
|
||||
for (const toolCall of toolsCalling) {
|
||||
const manifest = state.toolManifestMap[toolCall.identifier];
|
||||
const policy = this.checkToolIntervention(toolCall, manifest);
|
||||
|
||||
// For now, only handle 'always' policy (skip 'first' and 'never')
|
||||
if (policy === 'always') {
|
||||
toolsNeedingApproval.push(toolCall);
|
||||
}
|
||||
}
|
||||
|
||||
// If any tools need approval, request human intervention
|
||||
if (toolsNeedingApproval.length > 0) {
|
||||
log(
|
||||
`[${this.config.sessionId}] Tools requiring approval: %o`,
|
||||
toolsNeedingApproval.map((t) => `${t.identifier}/${t.apiName}`),
|
||||
);
|
||||
|
||||
return {
|
||||
pendingToolsCalling: toolsNeedingApproval as any,
|
||||
reason: 'Tools require human approval',
|
||||
type: 'request_human_approve',
|
||||
};
|
||||
}
|
||||
|
||||
// No intervention needed, proceed with tool execution
|
||||
// Use batch execution for multiple tool calls to improve performance
|
||||
if (toolsCalling.length > 1) {
|
||||
return {
|
||||
payload: toolsCalling as any,
|
||||
type: 'call_tools_batch',
|
||||
};
|
||||
} else if (toolsCalling.length === 1) {
|
||||
// Single tool executes directly
|
||||
return {
|
||||
payload: toolsCalling[0] as any,
|
||||
type: 'call_tool',
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Finish if no tools
|
||||
return {
|
||||
reason: 'completed',
|
||||
reasonDetail: 'General agent completed successfully',
|
||||
type: 'finish',
|
||||
};
|
||||
}
|
||||
|
||||
case 'tool_result':
|
||||
case 'tools_batch_result': {
|
||||
// Continue calling LLM after tool execution completes
|
||||
return {
|
||||
payload: {
|
||||
messages: state.messages,
|
||||
model: this.config.modelRuntimeConfig?.model,
|
||||
provider: this.config.modelRuntimeConfig?.provider,
|
||||
tools: state.tools,
|
||||
},
|
||||
type: 'call_llm',
|
||||
};
|
||||
}
|
||||
|
||||
default: {
|
||||
return {
|
||||
reason: 'error_recovery',
|
||||
reasonDetail: `Unknown phase: ${context.phase}`,
|
||||
type: 'finish',
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a tool call requires human intervention
|
||||
* @param toolCall - Tool call to check
|
||||
* @param manifest - Tool manifest containing intervention config
|
||||
* @returns Intervention policy to apply
|
||||
*/
|
||||
private checkToolIntervention(
|
||||
toolCall: ChatToolPayload,
|
||||
manifest: BuiltinToolManifest | undefined,
|
||||
): HumanInterventionPolicy {
|
||||
// No manifest means no intervention config
|
||||
if (!manifest) {
|
||||
log(`[${this.config.sessionId}] No manifest found for tool: ${toolCall.identifier}`);
|
||||
return 'never';
|
||||
}
|
||||
|
||||
// First, try to get API-level intervention config
|
||||
const api = manifest.api.find((a) => a.name === toolCall.apiName);
|
||||
const apiLevelConfig = api?.humanIntervention;
|
||||
|
||||
// If API has its own intervention config, use it
|
||||
if (apiLevelConfig) {
|
||||
// Parse tool arguments
|
||||
let toolArgs: Record<string, any> = {};
|
||||
try {
|
||||
toolArgs = JSON.parse(toolCall.arguments);
|
||||
} catch (error) {
|
||||
log(
|
||||
`[${this.config.sessionId}] Failed to parse tool arguments for ${toolCall.identifier}/${toolCall.apiName}: %o`,
|
||||
error,
|
||||
);
|
||||
}
|
||||
|
||||
// Check intervention using InterventionChecker
|
||||
const policy = InterventionChecker.shouldIntervene({
|
||||
config: apiLevelConfig,
|
||||
toolArgs,
|
||||
// TODO: Add confirmedHistory support when implementing 'first' policy
|
||||
// confirmedHistory: state.metadata?.confirmedToolCalls || [],
|
||||
// toolKey: InterventionChecker.generateToolKey(toolCall.identifier, toolCall.apiName),
|
||||
});
|
||||
|
||||
log(
|
||||
`[${this.config.sessionId}] API-level intervention check for ${toolCall.identifier}/${toolCall.apiName}: %s`,
|
||||
policy,
|
||||
);
|
||||
|
||||
return policy;
|
||||
}
|
||||
|
||||
// Otherwise, use tool-level default intervention policy
|
||||
const toolLevelPolicy = manifest.humanIntervention || 'never';
|
||||
|
||||
log(
|
||||
`[${this.config.sessionId}] Tool-level intervention check for ${toolCall.identifier}/${toolCall.apiName}: %s`,
|
||||
toolLevelPolicy,
|
||||
);
|
||||
|
||||
return toolLevelPolicy;
|
||||
}
|
||||
|
||||
/**
|
||||
* Empty tools registry
|
||||
*/
|
||||
tools = {};
|
||||
|
||||
/**
|
||||
* Get configuration
|
||||
*/
|
||||
getConfig() {
|
||||
return this.config;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,611 @@
|
||||
import {
|
||||
AgentEvent,
|
||||
AgentInstruction,
|
||||
CallLLMPayload,
|
||||
InstructionExecutor,
|
||||
UsageCounter,
|
||||
} from '@lobechat/agent-runtime';
|
||||
import { ToolNameResolver } from '@lobechat/context-engine';
|
||||
import { consumeStreamUntilDone } from '@lobechat/model-runtime';
|
||||
import { ChatToolPayload, ClientSecretPayload, MessageToolCall } from '@lobechat/types';
|
||||
import debug from 'debug';
|
||||
|
||||
import { MessageModel } from '@/database/models/message';
|
||||
import { GeneralAgentLLMResultPayload } from '@/server/modules/AgentRuntime/GeneralAgent';
|
||||
import { initModelRuntimeWithUserPayload } from '@/server/modules/ModelRuntime';
|
||||
import { ToolExecutionService } from '@/server/services/toolExecution';
|
||||
|
||||
import { StreamEventManager } from './StreamEventManager';
|
||||
|
||||
const log = debug('lobe-server:agent-runtime:streaming-executors');
|
||||
|
||||
// Tool pricing configuration (USD per call)
|
||||
const TOOL_PRICING: Record<string, number> = {
|
||||
'lobe-web-browsing/craw': 0.002,
|
||||
'lobe-web-browsing/search': 0.001,
|
||||
};
|
||||
|
||||
export interface RuntimeExecutorContext {
|
||||
fileService?: any;
|
||||
messageModel: MessageModel;
|
||||
sessionId: string;
|
||||
stepIndex: number;
|
||||
streamManager: StreamEventManager;
|
||||
toolExecutionService: ToolExecutionService;
|
||||
userId?: string;
|
||||
userPayload?: ClientSecretPayload;
|
||||
}
|
||||
|
||||
export const createRuntimeExecutors = (
|
||||
ctx: RuntimeExecutorContext,
|
||||
): Partial<Record<AgentInstruction['type'], InstructionExecutor>> => ({
|
||||
/**
|
||||
* 创建流式 LLM 执行器
|
||||
* 集成 Agent Runtime 和流式事件发布
|
||||
*/
|
||||
call_llm: async (instruction, state) => {
|
||||
const { payload } = instruction as Extract<AgentInstruction, { type: 'call_llm' }>;
|
||||
const llmPayload = payload as CallLLMPayload;
|
||||
const { sessionId, stepIndex, streamManager } = ctx;
|
||||
const events: AgentEvent[] = [];
|
||||
|
||||
// 类型断言确保 payload 的正确性
|
||||
const sessionLogId = `${sessionId}:${stepIndex}`;
|
||||
|
||||
const stagePrefix = `[${sessionLogId}][call_llm]`;
|
||||
|
||||
log(`${stagePrefix} Starting session`);
|
||||
|
||||
// create assistant message
|
||||
const assistantMessageItem = await ctx.messageModel.create({
|
||||
content: '',
|
||||
fromModel: llmPayload.model,
|
||||
fromProvider: llmPayload.provider,
|
||||
role: 'assistant',
|
||||
sessionId: state.metadata!.sessionId!,
|
||||
threadId: state.metadata?.threadId,
|
||||
topicId: state.metadata?.topicId,
|
||||
});
|
||||
|
||||
// 发布流式开始事件
|
||||
await streamManager.publishStreamEvent(sessionId, {
|
||||
data: {
|
||||
assistantMessage: assistantMessageItem,
|
||||
model: llmPayload.model,
|
||||
provider: llmPayload.provider,
|
||||
},
|
||||
stepIndex,
|
||||
type: 'stream_start',
|
||||
});
|
||||
|
||||
try {
|
||||
let content = '';
|
||||
let toolsCalling: ChatToolPayload[] = [];
|
||||
let tool_calls: MessageToolCall[] = [];
|
||||
let thinkingContent = '';
|
||||
let imageList: any[] = [];
|
||||
let grounding: any = null;
|
||||
let currentStepUsage: any = undefined;
|
||||
|
||||
// 初始化 ModelRuntime
|
||||
const modelRuntime = initModelRuntimeWithUserPayload(
|
||||
llmPayload.provider,
|
||||
ctx.userPayload || {},
|
||||
);
|
||||
|
||||
// 构造 ChatStreamPayload
|
||||
const chatPayload = {
|
||||
messages: llmPayload.messages,
|
||||
model: llmPayload.model,
|
||||
tools: llmPayload.tools,
|
||||
};
|
||||
|
||||
log(
|
||||
`${stagePrefix} calling model-runtime chat (model: %s, messages: %d, tools: %d)`,
|
||||
llmPayload.model,
|
||||
llmPayload.messages.length,
|
||||
llmPayload.tools?.length ?? 0,
|
||||
);
|
||||
|
||||
// Buffer:累积 text 和 reasoning,每 50ms 发送一次
|
||||
const BUFFER_INTERVAL = 50;
|
||||
let textBuffer = '';
|
||||
let reasoningBuffer = '';
|
||||
// eslint-disable-next-line no-undef
|
||||
let textBufferTimer: NodeJS.Timeout | null = null;
|
||||
// eslint-disable-next-line no-undef
|
||||
let reasoningBufferTimer: NodeJS.Timeout | null = null;
|
||||
|
||||
const flushTextBuffer = async () => {
|
||||
const delta = textBuffer;
|
||||
textBuffer = '';
|
||||
|
||||
if (!!delta) {
|
||||
log(`[${sessionLogId}] flushTextBuffer:`, delta);
|
||||
|
||||
// 构建标准 Agent Runtime 事件
|
||||
events.push({
|
||||
chunk: { text: delta, type: 'text' },
|
||||
type: 'llm_stream',
|
||||
});
|
||||
|
||||
await streamManager.publishStreamChunk(sessionId, stepIndex, {
|
||||
chunkType: 'text',
|
||||
content: delta,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const flushReasoningBuffer = async () => {
|
||||
const delta = reasoningBuffer;
|
||||
|
||||
reasoningBuffer = '';
|
||||
|
||||
if (!!delta) {
|
||||
log(`[${sessionLogId}] flushReasoningBuffer:`, delta);
|
||||
|
||||
events.push({
|
||||
chunk: { text: delta, type: 'reasoning' },
|
||||
type: 'llm_stream',
|
||||
});
|
||||
|
||||
await streamManager.publishStreamChunk(sessionId, stepIndex, {
|
||||
chunkType: 'reasoning',
|
||||
reasoning: delta,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// 调用 model-runtime chat
|
||||
const response = await modelRuntime.chat(chatPayload, {
|
||||
callback: {
|
||||
onCompletion: async (data) => {
|
||||
// 捕获 usage (可能包含 cost,也可能不包含)
|
||||
if (data.usage) {
|
||||
currentStepUsage = data.usage;
|
||||
}
|
||||
},
|
||||
onText: async (text) => {
|
||||
// log(`[${sessionLogId}][text]`, text);
|
||||
content += text;
|
||||
|
||||
textBuffer += text;
|
||||
|
||||
// 如果没有定时器,创建一个
|
||||
if (!textBufferTimer) {
|
||||
textBufferTimer = setTimeout(async () => {
|
||||
await flushTextBuffer();
|
||||
textBufferTimer = null;
|
||||
}, BUFFER_INTERVAL);
|
||||
}
|
||||
},
|
||||
onThinking: async (reasoning) => {
|
||||
// log(`[${sessionLogId}][reasoning]`, reasoning);
|
||||
thinkingContent += reasoning;
|
||||
|
||||
// Buffer reasoning 内容
|
||||
reasoningBuffer += reasoning;
|
||||
|
||||
// 如果没有定时器,创建一个
|
||||
if (!reasoningBufferTimer) {
|
||||
reasoningBufferTimer = setTimeout(async () => {
|
||||
await flushReasoningBuffer();
|
||||
reasoningBufferTimer = null;
|
||||
}, BUFFER_INTERVAL);
|
||||
}
|
||||
},
|
||||
onToolsCalling: async ({ toolsCalling: raw }) => {
|
||||
const payload = new ToolNameResolver().resolve(raw, state.toolManifestMap);
|
||||
// log(`[${sessionLogId}][toolsCalling]`, payload);
|
||||
toolsCalling = payload;
|
||||
tool_calls = raw;
|
||||
|
||||
// 如果有 textBuffer,先推一次
|
||||
if (!!textBuffer) {
|
||||
await flushTextBuffer();
|
||||
}
|
||||
|
||||
await streamManager.publishStreamChunk(sessionId, stepIndex, {
|
||||
chunkType: 'tools_calling',
|
||||
toolsCalling: payload,
|
||||
});
|
||||
},
|
||||
},
|
||||
user: ctx.userId,
|
||||
});
|
||||
|
||||
// 消费流确保所有回调执行完成
|
||||
await consumeStreamUntilDone(response);
|
||||
|
||||
await flushTextBuffer();
|
||||
await flushReasoningBuffer();
|
||||
|
||||
// 清理定时器并 flush 剩余 buffer
|
||||
if (textBufferTimer) {
|
||||
clearTimeout(textBufferTimer);
|
||||
textBufferTimer = null;
|
||||
}
|
||||
|
||||
if (reasoningBufferTimer) {
|
||||
clearTimeout(reasoningBufferTimer);
|
||||
reasoningBufferTimer = null;
|
||||
}
|
||||
|
||||
log(`[${sessionLogId}] finish model-runtime calling`);
|
||||
|
||||
if (thinkingContent) {
|
||||
log(`[${sessionLogId}][reasoning]`, thinkingContent);
|
||||
}
|
||||
if (content) {
|
||||
log(`[${sessionLogId}][content]`, content);
|
||||
}
|
||||
if (toolsCalling.length > 0) {
|
||||
log(`[${sessionLogId}][toolsCalling] `, toolsCalling);
|
||||
}
|
||||
|
||||
// 日志输出 usage
|
||||
if (currentStepUsage) {
|
||||
log(`[${sessionLogId}][usage] %O`, currentStepUsage);
|
||||
}
|
||||
|
||||
// 添加一个完整的 llm_stream 事件(包含所有流式块)
|
||||
events.push({
|
||||
result: { content, reasoning: thinkingContent, tool_calls, usage: currentStepUsage },
|
||||
type: 'llm_result',
|
||||
});
|
||||
|
||||
// 发布流式结束事件
|
||||
await streamManager.publishStreamEvent(sessionId, {
|
||||
data: {
|
||||
finalContent: content,
|
||||
grounding: grounding,
|
||||
imageList: imageList.length > 0 ? imageList : undefined,
|
||||
reasoning: thinkingContent || undefined,
|
||||
toolsCalling: toolsCalling,
|
||||
usage: currentStepUsage,
|
||||
},
|
||||
stepIndex,
|
||||
type: 'stream_end',
|
||||
});
|
||||
|
||||
log('[%s:%d] call_llm completed', sessionId, stepIndex);
|
||||
|
||||
// ===== 1. 先保存原始 usage 到 message.metadata =====
|
||||
try {
|
||||
await ctx.messageModel.update(assistantMessageItem.id, {
|
||||
content,
|
||||
// 保存原始 usage,不做任何修改
|
||||
metadata: currentStepUsage,
|
||||
reasoning: {
|
||||
content: thinkingContent,
|
||||
},
|
||||
search: grounding,
|
||||
tools: toolsCalling.length > 0 ? toolsCalling : undefined,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('[call_llm] Failed to update message:', error);
|
||||
}
|
||||
|
||||
// ===== 2. 然后累加到 AgentState =====
|
||||
let newState = structuredClone(state);
|
||||
|
||||
newState.messages.push({
|
||||
content,
|
||||
role: 'assistant',
|
||||
tool_calls: tool_calls.length > 0 ? tool_calls : undefined,
|
||||
});
|
||||
|
||||
if (currentStepUsage) {
|
||||
// 使用 UsageCounter 统一累加 usage 和 cost
|
||||
const { usage, cost } = UsageCounter.accumulateLLM({
|
||||
cost: newState.cost,
|
||||
model: llmPayload.model,
|
||||
modelUsage: currentStepUsage,
|
||||
provider: llmPayload.provider,
|
||||
usage: newState.usage,
|
||||
});
|
||||
|
||||
newState.usage = usage;
|
||||
if (cost) newState.cost = cost;
|
||||
}
|
||||
|
||||
return {
|
||||
events,
|
||||
newState,
|
||||
nextContext: {
|
||||
payload: {
|
||||
hasToolsCalling: toolsCalling.length > 0,
|
||||
result: { content, tool_calls },
|
||||
toolsCalling: toolsCalling,
|
||||
} as GeneralAgentLLMResultPayload,
|
||||
phase: 'llm_result',
|
||||
session: {
|
||||
eventCount: events.length,
|
||||
messageCount: newState.messages.length,
|
||||
sessionId: state.sessionId,
|
||||
status: 'running',
|
||||
stepCount: state.stepCount + 1,
|
||||
},
|
||||
stepUsage: currentStepUsage,
|
||||
},
|
||||
};
|
||||
} catch (error) {
|
||||
// 发布错误事件
|
||||
await streamManager.publishStreamEvent(sessionId, {
|
||||
data: {
|
||||
error: (error as Error).message,
|
||||
phase: 'llm_execution',
|
||||
},
|
||||
stepIndex,
|
||||
type: 'error',
|
||||
});
|
||||
|
||||
console.error(
|
||||
`[StreamingLLMExecutor][${sessionId}:${stepIndex}] LLM execution failed:`,
|
||||
error,
|
||||
);
|
||||
throw error;
|
||||
}
|
||||
},
|
||||
/**
|
||||
* 工具执行
|
||||
*/
|
||||
call_tool: async (instruction, state) => {
|
||||
const { payload } = instruction as Extract<AgentInstruction, { type: 'call_tool' }>;
|
||||
const { sessionId, stepIndex, streamManager, toolExecutionService } = ctx;
|
||||
const events: AgentEvent[] = [];
|
||||
|
||||
const sessionLogId = `${sessionId}:${stepIndex}`;
|
||||
log(`[${sessionLogId}] payload: %O`, payload);
|
||||
|
||||
// 发布工具执行开始事件
|
||||
await streamManager.publishStreamEvent(sessionId, {
|
||||
data: payload,
|
||||
stepIndex,
|
||||
type: 'tool_start',
|
||||
});
|
||||
|
||||
try {
|
||||
// Convert CallingToolPayload to ChatToolPayload for ToolExecutionService
|
||||
const chatToolPayload: ChatToolPayload = {
|
||||
apiName: payload.apiName,
|
||||
arguments: payload.arguments,
|
||||
id: payload.id,
|
||||
identifier: payload.identifier,
|
||||
type: payload.type as any, // CallingToolPayload.type is compatible
|
||||
};
|
||||
|
||||
const toolName = `${chatToolPayload.identifier}/${chatToolPayload.apiName}`;
|
||||
// Execute tool using ToolExecutionService
|
||||
log(`[${sessionLogId}] Executing tool ${toolName} ...`);
|
||||
const executionResult = await toolExecutionService.executeTool(chatToolPayload, {
|
||||
toolManifestMap: state.toolManifestMap,
|
||||
userId: ctx.userId,
|
||||
userPayload: ctx.userPayload,
|
||||
});
|
||||
|
||||
const executionTime = executionResult.executionTime;
|
||||
const isSuccess = executionResult.success;
|
||||
log(
|
||||
`[${sessionLogId}] Executing ${toolName} in ${executionTime}ms, result: %O`,
|
||||
executionResult,
|
||||
);
|
||||
|
||||
// 发布工具执行结果事件
|
||||
await streamManager.publishStreamEvent(sessionId, {
|
||||
data: {
|
||||
executionTime,
|
||||
isSuccess,
|
||||
payload,
|
||||
phase: 'tool_execution',
|
||||
result: executionResult,
|
||||
},
|
||||
stepIndex,
|
||||
type: 'tool_end',
|
||||
});
|
||||
|
||||
// 最终更新数据库
|
||||
try {
|
||||
await ctx.messageModel.create({
|
||||
content: executionResult.content,
|
||||
plugin: payload as any,
|
||||
pluginError: executionResult.error,
|
||||
pluginState: executionResult.state,
|
||||
role: 'tool',
|
||||
sessionId: state.metadata!.sessionId!,
|
||||
threadId: state.metadata?.threadId,
|
||||
tool_call_id: payload.id,
|
||||
topicId: state.metadata?.topicId,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('[StreamingToolExecutor] Failed to create tool message: %O', error);
|
||||
}
|
||||
|
||||
const newState = structuredClone(state);
|
||||
|
||||
newState.messages.push({
|
||||
content: executionResult.content,
|
||||
role: 'tool',
|
||||
tool_call_id: payload.id,
|
||||
});
|
||||
|
||||
events.push({ id: payload.id, result: executionResult, type: 'tool_result' });
|
||||
|
||||
// 获取工具单价
|
||||
const toolCost = TOOL_PRICING[toolName] || 0;
|
||||
|
||||
// 使用 UsageCounter 统一累加 tool usage
|
||||
const { usage, cost } = UsageCounter.accumulateTool({
|
||||
cost: newState.cost,
|
||||
executionTime,
|
||||
success: isSuccess,
|
||||
toolCost,
|
||||
toolName,
|
||||
usage: newState.usage,
|
||||
});
|
||||
|
||||
newState.usage = usage;
|
||||
if (cost) newState.cost = cost;
|
||||
|
||||
// 查找当前工具的统计信息
|
||||
const currentToolStats = usage.tools.byTool.find((t) => t.name === toolName);
|
||||
|
||||
// 日志输出 usage
|
||||
log(
|
||||
`[${sessionLogId}][tool usage] %s: calls=%d, time=%dms, success=%s, cost=$%s`,
|
||||
toolName,
|
||||
currentToolStats?.calls || 0,
|
||||
executionTime,
|
||||
isSuccess,
|
||||
toolCost.toFixed(4),
|
||||
);
|
||||
|
||||
log('[%s:%d] Tool execution completed', sessionId, stepIndex);
|
||||
|
||||
return {
|
||||
events,
|
||||
newState,
|
||||
nextContext: {
|
||||
payload: {
|
||||
data: executionResult,
|
||||
executionTime,
|
||||
isSuccess,
|
||||
toolCall: payload,
|
||||
toolCallId: payload.id,
|
||||
},
|
||||
phase: 'tool_result',
|
||||
session: {
|
||||
eventCount: events.length,
|
||||
messageCount: newState.messages.length,
|
||||
sessionId: state.sessionId,
|
||||
status: 'running',
|
||||
stepCount: state.stepCount + 1,
|
||||
},
|
||||
stepUsage: {
|
||||
cost: toolCost,
|
||||
toolName,
|
||||
unitPrice: toolCost,
|
||||
usageCount: 1,
|
||||
},
|
||||
},
|
||||
};
|
||||
} catch (error) {
|
||||
// 发布工具执行错误事件
|
||||
await streamManager.publishStreamEvent(sessionId, {
|
||||
data: {
|
||||
error: (error as Error).message,
|
||||
phase: 'tool_execution',
|
||||
},
|
||||
stepIndex,
|
||||
type: 'error',
|
||||
});
|
||||
|
||||
events.push({
|
||||
error: error,
|
||||
type: 'error',
|
||||
});
|
||||
|
||||
console.error(
|
||||
`[StreamingToolExecutor] Tool execution failed for session ${sessionId}:${stepIndex}:`,
|
||||
error,
|
||||
);
|
||||
|
||||
return {
|
||||
events,
|
||||
newState: state, // 状态不变
|
||||
};
|
||||
}
|
||||
},
|
||||
/**
|
||||
* 完成 runtime 运行
|
||||
*/
|
||||
finish: async (instruction, state) => {
|
||||
const { reason, reasonDetail } = instruction as Extract<AgentInstruction, { type: 'finish' }>;
|
||||
const { sessionId, stepIndex, streamManager } = ctx;
|
||||
|
||||
log('[%s:%d] Finishing execution: (%s)', sessionId, stepIndex, reason);
|
||||
|
||||
// 发布执行完成事件
|
||||
await streamManager.publishStreamEvent(sessionId, {
|
||||
data: {
|
||||
finalState: { ...state, status: 'done' },
|
||||
phase: 'execution_complete',
|
||||
reason,
|
||||
reasonDetail,
|
||||
},
|
||||
stepIndex,
|
||||
type: 'step_complete',
|
||||
});
|
||||
|
||||
const newState = structuredClone(state);
|
||||
newState.lastModified = new Date().toISOString();
|
||||
newState.status = 'done';
|
||||
|
||||
const events: AgentEvent[] = [
|
||||
{
|
||||
finalState: newState,
|
||||
reason,
|
||||
reasonDetail,
|
||||
type: 'done',
|
||||
},
|
||||
];
|
||||
|
||||
return { events, newState };
|
||||
},
|
||||
|
||||
/**
|
||||
* 人工审批
|
||||
*/
|
||||
request_human_approve: async (instruction, state) => {
|
||||
const { pendingToolsCalling } = instruction as Extract<
|
||||
AgentInstruction,
|
||||
{ type: 'request_human_approve' }
|
||||
>;
|
||||
const { sessionId, stepIndex, streamManager } = ctx;
|
||||
|
||||
log('[%s:%d] Requesting human approval for %O', sessionId, stepIndex, pendingToolsCalling);
|
||||
|
||||
// 发布人工审批请求事件
|
||||
await streamManager.publishStreamEvent(sessionId, {
|
||||
data: {
|
||||
pendingToolsCalling,
|
||||
phase: 'human_approval',
|
||||
requiresApproval: true,
|
||||
},
|
||||
stepIndex,
|
||||
type: 'step_start',
|
||||
});
|
||||
|
||||
const newState = structuredClone(state);
|
||||
newState.lastModified = new Date().toISOString();
|
||||
newState.status = 'waiting_for_human';
|
||||
newState.pendingToolsCalling = pendingToolsCalling;
|
||||
|
||||
// 通过流式系统通知前端显示审批 UI
|
||||
await streamManager.publishStreamChunk(sessionId, stepIndex, {
|
||||
// 使用 sessionId 作为 messageId
|
||||
chunkType: 'tools_calling',
|
||||
toolsCalling: pendingToolsCalling as any,
|
||||
});
|
||||
|
||||
const events: AgentEvent[] = [
|
||||
{
|
||||
pendingToolsCalling,
|
||||
sessionId: newState.sessionId,
|
||||
type: 'human_approve_required',
|
||||
},
|
||||
{
|
||||
toolCalls: pendingToolsCalling,
|
||||
type: 'tool_pending',
|
||||
},
|
||||
];
|
||||
|
||||
log('Human approval requested for session %s:%d', sessionId, stepIndex);
|
||||
|
||||
return {
|
||||
events,
|
||||
newState,
|
||||
// 不提供 nextContext,因为需要等待人工干预
|
||||
};
|
||||
},
|
||||
});
|
||||
@@ -0,0 +1,284 @@
|
||||
import { ChatToolPayload } from '@lobechat/types';
|
||||
import debug from 'debug';
|
||||
import Redis from 'ioredis';
|
||||
|
||||
import { getRedisClient } from '@/libs/redis';
|
||||
|
||||
const log = debug('lobe-server:agent-runtime:stream-event-manager');
|
||||
|
||||
export interface StreamEvent {
|
||||
data: any;
|
||||
id?: string; // Redis Stream event ID
|
||||
sessionId: string;
|
||||
stepIndex: number;
|
||||
timestamp: number;
|
||||
type:
|
||||
| 'agent_runtime_init'
|
||||
| 'agent_runtime_end'
|
||||
| 'stream_start'
|
||||
| 'stream_chunk'
|
||||
| 'stream_end'
|
||||
| 'tool_start'
|
||||
| 'tool_end'
|
||||
| 'step_start'
|
||||
| 'step_complete'
|
||||
| 'error';
|
||||
}
|
||||
|
||||
export interface StreamChunkData {
|
||||
chunkType: 'text' | 'reasoning' | 'tools_calling' | 'image' | 'grounding';
|
||||
content?: string;
|
||||
images?: any[];
|
||||
reasoning?: string;
|
||||
toolsCalling?: ChatToolPayload[];
|
||||
}
|
||||
|
||||
export class StreamEventManager {
|
||||
private redis: Redis;
|
||||
private readonly STREAM_PREFIX = 'agent_runtime_stream';
|
||||
private readonly STREAM_RETENTION = 3600; // 1小时
|
||||
|
||||
constructor() {
|
||||
const redisClient = getRedisClient();
|
||||
if (!redisClient) {
|
||||
throw new Error('Redis is not available. Please configure REDIS_URL environment variable.');
|
||||
}
|
||||
this.redis = redisClient;
|
||||
}
|
||||
|
||||
/**
|
||||
* 发布流式事件到 Redis Stream
|
||||
*/
|
||||
async publishStreamEvent(
|
||||
sessionId: string,
|
||||
event: Omit<StreamEvent, 'sessionId' | 'timestamp'>,
|
||||
): Promise<string> {
|
||||
const streamKey = `${this.STREAM_PREFIX}:${sessionId}`;
|
||||
|
||||
const eventData: StreamEvent = {
|
||||
...event,
|
||||
sessionId,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
const eventId = await this.redis.xadd(
|
||||
streamKey,
|
||||
'MAXLEN',
|
||||
'~',
|
||||
'1000', // 限制流长度,防止内存溢出
|
||||
'*', // 自动生成 ID
|
||||
'type',
|
||||
eventData.type,
|
||||
'stepIndex',
|
||||
eventData.stepIndex.toString(),
|
||||
'sessionId',
|
||||
eventData.sessionId,
|
||||
'data',
|
||||
JSON.stringify(eventData.data),
|
||||
'timestamp',
|
||||
eventData.timestamp.toString(),
|
||||
);
|
||||
|
||||
// 设置过期时间
|
||||
await this.redis.expire(streamKey, this.STREAM_RETENTION);
|
||||
|
||||
log('Published event %s for session %s:%d', eventData.type, sessionId, eventData.stepIndex);
|
||||
|
||||
return eventId as string;
|
||||
} catch (error) {
|
||||
console.error('[StreamEventManager] Failed to publish stream event:', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 发布流式内容块
|
||||
*/
|
||||
async publishStreamChunk(
|
||||
sessionId: string,
|
||||
stepIndex: number,
|
||||
chunkData: StreamChunkData,
|
||||
): Promise<string> {
|
||||
return this.publishStreamEvent(sessionId, {
|
||||
data: chunkData,
|
||||
stepIndex,
|
||||
type: 'stream_chunk',
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 发布 Agent 运行时初始化事件
|
||||
*/
|
||||
async publishAgentRuntimeInit(sessionId: string, initialState: any): Promise<string> {
|
||||
return this.publishStreamEvent(sessionId, {
|
||||
data: initialState,
|
||||
stepIndex: 0,
|
||||
type: 'agent_runtime_init',
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 发布 Agent 运行时结束事件
|
||||
*/
|
||||
async publishAgentRuntimeEnd(
|
||||
sessionId: string,
|
||||
stepIndex: number,
|
||||
finalState: any,
|
||||
reason?: string,
|
||||
reasonDetail?: string,
|
||||
): Promise<string> {
|
||||
return this.publishStreamEvent(sessionId, {
|
||||
data: {
|
||||
finalState,
|
||||
phase: 'execution_complete',
|
||||
reason: reason || 'completed',
|
||||
reasonDetail: reasonDetail || 'Agent runtime completed successfully',
|
||||
sessionId,
|
||||
},
|
||||
stepIndex,
|
||||
type: 'agent_runtime_end',
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 订阅流式事件(用于 WebSocket/SSE)
|
||||
*/
|
||||
async subscribeStreamEvents(
|
||||
sessionId: string,
|
||||
lastEventId: string = '0',
|
||||
onEvents: (events: StreamEvent[]) => void,
|
||||
signal?: AbortSignal,
|
||||
): Promise<void> {
|
||||
const streamKey = `${this.STREAM_PREFIX}:${sessionId}`;
|
||||
let currentLastId = lastEventId;
|
||||
|
||||
log('Starting subscription for session %s from %s', sessionId, lastEventId);
|
||||
|
||||
while (!signal?.aborted) {
|
||||
try {
|
||||
const results = await this.redis.xread(
|
||||
'BLOCK',
|
||||
1000, // 1秒超时
|
||||
'STREAMS',
|
||||
streamKey,
|
||||
currentLastId,
|
||||
);
|
||||
|
||||
if (results && results.length > 0) {
|
||||
const [, messages] = results[0];
|
||||
const events: StreamEvent[] = [];
|
||||
|
||||
for (const [id, fields] of messages) {
|
||||
const eventData: any = {};
|
||||
|
||||
// 解析 Redis Stream 字段
|
||||
for (let i = 0; i < fields.length; i += 2) {
|
||||
const key = fields[i];
|
||||
const value = fields[i + 1];
|
||||
|
||||
if (key === 'data') {
|
||||
eventData[key] = JSON.parse(value);
|
||||
} else if (key === 'stepIndex' || key === 'timestamp') {
|
||||
eventData[key] = parseInt(value);
|
||||
} else {
|
||||
eventData[key] = value;
|
||||
}
|
||||
}
|
||||
|
||||
events.push({
|
||||
...eventData,
|
||||
id, // Redis Stream 事件 ID
|
||||
} as StreamEvent);
|
||||
|
||||
currentLastId = id;
|
||||
}
|
||||
|
||||
if (events.length > 0) {
|
||||
onEvents(events);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
if (signal?.aborted) {
|
||||
break;
|
||||
}
|
||||
|
||||
console.error('[StreamEventManager] Stream subscription error:', error);
|
||||
// 短暂延迟后重试
|
||||
await new Promise((resolve) => {
|
||||
setTimeout(resolve, 1000);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
log('Subscription ended for session %s', sessionId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取流式事件历史
|
||||
*/
|
||||
async getStreamHistory(sessionId: string, count: number = 100): Promise<StreamEvent[]> {
|
||||
const streamKey = `${this.STREAM_PREFIX}:${sessionId}`;
|
||||
|
||||
try {
|
||||
const results = await this.redis.xrevrange(streamKey, '+', '-', 'COUNT', count);
|
||||
|
||||
return results.map(([id, fields]) => {
|
||||
const eventData: any = { id };
|
||||
|
||||
for (let i = 0; i < fields.length; i += 2) {
|
||||
const key = fields[i];
|
||||
const value = fields[i + 1];
|
||||
|
||||
if (key === 'data') {
|
||||
eventData[key] = JSON.parse(value);
|
||||
} else if (key === 'stepIndex' || key === 'timestamp') {
|
||||
eventData[key] = parseInt(value);
|
||||
} else {
|
||||
eventData[key] = value;
|
||||
}
|
||||
}
|
||||
|
||||
return eventData as StreamEvent;
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('[StreamEventManager] Failed to get stream history:', error);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理会话的流式数据
|
||||
*/
|
||||
async cleanupSession(sessionId: string): Promise<void> {
|
||||
const streamKey = `${this.STREAM_PREFIX}:${sessionId}`;
|
||||
|
||||
try {
|
||||
await this.redis.del(streamKey);
|
||||
log('Cleaned up session %s', sessionId);
|
||||
} catch (error) {
|
||||
console.error('[StreamEventManager] Failed to cleanup session:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取活跃会话数量
|
||||
*/
|
||||
async getActiveSessionsCount(): Promise<number> {
|
||||
try {
|
||||
const pattern = `${this.STREAM_PREFIX}:*`;
|
||||
const keys = await this.redis.keys(pattern);
|
||||
return keys.length;
|
||||
} catch (error) {
|
||||
console.error('[StreamEventManager] Failed to get active sessions count:', error);
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭 Redis 连接
|
||||
*/
|
||||
async disconnect(): Promise<void> {
|
||||
await this.redis.quit();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
Feature: Agent 运行时协调器
|
||||
作为一个使用 Agent 运行时系统的开发者
|
||||
我想要协调智能体状态管理和事件流
|
||||
以便我能够正确跟踪智能体生命周期并处理事件
|
||||
|
||||
Background:
|
||||
Given 智能体运行时系统已经启动
|
||||
And 协调器已经初始化完成
|
||||
|
||||
Scenario: 创建新的智能体会话
|
||||
When 我创建一个新的智能体会话 "chat-session-001"
|
||||
And 会话配置为:
|
||||
| 字段 | 值 |
|
||||
| 用户ID | user-12345 |
|
||||
| 模型配置 | GPT-4 温度0.7 |
|
||||
| 智能体类型 | 对话助手 |
|
||||
Then 会话应该成功创建
|
||||
And 系统应该发布会话初始化事件
|
||||
And 事件应该包含会话的基本信息
|
||||
|
||||
Scenario: 智能体执行任务并完成
|
||||
Given 智能体会话 "chat-session-001" 已经创建
|
||||
And 智能体当前状态为 "运行中"
|
||||
When 智能体完成所有任务
|
||||
And 会话状态更新为 "已完成"
|
||||
Then 系统应该保存最终状态
|
||||
And 应该发布会话结束事件
|
||||
And 事件应该包含执行结果和统计信息
|
||||
|
||||
Scenario: 智能体正在执行任务
|
||||
Given 智能体会话 "chat-session-001" 已经创建
|
||||
And 智能体当前状态为 "空闲"
|
||||
When 智能体开始执行任务
|
||||
And 会话状态更新为 "运行中"
|
||||
Then 系统应该保存当前状态
|
||||
But 不应该发布会话结束事件
|
||||
|
||||
Scenario: 完成单个执行步骤
|
||||
Given 智能体会话 "chat-session-001" 正在运行
|
||||
When 智能体完成第5个执行步骤
|
||||
And 该步骤标记为最终步骤
|
||||
Then 系统应该保存步骤结果
|
||||
And 应该发布会话结束事件
|
||||
And 事件应该包含完整的执行历史
|
||||
|
||||
Scenario: 完成中间执行步骤
|
||||
Given 智能体会话 "chat-session-001" 正在运行
|
||||
When 智能体完成第3个执行步骤
|
||||
And 该步骤不是最终步骤
|
||||
Then 系统应该保存步骤结果
|
||||
But 不应该发布会话结束事件
|
||||
|
||||
Scenario: 查询智能体执行状态
|
||||
Given 智能体会话 "chat-session-001" 存在
|
||||
When 我查询会话的当前状态
|
||||
Then 应该返回最新的状态信息
|
||||
And 信息应该包含当前步骤数和执行状态
|
||||
|
||||
Scenario: 查询会话元数据
|
||||
Given 智能体会话 "chat-session-001" 存在
|
||||
When 我查询会话的元数据
|
||||
Then 应该返回会话的配置信息
|
||||
And 信息应该包含用户ID、模型配置和创建时间
|
||||
|
||||
Scenario: 查询执行历史
|
||||
Given 智能体会话 "chat-session-001" 已经执行了多个步骤
|
||||
When 我查询最近10步的执行历史
|
||||
Then 应该返回包含10个步骤的历史记录
|
||||
And 每个记录应该包含步骤索引、执行时间和状态
|
||||
|
||||
Scenario: 清理会话资源
|
||||
Given 智能体会话 "chat-session-001" 已经完成
|
||||
When 我删除该会话
|
||||
Then 会话的所有数据应该被清理
|
||||
And 相关的事件流数据也应该被清理
|
||||
|
||||
Scenario: 断开系统连接
|
||||
Given 协调器正在运行中
|
||||
When 我关闭协调器
|
||||
Then 所有的数据库连接应该被正确关闭
|
||||
And 所有的资源应该被释放
|
||||
|
||||
Scenario: 获取系统统计信息
|
||||
Given 系统中存在多个智能体会话
|
||||
When 我查询系统统计信息
|
||||
Then 应该返回活跃会话数量
|
||||
And 应该返回已完成会话数量
|
||||
And 应该返回出错会话数量
|
||||
|
||||
Scenario: 清理过期会话
|
||||
Given 系统中存在一些长时间未活跃的会话
|
||||
When 我执行过期会话清理
|
||||
Then 超过保留期限的会话应该被删除
|
||||
And 应该返回清理的会话数量
|
||||
@@ -0,0 +1,245 @@
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { AgentRuntimeCoordinator } from '../AgentRuntimeCoordinator';
|
||||
import { AgentStateManager } from '../AgentStateManager';
|
||||
import { StreamEventManager } from '../StreamEventManager';
|
||||
|
||||
// Mock AgentStateManager
|
||||
vi.mock('../AgentStateManager', () => ({
|
||||
AgentStateManager: vi.fn(() => ({
|
||||
cleanupExpiredSessions: vi.fn(),
|
||||
createSessionMetadata: vi.fn(),
|
||||
deleteAgentSession: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
getActiveSessions: vi.fn(),
|
||||
getExecutionHistory: vi.fn(),
|
||||
getSessionMetadata: vi.fn(),
|
||||
getStats: vi.fn(),
|
||||
loadAgentState: vi.fn(),
|
||||
saveAgentState: vi.fn(),
|
||||
saveStepResult: vi.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
// Mock StreamEventManager
|
||||
vi.mock('../StreamEventManager', () => ({
|
||||
StreamEventManager: vi.fn(() => ({
|
||||
cleanupSession: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
publishAgentRuntimeEnd: vi.fn(),
|
||||
publishAgentRuntimeInit: vi.fn(),
|
||||
publishStreamEvent: vi.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
describe('AgentRuntimeCoordinator', () => {
|
||||
const MockedAgentStateManager = AgentStateManager as any;
|
||||
const MockedStreamEventManager = StreamEventManager as any;
|
||||
let coordinator: AgentRuntimeCoordinator;
|
||||
let mockStateManager: any;
|
||||
let mockStreamManager: any;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
mockStateManager = {
|
||||
cleanupExpiredSessions: vi.fn(),
|
||||
createSessionMetadata: vi.fn(),
|
||||
deleteAgentSession: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
getActiveSessions: vi.fn(),
|
||||
getExecutionHistory: vi.fn(),
|
||||
getSessionMetadata: vi.fn(),
|
||||
getStats: vi.fn(),
|
||||
loadAgentState: vi.fn(),
|
||||
saveAgentState: vi.fn(),
|
||||
saveStepResult: vi.fn(),
|
||||
};
|
||||
|
||||
mockStreamManager = {
|
||||
cleanupSession: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
publishAgentRuntimeEnd: vi.fn(),
|
||||
publishAgentRuntimeInit: vi.fn(),
|
||||
publishStreamEvent: vi.fn(),
|
||||
};
|
||||
|
||||
MockedAgentStateManager.mockImplementation(() => mockStateManager);
|
||||
MockedStreamEventManager.mockImplementation(() => mockStreamManager);
|
||||
|
||||
coordinator = new AgentRuntimeCoordinator();
|
||||
});
|
||||
|
||||
describe('createAgentSession', () => {
|
||||
it('should create session metadata and publish init event', async () => {
|
||||
const sessionId = 'test-session-id';
|
||||
const data = {
|
||||
agentConfig: { test: true },
|
||||
modelRuntimeConfig: { model: 'gpt-4' },
|
||||
userId: 'user-123',
|
||||
};
|
||||
const metadata = {
|
||||
createdAt: '2024-01-01T00:00:00.000Z',
|
||||
status: 'idle',
|
||||
totalCost: 0,
|
||||
totalSteps: 0,
|
||||
...data,
|
||||
};
|
||||
|
||||
mockStateManager.getSessionMetadata.mockResolvedValue(metadata);
|
||||
|
||||
await coordinator.createAgentSession(sessionId, data);
|
||||
|
||||
expect(mockStateManager.createSessionMetadata).toHaveBeenCalledWith(sessionId, data);
|
||||
expect(mockStateManager.getSessionMetadata).toHaveBeenCalledWith(sessionId);
|
||||
expect(mockStreamManager.publishAgentRuntimeInit).toHaveBeenCalledWith(sessionId, metadata);
|
||||
});
|
||||
|
||||
it('should not publish init event if metadata creation fails', async () => {
|
||||
const sessionId = 'test-session-id';
|
||||
const data = { userId: 'user-123' };
|
||||
|
||||
mockStateManager.getSessionMetadata.mockResolvedValue(null);
|
||||
|
||||
await coordinator.createAgentSession(sessionId, data);
|
||||
|
||||
expect(mockStateManager.createSessionMetadata).toHaveBeenCalledWith(sessionId, data);
|
||||
expect(mockStreamManager.publishAgentRuntimeInit).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('saveAgentState', () => {
|
||||
it('should save state and publish end event when status changes to done', async () => {
|
||||
const sessionId = 'test-session-id';
|
||||
const previousState = { status: 'running', stepCount: 3 };
|
||||
const newState = { status: 'done', stepCount: 5 };
|
||||
|
||||
mockStateManager.loadAgentState.mockResolvedValue(previousState);
|
||||
|
||||
await coordinator.saveAgentState(sessionId, newState as any);
|
||||
|
||||
expect(mockStateManager.saveAgentState).toHaveBeenCalledWith(sessionId, newState);
|
||||
expect(mockStreamManager.publishAgentRuntimeEnd).toHaveBeenCalledWith(
|
||||
sessionId,
|
||||
newState.stepCount,
|
||||
newState,
|
||||
);
|
||||
});
|
||||
|
||||
it('should not publish end event when status was already done', async () => {
|
||||
const sessionId = 'test-session-id';
|
||||
const previousState = { status: 'done', stepCount: 5 };
|
||||
const newState = { status: 'done', stepCount: 5 };
|
||||
|
||||
mockStateManager.loadAgentState.mockResolvedValue(previousState);
|
||||
|
||||
await coordinator.saveAgentState(sessionId, newState as any);
|
||||
|
||||
expect(mockStateManager.saveAgentState).toHaveBeenCalledWith(sessionId, newState);
|
||||
expect(mockStreamManager.publishAgentRuntimeEnd).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not publish end event when status is not done', async () => {
|
||||
const sessionId = 'test-session-id';
|
||||
const previousState = { status: 'idle', stepCount: 0 };
|
||||
const newState = { status: 'running', stepCount: 1 };
|
||||
|
||||
mockStateManager.loadAgentState.mockResolvedValue(previousState);
|
||||
|
||||
await coordinator.saveAgentState(sessionId, newState as any);
|
||||
|
||||
expect(mockStateManager.saveAgentState).toHaveBeenCalledWith(sessionId, newState);
|
||||
expect(mockStreamManager.publishAgentRuntimeEnd).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('saveStepResult', () => {
|
||||
it('should save step result but not publish end event (left to saveAgentState)', async () => {
|
||||
const sessionId = 'test-session-id';
|
||||
const stepResult = {
|
||||
executionTime: 1000,
|
||||
newState: { status: 'done', stepCount: 5 },
|
||||
stepIndex: 5,
|
||||
};
|
||||
|
||||
await coordinator.saveStepResult(sessionId, stepResult as any);
|
||||
|
||||
expect(mockStateManager.saveStepResult).toHaveBeenCalledWith(sessionId, stepResult);
|
||||
// agent_runtime_end 事件现在由 saveAgentState 统一处理,确保它是最后一个事件
|
||||
expect(mockStreamManager.publishAgentRuntimeEnd).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should not publish end event when status is not done', async () => {
|
||||
const sessionId = 'test-session-id';
|
||||
const stepResult = {
|
||||
executionTime: 500,
|
||||
newState: { status: 'running', stepCount: 3 },
|
||||
stepIndex: 3,
|
||||
};
|
||||
|
||||
await coordinator.saveStepResult(sessionId, stepResult as any);
|
||||
|
||||
expect(mockStateManager.saveStepResult).toHaveBeenCalledWith(sessionId, stepResult);
|
||||
expect(mockStreamManager.publishAgentRuntimeEnd).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('deleteAgentSession', () => {
|
||||
it('should delete session from both state manager and stream manager', async () => {
|
||||
const sessionId = 'test-session-id';
|
||||
|
||||
await coordinator.deleteAgentSession(sessionId);
|
||||
|
||||
expect(mockStateManager.deleteAgentSession).toHaveBeenCalledWith(sessionId);
|
||||
expect(mockStreamManager.cleanupSession).toHaveBeenCalledWith(sessionId);
|
||||
});
|
||||
});
|
||||
|
||||
describe('disconnect', () => {
|
||||
it('should disconnect both managers', async () => {
|
||||
await coordinator.disconnect();
|
||||
|
||||
expect(mockStateManager.disconnect).toHaveBeenCalled();
|
||||
expect(mockStreamManager.disconnect).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('delegation methods', () => {
|
||||
it('should delegate loadAgentState to state manager', async () => {
|
||||
const sessionId = 'test-session-id';
|
||||
const expectedState = { status: 'running' };
|
||||
|
||||
mockStateManager.loadAgentState.mockResolvedValue(expectedState);
|
||||
|
||||
const result = await coordinator.loadAgentState(sessionId);
|
||||
|
||||
expect(mockStateManager.loadAgentState).toHaveBeenCalledWith(sessionId);
|
||||
expect(result).toBe(expectedState);
|
||||
});
|
||||
|
||||
it('should delegate getSessionMetadata to state manager', async () => {
|
||||
const sessionId = 'test-session-id';
|
||||
const expectedMetadata = { status: 'idle' };
|
||||
|
||||
mockStateManager.getSessionMetadata.mockResolvedValue(expectedMetadata);
|
||||
|
||||
const result = await coordinator.getSessionMetadata(sessionId);
|
||||
|
||||
expect(mockStateManager.getSessionMetadata).toHaveBeenCalledWith(sessionId);
|
||||
expect(result).toBe(expectedMetadata);
|
||||
});
|
||||
|
||||
it('should delegate getExecutionHistory to state manager', async () => {
|
||||
const sessionId = 'test-session-id';
|
||||
const limit = 10;
|
||||
const expectedHistory = [{ step: 1 }];
|
||||
|
||||
mockStateManager.getExecutionHistory.mockResolvedValue(expectedHistory);
|
||||
|
||||
const result = await coordinator.getExecutionHistory(sessionId, limit);
|
||||
|
||||
expect(mockStateManager.getExecutionHistory).toHaveBeenCalledWith(sessionId, limit);
|
||||
expect(result).toBe(expectedHistory);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,107 @@
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { AgentStateManager } from '../AgentStateManager';
|
||||
|
||||
// Mock Redis client
|
||||
vi.mock('@/libs/redis', () => ({
|
||||
getRedisClient: () => ({
|
||||
del: vi.fn(),
|
||||
expire: vi.fn(),
|
||||
get: vi.fn(),
|
||||
hgetall: vi.fn(),
|
||||
hmset: vi.fn(),
|
||||
keys: vi.fn(),
|
||||
multi: vi.fn(() => ({
|
||||
exec: vi.fn(),
|
||||
expire: vi.fn(),
|
||||
hmset: vi.fn(),
|
||||
lpush: vi.fn(),
|
||||
ltrim: vi.fn(),
|
||||
setex: vi.fn(),
|
||||
})),
|
||||
quit: vi.fn(),
|
||||
setex: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
describe('AgentStateManager', () => {
|
||||
let stateManager: AgentStateManager;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
stateManager = new AgentStateManager();
|
||||
});
|
||||
|
||||
describe('createSessionMetadata', () => {
|
||||
it('should create session metadata successfully', async () => {
|
||||
const sessionId = 'test-session-id';
|
||||
const data = {
|
||||
agentConfig: { test: true },
|
||||
modelRuntimeConfig: { model: 'gpt-4' },
|
||||
userId: 'user-123',
|
||||
};
|
||||
|
||||
await expect(stateManager.createSessionMetadata(sessionId, data)).resolves.not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('saveAgentState', () => {
|
||||
it('should save agent state successfully', async () => {
|
||||
const sessionId = 'test-session-id';
|
||||
const state = {
|
||||
cost: { total: 100 },
|
||||
status: 'done' as const,
|
||||
stepCount: 5,
|
||||
};
|
||||
|
||||
await expect(stateManager.saveAgentState(sessionId, state as any)).resolves.not.toThrow();
|
||||
});
|
||||
|
||||
it('should save agent state with running status', async () => {
|
||||
const sessionId = 'test-session-id';
|
||||
const state = {
|
||||
cost: { total: 50 },
|
||||
status: 'running' as const,
|
||||
stepCount: 3,
|
||||
};
|
||||
|
||||
await expect(stateManager.saveAgentState(sessionId, state as any)).resolves.not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('saveStepResult', () => {
|
||||
it('should save step result successfully when status is done', async () => {
|
||||
const sessionId = 'test-session-id';
|
||||
const stepResult = {
|
||||
executionTime: 1000,
|
||||
newState: {
|
||||
cost: { total: 200 },
|
||||
status: 'done' as const,
|
||||
stepCount: 10,
|
||||
},
|
||||
stepIndex: 10,
|
||||
};
|
||||
|
||||
await expect(
|
||||
stateManager.saveStepResult(sessionId, stepResult as any),
|
||||
).resolves.not.toThrow();
|
||||
});
|
||||
|
||||
it('should save step result successfully when status is not done', async () => {
|
||||
const sessionId = 'test-session-id';
|
||||
const stepResult = {
|
||||
executionTime: 500,
|
||||
newState: {
|
||||
cost: { total: 75 },
|
||||
status: 'running' as const,
|
||||
stepCount: 3,
|
||||
},
|
||||
stepIndex: 3,
|
||||
};
|
||||
|
||||
await expect(
|
||||
stateManager.saveStepResult(sessionId, stepResult as any),
|
||||
).resolves.not.toThrow();
|
||||
});
|
||||
});
|
||||
});
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,148 @@
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { StreamEventManager } from '../StreamEventManager';
|
||||
|
||||
// Mock Redis client
|
||||
const mockRedis = {
|
||||
del: vi.fn(),
|
||||
expire: vi.fn(),
|
||||
keys: vi.fn(),
|
||||
quit: vi.fn(),
|
||||
xadd: vi.fn(),
|
||||
xread: vi.fn(),
|
||||
xrevrange: vi.fn(),
|
||||
};
|
||||
|
||||
vi.mock('@/libs/redis', () => ({
|
||||
getRedisClient: () => mockRedis,
|
||||
}));
|
||||
|
||||
describe('StreamEventManager', () => {
|
||||
let streamManager: StreamEventManager;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
streamManager = new StreamEventManager();
|
||||
});
|
||||
|
||||
describe('publishAgentRuntimeInit', () => {
|
||||
it('should publish agent runtime init event with correct data', async () => {
|
||||
const sessionId = 'test-session-id';
|
||||
const metadata = {
|
||||
agentConfig: { test: true },
|
||||
createdAt: '2024-01-01T00:00:00.000Z',
|
||||
modelRuntimeConfig: { model: 'gpt-4' },
|
||||
status: 'idle',
|
||||
totalCost: 0,
|
||||
totalSteps: 0,
|
||||
userId: 'user-123',
|
||||
};
|
||||
|
||||
mockRedis.xadd.mockResolvedValue('event-id-123');
|
||||
|
||||
const result = await streamManager.publishAgentRuntimeInit(sessionId, metadata);
|
||||
|
||||
expect(result).toBe('event-id-123');
|
||||
expect(mockRedis.xadd).toHaveBeenCalledWith(
|
||||
`agent_runtime_stream:${sessionId}`,
|
||||
'MAXLEN',
|
||||
'~',
|
||||
'1000',
|
||||
'*',
|
||||
'type',
|
||||
'agent_runtime_init',
|
||||
'stepIndex',
|
||||
'0',
|
||||
'sessionId',
|
||||
sessionId,
|
||||
'data',
|
||||
JSON.stringify(metadata),
|
||||
'timestamp',
|
||||
expect.any(String),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('publishAgentRuntimeEnd', () => {
|
||||
it('should publish agent runtime end event with correct data', async () => {
|
||||
const sessionId = 'test-session-id';
|
||||
const stepIndex = 5;
|
||||
const finalState = {
|
||||
cost: { total: 100 },
|
||||
status: 'done',
|
||||
stepCount: 5,
|
||||
};
|
||||
|
||||
mockRedis.xadd.mockResolvedValue('event-id-456');
|
||||
|
||||
const result = await streamManager.publishAgentRuntimeEnd(sessionId, stepIndex, finalState);
|
||||
|
||||
expect(result).toBe('event-id-456');
|
||||
expect(mockRedis.xadd).toHaveBeenCalledWith(
|
||||
`agent_runtime_stream:${sessionId}`,
|
||||
'MAXLEN',
|
||||
'~',
|
||||
'1000',
|
||||
'*',
|
||||
'type',
|
||||
'agent_runtime_end',
|
||||
'stepIndex',
|
||||
'5',
|
||||
'sessionId',
|
||||
sessionId,
|
||||
'data',
|
||||
JSON.stringify({
|
||||
finalState,
|
||||
phase: 'execution_complete',
|
||||
reason: 'completed',
|
||||
reasonDetail: 'Agent runtime completed successfully',
|
||||
sessionId,
|
||||
}),
|
||||
'timestamp',
|
||||
expect.any(String),
|
||||
);
|
||||
});
|
||||
|
||||
it('should accept custom reason and reasonDetail', async () => {
|
||||
const sessionId = 'test-session-id';
|
||||
const stepIndex = 3;
|
||||
const finalState = { status: 'error' };
|
||||
const reason = 'error';
|
||||
const reasonDetail = 'Agent failed due to timeout';
|
||||
|
||||
mockRedis.xadd.mockResolvedValue('event-id-789');
|
||||
|
||||
await streamManager.publishAgentRuntimeEnd(
|
||||
sessionId,
|
||||
stepIndex,
|
||||
finalState,
|
||||
reason,
|
||||
reasonDetail,
|
||||
);
|
||||
|
||||
expect(mockRedis.xadd).toHaveBeenCalledWith(
|
||||
expect.any(String),
|
||||
expect.any(String),
|
||||
expect.any(String),
|
||||
expect.any(String),
|
||||
expect.any(String),
|
||||
expect.any(String),
|
||||
expect.any(String),
|
||||
expect.any(String),
|
||||
expect.any(String),
|
||||
expect.any(String),
|
||||
sessionId,
|
||||
'data',
|
||||
JSON.stringify({
|
||||
finalState,
|
||||
phase: 'execution_complete',
|
||||
reason,
|
||||
reasonDetail,
|
||||
sessionId,
|
||||
}),
|
||||
expect.any(String),
|
||||
expect.any(String),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,5 @@
|
||||
export { AgentRuntimeCoordinator } from './AgentRuntimeCoordinator';
|
||||
export { AgentStateManager } from './AgentStateManager';
|
||||
export { GeneralAgent } from './GeneralAgent';
|
||||
export { createRuntimeExecutors } from './RuntimeExecutors';
|
||||
export { StreamEventManager } from './StreamEventManager';
|
||||
@@ -0,0 +1,299 @@
|
||||
import { AgentRuntimeContext } from '@lobechat/agent-runtime';
|
||||
import { TRPCError } from '@trpc/server';
|
||||
import debug from 'debug';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { isEnableAgent } from '@/app/(backend)/api/agent/isEnableAgent';
|
||||
import { authedProcedure, router } from '@/libs/trpc/lambda';
|
||||
import { serverDatabase } from '@/libs/trpc/lambda/middleware';
|
||||
import { AgentRuntimeService } from '@/server/services/agentRuntime';
|
||||
|
||||
const log = debug('lobe-server:ai-agent-router');
|
||||
|
||||
// Zod schemas for agent session operations
|
||||
const CreateAgentSessionSchema = z.object({
|
||||
agentConfig: z.record(z.any()).optional().default({}),
|
||||
agentSessionId: z.string().optional(),
|
||||
autoStart: z.boolean().optional().default(true),
|
||||
messages: z.array(z.any()).optional().default([]),
|
||||
modelRuntimeConfig: z.object({
|
||||
model: z.string(),
|
||||
provider: z.string(),
|
||||
}),
|
||||
threadId: z.string().optional().nullable(),
|
||||
toolManifestMap: z.record(z.string(), z.any()).default({}),
|
||||
tools: z.array(z.any()).optional(),
|
||||
topicId: z.string().optional().nullable(),
|
||||
userId: z.string().optional(),
|
||||
});
|
||||
|
||||
const GetSessionStatusSchema = z.object({
|
||||
historyLimit: z.number().optional().default(10),
|
||||
includeHistory: z.boolean().optional().default(false),
|
||||
sessionId: z.string(),
|
||||
});
|
||||
|
||||
const ProcessHumanInterventionSchema = z.object({
|
||||
action: z.enum(['approve', 'reject', 'input', 'select']),
|
||||
data: z
|
||||
.object({
|
||||
approvedToolCall: z.any().optional(),
|
||||
input: z.any().optional(),
|
||||
selection: z.any().optional(),
|
||||
})
|
||||
.optional(),
|
||||
reason: z.string().optional(),
|
||||
sessionId: z.string(),
|
||||
stepIndex: z.number().optional().default(0),
|
||||
});
|
||||
|
||||
const GetPendingInterventionsSchema = z
|
||||
.object({
|
||||
sessionId: z.string().optional(),
|
||||
userId: z.string().optional(),
|
||||
})
|
||||
.refine((data) => data.sessionId || data.userId, {
|
||||
message: 'Either sessionId or userId must be provided',
|
||||
});
|
||||
|
||||
const StartExecutionSchema = z.object({
|
||||
context: z.any().optional(),
|
||||
delay: z.number().optional().default(1000),
|
||||
priority: z.enum(['high', 'normal', 'low']).optional().default('normal'),
|
||||
sessionId: z.string(),
|
||||
});
|
||||
|
||||
const aiAgentProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: {
|
||||
agentRuntimeService: new AgentRuntimeService(ctx.serverDB, ctx.userId),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
export const aiAgentRouter = router({
|
||||
createSession: aiAgentProcedure
|
||||
.input(CreateAgentSessionSchema)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
if (!isEnableAgent()) {
|
||||
throw new TRPCError({ code: 'NOT_IMPLEMENTED', message: 'Agent features are not enabled' });
|
||||
}
|
||||
|
||||
const {
|
||||
agentConfig = {},
|
||||
agentSessionId,
|
||||
autoStart = true,
|
||||
messages = [],
|
||||
modelRuntimeConfig,
|
||||
threadId,
|
||||
topicId,
|
||||
tools,
|
||||
toolManifestMap,
|
||||
} = input;
|
||||
log('input: %O', input);
|
||||
|
||||
// Validate required parameters
|
||||
if (!modelRuntimeConfig.model || !modelRuntimeConfig.provider) {
|
||||
throw new TRPCError({
|
||||
code: 'BAD_REQUEST',
|
||||
message: 'modelRuntimeConfig.model and modelRuntimeConfig.provider are required',
|
||||
});
|
||||
}
|
||||
|
||||
// Generate runtime session ID
|
||||
const runtimeSessionId = `agent_${Date.now()}_${Math.random().toString(36).slice(2, 11)}`;
|
||||
|
||||
log(`Creating session ${runtimeSessionId} for user ${ctx.userId}`);
|
||||
|
||||
// Create initial context
|
||||
const initialContext: AgentRuntimeContext = {
|
||||
payload: {},
|
||||
phase: 'user_input' as const,
|
||||
session: {
|
||||
messageCount: messages.length,
|
||||
sessionId: runtimeSessionId,
|
||||
status: 'idle' as const,
|
||||
stepCount: 0,
|
||||
},
|
||||
};
|
||||
|
||||
// Create session using AgentRuntimeService
|
||||
const result = await ctx.agentRuntimeService.createSession({
|
||||
agentConfig,
|
||||
appContext: {
|
||||
sessionId: agentSessionId,
|
||||
threadId,
|
||||
topicId,
|
||||
},
|
||||
autoStart,
|
||||
initialContext,
|
||||
initialMessages: messages,
|
||||
modelRuntimeConfig,
|
||||
sessionId: runtimeSessionId,
|
||||
toolManifestMap,
|
||||
tools,
|
||||
userId: ctx.userId,
|
||||
});
|
||||
|
||||
let firstStepResult;
|
||||
if (result.autoStarted) {
|
||||
firstStepResult = {
|
||||
context: initialContext,
|
||||
messageId: result.messageId,
|
||||
scheduled: true,
|
||||
};
|
||||
|
||||
log(
|
||||
`Session ${runtimeSessionId} created and first step scheduled (messageId: ${result.messageId})`,
|
||||
);
|
||||
} else {
|
||||
log(`Session ${runtimeSessionId} created without auto-start`);
|
||||
}
|
||||
|
||||
return {
|
||||
autoStart,
|
||||
createdAt: new Date().toISOString(),
|
||||
firstStep: firstStepResult,
|
||||
sessionId: runtimeSessionId,
|
||||
status: 'created',
|
||||
success: true,
|
||||
};
|
||||
}),
|
||||
|
||||
getPendingInterventions: aiAgentProcedure
|
||||
.input(GetPendingInterventionsSchema)
|
||||
.query(async ({ input, ctx }) => {
|
||||
if (!isEnableAgent()) {
|
||||
throw new TRPCError({ code: 'NOT_IMPLEMENTED', message: 'Agent features are not enabled' });
|
||||
}
|
||||
|
||||
const { sessionId, userId } = input;
|
||||
|
||||
log('Getting pending interventions for sessionId: %s, userId: %s', sessionId, userId);
|
||||
|
||||
// Get pending interventions using AgentRuntimeService
|
||||
const result = await ctx.agentRuntimeService.getPendingInterventions({
|
||||
sessionId: sessionId || undefined,
|
||||
userId: userId || undefined,
|
||||
});
|
||||
|
||||
return result;
|
||||
}),
|
||||
|
||||
getSessionStatus: aiAgentProcedure.input(GetSessionStatusSchema).query(async ({ input, ctx }) => {
|
||||
if (!isEnableAgent()) {
|
||||
throw new Error('Agent features are not enabled');
|
||||
}
|
||||
|
||||
const { historyLimit, includeHistory, sessionId } = input;
|
||||
|
||||
if (!sessionId) {
|
||||
throw new Error('sessionId parameter is required');
|
||||
}
|
||||
|
||||
log('Getting session status for %s', sessionId);
|
||||
|
||||
// Get session status using AgentRuntimeService
|
||||
const sessionStatus = await ctx.agentRuntimeService.getSessionStatus({
|
||||
historyLimit,
|
||||
includeHistory,
|
||||
sessionId,
|
||||
});
|
||||
|
||||
return sessionStatus;
|
||||
}),
|
||||
|
||||
processHumanIntervention: aiAgentProcedure
|
||||
.input(ProcessHumanInterventionSchema)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
if (!isEnableAgent()) {
|
||||
throw new TRPCError({ code: 'NOT_IMPLEMENTED', message: 'Agent features are not enabled' });
|
||||
}
|
||||
|
||||
const { sessionId, action, data, reason, stepIndex } = input;
|
||||
|
||||
log(`Processing ${action} for session ${sessionId}`);
|
||||
|
||||
// Build intervention parameters
|
||||
let interventionParams: any = {
|
||||
action,
|
||||
sessionId,
|
||||
stepIndex,
|
||||
};
|
||||
|
||||
switch (action) {
|
||||
case 'approve': {
|
||||
if (!data?.approvedToolCall) {
|
||||
throw new TRPCError({
|
||||
code: 'BAD_REQUEST',
|
||||
message: 'approvedToolCall is required for approve action',
|
||||
});
|
||||
}
|
||||
interventionParams.approvedToolCall = data.approvedToolCall;
|
||||
break;
|
||||
}
|
||||
case 'reject': {
|
||||
interventionParams.rejectionReason = reason || 'Tool call rejected by user';
|
||||
break;
|
||||
}
|
||||
case 'input': {
|
||||
if (!data?.input) {
|
||||
throw new TRPCError({
|
||||
code: 'BAD_REQUEST',
|
||||
message: 'input is required for input action',
|
||||
});
|
||||
}
|
||||
interventionParams.humanInput = { response: data.input };
|
||||
break;
|
||||
}
|
||||
case 'select': {
|
||||
if (!data?.selection) {
|
||||
throw new TRPCError({
|
||||
code: 'BAD_REQUEST',
|
||||
message: 'selection is required for select action',
|
||||
});
|
||||
}
|
||||
interventionParams.humanInput = { selection: data.selection };
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Process human intervention using AgentRuntimeService
|
||||
const result = await ctx.agentRuntimeService.processHumanIntervention(interventionParams);
|
||||
|
||||
return {
|
||||
action,
|
||||
message: `Human intervention processed successfully. Execution resumed.`,
|
||||
scheduledMessageId: result.messageId,
|
||||
sessionId,
|
||||
success: true,
|
||||
timestamp: new Date().toISOString(),
|
||||
};
|
||||
}),
|
||||
|
||||
startExecution: aiAgentProcedure.input(StartExecutionSchema).mutation(async ({ input, ctx }) => {
|
||||
if (!isEnableAgent()) {
|
||||
throw new TRPCError({ code: 'NOT_IMPLEMENTED', message: 'Agent features are not enabled' });
|
||||
}
|
||||
|
||||
const { sessionId, context, priority, delay } = input;
|
||||
|
||||
log('Starting execution for session %s', sessionId);
|
||||
|
||||
// Start execution using AgentRuntimeService
|
||||
const result = await ctx.agentRuntimeService.startExecution({
|
||||
context,
|
||||
delay,
|
||||
priority,
|
||||
sessionId,
|
||||
});
|
||||
|
||||
return {
|
||||
...result,
|
||||
message: 'Agent execution started successfully',
|
||||
timestamp: new Date().toISOString(),
|
||||
};
|
||||
}),
|
||||
});
|
||||
@@ -4,6 +4,7 @@
|
||||
import { publicProcedure, router } from '@/libs/trpc/lambda';
|
||||
|
||||
import { agentRouter } from './agent';
|
||||
import { aiAgentRouter } from './aiAgent';
|
||||
import { aiChatRouter } from './aiChat';
|
||||
import { aiModelRouter } from './aiModel';
|
||||
import { aiProviderRouter } from './aiProvider';
|
||||
@@ -35,6 +36,7 @@ import { userRouter } from './user';
|
||||
|
||||
export const lambdaRouter = router({
|
||||
agent: agentRouter,
|
||||
aiAgent: aiAgentRouter,
|
||||
aiChat: aiChatRouter,
|
||||
aiModel: aiModelRouter,
|
||||
aiProvider: aiProviderRouter,
|
||||
|
||||
@@ -0,0 +1,813 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { AgentRuntimeService } from './AgentRuntimeService';
|
||||
import type { AgentExecutionParams, SessionCreationParams, StartExecutionParams } from './types';
|
||||
|
||||
// Mock database and models
|
||||
vi.mock('@/database/models/message', () => ({
|
||||
MessageModel: vi.fn().mockImplementation(() => ({})),
|
||||
}));
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('@/server/modules/AgentRuntime', () => ({
|
||||
AgentRuntimeCoordinator: vi.fn().mockImplementation(() => ({
|
||||
createAgentSession: vi.fn(),
|
||||
saveAgentState: vi.fn(),
|
||||
loadAgentState: vi.fn(),
|
||||
getSessionMetadata: vi.fn(),
|
||||
saveStepResult: vi.fn(),
|
||||
getExecutionHistory: vi.fn(),
|
||||
getActiveSessions: vi.fn(),
|
||||
deleteAgentSession: vi.fn(),
|
||||
disconnect: vi.fn(),
|
||||
})),
|
||||
StreamEventManager: vi.fn().mockImplementation(() => ({
|
||||
publishStreamEvent: vi.fn(),
|
||||
getStreamHistory: vi.fn(),
|
||||
})),
|
||||
DurableLobeChatAgent: vi.fn(),
|
||||
createStreamingLLMExecutor: vi.fn(),
|
||||
createStreamingToolExecutor: vi.fn(),
|
||||
createStreamingFinishExecutor: vi.fn(),
|
||||
createStreamingHumanApprovalExecutor: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('@lobechat/agent-runtime', () => ({
|
||||
AgentRuntime: vi.fn().mockImplementation((agent, options) => ({
|
||||
step: vi.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('@/server/services/queue', () => ({
|
||||
QueueService: vi.fn().mockImplementation(() => ({
|
||||
scheduleMessage: vi.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
describe('AgentRuntimeService', () => {
|
||||
let service: AgentRuntimeService;
|
||||
let mockCoordinator: any;
|
||||
let mockStreamManager: any;
|
||||
let mockQueueService: any;
|
||||
let mockDb: any;
|
||||
const mockUserId = 'test-user-id';
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
process.env.AGENT_RUNTIME_BASE_URL = 'http://localhost:3010';
|
||||
|
||||
// Mock database
|
||||
mockDb = {
|
||||
select: vi.fn(),
|
||||
insert: vi.fn(),
|
||||
update: vi.fn(),
|
||||
delete: vi.fn(),
|
||||
};
|
||||
|
||||
service = new AgentRuntimeService(mockDb, mockUserId);
|
||||
|
||||
// Get mocked instances
|
||||
mockCoordinator = (service as any).coordinator;
|
||||
mockStreamManager = (service as any).streamManager;
|
||||
mockQueueService = (service as any).queueService;
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
delete process.env.AGENT_RUNTIME_BASE_URL;
|
||||
});
|
||||
|
||||
describe('constructor', () => {
|
||||
it('should initialize with default base URL', () => {
|
||||
delete process.env.AGENT_RUNTIME_BASE_URL;
|
||||
const newService = new AgentRuntimeService(mockDb, mockUserId);
|
||||
expect((newService as any).baseURL).toBe('http://localhost:3010/api/agent');
|
||||
});
|
||||
|
||||
it('should initialize with custom base URL from environment', () => {
|
||||
process.env.AGENT_RUNTIME_BASE_URL = 'http://custom:3000';
|
||||
const newService = new AgentRuntimeService(mockDb, mockUserId);
|
||||
expect((newService as any).baseURL).toBe('http://custom:3000/api/agent');
|
||||
});
|
||||
});
|
||||
|
||||
describe('createSession', () => {
|
||||
const mockParams: SessionCreationParams = {
|
||||
sessionId: 'test-session-1',
|
||||
initialContext: {
|
||||
phase: 'user_input',
|
||||
payload: {
|
||||
message: { content: 'test' },
|
||||
sessionId: 'test-session-1',
|
||||
isFirstMessage: true,
|
||||
},
|
||||
session: { sessionId: 'test-session-1', status: 'idle', stepCount: 0, messageCount: 0 },
|
||||
},
|
||||
appContext: {},
|
||||
agentConfig: { name: 'test-agent' },
|
||||
modelRuntimeConfig: { model: 'gpt-4' },
|
||||
toolManifestMap: {},
|
||||
userId: 'user-123',
|
||||
autoStart: true,
|
||||
initialMessages: [],
|
||||
};
|
||||
|
||||
it('should create session successfully with autoStart=true', async () => {
|
||||
mockQueueService.scheduleMessage.mockResolvedValueOnce('message-123');
|
||||
|
||||
const result = await service.createSession(mockParams);
|
||||
|
||||
expect(result).toEqual({
|
||||
success: true,
|
||||
sessionId: 'test-session-1',
|
||||
autoStarted: true,
|
||||
messageId: 'message-123',
|
||||
});
|
||||
|
||||
expect(mockCoordinator.saveAgentState).toHaveBeenCalledWith(
|
||||
'test-session-1',
|
||||
expect.objectContaining({
|
||||
sessionId: 'test-session-1',
|
||||
status: 'idle',
|
||||
stepCount: 0,
|
||||
messages: [],
|
||||
events: [],
|
||||
}),
|
||||
);
|
||||
|
||||
expect(mockCoordinator.createAgentSession).toHaveBeenCalledWith('test-session-1', {
|
||||
agentConfig: mockParams.agentConfig,
|
||||
modelRuntimeConfig: mockParams.modelRuntimeConfig,
|
||||
userId: mockParams.userId,
|
||||
});
|
||||
|
||||
expect(mockQueueService.scheduleMessage).toHaveBeenCalledWith({
|
||||
sessionId: 'test-session-1',
|
||||
stepIndex: 0,
|
||||
context: mockParams.initialContext,
|
||||
endpoint: 'http://localhost:3010/api/agent/run',
|
||||
priority: 'high',
|
||||
delay: 50,
|
||||
});
|
||||
});
|
||||
|
||||
it('should create session successfully with autoStart=false', async () => {
|
||||
const params = { ...mockParams, autoStart: false };
|
||||
|
||||
const result = await service.createSession(params);
|
||||
|
||||
expect(result).toEqual({
|
||||
success: true,
|
||||
sessionId: 'test-session-1',
|
||||
autoStarted: false,
|
||||
messageId: undefined,
|
||||
});
|
||||
|
||||
expect(mockQueueService.scheduleMessage).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle errors during session creation', async () => {
|
||||
mockCoordinator.saveAgentState.mockRejectedValueOnce(new Error('Database error'));
|
||||
|
||||
await expect(service.createSession(mockParams)).rejects.toThrow('Database error');
|
||||
});
|
||||
});
|
||||
|
||||
describe('executeStep', () => {
|
||||
const mockParams: AgentExecutionParams = {
|
||||
sessionId: 'test-session-1',
|
||||
stepIndex: 1,
|
||||
context: {
|
||||
phase: 'user_input',
|
||||
payload: {
|
||||
message: { content: 'test' },
|
||||
sessionId: 'test-session-1',
|
||||
isFirstMessage: false,
|
||||
},
|
||||
session: { sessionId: 'test-session-1', status: 'running', stepCount: 1, messageCount: 1 },
|
||||
},
|
||||
};
|
||||
|
||||
const mockState = {
|
||||
sessionId: 'test-session-1',
|
||||
status: 'running',
|
||||
stepCount: 1,
|
||||
messages: [],
|
||||
events: [],
|
||||
lastModified: new Date().toISOString(),
|
||||
};
|
||||
|
||||
const mockMetadata = {
|
||||
userId: 'user-123',
|
||||
agentConfig: { name: 'test-agent' },
|
||||
modelRuntimeConfig: { model: 'gpt-4' },
|
||||
createdAt: new Date().toISOString(),
|
||||
lastActiveAt: new Date().toISOString(),
|
||||
status: 'running',
|
||||
totalCost: 0,
|
||||
totalSteps: 1,
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
mockCoordinator.loadAgentState.mockResolvedValue(mockState);
|
||||
mockCoordinator.getSessionMetadata.mockResolvedValue(mockMetadata);
|
||||
});
|
||||
|
||||
it('should execute step successfully', async () => {
|
||||
const mockStepResult = {
|
||||
newState: { ...mockState, stepCount: 2, status: 'running' },
|
||||
nextContext: mockParams.context,
|
||||
events: [],
|
||||
};
|
||||
|
||||
// Mock runtime.step
|
||||
const mockRuntime = { step: vi.fn().mockResolvedValue(mockStepResult) };
|
||||
vi.spyOn(service as any, 'createAgentRuntime').mockReturnValue({ runtime: mockRuntime });
|
||||
|
||||
const result = await service.executeStep(mockParams);
|
||||
|
||||
expect(result).toEqual({
|
||||
success: true,
|
||||
state: mockStepResult.newState,
|
||||
stepResult: expect.objectContaining(mockStepResult),
|
||||
nextStepScheduled: true,
|
||||
});
|
||||
|
||||
expect(mockStreamManager.publishStreamEvent).toHaveBeenCalledWith('test-session-1', {
|
||||
type: 'step_start',
|
||||
stepIndex: 1,
|
||||
data: {},
|
||||
});
|
||||
|
||||
expect(mockStreamManager.publishStreamEvent).toHaveBeenCalledWith('test-session-1', {
|
||||
type: 'step_complete',
|
||||
stepIndex: 1,
|
||||
data: {
|
||||
stepIndex: 1,
|
||||
finalState: mockStepResult.newState,
|
||||
nextStepScheduled: true,
|
||||
},
|
||||
});
|
||||
|
||||
expect(mockCoordinator.saveStepResult).toHaveBeenCalled();
|
||||
expect(mockQueueService.scheduleMessage).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should handle missing agent state', async () => {
|
||||
mockCoordinator.loadAgentState.mockResolvedValue(null);
|
||||
|
||||
await expect(service.executeStep(mockParams)).rejects.toThrow(
|
||||
'Agent state not found for session test-session-1',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle execution errors', async () => {
|
||||
const error = new Error('Runtime error');
|
||||
const mockRuntime = { step: vi.fn().mockRejectedValue(error) };
|
||||
vi.spyOn(service as any, 'createAgentRuntime').mockReturnValue({ runtime: mockRuntime });
|
||||
|
||||
await expect(service.executeStep(mockParams)).rejects.toThrow('Runtime error');
|
||||
|
||||
expect(mockStreamManager.publishStreamEvent).toHaveBeenCalledWith('test-session-1', {
|
||||
type: 'error',
|
||||
stepIndex: 1,
|
||||
data: {
|
||||
stepIndex: 1,
|
||||
phase: 'step_execution',
|
||||
error: 'Runtime error',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle human intervention', async () => {
|
||||
const paramsWithIntervention = {
|
||||
...mockParams,
|
||||
humanInput: { type: 'text', content: 'user input' },
|
||||
approvedToolCall: { toolName: 'calculator', args: {} },
|
||||
rejectionReason: 'Not safe',
|
||||
};
|
||||
|
||||
const mockStepResult = {
|
||||
newState: { ...mockState, stepCount: 2, status: 'done' },
|
||||
nextContext: null,
|
||||
events: [],
|
||||
};
|
||||
|
||||
const mockRuntime = { step: vi.fn().mockResolvedValue(mockStepResult) };
|
||||
vi.spyOn(service as any, 'createAgentRuntime').mockReturnValue({ runtime: mockRuntime });
|
||||
vi.spyOn(service as any, 'handleHumanIntervention').mockResolvedValue({
|
||||
newState: mockState,
|
||||
nextContext: mockParams.context,
|
||||
});
|
||||
|
||||
const result = await service.executeStep(paramsWithIntervention);
|
||||
|
||||
expect((service as any).handleHumanIntervention).toHaveBeenCalledWith(
|
||||
mockRuntime,
|
||||
mockState,
|
||||
{
|
||||
humanInput: paramsWithIntervention.humanInput,
|
||||
approvedToolCall: paramsWithIntervention.approvedToolCall,
|
||||
rejectionReason: paramsWithIntervention.rejectionReason,
|
||||
},
|
||||
);
|
||||
|
||||
expect(result.success).toBe(true);
|
||||
expect(result.nextStepScheduled).toBe(false); // Should not schedule next step when status is 'done'
|
||||
});
|
||||
});
|
||||
|
||||
describe('getSessionStatus', () => {
|
||||
const mockState = {
|
||||
sessionId: 'test-session-1',
|
||||
status: 'running',
|
||||
stepCount: 5,
|
||||
messages: [{ content: 'msg1' }, { content: 'msg2' }],
|
||||
cost: { total: 0.1 },
|
||||
usage: { tokens: 100 },
|
||||
lastModified: new Date().toISOString(),
|
||||
};
|
||||
|
||||
const mockMetadata = {
|
||||
userId: 'user-123',
|
||||
createdAt: new Date(Date.now() - 3600000).toISOString(), // 1 hour ago
|
||||
lastActiveAt: new Date(Date.now() - 1800000).toISOString(), // 30 minutes ago
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
mockCoordinator.loadAgentState.mockResolvedValue(mockState);
|
||||
mockCoordinator.getSessionMetadata.mockResolvedValue(mockMetadata);
|
||||
});
|
||||
|
||||
it('should get session status successfully', async () => {
|
||||
const result = await service.getSessionStatus({
|
||||
sessionId: 'test-session-1',
|
||||
includeHistory: false,
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
sessionId: 'test-session-1',
|
||||
currentState: expect.objectContaining({
|
||||
status: 'running',
|
||||
stepCount: 5,
|
||||
cost: { total: 0.1 },
|
||||
usage: { tokens: 100 },
|
||||
}),
|
||||
metadata: mockMetadata,
|
||||
isActive: true,
|
||||
isCompleted: false,
|
||||
hasError: false,
|
||||
needsHumanInput: false,
|
||||
stats: {
|
||||
totalSteps: 5,
|
||||
totalMessages: 2,
|
||||
totalCost: 0.1,
|
||||
uptime: expect.any(Number),
|
||||
lastActiveTime: expect.any(Number),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should include history when requested', async () => {
|
||||
const mockHistory = [{ stepIndex: 1, timestamp: Date.now() }];
|
||||
const mockEvents = [{ type: 'step_start', timestamp: Date.now() }];
|
||||
|
||||
mockCoordinator.getExecutionHistory.mockResolvedValue(mockHistory);
|
||||
mockStreamManager.getStreamHistory.mockResolvedValue(mockEvents);
|
||||
|
||||
const result = await service.getSessionStatus({
|
||||
sessionId: 'test-session-1',
|
||||
includeHistory: true,
|
||||
historyLimit: 20,
|
||||
});
|
||||
|
||||
expect(result.executionHistory).toEqual(mockHistory);
|
||||
expect(result.recentEvents).toEqual(mockEvents.slice(0, 10));
|
||||
});
|
||||
|
||||
it('should handle missing session', async () => {
|
||||
mockCoordinator.loadAgentState.mockResolvedValue(null);
|
||||
mockCoordinator.getSessionMetadata.mockResolvedValue(null);
|
||||
|
||||
await expect(
|
||||
service.getSessionStatus({
|
||||
sessionId: 'nonexistent-session',
|
||||
}),
|
||||
).rejects.toThrow('Session not found');
|
||||
});
|
||||
|
||||
it('should handle different session statuses', async () => {
|
||||
// Test waiting_for_human status
|
||||
const waitingState = { ...mockState, status: 'waiting_for_human' };
|
||||
mockCoordinator.loadAgentState.mockResolvedValue(waitingState);
|
||||
|
||||
const result = await service.getSessionStatus({
|
||||
sessionId: 'test-session-1',
|
||||
});
|
||||
|
||||
expect(result.isActive).toBe(true);
|
||||
expect(result.needsHumanInput).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getPendingInterventions', () => {
|
||||
it('should get pending interventions for specific session', async () => {
|
||||
const mockState = {
|
||||
status: 'waiting_for_human',
|
||||
pendingToolsCalling: [{ toolName: 'calculator', args: {} }],
|
||||
stepCount: 3,
|
||||
lastModified: new Date().toISOString(),
|
||||
};
|
||||
|
||||
const mockMetadata = {
|
||||
userId: 'user-123',
|
||||
modelRuntimeConfig: { model: 'gpt-4' },
|
||||
};
|
||||
|
||||
mockCoordinator.loadAgentState.mockResolvedValue(mockState);
|
||||
mockCoordinator.getSessionMetadata.mockResolvedValue(mockMetadata);
|
||||
|
||||
const result = await service.getPendingInterventions({
|
||||
sessionId: 'test-session-1',
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
totalCount: 1,
|
||||
timestamp: expect.any(String),
|
||||
pendingInterventions: [
|
||||
{
|
||||
sessionId: 'test-session-1',
|
||||
type: 'tool_approval',
|
||||
status: 'waiting_for_human',
|
||||
stepCount: 3,
|
||||
lastModified: mockState.lastModified,
|
||||
userId: 'user-123',
|
||||
modelRuntimeConfig: { model: 'gpt-4' },
|
||||
pendingToolsCalling: mockState.pendingToolsCalling,
|
||||
},
|
||||
],
|
||||
});
|
||||
});
|
||||
|
||||
it('should get pending interventions for user', async () => {
|
||||
const mockSessions = ['session-1', 'session-2'];
|
||||
mockCoordinator.getActiveSessions.mockResolvedValue(mockSessions);
|
||||
|
||||
// Mock metadata for filtering by userId
|
||||
mockCoordinator.getSessionMetadata
|
||||
.mockResolvedValueOnce({ userId: 'user-123' })
|
||||
.mockResolvedValueOnce({ userId: 'other-user' });
|
||||
|
||||
// Mock states - only first session needs intervention
|
||||
mockCoordinator.loadAgentState
|
||||
.mockResolvedValueOnce({
|
||||
status: 'waiting_for_human',
|
||||
pendingHumanPrompt: 'Please confirm',
|
||||
stepCount: 2,
|
||||
lastModified: new Date().toISOString(),
|
||||
})
|
||||
.mockResolvedValueOnce({
|
||||
status: 'running',
|
||||
stepCount: 1,
|
||||
lastModified: new Date().toISOString(),
|
||||
});
|
||||
|
||||
const result = await service.getPendingInterventions({
|
||||
userId: 'user-123',
|
||||
});
|
||||
|
||||
expect(result.totalCount).toBe(1);
|
||||
expect(result.pendingInterventions[0]).toEqual({
|
||||
sessionId: 'session-1',
|
||||
type: 'human_prompt',
|
||||
status: 'waiting_for_human',
|
||||
pendingHumanPrompt: 'Please confirm',
|
||||
stepCount: 2,
|
||||
lastModified: expect.any(String),
|
||||
userId: undefined, // getSessionMetadata is not called due to the way sessions are filtered
|
||||
modelRuntimeConfig: undefined,
|
||||
});
|
||||
});
|
||||
|
||||
it('should return empty list when no interventions needed', async () => {
|
||||
mockCoordinator.loadAgentState.mockResolvedValue({
|
||||
status: 'running',
|
||||
stepCount: 1,
|
||||
});
|
||||
mockCoordinator.getSessionMetadata.mockResolvedValue({ userId: 'user-123' });
|
||||
|
||||
const result = await service.getPendingInterventions({
|
||||
sessionId: 'test-session-1',
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
totalCount: 0,
|
||||
timestamp: expect.any(String),
|
||||
pendingInterventions: [],
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('startExecution', () => {
|
||||
const mockParams: StartExecutionParams = {
|
||||
sessionId: 'test-session-1',
|
||||
context: {
|
||||
phase: 'user_input',
|
||||
payload: {
|
||||
message: { content: 'test' },
|
||||
sessionId: 'test-session-1',
|
||||
isFirstMessage: false,
|
||||
},
|
||||
session: { sessionId: 'test-session-1', status: 'idle', stepCount: 0, messageCount: 0 },
|
||||
},
|
||||
priority: 'high',
|
||||
delay: 500,
|
||||
};
|
||||
|
||||
const mockState = {
|
||||
sessionId: 'test-session-1',
|
||||
status: 'idle',
|
||||
stepCount: 2,
|
||||
messages: [{ content: 'msg1' }],
|
||||
lastModified: new Date().toISOString(),
|
||||
};
|
||||
|
||||
const mockMetadata = {
|
||||
userId: 'user-123',
|
||||
agentConfig: { name: 'test-agent' },
|
||||
modelRuntimeConfig: { model: 'gpt-4' },
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
mockCoordinator.getSessionMetadata.mockResolvedValue(mockMetadata);
|
||||
mockCoordinator.loadAgentState.mockResolvedValue(mockState);
|
||||
mockQueueService.scheduleMessage.mockResolvedValue('message-456');
|
||||
});
|
||||
|
||||
it('should start execution successfully', async () => {
|
||||
const result = await service.startExecution(mockParams);
|
||||
|
||||
expect(result).toEqual({
|
||||
success: true,
|
||||
scheduled: true,
|
||||
sessionId: 'test-session-1',
|
||||
messageId: 'message-456',
|
||||
});
|
||||
|
||||
expect(mockCoordinator.saveAgentState).toHaveBeenCalledWith(
|
||||
'test-session-1',
|
||||
expect.objectContaining({
|
||||
status: 'running',
|
||||
lastModified: expect.any(String),
|
||||
}),
|
||||
);
|
||||
|
||||
expect(mockQueueService.scheduleMessage).toHaveBeenCalledWith({
|
||||
sessionId: 'test-session-1',
|
||||
stepIndex: 2,
|
||||
context: mockParams.context,
|
||||
endpoint: 'http://localhost:3010/api/agent/run',
|
||||
priority: 'high',
|
||||
delay: 500,
|
||||
});
|
||||
});
|
||||
|
||||
it('should create default context when none provided', async () => {
|
||||
const paramsWithoutContext = { ...mockParams };
|
||||
delete paramsWithoutContext.context;
|
||||
|
||||
await service.startExecution(paramsWithoutContext);
|
||||
|
||||
expect(mockQueueService.scheduleMessage).toHaveBeenCalledWith({
|
||||
sessionId: 'test-session-1',
|
||||
stepIndex: 2,
|
||||
context: expect.objectContaining({
|
||||
phase: 'user_input',
|
||||
payload: expect.objectContaining({
|
||||
sessionId: 'test-session-1',
|
||||
isFirstMessage: true,
|
||||
message: expect.objectContaining({
|
||||
content: '',
|
||||
}),
|
||||
}),
|
||||
session: expect.objectContaining({
|
||||
sessionId: 'test-session-1',
|
||||
status: 'idle',
|
||||
stepCount: 2,
|
||||
messageCount: 1,
|
||||
}),
|
||||
}),
|
||||
endpoint: 'http://localhost:3010/api/agent/run',
|
||||
priority: 'high', // Uses the provided priority from params
|
||||
delay: 500, // Uses the provided delay from params
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle session not found', async () => {
|
||||
mockCoordinator.getSessionMetadata.mockResolvedValue(null);
|
||||
|
||||
await expect(service.startExecution(mockParams)).rejects.toThrow(
|
||||
'Session test-session-1 not found',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle already running session', async () => {
|
||||
mockCoordinator.loadAgentState.mockResolvedValue({
|
||||
...mockState,
|
||||
status: 'running',
|
||||
});
|
||||
|
||||
await expect(service.startExecution(mockParams)).rejects.toThrow(
|
||||
'Session test-session-1 is already running',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle completed session', async () => {
|
||||
mockCoordinator.loadAgentState.mockResolvedValue({
|
||||
...mockState,
|
||||
status: 'done',
|
||||
});
|
||||
|
||||
await expect(service.startExecution(mockParams)).rejects.toThrow(
|
||||
'Session test-session-1 is already completed',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle error state session', async () => {
|
||||
mockCoordinator.loadAgentState.mockResolvedValue({
|
||||
...mockState,
|
||||
status: 'error',
|
||||
});
|
||||
|
||||
await expect(service.startExecution(mockParams)).rejects.toThrow(
|
||||
'Session test-session-1 is in error state',
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('processHumanIntervention', () => {
|
||||
it('should process human intervention successfully', async () => {
|
||||
mockQueueService.scheduleMessage.mockResolvedValue('message-789');
|
||||
|
||||
const result = await service.processHumanIntervention({
|
||||
sessionId: 'test-session-1',
|
||||
stepIndex: 2,
|
||||
action: 'approve',
|
||||
approvedToolCall: { toolName: 'calculator', args: {} },
|
||||
});
|
||||
|
||||
expect(result).toEqual({
|
||||
messageId: 'message-789',
|
||||
});
|
||||
|
||||
expect(mockQueueService.scheduleMessage).toHaveBeenCalledWith({
|
||||
sessionId: 'test-session-1',
|
||||
stepIndex: 2,
|
||||
context: undefined,
|
||||
endpoint: 'http://localhost:3010/api/agent/run',
|
||||
priority: 'high',
|
||||
delay: 100,
|
||||
payload: {
|
||||
approvedToolCall: { toolName: 'calculator', args: {} },
|
||||
humanInput: undefined,
|
||||
rejectionReason: undefined,
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle different intervention actions', async () => {
|
||||
mockQueueService.scheduleMessage.mockResolvedValue('message-890');
|
||||
|
||||
await service.processHumanIntervention({
|
||||
sessionId: 'test-session-1',
|
||||
stepIndex: 3,
|
||||
action: 'input',
|
||||
humanInput: { type: 'text', content: 'user response' },
|
||||
});
|
||||
|
||||
expect(mockQueueService.scheduleMessage).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
payload: expect.objectContaining({
|
||||
humanInput: { type: 'text', content: 'user response' },
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe('private methods', () => {
|
||||
describe('shouldContinueExecution', () => {
|
||||
it('should return false for completed status', () => {
|
||||
const shouldContinue = (service as any).shouldContinueExecution(
|
||||
{ status: 'done' },
|
||||
{ phase: 'user_input' },
|
||||
);
|
||||
expect(shouldContinue).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false when waiting for human input', () => {
|
||||
const shouldContinue = (service as any).shouldContinueExecution(
|
||||
{ status: 'waiting_for_human' },
|
||||
{ phase: 'user_input' },
|
||||
);
|
||||
expect(shouldContinue).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false when max steps reached', () => {
|
||||
const shouldContinue = (service as any).shouldContinueExecution(
|
||||
{ status: 'running', maxSteps: 10, stepCount: 10 },
|
||||
{ phase: 'user_input' },
|
||||
);
|
||||
expect(shouldContinue).toBe(false);
|
||||
});
|
||||
|
||||
it('should return false when cost limit exceeded with stop action', () => {
|
||||
const shouldContinue = (service as any).shouldContinueExecution(
|
||||
{
|
||||
status: 'running',
|
||||
cost: { total: 1.0 },
|
||||
costLimit: { maxTotalCost: 0.5, onExceeded: 'stop' },
|
||||
},
|
||||
{ phase: 'user_input' },
|
||||
);
|
||||
expect(shouldContinue).toBe(false);
|
||||
});
|
||||
|
||||
it('should return true when cost limit exceeded with continue action', () => {
|
||||
const shouldContinue = (service as any).shouldContinueExecution(
|
||||
{
|
||||
status: 'running',
|
||||
cost: { total: 1.0 },
|
||||
costLimit: { maxTotalCost: 0.5, onExceeded: 'continue' },
|
||||
},
|
||||
{ phase: 'user_input' },
|
||||
);
|
||||
expect(shouldContinue).toBe(true);
|
||||
});
|
||||
|
||||
it('should return false when no context provided', () => {
|
||||
const shouldContinue = (service as any).shouldContinueExecution(
|
||||
{ status: 'running' },
|
||||
null,
|
||||
);
|
||||
expect(shouldContinue).toBe(false);
|
||||
});
|
||||
|
||||
it('should return true for normal running state', () => {
|
||||
const shouldContinue = (service as any).shouldContinueExecution(
|
||||
{ status: 'running' },
|
||||
{ phase: 'user_input' },
|
||||
);
|
||||
expect(shouldContinue).toBe(true);
|
||||
});
|
||||
});
|
||||
|
||||
describe('calculateStepDelay', () => {
|
||||
it('should return base delay for normal step', () => {
|
||||
const delay = (service as any).calculateStepDelay({
|
||||
events: [{ type: 'llm_response' }],
|
||||
});
|
||||
expect(delay).toBe(1000);
|
||||
});
|
||||
|
||||
it('should return longer delay for tool calls', () => {
|
||||
const delay = (service as any).calculateStepDelay({
|
||||
events: [{ type: 'tool_result' }],
|
||||
});
|
||||
expect(delay).toBe(2000);
|
||||
});
|
||||
|
||||
it('should return exponential backoff delay for errors', () => {
|
||||
const delay = (service as any).calculateStepDelay({
|
||||
events: [{ type: 'error' }],
|
||||
});
|
||||
expect(delay).toBe(2000);
|
||||
});
|
||||
});
|
||||
|
||||
describe('calculatePriority', () => {
|
||||
it('should return high priority for human input needed', () => {
|
||||
const priority = (service as any).calculatePriority({
|
||||
newState: { status: 'waiting_for_human' },
|
||||
events: [],
|
||||
});
|
||||
expect(priority).toBe('high');
|
||||
});
|
||||
|
||||
it('should return normal priority for errors', () => {
|
||||
const priority = (service as any).calculatePriority({
|
||||
newState: { status: 'running' },
|
||||
events: [{ type: 'error' }],
|
||||
});
|
||||
expect(priority).toBe('normal');
|
||||
});
|
||||
|
||||
it('should return normal priority by default', () => {
|
||||
const priority = (service as any).calculatePriority({
|
||||
newState: { status: 'running' },
|
||||
events: [{ type: 'llm_response' }],
|
||||
});
|
||||
expect(priority).toBe('normal');
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,699 @@
|
||||
import { AgentRuntime, AgentState } from '@lobechat/agent-runtime';
|
||||
import debug from 'debug';
|
||||
import urlJoin from 'url-join';
|
||||
|
||||
import { MessageModel } from '@/database/models/message';
|
||||
import { LobeChatDatabase } from '@/database/type';
|
||||
import {
|
||||
AgentRuntimeCoordinator,
|
||||
GeneralAgent,
|
||||
StreamEventManager,
|
||||
} from '@/server/modules/AgentRuntime';
|
||||
import {
|
||||
RuntimeExecutorContext,
|
||||
createRuntimeExecutors,
|
||||
} from '@/server/modules/AgentRuntime/RuntimeExecutors';
|
||||
import { mcpService } from '@/server/services/mcp';
|
||||
import { PluginGatewayService } from '@/server/services/pluginGateway';
|
||||
import { QueueService } from '@/server/services/queue';
|
||||
import { ToolExecutionService } from '@/server/services/toolExecution';
|
||||
import { BuiltinToolsExecutor } from '@/server/services/toolExecution/builtin';
|
||||
|
||||
import type {
|
||||
AgentExecutionParams,
|
||||
AgentExecutionResult,
|
||||
PendingInterventionsResult,
|
||||
SessionCreationParams,
|
||||
SessionCreationResult,
|
||||
SessionStatusResult,
|
||||
StartExecutionParams,
|
||||
StartExecutionResult,
|
||||
} from './types';
|
||||
|
||||
const log = debug('lobe-server:agent-runtime-service');
|
||||
|
||||
/**
|
||||
* Agent Runtime Service
|
||||
* 封装 Agent 执行相关的逻辑,提供统一的服务接口
|
||||
*/
|
||||
export class AgentRuntimeService {
|
||||
private coordinator: AgentRuntimeCoordinator;
|
||||
private streamManager: StreamEventManager;
|
||||
private queueService: QueueService;
|
||||
private toolExecutionService: ToolExecutionService;
|
||||
private get baseURL() {
|
||||
const baseUrl =
|
||||
process.env.AGENT_RUNTIME_BASE_URL || process.env.APP_URL || 'http://localhost:3010';
|
||||
|
||||
return urlJoin(baseUrl, '/api/agent');
|
||||
}
|
||||
private userId: string;
|
||||
private db: LobeChatDatabase;
|
||||
private messageModel: MessageModel;
|
||||
|
||||
constructor(db: LobeChatDatabase, userId: string) {
|
||||
this.coordinator = new AgentRuntimeCoordinator();
|
||||
this.streamManager = new StreamEventManager();
|
||||
this.queueService = new QueueService();
|
||||
this.userId = userId;
|
||||
this.db = db;
|
||||
this.messageModel = new MessageModel(db, this.userId);
|
||||
|
||||
// Initialize ToolExecutionService with dependencies
|
||||
const pluginGatewayService = new PluginGatewayService();
|
||||
const builtinToolsExecutor = new BuiltinToolsExecutor();
|
||||
|
||||
this.toolExecutionService = new ToolExecutionService({
|
||||
builtinToolsExecutor,
|
||||
mcpService,
|
||||
pluginGatewayService,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建新的 Agent 会话
|
||||
*/
|
||||
async createSession(params: SessionCreationParams): Promise<SessionCreationResult> {
|
||||
const {
|
||||
sessionId,
|
||||
initialContext,
|
||||
agentConfig,
|
||||
modelRuntimeConfig,
|
||||
userId,
|
||||
autoStart = true,
|
||||
tools,
|
||||
initialMessages = [],
|
||||
appContext,
|
||||
toolManifestMap,
|
||||
} = params;
|
||||
|
||||
try {
|
||||
log('[ %s] Creating new session (autoStart: %s)', sessionId, autoStart);
|
||||
|
||||
// 初始化会话状态 - 先创建状态再保存
|
||||
const initialState = {
|
||||
createdAt: new Date().toISOString(),
|
||||
lastModified: new Date().toISOString(),
|
||||
// 使用传入的初始消息
|
||||
messages: initialMessages,
|
||||
metadata: {
|
||||
agentConfig,
|
||||
modelRuntimeConfig,
|
||||
userId,
|
||||
...appContext,
|
||||
},
|
||||
sessionId,
|
||||
status: 'idle',
|
||||
stepCount: 0,
|
||||
toolManifestMap,
|
||||
tools,
|
||||
} as Partial<AgentState>;
|
||||
|
||||
// 使用协调器创建会话,自动发送初始化事件
|
||||
await this.coordinator.createAgentSession(sessionId, {
|
||||
agentConfig,
|
||||
modelRuntimeConfig,
|
||||
userId,
|
||||
});
|
||||
|
||||
// 保存初始状态
|
||||
await this.coordinator.saveAgentState(sessionId, initialState as any);
|
||||
|
||||
let messageId: string | undefined;
|
||||
let autoStarted = false;
|
||||
|
||||
// 只有在 autoStart 为 true 时才调度第一步执行
|
||||
if (autoStart) {
|
||||
messageId = await this.queueService.scheduleMessage({
|
||||
context: initialContext,
|
||||
delay: 50, // 短延迟启动
|
||||
endpoint: `${this.baseURL}/run`,
|
||||
priority: 'high',
|
||||
sessionId,
|
||||
stepIndex: 0,
|
||||
});
|
||||
autoStarted = true;
|
||||
log('[%s]Scheduled first step (messageId: %s)', sessionId, messageId);
|
||||
} else {
|
||||
log('[%s]created session without auto-start', sessionId);
|
||||
}
|
||||
|
||||
return { autoStarted, messageId, sessionId, success: true };
|
||||
} catch (error) {
|
||||
log('Failed to create session %s: %O', sessionId, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行 Agent 步骤
|
||||
*/
|
||||
async executeStep(params: AgentExecutionParams): Promise<AgentExecutionResult> {
|
||||
const { sessionId, stepIndex, context, humanInput, approvedToolCall, rejectionReason } = params;
|
||||
|
||||
try {
|
||||
log(`[${sessionId}] Executing step %d`, stepIndex);
|
||||
|
||||
// 发布步骤开始事件
|
||||
await this.streamManager.publishStreamEvent(sessionId, {
|
||||
data: {},
|
||||
stepIndex,
|
||||
type: 'step_start',
|
||||
});
|
||||
|
||||
// 获取会话状态和元数据
|
||||
const [agentState, sessionMetadata] = await Promise.all([
|
||||
this.coordinator.loadAgentState(sessionId),
|
||||
this.coordinator.getSessionMetadata(sessionId),
|
||||
]);
|
||||
|
||||
if (!agentState) {
|
||||
throw new Error(`Agent state not found for session ${sessionId}`);
|
||||
}
|
||||
|
||||
// 创建 Agent 和 Runtime 实例
|
||||
const { runtime } = await this.createAgentRuntime({
|
||||
metadata: sessionMetadata,
|
||||
sessionId,
|
||||
stepIndex,
|
||||
});
|
||||
|
||||
// 处理人工干预
|
||||
let currentContext = context;
|
||||
let currentState = agentState;
|
||||
|
||||
if (humanInput || approvedToolCall || rejectionReason) {
|
||||
const interventionResult = await this.handleHumanIntervention(runtime, currentState, {
|
||||
approvedToolCall,
|
||||
humanInput,
|
||||
rejectionReason,
|
||||
});
|
||||
currentState = interventionResult.newState;
|
||||
currentContext = interventionResult.nextContext;
|
||||
}
|
||||
|
||||
// 执行步骤
|
||||
const startAt = Date.now();
|
||||
const stepResult = await runtime.step(currentState, currentContext);
|
||||
|
||||
// 保存状态,协调器会自动处理事件发送
|
||||
await this.coordinator.saveStepResult(sessionId, {
|
||||
...stepResult,
|
||||
executionTime: Date.now() - startAt,
|
||||
stepIndex, // placeholder
|
||||
});
|
||||
|
||||
// 决定是否调度下一步
|
||||
const shouldContinue = this.shouldContinueExecution(
|
||||
stepResult.newState,
|
||||
stepResult.nextContext,
|
||||
);
|
||||
let nextStepScheduled = false;
|
||||
|
||||
// 发布步骤完成事件
|
||||
await this.streamManager.publishStreamEvent(sessionId, {
|
||||
data: {
|
||||
finalState: stepResult.newState,
|
||||
nextStepScheduled,
|
||||
stepIndex,
|
||||
},
|
||||
stepIndex,
|
||||
type: 'step_complete',
|
||||
});
|
||||
|
||||
log(`[${sessionId}] Step %d completed`, stepIndex);
|
||||
|
||||
if (shouldContinue && stepResult.nextContext) {
|
||||
const nextStepIndex = stepIndex + 1;
|
||||
const delay = this.calculateStepDelay(stepResult);
|
||||
const priority = this.calculatePriority(stepResult);
|
||||
|
||||
await this.queueService.scheduleMessage({
|
||||
context: stepResult.nextContext,
|
||||
delay,
|
||||
endpoint: `${this.baseURL}/run`,
|
||||
priority,
|
||||
sessionId,
|
||||
stepIndex: nextStepIndex,
|
||||
});
|
||||
nextStepScheduled = true;
|
||||
|
||||
log(`[${sessionId}] Scheduled next step %d for session %s`, nextStepIndex);
|
||||
}
|
||||
|
||||
return {
|
||||
nextStepScheduled,
|
||||
state: stepResult.newState,
|
||||
stepResult,
|
||||
success: true,
|
||||
};
|
||||
} catch (error) {
|
||||
log('Step %d failed for session %s: %O', stepIndex, sessionId, error);
|
||||
|
||||
// 发布错误事件
|
||||
await this.streamManager.publishStreamEvent(sessionId, {
|
||||
data: {
|
||||
error: (error as Error).message,
|
||||
phase: 'step_execution',
|
||||
stepIndex,
|
||||
},
|
||||
stepIndex,
|
||||
type: 'error',
|
||||
});
|
||||
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取会话状态
|
||||
*/
|
||||
async getSessionStatus(params: {
|
||||
historyLimit?: number;
|
||||
includeHistory?: boolean;
|
||||
sessionId: string;
|
||||
}): Promise<SessionStatusResult> {
|
||||
const { sessionId, includeHistory = false, historyLimit = 10 } = params;
|
||||
|
||||
try {
|
||||
log('Getting session status for %s', sessionId);
|
||||
|
||||
// 获取当前状态和元数据
|
||||
const [currentState, sessionMetadata] = await Promise.all([
|
||||
this.coordinator.loadAgentState(sessionId),
|
||||
this.coordinator.getSessionMetadata(sessionId),
|
||||
]);
|
||||
|
||||
if (!currentState || !sessionMetadata) {
|
||||
throw new Error('Session not found');
|
||||
}
|
||||
|
||||
// 获取执行历史(如果需要)
|
||||
let executionHistory;
|
||||
if (includeHistory) {
|
||||
try {
|
||||
executionHistory = await this.coordinator.getExecutionHistory(sessionId, historyLimit);
|
||||
} catch (error) {
|
||||
log('Failed to load execution history: %O', error);
|
||||
executionHistory = [];
|
||||
}
|
||||
}
|
||||
|
||||
// 获取最近的流式事件(用于调试)
|
||||
let recentEvents;
|
||||
if (includeHistory) {
|
||||
try {
|
||||
recentEvents = await this.streamManager.getStreamHistory(sessionId, 20);
|
||||
} catch (error) {
|
||||
log('Failed to load recent events: %O', error);
|
||||
recentEvents = [];
|
||||
}
|
||||
}
|
||||
|
||||
// 计算会话统计信息
|
||||
const stats = {
|
||||
lastActiveTime: sessionMetadata.lastActiveAt
|
||||
? Date.now() - new Date(sessionMetadata.lastActiveAt).getTime()
|
||||
: 0,
|
||||
totalCost: currentState.cost?.total || 0,
|
||||
totalMessages: currentState.messages?.length || 0,
|
||||
totalSteps: currentState.stepCount || 0,
|
||||
uptime: sessionMetadata.createdAt
|
||||
? Date.now() - new Date(sessionMetadata.createdAt).getTime()
|
||||
: 0,
|
||||
};
|
||||
|
||||
return {
|
||||
currentState: {
|
||||
cost: currentState.cost,
|
||||
costLimit: currentState.costLimit,
|
||||
error: currentState.error,
|
||||
interruption: currentState.interruption,
|
||||
lastModified: currentState.lastModified,
|
||||
maxSteps: currentState.maxSteps,
|
||||
pendingHumanPrompt: currentState.pendingHumanPrompt,
|
||||
pendingHumanSelect: currentState.pendingHumanSelect,
|
||||
pendingToolsCalling: currentState.pendingToolsCalling,
|
||||
status: currentState.status,
|
||||
stepCount: currentState.stepCount,
|
||||
usage: currentState.usage,
|
||||
},
|
||||
executionHistory: executionHistory?.slice(0, historyLimit),
|
||||
hasError: currentState.status === 'error',
|
||||
isActive: ['running', 'waiting_for_human'].includes(currentState.status),
|
||||
isCompleted: currentState.status === 'done',
|
||||
metadata: sessionMetadata,
|
||||
needsHumanInput: currentState.status === 'waiting_for_human',
|
||||
recentEvents: recentEvents?.slice(0, 10),
|
||||
sessionId,
|
||||
stats,
|
||||
};
|
||||
} catch (error) {
|
||||
log('Failed to get session status for %s: %O', sessionId, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取待处理的人工干预列表
|
||||
*/
|
||||
async getPendingInterventions(params: {
|
||||
sessionId?: string;
|
||||
userId?: string;
|
||||
}): Promise<PendingInterventionsResult> {
|
||||
const { sessionId, userId } = params;
|
||||
|
||||
try {
|
||||
log('Getting pending interventions for sessionId: %s, userId: %s', sessionId, userId);
|
||||
|
||||
let sessions: string[] = [];
|
||||
|
||||
if (sessionId) {
|
||||
sessions = [sessionId];
|
||||
} else if (userId) {
|
||||
// 获取用户的所有活跃会话
|
||||
try {
|
||||
const activeSessions = await this.coordinator.getActiveSessions();
|
||||
|
||||
// 过滤出属于该用户的会话
|
||||
const userSessions = [];
|
||||
for (const session of activeSessions) {
|
||||
try {
|
||||
const metadata = await this.coordinator.getSessionMetadata(session);
|
||||
if (metadata?.userId === userId) {
|
||||
userSessions.push(session);
|
||||
}
|
||||
} catch (error) {
|
||||
log('Failed to get metadata for session %s: %O', session, error);
|
||||
}
|
||||
}
|
||||
sessions = userSessions;
|
||||
} catch (error) {
|
||||
log('Failed to get active sessions: %O', error);
|
||||
sessions = [];
|
||||
}
|
||||
}
|
||||
|
||||
// 检查每个会话的状态
|
||||
const pendingInterventions = [];
|
||||
|
||||
for (const session of sessions) {
|
||||
try {
|
||||
const [state, metadata] = await Promise.all([
|
||||
this.coordinator.loadAgentState(session),
|
||||
this.coordinator.getSessionMetadata(session),
|
||||
]);
|
||||
|
||||
if (state?.status === 'waiting_for_human') {
|
||||
const intervention: any = {
|
||||
lastModified: state.lastModified,
|
||||
modelRuntimeConfig: metadata?.modelRuntimeConfig,
|
||||
sessionId: session,
|
||||
status: state.status,
|
||||
stepCount: state.stepCount,
|
||||
userId: metadata?.userId,
|
||||
};
|
||||
|
||||
// 添加具体的待处理内容
|
||||
if (state.pendingToolsCalling) {
|
||||
intervention.type = 'tool_approval';
|
||||
intervention.pendingToolsCalling = state.pendingToolsCalling;
|
||||
} else if (state.pendingHumanPrompt) {
|
||||
intervention.type = 'human_prompt';
|
||||
intervention.pendingHumanPrompt = state.pendingHumanPrompt;
|
||||
} else if (state.pendingHumanSelect) {
|
||||
intervention.type = 'human_select';
|
||||
intervention.pendingHumanSelect = state.pendingHumanSelect;
|
||||
}
|
||||
|
||||
pendingInterventions.push(intervention);
|
||||
}
|
||||
} catch (error) {
|
||||
log('Failed to get state for session %s: %O', session, error);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
pendingInterventions,
|
||||
timestamp: new Date().toISOString(),
|
||||
totalCount: pendingInterventions.length,
|
||||
};
|
||||
} catch (error) {
|
||||
log('Failed to get pending interventions: %O', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 显式启动会话执行
|
||||
*/
|
||||
async startExecution(params: StartExecutionParams): Promise<StartExecutionResult> {
|
||||
const { sessionId, context, priority = 'normal', delay = 50 } = params;
|
||||
|
||||
try {
|
||||
log('Starting execution for session %s', sessionId);
|
||||
|
||||
// 检查会话是否存在
|
||||
const sessionMetadata = await this.coordinator.getSessionMetadata(sessionId);
|
||||
if (!sessionMetadata) {
|
||||
throw new Error(`Session ${sessionId} not found`);
|
||||
}
|
||||
|
||||
// 获取当前状态
|
||||
const currentState = await this.coordinator.loadAgentState(sessionId);
|
||||
if (!currentState) {
|
||||
throw new Error(`Agent state not found for session ${sessionId}`);
|
||||
}
|
||||
|
||||
// 检查会话状态
|
||||
if (currentState.status === 'running') {
|
||||
throw new Error(`Session ${sessionId} is already running`);
|
||||
}
|
||||
|
||||
if (currentState.status === 'done') {
|
||||
throw new Error(`Session ${sessionId} is already completed`);
|
||||
}
|
||||
|
||||
if (currentState.status === 'error') {
|
||||
throw new Error(`Session ${sessionId} is in error state`);
|
||||
}
|
||||
|
||||
// 构建执行上下文
|
||||
let executionContext = context;
|
||||
if (!executionContext) {
|
||||
// 如果没有提供上下文,从元数据构建默认上下文
|
||||
executionContext = {
|
||||
payload: {
|
||||
isFirstMessage: true,
|
||||
message: [{ content: '' }],
|
||||
},
|
||||
phase: 'user_input' as const,
|
||||
session: {
|
||||
messageCount: currentState.messages?.length || 0,
|
||||
sessionId,
|
||||
status: 'idle' as const,
|
||||
stepCount: currentState.stepCount || 0,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// 更新会话状态为运行中
|
||||
await this.coordinator.saveAgentState(sessionId, {
|
||||
...currentState,
|
||||
lastModified: new Date().toISOString(),
|
||||
status: 'running',
|
||||
});
|
||||
|
||||
// 调度执行
|
||||
const messageId = await this.queueService.scheduleMessage({
|
||||
context: executionContext,
|
||||
delay,
|
||||
endpoint: `${this.baseURL}/run`,
|
||||
priority,
|
||||
sessionId,
|
||||
stepIndex: currentState.stepCount || 0,
|
||||
});
|
||||
|
||||
log('Scheduled execution for session %s (messageId: %s)', sessionId, messageId);
|
||||
|
||||
return {
|
||||
messageId,
|
||||
scheduled: true,
|
||||
sessionId,
|
||||
success: true,
|
||||
};
|
||||
} catch (error) {
|
||||
log('Failed to start execution for session %s: %O', sessionId, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理人工干预
|
||||
*/
|
||||
async processHumanIntervention(params: {
|
||||
action: 'approve' | 'reject' | 'input' | 'select';
|
||||
approvedToolCall?: any;
|
||||
humanInput?: any;
|
||||
rejectionReason?: string;
|
||||
sessionId: string;
|
||||
stepIndex: number;
|
||||
}): Promise<{ messageId: string }> {
|
||||
const { sessionId, stepIndex, action, approvedToolCall, humanInput, rejectionReason } = params;
|
||||
|
||||
try {
|
||||
log(
|
||||
'Processing human intervention for session %s:%d (action: %s)',
|
||||
sessionId,
|
||||
stepIndex,
|
||||
action,
|
||||
);
|
||||
|
||||
// 高优先级调度执行
|
||||
const messageId = await this.queueService.scheduleMessage({
|
||||
context: undefined, // 会从状态管理器中获取
|
||||
delay: 100,
|
||||
endpoint: `${this.baseURL}/run`,
|
||||
payload: { approvedToolCall, humanInput, rejectionReason },
|
||||
priority: 'high',
|
||||
sessionId,
|
||||
stepIndex,
|
||||
});
|
||||
|
||||
log('Scheduled immediate execution for session %s (messageId: %s)', sessionId, messageId);
|
||||
|
||||
return { messageId };
|
||||
} catch (error) {
|
||||
log('Failed to process human intervention for session %s: %O', sessionId, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建 Agent Runtime 实例
|
||||
*/
|
||||
private async createAgentRuntime({
|
||||
metadata,
|
||||
sessionId,
|
||||
stepIndex,
|
||||
}: {
|
||||
metadata?: any;
|
||||
sessionId: string;
|
||||
stepIndex: number;
|
||||
}) {
|
||||
// 创建 Durable Agent 实例
|
||||
const agent = new GeneralAgent({
|
||||
agentConfig: metadata?.agentConfig,
|
||||
modelRuntimeConfig: metadata?.modelRuntimeConfig,
|
||||
sessionId,
|
||||
userId: metadata?.userId,
|
||||
});
|
||||
|
||||
// 创建流式执行器上下文
|
||||
const executorContext: RuntimeExecutorContext = {
|
||||
messageModel: this.messageModel,
|
||||
sessionId,
|
||||
stepIndex,
|
||||
streamManager: this.streamManager,
|
||||
toolExecutionService: this.toolExecutionService,
|
||||
userId: metadata?.userId,
|
||||
};
|
||||
|
||||
// 创建 Agent Runtime 实例
|
||||
const runtime = new AgentRuntime(agent as any, {
|
||||
executors: createRuntimeExecutors(executorContext),
|
||||
});
|
||||
|
||||
return { agent, runtime };
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理人工干预逻辑
|
||||
*/
|
||||
private async handleHumanIntervention(
|
||||
runtime: AgentRuntime,
|
||||
state: any,
|
||||
intervention: { approvedToolCall?: any; humanInput?: any; rejectionReason?: string },
|
||||
) {
|
||||
const { humanInput, approvedToolCall, rejectionReason } = intervention;
|
||||
|
||||
if (approvedToolCall && state.status === 'waiting_for_human') {
|
||||
// TODO: 实现 approveToolCall 逻辑
|
||||
return { newState: state, nextContext: undefined };
|
||||
} else if (rejectionReason && state.status === 'waiting_for_human') {
|
||||
// TODO: 实现 rejectToolCall 逻辑
|
||||
return { newState: state, nextContext: undefined };
|
||||
} else if (humanInput) {
|
||||
// TODO: 实现 processHumanInput 逻辑
|
||||
return { newState: state, nextContext: undefined };
|
||||
}
|
||||
|
||||
return { newState: state, nextContext: undefined };
|
||||
}
|
||||
|
||||
/**
|
||||
* 决定是否继续执行
|
||||
*/
|
||||
private shouldContinueExecution(state: any, context?: any): boolean {
|
||||
// 已完成
|
||||
if (state.status === 'done') return false;
|
||||
|
||||
// 需要人工干预
|
||||
if (state.status === 'waiting_for_human') return false;
|
||||
|
||||
// 出错了
|
||||
if (state.status === 'error') return false;
|
||||
|
||||
// 被中断
|
||||
if (state.status === 'interrupted') return false;
|
||||
|
||||
// 达到最大步数
|
||||
if (state.maxSteps && state.stepCount >= state.maxSteps) return false;
|
||||
|
||||
// 超过成本限制
|
||||
if (state.costLimit && state.cost?.total >= state.costLimit.maxTotalCost) {
|
||||
return state.costLimit.onExceeded !== 'stop';
|
||||
}
|
||||
|
||||
// 没有下一个上下文
|
||||
if (!context) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* 计算步骤延迟
|
||||
*/
|
||||
private calculateStepDelay(stepResult: any): number {
|
||||
const baseDelay = 50;
|
||||
|
||||
// 如果有工具调用,延迟长一点
|
||||
if (stepResult.events?.some((e: any) => e.type === 'tool_result')) {
|
||||
return baseDelay + 50;
|
||||
}
|
||||
|
||||
// 如果有错误,使用指数退避
|
||||
if (stepResult.events?.some((e: any) => e.type === 'error')) {
|
||||
return Math.min(baseDelay * 2, 1000);
|
||||
}
|
||||
|
||||
return baseDelay;
|
||||
}
|
||||
|
||||
/**
|
||||
* 计算优先级
|
||||
*/
|
||||
private calculatePriority(stepResult: any): 'high' | 'normal' | 'low' {
|
||||
// 如果需要人工干预,高优先级
|
||||
if (stepResult.newState?.status === 'waiting_for_human') {
|
||||
return 'high';
|
||||
}
|
||||
|
||||
// 如果有错误,正常优先级
|
||||
if (stepResult.events?.some((e: any) => e.type === 'error')) {
|
||||
return 'normal';
|
||||
}
|
||||
|
||||
return 'normal';
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
export * from './AgentRuntimeService';
|
||||
export * from './types';
|
||||
@@ -0,0 +1,105 @@
|
||||
import { AgentRuntimeContext } from '@lobechat/agent-runtime';
|
||||
import { LobeToolManifest } from '@lobechat/context-engine/src/tools/types';
|
||||
|
||||
export interface AgentExecutionParams {
|
||||
approvedToolCall?: any;
|
||||
context?: AgentRuntimeContext;
|
||||
humanInput?: any;
|
||||
rejectionReason?: string;
|
||||
sessionId: string;
|
||||
stepIndex: number;
|
||||
}
|
||||
|
||||
export interface AgentExecutionResult {
|
||||
nextStepScheduled: boolean;
|
||||
state: any;
|
||||
stepResult?: any;
|
||||
success: boolean;
|
||||
}
|
||||
|
||||
export interface SessionCreationParams {
|
||||
agentConfig?: any;
|
||||
appContext: {
|
||||
sessionId?: string;
|
||||
threadId?: string | null;
|
||||
topicId?: string | null;
|
||||
};
|
||||
autoStart?: boolean;
|
||||
initialContext: AgentRuntimeContext;
|
||||
initialMessages?: any[];
|
||||
modelRuntimeConfig?: any;
|
||||
sessionId: string;
|
||||
toolManifestMap: Record<string, LobeToolManifest>;
|
||||
tools?: any[];
|
||||
userId?: string;
|
||||
}
|
||||
|
||||
export interface SessionCreationResult {
|
||||
autoStarted: boolean;
|
||||
messageId?: string;
|
||||
sessionId: string;
|
||||
success: boolean;
|
||||
}
|
||||
|
||||
export interface SessionStatusResult {
|
||||
currentState: {
|
||||
cost?: any;
|
||||
costLimit?: any;
|
||||
error?: string;
|
||||
interruption?: any;
|
||||
lastModified: string;
|
||||
maxSteps?: number;
|
||||
pendingHumanPrompt?: any;
|
||||
pendingHumanSelect?: any;
|
||||
pendingToolsCalling?: any;
|
||||
status: string;
|
||||
stepCount: number;
|
||||
usage?: any;
|
||||
};
|
||||
executionHistory?: any[];
|
||||
hasError: boolean;
|
||||
isActive: boolean;
|
||||
isCompleted: boolean;
|
||||
metadata: any;
|
||||
needsHumanInput: boolean;
|
||||
recentEvents?: any[];
|
||||
sessionId: string;
|
||||
stats: {
|
||||
lastActiveTime: number;
|
||||
totalCost: number;
|
||||
totalMessages: number;
|
||||
totalSteps: number;
|
||||
uptime: number;
|
||||
};
|
||||
}
|
||||
|
||||
export interface PendingInterventionsResult {
|
||||
pendingInterventions: Array<{
|
||||
lastModified: string;
|
||||
modelRuntimeConfig?: any;
|
||||
pendingHumanPrompt?: any;
|
||||
pendingHumanSelect?: any;
|
||||
pendingToolsCalling?: any[];
|
||||
sessionId: string;
|
||||
status: string;
|
||||
stepCount: number;
|
||||
type: 'tool_approval' | 'human_prompt' | 'human_select';
|
||||
userId?: string;
|
||||
}>;
|
||||
timestamp: string;
|
||||
totalCount: number;
|
||||
}
|
||||
|
||||
export interface StartExecutionParams {
|
||||
context?: AgentRuntimeContext;
|
||||
delay?: number;
|
||||
priority?: 'high' | 'normal' | 'low';
|
||||
sessionId: string;
|
||||
}
|
||||
|
||||
export interface StartExecutionResult {
|
||||
messageId: string;
|
||||
scheduled: boolean;
|
||||
sessionId: string;
|
||||
success: boolean;
|
||||
}
|
||||
@@ -322,6 +322,8 @@ export class MCPService {
|
||||
return {
|
||||
api: tools,
|
||||
identifier,
|
||||
// @ts-ignore
|
||||
mcpParams,
|
||||
meta: {
|
||||
avatar: metadata?.avatar || 'MCP_AVATAR',
|
||||
description:
|
||||
@@ -338,13 +340,15 @@ export class MCPService {
|
||||
params: Omit<StdioMCPParams, 'type'>,
|
||||
metadata?: CustomPluginMetadata,
|
||||
): Promise<LobeChatPluginManifest> {
|
||||
const client = await this.getClient({
|
||||
const mcpParams = {
|
||||
args: params.args,
|
||||
command: params.command,
|
||||
env: params.env,
|
||||
name: params.name,
|
||||
type: 'stdio',
|
||||
}); // Get client using params
|
||||
type: 'stdio' as const,
|
||||
};
|
||||
|
||||
const client = await this.getClient(mcpParams); // Get client using params
|
||||
|
||||
const manifest = await client.listManifests();
|
||||
|
||||
@@ -365,6 +369,8 @@ export class MCPService {
|
||||
title: metadata?.name || identifier,
|
||||
},
|
||||
...manifest,
|
||||
// @ts-ignore
|
||||
mcpParams,
|
||||
// TODO: temporary
|
||||
type: 'mcp' as any,
|
||||
} as LobeChatPluginManifest;
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
import { ChatToolPayload } from '@lobechat/types';
|
||||
import { safeParseJSON } from '@lobechat/utils';
|
||||
import { PluginRequestPayload } from '@lobehub/chat-plugin-sdk';
|
||||
import { Gateway, GatewaySuccessResponse } from '@lobehub/chat-plugins-gateway';
|
||||
import debug from 'debug';
|
||||
|
||||
import { parserPluginSettings } from '@/app/(backend)/webapi/plugin/gateway/settings';
|
||||
import { getAppConfig } from '@/envs/app';
|
||||
import { ToolExecutionContext } from '@/server/services/toolExecution/types';
|
||||
|
||||
const log = debug('lobe-server:plugin-gateway-service');
|
||||
|
||||
export class PluginGatewayService {
|
||||
private gateway: Gateway;
|
||||
|
||||
constructor() {
|
||||
const { PLUGINS_INDEX_URL, PLUGIN_SETTINGS } = getAppConfig();
|
||||
|
||||
this.gateway = new Gateway({
|
||||
defaultPluginSettings: parserPluginSettings(PLUGIN_SETTINGS),
|
||||
pluginsIndexUrl: PLUGINS_INDEX_URL,
|
||||
});
|
||||
}
|
||||
|
||||
async execute(payload: ChatToolPayload, context: ToolExecutionContext) {
|
||||
const { identifier, apiName, arguments: argsStr } = payload;
|
||||
const args = safeParseJSON(argsStr) || {};
|
||||
|
||||
log('Executing plugin: %s:%s with args: %O', identifier, apiName, args, context);
|
||||
|
||||
try {
|
||||
// Construct plugin request
|
||||
const requestBody: PluginRequestPayload = {
|
||||
apiName,
|
||||
arguments: JSON.stringify(args),
|
||||
identifier,
|
||||
manifest: context.toolManifestMap[identifier] as any,
|
||||
};
|
||||
|
||||
const response = await this.gateway.execute(requestBody);
|
||||
|
||||
log('Plugin execution result: %O', response);
|
||||
|
||||
return {
|
||||
content: (response as GatewaySuccessResponse).data,
|
||||
success: true,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error('Error executing plugin %s:%s: %O', identifier, apiName, error);
|
||||
return {
|
||||
content: (error as Error).message,
|
||||
error: {
|
||||
message: (error as Error).message,
|
||||
},
|
||||
success: false,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
import { QueueServiceImpl, createQueueServiceModule } from './impls';
|
||||
import { HealthCheckResult, QueueMessage, QueueStats } from './types';
|
||||
|
||||
/**
|
||||
* Queue Service
|
||||
* Uses modular implementation approach to provide queue operation services
|
||||
*/
|
||||
export class QueueService {
|
||||
private impl: QueueServiceImpl;
|
||||
|
||||
constructor() {
|
||||
this.impl = createQueueServiceModule();
|
||||
}
|
||||
|
||||
/**
|
||||
* Schedule a message to the queue
|
||||
*/
|
||||
async scheduleMessage(message: QueueMessage): Promise<string> {
|
||||
return this.impl.scheduleMessage(message);
|
||||
}
|
||||
|
||||
/**
|
||||
* Schedule multiple messages to the queue
|
||||
*/
|
||||
async scheduleBatchMessages(messages: QueueMessage[]): Promise<string[]> {
|
||||
return this.impl.scheduleBatchMessages(messages);
|
||||
}
|
||||
|
||||
/**
|
||||
* Cancel scheduled task
|
||||
*/
|
||||
async cancelScheduledTask(taskId: string): Promise<void> {
|
||||
return this.impl.cancelScheduledTask(taskId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get queue statistics
|
||||
*/
|
||||
async getQueueStats(): Promise<QueueStats> {
|
||||
return this.impl.getQueueStats();
|
||||
}
|
||||
|
||||
/**
|
||||
* Health check
|
||||
*/
|
||||
async healthCheck(): Promise<HealthCheckResult> {
|
||||
return this.impl.healthCheck();
|
||||
}
|
||||
|
||||
/**
|
||||
* Calculate delay time (dynamically adjusted based on different situations)
|
||||
*/
|
||||
static calculateDelay(params: {
|
||||
hasErrors: boolean;
|
||||
hasToolCalls: boolean;
|
||||
priority: 'high' | 'normal' | 'low';
|
||||
stepIndex: number;
|
||||
}): number {
|
||||
const { stepIndex, hasErrors, hasToolCalls, priority } = params;
|
||||
|
||||
let baseDelay = 1000; // 1 second base delay
|
||||
|
||||
// Adjust based on priority
|
||||
switch (priority) {
|
||||
case 'high': {
|
||||
baseDelay = 200;
|
||||
break;
|
||||
}
|
||||
case 'low': {
|
||||
baseDelay = 5000;
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
baseDelay = 1000;
|
||||
}
|
||||
}
|
||||
|
||||
// If there are tool calls, delay a bit longer to wait for tool execution completion
|
||||
if (hasToolCalls) {
|
||||
baseDelay += 1000;
|
||||
}
|
||||
|
||||
// If there are errors, delay longer to avoid consecutive failures
|
||||
if (hasErrors) {
|
||||
baseDelay += Math.min(stepIndex * 1000, 10_000); // Exponential backoff, max 10 seconds
|
||||
}
|
||||
|
||||
return baseDelay;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
import { QStashQueueServiceImpl } from './qstash';
|
||||
import { SimpleQueueServiceImpl } from './simple';
|
||||
import { QueueServiceImpl } from './type';
|
||||
|
||||
/**
|
||||
* Create queue service module
|
||||
* Automatically select QStash or simple queue implementation based on environment variables
|
||||
*/
|
||||
export const createQueueServiceModule = (): QueueServiceImpl => {
|
||||
// Check if QStash is configured
|
||||
const qstashToken = process.env.QSTASH_TOKEN;
|
||||
|
||||
if (qstashToken) {
|
||||
return new QStashQueueServiceImpl({ qstashToken });
|
||||
}
|
||||
|
||||
return new SimpleQueueServiceImpl();
|
||||
};
|
||||
|
||||
export type { QueueServiceImpl } from './type';
|
||||
@@ -0,0 +1,117 @@
|
||||
import { Client } from '@upstash/qstash';
|
||||
import debug from 'debug';
|
||||
|
||||
import { HealthCheckResult, QueueMessage, QueueStats } from '../types';
|
||||
import { QueueServiceImpl } from './type';
|
||||
|
||||
const log = debug('lobe-server:service:queue:qstash');
|
||||
|
||||
/**
|
||||
* QStash queue service implementation
|
||||
*/
|
||||
export class QStashQueueServiceImpl implements QueueServiceImpl {
|
||||
private qstashClient: Client;
|
||||
|
||||
constructor(config: { publishUrl?: string; qstashToken: string }) {
|
||||
if (!config.qstashToken) {
|
||||
throw new Error('QStash token is required for queue service');
|
||||
}
|
||||
|
||||
this.qstashClient = new Client({ token: config.qstashToken });
|
||||
log('Initialized QStash queue service');
|
||||
}
|
||||
|
||||
async scheduleMessage(message: QueueMessage): Promise<string> {
|
||||
const {
|
||||
sessionId,
|
||||
stepIndex,
|
||||
context,
|
||||
endpoint,
|
||||
payload,
|
||||
delay = 50,
|
||||
priority = 'normal',
|
||||
retries = 3,
|
||||
} = message;
|
||||
|
||||
try {
|
||||
const response = await this.qstashClient.publishJSON({
|
||||
body: {
|
||||
context,
|
||||
payload,
|
||||
priority,
|
||||
sessionId,
|
||||
stepIndex,
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
delay: Math.ceil(delay / 1000), // 将毫秒转换为秒
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'X-Agent-Priority': priority,
|
||||
'X-Agent-Session-Id': sessionId,
|
||||
'X-Agent-Step-Index': stepIndex.toString(),
|
||||
},
|
||||
retries,
|
||||
url: endpoint,
|
||||
});
|
||||
|
||||
log(
|
||||
`[${sessionId}] Scheduled step %d to %s with %dms delay (messageId: %s)`,
|
||||
stepIndex,
|
||||
endpoint,
|
||||
delay,
|
||||
'messageId' in response ? response.messageId : 'batch-message',
|
||||
);
|
||||
|
||||
return 'messageId' in response ? response.messageId : `scheduled-${Date.now()}`;
|
||||
} catch (error) {
|
||||
log('Failed to schedule step %d for session %s: %O', stepIndex, sessionId, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async scheduleBatchMessages(messages: QueueMessage[]): Promise<string[]> {
|
||||
try {
|
||||
// Use Promise.all for concurrent execution
|
||||
const messageIds = await Promise.all(
|
||||
messages.map((message) => this.scheduleMessage(message)),
|
||||
);
|
||||
|
||||
log('Scheduled %d batch messages', messages.length);
|
||||
return messageIds;
|
||||
} catch (error) {
|
||||
log('Failed to schedule batch messages: %O', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async cancelScheduledTask(messageId: string): Promise<void> {
|
||||
try {
|
||||
// QStash currently doesn't support task cancellation, can record to Redis as cancellation marker
|
||||
// Check this marker during actual execution
|
||||
log('Requested cancellation for message %s', messageId);
|
||||
|
||||
// TODO: Implement cancellation logic, can store cancellation list via Redis
|
||||
// await this.redis.sadd('cancelled_tasks', messageId);
|
||||
} catch (error) {
|
||||
log('Failed to cancel task %s: %O', messageId, error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async getQueueStats(): Promise<QueueStats> {
|
||||
return {
|
||||
completedCount: 0,
|
||||
failedCount: 0,
|
||||
pendingCount: 0,
|
||||
processingCount: 0,
|
||||
};
|
||||
}
|
||||
|
||||
async healthCheck(): Promise<HealthCheckResult> {
|
||||
// Simple health check without sending actual messages
|
||||
return {
|
||||
healthy: true,
|
||||
message: 'QStash queue service is ready',
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
import debug from 'debug';
|
||||
|
||||
import { HealthCheckResult, QueueMessage, QueueStats } from '../types';
|
||||
import { QueueServiceImpl } from './type';
|
||||
|
||||
const log = debug('queue:simple');
|
||||
|
||||
/**
|
||||
* Simplified queue service implementation for scenarios not using QStash
|
||||
*/
|
||||
export class SimpleQueueServiceImpl implements QueueServiceImpl {
|
||||
// eslint-disable-next-line no-undef
|
||||
private timeouts: Map<string, NodeJS.Timeout> = new Map();
|
||||
|
||||
async scheduleMessage(message: QueueMessage): Promise<string> {
|
||||
const { sessionId, stepIndex, context, endpoint, payload, delay = 1000 } = message;
|
||||
|
||||
const taskId = `${sessionId}_${stepIndex}_${Date.now()}`;
|
||||
|
||||
const timeout = setTimeout(async () => {
|
||||
try {
|
||||
// Directly call execution endpoint
|
||||
const response = await fetch(endpoint, {
|
||||
body: JSON.stringify({
|
||||
context,
|
||||
payload,
|
||||
sessionId,
|
||||
stepIndex,
|
||||
timestamp: Date.now(),
|
||||
}),
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
method: 'POST',
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP ${response.status}: ${response.statusText}`);
|
||||
}
|
||||
|
||||
log('Executed step %d for session %s to endpoint %s', stepIndex, sessionId, endpoint);
|
||||
} catch (error) {
|
||||
log('Failed to execute step %d for session %s: %O', stepIndex, sessionId, error);
|
||||
} finally {
|
||||
this.timeouts.delete(taskId);
|
||||
}
|
||||
}, delay);
|
||||
|
||||
this.timeouts.set(taskId, timeout);
|
||||
|
||||
log('Scheduled step %d for session %s to %s with %dms delay', stepIndex, sessionId, endpoint, delay);
|
||||
|
||||
return taskId;
|
||||
}
|
||||
|
||||
async scheduleBatchMessages(messages: QueueMessage[]): Promise<string[]> {
|
||||
const taskIds: string[] = [];
|
||||
|
||||
try {
|
||||
for (const message of messages) {
|
||||
const taskId = await this.scheduleMessage(message);
|
||||
taskIds.push(taskId);
|
||||
}
|
||||
|
||||
log('Scheduled %d batch messages', messages.length);
|
||||
return taskIds;
|
||||
} catch (error) {
|
||||
log('Failed to schedule batch messages: %O', error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
async cancelScheduledTask(taskId: string): Promise<void> {
|
||||
const timeout = this.timeouts.get(taskId);
|
||||
if (timeout) {
|
||||
clearTimeout(timeout);
|
||||
this.timeouts.delete(taskId);
|
||||
log('Cancelled task %s', taskId);
|
||||
}
|
||||
}
|
||||
|
||||
async getQueueStats(): Promise<QueueStats> {
|
||||
return {
|
||||
completedCount: 0,
|
||||
failedCount: 0,
|
||||
pendingCount: this.timeouts.size,
|
||||
processingCount: 0,
|
||||
};
|
||||
}
|
||||
|
||||
async healthCheck(): Promise<HealthCheckResult> {
|
||||
return {
|
||||
healthy: true,
|
||||
message: `Simple queue service healthy, ${this.timeouts.size} pending tasks`,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
import { HealthCheckResult, QueueMessage, QueueStats } from '../types';
|
||||
|
||||
/**
|
||||
* Queue service implementation interface
|
||||
*/
|
||||
export interface QueueServiceImpl {
|
||||
/**
|
||||
* Cancel scheduled task
|
||||
*/
|
||||
cancelScheduledTask(taskId: string): Promise<void>;
|
||||
|
||||
/**
|
||||
* Get queue statistics
|
||||
*/
|
||||
getQueueStats(): Promise<QueueStats>;
|
||||
|
||||
/**
|
||||
* Health check
|
||||
*/
|
||||
healthCheck(): Promise<HealthCheckResult>;
|
||||
|
||||
/**
|
||||
* Schedule multiple messages to the queue
|
||||
*/
|
||||
scheduleBatchMessages(messages: QueueMessage[]): Promise<string[]>;
|
||||
|
||||
/**
|
||||
* Schedule a message to the queue
|
||||
*/
|
||||
scheduleMessage(message: QueueMessage): Promise<string>;
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
export type { QueueServiceImpl } from './impls';
|
||||
export { QueueService } from './QueueService';
|
||||
export type { HealthCheckResult, QueueMessage, QueueStats } from './types';
|
||||
@@ -0,0 +1,24 @@
|
||||
import { AgentRuntimeContext } from '@lobechat/agent-runtime';
|
||||
|
||||
export interface QueueMessage {
|
||||
context?: AgentRuntimeContext;
|
||||
delay?: number;
|
||||
endpoint: string;
|
||||
payload?: any;
|
||||
priority?: 'high' | 'normal' | 'low';
|
||||
retries?: number;
|
||||
sessionId: string;
|
||||
stepIndex: number;
|
||||
}
|
||||
|
||||
export interface QueueStats {
|
||||
completedCount: number;
|
||||
failedCount: number;
|
||||
pendingCount: number;
|
||||
processingCount: number;
|
||||
}
|
||||
|
||||
export interface HealthCheckResult {
|
||||
healthy: boolean;
|
||||
message?: string;
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
import { ChatToolPayload } from '@lobechat/types';
|
||||
import { safeParseJSON } from '@lobechat/utils';
|
||||
import debug from 'debug';
|
||||
|
||||
import { SearchService } from '@/server/services/search';
|
||||
import { BuiltinToolServerRuntimes } from '@/tools/executionRuntimes';
|
||||
|
||||
import { IToolExecutor, ToolExecutionContext, ToolExecutionResult } from './types';
|
||||
|
||||
const log = debug('lobe-server:builtin-tools-executor');
|
||||
|
||||
export class BuiltinToolsExecutor implements IToolExecutor {
|
||||
async execute(
|
||||
payload: ChatToolPayload,
|
||||
context: ToolExecutionContext,
|
||||
): Promise<ToolExecutionResult> {
|
||||
const { identifier, apiName, arguments: argsStr } = payload;
|
||||
const args = safeParseJSON(argsStr) || {};
|
||||
|
||||
log('Executing builtin tool: %s:%s with args: %O', identifier, apiName, args, context);
|
||||
|
||||
const ServerRuntime = BuiltinToolServerRuntimes[identifier];
|
||||
|
||||
if (!ServerRuntime) {
|
||||
throw new Error(`Builtin tool "${identifier}" is not implemented`);
|
||||
}
|
||||
|
||||
const runtime = new ServerRuntime({
|
||||
searchService: new SearchService(),
|
||||
});
|
||||
|
||||
if (!runtime[apiName]) {
|
||||
throw new Error(`Builtin tool ${identifier} 's ${apiName} is not implemented`);
|
||||
}
|
||||
|
||||
try {
|
||||
return await runtime[apiName](args);
|
||||
} catch (e) {
|
||||
const error = e as Error;
|
||||
console.error('Error executing builtin tool %s:%s: %O', identifier, apiName, error);
|
||||
|
||||
return { content: error.message, error: error, success: false };
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,149 @@
|
||||
import { ChatToolPayload } from '@lobechat/types';
|
||||
import debug from 'debug';
|
||||
|
||||
import { MCPService } from '../mcp';
|
||||
import { PluginGatewayService } from '../pluginGateway';
|
||||
import { BuiltinToolsExecutor } from './builtin';
|
||||
import { ToolExecutionContext, ToolExecutionResult, ToolExecutionResultResponse } from './types';
|
||||
|
||||
const log = debug('lobe-server:tool-execution-service');
|
||||
|
||||
interface ToolExecutionServiceDeps {
|
||||
builtinToolsExecutor: BuiltinToolsExecutor;
|
||||
mcpService: MCPService;
|
||||
pluginGatewayService: PluginGatewayService;
|
||||
}
|
||||
|
||||
export class ToolExecutionService {
|
||||
private builtinToolsExecutor: BuiltinToolsExecutor;
|
||||
private mcpService: MCPService;
|
||||
private pluginGatewayService: PluginGatewayService;
|
||||
|
||||
constructor({
|
||||
mcpService,
|
||||
pluginGatewayService,
|
||||
builtinToolsExecutor,
|
||||
}: ToolExecutionServiceDeps) {
|
||||
this.builtinToolsExecutor = builtinToolsExecutor;
|
||||
this.mcpService = mcpService;
|
||||
this.pluginGatewayService = pluginGatewayService;
|
||||
}
|
||||
|
||||
async executeTool(
|
||||
payload: ChatToolPayload,
|
||||
context: ToolExecutionContext,
|
||||
): Promise<ToolExecutionResultResponse> {
|
||||
const { identifier, apiName, type } = payload;
|
||||
|
||||
log('Executing tool: %s:%s (type: %s)', identifier, apiName, type);
|
||||
|
||||
const startTime = Date.now();
|
||||
try {
|
||||
const typeStr = type as string;
|
||||
let data: ToolExecutionResult;
|
||||
switch (typeStr) {
|
||||
case 'builtin': {
|
||||
data = await this.builtinToolsExecutor.execute(payload, context);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'mcp': {
|
||||
data = await this.executeMCPTool(payload, context);
|
||||
break;
|
||||
}
|
||||
|
||||
default: {
|
||||
data = await this.pluginGatewayService.execute(payload, context);
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
const executionTime = Date.now() - startTime;
|
||||
|
||||
return {
|
||||
...data,
|
||||
executionTime,
|
||||
};
|
||||
|
||||
// Handle MCP and other types (default, standalone, markdown, mcp)
|
||||
} catch (error) {
|
||||
const executionTime = Date.now() - startTime;
|
||||
log('Error executing tool %s:%s: %O', identifier, apiName, error);
|
||||
return {
|
||||
content: (error as Error).message,
|
||||
error: {
|
||||
message: (error as Error).message,
|
||||
},
|
||||
executionTime,
|
||||
success: false,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private async executeMCPTool(
|
||||
payload: ChatToolPayload,
|
||||
context: ToolExecutionContext,
|
||||
): Promise<ToolExecutionResult> {
|
||||
const { identifier, apiName, arguments: args } = payload;
|
||||
|
||||
log('Executing MCP tool: %s:%s', identifier, apiName);
|
||||
|
||||
// Get the manifest from context
|
||||
const manifest = context.toolManifestMap[identifier];
|
||||
if (!manifest) {
|
||||
log('Manifest not found for MCP tool: %s', identifier);
|
||||
return {
|
||||
content: `Manifest not found for tool: ${identifier}`,
|
||||
error: {
|
||||
code: 'MANIFEST_NOT_FOUND',
|
||||
message: `Manifest not found for tool: ${identifier}`,
|
||||
},
|
||||
success: false,
|
||||
};
|
||||
}
|
||||
|
||||
// Extract MCP params from manifest (stored in customParams.mcp in LobeTool)
|
||||
const mcpParams = (manifest as any).mcpParams;
|
||||
if (!mcpParams) {
|
||||
log('MCP configuration not found in manifest for: %s ', identifier);
|
||||
return {
|
||||
content: `MCP configuration not found for tool: ${identifier}, please tell user TRY TO REINSTALL THE MCP PLUGIN`,
|
||||
error: {
|
||||
code: 'MCP_CONFIG_NOT_FOUND',
|
||||
message: `MCP configuration not found for tool: ${identifier}`,
|
||||
},
|
||||
success: false,
|
||||
};
|
||||
}
|
||||
|
||||
// Construct MCPClientParams from the mcp config
|
||||
|
||||
log('Calling MCP service with params for: %s:%s', identifier, apiName);
|
||||
|
||||
try {
|
||||
// Call the MCP service
|
||||
const result = await this.mcpService.callTool(mcpParams, apiName, args);
|
||||
|
||||
log('MCP tool execution successful for: %s:%s', identifier, apiName);
|
||||
|
||||
return {
|
||||
content: typeof result === 'string' ? result : JSON.stringify(result),
|
||||
state: typeof result === 'object' ? result : undefined,
|
||||
success: true,
|
||||
};
|
||||
} catch (error) {
|
||||
log('MCP tool execution failed for %s:%s: %O', identifier, apiName, error);
|
||||
return {
|
||||
content: (error as Error).message,
|
||||
error: {
|
||||
code: 'MCP_EXECUTION_ERROR',
|
||||
message: (error as Error).message,
|
||||
},
|
||||
success: false,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export * from './types';
|
||||
@@ -0,0 +1,23 @@
|
||||
import { LobeToolManifest } from '@lobechat/context-engine';
|
||||
import { ChatToolPayload, ClientSecretPayload } from '@lobechat/types';
|
||||
|
||||
export interface ToolExecutionContext {
|
||||
toolManifestMap: Record<string, LobeToolManifest>;
|
||||
userId?: string;
|
||||
userPayload?: ClientSecretPayload;
|
||||
}
|
||||
|
||||
export interface ToolExecutionResult {
|
||||
content: string;
|
||||
error?: any;
|
||||
state?: Record<string, any>;
|
||||
success: boolean;
|
||||
}
|
||||
|
||||
export interface ToolExecutionResultResponse extends ToolExecutionResult {
|
||||
executionTime: number;
|
||||
}
|
||||
|
||||
export interface IToolExecutor {
|
||||
execute(payload: ChatToolPayload, context: ToolExecutionContext): Promise<ToolExecutionResult>;
|
||||
}
|
||||
@@ -0,0 +1,319 @@
|
||||
// @ts-nocheck
|
||||
// @vitest-environment happy-dom
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { agentRuntimeClient } from '../client';
|
||||
|
||||
// Mock EventSource
|
||||
class MockEventSource {
|
||||
url: string;
|
||||
readyState: number = 0;
|
||||
onopen: ((event: any) => void) | null = null;
|
||||
onmessage: ((event: any) => void) | null = null;
|
||||
onerror: ((event: any) => void) | null = null;
|
||||
|
||||
private listeners: { [key: string]: Array<(event: any) => void> } = {};
|
||||
|
||||
constructor(url: string) {
|
||||
this.url = url;
|
||||
this.readyState = 1; // OPEN
|
||||
|
||||
// Simulate connection opening - use setImmediate to ensure listeners are set up first
|
||||
setImmediate(() => {
|
||||
this.onopen?.({ type: 'open' });
|
||||
this.listeners['open']?.forEach((listener) => listener({ type: 'open' }));
|
||||
});
|
||||
}
|
||||
|
||||
addEventListener(type: string, listener: (event: any) => void) {
|
||||
if (!this.listeners[type]) {
|
||||
this.listeners[type] = [];
|
||||
}
|
||||
this.listeners[type].push(listener);
|
||||
}
|
||||
|
||||
removeEventListener(type: string, listener: (event: any) => void) {
|
||||
if (this.listeners[type]) {
|
||||
const index = this.listeners[type].indexOf(listener);
|
||||
if (index > -1) {
|
||||
this.listeners[type].splice(index, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
close() {
|
||||
this.readyState = 2; // CLOSED
|
||||
}
|
||||
|
||||
// Test helper method to simulate receiving messages
|
||||
_simulateMessage(data: string, eventType?: string, id?: string) {
|
||||
const event = {
|
||||
data,
|
||||
type: 'message',
|
||||
lastEventId: id || '',
|
||||
};
|
||||
|
||||
this.onmessage?.(event);
|
||||
}
|
||||
|
||||
// Test helper method to simulate errors
|
||||
_simulateError() {
|
||||
this.onerror?.({ type: 'error' });
|
||||
}
|
||||
}
|
||||
|
||||
// Mock global EventSource
|
||||
global.EventSource = MockEventSource as any;
|
||||
|
||||
describe('AgentRuntimeClient', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('createStreamConnection', () => {
|
||||
it('should create EventSource with correct URL and parameters', () => {
|
||||
const sessionId = 'agent_1758302563222_0g28qmdmu';
|
||||
|
||||
const eventSource = agentRuntimeClient.createStreamConnection(sessionId, {
|
||||
includeHistory: false,
|
||||
lastEventId: '0',
|
||||
});
|
||||
|
||||
expect(eventSource.url).toBe(
|
||||
'/api/agent/stream?includeHistory=false&lastEventId=0&sessionId=agent_1758302563222_0g28qmdmu',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle complete agent runtime lifecycle with real stream data', async () => {
|
||||
const sessionId = 'agent_1758302563222_abc';
|
||||
const events: any[] = [];
|
||||
let connectCalled = false;
|
||||
let disconnectCalled = false;
|
||||
let errorCalled = false;
|
||||
|
||||
const eventSource = agentRuntimeClient.createStreamConnection(sessionId, {
|
||||
includeHistory: false,
|
||||
onConnect: () => {
|
||||
connectCalled = true;
|
||||
},
|
||||
onDisconnect: () => {
|
||||
disconnectCalled = true;
|
||||
},
|
||||
onError: (error) => {
|
||||
errorCalled = true;
|
||||
},
|
||||
onEvent: (event) => {
|
||||
events.push(event);
|
||||
},
|
||||
}) as MockEventSource;
|
||||
|
||||
// Wait for connection
|
||||
await new Promise((resolve) => setTimeout(resolve, 20));
|
||||
expect(connectCalled).toBe(true);
|
||||
|
||||
// Simulate the complete SSE stream from the real data
|
||||
const streamEvents = [
|
||||
// 1. Connected event
|
||||
`{"lastEventId":"0","sessionId":"${sessionId}","timestamp":1758302567925,"type":"connected"}`,
|
||||
|
||||
// 2. Agent runtime init
|
||||
`{"type":"agent_runtime_init","stepIndex":0,"sessionId":"${sessionId}","data":{"agentConfig":{"enableSearch":true,"maxSteps":50},"createdAt":"2025-09-19T17:22:43.222Z","lastActiveAt":"2025-09-19T17:22:43.222Z","modelRuntimeConfig":{"model":"gpt-5-mini","provider":"openai"},"status":"idle","totalCost":0,"totalSteps":0,"userId":"inbox"},"timestamp":1758302564421,"id":"1758302567005-0"}`,
|
||||
|
||||
// 3. Step start
|
||||
`{"type":"step_start","stepIndex":0,"sessionId":"${sessionId}","data":{"sessionId":"${sessionId}","stepIndex":0},"timestamp":1758302573386,"id":"1758302573993-0"}`,
|
||||
|
||||
// 4. Stream start
|
||||
`{"type":"stream_start","stepIndex":0,"sessionId":"${sessionId}","data":{"messageId":"unknown","model":"gpt-5-mini","provider":"openai","sessionId":"${sessionId}"},"timestamp":1758302574552,"id":"1758302574662-0"}`,
|
||||
|
||||
// 5. Some stream chunks (sample)
|
||||
`{"type":"stream_chunk","stepIndex":0,"sessionId":"${sessionId}","data":{"chunkType":"text","content":"Do","fullContent":"Do","messageId":"unknown"},"timestamp":1758302578042,"id":"1758302578151-0"}`,
|
||||
`{"type":"stream_chunk","stepIndex":0,"sessionId":"${sessionId}","data":{"chunkType":"text","content":" you","fullContent":"Do you","messageId":"unknown"},"timestamp":1758302578490,"id":"1758302578600-0"}`,
|
||||
`{"type":"stream_chunk","stepIndex":0,"sessionId":"${sessionId}","data":{"chunkType":"text","content":" mean","fullContent":"Do you mean","messageId":"unknown"},"timestamp":1758302578935,"id":"1758302579045-0"}`,
|
||||
|
||||
// 6. Stream end
|
||||
`{"type":"stream_end","stepIndex":0,"sessionId":"${sessionId}","data":{"finalContent":"Do you mean the number 123? How can I help with it — conversions, math, encoding, or something else?\\n\\nQuick facts in case it's useful:\\n- Decimal: 123\\n- Binary: 1111011\\n- Hex: 0x7B\\n- Octal: 173\\n- Prime factorization: 3 × 41\\n- Sum of digits: 1+2+3 = 6\\n- ASCII code 123 = '{'\\n\\nTell me which of these (or something else) you want.","grounding":null,"messageId":"unknown","toolCalls":[]},"timestamp":1758302626595,"id":"1758302626704-0"}`,
|
||||
|
||||
// 7. Step complete (first step)
|
||||
`{"type":"step_complete","stepIndex":0,"sessionId":"${sessionId}","data":{"finalState":{"events":[],"lastModified":"2025-09-19T17:22:54.552Z","messages":[{"content":"123","role":"user"},{"content":"Do you mean the number 123? How can I help with it — conversions, math, encoding, or something else?\\n\\nQuick facts in case it's useful:\\n- Decimal: 123\\n- Binary: 1111011\\n- Hex: 0x7B\\n- Octal: 173\\n- Prime factorization: 3 × 41\\n- Sum of digits: 1+2+3 = 6\\n- ASCII code 123 = '{'\\n\\nTell me which of these (or something else) you want.","role":"assistant"}],"metadata":{"agentConfig":{"enableSearch":true,"maxSteps":50},"createdAt":"2025-09-19T17:22:43.222Z","modelRuntimeConfig":{"model":"gpt-5-mini","provider":"openai"},"userId":"inbox"},"sessionId":"${sessionId}","status":"idle","stepCount":1},"nextStepScheduled":true,"stepIndex":0},"timestamp":1758302627846,"id":"1758302627955-0"}`,
|
||||
|
||||
// 8. Step start (second step)
|
||||
`{"type":"step_start","stepIndex":1,"sessionId":"${sessionId}","data":{"sessionId":"${sessionId}","stepIndex":1},"timestamp":1758302629691,"id":"1758302629800-0"}`,
|
||||
|
||||
// 9. Step complete (final execution)
|
||||
`{"type":"step_complete","stepIndex":0,"sessionId":"${sessionId}","data":{"finalState":{"events":[],"lastModified":"2025-09-19T17:23:50.360Z","messages":[{"content":"123","role":"user"},{"content":"Do you mean the number 123? How can I help with it — conversions, math, encoding, or something else?\\n\\nQuick facts in case it's useful:\\n- Decimal: 123\\n- Binary: 1111011\\n- Hex: 0x7B\\n- Octal: 173\\n- Prime factorization: 3 × 41\\n- Sum of digits: 1+2+3 = 6\\n- ASCII code 123 = '{'\\n\\nTell me which of these (or something else) you want.","role":"assistant"}],"metadata":{"agentConfig":{"enableSearch":true,"maxSteps":50},"createdAt":"2025-09-19T17:22:43.222Z","modelRuntimeConfig":{"model":"gpt-5-mini","provider":"openai"},"userId":"inbox"},"sessionId":"${sessionId}","status":"done","stepCount":2},"phase":"execution_complete","reason":"completed","reasonDetail":"Simple agent completed successfully"},"timestamp":1758302630360,"id":"1758302630469-0"}`,
|
||||
|
||||
// 10. Agent runtime end
|
||||
`{"type":"agent_runtime_end","stepIndex":1,"sessionId":"${sessionId}","data":{"finalState":{"events":[],"lastModified":"2025-09-19T17:23:50.360Z","messages":[{"content":"123","role":"user"},{"content":"Do you mean the number 123? How can I help with it — conversions, math, encoding, or something else?\\n\\nQuick facts in case it's useful:\\n- Decimal: 123\\n- Binary: 1111011\\n- Hex: 0x7B\\n- Octal: 173\\n- Prime factorization: 3 × 41\\n- Sum of digits: 1+2+3 = 6\\n- ASCII code 123 = '{'\\n\\nTell me which of these (or something else) you want.","role":"assistant"}],"metadata":{"agentConfig":{"enableSearch":true,"maxSteps":50},"createdAt":"2025-09-19T17:22:43.222Z","modelRuntimeConfig":{"model":"gpt-5-mini","provider":"openai"},"userId":"inbox"},"sessionId":"${sessionId}","status":"done","stepCount":2},"phase":"execution_complete","reason":"completed","reasonDetail":"Agent runtime completed successfully","sessionId":"${sessionId}"},"timestamp":1758302631030,"id":"1758302631139-0"}`,
|
||||
|
||||
// 11. Final step complete
|
||||
`{"type":"step_complete","stepIndex":1,"sessionId":"${sessionId}","data":{"finalState":{"events":[],"lastModified":"2025-09-19T17:23:50.360Z","messages":[{"content":"123","role":"user"},{"content":"Do you mean the number 123? How can I help with it — conversions, math, encoding, or something else?\\n\\nQuick facts in case it's useful:\\n- Decimal: 123\\n- Binary: 1111011\\n- Hex: 0x7B\\n- Octal: 173\\n- Prime factorization: 3 × 41\\n- Sum of digits: 1+2+3 = 6\\n- ASCII code 123 = '{'\\n\\nTell me which of these (or something else) you want.","role":"assistant"}],"metadata":{"agentConfig":{"enableSearch":true,"maxSteps":50},"createdAt":"2025-09-19T17:22:43.222Z","modelRuntimeConfig":{"model":"gpt-5-mini","provider":"openai"},"userId":"inbox"},"sessionId":"${sessionId}","status":"done","stepCount":2},"nextStepScheduled":false,"stepIndex":1},"timestamp":1758302631475,"id":"1758302631584-0"}`,
|
||||
];
|
||||
|
||||
// Simulate receiving all events
|
||||
for (const eventData of streamEvents) {
|
||||
eventSource._simulateMessage(eventData);
|
||||
// Small delay to simulate real-time streaming
|
||||
await new Promise((resolve) => setTimeout(resolve, 5));
|
||||
}
|
||||
|
||||
// Verify all events were received and parsed correctly (13 total events)
|
||||
expect(events).toHaveLength(13);
|
||||
|
||||
// Verify event sequence
|
||||
expect(events[0].type).toBe('connected');
|
||||
expect(events[1].type).toBe('agent_runtime_init');
|
||||
expect(events[2].type).toBe('step_start');
|
||||
expect(events[3].type).toBe('stream_start');
|
||||
expect(events[4].type).toBe('stream_chunk');
|
||||
expect(events[5].type).toBe('stream_chunk');
|
||||
expect(events[6].type).toBe('stream_chunk');
|
||||
expect(events[7].type).toBe('stream_end');
|
||||
expect(events[8].type).toBe('step_complete');
|
||||
expect(events[9].type).toBe('step_start');
|
||||
expect(events[10].type).toBe('step_complete');
|
||||
expect(events[11].type).toBe('agent_runtime_end');
|
||||
expect(events[12].type).toBe('step_complete');
|
||||
|
||||
// Verify specific event data
|
||||
const connectedEvent = events[0];
|
||||
expect(connectedEvent.sessionId).toBe(sessionId);
|
||||
expect(connectedEvent.lastEventId).toBe('0');
|
||||
|
||||
const initEvent = events[1];
|
||||
expect(initEvent.stepIndex).toBe(0);
|
||||
expect(initEvent.initialState.agentConfig.enableSearch).toBe(true);
|
||||
expect(initEvent.initialState.modelRuntimeConfig.model).toBe('gpt-5-mini');
|
||||
|
||||
const streamStartEvent = events[3];
|
||||
expect(streamStartEvent.data.model).toBe('gpt-5-mini');
|
||||
expect(streamStartEvent.data.provider).toBe('openai');
|
||||
|
||||
const streamChunkEvent = events[4];
|
||||
expect(streamChunkEvent.data.chunkType).toBe('text');
|
||||
expect(streamChunkEvent.data.content).toBe('Do');
|
||||
expect(streamChunkEvent.data.fullContent).toBe('Do');
|
||||
|
||||
const streamEndEvent = events[7];
|
||||
expect(streamEndEvent.data.finalContent).toContain('Do you mean the number 123?');
|
||||
expect(streamEndEvent.data.toolCalls).toEqual([]);
|
||||
|
||||
const stepCompleteEvent = events[8];
|
||||
expect(stepCompleteEvent.data.finalState.status).toBe('idle');
|
||||
expect(stepCompleteEvent.data.nextStepScheduled).toBe(true);
|
||||
|
||||
const agentRuntimeEndEvent = events[11];
|
||||
expect(agentRuntimeEndEvent.type).toBe('agent_runtime_end');
|
||||
expect(agentRuntimeEndEvent.data.finalState.status).toBe('done');
|
||||
|
||||
const finalStepCompleteEvent = events[12];
|
||||
expect(finalStepCompleteEvent.type).toBe('step_complete');
|
||||
expect(finalStepCompleteEvent.data.finalState.status).toBe('done');
|
||||
expect(finalStepCompleteEvent.data.nextStepScheduled).toBe(false);
|
||||
|
||||
// Verify session ID consistency
|
||||
events.forEach((event, index) => {
|
||||
expect(event.sessionId).toBe(sessionId);
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle heartbeat events correctly', async () => {
|
||||
const sessionId = 'test-session';
|
||||
const events: any[] = [];
|
||||
|
||||
const eventSource = agentRuntimeClient.createStreamConnection(sessionId, {
|
||||
onEvent: (event) => {
|
||||
events.push(event);
|
||||
},
|
||||
}) as MockEventSource;
|
||||
|
||||
// Simulate heartbeat event (no event type, just data)
|
||||
const heartbeatData = `{"sessionId":"${sessionId}","timestamp":1758302597927,"type":"heartbeat"}`;
|
||||
eventSource._simulateMessage(heartbeatData);
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
|
||||
expect(events).toHaveLength(1);
|
||||
expect(events[0].type).toBe('heartbeat');
|
||||
expect(events[0].sessionId).toBe(sessionId);
|
||||
});
|
||||
|
||||
it('should handle connection errors', async () => {
|
||||
const sessionId = 'test-session';
|
||||
let errorOccurred = false;
|
||||
|
||||
const eventSource = agentRuntimeClient.createStreamConnection(sessionId, {
|
||||
onError: (error) => {
|
||||
errorOccurred = true;
|
||||
},
|
||||
}) as MockEventSource;
|
||||
|
||||
// Simulate error
|
||||
eventSource._simulateError();
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
|
||||
expect(errorOccurred).toBe(true);
|
||||
});
|
||||
|
||||
it('should handle malformed JSON gracefully', async () => {
|
||||
const sessionId = 'test-session';
|
||||
const events: any[] = [];
|
||||
let errorOccurred = false;
|
||||
|
||||
const eventSource = agentRuntimeClient.createStreamConnection(sessionId, {
|
||||
onEvent: (event) => {
|
||||
events.push(event);
|
||||
},
|
||||
onError: (error) => {
|
||||
errorOccurred = true;
|
||||
},
|
||||
}) as MockEventSource;
|
||||
|
||||
// Simulate malformed JSON
|
||||
eventSource._simulateMessage('invalid json');
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
|
||||
expect(events).toHaveLength(0);
|
||||
expect(errorOccurred).toBe(true);
|
||||
});
|
||||
|
||||
it('should call onDisconnect when EventSource is closed', () => {
|
||||
const sessionId = 'test-session';
|
||||
let disconnectCalled = false;
|
||||
|
||||
const eventSource = agentRuntimeClient.createStreamConnection(sessionId, {
|
||||
onDisconnect: () => {
|
||||
disconnectCalled = true;
|
||||
},
|
||||
});
|
||||
|
||||
eventSource.close();
|
||||
|
||||
expect(disconnectCalled).toBe(true);
|
||||
});
|
||||
|
||||
it('should include correct parameters in URL', () => {
|
||||
const sessionId = 'test-session-123';
|
||||
|
||||
const eventSource = agentRuntimeClient.createStreamConnection(sessionId, {
|
||||
includeHistory: true,
|
||||
lastEventId: '12345',
|
||||
});
|
||||
|
||||
expect(eventSource.url).toBe(
|
||||
'/api/agent/stream?includeHistory=true&lastEventId=12345&sessionId=test-session-123',
|
||||
);
|
||||
});
|
||||
|
||||
it('should use default parameters when not provided', () => {
|
||||
const sessionId = 'test-session';
|
||||
|
||||
const eventSource = agentRuntimeClient.createStreamConnection(sessionId);
|
||||
|
||||
expect(eventSource.url).toBe(
|
||||
'/api/agent/stream?includeHistory=false&lastEventId=0&sessionId=test-session',
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,77 @@
|
||||
import { fetchEventSource } from '@lobechat/utils/client';
|
||||
import debug from 'debug';
|
||||
|
||||
import { StreamConnectionOptions, StreamEvent } from './type';
|
||||
|
||||
const log = debug('lobe-agent-runtime:client');
|
||||
|
||||
/**
|
||||
* Agent Client Service for communicating with durable agents
|
||||
*/
|
||||
class AgentRuntimeClient {
|
||||
private baseUrl = '/api/agent';
|
||||
|
||||
/**
|
||||
* Create a streaming connection to receive real-time agent events
|
||||
*/
|
||||
createStreamConnection(
|
||||
sessionId: string,
|
||||
options: StreamConnectionOptions = {},
|
||||
): AbortController {
|
||||
const {
|
||||
includeHistory = false,
|
||||
lastEventId = '0',
|
||||
onEvent,
|
||||
onError,
|
||||
onConnect,
|
||||
onDisconnect,
|
||||
} = options;
|
||||
|
||||
const params = new URLSearchParams({
|
||||
includeHistory: includeHistory.toString(),
|
||||
lastEventId,
|
||||
sessionId,
|
||||
});
|
||||
|
||||
const controller = new AbortController();
|
||||
|
||||
fetchEventSource(`${this.baseUrl}/stream?${params}`, {
|
||||
headers: {
|
||||
'Cache-Control': 'no-cache',
|
||||
'Last-Event-ID': lastEventId,
|
||||
},
|
||||
onclose: () => {
|
||||
log(`Stream connection closed for session ${sessionId}`);
|
||||
onDisconnect?.();
|
||||
},
|
||||
onerror: (error) => {
|
||||
console.error(`[AgentClientService] Stream error for session ${sessionId}:`, error);
|
||||
onError?.(error instanceof Error ? error : new Error('Stream connection error'));
|
||||
},
|
||||
onmessage: (event) => {
|
||||
try {
|
||||
const data = JSON.parse(event.data) as StreamEvent;
|
||||
log(`Received event: ${event.event || 'message'}`, event.data);
|
||||
|
||||
onEvent?.(data);
|
||||
} catch (error) {
|
||||
console.error('[AgentClientService] Failed to parse stream event:', error);
|
||||
onError?.(new Error('Failed to parse stream event'));
|
||||
}
|
||||
},
|
||||
onopen: async (response) => {
|
||||
if (response.ok) {
|
||||
log(`Stream connection opened for session ${sessionId}`);
|
||||
onConnect?.();
|
||||
} else {
|
||||
throw new Error(`Failed to open stream: ${response.status} ${response.statusText}`);
|
||||
}
|
||||
},
|
||||
signal: controller.signal,
|
||||
});
|
||||
|
||||
return controller;
|
||||
}
|
||||
}
|
||||
|
||||
export const agentRuntimeClient = new AgentRuntimeClient();
|
||||
@@ -0,0 +1,104 @@
|
||||
import { ChatMessage } from '@lobechat/types';
|
||||
|
||||
import { lambdaClient } from '@/libs/trpc/client';
|
||||
import { createAgentToolsEngine } from '@/services/agentRuntime/toolEngine';
|
||||
import { HumanInterventionRequest } from '@/services/agentRuntime/type';
|
||||
import { contextEngineering } from '@/services/chat/contextEngineering';
|
||||
import { getAgentStoreState } from '@/store/agent';
|
||||
import { agentChatConfigSelectors, agentSelectors } from '@/store/agent/selectors';
|
||||
|
||||
export { agentRuntimeClient } from './client';
|
||||
export * from './type';
|
||||
|
||||
interface AgentSessionRequest {
|
||||
agentSessionId?: string;
|
||||
autoStart?: boolean;
|
||||
messages: ChatMessage[];
|
||||
threadId?: string;
|
||||
topicId?: string;
|
||||
userMessageId: string;
|
||||
}
|
||||
|
||||
class AgentRuntimeService {
|
||||
createSession = async (data: AgentSessionRequest) => {
|
||||
const agentStoreState = getAgentStoreState();
|
||||
const agentConfig = agentSelectors.currentAgentConfig(agentStoreState);
|
||||
const chatConfig = agentChatConfigSelectors.currentChatConfig(agentStoreState);
|
||||
|
||||
const modelRuntimeConfig = {
|
||||
model: agentConfig.model,
|
||||
provider: agentConfig.provider!,
|
||||
};
|
||||
|
||||
const toolsEngine = createAgentToolsEngine({
|
||||
model: agentConfig.model,
|
||||
provider: agentConfig.provider!,
|
||||
});
|
||||
|
||||
const { tools, enabledToolIds } = toolsEngine.generateToolsDetailed({
|
||||
model: agentConfig.model,
|
||||
provider: agentConfig.provider!,
|
||||
toolIds: agentConfig.plugins,
|
||||
});
|
||||
|
||||
// Apply context engineering with preprocessing configuration
|
||||
const llmMessages = await contextEngineering({
|
||||
enableHistoryCount: agentChatConfigSelectors.enableHistoryCount(agentStoreState),
|
||||
// include user messages
|
||||
historyCount: agentChatConfigSelectors.historyCount(agentStoreState) + 2,
|
||||
inputTemplate: chatConfig.inputTemplate,
|
||||
messages: data.messages as any,
|
||||
...modelRuntimeConfig,
|
||||
systemRole: agentConfig.systemRole,
|
||||
tools: enabledToolIds,
|
||||
});
|
||||
|
||||
const toolManifestMap = Object.fromEntries(
|
||||
toolsEngine.getEnabledPluginManifests(enabledToolIds).entries(),
|
||||
);
|
||||
|
||||
return await lambdaClient.aiAgent.createSession.mutate({
|
||||
...data,
|
||||
agentConfig: {
|
||||
enableSearch: agentChatConfigSelectors.isAgentEnableSearch(agentStoreState),
|
||||
maxSteps: 50,
|
||||
// costLimit: agentChatConfig.costLimit,
|
||||
// enableRAG: false,
|
||||
// humanApprovalRequired: agentChatConfig.humanApprovalRequired || false,
|
||||
},
|
||||
messages: llmMessages,
|
||||
modelRuntimeConfig,
|
||||
toolManifestMap,
|
||||
tools,
|
||||
});
|
||||
};
|
||||
|
||||
/**
|
||||
* Delete a session
|
||||
*/
|
||||
async deleteSession(sessionId: string): Promise<void> {
|
||||
await lambdaClient.session.removeSession.mutate({ id: sessionId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Get session status
|
||||
*/
|
||||
async getSessionStatus(sessionId: string, includeHistory = false): Promise<any> {
|
||||
return await lambdaClient.aiAgent.getSessionStatus.query({ includeHistory, sessionId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle human intervention
|
||||
*/
|
||||
async handleHumanIntervention(request: HumanInterventionRequest): Promise<any> {
|
||||
return await lambdaClient.aiAgent.processHumanIntervention.mutate({
|
||||
action: request.action,
|
||||
data: request.data,
|
||||
reason: request.reason,
|
||||
sessionId: request.sessionId,
|
||||
stepIndex: 0, // Default to 0 since it's not provided in the request type
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export const agentRuntimeService = new AgentRuntimeService();
|
||||
@@ -0,0 +1,22 @@
|
||||
import { WorkingModel } from '@lobechat/types';
|
||||
|
||||
import { getSearchConfig } from '@/helpers/getSearchConfig';
|
||||
import { createToolsEngine } from '@/helpers/toolEngineering';
|
||||
import { WebBrowsingManifest } from '@/tools/web-browsing';
|
||||
|
||||
export const createAgentToolsEngine = (workingModel: WorkingModel) =>
|
||||
createToolsEngine({
|
||||
// Add WebBrowsingManifest as default tool
|
||||
defaultToolIds: [WebBrowsingManifest.identifier],
|
||||
// Create search-aware enableChecker for this request
|
||||
enableChecker: ({ pluginId }) => {
|
||||
// For WebBrowsingManifest, apply search logic
|
||||
if (pluginId === WebBrowsingManifest.identifier) {
|
||||
const searchConfig = getSearchConfig(workingModel.model, workingModel.provider);
|
||||
return searchConfig.useApplicationBuiltinSearchTool;
|
||||
}
|
||||
|
||||
// For all other plugins, enable by default
|
||||
return true;
|
||||
},
|
||||
});
|
||||
@@ -0,0 +1,70 @@
|
||||
import { OpenAIChatMessage } from '@/types/openai/chat';
|
||||
|
||||
export interface StreamEvent {
|
||||
data?: any;
|
||||
sessionId?: string;
|
||||
stepIndex?: number;
|
||||
timestamp: number;
|
||||
type:
|
||||
| 'connected'
|
||||
| 'stream_start'
|
||||
| 'stream_chunk'
|
||||
| 'stream_end'
|
||||
| 'step_start'
|
||||
| 'step_complete'
|
||||
| 'error'
|
||||
| 'heartbeat';
|
||||
}
|
||||
|
||||
export interface StreamConnectionOptions {
|
||||
includeHistory?: boolean;
|
||||
lastEventId?: string;
|
||||
onConnect?: () => void;
|
||||
onDisconnect?: () => void;
|
||||
onError?: (error: Error) => void;
|
||||
onEvent?: (event: StreamEvent) => void;
|
||||
}
|
||||
|
||||
export interface AgentSessionRequest {
|
||||
agentConfig?: {
|
||||
[key: string]: any;
|
||||
costLimit?: {
|
||||
currency: string;
|
||||
maxTotalCost: number;
|
||||
onExceeded: 'stop' | 'interrupt' | 'continue';
|
||||
};
|
||||
enableRAG?: boolean;
|
||||
enableSearch?: boolean;
|
||||
humanApprovalRequired?: boolean;
|
||||
maxSteps?: number;
|
||||
};
|
||||
autoStart?: boolean;
|
||||
messages: OpenAIChatMessage[];
|
||||
modelRuntimeConfig: {
|
||||
[key: string]: any;
|
||||
model: string;
|
||||
provider: string;
|
||||
};
|
||||
sessionId?: string;
|
||||
userMessageId: string;
|
||||
}
|
||||
|
||||
export interface AgentSessionResponse {
|
||||
autoStart: boolean;
|
||||
createdAt: string;
|
||||
firstStep?: {
|
||||
context?: any;
|
||||
messageId?: string;
|
||||
scheduled: boolean;
|
||||
};
|
||||
sessionId: string;
|
||||
status: string;
|
||||
success: boolean;
|
||||
}
|
||||
|
||||
export interface HumanInterventionRequest {
|
||||
action: 'approve' | 'reject' | 'input' | 'select';
|
||||
data?: any;
|
||||
reason?: string;
|
||||
sessionId: string;
|
||||
}
|
||||
@@ -169,6 +169,8 @@ const openingQuestions = (s: AgentStoreState) =>
|
||||
currentAgentConfig(s).openingQuestions || DEFAULT_OPENING_QUESTIONS;
|
||||
const openingMessage = (s: AgentStoreState) => currentAgentConfig(s).openingMessage || '';
|
||||
|
||||
const enableAgentMode = (s: AgentStoreState) => currentAgentConfig(s).enableAgentMode || true;
|
||||
|
||||
export const agentSelectors = {
|
||||
currentAgentConfig,
|
||||
currentAgentFiles,
|
||||
@@ -182,6 +184,7 @@ export const agentSelectors = {
|
||||
currentEnabledKnowledge,
|
||||
currentKnowledgeIds,
|
||||
displayableAgentPlugins,
|
||||
enableAgentMode,
|
||||
getAgentConfigByAgentId,
|
||||
getAgentConfigById,
|
||||
hasEnabledKnowledge,
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
import { StateCreator } from 'zustand/vanilla';
|
||||
|
||||
import { ChatStore } from '@/store/chat/store';
|
||||
|
||||
import { AgentAction, agentSlice } from './runAgent';
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-empty-interface
|
||||
export interface ChatAIAgentAction extends AgentAction {
|
||||
/**/
|
||||
}
|
||||
|
||||
export const chatAiAgent: StateCreator<
|
||||
ChatStore,
|
||||
[['zustand/devtools', never]],
|
||||
[],
|
||||
ChatAIAgentAction
|
||||
> = (...params) => ({
|
||||
...agentSlice(...params),
|
||||
});
|
||||
@@ -0,0 +1,515 @@
|
||||
import { LOADING_FLAT, isDesktop } from '@lobechat/const';
|
||||
import { ChatToolPayload, CreateMessageParams, SendMessageParams } from '@lobechat/types';
|
||||
import debug from 'debug';
|
||||
import { produce } from 'immer';
|
||||
import { StateCreator } from 'zustand/vanilla';
|
||||
|
||||
import { StreamEvent, agentRuntimeClient, agentRuntimeService } from '@/services/agentRuntime';
|
||||
import { messageService } from '@/services/message';
|
||||
import { chatSelectors } from '@/store/chat/selectors';
|
||||
import { ChatStore } from '@/store/chat/store';
|
||||
import { setNamespace } from '@/utils/storeDebug';
|
||||
|
||||
const log = debug('store:chat:ai-agent:runAgent');
|
||||
const n = setNamespace('agent');
|
||||
|
||||
interface StreamingContext {
|
||||
assistantId: string;
|
||||
content: string;
|
||||
reasoning: string;
|
||||
tmpAssistantId: string;
|
||||
toolsCalling?: ChatToolPayload[];
|
||||
}
|
||||
export interface AgentAction {
|
||||
internal_cleanupAgentSession: (assistantId: string) => void;
|
||||
internal_handleAgentError: (assistantId: string, error: string) => void;
|
||||
/**
|
||||
* Agent Runtime 相关方法
|
||||
*/
|
||||
internal_handleAgentStreamEvent: (
|
||||
sessionId: string,
|
||||
event: StreamEvent,
|
||||
context: StreamingContext,
|
||||
) => Promise<void>;
|
||||
internal_handleHumanIntervention: (
|
||||
assistantId: string,
|
||||
action: string,
|
||||
data?: any,
|
||||
) => Promise<void>;
|
||||
/**
|
||||
* Sends a message through the agent runtime workflow
|
||||
*/
|
||||
sendAgentMessage: (params: SendMessageParams) => Promise<void>;
|
||||
}
|
||||
|
||||
export const agentSlice: StateCreator<ChatStore, [['zustand/devtools', never]], [], AgentAction> = (
|
||||
set,
|
||||
get,
|
||||
) => ({
|
||||
internal_cleanupAgentSession: (assistantId: string) => {
|
||||
const session = get().agentSessions[assistantId];
|
||||
if (!session) return;
|
||||
|
||||
log(`Cleaning up agent session for ${assistantId}`);
|
||||
|
||||
// 关闭 EventSource 连接
|
||||
if (session.eventSource) {
|
||||
session.eventSource.close();
|
||||
}
|
||||
|
||||
// 删除会话信息
|
||||
set(
|
||||
produce((draft) => {
|
||||
delete draft.agentSessions[assistantId];
|
||||
}),
|
||||
false,
|
||||
n('cleanupAgentSession', { assistantId }),
|
||||
);
|
||||
|
||||
// 如果有错误,删除服务端会话
|
||||
if (session.sessionId && session.status === 'error') {
|
||||
agentRuntimeService.deleteSession(session.sessionId).catch((error: Error) => {
|
||||
console.warn(
|
||||
`[Agent Runtime] Failed to delete server session ${session.sessionId}:`,
|
||||
error,
|
||||
);
|
||||
});
|
||||
}
|
||||
},
|
||||
|
||||
internal_handleAgentError: (assistantId: string, errorMessage: string) => {
|
||||
log(`Agent error for ${assistantId}: ${errorMessage}`);
|
||||
|
||||
// 更新会话错误状态
|
||||
set(
|
||||
produce((draft) => {
|
||||
if (draft.agentSessions[assistantId]) {
|
||||
draft.agentSessions[assistantId].status = 'error';
|
||||
draft.agentSessions[assistantId].error = errorMessage;
|
||||
}
|
||||
}),
|
||||
false,
|
||||
n('setAgentError', { assistantId, errorMessage }),
|
||||
);
|
||||
|
||||
// 更新消息错误状态
|
||||
messageService.updateMessageError(assistantId, {
|
||||
message: errorMessage,
|
||||
type: 'UnknownError' as any,
|
||||
});
|
||||
get().refreshMessages();
|
||||
|
||||
// 停止 loading 状态
|
||||
get().internal_toggleChatLoading(false, assistantId);
|
||||
|
||||
// 清理会话
|
||||
get().internal_cleanupAgentSession(assistantId);
|
||||
},
|
||||
|
||||
// ======== Agent Runtime 相关方法 ========
|
||||
internal_handleAgentStreamEvent: async (sessionId, event, context) => {
|
||||
const { internal_dispatchMessage } = get();
|
||||
const session = get().agentSessions[sessionId];
|
||||
if (!session) {
|
||||
log(`No session found for ${sessionId}, ignoring event ${event.type}`);
|
||||
return;
|
||||
}
|
||||
|
||||
// 更新会话状态
|
||||
set(
|
||||
produce((draft) => {
|
||||
if (draft.agentSessions[sessionId]) {
|
||||
draft.agentSessions[sessionId].lastEventId = event.timestamp.toString();
|
||||
if (event.stepIndex !== undefined) {
|
||||
draft.agentSessions[sessionId].stepCount = event.stepIndex;
|
||||
}
|
||||
}
|
||||
}),
|
||||
false,
|
||||
n('updateAgentSessionFromEvent', { eventType: event.type }),
|
||||
);
|
||||
const assistantId = context.assistantId || context.tmpAssistantId;
|
||||
log(`assistantMessageId: ${assistantId}`);
|
||||
|
||||
switch (event.type) {
|
||||
case 'connected': {
|
||||
log(`Agent stream connected for ${assistantId}`);
|
||||
break;
|
||||
}
|
||||
|
||||
case 'heartbeat': {
|
||||
// 心跳事件,保持连接活跃
|
||||
break;
|
||||
}
|
||||
|
||||
case 'stream_start': {
|
||||
log(`Stream started for ${assistantId}:`, event.data);
|
||||
internal_dispatchMessage({
|
||||
id: context.tmpAssistantId,
|
||||
type: 'deleteMessage',
|
||||
});
|
||||
|
||||
context.assistantId = event.data.assistantMessage.id;
|
||||
|
||||
internal_dispatchMessage({
|
||||
id: context.assistantId,
|
||||
type: 'createMessage',
|
||||
value: event.data.assistantMessage,
|
||||
});
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case 'stream_chunk': {
|
||||
// 处理流式内容块
|
||||
const { chunkType } = event.data || {};
|
||||
|
||||
switch (chunkType) {
|
||||
case 'text': {
|
||||
// 更新文本内容
|
||||
context.content += event.data.content;
|
||||
log(`Stream(${event.sessionId}) chunk type=${chunkType}: `, event.data.content);
|
||||
|
||||
internal_dispatchMessage({
|
||||
id: assistantId,
|
||||
type: 'updateMessage',
|
||||
value: { content: context.content },
|
||||
});
|
||||
break;
|
||||
}
|
||||
|
||||
case 'reasoning': {
|
||||
// 更新文本内容
|
||||
context.reasoning += event.data.reasoning;
|
||||
log(`Stream(${event.sessionId}) chunk type=${chunkType}: `, event.data.reasoning);
|
||||
|
||||
internal_dispatchMessage({
|
||||
id: assistantId,
|
||||
type: 'updateMessage',
|
||||
value: { reasoning: { content: context.reasoning } },
|
||||
});
|
||||
break;
|
||||
}
|
||||
|
||||
case 'tools_calling': {
|
||||
context.toolsCalling = event.data.toolsCalling;
|
||||
|
||||
internal_dispatchMessage({
|
||||
id: assistantId,
|
||||
type: 'updateMessage',
|
||||
value: { tools: context.toolsCalling },
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
case 'stream_end': {
|
||||
// 流式结束,更新最终内容
|
||||
const { finalContent, toolCalls, reasoning, imageList, grounding } = event.data || {};
|
||||
log(`Stream ended for ${assistantId}:`, {
|
||||
hasFinalContent: !!finalContent,
|
||||
hasGrounding: !!grounding,
|
||||
hasImageList: !!(imageList && imageList.length > 0),
|
||||
hasReasoning: !!reasoning,
|
||||
hasToolCalls: !!(toolCalls && toolCalls.length > 0),
|
||||
});
|
||||
|
||||
if (finalContent !== undefined) {
|
||||
await get().internal_updateMessageContent(assistantId, finalContent, {
|
||||
...(toolCalls && toolCalls.length > 0 ? { toolCalls } : {}),
|
||||
...(reasoning ? { reasoning } : {}),
|
||||
...(imageList && imageList.length > 0 ? { imageList } : {}),
|
||||
...(grounding ? { search: grounding } : {}),
|
||||
});
|
||||
}
|
||||
|
||||
// 停止 loading 状态
|
||||
log(`Stopping loading for ${assistantId}`);
|
||||
get().internal_toggleChatLoading(false, assistantId);
|
||||
|
||||
// 显示桌面通知
|
||||
if (isDesktop) {
|
||||
try {
|
||||
const { desktopNotificationService } = await import(
|
||||
'@/services/electron/desktopNotification'
|
||||
);
|
||||
await desktopNotificationService.showNotification({
|
||||
body: 'AI 回复生成完成',
|
||||
title: 'AI 回复完成', // TODO: 使用 i18n
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Desktop notification error:', error);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case 'step_start': {
|
||||
const { phase, toolCall, pendingToolsCalling, requiresApproval } = event.data || {};
|
||||
|
||||
if (phase === 'human_approval' && requiresApproval) {
|
||||
// 需要人工批准
|
||||
log(`Human approval required for ${assistantId}:`, pendingToolsCalling);
|
||||
set(
|
||||
produce((draft) => {
|
||||
if (draft.agentSessions[assistantId]) {
|
||||
draft.agentSessions[assistantId].needsHumanInput = true;
|
||||
draft.agentSessions[assistantId].pendingApproval = pendingToolsCalling;
|
||||
}
|
||||
}),
|
||||
false,
|
||||
n('setHumanApprovalNeeded', { assistantId }),
|
||||
);
|
||||
|
||||
// 停止 loading 状态,等待人工干预
|
||||
log(`Stopping loading for human approval: ${assistantId}`);
|
||||
get().internal_toggleChatLoading(false, assistantId);
|
||||
} else if (phase === 'tool_execution' && toolCall) {
|
||||
log(`Tool execution started for ${assistantId}: ${toolCall.function?.name}`);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case 'step_complete': {
|
||||
const { phase, result, executionTime, finalState } = event.data || {};
|
||||
|
||||
if (phase === 'tool_execution' && result) {
|
||||
log(`Tool execution completed for ${assistantId} in ${executionTime}ms:`, result);
|
||||
// 刷新消息以显示工具结果
|
||||
await get().refreshMessages();
|
||||
} else if (phase === 'execution_complete' && finalState) {
|
||||
// Agent 执行完成
|
||||
log(`Agent execution completed for ${assistantId}:`, finalState);
|
||||
set(
|
||||
produce((draft) => {
|
||||
if (draft.agentSessions[assistantId]) {
|
||||
draft.agentSessions[assistantId].status = finalState.status;
|
||||
}
|
||||
}),
|
||||
false,
|
||||
n('updateAgentFinalStatus', { assistantId, status: finalState.status }),
|
||||
);
|
||||
|
||||
log(`Stopping loading for completed agent: ${assistantId}`);
|
||||
get().internal_toggleChatLoading(false, assistantId);
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case 'error': {
|
||||
const { error, message, phase } = event.data || {};
|
||||
log(`Error in ${phase} for ${assistantId}:`, error);
|
||||
get().internal_handleAgentError(assistantId, message || error || 'Unknown agent error');
|
||||
break;
|
||||
}
|
||||
|
||||
default: {
|
||||
log(`Handling event ${event.type} for ${assistantId}:`, event);
|
||||
break;
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
internal_handleHumanIntervention: async (assistantId: string, action: string, data?: any) => {
|
||||
const session = get().agentSessions[assistantId];
|
||||
if (!session || !session.needsHumanInput) {
|
||||
log(`No human intervention needed for ${assistantId}`);
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
log(`Handling human intervention ${action} for ${assistantId}:`, data);
|
||||
|
||||
// 发送人工干预请求
|
||||
await agentRuntimeService.handleHumanIntervention({
|
||||
action: action as any,
|
||||
data,
|
||||
sessionId: session.sessionId,
|
||||
});
|
||||
|
||||
// 重新开始 loading 状态
|
||||
get().internal_toggleChatLoading(true, assistantId);
|
||||
|
||||
// 清除人工干预状态
|
||||
set(
|
||||
produce((draft) => {
|
||||
if (draft.agentSessions[assistantId]) {
|
||||
draft.agentSessions[assistantId].needsHumanInput = false;
|
||||
draft.agentSessions[assistantId].pendingApproval = undefined;
|
||||
draft.agentSessions[assistantId].pendingPrompt = undefined;
|
||||
draft.agentSessions[assistantId].pendingSelect = undefined;
|
||||
}
|
||||
}),
|
||||
false,
|
||||
n('clearHumanIntervention', { action, assistantId }),
|
||||
);
|
||||
|
||||
log(`Human intervention ${action} processed for ${assistantId}`);
|
||||
} catch (error) {
|
||||
log(`Failed to handle human intervention for ${assistantId}:`, error);
|
||||
get().internal_handleAgentError(
|
||||
assistantId,
|
||||
`Human intervention failed: ${(error as Error).message}`,
|
||||
);
|
||||
}
|
||||
},
|
||||
|
||||
sendAgentMessage: async ({
|
||||
message,
|
||||
files,
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
isWelcomeQuestion: _isWelcomeQuestion,
|
||||
}: SendMessageParams) => {
|
||||
const { activeTopicId, activeId, activeThreadId } = get();
|
||||
if (!activeId) {
|
||||
log('No active session ID, cannot send agent message');
|
||||
return;
|
||||
}
|
||||
|
||||
log(`Starting agent message for session ${activeId}:`, {
|
||||
fileCount: files?.length || 0,
|
||||
message,
|
||||
});
|
||||
|
||||
set({ isCreatingMessage: true }, false, n('creatingMessage/start(agent)'));
|
||||
|
||||
const fileIdList = files?.map((f: any) => f.id);
|
||||
|
||||
// First add the user message
|
||||
const newMessage: CreateMessageParams = {
|
||||
content: message,
|
||||
files: fileIdList,
|
||||
role: 'user',
|
||||
sessionId: activeId,
|
||||
threadId: activeThreadId,
|
||||
topicId: activeTopicId,
|
||||
};
|
||||
|
||||
const tmpUserMessageId = get().internal_createTmpMessage(newMessage);
|
||||
|
||||
// Create message in server (for persistence)
|
||||
let userMessageId: string | undefined;
|
||||
|
||||
try {
|
||||
userMessageId = await get().internal_createMessage(newMessage, {
|
||||
tempMessageId: tmpUserMessageId,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Failed to create user message:', error);
|
||||
get().internal_dispatchMessage({ id: tmpUserMessageId, type: 'deleteMessage' });
|
||||
set({ isCreatingMessage: false }, false, n('creatingMessage/error'));
|
||||
}
|
||||
|
||||
set({ isCreatingMessage: false }, false, n('creatingMessage/end'));
|
||||
|
||||
if (!userMessageId) return;
|
||||
const messages = chatSelectors.activeBaseChats(get());
|
||||
|
||||
// Create a placeholder AI message for the agent response
|
||||
const agentMessageId = get().internal_createTmpMessage({
|
||||
content: LOADING_FLAT,
|
||||
role: 'assistant',
|
||||
sessionId: activeId,
|
||||
threadId: activeThreadId,
|
||||
topicId: activeTopicId,
|
||||
});
|
||||
|
||||
// Start durable agent runtime processing
|
||||
try {
|
||||
set({ isCreatingMessage: true }, false, n('agentWorkflow/start'));
|
||||
get().internal_toggleChatLoading(true, agentMessageId, n('sendAgentMessage/start') as string);
|
||||
|
||||
// 创建 Agent 会话
|
||||
|
||||
const sessionResponse = await agentRuntimeService.createSession({
|
||||
agentSessionId: activeId,
|
||||
autoStart: true,
|
||||
messages,
|
||||
threadId: activeThreadId,
|
||||
topicId: activeTopicId,
|
||||
userMessageId,
|
||||
});
|
||||
|
||||
log(`Created session ${sessionResponse.sessionId}:`, sessionResponse);
|
||||
|
||||
// 存储 Agent 会话信息
|
||||
set(
|
||||
produce((draft) => {
|
||||
draft.agentSessions[sessionResponse.sessionId] = {
|
||||
lastEventId: '0',
|
||||
sessionId: sessionResponse.sessionId,
|
||||
status: 'created', // 使用后端返回的实际状态
|
||||
stepCount: 0,
|
||||
totalCost: 0,
|
||||
};
|
||||
}),
|
||||
false,
|
||||
n('createAgentSession', {
|
||||
assistantId: agentMessageId,
|
||||
sessionId: sessionResponse.sessionId,
|
||||
}),
|
||||
);
|
||||
|
||||
// 创建流式连接
|
||||
log(`[StreamConnection] Creating stream connection for session ${sessionResponse.sessionId}`);
|
||||
|
||||
const context: StreamingContext = {
|
||||
assistantId: '',
|
||||
content: '',
|
||||
reasoning: '',
|
||||
tmpAssistantId: agentMessageId,
|
||||
};
|
||||
|
||||
const eventSource = agentRuntimeClient.createStreamConnection(sessionResponse.sessionId, {
|
||||
includeHistory: false,
|
||||
onConnect: () => {
|
||||
log(`Stream connected to ${sessionResponse.sessionId}`);
|
||||
},
|
||||
onDisconnect: () => {
|
||||
log(`Stream disconnected to ${sessionResponse.sessionId}`);
|
||||
get().internal_cleanupAgentSession(agentMessageId);
|
||||
},
|
||||
onError: (error: Error) => {
|
||||
log(`Stream error for ${sessionResponse.sessionId}:`, error);
|
||||
get().internal_handleAgentError(agentMessageId, error.message);
|
||||
},
|
||||
onEvent: async (event: StreamEvent) => {
|
||||
await get().internal_handleAgentStreamEvent(sessionResponse.sessionId, event, context);
|
||||
},
|
||||
});
|
||||
|
||||
// 保存 EventSource 引用
|
||||
set(
|
||||
produce((draft) => {
|
||||
if (draft.agentSessions[agentMessageId]) {
|
||||
draft.agentSessions[agentMessageId].eventSource = eventSource;
|
||||
}
|
||||
}),
|
||||
false,
|
||||
n('saveAgentEventSource', { assistantId: agentMessageId }),
|
||||
);
|
||||
} catch (error) {
|
||||
log(`Failed to start agent session for ${agentMessageId}:`, error);
|
||||
|
||||
// 更新错误状态
|
||||
await messageService.updateMessageError(agentMessageId, {
|
||||
message: (error as Error).message,
|
||||
type: 'UnknownError' as any,
|
||||
});
|
||||
await get().refreshMessages();
|
||||
|
||||
get().internal_toggleChatLoading(
|
||||
false,
|
||||
agentMessageId,
|
||||
n('generateMessage(error)', { error, messageId: agentMessageId }),
|
||||
);
|
||||
|
||||
throw error;
|
||||
} finally {
|
||||
set({ isCreatingMessage: false }, false, n('agentWorkflow/end'));
|
||||
}
|
||||
},
|
||||
});
|
||||
@@ -1,6 +1,24 @@
|
||||
import type { ChatInputEditor } from '@/features/ChatInput';
|
||||
|
||||
export interface AgentSessionInfo {
|
||||
error?: string;
|
||||
eventSource?: EventSource;
|
||||
lastEventId?: string;
|
||||
needsHumanInput?: boolean;
|
||||
pendingApproval?: any[];
|
||||
pendingPrompt?: any;
|
||||
pendingSelect?: any;
|
||||
sessionId: string;
|
||||
status: string;
|
||||
stepCount: number;
|
||||
totalCost?: number;
|
||||
}
|
||||
|
||||
export interface ChatAIChatState {
|
||||
/**
|
||||
* Agent sessions map, keyed by messageId (assistantMessageId)
|
||||
*/
|
||||
agentSessions: Record<string, AgentSessionInfo>;
|
||||
inputFiles: File[];
|
||||
inputMessage: string;
|
||||
mainInputEditor: ChatInputEditor | null;
|
||||
@@ -17,6 +35,7 @@ export interface ChatAIChatState {
|
||||
}
|
||||
|
||||
export const initialAiChatState: ChatAIChatState = {
|
||||
agentSessions: {},
|
||||
inputFiles: [],
|
||||
inputMessage: '',
|
||||
mainInputEditor: null,
|
||||
|
||||
@@ -17,6 +17,7 @@ import { ChatTTSAction, chatTTS } from './slices/tts/action';
|
||||
import { ChatThreadAction, chatThreadMessage } from './slices/thread/action';
|
||||
import { chatAiGroupChat, ChatGroupChatAction } from './slices/aiChat/actions/generateAIGroupChat';
|
||||
import { OperationActions, operationActions } from './slices/operation/actions';
|
||||
import { ChatAIAgentAction, chatAiAgent } from './slices/aiAgent/actions';
|
||||
|
||||
export interface ChatStoreAction
|
||||
extends ChatMessageAction,
|
||||
@@ -29,7 +30,8 @@ export interface ChatStoreAction
|
||||
ChatPluginAction,
|
||||
ChatBuiltinToolAction,
|
||||
ChatPortalAction,
|
||||
OperationActions {}
|
||||
OperationActions,
|
||||
ChatAIAgentAction {}
|
||||
|
||||
export type ChatStore = ChatStoreAction & ChatStoreState;
|
||||
|
||||
@@ -49,6 +51,7 @@ const createStore: StateCreator<ChatStore, [['zustand/devtools', never]]> = (...
|
||||
...chatPlugin(...params),
|
||||
...chatPortalSlice(...params),
|
||||
...operationActions(...params),
|
||||
...chatAiAgent(...params),
|
||||
|
||||
// cloud
|
||||
});
|
||||
|
||||
@@ -14,6 +14,7 @@ export const WebBrowsingManifest: BuiltinToolManifest = {
|
||||
{
|
||||
description:
|
||||
'a search service. Useful for when you need to answer questions about current events. Input should be a search query. Output is a JSON array of the query results',
|
||||
// humanIntervention: 'always',
|
||||
name: WebBrowsingApiName.search,
|
||||
parameters: {
|
||||
properties: {
|
||||
|
||||
Reference in New Issue
Block a user