🐛 fix: Custom provider fails when client requests are enabled (#9534)

*  fix: (启用客户端请求,自定义服务商未遵循指定请求格式) 更新 initializeWithClientStore 函数,支持通过选项对象传递 provider 和 payload,增强代码可读性

*  feat: 添加 runtimeProvider 支持,优化模型服务的提供者初始化逻辑

* add test
This commit is contained in:
sxjeru
2025-10-04 19:15:59 +08:00
committed by GitHub
parent ba3f67f7d4
commit 8b12fdfb82
11 changed files with 242 additions and 49 deletions
@@ -41,7 +41,7 @@ export class ModelRuntime {
*
* @example - Use without trace
* ```ts
* const agentRuntime = await initializeWithClientStore(provider, payload);
* const agentRuntime = await initializeWithClientStore({ provider, payload });
* const data = payload as ChatStreamPayload;
* return await agentRuntime.chat(data);
* ```
+2
View File
@@ -16,6 +16,8 @@ export interface ClientSecretPayload {
*/
baseURL?: string;
runtimeProvider?: string;
azureApiVersion?: string;
awsAccessKeyId?: string;
@@ -87,6 +87,19 @@ describe('initModelRuntimeWithUserPayload method', () => {
expect(runtime['_runtime'].baseURL).toBe(jwtPayload.baseURL);
});
it('Custom provider should use runtimeProvider to init', async () => {
const jwtPayload: ClientSecretPayload = {
apiKey: 'user-azure-key',
azureApiVersion: '2024-06-01',
baseURL: 'user-azure-endpoint',
runtimeProvider: ModelProvider.Azure,
};
const runtime = await initModelRuntimeWithUserPayload('custom-provider', jwtPayload);
expect(runtime).toBeInstanceOf(ModelRuntime);
expect(runtime['_runtime']).toBeInstanceOf(LobeAzureOpenAI);
expect(runtime['_runtime'].baseURL).toBe(jwtPayload.baseURL);
});
it('ZhiPu AI provider: with apikey', async () => {
const jwtPayload: ClientSecretPayload = { apiKey: 'zhipu.user-key' };
const runtime = await initModelRuntimeWithUserPayload(ModelProvider.ZhiPu, jwtPayload);
+4 -2
View File
@@ -128,8 +128,10 @@ export const initModelRuntimeWithUserPayload = (
payload: ClientSecretPayload,
params: any = {},
) => {
return ModelRuntime.initializeWithProvider(provider, {
...getParamsFromPayload(provider, payload),
const runtimeProvider = payload.runtimeProvider ?? provider;
return ModelRuntime.initializeWithProvider(runtimeProvider, {
...getParamsFromPayload(runtimeProvider, payload),
...params,
});
};
+95 -5
View File
@@ -1,21 +1,111 @@
import { Mock, describe, expect, it, vi } from 'vitest';
import { Mock, beforeEach, describe, expect, it, vi } from 'vitest';
import { aiProviderSelectors } from '@/store/aiInfra';
import { createHeaderWithAuth } from '../_auth';
import { initializeWithClientStore } from '../chat/clientModelRuntime';
import { resolveRuntimeProvider } from '../chat/helper';
import { ModelsService } from '../models';
vi.stubGlobal('fetch', vi.fn());
// 创建一个测试用的 ModelsService 实例
vi.mock('@/const/version', () => ({
isDeprecatedEdition: false,
}));
vi.mock('../_auth', () => ({
createHeaderWithAuth: vi.fn(async () => ({})),
}));
vi.mock('../chat/helper', () => ({
resolveRuntimeProvider: vi.fn((provider: string) => provider),
}));
vi.mock('../chat/clientModelRuntime', () => ({
initializeWithClientStore: vi.fn(),
}));
vi.mock('@/store/aiInfra', () => ({
aiProviderSelectors: {
isProviderFetchOnClient: () => () => false,
},
getAiInfraStoreState: () => ({}),
}));
vi.mock('@/store/user', () => ({
useUserStore: {
getState: vi.fn(),
},
}));
vi.mock('@/store/user/selectors', () => ({
modelConfigSelectors: {
isProviderFetchOnClient: () => () => false,
},
}));
// 创建一个测试用的 ModelsService 实例
const modelsService = new ModelsService();
const mockedCreateHeaderWithAuth = vi.mocked(createHeaderWithAuth);
const mockedResolveRuntimeProvider = vi.mocked(resolveRuntimeProvider);
const mockedInitializeWithClientStore = vi.mocked(initializeWithClientStore);
describe('ModelsService', () => {
beforeEach(() => {
(fetch as Mock).mockClear();
mockedCreateHeaderWithAuth.mockClear();
mockedResolveRuntimeProvider.mockReset();
mockedResolveRuntimeProvider.mockImplementation((provider: string) => provider);
mockedInitializeWithClientStore.mockClear();
});
describe('getModels', () => {
it('should call the appropriate endpoint for a generic provider', async () => {
(fetch as Mock).mockResolvedValueOnce(new Response(JSON.stringify({ models: [] })));
it('should call the endpoint for runtime provider when server fetching', async () => {
(fetch as Mock).mockResolvedValueOnce(
new Response(JSON.stringify({ models: [] }), { status: 200 }),
);
await modelsService.getModels('openai');
expect(fetch).toHaveBeenCalled();
expect(mockedResolveRuntimeProvider).toHaveBeenCalledWith('openai');
expect(fetch).toHaveBeenCalledWith('/webapi/models/openai', { headers: {} });
expect(mockedInitializeWithClientStore).not.toHaveBeenCalled();
});
it('should map custom provider to runtime provider endpoint', async () => {
mockedResolveRuntimeProvider.mockImplementation(() => 'openai');
(fetch as Mock).mockResolvedValueOnce(
new Response(JSON.stringify({ models: [] }), { status: 200 }),
);
await modelsService.getModels('custom-provider');
expect(mockedResolveRuntimeProvider).toHaveBeenCalledWith('custom-provider');
expect(fetch).toHaveBeenCalledWith('/webapi/models/openai', { headers: {} });
expect(mockedInitializeWithClientStore).not.toHaveBeenCalled();
});
it('should fetch models on client when isProviderFetchOnClient is true', async () => {
// Mock isProviderFetchOnClient to return true
const spyIsClient = vi
.spyOn(aiProviderSelectors, 'isProviderFetchOnClient')
.mockReturnValue(() => true);
// Mock initializeWithClientStore to return a runtime with a models() method
const mockModels = vi.fn().mockResolvedValue({ models: ['model1', 'model2'] });
mockedInitializeWithClientStore.mockResolvedValue({ models: mockModels } as any);
const result = await modelsService.getModels('openai');
expect(spyIsClient).toHaveBeenCalledWith('openai');
expect(mockedInitializeWithClientStore).toHaveBeenCalledWith({
provider: 'openai',
runtimeProvider: 'openai',
});
expect(mockModels).toHaveBeenCalled();
expect(result).toEqual({ models: ['model1', 'model2'] });
spyIsClient.mockRestore();
});
});
});
+8 -1
View File
@@ -14,6 +14,8 @@ import { useUserStore } from '@/store/user';
import { keyVaultsConfigSelectors, userProfileSelectors } from '@/store/user/selectors';
import { obfuscatePayloadWithXOR } from '@/utils/client/xor-obfuscation';
import { resolveRuntimeProvider } from './chat/helper';
export const getProviderAuthPayload = (
provider: string,
keyVaults: OpenAICompatibleKeyVault &
@@ -104,7 +106,12 @@ export const createPayloadWithKeyVaults = (provider: string) => {
keyVaults = aiProviderSelectors.providerKeyVaults(provider)(useAiInfraStore.getState()) || {};
}
return getProviderAuthPayload(provider, keyVaults);
const runtimeProvider = resolveRuntimeProvider(provider);
return {
...getProviderAuthPayload(runtimeProvider, keyVaults as any),
runtimeProvider,
};
};
export const createXorKeyVaultsPayload = (provider: string) => {
+68 -17
View File
@@ -108,7 +108,10 @@ describe('ModelRuntimeOnClient', () => {
},
},
} as UserSettingsState) as unknown as UserStore;
const runtime = await initializeWithClientStore(ModelProvider.OpenAI, {});
const runtime = await initializeWithClientStore({
payload: {},
provider: ModelProvider.OpenAI,
});
expect(runtime).toBeInstanceOf(ModelRuntime);
expect(runtime['_runtime']).toBeInstanceOf(LobeOpenAI);
expect(runtime['_runtime'].baseURL).toBe('user-openai-endpoint');
@@ -127,7 +130,10 @@ describe('ModelRuntimeOnClient', () => {
},
} as UserSettingsState) as unknown as UserStore;
const runtime = await initializeWithClientStore(ModelProvider.Azure, {});
const runtime = await initializeWithClientStore({
payload: {},
provider: ModelProvider.Azure,
});
expect(runtime).toBeInstanceOf(ModelRuntime);
expect(runtime['_runtime']).toBeInstanceOf(LobeAzureOpenAI);
});
@@ -142,7 +148,10 @@ describe('ModelRuntimeOnClient', () => {
},
},
} as UserSettingsState) as unknown as UserStore;
const runtime = await initializeWithClientStore(ModelProvider.Google, {});
const runtime = await initializeWithClientStore({
payload: {},
provider: ModelProvider.Google,
});
expect(runtime).toBeInstanceOf(ModelRuntime);
expect(runtime['_runtime']).toBeInstanceOf(LobeGoogleAI);
});
@@ -157,7 +166,10 @@ describe('ModelRuntimeOnClient', () => {
},
},
} as UserSettingsState) as unknown as UserStore;
const runtime = await initializeWithClientStore(ModelProvider.Moonshot, {});
const runtime = await initializeWithClientStore({
payload: {},
provider: ModelProvider.Moonshot,
});
expect(runtime).toBeInstanceOf(ModelRuntime);
expect(runtime['_runtime']).toBeInstanceOf(LobeMoonshotAI);
});
@@ -174,7 +186,10 @@ describe('ModelRuntimeOnClient', () => {
},
},
} as UserSettingsState) as unknown as UserStore;
const runtime = await initializeWithClientStore(ModelProvider.Bedrock, {});
const runtime = await initializeWithClientStore({
payload: {},
provider: ModelProvider.Bedrock,
});
expect(runtime).toBeInstanceOf(ModelRuntime);
expect(runtime['_runtime']).toBeInstanceOf(LobeBedrockAI);
});
@@ -189,7 +204,10 @@ describe('ModelRuntimeOnClient', () => {
},
},
} as UserSettingsState) as unknown as UserStore;
const runtime = await initializeWithClientStore(ModelProvider.Ollama, {});
const runtime = await initializeWithClientStore({
payload: {},
provider: ModelProvider.Ollama,
});
expect(runtime).toBeInstanceOf(ModelRuntime);
expect(runtime['_runtime']).toBeInstanceOf(LobeOllamaAI);
});
@@ -204,7 +222,10 @@ describe('ModelRuntimeOnClient', () => {
},
},
} as UserSettingsState) as unknown as UserStore;
const runtime = await initializeWithClientStore(ModelProvider.Perplexity, {});
const runtime = await initializeWithClientStore({
payload: {},
provider: ModelProvider.Perplexity,
});
expect(runtime).toBeInstanceOf(ModelRuntime);
expect(runtime['_runtime']).toBeInstanceOf(LobePerplexityAI);
});
@@ -219,7 +240,10 @@ describe('ModelRuntimeOnClient', () => {
},
},
} as UserSettingsState) as unknown as UserStore;
const runtime = await initializeWithClientStore(ModelProvider.Anthropic, {});
const runtime = await initializeWithClientStore({
payload: {},
provider: ModelProvider.Anthropic,
});
expect(runtime).toBeInstanceOf(ModelRuntime);
expect(runtime['_runtime']).toBeInstanceOf(LobeAnthropicAI);
});
@@ -234,7 +258,10 @@ describe('ModelRuntimeOnClient', () => {
},
},
} as UserSettingsState) as unknown as UserStore;
const runtime = await initializeWithClientStore(ModelProvider.Mistral, {});
const runtime = await initializeWithClientStore({
payload: {},
provider: ModelProvider.Mistral,
});
expect(runtime).toBeInstanceOf(ModelRuntime);
expect(runtime['_runtime']).toBeInstanceOf(LobeMistralAI);
});
@@ -249,7 +276,10 @@ describe('ModelRuntimeOnClient', () => {
},
},
} as UserSettingsState) as unknown as UserStore;
const runtime = await initializeWithClientStore(ModelProvider.OpenRouter, {});
const runtime = await initializeWithClientStore({
payload: {},
provider: ModelProvider.OpenRouter,
});
expect(runtime).toBeInstanceOf(ModelRuntime);
expect(runtime['_runtime']).toBeInstanceOf(LobeOpenRouterAI);
});
@@ -264,7 +294,10 @@ describe('ModelRuntimeOnClient', () => {
},
},
} as UserSettingsState) as unknown as UserStore;
const runtime = await initializeWithClientStore(ModelProvider.TogetherAI, {});
const runtime = await initializeWithClientStore({
payload: {},
provider: ModelProvider.TogetherAI,
});
expect(runtime).toBeInstanceOf(ModelRuntime);
expect(runtime['_runtime']).toBeInstanceOf(LobeTogetherAI);
});
@@ -279,7 +312,10 @@ describe('ModelRuntimeOnClient', () => {
},
},
} as UserSettingsState) as unknown as UserStore;
const runtime = await initializeWithClientStore(ModelProvider.ZeroOne, {});
const runtime = await initializeWithClientStore({
payload: {},
provider: ModelProvider.ZeroOne,
});
expect(runtime).toBeInstanceOf(ModelRuntime);
expect(runtime['_runtime']).toBeInstanceOf(LobeZeroOneAI);
});
@@ -295,7 +331,10 @@ describe('ModelRuntimeOnClient', () => {
},
},
} as UserSettingsState) as unknown as UserStore;
const runtime = await initializeWithClientStore(ModelProvider.Groq, {});
const runtime = await initializeWithClientStore({
payload: {},
provider: ModelProvider.Groq,
});
expect(runtime).toBeInstanceOf(ModelRuntime);
const lobeOpenAICompatibleInstance = runtime['_runtime'] as LobeOpenAICompatibleRuntime;
expect(lobeOpenAICompatibleInstance).toBeInstanceOf(LobeGroq);
@@ -314,7 +353,10 @@ describe('ModelRuntimeOnClient', () => {
},
},
} as UserSettingsState) as unknown as UserStore;
const runtime = await initializeWithClientStore(ModelProvider.DeepSeek, {});
const runtime = await initializeWithClientStore({
payload: {},
provider: ModelProvider.DeepSeek,
});
expect(runtime).toBeInstanceOf(ModelRuntime);
expect(runtime['_runtime']).toBeInstanceOf(LobeDeepSeekAI);
});
@@ -329,7 +371,10 @@ describe('ModelRuntimeOnClient', () => {
},
},
} as UserSettingsState) as unknown as UserStore;
const runtime = await initializeWithClientStore(ModelProvider.Qwen, {});
const runtime = await initializeWithClientStore({
payload: {},
provider: ModelProvider.Qwen,
});
expect(runtime).toBeInstanceOf(ModelRuntime);
expect(runtime['_runtime']).toBeInstanceOf(LobeQwenAI);
});
@@ -349,7 +394,10 @@ describe('ModelRuntimeOnClient', () => {
},
},
} as any as UserSettingsState) as unknown as UserStore;
const runtime = await initializeWithClientStore('unknown' as ModelProvider, {});
const runtime = await initializeWithClientStore({
payload: {},
provider: 'unknown' as ModelProvider,
});
expect(runtime).toBeInstanceOf(ModelRuntime);
expect(runtime['_runtime']).toBeInstanceOf(LobeOpenAI);
});
@@ -376,7 +424,10 @@ describe('ModelRuntimeOnClient', () => {
},
},
} as UserSettingsState) as unknown as UserStore;
const runtime = await initializeWithClientStore(ModelProvider.ZhiPu, {});
const runtime = await initializeWithClientStore({
payload: {},
provider: ModelProvider.ZhiPu,
});
expect(runtime).toBeInstanceOf(ModelRuntime);
expect(runtime['_runtime']).toBeInstanceOf(LobeZhipuAI);
});
+15 -4
View File
@@ -2,15 +2,26 @@ import { ModelRuntime } from '@lobechat/model-runtime';
import { createPayloadWithKeyVaults } from '../_auth';
export interface InitializeWithClientStoreOptions {
payload?: any;
provider: string;
runtimeProvider?: string;
}
/**
* Initializes the AgentRuntime with the client store.
* @param provider - The provider name.
* @param payload - Init options
* @param options.provider - Provider identifier used to resolve key vaults.
* @param options.runtimeProvider - Actual runtime implementation key (defaults to provider).
* @param options.payload - Additional initialization payload.
* @returns The initialized AgentRuntime instance
*
* **Note**: if you try to fetch directly, use `fetchOnClient` instead.
*/
export const initializeWithClientStore = (provider: string, payload?: any) => {
export const initializeWithClientStore = ({
provider,
runtimeProvider,
payload,
}: InitializeWithClientStoreOptions) => {
/**
* Since #5267, we map parameters for client-fetch in function `getProviderAuthPayload`
* which called by `createPayloadWithKeyVaults` below.
@@ -26,7 +37,7 @@ export const initializeWithClientStore = (provider: string, payload?: any) => {
* Configuration override order:
* payload -> providerAuthPayload -> commonOptions
*/
return ModelRuntime.initializeWithProvider(provider, {
return ModelRuntime.initializeWithProvider(runtimeProvider ?? provider, {
...commonOptions,
...providerAuthPayload,
...payload,
+11
View File
@@ -54,3 +54,14 @@ export const isEnableFetchOnClient = (provider: string) => {
return aiProviderSelectors.isProviderFetchOnClient(provider)(getAiInfraStoreState());
}
};
export const resolveRuntimeProvider = (provider: string) => {
if (isDeprecatedEdition) return provider;
const isBuiltin = Object.values(ModelProvider).includes(provider as any);
if (isBuiltin) return provider;
const providerConfig = aiProviderSelectors.providerConfigById(provider)(getAiInfraStoreState());
return providerConfig?.settings.sdkType || 'openai';
};
+11 -15
View File
@@ -6,7 +6,7 @@ import { ModelProvider } from 'model-bank';
import { enableAuth } from '@/const/auth';
import { DEFAULT_AGENT_CONFIG } from '@/const/settings';
import { isDeprecatedEdition, isDesktop } from '@/const/version';
import { isDesktop } from '@/const/version';
import { getSearchConfig } from '@/helpers/getSearchConfig';
import { createChatToolsEngine, createToolsEngine } from '@/helpers/toolEngineering';
import { getAgentStoreState } from '@/store/agent';
@@ -38,7 +38,7 @@ import { createHeaderWithAuth } from '../_auth';
import { API_ENDPOINTS } from '../_url';
import { initializeWithClientStore } from './clientModelRuntime';
import { contextEngineering } from './contextEngineering';
import { findDeploymentName, isEnableFetchOnClient } from './helper';
import { findDeploymentName, isEnableFetchOnClient, resolveRuntimeProvider } from './helper';
import { FetchOptions } from './types';
interface GetChatCompletionPayload extends Partial<Omit<ChatStreamPayload, 'messages'>> {
@@ -268,6 +268,8 @@ class ChatService {
{ ...res, apiMode, model },
);
const sdkType = resolveRuntimeProvider(provider);
/**
* Use browser agent runtime
*/
@@ -287,7 +289,7 @@ class ChatService {
*/
fetcher = async () => {
try {
return await this.fetchOnClient({ payload, provider, signal });
return await this.fetchOnClient({ payload, provider, runtimeProvider: sdkType, signal });
} catch (e) {
const {
errorType = ChatErrorType.BadRequest,
@@ -314,17 +316,6 @@ class ChatService {
const { DEFAULT_MODEL_PROVIDER_LIST } = await import('@/config/modelProviders');
const providerConfig = DEFAULT_MODEL_PROVIDER_LIST.find((item) => item.id === provider);
let sdkType = provider;
const isBuiltin = Object.values(ModelProvider).includes(provider as any);
// TODO: remove `!isDeprecatedEdition` condition in V2.0
if (!isDeprecatedEdition && !isBuiltin) {
const providerConfig =
aiProviderSelectors.providerConfigById(provider)(getAiInfraStoreState());
sdkType = providerConfig?.settings.sdkType || 'openai';
}
const userPreferTransitionMode =
userGeneralSettingsSelectors.transitionMode(getUserStoreState());
@@ -461,6 +452,7 @@ class ChatService {
private fetchOnClient = async (params: {
payload: Partial<ChatStreamPayload>;
provider: string;
runtimeProvider: string;
signal?: AbortSignal;
}) => {
/**
@@ -471,7 +463,11 @@ class ChatService {
throw AgentRuntimeError.createError(ChatErrorType.InvalidAccessCode);
}
const agentRuntime = await initializeWithClientStore(params.provider, params.payload);
const agentRuntime = await initializeWithClientStore({
payload: params.payload,
provider: params.provider,
runtimeProvider: params.runtimeProvider,
});
const data = params.payload as ChatStreamPayload;
return agentRuntime.chat(data, { signal: params.signal });
+14 -4
View File
@@ -8,6 +8,7 @@ import { getMessageError } from '@/utils/fetch';
import { API_ENDPOINTS } from './_url';
import { initializeWithClientStore } from './chat/clientModelRuntime';
import { resolveRuntimeProvider } from './chat/helper';
const isEnableFetchOnClient = (provider: string) => {
// TODO: remove this condition in V2.0
@@ -41,17 +42,22 @@ export class ModelsService {
headers: { 'Content-Type': 'application/json' },
provider,
});
const runtimeProvider = resolveRuntimeProvider(provider);
try {
/**
* Use browser agent runtime
*/
const enableFetchOnClient = isEnableFetchOnClient(provider);
if (enableFetchOnClient) {
const agentRuntime = await initializeWithClientStore(provider);
const agentRuntime = await initializeWithClientStore({
provider,
runtimeProvider,
});
return agentRuntime.models();
}
const res = await fetch(API_ENDPOINTS.models(provider), { headers });
const res = await fetch(API_ENDPOINTS.models(runtimeProvider), { headers });
if (!res.ok) return;
return res.json();
@@ -77,15 +83,19 @@ export class ModelsService {
provider,
});
const runtimeProvider = resolveRuntimeProvider(provider);
const enableFetchOnClient = isEnableFetchOnClient(provider);
console.log('enableFetchOnClient', enableFetchOnClient);
let res: Response;
if (enableFetchOnClient) {
const agentRuntime = await initializeWithClientStore(provider);
const agentRuntime = await initializeWithClientStore({
provider,
runtimeProvider,
});
res = (await agentRuntime.pullModel({ model }, { signal }))!;
} else {
res = await fetch(API_ENDPOINTS.modelPull(provider), {
res = await fetch(API_ENDPOINTS.modelPull(runtimeProvider), {
body: JSON.stringify({ model }),
headers,
method: 'POST',