♻️ refactor: refactor client mode upload to match server mode (#5111)

* ♻️ refactor: refactor upload method

* fix tests

*  test: add tests

* 🐛 fix: fix image
This commit is contained in:
Arvin Xu
2024-12-21 21:31:42 +08:00
committed by GitHub
parent 9c51c689ec
commit 0361ced7c2
11 changed files with 472 additions and 106 deletions
+16 -2
View File
@@ -1,5 +1,6 @@
import { DBModel } from '@/database/_deprecated/core/types/db';
import { DB_File, DB_FileSchema } from '@/database/_deprecated/schemas/files';
import { clientS3Storage } from '@/services/file/ClientS3';
import { nanoid } from '@/utils/uuid';
import { BaseModel } from '../core';
@@ -20,9 +21,15 @@ class _FileModel extends BaseModel<'files'> {
if (!item) return;
// arrayBuffer to url
const base64 = Buffer.from(item.data!).toString('base64');
let base64;
if (!item.data) {
const hash = (item.url as string).replace('client-s3://', '');
base64 = await this.getBase64ByFileHash(hash);
} else {
base64 = Buffer.from(item.data).toString('base64');
}
return { ...item, url: `data:${item.fileType};base64,${base64}` };
return { ...item, base64, url: `data:${item.fileType};base64,${base64}` };
}
async delete(id: string) {
@@ -32,6 +39,13 @@ class _FileModel extends BaseModel<'files'> {
async clear() {
return this.table.clear();
}
private async getBase64ByFileHash(hash: string) {
const fileItem = await clientS3Storage.getObject(hash);
if (!fileItem) throw new Error('file not found');
return Buffer.from(await fileItem.arrayBuffer()).toString('base64');
}
}
export const FileModel = new _FileModel();
+1 -3
View File
@@ -32,9 +32,7 @@ export const fileRouter = router({
}),
createFile: fileProcedure
.input(
UploadFileSchema.omit({ data: true, saveMode: true, url: true }).extend({ url: z.string() }),
)
.input(UploadFileSchema.omit({ url: true }).extend({ url: z.string() }))
.mutation(async ({ ctx, input }) => {
const { isExist } = await ctx.fileModel.checkHash(input.hash!);
+175
View File
@@ -0,0 +1,175 @@
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { fileEnv } from '@/config/file';
import { edgeClient } from '@/libs/trpc/client';
import { API_ENDPOINTS } from '@/services/_url';
import { clientS3Storage } from '@/services/file/ClientS3';
import { UPLOAD_NETWORK_ERROR, uploadService } from '../upload';
// Mock dependencies
vi.mock('@/libs/trpc/client', () => ({
edgeClient: {
upload: {
createS3PreSignedUrl: {
mutate: vi.fn(),
},
},
},
}));
vi.mock('@/services/file/ClientS3', () => ({
clientS3Storage: {
putObject: vi.fn(),
},
}));
vi.mock('@/utils/uuid', () => ({
uuid: () => 'mock-uuid',
}));
describe('UploadService', () => {
const mockFile = new File(['test'], 'test.png', { type: 'image/png' });
const mockPreSignUrl = 'https://example.com/presign';
beforeEach(() => {
vi.clearAllMocks();
// Mock Date.now
vi.spyOn(Date, 'now').mockImplementation(() => 3600000); // 1 hour in milliseconds
});
describe('uploadWithProgress', () => {
beforeEach(() => {
// Mock XMLHttpRequest
const xhrMock = {
upload: {
addEventListener: vi.fn(),
},
open: vi.fn(),
send: vi.fn(),
setRequestHeader: vi.fn(),
addEventListener: vi.fn(),
status: 200,
};
global.XMLHttpRequest = vi.fn(() => xhrMock) as any;
// Mock createS3PreSignedUrl
(edgeClient.upload.createS3PreSignedUrl.mutate as any).mockResolvedValue(mockPreSignUrl);
});
it('should upload file successfully with progress', async () => {
const onProgress = vi.fn();
const xhr = new XMLHttpRequest();
// Simulate successful upload
vi.spyOn(xhr, 'addEventListener').mockImplementation((event, handler) => {
if (event === 'load') {
// @ts-ignore
handler({ target: { status: 200 } });
}
});
const result = await uploadService.uploadWithProgress(mockFile, { onProgress });
expect(result).toEqual({
date: '1',
dirname: `${fileEnv.NEXT_PUBLIC_S3_FILE_PATH}/1`,
filename: 'mock-uuid.png',
path: `${fileEnv.NEXT_PUBLIC_S3_FILE_PATH}/1/mock-uuid.png`,
});
});
it('should handle network error', async () => {
const xhr = new XMLHttpRequest();
// Simulate network error
vi.spyOn(xhr, 'addEventListener').mockImplementation((event, handler) => {
if (event === 'error') {
Object.assign(xhr, { status: 0 });
// @ts-ignore
handler({});
}
});
await expect(uploadService.uploadWithProgress(mockFile, {})).rejects.toBe(
UPLOAD_NETWORK_ERROR,
);
});
it('should handle upload error', async () => {
const xhr = new XMLHttpRequest();
// Simulate upload error
vi.spyOn(xhr, 'addEventListener').mockImplementation((event, handler) => {
if (event === 'load') {
Object.assign(xhr, { status: 400, statusText: 'Bad Request' });
// @ts-ignore
handler({});
}
});
await expect(uploadService.uploadWithProgress(mockFile, {})).rejects.toBe('Bad Request');
});
});
describe('uploadToClientS3', () => {
it('should upload file to client S3 successfully', async () => {
const hash = 'test-hash';
const expectedResult = {
date: '1',
dirname: '',
filename: mockFile.name,
path: `client-s3://${hash}`,
};
(clientS3Storage.putObject as any).mockResolvedValue(undefined);
const result = await uploadService.uploadToClientS3(hash, mockFile);
expect(clientS3Storage.putObject).toHaveBeenCalledWith(hash, mockFile);
expect(result).toEqual(expectedResult);
});
});
describe('getImageFileByUrlWithCORS', () => {
beforeEach(() => {
global.fetch = vi.fn();
});
it('should fetch and create file from URL', async () => {
const url = 'https://example.com/image.png';
const filename = 'test.png';
const mockArrayBuffer = new ArrayBuffer(8);
(global.fetch as any).mockResolvedValue({
arrayBuffer: () => Promise.resolve(mockArrayBuffer),
});
const result = await uploadService.getImageFileByUrlWithCORS(url, filename);
expect(global.fetch).toHaveBeenCalledWith(API_ENDPOINTS.proxy, {
body: url,
method: 'POST',
});
expect(result).toBeInstanceOf(File);
expect(result.name).toBe(filename);
expect(result.type).toBe('image/png');
});
it('should handle custom file type', async () => {
const url = 'https://example.com/image.jpg';
const filename = 'test.jpg';
const fileType = 'image/jpeg';
const mockArrayBuffer = new ArrayBuffer(8);
(global.fetch as any).mockResolvedValue({
arrayBuffer: () => Promise.resolve(mockArrayBuffer),
});
const result = await uploadService.getImageFileByUrlWithCORS(url, filename, fileType);
expect(result.type).toBe(fileType);
});
});
});
+115
View File
@@ -0,0 +1,115 @@
import { createStore, del, get, set } from 'idb-keyval';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { BrowserS3Storage } from './index';
// Mock idb-keyval
vi.mock('idb-keyval', () => ({
createStore: vi.fn(),
set: vi.fn(),
get: vi.fn(),
del: vi.fn(),
}));
let storage: BrowserS3Storage;
let mockStore = {};
beforeEach(() => {
// Reset all mocks before each test
vi.clearAllMocks();
mockStore = {};
(createStore as any).mockReturnValue(mockStore);
storage = new BrowserS3Storage();
});
describe('BrowserS3Storage', () => {
describe('constructor', () => {
it('should create store when in browser environment', () => {
expect(createStore).toHaveBeenCalledWith('lobechat-local-s3', 'objects');
});
});
describe('putObject', () => {
it('should successfully put a file object', async () => {
const mockFile = new File(['test content'], 'test.txt', { type: 'text/plain' });
const mockArrayBuffer = new ArrayBuffer(8);
vi.spyOn(mockFile, 'arrayBuffer').mockResolvedValue(mockArrayBuffer);
(set as any).mockResolvedValue(undefined);
await storage.putObject('1-test-key', mockFile);
expect(set).toHaveBeenCalledWith(
'1-test-key',
{
data: mockArrayBuffer,
name: 'test.txt',
type: 'text/plain',
},
mockStore,
);
});
it('should throw error when put operation fails', async () => {
const mockFile = new File(['test content'], 'test.txt', { type: 'text/plain' });
const mockError = new Error('Storage error');
(set as any).mockRejectedValue(mockError);
await expect(storage.putObject('test-key', mockFile)).rejects.toThrow(
'Failed to put file test.txt: Storage error',
);
});
});
describe('getObject', () => {
it('should successfully get a file object', async () => {
const mockData = {
data: new ArrayBuffer(8),
name: 'test.txt',
type: 'text/plain',
};
(get as any).mockResolvedValue(mockData);
const result = await storage.getObject('test-key');
expect(result).toBeInstanceOf(File);
expect(result?.name).toBe('test.txt');
expect(result?.type).toBe('text/plain');
});
it('should return undefined when file not found', async () => {
(get as any).mockResolvedValue(undefined);
const result = await storage.getObject('test-key');
expect(result).toBeUndefined();
});
it('should throw error when get operation fails', async () => {
const mockError = new Error('Storage error');
(get as any).mockRejectedValue(mockError);
await expect(storage.getObject('test-key')).rejects.toThrow(
'Failed to get object (key=test-key): Storage error',
);
});
});
describe('deleteObject', () => {
it('should successfully delete a file object', async () => {
(del as any).mockResolvedValue(undefined);
await storage.deleteObject('test-key2');
expect(del).toHaveBeenCalledWith('test-key2', {});
});
it('should throw error when delete operation fails', async () => {
const mockError = new Error('Storage error');
(del as any).mockRejectedValue(mockError);
await expect(storage.deleteObject('test-key')).rejects.toThrow(
'Failed to delete object (key=test-key): Storage error',
);
});
});
});
+58
View File
@@ -0,0 +1,58 @@
import { createStore, del, get, set } from 'idb-keyval';
const BROWSER_S3_DB_NAME = 'lobechat-local-s3';
export class BrowserS3Storage {
private store;
constructor() {
// skip server-side rendering
if (typeof window === 'undefined') return;
this.store = createStore(BROWSER_S3_DB_NAME, 'objects');
}
/**
* 上传文件
* @param key 文件 hash
* @param file File 对象
*/
async putObject(key: string, file: File): Promise<void> {
try {
const data = await file.arrayBuffer();
await set(key, { data, name: file.name, type: file.type }, this.store);
} catch (e) {
throw new Error(`Failed to put file ${file.name}: ${(e as Error).message}`);
}
}
/**
* 获取文件
* @param key 文件 hash
* @returns File 对象
*/
async getObject(key: string): Promise<File | undefined> {
try {
const res = await get<{ data: ArrayBuffer; name: string; type: string }>(key, this.store);
if (!res) return;
return new File([res.data], res!.name, { type: res?.type });
} catch (e) {
throw new Error(`Failed to get object (key=${key}): ${(e as Error).message}`);
}
}
/**
* 删除文件
* @param key 文件 hash
*/
async deleteObject(key: string): Promise<void> {
try {
await del(key, this.store);
} catch (e) {
throw new Error(`Failed to delete object (key=${key}): ${(e as Error).message}`);
}
}
}
export const clientS3Storage = new BrowserS3Storage();
+9 -4
View File
@@ -3,6 +3,7 @@ import { Mock, beforeAll, beforeEach, describe, expect, it, vi } from 'vitest';
import { fileEnv } from '@/config/file';
import { FileModel } from '@/database/_deprecated/models/file';
import { DB_File } from '@/database/_deprecated/schemas/files';
import { clientS3Storage } from '@/services/file/ClientS3';
import { serverConfigSelectors } from '@/store/serverConfig/selectors';
import { createServerConfigStore } from '@/store/serverConfig/store';
@@ -45,19 +46,23 @@ beforeEach(() => {
describe('FileService', () => {
it('createFile should save the file to the database', async () => {
const localFile: DB_File = {
const localFile = {
name: 'test',
data: new ArrayBuffer(1),
fileType: 'image/png',
saveMode: 'local',
url: 'client-s3://123',
size: 1,
hash: '123',
};
await clientS3Storage.putObject(
'123',
new File([new ArrayBuffer(1)], 'test.png', { type: 'image/png' }),
);
(FileModel.create as Mock).mockResolvedValue(localFile);
const result = await fileService.createFile(localFile);
expect(FileModel.create).toHaveBeenCalledWith(localFile);
expect(result).toEqual({ url: 'data:image/png;base64,AA==' });
});
+36 -8
View File
@@ -1,16 +1,27 @@
import { FileModel } from '@/database/_deprecated/models/file';
import { DB_File } from '@/database/_deprecated/schemas/files';
import { FileItem } from '@/types/files';
import { clientS3Storage } from '@/services/file/ClientS3';
import { FileItem, UploadFileParams } from '@/types/files';
import { IFileService } from './type';
export class ClientService implements IFileService {
async createFile(file: DB_File) {
async createFile(file: UploadFileParams) {
// save to local storage
// we may want to save to a remote server later
const res = await FileModel.create(file);
// arrayBuffer to url
const base64 = Buffer.from(file.data!).toString('base64');
const res = await FileModel.create({
createdAt: Date.now(),
data: undefined,
fileHash: file.hash,
fileType: file.fileType,
metadata: file.metadata,
name: file.name,
saveMode: 'url',
size: file.size,
url: file.url,
} as any);
// get file to base64 url
const base64 = await this.getBase64ByFileHash(file.hash!);
return {
id: res.id,
@@ -18,14 +29,24 @@ export class ClientService implements IFileService {
};
}
// eslint-disable-next-line @typescript-eslint/no-unused-vars
async checkFileHash(_hash: string) {
return { isExist: false, metadata: {} };
}
async getFile(id: string): Promise<FileItem> {
const item = await FileModel.findById(id);
if (!item) {
throw new Error('file not found');
}
// arrayBuffer to url
const url = URL.createObjectURL(new Blob([item.data!], { type: item.fileType }));
// arrayBuffer to blob or base64 to blob
const blob = !!item.data
? new Blob([item.data!], { type: item.fileType })
: // @ts-ignore
new Blob([Buffer.from(item.base64!, 'base64')], { type: item.fileType });
const url = URL.createObjectURL(blob);
return {
createdAt: new Date(item.createdAt),
@@ -49,4 +70,11 @@ export class ClientService implements IFileService {
async removeAllFiles() {
return FileModel.clear();
}
private async getBase64ByFileHash(hash: string) {
const fileItem = await clientS3Storage.getObject(hash);
if (!fileItem) throw new Error('file not found');
return Buffer.from(await fileItem.arrayBuffer()).toString('base64');
}
}
+8 -16
View File
@@ -1,7 +1,8 @@
import { fileEnv } from '@/config/file';
import { edgeClient } from '@/libs/trpc/client';
import { API_ENDPOINTS } from '@/services/_url';
import { FileMetadata, UploadFileParams } from '@/types/files';
import { clientS3Storage } from '@/services/file/ClientS3';
import { FileMetadata } from '@/types/files';
import { FileUploadState, FileUploadStatus } from '@/types/files/upload';
import { uuid } from '@/utils/uuid';
@@ -66,23 +67,14 @@ class UploadService {
return result;
};
uploadToClientDB = async (params: UploadFileParams, file: File) => {
const { FileModel } = await import('@/database/_deprecated/models/file');
const fileArrayBuffer = await file.arrayBuffer();
// save to local storage
// we may want to save to a remote server later
const res = await FileModel.create({
createdAt: Date.now(),
...params,
data: fileArrayBuffer,
});
// arrayBuffer to url
const base64 = Buffer.from(fileArrayBuffer).toString('base64');
uploadToClientS3 = async (hash: string, file: File): Promise<FileMetadata> => {
await clientS3Storage.putObject(hash, file);
return {
id: res.id,
url: `data:${params.fileType};base64,${base64}`,
date: (Date.now() / 1000 / 60 / 60).toFixed(0),
dirname: '',
filename: file.name,
path: `client-s3://${hash}`,
};
};
@@ -2,6 +2,8 @@ import { act, renderHook } from '@testing-library/react';
import { describe, expect, it, vi } from 'vitest';
import { fileService } from '@/services/file';
import { ClientService } from '@/services/file/client';
import { messageService } from '@/services/message';
import { imageGenerationService } from '@/services/textToImage';
import { uploadService } from '@/services/upload';
import { chatSelectors } from '@/store/chat/selectors';
@@ -39,17 +41,23 @@ describe('chatToolSlice', () => {
vi.spyOn(uploadService, 'getImageFileByUrlWithCORS').mockResolvedValue(
new File(['1'], 'file.png', { type: 'image/png' }),
);
vi.spyOn(uploadService, 'uploadToClientDB').mockResolvedValue({} as any);
vi.spyOn(fileService, 'createFile').mockResolvedValue({ id: mockId, url: '' });
vi.spyOn(uploadService, 'uploadToClientS3').mockResolvedValue({} as any);
vi.spyOn(ClientService.prototype, 'createFile').mockResolvedValue({
id: mockId,
url: '',
});
vi.spyOn(result.current, 'toggleDallEImageLoading');
vi.spyOn(ClientService.prototype, 'checkFileHash').mockImplementation(async () => ({
isExist: false,
metadata: {},
}));
await act(async () => {
await result.current.generateImageFromPrompts(prompts, messageId);
});
// For each prompt, loading is toggled on and then off
expect(imageGenerationService.generateImage).toHaveBeenCalledTimes(prompts.length);
expect(uploadService.uploadToClientDB).toHaveBeenCalledTimes(prompts.length);
expect(uploadService.uploadToClientS3).toHaveBeenCalledTimes(prompts.length);
expect(result.current.toggleDallEImageLoading).toHaveBeenCalledTimes(prompts.length * 2);
});
});
@@ -75,6 +83,7 @@ describe('chatToolSlice', () => {
content: initialMessageContent,
}) as ChatMessage,
);
vi.spyOn(messageService, 'updateMessage').mockResolvedValueOnce(undefined);
await act(async () => {
await result.current.updateImageItem(messageId, updateFunction);
+33 -67
View File
@@ -6,14 +6,11 @@ import { message } from '@/components/AntdStaticMethods';
import { LOBE_CHAT_CLOUD } from '@/const/branding';
import { isServerMode } from '@/const/version';
import { fileService } from '@/services/file';
import { ServerService } from '@/services/file/server';
import { uploadService } from '@/services/upload';
import { FileMetadata, UploadFileItem } from '@/types/files';
import { FileStore } from '../../store';
const serverFileService = new ServerService();
interface UploadWithProgressParams {
file: File;
knowledgeBaseId?: string;
@@ -43,10 +40,6 @@ interface UploadWithProgressResult {
}
export interface FileUploadAction {
internal_uploadToClientDB: (
params: Omit<UploadWithProgressParams, 'knowledgeBaseId'>,
) => Promise<UploadWithProgressResult | undefined>;
internal_uploadToServer: (params: UploadWithProgressParams) => Promise<UploadWithProgressResult>;
uploadWithProgress: (
params: UploadWithProgressParams,
) => Promise<UploadWithProgressResult | undefined>;
@@ -57,51 +50,14 @@ export const createFileUploadSlice: StateCreator<
[['zustand/devtools', never]],
[],
FileUploadAction
> = (set, get) => ({
internal_uploadToClientDB: async ({ file, onStatusUpdate, skipCheckFileType }) => {
if (!skipCheckFileType && !file.type.startsWith('image')) {
onStatusUpdate?.({ id: file.name, type: 'removeFile' });
message.info({
content: t('upload.fileOnlySupportInServerMode', {
cloud: LOBE_CHAT_CLOUD,
ext: file.name.split('.').pop(),
ns: 'error',
}),
duration: 5,
});
return;
}
const fileArrayBuffer = await file.arrayBuffer();
const hash = sha256(fileArrayBuffer);
const data = await uploadService.uploadToClientDB(
{ fileType: file.type, hash, name: file.name, saveMode: 'local', size: file.size },
file,
);
onStatusUpdate?.({
id: file.name,
type: 'updateFile',
value: {
fileUrl: data.url,
id: data.id,
status: 'success',
uploadState: { progress: 100, restTime: 0, speed: 0 },
},
});
return data;
},
internal_uploadToServer: async ({ file, onStatusUpdate, knowledgeBaseId }) => {
> = () => ({
uploadWithProgress: async ({ file, onStatusUpdate, knowledgeBaseId, skipCheckFileType }) => {
const fileArrayBuffer = await file.arrayBuffer();
// 1. check file hash
const hash = sha256(fileArrayBuffer);
const checkStatus = await serverFileService.checkFileHash(hash);
const checkStatus = await fileService.checkFileHash(hash);
let metadata: FileMetadata;
// 2. if file exist, just skip upload
@@ -112,17 +68,37 @@ export const createFileUploadSlice: StateCreator<
type: 'updateFile',
value: { status: 'processing', uploadState: { progress: 100, restTime: 0, speed: 0 } },
});
} else {
// 2. if file don't exist, need upload files
metadata = await uploadService.uploadWithProgress(file, {
onProgress: (status, upload) => {
onStatusUpdate?.({
id: file.name,
type: 'updateFile',
value: { status: status === 'success' ? 'processing' : status, uploadState: upload },
}
// 2. if file don't exist, need upload files
else {
// if is server mode, upload to server s3, or upload to client s3
if (isServerMode) {
metadata = await uploadService.uploadWithProgress(file, {
onProgress: (status, upload) => {
onStatusUpdate?.({
id: file.name,
type: 'updateFile',
value: { status: status === 'success' ? 'processing' : status, uploadState: upload },
});
},
});
} else {
if (!skipCheckFileType && !file.type.startsWith('image')) {
onStatusUpdate?.({ id: file.name, type: 'removeFile' });
message.info({
content: t('upload.fileOnlySupportInServerMode', {
cloud: LOBE_CHAT_CLOUD,
ext: file.name.split('.').pop(),
ns: 'error',
}),
duration: 5,
});
},
});
return;
}
// Upload to the indexeddb in the browser
metadata = await uploadService.uploadToClientS3(hash, file);
}
}
// 3. use more powerful file type detector to get file type
@@ -138,12 +114,10 @@ export const createFileUploadSlice: StateCreator<
// 4. create file to db
const data = await fileService.createFile(
{
createdAt: Date.now(),
fileType,
hash,
metadata,
name: file.name,
saveMode: 'url',
size: file.size,
url: metadata.path,
},
@@ -163,12 +137,4 @@ export const createFileUploadSlice: StateCreator<
return data;
},
uploadWithProgress: async (payload) => {
const { internal_uploadToServer, internal_uploadToClientDB } = get();
if (isServerMode) return internal_uploadToServer(payload);
return internal_uploadToClientDB(payload);
},
});
+8 -2
View File
@@ -53,7 +53,6 @@ export const FileMetadataSchema = z.object({
export type FileMetadata = z.infer<typeof FileMetadataSchema>;
export const UploadFileSchema = z.object({
data: z.instanceof(ArrayBuffer).optional(),
/**
* file type
* @example 'image/png'
@@ -77,7 +76,6 @@ export const UploadFileSchema = z.object({
* local mean save the raw file into data
* url mean upload the file to a cdn and then save the url
*/
saveMode: z.enum(['local', 'url']),
/**
* file size
*/
@@ -89,3 +87,11 @@ export const UploadFileSchema = z.object({
});
export type UploadFileParams = z.infer<typeof UploadFileSchema>;
export interface CheckFileHashResult {
fileType?: string;
isExist: boolean;
metadata?: unknown;
size?: number;
url?: string;
}