♻️ refactor: refactor the db to context inject mode (#7255)

* refactor with new db init mode

* fix tests

* fix tests

* move the separate index

* fix tests

* fix tests

* fix db issue

* fix db

* refactor to clean

* Update index.ts

* fix error

* fix the exist inbox slug session

* fix the tests
This commit is contained in:
Arvin Xu
2025-04-01 20:53:03 +08:00
committed by GitHub
parent 3a52f5cf97
commit ffd0dbc7f5
29 changed files with 172 additions and 110 deletions
+20 -2
View File
@@ -1,6 +1,24 @@
// import { isDesktop } from '@/const/version';
import { getDBInstance } from '@/database/core/web-server';
import { LobeChatDatabase } from '@/database/type';
// import { getPgliteInstance } from './electron';
/**
* 懒加载数据库实例
* 避免每次模块导入时都初始化数据库
*/
let cachedDB: LobeChatDatabase | null = null;
export const getServerDB = async (): Promise<LobeChatDatabase> => {
// 如果已经有缓存的实例,直接返回
if (cachedDB) return cachedDB;
try {
// 根据环境选择合适的数据库实例
cachedDB = getDBInstance();
return cachedDB;
} catch (error) {
console.error('❌ Failed to initialize database:', error);
throw error;
}
};
export const serverDB = getDBInstance();
@@ -4,6 +4,7 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { LobeChatDatabase } from '@/database/type';
import { ModelProvider } from '@/libs/agent-runtime';
import { sleep } from '@/utils/sleep';
import { aiProviders, users } from '../../schemas';
import { AiProviderModel } from '../aiProvider';
@@ -96,6 +97,7 @@ describe('AiProviderModel', () => {
describe('query', () => {
it('should query ai providers for the user', async () => {
await aiProviderModel.create({ name: 'AiHubMix', source: 'custom', id: 'aihubmix' });
await sleep(10);
await aiProviderModel.create({ name: 'AiHubMix', source: 'custom', id: 'aihubmix-2' });
const userGroups = await aiProviderModel.query();
+8
View File
@@ -193,6 +193,14 @@ export class SessionModel {
type: 'agent' | 'group';
}): Promise<SessionItem> => {
return this.db.transaction(async (trx) => {
if (slug) {
const existResult = await trx.query.sessions.findFirst({
where: and(eq(sessions.slug, slug), eq(sessions.userId, this.userId)),
});
if (existResult) return existResult;
}
const newAgents = await trx
.insert(agents)
.values({
+7 -1
View File
@@ -85,7 +85,13 @@ export class UserModel {
const state = result[0];
// Decrypt keyVaults
const decryptKeyVaults = await decryptor(state.settingsKeyVaults, this.userId);
let decryptKeyVaults = {};
try {
decryptKeyVaults = await decryptor(state.settingsKeyVaults, this.userId);
} catch {
/* empty */
}
const settings: DeepPartial<UserSettings> = {
defaultAgent: state.settingsDefaultAgent || {},
+1 -1
View File
@@ -1 +1 @@
export { serverDB } from '../core/db-adaptor';
export { getServerDB, serverDB } from '../core/db-adaptor';
+11 -1
View File
@@ -1,3 +1,5 @@
import { getServerDB } from '@/database/core/db-adaptor';
import { asyncAuth } from './asyncAuth';
import { asyncTrpc } from './init';
@@ -5,6 +7,14 @@ export const publicProcedure = asyncTrpc.procedure;
export const asyncRouter = asyncTrpc.router;
export const asyncAuthedProcedure = asyncTrpc.procedure.use(asyncAuth);
export const asyncAuthedProcedure = asyncTrpc.procedure.use(asyncAuth).use(
asyncTrpc.middleware(async (opts) => {
const serverDB = await getServerDB();
return opts.next({
ctx: { serverDB },
});
}),
);
export const createAsyncCallerFactory = asyncTrpc.createCallerFactory;
+1
View File
@@ -0,0 +1 @@
export * from './serverDatabase';
+10
View File
@@ -0,0 +1,10 @@
import { getServerDB } from '@/database/core/db-adaptor';
import { trpc } from '@/libs/trpc/init';
export const serverDatabase = trpc.middleware(async (opts) => {
const serverDB = await getServerDB();
return opts.next({
ctx: { serverDB },
});
});
+4 -5
View File
@@ -11,7 +11,6 @@ import { ChunkModel } from '@/database/models/chunk';
import { EmbeddingModel } from '@/database/models/embedding';
import { FileModel } from '@/database/models/file';
import { NewChunkItem, NewEmbeddingsItem } from '@/database/schemas';
import { serverDB } from '@/database/server';
import { asyncAuthedProcedure, asyncRouter as router } from '@/libs/trpc/async';
import { getServerDefaultFilesConfig } from '@/server/globalConfig';
import { initAgentRuntimeWithUserPayload } from '@/server/modules/AgentRuntime';
@@ -31,11 +30,11 @@ const fileProcedure = asyncAuthedProcedure.use(async (opts) => {
return opts.next({
ctx: {
asyncTaskModel: new AsyncTaskModel(serverDB, ctx.userId),
chunkModel: new ChunkModel(serverDB, ctx.userId),
asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId),
chunkModel: new ChunkModel(ctx.serverDB, ctx.userId),
chunkService: new ChunkService(ctx.userId),
embeddingModel: new EmbeddingModel(serverDB, ctx.userId),
fileModel: new FileModel(serverDB, ctx.userId),
embeddingModel: new EmbeddingModel(ctx.serverDB, ctx.userId),
fileModel: new FileModel(ctx.serverDB, ctx.userId),
},
});
});
+3 -4
View File
@@ -7,7 +7,6 @@ import { DEFAULT_EMBEDDING_MODEL, DEFAULT_MODEL } from '@/const/settings';
import { ChunkModel } from '@/database/models/chunk';
import { EmbeddingModel } from '@/database/models/embedding';
import { FileModel } from '@/database/models/file';
import { serverDB } from '@/database/server';
import {
EvalDatasetRecordModel,
EvalEvaluationModel,
@@ -25,13 +24,13 @@ const ragEvalProcedure = asyncAuthedProcedure.use(async (opts) => {
return opts.next({
ctx: {
chunkModel: new ChunkModel(serverDB, ctx.userId),
chunkModel: new ChunkModel(ctx.serverDB, ctx.userId),
chunkService: new ChunkService(ctx.userId),
datasetRecordModel: new EvalDatasetRecordModel(ctx.userId),
embeddingModel: new EmbeddingModel(serverDB, ctx.userId),
embeddingModel: new EmbeddingModel(ctx.serverDB, ctx.userId),
evalRecordModel: new EvaluationRecordModel(ctx.userId),
evaluationModel: new EvalEvaluationModel(ctx.userId),
fileModel: new FileModel(serverDB, ctx.userId),
fileModel: new FileModel(ctx.serverDB, ctx.userId),
},
});
});
+3 -5
View File
@@ -2,16 +2,16 @@ import { z } from 'zod';
import { SessionGroupModel } from '@/database/models/sessionGroup';
import { insertSessionGroupSchema } from '@/database/schemas';
import { serverDB } from '@/database/server';
import { authedProcedure, router } from '@/libs/trpc';
import { serverDatabase } from '@/libs/trpc/lambda';
import { SessionGroupItem } from '@/types/session';
const sessionProcedure = authedProcedure.use(async (opts) => {
const sessionProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
const { ctx } = opts;
return opts.next({
ctx: {
sessionGroupModel: new SessionGroupModel(serverDB, ctx.userId),
sessionGroupModel: new SessionGroupModel(ctx.serverDB, ctx.userId),
},
});
});
@@ -69,8 +69,6 @@ export const sessionGroupRouter = router({
}),
)
.mutation(async ({ input, ctx }) => {
console.log('sortMap:', input.sortMap);
return ctx.sessionGroupModel.updateOrder(input.sortMap);
}),
});
+8 -8
View File
@@ -7,22 +7,22 @@ import { FileModel } from '@/database/models/file';
import { KnowledgeBaseModel } from '@/database/models/knowledgeBase';
import { SessionModel } from '@/database/models/session';
import { UserModel } from '@/database/models/user';
import { serverDB } from '@/database/server';
import { pino } from '@/libs/logger';
import { authedProcedure, router } from '@/libs/trpc';
import { serverDatabase } from '@/libs/trpc/lambda';
import { AgentService } from '@/server/services/agent';
import { KnowledgeItem, KnowledgeType } from '@/types/knowledgeBase';
const agentProcedure = authedProcedure.use(async (opts) => {
const agentProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
const { ctx } = opts;
return opts.next({
ctx: {
agentModel: new AgentModel(serverDB, ctx.userId),
agentService: new AgentService(serverDB, ctx.userId),
fileModel: new FileModel(serverDB, ctx.userId),
knowledgeBaseModel: new KnowledgeBaseModel(serverDB, ctx.userId),
sessionModel: new SessionModel(serverDB, ctx.userId),
agentModel: new AgentModel(ctx.serverDB, ctx.userId),
agentService: new AgentService(ctx.serverDB, ctx.userId),
fileModel: new FileModel(ctx.serverDB, ctx.userId),
knowledgeBaseModel: new KnowledgeBaseModel(ctx.serverDB, ctx.userId),
sessionModel: new SessionModel(ctx.serverDB, ctx.userId),
},
});
});
@@ -90,7 +90,7 @@ export const agentRouter = router({
// if there is no session for user, create one
if (!item) {
// if there is no user, return default config
const user = await UserModel.findById(serverDB, ctx.userId);
const user = await UserModel.findById(ctx.serverDB, ctx.userId);
if (!user) return DEFAULT_AGENT_CONFIG;
const res = await ctx.agentService.createInbox();
+5 -5
View File
@@ -3,8 +3,8 @@ import { z } from 'zod';
import { AiModelModel } from '@/database/models/aiModel';
import { UserModel } from '@/database/models/user';
import { AiInfraRepos } from '@/database/repositories/aiInfra';
import { serverDB } from '@/database/server';
import { authedProcedure, router } from '@/libs/trpc';
import { serverDatabase } from '@/libs/trpc/lambda';
import { getServerGlobalConfig } from '@/server/globalConfig';
import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt';
import {
@@ -15,7 +15,7 @@ import {
} from '@/types/aiModel';
import { ProviderConfig } from '@/types/user/settings';
const aiModelProcedure = authedProcedure.use(async (opts) => {
const aiModelProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
const { ctx } = opts;
const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey();
@@ -24,13 +24,13 @@ const aiModelProcedure = authedProcedure.use(async (opts) => {
return opts.next({
ctx: {
aiInfraRepos: new AiInfraRepos(
serverDB,
ctx.serverDB,
ctx.userId,
aiProvider as Record<string, ProviderConfig>,
),
aiModelModel: new AiModelModel(serverDB, ctx.userId),
aiModelModel: new AiModelModel(ctx.serverDB, ctx.userId),
gateKeeper,
userModel: new UserModel(serverDB, ctx.userId),
userModel: new UserModel(ctx.serverDB, ctx.userId),
},
});
});
@@ -1,9 +1,7 @@
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { AiProviderModel } from '@/database/models/aiProvider';
import { UserModel } from '@/database/models/user';
import { AiInfraRepos } from '@/database/repositories/aiInfra';
import { serverDB } from '@/database/server';
import { getServerGlobalConfig } from '@/server/globalConfig';
import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt';
import { AiProviderDetailItem, AiProviderRuntimeState } from '@/types/aiProvider';
+5 -5
View File
@@ -3,8 +3,8 @@ import { z } from 'zod';
import { AiProviderModel } from '@/database/models/aiProvider';
import { UserModel } from '@/database/models/user';
import { AiInfraRepos } from '@/database/repositories/aiInfra';
import { serverDB } from '@/database/server';
import { authedProcedure, router } from '@/libs/trpc';
import { serverDatabase } from '@/libs/trpc/lambda';
import { getServerGlobalConfig } from '@/server/globalConfig';
import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt';
import {
@@ -16,7 +16,7 @@ import {
} from '@/types/aiProvider';
import { ProviderConfig } from '@/types/user/settings';
const aiProviderProcedure = authedProcedure.use(async (opts) => {
const aiProviderProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
const { ctx } = opts;
const { aiProvider } = await getServerGlobalConfig();
@@ -25,13 +25,13 @@ const aiProviderProcedure = authedProcedure.use(async (opts) => {
return opts.next({
ctx: {
aiInfraRepos: new AiInfraRepos(
serverDB,
ctx.serverDB,
ctx.userId,
aiProvider as Record<string, ProviderConfig>,
),
aiProviderModel: new AiProviderModel(serverDB, ctx.userId),
aiProviderModel: new AiProviderModel(ctx.serverDB, ctx.userId),
gateKeeper,
userModel: new UserModel(serverDB, ctx.userId),
userModel: new UserModel(ctx.serverDB, ctx.userId),
},
});
});
+17 -14
View File
@@ -9,28 +9,31 @@ import { EmbeddingModel } from '@/database/models/embedding';
import { FileModel } from '@/database/models/file';
import { MessageModel } from '@/database/models/message';
import { knowledgeBaseFiles } from '@/database/schemas';
import { serverDB } from '@/database/server';
import { authedProcedure, router } from '@/libs/trpc';
import { serverDatabase } from '@/libs/trpc/lambda';
import { keyVaults } from '@/libs/trpc/middleware/keyVaults';
import { getServerDefaultFilesConfig } from '@/server/globalConfig';
import { initAgentRuntimeWithUserPayload } from '@/server/modules/AgentRuntime';
import { ChunkService } from '@/server/services/chunk';
import { SemanticSearchSchema } from '@/types/rag';
const chunkProcedure = authedProcedure.use(keyVaults).use(async (opts) => {
const { ctx } = opts;
const chunkProcedure = authedProcedure
.use(serverDatabase)
.use(keyVaults)
.use(async (opts) => {
const { ctx } = opts;
return opts.next({
ctx: {
asyncTaskModel: new AsyncTaskModel(serverDB, ctx.userId),
chunkModel: new ChunkModel(serverDB, ctx.userId),
chunkService: new ChunkService(ctx.userId),
embeddingModel: new EmbeddingModel(serverDB, ctx.userId),
fileModel: new FileModel(serverDB, ctx.userId),
messageModel: new MessageModel(serverDB, ctx.userId),
},
return opts.next({
ctx: {
asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId),
chunkModel: new ChunkModel(ctx.serverDB, ctx.userId),
chunkService: new ChunkService(ctx.userId),
embeddingModel: new EmbeddingModel(ctx.serverDB, ctx.userId),
fileModel: new FileModel(ctx.serverDB, ctx.userId),
messageModel: new MessageModel(ctx.serverDB, ctx.userId),
},
});
});
});
export const chunkRouter = router({
createEmbeddingChunksTask: chunkProcedure
@@ -173,7 +176,7 @@ export const chunkRouter = router({
let finalFileIds = input.fileIds ?? [];
if (input.knowledgeIds && input.knowledgeIds.length > 0) {
const knowledgeFiles = await serverDB.query.knowledgeBaseFiles.findMany({
const knowledgeFiles = await ctx.serverDB.query.knowledgeBaseFiles.findMany({
where: inArray(knowledgeBaseFiles.knowledgeBaseId, input.knowledgeIds),
});
+4 -4
View File
@@ -1,13 +1,13 @@
import { DrizzleMigrationModel } from '@/database/models/drizzleMigration';
import { DataExporterRepos } from '@/database/repositories/dataExporter';
import { serverDB } from '@/database/server';
import { authedProcedure, router } from '@/libs/trpc';
import { serverDatabase } from '@/libs/trpc/lambda';
import { ExportDatabaseData } from '@/types/export';
const exportProcedure = authedProcedure.use(async (opts) => {
const exportProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
const { ctx } = opts;
const dataExporterRepos = new DataExporterRepos(serverDB, ctx.userId);
const drizzleMigration = new DrizzleMigrationModel(serverDB);
const dataExporterRepos = new DataExporterRepos(ctx.serverDB, ctx.userId);
const drizzleMigration = new DrizzleMigrationModel(ctx.serverDB);
return opts.next({
ctx: { dataExporterRepos, drizzleMigration },
+5 -5
View File
@@ -5,21 +5,21 @@ import { serverDBEnv } from '@/config/db';
import { AsyncTaskModel } from '@/database/models/asyncTask';
import { ChunkModel } from '@/database/models/chunk';
import { FileModel } from '@/database/models/file';
import { serverDB } from '@/database/server';
import { authedProcedure, router } from '@/libs/trpc';
import { serverDatabase } from '@/libs/trpc/lambda';
import { S3 } from '@/server/modules/S3';
import { getFullFileUrl } from '@/server/utils/files';
import { AsyncTaskStatus, AsyncTaskType } from '@/types/asyncTask';
import { FileListItem, QueryFileListSchema, UploadFileSchema } from '@/types/files';
const fileProcedure = authedProcedure.use(async (opts) => {
const fileProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
const { ctx } = opts;
return opts.next({
ctx: {
asyncTaskModel: new AsyncTaskModel(serverDB, ctx.userId),
chunkModel: new ChunkModel(serverDB, ctx.userId),
fileModel: new FileModel(serverDB, ctx.userId),
asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId),
chunkModel: new ChunkModel(ctx.serverDB, ctx.userId),
fileModel: new FileModel(ctx.serverDB, ctx.userId),
},
});
});
+3 -3
View File
@@ -2,15 +2,15 @@ import { TRPCError } from '@trpc/server';
import { z } from 'zod';
import { DataImporterRepos } from '@/database/repositories/dataImporter';
import { serverDB } from '@/database/server';
import { authedProcedure, router } from '@/libs/trpc';
import { serverDatabase } from '@/libs/trpc/lambda';
import { S3 } from '@/server/modules/S3';
import { ImportPgDataStructure } from '@/types/export';
import { ImportResultData, ImporterEntryData } from '@/types/importer';
const importProcedure = authedProcedure.use(async (opts) => {
const importProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
const { ctx } = opts;
const dataImporterService = new DataImporterRepos(serverDB, ctx.userId);
const dataImporterService = new DataImporterRepos(ctx.serverDB, ctx.userId);
return opts.next({
ctx: { dataImporterService },
+3 -3
View File
@@ -2,16 +2,16 @@ import { z } from 'zod';
import { KnowledgeBaseModel } from '@/database/models/knowledgeBase';
import { insertKnowledgeBasesSchema } from '@/database/schemas';
import { serverDB } from '@/database/server';
import { authedProcedure, router } from '@/libs/trpc';
import { serverDatabase } from '@/libs/trpc/lambda';
import { KnowledgeBaseItem } from '@/types/knowledgeBase';
const knowledgeBaseProcedure = authedProcedure.use(async (opts) => {
const knowledgeBaseProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
const { ctx } = opts;
return opts.next({
ctx: {
knowledgeBaseModel: new KnowledgeBaseModel(serverDB, ctx.userId),
knowledgeBaseModel: new KnowledgeBaseModel(ctx.serverDB, ctx.userId),
},
});
});
+5 -3
View File
@@ -2,19 +2,20 @@ import { z } from 'zod';
import { MessageModel } from '@/database/models/message';
import { updateMessagePluginSchema } from '@/database/schemas';
import { serverDB } from '@/database/server';
import { getServerDB } from '@/database/server';
import { authedProcedure, publicProcedure, router } from '@/libs/trpc';
import { serverDatabase } from '@/libs/trpc/lambda';
import { getFullFileUrl } from '@/server/utils/files';
import { ChatMessage } from '@/types/message';
import { BatchTaskResult } from '@/types/service';
type ChatMessageList = ChatMessage[];
const messageProcedure = authedProcedure.use(async (opts) => {
const messageProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
const { ctx } = opts;
return opts.next({
ctx: { messageModel: new MessageModel(serverDB, ctx.userId) },
ctx: { messageModel: new MessageModel(ctx.serverDB, ctx.userId) },
});
});
@@ -95,6 +96,7 @@ export const messageRouter = router({
)
.query(async ({ input, ctx }) => {
if (!ctx.userId) return [];
const serverDB = await getServerDB();
const messageModel = new MessageModel(serverDB, ctx.userId);
+5 -3
View File
@@ -1,15 +1,16 @@
import { z } from 'zod';
import { PluginModel } from '@/database/models/plugin';
import { serverDB } from '@/database/server';
import { getServerDB } from '@/database/server';
import { authedProcedure, publicProcedure, router } from '@/libs/trpc';
import { serverDatabase } from '@/libs/trpc/lambda';
import { LobeTool } from '@/types/tool';
const pluginProcedure = authedProcedure.use(async (opts) => {
const pluginProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
const { ctx } = opts;
return opts.next({
ctx: { pluginModel: new PluginModel(serverDB, ctx.userId) },
ctx: { pluginModel: new PluginModel(ctx.serverDB, ctx.userId) },
});
});
@@ -66,6 +67,7 @@ export const pluginRouter = router({
getPlugins: publicProcedure.query(async ({ ctx }): Promise<LobeTool[]> => {
if (!ctx.userId) return [];
const serverDB = await getServerDB();
const pluginModel = new PluginModel(serverDB, ctx.userId);
return pluginModel.query();
+16 -13
View File
@@ -7,7 +7,6 @@ import { z } from 'zod';
import { DEFAULT_EMBEDDING_MODEL, DEFAULT_MODEL } from '@/const/settings';
import { FileModel } from '@/database/models/file';
import { serverDB } from '@/database/server';
import {
EvalDatasetModel,
EvalDatasetRecordModel,
@@ -15,6 +14,7 @@ import {
EvaluationRecordModel,
} from '@/database/server/models/ragEval';
import { authedProcedure, router } from '@/libs/trpc';
import { serverDatabase } from '@/libs/trpc/lambda';
import { keyVaults } from '@/libs/trpc/middleware/keyVaults';
import { S3 } from '@/server/modules/S3';
import { createAsyncServerClient } from '@/server/routers/async';
@@ -29,20 +29,23 @@ import {
insertEvalEvaluationSchema,
} from '@/types/eval';
const ragEvalProcedure = authedProcedure.use(keyVaults).use(async (opts) => {
const { ctx } = opts;
const ragEvalProcedure = authedProcedure
.use(serverDatabase)
.use(keyVaults)
.use(async (opts) => {
const { ctx } = opts;
return opts.next({
ctx: {
datasetModel: new EvalDatasetModel(ctx.userId),
fileModel: new FileModel(serverDB, ctx.userId),
datasetRecordModel: new EvalDatasetRecordModel(ctx.userId),
evaluationModel: new EvalEvaluationModel(ctx.userId),
evaluationRecordModel: new EvaluationRecordModel(ctx.userId),
s3: new S3(),
},
return opts.next({
ctx: {
datasetModel: new EvalDatasetModel(ctx.userId),
fileModel: new FileModel(ctx.serverDB, ctx.userId),
datasetRecordModel: new EvalDatasetRecordModel(ctx.userId),
evaluationModel: new EvalEvaluationModel(ctx.userId),
evaluationRecordModel: new EvaluationRecordModel(ctx.userId),
s3: new S3(),
},
});
});
});
export const ragEvalRouter = router({
createDataset: ragEvalProcedure
+6 -4
View File
@@ -3,21 +3,22 @@ import { z } from 'zod';
import { SessionModel } from '@/database/models/session';
import { SessionGroupModel } from '@/database/models/sessionGroup';
import { insertAgentSchema, insertSessionSchema } from '@/database/schemas';
import { serverDB } from '@/database/server';
import { getServerDB } from '@/database/server';
import { authedProcedure, publicProcedure, router } from '@/libs/trpc';
import { serverDatabase } from '@/libs/trpc/lambda';
import { AgentChatConfigSchema } from '@/types/agent';
import { LobeMetaDataSchema } from '@/types/meta';
import { BatchTaskResult } from '@/types/service';
import { ChatSessionList } from '@/types/session';
import { merge } from '@/utils/merge';
const sessionProcedure = authedProcedure.use(async (opts) => {
const sessionProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
const { ctx } = opts;
return opts.next({
ctx: {
sessionGroupModel: new SessionGroupModel(serverDB, ctx.userId),
sessionModel: new SessionModel(serverDB, ctx.userId),
sessionGroupModel: new SessionGroupModel(ctx.serverDB, ctx.userId),
sessionModel: new SessionModel(ctx.serverDB, ctx.userId),
},
});
});
@@ -95,6 +96,7 @@ export const sessionRouter = router({
sessions: [],
};
const serverDB = await getServerDB();
const sessionModel = new SessionModel(serverDB, ctx.userId);
return sessionModel.queryWithGroups();
+3 -3
View File
@@ -2,16 +2,16 @@ import { z } from 'zod';
import { SessionGroupModel } from '@/database/models/sessionGroup';
import { insertSessionGroupSchema } from '@/database/schemas';
import { serverDB } from '@/database/server';
import { authedProcedure, router } from '@/libs/trpc';
import { serverDatabase } from '@/libs/trpc/lambda';
import { SessionGroupItem } from '@/types/session';
const sessionProcedure = authedProcedure.use(async (opts) => {
const sessionProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
const { ctx } = opts;
return opts.next({
ctx: {
sessionGroupModel: new SessionGroupModel(serverDB, ctx.userId),
sessionGroupModel: new SessionGroupModel(ctx.serverDB, ctx.userId),
},
});
});
+4 -4
View File
@@ -3,17 +3,17 @@ import { z } from 'zod';
import { MessageModel } from '@/database/models/message';
import { ThreadModel } from '@/database/models/thread';
import { insertThreadSchema } from '@/database/schemas';
import { serverDB } from '@/database/server';
import { authedProcedure, router } from '@/libs/trpc';
import { serverDatabase } from '@/libs/trpc/lambda';
import { ThreadItem, createThreadSchema } from '@/types/topic/thread';
const threadProcedure = authedProcedure.use(async (opts) => {
const threadProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
const { ctx } = opts;
return opts.next({
ctx: {
messageModel: new MessageModel(serverDB, ctx.userId),
threadModel: new ThreadModel(serverDB, ctx.userId),
messageModel: new MessageModel(ctx.serverDB, ctx.userId),
threadModel: new ThreadModel(ctx.serverDB, ctx.userId),
},
});
});
+5 -3
View File
@@ -1,15 +1,16 @@
import { z } from 'zod';
import { TopicModel } from '@/database/models/topic';
import { serverDB } from '@/database/server';
import { getServerDB } from '@/database/server';
import { authedProcedure, publicProcedure, router } from '@/libs/trpc';
import { serverDatabase } from '@/libs/trpc/lambda';
import { BatchTaskResult } from '@/types/service';
const topicProcedure = authedProcedure.use(async (opts) => {
const topicProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
const { ctx } = opts;
return opts.next({
ctx: { topicModel: new TopicModel(serverDB, ctx.userId) },
ctx: { topicModel: new TopicModel(ctx.serverDB, ctx.userId) },
});
});
@@ -101,6 +102,7 @@ export const topicRouter = router({
.query(async ({ input, ctx }) => {
if (!ctx.userId) return [];
const serverDB = await getServerDB();
const topicModel = new TopicModel(serverDB, ctx.userId);
return topicModel.query(input);
+7 -7
View File
@@ -5,10 +5,10 @@ import { enableClerk } from '@/const/auth';
import { MessageModel } from '@/database/models/message';
import { SessionModel } from '@/database/models/session';
import { UserModel, UserNotFoundError } from '@/database/models/user';
import { serverDB } from '@/database/server';
import { ClerkAuth } from '@/libs/clerk-auth';
import { LobeNextAuthDbAdapter } from '@/libs/next-auth/adapter';
import { authedProcedure, router } from '@/libs/trpc';
import { serverDatabase } from '@/libs/trpc/lambda';
import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt';
import { UserService } from '@/server/services/user';
import {
@@ -19,12 +19,12 @@ import {
} from '@/types/user';
import { UserSettings } from '@/types/user/settings';
const userProcedure = authedProcedure.use(async (opts) => {
return opts.next({
const userProcedure = authedProcedure.use(serverDatabase).use(async ({ ctx, next }) => {
return next({
ctx: {
clerkAuth: new ClerkAuth(),
nextAuthDbAdapter: LobeNextAuthDbAdapter(serverDB),
userModel: new UserModel(serverDB, opts.ctx.userId),
nextAuthDbAdapter: LobeNextAuthDbAdapter(ctx.serverDB),
userModel: new UserModel(ctx.serverDB, ctx.userId),
},
});
});
@@ -77,10 +77,10 @@ export const userRouter = router({
}
}
const messageModel = new MessageModel(serverDB, ctx.userId);
const messageModel = new MessageModel(ctx.serverDB, ctx.userId);
const hasMoreThan4Messages = await messageModel.hasMoreThanN(4);
const sessionModel = new SessionModel(serverDB, ctx.userId);
const sessionModel = new SessionModel(ctx.serverDB, ctx.userId);
const hasAnyMessages = await messageModel.hasMoreThanN(0);
const hasExtraSession = await sessionModel.hasMoreThanN(1);
@@ -6,7 +6,6 @@ import { UserModel } from '@/database/models/user';
import { UserItem } from '@/database/schemas';
import { serverDB } from '@/database/server';
import { pino } from '@/libs/logger';
import { LobeNextAuthDbAdapter } from '@/libs/next-auth/adapter';
import { NextAuthUserService } from './index';
@@ -23,7 +22,7 @@ vi.mock('@/database/server');
describe('NextAuthUserService', () => {
let service: NextAuthUserService;
beforeEach(() => {
beforeEach(async () => {
vi.clearAllMocks();
service = new NextAuthUserService();
});