From f3d5d03cf5c36a8ebe2d05ea21935791ae55c58b Mon Sep 17 00:00:00 2001 From: Neko Date: Mon, 25 May 2026 04:06:50 +0800 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20refactor(userMemories):=20?= =?UTF-8?q?support=20resolving=20agent=20config=20from=20ServiceModel=20(#?= =?UTF-8?q?15138)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ♻️ refactor(userMemories): support resolving agent config from ServiceModel * ♻️ refactor(userMemories): share memory analysis service model --- locales/en-US/setting.json | 11 + locales/zh-CN/setting.json | 11 + packages/const/src/settings/systemAgent.ts | 20 +- packages/types/src/serverConfig.ts | 4 +- packages/types/src/user/settings/index.ts | 4 +- .../types/src/user/settings/systemAgent.ts | 12 + .../ServiceModel/ModelAssignmentsForm.tsx | 63 +++- src/locales/default/setting.ts | 14 + .../globalConfig/parseSystemAgent.test.ts | 19 ++ src/server/globalConfig/parseSystemAgent.ts | 20 +- .../__tests__/extract.runtime.test.ts | 205 ++++++++++++- .../services/memory/userMemory/extract.ts | 274 ++++++++++++++---- .../persona/__tests__/service.test.ts | 45 ++- .../memory/userMemory/persona/service.ts | 56 +++- src/store/user/slices/settings/action.ts | 3 +- .../__snapshots__/settings.test.ts.snap | 12 + vitest.config.mts | 8 + 17 files changed, 688 insertions(+), 93 deletions(-) diff --git a/locales/en-US/setting.json b/locales/en-US/setting.json index b61a6aad82..b3d663764a 100644 --- a/locales/en-US/setting.json +++ b/locales/en-US/setting.json @@ -503,6 +503,8 @@ "plugin.settings.tooltip": "Skill Configuration", "plugin.store": "Skill Store", "publishToCommunity": "Publish to Community", + "serviceModel.contextLimit.placeholder": "Context limit", + "serviceModel.memoryModels.title": "Memory Models", "serviceModel.modelAssignments.title": "Model Assignments", "serviceModel.optionalFeatures.title": "Optional Features", "settingAgent.avatar.sizeExceeded": "Image size exceeds 1MB limit, please choose a smaller image", @@ -850,6 +852,9 @@ "systemAgent.inputCompletion.label": "Model", "systemAgent.inputCompletion.modelDesc": "Suggests text while you type. When enabled, this model generates the suggestions.", "systemAgent.inputCompletion.title": "Input Suggestions", + "systemAgent.memoryAnalysisAgentConfig.label": "Model", + "systemAgent.memoryAnalysisAgentConfig.modelDesc": "Model used to decide whether conversations contain memory and extract identities, preferences, contexts, activities, and experiences.", + "systemAgent.memoryAnalysisAgentConfig.title": "Memory Analysis", "systemAgent.promptRewrite.label": "Model", "systemAgent.promptRewrite.modelDesc": "Improves prompts before generation. When enabled, this model rewrites the prompt.", "systemAgent.promptRewrite.title": "Prompt Rewriting", @@ -863,6 +868,12 @@ "systemAgent.translation.label": "Model", "systemAgent.translation.modelDesc": "Model used to translate messages", "systemAgent.translation.title": "Message Translation", + "systemAgent.userMemoryEmbedding.label": "Model", + "systemAgent.userMemoryEmbedding.modelDesc": "Model used to embed memory content for retrieval. The context limit caps each embedding input.", + "systemAgent.userMemoryEmbedding.title": "Memory Embedding", + "systemAgent.userMemoryPersonaWriter.label": "Model", + "systemAgent.userMemoryPersonaWriter.modelDesc": "Model used to write persona-oriented memory summaries.", + "systemAgent.userMemoryPersonaWriter.title": "Memory Persona Writer", "tab.about": "About", "tab.addAgentSkill": "Add Agent Skill", "tab.addCustomMcp": "Add Custom MCP Skill", diff --git a/locales/zh-CN/setting.json b/locales/zh-CN/setting.json index 1f2093cb66..1a8c8469de 100644 --- a/locales/zh-CN/setting.json +++ b/locales/zh-CN/setting.json @@ -503,6 +503,8 @@ "plugin.settings.tooltip": "技能配置", "plugin.store": "技能商店", "publishToCommunity": "发布到社区", + "serviceModel.contextLimit.placeholder": "上下文限制", + "serviceModel.memoryModels.title": "记忆模型", "serviceModel.modelAssignments.title": "模型分配", "serviceModel.optionalFeatures.title": "可选功能", "settingAgent.avatar.sizeExceeded": "图片大小超过 1MB 限制,请选择更小的图片", @@ -850,6 +852,9 @@ "systemAgent.inputCompletion.label": "模型", "systemAgent.inputCompletion.modelDesc": "输入时生成文本建议。开启后,由该模型生成建议。", "systemAgent.inputCompletion.title": "输入建议", + "systemAgent.memoryAnalysisAgentConfig.label": "模型", + "systemAgent.memoryAnalysisAgentConfig.modelDesc": "用于判断对话是否包含记忆,并提取身份、偏好、上下文、活动和经历。", + "systemAgent.memoryAnalysisAgentConfig.title": "记忆分析", "systemAgent.promptRewrite.label": "模型", "systemAgent.promptRewrite.modelDesc": "生成前优化提示词。开启后,由该模型改写提示词。", "systemAgent.promptRewrite.title": "提示词改写", @@ -863,6 +868,12 @@ "systemAgent.translation.label": "模型", "systemAgent.translation.modelDesc": "用于翻译消息内容的模型", "systemAgent.translation.title": "消息内容翻译", + "systemAgent.userMemoryEmbedding.label": "模型", + "systemAgent.userMemoryEmbedding.modelDesc": "用于为记忆内容生成向量以支持检索。上下文限制会约束每次向量化输入。", + "systemAgent.userMemoryEmbedding.title": "记忆向量化", + "systemAgent.userMemoryPersonaWriter.label": "模型", + "systemAgent.userMemoryPersonaWriter.modelDesc": "用于生成面向画像的记忆摘要。", + "systemAgent.userMemoryPersonaWriter.title": "记忆画像写入", "tab.about": "关于", "tab.addAgentSkill": "添加 Agent 技能", "tab.addCustomMcp": "添加自定义 MCP 技能", diff --git a/packages/const/src/settings/systemAgent.ts b/packages/const/src/settings/systemAgent.ts index 07c6335444..4b895a2b12 100644 --- a/packages/const/src/settings/systemAgent.ts +++ b/packages/const/src/settings/systemAgent.ts @@ -1,11 +1,15 @@ -import { DEFAULT_MINI_PROVIDER, DEFAULT_PROVIDER } from '@lobechat/business-const'; +import { + DEFAULT_EMBEDDING_PROVIDER, + DEFAULT_MINI_PROVIDER, + DEFAULT_PROVIDER, +} from '@lobechat/business-const'; import type { PromptRewriteSystemAgent, SystemAgentItem, - UserSystemAgentConfig, + UserServiceModelConfig, } from '@lobechat/types'; -import { DEFAULT_MINI_MODEL, DEFAULT_MODEL } from './llm'; +import { DEFAULT_EMBEDDING_MODEL, DEFAULT_MINI_MODEL, DEFAULT_MODEL } from './llm'; export const DEFAULT_SYSTEM_AGENT_ITEM: SystemAgentItem = { model: DEFAULT_MODEL, @@ -35,12 +39,20 @@ export const DEFAULT_FOLLOW_UP_ACTION_SYSTEM_AGENT_ITEM: SystemAgentItem = { provider: DEFAULT_MINI_SYSTEM_AGENT_ITEM.provider, }; -export const DEFAULT_SYSTEM_AGENT_CONFIG: UserSystemAgentConfig = { +export const DEFAULT_USER_MEMORY_EMBEDDING_SYSTEM_AGENT_ITEM: SystemAgentItem = { + model: DEFAULT_EMBEDDING_MODEL, + provider: DEFAULT_EMBEDDING_PROVIDER, +}; + +export const DEFAULT_SYSTEM_AGENT_CONFIG: UserServiceModelConfig = { agentMeta: DEFAULT_SYSTEM_AGENT_ITEM, followUpAction: DEFAULT_FOLLOW_UP_ACTION_SYSTEM_AGENT_ITEM, generationTopic: DEFAULT_MINI_SYSTEM_AGENT_ITEM, historyCompress: DEFAULT_SYSTEM_AGENT_ITEM, inputCompletion: DEFAULT_INPUT_COMPLETION_SYSTEM_AGENT_ITEM, + memoryAnalysisAgentConfig: DEFAULT_MINI_SYSTEM_AGENT_ITEM, + userMemoryEmbedding: DEFAULT_USER_MEMORY_EMBEDDING_SYSTEM_AGENT_ITEM, + userMemoryPersonaWriter: DEFAULT_MINI_SYSTEM_AGENT_ITEM, promptRewrite: DEFAULT_PROMPT_REWRITE_SYSTEM_AGENT_ITEM, thread: DEFAULT_SYSTEM_AGENT_ITEM, topic: DEFAULT_MINI_SYSTEM_AGENT_ITEM, diff --git a/packages/types/src/serverConfig.ts b/packages/types/src/serverConfig.ts index 403a66760a..5871401748 100644 --- a/packages/types/src/serverConfig.ts +++ b/packages/types/src/serverConfig.ts @@ -7,7 +7,7 @@ import type { GlobalLLMProviderKey, UserDefaultAgent, UserImageConfig, - UserSystemAgentConfig, + UserServiceModelConfig, } from './user/settings'; export type GlobalMemoryLayer = 'activity' | 'context' | 'experience' | 'identity' | 'preference'; @@ -76,7 +76,7 @@ export interface GlobalServerConfig { image?: PartialDeep; memory?: GlobalMemoryConfig; oAuthSSOProviders?: string[]; - systemAgent?: PartialDeep; + systemAgent?: PartialDeep; telemetry: { langfuse?: boolean; }; diff --git a/packages/types/src/user/settings/index.ts b/packages/types/src/user/settings/index.ts index effb416126..a2a5f4070a 100644 --- a/packages/types/src/user/settings/index.ts +++ b/packages/types/src/user/settings/index.ts @@ -9,7 +9,7 @@ import type { MarketAuthTokens } from './market'; import type { UserMemorySettings } from './memory'; import type { UserModelProviderConfig } from './modelProvider'; import type { NotificationSettings } from './notification'; -import type { UserSystemAgentConfig } from './systemAgent'; +import type { UserServiceModelConfig } from './systemAgent'; import type { UserToolConfig } from './tool'; import type { UserTTSConfig } from './tts'; @@ -42,7 +42,7 @@ export interface UserSettings { market?: MarketAuthTokens; memory?: UserMemorySettings; notification?: NotificationSettings; - systemAgent: UserSystemAgentConfig; + systemAgent: UserServiceModelConfig; tool: UserToolConfig; tts: UserTTSConfig; } diff --git a/packages/types/src/user/settings/systemAgent.ts b/packages/types/src/user/settings/systemAgent.ts index e652d04538..3e1f796f5a 100644 --- a/packages/types/src/user/settings/systemAgent.ts +++ b/packages/types/src/user/settings/systemAgent.ts @@ -1,4 +1,5 @@ export interface SystemAgentItem { + contextLimit?: number; customPrompt?: string; enabled?: boolean; model: string; @@ -21,4 +22,15 @@ export interface UserSystemAgentConfig { translation: SystemAgentItem; } +export interface UserMemoryServiceModelConfig { + memoryAnalysisAgentConfig: SystemAgentItem; + userMemoryEmbedding: SystemAgentItem; + userMemoryPersonaWriter: SystemAgentItem; +} + +export interface UserServiceModelConfig + extends UserSystemAgentConfig, UserMemoryServiceModelConfig {} + export type UserSystemAgentConfigKey = keyof UserSystemAgentConfig; +export type UserMemoryServiceModelConfigKey = keyof UserMemoryServiceModelConfig; +export type UserServiceModelConfigKey = keyof UserServiceModelConfig; diff --git a/src/features/ServiceModel/ModelAssignmentsForm.tsx b/src/features/ServiceModel/ModelAssignmentsForm.tsx index 9fa20ce063..df2fe2faec 100644 --- a/src/features/ServiceModel/ModelAssignmentsForm.tsx +++ b/src/features/ServiceModel/ModelAssignmentsForm.tsx @@ -1,7 +1,7 @@ 'use client'; import type { FormGroupItemType, FormItemProps } from '@lobehub/ui'; -import { Flexbox, Form, Icon, Skeleton } from '@lobehub/ui'; +import { Flexbox, Form, Icon, InputNumber, Skeleton } from '@lobehub/ui'; import { Switch } from 'antd'; import isEqual from 'fast-deep-equal'; import { Loader2Icon } from 'lucide-react'; @@ -12,13 +12,14 @@ import { FORM_STYLE } from '@/const/layoutTokens'; import ModelSelect from '@/features/ModelSelect'; import { useUserStore } from '@/store/user'; import { settingsSelectors } from '@/store/user/selectors'; -import type { SystemAgentItem, UserSystemAgentConfigKey } from '@/types/user/settings'; +import type { SystemAgentItem, UserServiceModelConfigKey } from '@/types/user/settings'; interface SystemAgentModelItem { - key: UserSystemAgentConfigKey; + contextLimit?: boolean; + key: UserServiceModelConfigKey; } -type LoadingKey = 'defaultAgent' | UserSystemAgentConfigKey; +type LoadingKey = 'defaultAgent' | UserServiceModelConfigKey; const SYSTEM_AGENT_MODEL_ITEMS: SystemAgentModelItem[] = [ { key: 'topic' }, @@ -34,6 +35,12 @@ const OPTIONAL_FEATURE_ITEMS: SystemAgentModelItem[] = [ { key: 'promptRewrite' }, ]; +const MEMORY_MODEL_ITEMS: SystemAgentModelItem[] = [ + { contextLimit: true, key: 'memoryAnalysisAgentConfig' }, + { contextLimit: true, key: 'userMemoryPersonaWriter' }, + { contextLimit: true, key: 'userMemoryEmbedding' }, +]; + const ModelAssignmentsForm = memo(() => { const { t } = useTranslation('setting'); const [defaultAgent, systemAgentSettings] = useUserStore( @@ -69,7 +76,7 @@ const ModelAssignmentsForm = memo(() => { }; const updateSystemAgentModel = async ( - key: UserSystemAgentConfigKey, + key: UserServiceModelConfigKey, value: Partial, ) => { setLoadingKey(key); @@ -121,6 +128,39 @@ const ModelAssignmentsForm = memo(() => { } satisfies FormItemProps; }); + const memoryModelItems: FormItemProps[] = MEMORY_MODEL_ITEMS.map(({ contextLimit, key }) => { + const value = systemAgentSettings[key]; + + return { + children: ( + + updateSystemAgentModel(key, props)} + /> + {contextLimit && ( + + updateSystemAgentModel(key, { + contextLimit: typeof contextLimit === 'number' ? contextLimit : undefined, + }) + } + /> + )} + + ), + desc: t(`systemAgent.${key}.modelDesc`), + label: t(`systemAgent.${key}.title`), + minWidth: undefined, + } satisfies FormItemProps; + }); + const optionalFeatureItems: FormItemProps[] = OPTIONAL_FEATURE_ITEMS.map(({ key }) => { const value = systemAgentSettings[key]; const disabled = value.enabled === false; @@ -167,7 +207,8 @@ const ModelAssignmentsForm = memo(() => { loadingKey === 'followUpAction' || loadingKey === 'inputCompletion' || loadingKey === 'promptRewrite'; - const isModelAssignmentLoading = loadingKey && !isOptionalFeatureLoading; + const isMemoryModelLoading = MEMORY_MODEL_ITEMS.some(({ key }) => loadingKey === key); + const isModelAssignmentLoading = loadingKey && !isOptionalFeatureLoading && !isMemoryModelLoading; const modelAssignments: FormGroupItemType = { children: [defaultAgentItem, ...systemModelItems], @@ -185,10 +226,18 @@ const ModelAssignmentsForm = memo(() => { title: t('serviceModel.optionalFeatures.title'), }; + const memoryModels: FormGroupItemType = { + children: memoryModelItems, + extra: isMemoryModelLoading && ( + + ), + title: t('serviceModel.memoryModels.title'), + }; + return (
{ expect(result.agentMeta).toEqual({ provider: 'ollama', model: 'deepseek-v3' }); expect(result.historyCompress).toEqual({ provider: 'ollama', model: 'deepseek-v3' }); expect(result.thread).toEqual({ provider: 'ollama', model: 'deepseek-v3' }); + expect(result.userMemoryEmbedding).toBeUndefined(); + expect(result.memoryAnalysisAgentConfig).toBeUndefined(); + expect(result.userMemoryPersonaWriter).toBeUndefined(); + }); + + it('should parse memory service model assignments explicitly', () => { + const envValue = + 'memoryAnalysisAgentConfig=lobehub/gpt-5.4-mini,userMemoryEmbedding=openai/text-embedding-3-large'; + + const result = parseSystemAgent(envValue); + + expect(result.memoryAnalysisAgentConfig).toEqual({ + provider: 'lobehub', + model: 'gpt-5.4-mini', + }); + expect(result.userMemoryEmbedding).toEqual({ + provider: 'openai', + model: 'text-embedding-3-large', + }); }); it('should override default setting with specific settings', () => { diff --git a/src/server/globalConfig/parseSystemAgent.ts b/src/server/globalConfig/parseSystemAgent.ts index dee7989192..42f999960c 100644 --- a/src/server/globalConfig/parseSystemAgent.ts +++ b/src/server/globalConfig/parseSystemAgent.ts @@ -1,14 +1,20 @@ import { DEFAULT_SYSTEM_AGENT_CONFIG } from '@/const/settings'; -import { type UserSystemAgentConfig } from '@/types/user/settings'; +import { type UserServiceModelConfig } from '@/types/user/settings'; const protectedKeys = Object.keys(DEFAULT_SYSTEM_AGENT_CONFIG); const defaultTrueLey = new Set(['promptRewrite', 'autoSuggestion']); +const memoryServiceModelKeys = new Set([ + 'memoryAnalysisAgentConfig', + 'userMemoryEmbedding', + 'userMemoryPersonaWriter', +]); +const defaultModelAssignmentKeys = protectedKeys.filter((key) => !memoryServiceModelKeys.has(key)); -export const parseSystemAgent = (envString: string = ''): Partial => { +export const parseSystemAgent = (envString: string = ''): Partial => { if (!envString) return {}; - const config: Partial = {}; + const config: Partial = {}; // Handle full-width commas and extra spaces const envValue = envString.replaceAll(',', ',').trim(); @@ -39,7 +45,7 @@ export const parseSystemAgent = (envString: string = ''): Partial { + const memoryServiceConfig = (executor as any).resolveUserMemoryServiceConfig(); + + return (executor as any).resolveRuntimeKeyVaults(runtimeState, memoryServiceConfig); +}; + describe('MemoryExtractionExecutor.resolveRuntimeKeyVaults', () => { + it('drops fallback credentials when user memory provider is overridden', () => { + const executor = createExecutor({ + embedding: { + apiKey: 'openai-system-key', + baseURL: 'https://openai.example.com', + model: 'embed-1', + provider: 'openai', + }, + }); + + const memoryServiceConfig = (executor as any).resolveUserMemoryServiceConfig({ + userMemoryEmbedding: { + model: 'embed-2', + provider: 'anthropic', + }, + }); + + expect(memoryServiceConfig.agents.embedding).toMatchObject({ + model: 'embed-2', + provider: 'anthropic', + }); + expect(memoryServiceConfig.agents.embedding.apiKey).toBeUndefined(); + expect(memoryServiceConfig.agents.embedding.baseURL).toBeUndefined(); + }); + + it('keeps fallback credentials when user memory provider is unchanged', () => { + const executor = createExecutor({ + embedding: { + apiKey: 'openai-system-key', + baseURL: 'https://openai.example.com', + model: 'embed-1', + provider: 'openai', + }, + }); + + const memoryServiceConfig = (executor as any).resolveUserMemoryServiceConfig({ + userMemoryEmbedding: { + model: 'embed-2', + provider: 'openai', + }, + }); + + expect(memoryServiceConfig.agents.embedding).toMatchObject({ + apiKey: 'openai-system-key', + baseURL: 'https://openai.example.com', + model: 'embed-2', + provider: 'openai', + }); + }); + + it('shares ServiceModel memory analysis config between gatekeeper and layer extractor', () => { + const executor = createExecutor({ + agentGateKeeper: { + apiKey: 'gate-system-key', + baseURL: 'https://gate.example.com', + model: 'gate-1', + provider: 'provider-gate', + }, + agentLayerExtractor: { + apiKey: 'layer-system-key', + baseURL: 'https://layer.example.com', + contextLimit: 2048, + layers: { + activity: 'layer-act', + context: 'layer-ctx', + experience: 'layer-exp', + identity: 'layer-id', + preference: 'layer-pref', + }, + model: 'layer-1', + provider: 'provider-layer', + }, + }); + + const memoryServiceConfig = (executor as any).resolveUserMemoryServiceConfig({ + memoryAnalysisAgentConfig: { + contextLimit: 4096, + model: 'analysis-1', + provider: 'provider-analysis', + }, + }); + + expect(memoryServiceConfig.agents.gatekeeper).toMatchObject({ + model: 'analysis-1', + provider: 'provider-analysis', + }); + expect(memoryServiceConfig.agents.layerExtractor).toMatchObject({ + contextLimit: 4096, + model: 'analysis-1', + provider: 'provider-analysis', + }); + expect(memoryServiceConfig.agents.gatekeeper.apiKey).toBeUndefined(); + expect(memoryServiceConfig.agents.layerExtractor.apiKey).toBeUndefined(); + expect(memoryServiceConfig.modelConfig.gateModel).toBe('analysis-1'); + expect(memoryServiceConfig.modelConfig.layerModels).toEqual({ + activity: 'analysis-1', + context: 'analysis-1', + experience: 'analysis-1', + identity: 'analysis-1', + preference: 'analysis-1', + }); + }); + + it('uses ServiceModel provider before env preferred providers when provider is overridden', async () => { + const executor = createExecutor({ + agentGateKeeper: { + model: 'gate-1', + provider: 'provider-g', + }, + agentLayerExtractor: { + contextLimit: 2048, + layers: { + activity: 'layer-1', + context: 'layer-1', + experience: 'layer-1', + identity: 'layer-1', + preference: 'layer-1', + }, + model: 'layer-1', + provider: 'provider-l', + }, + embedding: { + apiKey: 'openai-system-key', + baseURL: 'https://openai.example.com', + model: 'embed-1', + provider: 'openai', + }, + embeddingPreferredProviders: ['provider-b'], + }); + + const memoryServiceConfig = (executor as any).resolveUserMemoryServiceConfig({ + userMemoryEmbedding: { + model: 'embed-2', + provider: 'provider-a', + }, + }); + const runtimeState = createRuntimeState( + [ + { + abilities: {}, + enabled: true, + id: 'gate-1', + providerId: 'provider-g', + type: 'chat', + }, + { + abilities: {}, + enabled: true, + id: 'layer-1', + providerId: 'provider-l', + type: 'chat', + }, + { + abilities: {}, + enabled: true, + id: 'embed-2', + providerId: 'provider-a', + type: 'embedding', + }, + { + abilities: {}, + enabled: true, + id: 'embed-2', + providerId: 'provider-b', + type: 'embedding', + }, + ], + { + 'provider-a': { apiKey: 'a-key' }, + 'provider-b': { apiKey: 'b-key' }, + 'provider-g': { apiKey: 'g-key' }, + 'provider-l': { apiKey: 'l-key' }, + }, + ); + + const keyVaults = await (executor as any).resolveRuntimeKeyVaults( + runtimeState, + memoryServiceConfig, + ); + + expect(keyVaults).toMatchObject({ + 'provider-a': { apiKey: 'a-key' }, + }); + expect(keyVaults).not.toHaveProperty('provider-b'); + }); + it('prefers configured providers/models for gatekeeper, embedding, and layer extractors', async () => { const executor = createExecutor({ embeddingPreferredProviders: ['provider-c', 'provider-a'], @@ -119,7 +314,7 @@ describe('MemoryExtractionExecutor.resolveRuntimeKeyVaults', () => { }, ); - const keyVaults = await (executor as any).resolveRuntimeKeyVaults(runtimeState); + const keyVaults = await resolveRuntimeKeyVaults(executor, runtimeState); expect(keyVaults).toMatchObject({ 'provider-a': { apiKey: 'a-key' }, @@ -182,7 +377,7 @@ describe('MemoryExtractionExecutor.resolveRuntimeKeyVaults', () => { }, ); - const keyVaults = await (executor as any).resolveRuntimeKeyVaults(runtimeState); + const keyVaults = await resolveRuntimeKeyVaults(executor, runtimeState); expect(keyVaults).toMatchObject({ 'provider-b': { apiKey: 'b-key' }, @@ -222,7 +417,7 @@ describe('MemoryExtractionExecutor.resolveRuntimeKeyVaults', () => { }, ); - const keyVaults = await (executor as any).resolveRuntimeKeyVaults(runtimeState); + const keyVaults = await resolveRuntimeKeyVaults(executor, runtimeState); expect(keyVaults).toMatchObject({ 'provider-a': { apiKey: 'a-key' }, @@ -253,7 +448,7 @@ describe('MemoryExtractionExecutor.resolveRuntimeKeyVaults', () => { }, ); - const keyVaults = await (executor as any).resolveRuntimeKeyVaults(runtimeState); + const keyVaults = await resolveRuntimeKeyVaults(executor, runtimeState); expect(keyVaults).toMatchObject({ 'provider-b': { apiKey: 'b-key' }, // picks first preferred provider @@ -271,7 +466,7 @@ describe('MemoryExtractionExecutor.resolveRuntimeKeyVaults', () => { 'provider-fallback': { apiKey: 'fb-key' }, }); - const keyVaults = await (executor as any).resolveRuntimeKeyVaults(runtimeState); + const keyVaults = await resolveRuntimeKeyVaults(executor, runtimeState); expect(keyVaults).toMatchObject({ 'provider-fallback': { apiKey: 'fb-key' }, diff --git a/src/server/services/memory/userMemory/extract.ts b/src/server/services/memory/userMemory/extract.ts index 6fc987f3c2..6a0eae0234 100644 --- a/src/server/services/memory/userMemory/extract.ts +++ b/src/server/services/memory/userMemory/extract.ts @@ -45,6 +45,7 @@ import type { MemoryExtractionAgentCallTrace, MemoryExtractionTraceError, MemoryExtractionTracePayload, + UserServiceModelConfig, } from '@lobechat/types'; import { RequestTrigger } from '@lobechat/types'; import { type FlowControl } from '@upstash/qstash'; @@ -583,6 +584,38 @@ type RuntimeBundle = { layerExtractor: ModelRuntime; }; +interface MemoryExtractionModelConfig { + embeddingsModel: string; + gateModel: string; + layerModels: Partial>; + observabilityS3: MemoryExtractionConfig['observabilityS3']; +} + +interface ResolvedMemoryServiceConfig { + agents: { + embedding: MemoryAgentConfig; + gatekeeper: MemoryAgentConfig; + layerExtractor: MemoryAgentConfig; + }; + embeddingContextLimit?: number; + extractorContextLimit?: number; + modelConfig: MemoryExtractionModelConfig; + overrides: { + embedding: { + model: boolean; + provider: boolean; + }; + gatekeeper: { + model: boolean; + provider: boolean; + }; + layerExtractor: { + model: boolean; + provider: boolean; + }; + }; +} + export interface TopicExtractionJob { asyncTaskId?: string; forceAll: boolean; @@ -623,14 +656,7 @@ export class MemoryExtractionExecutor { private readonly layerPreferredModels?: string[]; private readonly layerPreferredProviders?: string[]; private readonly privateConfig: MemoryExtractionConfig; - private readonly modelConfig: { - embeddingsModel: string; - gateModel: string; - layerModels: Partial>; - observabilityS3: MemoryExtractionConfig['observabilityS3']; - }; - private readonly embeddingContextLimit?: number; - + private readonly modelConfig: MemoryExtractionModelConfig; private readonly runtimeCache = new Map(); private readonly db = getServerDB(); @@ -659,9 +685,6 @@ export class MemoryExtractionExecutor { ), observabilityS3: privateConfig.observabilityS3, }; - - this.embeddingContextLimit = - privateConfig.embedding?.contextLimit ?? privateConfig.agentLayerExtractor.contextLimit; } static async create() { @@ -673,6 +696,95 @@ export class MemoryExtractionExecutor { return new MemoryExtractionExecutor(serverConfig, privateConfig); } + private resolveUserMemoryAgent( + systemAgent: Partial | undefined, + key: keyof Pick< + UserServiceModelConfig, + 'userMemoryEmbedding' | 'memoryAnalysisAgentConfig' | 'userMemoryPersonaWriter' + >, + fallback: MemoryAgentConfig, + ): MemoryAgentConfig { + const override = systemAgent?.[key]; + const provider = override?.provider || fallback.provider; + const shouldInheritCredentials = + !override?.provider || + normalizeProvider(override.provider) === normalizeProvider(fallback.provider || 'openai'); + const contextLimit = + typeof override?.contextLimit === 'number' && + Number.isFinite(override.contextLimit) && + override.contextLimit > 0 + ? Math.floor(override.contextLimit) + : fallback.contextLimit; + + return { + apiKey: shouldInheritCredentials ? fallback.apiKey : undefined, + baseURL: shouldInheritCredentials ? fallback.baseURL : undefined, + contextLimit, + language: fallback.language, + model: override?.model || fallback.model, + provider, + }; + } + + private resolveUserMemoryServiceConfig( + systemAgent?: Partial, + ): ResolvedMemoryServiceConfig { + const gatekeeper = this.resolveUserMemoryAgent( + systemAgent, + 'memoryAnalysisAgentConfig', + this.privateConfig.agentGateKeeper, + ); + const layerExtractor = this.resolveUserMemoryAgent( + systemAgent, + 'memoryAnalysisAgentConfig', + this.privateConfig.agentLayerExtractor, + ); + const embedding = this.resolveUserMemoryAgent( + systemAgent, + 'userMemoryEmbedding', + this.privateConfig.embedding, + ); + const layerModels = systemAgent?.memoryAnalysisAgentConfig?.model + ? { + activity: layerExtractor.model, + context: layerExtractor.model, + experience: layerExtractor.model, + identity: layerExtractor.model, + preference: layerExtractor.model, + } + : resolveLayerModels(undefined, this.privateConfig.agentLayerExtractor.layers); + + return { + agents: { + embedding, + gatekeeper, + layerExtractor, + }, + embeddingContextLimit: embedding.contextLimit ?? layerExtractor.contextLimit, + extractorContextLimit: layerExtractor.contextLimit, + modelConfig: { + embeddingsModel: embedding.model, + gateModel: gatekeeper.model, + layerModels, + observabilityS3: this.privateConfig.observabilityS3, + }, + overrides: { + embedding: { + model: Boolean(systemAgent?.userMemoryEmbedding?.model), + provider: Boolean(systemAgent?.userMemoryEmbedding?.provider), + }, + gatekeeper: { + model: Boolean(systemAgent?.memoryAnalysisAgentConfig?.model), + provider: Boolean(systemAgent?.memoryAnalysisAgentConfig?.provider), + }, + layerExtractor: { + model: Boolean(systemAgent?.memoryAnalysisAgentConfig?.model), + provider: Boolean(systemAgent?.memoryAnalysisAgentConfig?.provider), + }, + }, + }; + } + private buildBaseMetadata( job: MemoryExtractionJob, messageIds: string[], @@ -1431,10 +1543,16 @@ export class MemoryExtractionExecutor { userModel.getUserState(KeyVaultsGateKeeper.getUserKeyVaults), this.getAiProviderRuntimeState(job.userId), ]); - const keyVaults = await this.resolveRuntimeKeyVaults(aiProviderRuntimeState); + const memoryServiceConfig = this.resolveUserMemoryServiceConfig( + userState.settings?.systemAgent as Partial | undefined, + ); + const keyVaults = await this.resolveRuntimeKeyVaults( + aiProviderRuntimeState, + memoryServiceConfig, + ); const language = userState.settings?.general?.responseLanguage; - const runtimes = await this.getRuntime(job.userId, keyVaults); + const runtimes = await this.getRuntime(job.userId, memoryServiceConfig, keyVaults); const conversations = await this.listConversationsForTopic( job.userId, @@ -1454,8 +1572,9 @@ export class MemoryExtractionExecutor { }; } - const extractorContextLimit = this.privateConfig.agentLayerExtractor.contextLimit; - const embeddingContextLimit = this.embeddingContextLimit ?? extractorContextLimit; + const extractorContextLimit = memoryServiceConfig.extractorContextLimit; + const embeddingContextLimit = + memoryServiceConfig.embeddingContextLimit ?? extractorContextLimit; const extractorConversations = await this.trimConversationsToTokenLimit( conversations, extractorContextLimit, @@ -1508,7 +1627,7 @@ export class MemoryExtractionExecutor { searchResult = await this.listRelevantUserMemories( extractionJob, runtimes.embeddings, - this.modelConfig.embeddingsModel, + memoryServiceConfig.modelConfig.embeddingsModel, job.userId, embeddingConversations, embeddingContextLimit, @@ -1621,7 +1740,7 @@ export class MemoryExtractionExecutor { }; const service = new MemoryExtractionService({ - config: this.modelConfig, + config: memoryServiceConfig.modelConfig, db, language, runtimes, @@ -1676,6 +1795,7 @@ export class MemoryExtractionExecutor { messageIds, extraction, runtimes, + memoryServiceConfig, db, ); if (retrievalErrors.length > 0) { @@ -2023,6 +2143,7 @@ export class MemoryExtractionExecutor { messageIds: string[], extraction: MemoryExtractionResult, runtimes: RuntimeBundle, + memoryServiceConfig: ResolvedMemoryServiceConfig, db: Awaited>, ): Promise { const createdIds: string[] = []; @@ -2108,8 +2229,8 @@ export class MemoryExtractionExecutor { messageIds, activityOutput.data, runtimes.embeddings, - this.modelConfig.embeddingsModel, - this.embeddingContextLimit, + memoryServiceConfig.modelConfig.embeddingsModel, + memoryServiceConfig.embeddingContextLimit, db, ), ); @@ -2126,8 +2247,8 @@ export class MemoryExtractionExecutor { messageIds, contextOutput.data, runtimes.embeddings, - this.modelConfig.embeddingsModel, - this.embeddingContextLimit, + memoryServiceConfig.modelConfig.embeddingsModel, + memoryServiceConfig.embeddingContextLimit, db, ), ); @@ -2144,8 +2265,8 @@ export class MemoryExtractionExecutor { messageIds, experienceOutput.data, runtimes.embeddings, - this.modelConfig.embeddingsModel, - this.embeddingContextLimit, + memoryServiceConfig.modelConfig.embeddingsModel, + memoryServiceConfig.embeddingContextLimit, db, ), ); @@ -2162,8 +2283,8 @@ export class MemoryExtractionExecutor { messageIds, preferenceOutput.data, runtimes.embeddings, - this.modelConfig.embeddingsModel, - this.embeddingContextLimit, + memoryServiceConfig.modelConfig.embeddingsModel, + memoryServiceConfig.embeddingContextLimit, db, ), ); @@ -2180,8 +2301,8 @@ export class MemoryExtractionExecutor { messageIds, identityOutput.data, runtimes.embeddings, - this.modelConfig.embeddingsModel, - this.embeddingContextLimit, + memoryServiceConfig.modelConfig.embeddingsModel, + memoryServiceConfig.embeddingContextLimit, db, ), ); @@ -2210,6 +2331,7 @@ export class MemoryExtractionExecutor { private async resolveRuntimeKeyVaults( runtimeState: AiProviderRuntimeState, + memoryServiceConfig: ResolvedMemoryServiceConfig, ): Promise { const normalizedRuntimeConfig = Object.fromEntries( Object.entries(runtimeState.runtimeConfig || {}).map(([providerId, config]) => [ @@ -2221,11 +2343,15 @@ export class MemoryExtractionExecutor { const keyVaults: ProviderKeyVaultMap = {}; const gatekeeperProvider = await AiInfraRepos.tryMatchingProviderFrom(runtimeState, { - fallbackProvider: this.privateConfig.agentGateKeeper.provider, + fallbackProvider: memoryServiceConfig.agents.gatekeeper.provider, label: 'gatekeeper', - modelId: this.modelConfig.gateModel, - preferredModels: this.gatekeeperPreferredModels, - preferredProviders: this.gatekeeperPreferredProviders, + modelId: memoryServiceConfig.modelConfig.gateModel, + preferredModels: memoryServiceConfig.overrides.gatekeeper.model + ? undefined + : this.gatekeeperPreferredModels, + preferredProviders: memoryServiceConfig.overrides.gatekeeper.provider + ? undefined + : this.gatekeeperPreferredProviders, }); const gatekeeperRuntime = normalizedRuntimeConfig[gatekeeperProvider]; if (gatekeeperRuntime?.keyVaults) { @@ -2233,25 +2359,33 @@ export class MemoryExtractionExecutor { } const embeddingProvider = await AiInfraRepos.tryMatchingProviderFrom(runtimeState, { - fallbackProvider: this.privateConfig.embedding.provider, + fallbackProvider: memoryServiceConfig.agents.embedding.provider, label: 'embedding', - modelId: this.modelConfig.embeddingsModel, - preferredModels: this.embeddingPreferredModels, - preferredProviders: this.embeddingPreferredProviders, + modelId: memoryServiceConfig.modelConfig.embeddingsModel, + preferredModels: memoryServiceConfig.overrides.embedding.model + ? undefined + : this.embeddingPreferredModels, + preferredProviders: memoryServiceConfig.overrides.embedding.provider + ? undefined + : this.embeddingPreferredProviders, }); const embeddingRuntime = normalizedRuntimeConfig[embeddingProvider]; if (embeddingRuntime?.keyVaults) { keyVaults[embeddingProvider] = embeddingRuntime.keyVaults; } - for (const model of Object.values(this.modelConfig.layerModels)) { + for (const model of Object.values(memoryServiceConfig.modelConfig.layerModels)) { if (!model) continue; const providerId = await AiInfraRepos.tryMatchingProviderFrom(runtimeState, { - fallbackProvider: this.privateConfig.agentLayerExtractor.provider, + fallbackProvider: memoryServiceConfig.agents.layerExtractor.provider, label: 'layer extractor', modelId: model, - preferredModels: this.layerPreferredModels, - preferredProviders: this.layerPreferredProviders, + preferredModels: memoryServiceConfig.overrides.layerExtractor.model + ? undefined + : this.layerPreferredModels, + preferredProviders: memoryServiceConfig.overrides.layerExtractor.provider + ? undefined + : this.layerPreferredProviders, }); const runtime = normalizedRuntimeConfig[providerId]; if (runtime?.keyVaults) { @@ -2264,6 +2398,7 @@ export class MemoryExtractionExecutor { private async getRuntime( userId: string, + memoryServiceConfig: ResolvedMemoryServiceConfig, keyVaults?: ProviderKeyVaultMap, ): Promise { // TODO: implement a better cache eviction strategy @@ -2272,33 +2407,51 @@ export class MemoryExtractionExecutor { this.runtimeCache.clear(); } - const cached = this.runtimeCache.get(userId); + const cacheKey = [ + userId, + memoryServiceConfig.agents.embedding.provider, + memoryServiceConfig.agents.gatekeeper.provider, + memoryServiceConfig.agents.layerExtractor.provider, + ].join(':'); + const cached = this.runtimeCache.get(cacheKey); if (cached) return cached; const embeddingOptions: RuntimeResolveOptions = { fallback: { - apiKey: this.privateConfig.embedding.apiKey, - baseURL: this.privateConfig.embedding.baseURL, + apiKey: memoryServiceConfig.agents.embedding.apiKey, + baseURL: memoryServiceConfig.agents.embedding.baseURL, + }, + preferred: { + providerIds: memoryServiceConfig.overrides.embedding.provider + ? undefined + : this.embeddingPreferredProviders, }, - preferred: { providerIds: this.embeddingPreferredProviders }, userId, }; const gatekeeperOptions: RuntimeResolveOptions = { fallback: { - apiKey: this.privateConfig.agentGateKeeper.apiKey, - baseURL: this.privateConfig.agentGateKeeper.baseURL, + apiKey: memoryServiceConfig.agents.gatekeeper.apiKey, + baseURL: memoryServiceConfig.agents.gatekeeper.baseURL, + }, + preferred: { + providerIds: memoryServiceConfig.overrides.gatekeeper.provider + ? undefined + : this.gatekeeperPreferredProviders, }, - preferred: { providerIds: this.gatekeeperPreferredProviders }, userId, }; const layerExtractorOptions: RuntimeResolveOptions = { fallback: { - apiKey: this.privateConfig.agentLayerExtractor.apiKey, - baseURL: this.privateConfig.agentLayerExtractor.baseURL, + apiKey: memoryServiceConfig.agents.layerExtractor.apiKey, + baseURL: memoryServiceConfig.agents.layerExtractor.baseURL, + }, + preferred: { + providerIds: memoryServiceConfig.overrides.layerExtractor.provider + ? undefined + : this.layerPreferredProviders, }, - preferred: { providerIds: this.layerPreferredProviders }, userId, }; @@ -2306,26 +2459,26 @@ export class MemoryExtractionExecutor { const runtimes: RuntimeBundle = { embeddings: await resolveRuntimeAgentConfig( - { ...this.privateConfig.embedding }, + memoryServiceConfig.agents.embedding, keyVaults, embeddingOptions, hooks, ), gatekeeper: await resolveRuntimeAgentConfig( - { ...this.privateConfig.agentGateKeeper }, + memoryServiceConfig.agents.gatekeeper, keyVaults, gatekeeperOptions, hooks, ), layerExtractor: await resolveRuntimeAgentConfig( - { ...this.privateConfig.agentLayerExtractor }, + memoryServiceConfig.agents.layerExtractor, keyVaults, layerExtractorOptions, hooks, ), }; - this.runtimeCache.set(userId, runtimes); + this.runtimeCache.set(cacheKey, runtimes); return runtimes; } @@ -2361,10 +2514,16 @@ export class MemoryExtractionExecutor { userModel.getUserState(KeyVaultsGateKeeper.getUserKeyVaults), this.getAiProviderRuntimeState(params.userId), ]); - const keyVaults = await this.resolveRuntimeKeyVaults(aiProviderRuntimeState); + const memoryServiceConfig = this.resolveUserMemoryServiceConfig( + userState.settings?.systemAgent as Partial | undefined, + ); + const keyVaults = await this.resolveRuntimeKeyVaults( + aiProviderRuntimeState, + memoryServiceConfig, + ); const language = params.language || userState.settings?.general?.responseLanguage; - const runtimes = await this.getRuntime(params.userId, keyVaults); + const runtimes = await this.getRuntime(params.userId, memoryServiceConfig, keyVaults); const contextProvider = params.contextProvider || new BenchmarkLocomoContextProvider({ @@ -2393,7 +2552,7 @@ export class MemoryExtractionExecutor { }; const builtContext = await contextProvider.buildContext(extractionJob.userId); - const extractorContextLimit = this.privateConfig.agentLayerExtractor.contextLimit; + const extractorContextLimit = memoryServiceConfig.extractorContextLimit; const trimmedContext = await this.trimTextToTokenLimit( builtContext.context, extractorContextLimit, @@ -2431,7 +2590,7 @@ export class MemoryExtractionExecutor { }; const service = new MemoryExtractionService({ - config: this.modelConfig, + config: memoryServiceConfig.modelConfig, db, language, runtimes, @@ -2476,6 +2635,7 @@ export class MemoryExtractionExecutor { [], extraction, runtimes, + memoryServiceConfig, db, ); diff --git a/src/server/services/memory/userMemory/persona/__tests__/service.test.ts b/src/server/services/memory/userMemory/persona/__tests__/service.test.ts index 9c8e48bc55..87fe7eacf6 100644 --- a/src/server/services/memory/userMemory/persona/__tests__/service.test.ts +++ b/src/server/services/memory/userMemory/persona/__tests__/service.test.ts @@ -1,11 +1,12 @@ // @vitest-environment node import { type LobeChatDatabase } from '@lobechat/database'; -import { users } from '@lobechat/database/schemas'; +import { users, userSettings } from '@lobechat/database/schemas'; import { getTestDB } from '@lobechat/database/test-utils'; import { beforeEach, describe, expect, it, vi } from 'vitest'; import { UserPersonaModel } from '@/database/models/userMemory/persona'; import type * as AiInfraReposModule from '@/database/repositories/aiInfra'; +import { resolveRuntimeAgentConfig } from '@/server/services/memory/userMemory/extract'; import { UserPersonaService } from '../service'; @@ -134,4 +135,46 @@ describe('UserPersonaService', () => { }), ); }); + + it('drops fallback credentials when persona writer provider is overridden', async () => { + await db.insert(userSettings).values({ + id: userId, + systemAgent: { + userMemoryPersonaWriter: { + model: 'claude-mock', + provider: 'anthropic', + }, + }, + }); + aiInfraMocks.tryMatchingProviderFrom.mockResolvedValue('anthropic'); + aiInfraMocks.getAiProviderRuntimeState.mockResolvedValue({ + enabledAiModels: [ + { abilities: {}, enabled: true, id: 'claude-mock', providerId: 'anthropic', type: 'chat' }, + ], + enabledAiProviders: [], + enabledChatAiProviders: [], + enabledImageAiProviders: [], + runtimeConfig: {}, + }); + + const service = new UserPersonaService(db); + await service.composeWriting({ userId, username: 'User' }); + + expect(resolveRuntimeAgentConfig).toHaveBeenLastCalledWith( + expect.objectContaining({ + apiKey: undefined, + baseURL: undefined, + model: 'claude-mock', + provider: 'anthropic', + }), + expect.any(Object), + expect.objectContaining({ + fallback: { + apiKey: undefined, + baseURL: undefined, + }, + }), + undefined, + ); + }); }); diff --git a/src/server/services/memory/userMemory/persona/service.ts b/src/server/services/memory/userMemory/persona/service.ts index 3d4ee3108d..e778b1e918 100644 --- a/src/server/services/memory/userMemory/persona/service.ts +++ b/src/server/services/memory/userMemory/persona/service.ts @@ -9,9 +9,11 @@ import { RetrievalUserMemoryIdentitiesProvider, UserPersonaExtractor, } from '@lobechat/memory-user-memory'; +import type { UserServiceModelConfig } from '@lobechat/types'; import { desc, eq } from 'drizzle-orm'; import { getBusinessModelRuntimeHooks } from '@/business/server/model-runtime'; +import { UserModel } from '@/database/models/user'; import { UserMemoryModel } from '@/database/models/userMemory'; import { UserPersonaModel } from '@/database/models/userMemory/persona'; import { AiInfraRepos } from '@/database/repositories/aiInfra'; @@ -47,6 +49,14 @@ interface UserPersonaAgentResult { document: UserPersonaDocument; } +const resolvePositiveInteger = (value?: number) => { + if (typeof value !== 'number' || !Number.isFinite(value) || value <= 0) return undefined; + + return Math.floor(value); +}; + +const normalizeProvider = (provider: string) => provider.toLowerCase(); + export class UserPersonaService { private readonly preferredLanguage?: string; private readonly db: LobeChatDatabase; @@ -60,15 +70,40 @@ export class UserPersonaService { this.agentConfig = agentPersonaWriter; } + private async resolveAgentConfig(userId: string): Promise { + const userModel = new UserModel(this.db, userId); + const settings = await userModel.getUserSettings(); + const userMemoryPersonaWriter = ( + settings?.systemAgent as Partial | undefined + )?.userMemoryPersonaWriter; + const provider = userMemoryPersonaWriter?.provider || this.agentConfig.provider; + const shouldInheritCredentials = + !userMemoryPersonaWriter?.provider || + normalizeProvider(userMemoryPersonaWriter.provider) === + normalizeProvider(this.agentConfig.provider || 'openai'); + + return { + apiKey: shouldInheritCredentials ? this.agentConfig.apiKey : undefined, + baseURL: shouldInheritCredentials ? this.agentConfig.baseURL : undefined, + contextLimit: + resolvePositiveInteger(userMemoryPersonaWriter?.contextLimit) ?? + this.agentConfig.contextLimit, + language: this.agentConfig.language, + model: userMemoryPersonaWriter?.model || this.agentConfig.model, + provider, + }; + } + async composeWriting(payload: UserPersonaAgentPayload): Promise { + const agentConfig = await this.resolveAgentConfig(payload.userId); const aiInfraRepos = new AiInfraRepos(this.db, payload.userId, {}); const runtimeState = await aiInfraRepos.getAiProviderRuntimeState( KeyVaultsGateKeeper.getUserKeyVaults, ); const providerId = await AiInfraRepos.tryMatchingProviderFrom(runtimeState, { - fallbackProvider: this.agentConfig.provider, + fallbackProvider: agentConfig.provider, label: 'persona writer', - modelId: this.agentConfig.model, + modelId: agentConfig.model, }); const keyVaults: ProviderKeyVaultMap = Object.entries(runtimeState.runtimeConfig || {}).reduce( @@ -82,12 +117,12 @@ export class UserPersonaService { const hooks = getBusinessModelRuntimeHooks(payload.userId, 'lobehub'); const runtime = await resolveRuntimeAgentConfig( - { ...this.agentConfig }, + agentConfig, keyVaults, { fallback: { - apiKey: this.agentConfig.apiKey, - baseURL: this.agentConfig.baseURL, + apiKey: agentConfig.apiKey, + baseURL: agentConfig.baseURL, }, preferred: { providerIds: [providerId] }, userId: payload.userId, @@ -101,7 +136,7 @@ export class UserPersonaService { const extractor = new UserPersonaExtractor({ agent: 'user-persona', - model: this.agentConfig.model, + model: agentConfig.model, modelRuntime: runtime, }); @@ -136,7 +171,14 @@ export const buildUserPersonaJobInput = async (db: LobeChatDatabase, userId: str const personaModel = new UserPersonaModel(db, userId); const latestPersona = await personaModel.getLatestPersonaDocument(); const { agentPersonaWriter } = parseMemoryExtractionConfig(); - const personaContextLimit = agentPersonaWriter.contextLimit; + const userModel = new UserModel(db, userId); + const settings = await userModel.getUserSettings(); + const userMemoryPersonaWriter = ( + settings?.systemAgent as Partial | undefined + )?.userMemoryPersonaWriter; + const personaContextLimit = + resolvePositiveInteger(userMemoryPersonaWriter?.contextLimit) ?? + agentPersonaWriter.contextLimit; const userMemoryModel = new UserMemoryModel(db, userId); diff --git a/src/store/user/slices/settings/action.ts b/src/store/user/slices/settings/action.ts index 1bc9d24f33..a888083102 100644 --- a/src/store/user/slices/settings/action.ts +++ b/src/store/user/slices/settings/action.ts @@ -11,6 +11,7 @@ import type { SystemAgentItem, UserGeneralConfig, UserKeyVaults, + UserServiceModelConfigKey, UserSettings, UserSystemAgentConfigKey, } from '@/types/user/settings'; @@ -237,7 +238,7 @@ export class UserSettingsActionImpl { }; updateSystemAgent = async ( - key: UserSystemAgentConfigKey, + key: UserServiceModelConfigKey, value: Partial, ): Promise => { await this.#get().setSettings({ diff --git a/src/store/user/slices/settings/selectors/__snapshots__/settings.test.ts.snap b/src/store/user/slices/settings/selectors/__snapshots__/settings.test.ts.snap index 9531e1ca69..e5ac478563 100644 --- a/src/store/user/slices/settings/selectors/__snapshots__/settings.test.ts.snap +++ b/src/store/user/slices/settings/selectors/__snapshots__/settings.test.ts.snap @@ -73,6 +73,10 @@ exports[`settingsSelectors > currentSystemAgent > should merge DEFAULT_SYSTEM_AG "model": "gpt-5.4-mini", "provider": "openai", }, + "memoryAnalysisAgentConfig": { + "model": "gpt-5.4-mini", + "provider": "openai", + }, "promptRewrite": { "enabled": true, "model": "gpt-5.4-mini", @@ -91,6 +95,14 @@ exports[`settingsSelectors > currentSystemAgent > should merge DEFAULT_SYSTEM_AG "model": "gpt-5.4-mini", "provider": "openai", }, + "userMemoryEmbedding": { + "model": "text-embedding-3-small", + "provider": "openai", + }, + "userMemoryPersonaWriter": { + "model": "gpt-5.4-mini", + "provider": "openai", + }, } `; diff --git a/vitest.config.mts b/vitest.config.mts index 8aa8982b4d..5136bca778 100644 --- a/vitest.config.mts +++ b/vitest.config.mts @@ -16,6 +16,14 @@ const alias = { __dirname, './packages/business/model-runtime/src/index.ts', ), + '@lobechat/business-model-bank/model-config': resolve( + __dirname, + './packages/business/model-bank/src/model-config.ts', + ), + '@lobechat/business-model-bank': resolve( + __dirname, + './packages/business/model-bank/src/index.ts', + ), '@emoji-mart/data': resolve(__dirname, './tests/mocks/emojiMartData.ts'), '@emoji-mart/react': resolve(__dirname, './tests/mocks/emojiMartReact.tsx'), '@/database/_deprecated': resolve(__dirname, './src/database/_deprecated'),