feat(image): implement model selection memory functionality (#9160)

This commit is contained in:
YuTengjing
2025-09-08 23:46:47 +08:00
committed by GitHub
parent 58378fd10f
commit b00e6d7817
16 changed files with 1031 additions and 377 deletions
+1
View File
@@ -9,6 +9,7 @@ alwaysApply: false
- 如果要写复杂样式的话用 antd-style ,简单的话可以用 style 属性直接写内联样式
- 如果需要 flex 布局或者居中布局应该使用 react-layout-kit 的 Flexbox 和 Center 组件
- 选择组件时优先顺序应该是 src/components > 安装的组件 package > lobe-ui > antd
- 使用 selector 访问 zustand store 的数据,而不是直接从 store 获取
## antd-style token system
@@ -0,0 +1,53 @@
'use client';
import { Skeleton } from 'antd';
import { memo } from 'react';
import { Flexbox } from 'react-layout-kit';
/**
* Skeleton loading state for image configuration panel
* Provides visual feedback during initialization
*/
const ImageConfigSkeleton = memo(() => {
return (
<Flexbox gap={32} padding="12px 12px 0 12px" style={{ height: '100%' }}>
{/* Model Selection */}
<Flexbox gap={8}>
<Skeleton.Input size="small" style={{ width: 100 }} />
<Skeleton.Input size="large" style={{ width: '100%' }} />
</Flexbox>
{/* Image Upload Area */}
<Flexbox gap={8}>
<Skeleton.Input size="small" style={{ width: 60 }} />
<Skeleton.Node
style={{
borderRadius: 8,
height: 100,
width: '100%',
}}
/>
</Flexbox>
{/* Parameter Controls */}
{Array.from({ length: 2 }, (_, index) => (
<Flexbox gap={8} key={index}>
<Skeleton.Input size="small" style={{ width: 80 }} />
<Skeleton.Input size="default" style={{ width: '100%' }} />
</Flexbox>
))}
{/* Image Number Control (Sticky at bottom) */}
<Flexbox padding="12px 0" style={{ marginTop: 'auto' }}>
<Flexbox gap={8}>
<Skeleton.Input size="small" style={{ width: 60 }} />
<Skeleton.Input size="default" style={{ width: '100%' }} />
</Flexbox>
</Flexbox>
</Flexbox>
);
});
ImageConfigSkeleton.displayName = 'ImageConfigSkeleton';
export default ImageConfigSkeleton;
@@ -6,11 +6,13 @@ import { ReactNode, memo, useCallback, useEffect, useMemo, useRef, useState } fr
import { useTranslation } from 'react-i18next';
import { Flexbox } from 'react-layout-kit';
import { useFetchAiImageConfig } from '@/hooks/useFetchAiImageConfig';
import { imageGenerationConfigSelectors } from '@/store/image/selectors';
import { useDimensionControl } from '@/store/image/slices/generationConfig/hooks';
import { useImageStore } from '@/store/image/store';
import DimensionControlGroup from './components/DimensionControlGroup';
import ImageConfigSkeleton from './components/ImageConfigSkeleton';
import ImageNum from './components/ImageNum';
import ImageUrl from './components/ImageUrl';
import ImageUrlsUpload from './components/ImageUrlsUpload';
@@ -38,9 +40,15 @@ const isSupportedParamSelector = imageGenerationConfigSelectors.isSupportedParam
const ConfigPanel = memo(() => {
const { t } = useTranslation('image');
const theme = useTheme();
// Initialize image configuration
useFetchAiImageConfig();
// All hooks must be called before any early returns
const scrollContainerRef = useRef<HTMLDivElement>(null);
const [isScrollable, setIsScrollable] = useState(false);
const isInit = useImageStore((s) => s.isInit);
const isSupportImageUrl = useImageStore(isSupportedParamSelector('imageUrl'));
const isSupportSize = useImageStore(isSupportedParamSelector('size'));
const isSupportSeed = useImageStore(isSupportedParamSelector('seed'));
@@ -103,8 +111,7 @@ const ConfigPanel = memo(() => {
backgroundColor: theme.colorBgContainer,
borderTop: `1px solid ${theme.colorBorder}`,
// Use negative margin to extend background to container edges
marginLeft: -12,
marginLeft: -12,
marginRight: -12,
marginTop: 20,
// Add back internal padding
@@ -115,6 +122,11 @@ marginLeft: -12,
[isScrollable, theme.colorBgContainer, theme.colorBorder],
);
// Show loading state if not initialized
if (!isInit) {
return <ImageConfigSkeleton />;
}
return (
<Flexbox
gap={32}
+49
View File
@@ -0,0 +1,49 @@
import { useEffect, useMemo } from 'react';
import { aiProviderSelectors, useAiInfraStore } from '@/store/aiInfra';
import { useGlobalStore } from '@/store/global';
import { systemStatusSelectors } from '@/store/global/selectors';
import { useImageStore } from '@/store/image';
import { useUserStore } from '@/store/user';
import { authSelectors } from '@/store/user/selectors';
/**
* Manages image configuration initialization
* Uses optimized state checks to reduce unnecessary re-renders
*/
export const useFetchAiImageConfig = () => {
// Individual state checks for better performance
const isStatusInit = useGlobalStore(systemStatusSelectors.isStatusInit);
const isAuthLoaded = useUserStore(authSelectors.isLoaded);
const isInitAiProviderRuntimeState = useAiInfraStore(
aiProviderSelectors.isInitAiProviderRuntimeState,
);
// Combined readiness check with memoization
const isReadyForInit = useMemo(
() => isStatusInit && isAuthLoaded && isInitAiProviderRuntimeState,
[isStatusInit, isAuthLoaded, isInitAiProviderRuntimeState],
);
const isLogin = useUserStore(authSelectors.isLogin);
const { lastSelectedImageModel, lastSelectedImageProvider } = useGlobalStore((s) => ({
lastSelectedImageModel: s.status.lastSelectedImageModel,
lastSelectedImageProvider: s.status.lastSelectedImageProvider,
}));
const isInit = useImageStore((s) => s.isInit);
const initializeImageConfig = useImageStore((s) => s.initializeImageConfig);
useEffect(() => {
if (isReadyForInit && !isInit) {
initializeImageConfig(isLogin, lastSelectedImageModel, lastSelectedImageProvider);
}
}, [
isReadyForInit,
isInit,
isLogin,
lastSelectedImageModel,
lastSelectedImageProvider,
initializeImageConfig,
]);
};
@@ -67,6 +67,7 @@ describe('aiModelSelectors', () => {
providerSearchKeyword: '',
aiProviderRuntimeConfig: {},
initAiProviderList: false,
isInitAiProviderRuntimeState: false,
};
describe('aiProviderChatModelListIds', () => {
@@ -4,63 +4,71 @@ import { describe, expect, it, vi } from 'vitest';
import { getModelListByType } from '../action';
// Test fixtures
const createChatModel = (
id: string,
providerId: string,
overrides: Partial<EnabledAiModel> = {},
): EnabledAiModel => ({
id,
providerId,
type: 'chat',
abilities: { functionCall: true, files: true } satisfies ModelAbilities,
contextWindowTokens: 8192,
displayName: `${id} model`,
enabled: true,
...overrides,
});
const createImageModel = (
id: string,
providerId: string,
overrides: Partial<EnabledAiModel> = {},
): EnabledAiModel => ({
id,
providerId,
type: 'image',
abilities: {} satisfies ModelAbilities,
displayName: `${id} model`,
enabled: true,
...overrides,
});
// Core test data
const mockChatModels = [
createChatModel('gpt-4', 'openai', {
displayName: 'GPT-4',
abilities: { functionCall: true, files: true } satisfies ModelAbilities,
}),
createChatModel('gpt-3.5-turbo', 'openai', {
displayName: 'GPT-3.5 Turbo',
abilities: { functionCall: true } satisfies ModelAbilities,
contextWindowTokens: 4096,
}),
createChatModel('claude-3-opus', 'anthropic', {
displayName: 'Claude 3 Opus',
abilities: { functionCall: false, files: true } satisfies ModelAbilities,
contextWindowTokens: 200000,
}),
];
const mockImageModels = [
createImageModel('dall-e-3', 'openai', {
displayName: 'DALL-E 3',
parameters: {
prompt: { default: '' },
size: { default: '1024x1024', enum: ['512x512', '1024x1024', '1536x1536'] },
},
}),
createImageModel('midjourney', 'midjourney', {
displayName: 'Midjourney',
}),
];
const allModels = [...mockChatModels, ...mockImageModels];
describe('getModelListByType', () => {
const mockChatModels: EnabledAiModel[] = [
{
id: 'gpt-4',
providerId: 'openai',
type: 'chat',
abilities: { functionCall: true, files: true } as ModelAbilities,
contextWindowTokens: 8192,
displayName: 'GPT-4',
enabled: true,
},
{
id: 'gpt-3.5-turbo',
providerId: 'openai',
type: 'chat',
abilities: { functionCall: true } as ModelAbilities,
contextWindowTokens: 4096,
displayName: 'GPT-3.5 Turbo',
enabled: true,
},
{
id: 'claude-3-opus',
providerId: 'anthropic',
type: 'chat',
abilities: { functionCall: false, files: true } as ModelAbilities,
contextWindowTokens: 200000,
displayName: 'Claude 3 Opus',
enabled: true,
},
];
const mockImageModels: EnabledAiModel[] = [
{
id: 'dall-e-3',
providerId: 'openai',
type: 'image',
abilities: {} as ModelAbilities,
displayName: 'DALL-E 3',
enabled: true,
parameters: {
prompt: { default: '' },
size: { default: '1024x1024', enum: ['512x512', '1024x1024', '1536x1536'] },
},
},
{
id: 'midjourney',
providerId: 'midjourney',
type: 'image',
abilities: {} as ModelAbilities,
displayName: 'Midjourney',
enabled: true,
},
];
const allModels = [...mockChatModels, ...mockImageModels];
describe('basic functionality', () => {
describe('Core Functionality', () => {
it('should filter models by providerId and type correctly', async () => {
const result = await getModelListByType(allModels, 'openai', 'chat');
@@ -68,7 +76,7 @@ describe('getModelListByType', () => {
expect(result.map((m) => m.id)).toEqual(['gpt-4', 'gpt-3.5-turbo']);
});
it('should return correct model structure', async () => {
it('should return correct model structure for chat models', async () => {
const result = await getModelListByType(allModels, 'openai', 'chat');
expect(result[0]).toEqual({
@@ -79,7 +87,7 @@ describe('getModelListByType', () => {
});
});
it('should add parameters field for image models', async () => {
it('should include parameters field for image models', async () => {
const result = await getModelListByType(allModels, 'openai', 'image');
expect(result[0]).toEqual({
@@ -94,8 +102,87 @@ describe('getModelListByType', () => {
});
});
it('should exclude parameters field from chat models', async () => {
const result = await getModelListByType(mockChatModels, 'openai', 'chat');
result.forEach((model) => {
expect(model).not.toHaveProperty('parameters');
});
});
it('should remove duplicate model IDs', async () => {
const duplicateModels = [
createChatModel('gpt-4', 'openai', {
displayName: 'GPT-4 Version 1',
abilities: { functionCall: true } satisfies ModelAbilities,
}),
createChatModel('gpt-4', 'openai', {
displayName: 'GPT-4 Version 2',
abilities: { functionCall: false } satisfies ModelAbilities,
}),
];
const result = await getModelListByType(duplicateModels, 'openai', 'chat');
expect(result).toHaveLength(1);
expect(result[0].displayName).toBe('GPT-4 Version 1');
});
});
describe('Edge Cases and Error Handling', () => {
it('should handle empty inputs gracefully', async () => {
const emptyResult = await getModelListByType([], 'openai', 'chat');
expect(emptyResult).toEqual([]);
const noMatchingProvider = await getModelListByType(allModels, 'nonexistent', 'chat');
expect(noMatchingProvider).toEqual([]);
const noMatchingType = await getModelListByType(allModels, 'openai', 'nonexistent');
expect(noMatchingType).toEqual([]);
});
it('should handle missing optional properties', async () => {
const modelWithMissingProps = createChatModel('test-model', 'test', {
displayName: undefined,
abilities: undefined,
contextWindowTokens: undefined,
});
const result = await getModelListByType([modelWithMissingProps], 'test', 'chat');
expect(result[0].displayName).toBe('');
expect(result[0].abilities).toEqual({});
expect(result[0].contextWindowTokens).toBeUndefined();
});
it('should preserve complex model properties', async () => {
const complexModel = createChatModel('complex-model', 'test', {
displayName: 'Complex Model with All Properties',
abilities: {
functionCall: true,
files: true,
vision: false,
} satisfies ModelAbilities,
contextWindowTokens: 128000,
});
const result = await getModelListByType([complexModel], 'test', 'chat');
expect(result[0]).toEqual({
id: 'complex-model',
displayName: 'Complex Model with All Properties',
abilities: {
functionCall: true,
files: true,
vision: false,
},
contextWindowTokens: 128000,
});
});
});
describe('Image Model Parameter Handling', () => {
it('should use fallback parameters for image models without parameters', async () => {
// Mock getModelPropertyWithFallback
vi.spyOn(runtimeModule, 'getModelPropertyWithFallback').mockResolvedValueOnce({
size: '1024x1024',
});
@@ -110,108 +197,80 @@ describe('getModelListByType', () => {
parameters: { size: '1024x1024' },
});
});
});
describe('edge cases', () => {
it('should handle empty model list', async () => {
const result = await getModelListByType([], 'openai', 'chat');
expect(result).toEqual([]);
});
it('should handle non-existent providerId', async () => {
const result = await getModelListByType(allModels, 'nonexistent', 'chat');
expect(result).toEqual([]);
});
it('should handle non-existent type', async () => {
const result = await getModelListByType(allModels, 'openai', 'nonexistent');
expect(result).toEqual([]);
});
it('should handle missing displayName', async () => {
const modelsWithoutDisplayName: EnabledAiModel[] = [
{
id: 'test-model',
providerId: 'test',
type: 'chat',
abilities: {} as ModelAbilities,
enabled: true,
},
it('should handle async parameter fetching for multiple models', async () => {
const imageModelsWithoutParams = [
createImageModel('stable-diffusion', 'stability', { displayName: 'Stable Diffusion' }),
createImageModel('flux-schnell', 'fal', { displayName: 'FLUX Schnell' }),
];
const result = await getModelListByType(modelsWithoutDisplayName, 'test', 'chat');
expect(result[0].displayName).toBe('');
});
vi.spyOn(runtimeModule, 'getModelPropertyWithFallback').mockResolvedValue({
prompt: { default: '' },
width: { default: 512, min: 256, max: 2048 },
height: { default: 512, min: 256, max: 2048 },
});
it('should handle missing abilities', async () => {
const modelsWithoutAbilities: EnabledAiModel[] = [
{
id: 'test-model',
providerId: 'test',
type: 'chat',
enabled: true,
} as EnabledAiModel,
];
const result = await getModelListByType(modelsWithoutAbilities, 'test', 'chat');
expect(result[0].abilities).toEqual({});
});
});
describe('deduplication', () => {
it('should remove duplicate model IDs', async () => {
const duplicateModels: EnabledAiModel[] = [
{
id: 'gpt-4',
providerId: 'openai',
type: 'chat',
abilities: { functionCall: true } as ModelAbilities,
displayName: 'GPT-4 Version 1',
enabled: true,
},
{
id: 'gpt-4',
providerId: 'openai',
type: 'chat',
abilities: { functionCall: false } as ModelAbilities,
displayName: 'GPT-4 Version 2',
enabled: true,
},
];
const result = await getModelListByType(duplicateModels, 'openai', 'chat');
const result = await getModelListByType(imageModelsWithoutParams, 'stability', 'image');
expect(result).toHaveLength(1);
expect(result[0].displayName).toBe('GPT-4 Version 1');
expect(result[0].parameters).toEqual({
prompt: { default: '' },
width: { default: 512, min: 256, max: 2048 },
height: { default: 512, min: 256, max: 2048 },
});
expect(runtimeModule.getModelPropertyWithFallback).toHaveBeenCalledWith(
'stable-diffusion',
'parameters',
);
});
it('should handle failed parameter fallback gracefully', async () => {
const failingModel = createImageModel('failing-model', 'test-provider', {
displayName: 'Failing Model',
});
vi.spyOn(runtimeModule, 'getModelPropertyWithFallback').mockResolvedValueOnce(undefined);
const result = await getModelListByType([failingModel], 'test-provider', 'image');
expect(result).toHaveLength(1);
expect(result[0].id).toBe('failing-model');
expect(result[0].parameters).toBeUndefined();
});
});
describe('type casting', () => {
it('should handle image model type casting correctly', async () => {
const imageModel: EnabledAiModel[] = [
{
id: 'dall-e-3',
providerId: 'openai',
type: 'image',
abilities: {} as ModelAbilities,
displayName: 'DALL-E 3',
enabled: true,
parameters: { size: '1024x1024' },
} as any, // Simulate AIImageModelCard type
];
describe('Concurrent Processing', () => {
it('should handle large-scale concurrent model processing', async () => {
const manyModels = Array.from({ length: 10 }, (_, i) =>
createChatModel(`model-${i}`, 'test-provider', {
displayName: `Model ${i}`,
abilities: { functionCall: i % 2 === 0 } satisfies ModelAbilities,
contextWindowTokens: 4096 + i * 1000,
}),
);
const result = await getModelListByType(imageModel, 'openai', 'image');
const result = await getModelListByType(manyModels, 'test-provider', 'chat');
expect(result[0]).toHaveProperty('parameters');
expect(result[0].parameters).toEqual({ size: '1024x1024' });
expect(result).toHaveLength(10);
expect(result.map((m) => m.id)).toEqual(manyModels.map((m) => m.id));
result.forEach((model, index) => {
expect(model.abilities.functionCall).toBe(index % 2 === 0);
expect(model.contextWindowTokens).toBe(4096 + index * 1000);
});
});
it('should not add parameters field for non-image models', async () => {
const result = await getModelListByType(mockChatModels, 'openai', 'chat');
it('should maintain model order during concurrent processing', async () => {
const orderedModels = [
createChatModel('first-model', 'test', { displayName: 'First Model' }),
createChatModel('second-model', 'test', { displayName: 'Second Model' }),
createChatModel('third-model', 'test', { displayName: 'Third Model' }),
];
result.forEach((model) => {
expect(model).not.toHaveProperty('parameters');
});
const result = await getModelListByType(orderedModels, 'test', 'chat');
expect(result.map((m) => m.id)).toEqual(['first-model', 'second-model', 'third-model']);
});
});
});
+11 -4
View File
@@ -13,6 +13,8 @@ import { StateCreator } from 'zustand/vanilla';
import { useClientDataSWR } from '@/libs/swr';
import { aiProviderService } from '@/services/aiProvider';
import { AiInfraStore } from '@/store/aiInfra/store';
import { useUserStore } from '@/store/user';
import { authSelectors } from '@/store/user/selectors';
import {
AiProviderDetailItem,
AiProviderListItem,
@@ -227,9 +229,12 @@ export const createAiProviderSlice: StateCreator<
},
),
useFetchAiProviderRuntimeState: (isLogin) =>
useClientDataSWR<AiProviderRuntimeStateWithBuiltinModels | undefined>(
!isDeprecatedEdition ? [AiProviderSwrKey.fetchAiProviderRuntimeState, isLogin] : null,
useFetchAiProviderRuntimeState: (isLogin) => {
const isAuthLoaded = authSelectors.isLoaded(useUserStore.getState());
return useClientDataSWR<AiProviderRuntimeStateWithBuiltinModels | undefined>(
isAuthLoaded && !isDeprecatedEdition
? [AiProviderSwrKey.fetchAiProviderRuntimeState, isLogin]
: null,
async ([, isLogin]) => {
const [{ LOBE_DEFAULT_MODEL_LIST: builtinAiModelList }, { DEFAULT_MODEL_PROVIDER_LIST }] =
await Promise.all([import('model-bank'), import('@/config/modelProviders')]);
@@ -300,11 +305,13 @@ export const createAiProviderSlice: StateCreator<
enabledAiProviders: data.enabledAiProviders,
enabledChatModelList: data.enabledChatModelList || [],
enabledImageModelList: data.enabledImageModelList || [],
isInitAiProviderRuntimeState: true,
},
false,
'useFetchAiProviderRuntimeState',
);
},
},
),
);
},
});
@@ -22,6 +22,7 @@ export interface AIProviderState {
enabledChatModelList?: EnabledProviderWithModels[];
enabledImageModelList?: EnabledProviderWithModels[];
initAiProviderList: boolean;
isInitAiProviderRuntimeState: boolean;
providerSearchKeyword: string;
}
@@ -32,5 +33,6 @@ export const initialAIProviderState: AIProviderState = {
aiProviderLoadingIds: [],
aiProviderRuntimeConfig: {},
initAiProviderList: false,
isInitAiProviderRuntimeState: false,
providerSearchKeyword: '',
};
@@ -111,6 +111,8 @@ const isProviderEnableResponseApi = (id: string) => (s: AIProviderStoreState) =>
return false;
};
const isInitAiProviderRuntimeState = (s: AIProviderStoreState) => !!s.isInitAiProviderRuntimeState;
export const aiProviderSelectors = {
activeProviderConfig,
disabledAiProviderList,
@@ -119,6 +121,7 @@ export const aiProviderSelectors = {
isActiveProviderApiKeyNotEmpty,
isActiveProviderEndpointNotEmpty,
isAiProviderConfigLoading,
isInitAiProviderRuntimeState,
isProviderConfigUpdating,
isProviderEnableResponseApi,
isProviderEnabled,
+8
View File
@@ -61,6 +61,14 @@ export interface SystemStatus {
isEnablePglite?: boolean;
isShowCredit?: boolean;
language?: LocaleMode;
/**
* 记住用户最后选择的图像生成模型
*/
lastSelectedImageModel?: string;
/**
* 记住用户最后选择的图像生成提供商
*/
lastSelectedImageProvider?: string;
latestChangelogId?: string;
mobileShowPortal?: boolean;
mobileShowTopic?: boolean;
+5 -3
View File
@@ -29,15 +29,16 @@ const filePanelWidth = (s: GlobalState) => s.status.filePanelWidth;
const imagePanelWidth = (s: GlobalState) => s.status.imagePanelWidth;
const imageTopicPanelWidth = (s: GlobalState) => s.status.imageTopicPanelWidth;
const wideScreen = (s: GlobalState) => s.status.wideScreen;
const isStatusInit = (s: GlobalState) => !!s.isStatusInit;
const isPgliteNotEnabled = (s: GlobalState) =>
isUsePgliteDB && !isServerMode && s.isStatusInit && !s.status.isEnablePglite;
isUsePgliteDB && !isServerMode && isStatusInit(s) && !s.status.isEnablePglite;
/**
* 当且仅当 client db 模式,且 pglite 未初始化完成时返回 true
*/
const isPgliteNotInited = (s: GlobalState) =>
isUsePgliteDB &&
s.isStatusInit &&
isStatusInit(s) &&
s.status.isEnablePglite &&
s.initClientDBStage !== DatabaseLoadingState.Ready;
@@ -45,7 +46,7 @@ const isPgliteNotInited = (s: GlobalState) =>
* 当且仅当 client db 模式,且 pglite 初始化完成时返回 true
*/
const isPgliteInited = (s: GlobalState): boolean =>
(s.isStatusInit &&
(isStatusInit(s) &&
s.status.isEnablePglite &&
s.initClientDBStage === DatabaseLoadingState.Ready) ||
false;
@@ -72,6 +73,7 @@ export const systemStatusSelectors = {
isPgliteNotEnabled,
isPgliteNotInited,
isShowCredit,
isStatusInit,
language,
mobileShowPortal,
mobileShowTopic,
@@ -6,48 +6,7 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { useImageStore } from '@/store/image';
// Mock external dependencies
vi.mock('@/store/aiInfra', () => ({
aiProviderSelectors: {
enabledImageModelList: vi.fn(() => [
{
id: 'fal',
name: 'Fal',
children: [
{
id: 'flux/schnell',
displayName: 'FLUX.1 Schnell',
type: 'image',
parameters: fluxSchnellParamsSchema,
releasedAt: '2024-08-01',
} as AIImageModelCard,
],
},
{
id: 'custom-provider',
name: 'Custom Provider',
children: [
{
id: 'custom-model',
displayName: 'Custom Model',
type: 'image',
parameters: {
prompt: { default: '' },
width: { default: 1024, min: 256, max: 2048, step: 64 },
height: { default: 1024, min: 256, max: 2048, step: 64 },
steps: { default: 20, min: 1, max: 50 },
} as ModelParamsSchema,
releasedAt: '2024-01-01',
} as AIImageModelCard,
],
},
]),
},
getAiInfraStoreState: vi.fn(() => ({})),
}));
const fluxSchnellDefaultValues = extractDefaultValues(fluxSchnellParamsSchema);
// Test fixtures
const customModelSchema: ModelParamsSchema = {
prompt: { default: '' },
width: { default: 1024, min: 256, max: 2048, step: 64 },
@@ -55,25 +14,67 @@ const customModelSchema: ModelParamsSchema = {
steps: { default: 20, min: 1, max: 50 },
};
const testImageModels: AIImageModelCard[] = [
{
id: 'flux/schnell',
displayName: 'FLUX.1 Schnell',
type: 'image',
parameters: fluxSchnellParamsSchema,
releasedAt: '2024-08-01',
},
{
id: 'custom-model',
displayName: 'Custom Model',
type: 'image',
parameters: customModelSchema,
releasedAt: '2024-01-01',
},
];
const mockProviders = [
{
id: 'fal',
name: 'Fal',
children: [testImageModels[0]],
},
{
id: 'custom-provider',
name: 'Custom Provider',
children: [testImageModels[1]],
},
];
// Mock external dependencies
vi.mock('@/store/aiInfra', () => ({
aiProviderSelectors: {
enabledImageModelList: vi.fn(() => mockProviders),
},
getAiInfraStoreState: vi.fn(() => ({})),
}));
// Test data
const fluxSchnellDefaultValues = extractDefaultValues(fluxSchnellParamsSchema);
const customModelDefaultValues = extractDefaultValues(customModelSchema);
const initialTestState = {
model: 'initial-model',
provider: 'initial-provider',
imageNum: 1,
parameters: {
prompt: 'initial prompt',
width: 512,
height: 512,
} satisfies Partial<RuntimeImageGenParams>,
parametersSchema: {
prompt: { default: '' },
width: { default: 512, min: 256, max: 1024 },
height: { default: 512, min: 256, max: 1024 },
} satisfies ModelParamsSchema,
};
beforeEach(() => {
vi.clearAllMocks();
// Reset store state
useImageStore.setState({
model: 'initial-model',
provider: 'initial-provider',
imageNum: 1,
parameters: {
prompt: 'initial prompt',
width: 512,
height: 512,
} as RuntimeImageGenParams,
parametersSchema: {
prompt: { default: '' },
width: { default: 512, min: 256, max: 1024 },
height: { default: 512, min: 256, max: 1024 },
},
});
useImageStore.setState(initialTestState);
});
afterEach(() => {
@@ -81,8 +82,26 @@ afterEach(() => {
});
describe('GenerationConfigAction', () => {
describe('setParamOnInput', () => {
it('should update a single parameter in the parameters object', async () => {
// Helper function to create test parameters
const createTestParameters = (overrides: Partial<RuntimeImageGenParams> = {}) =>
({
prompt: '',
width: 512,
height: 512,
...overrides,
}) satisfies Partial<RuntimeImageGenParams>;
// Helper function to create test schema
const createTestSchema = (overrides: Partial<ModelParamsSchema> = {}) =>
({
prompt: { default: '' },
width: { default: 512, min: 256, max: 2048 },
height: { default: 512, min: 256, max: 2048 },
...overrides,
}) satisfies ModelParamsSchema;
describe('Parameter Management', () => {
it('should update individual parameters via setParamOnInput', () => {
const { result } = renderHook(() => useImageStore());
act(() => {
@@ -96,43 +115,45 @@ describe('GenerationConfigAction', () => {
});
});
it('should update numeric parameters correctly', async () => {
it('should handle different parameter types (string, number, null, array)', () => {
const { result } = renderHook(() => useImageStore());
act(() => {
result.current.setParamOnInput('width', 2048);
});
expect(result.current.parameters).toMatchObject({
prompt: 'initial prompt',
width: 2048,
height: 512,
});
});
it('should handle null values correctly', async () => {
const { result } = renderHook(() => useImageStore());
act(() => {
result.current.setParamOnInput('seed', null);
});
expect(result.current.parameters?.seed).toBeNull();
});
it('should handle array values correctly', async () => {
const { result } = renderHook(() => useImageStore());
act(() => {
result.current.setParamOnInput('imageUrls', ['test1.jpg', 'test2.jpg']);
});
expect(result.current.parameters?.imageUrls).toEqual(['test1.jpg', 'test2.jpg']);
expect(result.current.parameters).toMatchObject({
width: 2048,
seed: null,
imageUrls: ['test1.jpg', 'test2.jpg'],
});
});
it('should update imageNum independently', () => {
const { result } = renderHook(() => useImageStore());
act(() => {
result.current.setImageNum(4);
});
expect(result.current.imageNum).toBe(4);
});
it('should handle edge case values for imageNum', () => {
const { result } = renderHook(() => useImageStore());
act(() => {
result.current.setImageNum(0);
});
expect(result.current.imageNum).toBe(0);
});
});
describe('setModelAndProviderOnSelect', () => {
it('should set model, provider, parameters and parametersSchema for flux/schnell', async () => {
describe('Model and Provider Selection', () => {
it('should set complete configuration for flux/schnell model', () => {
const { result } = renderHook(() => useImageStore());
act(() => {
@@ -145,81 +166,40 @@ describe('GenerationConfigAction', () => {
expect(result.current.parametersSchema).toEqual(fluxSchnellParamsSchema);
});
it('should handle model selection with custom parameters', async () => {
it('should handle custom model configuration', () => {
const { result } = renderHook(() => useImageStore());
act(() => {
result.current.setModelAndProviderOnSelect('custom-model', 'custom-provider');
});
const expectedParams = extractDefaultValues(customModelSchema);
expect(result.current.model).toBe('custom-model');
expect(result.current.provider).toBe('custom-provider');
expect(result.current.parameters).toEqual(expectedParams);
expect(result.current.parameters).toEqual(customModelDefaultValues);
expect(result.current.parametersSchema).toEqual(customModelSchema);
});
it('should replace all previous parameters with new model defaults', async () => {
it('should completely replace parameters when switching models', () => {
const { result } = renderHook(() => useImageStore());
// First set some custom parameters
// Set some custom parameters
act(() => {
result.current.setParamOnInput('prompt', 'custom prompt');
result.current.setParamOnInput('steps', 50);
});
// Then switch model
// Switch model
act(() => {
result.current.setModelAndProviderOnSelect('flux/schnell', 'fal');
});
// Should completely replace parameters with model defaults
expect(result.current.parameters).toEqual(fluxSchnellDefaultValues);
expect(result.current.parameters?.prompt).toBe(''); // Default value, not 'custom prompt'
expect(result.current.parameters?.prompt).toBe('');
});
});
describe('setImageNum', () => {
it('should update the imageNum value', async () => {
const { result } = renderHook(() => useImageStore());
act(() => {
result.current.setImageNum(4);
});
expect(result.current.imageNum).toBe(4);
});
it('should handle different imageNum values', async () => {
const { result } = renderHook(() => useImageStore());
act(() => {
result.current.setImageNum(8);
});
expect(result.current.imageNum).toBe(8);
act(() => {
result.current.setImageNum(1);
});
expect(result.current.imageNum).toBe(1);
});
it('should handle edge case values', async () => {
const { result } = renderHook(() => useImageStore());
act(() => {
result.current.setImageNum(0);
});
expect(result.current.imageNum).toBe(0);
});
});
describe('reuseSettings', () => {
it('should set model, provider and merge settings with default values', async () => {
describe('Settings Reuse', () => {
it('should merge custom settings with model defaults', () => {
const { result } = renderHook(() => useImageStore());
const customSettings: Partial<RuntimeImageGenParams> = {
prompt: 'custom prompt',
@@ -240,25 +220,7 @@ describe('GenerationConfigAction', () => {
expect(result.current.parametersSchema).toEqual(fluxSchnellParamsSchema);
});
it('should override default values with provided settings', async () => {
const { result } = renderHook(() => useImageStore());
const customSettings: Partial<RuntimeImageGenParams> = {
width: 1536,
height: 1536,
};
act(() => {
result.current.reuseSettings('flux/schnell', 'fal', customSettings);
});
expect(result.current.parameters).toEqual({
...fluxSchnellDefaultValues,
width: 1536,
height: 1536,
});
});
it('should handle empty settings object', async () => {
it('should handle empty and null settings', () => {
const { result } = renderHook(() => useImageStore());
act(() => {
@@ -266,64 +228,32 @@ describe('GenerationConfigAction', () => {
});
expect(result.current.parameters).toEqual(fluxSchnellDefaultValues);
});
it('should handle partial settings with null values', async () => {
const { result } = renderHook(() => useImageStore());
const customSettings: Partial<RuntimeImageGenParams> = {
seed: null,
imageUrl: null,
};
act(() => {
result.current.reuseSettings('flux/schnell', 'fal', customSettings);
result.current.reuseSettings('flux/schnell', 'fal', { seed: null, imageUrl: null });
});
expect(result.current.parameters?.seed).toBeNull();
expect(result.current.parameters?.imageUrl).toBeNull();
});
});
describe('reuseSeed', () => {
it('should update only the seed parameter', async () => {
const { result } = renderHook(() => useImageStore());
const newSeed = 98765;
act(() => {
result.current.reuseSeed(newSeed);
});
expect(result.current.parameters).toMatchObject({
prompt: 'initial prompt',
width: 512,
height: 512,
seed: newSeed,
});
});
it('should preserve other parameters when updating seed', async () => {
it('should update only seed parameter via reuseSeed', () => {
const { result } = renderHook(() => useImageStore());
// First set some parameters
act(() => {
result.current.setParamOnInput('prompt', 'test prompt');
result.current.setParamOnInput('width', 1024);
});
const newSeed = 11111;
act(() => {
result.current.reuseSeed(newSeed);
result.current.reuseSeed(98765);
});
expect(result.current.parameters).toMatchObject({
prompt: 'test prompt',
width: 1024,
width: 512,
height: 512,
seed: newSeed,
seed: 98765,
});
});
it('should handle seed value of 0', async () => {
it('should handle edge case seed values', () => {
const { result } = renderHook(() => useImageStore());
act(() => {
@@ -331,12 +261,8 @@ describe('GenerationConfigAction', () => {
});
expect(result.current.parameters?.seed).toBe(0);
});
it('should handle large seed values within range', async () => {
const { result } = renderHook(() => useImageStore());
const largeSeed = 2147483647; // MAX_SEED
const largeSeed = 2147483647;
act(() => {
result.current.reuseSeed(largeSeed);
});
@@ -344,4 +270,259 @@ describe('GenerationConfigAction', () => {
expect(result.current.parameters?.seed).toBe(largeSeed);
});
});
describe('Aspect Ratio and Dimension Control', () => {
it('should update width without affecting height when aspect ratio is unlocked', () => {
const { result } = renderHook(() => useImageStore());
useImageStore.setState({
parameters: createTestParameters(),
parametersSchema: createTestSchema(),
isAspectRatioLocked: false,
});
act(() => {
result.current.setWidth(1024);
});
expect(result.current.parameters).toMatchObject({
width: 1024,
height: 512,
});
});
it('should update both dimensions when aspect ratio is locked', () => {
const { result } = renderHook(() => useImageStore());
useImageStore.setState({
parameters: createTestParameters(),
parametersSchema: createTestSchema(),
isAspectRatioLocked: true,
activeAspectRatio: '1:1',
});
act(() => {
result.current.setWidth(1024);
});
expect(result.current.parameters).toMatchObject({
width: 1024,
height: 1024,
});
});
it('should clamp dimensions to schema bounds when aspect ratio is locked', () => {
const { result } = renderHook(() => useImageStore());
useImageStore.setState({
parameters: createTestParameters(),
parametersSchema: createTestSchema({
height: { default: 512, min: 256, max: 1024 },
}),
isAspectRatioLocked: true,
activeAspectRatio: '1:1',
});
act(() => {
result.current.setWidth(2048);
});
expect(result.current.parameters).toMatchObject({
width: 2048,
height: 1024, // Clamped to max
});
});
it('should update height with proportional width adjustment when locked', () => {
const { result } = renderHook(() => useImageStore());
useImageStore.setState({
parameters: createTestParameters(),
parametersSchema: createTestSchema(),
isAspectRatioLocked: true,
activeAspectRatio: '2:1',
});
act(() => {
result.current.setHeight(512);
});
expect(result.current.parameters).toMatchObject({
width: 1024,
height: 512,
});
});
it('should toggle aspect ratio lock state', () => {
const { result } = renderHook(() => useImageStore());
useImageStore.setState({ isAspectRatioLocked: false });
act(() => {
result.current.toggleAspectRatioLock();
});
expect(result.current.isAspectRatioLocked).toBe(true);
act(() => {
result.current.toggleAspectRatioLock();
});
expect(result.current.isAspectRatioLocked).toBe(false);
});
it('should adjust dimensions when locking with mismatched ratio', () => {
const { result } = renderHook(() => useImageStore());
useImageStore.setState({
parameters: createTestParameters({ width: 1024, height: 512 }), // 2:1 ratio
parametersSchema: createTestSchema(),
isAspectRatioLocked: false,
activeAspectRatio: '1:1', // Target 1:1 ratio
});
act(() => {
result.current.toggleAspectRatioLock();
});
expect(result.current.isAspectRatioLocked).toBe(true);
expect(result.current.parameters).toMatchObject({
width: 1024,
height: 1024,
});
});
});
describe('Aspect Ratio Setting', () => {
it('should update active aspect ratio', () => {
const { result } = renderHook(() => useImageStore());
useImageStore.setState({
parameters: createTestParameters(),
parametersSchema: createTestSchema(),
});
act(() => {
result.current.setAspectRatio('16:9');
});
expect(result.current.activeAspectRatio).toBe('16:9');
});
it('should calculate dimensions for width/height-based models', () => {
const { result } = renderHook(() => useImageStore());
useImageStore.setState({
parameters: createTestParameters(),
parametersSchema: createTestSchema(),
});
act(() => {
result.current.setAspectRatio('16:9');
});
const params = result.current.parameters!;
expect(params.width).toBeGreaterThan(params.height!);
const ratio = params.width! / params.height!;
expect(ratio).toBeCloseTo(16 / 9, 1);
});
it('should update aspectRatio parameter for models with native support', () => {
const { result } = renderHook(() => useImageStore());
useImageStore.setState({
parameters: { aspectRatio: '1:1', prompt: '' },
parametersSchema: createTestSchema({
aspectRatio: { default: '1:1', enum: ['1:1', '16:9', '4:3'] },
}),
});
act(() => {
result.current.setAspectRatio('16:9');
});
expect(result.current.parameters?.aspectRatio).toBe('16:9');
expect(result.current.activeAspectRatio).toBe('16:9');
});
it('should handle missing parameters or schema gracefully', () => {
const { result } = renderHook(() => useImageStore());
useImageStore.setState({
parameters: undefined,
parametersSchema: undefined,
});
expect(() => {
act(() => {
result.current.setAspectRatio('16:9');
});
}).not.toThrow();
});
});
describe('Configuration Initialization', () => {
beforeEach(() => {
vi.doMock('@/store/global', () => ({
useGlobalStore: {
getState: () => ({
status: {
lastSelectedImageModel: 'flux/schnell',
lastSelectedImageProvider: 'fal',
},
}),
},
}));
vi.doMock('@/store/user', () => ({
useUserStore: {
getState: () => ({ user: { id: 'test' } }),
},
}));
});
it('should initialize with remembered model when user is logged in', () => {
const { result } = renderHook(() => useImageStore());
useImageStore.setState({
isInit: false,
model: '',
provider: '',
});
act(() => {
result.current.initializeImageConfig(true, 'flux/schnell', 'fal');
});
expect(result.current.model).toBe('flux/schnell');
expect(result.current.provider).toBe('fal');
expect(result.current.parameters).toEqual(fluxSchnellDefaultValues);
expect(result.current.isInit).toBe(true);
});
it('should handle initialization without remembered preferences', () => {
const { result } = renderHook(() => useImageStore());
useImageStore.setState({ isInit: false });
act(() => {
result.current.initializeImageConfig(false);
});
expect(result.current.isInit).toBe(true);
});
it('should handle initialization errors gracefully', () => {
const { result } = renderHook(() => useImageStore());
useImageStore.setState({ isInit: false });
act(() => {
result.current.initializeImageConfig(true, 'invalid-model', 'invalid-provider');
});
expect(result.current.isInit).toBe(true);
});
});
});
+100 -23
View File
@@ -9,8 +9,12 @@ import {
import { StateCreator } from 'zustand/vanilla';
import { aiProviderSelectors, getAiInfraStoreState } from '@/store/aiInfra';
import { useGlobalStore } from '@/store/global';
import { useUserStore } from '@/store/user';
import { authSelectors } from '@/store/user/selectors';
import type { ImageStore } from '../../store';
import { calculateInitialAspectRatio } from '../../utils/aspectRatio';
import { adaptSizeToRatio, parseRatio } from '../../utils/size';
export interface GenerationConfigAction {
@@ -34,6 +38,13 @@ export interface GenerationConfigAction {
setHeight(height: number): void;
toggleAspectRatioLock(): void;
setAspectRatio(aspectRatio: string): void;
// 初始化相关方法
initializeImageConfig(
isLogin?: boolean,
lastSelectedImageModel?: string,
lastSelectedImageProvider?: string,
): void;
}
/**
@@ -43,9 +54,22 @@ export interface GenerationConfigAction {
*/
export function getModelAndDefaults(model: string, provider: string) {
const enabledImageModelList = aiProviderSelectors.enabledImageModelList(getAiInfraStoreState());
const activeModel = enabledImageModelList
.find((providerItem) => providerItem.id === provider)
?.children.find((modelItem) => modelItem.id === model) as unknown as AIImageModelCard;
const providerItem = enabledImageModelList.find((providerItem) => providerItem.id === provider);
if (!providerItem) {
throw new Error(
`Provider "${provider}" not found in enabled image provider list. Available providers: ${enabledImageModelList.map((p) => p.id).join(', ')}`,
);
}
const activeModel = providerItem.children.find(
(modelItem) => modelItem.id === model,
) as unknown as AIImageModelCard;
if (!activeModel) {
throw new Error(
`Model "${model}" not found in provider "${provider}". Available models: ${providerItem.children.map((m) => m.id).join(', ')}`,
);
}
const parametersSchema = activeModel.parameters as ModelParamsSchema;
const defaultValues = extractDefaultValues(parametersSchema);
@@ -53,6 +77,22 @@ export function getModelAndDefaults(model: string, provider: string) {
return { defaultValues, activeModel, parametersSchema };
}
/**
* @internal Helper
* Internal utility to derive initial config for a given provider/model.
* Not exported; tests should cover through public actions.
*/
function prepareModelConfigState(model: string, provider: string) {
const { defaultValues, parametersSchema } = getModelAndDefaults(model, provider);
const initialActiveRatio = calculateInitialAspectRatio(parametersSchema, defaultValues);
return {
defaultValues,
parametersSchema,
initialActiveRatio,
};
}
export const createGenerationConfigSlice: StateCreator<
ImageStore,
[['zustand/devtools', never]],
@@ -237,38 +277,32 @@ export const createGenerationConfigSlice: StateCreator<
},
setModelAndProviderOnSelect: (model, provider) => {
const { defaultValues, activeModel } = getModelAndDefaults(model, provider);
const parametersSchema = activeModel.parameters;
let initialActiveRatio: string | null = null;
// 如果模型没有原生比例或尺寸参数,但有宽高,则启用虚拟比例控制
if (
!parametersSchema?.aspectRatio &&
!parametersSchema?.size &&
parametersSchema?.width &&
parametersSchema?.height
) {
const { width, height } = defaultValues;
if (typeof width === 'number' && typeof height === 'number' && width > 0 && height > 0) {
initialActiveRatio = `${width}:${height}`;
} else {
initialActiveRatio = '1:1';
}
}
const { defaultValues, parametersSchema, initialActiveRatio } = prepareModelConfigState(
model,
provider,
);
set(
{
model,
provider,
parameters: defaultValues,
parametersSchema: parametersSchema,
parametersSchema,
isAspectRatioLocked: false,
activeAspectRatio: initialActiveRatio,
},
false,
`setModelAndProviderOnSelect/${model}/${provider}`,
);
// 仅在登录用户下记忆上次选择,保持与恢复策略一致
const isLogin = authSelectors.isLogin(useUserStore.getState());
if (isLogin) {
useGlobalStore.getState().updateSystemStatus({
lastSelectedImageModel: model,
lastSelectedImageProvider: provider,
});
}
},
setImageNum: (imageNum) => {
@@ -292,4 +326,47 @@ export const createGenerationConfigSlice: StateCreator<
reuseSeed: (seed: number) => {
set((state) => ({ parameters: { ...state.parameters, seed } }), false, `reuseSeed/${seed}`);
},
initializeImageConfig: (isLogin, lastSelectedImageModel, lastSelectedImageProvider) => {
// If no parameters are passed, get from store (backward compatibility)
let actualIsLogin = isLogin;
let actualLastSelectedImageModel = lastSelectedImageModel;
let actualLastSelectedImageProvider = lastSelectedImageProvider;
if (typeof isLogin === 'undefined') {
const globalStatus = useGlobalStore.getState().status;
actualIsLogin = authSelectors.isLogin(useUserStore.getState());
actualLastSelectedImageModel = globalStatus.lastSelectedImageModel;
actualLastSelectedImageProvider = globalStatus.lastSelectedImageProvider;
}
if (actualIsLogin && actualLastSelectedImageModel && actualLastSelectedImageProvider) {
try {
const { defaultValues, parametersSchema, initialActiveRatio } = prepareModelConfigState(
actualLastSelectedImageModel,
actualLastSelectedImageProvider,
);
set(
{
model: actualLastSelectedImageModel,
provider: actualLastSelectedImageProvider,
parameters: defaultValues,
parametersSchema,
isAspectRatioLocked: false,
activeAspectRatio: initialActiveRatio,
isInit: true,
},
false,
`initializeImageConfig/${actualLastSelectedImageModel}/${actualLastSelectedImageProvider}`,
);
} catch {
// If restoration fails, simply mark as initialized to use default configuration
set({ isInit: true }, false, 'initializeImageConfig/fallback');
}
} else {
// No remembered model, directly mark as initialized (use default values)
set({ isInit: true }, false, 'initializeImageConfig/default');
}
},
});
@@ -21,6 +21,11 @@ export interface GenerationConfigState {
isAspectRatioLocked: boolean;
activeAspectRatio: string | null; // string - 虚拟比例; null - 原生比例
/**
* 标记配置是否已初始化(包括从记忆中恢复)
*/
isInit: boolean;
}
export const DEFAULT_IMAGE_GENERATION_PARAMETERS: RuntimeImageGenParams =
@@ -34,4 +39,5 @@ export const initialGenerationConfigState: GenerationConfigState = {
parametersSchema: gptImage1ParamsSchema,
isAspectRatioLocked: false,
activeAspectRatio: null,
isInit: false,
};
+148
View File
@@ -0,0 +1,148 @@
import { ModelParamsSchema } from 'model-bank';
import { describe, expect, it } from 'vitest';
import { calculateInitialAspectRatio, supportsVirtualAspectRatio } from './aspectRatio';
// Test data fixtures
const createBaseSchema = (overrides: Partial<ModelParamsSchema> = {}): ModelParamsSchema => ({
prompt: { default: '' },
...overrides,
});
const createDimensionSchema = (overrides: Partial<ModelParamsSchema> = {}): ModelParamsSchema =>
createBaseSchema({
width: { default: 512, min: 256, max: 2048 },
height: { default: 512, min: 256, max: 2048 },
...overrides,
});
const createDefaultValues = (values: Record<string, any> = {}) => ({
prompt: '',
...values,
});
describe('aspectRatio utils', () => {
describe('calculateInitialAspectRatio', () => {
it('should return null when native aspect controls are present', () => {
// Models with native aspectRatio parameter
const aspectRatioSchema = createBaseSchema({
aspectRatio: { default: '1:1', enum: ['1:1', '16:9', '4:3'] },
});
const aspectRatioValues = createDefaultValues({ aspectRatio: '1:1' });
expect(calculateInitialAspectRatio(aspectRatioSchema, aspectRatioValues)).toBeNull();
// Models with native size parameter
const sizeSchema = createBaseSchema({
size: { default: '1024x1024', enum: ['512x512', '1024x1024', '1536x1536'] },
});
const sizeValues = createDefaultValues({ size: '1024x1024' });
expect(calculateInitialAspectRatio(sizeSchema, sizeValues)).toBeNull();
});
it('should return null when width or height parameters are missing', () => {
const schemaWithoutWidth = createBaseSchema({
height: { default: 512, min: 256, max: 2048 },
});
const valuesWithoutWidth = createDefaultValues({ height: 512 });
expect(calculateInitialAspectRatio(schemaWithoutWidth, valuesWithoutWidth)).toBeNull();
const schemaWithoutHeight = createBaseSchema({
width: { default: 512, min: 256, max: 2048 },
});
const valuesWithoutHeight = createDefaultValues({ width: 512 });
expect(calculateInitialAspectRatio(schemaWithoutHeight, valuesWithoutHeight)).toBeNull();
});
it('should calculate aspect ratio from width and height values', () => {
const schema = createDimensionSchema({
width: { default: 1024, min: 256, max: 2048 },
height: { default: 768, min: 256, max: 2048 },
});
const values = createDefaultValues({ width: 1024, height: 768 });
expect(calculateInitialAspectRatio(schema, values)).toBe('1024:768');
});
it('should handle square dimensions correctly', () => {
const schema = createDimensionSchema();
const values = createDefaultValues({ width: 512, height: 512 });
expect(calculateInitialAspectRatio(schema, values)).toBe('512:512');
});
it('should return fallback ratio for invalid dimension values', () => {
const schema = createDimensionSchema();
// Invalid values should fallback to 1:1
const testCases = [
{ width: NaN, height: NaN },
{ width: 0, height: 512 },
{ width: -512, height: 512 },
{ height: 512 }, // missing width
{ width: 512 }, // missing height
];
testCases.forEach((testCase) => {
const values = createDefaultValues(testCase);
expect(calculateInitialAspectRatio(schema, values)).toBe('1:1');
});
});
});
describe('supportsVirtualAspectRatio', () => {
it('should return true for models with width/height but no native aspect controls', () => {
const schema = createDimensionSchema();
expect(supportsVirtualAspectRatio(schema)).toBe(true);
});
it('should return false when native aspect controls are present', () => {
// Schema with native aspectRatio parameter
const aspectRatioSchema = createDimensionSchema({
aspectRatio: { default: '1:1', enum: ['1:1', '16:9', '4:3'] },
});
expect(supportsVirtualAspectRatio(aspectRatioSchema)).toBe(false);
// Schema with native size parameter
const sizeSchema = createDimensionSchema({
size: { default: '1024x1024', enum: ['512x512', '1024x1024', '1536x1536'] },
});
expect(supportsVirtualAspectRatio(sizeSchema)).toBe(false);
// Schema with both aspectRatio and size parameters
const bothSchema = createDimensionSchema({
aspectRatio: { default: '1:1', enum: ['1:1', '16:9', '4:3'] },
size: { default: '1024x1024', enum: ['512x512', '1024x1024', '1536x1536'] },
});
expect(supportsVirtualAspectRatio(bothSchema)).toBe(false);
});
it('should return false when required dimension parameters are missing', () => {
// Missing width parameter
const schemaWithoutWidth = createBaseSchema({
height: { default: 512, min: 256, max: 2048 },
});
expect(supportsVirtualAspectRatio(schemaWithoutWidth)).toBe(false);
// Missing height parameter
const schemaWithoutHeight = createBaseSchema({
width: { default: 512, min: 256, max: 2048 },
});
expect(supportsVirtualAspectRatio(schemaWithoutHeight)).toBe(false);
// Missing both width and height parameters
const emptySchema = createBaseSchema();
expect(supportsVirtualAspectRatio(emptySchema)).toBe(false);
});
});
});
+45
View File
@@ -0,0 +1,45 @@
import { ModelParamsSchema } from 'model-bank';
/**
* Calculate initial aspect ratio for image generation models
* @param parametersSchema - The model's parameter schema
* @param defaultValues - Default parameter values from the model
* @returns Initial aspect ratio string or null if not applicable
*/
export const calculateInitialAspectRatio = (
parametersSchema: ModelParamsSchema,
defaultValues: Record<string, any>,
): string | null => {
// If model has native aspect ratio or size parameters, don't use virtual ratio control
if (parametersSchema?.aspectRatio || parametersSchema?.size) {
return null;
}
// If model doesn't have width/height parameters, no virtual ratio needed
if (!parametersSchema?.width || !parametersSchema?.height) {
return null;
}
const { width, height } = defaultValues;
// Ensure we have valid numeric width and height values
if (typeof width === 'number' && typeof height === 'number' && width > 0 && height > 0) {
return `${width}:${height}`;
}
// Default fallback ratio
return '1:1';
};
/**
* Check if a model supports virtual aspect ratio control
* Virtual aspect ratio is enabled when model has width/height but no native aspect ratio/size controls
*/
export const supportsVirtualAspectRatio = (parametersSchema: ModelParamsSchema): boolean => {
return (
!parametersSchema?.aspectRatio &&
!parametersSchema?.size &&
!!parametersSchema?.width &&
!!parametersSchema?.height
);
};