mirror of
https://github.com/lobehub/lobe-chat.git
synced 2026-06-18 05:18:31 +00:00
✨ feat(image): image model show price (#10198)
This commit is contained in:
@@ -1056,6 +1056,7 @@ const aihubmixModels: AIChatModelCard[] = [
|
||||
id: 'gemini-2.5-flash-image',
|
||||
maxOutput: 8192,
|
||||
pricing: {
|
||||
approximatePricePerImage: 0.039,
|
||||
units: [
|
||||
{ name: 'textInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
|
||||
{ name: 'imageInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
|
||||
|
||||
@@ -370,6 +370,7 @@ const googleChatModels: AIChatModelCard[] = [
|
||||
id: 'gemini-2.5-flash-image',
|
||||
maxOutput: 8192,
|
||||
pricing: {
|
||||
approximatePricePerImage: 0.039,
|
||||
units: [
|
||||
{ name: 'textInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
|
||||
{ name: 'imageInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
|
||||
@@ -392,6 +393,7 @@ const googleChatModels: AIChatModelCard[] = [
|
||||
id: 'gemini-2.5-flash-image-preview',
|
||||
maxOutput: 8192,
|
||||
pricing: {
|
||||
approximatePricePerImage: 0.039,
|
||||
units: [
|
||||
{ name: 'textInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
|
||||
{ name: 'imageInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
|
||||
@@ -864,6 +866,7 @@ const googleImageModels: AIImageModelCard[] = [
|
||||
releasedAt: '2025-08-26',
|
||||
parameters: nanoBananaParameters,
|
||||
pricing: {
|
||||
approximatePricePerImage: 0.039,
|
||||
units: [
|
||||
{ name: 'textInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
|
||||
{ name: 'textOutput', rate: 2.5, strategy: 'fixed', unit: 'millionTokens' },
|
||||
@@ -880,6 +883,7 @@ const googleImageModels: AIImageModelCard[] = [
|
||||
releasedAt: '2025-08-26',
|
||||
parameters: CHAT_MODEL_IMAGE_GENERATION_PARAMS,
|
||||
pricing: {
|
||||
approximatePricePerImage: 0.039,
|
||||
units: [
|
||||
{ name: 'textInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
|
||||
{ name: 'textOutput', rate: 2.5, strategy: 'fixed', unit: 'millionTokens' },
|
||||
@@ -892,7 +896,7 @@ const googleImageModels: AIImageModelCard[] = [
|
||||
id: 'imagen-4.0-generate-001',
|
||||
enabled: true,
|
||||
type: 'image',
|
||||
description: 'Imagen 4th generation text-to-image model series',
|
||||
description: 'Imagen 第四代文生图模型系列',
|
||||
organization: 'Deepmind',
|
||||
releasedAt: '2025-08-15',
|
||||
parameters: imagenGenParameters,
|
||||
@@ -905,7 +909,7 @@ const googleImageModels: AIImageModelCard[] = [
|
||||
id: 'imagen-4.0-ultra-generate-001',
|
||||
enabled: true,
|
||||
type: 'image',
|
||||
description: 'Imagen 4th generation text-to-image model series Ultra version',
|
||||
description: 'Imagen 第四代文生图模型系列的 Ultra 版本',
|
||||
organization: 'Deepmind',
|
||||
releasedAt: '2025-08-15',
|
||||
parameters: imagenGenParameters,
|
||||
@@ -918,7 +922,7 @@ const googleImageModels: AIImageModelCard[] = [
|
||||
id: 'imagen-4.0-fast-generate-001',
|
||||
enabled: true,
|
||||
type: 'image',
|
||||
description: 'Imagen 4th generation text-to-image model series Fast version',
|
||||
description: 'Imagen 第四代文生图模型系列的快速版本',
|
||||
organization: 'Deepmind',
|
||||
releasedAt: '2025-08-15',
|
||||
parameters: imagenGenParameters,
|
||||
@@ -930,7 +934,7 @@ const googleImageModels: AIImageModelCard[] = [
|
||||
displayName: 'Imagen 4 Preview 06-06',
|
||||
id: 'imagen-4.0-generate-preview-06-06',
|
||||
type: 'image',
|
||||
description: 'Imagen 4th generation text-to-image model series',
|
||||
description: 'Imagen 第四代文生图模型系列',
|
||||
organization: 'Deepmind',
|
||||
releasedAt: '2025-06-06',
|
||||
parameters: imagenGenParameters,
|
||||
@@ -942,7 +946,7 @@ const googleImageModels: AIImageModelCard[] = [
|
||||
displayName: 'Imagen 4 Ultra Preview 06-06',
|
||||
id: 'imagen-4.0-ultra-generate-preview-06-06',
|
||||
type: 'image',
|
||||
description: 'Imagen 4th generation text-to-image model series Ultra version',
|
||||
description: 'Imagen 第四代文生图模型系列的 Ultra 版本',
|
||||
organization: 'Deepmind',
|
||||
releasedAt: '2025-06-11',
|
||||
parameters: imagenGenParameters,
|
||||
|
||||
@@ -1171,31 +1171,13 @@ export const openaiImageModels: AIImageModelCard[] = [
|
||||
id: 'gpt-image-1',
|
||||
parameters: gptImage1ParamsSchema,
|
||||
pricing: {
|
||||
approximatePricePerImage: 0.042,
|
||||
units: [
|
||||
{ name: 'textInput', rate: 5, strategy: 'fixed', unit: 'millionTokens' },
|
||||
{ name: 'textInput_cacheRead', rate: 1.25, strategy: 'fixed', unit: 'millionTokens' },
|
||||
{ name: 'imageInput', rate: 10, strategy: 'fixed', unit: 'millionTokens' },
|
||||
{ name: 'imageInput_cacheRead', rate: 2.5, strategy: 'fixed', unit: 'millionTokens' },
|
||||
{ name: 'imageOutput', rate: 40, strategy: 'fixed', unit: 'millionTokens' },
|
||||
{
|
||||
lookup: {
|
||||
prices: {
|
||||
high_1024x1024: 0.167,
|
||||
high_1024x1536: 0.25,
|
||||
high_1536x1024: 0.25,
|
||||
low_1024x1024: 0.011,
|
||||
low_1024x1536: 0.016,
|
||||
low_1536x1024: 0.016,
|
||||
medium_1024x1024: 0.042,
|
||||
medium_1024x1536: 0.063,
|
||||
medium_1536x1024: 0.063,
|
||||
},
|
||||
pricingParams: ['quality', 'size'],
|
||||
},
|
||||
name: 'imageGeneration',
|
||||
strategy: 'lookup',
|
||||
unit: 'image',
|
||||
},
|
||||
],
|
||||
},
|
||||
resolutions: ['1024x1024', '1024x1536', '1536x1024'],
|
||||
@@ -1208,28 +1190,13 @@ export const openaiImageModels: AIImageModelCard[] = [
|
||||
id: 'gpt-image-1-mini',
|
||||
parameters: gptImage1ParamsSchema,
|
||||
pricing: {
|
||||
approximatePricePerImage: 0.011,
|
||||
units: [
|
||||
{ name: 'textInput', rate: 2, strategy: 'fixed', unit: 'millionTokens' },
|
||||
{ name: 'textInput_cacheRead', rate: 0.2, strategy: 'fixed', unit: 'millionTokens' },
|
||||
{ name: 'imageInput', rate: 2.5, strategy: 'fixed', unit: 'millionTokens' },
|
||||
{ name: 'imageInput_cacheRead', rate: 0.25, strategy: 'fixed', unit: 'millionTokens' },
|
||||
{ name: 'imageOutput', rate: 8, strategy: 'fixed', unit: 'millionTokens' },
|
||||
{
|
||||
lookup: {
|
||||
prices: {
|
||||
low_1024x1024: 0.005,
|
||||
low_1024x1536: 0.006,
|
||||
low_1536x1024: 0.006,
|
||||
medium_1024x1024: 0.011,
|
||||
medium_1024x1536: 0.015,
|
||||
medium_1536x1024: 0.015,
|
||||
},
|
||||
pricingParams: ['quality', 'size'],
|
||||
},
|
||||
name: 'imageGeneration',
|
||||
strategy: 'lookup',
|
||||
unit: 'image',
|
||||
},
|
||||
],
|
||||
},
|
||||
releasedAt: '2025-10-06',
|
||||
|
||||
@@ -41,6 +41,7 @@ const openrouterChatModels: AIChatModelCard[] = [
|
||||
id: 'google/gemini-2.5-flash-image-preview',
|
||||
maxOutput: 8192,
|
||||
pricing: {
|
||||
approximatePricePerImage: 0.039,
|
||||
units: [
|
||||
{ name: 'imageOutput', rate: 30, strategy: 'fixed', unit: 'millionTokens' },
|
||||
{ name: 'textInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
|
||||
|
||||
@@ -135,6 +135,7 @@ const vertexaiChatModels: AIChatModelCard[] = [
|
||||
id: 'gemini-2.5-flash-image-preview',
|
||||
maxOutput: 8192,
|
||||
pricing: {
|
||||
approximatePricePerImage: 0.039,
|
||||
units: [
|
||||
{ name: 'textInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
|
||||
{ name: 'textOutput', rate: 2.5, strategy: 'fixed', unit: 'millionTokens' },
|
||||
@@ -291,6 +292,7 @@ const vertexaiImageModels: AIImageModelCard[] = [
|
||||
releasedAt: '2025-08-26',
|
||||
parameters: nanoBananaParameters,
|
||||
pricing: {
|
||||
approximatePricePerImage: 0.039,
|
||||
units: [
|
||||
{ name: 'textInput', rate: 0.3, strategy: 'fixed', unit: 'millionTokens' },
|
||||
{ name: 'textOutput', rate: 2.5, strategy: 'fixed', unit: 'millionTokens' },
|
||||
|
||||
@@ -185,6 +185,10 @@ export interface LookupPricingUnit extends PricingUnitBase {
|
||||
export type PricingUnit = FixedPricingUnit | TieredPricingUnit | LookupPricingUnit;
|
||||
|
||||
export interface Pricing {
|
||||
/**
|
||||
* Fallback approximate per-image price (USD) when detailed pricing table is unavailable
|
||||
*/
|
||||
approximatePricePerImage?: number;
|
||||
currency?: ModelPriceCurrency;
|
||||
units: PricingUnit[];
|
||||
}
|
||||
@@ -391,13 +395,22 @@ export const ToggleAiModelEnableSchema = z.object({
|
||||
|
||||
export type ToggleAiModelEnableParams = z.infer<typeof ToggleAiModelEnableSchema>;
|
||||
|
||||
//
|
||||
|
||||
export interface AiModelForSelect {
|
||||
abilities: ModelAbilities;
|
||||
/**
|
||||
* Approximate per-image price (USD), used when exact calculation is not possible
|
||||
*/
|
||||
approximatePricePerImage?: number;
|
||||
contextWindowTokens?: number;
|
||||
description?: string;
|
||||
displayName?: string;
|
||||
id: string;
|
||||
parameters?: ModelParamsSchema;
|
||||
/**
|
||||
* Exact per-image price (USD) calculated from pricing units
|
||||
*/
|
||||
pricePerImage?: number;
|
||||
pricing?: Pricing;
|
||||
}
|
||||
|
||||
export interface EnabledAiModel {
|
||||
|
||||
@@ -2,3 +2,4 @@ export { convertAnthropicUsage } from './anthropic';
|
||||
export { convertGoogleAIUsage } from './google-ai';
|
||||
export { convertOpenAIResponseUsage, convertOpenAIUsage } from './openai';
|
||||
export { computeImageCost } from './utils/computeImageCost';
|
||||
export { resolveImageSinglePrice } from './utils/resolveImageSinglePrice';
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
import { Pricing } from 'model-bank';
|
||||
|
||||
export interface ImageSinglePriceResult {
|
||||
approximatePrice?: number;
|
||||
price?: number;
|
||||
}
|
||||
|
||||
const DEFAULT_REFERENCE_MP = (1024 * 1024) / 1_000_000;
|
||||
|
||||
export const resolveImageSinglePrice = (pricing?: Pricing): ImageSinglePriceResult => {
|
||||
if (!pricing) return {};
|
||||
|
||||
// Priority 1: Use approximate price if explicitly provided
|
||||
if (typeof pricing.approximatePricePerImage === 'number') {
|
||||
return { approximatePrice: pricing.approximatePricePerImage };
|
||||
}
|
||||
|
||||
// Priority 2: Calculate exact price from pricing units
|
||||
const imageGenerationUnit = pricing.units.find((unit) => unit.name === 'imageGeneration');
|
||||
if (!imageGenerationUnit) return {};
|
||||
|
||||
if (imageGenerationUnit.strategy === 'fixed') {
|
||||
if (imageGenerationUnit.unit === 'image') {
|
||||
return { price: imageGenerationUnit.rate };
|
||||
}
|
||||
|
||||
if (imageGenerationUnit.unit === 'megapixel') {
|
||||
return { price: imageGenerationUnit.rate * DEFAULT_REFERENCE_MP };
|
||||
}
|
||||
}
|
||||
|
||||
// Lookup/tiered pricing typically requires explicit configuration; treat as unavailable here.
|
||||
return {};
|
||||
};
|
||||
@@ -1,4 +1,4 @@
|
||||
import { ModelParamsSchema , AiModelType, Pricing } from 'model-bank';
|
||||
import { AiModelType, ModelParamsSchema, Pricing } from 'model-bank';
|
||||
import { ReactNode } from 'react';
|
||||
|
||||
import { AiProviderSettings } from './aiProvider';
|
||||
|
||||
+93
@@ -0,0 +1,93 @@
|
||||
import { ModelIcon } from '@lobehub/icons';
|
||||
import { Text } from '@lobehub/ui';
|
||||
import { Popover } from 'antd';
|
||||
import { createStyles } from 'antd-style';
|
||||
import { AiModelForSelect } from 'model-bank';
|
||||
import numeral from 'numeral';
|
||||
import { memo, useMemo } from 'react';
|
||||
import { Flexbox } from 'react-layout-kit';
|
||||
|
||||
const POPOVER_MAX_WIDTH = 320;
|
||||
|
||||
const useStyles = createStyles(({ css, token, isDarkMode }) => ({
|
||||
descriptionText: css`
|
||||
color: ${isDarkMode ? token.colorText : token.colorTextSecondary};
|
||||
`,
|
||||
popover: css`
|
||||
.ant-popover-inner {
|
||||
background: ${isDarkMode ? token.colorBgSpotlight : token.colorBgElevated};
|
||||
}
|
||||
`,
|
||||
priceText: css`
|
||||
font-weight: 500;
|
||||
color: ${isDarkMode ? token.colorTextLightSolid : token.colorTextTertiary};
|
||||
`,
|
||||
}));
|
||||
|
||||
type ImageModelItemProps = AiModelForSelect & {
|
||||
/**
|
||||
* Whether to show popover on hover
|
||||
* @default true
|
||||
*/
|
||||
showPopover?: boolean;
|
||||
};
|
||||
|
||||
const ImageModelItem = memo<ImageModelItemProps>(
|
||||
({ approximatePricePerImage, description, pricePerImage, showPopover = true, ...model }) => {
|
||||
const { styles } = useStyles();
|
||||
|
||||
const priceLabel = useMemo(() => {
|
||||
// Priority 1: Use exact price
|
||||
if (typeof pricePerImage === 'number') {
|
||||
return `${numeral(pricePerImage).format('$0,0.00[000]')} / image`;
|
||||
}
|
||||
|
||||
// Priority 2: Use approximate price with prefix
|
||||
if (typeof approximatePricePerImage === 'number') {
|
||||
return `~ ${numeral(approximatePricePerImage).format('$0,0.00[000]')} / image`;
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}, [approximatePricePerImage, pricePerImage]);
|
||||
|
||||
const popoverContent = useMemo(() => {
|
||||
if (!description && !priceLabel) return null;
|
||||
|
||||
return (
|
||||
<Flexbox gap={8} style={{ maxWidth: POPOVER_MAX_WIDTH }}>
|
||||
{description && <Text className={styles.descriptionText}>{description}</Text>}
|
||||
{priceLabel && <Text className={styles.priceText}>{priceLabel}</Text>}
|
||||
</Flexbox>
|
||||
);
|
||||
}, [description, priceLabel, styles.descriptionText, styles.priceText]);
|
||||
|
||||
const content = (
|
||||
<Flexbox align={'center'} gap={8} horizontal style={{ overflow: 'hidden' }}>
|
||||
<ModelIcon model={model.id} size={20} />
|
||||
<Text ellipsis title={model.displayName || model.id}>
|
||||
{model.displayName || model.id}
|
||||
</Text>
|
||||
</Flexbox>
|
||||
);
|
||||
|
||||
if (!showPopover || !popoverContent) return content;
|
||||
|
||||
return (
|
||||
<Popover
|
||||
align={{
|
||||
offset: [24, -10],
|
||||
}}
|
||||
arrow={false}
|
||||
classNames={{ root: styles.popover }}
|
||||
content={popoverContent}
|
||||
placement="rightTop"
|
||||
>
|
||||
{content}
|
||||
</Popover>
|
||||
);
|
||||
},
|
||||
);
|
||||
|
||||
ImageModelItem.displayName = 'ImageModelItem';
|
||||
|
||||
export default ImageModelItem;
|
||||
+17
-2
@@ -7,7 +7,7 @@ import { memo, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Flexbox } from 'react-layout-kit';
|
||||
|
||||
import { ModelItemRender, ProviderItemRender } from '@/components/ModelSelect';
|
||||
import { ProviderItemRender } from '@/components/ModelSelect';
|
||||
import { isDeprecatedEdition } from '@/const/version';
|
||||
import { useAiInfraStore } from '@/store/aiInfra';
|
||||
import { aiProviderSelectors } from '@/store/aiInfra/slices/aiProvider/selectors';
|
||||
@@ -16,6 +16,8 @@ import { imageGenerationConfigSelectors } from '@/store/image/slices/generationC
|
||||
import { featureFlagsSelectors, useServerConfigStore } from '@/store/serverConfig';
|
||||
import { EnabledProviderWithModels } from '@/types/aiProvider';
|
||||
|
||||
import ImageModelItem from './ImageModelItem';
|
||||
|
||||
const useStyles = createStyles(({ css, prefixCls }) => ({
|
||||
popup: css`
|
||||
&.${prefixCls}-select-dropdown .${prefixCls}-select-item-option-grouped {
|
||||
@@ -48,7 +50,7 @@ const ModelSelect = memo(() => {
|
||||
const options = useMemo<SelectProps['options']>(() => {
|
||||
const getImageModels = (provider: EnabledProviderWithModels) => {
|
||||
const modelOptions = provider.children.map((model) => ({
|
||||
label: <ModelItemRender {...model} {...model.abilities} showInfoTag={false} />,
|
||||
label: <ImageModelItem {...model} />,
|
||||
provider: provider.id,
|
||||
value: `${provider.id}/${model.id}`,
|
||||
}));
|
||||
@@ -133,11 +135,24 @@ const ModelSelect = memo(() => {
|
||||
}));
|
||||
}, [enabledImageModelList, showLLM, t, theme.colorTextTertiary, router]);
|
||||
|
||||
const labelRender: SelectProps['labelRender'] = (props) => {
|
||||
const modelInfo = enabledImageModelList
|
||||
.flatMap((provider) =>
|
||||
provider.children.map((model) => ({ ...model, providerId: provider.id })),
|
||||
)
|
||||
.find((model) => props.value === `${model.providerId}/${model.id}`);
|
||||
|
||||
if (!modelInfo) return props.label;
|
||||
|
||||
return <ImageModelItem {...modelInfo} showPopover={false} />;
|
||||
};
|
||||
|
||||
return (
|
||||
<Select
|
||||
classNames={{
|
||||
root: styles.popup,
|
||||
}}
|
||||
labelRender={labelRender}
|
||||
onChange={(value, option) => {
|
||||
// Skip onChange for disabled options (empty states)
|
||||
if (value === 'no-provider' || value.includes('/empty')) return;
|
||||
@@ -1,276 +1,172 @@
|
||||
import * as runtimeModule from '@lobechat/model-runtime';
|
||||
import type { EnabledAiModel, ModelAbilities } from 'model-bank';
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
import type { AIImageModelCard, EnabledAiModel, ModelParamsSchema } from 'model-bank';
|
||||
import { afterEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { getModelListByType } from '../action';
|
||||
import {
|
||||
getChatModelList,
|
||||
getImageModelList,
|
||||
normalizeChatModel,
|
||||
normalizeImageModel,
|
||||
} from '../action';
|
||||
|
||||
// Test fixtures
|
||||
const createChatModel = (
|
||||
id: string,
|
||||
providerId: string,
|
||||
overrides: Partial<EnabledAiModel> = {},
|
||||
): EnabledAiModel => ({
|
||||
id,
|
||||
providerId,
|
||||
const createChatModel = (overrides: Partial<EnabledAiModel> = {}): EnabledAiModel => ({
|
||||
abilities: overrides.abilities ?? { functionCall: true },
|
||||
contextWindowTokens: overrides.contextWindowTokens ?? 8192,
|
||||
displayName: overrides.displayName ?? 'Chat Model',
|
||||
enabled: overrides.enabled ?? true,
|
||||
id: overrides.id ?? 'chat-model',
|
||||
providerId: overrides.providerId ?? 'openai',
|
||||
type: 'chat',
|
||||
abilities: { functionCall: true, files: true } satisfies ModelAbilities,
|
||||
contextWindowTokens: 8192,
|
||||
displayName: `${id} model`,
|
||||
enabled: true,
|
||||
...overrides,
|
||||
});
|
||||
|
||||
const createImageModel = (
|
||||
id: string,
|
||||
providerId: string,
|
||||
overrides: Partial<EnabledAiModel> = {},
|
||||
): EnabledAiModel => ({
|
||||
id,
|
||||
providerId,
|
||||
type ImageEnabledModel = EnabledAiModel & AIImageModelCard;
|
||||
|
||||
const createImageModel = (overrides: Partial<ImageEnabledModel> = {}): ImageEnabledModel => ({
|
||||
abilities: overrides.abilities ?? {},
|
||||
contextWindowTokens: overrides.contextWindowTokens,
|
||||
displayName: overrides.displayName ?? 'Image Model',
|
||||
enabled: overrides.enabled ?? true,
|
||||
id: overrides.id ?? 'image-model',
|
||||
providerId: overrides.providerId ?? 'openai',
|
||||
type: 'image',
|
||||
abilities: {} satisfies ModelAbilities,
|
||||
displayName: `${id} model`,
|
||||
enabled: true,
|
||||
...overrides,
|
||||
});
|
||||
|
||||
// Core test data
|
||||
const mockChatModels = [
|
||||
createChatModel('gpt-4', 'openai', {
|
||||
displayName: 'GPT-4',
|
||||
abilities: { functionCall: true, files: true } satisfies ModelAbilities,
|
||||
}),
|
||||
createChatModel('gpt-3.5-turbo', 'openai', {
|
||||
displayName: 'GPT-3.5 Turbo',
|
||||
abilities: { functionCall: true } satisfies ModelAbilities,
|
||||
contextWindowTokens: 4096,
|
||||
}),
|
||||
createChatModel('claude-3-opus', 'anthropic', {
|
||||
displayName: 'Claude 3 Opus',
|
||||
abilities: { functionCall: false, files: true } satisfies ModelAbilities,
|
||||
contextWindowTokens: 200000,
|
||||
}),
|
||||
];
|
||||
|
||||
const mockImageModels = [
|
||||
createImageModel('dall-e-3', 'openai', {
|
||||
displayName: 'DALL-E 3',
|
||||
parameters: {
|
||||
prompt: { default: '' },
|
||||
size: { default: '1024x1024', enum: ['512x512', '1024x1024', '1536x1536'] },
|
||||
},
|
||||
}),
|
||||
createImageModel('midjourney', 'midjourney', {
|
||||
displayName: 'Midjourney',
|
||||
}),
|
||||
];
|
||||
|
||||
const allModels = [...mockChatModels, ...mockImageModels];
|
||||
|
||||
describe('getModelListByType', () => {
|
||||
describe('Core Functionality', () => {
|
||||
it('should filter models by providerId and type correctly', async () => {
|
||||
const result = await getModelListByType(allModels, 'openai', 'chat');
|
||||
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result.map((m) => m.id)).toEqual(['gpt-4', 'gpt-3.5-turbo']);
|
||||
});
|
||||
|
||||
it('should return correct model structure for chat models', async () => {
|
||||
const result = await getModelListByType(allModels, 'openai', 'chat');
|
||||
|
||||
expect(result[0]).toEqual({
|
||||
abilities: { functionCall: true, files: true },
|
||||
contextWindowTokens: 8192,
|
||||
displayName: 'GPT-4',
|
||||
id: 'gpt-4',
|
||||
});
|
||||
});
|
||||
|
||||
it('should include parameters field for image models', async () => {
|
||||
const result = await getModelListByType(allModels, 'openai', 'image');
|
||||
|
||||
expect(result[0]).toEqual({
|
||||
abilities: {},
|
||||
contextWindowTokens: undefined,
|
||||
displayName: 'DALL-E 3',
|
||||
id: 'dall-e-3',
|
||||
parameters: {
|
||||
prompt: { default: '' },
|
||||
size: { default: '1024x1024', enum: ['512x512', '1024x1024', '1536x1536'] },
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should exclude parameters field from chat models', async () => {
|
||||
const result = await getModelListByType(mockChatModels, 'openai', 'chat');
|
||||
|
||||
result.forEach((model) => {
|
||||
expect(model).not.toHaveProperty('parameters');
|
||||
});
|
||||
});
|
||||
|
||||
it('should remove duplicate model IDs', async () => {
|
||||
const duplicateModels = [
|
||||
createChatModel('gpt-4', 'openai', {
|
||||
displayName: 'GPT-4 Version 1',
|
||||
abilities: { functionCall: true } satisfies ModelAbilities,
|
||||
}),
|
||||
createChatModel('gpt-4', 'openai', {
|
||||
displayName: 'GPT-4 Version 2',
|
||||
abilities: { functionCall: false } satisfies ModelAbilities,
|
||||
}),
|
||||
];
|
||||
|
||||
const result = await getModelListByType(duplicateModels, 'openai', 'chat');
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].displayName).toBe('GPT-4 Version 1');
|
||||
});
|
||||
describe('aiProvider action helpers', () => {
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks();
|
||||
});
|
||||
|
||||
describe('Edge Cases and Error Handling', () => {
|
||||
it('should handle empty inputs gracefully', async () => {
|
||||
const emptyResult = await getModelListByType([], 'openai', 'chat');
|
||||
expect(emptyResult).toEqual([]);
|
||||
|
||||
const noMatchingProvider = await getModelListByType(allModels, 'nonexistent', 'chat');
|
||||
expect(noMatchingProvider).toEqual([]);
|
||||
|
||||
const noMatchingType = await getModelListByType(allModels, 'openai', 'nonexistent');
|
||||
expect(noMatchingType).toEqual([]);
|
||||
});
|
||||
|
||||
it('should handle missing optional properties', async () => {
|
||||
const modelWithMissingProps = createChatModel('test-model', 'test', {
|
||||
displayName: undefined,
|
||||
describe('normalizeChatModel', () => {
|
||||
it('fills missing optional fields with safe defaults', () => {
|
||||
const model = createChatModel({
|
||||
abilities: undefined,
|
||||
contextWindowTokens: undefined,
|
||||
displayName: undefined,
|
||||
});
|
||||
|
||||
const result = await getModelListByType([modelWithMissingProps], 'test', 'chat');
|
||||
const result = normalizeChatModel(model);
|
||||
|
||||
expect(result[0].displayName).toBe('');
|
||||
expect(result[0].abilities).toEqual({});
|
||||
expect(result[0].contextWindowTokens).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should preserve complex model properties', async () => {
|
||||
const complexModel = createChatModel('complex-model', 'test', {
|
||||
displayName: 'Complex Model with All Properties',
|
||||
abilities: {
|
||||
functionCall: true,
|
||||
files: true,
|
||||
vision: false,
|
||||
} satisfies ModelAbilities,
|
||||
contextWindowTokens: 128000,
|
||||
});
|
||||
|
||||
const result = await getModelListByType([complexModel], 'test', 'chat');
|
||||
|
||||
expect(result[0]).toEqual({
|
||||
id: 'complex-model',
|
||||
displayName: 'Complex Model with All Properties',
|
||||
abilities: {
|
||||
functionCall: true,
|
||||
files: true,
|
||||
vision: false,
|
||||
},
|
||||
contextWindowTokens: 128000,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe('Image Model Parameter Handling', () => {
|
||||
it('should use fallback parameters for image models without parameters', async () => {
|
||||
vi.spyOn(runtimeModule, 'getModelPropertyWithFallback').mockResolvedValueOnce({
|
||||
size: '1024x1024',
|
||||
});
|
||||
|
||||
const result = await getModelListByType(allModels, 'midjourney', 'image');
|
||||
|
||||
expect(result[0]).toEqual({
|
||||
expect(result).toEqual({
|
||||
abilities: {},
|
||||
contextWindowTokens: undefined,
|
||||
displayName: 'Midjourney',
|
||||
id: 'midjourney',
|
||||
parameters: { size: '1024x1024' },
|
||||
displayName: '',
|
||||
id: 'chat-model',
|
||||
});
|
||||
});
|
||||
|
||||
it('should handle async parameter fetching for multiple models', async () => {
|
||||
const imageModelsWithoutParams = [
|
||||
createImageModel('stable-diffusion', 'stability', { displayName: 'Stable Diffusion' }),
|
||||
createImageModel('flux-schnell', 'fal', { displayName: 'FLUX Schnell' }),
|
||||
];
|
||||
|
||||
vi.spyOn(runtimeModule, 'getModelPropertyWithFallback').mockResolvedValue({
|
||||
prompt: { default: '' },
|
||||
width: { default: 512, min: 256, max: 2048 },
|
||||
height: { default: 512, min: 256, max: 2048 },
|
||||
});
|
||||
|
||||
const result = await getModelListByType(imageModelsWithoutParams, 'stability', 'image');
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].parameters).toEqual({
|
||||
prompt: { default: '' },
|
||||
width: { default: 512, min: 256, max: 2048 },
|
||||
height: { default: 512, min: 256, max: 2048 },
|
||||
});
|
||||
|
||||
expect(runtimeModule.getModelPropertyWithFallback).toHaveBeenCalledWith(
|
||||
'stable-diffusion',
|
||||
'parameters',
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle failed parameter fallback gracefully', async () => {
|
||||
const failingModel = createImageModel('failing-model', 'test-provider', {
|
||||
displayName: 'Failing Model',
|
||||
});
|
||||
|
||||
vi.spyOn(runtimeModule, 'getModelPropertyWithFallback').mockResolvedValueOnce(undefined);
|
||||
|
||||
const result = await getModelListByType([failingModel], 'test-provider', 'image');
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].id).toBe('failing-model');
|
||||
expect(result[0].parameters).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Concurrent Processing', () => {
|
||||
it('should handle large-scale concurrent model processing', async () => {
|
||||
const manyModels = Array.from({ length: 10 }, (_, i) =>
|
||||
createChatModel(`model-${i}`, 'test-provider', {
|
||||
displayName: `Model ${i}`,
|
||||
abilities: { functionCall: i % 2 === 0 } satisfies ModelAbilities,
|
||||
contextWindowTokens: 4096 + i * 1000,
|
||||
}),
|
||||
);
|
||||
describe('normalizeImageModel', () => {
|
||||
it('preserves inline metadata and pricing information', async () => {
|
||||
const model = createImageModel({
|
||||
abilities: { vision: true },
|
||||
contextWindowTokens: 4096,
|
||||
displayName: 'Inline Model',
|
||||
parameters: {
|
||||
prompt: { default: '' },
|
||||
size: { default: '1024x1024', enum: ['512x512', '1024x1024'] },
|
||||
} as ModelParamsSchema,
|
||||
pricing: {
|
||||
units: [{ name: 'imageGeneration', rate: 0.04, strategy: 'fixed', unit: 'image' }],
|
||||
},
|
||||
});
|
||||
|
||||
const result = await getModelListByType(manyModels, 'test-provider', 'chat');
|
||||
const result = await normalizeImageModel(model);
|
||||
|
||||
expect(result).toHaveLength(10);
|
||||
expect(result.map((m) => m.id)).toEqual(manyModels.map((m) => m.id));
|
||||
|
||||
result.forEach((model, index) => {
|
||||
expect(model.abilities.functionCall).toBe(index % 2 === 0);
|
||||
expect(model.contextWindowTokens).toBe(4096 + index * 1000);
|
||||
expect(result).toMatchObject({
|
||||
abilities: { vision: true },
|
||||
displayName: 'Inline Model',
|
||||
parameters: { size: { default: '1024x1024', enum: ['512x512', '1024x1024'] } },
|
||||
pricing: {
|
||||
units: [{ name: 'imageGeneration', rate: 0.04, strategy: 'fixed', unit: 'image' }],
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should maintain model order during concurrent processing', async () => {
|
||||
const orderedModels = [
|
||||
createChatModel('first-model', 'test', { displayName: 'First Model' }),
|
||||
createChatModel('second-model', 'test', { displayName: 'Second Model' }),
|
||||
createChatModel('third-model', 'test', { displayName: 'Third Model' }),
|
||||
it('fetches fallback description/parameters/pricing when missing', async () => {
|
||||
const fallbackSpy = vi
|
||||
.spyOn(runtimeModule, 'getModelPropertyWithFallback')
|
||||
.mockImplementation(async (_id, key) => {
|
||||
if (key === 'parameters')
|
||||
return {
|
||||
prompt: { default: '' },
|
||||
size: { default: '768x768', enum: ['512x512', '768x768'] },
|
||||
} satisfies ModelParamsSchema;
|
||||
if (key === 'pricing')
|
||||
return {
|
||||
units: [{ name: 'imageGeneration', rate: 0.02, strategy: 'fixed', unit: 'image' }],
|
||||
};
|
||||
if (key === 'description') return 'Fallback description';
|
||||
return undefined;
|
||||
});
|
||||
|
||||
const model = createImageModel({
|
||||
id: 'stable-diffusion',
|
||||
providerId: 'stability',
|
||||
parameters: undefined,
|
||||
pricing: undefined,
|
||||
});
|
||||
|
||||
const result = await normalizeImageModel(model);
|
||||
|
||||
expect(result.parameters).toEqual({
|
||||
prompt: { default: '' },
|
||||
size: { default: '768x768', enum: ['512x512', '768x768'] },
|
||||
});
|
||||
expect(result.pricing).toEqual({
|
||||
units: [{ name: 'imageGeneration', rate: 0.02, strategy: 'fixed', unit: 'image' }],
|
||||
});
|
||||
expect(result.description).toBe('Fallback description');
|
||||
expect(fallbackSpy).toHaveBeenCalledWith('stable-diffusion', 'parameters', 'stability');
|
||||
expect(fallbackSpy).toHaveBeenCalledWith('stable-diffusion', 'pricing', 'stability');
|
||||
expect(fallbackSpy).toHaveBeenCalledWith('stable-diffusion', 'description', 'stability');
|
||||
});
|
||||
});
|
||||
|
||||
describe('getChatModelList', () => {
|
||||
const chatModels = [
|
||||
createChatModel({ id: 'gpt-4', providerId: 'openai', displayName: 'GPT-4' }),
|
||||
createChatModel({ id: 'gpt-3.5', providerId: 'openai', displayName: 'GPT-3.5' }),
|
||||
createChatModel({ id: 'claude-3', providerId: 'anthropic', displayName: 'Claude 3' }),
|
||||
];
|
||||
|
||||
it('filters by provider and deduplicates IDs', async () => {
|
||||
const duplicated = [
|
||||
...chatModels,
|
||||
createChatModel({ id: 'gpt-4', providerId: 'openai', displayName: 'GPT-4 Duplicate' }),
|
||||
];
|
||||
|
||||
const result = await getModelListByType(orderedModels, 'test', 'chat');
|
||||
const result = await getChatModelList(duplicated, 'openai');
|
||||
|
||||
expect(result.map((m) => m.id)).toEqual(['first-model', 'second-model', 'third-model']);
|
||||
expect(result).toHaveLength(2);
|
||||
expect(result.map((m) => m.id)).toEqual(['gpt-4', 'gpt-3.5']);
|
||||
expect(result[0].displayName).toBe('GPT-4');
|
||||
});
|
||||
|
||||
it('returns empty array when provider has no chat models', async () => {
|
||||
const result = await getChatModelList(chatModels, 'nonexistent');
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getImageModelList', () => {
|
||||
const imageModels = [
|
||||
createImageModel({ id: 'dall-e-3', providerId: 'openai', displayName: 'DALL-E 3' }),
|
||||
createImageModel({ id: 'midjourney', providerId: 'midjourney', displayName: 'Midjourney' }),
|
||||
];
|
||||
|
||||
it('collects normalized image models for a provider', async () => {
|
||||
const result = await getImageModelList(imageModels, 'openai');
|
||||
|
||||
expect(result).toHaveLength(1);
|
||||
expect(result[0].id).toBe('dall-e-3');
|
||||
expect(result[0].displayName).toBe('DALL-E 3');
|
||||
});
|
||||
|
||||
it('returns empty array when provider has no image models', async () => {
|
||||
const result = await getImageModelList(imageModels, 'unknown');
|
||||
expect(result).toEqual([]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import { isDeprecatedEdition, isDesktop, isUsePgliteDB } from '@lobechat/const';
|
||||
import { getModelPropertyWithFallback } from '@lobechat/model-runtime';
|
||||
import { getModelPropertyWithFallback, resolveImageSinglePrice } from '@lobechat/model-runtime';
|
||||
import { uniqBy } from 'lodash-es';
|
||||
import {
|
||||
AIImageModelCard,
|
||||
EnabledAiModel,
|
||||
LobeDefaultAiModelListItem,
|
||||
ModelAbilities,
|
||||
ModelParamsSchema,
|
||||
Pricing,
|
||||
} from 'model-bank';
|
||||
import { SWRResponse, mutate } from 'swr';
|
||||
import { StateCreator } from 'zustand/vanilla';
|
||||
@@ -28,52 +30,130 @@ import {
|
||||
UpdateAiProviderParams,
|
||||
} from '@/types/aiProvider';
|
||||
|
||||
/**
|
||||
* Get models by provider ID and type, with proper formatting and deduplication
|
||||
*/
|
||||
export const getModelListByType = async (
|
||||
enabledAiModels: EnabledAiModel[],
|
||||
providerId: string,
|
||||
type: string,
|
||||
) => {
|
||||
const filteredModels = enabledAiModels.filter(
|
||||
(model) => model.providerId === providerId && model.type === type,
|
||||
);
|
||||
|
||||
const models = await Promise.all(
|
||||
filteredModels.map(async (model) => ({
|
||||
abilities: (model.abilities || {}) as ModelAbilities,
|
||||
contextWindowTokens: model.contextWindowTokens,
|
||||
displayName: model.displayName ?? '',
|
||||
id: model.id,
|
||||
...(model.type === 'image' && {
|
||||
parameters:
|
||||
(model as AIImageModelCard).parameters ||
|
||||
(await getModelPropertyWithFallback(model.id, 'parameters')),
|
||||
}),
|
||||
})),
|
||||
);
|
||||
|
||||
return uniqBy(models, 'id');
|
||||
export type ProviderModelListItem = {
|
||||
abilities: ModelAbilities;
|
||||
approximatePricePerImage?: number;
|
||||
contextWindowTokens?: number;
|
||||
description?: string;
|
||||
displayName: string;
|
||||
id: string;
|
||||
parameters?: ModelParamsSchema;
|
||||
pricePerImage?: number;
|
||||
pricing?: Pricing;
|
||||
};
|
||||
|
||||
/**
|
||||
* Build provider model lists with proper async handling
|
||||
*/
|
||||
type ModelNormalizer = (model: EnabledAiModel) => Promise<ProviderModelListItem>;
|
||||
|
||||
const dedupeById = (models: ProviderModelListItem[]) => uniqBy(models, 'id');
|
||||
|
||||
const createProviderModelCollector = (
|
||||
type: EnabledAiModel['type'],
|
||||
normalizer: ModelNormalizer,
|
||||
) => {
|
||||
return async (enabledAiModels: EnabledAiModel[], providerId: string) => {
|
||||
const filteredModels = enabledAiModels.filter(
|
||||
(model) => model.providerId === providerId && model.type === type,
|
||||
);
|
||||
|
||||
if (!filteredModels.length) return [];
|
||||
|
||||
const normalized = await Promise.all(filteredModels.map((model) => normalizer(model)));
|
||||
return dedupeById(normalized);
|
||||
};
|
||||
};
|
||||
|
||||
export const normalizeChatModel = (model: EnabledAiModel): ProviderModelListItem => ({
|
||||
abilities: (model.abilities || {}) as ModelAbilities,
|
||||
contextWindowTokens: model.contextWindowTokens,
|
||||
displayName: model.displayName ?? '',
|
||||
id: model.id,
|
||||
});
|
||||
|
||||
export const normalizeImageModel = async (
|
||||
model: EnabledAiModel,
|
||||
): Promise<ProviderModelListItem> => {
|
||||
const fallbackParametersPromise = model.parameters
|
||||
? Promise.resolve<ModelParamsSchema | undefined>(model.parameters)
|
||||
: getModelPropertyWithFallback<ModelParamsSchema | undefined>(
|
||||
model.id,
|
||||
'parameters',
|
||||
model.providerId,
|
||||
);
|
||||
|
||||
const modelWithPricing = model as AIImageModelCard;
|
||||
const fallbackPricingPromise = modelWithPricing.pricing
|
||||
? Promise.resolve<Pricing | undefined>(modelWithPricing.pricing)
|
||||
: getModelPropertyWithFallback<Pricing | undefined>(model.id, 'pricing', model.providerId);
|
||||
|
||||
const fallbackDescriptionPromise = getModelPropertyWithFallback<string | undefined>(
|
||||
model.id,
|
||||
'description',
|
||||
model.providerId,
|
||||
);
|
||||
|
||||
const [fallbackParameters, fallbackPricing, fallbackDescription] = await Promise.all([
|
||||
fallbackParametersPromise,
|
||||
fallbackPricingPromise,
|
||||
fallbackDescriptionPromise,
|
||||
]);
|
||||
|
||||
const parameters = model.parameters ?? fallbackParameters;
|
||||
const pricing = fallbackPricing;
|
||||
const description = fallbackDescription;
|
||||
const { price, approximatePrice } = resolveImageSinglePrice(pricing);
|
||||
|
||||
return {
|
||||
abilities: (model.abilities || {}) as ModelAbilities,
|
||||
contextWindowTokens: model.contextWindowTokens,
|
||||
displayName: model.displayName ?? '',
|
||||
id: model.id,
|
||||
...(parameters && { parameters }),
|
||||
...(description && { description }),
|
||||
...(pricing && { pricing }),
|
||||
...(typeof approximatePrice === 'number' && { approximatePricePerImage: approximatePrice }),
|
||||
...(typeof price === 'number' && { pricePerImage: price }),
|
||||
};
|
||||
};
|
||||
|
||||
export const getChatModelList = createProviderModelCollector('chat', async (model) =>
|
||||
normalizeChatModel(model),
|
||||
);
|
||||
|
||||
export const getImageModelList = createProviderModelCollector('image', normalizeImageModel);
|
||||
|
||||
const buildProviderModelLists = async (
|
||||
providers: EnabledProvider[],
|
||||
enabledAiModels: EnabledAiModel[],
|
||||
type: 'chat' | 'image',
|
||||
collector: (
|
||||
enabledAiModels: EnabledAiModel[],
|
||||
providerId: string,
|
||||
) => Promise<ProviderModelListItem[]>,
|
||||
) => {
|
||||
return Promise.all(
|
||||
providers.map(async (provider) => ({
|
||||
...provider,
|
||||
children: await getModelListByType(enabledAiModels, provider.id, type),
|
||||
children: await collector(enabledAiModels, provider.id),
|
||||
name: provider.name || provider.id,
|
||||
})),
|
||||
);
|
||||
};
|
||||
|
||||
/**
|
||||
* Build image provider model lists with proper async handling
|
||||
*/
|
||||
const buildImageProviderModelLists = async (
|
||||
providers: EnabledProvider[],
|
||||
enabledAiModels: EnabledAiModel[],
|
||||
) => buildProviderModelLists(providers, enabledAiModels, getImageModelList);
|
||||
|
||||
/**
|
||||
* Build chat provider model lists with proper async handling
|
||||
*/
|
||||
const buildChatProviderModelLists = async (
|
||||
providers: EnabledProvider[],
|
||||
enabledAiModels: EnabledAiModel[],
|
||||
) => buildProviderModelLists(providers, enabledAiModels, getChatModelList);
|
||||
|
||||
enum AiProviderSwrKey {
|
||||
fetchAiProviderItem = 'FETCH_AI_PROVIDER_ITEM',
|
||||
fetchAiProviderList = 'FETCH_AI_PROVIDER',
|
||||
@@ -252,8 +332,8 @@ export const createAiProviderSlice: StateCreator<
|
||||
|
||||
// Build model lists with proper async handling
|
||||
const [enabledChatModelList, enabledImageModelList] = await Promise.all([
|
||||
buildProviderModelLists(data.enabledChatAiProviders, data.enabledAiModels, 'chat'),
|
||||
buildProviderModelLists(data.enabledImageAiProviders, data.enabledAiModels, 'image'),
|
||||
buildChatProviderModelLists(data.enabledChatAiProviders, data.enabledAiModels),
|
||||
buildImageProviderModelLists(data.enabledImageAiProviders, data.enabledAiModels),
|
||||
]);
|
||||
|
||||
return {
|
||||
@@ -285,8 +365,8 @@ export const createAiProviderSlice: StateCreator<
|
||||
// Build model lists for non-login state as well
|
||||
const enabledAiModels = builtinAiModelList.filter((m) => m.enabled);
|
||||
const [enabledChatModelList, enabledImageModelList] = await Promise.all([
|
||||
buildProviderModelLists(enabledChatAiProviders, enabledAiModels, 'chat'),
|
||||
buildProviderModelLists(enabledImageAiProviders, enabledAiModels, 'image'),
|
||||
buildChatProviderModelLists(enabledChatAiProviders, enabledAiModels),
|
||||
buildImageProviderModelLists(enabledImageAiProviders, enabledAiModels),
|
||||
]);
|
||||
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user