From 391b16e08272d13d4c29e046467bedf09d56027e Mon Sep 17 00:00:00 2001 From: YuTengjing Date: Tue, 19 May 2026 12:53:32 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20perf:=20optimize=20chat=20?= =?UTF-8?q?bootstrap=20persistence=20(#14934)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../__tests__/messages/message.create.test.ts | 121 +- .../__tests__/topics/topic.create.test.ts | 7 +- packages/database/src/models/message.ts | 1283 ++++++++++++----- packages/database/src/models/topic.ts | 292 ++-- .../model-runtime/src/core/ModelRuntime.ts | 114 +- .../src/core/RouterRuntime/createRuntime.ts | 201 ++- packages/types/src/aiChat.test.ts | 21 + packages/types/src/aiChat.ts | 5 + packages/utils/src/index.ts | 1 + packages/utils/src/timing.ts | 173 +++ .../routers/lambda/__tests__/aiChat.test.ts | 174 ++- src/server/routers/lambda/aiChat.ts | 264 +++- src/server/routers/lambda/message.ts | 35 +- src/server/services/aiChat/index.test.ts | 15 +- src/server/services/aiChat/index.ts | 90 +- src/server/services/message/index.ts | 43 +- .../__tests__/conversationLifecycle.test.ts | 8 +- .../aiChat/actions/conversationLifecycle.ts | 19 +- 18 files changed, 2239 insertions(+), 627 deletions(-) create mode 100644 packages/types/src/aiChat.test.ts create mode 100644 packages/utils/src/timing.ts diff --git a/packages/database/src/models/__tests__/messages/message.create.test.ts b/packages/database/src/models/__tests__/messages/message.create.test.ts index 42ad9ed236..6176c17ac5 100644 --- a/packages/database/src/models/__tests__/messages/message.create.test.ts +++ b/packages/database/src/models/__tests__/messages/message.create.test.ts @@ -1,5 +1,5 @@ import type { DBMessageItem } from '@lobechat/types'; -import { eq } from 'drizzle-orm'; +import { asc, eq } from 'drizzle-orm'; import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { uuid } from '@/utils/uuid'; @@ -16,6 +16,7 @@ import { messages, messagesFiles, sessions, + topics, users, } from '../../../schemas'; import type { LobeChatDatabase } from '../../../type'; @@ -248,6 +249,124 @@ describe('MessageModel Create Tests', () => { expect(pluginResult[0].arguments).not.toContain('\u0000'); }); + it('should create user and assistant messages with one topic touch', async () => { + await serverDB.insert(topics).values({ + id: 'topic-pair', + sessionId: '1', + title: 'Topic pair', + userId, + }); + + const timingEvents: string[] = []; + const result = await messageModel.createUserAndAssistantMessages( + { + assistantMessage: { + content: '', + model: 'gpt-4o', + provider: 'openai', + role: 'assistant', + sessionId: '1', + topicId: 'topic-pair', + }, + userMessage: { + content: 'hello', + files: ['f1'], + role: 'user', + sessionId: '1', + topicId: 'topic-pair', + }, + }, + { + timing: { + log: (event) => timingEvents.push(event), + }, + }, + ); + + expect(result.userMessage.id).toBeDefined(); + expect(result.assistantMessage.id).toBeDefined(); + expect(result.assistantMessage.parentId).toBe(result.userMessage.id); + expect(result.userMessage.createdAt.getTime()).toBeLessThan( + result.assistantMessage.createdAt.getTime(), + ); + + const dbMessages = await serverDB + .select() + .from(messages) + .where(eq(messages.userId, userId)) + .orderBy(asc(messages.createdAt)); + + expect(dbMessages.map((message) => message.id)).toEqual([ + result.userMessage.id, + result.assistantMessage.id, + ]); + + const messageFiles = await serverDB + .select() + .from(messagesFiles) + .where(eq(messagesFiles.messageId, result.userMessage.id)); + + expect(messageFiles).toHaveLength(1); + expect( + timingEvents.filter( + (event) => event === 'db.message.createUserAndAssistant.messages.insert:start', + ), + ).toHaveLength(1); + expect( + timingEvents.filter( + (event) => event === 'db.message.createUserAndAssistant.topic.touchUpdatedAt:start', + ), + ).toHaveLength(1); + }); + + it('should skip topic touch when creating a pair for an already-created topic', async () => { + await serverDB.insert(topics).values({ + id: 'topic-pair-no-touch', + sessionId: '1', + title: 'Topic pair no touch', + userId, + }); + + const timingEvents: string[] = []; + const result = await messageModel.createUserAndAssistantMessages( + { + assistantMessage: { + content: '', + model: 'gpt-4o', + provider: 'openai', + role: 'assistant', + sessionId: '1', + topicId: 'topic-pair-no-touch', + }, + userMessage: { + content: 'hello', + role: 'user', + sessionId: '1', + topicId: 'topic-pair-no-touch', + }, + }, + { + timing: { + log: (event) => timingEvents.push(event), + }, + touchTopicUpdatedAt: false, + }, + ); + + expect(result.userMessage.id).toBeDefined(); + expect(result.assistantMessage.parentId).toBe(result.userMessage.id); + expect( + timingEvents.filter( + (event) => event === 'db.message.createUserAndAssistant.messages.insert:start', + ), + ).toHaveLength(1); + expect( + timingEvents.filter( + (event) => event === 'db.message.createUserAndAssistant.topic.touchUpdatedAt:start', + ), + ).toHaveLength(0); + }); + describe('create with advanced parameters', () => { it('should create a message with custom ID', async () => { const customId = 'custom-msg-id'; diff --git a/packages/database/src/models/__tests__/topics/topic.create.test.ts b/packages/database/src/models/__tests__/topics/topic.create.test.ts index 451d79d953..fdef1d23ad 100644 --- a/packages/database/src/models/__tests__/topics/topic.create.test.ts +++ b/packages/database/src/models/__tests__/topics/topic.create.test.ts @@ -95,7 +95,10 @@ describe('TopicModel - Create', () => { const topicId = 'new-topic'; - const createdTopic = await topicModel.create(topicData, topicId); + const timingEvents: string[] = []; + const createdTopic = await topicModel.create(topicData, topicId, { + log: (event) => timingEvents.push(event), + }); expect(createdTopic).toEqual({ id: topicId, @@ -123,6 +126,8 @@ describe('TopicModel - Create', () => { const dbTopic = await serverDB.select().from(topics).where(eq(topics.id, topicId)); expect(dbTopic).toHaveLength(1); expect(dbTopic[0]).toEqual(createdTopic); + expect(timingEvents).toContain('db.topic.create.topics.insert:start'); + expect(timingEvents).not.toContain('db.topic.create.transaction:start'); }); it('should create a new topic with agentId', async () => { diff --git a/packages/database/src/models/message.ts b/packages/database/src/models/message.ts index 9945566f25..2b5bd87182 100644 --- a/packages/database/src/models/message.ts +++ b/packages/database/src/models/message.ts @@ -21,6 +21,12 @@ import type { UpdateMessageRAGParams, } from '@lobechat/types'; import { MessageGroupType, ThreadType } from '@lobechat/types'; +import type { TimingSink } from '@lobechat/utils'; +import { + getDurationMs, + logTimingSink as logTiming, + runTimedSinkStage as runTimedStage, +} from '@lobechat/utils'; import type { HeatmapsProps } from '@lobehub/charts'; import dayjs from 'dayjs'; import type { SQL } from 'drizzle-orm'; @@ -84,6 +90,7 @@ export interface QueryMessagesOptions { * Post-process function for file URLs */ postProcessUrl?: (path: string | null, file: { fileType: string }) => Promise; + timing?: ModelTimingContext; /** * Topic ID for MessageGroup aggregation queries */ @@ -94,6 +101,92 @@ export interface QueryMessagesOptions { where?: SQL; } +export interface ModelTimingContext extends TimingSink {} + +interface MessageRelatedFile { + fileType: string | null; + id: string; + messageId: string; + name: string | null; + size: number | null; + url: string; +} + +interface MessageChunkRelation { + fileId: string; + filename: string | null; + fileType: string | null; + fileUrl: string | null; + id: string | null; + messageId: string | null; + similarity: string | null; + text: string | null; +} + +interface MessageQueryRelation { + id: string; + messageId: string; + rewriteQuery: string | null; + userQuery: string | null; +} + +interface MessageThreadRelation { + metadata: unknown; + sourceMessageId: string | null; + status: string | null; + threadId: string; + title: string | null; +} + +interface MessageFileRelations { + documentsMap: Record; + relatedFileList: MessageRelatedFile[]; +} + +interface CreateUserAndAssistantMessagesParams { + assistantMessage: CreateMessageParams; + userMessage: CreateMessageParams; +} + +interface CreateUserAndAssistantMessagesOptions { + timing?: ModelTimingContext; + touchTopicUpdatedAt?: boolean; +} + +interface CreateMessageInsertParams { + createdAt?: CreateMessageParams['createdAt']; + fromModel?: CreateMessageParams['model']; + fromProvider?: CreateMessageParams['provider']; + message: Omit< + CreateMessageParams, + | 'createdAt' + | 'fileChunks' + | 'files' + | 'model' + | 'plugin' + | 'pluginIntervention' + | 'pluginState' + | 'provider' + | 'ragQueryId' + | 'updatedAt' + >; + updatedAt?: CreateMessageParams['updatedAt']; +} + +interface CreateMessageRelationParams { + fileChunks?: CreateMessageParams['fileChunks']; + files?: CreateMessageParams['files']; + plugin?: CreateMessageParams['plugin']; + pluginIntervention?: CreateMessageParams['pluginIntervention']; + pluginState?: CreateMessageParams['pluginState']; + ragQueryId?: CreateMessageParams['ragQueryId']; +} + +interface SplitCreateMessageParams { + insert: CreateMessageInsertParams; + relations: CreateMessageRelationParams; +} + export class MessageModel { private userId: string; private db: LobeChatDatabase; @@ -134,26 +227,55 @@ export class MessageModel { }: QueryMessageParams = {}, options: { postProcessUrl?: (path: string | null, file: { fileType: string }) => Promise; + timing?: ModelTimingContext; } = {}, ) => { + const queryStartedAt = Date.now(); + const timing = options.timing; + logTiming(timing, 'db.message.query:start', { + current, + hasAgentId: !!agentId, + hasGroupId: !!groupId, + hasSessionId: !!sessionId, + hasThreadId: !!threadId, + hasTopicId: !!topicId, + pageSize, + }); + // Build agent condition (handles legacy sessionId lookup) let agentCondition: SQL | undefined; if (agentId) { - agentCondition = await this.buildAgentCondition(agentId); + agentCondition = await runTimedStage( + timing, + 'db.message.query.buildAgentCondition', + () => this.buildAgentCondition(agentId), + { hasAgentId: true }, + ); } else if (sessionId) { agentCondition = this.matchSession(sessionId); } // For thread queries, we need to fetch complete thread data (parent + thread messages) if (threadId) { - const threadCondition = await this.buildThreadQueryCondition(threadId); - return this.queryWithWhere({ + const threadCondition = await runTimedStage( + timing, + 'db.message.query.buildThreadCondition', + () => this.buildThreadQueryCondition(threadId), + { hasThreadId: true }, + ); + const messageItems = await this.queryWithWhere({ current, pageSize, postProcessUrl: options.postProcessUrl, + timing, // Thread queries optionally add agent/session scope if provided where: agentCondition ? and(agentCondition, threadCondition) : threadCondition, }); + logTiming(timing, 'db.message.query:done', { + messageCount: messageItems.length, + stageMs: getDurationMs(queryStartedAt), + }); + return messageItems; } // For Group Chat queries: filter by groupId only (not agentId) @@ -166,13 +288,19 @@ export class MessageModel { this.matchThread(threadId), ); - return this.queryWithWhere({ + const messageItems = await this.queryWithWhere({ current, pageSize, postProcessUrl: options.postProcessUrl, + timing, topicId: topicId ?? undefined, where: whereCondition, }); + logTiming(timing, 'db.message.query:done', { + messageCount: messageItems.length, + stageMs: getDurationMs(queryStartedAt), + }); + return messageItems; } // Standard query with session/topic/group filters @@ -183,13 +311,19 @@ export class MessageModel { this.matchThread(threadId), ); - return this.queryWithWhere({ + const messageItems = await this.queryWithWhere({ current, pageSize, postProcessUrl: options.postProcessUrl, + timing, topicId: topicId ?? undefined, where: whereCondition, }); + logTiming(timing, 'db.message.query:done', { + messageCount: messageItems.length, + stageMs: getDurationMs(queryStartedAt), + }); + return messageItems; }; /** @@ -208,162 +342,122 @@ export class MessageModel { * @returns Messages with all related data, including MessageGroup nodes */ queryWithWhere = async (options: QueryMessagesOptions = {}): Promise => { - const { where, current = 0, pageSize = 1000, postProcessUrl, topicId } = options; + const { where, current = 0, pageSize = 1000, postProcessUrl, topicId, timing } = options; + const totalStartedAt = Date.now(); const offset = current * pageSize; // 1. get basic messages with joins, excluding messages that belong to MessageGroups - const result = await this.db - .select({ - id: messages.id, - role: messages.role, - content: messages.content, - editorData: messages.editorData, - reasoning: messages.reasoning, - search: messages.search, - metadata: messages.metadata, - error: messages.error, + const result = await runTimedStage( + timing, + 'db.message.queryWithWhere.baseSelect', + () => + this.db + .select({ + id: messages.id, + role: messages.role, + content: messages.content, + editorData: messages.editorData, + reasoning: messages.reasoning, + search: messages.search, + metadata: messages.metadata, + error: messages.error, - model: messages.model, - provider: messages.provider, + model: messages.model, + provider: messages.provider, - createdAt: messages.createdAt, - updatedAt: messages.updatedAt, + createdAt: messages.createdAt, + updatedAt: messages.updatedAt, - sessionId: messages.sessionId, - topicId: messages.topicId, - parentId: messages.parentId, - threadId: messages.threadId, + sessionId: messages.sessionId, + topicId: messages.topicId, + parentId: messages.parentId, + threadId: messages.threadId, - // Group chat fields - groupId: messages.groupId, - agentId: messages.agentId, - targetId: messages.targetId, + // Group chat fields + groupId: messages.groupId, + agentId: messages.agentId, + targetId: messages.targetId, - tools: messages.tools, - tool_call_id: messagePlugins.toolCallId, + tools: messages.tools, + tool_call_id: messagePlugins.toolCallId, - plugin: { - apiName: messagePlugins.apiName, - arguments: messagePlugins.arguments, - identifier: messagePlugins.identifier, - type: messagePlugins.type, - }, - pluginError: messagePlugins.error, - pluginIntervention: messagePlugins.intervention, - pluginState: messagePlugins.state, + plugin: { + apiName: messagePlugins.apiName, + arguments: messagePlugins.arguments, + identifier: messagePlugins.identifier, + type: messagePlugins.type, + }, + pluginError: messagePlugins.error, + pluginIntervention: messagePlugins.intervention, + pluginState: messagePlugins.state, - translate: { - content: messageTranslates.content, - from: messageTranslates.from, - to: messageTranslates.to, - }, + translate: { + content: messageTranslates.content, + from: messageTranslates.from, + to: messageTranslates.to, + }, - ttsId: messageTTS.id, - ttsContentMd5: messageTTS.contentMd5, - ttsFile: messageTTS.fileId, - ttsVoice: messageTTS.voice, - }) - .from(messages) - .where( - and( - eq(messages.userId, this.userId), - // Filter out messages that belong to MessageGroups - isNull(messages.messageGroupId), - where, - ), - ) - .leftJoin(messagePlugins, eq(messagePlugins.id, messages.id)) - .leftJoin(messageTranslates, eq(messageTranslates.id, messages.id)) - .leftJoin(messageTTS, eq(messageTTS.id, messages.id)) - .orderBy(asc(messages.createdAt)) - .limit(pageSize) - .offset(offset); + ttsId: messageTTS.id, + ttsContentMd5: messageTTS.contentMd5, + ttsFile: messageTTS.fileId, + ttsVoice: messageTTS.voice, + }) + .from(messages) + .where( + and( + eq(messages.userId, this.userId), + // Filter out messages that belong to MessageGroups + isNull(messages.messageGroupId), + where, + ), + ) + .leftJoin(messagePlugins, eq(messagePlugins.id, messages.id)) + .leftJoin(messageTranslates, eq(messageTranslates.id, messages.id)) + .leftJoin(messageTTS, eq(messageTTS.id, messages.id)) + .orderBy(asc(messages.createdAt)) + .limit(pageSize) + .offset(offset), + { current, pageSize }, + ); + logTiming(timing, 'db.message.queryWithWhere.baseSelect:rows', { rowCount: result.length }); const messageIds = result.map((message) => message.id as string); - // 2. Query MessageGroups for this topic (if topicId is available) - // For pagination support: - // - First page (current === 0): fetch all MessageGroup nodes (no time filter) - // - Subsequent pages: only fetch groups within the current page's time range - let messageGroupNodes: UIChatMessage[] = []; - if (topicId && result.length > 0) { - if (current === 0) { - // First page: fetch all groups to include compressed history - messageGroupNodes = await this.queryMessageGroupNodes(topicId, undefined, postProcessUrl); - } else { - // Subsequent pages: filter by time range to avoid duplicates - const firstMessageTime = result[0].createdAt; - const lastMessageTime = result.at(-1)!.createdAt; - messageGroupNodes = await this.queryMessageGroupNodes( - topicId, - { - endTime: lastMessageTime, - startTime: firstMessageTime, - }, - postProcessUrl, - ); - } - } else if (topicId && current === 0) { - // First page with no messages: still fetch all groups - messageGroupNodes = await this.queryMessageGroupNodes(topicId, undefined, postProcessUrl); - } + const messageGroupNodesPromise = this.queryMessageGroupNodesForPage({ + current, + postProcessUrl, + result, + timing, + topicId, + }); - // If no messages and no group nodes, return empty - if (messageIds.length === 0 && messageGroupNodes.length === 0) return []; + const taskMessageIds = result + .filter((message) => message.role === 'task') + .map((message) => { + return message.id as string; + }); - // 3. get relative files (only if we have messages) - let relatedFileList: { - fileType: string | null; - id: string; - messageId: string; - name: string | null; - size: number | null; - url: string; - }[] = []; + const [ + messageGroupNodes, + { documentsMap, relatedFileList }, + chunksList, + messageQueriesList, + threadData, + ] = await Promise.all([ + messageGroupNodesPromise, + this.queryMessageFileRelations(messageIds, postProcessUrl, timing), + this.queryMessageChunkRelations(messageIds, timing), + this.queryMessageQueryRelations(messageIds, timing), + this.queryMessageThreadRelations(taskMessageIds, timing), + ]); - if (messageIds.length > 0) { - const rawRelatedFileList = await this.db - .select({ - fileType: files.fileType, - id: messagesFiles.fileId, - messageId: messagesFiles.messageId, - name: files.name, - size: files.size, - url: files.url, - }) - .from(messagesFiles) - .leftJoin(files, eq(files.id, messagesFiles.fileId)) - .where(inArray(messagesFiles.messageId, messageIds)); - - relatedFileList = await Promise.all( - rawRelatedFileList.map(async (file) => ({ - ...file, - url: postProcessUrl ? await postProcessUrl(file.url, file as any) : (file.url as string), - })), - ); - } - - // Get associated document content - const fileIds = relatedFileList.map((file) => file.id).filter(Boolean); - - let documentsMap: Record = {}; - - if (fileIds.length > 0) { - const documentsList = await this.db - .select({ - content: documents.content, - fileId: documents.fileId, - }) - .from(documents) - .where(inArray(documents.fileId, fileIds)); - - documentsMap = documentsList.reduce( - (acc, doc) => { - if (doc.fileId) acc[doc.fileId] = doc.content as string; - return acc; - }, - {} as Record, - ); + if (messageIds.length === 0 && messageGroupNodes.length === 0) { + logTiming(timing, 'db.message.queryWithWhere:done', { + messageGroupCount: 0, + rowCount: 0, + stageMs: getDurationMs(totalStartedAt), + }); + return []; } const imageList = relatedFileList.filter((i) => (i.fileType || '').startsWith('image')); @@ -372,151 +466,75 @@ export class MessageModel { (i) => !(i.fileType || '').startsWith('image') && !(i.fileType || '').startsWith('video'), ); - // 4. get relative file chunks - let chunksList: { - fileId: string; - fileType: string | null; - fileUrl: string | null; - filename: string | null; - id: string | null; - messageId: string | null; - similarity: string | null; - text: string | null; - }[] = []; - - if (messageIds.length > 0) { - chunksList = await this.db - .select({ - fileId: files.id, - fileType: files.fileType, - fileUrl: files.url, - filename: files.name, - id: chunks.id, - messageId: messageQueryChunks.messageId, - similarity: messageQueryChunks.similarity, - text: chunks.text, - }) - .from(messageQueryChunks) - .leftJoin(chunks, eq(chunks.id, messageQueryChunks.chunkId)) - .leftJoin(fileChunks, eq(fileChunks.chunkId, chunks.id)) - .innerJoin(files, eq(fileChunks.fileId, files.id)) - .where(inArray(messageQueryChunks.messageId, messageIds)); - } - - // 5. get relative message query - let messageQueriesList: { - id: string; - messageId: string; - rewriteQuery: string | null; - userQuery: string | null; - }[] = []; - - if (messageIds.length > 0) { - messageQueriesList = await this.db - .select({ - id: messageQueries.id, - messageId: messageQueries.messageId, - rewriteQuery: messageQueries.rewriteQuery, - userQuery: messageQueries.userQuery, - }) - .from(messageQueries) - .where(inArray(messageQueries.messageId, messageIds)); - } - - // 5. get thread info for task messages - const taskMessageIds = result.filter((m) => m.role === 'task').map((m) => m.id as string); - - let threadMap = new Map(); - - if (taskMessageIds.length > 0) { - const threadData = await this.db - .select({ - metadata: threads.metadata, - sourceMessageId: threads.sourceMessageId, - status: threads.status, - threadId: threads.id, - title: threads.title, - }) - .from(threads) - .where( - and(eq(threads.userId, this.userId), inArray(threads.sourceMessageId, taskMessageIds)), - ); - - threadMap = new Map( - threadData.map((t) => { - const metadata = t.metadata as Record | null; - return [ - t.sourceMessageId!, - { - clientMode: metadata?.clientMode as boolean | undefined, - duration: metadata?.duration as number | undefined, - status: t.status as ThreadStatus, - threadId: t.threadId, - title: t.title ?? undefined, - totalCost: metadata?.totalCost as number | undefined, - totalMessages: metadata?.totalMessages as number | undefined, - totalTokens: metadata?.totalTokens as number | undefined, - totalToolCalls: metadata?.totalToolCalls as number | undefined, - }, - ]; - }), - ); - } + const threadMap = this.createThreadMap(threadData); // 6. Transform regular messages - const transformedMessages = result.map( - ({ model, provider, translate, ttsId, ttsFile, ttsContentMd5, ttsVoice, ...item }) => { - const messageQuery = messageQueriesList.find((relation) => relation.messageId === item.id); - return { - ...item, - chunksList: chunksList - .filter((relation) => relation.messageId === item.id) - .map((c) => ({ - ...c, - similarity: c.similarity === null ? undefined : Number(c.similarity), - })), + const transformedMessages = await runTimedStage( + timing, + 'db.message.queryWithWhere.transform', + () => + result.map( + ({ model, provider, translate, ttsId, ttsFile, ttsContentMd5, ttsVoice, ...item }) => { + const messageQuery = messageQueriesList.find( + (relation) => relation.messageId === item.id, + ); + return { + ...item, + chunksList: chunksList + .filter((relation) => relation.messageId === item.id) + .map((c) => ({ + ...c, + similarity: c.similarity === null ? undefined : Number(c.similarity), + })), - extra: { - model, - provider, - translate, - tts: ttsId - ? { - contentMd5: ttsContentMd5, - file: ttsFile, - voice: ttsVoice, - } - : undefined, + extra: { + model, + provider, + translate, + tts: ttsId + ? { + contentMd5: ttsContentMd5, + file: ttsFile, + voice: ttsVoice, + } + : undefined, + }, + fileList: fileList + .filter((relation) => relation.messageId === item.id) + + .map(({ id, url, size, fileType, name }) => ({ + content: documentsMap[id], + fileType: fileType!, + id, + name: name!, + size: size!, + url, + })), + imageList: imageList + .filter((relation) => relation.messageId === item.id) + + .map(({ id, url, name }) => ({ alt: name!, id, url })), + + model, + + provider, + ragQuery: messageQuery?.rewriteQuery, + ragQueryId: messageQuery?.id, + ragRawQuery: messageQuery?.userQuery, + // Add taskDetail for task messages + taskDetail: item.role === 'task' ? threadMap.get(item.id as string) : undefined, + videoList: videoList + .filter((relation) => relation.messageId === item.id) + + .map(({ id, url, name }) => ({ alt: name!, id, url })), + } as unknown as UIChatMessage; }, - fileList: fileList - .filter((relation) => relation.messageId === item.id) - - .map(({ id, url, size, fileType, name }) => ({ - content: documentsMap[id], - fileType: fileType!, - id, - name: name!, - size: size!, - url, - })), - imageList: imageList - .filter((relation) => relation.messageId === item.id) - - .map(({ id, url, name }) => ({ alt: name!, id, url })), - - model, - - provider, - ragQuery: messageQuery?.rewriteQuery, - ragQueryId: messageQuery?.id, - ragRawQuery: messageQuery?.userQuery, - // Add taskDetail for task messages - taskDetail: item.role === 'task' ? threadMap.get(item.id as string) : undefined, - videoList: videoList - .filter((relation) => relation.messageId === item.id) - - .map(({ id, url, name }) => ({ alt: name!, id, url })), - } as unknown as UIChatMessage; + ), + { + chunkCount: chunksList.length, + fileCount: relatedFileList.length, + messageQueryCount: messageQueriesList.length, + rowCount: result.length, }, ); @@ -528,9 +546,254 @@ export class MessageModel { return aTime - bTime; }); + logTiming(timing, 'db.message.queryWithWhere:done', { + messageGroupCount: messageGroupNodes.length, + resultCount: allItems.length, + rowCount: result.length, + stageMs: getDurationMs(totalStartedAt), + }); + return allItems; }; + private queryMessageGroupNodesForPage = async ({ + current, + postProcessUrl, + result, + timing, + topicId, + }: { + current: number; + postProcessUrl?: (path: string | null, file: { fileType: string }) => Promise; + result: { createdAt: Date }[]; + timing?: ModelTimingContext; + topicId?: string; + }): Promise => { + if (!topicId) return []; + + if (result.length === 0) { + if (current !== 0) return []; + + return runTimedStage( + timing, + 'db.message.queryWithWhere.messageGroups', + () => this.queryMessageGroupNodes(topicId, undefined, postProcessUrl, timing), + { current, hasMessages: false, topicId }, + ); + } + + if (current === 0) { + return runTimedStage( + timing, + 'db.message.queryWithWhere.messageGroups', + () => this.queryMessageGroupNodes(topicId, undefined, postProcessUrl, timing), + { current, hasMessages: true, topicId }, + ); + } + + const firstMessageTime = result[0].createdAt; + const lastMessageTime = result.at(-1)!.createdAt; + + return runTimedStage( + timing, + 'db.message.queryWithWhere.messageGroups', + () => + this.queryMessageGroupNodes( + topicId, + { + endTime: lastMessageTime, + startTime: firstMessageTime, + }, + postProcessUrl, + timing, + ), + { current, hasMessages: true, topicId }, + ); + }; + + private queryMessageFileRelations = async ( + messageIds: string[], + postProcessUrl: QueryMessagesOptions['postProcessUrl'], + timing?: ModelTimingContext, + ): Promise => { + if (messageIds.length === 0) return { documentsMap: {}, relatedFileList: [] }; + + const rawRelatedFileList = await runTimedStage( + timing, + 'db.message.queryWithWhere.relatedFiles.select', + () => + this.db + .select({ + fileType: files.fileType, + id: messagesFiles.fileId, + messageId: messagesFiles.messageId, + name: files.name, + size: files.size, + url: files.url, + }) + .from(messagesFiles) + .leftJoin(files, eq(files.id, messagesFiles.fileId)) + .where(inArray(messagesFiles.messageId, messageIds)), + { messageCount: messageIds.length }, + ); + logTiming(timing, 'db.message.queryWithWhere.relatedFiles.select:rows', { + rowCount: rawRelatedFileList.length, + }); + + const relatedFileList = await runTimedStage( + timing, + 'db.message.queryWithWhere.relatedFiles.postProcess', + () => + Promise.all( + rawRelatedFileList.map(async (file) => ({ + ...file, + url: postProcessUrl + ? await postProcessUrl(file.url, file as unknown as { fileType: string }) + : (file.url as string), + })), + ), + { fileCount: rawRelatedFileList.length }, + ); + + const fileIds = relatedFileList.map((file) => file.id).filter(Boolean); + + if (fileIds.length === 0) return { documentsMap: {}, relatedFileList }; + + const documentsList = await runTimedStage( + timing, + 'db.message.queryWithWhere.documents.select', + () => + this.db + .select({ + content: documents.content, + fileId: documents.fileId, + }) + .from(documents) + .where(inArray(documents.fileId, fileIds)), + { fileCount: fileIds.length }, + ); + + const documentsMap = documentsList.reduce( + (acc, doc) => { + if (doc.fileId) acc[doc.fileId] = doc.content as string; + return acc; + }, + {} as Record, + ); + + return { documentsMap, relatedFileList }; + }; + + private queryMessageChunkRelations = async ( + messageIds: string[], + timing?: ModelTimingContext, + ): Promise => { + if (messageIds.length === 0) return []; + + const chunksList = await runTimedStage( + timing, + 'db.message.queryWithWhere.chunks.select', + () => + this.db + .select({ + fileId: files.id, + fileType: files.fileType, + fileUrl: files.url, + filename: files.name, + id: chunks.id, + messageId: messageQueryChunks.messageId, + similarity: messageQueryChunks.similarity, + text: chunks.text, + }) + .from(messageQueryChunks) + .leftJoin(chunks, eq(chunks.id, messageQueryChunks.chunkId)) + .leftJoin(fileChunks, eq(fileChunks.chunkId, chunks.id)) + .innerJoin(files, eq(fileChunks.fileId, files.id)) + .where(inArray(messageQueryChunks.messageId, messageIds)), + { messageCount: messageIds.length }, + ); + logTiming(timing, 'db.message.queryWithWhere.chunks.select:rows', { + rowCount: chunksList.length, + }); + + return chunksList; + }; + + private queryMessageQueryRelations = async ( + messageIds: string[], + timing?: ModelTimingContext, + ): Promise => { + if (messageIds.length === 0) return []; + + const messageQueriesList = await runTimedStage( + timing, + 'db.message.queryWithWhere.messageQueries.select', + () => + this.db + .select({ + id: messageQueries.id, + messageId: messageQueries.messageId, + rewriteQuery: messageQueries.rewriteQuery, + userQuery: messageQueries.userQuery, + }) + .from(messageQueries) + .where(inArray(messageQueries.messageId, messageIds)), + { messageCount: messageIds.length }, + ); + logTiming(timing, 'db.message.queryWithWhere.messageQueries.select:rows', { + rowCount: messageQueriesList.length, + }); + + return messageQueriesList; + }; + + private queryMessageThreadRelations = async ( + taskMessageIds: string[], + timing?: ModelTimingContext, + ): Promise => { + if (taskMessageIds.length === 0) return []; + + return runTimedStage( + timing, + 'db.message.queryWithWhere.taskThreads.select', + () => + this.db + .select({ + metadata: threads.metadata, + sourceMessageId: threads.sourceMessageId, + status: threads.status, + threadId: threads.id, + title: threads.title, + }) + .from(threads) + .where( + and(eq(threads.userId, this.userId), inArray(threads.sourceMessageId, taskMessageIds)), + ), + { taskMessageCount: taskMessageIds.length }, + ); + }; + + private createThreadMap = (threadData: MessageThreadRelation[]) => + new Map( + threadData.map((thread) => { + const metadata = thread.metadata as Record | null; + return [ + thread.sourceMessageId!, + { + clientMode: metadata?.clientMode as boolean | undefined, + duration: metadata?.duration as number | undefined, + status: thread.status as ThreadStatus, + threadId: thread.threadId, + title: thread.title ?? undefined, + totalCost: metadata?.totalCost as number | undefined, + totalMessages: metadata?.totalMessages as number | undefined, + totalTokens: metadata?.totalTokens as number | undefined, + totalToolCalls: metadata?.totalToolCalls as number | undefined, + }, + ]; + }), + ); + /** * Query messages by their IDs with full relations * @@ -804,6 +1067,7 @@ export class MessageModel { topicId: string, timeRange?: { endTime: Date; startTime: Date }, postProcessUrl?: (path: string | null, file: { fileType: string }) => Promise, + timing?: ModelTimingContext, ): Promise => { // 1. Query MessageGroups for this topic, optionally filtered by time range const whereConditions = [ @@ -819,30 +1083,51 @@ export class MessageModel { ); } - const groups = await this.db - .select() - .from(messageGroups) - .where(and(...whereConditions)) - .orderBy(asc(messageGroups.createdAt)); + const groups = await runTimedStage( + timing, + 'db.message.messageGroups.groups.select', + () => + this.db + .select() + .from(messageGroups) + .where(and(...whereConditions)) + .orderBy(asc(messageGroups.createdAt)), + { hasTimeRange: !!timeRange, topicId }, + ); + logTiming(timing, 'db.message.messageGroups.groups.select:rows', { rowCount: groups.length }); if (groups.length === 0) return []; const groupIds = groups.map((g) => g.id); // 2. Get all message IDs that belong to these groups (using messageGroupId relation) - const groupMessageRecords = await this.db - .select({ - favorite: messages.favorite, - id: messages.id, - messageGroupId: messages.messageGroupId, - }) - .from(messages) - .where(and(eq(messages.userId, this.userId), inArray(messages.messageGroupId, groupIds))) - .orderBy(asc(messages.createdAt)); + const groupMessageRecords = await runTimedStage( + timing, + 'db.message.messageGroups.messages.select', + () => + this.db + .select({ + favorite: messages.favorite, + id: messages.id, + messageGroupId: messages.messageGroupId, + }) + .from(messages) + .where(and(eq(messages.userId, this.userId), inArray(messages.messageGroupId, groupIds))) + .orderBy(asc(messages.createdAt)), + { groupCount: groupIds.length }, + ); + logTiming(timing, 'db.message.messageGroups.messages.select:rows', { + rowCount: groupMessageRecords.length, + }); // 3. Query full message data using queryByIds (reuses all transformation logic) const allMessageIds = groupMessageRecords.map((m) => m.id as string); - const fullMessages = await this.queryByIds(allMessageIds, { postProcessUrl }); + const fullMessages = await runTimedStage( + timing, + 'db.message.messageGroups.queryByIds', + () => this.queryByIds(allMessageIds, { postProcessUrl }), + { messageCount: allMessageIds.length }, + ); // Create a map for quick lookup const messageMap = new Map(fullMessages.map((m) => [m.id, m])); @@ -1244,45 +1529,76 @@ export class MessageModel { // **************** Create *************** // - create = async ( - { - model: fromModel, - provider: fromProvider, + private splitCreateMessageParams = ({ + fileChunks, + files, + model: fromModel, + plugin, + pluginIntervention, + pluginState, + provider: fromProvider, + ragQueryId, + updatedAt, + createdAt, + ...message + }: CreateMessageParams): SplitCreateMessageParams => ({ + insert: { + createdAt, + fromModel, + fromProvider, + message, + updatedAt, + }, + relations: { + fileChunks, files, plugin, pluginIntervention, pluginState, - fileChunks, ragQueryId, - updatedAt, - createdAt, - ...message - }: CreateMessageParams, - id: string = this.genId(), - ): Promise => { - return this.db.transaction(async (trx) => { - // Ensure group message does not populate sessionId - const normalizedMessage = message.groupId ? { ...message, sessionId: null } : message; + }, + }); - const [item] = (await trx - .insert(messages) - .values({ - ...normalizedMessage, - // Sanitize content to strip null bytes that PostgreSQL rejects - content: sanitizeNullBytes(normalizedMessage.content), - // TODO: remove this when the client is updated - createdAt: createdAt ? new Date(createdAt) : undefined, - id, - model: fromModel, - provider: fromProvider, - updatedAt: updatedAt ? new Date(updatedAt) : undefined, - userId: this.userId, - }) - .returning()) as DBMessageItem[]; + private buildMessageInsertValue = ( + { createdAt, fromModel, fromProvider, message, updatedAt }: CreateMessageInsertParams, + id: string, + ) => { + // Ensure group message does not populate sessionId + const normalizedMessage = message.groupId ? { ...message, sessionId: null } : message; - // Insert the plugin data if the message is a tool - if (message.role === 'tool') { - await trx.insert(messagePlugins).values({ + return { + ...normalizedMessage, + // Sanitize content to strip null bytes that PostgreSQL rejects + content: sanitizeNullBytes(normalizedMessage.content), + // TODO: remove this when the client is updated + createdAt: createdAt ? new Date(createdAt) : undefined, + id, + model: fromModel, + provider: fromProvider, + updatedAt: updatedAt ? new Date(updatedAt) : undefined, + userId: this.userId, + }; + }; + + private insertMessageRelationsInTransaction = async ( + trx: Transaction, + { + fileChunks, + files, + plugin, + pluginIntervention, + pluginState, + ragQueryId, + }: CreateMessageRelationParams, + message: CreateMessageInsertParams['message'], + id: string, + timing?: ModelTimingContext, + timingPrefix: string = 'db.message.create', + ): Promise => { + // Insert the plugin data if the message is a tool + if (message.role === 'tool') { + await runTimedStage(timing, `${timingPrefix}.plugin.insert`, () => + trx.insert(messagePlugins).values({ apiName: plugin?.apiName, arguments: sanitizeNullBytes(plugin?.arguments), id, @@ -1292,34 +1608,196 @@ export class MessageModel { toolCallId: message.tool_call_id, type: plugin?.type, userId: this.userId, - }); - } + }), + ); + } - if (files && files.length > 0) { - await trx - .insert(messagesFiles) - .values(files.map((file) => ({ fileId: file, messageId: id, userId: this.userId }))); - } + if (files && files.length > 0) { + await runTimedStage( + timing, + `${timingPrefix}.files.insert`, + () => + trx + .insert(messagesFiles) + .values(files.map((file) => ({ fileId: file, messageId: id, userId: this.userId }))), + { fileCount: files.length }, + ); + } - if (fileChunks && fileChunks.length > 0 && ragQueryId) { - await trx.insert(messageQueryChunks).values( - fileChunks.map((chunk) => ({ - chunkId: chunk.id, - messageId: id, - queryId: ragQueryId, - similarity: chunk.similarity?.toString(), - userId: this.userId, - })), - ); - } + if (fileChunks && fileChunks.length > 0 && ragQueryId) { + await runTimedStage( + timing, + `${timingPrefix}.fileChunks.insert`, + () => + trx.insert(messageQueryChunks).values( + fileChunks.map((chunk) => ({ + chunkId: chunk.id, + messageId: id, + queryId: ragQueryId, + similarity: chunk.similarity?.toString(), + userId: this.userId, + })), + ), + { chunkCount: fileChunks.length }, + ); + } + }; - // Touch topic's updatedAt when creating a message in a topic - if (message.topicId) { - await this.touchTopicUpdatedAt(trx, [message.topicId]); - } + private createInTransaction = async ( + trx: Transaction, + params: CreateMessageParams, + id: string, + timing?: ModelTimingContext, + timingPrefix: string = 'db.message.create', + ): Promise => { + const { insert, relations } = this.splitCreateMessageParams(params); - return item; - }); + const [item] = (await runTimedStage( + timing, + `${timingPrefix}.messages.insert`, + () => trx.insert(messages).values(this.buildMessageInsertValue(insert, id)).returning(), + { + hasGroupId: !!insert.message.groupId, + hasTopicId: !!insert.message.topicId, + role: insert.message.role, + }, + )) as DBMessageItem[]; + + await this.insertMessageRelationsInTransaction( + trx, + relations, + insert.message, + id, + timing, + timingPrefix, + ); + + return item; + }; + + create = async ( + params: CreateMessageParams, + id: string = this.genId(), + timing?: ModelTimingContext, + ): Promise => { + return runTimedStage( + timing, + 'db.message.create.transaction', + () => + this.db.transaction(async (trx) => { + const item = await this.createInTransaction(trx, params, id, timing); + + // Touch topic's updatedAt when creating a message in a topic + if (params.topicId) { + await runTimedStage( + timing, + 'db.message.create.topic.touchUpdatedAt', + () => this.touchTopicUpdatedAt(trx, [params.topicId!]), + { topicCount: 1 }, + ); + } + + return item; + }), + { + fileChunkCount: params.fileChunks?.length ?? 0, + fileCount: params.files?.length ?? 0, + hasTopicId: !!params.topicId, + role: params.role, + }, + ); + }; + + createUserAndAssistantMessages = async ( + { userMessage, assistantMessage }: CreateUserAndAssistantMessagesParams, + { timing, touchTopicUpdatedAt = true }: CreateUserAndAssistantMessagesOptions = {}, + ): Promise<{ assistantMessage: DBMessageItem; userMessage: DBMessageItem }> => { + const userMessageId = this.genId(); + const assistantMessageId = this.genId(); + const createdAt = Date.now(); + const defaultUserCreatedAt = createdAt; + const defaultAssistantCreatedAt = createdAt + 1; + const userMessageWithTimestamp = { + ...userMessage, + createdAt: userMessage.createdAt ?? defaultUserCreatedAt, + updatedAt: + userMessage.updatedAt ?? (userMessage.createdAt ? undefined : defaultUserCreatedAt), + }; + const assistantMessageWithParent = { + ...assistantMessage, + createdAt: assistantMessage.createdAt ?? defaultAssistantCreatedAt, + parentId: userMessageId, + updatedAt: + assistantMessage.updatedAt ?? + (assistantMessage.createdAt ? undefined : defaultAssistantCreatedAt), + }; + const topicIds = [ + ...new Set([userMessage.topicId, assistantMessage.topicId].filter(Boolean) as string[]), + ]; + + return runTimedStage( + timing, + 'db.message.createUserAndAssistant.transaction', + () => + this.db.transaction(async (trx) => { + const userPayload = this.splitCreateMessageParams(userMessageWithTimestamp); + const assistantPayload = this.splitCreateMessageParams(assistantMessageWithParent); + const insertedMessages = (await runTimedStage( + timing, + 'db.message.createUserAndAssistant.messages.insert', + () => + trx + .insert(messages) + .values([ + this.buildMessageInsertValue(userPayload.insert, userMessageId), + this.buildMessageInsertValue(assistantPayload.insert, assistantMessageId), + ]) + .returning(), + { hasTopicId: topicIds.length > 0, messageCount: 2 }, + )) as DBMessageItem[]; + const messageMap = new Map(insertedMessages.map((message) => [message.id, message])); + + await this.insertMessageRelationsInTransaction( + trx, + userPayload.relations, + userPayload.insert.message, + userMessageId, + timing, + 'db.message.createUserAndAssistant.user', + ); + await this.insertMessageRelationsInTransaction( + trx, + assistantPayload.relations, + assistantPayload.insert.message, + assistantMessageId, + timing, + 'db.message.createUserAndAssistant.assistant', + ); + + if (touchTopicUpdatedAt && topicIds.length > 0) { + await runTimedStage( + timing, + 'db.message.createUserAndAssistant.topic.touchUpdatedAt', + () => this.touchTopicUpdatedAt(trx, topicIds), + { topicCount: topicIds.length }, + ); + } + + const userMessageItem = messageMap.get(userMessageId); + const assistantMessageItem = messageMap.get(assistantMessageId); + + if (!userMessageItem || !assistantMessageItem) { + throw new Error('Failed to create user and assistant messages'); + } + + return { assistantMessage: assistantMessageItem, userMessage: userMessageItem }; + }), + { + assistantFileCount: assistantMessage.files?.length ?? 0, + hasTopicId: topicIds.length > 0, + userFileCount: userMessage.files?.length ?? 0, + }, + ); }; batchCreate = async (newMessages: DBMessageItem[]) => { @@ -1352,39 +1830,74 @@ export class MessageModel { update = async ( id: string, { imageList, metadata, ...message }: Partial, + timing?: ModelTimingContext, ): Promise<{ success: boolean }> => { try { - await this.db.transaction(async (trx) => { - // 1. insert message files - if (imageList && imageList.length > 0) { - await trx - .insert(messagesFiles) - .values( - imageList.map((file) => ({ fileId: file.id, messageId: id, userId: this.userId })), + await runTimedStage( + timing, + 'db.message.update.transaction', + () => + this.db.transaction(async (trx) => { + // 1. insert message files + if (imageList && imageList.length > 0) { + await runTimedStage( + timing, + 'db.message.update.imageFiles.insert', + () => + trx.insert(messagesFiles).values( + imageList.map((file) => ({ + fileId: file.id, + messageId: id, + userId: this.userId, + })), + ), + { imageCount: imageList.length }, + ); + } + + // 2. Handle metadata merge if provided + let mergedMetadata: Record | undefined; + if (metadata) { + const [existingMessage] = await runTimedStage( + timing, + 'db.message.update.metadata.select', + () => + trx + .select({ metadata: messages.metadata }) + .from(messages) + .where(and(eq(messages.id, id), eq(messages.userId, this.userId))), + ); + mergedMetadata = merge(existingMessage?.metadata || {}, metadata); + } + + const [updated] = await runTimedStage( + timing, + 'db.message.update.messages.update', + () => + trx + .update(messages) + .set({ ...message, ...(mergedMetadata && { metadata: mergedMetadata }) }) + .where(and(eq(messages.id, id), eq(messages.userId, this.userId))) + .returning({ topicId: messages.topicId }), + { hasMetadata: !!metadata, valueKeys: Object.keys(message) }, ); - } - // 2. Handle metadata merge if provided - let mergedMetadata: Record | undefined; - if (metadata) { - const [existingMessage] = await trx - .select({ metadata: messages.metadata }) - .from(messages) - .where(and(eq(messages.id, id), eq(messages.userId, this.userId))); - mergedMetadata = merge(existingMessage?.metadata || {}, metadata); - } - - const [updated] = await trx - .update(messages) - .set({ ...message, ...(mergedMetadata && { metadata: mergedMetadata }) }) - .where(and(eq(messages.id, id), eq(messages.userId, this.userId))) - .returning({ topicId: messages.topicId }); - - // Touch topic's updatedAt when updating a message - if (updated?.topicId) { - await this.touchTopicUpdatedAt(trx, [updated.topicId]); - } - }); + // Touch topic's updatedAt when updating a message + if (updated?.topicId) { + await runTimedStage( + timing, + 'db.message.update.topic.touchUpdatedAt', + () => this.touchTopicUpdatedAt(trx, [updated.topicId!]), + { topicCount: 1 }, + ); + } + }), + { + hasImageList: !!imageList?.length, + hasMetadata: !!metadata, + valueKeys: Object.keys(message), + }, + ); return { success: true }; } catch (error) { diff --git a/packages/database/src/models/topic.ts b/packages/database/src/models/topic.ts index 539c030440..24e095d8de 100644 --- a/packages/database/src/models/topic.ts +++ b/packages/database/src/models/topic.ts @@ -4,6 +4,12 @@ import type { DBMessageItem, TopicRankItem, } from '@lobechat/types'; +import type { TimingSink } from '@lobechat/utils'; +import { + getDurationMs, + logTimingSink as logTiming, + runTimedSinkStage as runTimedStage, +} from '@lobechat/utils'; import type { SQL } from 'drizzle-orm'; import { and, count, desc, eq, gt, gte, inArray, isNull, lte, ne, not, or, sql } from 'drizzle-orm'; @@ -62,12 +68,15 @@ interface QueryTopicParams { */ isInbox?: boolean; pageSize?: number; + timing?: ModelTimingContext; /** * Include only topics matching the given trigger types (positive filter) */ triggers?: string[]; } +export interface ModelTimingContext extends TimingSink {} + export interface ListTopicsForMemoryExtractorCursor { createdAt: Date; id: string; @@ -93,8 +102,18 @@ export class TopicModel { pageSize = 9999, groupId, isInbox, + timing, triggers, }: QueryTopicParams = {}) => { + const queryStartedAt = Date.now(); + logTiming(timing, 'db.topic.query:start', { + current, + hasAgentId: !!agentId, + hasContainerId: !!containerId, + hasGroupId: !!groupId, + isInbox: !!isInbox, + pageSize, + }); const offset = current * pageSize; const includeTriggerCondition = includeTriggers && includeTriggers.length > 0 @@ -127,29 +146,42 @@ export class TopicModel { ); const [items, totalResult] = await Promise.all([ - this.db - .select({ - completedAt: topics.completedAt, - createdAt: topics.createdAt, - favorite: topics.favorite, - historySummary: topics.historySummary, - id: topics.id, - metadata: topics.metadata, - status: topics.status, - title: topics.title, - updatedAt: topics.updatedAt, - }) - .from(topics) - .where(whereCondition) - .orderBy(desc(topics.favorite), desc(topics.updatedAt)) - .limit(pageSize) - .offset(offset), - this.db - .select({ count: count(topics.id) }) - .from(topics) - .where(whereCondition), + runTimedStage( + timing, + 'db.topic.query.group.items.select', + () => + this.db + .select({ + completedAt: topics.completedAt, + createdAt: topics.createdAt, + favorite: topics.favorite, + historySummary: topics.historySummary, + id: topics.id, + metadata: topics.metadata, + status: topics.status, + title: topics.title, + updatedAt: topics.updatedAt, + }) + .from(topics) + .where(whereCondition) + .orderBy(desc(topics.favorite), desc(topics.updatedAt)) + .limit(pageSize) + .offset(offset), + { current, pageSize }, + ), + runTimedStage(timing, 'db.topic.query.group.count.select', () => + this.db + .select({ count: count(topics.id) }) + .from(topics) + .where(whereCondition), + ), ]); + logTiming(timing, 'db.topic.query:done', { + itemCount: items.length, + stageMs: getDurationMs(queryStartedAt), + total: totalResult[0].count, + }); return { items, total: totalResult[0].count }; } @@ -159,11 +191,19 @@ export class TopicModel { // 3. For inbox: sessionId IS NULL AND groupId IS NULL AND agentId IS NULL (legacy inbox data) if (agentId) { // Get the associated sessionId for backward compatibility with legacy data - const agentSession = await this.db - .select({ sessionId: agentsToSessions.sessionId }) - .from(agentsToSessions) - .where(and(eq(agentsToSessions.agentId, agentId), eq(agentsToSessions.userId, this.userId))) - .limit(1); + const agentSession = await runTimedStage( + timing, + 'db.topic.query.agentSession.select', + () => + this.db + .select({ sessionId: agentsToSessions.sessionId }) + .from(agentsToSessions) + .where( + and(eq(agentsToSessions.agentId, agentId), eq(agentsToSessions.userId, this.userId)), + ) + .limit(1), + { hasAgentId: true }, + ); const associatedSessionId = agentSession[0]?.sessionId; @@ -201,29 +241,46 @@ export class TopicModel { ); const [items, totalResult] = await Promise.all([ - this.db - .select({ - completedAt: topics.completedAt, - createdAt: topics.createdAt, - favorite: topics.favorite, - historySummary: topics.historySummary, - id: topics.id, - metadata: topics.metadata, - status: topics.status, - title: topics.title, - updatedAt: topics.updatedAt, - }) - .from(topics) - .where(agentWhere) - .orderBy(desc(topics.favorite), desc(topics.updatedAt)) - .limit(pageSize) - .offset(offset), - this.db - .select({ count: count(topics.id) }) - .from(topics) - .where(agentWhere), + runTimedStage( + timing, + 'db.topic.query.agent.items.select', + () => + this.db + .select({ + completedAt: topics.completedAt, + createdAt: topics.createdAt, + favorite: topics.favorite, + historySummary: topics.historySummary, + id: topics.id, + metadata: topics.metadata, + status: topics.status, + title: topics.title, + updatedAt: topics.updatedAt, + }) + .from(topics) + .where(agentWhere) + .orderBy(desc(topics.favorite), desc(topics.updatedAt)) + .limit(pageSize) + .offset(offset), + { current, hasAssociatedSessionId: !!associatedSessionId, isInbox: !!isInbox, pageSize }, + ), + runTimedStage( + timing, + 'db.topic.query.agent.count.select', + () => + this.db + .select({ count: count(topics.id) }) + .from(topics) + .where(agentWhere), + { hasAssociatedSessionId: !!associatedSessionId, isInbox: !!isInbox }, + ), ]); + logTiming(timing, 'db.topic.query:done', { + itemCount: items.length, + stageMs: getDurationMs(queryStartedAt), + total: totalResult[0].count, + }); return { items, total: totalResult[0].count }; } @@ -238,37 +295,51 @@ export class TopicModel { ); const [items, totalResult] = await Promise.all([ - this.db - .select({ - agentId: topics.agentId, - completedAt: topics.completedAt, - createdAt: topics.createdAt, - favorite: topics.favorite, - historySummary: topics.historySummary, - id: topics.id, - metadata: topics.metadata, - sessionId: topics.sessionId, - status: topics.status, - title: topics.title, - updatedAt: topics.updatedAt, - }) - .from(topics) - .where(whereCondition) - // In boolean sorting, false is considered "smaller" than true. - // So here we use desc to ensure that topics with favorite as true are in front. - .orderBy(desc(topics.favorite), desc(topics.updatedAt)) - .limit(pageSize) - .offset(offset), - this.db - .select({ count: count(topics.id) }) - .from(topics) - .where(whereCondition), + runTimedStage( + timing, + 'db.topic.query.container.items.select', + () => + this.db + .select({ + agentId: topics.agentId, + completedAt: topics.completedAt, + createdAt: topics.createdAt, + favorite: topics.favorite, + historySummary: topics.historySummary, + id: topics.id, + metadata: topics.metadata, + sessionId: topics.sessionId, + status: topics.status, + title: topics.title, + updatedAt: topics.updatedAt, + }) + .from(topics) + .where(whereCondition) + // In boolean sorting, false is considered "smaller" than true. + // So here we use desc to ensure that topics with favorite as true are in front. + .orderBy(desc(topics.favorite), desc(topics.updatedAt)) + .limit(pageSize) + .offset(offset), + { current, pageSize }, + ), + runTimedStage(timing, 'db.topic.query.container.count.select', () => + this.db + .select({ count: count(topics.id) }) + .from(topics) + .where(whereCondition), + ), ]); // Remove internal fields before returning const cleanItems = items.map(({ agentId, sessionId, ...rest }) => rest); + logTiming(timing, 'db.topic.query:done', { + itemCount: cleanItems.length, + stageMs: getDurationMs(queryStartedAt), + total: totalResult[0].count, + }); + return { items: cleanItems, total: totalResult[0].count }; }; @@ -468,30 +539,67 @@ export class TopicModel { create = async ( { messages: messageIds, ...params }: CreateTopicParams, id: string = this.genId(), + timing?: ModelTimingContext, ): Promise => { - return this.db.transaction(async (tx) => { - const insertData = { - ...params, - agentId: params.agentId || null, - groupId: params.groupId || null, - id, - sessionId: params.sessionId || null, - userId: this.userId, - }; + const insertData = { + ...params, + agentId: params.agentId || null, + groupId: params.groupId || null, + id, + sessionId: params.sessionId || null, + userId: this.userId, + }; + const insertMeta = { + hasAgentId: !!params.agentId, + hasGroupId: !!params.groupId, + hasSessionId: !!params.sessionId, + }; - // Insert new topic - const [topic] = await tx.insert(topics).values(insertData).returning(); - - // Update associated messages' topicId - if (messageIds && messageIds.length > 0) { - await tx - .update(messages) - .set({ topicId: topic.id }) - .where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIds))); - } + if (!messageIds || messageIds.length === 0) { + const [topic] = await runTimedStage( + timing, + 'db.topic.create.topics.insert', + () => this.db.insert(topics).values(insertData).returning(), + insertMeta, + ); return topic; - }); + } + + return runTimedStage( + timing, + 'db.topic.create.transaction', + () => + this.db.transaction(async (tx) => { + // Insert new topic + const [topic] = await runTimedStage( + timing, + 'db.topic.create.topics.insert', + () => tx.insert(topics).values(insertData).returning(), + insertMeta, + ); + + // Update associated messages' topicId + await runTimedStage( + timing, + 'db.topic.create.messages.updateTopic', + () => + tx + .update(messages) + .set({ topicId: topic.id }) + .where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIds))), + { messageCount: messageIds.length }, + ); + + return topic; + }), + { + hasAgentId: !!params.agentId, + hasGroupId: !!params.groupId, + hasSessionId: !!params.sessionId, + messageCount: messageIds?.length ?? 0, + }, + ); }; batchCreate = async (topicParams: (CreateTopicParams & { id?: string })[]) => { diff --git a/packages/model-runtime/src/core/ModelRuntime.ts b/packages/model-runtime/src/core/ModelRuntime.ts index 4f183524c1..8c740a0bc2 100644 --- a/packages/model-runtime/src/core/ModelRuntime.ts +++ b/packages/model-runtime/src/core/ModelRuntime.ts @@ -1,4 +1,5 @@ import type { ModelUsage, TracePayload } from '@lobechat/types'; +import { createTimingHelpers, getDurationMs } from '@lobechat/utils'; import type { ClientOptions } from 'openai'; import type { LobeBedrockAIParams } from '../providers/bedrock'; @@ -32,6 +33,13 @@ import type { import { AgentRuntimeError } from '../utils/createError'; import type { LobeRuntimeAI } from './BaseAI'; +const { logger: timing } = createTimingHelpers('lobe-server:chat:lobehub:timing'); + +const getLobeHubTimingMetadata = (options?: { + metadata?: Record; +}): Record | undefined => + options?.metadata?.provider === 'lobehub' ? options.metadata : undefined; + export interface AgentChatOptions { enableTrace?: boolean; provider: string; @@ -126,6 +134,17 @@ export class ModelRuntime { * ``` */ async chat(payload: ChatStreamPayload, options?: ChatMethodOptions) { + const metadata = getLobeHubTimingMetadata(options); + const startedAt = Date.now(); + if (metadata) { + timing( + 'ModelRuntime.chat start model=%s trigger=%s traceId=%s', + payload.model, + metadata.trigger, + metadata.traceId, + ); + } + if (typeof this._runtime.chat !== 'function') { throw AgentRuntimeError.chat({ error: new Error('Chat is not supported by this provider'), @@ -135,11 +154,48 @@ export class ModelRuntime { } try { + const hooksStartedAt = Date.now(); const finalOptions = await this.applyHooks(payload, options); - return await this._runtime.chat(payload, finalOptions); + if (metadata) { + timing( + 'ModelRuntime.chat hooks done model=%s durationMs=%d traceId=%s', + payload.model, + getDurationMs(hooksStartedAt), + metadata.traceId, + ); + } + const runtimeStartedAt = Date.now(); + const response = await this._runtime.chat(payload, finalOptions); + if (metadata) { + timing( + 'ModelRuntime.chat runtime done model=%s durationMs=%d totalMs=%d traceId=%s', + payload.model, + getDurationMs(runtimeStartedAt), + getDurationMs(startedAt), + metadata.traceId, + ); + } + return response; } catch (error) { + if (metadata) { + timing( + 'ModelRuntime.chat error model=%s durationMs=%d traceId=%s', + payload.model, + getDurationMs(startedAt), + metadata.traceId, + ); + } if (this._hooks?.onChatError) { + const errorHookStartedAt = Date.now(); await this._hooks.onChatError(error as ChatCompletionErrorPayload, { options, payload }); + if (metadata) { + timing( + 'ModelRuntime.chat onChatError done model=%s durationMs=%d traceId=%s', + payload.model, + getDurationMs(errorHookStartedAt), + metadata.traceId, + ); + } } throw error; } @@ -152,7 +208,37 @@ export class ModelRuntime { payload: ChatStreamPayload, options?: ChatMethodOptions, ): Promise { - await this._hooks?.beforeChat?.(payload, options); + const metadata = getLobeHubTimingMetadata(options); + const beforeChatStartedAt = Date.now(); + if (metadata) { + timing( + 'ModelRuntime.beforeChat start model=%s trigger=%s traceId=%s', + payload.model, + metadata.trigger, + metadata.traceId, + ); + } + try { + await this._hooks?.beforeChat?.(payload, options); + } catch (error) { + if (metadata) { + timing( + 'ModelRuntime.beforeChat error model=%s durationMs=%d traceId=%s', + payload.model, + getDurationMs(beforeChatStartedAt), + metadata.traceId, + ); + } + throw error; + } + if (metadata) { + timing( + 'ModelRuntime.beforeChat done model=%s durationMs=%d traceId=%s', + payload.model, + getDurationMs(beforeChatStartedAt), + metadata.traceId, + ); + } if (!this._hooks?.onChatFinal) return options; @@ -163,10 +249,34 @@ export class ModelRuntime { callback: { ...options?.callback, async onFinal(data) { + const finalStartedAt = Date.now(); + if (metadata) { + timing( + 'ModelRuntime.onChatFinal start model=%s traceId=%s', + payload.model, + metadata.traceId, + ); + } await existingOnFinal?.(data); try { await hookFn(data, { options, payload }); + if (metadata) { + timing( + 'ModelRuntime.onChatFinal done model=%s durationMs=%d traceId=%s', + payload.model, + getDurationMs(finalStartedAt), + metadata.traceId, + ); + } } catch (e) { + if (metadata) { + timing( + 'ModelRuntime.onChatFinal error model=%s durationMs=%d traceId=%s', + payload.model, + getDurationMs(finalStartedAt), + metadata.traceId, + ); + } // Hook failures (billing, tracing) must not interfere with response completion console.error('[ModelRuntime] onChatFinal hook error:', e); } diff --git a/packages/model-runtime/src/core/RouterRuntime/createRuntime.ts b/packages/model-runtime/src/core/RouterRuntime/createRuntime.ts index 60b041594a..24bf847c2c 100644 --- a/packages/model-runtime/src/core/RouterRuntime/createRuntime.ts +++ b/packages/model-runtime/src/core/RouterRuntime/createRuntime.ts @@ -4,6 +4,7 @@ import type { GoogleGenAIOptions } from '@google/genai'; import type { ChatModelCard } from '@lobechat/types'; import { AgentRuntimeErrorType } from '@lobechat/types'; +import { createTimingHelpers, getDurationMs } from '@lobechat/utils'; import debug from 'debug'; import type { ClientOptions } from 'openai'; import type OpenAI from 'openai'; @@ -44,6 +45,7 @@ import type { import type { ApiType, RuntimeClass } from './apiTypes'; const log = debug('lobe-model-runtime:router-runtime'); +const { logger: timing } = createTimingHelpers('lobe-server:chat:lobehub:timing'); interface ProviderIniOptions extends Record { accessKeyId?: string; @@ -190,6 +192,7 @@ export const createRouterRuntime = ({ private _id: string; constructor(options: ClientOptions & Record = {}) { + const startedAt = Date.now(); this._options = { ...options, apiKey: options.apiKey?.trim() || DEFAULT_API_KEY, @@ -200,36 +203,76 @@ export const createRouterRuntime = ({ this._routers = routers; this._params = params; this._id = options.id ?? id; + + if (this._id === 'lobehub') { + timing( + 'constructor done providerId=%s durationMs=%d hasApiKey=%s hasBaseURL=%s', + this._id, + getDurationMs(startedAt), + !!this._options.apiKey, + !!this._options.baseURL, + ); + } } /** * Resolve routers configuration and validate */ private async resolveRouters(model?: string): Promise { - const resolvedRouters = - typeof this._routers === 'function' - ? await this._routers(this._options, { model }) - : this._routers; + const startedAt = Date.now(); + try { + const resolvedRouters = + typeof this._routers === 'function' + ? await this._routers(this._options, { model }) + : this._routers; - if (resolvedRouters.length === 0) { - throw AgentRuntimeError.chat({ - error: { message: 'empty providers' }, - errorType: AgentRuntimeErrorType.NoAvailableProvider, - provider: this._id, - }); + if (this._id === 'lobehub') { + timing( + 'resolveRouters done model=%s durationMs=%d routerCount=%d dynamic=%s', + model, + getDurationMs(startedAt), + resolvedRouters.length, + typeof this._routers === 'function', + ); + } + + if (resolvedRouters.length === 0) { + throw AgentRuntimeError.chat({ + error: { message: 'empty providers' }, + errorType: AgentRuntimeErrorType.NoAvailableProvider, + provider: this._id, + }); + } + + return resolvedRouters; + } catch (error) { + if (this._id === 'lobehub') { + timing('resolveRouters error model=%s durationMs=%d', model, getDurationMs(startedAt)); + } + throw error; } - - return resolvedRouters; } private async resolveMatchedRouter(model: string): Promise { + const startedAt = Date.now(); const resolvedRouters = await this.resolveRouters(model); const baseURL = this._options.baseURL; // Priority 1: Match by baseURLPattern (RegExp only) if (baseURL) { const baseURLMatch = resolvedRouters.find((router) => router.baseURLPattern?.test(baseURL)); - if (baseURLMatch) return baseURLMatch; + if (baseURLMatch) { + if (this._id === 'lobehub') { + timing( + 'resolveMatchedRouter done model=%s match=baseURL routerId=%s apiType=%s durationMs=%d', + model, + baseURLMatch.id, + baseURLMatch.apiType, + getDurationMs(startedAt), + ); + } + return baseURLMatch; + } } // Priority 2: Match by models @@ -239,19 +282,50 @@ export const createRouterRuntime = ({ } return false; }); - if (modelMatch) return modelMatch; + if (modelMatch) { + if (this._id === 'lobehub') { + timing( + 'resolveMatchedRouter done model=%s match=models routerId=%s apiType=%s durationMs=%d', + model, + modelMatch.id, + modelMatch.apiType, + getDurationMs(startedAt), + ); + } + return modelMatch; + } // Fallback: Use the last router - return resolvedRouters.at(-1)!; + const fallbackRouter = resolvedRouters.at(-1)!; + if (this._id === 'lobehub') { + timing( + 'resolveMatchedRouter done model=%s match=fallback routerId=%s apiType=%s durationMs=%d', + model, + fallbackRouter.id, + fallbackRouter.apiType, + getDurationMs(startedAt), + ); + } + return fallbackRouter; } private normalizeRouterOptions(router: RouterInstance): RouterOptionItem[] { + const startedAt = Date.now(); const routerOptions = Array.isArray(router.options) ? router.options : [router.options]; if (routerOptions.length === 0 || routerOptions.some((optionItem) => !optionItem)) { throw new Error('empty provider options'); } + if (this._id === 'lobehub') { + timing( + 'normalizeRouterOptions done routerId=%s options=%d durationMs=%d', + router.id, + routerOptions.length, + getDurationMs(startedAt), + ); + } + return routerOptions; } @@ -268,6 +342,7 @@ export const createRouterRuntime = ({ remark?: string; runtime: LobeRuntimeAI; }> { + const startedAt = Date.now(); const { apiType: optionApiType, id: channelId, remark, ...optionOverrides } = optionItem; const resolvedApiType = optionApiType ?? router.apiType; const finalOptions = { @@ -297,6 +372,16 @@ export const createRouterRuntime = ({ if (project) vertexOptions.project = project; if (location) vertexOptions.location = location as GoogleGenAIOptions['location']; + if (this._id === 'lobehub') { + timing( + 'createRuntimeFromOption done routerId=%s channelId=%s apiType=%s durationMs=%d vertex=true', + router.id, + channelId, + resolvedApiType, + getDurationMs(startedAt), + ); + } + return { channelId, id: resolvedApiType, @@ -312,6 +397,16 @@ export const createRouterRuntime = ({ : (baseRuntimeMap[resolvedApiType] ?? LobeOpenAI); const runtime: LobeRuntimeAI = new providerAI({ ...finalOptions, id: this._id }); + if (this._id === 'lobehub') { + timing( + 'createRuntimeFromOption done routerId=%s channelId=%s apiType=%s durationMs=%d', + router.id, + channelId, + resolvedApiType, + getDurationMs(startedAt), + ); + } + return { channelId, id: resolvedApiType, @@ -325,10 +420,22 @@ export const createRouterRuntime = ({ requestHandler: (runtime: LobeRuntimeAI) => Promise, metadata?: Record, ): Promise { + const totalStartedAt = Date.now(); const matchedRouter = await this.resolveMatchedRouter(model); const routerOptions = this.normalizeRouterOptions(matchedRouter); const totalOptions = routerOptions.length; + if (this._id === 'lobehub') { + timing( + 'runWithFallback start model=%s routerId=%s apiType=%s options=%d traceId=%s', + model, + matchedRouter.id, + matchedRouter.apiType, + totalOptions, + metadata?.traceId, + ); + } + log( 'resolve router for model=%s apiType=%s options=%d', model, @@ -349,7 +456,33 @@ export const createRouterRuntime = ({ } = await this.createRuntimeFromOption(matchedRouter, optionItem); try { + if (this._id === 'lobehub') { + timing( + 'attempt request start model=%s attempt=%d/%d routerId=%s channelId=%s apiType=%s traceId=%s', + model, + attempt, + totalOptions, + matchedRouter.id, + channelId, + resolvedApiType, + metadata?.traceId, + ); + } const result = await requestHandler(runtime); + if (this._id === 'lobehub') { + timing( + 'attempt request success model=%s attempt=%d/%d routerId=%s channelId=%s apiType=%s durationMs=%d totalMs=%d traceId=%s', + model, + attempt, + totalOptions, + matchedRouter.id, + channelId, + resolvedApiType, + getDurationMs(startTime), + getDurationMs(totalStartedAt), + metadata?.traceId, + ); + } if (totalOptions > 1 && attempt > 1) { log( @@ -392,6 +525,20 @@ export const createRouterRuntime = ({ return result; } catch (error) { lastError = error; + if (this._id === 'lobehub') { + timing( + 'attempt request error model=%s attempt=%d/%d routerId=%s channelId=%s apiType=%s durationMs=%d totalMs=%d traceId=%s', + model, + attempt, + totalOptions, + matchedRouter.id, + channelId, + resolvedApiType, + getDurationMs(startTime), + getDurationMs(totalStartedAt), + metadata?.traceId, + ); + } params .onRouteAttempt?.({ @@ -417,6 +564,7 @@ export const createRouterRuntime = ({ } try { + const shouldStopStartedAt = Date.now(); const shouldStopFallback = await params.shouldStopFallback?.({ error, metadata, @@ -424,6 +572,18 @@ export const createRouterRuntime = ({ optionIndex: index, }); + if (this._id === 'lobehub') { + timing( + 'shouldStopFallback done model=%s attempt=%d/%d durationMs=%d shouldStop=%s traceId=%s', + model, + attempt, + totalOptions, + getDurationMs(shouldStopStartedAt), + shouldStopFallback, + metadata?.traceId, + ); + } + if (shouldStopFallback) { throw error; } @@ -460,6 +620,17 @@ export const createRouterRuntime = ({ } } + if (this._id === 'lobehub') { + timing( + 'runWithFallback failed model=%s routerId=%s options=%d totalMs=%d traceId=%s', + model, + matchedRouter.id, + totalOptions, + getDurationMs(totalStartedAt), + metadata?.traceId, + ); + } + throw lastError ?? new Error('empty provider options'); } diff --git a/packages/types/src/aiChat.test.ts b/packages/types/src/aiChat.test.ts new file mode 100644 index 0000000000..242672d603 --- /dev/null +++ b/packages/types/src/aiChat.test.ts @@ -0,0 +1,21 @@ +import { describe, expect, it } from 'vitest'; + +import { AiSendMessageServerSchema } from './aiChat'; + +const createInput = (topicPageSize: number) => ({ + newAssistantMessage: { model: 'gpt-4o', provider: 'openai' }, + newUserMessage: { content: 'hello' }, + topicPageSize, +}); + +describe('AiSendMessageServerSchema', () => { + it('should only accept positive integer topic page sizes up to 100', () => { + for (const topicPageSize of [1, 20, 100]) { + expect(AiSendMessageServerSchema.safeParse(createInput(topicPageSize)).success).toBe(true); + } + + for (const topicPageSize of [-1, 0, 1.5, 101]) { + expect(AiSendMessageServerSchema.safeParse(createInput(topicPageSize)).success).toBe(false); + } + }); +}); diff --git a/packages/types/src/aiChat.ts b/packages/types/src/aiChat.ts index 4b7ed890a9..262f47177a 100644 --- a/packages/types/src/aiChat.ts +++ b/packages/types/src/aiChat.ts @@ -96,6 +96,10 @@ export interface SendMessageServerParams { }; // if there is activeTopicId, then add topicId to message topicId?: string; + /** + * Page size for the topic list returned after creating a new topic. + */ + topicPageSize?: number; } export const CreateThreadWithMessageSchema = z.object({ @@ -156,6 +160,7 @@ export const AiSendMessageServerSchema = z.object({ includeTriggers: z.array(z.string()).optional(), }) .optional(), + topicPageSize: z.number().int().min(1).max(100).optional(), topicId: z.string().optional(), }); diff --git a/packages/utils/src/index.ts b/packages/utils/src/index.ts index b1c20e9299..4dc150d78d 100644 --- a/packages/utils/src/index.ts +++ b/packages/utils/src/index.ts @@ -20,6 +20,7 @@ export * from './pricing'; export * from './safeParseJSON'; export * from './sanitizeToolCallArguments'; export * from './sleep'; +export * from './timing'; export * from './uriParser'; export * from './url'; export * from './uuid'; diff --git a/packages/utils/src/timing.ts b/packages/utils/src/timing.ts new file mode 100644 index 0000000000..d41e0bdffe --- /dev/null +++ b/packages/utils/src/timing.ts @@ -0,0 +1,173 @@ +import debug from 'debug'; + +export interface TimingContext { + requestId: string; + startedAt: number; +} + +export interface TimingMetadata { + [key: string]: unknown; +} + +export interface TimingParams { + timingRequestId?: string; + timingStartedAt?: number; +} + +export interface TimingSink { + log: (event: string, metadata?: TimingMetadata) => void; +} + +export type TimingLogger = (formatter: string, ...args: unknown[]) => void; + +export const createDebugTimingLogger = (namespace: string): TimingLogger => debug(namespace); + +export const getDurationMs = (startedAt: number) => Date.now() - startedAt; + +export const createTimingRequestId = () => + globalThis.crypto?.randomUUID?.() ?? + `${Date.now().toString(36)}-${Math.random().toString(36).slice(2, 10)}`; + +const isRecord = (value: unknown): value is Record => + !!value && typeof value === 'object'; + +export const getTimingErrorMetadata = (error: unknown): TimingMetadata => { + if (error instanceof Error) { + return { + errorMessage: error.message, + errorName: error.name, + }; + } + + if (isRecord(error)) { + return { + errorType: typeof error.errorType === 'string' ? error.errorType : undefined, + status: typeof error.status === 'number' ? error.status : undefined, + }; + } + + return { errorMessage: String(error) }; +}; + +export const toTimingContext = (params?: TimingParams): TimingContext | undefined => + params?.timingRequestId + ? { requestId: params.timingRequestId, startedAt: params.timingStartedAt ?? Date.now() } + : undefined; + +export const logTiming = ( + logger: TimingLogger, + context: TimingContext | undefined, + event: string, + metadata?: TimingMetadata, +) => { + if (!context) return; + + const totalMs = getDurationMs(context.startedAt); + if (metadata) { + logger('[%s] %s totalMs=%d %O', context.requestId, event, totalMs, metadata); + return; + } + + logger('[%s] %s totalMs=%d', context.requestId, event, totalMs); +}; + +export const logTimingSink = ( + timing: TimingSink | undefined, + event: string, + metadata?: TimingMetadata, +) => { + timing?.log(event, metadata); +}; + +export const runTimedStage = async ( + logger: TimingLogger, + context: TimingContext | undefined, + stage: string, + task: () => T | Promise, + metadata?: TimingMetadata, +): Promise> => { + if (!context) return await task(); + + const startedAt = Date.now(); + logTiming(logger, context, `${stage}:start`, metadata); + + try { + const result = await task(); + logTiming(logger, context, `${stage}:done`, { + ...metadata, + stageMs: getDurationMs(startedAt), + }); + + return result; + } catch (error) { + logTiming(logger, context, `${stage}:error`, { + ...metadata, + ...getTimingErrorMetadata(error), + stageMs: getDurationMs(startedAt), + }); + + throw error; + } +}; + +export const runTimedSinkStage = async ( + timing: TimingSink | undefined, + stage: string, + task: () => T | Promise, + metadata?: TimingMetadata, +): Promise> => { + if (!timing) return await task(); + + const startedAt = Date.now(); + logTimingSink(timing, `${stage}:start`, metadata); + + try { + const result = await task(); + logTimingSink(timing, `${stage}:done`, { + ...metadata, + stageMs: getDurationMs(startedAt), + }); + + return result; + } catch (error) { + logTimingSink(timing, `${stage}:error`, { + ...metadata, + ...getTimingErrorMetadata(error), + stageMs: getDurationMs(startedAt), + }); + + throw error; + } +}; + +export const createPrefixedTimingContext = ( + logger: TimingLogger, + context: TimingContext | undefined, + prefix: string, +): TimingSink | undefined => + context + ? { + log: (event: string, metadata?: TimingMetadata) => { + logTiming(logger, context, `${prefix}.${event}`, metadata); + }, + } + : undefined; + +export const createTimingHelpers = (namespace: string) => { + const logger = createDebugTimingLogger(namespace); + + return { + createPrefixedTimingContext: (context: TimingContext | undefined, prefix: string) => + createPrefixedTimingContext(logger, context, prefix), + logger, + logTiming: (context: TimingContext | undefined, event: string, metadata?: TimingMetadata) => + logTiming(logger, context, event, metadata), + runTimedStage: ( + context: TimingContext | undefined, + stage: string, + task: () => T | Promise, + metadata?: TimingMetadata, + ) => runTimedStage(logger, context, stage, task, metadata), + toTimingContext, + }; +}; diff --git a/src/server/routers/lambda/__tests__/aiChat.test.ts b/src/server/routers/lambda/__tests__/aiChat.test.ts index eb73cf69fa..6be9d88a31 100644 --- a/src/server/routers/lambda/__tests__/aiChat.test.ts +++ b/src/server/routers/lambda/__tests__/aiChat.test.ts @@ -1,4 +1,5 @@ // @vitest-environment node +import type { CreateMessageParams } from '@lobechat/types'; import { ThreadType } from '@lobechat/types'; import { describe, expect, it, vi } from 'vitest'; @@ -10,6 +11,8 @@ import { AiChatService } from '@/server/services/aiChat'; import { aiChatRouter } from '../aiChat'; +const flushAsyncTasks = () => new Promise((resolve) => setTimeout(resolve, 0)); + vi.mock('@/database/models/agent'); vi.mock('@/database/models/message'); vi.mock('@/database/models/thread'); @@ -24,6 +27,38 @@ vi.mock('@/server/modules/ModelRuntime', () => ({ describe('aiChatRouter', () => { const mockCtx = { userId: 'u1' }; + const mockMessageModel = (mockCreateMessage: ReturnType) => { + const mockCreateUserAndAssistantMessages = vi.fn( + async ( + { + assistantMessage, + userMessage, + }: { + assistantMessage: CreateMessageParams; + userMessage: CreateMessageParams; + }, + _options?: unknown, + ) => { + const userMessageItem = await mockCreateMessage(userMessage); + const assistantMessageItem = await mockCreateMessage({ + ...assistantMessage, + parentId: userMessageItem.id, + }); + + return { assistantMessage: assistantMessageItem, userMessage: userMessageItem }; + }, + ); + + vi.mocked(MessageModel).mockImplementation( + () => + ({ + create: mockCreateMessage, + createUserAndAssistantMessages: mockCreateUserAndAssistantMessages, + }) as any, + ); + + return mockCreateUserAndAssistantMessages; + }; it('should create topic optionally, create user/assistant messages, and return payload', async () => { const mockCreateTopic = vi.fn().mockResolvedValue({ id: 't1' }); @@ -37,7 +72,7 @@ describe('aiChatRouter', () => { }); vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any); - vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); + const mockCreateUserAndAssistantMessages = mockMessageModel(mockCreateMessage); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); const caller = aiChatRouter.createCaller(mockCtx as any); @@ -47,6 +82,7 @@ describe('aiChatRouter', () => { newTopic: { title: 'T', topicMessageIds: ['a', 'b'] }, newUserMessage: { content: 'hi', files: ['f1'] }, sessionId: 's1', + topicPageSize: 20, } as any; const res = await caller.sendMessageInServer(input); @@ -79,9 +115,19 @@ describe('aiChatRouter', () => { topicId: 't1', }), ); + expect(mockCreateUserAndAssistantMessages).toHaveBeenCalledTimes(1); + expect(mockCreateUserAndAssistantMessages).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ touchTopicUpdatedAt: false }), + ); expect(mockGet).toHaveBeenCalledWith( - expect.objectContaining({ includeTopic: true, sessionId: 's1', topicId: 't1' }), + expect.objectContaining({ + includeTopic: true, + sessionId: 's1', + topicId: 't1', + topicPageSize: 20, + }), ); expect(res.assistantMessageId).toBe('m-assistant'); expect(res.userMessageId).toBe('m-user'); @@ -99,7 +145,7 @@ describe('aiChatRouter', () => { .mockResolvedValueOnce({ id: 'm-assistant' }); const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined }); - vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); + const mockCreateUserAndAssistantMessages = mockMessageModel(mockCreateMessage); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); const caller = aiChatRouter.createCaller(mockCtx as any); @@ -112,6 +158,10 @@ describe('aiChatRouter', () => { } as any); expect(mockCreateMessage).toHaveBeenCalled(); + expect(mockCreateUserAndAssistantMessages).toHaveBeenCalledWith( + expect.any(Object), + expect.objectContaining({ touchTopicUpdatedAt: true }), + ); expect(mockGet).toHaveBeenCalledWith( expect.objectContaining({ includeTopic: false, @@ -130,7 +180,7 @@ describe('aiChatRouter', () => { .mockResolvedValueOnce({ id: 'm-assistant' }); const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined }); - vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); + mockMessageModel(mockCreateMessage); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); const caller = aiChatRouter.createCaller(mockCtx as any); @@ -175,7 +225,7 @@ describe('aiChatRouter', () => { .mockResolvedValueOnce({ id: 'm-assistant' }); const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined }); - vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); + mockMessageModel(mockCreateMessage); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); const caller = aiChatRouter.createCaller(mockCtx as any); @@ -282,7 +332,7 @@ describe('aiChatRouter', () => { const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined }); vi.mocked(ThreadModel).mockImplementation(() => ({ create: mockCreateThread }) as any); - vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); + mockMessageModel(mockCreateMessage); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); const caller = aiChatRouter.createCaller(mockCtx as any); @@ -346,7 +396,7 @@ describe('aiChatRouter', () => { vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any); vi.mocked(ThreadModel).mockImplementation(() => ({ create: mockCreateThread }) as any); - vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); + mockMessageModel(mockCreateMessage); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); const caller = aiChatRouter.createCaller(mockCtx as any); @@ -402,7 +452,7 @@ describe('aiChatRouter', () => { .mockResolvedValueOnce({ id: 'm-assistant' }); const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined }); - vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); + mockMessageModel(mockCreateMessage); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); const caller = aiChatRouter.createCaller(mockCtx as any); @@ -427,7 +477,7 @@ describe('aiChatRouter', () => { const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: [{}] }); vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any); - vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); + mockMessageModel(mockCreateMessage); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); const caller = aiChatRouter.createCaller(mockCtx as any); @@ -459,7 +509,7 @@ describe('aiChatRouter', () => { const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: [{}] }); vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any); - vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); + mockMessageModel(mockCreateMessage); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); const caller = aiChatRouter.createCaller(mockCtx as any); @@ -489,7 +539,7 @@ describe('aiChatRouter', () => { .mockResolvedValueOnce({ id: 'm-assistant' }); const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined }); - vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); + mockMessageModel(mockCreateMessage); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); const caller = aiChatRouter.createCaller(mockCtx as any); @@ -537,7 +587,7 @@ describe('aiChatRouter', () => { .mockResolvedValueOnce({ id: 'm-assistant' }); const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined }); - vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); + mockMessageModel(mockCreateMessage); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); const caller = aiChatRouter.createCaller(mockCtx as any); @@ -569,7 +619,7 @@ describe('aiChatRouter', () => { .mockResolvedValueOnce({ id: 'm-assistant' }); const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined }); - vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); + mockMessageModel(mockCreateMessage); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); const caller = aiChatRouter.createCaller(mockCtx as any); @@ -621,7 +671,7 @@ describe('aiChatRouter', () => { .mockResolvedValueOnce({ id: 'm-assistant' }); const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined }); - vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); + mockMessageModel(mockCreateMessage); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); const caller = aiChatRouter.createCaller(mockCtx as any); @@ -677,7 +727,7 @@ describe('aiChatRouter', () => { const mockTouchUpdatedAt = vi.fn().mockResolvedValue(undefined); vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any); - vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); + mockMessageModel(mockCreateMessage); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); vi.mocked(AgentModel).mockImplementation( () => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any, @@ -713,7 +763,7 @@ describe('aiChatRouter', () => { const mockTouchUpdatedAt = vi.fn().mockResolvedValue(undefined); vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any); - vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); + mockMessageModel(mockCreateMessage); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); vi.mocked(AgentModel).mockImplementation( () => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any, @@ -733,6 +783,94 @@ describe('aiChatRouter', () => { expect(mockTouchUpdatedAt).toHaveBeenCalledWith('agent-1'); }); + it('should keep the message response when agent updatedAt touch fails', async () => { + const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => undefined); + const mockCreateTopic = vi.fn().mockResolvedValue({ id: 't1' }); + const mockCreateMessage = vi + .fn() + .mockResolvedValueOnce({ id: 'm-user' }) + .mockResolvedValueOnce({ id: 'm-assistant' }); + const mockGet = vi.fn().mockResolvedValue({ + messages: [{ id: 'm-user' }, { id: 'm-assistant' }], + topics: undefined, + }); + const touchError = new Error('touch failed'); + const mockTouchUpdatedAt = vi.fn().mockRejectedValue(touchError); + + try { + vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any); + mockMessageModel(mockCreateMessage); + vi.mocked(AiChatService).mockImplementation( + () => ({ getMessagesAndTopics: mockGet }) as any, + ); + vi.mocked(AgentModel).mockImplementation( + () => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any, + ); + + const caller = aiChatRouter.createCaller(mockCtx as any); + + const res = await caller.sendMessageInServer({ + agentId: 'agent-1', + newAssistantMessage: { model: 'gpt-4o', provider: 'openai' }, + newTopic: { title: 'New Topic' }, + newUserMessage: { content: 'hi' }, + sessionId: 's1', + } as any); + + expect(res.userMessageId).toBe('m-user'); + expect(res.assistantMessageId).toBe('m-assistant'); + expect(mockTouchUpdatedAt).toHaveBeenCalledWith('agent-1'); + expect(consoleErrorSpy).toHaveBeenCalledWith( + '[aiChat] Failed to touch agent updatedAt:', + touchError, + ); + } finally { + consoleErrorSpy.mockRestore(); + } + }); + + it('should create messages while agent updatedAt touch is still pending', async () => { + const mockCreateTopic = vi.fn().mockResolvedValue({ id: 't1' }); + const mockCreateMessage = vi + .fn() + .mockResolvedValueOnce({ id: 'm-user' }) + .mockResolvedValueOnce({ id: 'm-assistant' }); + const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: [{}] }); + let resolveTouchUpdatedAt: () => void = () => {}; + const touchUpdatedAtPromise = new Promise((resolve) => { + resolveTouchUpdatedAt = resolve; + }); + const mockTouchUpdatedAt = vi.fn(() => touchUpdatedAtPromise); + + vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any); + const mockCreateUserAndAssistantMessages = mockMessageModel(mockCreateMessage); + vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); + vi.mocked(AgentModel).mockImplementation( + () => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any, + ); + + const caller = aiChatRouter.createCaller(mockCtx as any); + + const request = caller.sendMessageInServer({ + agentId: 'agent-1', + newAssistantMessage: { model: 'gpt-4o', provider: 'openai' }, + newTopic: { title: 'New Topic' }, + newUserMessage: { content: 'hi' }, + sessionId: 's1', + } as any); + + await flushAsyncTasks(); + + try { + expect(mockTouchUpdatedAt).toHaveBeenCalledWith('agent-1'); + expect(mockCreateUserAndAssistantMessages).toHaveBeenCalledTimes(1); + } finally { + resolveTouchUpdatedAt(); + } + + await request; + }); + it('should not touch agent updatedAt when creating topic without agentId', async () => { const mockCreateTopic = vi.fn().mockResolvedValue({ id: 't1' }); const mockCreateMessage = vi @@ -743,7 +881,7 @@ describe('aiChatRouter', () => { const mockTouchUpdatedAt = vi.fn().mockResolvedValue(undefined); vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any); - vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); + mockMessageModel(mockCreateMessage); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); vi.mocked(AgentModel).mockImplementation( () => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any, @@ -771,7 +909,7 @@ describe('aiChatRouter', () => { const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined }); const mockTouchUpdatedAt = vi.fn().mockResolvedValue(undefined); - vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); + mockMessageModel(mockCreateMessage); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); vi.mocked(AgentModel).mockImplementation( () => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any, diff --git a/src/server/routers/lambda/aiChat.ts b/src/server/routers/lambda/aiChat.ts index 78a868605e..d36b49c7fa 100644 --- a/src/server/routers/lambda/aiChat.ts +++ b/src/server/routers/lambda/aiChat.ts @@ -1,5 +1,6 @@ -import { type CreateMessageParams, type SendMessageServerResponse } from '@lobechat/types'; +import type { CreateMessageParams, SendMessageServerResponse } from '@lobechat/types'; import { AiSendMessageServerSchema, RequestTrigger, StructureOutputSchema } from '@lobechat/types'; +import { createTimingHelpers, createTimingRequestId } from '@lobechat/utils'; import debug from 'debug'; import { LOADING_FLAT } from '@/const/message'; @@ -15,6 +16,9 @@ import { AiChatService } from '@/server/services/aiChat'; import { FileService } from '@/server/services/file'; const log = debug('lobe-lambda-router:ai-chat'); +const { createPrefixedTimingContext, logTiming, runTimedStage } = createTimingHelpers( + 'lobe-server:chat:lobehub:timing', +); const aiChatProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; @@ -59,6 +63,17 @@ export const aiChatRouter = router({ sendMessageInServer: aiChatProcedure .input(AiSendMessageServerSchema) .mutation(async ({ input, ctx }) => { + const timingContext = + input.newAssistantMessage.provider === 'lobehub' + ? { requestId: createTimingRequestId(), startedAt: Date.now() } + : undefined; + logTiming(timingContext, 'lambda.aiChat.sendMessageInServer:start', { + hasNewThread: !!input.newThread, + hasNewTopic: !!input.newTopic, + hasSessionId: !!input.sessionId, + hasTopicId: !!input.topicId, + preloadCount: input.preloadMessages?.length ?? 0, + }); log('sendMessageInServer called for agentId: %s', input.agentId); log( 'topicId: %s, newTopic: %O, newThread: %O', @@ -68,7 +83,12 @@ export const aiChatRouter = router({ ); let sessionId = input.sessionId; if (!sessionId) { - const context = await resolveContext(input, ctx.serverDB, ctx.userId); + const context = await runTimedStage( + timingContext, + 'lambda.aiChat.resolveContext', + () => resolveContext(input, ctx.serverDB, ctx.userId), + { hasAgentId: !!input.agentId }, + ); if (!!context.sessionId) sessionId = context.sessionId; } @@ -77,27 +97,54 @@ export const aiChatRouter = router({ let createdThreadId: string | undefined; let isCreateNewTopic = false; + let agentTouchUpdatedAtTask: Promise | undefined; // create topic if there should be a new topic if (input.newTopic) { log('creating new topic with title: %s', input.newTopic.title); - const topicItem = await ctx.topicModel.create({ - agentId: input.agentId, - groupId: input.groupId, - messages: input.newTopic.topicMessageIds, - metadata: input.newTopic.metadata, - sessionId, - title: input.newTopic.title, - trigger: input.newTopic.trigger, - }); + const topicItem = await runTimedStage( + timingContext, + 'lambda.aiChat.topic.create', + () => { + const payload = { + agentId: input.agentId, + groupId: input.groupId, + messages: input.newTopic!.topicMessageIds, + metadata: input.newTopic!.metadata, + sessionId, + title: input.newTopic!.title, + trigger: input.newTopic!.trigger, + }; + const modelTiming = createPrefixedTimingContext( + timingContext, + 'lambda.aiChat.topic.create', + ); + return modelTiming + ? ctx.topicModel.create(payload, undefined, modelTiming) + : ctx.topicModel.create(payload); + }, + { + messageCount: input.newTopic.topicMessageIds?.length ?? 0, + trigger: input.newTopic.trigger, + }, + ); topicId = topicItem.id; isCreateNewTopic = true; log('new topic created with id: %s', topicId); // update agent's updatedAt to reflect new activity if (input.agentId) { - await ctx.agentModel.touchUpdatedAt(input.agentId); - log('agent updatedAt touched for agentId: %s', input.agentId); + agentTouchUpdatedAtTask = runTimedStage( + timingContext, + 'lambda.aiChat.agent.touchUpdatedAt', + async () => { + await ctx.agentModel.touchUpdatedAt(input.agentId!); + }, + { hasAgentId: true }, + ).catch((error) => { + console.error('[aiChat] Failed to touch agent updatedAt:', error); + }); + log('agent updatedAt touch scheduled for agentId: %s', input.agentId); } } @@ -108,13 +155,19 @@ export const aiChatRouter = router({ input.newThread.sourceMessageId, input.newThread.type, ); - const threadItem = await ctx.threadModel.create({ - parentThreadId: input.newThread.parentThreadId, - sourceMessageId: input.newThread.sourceMessageId, - title: input.newThread.title, - topicId, - type: input.newThread.type, - }); + const threadItem = await runTimedStage( + timingContext, + 'lambda.aiChat.thread.create', + () => + ctx.threadModel.create({ + parentThreadId: input.newThread!.parentThreadId, + sourceMessageId: input.newThread!.sourceMessageId, + title: input.newThread!.title, + topicId, + type: input.newThread!.type, + }), + { threadType: input.newThread.type }, + ); if (threadItem) { threadId = threadItem.id; createdThreadId = threadItem.id; @@ -127,24 +180,40 @@ export const aiChatRouter = router({ if (input.preloadMessages?.length) { log('creating %d preload messages before user message', input.preloadMessages.length); - for (const preloadMessage of input.preloadMessages) { - const preloadItem = await ctx.messageModel.create({ - agentId: input.agentId, - content: preloadMessage.content, - groupId: input.groupId, - metadata: preloadMessage.metadata, - parentId, - plugin: preloadMessage.plugin as CreateMessageParams['plugin'], - role: preloadMessage.role, - sessionId, - threadId, - tool_call_id: preloadMessage.tool_call_id, - tools: preloadMessage.tools as CreateMessageParams['tools'], - topicId, - }); + parentId = await runTimedStage( + timingContext, + 'lambda.aiChat.preloadMessages.create', + async () => { + let latestParentId = parentId; + for (const preloadMessage of input.preloadMessages!) { + const payload = { + agentId: input.agentId, + content: preloadMessage.content, + groupId: input.groupId, + metadata: preloadMessage.metadata, + parentId: latestParentId, + plugin: preloadMessage.plugin as CreateMessageParams['plugin'], + role: preloadMessage.role, + sessionId, + threadId, + tool_call_id: preloadMessage.tool_call_id, + tools: preloadMessage.tools as CreateMessageParams['tools'], + topicId, + }; + const modelTiming = createPrefixedTimingContext( + timingContext, + 'lambda.aiChat.preloadMessages.create', + ); + const preloadItem = await (modelTiming + ? ctx.messageModel.create(payload, undefined, modelTiming) + : ctx.messageModel.create(payload)); - parentId = preloadItem.id; - } + latestParentId = preloadItem.id; + } + return latestParentId; + }, + { count: input.preloadMessages.length }, + ); } // create user message @@ -161,58 +230,95 @@ export const aiChatRouter = router({ } : undefined; - const userMessageItem = await ctx.messageModel.create({ - agentId: input.agentId, - content: input.newUserMessage.content, - editorData: input.newUserMessage.editorData, - files: input.newUserMessage.files, - groupId: input.groupId, - metadata: userMessageMetadata, - parentId, - role: 'user', - sessionId, - threadId, - topicId, - }); + const createMessagePairPromise = runTimedStage( + timingContext, + 'lambda.aiChat.messages.createUserAndAssistant', + () => { + const userMessage = { + agentId: input.agentId, + content: input.newUserMessage.content, + editorData: input.newUserMessage.editorData, + files: input.newUserMessage.files, + groupId: input.groupId, + metadata: userMessageMetadata, + parentId, + role: 'user', + sessionId, + threadId, + topicId, + } satisfies CreateMessageParams; + const assistantMessage = { + agentId: input.agentId, + content: LOADING_FLAT, + groupId: input.groupId, + metadata: input.newAssistantMessage.metadata, + model: input.newAssistantMessage.model, + provider: input.newAssistantMessage.provider, + role: 'assistant', + sessionId, + threadId, + topicId, + } satisfies CreateMessageParams; + const modelTiming = createPrefixedTimingContext( + timingContext, + 'lambda.aiChat.messages.createUserAndAssistant', + ); + return ctx.messageModel.createUserAndAssistantMessages( + { assistantMessage, userMessage }, + { + ...(modelTiming ? { timing: modelTiming } : {}), + touchTopicUpdatedAt: !isCreateNewTopic, + }, + ); + }, + { + contentLength: input.newUserMessage.content.length, + fileCount: input.newUserMessage.files?.length ?? 0, + model: input.newAssistantMessage.model, + provider: input.newAssistantMessage.provider, + }, + ); + const { assistantMessage: assistantMessageItem, userMessage: userMessageItem } = + agentTouchUpdatedAtTask + ? (await Promise.all([createMessagePairPromise, agentTouchUpdatedAtTask]))[0] + : await createMessagePairPromise; const messageId = userMessageItem.id; log('user message created with id: %s', messageId); - // create assistant message - log( - 'creating assistant message with model: %s, provider: %s, metadata: %O', - input.newAssistantMessage.model, - input.newAssistantMessage.provider, - input.newAssistantMessage.metadata, - ); - const assistantMessageItem = await ctx.messageModel.create({ - agentId: input.agentId, - content: LOADING_FLAT, - groupId: input.groupId, - metadata: input.newAssistantMessage.metadata, - model: input.newAssistantMessage.model, - parentId: messageId, - provider: input.newAssistantMessage.provider, - role: 'assistant', - sessionId, - threadId, - topicId, - }); log('assistant message created with id: %s', assistantMessageItem.id); // retrieve latest messages and topic with log('retrieving messages and topics'); - const { messages, topics } = await ctx.aiChatService.getMessagesAndTopics({ - agentId: input.agentId, - groupId: input.groupId, - includeTopic: isCreateNewTopic, - sessionId, - threadId, - topicFilter: input.topicFilter, - topicId, - }); + const { messages, topics } = await runTimedStage( + timingContext, + 'lambda.aiChat.messagesAndTopics.query', + () => + ctx.aiChatService.getMessagesAndTopics({ + agentId: input.agentId, + groupId: input.groupId, + includeTopic: isCreateNewTopic, + sessionId, + threadId, + topicFilter: input.topicFilter, + topicId, + topicPageSize: input.topicPageSize, + ...(timingContext + ? { + timingRequestId: timingContext.requestId, + timingStartedAt: timingContext.startedAt, + } + : {}), + }), + { includeTopic: isCreateNewTopic }, + ); log('retrieved %d messages, %d topics', messages.length, topics?.items?.length ?? 0); + logTiming(timingContext, 'lambda.aiChat.sendMessageInServer:done', { + isCreateNewTopic, + messageCount: messages.length, + topicCount: topics?.items?.length ?? 0, + }); return { assistantMessageId: assistantMessageItem.id, diff --git a/src/server/routers/lambda/message.ts b/src/server/routers/lambda/message.ts index fc3f5947cb..b29711d73d 100644 --- a/src/server/routers/lambda/message.ts +++ b/src/server/routers/lambda/message.ts @@ -4,6 +4,7 @@ import { UpdateMessagePluginSchema, UpdateMessageRAGParamsSchema, } from '@lobechat/types'; +import { createTimingHelpers, createTimingRequestId } from '@lobechat/utils'; import { TRPCError } from '@trpc/server'; import { z } from 'zod'; @@ -18,6 +19,8 @@ import { MessageService } from '@/server/services/message'; import { resolveAgentIdFromSession, resolveContext } from './_helpers/resolveContext'; import { basicContextSchema } from './_schema/context'; +const { logTiming, runTimedStage } = createTimingHelpers('lobe-server:chat:lobehub:timing'); + const messageProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; @@ -316,9 +319,37 @@ export const messageRouter = router({ ) .mutation(async ({ input, ctx }) => { const { id, value, agentId, ...options } = input; - const resolved = await resolveContext({ agentId, ...options }, ctx.serverDB, ctx.userId); + const timingContext = { requestId: createTimingRequestId(), startedAt: Date.now() }; + logTiming(timingContext, 'lambda.message.update:start', { + hasAgentId: !!agentId, + hasTopicId: !!options.topicId, + valueKeys: Object.keys(value ?? {}), + }); - return ctx.messageService.updateMessage(id, value as any, resolved); + const resolved = await runTimedStage( + timingContext, + 'lambda.message.update.resolveContext', + () => resolveContext({ agentId, ...options }, ctx.serverDB, ctx.userId), + { hasAgentId: !!agentId }, + ); + + const result = await runTimedStage( + timingContext, + 'lambda.message.update.service', + () => + ctx.messageService.updateMessage(id, value as any, { + ...resolved, + timingRequestId: timingContext.requestId, + timingStartedAt: timingContext.startedAt, + }), + { hasResolvedTopicId: !!resolved.topicId }, + ); + + logTiming(timingContext, 'lambda.message.update:done', { + messageCount: result.messages?.length ?? 0, + success: result.success, + }); + return result; }), /** diff --git a/src/server/services/aiChat/index.test.ts b/src/server/services/aiChat/index.test.ts index e2a05a1dce..401796269b 100644 --- a/src/server/services/aiChat/index.test.ts +++ b/src/server/services/aiChat/index.test.ts @@ -1,4 +1,4 @@ -import { type LobeChatDatabase } from '@lobechat/database'; +import type { LobeChatDatabase } from '@lobechat/database'; import { describe, expect, it, vi } from 'vitest'; import { MessageModel } from '@/database/models/message'; @@ -31,13 +31,18 @@ describe('AiChatService', () => { groupId: 'group-1', includeTopic: true, sessionId: 's1', + topicPageSize: 20, }); expect(mockQueryMessages).toHaveBeenCalledWith( { agentId: 'agent-1', groupId: 'group-1', includeTopic: true, sessionId: 's1' }, expect.objectContaining({ postProcessUrl: expect.any(Function) }), ); - expect(mockQueryTopics).toHaveBeenCalledWith({ agentId: 'agent-1', groupId: 'group-1' }); + expect(mockQueryTopics).toHaveBeenCalledWith({ + agentId: 'agent-1', + groupId: 'group-1', + pageSize: 20, + }); expect(res.messages).toEqual([{ id: 'm1' }]); expect(res.topics).toEqual([{ id: 't1' }]); }); @@ -63,6 +68,7 @@ describe('AiChatService', () => { excludeStatuses: ['completed'], excludeTriggers: ['cron', 'eval'], }, + topicPageSize: 20, }); expect(mockQueryTopics).toHaveBeenCalledWith({ @@ -70,12 +76,17 @@ describe('AiChatService', () => { excludeStatuses: ['completed'], excludeTriggers: ['cron', 'eval'], groupId: undefined, + pageSize: 20, }); // topicFilter must not leak into messageModel.query expect(mockQueryMessages).toHaveBeenCalledWith( expect.not.objectContaining({ topicFilter: expect.anything() }), expect.objectContaining({ postProcessUrl: expect.any(Function) }), ); + expect(mockQueryMessages).toHaveBeenCalledWith( + expect.not.objectContaining({ topicPageSize: 20 }), + expect.objectContaining({ postProcessUrl: expect.any(Function) }), + ); }); it('getMessagesAndTopics should not query topics when includeTopic is false', async () => { diff --git a/src/server/services/aiChat/index.ts b/src/server/services/aiChat/index.ts index a798011267..85a24c0ebb 100644 --- a/src/server/services/aiChat/index.ts +++ b/src/server/services/aiChat/index.ts @@ -1,9 +1,33 @@ -import { type LobeChatDatabase } from '@lobechat/database'; +import type { LobeChatDatabase } from '@lobechat/database'; +import { createTimingHelpers } from '@lobechat/utils'; import { MessageModel } from '@/database/models/message'; import { TopicModel } from '@/database/models/topic'; import { FileService } from '@/server/services/file'; +const { createPrefixedTimingContext, runTimedStage, toTimingContext } = createTimingHelpers( + 'lobe-server:chat:lobehub:timing', +); + +interface GetMessagesAndTopicsParams { + agentId?: string; + current?: number; + groupId?: string; + includeTopic?: boolean; + pageSize?: number; + sessionId?: string; + threadId?: string; + timingRequestId?: string; + timingStartedAt?: number; + topicFilter?: { + excludeStatuses?: string[]; + excludeTriggers?: string[]; + includeTriggers?: string[]; + }; + topicId?: string; + topicPageSize?: number; +} + export class AiChatService { private userId: string; private messageModel: MessageModel; @@ -18,32 +42,48 @@ export class AiChatService { this.fileService = new FileService(serverDB, userId); } - async getMessagesAndTopics(params: { - agentId?: string; - current?: number; - groupId?: string; - includeTopic?: boolean; - pageSize?: number; - sessionId?: string; - threadId?: string; - topicFilter?: { - excludeStatuses?: string[]; - excludeTriggers?: string[]; - includeTriggers?: string[]; - }; - topicId?: string; - }) { - const { topicFilter, ...messageParams } = params; + async getMessagesAndTopics(params: GetMessagesAndTopicsParams) { + const { topicFilter, topicPageSize, timingRequestId, timingStartedAt, ...messageParams } = + params; + const timingContext = toTimingContext({ timingRequestId, timingStartedAt }); + const messageTiming = createPrefixedTimingContext( + timingContext, + 'lambda.aiChat.messagesAndTopics.messageModel.query', + ); + const topicTiming = createPrefixedTimingContext( + timingContext, + 'lambda.aiChat.messagesAndTopics.topicModel.query', + ); + const messageQueryPromise = runTimedStage( + timingContext, + 'lambda.aiChat.messagesAndTopics.messageModel.query', + () => + this.messageModel.query(messageParams, { + postProcessUrl: (path) => this.fileService.getFullFileUrl(path), + ...(messageTiming ? { timing: messageTiming } : {}), + }), + { + hasAgentId: !!params.agentId, + hasThreadId: !!params.threadId, + hasTopicId: !!params.topicId, + }, + ); const [messages, topics] = await Promise.all([ - this.messageModel.query(messageParams, { - postProcessUrl: (path) => this.fileService.getFullFileUrl(path), - }), + messageQueryPromise, params.includeTopic - ? this.topicModel.query({ - agentId: params.agentId, - groupId: params.groupId, - ...topicFilter, - }) + ? runTimedStage( + timingContext, + 'lambda.aiChat.messagesAndTopics.topicModel.query', + () => + this.topicModel.query({ + agentId: params.agentId, + groupId: params.groupId, + pageSize: topicPageSize, + ...(topicTiming ? { timing: topicTiming } : {}), + ...topicFilter, + }), + { hasAgentId: !!params.agentId, hasGroupId: !!params.groupId }, + ) : undefined, ]); diff --git a/src/server/services/message/index.ts b/src/server/services/message/index.ts index 07f55cdc64..f6d180774a 100644 --- a/src/server/services/message/index.ts +++ b/src/server/services/message/index.ts @@ -5,6 +5,7 @@ import { type UIChatMessage, type UpdateMessageParams, } from '@lobechat/types'; +import { createTimingHelpers, getDurationMs } from '@lobechat/utils'; import { MessageModel } from '@/database/models/message'; @@ -15,9 +16,26 @@ interface QueryOptions { groupId?: string | null; sessionId?: string | null; threadId?: string | null; + timingRequestId?: string; + timingStartedAt?: number; topicId?: string | null; } +const { createPrefixedTimingContext, logTiming, toTimingContext } = createTimingHelpers( + 'lobe-server:chat:lobehub:timing', +); + +const logMessageTiming = ( + options: QueryOptions | undefined, + event: string, + metadata?: Record, +) => { + logTiming(toTimingContext(options), event, metadata); +}; + +const createModelTiming = (options: QueryOptions | undefined, prefix: string) => + createPrefixedTimingContext(toTimingContext(options), prefix); + interface CreateMessageResult { id: string; messages: any[]; @@ -70,15 +88,25 @@ export class MessageService { options.sessionId === undefined && options.topicId === undefined) ) { + logMessageTiming(options, 'lambda.message.update.queryMessages:skipped'); return { success: true }; } const { agentId, sessionId, topicId, groupId, threadId } = options; + const queryStartedAt = Date.now(); + const modelTiming = createModelTiming(options, 'lambda.message.update.queryMessages'); const messages = await this.messageModel.query( { agentId, groupId, sessionId, threadId, topicId }, - this.getQueryOptions(), + { + ...this.getQueryOptions(), + ...(modelTiming ? { timing: modelTiming } : {}), + }, ); + logMessageTiming(options, 'lambda.message.update.queryMessages:done', { + messageCount: messages.length, + stageMs: getDurationMs(queryStartedAt), + }); return { messages, success: true }; } @@ -188,7 +216,18 @@ export class MessageService { value: UpdateMessageParams, options: QueryOptions, ): Promise<{ messages?: UIChatMessage[]; success: boolean }> { - await this.messageModel.update(id, value as any); + const updateStartedAt = Date.now(); + const modelTiming = createModelTiming(options, 'lambda.message.update.dbUpdate'); + if (modelTiming) { + await this.messageModel.update(id, value as any, modelTiming); + } else { + await this.messageModel.update(id, value as any); + } + logMessageTiming(options, 'lambda.message.update.dbUpdate:done', { + stageMs: getDurationMs(updateStartedAt), + valueKeys: Object.keys(value ?? {}), + }); + return this.queryWithSuccess(options); } diff --git a/src/store/chat/slices/aiChat/actions/__tests__/conversationLifecycle.test.ts b/src/store/chat/slices/aiChat/actions/__tests__/conversationLifecycle.test.ts index c754be8045..411fe5abcc 100644 --- a/src/store/chat/slices/aiChat/actions/__tests__/conversationLifecycle.test.ts +++ b/src/store/chat/slices/aiChat/actions/__tests__/conversationLifecycle.test.ts @@ -9,6 +9,7 @@ import { chatService } from '@/services/chat'; import { messageService } from '@/services/message'; import * as agentGroupStore from '@/store/agentGroup'; import { messageMapKey } from '@/store/chat/utils/messageMapKey'; +import { topicMapKey } from '@/store/chat/utils/topicMapKey'; import { getSessionStoreState } from '@/store/session'; import * as toolStoreModule from '@/store/tool'; @@ -1622,7 +1623,6 @@ describe('ConversationLifecycle actions', () => { createMockMessage({ id: 'new-user-msg', role: 'user', topicId: newTopicId }), createMockMessage({ id: 'new-assistant-msg', role: 'assistant', topicId: newTopicId }), ], - topics: { items: [{ id: newTopicId, title: 'New Topic' }], total: 1 }, topicId: newTopicId, isCreateNewTopic: true, assistantMessageId: 'new-assistant-msg', @@ -1648,6 +1648,12 @@ describe('ConversationLifecycle actions', () => { // After new topic creation, the _new key should be cleared const messagesInNewKey = useChatStore.getState().messagesMap[newKey]; expect(messagesInNewKey ?? []).toHaveLength(0); + + const newTopicKey = messageMapKey({ agentId, topicId: newTopicId }); + expect(useChatStore.getState().messagesMap[newTopicKey]).toHaveLength(2); + expect(useChatStore.getState().topicDataMap[topicMapKey({ agentId })]?.items[0]).toEqual( + expect.objectContaining({ id: newTopicId }), + ); }); }); }); diff --git a/src/store/chat/slices/aiChat/actions/conversationLifecycle.ts b/src/store/chat/slices/aiChat/actions/conversationLifecycle.ts index 185043e603..151d7ffbaf 100644 --- a/src/store/chat/slices/aiChat/actions/conversationLifecycle.ts +++ b/src/store/chat/slices/aiChat/actions/conversationLifecycle.ts @@ -523,6 +523,7 @@ export class ConversationLifecycleActionImpl { operationContext.agentId, operationContext.groupId ?? undefined, ), + topicPageSize: systemStatusSelectors.topicPageSize(useGlobalStore.getState()), topicId: operationContext.topicId ?? undefined, }, abortController, @@ -712,6 +713,7 @@ export class ConversationLifecycleActionImpl { const toolContext = formatSelectedToolsContext(dedupedTools); const contextSuffix = [skillContext, toolContext].filter(Boolean).join('\n'); const persistedContent = contextSuffix ? `${message}\n\n${contextSuffix}` : message; + const newTopicTitle = message.slice(0, 80) || t('defaultTitle', { ns: 'topic' }); data = await aiChatService.sendMessageInServer( { @@ -730,6 +732,7 @@ export class ConversationLifecycleActionImpl { operationContext.agentId, operationContext.groupId ?? undefined, ), + topicPageSize: systemStatusSelectors.topicPageSize(useGlobalStore.getState()), threadId: operationContext.threadId ?? undefined, // Support creating new thread along with message newThread: newThread @@ -741,7 +744,7 @@ export class ConversationLifecycleActionImpl { newTopic: !topicId ? { topicMessageIds: forceNewTopicFromExisting ? [] : messages.map((m) => m.id), - title: message.slice(0, 80) || t('defaultTitle', { ns: 'topic' }), + title: newTopicTitle, } : undefined, agentId: operationContext.agentId, @@ -757,7 +760,7 @@ export class ConversationLifecycleActionImpl { abortController, ); // Use created topicId/threadId if available, otherwise use original from context - let finalTopicId = operationContext.topicId; + let finalTopicId = data.topicId ?? operationContext.topicId; const finalThreadId = data.createdThreadId ?? operationContext.threadId; // refresh the total data @@ -780,6 +783,18 @@ export class ConversationLifecycleActionImpl { // Record the created topicId in metadata (not context) this.#get().updateOperationMetadata(operationId, { createdTopicId: data.topicId }); } + } else if (data.isCreateNewTopic && data.topicId && !context.isolatedTopic) { + this.#get().internal_dispatchTopic( + { + type: 'addTopic', + value: { + id: data.topicId, + title: newTopicTitle, + }, + }, + 'sendMessage/createTopicPlaceholder', + ); + this.#get().updateOperationMetadata(operationId, { createdTopicId: data.topicId }); } else if (operationContext.topicId) { // Optimistically update topic's updatedAt so sidebar re-groups immediately this.#get().internal_dispatchTopic({