mirror of
https://github.com/lobehub/lobe-chat.git
synced 2026-06-14 19:50:09 +00:00
♻️ refactor: refactor the db to context inject mode (#7255)
* refactor with new db init mode * fix tests * fix tests * move the separate index * fix tests * fix tests * fix db issue * fix db * refactor to clean * Update index.ts * fix error * fix the exist inbox slug session * fix the tests
This commit is contained in:
@@ -1,6 +1,24 @@
|
||||
// import { isDesktop } from '@/const/version';
|
||||
import { getDBInstance } from '@/database/core/web-server';
|
||||
import { LobeChatDatabase } from '@/database/type';
|
||||
|
||||
// import { getPgliteInstance } from './electron';
|
||||
/**
|
||||
* 懒加载数据库实例
|
||||
* 避免每次模块导入时都初始化数据库
|
||||
*/
|
||||
let cachedDB: LobeChatDatabase | null = null;
|
||||
|
||||
export const getServerDB = async (): Promise<LobeChatDatabase> => {
|
||||
// 如果已经有缓存的实例,直接返回
|
||||
if (cachedDB) return cachedDB;
|
||||
|
||||
try {
|
||||
// 根据环境选择合适的数据库实例
|
||||
cachedDB = getDBInstance();
|
||||
return cachedDB;
|
||||
} catch (error) {
|
||||
console.error('❌ Failed to initialize database:', error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
export const serverDB = getDBInstance();
|
||||
|
||||
@@ -4,6 +4,7 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { LobeChatDatabase } from '@/database/type';
|
||||
import { ModelProvider } from '@/libs/agent-runtime';
|
||||
import { sleep } from '@/utils/sleep';
|
||||
|
||||
import { aiProviders, users } from '../../schemas';
|
||||
import { AiProviderModel } from '../aiProvider';
|
||||
@@ -96,6 +97,7 @@ describe('AiProviderModel', () => {
|
||||
describe('query', () => {
|
||||
it('should query ai providers for the user', async () => {
|
||||
await aiProviderModel.create({ name: 'AiHubMix', source: 'custom', id: 'aihubmix' });
|
||||
await sleep(10);
|
||||
await aiProviderModel.create({ name: 'AiHubMix', source: 'custom', id: 'aihubmix-2' });
|
||||
|
||||
const userGroups = await aiProviderModel.query();
|
||||
|
||||
@@ -193,6 +193,14 @@ export class SessionModel {
|
||||
type: 'agent' | 'group';
|
||||
}): Promise<SessionItem> => {
|
||||
return this.db.transaction(async (trx) => {
|
||||
if (slug) {
|
||||
const existResult = await trx.query.sessions.findFirst({
|
||||
where: and(eq(sessions.slug, slug), eq(sessions.userId, this.userId)),
|
||||
});
|
||||
|
||||
if (existResult) return existResult;
|
||||
}
|
||||
|
||||
const newAgents = await trx
|
||||
.insert(agents)
|
||||
.values({
|
||||
|
||||
@@ -85,7 +85,13 @@ export class UserModel {
|
||||
const state = result[0];
|
||||
|
||||
// Decrypt keyVaults
|
||||
const decryptKeyVaults = await decryptor(state.settingsKeyVaults, this.userId);
|
||||
let decryptKeyVaults = {};
|
||||
|
||||
try {
|
||||
decryptKeyVaults = await decryptor(state.settingsKeyVaults, this.userId);
|
||||
} catch {
|
||||
/* empty */
|
||||
}
|
||||
|
||||
const settings: DeepPartial<UserSettings> = {
|
||||
defaultAgent: state.settingsDefaultAgent || {},
|
||||
|
||||
@@ -1 +1 @@
|
||||
export { serverDB } from '../core/db-adaptor';
|
||||
export { getServerDB, serverDB } from '../core/db-adaptor';
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import { getServerDB } from '@/database/core/db-adaptor';
|
||||
|
||||
import { asyncAuth } from './asyncAuth';
|
||||
import { asyncTrpc } from './init';
|
||||
|
||||
@@ -5,6 +7,14 @@ export const publicProcedure = asyncTrpc.procedure;
|
||||
|
||||
export const asyncRouter = asyncTrpc.router;
|
||||
|
||||
export const asyncAuthedProcedure = asyncTrpc.procedure.use(asyncAuth);
|
||||
export const asyncAuthedProcedure = asyncTrpc.procedure.use(asyncAuth).use(
|
||||
asyncTrpc.middleware(async (opts) => {
|
||||
const serverDB = await getServerDB();
|
||||
|
||||
return opts.next({
|
||||
ctx: { serverDB },
|
||||
});
|
||||
}),
|
||||
);
|
||||
|
||||
export const createAsyncCallerFactory = asyncTrpc.createCallerFactory;
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
export * from './serverDatabase';
|
||||
@@ -0,0 +1,10 @@
|
||||
import { getServerDB } from '@/database/core/db-adaptor';
|
||||
import { trpc } from '@/libs/trpc/init';
|
||||
|
||||
export const serverDatabase = trpc.middleware(async (opts) => {
|
||||
const serverDB = await getServerDB();
|
||||
|
||||
return opts.next({
|
||||
ctx: { serverDB },
|
||||
});
|
||||
});
|
||||
@@ -11,7 +11,6 @@ import { ChunkModel } from '@/database/models/chunk';
|
||||
import { EmbeddingModel } from '@/database/models/embedding';
|
||||
import { FileModel } from '@/database/models/file';
|
||||
import { NewChunkItem, NewEmbeddingsItem } from '@/database/schemas';
|
||||
import { serverDB } from '@/database/server';
|
||||
import { asyncAuthedProcedure, asyncRouter as router } from '@/libs/trpc/async';
|
||||
import { getServerDefaultFilesConfig } from '@/server/globalConfig';
|
||||
import { initAgentRuntimeWithUserPayload } from '@/server/modules/AgentRuntime';
|
||||
@@ -31,11 +30,11 @@ const fileProcedure = asyncAuthedProcedure.use(async (opts) => {
|
||||
|
||||
return opts.next({
|
||||
ctx: {
|
||||
asyncTaskModel: new AsyncTaskModel(serverDB, ctx.userId),
|
||||
chunkModel: new ChunkModel(serverDB, ctx.userId),
|
||||
asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId),
|
||||
chunkModel: new ChunkModel(ctx.serverDB, ctx.userId),
|
||||
chunkService: new ChunkService(ctx.userId),
|
||||
embeddingModel: new EmbeddingModel(serverDB, ctx.userId),
|
||||
fileModel: new FileModel(serverDB, ctx.userId),
|
||||
embeddingModel: new EmbeddingModel(ctx.serverDB, ctx.userId),
|
||||
fileModel: new FileModel(ctx.serverDB, ctx.userId),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
@@ -7,7 +7,6 @@ import { DEFAULT_EMBEDDING_MODEL, DEFAULT_MODEL } from '@/const/settings';
|
||||
import { ChunkModel } from '@/database/models/chunk';
|
||||
import { EmbeddingModel } from '@/database/models/embedding';
|
||||
import { FileModel } from '@/database/models/file';
|
||||
import { serverDB } from '@/database/server';
|
||||
import {
|
||||
EvalDatasetRecordModel,
|
||||
EvalEvaluationModel,
|
||||
@@ -25,13 +24,13 @@ const ragEvalProcedure = asyncAuthedProcedure.use(async (opts) => {
|
||||
|
||||
return opts.next({
|
||||
ctx: {
|
||||
chunkModel: new ChunkModel(serverDB, ctx.userId),
|
||||
chunkModel: new ChunkModel(ctx.serverDB, ctx.userId),
|
||||
chunkService: new ChunkService(ctx.userId),
|
||||
datasetRecordModel: new EvalDatasetRecordModel(ctx.userId),
|
||||
embeddingModel: new EmbeddingModel(serverDB, ctx.userId),
|
||||
embeddingModel: new EmbeddingModel(ctx.serverDB, ctx.userId),
|
||||
evalRecordModel: new EvaluationRecordModel(ctx.userId),
|
||||
evaluationModel: new EvalEvaluationModel(ctx.userId),
|
||||
fileModel: new FileModel(serverDB, ctx.userId),
|
||||
fileModel: new FileModel(ctx.serverDB, ctx.userId),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
@@ -2,16 +2,16 @@ import { z } from 'zod';
|
||||
|
||||
import { SessionGroupModel } from '@/database/models/sessionGroup';
|
||||
import { insertSessionGroupSchema } from '@/database/schemas';
|
||||
import { serverDB } from '@/database/server';
|
||||
import { authedProcedure, router } from '@/libs/trpc';
|
||||
import { serverDatabase } from '@/libs/trpc/lambda';
|
||||
import { SessionGroupItem } from '@/types/session';
|
||||
|
||||
const sessionProcedure = authedProcedure.use(async (opts) => {
|
||||
const sessionProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: {
|
||||
sessionGroupModel: new SessionGroupModel(serverDB, ctx.userId),
|
||||
sessionGroupModel: new SessionGroupModel(ctx.serverDB, ctx.userId),
|
||||
},
|
||||
});
|
||||
});
|
||||
@@ -69,8 +69,6 @@ export const sessionGroupRouter = router({
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ input, ctx }) => {
|
||||
console.log('sortMap:', input.sortMap);
|
||||
|
||||
return ctx.sessionGroupModel.updateOrder(input.sortMap);
|
||||
}),
|
||||
});
|
||||
|
||||
@@ -7,22 +7,22 @@ import { FileModel } from '@/database/models/file';
|
||||
import { KnowledgeBaseModel } from '@/database/models/knowledgeBase';
|
||||
import { SessionModel } from '@/database/models/session';
|
||||
import { UserModel } from '@/database/models/user';
|
||||
import { serverDB } from '@/database/server';
|
||||
import { pino } from '@/libs/logger';
|
||||
import { authedProcedure, router } from '@/libs/trpc';
|
||||
import { serverDatabase } from '@/libs/trpc/lambda';
|
||||
import { AgentService } from '@/server/services/agent';
|
||||
import { KnowledgeItem, KnowledgeType } from '@/types/knowledgeBase';
|
||||
|
||||
const agentProcedure = authedProcedure.use(async (opts) => {
|
||||
const agentProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: {
|
||||
agentModel: new AgentModel(serverDB, ctx.userId),
|
||||
agentService: new AgentService(serverDB, ctx.userId),
|
||||
fileModel: new FileModel(serverDB, ctx.userId),
|
||||
knowledgeBaseModel: new KnowledgeBaseModel(serverDB, ctx.userId),
|
||||
sessionModel: new SessionModel(serverDB, ctx.userId),
|
||||
agentModel: new AgentModel(ctx.serverDB, ctx.userId),
|
||||
agentService: new AgentService(ctx.serverDB, ctx.userId),
|
||||
fileModel: new FileModel(ctx.serverDB, ctx.userId),
|
||||
knowledgeBaseModel: new KnowledgeBaseModel(ctx.serverDB, ctx.userId),
|
||||
sessionModel: new SessionModel(ctx.serverDB, ctx.userId),
|
||||
},
|
||||
});
|
||||
});
|
||||
@@ -90,7 +90,7 @@ export const agentRouter = router({
|
||||
// if there is no session for user, create one
|
||||
if (!item) {
|
||||
// if there is no user, return default config
|
||||
const user = await UserModel.findById(serverDB, ctx.userId);
|
||||
const user = await UserModel.findById(ctx.serverDB, ctx.userId);
|
||||
if (!user) return DEFAULT_AGENT_CONFIG;
|
||||
|
||||
const res = await ctx.agentService.createInbox();
|
||||
|
||||
@@ -3,8 +3,8 @@ import { z } from 'zod';
|
||||
import { AiModelModel } from '@/database/models/aiModel';
|
||||
import { UserModel } from '@/database/models/user';
|
||||
import { AiInfraRepos } from '@/database/repositories/aiInfra';
|
||||
import { serverDB } from '@/database/server';
|
||||
import { authedProcedure, router } from '@/libs/trpc';
|
||||
import { serverDatabase } from '@/libs/trpc/lambda';
|
||||
import { getServerGlobalConfig } from '@/server/globalConfig';
|
||||
import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt';
|
||||
import {
|
||||
@@ -15,7 +15,7 @@ import {
|
||||
} from '@/types/aiModel';
|
||||
import { ProviderConfig } from '@/types/user/settings';
|
||||
|
||||
const aiModelProcedure = authedProcedure.use(async (opts) => {
|
||||
const aiModelProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey();
|
||||
@@ -24,13 +24,13 @@ const aiModelProcedure = authedProcedure.use(async (opts) => {
|
||||
return opts.next({
|
||||
ctx: {
|
||||
aiInfraRepos: new AiInfraRepos(
|
||||
serverDB,
|
||||
ctx.serverDB,
|
||||
ctx.userId,
|
||||
aiProvider as Record<string, ProviderConfig>,
|
||||
),
|
||||
aiModelModel: new AiModelModel(serverDB, ctx.userId),
|
||||
aiModelModel: new AiModelModel(ctx.serverDB, ctx.userId),
|
||||
gateKeeper,
|
||||
userModel: new UserModel(serverDB, ctx.userId),
|
||||
userModel: new UserModel(ctx.serverDB, ctx.userId),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { AiProviderModel } from '@/database/models/aiProvider';
|
||||
import { UserModel } from '@/database/models/user';
|
||||
import { AiInfraRepos } from '@/database/repositories/aiInfra';
|
||||
import { serverDB } from '@/database/server';
|
||||
import { getServerGlobalConfig } from '@/server/globalConfig';
|
||||
import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt';
|
||||
import { AiProviderDetailItem, AiProviderRuntimeState } from '@/types/aiProvider';
|
||||
|
||||
@@ -3,8 +3,8 @@ import { z } from 'zod';
|
||||
import { AiProviderModel } from '@/database/models/aiProvider';
|
||||
import { UserModel } from '@/database/models/user';
|
||||
import { AiInfraRepos } from '@/database/repositories/aiInfra';
|
||||
import { serverDB } from '@/database/server';
|
||||
import { authedProcedure, router } from '@/libs/trpc';
|
||||
import { serverDatabase } from '@/libs/trpc/lambda';
|
||||
import { getServerGlobalConfig } from '@/server/globalConfig';
|
||||
import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt';
|
||||
import {
|
||||
@@ -16,7 +16,7 @@ import {
|
||||
} from '@/types/aiProvider';
|
||||
import { ProviderConfig } from '@/types/user/settings';
|
||||
|
||||
const aiProviderProcedure = authedProcedure.use(async (opts) => {
|
||||
const aiProviderProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
const { aiProvider } = await getServerGlobalConfig();
|
||||
@@ -25,13 +25,13 @@ const aiProviderProcedure = authedProcedure.use(async (opts) => {
|
||||
return opts.next({
|
||||
ctx: {
|
||||
aiInfraRepos: new AiInfraRepos(
|
||||
serverDB,
|
||||
ctx.serverDB,
|
||||
ctx.userId,
|
||||
aiProvider as Record<string, ProviderConfig>,
|
||||
),
|
||||
aiProviderModel: new AiProviderModel(serverDB, ctx.userId),
|
||||
aiProviderModel: new AiProviderModel(ctx.serverDB, ctx.userId),
|
||||
gateKeeper,
|
||||
userModel: new UserModel(serverDB, ctx.userId),
|
||||
userModel: new UserModel(ctx.serverDB, ctx.userId),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
@@ -9,28 +9,31 @@ import { EmbeddingModel } from '@/database/models/embedding';
|
||||
import { FileModel } from '@/database/models/file';
|
||||
import { MessageModel } from '@/database/models/message';
|
||||
import { knowledgeBaseFiles } from '@/database/schemas';
|
||||
import { serverDB } from '@/database/server';
|
||||
import { authedProcedure, router } from '@/libs/trpc';
|
||||
import { serverDatabase } from '@/libs/trpc/lambda';
|
||||
import { keyVaults } from '@/libs/trpc/middleware/keyVaults';
|
||||
import { getServerDefaultFilesConfig } from '@/server/globalConfig';
|
||||
import { initAgentRuntimeWithUserPayload } from '@/server/modules/AgentRuntime';
|
||||
import { ChunkService } from '@/server/services/chunk';
|
||||
import { SemanticSearchSchema } from '@/types/rag';
|
||||
|
||||
const chunkProcedure = authedProcedure.use(keyVaults).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
const chunkProcedure = authedProcedure
|
||||
.use(serverDatabase)
|
||||
.use(keyVaults)
|
||||
.use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: {
|
||||
asyncTaskModel: new AsyncTaskModel(serverDB, ctx.userId),
|
||||
chunkModel: new ChunkModel(serverDB, ctx.userId),
|
||||
chunkService: new ChunkService(ctx.userId),
|
||||
embeddingModel: new EmbeddingModel(serverDB, ctx.userId),
|
||||
fileModel: new FileModel(serverDB, ctx.userId),
|
||||
messageModel: new MessageModel(serverDB, ctx.userId),
|
||||
},
|
||||
return opts.next({
|
||||
ctx: {
|
||||
asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId),
|
||||
chunkModel: new ChunkModel(ctx.serverDB, ctx.userId),
|
||||
chunkService: new ChunkService(ctx.userId),
|
||||
embeddingModel: new EmbeddingModel(ctx.serverDB, ctx.userId),
|
||||
fileModel: new FileModel(ctx.serverDB, ctx.userId),
|
||||
messageModel: new MessageModel(ctx.serverDB, ctx.userId),
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
export const chunkRouter = router({
|
||||
createEmbeddingChunksTask: chunkProcedure
|
||||
@@ -173,7 +176,7 @@ export const chunkRouter = router({
|
||||
let finalFileIds = input.fileIds ?? [];
|
||||
|
||||
if (input.knowledgeIds && input.knowledgeIds.length > 0) {
|
||||
const knowledgeFiles = await serverDB.query.knowledgeBaseFiles.findMany({
|
||||
const knowledgeFiles = await ctx.serverDB.query.knowledgeBaseFiles.findMany({
|
||||
where: inArray(knowledgeBaseFiles.knowledgeBaseId, input.knowledgeIds),
|
||||
});
|
||||
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import { DrizzleMigrationModel } from '@/database/models/drizzleMigration';
|
||||
import { DataExporterRepos } from '@/database/repositories/dataExporter';
|
||||
import { serverDB } from '@/database/server';
|
||||
import { authedProcedure, router } from '@/libs/trpc';
|
||||
import { serverDatabase } from '@/libs/trpc/lambda';
|
||||
import { ExportDatabaseData } from '@/types/export';
|
||||
|
||||
const exportProcedure = authedProcedure.use(async (opts) => {
|
||||
const exportProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
const dataExporterRepos = new DataExporterRepos(serverDB, ctx.userId);
|
||||
const drizzleMigration = new DrizzleMigrationModel(serverDB);
|
||||
const dataExporterRepos = new DataExporterRepos(ctx.serverDB, ctx.userId);
|
||||
const drizzleMigration = new DrizzleMigrationModel(ctx.serverDB);
|
||||
|
||||
return opts.next({
|
||||
ctx: { dataExporterRepos, drizzleMigration },
|
||||
|
||||
@@ -5,21 +5,21 @@ import { serverDBEnv } from '@/config/db';
|
||||
import { AsyncTaskModel } from '@/database/models/asyncTask';
|
||||
import { ChunkModel } from '@/database/models/chunk';
|
||||
import { FileModel } from '@/database/models/file';
|
||||
import { serverDB } from '@/database/server';
|
||||
import { authedProcedure, router } from '@/libs/trpc';
|
||||
import { serverDatabase } from '@/libs/trpc/lambda';
|
||||
import { S3 } from '@/server/modules/S3';
|
||||
import { getFullFileUrl } from '@/server/utils/files';
|
||||
import { AsyncTaskStatus, AsyncTaskType } from '@/types/asyncTask';
|
||||
import { FileListItem, QueryFileListSchema, UploadFileSchema } from '@/types/files';
|
||||
|
||||
const fileProcedure = authedProcedure.use(async (opts) => {
|
||||
const fileProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: {
|
||||
asyncTaskModel: new AsyncTaskModel(serverDB, ctx.userId),
|
||||
chunkModel: new ChunkModel(serverDB, ctx.userId),
|
||||
fileModel: new FileModel(serverDB, ctx.userId),
|
||||
asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId),
|
||||
chunkModel: new ChunkModel(ctx.serverDB, ctx.userId),
|
||||
fileModel: new FileModel(ctx.serverDB, ctx.userId),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
@@ -2,15 +2,15 @@ import { TRPCError } from '@trpc/server';
|
||||
import { z } from 'zod';
|
||||
|
||||
import { DataImporterRepos } from '@/database/repositories/dataImporter';
|
||||
import { serverDB } from '@/database/server';
|
||||
import { authedProcedure, router } from '@/libs/trpc';
|
||||
import { serverDatabase } from '@/libs/trpc/lambda';
|
||||
import { S3 } from '@/server/modules/S3';
|
||||
import { ImportPgDataStructure } from '@/types/export';
|
||||
import { ImportResultData, ImporterEntryData } from '@/types/importer';
|
||||
|
||||
const importProcedure = authedProcedure.use(async (opts) => {
|
||||
const importProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
const dataImporterService = new DataImporterRepos(serverDB, ctx.userId);
|
||||
const dataImporterService = new DataImporterRepos(ctx.serverDB, ctx.userId);
|
||||
|
||||
return opts.next({
|
||||
ctx: { dataImporterService },
|
||||
|
||||
@@ -2,16 +2,16 @@ import { z } from 'zod';
|
||||
|
||||
import { KnowledgeBaseModel } from '@/database/models/knowledgeBase';
|
||||
import { insertKnowledgeBasesSchema } from '@/database/schemas';
|
||||
import { serverDB } from '@/database/server';
|
||||
import { authedProcedure, router } from '@/libs/trpc';
|
||||
import { serverDatabase } from '@/libs/trpc/lambda';
|
||||
import { KnowledgeBaseItem } from '@/types/knowledgeBase';
|
||||
|
||||
const knowledgeBaseProcedure = authedProcedure.use(async (opts) => {
|
||||
const knowledgeBaseProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: {
|
||||
knowledgeBaseModel: new KnowledgeBaseModel(serverDB, ctx.userId),
|
||||
knowledgeBaseModel: new KnowledgeBaseModel(ctx.serverDB, ctx.userId),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
@@ -2,19 +2,20 @@ import { z } from 'zod';
|
||||
|
||||
import { MessageModel } from '@/database/models/message';
|
||||
import { updateMessagePluginSchema } from '@/database/schemas';
|
||||
import { serverDB } from '@/database/server';
|
||||
import { getServerDB } from '@/database/server';
|
||||
import { authedProcedure, publicProcedure, router } from '@/libs/trpc';
|
||||
import { serverDatabase } from '@/libs/trpc/lambda';
|
||||
import { getFullFileUrl } from '@/server/utils/files';
|
||||
import { ChatMessage } from '@/types/message';
|
||||
import { BatchTaskResult } from '@/types/service';
|
||||
|
||||
type ChatMessageList = ChatMessage[];
|
||||
|
||||
const messageProcedure = authedProcedure.use(async (opts) => {
|
||||
const messageProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: { messageModel: new MessageModel(serverDB, ctx.userId) },
|
||||
ctx: { messageModel: new MessageModel(ctx.serverDB, ctx.userId) },
|
||||
});
|
||||
});
|
||||
|
||||
@@ -95,6 +96,7 @@ export const messageRouter = router({
|
||||
)
|
||||
.query(async ({ input, ctx }) => {
|
||||
if (!ctx.userId) return [];
|
||||
const serverDB = await getServerDB();
|
||||
|
||||
const messageModel = new MessageModel(serverDB, ctx.userId);
|
||||
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
import { PluginModel } from '@/database/models/plugin';
|
||||
import { serverDB } from '@/database/server';
|
||||
import { getServerDB } from '@/database/server';
|
||||
import { authedProcedure, publicProcedure, router } from '@/libs/trpc';
|
||||
import { serverDatabase } from '@/libs/trpc/lambda';
|
||||
import { LobeTool } from '@/types/tool';
|
||||
|
||||
const pluginProcedure = authedProcedure.use(async (opts) => {
|
||||
const pluginProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: { pluginModel: new PluginModel(serverDB, ctx.userId) },
|
||||
ctx: { pluginModel: new PluginModel(ctx.serverDB, ctx.userId) },
|
||||
});
|
||||
});
|
||||
|
||||
@@ -66,6 +67,7 @@ export const pluginRouter = router({
|
||||
getPlugins: publicProcedure.query(async ({ ctx }): Promise<LobeTool[]> => {
|
||||
if (!ctx.userId) return [];
|
||||
|
||||
const serverDB = await getServerDB();
|
||||
const pluginModel = new PluginModel(serverDB, ctx.userId);
|
||||
|
||||
return pluginModel.query();
|
||||
|
||||
@@ -7,7 +7,6 @@ import { z } from 'zod';
|
||||
|
||||
import { DEFAULT_EMBEDDING_MODEL, DEFAULT_MODEL } from '@/const/settings';
|
||||
import { FileModel } from '@/database/models/file';
|
||||
import { serverDB } from '@/database/server';
|
||||
import {
|
||||
EvalDatasetModel,
|
||||
EvalDatasetRecordModel,
|
||||
@@ -15,6 +14,7 @@ import {
|
||||
EvaluationRecordModel,
|
||||
} from '@/database/server/models/ragEval';
|
||||
import { authedProcedure, router } from '@/libs/trpc';
|
||||
import { serverDatabase } from '@/libs/trpc/lambda';
|
||||
import { keyVaults } from '@/libs/trpc/middleware/keyVaults';
|
||||
import { S3 } from '@/server/modules/S3';
|
||||
import { createAsyncServerClient } from '@/server/routers/async';
|
||||
@@ -29,20 +29,23 @@ import {
|
||||
insertEvalEvaluationSchema,
|
||||
} from '@/types/eval';
|
||||
|
||||
const ragEvalProcedure = authedProcedure.use(keyVaults).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
const ragEvalProcedure = authedProcedure
|
||||
.use(serverDatabase)
|
||||
.use(keyVaults)
|
||||
.use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: {
|
||||
datasetModel: new EvalDatasetModel(ctx.userId),
|
||||
fileModel: new FileModel(serverDB, ctx.userId),
|
||||
datasetRecordModel: new EvalDatasetRecordModel(ctx.userId),
|
||||
evaluationModel: new EvalEvaluationModel(ctx.userId),
|
||||
evaluationRecordModel: new EvaluationRecordModel(ctx.userId),
|
||||
s3: new S3(),
|
||||
},
|
||||
return opts.next({
|
||||
ctx: {
|
||||
datasetModel: new EvalDatasetModel(ctx.userId),
|
||||
fileModel: new FileModel(ctx.serverDB, ctx.userId),
|
||||
datasetRecordModel: new EvalDatasetRecordModel(ctx.userId),
|
||||
evaluationModel: new EvalEvaluationModel(ctx.userId),
|
||||
evaluationRecordModel: new EvaluationRecordModel(ctx.userId),
|
||||
s3: new S3(),
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
export const ragEvalRouter = router({
|
||||
createDataset: ragEvalProcedure
|
||||
|
||||
@@ -3,21 +3,22 @@ import { z } from 'zod';
|
||||
import { SessionModel } from '@/database/models/session';
|
||||
import { SessionGroupModel } from '@/database/models/sessionGroup';
|
||||
import { insertAgentSchema, insertSessionSchema } from '@/database/schemas';
|
||||
import { serverDB } from '@/database/server';
|
||||
import { getServerDB } from '@/database/server';
|
||||
import { authedProcedure, publicProcedure, router } from '@/libs/trpc';
|
||||
import { serverDatabase } from '@/libs/trpc/lambda';
|
||||
import { AgentChatConfigSchema } from '@/types/agent';
|
||||
import { LobeMetaDataSchema } from '@/types/meta';
|
||||
import { BatchTaskResult } from '@/types/service';
|
||||
import { ChatSessionList } from '@/types/session';
|
||||
import { merge } from '@/utils/merge';
|
||||
|
||||
const sessionProcedure = authedProcedure.use(async (opts) => {
|
||||
const sessionProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: {
|
||||
sessionGroupModel: new SessionGroupModel(serverDB, ctx.userId),
|
||||
sessionModel: new SessionModel(serverDB, ctx.userId),
|
||||
sessionGroupModel: new SessionGroupModel(ctx.serverDB, ctx.userId),
|
||||
sessionModel: new SessionModel(ctx.serverDB, ctx.userId),
|
||||
},
|
||||
});
|
||||
});
|
||||
@@ -95,6 +96,7 @@ export const sessionRouter = router({
|
||||
sessions: [],
|
||||
};
|
||||
|
||||
const serverDB = await getServerDB();
|
||||
const sessionModel = new SessionModel(serverDB, ctx.userId);
|
||||
|
||||
return sessionModel.queryWithGroups();
|
||||
|
||||
@@ -2,16 +2,16 @@ import { z } from 'zod';
|
||||
|
||||
import { SessionGroupModel } from '@/database/models/sessionGroup';
|
||||
import { insertSessionGroupSchema } from '@/database/schemas';
|
||||
import { serverDB } from '@/database/server';
|
||||
import { authedProcedure, router } from '@/libs/trpc';
|
||||
import { serverDatabase } from '@/libs/trpc/lambda';
|
||||
import { SessionGroupItem } from '@/types/session';
|
||||
|
||||
const sessionProcedure = authedProcedure.use(async (opts) => {
|
||||
const sessionProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: {
|
||||
sessionGroupModel: new SessionGroupModel(serverDB, ctx.userId),
|
||||
sessionGroupModel: new SessionGroupModel(ctx.serverDB, ctx.userId),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
@@ -3,17 +3,17 @@ import { z } from 'zod';
|
||||
import { MessageModel } from '@/database/models/message';
|
||||
import { ThreadModel } from '@/database/models/thread';
|
||||
import { insertThreadSchema } from '@/database/schemas';
|
||||
import { serverDB } from '@/database/server';
|
||||
import { authedProcedure, router } from '@/libs/trpc';
|
||||
import { serverDatabase } from '@/libs/trpc/lambda';
|
||||
import { ThreadItem, createThreadSchema } from '@/types/topic/thread';
|
||||
|
||||
const threadProcedure = authedProcedure.use(async (opts) => {
|
||||
const threadProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: {
|
||||
messageModel: new MessageModel(serverDB, ctx.userId),
|
||||
threadModel: new ThreadModel(serverDB, ctx.userId),
|
||||
messageModel: new MessageModel(ctx.serverDB, ctx.userId),
|
||||
threadModel: new ThreadModel(ctx.serverDB, ctx.userId),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import { z } from 'zod';
|
||||
|
||||
import { TopicModel } from '@/database/models/topic';
|
||||
import { serverDB } from '@/database/server';
|
||||
import { getServerDB } from '@/database/server';
|
||||
import { authedProcedure, publicProcedure, router } from '@/libs/trpc';
|
||||
import { serverDatabase } from '@/libs/trpc/lambda';
|
||||
import { BatchTaskResult } from '@/types/service';
|
||||
|
||||
const topicProcedure = authedProcedure.use(async (opts) => {
|
||||
const topicProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: { topicModel: new TopicModel(serverDB, ctx.userId) },
|
||||
ctx: { topicModel: new TopicModel(ctx.serverDB, ctx.userId) },
|
||||
});
|
||||
});
|
||||
|
||||
@@ -101,6 +102,7 @@ export const topicRouter = router({
|
||||
.query(async ({ input, ctx }) => {
|
||||
if (!ctx.userId) return [];
|
||||
|
||||
const serverDB = await getServerDB();
|
||||
const topicModel = new TopicModel(serverDB, ctx.userId);
|
||||
|
||||
return topicModel.query(input);
|
||||
|
||||
@@ -5,10 +5,10 @@ import { enableClerk } from '@/const/auth';
|
||||
import { MessageModel } from '@/database/models/message';
|
||||
import { SessionModel } from '@/database/models/session';
|
||||
import { UserModel, UserNotFoundError } from '@/database/models/user';
|
||||
import { serverDB } from '@/database/server';
|
||||
import { ClerkAuth } from '@/libs/clerk-auth';
|
||||
import { LobeNextAuthDbAdapter } from '@/libs/next-auth/adapter';
|
||||
import { authedProcedure, router } from '@/libs/trpc';
|
||||
import { serverDatabase } from '@/libs/trpc/lambda';
|
||||
import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt';
|
||||
import { UserService } from '@/server/services/user';
|
||||
import {
|
||||
@@ -19,12 +19,12 @@ import {
|
||||
} from '@/types/user';
|
||||
import { UserSettings } from '@/types/user/settings';
|
||||
|
||||
const userProcedure = authedProcedure.use(async (opts) => {
|
||||
return opts.next({
|
||||
const userProcedure = authedProcedure.use(serverDatabase).use(async ({ ctx, next }) => {
|
||||
return next({
|
||||
ctx: {
|
||||
clerkAuth: new ClerkAuth(),
|
||||
nextAuthDbAdapter: LobeNextAuthDbAdapter(serverDB),
|
||||
userModel: new UserModel(serverDB, opts.ctx.userId),
|
||||
nextAuthDbAdapter: LobeNextAuthDbAdapter(ctx.serverDB),
|
||||
userModel: new UserModel(ctx.serverDB, ctx.userId),
|
||||
},
|
||||
});
|
||||
});
|
||||
@@ -77,10 +77,10 @@ export const userRouter = router({
|
||||
}
|
||||
}
|
||||
|
||||
const messageModel = new MessageModel(serverDB, ctx.userId);
|
||||
const messageModel = new MessageModel(ctx.serverDB, ctx.userId);
|
||||
const hasMoreThan4Messages = await messageModel.hasMoreThanN(4);
|
||||
|
||||
const sessionModel = new SessionModel(serverDB, ctx.userId);
|
||||
const sessionModel = new SessionModel(ctx.serverDB, ctx.userId);
|
||||
const hasAnyMessages = await messageModel.hasMoreThanN(0);
|
||||
const hasExtraSession = await sessionModel.hasMoreThanN(1);
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import { UserModel } from '@/database/models/user';
|
||||
import { UserItem } from '@/database/schemas';
|
||||
import { serverDB } from '@/database/server';
|
||||
import { pino } from '@/libs/logger';
|
||||
import { LobeNextAuthDbAdapter } from '@/libs/next-auth/adapter';
|
||||
|
||||
import { NextAuthUserService } from './index';
|
||||
|
||||
@@ -23,7 +22,7 @@ vi.mock('@/database/server');
|
||||
describe('NextAuthUserService', () => {
|
||||
let service: NextAuthUserService;
|
||||
|
||||
beforeEach(() => {
|
||||
beforeEach(async () => {
|
||||
vi.clearAllMocks();
|
||||
service = new NextAuthUserService();
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user