mirror of
https://github.com/lobehub/lobe-chat.git
synced 2026-06-13 19:20:04 +00:00
🐛 fix: model runtime provider issue (#11314)
* fix * upload * update * fix * fix tests
This commit is contained in:
@@ -102,8 +102,8 @@ jobs:
|
||||
- name: Install deps
|
||||
run: bun i
|
||||
|
||||
- name: Run tests with blob reporter
|
||||
run: bunx vitest --coverage --reporter=blob --silent='passed-only' --shard=${{ matrix.shard }}/2
|
||||
- name: Run tests
|
||||
run: bunx vitest --coverage --silent='passed-only' --shard=${{ matrix.shard }}/2
|
||||
|
||||
- name: Upload blob report
|
||||
if: ${{ !cancelled() }}
|
||||
@@ -139,7 +139,7 @@ jobs:
|
||||
merge-multiple: true
|
||||
|
||||
- name: Merge reports
|
||||
run: bunx vitest --merge-reports --coverage
|
||||
run: bunx vitest --merge-reports --reporter=default --coverage
|
||||
|
||||
- name: Upload App Coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
|
||||
@@ -116,3 +116,4 @@ CLAUDE.local.md
|
||||
e2e/reports
|
||||
out
|
||||
i18n-unused-keys-report.json
|
||||
.vitest-reports
|
||||
|
||||
@@ -26,7 +26,7 @@ import { ClientSecretPayload } from '@lobechat/types';
|
||||
import { ModelProvider } from 'model-bank';
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { initModelRuntimeWithUserPayload } from './index';
|
||||
import { buildPayloadFromKeyVaults, initModelRuntimeWithUserPayload } from './index';
|
||||
|
||||
// 模拟依赖项
|
||||
vi.mock('@/envs/llm', () => ({
|
||||
@@ -496,3 +496,216 @@ describe('initModelRuntimeWithUserPayload method', () => {
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
/**
|
||||
* Test cases for buildPayloadFromKeyVaults function
|
||||
* This function builds ClientSecretPayload based on runtimeProvider (sdkType)
|
||||
* to ensure provider-specific fields are correctly forwarded
|
||||
*/
|
||||
describe('buildPayloadFromKeyVaults', () => {
|
||||
describe('should build payload with correct fields based on runtimeProvider', () => {
|
||||
it('OpenAI compatible: returns apiKey, baseURL and runtimeProvider', () => {
|
||||
const keyVaults = {
|
||||
apiKey: 'test-api-key',
|
||||
baseURL: 'https://custom-endpoint.com/v1',
|
||||
};
|
||||
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.OpenAI);
|
||||
|
||||
expect(payload).toEqual({
|
||||
apiKey: 'test-api-key',
|
||||
baseURL: 'https://custom-endpoint.com/v1',
|
||||
runtimeProvider: ModelProvider.OpenAI,
|
||||
});
|
||||
});
|
||||
|
||||
it('Azure: returns apiKey, baseURL, azureApiVersion and runtimeProvider', () => {
|
||||
const keyVaults = {
|
||||
apiKey: 'azure-api-key',
|
||||
baseURL: 'https://my-azure.openai.azure.com',
|
||||
apiVersion: '2024-06-01',
|
||||
endpoint: 'https://fallback-endpoint.com',
|
||||
};
|
||||
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Azure);
|
||||
|
||||
expect(payload).toEqual({
|
||||
apiKey: 'azure-api-key',
|
||||
azureApiVersion: '2024-06-01',
|
||||
baseURL: 'https://my-azure.openai.azure.com',
|
||||
runtimeProvider: ModelProvider.Azure,
|
||||
});
|
||||
});
|
||||
|
||||
it('Azure: uses endpoint as fallback when baseURL is not provided', () => {
|
||||
const keyVaults = {
|
||||
apiKey: 'azure-api-key',
|
||||
endpoint: 'https://fallback-endpoint.com',
|
||||
apiVersion: '2024-06-01',
|
||||
};
|
||||
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Azure);
|
||||
|
||||
expect(payload.baseURL).toBe('https://fallback-endpoint.com');
|
||||
});
|
||||
|
||||
it('Cloudflare: returns apiKey, cloudflareBaseURLOrAccountID and runtimeProvider', () => {
|
||||
const keyVaults = {
|
||||
apiKey: 'cloudflare-api-key',
|
||||
baseURLOrAccountID: 'my-account-id',
|
||||
};
|
||||
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Cloudflare);
|
||||
|
||||
expect(payload).toEqual({
|
||||
apiKey: 'cloudflare-api-key',
|
||||
cloudflareBaseURLOrAccountID: 'my-account-id',
|
||||
runtimeProvider: ModelProvider.Cloudflare,
|
||||
});
|
||||
});
|
||||
|
||||
it('Bedrock: returns AWS credentials and runtimeProvider', () => {
|
||||
const keyVaults = {
|
||||
accessKeyId: 'aws-access-key',
|
||||
secretAccessKey: 'aws-secret-key',
|
||||
region: 'us-east-1',
|
||||
sessionToken: 'session-token',
|
||||
};
|
||||
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Bedrock);
|
||||
|
||||
expect(payload).toEqual({
|
||||
apiKey: 'aws-secret-keyaws-access-key',
|
||||
awsAccessKeyId: 'aws-access-key',
|
||||
awsRegion: 'us-east-1',
|
||||
awsSecretAccessKey: 'aws-secret-key',
|
||||
awsSessionToken: 'session-token',
|
||||
runtimeProvider: ModelProvider.Bedrock,
|
||||
});
|
||||
});
|
||||
|
||||
it('Ollama: returns baseURL and runtimeProvider', () => {
|
||||
const keyVaults = {
|
||||
baseURL: 'http://localhost:11434',
|
||||
};
|
||||
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Ollama);
|
||||
|
||||
expect(payload).toEqual({
|
||||
baseURL: 'http://localhost:11434',
|
||||
runtimeProvider: ModelProvider.Ollama,
|
||||
});
|
||||
});
|
||||
|
||||
it('VertexAI: returns apiKey, baseURL, vertexAIRegion and runtimeProvider', () => {
|
||||
const keyVaults = {
|
||||
apiKey: 'vertex-credentials-json',
|
||||
baseURL: 'https://vertex-endpoint.com',
|
||||
region: 'us-central1',
|
||||
};
|
||||
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.VertexAI);
|
||||
|
||||
expect(payload).toEqual({
|
||||
apiKey: 'vertex-credentials-json',
|
||||
baseURL: 'https://vertex-endpoint.com',
|
||||
runtimeProvider: ModelProvider.VertexAI,
|
||||
vertexAIRegion: 'us-central1',
|
||||
});
|
||||
});
|
||||
|
||||
it('ComfyUI: returns all auth fields and runtimeProvider', () => {
|
||||
const keyVaults = {
|
||||
apiKey: 'comfyui-api-key',
|
||||
authType: 'bearer',
|
||||
baseURL: 'http://localhost:8188',
|
||||
customHeaders: { 'X-Custom': 'header' },
|
||||
password: 'pass',
|
||||
username: 'user',
|
||||
} as const;
|
||||
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.ComfyUI);
|
||||
|
||||
expect(payload).toEqual({
|
||||
apiKey: 'comfyui-api-key',
|
||||
authType: 'bearer',
|
||||
baseURL: 'http://localhost:8188',
|
||||
customHeaders: { 'X-Custom': 'header' },
|
||||
password: 'pass',
|
||||
runtimeProvider: ModelProvider.ComfyUI,
|
||||
username: 'user',
|
||||
});
|
||||
});
|
||||
|
||||
it('Unknown provider: falls back to default with apiKey, baseURL and runtimeProvider', () => {
|
||||
const keyVaults = {
|
||||
apiKey: 'unknown-api-key',
|
||||
baseURL: 'https://unknown-endpoint.com',
|
||||
};
|
||||
const payload = buildPayloadFromKeyVaults(keyVaults, 'unknown-provider');
|
||||
|
||||
expect(payload).toEqual({
|
||||
apiKey: 'unknown-api-key',
|
||||
baseURL: 'https://unknown-endpoint.com',
|
||||
runtimeProvider: 'unknown-provider',
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('custom provider with sdkType should include provider-specific fields', () => {
|
||||
it('custom provider with Azure sdkType includes azureApiVersion', () => {
|
||||
const keyVaults = {
|
||||
apiKey: 'custom-azure-key',
|
||||
baseURL: 'https://custom-azure.openai.azure.com',
|
||||
apiVersion: '2024-06-01',
|
||||
};
|
||||
// Simulates a custom provider where runtimeProvider is resolved to 'azure'
|
||||
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Azure);
|
||||
|
||||
expect(payload.azureApiVersion).toBe('2024-06-01');
|
||||
expect(payload.runtimeProvider).toBe(ModelProvider.Azure);
|
||||
});
|
||||
|
||||
it('custom provider with Cloudflare sdkType includes cloudflareBaseURLOrAccountID', () => {
|
||||
const keyVaults = {
|
||||
apiKey: 'custom-cloudflare-key',
|
||||
baseURLOrAccountID: 'custom-account-id',
|
||||
};
|
||||
// Simulates a custom provider where runtimeProvider is resolved to 'cloudflare'
|
||||
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Cloudflare);
|
||||
|
||||
expect(payload.cloudflareBaseURLOrAccountID).toBe('custom-account-id');
|
||||
expect(payload.runtimeProvider).toBe(ModelProvider.Cloudflare);
|
||||
});
|
||||
|
||||
it('custom provider with Bedrock sdkType includes AWS credentials', () => {
|
||||
const keyVaults = {
|
||||
accessKeyId: 'custom-aws-id',
|
||||
secretAccessKey: 'custom-aws-secret',
|
||||
region: 'eu-west-1',
|
||||
};
|
||||
// Simulates a custom provider where runtimeProvider is resolved to 'bedrock'
|
||||
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Bedrock);
|
||||
|
||||
expect(payload.awsAccessKeyId).toBe('custom-aws-id');
|
||||
expect(payload.awsSecretAccessKey).toBe('custom-aws-secret');
|
||||
expect(payload.awsRegion).toBe('eu-west-1');
|
||||
expect(payload.runtimeProvider).toBe(ModelProvider.Bedrock);
|
||||
});
|
||||
|
||||
it('custom provider with Ollama sdkType includes baseURL', () => {
|
||||
const keyVaults = {
|
||||
baseURL: 'http://custom-ollama:11434',
|
||||
};
|
||||
// Simulates a custom provider where runtimeProvider is resolved to 'ollama'
|
||||
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.Ollama);
|
||||
|
||||
expect(payload.baseURL).toBe('http://custom-ollama:11434');
|
||||
expect(payload.runtimeProvider).toBe(ModelProvider.Ollama);
|
||||
});
|
||||
|
||||
it('custom provider with VertexAI sdkType includes vertexAIRegion', () => {
|
||||
const keyVaults = {
|
||||
apiKey: 'custom-vertex-creds',
|
||||
region: 'asia-northeast1',
|
||||
};
|
||||
// Simulates a custom provider where runtimeProvider is resolved to 'vertexai'
|
||||
const payload = buildPayloadFromKeyVaults(keyVaults, ModelProvider.VertexAI);
|
||||
|
||||
expect(payload.vertexAIRegion).toBe('asia-northeast1');
|
||||
expect(payload.runtimeProvider).toBe(ModelProvider.VertexAI);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -32,6 +32,24 @@ type ProviderKeyVaults = OpenAICompatibleKeyVault &
|
||||
ComfyUIKeyVault &
|
||||
VertexAIKeyVault;
|
||||
|
||||
/**
|
||||
* Resolve the runtime provider for a given provider.
|
||||
*
|
||||
* This is the server-side equivalent of the frontend's resolveRuntimeProvider function.
|
||||
* For builtin providers, returns the provider as-is.
|
||||
* For custom providers, returns the sdkType from settings (defaults to 'openai').
|
||||
*
|
||||
* @param provider - The provider id
|
||||
* @param sdkType - The sdkType from provider settings
|
||||
* @returns The resolved runtime provider
|
||||
*/
|
||||
const resolveRuntimeProvider = (provider: string, sdkType?: string): string => {
|
||||
const isBuiltin = Object.values(ModelProvider).includes(provider as ModelProvider);
|
||||
if (isBuiltin) return provider;
|
||||
|
||||
return sdkType || 'openai';
|
||||
};
|
||||
|
||||
/**
|
||||
* Build ClientSecretPayload from keyVaults stored in database
|
||||
*
|
||||
@@ -39,15 +57,21 @@ type ProviderKeyVaults = OpenAICompatibleKeyVault &
|
||||
* It converts the keyVaults object from database to the ClientSecretPayload format
|
||||
* expected by initModelRuntimeWithUserPayload.
|
||||
*
|
||||
* @param provider - The model provider
|
||||
* For custom providers, we use runtimeProvider (sdkType) to determine which fields
|
||||
* to include in the payload. This ensures that provider-specific fields like
|
||||
* cloudflareBaseURLOrAccountID or azureApiVersion are correctly forwarded.
|
||||
*
|
||||
* @param keyVaults - The keyVaults object from database (already decrypted)
|
||||
* @param runtimeProvider - The runtime provider (sdkType) to use for building payload
|
||||
* @returns ClientSecretPayload for the provider
|
||||
*/
|
||||
export const buildPayloadFromKeyVaults = (
|
||||
provider: string,
|
||||
keyVaults: ProviderKeyVaults,
|
||||
runtimeProvider: string,
|
||||
): ClientSecretPayload => {
|
||||
switch (provider) {
|
||||
// Use runtimeProvider to determine which fields to include
|
||||
// This handles both builtin providers and custom providers with sdkType
|
||||
switch (runtimeProvider) {
|
||||
case ModelProvider.Bedrock: {
|
||||
const { accessKeyId, region, secretAccessKey, sessionToken } = keyVaults;
|
||||
const apiKey = (secretAccessKey || '') + (accessKeyId || '');
|
||||
@@ -58,6 +82,7 @@ export const buildPayloadFromKeyVaults = (
|
||||
awsRegion: region,
|
||||
awsSecretAccessKey: secretAccessKey,
|
||||
awsSessionToken: sessionToken,
|
||||
runtimeProvider,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -66,17 +91,19 @@ export const buildPayloadFromKeyVaults = (
|
||||
apiKey: keyVaults.apiKey,
|
||||
azureApiVersion: keyVaults.apiVersion,
|
||||
baseURL: keyVaults.baseURL || keyVaults.endpoint,
|
||||
runtimeProvider,
|
||||
};
|
||||
}
|
||||
|
||||
case ModelProvider.Ollama: {
|
||||
return { baseURL: keyVaults.baseURL };
|
||||
return { baseURL: keyVaults.baseURL, runtimeProvider };
|
||||
}
|
||||
|
||||
case ModelProvider.Cloudflare: {
|
||||
return {
|
||||
apiKey: keyVaults.apiKey,
|
||||
cloudflareBaseURLOrAccountID: keyVaults.baseURLOrAccountID,
|
||||
runtimeProvider,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -87,6 +114,7 @@ export const buildPayloadFromKeyVaults = (
|
||||
baseURL: keyVaults.baseURL,
|
||||
customHeaders: keyVaults.customHeaders,
|
||||
password: keyVaults.password,
|
||||
runtimeProvider,
|
||||
username: keyVaults.username,
|
||||
};
|
||||
}
|
||||
@@ -95,6 +123,7 @@ export const buildPayloadFromKeyVaults = (
|
||||
return {
|
||||
apiKey: keyVaults.apiKey,
|
||||
baseURL: keyVaults.baseURL,
|
||||
runtimeProvider,
|
||||
vertexAIRegion: keyVaults.region,
|
||||
};
|
||||
}
|
||||
@@ -103,6 +132,7 @@ export const buildPayloadFromKeyVaults = (
|
||||
return {
|
||||
apiKey: keyVaults.apiKey,
|
||||
baseURL: keyVaults.baseURL,
|
||||
runtimeProvider,
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -350,10 +380,16 @@ export const initModelRuntimeFromDB = async (
|
||||
KeyVaultsGateKeeper.getUserKeyVaults,
|
||||
);
|
||||
|
||||
// 2. Build ClientSecretPayload from keyVaults
|
||||
const keyVaults = (providerConfig?.keyVaults || {}) as ProviderKeyVaults;
|
||||
const payload = buildPayloadFromKeyVaults(provider, keyVaults);
|
||||
// 2. Resolve the runtime provider for custom providers
|
||||
// For custom providers, use sdkType from settings (defaults to 'openai')
|
||||
const sdkType = providerConfig?.settings?.sdkType;
|
||||
const runtimeProvider = resolveRuntimeProvider(provider, sdkType);
|
||||
|
||||
// 3. Initialize ModelRuntime with the payload
|
||||
// 3. Build ClientSecretPayload from keyVaults based on runtimeProvider
|
||||
// This ensures provider-specific fields (e.g., cloudflareBaseURLOrAccountID) are included
|
||||
const keyVaults = (providerConfig?.keyVaults || {}) as ProviderKeyVaults;
|
||||
const payload = buildPayloadFromKeyVaults(keyVaults, runtimeProvider);
|
||||
|
||||
// 4. Initialize ModelRuntime with the payload
|
||||
return initModelRuntimeWithUserPayload(provider, payload);
|
||||
};
|
||||
|
||||
@@ -92,6 +92,7 @@ export default defineConfig({
|
||||
'**/e2e/**',
|
||||
],
|
||||
globals: true,
|
||||
reporters: ['default', 'blob'],
|
||||
server: {
|
||||
deps: {
|
||||
inline: ['vitest-canvas-mock', '@lobehub/ui', '@lobehub/fluent-emoji'],
|
||||
|
||||
Reference in New Issue
Block a user