feat(image): image model show price (#10198)

This commit is contained in:
YuTengjing
2025-11-14 16:10:22 +08:00
committed by GitHub
parent f46edeb2d1
commit b87e0e422e
13 changed files with 430 additions and 323 deletions
@@ -1056,6 +1056,7 @@ const aihubmixModels: AIChatModelCard[] = [
id: 'gemini-2.5-flash-image',
maxOutput: 8192,
pricing: {
approximatePricePerImage: 0.039,
units: [
{ name: 'textInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
{ name: 'imageInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
+9 -5
View File
@@ -370,6 +370,7 @@ const googleChatModels: AIChatModelCard[] = [
id: 'gemini-2.5-flash-image',
maxOutput: 8192,
pricing: {
approximatePricePerImage: 0.039,
units: [
{ name: 'textInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
{ name: 'imageInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
@@ -392,6 +393,7 @@ const googleChatModels: AIChatModelCard[] = [
id: 'gemini-2.5-flash-image-preview',
maxOutput: 8192,
pricing: {
approximatePricePerImage: 0.039,
units: [
{ name: 'textInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
{ name: 'imageInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
@@ -864,6 +866,7 @@ const googleImageModels: AIImageModelCard[] = [
releasedAt: '2025-08-26',
parameters: nanoBananaParameters,
pricing: {
approximatePricePerImage: 0.039,
units: [
{ name: 'textInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
{ name: 'textOutput', rate: 2.5, strategy: 'fixed', unit: 'millionTokens' },
@@ -880,6 +883,7 @@ const googleImageModels: AIImageModelCard[] = [
releasedAt: '2025-08-26',
parameters: CHAT_MODEL_IMAGE_GENERATION_PARAMS,
pricing: {
approximatePricePerImage: 0.039,
units: [
{ name: 'textInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
{ name: 'textOutput', rate: 2.5, strategy: 'fixed', unit: 'millionTokens' },
@@ -892,7 +896,7 @@ const googleImageModels: AIImageModelCard[] = [
id: 'imagen-4.0-generate-001',
enabled: true,
type: 'image',
description: 'Imagen 4th generation text-to-image model series',
description: 'Imagen 第四代文生图模型系列',
organization: 'Deepmind',
releasedAt: '2025-08-15',
parameters: imagenGenParameters,
@@ -905,7 +909,7 @@ const googleImageModels: AIImageModelCard[] = [
id: 'imagen-4.0-ultra-generate-001',
enabled: true,
type: 'image',
description: 'Imagen 4th generation text-to-image model series Ultra version',
description: 'Imagen 第四代文生图模型系列的 Ultra 版本',
organization: 'Deepmind',
releasedAt: '2025-08-15',
parameters: imagenGenParameters,
@@ -918,7 +922,7 @@ const googleImageModels: AIImageModelCard[] = [
id: 'imagen-4.0-fast-generate-001',
enabled: true,
type: 'image',
description: 'Imagen 4th generation text-to-image model series Fast version',
description: 'Imagen 第四代文生图模型系列的快速版本',
organization: 'Deepmind',
releasedAt: '2025-08-15',
parameters: imagenGenParameters,
@@ -930,7 +934,7 @@ const googleImageModels: AIImageModelCard[] = [
displayName: 'Imagen 4 Preview 06-06',
id: 'imagen-4.0-generate-preview-06-06',
type: 'image',
description: 'Imagen 4th generation text-to-image model series',
description: 'Imagen 第四代文生图模型系列',
organization: 'Deepmind',
releasedAt: '2025-06-06',
parameters: imagenGenParameters,
@@ -942,7 +946,7 @@ const googleImageModels: AIImageModelCard[] = [
displayName: 'Imagen 4 Ultra Preview 06-06',
id: 'imagen-4.0-ultra-generate-preview-06-06',
type: 'image',
description: 'Imagen 4th generation text-to-image model series Ultra version',
description: 'Imagen 第四代文生图模型系列的 Ultra 版本',
organization: 'Deepmind',
releasedAt: '2025-06-11',
parameters: imagenGenParameters,
+2 -35
View File
@@ -1171,31 +1171,13 @@ export const openaiImageModels: AIImageModelCard[] = [
id: 'gpt-image-1',
parameters: gptImage1ParamsSchema,
pricing: {
approximatePricePerImage: 0.042,
units: [
{ name: 'textInput', rate: 5, strategy: 'fixed', unit: 'millionTokens' },
{ name: 'textInput_cacheRead', rate: 1.25, strategy: 'fixed', unit: 'millionTokens' },
{ name: 'imageInput', rate: 10, strategy: 'fixed', unit: 'millionTokens' },
{ name: 'imageInput_cacheRead', rate: 2.5, strategy: 'fixed', unit: 'millionTokens' },
{ name: 'imageOutput', rate: 40, strategy: 'fixed', unit: 'millionTokens' },
{
lookup: {
prices: {
high_1024x1024: 0.167,
high_1024x1536: 0.25,
high_1536x1024: 0.25,
low_1024x1024: 0.011,
low_1024x1536: 0.016,
low_1536x1024: 0.016,
medium_1024x1024: 0.042,
medium_1024x1536: 0.063,
medium_1536x1024: 0.063,
},
pricingParams: ['quality', 'size'],
},
name: 'imageGeneration',
strategy: 'lookup',
unit: 'image',
},
],
},
resolutions: ['1024x1024', '1024x1536', '1536x1024'],
@@ -1208,28 +1190,13 @@ export const openaiImageModels: AIImageModelCard[] = [
id: 'gpt-image-1-mini',
parameters: gptImage1ParamsSchema,
pricing: {
approximatePricePerImage: 0.011,
units: [
{ name: 'textInput', rate: 2, strategy: 'fixed', unit: 'millionTokens' },
{ name: 'textInput_cacheRead', rate: 0.2, strategy: 'fixed', unit: 'millionTokens' },
{ name: 'imageInput', rate: 2.5, strategy: 'fixed', unit: 'millionTokens' },
{ name: 'imageInput_cacheRead', rate: 0.25, strategy: 'fixed', unit: 'millionTokens' },
{ name: 'imageOutput', rate: 8, strategy: 'fixed', unit: 'millionTokens' },
{
lookup: {
prices: {
low_1024x1024: 0.005,
low_1024x1536: 0.006,
low_1536x1024: 0.006,
medium_1024x1024: 0.011,
medium_1024x1536: 0.015,
medium_1536x1024: 0.015,
},
pricingParams: ['quality', 'size'],
},
name: 'imageGeneration',
strategy: 'lookup',
unit: 'image',
},
],
},
releasedAt: '2025-10-06',
@@ -41,6 +41,7 @@ const openrouterChatModels: AIChatModelCard[] = [
id: 'google/gemini-2.5-flash-image-preview',
maxOutput: 8192,
pricing: {
approximatePricePerImage: 0.039,
units: [
{ name: 'imageOutput', rate: 30, strategy: 'fixed', unit: 'millionTokens' },
{ name: 'textInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
@@ -135,6 +135,7 @@ const vertexaiChatModels: AIChatModelCard[] = [
id: 'gemini-2.5-flash-image-preview',
maxOutput: 8192,
pricing: {
approximatePricePerImage: 0.039,
units: [
{ name: 'textInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
{ name: 'textOutput', rate: 2.5, strategy: 'fixed', unit: 'millionTokens' },
@@ -291,6 +292,7 @@ const vertexaiImageModels: AIImageModelCard[] = [
releasedAt: '2025-08-26',
parameters: nanoBananaParameters,
pricing: {
approximatePricePerImage: 0.039,
units: [
{ name: 'textInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
{ name: 'textOutput', rate: 2.5, strategy: 'fixed', unit: 'millionTokens' },
+15 -2
View File
@@ -185,6 +185,10 @@ export interface LookupPricingUnit extends PricingUnitBase {
export type PricingUnit = FixedPricingUnit | TieredPricingUnit | LookupPricingUnit;
export interface Pricing {
/**
* Fallback approximate per-image price (USD) when detailed pricing table is unavailable
*/
approximatePricePerImage?: number;
currency?: ModelPriceCurrency;
units: PricingUnit[];
}
@@ -391,13 +395,22 @@ export const ToggleAiModelEnableSchema = z.object({
export type ToggleAiModelEnableParams = z.infer<typeof ToggleAiModelEnableSchema>;
//
export interface AiModelForSelect {
abilities: ModelAbilities;
/**
* Approximate per-image price (USD), used when exact calculation is not possible
*/
approximatePricePerImage?: number;
contextWindowTokens?: number;
description?: string;
displayName?: string;
id: string;
parameters?: ModelParamsSchema;
/**
* Exact per-image price (USD) calculated from pricing units
*/
pricePerImage?: number;
pricing?: Pricing;
}
export interface EnabledAiModel {
@@ -2,3 +2,4 @@ export { convertAnthropicUsage } from './anthropic';
export { convertGoogleAIUsage } from './google-ai';
export { convertOpenAIResponseUsage, convertOpenAIUsage } from './openai';
export { computeImageCost } from './utils/computeImageCost';
export { resolveImageSinglePrice } from './utils/resolveImageSinglePrice';
@@ -0,0 +1,34 @@
import { Pricing } from 'model-bank';
export interface ImageSinglePriceResult {
approximatePrice?: number;
price?: number;
}
const DEFAULT_REFERENCE_MP = (1024 * 1024) / 1_000_000;
export const resolveImageSinglePrice = (pricing?: Pricing): ImageSinglePriceResult => {
if (!pricing) return {};
// Priority 1: Use approximate price if explicitly provided
if (typeof pricing.approximatePricePerImage === 'number') {
return { approximatePrice: pricing.approximatePricePerImage };
}
// Priority 2: Calculate exact price from pricing units
const imageGenerationUnit = pricing.units.find((unit) => unit.name === 'imageGeneration');
if (!imageGenerationUnit) return {};
if (imageGenerationUnit.strategy === 'fixed') {
if (imageGenerationUnit.unit === 'image') {
return { price: imageGenerationUnit.rate };
}
if (imageGenerationUnit.unit === 'megapixel') {
return { price: imageGenerationUnit.rate * DEFAULT_REFERENCE_MP };
}
}
// Lookup/tiered pricing typically requires explicit configuration; treat as unavailable here.
return {};
};
+1 -1
View File
@@ -1,4 +1,4 @@
import { ModelParamsSchema , AiModelType, Pricing } from 'model-bank';
import { AiModelType, ModelParamsSchema, Pricing } from 'model-bank';
import { ReactNode } from 'react';
import { AiProviderSettings } from './aiProvider';
@@ -0,0 +1,93 @@
import { ModelIcon } from '@lobehub/icons';
import { Text } from '@lobehub/ui';
import { Popover } from 'antd';
import { createStyles } from 'antd-style';
import { AiModelForSelect } from 'model-bank';
import numeral from 'numeral';
import { memo, useMemo } from 'react';
import { Flexbox } from 'react-layout-kit';
const POPOVER_MAX_WIDTH = 320;
const useStyles = createStyles(({ css, token, isDarkMode }) => ({
descriptionText: css`
color: ${isDarkMode ? token.colorText : token.colorTextSecondary};
`,
popover: css`
.ant-popover-inner {
background: ${isDarkMode ? token.colorBgSpotlight : token.colorBgElevated};
}
`,
priceText: css`
font-weight: 500;
color: ${isDarkMode ? token.colorTextLightSolid : token.colorTextTertiary};
`,
}));
type ImageModelItemProps = AiModelForSelect & {
/**
* Whether to show popover on hover
* @default true
*/
showPopover?: boolean;
};
const ImageModelItem = memo<ImageModelItemProps>(
({ approximatePricePerImage, description, pricePerImage, showPopover = true, ...model }) => {
const { styles } = useStyles();
const priceLabel = useMemo(() => {
// Priority 1: Use exact price
if (typeof pricePerImage === 'number') {
return `${numeral(pricePerImage).format('$0,0.00[000]')} / image`;
}
// Priority 2: Use approximate price with prefix
if (typeof approximatePricePerImage === 'number') {
return `~ ${numeral(approximatePricePerImage).format('$0,0.00[000]')} / image`;
}
return undefined;
}, [approximatePricePerImage, pricePerImage]);
const popoverContent = useMemo(() => {
if (!description && !priceLabel) return null;
return (
<Flexbox gap={8} style={{ maxWidth: POPOVER_MAX_WIDTH }}>
{description && <Text className={styles.descriptionText}>{description}</Text>}
{priceLabel && <Text className={styles.priceText}>{priceLabel}</Text>}
</Flexbox>
);
}, [description, priceLabel, styles.descriptionText, styles.priceText]);
const content = (
<Flexbox align={'center'} gap={8} horizontal style={{ overflow: 'hidden' }}>
<ModelIcon model={model.id} size={20} />
<Text ellipsis title={model.displayName || model.id}>
{model.displayName || model.id}
</Text>
</Flexbox>
);
if (!showPopover || !popoverContent) return content;
return (
<Popover
align={{
offset: [24, -10],
}}
arrow={false}
classNames={{ root: styles.popover }}
content={popoverContent}
placement="rightTop"
>
{content}
</Popover>
);
},
);
ImageModelItem.displayName = 'ImageModelItem';
export default ImageModelItem;
@@ -7,7 +7,7 @@ import { memo, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { Flexbox } from 'react-layout-kit';
import { ModelItemRender, ProviderItemRender } from '@/components/ModelSelect';
import { ProviderItemRender } from '@/components/ModelSelect';
import { isDeprecatedEdition } from '@/const/version';
import { useAiInfraStore } from '@/store/aiInfra';
import { aiProviderSelectors } from '@/store/aiInfra/slices/aiProvider/selectors';
@@ -16,6 +16,8 @@ import { imageGenerationConfigSelectors } from '@/store/image/slices/generationC
import { featureFlagsSelectors, useServerConfigStore } from '@/store/serverConfig';
import { EnabledProviderWithModels } from '@/types/aiProvider';
import ImageModelItem from './ImageModelItem';
const useStyles = createStyles(({ css, prefixCls }) => ({
popup: css`
&.${prefixCls}-select-dropdown .${prefixCls}-select-item-option-grouped {
@@ -48,7 +50,7 @@ const ModelSelect = memo(() => {
const options = useMemo<SelectProps['options']>(() => {
const getImageModels = (provider: EnabledProviderWithModels) => {
const modelOptions = provider.children.map((model) => ({
label: <ModelItemRender {...model} {...model.abilities} showInfoTag={false} />,
label: <ImageModelItem {...model} />,
provider: provider.id,
value: `${provider.id}/${model.id}`,
}));
@@ -133,11 +135,24 @@ const ModelSelect = memo(() => {
}));
}, [enabledImageModelList, showLLM, t, theme.colorTextTertiary, router]);
const labelRender: SelectProps['labelRender'] = (props) => {
const modelInfo = enabledImageModelList
.flatMap((provider) =>
provider.children.map((model) => ({ ...model, providerId: provider.id })),
)
.find((model) => props.value === `${model.providerId}/${model.id}`);
if (!modelInfo) return props.label;
return <ImageModelItem {...modelInfo} showPopover={false} />;
};
return (
<Select
classNames={{
root: styles.popup,
}}
labelRender={labelRender}
onChange={(value, option) => {
// Skip onChange for disabled options (empty states)
if (value === 'no-provider' || value.includes('/empty')) return;
@@ -1,276 +1,172 @@
import * as runtimeModule from '@lobechat/model-runtime';
import type { EnabledAiModel, ModelAbilities } from 'model-bank';
import { describe, expect, it, vi } from 'vitest';
import type { AIImageModelCard, EnabledAiModel, ModelParamsSchema } from 'model-bank';
import { afterEach, describe, expect, it, vi } from 'vitest';
import { getModelListByType } from '../action';
import {
getChatModelList,
getImageModelList,
normalizeChatModel,
normalizeImageModel,
} from '../action';
// Test fixtures
const createChatModel = (
id: string,
providerId: string,
overrides: Partial<EnabledAiModel> = {},
): EnabledAiModel => ({
id,
providerId,
const createChatModel = (overrides: Partial<EnabledAiModel> = {}): EnabledAiModel => ({
abilities: overrides.abilities ?? { functionCall: true },
contextWindowTokens: overrides.contextWindowTokens ?? 8192,
displayName: overrides.displayName ?? 'Chat Model',
enabled: overrides.enabled ?? true,
id: overrides.id ?? 'chat-model',
providerId: overrides.providerId ?? 'openai',
type: 'chat',
abilities: { functionCall: true, files: true } satisfies ModelAbilities,
contextWindowTokens: 8192,
displayName: `${id} model`,
enabled: true,
...overrides,
});
const createImageModel = (
id: string,
providerId: string,
overrides: Partial<EnabledAiModel> = {},
): EnabledAiModel => ({
id,
providerId,
type ImageEnabledModel = EnabledAiModel & AIImageModelCard;
const createImageModel = (overrides: Partial<ImageEnabledModel> = {}): ImageEnabledModel => ({
abilities: overrides.abilities ?? {},
contextWindowTokens: overrides.contextWindowTokens,
displayName: overrides.displayName ?? 'Image Model',
enabled: overrides.enabled ?? true,
id: overrides.id ?? 'image-model',
providerId: overrides.providerId ?? 'openai',
type: 'image',
abilities: {} satisfies ModelAbilities,
displayName: `${id} model`,
enabled: true,
...overrides,
});
// Core test data
const mockChatModels = [
createChatModel('gpt-4', 'openai', {
displayName: 'GPT-4',
abilities: { functionCall: true, files: true } satisfies ModelAbilities,
}),
createChatModel('gpt-3.5-turbo', 'openai', {
displayName: 'GPT-3.5 Turbo',
abilities: { functionCall: true } satisfies ModelAbilities,
contextWindowTokens: 4096,
}),
createChatModel('claude-3-opus', 'anthropic', {
displayName: 'Claude 3 Opus',
abilities: { functionCall: false, files: true } satisfies ModelAbilities,
contextWindowTokens: 200000,
}),
];
const mockImageModels = [
createImageModel('dall-e-3', 'openai', {
displayName: 'DALL-E 3',
parameters: {
prompt: { default: '' },
size: { default: '1024x1024', enum: ['512x512', '1024x1024', '1536x1536'] },
},
}),
createImageModel('midjourney', 'midjourney', {
displayName: 'Midjourney',
}),
];
const allModels = [...mockChatModels, ...mockImageModels];
describe('getModelListByType', () => {
describe('Core Functionality', () => {
it('should filter models by providerId and type correctly', async () => {
const result = await getModelListByType(allModels, 'openai', 'chat');
expect(result).toHaveLength(2);
expect(result.map((m) => m.id)).toEqual(['gpt-4', 'gpt-3.5-turbo']);
});
it('should return correct model structure for chat models', async () => {
const result = await getModelListByType(allModels, 'openai', 'chat');
expect(result[0]).toEqual({
abilities: { functionCall: true, files: true },
contextWindowTokens: 8192,
displayName: 'GPT-4',
id: 'gpt-4',
});
});
it('should include parameters field for image models', async () => {
const result = await getModelListByType(allModels, 'openai', 'image');
expect(result[0]).toEqual({
abilities: {},
contextWindowTokens: undefined,
displayName: 'DALL-E 3',
id: 'dall-e-3',
parameters: {
prompt: { default: '' },
size: { default: '1024x1024', enum: ['512x512', '1024x1024', '1536x1536'] },
},
});
});
it('should exclude parameters field from chat models', async () => {
const result = await getModelListByType(mockChatModels, 'openai', 'chat');
result.forEach((model) => {
expect(model).not.toHaveProperty('parameters');
});
});
it('should remove duplicate model IDs', async () => {
const duplicateModels = [
createChatModel('gpt-4', 'openai', {
displayName: 'GPT-4 Version 1',
abilities: { functionCall: true } satisfies ModelAbilities,
}),
createChatModel('gpt-4', 'openai', {
displayName: 'GPT-4 Version 2',
abilities: { functionCall: false } satisfies ModelAbilities,
}),
];
const result = await getModelListByType(duplicateModels, 'openai', 'chat');
expect(result).toHaveLength(1);
expect(result[0].displayName).toBe('GPT-4 Version 1');
});
describe('aiProvider action helpers', () => {
afterEach(() => {
vi.restoreAllMocks();
});
describe('Edge Cases and Error Handling', () => {
it('should handle empty inputs gracefully', async () => {
const emptyResult = await getModelListByType([], 'openai', 'chat');
expect(emptyResult).toEqual([]);
const noMatchingProvider = await getModelListByType(allModels, 'nonexistent', 'chat');
expect(noMatchingProvider).toEqual([]);
const noMatchingType = await getModelListByType(allModels, 'openai', 'nonexistent');
expect(noMatchingType).toEqual([]);
});
it('should handle missing optional properties', async () => {
const modelWithMissingProps = createChatModel('test-model', 'test', {
displayName: undefined,
describe('normalizeChatModel', () => {
it('fills missing optional fields with safe defaults', () => {
const model = createChatModel({
abilities: undefined,
contextWindowTokens: undefined,
displayName: undefined,
});
const result = await getModelListByType([modelWithMissingProps], 'test', 'chat');
const result = normalizeChatModel(model);
expect(result[0].displayName).toBe('');
expect(result[0].abilities).toEqual({});
expect(result[0].contextWindowTokens).toBeUndefined();
});
it('should preserve complex model properties', async () => {
const complexModel = createChatModel('complex-model', 'test', {
displayName: 'Complex Model with All Properties',
abilities: {
functionCall: true,
files: true,
vision: false,
} satisfies ModelAbilities,
contextWindowTokens: 128000,
});
const result = await getModelListByType([complexModel], 'test', 'chat');
expect(result[0]).toEqual({
id: 'complex-model',
displayName: 'Complex Model with All Properties',
abilities: {
functionCall: true,
files: true,
vision: false,
},
contextWindowTokens: 128000,
});
});
});
describe('Image Model Parameter Handling', () => {
it('should use fallback parameters for image models without parameters', async () => {
vi.spyOn(runtimeModule, 'getModelPropertyWithFallback').mockResolvedValueOnce({
size: '1024x1024',
});
const result = await getModelListByType(allModels, 'midjourney', 'image');
expect(result[0]).toEqual({
expect(result).toEqual({
abilities: {},
contextWindowTokens: undefined,
displayName: 'Midjourney',
id: 'midjourney',
parameters: { size: '1024x1024' },
displayName: '',
id: 'chat-model',
});
});
it('should handle async parameter fetching for multiple models', async () => {
const imageModelsWithoutParams = [
createImageModel('stable-diffusion', 'stability', { displayName: 'Stable Diffusion' }),
createImageModel('flux-schnell', 'fal', { displayName: 'FLUX Schnell' }),
];
vi.spyOn(runtimeModule, 'getModelPropertyWithFallback').mockResolvedValue({
prompt: { default: '' },
width: { default: 512, min: 256, max: 2048 },
height: { default: 512, min: 256, max: 2048 },
});
const result = await getModelListByType(imageModelsWithoutParams, 'stability', 'image');
expect(result).toHaveLength(1);
expect(result[0].parameters).toEqual({
prompt: { default: '' },
width: { default: 512, min: 256, max: 2048 },
height: { default: 512, min: 256, max: 2048 },
});
expect(runtimeModule.getModelPropertyWithFallback).toHaveBeenCalledWith(
'stable-diffusion',
'parameters',
);
});
it('should handle failed parameter fallback gracefully', async () => {
const failingModel = createImageModel('failing-model', 'test-provider', {
displayName: 'Failing Model',
});
vi.spyOn(runtimeModule, 'getModelPropertyWithFallback').mockResolvedValueOnce(undefined);
const result = await getModelListByType([failingModel], 'test-provider', 'image');
expect(result).toHaveLength(1);
expect(result[0].id).toBe('failing-model');
expect(result[0].parameters).toBeUndefined();
});
});
describe('Concurrent Processing', () => {
it('should handle large-scale concurrent model processing', async () => {
const manyModels = Array.from({ length: 10 }, (_, i) =>
createChatModel(`model-${i}`, 'test-provider', {
displayName: `Model ${i}`,
abilities: { functionCall: i % 2 === 0 } satisfies ModelAbilities,
contextWindowTokens: 4096 + i * 1000,
}),
);
describe('normalizeImageModel', () => {
it('preserves inline metadata and pricing information', async () => {
const model = createImageModel({
abilities: { vision: true },
contextWindowTokens: 4096,
displayName: 'Inline Model',
parameters: {
prompt: { default: '' },
size: { default: '1024x1024', enum: ['512x512', '1024x1024'] },
} as ModelParamsSchema,
pricing: {
units: [{ name: 'imageGeneration', rate: 0.04, strategy: 'fixed', unit: 'image' }],
},
});
const result = await getModelListByType(manyModels, 'test-provider', 'chat');
const result = await normalizeImageModel(model);
expect(result).toHaveLength(10);
expect(result.map((m) => m.id)).toEqual(manyModels.map((m) => m.id));
result.forEach((model, index) => {
expect(model.abilities.functionCall).toBe(index % 2 === 0);
expect(model.contextWindowTokens).toBe(4096 + index * 1000);
expect(result).toMatchObject({
abilities: { vision: true },
displayName: 'Inline Model',
parameters: { size: { default: '1024x1024', enum: ['512x512', '1024x1024'] } },
pricing: {
units: [{ name: 'imageGeneration', rate: 0.04, strategy: 'fixed', unit: 'image' }],
},
});
});
it('should maintain model order during concurrent processing', async () => {
const orderedModels = [
createChatModel('first-model', 'test', { displayName: 'First Model' }),
createChatModel('second-model', 'test', { displayName: 'Second Model' }),
createChatModel('third-model', 'test', { displayName: 'Third Model' }),
it('fetches fallback description/parameters/pricing when missing', async () => {
const fallbackSpy = vi
.spyOn(runtimeModule, 'getModelPropertyWithFallback')
.mockImplementation(async (_id, key) => {
if (key === 'parameters')
return {
prompt: { default: '' },
size: { default: '768x768', enum: ['512x512', '768x768'] },
} satisfies ModelParamsSchema;
if (key === 'pricing')
return {
units: [{ name: 'imageGeneration', rate: 0.02, strategy: 'fixed', unit: 'image' }],
};
if (key === 'description') return 'Fallback description';
return undefined;
});
const model = createImageModel({
id: 'stable-diffusion',
providerId: 'stability',
parameters: undefined,
pricing: undefined,
});
const result = await normalizeImageModel(model);
expect(result.parameters).toEqual({
prompt: { default: '' },
size: { default: '768x768', enum: ['512x512', '768x768'] },
});
expect(result.pricing).toEqual({
units: [{ name: 'imageGeneration', rate: 0.02, strategy: 'fixed', unit: 'image' }],
});
expect(result.description).toBe('Fallback description');
expect(fallbackSpy).toHaveBeenCalledWith('stable-diffusion', 'parameters', 'stability');
expect(fallbackSpy).toHaveBeenCalledWith('stable-diffusion', 'pricing', 'stability');
expect(fallbackSpy).toHaveBeenCalledWith('stable-diffusion', 'description', 'stability');
});
});
describe('getChatModelList', () => {
const chatModels = [
createChatModel({ id: 'gpt-4', providerId: 'openai', displayName: 'GPT-4' }),
createChatModel({ id: 'gpt-3.5', providerId: 'openai', displayName: 'GPT-3.5' }),
createChatModel({ id: 'claude-3', providerId: 'anthropic', displayName: 'Claude 3' }),
];
it('filters by provider and deduplicates IDs', async () => {
const duplicated = [
...chatModels,
createChatModel({ id: 'gpt-4', providerId: 'openai', displayName: 'GPT-4 Duplicate' }),
];
const result = await getModelListByType(orderedModels, 'test', 'chat');
const result = await getChatModelList(duplicated, 'openai');
expect(result.map((m) => m.id)).toEqual(['first-model', 'second-model', 'third-model']);
expect(result).toHaveLength(2);
expect(result.map((m) => m.id)).toEqual(['gpt-4', 'gpt-3.5']);
expect(result[0].displayName).toBe('GPT-4');
});
it('returns empty array when provider has no chat models', async () => {
const result = await getChatModelList(chatModels, 'nonexistent');
expect(result).toEqual([]);
});
});
describe('getImageModelList', () => {
const imageModels = [
createImageModel({ id: 'dall-e-3', providerId: 'openai', displayName: 'DALL-E 3' }),
createImageModel({ id: 'midjourney', providerId: 'midjourney', displayName: 'Midjourney' }),
];
it('collects normalized image models for a provider', async () => {
const result = await getImageModelList(imageModels, 'openai');
expect(result).toHaveLength(1);
expect(result[0].id).toBe('dall-e-3');
expect(result[0].displayName).toBe('DALL-E 3');
});
it('returns empty array when provider has no image models', async () => {
const result = await getImageModelList(imageModels, 'unknown');
expect(result).toEqual([]);
});
});
});
+117 -37
View File
@@ -1,11 +1,13 @@
import { isDeprecatedEdition, isDesktop, isUsePgliteDB } from '@lobechat/const';
import { getModelPropertyWithFallback } from '@lobechat/model-runtime';
import { getModelPropertyWithFallback, resolveImageSinglePrice } from '@lobechat/model-runtime';
import { uniqBy } from 'lodash-es';
import {
AIImageModelCard,
EnabledAiModel,
LobeDefaultAiModelListItem,
ModelAbilities,
ModelParamsSchema,
Pricing,
} from 'model-bank';
import { SWRResponse, mutate } from 'swr';
import { StateCreator } from 'zustand/vanilla';
@@ -28,52 +30,130 @@ import {
UpdateAiProviderParams,
} from '@/types/aiProvider';
/**
* Get models by provider ID and type, with proper formatting and deduplication
*/
export const getModelListByType = async (
enabledAiModels: EnabledAiModel[],
providerId: string,
type: string,
) => {
const filteredModels = enabledAiModels.filter(
(model) => model.providerId === providerId && model.type === type,
);
const models = await Promise.all(
filteredModels.map(async (model) => ({
abilities: (model.abilities || {}) as ModelAbilities,
contextWindowTokens: model.contextWindowTokens,
displayName: model.displayName ?? '',
id: model.id,
...(model.type === 'image' && {
parameters:
(model as AIImageModelCard).parameters ||
(await getModelPropertyWithFallback(model.id, 'parameters')),
}),
})),
);
return uniqBy(models, 'id');
export type ProviderModelListItem = {
abilities: ModelAbilities;
approximatePricePerImage?: number;
contextWindowTokens?: number;
description?: string;
displayName: string;
id: string;
parameters?: ModelParamsSchema;
pricePerImage?: number;
pricing?: Pricing;
};
/**
* Build provider model lists with proper async handling
*/
type ModelNormalizer = (model: EnabledAiModel) => Promise<ProviderModelListItem>;
const dedupeById = (models: ProviderModelListItem[]) => uniqBy(models, 'id');
const createProviderModelCollector = (
type: EnabledAiModel['type'],
normalizer: ModelNormalizer,
) => {
return async (enabledAiModels: EnabledAiModel[], providerId: string) => {
const filteredModels = enabledAiModels.filter(
(model) => model.providerId === providerId && model.type === type,
);
if (!filteredModels.length) return [];
const normalized = await Promise.all(filteredModels.map((model) => normalizer(model)));
return dedupeById(normalized);
};
};
export const normalizeChatModel = (model: EnabledAiModel): ProviderModelListItem => ({
abilities: (model.abilities || {}) as ModelAbilities,
contextWindowTokens: model.contextWindowTokens,
displayName: model.displayName ?? '',
id: model.id,
});
export const normalizeImageModel = async (
model: EnabledAiModel,
): Promise<ProviderModelListItem> => {
const fallbackParametersPromise = model.parameters
? Promise.resolve<ModelParamsSchema | undefined>(model.parameters)
: getModelPropertyWithFallback<ModelParamsSchema | undefined>(
model.id,
'parameters',
model.providerId,
);
const modelWithPricing = model as AIImageModelCard;
const fallbackPricingPromise = modelWithPricing.pricing
? Promise.resolve<Pricing | undefined>(modelWithPricing.pricing)
: getModelPropertyWithFallback<Pricing | undefined>(model.id, 'pricing', model.providerId);
const fallbackDescriptionPromise = getModelPropertyWithFallback<string | undefined>(
model.id,
'description',
model.providerId,
);
const [fallbackParameters, fallbackPricing, fallbackDescription] = await Promise.all([
fallbackParametersPromise,
fallbackPricingPromise,
fallbackDescriptionPromise,
]);
const parameters = model.parameters ?? fallbackParameters;
const pricing = fallbackPricing;
const description = fallbackDescription;
const { price, approximatePrice } = resolveImageSinglePrice(pricing);
return {
abilities: (model.abilities || {}) as ModelAbilities,
contextWindowTokens: model.contextWindowTokens,
displayName: model.displayName ?? '',
id: model.id,
...(parameters && { parameters }),
...(description && { description }),
...(pricing && { pricing }),
...(typeof approximatePrice === 'number' && { approximatePricePerImage: approximatePrice }),
...(typeof price === 'number' && { pricePerImage: price }),
};
};
export const getChatModelList = createProviderModelCollector('chat', async (model) =>
normalizeChatModel(model),
);
export const getImageModelList = createProviderModelCollector('image', normalizeImageModel);
const buildProviderModelLists = async (
providers: EnabledProvider[],
enabledAiModels: EnabledAiModel[],
type: 'chat' | 'image',
collector: (
enabledAiModels: EnabledAiModel[],
providerId: string,
) => Promise<ProviderModelListItem[]>,
) => {
return Promise.all(
providers.map(async (provider) => ({
...provider,
children: await getModelListByType(enabledAiModels, provider.id, type),
children: await collector(enabledAiModels, provider.id),
name: provider.name || provider.id,
})),
);
};
/**
* Build image provider model lists with proper async handling
*/
const buildImageProviderModelLists = async (
providers: EnabledProvider[],
enabledAiModels: EnabledAiModel[],
) => buildProviderModelLists(providers, enabledAiModels, getImageModelList);
/**
* Build chat provider model lists with proper async handling
*/
const buildChatProviderModelLists = async (
providers: EnabledProvider[],
enabledAiModels: EnabledAiModel[],
) => buildProviderModelLists(providers, enabledAiModels, getChatModelList);
enum AiProviderSwrKey {
fetchAiProviderItem = 'FETCH_AI_PROVIDER_ITEM',
fetchAiProviderList = 'FETCH_AI_PROVIDER',
@@ -252,8 +332,8 @@ export const createAiProviderSlice: StateCreator<
// Build model lists with proper async handling
const [enabledChatModelList, enabledImageModelList] = await Promise.all([
buildProviderModelLists(data.enabledChatAiProviders, data.enabledAiModels, 'chat'),
buildProviderModelLists(data.enabledImageAiProviders, data.enabledAiModels, 'image'),
buildChatProviderModelLists(data.enabledChatAiProviders, data.enabledAiModels),
buildImageProviderModelLists(data.enabledImageAiProviders, data.enabledAiModels),
]);
return {
@@ -285,8 +365,8 @@ export const createAiProviderSlice: StateCreator<
// Build model lists for non-login state as well
const enabledAiModels = builtinAiModelList.filter((m) => m.enabled);
const [enabledChatModelList, enabledImageModelList] = await Promise.all([
buildProviderModelLists(enabledChatAiProviders, enabledAiModels, 'chat'),
buildProviderModelLists(enabledImageAiProviders, enabledAiModels, 'image'),
buildChatProviderModelLists(enabledChatAiProviders, enabledAiModels),
buildImageProviderModelLists(enabledImageAiProviders, enabledAiModels),
]);
return {