mirror of
https://github.com/lobehub/lobe-chat.git
synced 2026-06-13 19:20:04 +00:00
♻️ refactor(userMemories): support resolving agent config from ServiceModel (#15138)
* ♻️ refactor(userMemories): support resolving agent config from ServiceModel * ♻️ refactor(userMemories): share memory analysis service model
This commit is contained in:
@@ -503,6 +503,8 @@
|
||||
"plugin.settings.tooltip": "Skill Configuration",
|
||||
"plugin.store": "Skill Store",
|
||||
"publishToCommunity": "Publish to Community",
|
||||
"serviceModel.contextLimit.placeholder": "Context limit",
|
||||
"serviceModel.memoryModels.title": "Memory Models",
|
||||
"serviceModel.modelAssignments.title": "Model Assignments",
|
||||
"serviceModel.optionalFeatures.title": "Optional Features",
|
||||
"settingAgent.avatar.sizeExceeded": "Image size exceeds 1MB limit, please choose a smaller image",
|
||||
@@ -850,6 +852,9 @@
|
||||
"systemAgent.inputCompletion.label": "Model",
|
||||
"systemAgent.inputCompletion.modelDesc": "Suggests text while you type. When enabled, this model generates the suggestions.",
|
||||
"systemAgent.inputCompletion.title": "Input Suggestions",
|
||||
"systemAgent.memoryAnalysisAgentConfig.label": "Model",
|
||||
"systemAgent.memoryAnalysisAgentConfig.modelDesc": "Model used to decide whether conversations contain memory and extract identities, preferences, contexts, activities, and experiences.",
|
||||
"systemAgent.memoryAnalysisAgentConfig.title": "Memory Analysis",
|
||||
"systemAgent.promptRewrite.label": "Model",
|
||||
"systemAgent.promptRewrite.modelDesc": "Improves prompts before generation. When enabled, this model rewrites the prompt.",
|
||||
"systemAgent.promptRewrite.title": "Prompt Rewriting",
|
||||
@@ -863,6 +868,12 @@
|
||||
"systemAgent.translation.label": "Model",
|
||||
"systemAgent.translation.modelDesc": "Model used to translate messages",
|
||||
"systemAgent.translation.title": "Message Translation",
|
||||
"systemAgent.userMemoryEmbedding.label": "Model",
|
||||
"systemAgent.userMemoryEmbedding.modelDesc": "Model used to embed memory content for retrieval. The context limit caps each embedding input.",
|
||||
"systemAgent.userMemoryEmbedding.title": "Memory Embedding",
|
||||
"systemAgent.userMemoryPersonaWriter.label": "Model",
|
||||
"systemAgent.userMemoryPersonaWriter.modelDesc": "Model used to write persona-oriented memory summaries.",
|
||||
"systemAgent.userMemoryPersonaWriter.title": "Memory Persona Writer",
|
||||
"tab.about": "About",
|
||||
"tab.addAgentSkill": "Add Agent Skill",
|
||||
"tab.addCustomMcp": "Add Custom MCP Skill",
|
||||
|
||||
@@ -503,6 +503,8 @@
|
||||
"plugin.settings.tooltip": "技能配置",
|
||||
"plugin.store": "技能商店",
|
||||
"publishToCommunity": "发布到社区",
|
||||
"serviceModel.contextLimit.placeholder": "上下文限制",
|
||||
"serviceModel.memoryModels.title": "记忆模型",
|
||||
"serviceModel.modelAssignments.title": "模型分配",
|
||||
"serviceModel.optionalFeatures.title": "可选功能",
|
||||
"settingAgent.avatar.sizeExceeded": "图片大小超过 1MB 限制,请选择更小的图片",
|
||||
@@ -850,6 +852,9 @@
|
||||
"systemAgent.inputCompletion.label": "模型",
|
||||
"systemAgent.inputCompletion.modelDesc": "输入时生成文本建议。开启后,由该模型生成建议。",
|
||||
"systemAgent.inputCompletion.title": "输入建议",
|
||||
"systemAgent.memoryAnalysisAgentConfig.label": "模型",
|
||||
"systemAgent.memoryAnalysisAgentConfig.modelDesc": "用于判断对话是否包含记忆,并提取身份、偏好、上下文、活动和经历。",
|
||||
"systemAgent.memoryAnalysisAgentConfig.title": "记忆分析",
|
||||
"systemAgent.promptRewrite.label": "模型",
|
||||
"systemAgent.promptRewrite.modelDesc": "生成前优化提示词。开启后,由该模型改写提示词。",
|
||||
"systemAgent.promptRewrite.title": "提示词改写",
|
||||
@@ -863,6 +868,12 @@
|
||||
"systemAgent.translation.label": "模型",
|
||||
"systemAgent.translation.modelDesc": "用于翻译消息内容的模型",
|
||||
"systemAgent.translation.title": "消息内容翻译",
|
||||
"systemAgent.userMemoryEmbedding.label": "模型",
|
||||
"systemAgent.userMemoryEmbedding.modelDesc": "用于为记忆内容生成向量以支持检索。上下文限制会约束每次向量化输入。",
|
||||
"systemAgent.userMemoryEmbedding.title": "记忆向量化",
|
||||
"systemAgent.userMemoryPersonaWriter.label": "模型",
|
||||
"systemAgent.userMemoryPersonaWriter.modelDesc": "用于生成面向画像的记忆摘要。",
|
||||
"systemAgent.userMemoryPersonaWriter.title": "记忆画像写入",
|
||||
"tab.about": "关于",
|
||||
"tab.addAgentSkill": "添加 Agent 技能",
|
||||
"tab.addCustomMcp": "添加自定义 MCP 技能",
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
import { DEFAULT_MINI_PROVIDER, DEFAULT_PROVIDER } from '@lobechat/business-const';
|
||||
import {
|
||||
DEFAULT_EMBEDDING_PROVIDER,
|
||||
DEFAULT_MINI_PROVIDER,
|
||||
DEFAULT_PROVIDER,
|
||||
} from '@lobechat/business-const';
|
||||
import type {
|
||||
PromptRewriteSystemAgent,
|
||||
SystemAgentItem,
|
||||
UserSystemAgentConfig,
|
||||
UserServiceModelConfig,
|
||||
} from '@lobechat/types';
|
||||
|
||||
import { DEFAULT_MINI_MODEL, DEFAULT_MODEL } from './llm';
|
||||
import { DEFAULT_EMBEDDING_MODEL, DEFAULT_MINI_MODEL, DEFAULT_MODEL } from './llm';
|
||||
|
||||
export const DEFAULT_SYSTEM_AGENT_ITEM: SystemAgentItem = {
|
||||
model: DEFAULT_MODEL,
|
||||
@@ -35,12 +39,20 @@ export const DEFAULT_FOLLOW_UP_ACTION_SYSTEM_AGENT_ITEM: SystemAgentItem = {
|
||||
provider: DEFAULT_MINI_SYSTEM_AGENT_ITEM.provider,
|
||||
};
|
||||
|
||||
export const DEFAULT_SYSTEM_AGENT_CONFIG: UserSystemAgentConfig = {
|
||||
export const DEFAULT_USER_MEMORY_EMBEDDING_SYSTEM_AGENT_ITEM: SystemAgentItem = {
|
||||
model: DEFAULT_EMBEDDING_MODEL,
|
||||
provider: DEFAULT_EMBEDDING_PROVIDER,
|
||||
};
|
||||
|
||||
export const DEFAULT_SYSTEM_AGENT_CONFIG: UserServiceModelConfig = {
|
||||
agentMeta: DEFAULT_SYSTEM_AGENT_ITEM,
|
||||
followUpAction: DEFAULT_FOLLOW_UP_ACTION_SYSTEM_AGENT_ITEM,
|
||||
generationTopic: DEFAULT_MINI_SYSTEM_AGENT_ITEM,
|
||||
historyCompress: DEFAULT_SYSTEM_AGENT_ITEM,
|
||||
inputCompletion: DEFAULT_INPUT_COMPLETION_SYSTEM_AGENT_ITEM,
|
||||
memoryAnalysisAgentConfig: DEFAULT_MINI_SYSTEM_AGENT_ITEM,
|
||||
userMemoryEmbedding: DEFAULT_USER_MEMORY_EMBEDDING_SYSTEM_AGENT_ITEM,
|
||||
userMemoryPersonaWriter: DEFAULT_MINI_SYSTEM_AGENT_ITEM,
|
||||
promptRewrite: DEFAULT_PROMPT_REWRITE_SYSTEM_AGENT_ITEM,
|
||||
thread: DEFAULT_SYSTEM_AGENT_ITEM,
|
||||
topic: DEFAULT_MINI_SYSTEM_AGENT_ITEM,
|
||||
|
||||
@@ -7,7 +7,7 @@ import type {
|
||||
GlobalLLMProviderKey,
|
||||
UserDefaultAgent,
|
||||
UserImageConfig,
|
||||
UserSystemAgentConfig,
|
||||
UserServiceModelConfig,
|
||||
} from './user/settings';
|
||||
|
||||
export type GlobalMemoryLayer = 'activity' | 'context' | 'experience' | 'identity' | 'preference';
|
||||
@@ -76,7 +76,7 @@ export interface GlobalServerConfig {
|
||||
image?: PartialDeep<UserImageConfig>;
|
||||
memory?: GlobalMemoryConfig;
|
||||
oAuthSSOProviders?: string[];
|
||||
systemAgent?: PartialDeep<UserSystemAgentConfig>;
|
||||
systemAgent?: PartialDeep<UserServiceModelConfig>;
|
||||
telemetry: {
|
||||
langfuse?: boolean;
|
||||
};
|
||||
|
||||
@@ -9,7 +9,7 @@ import type { MarketAuthTokens } from './market';
|
||||
import type { UserMemorySettings } from './memory';
|
||||
import type { UserModelProviderConfig } from './modelProvider';
|
||||
import type { NotificationSettings } from './notification';
|
||||
import type { UserSystemAgentConfig } from './systemAgent';
|
||||
import type { UserServiceModelConfig } from './systemAgent';
|
||||
import type { UserToolConfig } from './tool';
|
||||
import type { UserTTSConfig } from './tts';
|
||||
|
||||
@@ -42,7 +42,7 @@ export interface UserSettings {
|
||||
market?: MarketAuthTokens;
|
||||
memory?: UserMemorySettings;
|
||||
notification?: NotificationSettings;
|
||||
systemAgent: UserSystemAgentConfig;
|
||||
systemAgent: UserServiceModelConfig;
|
||||
tool: UserToolConfig;
|
||||
tts: UserTTSConfig;
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
export interface SystemAgentItem {
|
||||
contextLimit?: number;
|
||||
customPrompt?: string;
|
||||
enabled?: boolean;
|
||||
model: string;
|
||||
@@ -21,4 +22,15 @@ export interface UserSystemAgentConfig {
|
||||
translation: SystemAgentItem;
|
||||
}
|
||||
|
||||
export interface UserMemoryServiceModelConfig {
|
||||
memoryAnalysisAgentConfig: SystemAgentItem;
|
||||
userMemoryEmbedding: SystemAgentItem;
|
||||
userMemoryPersonaWriter: SystemAgentItem;
|
||||
}
|
||||
|
||||
export interface UserServiceModelConfig
|
||||
extends UserSystemAgentConfig, UserMemoryServiceModelConfig {}
|
||||
|
||||
export type UserSystemAgentConfigKey = keyof UserSystemAgentConfig;
|
||||
export type UserMemoryServiceModelConfigKey = keyof UserMemoryServiceModelConfig;
|
||||
export type UserServiceModelConfigKey = keyof UserServiceModelConfig;
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
'use client';
|
||||
|
||||
import type { FormGroupItemType, FormItemProps } from '@lobehub/ui';
|
||||
import { Flexbox, Form, Icon, Skeleton } from '@lobehub/ui';
|
||||
import { Flexbox, Form, Icon, InputNumber, Skeleton } from '@lobehub/ui';
|
||||
import { Switch } from 'antd';
|
||||
import isEqual from 'fast-deep-equal';
|
||||
import { Loader2Icon } from 'lucide-react';
|
||||
@@ -12,13 +12,14 @@ import { FORM_STYLE } from '@/const/layoutTokens';
|
||||
import ModelSelect from '@/features/ModelSelect';
|
||||
import { useUserStore } from '@/store/user';
|
||||
import { settingsSelectors } from '@/store/user/selectors';
|
||||
import type { SystemAgentItem, UserSystemAgentConfigKey } from '@/types/user/settings';
|
||||
import type { SystemAgentItem, UserServiceModelConfigKey } from '@/types/user/settings';
|
||||
|
||||
interface SystemAgentModelItem {
|
||||
key: UserSystemAgentConfigKey;
|
||||
contextLimit?: boolean;
|
||||
key: UserServiceModelConfigKey;
|
||||
}
|
||||
|
||||
type LoadingKey = 'defaultAgent' | UserSystemAgentConfigKey;
|
||||
type LoadingKey = 'defaultAgent' | UserServiceModelConfigKey;
|
||||
|
||||
const SYSTEM_AGENT_MODEL_ITEMS: SystemAgentModelItem[] = [
|
||||
{ key: 'topic' },
|
||||
@@ -34,6 +35,12 @@ const OPTIONAL_FEATURE_ITEMS: SystemAgentModelItem[] = [
|
||||
{ key: 'promptRewrite' },
|
||||
];
|
||||
|
||||
const MEMORY_MODEL_ITEMS: SystemAgentModelItem[] = [
|
||||
{ contextLimit: true, key: 'memoryAnalysisAgentConfig' },
|
||||
{ contextLimit: true, key: 'userMemoryPersonaWriter' },
|
||||
{ contextLimit: true, key: 'userMemoryEmbedding' },
|
||||
];
|
||||
|
||||
const ModelAssignmentsForm = memo(() => {
|
||||
const { t } = useTranslation('setting');
|
||||
const [defaultAgent, systemAgentSettings] = useUserStore(
|
||||
@@ -69,7 +76,7 @@ const ModelAssignmentsForm = memo(() => {
|
||||
};
|
||||
|
||||
const updateSystemAgentModel = async (
|
||||
key: UserSystemAgentConfigKey,
|
||||
key: UserServiceModelConfigKey,
|
||||
value: Partial<SystemAgentItem>,
|
||||
) => {
|
||||
setLoadingKey(key);
|
||||
@@ -121,6 +128,39 @@ const ModelAssignmentsForm = memo(() => {
|
||||
} satisfies FormItemProps;
|
||||
});
|
||||
|
||||
const memoryModelItems: FormItemProps[] = MEMORY_MODEL_ITEMS.map(({ contextLimit, key }) => {
|
||||
const value = systemAgentSettings[key];
|
||||
|
||||
return {
|
||||
children: (
|
||||
<Flexbox direction="vertical" gap={8} style={{ width: 'min(100%, 448px)' }}>
|
||||
<ModelSelect
|
||||
showAbility={false}
|
||||
style={{ minWidth: 0, width: '100%' }}
|
||||
value={value}
|
||||
onChange={(props) => updateSystemAgentModel(key, props)}
|
||||
/>
|
||||
{contextLimit && (
|
||||
<InputNumber
|
||||
min={1}
|
||||
placeholder={t('serviceModel.contextLimit.placeholder')}
|
||||
style={{ alignSelf: 'flex-end', width: 180 }}
|
||||
value={value.contextLimit}
|
||||
onChange={(contextLimit) =>
|
||||
updateSystemAgentModel(key, {
|
||||
contextLimit: typeof contextLimit === 'number' ? contextLimit : undefined,
|
||||
})
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Flexbox>
|
||||
),
|
||||
desc: t(`systemAgent.${key}.modelDesc`),
|
||||
label: t(`systemAgent.${key}.title`),
|
||||
minWidth: undefined,
|
||||
} satisfies FormItemProps;
|
||||
});
|
||||
|
||||
const optionalFeatureItems: FormItemProps[] = OPTIONAL_FEATURE_ITEMS.map(({ key }) => {
|
||||
const value = systemAgentSettings[key];
|
||||
const disabled = value.enabled === false;
|
||||
@@ -167,7 +207,8 @@ const ModelAssignmentsForm = memo(() => {
|
||||
loadingKey === 'followUpAction' ||
|
||||
loadingKey === 'inputCompletion' ||
|
||||
loadingKey === 'promptRewrite';
|
||||
const isModelAssignmentLoading = loadingKey && !isOptionalFeatureLoading;
|
||||
const isMemoryModelLoading = MEMORY_MODEL_ITEMS.some(({ key }) => loadingKey === key);
|
||||
const isModelAssignmentLoading = loadingKey && !isOptionalFeatureLoading && !isMemoryModelLoading;
|
||||
|
||||
const modelAssignments: FormGroupItemType = {
|
||||
children: [defaultAgentItem, ...systemModelItems],
|
||||
@@ -185,10 +226,18 @@ const ModelAssignmentsForm = memo(() => {
|
||||
title: t('serviceModel.optionalFeatures.title'),
|
||||
};
|
||||
|
||||
const memoryModels: FormGroupItemType = {
|
||||
children: memoryModelItems,
|
||||
extra: isMemoryModelLoading && (
|
||||
<Icon spin icon={Loader2Icon} size={16} style={{ opacity: 0.5 }} />
|
||||
),
|
||||
title: t('serviceModel.memoryModels.title'),
|
||||
};
|
||||
|
||||
return (
|
||||
<Form
|
||||
collapsible={false}
|
||||
items={[modelAssignments, optionalFeatures]}
|
||||
items={[modelAssignments, memoryModels, optionalFeatures]}
|
||||
itemsType={'group'}
|
||||
variant={'filled'}
|
||||
{...FORM_STYLE}
|
||||
|
||||
@@ -812,6 +812,8 @@ export default {
|
||||
'settingSystem.oauth.signout.success': 'Sign out successful',
|
||||
'settingSystem.title': 'System Settings',
|
||||
'serviceModel.modelAssignments.title': 'Model Assignments',
|
||||
'serviceModel.contextLimit.placeholder': 'Context limit',
|
||||
'serviceModel.memoryModels.title': 'Memory Models',
|
||||
'serviceModel.optionalFeatures.title': 'Optional Features',
|
||||
'settingSystemTools.appEnvironment.chromium.desc': 'Chromium browser engine version',
|
||||
'settingSystemTools.appEnvironment.desc': 'Built-in runtime versions in the desktop app',
|
||||
@@ -991,6 +993,18 @@ When I am ___, I need ___
|
||||
'systemAgent.inputCompletion.modelDesc':
|
||||
'Suggests text while you type. When enabled, this model generates the suggestions.',
|
||||
'systemAgent.inputCompletion.title': 'Input Suggestions',
|
||||
'systemAgent.userMemoryEmbedding.label': 'Model',
|
||||
'systemAgent.userMemoryEmbedding.modelDesc':
|
||||
'Model used to embed memory content for retrieval. The context limit caps each embedding input.',
|
||||
'systemAgent.userMemoryEmbedding.title': 'Memory Embedding',
|
||||
'systemAgent.memoryAnalysisAgentConfig.label': 'Model',
|
||||
'systemAgent.memoryAnalysisAgentConfig.modelDesc':
|
||||
'Model used to decide whether conversations contain memory and extract identities, preferences, contexts, activities, and experiences.',
|
||||
'systemAgent.memoryAnalysisAgentConfig.title': 'Memory Analysis',
|
||||
'systemAgent.userMemoryPersonaWriter.label': 'Model',
|
||||
'systemAgent.userMemoryPersonaWriter.modelDesc':
|
||||
'Model used to write persona-oriented memory summaries.',
|
||||
'systemAgent.userMemoryPersonaWriter.title': 'Memory Persona Writer',
|
||||
'systemAgent.promptRewrite.label': 'Model',
|
||||
'systemAgent.promptRewrite.modelDesc':
|
||||
'Improves prompts before generation. When enabled, this model rewrites the prompt.',
|
||||
|
||||
@@ -103,6 +103,25 @@ describe('parseSystemAgent', () => {
|
||||
expect(result.agentMeta).toEqual({ provider: 'ollama', model: 'deepseek-v3' });
|
||||
expect(result.historyCompress).toEqual({ provider: 'ollama', model: 'deepseek-v3' });
|
||||
expect(result.thread).toEqual({ provider: 'ollama', model: 'deepseek-v3' });
|
||||
expect(result.userMemoryEmbedding).toBeUndefined();
|
||||
expect(result.memoryAnalysisAgentConfig).toBeUndefined();
|
||||
expect(result.userMemoryPersonaWriter).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should parse memory service model assignments explicitly', () => {
|
||||
const envValue =
|
||||
'memoryAnalysisAgentConfig=lobehub/gpt-5.4-mini,userMemoryEmbedding=openai/text-embedding-3-large';
|
||||
|
||||
const result = parseSystemAgent(envValue);
|
||||
|
||||
expect(result.memoryAnalysisAgentConfig).toEqual({
|
||||
provider: 'lobehub',
|
||||
model: 'gpt-5.4-mini',
|
||||
});
|
||||
expect(result.userMemoryEmbedding).toEqual({
|
||||
provider: 'openai',
|
||||
model: 'text-embedding-3-large',
|
||||
});
|
||||
});
|
||||
|
||||
it('should override default setting with specific settings', () => {
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
import { DEFAULT_SYSTEM_AGENT_CONFIG } from '@/const/settings';
|
||||
import { type UserSystemAgentConfig } from '@/types/user/settings';
|
||||
import { type UserServiceModelConfig } from '@/types/user/settings';
|
||||
|
||||
const protectedKeys = Object.keys(DEFAULT_SYSTEM_AGENT_CONFIG);
|
||||
|
||||
const defaultTrueLey = new Set(['promptRewrite', 'autoSuggestion']);
|
||||
const memoryServiceModelKeys = new Set([
|
||||
'memoryAnalysisAgentConfig',
|
||||
'userMemoryEmbedding',
|
||||
'userMemoryPersonaWriter',
|
||||
]);
|
||||
const defaultModelAssignmentKeys = protectedKeys.filter((key) => !memoryServiceModelKeys.has(key));
|
||||
|
||||
export const parseSystemAgent = (envString: string = ''): Partial<UserSystemAgentConfig> => {
|
||||
export const parseSystemAgent = (envString: string = ''): Partial<UserServiceModelConfig> => {
|
||||
if (!envString) return {};
|
||||
|
||||
const config: Partial<UserSystemAgentConfig> = {};
|
||||
const config: Partial<UserServiceModelConfig> = {};
|
||||
|
||||
// Handle full-width commas and extra spaces
|
||||
const envValue = envString.replaceAll(',', ',').trim();
|
||||
@@ -39,7 +45,7 @@ export const parseSystemAgent = (envString: string = ''): Partial<UserSystemAgen
|
||||
}
|
||||
|
||||
if (protectedKeys.includes(key)) {
|
||||
config[key as keyof UserSystemAgentConfig] = {
|
||||
config[key as keyof UserServiceModelConfig] = {
|
||||
enabled: defaultTrueLey.has(key) ? true : undefined,
|
||||
model: model.trim(),
|
||||
provider: provider.trim(),
|
||||
@@ -52,9 +58,9 @@ export const parseSystemAgent = (envString: string = ''): Partial<UserSystemAgen
|
||||
|
||||
// If there are default settings, apply them to all unconfigured system agents
|
||||
if (defaultSetting) {
|
||||
for (const key of protectedKeys) {
|
||||
if (!config[key as keyof UserSystemAgentConfig]) {
|
||||
config[key as keyof UserSystemAgentConfig] = {
|
||||
for (const key of defaultModelAssignmentKeys) {
|
||||
if (!config[key as keyof UserServiceModelConfig]) {
|
||||
config[key as keyof UserServiceModelConfig] = {
|
||||
enabled: defaultTrueLey.has(key) ? true : undefined,
|
||||
model: defaultSetting.model,
|
||||
provider: defaultSetting.provider,
|
||||
|
||||
@@ -57,7 +57,202 @@ const createExecutor = (privateOverrides?: Partial<MemoryExtractionPrivateConfig
|
||||
});
|
||||
};
|
||||
|
||||
const resolveRuntimeKeyVaults = async (
|
||||
executor: MemoryExtractionExecutor,
|
||||
runtimeState: AiProviderRuntimeState,
|
||||
) => {
|
||||
const memoryServiceConfig = (executor as any).resolveUserMemoryServiceConfig();
|
||||
|
||||
return (executor as any).resolveRuntimeKeyVaults(runtimeState, memoryServiceConfig);
|
||||
};
|
||||
|
||||
describe('MemoryExtractionExecutor.resolveRuntimeKeyVaults', () => {
|
||||
it('drops fallback credentials when user memory provider is overridden', () => {
|
||||
const executor = createExecutor({
|
||||
embedding: {
|
||||
apiKey: 'openai-system-key',
|
||||
baseURL: 'https://openai.example.com',
|
||||
model: 'embed-1',
|
||||
provider: 'openai',
|
||||
},
|
||||
});
|
||||
|
||||
const memoryServiceConfig = (executor as any).resolveUserMemoryServiceConfig({
|
||||
userMemoryEmbedding: {
|
||||
model: 'embed-2',
|
||||
provider: 'anthropic',
|
||||
},
|
||||
});
|
||||
|
||||
expect(memoryServiceConfig.agents.embedding).toMatchObject({
|
||||
model: 'embed-2',
|
||||
provider: 'anthropic',
|
||||
});
|
||||
expect(memoryServiceConfig.agents.embedding.apiKey).toBeUndefined();
|
||||
expect(memoryServiceConfig.agents.embedding.baseURL).toBeUndefined();
|
||||
});
|
||||
|
||||
it('keeps fallback credentials when user memory provider is unchanged', () => {
|
||||
const executor = createExecutor({
|
||||
embedding: {
|
||||
apiKey: 'openai-system-key',
|
||||
baseURL: 'https://openai.example.com',
|
||||
model: 'embed-1',
|
||||
provider: 'openai',
|
||||
},
|
||||
});
|
||||
|
||||
const memoryServiceConfig = (executor as any).resolveUserMemoryServiceConfig({
|
||||
userMemoryEmbedding: {
|
||||
model: 'embed-2',
|
||||
provider: 'openai',
|
||||
},
|
||||
});
|
||||
|
||||
expect(memoryServiceConfig.agents.embedding).toMatchObject({
|
||||
apiKey: 'openai-system-key',
|
||||
baseURL: 'https://openai.example.com',
|
||||
model: 'embed-2',
|
||||
provider: 'openai',
|
||||
});
|
||||
});
|
||||
|
||||
it('shares ServiceModel memory analysis config between gatekeeper and layer extractor', () => {
|
||||
const executor = createExecutor({
|
||||
agentGateKeeper: {
|
||||
apiKey: 'gate-system-key',
|
||||
baseURL: 'https://gate.example.com',
|
||||
model: 'gate-1',
|
||||
provider: 'provider-gate',
|
||||
},
|
||||
agentLayerExtractor: {
|
||||
apiKey: 'layer-system-key',
|
||||
baseURL: 'https://layer.example.com',
|
||||
contextLimit: 2048,
|
||||
layers: {
|
||||
activity: 'layer-act',
|
||||
context: 'layer-ctx',
|
||||
experience: 'layer-exp',
|
||||
identity: 'layer-id',
|
||||
preference: 'layer-pref',
|
||||
},
|
||||
model: 'layer-1',
|
||||
provider: 'provider-layer',
|
||||
},
|
||||
});
|
||||
|
||||
const memoryServiceConfig = (executor as any).resolveUserMemoryServiceConfig({
|
||||
memoryAnalysisAgentConfig: {
|
||||
contextLimit: 4096,
|
||||
model: 'analysis-1',
|
||||
provider: 'provider-analysis',
|
||||
},
|
||||
});
|
||||
|
||||
expect(memoryServiceConfig.agents.gatekeeper).toMatchObject({
|
||||
model: 'analysis-1',
|
||||
provider: 'provider-analysis',
|
||||
});
|
||||
expect(memoryServiceConfig.agents.layerExtractor).toMatchObject({
|
||||
contextLimit: 4096,
|
||||
model: 'analysis-1',
|
||||
provider: 'provider-analysis',
|
||||
});
|
||||
expect(memoryServiceConfig.agents.gatekeeper.apiKey).toBeUndefined();
|
||||
expect(memoryServiceConfig.agents.layerExtractor.apiKey).toBeUndefined();
|
||||
expect(memoryServiceConfig.modelConfig.gateModel).toBe('analysis-1');
|
||||
expect(memoryServiceConfig.modelConfig.layerModels).toEqual({
|
||||
activity: 'analysis-1',
|
||||
context: 'analysis-1',
|
||||
experience: 'analysis-1',
|
||||
identity: 'analysis-1',
|
||||
preference: 'analysis-1',
|
||||
});
|
||||
});
|
||||
|
||||
it('uses ServiceModel provider before env preferred providers when provider is overridden', async () => {
|
||||
const executor = createExecutor({
|
||||
agentGateKeeper: {
|
||||
model: 'gate-1',
|
||||
provider: 'provider-g',
|
||||
},
|
||||
agentLayerExtractor: {
|
||||
contextLimit: 2048,
|
||||
layers: {
|
||||
activity: 'layer-1',
|
||||
context: 'layer-1',
|
||||
experience: 'layer-1',
|
||||
identity: 'layer-1',
|
||||
preference: 'layer-1',
|
||||
},
|
||||
model: 'layer-1',
|
||||
provider: 'provider-l',
|
||||
},
|
||||
embedding: {
|
||||
apiKey: 'openai-system-key',
|
||||
baseURL: 'https://openai.example.com',
|
||||
model: 'embed-1',
|
||||
provider: 'openai',
|
||||
},
|
||||
embeddingPreferredProviders: ['provider-b'],
|
||||
});
|
||||
|
||||
const memoryServiceConfig = (executor as any).resolveUserMemoryServiceConfig({
|
||||
userMemoryEmbedding: {
|
||||
model: 'embed-2',
|
||||
provider: 'provider-a',
|
||||
},
|
||||
});
|
||||
const runtimeState = createRuntimeState(
|
||||
[
|
||||
{
|
||||
abilities: {},
|
||||
enabled: true,
|
||||
id: 'gate-1',
|
||||
providerId: 'provider-g',
|
||||
type: 'chat',
|
||||
},
|
||||
{
|
||||
abilities: {},
|
||||
enabled: true,
|
||||
id: 'layer-1',
|
||||
providerId: 'provider-l',
|
||||
type: 'chat',
|
||||
},
|
||||
{
|
||||
abilities: {},
|
||||
enabled: true,
|
||||
id: 'embed-2',
|
||||
providerId: 'provider-a',
|
||||
type: 'embedding',
|
||||
},
|
||||
{
|
||||
abilities: {},
|
||||
enabled: true,
|
||||
id: 'embed-2',
|
||||
providerId: 'provider-b',
|
||||
type: 'embedding',
|
||||
},
|
||||
],
|
||||
{
|
||||
'provider-a': { apiKey: 'a-key' },
|
||||
'provider-b': { apiKey: 'b-key' },
|
||||
'provider-g': { apiKey: 'g-key' },
|
||||
'provider-l': { apiKey: 'l-key' },
|
||||
},
|
||||
);
|
||||
|
||||
const keyVaults = await (executor as any).resolveRuntimeKeyVaults(
|
||||
runtimeState,
|
||||
memoryServiceConfig,
|
||||
);
|
||||
|
||||
expect(keyVaults).toMatchObject({
|
||||
'provider-a': { apiKey: 'a-key' },
|
||||
});
|
||||
expect(keyVaults).not.toHaveProperty('provider-b');
|
||||
});
|
||||
|
||||
it('prefers configured providers/models for gatekeeper, embedding, and layer extractors', async () => {
|
||||
const executor = createExecutor({
|
||||
embeddingPreferredProviders: ['provider-c', 'provider-a'],
|
||||
@@ -119,7 +314,7 @@ describe('MemoryExtractionExecutor.resolveRuntimeKeyVaults', () => {
|
||||
},
|
||||
);
|
||||
|
||||
const keyVaults = await (executor as any).resolveRuntimeKeyVaults(runtimeState);
|
||||
const keyVaults = await resolveRuntimeKeyVaults(executor, runtimeState);
|
||||
|
||||
expect(keyVaults).toMatchObject({
|
||||
'provider-a': { apiKey: 'a-key' },
|
||||
@@ -182,7 +377,7 @@ describe('MemoryExtractionExecutor.resolveRuntimeKeyVaults', () => {
|
||||
},
|
||||
);
|
||||
|
||||
const keyVaults = await (executor as any).resolveRuntimeKeyVaults(runtimeState);
|
||||
const keyVaults = await resolveRuntimeKeyVaults(executor, runtimeState);
|
||||
|
||||
expect(keyVaults).toMatchObject({
|
||||
'provider-b': { apiKey: 'b-key' },
|
||||
@@ -222,7 +417,7 @@ describe('MemoryExtractionExecutor.resolveRuntimeKeyVaults', () => {
|
||||
},
|
||||
);
|
||||
|
||||
const keyVaults = await (executor as any).resolveRuntimeKeyVaults(runtimeState);
|
||||
const keyVaults = await resolveRuntimeKeyVaults(executor, runtimeState);
|
||||
|
||||
expect(keyVaults).toMatchObject({
|
||||
'provider-a': { apiKey: 'a-key' },
|
||||
@@ -253,7 +448,7 @@ describe('MemoryExtractionExecutor.resolveRuntimeKeyVaults', () => {
|
||||
},
|
||||
);
|
||||
|
||||
const keyVaults = await (executor as any).resolveRuntimeKeyVaults(runtimeState);
|
||||
const keyVaults = await resolveRuntimeKeyVaults(executor, runtimeState);
|
||||
|
||||
expect(keyVaults).toMatchObject({
|
||||
'provider-b': { apiKey: 'b-key' }, // picks first preferred provider
|
||||
@@ -271,7 +466,7 @@ describe('MemoryExtractionExecutor.resolveRuntimeKeyVaults', () => {
|
||||
'provider-fallback': { apiKey: 'fb-key' },
|
||||
});
|
||||
|
||||
const keyVaults = await (executor as any).resolveRuntimeKeyVaults(runtimeState);
|
||||
const keyVaults = await resolveRuntimeKeyVaults(executor, runtimeState);
|
||||
|
||||
expect(keyVaults).toMatchObject({
|
||||
'provider-fallback': { apiKey: 'fb-key' },
|
||||
|
||||
@@ -45,6 +45,7 @@ import type {
|
||||
MemoryExtractionAgentCallTrace,
|
||||
MemoryExtractionTraceError,
|
||||
MemoryExtractionTracePayload,
|
||||
UserServiceModelConfig,
|
||||
} from '@lobechat/types';
|
||||
import { RequestTrigger } from '@lobechat/types';
|
||||
import { type FlowControl } from '@upstash/qstash';
|
||||
@@ -583,6 +584,38 @@ type RuntimeBundle = {
|
||||
layerExtractor: ModelRuntime;
|
||||
};
|
||||
|
||||
interface MemoryExtractionModelConfig {
|
||||
embeddingsModel: string;
|
||||
gateModel: string;
|
||||
layerModels: Partial<Record<LayersEnum, string>>;
|
||||
observabilityS3: MemoryExtractionConfig['observabilityS3'];
|
||||
}
|
||||
|
||||
interface ResolvedMemoryServiceConfig {
|
||||
agents: {
|
||||
embedding: MemoryAgentConfig;
|
||||
gatekeeper: MemoryAgentConfig;
|
||||
layerExtractor: MemoryAgentConfig;
|
||||
};
|
||||
embeddingContextLimit?: number;
|
||||
extractorContextLimit?: number;
|
||||
modelConfig: MemoryExtractionModelConfig;
|
||||
overrides: {
|
||||
embedding: {
|
||||
model: boolean;
|
||||
provider: boolean;
|
||||
};
|
||||
gatekeeper: {
|
||||
model: boolean;
|
||||
provider: boolean;
|
||||
};
|
||||
layerExtractor: {
|
||||
model: boolean;
|
||||
provider: boolean;
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
export interface TopicExtractionJob {
|
||||
asyncTaskId?: string;
|
||||
forceAll: boolean;
|
||||
@@ -623,14 +656,7 @@ export class MemoryExtractionExecutor {
|
||||
private readonly layerPreferredModels?: string[];
|
||||
private readonly layerPreferredProviders?: string[];
|
||||
private readonly privateConfig: MemoryExtractionConfig;
|
||||
private readonly modelConfig: {
|
||||
embeddingsModel: string;
|
||||
gateModel: string;
|
||||
layerModels: Partial<Record<LayersEnum, string>>;
|
||||
observabilityS3: MemoryExtractionConfig['observabilityS3'];
|
||||
};
|
||||
private readonly embeddingContextLimit?: number;
|
||||
|
||||
private readonly modelConfig: MemoryExtractionModelConfig;
|
||||
private readonly runtimeCache = new Map<string, RuntimeBundle>();
|
||||
private readonly db = getServerDB();
|
||||
|
||||
@@ -659,9 +685,6 @@ export class MemoryExtractionExecutor {
|
||||
),
|
||||
observabilityS3: privateConfig.observabilityS3,
|
||||
};
|
||||
|
||||
this.embeddingContextLimit =
|
||||
privateConfig.embedding?.contextLimit ?? privateConfig.agentLayerExtractor.contextLimit;
|
||||
}
|
||||
|
||||
static async create() {
|
||||
@@ -673,6 +696,95 @@ export class MemoryExtractionExecutor {
|
||||
return new MemoryExtractionExecutor(serverConfig, privateConfig);
|
||||
}
|
||||
|
||||
private resolveUserMemoryAgent(
|
||||
systemAgent: Partial<UserServiceModelConfig> | undefined,
|
||||
key: keyof Pick<
|
||||
UserServiceModelConfig,
|
||||
'userMemoryEmbedding' | 'memoryAnalysisAgentConfig' | 'userMemoryPersonaWriter'
|
||||
>,
|
||||
fallback: MemoryAgentConfig,
|
||||
): MemoryAgentConfig {
|
||||
const override = systemAgent?.[key];
|
||||
const provider = override?.provider || fallback.provider;
|
||||
const shouldInheritCredentials =
|
||||
!override?.provider ||
|
||||
normalizeProvider(override.provider) === normalizeProvider(fallback.provider || 'openai');
|
||||
const contextLimit =
|
||||
typeof override?.contextLimit === 'number' &&
|
||||
Number.isFinite(override.contextLimit) &&
|
||||
override.contextLimit > 0
|
||||
? Math.floor(override.contextLimit)
|
||||
: fallback.contextLimit;
|
||||
|
||||
return {
|
||||
apiKey: shouldInheritCredentials ? fallback.apiKey : undefined,
|
||||
baseURL: shouldInheritCredentials ? fallback.baseURL : undefined,
|
||||
contextLimit,
|
||||
language: fallback.language,
|
||||
model: override?.model || fallback.model,
|
||||
provider,
|
||||
};
|
||||
}
|
||||
|
||||
private resolveUserMemoryServiceConfig(
|
||||
systemAgent?: Partial<UserServiceModelConfig>,
|
||||
): ResolvedMemoryServiceConfig {
|
||||
const gatekeeper = this.resolveUserMemoryAgent(
|
||||
systemAgent,
|
||||
'memoryAnalysisAgentConfig',
|
||||
this.privateConfig.agentGateKeeper,
|
||||
);
|
||||
const layerExtractor = this.resolveUserMemoryAgent(
|
||||
systemAgent,
|
||||
'memoryAnalysisAgentConfig',
|
||||
this.privateConfig.agentLayerExtractor,
|
||||
);
|
||||
const embedding = this.resolveUserMemoryAgent(
|
||||
systemAgent,
|
||||
'userMemoryEmbedding',
|
||||
this.privateConfig.embedding,
|
||||
);
|
||||
const layerModels = systemAgent?.memoryAnalysisAgentConfig?.model
|
||||
? {
|
||||
activity: layerExtractor.model,
|
||||
context: layerExtractor.model,
|
||||
experience: layerExtractor.model,
|
||||
identity: layerExtractor.model,
|
||||
preference: layerExtractor.model,
|
||||
}
|
||||
: resolveLayerModels(undefined, this.privateConfig.agentLayerExtractor.layers);
|
||||
|
||||
return {
|
||||
agents: {
|
||||
embedding,
|
||||
gatekeeper,
|
||||
layerExtractor,
|
||||
},
|
||||
embeddingContextLimit: embedding.contextLimit ?? layerExtractor.contextLimit,
|
||||
extractorContextLimit: layerExtractor.contextLimit,
|
||||
modelConfig: {
|
||||
embeddingsModel: embedding.model,
|
||||
gateModel: gatekeeper.model,
|
||||
layerModels,
|
||||
observabilityS3: this.privateConfig.observabilityS3,
|
||||
},
|
||||
overrides: {
|
||||
embedding: {
|
||||
model: Boolean(systemAgent?.userMemoryEmbedding?.model),
|
||||
provider: Boolean(systemAgent?.userMemoryEmbedding?.provider),
|
||||
},
|
||||
gatekeeper: {
|
||||
model: Boolean(systemAgent?.memoryAnalysisAgentConfig?.model),
|
||||
provider: Boolean(systemAgent?.memoryAnalysisAgentConfig?.provider),
|
||||
},
|
||||
layerExtractor: {
|
||||
model: Boolean(systemAgent?.memoryAnalysisAgentConfig?.model),
|
||||
provider: Boolean(systemAgent?.memoryAnalysisAgentConfig?.provider),
|
||||
},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
private buildBaseMetadata(
|
||||
job: MemoryExtractionJob,
|
||||
messageIds: string[],
|
||||
@@ -1431,10 +1543,16 @@ export class MemoryExtractionExecutor {
|
||||
userModel.getUserState(KeyVaultsGateKeeper.getUserKeyVaults),
|
||||
this.getAiProviderRuntimeState(job.userId),
|
||||
]);
|
||||
const keyVaults = await this.resolveRuntimeKeyVaults(aiProviderRuntimeState);
|
||||
const memoryServiceConfig = this.resolveUserMemoryServiceConfig(
|
||||
userState.settings?.systemAgent as Partial<UserServiceModelConfig> | undefined,
|
||||
);
|
||||
const keyVaults = await this.resolveRuntimeKeyVaults(
|
||||
aiProviderRuntimeState,
|
||||
memoryServiceConfig,
|
||||
);
|
||||
const language = userState.settings?.general?.responseLanguage;
|
||||
|
||||
const runtimes = await this.getRuntime(job.userId, keyVaults);
|
||||
const runtimes = await this.getRuntime(job.userId, memoryServiceConfig, keyVaults);
|
||||
|
||||
const conversations = await this.listConversationsForTopic(
|
||||
job.userId,
|
||||
@@ -1454,8 +1572,9 @@ export class MemoryExtractionExecutor {
|
||||
};
|
||||
}
|
||||
|
||||
const extractorContextLimit = this.privateConfig.agentLayerExtractor.contextLimit;
|
||||
const embeddingContextLimit = this.embeddingContextLimit ?? extractorContextLimit;
|
||||
const extractorContextLimit = memoryServiceConfig.extractorContextLimit;
|
||||
const embeddingContextLimit =
|
||||
memoryServiceConfig.embeddingContextLimit ?? extractorContextLimit;
|
||||
const extractorConversations = await this.trimConversationsToTokenLimit(
|
||||
conversations,
|
||||
extractorContextLimit,
|
||||
@@ -1508,7 +1627,7 @@ export class MemoryExtractionExecutor {
|
||||
searchResult = await this.listRelevantUserMemories(
|
||||
extractionJob,
|
||||
runtimes.embeddings,
|
||||
this.modelConfig.embeddingsModel,
|
||||
memoryServiceConfig.modelConfig.embeddingsModel,
|
||||
job.userId,
|
||||
embeddingConversations,
|
||||
embeddingContextLimit,
|
||||
@@ -1621,7 +1740,7 @@ export class MemoryExtractionExecutor {
|
||||
};
|
||||
|
||||
const service = new MemoryExtractionService({
|
||||
config: this.modelConfig,
|
||||
config: memoryServiceConfig.modelConfig,
|
||||
db,
|
||||
language,
|
||||
runtimes,
|
||||
@@ -1676,6 +1795,7 @@ export class MemoryExtractionExecutor {
|
||||
messageIds,
|
||||
extraction,
|
||||
runtimes,
|
||||
memoryServiceConfig,
|
||||
db,
|
||||
);
|
||||
if (retrievalErrors.length > 0) {
|
||||
@@ -2023,6 +2143,7 @@ export class MemoryExtractionExecutor {
|
||||
messageIds: string[],
|
||||
extraction: MemoryExtractionResult,
|
||||
runtimes: RuntimeBundle,
|
||||
memoryServiceConfig: ResolvedMemoryServiceConfig,
|
||||
db: Awaited<ReturnType<typeof getServerDB>>,
|
||||
): Promise<PersistedMemoryResult> {
|
||||
const createdIds: string[] = [];
|
||||
@@ -2108,8 +2229,8 @@ export class MemoryExtractionExecutor {
|
||||
messageIds,
|
||||
activityOutput.data,
|
||||
runtimes.embeddings,
|
||||
this.modelConfig.embeddingsModel,
|
||||
this.embeddingContextLimit,
|
||||
memoryServiceConfig.modelConfig.embeddingsModel,
|
||||
memoryServiceConfig.embeddingContextLimit,
|
||||
db,
|
||||
),
|
||||
);
|
||||
@@ -2126,8 +2247,8 @@ export class MemoryExtractionExecutor {
|
||||
messageIds,
|
||||
contextOutput.data,
|
||||
runtimes.embeddings,
|
||||
this.modelConfig.embeddingsModel,
|
||||
this.embeddingContextLimit,
|
||||
memoryServiceConfig.modelConfig.embeddingsModel,
|
||||
memoryServiceConfig.embeddingContextLimit,
|
||||
db,
|
||||
),
|
||||
);
|
||||
@@ -2144,8 +2265,8 @@ export class MemoryExtractionExecutor {
|
||||
messageIds,
|
||||
experienceOutput.data,
|
||||
runtimes.embeddings,
|
||||
this.modelConfig.embeddingsModel,
|
||||
this.embeddingContextLimit,
|
||||
memoryServiceConfig.modelConfig.embeddingsModel,
|
||||
memoryServiceConfig.embeddingContextLimit,
|
||||
db,
|
||||
),
|
||||
);
|
||||
@@ -2162,8 +2283,8 @@ export class MemoryExtractionExecutor {
|
||||
messageIds,
|
||||
preferenceOutput.data,
|
||||
runtimes.embeddings,
|
||||
this.modelConfig.embeddingsModel,
|
||||
this.embeddingContextLimit,
|
||||
memoryServiceConfig.modelConfig.embeddingsModel,
|
||||
memoryServiceConfig.embeddingContextLimit,
|
||||
db,
|
||||
),
|
||||
);
|
||||
@@ -2180,8 +2301,8 @@ export class MemoryExtractionExecutor {
|
||||
messageIds,
|
||||
identityOutput.data,
|
||||
runtimes.embeddings,
|
||||
this.modelConfig.embeddingsModel,
|
||||
this.embeddingContextLimit,
|
||||
memoryServiceConfig.modelConfig.embeddingsModel,
|
||||
memoryServiceConfig.embeddingContextLimit,
|
||||
db,
|
||||
),
|
||||
);
|
||||
@@ -2210,6 +2331,7 @@ export class MemoryExtractionExecutor {
|
||||
|
||||
private async resolveRuntimeKeyVaults(
|
||||
runtimeState: AiProviderRuntimeState,
|
||||
memoryServiceConfig: ResolvedMemoryServiceConfig,
|
||||
): Promise<ProviderKeyVaultMap> {
|
||||
const normalizedRuntimeConfig = Object.fromEntries(
|
||||
Object.entries(runtimeState.runtimeConfig || {}).map(([providerId, config]) => [
|
||||
@@ -2221,11 +2343,15 @@ export class MemoryExtractionExecutor {
|
||||
const keyVaults: ProviderKeyVaultMap = {};
|
||||
|
||||
const gatekeeperProvider = await AiInfraRepos.tryMatchingProviderFrom(runtimeState, {
|
||||
fallbackProvider: this.privateConfig.agentGateKeeper.provider,
|
||||
fallbackProvider: memoryServiceConfig.agents.gatekeeper.provider,
|
||||
label: 'gatekeeper',
|
||||
modelId: this.modelConfig.gateModel,
|
||||
preferredModels: this.gatekeeperPreferredModels,
|
||||
preferredProviders: this.gatekeeperPreferredProviders,
|
||||
modelId: memoryServiceConfig.modelConfig.gateModel,
|
||||
preferredModels: memoryServiceConfig.overrides.gatekeeper.model
|
||||
? undefined
|
||||
: this.gatekeeperPreferredModels,
|
||||
preferredProviders: memoryServiceConfig.overrides.gatekeeper.provider
|
||||
? undefined
|
||||
: this.gatekeeperPreferredProviders,
|
||||
});
|
||||
const gatekeeperRuntime = normalizedRuntimeConfig[gatekeeperProvider];
|
||||
if (gatekeeperRuntime?.keyVaults) {
|
||||
@@ -2233,25 +2359,33 @@ export class MemoryExtractionExecutor {
|
||||
}
|
||||
|
||||
const embeddingProvider = await AiInfraRepos.tryMatchingProviderFrom(runtimeState, {
|
||||
fallbackProvider: this.privateConfig.embedding.provider,
|
||||
fallbackProvider: memoryServiceConfig.agents.embedding.provider,
|
||||
label: 'embedding',
|
||||
modelId: this.modelConfig.embeddingsModel,
|
||||
preferredModels: this.embeddingPreferredModels,
|
||||
preferredProviders: this.embeddingPreferredProviders,
|
||||
modelId: memoryServiceConfig.modelConfig.embeddingsModel,
|
||||
preferredModels: memoryServiceConfig.overrides.embedding.model
|
||||
? undefined
|
||||
: this.embeddingPreferredModels,
|
||||
preferredProviders: memoryServiceConfig.overrides.embedding.provider
|
||||
? undefined
|
||||
: this.embeddingPreferredProviders,
|
||||
});
|
||||
const embeddingRuntime = normalizedRuntimeConfig[embeddingProvider];
|
||||
if (embeddingRuntime?.keyVaults) {
|
||||
keyVaults[embeddingProvider] = embeddingRuntime.keyVaults;
|
||||
}
|
||||
|
||||
for (const model of Object.values(this.modelConfig.layerModels)) {
|
||||
for (const model of Object.values(memoryServiceConfig.modelConfig.layerModels)) {
|
||||
if (!model) continue;
|
||||
const providerId = await AiInfraRepos.tryMatchingProviderFrom(runtimeState, {
|
||||
fallbackProvider: this.privateConfig.agentLayerExtractor.provider,
|
||||
fallbackProvider: memoryServiceConfig.agents.layerExtractor.provider,
|
||||
label: 'layer extractor',
|
||||
modelId: model,
|
||||
preferredModels: this.layerPreferredModels,
|
||||
preferredProviders: this.layerPreferredProviders,
|
||||
preferredModels: memoryServiceConfig.overrides.layerExtractor.model
|
||||
? undefined
|
||||
: this.layerPreferredModels,
|
||||
preferredProviders: memoryServiceConfig.overrides.layerExtractor.provider
|
||||
? undefined
|
||||
: this.layerPreferredProviders,
|
||||
});
|
||||
const runtime = normalizedRuntimeConfig[providerId];
|
||||
if (runtime?.keyVaults) {
|
||||
@@ -2264,6 +2398,7 @@ export class MemoryExtractionExecutor {
|
||||
|
||||
private async getRuntime(
|
||||
userId: string,
|
||||
memoryServiceConfig: ResolvedMemoryServiceConfig,
|
||||
keyVaults?: ProviderKeyVaultMap,
|
||||
): Promise<RuntimeBundle> {
|
||||
// TODO: implement a better cache eviction strategy
|
||||
@@ -2272,33 +2407,51 @@ export class MemoryExtractionExecutor {
|
||||
this.runtimeCache.clear();
|
||||
}
|
||||
|
||||
const cached = this.runtimeCache.get(userId);
|
||||
const cacheKey = [
|
||||
userId,
|
||||
memoryServiceConfig.agents.embedding.provider,
|
||||
memoryServiceConfig.agents.gatekeeper.provider,
|
||||
memoryServiceConfig.agents.layerExtractor.provider,
|
||||
].join(':');
|
||||
const cached = this.runtimeCache.get(cacheKey);
|
||||
if (cached) return cached;
|
||||
|
||||
const embeddingOptions: RuntimeResolveOptions = {
|
||||
fallback: {
|
||||
apiKey: this.privateConfig.embedding.apiKey,
|
||||
baseURL: this.privateConfig.embedding.baseURL,
|
||||
apiKey: memoryServiceConfig.agents.embedding.apiKey,
|
||||
baseURL: memoryServiceConfig.agents.embedding.baseURL,
|
||||
},
|
||||
preferred: {
|
||||
providerIds: memoryServiceConfig.overrides.embedding.provider
|
||||
? undefined
|
||||
: this.embeddingPreferredProviders,
|
||||
},
|
||||
preferred: { providerIds: this.embeddingPreferredProviders },
|
||||
userId,
|
||||
};
|
||||
|
||||
const gatekeeperOptions: RuntimeResolveOptions = {
|
||||
fallback: {
|
||||
apiKey: this.privateConfig.agentGateKeeper.apiKey,
|
||||
baseURL: this.privateConfig.agentGateKeeper.baseURL,
|
||||
apiKey: memoryServiceConfig.agents.gatekeeper.apiKey,
|
||||
baseURL: memoryServiceConfig.agents.gatekeeper.baseURL,
|
||||
},
|
||||
preferred: {
|
||||
providerIds: memoryServiceConfig.overrides.gatekeeper.provider
|
||||
? undefined
|
||||
: this.gatekeeperPreferredProviders,
|
||||
},
|
||||
preferred: { providerIds: this.gatekeeperPreferredProviders },
|
||||
userId,
|
||||
};
|
||||
|
||||
const layerExtractorOptions: RuntimeResolveOptions = {
|
||||
fallback: {
|
||||
apiKey: this.privateConfig.agentLayerExtractor.apiKey,
|
||||
baseURL: this.privateConfig.agentLayerExtractor.baseURL,
|
||||
apiKey: memoryServiceConfig.agents.layerExtractor.apiKey,
|
||||
baseURL: memoryServiceConfig.agents.layerExtractor.baseURL,
|
||||
},
|
||||
preferred: {
|
||||
providerIds: memoryServiceConfig.overrides.layerExtractor.provider
|
||||
? undefined
|
||||
: this.layerPreferredProviders,
|
||||
},
|
||||
preferred: { providerIds: this.layerPreferredProviders },
|
||||
userId,
|
||||
};
|
||||
|
||||
@@ -2306,26 +2459,26 @@ export class MemoryExtractionExecutor {
|
||||
|
||||
const runtimes: RuntimeBundle = {
|
||||
embeddings: await resolveRuntimeAgentConfig(
|
||||
{ ...this.privateConfig.embedding },
|
||||
memoryServiceConfig.agents.embedding,
|
||||
keyVaults,
|
||||
embeddingOptions,
|
||||
hooks,
|
||||
),
|
||||
gatekeeper: await resolveRuntimeAgentConfig(
|
||||
{ ...this.privateConfig.agentGateKeeper },
|
||||
memoryServiceConfig.agents.gatekeeper,
|
||||
keyVaults,
|
||||
gatekeeperOptions,
|
||||
hooks,
|
||||
),
|
||||
layerExtractor: await resolveRuntimeAgentConfig(
|
||||
{ ...this.privateConfig.agentLayerExtractor },
|
||||
memoryServiceConfig.agents.layerExtractor,
|
||||
keyVaults,
|
||||
layerExtractorOptions,
|
||||
hooks,
|
||||
),
|
||||
};
|
||||
|
||||
this.runtimeCache.set(userId, runtimes);
|
||||
this.runtimeCache.set(cacheKey, runtimes);
|
||||
|
||||
return runtimes;
|
||||
}
|
||||
@@ -2361,10 +2514,16 @@ export class MemoryExtractionExecutor {
|
||||
userModel.getUserState(KeyVaultsGateKeeper.getUserKeyVaults),
|
||||
this.getAiProviderRuntimeState(params.userId),
|
||||
]);
|
||||
const keyVaults = await this.resolveRuntimeKeyVaults(aiProviderRuntimeState);
|
||||
const memoryServiceConfig = this.resolveUserMemoryServiceConfig(
|
||||
userState.settings?.systemAgent as Partial<UserServiceModelConfig> | undefined,
|
||||
);
|
||||
const keyVaults = await this.resolveRuntimeKeyVaults(
|
||||
aiProviderRuntimeState,
|
||||
memoryServiceConfig,
|
||||
);
|
||||
const language = params.language || userState.settings?.general?.responseLanguage;
|
||||
|
||||
const runtimes = await this.getRuntime(params.userId, keyVaults);
|
||||
const runtimes = await this.getRuntime(params.userId, memoryServiceConfig, keyVaults);
|
||||
const contextProvider =
|
||||
params.contextProvider ||
|
||||
new BenchmarkLocomoContextProvider({
|
||||
@@ -2393,7 +2552,7 @@ export class MemoryExtractionExecutor {
|
||||
};
|
||||
|
||||
const builtContext = await contextProvider.buildContext(extractionJob.userId);
|
||||
const extractorContextLimit = this.privateConfig.agentLayerExtractor.contextLimit;
|
||||
const extractorContextLimit = memoryServiceConfig.extractorContextLimit;
|
||||
const trimmedContext = await this.trimTextToTokenLimit(
|
||||
builtContext.context,
|
||||
extractorContextLimit,
|
||||
@@ -2431,7 +2590,7 @@ export class MemoryExtractionExecutor {
|
||||
};
|
||||
|
||||
const service = new MemoryExtractionService({
|
||||
config: this.modelConfig,
|
||||
config: memoryServiceConfig.modelConfig,
|
||||
db,
|
||||
language,
|
||||
runtimes,
|
||||
@@ -2476,6 +2635,7 @@ export class MemoryExtractionExecutor {
|
||||
[],
|
||||
extraction,
|
||||
runtimes,
|
||||
memoryServiceConfig,
|
||||
db,
|
||||
);
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
// @vitest-environment node
|
||||
import { type LobeChatDatabase } from '@lobechat/database';
|
||||
import { users } from '@lobechat/database/schemas';
|
||||
import { users, userSettings } from '@lobechat/database/schemas';
|
||||
import { getTestDB } from '@lobechat/database/test-utils';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { UserPersonaModel } from '@/database/models/userMemory/persona';
|
||||
import type * as AiInfraReposModule from '@/database/repositories/aiInfra';
|
||||
import { resolveRuntimeAgentConfig } from '@/server/services/memory/userMemory/extract';
|
||||
|
||||
import { UserPersonaService } from '../service';
|
||||
|
||||
@@ -134,4 +135,46 @@ describe('UserPersonaService', () => {
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('drops fallback credentials when persona writer provider is overridden', async () => {
|
||||
await db.insert(userSettings).values({
|
||||
id: userId,
|
||||
systemAgent: {
|
||||
userMemoryPersonaWriter: {
|
||||
model: 'claude-mock',
|
||||
provider: 'anthropic',
|
||||
},
|
||||
},
|
||||
});
|
||||
aiInfraMocks.tryMatchingProviderFrom.mockResolvedValue('anthropic');
|
||||
aiInfraMocks.getAiProviderRuntimeState.mockResolvedValue({
|
||||
enabledAiModels: [
|
||||
{ abilities: {}, enabled: true, id: 'claude-mock', providerId: 'anthropic', type: 'chat' },
|
||||
],
|
||||
enabledAiProviders: [],
|
||||
enabledChatAiProviders: [],
|
||||
enabledImageAiProviders: [],
|
||||
runtimeConfig: {},
|
||||
});
|
||||
|
||||
const service = new UserPersonaService(db);
|
||||
await service.composeWriting({ userId, username: 'User' });
|
||||
|
||||
expect(resolveRuntimeAgentConfig).toHaveBeenLastCalledWith(
|
||||
expect.objectContaining({
|
||||
apiKey: undefined,
|
||||
baseURL: undefined,
|
||||
model: 'claude-mock',
|
||||
provider: 'anthropic',
|
||||
}),
|
||||
expect.any(Object),
|
||||
expect.objectContaining({
|
||||
fallback: {
|
||||
apiKey: undefined,
|
||||
baseURL: undefined,
|
||||
},
|
||||
}),
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -9,9 +9,11 @@ import {
|
||||
RetrievalUserMemoryIdentitiesProvider,
|
||||
UserPersonaExtractor,
|
||||
} from '@lobechat/memory-user-memory';
|
||||
import type { UserServiceModelConfig } from '@lobechat/types';
|
||||
import { desc, eq } from 'drizzle-orm';
|
||||
|
||||
import { getBusinessModelRuntimeHooks } from '@/business/server/model-runtime';
|
||||
import { UserModel } from '@/database/models/user';
|
||||
import { UserMemoryModel } from '@/database/models/userMemory';
|
||||
import { UserPersonaModel } from '@/database/models/userMemory/persona';
|
||||
import { AiInfraRepos } from '@/database/repositories/aiInfra';
|
||||
@@ -47,6 +49,14 @@ interface UserPersonaAgentResult {
|
||||
document: UserPersonaDocument;
|
||||
}
|
||||
|
||||
const resolvePositiveInteger = (value?: number) => {
|
||||
if (typeof value !== 'number' || !Number.isFinite(value) || value <= 0) return undefined;
|
||||
|
||||
return Math.floor(value);
|
||||
};
|
||||
|
||||
const normalizeProvider = (provider: string) => provider.toLowerCase();
|
||||
|
||||
export class UserPersonaService {
|
||||
private readonly preferredLanguage?: string;
|
||||
private readonly db: LobeChatDatabase;
|
||||
@@ -60,15 +70,40 @@ export class UserPersonaService {
|
||||
this.agentConfig = agentPersonaWriter;
|
||||
}
|
||||
|
||||
private async resolveAgentConfig(userId: string): Promise<MemoryAgentConfig> {
|
||||
const userModel = new UserModel(this.db, userId);
|
||||
const settings = await userModel.getUserSettings();
|
||||
const userMemoryPersonaWriter = (
|
||||
settings?.systemAgent as Partial<UserServiceModelConfig> | undefined
|
||||
)?.userMemoryPersonaWriter;
|
||||
const provider = userMemoryPersonaWriter?.provider || this.agentConfig.provider;
|
||||
const shouldInheritCredentials =
|
||||
!userMemoryPersonaWriter?.provider ||
|
||||
normalizeProvider(userMemoryPersonaWriter.provider) ===
|
||||
normalizeProvider(this.agentConfig.provider || 'openai');
|
||||
|
||||
return {
|
||||
apiKey: shouldInheritCredentials ? this.agentConfig.apiKey : undefined,
|
||||
baseURL: shouldInheritCredentials ? this.agentConfig.baseURL : undefined,
|
||||
contextLimit:
|
||||
resolvePositiveInteger(userMemoryPersonaWriter?.contextLimit) ??
|
||||
this.agentConfig.contextLimit,
|
||||
language: this.agentConfig.language,
|
||||
model: userMemoryPersonaWriter?.model || this.agentConfig.model,
|
||||
provider,
|
||||
};
|
||||
}
|
||||
|
||||
async composeWriting(payload: UserPersonaAgentPayload): Promise<UserPersonaAgentResult> {
|
||||
const agentConfig = await this.resolveAgentConfig(payload.userId);
|
||||
const aiInfraRepos = new AiInfraRepos(this.db, payload.userId, {});
|
||||
const runtimeState = await aiInfraRepos.getAiProviderRuntimeState(
|
||||
KeyVaultsGateKeeper.getUserKeyVaults,
|
||||
);
|
||||
const providerId = await AiInfraRepos.tryMatchingProviderFrom(runtimeState, {
|
||||
fallbackProvider: this.agentConfig.provider,
|
||||
fallbackProvider: agentConfig.provider,
|
||||
label: 'persona writer',
|
||||
modelId: this.agentConfig.model,
|
||||
modelId: agentConfig.model,
|
||||
});
|
||||
|
||||
const keyVaults: ProviderKeyVaultMap = Object.entries(runtimeState.runtimeConfig || {}).reduce(
|
||||
@@ -82,12 +117,12 @@ export class UserPersonaService {
|
||||
const hooks = getBusinessModelRuntimeHooks(payload.userId, 'lobehub');
|
||||
|
||||
const runtime = await resolveRuntimeAgentConfig(
|
||||
{ ...this.agentConfig },
|
||||
agentConfig,
|
||||
keyVaults,
|
||||
{
|
||||
fallback: {
|
||||
apiKey: this.agentConfig.apiKey,
|
||||
baseURL: this.agentConfig.baseURL,
|
||||
apiKey: agentConfig.apiKey,
|
||||
baseURL: agentConfig.baseURL,
|
||||
},
|
||||
preferred: { providerIds: [providerId] },
|
||||
userId: payload.userId,
|
||||
@@ -101,7 +136,7 @@ export class UserPersonaService {
|
||||
|
||||
const extractor = new UserPersonaExtractor({
|
||||
agent: 'user-persona',
|
||||
model: this.agentConfig.model,
|
||||
model: agentConfig.model,
|
||||
modelRuntime: runtime,
|
||||
});
|
||||
|
||||
@@ -136,7 +171,14 @@ export const buildUserPersonaJobInput = async (db: LobeChatDatabase, userId: str
|
||||
const personaModel = new UserPersonaModel(db, userId);
|
||||
const latestPersona = await personaModel.getLatestPersonaDocument();
|
||||
const { agentPersonaWriter } = parseMemoryExtractionConfig();
|
||||
const personaContextLimit = agentPersonaWriter.contextLimit;
|
||||
const userModel = new UserModel(db, userId);
|
||||
const settings = await userModel.getUserSettings();
|
||||
const userMemoryPersonaWriter = (
|
||||
settings?.systemAgent as Partial<UserServiceModelConfig> | undefined
|
||||
)?.userMemoryPersonaWriter;
|
||||
const personaContextLimit =
|
||||
resolvePositiveInteger(userMemoryPersonaWriter?.contextLimit) ??
|
||||
agentPersonaWriter.contextLimit;
|
||||
|
||||
const userMemoryModel = new UserMemoryModel(db, userId);
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import type {
|
||||
SystemAgentItem,
|
||||
UserGeneralConfig,
|
||||
UserKeyVaults,
|
||||
UserServiceModelConfigKey,
|
||||
UserSettings,
|
||||
UserSystemAgentConfigKey,
|
||||
} from '@/types/user/settings';
|
||||
@@ -237,7 +238,7 @@ export class UserSettingsActionImpl {
|
||||
};
|
||||
|
||||
updateSystemAgent = async (
|
||||
key: UserSystemAgentConfigKey,
|
||||
key: UserServiceModelConfigKey,
|
||||
value: Partial<SystemAgentItem>,
|
||||
): Promise<void> => {
|
||||
await this.#get().setSettings({
|
||||
|
||||
@@ -73,6 +73,10 @@ exports[`settingsSelectors > currentSystemAgent > should merge DEFAULT_SYSTEM_AG
|
||||
"model": "gpt-5.4-mini",
|
||||
"provider": "openai",
|
||||
},
|
||||
"memoryAnalysisAgentConfig": {
|
||||
"model": "gpt-5.4-mini",
|
||||
"provider": "openai",
|
||||
},
|
||||
"promptRewrite": {
|
||||
"enabled": true,
|
||||
"model": "gpt-5.4-mini",
|
||||
@@ -91,6 +95,14 @@ exports[`settingsSelectors > currentSystemAgent > should merge DEFAULT_SYSTEM_AG
|
||||
"model": "gpt-5.4-mini",
|
||||
"provider": "openai",
|
||||
},
|
||||
"userMemoryEmbedding": {
|
||||
"model": "text-embedding-3-small",
|
||||
"provider": "openai",
|
||||
},
|
||||
"userMemoryPersonaWriter": {
|
||||
"model": "gpt-5.4-mini",
|
||||
"provider": "openai",
|
||||
},
|
||||
}
|
||||
`;
|
||||
|
||||
|
||||
@@ -16,6 +16,14 @@ const alias = {
|
||||
__dirname,
|
||||
'./packages/business/model-runtime/src/index.ts',
|
||||
),
|
||||
'@lobechat/business-model-bank/model-config': resolve(
|
||||
__dirname,
|
||||
'./packages/business/model-bank/src/model-config.ts',
|
||||
),
|
||||
'@lobechat/business-model-bank': resolve(
|
||||
__dirname,
|
||||
'./packages/business/model-bank/src/index.ts',
|
||||
),
|
||||
'@emoji-mart/data': resolve(__dirname, './tests/mocks/emojiMartData.ts'),
|
||||
'@emoji-mart/react': resolve(__dirname, './tests/mocks/emojiMartReact.tsx'),
|
||||
'@/database/_deprecated': resolve(__dirname, './src/database/_deprecated'),
|
||||
|
||||
Reference in New Issue
Block a user