♻️ refactor: refactor service

This commit is contained in:
arvinxx
2025-12-20 22:19:42 +08:00
parent 78d07c0504
commit 91bbbf5cb0
33 changed files with 2723 additions and 1379 deletions
+24
View File
@@ -0,0 +1,24 @@
'use client';
import { PropsWithChildren, memo, useEffect } from 'react';
import { createStoreUpdater } from 'zustand-utils';
import { useUserStore } from '@/store/user';
const DesktopAuthProvider = memo<PropsWithChildren>(({ children }) => {
const useStoreUpdater = createStoreUpdater(useUserStore);
const isUserStateInit = useUserStore((s) => s.isUserStateInit);
useStoreUpdater('isLoaded', true);
// Desktop mode uses local auth (DESKTOP_USER_ID) on server,
// so client should be treated as signed-in to enable data initialization.
useEffect(() => {
if (isUserStateInit) {
useUserStore.setState({ isSignedIn: true });
}
}, [isUserStateInit]);
return children;
});
export default DesktopAuthProvider;
@@ -1,107 +1,12 @@
'use client';
import { Modal } from '@lobehub/ui';
import { createStyles } from 'antd-style';
import { ArrowRight, ShieldCheck } from 'lucide-react';
import { Block, Modal, Text } from '@lobehub/ui';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { Trans, useTranslation } from 'react-i18next';
const useStyles = createStyles(({ css, token }) => ({
content: css`
.ant-modal-content {
overflow: hidden;
padding: 0;
}
.ant-modal-header {
margin-block-end: 0;
padding: 0;
border-block-end: none;
}
.ant-modal-body {
padding: 0;
}
.ant-modal-footer {
display: flex;
align-items: center;
justify-content: space-between;
margin-block-start: 0;
padding-block: 16px;
padding-inline: 24px;
border-block-start: 1px solid ${token.colorBorder};
background: ${token.colorBgContainer};
.ant-btn {
margin: 0;
}
}
`,
description: css`
font-size: 14px;
line-height: 1.5;
color: ${token.colorTextSecondary};
text-align: center;
a {
color: ${token.colorPrimary};
text-decoration: none;
&:hover {
text-decoration: underline;
}
}
.highlight {
font-weight: 500;
color: ${token.colorText};
}
`,
header: css`
padding-block: 24px 16px;
padding-inline: 24px;
text-align: center;
`,
iconWrapper: css`
display: flex;
flex-direction: column;
align-items: center;
padding-block: 32px 0;
padding-inline: 0;
`,
okButton: css`
display: flex;
gap: 8px;
align-items: center;
`,
shieldIcon: css`
display: flex;
align-items: center;
justify-content: center;
width: 64px;
height: 64px;
border-radius: 50%;
background: ${token.colorPrimaryBg};
svg {
width: 36px;
height: 36px;
color: ${token.colorPrimary};
}
`,
title: css`
margin-block-end: 24px;
font-size: 18px;
font-weight: 600;
color: ${token.colorText};
`,
}));
import { BRANDING_NAME } from '@/const/branding';
import { PRIVACY_URL, TERMS_URL } from '@/const/url';
import AuthCard from '@/features/AuthCard';
interface MarketAuthConfirmModalProps {
onCancel: () => void;
@@ -112,42 +17,58 @@ interface MarketAuthConfirmModalProps {
const MarketAuthConfirmModal = memo<MarketAuthConfirmModalProps>(
({ open, onConfirm, onCancel }) => {
const { t } = useTranslation('marketAuth');
const { styles } = useStyles();
const footer = (
<Text align={'center'} as={'div'} fontSize={13} type={'secondary'}>
<Trans
components={{
privacy: (
<a
href={PRIVACY_URL}
style={{ color: 'inherit', cursor: 'pointer', textDecoration: 'underline' }}
>
{t('authorize.footer.terms')}
</a>
),
terms: (
<a
href={TERMS_URL}
style={{ color: 'inherit', cursor: 'pointer', textDecoration: 'underline' }}
>
{t('authorize.footer.privacy')}
</a>
),
}}
i18nKey={'authorize.footer.agreement'}
ns={'marketAuth'}
/>
</Text>
);
return (
<Modal
cancelText={t('authorize.cancel')}
className={styles.content}
okButtonProps={{
className: styles.okButton,
icon: <ArrowRight size={16} />,
}}
centered
okText={t('authorize.confirm')}
onCancel={onCancel}
onOk={onConfirm}
open={open}
paddings={{
desktop: 24,
}}
title={null}
width={440}
>
<div className={styles.iconWrapper}>
<div className={styles.shieldIcon}>
<ShieldCheck />
</div>
</div>
<div className={styles.header}>
<div className={styles.title}>{t('authorize.title')}</div>
<div className={styles.description}>
{t('authorize.description.prefix')} <span className="highlight">LobeHub</span>{' '}
<a href="https://lobehub.com/terms" rel="noopener noreferrer" target="_blank">
{t('authorize.description.terms')}
</a>{' '}
{t('authorize.description.and')}{' '}
<a href="https://lobehub.com/privacy" rel="noopener noreferrer" target="_blank">
{t('authorize.description.privacy')}
</a>
</div>
</div>
<AuthCard
footer={footer}
paddingBlock={'40px 20px'}
subtitle={t('authorize.subtitle')}
title={t('authorize.title')}
width={'100%'}
>
<Block padding={16} variant={'filled'}>
<Text align={'center'}>{t('authorize.description', { appName: BRANDING_NAME })}</Text>
</Block>
</AuthCard>
</Modal>
);
},
@@ -1,17 +1,26 @@
'use client';
import { App } from 'antd';
import { ReactNode, createContext, useContext, useEffect, useState } from 'react';
import { ReactNode, createContext, useCallback, useContext, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { mutate as globalMutate } from 'swr';
import { MARKET_OIDC_ENDPOINTS } from '@/services/_url';
import { MARKET_ENDPOINTS, MARKET_OIDC_ENDPOINTS } from '@/services/_url';
import { useUserStore } from '@/store/user';
import { settingsSelectors } from '@/store/user/slices/settings/selectors/settings';
import MarketAuthConfirmModal from './MarketAuthConfirmModal';
import ProfileSetupModal from './ProfileSetupModal';
import { MarketAuthError } from './errors';
import { MarketOIDC } from './oidc';
import { MarketAuthContextType, MarketAuthSession, MarketUserInfo, OIDCConfig } from './types';
import {
MarketAuthContextType,
MarketAuthSession,
MarketUserInfo,
MarketUserProfile,
OIDCConfig,
} from './types';
import { useMarketUserProfile } from './useMarketUserProfile';
const MarketAuthContext = createContext<MarketAuthContextType | null>(null);
@@ -21,44 +30,7 @@ interface MarketAuthProviderProps {
}
/**
* 从 cookie 中获取 token
*/
const getTokenFromCookie = (): string | null => {
if (typeof document === 'undefined') return null;
// eslint-disable-next-line unicorn/no-document-cookie
const cookies = document.cookie.split(';');
for (const cookie of cookies) {
const [name, value] = cookie.trim().split('=');
if (name === 'market-bearertoken') {
console.log('[MarketAuth] Found market token in cookie');
return value;
}
}
return null;
};
/**
* 将 token 存储到 cookie
*/
const setTokenToCookie = (token: string, expiresIn: number) => {
console.log('[MarketAuth] Storing token to cookie');
const expiresAt = new Date(Date.now() + expiresIn * 1000);
// eslint-disable-next-line unicorn/no-document-cookie
document.cookie = `market-bearertoken=${token}; expires=${expiresAt.toUTCString()}; path=/; secure; samesite=strict`;
};
/**
* 从 cookie 中删除 token
*/
const removeTokenFromCookie = () => {
console.log('[MarketAuth] Removing token from cookie');
// eslint-disable-next-line unicorn/no-document-cookie
document.cookie = 'market-bearertoken=; expires=Thu, 01 Jan 1970 00:00:00 UTC; path=/;';
};
/**
* 获取用户信息
* 获取用户信息(从 OIDC userinfo endpoint
*/
const fetchUserInfo = async (accessToken: string): Promise<MarketUserInfo | null> => {
try {
@@ -70,8 +42,6 @@ const fetchUserInfo = async (accessToken: string): Promise<MarketUserInfo | null
method: 'POST',
});
console.log('[MarketAuth] User info response:', response);
if (!response.ok) {
console.error(
'[MarketAuth] Failed to fetch user info:',
@@ -82,7 +52,6 @@ const fetchUserInfo = async (accessToken: string): Promise<MarketUserInfo | null
}
const userInfo = (await response.json()) as MarketUserInfo;
console.log('[MarketAuth] User info fetched successfully:', userInfo);
return userInfo;
} catch (error) {
@@ -107,7 +76,6 @@ const saveMarketTokensToDB = async (
refreshToken?: string,
expiresAt?: number,
) => {
console.log('[MarketAuth] Saving tokens to DB');
try {
await useUserStore.getState().setSettings({
market: {
@@ -116,7 +84,6 @@ const saveMarketTokensToDB = async (
refreshToken,
},
});
console.log('[MarketAuth] Tokens saved to DB successfully');
} catch (error) {
console.error('[MarketAuth] Failed to save tokens to DB:', error);
}
@@ -126,16 +93,16 @@ const saveMarketTokensToDB = async (
* 清除 DB 中的 market tokens
*/
const clearMarketTokensFromDB = async () => {
console.log('[MarketAuth] Clearing tokens from DB');
// 如果已经没有 tokens,不需要调用 setSettings
const currentTokens = getMarketTokensFromDB();
if (!currentTokens?.accessToken && !currentTokens?.refreshToken && !currentTokens?.expiresAt) {
return;
}
try {
await useUserStore.getState().setSettings({
market: {
accessToken: undefined,
expiresAt: undefined,
refreshToken: undefined,
},
market: undefined,
});
console.log('[MarketAuth] Tokens cleared from DB successfully');
} catch (error) {
console.error('[MarketAuth] Failed to clear tokens from DB:', error);
}
@@ -148,11 +115,9 @@ const getRefreshToken = (): string | null => {
// 优先从 DB 获取
const dbTokens = getMarketTokensFromDB();
if (dbTokens?.refreshToken) {
console.log('[MarketAuth] Retrieved refresh token from DB');
return dbTokens.refreshToken;
}
console.log('[MarketAuth] No refresh token found');
return null;
};
@@ -160,10 +125,28 @@ const getRefreshToken = (): string | null => {
* 刷新令牌(暂时简化,后续可以实现 refresh token 逻辑)
*/
const refreshToken = async (): Promise<boolean> => {
console.log('[MarketAuth] Refresh token not implemented yet');
return false;
};
/**
* 检查用户是否需要设置用户名(首次登录)
*/
const checkNeedsProfileSetup = async (username: string): Promise<boolean> => {
try {
const response = await fetch(MARKET_ENDPOINTS.getUserProfile(username));
if (!response.ok) {
// User profile not found, needs setup
return true;
}
const profile = (await response.json()) as MarketUserProfile;
// If userName is not set, user needs to complete profile setup
return !profile.userName;
} catch {
// Error fetching profile, assume needs setup
return true;
}
};
/**
* Market 授权上下文提供者
*/
@@ -174,14 +157,21 @@ export const MarketAuthProvider = ({ children, isDesktop }: MarketAuthProviderPr
const [session, setSession] = useState<MarketAuthSession | null>(null);
const [status, setStatus] = useState<'loading' | 'authenticated' | 'unauthenticated'>('loading');
const [oidcClient, setOidcClient] = useState<MarketOIDC | null>(null);
const [shouldReauthorize, setShouldReauthorize] = useState(false);
const [showConfirmModal, setShowConfirmModal] = useState(false);
const [showProfileSetupModal, setShowProfileSetupModal] = useState(false);
const [isFirstTimeSetup, setIsFirstTimeSetup] = useState(false);
const [pendingSignInResolve, setPendingSignInResolve] = useState<
((value: number | null) => void) | null
((_value: number | null) => void) | null
>(null);
const [pendingSignInReject, setPendingSignInReject] = useState<((reason?: any) => void) | null>(
const [pendingSignInReject, setPendingSignInReject] = useState<((_reason?: any) => void) | null>(
null,
);
const [pendingProfileSuccessCallback, setPendingProfileSuccessCallback] = useState<
((profile: MarketUserProfile) => void) | null
>(null);
// 订阅 user store 的初始化状态,当 isUserStateInit 为 true 时,settings 数据已加载完成
const isUserStateInit = useUserStore((s) => s.isUserStateInit);
// 初始化 OIDC 客户端(仅在客户端)
useEffect(() => {
@@ -205,117 +195,54 @@ export const MarketAuthProvider = ({ children, isDesktop }: MarketAuthProviderPr
}, [isDesktop]);
/**
* 检查并恢复会话
* 初始化:检查并恢复会话,获取用户信息
*/
const restoreSession = () => {
console.log('[MarketAuth] Attempting to restore session');
const initializeSession = async () => {
setStatus('loading');
// 优先级 1: 从 DB 中获取 token(优先级最高)
const dbTokens = getMarketTokensFromDB();
if (dbTokens?.accessToken && dbTokens?.expiresAt) {
// 检查 DB 中的 token 是否过期
if (dbTokens.expiresAt > Date.now()) {
console.log('[MarketAuth] Session restored from DB');
// 尝试从 sessionStorage 获取用户信息(如果有的话)
let userInfo: MarketUserInfo | undefined;
const userInfoData = sessionStorage.getItem('market_user_info');
if (userInfoData) {
try {
userInfo = JSON.parse(userInfoData);
} catch (error) {
console.error('[MarketAuth] Failed to parse stored user info:', error);
}
}
// 创建会话对象
const restoredSession: MarketAuthSession = {
accessToken: dbTokens.accessToken,
expiresAt: dbTokens.expiresAt,
expiresIn: Math.floor((dbTokens.expiresAt - Date.now()) / 1000),
scope: 'openid profile email',
tokenType: 'Bearer',
userInfo,
};
// 同步到 cookie 和 sessionStorage
setTokenToCookie(dbTokens.accessToken, restoredSession.expiresIn);
sessionStorage.setItem('market_auth_session', JSON.stringify(restoredSession));
setSession(restoredSession);
setStatus('authenticated');
return;
} else {
console.log('[MarketAuth] DB token has expired, will trigger re-authorization');
// 清理过期的 DB tokens
clearMarketTokensFromDB();
sessionStorage.removeItem('market_auth_session');
removeTokenFromCookie();
// 标记需要重新授权,等待 oidcClient 准备好
setShouldReauthorize(true);
return;
}
// 检查 DB 中是否有 token
if (!dbTokens?.accessToken) {
setStatus('unauthenticated');
return;
}
// 优先级 2: 从 cookie 和 sessionStorage 中获取(DB 中没有时的备选方案)
const token = getTokenFromCookie();
if (token) {
// 从 sessionStorage 中获取完整的会话信息
const sessionData = sessionStorage.getItem('market_auth_session');
if (sessionData) {
try {
const parsedSession = JSON.parse(sessionData) as MarketAuthSession;
// 检查 token 是否过期
if (parsedSession.expiresAt > Date.now()) {
console.log('[MarketAuth] Session restored from cookie/sessionStorage');
// 如果 session 中没有 userInfo,尝试从单独的存储中获取
if (!parsedSession.userInfo) {
const userInfoData = sessionStorage.getItem('market_user_info');
if (userInfoData) {
try {
parsedSession.userInfo = JSON.parse(userInfoData);
} catch (error) {
console.error('[MarketAuth] Failed to parse stored user info:', error);
}
} else {
setShouldReauthorize(true);
}
}
// 同步到 DB
saveMarketTokensToDB(parsedSession.accessToken, undefined, parsedSession.expiresAt);
setSession(parsedSession);
setStatus('authenticated');
return;
} else {
console.log('[MarketAuth] Stored session has expired, will trigger re-authorization');
sessionStorage.removeItem('market_auth_session');
removeTokenFromCookie();
// 标记需要重新授权,等待 oidcClient 准备好
setShouldReauthorize(true);
return;
}
} catch (error) {
console.error('[MarketAuth] Failed to parse stored session:', error);
sessionStorage.removeItem('market_auth_session');
removeTokenFromCookie();
}
}
// 检查 token 是否过期
if (!dbTokens.expiresAt || dbTokens.expiresAt <= Date.now()) {
// 清理过期的 DB tokens
await clearMarketTokensFromDB();
setStatus('unauthenticated');
return;
}
console.log('[MarketAuth] No valid session found');
setStatus('unauthenticated');
// 获取用户信息
const userInfo = await fetchUserInfo(dbTokens.accessToken);
if (!userInfo) {
// 清理无效的 token
await clearMarketTokensFromDB();
setStatus('unauthenticated');
return;
}
const restoredSession: MarketAuthSession = {
accessToken: dbTokens.accessToken,
expiresAt: dbTokens.expiresAt,
expiresIn: Math.floor((dbTokens.expiresAt - Date.now()) / 1000),
scope: 'openid profile email',
tokenType: 'Bearer',
userInfo,
};
setSession(restoredSession);
setStatus('authenticated');
};
/**
* 实际执行登录的方法(内部使用)
*/
const handleActualSignIn = async (): Promise<number | null> => {
console.log('[MarketAuth] Starting sign in process');
if (!oidcClient) {
console.error('[MarketAuth] OIDC client not initialized');
throw new MarketAuthError('oidcNotReady', { message: 'OIDC client not initialized' });
@@ -326,7 +253,6 @@ export const MarketAuthProvider = ({ children, isDesktop }: MarketAuthProviderPr
// 启动 OIDC 授权流程并获取授权码
const authResult = await oidcClient.startAuthorization();
console.log('[MarketAuth] Authorization successful, exchanging code for token', authResult);
// 用授权码换取访问令牌
const tokenResponse = await oidcClient.exchangeCodeForToken(
@@ -334,40 +260,39 @@ export const MarketAuthProvider = ({ children, isDesktop }: MarketAuthProviderPr
authResult.state,
);
console.log('[MarketAuth] Token response:', tokenResponse);
// 获取用户信息
const userInfo = await fetchUserInfo(tokenResponse.accessToken);
// 创建会话对象
const expiresAt = Date.now() + tokenResponse.expiresIn * 1000;
const newSession: MarketAuthSession = {
accessToken: tokenResponse.accessToken,
expiresAt: Date.now() + tokenResponse.expiresIn * 1000,
expiresAt,
expiresIn: tokenResponse.expiresIn,
scope: tokenResponse.scope,
tokenType: tokenResponse.tokenType as 'Bearer',
userInfo: userInfo || undefined,
};
// 存储 token 到 cookie 和 sessionStorage
setTokenToCookie(tokenResponse.accessToken, tokenResponse.expiresIn);
sessionStorage.setItem('market_auth_session', JSON.stringify(newSession));
// 单独存储用户信息到 sessionStorage 供其他地方使用
if (userInfo) {
sessionStorage.setItem('market_user_info', JSON.stringify(userInfo));
}
// 存储 tokens 到 DB
await saveMarketTokensToDB(
tokenResponse.accessToken,
tokenResponse.refreshToken,
newSession.expiresAt,
);
await saveMarketTokensToDB(tokenResponse.accessToken, tokenResponse.refreshToken, expiresAt);
setSession(newSession);
setStatus('authenticated');
// Check if user needs to set up profile (first-time login)
if (userInfo?.sub) {
const needsSetup = await checkNeedsProfileSetup(userInfo.sub);
if (needsSetup) {
// Wait for next tick to ensure session state is updated before opening modal
// This prevents the edge case where accessToken is null when modal opens
setTimeout(() => {
setIsFirstTimeSetup(true);
setShowProfileSetupModal(true);
}, 0);
}
}
return userInfo?.accountId ?? null;
} catch (error) {
setStatus('unauthenticated');
@@ -433,10 +358,6 @@ export const MarketAuthProvider = ({ children, isDesktop }: MarketAuthProviderPr
const signOut = async () => {
setSession(null);
setStatus('unauthenticated');
removeTokenFromCookie();
sessionStorage.removeItem('market_auth_session');
sessionStorage.removeItem('market_user_info');
// 清除 DB 中的 tokens
await clearMarketTokensFromDB();
};
@@ -444,123 +365,57 @@ export const MarketAuthProvider = ({ children, isDesktop }: MarketAuthProviderPr
* 获取当前用户信息
*/
const getCurrentUserInfo = (): MarketUserInfo | null => {
console.log('getCurrentUserInfo-session', session, session?.userInfo);
if (session?.userInfo) {
return session.userInfo;
}
// 如果 session 中没有,尝试从 sessionStorage 中获取
try {
const userInfoData = sessionStorage.getItem('market_user_info');
if (userInfoData) {
return JSON.parse(userInfoData) as MarketUserInfo;
}
} catch (error) {
console.error('[MarketAuth] Failed to get user info from storage:', error);
}
return null;
return session?.userInfo ?? null;
};
/**
* 获取 access token(优先从 DB 获取,否则从 session 获取)
* 获取 access token(优先从 session 获取,否则从 DB 获取)
*/
const getAccessToken = (): string | null => {
// 优先从 DB 获取
const dbTokens = getMarketTokensFromDB();
if (dbTokens?.accessToken) {
console.log('[MarketAuth] Retrieved access token from DB');
return dbTokens.accessToken;
}
// 如果 DB 中没有,从 session 获取
// 优先从 session 获取(内存中的状态)
if (session?.accessToken) {
console.log('[MarketAuth] Retrieved access token from session');
return session.accessToken;
}
// 如果 session 中也没有,尝试从 sessionStorage 获取
try {
const sessionData = sessionStorage.getItem('market_auth_session');
if (sessionData) {
const parsedSession = JSON.parse(sessionData) as MarketAuthSession;
if (parsedSession.accessToken) {
console.log('[MarketAuth] Retrieved access token from sessionStorage');
return parsedSession.accessToken;
}
}
} catch (error) {
console.error('[MarketAuth] Failed to get access token from sessionStorage:', error);
}
return null;
// 备选从 DB 获取
const dbTokens = getMarketTokensFromDB();
return dbTokens?.accessToken ?? null;
};
/**
* 初始化时恢复会话
* 打开个人资料设置模态框(用于用户手动编辑)
*/
useEffect(() => {
restoreSession();
const openProfileSetup = useCallback((onSuccess?: (profile: MarketUserProfile) => void) => {
setIsFirstTimeSetup(false);
setPendingProfileSuccessCallback(() => onSuccess || null);
setShowProfileSetupModal(true);
}, []);
/**
* 当需要重新授权且 OIDC 客户端准备好时,自动触发重新授权
* 关闭个人资料设置模态框
*/
const handleCloseProfileSetup = useCallback(() => {
setShowProfileSetupModal(false);
setIsFirstTimeSetup(false);
setPendingProfileSuccessCallback(null);
}, []);
/**
* 个人资料更新成功回调
*/
const handleProfileUpdateSuccess = useCallback(() => {
// Profile is updated, modal will close automatically
}, []);
/**
* 初始化时恢复会话并获取用户信息
* 等待 isUserStateInit 为 true,此时 useInitUserState 的 SWR 请求已完成,settings 数据已加载
*/
useEffect(() => {
const handleAutoReauthorization = async () => {
if (shouldReauthorize && oidcClient) {
setShouldReauthorize(false); // 重置标识,避免重复触发
try {
setStatus('loading');
// 启动 OIDC 授权流程并获取授权码
const authResult = await oidcClient.startAuthorization();
// 用授权码换取访问令牌
const tokenResponse = await oidcClient.exchangeCodeForToken(
authResult.code,
authResult.state,
);
// 获取用户信息
const userInfo = await fetchUserInfo(tokenResponse.accessToken);
// 创建会话对象
const newSession: MarketAuthSession = {
accessToken: tokenResponse.accessToken,
expiresAt: Date.now() + tokenResponse.expiresIn * 1000,
expiresIn: tokenResponse.expiresIn,
scope: tokenResponse.scope,
tokenType: tokenResponse.tokenType as 'Bearer',
userInfo: userInfo || undefined,
};
// 存储 token 到 cookie 和 sessionStorage
setTokenToCookie(tokenResponse.accessToken, tokenResponse.expiresIn);
sessionStorage.setItem('market_auth_session', JSON.stringify(newSession));
// 单独存储用户信息到 sessionStorage 供其他地方使用
if (userInfo) {
sessionStorage.setItem('market_user_info', JSON.stringify(userInfo));
}
// 存储 tokens 到 DB
await saveMarketTokensToDB(
tokenResponse.accessToken,
tokenResponse.refreshToken,
newSession.expiresAt,
);
setSession(newSession);
setStatus('authenticated');
console.log('[MarketAuth] Auto re-authorization completed successfully');
} catch (error) {
console.error('[MarketAuth] Auto re-authorization failed:', error);
setStatus('unauthenticated');
}
}
};
handleAutoReauthorization();
}, [shouldReauthorize, oidcClient]);
if (isUserStateInit) {
initializeSession();
}
}, [isUserStateInit]);
const contextValue: MarketAuthContextType = {
getAccessToken,
@@ -568,6 +423,7 @@ export const MarketAuthProvider = ({ children, isDesktop }: MarketAuthProviderPr
getRefreshToken,
isAuthenticated: status === 'authenticated',
isLoading: status === 'loading',
openProfileSetup,
refreshToken,
session,
signIn,
@@ -575,6 +431,40 @@ export const MarketAuthProvider = ({ children, isDesktop }: MarketAuthProviderPr
status,
};
// Get current user's profile for the edit modal
const userInfo = session?.userInfo;
const username = userInfo?.sub;
const { data: userProfile, mutate: mutateUserProfile } = useMarketUserProfile(username);
// Handle profile update success - also refresh the cached profile
const handleProfileSuccess = useCallback(
(profile: MarketUserProfile) => {
handleProfileUpdateSuccess();
// Update the SWR cache with the new profile
mutateUserProfile(profile, false);
// Also refresh the discover store's user profile cache
// The discover store uses keys like 'user-profile-{locale}-{username}'
if (profile.userName) {
globalMutate(
(key) =>
typeof key === 'string' &&
key.includes(`user-profile`) &&
key.includes(profile.userName!),
undefined,
{ revalidate: true },
);
}
// Call the external success callback if provided
if (pendingProfileSuccessCallback) {
pendingProfileSuccessCallback(profile);
setPendingProfileSuccessCallback(null);
}
},
[handleProfileUpdateSuccess, mutateUserProfile, pendingProfileSuccessCallback],
);
return (
<MarketAuthContext.Provider value={contextValue}>
{children}
@@ -583,6 +473,15 @@ export const MarketAuthProvider = ({ children, isDesktop }: MarketAuthProviderPr
onConfirm={handleConfirmAuth}
open={showConfirmModal}
/>
<ProfileSetupModal
accessToken={session?.accessToken ?? null}
defaultDisplayName={userProfile?.displayName || ''}
isFirstTimeSetup={isFirstTimeSetup}
onClose={handleCloseProfileSetup}
onSuccess={handleProfileSuccess}
open={showProfileSetupModal}
userProfile={userProfile}
/>
</MarketAuthContext.Provider>
);
};
@@ -0,0 +1,493 @@
'use client';
import { SiGithub, SiX } from '@icons-pack/react-simple-icons';
import { Modal, Text } from '@lobehub/ui';
import { App, Divider, Form, Input, Tooltip, Upload, UploadProps } from 'antd';
import { useTheme } from 'antd-style';
import { CircleHelp, Globe, ImagePlus, Trash2 } from 'lucide-react';
import { memo, useCallback, useEffect, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { Center, Flexbox } from 'react-layout-kit';
import EmojiPicker from '@/components/EmojiPicker';
import { MARKET_ENDPOINTS } from '@/services/_url';
import { useFileStore } from '@/store/file';
import { useGlobalStore } from '@/store/global';
import { globalGeneralSelectors } from '@/store/global/selectors';
import { MarketUserProfile } from './types';
const MAX_FILE_SIZE = 2 * 1024 * 1024; // 2MB limit
interface ProfileSetupModalProps {
accessToken: string | null;
/**
* Default display name to use (typically from OIDC)
*/
defaultDisplayName?: string;
/**
* Whether this is the first-time setup (after initial sign in)
*/
isFirstTimeSetup?: boolean;
onClose: () => void;
/**
* Callback when profile is successfully updated
*/
onSuccess?: (profile: MarketUserProfile) => void;
open: boolean;
/**
* Current user profile (for editing existing profile)
*/
userProfile?: MarketUserProfile | null;
}
interface FormValues {
description?: string;
displayName: string;
github?: string;
twitter?: string;
userName: string;
website?: string;
}
const ProfileSetupModal = memo<ProfileSetupModalProps>(
({
open,
onClose,
onSuccess,
accessToken,
defaultDisplayName,
userProfile,
isFirstTimeSetup = false,
}) => {
const { t } = useTranslation('marketAuth');
const theme = useTheme();
const { message } = App.useApp();
const [form] = Form.useForm<FormValues>();
const [loading, setLoading] = useState(false);
const locale = useGlobalStore(globalGeneralSelectors.currentLanguage);
// Avatar state
const [avatarUrl, setAvatarUrl] = useState<string | null>(null);
const [avatarUploading, setAvatarUploading] = useState(false);
// Banner state
const [bannerUrl, setBannerUrl] = useState<string | null>(null);
const [bannerUploading, setBannerUploading] = useState(false);
// File upload
const uploadWithProgress = useFileStore((s) => s.uploadWithProgress);
// Reset form when modal opens
useEffect(() => {
if (open) {
// For userName default: use existing userName, or generate from displayName
const existingUserName = userProfile?.userName;
const existingDisplayName = userProfile?.displayName || defaultDisplayName || '';
// Generate default userName from displayName (remove invalid chars, lowercase)
const generatedUserName = existingDisplayName
.toLowerCase()
.replaceAll(/[^\w-]/g, '')
.slice(0, 32);
form.setFieldsValue({
description: userProfile?.description || '',
displayName: existingDisplayName,
github: userProfile?.socialLinks?.github || '',
twitter: userProfile?.socialLinks?.twitter || '',
userName: existingUserName || generatedUserName,
website: userProfile?.socialLinks?.website || '',
});
// Reset avatar and banner
setAvatarUrl(userProfile?.avatarUrl || null);
setBannerUrl(userProfile?.bannerUrl || null);
}
}, [open, userProfile, defaultDisplayName, form]);
// Handle avatar change (emoji)
const handleAvatarChange = useCallback((emoji: string) => {
setAvatarUrl(emoji);
}, []);
// Handle avatar upload
const handleAvatarUpload = useCallback(
async (file: File) => {
if (file.size > MAX_FILE_SIZE) {
message.error(t('profileSetup.errors.fileTooLarge'));
return;
}
setAvatarUploading(true);
try {
const result = await uploadWithProgress({ file });
if (result?.url) {
setAvatarUrl(result.url);
}
} catch (error) {
console.error('[ProfileSetupModal] Avatar upload failed:', error);
message.error(t('profileSetup.errors.uploadFailed'));
} finally {
setAvatarUploading(false);
}
},
[uploadWithProgress, message, t],
);
// Handle avatar delete
const handleAvatarDelete = useCallback(() => {
setAvatarUrl(null);
}, []);
// Handle banner upload
const handleBannerUpload: UploadProps['customRequest'] = useCallback(
async (options: Parameters<NonNullable<UploadProps['customRequest']>>[0]) => {
const file = options.file as File;
if (file.size > MAX_FILE_SIZE) {
message.error(t('profileSetup.errors.fileTooLarge'));
options.onError?.(new Error('File too large'));
return;
}
setBannerUploading(true);
try {
const result = await uploadWithProgress({ file });
if (result?.url) {
setBannerUrl(result.url);
options.onSuccess?.(result);
}
} catch (error) {
console.error('[ProfileSetupModal] Banner upload failed:', error);
message.error(t('profileSetup.errors.uploadFailed'));
options.onError?.(error as Error);
} finally {
setBannerUploading(false);
}
},
[uploadWithProgress, message, t],
);
// Handle banner delete
const handleBannerDelete = useCallback(() => {
setBannerUrl(null);
}, []);
const handleSubmit = useCallback(async () => {
if (!accessToken) {
message.error(t('profileSetup.errors.notAuthenticated'));
return;
}
try {
const values = await form.validateFields();
setLoading(true);
// Build socialLinks object (only include non-empty values)
const socialLinks: { github?: string; twitter?: string; website?: string } = {};
if (values.github) socialLinks.github = values.github;
if (values.twitter) socialLinks.twitter = values.twitter;
if (values.website) socialLinks.website = values.website;
// Build meta object (socialLinks should be inside meta)
const meta: {
bannerUrl?: string;
description?: string;
socialLinks?: { github?: string; twitter?: string; website?: string };
} = {};
if (values.description) meta.description = values.description;
if (bannerUrl) meta.bannerUrl = bannerUrl;
if (Object.keys(socialLinks).length > 0) meta.socialLinks = socialLinks;
const response = await fetch(MARKET_ENDPOINTS.updateUserProfile, {
body: JSON.stringify({
avatarUrl: avatarUrl || undefined,
displayName: values.displayName,
meta: Object.keys(meta).length > 0 ? meta : undefined,
userName: values.userName,
}),
headers: {
'Authorization': `Bearer ${accessToken}`,
'Content-Type': 'application/json',
},
method: 'PUT',
});
if (!response.ok) {
const errorData = await response.json();
if (errorData.error === 'username_taken') {
message.error(t('profileSetup.errors.usernameTaken'));
return;
}
throw new Error(errorData.message || 'Update failed');
}
const data = await response.json();
message.success(t('profileSetup.success'));
onSuccess?.(data.user);
onClose();
} catch (error) {
console.error('[ProfileSetupModal] Update failed:', error);
if (error instanceof Error && error.message !== 'Validation failed') {
message.error(t('profileSetup.errors.updateFailed'));
}
} finally {
setLoading(false);
}
}, [accessToken, avatarUrl, bannerUrl, form, message, onClose, onSuccess, t]);
const handleCancel = useCallback(() => {
if (!isFirstTimeSetup) {
onClose();
}
}, [isFirstTimeSetup, onClose]);
return (
<Modal
cancelButtonProps={isFirstTimeSetup ? { style: { display: 'none' } } : undefined}
cancelText={t('profileSetup.cancel')}
centered
closable={!isFirstTimeSetup}
confirmLoading={loading}
keyboard={!isFirstTimeSetup}
maskClosable={!isFirstTimeSetup}
okText={isFirstTimeSetup ? t('profileSetup.getStarted') : t('profileSetup.save')}
onCancel={handleCancel}
onOk={handleSubmit}
open={open}
title={isFirstTimeSetup ? t('profileSetup.titleFirstTime') : t('profileSetup.titleEdit')}
width={480}
>
<Text style={{ display: 'block', marginBottom: 24 }} type="secondary">
{isFirstTimeSetup
? t('profileSetup.descriptionFirstTime')
: t('profileSetup.descriptionEdit')}
</Text>
<Form form={form} layout="vertical">
{/* Avatar Section */}
<Form.Item label={t('profileSetup.fields.avatar.label')}>
<EmojiPicker
allowDelete={!!avatarUrl}
allowUpload
loading={avatarUploading}
locale={locale}
onChange={handleAvatarChange}
onDelete={handleAvatarDelete}
onUpload={handleAvatarUpload}
shape="circle"
size={80}
value={avatarUrl || undefined}
/>
</Form.Item>
<Form.Item
label={t('profileSetup.fields.displayName.label')}
name="displayName"
rules={[
{ message: t('profileSetup.fields.displayName.required'), required: true },
{
max: 50,
message: t('profileSetup.fields.displayName.maxLength'),
},
]}
>
<Input
maxLength={50}
placeholder={t('profileSetup.fields.displayName.placeholder')}
showCount
/>
</Form.Item>
<Form.Item
label={
<Flexbox align="center" gap={4} horizontal>
{t('profileSetup.fields.userName.label')}
<Tooltip title={t('profileSetup.fields.userName.tooltip')}>
<CircleHelp size={14} style={{ cursor: 'help', opacity: 0.5 }} />
</Tooltip>
</Flexbox>
}
name="userName"
rules={[
{ message: t('profileSetup.fields.userName.required'), required: true },
{
message: t('profileSetup.fields.userName.pattern'),
pattern: /^[\w-]+$/,
},
{
max: 32,
message: t('profileSetup.fields.userName.maxLength'),
},
{
message: t('profileSetup.fields.userName.minLength'),
min: 3,
},
]}
>
<Input
maxLength={32}
placeholder={t('profileSetup.fields.userName.placeholder')}
prefix="@"
showCount
/>
</Form.Item>
<Form.Item
label={t('profileSetup.fields.description.label')}
name="description"
rules={[
{
max: 200,
message: t('profileSetup.fields.description.maxLength'),
},
]}
>
<Input.TextArea
maxLength={200}
placeholder={t('profileSetup.fields.description.placeholder')}
rows={3}
showCount
/>
</Form.Item>
{/* Only show banner and social links in edit mode, not first-time setup */}
{!isFirstTimeSetup && (
<>
<Divider style={{ margin: '16px 0' }} />
{/* Banner Upload Section */}
<Form.Item
label={
<Flexbox align="center" gap={4} horizontal>
{t('profileSetup.fields.bannerUrl.label')}
<Tooltip title={t('profileSetup.fields.bannerUrl.tooltip')}>
<CircleHelp size={14} style={{ cursor: 'help', opacity: 0.5 }} />
</Tooltip>
</Flexbox>
}
>
<Flexbox gap={8} width="100%">
<Upload
accept="image/*"
customRequest={handleBannerUpload}
maxCount={1}
showUploadList={false}
style={{ display: 'block', width: '100%' }}
>
<div
style={{
backgroundColor: bannerUrl ? undefined : theme.colorFillTertiary,
backgroundImage: bannerUrl ? `url(${bannerUrl})` : undefined,
backgroundPosition: 'center',
backgroundSize: 'cover',
borderRadius: theme.borderRadiusLG,
cursor: 'pointer',
height: 120,
overflow: 'hidden',
position: 'relative',
width: '100%',
}}
>
<Center
onMouseEnter={(e) => {
e.currentTarget.style.opacity = '1';
}}
onMouseLeave={(e) => {
if (bannerUrl) e.currentTarget.style.opacity = '0';
}}
style={{
background: bannerUrl ? 'rgba(0,0,0,0.4)' : 'transparent',
height: '100%',
opacity: bannerUrl ? 0 : 1,
transition: 'opacity 0.2s',
width: '100%',
}}
>
<Flexbox align="center" gap={8}>
<ImagePlus
size={24}
style={{ color: bannerUrl ? '#fff' : theme.colorTextSecondary }}
/>
<Text
style={{
color: bannerUrl ? '#fff' : theme.colorTextSecondary,
fontSize: 12,
}}
>
{bannerUploading
? t('profileSetup.fields.bannerUrl.uploading')
: t('profileSetup.fields.bannerUrl.clickToUpload')}
</Text>
</Flexbox>
</Center>
</div>
</Upload>
{bannerUrl && (
<Flexbox align="center" gap={8} horizontal justify="flex-end">
<Text
onClick={(e) => {
e.stopPropagation();
handleBannerDelete();
}}
style={{
color: theme.colorError,
cursor: 'pointer',
fontSize: 12,
}}
>
<Flexbox align="center" gap={4} horizontal>
<Trash2 size={12} />
{t('profileSetup.fields.bannerUrl.remove')}
</Flexbox>
</Text>
</Flexbox>
)}
</Flexbox>
</Form.Item>
<Divider style={{ margin: '16px 0' }} />
<Text style={{ display: 'block', marginBottom: 12 }} type="secondary">
{t('profileSetup.socialLinks.title')}
</Text>
<Form.Item name="github">
<Input
placeholder={t('profileSetup.fields.github.placeholder')}
prefix={<SiGithub size={14} style={{ opacity: 0.5 }} />}
/>
</Form.Item>
<Form.Item name="twitter">
<Input
placeholder={t('profileSetup.fields.twitter.placeholder')}
prefix={<SiX size={14} style={{ opacity: 0.5 }} />}
/>
</Form.Item>
<Form.Item
name="website"
rules={[
{
message: t('profileSetup.fields.website.invalidUrl'),
type: 'url',
},
]}
>
<Input
placeholder={t('profileSetup.fields.website.placeholder')}
prefix={<Globe size={14} style={{ opacity: 0.5 }} />}
/>
</Form.Item>
</>
)}
</Form>
</Modal>
);
},
);
ProfileSetupModal.displayName = 'ProfileSetupModal';
export default ProfileSetupModal;
+7 -1
View File
@@ -1,2 +1,8 @@
export { MarketAuthProvider, useMarketAuth } from './MarketAuthProvider';
export type { MarketAuthContextType, MarketAuthSession, MarketAuthState } from './types';
export type {
MarketAuthContextType,
MarketAuthSession,
MarketAuthState,
MarketUserProfile,
} from './types';
export { useMarketUserProfile } from './useMarketUserProfile';
@@ -0,0 +1,38 @@
import { describe, expect, it, vi } from 'vitest';
import { MARKET_OIDC_ENDPOINTS } from '@/services/_url';
import { MarketOIDC } from './oidc';
describe('MarketOIDC.buildAuthUrl', () => {
it('should join market baseUrl with OIDC auth path correctly (no string concat issues)', async () => {
const client = new MarketOIDC({
baseUrl: 'https://market.lobehub.com/', // trailing slash on purpose
clientId: 'lobehub-desktop',
redirectUri: 'https://market.lobehub.com/lobehub-oidc/callback/desktop',
scope: 'openid profile email',
});
vi.spyOn(client, 'generatePKCEParams').mockResolvedValue({
codeChallenge: 'code_challenge',
codeVerifier: 'code_verifier',
state: 'state_value',
});
const url = await client.buildAuthUrl();
expect(url).toContain('https://market.lobehub.com/lobehub-oidc/auth?');
expect(url).toContain(`client_id=${encodeURIComponent('lobehub-desktop')}`);
expect(url).toContain(`redirect_uri=${encodeURIComponent('https://market.lobehub.com/lobehub-oidc/callback/desktop')}`);
expect(url).toContain(`state=${encodeURIComponent('state_value')}`);
expect(url).toContain(`code_challenge=${encodeURIComponent('code_challenge')}`);
const parsed = new URL(url);
expect(parsed.searchParams.get('scope')).toBe('openid profile email');
// The auth endpoint must be a plain path; it is opened in a real browser.
expect(MARKET_OIDC_ENDPOINTS.auth).toBe('/lobehub-oidc/auth');
});
});
+1 -15
View File
@@ -89,9 +89,7 @@ export class MarketOIDC {
console.log('[MarketOIDC] this.config:', this.config);
const authUrl = new URL(
`${this.config.baseUrl.replace(/\/$/, '')}${MARKET_OIDC_ENDPOINTS.auth}`,
);
const authUrl = new URL(MARKET_OIDC_ENDPOINTS.auth, this.config.baseUrl);
authUrl.searchParams.set('client_id', this.config.clientId);
authUrl.searchParams.set('redirect_uri', this.config.redirectUri);
authUrl.searchParams.set('response_type', 'code');
@@ -132,14 +130,6 @@ export class MarketOIDC {
grant_type: 'authorization_code',
redirect_uri: this.config.redirectUri,
});
console.log('[MarketOIDC] Sending token exchange request', {
client_id: this.config.clientId,
code,
code_verifier: codeVerifier,
grant_type: 'authorization_code',
redirect_uri: this.config.redirectUri,
});
const response = await fetch(tokenUrl, {
body: body.toString(),
headers: {
@@ -148,11 +138,8 @@ export class MarketOIDC {
method: 'POST',
});
console.log('[MarketOIDC] Token exchange response:', response);
if (!response.ok) {
const errorData = await response.json().catch(() => undefined);
console.log('[MarketOIDC] Token exchange error data:', errorData);
const errorMessage =
`Token exchange failed: ${response.status} ${response.statusText} ${errorData?.error_description || errorData?.error || ''}`.trim();
console.error('[MarketOIDC]', errorMessage);
@@ -180,7 +167,6 @@ export class MarketOIDC {
* 启动授权流程并返回授权结果
*/
async startAuthorization(): Promise<{ code: string; state: string }> {
console.log('[MarketOIDC] Starting authorization flow');
const authUrl = await this.buildAuthUrl();
if (typeof window === 'undefined') {
@@ -19,6 +19,26 @@ export interface MarketUserInfo {
};
}
/**
* Market User Profile - Extended user information from Market SDK
*/
export interface MarketUserProfile {
avatarUrl: string | null;
bannerUrl: string | null;
createdAt: string;
description: string | null;
displayName: string | null;
id: number;
namespace: string;
socialLinks: {
github?: string;
twitter?: string;
website?: string;
} | null;
type: string | null;
userName: string | null;
}
export interface MarketAuthSession {
accessToken: string;
expiresAt: number;
@@ -39,6 +59,7 @@ export interface MarketAuthContextType extends MarketAuthState {
getAccessToken: () => string | null;
getCurrentUserInfo: () => MarketUserInfo | null;
getRefreshToken: () => string | null;
openProfileSetup: (onSuccess?: (profile: MarketUserProfile) => void) => void;
refreshToken: () => Promise<boolean>;
signIn: () => Promise<number | null>;
signOut: () => Promise<void>;
@@ -0,0 +1,36 @@
import useSWR from 'swr';
import { MARKET_ENDPOINTS } from '@/services/_url';
import { MarketUserProfile } from './types';
/**
* Fetcher function for user profile
*/
const fetchUserProfile = async (username: string): Promise<MarketUserProfile | null> => {
const response = await fetch(MARKET_ENDPOINTS.getUserProfile(username));
if (!response.ok) {
throw new Error(`Failed to fetch user profile: ${response.status}`);
}
return response.json();
};
/**
* Hook to fetch and cache Market user profile using SWR
*
* @param username - The username to fetch profile for (typically userInfo.sub)
* @returns SWR response with user profile data
*/
export const useMarketUserProfile = (username: string | null | undefined) => {
return useSWR<MarketUserProfile | null>(
username ? ['market-user-profile', username] : null,
() => fetchUserProfile(username!),
{
dedupingInterval: 60_000, // 1 minute deduplication
revalidateOnFocus: false,
revalidateOnReconnect: false,
},
);
};
+17 -18
View File
@@ -1,33 +1,32 @@
import { isDesktop } from '@lobechat/const';
import { PropsWithChildren } from 'react';
import { isDesktop } from '@/const/version';
import { authEnv } from '@/envs/auth';
import BetterAuth from './BetterAuth';
import Clerk from './Clerk';
import { MarketAuthProvider } from './MarketAuth';
import Desktop from './Desktop';
import NextAuth from './NextAuth';
import NoAuth from './NoAuth';
const AuthProvider = ({ children }: PropsWithChildren) => {
// 获取内部 AuthProvider
let InnerAuthProvider;
if (authEnv.NEXT_PUBLIC_ENABLE_CLERK_AUTH) {
InnerAuthProvider = ({ children }: PropsWithChildren) => <Clerk>{children}</Clerk>;
} else if (authEnv.NEXT_PUBLIC_ENABLE_BETTER_AUTH) {
InnerAuthProvider = ({ children }: PropsWithChildren) => <BetterAuth>{children}</BetterAuth>;
} else if (authEnv.NEXT_PUBLIC_ENABLE_NEXT_AUTH) {
InnerAuthProvider = ({ children }: PropsWithChildren) => <NextAuth>{children}</NextAuth>;
} else {
InnerAuthProvider = ({ children }: PropsWithChildren) => <NoAuth>{children}</NoAuth>;
if (isDesktop) {
return <Desktop>{children}</Desktop>;
}
// 将 MarketAuthProvider 包装在内部 AuthProvider 之外
return (
<InnerAuthProvider>
<MarketAuthProvider isDesktop={isDesktop}>{children}</MarketAuthProvider>
</InnerAuthProvider>
);
if (authEnv.NEXT_PUBLIC_ENABLE_CLERK_AUTH) {
return <Clerk>{children}</Clerk>;
}
if (authEnv.NEXT_PUBLIC_ENABLE_BETTER_AUTH) {
return <BetterAuth>{children}</BetterAuth>;
}
if (authEnv.NEXT_PUBLIC_ENABLE_NEXT_AUTH) {
return <NextAuth>{children}</NextAuth>;
}
return <NoAuth>{children}</NoAuth>;
};
export default AuthProvider;
-470
View File
@@ -1,470 +0,0 @@
'use client';
import { Tag } from '@lobehub/ui';
import { createStyles } from 'antd-style';
import { Command } from 'cmdk';
import {
ArrowLeft,
ArrowUpDown,
BookOpen,
Bot,
Compass,
CornerDownLeft,
Github,
MessageCircle,
Monitor,
Moon,
Palette,
Settings,
Star,
Sun,
} from 'lucide-react';
import { usePathname, useRouter } from 'next/navigation';
import { memo, useEffect, useState } from 'react';
import { createPortal } from 'react-dom';
import { useTranslation } from 'react-i18next';
import { useHotkeyById } from '@/hooks/useHotkeys/useHotkeyById';
import { useGlobalStore } from '@/store/global';
import { featureFlagsSelectors, useServerConfigStore } from '@/store/serverConfig';
import { useSessionStore } from '@/store/session';
import { HotkeyEnum } from '@/types/hotkey';
const useStyles = createStyles(({ css, token }) => ({
backTag: css`
cursor: pointer;
&:hover {
opacity: 0.8;
}
`,
commandFooter: css`
display: flex;
gap: 16px;
align-items: center;
justify-content: flex-end;
padding-block: 8px;
padding-inline: 16px;
border-block-start: 1px solid ${token.colorBorderSecondary};
background: ${token.colorBgContainer};
`,
commandRoot: css`
overflow: hidden;
display: flex;
flex-direction: column;
width: min(640px, 90vw);
max-height: min(500px, 70vh);
border-radius: ${token.borderRadiusLG}px;
background: ${token.colorBgElevated};
box-shadow: ${token.boxShadowSecondary};
animation: slide-down 0.12s ease-out;
@keyframes slide-down {
from {
transform: translateY(-20px) scale(0.96);
opacity: 0;
}
to {
transform: translateY(0) scale(1);
opacity: 1;
}
}
[cmdk-input] {
flex: 1;
min-width: 0;
padding: 0;
border: none;
font-family: inherit;
font-size: 16px;
color: ${token.colorText};
background: transparent;
outline: none;
&::placeholder {
color: ${token.colorTextPlaceholder};
}
}
[cmdk-list] {
overflow-y: auto;
max-height: 400px;
padding: 8px;
}
[cmdk-empty] {
padding-block: 32px;
padding-inline: 16px;
font-size: 14px;
color: ${token.colorTextTertiary};
text-align: center;
}
[cmdk-item] {
cursor: pointer;
user-select: none;
display: flex;
gap: 12px;
align-items: center;
padding-block: 12px;
padding-inline: 16px;
border-radius: ${token.borderRadius}px;
color: ${token.colorText};
transition: all 0.15s ease;
&[aria-selected='true'] {
background: ${token.colorBgTextHover};
}
&:hover {
background: ${token.colorBgTextHover};
}
}
[cmdk-group-heading] {
user-select: none;
padding-block: 8px;
padding-inline: 16px;
font-size: 12px;
font-weight: 500;
color: ${token.colorTextSecondary};
}
[cmdk-separator] {
height: 1px;
margin-block: 4px;
background: ${token.colorBorderSecondary};
}
`,
icon: css`
flex-shrink: 0;
width: 20px;
height: 20px;
color: ${token.colorTextSecondary};
`,
inputWrapper: css`
display: flex;
gap: 8px;
align-items: center;
padding: 16px;
border-block-end: 1px solid ${token.colorBorderSecondary};
`,
itemContent: css`
flex: 1;
min-width: 0;
`,
itemDescription: css`
margin-block-start: 2px;
font-size: 12px;
line-height: 1.4;
color: ${token.colorTextTertiary};
`,
itemLabel: css`
font-size: 14px;
font-weight: 500;
line-height: 1.4;
`,
kbd: css`
display: inline-flex;
gap: 4px;
align-items: center;
padding-block: 2px;
padding-inline: 6px;
border-radius: ${token.borderRadiusSM}px;
font-size: 11px;
font-weight: 500;
line-height: 1.2;
color: ${token.colorTextSecondary};
background: ${token.colorFillQuaternary};
`,
kbdIcon: css`
width: 12px;
height: 12px;
`,
overlay: css`
position: fixed;
z-index: 9999;
inset: 0;
display: flex;
justify-content: center;
padding-block-start: 15vh;
background: ${token.colorBgMask};
animation: fade-in 0.1s ease-in-out;
@keyframes fade-in {
from {
opacity: 0;
}
to {
opacity: 1;
}
}
`,
}));
const Cmdk = memo(() => {
const [open, setOpen] = useState(false);
const [mounted, setMounted] = useState(false);
const [search, setSearch] = useState('');
const [pages, setPages] = useState<string[]>([]);
const router = useRouter();
const pathname = usePathname();
const { t } = useTranslation('common');
const { styles } = useStyles();
const switchThemeMode = useGlobalStore((s) => s.switchThemeMode);
const createSession = useSessionStore((s) => s.createSession);
const { showCreateSession } = useServerConfigStore(featureFlagsSelectors);
const page = pages.at(-1);
// Ensure we're mounted on the client
useEffect(() => {
setMounted(true);
}, []);
// Register Cmd+K / Ctrl+K hotkey
useHotkeyById(HotkeyEnum.CommandPalette, () => {
setOpen((prev) => !prev);
});
// Close on Escape key and prevent body scroll
useEffect(() => {
if (open) {
const originalStyle = window.getComputedStyle(document.body).overflow;
document.body.style.overflow = 'hidden';
return () => {
document.body.style.overflow = originalStyle;
};
}
}, [open]);
// Reset pages and search when opening/closing
useEffect(() => {
if (open) {
setPages([]);
setSearch('');
}
}, [open]);
const handleNavigate = (path: string) => {
router.push(path);
setOpen(false);
};
const handleExternalLink = (url: string) => {
window.open(url, '_blank', 'noopener,noreferrer');
setOpen(false);
};
const handleThemeChange = (theme: 'light' | 'dark' | 'auto') => {
switchThemeMode(theme);
setOpen(false);
};
if (!mounted || !open) return null;
return createPortal(
<div className={styles.overlay} onClick={() => setOpen(false)}>
<div onClick={(e) => e.stopPropagation()}>
<Command
className={styles.commandRoot}
onKeyDown={(e) => {
// Escape goes to previous page or closes
if (e.key === 'Escape') {
e.preventDefault();
if (pages.length > 0) {
setPages((prev) => prev.slice(0, -1));
} else {
setOpen(false);
}
}
// Backspace goes to previous page when search is empty
if (e.key === 'Backspace' && !search && pages.length > 0) {
e.preventDefault();
setPages((prev) => prev.slice(0, -1));
}
}}
shouldFilter={true}
>
<div className={styles.inputWrapper}>
{pages.length > 0 && (
<Tag
className={styles.backTag}
icon={<ArrowLeft size={12} />}
onClick={() => setPages((prev) => prev.slice(0, -1))}
/>
)}
<Command.Input
autoFocus
onValueChange={setSearch}
placeholder={t('cmdk.searchPlaceholder')}
value={search}
/>
<Tag>ESC</Tag>
</div>
<Command.List>
<Command.Empty>{t('cmdk.noResults')}</Command.Empty>
{!page && (
<>
{showCreateSession && (
<Command.Item
onSelect={() => {
createSession();
setOpen(false);
}}
value="new-agent"
>
<Bot className={styles.icon} />
<div className={styles.itemContent}>
<div className={styles.itemLabel}>{t('cmdk.newAgent')}</div>
</div>
</Command.Item>
)}
{!pathname?.startsWith('/settings') && (
<Command.Item onSelect={() => handleNavigate('/settings')} value="settings">
<Settings className={styles.icon} />
<div className={styles.itemContent}>
<div className={styles.itemLabel}>{t('cmdk.settings')}</div>
</div>
</Command.Item>
)}
<Command.Item onSelect={() => setPages([...pages, 'theme'])} value="theme">
<Monitor className={styles.icon} />
<div className={styles.itemContent}>
<div className={styles.itemLabel}>{t('cmdk.theme')}</div>
</div>
</Command.Item>
<Command.Group heading={t('cmdk.navigate')}>
{!pathname?.startsWith('/discover') && (
<Command.Item onSelect={() => handleNavigate('/discover')} value="discover">
<Compass className={styles.icon} />
<div className={styles.itemContent}>
<div className={styles.itemLabel}>{t('cmdk.discover')}</div>
</div>
</Command.Item>
)}
{!pathname?.startsWith('/image') && (
<Command.Item onSelect={() => handleNavigate('/image')} value="painting">
<Palette className={styles.icon} />
<div className={styles.itemContent}>
<div className={styles.itemLabel}>{t('cmdk.painting')}</div>
</div>
</Command.Item>
)}
{!pathname?.startsWith('/knowledge') && (
<Command.Item onSelect={() => handleNavigate('/knowledge')} value="knowledge">
<BookOpen className={styles.icon} />
<div className={styles.itemContent}>
<div className={styles.itemLabel}>{t('cmdk.knowledgeBase')}</div>
</div>
</Command.Item>
)}
</Command.Group>
<Command.Group heading={t('cmdk.about')}>
<Command.Item
onSelect={() =>
handleExternalLink('https://github.com/lobehub/lobe-chat/issues/new/choose')
}
value="submit-issue"
>
<Github className={styles.icon} />
<div className={styles.itemContent}>
<div className={styles.itemLabel}>{t('cmdk.submitIssue')}</div>
</div>
</Command.Item>
<Command.Item
onSelect={() => handleExternalLink('https://github.com/lobehub/lobe-chat')}
value="star-github"
>
<Star className={styles.icon} />
<div className={styles.itemContent}>
<div className={styles.itemLabel}>{t('cmdk.starOnGitHub')}</div>
</div>
</Command.Item>
<Command.Item
onSelect={() => handleExternalLink('https://discord.gg/AYFPHvv2jT')}
value="discord"
>
<MessageCircle className={styles.icon} />
<div className={styles.itemContent}>
<div className={styles.itemLabel}>{t('cmdk.communitySupport')}</div>
</div>
</Command.Item>
</Command.Group>
</>
)}
{page === 'theme' && (
<>
<Command.Item onSelect={() => handleThemeChange('light')} value="theme-light">
<Sun className={styles.icon} />
<div className={styles.itemContent}>
<div className={styles.itemLabel}>{t('cmdk.themeLight')}</div>
</div>
</Command.Item>
<Command.Item onSelect={() => handleThemeChange('dark')} value="theme-dark">
<Moon className={styles.icon} />
<div className={styles.itemContent}>
<div className={styles.itemLabel}>{t('cmdk.themeDark')}</div>
</div>
</Command.Item>
<Command.Item onSelect={() => handleThemeChange('auto')} value="theme-auto">
<Monitor className={styles.icon} />
<div className={styles.itemContent}>
<div className={styles.itemLabel}>{t('cmdk.themeAuto')}</div>
</div>
</Command.Item>
</>
)}
</Command.List>
<div className={styles.commandFooter}>
<div className={styles.kbd}>
<CornerDownLeft className={styles.kbdIcon} />
<span>{t('cmdk.toOpen')}</span>
</div>
<div className={styles.kbd}>
<ArrowUpDown className={styles.kbdIcon} />
<span>{t('cmdk.toSelect')}</span>
</div>
</div>
</Command>
</div>
</div>,
document.body,
);
});
Cmdk.displayName = 'Cmdk';
export default Cmdk;
+9 -9
View File
@@ -1,16 +1,16 @@
'use client';
import dynamic from 'next/dynamic';
import { memo } from 'react';
import { Suspense, lazy, memo } from 'react';
// Lazy load the CMDK component with Next.js dynamic import
// This splits the CMDK code into a separate chunk that only loads when needed
// ssr: false ensures it only loads on the client side
const CmdkComponent = dynamic(() => import('./Cmdk'), {
ssr: false,
});
// Lazy load the CommandMenu component with React lazy
// This splits the CommandMenu code into a separate chunk that only loads when needed
const CmdkComponent = lazy(() => import('@/features/CommandMenu'));
const CmdkLazy = memo(() => <CmdkComponent />);
const CmdkLazy = memo(() => (
<Suspense fallback={null}>
<CmdkComponent />
</Suspense>
));
CmdkLazy.displayName = 'CmdkLazy';
@@ -1,7 +1,7 @@
'use client';
import { enableNextAuth } from '@lobechat/const';
import { useRouter } from 'next/navigation';
import { INBOX_SESSION_ID, enableNextAuth } from '@lobechat/const';
import { usePathname } from 'next/navigation';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { createStoreUpdater } from 'zustand-utils';
@@ -14,19 +14,14 @@ import { electronSyncSelectors } from '@/store/electron/selectors';
import { useGlobalStore } from '@/store/global';
import { useServerConfigStore } from '@/store/serverConfig';
import { serverConfigSelectors } from '@/store/serverConfig/selectors';
import { useUrlHydrationStore } from '@/store/urlHydration';
import { useUserStore } from '@/store/user';
import { authSelectors } from '@/store/user/selectors';
import { authSelectors, onboardingSelectors } from '@/store/user/selectors';
const StoreInitialization = memo(() => {
// prefetch error ns to avoid don't show error content correctly
useTranslation('error');
// Initialize from URL (one-time)
const initAgentPinnedFromUrl = useUrlHydrationStore((s) => s.initAgentPinnedFromUrl);
initAgentPinnedFromUrl();
const router = useRouter();
const pathname = usePathname();
const [isLogin, isSignedIn, useInitUserState] = useUserStore((s) => [
authSelectors.isLogin(s),
s.isSignedIn,
@@ -37,7 +32,7 @@ const StoreInitialization = memo(() => {
const useInitSystemStatus = useGlobalStore((s) => s.useInitSystemStatus);
const useInitAgentStore = useAgentStore((s) => s.useInitInboxAgentStore);
const useInitBuiltinAgent = useAgentStore((s) => s.useInitBuiltinAgent);
const useInitAiProviderKeyVaults = useAiInfraStore((s) => s.useFetchAiProviderRuntimeState);
// init the system preference
@@ -62,8 +57,8 @@ const StoreInitialization = memo(() => {
*/
const isLoginOnInit = Boolean(enableNextAuth ? isSignedIn : isLogin);
// init inbox agent and default agent config
useInitAgentStore(isLoginOnInit, serverConfig.defaultAgent?.config);
// init inbox agent via builtin agent mechanism
useInitBuiltinAgent(INBOX_SESSION_ID, { isLogin: isLoginOnInit });
const isSyncActive = useElectronStore((s) => electronSyncSelectors.isSyncActive(s));
@@ -73,8 +68,11 @@ const StoreInitialization = memo(() => {
// init user state
useInitUserState(isLoginOnInit, serverConfig, {
onSuccess: (state) => {
if (state.isOnboard === false) {
router.push('/onboard');
// Skip redirect if already on onboarding page
if (pathname?.includes('/onboarding')) return;
if (onboardingSelectors.needsOnboarding(state)) {
window.location.href = '/onboarding';
}
},
});
+4 -3
View File
@@ -1,3 +1,4 @@
import { LazyMotion, domMax } from 'motion/react';
import { ReactNode, Suspense } from 'react';
import { LobeAnalyticsProviderWrapper } from '@/components/Analytics/LobeAnalyticsProviderWrapper';
@@ -10,7 +11,6 @@ import { getAntdLocale } from '@/utils/locale';
import AntdV5MonkeyPatch from './AntdV5MonkeyPatch';
import AppTheme from './AppTheme';
import CmdkLazy from './CmdkLazy';
import ImportSettings from './ImportSettings';
import Locale from './Locale';
import QueryProvider from './Query';
@@ -59,14 +59,15 @@ const GlobalLayout = async ({
serverConfig={serverConfig}
>
<QueryProvider>
<LobeAnalyticsProviderWrapper>{children}</LobeAnalyticsProviderWrapper>
<LazyMotion features={domMax}>
<LobeAnalyticsProviderWrapper>{children}</LobeAnalyticsProviderWrapper>
</LazyMotion>
</QueryProvider>
<StoreInitialization />
<Suspense>
<ImportSettings />
{process.env.NODE_ENV === 'development' && <DevPanel />}
</Suspense>
<CmdkLazy />
</ServerConfigStoreProvider>
</AppTheme>
</Locale>
+141
View File
@@ -1,6 +1,85 @@
import { AgentItem, LobeAgentConfig, MetaData } from '@lobechat/types';
import type { PartialDeep } from 'type-fest';
import { lambdaClient } from '@/libs/trpc/client';
/**
* Market agent model can be either a string or an object with model details
*/
type MarketAgentModel =
| LobeAgentConfig['model']
| {
model: LobeAgentConfig['model'];
parameters?: Partial<LobeAgentConfig['params']>;
provider?: LobeAgentConfig['provider'];
};
/**
* Normalize market agent config to standard agent config.
* Handles the case where market returns model as an object instead of string.
*/
const normalizeMarketAgentModel = (config?: PartialDeep<AgentItem>): PartialDeep<AgentItem> => {
if (!config) return {};
const model = config.model as MarketAgentModel | undefined;
// If model is not an object, return config as-is
if (typeof model !== 'object' || model === null) {
return config;
}
// Extract model info and merge parameters
const { model: modelName, provider: modelProvider, parameters } = model;
const existingParams = (config.params ?? {}) as Record<string, any>;
const mergedParams = { ...parameters, ...existingParams };
return {
...config,
model: modelName,
params: Object.keys(mergedParams).length > 0 ? mergedParams : undefined,
provider: config.provider ?? modelProvider,
};
};
export interface CreateAgentParams {
config?: PartialDeep<AgentItem>;
groupId?: string;
}
export interface CreateAgentResult {
agentId?: string;
sessionId: string;
}
class AgentService {
/**
* Check if an agent with the given marketIdentifier already exists
*/
checkByMarketIdentifier = async (marketIdentifier: string): Promise<boolean> => {
return lambdaClient.agent.checkByMarketIdentifier.query({ marketIdentifier });
};
/**
* Get an agent by marketIdentifier
* @returns agent id if exists, null otherwise
*/
getAgentByMarketIdentifier = async (marketIdentifier: string): Promise<string | null> => {
return lambdaClient.agent.getAgentByMarketIdentifier.query({ marketIdentifier });
};
/**
* Create a new agent with session.
* Automatically normalizes market agent config (handles model as object).
*/
createAgent = async (params: CreateAgentParams): Promise<CreateAgentResult> => {
const normalizedConfig = normalizeMarketAgentModel(params.config);
return lambdaClient.agent.createAgent.mutate({
config: normalizedConfig as any,
groupId: params.groupId,
});
};
createAgentKnowledgeBase = async (
agentId: string,
knowledgeBaseId: string,
@@ -44,6 +123,68 @@ class AgentService {
getFilesAndKnowledgeBases = async (agentId: string) => {
return lambdaClient.agent.getKnowledgeBasesAndFiles.query({ agentId });
};
getAgentConfigById = async (agentId: string) => {
return lambdaClient.agent.getAgentConfigById.query({ agentId });
};
/**
* @deprecated use getAgentConfigById instead
*/
getSessionConfig = async (sessionId: string) => {
return lambdaClient.agent.getAgentConfig.query({ sessionId });
};
/**
* Update agent config and return the updated agent data
*/
updateAgentConfig = async (
agentId: string,
config: PartialDeep<LobeAgentConfig>,
signal?: AbortSignal,
) => {
return lambdaClient.agent.updateAgentConfig.mutate(
{ agentId, value: config },
{ context: { showNotification: false }, signal },
);
};
/**
* Update agent meta and return the updated agent data
*/
updateAgentMeta = async (agentId: string, meta: Partial<MetaData>, signal?: AbortSignal) => {
return lambdaClient.agent.updateAgentConfig.mutate({ agentId, value: meta }, { signal });
};
/**
* Get a builtin agent by slug, creating it if it doesn't exist.
* This is a generic interface for all builtin agents (page-copilot, inbox, etc.)
*/
getBuiltinAgent = async (slug: string) => {
return lambdaClient.agent.getBuiltinAgent.query({ slug });
};
/**
* Remove an agent and its associated session
*/
removeAgent = async (agentId: string) => {
return lambdaClient.agent.removeAgent.mutate({ agentId });
};
/**
* Query non-virtual agents with optional keyword filter.
* Returns agents with minimal info (id, title, description, avatar, backgroundColor).
*/
queryAgents = async (params?: { keyword?: string; limit?: number; offset?: number }) => {
return lambdaClient.agent.queryAgents.query(params);
};
/**
* Pin or unpin an agent
*/
updateAgentPinned = async (agentId: string, pinned: boolean) => {
return lambdaClient.agent.updateAgentPinned.mutate({ id: agentId, pinned });
};
}
export const agentService = new AgentService();
+107 -51
View File
@@ -9,7 +9,7 @@ import { type Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vite
import { DEFAULT_USER_AVATAR } from '@/const/meta';
import { DEFAULT_AGENT_CONFIG } from '@/const/settings';
import * as toolEngineeringModule from '@/helpers/toolEngineering';
import { agentChatConfigSelectors } from '@/store/agent/selectors';
import { agentSelectors, chatConfigByIdSelectors } from '@/store/agent/selectors';
import { aiModelSelectors } from '@/store/aiInfra';
import { useToolStore } from '@/store/tool';
import { WebBrowsingManifest } from '@/tools/web-browsing';
@@ -58,6 +58,15 @@ beforeEach(async () => {
isDeprecatedEdition: true,
isDesktop: false,
}));
// Default mock for agentSelectors - resolveAgentConfig needs these
vi.spyOn(agentSelectors, 'getAgentConfigById').mockReturnValue(
() => ({ plugins: [], systemRole: '' }) as any,
);
vi.spyOn(agentSelectors, 'getAgentSlugById').mockReturnValue(() => undefined);
vi.spyOn(chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() => ({ searchMode: 'off' }) as any,
);
});
// mock auth
@@ -128,11 +137,14 @@ describe('ChatService', () => {
vi.spyOn(aiModelSelectors, 'modelExtendParams').mockReturnValue(() => ['enableReasoning']);
// Mock agent chat config with reasoning enabled
vi.spyOn(agentChatConfigSelectors, 'currentChatConfig').mockReturnValue({
enableReasoning: true,
reasoningBudgetToken: 2048,
searchMode: 'off',
} as any);
vi.spyOn(chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() =>
({
enableReasoning: true,
reasoningBudgetToken: 2048,
searchMode: 'off',
}) as any,
);
await chatService.createAssistantMessage({
messages,
@@ -161,10 +173,13 @@ describe('ChatService', () => {
vi.spyOn(aiModelSelectors, 'modelExtendParams').mockReturnValue(() => ['enableReasoning']);
// Mock agent chat config with reasoning disabled
vi.spyOn(agentChatConfigSelectors, 'currentChatConfig').mockReturnValue({
enableReasoning: false,
searchMode: 'off',
} as any);
vi.spyOn(chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() =>
({
enableReasoning: false,
searchMode: 'off',
}) as any,
);
await chatService.createAssistantMessage({
messages,
@@ -193,11 +208,14 @@ describe('ChatService', () => {
vi.spyOn(aiModelSelectors, 'modelExtendParams').mockReturnValue(() => ['enableReasoning']);
// Mock agent chat config with reasoning enabled but no custom budget
vi.spyOn(agentChatConfigSelectors, 'currentChatConfig').mockReturnValue({
enableReasoning: true,
// reasoningBudgetToken is undefined
searchMode: 'off',
} as any);
vi.spyOn(chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() =>
({
enableReasoning: true,
// reasoningBudgetToken is undefined
searchMode: 'off',
}) as any,
);
await chatService.createAssistantMessage({
messages,
@@ -226,10 +244,13 @@ describe('ChatService', () => {
vi.spyOn(aiModelSelectors, 'modelExtendParams').mockReturnValue(() => ['reasoningEffort']);
// Mock agent chat config with reasoning effort set
vi.spyOn(agentChatConfigSelectors, 'currentChatConfig').mockReturnValue({
reasoningEffort: 'high',
searchMode: 'off',
} as any);
vi.spyOn(chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() =>
({
reasoningEffort: 'high',
searchMode: 'off',
}) as any,
);
await chatService.createAssistantMessage({
messages,
@@ -255,10 +276,13 @@ describe('ChatService', () => {
vi.spyOn(aiModelSelectors, 'modelExtendParams').mockReturnValue(() => ['thinkingBudget']);
// Mock agent chat config with thinking budget set
vi.spyOn(agentChatConfigSelectors, 'currentChatConfig').mockReturnValue({
thinkingBudget: 5000,
searchMode: 'off',
} as any);
vi.spyOn(chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() =>
({
thinkingBudget: 5000,
searchMode: 'off',
}) as any,
);
await chatService.createAssistantMessage({
messages,
@@ -339,6 +363,7 @@ describe('ChatService', () => {
],
model: 'gpt-4-vision-preview',
provider: 'openai',
stream: true,
enabledSearch: undefined,
tools: undefined,
},
@@ -362,6 +387,7 @@ describe('ChatService', () => {
{ content: 'Hello', role: 'user' },
{ content: 'Hey', role: 'assistant' },
],
stream: true,
tools: undefined,
},
undefined,
@@ -457,6 +483,7 @@ describe('ChatService', () => {
},
],
model: 'gpt-4-vision-preview',
stream: true,
enabledSearch: undefined,
tools: undefined,
},
@@ -544,6 +571,7 @@ describe('ChatService', () => {
},
],
model: 'gpt-4-vision-preview',
stream: true,
enabledSearch: undefined,
tools: undefined,
},
@@ -713,7 +741,9 @@ describe('ChatService', () => {
expect(getChatCompletionSpy).toHaveBeenCalledWith(
{
enabledSearch: undefined,
model: 'gpt-3.5-turbo-1106',
stream: true,
top_p: 1,
tools: [
{
@@ -813,7 +843,9 @@ describe('ChatService', () => {
expect(getChatCompletionSpy).toHaveBeenCalledWith(
{
enabledSearch: undefined,
model: 'gpt-3.5-turbo-1106',
stream: true,
top_p: 1,
tools: [
{
@@ -867,7 +899,10 @@ describe('ChatService', () => {
expect(getChatCompletionSpy).toHaveBeenCalledWith(
{
enabledSearch: undefined,
model: 'gpt-3.5-turbo-1106',
stream: true,
tools: undefined,
top_p: 1,
messages: [
{
@@ -889,10 +924,13 @@ describe('ChatService', () => {
const messages = [{ content: 'Search for something', role: 'user' }] as UIChatMessage[];
// Mock agent store state with search enabled
vi.spyOn(agentChatConfigSelectors, 'currentChatConfig').mockReturnValueOnce({
searchMode: 'auto', // not 'off'
useModelBuiltinSearch: false,
} as any);
vi.spyOn(chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() =>
({
searchMode: 'auto', // not 'off'
useModelBuiltinSearch: false,
}) as any,
);
// Mock AI infra store state
vi.spyOn(aiModelSelectors, 'isModelHasBuiltinSearch').mockReturnValueOnce(() => false);
@@ -940,10 +978,13 @@ describe('ChatService', () => {
const messages = [{ content: 'Search for something', role: 'user' }] as UIChatMessage[];
// Mock agent store state with search enabled and useModelBuiltinSearch enabled
vi.spyOn(agentChatConfigSelectors, 'currentChatConfig').mockReturnValueOnce({
searchMode: 'auto', // not 'off'
useModelBuiltinSearch: true,
} as any);
vi.spyOn(chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() =>
({
searchMode: 'auto', // not 'off'
useModelBuiltinSearch: true,
}) as any,
);
// Mock AI infra store state - model has built-in search
vi.spyOn(aiModelSelectors, 'isModelHasBuiltinSearch').mockReturnValueOnce(() => true);
@@ -985,10 +1026,13 @@ describe('ChatService', () => {
const messages = [{ content: 'Search for something', role: 'user' }] as UIChatMessage[];
// Mock agent store state with search disabled
vi.spyOn(agentChatConfigSelectors, 'currentChatConfig').mockReturnValueOnce({
searchMode: 'off',
useModelBuiltinSearch: true,
} as any);
vi.spyOn(chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() =>
({
searchMode: 'off',
useModelBuiltinSearch: true,
}) as any,
);
// Mock AI infra store state
vi.spyOn(aiModelSelectors, 'isModelHasBuiltinSearch').mockReturnValueOnce(() => true);
@@ -1305,10 +1349,13 @@ describe('ChatService private methods', () => {
]);
// Mock agent chat config with context caching disabled
vi.spyOn(agentChatConfigSelectors, 'currentChatConfig').mockReturnValue({
disableContextCaching: true,
searchMode: 'off',
} as any);
vi.spyOn(chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() =>
({
disableContextCaching: true,
searchMode: 'off',
}) as any,
);
await chatService.createAssistantMessage({
messages,
@@ -1338,10 +1385,13 @@ describe('ChatService private methods', () => {
]);
// Mock agent chat config with context caching enabled (default)
vi.spyOn(agentChatConfigSelectors, 'currentChatConfig').mockReturnValue({
disableContextCaching: false,
searchMode: 'off',
} as any);
vi.spyOn(chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() =>
({
disableContextCaching: false,
searchMode: 'off',
}) as any,
);
await chatService.createAssistantMessage({
messages,
@@ -1364,10 +1414,13 @@ describe('ChatService private methods', () => {
vi.spyOn(aiModelSelectors, 'modelExtendParams').mockReturnValue(() => ['reasoningEffort']);
// Mock agent chat config with reasoning effort set
vi.spyOn(agentChatConfigSelectors, 'currentChatConfig').mockReturnValue({
reasoningEffort: 'high',
searchMode: 'off',
} as any);
vi.spyOn(chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() =>
({
reasoningEffort: 'high',
searchMode: 'off',
}) as any,
);
await chatService.createAssistantMessage({
messages,
@@ -1393,10 +1446,13 @@ describe('ChatService private methods', () => {
vi.spyOn(aiModelSelectors, 'modelExtendParams').mockReturnValue(() => ['thinkingBudget']);
// Mock agent chat config with thinking budget set
vi.spyOn(agentChatConfigSelectors, 'currentChatConfig').mockReturnValue({
thinkingBudget: 5000,
searchMode: 'off',
} as any);
vi.spyOn(chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() =>
({
thinkingBudget: 5000,
searchMode: 'off',
}) as any,
);
await chatService.createAssistantMessage({
messages,
-136
View File
@@ -1,136 +0,0 @@
import { isDesktop } from '@lobechat/const';
import {
ContextEngine,
GroupMessageFlattenProcessor,
HistorySummaryProvider,
HistoryTruncateProcessor,
InputTemplateProcessor,
KnowledgeInjector,
MessageCleanupProcessor,
MessageContentProcessor,
PlaceholderVariablesProcessor,
SystemRoleInjector,
ToolCallProcessor,
ToolMessageReorder,
ToolNameResolver,
ToolSystemRoleProvider,
} from '@lobechat/context-engine';
import { historySummaryPrompt } from '@lobechat/prompts';
import { OpenAIChatMessage, UIChatMessage } from '@lobechat/types';
import { VARIABLE_GENERATORS } from '@lobechat/utils/client';
import { isCanUseFC } from '@/helpers/isCanUseFC';
import { getAgentStoreState } from '@/store/agent';
import { agentSelectors } from '@/store/agent/selectors';
import { getToolStoreState } from '@/store/tool';
import { toolSelectors } from '@/store/tool/selectors';
import { isCanUseVideo, isCanUseVision } from './helper';
interface ContextEngineeringContext {
enableHistoryCount?: boolean;
historyCount?: number;
historySummary?: string;
inputTemplate?: string;
messages: UIChatMessage[];
model: string;
provider: string;
sessionId?: string;
systemRole?: string;
tools?: string[];
}
export const contextEngineering = async ({
messages = [],
tools,
model,
provider,
systemRole,
inputTemplate,
enableHistoryCount,
historyCount,
historySummary,
}: ContextEngineeringContext): Promise<OpenAIChatMessage[]> => {
const toolNameResolver = new ToolNameResolver();
// Get enabled agent files with content and knowledge bases from agent store
const agentStoreState = getAgentStoreState();
const agentFiles = agentSelectors.currentAgentFiles(agentStoreState);
const agentKnowledgeBases = agentSelectors.currentAgentKnowledgeBases(agentStoreState);
const fileContents = agentFiles
.filter((file) => file.enabled && file.content)
.map((file) => ({ content: file.content!, fileId: file.id, filename: file.name }));
const knowledgeBases = agentKnowledgeBases
.filter((kb) => kb.enabled)
.map((kb) => ({ description: kb.description, id: kb.id, name: kb.name }));
const pipeline = new ContextEngine({
pipeline: [
// 1. History truncation (MUST be first, before any message injection)
new HistoryTruncateProcessor({ enableHistoryCount, historyCount }),
// --------- Create system role injection providers
// 2. System role injection (agent's system role)
new SystemRoleInjector({ systemRole }),
// 3. Knowledge injection (full content for agent files + metadata for knowledge bases)
new KnowledgeInjector({ fileContents, knowledgeBases }),
// 4. Tool system role injection
new ToolSystemRoleProvider({
getToolSystemRoles: (tools) => toolSelectors.enabledSystemRoles(tools)(getToolStoreState()),
isCanUseFC,
model,
provider,
tools,
}),
// 5. History summary injection
new HistorySummaryProvider({
formatHistorySummary: historySummaryPrompt,
historySummary: historySummary,
}),
// Create message processing processors
// 6. Input template processing
new InputTemplateProcessor({ inputTemplate }),
// 7. Placeholder variables processing
new PlaceholderVariablesProcessor({ variableGenerators: VARIABLE_GENERATORS }),
// 8. Group message flatten (convert role=group to standard assistant + tool messages)
new GroupMessageFlattenProcessor(),
// 8.5 Message content processing
new MessageContentProcessor({
fileContext: { enabled: true, includeFileUrl: !isDesktop },
isCanUseVideo,
isCanUseVision,
model,
provider,
}),
// 9. Tool call processing
new ToolCallProcessor({
genToolCallingName: toolNameResolver.generate.bind(toolNameResolver),
isCanUseFC,
model,
provider,
}),
// 10. Tool message reordering
new ToolMessageReorder(),
// 11. Message cleanup (final step, keep only necessary fields)
new MessageCleanupProcessor(),
],
});
const result = await pipeline.process({ messages });
return result.messages;
};
+144 -124
View File
@@ -1,3 +1,7 @@
import { AgentBuilderIdentifier } from '@lobechat/builtin-tool-agent-builder';
import { MemoryManifest } from '@lobechat/builtin-tool-memory';
import { KLAVIS_SERVER_TYPES } from '@lobechat/const';
import type { OfficialToolItem } from '@lobechat/context-engine';
import {
FetchSSEOptions,
fetchSSE,
@@ -7,40 +11,50 @@ import {
import { AgentRuntimeError, ChatCompletionErrorPayload } from '@lobechat/model-runtime';
import { ChatErrorType, TracePayload, TraceTagMap, UIChatMessage } from '@lobechat/types';
import { PluginRequestPayload, createHeadersWithPluginSettings } from '@lobehub/chat-plugin-sdk';
import { merge } from 'lodash-es';
import { merge } from 'es-toolkit/compat';
import { ModelProvider } from 'model-bank';
import { enableAuth } from '@/const/auth';
import { DEFAULT_AGENT_CONFIG } from '@/const/settings';
import { isDesktop } from '@/const/version';
import { getSearchConfig } from '@/helpers/getSearchConfig';
import { createAgentToolsEngine, createToolsEngine } from '@/helpers/toolEngineering';
import { getAgentStoreState } from '@/store/agent';
import { agentChatConfigSelectors, agentSelectors } from '@/store/agent/selectors';
import { aiModelSelectors, aiProviderSelectors, getAiInfraStoreState } from '@/store/aiInfra';
import { getSessionStoreState } from '@/store/session';
import { sessionMetaSelectors } from '@/store/session/selectors';
import { getToolStoreState } from '@/store/tool';
import { pluginSelectors } from '@/store/tool/selectors';
import { getUserStoreState, useUserStore } from '@/store/user';
import {
preferenceSelectors,
userGeneralSettingsSelectors,
userProfileSelectors,
} from '@/store/user/selectors';
agentByIdSelectors,
agentChatConfigSelectors,
agentSelectors,
chatConfigByIdSelectors,
} from '@/store/agent/selectors';
import { aiProviderSelectors, getAiInfraStoreState } from '@/store/aiInfra';
import { getChatStoreState } from '@/store/chat';
import { getToolStoreState } from '@/store/tool';
import {
builtinToolSelectors,
klavisStoreSelectors,
pluginSelectors,
} from '@/store/tool/selectors';
import { getUserStoreState, useUserStore } from '@/store/user';
import { userGeneralSettingsSelectors, userProfileSelectors } from '@/store/user/selectors';
import type { ChatStreamPayload, OpenAIChatMessage } from '@/types/openai/chat';
import { fetchWithInvokeStream } from '@/utils/electron/desktopRemoteRPCFetch';
import { createErrorResponse } from '@/utils/errorResponse';
import { createTraceHeader, getTraceId } from '@/utils/trace';
import { createHeaderWithAuth } from '../_auth';
import { API_ENDPOINTS } from '../_url';
import { initializeWithClientStore } from './clientModelRuntime';
import { contextEngineering } from './contextEngineering';
import { findDeploymentName, isEnableFetchOnClient, resolveRuntimeProvider } from './helper';
import {
contextEngineering,
getTargetAgentId,
initializeWithClientStore,
resolveAgentConfig,
resolveModelExtendParams,
resolveUserMemories,
} from './mecha';
import { FetchOptions } from './types';
interface GetChatCompletionPayload extends Partial<Omit<ChatStreamPayload, 'messages'>> {
agentId?: string;
groupId?: string;
messages: UIChatMessage[];
}
@@ -72,7 +86,7 @@ interface CreateAssistantMessageStream extends FetchSSEOptions {
class ChatService {
createAssistantMessage = async (
{ plugins: enabledPlugins, messages, ...params }: GetChatCompletionPayload,
{ plugins: enabledPlugins, messages, agentId, groupId, ...params }: GetChatCompletionPayload,
options?: FetchOptions,
) => {
const payload = merge(
@@ -84,11 +98,25 @@ class ChatService {
params,
);
const searchConfig = getSearchConfig(payload.model, payload.provider!);
// =================== 1. resolve agent config =================== //
// =================== 1. preprocess tools =================== //
const targetAgentId = getTargetAgentId(agentId);
const pluginIds = [...(enabledPlugins || [])];
// Resolve agent config with builtin agent runtime config merged
// plugins is already merged (runtime plugins > agent config plugins)
const {
agentConfig,
chatConfig,
plugins: pluginIds,
} = resolveAgentConfig({
agentId: targetAgentId,
model: payload.model,
plugins: enabledPlugins,
provider: payload.provider,
});
// Get search config with agentId for agent-specific settings
const searchConfig = getSearchConfig(payload.model, payload.provider!, targetAgentId);
const toolsEngine = createAgentToolsEngine({
model: payload.model,
@@ -101,17 +129,93 @@ class ChatService {
toolIds: pluginIds,
});
// ============ 2. preprocess messages ============ //
// =================== 1.1 process user memories =================== //
const agentStoreState = getAgentStoreState();
const agentConfig = agentSelectors.currentAgentConfig(agentStoreState);
const chatConfig = agentChatConfigSelectors.currentChatConfig(agentStoreState);
const isMemoryPluginEnabled =
pluginIds.includes(MemoryManifest.identifier) ||
enabledToolIds.includes(MemoryManifest.identifier);
const userMemories = await resolveUserMemories({
isMemoryPluginEnabled,
messages,
});
// =================== 1.2 build agent builder context =================== //
// Check if Agent Builder tool is enabled and build context for it
// Note: When Agent Builder is active, we need to get the context of the agent being edited,
// which is stored in chatStore.activeAgentId, not the targetAgentId (which is the Agent Builder itself)
const isAgentBuilderEnabled = enabledToolIds.includes(AgentBuilderIdentifier);
let agentBuilderContext;
if (isAgentBuilderEnabled) {
const activeAgentId = getChatStoreState().activeAgentId || '';
const baseContext =
agentByIdSelectors.getAgentBuilderContextById(activeAgentId)(getAgentStoreState());
// Build official tools list (builtin tools + Klavis tools)
const toolState = getToolStoreState();
const enabledPlugins =
agentSelectors.getAgentConfigById(activeAgentId)(getAgentStoreState()).plugins || [];
const officialTools: OfficialToolItem[] = [];
// Get builtin tools (excluding Klavis tools)
const builtinTools = builtinToolSelectors.metaList(toolState);
const klavisIdentifiers = new Set(KLAVIS_SERVER_TYPES.map((t) => t.identifier));
for (const tool of builtinTools) {
// Skip Klavis tools in builtin list (they'll be shown separately)
if (klavisIdentifiers.has(tool.identifier)) continue;
officialTools.push({
description: tool.meta?.description,
enabled: enabledPlugins.includes(tool.identifier),
identifier: tool.identifier,
installed: true,
name: tool.meta?.title || tool.identifier,
type: 'builtin',
});
}
// Get Klavis tools (if enabled)
const isKlavisEnabled =
typeof window !== 'undefined' &&
window.global_serverConfigStore?.getState()?.serverConfig?.enableKlavis;
if (isKlavisEnabled) {
const allKlavisServers = klavisStoreSelectors.getServers(toolState);
for (const klavisType of KLAVIS_SERVER_TYPES) {
const server = allKlavisServers.find((s) => s.identifier === klavisType.identifier);
officialTools.push({
description: `LobeHub Mcp Server: ${klavisType.label}`,
enabled: enabledPlugins.includes(klavisType.identifier),
identifier: klavisType.identifier,
installed: !!server,
name: klavisType.label,
type: 'klavis',
});
}
}
agentBuilderContext = {
...baseContext,
officialTools,
};
}
// Apply context engineering with preprocessing configuration
const oaiMessages = await contextEngineering({
enableHistoryCount: agentChatConfigSelectors.enableHistoryCount(agentStoreState),
// include user messages
historyCount: agentChatConfigSelectors.historyCount(agentStoreState) + 2,
// Note: agentConfig.systemRole is already resolved by resolveAgentConfig for builtin agents
const modelMessages = await contextEngineering({
agentBuilderContext,
agentId: targetAgentId,
enableHistoryCount:
chatConfigByIdSelectors.getEnableHistoryCountById(targetAgentId)(getAgentStoreState()),
groupId,
historyCount:
chatConfigByIdSelectors.getHistoryCountById(targetAgentId)(getAgentStoreState()) + 2,
inputTemplate: chatConfig.inputTemplate,
messages,
model: payload.model,
@@ -119,106 +223,25 @@ class ChatService {
sessionId: options?.trace?.sessionId,
systemRole: agentConfig.systemRole,
tools: enabledToolIds,
userMemories,
});
// ============ 3. process extend params ============ //
let extendParams: Record<string, any> = {};
const aiInfraStoreState = getAiInfraStoreState();
const isModelHasExtendParams = aiModelSelectors.isModelHasExtendParams(
payload.model,
payload.provider!,
)(aiInfraStoreState);
// model
if (isModelHasExtendParams) {
const modelExtendParams = aiModelSelectors.modelExtendParams(
payload.model,
payload.provider!,
)(aiInfraStoreState);
// if model has extended params, then we need to check if the model can use reasoning
if (modelExtendParams!.includes('enableReasoning')) {
if (chatConfig.enableReasoning) {
extendParams.thinking = {
budget_tokens: chatConfig.reasoningBudgetToken || 1024,
type: 'enabled',
};
} else {
extendParams.thinking = {
budget_tokens: 0,
type: 'disabled',
};
}
} else if (modelExtendParams!.includes('reasoningBudgetToken')) {
// For models that only have reasoningBudgetToken without enableReasoning
extendParams.thinking = {
budget_tokens: chatConfig.reasoningBudgetToken || 1024,
type: 'enabled',
};
}
if (
modelExtendParams!.includes('disableContextCaching') &&
chatConfig.disableContextCaching
) {
extendParams.enabledContextCaching = false;
}
if (modelExtendParams!.includes('reasoningEffort') && chatConfig.reasoningEffort) {
extendParams.reasoning_effort = chatConfig.reasoningEffort;
}
if (modelExtendParams!.includes('gpt5ReasoningEffort') && chatConfig.gpt5ReasoningEffort) {
extendParams.reasoning_effort = chatConfig.gpt5ReasoningEffort;
}
if (
modelExtendParams!.includes('gpt5_1ReasoningEffort') &&
chatConfig.gpt5_1ReasoningEffort
) {
extendParams.reasoning_effort = chatConfig.gpt5_1ReasoningEffort;
}
if (modelExtendParams!.includes('textVerbosity') && chatConfig.textVerbosity) {
extendParams.verbosity = chatConfig.textVerbosity;
}
if (modelExtendParams!.includes('thinking') && chatConfig.thinking) {
extendParams.thinking = { type: chatConfig.thinking };
}
if (
modelExtendParams!.includes('thinkingBudget') &&
chatConfig.thinkingBudget !== undefined
) {
extendParams.thinkingBudget = chatConfig.thinkingBudget;
}
if (modelExtendParams!.includes('thinkingLevel') && chatConfig.thinkingLevel) {
extendParams.thinkingLevel = chatConfig.thinkingLevel;
}
if (modelExtendParams!.includes('urlContext') && chatConfig.urlContext) {
extendParams.urlContext = chatConfig.urlContext;
}
if (modelExtendParams!.includes('imageAspectRatio') && chatConfig.imageAspectRatio) {
extendParams.imageAspectRatio = chatConfig.imageAspectRatio;
}
if (modelExtendParams!.includes('imageResolution') && chatConfig.imageResolution) {
extendParams.imageResolution = chatConfig.imageResolution;
}
}
const extendParams = resolveModelExtendParams({
chatConfig,
model: payload.model,
provider: payload.provider!,
});
return this.getChatCompletion(
{
...params,
...extendParams,
enabledSearch: searchConfig.enabledSearch && searchConfig.useModelSearch ? true : undefined,
messages: oaiMessages,
messages: modelMessages,
// Use the chatConfig from the target agent for streaming preference
stream: chatConfig.enableStreaming !== false,
tools,
},
options,
@@ -303,10 +326,7 @@ class ChatService {
let fetcher: typeof fetch | undefined = undefined;
// Add desktop remote RPC fetch support
if (isDesktop) {
fetcher = fetchWithInvokeStream;
} else if (enableFetchOnClient) {
if (enableFetchOnClient) {
/**
* Notes:
* 1. Browser agent runtime will skip auth check if a key and endpoint provided by
@@ -457,9 +477,9 @@ class ChatService {
};
private mapTrace = (trace?: TracePayload, tag?: TraceTagMap): TracePayload => {
const tags = sessionMetaSelectors.currentAgentMeta(getSessionStoreState()).tags || [];
const tags = agentSelectors.currentAgentMeta(getAgentStoreState()).tags || [];
const enabled = preferenceSelectors.userAllowTrace(getUserStoreState());
const enabled = userGeneralSettingsSelectors.telemetry(getUserStoreState());
if (!enabled) return { ...trace, enabled: false };
@@ -0,0 +1,407 @@
import * as builtinAgents from '@lobechat/builtin-agents';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import * as agentStore from '@/store/agent';
import * as agentSelectors from '@/store/agent/selectors';
import { resolveAgentConfig } from './agentConfigResolver';
describe('resolveAgentConfig', () => {
const mockAgentStoreState = { someState: true };
const mockAgentConfig = {
model: 'gpt-4',
plugins: ['plugin-a', 'plugin-b'],
systemRole: 'You are a helpful assistant',
};
const mockChatConfig = {
enableStreaming: true,
};
beforeEach(() => {
vi.restoreAllMocks();
vi.spyOn(agentStore, 'getAgentStoreState').mockReturnValue(mockAgentStoreState as any);
vi.spyOn(agentSelectors.agentSelectors, 'getAgentConfigById').mockReturnValue(
() => mockAgentConfig as any,
);
vi.spyOn(agentSelectors.chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() => mockChatConfig as any,
);
});
describe('regular agent (non-builtin)', () => {
beforeEach(() => {
// No slug means regular agent
vi.spyOn(agentSelectors.agentSelectors, 'getAgentSlugById').mockReturnValue(() => undefined);
});
it('should return plugins from agent config', () => {
const result = resolveAgentConfig({ agentId: 'test-agent' });
expect(result.plugins).toEqual(['plugin-a', 'plugin-b']);
expect(result.isBuiltinAgent).toBe(false);
});
it('should return empty array when agent config has no plugins', () => {
vi.spyOn(agentSelectors.agentSelectors, 'getAgentConfigById').mockReturnValue(
() =>
({
...mockAgentConfig,
plugins: undefined,
}) as any,
);
const result = resolveAgentConfig({ agentId: 'test-agent' });
expect(result.plugins).toEqual([]);
expect(result.isBuiltinAgent).toBe(false);
});
it('should return empty array when agent config plugins is null', () => {
vi.spyOn(agentSelectors.agentSelectors, 'getAgentConfigById').mockReturnValue(
() =>
({
...mockAgentConfig,
plugins: null as any,
}) as any,
);
const result = resolveAgentConfig({ agentId: 'test-agent' });
expect(result.plugins).toEqual([]);
});
it('should return agent config and chat config correctly', () => {
const result = resolveAgentConfig({ agentId: 'test-agent' });
expect(result.agentConfig).toEqual(mockAgentConfig);
expect(result.chatConfig).toEqual(mockChatConfig);
});
describe('params adjustment based on chatConfig', () => {
const mockAgentConfigWithParams = {
model: 'gpt-4',
params: {
max_tokens: 4096,
reasoning_effort: 'high',
temperature: 0.7,
},
plugins: ['plugin-a'],
systemRole: 'You are a helpful assistant',
};
beforeEach(() => {
vi.spyOn(agentSelectors.agentSelectors, 'getAgentConfigById').mockReturnValue(
() => mockAgentConfigWithParams as any,
);
});
it('should include max_tokens when enableMaxTokens is true', () => {
vi.spyOn(agentSelectors.chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() =>
({
enableMaxTokens: true,
enableReasoningEffort: false,
}) as any,
);
const result = resolveAgentConfig({ agentId: 'test-agent' });
expect(result.agentConfig.params.max_tokens).toBe(4096);
expect(result.agentConfig.params.reasoning_effort).toBeUndefined();
expect(result.agentConfig.params.temperature).toBe(0.7);
});
it('should set max_tokens to undefined when enableMaxTokens is false', () => {
vi.spyOn(agentSelectors.chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() =>
({
enableMaxTokens: false,
enableReasoningEffort: true,
}) as any,
);
const result = resolveAgentConfig({ agentId: 'test-agent' });
expect(result.agentConfig.params.max_tokens).toBeUndefined();
expect(result.agentConfig.params.reasoning_effort).toBe('high');
});
it('should include reasoning_effort when enableReasoningEffort is true', () => {
vi.spyOn(agentSelectors.chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() =>
({
enableMaxTokens: false,
enableReasoningEffort: true,
}) as any,
);
const result = resolveAgentConfig({ agentId: 'test-agent' });
expect(result.agentConfig.params.reasoning_effort).toBe('high');
});
it('should set reasoning_effort to undefined when enableReasoningEffort is false', () => {
vi.spyOn(agentSelectors.chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() =>
({
enableMaxTokens: true,
enableReasoningEffort: false,
}) as any,
);
const result = resolveAgentConfig({ agentId: 'test-agent' });
expect(result.agentConfig.params.reasoning_effort).toBeUndefined();
});
it('should handle both params being enabled', () => {
vi.spyOn(agentSelectors.chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() =>
({
enableMaxTokens: true,
enableReasoningEffort: true,
}) as any,
);
const result = resolveAgentConfig({ agentId: 'test-agent' });
expect(result.agentConfig.params.max_tokens).toBe(4096);
expect(result.agentConfig.params.reasoning_effort).toBe('high');
});
it('should handle both params being disabled', () => {
vi.spyOn(agentSelectors.chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() =>
({
enableMaxTokens: false,
enableReasoningEffort: false,
}) as any,
);
const result = resolveAgentConfig({ agentId: 'test-agent' });
expect(result.agentConfig.params.max_tokens).toBeUndefined();
expect(result.agentConfig.params.reasoning_effort).toBeUndefined();
});
it('should not mutate original agent config', () => {
vi.spyOn(agentSelectors.chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() =>
({
enableMaxTokens: false,
enableReasoningEffort: false,
}) as any,
);
resolveAgentConfig({ agentId: 'test-agent' });
// Original should be unchanged
expect(mockAgentConfigWithParams.params.max_tokens).toBe(4096);
expect(mockAgentConfigWithParams.params.reasoning_effort).toBe('high');
});
it('should skip params adjustment when params is undefined', () => {
vi.spyOn(agentSelectors.agentSelectors, 'getAgentConfigById').mockReturnValue(
() =>
({
model: 'gpt-4',
plugins: ['plugin-a'],
systemRole: 'You are a helpful assistant',
}) as any,
);
vi.spyOn(agentSelectors.chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() =>
({
enableMaxTokens: true,
enableReasoningEffort: true,
}) as any,
);
const result = resolveAgentConfig({ agentId: 'test-agent' });
expect(result.agentConfig.params).toBeUndefined();
});
});
});
describe('builtin agent', () => {
beforeEach(() => {
// Has slug means builtin agent
vi.spyOn(agentSelectors.agentSelectors, 'getAgentSlugById').mockReturnValue(
() => 'agent-builder',
);
});
it('should use runtime plugins when available', () => {
vi.spyOn(builtinAgents, 'getAgentRuntimeConfig').mockReturnValue({
plugins: ['runtime-plugin-1', 'runtime-plugin-2'],
systemRole: 'Runtime system role',
});
const result = resolveAgentConfig({ agentId: 'builtin-agent' });
expect(result.plugins).toEqual(['runtime-plugin-1', 'runtime-plugin-2']);
expect(result.isBuiltinAgent).toBe(true);
expect(result.slug).toBe('agent-builder');
});
it('should fallback to agent config plugins when runtime plugins is undefined', () => {
vi.spyOn(builtinAgents, 'getAgentRuntimeConfig').mockReturnValue({
plugins: undefined,
systemRole: 'Runtime system role',
});
const result = resolveAgentConfig({ agentId: 'builtin-agent' });
expect(result.plugins).toEqual(['plugin-a', 'plugin-b']);
expect(result.isBuiltinAgent).toBe(true);
});
it('should fallback to agent config plugins when runtime plugins is empty array', () => {
vi.spyOn(builtinAgents, 'getAgentRuntimeConfig').mockReturnValue({
plugins: [],
systemRole: 'Runtime system role',
});
const result = resolveAgentConfig({ agentId: 'builtin-agent' });
expect(result.plugins).toEqual(['plugin-a', 'plugin-b']);
expect(result.isBuiltinAgent).toBe(true);
});
it('should fallback to agent config plugins when runtimeConfig is undefined', () => {
vi.spyOn(builtinAgents, 'getAgentRuntimeConfig').mockReturnValue(undefined);
const result = resolveAgentConfig({ agentId: 'builtin-agent' });
expect(result.plugins).toEqual(['plugin-a', 'plugin-b']);
expect(result.isBuiltinAgent).toBe(true);
});
it('should use runtime systemRole when available', () => {
vi.spyOn(builtinAgents, 'getAgentRuntimeConfig').mockReturnValue({
plugins: ['runtime-plugin'],
systemRole: 'Runtime system role',
});
const result = resolveAgentConfig({ agentId: 'builtin-agent' });
expect(result.agentConfig.systemRole).toBe('Runtime system role');
});
it('should fallback to agent config systemRole when runtime systemRole is undefined', () => {
vi.spyOn(builtinAgents, 'getAgentRuntimeConfig').mockReturnValue({
plugins: ['runtime-plugin'],
systemRole: undefined as any,
});
const result = resolveAgentConfig({ agentId: 'builtin-agent' });
expect(result.agentConfig.systemRole).toBe('You are a helpful assistant');
});
it('should return empty plugins when both runtime and agent config have no plugins', () => {
vi.spyOn(agentSelectors.agentSelectors, 'getAgentConfigById').mockReturnValue(
() =>
({
...mockAgentConfig,
plugins: undefined,
}) as any,
);
vi.spyOn(builtinAgents, 'getAgentRuntimeConfig').mockReturnValue({
plugins: undefined,
systemRole: 'Runtime system role',
});
const result = resolveAgentConfig({ agentId: 'builtin-agent' });
expect(result.plugins).toEqual([]);
});
it('should pass context parameters to getAgentRuntimeConfig', () => {
const getAgentRuntimeConfigSpy = vi
.spyOn(builtinAgents, 'getAgentRuntimeConfig')
.mockReturnValue({
plugins: ['runtime-plugin'],
systemRole: 'Runtime system role',
});
const targetAgentConfig = { model: 'target-model' };
resolveAgentConfig({
agentId: 'builtin-agent',
documentContent: 'some document content',
model: 'gpt-4-turbo',
plugins: ['input-plugin'],
targetAgentConfig: targetAgentConfig as any,
});
expect(getAgentRuntimeConfigSpy).toHaveBeenCalledWith('agent-builder', {
documentContent: 'some document content',
model: 'gpt-4-turbo',
plugins: ['input-plugin'],
targetAgentConfig,
});
});
it('should merge runtime chatConfig with base chatConfig', () => {
vi.spyOn(builtinAgents, 'getAgentRuntimeConfig').mockReturnValue({
chatConfig: {
enableHistoryCount: false,
historyCount: 10,
},
plugins: ['runtime-plugin'],
systemRole: 'Runtime system role',
});
const result = resolveAgentConfig({ agentId: 'builtin-agent' });
// Base chatConfig has enableStreaming: true
// Runtime chatConfig adds enableHistoryCount: false and historyCount: 10
expect(result.chatConfig).toEqual({
enableHistoryCount: false,
enableStreaming: true,
historyCount: 10,
});
});
it('should override base chatConfig values with runtime chatConfig', () => {
vi.spyOn(agentSelectors.chatConfigByIdSelectors, 'getChatConfigById').mockReturnValue(
() =>
({
enableHistoryCount: true,
enableStreaming: true,
historyCount: 20,
}) as any,
);
vi.spyOn(builtinAgents, 'getAgentRuntimeConfig').mockReturnValue({
chatConfig: {
enableHistoryCount: false,
},
plugins: ['runtime-plugin'],
systemRole: 'Runtime system role',
});
const result = resolveAgentConfig({ agentId: 'builtin-agent' });
expect(result.chatConfig).toEqual({
enableHistoryCount: false,
enableStreaming: true,
historyCount: 20,
});
});
it('should use base chatConfig when runtime chatConfig is undefined', () => {
vi.spyOn(builtinAgents, 'getAgentRuntimeConfig').mockReturnValue({
chatConfig: undefined,
plugins: ['runtime-plugin'],
systemRole: 'Runtime system role',
});
const result = resolveAgentConfig({ agentId: 'builtin-agent' });
expect(result.chatConfig).toEqual(mockChatConfig);
});
});
});
@@ -0,0 +1,181 @@
import { BUILTIN_AGENT_SLUGS, getAgentRuntimeConfig } from '@lobechat/builtin-agents';
import { LobeAgentChatConfig, LobeAgentConfig } from '@lobechat/types';
import { produce } from 'immer';
import { getAgentStoreState } from '@/store/agent';
import { agentSelectors, chatConfigByIdSelectors } from '@/store/agent/selectors';
import { getChatGroupStoreState } from '@/store/agentGroup';
import { agentGroupSelectors } from '@/store/agentGroup/selectors';
/**
* Applies params adjustments based on chatConfig settings.
*
* This function handles the conditional enabling/disabling of certain params:
* - max_tokens: Only included if chatConfig.enableMaxTokens is true
* - reasoning_effort: Only included if chatConfig.enableReasoningEffort is true
*
* Uses immer to create a new object without mutating the original.
*/
const applyParamsFromChatConfig = (
agentConfig: LobeAgentConfig,
chatConfig: LobeAgentChatConfig,
): LobeAgentConfig => {
// If params is not defined, return agentConfig as-is
if (!agentConfig.params) {
return agentConfig;
}
return produce(agentConfig, (draft) => {
// Only include max_tokens if enableMaxTokens is true
draft.params.max_tokens = chatConfig.enableMaxTokens ? draft.params.max_tokens : undefined;
// Only include reasoning_effort if enableReasoningEffort is true
draft.params.reasoning_effort = chatConfig.enableReasoningEffort
? draft.params.reasoning_effort
: undefined;
});
};
/**
* Runtime context for resolving agent config
*/
export interface AgentConfigResolverContext {
/** Agent ID to resolve config for */
agentId: string;
// Builtin agent specific context
/** Document content for page-agent */
documentContent?: string;
/** Current model being used (for template variables) */
model?: string;
/** Plugins enabled for the agent */
plugins?: string[];
/** Current provider */
provider?: string;
/** Target agent config for agent-builder */
targetAgentConfig?: LobeAgentConfig;
}
/**
* Resolved agent config with runtime values merged
*/
export interface ResolvedAgentConfig {
/** The resolved agent config */
agentConfig: LobeAgentConfig;
/** The chat config */
chatConfig: LobeAgentChatConfig;
/** Whether this is a builtin agent */
isBuiltinAgent: boolean;
/**
* Final merged plugins for the agent
* For builtin agents: runtime plugins (if any) or fallback to agent config plugins
* For regular agents: agent config plugins
*/
plugins: string[];
/** The agent's slug (if builtin) */
slug?: string;
}
/**
* Resolves the agent config, merging runtime config for builtin agents
*
* For builtin agents (identified by slug), this will:
* 1. Get the base config from the agent store
* 2. Get the runtime config from @lobechat/builtin-agents
* 3. Merge the runtime systemRole into the agent config
*
* For regular agents, this simply returns the config from the store.
*/
export const resolveAgentConfig = (ctx: AgentConfigResolverContext): ResolvedAgentConfig => {
const { agentId, model, documentContent, plugins, targetAgentConfig } = ctx;
const agentStoreState = getAgentStoreState();
// Get base config from store
const agentConfig = agentSelectors.getAgentConfigById(agentId)(agentStoreState);
const chatConfig = chatConfigByIdSelectors.getChatConfigById(agentId)(agentStoreState);
// Base plugins from agent config
const basePlugins = agentConfig.plugins ?? [];
// Check if this is a builtin agent
const slug = agentSelectors.getAgentSlugById(agentId)(agentStoreState);
if (!slug) {
// Regular agent - use provided plugins if available, fallback to agent's plugins
const finalPlugins = plugins && plugins.length > 0 ? plugins : basePlugins;
// Apply params adjustments based on chatConfig
const finalAgentConfig = applyParamsFromChatConfig(agentConfig, chatConfig);
return {
agentConfig: finalAgentConfig,
chatConfig,
isBuiltinAgent: false,
plugins: finalPlugins,
};
}
// Build groupSupervisorContext if this is a group-supervisor agent
let groupSupervisorContext;
if (slug === BUILTIN_AGENT_SLUGS.groupSupervisor) {
const groupStoreState = getChatGroupStoreState();
// Find the group by supervisor agent ID
const group = agentGroupSelectors.getGroupBySupervisorAgentId(agentId)(groupStoreState);
if (group) {
const groupMembers = agentGroupSelectors.getGroupMembers(group.id)(groupStoreState);
groupSupervisorContext = {
availableAgents: groupMembers.map((agent) => ({ id: agent.id, title: agent.title })),
groupId: group.id,
groupTitle: group.title || 'Group Chat',
systemPrompt: group.config?.systemPrompt,
};
}
}
// Builtin agent - merge runtime config
const runtimeConfig = getAgentRuntimeConfig(slug, {
documentContent,
groupSupervisorContext,
model,
plugins,
targetAgentConfig,
});
// Merge runtime systemRole into agent config
const resolvedAgentConfig: LobeAgentConfig = {
...agentConfig,
// Use runtime systemRole if available, otherwise fallback to stored systemRole
systemRole: runtimeConfig?.systemRole ?? agentConfig.systemRole,
};
// Merge plugins: runtime plugins take priority, fallback to base plugins
const finalPlugins =
runtimeConfig?.plugins && runtimeConfig.plugins.length > 0
? runtimeConfig.plugins
: basePlugins;
// Merge chatConfig: runtime chatConfig overrides base chatConfig
const resolvedChatConfig: LobeAgentChatConfig = {
...chatConfig,
...runtimeConfig?.chatConfig,
};
// Apply params adjustments based on chatConfig
const finalAgentConfig = applyParamsFromChatConfig(resolvedAgentConfig, resolvedChatConfig);
return {
agentConfig: finalAgentConfig,
chatConfig: resolvedChatConfig,
isBuiltinAgent: true,
plugins: finalPlugins,
slug,
};
};
/**
* Get the target agent ID, falling back to active agent if not provided
*/
export const getTargetAgentId = (agentId?: string): string => {
const agentStoreState = getAgentStoreState();
return agentId || agentStoreState.activeAgentId || '';
};
@@ -1,6 +1,6 @@
import { ModelRuntime } from '@lobechat/model-runtime';
import { createPayloadWithKeyVaults } from '../_auth';
import { createPayloadWithKeyVaults } from '../../_auth';
export interface InitializeWithClientStoreOptions {
payload?: any;
@@ -3,8 +3,9 @@ import { afterEach, describe, expect, it, vi } from 'vitest';
import * as isCanUseFCModule from '@/helpers/isCanUseFC';
import * as helpers from '../helper';
import { contextEngineering } from './contextEngineering';
import * as helpers from './helper';
import type { UserMemoriesResult } from './memoryManager';
// Mock VARIABLE_GENERATORS
vi.mock('@/utils/client/parserPlaceholder', () => ({
@@ -434,6 +435,66 @@ describe('contextEngineering', () => {
expect(content[1].image_url.url).toBe('data:image/png;base64,abc123');
});
it('should merge custom memory placeholder variables', async () => {
const messages: UIChatMessage[] = [
{
role: 'system',
content:
'Memory load: available={{memory_available}}, total contexts={{memory_contexts_count}}\n{{memory_summary}}',
createdAt: Date.now(),
id: 'memory-placeholder-test',
meta: {},
updatedAt: Date.now(),
},
];
const userMemories: UserMemoriesResult = {
fetchedAt: Date.now(),
memories: {
contexts: [
{
accessedAt: new Date('2024-01-01T00:00:00.000Z'),
associatedObjects: [],
associatedSubjects: [],
createdAt: new Date('2024-01-01T00:00:00.000Z'),
currentStatus: 'active',
description: 'Weekly syncs for LobeHub',
id: 'ctx-1',
metadata: {},
scoreImpact: 0.8,
scoreUrgency: 0.5,
tags: ['project'],
title: 'LobeHub',
type: 'project',
updatedAt: new Date('2024-01-02T00:00:00.000Z'),
userMemoryIds: ['mem-1'],
},
],
experiences: [],
preferences: [],
},
};
const result = await contextEngineering({
userMemories,
messages,
model: 'gpt-4',
provider: 'openai',
});
expect(result[0].role).toBe('system');
expect(result[0].content).toContain(
'<user_memories contexts="1" experiences="0" memory_fetched_at="',
);
expect(result[0].content).toContain(
'" preferences="0"><user_memories_context id="ctx-1"><context_title>LobeHub</context_title><context_description>Weekly syncs for LobeHub</context_description></user_memories_context></user_memories>',
);
expect(result[0].content).toContain('<context_title>LobeHub</context_title>');
expect(result[1].content).toBe(
'Memory load: available={{memory_available}}, total contexts={{memory_contexts_count}}\n{{memory_summary}}',
);
});
it('should handle missing placeholder variables gracefully', async () => {
const messages: UIChatMessage[] = [
{
@@ -0,0 +1,283 @@
import { AgentBuilderIdentifier } from '@lobechat/builtin-tool-agent-builder';
import { GroupAgentBuilderIdentifier } from '@lobechat/builtin-tool-group-agent-builder';
import { KLAVIS_SERVER_TYPES, isDesktop } from '@lobechat/const';
import {
AgentBuilderContext,
AgentGroupConfig,
GroupAgentBuilderContext,
GroupOfficialToolItem,
MessagesEngine,
} from '@lobechat/context-engine';
import { historySummaryPrompt } from '@lobechat/prompts';
import { OpenAIChatMessage, UIChatMessage } from '@lobechat/types';
import { VARIABLE_GENERATORS } from '@lobechat/utils/client';
import debug from 'debug';
import { isCanUseFC } from '@/helpers/isCanUseFC';
import { getAgentStoreState } from '@/store/agent';
import { agentSelectors } from '@/store/agent/selectors';
import { getChatGroupStoreState } from '@/store/agentGroup';
import { agentGroupSelectors } from '@/store/agentGroup/selectors';
import { getChatStoreState } from '@/store/chat';
import { getToolStoreState } from '@/store/tool';
import { builtinToolSelectors, klavisStoreSelectors, toolSelectors } from '@/store/tool/selectors';
import { isCanUseVideo, isCanUseVision } from '../helper';
import type { UserMemoriesResult } from './memoryManager';
const log = debug('context-engine:contextEngineering');
interface ContextEngineeringContext {
/** Agent Builder context for injecting current agent info */
agentBuilderContext?: AgentBuilderContext;
/** The agent ID that will respond (for group context injection) */
agentId?: string;
enableHistoryCount?: boolean;
/** Group ID for multi-agent scenarios */
groupId?: string;
historyCount?: number;
historySummary?: string;
inputTemplate?: string;
messages: UIChatMessage[];
model: string;
provider: string;
sessionId?: string;
systemRole?: string;
tools?: string[];
userMemories?: UserMemoriesResult;
}
// REVIEW:可能这里可以约束一下 identitypreferenceexp 的 重新排序或者裁切过的上下文进来而不是全部丢进来
export const contextEngineering = async ({
messages = [],
tools,
model,
provider,
systemRole,
inputTemplate,
userMemories,
enableHistoryCount,
historyCount,
historySummary,
agentBuilderContext,
agentId,
groupId,
}: ContextEngineeringContext): Promise<OpenAIChatMessage[]> => {
log('tools: %o', tools);
// Check if Agent Builder tool is enabled
const isAgentBuilderEnabled = tools?.includes(AgentBuilderIdentifier) ?? false;
// Check if Group Agent Builder tool is enabled
const isGroupAgentBuilderEnabled = tools?.includes(GroupAgentBuilderIdentifier) ?? false;
log('isAgentBuilderEnabled: %s', isAgentBuilderEnabled);
log('isGroupAgentBuilderEnabled: %s', isGroupAgentBuilderEnabled);
// Build agent group configuration if groupId is provided
let agentGroup: AgentGroupConfig | undefined;
if (groupId) {
const groupStoreState = getChatGroupStoreState();
const groupDetail = agentGroupSelectors.getGroupById(groupId)(groupStoreState);
if (groupDetail?.agents && groupDetail.agents.length > 0) {
const agentMap: AgentGroupConfig['agentMap'] = {};
const members: AgentGroupConfig['members'] = [];
// Find the responding agent to get its name and role
let currentAgentName: string | undefined;
let currentAgentRole: 'supervisor' | 'participant' | undefined;
for (const agent of groupDetail.agents) {
const role = agent.isSupervisor ? 'supervisor' : 'participant';
const name = agent.title || 'Untitled Agent';
agentMap[agent.id] = { name, role };
members.push({ id: agent.id, name, role });
// Capture responding agent info
if (agentId && agent.id === agentId) {
currentAgentName = name;
currentAgentRole = role;
}
}
agentGroup = {
agentMap,
currentAgentId: agentId,
currentAgentName,
currentAgentRole,
groupTitle: groupDetail.title || undefined,
members,
systemPrompt: groupDetail.config?.systemPrompt || undefined,
};
log('agentGroup built: %o', agentGroup);
}
}
// Get agent store state (used for both group agent builder context and file/knowledge base)
const agentStoreState = getAgentStoreState();
// Build group agent builder context if Group Agent Builder is enabled
// Note: Uses activeGroupId from chatStore to get the group being edited
let groupAgentBuilderContext: GroupAgentBuilderContext | undefined;
if (isGroupAgentBuilderEnabled) {
const activeGroupId = getChatStoreState().activeGroupId;
if (activeGroupId) {
const groupStoreState = getChatGroupStoreState();
const activeGroupDetail = agentGroupSelectors.getGroupById(activeGroupId)(groupStoreState);
if (activeGroupDetail) {
// Get supervisor agent config if supervisorAgentId exists
let supervisorConfig: GroupAgentBuilderContext['supervisorConfig'];
let enabledPlugins: string[] = [];
if (activeGroupDetail.supervisorAgentId) {
const supervisorAgentConfig = agentSelectors.getAgentConfigById(
activeGroupDetail.supervisorAgentId,
)(agentStoreState);
supervisorConfig = {
model: supervisorAgentConfig.model,
plugins: supervisorAgentConfig.plugins,
provider: supervisorAgentConfig.provider,
};
enabledPlugins = supervisorAgentConfig.plugins || [];
}
// Build official tools list (builtin tools + Klavis tools)
const toolState = getToolStoreState();
const officialTools: GroupOfficialToolItem[] = [];
// Get builtin tools (excluding Klavis tools)
const builtinTools = builtinToolSelectors.metaList(toolState);
const klavisIdentifiers = new Set(KLAVIS_SERVER_TYPES.map((t) => t.identifier));
for (const tool of builtinTools) {
// Skip Klavis tools in builtin list (they'll be shown separately)
if (klavisIdentifiers.has(tool.identifier)) continue;
officialTools.push({
description: tool.meta?.description,
enabled: enabledPlugins.includes(tool.identifier),
identifier: tool.identifier,
installed: true,
name: tool.meta?.title || tool.identifier,
type: 'builtin',
});
}
// Get Klavis tools (if enabled)
const isKlavisEnabled =
typeof window !== 'undefined' &&
window.global_serverConfigStore?.getState()?.serverConfig?.enableKlavis;
if (isKlavisEnabled) {
const allKlavisServers = klavisStoreSelectors.getServers(toolState);
for (const klavisType of KLAVIS_SERVER_TYPES) {
const server = allKlavisServers.find((s) => s.identifier === klavisType.identifier);
officialTools.push({
description: `LobeHub Mcp Server: ${klavisType.label}`,
enabled: enabledPlugins.includes(klavisType.identifier),
identifier: klavisType.identifier,
installed: !!server,
name: klavisType.label,
type: 'klavis',
});
}
}
groupAgentBuilderContext = {
config: {
openingMessage: activeGroupDetail.config?.openingMessage || undefined,
openingQuestions: activeGroupDetail.config?.openingQuestions,
systemPrompt: activeGroupDetail.config?.systemPrompt || undefined,
},
groupId: activeGroupId,
groupTitle: activeGroupDetail.title || undefined,
members: activeGroupDetail.agents?.map((agent) => ({
description: agent.description || undefined,
id: agent.id,
isSupervisor: agent.isSupervisor,
title: agent.title || 'Untitled Agent',
})),
officialTools,
supervisorConfig,
};
log('groupAgentBuilderContext built from activeGroupId: %o', groupAgentBuilderContext);
}
}
}
// Get enabled agent files with content and knowledge bases from agent store
const agentFiles = agentSelectors.currentAgentFiles(agentStoreState);
const agentKnowledgeBases = agentSelectors.currentAgentKnowledgeBases(agentStoreState);
const fileContents = agentFiles
.filter((file) => file.enabled && file.content)
.map((file) => ({ content: file.content!, fileId: file.id, filename: file.name }));
const knowledgeBases = agentKnowledgeBases
.filter((kb) => kb.enabled)
.map((kb) => ({ description: kb.description, id: kb.id, name: kb.name }));
// Create MessagesEngine with injected dependencies
/* eslint-disable sort-keys-fix/sort-keys-fix */
const engine = new MessagesEngine({
// Agent configuration
enableHistoryCount,
formatHistorySummary: historySummaryPrompt,
historyCount,
historySummary,
inputTemplate,
systemRole,
// Capability injection
capabilities: {
isCanUseFC,
isCanUseVideo,
isCanUseVision,
},
// File context configuration
fileContext: { enabled: true, includeFileUrl: !isDesktop },
// Knowledge injection
knowledge: {
fileContents,
knowledgeBases,
},
// Messages
messages,
// Model info
model,
provider,
// Tools configuration
toolsConfig: {
getToolSystemRoles: (tools) => toolSelectors.enabledSystemRoles(tools)(getToolStoreState()),
tools,
},
// User memory configuration
userMemory: userMemories
? {
enabled: !!userMemories.memories,
fetchedAt: userMemories.fetchedAt,
memories: userMemories.memories,
}
: undefined,
// Variable generators
variableGenerators: VARIABLE_GENERATORS,
// Extended contexts - only pass when enabled
...(isAgentBuilderEnabled && { agentBuilderContext }),
...(isGroupAgentBuilderEnabled && { groupAgentBuilderContext }),
...(agentGroup && { agentGroup }),
});
const result = await engine.process();
return result.messages;
};
+25
View File
@@ -0,0 +1,25 @@
/**
* Mecha - Core AI execution module
*
* This module provides the core functionality for AI agent execution,
* including agent configuration resolution, context engineering,
* model parameter handling, and memory management.
*/
// Agent configuration
export type { AgentConfigResolverContext, ResolvedAgentConfig } from './agentConfigResolver';
export { getTargetAgentId, resolveAgentConfig } from './agentConfigResolver';
// Context engineering
export { contextEngineering } from './contextEngineering';
// Client model runtime
export { initializeWithClientStore } from './clientModelRuntime';
// Model parameters
export type { ModelExtendParams, ModelParamsContext } from './modelParamsResolver';
export { resolveModelExtendParams } from './modelParamsResolver';
// Memory management
export type { MemoryResolverContext, UserMemoriesResult } from './memoryManager';
export { resolveUserMemories } from './memoryManager';
+128
View File
@@ -0,0 +1,128 @@
import { UIChatMessage } from '@lobechat/types';
import type { RetrieveMemoryResult } from '@lobechat/types';
import { userMemoryService } from '@/services/userMemory';
import { getChatStoreState } from '@/store/chat';
import { topicSelectors } from '@/store/chat/selectors';
import { getSessionStoreState } from '@/store/session';
import { sessionSelectors } from '@/store/session/selectors';
import {
getUserMemoryStoreState,
useUserMemoryStore,
userMemorySelectors,
} from '@/store/userMemory';
import { userMemoryCacheKey } from '@/store/userMemory/utils/cacheKey';
import { createMemorySearchParams } from '@/store/userMemory/utils/searchParams';
/**
* User memories with fetch metadata
*/
export interface UserMemoriesResult {
fetchedAt: number;
memories: RetrieveMemoryResult;
}
/**
* Context for resolving user memories
*/
export interface MemoryResolverContext {
/** Whether memory plugin is enabled */
isMemoryPluginEnabled: boolean;
/** Chat messages for context extraction */
messages: UIChatMessage[];
}
const EMPTY_MEMORIES: RetrieveMemoryResult = {
contexts: [],
experiences: [],
preferences: [],
};
/**
* Resolves user memories for context injection
*
* This function handles:
* 1. Checking if memories are already cached
* 2. Building memory context from messages and session
* 3. Fetching memories from the service if not cached
* 4. Caching the fetched memories for future use
*/
export const resolveUserMemories = async (
ctx: MemoryResolverContext,
): Promise<UserMemoriesResult | undefined> => {
const { isMemoryPluginEnabled, messages } = ctx;
// Check if already have cached memories
let userMemories =
userMemorySelectors.activeUserMemories(isMemoryPluginEnabled)(getUserMemoryStoreState());
if (userMemories) {
return userMemories;
}
// If memory plugin not enabled, return undefined
if (!isMemoryPluginEnabled) {
return undefined;
}
// Build memory context from messages and session
const chatStoreState = getChatStoreState();
const sessionStoreState = getSessionStoreState();
const historyMessages = messages.slice(0, -1);
const latestHistoryMessage = [...historyMessages]
.reverse()
.find((item) => typeof item?.content === 'string' && item.content.trim().length > 0);
const pendingMessage = messages.at(-1);
const memoryContext = {
latestMessageContent: latestHistoryMessage?.content,
pendingMessageContent: pendingMessage?.content,
session: sessionSelectors.currentSession(sessionStoreState),
topic: topicSelectors.currentActiveTopic(chatStoreState),
};
// Set active memory context in store
useUserMemoryStore.getState().setActiveMemoryContext(memoryContext);
const updatedMemoryState = getUserMemoryStoreState();
const memoryParams = updatedMemoryState.activeParams ?? createMemorySearchParams(memoryContext);
if (!memoryParams) {
return undefined;
}
const key = userMemoryCacheKey(memoryParams);
const cachedMemories = updatedMemoryState.memoryMap[key];
// Return cached memories if available
if (cachedMemories) {
const cachedAt = updatedMemoryState.memoryFetchedAtMap[key] ?? Date.now();
return {
fetchedAt: cachedAt,
memories: cachedMemories,
};
}
// Fetch memories from service
const result = await userMemoryService.retrieveMemory(memoryParams);
const memories = result ?? EMPTY_MEMORIES;
const fetchedAt = Date.now();
// Cache the fetched memories
useUserMemoryStore.setState((state) => ({
memoryFetchedAtMap: {
...state.memoryFetchedAtMap,
[key]: fetchedAt,
},
memoryMap: {
...state.memoryMap,
[key]: memories,
},
}));
return {
fetchedAt,
memories,
};
};
@@ -0,0 +1,131 @@
import { LobeAgentChatConfig } from '@lobechat/types';
import { aiModelSelectors, getAiInfraStoreState } from '@/store/aiInfra';
/**
* Context for resolving model parameters
*/
export interface ModelParamsContext {
chatConfig: LobeAgentChatConfig;
model: string;
provider: string;
}
/**
* Extended parameters for model runtime
*/
export interface ModelExtendParams {
enabledContextCaching?: boolean;
imageAspectRatio?: string;
imageResolution?: string;
reasoning_effort?: string;
thinking?: {
budget_tokens?: number;
type: string;
};
thinkingBudget?: number;
thinkingLevel?: string;
urlContext?: boolean;
verbosity?: string;
}
/**
* Resolves extended parameters for model runtime based on model capabilities and chat config
*
* This function checks what extended parameters the model supports and applies
* the corresponding values from chat config.
*/
export const resolveModelExtendParams = (ctx: ModelParamsContext): ModelExtendParams => {
const { model, provider, chatConfig } = ctx;
const extendParams: ModelExtendParams = {};
const aiInfraStoreState = getAiInfraStoreState();
const isModelHasExtendParams = aiModelSelectors.isModelHasExtendParams(
model,
provider,
)(aiInfraStoreState);
if (!isModelHasExtendParams) {
return extendParams;
}
const modelExtendParams = aiModelSelectors.modelExtendParams(model, provider)(aiInfraStoreState);
if (!modelExtendParams) {
return extendParams;
}
// Reasoning configuration
if (modelExtendParams.includes('enableReasoning')) {
if (chatConfig.enableReasoning) {
extendParams.thinking = {
budget_tokens: chatConfig.reasoningBudgetToken || 1024,
type: 'enabled',
};
} else {
extendParams.thinking = {
budget_tokens: 0,
type: 'disabled',
};
}
} else if (modelExtendParams.includes('reasoningBudgetToken')) {
// For models that only have reasoningBudgetToken without enableReasoning
extendParams.thinking = {
budget_tokens: chatConfig.reasoningBudgetToken || 1024,
type: 'enabled',
};
}
// Context caching
if (modelExtendParams.includes('disableContextCaching') && chatConfig.disableContextCaching) {
extendParams.enabledContextCaching = false;
}
// Reasoning effort variants
if (modelExtendParams.includes('reasoningEffort') && chatConfig.reasoningEffort) {
extendParams.reasoning_effort = chatConfig.reasoningEffort;
}
if (modelExtendParams.includes('gpt5ReasoningEffort') && chatConfig.gpt5ReasoningEffort) {
extendParams.reasoning_effort = chatConfig.gpt5ReasoningEffort;
}
if (modelExtendParams.includes('gpt5_1ReasoningEffort') && chatConfig.gpt5_1ReasoningEffort) {
extendParams.reasoning_effort = chatConfig.gpt5_1ReasoningEffort;
}
// Text verbosity
if (modelExtendParams.includes('textVerbosity') && chatConfig.textVerbosity) {
extendParams.verbosity = chatConfig.textVerbosity;
}
// Thinking configuration
if (modelExtendParams.includes('thinking') && chatConfig.thinking) {
extendParams.thinking = { type: chatConfig.thinking };
}
if (modelExtendParams.includes('thinkingBudget') && chatConfig.thinkingBudget !== undefined) {
extendParams.thinkingBudget = chatConfig.thinkingBudget;
}
if (modelExtendParams.includes('thinkingLevel') && chatConfig.thinkingLevel) {
extendParams.thinkingLevel = chatConfig.thinkingLevel;
}
// URL context
if (modelExtendParams.includes('urlContext') && chatConfig.urlContext) {
extendParams.urlContext = chatConfig.urlContext;
}
// Image generation params
if (modelExtendParams.includes('imageAspectRatio') && chatConfig.imageAspectRatio) {
extendParams.imageAspectRatio = chatConfig.imageAspectRatio;
}
if (modelExtendParams.includes('imageResolution') && chatConfig.imageResolution) {
extendParams.imageResolution = chatConfig.imageResolution;
}
return extendParams;
};
+48 -3
View File
@@ -1,3 +1,5 @@
import { AgentGroupDetail, AgentItem } from '@lobechat/types';
import {
ChatGroupAgentItem,
ChatGroupItem,
@@ -6,14 +8,50 @@ import {
} from '@/database/schemas';
import { lambdaClient } from '@/libs/trpc/client';
export interface GroupMemberConfig {
avatar?: string;
backgroundColor?: string;
description?: string;
model?: string;
plugins?: string[];
provider?: string;
systemRole?: string;
tags?: string[];
title?: string;
}
class ChatGroupService {
createGroup = (params: Omit<NewChatGroup, 'userId'>): Promise<ChatGroupItem> => {
/**
* Create a group with a supervisor agent.
* The supervisor agent is automatically created as a virtual agent.
*/
createGroup = (
params: Omit<NewChatGroup, 'userId'>,
): Promise<{ group: ChatGroupItem; supervisorAgentId: string }> => {
return lambdaClient.group.createGroup.mutate({
...params,
config: params.config as any,
});
};
/**
* Create a group with virtual member agents in one request.
* This is the recommended way to create a group from a template.
* Returns groupId, supervisorAgentId, and member agentIds.
*/
createGroupWithMembers = (
groupConfig: Omit<NewChatGroup, 'userId'>,
members: GroupMemberConfig[],
): Promise<{ agentIds: string[]; groupId: string; supervisorAgentId: string }> => {
return lambdaClient.group.createGroupWithMembers.mutate({
groupConfig: {
...groupConfig,
config: groupConfig.config as any,
},
members: members as Partial<AgentItem>[],
});
};
updateGroup = (id: string, value: Partial<ChatGroupItem>): Promise<ChatGroupItem> => {
return lambdaClient.group.updateGroup.mutate({
id,
@@ -32,11 +70,18 @@ class ChatGroupService {
return lambdaClient.group.getGroup.query({ id });
};
getGroupDetail = (id: string): Promise<AgentGroupDetail | null> => {
return lambdaClient.group.getGroupDetail.query({ id });
};
getGroups = (): Promise<ChatGroupItem[]> => {
return lambdaClient.group.getGroups.query();
};
addAgentsToGroup = (groupId: string, agentIds: string[]): Promise<ChatGroupAgentItem[]> => {
addAgentsToGroup = (
groupId: string,
agentIds: string[],
): Promise<{ added: NewChatGroupAgent[]; existing: string[] }> => {
return lambdaClient.group.addAgentsToGroup.mutate({ agentIds, groupId });
};
@@ -48,7 +93,7 @@ class ChatGroupService {
groupId: string,
agentId: string,
updates: Partial<Pick<NewChatGroupAgent, 'order' | 'role'>>,
): Promise<ChatGroupAgentItem> => {
): Promise<NewChatGroupAgent> => {
return lambdaClient.group.updateAgentInGroup.mutate({
agentId,
groupId,
+20
View File
@@ -0,0 +1,20 @@
import type { SidebarAgentItem, SidebarAgentListResponse } from '@/database/repositories/home';
import { lambdaClient } from '@/libs/trpc/client';
export class HomeService {
/**
* Get sidebar agent list with pinned, grouped, and ungrouped items
*/
getSidebarAgentList = (): Promise<SidebarAgentListResponse> => {
return lambdaClient.home.getSidebarAgentList.query();
};
/**
* Search agents by keyword
*/
searchAgents = (keyword: string): Promise<SidebarAgentItem[]> => {
return lambdaClient.home.searchAgents.query({ keyword });
};
}
export const homeService = new HomeService();
+79 -98
View File
@@ -15,44 +15,32 @@ import {
} from '@lobechat/types';
import type { HeatmapsProps } from '@lobehub/charts';
import { INBOX_SESSION_ID } from '@/const/session';
import { lambdaClient } from '@/libs/trpc/client';
import { abortableRequest } from '../utils/abortableRequest';
/**
* Query context for message operations
* Contains identifiers needed for querying/filtering messages after mutations
*/
export interface MessageQueryContext {
agentId?: string;
groupId?: string;
threadId?: string | null;
topicId?: string | null;
}
export class MessageService {
createMessage = async ({
sessionId,
...params
}: CreateMessageParams): Promise<CreateMessageResult> => {
return lambdaClient.message.createMessage.mutate({
...params,
sessionId: sessionId ? this.toDbSessionId(sessionId) : undefined,
});
createMessage = async (params: CreateMessageParams): Promise<CreateMessageResult> => {
return lambdaClient.message.createMessage.mutate(params as any);
};
getMessages = async (
sessionId: string,
topicId?: string,
groupId?: string,
): Promise<UIChatMessage[]> => {
const data = await lambdaClient.message.getMessages.query({
groupId,
sessionId: this.toDbSessionId(sessionId),
topicId,
});
getMessages = async (params: MessageQueryContext): Promise<UIChatMessage[]> => {
const data = await lambdaClient.message.getMessages.query(params);
return data as unknown as UIChatMessage[];
};
getGroupMessages = async (groupId: string, topicId?: string): Promise<UIChatMessage[]> => {
const data = await lambdaClient.message.getMessages.query({
groupId,
topicId,
});
return data as unknown as UIChatMessage[];
};
countMessages = async (params?: {
endDate?: string;
range?: [string, string];
@@ -77,19 +65,14 @@ export class MessageService {
return lambdaClient.message.getHeatmaps.query();
};
updateMessageError = async (
id: string,
value: ChatMessageError,
options?: { sessionId?: string | null; topicId?: string | null },
) => {
updateMessageError = async (id: string, value: ChatMessageError, ctx?: MessageQueryContext) => {
const error = value.type
? value
: { body: value, message: value.message, type: 'ApplicationRuntimeError' };
return lambdaClient.message.update.mutate({
...ctx,
id,
sessionId: options?.sessionId,
topicId: options?.topicId,
value: { error },
});
};
@@ -99,15 +82,30 @@ export class MessageService {
return lambdaClient.message.updateMessagePlugin.mutate({ id, value: { arguments: args } });
};
/**
* Update tool arguments by toolCallId - updates both tool message and parent assistant message in one transaction
* This is the preferred method for updating tool arguments as it prevents race conditions
*
* @param toolCallId - The tool call ID (stable identifier from AI response)
* @param value - The new arguments value
* @param ctx - Message query context
*/
updateToolArguments = async (
toolCallId: string,
value: string | Record<string, unknown>,
ctx?: MessageQueryContext,
) => {
return lambdaClient.message.updateToolArguments.mutate({ ...ctx, toolCallId, value });
};
updateMessage = async (
id: string,
value: Partial<UpdateMessageParams>,
options?: { sessionId?: string | null; topicId?: string | null },
ctx?: MessageQueryContext,
): Promise<UpdateMessageResult> => {
return lambdaClient.message.update.mutate({
...ctx,
id,
sessionId: options?.sessionId,
topicId: options?.topicId,
value,
});
};
@@ -123,115 +121,98 @@ export class MessageService {
updateMessageMetadata = async (
id: string,
value: Partial<MessageMetadata>,
options?: { sessionId?: string | null; topicId?: string | null },
ctx?: MessageQueryContext,
): Promise<UpdateMessageResult> => {
return abortableRequest.execute(`message-metadata-${id}`, (signal) =>
lambdaClient.message.updateMetadata.mutate(
{
id,
sessionId: options?.sessionId,
topicId: options?.topicId,
value,
},
{ signal },
),
lambdaClient.message.updateMetadata.mutate({ ...ctx, id, value }, { signal }),
);
};
updateMessagePluginState = async (
id: string,
value: Record<string, any>,
options?: { sessionId?: string | null; topicId?: string | null },
ctx?: MessageQueryContext,
): Promise<UpdateMessageResult> => {
return lambdaClient.message.updatePluginState.mutate({
id,
sessionId: options?.sessionId,
topicId: options?.topicId,
value,
});
return lambdaClient.message.updatePluginState.mutate({ ...ctx, id, value });
};
updateMessagePluginError = async (
id: string,
error: ChatMessagePluginError | null,
options?: { sessionId?: string | null; topicId?: string | null },
ctx?: MessageQueryContext,
): Promise<UpdateMessageResult> => {
return lambdaClient.message.updatePluginError.mutate({
id,
sessionId: options?.sessionId,
topicId: options?.topicId,
value: error as any,
});
return lambdaClient.message.updatePluginError.mutate({ ...ctx, id, value: error as any });
};
updateMessagePlugin = async (
id: string,
value: Partial<Omit<MessagePluginItem, 'id'>>,
options?: { sessionId?: string | null; topicId?: string | null },
ctx?: MessageQueryContext,
): Promise<UpdateMessageResult> => {
return lambdaClient.message.updateMessagePlugin.mutate({
id,
sessionId: options?.sessionId,
topicId: options?.topicId,
value,
});
return lambdaClient.message.updateMessagePlugin.mutate({ ...ctx, id, value });
};
updateMessageRAG = async (
id: string,
data: UpdateMessageRAGParams,
options?: { sessionId?: string | null; topicId?: string | null },
ctx?: MessageQueryContext,
): Promise<UpdateMessageResult> => {
return lambdaClient.message.updateMessageRAG.mutate({
id,
sessionId: options?.sessionId,
topicId: options?.topicId,
value: data,
});
return lambdaClient.message.updateMessageRAG.mutate({ ...ctx, id, value: data });
};
removeMessage = async (
/**
* Update tool message with content, metadata, pluginState, and pluginError in a single request
* This prevents race conditions when updating multiple fields
* Uses abortableRequest to cancel previous requests for the same message
*/
updateToolMessage = async (
id: string,
options?: { sessionId?: string | null; topicId?: string | null },
value: {
content?: string;
metadata?: Record<string, any>;
pluginError?: any;
pluginState?: Record<string, any>;
},
ctx?: MessageQueryContext,
): Promise<UpdateMessageResult> => {
return lambdaClient.message.removeMessage.mutate({
id,
sessionId: options?.sessionId,
topicId: options?.topicId,
});
return abortableRequest.execute(`tool-message-${id}`, (signal) =>
lambdaClient.message.updateToolMessage.mutate({ ...ctx, id, value }, { signal }),
);
};
removeMessage = async (id: string, ctx?: MessageQueryContext): Promise<UpdateMessageResult> => {
return lambdaClient.message.removeMessage.mutate({ ...ctx, id });
};
removeMessages = async (
ids: string[],
options?: { sessionId?: string | null; topicId?: string | null },
ctx?: MessageQueryContext,
): Promise<UpdateMessageResult> => {
return lambdaClient.message.removeMessages.mutate({
ids,
sessionId: options?.sessionId,
topicId: options?.topicId,
});
return lambdaClient.message.removeMessages.mutate({ ...ctx, ids });
};
removeMessagesByAssistant = async (sessionId: string, topicId?: string) => {
return lambdaClient.message.removeMessagesByAssistant.mutate({
sessionId: this.toDbSessionId(sessionId),
topicId,
});
return lambdaClient.message.removeMessagesByAssistant.mutate({ sessionId, topicId });
};
removeMessagesByGroup = async (groupId: string, topicId?: string) => {
return lambdaClient.message.removeMessagesByGroup.mutate({
groupId,
topicId,
});
return lambdaClient.message.removeMessagesByGroup.mutate({ groupId, topicId });
};
removeAllMessages = async () => {
return lambdaClient.message.removeAllMessages.mutate();
};
private toDbSessionId = (sessionId: string | undefined) => {
return sessionId === INBOX_SESSION_ID ? null : sessionId;
/**
* Add files to a message
* Used to associate exported files from code interpreter with the tool message
*/
addFilesToMessage = async (
id: string,
fileIds: string[],
ctx?: MessageQueryContext,
): Promise<UpdateMessageResult> => {
return lambdaClient.message.addFilesToMessage.mutate({ ...ctx, fileIds, id });
};
}
+60 -26
View File
@@ -1,44 +1,78 @@
import { describe, expect, it } from 'vitest';
import { describe, expect, it, vi } from 'vitest';
import { INBOX_SESSION_ID } from '@/const/session';
import { lambdaClient } from '@/libs/trpc/client';
import { MessageService } from './index';
vi.mock('@/libs/trpc/client', () => ({
lambdaClient: {
message: {
createMessage: { mutate: vi.fn() },
getMessages: { query: vi.fn() },
removeMessagesByAssistant: { mutate: vi.fn() },
},
},
}));
describe('MessageService', () => {
describe('toDbSessionId', () => {
describe('createMessage', () => {
const service = new MessageService();
// @ts-ignore access private method for testing
const toDbSessionId = service.toDbSessionId;
it('should return null for INBOX_SESSION_ID', () => {
expect(toDbSessionId(INBOX_SESSION_ID)).toBeNull();
afterEach(() => {
vi.clearAllMocks();
});
it('should return the same session id for non-inbox sessions', () => {
const sessionId = 'test-session-123';
expect(toDbSessionId(sessionId)).toBe(sessionId);
it('should pass params directly to lambdaClient', async () => {
vi.mocked(lambdaClient.message.createMessage.mutate).mockResolvedValue({
id: 'msg-1',
messages: [],
});
await service.createMessage({
content: 'test',
role: 'user',
agentId: 'agent-123',
});
expect(lambdaClient.message.createMessage.mutate).toHaveBeenCalledWith({
content: 'test',
role: 'user',
agentId: 'agent-123',
});
});
});
describe('removeMessagesByAssistant', () => {
const service = new MessageService();
afterEach(() => {
vi.clearAllMocks();
});
it('should handle undefined input', () => {
expect(toDbSessionId(undefined)).toBeUndefined(); // Updated to match the actual behavior
it('should pass sessionId to lambdaClient', async () => {
vi.mocked(lambdaClient.message.removeMessagesByAssistant.mutate).mockResolvedValue(
undefined as any,
);
await service.removeMessagesByAssistant('session-123');
expect(lambdaClient.message.removeMessagesByAssistant.mutate).toHaveBeenCalledWith({
sessionId: 'session-123',
topicId: undefined,
});
});
it('should handle empty string input', () => {
expect(toDbSessionId('')).toBe(''); // No changes needed
});
it('should pass sessionId and topicId to lambdaClient', async () => {
vi.mocked(lambdaClient.message.removeMessagesByAssistant.mutate).mockResolvedValue(
undefined as any,
);
it('should handle special characters in session id', () => {
const specialSessionId = '!@#$%^&*()_+';
expect(toDbSessionId(specialSessionId)).toBe(specialSessionId);
});
await service.removeMessagesByAssistant('session-123', 'topic-1');
it('should handle numeric session id', () => {
const numericSessionId = '12345';
expect(toDbSessionId(numericSessionId)).toBe(numericSessionId);
});
it('should handle null session id', () => {
expect(toDbSessionId(null as any)).toBeNull(); // Cast null to any to bypass type errors
expect(lambdaClient.message.removeMessagesByAssistant.mutate).toHaveBeenCalledWith({
sessionId: 'session-123',
topicId: 'topic-1',
});
});
});
});
+1 -1
View File
@@ -5,8 +5,8 @@ import { aiProviderSelectors, getAiInfraStoreState } from '@/store/aiInfra';
import { ChatModelCard } from '@/types/llm';
import { API_ENDPOINTS } from './_url';
import { initializeWithClientStore } from './chat/clientModelRuntime';
import { resolveRuntimeProvider } from './chat/helper';
import { initializeWithClientStore } from './chat/mecha';
const isEnableFetchOnClient = (provider: string) =>
aiProviderSelectors.isProviderFetchOnClient(provider)(getAiInfraStoreState());
+21 -6
View File
@@ -1,7 +1,13 @@
import { INBOX_SESSION_ID } from '@/const/session';
import { lambdaClient } from '@/libs/trpc/client';
import { BatchTaskResult } from '@/types/service';
import { ChatTopic, CreateTopicParams, QueryTopicParams, TopicRankItem } from '@/types/topic';
import {
ChatTopic,
CreateTopicParams,
QueryTopicParams,
RecentTopic,
TopicRankItem,
} from '@/types/topic';
export class TopicService {
createTopic = (params: CreateTopicParams): Promise<string> => {
@@ -19,10 +25,13 @@ export class TopicService {
return lambdaClient.topic.cloneTopic.mutate({ id, newTitle });
};
getTopics = (params: QueryTopicParams): Promise<ChatTopic[]> => {
getTopics = async (params: QueryTopicParams): Promise<{ items: ChatTopic[]; total: number }> => {
return lambdaClient.topic.getTopics.query({
...params,
containerId: this.toDbSessionId(params.containerId),
agentId: params.agentId,
current: params.current,
groupId: params.groupId,
isInbox: params.isInbox,
pageSize: params.pageSize,
}) as any;
};
@@ -31,6 +40,8 @@ export class TopicService {
};
countTopics = async (params?: {
agentId?: string;
containerId?: string | null;
endDate?: string;
range?: [string, string];
startDate?: string;
@@ -42,11 +53,15 @@ export class TopicService {
return lambdaClient.topic.rankTopics.query(limit);
};
searchTopics = (keywords: string, sessionId?: string, groupId?: string): Promise<ChatTopic[]> => {
getRecentTopics = async (limit?: number): Promise<RecentTopic[]> => {
return lambdaClient.topic.recentTopics.query({ limit });
};
searchTopics = (keywords: string, agentId?: string, groupId?: string): Promise<ChatTopic[]> => {
return lambdaClient.topic.searchTopics.query({
agentId,
groupId,
keywords,
sessionId: this.toDbSessionId(sessionId),
}) as any;
};