mirror of
https://github.com/lobehub/lobe-chat.git
synced 2026-06-15 12:10:16 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c6bd5fb03b | |||
| dfc9ec2c05 | |||
| cd8ae7f976 | |||
| 91cac37bc3 | |||
| 7eecd7bff1 | |||
| cc2890766f |
@@ -83,7 +83,7 @@ export interface QueryMessagesOptions {
|
||||
/**
|
||||
* Post-process function for file URLs
|
||||
*/
|
||||
postProcessUrl?: (path: string | null, file: { fileType: string }) => Promise<string>;
|
||||
postProcessUrl?: PostProcessUrl;
|
||||
/**
|
||||
* Topic ID for MessageGroup aggregation queries
|
||||
*/
|
||||
@@ -94,13 +94,17 @@ export interface QueryMessagesOptions {
|
||||
where?: SQL;
|
||||
}
|
||||
|
||||
export type PostProcessUrl = (path: string | null, file: { fileType: string }) => Promise<string>;
|
||||
|
||||
export class MessageModel {
|
||||
private userId: string;
|
||||
private db: LobeChatDatabase;
|
||||
private defaultPostProcessUrl?: PostProcessUrl;
|
||||
|
||||
constructor(db: LobeChatDatabase, userId: string) {
|
||||
constructor(db: LobeChatDatabase, userId: string, options?: { postProcessUrl?: PostProcessUrl }) {
|
||||
this.userId = userId;
|
||||
this.db = db;
|
||||
this.defaultPostProcessUrl = options?.postProcessUrl;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -133,9 +137,11 @@ export class MessageModel {
|
||||
threadId,
|
||||
}: QueryMessageParams = {},
|
||||
options: {
|
||||
postProcessUrl?: (path: string | null, file: { fileType: string }) => Promise<string>;
|
||||
postProcessUrl?: PostProcessUrl;
|
||||
} = {},
|
||||
) => {
|
||||
const postProcessUrl = options.postProcessUrl ?? this.defaultPostProcessUrl;
|
||||
|
||||
// Build agent condition (handles legacy sessionId lookup)
|
||||
let agentCondition: SQL | undefined;
|
||||
if (agentId) {
|
||||
@@ -150,7 +156,7 @@ export class MessageModel {
|
||||
return this.queryWithWhere({
|
||||
current,
|
||||
pageSize,
|
||||
postProcessUrl: options.postProcessUrl,
|
||||
postProcessUrl,
|
||||
// Thread queries optionally add agent/session scope if provided
|
||||
where: agentCondition ? and(agentCondition, threadCondition) : threadCondition,
|
||||
});
|
||||
@@ -169,7 +175,7 @@ export class MessageModel {
|
||||
return this.queryWithWhere({
|
||||
current,
|
||||
pageSize,
|
||||
postProcessUrl: options.postProcessUrl,
|
||||
postProcessUrl,
|
||||
topicId: topicId ?? undefined,
|
||||
where: whereCondition,
|
||||
});
|
||||
@@ -186,7 +192,7 @@ export class MessageModel {
|
||||
return this.queryWithWhere({
|
||||
current,
|
||||
pageSize,
|
||||
postProcessUrl: options.postProcessUrl,
|
||||
postProcessUrl,
|
||||
topicId: topicId ?? undefined,
|
||||
where: whereCondition,
|
||||
});
|
||||
@@ -208,7 +214,13 @@ export class MessageModel {
|
||||
* @returns Messages with all related data, including MessageGroup nodes
|
||||
*/
|
||||
queryWithWhere = async (options: QueryMessagesOptions = {}): Promise<UIChatMessage[]> => {
|
||||
const { where, current = 0, pageSize = 1000, postProcessUrl, topicId } = options;
|
||||
const {
|
||||
where,
|
||||
current = 0,
|
||||
pageSize = 1000,
|
||||
postProcessUrl = this.defaultPostProcessUrl,
|
||||
topicId,
|
||||
} = options;
|
||||
const offset = current * pageSize;
|
||||
|
||||
// 1. get basic messages with joins, excluding messages that belong to MessageGroups
|
||||
@@ -543,12 +555,12 @@ export class MessageModel {
|
||||
queryByIds = async (
|
||||
messageIds: string[],
|
||||
options: {
|
||||
postProcessUrl?: (path: string | null, file: { fileType: string }) => Promise<string>;
|
||||
postProcessUrl?: PostProcessUrl;
|
||||
} = {},
|
||||
): Promise<UIChatMessage[]> => {
|
||||
if (messageIds.length === 0) return [];
|
||||
|
||||
const { postProcessUrl } = options;
|
||||
const postProcessUrl = options.postProcessUrl ?? this.defaultPostProcessUrl;
|
||||
|
||||
// 1. Query messages with joins
|
||||
const result = await this.db
|
||||
@@ -801,7 +813,7 @@ export class MessageModel {
|
||||
private queryMessageGroupNodes = async (
|
||||
topicId: string,
|
||||
timeRange?: { endTime: Date; startTime: Date },
|
||||
postProcessUrl?: (path: string | null, file: { fileType: string }) => Promise<string>,
|
||||
postProcessUrl?: PostProcessUrl,
|
||||
): Promise<UIChatMessage[]> => {
|
||||
// 1. Query MessageGroups for this topic, optionally filtered by time range
|
||||
const whereConditions = [
|
||||
|
||||
@@ -32,6 +32,13 @@ vi.mock('@/server/services/aiChat', () => ({
|
||||
AiChatService: vi.fn().mockImplementation(() => ({})),
|
||||
}));
|
||||
|
||||
// Mock FileService to avoid S3 dependency
|
||||
vi.mock('@/server/services/file', () => ({
|
||||
FileService: vi.fn().mockImplementation(() => ({
|
||||
getFullFileUrl: vi.fn((path: string | null) => path),
|
||||
})),
|
||||
}));
|
||||
|
||||
describe('aiAgentRouter.execSubAgentTask', () => {
|
||||
let serverDB: LobeChatDatabase;
|
||||
let userId: string;
|
||||
|
||||
@@ -40,6 +40,13 @@ vi.mock('@/server/services/aiChat', () => ({
|
||||
AiChatService: vi.fn().mockImplementation(() => ({})),
|
||||
}));
|
||||
|
||||
// Mock FileService to avoid S3 dependency
|
||||
vi.mock('@/server/services/file', () => ({
|
||||
FileService: vi.fn().mockImplementation(() => ({
|
||||
getFullFileUrl: vi.fn((path: string | null) => path),
|
||||
})),
|
||||
}));
|
||||
|
||||
describe('aiAgentRouter.getSubAgentTaskStatus', () => {
|
||||
let serverDB: LobeChatDatabase;
|
||||
let userId: string;
|
||||
|
||||
@@ -29,6 +29,13 @@ vi.mock('@/server/services/aiChat', () => ({
|
||||
AiChatService: vi.fn().mockImplementation(() => ({})),
|
||||
}));
|
||||
|
||||
// Mock FileService to avoid S3 dependency
|
||||
vi.mock('@/server/services/file', () => ({
|
||||
FileService: vi.fn().mockImplementation(() => ({
|
||||
getFullFileUrl: vi.fn((path: string | null) => path),
|
||||
})),
|
||||
}));
|
||||
|
||||
describe('aiAgentRouter.interruptTask', () => {
|
||||
let serverDB: LobeChatDatabase;
|
||||
let userId: string;
|
||||
|
||||
+7
@@ -35,6 +35,13 @@ vi.mock('@/server/services/aiChat', () => ({
|
||||
AiChatService: vi.fn().mockImplementation(() => ({})),
|
||||
}));
|
||||
|
||||
// Mock FileService to avoid S3 dependency
|
||||
vi.mock('@/server/services/file', () => ({
|
||||
FileService: vi.fn().mockImplementation(() => ({
|
||||
getFullFileUrl: vi.fn((path: string | null) => path),
|
||||
})),
|
||||
}));
|
||||
|
||||
describe('createClientGroupAgentTaskThread Integration', () => {
|
||||
let serverDB: LobeChatDatabase;
|
||||
let userId: string;
|
||||
|
||||
+7
@@ -37,6 +37,13 @@ vi.mock('@/server/services/aiChat', () => ({
|
||||
AiChatService: vi.fn().mockImplementation(() => ({})),
|
||||
}));
|
||||
|
||||
// Mock FileService to avoid S3 dependency
|
||||
vi.mock('@/server/services/file', () => ({
|
||||
FileService: vi.fn().mockImplementation(() => ({
|
||||
getFullFileUrl: vi.fn((path: string | null) => path),
|
||||
})),
|
||||
}));
|
||||
|
||||
describe('createClientTaskThread Integration', () => {
|
||||
let serverDB: LobeChatDatabase;
|
||||
let userId: string;
|
||||
|
||||
@@ -38,6 +38,13 @@ vi.mock('@/server/services/aiChat', () => ({
|
||||
AiChatService: vi.fn().mockImplementation(() => ({})),
|
||||
}));
|
||||
|
||||
// Mock FileService to avoid S3 dependency
|
||||
vi.mock('@/server/services/file', () => ({
|
||||
FileService: vi.fn().mockImplementation(() => ({
|
||||
getFullFileUrl: vi.fn((path: string | null) => path),
|
||||
})),
|
||||
}));
|
||||
|
||||
describe('Agent Task Integration', () => {
|
||||
let serverDB: LobeChatDatabase;
|
||||
let userId: string;
|
||||
|
||||
@@ -15,6 +15,7 @@ import { serverDatabase } from '@/libs/trpc/lambda/middleware';
|
||||
import { AgentRuntimeService } from '@/server/services/agentRuntime';
|
||||
import { AiAgentService } from '@/server/services/aiAgent';
|
||||
import { AiChatService } from '@/server/services/aiChat';
|
||||
import { FileService } from '@/server/services/file';
|
||||
import { nanoid } from '@/utils/uuid';
|
||||
|
||||
const log = debug('lobe-server:ai-agent-router');
|
||||
@@ -235,12 +236,16 @@ const InterruptTaskSchema = z
|
||||
const aiAgentProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
const fileService = new FileService(ctx.serverDB, ctx.userId);
|
||||
|
||||
return opts.next({
|
||||
ctx: {
|
||||
agentRuntimeService: new AgentRuntimeService(ctx.serverDB, ctx.userId),
|
||||
aiAgentService: new AiAgentService(ctx.serverDB, ctx.userId),
|
||||
aiChatService: new AiChatService(ctx.serverDB, ctx.userId),
|
||||
messageModel: new MessageModel(ctx.serverDB, ctx.userId),
|
||||
messageModel: new MessageModel(ctx.serverDB, ctx.userId, {
|
||||
postProcessUrl: (path) => fileService.getFullFileUrl(path),
|
||||
}),
|
||||
threadModel: new ThreadModel(ctx.serverDB, ctx.userId),
|
||||
topicModel: new TopicModel(ctx.serverDB, ctx.userId),
|
||||
},
|
||||
|
||||
@@ -21,11 +21,14 @@ import { basicContextSchema } from './_schema/context';
|
||||
const messageProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
const fileService = new FileService(ctx.serverDB, ctx.userId);
|
||||
|
||||
return opts.next({
|
||||
ctx: {
|
||||
compressionRepo: new CompressionRepository(ctx.serverDB, ctx.userId),
|
||||
fileService: new FileService(ctx.serverDB, ctx.userId),
|
||||
messageModel: new MessageModel(ctx.serverDB, ctx.userId),
|
||||
messageModel: new MessageModel(ctx.serverDB, ctx.userId, {
|
||||
postProcessUrl: (path) => fileService.getFullFileUrl(path),
|
||||
}),
|
||||
messageService: new MessageService(ctx.serverDB, ctx.userId),
|
||||
},
|
||||
});
|
||||
@@ -201,15 +204,12 @@ export const messageRouter = router({
|
||||
ctx.userId ?? undefined,
|
||||
);
|
||||
|
||||
const messageModel = new MessageModel(ctx.serverDB, share.ownerId);
|
||||
const fileService = new FileService(ctx.serverDB, share.ownerId);
|
||||
const messageModel = new MessageModel(ctx.serverDB, share.ownerId, {
|
||||
postProcessUrl: (path) => fileService.getFullFileUrl(path),
|
||||
});
|
||||
|
||||
return messageModel.query(
|
||||
{ ...queryParams, topicId: share.topicId },
|
||||
{
|
||||
postProcessUrl: (path) => fileService.getFullFileUrl(path),
|
||||
},
|
||||
);
|
||||
return messageModel.query({ ...queryParams, topicId: share.topicId });
|
||||
}
|
||||
|
||||
// Authenticated access - require userId
|
||||
@@ -217,12 +217,12 @@ export const messageRouter = router({
|
||||
throw new TRPCError({ code: 'UNAUTHORIZED', message: 'Authentication required' });
|
||||
}
|
||||
|
||||
const messageModel = new MessageModel(ctx.serverDB, ctx.userId);
|
||||
const fileService = new FileService(ctx.serverDB, ctx.userId);
|
||||
|
||||
return messageModel.query(queryParams, {
|
||||
const messageModel = new MessageModel(ctx.serverDB, ctx.userId, {
|
||||
postProcessUrl: (path) => fileService.getFullFileUrl(path),
|
||||
});
|
||||
|
||||
return messageModel.query(queryParams);
|
||||
}),
|
||||
|
||||
rankModels: messageProcedure.query(async ({ ctx }) => {
|
||||
|
||||
@@ -26,6 +26,13 @@ vi.mock('@/database/models/message', () => ({
|
||||
})),
|
||||
}));
|
||||
|
||||
// Mock FileService to avoid S3 dependency
|
||||
vi.mock('@/server/services/file', () => ({
|
||||
FileService: vi.fn().mockImplementation(() => ({
|
||||
getFullFileUrl: vi.fn((path: string | null) => path),
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('@/database/models/agent', () => ({
|
||||
AgentModel: vi.fn().mockImplementation(() => ({
|
||||
getAgentConfigById: vi.fn(),
|
||||
|
||||
@@ -14,6 +14,7 @@ import { AgentRuntimeCoordinator, createStreamEventManager } from '@/server/modu
|
||||
import { type RuntimeExecutorContext } from '@/server/modules/AgentRuntime/RuntimeExecutors';
|
||||
import { createRuntimeExecutors } from '@/server/modules/AgentRuntime/RuntimeExecutors';
|
||||
import { type IStreamEventManager } from '@/server/modules/AgentRuntime/types';
|
||||
import { FileService } from '@/server/services/file';
|
||||
import { mcpService } from '@/server/services/mcp';
|
||||
import { PluginGatewayService } from '@/server/services/pluginGateway';
|
||||
import { QueueService } from '@/server/services/queue';
|
||||
@@ -157,7 +158,10 @@ export class AgentRuntimeService {
|
||||
this.snapshotStore = options?.snapshotStore ?? this.createDefaultSnapshotStore();
|
||||
this.serverDB = db;
|
||||
this.userId = userId;
|
||||
this.messageModel = new MessageModel(db, this.userId);
|
||||
const fileService = new FileService(db, this.userId);
|
||||
this.messageModel = new MessageModel(db, this.userId, {
|
||||
postProcessUrl: (path) => fileService.getFullFileUrl(path),
|
||||
});
|
||||
|
||||
// Initialize ToolExecutionService with dependencies
|
||||
const pluginGatewayService = new PluginGatewayService();
|
||||
|
||||
@@ -17,6 +17,13 @@ vi.mock('@/database/models/message', () => ({
|
||||
})),
|
||||
}));
|
||||
|
||||
// Mock FileService to avoid S3 dependency
|
||||
vi.mock('@/server/services/file', () => ({
|
||||
FileService: vi.fn().mockImplementation(() => ({
|
||||
getFullFileUrl: vi.fn((path: string | null) => path),
|
||||
})),
|
||||
}));
|
||||
|
||||
// Mock ModelRuntime
|
||||
vi.mock('@/server/modules/ModelRuntime', () => ({
|
||||
ApiKeyManager: vi.fn().mockImplementation(() => ({
|
||||
|
||||
@@ -8,6 +8,12 @@ vi.mock('@/envs/app', () => ({ appEnv: { APP_URL: 'http://localhost:3010' } }));
|
||||
vi.mock('@/database/models/message', () => ({
|
||||
MessageModel: vi.fn().mockImplementation(() => ({})),
|
||||
}));
|
||||
// Mock FileService to avoid S3 dependency
|
||||
vi.mock('@/server/services/file', () => ({
|
||||
FileService: vi.fn().mockImplementation(() => ({
|
||||
getFullFileUrl: vi.fn((path: string | null) => path),
|
||||
})),
|
||||
}));
|
||||
vi.mock('@/server/modules/AgentRuntime', () => ({
|
||||
AgentRuntimeCoordinator: vi.fn().mockImplementation(() => ({
|
||||
loadAgentState: vi.fn(),
|
||||
|
||||
@@ -17,6 +17,13 @@ vi.mock('@/database/models/message', () => ({
|
||||
})),
|
||||
}));
|
||||
|
||||
// Mock FileService to avoid S3 dependency
|
||||
vi.mock('@/server/services/file', () => ({
|
||||
FileService: vi.fn().mockImplementation(() => ({
|
||||
getFullFileUrl: vi.fn((path: string | null) => path),
|
||||
})),
|
||||
}));
|
||||
|
||||
// Mock ModelRuntime
|
||||
vi.mock('@/server/modules/ModelRuntime', () => ({
|
||||
initializeRuntimeOptions: vi.fn(),
|
||||
|
||||
@@ -18,6 +18,13 @@ vi.mock('@/database/models/message', () => ({
|
||||
})),
|
||||
}));
|
||||
|
||||
// Mock FileService to avoid S3 dependency
|
||||
vi.mock('@/server/services/file', () => ({
|
||||
FileService: vi.fn().mockImplementation(() => ({
|
||||
getFullFileUrl: vi.fn((path: string | null) => path),
|
||||
})),
|
||||
}));
|
||||
|
||||
// Mock ModelRuntime
|
||||
vi.mock('@/server/modules/ModelRuntime', () => ({
|
||||
ApiKeyManager: vi.fn().mockImplementation(() => ({
|
||||
|
||||
@@ -174,7 +174,6 @@ describe('AiAgentService.execAgent - topic history loading', () => {
|
||||
// Verify messageModel.query was called to load history for the topic
|
||||
expect(mockMessageQuery).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ topicId: 'topic-existing' }),
|
||||
expect.objectContaining({ postProcessUrl: expect.any(Function) }),
|
||||
);
|
||||
|
||||
// Verify createOperation received all history messages + the new user message
|
||||
|
||||
@@ -48,6 +48,13 @@ vi.mock('@/database/models/topic', () => ({
|
||||
})),
|
||||
}));
|
||||
|
||||
// Mock FileService to avoid S3 dependency
|
||||
vi.mock('@/server/services/file', () => ({
|
||||
FileService: vi.fn().mockImplementation(() => ({
|
||||
getFullFileUrl: vi.fn((path: string | null) => path),
|
||||
})),
|
||||
}));
|
||||
|
||||
// Mock AgentService
|
||||
vi.mock('@/server/services/agent', () => ({
|
||||
AgentService: vi.fn().mockImplementation(() => ({
|
||||
|
||||
@@ -187,6 +187,7 @@ export class AiAgentService {
|
||||
private readonly threadModel: ThreadModel;
|
||||
private readonly topicModel: TopicModel;
|
||||
private readonly agentRuntimeService: AgentRuntimeService;
|
||||
private readonly fileService: FileService;
|
||||
private readonly marketService: MarketService;
|
||||
private readonly klavisService: KlavisService;
|
||||
|
||||
@@ -200,7 +201,10 @@ export class AiAgentService {
|
||||
this.agentDocumentsService = new AgentDocumentsService(db, userId);
|
||||
this.agentModel = new AgentModel(db, userId);
|
||||
this.agentService = new AgentService(db, userId);
|
||||
this.messageModel = new MessageModel(db, userId);
|
||||
this.fileService = new FileService(db, userId);
|
||||
this.messageModel = new MessageModel(db, userId, {
|
||||
postProcessUrl: (path) => this.fileService.getFullFileUrl(path),
|
||||
});
|
||||
this.pluginModel = new PluginModel(db, userId);
|
||||
this.threadModel = new ThreadModel(db, userId);
|
||||
this.topicModel = new TopicModel(db, userId);
|
||||
@@ -765,31 +769,20 @@ export class AiAgentService {
|
||||
}
|
||||
|
||||
// 11. Get existing messages if provided
|
||||
// Use postProcessUrl to resolve S3 keys in imageList to publicly accessible URLs,
|
||||
// matching the frontend flow in aiChatService.getMessagesAndTopics.
|
||||
const fileService = new FileService(this.db, this.userId);
|
||||
const postProcessUrl = (path: string | null) => fileService.getFullFileUrl(path);
|
||||
|
||||
let historyMessages: any[] = [];
|
||||
if (existingMessageIds.length > 0) {
|
||||
historyMessages = await this.messageModel.query(
|
||||
{
|
||||
sessionId: appContext?.sessionId,
|
||||
topicId: appContext?.topicId ?? undefined,
|
||||
},
|
||||
{ postProcessUrl },
|
||||
);
|
||||
historyMessages = await this.messageModel.query({
|
||||
sessionId: appContext?.sessionId,
|
||||
topicId: appContext?.topicId ?? undefined,
|
||||
});
|
||||
const idSet = new Set(existingMessageIds);
|
||||
historyMessages = historyMessages.filter((msg) => idSet.has(msg.id));
|
||||
} else if (appContext?.topicId) {
|
||||
// Follow-up message in existing topic: load all history for context
|
||||
historyMessages = await this.messageModel.query(
|
||||
{
|
||||
sessionId: appContext?.sessionId,
|
||||
topicId: appContext.topicId,
|
||||
},
|
||||
{ postProcessUrl },
|
||||
);
|
||||
historyMessages = await this.messageModel.query({
|
||||
sessionId: appContext?.sessionId,
|
||||
topicId: appContext.topicId,
|
||||
});
|
||||
}
|
||||
|
||||
await throwIfExecutionAborted('message history loading');
|
||||
@@ -806,7 +799,7 @@ export class AiAgentService {
|
||||
await throwIfExecutionAborted('file upload');
|
||||
|
||||
try {
|
||||
const result = await ingestAttachment(file, fileService, this.userId);
|
||||
const result = await ingestAttachment(file, this.fileService, this.userId);
|
||||
fileIds.push(result.fileId);
|
||||
|
||||
if (result.isImage) {
|
||||
|
||||
@@ -33,10 +33,12 @@ describe('AiChatService', () => {
|
||||
sessionId: 's1',
|
||||
});
|
||||
|
||||
expect(mockQueryMessages).toHaveBeenCalledWith(
|
||||
{ agentId: 'agent-1', groupId: 'group-1', includeTopic: true, sessionId: 's1' },
|
||||
expect.objectContaining({ postProcessUrl: expect.any(Function) }),
|
||||
);
|
||||
expect(mockQueryMessages).toHaveBeenCalledWith({
|
||||
agentId: 'agent-1',
|
||||
groupId: 'group-1',
|
||||
includeTopic: true,
|
||||
sessionId: 's1',
|
||||
});
|
||||
expect(mockQueryTopics).toHaveBeenCalledWith({ agentId: 'agent-1', groupId: 'group-1' });
|
||||
expect(res.messages).toEqual([{ id: 'm1' }]);
|
||||
expect(res.topics).toEqual([{ id: 't1' }]);
|
||||
|
||||
@@ -7,15 +7,16 @@ import { FileService } from '@/server/services/file';
|
||||
export class AiChatService {
|
||||
private userId: string;
|
||||
private messageModel: MessageModel;
|
||||
private fileService: FileService;
|
||||
private topicModel: TopicModel;
|
||||
|
||||
constructor(serverDB: LobeChatDatabase, userId: string) {
|
||||
this.userId = userId;
|
||||
|
||||
this.messageModel = new MessageModel(serverDB, userId);
|
||||
const fileService = new FileService(serverDB, userId);
|
||||
this.messageModel = new MessageModel(serverDB, userId, {
|
||||
postProcessUrl: (path) => fileService.getFullFileUrl(path),
|
||||
});
|
||||
this.topicModel = new TopicModel(serverDB, userId);
|
||||
this.fileService = new FileService(serverDB, userId);
|
||||
}
|
||||
|
||||
async getMessagesAndTopics(params: {
|
||||
@@ -29,9 +30,7 @@ export class AiChatService {
|
||||
topicId?: string;
|
||||
}) {
|
||||
const [messages, topics] = await Promise.all([
|
||||
this.messageModel.query(params, {
|
||||
postProcessUrl: (path) => this.fileService.getFullFileUrl(path),
|
||||
}),
|
||||
this.messageModel.query(params),
|
||||
params.includeTopic
|
||||
? this.topicModel.query({ agentId: params.agentId, groupId: params.groupId })
|
||||
: undefined,
|
||||
|
||||
@@ -60,12 +60,11 @@ describe('MessageService', () => {
|
||||
const result = await messageService.removeMessage(messageId, { sessionId: 'session-1' });
|
||||
|
||||
expect(mockMessageModel.deleteMessage).toHaveBeenCalledWith(messageId);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith(
|
||||
{ groupId: undefined, sessionId: 'session-1', topicId: undefined },
|
||||
expect.objectContaining({
|
||||
groupAssistantMessages: false,
|
||||
}),
|
||||
);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith({
|
||||
groupId: undefined,
|
||||
sessionId: 'session-1',
|
||||
topicId: undefined,
|
||||
});
|
||||
expect(result).toEqual({ messages: mockMessages, success: true });
|
||||
});
|
||||
|
||||
@@ -77,12 +76,11 @@ describe('MessageService', () => {
|
||||
const result = await messageService.removeMessage(messageId, { topicId: 'topic-1' });
|
||||
|
||||
expect(mockMessageModel.deleteMessage).toHaveBeenCalledWith(messageId);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith(
|
||||
{ groupId: undefined, sessionId: undefined, topicId: 'topic-1' },
|
||||
expect.objectContaining({
|
||||
groupAssistantMessages: false,
|
||||
}),
|
||||
);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith({
|
||||
groupId: undefined,
|
||||
sessionId: undefined,
|
||||
topicId: 'topic-1',
|
||||
});
|
||||
expect(result).toEqual({ messages: mockMessages, success: true });
|
||||
});
|
||||
});
|
||||
@@ -261,12 +259,11 @@ describe('MessageService', () => {
|
||||
});
|
||||
|
||||
expect(mockMessageModel.updateMetadata).toHaveBeenCalledWith(messageId, metadata);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith(
|
||||
{ groupId: undefined, sessionId: undefined, topicId: 'topic-1' },
|
||||
expect.objectContaining({
|
||||
groupAssistantMessages: false,
|
||||
}),
|
||||
);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith({
|
||||
groupId: undefined,
|
||||
sessionId: undefined,
|
||||
topicId: 'topic-1',
|
||||
});
|
||||
expect(result).toEqual({ messages: mockMessages, success: true });
|
||||
});
|
||||
});
|
||||
@@ -287,19 +284,14 @@ describe('MessageService', () => {
|
||||
const result = await messageService.createMessage(params as any);
|
||||
|
||||
expect(mockMessageModel.create).toHaveBeenCalledWith(params);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith(
|
||||
{
|
||||
agentId: 'agent-1',
|
||||
current: 0,
|
||||
groupId: undefined,
|
||||
pageSize: 9999,
|
||||
threadId: undefined,
|
||||
topicId: undefined,
|
||||
},
|
||||
expect.objectContaining({
|
||||
postProcessUrl: expect.any(Function),
|
||||
}),
|
||||
);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith({
|
||||
agentId: 'agent-1',
|
||||
current: 0,
|
||||
groupId: undefined,
|
||||
pageSize: 9999,
|
||||
threadId: undefined,
|
||||
topicId: undefined,
|
||||
});
|
||||
expect(result).toEqual({
|
||||
id: 'msg-1',
|
||||
messages: mockMessages,
|
||||
@@ -322,19 +314,14 @@ describe('MessageService', () => {
|
||||
|
||||
const result = await messageService.createMessage(params as any);
|
||||
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith(
|
||||
{
|
||||
agentId: 'agent-1',
|
||||
current: 0,
|
||||
groupId: 'group-1',
|
||||
pageSize: 9999,
|
||||
threadId: undefined,
|
||||
topicId: 'topic-1',
|
||||
},
|
||||
expect.objectContaining({
|
||||
postProcessUrl: expect.any(Function),
|
||||
}),
|
||||
);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith({
|
||||
agentId: 'agent-1',
|
||||
current: 0,
|
||||
groupId: 'group-1',
|
||||
pageSize: 9999,
|
||||
threadId: undefined,
|
||||
topicId: 'topic-1',
|
||||
});
|
||||
expect(result.id).toBe('msg-1');
|
||||
expect(result.messages).toEqual(mockMessages);
|
||||
});
|
||||
@@ -357,19 +344,14 @@ describe('MessageService', () => {
|
||||
const result = await messageService.createMessage(params as any);
|
||||
|
||||
expect(mockMessageModel.create).toHaveBeenCalledWith(params);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith(
|
||||
{
|
||||
agentId: 'agent-1',
|
||||
current: 0,
|
||||
groupId: 'group-1',
|
||||
pageSize: 9999,
|
||||
threadId: 'thread-1',
|
||||
topicId: 'topic-1',
|
||||
},
|
||||
expect.objectContaining({
|
||||
postProcessUrl: expect.any(Function),
|
||||
}),
|
||||
);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith({
|
||||
agentId: 'agent-1',
|
||||
current: 0,
|
||||
groupId: 'group-1',
|
||||
pageSize: 9999,
|
||||
threadId: 'thread-1',
|
||||
topicId: 'topic-1',
|
||||
});
|
||||
expect(result.id).toBe('msg-1');
|
||||
expect(result.messages).toEqual(mockMessages);
|
||||
});
|
||||
@@ -387,12 +369,11 @@ describe('MessageService', () => {
|
||||
const result = await messageService.removeMessage(messageId, { groupId, topicId });
|
||||
|
||||
expect(mockMessageModel.deleteMessage).toHaveBeenCalledWith(messageId);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith(
|
||||
{ groupId, sessionId: undefined, topicId },
|
||||
expect.objectContaining({
|
||||
groupAssistantMessages: false,
|
||||
}),
|
||||
);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith({
|
||||
groupId,
|
||||
sessionId: undefined,
|
||||
topicId,
|
||||
});
|
||||
expect(result).toEqual({ messages: mockMessages, success: true });
|
||||
});
|
||||
|
||||
@@ -404,12 +385,11 @@ describe('MessageService', () => {
|
||||
const result = await messageService.removeMessages(messageIds, { groupId, topicId });
|
||||
|
||||
expect(mockMessageModel.deleteMessages).toHaveBeenCalledWith(messageIds);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith(
|
||||
{ groupId, sessionId: undefined, topicId },
|
||||
expect.objectContaining({
|
||||
groupAssistantMessages: false,
|
||||
}),
|
||||
);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith({
|
||||
groupId,
|
||||
sessionId: undefined,
|
||||
topicId,
|
||||
});
|
||||
expect(result).toEqual({ messages: mockMessages, success: true });
|
||||
});
|
||||
|
||||
@@ -425,12 +405,11 @@ describe('MessageService', () => {
|
||||
});
|
||||
|
||||
expect(mockMessageModel.update).toHaveBeenCalledWith(messageId, value);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith(
|
||||
{ groupId, sessionId: undefined, topicId },
|
||||
expect.objectContaining({
|
||||
groupAssistantMessages: false,
|
||||
}),
|
||||
);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith({
|
||||
groupId,
|
||||
sessionId: undefined,
|
||||
topicId,
|
||||
});
|
||||
expect(result).toEqual({ messages: mockMessages, success: true });
|
||||
});
|
||||
|
||||
@@ -446,12 +425,11 @@ describe('MessageService', () => {
|
||||
});
|
||||
|
||||
expect(mockMessageModel.updateMetadata).toHaveBeenCalledWith(messageId, metadata);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith(
|
||||
{ groupId, sessionId: undefined, topicId },
|
||||
expect.objectContaining({
|
||||
groupAssistantMessages: false,
|
||||
}),
|
||||
);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith({
|
||||
groupId,
|
||||
sessionId: undefined,
|
||||
topicId,
|
||||
});
|
||||
expect(result).toEqual({ messages: mockMessages, success: true });
|
||||
});
|
||||
|
||||
@@ -467,12 +445,11 @@ describe('MessageService', () => {
|
||||
});
|
||||
|
||||
expect(mockMessageModel.updatePluginState).toHaveBeenCalledWith(messageId, state);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith(
|
||||
{ groupId, sessionId: undefined, topicId },
|
||||
expect.objectContaining({
|
||||
groupAssistantMessages: false,
|
||||
}),
|
||||
);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith({
|
||||
groupId,
|
||||
sessionId: undefined,
|
||||
topicId,
|
||||
});
|
||||
expect(result).toEqual({ messages: mockMessages, success: true });
|
||||
});
|
||||
|
||||
@@ -488,12 +465,11 @@ describe('MessageService', () => {
|
||||
});
|
||||
|
||||
expect(mockMessageModel.updateMessagePlugin).toHaveBeenCalledWith(messageId, { error });
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith(
|
||||
{ groupId, sessionId: undefined, topicId },
|
||||
expect.objectContaining({
|
||||
groupAssistantMessages: false,
|
||||
}),
|
||||
);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith({
|
||||
groupId,
|
||||
sessionId: undefined,
|
||||
topicId,
|
||||
});
|
||||
expect(result).toEqual({ messages: mockMessages, success: true });
|
||||
});
|
||||
|
||||
@@ -509,12 +485,11 @@ describe('MessageService', () => {
|
||||
});
|
||||
|
||||
expect(mockMessageModel.updateMessageRAG).toHaveBeenCalledWith(messageId, ragValue);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith(
|
||||
{ groupId, sessionId: undefined, topicId },
|
||||
expect.objectContaining({
|
||||
groupAssistantMessages: false,
|
||||
}),
|
||||
);
|
||||
expect(mockMessageModel.query).toHaveBeenCalledWith({
|
||||
groupId,
|
||||
sessionId: undefined,
|
||||
topicId,
|
||||
});
|
||||
expect(result).toEqual({ messages: mockMessages, success: true });
|
||||
});
|
||||
});
|
||||
|
||||
@@ -31,32 +31,16 @@ interface CreateMessageResult {
|
||||
*/
|
||||
export class MessageService {
|
||||
private messageModel: MessageModel;
|
||||
private fileService: FileService;
|
||||
private compressionRepository: CompressionRepository;
|
||||
|
||||
constructor(db: LobeChatDatabase, userId: string) {
|
||||
this.messageModel = new MessageModel(db, userId);
|
||||
this.fileService = new FileService(db, userId);
|
||||
const fileService = new FileService(db, userId);
|
||||
this.messageModel = new MessageModel(db, userId, {
|
||||
postProcessUrl: (path) => fileService.getFullFileUrl(path),
|
||||
});
|
||||
this.compressionRepository = new CompressionRepository(db, userId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Unified URL processing function
|
||||
*/
|
||||
private get postProcessUrl() {
|
||||
return (path: string | null) => this.fileService.getFullFileUrl(path);
|
||||
}
|
||||
|
||||
/**
|
||||
* Unified query options
|
||||
*/
|
||||
private getQueryOptions() {
|
||||
return {
|
||||
groupAssistantMessages: false,
|
||||
postProcessUrl: this.postProcessUrl,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Query messages and return response with success status (used after mutations)
|
||||
* Prioritize agentId, fallback to sessionId if not provided (for backwards compatibility)
|
||||
@@ -75,10 +59,13 @@ export class MessageService {
|
||||
|
||||
const { agentId, sessionId, topicId, groupId, threadId } = options;
|
||||
|
||||
const messages = await this.messageModel.query(
|
||||
{ agentId, groupId, sessionId, threadId, topicId },
|
||||
this.getQueryOptions(),
|
||||
);
|
||||
const messages = await this.messageModel.query({
|
||||
agentId,
|
||||
groupId,
|
||||
sessionId,
|
||||
threadId,
|
||||
topicId,
|
||||
});
|
||||
|
||||
return { messages, success: true };
|
||||
}
|
||||
@@ -96,19 +83,14 @@ export class MessageService {
|
||||
|
||||
// 2. Query all messages for this agent/topic
|
||||
// Use agentId field for query
|
||||
const messages = await this.messageModel.query(
|
||||
{
|
||||
agentId: params.agentId,
|
||||
current: 0,
|
||||
groupId: params.groupId,
|
||||
pageSize: 9999,
|
||||
threadId: params.threadId,
|
||||
topicId: params.topicId,
|
||||
},
|
||||
{
|
||||
postProcessUrl: this.postProcessUrl,
|
||||
},
|
||||
);
|
||||
const messages = await this.messageModel.query({
|
||||
agentId: params.agentId,
|
||||
current: 0,
|
||||
groupId: params.groupId,
|
||||
pageSize: 9999,
|
||||
threadId: params.threadId,
|
||||
topicId: params.topicId,
|
||||
});
|
||||
|
||||
// 3. Return the result
|
||||
return {
|
||||
@@ -286,10 +268,7 @@ export class MessageService {
|
||||
success: boolean;
|
||||
}> {
|
||||
// 1. Get messages that need to be summarized (before marking them as compressed)
|
||||
const allMessages = await this.messageModel.query(
|
||||
{ topicId, ...options },
|
||||
this.getQueryOptions(),
|
||||
);
|
||||
const allMessages = await this.messageModel.query({ topicId, ...options });
|
||||
|
||||
const messagesToSummarize = allMessages.filter((msg) => messageIds.includes(msg.id));
|
||||
|
||||
@@ -304,7 +283,7 @@ export class MessageService {
|
||||
});
|
||||
|
||||
// 3. Query updated messages (compressed messages will be grouped)
|
||||
const messages = await this.messageModel.query({ topicId, ...options }, this.getQueryOptions());
|
||||
const messages = await this.messageModel.query({ topicId, ...options });
|
||||
|
||||
return {
|
||||
messageGroupId,
|
||||
@@ -338,7 +317,7 @@ export class MessageService {
|
||||
|
||||
// 2. Query final messages
|
||||
const queryOptions = { agentId, groupId, threadId, topicId };
|
||||
const finalMessages = await this.messageModel.query(queryOptions, this.getQueryOptions());
|
||||
const finalMessages = await this.messageModel.query(queryOptions);
|
||||
|
||||
return {
|
||||
messages: finalMessages,
|
||||
@@ -356,7 +335,7 @@ export class MessageService {
|
||||
): Promise<{ messages: UIChatMessage[] }> {
|
||||
await this.compressionRepository.updateMetadata(messageGroupId, metadata);
|
||||
|
||||
const messages = await this.messageModel.query(context, this.getQueryOptions());
|
||||
const messages = await this.messageModel.query(context);
|
||||
|
||||
return { messages };
|
||||
}
|
||||
@@ -375,7 +354,7 @@ export class MessageService {
|
||||
await this.compressionRepository.deleteCompressionGroup(messageGroupId);
|
||||
|
||||
// Query updated messages
|
||||
const messages = await this.messageModel.query(context, this.getQueryOptions());
|
||||
const messages = await this.messageModel.query(context);
|
||||
|
||||
return { messages, success: true };
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user