️ perf: optimize chat bootstrap persistence (#14934)

This commit is contained in:
YuTengjing
2026-05-19 12:53:32 +08:00
committed by GitHub
parent 97ea30e48b
commit 391b16e082
18 changed files with 2239 additions and 627 deletions
@@ -1,5 +1,5 @@
import type { DBMessageItem } from '@lobechat/types'; import type { DBMessageItem } from '@lobechat/types';
import { eq } from 'drizzle-orm'; import { asc, eq } from 'drizzle-orm';
import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import { uuid } from '@/utils/uuid'; import { uuid } from '@/utils/uuid';
@@ -16,6 +16,7 @@ import {
messages, messages,
messagesFiles, messagesFiles,
sessions, sessions,
topics,
users, users,
} from '../../../schemas'; } from '../../../schemas';
import type { LobeChatDatabase } from '../../../type'; import type { LobeChatDatabase } from '../../../type';
@@ -248,6 +249,124 @@ describe('MessageModel Create Tests', () => {
expect(pluginResult[0].arguments).not.toContain('\u0000'); expect(pluginResult[0].arguments).not.toContain('\u0000');
}); });
it('should create user and assistant messages with one topic touch', async () => {
await serverDB.insert(topics).values({
id: 'topic-pair',
sessionId: '1',
title: 'Topic pair',
userId,
});
const timingEvents: string[] = [];
const result = await messageModel.createUserAndAssistantMessages(
{
assistantMessage: {
content: '',
model: 'gpt-4o',
provider: 'openai',
role: 'assistant',
sessionId: '1',
topicId: 'topic-pair',
},
userMessage: {
content: 'hello',
files: ['f1'],
role: 'user',
sessionId: '1',
topicId: 'topic-pair',
},
},
{
timing: {
log: (event) => timingEvents.push(event),
},
},
);
expect(result.userMessage.id).toBeDefined();
expect(result.assistantMessage.id).toBeDefined();
expect(result.assistantMessage.parentId).toBe(result.userMessage.id);
expect(result.userMessage.createdAt.getTime()).toBeLessThan(
result.assistantMessage.createdAt.getTime(),
);
const dbMessages = await serverDB
.select()
.from(messages)
.where(eq(messages.userId, userId))
.orderBy(asc(messages.createdAt));
expect(dbMessages.map((message) => message.id)).toEqual([
result.userMessage.id,
result.assistantMessage.id,
]);
const messageFiles = await serverDB
.select()
.from(messagesFiles)
.where(eq(messagesFiles.messageId, result.userMessage.id));
expect(messageFiles).toHaveLength(1);
expect(
timingEvents.filter(
(event) => event === 'db.message.createUserAndAssistant.messages.insert:start',
),
).toHaveLength(1);
expect(
timingEvents.filter(
(event) => event === 'db.message.createUserAndAssistant.topic.touchUpdatedAt:start',
),
).toHaveLength(1);
});
it('should skip topic touch when creating a pair for an already-created topic', async () => {
await serverDB.insert(topics).values({
id: 'topic-pair-no-touch',
sessionId: '1',
title: 'Topic pair no touch',
userId,
});
const timingEvents: string[] = [];
const result = await messageModel.createUserAndAssistantMessages(
{
assistantMessage: {
content: '',
model: 'gpt-4o',
provider: 'openai',
role: 'assistant',
sessionId: '1',
topicId: 'topic-pair-no-touch',
},
userMessage: {
content: 'hello',
role: 'user',
sessionId: '1',
topicId: 'topic-pair-no-touch',
},
},
{
timing: {
log: (event) => timingEvents.push(event),
},
touchTopicUpdatedAt: false,
},
);
expect(result.userMessage.id).toBeDefined();
expect(result.assistantMessage.parentId).toBe(result.userMessage.id);
expect(
timingEvents.filter(
(event) => event === 'db.message.createUserAndAssistant.messages.insert:start',
),
).toHaveLength(1);
expect(
timingEvents.filter(
(event) => event === 'db.message.createUserAndAssistant.topic.touchUpdatedAt:start',
),
).toHaveLength(0);
});
describe('create with advanced parameters', () => { describe('create with advanced parameters', () => {
it('should create a message with custom ID', async () => { it('should create a message with custom ID', async () => {
const customId = 'custom-msg-id'; const customId = 'custom-msg-id';
@@ -95,7 +95,10 @@ describe('TopicModel - Create', () => {
const topicId = 'new-topic'; const topicId = 'new-topic';
const createdTopic = await topicModel.create(topicData, topicId); const timingEvents: string[] = [];
const createdTopic = await topicModel.create(topicData, topicId, {
log: (event) => timingEvents.push(event),
});
expect(createdTopic).toEqual({ expect(createdTopic).toEqual({
id: topicId, id: topicId,
@@ -123,6 +126,8 @@ describe('TopicModel - Create', () => {
const dbTopic = await serverDB.select().from(topics).where(eq(topics.id, topicId)); const dbTopic = await serverDB.select().from(topics).where(eq(topics.id, topicId));
expect(dbTopic).toHaveLength(1); expect(dbTopic).toHaveLength(1);
expect(dbTopic[0]).toEqual(createdTopic); expect(dbTopic[0]).toEqual(createdTopic);
expect(timingEvents).toContain('db.topic.create.topics.insert:start');
expect(timingEvents).not.toContain('db.topic.create.transaction:start');
}); });
it('should create a new topic with agentId', async () => { it('should create a new topic with agentId', async () => {
File diff suppressed because it is too large Load Diff
+200 -92
View File
@@ -4,6 +4,12 @@ import type {
DBMessageItem, DBMessageItem,
TopicRankItem, TopicRankItem,
} from '@lobechat/types'; } from '@lobechat/types';
import type { TimingSink } from '@lobechat/utils';
import {
getDurationMs,
logTimingSink as logTiming,
runTimedSinkStage as runTimedStage,
} from '@lobechat/utils';
import type { SQL } from 'drizzle-orm'; import type { SQL } from 'drizzle-orm';
import { and, count, desc, eq, gt, gte, inArray, isNull, lte, ne, not, or, sql } from 'drizzle-orm'; import { and, count, desc, eq, gt, gte, inArray, isNull, lte, ne, not, or, sql } from 'drizzle-orm';
@@ -62,12 +68,15 @@ interface QueryTopicParams {
*/ */
isInbox?: boolean; isInbox?: boolean;
pageSize?: number; pageSize?: number;
timing?: ModelTimingContext;
/** /**
* Include only topics matching the given trigger types (positive filter) * Include only topics matching the given trigger types (positive filter)
*/ */
triggers?: string[]; triggers?: string[];
} }
export interface ModelTimingContext extends TimingSink {}
export interface ListTopicsForMemoryExtractorCursor { export interface ListTopicsForMemoryExtractorCursor {
createdAt: Date; createdAt: Date;
id: string; id: string;
@@ -93,8 +102,18 @@ export class TopicModel {
pageSize = 9999, pageSize = 9999,
groupId, groupId,
isInbox, isInbox,
timing,
triggers, triggers,
}: QueryTopicParams = {}) => { }: QueryTopicParams = {}) => {
const queryStartedAt = Date.now();
logTiming(timing, 'db.topic.query:start', {
current,
hasAgentId: !!agentId,
hasContainerId: !!containerId,
hasGroupId: !!groupId,
isInbox: !!isInbox,
pageSize,
});
const offset = current * pageSize; const offset = current * pageSize;
const includeTriggerCondition = const includeTriggerCondition =
includeTriggers && includeTriggers.length > 0 includeTriggers && includeTriggers.length > 0
@@ -127,29 +146,42 @@ export class TopicModel {
); );
const [items, totalResult] = await Promise.all([ const [items, totalResult] = await Promise.all([
this.db runTimedStage(
.select({ timing,
completedAt: topics.completedAt, 'db.topic.query.group.items.select',
createdAt: topics.createdAt, () =>
favorite: topics.favorite, this.db
historySummary: topics.historySummary, .select({
id: topics.id, completedAt: topics.completedAt,
metadata: topics.metadata, createdAt: topics.createdAt,
status: topics.status, favorite: topics.favorite,
title: topics.title, historySummary: topics.historySummary,
updatedAt: topics.updatedAt, id: topics.id,
}) metadata: topics.metadata,
.from(topics) status: topics.status,
.where(whereCondition) title: topics.title,
.orderBy(desc(topics.favorite), desc(topics.updatedAt)) updatedAt: topics.updatedAt,
.limit(pageSize) })
.offset(offset), .from(topics)
this.db .where(whereCondition)
.select({ count: count(topics.id) }) .orderBy(desc(topics.favorite), desc(topics.updatedAt))
.from(topics) .limit(pageSize)
.where(whereCondition), .offset(offset),
{ current, pageSize },
),
runTimedStage(timing, 'db.topic.query.group.count.select', () =>
this.db
.select({ count: count(topics.id) })
.from(topics)
.where(whereCondition),
),
]); ]);
logTiming(timing, 'db.topic.query:done', {
itemCount: items.length,
stageMs: getDurationMs(queryStartedAt),
total: totalResult[0].count,
});
return { items, total: totalResult[0].count }; return { items, total: totalResult[0].count };
} }
@@ -159,11 +191,19 @@ export class TopicModel {
// 3. For inbox: sessionId IS NULL AND groupId IS NULL AND agentId IS NULL (legacy inbox data) // 3. For inbox: sessionId IS NULL AND groupId IS NULL AND agentId IS NULL (legacy inbox data)
if (agentId) { if (agentId) {
// Get the associated sessionId for backward compatibility with legacy data // Get the associated sessionId for backward compatibility with legacy data
const agentSession = await this.db const agentSession = await runTimedStage(
.select({ sessionId: agentsToSessions.sessionId }) timing,
.from(agentsToSessions) 'db.topic.query.agentSession.select',
.where(and(eq(agentsToSessions.agentId, agentId), eq(agentsToSessions.userId, this.userId))) () =>
.limit(1); this.db
.select({ sessionId: agentsToSessions.sessionId })
.from(agentsToSessions)
.where(
and(eq(agentsToSessions.agentId, agentId), eq(agentsToSessions.userId, this.userId)),
)
.limit(1),
{ hasAgentId: true },
);
const associatedSessionId = agentSession[0]?.sessionId; const associatedSessionId = agentSession[0]?.sessionId;
@@ -201,29 +241,46 @@ export class TopicModel {
); );
const [items, totalResult] = await Promise.all([ const [items, totalResult] = await Promise.all([
this.db runTimedStage(
.select({ timing,
completedAt: topics.completedAt, 'db.topic.query.agent.items.select',
createdAt: topics.createdAt, () =>
favorite: topics.favorite, this.db
historySummary: topics.historySummary, .select({
id: topics.id, completedAt: topics.completedAt,
metadata: topics.metadata, createdAt: topics.createdAt,
status: topics.status, favorite: topics.favorite,
title: topics.title, historySummary: topics.historySummary,
updatedAt: topics.updatedAt, id: topics.id,
}) metadata: topics.metadata,
.from(topics) status: topics.status,
.where(agentWhere) title: topics.title,
.orderBy(desc(topics.favorite), desc(topics.updatedAt)) updatedAt: topics.updatedAt,
.limit(pageSize) })
.offset(offset), .from(topics)
this.db .where(agentWhere)
.select({ count: count(topics.id) }) .orderBy(desc(topics.favorite), desc(topics.updatedAt))
.from(topics) .limit(pageSize)
.where(agentWhere), .offset(offset),
{ current, hasAssociatedSessionId: !!associatedSessionId, isInbox: !!isInbox, pageSize },
),
runTimedStage(
timing,
'db.topic.query.agent.count.select',
() =>
this.db
.select({ count: count(topics.id) })
.from(topics)
.where(agentWhere),
{ hasAssociatedSessionId: !!associatedSessionId, isInbox: !!isInbox },
),
]); ]);
logTiming(timing, 'db.topic.query:done', {
itemCount: items.length,
stageMs: getDurationMs(queryStartedAt),
total: totalResult[0].count,
});
return { items, total: totalResult[0].count }; return { items, total: totalResult[0].count };
} }
@@ -238,37 +295,51 @@ export class TopicModel {
); );
const [items, totalResult] = await Promise.all([ const [items, totalResult] = await Promise.all([
this.db runTimedStage(
.select({ timing,
agentId: topics.agentId, 'db.topic.query.container.items.select',
completedAt: topics.completedAt, () =>
createdAt: topics.createdAt, this.db
favorite: topics.favorite, .select({
historySummary: topics.historySummary, agentId: topics.agentId,
id: topics.id, completedAt: topics.completedAt,
metadata: topics.metadata, createdAt: topics.createdAt,
sessionId: topics.sessionId, favorite: topics.favorite,
status: topics.status, historySummary: topics.historySummary,
title: topics.title, id: topics.id,
updatedAt: topics.updatedAt, metadata: topics.metadata,
}) sessionId: topics.sessionId,
.from(topics) status: topics.status,
.where(whereCondition) title: topics.title,
// In boolean sorting, false is considered "smaller" than true. updatedAt: topics.updatedAt,
// So here we use desc to ensure that topics with favorite as true are in front. })
.orderBy(desc(topics.favorite), desc(topics.updatedAt)) .from(topics)
.limit(pageSize) .where(whereCondition)
.offset(offset), // In boolean sorting, false is considered "smaller" than true.
this.db // So here we use desc to ensure that topics with favorite as true are in front.
.select({ count: count(topics.id) }) .orderBy(desc(topics.favorite), desc(topics.updatedAt))
.from(topics) .limit(pageSize)
.where(whereCondition), .offset(offset),
{ current, pageSize },
),
runTimedStage(timing, 'db.topic.query.container.count.select', () =>
this.db
.select({ count: count(topics.id) })
.from(topics)
.where(whereCondition),
),
]); ]);
// Remove internal fields before returning // Remove internal fields before returning
const cleanItems = items.map(({ agentId, sessionId, ...rest }) => rest); const cleanItems = items.map(({ agentId, sessionId, ...rest }) => rest);
logTiming(timing, 'db.topic.query:done', {
itemCount: cleanItems.length,
stageMs: getDurationMs(queryStartedAt),
total: totalResult[0].count,
});
return { items: cleanItems, total: totalResult[0].count }; return { items: cleanItems, total: totalResult[0].count };
}; };
@@ -468,30 +539,67 @@ export class TopicModel {
create = async ( create = async (
{ messages: messageIds, ...params }: CreateTopicParams, { messages: messageIds, ...params }: CreateTopicParams,
id: string = this.genId(), id: string = this.genId(),
timing?: ModelTimingContext,
): Promise<TopicItem> => { ): Promise<TopicItem> => {
return this.db.transaction(async (tx) => { const insertData = {
const insertData = { ...params,
...params, agentId: params.agentId || null,
agentId: params.agentId || null, groupId: params.groupId || null,
groupId: params.groupId || null, id,
id, sessionId: params.sessionId || null,
sessionId: params.sessionId || null, userId: this.userId,
userId: this.userId, };
}; const insertMeta = {
hasAgentId: !!params.agentId,
hasGroupId: !!params.groupId,
hasSessionId: !!params.sessionId,
};
// Insert new topic if (!messageIds || messageIds.length === 0) {
const [topic] = await tx.insert(topics).values(insertData).returning(); const [topic] = await runTimedStage(
timing,
// Update associated messages' topicId 'db.topic.create.topics.insert',
if (messageIds && messageIds.length > 0) { () => this.db.insert(topics).values(insertData).returning(),
await tx insertMeta,
.update(messages) );
.set({ topicId: topic.id })
.where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIds)));
}
return topic; return topic;
}); }
return runTimedStage(
timing,
'db.topic.create.transaction',
() =>
this.db.transaction(async (tx) => {
// Insert new topic
const [topic] = await runTimedStage(
timing,
'db.topic.create.topics.insert',
() => tx.insert(topics).values(insertData).returning(),
insertMeta,
);
// Update associated messages' topicId
await runTimedStage(
timing,
'db.topic.create.messages.updateTopic',
() =>
tx
.update(messages)
.set({ topicId: topic.id })
.where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIds))),
{ messageCount: messageIds.length },
);
return topic;
}),
{
hasAgentId: !!params.agentId,
hasGroupId: !!params.groupId,
hasSessionId: !!params.sessionId,
messageCount: messageIds?.length ?? 0,
},
);
}; };
batchCreate = async (topicParams: (CreateTopicParams & { id?: string })[]) => { batchCreate = async (topicParams: (CreateTopicParams & { id?: string })[]) => {
+112 -2
View File
@@ -1,4 +1,5 @@
import type { ModelUsage, TracePayload } from '@lobechat/types'; import type { ModelUsage, TracePayload } from '@lobechat/types';
import { createTimingHelpers, getDurationMs } from '@lobechat/utils';
import type { ClientOptions } from 'openai'; import type { ClientOptions } from 'openai';
import type { LobeBedrockAIParams } from '../providers/bedrock'; import type { LobeBedrockAIParams } from '../providers/bedrock';
@@ -32,6 +33,13 @@ import type {
import { AgentRuntimeError } from '../utils/createError'; import { AgentRuntimeError } from '../utils/createError';
import type { LobeRuntimeAI } from './BaseAI'; import type { LobeRuntimeAI } from './BaseAI';
const { logger: timing } = createTimingHelpers('lobe-server:chat:lobehub:timing');
const getLobeHubTimingMetadata = (options?: {
metadata?: Record<string, unknown>;
}): Record<string, unknown> | undefined =>
options?.metadata?.provider === 'lobehub' ? options.metadata : undefined;
export interface AgentChatOptions { export interface AgentChatOptions {
enableTrace?: boolean; enableTrace?: boolean;
provider: string; provider: string;
@@ -126,6 +134,17 @@ export class ModelRuntime {
* ``` * ```
*/ */
async chat(payload: ChatStreamPayload, options?: ChatMethodOptions) { async chat(payload: ChatStreamPayload, options?: ChatMethodOptions) {
const metadata = getLobeHubTimingMetadata(options);
const startedAt = Date.now();
if (metadata) {
timing(
'ModelRuntime.chat start model=%s trigger=%s traceId=%s',
payload.model,
metadata.trigger,
metadata.traceId,
);
}
if (typeof this._runtime.chat !== 'function') { if (typeof this._runtime.chat !== 'function') {
throw AgentRuntimeError.chat({ throw AgentRuntimeError.chat({
error: new Error('Chat is not supported by this provider'), error: new Error('Chat is not supported by this provider'),
@@ -135,11 +154,48 @@ export class ModelRuntime {
} }
try { try {
const hooksStartedAt = Date.now();
const finalOptions = await this.applyHooks(payload, options); const finalOptions = await this.applyHooks(payload, options);
return await this._runtime.chat(payload, finalOptions); if (metadata) {
timing(
'ModelRuntime.chat hooks done model=%s durationMs=%d traceId=%s',
payload.model,
getDurationMs(hooksStartedAt),
metadata.traceId,
);
}
const runtimeStartedAt = Date.now();
const response = await this._runtime.chat(payload, finalOptions);
if (metadata) {
timing(
'ModelRuntime.chat runtime done model=%s durationMs=%d totalMs=%d traceId=%s',
payload.model,
getDurationMs(runtimeStartedAt),
getDurationMs(startedAt),
metadata.traceId,
);
}
return response;
} catch (error) { } catch (error) {
if (metadata) {
timing(
'ModelRuntime.chat error model=%s durationMs=%d traceId=%s',
payload.model,
getDurationMs(startedAt),
metadata.traceId,
);
}
if (this._hooks?.onChatError) { if (this._hooks?.onChatError) {
const errorHookStartedAt = Date.now();
await this._hooks.onChatError(error as ChatCompletionErrorPayload, { options, payload }); await this._hooks.onChatError(error as ChatCompletionErrorPayload, { options, payload });
if (metadata) {
timing(
'ModelRuntime.chat onChatError done model=%s durationMs=%d traceId=%s',
payload.model,
getDurationMs(errorHookStartedAt),
metadata.traceId,
);
}
} }
throw error; throw error;
} }
@@ -152,7 +208,37 @@ export class ModelRuntime {
payload: ChatStreamPayload, payload: ChatStreamPayload,
options?: ChatMethodOptions, options?: ChatMethodOptions,
): Promise<ChatMethodOptions | undefined> { ): Promise<ChatMethodOptions | undefined> {
await this._hooks?.beforeChat?.(payload, options); const metadata = getLobeHubTimingMetadata(options);
const beforeChatStartedAt = Date.now();
if (metadata) {
timing(
'ModelRuntime.beforeChat start model=%s trigger=%s traceId=%s',
payload.model,
metadata.trigger,
metadata.traceId,
);
}
try {
await this._hooks?.beforeChat?.(payload, options);
} catch (error) {
if (metadata) {
timing(
'ModelRuntime.beforeChat error model=%s durationMs=%d traceId=%s',
payload.model,
getDurationMs(beforeChatStartedAt),
metadata.traceId,
);
}
throw error;
}
if (metadata) {
timing(
'ModelRuntime.beforeChat done model=%s durationMs=%d traceId=%s',
payload.model,
getDurationMs(beforeChatStartedAt),
metadata.traceId,
);
}
if (!this._hooks?.onChatFinal) return options; if (!this._hooks?.onChatFinal) return options;
@@ -163,10 +249,34 @@ export class ModelRuntime {
callback: { callback: {
...options?.callback, ...options?.callback,
async onFinal(data) { async onFinal(data) {
const finalStartedAt = Date.now();
if (metadata) {
timing(
'ModelRuntime.onChatFinal start model=%s traceId=%s',
payload.model,
metadata.traceId,
);
}
await existingOnFinal?.(data); await existingOnFinal?.(data);
try { try {
await hookFn(data, { options, payload }); await hookFn(data, { options, payload });
if (metadata) {
timing(
'ModelRuntime.onChatFinal done model=%s durationMs=%d traceId=%s',
payload.model,
getDurationMs(finalStartedAt),
metadata.traceId,
);
}
} catch (e) { } catch (e) {
if (metadata) {
timing(
'ModelRuntime.onChatFinal error model=%s durationMs=%d traceId=%s',
payload.model,
getDurationMs(finalStartedAt),
metadata.traceId,
);
}
// Hook failures (billing, tracing) must not interfere with response completion // Hook failures (billing, tracing) must not interfere with response completion
console.error('[ModelRuntime] onChatFinal hook error:', e); console.error('[ModelRuntime] onChatFinal hook error:', e);
} }
@@ -4,6 +4,7 @@
import type { GoogleGenAIOptions } from '@google/genai'; import type { GoogleGenAIOptions } from '@google/genai';
import type { ChatModelCard } from '@lobechat/types'; import type { ChatModelCard } from '@lobechat/types';
import { AgentRuntimeErrorType } from '@lobechat/types'; import { AgentRuntimeErrorType } from '@lobechat/types';
import { createTimingHelpers, getDurationMs } from '@lobechat/utils';
import debug from 'debug'; import debug from 'debug';
import type { ClientOptions } from 'openai'; import type { ClientOptions } from 'openai';
import type OpenAI from 'openai'; import type OpenAI from 'openai';
@@ -44,6 +45,7 @@ import type {
import type { ApiType, RuntimeClass } from './apiTypes'; import type { ApiType, RuntimeClass } from './apiTypes';
const log = debug('lobe-model-runtime:router-runtime'); const log = debug('lobe-model-runtime:router-runtime');
const { logger: timing } = createTimingHelpers('lobe-server:chat:lobehub:timing');
interface ProviderIniOptions extends Record<string, any> { interface ProviderIniOptions extends Record<string, any> {
accessKeyId?: string; accessKeyId?: string;
@@ -190,6 +192,7 @@ export const createRouterRuntime = ({
private _id: string; private _id: string;
constructor(options: ClientOptions & Record<string, any> = {}) { constructor(options: ClientOptions & Record<string, any> = {}) {
const startedAt = Date.now();
this._options = { this._options = {
...options, ...options,
apiKey: options.apiKey?.trim() || DEFAULT_API_KEY, apiKey: options.apiKey?.trim() || DEFAULT_API_KEY,
@@ -200,36 +203,76 @@ export const createRouterRuntime = ({
this._routers = routers; this._routers = routers;
this._params = params; this._params = params;
this._id = options.id ?? id; this._id = options.id ?? id;
if (this._id === 'lobehub') {
timing(
'constructor done providerId=%s durationMs=%d hasApiKey=%s hasBaseURL=%s',
this._id,
getDurationMs(startedAt),
!!this._options.apiKey,
!!this._options.baseURL,
);
}
} }
/** /**
* Resolve routers configuration and validate * Resolve routers configuration and validate
*/ */
private async resolveRouters(model?: string): Promise<RouterInstance[]> { private async resolveRouters(model?: string): Promise<RouterInstance[]> {
const resolvedRouters = const startedAt = Date.now();
typeof this._routers === 'function' try {
? await this._routers(this._options, { model }) const resolvedRouters =
: this._routers; typeof this._routers === 'function'
? await this._routers(this._options, { model })
: this._routers;
if (resolvedRouters.length === 0) { if (this._id === 'lobehub') {
throw AgentRuntimeError.chat({ timing(
error: { message: 'empty providers' }, 'resolveRouters done model=%s durationMs=%d routerCount=%d dynamic=%s',
errorType: AgentRuntimeErrorType.NoAvailableProvider, model,
provider: this._id, getDurationMs(startedAt),
}); resolvedRouters.length,
typeof this._routers === 'function',
);
}
if (resolvedRouters.length === 0) {
throw AgentRuntimeError.chat({
error: { message: 'empty providers' },
errorType: AgentRuntimeErrorType.NoAvailableProvider,
provider: this._id,
});
}
return resolvedRouters;
} catch (error) {
if (this._id === 'lobehub') {
timing('resolveRouters error model=%s durationMs=%d', model, getDurationMs(startedAt));
}
throw error;
} }
return resolvedRouters;
} }
private async resolveMatchedRouter(model: string): Promise<RouterInstance> { private async resolveMatchedRouter(model: string): Promise<RouterInstance> {
const startedAt = Date.now();
const resolvedRouters = await this.resolveRouters(model); const resolvedRouters = await this.resolveRouters(model);
const baseURL = this._options.baseURL; const baseURL = this._options.baseURL;
// Priority 1: Match by baseURLPattern (RegExp only) // Priority 1: Match by baseURLPattern (RegExp only)
if (baseURL) { if (baseURL) {
const baseURLMatch = resolvedRouters.find((router) => router.baseURLPattern?.test(baseURL)); const baseURLMatch = resolvedRouters.find((router) => router.baseURLPattern?.test(baseURL));
if (baseURLMatch) return baseURLMatch; if (baseURLMatch) {
if (this._id === 'lobehub') {
timing(
'resolveMatchedRouter done model=%s match=baseURL routerId=%s apiType=%s durationMs=%d',
model,
baseURLMatch.id,
baseURLMatch.apiType,
getDurationMs(startedAt),
);
}
return baseURLMatch;
}
} }
// Priority 2: Match by models // Priority 2: Match by models
@@ -239,19 +282,50 @@ export const createRouterRuntime = ({
} }
return false; return false;
}); });
if (modelMatch) return modelMatch; if (modelMatch) {
if (this._id === 'lobehub') {
timing(
'resolveMatchedRouter done model=%s match=models routerId=%s apiType=%s durationMs=%d',
model,
modelMatch.id,
modelMatch.apiType,
getDurationMs(startedAt),
);
}
return modelMatch;
}
// Fallback: Use the last router // Fallback: Use the last router
return resolvedRouters.at(-1)!; const fallbackRouter = resolvedRouters.at(-1)!;
if (this._id === 'lobehub') {
timing(
'resolveMatchedRouter done model=%s match=fallback routerId=%s apiType=%s durationMs=%d',
model,
fallbackRouter.id,
fallbackRouter.apiType,
getDurationMs(startedAt),
);
}
return fallbackRouter;
} }
private normalizeRouterOptions(router: RouterInstance): RouterOptionItem[] { private normalizeRouterOptions(router: RouterInstance): RouterOptionItem[] {
const startedAt = Date.now();
const routerOptions = Array.isArray(router.options) ? router.options : [router.options]; const routerOptions = Array.isArray(router.options) ? router.options : [router.options];
if (routerOptions.length === 0 || routerOptions.some((optionItem) => !optionItem)) { if (routerOptions.length === 0 || routerOptions.some((optionItem) => !optionItem)) {
throw new Error('empty provider options'); throw new Error('empty provider options');
} }
if (this._id === 'lobehub') {
timing(
'normalizeRouterOptions done routerId=%s options=%d durationMs=%d',
router.id,
routerOptions.length,
getDurationMs(startedAt),
);
}
return routerOptions; return routerOptions;
} }
@@ -268,6 +342,7 @@ export const createRouterRuntime = ({
remark?: string; remark?: string;
runtime: LobeRuntimeAI; runtime: LobeRuntimeAI;
}> { }> {
const startedAt = Date.now();
const { apiType: optionApiType, id: channelId, remark, ...optionOverrides } = optionItem; const { apiType: optionApiType, id: channelId, remark, ...optionOverrides } = optionItem;
const resolvedApiType = optionApiType ?? router.apiType; const resolvedApiType = optionApiType ?? router.apiType;
const finalOptions = { const finalOptions = {
@@ -297,6 +372,16 @@ export const createRouterRuntime = ({
if (project) vertexOptions.project = project; if (project) vertexOptions.project = project;
if (location) vertexOptions.location = location as GoogleGenAIOptions['location']; if (location) vertexOptions.location = location as GoogleGenAIOptions['location'];
if (this._id === 'lobehub') {
timing(
'createRuntimeFromOption done routerId=%s channelId=%s apiType=%s durationMs=%d vertex=true',
router.id,
channelId,
resolvedApiType,
getDurationMs(startedAt),
);
}
return { return {
channelId, channelId,
id: resolvedApiType, id: resolvedApiType,
@@ -312,6 +397,16 @@ export const createRouterRuntime = ({
: (baseRuntimeMap[resolvedApiType] ?? LobeOpenAI); : (baseRuntimeMap[resolvedApiType] ?? LobeOpenAI);
const runtime: LobeRuntimeAI = new providerAI({ ...finalOptions, id: this._id }); const runtime: LobeRuntimeAI = new providerAI({ ...finalOptions, id: this._id });
if (this._id === 'lobehub') {
timing(
'createRuntimeFromOption done routerId=%s channelId=%s apiType=%s durationMs=%d',
router.id,
channelId,
resolvedApiType,
getDurationMs(startedAt),
);
}
return { return {
channelId, channelId,
id: resolvedApiType, id: resolvedApiType,
@@ -325,10 +420,22 @@ export const createRouterRuntime = ({
requestHandler: (runtime: LobeRuntimeAI) => Promise<T>, requestHandler: (runtime: LobeRuntimeAI) => Promise<T>,
metadata?: Record<string, unknown>, metadata?: Record<string, unknown>,
): Promise<T> { ): Promise<T> {
const totalStartedAt = Date.now();
const matchedRouter = await this.resolveMatchedRouter(model); const matchedRouter = await this.resolveMatchedRouter(model);
const routerOptions = this.normalizeRouterOptions(matchedRouter); const routerOptions = this.normalizeRouterOptions(matchedRouter);
const totalOptions = routerOptions.length; const totalOptions = routerOptions.length;
if (this._id === 'lobehub') {
timing(
'runWithFallback start model=%s routerId=%s apiType=%s options=%d traceId=%s',
model,
matchedRouter.id,
matchedRouter.apiType,
totalOptions,
metadata?.traceId,
);
}
log( log(
'resolve router for model=%s apiType=%s options=%d', 'resolve router for model=%s apiType=%s options=%d',
model, model,
@@ -349,7 +456,33 @@ export const createRouterRuntime = ({
} = await this.createRuntimeFromOption(matchedRouter, optionItem); } = await this.createRuntimeFromOption(matchedRouter, optionItem);
try { try {
if (this._id === 'lobehub') {
timing(
'attempt request start model=%s attempt=%d/%d routerId=%s channelId=%s apiType=%s traceId=%s',
model,
attempt,
totalOptions,
matchedRouter.id,
channelId,
resolvedApiType,
metadata?.traceId,
);
}
const result = await requestHandler(runtime); const result = await requestHandler(runtime);
if (this._id === 'lobehub') {
timing(
'attempt request success model=%s attempt=%d/%d routerId=%s channelId=%s apiType=%s durationMs=%d totalMs=%d traceId=%s',
model,
attempt,
totalOptions,
matchedRouter.id,
channelId,
resolvedApiType,
getDurationMs(startTime),
getDurationMs(totalStartedAt),
metadata?.traceId,
);
}
if (totalOptions > 1 && attempt > 1) { if (totalOptions > 1 && attempt > 1) {
log( log(
@@ -392,6 +525,20 @@ export const createRouterRuntime = ({
return result; return result;
} catch (error) { } catch (error) {
lastError = error; lastError = error;
if (this._id === 'lobehub') {
timing(
'attempt request error model=%s attempt=%d/%d routerId=%s channelId=%s apiType=%s durationMs=%d totalMs=%d traceId=%s',
model,
attempt,
totalOptions,
matchedRouter.id,
channelId,
resolvedApiType,
getDurationMs(startTime),
getDurationMs(totalStartedAt),
metadata?.traceId,
);
}
params params
.onRouteAttempt?.({ .onRouteAttempt?.({
@@ -417,6 +564,7 @@ export const createRouterRuntime = ({
} }
try { try {
const shouldStopStartedAt = Date.now();
const shouldStopFallback = await params.shouldStopFallback?.({ const shouldStopFallback = await params.shouldStopFallback?.({
error, error,
metadata, metadata,
@@ -424,6 +572,18 @@ export const createRouterRuntime = ({
optionIndex: index, optionIndex: index,
}); });
if (this._id === 'lobehub') {
timing(
'shouldStopFallback done model=%s attempt=%d/%d durationMs=%d shouldStop=%s traceId=%s',
model,
attempt,
totalOptions,
getDurationMs(shouldStopStartedAt),
shouldStopFallback,
metadata?.traceId,
);
}
if (shouldStopFallback) { if (shouldStopFallback) {
throw error; throw error;
} }
@@ -460,6 +620,17 @@ export const createRouterRuntime = ({
} }
} }
if (this._id === 'lobehub') {
timing(
'runWithFallback failed model=%s routerId=%s options=%d totalMs=%d traceId=%s',
model,
matchedRouter.id,
totalOptions,
getDurationMs(totalStartedAt),
metadata?.traceId,
);
}
throw lastError ?? new Error('empty provider options'); throw lastError ?? new Error('empty provider options');
} }
+21
View File
@@ -0,0 +1,21 @@
import { describe, expect, it } from 'vitest';
import { AiSendMessageServerSchema } from './aiChat';
const createInput = (topicPageSize: number) => ({
newAssistantMessage: { model: 'gpt-4o', provider: 'openai' },
newUserMessage: { content: 'hello' },
topicPageSize,
});
describe('AiSendMessageServerSchema', () => {
it('should only accept positive integer topic page sizes up to 100', () => {
for (const topicPageSize of [1, 20, 100]) {
expect(AiSendMessageServerSchema.safeParse(createInput(topicPageSize)).success).toBe(true);
}
for (const topicPageSize of [-1, 0, 1.5, 101]) {
expect(AiSendMessageServerSchema.safeParse(createInput(topicPageSize)).success).toBe(false);
}
});
});
+5
View File
@@ -96,6 +96,10 @@ export interface SendMessageServerParams {
}; };
// if there is activeTopicId, then add topicId to message // if there is activeTopicId, then add topicId to message
topicId?: string; topicId?: string;
/**
* Page size for the topic list returned after creating a new topic.
*/
topicPageSize?: number;
} }
export const CreateThreadWithMessageSchema = z.object({ export const CreateThreadWithMessageSchema = z.object({
@@ -156,6 +160,7 @@ export const AiSendMessageServerSchema = z.object({
includeTriggers: z.array(z.string()).optional(), includeTriggers: z.array(z.string()).optional(),
}) })
.optional(), .optional(),
topicPageSize: z.number().int().min(1).max(100).optional(),
topicId: z.string().optional(), topicId: z.string().optional(),
}); });
+1
View File
@@ -20,6 +20,7 @@ export * from './pricing';
export * from './safeParseJSON'; export * from './safeParseJSON';
export * from './sanitizeToolCallArguments'; export * from './sanitizeToolCallArguments';
export * from './sleep'; export * from './sleep';
export * from './timing';
export * from './uriParser'; export * from './uriParser';
export * from './url'; export * from './url';
export * from './uuid'; export * from './uuid';
+173
View File
@@ -0,0 +1,173 @@
import debug from 'debug';
export interface TimingContext {
requestId: string;
startedAt: number;
}
export interface TimingMetadata {
[key: string]: unknown;
}
export interface TimingParams {
timingRequestId?: string;
timingStartedAt?: number;
}
export interface TimingSink {
log: (event: string, metadata?: TimingMetadata) => void;
}
export type TimingLogger = (formatter: string, ...args: unknown[]) => void;
export const createDebugTimingLogger = (namespace: string): TimingLogger => debug(namespace);
export const getDurationMs = (startedAt: number) => Date.now() - startedAt;
export const createTimingRequestId = () =>
globalThis.crypto?.randomUUID?.() ??
`${Date.now().toString(36)}-${Math.random().toString(36).slice(2, 10)}`;
const isRecord = (value: unknown): value is Record<string, unknown> =>
!!value && typeof value === 'object';
export const getTimingErrorMetadata = (error: unknown): TimingMetadata => {
if (error instanceof Error) {
return {
errorMessage: error.message,
errorName: error.name,
};
}
if (isRecord(error)) {
return {
errorType: typeof error.errorType === 'string' ? error.errorType : undefined,
status: typeof error.status === 'number' ? error.status : undefined,
};
}
return { errorMessage: String(error) };
};
export const toTimingContext = (params?: TimingParams): TimingContext | undefined =>
params?.timingRequestId
? { requestId: params.timingRequestId, startedAt: params.timingStartedAt ?? Date.now() }
: undefined;
export const logTiming = (
logger: TimingLogger,
context: TimingContext | undefined,
event: string,
metadata?: TimingMetadata,
) => {
if (!context) return;
const totalMs = getDurationMs(context.startedAt);
if (metadata) {
logger('[%s] %s totalMs=%d %O', context.requestId, event, totalMs, metadata);
return;
}
logger('[%s] %s totalMs=%d', context.requestId, event, totalMs);
};
export const logTimingSink = (
timing: TimingSink | undefined,
event: string,
metadata?: TimingMetadata,
) => {
timing?.log(event, metadata);
};
export const runTimedStage = async <T>(
logger: TimingLogger,
context: TimingContext | undefined,
stage: string,
task: () => T | Promise<T>,
metadata?: TimingMetadata,
): Promise<Awaited<T>> => {
if (!context) return await task();
const startedAt = Date.now();
logTiming(logger, context, `${stage}:start`, metadata);
try {
const result = await task();
logTiming(logger, context, `${stage}:done`, {
...metadata,
stageMs: getDurationMs(startedAt),
});
return result;
} catch (error) {
logTiming(logger, context, `${stage}:error`, {
...metadata,
...getTimingErrorMetadata(error),
stageMs: getDurationMs(startedAt),
});
throw error;
}
};
export const runTimedSinkStage = async <T>(
timing: TimingSink | undefined,
stage: string,
task: () => T | Promise<T>,
metadata?: TimingMetadata,
): Promise<Awaited<T>> => {
if (!timing) return await task();
const startedAt = Date.now();
logTimingSink(timing, `${stage}:start`, metadata);
try {
const result = await task();
logTimingSink(timing, `${stage}:done`, {
...metadata,
stageMs: getDurationMs(startedAt),
});
return result;
} catch (error) {
logTimingSink(timing, `${stage}:error`, {
...metadata,
...getTimingErrorMetadata(error),
stageMs: getDurationMs(startedAt),
});
throw error;
}
};
export const createPrefixedTimingContext = (
logger: TimingLogger,
context: TimingContext | undefined,
prefix: string,
): TimingSink | undefined =>
context
? {
log: (event: string, metadata?: TimingMetadata) => {
logTiming(logger, context, `${prefix}.${event}`, metadata);
},
}
: undefined;
export const createTimingHelpers = (namespace: string) => {
const logger = createDebugTimingLogger(namespace);
return {
createPrefixedTimingContext: (context: TimingContext | undefined, prefix: string) =>
createPrefixedTimingContext(logger, context, prefix),
logger,
logTiming: (context: TimingContext | undefined, event: string, metadata?: TimingMetadata) =>
logTiming(logger, context, event, metadata),
runTimedStage: <T>(
context: TimingContext | undefined,
stage: string,
task: () => T | Promise<T>,
metadata?: TimingMetadata,
) => runTimedStage(logger, context, stage, task, metadata),
toTimingContext,
};
};
@@ -1,4 +1,5 @@
// @vitest-environment node // @vitest-environment node
import type { CreateMessageParams } from '@lobechat/types';
import { ThreadType } from '@lobechat/types'; import { ThreadType } from '@lobechat/types';
import { describe, expect, it, vi } from 'vitest'; import { describe, expect, it, vi } from 'vitest';
@@ -10,6 +11,8 @@ import { AiChatService } from '@/server/services/aiChat';
import { aiChatRouter } from '../aiChat'; import { aiChatRouter } from '../aiChat';
const flushAsyncTasks = () => new Promise<void>((resolve) => setTimeout(resolve, 0));
vi.mock('@/database/models/agent'); vi.mock('@/database/models/agent');
vi.mock('@/database/models/message'); vi.mock('@/database/models/message');
vi.mock('@/database/models/thread'); vi.mock('@/database/models/thread');
@@ -24,6 +27,38 @@ vi.mock('@/server/modules/ModelRuntime', () => ({
describe('aiChatRouter', () => { describe('aiChatRouter', () => {
const mockCtx = { userId: 'u1' }; const mockCtx = { userId: 'u1' };
const mockMessageModel = (mockCreateMessage: ReturnType<typeof vi.fn>) => {
const mockCreateUserAndAssistantMessages = vi.fn(
async (
{
assistantMessage,
userMessage,
}: {
assistantMessage: CreateMessageParams;
userMessage: CreateMessageParams;
},
_options?: unknown,
) => {
const userMessageItem = await mockCreateMessage(userMessage);
const assistantMessageItem = await mockCreateMessage({
...assistantMessage,
parentId: userMessageItem.id,
});
return { assistantMessage: assistantMessageItem, userMessage: userMessageItem };
},
);
vi.mocked(MessageModel).mockImplementation(
() =>
({
create: mockCreateMessage,
createUserAndAssistantMessages: mockCreateUserAndAssistantMessages,
}) as any,
);
return mockCreateUserAndAssistantMessages;
};
it('should create topic optionally, create user/assistant messages, and return payload', async () => { it('should create topic optionally, create user/assistant messages, and return payload', async () => {
const mockCreateTopic = vi.fn().mockResolvedValue({ id: 't1' }); const mockCreateTopic = vi.fn().mockResolvedValue({ id: 't1' });
@@ -37,7 +72,7 @@ describe('aiChatRouter', () => {
}); });
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any); vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); const mockCreateUserAndAssistantMessages = mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
const caller = aiChatRouter.createCaller(mockCtx as any); const caller = aiChatRouter.createCaller(mockCtx as any);
@@ -47,6 +82,7 @@ describe('aiChatRouter', () => {
newTopic: { title: 'T', topicMessageIds: ['a', 'b'] }, newTopic: { title: 'T', topicMessageIds: ['a', 'b'] },
newUserMessage: { content: 'hi', files: ['f1'] }, newUserMessage: { content: 'hi', files: ['f1'] },
sessionId: 's1', sessionId: 's1',
topicPageSize: 20,
} as any; } as any;
const res = await caller.sendMessageInServer(input); const res = await caller.sendMessageInServer(input);
@@ -79,9 +115,19 @@ describe('aiChatRouter', () => {
topicId: 't1', topicId: 't1',
}), }),
); );
expect(mockCreateUserAndAssistantMessages).toHaveBeenCalledTimes(1);
expect(mockCreateUserAndAssistantMessages).toHaveBeenCalledWith(
expect.any(Object),
expect.objectContaining({ touchTopicUpdatedAt: false }),
);
expect(mockGet).toHaveBeenCalledWith( expect(mockGet).toHaveBeenCalledWith(
expect.objectContaining({ includeTopic: true, sessionId: 's1', topicId: 't1' }), expect.objectContaining({
includeTopic: true,
sessionId: 's1',
topicId: 't1',
topicPageSize: 20,
}),
); );
expect(res.assistantMessageId).toBe('m-assistant'); expect(res.assistantMessageId).toBe('m-assistant');
expect(res.userMessageId).toBe('m-user'); expect(res.userMessageId).toBe('m-user');
@@ -99,7 +145,7 @@ describe('aiChatRouter', () => {
.mockResolvedValueOnce({ id: 'm-assistant' }); .mockResolvedValueOnce({ id: 'm-assistant' });
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined }); const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); const mockCreateUserAndAssistantMessages = mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
const caller = aiChatRouter.createCaller(mockCtx as any); const caller = aiChatRouter.createCaller(mockCtx as any);
@@ -112,6 +158,10 @@ describe('aiChatRouter', () => {
} as any); } as any);
expect(mockCreateMessage).toHaveBeenCalled(); expect(mockCreateMessage).toHaveBeenCalled();
expect(mockCreateUserAndAssistantMessages).toHaveBeenCalledWith(
expect.any(Object),
expect.objectContaining({ touchTopicUpdatedAt: true }),
);
expect(mockGet).toHaveBeenCalledWith( expect(mockGet).toHaveBeenCalledWith(
expect.objectContaining({ expect.objectContaining({
includeTopic: false, includeTopic: false,
@@ -130,7 +180,7 @@ describe('aiChatRouter', () => {
.mockResolvedValueOnce({ id: 'm-assistant' }); .mockResolvedValueOnce({ id: 'm-assistant' });
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined }); const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
const caller = aiChatRouter.createCaller(mockCtx as any); const caller = aiChatRouter.createCaller(mockCtx as any);
@@ -175,7 +225,7 @@ describe('aiChatRouter', () => {
.mockResolvedValueOnce({ id: 'm-assistant' }); .mockResolvedValueOnce({ id: 'm-assistant' });
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined }); const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
const caller = aiChatRouter.createCaller(mockCtx as any); const caller = aiChatRouter.createCaller(mockCtx as any);
@@ -282,7 +332,7 @@ describe('aiChatRouter', () => {
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined }); const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
vi.mocked(ThreadModel).mockImplementation(() => ({ create: mockCreateThread }) as any); vi.mocked(ThreadModel).mockImplementation(() => ({ create: mockCreateThread }) as any);
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
const caller = aiChatRouter.createCaller(mockCtx as any); const caller = aiChatRouter.createCaller(mockCtx as any);
@@ -346,7 +396,7 @@ describe('aiChatRouter', () => {
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any); vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
vi.mocked(ThreadModel).mockImplementation(() => ({ create: mockCreateThread }) as any); vi.mocked(ThreadModel).mockImplementation(() => ({ create: mockCreateThread }) as any);
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
const caller = aiChatRouter.createCaller(mockCtx as any); const caller = aiChatRouter.createCaller(mockCtx as any);
@@ -402,7 +452,7 @@ describe('aiChatRouter', () => {
.mockResolvedValueOnce({ id: 'm-assistant' }); .mockResolvedValueOnce({ id: 'm-assistant' });
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined }); const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
const caller = aiChatRouter.createCaller(mockCtx as any); const caller = aiChatRouter.createCaller(mockCtx as any);
@@ -427,7 +477,7 @@ describe('aiChatRouter', () => {
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: [{}] }); const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: [{}] });
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any); vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
const caller = aiChatRouter.createCaller(mockCtx as any); const caller = aiChatRouter.createCaller(mockCtx as any);
@@ -459,7 +509,7 @@ describe('aiChatRouter', () => {
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: [{}] }); const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: [{}] });
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any); vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
const caller = aiChatRouter.createCaller(mockCtx as any); const caller = aiChatRouter.createCaller(mockCtx as any);
@@ -489,7 +539,7 @@ describe('aiChatRouter', () => {
.mockResolvedValueOnce({ id: 'm-assistant' }); .mockResolvedValueOnce({ id: 'm-assistant' });
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined }); const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
const caller = aiChatRouter.createCaller(mockCtx as any); const caller = aiChatRouter.createCaller(mockCtx as any);
@@ -537,7 +587,7 @@ describe('aiChatRouter', () => {
.mockResolvedValueOnce({ id: 'm-assistant' }); .mockResolvedValueOnce({ id: 'm-assistant' });
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined }); const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
const caller = aiChatRouter.createCaller(mockCtx as any); const caller = aiChatRouter.createCaller(mockCtx as any);
@@ -569,7 +619,7 @@ describe('aiChatRouter', () => {
.mockResolvedValueOnce({ id: 'm-assistant' }); .mockResolvedValueOnce({ id: 'm-assistant' });
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined }); const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
const caller = aiChatRouter.createCaller(mockCtx as any); const caller = aiChatRouter.createCaller(mockCtx as any);
@@ -621,7 +671,7 @@ describe('aiChatRouter', () => {
.mockResolvedValueOnce({ id: 'm-assistant' }); .mockResolvedValueOnce({ id: 'm-assistant' });
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined }); const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
const caller = aiChatRouter.createCaller(mockCtx as any); const caller = aiChatRouter.createCaller(mockCtx as any);
@@ -677,7 +727,7 @@ describe('aiChatRouter', () => {
const mockTouchUpdatedAt = vi.fn().mockResolvedValue(undefined); const mockTouchUpdatedAt = vi.fn().mockResolvedValue(undefined);
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any); vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
vi.mocked(AgentModel).mockImplementation( vi.mocked(AgentModel).mockImplementation(
() => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any, () => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any,
@@ -713,7 +763,7 @@ describe('aiChatRouter', () => {
const mockTouchUpdatedAt = vi.fn().mockResolvedValue(undefined); const mockTouchUpdatedAt = vi.fn().mockResolvedValue(undefined);
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any); vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
vi.mocked(AgentModel).mockImplementation( vi.mocked(AgentModel).mockImplementation(
() => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any, () => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any,
@@ -733,6 +783,94 @@ describe('aiChatRouter', () => {
expect(mockTouchUpdatedAt).toHaveBeenCalledWith('agent-1'); expect(mockTouchUpdatedAt).toHaveBeenCalledWith('agent-1');
}); });
it('should keep the message response when agent updatedAt touch fails', async () => {
const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => undefined);
const mockCreateTopic = vi.fn().mockResolvedValue({ id: 't1' });
const mockCreateMessage = vi
.fn()
.mockResolvedValueOnce({ id: 'm-user' })
.mockResolvedValueOnce({ id: 'm-assistant' });
const mockGet = vi.fn().mockResolvedValue({
messages: [{ id: 'm-user' }, { id: 'm-assistant' }],
topics: undefined,
});
const touchError = new Error('touch failed');
const mockTouchUpdatedAt = vi.fn().mockRejectedValue(touchError);
try {
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(
() => ({ getMessagesAndTopics: mockGet }) as any,
);
vi.mocked(AgentModel).mockImplementation(
() => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any,
);
const caller = aiChatRouter.createCaller(mockCtx as any);
const res = await caller.sendMessageInServer({
agentId: 'agent-1',
newAssistantMessage: { model: 'gpt-4o', provider: 'openai' },
newTopic: { title: 'New Topic' },
newUserMessage: { content: 'hi' },
sessionId: 's1',
} as any);
expect(res.userMessageId).toBe('m-user');
expect(res.assistantMessageId).toBe('m-assistant');
expect(mockTouchUpdatedAt).toHaveBeenCalledWith('agent-1');
expect(consoleErrorSpy).toHaveBeenCalledWith(
'[aiChat] Failed to touch agent updatedAt:',
touchError,
);
} finally {
consoleErrorSpy.mockRestore();
}
});
it('should create messages while agent updatedAt touch is still pending', async () => {
const mockCreateTopic = vi.fn().mockResolvedValue({ id: 't1' });
const mockCreateMessage = vi
.fn()
.mockResolvedValueOnce({ id: 'm-user' })
.mockResolvedValueOnce({ id: 'm-assistant' });
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: [{}] });
let resolveTouchUpdatedAt: () => void = () => {};
const touchUpdatedAtPromise = new Promise<void>((resolve) => {
resolveTouchUpdatedAt = resolve;
});
const mockTouchUpdatedAt = vi.fn(() => touchUpdatedAtPromise);
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
const mockCreateUserAndAssistantMessages = mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
vi.mocked(AgentModel).mockImplementation(
() => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any,
);
const caller = aiChatRouter.createCaller(mockCtx as any);
const request = caller.sendMessageInServer({
agentId: 'agent-1',
newAssistantMessage: { model: 'gpt-4o', provider: 'openai' },
newTopic: { title: 'New Topic' },
newUserMessage: { content: 'hi' },
sessionId: 's1',
} as any);
await flushAsyncTasks();
try {
expect(mockTouchUpdatedAt).toHaveBeenCalledWith('agent-1');
expect(mockCreateUserAndAssistantMessages).toHaveBeenCalledTimes(1);
} finally {
resolveTouchUpdatedAt();
}
await request;
});
it('should not touch agent updatedAt when creating topic without agentId', async () => { it('should not touch agent updatedAt when creating topic without agentId', async () => {
const mockCreateTopic = vi.fn().mockResolvedValue({ id: 't1' }); const mockCreateTopic = vi.fn().mockResolvedValue({ id: 't1' });
const mockCreateMessage = vi const mockCreateMessage = vi
@@ -743,7 +881,7 @@ describe('aiChatRouter', () => {
const mockTouchUpdatedAt = vi.fn().mockResolvedValue(undefined); const mockTouchUpdatedAt = vi.fn().mockResolvedValue(undefined);
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any); vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
vi.mocked(AgentModel).mockImplementation( vi.mocked(AgentModel).mockImplementation(
() => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any, () => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any,
@@ -771,7 +909,7 @@ describe('aiChatRouter', () => {
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined }); const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
const mockTouchUpdatedAt = vi.fn().mockResolvedValue(undefined); const mockTouchUpdatedAt = vi.fn().mockResolvedValue(undefined);
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any); mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any); vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
vi.mocked(AgentModel).mockImplementation( vi.mocked(AgentModel).mockImplementation(
() => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any, () => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any,
+185 -79
View File
@@ -1,5 +1,6 @@
import { type CreateMessageParams, type SendMessageServerResponse } from '@lobechat/types'; import type { CreateMessageParams, SendMessageServerResponse } from '@lobechat/types';
import { AiSendMessageServerSchema, RequestTrigger, StructureOutputSchema } from '@lobechat/types'; import { AiSendMessageServerSchema, RequestTrigger, StructureOutputSchema } from '@lobechat/types';
import { createTimingHelpers, createTimingRequestId } from '@lobechat/utils';
import debug from 'debug'; import debug from 'debug';
import { LOADING_FLAT } from '@/const/message'; import { LOADING_FLAT } from '@/const/message';
@@ -15,6 +16,9 @@ import { AiChatService } from '@/server/services/aiChat';
import { FileService } from '@/server/services/file'; import { FileService } from '@/server/services/file';
const log = debug('lobe-lambda-router:ai-chat'); const log = debug('lobe-lambda-router:ai-chat');
const { createPrefixedTimingContext, logTiming, runTimedStage } = createTimingHelpers(
'lobe-server:chat:lobehub:timing',
);
const aiChatProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { const aiChatProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
const { ctx } = opts; const { ctx } = opts;
@@ -59,6 +63,17 @@ export const aiChatRouter = router({
sendMessageInServer: aiChatProcedure sendMessageInServer: aiChatProcedure
.input(AiSendMessageServerSchema) .input(AiSendMessageServerSchema)
.mutation(async ({ input, ctx }) => { .mutation(async ({ input, ctx }) => {
const timingContext =
input.newAssistantMessage.provider === 'lobehub'
? { requestId: createTimingRequestId(), startedAt: Date.now() }
: undefined;
logTiming(timingContext, 'lambda.aiChat.sendMessageInServer:start', {
hasNewThread: !!input.newThread,
hasNewTopic: !!input.newTopic,
hasSessionId: !!input.sessionId,
hasTopicId: !!input.topicId,
preloadCount: input.preloadMessages?.length ?? 0,
});
log('sendMessageInServer called for agentId: %s', input.agentId); log('sendMessageInServer called for agentId: %s', input.agentId);
log( log(
'topicId: %s, newTopic: %O, newThread: %O', 'topicId: %s, newTopic: %O, newThread: %O',
@@ -68,7 +83,12 @@ export const aiChatRouter = router({
); );
let sessionId = input.sessionId; let sessionId = input.sessionId;
if (!sessionId) { if (!sessionId) {
const context = await resolveContext(input, ctx.serverDB, ctx.userId); const context = await runTimedStage(
timingContext,
'lambda.aiChat.resolveContext',
() => resolveContext(input, ctx.serverDB, ctx.userId),
{ hasAgentId: !!input.agentId },
);
if (!!context.sessionId) sessionId = context.sessionId; if (!!context.sessionId) sessionId = context.sessionId;
} }
@@ -77,27 +97,54 @@ export const aiChatRouter = router({
let createdThreadId: string | undefined; let createdThreadId: string | undefined;
let isCreateNewTopic = false; let isCreateNewTopic = false;
let agentTouchUpdatedAtTask: Promise<void> | undefined;
// create topic if there should be a new topic // create topic if there should be a new topic
if (input.newTopic) { if (input.newTopic) {
log('creating new topic with title: %s', input.newTopic.title); log('creating new topic with title: %s', input.newTopic.title);
const topicItem = await ctx.topicModel.create({ const topicItem = await runTimedStage(
agentId: input.agentId, timingContext,
groupId: input.groupId, 'lambda.aiChat.topic.create',
messages: input.newTopic.topicMessageIds, () => {
metadata: input.newTopic.metadata, const payload = {
sessionId, agentId: input.agentId,
title: input.newTopic.title, groupId: input.groupId,
trigger: input.newTopic.trigger, messages: input.newTopic!.topicMessageIds,
}); metadata: input.newTopic!.metadata,
sessionId,
title: input.newTopic!.title,
trigger: input.newTopic!.trigger,
};
const modelTiming = createPrefixedTimingContext(
timingContext,
'lambda.aiChat.topic.create',
);
return modelTiming
? ctx.topicModel.create(payload, undefined, modelTiming)
: ctx.topicModel.create(payload);
},
{
messageCount: input.newTopic.topicMessageIds?.length ?? 0,
trigger: input.newTopic.trigger,
},
);
topicId = topicItem.id; topicId = topicItem.id;
isCreateNewTopic = true; isCreateNewTopic = true;
log('new topic created with id: %s', topicId); log('new topic created with id: %s', topicId);
// update agent's updatedAt to reflect new activity // update agent's updatedAt to reflect new activity
if (input.agentId) { if (input.agentId) {
await ctx.agentModel.touchUpdatedAt(input.agentId); agentTouchUpdatedAtTask = runTimedStage(
log('agent updatedAt touched for agentId: %s', input.agentId); timingContext,
'lambda.aiChat.agent.touchUpdatedAt',
async () => {
await ctx.agentModel.touchUpdatedAt(input.agentId!);
},
{ hasAgentId: true },
).catch((error) => {
console.error('[aiChat] Failed to touch agent updatedAt:', error);
});
log('agent updatedAt touch scheduled for agentId: %s', input.agentId);
} }
} }
@@ -108,13 +155,19 @@ export const aiChatRouter = router({
input.newThread.sourceMessageId, input.newThread.sourceMessageId,
input.newThread.type, input.newThread.type,
); );
const threadItem = await ctx.threadModel.create({ const threadItem = await runTimedStage(
parentThreadId: input.newThread.parentThreadId, timingContext,
sourceMessageId: input.newThread.sourceMessageId, 'lambda.aiChat.thread.create',
title: input.newThread.title, () =>
topicId, ctx.threadModel.create({
type: input.newThread.type, parentThreadId: input.newThread!.parentThreadId,
}); sourceMessageId: input.newThread!.sourceMessageId,
title: input.newThread!.title,
topicId,
type: input.newThread!.type,
}),
{ threadType: input.newThread.type },
);
if (threadItem) { if (threadItem) {
threadId = threadItem.id; threadId = threadItem.id;
createdThreadId = threadItem.id; createdThreadId = threadItem.id;
@@ -127,24 +180,40 @@ export const aiChatRouter = router({
if (input.preloadMessages?.length) { if (input.preloadMessages?.length) {
log('creating %d preload messages before user message', input.preloadMessages.length); log('creating %d preload messages before user message', input.preloadMessages.length);
for (const preloadMessage of input.preloadMessages) { parentId = await runTimedStage(
const preloadItem = await ctx.messageModel.create({ timingContext,
agentId: input.agentId, 'lambda.aiChat.preloadMessages.create',
content: preloadMessage.content, async () => {
groupId: input.groupId, let latestParentId = parentId;
metadata: preloadMessage.metadata, for (const preloadMessage of input.preloadMessages!) {
parentId, const payload = {
plugin: preloadMessage.plugin as CreateMessageParams['plugin'], agentId: input.agentId,
role: preloadMessage.role, content: preloadMessage.content,
sessionId, groupId: input.groupId,
threadId, metadata: preloadMessage.metadata,
tool_call_id: preloadMessage.tool_call_id, parentId: latestParentId,
tools: preloadMessage.tools as CreateMessageParams['tools'], plugin: preloadMessage.plugin as CreateMessageParams['plugin'],
topicId, role: preloadMessage.role,
}); sessionId,
threadId,
tool_call_id: preloadMessage.tool_call_id,
tools: preloadMessage.tools as CreateMessageParams['tools'],
topicId,
};
const modelTiming = createPrefixedTimingContext(
timingContext,
'lambda.aiChat.preloadMessages.create',
);
const preloadItem = await (modelTiming
? ctx.messageModel.create(payload, undefined, modelTiming)
: ctx.messageModel.create(payload));
parentId = preloadItem.id; latestParentId = preloadItem.id;
} }
return latestParentId;
},
{ count: input.preloadMessages.length },
);
} }
// create user message // create user message
@@ -161,58 +230,95 @@ export const aiChatRouter = router({
} }
: undefined; : undefined;
const userMessageItem = await ctx.messageModel.create({ const createMessagePairPromise = runTimedStage(
agentId: input.agentId, timingContext,
content: input.newUserMessage.content, 'lambda.aiChat.messages.createUserAndAssistant',
editorData: input.newUserMessage.editorData, () => {
files: input.newUserMessage.files, const userMessage = {
groupId: input.groupId, agentId: input.agentId,
metadata: userMessageMetadata, content: input.newUserMessage.content,
parentId, editorData: input.newUserMessage.editorData,
role: 'user', files: input.newUserMessage.files,
sessionId, groupId: input.groupId,
threadId, metadata: userMessageMetadata,
topicId, parentId,
}); role: 'user',
sessionId,
threadId,
topicId,
} satisfies CreateMessageParams;
const assistantMessage = {
agentId: input.agentId,
content: LOADING_FLAT,
groupId: input.groupId,
metadata: input.newAssistantMessage.metadata,
model: input.newAssistantMessage.model,
provider: input.newAssistantMessage.provider,
role: 'assistant',
sessionId,
threadId,
topicId,
} satisfies CreateMessageParams;
const modelTiming = createPrefixedTimingContext(
timingContext,
'lambda.aiChat.messages.createUserAndAssistant',
);
return ctx.messageModel.createUserAndAssistantMessages(
{ assistantMessage, userMessage },
{
...(modelTiming ? { timing: modelTiming } : {}),
touchTopicUpdatedAt: !isCreateNewTopic,
},
);
},
{
contentLength: input.newUserMessage.content.length,
fileCount: input.newUserMessage.files?.length ?? 0,
model: input.newAssistantMessage.model,
provider: input.newAssistantMessage.provider,
},
);
const { assistantMessage: assistantMessageItem, userMessage: userMessageItem } =
agentTouchUpdatedAtTask
? (await Promise.all([createMessagePairPromise, agentTouchUpdatedAtTask]))[0]
: await createMessagePairPromise;
const messageId = userMessageItem.id; const messageId = userMessageItem.id;
log('user message created with id: %s', messageId); log('user message created with id: %s', messageId);
// create assistant message
log(
'creating assistant message with model: %s, provider: %s, metadata: %O',
input.newAssistantMessage.model,
input.newAssistantMessage.provider,
input.newAssistantMessage.metadata,
);
const assistantMessageItem = await ctx.messageModel.create({
agentId: input.agentId,
content: LOADING_FLAT,
groupId: input.groupId,
metadata: input.newAssistantMessage.metadata,
model: input.newAssistantMessage.model,
parentId: messageId,
provider: input.newAssistantMessage.provider,
role: 'assistant',
sessionId,
threadId,
topicId,
});
log('assistant message created with id: %s', assistantMessageItem.id); log('assistant message created with id: %s', assistantMessageItem.id);
// retrieve latest messages and topic with // retrieve latest messages and topic with
log('retrieving messages and topics'); log('retrieving messages and topics');
const { messages, topics } = await ctx.aiChatService.getMessagesAndTopics({ const { messages, topics } = await runTimedStage(
agentId: input.agentId, timingContext,
groupId: input.groupId, 'lambda.aiChat.messagesAndTopics.query',
includeTopic: isCreateNewTopic, () =>
sessionId, ctx.aiChatService.getMessagesAndTopics({
threadId, agentId: input.agentId,
topicFilter: input.topicFilter, groupId: input.groupId,
topicId, includeTopic: isCreateNewTopic,
}); sessionId,
threadId,
topicFilter: input.topicFilter,
topicId,
topicPageSize: input.topicPageSize,
...(timingContext
? {
timingRequestId: timingContext.requestId,
timingStartedAt: timingContext.startedAt,
}
: {}),
}),
{ includeTopic: isCreateNewTopic },
);
log('retrieved %d messages, %d topics', messages.length, topics?.items?.length ?? 0); log('retrieved %d messages, %d topics', messages.length, topics?.items?.length ?? 0);
logTiming(timingContext, 'lambda.aiChat.sendMessageInServer:done', {
isCreateNewTopic,
messageCount: messages.length,
topicCount: topics?.items?.length ?? 0,
});
return { return {
assistantMessageId: assistantMessageItem.id, assistantMessageId: assistantMessageItem.id,
+33 -2
View File
@@ -4,6 +4,7 @@ import {
UpdateMessagePluginSchema, UpdateMessagePluginSchema,
UpdateMessageRAGParamsSchema, UpdateMessageRAGParamsSchema,
} from '@lobechat/types'; } from '@lobechat/types';
import { createTimingHelpers, createTimingRequestId } from '@lobechat/utils';
import { TRPCError } from '@trpc/server'; import { TRPCError } from '@trpc/server';
import { z } from 'zod'; import { z } from 'zod';
@@ -18,6 +19,8 @@ import { MessageService } from '@/server/services/message';
import { resolveAgentIdFromSession, resolveContext } from './_helpers/resolveContext'; import { resolveAgentIdFromSession, resolveContext } from './_helpers/resolveContext';
import { basicContextSchema } from './_schema/context'; import { basicContextSchema } from './_schema/context';
const { logTiming, runTimedStage } = createTimingHelpers('lobe-server:chat:lobehub:timing');
const messageProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { const messageProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
const { ctx } = opts; const { ctx } = opts;
@@ -316,9 +319,37 @@ export const messageRouter = router({
) )
.mutation(async ({ input, ctx }) => { .mutation(async ({ input, ctx }) => {
const { id, value, agentId, ...options } = input; const { id, value, agentId, ...options } = input;
const resolved = await resolveContext({ agentId, ...options }, ctx.serverDB, ctx.userId); const timingContext = { requestId: createTimingRequestId(), startedAt: Date.now() };
logTiming(timingContext, 'lambda.message.update:start', {
hasAgentId: !!agentId,
hasTopicId: !!options.topicId,
valueKeys: Object.keys(value ?? {}),
});
return ctx.messageService.updateMessage(id, value as any, resolved); const resolved = await runTimedStage(
timingContext,
'lambda.message.update.resolveContext',
() => resolveContext({ agentId, ...options }, ctx.serverDB, ctx.userId),
{ hasAgentId: !!agentId },
);
const result = await runTimedStage(
timingContext,
'lambda.message.update.service',
() =>
ctx.messageService.updateMessage(id, value as any, {
...resolved,
timingRequestId: timingContext.requestId,
timingStartedAt: timingContext.startedAt,
}),
{ hasResolvedTopicId: !!resolved.topicId },
);
logTiming(timingContext, 'lambda.message.update:done', {
messageCount: result.messages?.length ?? 0,
success: result.success,
});
return result;
}), }),
/** /**
+13 -2
View File
@@ -1,4 +1,4 @@
import { type LobeChatDatabase } from '@lobechat/database'; import type { LobeChatDatabase } from '@lobechat/database';
import { describe, expect, it, vi } from 'vitest'; import { describe, expect, it, vi } from 'vitest';
import { MessageModel } from '@/database/models/message'; import { MessageModel } from '@/database/models/message';
@@ -31,13 +31,18 @@ describe('AiChatService', () => {
groupId: 'group-1', groupId: 'group-1',
includeTopic: true, includeTopic: true,
sessionId: 's1', sessionId: 's1',
topicPageSize: 20,
}); });
expect(mockQueryMessages).toHaveBeenCalledWith( expect(mockQueryMessages).toHaveBeenCalledWith(
{ agentId: 'agent-1', groupId: 'group-1', includeTopic: true, sessionId: 's1' }, { agentId: 'agent-1', groupId: 'group-1', includeTopic: true, sessionId: 's1' },
expect.objectContaining({ postProcessUrl: expect.any(Function) }), expect.objectContaining({ postProcessUrl: expect.any(Function) }),
); );
expect(mockQueryTopics).toHaveBeenCalledWith({ agentId: 'agent-1', groupId: 'group-1' }); expect(mockQueryTopics).toHaveBeenCalledWith({
agentId: 'agent-1',
groupId: 'group-1',
pageSize: 20,
});
expect(res.messages).toEqual([{ id: 'm1' }]); expect(res.messages).toEqual([{ id: 'm1' }]);
expect(res.topics).toEqual([{ id: 't1' }]); expect(res.topics).toEqual([{ id: 't1' }]);
}); });
@@ -63,6 +68,7 @@ describe('AiChatService', () => {
excludeStatuses: ['completed'], excludeStatuses: ['completed'],
excludeTriggers: ['cron', 'eval'], excludeTriggers: ['cron', 'eval'],
}, },
topicPageSize: 20,
}); });
expect(mockQueryTopics).toHaveBeenCalledWith({ expect(mockQueryTopics).toHaveBeenCalledWith({
@@ -70,12 +76,17 @@ describe('AiChatService', () => {
excludeStatuses: ['completed'], excludeStatuses: ['completed'],
excludeTriggers: ['cron', 'eval'], excludeTriggers: ['cron', 'eval'],
groupId: undefined, groupId: undefined,
pageSize: 20,
}); });
// topicFilter must not leak into messageModel.query // topicFilter must not leak into messageModel.query
expect(mockQueryMessages).toHaveBeenCalledWith( expect(mockQueryMessages).toHaveBeenCalledWith(
expect.not.objectContaining({ topicFilter: expect.anything() }), expect.not.objectContaining({ topicFilter: expect.anything() }),
expect.objectContaining({ postProcessUrl: expect.any(Function) }), expect.objectContaining({ postProcessUrl: expect.any(Function) }),
); );
expect(mockQueryMessages).toHaveBeenCalledWith(
expect.not.objectContaining({ topicPageSize: 20 }),
expect.objectContaining({ postProcessUrl: expect.any(Function) }),
);
}); });
it('getMessagesAndTopics should not query topics when includeTopic is false', async () => { it('getMessagesAndTopics should not query topics when includeTopic is false', async () => {
+65 -25
View File
@@ -1,9 +1,33 @@
import { type LobeChatDatabase } from '@lobechat/database'; import type { LobeChatDatabase } from '@lobechat/database';
import { createTimingHelpers } from '@lobechat/utils';
import { MessageModel } from '@/database/models/message'; import { MessageModel } from '@/database/models/message';
import { TopicModel } from '@/database/models/topic'; import { TopicModel } from '@/database/models/topic';
import { FileService } from '@/server/services/file'; import { FileService } from '@/server/services/file';
const { createPrefixedTimingContext, runTimedStage, toTimingContext } = createTimingHelpers(
'lobe-server:chat:lobehub:timing',
);
interface GetMessagesAndTopicsParams {
agentId?: string;
current?: number;
groupId?: string;
includeTopic?: boolean;
pageSize?: number;
sessionId?: string;
threadId?: string;
timingRequestId?: string;
timingStartedAt?: number;
topicFilter?: {
excludeStatuses?: string[];
excludeTriggers?: string[];
includeTriggers?: string[];
};
topicId?: string;
topicPageSize?: number;
}
export class AiChatService { export class AiChatService {
private userId: string; private userId: string;
private messageModel: MessageModel; private messageModel: MessageModel;
@@ -18,32 +42,48 @@ export class AiChatService {
this.fileService = new FileService(serverDB, userId); this.fileService = new FileService(serverDB, userId);
} }
async getMessagesAndTopics(params: { async getMessagesAndTopics(params: GetMessagesAndTopicsParams) {
agentId?: string; const { topicFilter, topicPageSize, timingRequestId, timingStartedAt, ...messageParams } =
current?: number; params;
groupId?: string; const timingContext = toTimingContext({ timingRequestId, timingStartedAt });
includeTopic?: boolean; const messageTiming = createPrefixedTimingContext(
pageSize?: number; timingContext,
sessionId?: string; 'lambda.aiChat.messagesAndTopics.messageModel.query',
threadId?: string; );
topicFilter?: { const topicTiming = createPrefixedTimingContext(
excludeStatuses?: string[]; timingContext,
excludeTriggers?: string[]; 'lambda.aiChat.messagesAndTopics.topicModel.query',
includeTriggers?: string[]; );
}; const messageQueryPromise = runTimedStage(
topicId?: string; timingContext,
}) { 'lambda.aiChat.messagesAndTopics.messageModel.query',
const { topicFilter, ...messageParams } = params; () =>
this.messageModel.query(messageParams, {
postProcessUrl: (path) => this.fileService.getFullFileUrl(path),
...(messageTiming ? { timing: messageTiming } : {}),
}),
{
hasAgentId: !!params.agentId,
hasThreadId: !!params.threadId,
hasTopicId: !!params.topicId,
},
);
const [messages, topics] = await Promise.all([ const [messages, topics] = await Promise.all([
this.messageModel.query(messageParams, { messageQueryPromise,
postProcessUrl: (path) => this.fileService.getFullFileUrl(path),
}),
params.includeTopic params.includeTopic
? this.topicModel.query({ ? runTimedStage(
agentId: params.agentId, timingContext,
groupId: params.groupId, 'lambda.aiChat.messagesAndTopics.topicModel.query',
...topicFilter, () =>
}) this.topicModel.query({
agentId: params.agentId,
groupId: params.groupId,
pageSize: topicPageSize,
...(topicTiming ? { timing: topicTiming } : {}),
...topicFilter,
}),
{ hasAgentId: !!params.agentId, hasGroupId: !!params.groupId },
)
: undefined, : undefined,
]); ]);
+41 -2
View File
@@ -5,6 +5,7 @@ import {
type UIChatMessage, type UIChatMessage,
type UpdateMessageParams, type UpdateMessageParams,
} from '@lobechat/types'; } from '@lobechat/types';
import { createTimingHelpers, getDurationMs } from '@lobechat/utils';
import { MessageModel } from '@/database/models/message'; import { MessageModel } from '@/database/models/message';
@@ -15,9 +16,26 @@ interface QueryOptions {
groupId?: string | null; groupId?: string | null;
sessionId?: string | null; sessionId?: string | null;
threadId?: string | null; threadId?: string | null;
timingRequestId?: string;
timingStartedAt?: number;
topicId?: string | null; topicId?: string | null;
} }
const { createPrefixedTimingContext, logTiming, toTimingContext } = createTimingHelpers(
'lobe-server:chat:lobehub:timing',
);
const logMessageTiming = (
options: QueryOptions | undefined,
event: string,
metadata?: Record<string, unknown>,
) => {
logTiming(toTimingContext(options), event, metadata);
};
const createModelTiming = (options: QueryOptions | undefined, prefix: string) =>
createPrefixedTimingContext(toTimingContext(options), prefix);
interface CreateMessageResult { interface CreateMessageResult {
id: string; id: string;
messages: any[]; messages: any[];
@@ -70,15 +88,25 @@ export class MessageService {
options.sessionId === undefined && options.sessionId === undefined &&
options.topicId === undefined) options.topicId === undefined)
) { ) {
logMessageTiming(options, 'lambda.message.update.queryMessages:skipped');
return { success: true }; return { success: true };
} }
const { agentId, sessionId, topicId, groupId, threadId } = options; const { agentId, sessionId, topicId, groupId, threadId } = options;
const queryStartedAt = Date.now();
const modelTiming = createModelTiming(options, 'lambda.message.update.queryMessages');
const messages = await this.messageModel.query( const messages = await this.messageModel.query(
{ agentId, groupId, sessionId, threadId, topicId }, { agentId, groupId, sessionId, threadId, topicId },
this.getQueryOptions(), {
...this.getQueryOptions(),
...(modelTiming ? { timing: modelTiming } : {}),
},
); );
logMessageTiming(options, 'lambda.message.update.queryMessages:done', {
messageCount: messages.length,
stageMs: getDurationMs(queryStartedAt),
});
return { messages, success: true }; return { messages, success: true };
} }
@@ -188,7 +216,18 @@ export class MessageService {
value: UpdateMessageParams, value: UpdateMessageParams,
options: QueryOptions, options: QueryOptions,
): Promise<{ messages?: UIChatMessage[]; success: boolean }> { ): Promise<{ messages?: UIChatMessage[]; success: boolean }> {
await this.messageModel.update(id, value as any); const updateStartedAt = Date.now();
const modelTiming = createModelTiming(options, 'lambda.message.update.dbUpdate');
if (modelTiming) {
await this.messageModel.update(id, value as any, modelTiming);
} else {
await this.messageModel.update(id, value as any);
}
logMessageTiming(options, 'lambda.message.update.dbUpdate:done', {
stageMs: getDurationMs(updateStartedAt),
valueKeys: Object.keys(value ?? {}),
});
return this.queryWithSuccess(options); return this.queryWithSuccess(options);
} }
@@ -9,6 +9,7 @@ import { chatService } from '@/services/chat';
import { messageService } from '@/services/message'; import { messageService } from '@/services/message';
import * as agentGroupStore from '@/store/agentGroup'; import * as agentGroupStore from '@/store/agentGroup';
import { messageMapKey } from '@/store/chat/utils/messageMapKey'; import { messageMapKey } from '@/store/chat/utils/messageMapKey';
import { topicMapKey } from '@/store/chat/utils/topicMapKey';
import { getSessionStoreState } from '@/store/session'; import { getSessionStoreState } from '@/store/session';
import * as toolStoreModule from '@/store/tool'; import * as toolStoreModule from '@/store/tool';
@@ -1622,7 +1623,6 @@ describe('ConversationLifecycle actions', () => {
createMockMessage({ id: 'new-user-msg', role: 'user', topicId: newTopicId }), createMockMessage({ id: 'new-user-msg', role: 'user', topicId: newTopicId }),
createMockMessage({ id: 'new-assistant-msg', role: 'assistant', topicId: newTopicId }), createMockMessage({ id: 'new-assistant-msg', role: 'assistant', topicId: newTopicId }),
], ],
topics: { items: [{ id: newTopicId, title: 'New Topic' }], total: 1 },
topicId: newTopicId, topicId: newTopicId,
isCreateNewTopic: true, isCreateNewTopic: true,
assistantMessageId: 'new-assistant-msg', assistantMessageId: 'new-assistant-msg',
@@ -1648,6 +1648,12 @@ describe('ConversationLifecycle actions', () => {
// After new topic creation, the _new key should be cleared // After new topic creation, the _new key should be cleared
const messagesInNewKey = useChatStore.getState().messagesMap[newKey]; const messagesInNewKey = useChatStore.getState().messagesMap[newKey];
expect(messagesInNewKey ?? []).toHaveLength(0); expect(messagesInNewKey ?? []).toHaveLength(0);
const newTopicKey = messageMapKey({ agentId, topicId: newTopicId });
expect(useChatStore.getState().messagesMap[newTopicKey]).toHaveLength(2);
expect(useChatStore.getState().topicDataMap[topicMapKey({ agentId })]?.items[0]).toEqual(
expect.objectContaining({ id: newTopicId }),
);
}); });
}); });
}); });
@@ -523,6 +523,7 @@ export class ConversationLifecycleActionImpl {
operationContext.agentId, operationContext.agentId,
operationContext.groupId ?? undefined, operationContext.groupId ?? undefined,
), ),
topicPageSize: systemStatusSelectors.topicPageSize(useGlobalStore.getState()),
topicId: operationContext.topicId ?? undefined, topicId: operationContext.topicId ?? undefined,
}, },
abortController, abortController,
@@ -712,6 +713,7 @@ export class ConversationLifecycleActionImpl {
const toolContext = formatSelectedToolsContext(dedupedTools); const toolContext = formatSelectedToolsContext(dedupedTools);
const contextSuffix = [skillContext, toolContext].filter(Boolean).join('\n'); const contextSuffix = [skillContext, toolContext].filter(Boolean).join('\n');
const persistedContent = contextSuffix ? `${message}\n\n${contextSuffix}` : message; const persistedContent = contextSuffix ? `${message}\n\n${contextSuffix}` : message;
const newTopicTitle = message.slice(0, 80) || t('defaultTitle', { ns: 'topic' });
data = await aiChatService.sendMessageInServer( data = await aiChatService.sendMessageInServer(
{ {
@@ -730,6 +732,7 @@ export class ConversationLifecycleActionImpl {
operationContext.agentId, operationContext.agentId,
operationContext.groupId ?? undefined, operationContext.groupId ?? undefined,
), ),
topicPageSize: systemStatusSelectors.topicPageSize(useGlobalStore.getState()),
threadId: operationContext.threadId ?? undefined, threadId: operationContext.threadId ?? undefined,
// Support creating new thread along with message // Support creating new thread along with message
newThread: newThread newThread: newThread
@@ -741,7 +744,7 @@ export class ConversationLifecycleActionImpl {
newTopic: !topicId newTopic: !topicId
? { ? {
topicMessageIds: forceNewTopicFromExisting ? [] : messages.map((m) => m.id), topicMessageIds: forceNewTopicFromExisting ? [] : messages.map((m) => m.id),
title: message.slice(0, 80) || t('defaultTitle', { ns: 'topic' }), title: newTopicTitle,
} }
: undefined, : undefined,
agentId: operationContext.agentId, agentId: operationContext.agentId,
@@ -757,7 +760,7 @@ export class ConversationLifecycleActionImpl {
abortController, abortController,
); );
// Use created topicId/threadId if available, otherwise use original from context // Use created topicId/threadId if available, otherwise use original from context
let finalTopicId = operationContext.topicId; let finalTopicId = data.topicId ?? operationContext.topicId;
const finalThreadId = data.createdThreadId ?? operationContext.threadId; const finalThreadId = data.createdThreadId ?? operationContext.threadId;
// refresh the total data // refresh the total data
@@ -780,6 +783,18 @@ export class ConversationLifecycleActionImpl {
// Record the created topicId in metadata (not context) // Record the created topicId in metadata (not context)
this.#get().updateOperationMetadata(operationId, { createdTopicId: data.topicId }); this.#get().updateOperationMetadata(operationId, { createdTopicId: data.topicId });
} }
} else if (data.isCreateNewTopic && data.topicId && !context.isolatedTopic) {
this.#get().internal_dispatchTopic(
{
type: 'addTopic',
value: {
id: data.topicId,
title: newTopicTitle,
},
},
'sendMessage/createTopicPlaceholder',
);
this.#get().updateOperationMetadata(operationId, { createdTopicId: data.topicId });
} else if (operationContext.topicId) { } else if (operationContext.topicId) {
// Optimistically update topic's updatedAt so sidebar re-groups immediately // Optimistically update topic's updatedAt so sidebar re-groups immediately
this.#get().internal_dispatchTopic({ this.#get().internal_dispatchTopic({