feat: add memory implement

This commit is contained in:
Neko
2025-12-20 22:03:44 +08:00
committed by arvinxx
parent 6ff8efacb3
commit fdae83ca2d
24 changed files with 1629 additions and 30 deletions
+4
View File
@@ -14,6 +14,7 @@ import { cleanObject } from '@/utils/object';
import { genServerAiProvidersConfig } from './genServerAiProviderConfig';
import { parseAgentConfig } from './parseDefaultAgent';
import { parseFilesConfig } from './parseFilesConfig';
import { getPublicMemoryExtractionConfig } from './parseMemoryExtractionConfig';
/**
* Get Better-Auth SSO providers list
@@ -74,6 +75,9 @@ export const getServerGlobalConfig = async () => {
image: cleanObject({
defaultImageNum: imageEnv.AI_IMAGE_DEFAULT_IMAGE_NUM,
}),
memory: {
userMemory: cleanObject(getPublicMemoryExtractionConfig()),
},
oAuthSSOProviders: authEnv.NEXT_PUBLIC_ENABLE_BETTER_AUTH
? getBetterAuthSSOProviders()
: authEnv.NEXT_AUTH_SSO_PROVIDERS.trim().split(/[,]/),
@@ -0,0 +1,210 @@
import { DEFAULT_USER_MEMORY_EMBEDDING_MODEL_ITEM } from '@lobechat/const';
import {
GlobalMemoryExtractionConfig,
GlobalMemoryLayer,
MemoryAgentPublicConfig,
MemoryLayerExtractorPublicConfig,
} from '@/types/serverConfig';
const MEMORY_LAYERS: GlobalMemoryLayer[] = ['context', 'experience', 'identity', 'preference'];
const DEFAULT_GATE_MODEL = 'gpt-5-mini';
const DEFAULT_PROVIDER = 'openai';
const parseTokenLimitEnv = (value?: string) => {
if (value === undefined) return undefined;
const parsed = Number(value);
if (!Number.isFinite(parsed) || parsed <= 0) return undefined;
return Math.floor(parsed);
};
export type MemoryAgentConfig = MemoryAgentPublicConfig & {
apiKey?: string;
language?: string;
model: string;
};
export type MemoryLayerExtractorConfig = MemoryLayerExtractorPublicConfig &
MemoryAgentConfig & {
layers: Record<GlobalMemoryLayer, string>;
};
export interface MemoryExtractionPrivateConfig {
agentGateKeeper: MemoryAgentConfig;
agentLayerExtractor: MemoryLayerExtractorConfig;
concurrency?: number;
embedding: MemoryAgentConfig;
observabilityS3?: {
accessKeyId?: string;
bucketName?: string;
enabled: boolean;
endpoint?: string;
forcePathStyle?: boolean;
pathPrefix?: string;
region?: string;
secretAccessKey?: string;
};
webhookHeaders?: Record<string, string>;
whitelistUsers?: string[];
}
const parseGateKeeperAgent = (): MemoryAgentConfig => {
const apiKey = process.env.MEMORY_USER_MEMORY_GATEKEEPER_API_KEY;
const baseURL = process.env.MEMORY_USER_MEMORY_GATEKEEPER_BASE_URL;
const model = process.env.MEMORY_USER_MEMORY_GATEKEEPER_MODEL || DEFAULT_GATE_MODEL;
const provider = process.env.MEMORY_USER_MEMORY_GATEKEEPER_PROVIDER || DEFAULT_PROVIDER;
const language = process.env.MEMORY_USER_MEMORY_GATEKEEPER_LANGUAGE || 'English';
return {
apiKey,
baseURL,
language,
model,
provider,
};
};
const parseLayerExtractorAgent = (fallbackModel: string): MemoryLayerExtractorConfig => {
const apiKey = process.env.MEMORY_USER_MEMORY_LAYER_EXTRACTOR_API_KEY;
const baseURL = process.env.MEMORY_USER_MEMORY_LAYER_EXTRACTOR_BASE_URL;
const model = process.env.MEMORY_USER_MEMORY_LAYER_EXTRACTOR_MODEL || fallbackModel;
const provider = process.env.MEMORY_USER_MEMORY_LAYER_EXTRACTOR_PROVIDER || DEFAULT_PROVIDER;
const contextLimit = parseTokenLimitEnv(
process.env.MEMORY_USER_MEMORY_LAYER_EXTRACTOR_CONTEXT_LIMIT,
);
const layers = MEMORY_LAYERS.reduce<Record<GlobalMemoryLayer, string>>(
(acc, layer) => {
const envKey = `MEMORY_USER_MEMORY_LAYER_EXTRACTOR_${layer.toUpperCase()}_MODEL`;
const override = (process.env as Record<string, string | undefined>)[envKey];
acc[layer] = override || model;
return acc;
},
{} as Record<GlobalMemoryLayer, string>,
);
return {
apiKey,
baseURL,
contextLimit,
layers,
model,
provider,
};
};
const parseEmbeddingAgent = (
fallbackModel: string,
fallbackProvider: string,
fallbackApiKey?: string,
): MemoryAgentConfig => {
const { model: defaultModel, provider: defaultProvider } =
DEFAULT_USER_MEMORY_EMBEDDING_MODEL_ITEM;
const model = process.env.MEMORY_USER_MEMORY_EMBEDDING_MODEL || fallbackModel || defaultModel;
const provider =
process.env.MEMORY_USER_MEMORY_EMBEDDING_PROVIDER ||
fallbackProvider ||
defaultProvider ||
DEFAULT_PROVIDER;
return {
apiKey: process.env.MEMORY_USER_MEMORY_EMBEDDING_API_KEY ?? fallbackApiKey,
baseURL: process.env.MEMORY_USER_MEMORY_EMBEDDING_BASE_URL,
contextLimit: parseTokenLimitEnv(process.env.MEMORY_USER_MEMORY_EMBEDDING_CONTEXT_LIMIT),
model,
provider,
};
};
const parseExtractorAgentObservabilityS3 = () => {
const accessKeyId = process.env.MEMORY_USER_MEMORY_EXTRACTOR_S3_ACCESS_KEY_ID;
const secretAccessKey = process.env.MEMORY_USER_MEMORY_EXTRACTOR_S3_SECRET_ACCESS_KEY;
const bucketName = process.env.MEMORY_USER_MEMORY_EXTRACTOR_S3_BUCKET_NAME;
const region = process.env.MEMORY_USER_MEMORY_EXTRACTOR_S3_REGION;
const endpoint = process.env.MEMORY_USER_MEMORY_EXTRACTOR_S3_ENDPOINT;
const forcePathStyle = process.env.MEMORY_USER_MEMORY_EXTRACTOR_S3_FORCE_PATH_STYLE === 'true';
const pathPrefix = process.env.MEMORY_USER_MEMORY_EXTRACTOR_S3_PATH_PREFIX;
if (!accessKeyId || !secretAccessKey || !endpoint) {
return {
enabled: false,
};
}
return {
accessKeyId,
bucketName,
enabled: true,
endpoint,
forcePathStyle,
pathPrefix,
region,
secretAccessKey,
};
};
const sanitizeAgent = (agent?: MemoryAgentConfig): MemoryAgentPublicConfig | undefined => {
if (!agent) return undefined;
const sanitized: MemoryAgentConfig = { ...agent };
delete sanitized.apiKey;
return sanitized as MemoryAgentPublicConfig;
};
export const parseMemoryExtractionConfig = (): MemoryExtractionPrivateConfig => {
const agentGateKeeper = parseGateKeeperAgent();
const agentLayerExtractor = parseLayerExtractorAgent(agentGateKeeper.model);
const embedding = parseEmbeddingAgent(
agentLayerExtractor.model,
agentLayerExtractor.provider || DEFAULT_PROVIDER,
agentGateKeeper.apiKey || agentLayerExtractor.apiKey,
);
const extractorObservabilityS3 = parseExtractorAgentObservabilityS3();
const concurrencyRaw = process.env.MEMORY_USER_MEMORY_CONCURRENCY;
const concurrency =
concurrencyRaw !== undefined
? Number.isInteger(Number(concurrencyRaw)) && Number(concurrencyRaw) > 0
? Number(concurrencyRaw)
: undefined
: undefined;
const whitelistUsers = process.env.MEMORY_USER_MEMORY_WHITELIST_USERS?.split(',')
.filter(Boolean)
.map((s) => s.trim());
const webhookHeaders = process.env.MEMORY_USER_MEMORY_WEBHOOK_HEADERS?.split(',')
.filter(Boolean)
.reduce<Record<string, string>>((acc, pair) => {
const [key, value] = pair.split('=').map((s) => s.trim());
if (key && value) {
acc[key] = value;
}
return acc;
}, {});
return {
agentGateKeeper,
agentLayerExtractor,
concurrency,
embedding,
observabilityS3: extractorObservabilityS3,
webhookHeaders,
whitelistUsers,
};
};
export const getPublicMemoryExtractionConfig = (): GlobalMemoryExtractionConfig => {
const privateConfig = parseMemoryExtractionConfig();
return {
agentGateKeeper: sanitizeAgent(privateConfig.agentGateKeeper)!,
agentLayerExtractor: {
...sanitizeAgent(privateConfig.agentLayerExtractor),
layers: privateConfig.agentLayerExtractor.layers,
},
concurrency: privateConfig.concurrency,
embedding: sanitizeAgent(privateConfig.embedding),
};
};
+110 -21
View File
@@ -9,7 +9,7 @@ import {
import { getSignedUrl } from '@aws-sdk/s3-request-presigner';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { S3 } from './index';
import { FileS3, S3 } from './index';
// Mock AWS SDK
vi.mock('@aws-sdk/client-s3');
@@ -59,7 +59,96 @@ describe('S3', () => {
describe('constructor', () => {
it('should initialize S3 client with correct configuration', () => {
new S3();
const testFileEnv = {
S3_ACCESS_KEY_ID: 'test-access-key',
S3_BUCKET: 'test-bucket',
S3_ENABLE_PATH_STYLE: false,
S3_ENDPOINT: 'https://s3.amazonaws.com',
S3_PREVIEW_URL_EXPIRE_IN: 7200,
S3_REGION: 'us-east-1',
S3_SECRET_ACCESS_KEY: 'test-secret-key',
S3_SET_ACL: true,
};
new S3(
testFileEnv.S3_ACCESS_KEY_ID,
testFileEnv.S3_SECRET_ACCESS_KEY,
testFileEnv.S3_ENDPOINT,
{
bucket: testFileEnv.S3_BUCKET,
forcePathStyle: testFileEnv.S3_ENABLE_PATH_STYLE,
region: testFileEnv.S3_REGION,
setAcl: testFileEnv.S3_SET_ACL,
},
);
expect(S3Client).toHaveBeenCalledWith({
credentials: {
accessKeyId: 'test-access-key',
secretAccessKey: 'test-secret-key',
},
endpoint: 'https://s3.amazonaws.com',
forcePathStyle: false,
region: 'us-east-1',
requestChecksumCalculation: 'WHEN_REQUIRED',
responseChecksumValidation: 'WHEN_REQUIRED',
});
});
it('should use default region when S3_REGION is not set', () => {
const testEnvWithoutRegion = {
S3_ACCESS_KEY_ID: 'test-access-key',
S3_BUCKET: 'test-bucket',
S3_ENABLE_PATH_STYLE: false,
S3_ENDPOINT: 'https://s3.amazonaws.com',
S3_PREVIEW_URL_EXPIRE_IN: 7200,
S3_REGION: '',
S3_SECRET_ACCESS_KEY: 'test-secret-key',
S3_SET_ACL: true,
};
new S3(
testEnvWithoutRegion.S3_ACCESS_KEY_ID,
testEnvWithoutRegion.S3_SECRET_ACCESS_KEY,
testEnvWithoutRegion.S3_ENDPOINT,
{
bucket: testEnvWithoutRegion.S3_BUCKET,
forcePathStyle: testEnvWithoutRegion.S3_ENABLE_PATH_STYLE,
region: testEnvWithoutRegion.S3_REGION,
setAcl: testEnvWithoutRegion.S3_SET_ACL,
},
);
expect(S3Client).toHaveBeenCalledWith(
expect.objectContaining({
region: 'us-east-1',
}),
);
});
});
});
describe('FileS3', () => {
let mockS3ClientSend: ReturnType<typeof vi.fn>;
let mockGetSignedUrl: ReturnType<typeof vi.fn>;
beforeEach(() => {
vi.clearAllMocks();
// Setup S3Client mock
mockS3ClientSend = vi.fn();
(S3Client as unknown as ReturnType<typeof vi.fn>).mockImplementation(() => ({
send: mockS3ClientSend,
}));
// Setup getSignedUrl mock
mockGetSignedUrl = vi.fn().mockResolvedValue('https://presigned-url.example.com');
(getSignedUrl as unknown as ReturnType<typeof vi.fn>).mockImplementation(mockGetSignedUrl);
});
describe('constructor', () => {
it('should initialize S3 client with correct configuration', () => {
new FileS3();
expect(S3Client).toHaveBeenCalledWith({
credentials: {
@@ -88,7 +177,7 @@ describe('S3', () => {
},
}));
new S3();
new FileS3();
expect(S3Client).toHaveBeenCalledWith(
expect.objectContaining({
@@ -100,7 +189,7 @@ describe('S3', () => {
describe('deleteFile', () => {
it('should delete a file with the correct parameters', async () => {
const s3 = new S3();
const s3 = new FileS3();
mockS3ClientSend.mockResolvedValue({});
await s3.deleteFile('test-key.txt');
@@ -113,7 +202,7 @@ describe('S3', () => {
});
it('should handle deletion errors', async () => {
const s3 = new S3();
const s3 = new FileS3();
const error = new Error('Delete failed');
mockS3ClientSend.mockRejectedValue(error);
@@ -123,7 +212,7 @@ describe('S3', () => {
describe('deleteFiles', () => {
it('should delete multiple files with correct parameters', async () => {
const s3 = new S3();
const s3 = new FileS3();
mockS3ClientSend.mockResolvedValue({});
const keys = ['file1.txt', 'file2.txt', 'file3.txt'];
@@ -139,7 +228,7 @@ describe('S3', () => {
});
it('should handle empty array', async () => {
const s3 = new S3();
const s3 = new FileS3();
mockS3ClientSend.mockResolvedValue({});
await s3.deleteFiles([]);
@@ -155,7 +244,7 @@ describe('S3', () => {
describe('getFileContent', () => {
it('should retrieve file content as string', async () => {
const s3 = new S3();
const s3 = new FileS3();
const mockContent = 'Hello, World!';
mockS3ClientSend.mockResolvedValue({
Body: {
@@ -173,7 +262,7 @@ describe('S3', () => {
});
it('should throw error when response body is missing', async () => {
const s3 = new S3();
const s3 = new FileS3();
mockS3ClientSend.mockResolvedValue({
Body: undefined,
});
@@ -186,7 +275,7 @@ describe('S3', () => {
describe('getFileByteArray', () => {
it('should retrieve file content as byte array', async () => {
const s3 = new S3();
const s3 = new FileS3();
const mockBytes = new Uint8Array([1, 2, 3, 4, 5]);
mockS3ClientSend.mockResolvedValue({
Body: {
@@ -204,7 +293,7 @@ describe('S3', () => {
});
it('should throw error when response body is missing', async () => {
const s3 = new S3();
const s3 = new FileS3();
mockS3ClientSend.mockResolvedValue({
Body: undefined,
});
@@ -217,7 +306,7 @@ describe('S3', () => {
describe('createPreSignedUrl', () => {
it('should create presigned URL for upload with ACL', async () => {
const s3 = new S3();
const s3 = new FileS3();
const result = await s3.createPreSignedUrl('upload-file.txt');
@@ -235,7 +324,7 @@ describe('S3', () => {
describe('createPreSignedUrlForPreview', () => {
it('should create presigned URL for preview with default expiration', async () => {
const s3 = new S3();
const s3 = new FileS3();
const result = await s3.createPreSignedUrlForPreview('preview-file.jpg');
@@ -250,7 +339,7 @@ describe('S3', () => {
});
it('should create presigned URL for preview with custom expiration', async () => {
const s3 = new S3();
const s3 = new FileS3();
await s3.createPreSignedUrlForPreview('preview-file.jpg', 1800);
@@ -262,7 +351,7 @@ describe('S3', () => {
describe('uploadBuffer', () => {
it('should upload buffer with correct parameters', async () => {
const s3 = new S3();
const s3 = new FileS3();
mockS3ClientSend.mockResolvedValue({});
const buffer = Buffer.from('test data');
@@ -279,7 +368,7 @@ describe('S3', () => {
});
it('should upload buffer without content type', async () => {
const s3 = new S3();
const s3 = new FileS3();
mockS3ClientSend.mockResolvedValue({});
const buffer = Buffer.from('test data');
@@ -297,7 +386,7 @@ describe('S3', () => {
describe('uploadContent', () => {
it('should upload string content with correct parameters', async () => {
const s3 = new S3();
const s3 = new FileS3();
mockS3ClientSend.mockResolvedValue({});
const content = 'Hello, World!';
@@ -313,7 +402,7 @@ describe('S3', () => {
});
it('should handle empty content', async () => {
const s3 = new S3();
const s3 = new FileS3();
mockS3ClientSend.mockResolvedValue({});
await s3.uploadContent('empty.txt', '');
@@ -329,7 +418,7 @@ describe('S3', () => {
describe('uploadMedia', () => {
it('should upload media with correct content type and cache control for JPEG', async () => {
const s3 = new S3();
const s3 = new FileS3();
mockS3ClientSend.mockResolvedValue({});
const buffer = Buffer.from('fake image data');
@@ -347,7 +436,7 @@ describe('S3', () => {
});
it('should upload media with correct content type for PNG', async () => {
const s3 = new S3();
const s3 = new FileS3();
mockS3ClientSend.mockResolvedValue({});
const buffer = Buffer.from('fake image data');
@@ -362,7 +451,7 @@ describe('S3', () => {
});
it('should upload media with correct content type for GIF', async () => {
const s3 = new S3();
const s3 = new FileS3();
mockS3ClientSend.mockResolvedValue({});
const buffer = Buffer.from('fake image data');
+31 -9
View File
@@ -31,21 +31,32 @@ export class S3 {
private readonly setAcl: boolean;
constructor() {
if (!fileEnv.S3_ACCESS_KEY_ID || !fileEnv.S3_SECRET_ACCESS_KEY || !fileEnv.S3_BUCKET)
constructor(
accessKeyId: string | undefined,
secretAccessKey: string | undefined,
endpoint: string | undefined,
options?: {
bucket?: string;
forcePathStyle?: boolean;
region?: string;
setAcl?: boolean;
},
) {
if (!accessKeyId || !secretAccessKey || !endpoint)
throw new Error('S3 environment variables are not set completely, please check your env');
if (!options?.bucket) throw new Error('S3 bucket is not set, please check your env');
this.bucket = fileEnv.S3_BUCKET;
this.setAcl = fileEnv.S3_SET_ACL;
this.bucket = options?.bucket;
this.setAcl = options?.setAcl || false;
this.client = new S3Client({
credentials: {
accessKeyId: fileEnv.S3_ACCESS_KEY_ID,
secretAccessKey: fileEnv.S3_SECRET_ACCESS_KEY,
accessKeyId: accessKeyId,
secretAccessKey: secretAccessKey,
},
endpoint: fileEnv.S3_ENDPOINT,
forcePathStyle: fileEnv.S3_ENABLE_PATH_STYLE,
region: fileEnv.S3_REGION || DEFAULT_S3_REGION,
endpoint: endpoint,
forcePathStyle: options?.forcePathStyle,
region: options?.region || DEFAULT_S3_REGION,
// refs: https://github.com/lobehub/lobe-chat/pull/5479
requestChecksumCalculation: 'WHEN_REQUIRED',
responseChecksumValidation: 'WHEN_REQUIRED',
@@ -158,3 +169,14 @@ export class S3 {
await this.client.send(command);
}
}
export class FileS3 extends S3 {
constructor() {
super(fileEnv.S3_ACCESS_KEY_ID, fileEnv.S3_SECRET_ACCESS_KEY, fileEnv.S3_ENDPOINT, {
bucket: fileEnv.S3_BUCKET,
forcePathStyle: fileEnv.S3_ENABLE_PATH_STYLE,
region: fileEnv.S3_REGION,
setAcl: fileEnv.S3_SET_ACL,
});
}
}
+97
View File
@@ -0,0 +1,97 @@
import type {
AddIdentityActionSchema,
ContextMemoryItemSchema,
ExperienceMemoryItemSchema,
PreferenceMemoryItemSchema,
RemoveIdentityActionSchema,
UpdateIdentityActionSchema,
} from '@lobechat/memory-user-memory/schemas';
import { z } from 'zod';
import { lambdaClient } from '@/libs/trpc/client';
import {
AddContextMemoryResult,
AddExperienceMemoryResult,
AddIdentityMemoryResult,
AddPreferenceMemoryResult,
LayersEnum,
RemoveIdentityMemoryResult,
SearchMemoryParams,
SearchMemoryResult,
TypesEnum,
UpdateIdentityMemoryResult,
} from '@/types/userMemory';
class UserMemoryService {
addContextMemory = async (
params: z.infer<typeof ContextMemoryItemSchema>,
): Promise<AddContextMemoryResult> => {
return lambdaClient.userMemories.toolAddContextMemory.mutate(params);
};
addExperienceMemory = async (
params: z.infer<typeof ExperienceMemoryItemSchema>,
): Promise<AddExperienceMemoryResult> => {
return lambdaClient.userMemories.toolAddExperienceMemory.mutate(params);
};
addIdentityMemory = async (
params: z.infer<typeof AddIdentityActionSchema>,
): Promise<AddIdentityMemoryResult> => {
return lambdaClient.userMemories.toolAddIdentityMemory.mutate(params);
};
addPreferenceMemory = async (
params: z.infer<typeof PreferenceMemoryItemSchema>,
): Promise<AddPreferenceMemoryResult> => {
return lambdaClient.userMemories.toolAddPreferenceMemory.mutate(params);
};
removeIdentityMemory = async (
params: z.infer<typeof RemoveIdentityActionSchema>,
): Promise<RemoveIdentityMemoryResult> => {
return lambdaClient.userMemories.toolRemoveIdentityMemory.mutate(params);
};
getMemoryDetail = async (params: { id: string; layer: LayersEnum }) => {
return lambdaClient.userMemories.getMemoryDetail.query(params);
};
retrieveMemory = async (params: SearchMemoryParams): Promise<SearchMemoryResult> => {
return lambdaClient.userMemories.toolSearchMemory.query(params);
};
searchMemory = async (params: SearchMemoryParams): Promise<SearchMemoryResult> => {
return lambdaClient.userMemories.toolSearchMemory.query(params);
};
queryTags = async (params?: { layers?: LayersEnum[]; page?: number; size?: number }) => {
return lambdaClient.userMemories.queryTags.query(params);
};
queryIdentityRoles = async (params?: { page?: number; size?: number }) => {
return lambdaClient.userMemories.queryIdentityRoles.query(params);
};
queryMemories = async (params?: {
categories?: string[];
layer?: LayersEnum;
order?: 'asc' | 'desc';
page?: number;
pageSize?: number;
q?: string;
sort?: 'scoreConfidence' | 'scoreImpact' | 'scorePriority' | 'scoreUrgency';
tags?: string[];
types?: TypesEnum[];
}) => {
return lambdaClient.userMemories.queryMemories.query(params);
};
updateIdentityMemory = async (
params: z.infer<typeof UpdateIdentityActionSchema>,
): Promise<UpdateIdentityMemoryResult> => {
return lambdaClient.userMemories.toolUpdateIdentityMemory.mutate(params);
};
}
export const userMemoryService = new UserMemoryService();
+79
View File
@@ -0,0 +1,79 @@
import { NewUserMemoryIdentity } from '@lobechat/types';
import { lambdaClient } from '@/libs/trpc/client';
class MemoryCRUDService {
// ============ Identity CRUD ============
createIdentity = async (data: NewUserMemoryIdentity) => {
return lambdaClient.userMemory.createIdentity.mutate(data);
};
deleteIdentity = async (id: string) => {
return lambdaClient.userMemory.deleteIdentity.mutate({ id });
};
getIdentities = async () => {
return lambdaClient.userMemory.getIdentities.query();
};
updateIdentity = async (id: string, data: Partial<NewUserMemoryIdentity>) => {
return lambdaClient.userMemory.updateIdentity.mutate({ data, id });
};
// ============ Context CRUD ============
deleteContext = async (id: string) => {
return lambdaClient.userMemory.deleteContext.mutate({ id });
};
getContexts = async () => {
return lambdaClient.userMemory.getContexts.query();
};
updateContext = async (
id: string,
data: { currentStatus?: string; description?: string; title?: string },
) => {
return lambdaClient.userMemory.updateContext.mutate({ data, id });
};
// ============ Experience CRUD ============
deleteExperience = async (id: string) => {
return lambdaClient.userMemory.deleteExperience.mutate({ id });
};
getExperiences = async () => {
return lambdaClient.userMemory.getExperiences.query();
};
updateExperience = async (
id: string,
data: { action?: string; keyLearning?: string; situation?: string },
) => {
return lambdaClient.userMemory.updateExperience.mutate({ data, id });
};
// ============ Preference CRUD ============
deletePreference = async (id: string) => {
return lambdaClient.userMemory.deletePreference.mutate({ id });
};
getPreferences = async () => {
return lambdaClient.userMemory.getPreferences.query();
};
updatePreference = async (
id: string,
data: { conclusionDirectives?: string; suggestions?: string },
) => {
return lambdaClient.userMemory.updatePreference.mutate({ data, id });
};
}
export const memoryCRUDService = new MemoryCRUDService();
// Backward compatibility alias
export const memoryService = memoryCRUDService;
+2
View File
@@ -0,0 +1,2 @@
export * from './selectors';
export * from './store';
+93
View File
@@ -0,0 +1,93 @@
import {
DisplayContextMemory,
DisplayExperienceMemory,
DisplayIdentityMemory,
DisplayPreferenceMemory,
} from '@/database/repositories/userMemory';
import type { RetrieveMemoryParams, RetrieveMemoryResult, TypesEnum } from '@/types/userMemory';
export interface UserMemoryStoreState {
activeParams?: RetrieveMemoryParams;
activeParamsKey?: string;
contexts: DisplayContextMemory[];
contextsHasMore: boolean;
contextsInit: boolean;
contextsPage: number;
contextsQuery?: string;
contextsSearchLoading?: boolean;
contextsSort?: 'scoreImpact' | 'scoreUrgency';
contextsTotal: number;
editingMemoryContent?: string;
editingMemoryId?: string;
editingMemoryLayer?: 'context' | 'experience' | 'identity' | 'preference';
experiences: DisplayExperienceMemory[];
experiencesHasMore: boolean;
experiencesInit: boolean;
experiencesPage: number;
experiencesQuery?: string;
experiencesSearchLoading?: boolean;
experiencesSort?: 'scoreConfidence';
experiencesTotal: number;
identities: DisplayIdentityMemory[];
identitiesHasMore: boolean;
identitiesInit: boolean;
identitiesPage: number;
identitiesQuery?: string;
identitiesSearchLoading?: boolean;
identitiesTotal: number;
identitiesTypes?: TypesEnum[];
memoryFetchedAtMap: Record<string, number>;
memoryMap: Record<string, RetrieveMemoryResult>;
preferences: DisplayPreferenceMemory[];
preferencesHasMore: boolean;
preferencesInit: boolean;
preferencesPage: number;
preferencesQuery?: string;
preferencesSearchLoading?: boolean;
preferencesSort?: 'scorePriority';
preferencesTotal: number;
roles: { count: number; tag: string }[];
tags: { count: number; tag: string }[];
tagsInit: boolean;
}
export const initialState: UserMemoryStoreState = {
activeParams: undefined,
activeParamsKey: undefined,
contexts: [],
contextsHasMore: true,
contextsInit: false,
contextsPage: 1,
contextsQuery: undefined,
contextsSort: undefined,
contextsTotal: 0,
editingMemoryContent: undefined,
editingMemoryId: undefined,
editingMemoryLayer: undefined,
experiences: [],
experiencesHasMore: true,
experiencesInit: false,
experiencesPage: 1,
experiencesQuery: undefined,
experiencesSort: undefined,
experiencesTotal: 0,
identities: [],
identitiesHasMore: true,
identitiesInit: false,
identitiesPage: 1,
identitiesQuery: undefined,
identitiesTotal: 0,
identitiesTypes: undefined,
memoryFetchedAtMap: {},
memoryMap: {},
preferences: [],
preferencesHasMore: true,
preferencesInit: false,
preferencesPage: 1,
preferencesQuery: undefined,
preferencesSort: undefined,
preferencesTotal: 0,
roles: [],
tags: [],
tagsInit: false,
};
+59
View File
@@ -0,0 +1,59 @@
import type { RetrieveMemoryParams, RetrieveMemoryResult } from '@/types/userMemory';
import type { UserMemoryStoreState } from './initialState';
import { userMemoryCacheKey } from './utils/cacheKey';
const EMPTY_RESULT: RetrieveMemoryResult = {
contexts: [],
experiences: [],
preferences: [],
};
type ActiveUserMemoriesResult = {
fetchedAt: number;
memories: RetrieveMemoryResult;
};
export const userMemorySelectors = {
activeMemories: (state: UserMemoryStoreState): RetrieveMemoryResult | undefined => {
if (!state.activeParamsKey) return undefined;
return state.memoryMap[state.activeParamsKey];
},
activeMemoryFetchedAt: (state: UserMemoryStoreState): number | undefined => {
if (!state.activeParamsKey) return undefined;
return state.memoryFetchedAtMap[state.activeParamsKey];
},
activeParams: (state: UserMemoryStoreState): RetrieveMemoryParams | undefined =>
state.activeParams,
activeUserMemories:
(enabled: boolean) =>
(state: UserMemoryStoreState): ActiveUserMemoriesResult | undefined => {
if (!enabled || !state.activeParamsKey) return undefined;
const fetchedAt = state.memoryFetchedAtMap[state.activeParamsKey];
const memories = state.memoryMap[state.activeParamsKey];
if (fetchedAt === undefined || !memories) return undefined;
return {
fetchedAt,
memories,
};
},
memoriesByParams: (params?: RetrieveMemoryParams) => (state: UserMemoryStoreState) => {
if (!params) return EMPTY_RESULT;
const key = userMemoryCacheKey(params);
return state.memoryMap[key] ?? EMPTY_RESULT;
},
memoryFetchedAtByParams: (params?: RetrieveMemoryParams) => (state: UserMemoryStoreState) => {
if (!params) return undefined;
const key = userMemoryCacheKey(params);
return state.memoryFetchedAtMap[key];
},
};
+252
View File
@@ -0,0 +1,252 @@
import isEqual from 'fast-deep-equal';
import { type SWRResponse, mutate } from 'swr';
import useSWR from 'swr';
import { StateCreator } from 'zustand/vanilla';
import { useClientDataSWR } from '@/libs/swr';
import { userMemoryService } from '@/services/userMemory';
import { LayersEnum } from '@/types/userMemory';
import type { RetrieveMemoryParams, RetrieveMemoryResult } from '@/types/userMemory';
import { setNamespace } from '@/utils/storeDebug';
import { UserMemoryStore } from '../../store';
import { userMemoryCacheKey } from '../../utils/cacheKey';
import { createMemorySearchParams } from '../../utils/searchParams';
const SWR_FETCH_USER_MEMORY = 'SWR_FETCH_USER_MEMORY';
const n = setNamespace('userMemory');
type MemoryContext = Parameters<typeof createMemorySearchParams>[0];
export interface BaseAction {
clearEditingMemory: () => void;
refreshUserMemory: (params: RetrieveMemoryParams) => Promise<void>;
setActiveMemoryContext: (context?: MemoryContext) => void;
setEditingMemory: (
id: string,
content: string,
layer: 'context' | 'experience' | 'identity' | 'preference',
) => void;
updateMemory: (id: string, content: string, layer: LayersEnum) => Promise<void>;
useFetchMemoryDetail: (id: string | null, layer: LayersEnum) => SWRResponse<any>;
useFetchUserMemory: (
enable: boolean,
params?: RetrieveMemoryParams,
) => SWRResponse<RetrieveMemoryResult>;
}
export const createBaseSlice: StateCreator<
UserMemoryStore,
[['zustand/devtools', never]],
[],
BaseAction
> = (set, get) => ({
clearEditingMemory: () => {
set(
{
editingMemoryContent: undefined,
editingMemoryId: undefined,
editingMemoryLayer: undefined,
},
false,
n('clearEditingMemory'),
);
},
refreshUserMemory: async (params) => {
const key = userMemoryCacheKey(params);
await mutate([SWR_FETCH_USER_MEMORY, key]);
},
setActiveMemoryContext: (context) => {
const params = context ? createMemorySearchParams(context) : undefined;
const key = params ? userMemoryCacheKey(params) : undefined;
set(
{
activeParams: params,
activeParamsKey: key,
},
false,
n('setActiveMemoryContext', { key }),
);
},
setEditingMemory: (id, content, layer) => {
set(
{
editingMemoryContent: content,
editingMemoryId: id,
editingMemoryLayer: layer,
},
false,
n('setEditingMemory', { id, layer }),
);
},
updateMemory: async (id, content, layer) => {
const { memoryCRUDService } = await import('@/services/userMemory/index');
const { resetContextsList, resetExperiencesList, resetIdentitiesList, resetPreferencesList } =
get();
// Update the memory content based on layer
switch (layer) {
case LayersEnum.Context: {
await memoryCRUDService.updateContext(id, { description: content });
resetContextsList({ q: get().contextsQuery, sort: get().contextsSort });
break;
}
case LayersEnum.Experience: {
await memoryCRUDService.updateExperience(id, { keyLearning: content });
resetExperiencesList({ q: get().experiencesQuery, sort: get().experiencesSort });
break;
}
case LayersEnum.Identity: {
await memoryCRUDService.updateIdentity(id, { description: content });
resetIdentitiesList({ q: get().identitiesQuery, types: get().identitiesTypes });
break;
}
case LayersEnum.Preference: {
await memoryCRUDService.updatePreference(id, { conclusionDirectives: content });
resetPreferencesList({ q: get().preferencesQuery, sort: get().preferencesSort });
break;
}
}
// Clear editing state
get().clearEditingMemory();
},
useFetchMemoryDetail: (id, layer) => {
const swrKey = id ? `memoryDetail-${layer}-${id}` : null;
return useSWR(
swrKey,
async () => {
if (!id) return null;
const detail = await userMemoryService.getMemoryDetail({ id, layer });
if (!detail) return null;
// Transform nested structure to flat structure
switch (layer) {
case LayersEnum.Context: {
if (detail.layer === LayersEnum.Context) {
return {
...detail.memory,
...detail.context,
source: detail.source,
sourceType: detail.sourceType,
};
}
break;
}
case LayersEnum.Experience: {
if (detail.layer === LayersEnum.Experience) {
return {
...detail.memory,
...detail.experience,
source: detail.source,
sourceType: detail.sourceType,
};
}
break;
}
case LayersEnum.Identity: {
if (detail.layer === LayersEnum.Identity) {
return {
...detail.memory,
...detail.identity,
source: detail.source,
sourceType: detail.sourceType,
};
}
break;
}
case LayersEnum.Preference: {
if (detail.layer === LayersEnum.Preference) {
return {
...detail.memory,
...detail.preference,
source: detail.source,
sourceType: detail.sourceType,
};
}
break;
}
}
return null;
},
{
revalidateOnFocus: false,
},
);
},
useFetchUserMemory: (enable, params) => {
const resolvedParams = params ?? get().activeParams;
const key = resolvedParams ? userMemoryCacheKey(resolvedParams) : undefined;
return useClientDataSWR<RetrieveMemoryResult>(
enable && resolvedParams ? [SWR_FETCH_USER_MEMORY, key] : null,
() => userMemoryService.retrieveMemory(resolvedParams!),
{
onSuccess: (result) => {
if (!resolvedParams || !key) return;
const state = get();
const previous = state.memoryMap[key];
const next = result ?? { contexts: [], experiences: [], preferences: [] };
const fetchedAt = Date.now();
if (previous && isEqual(previous, next)) {
set(
{
memoryFetchedAtMap: {
...state.memoryFetchedAtMap,
[key]: fetchedAt,
},
},
false,
n('useFetchUserMemory/refresh', {
key,
totals: {
contexts: next.contexts.length,
experiences: next.experiences.length,
preferences: next.preferences.length,
},
}),
);
return;
}
set(
{
memoryFetchedAtMap: {
...state.memoryFetchedAtMap,
[key]: fetchedAt,
},
memoryMap: {
...state.memoryMap,
[key]: next,
},
},
false,
n('useFetchUserMemory/success', {
key,
totals: {
contexts: next.contexts.length,
experiences: next.experiences.length,
preferences: next.preferences.length,
},
}),
);
},
},
);
},
});
@@ -0,0 +1 @@
export { type BaseAction,createBaseSlice } from './action';
@@ -0,0 +1,127 @@
import { uniqBy } from 'es-toolkit/compat';
import { produce } from 'immer';
import useSWR, { SWRResponse } from 'swr';
import { StateCreator } from 'zustand/vanilla';
import { userMemoryService } from '@/services/userMemory';
import { memoryCRUDService } from '@/services/userMemory/index';
import { LayersEnum } from '@/types/userMemory';
import { setNamespace } from '@/utils/storeDebug';
import { UserMemoryStore } from '../../store';
const n = setNamespace('userMemory/context');
export interface ContextQueryParams {
page?: number;
pageSize?: number;
q?: string;
sort?: 'scoreImpact' | 'scoreUrgency';
}
export interface ContextAction {
deleteContext: (id: string) => Promise<void>;
loadMoreContexts: () => void;
resetContextsList: (params?: Omit<ContextQueryParams, 'page' | 'pageSize'>) => void;
useFetchContexts: (params: ContextQueryParams) => SWRResponse<any>;
}
export const createContextSlice: StateCreator<
UserMemoryStore,
[['zustand/devtools', never]],
[],
ContextAction
> = (set, get) => ({
deleteContext: async (id) => {
await memoryCRUDService.deleteContext(id);
// Reset list to refresh
get().resetContextsList({ q: get().contextsQuery, sort: get().contextsSort });
},
loadMoreContexts: () => {
const { contextsPage, contextsTotal, contexts } = get();
if (contexts.length < (contextsTotal || 0)) {
set(
produce((draft) => {
draft.contextsPage = contextsPage + 1;
}),
false,
n('loadMoreContexts'),
);
}
},
resetContextsList: (params) => {
set(
produce((draft) => {
draft.contexts = [];
draft.contextsPage = 1;
draft.contextsQuery = params?.q;
draft.contextsSearchLoading = true;
draft.contextsSort = params?.sort;
}),
false,
n('resetContextsList'),
);
},
useFetchContexts: (params) => {
const swrKeyParts = ['useFetchContexts', params.page, params.pageSize, params.q, params.sort];
const swrKey = swrKeyParts
.filter((part) => part !== undefined && part !== null && part !== '')
.join('-');
const page = params.page ?? 1;
return useSWR(
swrKey,
async () => {
const result = await userMemoryService.queryMemories({
layer: LayersEnum.Context,
page: params.page,
pageSize: params.pageSize,
q: params.q,
sort: params.sort,
});
return result;
},
{
onSuccess(data: any) {
set(
produce((draft) => {
draft.contextsSearchLoading = false;
// 设置基础信息
if (!draft.contextsInit) {
draft.contextsInit = true;
draft.contextsTotal = data.total;
}
// 转换数据结构
const transformedItems = data.items.map((item: any) => ({
...item.memory,
...item.context,
source: null,
}));
// 累积数据逻辑
if (page === 1) {
// 第一页,直接设置
draft.contexts = uniqBy(transformedItems, 'id');
} else {
// 后续页面,累积数据
draft.contexts = uniqBy([...draft.contexts, ...transformedItems], 'id');
}
// 更新 hasMore
draft.contextsHasMore = data.items.length >= (params.pageSize || 20);
}),
false,
n('useFetchContexts/onSuccess'),
);
},
revalidateOnFocus: false,
},
);
},
});
@@ -0,0 +1 @@
export { type ContextAction,createContextSlice } from './action';
@@ -0,0 +1,132 @@
import { uniqBy } from 'es-toolkit/compat';
import { produce } from 'immer';
import useSWR, { SWRResponse } from 'swr';
import { StateCreator } from 'zustand/vanilla';
import { userMemoryService } from '@/services/userMemory';
import { memoryCRUDService } from '@/services/userMemory/index';
import { LayersEnum } from '@/types/userMemory';
import { setNamespace } from '@/utils/storeDebug';
import { UserMemoryStore } from '../../store';
const n = setNamespace('userMemory/experience');
export interface ExperienceQueryParams {
page?: number;
pageSize?: number;
q?: string;
sort?: 'scoreConfidence';
}
export interface ExperienceAction {
deleteExperience: (id: string) => Promise<void>;
loadMoreExperiences: () => void;
resetExperiencesList: (params?: Omit<ExperienceQueryParams, 'page' | 'pageSize'>) => void;
useFetchExperiences: (params: ExperienceQueryParams) => SWRResponse<any>;
}
export const createExperienceSlice: StateCreator<
UserMemoryStore,
[['zustand/devtools', never]],
[],
ExperienceAction
> = (set, get) => ({
deleteExperience: async (id) => {
await memoryCRUDService.deleteExperience(id);
// Reset list to refresh
get().resetExperiencesList({ q: get().experiencesQuery, sort: get().experiencesSort });
},
loadMoreExperiences: () => {
const { experiencesPage, experiencesTotal, experiences } = get();
if (experiences.length < (experiencesTotal || 0)) {
set(
produce((draft) => {
draft.experiencesPage = experiencesPage + 1;
}),
false,
n('loadMoreExperiences'),
);
}
},
resetExperiencesList: (params) => {
set(
produce((draft) => {
draft.experiences = [];
draft.experiencesPage = 1;
draft.experiencesQuery = params?.q;
draft.experiencesSearchLoading = true;
draft.experiencesSort = params?.sort;
}),
false,
n('resetExperiencesList'),
);
},
useFetchExperiences: (params) => {
const swrKeyParts = [
'useFetchExperiences',
params.page,
params.pageSize,
params.q,
params.sort,
];
const swrKey = swrKeyParts
.filter((part) => part !== undefined && part !== null && part !== '')
.join('-');
const page = params.page ?? 1;
return useSWR(
swrKey,
async () => {
const result = await userMemoryService.queryMemories({
layer: LayersEnum.Experience,
page: params.page,
pageSize: params.pageSize,
q: params.q,
sort: params.sort,
});
return result;
},
{
onSuccess(data: any) {
set(
produce((draft) => {
draft.experiencesSearchLoading = false;
// 设置基础信息
if (!draft.experiencesInit) {
draft.experiencesInit = true;
draft.experiencesTotal = data.total;
}
// 转换数据结构
const transformedItems = data.items.map((item: any) => ({
...item.memory,
...item.experience,
}));
// 累积数据逻辑
if (page === 1) {
// 第一页,直接设置
draft.experiences = uniqBy(transformedItems, 'id');
} else {
// 后续页面,累积数据
draft.experiences = uniqBy([...draft.experiences, ...transformedItems], 'id');
}
// 更新 hasMore
draft.experiencesHasMore = data.items.length >= (params.pageSize || 20);
}),
false,
n('useFetchExperiences/onSuccess'),
);
},
revalidateOnFocus: false,
},
);
},
});
@@ -0,0 +1 @@
export { createExperienceSlice, type ExperienceAction } from './action';
@@ -0,0 +1,45 @@
import { SWRResponse } from 'swr';
import { StateCreator } from 'zustand/vanilla';
import { QueryIdentityRolesResult } from '@/database/models/userMemory';
import { useClientDataSWR } from '@/libs/swr';
import { userMemoryService } from '@/services/userMemory';
import { UserMemoryStore } from '../../store';
const FETCH_TAGS_KEY = 'useFetchTags';
const n = (namespace: string) => namespace;
export interface HomeAction {
useFetchTags: () => SWRResponse<QueryIdentityRolesResult>;
}
export const createHomeSlice: StateCreator<
UserMemoryStore,
[['zustand/devtools', never]],
[],
HomeAction
> = (set) => ({
useFetchTags: () =>
useClientDataSWR(
FETCH_TAGS_KEY,
() =>
userMemoryService.queryIdentityRoles({
page: 1,
size: 64,
}),
{
onSuccess: (data: QueryIdentityRolesResult | undefined) => {
set(
{
roles: data?.roles.map((item) => ({ count: item.count, tag: item.role })) || [],
tags: data?.tags || [],
tagsInit: true,
},
false,
n('useFetchTags/onSuccess'),
);
},
},
),
});
@@ -0,0 +1 @@
export * from './action';
@@ -0,0 +1,150 @@
import { NewUserMemoryIdentity, UpdateUserMemoryIdentity } from '@lobechat/types';
import { uniqBy } from 'es-toolkit/compat';
import { produce } from 'immer';
import useSWR, { SWRResponse } from 'swr';
import { StateCreator } from 'zustand/vanilla';
import { AddIdentityEntryResult } from '@/database/models/userMemory';
import { userMemoryService } from '@/services/userMemory';
import { memoryCRUDService } from '@/services/userMemory/index';
import { LayersEnum, TypesEnum } from '@/types/userMemory';
import { setNamespace } from '@/utils/storeDebug';
import { UserMemoryStore } from '../../store';
const n = setNamespace('userMemory/identity');
export interface IdentityQueryParams {
page?: number;
pageSize?: number;
q?: string;
types?: TypesEnum[];
}
export interface IdentityAction {
createIdentity: (data: NewUserMemoryIdentity) => Promise<AddIdentityEntryResult>;
deleteIdentity: (id: string) => Promise<void>;
loadMoreIdentities: () => void;
resetIdentitiesList: (params?: Omit<IdentityQueryParams, 'page' | 'pageSize'>) => void;
updateIdentity: (id: string, data: UpdateUserMemoryIdentity) => Promise<boolean>;
useFetchIdentities: (params: IdentityQueryParams) => SWRResponse<any>;
}
export const createIdentitySlice: StateCreator<
UserMemoryStore,
[['zustand/devtools', never]],
[],
IdentityAction
> = (set, get) => ({
createIdentity: async (data) => {
const result = await memoryCRUDService.createIdentity(data);
// Reset list to refresh
get().resetIdentitiesList({ q: get().identitiesQuery, types: get().identitiesTypes });
return result;
},
deleteIdentity: async (id) => {
await memoryCRUDService.deleteIdentity(id);
// Reset list to refresh
get().resetIdentitiesList({ q: get().identitiesQuery, types: get().identitiesTypes });
},
loadMoreIdentities: () => {
const { identitiesPage, identitiesTotal, identities } = get();
if (identities.length < (identitiesTotal || 0)) {
set(
produce((draft) => {
draft.identitiesPage = identitiesPage + 1;
}),
false,
n('loadMoreIdentities'),
);
}
},
resetIdentitiesList: (params) => {
set(
produce((draft) => {
draft.identities = [];
draft.identitiesPage = 1;
draft.identitiesQuery = params?.q;
draft.identitiesSearchLoading = true;
draft.identitiesTypes = params?.types;
}),
false,
n('resetIdentitiesList'),
);
},
updateIdentity: async (id, data) => {
const result = await memoryCRUDService.updateIdentity(id, data);
// Reset list to refresh
get().resetIdentitiesList({ q: get().identitiesQuery, types: get().identitiesTypes });
return result;
},
useFetchIdentities: (params) => {
const swrKeyParts = [
'useFetchIdentities',
params.page,
params.pageSize,
params.q,
params.types?.join(','),
];
const swrKey = swrKeyParts
.filter((part) => part !== undefined && part !== null && part !== '')
.join('-');
const page = params.page ?? 1;
return useSWR(
swrKey,
async () => {
const result = await userMemoryService.queryMemories({
layer: LayersEnum.Identity,
page: params.page,
pageSize: params.pageSize,
q: params.q,
types: params.types,
});
return result;
},
{
onSuccess(data: any) {
set(
produce((draft) => {
draft.identitiesSearchLoading = false;
// 设置基础信息
if (!draft.identitiesInit) {
draft.identitiesInit = true;
draft.identitiesTotal = data.total;
}
// 转换数据结构
const transformedItems = data.items.map((item: any) => ({
...item.memory,
...item.identity,
}));
// 累积数据逻辑
if (page === 1) {
// 第一页,直接设置
draft.identities = uniqBy(transformedItems, 'id');
} else {
// 后续页面,累积数据
draft.identities = uniqBy([...draft.identities, ...transformedItems], 'id');
}
// 更新 hasMore
draft.identitiesHasMore = data.items.length >= (params.pageSize || 20);
}),
false,
n('useFetchIdentities/onSuccess'),
);
},
revalidateOnFocus: false,
},
);
},
});
@@ -0,0 +1 @@
export { createIdentitySlice, type IdentityAction } from './action';
@@ -0,0 +1,132 @@
import { uniqBy } from 'es-toolkit/compat';
import { produce } from 'immer';
import useSWR, { SWRResponse } from 'swr';
import { StateCreator } from 'zustand/vanilla';
import { userMemoryService } from '@/services/userMemory';
import { memoryCRUDService } from '@/services/userMemory/index';
import { LayersEnum } from '@/types/userMemory';
import { setNamespace } from '@/utils/storeDebug';
import { UserMemoryStore } from '../../store';
const n = setNamespace('userMemory/preference');
export interface PreferenceQueryParams {
page?: number;
pageSize?: number;
q?: string;
sort?: 'scorePriority';
}
export interface PreferenceAction {
deletePreference: (id: string) => Promise<void>;
loadMorePreferences: () => void;
resetPreferencesList: (params?: Omit<PreferenceQueryParams, 'page' | 'pageSize'>) => void;
useFetchPreferences: (params: PreferenceQueryParams) => SWRResponse<any>;
}
export const createPreferenceSlice: StateCreator<
UserMemoryStore,
[['zustand/devtools', never]],
[],
PreferenceAction
> = (set, get) => ({
deletePreference: async (id) => {
await memoryCRUDService.deletePreference(id);
// Reset list to refresh
get().resetPreferencesList({ q: get().preferencesQuery, sort: get().preferencesSort });
},
loadMorePreferences: () => {
const { preferencesPage, preferencesTotal, preferences } = get();
if (preferences.length < (preferencesTotal || 0)) {
set(
produce((draft) => {
draft.preferencesPage = preferencesPage + 1;
}),
false,
n('loadMorePreferences'),
);
}
},
resetPreferencesList: (params) => {
set(
produce((draft) => {
draft.preferences = [];
draft.preferencesPage = 1;
draft.preferencesQuery = params?.q;
draft.preferencesSearchLoading = true;
draft.preferencesSort = params?.sort;
}),
false,
n('resetPreferencesList'),
);
},
useFetchPreferences: (params) => {
const swrKeyParts = [
'useFetchPreferences',
params.page,
params.pageSize,
params.q,
params.sort,
];
const swrKey = swrKeyParts
.filter((part) => part !== undefined && part !== null && part !== '')
.join('-');
const page = params.page ?? 1;
return useSWR(
swrKey,
async () => {
const result = await userMemoryService.queryMemories({
layer: LayersEnum.Preference,
page: params.page,
pageSize: params.pageSize,
q: params.q,
sort: params.sort,
});
return result;
},
{
onSuccess(data: any) {
set(
produce((draft) => {
draft.preferencesSearchLoading = false;
// 设置基础信息
if (!draft.preferencesInit) {
draft.preferencesInit = true;
draft.preferencesTotal = data.total;
}
// 转换数据结构
const transformedItems = data.items.map((item: any) => ({
...item.memory,
...item.preference,
}));
// 累积数据逻辑
if (page === 1) {
// 第一页,直接设置
draft.preferences = uniqBy(transformedItems, 'id');
} else {
// 后续页面,累积数据
draft.preferences = uniqBy([...draft.preferences, ...transformedItems], 'id');
}
// 更新 hasMore
draft.preferencesHasMore = data.items.length >= (params.pageSize || 20);
}),
false,
n('useFetchPreferences/onSuccess'),
);
},
revalidateOnFocus: false,
},
);
},
});
@@ -0,0 +1 @@
export { createPreferenceSlice, type PreferenceAction } from './action';
+43
View File
@@ -0,0 +1,43 @@
import { shallow } from 'zustand/shallow';
import { createWithEqualityFn } from 'zustand/traditional';
import { type StateCreator } from 'zustand/vanilla';
import { createDevtools } from '../middleware/createDevtools';
import { type UserMemoryStoreState, initialState } from './initialState';
import { type BaseAction, createBaseSlice } from './slices/base';
import { type ContextAction, createContextSlice } from './slices/context';
import { type ExperienceAction, createExperienceSlice } from './slices/experience';
import { type HomeAction, createHomeSlice } from './slices/home';
import { type IdentityAction, createIdentitySlice } from './slices/identity';
import { type PreferenceAction, createPreferenceSlice } from './slices/preference';
export type UserMemoryStore = UserMemoryStoreState &
BaseAction &
ContextAction &
ExperienceAction &
HomeAction &
IdentityAction &
PreferenceAction;
const createStore: StateCreator<UserMemoryStore, [['zustand/devtools', never]]> = (
set,
get,
store,
) => ({
...initialState,
...createBaseSlice(set, get, store),
...createContextSlice(set, get, store),
...createExperienceSlice(set, get, store),
...createHomeSlice(set, get, store),
...createIdentitySlice(set, get, store),
...createPreferenceSlice(set, get, store),
});
const devtools = createDevtools('userMemory');
export const useUserMemoryStore = createWithEqualityFn<UserMemoryStore>()(
devtools(createStore),
shallow,
);
export const getUserMemoryStoreState = () => useUserMemoryStore.getState();
+10
View File
@@ -0,0 +1,10 @@
import type { RetrieveMemoryParams } from '@/types/userMemory';
export const userMemoryCacheKey = (params: RetrieveMemoryParams): string => {
const { query, topK } = params;
return JSON.stringify({
query,
topK: topK ?? null,
});
};
@@ -0,0 +1,47 @@
import { find, isString, trim } from 'es-toolkit/compat';
import { DEFAULT_SEARCH_USER_MEMORY_TOP_K } from '@/const/userMemory';
import type { RetrieveMemoryParams } from '@/types/userMemory';
interface MemorySearchSource {
latestUserMessage?: string | null;
sendingMessage?: string | null;
session?: {
meta?: {
description?: string | null;
title?: string | null;
} | null;
} | null;
topic?: {
historySummary?: string | null;
title?: string | null;
} | null;
}
const pickFirstNonEmpty = (values: Array<string | null | undefined>) => {
const matched = find(values, (value) => isString(value) && trim(value).length > 0);
if (!isString(matched)) return undefined;
return trim(matched);
};
export const createMemorySearchParams = (
source: MemorySearchSource,
): RetrieveMemoryParams | undefined => {
const query = pickFirstNonEmpty([
source.topic?.historySummary,
source.session?.meta?.description,
source.latestUserMessage,
source.sendingMessage,
]);
if (!query) return undefined;
return {
query,
topK: {
...DEFAULT_SEARCH_USER_MEMORY_TOP_K,
},
} as RetrieveMemoryParams;
};