feat(userMemories): added user memory request, implemented workflow trigger (#11749)

This commit is contained in:
Neko
2026-01-24 23:20:55 +08:00
committed by GitHub
parent 7bca7d6f79
commit 9df3b88c49
33 changed files with 1196 additions and 64 deletions
+19 -1
View File
@@ -62,5 +62,23 @@
"tab.preferences": "Preferences",
"tab.search": "Search",
"viewMode.masonry": "Masonry",
"viewMode.timeline": "Timeline"
"viewMode.timeline": "Timeline",
"analysis.action.button": "Request Memory",
"analysis.modal.cancel": "Cancel",
"analysis.modal.helper": "Optional time window. By default, only conversations not yet analyzed will be processed. Choose a window to limit analysis.",
"analysis.modal.rangePlaceholder": "No time range selected, all conversations will be analyzed",
"analysis.modal.rangeSelected": "Analyze conversations from {{start}} to {{end}}",
"analysis.modal.submit": "Request Analysis",
"analysis.modal.title": "Analyze conversations to generate memories",
"analysis.range.all": "All conversations",
"analysis.range.end": "Today",
"analysis.range.start": "Beginning",
"analysis.status.errorTitle": "Analysis request failed",
"analysis.status.progress": "Processed {{completed}} / {{total}} topics",
"analysis.status.progressUnknown": "Processed {{completed}} topics",
"analysis.status.tip": "We're processing your conversations to generate memories. This may take a few minutes - thanks for your patience.",
"analysis.status.title": "Memory analysis in progress",
"analysis.toast.deduped": "A request is already running - please wait...",
"analysis.toast.failed": "Analysis request failed, please retry.",
"analysis.toast.started": "Memory analysis started. It will update automatically when done."
}
+19 -1
View File
@@ -62,5 +62,23 @@
"tab.preferences": "偏好",
"tab.search": "搜索",
"viewMode.masonry": "瀑布流",
"viewMode.timeline": "时间线"
"viewMode.timeline": "时间线",
"analysis.action.button": "请求记忆",
"analysis.modal.cancel": "取消",
"analysis.modal.helper": "可选的时间窗口,默认会对尚未分析的对话进行分析,请选择用于分析的时间窗口。",
"analysis.modal.rangePlaceholder": "未选择时间范围,将分析所有对话",
"analysis.modal.rangeSelected": "分析对话 {{start}} 至 {{end}}",
"analysis.modal.submit": "请求分析记忆",
"analysis.modal.title": "分析对话以生成记忆",
"analysis.range.all": "全部对话",
"analysis.range.end": "今天",
"analysis.range.start": "最初",
"analysis.status.errorTitle": "请求分析失败",
"analysis.status.progress": "已处理 {{completed}} / {{total}} 个话题",
"analysis.status.progressUnknown": "已处理 {{completed}} 个话题",
"analysis.status.tip": "正在处理你的对话来生成记忆,可能需要几分钟,请耐心等待。",
"analysis.status.title": "记忆分析中",
"analysis.toast.deduped": "已有请求仍在处理中,请稍等…",
"analysis.toast.failed": "请求分析失败,请重试。",
"analysis.toast.started": "记忆分析已开始,完成后会自动更新。"
}
@@ -1,6 +1,10 @@
// @vitest-environment node
import { ASYNC_TASK_TIMEOUT } from '@lobechat/business-config/server';
import { AsyncTaskStatus, AsyncTaskType } from '@lobechat/types';
import {
AsyncTaskStatus,
AsyncTaskType,
type UserMemoryExtractionMetadata,
} from '@lobechat/types';
import { eq } from 'drizzle-orm';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
@@ -121,6 +125,40 @@ describe('AsyncTaskModel', () => {
});
});
describe('incrementUserMemoryExtractionProgress', () => {
it('should increment completedTopics and set status to success when reaching total', async () => {
const { id } = await serverDB
.insert(asyncTasks)
.values({
metadata: {
progress: {
completedTopics: 0,
totalTopics: 2,
},
source: 'chat_topic',
},
status: AsyncTaskStatus.Pending,
type: AsyncTaskType.UserMemoryExtractionWithChatTopic,
userId,
})
.returning()
.then((res) => res[0]);
await asyncTaskModel.incrementUserMemoryExtractionProgress(id);
let task = await serverDB.query.asyncTasks.findFirst({ where: eq(asyncTasks.id, id) });
const firstMetadata = task?.metadata as UserMemoryExtractionMetadata | undefined;
expect(firstMetadata?.progress?.completedTopics).toBe(1);
expect(firstMetadata?.progress?.totalTopics).toBe(2);
expect(task?.status).toBe(AsyncTaskStatus.Processing);
await asyncTaskModel.incrementUserMemoryExtractionProgress(id);
task = await serverDB.query.asyncTasks.findFirst({ where: eq(asyncTasks.id, id) });
const secondMetadata = task?.metadata as UserMemoryExtractionMetadata | undefined;
expect(secondMetadata?.progress?.completedTopics).toBe(2);
expect(task?.status).toBe(AsyncTaskStatus.Success);
});
});
describe('checkTimeoutTasks', () => {
it('should mark tasks as error if they timeout', async () => {
// Create a task with old timestamp (beyond timeout)
@@ -0,0 +1,77 @@
import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import { topics, users } from '../../../schemas';
import { LobeChatDatabase } from '../../../type';
import { TopicModel } from '../../topic';
import { getTestDB } from '../../../core/getTestDB';
const userId = 'topic-memory-extractor-user';
const serverDB: LobeChatDatabase = await getTestDB();
const topicModel = new TopicModel(serverDB, userId);
describe('TopicModel - countTopicsForMemoryExtractor', () => {
beforeEach(async () => {
await serverDB.delete(users);
await serverDB.insert(users).values({ id: userId });
});
afterEach(async () => {
await serverDB.delete(users);
});
it('counts only unextracted topics when ignoreExtracted is false (default behavior)', async () => {
await serverDB.insert(topics).values([
{
id: 't1',
createdAt: new Date('2023-01-01'),
metadata: {},
userId,
},
{
id: 't2',
createdAt: new Date('2023-02-01'),
metadata: {
userMemoryExtractStatus: 'completed',
},
userId,
},
{
id: 't3',
createdAt: new Date('2023-03-01'),
metadata: {},
userId,
},
]);
const total = await topicModel.countTopicsForMemoryExtractor({
ignoreExtracted: false,
});
expect(total).toBe(2);
});
it('includes extracted topics when ignoreExtracted is true', async () => {
await serverDB.insert(topics).values([
{
id: 't1',
createdAt: new Date('2023-01-01'),
metadata: {},
userId,
},
{
id: 't2',
createdAt: new Date('2023-02-01'),
metadata: {
userMemoryExtractStatus: 'completed',
},
userId,
},
]);
const total = await topicModel.countTopicsForMemoryExtractor({
ignoreExtracted: true,
});
expect(total).toBe(2);
});
});
+57 -2
View File
@@ -4,8 +4,9 @@ import {
AsyncTaskErrorType,
AsyncTaskStatus,
AsyncTaskType,
type UserMemoryExtractionMetadata,
} from '@lobechat/types';
import { and, eq, inArray, lt, or } from 'drizzle-orm';
import { and, eq, inArray, lt, or, sql } from 'drizzle-orm';
import { AsyncTaskSelectItem, NewAsyncTaskItem, asyncTasks } from '../schemas';
import { LobeChatDatabase } from '../type';
@@ -19,7 +20,9 @@ export class AsyncTaskModel {
this.db = db;
}
create = async (params: Pick<NewAsyncTaskItem, 'type' | 'status'>): Promise<string> => {
create = async (
params: Pick<NewAsyncTaskItem, 'type' | 'status' | 'metadata' | 'parentId'>,
): Promise<string> => {
const data = await this.db
.insert(asyncTasks)
.values({ ...params, userId: this.userId })
@@ -45,6 +48,47 @@ export class AsyncTaskModel {
.where(and(eq(asyncTasks.id, taskId)));
}
findActiveByType = async (type: AsyncTaskType) => {
return this.db.query.asyncTasks.findFirst({
where: and(
eq(asyncTasks.userId, this.userId),
eq(asyncTasks.type, type),
inArray(asyncTasks.status, [AsyncTaskStatus.Pending, AsyncTaskStatus.Processing]),
),
});
};
incrementUserMemoryExtractionProgress = async (taskId: string) => {
const completedExpr = sql<number>`COALESCE(((${asyncTasks.metadata}) -> 'progress' ->> 'completedTopics')::int, 0) + 1`;
const totalExpr = sql<number | null>`((${asyncTasks.metadata}) -> 'progress' ->> 'totalTopics')::int`;
const result = await this.db
.update(asyncTasks)
.set({
metadata: sql`jsonb_set(
jsonb_set(
${asyncTasks.metadata},
'{progress,completedTopics}',
to_jsonb(${completedExpr}),
true
),
'{progress,totalTopics}',
COALESCE((${asyncTasks.metadata}) -> 'progress' -> 'totalTopics', 'null'::jsonb),
true
)`,
status: sql`CASE
WHEN ${totalExpr} IS NOT NULL AND ${completedExpr} >= ${totalExpr}
THEN ${AsyncTaskStatus.Success}
ELSE ${AsyncTaskStatus.Processing}
END`,
updatedAt: new Date(),
})
.where(and(eq(asyncTasks.id, taskId), eq(asyncTasks.userId, this.userId)))
.returning({ metadata: asyncTasks.metadata, status: asyncTasks.status });
return result[0];
};
findByIds = async (taskIds: string[], type: AsyncTaskType): Promise<AsyncTaskSelectItem[]> => {
let chunkTasks: AsyncTaskSelectItem[] = [];
@@ -95,3 +139,14 @@ export class AsyncTaskModel {
}
};
}
export const initUserMemoryExtractionMetadata = (
metadata?: UserMemoryExtractionMetadata,
): UserMemoryExtractionMetadata => ({
progress: {
completedTopics: metadata?.progress?.completedTopics ?? 0,
totalTopics: metadata?.progress?.totalTopics ?? null,
},
range: metadata?.range,
source: metadata?.source ?? 'chat_topic',
});
+28 -3
View File
@@ -562,8 +562,8 @@ export class TopicModel {
clientId: null,
id: newId,
parentId: newParentId,
topicId: duplicatedTopic.id,
tools: newTools,
topicId: duplicatedTopic.id,
})
.returning()) as DBMessageItem[];
@@ -576,8 +576,8 @@ export class TopicModel {
await tx.insert(messagePlugins).values({
...plugin,
id: newId,
clientId: null,
id: newId,
toolCallId: newToolCallId,
});
}
@@ -739,13 +739,38 @@ export class TopicModel {
? undefined
: or(
isNull(topics.metadata),
sql`(${topics.metadata}->'memory_user_memory_extract'->>'extract_status') IS DISTINCT FROM 'completed'`,
sql`(${topics.metadata}->>'userMemoryExtractStatus') IS DISTINCT FROM 'completed'`,
),
cursorCondition,
),
});
};
countTopicsForMemoryExtractor = async (options: {
endDate?: Date;
ignoreExtracted?: boolean;
startDate?: Date;
} = {}) => {
const result = await this.db
.select({ total: count(topics.id) })
.from(topics)
.where(
and(
eq(topics.userId, this.userId),
options.startDate ? gte(topics.createdAt, options.startDate) : undefined,
options.endDate ? lte(topics.createdAt, options.endDate) : undefined,
options.ignoreExtracted
? undefined
: or(
isNull(topics.metadata),
sql`(${topics.metadata}->>'userMemoryExtractStatus') IS DISTINCT FROM 'completed'`,
),
),
);
return result[0]?.total ?? 0;
};
/**
* Get cron topics grouped by cronJob for a specific agent
* Returns topics where trigger='cron' and metadata contains cronJobId
@@ -89,8 +89,8 @@ export interface BaseCreateUserMemoryParams {
capturedAt?: Date;
details: string;
detailsEmbedding?: number[];
memoryCategory: string;
memoryLayer?: LayersEnum;
memoryCategory?: string | null;
memoryLayer: LayersEnum;
memoryType: TypesEnum;
summary: string;
summaryEmbedding?: number[];
+15
View File
@@ -2,6 +2,7 @@ export enum AsyncTaskType {
Chunking = 'chunk',
Embedding = 'embedding',
ImageGeneration = 'image_generation',
UserMemoryExtractionWithChatTopic = 'user_memory_extraction:chat_topic',
}
export enum AsyncTaskStatus {
@@ -67,3 +68,17 @@ export interface FileParsingTask {
embeddingStatus?: AsyncTaskStatus | null;
finishEmbedding?: boolean;
}
export interface UserMemoryExtractionProgress {
completedTopics: number;
totalTopics: number | null;
}
export interface UserMemoryExtractionMetadata {
progress: UserMemoryExtractionProgress;
range?: {
from?: string;
to?: string;
};
source: 'chat_topic';
}
@@ -23,13 +23,13 @@ const bodySchema = z.object({
});
export const POST = async (req: Request) => {
const { webhookHeaders, featureFlags } = parseMemoryExtractionConfig();
const { webhook, featureFlags } = parseMemoryExtractionConfig();
if (!featureFlags.enableBenchmarkLoCoMo) {
return NextResponse.json({ error: 'Not found' }, { status: 404 });
}
if (webhookHeaders && Object.keys(webhookHeaders).length > 0) {
for (const [key, value] of Object.entries(webhookHeaders)) {
if (webhook?.headers && Object.keys(webhook?.headers).length > 0) {
for (const [key, value] of Object.entries(webhook?.headers)) {
const headerValue = req.headers.get(key);
if (headerValue !== value) {
return NextResponse.json(
@@ -58,10 +58,10 @@ interface SessionExtractionResult {
export const POST = async (req: Request) => {
try {
const { webhookHeaders } = parseMemoryExtractionConfig();
const { webhook } = parseMemoryExtractionConfig();
if (webhookHeaders && Object.keys(webhookHeaders).length > 0) {
for (const [key, value] of Object.entries(webhookHeaders)) {
if (webhook.headers && Object.keys(webhook.headers).length > 0) {
for (const [key, value] of Object.entries(webhook.headers)) {
const headerValue = req.headers.get(key);
if (headerValue !== value) {
return NextResponse.json(
@@ -10,10 +10,10 @@ import {
} from '@/server/services/memory/userMemory/extract';
export const POST = async (req: Request) => {
const { webhookHeaders, upstashWorkflowExtraHeaders } = parseMemoryExtractionConfig();
const { webhook, upstashWorkflowExtraHeaders } = parseMemoryExtractionConfig();
if (webhookHeaders && Object.keys(webhookHeaders).length > 0) {
for (const [key, value] of Object.entries(webhookHeaders)) {
if (webhook.headers && Object.keys(webhook.headers).length > 0) {
for (const [key, value] of Object.entries(webhook.headers)) {
const headerValue = req.headers.get(key);
if (headerValue !== value) {
return NextResponse.json(
@@ -82,6 +82,7 @@ export const { POST } = serve<MemoryExtractionPayloadInput>(
`memory:user-memory:extract:users:${userId}:topics:${topicId}:cep:${index}`,
() =>
executor.extractTopic({
asyncTaskId: payload.asyncTaskId,
forceAll: payload.forceAll,
forceTopics: payload.forceTopics,
from: payload.from,
@@ -90,6 +91,7 @@ export const { POST } = serve<MemoryExtractionPayloadInput>(
to: payload.to,
topicId,
userId,
userInitiated: false,
}),
),
),
@@ -101,6 +103,7 @@ export const { POST } = serve<MemoryExtractionPayloadInput>(
`memory:user-memory:extract:users:${userId}:topics:${topicId}:identity:${index}`,
() =>
executor.extractTopic({
asyncTaskId: payload.asyncTaskId,
forceAll: payload.forceAll,
forceTopics: payload.forceTopics,
from: payload.from,
@@ -109,6 +112,7 @@ export const { POST } = serve<MemoryExtractionPayloadInput>(
to: payload.to,
topicId,
userId,
userInitiated: payload.userInitiated,
}),
);
}
@@ -9,7 +9,7 @@ import { SCROLL_PARENT_ID } from '@/app/[variants]/(main)/community/features/con
import { withSuspense } from '@/components/withSuspense';
import { useQuery } from '@/hooks/useQuery';
import { useDiscoverStore } from '@/store/discover';
import { AssistantCategory } from '@/types/discover';
import { AssistantCategory, AssistantSorts } from '@/types/discover';
import CategoryMenu from '../../../../components/CategoryMenu';
import { useCategory } from './useCategory';
@@ -28,7 +28,11 @@ const Category = memo(() => {
const genUrl = (key: AssistantCategory) =>
qs.stringifyUrl(
{
query: { category: key === AssistantCategory.All ? null : key, q, source },
query: {
category: [AssistantCategory.All, AssistantCategory.Discover].includes(key) ? null : key,
q,
sort: key === AssistantCategory.Discover ? AssistantSorts.Recommended : null,
},
url: '/community/agent',
},
{ skipNull: true },
@@ -2,14 +2,15 @@ import { Flexbox } from '@lobehub/ui';
// import { PencilLineIcon } from 'lucide-react';
import { type FC } from 'react';
import { SCROLL_PARENT_ID } from '@/app/[variants]/(main)/memory/features/TimeLineView/useScrollParent';
import Loading from '@/components/Loading/BrandTextLoading';
import NavHeader from '@/features/NavHeader';
import WideScreenContainer from '@/features/WideScreenContainer';
import WideScreenButton from '@/features/WideScreenContainer/WideScreenButton';
import { useUserMemoryStore } from '@/store/userMemory';
import MemoryEmpty from '../features/MemoryEmpty';
import { SCROLL_PARENT_ID } from '../features/TimeLineView/useScrollParent';
import MemoryAnalysis from '@/app/[variants]/(main)/memory/features/MemoryAnalysis';
import MemoryEmpty from '@/app/[variants]/(main)/memory/features/MemoryEmpty';
import Persona from './features/Persona';
import PersonaHeader from './features/Persona/PersonaHeader';
import RoleTagCloud from './features/RoleTagCloud';
@@ -23,7 +24,11 @@ const Home: FC = () => {
if (isLoading) return <Loading debugId={'Home'} />;
if (!roles || roles.length === 0) {
return <MemoryEmpty />;
return (
<MemoryEmpty>
<MemoryAnalysis />
</MemoryEmpty>
);
}
return (
@@ -39,6 +44,7 @@ const Home: FC = () => {
zIndex: 1,
}}
/>
<MemoryAnalysis />
<Flexbox
height={'100%'}
id={SCROLL_PARENT_ID}
@@ -1,6 +1,7 @@
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import MemoryAnalysis from '@/app/[variants]/(main)/memory/features/MemoryAnalysis';
import MemoryEmpty from '@/app/[variants]/(main)/memory/features/MemoryEmpty';
import { useQueryState } from '@/hooks/useQueryParam';
import { useGlobalStore } from '@/store/global';
@@ -31,7 +32,14 @@ const ContextsList = memo<ContextsListProps>(({ isLoading, searchValue, viewMode
const isEmpty = contexts.length === 0;
if (isEmpty) {
return <MemoryEmpty search={Boolean(searchValue)} title={t('context.empty')} />;
return (
<MemoryEmpty
search={Boolean(searchValue)}
title={t('context.empty')}
>
<MemoryAnalysis />
</MemoryEmpty>
);
}
return viewMode === 'timeline' ? (
@@ -1,6 +1,7 @@
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import MemoryAnalysis from '@/app/[variants]/(main)/memory/features/MemoryAnalysis';
import MemoryEmpty from '@/app/[variants]/(main)/memory/features/MemoryEmpty';
import { useQueryState } from '@/hooks/useQueryParam';
import { useGlobalStore } from '@/store/global';
@@ -30,7 +31,14 @@ const ExperiencesList = memo<ExperiencesListProps>(({ isLoading, searchValue, vi
const isEmpty = experiences.length === 0;
if (isEmpty) {
return <MemoryEmpty search={Boolean(searchValue)} title={t('experience.empty')} />;
return (
<MemoryEmpty
search={Boolean(searchValue)}
title={t('experience.empty')}
>
<MemoryAnalysis />
</MemoryEmpty>
);
}
return viewMode === 'timeline' ? (
@@ -0,0 +1,31 @@
'use client';
import { memo, useMemo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import AnalysisTrigger from './AnalysisTrigger';
const AnalysisAction = memo(() => {
const { t } = useTranslation('memory');
const [range, setRange] = useState<[Date | null, Date | null]>([null, null]);
const footerNote = useMemo(
() =>
range[0] || range[1]
? t('analysis.modal.rangeSelected', {
end: range[1]?.toISOString().slice(0, 10)?.replaceAll('-', '/') ||
t('analysis.range.end'),
start:
range[0]?.toISOString().slice(0, 10)?.replaceAll('-', '/') ||
t('analysis.range.start'),
})
: t('analysis.modal.rangePlaceholder'),
[range, t],
);
return <AnalysisTrigger footerNote={footerNote} onRangeChange={setRange} range={range} />;
});
AnalysisAction.displayName = 'AnalysisAction';
export default AnalysisAction;
@@ -0,0 +1,76 @@
'use client';
import { Button, Flexbox, Icon, Text } from '@lobehub/ui';
import { App } from 'antd';
import { CalendarClockIcon } from 'lucide-react';
import { memo, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useMemoryAnalysisAsyncTask } from '@/app/[variants]/(main)/memory/features/MemoryAnalysis/useTask';
import { memoryExtractionService } from '@/services/userMemory/extraction';
import DateRangeModal from './DateRangeModal';
interface Props {
footerNote: string;
onRangeChange: (range: [Date | null, Date | null]) => void;
range: [Date | null, Date | null];
}
const AnalysisTrigger = memo<Props>(({ footerNote, range, onRangeChange }) => {
const { t } = useTranslation('memory');
const { message } = App.useApp();
const { isValidating, refresh } = useMemoryAnalysisAsyncTask();
const [open, setOpen] = useState(false);
const [submitting, setSubmitting] = useState(false);
const handleSubmit = async () => {
setSubmitting(true);
try {
const [from, to] = range;
const result = await memoryExtractionService.requestFromChatTopics({
fromDate: from ?? undefined,
toDate: to ?? undefined,
});
await refresh();
message.success(result.deduped ? t('analysis.toast.deduped') : t('analysis.toast.started'));
setOpen(false);
} catch (error) {
console.error(error);
message.error(t('analysis.toast.failed'));
} finally {
setSubmitting(false);
}
};
return (
<>
<Button
icon={<Icon icon={CalendarClockIcon} />}
loading={submitting || isValidating}
onClick={() => setOpen(true)}
size={'large'}
type={'primary'}
style={{ maxWidth: 300 }}
>
{t('analysis.action.button')}
</Button>
<DateRangeModal
footerNote={footerNote}
open={open}
onCancel={() => setOpen(false)}
onChange={onRangeChange}
onSubmit={handleSubmit}
range={range}
submitting={submitting}
/>
</>
);
});
AnalysisTrigger.displayName = 'AnalysisTrigger';
export default AnalysisTrigger;
@@ -0,0 +1,68 @@
'use client';
import { Flexbox, Text } from '@lobehub/ui';
import { DatePicker, Modal } from 'antd';
import type { RangePickerProps } from 'antd/es/date-picker';
import dayjs, { type Dayjs } from 'dayjs';
import { memo, useCallback } from 'react';
import { useTranslation } from 'react-i18next';
interface Props {
footerNote: string;
onCancel: () => void;
onChange: (range: [Date | null, Date | null]) => void;
onSubmit: () => void;
open: boolean;
range: [Date | null, Date | null];
submitting: boolean;
}
const DateRangeModal = memo<Props>(
({ footerNote, onCancel, onChange, onSubmit, open, range, submitting }) => {
const { t } = useTranslation('memory');
const disabledDate = useCallback<NonNullable<RangePickerProps['disabledDate']>>(
(current) => current.isAfter(dayjs(), 'day'),
[],
);
return (
<Modal
cancelText={t('analysis.modal.cancel')}
okButtonProps={{ loading: submitting }}
okText={t('analysis.modal.submit')}
onCancel={onCancel}
onOk={onSubmit}
open={open}
title={t('analysis.modal.title')}
>
<Flexbox gap={12}>
<Text type={'secondary'}>{t('analysis.modal.helper')}</Text>
<DatePicker.RangePicker
allowClear
disabledDate={disabledDate}
format={'YYYY/MM/DD'}
onChange={(values) =>
onChange([
values?.[0]?.toDate() ?? null,
values?.[1]?.toDate() ?? null,
])
}
style={{ width: '100%' }}
value={[
range[0] ? dayjs(range[0]) : null,
range[1] ? dayjs(range[1]) : null,
]}
/>
<Text fontSize={12} type={'secondary'}>
{footerNote}
</Text>
</Flexbox>
</Modal>
);
},
);
DateRangeModal.displayName = 'DateRangeModal';
export default DateRangeModal;
@@ -0,0 +1,86 @@
'use client';
import { Alert, Flexbox, Icon, Text } from '@lobehub/ui';
import { Progress } from 'antd';
import { Loader2Icon, TriangleAlertIcon } from 'lucide-react';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { AsyncTaskStatus } from '@lobechat/types';
import type { MemoryExtractionTask } from '@/services/userMemory/extraction';
import { useMemoryAnalysisAsyncTask } from './useTask';
interface StatusProps {
task?: MemoryExtractionTask | null;
}
export const MemoryAnalysisStatus = memo<StatusProps>(({ task }) => {
const { t } = useTranslation('memory');
const data = task;
const status = data?.status;
const isRunning =
status === AsyncTaskStatus.Pending || status === AsyncTaskStatus.Processing;
const isError = status === AsyncTaskStatus.Error;
if (!data || (!isRunning && !isError)) return null;
const { progress } = data.metadata;
const percent =
progress.totalTopics && progress.totalTopics > 0
? Math.min(100, Math.round((progress.completedTopics / progress.totalTopics) * 100))
: undefined;
const progressText = progress.totalTopics
? t('analysis.status.progress', {
completed: progress.completedTopics,
total: progress.totalTopics,
})
: t('analysis.status.progressUnknown', { completed: progress.completedTopics });
const body = data.error?.body;
const errorText =
typeof body === 'string'
? body
: body && typeof body === 'object' && 'detail' in body && typeof body.detail === 'string'
? body.detail
: data.error?.name ?? t('analysis.status.errorTitle');
return (
<Alert
description={
<Flexbox gap={12}>
<Flexbox align="center" gap={12} horizontal wrap="wrap">
<Progress
percent={percent ?? 30}
showInfo={Boolean(percent)}
status={isError ? 'exception' : 'active'}
style={{ flex: 1, minWidth: 220 }}
/>
<Text fontSize={13} type={isError ? 'danger' : 'secondary'}>
{isError ? errorText ?? t('analysis.status.errorTitle') : progressText}
</Text>
</Flexbox>
</Flexbox>
}
icon={<Icon icon={isError ? TriangleAlertIcon : Loader2Icon} spin={isRunning && !isError} />}
title={isError ? t('analysis.status.errorTitle') : t('analysis.status.title')}
type={isError ? 'error' : 'info'}
variant={'borderless'}
/>
);
});
MemoryAnalysisStatus.displayName = 'MemoryAnalysisStatus';
const Status = memo(() => {
const { data } = useMemoryAnalysisAsyncTask();
return <MemoryAnalysisStatus task={data} />;
});
Status.displayName = 'MemoryAnalysisStatusWithData';
export default Status;
@@ -0,0 +1,40 @@
'use client';
import { Flexbox } from '@lobehub/ui';
import { AsyncTaskStatus } from '@lobechat/types';
import { memo, useMemo } from 'react';
import AnalysisAction from './Action';
import { MemoryAnalysisStatus } from './Status';
import { useMemoryAnalysisAsyncTask } from './useTask';
const MemoryAnalysis = memo(() => {
const { data, isValidating } = useMemoryAnalysisAsyncTask();
const { showAction, showStatus } = useMemo(() => {
const status = data?.status;
const isRunning =
status === AsyncTaskStatus.Pending || status === AsyncTaskStatus.Processing;
const isError = status === AsyncTaskStatus.Error;
console.log(isRunning, isValidating, isError, data);
return {
showAction: (!isRunning && (!isValidating || isError)) || !data || isError,
showStatus: Boolean(data && (isRunning || isError)),
};
}, [data, isValidating]);
if (!showAction && !showStatus) return null;
return (
<Flexbox gap={12} style={{ width: '100%', paddingTop: 16 }}>
{showStatus && <MemoryAnalysisStatus task={data} />}
{showAction && <AnalysisAction />}
</Flexbox>
);
});
MemoryAnalysis.displayName = 'MemoryAnalysis';
export default MemoryAnalysis;
@@ -0,0 +1,43 @@
import { AsyncTaskStatus } from '@lobechat/types';
import { useEffect } from 'react';
import { useClientDataSWR } from '@/libs/swr';
import {
type MemoryExtractionTask,
memoryExtractionService,
} from '@/services/userMemory/extraction';
const SWR_KEY = 'user-memory:analysis-task';
export const useMemoryAnalysisAsyncTask = (taskId?: string) => {
const swr = useClientDataSWR<MemoryExtractionTask | null>(
taskId ? [SWR_KEY, taskId] : SWR_KEY,
() => memoryExtractionService.getTask(taskId),
{
refreshInterval: (data) =>
data && [AsyncTaskStatus.Pending, AsyncTaskStatus.Processing].includes(data.status)
? 30_000
: 0,
},
);
useEffect(() => {
if (!swr.data) return;
const isRunning = [AsyncTaskStatus.Pending, AsyncTaskStatus.Processing].includes(
swr.data.status,
);
if (!isRunning) return;
const timer = setInterval(() => {
swr.mutate();
}, 5_000);
return () => clearInterval(timer);
}, [swr.data?.id, swr.data?.status, swr.mutate]);
return {
...swr,
refresh: swr.mutate,
};
};
@@ -1,27 +1,36 @@
import { Center, Empty, type EmptyProps } from '@lobehub/ui';
import { Center, Empty, Flexbox, type EmptyProps } from '@lobehub/ui';
import { BrainCircuitIcon } from 'lucide-react';
import { memo } from 'react';
import { memo, ReactNode } from 'react';
import { useTranslation } from 'react-i18next';
const MemoryEmpty = memo<EmptyProps & { search?: boolean }>(({ search, title, ...rest }) => {
const { t } = useTranslation('memory');
return (
<Center height="100%" style={{ minHeight: '50vh' }} width="100%">
<Empty
description={search ? t('empty.search') : t('empty.description')}
descriptionProps={{
fontSize: 14,
}}
icon={BrainCircuitIcon}
style={{
maxWidth: 400,
}}
title={search ? undefined : title || t('empty.title')}
type={search ? 'default' : 'page'}
{...rest}
/>
</Center>
);
});
const MemoryEmpty = memo<
EmptyProps & { children?: ReactNode | ReactNode[], search?: boolean; }
>(({ search, title, children, ...rest }) => {
const { t } = useTranslation('memory');
return (
<Center height="100%" style={{ minHeight: '50vh' }} width="100%">
<Flexbox align="center" gap={12}>
<Empty
description={search ? t('empty.search') : t('empty.description')}
descriptionProps={{
fontSize: 14,
}}
icon={BrainCircuitIcon}
style={{
maxWidth: 550,
}}
title={search ? undefined : title || t('empty.title')}
type={search ? 'default' : 'page'}
{...rest}
>
<Flexbox>
{children}
</Flexbox>
</Empty>
</Flexbox>
</Center>
);
},
);
export default MemoryEmpty;
@@ -1,6 +1,7 @@
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import MemoryAnalysis from '@/app/[variants]/(main)/memory/features/MemoryAnalysis';
import { useQueryState } from '@/hooks/useQueryParam';
import { useGlobalStore } from '@/store/global';
import { useUserMemoryStore } from '@/store/userMemory';
@@ -30,7 +31,14 @@ const IdentitiesList = memo<IdentitiesListProps>(({ isLoading, searchValue, view
};
if (!identities || identities.length === 0)
return <MemoryEmpty search={Boolean(searchValue)} title={t('identity.empty')} />;
return (
<MemoryEmpty
search={Boolean(searchValue)}
title={t('identity.empty')}
>
<MemoryAnalysis />
</MemoryEmpty>
);
if (viewMode === 'timeline')
return <TimelineView identities={identities} isLoading={isLoading} onClick={handleCardClick} />;
@@ -1,6 +1,7 @@
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import MemoryAnalysis from '@/app/[variants]/(main)/memory/features/MemoryAnalysis';
import MemoryEmpty from '@/app/[variants]/(main)/memory/features/MemoryEmpty';
import { useQueryState } from '@/hooks/useQueryParam';
import { useGlobalStore } from '@/store/global';
@@ -30,7 +31,14 @@ const PreferencesList = memo<PreferencesListProps>(({ isLoading, searchValue, vi
const isEmpty = preferences.length === 0;
if (isEmpty) {
return <MemoryEmpty search={Boolean(searchValue)} title={t('preference.empty')} />;
return (
<MemoryEmpty
search={Boolean(searchValue)}
title={t('preference.empty')}
>
<MemoryAnalysis />
</MemoryEmpty>
);
}
return viewMode === 'timeline' ? (
+20
View File
@@ -1,4 +1,24 @@
export default {
'analysis.action.button': 'Request memory analysis',
'analysis.modal.cancel': 'Cancel',
'analysis.modal.helper':
'By default Lobe AI will analyze all unprocessed chats. It\'s optional to select a date range to analyze.',
'analysis.modal.rangePlaceholder': 'No range selected; all conversations will be analyzed.',
'analysis.modal.rangeSelected': 'Analyzing chats from {{start}} to {{end}}',
'analysis.modal.submit': 'Request memory analysis',
'analysis.modal.title': 'Analyze chats to generate memories',
'analysis.range.all': 'All conversations',
'analysis.range.end': 'Today',
'analysis.range.start': 'Beginning',
'analysis.status.errorTitle': 'Memory analysis request failed',
'analysis.status.progress': 'Processed {{completed}} / {{total}} topics',
'analysis.status.progressUnknown': 'Processed {{completed}} topics so far',
'analysis.status.tip':
'We are processing your conversations to build personal memories. This may take a few minutes.',
'analysis.status.title': 'Memory analysis in progress',
'analysis.toast.deduped': 'A memory request is already running, continuing progress…',
'analysis.toast.failed': 'Memory analysis request failed. Please try again.',
'analysis.toast.started': 'Memory analysis started. We will update once its ready.',
'context.actions.delete': 'Delete',
'context.actions.edit': 'Edit',
'context.defaultType': 'Context',
@@ -60,7 +60,10 @@ export interface MemoryExtractionPrivateConfig {
secretAccessKey?: string;
};
upstashWorkflowExtraHeaders?: Record<string, string>;
webhookHeaders?: Record<string, string>;
webhook: {
headers?: Record<string, string>;
baseUrl?: string;
}
whitelistUsers?: string[];
}
@@ -253,7 +256,10 @@ export const parseMemoryExtractionConfig = (): MemoryExtractionPrivateConfig =>
featureFlags,
observabilityS3: extractorObservabilityS3,
upstashWorkflowExtraHeaders,
webhookHeaders,
webhook: {
headers: webhookHeaders,
baseUrl: process.env.MEMORY_USER_MEMORY_WEBHOOK_BASE_URL,
},
whitelistUsers,
};
};
@@ -0,0 +1,222 @@
import { TRPCError } from '@trpc/server';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { userMemoryRouter } from '@/server/routers/lambda/userMemory';
import { AsyncTaskStatus, AsyncTaskType } from '@/types/asyncTask';
import { MemorySourceType } from '@/types/userMemory';
const mockFindActiveByType = vi.fn();
const mockCreate = vi.fn();
const mockUpdate = vi.fn();
const mockFindById = vi.fn();
const mockCountTopicsForMemoryExtractor = vi.fn();
const { mockTriggerProcessUsers } = vi.hoisted(() => ({
mockTriggerProcessUsers: vi.fn(),
}));
vi.mock('@/database/models/asyncTask', () => ({
AsyncTaskModel: vi.fn(() => ({
create: mockCreate,
findById: mockFindById,
findActiveByType: mockFindActiveByType,
update: mockUpdate,
})),
initUserMemoryExtractionMetadata: vi.fn((metadata) => metadata),
}));
vi.mock('@/database/models/topic', () => ({
TopicModel: vi.fn(() => ({
countTopicsForMemoryExtractor: mockCountTopicsForMemoryExtractor,
})),
}));
vi.mock('@/envs/app', () => ({
appEnv: {
APP_URL: 'https://example.com',
INTERNAL_APP_URL: 'https://internal.example.com',
},
}));
vi.mock('@/server/globalConfig/parseMemoryExtractionConfig', () => ({
parseMemoryExtractionConfig: vi.fn(() => ({
webhook: { baseUrl: 'https://internal.example.com' },
upstashWorkflowExtraHeaders: { 'x-test': 'ok' },
})),
}));
vi.mock('@/server/services/memory/userMemory/extract', () => ({
MemoryExtractionWorkflowService: {
triggerProcessUsers: mockTriggerProcessUsers,
},
buildWorkflowPayloadInput: (payload: any) => payload,
normalizeMemoryExtractionPayload: (payload: any) => payload,
}));
const createCaller = (ctxOverrides: Partial<any> = {}) => {
const ctx = {
serverDB: {} as any,
userId: 'user-1',
...ctxOverrides,
};
return userMemoryRouter.createCaller(ctx);
};
describe('userMemoryRouter.requestMemoryFromChatTopic', () => {
beforeEach(() => {
vi.clearAllMocks();
});
it('dedupes when an active task exists', async () => {
mockFindActiveByType.mockResolvedValue({
id: 'existing-task',
metadata: { progress: { completedTopics: 0, totalTopics: 1 } },
status: AsyncTaskStatus.Pending,
});
const caller = createCaller();
const result = await caller.requestMemoryFromChatTopic({});
expect(result).toEqual({
deduped: true,
id: 'existing-task',
metadata: { progress: { completedTopics: 0, totalTopics: 1 } },
status: AsyncTaskStatus.Pending,
});
expect(mockCreate).not.toHaveBeenCalled();
expect(mockTriggerProcessUsers).not.toHaveBeenCalled();
});
it('creates task and triggers workflow with user context and dates', async () => {
mockFindActiveByType.mockResolvedValue(undefined);
mockCreate.mockResolvedValue('new-task');
mockCountTopicsForMemoryExtractor.mockResolvedValue(2);
const caller = createCaller();
const result = await caller.requestMemoryFromChatTopic({
fromDate: new Date('2024-01-01'),
toDate: new Date('2024-02-01'),
});
expect(mockCreate).toHaveBeenCalledWith({
metadata: {
progress: { completedTopics: 0, totalTopics: 2 },
range: { from: new Date('2024-01-01').toISOString(), to: new Date('2024-02-01').toISOString() },
source: 'chat_topic',
},
status: AsyncTaskStatus.Pending,
type: AsyncTaskType.UserMemoryExtractionWithChatTopic,
});
expect(mockTriggerProcessUsers).toHaveBeenCalledWith(
expect.objectContaining({
asyncTaskId: 'new-task',
baseUrl: 'https://internal.example.com',
fromDate: new Date('2024-01-01'),
sources: [MemorySourceType.ChatTopic],
toDate: new Date('2024-02-01'),
userIds: ['user-1'],
userInitiated: true,
}),
{ extraHeaders: { 'x-test': 'ok' } },
);
expect(result).toMatchObject({
deduped: false,
id: 'new-task',
status: AsyncTaskStatus.Pending,
});
});
it('returns success immediately when no topics', async () => {
mockFindActiveByType.mockResolvedValue(undefined);
mockCountTopicsForMemoryExtractor.mockResolvedValue(0);
mockCreate.mockResolvedValue('empty-task');
const caller = createCaller();
const result = await caller.requestMemoryFromChatTopic({});
expect(result).toEqual({
deduped: false,
id: 'empty-task',
metadata: {
progress: { completedTopics: 0, totalTopics: 0 },
range: { from: undefined, to: undefined },
source: 'chat_topic',
},
status: AsyncTaskStatus.Success,
});
expect(mockTriggerProcessUsers).not.toHaveBeenCalled();
});
it('throws on invalid date range', async () => {
const caller = createCaller();
await expect(
caller.requestMemoryFromChatTopic({
fromDate: new Date('2024-02-02'),
toDate: new Date('2024-01-01'),
}),
).rejects.toBeInstanceOf(TRPCError);
});
});
describe('userMemoryRouter.getMemoryExtractionTask', () => {
beforeEach(() => {
vi.clearAllMocks();
});
it('returns null when no active task', async () => {
mockFindActiveByType.mockResolvedValue(undefined);
const caller = createCaller();
const result = await caller.getMemoryExtractionTask();
expect(result).toBeNull();
});
it('returns active task with normalized metadata', async () => {
mockFindActiveByType.mockResolvedValue({
id: 'task-1',
metadata: {
progress: { completedTopics: 1, totalTopics: 4 },
source: 'chat_topic',
},
status: AsyncTaskStatus.Processing,
userId: 'user-1',
});
const caller = createCaller();
const result = await caller.getMemoryExtractionTask();
expect(result).toEqual({
error: undefined,
id: 'task-1',
metadata: {
progress: { completedTopics: 1, totalTopics: 4 },
range: undefined,
source: 'chat_topic',
},
status: AsyncTaskStatus.Processing,
});
});
it('fetches by task id when provided', async () => {
mockFindActiveByType.mockResolvedValue(undefined);
mockFindById.mockResolvedValue({
id: 'a0a0a0a0-a0a0-4a0a-a0a0-a0a0a0a0a0a0',
metadata: {
progress: { completedTopics: 2, totalTopics: 8 },
source: 'chat_topic',
},
status: AsyncTaskStatus.Pending,
userId: 'user-1',
});
const caller = createCaller();
const result = await caller.getMemoryExtractionTask({
taskId: 'a0a0a0a0-a0a0-4a0a-a0a0-a0a0a0a0a0a0',
});
expect(mockFindById).toHaveBeenCalledWith('a0a0a0a0-a0a0-4a0a-a0a0-a0a0a0a0a0a0');
expect(result?.id).toBe('a0a0a0a0-a0a0-4a0a-a0a0-a0a0a0a0a0a0');
});
});
+152 -1
View File
@@ -1,6 +1,18 @@
import { CreateUserMemoryIdentitySchema, UpdateUserMemoryIdentitySchema } from '@lobechat/types';
import {
AsyncTaskError,
AsyncTaskErrorType,
AsyncTaskStatus,
AsyncTaskType,
CreateUserMemoryIdentitySchema,
MemorySourceType,
UpdateUserMemoryIdentitySchema,
UserMemoryExtractionMetadata,
} from '@lobechat/types';
import { TRPCError } from '@trpc/server';
import { z } from 'zod';
import { AsyncTaskModel, initUserMemoryExtractionMetadata } from '@/database/models/asyncTask';
import { TopicModel } from '@/database/models/topic';
import { UserMemoryModel } from '@/database/models/userMemory';
import {
UserMemoryContextModel,
@@ -10,21 +22,41 @@ import {
} from '@/database/models/userMemory/index';
import { authedProcedure, router } from '@/libs/trpc/lambda';
import { serverDatabase } from '@/libs/trpc/lambda/middleware';
import { appEnv } from '@/envs/app';
import { parseMemoryExtractionConfig } from '@/server/globalConfig/parseMemoryExtractionConfig';
import {
MemoryExtractionWorkflowService,
buildWorkflowPayloadInput,
normalizeMemoryExtractionPayload,
} from '@/server/services/memory/userMemory/extract';
const userMemoryProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
const { ctx } = opts;
return opts.next({
ctx: {
asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId),
contextModel: new UserMemoryContextModel(ctx.serverDB, ctx.userId),
experienceModel: new UserMemoryExperienceModel(ctx.serverDB, ctx.userId),
identityModel: new UserMemoryIdentityModel(ctx.serverDB, ctx.userId),
preferenceModel: new UserMemoryPreferenceModel(ctx.serverDB, ctx.userId),
topicModel: new TopicModel(ctx.serverDB, ctx.userId),
userMemoryModel: new UserMemoryModel(ctx.serverDB, ctx.userId),
},
});
});
const userMemoryExtractionInputSchema = z.object({
fromDate: z.coerce.date().optional(),
toDate: z.coerce.date().optional(),
});
const userMemoryExtractionTaskInputSchema = z
.object({
taskId: z.string().uuid().optional(),
})
.optional();
export const userMemoryRouter = router({
// ============ Identity CRUD ============
createIdentity: userMemoryProcedure
@@ -82,6 +114,25 @@ export const userMemoryRouter = router({
return ctx.userMemoryModel.getAllIdentities();
}),
getMemoryExtractionTask: userMemoryProcedure
.input(userMemoryExtractionTaskInputSchema)
.query(async ({ ctx, input }) => {
const task = input?.taskId
? await ctx.asyncTaskModel.findById(input.taskId)
: await ctx.asyncTaskModel.findActiveByType(
AsyncTaskType.UserMemoryExtractionWithChatTopic,
);
if (!task || task.userId !== ctx.userId) return null;
return {
error: task.error,
id: task.id,
metadata: initUserMemoryExtractionMetadata(task.metadata as UserMemoryExtractionMetadata | undefined),
status: task.status as AsyncTaskStatus,
};
}),
// ============ Persona ============
getPersona: userMemoryProcedure.query(async () => {
return { content: '', summary: '' };
@@ -91,6 +142,106 @@ export const userMemoryRouter = router({
return ctx.userMemoryModel.searchPreferences({});
}),
requestMemoryFromChatTopic: userMemoryProcedure
.input(userMemoryExtractionInputSchema)
.mutation(async ({ ctx, input }) => {
if (input.fromDate && input.toDate && input.fromDate > input.toDate) {
throw new TRPCError({
code: 'BAD_REQUEST',
message: '`fromDate` cannot be later than `toDate`',
});
}
const existingTask = await ctx.asyncTaskModel.findActiveByType(
AsyncTaskType.UserMemoryExtractionWithChatTopic,
);
if (existingTask) {
return {
deduped: true,
id: existingTask.id,
metadata: existingTask.metadata as UserMemoryExtractionMetadata,
status: existingTask.status as AsyncTaskStatus,
};
}
const totalTopics = await ctx.topicModel.countTopicsForMemoryExtractor({
endDate: input.toDate,
ignoreExtracted: false,
startDate: input.fromDate,
});
const metadata = initUserMemoryExtractionMetadata({
progress: {
completedTopics: 0,
totalTopics,
},
range: {
from: input.fromDate?.toISOString(),
to: input.toDate?.toISOString(),
},
source: 'chat_topic',
});
const initialStatus =
totalTopics === 0 ? AsyncTaskStatus.Success : AsyncTaskStatus.Pending;
const taskId = await ctx.asyncTaskModel.create({
metadata,
status: initialStatus,
type: AsyncTaskType.UserMemoryExtractionWithChatTopic,
});
if (totalTopics === 0) {
return {
deduped: false,
id: taskId,
metadata: metadata as UserMemoryExtractionMetadata,
status: initialStatus as AsyncTaskStatus,
};
}
const { webhook, upstashWorkflowExtraHeaders } = parseMemoryExtractionConfig();
const baseUrl = webhook.baseUrl || appEnv.INTERNAL_APP_URL || appEnv.APP_URL;
try {
await MemoryExtractionWorkflowService.triggerProcessUsers(
buildWorkflowPayloadInput(
normalizeMemoryExtractionPayload({
asyncTaskId: taskId,
baseUrl,
forceAll: false,
forceTopics: false,
fromDate: input.fromDate,
mode: 'workflow',
sources: [MemorySourceType.ChatTopic],
toDate: input.toDate,
userIds: [ctx.userId],
userInitiated: true,
}),
),
{ extraHeaders: upstashWorkflowExtraHeaders },
);
} catch (error) {
await ctx.asyncTaskModel.update(taskId, {
error: new AsyncTaskError(
AsyncTaskErrorType.TaskTriggerError,
'Failed to schedule memory extraction workflow',
),
status: AsyncTaskStatus.Error,
});
throw new TRPCError({
cause: error,
code: 'INTERNAL_SERVER_ERROR',
message: 'Failed to trigger user memory extraction',
});
}
return {
deduped: false,
id: taskId,
metadata: metadata as UserMemoryExtractionMetadata,
status: AsyncTaskStatus.Pending,
};
}),
updateContext: userMemoryProcedure
.input(
z.object({
@@ -1,6 +1,6 @@
import type { AiProviderRuntimeState } from '@lobechat/types';
import type { EnabledAiModel } from 'model-bank';
import { describe, expect, it } from 'vitest';
import { describe, expect, it, vi } from 'vitest';
import type { MemoryExtractionPrivateConfig } from '@/server/globalConfig/parseMemoryExtractionConfig';
@@ -39,7 +39,7 @@ const createExecutor = (privateOverrides?: Partial<MemoryExtractionPrivateConfig
embedding: { model: 'embed-1', provider: 'provider-e' },
featureFlags: { enableBenchmarkLoCoMo: false },
observabilityS3: { enabled: false },
webhookHeaders: {},
webhook: {},
};
const serverConfig = {
@@ -70,6 +70,7 @@ describe('MemoryExtractionExecutor.resolveRuntimeKeyVaults', () => {
{ abilities: {}, id: 'gate-2', providerId: 'provider-b', type: 'chat' },
{ abilities: {}, id: 'embed-1', providerId: 'provider-e', type: 'embedding' },
{ abilities: {}, id: 'layer-ctx', providerId: 'provider-l', type: 'chat' },
{ abilities: {}, id: 'layer-act', providerId: 'provider-l', type: 'chat' },
{ abilities: {}, id: 'layer-exp', providerId: 'provider-l', type: 'chat' },
{ abilities: {}, id: 'layer-id', providerId: 'provider-l', type: 'chat' },
{ abilities: {}, id: 'layer-pref', providerId: 'provider-l', type: 'chat' },
@@ -98,6 +99,7 @@ describe('MemoryExtractionExecutor.resolveRuntimeKeyVaults', () => {
const runtimeState = createRuntimeState(
[
{ abilities: {}, id: 'gate-2', providerId: 'provider-b', type: 'chat' },
{ abilities: {}, id: 'layer-act', providerId: 'provider-l', type: 'chat' },
{ abilities: {}, id: 'layer-ctx', providerId: 'provider-l', type: 'chat' },
{ abilities: {}, id: 'layer-exp', providerId: 'provider-l', type: 'chat' },
{ abilities: {}, id: 'layer-id', providerId: 'provider-l', type: 'chat' },
@@ -39,6 +39,7 @@ import {
import { attributesCommon } from '@lobechat/observability-otel/node';
import type {
AiProviderRuntimeState,
ChatTopicMetadata,
IdentityMemoryDetail,
MemoryExtractionAgentCallTrace,
MemoryExtractionTraceError,
@@ -54,6 +55,7 @@ import type { ListTopicsForMemoryExtractorCursor } from '@/database/models/topic
import { TopicModel } from '@/database/models/topic';
import type { ListUsersForMemoryExtractorCursor } from '@/database/models/user';
import { UserModel } from '@/database/models/user';
import { AsyncTaskModel } from '@/database/models/asyncTask';
import { UserMemoryModel } from '@/database/models/userMemory';
import { UserMemorySourceBenchmarkLoCoMoModel } from '@/database/models/userMemory/sources/benchmarkLoCoMo';
import { AiInfraRepos } from '@/database/repositories/aiInfra';
@@ -67,12 +69,12 @@ import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt';
import { S3 } from '@/server/modules/S3';
import type { GlobalMemoryLayer } from '@/types/serverConfig';
import type { ProviderConfig } from '@/types/user/settings';
import { LayersEnum, MemorySourceType, type MergeStrategyEnum, TypesEnum } from '@/types/userMemory';
import {
LayersEnum,
MemorySourceType,
type MergeStrategyEnum,
TypesEnum,
} from '@/types/userMemory';
AsyncTaskError,
AsyncTaskErrorType,
AsyncTaskStatus,
} from '@/types/asyncTask';
import { encodeAsync } from '@/utils/tokenizer';
const SOURCE_ALIAS_MAP: Record<string, MemorySourceType> = {
@@ -107,6 +109,7 @@ export interface TopicWorkflowCursor extends MemoryExtractionWorkflowCursor {
}
export interface MemoryExtractionNormalizedPayload {
asyncTaskId?: string;
baseUrl: string;
forceAll: boolean;
forceTopics: boolean;
@@ -126,9 +129,11 @@ export interface MemoryExtractionNormalizedPayload {
userCursor?: MemoryExtractionWorkflowCursor;
userId?: string;
userIds: string[];
userInitiated?: boolean;
}
export const memoryExtractionPayloadSchema = z.object({
asyncTaskId: z.string().uuid().optional(),
baseUrl: z.string().url().optional(),
forceAll: z.boolean().optional(),
forceTopics: z.boolean().optional(),
@@ -155,6 +160,7 @@ export const memoryExtractionPayloadSchema = z.object({
.optional(),
userId: z.string().optional(),
userIds: z.array(z.string()).optional(),
userInitiated: z.boolean().optional(),
});
export type MemoryExtractionPayloadInput = z.infer<typeof memoryExtractionPayloadSchema>;
@@ -188,6 +194,7 @@ export const normalizeMemoryExtractionPayload = (
if (!baseUrl) throw new Error('Missing baseUrl for workflow trigger');
return {
asyncTaskId: parsed.asyncTaskId,
baseUrl,
forceAll: parsed.forceAll ?? false,
forceTopics: parsed.forceTopics ?? false,
@@ -205,6 +212,7 @@ export const normalizeMemoryExtractionPayload = (
userIds: Array.from(
new Set([...(parsed.userIds || []), ...(parsed.userId ? [parsed.userId] : [])]),
).filter(Boolean),
userInitiated: parsed.userInitiated ?? false,
};
};
@@ -223,6 +231,7 @@ type ProviderKeyVaultMap = Record<
export const buildWorkflowPayloadInput = (
payload: MemoryExtractionNormalizedPayload,
): MemoryExtractionPayloadInput => ({
asyncTaskId: payload.asyncTaskId,
baseUrl: payload.baseUrl,
forceAll: payload.forceAll,
forceTopics: payload.forceTopics,
@@ -238,6 +247,7 @@ export const buildWorkflowPayloadInput = (
userCursor: payload.userCursor,
userId: payload.userId ?? payload.userIds[0],
userIds: payload.userIds,
userInitiated: payload.userInitiated,
});
const normalizeProvider = (provider: string) => provider.toLowerCase();
@@ -329,12 +339,11 @@ const initRuntimeForAgent = async (agent: MemoryAgentConfig, keyVaults?: Provide
});
};
const isTopicExtracted = (metadata: any): boolean => {
const extractStatus = metadata?.memory_user_memory_extract?.extract_status;
const isTopicExtracted = (metadata?: ChatTopicMetadata | null): boolean => {
const extractStatus = metadata?.userMemoryExtractStatus;
if (extractStatus) return extractStatus === 'completed';
const state = metadata?.memoryExtraction?.sources?.chat_topic;
return state?.status === 'completed' && !!state?.lastRunAt;
return metadata?.userMemoryExtractStatus === 'completed' && !!metadata?.userMemoryExtractRunState?.lastRunAt;
};
type RuntimeBundle = {
@@ -344,6 +353,7 @@ type RuntimeBundle = {
};
export interface TopicExtractionJob {
asyncTaskId?: string;
forceAll: boolean;
forceTopics: boolean;
from?: Date;
@@ -352,6 +362,7 @@ export interface TopicExtractionJob {
to?: Date;
topicId: string;
userId: string;
userInitiated?: boolean;
}
export interface TopicPaginationJob {
@@ -1024,6 +1035,18 @@ export class MemoryExtractionExecutor {
return res.map((item) => ({ ...item, layer: LayersEnum.Identity }));
}
private async reportUserInitiatedProgress(job: TopicExtractionJob) {
if (!job.asyncTaskId || !job.userInitiated) return;
try {
const db = await this.db;
const asyncTaskModel = new AsyncTaskModel(db, job.userId);
await asyncTaskModel.incrementUserMemoryExtractionProgress(job.asyncTaskId);
} catch (error) {
console.error('[memory-extraction] failed to update async task progress', error);
}
}
async extractTopic(job: TopicExtractionJob) {
const attributes = {
source: job.source,
@@ -1050,6 +1073,8 @@ export class MemoryExtractionExecutor {
'Memory User Memory: Extract Chat Topic',
{ attributes },
async (span) => {
const shouldReportProgress = job.userInitiated && !!job.asyncTaskId;
let topicProcessed = false;
const startTime = Date.now();
let extractionJob: MemoryExtractionJob | null = null;
let extraction: MemoryExtractionResult | null = null;
@@ -1072,6 +1097,7 @@ export class MemoryExtractionExecutor {
`[memory-extraction] topic ${job.topicId} not found for user ${job.userId}`,
);
span.setStatus({ code: SpanStatusCode.OK, message: 'topic_not_found' });
topicProcessed = true;
return {
extracted: false,
layers: {},
@@ -1081,6 +1107,7 @@ export class MemoryExtractionExecutor {
}
if ((job.from && topic.createdAt < job.from) || (job.to && topic.createdAt > job.to)) {
span.setStatus({ code: SpanStatusCode.OK, message: 'topic_out_of_range' });
topicProcessed = true;
return {
extracted: false,
layers: {},
@@ -1090,6 +1117,7 @@ export class MemoryExtractionExecutor {
}
if (!job.forceAll && !job.forceTopics && isTopicExtracted(topic.metadata)) {
span.setStatus({ code: SpanStatusCode.OK, message: 'already_extracted' });
topicProcessed = true;
return {
extracted: false,
layers: {},
@@ -1292,6 +1320,7 @@ export class MemoryExtractionExecutor {
if (!extraction) {
this.recordJobMetrics(extractionJob, 'completed', Date.now() - startTime);
span.setStatus({ code: SpanStatusCode.OK, message: 'no_extraction' });
topicProcessed = true;
return {
extracted: false,
layers: {},
@@ -1318,6 +1347,7 @@ export class MemoryExtractionExecutor {
span.setStatus({ code: SpanStatusCode.OK });
span.setAttribute('memory.processed_memory_count', persistedRes.createdIds.length);
topicProcessed = true;
return {
extracted: true,
layers: persistedRes.layers,
@@ -1340,8 +1370,26 @@ export class MemoryExtractionExecutor {
if (tracePayload) {
tracePayload.error = serializeError(error);
}
if (job.asyncTaskId && job.userInitiated) {
try {
const asyncTaskModel = new AsyncTaskModel(await this.db, job.userId);
await asyncTaskModel.update(job.asyncTaskId, {
error: new AsyncTaskError(
AsyncTaskErrorType.ServerError,
error instanceof Error ? error.message : 'Extraction failed',
),
status: AsyncTaskStatus.Error,
});
} catch (taskError) {
console.error('[memory-extraction] failed to update async task status', taskError);
}
}
throw error;
} finally {
if (shouldReportProgress && topicProcessed) {
await this.reportUserInitiatedProgress(job);
}
if (observabilityS3 && tracePayload) {
try {
await this.uploadExtractionTrace(
@@ -1428,6 +1476,7 @@ export class MemoryExtractionExecutor {
const topicIds = await this.filterTopicIdsForUser(userId, payload.topicIds);
for (const topicId of topicIds) {
const extracted = await this.extractTopic({
asyncTaskId: payload.asyncTaskId,
forceAll: payload.forceAll,
forceTopics: payload.forceTopics,
from: payload.from,
@@ -1436,6 +1485,7 @@ export class MemoryExtractionExecutor {
to: payload.to,
topicId,
userId,
userInitiated: payload.userInitiated,
});
results.push({ ...extracted, topicId, userId });
+35
View File
@@ -0,0 +1,35 @@
import type { AsyncTaskStatus, IAsyncTaskError, UserMemoryExtractionMetadata } from '@lobechat/types';
import { lambdaClient } from '@/libs/trpc/client';
export interface MemoryExtractionTask {
error?: IAsyncTaskError | null;
id: string;
metadata: UserMemoryExtractionMetadata;
status: AsyncTaskStatus;
}
export interface RequestMemoryExtractionParams {
fromDate?: Date;
toDate?: Date;
}
export interface RequestMemoryExtractionResult extends MemoryExtractionTask {
deduped: boolean;
}
class MemoryExtractionService {
requestFromChatTopics = async (
params: RequestMemoryExtractionParams,
): Promise<RequestMemoryExtractionResult> => {
return lambdaClient.userMemory.requestMemoryFromChatTopic.mutate(params);
};
getTask = async (taskId?: string): Promise<MemoryExtractionTask | null> => {
return lambdaClient.userMemory.getMemoryExtractionTask.query(
taskId ? { taskId } : undefined,
) as Promise<MemoryExtractionTask | null>;
};
}
export const memoryExtractionService = new MemoryExtractionService();
+1
View File
@@ -136,3 +136,4 @@ class UserMemoryService {
export const userMemoryService = new UserMemoryService();
export { memoryCRUDService } from './crud';
export { memoryExtractionService } from './extraction';