mirror of
https://github.com/lobehub/lobe-chat.git
synced 2026-06-13 19:20:04 +00:00
⚡️ perf: optimize chat bootstrap persistence (#14934)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import type { DBMessageItem } from '@lobechat/types';
|
||||
import { eq } from 'drizzle-orm';
|
||||
import { asc, eq } from 'drizzle-orm';
|
||||
import { afterEach, beforeEach, describe, expect, it } from 'vitest';
|
||||
|
||||
import { uuid } from '@/utils/uuid';
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
messages,
|
||||
messagesFiles,
|
||||
sessions,
|
||||
topics,
|
||||
users,
|
||||
} from '../../../schemas';
|
||||
import type { LobeChatDatabase } from '../../../type';
|
||||
@@ -248,6 +249,124 @@ describe('MessageModel Create Tests', () => {
|
||||
expect(pluginResult[0].arguments).not.toContain('\u0000');
|
||||
});
|
||||
|
||||
it('should create user and assistant messages with one topic touch', async () => {
|
||||
await serverDB.insert(topics).values({
|
||||
id: 'topic-pair',
|
||||
sessionId: '1',
|
||||
title: 'Topic pair',
|
||||
userId,
|
||||
});
|
||||
|
||||
const timingEvents: string[] = [];
|
||||
const result = await messageModel.createUserAndAssistantMessages(
|
||||
{
|
||||
assistantMessage: {
|
||||
content: '',
|
||||
model: 'gpt-4o',
|
||||
provider: 'openai',
|
||||
role: 'assistant',
|
||||
sessionId: '1',
|
||||
topicId: 'topic-pair',
|
||||
},
|
||||
userMessage: {
|
||||
content: 'hello',
|
||||
files: ['f1'],
|
||||
role: 'user',
|
||||
sessionId: '1',
|
||||
topicId: 'topic-pair',
|
||||
},
|
||||
},
|
||||
{
|
||||
timing: {
|
||||
log: (event) => timingEvents.push(event),
|
||||
},
|
||||
},
|
||||
);
|
||||
|
||||
expect(result.userMessage.id).toBeDefined();
|
||||
expect(result.assistantMessage.id).toBeDefined();
|
||||
expect(result.assistantMessage.parentId).toBe(result.userMessage.id);
|
||||
expect(result.userMessage.createdAt.getTime()).toBeLessThan(
|
||||
result.assistantMessage.createdAt.getTime(),
|
||||
);
|
||||
|
||||
const dbMessages = await serverDB
|
||||
.select()
|
||||
.from(messages)
|
||||
.where(eq(messages.userId, userId))
|
||||
.orderBy(asc(messages.createdAt));
|
||||
|
||||
expect(dbMessages.map((message) => message.id)).toEqual([
|
||||
result.userMessage.id,
|
||||
result.assistantMessage.id,
|
||||
]);
|
||||
|
||||
const messageFiles = await serverDB
|
||||
.select()
|
||||
.from(messagesFiles)
|
||||
.where(eq(messagesFiles.messageId, result.userMessage.id));
|
||||
|
||||
expect(messageFiles).toHaveLength(1);
|
||||
expect(
|
||||
timingEvents.filter(
|
||||
(event) => event === 'db.message.createUserAndAssistant.messages.insert:start',
|
||||
),
|
||||
).toHaveLength(1);
|
||||
expect(
|
||||
timingEvents.filter(
|
||||
(event) => event === 'db.message.createUserAndAssistant.topic.touchUpdatedAt:start',
|
||||
),
|
||||
).toHaveLength(1);
|
||||
});
|
||||
|
||||
it('should skip topic touch when creating a pair for an already-created topic', async () => {
|
||||
await serverDB.insert(topics).values({
|
||||
id: 'topic-pair-no-touch',
|
||||
sessionId: '1',
|
||||
title: 'Topic pair no touch',
|
||||
userId,
|
||||
});
|
||||
|
||||
const timingEvents: string[] = [];
|
||||
const result = await messageModel.createUserAndAssistantMessages(
|
||||
{
|
||||
assistantMessage: {
|
||||
content: '',
|
||||
model: 'gpt-4o',
|
||||
provider: 'openai',
|
||||
role: 'assistant',
|
||||
sessionId: '1',
|
||||
topicId: 'topic-pair-no-touch',
|
||||
},
|
||||
userMessage: {
|
||||
content: 'hello',
|
||||
role: 'user',
|
||||
sessionId: '1',
|
||||
topicId: 'topic-pair-no-touch',
|
||||
},
|
||||
},
|
||||
{
|
||||
timing: {
|
||||
log: (event) => timingEvents.push(event),
|
||||
},
|
||||
touchTopicUpdatedAt: false,
|
||||
},
|
||||
);
|
||||
|
||||
expect(result.userMessage.id).toBeDefined();
|
||||
expect(result.assistantMessage.parentId).toBe(result.userMessage.id);
|
||||
expect(
|
||||
timingEvents.filter(
|
||||
(event) => event === 'db.message.createUserAndAssistant.messages.insert:start',
|
||||
),
|
||||
).toHaveLength(1);
|
||||
expect(
|
||||
timingEvents.filter(
|
||||
(event) => event === 'db.message.createUserAndAssistant.topic.touchUpdatedAt:start',
|
||||
),
|
||||
).toHaveLength(0);
|
||||
});
|
||||
|
||||
describe('create with advanced parameters', () => {
|
||||
it('should create a message with custom ID', async () => {
|
||||
const customId = 'custom-msg-id';
|
||||
|
||||
@@ -95,7 +95,10 @@ describe('TopicModel - Create', () => {
|
||||
|
||||
const topicId = 'new-topic';
|
||||
|
||||
const createdTopic = await topicModel.create(topicData, topicId);
|
||||
const timingEvents: string[] = [];
|
||||
const createdTopic = await topicModel.create(topicData, topicId, {
|
||||
log: (event) => timingEvents.push(event),
|
||||
});
|
||||
|
||||
expect(createdTopic).toEqual({
|
||||
id: topicId,
|
||||
@@ -123,6 +126,8 @@ describe('TopicModel - Create', () => {
|
||||
const dbTopic = await serverDB.select().from(topics).where(eq(topics.id, topicId));
|
||||
expect(dbTopic).toHaveLength(1);
|
||||
expect(dbTopic[0]).toEqual(createdTopic);
|
||||
expect(timingEvents).toContain('db.topic.create.topics.insert:start');
|
||||
expect(timingEvents).not.toContain('db.topic.create.transaction:start');
|
||||
});
|
||||
|
||||
it('should create a new topic with agentId', async () => {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,6 +4,12 @@ import type {
|
||||
DBMessageItem,
|
||||
TopicRankItem,
|
||||
} from '@lobechat/types';
|
||||
import type { TimingSink } from '@lobechat/utils';
|
||||
import {
|
||||
getDurationMs,
|
||||
logTimingSink as logTiming,
|
||||
runTimedSinkStage as runTimedStage,
|
||||
} from '@lobechat/utils';
|
||||
import type { SQL } from 'drizzle-orm';
|
||||
import { and, count, desc, eq, gt, gte, inArray, isNull, lte, ne, not, or, sql } from 'drizzle-orm';
|
||||
|
||||
@@ -62,12 +68,15 @@ interface QueryTopicParams {
|
||||
*/
|
||||
isInbox?: boolean;
|
||||
pageSize?: number;
|
||||
timing?: ModelTimingContext;
|
||||
/**
|
||||
* Include only topics matching the given trigger types (positive filter)
|
||||
*/
|
||||
triggers?: string[];
|
||||
}
|
||||
|
||||
export interface ModelTimingContext extends TimingSink {}
|
||||
|
||||
export interface ListTopicsForMemoryExtractorCursor {
|
||||
createdAt: Date;
|
||||
id: string;
|
||||
@@ -93,8 +102,18 @@ export class TopicModel {
|
||||
pageSize = 9999,
|
||||
groupId,
|
||||
isInbox,
|
||||
timing,
|
||||
triggers,
|
||||
}: QueryTopicParams = {}) => {
|
||||
const queryStartedAt = Date.now();
|
||||
logTiming(timing, 'db.topic.query:start', {
|
||||
current,
|
||||
hasAgentId: !!agentId,
|
||||
hasContainerId: !!containerId,
|
||||
hasGroupId: !!groupId,
|
||||
isInbox: !!isInbox,
|
||||
pageSize,
|
||||
});
|
||||
const offset = current * pageSize;
|
||||
const includeTriggerCondition =
|
||||
includeTriggers && includeTriggers.length > 0
|
||||
@@ -127,29 +146,42 @@ export class TopicModel {
|
||||
);
|
||||
|
||||
const [items, totalResult] = await Promise.all([
|
||||
this.db
|
||||
.select({
|
||||
completedAt: topics.completedAt,
|
||||
createdAt: topics.createdAt,
|
||||
favorite: topics.favorite,
|
||||
historySummary: topics.historySummary,
|
||||
id: topics.id,
|
||||
metadata: topics.metadata,
|
||||
status: topics.status,
|
||||
title: topics.title,
|
||||
updatedAt: topics.updatedAt,
|
||||
})
|
||||
.from(topics)
|
||||
.where(whereCondition)
|
||||
.orderBy(desc(topics.favorite), desc(topics.updatedAt))
|
||||
.limit(pageSize)
|
||||
.offset(offset),
|
||||
this.db
|
||||
.select({ count: count(topics.id) })
|
||||
.from(topics)
|
||||
.where(whereCondition),
|
||||
runTimedStage(
|
||||
timing,
|
||||
'db.topic.query.group.items.select',
|
||||
() =>
|
||||
this.db
|
||||
.select({
|
||||
completedAt: topics.completedAt,
|
||||
createdAt: topics.createdAt,
|
||||
favorite: topics.favorite,
|
||||
historySummary: topics.historySummary,
|
||||
id: topics.id,
|
||||
metadata: topics.metadata,
|
||||
status: topics.status,
|
||||
title: topics.title,
|
||||
updatedAt: topics.updatedAt,
|
||||
})
|
||||
.from(topics)
|
||||
.where(whereCondition)
|
||||
.orderBy(desc(topics.favorite), desc(topics.updatedAt))
|
||||
.limit(pageSize)
|
||||
.offset(offset),
|
||||
{ current, pageSize },
|
||||
),
|
||||
runTimedStage(timing, 'db.topic.query.group.count.select', () =>
|
||||
this.db
|
||||
.select({ count: count(topics.id) })
|
||||
.from(topics)
|
||||
.where(whereCondition),
|
||||
),
|
||||
]);
|
||||
|
||||
logTiming(timing, 'db.topic.query:done', {
|
||||
itemCount: items.length,
|
||||
stageMs: getDurationMs(queryStartedAt),
|
||||
total: totalResult[0].count,
|
||||
});
|
||||
return { items, total: totalResult[0].count };
|
||||
}
|
||||
|
||||
@@ -159,11 +191,19 @@ export class TopicModel {
|
||||
// 3. For inbox: sessionId IS NULL AND groupId IS NULL AND agentId IS NULL (legacy inbox data)
|
||||
if (agentId) {
|
||||
// Get the associated sessionId for backward compatibility with legacy data
|
||||
const agentSession = await this.db
|
||||
.select({ sessionId: agentsToSessions.sessionId })
|
||||
.from(agentsToSessions)
|
||||
.where(and(eq(agentsToSessions.agentId, agentId), eq(agentsToSessions.userId, this.userId)))
|
||||
.limit(1);
|
||||
const agentSession = await runTimedStage(
|
||||
timing,
|
||||
'db.topic.query.agentSession.select',
|
||||
() =>
|
||||
this.db
|
||||
.select({ sessionId: agentsToSessions.sessionId })
|
||||
.from(agentsToSessions)
|
||||
.where(
|
||||
and(eq(agentsToSessions.agentId, agentId), eq(agentsToSessions.userId, this.userId)),
|
||||
)
|
||||
.limit(1),
|
||||
{ hasAgentId: true },
|
||||
);
|
||||
|
||||
const associatedSessionId = agentSession[0]?.sessionId;
|
||||
|
||||
@@ -201,29 +241,46 @@ export class TopicModel {
|
||||
);
|
||||
|
||||
const [items, totalResult] = await Promise.all([
|
||||
this.db
|
||||
.select({
|
||||
completedAt: topics.completedAt,
|
||||
createdAt: topics.createdAt,
|
||||
favorite: topics.favorite,
|
||||
historySummary: topics.historySummary,
|
||||
id: topics.id,
|
||||
metadata: topics.metadata,
|
||||
status: topics.status,
|
||||
title: topics.title,
|
||||
updatedAt: topics.updatedAt,
|
||||
})
|
||||
.from(topics)
|
||||
.where(agentWhere)
|
||||
.orderBy(desc(topics.favorite), desc(topics.updatedAt))
|
||||
.limit(pageSize)
|
||||
.offset(offset),
|
||||
this.db
|
||||
.select({ count: count(topics.id) })
|
||||
.from(topics)
|
||||
.where(agentWhere),
|
||||
runTimedStage(
|
||||
timing,
|
||||
'db.topic.query.agent.items.select',
|
||||
() =>
|
||||
this.db
|
||||
.select({
|
||||
completedAt: topics.completedAt,
|
||||
createdAt: topics.createdAt,
|
||||
favorite: topics.favorite,
|
||||
historySummary: topics.historySummary,
|
||||
id: topics.id,
|
||||
metadata: topics.metadata,
|
||||
status: topics.status,
|
||||
title: topics.title,
|
||||
updatedAt: topics.updatedAt,
|
||||
})
|
||||
.from(topics)
|
||||
.where(agentWhere)
|
||||
.orderBy(desc(topics.favorite), desc(topics.updatedAt))
|
||||
.limit(pageSize)
|
||||
.offset(offset),
|
||||
{ current, hasAssociatedSessionId: !!associatedSessionId, isInbox: !!isInbox, pageSize },
|
||||
),
|
||||
runTimedStage(
|
||||
timing,
|
||||
'db.topic.query.agent.count.select',
|
||||
() =>
|
||||
this.db
|
||||
.select({ count: count(topics.id) })
|
||||
.from(topics)
|
||||
.where(agentWhere),
|
||||
{ hasAssociatedSessionId: !!associatedSessionId, isInbox: !!isInbox },
|
||||
),
|
||||
]);
|
||||
|
||||
logTiming(timing, 'db.topic.query:done', {
|
||||
itemCount: items.length,
|
||||
stageMs: getDurationMs(queryStartedAt),
|
||||
total: totalResult[0].count,
|
||||
});
|
||||
return { items, total: totalResult[0].count };
|
||||
}
|
||||
|
||||
@@ -238,37 +295,51 @@ export class TopicModel {
|
||||
);
|
||||
|
||||
const [items, totalResult] = await Promise.all([
|
||||
this.db
|
||||
.select({
|
||||
agentId: topics.agentId,
|
||||
completedAt: topics.completedAt,
|
||||
createdAt: topics.createdAt,
|
||||
favorite: topics.favorite,
|
||||
historySummary: topics.historySummary,
|
||||
id: topics.id,
|
||||
metadata: topics.metadata,
|
||||
sessionId: topics.sessionId,
|
||||
status: topics.status,
|
||||
title: topics.title,
|
||||
updatedAt: topics.updatedAt,
|
||||
})
|
||||
.from(topics)
|
||||
.where(whereCondition)
|
||||
// In boolean sorting, false is considered "smaller" than true.
|
||||
// So here we use desc to ensure that topics with favorite as true are in front.
|
||||
.orderBy(desc(topics.favorite), desc(topics.updatedAt))
|
||||
.limit(pageSize)
|
||||
.offset(offset),
|
||||
this.db
|
||||
.select({ count: count(topics.id) })
|
||||
.from(topics)
|
||||
.where(whereCondition),
|
||||
runTimedStage(
|
||||
timing,
|
||||
'db.topic.query.container.items.select',
|
||||
() =>
|
||||
this.db
|
||||
.select({
|
||||
agentId: topics.agentId,
|
||||
completedAt: topics.completedAt,
|
||||
createdAt: topics.createdAt,
|
||||
favorite: topics.favorite,
|
||||
historySummary: topics.historySummary,
|
||||
id: topics.id,
|
||||
metadata: topics.metadata,
|
||||
sessionId: topics.sessionId,
|
||||
status: topics.status,
|
||||
title: topics.title,
|
||||
updatedAt: topics.updatedAt,
|
||||
})
|
||||
.from(topics)
|
||||
.where(whereCondition)
|
||||
// In boolean sorting, false is considered "smaller" than true.
|
||||
// So here we use desc to ensure that topics with favorite as true are in front.
|
||||
.orderBy(desc(topics.favorite), desc(topics.updatedAt))
|
||||
.limit(pageSize)
|
||||
.offset(offset),
|
||||
{ current, pageSize },
|
||||
),
|
||||
runTimedStage(timing, 'db.topic.query.container.count.select', () =>
|
||||
this.db
|
||||
.select({ count: count(topics.id) })
|
||||
.from(topics)
|
||||
.where(whereCondition),
|
||||
),
|
||||
]);
|
||||
|
||||
// Remove internal fields before returning
|
||||
|
||||
const cleanItems = items.map(({ agentId, sessionId, ...rest }) => rest);
|
||||
|
||||
logTiming(timing, 'db.topic.query:done', {
|
||||
itemCount: cleanItems.length,
|
||||
stageMs: getDurationMs(queryStartedAt),
|
||||
total: totalResult[0].count,
|
||||
});
|
||||
|
||||
return { items: cleanItems, total: totalResult[0].count };
|
||||
};
|
||||
|
||||
@@ -468,30 +539,67 @@ export class TopicModel {
|
||||
create = async (
|
||||
{ messages: messageIds, ...params }: CreateTopicParams,
|
||||
id: string = this.genId(),
|
||||
timing?: ModelTimingContext,
|
||||
): Promise<TopicItem> => {
|
||||
return this.db.transaction(async (tx) => {
|
||||
const insertData = {
|
||||
...params,
|
||||
agentId: params.agentId || null,
|
||||
groupId: params.groupId || null,
|
||||
id,
|
||||
sessionId: params.sessionId || null,
|
||||
userId: this.userId,
|
||||
};
|
||||
const insertData = {
|
||||
...params,
|
||||
agentId: params.agentId || null,
|
||||
groupId: params.groupId || null,
|
||||
id,
|
||||
sessionId: params.sessionId || null,
|
||||
userId: this.userId,
|
||||
};
|
||||
const insertMeta = {
|
||||
hasAgentId: !!params.agentId,
|
||||
hasGroupId: !!params.groupId,
|
||||
hasSessionId: !!params.sessionId,
|
||||
};
|
||||
|
||||
// Insert new topic
|
||||
const [topic] = await tx.insert(topics).values(insertData).returning();
|
||||
|
||||
// Update associated messages' topicId
|
||||
if (messageIds && messageIds.length > 0) {
|
||||
await tx
|
||||
.update(messages)
|
||||
.set({ topicId: topic.id })
|
||||
.where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIds)));
|
||||
}
|
||||
if (!messageIds || messageIds.length === 0) {
|
||||
const [topic] = await runTimedStage(
|
||||
timing,
|
||||
'db.topic.create.topics.insert',
|
||||
() => this.db.insert(topics).values(insertData).returning(),
|
||||
insertMeta,
|
||||
);
|
||||
|
||||
return topic;
|
||||
});
|
||||
}
|
||||
|
||||
return runTimedStage(
|
||||
timing,
|
||||
'db.topic.create.transaction',
|
||||
() =>
|
||||
this.db.transaction(async (tx) => {
|
||||
// Insert new topic
|
||||
const [topic] = await runTimedStage(
|
||||
timing,
|
||||
'db.topic.create.topics.insert',
|
||||
() => tx.insert(topics).values(insertData).returning(),
|
||||
insertMeta,
|
||||
);
|
||||
|
||||
// Update associated messages' topicId
|
||||
await runTimedStage(
|
||||
timing,
|
||||
'db.topic.create.messages.updateTopic',
|
||||
() =>
|
||||
tx
|
||||
.update(messages)
|
||||
.set({ topicId: topic.id })
|
||||
.where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIds))),
|
||||
{ messageCount: messageIds.length },
|
||||
);
|
||||
|
||||
return topic;
|
||||
}),
|
||||
{
|
||||
hasAgentId: !!params.agentId,
|
||||
hasGroupId: !!params.groupId,
|
||||
hasSessionId: !!params.sessionId,
|
||||
messageCount: messageIds?.length ?? 0,
|
||||
},
|
||||
);
|
||||
};
|
||||
|
||||
batchCreate = async (topicParams: (CreateTopicParams & { id?: string })[]) => {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import type { ModelUsage, TracePayload } from '@lobechat/types';
|
||||
import { createTimingHelpers, getDurationMs } from '@lobechat/utils';
|
||||
import type { ClientOptions } from 'openai';
|
||||
|
||||
import type { LobeBedrockAIParams } from '../providers/bedrock';
|
||||
@@ -32,6 +33,13 @@ import type {
|
||||
import { AgentRuntimeError } from '../utils/createError';
|
||||
import type { LobeRuntimeAI } from './BaseAI';
|
||||
|
||||
const { logger: timing } = createTimingHelpers('lobe-server:chat:lobehub:timing');
|
||||
|
||||
const getLobeHubTimingMetadata = (options?: {
|
||||
metadata?: Record<string, unknown>;
|
||||
}): Record<string, unknown> | undefined =>
|
||||
options?.metadata?.provider === 'lobehub' ? options.metadata : undefined;
|
||||
|
||||
export interface AgentChatOptions {
|
||||
enableTrace?: boolean;
|
||||
provider: string;
|
||||
@@ -126,6 +134,17 @@ export class ModelRuntime {
|
||||
* ```
|
||||
*/
|
||||
async chat(payload: ChatStreamPayload, options?: ChatMethodOptions) {
|
||||
const metadata = getLobeHubTimingMetadata(options);
|
||||
const startedAt = Date.now();
|
||||
if (metadata) {
|
||||
timing(
|
||||
'ModelRuntime.chat start model=%s trigger=%s traceId=%s',
|
||||
payload.model,
|
||||
metadata.trigger,
|
||||
metadata.traceId,
|
||||
);
|
||||
}
|
||||
|
||||
if (typeof this._runtime.chat !== 'function') {
|
||||
throw AgentRuntimeError.chat({
|
||||
error: new Error('Chat is not supported by this provider'),
|
||||
@@ -135,11 +154,48 @@ export class ModelRuntime {
|
||||
}
|
||||
|
||||
try {
|
||||
const hooksStartedAt = Date.now();
|
||||
const finalOptions = await this.applyHooks(payload, options);
|
||||
return await this._runtime.chat(payload, finalOptions);
|
||||
if (metadata) {
|
||||
timing(
|
||||
'ModelRuntime.chat hooks done model=%s durationMs=%d traceId=%s',
|
||||
payload.model,
|
||||
getDurationMs(hooksStartedAt),
|
||||
metadata.traceId,
|
||||
);
|
||||
}
|
||||
const runtimeStartedAt = Date.now();
|
||||
const response = await this._runtime.chat(payload, finalOptions);
|
||||
if (metadata) {
|
||||
timing(
|
||||
'ModelRuntime.chat runtime done model=%s durationMs=%d totalMs=%d traceId=%s',
|
||||
payload.model,
|
||||
getDurationMs(runtimeStartedAt),
|
||||
getDurationMs(startedAt),
|
||||
metadata.traceId,
|
||||
);
|
||||
}
|
||||
return response;
|
||||
} catch (error) {
|
||||
if (metadata) {
|
||||
timing(
|
||||
'ModelRuntime.chat error model=%s durationMs=%d traceId=%s',
|
||||
payload.model,
|
||||
getDurationMs(startedAt),
|
||||
metadata.traceId,
|
||||
);
|
||||
}
|
||||
if (this._hooks?.onChatError) {
|
||||
const errorHookStartedAt = Date.now();
|
||||
await this._hooks.onChatError(error as ChatCompletionErrorPayload, { options, payload });
|
||||
if (metadata) {
|
||||
timing(
|
||||
'ModelRuntime.chat onChatError done model=%s durationMs=%d traceId=%s',
|
||||
payload.model,
|
||||
getDurationMs(errorHookStartedAt),
|
||||
metadata.traceId,
|
||||
);
|
||||
}
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
@@ -152,7 +208,37 @@ export class ModelRuntime {
|
||||
payload: ChatStreamPayload,
|
||||
options?: ChatMethodOptions,
|
||||
): Promise<ChatMethodOptions | undefined> {
|
||||
await this._hooks?.beforeChat?.(payload, options);
|
||||
const metadata = getLobeHubTimingMetadata(options);
|
||||
const beforeChatStartedAt = Date.now();
|
||||
if (metadata) {
|
||||
timing(
|
||||
'ModelRuntime.beforeChat start model=%s trigger=%s traceId=%s',
|
||||
payload.model,
|
||||
metadata.trigger,
|
||||
metadata.traceId,
|
||||
);
|
||||
}
|
||||
try {
|
||||
await this._hooks?.beforeChat?.(payload, options);
|
||||
} catch (error) {
|
||||
if (metadata) {
|
||||
timing(
|
||||
'ModelRuntime.beforeChat error model=%s durationMs=%d traceId=%s',
|
||||
payload.model,
|
||||
getDurationMs(beforeChatStartedAt),
|
||||
metadata.traceId,
|
||||
);
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
if (metadata) {
|
||||
timing(
|
||||
'ModelRuntime.beforeChat done model=%s durationMs=%d traceId=%s',
|
||||
payload.model,
|
||||
getDurationMs(beforeChatStartedAt),
|
||||
metadata.traceId,
|
||||
);
|
||||
}
|
||||
|
||||
if (!this._hooks?.onChatFinal) return options;
|
||||
|
||||
@@ -163,10 +249,34 @@ export class ModelRuntime {
|
||||
callback: {
|
||||
...options?.callback,
|
||||
async onFinal(data) {
|
||||
const finalStartedAt = Date.now();
|
||||
if (metadata) {
|
||||
timing(
|
||||
'ModelRuntime.onChatFinal start model=%s traceId=%s',
|
||||
payload.model,
|
||||
metadata.traceId,
|
||||
);
|
||||
}
|
||||
await existingOnFinal?.(data);
|
||||
try {
|
||||
await hookFn(data, { options, payload });
|
||||
if (metadata) {
|
||||
timing(
|
||||
'ModelRuntime.onChatFinal done model=%s durationMs=%d traceId=%s',
|
||||
payload.model,
|
||||
getDurationMs(finalStartedAt),
|
||||
metadata.traceId,
|
||||
);
|
||||
}
|
||||
} catch (e) {
|
||||
if (metadata) {
|
||||
timing(
|
||||
'ModelRuntime.onChatFinal error model=%s durationMs=%d traceId=%s',
|
||||
payload.model,
|
||||
getDurationMs(finalStartedAt),
|
||||
metadata.traceId,
|
||||
);
|
||||
}
|
||||
// Hook failures (billing, tracing) must not interfere with response completion
|
||||
console.error('[ModelRuntime] onChatFinal hook error:', e);
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
import type { GoogleGenAIOptions } from '@google/genai';
|
||||
import type { ChatModelCard } from '@lobechat/types';
|
||||
import { AgentRuntimeErrorType } from '@lobechat/types';
|
||||
import { createTimingHelpers, getDurationMs } from '@lobechat/utils';
|
||||
import debug from 'debug';
|
||||
import type { ClientOptions } from 'openai';
|
||||
import type OpenAI from 'openai';
|
||||
@@ -44,6 +45,7 @@ import type {
|
||||
import type { ApiType, RuntimeClass } from './apiTypes';
|
||||
|
||||
const log = debug('lobe-model-runtime:router-runtime');
|
||||
const { logger: timing } = createTimingHelpers('lobe-server:chat:lobehub:timing');
|
||||
|
||||
interface ProviderIniOptions extends Record<string, any> {
|
||||
accessKeyId?: string;
|
||||
@@ -190,6 +192,7 @@ export const createRouterRuntime = ({
|
||||
private _id: string;
|
||||
|
||||
constructor(options: ClientOptions & Record<string, any> = {}) {
|
||||
const startedAt = Date.now();
|
||||
this._options = {
|
||||
...options,
|
||||
apiKey: options.apiKey?.trim() || DEFAULT_API_KEY,
|
||||
@@ -200,36 +203,76 @@ export const createRouterRuntime = ({
|
||||
this._routers = routers;
|
||||
this._params = params;
|
||||
this._id = options.id ?? id;
|
||||
|
||||
if (this._id === 'lobehub') {
|
||||
timing(
|
||||
'constructor done providerId=%s durationMs=%d hasApiKey=%s hasBaseURL=%s',
|
||||
this._id,
|
||||
getDurationMs(startedAt),
|
||||
!!this._options.apiKey,
|
||||
!!this._options.baseURL,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve routers configuration and validate
|
||||
*/
|
||||
private async resolveRouters(model?: string): Promise<RouterInstance[]> {
|
||||
const resolvedRouters =
|
||||
typeof this._routers === 'function'
|
||||
? await this._routers(this._options, { model })
|
||||
: this._routers;
|
||||
const startedAt = Date.now();
|
||||
try {
|
||||
const resolvedRouters =
|
||||
typeof this._routers === 'function'
|
||||
? await this._routers(this._options, { model })
|
||||
: this._routers;
|
||||
|
||||
if (resolvedRouters.length === 0) {
|
||||
throw AgentRuntimeError.chat({
|
||||
error: { message: 'empty providers' },
|
||||
errorType: AgentRuntimeErrorType.NoAvailableProvider,
|
||||
provider: this._id,
|
||||
});
|
||||
if (this._id === 'lobehub') {
|
||||
timing(
|
||||
'resolveRouters done model=%s durationMs=%d routerCount=%d dynamic=%s',
|
||||
model,
|
||||
getDurationMs(startedAt),
|
||||
resolvedRouters.length,
|
||||
typeof this._routers === 'function',
|
||||
);
|
||||
}
|
||||
|
||||
if (resolvedRouters.length === 0) {
|
||||
throw AgentRuntimeError.chat({
|
||||
error: { message: 'empty providers' },
|
||||
errorType: AgentRuntimeErrorType.NoAvailableProvider,
|
||||
provider: this._id,
|
||||
});
|
||||
}
|
||||
|
||||
return resolvedRouters;
|
||||
} catch (error) {
|
||||
if (this._id === 'lobehub') {
|
||||
timing('resolveRouters error model=%s durationMs=%d', model, getDurationMs(startedAt));
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
|
||||
return resolvedRouters;
|
||||
}
|
||||
|
||||
private async resolveMatchedRouter(model: string): Promise<RouterInstance> {
|
||||
const startedAt = Date.now();
|
||||
const resolvedRouters = await this.resolveRouters(model);
|
||||
const baseURL = this._options.baseURL;
|
||||
|
||||
// Priority 1: Match by baseURLPattern (RegExp only)
|
||||
if (baseURL) {
|
||||
const baseURLMatch = resolvedRouters.find((router) => router.baseURLPattern?.test(baseURL));
|
||||
if (baseURLMatch) return baseURLMatch;
|
||||
if (baseURLMatch) {
|
||||
if (this._id === 'lobehub') {
|
||||
timing(
|
||||
'resolveMatchedRouter done model=%s match=baseURL routerId=%s apiType=%s durationMs=%d',
|
||||
model,
|
||||
baseURLMatch.id,
|
||||
baseURLMatch.apiType,
|
||||
getDurationMs(startedAt),
|
||||
);
|
||||
}
|
||||
return baseURLMatch;
|
||||
}
|
||||
}
|
||||
|
||||
// Priority 2: Match by models
|
||||
@@ -239,19 +282,50 @@ export const createRouterRuntime = ({
|
||||
}
|
||||
return false;
|
||||
});
|
||||
if (modelMatch) return modelMatch;
|
||||
if (modelMatch) {
|
||||
if (this._id === 'lobehub') {
|
||||
timing(
|
||||
'resolveMatchedRouter done model=%s match=models routerId=%s apiType=%s durationMs=%d',
|
||||
model,
|
||||
modelMatch.id,
|
||||
modelMatch.apiType,
|
||||
getDurationMs(startedAt),
|
||||
);
|
||||
}
|
||||
return modelMatch;
|
||||
}
|
||||
|
||||
// Fallback: Use the last router
|
||||
return resolvedRouters.at(-1)!;
|
||||
const fallbackRouter = resolvedRouters.at(-1)!;
|
||||
if (this._id === 'lobehub') {
|
||||
timing(
|
||||
'resolveMatchedRouter done model=%s match=fallback routerId=%s apiType=%s durationMs=%d',
|
||||
model,
|
||||
fallbackRouter.id,
|
||||
fallbackRouter.apiType,
|
||||
getDurationMs(startedAt),
|
||||
);
|
||||
}
|
||||
return fallbackRouter;
|
||||
}
|
||||
|
||||
private normalizeRouterOptions(router: RouterInstance): RouterOptionItem[] {
|
||||
const startedAt = Date.now();
|
||||
const routerOptions = Array.isArray(router.options) ? router.options : [router.options];
|
||||
|
||||
if (routerOptions.length === 0 || routerOptions.some((optionItem) => !optionItem)) {
|
||||
throw new Error('empty provider options');
|
||||
}
|
||||
|
||||
if (this._id === 'lobehub') {
|
||||
timing(
|
||||
'normalizeRouterOptions done routerId=%s options=%d durationMs=%d',
|
||||
router.id,
|
||||
routerOptions.length,
|
||||
getDurationMs(startedAt),
|
||||
);
|
||||
}
|
||||
|
||||
return routerOptions;
|
||||
}
|
||||
|
||||
@@ -268,6 +342,7 @@ export const createRouterRuntime = ({
|
||||
remark?: string;
|
||||
runtime: LobeRuntimeAI;
|
||||
}> {
|
||||
const startedAt = Date.now();
|
||||
const { apiType: optionApiType, id: channelId, remark, ...optionOverrides } = optionItem;
|
||||
const resolvedApiType = optionApiType ?? router.apiType;
|
||||
const finalOptions = {
|
||||
@@ -297,6 +372,16 @@ export const createRouterRuntime = ({
|
||||
if (project) vertexOptions.project = project;
|
||||
if (location) vertexOptions.location = location as GoogleGenAIOptions['location'];
|
||||
|
||||
if (this._id === 'lobehub') {
|
||||
timing(
|
||||
'createRuntimeFromOption done routerId=%s channelId=%s apiType=%s durationMs=%d vertex=true',
|
||||
router.id,
|
||||
channelId,
|
||||
resolvedApiType,
|
||||
getDurationMs(startedAt),
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
channelId,
|
||||
id: resolvedApiType,
|
||||
@@ -312,6 +397,16 @@ export const createRouterRuntime = ({
|
||||
: (baseRuntimeMap[resolvedApiType] ?? LobeOpenAI);
|
||||
const runtime: LobeRuntimeAI = new providerAI({ ...finalOptions, id: this._id });
|
||||
|
||||
if (this._id === 'lobehub') {
|
||||
timing(
|
||||
'createRuntimeFromOption done routerId=%s channelId=%s apiType=%s durationMs=%d',
|
||||
router.id,
|
||||
channelId,
|
||||
resolvedApiType,
|
||||
getDurationMs(startedAt),
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
channelId,
|
||||
id: resolvedApiType,
|
||||
@@ -325,10 +420,22 @@ export const createRouterRuntime = ({
|
||||
requestHandler: (runtime: LobeRuntimeAI) => Promise<T>,
|
||||
metadata?: Record<string, unknown>,
|
||||
): Promise<T> {
|
||||
const totalStartedAt = Date.now();
|
||||
const matchedRouter = await this.resolveMatchedRouter(model);
|
||||
const routerOptions = this.normalizeRouterOptions(matchedRouter);
|
||||
const totalOptions = routerOptions.length;
|
||||
|
||||
if (this._id === 'lobehub') {
|
||||
timing(
|
||||
'runWithFallback start model=%s routerId=%s apiType=%s options=%d traceId=%s',
|
||||
model,
|
||||
matchedRouter.id,
|
||||
matchedRouter.apiType,
|
||||
totalOptions,
|
||||
metadata?.traceId,
|
||||
);
|
||||
}
|
||||
|
||||
log(
|
||||
'resolve router for model=%s apiType=%s options=%d',
|
||||
model,
|
||||
@@ -349,7 +456,33 @@ export const createRouterRuntime = ({
|
||||
} = await this.createRuntimeFromOption(matchedRouter, optionItem);
|
||||
|
||||
try {
|
||||
if (this._id === 'lobehub') {
|
||||
timing(
|
||||
'attempt request start model=%s attempt=%d/%d routerId=%s channelId=%s apiType=%s traceId=%s',
|
||||
model,
|
||||
attempt,
|
||||
totalOptions,
|
||||
matchedRouter.id,
|
||||
channelId,
|
||||
resolvedApiType,
|
||||
metadata?.traceId,
|
||||
);
|
||||
}
|
||||
const result = await requestHandler(runtime);
|
||||
if (this._id === 'lobehub') {
|
||||
timing(
|
||||
'attempt request success model=%s attempt=%d/%d routerId=%s channelId=%s apiType=%s durationMs=%d totalMs=%d traceId=%s',
|
||||
model,
|
||||
attempt,
|
||||
totalOptions,
|
||||
matchedRouter.id,
|
||||
channelId,
|
||||
resolvedApiType,
|
||||
getDurationMs(startTime),
|
||||
getDurationMs(totalStartedAt),
|
||||
metadata?.traceId,
|
||||
);
|
||||
}
|
||||
|
||||
if (totalOptions > 1 && attempt > 1) {
|
||||
log(
|
||||
@@ -392,6 +525,20 @@ export const createRouterRuntime = ({
|
||||
return result;
|
||||
} catch (error) {
|
||||
lastError = error;
|
||||
if (this._id === 'lobehub') {
|
||||
timing(
|
||||
'attempt request error model=%s attempt=%d/%d routerId=%s channelId=%s apiType=%s durationMs=%d totalMs=%d traceId=%s',
|
||||
model,
|
||||
attempt,
|
||||
totalOptions,
|
||||
matchedRouter.id,
|
||||
channelId,
|
||||
resolvedApiType,
|
||||
getDurationMs(startTime),
|
||||
getDurationMs(totalStartedAt),
|
||||
metadata?.traceId,
|
||||
);
|
||||
}
|
||||
|
||||
params
|
||||
.onRouteAttempt?.({
|
||||
@@ -417,6 +564,7 @@ export const createRouterRuntime = ({
|
||||
}
|
||||
|
||||
try {
|
||||
const shouldStopStartedAt = Date.now();
|
||||
const shouldStopFallback = await params.shouldStopFallback?.({
|
||||
error,
|
||||
metadata,
|
||||
@@ -424,6 +572,18 @@ export const createRouterRuntime = ({
|
||||
optionIndex: index,
|
||||
});
|
||||
|
||||
if (this._id === 'lobehub') {
|
||||
timing(
|
||||
'shouldStopFallback done model=%s attempt=%d/%d durationMs=%d shouldStop=%s traceId=%s',
|
||||
model,
|
||||
attempt,
|
||||
totalOptions,
|
||||
getDurationMs(shouldStopStartedAt),
|
||||
shouldStopFallback,
|
||||
metadata?.traceId,
|
||||
);
|
||||
}
|
||||
|
||||
if (shouldStopFallback) {
|
||||
throw error;
|
||||
}
|
||||
@@ -460,6 +620,17 @@ export const createRouterRuntime = ({
|
||||
}
|
||||
}
|
||||
|
||||
if (this._id === 'lobehub') {
|
||||
timing(
|
||||
'runWithFallback failed model=%s routerId=%s options=%d totalMs=%d traceId=%s',
|
||||
model,
|
||||
matchedRouter.id,
|
||||
totalOptions,
|
||||
getDurationMs(totalStartedAt),
|
||||
metadata?.traceId,
|
||||
);
|
||||
}
|
||||
|
||||
throw lastError ?? new Error('empty provider options');
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { AiSendMessageServerSchema } from './aiChat';
|
||||
|
||||
const createInput = (topicPageSize: number) => ({
|
||||
newAssistantMessage: { model: 'gpt-4o', provider: 'openai' },
|
||||
newUserMessage: { content: 'hello' },
|
||||
topicPageSize,
|
||||
});
|
||||
|
||||
describe('AiSendMessageServerSchema', () => {
|
||||
it('should only accept positive integer topic page sizes up to 100', () => {
|
||||
for (const topicPageSize of [1, 20, 100]) {
|
||||
expect(AiSendMessageServerSchema.safeParse(createInput(topicPageSize)).success).toBe(true);
|
||||
}
|
||||
|
||||
for (const topicPageSize of [-1, 0, 1.5, 101]) {
|
||||
expect(AiSendMessageServerSchema.safeParse(createInput(topicPageSize)).success).toBe(false);
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -96,6 +96,10 @@ export interface SendMessageServerParams {
|
||||
};
|
||||
// if there is activeTopicId, then add topicId to message
|
||||
topicId?: string;
|
||||
/**
|
||||
* Page size for the topic list returned after creating a new topic.
|
||||
*/
|
||||
topicPageSize?: number;
|
||||
}
|
||||
|
||||
export const CreateThreadWithMessageSchema = z.object({
|
||||
@@ -156,6 +160,7 @@ export const AiSendMessageServerSchema = z.object({
|
||||
includeTriggers: z.array(z.string()).optional(),
|
||||
})
|
||||
.optional(),
|
||||
topicPageSize: z.number().int().min(1).max(100).optional(),
|
||||
topicId: z.string().optional(),
|
||||
});
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ export * from './pricing';
|
||||
export * from './safeParseJSON';
|
||||
export * from './sanitizeToolCallArguments';
|
||||
export * from './sleep';
|
||||
export * from './timing';
|
||||
export * from './uriParser';
|
||||
export * from './url';
|
||||
export * from './uuid';
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
import debug from 'debug';
|
||||
|
||||
export interface TimingContext {
|
||||
requestId: string;
|
||||
startedAt: number;
|
||||
}
|
||||
|
||||
export interface TimingMetadata {
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export interface TimingParams {
|
||||
timingRequestId?: string;
|
||||
timingStartedAt?: number;
|
||||
}
|
||||
|
||||
export interface TimingSink {
|
||||
log: (event: string, metadata?: TimingMetadata) => void;
|
||||
}
|
||||
|
||||
export type TimingLogger = (formatter: string, ...args: unknown[]) => void;
|
||||
|
||||
export const createDebugTimingLogger = (namespace: string): TimingLogger => debug(namespace);
|
||||
|
||||
export const getDurationMs = (startedAt: number) => Date.now() - startedAt;
|
||||
|
||||
export const createTimingRequestId = () =>
|
||||
globalThis.crypto?.randomUUID?.() ??
|
||||
`${Date.now().toString(36)}-${Math.random().toString(36).slice(2, 10)}`;
|
||||
|
||||
const isRecord = (value: unknown): value is Record<string, unknown> =>
|
||||
!!value && typeof value === 'object';
|
||||
|
||||
export const getTimingErrorMetadata = (error: unknown): TimingMetadata => {
|
||||
if (error instanceof Error) {
|
||||
return {
|
||||
errorMessage: error.message,
|
||||
errorName: error.name,
|
||||
};
|
||||
}
|
||||
|
||||
if (isRecord(error)) {
|
||||
return {
|
||||
errorType: typeof error.errorType === 'string' ? error.errorType : undefined,
|
||||
status: typeof error.status === 'number' ? error.status : undefined,
|
||||
};
|
||||
}
|
||||
|
||||
return { errorMessage: String(error) };
|
||||
};
|
||||
|
||||
export const toTimingContext = (params?: TimingParams): TimingContext | undefined =>
|
||||
params?.timingRequestId
|
||||
? { requestId: params.timingRequestId, startedAt: params.timingStartedAt ?? Date.now() }
|
||||
: undefined;
|
||||
|
||||
export const logTiming = (
|
||||
logger: TimingLogger,
|
||||
context: TimingContext | undefined,
|
||||
event: string,
|
||||
metadata?: TimingMetadata,
|
||||
) => {
|
||||
if (!context) return;
|
||||
|
||||
const totalMs = getDurationMs(context.startedAt);
|
||||
if (metadata) {
|
||||
logger('[%s] %s totalMs=%d %O', context.requestId, event, totalMs, metadata);
|
||||
return;
|
||||
}
|
||||
|
||||
logger('[%s] %s totalMs=%d', context.requestId, event, totalMs);
|
||||
};
|
||||
|
||||
export const logTimingSink = (
|
||||
timing: TimingSink | undefined,
|
||||
event: string,
|
||||
metadata?: TimingMetadata,
|
||||
) => {
|
||||
timing?.log(event, metadata);
|
||||
};
|
||||
|
||||
export const runTimedStage = async <T>(
|
||||
logger: TimingLogger,
|
||||
context: TimingContext | undefined,
|
||||
stage: string,
|
||||
task: () => T | Promise<T>,
|
||||
metadata?: TimingMetadata,
|
||||
): Promise<Awaited<T>> => {
|
||||
if (!context) return await task();
|
||||
|
||||
const startedAt = Date.now();
|
||||
logTiming(logger, context, `${stage}:start`, metadata);
|
||||
|
||||
try {
|
||||
const result = await task();
|
||||
logTiming(logger, context, `${stage}:done`, {
|
||||
...metadata,
|
||||
stageMs: getDurationMs(startedAt),
|
||||
});
|
||||
|
||||
return result;
|
||||
} catch (error) {
|
||||
logTiming(logger, context, `${stage}:error`, {
|
||||
...metadata,
|
||||
...getTimingErrorMetadata(error),
|
||||
stageMs: getDurationMs(startedAt),
|
||||
});
|
||||
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
export const runTimedSinkStage = async <T>(
|
||||
timing: TimingSink | undefined,
|
||||
stage: string,
|
||||
task: () => T | Promise<T>,
|
||||
metadata?: TimingMetadata,
|
||||
): Promise<Awaited<T>> => {
|
||||
if (!timing) return await task();
|
||||
|
||||
const startedAt = Date.now();
|
||||
logTimingSink(timing, `${stage}:start`, metadata);
|
||||
|
||||
try {
|
||||
const result = await task();
|
||||
logTimingSink(timing, `${stage}:done`, {
|
||||
...metadata,
|
||||
stageMs: getDurationMs(startedAt),
|
||||
});
|
||||
|
||||
return result;
|
||||
} catch (error) {
|
||||
logTimingSink(timing, `${stage}:error`, {
|
||||
...metadata,
|
||||
...getTimingErrorMetadata(error),
|
||||
stageMs: getDurationMs(startedAt),
|
||||
});
|
||||
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
export const createPrefixedTimingContext = (
|
||||
logger: TimingLogger,
|
||||
context: TimingContext | undefined,
|
||||
prefix: string,
|
||||
): TimingSink | undefined =>
|
||||
context
|
||||
? {
|
||||
log: (event: string, metadata?: TimingMetadata) => {
|
||||
logTiming(logger, context, `${prefix}.${event}`, metadata);
|
||||
},
|
||||
}
|
||||
: undefined;
|
||||
|
||||
export const createTimingHelpers = (namespace: string) => {
|
||||
const logger = createDebugTimingLogger(namespace);
|
||||
|
||||
return {
|
||||
createPrefixedTimingContext: (context: TimingContext | undefined, prefix: string) =>
|
||||
createPrefixedTimingContext(logger, context, prefix),
|
||||
logger,
|
||||
logTiming: (context: TimingContext | undefined, event: string, metadata?: TimingMetadata) =>
|
||||
logTiming(logger, context, event, metadata),
|
||||
runTimedStage: <T>(
|
||||
context: TimingContext | undefined,
|
||||
stage: string,
|
||||
task: () => T | Promise<T>,
|
||||
metadata?: TimingMetadata,
|
||||
) => runTimedStage(logger, context, stage, task, metadata),
|
||||
toTimingContext,
|
||||
};
|
||||
};
|
||||
@@ -1,4 +1,5 @@
|
||||
// @vitest-environment node
|
||||
import type { CreateMessageParams } from '@lobechat/types';
|
||||
import { ThreadType } from '@lobechat/types';
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
|
||||
@@ -10,6 +11,8 @@ import { AiChatService } from '@/server/services/aiChat';
|
||||
|
||||
import { aiChatRouter } from '../aiChat';
|
||||
|
||||
const flushAsyncTasks = () => new Promise<void>((resolve) => setTimeout(resolve, 0));
|
||||
|
||||
vi.mock('@/database/models/agent');
|
||||
vi.mock('@/database/models/message');
|
||||
vi.mock('@/database/models/thread');
|
||||
@@ -24,6 +27,38 @@ vi.mock('@/server/modules/ModelRuntime', () => ({
|
||||
|
||||
describe('aiChatRouter', () => {
|
||||
const mockCtx = { userId: 'u1' };
|
||||
const mockMessageModel = (mockCreateMessage: ReturnType<typeof vi.fn>) => {
|
||||
const mockCreateUserAndAssistantMessages = vi.fn(
|
||||
async (
|
||||
{
|
||||
assistantMessage,
|
||||
userMessage,
|
||||
}: {
|
||||
assistantMessage: CreateMessageParams;
|
||||
userMessage: CreateMessageParams;
|
||||
},
|
||||
_options?: unknown,
|
||||
) => {
|
||||
const userMessageItem = await mockCreateMessage(userMessage);
|
||||
const assistantMessageItem = await mockCreateMessage({
|
||||
...assistantMessage,
|
||||
parentId: userMessageItem.id,
|
||||
});
|
||||
|
||||
return { assistantMessage: assistantMessageItem, userMessage: userMessageItem };
|
||||
},
|
||||
);
|
||||
|
||||
vi.mocked(MessageModel).mockImplementation(
|
||||
() =>
|
||||
({
|
||||
create: mockCreateMessage,
|
||||
createUserAndAssistantMessages: mockCreateUserAndAssistantMessages,
|
||||
}) as any,
|
||||
);
|
||||
|
||||
return mockCreateUserAndAssistantMessages;
|
||||
};
|
||||
|
||||
it('should create topic optionally, create user/assistant messages, and return payload', async () => {
|
||||
const mockCreateTopic = vi.fn().mockResolvedValue({ id: 't1' });
|
||||
@@ -37,7 +72,7 @@ describe('aiChatRouter', () => {
|
||||
});
|
||||
|
||||
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
|
||||
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any);
|
||||
const mockCreateUserAndAssistantMessages = mockMessageModel(mockCreateMessage);
|
||||
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
|
||||
|
||||
const caller = aiChatRouter.createCaller(mockCtx as any);
|
||||
@@ -47,6 +82,7 @@ describe('aiChatRouter', () => {
|
||||
newTopic: { title: 'T', topicMessageIds: ['a', 'b'] },
|
||||
newUserMessage: { content: 'hi', files: ['f1'] },
|
||||
sessionId: 's1',
|
||||
topicPageSize: 20,
|
||||
} as any;
|
||||
|
||||
const res = await caller.sendMessageInServer(input);
|
||||
@@ -79,9 +115,19 @@ describe('aiChatRouter', () => {
|
||||
topicId: 't1',
|
||||
}),
|
||||
);
|
||||
expect(mockCreateUserAndAssistantMessages).toHaveBeenCalledTimes(1);
|
||||
expect(mockCreateUserAndAssistantMessages).toHaveBeenCalledWith(
|
||||
expect.any(Object),
|
||||
expect.objectContaining({ touchTopicUpdatedAt: false }),
|
||||
);
|
||||
|
||||
expect(mockGet).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ includeTopic: true, sessionId: 's1', topicId: 't1' }),
|
||||
expect.objectContaining({
|
||||
includeTopic: true,
|
||||
sessionId: 's1',
|
||||
topicId: 't1',
|
||||
topicPageSize: 20,
|
||||
}),
|
||||
);
|
||||
expect(res.assistantMessageId).toBe('m-assistant');
|
||||
expect(res.userMessageId).toBe('m-user');
|
||||
@@ -99,7 +145,7 @@ describe('aiChatRouter', () => {
|
||||
.mockResolvedValueOnce({ id: 'm-assistant' });
|
||||
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
|
||||
|
||||
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any);
|
||||
const mockCreateUserAndAssistantMessages = mockMessageModel(mockCreateMessage);
|
||||
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
|
||||
|
||||
const caller = aiChatRouter.createCaller(mockCtx as any);
|
||||
@@ -112,6 +158,10 @@ describe('aiChatRouter', () => {
|
||||
} as any);
|
||||
|
||||
expect(mockCreateMessage).toHaveBeenCalled();
|
||||
expect(mockCreateUserAndAssistantMessages).toHaveBeenCalledWith(
|
||||
expect.any(Object),
|
||||
expect.objectContaining({ touchTopicUpdatedAt: true }),
|
||||
);
|
||||
expect(mockGet).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
includeTopic: false,
|
||||
@@ -130,7 +180,7 @@ describe('aiChatRouter', () => {
|
||||
.mockResolvedValueOnce({ id: 'm-assistant' });
|
||||
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
|
||||
|
||||
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any);
|
||||
mockMessageModel(mockCreateMessage);
|
||||
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
|
||||
|
||||
const caller = aiChatRouter.createCaller(mockCtx as any);
|
||||
@@ -175,7 +225,7 @@ describe('aiChatRouter', () => {
|
||||
.mockResolvedValueOnce({ id: 'm-assistant' });
|
||||
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
|
||||
|
||||
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any);
|
||||
mockMessageModel(mockCreateMessage);
|
||||
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
|
||||
|
||||
const caller = aiChatRouter.createCaller(mockCtx as any);
|
||||
@@ -282,7 +332,7 @@ describe('aiChatRouter', () => {
|
||||
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
|
||||
|
||||
vi.mocked(ThreadModel).mockImplementation(() => ({ create: mockCreateThread }) as any);
|
||||
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any);
|
||||
mockMessageModel(mockCreateMessage);
|
||||
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
|
||||
|
||||
const caller = aiChatRouter.createCaller(mockCtx as any);
|
||||
@@ -346,7 +396,7 @@ describe('aiChatRouter', () => {
|
||||
|
||||
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
|
||||
vi.mocked(ThreadModel).mockImplementation(() => ({ create: mockCreateThread }) as any);
|
||||
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any);
|
||||
mockMessageModel(mockCreateMessage);
|
||||
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
|
||||
|
||||
const caller = aiChatRouter.createCaller(mockCtx as any);
|
||||
@@ -402,7 +452,7 @@ describe('aiChatRouter', () => {
|
||||
.mockResolvedValueOnce({ id: 'm-assistant' });
|
||||
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
|
||||
|
||||
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any);
|
||||
mockMessageModel(mockCreateMessage);
|
||||
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
|
||||
|
||||
const caller = aiChatRouter.createCaller(mockCtx as any);
|
||||
@@ -427,7 +477,7 @@ describe('aiChatRouter', () => {
|
||||
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: [{}] });
|
||||
|
||||
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
|
||||
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any);
|
||||
mockMessageModel(mockCreateMessage);
|
||||
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
|
||||
|
||||
const caller = aiChatRouter.createCaller(mockCtx as any);
|
||||
@@ -459,7 +509,7 @@ describe('aiChatRouter', () => {
|
||||
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: [{}] });
|
||||
|
||||
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
|
||||
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any);
|
||||
mockMessageModel(mockCreateMessage);
|
||||
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
|
||||
|
||||
const caller = aiChatRouter.createCaller(mockCtx as any);
|
||||
@@ -489,7 +539,7 @@ describe('aiChatRouter', () => {
|
||||
.mockResolvedValueOnce({ id: 'm-assistant' });
|
||||
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
|
||||
|
||||
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any);
|
||||
mockMessageModel(mockCreateMessage);
|
||||
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
|
||||
|
||||
const caller = aiChatRouter.createCaller(mockCtx as any);
|
||||
@@ -537,7 +587,7 @@ describe('aiChatRouter', () => {
|
||||
.mockResolvedValueOnce({ id: 'm-assistant' });
|
||||
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
|
||||
|
||||
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any);
|
||||
mockMessageModel(mockCreateMessage);
|
||||
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
|
||||
|
||||
const caller = aiChatRouter.createCaller(mockCtx as any);
|
||||
@@ -569,7 +619,7 @@ describe('aiChatRouter', () => {
|
||||
.mockResolvedValueOnce({ id: 'm-assistant' });
|
||||
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
|
||||
|
||||
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any);
|
||||
mockMessageModel(mockCreateMessage);
|
||||
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
|
||||
|
||||
const caller = aiChatRouter.createCaller(mockCtx as any);
|
||||
@@ -621,7 +671,7 @@ describe('aiChatRouter', () => {
|
||||
.mockResolvedValueOnce({ id: 'm-assistant' });
|
||||
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
|
||||
|
||||
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any);
|
||||
mockMessageModel(mockCreateMessage);
|
||||
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
|
||||
|
||||
const caller = aiChatRouter.createCaller(mockCtx as any);
|
||||
@@ -677,7 +727,7 @@ describe('aiChatRouter', () => {
|
||||
const mockTouchUpdatedAt = vi.fn().mockResolvedValue(undefined);
|
||||
|
||||
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
|
||||
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any);
|
||||
mockMessageModel(mockCreateMessage);
|
||||
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
|
||||
vi.mocked(AgentModel).mockImplementation(
|
||||
() => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any,
|
||||
@@ -713,7 +763,7 @@ describe('aiChatRouter', () => {
|
||||
const mockTouchUpdatedAt = vi.fn().mockResolvedValue(undefined);
|
||||
|
||||
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
|
||||
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any);
|
||||
mockMessageModel(mockCreateMessage);
|
||||
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
|
||||
vi.mocked(AgentModel).mockImplementation(
|
||||
() => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any,
|
||||
@@ -733,6 +783,94 @@ describe('aiChatRouter', () => {
|
||||
expect(mockTouchUpdatedAt).toHaveBeenCalledWith('agent-1');
|
||||
});
|
||||
|
||||
it('should keep the message response when agent updatedAt touch fails', async () => {
|
||||
const consoleErrorSpy = vi.spyOn(console, 'error').mockImplementation(() => undefined);
|
||||
const mockCreateTopic = vi.fn().mockResolvedValue({ id: 't1' });
|
||||
const mockCreateMessage = vi
|
||||
.fn()
|
||||
.mockResolvedValueOnce({ id: 'm-user' })
|
||||
.mockResolvedValueOnce({ id: 'm-assistant' });
|
||||
const mockGet = vi.fn().mockResolvedValue({
|
||||
messages: [{ id: 'm-user' }, { id: 'm-assistant' }],
|
||||
topics: undefined,
|
||||
});
|
||||
const touchError = new Error('touch failed');
|
||||
const mockTouchUpdatedAt = vi.fn().mockRejectedValue(touchError);
|
||||
|
||||
try {
|
||||
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
|
||||
mockMessageModel(mockCreateMessage);
|
||||
vi.mocked(AiChatService).mockImplementation(
|
||||
() => ({ getMessagesAndTopics: mockGet }) as any,
|
||||
);
|
||||
vi.mocked(AgentModel).mockImplementation(
|
||||
() => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any,
|
||||
);
|
||||
|
||||
const caller = aiChatRouter.createCaller(mockCtx as any);
|
||||
|
||||
const res = await caller.sendMessageInServer({
|
||||
agentId: 'agent-1',
|
||||
newAssistantMessage: { model: 'gpt-4o', provider: 'openai' },
|
||||
newTopic: { title: 'New Topic' },
|
||||
newUserMessage: { content: 'hi' },
|
||||
sessionId: 's1',
|
||||
} as any);
|
||||
|
||||
expect(res.userMessageId).toBe('m-user');
|
||||
expect(res.assistantMessageId).toBe('m-assistant');
|
||||
expect(mockTouchUpdatedAt).toHaveBeenCalledWith('agent-1');
|
||||
expect(consoleErrorSpy).toHaveBeenCalledWith(
|
||||
'[aiChat] Failed to touch agent updatedAt:',
|
||||
touchError,
|
||||
);
|
||||
} finally {
|
||||
consoleErrorSpy.mockRestore();
|
||||
}
|
||||
});
|
||||
|
||||
it('should create messages while agent updatedAt touch is still pending', async () => {
|
||||
const mockCreateTopic = vi.fn().mockResolvedValue({ id: 't1' });
|
||||
const mockCreateMessage = vi
|
||||
.fn()
|
||||
.mockResolvedValueOnce({ id: 'm-user' })
|
||||
.mockResolvedValueOnce({ id: 'm-assistant' });
|
||||
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: [{}] });
|
||||
let resolveTouchUpdatedAt: () => void = () => {};
|
||||
const touchUpdatedAtPromise = new Promise<void>((resolve) => {
|
||||
resolveTouchUpdatedAt = resolve;
|
||||
});
|
||||
const mockTouchUpdatedAt = vi.fn(() => touchUpdatedAtPromise);
|
||||
|
||||
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
|
||||
const mockCreateUserAndAssistantMessages = mockMessageModel(mockCreateMessage);
|
||||
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
|
||||
vi.mocked(AgentModel).mockImplementation(
|
||||
() => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any,
|
||||
);
|
||||
|
||||
const caller = aiChatRouter.createCaller(mockCtx as any);
|
||||
|
||||
const request = caller.sendMessageInServer({
|
||||
agentId: 'agent-1',
|
||||
newAssistantMessage: { model: 'gpt-4o', provider: 'openai' },
|
||||
newTopic: { title: 'New Topic' },
|
||||
newUserMessage: { content: 'hi' },
|
||||
sessionId: 's1',
|
||||
} as any);
|
||||
|
||||
await flushAsyncTasks();
|
||||
|
||||
try {
|
||||
expect(mockTouchUpdatedAt).toHaveBeenCalledWith('agent-1');
|
||||
expect(mockCreateUserAndAssistantMessages).toHaveBeenCalledTimes(1);
|
||||
} finally {
|
||||
resolveTouchUpdatedAt();
|
||||
}
|
||||
|
||||
await request;
|
||||
});
|
||||
|
||||
it('should not touch agent updatedAt when creating topic without agentId', async () => {
|
||||
const mockCreateTopic = vi.fn().mockResolvedValue({ id: 't1' });
|
||||
const mockCreateMessage = vi
|
||||
@@ -743,7 +881,7 @@ describe('aiChatRouter', () => {
|
||||
const mockTouchUpdatedAt = vi.fn().mockResolvedValue(undefined);
|
||||
|
||||
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
|
||||
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any);
|
||||
mockMessageModel(mockCreateMessage);
|
||||
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
|
||||
vi.mocked(AgentModel).mockImplementation(
|
||||
() => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any,
|
||||
@@ -771,7 +909,7 @@ describe('aiChatRouter', () => {
|
||||
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
|
||||
const mockTouchUpdatedAt = vi.fn().mockResolvedValue(undefined);
|
||||
|
||||
vi.mocked(MessageModel).mockImplementation(() => ({ create: mockCreateMessage }) as any);
|
||||
mockMessageModel(mockCreateMessage);
|
||||
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
|
||||
vi.mocked(AgentModel).mockImplementation(
|
||||
() => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { type CreateMessageParams, type SendMessageServerResponse } from '@lobechat/types';
|
||||
import type { CreateMessageParams, SendMessageServerResponse } from '@lobechat/types';
|
||||
import { AiSendMessageServerSchema, RequestTrigger, StructureOutputSchema } from '@lobechat/types';
|
||||
import { createTimingHelpers, createTimingRequestId } from '@lobechat/utils';
|
||||
import debug from 'debug';
|
||||
|
||||
import { LOADING_FLAT } from '@/const/message';
|
||||
@@ -15,6 +16,9 @@ import { AiChatService } from '@/server/services/aiChat';
|
||||
import { FileService } from '@/server/services/file';
|
||||
|
||||
const log = debug('lobe-lambda-router:ai-chat');
|
||||
const { createPrefixedTimingContext, logTiming, runTimedStage } = createTimingHelpers(
|
||||
'lobe-server:chat:lobehub:timing',
|
||||
);
|
||||
|
||||
const aiChatProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
@@ -59,6 +63,17 @@ export const aiChatRouter = router({
|
||||
sendMessageInServer: aiChatProcedure
|
||||
.input(AiSendMessageServerSchema)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const timingContext =
|
||||
input.newAssistantMessage.provider === 'lobehub'
|
||||
? { requestId: createTimingRequestId(), startedAt: Date.now() }
|
||||
: undefined;
|
||||
logTiming(timingContext, 'lambda.aiChat.sendMessageInServer:start', {
|
||||
hasNewThread: !!input.newThread,
|
||||
hasNewTopic: !!input.newTopic,
|
||||
hasSessionId: !!input.sessionId,
|
||||
hasTopicId: !!input.topicId,
|
||||
preloadCount: input.preloadMessages?.length ?? 0,
|
||||
});
|
||||
log('sendMessageInServer called for agentId: %s', input.agentId);
|
||||
log(
|
||||
'topicId: %s, newTopic: %O, newThread: %O',
|
||||
@@ -68,7 +83,12 @@ export const aiChatRouter = router({
|
||||
);
|
||||
let sessionId = input.sessionId;
|
||||
if (!sessionId) {
|
||||
const context = await resolveContext(input, ctx.serverDB, ctx.userId);
|
||||
const context = await runTimedStage(
|
||||
timingContext,
|
||||
'lambda.aiChat.resolveContext',
|
||||
() => resolveContext(input, ctx.serverDB, ctx.userId),
|
||||
{ hasAgentId: !!input.agentId },
|
||||
);
|
||||
if (!!context.sessionId) sessionId = context.sessionId;
|
||||
}
|
||||
|
||||
@@ -77,27 +97,54 @@ export const aiChatRouter = router({
|
||||
let createdThreadId: string | undefined;
|
||||
|
||||
let isCreateNewTopic = false;
|
||||
let agentTouchUpdatedAtTask: Promise<void> | undefined;
|
||||
|
||||
// create topic if there should be a new topic
|
||||
if (input.newTopic) {
|
||||
log('creating new topic with title: %s', input.newTopic.title);
|
||||
const topicItem = await ctx.topicModel.create({
|
||||
agentId: input.agentId,
|
||||
groupId: input.groupId,
|
||||
messages: input.newTopic.topicMessageIds,
|
||||
metadata: input.newTopic.metadata,
|
||||
sessionId,
|
||||
title: input.newTopic.title,
|
||||
trigger: input.newTopic.trigger,
|
||||
});
|
||||
const topicItem = await runTimedStage(
|
||||
timingContext,
|
||||
'lambda.aiChat.topic.create',
|
||||
() => {
|
||||
const payload = {
|
||||
agentId: input.agentId,
|
||||
groupId: input.groupId,
|
||||
messages: input.newTopic!.topicMessageIds,
|
||||
metadata: input.newTopic!.metadata,
|
||||
sessionId,
|
||||
title: input.newTopic!.title,
|
||||
trigger: input.newTopic!.trigger,
|
||||
};
|
||||
const modelTiming = createPrefixedTimingContext(
|
||||
timingContext,
|
||||
'lambda.aiChat.topic.create',
|
||||
);
|
||||
return modelTiming
|
||||
? ctx.topicModel.create(payload, undefined, modelTiming)
|
||||
: ctx.topicModel.create(payload);
|
||||
},
|
||||
{
|
||||
messageCount: input.newTopic.topicMessageIds?.length ?? 0,
|
||||
trigger: input.newTopic.trigger,
|
||||
},
|
||||
);
|
||||
topicId = topicItem.id;
|
||||
isCreateNewTopic = true;
|
||||
log('new topic created with id: %s', topicId);
|
||||
|
||||
// update agent's updatedAt to reflect new activity
|
||||
if (input.agentId) {
|
||||
await ctx.agentModel.touchUpdatedAt(input.agentId);
|
||||
log('agent updatedAt touched for agentId: %s', input.agentId);
|
||||
agentTouchUpdatedAtTask = runTimedStage(
|
||||
timingContext,
|
||||
'lambda.aiChat.agent.touchUpdatedAt',
|
||||
async () => {
|
||||
await ctx.agentModel.touchUpdatedAt(input.agentId!);
|
||||
},
|
||||
{ hasAgentId: true },
|
||||
).catch((error) => {
|
||||
console.error('[aiChat] Failed to touch agent updatedAt:', error);
|
||||
});
|
||||
log('agent updatedAt touch scheduled for agentId: %s', input.agentId);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,13 +155,19 @@ export const aiChatRouter = router({
|
||||
input.newThread.sourceMessageId,
|
||||
input.newThread.type,
|
||||
);
|
||||
const threadItem = await ctx.threadModel.create({
|
||||
parentThreadId: input.newThread.parentThreadId,
|
||||
sourceMessageId: input.newThread.sourceMessageId,
|
||||
title: input.newThread.title,
|
||||
topicId,
|
||||
type: input.newThread.type,
|
||||
});
|
||||
const threadItem = await runTimedStage(
|
||||
timingContext,
|
||||
'lambda.aiChat.thread.create',
|
||||
() =>
|
||||
ctx.threadModel.create({
|
||||
parentThreadId: input.newThread!.parentThreadId,
|
||||
sourceMessageId: input.newThread!.sourceMessageId,
|
||||
title: input.newThread!.title,
|
||||
topicId,
|
||||
type: input.newThread!.type,
|
||||
}),
|
||||
{ threadType: input.newThread.type },
|
||||
);
|
||||
if (threadItem) {
|
||||
threadId = threadItem.id;
|
||||
createdThreadId = threadItem.id;
|
||||
@@ -127,24 +180,40 @@ export const aiChatRouter = router({
|
||||
if (input.preloadMessages?.length) {
|
||||
log('creating %d preload messages before user message', input.preloadMessages.length);
|
||||
|
||||
for (const preloadMessage of input.preloadMessages) {
|
||||
const preloadItem = await ctx.messageModel.create({
|
||||
agentId: input.agentId,
|
||||
content: preloadMessage.content,
|
||||
groupId: input.groupId,
|
||||
metadata: preloadMessage.metadata,
|
||||
parentId,
|
||||
plugin: preloadMessage.plugin as CreateMessageParams['plugin'],
|
||||
role: preloadMessage.role,
|
||||
sessionId,
|
||||
threadId,
|
||||
tool_call_id: preloadMessage.tool_call_id,
|
||||
tools: preloadMessage.tools as CreateMessageParams['tools'],
|
||||
topicId,
|
||||
});
|
||||
parentId = await runTimedStage(
|
||||
timingContext,
|
||||
'lambda.aiChat.preloadMessages.create',
|
||||
async () => {
|
||||
let latestParentId = parentId;
|
||||
for (const preloadMessage of input.preloadMessages!) {
|
||||
const payload = {
|
||||
agentId: input.agentId,
|
||||
content: preloadMessage.content,
|
||||
groupId: input.groupId,
|
||||
metadata: preloadMessage.metadata,
|
||||
parentId: latestParentId,
|
||||
plugin: preloadMessage.plugin as CreateMessageParams['plugin'],
|
||||
role: preloadMessage.role,
|
||||
sessionId,
|
||||
threadId,
|
||||
tool_call_id: preloadMessage.tool_call_id,
|
||||
tools: preloadMessage.tools as CreateMessageParams['tools'],
|
||||
topicId,
|
||||
};
|
||||
const modelTiming = createPrefixedTimingContext(
|
||||
timingContext,
|
||||
'lambda.aiChat.preloadMessages.create',
|
||||
);
|
||||
const preloadItem = await (modelTiming
|
||||
? ctx.messageModel.create(payload, undefined, modelTiming)
|
||||
: ctx.messageModel.create(payload));
|
||||
|
||||
parentId = preloadItem.id;
|
||||
}
|
||||
latestParentId = preloadItem.id;
|
||||
}
|
||||
return latestParentId;
|
||||
},
|
||||
{ count: input.preloadMessages.length },
|
||||
);
|
||||
}
|
||||
|
||||
// create user message
|
||||
@@ -161,58 +230,95 @@ export const aiChatRouter = router({
|
||||
}
|
||||
: undefined;
|
||||
|
||||
const userMessageItem = await ctx.messageModel.create({
|
||||
agentId: input.agentId,
|
||||
content: input.newUserMessage.content,
|
||||
editorData: input.newUserMessage.editorData,
|
||||
files: input.newUserMessage.files,
|
||||
groupId: input.groupId,
|
||||
metadata: userMessageMetadata,
|
||||
parentId,
|
||||
role: 'user',
|
||||
sessionId,
|
||||
threadId,
|
||||
topicId,
|
||||
});
|
||||
const createMessagePairPromise = runTimedStage(
|
||||
timingContext,
|
||||
'lambda.aiChat.messages.createUserAndAssistant',
|
||||
() => {
|
||||
const userMessage = {
|
||||
agentId: input.agentId,
|
||||
content: input.newUserMessage.content,
|
||||
editorData: input.newUserMessage.editorData,
|
||||
files: input.newUserMessage.files,
|
||||
groupId: input.groupId,
|
||||
metadata: userMessageMetadata,
|
||||
parentId,
|
||||
role: 'user',
|
||||
sessionId,
|
||||
threadId,
|
||||
topicId,
|
||||
} satisfies CreateMessageParams;
|
||||
const assistantMessage = {
|
||||
agentId: input.agentId,
|
||||
content: LOADING_FLAT,
|
||||
groupId: input.groupId,
|
||||
metadata: input.newAssistantMessage.metadata,
|
||||
model: input.newAssistantMessage.model,
|
||||
provider: input.newAssistantMessage.provider,
|
||||
role: 'assistant',
|
||||
sessionId,
|
||||
threadId,
|
||||
topicId,
|
||||
} satisfies CreateMessageParams;
|
||||
const modelTiming = createPrefixedTimingContext(
|
||||
timingContext,
|
||||
'lambda.aiChat.messages.createUserAndAssistant',
|
||||
);
|
||||
return ctx.messageModel.createUserAndAssistantMessages(
|
||||
{ assistantMessage, userMessage },
|
||||
{
|
||||
...(modelTiming ? { timing: modelTiming } : {}),
|
||||
touchTopicUpdatedAt: !isCreateNewTopic,
|
||||
},
|
||||
);
|
||||
},
|
||||
{
|
||||
contentLength: input.newUserMessage.content.length,
|
||||
fileCount: input.newUserMessage.files?.length ?? 0,
|
||||
model: input.newAssistantMessage.model,
|
||||
provider: input.newAssistantMessage.provider,
|
||||
},
|
||||
);
|
||||
const { assistantMessage: assistantMessageItem, userMessage: userMessageItem } =
|
||||
agentTouchUpdatedAtTask
|
||||
? (await Promise.all([createMessagePairPromise, agentTouchUpdatedAtTask]))[0]
|
||||
: await createMessagePairPromise;
|
||||
|
||||
const messageId = userMessageItem.id;
|
||||
log('user message created with id: %s', messageId);
|
||||
|
||||
// create assistant message
|
||||
log(
|
||||
'creating assistant message with model: %s, provider: %s, metadata: %O',
|
||||
input.newAssistantMessage.model,
|
||||
input.newAssistantMessage.provider,
|
||||
input.newAssistantMessage.metadata,
|
||||
);
|
||||
const assistantMessageItem = await ctx.messageModel.create({
|
||||
agentId: input.agentId,
|
||||
content: LOADING_FLAT,
|
||||
groupId: input.groupId,
|
||||
metadata: input.newAssistantMessage.metadata,
|
||||
model: input.newAssistantMessage.model,
|
||||
parentId: messageId,
|
||||
provider: input.newAssistantMessage.provider,
|
||||
role: 'assistant',
|
||||
sessionId,
|
||||
threadId,
|
||||
topicId,
|
||||
});
|
||||
log('assistant message created with id: %s', assistantMessageItem.id);
|
||||
|
||||
// retrieve latest messages and topic with
|
||||
log('retrieving messages and topics');
|
||||
const { messages, topics } = await ctx.aiChatService.getMessagesAndTopics({
|
||||
agentId: input.agentId,
|
||||
groupId: input.groupId,
|
||||
includeTopic: isCreateNewTopic,
|
||||
sessionId,
|
||||
threadId,
|
||||
topicFilter: input.topicFilter,
|
||||
topicId,
|
||||
});
|
||||
const { messages, topics } = await runTimedStage(
|
||||
timingContext,
|
||||
'lambda.aiChat.messagesAndTopics.query',
|
||||
() =>
|
||||
ctx.aiChatService.getMessagesAndTopics({
|
||||
agentId: input.agentId,
|
||||
groupId: input.groupId,
|
||||
includeTopic: isCreateNewTopic,
|
||||
sessionId,
|
||||
threadId,
|
||||
topicFilter: input.topicFilter,
|
||||
topicId,
|
||||
topicPageSize: input.topicPageSize,
|
||||
...(timingContext
|
||||
? {
|
||||
timingRequestId: timingContext.requestId,
|
||||
timingStartedAt: timingContext.startedAt,
|
||||
}
|
||||
: {}),
|
||||
}),
|
||||
{ includeTopic: isCreateNewTopic },
|
||||
);
|
||||
|
||||
log('retrieved %d messages, %d topics', messages.length, topics?.items?.length ?? 0);
|
||||
logTiming(timingContext, 'lambda.aiChat.sendMessageInServer:done', {
|
||||
isCreateNewTopic,
|
||||
messageCount: messages.length,
|
||||
topicCount: topics?.items?.length ?? 0,
|
||||
});
|
||||
|
||||
return {
|
||||
assistantMessageId: assistantMessageItem.id,
|
||||
|
||||
@@ -4,6 +4,7 @@ import {
|
||||
UpdateMessagePluginSchema,
|
||||
UpdateMessageRAGParamsSchema,
|
||||
} from '@lobechat/types';
|
||||
import { createTimingHelpers, createTimingRequestId } from '@lobechat/utils';
|
||||
import { TRPCError } from '@trpc/server';
|
||||
import { z } from 'zod';
|
||||
|
||||
@@ -18,6 +19,8 @@ import { MessageService } from '@/server/services/message';
|
||||
import { resolveAgentIdFromSession, resolveContext } from './_helpers/resolveContext';
|
||||
import { basicContextSchema } from './_schema/context';
|
||||
|
||||
const { logTiming, runTimedStage } = createTimingHelpers('lobe-server:chat:lobehub:timing');
|
||||
|
||||
const messageProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
@@ -316,9 +319,37 @@ export const messageRouter = router({
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
const { id, value, agentId, ...options } = input;
|
||||
const resolved = await resolveContext({ agentId, ...options }, ctx.serverDB, ctx.userId);
|
||||
const timingContext = { requestId: createTimingRequestId(), startedAt: Date.now() };
|
||||
logTiming(timingContext, 'lambda.message.update:start', {
|
||||
hasAgentId: !!agentId,
|
||||
hasTopicId: !!options.topicId,
|
||||
valueKeys: Object.keys(value ?? {}),
|
||||
});
|
||||
|
||||
return ctx.messageService.updateMessage(id, value as any, resolved);
|
||||
const resolved = await runTimedStage(
|
||||
timingContext,
|
||||
'lambda.message.update.resolveContext',
|
||||
() => resolveContext({ agentId, ...options }, ctx.serverDB, ctx.userId),
|
||||
{ hasAgentId: !!agentId },
|
||||
);
|
||||
|
||||
const result = await runTimedStage(
|
||||
timingContext,
|
||||
'lambda.message.update.service',
|
||||
() =>
|
||||
ctx.messageService.updateMessage(id, value as any, {
|
||||
...resolved,
|
||||
timingRequestId: timingContext.requestId,
|
||||
timingStartedAt: timingContext.startedAt,
|
||||
}),
|
||||
{ hasResolvedTopicId: !!resolved.topicId },
|
||||
);
|
||||
|
||||
logTiming(timingContext, 'lambda.message.update:done', {
|
||||
messageCount: result.messages?.length ?? 0,
|
||||
success: result.success,
|
||||
});
|
||||
return result;
|
||||
}),
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { type LobeChatDatabase } from '@lobechat/database';
|
||||
import type { LobeChatDatabase } from '@lobechat/database';
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { MessageModel } from '@/database/models/message';
|
||||
@@ -31,13 +31,18 @@ describe('AiChatService', () => {
|
||||
groupId: 'group-1',
|
||||
includeTopic: true,
|
||||
sessionId: 's1',
|
||||
topicPageSize: 20,
|
||||
});
|
||||
|
||||
expect(mockQueryMessages).toHaveBeenCalledWith(
|
||||
{ agentId: 'agent-1', groupId: 'group-1', includeTopic: true, sessionId: 's1' },
|
||||
expect.objectContaining({ postProcessUrl: expect.any(Function) }),
|
||||
);
|
||||
expect(mockQueryTopics).toHaveBeenCalledWith({ agentId: 'agent-1', groupId: 'group-1' });
|
||||
expect(mockQueryTopics).toHaveBeenCalledWith({
|
||||
agentId: 'agent-1',
|
||||
groupId: 'group-1',
|
||||
pageSize: 20,
|
||||
});
|
||||
expect(res.messages).toEqual([{ id: 'm1' }]);
|
||||
expect(res.topics).toEqual([{ id: 't1' }]);
|
||||
});
|
||||
@@ -63,6 +68,7 @@ describe('AiChatService', () => {
|
||||
excludeStatuses: ['completed'],
|
||||
excludeTriggers: ['cron', 'eval'],
|
||||
},
|
||||
topicPageSize: 20,
|
||||
});
|
||||
|
||||
expect(mockQueryTopics).toHaveBeenCalledWith({
|
||||
@@ -70,12 +76,17 @@ describe('AiChatService', () => {
|
||||
excludeStatuses: ['completed'],
|
||||
excludeTriggers: ['cron', 'eval'],
|
||||
groupId: undefined,
|
||||
pageSize: 20,
|
||||
});
|
||||
// topicFilter must not leak into messageModel.query
|
||||
expect(mockQueryMessages).toHaveBeenCalledWith(
|
||||
expect.not.objectContaining({ topicFilter: expect.anything() }),
|
||||
expect.objectContaining({ postProcessUrl: expect.any(Function) }),
|
||||
);
|
||||
expect(mockQueryMessages).toHaveBeenCalledWith(
|
||||
expect.not.objectContaining({ topicPageSize: 20 }),
|
||||
expect.objectContaining({ postProcessUrl: expect.any(Function) }),
|
||||
);
|
||||
});
|
||||
|
||||
it('getMessagesAndTopics should not query topics when includeTopic is false', async () => {
|
||||
|
||||
@@ -1,9 +1,33 @@
|
||||
import { type LobeChatDatabase } from '@lobechat/database';
|
||||
import type { LobeChatDatabase } from '@lobechat/database';
|
||||
import { createTimingHelpers } from '@lobechat/utils';
|
||||
|
||||
import { MessageModel } from '@/database/models/message';
|
||||
import { TopicModel } from '@/database/models/topic';
|
||||
import { FileService } from '@/server/services/file';
|
||||
|
||||
const { createPrefixedTimingContext, runTimedStage, toTimingContext } = createTimingHelpers(
|
||||
'lobe-server:chat:lobehub:timing',
|
||||
);
|
||||
|
||||
interface GetMessagesAndTopicsParams {
|
||||
agentId?: string;
|
||||
current?: number;
|
||||
groupId?: string;
|
||||
includeTopic?: boolean;
|
||||
pageSize?: number;
|
||||
sessionId?: string;
|
||||
threadId?: string;
|
||||
timingRequestId?: string;
|
||||
timingStartedAt?: number;
|
||||
topicFilter?: {
|
||||
excludeStatuses?: string[];
|
||||
excludeTriggers?: string[];
|
||||
includeTriggers?: string[];
|
||||
};
|
||||
topicId?: string;
|
||||
topicPageSize?: number;
|
||||
}
|
||||
|
||||
export class AiChatService {
|
||||
private userId: string;
|
||||
private messageModel: MessageModel;
|
||||
@@ -18,32 +42,48 @@ export class AiChatService {
|
||||
this.fileService = new FileService(serverDB, userId);
|
||||
}
|
||||
|
||||
async getMessagesAndTopics(params: {
|
||||
agentId?: string;
|
||||
current?: number;
|
||||
groupId?: string;
|
||||
includeTopic?: boolean;
|
||||
pageSize?: number;
|
||||
sessionId?: string;
|
||||
threadId?: string;
|
||||
topicFilter?: {
|
||||
excludeStatuses?: string[];
|
||||
excludeTriggers?: string[];
|
||||
includeTriggers?: string[];
|
||||
};
|
||||
topicId?: string;
|
||||
}) {
|
||||
const { topicFilter, ...messageParams } = params;
|
||||
async getMessagesAndTopics(params: GetMessagesAndTopicsParams) {
|
||||
const { topicFilter, topicPageSize, timingRequestId, timingStartedAt, ...messageParams } =
|
||||
params;
|
||||
const timingContext = toTimingContext({ timingRequestId, timingStartedAt });
|
||||
const messageTiming = createPrefixedTimingContext(
|
||||
timingContext,
|
||||
'lambda.aiChat.messagesAndTopics.messageModel.query',
|
||||
);
|
||||
const topicTiming = createPrefixedTimingContext(
|
||||
timingContext,
|
||||
'lambda.aiChat.messagesAndTopics.topicModel.query',
|
||||
);
|
||||
const messageQueryPromise = runTimedStage(
|
||||
timingContext,
|
||||
'lambda.aiChat.messagesAndTopics.messageModel.query',
|
||||
() =>
|
||||
this.messageModel.query(messageParams, {
|
||||
postProcessUrl: (path) => this.fileService.getFullFileUrl(path),
|
||||
...(messageTiming ? { timing: messageTiming } : {}),
|
||||
}),
|
||||
{
|
||||
hasAgentId: !!params.agentId,
|
||||
hasThreadId: !!params.threadId,
|
||||
hasTopicId: !!params.topicId,
|
||||
},
|
||||
);
|
||||
const [messages, topics] = await Promise.all([
|
||||
this.messageModel.query(messageParams, {
|
||||
postProcessUrl: (path) => this.fileService.getFullFileUrl(path),
|
||||
}),
|
||||
messageQueryPromise,
|
||||
params.includeTopic
|
||||
? this.topicModel.query({
|
||||
agentId: params.agentId,
|
||||
groupId: params.groupId,
|
||||
...topicFilter,
|
||||
})
|
||||
? runTimedStage(
|
||||
timingContext,
|
||||
'lambda.aiChat.messagesAndTopics.topicModel.query',
|
||||
() =>
|
||||
this.topicModel.query({
|
||||
agentId: params.agentId,
|
||||
groupId: params.groupId,
|
||||
pageSize: topicPageSize,
|
||||
...(topicTiming ? { timing: topicTiming } : {}),
|
||||
...topicFilter,
|
||||
}),
|
||||
{ hasAgentId: !!params.agentId, hasGroupId: !!params.groupId },
|
||||
)
|
||||
: undefined,
|
||||
]);
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
type UIChatMessage,
|
||||
type UpdateMessageParams,
|
||||
} from '@lobechat/types';
|
||||
import { createTimingHelpers, getDurationMs } from '@lobechat/utils';
|
||||
|
||||
import { MessageModel } from '@/database/models/message';
|
||||
|
||||
@@ -15,9 +16,26 @@ interface QueryOptions {
|
||||
groupId?: string | null;
|
||||
sessionId?: string | null;
|
||||
threadId?: string | null;
|
||||
timingRequestId?: string;
|
||||
timingStartedAt?: number;
|
||||
topicId?: string | null;
|
||||
}
|
||||
|
||||
const { createPrefixedTimingContext, logTiming, toTimingContext } = createTimingHelpers(
|
||||
'lobe-server:chat:lobehub:timing',
|
||||
);
|
||||
|
||||
const logMessageTiming = (
|
||||
options: QueryOptions | undefined,
|
||||
event: string,
|
||||
metadata?: Record<string, unknown>,
|
||||
) => {
|
||||
logTiming(toTimingContext(options), event, metadata);
|
||||
};
|
||||
|
||||
const createModelTiming = (options: QueryOptions | undefined, prefix: string) =>
|
||||
createPrefixedTimingContext(toTimingContext(options), prefix);
|
||||
|
||||
interface CreateMessageResult {
|
||||
id: string;
|
||||
messages: any[];
|
||||
@@ -70,15 +88,25 @@ export class MessageService {
|
||||
options.sessionId === undefined &&
|
||||
options.topicId === undefined)
|
||||
) {
|
||||
logMessageTiming(options, 'lambda.message.update.queryMessages:skipped');
|
||||
return { success: true };
|
||||
}
|
||||
|
||||
const { agentId, sessionId, topicId, groupId, threadId } = options;
|
||||
|
||||
const queryStartedAt = Date.now();
|
||||
const modelTiming = createModelTiming(options, 'lambda.message.update.queryMessages');
|
||||
const messages = await this.messageModel.query(
|
||||
{ agentId, groupId, sessionId, threadId, topicId },
|
||||
this.getQueryOptions(),
|
||||
{
|
||||
...this.getQueryOptions(),
|
||||
...(modelTiming ? { timing: modelTiming } : {}),
|
||||
},
|
||||
);
|
||||
logMessageTiming(options, 'lambda.message.update.queryMessages:done', {
|
||||
messageCount: messages.length,
|
||||
stageMs: getDurationMs(queryStartedAt),
|
||||
});
|
||||
|
||||
return { messages, success: true };
|
||||
}
|
||||
@@ -188,7 +216,18 @@ export class MessageService {
|
||||
value: UpdateMessageParams,
|
||||
options: QueryOptions,
|
||||
): Promise<{ messages?: UIChatMessage[]; success: boolean }> {
|
||||
await this.messageModel.update(id, value as any);
|
||||
const updateStartedAt = Date.now();
|
||||
const modelTiming = createModelTiming(options, 'lambda.message.update.dbUpdate');
|
||||
if (modelTiming) {
|
||||
await this.messageModel.update(id, value as any, modelTiming);
|
||||
} else {
|
||||
await this.messageModel.update(id, value as any);
|
||||
}
|
||||
logMessageTiming(options, 'lambda.message.update.dbUpdate:done', {
|
||||
stageMs: getDurationMs(updateStartedAt),
|
||||
valueKeys: Object.keys(value ?? {}),
|
||||
});
|
||||
|
||||
return this.queryWithSuccess(options);
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import { chatService } from '@/services/chat';
|
||||
import { messageService } from '@/services/message';
|
||||
import * as agentGroupStore from '@/store/agentGroup';
|
||||
import { messageMapKey } from '@/store/chat/utils/messageMapKey';
|
||||
import { topicMapKey } from '@/store/chat/utils/topicMapKey';
|
||||
import { getSessionStoreState } from '@/store/session';
|
||||
import * as toolStoreModule from '@/store/tool';
|
||||
|
||||
@@ -1622,7 +1623,6 @@ describe('ConversationLifecycle actions', () => {
|
||||
createMockMessage({ id: 'new-user-msg', role: 'user', topicId: newTopicId }),
|
||||
createMockMessage({ id: 'new-assistant-msg', role: 'assistant', topicId: newTopicId }),
|
||||
],
|
||||
topics: { items: [{ id: newTopicId, title: 'New Topic' }], total: 1 },
|
||||
topicId: newTopicId,
|
||||
isCreateNewTopic: true,
|
||||
assistantMessageId: 'new-assistant-msg',
|
||||
@@ -1648,6 +1648,12 @@ describe('ConversationLifecycle actions', () => {
|
||||
// After new topic creation, the _new key should be cleared
|
||||
const messagesInNewKey = useChatStore.getState().messagesMap[newKey];
|
||||
expect(messagesInNewKey ?? []).toHaveLength(0);
|
||||
|
||||
const newTopicKey = messageMapKey({ agentId, topicId: newTopicId });
|
||||
expect(useChatStore.getState().messagesMap[newTopicKey]).toHaveLength(2);
|
||||
expect(useChatStore.getState().topicDataMap[topicMapKey({ agentId })]?.items[0]).toEqual(
|
||||
expect.objectContaining({ id: newTopicId }),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -523,6 +523,7 @@ export class ConversationLifecycleActionImpl {
|
||||
operationContext.agentId,
|
||||
operationContext.groupId ?? undefined,
|
||||
),
|
||||
topicPageSize: systemStatusSelectors.topicPageSize(useGlobalStore.getState()),
|
||||
topicId: operationContext.topicId ?? undefined,
|
||||
},
|
||||
abortController,
|
||||
@@ -712,6 +713,7 @@ export class ConversationLifecycleActionImpl {
|
||||
const toolContext = formatSelectedToolsContext(dedupedTools);
|
||||
const contextSuffix = [skillContext, toolContext].filter(Boolean).join('\n');
|
||||
const persistedContent = contextSuffix ? `${message}\n\n${contextSuffix}` : message;
|
||||
const newTopicTitle = message.slice(0, 80) || t('defaultTitle', { ns: 'topic' });
|
||||
|
||||
data = await aiChatService.sendMessageInServer(
|
||||
{
|
||||
@@ -730,6 +732,7 @@ export class ConversationLifecycleActionImpl {
|
||||
operationContext.agentId,
|
||||
operationContext.groupId ?? undefined,
|
||||
),
|
||||
topicPageSize: systemStatusSelectors.topicPageSize(useGlobalStore.getState()),
|
||||
threadId: operationContext.threadId ?? undefined,
|
||||
// Support creating new thread along with message
|
||||
newThread: newThread
|
||||
@@ -741,7 +744,7 @@ export class ConversationLifecycleActionImpl {
|
||||
newTopic: !topicId
|
||||
? {
|
||||
topicMessageIds: forceNewTopicFromExisting ? [] : messages.map((m) => m.id),
|
||||
title: message.slice(0, 80) || t('defaultTitle', { ns: 'topic' }),
|
||||
title: newTopicTitle,
|
||||
}
|
||||
: undefined,
|
||||
agentId: operationContext.agentId,
|
||||
@@ -757,7 +760,7 @@ export class ConversationLifecycleActionImpl {
|
||||
abortController,
|
||||
);
|
||||
// Use created topicId/threadId if available, otherwise use original from context
|
||||
let finalTopicId = operationContext.topicId;
|
||||
let finalTopicId = data.topicId ?? operationContext.topicId;
|
||||
const finalThreadId = data.createdThreadId ?? operationContext.threadId;
|
||||
|
||||
// refresh the total data
|
||||
@@ -780,6 +783,18 @@ export class ConversationLifecycleActionImpl {
|
||||
// Record the created topicId in metadata (not context)
|
||||
this.#get().updateOperationMetadata(operationId, { createdTopicId: data.topicId });
|
||||
}
|
||||
} else if (data.isCreateNewTopic && data.topicId && !context.isolatedTopic) {
|
||||
this.#get().internal_dispatchTopic(
|
||||
{
|
||||
type: 'addTopic',
|
||||
value: {
|
||||
id: data.topicId,
|
||||
title: newTopicTitle,
|
||||
},
|
||||
},
|
||||
'sendMessage/createTopicPlaceholder',
|
||||
);
|
||||
this.#get().updateOperationMetadata(operationId, { createdTopicId: data.topicId });
|
||||
} else if (operationContext.topicId) {
|
||||
// Optimistically update topic's updatedAt so sidebar re-groups immediately
|
||||
this.#get().internal_dispatchTopic({
|
||||
|
||||
Reference in New Issue
Block a user