mirror of
https://github.com/lobehub/lobe-chat.git
synced 2026-06-18 13:25:45 +00:00
✨ feat(userMemories): added user memory request, implemented workflow trigger (#11749)
This commit is contained in:
@@ -62,5 +62,23 @@
|
||||
"tab.preferences": "Preferences",
|
||||
"tab.search": "Search",
|
||||
"viewMode.masonry": "Masonry",
|
||||
"viewMode.timeline": "Timeline"
|
||||
"viewMode.timeline": "Timeline",
|
||||
"analysis.action.button": "Request Memory",
|
||||
"analysis.modal.cancel": "Cancel",
|
||||
"analysis.modal.helper": "Optional time window. By default, only conversations not yet analyzed will be processed. Choose a window to limit analysis.",
|
||||
"analysis.modal.rangePlaceholder": "No time range selected, all conversations will be analyzed",
|
||||
"analysis.modal.rangeSelected": "Analyze conversations from {{start}} to {{end}}",
|
||||
"analysis.modal.submit": "Request Analysis",
|
||||
"analysis.modal.title": "Analyze conversations to generate memories",
|
||||
"analysis.range.all": "All conversations",
|
||||
"analysis.range.end": "Today",
|
||||
"analysis.range.start": "Beginning",
|
||||
"analysis.status.errorTitle": "Analysis request failed",
|
||||
"analysis.status.progress": "Processed {{completed}} / {{total}} topics",
|
||||
"analysis.status.progressUnknown": "Processed {{completed}} topics",
|
||||
"analysis.status.tip": "We're processing your conversations to generate memories. This may take a few minutes - thanks for your patience.",
|
||||
"analysis.status.title": "Memory analysis in progress",
|
||||
"analysis.toast.deduped": "A request is already running - please wait...",
|
||||
"analysis.toast.failed": "Analysis request failed, please retry.",
|
||||
"analysis.toast.started": "Memory analysis started. It will update automatically when done."
|
||||
}
|
||||
|
||||
@@ -62,5 +62,23 @@
|
||||
"tab.preferences": "偏好",
|
||||
"tab.search": "搜索",
|
||||
"viewMode.masonry": "瀑布流",
|
||||
"viewMode.timeline": "时间线"
|
||||
"viewMode.timeline": "时间线",
|
||||
"analysis.action.button": "请求记忆",
|
||||
"analysis.modal.cancel": "取消",
|
||||
"analysis.modal.helper": "可选的时间窗口,默认会对尚未分析的对话进行分析,请选择用于分析的时间窗口。",
|
||||
"analysis.modal.rangePlaceholder": "未选择时间范围,将分析所有对话",
|
||||
"analysis.modal.rangeSelected": "分析对话 {{start}} 至 {{end}}",
|
||||
"analysis.modal.submit": "请求分析记忆",
|
||||
"analysis.modal.title": "分析对话以生成记忆",
|
||||
"analysis.range.all": "全部对话",
|
||||
"analysis.range.end": "今天",
|
||||
"analysis.range.start": "最初",
|
||||
"analysis.status.errorTitle": "请求分析失败",
|
||||
"analysis.status.progress": "已处理 {{completed}} / {{total}} 个话题",
|
||||
"analysis.status.progressUnknown": "已处理 {{completed}} 个话题",
|
||||
"analysis.status.tip": "正在处理你的对话来生成记忆,可能需要几分钟,请耐心等待。",
|
||||
"analysis.status.title": "记忆分析中",
|
||||
"analysis.toast.deduped": "已有请求仍在处理中,请稍等…",
|
||||
"analysis.toast.failed": "请求分析失败,请重试。",
|
||||
"analysis.toast.started": "记忆分析已开始,完成后会自动更新。"
|
||||
}
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
// @vitest-environment node
|
||||
import { ASYNC_TASK_TIMEOUT } from '@lobechat/business-config/server';
|
||||
import { AsyncTaskStatus, AsyncTaskType } from '@lobechat/types';
|
||||
import {
|
||||
AsyncTaskStatus,
|
||||
AsyncTaskType,
|
||||
type UserMemoryExtractionMetadata,
|
||||
} from '@lobechat/types';
|
||||
import { eq } from 'drizzle-orm';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
@@ -121,6 +125,40 @@ describe('AsyncTaskModel', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('incrementUserMemoryExtractionProgress', () => {
|
||||
it('should increment completedTopics and set status to success when reaching total', async () => {
|
||||
const { id } = await serverDB
|
||||
.insert(asyncTasks)
|
||||
.values({
|
||||
metadata: {
|
||||
progress: {
|
||||
completedTopics: 0,
|
||||
totalTopics: 2,
|
||||
},
|
||||
source: 'chat_topic',
|
||||
},
|
||||
status: AsyncTaskStatus.Pending,
|
||||
type: AsyncTaskType.UserMemoryExtractionWithChatTopic,
|
||||
userId,
|
||||
})
|
||||
.returning()
|
||||
.then((res) => res[0]);
|
||||
|
||||
await asyncTaskModel.incrementUserMemoryExtractionProgress(id);
|
||||
let task = await serverDB.query.asyncTasks.findFirst({ where: eq(asyncTasks.id, id) });
|
||||
const firstMetadata = task?.metadata as UserMemoryExtractionMetadata | undefined;
|
||||
expect(firstMetadata?.progress?.completedTopics).toBe(1);
|
||||
expect(firstMetadata?.progress?.totalTopics).toBe(2);
|
||||
expect(task?.status).toBe(AsyncTaskStatus.Processing);
|
||||
|
||||
await asyncTaskModel.incrementUserMemoryExtractionProgress(id);
|
||||
task = await serverDB.query.asyncTasks.findFirst({ where: eq(asyncTasks.id, id) });
|
||||
const secondMetadata = task?.metadata as UserMemoryExtractionMetadata | undefined;
|
||||
expect(secondMetadata?.progress?.completedTopics).toBe(2);
|
||||
expect(task?.status).toBe(AsyncTaskStatus.Success);
|
||||
});
|
||||
});
|
||||
|
||||
describe('checkTimeoutTasks', () => {
|
||||
it('should mark tasks as error if they timeout', async () => {
|
||||
// Create a task with old timestamp (beyond timeout)
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
import { afterEach, beforeEach, describe, expect, it } from 'vitest';
|
||||
|
||||
import { topics, users } from '../../../schemas';
|
||||
import { LobeChatDatabase } from '../../../type';
|
||||
import { TopicModel } from '../../topic';
|
||||
import { getTestDB } from '../../../core/getTestDB';
|
||||
|
||||
const userId = 'topic-memory-extractor-user';
|
||||
const serverDB: LobeChatDatabase = await getTestDB();
|
||||
const topicModel = new TopicModel(serverDB, userId);
|
||||
|
||||
describe('TopicModel - countTopicsForMemoryExtractor', () => {
|
||||
beforeEach(async () => {
|
||||
await serverDB.delete(users);
|
||||
await serverDB.insert(users).values({ id: userId });
|
||||
});
|
||||
|
||||
afterEach(async () => {
|
||||
await serverDB.delete(users);
|
||||
});
|
||||
|
||||
it('counts only unextracted topics when ignoreExtracted is false (default behavior)', async () => {
|
||||
await serverDB.insert(topics).values([
|
||||
{
|
||||
id: 't1',
|
||||
createdAt: new Date('2023-01-01'),
|
||||
metadata: {},
|
||||
userId,
|
||||
},
|
||||
{
|
||||
id: 't2',
|
||||
createdAt: new Date('2023-02-01'),
|
||||
metadata: {
|
||||
userMemoryExtractStatus: 'completed',
|
||||
},
|
||||
userId,
|
||||
},
|
||||
{
|
||||
id: 't3',
|
||||
createdAt: new Date('2023-03-01'),
|
||||
metadata: {},
|
||||
userId,
|
||||
},
|
||||
]);
|
||||
|
||||
const total = await topicModel.countTopicsForMemoryExtractor({
|
||||
ignoreExtracted: false,
|
||||
});
|
||||
|
||||
expect(total).toBe(2);
|
||||
});
|
||||
|
||||
it('includes extracted topics when ignoreExtracted is true', async () => {
|
||||
await serverDB.insert(topics).values([
|
||||
{
|
||||
id: 't1',
|
||||
createdAt: new Date('2023-01-01'),
|
||||
metadata: {},
|
||||
userId,
|
||||
},
|
||||
{
|
||||
id: 't2',
|
||||
createdAt: new Date('2023-02-01'),
|
||||
metadata: {
|
||||
userMemoryExtractStatus: 'completed',
|
||||
},
|
||||
userId,
|
||||
},
|
||||
]);
|
||||
|
||||
const total = await topicModel.countTopicsForMemoryExtractor({
|
||||
ignoreExtracted: true,
|
||||
});
|
||||
|
||||
expect(total).toBe(2);
|
||||
});
|
||||
});
|
||||
@@ -4,8 +4,9 @@ import {
|
||||
AsyncTaskErrorType,
|
||||
AsyncTaskStatus,
|
||||
AsyncTaskType,
|
||||
type UserMemoryExtractionMetadata,
|
||||
} from '@lobechat/types';
|
||||
import { and, eq, inArray, lt, or } from 'drizzle-orm';
|
||||
import { and, eq, inArray, lt, or, sql } from 'drizzle-orm';
|
||||
|
||||
import { AsyncTaskSelectItem, NewAsyncTaskItem, asyncTasks } from '../schemas';
|
||||
import { LobeChatDatabase } from '../type';
|
||||
@@ -19,7 +20,9 @@ export class AsyncTaskModel {
|
||||
this.db = db;
|
||||
}
|
||||
|
||||
create = async (params: Pick<NewAsyncTaskItem, 'type' | 'status'>): Promise<string> => {
|
||||
create = async (
|
||||
params: Pick<NewAsyncTaskItem, 'type' | 'status' | 'metadata' | 'parentId'>,
|
||||
): Promise<string> => {
|
||||
const data = await this.db
|
||||
.insert(asyncTasks)
|
||||
.values({ ...params, userId: this.userId })
|
||||
@@ -45,6 +48,47 @@ export class AsyncTaskModel {
|
||||
.where(and(eq(asyncTasks.id, taskId)));
|
||||
}
|
||||
|
||||
findActiveByType = async (type: AsyncTaskType) => {
|
||||
return this.db.query.asyncTasks.findFirst({
|
||||
where: and(
|
||||
eq(asyncTasks.userId, this.userId),
|
||||
eq(asyncTasks.type, type),
|
||||
inArray(asyncTasks.status, [AsyncTaskStatus.Pending, AsyncTaskStatus.Processing]),
|
||||
),
|
||||
});
|
||||
};
|
||||
|
||||
incrementUserMemoryExtractionProgress = async (taskId: string) => {
|
||||
const completedExpr = sql<number>`COALESCE(((${asyncTasks.metadata}) -> 'progress' ->> 'completedTopics')::int, 0) + 1`;
|
||||
const totalExpr = sql<number | null>`((${asyncTasks.metadata}) -> 'progress' ->> 'totalTopics')::int`;
|
||||
|
||||
const result = await this.db
|
||||
.update(asyncTasks)
|
||||
.set({
|
||||
metadata: sql`jsonb_set(
|
||||
jsonb_set(
|
||||
${asyncTasks.metadata},
|
||||
'{progress,completedTopics}',
|
||||
to_jsonb(${completedExpr}),
|
||||
true
|
||||
),
|
||||
'{progress,totalTopics}',
|
||||
COALESCE((${asyncTasks.metadata}) -> 'progress' -> 'totalTopics', 'null'::jsonb),
|
||||
true
|
||||
)`,
|
||||
status: sql`CASE
|
||||
WHEN ${totalExpr} IS NOT NULL AND ${completedExpr} >= ${totalExpr}
|
||||
THEN ${AsyncTaskStatus.Success}
|
||||
ELSE ${AsyncTaskStatus.Processing}
|
||||
END`,
|
||||
updatedAt: new Date(),
|
||||
})
|
||||
.where(and(eq(asyncTasks.id, taskId), eq(asyncTasks.userId, this.userId)))
|
||||
.returning({ metadata: asyncTasks.metadata, status: asyncTasks.status });
|
||||
|
||||
return result[0];
|
||||
};
|
||||
|
||||
findByIds = async (taskIds: string[], type: AsyncTaskType): Promise<AsyncTaskSelectItem[]> => {
|
||||
let chunkTasks: AsyncTaskSelectItem[] = [];
|
||||
|
||||
@@ -95,3 +139,14 @@ export class AsyncTaskModel {
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
export const initUserMemoryExtractionMetadata = (
|
||||
metadata?: UserMemoryExtractionMetadata,
|
||||
): UserMemoryExtractionMetadata => ({
|
||||
progress: {
|
||||
completedTopics: metadata?.progress?.completedTopics ?? 0,
|
||||
totalTopics: metadata?.progress?.totalTopics ?? null,
|
||||
},
|
||||
range: metadata?.range,
|
||||
source: metadata?.source ?? 'chat_topic',
|
||||
});
|
||||
|
||||
@@ -562,8 +562,8 @@ export class TopicModel {
|
||||
clientId: null,
|
||||
id: newId,
|
||||
parentId: newParentId,
|
||||
topicId: duplicatedTopic.id,
|
||||
tools: newTools,
|
||||
topicId: duplicatedTopic.id,
|
||||
})
|
||||
.returning()) as DBMessageItem[];
|
||||
|
||||
@@ -576,8 +576,8 @@ export class TopicModel {
|
||||
|
||||
await tx.insert(messagePlugins).values({
|
||||
...plugin,
|
||||
id: newId,
|
||||
clientId: null,
|
||||
id: newId,
|
||||
toolCallId: newToolCallId,
|
||||
});
|
||||
}
|
||||
@@ -739,13 +739,38 @@ export class TopicModel {
|
||||
? undefined
|
||||
: or(
|
||||
isNull(topics.metadata),
|
||||
sql`(${topics.metadata}->'memory_user_memory_extract'->>'extract_status') IS DISTINCT FROM 'completed'`,
|
||||
sql`(${topics.metadata}->>'userMemoryExtractStatus') IS DISTINCT FROM 'completed'`,
|
||||
),
|
||||
cursorCondition,
|
||||
),
|
||||
});
|
||||
};
|
||||
|
||||
countTopicsForMemoryExtractor = async (options: {
|
||||
endDate?: Date;
|
||||
ignoreExtracted?: boolean;
|
||||
startDate?: Date;
|
||||
} = {}) => {
|
||||
const result = await this.db
|
||||
.select({ total: count(topics.id) })
|
||||
.from(topics)
|
||||
.where(
|
||||
and(
|
||||
eq(topics.userId, this.userId),
|
||||
options.startDate ? gte(topics.createdAt, options.startDate) : undefined,
|
||||
options.endDate ? lte(topics.createdAt, options.endDate) : undefined,
|
||||
options.ignoreExtracted
|
||||
? undefined
|
||||
: or(
|
||||
isNull(topics.metadata),
|
||||
sql`(${topics.metadata}->>'userMemoryExtractStatus') IS DISTINCT FROM 'completed'`,
|
||||
),
|
||||
),
|
||||
);
|
||||
|
||||
return result[0]?.total ?? 0;
|
||||
};
|
||||
|
||||
/**
|
||||
* Get cron topics grouped by cronJob for a specific agent
|
||||
* Returns topics where trigger='cron' and metadata contains cronJobId
|
||||
|
||||
@@ -89,8 +89,8 @@ export interface BaseCreateUserMemoryParams {
|
||||
capturedAt?: Date;
|
||||
details: string;
|
||||
detailsEmbedding?: number[];
|
||||
memoryCategory: string;
|
||||
memoryLayer?: LayersEnum;
|
||||
memoryCategory?: string | null;
|
||||
memoryLayer: LayersEnum;
|
||||
memoryType: TypesEnum;
|
||||
summary: string;
|
||||
summaryEmbedding?: number[];
|
||||
|
||||
@@ -2,6 +2,7 @@ export enum AsyncTaskType {
|
||||
Chunking = 'chunk',
|
||||
Embedding = 'embedding',
|
||||
ImageGeneration = 'image_generation',
|
||||
UserMemoryExtractionWithChatTopic = 'user_memory_extraction:chat_topic',
|
||||
}
|
||||
|
||||
export enum AsyncTaskStatus {
|
||||
@@ -67,3 +68,17 @@ export interface FileParsingTask {
|
||||
embeddingStatus?: AsyncTaskStatus | null;
|
||||
finishEmbedding?: boolean;
|
||||
}
|
||||
|
||||
export interface UserMemoryExtractionProgress {
|
||||
completedTopics: number;
|
||||
totalTopics: number | null;
|
||||
}
|
||||
|
||||
export interface UserMemoryExtractionMetadata {
|
||||
progress: UserMemoryExtractionProgress;
|
||||
range?: {
|
||||
from?: string;
|
||||
to?: string;
|
||||
};
|
||||
source: 'chat_topic';
|
||||
}
|
||||
|
||||
@@ -23,13 +23,13 @@ const bodySchema = z.object({
|
||||
});
|
||||
|
||||
export const POST = async (req: Request) => {
|
||||
const { webhookHeaders, featureFlags } = parseMemoryExtractionConfig();
|
||||
const { webhook, featureFlags } = parseMemoryExtractionConfig();
|
||||
if (!featureFlags.enableBenchmarkLoCoMo) {
|
||||
return NextResponse.json({ error: 'Not found' }, { status: 404 });
|
||||
}
|
||||
|
||||
if (webhookHeaders && Object.keys(webhookHeaders).length > 0) {
|
||||
for (const [key, value] of Object.entries(webhookHeaders)) {
|
||||
if (webhook?.headers && Object.keys(webhook?.headers).length > 0) {
|
||||
for (const [key, value] of Object.entries(webhook?.headers)) {
|
||||
const headerValue = req.headers.get(key);
|
||||
if (headerValue !== value) {
|
||||
return NextResponse.json(
|
||||
|
||||
@@ -58,10 +58,10 @@ interface SessionExtractionResult {
|
||||
|
||||
export const POST = async (req: Request) => {
|
||||
try {
|
||||
const { webhookHeaders } = parseMemoryExtractionConfig();
|
||||
const { webhook } = parseMemoryExtractionConfig();
|
||||
|
||||
if (webhookHeaders && Object.keys(webhookHeaders).length > 0) {
|
||||
for (const [key, value] of Object.entries(webhookHeaders)) {
|
||||
if (webhook.headers && Object.keys(webhook.headers).length > 0) {
|
||||
for (const [key, value] of Object.entries(webhook.headers)) {
|
||||
const headerValue = req.headers.get(key);
|
||||
if (headerValue !== value) {
|
||||
return NextResponse.json(
|
||||
|
||||
@@ -10,10 +10,10 @@ import {
|
||||
} from '@/server/services/memory/userMemory/extract';
|
||||
|
||||
export const POST = async (req: Request) => {
|
||||
const { webhookHeaders, upstashWorkflowExtraHeaders } = parseMemoryExtractionConfig();
|
||||
const { webhook, upstashWorkflowExtraHeaders } = parseMemoryExtractionConfig();
|
||||
|
||||
if (webhookHeaders && Object.keys(webhookHeaders).length > 0) {
|
||||
for (const [key, value] of Object.entries(webhookHeaders)) {
|
||||
if (webhook.headers && Object.keys(webhook.headers).length > 0) {
|
||||
for (const [key, value] of Object.entries(webhook.headers)) {
|
||||
const headerValue = req.headers.get(key);
|
||||
if (headerValue !== value) {
|
||||
return NextResponse.json(
|
||||
|
||||
+4
@@ -82,6 +82,7 @@ export const { POST } = serve<MemoryExtractionPayloadInput>(
|
||||
`memory:user-memory:extract:users:${userId}:topics:${topicId}:cep:${index}`,
|
||||
() =>
|
||||
executor.extractTopic({
|
||||
asyncTaskId: payload.asyncTaskId,
|
||||
forceAll: payload.forceAll,
|
||||
forceTopics: payload.forceTopics,
|
||||
from: payload.from,
|
||||
@@ -90,6 +91,7 @@ export const { POST } = serve<MemoryExtractionPayloadInput>(
|
||||
to: payload.to,
|
||||
topicId,
|
||||
userId,
|
||||
userInitiated: false,
|
||||
}),
|
||||
),
|
||||
),
|
||||
@@ -101,6 +103,7 @@ export const { POST } = serve<MemoryExtractionPayloadInput>(
|
||||
`memory:user-memory:extract:users:${userId}:topics:${topicId}:identity:${index}`,
|
||||
() =>
|
||||
executor.extractTopic({
|
||||
asyncTaskId: payload.asyncTaskId,
|
||||
forceAll: payload.forceAll,
|
||||
forceTopics: payload.forceTopics,
|
||||
from: payload.from,
|
||||
@@ -109,6 +112,7 @@ export const { POST } = serve<MemoryExtractionPayloadInput>(
|
||||
to: payload.to,
|
||||
topicId,
|
||||
userId,
|
||||
userInitiated: payload.userInitiated,
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import { SCROLL_PARENT_ID } from '@/app/[variants]/(main)/community/features/con
|
||||
import { withSuspense } from '@/components/withSuspense';
|
||||
import { useQuery } from '@/hooks/useQuery';
|
||||
import { useDiscoverStore } from '@/store/discover';
|
||||
import { AssistantCategory } from '@/types/discover';
|
||||
import { AssistantCategory, AssistantSorts } from '@/types/discover';
|
||||
|
||||
import CategoryMenu from '../../../../components/CategoryMenu';
|
||||
import { useCategory } from './useCategory';
|
||||
@@ -28,7 +28,11 @@ const Category = memo(() => {
|
||||
const genUrl = (key: AssistantCategory) =>
|
||||
qs.stringifyUrl(
|
||||
{
|
||||
query: { category: key === AssistantCategory.All ? null : key, q, source },
|
||||
query: {
|
||||
category: [AssistantCategory.All, AssistantCategory.Discover].includes(key) ? null : key,
|
||||
q,
|
||||
sort: key === AssistantCategory.Discover ? AssistantSorts.Recommended : null,
|
||||
},
|
||||
url: '/community/agent',
|
||||
},
|
||||
{ skipNull: true },
|
||||
|
||||
@@ -2,14 +2,15 @@ import { Flexbox } from '@lobehub/ui';
|
||||
// import { PencilLineIcon } from 'lucide-react';
|
||||
import { type FC } from 'react';
|
||||
|
||||
import { SCROLL_PARENT_ID } from '@/app/[variants]/(main)/memory/features/TimeLineView/useScrollParent';
|
||||
import Loading from '@/components/Loading/BrandTextLoading';
|
||||
import NavHeader from '@/features/NavHeader';
|
||||
import WideScreenContainer from '@/features/WideScreenContainer';
|
||||
import WideScreenButton from '@/features/WideScreenContainer/WideScreenButton';
|
||||
import { useUserMemoryStore } from '@/store/userMemory';
|
||||
|
||||
import MemoryEmpty from '../features/MemoryEmpty';
|
||||
import { SCROLL_PARENT_ID } from '../features/TimeLineView/useScrollParent';
|
||||
import MemoryAnalysis from '@/app/[variants]/(main)/memory/features/MemoryAnalysis';
|
||||
import MemoryEmpty from '@/app/[variants]/(main)/memory/features/MemoryEmpty';
|
||||
import Persona from './features/Persona';
|
||||
import PersonaHeader from './features/Persona/PersonaHeader';
|
||||
import RoleTagCloud from './features/RoleTagCloud';
|
||||
@@ -23,7 +24,11 @@ const Home: FC = () => {
|
||||
if (isLoading) return <Loading debugId={'Home'} />;
|
||||
|
||||
if (!roles || roles.length === 0) {
|
||||
return <MemoryEmpty />;
|
||||
return (
|
||||
<MemoryEmpty>
|
||||
<MemoryAnalysis />
|
||||
</MemoryEmpty>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
@@ -39,6 +44,7 @@ const Home: FC = () => {
|
||||
zIndex: 1,
|
||||
}}
|
||||
/>
|
||||
<MemoryAnalysis />
|
||||
<Flexbox
|
||||
height={'100%'}
|
||||
id={SCROLL_PARENT_ID}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import MemoryAnalysis from '@/app/[variants]/(main)/memory/features/MemoryAnalysis';
|
||||
import MemoryEmpty from '@/app/[variants]/(main)/memory/features/MemoryEmpty';
|
||||
import { useQueryState } from '@/hooks/useQueryParam';
|
||||
import { useGlobalStore } from '@/store/global';
|
||||
@@ -31,7 +32,14 @@ const ContextsList = memo<ContextsListProps>(({ isLoading, searchValue, viewMode
|
||||
const isEmpty = contexts.length === 0;
|
||||
|
||||
if (isEmpty) {
|
||||
return <MemoryEmpty search={Boolean(searchValue)} title={t('context.empty')} />;
|
||||
return (
|
||||
<MemoryEmpty
|
||||
search={Boolean(searchValue)}
|
||||
title={t('context.empty')}
|
||||
>
|
||||
<MemoryAnalysis />
|
||||
</MemoryEmpty>
|
||||
);
|
||||
}
|
||||
|
||||
return viewMode === 'timeline' ? (
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import MemoryAnalysis from '@/app/[variants]/(main)/memory/features/MemoryAnalysis';
|
||||
import MemoryEmpty from '@/app/[variants]/(main)/memory/features/MemoryEmpty';
|
||||
import { useQueryState } from '@/hooks/useQueryParam';
|
||||
import { useGlobalStore } from '@/store/global';
|
||||
@@ -30,7 +31,14 @@ const ExperiencesList = memo<ExperiencesListProps>(({ isLoading, searchValue, vi
|
||||
const isEmpty = experiences.length === 0;
|
||||
|
||||
if (isEmpty) {
|
||||
return <MemoryEmpty search={Boolean(searchValue)} title={t('experience.empty')} />;
|
||||
return (
|
||||
<MemoryEmpty
|
||||
search={Boolean(searchValue)}
|
||||
title={t('experience.empty')}
|
||||
>
|
||||
<MemoryAnalysis />
|
||||
</MemoryEmpty>
|
||||
);
|
||||
}
|
||||
|
||||
return viewMode === 'timeline' ? (
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
'use client';
|
||||
|
||||
import { memo, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import AnalysisTrigger from './AnalysisTrigger';
|
||||
|
||||
const AnalysisAction = memo(() => {
|
||||
const { t } = useTranslation('memory');
|
||||
const [range, setRange] = useState<[Date | null, Date | null]>([null, null]);
|
||||
|
||||
const footerNote = useMemo(
|
||||
() =>
|
||||
range[0] || range[1]
|
||||
? t('analysis.modal.rangeSelected', {
|
||||
end: range[1]?.toISOString().slice(0, 10)?.replaceAll('-', '/') ||
|
||||
t('analysis.range.end'),
|
||||
start:
|
||||
range[0]?.toISOString().slice(0, 10)?.replaceAll('-', '/') ||
|
||||
t('analysis.range.start'),
|
||||
})
|
||||
: t('analysis.modal.rangePlaceholder'),
|
||||
[range, t],
|
||||
);
|
||||
|
||||
return <AnalysisTrigger footerNote={footerNote} onRangeChange={setRange} range={range} />;
|
||||
});
|
||||
|
||||
AnalysisAction.displayName = 'AnalysisAction';
|
||||
|
||||
export default AnalysisAction;
|
||||
@@ -0,0 +1,76 @@
|
||||
'use client';
|
||||
|
||||
import { Button, Flexbox, Icon, Text } from '@lobehub/ui';
|
||||
import { App } from 'antd';
|
||||
import { CalendarClockIcon } from 'lucide-react';
|
||||
import { memo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { useMemoryAnalysisAsyncTask } from '@/app/[variants]/(main)/memory/features/MemoryAnalysis/useTask';
|
||||
import { memoryExtractionService } from '@/services/userMemory/extraction';
|
||||
|
||||
import DateRangeModal from './DateRangeModal';
|
||||
|
||||
interface Props {
|
||||
footerNote: string;
|
||||
onRangeChange: (range: [Date | null, Date | null]) => void;
|
||||
range: [Date | null, Date | null];
|
||||
}
|
||||
|
||||
const AnalysisTrigger = memo<Props>(({ footerNote, range, onRangeChange }) => {
|
||||
const { t } = useTranslation('memory');
|
||||
const { message } = App.useApp();
|
||||
const { isValidating, refresh } = useMemoryAnalysisAsyncTask();
|
||||
const [open, setOpen] = useState(false);
|
||||
const [submitting, setSubmitting] = useState(false);
|
||||
|
||||
const handleSubmit = async () => {
|
||||
setSubmitting(true);
|
||||
try {
|
||||
const [from, to] = range;
|
||||
const result = await memoryExtractionService.requestFromChatTopics({
|
||||
fromDate: from ?? undefined,
|
||||
toDate: to ?? undefined,
|
||||
});
|
||||
|
||||
await refresh();
|
||||
message.success(result.deduped ? t('analysis.toast.deduped') : t('analysis.toast.started'));
|
||||
|
||||
setOpen(false);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
message.error(t('analysis.toast.failed'));
|
||||
} finally {
|
||||
setSubmitting(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<Button
|
||||
icon={<Icon icon={CalendarClockIcon} />}
|
||||
loading={submitting || isValidating}
|
||||
onClick={() => setOpen(true)}
|
||||
size={'large'}
|
||||
type={'primary'}
|
||||
style={{ maxWidth: 300 }}
|
||||
>
|
||||
{t('analysis.action.button')}
|
||||
</Button>
|
||||
|
||||
<DateRangeModal
|
||||
footerNote={footerNote}
|
||||
open={open}
|
||||
onCancel={() => setOpen(false)}
|
||||
onChange={onRangeChange}
|
||||
onSubmit={handleSubmit}
|
||||
range={range}
|
||||
submitting={submitting}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
AnalysisTrigger.displayName = 'AnalysisTrigger';
|
||||
|
||||
export default AnalysisTrigger;
|
||||
@@ -0,0 +1,68 @@
|
||||
'use client';
|
||||
|
||||
import { Flexbox, Text } from '@lobehub/ui';
|
||||
import { DatePicker, Modal } from 'antd';
|
||||
import type { RangePickerProps } from 'antd/es/date-picker';
|
||||
import dayjs, { type Dayjs } from 'dayjs';
|
||||
import { memo, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
interface Props {
|
||||
footerNote: string;
|
||||
onCancel: () => void;
|
||||
onChange: (range: [Date | null, Date | null]) => void;
|
||||
onSubmit: () => void;
|
||||
open: boolean;
|
||||
range: [Date | null, Date | null];
|
||||
submitting: boolean;
|
||||
}
|
||||
|
||||
const DateRangeModal = memo<Props>(
|
||||
({ footerNote, onCancel, onChange, onSubmit, open, range, submitting }) => {
|
||||
const { t } = useTranslation('memory');
|
||||
|
||||
const disabledDate = useCallback<NonNullable<RangePickerProps['disabledDate']>>(
|
||||
(current) => current.isAfter(dayjs(), 'day'),
|
||||
[],
|
||||
);
|
||||
|
||||
return (
|
||||
<Modal
|
||||
cancelText={t('analysis.modal.cancel')}
|
||||
okButtonProps={{ loading: submitting }}
|
||||
okText={t('analysis.modal.submit')}
|
||||
onCancel={onCancel}
|
||||
onOk={onSubmit}
|
||||
open={open}
|
||||
title={t('analysis.modal.title')}
|
||||
>
|
||||
<Flexbox gap={12}>
|
||||
<Text type={'secondary'}>{t('analysis.modal.helper')}</Text>
|
||||
<DatePicker.RangePicker
|
||||
allowClear
|
||||
disabledDate={disabledDate}
|
||||
format={'YYYY/MM/DD'}
|
||||
onChange={(values) =>
|
||||
onChange([
|
||||
values?.[0]?.toDate() ?? null,
|
||||
values?.[1]?.toDate() ?? null,
|
||||
])
|
||||
}
|
||||
style={{ width: '100%' }}
|
||||
value={[
|
||||
range[0] ? dayjs(range[0]) : null,
|
||||
range[1] ? dayjs(range[1]) : null,
|
||||
]}
|
||||
/>
|
||||
<Text fontSize={12} type={'secondary'}>
|
||||
{footerNote}
|
||||
</Text>
|
||||
</Flexbox>
|
||||
</Modal>
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
DateRangeModal.displayName = 'DateRangeModal';
|
||||
|
||||
export default DateRangeModal;
|
||||
@@ -0,0 +1,86 @@
|
||||
'use client';
|
||||
|
||||
import { Alert, Flexbox, Icon, Text } from '@lobehub/ui';
|
||||
import { Progress } from 'antd';
|
||||
import { Loader2Icon, TriangleAlertIcon } from 'lucide-react';
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { AsyncTaskStatus } from '@lobechat/types';
|
||||
|
||||
import type { MemoryExtractionTask } from '@/services/userMemory/extraction';
|
||||
|
||||
import { useMemoryAnalysisAsyncTask } from './useTask';
|
||||
|
||||
interface StatusProps {
|
||||
task?: MemoryExtractionTask | null;
|
||||
}
|
||||
|
||||
export const MemoryAnalysisStatus = memo<StatusProps>(({ task }) => {
|
||||
const { t } = useTranslation('memory');
|
||||
const data = task;
|
||||
|
||||
const status = data?.status;
|
||||
const isRunning =
|
||||
status === AsyncTaskStatus.Pending || status === AsyncTaskStatus.Processing;
|
||||
const isError = status === AsyncTaskStatus.Error;
|
||||
|
||||
if (!data || (!isRunning && !isError)) return null;
|
||||
|
||||
const { progress } = data.metadata;
|
||||
const percent =
|
||||
progress.totalTopics && progress.totalTopics > 0
|
||||
? Math.min(100, Math.round((progress.completedTopics / progress.totalTopics) * 100))
|
||||
: undefined;
|
||||
|
||||
const progressText = progress.totalTopics
|
||||
? t('analysis.status.progress', {
|
||||
completed: progress.completedTopics,
|
||||
total: progress.totalTopics,
|
||||
})
|
||||
: t('analysis.status.progressUnknown', { completed: progress.completedTopics });
|
||||
|
||||
const body = data.error?.body;
|
||||
const errorText =
|
||||
typeof body === 'string'
|
||||
? body
|
||||
: body && typeof body === 'object' && 'detail' in body && typeof body.detail === 'string'
|
||||
? body.detail
|
||||
: data.error?.name ?? t('analysis.status.errorTitle');
|
||||
|
||||
return (
|
||||
<Alert
|
||||
description={
|
||||
<Flexbox gap={12}>
|
||||
<Flexbox align="center" gap={12} horizontal wrap="wrap">
|
||||
<Progress
|
||||
percent={percent ?? 30}
|
||||
showInfo={Boolean(percent)}
|
||||
status={isError ? 'exception' : 'active'}
|
||||
style={{ flex: 1, minWidth: 220 }}
|
||||
/>
|
||||
<Text fontSize={13} type={isError ? 'danger' : 'secondary'}>
|
||||
{isError ? errorText ?? t('analysis.status.errorTitle') : progressText}
|
||||
</Text>
|
||||
</Flexbox>
|
||||
</Flexbox>
|
||||
}
|
||||
icon={<Icon icon={isError ? TriangleAlertIcon : Loader2Icon} spin={isRunning && !isError} />}
|
||||
title={isError ? t('analysis.status.errorTitle') : t('analysis.status.title')}
|
||||
type={isError ? 'error' : 'info'}
|
||||
variant={'borderless'}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
MemoryAnalysisStatus.displayName = 'MemoryAnalysisStatus';
|
||||
|
||||
const Status = memo(() => {
|
||||
const { data } = useMemoryAnalysisAsyncTask();
|
||||
|
||||
return <MemoryAnalysisStatus task={data} />;
|
||||
});
|
||||
|
||||
Status.displayName = 'MemoryAnalysisStatusWithData';
|
||||
|
||||
export default Status;
|
||||
@@ -0,0 +1,40 @@
|
||||
'use client';
|
||||
|
||||
import { Flexbox } from '@lobehub/ui';
|
||||
import { AsyncTaskStatus } from '@lobechat/types';
|
||||
import { memo, useMemo } from 'react';
|
||||
|
||||
import AnalysisAction from './Action';
|
||||
import { MemoryAnalysisStatus } from './Status';
|
||||
import { useMemoryAnalysisAsyncTask } from './useTask';
|
||||
|
||||
const MemoryAnalysis = memo(() => {
|
||||
const { data, isValidating } = useMemoryAnalysisAsyncTask();
|
||||
|
||||
const { showAction, showStatus } = useMemo(() => {
|
||||
const status = data?.status;
|
||||
const isRunning =
|
||||
status === AsyncTaskStatus.Pending || status === AsyncTaskStatus.Processing;
|
||||
const isError = status === AsyncTaskStatus.Error;
|
||||
|
||||
console.log(isRunning, isValidating, isError, data);
|
||||
|
||||
return {
|
||||
showAction: (!isRunning && (!isValidating || isError)) || !data || isError,
|
||||
showStatus: Boolean(data && (isRunning || isError)),
|
||||
};
|
||||
}, [data, isValidating]);
|
||||
|
||||
if (!showAction && !showStatus) return null;
|
||||
|
||||
return (
|
||||
<Flexbox gap={12} style={{ width: '100%', paddingTop: 16 }}>
|
||||
{showStatus && <MemoryAnalysisStatus task={data} />}
|
||||
{showAction && <AnalysisAction />}
|
||||
</Flexbox>
|
||||
);
|
||||
});
|
||||
|
||||
MemoryAnalysis.displayName = 'MemoryAnalysis';
|
||||
|
||||
export default MemoryAnalysis;
|
||||
@@ -0,0 +1,43 @@
|
||||
import { AsyncTaskStatus } from '@lobechat/types';
|
||||
import { useEffect } from 'react';
|
||||
|
||||
import { useClientDataSWR } from '@/libs/swr';
|
||||
import {
|
||||
type MemoryExtractionTask,
|
||||
memoryExtractionService,
|
||||
} from '@/services/userMemory/extraction';
|
||||
|
||||
const SWR_KEY = 'user-memory:analysis-task';
|
||||
|
||||
export const useMemoryAnalysisAsyncTask = (taskId?: string) => {
|
||||
const swr = useClientDataSWR<MemoryExtractionTask | null>(
|
||||
taskId ? [SWR_KEY, taskId] : SWR_KEY,
|
||||
() => memoryExtractionService.getTask(taskId),
|
||||
{
|
||||
refreshInterval: (data) =>
|
||||
data && [AsyncTaskStatus.Pending, AsyncTaskStatus.Processing].includes(data.status)
|
||||
? 30_000
|
||||
: 0,
|
||||
},
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (!swr.data) return;
|
||||
|
||||
const isRunning = [AsyncTaskStatus.Pending, AsyncTaskStatus.Processing].includes(
|
||||
swr.data.status,
|
||||
);
|
||||
if (!isRunning) return;
|
||||
|
||||
const timer = setInterval(() => {
|
||||
swr.mutate();
|
||||
}, 5_000);
|
||||
|
||||
return () => clearInterval(timer);
|
||||
}, [swr.data?.id, swr.data?.status, swr.mutate]);
|
||||
|
||||
return {
|
||||
...swr,
|
||||
refresh: swr.mutate,
|
||||
};
|
||||
};
|
||||
@@ -1,27 +1,36 @@
|
||||
import { Center, Empty, type EmptyProps } from '@lobehub/ui';
|
||||
import { Center, Empty, Flexbox, type EmptyProps } from '@lobehub/ui';
|
||||
import { BrainCircuitIcon } from 'lucide-react';
|
||||
import { memo } from 'react';
|
||||
import { memo, ReactNode } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
const MemoryEmpty = memo<EmptyProps & { search?: boolean }>(({ search, title, ...rest }) => {
|
||||
const { t } = useTranslation('memory');
|
||||
return (
|
||||
<Center height="100%" style={{ minHeight: '50vh' }} width="100%">
|
||||
<Empty
|
||||
description={search ? t('empty.search') : t('empty.description')}
|
||||
descriptionProps={{
|
||||
fontSize: 14,
|
||||
}}
|
||||
icon={BrainCircuitIcon}
|
||||
style={{
|
||||
maxWidth: 400,
|
||||
}}
|
||||
title={search ? undefined : title || t('empty.title')}
|
||||
type={search ? 'default' : 'page'}
|
||||
{...rest}
|
||||
/>
|
||||
</Center>
|
||||
);
|
||||
});
|
||||
const MemoryEmpty = memo<
|
||||
EmptyProps & { children?: ReactNode | ReactNode[], search?: boolean; }
|
||||
>(({ search, title, children, ...rest }) => {
|
||||
const { t } = useTranslation('memory');
|
||||
return (
|
||||
<Center height="100%" style={{ minHeight: '50vh' }} width="100%">
|
||||
<Flexbox align="center" gap={12}>
|
||||
<Empty
|
||||
description={search ? t('empty.search') : t('empty.description')}
|
||||
descriptionProps={{
|
||||
fontSize: 14,
|
||||
}}
|
||||
icon={BrainCircuitIcon}
|
||||
style={{
|
||||
maxWidth: 550,
|
||||
}}
|
||||
title={search ? undefined : title || t('empty.title')}
|
||||
type={search ? 'default' : 'page'}
|
||||
{...rest}
|
||||
>
|
||||
<Flexbox>
|
||||
{children}
|
||||
</Flexbox>
|
||||
</Empty>
|
||||
</Flexbox>
|
||||
</Center>
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
export default MemoryEmpty;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import MemoryAnalysis from '@/app/[variants]/(main)/memory/features/MemoryAnalysis';
|
||||
import { useQueryState } from '@/hooks/useQueryParam';
|
||||
import { useGlobalStore } from '@/store/global';
|
||||
import { useUserMemoryStore } from '@/store/userMemory';
|
||||
@@ -30,7 +31,14 @@ const IdentitiesList = memo<IdentitiesListProps>(({ isLoading, searchValue, view
|
||||
};
|
||||
|
||||
if (!identities || identities.length === 0)
|
||||
return <MemoryEmpty search={Boolean(searchValue)} title={t('identity.empty')} />;
|
||||
return (
|
||||
<MemoryEmpty
|
||||
search={Boolean(searchValue)}
|
||||
title={t('identity.empty')}
|
||||
>
|
||||
<MemoryAnalysis />
|
||||
</MemoryEmpty>
|
||||
);
|
||||
|
||||
if (viewMode === 'timeline')
|
||||
return <TimelineView identities={identities} isLoading={isLoading} onClick={handleCardClick} />;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { memo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import MemoryAnalysis from '@/app/[variants]/(main)/memory/features/MemoryAnalysis';
|
||||
import MemoryEmpty from '@/app/[variants]/(main)/memory/features/MemoryEmpty';
|
||||
import { useQueryState } from '@/hooks/useQueryParam';
|
||||
import { useGlobalStore } from '@/store/global';
|
||||
@@ -30,7 +31,14 @@ const PreferencesList = memo<PreferencesListProps>(({ isLoading, searchValue, vi
|
||||
const isEmpty = preferences.length === 0;
|
||||
|
||||
if (isEmpty) {
|
||||
return <MemoryEmpty search={Boolean(searchValue)} title={t('preference.empty')} />;
|
||||
return (
|
||||
<MemoryEmpty
|
||||
search={Boolean(searchValue)}
|
||||
title={t('preference.empty')}
|
||||
>
|
||||
<MemoryAnalysis />
|
||||
</MemoryEmpty>
|
||||
);
|
||||
}
|
||||
|
||||
return viewMode === 'timeline' ? (
|
||||
|
||||
@@ -1,4 +1,24 @@
|
||||
export default {
|
||||
'analysis.action.button': 'Request memory analysis',
|
||||
'analysis.modal.cancel': 'Cancel',
|
||||
'analysis.modal.helper':
|
||||
'By default Lobe AI will analyze all unprocessed chats. It\'s optional to select a date range to analyze.',
|
||||
'analysis.modal.rangePlaceholder': 'No range selected; all conversations will be analyzed.',
|
||||
'analysis.modal.rangeSelected': 'Analyzing chats from {{start}} to {{end}}',
|
||||
'analysis.modal.submit': 'Request memory analysis',
|
||||
'analysis.modal.title': 'Analyze chats to generate memories',
|
||||
'analysis.range.all': 'All conversations',
|
||||
'analysis.range.end': 'Today',
|
||||
'analysis.range.start': 'Beginning',
|
||||
'analysis.status.errorTitle': 'Memory analysis request failed',
|
||||
'analysis.status.progress': 'Processed {{completed}} / {{total}} topics',
|
||||
'analysis.status.progressUnknown': 'Processed {{completed}} topics so far',
|
||||
'analysis.status.tip':
|
||||
'We are processing your conversations to build personal memories. This may take a few minutes.',
|
||||
'analysis.status.title': 'Memory analysis in progress',
|
||||
'analysis.toast.deduped': 'A memory request is already running, continuing progress…',
|
||||
'analysis.toast.failed': 'Memory analysis request failed. Please try again.',
|
||||
'analysis.toast.started': 'Memory analysis started. We will update once it’s ready.',
|
||||
'context.actions.delete': 'Delete',
|
||||
'context.actions.edit': 'Edit',
|
||||
'context.defaultType': 'Context',
|
||||
|
||||
@@ -60,7 +60,10 @@ export interface MemoryExtractionPrivateConfig {
|
||||
secretAccessKey?: string;
|
||||
};
|
||||
upstashWorkflowExtraHeaders?: Record<string, string>;
|
||||
webhookHeaders?: Record<string, string>;
|
||||
webhook: {
|
||||
headers?: Record<string, string>;
|
||||
baseUrl?: string;
|
||||
}
|
||||
whitelistUsers?: string[];
|
||||
}
|
||||
|
||||
@@ -253,7 +256,10 @@ export const parseMemoryExtractionConfig = (): MemoryExtractionPrivateConfig =>
|
||||
featureFlags,
|
||||
observabilityS3: extractorObservabilityS3,
|
||||
upstashWorkflowExtraHeaders,
|
||||
webhookHeaders,
|
||||
webhook: {
|
||||
headers: webhookHeaders,
|
||||
baseUrl: process.env.MEMORY_USER_MEMORY_WEBHOOK_BASE_URL,
|
||||
},
|
||||
whitelistUsers,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -0,0 +1,222 @@
|
||||
import { TRPCError } from '@trpc/server';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { userMemoryRouter } from '@/server/routers/lambda/userMemory';
|
||||
import { AsyncTaskStatus, AsyncTaskType } from '@/types/asyncTask';
|
||||
import { MemorySourceType } from '@/types/userMemory';
|
||||
|
||||
const mockFindActiveByType = vi.fn();
|
||||
const mockCreate = vi.fn();
|
||||
const mockUpdate = vi.fn();
|
||||
const mockFindById = vi.fn();
|
||||
|
||||
const mockCountTopicsForMemoryExtractor = vi.fn();
|
||||
const { mockTriggerProcessUsers } = vi.hoisted(() => ({
|
||||
mockTriggerProcessUsers: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('@/database/models/asyncTask', () => ({
|
||||
AsyncTaskModel: vi.fn(() => ({
|
||||
create: mockCreate,
|
||||
findById: mockFindById,
|
||||
findActiveByType: mockFindActiveByType,
|
||||
update: mockUpdate,
|
||||
})),
|
||||
initUserMemoryExtractionMetadata: vi.fn((metadata) => metadata),
|
||||
}));
|
||||
|
||||
vi.mock('@/database/models/topic', () => ({
|
||||
TopicModel: vi.fn(() => ({
|
||||
countTopicsForMemoryExtractor: mockCountTopicsForMemoryExtractor,
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('@/envs/app', () => ({
|
||||
appEnv: {
|
||||
APP_URL: 'https://example.com',
|
||||
INTERNAL_APP_URL: 'https://internal.example.com',
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock('@/server/globalConfig/parseMemoryExtractionConfig', () => ({
|
||||
parseMemoryExtractionConfig: vi.fn(() => ({
|
||||
webhook: { baseUrl: 'https://internal.example.com' },
|
||||
upstashWorkflowExtraHeaders: { 'x-test': 'ok' },
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('@/server/services/memory/userMemory/extract', () => ({
|
||||
MemoryExtractionWorkflowService: {
|
||||
triggerProcessUsers: mockTriggerProcessUsers,
|
||||
},
|
||||
buildWorkflowPayloadInput: (payload: any) => payload,
|
||||
normalizeMemoryExtractionPayload: (payload: any) => payload,
|
||||
}));
|
||||
|
||||
const createCaller = (ctxOverrides: Partial<any> = {}) => {
|
||||
const ctx = {
|
||||
serverDB: {} as any,
|
||||
userId: 'user-1',
|
||||
...ctxOverrides,
|
||||
};
|
||||
|
||||
return userMemoryRouter.createCaller(ctx);
|
||||
};
|
||||
|
||||
describe('userMemoryRouter.requestMemoryFromChatTopic', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('dedupes when an active task exists', async () => {
|
||||
mockFindActiveByType.mockResolvedValue({
|
||||
id: 'existing-task',
|
||||
metadata: { progress: { completedTopics: 0, totalTopics: 1 } },
|
||||
status: AsyncTaskStatus.Pending,
|
||||
});
|
||||
|
||||
const caller = createCaller();
|
||||
const result = await caller.requestMemoryFromChatTopic({});
|
||||
|
||||
expect(result).toEqual({
|
||||
deduped: true,
|
||||
id: 'existing-task',
|
||||
metadata: { progress: { completedTopics: 0, totalTopics: 1 } },
|
||||
status: AsyncTaskStatus.Pending,
|
||||
});
|
||||
expect(mockCreate).not.toHaveBeenCalled();
|
||||
expect(mockTriggerProcessUsers).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('creates task and triggers workflow with user context and dates', async () => {
|
||||
mockFindActiveByType.mockResolvedValue(undefined);
|
||||
mockCreate.mockResolvedValue('new-task');
|
||||
mockCountTopicsForMemoryExtractor.mockResolvedValue(2);
|
||||
|
||||
const caller = createCaller();
|
||||
const result = await caller.requestMemoryFromChatTopic({
|
||||
fromDate: new Date('2024-01-01'),
|
||||
toDate: new Date('2024-02-01'),
|
||||
});
|
||||
|
||||
expect(mockCreate).toHaveBeenCalledWith({
|
||||
metadata: {
|
||||
progress: { completedTopics: 0, totalTopics: 2 },
|
||||
range: { from: new Date('2024-01-01').toISOString(), to: new Date('2024-02-01').toISOString() },
|
||||
source: 'chat_topic',
|
||||
},
|
||||
status: AsyncTaskStatus.Pending,
|
||||
type: AsyncTaskType.UserMemoryExtractionWithChatTopic,
|
||||
});
|
||||
expect(mockTriggerProcessUsers).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
asyncTaskId: 'new-task',
|
||||
baseUrl: 'https://internal.example.com',
|
||||
fromDate: new Date('2024-01-01'),
|
||||
sources: [MemorySourceType.ChatTopic],
|
||||
toDate: new Date('2024-02-01'),
|
||||
userIds: ['user-1'],
|
||||
userInitiated: true,
|
||||
}),
|
||||
{ extraHeaders: { 'x-test': 'ok' } },
|
||||
);
|
||||
expect(result).toMatchObject({
|
||||
deduped: false,
|
||||
id: 'new-task',
|
||||
status: AsyncTaskStatus.Pending,
|
||||
});
|
||||
});
|
||||
|
||||
it('returns success immediately when no topics', async () => {
|
||||
mockFindActiveByType.mockResolvedValue(undefined);
|
||||
mockCountTopicsForMemoryExtractor.mockResolvedValue(0);
|
||||
mockCreate.mockResolvedValue('empty-task');
|
||||
|
||||
const caller = createCaller();
|
||||
const result = await caller.requestMemoryFromChatTopic({});
|
||||
|
||||
expect(result).toEqual({
|
||||
deduped: false,
|
||||
id: 'empty-task',
|
||||
metadata: {
|
||||
progress: { completedTopics: 0, totalTopics: 0 },
|
||||
range: { from: undefined, to: undefined },
|
||||
source: 'chat_topic',
|
||||
},
|
||||
status: AsyncTaskStatus.Success,
|
||||
});
|
||||
expect(mockTriggerProcessUsers).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('throws on invalid date range', async () => {
|
||||
const caller = createCaller();
|
||||
await expect(
|
||||
caller.requestMemoryFromChatTopic({
|
||||
fromDate: new Date('2024-02-02'),
|
||||
toDate: new Date('2024-01-01'),
|
||||
}),
|
||||
).rejects.toBeInstanceOf(TRPCError);
|
||||
});
|
||||
});
|
||||
|
||||
describe('userMemoryRouter.getMemoryExtractionTask', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it('returns null when no active task', async () => {
|
||||
mockFindActiveByType.mockResolvedValue(undefined);
|
||||
|
||||
const caller = createCaller();
|
||||
const result = await caller.getMemoryExtractionTask();
|
||||
|
||||
expect(result).toBeNull();
|
||||
});
|
||||
|
||||
it('returns active task with normalized metadata', async () => {
|
||||
mockFindActiveByType.mockResolvedValue({
|
||||
id: 'task-1',
|
||||
metadata: {
|
||||
progress: { completedTopics: 1, totalTopics: 4 },
|
||||
source: 'chat_topic',
|
||||
},
|
||||
status: AsyncTaskStatus.Processing,
|
||||
userId: 'user-1',
|
||||
});
|
||||
|
||||
const caller = createCaller();
|
||||
const result = await caller.getMemoryExtractionTask();
|
||||
|
||||
expect(result).toEqual({
|
||||
error: undefined,
|
||||
id: 'task-1',
|
||||
metadata: {
|
||||
progress: { completedTopics: 1, totalTopics: 4 },
|
||||
range: undefined,
|
||||
source: 'chat_topic',
|
||||
},
|
||||
status: AsyncTaskStatus.Processing,
|
||||
});
|
||||
});
|
||||
|
||||
it('fetches by task id when provided', async () => {
|
||||
mockFindActiveByType.mockResolvedValue(undefined);
|
||||
mockFindById.mockResolvedValue({
|
||||
id: 'a0a0a0a0-a0a0-4a0a-a0a0-a0a0a0a0a0a0',
|
||||
metadata: {
|
||||
progress: { completedTopics: 2, totalTopics: 8 },
|
||||
source: 'chat_topic',
|
||||
},
|
||||
status: AsyncTaskStatus.Pending,
|
||||
userId: 'user-1',
|
||||
});
|
||||
|
||||
const caller = createCaller();
|
||||
const result = await caller.getMemoryExtractionTask({
|
||||
taskId: 'a0a0a0a0-a0a0-4a0a-a0a0-a0a0a0a0a0a0',
|
||||
});
|
||||
|
||||
expect(mockFindById).toHaveBeenCalledWith('a0a0a0a0-a0a0-4a0a-a0a0-a0a0a0a0a0a0');
|
||||
expect(result?.id).toBe('a0a0a0a0-a0a0-4a0a-a0a0-a0a0a0a0a0a0');
|
||||
});
|
||||
});
|
||||
@@ -1,6 +1,18 @@
|
||||
import { CreateUserMemoryIdentitySchema, UpdateUserMemoryIdentitySchema } from '@lobechat/types';
|
||||
import {
|
||||
AsyncTaskError,
|
||||
AsyncTaskErrorType,
|
||||
AsyncTaskStatus,
|
||||
AsyncTaskType,
|
||||
CreateUserMemoryIdentitySchema,
|
||||
MemorySourceType,
|
||||
UpdateUserMemoryIdentitySchema,
|
||||
UserMemoryExtractionMetadata,
|
||||
} from '@lobechat/types';
|
||||
import { TRPCError } from '@trpc/server';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { AsyncTaskModel, initUserMemoryExtractionMetadata } from '@/database/models/asyncTask';
|
||||
import { TopicModel } from '@/database/models/topic';
|
||||
import { UserMemoryModel } from '@/database/models/userMemory';
|
||||
import {
|
||||
UserMemoryContextModel,
|
||||
@@ -10,21 +22,41 @@ import {
|
||||
} from '@/database/models/userMemory/index';
|
||||
import { authedProcedure, router } from '@/libs/trpc/lambda';
|
||||
import { serverDatabase } from '@/libs/trpc/lambda/middleware';
|
||||
import { appEnv } from '@/envs/app';
|
||||
import { parseMemoryExtractionConfig } from '@/server/globalConfig/parseMemoryExtractionConfig';
|
||||
import {
|
||||
MemoryExtractionWorkflowService,
|
||||
buildWorkflowPayloadInput,
|
||||
normalizeMemoryExtractionPayload,
|
||||
} from '@/server/services/memory/userMemory/extract';
|
||||
|
||||
const userMemoryProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: {
|
||||
asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId),
|
||||
contextModel: new UserMemoryContextModel(ctx.serverDB, ctx.userId),
|
||||
experienceModel: new UserMemoryExperienceModel(ctx.serverDB, ctx.userId),
|
||||
identityModel: new UserMemoryIdentityModel(ctx.serverDB, ctx.userId),
|
||||
preferenceModel: new UserMemoryPreferenceModel(ctx.serverDB, ctx.userId),
|
||||
topicModel: new TopicModel(ctx.serverDB, ctx.userId),
|
||||
userMemoryModel: new UserMemoryModel(ctx.serverDB, ctx.userId),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
const userMemoryExtractionInputSchema = z.object({
|
||||
fromDate: z.coerce.date().optional(),
|
||||
toDate: z.coerce.date().optional(),
|
||||
});
|
||||
|
||||
const userMemoryExtractionTaskInputSchema = z
|
||||
.object({
|
||||
taskId: z.string().uuid().optional(),
|
||||
})
|
||||
.optional();
|
||||
|
||||
export const userMemoryRouter = router({
|
||||
// ============ Identity CRUD ============
|
||||
createIdentity: userMemoryProcedure
|
||||
@@ -82,6 +114,25 @@ export const userMemoryRouter = router({
|
||||
return ctx.userMemoryModel.getAllIdentities();
|
||||
}),
|
||||
|
||||
getMemoryExtractionTask: userMemoryProcedure
|
||||
.input(userMemoryExtractionTaskInputSchema)
|
||||
.query(async ({ ctx, input }) => {
|
||||
const task = input?.taskId
|
||||
? await ctx.asyncTaskModel.findById(input.taskId)
|
||||
: await ctx.asyncTaskModel.findActiveByType(
|
||||
AsyncTaskType.UserMemoryExtractionWithChatTopic,
|
||||
);
|
||||
|
||||
if (!task || task.userId !== ctx.userId) return null;
|
||||
|
||||
return {
|
||||
error: task.error,
|
||||
id: task.id,
|
||||
metadata: initUserMemoryExtractionMetadata(task.metadata as UserMemoryExtractionMetadata | undefined),
|
||||
status: task.status as AsyncTaskStatus,
|
||||
};
|
||||
}),
|
||||
|
||||
// ============ Persona ============
|
||||
getPersona: userMemoryProcedure.query(async () => {
|
||||
return { content: '', summary: '' };
|
||||
@@ -91,6 +142,106 @@ export const userMemoryRouter = router({
|
||||
return ctx.userMemoryModel.searchPreferences({});
|
||||
}),
|
||||
|
||||
requestMemoryFromChatTopic: userMemoryProcedure
|
||||
.input(userMemoryExtractionInputSchema)
|
||||
.mutation(async ({ ctx, input }) => {
|
||||
if (input.fromDate && input.toDate && input.fromDate > input.toDate) {
|
||||
throw new TRPCError({
|
||||
code: 'BAD_REQUEST',
|
||||
message: '`fromDate` cannot be later than `toDate`',
|
||||
});
|
||||
}
|
||||
|
||||
const existingTask = await ctx.asyncTaskModel.findActiveByType(
|
||||
AsyncTaskType.UserMemoryExtractionWithChatTopic,
|
||||
);
|
||||
if (existingTask) {
|
||||
return {
|
||||
deduped: true,
|
||||
id: existingTask.id,
|
||||
metadata: existingTask.metadata as UserMemoryExtractionMetadata,
|
||||
status: existingTask.status as AsyncTaskStatus,
|
||||
};
|
||||
}
|
||||
|
||||
const totalTopics = await ctx.topicModel.countTopicsForMemoryExtractor({
|
||||
endDate: input.toDate,
|
||||
ignoreExtracted: false,
|
||||
startDate: input.fromDate,
|
||||
});
|
||||
const metadata = initUserMemoryExtractionMetadata({
|
||||
progress: {
|
||||
completedTopics: 0,
|
||||
totalTopics,
|
||||
},
|
||||
range: {
|
||||
from: input.fromDate?.toISOString(),
|
||||
to: input.toDate?.toISOString(),
|
||||
},
|
||||
source: 'chat_topic',
|
||||
});
|
||||
|
||||
const initialStatus =
|
||||
totalTopics === 0 ? AsyncTaskStatus.Success : AsyncTaskStatus.Pending;
|
||||
const taskId = await ctx.asyncTaskModel.create({
|
||||
metadata,
|
||||
status: initialStatus,
|
||||
type: AsyncTaskType.UserMemoryExtractionWithChatTopic,
|
||||
});
|
||||
|
||||
if (totalTopics === 0) {
|
||||
return {
|
||||
deduped: false,
|
||||
id: taskId,
|
||||
metadata: metadata as UserMemoryExtractionMetadata,
|
||||
status: initialStatus as AsyncTaskStatus,
|
||||
};
|
||||
}
|
||||
|
||||
const { webhook, upstashWorkflowExtraHeaders } = parseMemoryExtractionConfig();
|
||||
const baseUrl = webhook.baseUrl || appEnv.INTERNAL_APP_URL || appEnv.APP_URL;
|
||||
|
||||
try {
|
||||
await MemoryExtractionWorkflowService.triggerProcessUsers(
|
||||
buildWorkflowPayloadInput(
|
||||
normalizeMemoryExtractionPayload({
|
||||
asyncTaskId: taskId,
|
||||
baseUrl,
|
||||
forceAll: false,
|
||||
forceTopics: false,
|
||||
fromDate: input.fromDate,
|
||||
mode: 'workflow',
|
||||
sources: [MemorySourceType.ChatTopic],
|
||||
toDate: input.toDate,
|
||||
userIds: [ctx.userId],
|
||||
userInitiated: true,
|
||||
}),
|
||||
),
|
||||
{ extraHeaders: upstashWorkflowExtraHeaders },
|
||||
);
|
||||
} catch (error) {
|
||||
await ctx.asyncTaskModel.update(taskId, {
|
||||
error: new AsyncTaskError(
|
||||
AsyncTaskErrorType.TaskTriggerError,
|
||||
'Failed to schedule memory extraction workflow',
|
||||
),
|
||||
status: AsyncTaskStatus.Error,
|
||||
});
|
||||
throw new TRPCError({
|
||||
cause: error,
|
||||
code: 'INTERNAL_SERVER_ERROR',
|
||||
message: 'Failed to trigger user memory extraction',
|
||||
});
|
||||
}
|
||||
|
||||
return {
|
||||
deduped: false,
|
||||
id: taskId,
|
||||
metadata: metadata as UserMemoryExtractionMetadata,
|
||||
status: AsyncTaskStatus.Pending,
|
||||
};
|
||||
}),
|
||||
|
||||
updateContext: userMemoryProcedure
|
||||
.input(
|
||||
z.object({
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { AiProviderRuntimeState } from '@lobechat/types';
|
||||
import type { EnabledAiModel } from 'model-bank';
|
||||
import { describe, expect, it } from 'vitest';
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import type { MemoryExtractionPrivateConfig } from '@/server/globalConfig/parseMemoryExtractionConfig';
|
||||
|
||||
@@ -39,7 +39,7 @@ const createExecutor = (privateOverrides?: Partial<MemoryExtractionPrivateConfig
|
||||
embedding: { model: 'embed-1', provider: 'provider-e' },
|
||||
featureFlags: { enableBenchmarkLoCoMo: false },
|
||||
observabilityS3: { enabled: false },
|
||||
webhookHeaders: {},
|
||||
webhook: {},
|
||||
};
|
||||
|
||||
const serverConfig = {
|
||||
@@ -70,6 +70,7 @@ describe('MemoryExtractionExecutor.resolveRuntimeKeyVaults', () => {
|
||||
{ abilities: {}, id: 'gate-2', providerId: 'provider-b', type: 'chat' },
|
||||
{ abilities: {}, id: 'embed-1', providerId: 'provider-e', type: 'embedding' },
|
||||
{ abilities: {}, id: 'layer-ctx', providerId: 'provider-l', type: 'chat' },
|
||||
{ abilities: {}, id: 'layer-act', providerId: 'provider-l', type: 'chat' },
|
||||
{ abilities: {}, id: 'layer-exp', providerId: 'provider-l', type: 'chat' },
|
||||
{ abilities: {}, id: 'layer-id', providerId: 'provider-l', type: 'chat' },
|
||||
{ abilities: {}, id: 'layer-pref', providerId: 'provider-l', type: 'chat' },
|
||||
@@ -98,6 +99,7 @@ describe('MemoryExtractionExecutor.resolveRuntimeKeyVaults', () => {
|
||||
const runtimeState = createRuntimeState(
|
||||
[
|
||||
{ abilities: {}, id: 'gate-2', providerId: 'provider-b', type: 'chat' },
|
||||
{ abilities: {}, id: 'layer-act', providerId: 'provider-l', type: 'chat' },
|
||||
{ abilities: {}, id: 'layer-ctx', providerId: 'provider-l', type: 'chat' },
|
||||
{ abilities: {}, id: 'layer-exp', providerId: 'provider-l', type: 'chat' },
|
||||
{ abilities: {}, id: 'layer-id', providerId: 'provider-l', type: 'chat' },
|
||||
|
||||
@@ -39,6 +39,7 @@ import {
|
||||
import { attributesCommon } from '@lobechat/observability-otel/node';
|
||||
import type {
|
||||
AiProviderRuntimeState,
|
||||
ChatTopicMetadata,
|
||||
IdentityMemoryDetail,
|
||||
MemoryExtractionAgentCallTrace,
|
||||
MemoryExtractionTraceError,
|
||||
@@ -54,6 +55,7 @@ import type { ListTopicsForMemoryExtractorCursor } from '@/database/models/topic
|
||||
import { TopicModel } from '@/database/models/topic';
|
||||
import type { ListUsersForMemoryExtractorCursor } from '@/database/models/user';
|
||||
import { UserModel } from '@/database/models/user';
|
||||
import { AsyncTaskModel } from '@/database/models/asyncTask';
|
||||
import { UserMemoryModel } from '@/database/models/userMemory';
|
||||
import { UserMemorySourceBenchmarkLoCoMoModel } from '@/database/models/userMemory/sources/benchmarkLoCoMo';
|
||||
import { AiInfraRepos } from '@/database/repositories/aiInfra';
|
||||
@@ -67,12 +69,12 @@ import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt';
|
||||
import { S3 } from '@/server/modules/S3';
|
||||
import type { GlobalMemoryLayer } from '@/types/serverConfig';
|
||||
import type { ProviderConfig } from '@/types/user/settings';
|
||||
import { LayersEnum, MemorySourceType, type MergeStrategyEnum, TypesEnum } from '@/types/userMemory';
|
||||
import {
|
||||
LayersEnum,
|
||||
MemorySourceType,
|
||||
type MergeStrategyEnum,
|
||||
TypesEnum,
|
||||
} from '@/types/userMemory';
|
||||
AsyncTaskError,
|
||||
AsyncTaskErrorType,
|
||||
AsyncTaskStatus,
|
||||
} from '@/types/asyncTask';
|
||||
import { encodeAsync } from '@/utils/tokenizer';
|
||||
|
||||
const SOURCE_ALIAS_MAP: Record<string, MemorySourceType> = {
|
||||
@@ -107,6 +109,7 @@ export interface TopicWorkflowCursor extends MemoryExtractionWorkflowCursor {
|
||||
}
|
||||
|
||||
export interface MemoryExtractionNormalizedPayload {
|
||||
asyncTaskId?: string;
|
||||
baseUrl: string;
|
||||
forceAll: boolean;
|
||||
forceTopics: boolean;
|
||||
@@ -126,9 +129,11 @@ export interface MemoryExtractionNormalizedPayload {
|
||||
userCursor?: MemoryExtractionWorkflowCursor;
|
||||
userId?: string;
|
||||
userIds: string[];
|
||||
userInitiated?: boolean;
|
||||
}
|
||||
|
||||
export const memoryExtractionPayloadSchema = z.object({
|
||||
asyncTaskId: z.string().uuid().optional(),
|
||||
baseUrl: z.string().url().optional(),
|
||||
forceAll: z.boolean().optional(),
|
||||
forceTopics: z.boolean().optional(),
|
||||
@@ -155,6 +160,7 @@ export const memoryExtractionPayloadSchema = z.object({
|
||||
.optional(),
|
||||
userId: z.string().optional(),
|
||||
userIds: z.array(z.string()).optional(),
|
||||
userInitiated: z.boolean().optional(),
|
||||
});
|
||||
|
||||
export type MemoryExtractionPayloadInput = z.infer<typeof memoryExtractionPayloadSchema>;
|
||||
@@ -188,6 +194,7 @@ export const normalizeMemoryExtractionPayload = (
|
||||
if (!baseUrl) throw new Error('Missing baseUrl for workflow trigger');
|
||||
|
||||
return {
|
||||
asyncTaskId: parsed.asyncTaskId,
|
||||
baseUrl,
|
||||
forceAll: parsed.forceAll ?? false,
|
||||
forceTopics: parsed.forceTopics ?? false,
|
||||
@@ -205,6 +212,7 @@ export const normalizeMemoryExtractionPayload = (
|
||||
userIds: Array.from(
|
||||
new Set([...(parsed.userIds || []), ...(parsed.userId ? [parsed.userId] : [])]),
|
||||
).filter(Boolean),
|
||||
userInitiated: parsed.userInitiated ?? false,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -223,6 +231,7 @@ type ProviderKeyVaultMap = Record<
|
||||
export const buildWorkflowPayloadInput = (
|
||||
payload: MemoryExtractionNormalizedPayload,
|
||||
): MemoryExtractionPayloadInput => ({
|
||||
asyncTaskId: payload.asyncTaskId,
|
||||
baseUrl: payload.baseUrl,
|
||||
forceAll: payload.forceAll,
|
||||
forceTopics: payload.forceTopics,
|
||||
@@ -238,6 +247,7 @@ export const buildWorkflowPayloadInput = (
|
||||
userCursor: payload.userCursor,
|
||||
userId: payload.userId ?? payload.userIds[0],
|
||||
userIds: payload.userIds,
|
||||
userInitiated: payload.userInitiated,
|
||||
});
|
||||
|
||||
const normalizeProvider = (provider: string) => provider.toLowerCase();
|
||||
@@ -329,12 +339,11 @@ const initRuntimeForAgent = async (agent: MemoryAgentConfig, keyVaults?: Provide
|
||||
});
|
||||
};
|
||||
|
||||
const isTopicExtracted = (metadata: any): boolean => {
|
||||
const extractStatus = metadata?.memory_user_memory_extract?.extract_status;
|
||||
const isTopicExtracted = (metadata?: ChatTopicMetadata | null): boolean => {
|
||||
const extractStatus = metadata?.userMemoryExtractStatus;
|
||||
if (extractStatus) return extractStatus === 'completed';
|
||||
|
||||
const state = metadata?.memoryExtraction?.sources?.chat_topic;
|
||||
return state?.status === 'completed' && !!state?.lastRunAt;
|
||||
return metadata?.userMemoryExtractStatus === 'completed' && !!metadata?.userMemoryExtractRunState?.lastRunAt;
|
||||
};
|
||||
|
||||
type RuntimeBundle = {
|
||||
@@ -344,6 +353,7 @@ type RuntimeBundle = {
|
||||
};
|
||||
|
||||
export interface TopicExtractionJob {
|
||||
asyncTaskId?: string;
|
||||
forceAll: boolean;
|
||||
forceTopics: boolean;
|
||||
from?: Date;
|
||||
@@ -352,6 +362,7 @@ export interface TopicExtractionJob {
|
||||
to?: Date;
|
||||
topicId: string;
|
||||
userId: string;
|
||||
userInitiated?: boolean;
|
||||
}
|
||||
|
||||
export interface TopicPaginationJob {
|
||||
@@ -1024,6 +1035,18 @@ export class MemoryExtractionExecutor {
|
||||
return res.map((item) => ({ ...item, layer: LayersEnum.Identity }));
|
||||
}
|
||||
|
||||
private async reportUserInitiatedProgress(job: TopicExtractionJob) {
|
||||
if (!job.asyncTaskId || !job.userInitiated) return;
|
||||
|
||||
try {
|
||||
const db = await this.db;
|
||||
const asyncTaskModel = new AsyncTaskModel(db, job.userId);
|
||||
await asyncTaskModel.incrementUserMemoryExtractionProgress(job.asyncTaskId);
|
||||
} catch (error) {
|
||||
console.error('[memory-extraction] failed to update async task progress', error);
|
||||
}
|
||||
}
|
||||
|
||||
async extractTopic(job: TopicExtractionJob) {
|
||||
const attributes = {
|
||||
source: job.source,
|
||||
@@ -1050,6 +1073,8 @@ export class MemoryExtractionExecutor {
|
||||
'Memory User Memory: Extract Chat Topic',
|
||||
{ attributes },
|
||||
async (span) => {
|
||||
const shouldReportProgress = job.userInitiated && !!job.asyncTaskId;
|
||||
let topicProcessed = false;
|
||||
const startTime = Date.now();
|
||||
let extractionJob: MemoryExtractionJob | null = null;
|
||||
let extraction: MemoryExtractionResult | null = null;
|
||||
@@ -1072,6 +1097,7 @@ export class MemoryExtractionExecutor {
|
||||
`[memory-extraction] topic ${job.topicId} not found for user ${job.userId}`,
|
||||
);
|
||||
span.setStatus({ code: SpanStatusCode.OK, message: 'topic_not_found' });
|
||||
topicProcessed = true;
|
||||
return {
|
||||
extracted: false,
|
||||
layers: {},
|
||||
@@ -1081,6 +1107,7 @@ export class MemoryExtractionExecutor {
|
||||
}
|
||||
if ((job.from && topic.createdAt < job.from) || (job.to && topic.createdAt > job.to)) {
|
||||
span.setStatus({ code: SpanStatusCode.OK, message: 'topic_out_of_range' });
|
||||
topicProcessed = true;
|
||||
return {
|
||||
extracted: false,
|
||||
layers: {},
|
||||
@@ -1090,6 +1117,7 @@ export class MemoryExtractionExecutor {
|
||||
}
|
||||
if (!job.forceAll && !job.forceTopics && isTopicExtracted(topic.metadata)) {
|
||||
span.setStatus({ code: SpanStatusCode.OK, message: 'already_extracted' });
|
||||
topicProcessed = true;
|
||||
return {
|
||||
extracted: false,
|
||||
layers: {},
|
||||
@@ -1292,6 +1320,7 @@ export class MemoryExtractionExecutor {
|
||||
if (!extraction) {
|
||||
this.recordJobMetrics(extractionJob, 'completed', Date.now() - startTime);
|
||||
span.setStatus({ code: SpanStatusCode.OK, message: 'no_extraction' });
|
||||
topicProcessed = true;
|
||||
return {
|
||||
extracted: false,
|
||||
layers: {},
|
||||
@@ -1318,6 +1347,7 @@ export class MemoryExtractionExecutor {
|
||||
span.setStatus({ code: SpanStatusCode.OK });
|
||||
span.setAttribute('memory.processed_memory_count', persistedRes.createdIds.length);
|
||||
|
||||
topicProcessed = true;
|
||||
return {
|
||||
extracted: true,
|
||||
layers: persistedRes.layers,
|
||||
@@ -1340,8 +1370,26 @@ export class MemoryExtractionExecutor {
|
||||
if (tracePayload) {
|
||||
tracePayload.error = serializeError(error);
|
||||
}
|
||||
if (job.asyncTaskId && job.userInitiated) {
|
||||
try {
|
||||
const asyncTaskModel = new AsyncTaskModel(await this.db, job.userId);
|
||||
await asyncTaskModel.update(job.asyncTaskId, {
|
||||
error: new AsyncTaskError(
|
||||
AsyncTaskErrorType.ServerError,
|
||||
error instanceof Error ? error.message : 'Extraction failed',
|
||||
),
|
||||
status: AsyncTaskStatus.Error,
|
||||
});
|
||||
} catch (taskError) {
|
||||
console.error('[memory-extraction] failed to update async task status', taskError);
|
||||
}
|
||||
}
|
||||
throw error;
|
||||
} finally {
|
||||
if (shouldReportProgress && topicProcessed) {
|
||||
await this.reportUserInitiatedProgress(job);
|
||||
}
|
||||
|
||||
if (observabilityS3 && tracePayload) {
|
||||
try {
|
||||
await this.uploadExtractionTrace(
|
||||
@@ -1428,6 +1476,7 @@ export class MemoryExtractionExecutor {
|
||||
const topicIds = await this.filterTopicIdsForUser(userId, payload.topicIds);
|
||||
for (const topicId of topicIds) {
|
||||
const extracted = await this.extractTopic({
|
||||
asyncTaskId: payload.asyncTaskId,
|
||||
forceAll: payload.forceAll,
|
||||
forceTopics: payload.forceTopics,
|
||||
from: payload.from,
|
||||
@@ -1436,6 +1485,7 @@ export class MemoryExtractionExecutor {
|
||||
to: payload.to,
|
||||
topicId,
|
||||
userId,
|
||||
userInitiated: payload.userInitiated,
|
||||
});
|
||||
|
||||
results.push({ ...extracted, topicId, userId });
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
import type { AsyncTaskStatus, IAsyncTaskError, UserMemoryExtractionMetadata } from '@lobechat/types';
|
||||
|
||||
import { lambdaClient } from '@/libs/trpc/client';
|
||||
|
||||
export interface MemoryExtractionTask {
|
||||
error?: IAsyncTaskError | null;
|
||||
id: string;
|
||||
metadata: UserMemoryExtractionMetadata;
|
||||
status: AsyncTaskStatus;
|
||||
}
|
||||
|
||||
export interface RequestMemoryExtractionParams {
|
||||
fromDate?: Date;
|
||||
toDate?: Date;
|
||||
}
|
||||
|
||||
export interface RequestMemoryExtractionResult extends MemoryExtractionTask {
|
||||
deduped: boolean;
|
||||
}
|
||||
|
||||
class MemoryExtractionService {
|
||||
requestFromChatTopics = async (
|
||||
params: RequestMemoryExtractionParams,
|
||||
): Promise<RequestMemoryExtractionResult> => {
|
||||
return lambdaClient.userMemory.requestMemoryFromChatTopic.mutate(params);
|
||||
};
|
||||
|
||||
getTask = async (taskId?: string): Promise<MemoryExtractionTask | null> => {
|
||||
return lambdaClient.userMemory.getMemoryExtractionTask.query(
|
||||
taskId ? { taskId } : undefined,
|
||||
) as Promise<MemoryExtractionTask | null>;
|
||||
};
|
||||
}
|
||||
|
||||
export const memoryExtractionService = new MemoryExtractionService();
|
||||
@@ -136,3 +136,4 @@ class UserMemoryService {
|
||||
|
||||
export const userMemoryService = new UserMemoryService();
|
||||
export { memoryCRUDService } from './crud';
|
||||
export { memoryExtractionService } from './extraction';
|
||||
|
||||
Reference in New Issue
Block a user