🐛 fix: model runtime provider issue (#11314)

* fix

* upload

* update

* fix

* fix tests
This commit is contained in:
Arvin Xu
2026-01-07 23:22:19 +08:00
committed by GitHub
parent f9a35eb036
commit b4ba8bf454
5 changed files with 263 additions and 12 deletions
+3 -3
View File
@@ -102,8 +102,8 @@ jobs:
- name: Install deps
run: bun i
- name: Run tests with blob reporter
run: bunx vitest --coverage --reporter=blob --silent='passed-only' --shard=${{ matrix.shard }}/2
- name: Run tests
run: bunx vitest --coverage --silent='passed-only' --shard=${{ matrix.shard }}/2
- name: Upload blob report
if: ${{ !cancelled() }}
@@ -139,7 +139,7 @@ jobs:
merge-multiple: true
- name: Merge reports
run: bunx vitest --merge-reports --coverage
run: bunx vitest --merge-reports --reporter=default --coverage
- name: Upload App Coverage to Codecov
uses: codecov/codecov-action@v5
+1
View File
@@ -116,3 +116,4 @@ CLAUDE.local.md
e2e/reports
out
i18n-unused-keys-report.json
.vitest-reports
+214 -1
View File
@@ -26,7 +26,7 @@ import { ClientSecretPayload } from '@lobechat/types';
import { ModelProvider } from 'model-bank';
import { describe, expect, it, vi } from 'vitest';
import { initModelRuntimeWithUserPayload } from './index';
import { buildPayloadFromKeyVaults, initModelRuntimeWithUserPayload } from './index';
// 模拟依赖项
vi.mock('@/envs/llm', () => ({
@@ -496,3 +496,216 @@ describe('initModelRuntimeWithUserPayload method', () => {
});
});
});
/**
* Test cases for buildPayloadFromKeyVaults function
* This function builds ClientSecretPayload based on runtimeProvider (sdkType)
* to ensure provider-specific fields are correctly forwarded
*/
describe('buildPayloadFromKeyVaults', () => {
describe('should build payload with correct fields based on runtimeProvider', () => {
it('OpenAI compatible: returns apiKey, baseURL and runtimeProvider', () => {
const keyVaults = {
apiKey: 'test-api-key',
baseURL: 'https://custom-endpoint.com/v1',
};
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.OpenAI);
expect(payload).toEqual({
apiKey: 'test-api-key',
baseURL: 'https://custom-endpoint.com/v1',
runtimeProvider: ModelProvider.OpenAI,
});
});
it('Azure: returns apiKey, baseURL, azureApiVersion and runtimeProvider', () => {
const keyVaults = {
apiKey: 'azure-api-key',
baseURL: 'https://my-azure.openai.azure.com',
apiVersion: '2024-06-01',
endpoint: 'https://fallback-endpoint.com',
};
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Azure);
expect(payload).toEqual({
apiKey: 'azure-api-key',
azureApiVersion: '2024-06-01',
baseURL: 'https://my-azure.openai.azure.com',
runtimeProvider: ModelProvider.Azure,
});
});
it('Azure: uses endpoint as fallback when baseURL is not provided', () => {
const keyVaults = {
apiKey: 'azure-api-key',
endpoint: 'https://fallback-endpoint.com',
apiVersion: '2024-06-01',
};
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Azure);
expect(payload.baseURL).toBe('https://fallback-endpoint.com');
});
it('Cloudflare: returns apiKey, cloudflareBaseURLOrAccountID and runtimeProvider', () => {
const keyVaults = {
apiKey: 'cloudflare-api-key',
baseURLOrAccountID: 'my-account-id',
};
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Cloudflare);
expect(payload).toEqual({
apiKey: 'cloudflare-api-key',
cloudflareBaseURLOrAccountID: 'my-account-id',
runtimeProvider: ModelProvider.Cloudflare,
});
});
it('Bedrock: returns AWS credentials and runtimeProvider', () => {
const keyVaults = {
accessKeyId: 'aws-access-key',
secretAccessKey: 'aws-secret-key',
region: 'us-east-1',
sessionToken: 'session-token',
};
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Bedrock);
expect(payload).toEqual({
apiKey: 'aws-secret-keyaws-access-key',
awsAccessKeyId: 'aws-access-key',
awsRegion: 'us-east-1',
awsSecretAccessKey: 'aws-secret-key',
awsSessionToken: 'session-token',
runtimeProvider: ModelProvider.Bedrock,
});
});
it('Ollama: returns baseURL and runtimeProvider', () => {
const keyVaults = {
baseURL: 'http://localhost:11434',
};
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Ollama);
expect(payload).toEqual({
baseURL: 'http://localhost:11434',
runtimeProvider: ModelProvider.Ollama,
});
});
it('VertexAI: returns apiKey, baseURL, vertexAIRegion and runtimeProvider', () => {
const keyVaults = {
apiKey: 'vertex-credentials-json',
baseURL: 'https://vertex-endpoint.com',
region: 'us-central1',
};
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.VertexAI);
expect(payload).toEqual({
apiKey: 'vertex-credentials-json',
baseURL: 'https://vertex-endpoint.com',
runtimeProvider: ModelProvider.VertexAI,
vertexAIRegion: 'us-central1',
});
});
it('ComfyUI: returns all auth fields and runtimeProvider', () => {
const keyVaults = {
apiKey: 'comfyui-api-key',
authType: 'bearer',
baseURL: 'http://localhost:8188',
customHeaders: { 'X-Custom': 'header' },
password: 'pass',
username: 'user',
} as const;
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.ComfyUI);
expect(payload).toEqual({
apiKey: 'comfyui-api-key',
authType: 'bearer',
baseURL: 'http://localhost:8188',
customHeaders: { 'X-Custom': 'header' },
password: 'pass',
runtimeProvider: ModelProvider.ComfyUI,
username: 'user',
});
});
it('Unknown provider: falls back to default with apiKey, baseURL and runtimeProvider', () => {
const keyVaults = {
apiKey: 'unknown-api-key',
baseURL: 'https://unknown-endpoint.com',
};
const payload = buildPayloadFromKeyVaults(keyVaults, 'unknown-provider');
expect(payload).toEqual({
apiKey: 'unknown-api-key',
baseURL: 'https://unknown-endpoint.com',
runtimeProvider: 'unknown-provider',
});
});
});
describe('custom provider with sdkType should include provider-specific fields', () => {
it('custom provider with Azure sdkType includes azureApiVersion', () => {
const keyVaults = {
apiKey: 'custom-azure-key',
baseURL: 'https://custom-azure.openai.azure.com',
apiVersion: '2024-06-01',
};
// Simulates a custom provider where runtimeProvider is resolved to 'azure'
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Azure);
expect(payload.azureApiVersion).toBe('2024-06-01');
expect(payload.runtimeProvider).toBe(ModelProvider.Azure);
});
it('custom provider with Cloudflare sdkType includes cloudflareBaseURLOrAccountID', () => {
const keyVaults = {
apiKey: 'custom-cloudflare-key',
baseURLOrAccountID: 'custom-account-id',
};
// Simulates a custom provider where runtimeProvider is resolved to 'cloudflare'
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Cloudflare);
expect(payload.cloudflareBaseURLOrAccountID).toBe('custom-account-id');
expect(payload.runtimeProvider).toBe(ModelProvider.Cloudflare);
});
it('custom provider with Bedrock sdkType includes AWS credentials', () => {
const keyVaults = {
accessKeyId: 'custom-aws-id',
secretAccessKey: 'custom-aws-secret',
region: 'eu-west-1',
};
// Simulates a custom provider where runtimeProvider is resolved to 'bedrock'
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Bedrock);
expect(payload.awsAccessKeyId).toBe('custom-aws-id');
expect(payload.awsSecretAccessKey).toBe('custom-aws-secret');
expect(payload.awsRegion).toBe('eu-west-1');
expect(payload.runtimeProvider).toBe(ModelProvider.Bedrock);
});
it('custom provider with Ollama sdkType includes baseURL', () => {
const keyVaults = {
baseURL: 'http://custom-ollama:11434',
};
// Simulates a custom provider where runtimeProvider is resolved to 'ollama'
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Ollama);
expect(payload.baseURL).toBe('http://custom-ollama:11434');
expect(payload.runtimeProvider).toBe(ModelProvider.Ollama);
});
it('custom provider with VertexAI sdkType includes vertexAIRegion', () => {
const keyVaults = {
apiKey: 'custom-vertex-creds',
region: 'asia-northeast1',
};
// Simulates a custom provider where runtimeProvider is resolved to 'vertexai'
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.VertexAI);
expect(payload.vertexAIRegion).toBe('asia-northeast1');
expect(payload.runtimeProvider).toBe(ModelProvider.VertexAI);
});
});
});
+44 -8
View File
@@ -32,6 +32,24 @@ type ProviderKeyVaults = OpenAICompatibleKeyVault &
ComfyUIKeyVault &
VertexAIKeyVault;
/**
* Resolve the runtime provider for a given provider.
*
* This is the server-side equivalent of the frontend's resolveRuntimeProvider function.
* For builtin providers, returns the provider as-is.
* For custom providers, returns the sdkType from settings (defaults to 'openai').
*
* @param provider - The provider id
* @param sdkType - The sdkType from provider settings
* @returns The resolved runtime provider
*/
const resolveRuntimeProvider = (provider: string, sdkType?: string): string => {
const isBuiltin = Object.values(ModelProvider).includes(provider as ModelProvider);
if (isBuiltin) return provider;
return sdkType || 'openai';
};
/**
* Build ClientSecretPayload from keyVaults stored in database
*
@@ -39,15 +57,21 @@ type ProviderKeyVaults = OpenAICompatibleKeyVault &
* It converts the keyVaults object from database to the ClientSecretPayload format
* expected by initModelRuntimeWithUserPayload.
*
* @param provider - The model provider
* For custom providers, we use runtimeProvider (sdkType) to determine which fields
* to include in the payload. This ensures that provider-specific fields like
* cloudflareBaseURLOrAccountID or azureApiVersion are correctly forwarded.
*
* @param keyVaults - The keyVaults object from database (already decrypted)
* @param runtimeProvider - The runtime provider (sdkType) to use for building payload
* @returns ClientSecretPayload for the provider
*/
export const buildPayloadFromKeyVaults = (
provider: string,
keyVaults: ProviderKeyVaults,
runtimeProvider: string,
): ClientSecretPayload => {
switch (provider) {
// Use runtimeProvider to determine which fields to include
// This handles both builtin providers and custom providers with sdkType
switch (runtimeProvider) {
case ModelProvider.Bedrock: {
const { accessKeyId, region, secretAccessKey, sessionToken } = keyVaults;
const apiKey = (secretAccessKey || '') + (accessKeyId || '');
@@ -58,6 +82,7 @@ export const buildPayloadFromKeyVaults = (
awsRegion: region,
awsSecretAccessKey: secretAccessKey,
awsSessionToken: sessionToken,
runtimeProvider,
};
}
@@ -66,17 +91,19 @@ export const buildPayloadFromKeyVaults = (
apiKey: keyVaults.apiKey,
azureApiVersion: keyVaults.apiVersion,
baseURL: keyVaults.baseURL || keyVaults.endpoint,
runtimeProvider,
};
}
case ModelProvider.Ollama: {
return { baseURL: keyVaults.baseURL };
return { baseURL: keyVaults.baseURL, runtimeProvider };
}
case ModelProvider.Cloudflare: {
return {
apiKey: keyVaults.apiKey,
cloudflareBaseURLOrAccountID: keyVaults.baseURLOrAccountID,
runtimeProvider,
};
}
@@ -87,6 +114,7 @@ export const buildPayloadFromKeyVaults = (
baseURL: keyVaults.baseURL,
customHeaders: keyVaults.customHeaders,
password: keyVaults.password,
runtimeProvider,
username: keyVaults.username,
};
}
@@ -95,6 +123,7 @@ export const buildPayloadFromKeyVaults = (
return {
apiKey: keyVaults.apiKey,
baseURL: keyVaults.baseURL,
runtimeProvider,
vertexAIRegion: keyVaults.region,
};
}
@@ -103,6 +132,7 @@ export const buildPayloadFromKeyVaults = (
return {
apiKey: keyVaults.apiKey,
baseURL: keyVaults.baseURL,
runtimeProvider,
};
}
}
@@ -350,10 +380,16 @@ export const initModelRuntimeFromDB = async (
KeyVaultsGateKeeper.getUserKeyVaults,
);
// 2. Build ClientSecretPayload from keyVaults
const keyVaults = (providerConfig?.keyVaults || {}) as ProviderKeyVaults;
const payload = buildPayloadFromKeyVaults(provider, keyVaults);
// 2. Resolve the runtime provider for custom providers
// For custom providers, use sdkType from settings (defaults to 'openai')
const sdkType = providerConfig?.settings?.sdkType;
const runtimeProvider = resolveRuntimeProvider(provider, sdkType);
// 3. Initialize ModelRuntime with the payload
// 3. Build ClientSecretPayload from keyVaults based on runtimeProvider
// This ensures provider-specific fields (e.g., cloudflareBaseURLOrAccountID) are included
const keyVaults = (providerConfig?.keyVaults || {}) as ProviderKeyVaults;
const payload = buildPayloadFromKeyVaults(keyVaults, runtimeProvider);
// 4. Initialize ModelRuntime with the payload
return initModelRuntimeWithUserPayload(provider, payload);
};
+1
View File
@@ -92,6 +92,7 @@ export default defineConfig({
'**/e2e/**',
],
globals: true,
reporters: ['default', 'blob'],
server: {
deps: {
inline: ['vitest-canvas-mock', '@lobehub/ui', '@lobehub/fluent-emoji'],