From ccb33fa48c56fbbc65eccec263c0fa753fc13e28 Mon Sep 17 00:00:00 2001 From: Rdmclin2 Date: Tue, 9 Jun 2026 15:54:26 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20workspace=20backend=20servi?= =?UTF-8?q?ce=20slice=20(#15560)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Backend-only slice of the workspace feature (server routers/services, database models with workspaceId threading, openapi middleware, business/server stubs, const/types). Excludes all UI (features/routes/store/hooks). Deploys dark behind the workspace feature flag. Includes open-source stub fixes: workspaceCreds router stub, ChargeParams workspaceId, usage.ts null-coalesce, DBMessageItem.workspaceId. Co-authored-by: Claude Opus 4.8 --- package.json | 4 +- packages/agent-signal/src/base/types.ts | 6 + packages/const/package.json | 1 + packages/const/src/index.ts | 1 + packages/const/src/rbac.ts | 238 ++++- packages/const/src/taskTemplate.ts | 13 + packages/const/src/workspace.ts | 10 + packages/database/src/index.ts | 2 + .../src/models/__tests__/agent.test.ts | 47 + .../models/__tests__/agentBotProvider.test.ts | 113 +- .../agentDocuments.workspace.test.ts | 89 ++ .../__tests__/agentEval.workspace.test.ts | 184 ++++ ...agentSignalReviewContext.workspace.test.ts | 108 ++ .../models/__tests__/agentTransfer.test.ts | 189 ++++ .../src/models/__tests__/chatGroup.test.ts | 43 +- .../src/models/__tests__/chunk.test.ts | 38 +- .../models/__tests__/documentTransfer.test.ts | 206 ++++ .../src/models/__tests__/generation.test.ts | 39 + .../models/__tests__/knowledgeBase.test.ts | 325 ++++-- .../__tests__/messages/message.update.test.ts | 83 ++ .../messages/message.workspace.test.ts | 73 ++ .../__tests__/messengerAccountLink.test.ts | 64 +- .../src/models/__tests__/notification.test.ts | 43 + .../__tests__/ragEval.workspace.test.ts | 204 ++++ .../src/models/__tests__/rbac.test.ts | 205 ++++ .../__tests__/session.workspace.test.ts | 79 ++ .../__tests__/topics/topic.create.test.ts | 160 ++- .../__tests__/topics/topic.query.test.ts | 44 + .../src/models/__tests__/workspace.test.ts | 282 +++++ packages/database/src/models/agent.ts | 385 +++++-- .../database/src/models/agentBotProvider.ts | 109 +- packages/database/src/models/agentCronJob.ts | 49 +- .../models/agentDocuments/agentDocument.ts | 98 +- .../src/models/agentEval/benchmark.ts | 78 +- .../database/src/models/agentEval/dataset.ts | 38 +- packages/database/src/models/agentEval/run.ts | 22 +- .../database/src/models/agentEval/runTopic.ts | 37 +- .../database/src/models/agentEval/testCase.ts | 40 +- .../database/src/models/agentOperation.ts | 15 +- .../src/models/agentSignal/nightlyReview.ts | 17 +- .../src/models/agentSignal/reviewContext.ts | 45 +- packages/database/src/models/agentSkill.ts | 32 +- packages/database/src/models/apiKey.ts | 27 +- packages/database/src/models/asyncTask.ts | 41 +- packages/database/src/models/brief.ts | 34 +- packages/database/src/models/chatGroup.ts | 74 +- packages/database/src/models/chunk.ts | 49 +- packages/database/src/models/connector.ts | 37 +- packages/database/src/models/connectorTool.ts | 60 +- packages/database/src/models/device.ts | 11 + packages/database/src/models/document.ts | 252 ++++- .../database/src/models/documentHistory.ts | 46 +- packages/database/src/models/documentShare.ts | 66 +- packages/database/src/models/embedding.ts | 28 +- packages/database/src/models/file.ts | 137 ++- packages/database/src/models/generation.ts | 24 +- .../database/src/models/generationBatch.ts | 30 +- .../database/src/models/generationTopic.ts | 30 +- packages/database/src/models/knowledgeBase.ts | 369 ++++++- .../src/models/llmGenerationTracing.ts | 19 +- packages/database/src/models/message.ts | 266 +++-- .../src/models/messengerAccountLink.ts | 63 +- packages/database/src/models/notification.ts | 24 +- packages/database/src/models/plugin.ts | 30 +- .../database/src/models/ragEval/dataset.ts | 27 +- .../src/models/ragEval/datasetRecord.ts | 41 +- .../database/src/models/ragEval/evaluation.ts | 34 +- .../src/models/ragEval/evaluationRecord.ts | 36 +- packages/database/src/models/rbac.ts | 171 ++- packages/database/src/models/recent.ts | 24 +- packages/database/src/models/session.ts | 181 ++-- packages/database/src/models/sessionGroup.ts | 29 +- packages/database/src/models/task.ts | 382 ++++++- packages/database/src/models/taskTopic.ts | 84 +- packages/database/src/models/thread.ts | 27 +- packages/database/src/models/topic.ts | 139 +-- packages/database/src/models/topicDocument.ts | 26 +- packages/database/src/models/topicShare.ts | 22 +- packages/database/src/models/topicUsage.ts | 15 +- .../src/models/userMemory/activity.ts | 29 +- .../database/src/models/userMemory/context.ts | 18 +- .../src/models/userMemory/experience.ts | 32 +- .../src/models/userMemory/identity.ts | 31 +- .../database/src/models/userMemory/model.ts | 149 +-- .../src/models/userMemory/preference.ts | 28 +- .../database/src/models/userMemory/query.ts | 52 +- packages/database/src/models/workspace.ts | 334 ++++++ .../database/src/models/workspaceAuditLog.ts | 99 ++ .../database/src/models/workspaceMember.ts | 133 +++ .../src/repositories/agentGroup/index.test.ts | 593 +++++++++++ .../src/repositories/agentGroup/index.ts | 466 ++++++++- .../src/repositories/agentMigration/index.ts | 59 +- .../src/repositories/compression/index.ts | 38 +- .../repositories/dataExporter/index.test.ts | 136 +++ .../src/repositories/dataExporter/index.ts | 33 +- .../dataImporter/deprecated/index.ts | 37 +- .../src/repositories/dataImporter/index.ts | 35 +- .../database/src/repositories/home/index.ts | 19 +- .../src/repositories/knowledge/index.test.ts | 73 ++ .../src/repositories/knowledge/index.ts | 41 +- .../database/src/repositories/search/index.ts | 31 +- .../src/repositories/topicImporter/index.ts | 9 +- .../userMemory/UserMemoryTopicRepository.ts | 7 +- packages/database/src/schemas/agentCronJob.ts | 2 +- .../database/src/schemas/agentOperations.ts | 2 +- packages/database/src/schemas/connector.ts | 2 + packages/database/src/schemas/device.ts | 8 + .../database/src/schemas/documentHistory.ts | 2 +- .../src/schemas/llmGenerationTracing.ts | 2 +- packages/database/src/schemas/relations.ts | 4 +- packages/database/src/schemas/topic.ts | 2 +- .../src/schemas/userMemories/index.ts | 1 - packages/database/src/utils/idGenerator.ts | 3 + .../database/src/utils/seedWorkspaceRoles.ts | 227 ++++ packages/database/src/utils/workspace.test.ts | 45 + packages/database/src/utils/workspace.ts | 70 ++ packages/openapi/src/app.ts | 2 + .../openapi/src/common/base.controller.ts | 17 +- .../openapi/src/common/base.service.test.ts | 158 +++ packages/openapi/src/common/base.service.ts | 95 +- .../src/controllers/agent-group.controller.ts | 30 +- .../src/controllers/agent.controller.ts | 10 +- .../src/controllers/chat.controller.ts | 6 +- .../src/controllers/file.controller.ts | 22 +- .../controllers/knowledge-base.controller.ts | 18 +- .../message-translation.controller.ts | 8 +- .../src/controllers/message.controller.ts | 14 +- .../src/controllers/model.controller.ts | 8 +- .../src/controllers/permission.controller.ts | 30 +- .../src/controllers/provider.controller.ts | 10 +- .../src/controllers/responses.controller.ts | 2 +- .../src/controllers/role.controller.ts | 16 +- .../src/controllers/topic.controller.ts | 10 +- .../src/controllers/user.controller.ts | 18 +- packages/openapi/src/middleware/auth.ts | 12 +- packages/openapi/src/middleware/index.ts | 1 + .../src/middleware/permission-check.ts | 5 +- .../openapi/src/middleware/workspace.test.ts | 142 +++ packages/openapi/src/middleware/workspace.ts | 60 ++ .../src/services/agent-group.service.ts | 25 +- .../openapi/src/services/agent.service.ts | 29 +- packages/openapi/src/services/chat.service.ts | 33 +- packages/openapi/src/services/file.service.ts | 56 +- .../src/services/knowledge-base.service.ts | 16 +- .../services/message-translations.service.ts | 30 +- .../openapi/src/services/message.service.ts | 37 +- .../openapi/src/services/model.service.ts | 23 +- .../src/services/permission.service.ts | 4 +- .../openapi/src/services/provider.service.ts | 36 +- .../openapi/src/services/responses.service.ts | 10 +- packages/openapi/src/services/role.service.ts | 49 +- .../openapi/src/services/topic.service.ts | 37 +- packages/openapi/src/services/user.service.ts | 46 +- packages/types/package.json | 2 +- packages/types/src/discover/assistants.ts | 4 + packages/types/src/discover/fork.ts | 10 + packages/types/src/discover/groupAgents.ts | 4 + packages/types/src/discover/index.ts | 1 + packages/types/src/document/index.ts | 2 + packages/types/src/files/list.ts | 1 + packages/types/src/message/db/item.ts | 1 + packages/types/src/task/index.ts | 2 + packages/types/src/topic/topic.ts | 1 + scripts/codemodWorkspaceNav.ts | 371 +++++++ .../extract/chat-topic/cancel/route.ts | 2 +- .../api/webhooks/video/[provider]/route.ts | 14 +- .../agent-eval-run/execute-test-case/route.ts | 4 +- .../agent-eval-run/finalize-run/route.ts | 10 +- .../on-thread-complete/route.ts | 6 +- .../on-trajectory-complete/route.ts | 6 +- .../paginate-test-cases/route.ts | 9 +- .../resume-agent-trajectory/route.ts | 4 +- .../resume-thread-trajectory/route.ts | 4 +- .../run-agent-trajectory/route.ts | 4 +- .../agent-eval-run/run-benchmark/route.ts | 7 +- .../run-thread-trajectory/route.ts | 4 +- .../market/social/[[...segments]]/route.ts | 7 +- .../oauth/connector/callback/route.ts | 2 +- .../(backend)/webapi/_utils/workspace.test.ts | 103 ++ src/app/(backend)/webapi/_utils/workspace.ts | 32 + .../webapi/chat/[provider]/route.test.ts | 1 + .../(backend)/webapi/chat/[provider]/route.ts | 6 +- .../webapi/models/[provider]/pull/route.ts | 6 +- .../webapi/models/[provider]/route.ts | 6 +- .../image-generation/chargeAfterGenerate.ts | 1 + .../image-generation/chargeBeforeGenerate.ts | 1 + src/business/server/lambda-routers/file.ts | 11 + .../server/lambda-routers/workspace.ts | 4 + .../lambda-routers/workspaceAuditLog.ts | 3 + .../server/lambda-routers/workspaceCredits.ts | 3 + .../server/lambda-routers/workspaceCreds.ts | 3 + .../server/lambda-routers/workspaceData.ts | 3 + .../server/lambda-routers/workspaceMember.ts | 3 + .../server/lambda-routers/workspaceUsage.ts | 3 + src/business/server/model-runtime.ts | 1 + .../server/trpc-middlewares/rbacPermission.ts | 28 + .../server/trpc-middlewares/workspaceAuth.ts | 22 + .../trpc-middlewares/workspaceContext.ts | 9 + .../video-generation/chargeAfterGenerate.ts | 1 + .../video-generation/chargeBeforeGenerate.ts | 1 + src/config/featureFlags/schema.test.ts | 10 + src/config/featureFlags/schema.ts | 3 + src/libs/better-auth/email-templates/index.ts | 2 + .../email-templates/workspace-invite.ts | 108 ++ .../workspace-member-removed.ts | 91 ++ src/libs/mcp/connectorPermissionCheck.ts | 6 +- src/libs/trpc/lambda/context.ts | 12 +- .../trpc/lambda/middleware/marketUserInfo.ts | 4 + src/libs/trusted-client/index.ts | 8 + src/locales/default/agent.ts | 18 + src/locales/default/auth.ts | 8 + src/locales/default/chat.ts | 16 + src/locales/default/discover.ts | 108 ++ src/locales/default/error.ts | 9 + src/locales/default/file.ts | 58 ++ src/locales/default/messenger.ts | 3 + src/locales/default/notification.ts | 30 + src/locales/default/setting.ts | 980 +++++++++++++++++- src/locales/default/subscription.ts | 39 + src/locales/default/topic.ts | 2 + src/locales/resources.ts | 10 +- .../handlers/__tests__/runStep.test.ts | 24 + src/server/agent-hono/handlers/execAgent.ts | 5 +- src/server/agent-hono/handlers/gatewayCron.ts | 8 +- src/server/agent-hono/handlers/runStep.ts | 7 +- .../AgentRuntime/AgentRuntimeCoordinator.ts | 1 + .../modules/AgentRuntime/AgentStateManager.ts | 11 + .../AgentRuntime/InMemoryAgentStateManager.ts | 2 + .../modules/AgentRuntime/RuntimeExecutors.ts | 95 +- .../__tests__/RuntimeExecutors.test.ts | 87 ++ src/server/modules/AgentRuntime/types.ts | 1 + .../modules/Mecha/ContextEngineering/index.ts | 2 + .../modules/Mecha/ContextEngineering/types.ts | 3 + src/server/modules/ModelRuntime/index.ts | 7 +- .../routers/async/__tests__/file.test.ts | 31 + .../routers/async/__tests__/ragEval.test.ts | 72 ++ src/server/routers/async/file.ts | 89 +- src/server/routers/async/image.ts | 27 +- src/server/routers/async/ragEval.ts | 38 +- src/server/routers/async/video.ts | 18 +- .../routers/lambda/__tests__/file.test.ts | 154 ++- .../lambda/__tests__/knowledgeBase.test.ts | 96 ++ .../__tests__/llmGenerationTracing.test.ts | 17 +- .../lambda/__tests__/messenger.test.ts | 121 ++- .../lambda/_helpers/resolveContext.test.ts | 45 + .../routers/lambda/_helpers/resolveContext.ts | 23 +- src/server/routers/lambda/agent.ts | 113 +- src/server/routers/lambda/agentBotProvider.ts | 35 +- src/server/routers/lambda/agentDocument.ts | 72 +- src/server/routers/lambda/agentEval.ts | 78 +- .../routers/lambda/agentEvalExternal.ts | 36 +- src/server/routers/lambda/agentGroup.ts | 131 ++- src/server/routers/lambda/agentNotify.ts | 16 +- src/server/routers/lambda/agentSignal.ts | 10 +- src/server/routers/lambda/agentSkills.ts | 50 +- src/server/routers/lambda/aiAgent.ts | 245 +++-- src/server/routers/lambda/aiChat.ts | 30 +- src/server/routers/lambda/aiModel.ts | 26 +- src/server/routers/lambda/aiProvider.ts | 21 +- src/server/routers/lambda/apiKey.ts | 20 +- src/server/routers/lambda/botMessage.ts | 30 +- src/server/routers/lambda/brief.ts | 41 +- src/server/routers/lambda/chunk.ts | 37 +- src/server/routers/lambda/comfyui.ts | 2 + src/server/routers/lambda/connector.ts | 14 +- src/server/routers/lambda/document.ts | 160 ++- src/server/routers/lambda/exporter.ts | 18 +- src/server/routers/lambda/file.ts | 226 +++- src/server/routers/lambda/followUpAction.ts | 12 +- src/server/routers/lambda/generation.ts | 14 +- src/server/routers/lambda/generationBatch.ts | 12 +- src/server/routers/lambda/generationTopic.ts | 17 +- src/server/routers/lambda/home.ts | 14 +- src/server/routers/lambda/image/index.test.ts | 42 +- src/server/routers/lambda/image/index.ts | 490 ++++----- src/server/routers/lambda/importer.ts | 19 +- src/server/routers/lambda/index.ts | 14 + src/server/routers/lambda/klavis.ts | 13 +- src/server/routers/lambda/knowledge.ts | 18 +- src/server/routers/lambda/knowledgeBase.ts | 149 ++- .../routers/lambda/llmGenerationTracing.ts | 17 +- .../routers/lambda/market/agent.test.ts | 98 ++ src/server/routers/lambda/market/agent.ts | 203 ++-- .../routers/lambda/market/agentGroup.test.ts | 120 +++ .../routers/lambda/market/agentGroup.ts | 86 +- src/server/routers/lambda/market/creds.ts | 52 +- src/server/routers/lambda/market/social.ts | 7 +- .../lambda/market/socialProfile.test.ts | 53 + .../routers/lambda/market/socialProfile.ts | 4 + src/server/routers/lambda/message.ts | 145 ++- src/server/routers/lambda/messenger.ts | 225 +++- src/server/routers/lambda/notebook.ts | 20 +- src/server/routers/lambda/notification.ts | 21 +- src/server/routers/lambda/oauthDeviceFlow.ts | 14 +- src/server/routers/lambda/plugin.ts | 30 +- src/server/routers/lambda/ragEval.ts | 40 +- src/server/routers/lambda/recent.ts | 7 +- src/server/routers/lambda/search.ts | 8 +- src/server/routers/lambda/session.ts | 79 +- src/server/routers/lambda/sessionGroup.ts | 21 +- src/server/routers/lambda/task.ts | 194 +++- src/server/routers/lambda/thread.ts | 55 +- src/server/routers/lambda/topic.ts | 61 +- src/server/routers/lambda/upload.ts | 2 + src/server/routers/lambda/usage.ts | 11 +- src/server/routers/lambda/user.ts | 62 +- src/server/routers/lambda/userMemories.ts | 24 +- src/server/routers/lambda/userMemory.ts | 38 +- src/server/routers/lambda/video/index.ts | 483 ++++----- src/server/routers/lambda/webBrowsing.ts | 11 +- src/server/routers/mobile/topic.ts | 32 +- src/server/routers/tools/klavis.ts | 10 +- src/server/routers/tools/market.ts | 23 +- src/server/routers/tools/mcp.ts | 12 +- src/server/services/agent/index.test.ts | 20 +- src/server/services/agent/index.ts | 8 +- src/server/services/agentDocumentVfs/index.ts | 6 +- .../mounts/skills/createSkillMount.ts | 10 +- src/server/services/agentDocuments/index.ts | 8 +- src/server/services/agentEvalRun/index.ts | 47 +- src/server/services/agentGroup/index.ts | 8 +- .../agentRuntime/AbandonOperationService.ts | 8 +- .../agentRuntime/AgentRuntimeService.ts | 24 +- .../agentRuntime/CompletionLifecycle.ts | 9 +- .../__tests__/executeStep.test.ts | 23 + src/server/services/agentRuntime/types.ts | 6 + .../__tests__/index.integration.test.ts | 74 ++ src/server/services/agentSignal/emitter.ts | 9 +- .../services/agentSignal/orchestrator.ts | 5 + .../__tests__/feedbackSatisfaction.test.ts | 7 +- .../__tests__/skillIntent.test.ts | 7 +- .../actions/__tests__/skillManagement.test.ts | 17 + .../analyzeIntent/actions/skillManagement.ts | 3 + .../analyzeIntent/actions/userMemory.ts | 14 +- .../policies/analyzeIntent/feedbackDomain.ts | 10 +- .../analyzeIntent/feedbackDomainAgent.ts | 4 + .../analyzeIntent/feedbackSatisfaction.ts | 18 +- .../policies/analyzeIntent/skillIntent.ts | 4 + .../services/briefs/selfReview.test.ts | 2 +- .../agentSignal/services/briefs/selfReview.ts | 21 +- .../dispatch/enqueueSelfIterationRun.ts | 8 +- .../selfIteration/feedback/handler.ts | 3 + .../services/selfIteration/feedback/server.ts | 4 +- .../selfIteration/reflection/handler.ts | 3 + .../selfIteration/reflection/server.ts | 8 +- .../services/selfIteration/review/brief.ts | 8 +- .../services/selfIteration/review/handler.ts | 3 + .../services/selfIteration/review/server.ts | 12 +- .../services/selfIteration/server.ts | 2 + .../selfIteration/tools/runtimePrimitives.ts | 5 +- .../hydration/clientRuntimeComplete.ts | 4 +- .../sources/hydration/clientRuntimeStart.ts | 4 +- src/server/services/aiAgent/index.ts | 53 +- src/server/services/aiChat/index.ts | 8 +- src/server/services/aiGeneration/index.ts | 8 +- src/server/services/bot/AgentBridgeService.ts | 23 +- src/server/services/bot/BotCallbackService.ts | 44 +- src/server/services/bot/BotMessageRouter.ts | 16 +- .../bot/__tests__/AgentBridgeService.test.ts | 29 + .../bot/__tests__/BotMessageRouter.test.ts | 24 +- src/server/services/brief/index.ts | 12 +- src/server/services/chunk/index.ts | 34 +- src/server/services/discover/index.ts | 9 +- src/server/services/document/history.ts | 24 +- src/server/services/document/index.ts | 38 +- .../services/file/__tests__/index.test.ts | 8 + .../file/extractFileIdsFromEditorData.ts | 7 +- src/server/services/file/index.ts | 4 +- .../services/file/resolveAttachments.ts | 13 +- .../services/followUpAction/index.test.ts | 29 + src/server/services/followUpAction/index.ts | 10 +- .../services/gateway/GatewayManager.test.ts | 60 +- src/server/services/gateway/GatewayManager.ts | 13 +- .../gateway/__tests__/GatewayManager.test.ts | 11 +- src/server/services/gateway/index.ts | 44 +- src/server/services/generation/index.ts | 4 +- src/server/services/generation/video.ts | 4 +- .../generation/videoBackgroundPolling.ts | 12 +- .../services/heterogeneousAgent/index.ts | 12 +- src/server/services/klavis/index.ts | 16 +- src/server/services/knowledgeBase/index.ts | 26 +- .../services/llmGenerationTracing/hook.ts | 2 + .../services/llmGenerationTracing/index.ts | 6 +- src/server/services/market/index.ts | 65 +- .../services/memory/userMemory/extract.ts | 51 +- .../memory/userMemory/persona/service.ts | 3 + src/server/services/message/index.ts | 8 +- .../messenger/MessengerRouter.test.ts | 195 ++++ .../services/messenger/MessengerRouter.ts | 269 ++++- .../platforms/discord/binder.test.ts | 14 + .../messenger/platforms/discord/binder.ts | 11 +- .../messenger/platforms/slack/binder.ts | 15 +- .../messenger/platforms/telegram/binder.ts | 10 +- src/server/services/messenger/types.ts | 17 +- src/server/services/notebook/index.ts | 15 +- src/server/services/skill/importer.ts | 8 +- src/server/services/skill/resource.ts | 4 +- .../SkillManagementDocumentService.test.ts | 1 + .../SkillManagementDocumentService.ts | 6 +- src/server/services/systemAgent/index.ts | 11 +- src/server/services/task/index.ts | 42 +- src/server/services/taskGraph/index.ts | 4 +- src/server/services/taskLifecycle/index.ts | 40 +- src/server/services/taskReview/index.ts | 11 +- .../services/taskRunner/buildTaskPrompt.ts | 6 +- .../services/taskRunner/heartbeatTick.ts | 17 +- src/server/services/taskRunner/index.ts | 20 +- .../services/taskRunner/scheduleTick.test.ts | 38 +- .../services/taskRunner/scheduleTick.ts | 21 +- .../services/taskTemplate/index.test.ts | 70 +- src/server/services/taskTemplate/index.ts | 33 +- .../toolExecution/archiveToolResult.ts | 6 +- src/server/services/toolExecution/builtin.ts | 1 + src/server/services/toolExecution/index.ts | 1 + .../__tests__/agentDocuments.test.ts | 20 +- .../__tests__/agentManagement.test.ts | 18 + .../agentSignalSkillManagement.test.ts | 32 + .../__tests__/lobeAgent.test.ts | 16 + .../__tests__/lobeAgentPlan.test.ts | 29 + .../serverRuntimes/__tests__/message.test.ts | 4 +- .../__tests__/skillManagement.test.ts | 5 +- .../__tests__/topicReference.test.ts | 20 + .../toolExecution/serverRuntimes/activator.ts | 2 +- .../serverRuntimes/agentBuilder.ts | 4 +- .../serverRuntimes/agentDocuments.ts | 18 +- .../serverRuntimes/agentManagement.ts | 4 +- .../agentSignalFeedbackIntent.ts | 5 +- .../serverRuntimes/agentSignalReflection.ts | 5 +- .../serverRuntimes/agentSignalReview.ts | 13 +- .../agentSignalSkillManagement.ts | 4 +- .../toolExecution/serverRuntimes/brief.ts | 156 +-- .../serverRuntimes/cloudSandbox.ts | 2 +- .../serverRuntimes/knowledgeBase.ts | 16 +- .../toolExecution/serverRuntimes/lobeAgent.ts | 12 +- .../serverRuntimes/lobeAgentPlan.ts | 5 +- .../toolExecution/serverRuntimes/memory.ts | 24 +- .../serverRuntimes/message/index.ts | 9 +- .../toolExecution/serverRuntimes/notebook.ts | 1 + .../serverRuntimes/skillManagement.ts | 6 +- .../toolExecution/serverRuntimes/skills.ts | 14 +- .../toolExecution/serverRuntimes/task.ts | 151 ++- .../serverRuntimes/topicReference.ts | 14 +- .../serverRuntimes/webBrowsing.ts | 4 +- .../serverRuntimes/webOnboarding.ts | 6 +- src/server/services/toolExecution/types.ts | 8 + src/server/services/usage/index.ts | 10 +- src/server/services/webBrowsing/index.ts | 10 +- .../workflows/processTopic.ts | 28 +- .../workflows/processTopics.ts | 8 +- .../workflows/processUserTopics.ts | 15 +- .../workflows/processUsers.ts | 8 +- .../task/handlers/onTopicComplete.ts | 12 +- .../workflows-hono/task/handlers/watchdog.ts | 5 +- src/server/workflows/agentEvalRun/index.ts | 6 +- src/server/workflows/agentEvalRun/utils.ts | 24 + src/server/workflows/agentSignal/run.ts | 27 +- src/server/workflows/agentSignal/types.ts | 6 + src/types/transferError.ts | 10 + src/types/workspaceSettings.ts | 24 + src/utils/dayjsLocale.test.ts | 15 + src/utils/dayjsLocale.ts | 3 + src/utils/locale.test.ts | 24 + src/utils/locale.ts | 15 +- tests/mocks/storeWorkspace.ts | 42 + vitest.config.mts | 4 + 465 files changed, 17609 insertions(+), 3891 deletions(-) create mode 100644 packages/const/src/workspace.ts create mode 100644 packages/database/src/models/__tests__/agentDocuments.workspace.test.ts create mode 100644 packages/database/src/models/__tests__/agentEval.workspace.test.ts create mode 100644 packages/database/src/models/__tests__/agentSignalReviewContext.workspace.test.ts create mode 100644 packages/database/src/models/__tests__/agentTransfer.test.ts create mode 100644 packages/database/src/models/__tests__/documentTransfer.test.ts create mode 100644 packages/database/src/models/__tests__/messages/message.workspace.test.ts create mode 100644 packages/database/src/models/__tests__/notification.test.ts create mode 100644 packages/database/src/models/__tests__/ragEval.workspace.test.ts create mode 100644 packages/database/src/models/__tests__/rbac.test.ts create mode 100644 packages/database/src/models/__tests__/session.workspace.test.ts create mode 100644 packages/database/src/models/__tests__/workspace.test.ts create mode 100644 packages/database/src/models/workspace.ts create mode 100644 packages/database/src/models/workspaceAuditLog.ts create mode 100644 packages/database/src/models/workspaceMember.ts create mode 100644 packages/database/src/utils/seedWorkspaceRoles.ts create mode 100644 packages/database/src/utils/workspace.test.ts create mode 100644 packages/database/src/utils/workspace.ts create mode 100644 packages/openapi/src/common/base.service.test.ts create mode 100644 packages/openapi/src/middleware/workspace.test.ts create mode 100644 packages/openapi/src/middleware/workspace.ts create mode 100644 scripts/codemodWorkspaceNav.ts create mode 100644 src/app/(backend)/webapi/_utils/workspace.test.ts create mode 100644 src/app/(backend)/webapi/_utils/workspace.ts create mode 100644 src/business/server/lambda-routers/workspace.ts create mode 100644 src/business/server/lambda-routers/workspaceAuditLog.ts create mode 100644 src/business/server/lambda-routers/workspaceCredits.ts create mode 100644 src/business/server/lambda-routers/workspaceCreds.ts create mode 100644 src/business/server/lambda-routers/workspaceData.ts create mode 100644 src/business/server/lambda-routers/workspaceMember.ts create mode 100644 src/business/server/lambda-routers/workspaceUsage.ts create mode 100644 src/business/server/trpc-middlewares/rbacPermission.ts create mode 100644 src/business/server/trpc-middlewares/workspaceAuth.ts create mode 100644 src/business/server/trpc-middlewares/workspaceContext.ts create mode 100644 src/libs/better-auth/email-templates/workspace-invite.ts create mode 100644 src/libs/better-auth/email-templates/workspace-member-removed.ts create mode 100644 src/server/routers/async/__tests__/ragEval.test.ts create mode 100644 src/server/routers/lambda/__tests__/knowledgeBase.test.ts create mode 100644 src/server/routers/lambda/market/agent.test.ts create mode 100644 src/server/routers/lambda/market/agentGroup.test.ts create mode 100644 src/server/routers/lambda/market/socialProfile.test.ts create mode 100644 src/server/services/toolExecution/serverRuntimes/__tests__/agentSignalSkillManagement.test.ts create mode 100644 src/server/services/toolExecution/serverRuntimes/__tests__/lobeAgentPlan.test.ts create mode 100644 src/server/workflows/agentEvalRun/utils.ts create mode 100644 src/types/transferError.ts create mode 100644 src/types/workspaceSettings.ts create mode 100644 src/utils/dayjsLocale.test.ts create mode 100644 tests/mocks/storeWorkspace.ts diff --git a/package.json b/package.json index a9033af1e9..1624c97e83 100644 --- a/package.json +++ b/package.json @@ -48,6 +48,8 @@ "build-migrate-db": "bun run db:migrate", "build-sitemap": "tsx ./scripts/buildSitemapIndex/index.ts", "clean:node_modules": "bash -lc 'set -e; echo \"Removing all node_modules...\"; rm -rf node_modules; pnpm -r exec rm -rf node_modules; rm -rf apps/desktop/node_modules; echo \"All node_modules removed.\"'", + "codemod:workspace-nav": "tsx ./scripts/codemodWorkspaceNav.ts", + "codemod:workspace-nav:check": "tsx ./scripts/codemodWorkspaceNav.ts --check", "db:generate": "drizzle-kit generate && npm run workflow:dbml", "db:migrate": "cross-env MIGRATION_DB=1 tsx ./scripts/migrateServerDB/index.ts", "db:studio": "drizzle-kit studio", @@ -287,7 +289,7 @@ "@lobehub/desktop-ipc-typings": "workspace:*", "@lobehub/editor": "^4.17.0", "@lobehub/icons": "^5.0.0", - "@lobehub/market-sdk": "0.33.3", + "@lobehub/market-sdk": "0.34.0", "@lobehub/tts": "^5.1.2", "@lobehub/ui": "^5.15.10", "@modelcontextprotocol/sdk": "^1.26.0", diff --git a/packages/agent-signal/src/base/types.ts b/packages/agent-signal/src/base/types.ts index 6002bbb105..fc99409377 100644 --- a/packages/agent-signal/src/base/types.ts +++ b/packages/agent-signal/src/base/types.ts @@ -5,6 +5,12 @@ export interface AgentSignalScope { taskId?: string; topicId?: string; userId: string; + /** + * Workspace identifier when the chain runs inside a team workspace. Omitted + * for personal-mode chains. Action handlers that write workspace-scoped + * tables (messages, memories) must honor this when present. + */ + workspaceId?: string; } /** Causal chain metadata for source, signal, and action nodes. */ diff --git a/packages/const/package.json b/packages/const/package.json index afb92958f5..b44e39a588 100644 --- a/packages/const/package.json +++ b/packages/const/package.json @@ -7,6 +7,7 @@ "./currency": "./src/currency.ts", "./desktopGlobalShortcuts": "./src/desktopGlobalShortcuts.ts", "./hotkeys": "./src/hotkeys.ts", + "./rbac": "./src/rbac.ts", "./visualRef": "./src/visualRef.ts" }, "main": "./src/index.ts", diff --git a/packages/const/src/index.ts b/packages/const/src/index.ts index 40d806463d..f2520635bf 100644 --- a/packages/const/src/index.ts +++ b/packages/const/src/index.ts @@ -28,3 +28,4 @@ export * from './url'; export * from './user'; export * from './userMemory'; export * from './version'; +export * from './workspace'; diff --git a/packages/const/src/rbac.ts b/packages/const/src/rbac.ts index dd5f49d5be..a55c1c841d 100644 --- a/packages/const/src/rbac.ts +++ b/packages/const/src/rbac.ts @@ -1,5 +1,3 @@ -/* eslint-disable sort-keys-fix/sort-keys-fix */ - /** * RBAC Permission Actions Definition * Defines all executable permission action types in the system @@ -153,6 +151,40 @@ export const PERMISSION_ACTIONS = { USER_READ: 'user:read', USER_UPDATE: 'user:update', + + // ==================== Workspace Management ==================== + WORKSPACE_READ: 'workspace:read', + + WORKSPACE_UPDATE: 'workspace:update', + + WORKSPACE_DELETE: 'workspace:delete', + + WORKSPACE_SETTINGS_UPDATE: 'workspace:settings_update', + + WORKSPACE_BILLING_READ: 'workspace:billing_read', + + WORKSPACE_BILLING_MANAGE: 'workspace:billing_manage', + + // ==================== Workspace Member Management ==================== + WORKSPACE_MEMBER_READ: 'workspace_member:read', + + WORKSPACE_MEMBER_INVITE: 'workspace_member:invite', + + WORKSPACE_MEMBER_REMOVE: 'workspace_member:remove', + + WORKSPACE_MEMBER_UPDATE_ROLE: 'workspace_member:update_role', + + // ==================== Workspace Audit ==================== + WORKSPACE_AUDIT_READ: 'workspace_audit:read', + + // ==================== Workspace Role Management ==================== + WORKSPACE_ROLE_READ: 'workspace_role:read', + + WORKSPACE_ROLE_CREATE: 'workspace_role:create', + + WORKSPACE_ROLE_UPDATE: 'workspace_role:update', + + WORKSPACE_ROLE_DELETE: 'workspace_role:delete', } as const; /** @@ -176,6 +208,11 @@ export const getAllowedScopesForAction = ( // RBAC resources: ALL only (system-level resource) if (resource === 'rbac') return ['ALL']; + // Workspace-scoped resources: ALL only. The workspace itself is the isolation + // boundary, so an "OWNER" sub-scope (resource-author-only) is redundant — + // workspace_member.role + assigned permissions already pin who can do what. + if (resource.startsWith('workspace')) return ['ALL']; + // user resource nuance: create/delete without OWNER; read/update allow OWNER if (resource === 'user') { if (action === 'create' || action === 'delete') return ['ALL']; @@ -236,3 +273,200 @@ export const SYSTEM_DEFAULT_ROLES = { export const ROLE_DESCRIPTIONS = { [SYSTEM_DEFAULT_ROLES.SUPER_ADMIN]: 'Administrator with all system permissions', } as const; + +/** + * Built-in role names for workspace-scoped RBAC. Each workspace is seeded with + * exactly these three system roles on creation; their `workspace_id` is the + * owning workspace, distinguishing them from the global `super_admin` role. + */ +export const WORKSPACE_SYSTEM_ROLES = { + OWNER: 'workspace_owner', + MEMBER: 'workspace_member', + VIEWER: 'workspace_viewer', +} as const; + +export type WorkspaceSystemRoleName = + (typeof WORKSPACE_SYSTEM_ROLES)[keyof typeof WORKSPACE_SYSTEM_ROLES]; + +const action = (key: keyof typeof PERMISSION_ACTIONS): string => PERMISSION_ACTIONS[key]; + +/** + * Permission codes granted to each built-in workspace role. The lists are the + * source of truth used both by `seedWorkspaceRoles` (DB seeding) and the + * migration backfill SQL — keep them aligned. + * + * Scope semantics: + * - `workspace_owner` — every workspace-domain permission + every content + * permission (`:all`) so they can manage other members' resources too. + * - `workspace_member` — read workspace + members; create/update/delete their + * own content (`:owner`) on every content resource. + * - `workspace_viewer` — strict read-only on workspace + members + content. + * No model invocation: chat without SESSION/MESSAGE write grants would + * either burn workspace budget without persisting history or require + * special-case bypasses. Use `workspace_member` if "can chat" is needed. + */ +export const WORKSPACE_ROLE_PERMISSIONS: Record = { + [WORKSPACE_SYSTEM_ROLES.OWNER]: [ + // Workspace + `${action('WORKSPACE_READ')}:all`, + `${action('WORKSPACE_UPDATE')}:all`, + `${action('WORKSPACE_DELETE')}:all`, + `${action('WORKSPACE_SETTINGS_UPDATE')}:all`, + `${action('WORKSPACE_BILLING_READ')}:all`, + `${action('WORKSPACE_BILLING_MANAGE')}:all`, + // Members + `${action('WORKSPACE_MEMBER_READ')}:all`, + `${action('WORKSPACE_MEMBER_INVITE')}:all`, + `${action('WORKSPACE_MEMBER_REMOVE')}:all`, + `${action('WORKSPACE_MEMBER_UPDATE_ROLE')}:all`, + // Audit + `${action('WORKSPACE_AUDIT_READ')}:all`, + // Custom roles + `${action('WORKSPACE_ROLE_READ')}:all`, + `${action('WORKSPACE_ROLE_CREATE')}:all`, + `${action('WORKSPACE_ROLE_UPDATE')}:all`, + `${action('WORKSPACE_ROLE_DELETE')}:all`, + // Content — owner can read/write everyone's resources + `${action('AGENT_READ')}:all`, + `${action('AGENT_CREATE')}:all`, + `${action('AGENT_UPDATE')}:all`, + `${action('AGENT_DELETE')}:all`, + `${action('AGENT_FORK')}:all`, + `${action('SESSION_READ')}:all`, + `${action('SESSION_CREATE')}:all`, + `${action('SESSION_UPDATE')}:all`, + `${action('SESSION_DELETE')}:all`, + `${action('SESSION_GROUP_READ')}:all`, + `${action('SESSION_GROUP_CREATE')}:all`, + `${action('SESSION_GROUP_UPDATE')}:all`, + `${action('SESSION_GROUP_DELETE')}:all`, + `${action('MESSAGE_READ')}:all`, + `${action('MESSAGE_CREATE')}:all`, + `${action('MESSAGE_UPDATE')}:all`, + `${action('MESSAGE_DELETE')}:all`, + `${action('TOPIC_READ')}:all`, + `${action('TOPIC_CREATE')}:all`, + `${action('TOPIC_UPDATE')}:all`, + `${action('TOPIC_DELETE')}:all`, + `${action('FILE_READ')}:all`, + `${action('FILE_UPLOAD')}:all`, + `${action('FILE_UPDATE')}:all`, + `${action('FILE_DELETE')}:all`, + `${action('DOCUMENT_READ')}:all`, + `${action('DOCUMENT_CREATE')}:all`, + `${action('DOCUMENT_UPDATE')}:all`, + `${action('DOCUMENT_DELETE')}:all`, + `${action('KNOWLEDGE_BASE_READ')}:all`, + `${action('KNOWLEDGE_BASE_CREATE')}:all`, + `${action('KNOWLEDGE_BASE_UPDATE')}:all`, + `${action('KNOWLEDGE_BASE_DELETE')}:all`, + `${action('AI_MODEL_READ')}:all`, + `${action('AI_MODEL_INVOKE')}:all`, + `${action('AI_MODEL_CREATE')}:all`, + `${action('AI_MODEL_UPDATE')}:all`, + `${action('AI_MODEL_DELETE')}:all`, + `${action('AI_PROVIDER_READ')}:all`, + `${action('AI_PROVIDER_CREATE')}:all`, + `${action('AI_PROVIDER_UPDATE')}:all`, + `${action('AI_PROVIDER_DELETE')}:all`, + `${action('API_KEY_READ')}:all`, + `${action('API_KEY_CREATE')}:all`, + `${action('API_KEY_UPDATE')}:all`, + `${action('API_KEY_DELETE')}:all`, + ], + [WORKSPACE_SYSTEM_ROLES.MEMBER]: [ + // Workspace — read only + `${action('WORKSPACE_READ')}:all`, + `${action('WORKSPACE_MEMBER_READ')}:all`, + // Content — can write own + `${action('AGENT_READ')}:all`, + `${action('AGENT_CREATE')}:owner`, + `${action('AGENT_UPDATE')}:owner`, + `${action('AGENT_DELETE')}:owner`, + `${action('AGENT_FORK')}:owner`, + `${action('SESSION_READ')}:all`, + `${action('SESSION_CREATE')}:owner`, + `${action('SESSION_UPDATE')}:owner`, + `${action('SESSION_DELETE')}:owner`, + `${action('SESSION_GROUP_READ')}:all`, + `${action('SESSION_GROUP_CREATE')}:owner`, + `${action('SESSION_GROUP_UPDATE')}:owner`, + `${action('SESSION_GROUP_DELETE')}:owner`, + `${action('MESSAGE_READ')}:all`, + `${action('MESSAGE_CREATE')}:owner`, + `${action('MESSAGE_UPDATE')}:owner`, + `${action('MESSAGE_DELETE')}:owner`, + `${action('TOPIC_READ')}:all`, + `${action('TOPIC_CREATE')}:owner`, + `${action('TOPIC_UPDATE')}:owner`, + `${action('TOPIC_DELETE')}:owner`, + `${action('FILE_READ')}:all`, + `${action('FILE_UPLOAD')}:owner`, + `${action('FILE_UPDATE')}:owner`, + `${action('FILE_DELETE')}:owner`, + `${action('DOCUMENT_READ')}:all`, + `${action('DOCUMENT_CREATE')}:owner`, + `${action('DOCUMENT_UPDATE')}:owner`, + `${action('DOCUMENT_DELETE')}:owner`, + `${action('KNOWLEDGE_BASE_READ')}:all`, + `${action('KNOWLEDGE_BASE_CREATE')}:owner`, + `${action('KNOWLEDGE_BASE_UPDATE')}:owner`, + `${action('KNOWLEDGE_BASE_DELETE')}:owner`, + `${action('AI_MODEL_READ')}:all`, + `${action('AI_MODEL_INVOKE')}:all`, + `${action('AI_PROVIDER_READ')}:all`, + `${action('API_KEY_READ')}:owner`, + `${action('API_KEY_CREATE')}:owner`, + `${action('API_KEY_UPDATE')}:owner`, + `${action('API_KEY_DELETE')}:owner`, + ], + [WORKSPACE_SYSTEM_ROLES.VIEWER]: [ + // Read-only across the board + `${action('WORKSPACE_READ')}:all`, + `${action('WORKSPACE_MEMBER_READ')}:all`, + `${action('AGENT_READ')}:all`, + `${action('SESSION_READ')}:all`, + `${action('SESSION_GROUP_READ')}:all`, + `${action('MESSAGE_READ')}:all`, + `${action('TOPIC_READ')}:all`, + `${action('FILE_READ')}:all`, + `${action('DOCUMENT_READ')}:all`, + `${action('KNOWLEDGE_BASE_READ')}:all`, + `${action('AI_MODEL_READ')}:all`, + `${action('AI_PROVIDER_READ')}:all`, + ], +}; + +export const WORKSPACE_ROLE_DESCRIPTIONS: Record = { + [WORKSPACE_SYSTEM_ROLES.OWNER]: 'Full access including billing, members, and all content.', + [WORKSPACE_SYSTEM_ROLES.MEMBER]: 'Can create and edit own content, read shared content.', + [WORKSPACE_SYSTEM_ROLES.VIEWER]: 'Read-only access to workspace content.', +}; + +export const WORKSPACE_ROLE_DISPLAY_NAMES: Record = { + [WORKSPACE_SYSTEM_ROLES.OWNER]: 'Owner', + [WORKSPACE_SYSTEM_ROLES.MEMBER]: 'Member', + [WORKSPACE_SYSTEM_ROLES.VIEWER]: 'Viewer', +}; + +/** + * Translate a legacy `workspace_members.role` text value to its corresponding + * built-in role name. Used by the migration backfill and member CRUD code that + * still double-writes to `workspace_members.role` for label/UI purposes. + */ +export const legacyRoleToWorkspaceRole = (role: string): WorkspaceSystemRoleName | null => { + switch (role) { + case 'owner': { + return WORKSPACE_SYSTEM_ROLES.OWNER; + } + case 'member': { + return WORKSPACE_SYSTEM_ROLES.MEMBER; + } + case 'viewer': { + return WORKSPACE_SYSTEM_ROLES.VIEWER; + } + default: { + return null; + } + } +}; diff --git a/packages/const/src/taskTemplate.ts b/packages/const/src/taskTemplate.ts index 4bebaa515d..bd0bbea1ee 100644 --- a/packages/const/src/taskTemplate.ts +++ b/packages/const/src/taskTemplate.ts @@ -63,6 +63,19 @@ export const TASK_TEMPLATE_FALLBACK_CATEGORIES: TaskTemplateCategory[] = [ 'learning-research', ]; +/** + * Categories that only make sense in a personal context. When the recommendation + * is requested from inside a workspace, every template under these categories + * is removed from the candidate pool — both matched and fallback — so a team + * dashboard never surfaces "bedtime gratitude" / "weekly family finance" etc. + */ +export const TASK_TEMPLATE_PERSONAL_ONLY_CATEGORIES: TaskTemplateCategory[] = [ + 'parenting', + 'health', + 'hobbies', + 'personal-life', +]; + export const TASK_TEMPLATE_RECOMMEND_COUNT = 3; export const taskTemplates: TaskTemplate[] = [ diff --git a/packages/const/src/workspace.ts b/packages/const/src/workspace.ts new file mode 100644 index 0000000000..15725a0f4b --- /dev/null +++ b/packages/const/src/workspace.ts @@ -0,0 +1,10 @@ +/** + * Number of days a workspace invitation token stays valid before it expires. + * Shared by `WorkspaceMemberModel.createInvitation` (sets `expiresAt`) and the + * cloud invite-email template (renders the human-facing expiry copy), so the + * actual TTL and what we promise to recipients can't drift apart. + * + * If you change this, also update the "expire after 1 week" copy in + * `lobehub/src/locales/default/setting.ts` (`workspace.members.invite.modal.expiryWarning`). + */ +export const INVITATION_EXPIRY_DAYS = 7; diff --git a/packages/database/src/index.ts b/packages/database/src/index.ts index 21c8e00da1..a341ed8545 100644 --- a/packages/database/src/index.ts +++ b/packages/database/src/index.ts @@ -2,3 +2,5 @@ export * from './core/db-adaptor'; export * from './repositories/compression'; export * from './type'; export * from './utils/idGenerator'; +export * from './utils/seedWorkspaceRoles'; +export * from './utils/workspace'; diff --git a/packages/database/src/models/__tests__/agent.test.ts b/packages/database/src/models/__tests__/agent.test.ts index 41473c545d..87967f5340 100644 --- a/packages/database/src/models/__tests__/agent.test.ts +++ b/packages/database/src/models/__tests__/agent.test.ts @@ -17,6 +17,7 @@ import { sessions, topics, users, + workspaces, } from '../../schemas'; import type { LobeChatDatabase } from '../../type'; import { AgentModel } from '../agent'; @@ -1309,6 +1310,52 @@ describe('AgentModel', () => { expect(result?.virtual).toBe(true); }); }); + + describe('workspace mode', () => { + it('should create workspace-scoped inbox agent', async () => { + const [workspace] = await serverDB + .insert(workspaces) + .values({ name: 'ws', primaryOwnerId: userId, slug: 'ws-slug' }) + .returning(); + + const wsAgentModel = new AgentModel(serverDB, userId, workspace.id); + const result = await wsAgentModel.getBuiltinAgent(INBOX_SESSION_ID); + + expect(result).toBeDefined(); + expect(result?.slug).toBe(INBOX_SESSION_ID); + expect(result?.workspaceId).toBe(workspace.id); + expect(result?.userId).toBe(userId); + }); + + it('should allow workspace inbox to coexist with personal inbox for the same user', async () => { + const personal = await agentModel.getBuiltinAgent(INBOX_SESSION_ID); + expect(personal?.workspaceId).toBeNull(); + + const [workspace] = await serverDB + .insert(workspaces) + .values({ name: 'ws2', primaryOwnerId: userId, slug: 'ws2-slug' }) + .returning(); + + const wsAgentModel = new AgentModel(serverDB, userId, workspace.id); + const ws = await wsAgentModel.getBuiltinAgent(INBOX_SESSION_ID); + + expect(ws?.id).not.toBe(personal?.id); + expect(ws?.workspaceId).toBe(workspace.id); + }); + + it('should be idempotent in workspace mode', async () => { + const [workspace] = await serverDB + .insert(workspaces) + .values({ name: 'ws3', primaryOwnerId: userId, slug: 'ws3-slug' }) + .returning(); + + const wsAgentModel = new AgentModel(serverDB, userId, workspace.id); + const first = await wsAgentModel.getBuiltinAgent(INBOX_SESSION_ID); + const second = await wsAgentModel.getBuiltinAgent(INBOX_SESSION_ID); + + expect(first?.id).toBe(second?.id); + }); + }); }); describe('batchDelete', () => { diff --git a/packages/database/src/models/__tests__/agentBotProvider.test.ts b/packages/database/src/models/__tests__/agentBotProvider.test.ts index 0be0116490..ca80220e64 100644 --- a/packages/database/src/models/__tests__/agentBotProvider.test.ts +++ b/packages/database/src/models/__tests__/agentBotProvider.test.ts @@ -2,7 +2,7 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { getTestDB } from '../../core/getTestDB'; -import { agentBotProviders, agents, users } from '../../schemas'; +import { agentBotProviders, agents, users, workspaces } from '../../schemas'; import type { LobeChatDatabase } from '../../type'; import { AgentBotProviderModel } from '../agentBotProvider'; @@ -337,6 +337,117 @@ describe('AgentBotProviderModel', () => { }); }); + describe('findEnabledByPlatformAndAppId (static)', () => { + it('should find an enabled provider that lives in a workspace (system-wide, ignores ownership scope)', async () => { + // Regression: workspace-scoped bots could not be connected because the + // gateway looked them up in personal scope (workspace_id IS NULL). + const workspaceId = 'bot-provider-test-workspace'; + await serverDB.insert(workspaces).values({ + id: workspaceId, + name: 'Test WS', + primaryOwnerId: userId, + slug: 'test-ws', + }); + + const wsModel = new AgentBotProviderModel(serverDB, userId, mockGateKeeper, workspaceId); + await wsModel.create({ + agentId, + applicationId: 'ws-app', + credentials: { botToken: 'ws-tok' }, + platform: 'discord', + }); + + // The personal-scope instance lookup misses the workspace row — this is + // the exact failure the static method exists to avoid. + const personalModel = new AgentBotProviderModel(serverDB, userId, mockGateKeeper); + expect(await personalModel.findEnabledByApplicationId('discord', 'ws-app')).toBeNull(); + + // The system-wide static lookup finds it and decrypts credentials. + const result = await AgentBotProviderModel.findEnabledByPlatformAndAppId( + serverDB, + 'discord', + 'ws-app', + mockGateKeeper, + ); + expect(result).not.toBeNull(); + expect(result!.applicationId).toBe('ws-app'); + expect(result!.workspaceId).toBe(workspaceId); + expect(result!.credentials.botToken).toBe('ws-tok'); + }); + + it('should find a provider owned by any user', async () => { + const model2 = new AgentBotProviderModel(serverDB, userId2); + await model2.create({ + agentId: agentId2, + applicationId: 'other-user-app', + credentials: { botToken: 'tok' }, + platform: 'slack', + }); + + const result = await AgentBotProviderModel.findEnabledByPlatformAndAppId( + serverDB, + 'slack', + 'other-user-app', + ); + expect(result).not.toBeNull(); + expect(result!.applicationId).toBe('other-user-app'); + }); + + it('should return null for a disabled provider', async () => { + const model = new AgentBotProviderModel(serverDB, userId); + const created = await model.create({ + agentId, + applicationId: 'disabled-app', + credentials: { botToken: 'tok' }, + platform: 'discord', + }); + await model.update(created.id, { enabled: false }); + + const result = await AgentBotProviderModel.findEnabledByPlatformAndAppId( + serverDB, + 'discord', + 'disabled-app', + ); + expect(result).toBeNull(); + }); + + it('should return null for a non-existent combination', async () => { + const result = await AgentBotProviderModel.findEnabledByPlatformAndAppId( + serverDB, + 'discord', + 'no-such-app', + ); + expect(result).toBeNull(); + }); + }); + + describe('findByAgentId (static)', () => { + it('should return all providers for an agent regardless of ownership scope, decrypted', async () => { + const model = new AgentBotProviderModel(serverDB, userId, mockGateKeeper); + await model.create({ + agentId, + applicationId: 'agent-app-1', + credentials: { botToken: 'tok-1' }, + platform: 'discord', + }); + const disabled = await model.create({ + agentId, + applicationId: 'agent-app-2', + credentials: { botToken: 'tok-2' }, + platform: 'slack', + }); + await model.update(disabled.id, { enabled: false }); + + const results = await AgentBotProviderModel.findByAgentId(serverDB, agentId, mockGateKeeper); + + // Returns both enabled and disabled rows (caller filters by `enabled`). + expect(results).toHaveLength(2); + const byApp = Object.fromEntries(results.map((r) => [r.applicationId, r])); + expect(byApp['agent-app-1'].credentials.botToken).toBe('tok-1'); + expect(byApp['agent-app-2'].credentials.botToken).toBe('tok-2'); + }); + }); + describe('findEnabledByPlatform (static)', () => { it('should return Discord providers with botToken', async () => { const model = new AgentBotProviderModel(serverDB, userId); diff --git a/packages/database/src/models/__tests__/agentDocuments.workspace.test.ts b/packages/database/src/models/__tests__/agentDocuments.workspace.test.ts new file mode 100644 index 0000000000..4b92a25c2f --- /dev/null +++ b/packages/database/src/models/__tests__/agentDocuments.workspace.test.ts @@ -0,0 +1,89 @@ +// @vitest-environment node +import { eq } from 'drizzle-orm'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; + +import { getTestDB } from '../../core/getTestDB'; +import { agentDocuments, agents, documents, users, workspaces } from '../../schemas'; +import type { LobeChatDatabase } from '../../type'; +import { AgentDocumentModel } from '../agentDocuments'; + +const serverDB: LobeChatDatabase = await getTestDB(); + +const userId = 'agent-document-workspace-user'; +const workspaceId = 'agent-document-workspace'; + +beforeEach(async () => { + await serverDB.delete(users); + await serverDB.insert(users).values({ id: userId }); + await serverDB.insert(workspaces).values({ + id: workspaceId, + name: 'Agent Document Workspace', + primaryOwnerId: userId, + slug: workspaceId, + }); + await serverDB.insert(agents).values([ + { id: 'personal-agent-document-agent', title: 'Personal Agent', userId, workspaceId: null }, + { id: 'workspace-agent-document-agent', title: 'Workspace Agent', userId, workspaceId }, + ]); +}); + +afterEach(async () => { + await serverDB.delete(users); +}); + +describe('AgentDocumentModel workspace scope', () => { + it('isolates document reads and deletes between personal and workspace scopes', async () => { + const personalModel = new AgentDocumentModel(serverDB, userId); + const workspaceModel = new AgentDocumentModel(serverDB, userId, workspaceId); + + const personalDoc = await personalModel.create( + 'personal-agent-document-agent', + 'README.md', + '# Personal', + ); + const workspaceDoc = await workspaceModel.create( + 'workspace-agent-document-agent', + 'README.md', + '# Workspace', + ); + + await expect(personalModel.findById(workspaceDoc.id)).resolves.toBeUndefined(); + await expect(workspaceModel.findById(personalDoc.id)).resolves.toBeUndefined(); + + await expect( + serverDB.query.agentDocuments.findFirst({ + where: eq(agentDocuments.id, personalDoc.id), + }), + ).resolves.toMatchObject({ id: personalDoc.id, workspaceId: null }); + await expect( + serverDB.query.agentDocuments.findFirst({ + where: eq(agentDocuments.id, workspaceDoc.id), + }), + ).resolves.toMatchObject({ id: workspaceDoc.id, workspaceId }); + + await expect(personalModel.findByAgent('personal-agent-document-agent')).resolves.toEqual([ + expect.objectContaining({ id: personalDoc.id }), + ]); + await expect(workspaceModel.findByAgent('workspace-agent-document-agent')).resolves.toEqual([ + expect.objectContaining({ id: workspaceDoc.id }), + ]); + + await personalModel.deleteByAgent('personal-agent-document-agent'); + + await expect(personalModel.findById(personalDoc.id)).resolves.toBeUndefined(); + await expect(workspaceModel.findById(workspaceDoc.id)).resolves.toMatchObject({ + id: workspaceDoc.id, + }); + + await personalModel.permanentlyDelete(workspaceDoc.id); + await expect(workspaceModel.findById(workspaceDoc.id)).resolves.toMatchObject({ + id: workspaceDoc.id, + }); + }); +}); + +afterEach(async () => { + await serverDB.delete(agentDocuments).where(eq(agentDocuments.userId, userId)); + await serverDB.delete(documents).where(eq(documents.userId, userId)); + await serverDB.delete(agents).where(eq(agents.userId, userId)); +}); diff --git a/packages/database/src/models/__tests__/agentEval.workspace.test.ts b/packages/database/src/models/__tests__/agentEval.workspace.test.ts new file mode 100644 index 0000000000..f078cf3765 --- /dev/null +++ b/packages/database/src/models/__tests__/agentEval.workspace.test.ts @@ -0,0 +1,184 @@ +// @vitest-environment node +import { eq } from 'drizzle-orm'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; + +import { getTestDB } from '../../core/getTestDB'; +import { + agentEvalBenchmarks, + agentEvalDatasets, + agentEvalRuns, + agentEvalRunTopics, + agentEvalTestCases, + topics, + users, + workspaces, +} from '../../schemas'; +import type { LobeChatDatabase } from '../../type'; +import { + AgentEvalBenchmarkModel, + AgentEvalDatasetModel, + AgentEvalRunModel, + AgentEvalRunTopicModel, + AgentEvalTestCaseModel, +} from '../agentEval'; + +const serverDB: LobeChatDatabase = await getTestDB(); + +const userId = 'agent-eval-workspace-user'; +const workspaceId = 'agent-eval-workspace'; + +beforeEach(async () => { + await serverDB.delete(users); + await serverDB.insert(users).values({ id: userId }); + await serverDB.insert(workspaces).values({ + id: workspaceId, + name: 'Agent Eval Workspace', + primaryOwnerId: userId, + slug: workspaceId, + }); +}); + +afterEach(async () => { + await serverDB.delete(users); +}); + +describe('Agent eval workspace scope', () => { + it('isolates benchmarks, datasets, test cases, runs, and run topics', async () => { + const personalBenchmarkModel = new AgentEvalBenchmarkModel(serverDB, userId); + const workspaceBenchmarkModel = new AgentEvalBenchmarkModel(serverDB, userId, workspaceId); + const personalDatasetModel = new AgentEvalDatasetModel(serverDB, userId); + const workspaceDatasetModel = new AgentEvalDatasetModel(serverDB, userId, workspaceId); + const personalTestCaseModel = new AgentEvalTestCaseModel(serverDB, userId); + const workspaceTestCaseModel = new AgentEvalTestCaseModel(serverDB, userId, workspaceId); + const personalRunModel = new AgentEvalRunModel(serverDB, userId); + const workspaceRunModel = new AgentEvalRunModel(serverDB, userId, workspaceId); + const personalRunTopicModel = new AgentEvalRunTopicModel(serverDB, userId); + const workspaceRunTopicModel = new AgentEvalRunTopicModel(serverDB, userId, workspaceId); + + const personalBenchmark = await personalBenchmarkModel.create({ + identifier: 'shared-benchmark', + isSystem: false, + name: 'Personal benchmark', + rubrics: [], + }); + const workspaceBenchmark = await workspaceBenchmarkModel.create({ + identifier: 'shared-benchmark', + isSystem: false, + name: 'Workspace benchmark', + rubrics: [], + }); + + await expect( + personalBenchmarkModel.findByIdentifier('shared-benchmark'), + ).resolves.toMatchObject({ + id: personalBenchmark.id, + workspaceId: null, + }); + await expect( + workspaceBenchmarkModel.findByIdentifier('shared-benchmark'), + ).resolves.toMatchObject({ + id: workspaceBenchmark.id, + workspaceId, + }); + + const personalDataset = await personalDatasetModel.create({ + benchmarkId: personalBenchmark.id, + identifier: 'shared-dataset', + name: 'Personal dataset', + }); + const workspaceDataset = await workspaceDatasetModel.create({ + benchmarkId: workspaceBenchmark.id, + identifier: 'shared-dataset', + name: 'Workspace dataset', + }); + + await expect(personalDatasetModel.query(personalBenchmark.id)).resolves.toEqual([ + expect.objectContaining({ id: personalDataset.id }), + ]); + await expect(workspaceDatasetModel.query(workspaceBenchmark.id)).resolves.toEqual([ + expect.objectContaining({ id: workspaceDataset.id }), + ]); + await expect(personalDatasetModel.findById(personalDataset.id)).resolves.toMatchObject({ + id: personalDataset.id, + workspaceId: null, + }); + await expect(workspaceDatasetModel.findById(workspaceDataset.id)).resolves.toMatchObject({ + id: workspaceDataset.id, + workspaceId, + }); + + const personalTestCase = await personalTestCaseModel.create({ + content: { expected: 'personal', input: 'question' }, + datasetId: personalDataset.id, + }); + const workspaceTestCase = await workspaceTestCaseModel.create({ + content: { expected: 'workspace', input: 'question' }, + datasetId: workspaceDataset.id, + }); + + await expect(personalTestCaseModel.findById(workspaceTestCase.id)).resolves.toBeUndefined(); + await expect(workspaceTestCaseModel.findById(personalTestCase.id)).resolves.toBeUndefined(); + + const personalRun = await personalRunModel.create({ + datasetId: personalDataset.id, + name: 'Personal run', + }); + const workspaceRun = await workspaceRunModel.create({ + datasetId: workspaceDataset.id, + name: 'Workspace run', + }); + + await expect(personalRunModel.findById(workspaceRun.id)).resolves.toBeUndefined(); + await expect(workspaceRunModel.findById(personalRun.id)).resolves.toBeUndefined(); + + await serverDB.insert(topics).values([ + { id: 'agent-eval-personal-topic', title: 'Personal topic', userId, workspaceId: null }, + { id: 'agent-eval-workspace-topic', title: 'Workspace topic', userId, workspaceId }, + ]); + + await personalRunTopicModel.batchCreate([ + { + runId: personalRun.id, + status: 'completed', + testCaseId: personalTestCase.id, + topicId: 'agent-eval-personal-topic', + }, + ]); + const [workspaceRunTopic] = await workspaceRunTopicModel.batchCreate([ + { + runId: workspaceRun.id, + status: 'completed', + testCaseId: workspaceTestCase.id, + topicId: 'agent-eval-workspace-topic', + }, + ]); + + expect(workspaceRunTopic).toMatchObject({ runId: workspaceRun.id, workspaceId }); + + await expect(personalRunTopicModel.findByRunId(workspaceRun.id)).resolves.toEqual([]); + await expect(workspaceRunTopicModel.findByRunId(personalRun.id)).resolves.toEqual([]); + await expect(workspaceRunTopicModel.findByRunId(workspaceRun.id)).resolves.toEqual([ + expect.objectContaining({ + runId: workspaceRun.id, + topicId: 'agent-eval-workspace-topic', + }), + ]); + + await personalBenchmarkModel.delete(personalBenchmark.id); + + await expect(personalBenchmarkModel.findById(personalBenchmark.id)).resolves.toBeUndefined(); + await expect(workspaceBenchmarkModel.findById(workspaceBenchmark.id)).resolves.toMatchObject({ + id: workspaceBenchmark.id, + workspaceId, + }); + }); +}); + +afterEach(async () => { + await serverDB.delete(agentEvalRunTopics).where(eq(agentEvalRunTopics.userId, userId)); + await serverDB.delete(topics).where(eq(topics.userId, userId)); + await serverDB.delete(agentEvalRuns).where(eq(agentEvalRuns.userId, userId)); + await serverDB.delete(agentEvalTestCases).where(eq(agentEvalTestCases.userId, userId)); + await serverDB.delete(agentEvalDatasets).where(eq(agentEvalDatasets.userId, userId)); + await serverDB.delete(agentEvalBenchmarks).where(eq(agentEvalBenchmarks.userId, userId)); +}); diff --git a/packages/database/src/models/__tests__/agentSignalReviewContext.workspace.test.ts b/packages/database/src/models/__tests__/agentSignalReviewContext.workspace.test.ts new file mode 100644 index 0000000000..d7be21b93d --- /dev/null +++ b/packages/database/src/models/__tests__/agentSignalReviewContext.workspace.test.ts @@ -0,0 +1,108 @@ +// @vitest-environment node +import { eq } from 'drizzle-orm'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; + +import { getTestDB } from '../../core/getTestDB'; +import { agentDocuments, agents, documents, users, workspaces } from '../../schemas'; +import type { LobeChatDatabase } from '../../type'; +import { AgentDocumentModel } from '../agentDocuments'; +import { AgentSignalReviewContextModel } from '../agentSignal/reviewContext'; + +const serverDB: LobeChatDatabase = await getTestDB(); + +const userId = 'agent-signal-review-workspace-user'; +const workspaceId = 'agent-signal-review-workspace'; + +beforeEach(async () => { + await serverDB.delete(users); + await serverDB.insert(users).values({ id: userId }); + await serverDB.insert(workspaces).values({ + id: workspaceId, + name: 'Agent Signal Review Workspace', + primaryOwnerId: userId, + slug: workspaceId, + }); + await serverDB.insert(agents).values([ + { + chatConfig: { selfIteration: { enabled: true } }, + id: 'personal-review-agent', + title: 'Personal Review Agent', + userId, + virtual: false, + workspaceId: null, + }, + { + chatConfig: { selfIteration: { enabled: true } }, + id: 'workspace-review-agent', + title: 'Workspace Review Agent', + userId, + virtual: false, + workspaceId, + }, + ]); +}); + +afterEach(async () => { + await serverDB.delete(users); +}); + +describe('AgentSignalReviewContextModel workspace scope', () => { + it('isolates self-iteration checks and document activity by workspace', async () => { + const personalContext = new AgentSignalReviewContextModel(serverDB, userId); + const workspaceContext = new AgentSignalReviewContextModel(serverDB, userId, workspaceId); + const personalDocumentModel = new AgentDocumentModel(serverDB, userId); + const workspaceDocumentModel = new AgentDocumentModel(serverDB, userId, workspaceId); + + const personalDoc = await personalDocumentModel.create( + 'personal-review-agent', + 'personal.md', + '# Personal', + ); + const workspaceDoc = await workspaceDocumentModel.create( + 'workspace-review-agent', + 'workspace.md', + '# Workspace', + ); + + await expect(personalContext.canAgentRunSelfIteration('personal-review-agent')).resolves.toBe( + true, + ); + await expect(personalContext.canAgentRunSelfIteration('workspace-review-agent')).resolves.toBe( + false, + ); + await expect(workspaceContext.canAgentRunSelfIteration('personal-review-agent')).resolves.toBe( + false, + ); + await expect(workspaceContext.canAgentRunSelfIteration('workspace-review-agent')).resolves.toBe( + true, + ); + + const window = { + agentId: 'workspace-review-agent', + windowEnd: new Date('2100-01-01'), + windowStart: new Date('2000-01-01'), + }; + + await expect(personalContext.listDocumentActivity(window)).resolves.toEqual([]); + await expect(workspaceContext.listDocumentActivity(window)).resolves.toEqual([ + expect.objectContaining({ + agentDocumentId: workspaceDoc.id, + documentId: workspaceDoc.documentId, + }), + ]); + + await expect( + workspaceContext.listDocumentActivity({ + ...window, + agentId: 'personal-review-agent', + }), + ).resolves.toEqual([]); + await expect(personalDocumentModel.findById(personalDoc.id)).resolves.toBeDefined(); + }); +}); + +afterEach(async () => { + await serverDB.delete(agentDocuments).where(eq(agentDocuments.userId, userId)); + await serverDB.delete(documents).where(eq(documents.userId, userId)); + await serverDB.delete(agents).where(eq(agents.userId, userId)); +}); diff --git a/packages/database/src/models/__tests__/agentTransfer.test.ts b/packages/database/src/models/__tests__/agentTransfer.test.ts new file mode 100644 index 0000000000..776b08303e --- /dev/null +++ b/packages/database/src/models/__tests__/agentTransfer.test.ts @@ -0,0 +1,189 @@ +// @vitest-environment node +import { eq } from 'drizzle-orm'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; + +import { getTestDB } from '../../core/getTestDB'; +import { + agentBotProviders, + agents, + agentsToSessions, + chatGroups, + chatGroupsAgents, + messages, + sessions, + topics, + users, + workspaces, +} from '../../schemas'; +import type { LobeChatDatabase } from '../../type'; +import { AgentModel } from '../agent'; + +const serverDB: LobeChatDatabase = await getTestDB(); + +const userId = 'transfer-test-user'; +const wsId1 = 'transfer-test-ws-1'; +const wsId2 = 'transfer-test-ws-2'; + +beforeEach(async () => { + await serverDB.delete(users); + await serverDB.insert(users).values([{ id: userId }]); + await serverDB.insert(workspaces).values([ + { id: wsId1, name: 'WS 1', slug: 'ws-1', primaryOwnerId: userId }, + { id: wsId2, name: 'WS 2', slug: 'ws-2', primaryOwnerId: userId }, + ]); +}); + +afterEach(async () => { + await serverDB.delete(users); +}); + +describe('AgentModel.transferAgent', () => { + it('should transfer agent from personal to workspace', async () => { + const model = new AgentModel(serverDB, userId); + const agent = await model.create({ title: 'Test Agent', slug: 'test-agent' }); + + const result = await model.transferAgent(agent.id, wsId1, userId); + + expect(result.agentId).toBe(agent.id); + + const updated = await serverDB.query.agents.findFirst({ + where: eq(agents.id, agent.id), + }); + expect(updated?.workspaceId).toBe(wsId1); + expect(updated?.userId).toBe(userId); + }); + + it('should transfer agent from workspace to personal', async () => { + const model = new AgentModel(serverDB, userId, wsId1); + const agent = await model.create({ title: 'WS Agent', slug: 'ws-agent' }); + + const result = await model.transferAgent(agent.id, null, userId); + + expect(result.agentId).toBe(agent.id); + + const updated = await serverDB.query.agents.findFirst({ + where: eq(agents.id, agent.id), + }); + expect(updated?.workspaceId).toBeNull(); + expect(updated?.userId).toBe(userId); + }); + + it('should transfer agent between workspaces', async () => { + const model = new AgentModel(serverDB, userId, wsId1); + const agent = await model.create({ title: 'WS1 Agent', slug: 'ws1-agent' }); + + const result = await model.transferAgent(agent.id, wsId2, userId); + + expect(result.agentId).toBe(agent.id); + + const updated = await serverDB.query.agents.findFirst({ + where: eq(agents.id, agent.id), + }); + expect(updated?.workspaceId).toBe(wsId2); + }); + + it('should handle slug conflict by appending suffix', async () => { + const model = new AgentModel(serverDB, userId, wsId1); + const agent1 = await model.create({ title: 'Agent', slug: 'my-agent' }); + + // Create an agent with the same slug in target workspace + const model2 = new AgentModel(serverDB, userId, wsId2); + await model2.create({ title: 'Existing Agent', slug: 'my-agent' }); + + const result = await model.transferAgent(agent1.id, wsId2, userId); + + expect(result.slug).toBe('my-agent-1'); + + const updated = await serverDB.query.agents.findFirst({ + where: eq(agents.id, agent1.id), + }); + expect(updated?.slug).toBe('my-agent-1'); + }); + + it('should update related sessions and agentsToSessions', async () => { + const model = new AgentModel(serverDB, userId); + const agent = await model.create({ title: 'Agent' }); + + // Create a session linked to the agent + await serverDB.insert(sessions).values({ id: 'sess-1', userId, type: 'agent' }); + await serverDB + .insert(agentsToSessions) + .values({ agentId: agent.id, sessionId: 'sess-1', userId }); + + await model.transferAgent(agent.id, wsId1, userId); + + const [session] = await serverDB.select().from(sessions).where(eq(sessions.id, 'sess-1')); + expect(session.workspaceId).toBe(wsId1); + + const [link] = await serverDB + .select() + .from(agentsToSessions) + .where(eq(agentsToSessions.agentId, agent.id)); + expect(link.workspaceId).toBe(wsId1); + }); + + it('should update topics and messages', async () => { + const model = new AgentModel(serverDB, userId); + const agent = await model.create({ title: 'Agent' }); + + await serverDB.insert(topics).values({ id: 'topic-1', agentId: agent.id, userId }); + await serverDB + .insert(messages) + .values({ id: 'msg-1', agentId: agent.id, userId, role: 'assistant' }); + + await model.transferAgent(agent.id, wsId1, userId); + + const [topic] = await serverDB.select().from(topics).where(eq(topics.id, 'topic-1')); + expect(topic.workspaceId).toBe(wsId1); + + const [msg] = await serverDB.select().from(messages).where(eq(messages.id, 'msg-1')); + expect(msg.workspaceId).toBe(wsId1); + }); + + it('should update bot providers', async () => { + const model = new AgentModel(serverDB, userId); + const agent = await model.create({ title: 'Agent' }); + + await serverDB.insert(agentBotProviders).values({ + agentId: agent.id, + userId, + platform: 'discord', + applicationId: 'app-1', + credentials: 'encrypted-creds', + }); + + await model.transferAgent(agent.id, wsId1, userId); + + const [bot] = await serverDB + .select() + .from(agentBotProviders) + .where(eq(agentBotProviders.agentId, agent.id)); + expect(bot.workspaceId).toBe(wsId1); + expect(bot.userId).toBe(userId); + }); + + it('should remove chat group associations', async () => { + const model = new AgentModel(serverDB, userId); + const agent = await model.create({ title: 'Agent' }); + + await serverDB.insert(chatGroups).values({ id: 'group-1', userId }); + await serverDB + .insert(chatGroupsAgents) + .values({ chatGroupId: 'group-1', agentId: agent.id, userId }); + + await model.transferAgent(agent.id, wsId1, userId); + + const groupLinks = await serverDB + .select() + .from(chatGroupsAgents) + .where(eq(chatGroupsAgents.agentId, agent.id)); + expect(groupLinks).toHaveLength(0); + }); + + it('should throw when agent not found', async () => { + const model = new AgentModel(serverDB, userId); + await expect(model.transferAgent('nonexistent', wsId1, userId)).rejects.toThrow( + 'Agent not found', + ); + }); +}); diff --git a/packages/database/src/models/__tests__/chatGroup.test.ts b/packages/database/src/models/__tests__/chatGroup.test.ts index b3228251da..92b0b27642 100644 --- a/packages/database/src/models/__tests__/chatGroup.test.ts +++ b/packages/database/src/models/__tests__/chatGroup.test.ts @@ -7,11 +7,18 @@ import type { LobeChatDatabase } from '@/database/type'; import { getTestDB } from '../../core/getTestDB'; import type { NewChatGroup } from '../../schemas'; -import { agents as agentsTable, chatGroups, chatGroupsAgents, users } from '../../schemas'; +import { + agents as agentsTable, + chatGroups, + chatGroupsAgents, + users, + workspaces, +} from '../../schemas'; import { ChatGroupModel } from '../chatGroup'; const userId = 'test-user'; const otherUserId = 'other-user'; +const workspaceId = 'chat-group-workspace'; const serverDB: LobeChatDatabase = await getTestDB(); @@ -26,11 +33,18 @@ type RelationAgent = { const toRelationAgents = (agents: unknown): RelationAgent[] => agents as RelationAgent[]; const chatGroupModel = new ChatGroupModel(serverDB, userId); +const workspaceChatGroupModel = new ChatGroupModel(serverDB, otherUserId, workspaceId); beforeEach(async () => { await serverDB.delete(users); // Create test users await serverDB.insert(users).values([{ id: userId }, { id: otherUserId }]); + await serverDB.insert(workspaces).values({ + id: workspaceId, + name: 'Chat Group Workspace', + primaryOwnerId: userId, + slug: workspaceId, + }); }); afterEach(async () => { @@ -983,5 +997,32 @@ describe('ChatGroupModel', () => { expect(result[0].id).toBe('user-group'); expect(result[0].userId).toBe(userId); }); + + it('should return workspace groups for members even when rows were created by another user', async () => { + await serverDB.transaction(async (trx) => { + await trx.insert(chatGroups).values({ + id: 'workspace-group', + title: 'Workspace Group', + userId, + workspaceId, + }); + await trx.insert(agentsTable).values({ + id: 'workspace-agent', + title: 'Workspace Agent', + userId, + workspaceId, + }); + await trx.insert(chatGroupsAgents).values({ + agentId: 'workspace-agent', + chatGroupId: 'workspace-group', + userId, + workspaceId, + }); + }); + + const result = await workspaceChatGroupModel.getGroupsWithAgents(['workspace-agent']); + + expect(result).toEqual([expect.objectContaining({ id: 'workspace-group', workspaceId })]); + }); }); }); diff --git a/packages/database/src/models/__tests__/chunk.test.ts b/packages/database/src/models/__tests__/chunk.test.ts index 8c572e5f3f..5021d1c31d 100644 --- a/packages/database/src/models/__tests__/chunk.test.ts +++ b/packages/database/src/models/__tests__/chunk.test.ts @@ -5,7 +5,15 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { uuid } from '@/utils/uuid'; import { getTestDB } from '../../core/getTestDB'; -import { chunks, embeddings, fileChunks, files, unstructuredChunks, users } from '../../schemas'; +import { + chunks, + embeddings, + fileChunks, + files, + unstructuredChunks, + users, + workspaces, +} from '../../schemas'; import type { LobeChatDatabase } from '../../type'; import { ChunkModel } from '../chunk'; import { codeEmbedding, designThinkingQuery, designThinkingQuery2 } from './fixtures/embedding'; @@ -13,6 +21,7 @@ import { codeEmbedding, designThinkingQuery, designThinkingQuery2 } from './fixt const serverDB: LobeChatDatabase = await getTestDB(); const userId = 'chunk-model-test-user-id'; +const workspaceId = 'chunk-model-workspace'; const chunkModel = new ChunkModel(serverDB, userId); const sharedFileList = [ { @@ -44,6 +53,12 @@ const sharedFileList = [ beforeEach(async () => { await serverDB.delete(users); await serverDB.insert(users).values([{ id: userId }]); + await serverDB.insert(workspaces).values({ + id: workspaceId, + name: 'Chunk Workspace', + primaryOwnerId: userId, + slug: workspaceId, + }); await serverDB.insert(files).values(sharedFileList); }); @@ -382,6 +397,27 @@ describe('ChunkModel', () => { expect(result).toHaveLength(0); }); + + it('should not count workspace chunks from personal scope', async () => { + await serverDB.insert(files).values({ + id: 'workspace-file', + name: 'workspace.pdf', + url: 'https://example.com/workspace.pdf', + size: 1000, + fileType: 'application/pdf', + userId, + workspaceId, + }); + const [chunk] = await serverDB + .insert(chunks) + .values({ text: 'Workspace Chunk', userId, workspaceId }) + .returning(); + await serverDB + .insert(fileChunks) + .values({ chunkId: chunk.id, fileId: 'workspace-file', userId, workspaceId }); + + await expect(chunkModel.countByFileIds(['workspace-file'])).resolves.toHaveLength(0); + }); }); describe('countByFileId', () => { diff --git a/packages/database/src/models/__tests__/documentTransfer.test.ts b/packages/database/src/models/__tests__/documentTransfer.test.ts new file mode 100644 index 0000000000..81f7904808 --- /dev/null +++ b/packages/database/src/models/__tests__/documentTransfer.test.ts @@ -0,0 +1,206 @@ +// @vitest-environment node +import { eq, inArray } from 'drizzle-orm'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; + +import { getTestDB } from '../../core/getTestDB'; +import { DOCUMENT_FOLDER_TYPE, documents, files, users, workspaces } from '../../schemas'; +import type { LobeChatDatabase } from '../../type'; +import { DocumentModel } from '../document'; + +const serverDB: LobeChatDatabase = await getTestDB(); + +const userId = 'doc-transfer-test-user'; +const wsId1 = 'doc-transfer-test-ws-1'; +const wsId2 = 'doc-transfer-test-ws-2'; + +const createFolder = async ( + model: DocumentModel, + filename: string, + slug: string, + parentId?: string, +) => + model.create({ + content: '', + fileType: DOCUMENT_FOLDER_TYPE, + filename, + parentId, + slug, + source: '', + sourceType: 'api', + title: filename, + totalCharCount: 0, + totalLineCount: 0, + }); + +const createPage = async ( + model: DocumentModel, + filename: string, + slug: string, + parentId?: string, +) => + model.create({ + content: 'hello', + fileType: 'page', + filename, + parentId, + slug, + source: '', + sourceType: 'api', + title: filename, + totalCharCount: 5, + totalLineCount: 1, + }); + +beforeEach(async () => { + await serverDB.delete(users); + await serverDB.insert(users).values([{ id: userId }]); + await serverDB.insert(workspaces).values([ + { id: wsId1, name: 'Doc WS 1', slug: 'doc-ws-1', primaryOwnerId: userId }, + { id: wsId2, name: 'Doc WS 2', slug: 'doc-ws-2', primaryOwnerId: userId }, + ]); +}); + +afterEach(async () => { + await serverDB.delete(users); +}); + +describe('DocumentModel.transferTo', () => { + it('transfers a single page from personal to workspace', async () => { + const model = new DocumentModel(serverDB, userId); + const page = await createPage(model, 'My Page', 'my-page'); + + const result = await model.transferTo(page.id, wsId1, userId); + + expect(result.documentIds).toEqual([page.id]); + const updated = await serverDB.query.documents.findFirst({ where: eq(documents.id, page.id) }); + expect(updated?.workspaceId).toBe(wsId1); + expect(updated?.userId).toBe(userId); + }); + + it('transfers a folder and all descendants', async () => { + const model = new DocumentModel(serverDB, userId); + const folder = await createFolder(model, 'Folder', 'folder-1'); + const child = await createPage(model, 'Child', 'child-1', folder.id); + const subFolder = await createFolder(model, 'Sub', 'sub-1', folder.id); + const grandchild = await createPage(model, 'Grand', 'grand-1', subFolder.id); + + const result = await model.transferTo(folder.id, wsId1, userId); + + expect(result.documentIds.sort()).toEqual( + [folder.id, child.id, subFolder.id, grandchild.id].sort(), + ); + + const rows = await serverDB + .select({ id: documents.id, workspaceId: documents.workspaceId }) + .from(documents) + .where(inArray(documents.id, result.documentIds)); + for (const row of rows) expect(row.workspaceId).toBe(wsId1); + }); + + it('resolves slug conflicts by suffixing', async () => { + const ws1 = new DocumentModel(serverDB, userId, wsId1); + await createPage(ws1, 'Existing', 'shared-slug'); + + const personal = new DocumentModel(serverDB, userId); + const mine = await createPage(personal, 'Mine', 'shared-slug'); + + await personal.transferTo(mine.id, wsId1, userId); + + const updated = await serverDB.query.documents.findFirst({ where: eq(documents.id, mine.id) }); + expect(updated?.slug).toBe('shared-slug-1'); + expect(updated?.workspaceId).toBe(wsId1); + }); + + it('moves files anchored to documents in the transferred subtree', async () => { + const model = new DocumentModel(serverDB, userId); + const folder = await createFolder(model, 'Folder', 'transfer-folder'); + + await serverDB.insert(files).values({ + id: 'file-x', + userId, + fileType: 'image/png', + name: 'pic.png', + size: 10, + url: 'http://x', + parentId: folder.id, + }); + + await model.transferTo(folder.id, wsId1, userId); + + const [file] = await serverDB.select().from(files).where(eq(files.id, 'file-x')); + expect(file.workspaceId).toBe(wsId1); + expect(file.userId).toBe(userId); + }); + + it('transfers from workspace back to personal', async () => { + const ws = new DocumentModel(serverDB, userId, wsId1); + const page = await createPage(ws, 'In WS', 'in-ws'); + + await ws.transferTo(page.id, null, userId); + + const updated = await serverDB.query.documents.findFirst({ where: eq(documents.id, page.id) }); + expect(updated?.workspaceId).toBeNull(); + }); +}); + +describe('DocumentModel.copyToWorkspace', () => { + it('clones a single page into the target workspace with a fresh id', async () => { + const model = new DocumentModel(serverDB, userId); + const page = await createPage(model, 'Page', 'page-x'); + + const { rootId } = await model.copyToWorkspace(page.id, wsId1, userId); + + expect(rootId).not.toBe(page.id); + const clone = await serverDB.query.documents.findFirst({ where: eq(documents.id, rootId) }); + expect(clone?.workspaceId).toBe(wsId1); + expect(clone?.title).toBe('Page'); + expect(clone?.content).toBe('hello'); + + // Original untouched + const original = await serverDB.query.documents.findFirst({ where: eq(documents.id, page.id) }); + expect(original?.workspaceId).toBeNull(); + }); + + it('clones a folder + descendants preserving the parent topology', async () => { + const model = new DocumentModel(serverDB, userId); + const folder = await createFolder(model, 'Folder', 'copy-folder'); + const child = await createPage(model, 'Child', 'copy-child', folder.id); + const sub = await createFolder(model, 'Sub', 'copy-sub', folder.id); + const grand = await createPage(model, 'Grand', 'copy-grand', sub.id); + + const { rootId } = await model.copyToWorkspace(folder.id, wsId1, userId); + + const cloned = await serverDB.select().from(documents).where(eq(documents.workspaceId, wsId1)); + + expect(cloned).toHaveLength(4); + const root = cloned.find((d) => d.id === rootId)!; + expect(root.parentId).toBeNull(); + + const childrenOfRoot = cloned.filter((d) => d.parentId === rootId); + expect(childrenOfRoot).toHaveLength(2); + + // Locate cloned sub folder, then grandchild beneath it + const clonedSub = childrenOfRoot.find((d) => d.title === 'Sub')!; + const clonedGrand = cloned.find((d) => d.parentId === clonedSub.id)!; + expect(clonedGrand.title).toBe('Grand'); + + // Verify originals untouched + const originals = await serverDB + .select() + .from(documents) + .where(inArray(documents.id, [folder.id, child.id, sub.id, grand.id])); + for (const row of originals) expect(row.workspaceId).toBeNull(); + }); + + it('reassigns slug on conflict in target scope', async () => { + const ws1 = new DocumentModel(serverDB, userId, wsId1); + await createPage(ws1, 'Existing', 'dupe-slug'); + + const personal = new DocumentModel(serverDB, userId); + const mine = await createPage(personal, 'Mine', 'dupe-slug'); + + const { rootId } = await personal.copyToWorkspace(mine.id, wsId1, userId); + const clone = await serverDB.query.documents.findFirst({ where: eq(documents.id, rootId) }); + expect(clone?.slug).toBe('dupe-slug-1'); + }); +}); diff --git a/packages/database/src/models/__tests__/generation.test.ts b/packages/database/src/models/__tests__/generation.test.ts index 60072ccb93..f04cc9a2b9 100644 --- a/packages/database/src/models/__tests__/generation.test.ts +++ b/packages/database/src/models/__tests__/generation.test.ts @@ -13,6 +13,7 @@ import { generations, generationTopics, users, + workspaces, } from '../../schemas'; import type { LobeChatDatabase } from '../../type'; import { GenerationModel } from '../generation'; @@ -37,6 +38,7 @@ vi.mock('../file', () => ({ const userId = 'generation-test-user-id'; const otherUserId = 'other-user-id'; +const workspaceId = 'generation-workspace'; const generationModel = new GenerationModel(serverDB, userId); // Test data @@ -101,6 +103,12 @@ beforeEach(async () => { // Clear database and create test users await serverDB.delete(users); await serverDB.insert(users).values([{ id: userId }, { id: otherUserId }]); + await serverDB.insert(workspaces).values({ + id: workspaceId, + name: 'Generation Workspace', + primaryOwnerId: userId, + slug: workspaceId, + }); // Create test topic await serverDB.insert(generationTopics).values(testTopic); @@ -956,5 +964,36 @@ describe('GenerationModel', () => { ); expect(result).toBeUndefined(); }); + + it('should not return workspace generation from personal scope', async () => { + const workspaceAsyncTaskId = '550e8400-e29b-41d4-a716-446655440111'; + await serverDB.insert(generationTopics).values({ + ...testTopic, + id: 'workspace-topic-id', + workspaceId, + }); + await serverDB.insert(generationBatches).values({ + ...testBatch, + id: 'workspace-batch-id', + generationTopicId: 'workspace-topic-id', + workspaceId, + }); + await serverDB.insert(asyncTasks).values({ + ...testAsyncTask, + id: workspaceAsyncTaskId, + workspaceId, + }); + await serverDB.insert(generations).values({ + ...testGeneration, + asyncTaskId: workspaceAsyncTaskId, + generationBatchId: 'workspace-batch-id', + userId, + workspaceId, + }); + + await expect( + generationModel.findByAsyncTaskId(workspaceAsyncTaskId), + ).resolves.toBeUndefined(); + }); }); }); diff --git a/packages/database/src/models/__tests__/knowledgeBase.test.ts b/packages/database/src/models/__tests__/knowledgeBase.test.ts index 92f48352bb..069d7c6b70 100644 --- a/packages/database/src/models/__tests__/knowledgeBase.test.ts +++ b/packages/database/src/models/__tests__/knowledgeBase.test.ts @@ -13,6 +13,7 @@ import { knowledgeBaseFiles, knowledgeBases, users, + workspaces, } from '../../schemas'; import type { LobeChatDatabase } from '../../type'; import { KnowledgeBaseModel } from '../knowledgeBase'; @@ -156,6 +157,15 @@ describe('KnowledgeBaseModel', () => { }, ]; + const createWorkspace = async (id: string, slug: string) => { + await serverDB.insert(workspaces).values({ + id, + name: slug, + primaryOwnerId: userId, + slug, + }); + }; + describe('addFilesToKnowledgeBase', () => { it('should add files to a knowledge base', async () => { await serverDB.insert(globalFiles).values([ @@ -683,30 +693,26 @@ describe('KnowledgeBaseModel', () => { }); it('should return empty array when all files are shared', async () => { - await serverDB - .insert(globalFiles) - .values([ - { - hashId: 'hash1', - url: 'https://example.com/a.pdf', - size: 100, - fileType: 'application/pdf', - creator: userId, - }, - ]); - await serverDB - .insert(files) - .values([ - { - id: 'file1', - name: 'a.pdf', - url: 'https://example.com/a.pdf', - fileHash: 'hash1', - size: 100, - fileType: 'application/pdf', - userId, - }, - ]); + await serverDB.insert(globalFiles).values([ + { + hashId: 'hash1', + url: 'https://example.com/a.pdf', + size: 100, + fileType: 'application/pdf', + creator: userId, + }, + ]); + await serverDB.insert(files).values([ + { + id: 'file1', + name: 'a.pdf', + url: 'https://example.com/a.pdf', + fileHash: 'hash1', + size: 100, + fileType: 'application/pdf', + userId, + }, + ]); const { id: kb1 } = await knowledgeBaseModel.create({ name: 'KB1' }); const { id: kb2 } = await knowledgeBaseModel.create({ name: 'KB2' }); await knowledgeBaseModel.addFilesToKnowledgeBase(kb1, ['file1']); @@ -767,30 +773,26 @@ describe('KnowledgeBaseModel', () => { describe('deleteWithFiles', () => { it('should delete KB and its exclusive files', async () => { - await serverDB - .insert(globalFiles) - .values([ - { - hashId: 'hash1', - url: 'https://example.com/a.pdf', - size: 100, - fileType: 'application/pdf', - creator: userId, - }, - ]); - await serverDB - .insert(files) - .values([ - { - id: 'file1', - name: 'a.pdf', - url: 'https://example.com/a.pdf', - fileHash: 'hash1', - size: 100, - fileType: 'application/pdf', - userId, - }, - ]); + await serverDB.insert(globalFiles).values([ + { + hashId: 'hash1', + url: 'https://example.com/a.pdf', + size: 100, + fileType: 'application/pdf', + creator: userId, + }, + ]); + await serverDB.insert(files).values([ + { + id: 'file1', + name: 'a.pdf', + url: 'https://example.com/a.pdf', + fileHash: 'hash1', + size: 100, + fileType: 'application/pdf', + userId, + }, + ]); const { id: kbId } = await knowledgeBaseModel.create({ name: 'KB1' }); await knowledgeBaseModel.addFilesToKnowledgeBase(kbId, ['file1']); const result = await knowledgeBaseModel.deleteWithFiles(kbId); @@ -931,30 +933,26 @@ describe('KnowledgeBaseModel', () => { }); it('should delete shared file when both KBs sharing it are deleted', async () => { - await serverDB - .insert(globalFiles) - .values([ - { - hashId: 'hash1', - url: 'https://example.com/a.pdf', - size: 100, - fileType: 'application/pdf', - creator: userId, - }, - ]); - await serverDB - .insert(files) - .values([ - { - id: 'file1', - name: 'a.pdf', - url: 'https://example.com/a.pdf', - fileHash: 'hash1', - size: 100, - fileType: 'application/pdf', - userId, - }, - ]); + await serverDB.insert(globalFiles).values([ + { + hashId: 'hash1', + url: 'https://example.com/a.pdf', + size: 100, + fileType: 'application/pdf', + creator: userId, + }, + ]); + await serverDB.insert(files).values([ + { + id: 'file1', + name: 'a.pdf', + url: 'https://example.com/a.pdf', + fileHash: 'hash1', + size: 100, + fileType: 'application/pdf', + userId, + }, + ]); const { id: kb1 } = await knowledgeBaseModel.create({ name: 'KB1' }); const { id: kb2 } = await knowledgeBaseModel.create({ name: 'KB2' }); await knowledgeBaseModel.addFilesToKnowledgeBase(kb1, ['file1']); @@ -976,6 +974,189 @@ describe('KnowledgeBaseModel', () => { }); }); + describe('transferTo', () => { + it('should transfer a knowledge base and its resources to another workspace', async () => { + await createWorkspace('workspace-target', 'workspace-target'); + await serverDB.insert(globalFiles).values([ + { + hashId: 'hash-transfer', + url: 'https://example.com/transfer.pdf', + size: 1000, + fileType: 'application/pdf', + creator: userId, + }, + ]); + await serverDB.insert(files).values({ + id: 'file-transfer', + name: 'transfer.pdf', + url: 'https://example.com/transfer.pdf', + fileHash: 'hash-transfer', + size: 1000, + fileType: 'application/pdf', + userId, + }); + + const { id: knowledgeBaseId } = await knowledgeBaseModel.create({ name: 'Transfer KB' }); + await serverDB.insert(documents).values({ + id: 'docs_transfer_folder', + title: 'Folder', + content: '', + fileType: 'custom/folder', + totalCharCount: 0, + totalLineCount: 0, + sourceType: 'api', + source: '', + knowledgeBaseId, + userId, + }); + await knowledgeBaseModel.addFilesToKnowledgeBase(knowledgeBaseId, ['file-transfer']); + + await knowledgeBaseModel.transferTo(knowledgeBaseId, 'workspace-target', userId); + + const transferredKb = await serverDB.query.knowledgeBases.findFirst({ + where: eq(knowledgeBases.id, knowledgeBaseId), + }); + const transferredFile = await serverDB.query.files.findFirst({ + where: eq(files.id, 'file-transfer'), + }); + const transferredDocument = await serverDB.query.documents.findFirst({ + where: eq(documents.id, 'docs_transfer_folder'), + }); + const transferredLink = await serverDB.query.knowledgeBaseFiles.findFirst({ + where: eq(knowledgeBaseFiles.knowledgeBaseId, knowledgeBaseId), + }); + + expect(transferredKb?.workspaceId).toBe('workspace-target'); + expect(transferredFile?.workspaceId).toBe('workspace-target'); + expect(transferredDocument?.workspaceId).toBe('workspace-target'); + expect(transferredLink?.workspaceId).toBe('workspace-target'); + }); + + it('should rename the transferred knowledge base when the target has the same name', async () => { + await createWorkspace('workspace-rename-target', 'workspace-rename-target'); + const targetModel = new KnowledgeBaseModel(serverDB, userId, 'workspace-rename-target'); + await targetModel.create({ name: 'Shared KB' }); + const { id: knowledgeBaseId } = await knowledgeBaseModel.create({ name: 'Shared KB' }); + + await knowledgeBaseModel.transferTo(knowledgeBaseId, 'workspace-rename-target', userId); + + const transferredKb = await serverDB.query.knowledgeBases.findFirst({ + where: eq(knowledgeBases.id, knowledgeBaseId), + }); + + expect(transferredKb?.name).toBe('Shared KB (1)'); + }); + }); + + describe('copyToWorkspace', () => { + it('should copy a knowledge base with files and document hierarchy to another workspace', async () => { + await createWorkspace('workspace-copy-target', 'workspace-copy-target'); + await serverDB.insert(globalFiles).values([ + { + hashId: 'hash-copy', + url: 'https://example.com/copy.pdf', + size: 1000, + fileType: 'application/pdf', + creator: userId, + }, + ]); + + const { id: knowledgeBaseId } = await knowledgeBaseModel.create({ name: 'Copy KB' }); + await serverDB.insert(documents).values([ + { + id: 'docs_copy_folder', + title: 'Folder', + content: '', + fileType: 'custom/folder', + totalCharCount: 0, + totalLineCount: 0, + sourceType: 'api', + source: '', + knowledgeBaseId, + userId, + }, + { + id: 'docs_copy_note', + title: 'Note', + content: 'note content', + fileType: 'custom/document', + totalCharCount: 12, + totalLineCount: 1, + sourceType: 'api', + source: '', + knowledgeBaseId, + parentId: 'docs_copy_folder', + userId, + }, + ]); + await serverDB.insert(files).values({ + id: 'file-copy', + name: 'copy.pdf', + url: 'https://example.com/copy.pdf', + fileHash: 'hash-copy', + size: 1000, + fileType: 'application/pdf', + parentId: 'docs_copy_folder', + userId, + }); + await knowledgeBaseModel.addFilesToKnowledgeBase(knowledgeBaseId, ['file-copy']); + + const result = await knowledgeBaseModel.copyToWorkspace( + knowledgeBaseId, + 'workspace-copy-target', + userId, + ); + + expect(result.id).not.toBe(knowledgeBaseId); + + const copiedKb = await serverDB.query.knowledgeBases.findFirst({ + where: eq(knowledgeBases.id, result.id), + }); + const copiedLinks = await serverDB.query.knowledgeBaseFiles.findMany({ + where: eq(knowledgeBaseFiles.knowledgeBaseId, result.id), + }); + const copiedDocs = await serverDB.query.documents.findMany({ + where: eq(documents.knowledgeBaseId, result.id), + }); + const originalKb = await serverDB.query.knowledgeBases.findFirst({ + where: eq(knowledgeBases.id, knowledgeBaseId), + }); + + expect(copiedKb).toMatchObject({ + name: 'Copy KB', + workspaceId: 'workspace-copy-target', + }); + expect(copiedLinks).toHaveLength(1); + expect(copiedLinks[0].fileId).not.toBe('file-copy'); + expect(copiedLinks[0].workspaceId).toBe('workspace-copy-target'); + expect(copiedDocs).toHaveLength(2); + expect(copiedDocs.every((doc) => doc.workspaceId === 'workspace-copy-target')).toBe(true); + expect(copiedDocs.find((doc) => doc.title === 'Note')?.parentId).toBe( + copiedDocs.find((doc) => doc.title === 'Folder')?.id, + ); + expect(originalKb?.workspaceId).toBeNull(); + }); + + it('should rename the copied knowledge base when the target has the same name', async () => { + await createWorkspace('workspace-copy-rename-target', 'workspace-copy-rename-target'); + const targetModel = new KnowledgeBaseModel(serverDB, userId, 'workspace-copy-rename-target'); + await targetModel.create({ name: 'Shared KB' }); + const { id: knowledgeBaseId } = await knowledgeBaseModel.create({ name: 'Shared KB' }); + + const result = await knowledgeBaseModel.copyToWorkspace( + knowledgeBaseId, + 'workspace-copy-rename-target', + userId, + ); + + const copiedKb = await serverDB.query.knowledgeBases.findFirst({ + where: eq(knowledgeBases.id, result.id), + }); + + expect(copiedKb?.name).toBe('Shared KB (1)'); + }); + }); + describe('static findById', () => { it('should find a knowledge base by id without user restriction', async () => { const { id } = await knowledgeBaseModel.create({ name: 'Test Group' }); diff --git a/packages/database/src/models/__tests__/messages/message.update.test.ts b/packages/database/src/models/__tests__/messages/message.update.test.ts index aab40b2a51..300d9ba191 100644 --- a/packages/database/src/models/__tests__/messages/message.update.test.ts +++ b/packages/database/src/models/__tests__/messages/message.update.test.ts @@ -15,6 +15,7 @@ import { sessions, topics, users, + workspaces, } from '../../../schemas'; import type { LobeChatDatabase } from '../../../type'; import { MessageModel } from '../../message'; @@ -24,7 +25,9 @@ const serverDB: LobeChatDatabase = await getTestDB(); const userId = 'message-update-test'; const otherUserId = 'message-update-test-other'; +const workspaceId = 'message-update-workspace'; const messageModel = new MessageModel(serverDB, userId); +const workspaceMessageModel = new MessageModel(serverDB, otherUserId, workspaceId); const embeddingsId = uuid(); beforeEach(async () => { @@ -33,6 +36,12 @@ beforeEach(async () => { await trx.delete(users).where(eq(users.id, userId)); await trx.delete(users).where(eq(users.id, otherUserId)); await trx.insert(users).values([{ id: userId }, { id: otherUserId }]); + await trx.insert(workspaces).values({ + id: workspaceId, + name: 'Message Workspace', + primaryOwnerId: userId, + slug: workspaceId, + }); await trx.insert(sessions).values([ // { id: 'session1', userId }, @@ -950,6 +959,30 @@ describe('MessageModel Update Tests', () => { expect(dbResult[0].metadata).toEqual({ originalKey: 'originalValue' }); }); + it('should update workspace messages even when created by another user', async () => { + await serverDB.insert(messages).values({ + id: 'msg-workspace-metadata', + userId, + workspaceId, + role: 'user', + content: 'test message', + metadata: { originalKey: 'originalValue' }, + }); + + await workspaceMessageModel.updateMetadata('msg-workspace-metadata', { + workspaceKey: 'workspaceValue', + }); + + const dbResult = await serverDB + .select() + .from(messages) + .where(eq(messages.id, 'msg-workspace-metadata')); + expect(dbResult[0].metadata).toEqual({ + originalKey: 'originalValue', + workspaceKey: 'workspaceValue', + }); + }); + it('should handle complex nested metadata updates', async () => { // Create test data await serverDB.insert(messages).values({ @@ -1273,6 +1306,33 @@ describe('MessageModel Update Tests', () => { expect(result[0].content).toBe('translated message 1'); }); + it('should insert workspaceId for workspace translate records', async () => { + await serverDB.insert(messages).values({ + id: 'workspace-translate', + userId, + workspaceId, + role: 'user', + content: 'message 1', + }); + + await workspaceMessageModel.updateTranslate('workspace-translate', { + content: 'translated message 1', + from: 'en', + to: 'zh', + }); + + const result = await serverDB + .select() + .from(messageTranslates) + .where(eq(messageTranslates.id, 'workspace-translate')); + + expect(result[0]).toMatchObject({ + id: 'workspace-translate', + userId: otherUserId, + workspaceId, + }); + }); + it('should update the corresponding fields if message exists in messageTranslates table', async () => { // Create test data await serverDB.transaction(async (trx) => { @@ -1314,6 +1374,29 @@ describe('MessageModel Update Tests', () => { expect(result[0].voice).toBe('voice1'); }); + it('should insert workspaceId for workspace TTS records', async () => { + await serverDB.insert(messages).values({ + id: 'workspace-tts', + userId, + workspaceId, + role: 'user', + content: 'message 1', + }); + + await workspaceMessageModel.updateTTS('workspace-tts', { + contentMd5: 'md5', + file: 'f1', + voice: 'voice1', + }); + + const result = await serverDB + .select() + .from(messageTTS) + .where(eq(messageTTS.id, 'workspace-tts')); + + expect(result[0]).toMatchObject({ id: 'workspace-tts', userId: otherUserId, workspaceId }); + }); + it('should update the corresponding fields if message exists in messageTTS table', async () => { // Create test data await serverDB.transaction(async (trx) => { diff --git a/packages/database/src/models/__tests__/messages/message.workspace.test.ts b/packages/database/src/models/__tests__/messages/message.workspace.test.ts new file mode 100644 index 0000000000..c65f6c1e81 --- /dev/null +++ b/packages/database/src/models/__tests__/messages/message.workspace.test.ts @@ -0,0 +1,73 @@ +// @vitest-environment node +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; + +import { getTestDB } from '../../../core/getTestDB'; +import { messages, sessions, topics, users, workspaces } from '../../../schemas'; +import type { LobeChatDatabase } from '../../../type'; +import { MessageModel } from '../../message'; + +const serverDB: LobeChatDatabase = await getTestDB(); + +const userId = 'message-workspace-user'; +const workspaceId = 'message-workspace'; + +beforeEach(async () => { + await serverDB.delete(users); + await serverDB.insert(users).values({ id: userId }); + await serverDB.insert(workspaces).values({ + id: workspaceId, + name: 'Workspace', + primaryOwnerId: userId, + slug: workspaceId, + }); + await serverDB.insert(sessions).values([ + { id: 'personal-session', userId, workspaceId: null }, + { id: 'workspace-session', userId, workspaceId }, + ]); + await serverDB.insert(topics).values([ + { id: 'personal-topic', sessionId: 'personal-session', userId, workspaceId: null }, + { id: 'workspace-topic', sessionId: 'workspace-session', userId, workspaceId }, + ]); +}); + +afterEach(async () => { + await serverDB.delete(users); +}); + +describe('MessageModel workspace scope', () => { + it('isolates personal and workspace messages for the same user', async () => { + await serverDB.insert(messages).values([ + { + content: 'personal', + id: 'personal-message', + role: 'user', + sessionId: 'personal-session', + topicId: 'personal-topic', + userId, + workspaceId: null, + }, + { + content: 'workspace', + id: 'workspace-message', + role: 'user', + sessionId: 'workspace-session', + topicId: 'workspace-topic', + userId, + workspaceId, + }, + ]); + + await expect( + new MessageModel(serverDB, userId).query({ + sessionId: 'personal-session', + topicId: 'personal-topic', + }), + ).resolves.toEqual([expect.objectContaining({ id: 'personal-message' })]); + await expect( + new MessageModel(serverDB, userId, workspaceId).query({ + sessionId: 'workspace-session', + topicId: 'workspace-topic', + }), + ).resolves.toEqual([expect.objectContaining({ id: 'workspace-message' })]); + }); +}); diff --git a/packages/database/src/models/__tests__/messengerAccountLink.test.ts b/packages/database/src/models/__tests__/messengerAccountLink.test.ts index 4f3d312d01..8dcb10a8f0 100644 --- a/packages/database/src/models/__tests__/messengerAccountLink.test.ts +++ b/packages/database/src/models/__tests__/messengerAccountLink.test.ts @@ -2,7 +2,7 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { getTestDB } from '../../core/getTestDB'; -import { agents, messengerAccountLinks, users } from '../../schemas'; +import { agents, messengerAccountLinks, users, workspaces } from '../../schemas'; import type { LobeChatDatabase } from '../../type'; import { MessengerAccountLinkConflictError, @@ -16,19 +16,29 @@ const userA = 'msg-link-user-a'; const userB = 'msg-link-user-b'; const agentA = 'msg-link-agent-a'; const agentB = 'msg-link-agent-b'; +const workspaceA = 'msg-link-workspace-a'; +const workspaceAgentA = 'msg-link-agent-workspace-a'; beforeEach(async () => { await serverDB.delete(users); await serverDB.insert(users).values([{ id: userA }, { id: userB }]); + await serverDB.insert(workspaces).values({ + id: workspaceA, + name: 'Workspace A', + primaryOwnerId: userA, + slug: 'workspace-a', + }); await serverDB.insert(agents).values([ { id: agentA, userId: userA }, { id: agentB, userId: userB }, + { id: workspaceAgentA, userId: userA, workspaceId: workspaceA }, ]); }); afterEach(async () => { await serverDB.delete(messengerAccountLinks); await serverDB.delete(agents); + await serverDB.delete(workspaces); await serverDB.delete(users); }); @@ -225,6 +235,56 @@ describe('MessengerAccountLinkModel', () => { }); }); + describe('active scope (workspaceId)', () => { + // A given IM identity has exactly one link; `workspaceId` on it is the + // *active scope* derived from the active agent (personal → null), not part + // of the link's identity. Switching scope reuses the same row. + it('persists the active scope passed at upsert time', async () => { + const model = new MessengerAccountLinkModel(serverDB, userA); + const personal = await model.upsertForPlatform({ + activeAgentId: agentA, + platform: 'telegram', + platformUserId: 'tg-scope', + workspaceId: null, + }); + expect(personal.workspaceId).toBeNull(); + + // Re-asserting the same identity with a workspace agent flips the active + // scope on the same row — no relink, single identity link. + const switched = await model.upsertForPlatform({ + activeAgentId: workspaceAgentA, + platform: 'telegram', + platformUserId: 'tg-scope', + workspaceId: workspaceA, + }); + expect(switched.id).toBe(personal.id); + expect(switched.workspaceId).toBe(workspaceA); + expect(switched.activeAgentId).toBe(workspaceAgentA); + }); + + it('setActiveAgent updates both the active agent and the derived scope', async () => { + const model = new MessengerAccountLinkModel(serverDB, userA); + await model.upsertForPlatform({ + activeAgentId: agentA, + platform: 'telegram', + platformUserId: 'tg-switch', + workspaceId: null, + }); + + // Switch into a workspace agent. + await model.setActiveAgent('telegram', workspaceAgentA, workspaceA); + let link = await model.findByPlatform('telegram'); + expect(link?.activeAgentId).toBe(workspaceAgentA); + expect(link?.workspaceId).toBe(workspaceA); + + // Switch back to personal. + await model.setActiveAgent('telegram', agentA, null); + link = await model.findByPlatform('telegram'); + expect(link?.activeAgentId).toBe(agentA); + expect(link?.workspaceId).toBeNull(); + }); + }); + describe('setActiveAgent', () => { it('only updates the targeted (platform, tenant) row', async () => { const model = new MessengerAccountLinkModel(serverDB, userA); @@ -241,7 +301,7 @@ describe('MessengerAccountLinkModel', () => { tenantId: 'T_BETA', }); - await model.setActiveAgent('slack', null, 'T_ACME'); + await model.setActiveAgent('slack', null, null, 'T_ACME'); const acme = await model.findByPlatform('slack', 'T_ACME'); const beta = await model.findByPlatform('slack', 'T_BETA'); diff --git a/packages/database/src/models/__tests__/notification.test.ts b/packages/database/src/models/__tests__/notification.test.ts new file mode 100644 index 0000000000..6700471246 --- /dev/null +++ b/packages/database/src/models/__tests__/notification.test.ts @@ -0,0 +1,43 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import { NotificationModel } from '../../models/notification'; +import { notifications } from '../../schemas/notification'; +import type { LobeChatDatabase } from '../../type'; + +describe('NotificationModel', () => { + const returning = vi.fn(); + const onConflictDoNothing = vi.fn(() => ({ returning })); + const values = vi.fn((_payload?: unknown) => ({ onConflictDoNothing })); + const insert = vi.fn(() => ({ values })); + const db = { insert } as unknown as LobeChatDatabase; + + beforeEach(() => { + vi.clearAllMocks(); + returning.mockResolvedValue([{ id: 'notification-1' }]); + }); + + describe('create', () => { + it('creates user-scoped notifications without persisting workspace context', async () => { + const model = new NotificationModel(db, 'user-1'); + + await model.create({ + category: 'workspace', + content: 'You have been removed from the workspace.', + dedupeKey: 'member_removed_workspace-1_user-1', + title: 'Removed from workspace', + type: 'workspace_member_removed', + }); + + const [payload] = values.mock.calls[0]; + + expect(payload).toMatchObject({ + dedupeKey: 'member_removed_workspace-1_user-1', + userId: 'user-1', + }); + expect(payload).not.toHaveProperty('workspaceId'); + expect(onConflictDoNothing).toHaveBeenCalledWith({ + target: [notifications.userId, notifications.dedupeKey], + }); + }); + }); +}); diff --git a/packages/database/src/models/__tests__/ragEval.workspace.test.ts b/packages/database/src/models/__tests__/ragEval.workspace.test.ts new file mode 100644 index 0000000000..1dd5b24d8e --- /dev/null +++ b/packages/database/src/models/__tests__/ragEval.workspace.test.ts @@ -0,0 +1,204 @@ +// @vitest-environment node +import { eq } from 'drizzle-orm'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; + +import { getTestDB } from '../../core/getTestDB'; +import { + evalDatasetRecords, + evalDatasets, + evalEvaluation, + evaluationRecords, + knowledgeBases, + users, + workspaces, +} from '../../schemas'; +import type { LobeChatDatabase } from '../../type'; +import { + EvalDatasetModel, + EvalDatasetRecordModel, + EvalEvaluationModel, + EvaluationRecordModel, +} from '../ragEval'; + +const serverDB: LobeChatDatabase = await getTestDB(); + +const userId = 'rag-eval-workspace-user'; +const workspaceId = 'rag-eval-workspace'; +const personalKnowledgeBaseId = 'rag-eval-personal-kb'; +const workspaceKnowledgeBaseId = 'rag-eval-workspace-kb'; + +beforeEach(async () => { + await serverDB.delete(users); + await serverDB.insert(users).values({ id: userId }); + await serverDB.insert(workspaces).values({ + id: workspaceId, + name: 'RAG Eval Workspace', + primaryOwnerId: userId, + slug: workspaceId, + }); + await serverDB.insert(knowledgeBases).values([ + { + id: personalKnowledgeBaseId, + name: 'Personal KB', + userId, + workspaceId: null, + }, + { + id: workspaceKnowledgeBaseId, + name: 'Workspace KB', + userId, + workspaceId, + }, + ]); +}); + +afterEach(async () => { + await serverDB.delete(users); +}); + +describe('RAG eval workspace scope', () => { + it('isolates datasets and dataset records between personal and workspace scopes', async () => { + const personalDatasetModel = new EvalDatasetModel(serverDB, userId); + const workspaceDatasetModel = new EvalDatasetModel(serverDB, userId, workspaceId); + + const personalDataset = await personalDatasetModel.create({ + knowledgeBaseId: personalKnowledgeBaseId, + name: 'Personal dataset', + }); + const workspaceDataset = await workspaceDatasetModel.create({ + knowledgeBaseId: workspaceKnowledgeBaseId, + name: 'Workspace dataset', + }); + + await expect(personalDatasetModel.query(personalKnowledgeBaseId)).resolves.toEqual([ + expect.objectContaining({ id: personalDataset.id }), + ]); + await expect(workspaceDatasetModel.query(workspaceKnowledgeBaseId)).resolves.toEqual([ + expect.objectContaining({ id: workspaceDataset.id }), + ]); + await expect(personalDatasetModel.findById(personalDataset.id)).resolves.toMatchObject({ + id: personalDataset.id, + workspaceId: null, + }); + await expect(workspaceDatasetModel.findById(workspaceDataset.id)).resolves.toMatchObject({ + id: workspaceDataset.id, + workspaceId, + }); + + const personalRecordModel = new EvalDatasetRecordModel(serverDB, userId); + const workspaceRecordModel = new EvalDatasetRecordModel(serverDB, userId, workspaceId); + + const personalRecord = await personalRecordModel.create({ + datasetId: personalDataset.id, + question: 'Personal question', + }); + const workspaceRecord = await workspaceRecordModel.create({ + datasetId: workspaceDataset.id, + question: 'Workspace question', + }); + + await expect(personalRecordModel.findById(workspaceRecord.id)).resolves.toBeUndefined(); + await expect(workspaceRecordModel.findById(personalRecord.id)).resolves.toBeUndefined(); + + await personalRecordModel.update(personalRecord.id, { question: 'Updated personal question' }); + await expect(personalRecordModel.findById(personalRecord.id)).resolves.toMatchObject({ + question: 'Updated personal question', + workspaceId: null, + }); + + await personalDatasetModel.delete(personalDataset.id); + + await expect(personalDatasetModel.findById(personalDataset.id)).resolves.toBeUndefined(); + await expect(workspaceDatasetModel.findById(workspaceDataset.id)).resolves.toMatchObject({ + id: workspaceDataset.id, + workspaceId, + }); + }); + + it('isolates evaluations and evaluation records between personal and workspace scopes', async () => { + const personalDatasetModel = new EvalDatasetModel(serverDB, userId); + const workspaceDatasetModel = new EvalDatasetModel(serverDB, userId, workspaceId); + const personalRecordModel = new EvalDatasetRecordModel(serverDB, userId); + const workspaceRecordModel = new EvalDatasetRecordModel(serverDB, userId, workspaceId); + const personalEvaluationModel = new EvalEvaluationModel(serverDB, userId); + const workspaceEvaluationModel = new EvalEvaluationModel(serverDB, userId, workspaceId); + const personalEvaluationRecordModel = new EvaluationRecordModel(serverDB, userId); + const workspaceEvaluationRecordModel = new EvaluationRecordModel(serverDB, userId, workspaceId); + + const personalDataset = await personalDatasetModel.create({ + knowledgeBaseId: personalKnowledgeBaseId, + name: 'Personal dataset', + }); + const workspaceDataset = await workspaceDatasetModel.create({ + knowledgeBaseId: workspaceKnowledgeBaseId, + name: 'Workspace dataset', + }); + const personalDatasetRecord = await personalRecordModel.create({ + datasetId: personalDataset.id, + question: 'Personal question', + }); + const workspaceDatasetRecord = await workspaceRecordModel.create({ + datasetId: workspaceDataset.id, + question: 'Workspace question', + }); + + const personalEvaluation = await personalEvaluationModel.create({ + datasetId: personalDataset.id, + knowledgeBaseId: personalKnowledgeBaseId, + name: 'Personal evaluation', + }); + const workspaceEvaluation = await workspaceEvaluationModel.create({ + datasetId: workspaceDataset.id, + knowledgeBaseId: workspaceKnowledgeBaseId, + name: 'Workspace evaluation', + }); + + await expect( + personalEvaluationModel.queryByKnowledgeBaseId(personalKnowledgeBaseId), + ).resolves.toEqual([expect.objectContaining({ id: personalEvaluation.id })]); + await expect( + workspaceEvaluationModel.queryByKnowledgeBaseId(workspaceKnowledgeBaseId), + ).resolves.toEqual([expect.objectContaining({ id: workspaceEvaluation.id })]); + await expect(personalEvaluationModel.findById(personalEvaluation.id)).resolves.toMatchObject({ + id: personalEvaluation.id, + workspaceId: null, + }); + await expect(workspaceEvaluationModel.findById(workspaceEvaluation.id)).resolves.toMatchObject({ + id: workspaceEvaluation.id, + workspaceId, + }); + + const personalEvaluationRecord = await personalEvaluationRecordModel.create({ + datasetRecordId: personalDatasetRecord.id, + evaluationId: personalEvaluation.id, + question: 'Personal eval question', + }); + const workspaceEvaluationRecord = await workspaceEvaluationRecordModel.create({ + datasetRecordId: workspaceDatasetRecord.id, + evaluationId: workspaceEvaluation.id, + question: 'Workspace eval question', + }); + + await expect( + personalEvaluationRecordModel.findById(workspaceEvaluationRecord.id), + ).resolves.toBeUndefined(); + await expect( + workspaceEvaluationRecordModel.findById(personalEvaluationRecord.id), + ).resolves.toBeUndefined(); + + await personalEvaluationRecordModel.delete(personalEvaluationRecord.id); + + await expect(personalEvaluationRecordModel.query(personalEvaluation.id)).resolves.toEqual([]); + await expect(workspaceEvaluationRecordModel.query(workspaceEvaluation.id)).resolves.toEqual([ + expect.objectContaining({ id: workspaceEvaluationRecord.id, workspaceId }), + ]); + }); +}); + +afterEach(async () => { + await serverDB.delete(evaluationRecords).where(eq(evaluationRecords.userId, userId)); + await serverDB.delete(evalEvaluation).where(eq(evalEvaluation.userId, userId)); + await serverDB.delete(evalDatasetRecords).where(eq(evalDatasetRecords.userId, userId)); + await serverDB.delete(evalDatasets).where(eq(evalDatasets.userId, userId)); + await serverDB.delete(knowledgeBases).where(eq(knowledgeBases.userId, userId)); +}); diff --git a/packages/database/src/models/__tests__/rbac.test.ts b/packages/database/src/models/__tests__/rbac.test.ts new file mode 100644 index 0000000000..49dfc6e6c7 --- /dev/null +++ b/packages/database/src/models/__tests__/rbac.test.ts @@ -0,0 +1,205 @@ +// @vitest-environment node +import { + PERMISSION_ACTIONS, + WORKSPACE_ROLE_PERMISSIONS, + WORKSPACE_SYSTEM_ROLES, +} from '@lobechat/const/rbac'; +import { and, eq } from 'drizzle-orm'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; + +import { getTestDB } from '../../core/getTestDB'; +import { permissions, rolePermissions, roles, userRoles, users, workspaces } from '../../schemas'; +import type { LobeChatDatabase } from '../../type'; +import { seedWorkspaceRoles } from '../../utils/seedWorkspaceRoles'; +import { RbacModel } from '../rbac'; + +const serverDB: LobeChatDatabase = await getTestDB(); + +const userId = 'rbac-model-test-user-id'; +const otherUserId = 'rbac-model-test-other-user-id'; +const workspaceAId = 'rbac-ws-a'; +const workspaceBId = 'rbac-ws-b'; + +const cleanup = async () => { + // userRoles + rolePermissions cascade via FK, but workspace-scoped roles only + // cascade when the workspace itself is deleted — so do it explicitly here. + await serverDB.delete(userRoles); + await serverDB.delete(rolePermissions); + await serverDB.delete(roles); + await serverDB.delete(permissions); + await serverDB.delete(workspaces); + await serverDB.delete(users); +}; + +beforeEach(async () => { + await cleanup(); + await serverDB.insert(users).values([{ id: userId }, { id: otherUserId }]); + await serverDB.insert(workspaces).values([ + { id: workspaceAId, name: 'A', primaryOwnerId: userId, slug: 'ws-a' }, + { id: workspaceBId, name: 'B', primaryOwnerId: userId, slug: 'ws-b' }, + ]); + await seedWorkspaceRoles(serverDB, workspaceAId); + await seedWorkspaceRoles(serverDB, workspaceBId); +}); + +afterEach(async () => { + await cleanup(); +}); + +describe('RbacModel — workspace scope', () => { + const ownerCode = `${PERMISSION_ACTIONS.WORKSPACE_UPDATE}:all`; + const memberCode = `${PERMISSION_ACTIONS.WORKSPACE_READ}:all`; + + describe('assignWorkspaceRole / hasPermission with workspaceId', () => { + it('returns true for a permission granted via the assigned role in that workspace', async () => { + const rbac = new RbacModel(serverDB, userId); + await rbac.assignWorkspaceRole({ + roleName: WORKSPACE_SYSTEM_ROLES.OWNER, + userId, + workspaceId: workspaceAId, + }); + + expect(await rbac.hasPermission(ownerCode, { workspaceId: workspaceAId })).toBe(true); + }); + + it('returns false for a permission the assigned role does not include', async () => { + const rbac = new RbacModel(serverDB, userId); + await rbac.assignWorkspaceRole({ + roleName: WORKSPACE_SYSTEM_ROLES.VIEWER, + userId, + workspaceId: workspaceAId, + }); + + // viewer never gets workspace:update:all (only owner does). + expect(await rbac.hasPermission(ownerCode, { workspaceId: workspaceAId })).toBe(false); + // but viewer does have workspace:read:all. + expect(await rbac.hasPermission(memberCode, { workspaceId: workspaceAId })).toBe(true); + }); + + it('does not leak permissions across workspaces', async () => { + const rbac = new RbacModel(serverDB, userId); + await rbac.assignWorkspaceRole({ + roleName: WORKSPACE_SYSTEM_ROLES.OWNER, + userId, + workspaceId: workspaceAId, + }); + + expect(await rbac.hasPermission(ownerCode, { workspaceId: workspaceAId })).toBe(true); + expect(await rbac.hasPermission(ownerCode, { workspaceId: workspaceBId })).toBe(false); + }); + + it('is idempotent', async () => { + const rbac = new RbacModel(serverDB, userId); + await rbac.assignWorkspaceRole({ + roleName: WORKSPACE_SYSTEM_ROLES.OWNER, + userId, + workspaceId: workspaceAId, + }); + // Re-assigning is a no-op thanks to the (userId, roleId, workspaceId) + // unique index — must not throw. + await rbac.assignWorkspaceRole({ + roleName: WORKSPACE_SYSTEM_ROLES.OWNER, + userId, + workspaceId: workspaceAId, + }); + + const grants = await serverDB.query.userRoles.findMany({ + where: and(eq(userRoles.userId, userId), eq(userRoles.workspaceId, workspaceAId)), + }); + expect(grants).toHaveLength(1); + }); + }); + + describe('revokeWorkspaceRole', () => { + it('drops every grant in the named workspace and leaves others untouched', async () => { + const rbac = new RbacModel(serverDB, userId); + await rbac.assignWorkspaceRole({ + roleName: WORKSPACE_SYSTEM_ROLES.OWNER, + userId, + workspaceId: workspaceAId, + }); + await rbac.assignWorkspaceRole({ + roleName: WORKSPACE_SYSTEM_ROLES.OWNER, + userId, + workspaceId: workspaceBId, + }); + + await rbac.revokeWorkspaceRole({ userId, workspaceId: workspaceAId }); + + expect(await rbac.hasPermission(ownerCode, { workspaceId: workspaceAId })).toBe(false); + expect(await rbac.hasPermission(ownerCode, { workspaceId: workspaceBId })).toBe(true); + }); + + it('is a no-op when the user has no grants in the workspace', async () => { + const rbac = new RbacModel(serverDB, userId); + await expect( + rbac.revokeWorkspaceRole({ userId, workspaceId: workspaceAId }), + ).resolves.not.toThrow(); + }); + }); + + describe('getUserPermissions with workspaceId', () => { + it('returns scoped codes for the named workspace, de-duped', async () => { + const rbac = new RbacModel(serverDB, userId); + await rbac.assignWorkspaceRole({ + roleName: WORKSPACE_SYSTEM_ROLES.OWNER, + userId, + workspaceId: workspaceAId, + }); + + const codes = await rbac.getUserPermissions({ workspaceId: workspaceAId }); + + const expected = new Set(WORKSPACE_ROLE_PERMISSIONS[WORKSPACE_SYSTEM_ROLES.OWNER]); + // every code the owner role grants should appear in the result + for (const code of expected) { + expect(codes).toContain(code); + } + // ...and no duplicates + expect(codes).toHaveLength(new Set(codes).size); + }); + + it('does not include workspace B permissions when scoped to workspace A', async () => { + const rbac = new RbacModel(serverDB, userId); + await rbac.assignWorkspaceRole({ + roleName: WORKSPACE_SYSTEM_ROLES.OWNER, + userId, + workspaceId: workspaceBId, + }); + // user has no grant in workspaceA + const codes = await rbac.getUserPermissions({ workspaceId: workspaceAId }); + expect(codes).toEqual([]); + }); + }); + + describe('listWorkspaceRoles', () => { + it('lists the three built-in roles seeded for that workspace', async () => { + const rbac = new RbacModel(serverDB, userId); + const list = await rbac.listWorkspaceRoles(workspaceAId); + const names = list.map((r) => r.name).sort(); + expect(names).toEqual( + [ + WORKSPACE_SYSTEM_ROLES.MEMBER, + WORKSPACE_SYSTEM_ROLES.OWNER, + WORKSPACE_SYSTEM_ROLES.VIEWER, + ].sort(), + ); + expect(list.every((r) => r.workspaceId === workspaceAId)).toBe(true); + }); + }); + + describe('back-compat: no workspaceId', () => { + it('still matches workspace-scoped grants when no workspaceId is given (legacy behavior)', async () => { + // Hono routes call `hasPermission(code)` without workspaceId. This must + // keep returning true for users whose only grant is workspace-scoped, + // otherwise every Hono content route regresses on workspace users. + const rbac = new RbacModel(serverDB, userId); + await rbac.assignWorkspaceRole({ + roleName: WORKSPACE_SYSTEM_ROLES.OWNER, + userId, + workspaceId: workspaceAId, + }); + + expect(await rbac.hasPermission(ownerCode)).toBe(true); + }); + }); +}); diff --git a/packages/database/src/models/__tests__/session.workspace.test.ts b/packages/database/src/models/__tests__/session.workspace.test.ts new file mode 100644 index 0000000000..f8c95bc873 --- /dev/null +++ b/packages/database/src/models/__tests__/session.workspace.test.ts @@ -0,0 +1,79 @@ +// @vitest-environment node +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; + +import { getTestDB } from '../../core/getTestDB'; +import { agents, agentsToSessions, sessions, users, workspaces } from '../../schemas'; +import type { LobeChatDatabase } from '../../type'; +import { SessionModel } from '../session'; + +const serverDB: LobeChatDatabase = await getTestDB(); + +const userId = 'session-workspace-user'; +const workspaceId = 'session-workspace'; + +beforeEach(async () => { + await serverDB.delete(users); + await serverDB.insert(users).values({ id: userId }); + await serverDB.insert(workspaces).values({ + id: workspaceId, + name: 'Workspace', + primaryOwnerId: userId, + slug: workspaceId, + }); +}); + +afterEach(async () => { + await serverDB.delete(users); +}); + +describe('SessionModel workspace scope', () => { + it('isolates personal and workspace sessions for the same user', async () => { + await serverDB.insert(sessions).values([ + { id: 'personal-session', updatedAt: new Date('2023-01-01'), userId, workspaceId: null }, + { + id: 'workspace-session', + updatedAt: new Date('2023-02-01'), + userId, + workspaceId, + }, + ]); + + await expect(new SessionModel(serverDB, userId).query()).resolves.toEqual([ + expect.objectContaining({ id: 'personal-session' }), + ]); + await expect(new SessionModel(serverDB, userId, workspaceId).query()).resolves.toEqual([ + expect.objectContaining({ id: 'workspace-session' }), + ]); + }); + + it('deleteAll on personal scope does not delete workspace sessions or links', async () => { + await serverDB.transaction(async (trx) => { + await trx.insert(sessions).values([ + { id: 'personal-session', updatedAt: new Date('2023-01-01'), userId, workspaceId: null }, + { + id: 'workspace-session', + updatedAt: new Date('2023-02-01'), + userId, + workspaceId, + }, + ]); + await trx.insert(agents).values([ + { id: 'personal-agent', userId, title: 'Personal Agent', workspaceId: null }, + { id: 'workspace-agent', userId, title: 'Workspace Agent', workspaceId }, + ]); + await trx.insert(agentsToSessions).values([ + { agentId: 'personal-agent', sessionId: 'personal-session', userId, workspaceId: null }, + { agentId: 'workspace-agent', sessionId: 'workspace-session', userId, workspaceId }, + ]); + }); + + await new SessionModel(serverDB, userId).deleteAll(); + + await expect(serverDB.select().from(sessions)).resolves.toEqual([ + expect.objectContaining({ id: 'workspace-session', workspaceId }), + ]); + await expect(serverDB.select().from(agentsToSessions)).resolves.toEqual([ + expect.objectContaining({ agentId: 'workspace-agent', sessionId: 'workspace-session' }), + ]); + }); +}); diff --git a/packages/database/src/models/__tests__/topics/topic.create.test.ts b/packages/database/src/models/__tests__/topics/topic.create.test.ts index 5e3db65cef..91c56c588c 100644 --- a/packages/database/src/models/__tests__/topics/topic.create.test.ts +++ b/packages/database/src/models/__tests__/topics/topic.create.test.ts @@ -2,7 +2,15 @@ import { asc, eq, inArray } from 'drizzle-orm'; import { afterEach, beforeEach, describe, expect, it } from 'vitest'; import { getTestDB } from '../../../core/getTestDB'; -import { agents, messagePlugins, messages, sessions, topics, users } from '../../../schemas'; +import { + agents, + messagePlugins, + messages, + sessions, + topics, + users, + workspaces, +} from '../../../schemas'; import type { LobeChatDatabase } from '../../../type'; import type { CreateTopicParams } from '../../topic'; import { TopicModel } from '../../topic'; @@ -95,6 +103,50 @@ describe('TopicModel - Create', () => { expect(unassociatedMessage[0].topicId).toBeNull(); }); + it('should associate workspace messages created by another member', async () => { + const workspaceId = 'topic-create-workspace'; + const workspaceSessionId = 'topic-create-workspace-session'; + const workspaceTopicModel = new TopicModel(serverDB, userId, workspaceId); + + await serverDB.transaction(async (tx) => { + await tx.insert(workspaces).values({ + id: workspaceId, + name: 'Topic Create Workspace', + primaryOwnerId: userId, + slug: workspaceId, + }); + await tx.insert(sessions).values({ + id: workspaceSessionId, + userId, + workspaceId, + }); + await tx.insert(messages).values({ + id: 'workspace-message-other-member', + role: 'user', + sessionId: workspaceSessionId, + userId: userId2, + workspaceId, + }); + }); + + const createdTopic = await workspaceTopicModel.create( + { + messages: ['workspace-message-other-member'], + sessionId: workspaceSessionId, + title: 'Workspace Topic', + }, + 'workspace-topic-created', + ); + + const [updatedMessage] = await serverDB + .select() + .from(messages) + .where(eq(messages.id, 'workspace-message-other-member')); + + expect(createdTopic.workspaceId).toBe(workspaceId); + expect(updatedMessage.topicId).toBe(createdTopic.id); + }); + it('should create a new topic without associating messages', async () => { const topicData = { title: 'New Topic', @@ -230,6 +282,62 @@ describe('TopicModel - Create', () => { expect(updatedMessages[2].topicId).toBe(createdTopics[1].id); }); + it('should batch associate workspace messages created by other members', async () => { + const workspaceId = 'topic-batch-workspace'; + const workspaceSessionId = 'topic-batch-workspace-session'; + const workspaceTopicModel = new TopicModel(serverDB, userId, workspaceId); + + await serverDB.transaction(async (tx) => { + await tx.insert(workspaces).values({ + id: workspaceId, + name: 'Topic Batch Workspace', + primaryOwnerId: userId, + slug: workspaceId, + }); + await tx.insert(sessions).values({ + id: workspaceSessionId, + userId, + workspaceId, + }); + await tx.insert(messages).values([ + { + id: 'workspace-batch-message-1', + role: 'user', + sessionId: workspaceSessionId, + userId: userId2, + workspaceId, + }, + { + id: 'workspace-batch-message-2', + role: 'assistant', + sessionId: workspaceSessionId, + userId, + workspaceId, + }, + ]); + }); + + const createdTopics = await workspaceTopicModel.batchCreate([ + { + messages: ['workspace-batch-message-1', 'workspace-batch-message-2'], + sessionId: workspaceSessionId, + title: 'Workspace Batch Topic', + }, + ]); + + const updatedMessages = await serverDB + .select() + .from(messages) + .where(inArray(messages.id, ['workspace-batch-message-1', 'workspace-batch-message-2'])) + .orderBy(asc(messages.id)); + + expect(createdTopics[0].workspaceId).toBe(workspaceId); + expect(updatedMessages.map((message) => message.topicId)).toEqual([ + createdTopics[0].id, + createdTopics[0].id, + ]); + }); + it('should generate topic IDs if not provided', async () => { const topicParams = [ { title: 'Topic 1', favorite: true, sessionId }, @@ -309,6 +417,56 @@ describe('TopicModel - Create', () => { expect(duplicatedMessages[1].content).toBe('Assistant message'); }); + it('should duplicate workspace messages created by other members', async () => { + const workspaceId = 'topic-duplicate-workspace'; + const topicId = 'workspace-topic-duplicate'; + const workspaceTopicModel = new TopicModel(serverDB, userId, workspaceId); + + await serverDB.transaction(async (tx) => { + await tx.insert(workspaces).values({ + id: workspaceId, + name: 'Topic Duplicate Workspace', + primaryOwnerId: userId, + slug: workspaceId, + }); + await tx.insert(topics).values({ + id: topicId, + title: 'Workspace Original Topic', + userId: userId2, + workspaceId, + }); + await tx.insert(messages).values([ + { + content: 'Other member user message', + id: 'workspace-duplicate-message-1', + role: 'user', + topicId, + userId: userId2, + workspaceId, + }, + { + content: 'Current member assistant message', + id: 'workspace-duplicate-message-2', + role: 'assistant', + topicId, + userId, + workspaceId, + }, + ]); + }); + + const { topic: duplicatedTopic, messages: duplicatedMessages } = + await workspaceTopicModel.duplicate(topicId, 'Workspace Duplicated Topic'); + + expect(duplicatedTopic.workspaceId).toBe(workspaceId); + expect(duplicatedMessages).toHaveLength(2); + expect(duplicatedMessages.map((message) => message.content).sort()).toEqual([ + 'Current member assistant message', + 'Other member user message', + ]); + expect(duplicatedMessages.every((message) => message.workspaceId === workspaceId)).toBe(true); + }); + it('should correctly map parentId references when duplicating messages', async () => { const topicId = 'topic-with-parent-refs'; diff --git a/packages/database/src/models/__tests__/topics/topic.query.test.ts b/packages/database/src/models/__tests__/topics/topic.query.test.ts index 0ebb66dc47..34d38341c2 100644 --- a/packages/database/src/models/__tests__/topics/topic.query.test.ts +++ b/packages/database/src/models/__tests__/topics/topic.query.test.ts @@ -9,6 +9,7 @@ import { sessions, topics, users, + workspaces, } from '../../../schemas'; import type { LobeChatDatabase } from '../../../type'; import { TopicModel } from '../../topic'; @@ -53,6 +54,49 @@ describe('TopicModel - Query', () => { expect(result.items[2].id).toBe('4'); }); + it('should isolate personal and workspace topics for the same user', async () => { + await serverDB.insert(workspaces).values({ + id: 'topic-workspace', + name: 'Workspace', + primaryOwnerId: userId, + slug: 'topic-workspace', + }); + await serverDB.insert(sessions).values({ + id: 'topic-workspace-session', + userId, + workspaceId: 'topic-workspace', + }); + await serverDB.insert(topics).values([ + { + id: 'personal-topic', + sessionId, + updatedAt: new Date('2023-01-01'), + userId, + workspaceId: null, + }, + { + id: 'workspace-topic', + sessionId: 'topic-workspace-session', + updatedAt: new Date('2023-02-01'), + userId, + workspaceId: 'topic-workspace', + }, + ]); + + await expect(topicModel.query({ containerId: sessionId })).resolves.toMatchObject({ + items: [expect.objectContaining({ id: 'personal-topic' })], + total: 1, + }); + await expect( + new TopicModel(serverDB, userId, 'topic-workspace').query({ + containerId: 'topic-workspace-session', + }), + ).resolves.toMatchObject({ + items: [expect.objectContaining({ id: 'workspace-topic' })], + total: 1, + }); + }); + it('should order by status priority when sortBy is "status"', async () => { await serverDB.insert(topics).values([ // favorite floats to the top regardless of its (lower-priority) status diff --git a/packages/database/src/models/__tests__/workspace.test.ts b/packages/database/src/models/__tests__/workspace.test.ts new file mode 100644 index 0000000000..8ad77f999d --- /dev/null +++ b/packages/database/src/models/__tests__/workspace.test.ts @@ -0,0 +1,282 @@ +// @vitest-environment node +import { eq } from 'drizzle-orm'; +import { afterEach, beforeEach, describe, expect, it } from 'vitest'; + +import { getTestDB } from '../../core/getTestDB'; +import { + users, + workspaceAuditLogs, + workspaceInvitations, + workspaceMembers, + workspaces, +} from '../../schemas'; +import type { LobeChatDatabase } from '../../type'; +import { WorkspaceModel } from '../workspace'; +import { WorkspaceAuditLogModel } from '../workspaceAuditLog'; +import { WorkspaceMemberModel } from '../workspaceMember'; + +const serverDB: LobeChatDatabase = await getTestDB(); + +const ownerId = 'workspace-model-owner'; +const memberId = 'workspace-model-member'; +const secondOwnerId = 'workspace-model-second-owner'; +const outsiderId = 'workspace-model-outsider'; + +const cleanup = async () => { + await serverDB.delete(workspaceAuditLogs); + await serverDB.delete(workspaceInvitations); + await serverDB.delete(workspaceMembers); + await serverDB.delete(workspaces); + await serverDB.delete(users); +}; + +const createWorkspace = async (id = 'workspace-model-ws') => { + await serverDB.insert(workspaces).values({ + id, + name: id, + primaryOwnerId: ownerId, + settings: { gracePeriodUntil: 123, keep: true }, + slug: id, + }); + await serverDB.insert(workspaceMembers).values([ + { role: 'owner', userId: ownerId, workspaceId: id }, + { role: 'member', userId: memberId, workspaceId: id }, + { role: 'owner', userId: secondOwnerId, workspaceId: id }, + ]); + return id; +}; + +beforeEach(async () => { + await cleanup(); + await serverDB + .insert(users) + .values([{ id: ownerId }, { id: memberId }, { id: secondOwnerId }, { id: outsiderId }]); +}); + +afterEach(async () => { + await cleanup(); +}); + +describe('WorkspaceModel', () => { + it('creates the workspace and inserts the creator as owner member', async () => { + const model = new WorkspaceModel(serverDB, ownerId); + + const workspace = await model.create({ + avatar: 'avatar.png', + description: 'Team workspace', + name: 'Acme', + slug: 'acme', + }); + + expect(workspace.primaryOwnerId).toBe(ownerId); + expect(workspace.slug).toBe('acme'); + + const membership = await serverDB.query.workspaceMembers.findFirst({ + where: eq(workspaceMembers.workspaceId, workspace.id), + }); + expect(membership).toMatchObject({ + role: 'owner', + userId: ownerId, + workspaceId: workspace.id, + }); + }); + + it('lists active memberships with their workspace roles and skips deleted memberships', async () => { + const workspaceId = await createWorkspace(); + await serverDB + .update(workspaceMembers) + .set({ deletedAt: new Date() }) + .where(eq(workspaceMembers.userId, memberId)); + + const ownerWorkspaces = await new WorkspaceModel(serverDB, ownerId).listUserWorkspaces(); + const memberWorkspaces = await new WorkspaceModel(serverDB, memberId).listUserWorkspaces(); + + expect(ownerWorkspaces).toEqual([expect.objectContaining({ id: workspaceId, role: 'owner' })]); + expect(memberWorkspaces).toEqual([]); + }); + + it('does not delete workspaces owned by another primary owner', async () => { + const workspaceId = await createWorkspace(); + + await new WorkspaceModel(serverDB, outsiderId).delete(workspaceId); + + const workspace = await serverDB.query.workspaces.findFirst({ + where: eq(workspaces.id, workspaceId), + }); + expect(workspace).toBeDefined(); + }); + + it('transfers primary ownership only to an active owner member', async () => { + const workspaceId = await createWorkspace(); + const model = new WorkspaceModel(serverDB, ownerId); + + await expect(model.transferPrimaryOwnership(workspaceId, memberId)).rejects.toThrow( + 'Target user must already be an owner', + ); + + await expect(model.transferPrimaryOwnership(workspaceId, secondOwnerId)).resolves.toEqual({ + newPrimaryOwnerUserId: secondOwnerId, + previousPrimaryOwnerUserId: ownerId, + workspaceId, + }); + + const workspace = await serverDB.query.workspaces.findFirst({ + where: eq(workspaces.id, workspaceId), + }); + expect(workspace?.primaryOwnerId).toBe(secondOwnerId); + }); + + it('downgrades to solo by removing non-primary members and clearing grace period', async () => { + const workspaceId = await createWorkspace(); + + const result = await new WorkspaceModel(serverDB, ownerId).downgradeToSolo(workspaceId); + + expect(result.removedUserIds.sort()).toEqual([memberId, secondOwnerId].sort()); + expect(result.workspace.settings).toEqual({ keep: true }); + + const activeMembers = await serverDB.query.workspaceMembers.findMany({ + where: eq(workspaceMembers.workspaceId, workspaceId), + }); + expect(activeMembers).toEqual( + expect.arrayContaining([ + expect.objectContaining({ deletedAt: null, userId: ownerId }), + expect.objectContaining({ userId: memberId }), + expect.objectContaining({ userId: secondOwnerId }), + ]), + ); + expect( + activeMembers.filter((member) => !member.deletedAt).map((member) => member.userId), + ).toEqual([ownerId]); + }); + + it('sets and clears grace period without dropping unrelated settings', async () => { + const workspaceId = await createWorkspace(); + const model = new WorkspaceModel(serverDB, ownerId); + + await model.setGracePeriod(workspaceId, 456); + await expect(model.getSettings(workspaceId)).resolves.toEqual({ + gracePeriodUntil: 456, + keep: true, + }); + + await model.setGracePeriod(workspaceId, null); + await expect(model.getSettings(workspaceId)).resolves.toEqual({ keep: true }); + }); +}); + +describe('WorkspaceMemberModel', () => { + it('revives a deleted member on addMember and applies the new role', async () => { + const workspaceId = await createWorkspace(); + const model = new WorkspaceMemberModel(serverDB, ownerId); + + await model.removeMember(workspaceId, memberId); + const revived = await model.addMember({ role: 'viewer', userId: memberId, workspaceId }); + + expect(revived).toMatchObject({ + deletedAt: null, + role: 'viewer', + userId: memberId, + workspaceId, + }); + }); + + it('lists only active members unless includeDeleted is requested', async () => { + const workspaceId = await createWorkspace(); + const model = new WorkspaceMemberModel(serverDB, ownerId); + + await model.removeMember(workspaceId, memberId); + + const active = await model.listMembers(workspaceId); + const all = await model.listMembers(workspaceId, { includeDeleted: true }); + + expect(active.map((member) => member.userId).sort()).toEqual([ownerId, secondOwnerId].sort()); + expect(all.map((member) => member.userId).sort()).toEqual( + [ownerId, memberId, secondOwnerId].sort(), + ); + }); + + it('creates pending invitations with a default member role and expiry', async () => { + const workspaceId = await createWorkspace(); + const before = new Date(); + + const invitation = await new WorkspaceMemberModel(serverDB, ownerId).createInvitation({ + email: 'new@example.com', + workspaceId, + }); + + expect(invitation).toMatchObject({ + email: 'new@example.com', + inviterId: ownerId, + role: 'member', + status: 'pending', + workspaceId, + }); + expect(invitation.token).toHaveLength(32); + expect(invitation.expiresAt.getTime()).toBeGreaterThan( + before.getTime() + 6 * 24 * 60 * 60 * 1000, + ); + }); +}); + +describe('WorkspaceAuditLogModel', () => { + it('creates logs with empty metadata by default', async () => { + const workspaceId = await createWorkspace(); + + const log = await new WorkspaceAuditLogModel(serverDB).create({ + action: 'workspace.created', + userId: ownerId, + workspaceId, + }); + + expect(log).toMatchObject({ + action: 'workspace.created', + metadata: {}, + userId: ownerId, + workspaceId, + }); + }); + + it('lists logs by workspace and action with cursor pagination', async () => { + const workspaceId = await createWorkspace(); + await serverDB.insert(workspaceAuditLogs).values([ + { + action: 'workspace.created', + createdAt: new Date('2026-01-01T00:00:00.000Z'), + resourceId: 'old', + userId: ownerId, + workspaceId, + }, + { + action: 'workspace.updated', + createdAt: new Date('2026-01-02T00:00:00.000Z'), + resourceId: 'middle', + userId: ownerId, + workspaceId, + }, + { + action: 'workspace.updated', + createdAt: new Date('2026-01-03T00:00:00.000Z'), + resourceId: 'new', + userId: ownerId, + workspaceId, + }, + ]); + + const result = await new WorkspaceAuditLogModel(serverDB).list({ + action: 'workspace.updated', + limit: 1, + workspaceId, + }); + + expect(result.items.map((item) => item.resourceId)).toEqual(['new']); + expect(result.nextCursor).toBe('2026-01-03T00:00:00.000Z'); + + const next = await new WorkspaceAuditLogModel(serverDB).list({ + action: 'workspace.updated', + cursor: new Date(result.nextCursor!), + limit: 1, + workspaceId, + }); + expect(next.items.map((item) => item.resourceId)).toEqual(['middle']); + }); +}); diff --git a/packages/database/src/models/agent.ts b/packages/database/src/models/agent.ts index 6bc874abf2..bc7d3a5012 100644 --- a/packages/database/src/models/agent.ts +++ b/packages/database/src/models/agent.ts @@ -8,25 +8,33 @@ import { merge } from '@/utils/merge'; import type { AgentItem } from '../schemas'; import { + agentBotProviders, + agentCronJobs, agents, agentsFiles, agentsKnowledgeBases, agentsToSessions, + chatGroupsAgents, documents, files, knowledgeBases, + messages, sessions, + threads, topics, } from '../schemas'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; export class AgentModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } /** @@ -45,21 +53,44 @@ export class AgentModel { }) .from(agents) .leftJoin(topics, eq(topics.agentId, agents.id)) - .where( - and( - eq(agents.userId, this.userId), - or(eq(agents.slug, INBOX_SESSION_ID), ne(agents.virtual, true)), - ), - ) + .where(and(this.ownership(), or(eq(agents.slug, INBOX_SESSION_ID), ne(agents.virtual, true)))) .groupBy(agents.id) .having(({ count }) => gt(count, 0)) .orderBy(desc(sql`count`)) .limit(limit); }; + /** + * Compat-mode ownership predicate for the `agents` table. + * - team mode (workspaceId set): `workspace_id = ?` (every member sees the same agents) + * - personal mode: `user_id = ? AND workspace_id IS NULL` + */ + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agents); + + /** Same predicate but for the `sessions` table (used in delete cascade). */ + private sessionsOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, sessions); + + /** Ownership predicates for the agent join/related tables. */ + private documentsOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents); + + private agentsFilesOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentsFiles); + + private agentsKnowledgeBasesOwnership = () => + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + agentsKnowledgeBases, + ); + + private agentsToSessionsOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentsToSessions); + getAgentConfigById = async (id: string) => { const agent = await this.db.query.agents.findFirst({ - where: and(eq(agents.id, id), eq(agents.userId, this.userId)), + where: and(eq(agents.id, id), this.ownership()), }); if (!agent) return null; @@ -71,7 +102,7 @@ export class AgentModel { const rows = await this.db .select({ id: agents.id }) .from(agents) - .where(and(eq(agents.id, id), eq(agents.userId, this.userId))) + .where(and(eq(agents.id, id), this.ownership())) .limit(1); return rows.length > 0; @@ -90,9 +121,7 @@ export class AgentModel { const rows = await this.db .select({ model: agents.model, provider: agents.provider }) .from(agents) - .where( - and(eq(agents.userId, this.userId), or(eq(agents.id, idOrSlug), eq(agents.slug, idOrSlug))), - ) + .where(and(this.ownership(), or(eq(agents.id, idOrSlug), eq(agents.slug, idOrSlug)))) .limit(1); const row = rows[0]; @@ -107,7 +136,7 @@ export class AgentModel { private buildQueryAgentsWhere = (keyword?: string) => { // Include agents where virtual is false OR null (legacy data without virtual field) const baseConditions = and( - eq(agents.userId, this.userId), + this.ownership(), or(eq(agents.virtual, false), isNull(agents.virtual)), ); @@ -173,7 +202,7 @@ export class AgentModel { title: agents.title, }) .from(agents) - .where(and(eq(agents.userId, this.userId), inArray(agents.id, ids))); + .where(and(this.ownership(), inArray(agents.id, ids))); return rows.map(({ slug, ...row }) => ({ ...row, @@ -186,12 +215,15 @@ export class AgentModel { * Get agent config by ID or slug (single query with OR condition) */ getAgentConfig = async (idOrSlug: string) => { - const agent = await this.db.query.agents.findFirst({ - where: and( - eq(agents.userId, this.userId), - or(eq(agents.id, idOrSlug), eq(agents.slug, idOrSlug)), - ), - }); + // Prefer an exact ID match over a slug match. The combined `or(id, slug)` + // query has no inherent ordering, so resolve ID first for determinism. + const agent = + (await this.db.query.agents.findFirst({ + where: and(this.ownership(), eq(agents.id, idOrSlug)), + })) ?? + (await this.db.query.agents.findFirst({ + where: and(this.ownership(), eq(agents.slug, idOrSlug)), + })); if (!agent) return null; @@ -214,7 +246,7 @@ export class AgentModel { if (enabledFileIds.length > 0) { const documentsData = await this.db.query.documents.findMany({ - where: and(eq(documents.userId, this.userId), inArray(documents.fileId, enabledFileIds)), + where: and(this.documentsOwnership(), inArray(documents.fileId, enabledFileIds)), }); const documentMap = new Map(documentsData.map((doc) => [doc.fileId, doc.content])); @@ -234,15 +266,13 @@ export class AgentModel { this.db .select({ enabled: agentsKnowledgeBases.enabled, knowledgeBases }) .from(agentsKnowledgeBases) - .where( - and(eq(agentsKnowledgeBases.agentId, id), eq(agentsKnowledgeBases.userId, this.userId)), - ) + .where(and(eq(agentsKnowledgeBases.agentId, id), this.agentsKnowledgeBasesOwnership())) .orderBy(desc(agentsKnowledgeBases.createdAt)) .leftJoin(knowledgeBases, eq(knowledgeBases.id, agentsKnowledgeBases.knowledgeBaseId)), this.db .select({ enabled: agentsFiles.enabled, files }) .from(agentsFiles) - .where(and(eq(agentsFiles.agentId, id), eq(agentsFiles.userId, this.userId))) + .where(and(eq(agentsFiles.agentId, id), this.agentsFilesOwnership())) .orderBy(desc(agentsFiles.createdAt)) .leftJoin(files, eq(files.id, agentsFiles.fileId)), ]); @@ -264,10 +294,7 @@ export class AgentModel { */ findBySessionId = async (sessionId: string) => { const item = await this.db.query.agentsToSessions.findFirst({ - where: and( - eq(agentsToSessions.sessionId, sessionId), - eq(agentsToSessions.userId, this.userId), - ), + where: and(eq(agentsToSessions.sessionId, sessionId), this.agentsToSessionsOwnership()), }); if (!item) return; @@ -282,12 +309,14 @@ export class AgentModel { knowledgeBaseId: string, enabled: boolean = true, ) => { - return this.db.insert(agentsKnowledgeBases).values({ - agentId, - enabled, - knowledgeBaseId, - userId: this.userId, - }); + return this.db + .insert(agentsKnowledgeBases) + .values( + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { agentId, enabled, knowledgeBaseId }, + ), + ); }; deleteAgentKnowledgeBase = async (agentId: string, knowledgeBaseId: string) => { @@ -297,7 +326,7 @@ export class AgentModel { and( eq(agentsKnowledgeBases.agentId, agentId), eq(agentsKnowledgeBases.knowledgeBaseId, knowledgeBaseId), - eq(agentsKnowledgeBases.userId, this.userId), + this.agentsKnowledgeBasesOwnership(), ), ); }; @@ -310,7 +339,7 @@ export class AgentModel { and( eq(agentsKnowledgeBases.agentId, agentId), eq(agentsKnowledgeBases.knowledgeBaseId, knowledgeBaseId), - eq(agentsKnowledgeBases.userId, this.userId), + this.agentsKnowledgeBasesOwnership(), ), ); }; @@ -323,7 +352,7 @@ export class AgentModel { .where( and( eq(agentsFiles.agentId, agentId), - eq(agentsFiles.userId, this.userId), + this.agentsFilesOwnership(), inArray(agentsFiles.fileId, fileIds), ), ); @@ -337,7 +366,12 @@ export class AgentModel { return this.db .insert(agentsFiles) .values( - needToInsertFileIds.map((fileId) => ({ agentId, enabled, fileId, userId: this.userId })), + needToInsertFileIds.map((fileId) => + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { agentId, enabled, fileId }, + ), + ), ); }; @@ -348,7 +382,7 @@ export class AgentModel { and( eq(agentsFiles.agentId, agentId), eq(agentsFiles.fileId, fileId), - eq(agentsFiles.userId, this.userId), + this.agentsFilesOwnership(), ), ); }; @@ -363,28 +397,24 @@ export class AgentModel { const links = await trx .select({ sessionId: agentsToSessions.sessionId }) .from(agentsToSessions) - .where( - and(eq(agentsToSessions.agentId, agentId), eq(agentsToSessions.userId, this.userId)), - ); + .where(and(eq(agentsToSessions.agentId, agentId), this.agentsToSessionsOwnership())); const sessionIds = links.map((link) => link.sessionId); // 2. Delete links in agentsToSessions await trx .delete(agentsToSessions) - .where( - and(eq(agentsToSessions.agentId, agentId), eq(agentsToSessions.userId, this.userId)), - ); + .where(and(eq(agentsToSessions.agentId, agentId), this.agentsToSessionsOwnership())); // 3. Delete associated sessions (this will cascade delete messages, topics, etc.) if (sessionIds.length > 0) { await trx .delete(sessions) - .where(and(inArray(sessions.id, sessionIds), eq(sessions.userId, this.userId))); + .where(and(inArray(sessions.id, sessionIds), this.sessionsOwnership())); } // 4. Delete the agent itself - return trx.delete(agents).where(and(eq(agents.id, agentId), eq(agents.userId, this.userId))); + return trx.delete(agents).where(and(eq(agents.id, agentId), this.ownership())); }); }; @@ -396,9 +426,7 @@ export class AgentModel { batchDelete = async (agentIds: string[]) => { if (agentIds.length === 0) return; - return this.db - .delete(agents) - .where(and(eq(agents.userId, this.userId), inArray(agents.id, agentIds))); + return this.db.delete(agents).where(and(this.ownership(), inArray(agents.id, agentIds))); }; toggleFile = async (agentId: string, fileId: string, enabled?: boolean) => { @@ -409,7 +437,7 @@ export class AgentModel { and( eq(agentsFiles.agentId, agentId), eq(agentsFiles.fileId, fileId), - eq(agentsFiles.userId, this.userId), + this.agentsFilesOwnership(), ), ); }; @@ -422,11 +450,13 @@ export class AgentModel { const [result] = await this.db .insert(agents) .values([ - { - ...config, - model: typeof config.model === 'string' ? config.model : null, - userId: this.userId, - }, + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { + ...config, + model: typeof config.model === 'string' ? config.model : null, + }, + ), ]) .returning(); @@ -443,11 +473,15 @@ export class AgentModel { return this.db .insert(agents) .values( - configs.map((config) => ({ - ...config, - model: typeof config.model === 'string' ? config.model : null, - userId: this.userId, - })), + configs.map((config) => + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { + ...config, + model: typeof config.model === 'string' ? config.model : null, + }, + ), + ), ) .returning(); }; @@ -456,7 +490,7 @@ export class AgentModel { return this.db .update(agents) .set({ ...data, updatedAt: new Date() }) - .where(and(eq(agents.id, agentId), eq(agents.userId, this.userId))); + .where(and(eq(agents.id, agentId), this.ownership())); }; touchUpdatedAt = async (agentId: string) => { @@ -469,7 +503,7 @@ export class AgentModel { */ checkByMarketIdentifier = async (marketIdentifier: string): Promise => { const result = await this.db.query.agents.findFirst({ - where: and(eq(agents.marketIdentifier, marketIdentifier), eq(agents.userId, this.userId)), + where: and(eq(agents.marketIdentifier, marketIdentifier), this.ownership()), }); return !!result; }; @@ -483,7 +517,7 @@ export class AgentModel { const result = await this.db.query.agents.findFirst({ columns: { id: true }, orderBy: (agents, { desc }) => [desc(agents.updatedAt)], - where: and(eq(agents.marketIdentifier, marketIdentifier), eq(agents.userId, this.userId)), + where: and(eq(agents.marketIdentifier, marketIdentifier), this.ownership()), }); return result?.id ?? null; }; @@ -498,7 +532,7 @@ export class AgentModel { columns: { id: true }, orderBy: (agents, { desc }) => [desc(agents.updatedAt)], where: and( - eq(agents.userId, this.userId), + this.ownership(), sql`${agents.params}->>'forkedFromIdentifier' = ${forkedFromIdentifier}`, ), }); @@ -509,7 +543,7 @@ export class AgentModel { if (!data || Object.keys(data).length === 0) return; const agent = await this.db.query.agents.findFirst({ - where: and(eq(agents.id, agentId), eq(agents.userId, this.userId)), + where: and(eq(agents.id, agentId), this.ownership()), }); if (!agent) return; @@ -562,7 +596,7 @@ export class AgentModel { return this.db .update(agents) .set(updateData) - .where(and(eq(agents.id, agentId), eq(agents.userId, this.userId))); + .where(and(eq(agents.id, agentId), this.ownership())); }; /** @@ -572,7 +606,7 @@ export class AgentModel { const result = await this.db .update(agents) .set({ sessionGroupId, updatedAt: new Date() }) - .where(and(eq(agents.id, agentId), eq(agents.userId, this.userId))) + .where(and(eq(agents.id, agentId), this.ownership())) .returning(); return result[0]; @@ -585,7 +619,7 @@ export class AgentModel { duplicate = async (agentId: string, newTitle?: string): Promise<{ agentId: string } | null> => { // Get the source agent const sourceAgent = await this.db.query.agents.findFirst({ - where: and(eq(agents.id, agentId), eq(agents.userId, this.userId)), + where: and(eq(agents.id, agentId), this.ownership()), }); if (!sourceAgent) return null; @@ -593,32 +627,35 @@ export class AgentModel { // Create new agent with explicit include fields const [newAgent] = await this.db .insert(agents) - .values({ - avatar: sourceAgent.avatar, - backgroundColor: sourceAgent.backgroundColor, - chatConfig: sourceAgent.chatConfig, - description: sourceAgent.description, - fewShots: sourceAgent.fewShots, - model: sourceAgent.model, - openingMessage: sourceAgent.openingMessage, - openingQuestions: sourceAgent.openingQuestions, - params: sourceAgent.params, - pinned: sourceAgent.pinned, - // Config - plugins: sourceAgent.plugins, - provider: sourceAgent.provider, + .values( + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { + avatar: sourceAgent.avatar, + backgroundColor: sourceAgent.backgroundColor, + chatConfig: sourceAgent.chatConfig, + description: sourceAgent.description, + fewShots: sourceAgent.fewShots, + model: sourceAgent.model, + openingMessage: sourceAgent.openingMessage, + openingQuestions: sourceAgent.openingQuestions, + params: sourceAgent.params, + pinned: sourceAgent.pinned, + // Config + plugins: sourceAgent.plugins, + provider: sourceAgent.provider, - // Session group - sessionGroupId: sourceAgent.sessionGroupId, - systemRole: sourceAgent.systemRole, + // Session group + sessionGroupId: sourceAgent.sessionGroupId, + systemRole: sourceAgent.systemRole, - tags: sourceAgent.tags, - // Metadata - title: newTitle || (sourceAgent.title ? `${sourceAgent.title} (Copy)` : 'Copy'), - tts: sourceAgent.tts, - // User - userId: this.userId, - }) + tags: sourceAgent.tags, + // Metadata + title: newTitle || (sourceAgent.title ? `${sourceAgent.title} (Copy)` : 'Copy'), + tts: sourceAgent.tts, + }, + ), + ) .returning(); return { agentId: newAgent.id }; @@ -632,7 +669,7 @@ export class AgentModel { getBuiltinAgent = async (slug: string): Promise => { // 1. First try to find existing agent by slug const existing = await this.db.query.agents.findFirst({ - where: and(eq(agents.slug, slug), eq(agents.userId, this.userId)), + where: and(eq(agents.slug, slug), this.ownership()), }); if (existing) return existing; @@ -647,7 +684,7 @@ export class AgentModel { .from(sessions) .innerJoin(agentsToSessions, eq(sessions.id, agentsToSessions.sessionId)) .innerJoin(agents, eq(agentsToSessions.agentId, agents.id)) - .where(and(eq(sessions.slug, INBOX_SESSION_ID), eq(sessions.userId, this.userId))) + .where(and(eq(sessions.slug, INBOX_SESSION_ID), this.sessionsOwnership())) .limit(1); if (result.length > 0 && result[0].agent) { @@ -673,30 +710,156 @@ export class AgentModel { // `onConflictDoNothing`, the loser hits the `agents_slug_user_id_unique` // constraint; with it, the loser's `.returning()` is empty and we re-read // the row that won. - // `agents_slug_user_id_unique` is a partial index (WHERE workspace_id IS - // NULL) since migration 0109, so the conflict arbiter must carry the same - // predicate; builtin agents are always workspace-less (workspace_id NULL). + // Bare `onConflictDoNothing()` (no target) does NOT pin an arbiter index, + // so it works whether `agents_slug_user_id_unique` is the legacy full + // unique or the migration-0109 partial (WHERE workspace_id IS NULL) — this + // is the transition-safe form while 0109 rolls out. Tighten back to a + // partitioned { target, where } once 0109 has flipped the index in every + // environment. Payload still carries workspaceId so workspace-scoped + // builtin agents land in the right workspace. const result = await this.db .insert(agents) - .values({ - model: persistConfig.model, - provider: persistConfig.provider, - slug: persistConfig.slug, - userId: this.userId, - virtual: true, - }) - .onConflictDoNothing({ - target: [agents.slug, agents.userId], - where: isNull(agents.workspaceId), - }) + .values( + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { + model: persistConfig.model, + provider: persistConfig.provider, + slug: persistConfig.slug, + virtual: true, + }, + ), + ) + .onConflictDoNothing() .returning(); if (result[0]) return result[0]; return ( (await this.db.query.agents.findFirst({ - where: and(eq(agents.slug, slug), eq(agents.userId, this.userId)), + where: and(eq(agents.slug, slug), this.ownership()), })) ?? null ); }; + + /** + * Transfer an agent and all its associated data to a different workspace or personal account. + * Runs in a single transaction to ensure atomicity. + */ + transferAgent = async ( + agentId: string, + targetWorkspaceId: string | null, + targetUserId: string, + ): Promise<{ agentId: string; slug: string | null }> => { + return this.db.transaction(async (trx) => { + // 1. Verify agent exists and belongs to current scope + const agent = await trx.query.agents.findFirst({ + where: and(eq(agents.id, agentId), this.ownership()), + }); + if (!agent) throw new Error('Agent not found'); + + // 2. Handle slug conflict in target scope + let slug = agent.slug; + if (slug) { + const buildConflictCheck = (candidate: string) => + targetWorkspaceId + ? and(eq(agents.slug, candidate), eq(agents.workspaceId, targetWorkspaceId)) + : and( + eq(agents.slug, candidate), + eq(agents.userId, targetUserId), + isNull(agents.workspaceId), + ); + + const existing = await trx.query.agents.findFirst({ + where: buildConflictCheck(slug), + }); + if (existing) { + let suffix = 1; + while (suffix < 100) { + const candidate = `${slug}-${suffix}`; + const conflict = await trx.query.agents.findFirst({ + where: buildConflictCheck(candidate), + }); + if (!conflict) { + slug = candidate; + break; + } + suffix++; + } + } + } + + // 3. Build ownership update payload + const ownershipUpdate = { + userId: targetUserId, + workspaceId: targetWorkspaceId, + }; + + // 4. Update the agent record + await trx + .update(agents) + .set({ ...ownershipUpdate, slug, updatedAt: new Date() }) + .where(eq(agents.id, agentId)); + + // 5. Update sessions linked via agentsToSessions + const links = await trx + .select({ sessionId: agentsToSessions.sessionId }) + .from(agentsToSessions) + .where(eq(agentsToSessions.agentId, agentId)); + + const sessionIds = links.map((l) => l.sessionId); + + if (sessionIds.length > 0) { + await trx.update(sessions).set(ownershipUpdate).where(inArray(sessions.id, sessionIds)); + } + + await trx + .update(agentsToSessions) + .set(ownershipUpdate) + .where(eq(agentsToSessions.agentId, agentId)); + + // 6. Update topics (linked via sessionId or agentId) + const topicCondition = + sessionIds.length > 0 + ? or(inArray(topics.sessionId, sessionIds), eq(topics.agentId, agentId)) + : eq(topics.agentId, agentId); + await trx.update(topics).set(ownershipUpdate).where(topicCondition!); + + // 7. Update messages (linked via sessionId or agentId) + const messageCondition = + sessionIds.length > 0 + ? or(inArray(messages.sessionId, sessionIds), eq(messages.agentId, agentId)) + : eq(messages.agentId, agentId); + await trx.update(messages).set(ownershipUpdate).where(messageCondition!); + + // 8. Update threads (linked via agentId) + await trx.update(threads).set(ownershipUpdate).where(eq(threads.agentId, agentId)); + + // 9. Update agent files associations + await trx.update(agentsFiles).set(ownershipUpdate).where(eq(agentsFiles.agentId, agentId)); + + // 10. Update agent knowledge base associations + await trx + .update(agentsKnowledgeBases) + .set(ownershipUpdate) + .where(eq(agentsKnowledgeBases.agentId, agentId)); + + // 11. Update agent cron jobs + await trx + .update(agentCronJobs) + .set(ownershipUpdate) + .where(eq(agentCronJobs.agentId, agentId)); + + // 12. Update agent bot providers (transfer, not delete) + await trx + .update(agentBotProviders) + .set(ownershipUpdate) + .where(eq(agentBotProviders.agentId, agentId)); + + // 13. Remove chat group associations (groups belong to source workspace context) + await trx.delete(chatGroupsAgents).where(eq(chatGroupsAgents.agentId, agentId)); + + return { agentId, slug }; + }); + }; } diff --git a/packages/database/src/models/agentBotProvider.ts b/packages/database/src/models/agentBotProvider.ts index 43407833e9..f6993b8268 100644 --- a/packages/database/src/models/agentBotProvider.ts +++ b/packages/database/src/models/agentBotProvider.ts @@ -3,6 +3,7 @@ import { and, desc, eq } from 'drizzle-orm'; import type { AgentBotProviderItem, NewAgentBotProvider } from '../schemas'; import { agentBotProviders } from '../schemas'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; interface GateKeeper { decrypt: (ciphertext: string) => Promise<{ plaintext: string }>; @@ -16,14 +17,19 @@ export interface DecryptedBotProvider extends Omit + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentBotProviders); + // --------------- User-scoped CRUD --------------- create = async ( @@ -35,7 +41,12 @@ export class AgentBotProviderModel { const [result] = await this.db .insert(agentBotProviders) - .values({ ...params, credentials, userId: this.userId }) + .values( + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { ...params, credentials }, + ), + ) .returning(); return result; @@ -44,11 +55,11 @@ export class AgentBotProviderModel { delete = async (id: string) => { return this.db .delete(agentBotProviders) - .where(and(eq(agentBotProviders.id, id), eq(agentBotProviders.userId, this.userId))); + .where(and(eq(agentBotProviders.id, id), this.ownership())); }; query = async (params?: { agentId?: string; platform?: string }) => { - const conditions = [eq(agentBotProviders.userId, this.userId)]; + const conditions = [this.ownership()]; if (params?.agentId) { conditions.push(eq(agentBotProviders.agentId, params.agentId)); @@ -70,7 +81,7 @@ export class AgentBotProviderModel { const [result] = await this.db .select() .from(agentBotProviders) - .where(and(eq(agentBotProviders.id, id), eq(agentBotProviders.userId, this.userId))) + .where(and(eq(agentBotProviders.id, id), this.ownership())) .limit(1); if (!result) return result; @@ -82,7 +93,7 @@ export class AgentBotProviderModel { const results = await this.db .select() .from(agentBotProviders) - .where(and(eq(agentBotProviders.agentId, agentId), eq(agentBotProviders.userId, this.userId))) + .where(and(eq(agentBotProviders.agentId, agentId), this.ownership())) .orderBy(desc(agentBotProviders.updatedAt)); return Promise.all(results.map((r) => this.decryptRow(r))); @@ -104,7 +115,7 @@ export class AgentBotProviderModel { return this.db .update(agentBotProviders) .set({ ...updateValue, updatedAt: new Date() }) - .where(and(eq(agentBotProviders.id, id), eq(agentBotProviders.userId, this.userId))); + .where(and(eq(agentBotProviders.id, id), this.ownership())); }; // --------------- System-wide static methods --------------- @@ -139,7 +150,7 @@ export class AgentBotProviderModel { and( eq(agentBotProviders.platform, platform), eq(agentBotProviders.applicationId, applicationId), - eq(agentBotProviders.userId, this.userId), + this.ownership(), eq(agentBotProviders.enabled, true), ), ) @@ -152,6 +163,88 @@ export class AgentBotProviderModel { // --------------- System-wide static methods --------------- + /** + * System-wide lookup of an enabled provider by platform + applicationId. + * + * `(platform, applicationId)` is globally unique, so this returns the single + * matching row regardless of which user / workspace owns it. Use only from + * post-authorization runtime layers (gateway service / manager / connect-queue + * cron) where the caller has already been authorized at the router boundary — + * never as an authorization check itself. + */ + static findEnabledByPlatformAndAppId = async ( + db: LobeChatDatabase, + platform: string, + applicationId: string, + gateKeeper?: GateKeeper, + ): Promise => { + const [result] = await db + .select() + .from(agentBotProviders) + .where( + and( + eq(agentBotProviders.platform, platform), + eq(agentBotProviders.applicationId, applicationId), + eq(agentBotProviders.enabled, true), + ), + ) + .limit(1); + + if (!result) return null; + + if (!result.credentials) return { ...result, credentials: {} }; + + try { + const credentials = gateKeeper + ? JSON.parse((await gateKeeper.decrypt(result.credentials)).plaintext) + : JSON.parse(result.credentials); + + return { ...result, credentials }; + } catch { + return { ...result, credentials: {} }; + } + }; + + /** + * System-wide lookup of all providers under an agent. + * + * An agent belongs to a single owner / workspace, so this returns every row + * for the agent regardless of scope. Same authorization caveat as + * {@link findEnabledByPlatformAndAppId}: runtime-layer use only. + */ + static findByAgentId = async ( + db: LobeChatDatabase, + agentId: string, + gateKeeper?: GateKeeper, + ): Promise => { + const results = await db + .select() + .from(agentBotProviders) + .where(eq(agentBotProviders.agentId, agentId)) + .orderBy(desc(agentBotProviders.updatedAt)); + + const decrypted: DecryptedBotProvider[] = []; + + for (const r of results) { + if (!r.credentials) { + decrypted.push({ ...r, credentials: {} }); + continue; + } + + try { + const credentials = gateKeeper + ? JSON.parse((await gateKeeper.decrypt(r.credentials)).plaintext) + : JSON.parse(r.credentials); + + decrypted.push({ ...r, credentials }); + } catch { + decrypted.push({ ...r, credentials: {} }); + } + } + + return decrypted; + }; + static findEnabledByPlatform = async ( db: LobeChatDatabase, platform: string, diff --git a/packages/database/src/models/agentCronJob.ts b/packages/database/src/models/agentCronJob.ts index 402a223778..b9d0c07bde 100644 --- a/packages/database/src/models/agentCronJob.ts +++ b/packages/database/src/models/agentCronJob.ts @@ -8,27 +8,36 @@ import type { } from '../schemas/agentCronJob'; import { agentCronJobs } from '../schemas/agentCronJob'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; export class AgentCronJobModel { private readonly userId: string; private readonly db: LobeChatDatabase; + private readonly workspaceId?: string; - constructor(db: LobeChatDatabase, userId?: string) { + constructor(db: LobeChatDatabase, userId?: string, workspaceId?: string) { this.db = db; this.userId = userId!; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentCronJobs); + // Create a new cron job async create(data: CreateAgentCronJobData): Promise { const cronJob = await this.db .insert(agentCronJobs) - .values({ - ...data, - // Initialize remaining executions to match max executions - remainingExecutions: data.maxExecutions, - - userId: this.userId, - } as NewAgentCronJob) + .values( + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { + ...data, + // Initialize remaining executions to match max executions + remainingExecutions: data.maxExecutions, + }, + ) as NewAgentCronJob, + ) .returning(); return cronJob[0]; @@ -39,7 +48,7 @@ export class AgentCronJobModel { const result = await this.db .select() .from(agentCronJobs) - .where(and(eq(agentCronJobs.id, id), eq(agentCronJobs.userId, this.userId))) + .where(and(eq(agentCronJobs.id, id), this.ownership())) .limit(1); return result[0] || null; @@ -50,7 +59,7 @@ export class AgentCronJobModel { return this.db .select() .from(agentCronJobs) - .where(and(eq(agentCronJobs.agentId, agentId), eq(agentCronJobs.userId, this.userId))) + .where(and(eq(agentCronJobs.agentId, agentId), this.ownership())) .orderBy(desc(agentCronJobs.createdAt)); } @@ -59,7 +68,7 @@ export class AgentCronJobModel { return this.db .select() .from(agentCronJobs) - .where(eq(agentCronJobs.userId, this.userId)) + .where(this.ownership()) .orderBy(desc(agentCronJobs.lastExecutedAt)); } @@ -109,7 +118,7 @@ export class AgentCronJobModel { const result = await this.db .update(agentCronJobs) .set(updateData) - .where(and(eq(agentCronJobs.id, id), eq(agentCronJobs.userId, this.userId))) + .where(and(eq(agentCronJobs.id, id), this.ownership())) .returning(); return result[0] || null; @@ -119,7 +128,7 @@ export class AgentCronJobModel { async delete(id: string): Promise { const result = await this.db .delete(agentCronJobs) - .where(and(eq(agentCronJobs.id, id), eq(agentCronJobs.userId, this.userId))) + .where(and(eq(agentCronJobs.id, id), this.ownership())) .returning(); return result.length > 0; @@ -181,7 +190,7 @@ export class AgentCronJobModel { totalExecutions: 0, updatedAt: new Date(), }) - .where(and(eq(agentCronJobs.id, id), eq(agentCronJobs.userId, this.userId))) + .where(and(eq(agentCronJobs.id, id), this.ownership())) .returning(); return result[0] || null; @@ -194,7 +203,7 @@ export class AgentCronJobModel { .from(agentCronJobs) .where( and( - eq(agentCronJobs.userId, this.userId), + this.ownership(), eq(agentCronJobs.enabled, true), gt(agentCronJobs.remainingExecutions, 0), sql`${agentCronJobs.remainingExecutions} <= ${threshold}`, @@ -208,7 +217,7 @@ export class AgentCronJobModel { return this.db .select() .from(agentCronJobs) - .where(and(eq(agentCronJobs.userId, this.userId), eq(agentCronJobs.enabled, enabled))) + .where(and(this.ownership(), eq(agentCronJobs.enabled, enabled))) .orderBy(desc(agentCronJobs.updatedAt)); } @@ -232,7 +241,7 @@ export class AgentCronJobModel { totalJobs: sql`count(*)`, }) .from(agentCronJobs) - .where(eq(agentCronJobs.userId, this.userId)); + .where(this.ownership()); const stats = result[0]; return { @@ -251,7 +260,7 @@ export class AgentCronJobModel { enabled, updatedAt: new Date(), }) - .where(and(inArray(agentCronJobs.id, ids), eq(agentCronJobs.userId, this.userId))) + .where(and(inArray(agentCronJobs.id, ids), this.ownership())) .returning(); return result.length; @@ -262,7 +271,7 @@ export class AgentCronJobModel { const result = await this.db .select({ count: sql`count(*)` }) .from(agentCronJobs) - .where(and(eq(agentCronJobs.agentId, agentId), eq(agentCronJobs.userId, this.userId))); + .where(and(eq(agentCronJobs.agentId, agentId), this.ownership())); return Number(result[0].count); } @@ -276,7 +285,7 @@ export class AgentCronJobModel { }): Promise<{ jobs: AgentCronJob[]; total: number }> { const { agentId, enabled, limit = 20, offset = 0 } = options; - const whereConditions = [eq(agentCronJobs.userId, this.userId)]; + const whereConditions = [this.ownership()]; if (agentId) { whereConditions.push(eq(agentCronJobs.agentId, agentId)); diff --git a/packages/database/src/models/agentDocuments/agentDocument.ts b/packages/database/src/models/agentDocuments/agentDocument.ts index e7bb723ca8..32c8cefcd2 100644 --- a/packages/database/src/models/agentDocuments/agentDocument.ts +++ b/packages/database/src/models/agentDocuments/agentDocument.ts @@ -4,6 +4,7 @@ import { and, asc, desc, eq, inArray, isNotNull, isNull, like, or, sql } from 'd import type { DocumentItem, NewAgentDocument, NewDocument } from '../../schemas'; import { AGENT_SKILL_TEMPLATE_ID, agentDocuments, documents } from '../../schemas'; import type { LobeChatDatabase, Transaction } from '../../type'; +import { buildWorkspaceWhere } from '../../utils/workspace'; import { deriveAgentDocumentFields } from './deriveFields'; import { buildDocumentFilename } from './filename'; import { @@ -71,13 +72,31 @@ interface ConvertAgentDocumentToSkillIndexParams { export class AgentDocumentModel { private userId: string; + private workspaceId?: string; private db: LobeChatDatabase; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; + this.workspaceId = workspaceId; this.db = db; } + /** + * Workspace-aware ownership predicate for the `agent_documents` binding table. + * Personal mode → `user_id = ? AND workspace_id IS NULL`; workspace mode → `workspace_id = ?`. + */ + private agentDocOwnership() { + return buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + agentDocuments, + ); + } + + /** Workspace-aware ownership predicate for the backing `documents` rows. */ + private documentOwnership() { + return buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents); + } + private getDocumentStats(content: string) { if (!content) return { totalCharCount: 0, totalLineCount: 0 }; @@ -175,7 +194,7 @@ export class AgentDocumentModel { .innerJoin(documents, eq(agentDocuments.documentId, documents.id)) .where( and( - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), eq(agentDocuments.agentId, agentId), inArray(documents.parentId, parentIds), ...this.buildDeletedAtFilters(options), @@ -212,7 +231,7 @@ export class AgentDocumentModel { const [doc] = await trx .select() .from(documents) - .where(and(eq(documents.id, documentId), eq(documents.userId, this.userId))) + .where(and(eq(documents.id, documentId), this.documentOwnership())) .limit(1); if (!doc) return { id: '' }; @@ -235,6 +254,7 @@ export class AgentDocumentModel { policyLoadPosition: DocumentLoadPosition.BEFORE_FIRST_USER, policyLoadRule: DocumentLoadRule.ALWAYS, userId: this.userId, + workspaceId: this.workspaceId ?? null, }) .onConflictDoNothing() .returning({ id: agentDocuments.id }); @@ -332,6 +352,7 @@ export class AgentDocumentModel { totalLineCount: stats.totalLineCount, updatedAt: updatedAt ?? createdAt, userId: this.userId, + workspaceId: this.workspaceId ?? null, }; const [insertedDocument] = await trx.insert(documents).values(documentPayload).returning(); @@ -361,6 +382,7 @@ export class AgentDocumentModel { templateId, updatedAt: updatedAt ?? createdAt, userId: this.userId, + workspaceId: this.workspaceId ?? null, }; const [settings] = await trx.insert(agentDocuments).values(newDoc).returning(); @@ -414,7 +436,7 @@ export class AgentDocumentModel { .where( and( eq(agentDocuments.id, params.agentDocumentId), - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), isNull(agentDocuments.deletedAt), ), ) @@ -445,7 +467,7 @@ export class AgentDocumentModel { totalLineCount: stats.totalLineCount, updatedAt, }) - .where(and(eq(documents.id, existing.documentId), eq(documents.userId, this.userId))); + .where(and(eq(documents.id, existing.documentId), this.documentOwnership())); await trx .update(agentDocuments) @@ -457,7 +479,7 @@ export class AgentDocumentModel { .where( and( eq(agentDocuments.id, params.agentDocumentId), - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), isNull(agentDocuments.deletedAt), ), ); @@ -469,7 +491,7 @@ export class AgentDocumentModel { .where( and( eq(agentDocuments.id, params.agentDocumentId), - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), isNull(agentDocuments.deletedAt), ), ) @@ -559,13 +581,13 @@ export class AgentDocumentModel { await trx .update(documents) .set(documentUpdate) - .where(and(eq(documents.id, existing.documentId), eq(documents.userId, this.userId))); + .where(and(eq(documents.id, existing.documentId), this.documentOwnership())); } await trx .update(agentDocuments) .set(settingsUpdate) - .where(and(eq(agentDocuments.id, documentId), eq(agentDocuments.userId, this.userId))); + .where(and(eq(agentDocuments.id, documentId), this.agentDocOwnership())); }); } @@ -616,7 +638,7 @@ export class AgentDocumentModel { ...(params.parentId !== undefined && { parentId: params.parentId }), ...(params.title !== undefined && { title: params.title }), }) - .where(and(eq(documents.id, existing.documentId), eq(documents.userId, this.userId))); + .where(and(eq(documents.id, existing.documentId), this.documentOwnership())); return this.findById(agentDocumentId); } @@ -658,7 +680,7 @@ export class AgentDocumentModel { source, title, }) - .where(and(eq(documents.id, existing.documentId), eq(documents.userId, this.userId))); + .where(and(eq(documents.id, existing.documentId), this.documentOwnership())); }); return this.findById(documentId); @@ -701,7 +723,7 @@ export class AgentDocumentModel { source, title: filename, }) - .where(and(eq(documents.id, existing.documentId), eq(documents.userId, this.userId))); + .where(and(eq(documents.id, existing.documentId), this.documentOwnership())); }); return this.findById(documentId); @@ -749,7 +771,7 @@ export class AgentDocumentModel { .where( and( eq(agentDocuments.id, documentId), - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), isNull(agentDocuments.deletedAt), ), ); @@ -775,7 +797,7 @@ export class AgentDocumentModel { .where( and( eq(agentDocuments.id, documentId), - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), ...this.buildDeletedAtFilters(options), ), ) @@ -871,7 +893,7 @@ export class AgentDocumentModel { .innerJoin(documents, eq(agentDocuments.documentId, documents.id)) .where( and( - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), eq(agentDocuments.agentId, agentId), isNull(agentDocuments.deletedAt), ), @@ -895,7 +917,7 @@ export class AgentDocumentModel { .innerJoin(documents, eq(agentDocuments.documentId, documents.id)) .where( and( - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), eq(agentDocuments.agentId, agentId), isNull(agentDocuments.deletedAt), or( @@ -958,7 +980,7 @@ export class AgentDocumentModel { .innerJoin(documents, eq(agentDocuments.documentId, documents.id)) .where( and( - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), eq(agentDocuments.agentId, agentId), isNull(agentDocuments.deletedAt), ), @@ -1013,7 +1035,7 @@ export class AgentDocumentModel { .innerJoin(documents, eq(agentDocuments.documentId, documents.id)) .where( and( - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), eq(agentDocuments.agentId, agentId), inArray(agentDocuments.documentId, documentIds), isNull(agentDocuments.deletedAt), @@ -1037,7 +1059,7 @@ export class AgentDocumentModel { .from(agentDocuments) .where( and( - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), eq(agentDocuments.agentId, agentId), isNull(agentDocuments.deletedAt), ), @@ -1054,7 +1076,7 @@ export class AgentDocumentModel { .innerJoin(documents, eq(agentDocuments.documentId, documents.id)) .where( and( - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), eq(agentDocuments.agentId, agentId), eq(agentDocuments.templateId, templateId), isNull(agentDocuments.deletedAt), @@ -1083,7 +1105,7 @@ export class AgentDocumentModel { .innerJoin(documents, eq(agentDocuments.documentId, documents.id)) .where( and( - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), eq(agentDocuments.agentId, agentId), eq(documents.filename, filename), ...this.buildDeletedAtFilters(options), @@ -1109,7 +1131,7 @@ export class AgentDocumentModel { .innerJoin(documents, eq(agentDocuments.documentId, documents.id)) .where( and( - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), eq(agentDocuments.agentId, agentId), eq(documents.filename, filename), parentId ? eq(documents.parentId, parentId) : isNull(documents.parentId), @@ -1149,7 +1171,7 @@ export class AgentDocumentModel { .innerJoin(documents, eq(agentDocuments.documentId, documents.id)) .where( and( - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), eq(agentDocuments.agentId, agentId), eq(documents.filename, filename), parentId ? eq(documents.parentId, parentId) : isNull(documents.parentId), @@ -1174,7 +1196,7 @@ export class AgentDocumentModel { .innerJoin(documents, eq(agentDocuments.documentId, documents.id)) .where( and( - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), eq(agentDocuments.agentId, agentId), eq(agentDocuments.documentId, documentId), ...this.buildDeletedAtFilters(options), @@ -1199,7 +1221,7 @@ export class AgentDocumentModel { .innerJoin(documents, eq(agentDocuments.documentId, documents.id)) .where( and( - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), eq(agentDocuments.agentId, agentId), parentId ? eq(documents.parentId, parentId) : isNull(documents.parentId), ...this.buildDeletedAtFilters(options), @@ -1219,9 +1241,9 @@ export class AgentDocumentModel { .innerJoin(documents, eq(agentDocuments.documentId, documents.id)) .where( and( - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), eq(agentDocuments.agentId, agentId), - eq(documents.userId, this.userId), + this.documentOwnership(), isNotNull(agentDocuments.deletedAt), ), ) @@ -1270,7 +1292,7 @@ export class AgentDocumentModel { .where( and( eq(agentDocuments.id, documentId), - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), isNull(agentDocuments.deletedAt), ), ); @@ -1296,7 +1318,7 @@ export class AgentDocumentModel { }) .where( and( - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), inArray( agentDocuments.id, subtree.map((item) => item.id), @@ -1336,7 +1358,7 @@ export class AgentDocumentModel { deletedByUserId: null, policyLoad: PolicyLoad.PROGRESSIVE, }) - .where(and(eq(agentDocuments.id, documentId), eq(agentDocuments.userId, this.userId))); + .where(and(eq(agentDocuments.id, documentId), this.agentDocOwnership())); }); } @@ -1358,7 +1380,7 @@ export class AgentDocumentModel { }) .where( and( - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), inArray( agentDocuments.id, subtree.map((item) => item.id), @@ -1375,11 +1397,11 @@ export class AgentDocumentModel { await this.db.transaction(async (trx) => { await trx .delete(agentDocuments) - .where(and(eq(agentDocuments.id, documentId), eq(agentDocuments.userId, this.userId))); + .where(and(eq(agentDocuments.id, documentId), this.agentDocOwnership())); await trx .delete(documents) - .where(and(eq(documents.id, existing.documentId), eq(documents.userId, this.userId))); + .where(and(eq(documents.id, existing.documentId), this.documentOwnership())); }); } @@ -1399,13 +1421,11 @@ export class AgentDocumentModel { await this.db.transaction(async (trx) => { await trx .delete(agentDocuments) - .where( - and(eq(agentDocuments.userId, this.userId), inArray(agentDocuments.id, agentDocumentIds)), - ); + .where(and(this.agentDocOwnership(), inArray(agentDocuments.id, agentDocumentIds))); await trx .delete(documents) - .where(and(eq(documents.userId, this.userId), inArray(documents.id, documentIds))); + .where(and(this.documentOwnership(), inArray(documents.id, documentIds))); }); } @@ -1423,7 +1443,7 @@ export class AgentDocumentModel { .where( and( eq(agentDocuments.agentId, agentId), - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), isNull(agentDocuments.deletedAt), ), ); @@ -1447,7 +1467,7 @@ export class AgentDocumentModel { and( eq(agentDocuments.agentId, agentId), eq(agentDocuments.templateId, templateId), - eq(agentDocuments.userId, this.userId), + this.agentDocOwnership(), isNull(agentDocuments.deletedAt), ), ); diff --git a/packages/database/src/models/agentEval/benchmark.ts b/packages/database/src/models/agentEval/benchmark.ts index 883654e014..674aa23c85 100644 --- a/packages/database/src/models/agentEval/benchmark.ts +++ b/packages/database/src/models/agentEval/benchmark.ts @@ -8,23 +8,46 @@ import { type NewAgentEvalBenchmark, } from '../../schemas'; import { type LobeChatDatabase } from '../../type'; +import { buildWorkspaceWhere } from '../../utils/workspace'; export class AgentEvalBenchmarkModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } + /** + * Ownership predicate: rows the current actor can see/edit. Includes + * workspace-scoped or personal rows AND system rows (`userId IS NULL`). + */ + private ownership = () => + or( + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + agentEvalBenchmarks, + ), + isNull(agentEvalBenchmarks.userId), + ); + + /** Mutate-only predicate excluding system rows. */ + private mutableOwnership = () => + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + agentEvalBenchmarks, + ); + /** * Create a new benchmark */ create = async (params: NewAgentEvalBenchmark) => { const [result] = await this.db .insert(agentEvalBenchmarks) - .values({ ...params, userId: this.userId }) + .values({ ...params, userId: this.userId, workspaceId: this.workspaceId ?? null }) .returning(); return result; }; @@ -39,23 +62,20 @@ export class AgentEvalBenchmarkModel { and( eq(agentEvalBenchmarks.id, id), eq(agentEvalBenchmarks.isSystem, false), - eq(agentEvalBenchmarks.userId, this.userId), + this.mutableOwnership(), ), ); }; /** - * Query benchmarks (system + user-created) + * Query benchmarks (system + user/workspace-created) * @param includeSystem - Whether to include system benchmarks (default: true) */ query = async (includeSystem = true) => { - const userCondition = or( - eq(agentEvalBenchmarks.userId, this.userId), - isNull(agentEvalBenchmarks.userId), - ); + const userCondition = this.ownership(); const conditions = includeSystem ? userCondition - : and(eq(agentEvalBenchmarks.isSystem, false), userCondition); + : and(eq(agentEvalBenchmarks.isSystem, false), this.ownership()); const datasetCountSq = this.db .select({ @@ -63,6 +83,12 @@ export class AgentEvalBenchmarkModel { count: count().as('dataset_count'), }) .from(agentEvalDatasets) + .where( + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + agentEvalDatasets, + ), + ) .groupBy(agentEvalDatasets.benchmarkId) .as('dc'); @@ -73,6 +99,12 @@ export class AgentEvalBenchmarkModel { }) .from(agentEvalTestCases) .innerJoin(agentEvalDatasets, eq(agentEvalTestCases.datasetId, agentEvalDatasets.id)) + .where( + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + agentEvalDatasets, + ), + ) .groupBy(agentEvalDatasets.benchmarkId) .as('tc'); @@ -83,7 +115,9 @@ export class AgentEvalBenchmarkModel { }) .from(agentEvalRuns) .innerJoin(agentEvalDatasets, eq(agentEvalRuns.datasetId, agentEvalDatasets.id)) - .where(eq(agentEvalRuns.userId, this.userId)) + .where( + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentEvalRuns), + ) .groupBy(agentEvalDatasets.benchmarkId) .as('rc'); @@ -109,7 +143,13 @@ export class AgentEvalBenchmarkModel { .from(agentEvalRuns) .innerJoin(agentEvalDatasets, eq(agentEvalRuns.datasetId, agentEvalDatasets.id)) .where( - and(eq(agentEvalDatasets.benchmarkId, row.id), eq(agentEvalRuns.userId, this.userId)), + and( + eq(agentEvalDatasets.benchmarkId, row.id), + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + agentEvalRuns, + ), + ), ) .orderBy(desc(agentEvalRuns.createdAt)) .limit(5); @@ -144,12 +184,7 @@ export class AgentEvalBenchmarkModel { const [result] = await this.db .select() .from(agentEvalBenchmarks) - .where( - and( - eq(agentEvalBenchmarks.id, id), - or(eq(agentEvalBenchmarks.userId, this.userId), isNull(agentEvalBenchmarks.userId)), - ), - ) + .where(and(eq(agentEvalBenchmarks.id, id), this.ownership())) .limit(1); return result; }; @@ -161,12 +196,7 @@ export class AgentEvalBenchmarkModel { const [result] = await this.db .select() .from(agentEvalBenchmarks) - .where( - and( - eq(agentEvalBenchmarks.identifier, identifier), - or(eq(agentEvalBenchmarks.userId, this.userId), isNull(agentEvalBenchmarks.userId)), - ), - ) + .where(and(eq(agentEvalBenchmarks.identifier, identifier), this.ownership())) .limit(1); return result; }; @@ -182,7 +212,7 @@ export class AgentEvalBenchmarkModel { and( eq(agentEvalBenchmarks.id, id), eq(agentEvalBenchmarks.isSystem, false), - eq(agentEvalBenchmarks.userId, this.userId), + this.mutableOwnership(), ), ) .returning(); diff --git a/packages/database/src/models/agentEval/dataset.ts b/packages/database/src/models/agentEval/dataset.ts index f4a33256ec..a2c105776c 100644 --- a/packages/database/src/models/agentEval/dataset.ts +++ b/packages/database/src/models/agentEval/dataset.ts @@ -2,23 +2,40 @@ import { and, asc, count, desc, eq, isNull, or } from 'drizzle-orm'; import { agentEvalDatasets, agentEvalTestCases, type NewAgentEvalDataset } from '../../schemas'; import { type LobeChatDatabase } from '../../type'; +import { buildWorkspaceWhere } from '../../utils/workspace'; export class AgentEvalDatasetModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } + /** Includes system datasets (`userId IS NULL`) on read. */ + private ownership = () => + or( + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + agentEvalDatasets, + ), + isNull(agentEvalDatasets.userId), + ); + + /** Mutate-only predicate excluding system rows. */ + private mutableOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentEvalDatasets); + /** * Create a new dataset */ create = async (params: NewAgentEvalDataset) => { const [result] = await this.db .insert(agentEvalDatasets) - .values({ ...params, userId: this.userId }) + .values({ ...params, userId: this.userId, workspaceId: this.workspaceId ?? null }) .returning(); return result; }; @@ -29,17 +46,15 @@ export class AgentEvalDatasetModel { delete = async (id: string) => { return this.db .delete(agentEvalDatasets) - .where(and(eq(agentEvalDatasets.id, id), eq(agentEvalDatasets.userId, this.userId))); + .where(and(eq(agentEvalDatasets.id, id), this.mutableOwnership())); }; /** - * Query datasets (system + user-owned) with test case counts + * Query datasets (system + user/workspace-owned) with test case counts * @param benchmarkId - Optional benchmark filter */ query = async (benchmarkId?: string) => { - const conditions = [ - or(eq(agentEvalDatasets.userId, this.userId), isNull(agentEvalDatasets.userId)), - ]; + const conditions = [this.ownership()]; if (benchmarkId) { conditions.push(eq(agentEvalDatasets.benchmarkId, benchmarkId)); @@ -74,12 +89,7 @@ export class AgentEvalDatasetModel { const [dataset] = await this.db .select() .from(agentEvalDatasets) - .where( - and( - eq(agentEvalDatasets.id, id), - or(eq(agentEvalDatasets.userId, this.userId), isNull(agentEvalDatasets.userId)), - ), - ) + .where(and(eq(agentEvalDatasets.id, id), this.ownership())) .limit(1); if (!dataset) return undefined; @@ -100,7 +110,7 @@ export class AgentEvalDatasetModel { const [result] = await this.db .update(agentEvalDatasets) .set({ ...value, updatedAt: new Date() }) - .where(and(eq(agentEvalDatasets.id, id), eq(agentEvalDatasets.userId, this.userId))) + .where(and(eq(agentEvalDatasets.id, id), this.mutableOwnership())) .returning(); return result; }; diff --git a/packages/database/src/models/agentEval/run.ts b/packages/database/src/models/agentEval/run.ts index 4642b7c9da..39d6de09cb 100644 --- a/packages/database/src/models/agentEval/run.ts +++ b/packages/database/src/models/agentEval/run.ts @@ -2,23 +2,29 @@ import { and, count, desc, eq, inArray } from 'drizzle-orm'; import { agentEvalDatasets, agentEvalRuns, type NewAgentEvalRun } from '../../schemas'; import { type LobeChatDatabase } from '../../type'; +import { buildWorkspaceWhere } from '../../utils/workspace'; export class AgentEvalRunModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentEvalRuns); + /** * Create a new run */ create = async (params: Omit) => { const [result] = await this.db .insert(agentEvalRuns) - .values({ ...params, userId: this.userId }) + .values({ ...params, userId: this.userId, workspaceId: this.workspaceId ?? null }) .returning(); return result; }; @@ -33,7 +39,7 @@ export class AgentEvalRunModel { offset?: number; status?: 'idle' | 'pending' | 'running' | 'completed' | 'failed' | 'aborted' | 'external'; }) => { - const conditions = [eq(agentEvalRuns.userId, this.userId)]; + const conditions = [this.ownership()]; if (filter?.datasetId) { conditions.push(eq(agentEvalRuns.datasetId, filter.datasetId)); @@ -77,7 +83,7 @@ export class AgentEvalRunModel { const [result] = await this.db .select() .from(agentEvalRuns) - .where(and(eq(agentEvalRuns.id, id), eq(agentEvalRuns.userId, this.userId))) + .where(and(eq(agentEvalRuns.id, id), this.ownership())) .limit(1); return result; }; @@ -89,7 +95,7 @@ export class AgentEvalRunModel { const [result] = await this.db .update(agentEvalRuns) .set({ ...value, updatedAt: new Date() }) - .where(and(eq(agentEvalRuns.id, id), eq(agentEvalRuns.userId, this.userId))) + .where(and(eq(agentEvalRuns.id, id), this.ownership())) .returning(); return result; }; @@ -98,9 +104,7 @@ export class AgentEvalRunModel { * Delete run (only user-created runs) */ delete = async (id: string) => { - return this.db - .delete(agentEvalRuns) - .where(and(eq(agentEvalRuns.id, id), eq(agentEvalRuns.userId, this.userId))); + return this.db.delete(agentEvalRuns).where(and(eq(agentEvalRuns.id, id), this.ownership())); }; /** @@ -110,7 +114,7 @@ export class AgentEvalRunModel { const result = await this.db .select({ value: count() }) .from(agentEvalRuns) - .where(and(eq(agentEvalRuns.datasetId, datasetId), eq(agentEvalRuns.userId, this.userId))); + .where(and(eq(agentEvalRuns.datasetId, datasetId), this.ownership())); return Number(result[0]?.value) || 0; }; } diff --git a/packages/database/src/models/agentEval/runTopic.ts b/packages/database/src/models/agentEval/runTopic.ts index abc28609c5..695df03de8 100644 --- a/packages/database/src/models/agentEval/runTopic.ts +++ b/packages/database/src/models/agentEval/runTopic.ts @@ -9,22 +9,32 @@ import { topics, } from '../../schemas'; import { type LobeChatDatabase } from '../../type'; +import { buildWorkspaceWhere } from '../../utils/workspace'; export class AgentEvalRunTopicModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentEvalRunTopics); + /** * Batch create run-topic associations */ batchCreate = async (items: Omit[]) => { if (items.length === 0) return []; - const withUserId = items.map((item) => ({ ...item, userId: this.userId })); + const withUserId = items.map((item) => ({ + ...item, + userId: this.userId, + workspaceId: this.workspaceId ?? null, + })); return this.db.insert(agentEvalRunTopics).values(withUserId).returning(); }; @@ -48,7 +58,7 @@ export class AgentEvalRunTopicModel { .from(agentEvalRunTopics) .leftJoin(agentEvalTestCases, eq(agentEvalRunTopics.testCaseId, agentEvalTestCases.id)) .leftJoin(topics, eq(agentEvalRunTopics.topicId, topics.id)) - .where(and(eq(agentEvalRunTopics.runId, runId), eq(agentEvalRunTopics.userId, this.userId))) + .where(and(eq(agentEvalRunTopics.runId, runId), this.ownership())) .orderBy(asc(agentEvalTestCases.sortOrder)); return rows; @@ -60,7 +70,7 @@ export class AgentEvalRunTopicModel { deleteByRunId = async (runId: string) => { return this.db .delete(agentEvalRunTopics) - .where(and(eq(agentEvalRunTopics.runId, runId), eq(agentEvalRunTopics.userId, this.userId))); + .where(and(eq(agentEvalRunTopics.runId, runId), this.ownership())); }; /** @@ -82,12 +92,7 @@ export class AgentEvalRunTopicModel { .from(agentEvalRunTopics) .leftJoin(agentEvalRuns, eq(agentEvalRunTopics.runId, agentEvalRuns.id)) .leftJoin(topics, eq(agentEvalRunTopics.topicId, topics.id)) - .where( - and( - eq(agentEvalRunTopics.testCaseId, testCaseId), - eq(agentEvalRunTopics.userId, this.userId), - ), - ) + .where(and(eq(agentEvalRunTopics.testCaseId, testCaseId), this.ownership())) .orderBy(desc(agentEvalRunTopics.createdAt)); return rows; @@ -117,7 +122,7 @@ export class AgentEvalRunTopicModel { and( eq(agentEvalRunTopics.runId, runId), eq(agentEvalRunTopics.testCaseId, testCaseId), - eq(agentEvalRunTopics.userId, this.userId), + this.ownership(), ), ) .limit(1); @@ -136,7 +141,7 @@ export class AgentEvalRunTopicModel { .set({ status: 'error', evalResult: { error: 'Aborted' } }) .where( and( - eq(agentEvalRunTopics.userId, this.userId), + this.ownership(), eq(agentEvalRunTopics.runId, runId), or(eq(agentEvalRunTopics.status, 'pending'), eq(agentEvalRunTopics.status, 'running')), ), @@ -151,7 +156,7 @@ export class AgentEvalRunTopicModel { .set({ status: 'timeout' }) .where( and( - eq(agentEvalRunTopics.userId, this.userId), + this.ownership(), eq(agentEvalRunTopics.runId, runId), eq(agentEvalRunTopics.status, 'running'), lt(agentEvalRunTopics.createdAt, deadline), @@ -165,7 +170,7 @@ export class AgentEvalRunTopicModel { .delete(agentEvalRunTopics) .where( and( - eq(agentEvalRunTopics.userId, this.userId), + this.ownership(), eq(agentEvalRunTopics.runId, runId), eq(agentEvalRunTopics.testCaseId, testCaseId), ), @@ -181,7 +186,7 @@ export class AgentEvalRunTopicModel { .delete(agentEvalRunTopics) .where( and( - eq(agentEvalRunTopics.userId, this.userId), + this.ownership(), eq(agentEvalRunTopics.runId, runId), or(eq(agentEvalRunTopics.status, 'error'), eq(agentEvalRunTopics.status, 'timeout')), ), @@ -205,7 +210,7 @@ export class AgentEvalRunTopicModel { .set(value) .where( and( - eq(agentEvalRunTopics.userId, this.userId), + this.ownership(), eq(agentEvalRunTopics.runId, runId), eq(agentEvalRunTopics.topicId, topicId), ), diff --git a/packages/database/src/models/agentEval/testCase.ts b/packages/database/src/models/agentEval/testCase.ts index 80cf2d6bec..4de89f6657 100644 --- a/packages/database/src/models/agentEval/testCase.ts +++ b/packages/database/src/models/agentEval/testCase.ts @@ -2,21 +2,31 @@ import { and, count, eq, sql } from 'drizzle-orm'; import { agentEvalTestCases, type NewAgentEvalTestCase } from '../../schemas'; import { type LobeChatDatabase } from '../../type'; +import { buildWorkspaceWhere } from '../../utils/workspace'; export class AgentEvalTestCaseModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentEvalTestCases); + /** * Create a single test case */ create = async (params: Omit) => { - let finalParams: NewAgentEvalTestCase = { ...params, userId: this.userId }; + let finalParams: NewAgentEvalTestCase = { + ...params, + userId: this.userId, + workspaceId: this.workspaceId ?? null, + }; if (finalParams.sortOrder === undefined || finalParams.sortOrder === null) { const [maxResult] = await this.db @@ -35,7 +45,11 @@ export class AgentEvalTestCaseModel { * Batch create test cases */ batchCreate = async (cases: Omit[]) => { - const withUserId = cases.map((c) => ({ ...c, userId: this.userId })); + const withUserId = cases.map((c) => ({ + ...c, + userId: this.userId, + workspaceId: this.workspaceId ?? null, + })); return this.db.insert(agentEvalTestCases).values(withUserId).returning(); }; @@ -45,7 +59,7 @@ export class AgentEvalTestCaseModel { delete = async (id: string) => { return this.db .delete(agentEvalTestCases) - .where(and(eq(agentEvalTestCases.id, id), eq(agentEvalTestCases.userId, this.userId))); + .where(and(eq(agentEvalTestCases.id, id), this.ownership())); }; /** @@ -55,7 +69,7 @@ export class AgentEvalTestCaseModel { const [result] = await this.db .select() .from(agentEvalTestCases) - .where(and(eq(agentEvalTestCases.id, id), eq(agentEvalTestCases.userId, this.userId))) + .where(and(eq(agentEvalTestCases.id, id), this.ownership())) .limit(1); return result; }; @@ -67,12 +81,7 @@ export class AgentEvalTestCaseModel { const query = this.db .select() .from(agentEvalTestCases) - .where( - and( - eq(agentEvalTestCases.datasetId, datasetId), - eq(agentEvalTestCases.userId, this.userId), - ), - ) + .where(and(eq(agentEvalTestCases.datasetId, datasetId), this.ownership())) .orderBy(agentEvalTestCases.sortOrder); if (limit !== undefined) { @@ -92,12 +101,7 @@ export class AgentEvalTestCaseModel { const result = await this.db .select({ value: count() }) .from(agentEvalTestCases) - .where( - and( - eq(agentEvalTestCases.datasetId, datasetId), - eq(agentEvalTestCases.userId, this.userId), - ), - ); + .where(and(eq(agentEvalTestCases.datasetId, datasetId), this.ownership())); return Number(result[0]?.value) || 0; }; @@ -108,7 +112,7 @@ export class AgentEvalTestCaseModel { const [result] = await this.db .update(agentEvalTestCases) .set({ ...value, updatedAt: new Date() }) - .where(and(eq(agentEvalTestCases.id, id), eq(agentEvalTestCases.userId, this.userId))) + .where(and(eq(agentEvalTestCases.id, id), this.ownership())) .returning(); return result; }; diff --git a/packages/database/src/models/agentOperation.ts b/packages/database/src/models/agentOperation.ts index ef42aae95b..3c25d152a4 100644 --- a/packages/database/src/models/agentOperation.ts +++ b/packages/database/src/models/agentOperation.ts @@ -11,6 +11,7 @@ import type { } from '../schemas/agentOperations'; import { agentOperations } from '../schemas/agentOperations'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspaceWhere } from '../utils/workspace'; /** Verify rollup states, mirrors the `verify_status` enum column. */ export type VerifyStatus = @@ -80,12 +81,17 @@ export interface RecordOperationCompletionParams { export class AgentOperationModel { private readonly db: LobeChatDatabase; private readonly userId: string; + private readonly workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentOperations); + /** * Insert the initial row when an operation is created. Idempotent via * `onConflictDoNothing` on the primary key so resumed operations don't @@ -110,6 +116,7 @@ export class AgentOperationModel { topicId: params.topicId ?? null, trigger: params.trigger, userId: this.userId, + workspaceId: this.workspaceId ?? null, }; await this.db.insert(agentOperations).values(values).onConflictDoNothing(); @@ -151,14 +158,14 @@ export class AgentOperationModel { await this.db .update(agentOperations) .set(updates) - .where(and(eq(agentOperations.id, operationId), eq(agentOperations.userId, this.userId))); + .where(and(eq(agentOperations.id, operationId), this.ownership())); } async findById(operationId: string) { const [row] = await this.db .select() .from(agentOperations) - .where(and(eq(agentOperations.id, operationId), eq(agentOperations.userId, this.userId))) + .where(and(eq(agentOperations.id, operationId), this.ownership())) .limit(1); return row ?? null; } @@ -182,7 +189,7 @@ export class AgentOperationModel { .from(agentOperations) .where( and( - eq(agentOperations.userId, this.userId), + this.ownership(), isNotNull(agentOperations.startedAt), isNotNull(agentOperations.completedAt), gte(agentOperations.createdAt, startDate), diff --git a/packages/database/src/models/agentSignal/nightlyReview.ts b/packages/database/src/models/agentSignal/nightlyReview.ts index 5b68cd774b..4a077a2730 100644 --- a/packages/database/src/models/agentSignal/nightlyReview.ts +++ b/packages/database/src/models/agentSignal/nightlyReview.ts @@ -192,16 +192,27 @@ export class AgentSignalNightlyReviewModel { topicCount: countDistinct(messages.topicId), }) .from(messages) - .leftJoin(topics, and(eq(topics.id, messages.topicId), eq(topics.userId, userId))) - .innerJoin(agents, and(eq(agents.id, effectiveAgentId), eq(agents.userId, userId))) + .leftJoin( + topics, + and(eq(topics.id, messages.topicId), eq(topics.userId, userId), isNull(topics.workspaceId)), + ) + .innerJoin( + agents, + and(eq(agents.id, effectiveAgentId), eq(agents.userId, userId), isNull(agents.workspaceId)), + ) .leftJoin(userSettings, eq(userSettings.id, userId)) .leftJoin( messagePlugins, - and(eq(messagePlugins.id, messages.id), eq(messagePlugins.userId, userId)), + and( + eq(messagePlugins.id, messages.id), + eq(messagePlugins.userId, userId), + isNull(messagePlugins.workspaceId), + ), ) .where( and( eq(messages.userId, userId), + isNull(messages.workspaceId), agentFilter, gte(messages.createdAt, options.windowStart), lte(messages.createdAt, options.windowEnd), diff --git a/packages/database/src/models/agentSignal/reviewContext.ts b/packages/database/src/models/agentSignal/reviewContext.ts index 018a21bab5..bfca104929 100644 --- a/packages/database/src/models/agentSignal/reviewContext.ts +++ b/packages/database/src/models/agentSignal/reviewContext.ts @@ -1,5 +1,6 @@ import { INBOX_SESSION_ID } from '@lobechat/const'; import { and, count, desc, eq, gte, isNull, lte, or, sql } from 'drizzle-orm'; +import type { AnyPgColumn } from 'drizzle-orm/pg-core'; import { agentDocuments, @@ -11,6 +12,7 @@ import { userMemories, } from '../../schemas'; import type { LobeChatDatabase } from '../../type'; +import { buildWorkspaceWhere } from '../../utils/workspace'; const parseAggregateTimestamp = (value: Date | string) => value instanceof Date ? value : new Date(value); @@ -99,12 +101,17 @@ export interface AgentSignalDocumentActivityRow { export class AgentSignalReviewContextModel { private readonly db: LobeChatDatabase; private readonly userId: string; + private readonly workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } + private ws = (cols: { userId: AnyPgColumn; workspaceId: AnyPgColumn }) => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, cols); + /** Checks agent ownership, virtual status, and self-iteration opt-in. */ canAgentRunSelfIteration = async (agentId: string) => { const [agent] = await this.db @@ -113,7 +120,7 @@ export class AgentSignalReviewContextModel { .where( and( eq(agents.id, agentId), - eq(agents.userId, this.userId), + this.ws(agents), or(eq(agents.virtual, false), isNull(agents.virtual), eq(agents.slug, INBOX_SESSION_ID)), or( eq(agents.slug, INBOX_SESSION_ID), @@ -183,14 +190,11 @@ export class AgentSignalReviewContextModel { totalCount: count(messagePlugins.id), }) .from(messagePlugins) - .innerJoin( - messages, - and(eq(messages.id, messagePlugins.id), eq(messages.userId, this.userId)), - ) - .leftJoin(topics, and(eq(topics.id, messages.topicId), eq(topics.userId, this.userId))) + .innerJoin(messages, and(eq(messages.id, messagePlugins.id), this.ws(messages))) + .leftJoin(topics, and(eq(topics.id, messages.topicId), this.ws(topics))) .where( and( - eq(messagePlugins.userId, this.userId), + this.ws(messagePlugins), eq(effectiveAgentId, options.agentId), gte(messages.createdAt, options.windowStart), lte(messages.createdAt, options.windowEnd), @@ -220,13 +224,10 @@ export class AgentSignalReviewContextModel { updatedAt: agentDocuments.updatedAt, }) .from(agentDocuments) - .innerJoin( - documents, - and(eq(documents.id, agentDocuments.documentId), eq(documents.userId, this.userId)), - ) + .innerJoin(documents, and(eq(documents.id, agentDocuments.documentId), this.ws(documents))) .where( and( - eq(agentDocuments.userId, this.userId), + this.ws(agentDocuments), eq(agentDocuments.agentId, options.agentId), isNull(agentDocuments.deletedAt), gte(agentDocuments.updatedAt, options.windowStart), @@ -285,14 +286,11 @@ export class AgentSignalReviewContextModel { topicId: topics.id, }) .from(messages) - .leftJoin(topics, and(eq(topics.id, messages.topicId), eq(topics.userId, this.userId))) - .leftJoin( - messagePlugins, - and(eq(messagePlugins.id, messages.id), eq(messagePlugins.userId, this.userId)), - ) + .leftJoin(topics, and(eq(topics.id, messages.topicId), this.ws(topics))) + .leftJoin(messagePlugins, and(eq(messagePlugins.id, messages.id), this.ws(messagePlugins))) .where( and( - eq(messages.userId, this.userId), + this.ws(messages), eq(effectiveAgentId, options.agentId), gte(messages.createdAt, options.windowStart), lte(messages.createdAt, options.windowEnd), @@ -349,14 +347,11 @@ export class AgentSignalReviewContextModel { topicId: topics.id, }) .from(messages) - .leftJoin(topics, and(eq(topics.id, messages.topicId), eq(topics.userId, this.userId))) - .leftJoin( - messagePlugins, - and(eq(messagePlugins.id, messages.id), eq(messagePlugins.userId, this.userId)), - ) + .leftJoin(topics, and(eq(topics.id, messages.topicId), this.ws(topics))) + .leftJoin(messagePlugins, and(eq(messagePlugins.id, messages.id), this.ws(messagePlugins))) .where( and( - eq(messages.userId, this.userId), + this.ws(messages), eq(messages.agentId, options.agentId), gte(messages.createdAt, options.windowStart), lte(messages.createdAt, options.windowEnd), diff --git a/packages/database/src/models/agentSkill.ts b/packages/database/src/models/agentSkill.ts index eb8766ff64..1b6a372b03 100644 --- a/packages/database/src/models/agentSkill.ts +++ b/packages/database/src/models/agentSkill.ts @@ -2,9 +2,10 @@ import type { SkillItem, SkillListItem } from '@lobechat/types'; import { merge } from '@lobechat/utils'; import { and, desc, eq, ilike, inArray, or } from 'drizzle-orm'; -import type {NewAgentSkill } from '../schemas'; +import type { NewAgentSkill } from '../schemas'; import { agentSkills } from '../schemas'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; const skillItemColumns = { content: agentSkills.content, @@ -35,19 +36,24 @@ const skillListColumns = { export class AgentSkillModel { private userId: string; + private workspaceId?: string; private db: LobeChatDatabase; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } + private scopeWhere = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentSkills); + // ========== Create ========== - create = async (data: Omit): Promise => { + create = async (data: Omit): Promise => { const [result] = await this.db .insert(agentSkills) - .values({ ...data, userId: this.userId }) + .values(buildWorkspacePayload({ userId: this.userId, workspaceId: this.workspaceId }, data)) .returning(skillItemColumns); return result; }; @@ -58,7 +64,7 @@ export class AgentSkillModel { const [result] = await this.db .select(skillItemColumns) .from(agentSkills) - .where(and(eq(agentSkills.id, id), eq(agentSkills.userId, this.userId))) + .where(and(eq(agentSkills.id, id), this.scopeWhere())) .limit(1); return result; }; @@ -67,7 +73,7 @@ export class AgentSkillModel { const [result] = await this.db .select(skillItemColumns) .from(agentSkills) - .where(and(eq(agentSkills.identifier, identifier), eq(agentSkills.userId, this.userId))) + .where(and(eq(agentSkills.identifier, identifier), this.scopeWhere())) .limit(1); return result; }; @@ -76,7 +82,7 @@ export class AgentSkillModel { const [result] = await this.db .select(skillItemColumns) .from(agentSkills) - .where(and(eq(agentSkills.name, name), eq(agentSkills.userId, this.userId))) + .where(and(eq(agentSkills.name, name), this.scopeWhere())) .limit(1); return result; }; @@ -85,7 +91,7 @@ export class AgentSkillModel { const data = await this.db .select(skillListColumns) .from(agentSkills) - .where(eq(agentSkills.userId, this.userId)) + .where(this.scopeWhere()) .orderBy(desc(agentSkills.updatedAt)); return { data, total: data.length }; @@ -96,7 +102,7 @@ export class AgentSkillModel { return this.db .select(skillItemColumns) .from(agentSkills) - .where(and(inArray(agentSkills.id, ids), eq(agentSkills.userId, this.userId))); + .where(and(inArray(agentSkills.id, ids), this.scopeWhere())); }; listBySource = async ( @@ -105,7 +111,7 @@ export class AgentSkillModel { const data = await this.db .select(skillListColumns) .from(agentSkills) - .where(and(eq(agentSkills.source, source), eq(agentSkills.userId, this.userId))) + .where(and(eq(agentSkills.source, source), this.scopeWhere())) .orderBy(desc(agentSkills.updatedAt)); return { data, total: data.length }; @@ -117,7 +123,7 @@ export class AgentSkillModel { .from(agentSkills) .where( and( - eq(agentSkills.userId, this.userId), + this.scopeWhere(), or(ilike(agentSkills.name, `%${query}%`), ilike(agentSkills.description, `%${query}%`)), ), ) @@ -136,7 +142,7 @@ export class AgentSkillModel { const [result] = await this.db .update(agentSkills) .set(updateData) - .where(and(eq(agentSkills.id, id), eq(agentSkills.userId, this.userId))) + .where(and(eq(agentSkills.id, id), this.scopeWhere())) .returning(skillItemColumns); return result; }; @@ -146,7 +152,7 @@ export class AgentSkillModel { delete = async (id: string): Promise<{ success: boolean }> => { const result = await this.db .delete(agentSkills) - .where(and(eq(agentSkills.id, id), eq(agentSkills.userId, this.userId))); + .where(and(eq(agentSkills.id, id), this.scopeWhere())); return { success: (result.rowCount ?? 0) > 0 }; }; diff --git a/packages/database/src/models/apiKey.ts b/packages/database/src/models/apiKey.ts index d0c114c428..4e41db1f82 100644 --- a/packages/database/src/models/apiKey.ts +++ b/packages/database/src/models/apiKey.ts @@ -7,6 +7,7 @@ import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt'; import type { ApiKeyItem, NewApiKeyItem } from '../schemas'; import { apiKeys } from '../schemas'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; export class ApiKeyModel { static findByKey = async (db: LobeChatDatabase, key: string) => { @@ -22,13 +23,18 @@ export class ApiKeyModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; private gateKeeperPromise: Promise | null = null; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, apiKeys); + private async getGateKeeper() { if (!this.gateKeeperPromise) { this.gateKeeperPromise = KeyVaultsGateKeeper.initWithEnvKey(); @@ -45,24 +51,29 @@ export class ApiKeyModel { const [result] = await this.db .insert(apiKeys) - .values({ ...params, key: encryptedKey, keyHash, userId: this.userId }) + .values( + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { ...params, key: encryptedKey, keyHash }, + ), + ) .returning(); return result; }; delete = async (id: string) => { - return this.db.delete(apiKeys).where(and(eq(apiKeys.id, id), eq(apiKeys.userId, this.userId))); + return this.db.delete(apiKeys).where(and(eq(apiKeys.id, id), this.ownership())); }; deleteAll = async () => { - return this.db.delete(apiKeys).where(eq(apiKeys.userId, this.userId)); + return this.db.delete(apiKeys).where(this.ownership()); }; query = async () => { const results = await this.db.query.apiKeys.findMany({ orderBy: [desc(apiKeys.updatedAt)], - where: eq(apiKeys.userId, this.userId), + where: this.ownership(), }); const gateKeeper = await this.getGateKeeper(); @@ -103,12 +114,12 @@ export class ApiKeyModel { return this.db .update(apiKeys) .set({ ...value, updatedAt: new Date() }) - .where(and(eq(apiKeys.id, id), eq(apiKeys.userId, this.userId))); + .where(and(eq(apiKeys.id, id), this.ownership())); }; findById = async (id: string) => { return this.db.query.apiKeys.findFirst({ - where: and(eq(apiKeys.id, id), eq(apiKeys.userId, this.userId)), + where: and(eq(apiKeys.id, id), this.ownership()), }); }; @@ -116,6 +127,6 @@ export class ApiKeyModel { return this.db .update(apiKeys) .set({ lastUsedAt: new Date() }) - .where(and(eq(apiKeys.id, id), eq(apiKeys.userId, this.userId))); + .where(and(eq(apiKeys.id, id), this.ownership())); }; } diff --git a/packages/database/src/models/asyncTask.ts b/packages/database/src/models/asyncTask.ts index c337b8fe8f..6f6c72b386 100644 --- a/packages/database/src/models/asyncTask.ts +++ b/packages/database/src/models/asyncTask.ts @@ -11,35 +11,46 @@ import { and, eq, inArray, lt, or, sql } from 'drizzle-orm'; import type { AsyncTaskSelectItem, NewAsyncTaskItem } from '../schemas'; import { asyncTasks } from '../schemas'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; export class AsyncTaskModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, asyncTasks); + create = async ( params: Pick, ): Promise => { const data = await this.db .insert(asyncTasks) - .values({ ...params, userId: this.userId }) + .values( + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { ...params }, + ), + ) .returning(); return data[0].id; }; delete = async (id: string) => { - return this.db - .delete(asyncTasks) - .where(and(eq(asyncTasks.id, id), eq(asyncTasks.userId, this.userId))); + return this.db.delete(asyncTasks).where(and(eq(asyncTasks.id, id), this.ownership())); }; findById = async (id: string) => { - return this.db.query.asyncTasks.findFirst({ where: and(eq(asyncTasks.id, id)) }); + return this.db.query.asyncTasks.findFirst({ + where: and(eq(asyncTasks.id, id), this.ownership()), + }); }; static findByInferenceId = async (db: LobeChatDatabase, inferenceId: string) => { @@ -52,13 +63,13 @@ export class AsyncTaskModel { return this.db .update(asyncTasks) .set({ ...value, updatedAt: new Date() }) - .where(and(eq(asyncTasks.id, taskId))); + .where(and(eq(asyncTasks.id, taskId), this.ownership())); } findActiveByType = async (type: AsyncTaskType) => { return this.db.query.asyncTasks.findFirst({ where: and( - eq(asyncTasks.userId, this.userId), + this.ownership(), eq(asyncTasks.type, type), inArray(asyncTasks.status, [AsyncTaskStatus.Pending, AsyncTaskStatus.Processing]), ), @@ -98,7 +109,7 @@ export class AsyncTaskModel { `, updatedAt: new Date(), }) - .where(and(eq(asyncTasks.id, taskId), eq(asyncTasks.userId, this.userId))) + .where(and(eq(asyncTasks.id, taskId), this.ownership())) .returning({ metadata: asyncTasks.metadata, status: asyncTasks.status }); return result[0]; @@ -110,7 +121,7 @@ export class AsyncTaskModel { if (taskIds.length > 0) { await this.checkTimeoutTasks(taskIds); chunkTasks = await this.db.query.asyncTasks.findMany({ - where: and(inArray(asyncTasks.id, taskIds), eq(asyncTasks.type, type)), + where: and(inArray(asyncTasks.id, taskIds), eq(asyncTasks.type, type), this.ownership()), }); } @@ -138,6 +149,7 @@ export class AsyncTaskModel { .where( and( inArray(asyncTasks.id, ids), + this.ownership(), or( eq(asyncTasks.status, AsyncTaskStatus.Pending), eq(asyncTasks.status, AsyncTaskStatus.Processing), @@ -157,9 +169,12 @@ export class AsyncTaskModel { status: AsyncTaskStatus.Error, }) .where( - inArray( - asyncTasks.id, - tasks.map((item) => item.id), + and( + inArray( + asyncTasks.id, + tasks.map((item) => item.id), + ), + this.ownership(), ), ); } diff --git a/packages/database/src/models/brief.ts b/packages/database/src/models/brief.ts index 8b2583b075..cfaf7f0ca6 100644 --- a/packages/database/src/models/brief.ts +++ b/packages/database/src/models/brief.ts @@ -4,6 +4,7 @@ import { agents } from '../schemas/agent'; import type { BriefItem, NewBrief } from '../schemas/task'; import { briefs, tasks } from '../schemas/task'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; export interface UnresolvedBriefRow { agentAvatar: string | null; @@ -17,16 +18,23 @@ export interface UnresolvedBriefRow { export class BriefModel { private readonly userId: string; private readonly db: LobeChatDatabase; + private readonly workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, briefs); + async create(data: Omit): Promise { const result = await this.db .insert(briefs) - .values({ ...data, userId: this.userId }) + .values( + buildWorkspacePayload({ userId: this.userId, workspaceId: this.workspaceId }, { ...data }), + ) .returning(); return result[0]; @@ -36,7 +44,7 @@ export class BriefModel { const result = await this.db .select() .from(briefs) - .where(and(eq(briefs.id, id), eq(briefs.userId, this.userId))) + .where(and(eq(briefs.id, id), this.ownership())) .limit(1); return result[0] || null; @@ -49,7 +57,7 @@ export class BriefModel { }): Promise<{ briefs: BriefItem[]; total: number }> { const { type, limit = 50, offset = 0 } = options || {}; - const conditions = [eq(briefs.userId, this.userId)]; + const conditions = [this.ownership()]; if (type) conditions.push(eq(briefs.type, type)); const where = and(...conditions); @@ -90,7 +98,7 @@ export class BriefModel { .from(briefs) .leftJoin(agents, eq(briefs.agentId, agents.id)) .leftJoin(tasks, eq(briefs.taskId, tasks.id)) - .where(and(eq(briefs.userId, this.userId), isNull(briefs.resolvedAt))) + .where(and(this.ownership(), isNull(briefs.resolvedAt))) .orderBy( sql`CASE WHEN ${briefs.priority} = 'urgent' THEN 0 @@ -130,7 +138,7 @@ export class BriefModel { .from(briefs) .where( and( - eq(briefs.userId, this.userId), + this.ownership(), eq(briefs.agentId, agentId), eq(briefs.trigger, trigger), isNull(briefs.resolvedAt), @@ -144,7 +152,7 @@ export class BriefModel { return this.db .select() .from(briefs) - .where(and(eq(briefs.taskId, taskId), eq(briefs.userId, this.userId))) + .where(and(eq(briefs.taskId, taskId), this.ownership())) .orderBy(desc(briefs.createdAt)); } @@ -159,7 +167,7 @@ export class BriefModel { ): Promise { const excludeTypes = options?.excludeTypes ?? []; const conditions = [ - eq(briefs.userId, this.userId), + this.ownership(), eq(briefs.taskId, taskId), eq(briefs.priority, 'urgent'), isNull(briefs.resolvedAt), @@ -180,7 +188,7 @@ export class BriefModel { return this.db .select() .from(briefs) - .where(and(eq(briefs.cronJobId, cronJobId), eq(briefs.userId, this.userId))) + .where(and(eq(briefs.cronJobId, cronJobId), this.ownership())) .orderBy(desc(briefs.createdAt)); } @@ -188,7 +196,7 @@ export class BriefModel { const result = await this.db .update(briefs) .set({ readAt: new Date() }) - .where(and(eq(briefs.id, id), eq(briefs.userId, this.userId))) + .where(and(eq(briefs.id, id), this.ownership())) .returning(); return result[0] || null; @@ -206,7 +214,7 @@ export class BriefModel { resolvedAt: new Date(), resolvedComment: options?.comment, }) - .where(and(eq(briefs.id, id), eq(briefs.userId, this.userId))) + .where(and(eq(briefs.id, id), this.ownership())) .returning(); return result[0] || null; @@ -229,7 +237,7 @@ export class BriefModel { const result = await this.db .update(briefs) .set({ metadata }) - .where(and(eq(briefs.id, id), eq(briefs.userId, this.userId))) + .where(and(eq(briefs.id, id), this.ownership())) .returning(); return result[0] || null; @@ -238,7 +246,7 @@ export class BriefModel { async delete(id: string): Promise { const result = await this.db .delete(briefs) - .where(and(eq(briefs.id, id), eq(briefs.userId, this.userId))) + .where(and(eq(briefs.id, id), this.ownership())) .returning(); return result.length > 0; diff --git a/packages/database/src/models/chatGroup.ts b/packages/database/src/models/chatGroup.ts index d1b6f6eaf4..142c61047c 100644 --- a/packages/database/src/models/chatGroup.ts +++ b/packages/database/src/models/chatGroup.ts @@ -8,20 +8,27 @@ import type { } from '../schemas'; import { chatGroups, chatGroupsAgents } from '../schemas'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; export class ChatGroupModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } + + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, chatGroups); + // ******* Query Methods ******* // async findById(id: string): Promise { const item = await this.db.query.chatGroups.findFirst({ - where: and(eq(chatGroups.id, id), eq(chatGroups.userId, this.userId)), + where: and(eq(chatGroups.id, id), this.ownership()), }); return item; @@ -30,7 +37,7 @@ export class ChatGroupModel { async query(): Promise { return this.db.query.chatGroups.findMany({ orderBy: [desc(chatGroups.updatedAt)], - where: eq(chatGroups.userId, this.userId), + where: this.ownership(), }); } @@ -44,7 +51,7 @@ export class ChatGroupModel { columns: { id: true }, orderBy: [desc(chatGroups.updatedAt)], where: and( - eq(chatGroups.userId, this.userId), + this.ownership(), sql`${chatGroups.config}->>'forkedFromIdentifier' = ${forkedFromIdentifier}`, ), }); @@ -58,7 +65,7 @@ export class ChatGroupModel { const groupIds = groups.map((g) => g.id); const groupAgents = await this.db.query.chatGroupsAgents.findMany({ - where: inArray(chatGroupsAgents.chatGroupId, groupIds), + where: and(inArray(chatGroupsAgents.chatGroupId, groupIds), this.agentsOwnership()), with: { agent: true }, }); @@ -87,7 +94,7 @@ export class ChatGroupModel { const agents = await this.db.query.chatGroupsAgents.findMany({ orderBy: [chatGroupsAgents.order], - where: eq(chatGroupsAgents.chatGroupId, groupId), + where: and(eq(chatGroupsAgents.chatGroupId, groupId), this.agentsOwnership()), }); return { agents, group }; @@ -98,7 +105,12 @@ export class ChatGroupModel { async create(params: Omit): Promise { const [result] = await this.db .insert(chatGroups) - .values({ ...params, userId: this.userId }) + .values( + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { ...params }, + ), + ) .returning(); return result; @@ -119,6 +131,7 @@ export class ChatGroupModel { chatGroupId: group.id, order: index, userId: this.userId, + workspaceId: this.workspaceId ?? null, })); const agents = await this.db.insert(chatGroupsAgents).values(agentParams).returning(); @@ -132,7 +145,7 @@ export class ChatGroupModel { const [result] = await this.db .update(chatGroups) .set(value) - .where(and(eq(chatGroups.id, id), eq(chatGroups.userId, this.userId))) + .where(and(eq(chatGroups.id, id), this.ownership())) .returning(); if (!result) { @@ -153,6 +166,7 @@ export class ChatGroupModel { order: options?.order || 0, role: options?.role || 'assistant', userId: this.userId, + workspaceId: this.workspaceId ?? null, }; const [result] = await this.db.insert(chatGroupsAgents).values(params).returning(); @@ -189,6 +203,7 @@ export class ChatGroupModel { chatGroupId: groupId, enabled: true, userId: this.userId, + workspaceId: this.workspaceId ?? null, })); const added = await this.db.insert(chatGroupsAgents).values(newAgents).returning(); @@ -196,10 +211,19 @@ export class ChatGroupModel { return { added, existing: existingIds }; } + private agentsOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, chatGroupsAgents); + async removeAgentFromGroup(groupId: string, agentId: string): Promise { await this.db .delete(chatGroupsAgents) - .where(and(eq(chatGroupsAgents.chatGroupId, groupId), eq(chatGroupsAgents.agentId, agentId))); + .where( + and( + eq(chatGroupsAgents.chatGroupId, groupId), + eq(chatGroupsAgents.agentId, agentId), + this.agentsOwnership(), + ), + ); } /** @@ -212,7 +236,11 @@ export class ChatGroupModel { await this.db .delete(chatGroupsAgents) .where( - and(eq(chatGroupsAgents.chatGroupId, groupId), inArray(chatGroupsAgents.agentId, agentIds)), + and( + eq(chatGroupsAgents.chatGroupId, groupId), + inArray(chatGroupsAgents.agentId, agentIds), + this.agentsOwnership(), + ), ); } @@ -224,7 +252,13 @@ export class ChatGroupModel { const [result] = await this.db .update(chatGroupsAgents) .set({ ...updates, updatedAt: new Date() }) - .where(and(eq(chatGroupsAgents.chatGroupId, groupId), eq(chatGroupsAgents.agentId, agentId))) + .where( + and( + eq(chatGroupsAgents.chatGroupId, groupId), + eq(chatGroupsAgents.agentId, agentId), + this.agentsOwnership(), + ), + ) .returning(); return result; @@ -236,7 +270,7 @@ export class ChatGroupModel { // Agents are automatically deleted due to CASCADE constraint const [result] = await this.db .delete(chatGroups) - .where(and(eq(chatGroups.id, id), eq(chatGroups.userId, this.userId))) + .where(and(eq(chatGroups.id, id), this.ownership())) .returning(); if (!result) { @@ -247,7 +281,7 @@ export class ChatGroupModel { } async deleteAll(): Promise { - await this.db.delete(chatGroups).where(eq(chatGroups.userId, this.userId)); + await this.db.delete(chatGroups).where(this.ownership()); } // ******* Agent Query Methods ******* // @@ -255,14 +289,18 @@ export class ChatGroupModel { async getGroupAgents(groupId: string): Promise { return this.db.query.chatGroupsAgents.findMany({ orderBy: [chatGroupsAgents.order], - where: eq(chatGroupsAgents.chatGroupId, groupId), + where: and(eq(chatGroupsAgents.chatGroupId, groupId), this.agentsOwnership()), }); } async getEnabledGroupAgents(groupId: string): Promise { return this.db.query.chatGroupsAgents.findMany({ orderBy: [chatGroupsAgents.order], - where: and(eq(chatGroupsAgents.chatGroupId, groupId), eq(chatGroupsAgents.enabled, true)), + where: and( + eq(chatGroupsAgents.chatGroupId, groupId), + eq(chatGroupsAgents.enabled, true), + this.agentsOwnership(), + ), }); } @@ -275,9 +313,7 @@ export class ChatGroupModel { const groupIds = await this.db .selectDistinct({ chatGroupId: chatGroupsAgents.chatGroupId }) .from(chatGroupsAgents) - .where( - and(eq(chatGroupsAgents.userId, this.userId), inArray(chatGroupsAgents.agentId, agentIds)), - ); + .where(and(this.agentsOwnership(), inArray(chatGroupsAgents.agentId, agentIds))); if (groupIds.length === 0) return []; @@ -288,7 +324,7 @@ export class ChatGroupModel { chatGroups.id, groupIds.map((g) => g.chatGroupId), ), - eq(chatGroups.userId, this.userId), + this.ownership(), ), }); } diff --git a/packages/database/src/models/chunk.ts b/packages/database/src/models/chunk.ts index f431f963a0..532eae4fc2 100644 --- a/packages/database/src/models/chunk.ts +++ b/packages/database/src/models/chunk.ts @@ -5,27 +5,44 @@ import { chunk } from 'es-toolkit/compat'; import type { NewChunkItem, NewUnstructuredChunkItem } from '../schemas'; import { chunks, embeddings, fileChunks, files, unstructuredChunks } from '../schemas'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; export class ChunkModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, chunks); + + private fileChunksOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, fileChunks); + bulkCreate = async (params: NewChunkItem[], fileId: string) => { return this.db.transaction(async (trx) => { if (params.length === 0) return []; - const result = await trx.insert(chunks).values(params).returning(); + const result = await trx + .insert(chunks) + .values( + params.map((p) => + buildWorkspacePayload({ userId: this.userId, workspaceId: this.workspaceId }, p), + ), + ) + .returning(); const fileChunksData = result.map((chunk) => ({ chunkId: chunk.id, fileId, userId: this.userId, + workspaceId: this.workspaceId ?? null, })); if (fileChunksData.length > 0) { @@ -37,11 +54,15 @@ export class ChunkModel { }; bulkCreateUnstructuredChunks = async (params: NewUnstructuredChunkItem[]) => { - return this.db.insert(unstructuredChunks).values(params); + return this.db + .insert(unstructuredChunks) + .values( + params.map((p) => ({ ...p, workspaceId: p.workspaceId ?? this.workspaceId ?? null })), + ); }; delete = async (id: string) => { - return this.db.delete(chunks).where(and(eq(chunks.id, id), eq(chunks.userId, this.userId))); + return this.db.delete(chunks).where(and(eq(chunks.id, id), this.ownership())); }; deleteOrphanChunks = async () => { @@ -67,7 +88,7 @@ export class ChunkModel { findById = async (id: string) => { return this.db.query.chunks.findFirst({ - where: and(eq(chunks.id, id)), + where: and(eq(chunks.id, id), this.ownership()), }); }; @@ -85,7 +106,7 @@ export class ChunkModel { }) .from(chunks) .innerJoin(fileChunks, eq(chunks.id, fileChunks.chunkId)) - .where(and(eq(fileChunks.fileId, id), eq(chunks.userId, this.userId))) + .where(and(eq(fileChunks.fileId, id), this.ownership(), this.fileChunksOwnership())) .limit(20) .offset(page * 20) .orderBy(asc(chunks.index)); @@ -102,7 +123,7 @@ export class ChunkModel { .select() .from(chunks) .innerJoin(fileChunks, eq(chunks.id, fileChunks.chunkId)) - .where(eq(fileChunks.fileId, id)); + .where(and(eq(fileChunks.fileId, id), this.ownership(), this.fileChunksOwnership())); return data .map((item) => item.chunks) @@ -119,7 +140,7 @@ export class ChunkModel { id: fileChunks.fileId, }) .from(fileChunks) - .where(inArray(fileChunks.fileId, ids)) + .where(and(inArray(fileChunks.fileId, ids), this.fileChunksOwnership())) .groupBy(fileChunks.fileId); }; @@ -130,7 +151,7 @@ export class ChunkModel { id: fileChunks.fileId, }) .from(fileChunks) - .where(eq(fileChunks.fileId, ids)) + .where(and(eq(fileChunks.fileId, ids), this.fileChunksOwnership())) .groupBy(fileChunks.fileId); return data[0]?.count ?? 0; @@ -161,7 +182,13 @@ export class ChunkModel { .leftJoin(embeddings, eq(chunks.id, embeddings.chunkId)) .leftJoin(fileChunks, eq(chunks.id, fileChunks.chunkId)) .leftJoin(files, eq(fileChunks.fileId, files.id)) - .where(fileIds ? inArray(fileChunks.fileId, fileIds) : undefined) + .where( + and( + this.ownership(), + fileIds ? this.fileChunksOwnership() : undefined, + fileIds ? inArray(fileChunks.fileId, fileIds) : undefined, + ), + ) .orderBy((t) => desc(t.similarity)) .limit(30); @@ -202,7 +229,7 @@ export class ChunkModel { .leftJoin(embeddings, eq(chunks.id, embeddings.chunkId)) .leftJoin(fileChunks, eq(chunks.id, fileChunks.chunkId)) .leftJoin(files, eq(files.id, fileChunks.fileId)) - .where(inArray(fileChunks.fileId, fileIds)) + .where(and(inArray(fileChunks.fileId, fileIds), this.ownership(), this.fileChunksOwnership())) .orderBy((t) => desc(t.similarity)) // Relaxed to 15 for now .limit(topK); diff --git a/packages/database/src/models/connector.ts b/packages/database/src/models/connector.ts index 8e40e1554f..ad2cee657f 100644 --- a/packages/database/src/models/connector.ts +++ b/packages/database/src/models/connector.ts @@ -8,6 +8,7 @@ import type { } from '../schemas'; import { userConnectors } from '../schemas'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; interface GateKeeper { decrypt: (ciphertext: string) => Promise<{ plaintext: string }>; @@ -28,13 +29,18 @@ export class ConnectorModel { private userId: string; private db: LobeChatDatabase; private gateKeeper?: GateKeeper; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string, gateKeeper?: GateKeeper) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string, gateKeeper?: GateKeeper) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; this.gateKeeper = gateKeeper; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, userConnectors); + create = async ( params: CreateConnectorParams, gateKeeper: GateKeeper | undefined = this.gateKeeper, @@ -45,25 +51,25 @@ export class ConnectorModel { const [result] = await this.db .insert(userConnectors) - .values({ ...params, credentials, userId: this.userId }) + .values( + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { ...params, credentials }, + ), + ) .returning(); return result; }; delete = async (id: string): Promise => { - await this.db - .delete(userConnectors) - .where(and(eq(userConnectors.id, id), eq(userConnectors.userId, this.userId))); + await this.db.delete(userConnectors).where(and(eq(userConnectors.id, id), this.ownership())); }; query = async ( gateKeeper: GateKeeper | undefined = this.gateKeeper, ): Promise => { - const rows = await this.db - .select() - .from(userConnectors) - .where(eq(userConnectors.userId, this.userId)); + const rows = await this.db.select().from(userConnectors).where(this.ownership()); return Promise.all(rows.map((r) => decryptRow(r, gateKeeper))); }; @@ -77,12 +83,7 @@ export class ConnectorModel { const rows = await this.db .select() .from(userConnectors) - .where( - and( - eq(userConnectors.userId, this.userId), - inArray(userConnectors.identifier, identifiers), - ), - ); + .where(and(this.ownership(), inArray(userConnectors.identifier, identifiers))); return Promise.all(rows.map((r) => decryptRow(r, gateKeeper))); }; @@ -94,7 +95,7 @@ export class ConnectorModel { const [row] = await this.db .select() .from(userConnectors) - .where(and(eq(userConnectors.id, id), eq(userConnectors.userId, this.userId))) + .where(and(eq(userConnectors.id, id), this.ownership())) .limit(1); if (!row) return null; @@ -120,14 +121,14 @@ export class ConnectorModel { await this.db .update(userConnectors) .set(set) - .where(and(eq(userConnectors.id, id), eq(userConnectors.userId, this.userId))); + .where(and(eq(userConnectors.id, id), this.ownership())); }; updateStatus = async (id: string, status: ConnectorStatus): Promise => { await this.db .update(userConnectors) .set({ status, updatedAt: new Date() }) - .where(and(eq(userConnectors.id, id), eq(userConnectors.userId, this.userId))); + .where(and(eq(userConnectors.id, id), this.ownership())); }; } diff --git a/packages/database/src/models/connectorTool.ts b/packages/database/src/models/connectorTool.ts index 3381020351..fd12465258 100644 --- a/packages/database/src/models/connectorTool.ts +++ b/packages/database/src/models/connectorTool.ts @@ -8,6 +8,7 @@ import type { } from '../schemas'; import { ConnectorToolPermission as Permission, userConnectorTools } from '../schemas'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; export interface SyncToolInput { crudType: ToolCRUDType; @@ -24,12 +25,17 @@ export interface SyncToolInput { export class ConnectorToolModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, userConnectorTools); + /** * Batch-upsert tools from a manifest sync. * @@ -41,19 +47,23 @@ export class ConnectorToolModel { upsertMany = async (userConnectorId: string, tools: SyncToolInput[]): Promise => { if (tools.length === 0) return; - const values: NewUserConnectorTool[] = tools.map((t) => ({ - crudType: t.crudType, - description: t.description ?? null, - displayName: t.displayName ?? null, - inputSchema: t.inputSchema ?? null, - isWorkArtifact: false, - outputSchema: t.outputSchema ?? null, - permission: t.defaultPermission ?? Permission.auto, - renderConfig: t.renderConfig ?? null, - toolName: t.toolName, - userConnectorId, - userId: this.userId, - })); + const values: NewUserConnectorTool[] = tools.map((t) => + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { + crudType: t.crudType, + description: t.description ?? null, + displayName: t.displayName ?? null, + inputSchema: t.inputSchema ?? null, + isWorkArtifact: false, + outputSchema: t.outputSchema ?? null, + permission: t.defaultPermission ?? Permission.auto, + renderConfig: t.renderConfig ?? null, + toolName: t.toolName, + userConnectorId, + }, + ), + ); await this.db .insert(userConnectorTools) @@ -80,19 +90,14 @@ export class ConnectorToolModel { await this.db .update(userConnectorTools) .set({ permission, updatedAt: new Date() }) - .where(and(eq(userConnectorTools.id, toolId), eq(userConnectorTools.userId, this.userId))); + .where(and(eq(userConnectorTools.id, toolId), this.ownership())); }; queryByConnector = async (userConnectorId: string): Promise => { return this.db .select() .from(userConnectorTools) - .where( - and( - eq(userConnectorTools.userConnectorId, userConnectorId), - eq(userConnectorTools.userId, this.userId), - ), - ); + .where(and(eq(userConnectorTools.userConnectorId, userConnectorId), this.ownership())); }; /** @@ -107,7 +112,7 @@ export class ConnectorToolModel { .from(userConnectorTools) .where( and( - eq(userConnectorTools.userId, this.userId), + this.ownership(), inArray(userConnectorTools.userConnectorId, connectorIds), ne(userConnectorTools.permission, Permission.disabled), ), @@ -124,12 +129,7 @@ export class ConnectorToolModel { return this.db .select() .from(userConnectorTools) - .where( - and( - eq(userConnectorTools.userId, this.userId), - inArray(userConnectorTools.userConnectorId, connectorIds), - ), - ); + .where(and(this.ownership(), inArray(userConnectorTools.userConnectorId, connectorIds))); }; /** @@ -140,9 +140,7 @@ export class ConnectorToolModel { const results = await this.db .select() .from(userConnectorTools) - .where( - and(eq(userConnectorTools.userId, this.userId), eq(userConnectorTools.toolName, toolName)), - ) + .where(and(this.ownership(), eq(userConnectorTools.toolName, toolName))) .limit(1); return results[0]; }; diff --git a/packages/database/src/models/device.ts b/packages/database/src/models/device.ts index d865f4df31..d5ce13559b 100644 --- a/packages/database/src/models/device.ts +++ b/packages/database/src/models/device.ts @@ -19,6 +19,17 @@ export interface UpdateDeviceParams { workingDirs?: WorkingDirEntry[]; } +/** + * Devices are intentionally USER-LEVEL, not workspace-scoped. + * + * Even though the `devices` table carries a nullable `workspace_id` column, a + * physical machine belongs to the user across every workspace they're in (the + * unique key is `(userId, deviceId)`). This model therefore scopes all reads + * and writes by `userId` only and deliberately does NOT take a `workspaceId` + * argument or use `buildWorkspaceWhere` / `buildWorkspacePayload`. Switching it + * to workspace-scoped lookups would hide a user's own device inside their + * workspaces. See the matching note on `devices.workspaceId` in the schema. + */ export class DeviceModel { private userId: string; private db: LobeChatDatabase; diff --git a/packages/database/src/models/document.ts b/packages/database/src/models/document.ts index 96f7c7a57f..911ba2fd2f 100644 --- a/packages/database/src/models/document.ts +++ b/packages/database/src/models/document.ts @@ -1,8 +1,9 @@ -import { and, count, desc, eq, inArray, isNull, notInArray } from 'drizzle-orm'; +import { and, count, desc, eq, inArray, isNull, notInArray, sum } from 'drizzle-orm'; import type { DocumentItem, NewDocument } from '../schemas'; -import { DOCUMENT_FOLDER_TYPE, documents } from '../schemas'; +import { DOCUMENT_FOLDER_TYPE, documents, files } from '../schemas'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; export interface QueryDocumentParams { current?: number; @@ -14,16 +15,21 @@ export interface QueryDocumentParams { export class DocumentModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents); + findOrCreateFolder = async (name: string, parentId?: string): Promise => { const existing = await this.db.query.documents.findFirst({ where: and( - eq(documents.userId, this.userId), + this.ownership(), eq(documents.fileType, DOCUMENT_FOLDER_TYPE), eq(documents.filename, name), parentId ? eq(documents.parentId, parentId) : isNull(documents.parentId), @@ -48,20 +54,23 @@ export class DocumentModel { create = async (params: Omit): Promise => { const result = (await this.db .insert(documents) - .values({ ...params, userId: this.userId }) + .values( + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { ...params }, + ), + ) .returning()) as DocumentItem[]; return result[0]!; }; delete = async (id: string) => { - return this.db - .delete(documents) - .where(and(eq(documents.id, id), eq(documents.userId, this.userId))); + return this.db.delete(documents).where(and(eq(documents.id, id), this.ownership())); }; deleteAll = async () => { - return this.db.delete(documents).where(eq(documents.userId, this.userId)); + return this.db.delete(documents).where(this.ownership()); }; query = async ({ @@ -74,7 +83,7 @@ export class DocumentModel { total: number; }> => { const offset = current * pageSize; - const conditions = [eq(documents.userId, this.userId)]; + const conditions = [this.ownership()]; if (fileTypes?.length) { conditions.push(inArray(documents.fileType, fileTypes)); @@ -141,19 +150,19 @@ export class DocumentModel { findById = async (id: string): Promise => { return this.db.query.documents.findFirst({ - where: and(eq(documents.userId, this.userId), eq(documents.id, id)), + where: and(this.ownership(), eq(documents.id, id)), }); }; findByFileId = async (fileId: string) => { return this.db.query.documents.findFirst({ - where: and(eq(documents.userId, this.userId), eq(documents.fileId, fileId)), + where: and(this.ownership(), eq(documents.fileId, fileId)), }); }; findBySlug = async (slug: string): Promise => { return this.db.query.documents.findFirst({ - where: and(eq(documents.userId, this.userId), eq(documents.slug, slug)), + where: and(this.ownership(), eq(documents.slug, slug)), }); }; @@ -170,7 +179,7 @@ export class DocumentModel { ): Promise => { return this.db.query.documents.findFirst({ where: and( - eq(documents.userId, this.userId), + this.ownership(), eq(documents.source, source), eq(documents.sourceType, sourceType), ), @@ -181,6 +190,219 @@ export class DocumentModel { return this.db .update(documents) .set({ ...value, updatedAt: new Date() }) - .where(and(eq(documents.userId, this.userId), eq(documents.id, id))); + .where(and(this.ownership(), eq(documents.id, id))); + }; + + /** + * Collect a document and all its descendants (folders + leaves) via BFS. + * Honors the current ownership scope. + */ + private collectSubtree = async ( + rootId: string, + runner: LobeChatDatabase = this.db, + ): Promise => { + const root = await runner.query.documents.findFirst({ + where: and(this.ownership(), eq(documents.id, rootId)), + }); + if (!root) return []; + + const collected: DocumentItem[] = [root]; + let frontier: string[] = [root.id]; + + while (frontier.length > 0) { + const children = await runner.query.documents.findMany({ + where: and(this.ownership(), inArray(documents.parentId, frontier)), + }); + if (children.length === 0) break; + collected.push(...children); + frontier = children.map((c) => c.id); + } + + return collected; + }; + + countFileUsageInSubtree = async ( + rootId: string, + runner: LobeChatDatabase = this.db, + ): Promise => { + const subtree = await this.collectSubtree(rootId, runner); + if (subtree.length === 0) return 0; + + const ids = subtree.map((d) => d.id); + const result = await runner + .select({ totalSize: sum(files.size) }) + .from(files) + .where( + and( + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, files), + inArray(files.parentId, ids), + ), + ); + + return parseInt(result[0]?.totalSize ?? '0') || 0; + }; + + /** + * Transfer a document (and its subtree) to another workspace / personal scope. + * Files anchored to documents in the subtree are also re-homed so the + * resource manager view stays consistent. + */ + transferTo = async ( + documentId: string, + targetWorkspaceId: string | null, + targetUserId: string, + ): Promise<{ documentIds: string[] }> => { + return this.db.transaction(async (trx) => { + const scopedTrx = new DocumentModel(trx as LobeChatDatabase, this.userId, this.workspaceId); + const subtree = await scopedTrx.collectSubtree(documentId, trx as LobeChatDatabase); + if (subtree.length === 0) throw new Error('Document not found'); + + const ids = subtree.map((d) => d.id); + const ownershipUpdate = { userId: targetUserId, workspaceId: targetWorkspaceId }; + + // Resolve slug conflicts in the target scope + for (const doc of subtree) { + if (!doc.slug) continue; + const slug = await this.findAvailableSlug( + trx as LobeChatDatabase, + doc.slug, + targetWorkspaceId, + targetUserId, + doc.id, + ); + if (slug !== doc.slug) { + await (trx as LobeChatDatabase) + .update(documents) + .set({ slug }) + .where(eq(documents.id, doc.id)); + } + } + + await (trx as LobeChatDatabase) + .update(documents) + .set({ ...ownershipUpdate, updatedAt: new Date() }) + .where(inArray(documents.id, ids)); + + // Move files anchored to these documents + await (trx as LobeChatDatabase) + .update(files) + .set(ownershipUpdate) + .where(inArray(files.parentId, ids)); + + return { documentIds: ids }; + }); + }; + + /** + * Deep clone a document (and its subtree) into another workspace / personal + * scope. Generates fresh ids and preserves the parent/child topology. + */ + copyToWorkspace = async ( + documentId: string, + targetWorkspaceId: string | null, + targetUserId: string, + ): Promise<{ rootId: string }> => { + return this.db.transaction(async (trx) => { + const scopedTrx = new DocumentModel(trx as LobeChatDatabase, this.userId, this.workspaceId); + const subtree = await scopedTrx.collectSubtree(documentId, trx as LobeChatDatabase); + if (subtree.length === 0) throw new Error('Document not found'); + + // BFS clone: parents are inserted before children so we always know the + // new parent id by the time we get to the child. + const idMap = new Map(); + const byId = new Map(subtree.map((d) => [d.id, d])); + const queue: string[] = [documentId]; + const seen = new Set(); + + while (queue.length > 0) { + const currentId = queue.shift()!; + if (seen.has(currentId)) continue; + seen.add(currentId); + const original = byId.get(currentId); + if (!original) continue; + + const newParentId = + currentId === documentId ? null : (idMap.get(original.parentId!) ?? null); + + let newSlug = original.slug; + if (newSlug) { + newSlug = await this.findAvailableSlug( + trx as LobeChatDatabase, + newSlug, + targetWorkspaceId, + targetUserId, + ); + } + + const inserted = (await (trx as LobeChatDatabase) + .insert(documents) + .values({ + accessedAt: original.accessedAt, + clientId: null, + content: original.content, + editorData: original.editorData, + fileId: null, + fileType: original.fileType, + filename: original.filename, + knowledgeBaseId: null, + metadata: { ...original.metadata, duplicatedFrom: original.id }, + pages: original.pages, + parentId: newParentId, + slug: newSlug, + source: original.source, + sourceType: original.sourceType, + title: original.title, + totalCharCount: original.totalCharCount, + totalLineCount: original.totalLineCount, + userId: targetUserId, + workspaceId: targetWorkspaceId, + } as NewDocument) + .returning({ id: documents.id })) as { id: string }[]; + + idMap.set(original.id, inserted[0]!.id); + + for (const c of subtree) { + if (c.parentId === original.id) queue.push(c.id); + } + } + + return { rootId: idMap.get(documentId)! }; + }); + }; + + /** + * Find a slug not already taken in the target (workspaceId, userId) scope. + * Tries `slug`, `slug-1`, …, `slug-99`. Mirrors the agent transfer behavior. + */ + private findAvailableSlug = async ( + runner: LobeChatDatabase, + baseSlug: string, + targetWorkspaceId: string | null, + targetUserId: string, + ignoreDocumentId?: string, + ): Promise => { + const buildWhere = (candidate: string) => + targetWorkspaceId + ? and(eq(documents.slug, candidate), eq(documents.workspaceId, targetWorkspaceId)) + : and( + eq(documents.slug, candidate), + eq(documents.userId, targetUserId), + isNull(documents.workspaceId), + ); + + const isFree = async (candidate: string): Promise => { + const existing = await runner.query.documents.findFirst({ where: buildWhere(candidate) }); + if (!existing) return true; + return ignoreDocumentId !== undefined && existing.id === ignoreDocumentId; + }; + + if (await isFree(baseSlug)) return baseSlug; + + for (let suffix = 1; suffix < 100; suffix++) { + const candidate = `${baseSlug}-${suffix}`; + if (await isFree(candidate)) return candidate; + } + // Fallback: append timestamp to guarantee uniqueness + return `${baseSlug}-${Date.now()}`; }; } diff --git a/packages/database/src/models/documentHistory.ts b/packages/database/src/models/documentHistory.ts index cd6c63208a..2ff6127c60 100644 --- a/packages/database/src/models/documentHistory.ts +++ b/packages/database/src/models/documentHistory.ts @@ -3,6 +3,7 @@ import { and, desc, eq, lt, or } from 'drizzle-orm'; import type { DocumentHistoryItem, NewDocumentHistory } from '../schemas'; import { documentHistories, documents } from '../schemas'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; export interface QueryDocumentHistoryParams { beforeId?: string; @@ -13,18 +14,32 @@ export interface QueryDocumentHistoryParams { export class DocumentHistoryModel { private userId: string; + private workspaceId?: string; private db: LobeChatDatabase; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; + this.workspaceId = workspaceId; this.db = db; } + private ownership() { + return buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + documentHistories, + ); + } + create = async (params: Omit): Promise => { const [document] = await this.db .select({ id: documents.id }) .from(documents) - .where(and(eq(documents.id, params.documentId), eq(documents.userId, this.userId))) + .where( + and( + eq(documents.id, params.documentId), + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents), + ), + ) .limit(1); if (!document) { @@ -33,7 +48,7 @@ export class DocumentHistoryModel { const [result] = await this.db .insert(documentHistories) - .values({ ...params, userId: this.userId }) + .values(buildWorkspacePayload({ userId: this.userId, workspaceId: this.workspaceId }, params)) .returning(); return result!; @@ -42,29 +57,24 @@ export class DocumentHistoryModel { delete = async (id: string) => { return this.db .delete(documentHistories) - .where(and(eq(documentHistories.id, id), eq(documentHistories.userId, this.userId))); + .where(and(eq(documentHistories.id, id), this.ownership())); }; deleteByDocumentId = async (documentId: string) => { return this.db .delete(documentHistories) - .where( - and( - eq(documentHistories.documentId, documentId), - eq(documentHistories.userId, this.userId), - ), - ); + .where(and(eq(documentHistories.documentId, documentId), this.ownership())); }; deleteAll = async () => { - return this.db.delete(documentHistories).where(eq(documentHistories.userId, this.userId)); + return this.db.delete(documentHistories).where(this.ownership()); }; findById = async (id: string): Promise => { const [result] = await this.db .select() .from(documentHistories) - .where(and(eq(documentHistories.id, id), eq(documentHistories.userId, this.userId))) + .where(and(eq(documentHistories.id, id), this.ownership())) .limit(1); return result; @@ -74,12 +84,7 @@ export class DocumentHistoryModel { const [result] = await this.db .select() .from(documentHistories) - .where( - and( - eq(documentHistories.documentId, documentId), - eq(documentHistories.userId, this.userId), - ), - ) + .where(and(eq(documentHistories.documentId, documentId), this.ownership())) .orderBy(desc(documentHistories.savedAt), desc(documentHistories.id)) .limit(1); @@ -92,10 +97,7 @@ export class DocumentHistoryModel { documentId, limit = 50, }: QueryDocumentHistoryParams): Promise => { - const conditions = [ - eq(documentHistories.documentId, documentId), - eq(documentHistories.userId, this.userId), - ]; + const conditions = [eq(documentHistories.documentId, documentId), this.ownership()]; if (beforeSavedAt !== undefined) { if (beforeId !== undefined) { diff --git a/packages/database/src/models/documentShare.ts b/packages/database/src/models/documentShare.ts index c0ea335ac6..1a854a34d2 100644 --- a/packages/database/src/models/documentShare.ts +++ b/packages/database/src/models/documentShare.ts @@ -4,6 +4,7 @@ import { and, eq, sql } from 'drizzle-orm'; import { documents, documentShares, users } from '../schemas'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; export interface DocumentShareAccessResult { document: typeof documents.$inferSelect; @@ -17,10 +18,12 @@ export interface DocumentShareAccessResult { export class DocumentShareModel { private userId: string; + private workspaceId?: string; private db: LobeChatDatabase; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; + this.workspaceId = workspaceId; this.db = db; } @@ -34,7 +37,12 @@ export class DocumentShareModel { const [doc] = await this.db .select({ id: documents.id }) .from(documents) - .where(and(eq(documents.id, documentId), eq(documents.userId, this.userId))) + .where( + and( + eq(documents.id, documentId), + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents), + ), + ) .limit(1); if (!doc) { @@ -43,12 +51,16 @@ export class DocumentShareModel { const [result] = await this.db .insert(documentShares) - .values({ - documentId, - permission: params.permission ?? 'read', - userId: this.userId, - visibility: params.visibility ?? 'private', - }) + .values( + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { + documentId, + permission: params.permission ?? 'read', + visibility: params.visibility ?? 'private', + }, + ), + ) .onConflictDoNothing({ target: documentShares.documentId }) .returning(); @@ -63,7 +75,15 @@ export class DocumentShareModel { const [result] = await this.db .update(documentShares) .set({ updatedAt: new Date(), visibility }) - .where(and(eq(documentShares.documentId, documentId), eq(documentShares.userId, this.userId))) + .where( + and( + eq(documentShares.documentId, documentId), + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + documentShares, + ), + ), + ) .returning(); return result || null; @@ -73,7 +93,15 @@ export class DocumentShareModel { const [result] = await this.db .update(documentShares) .set({ permission, updatedAt: new Date() }) - .where(and(eq(documentShares.documentId, documentId), eq(documentShares.userId, this.userId))) + .where( + and( + eq(documentShares.documentId, documentId), + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + documentShares, + ), + ), + ) .returning(); return result || null; @@ -83,7 +111,13 @@ export class DocumentShareModel { return this.db .delete(documentShares) .where( - and(eq(documentShares.documentId, documentId), eq(documentShares.userId, this.userId)), + and( + eq(documentShares.documentId, documentId), + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + documentShares, + ), + ), ); }; @@ -98,7 +132,15 @@ export class DocumentShareModel { visibility: documentShares.visibility, }) .from(documentShares) - .where(and(eq(documentShares.documentId, documentId), eq(documentShares.userId, this.userId))) + .where( + and( + eq(documentShares.documentId, documentId), + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + documentShares, + ), + ), + ) .limit(1); return result[0] || null; diff --git a/packages/database/src/models/embedding.ts b/packages/database/src/models/embedding.ts index 922edfbad5..e0534e814f 100644 --- a/packages/database/src/models/embedding.ts +++ b/packages/database/src/models/embedding.ts @@ -3,20 +3,26 @@ import { and, count, eq } from 'drizzle-orm'; import type { NewEmbeddingsItem } from '../schemas'; import { embeddings } from '../schemas'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspaceWhere } from '../utils/workspace'; export class EmbeddingModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, embeddings); + create = async (value: Omit) => { const [item] = await this.db .insert(embeddings) - .values({ ...value, userId: this.userId }) + .values({ ...value, userId: this.userId, workspaceId: this.workspaceId ?? null }) .returning(); return item.id as string; @@ -25,27 +31,31 @@ export class EmbeddingModel { bulkCreate = async (values: Omit[]) => { return this.db .insert(embeddings) - .values(values.map((item) => ({ ...item, userId: this.userId }))) + .values( + values.map((item) => ({ + ...item, + userId: this.userId, + workspaceId: this.workspaceId ?? null, + })), + ) .onConflictDoNothing({ target: [embeddings.chunkId], }); }; delete = async (id: string) => { - return this.db - .delete(embeddings) - .where(and(eq(embeddings.id, id), eq(embeddings.userId, this.userId))); + return this.db.delete(embeddings).where(and(eq(embeddings.id, id), this.ownership())); }; query = async () => { return this.db.query.embeddings.findMany({ - where: eq(embeddings.userId, this.userId), + where: this.ownership(), }); }; findById = async (id: string) => { return this.db.query.embeddings.findFirst({ - where: and(eq(embeddings.id, id), eq(embeddings.userId, this.userId)), + where: and(eq(embeddings.id, id), this.ownership()), }); }; @@ -55,7 +65,7 @@ export class EmbeddingModel { count: count(), }) .from(embeddings) - .where(eq(embeddings.userId, this.userId)); + .where(this.ownership()); return result[0].count; }; diff --git a/packages/database/src/models/file.ts b/packages/database/src/models/file.ts index 1a4f56aa4c..85daeea52a 100644 --- a/packages/database/src/models/file.ts +++ b/packages/database/src/models/file.ts @@ -20,6 +20,7 @@ import { topics, } from '../schemas'; import type { LobeChatDatabase, Transaction } from '../type'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; /** * Minimal file descriptor used to bootstrap user-uploaded files into a sandbox. @@ -36,12 +37,17 @@ export interface SandboxInitFileItem { export class FileModel { private readonly userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, files); + /** * Get file by ID without userId filter (public access) * Use this for scenarios like file proxy where file should be accessible by ID alone @@ -82,17 +88,26 @@ export class FileModel { const result = (await tx .insert(files) - .values({ ...params, userId: this.userId }) + .values( + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { ...params }, + ), + ) .returning()) as FileItem[]; const item = result[0]!; if (params.knowledgeBaseId) { - await tx.insert(knowledgeBaseFiles).values({ - fileId: item.id, - knowledgeBaseId: params.knowledgeBaseId, - userId: this.userId, - }); + await tx.insert(knowledgeBaseFiles).values( + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { + fileId: item.id, + knowledgeBaseId: params.knowledgeBaseId, + }, + ), + ); } return item; @@ -150,7 +165,7 @@ export class FileModel { .where( and( eq(documents.fileId, id), - eq(documents.userId, this.userId), + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents), eq(documents.sourceType, 'file'), ), ); @@ -166,7 +181,7 @@ export class FileModel { } // 4. Delete file record - await tx.delete(files).where(and(eq(files.id, id), eq(files.userId, this.userId))); + await tx.delete(files).where(and(eq(files.id, id), this.ownership())); if (!fileHash) return; @@ -200,7 +215,7 @@ export class FileModel { totalSize: sum(files.size), }) .from(files) - .where(eq(files.userId, this.userId)); + .where(this.ownership()); return parseInt(result[0].totalSize!) || 0; }; @@ -211,7 +226,7 @@ export class FileModel { return await this.db.transaction(async (trx) => { // 1. First get the file list to return the deleted files const fileList = await trx.query.files.findMany({ - where: and(inArray(files.id, ids), eq(files.userId, this.userId)), + where: and(inArray(files.id, ids), this.ownership()), }); if (fileList.length === 0) return []; @@ -229,7 +244,7 @@ export class FileModel { .where( and( inArray(documents.fileId, ids), - eq(documents.userId, this.userId), + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents), eq(documents.sourceType, 'file'), ), ); @@ -243,7 +258,7 @@ export class FileModel { } // 5. Delete file records - await trx.delete(files).where(and(inArray(files.id, ids), eq(files.userId, this.userId))); + await trx.delete(files).where(and(inArray(files.id, ids), this.ownership())); // If global files don't need to be deleted, no storage object should be removed. if (!removeGlobalFile || hashList.length === 0) return []; @@ -275,7 +290,7 @@ export class FileModel { }; clear = async () => { - return this.db.delete(files).where(eq(files.userId, this.userId)); + return this.db.delete(files).where(this.ownership()); }; query = async ({ @@ -287,10 +302,7 @@ export class FileModel { showFilesInKnowledgeBase, }: QueryFileListParams = {}) => { // 1. Build where clause - let whereClause = and( - q ? ilike(files.name, `%${q}%`) : undefined, - eq(files.userId, this.userId), - ); + let whereClause = and(q ? ilike(files.name, `%${q}%`) : undefined, this.ownership()); if (category && category !== FilesTabs.All && category !== FilesTabs.Home) { const fileTypePrefix = this.getFileTypePrefix(category as FilesTabs); if (Array.isArray(fileTypePrefix)) { @@ -333,6 +345,7 @@ export class FileModel { size: files.size, updatedAt: files.updatedAt, url: files.url, + userId: files.userId, }) .from(files); @@ -365,14 +378,14 @@ export class FileModel { findByIds = async (ids: string[]) => { return this.db.query.files.findMany({ - where: and(inArray(files.id, ids), eq(files.userId, this.userId)), + where: and(inArray(files.id, ids), this.ownership()), }); }; findById = async (id: string, trx?: Transaction) => { const database = trx || this.db; return database.query.files.findFirst({ - where: and(eq(files.id, id), eq(files.userId, this.userId)), + where: and(eq(files.id, id), this.ownership()), }); }; @@ -429,7 +442,7 @@ export class FileModel { this.db .update(files) .set({ ...value, updatedAt: new Date() }) - .where(and(eq(files.id, id), eq(files.userId, this.userId))); + .where(and(eq(files.id, id), this.ownership())); /** * get the corresponding file type prefix according to FilesTabs @@ -459,10 +472,7 @@ export class FileModel { findByNames = async (fileNames: string[]) => this.db.query.files.findMany({ - where: and( - or(...fileNames.map((name) => like(files.name, `${name}%`))), - eq(files.userId, this.userId), - ), + where: and(or(...fileNames.map((name) => like(files.name, `${name}%`))), this.ownership()), }); // Abstract common method for deleting chunks @@ -514,4 +524,81 @@ export class FileModel { return chunkIds; }; + + // ========== Transfer / Copy ========== + + /** + * Transfer a single file (not a folder — folders live in `documents` and are + * handled by `DocumentModel.transferTo`, which already cascades into `files` + * via `parentId`). Updates ownership + knowledgeBaseFiles linkage so the + * file remains visible in the target scope's resource manager. + */ + transferTo = async ( + fileId: string, + targetWorkspaceId: string | null, + targetUserId: string, + ): Promise<{ fileId: string }> => { + return this.db.transaction(async (trx) => { + const file = await trx.query.files.findFirst({ + where: and(eq(files.id, fileId), this.ownership()), + }); + if (!file) throw new Error('File not found'); + + const ownershipUpdate = { userId: targetUserId, workspaceId: targetWorkspaceId }; + + await trx + .update(files) + .set({ ...ownershipUpdate, updatedAt: new Date() }) + .where(eq(files.id, fileId)); + + // Knowledge base links are scoped per-user; keep them pointed at the new owner. + await trx + .update(knowledgeBaseFiles) + .set({ userId: targetUserId }) + .where(eq(knowledgeBaseFiles.fileId, fileId)); + + return { fileId }; + }); + }; + + /** + * Clone a file record into another workspace / personal scope. The physical + * blob is shared via `fileHash` → `globalFiles`, so we only copy the row. AI + * index references (`chunkTaskId` / `embeddingTaskId`) are reset; the new + * scope is expected to re-index lazily. + */ + copyToWorkspace = async ( + fileId: string, + targetWorkspaceId: string | null, + targetUserId: string, + ): Promise<{ fileId: string }> => { + return this.db.transaction(async (trx) => { + const file = await trx.query.files.findFirst({ + where: and(eq(files.id, fileId), this.ownership()), + }); + if (!file) throw new Error('File not found'); + + const inserted = (await trx + .insert(files) + .values({ + chunkTaskId: null, + clientId: null, + embeddingTaskId: null, + fileHash: file.fileHash, + fileType: file.fileType, + metadata: { ...(file.metadata as Record), duplicatedFrom: file.id }, + name: file.name, + // parentId would dangle in target scope; the user can drag it under a folder later. + parentId: null, + size: file.size, + source: file.source, + url: file.url, + userId: targetUserId, + workspaceId: targetWorkspaceId, + } as NewFile) + .returning({ id: files.id })) as { id: string }[]; + + return { fileId: inserted[0]!.id }; + }); + }; } diff --git a/packages/database/src/models/generation.ts b/packages/database/src/models/generation.ts index 58a6845649..145d5cf421 100644 --- a/packages/database/src/models/generation.ts +++ b/packages/database/src/models/generation.ts @@ -16,6 +16,7 @@ import type { NewFile } from '../schemas'; import type { GenerationItem, GenerationWithAsyncTask, NewGeneration } from '../schemas/generation'; import { generations } from '../schemas/generation'; import type { LobeChatDatabase, Transaction } from '../type'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; import { FileModel } from './file'; // Create debug logger @@ -24,16 +25,21 @@ const log = debug('lobe-image:generation-model'); export class GenerationModel { private db: LobeChatDatabase; private userId: string; + private workspaceId?: string; private fileModel: FileModel; private fileService: FileService; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; - this.fileModel = new FileModel(db, userId); + this.workspaceId = workspaceId; + this.fileModel = new FileModel(db, userId, workspaceId); this.fileService = new FileService(db, userId); } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, generations); + async create(value: Omit): Promise { log('Creating generation: %O', { generationBatchId: value.generationBatchId, @@ -42,7 +48,9 @@ export class GenerationModel { const [result] = await this.db .insert(generations) - .values({ ...value, userId: this.userId }) + .values( + buildWorkspacePayload({ userId: this.userId, workspaceId: this.workspaceId }, { ...value }), + ) .returning(); log('Generation created successfully: %s', result.id); @@ -53,7 +61,7 @@ export class GenerationModel { log('Finding generation by ID: %s for user: %s', id, this.userId); const result = await this.db.query.generations.findFirst({ - where: and(eq(generations.id, id), eq(generations.userId, this.userId)), + where: and(eq(generations.id, id), this.ownership()), }); log('Generation %s: %s', id, result ? 'found' : 'not found'); @@ -64,7 +72,7 @@ export class GenerationModel { log('Finding generation by ID: %s for user: %s', id, this.userId); const result = await this.db.query.generations.findFirst({ - where: and(eq(generations.id, id), eq(generations.userId, this.userId)), + where: and(eq(generations.id, id), this.ownership()), with: { asyncTask: true, }, @@ -84,7 +92,7 @@ export class GenerationModel { return await tx .update(generations) .set({ ...value, updatedAt: new Date() }) - .where(and(eq(generations.id, id), eq(generations.userId, this.userId))); + .where(and(eq(generations.id, id), this.ownership())); }; const result = await (trx ? executeUpdate(trx) : this.db.transaction(executeUpdate)); @@ -136,7 +144,7 @@ export class GenerationModel { log('Finding generation by asyncTaskId: %s', asyncTaskId); return this.db.query.generations.findFirst({ - where: eq(generations.asyncTaskId, asyncTaskId), + where: and(eq(generations.asyncTaskId, asyncTaskId), this.ownership()), }); } @@ -146,7 +154,7 @@ export class GenerationModel { const executeDelete = async (tx: Transaction) => { return await tx .delete(generations) - .where(and(eq(generations.id, id), eq(generations.userId, this.userId))) + .where(and(eq(generations.id, id), this.ownership())) .returning(); }; diff --git a/packages/database/src/models/generationBatch.ts b/packages/database/src/models/generationBatch.ts index f8ea26c958..3a99e6f70f 100644 --- a/packages/database/src/models/generationBatch.ts +++ b/packages/database/src/models/generationBatch.ts @@ -17,6 +17,7 @@ import type { } from '../schemas/generation'; import { generationBatches } from '../schemas/generation'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; import { GenerationModel } from './generation'; const log = debug('lobe-image:generation-batch-model'); @@ -24,16 +25,21 @@ const log = debug('lobe-image:generation-batch-model'); export class GenerationBatchModel { private db: LobeChatDatabase; private userId: string; + private workspaceId?: string; private fileService: FileService; private generationModel: GenerationModel; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; this.fileService = new FileService(db, userId); - this.generationModel = new GenerationModel(db, userId); + this.generationModel = new GenerationModel(db, userId, workspaceId); } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, generationBatches); + async create(value: NewGenerationBatch): Promise { log('Creating generation batch: %O', { topicId: value.generationTopicId, @@ -42,7 +48,9 @@ export class GenerationBatchModel { const [result] = await this.db .insert(generationBatches) - .values({ ...value, userId: this.userId }) + .values( + buildWorkspacePayload({ userId: this.userId, workspaceId: this.workspaceId }, { ...value }), + ) .returning(); log('Generation batch created successfully: %s', result.id); @@ -53,7 +61,7 @@ export class GenerationBatchModel { log('Finding generation batch by ID: %s for user: %s', id, this.userId); const result = await this.db.query.generationBatches.findFirst({ - where: and(eq(generationBatches.id, id), eq(generationBatches.userId, this.userId)), + where: and(eq(generationBatches.id, id), this.ownership()), }); log('Generation batch %s: %s', id, result ? 'found' : 'not found'); @@ -65,10 +73,7 @@ export class GenerationBatchModel { const results = await this.db.query.generationBatches.findMany({ orderBy: (table, { desc }) => [desc(table.createdAt)], - where: and( - eq(generationBatches.generationTopicId, topicId), - eq(generationBatches.userId, this.userId), - ), + where: and(eq(generationBatches.generationTopicId, topicId), this.ownership()), }); log('Found %d generation batches for topic %s', results.length, topicId); @@ -87,10 +92,7 @@ export class GenerationBatchModel { const results = await this.db.query.generationBatches.findMany({ orderBy: (table, { asc }) => [asc(table.createdAt)], - where: and( - eq(generationBatches.generationTopicId, topicId), - eq(generationBatches.userId, this.userId), - ), + where: and(eq(generationBatches.generationTopicId, topicId), this.ownership()), with: { generations: { orderBy: (table, { asc }) => [asc(table.createdAt), asc(table.id)], @@ -184,7 +186,7 @@ export class GenerationBatchModel { // 1. First, get generations with their assets to collect file URLs for cleanup const batchWithGenerations = await this.db.query.generationBatches.findFirst({ - where: and(eq(generationBatches.id, id), eq(generationBatches.userId, this.userId)), + where: and(eq(generationBatches.id, id), this.ownership()), with: { generations: { columns: { @@ -215,7 +217,7 @@ export class GenerationBatchModel { // 3. Delete the batch record (this will cascade delete all associated generations) const [deletedBatch] = await this.db .delete(generationBatches) - .where(and(eq(generationBatches.id, id), eq(generationBatches.userId, this.userId))) + .where(and(eq(generationBatches.id, id), this.ownership())) .returning(); log( diff --git a/packages/database/src/models/generationTopic.ts b/packages/database/src/models/generationTopic.ts index 45cafc5471..428b2e4b94 100644 --- a/packages/database/src/models/generationTopic.ts +++ b/packages/database/src/models/generationTopic.ts @@ -11,20 +11,26 @@ import type { GenerationTopicItem } from '../schemas/generation'; import { generationTopics } from '../schemas/generation'; import type { LobeChatDatabase } from '../type'; import type { GenerationTopicType } from '../types/generation'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; export class GenerationTopicModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; private fileService: FileService; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; this.fileService = new FileService(db, userId); } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, generationTopics); + queryAll = async (type?: GenerationTopicType) => { - const conditions = [eq(generationTopics.userId, this.userId)]; + const conditions = [this.ownership()]; if (type) { conditions.push(eq(generationTopics.type, type)); } @@ -51,11 +57,15 @@ export class GenerationTopicModel { create = async (title: string, type?: GenerationTopicType) => { const [newGenerationTopic] = await this.db .insert(generationTopics) - .values({ - title, - type: type ?? 'image', - userId: this.userId, - }) + .values( + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { + title, + type: type ?? 'image', + }, + ), + ) .returning(); return newGenerationTopic; @@ -68,7 +78,7 @@ export class GenerationTopicModel { const [updatedTopic] = await this.db .update(generationTopics) .set({ ...data, updatedAt: new Date() }) - .where(and(eq(generationTopics.id, id), eq(generationTopics.userId, this.userId))) + .where(and(eq(generationTopics.id, id), this.ownership())) .returning(); return updatedTopic; @@ -90,7 +100,7 @@ export class GenerationTopicModel { ): Promise<{ deletedTopic: GenerationTopicItem; filesToDelete: string[] } | undefined> => { // 1. First, get the topic with all its batches and generations to collect file URLs const topicWithBatches = await this.db.query.generationTopics.findFirst({ - where: and(eq(generationTopics.id, id), eq(generationTopics.userId, this.userId)), + where: and(eq(generationTopics.id, id), this.ownership()), with: { batches: { with: { @@ -134,7 +144,7 @@ export class GenerationTopicModel { // 3. Delete the topic record (this will cascade delete all batches and generations) const [deletedTopic] = await this.db .delete(generationTopics) - .where(and(eq(generationTopics.id, id), eq(generationTopics.userId, this.userId))) + .where(and(eq(generationTopics.id, id), this.ownership())) .returning(); return { diff --git a/packages/database/src/models/knowledgeBase.ts b/packages/database/src/models/knowledgeBase.ts index 1d267983c0..3e774f857b 100644 --- a/packages/database/src/models/knowledgeBase.ts +++ b/packages/database/src/models/knowledgeBase.ts @@ -1,26 +1,37 @@ import type { KnowledgeBaseItem } from '@lobechat/types'; -import { and, count, desc, eq, inArray } from 'drizzle-orm'; +import { and, count, desc, eq, inArray, or, sum } from 'drizzle-orm'; -import type { NewKnowledgeBase } from '../schemas'; -import { documents, knowledgeBaseFiles, knowledgeBases } from '../schemas'; +import type { NewDocument, NewFile, NewKnowledgeBase } from '../schemas'; +import { documents, files, knowledgeBaseFiles, knowledgeBases } from '../schemas'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; import { FileModel } from './file'; export class KnowledgeBaseModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, knowledgeBases); + // create create = async (params: Omit) => { const [result] = await this.db .insert(knowledgeBases) - .values({ ...params, userId: this.userId }) + .values( + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { ...params }, + ), + ) .returning(); return result; @@ -29,7 +40,7 @@ export class KnowledgeBaseModel { addFilesToKnowledgeBase = async (id: string, fileIds: string[]) => { // Verify the target knowledge base belongs to the current user const kb = await this.db.query.knowledgeBases.findFirst({ - where: and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId)), + where: and(eq(knowledgeBases.id, id), this.ownership()), }); if (!kb) return []; @@ -43,7 +54,12 @@ export class KnowledgeBaseModel { const docsWithFiles = await this.db .select({ fileId: documents.fileId }) .from(documents) - .where(and(inArray(documents.id, documentIds), eq(documents.userId, this.userId))); + .where( + and( + inArray(documents.id, documentIds), + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents), + ), + ); const mirrorFileIds = docsWithFiles .map((doc) => doc.fileId) @@ -54,7 +70,12 @@ export class KnowledgeBaseModel { await this.db .update(documents) .set({ knowledgeBaseId: id }) - .where(and(inArray(documents.id, documentIds), eq(documents.userId, this.userId))); + .where( + and( + inArray(documents.id, documentIds), + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents), + ), + ); } // Insert using resolved file IDs @@ -65,20 +86,23 @@ export class KnowledgeBaseModel { return this.db .insert(knowledgeBaseFiles) .values( - resolvedFileIds.map((fileId) => ({ fileId, knowledgeBaseId: id, userId: this.userId })), + resolvedFileIds.map((fileId) => ({ + fileId, + knowledgeBaseId: id, + userId: this.userId, + workspaceId: this.workspaceId ?? null, + })), ) .returning(); }; // delete delete = async (id: string) => { - return this.db - .delete(knowledgeBases) - .where(and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId))); + return this.db.delete(knowledgeBases).where(and(eq(knowledgeBases.id, id), this.ownership())); }; deleteAll = async () => { - return this.db.delete(knowledgeBases).where(eq(knowledgeBases.userId, this.userId)); + return this.db.delete(knowledgeBases).where(this.ownership()); }; removeFilesFromKnowledgeBase = async (knowledgeBaseId: string, ids: string[]) => { @@ -92,7 +116,12 @@ export class KnowledgeBaseModel { const docsWithFiles = await this.db .select({ fileId: documents.fileId }) .from(documents) - .where(and(inArray(documents.id, documentIds), eq(documents.userId, this.userId))); + .where( + and( + inArray(documents.id, documentIds), + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents), + ), + ); const mirrorFileIds = docsWithFiles .map((doc) => doc.fileId) @@ -106,7 +135,7 @@ export class KnowledgeBaseModel { .where( and( inArray(documents.id, documentIds), - eq(documents.userId, this.userId), + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents), eq(documents.knowledgeBaseId, knowledgeBaseId), ), ); @@ -121,7 +150,10 @@ export class KnowledgeBaseModel { .delete(knowledgeBaseFiles) .where( and( - eq(knowledgeBaseFiles.userId, this.userId), + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + knowledgeBaseFiles, + ), eq(knowledgeBaseFiles.knowledgeBaseId, knowledgeBaseId), inArray(knowledgeBaseFiles.fileId, resolvedFileIds), ), @@ -142,7 +174,7 @@ export class KnowledgeBaseModel { updatedAt: knowledgeBases.updatedAt, }) .from(knowledgeBases) - .where(eq(knowledgeBases.userId, this.userId)) + .where(this.ownership()) .orderBy(desc(knowledgeBases.updatedAt)); return data as KnowledgeBaseItem[]; @@ -150,16 +182,288 @@ export class KnowledgeBaseModel { findById = async (id: string) => { return this.db.query.knowledgeBases.findFirst({ - where: and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId)), + where: and(eq(knowledgeBases.id, id), this.ownership()), }); }; + countFileUsage = async (id: string): Promise => { + const result = await this.db + .select({ totalSize: sum(files.size) }) + .from(knowledgeBaseFiles) + .innerJoin(files, eq(files.id, knowledgeBaseFiles.fileId)) + .where( + and( + eq(knowledgeBaseFiles.knowledgeBaseId, id), + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + knowledgeBaseFiles, + ), + ), + ); + + return parseInt(result[0]?.totalSize ?? '0') || 0; + }; + // update update = async (id: string, value: Partial) => this.db .update(knowledgeBases) .set({ ...value, updatedAt: new Date() }) - .where(and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId))); + .where(and(eq(knowledgeBases.id, id), this.ownership())); + + private resolveAvailableName = async ( + db: LobeChatDatabase, + name: string, + targetWorkspaceId: string | null, + targetUserId: string, + excludeId?: string, + ): Promise => { + const existingKnowledgeBases = await db + .select({ id: knowledgeBases.id, name: knowledgeBases.name }) + .from(knowledgeBases) + .where( + buildWorkspaceWhere( + { userId: targetUserId, workspaceId: targetWorkspaceId ?? undefined }, + knowledgeBases, + ), + ); + const existingNames = new Set( + existingKnowledgeBases + .filter((knowledgeBase) => knowledgeBase.id !== excludeId) + .map((knowledgeBase) => knowledgeBase.name), + ); + + if (!existingNames.has(name)) return name; + + let index = 1; + let candidate = `${name} (${index})`; + while (existingNames.has(candidate)) { + index += 1; + candidate = `${name} (${index})`; + } + + return candidate; + }; + + transferTo = async ( + id: string, + targetWorkspaceId: string | null, + targetUserId: string, + ): Promise<{ id: string }> => { + return this.db.transaction(async (trx) => { + const [knowledgeBase] = await trx + .select() + .from(knowledgeBases) + .where(and(eq(knowledgeBases.id, id), this.ownership())) + .limit(1); + if (!knowledgeBase) throw new Error('Knowledge base not found'); + + const fileLinks = await trx + .select({ fileId: knowledgeBaseFiles.fileId }) + .from(knowledgeBaseFiles) + .where(eq(knowledgeBaseFiles.knowledgeBaseId, id)); + const fileIds = fileLinks.map((item) => item.fileId); + const now = new Date(); + const ownershipUpdate = { userId: targetUserId, workspaceId: targetWorkspaceId }; + const targetName = await this.resolveAvailableName( + trx as LobeChatDatabase, + knowledgeBase.name, + targetWorkspaceId, + targetUserId, + id, + ); + + await trx + .update(knowledgeBases) + .set({ ...ownershipUpdate, name: targetName, updatedAt: now }) + .where(eq(knowledgeBases.id, id)); + + await trx + .update(knowledgeBaseFiles) + .set(ownershipUpdate) + .where(eq(knowledgeBaseFiles.knowledgeBaseId, id)); + + if (fileIds.length > 0) { + await trx + .update(files) + .set({ ...ownershipUpdate, updatedAt: now }) + .where(inArray(files.id, fileIds)); + } + + const documentWhere = + fileIds.length > 0 + ? or(eq(documents.knowledgeBaseId, id), inArray(documents.fileId, fileIds)) + : eq(documents.knowledgeBaseId, id); + + await trx + .update(documents) + .set({ ...ownershipUpdate, updatedAt: now }) + .where(documentWhere); + + return { id }; + }); + }; + + copyToWorkspace = async ( + id: string, + targetWorkspaceId: string | null, + targetUserId: string, + ): Promise<{ id: string }> => { + return this.db.transaction(async (trx) => { + const [knowledgeBase] = await trx + .select() + .from(knowledgeBases) + .where(and(eq(knowledgeBases.id, id), this.ownership())) + .limit(1); + if (!knowledgeBase) throw new Error('Knowledge base not found'); + const targetName = await this.resolveAvailableName( + trx as LobeChatDatabase, + knowledgeBase.name, + targetWorkspaceId, + targetUserId, + ); + + const [copiedKnowledgeBase] = await trx + .insert(knowledgeBases) + .values({ + avatar: knowledgeBase.avatar, + description: knowledgeBase.description, + isPublic: knowledgeBase.isPublic, + name: targetName, + settings: knowledgeBase.settings, + type: knowledgeBase.type, + userId: targetUserId, + workspaceId: targetWorkspaceId, + } as NewKnowledgeBase) + .returning(); + + const fileLinks = await trx + .select({ fileId: knowledgeBaseFiles.fileId }) + .from(knowledgeBaseFiles) + .where(eq(knowledgeBaseFiles.knowledgeBaseId, id)); + const fileIds = fileLinks.map((item) => item.fileId); + + const documentWhere = + fileIds.length > 0 + ? or(eq(documents.knowledgeBaseId, id), inArray(documents.fileId, fileIds)) + : eq(documents.knowledgeBaseId, id); + const sourceDocuments = await trx.select().from(documents).where(documentWhere); + const sourceDocumentIds = new Set(sourceDocuments.map((item) => item.id)); + const documentIdMap = new Map(); + let pendingDocuments = [...sourceDocuments]; + + while (pendingDocuments.length > 0) { + const readyDocuments = pendingDocuments.filter( + (document) => + !document.parentId || + !sourceDocumentIds.has(document.parentId) || + documentIdMap.has(document.parentId), + ); + const documentsToCopy = readyDocuments.length > 0 ? readyDocuments : pendingDocuments; + + for (const document of documentsToCopy) { + const metadata = + document.metadata && typeof document.metadata === 'object' + ? { ...document.metadata, duplicatedFrom: document.id } + : { duplicatedFrom: document.id }; + const [copiedDocument] = await trx + .insert(documents) + .values({ + clientId: null, + content: document.content, + description: document.description, + editorData: document.editorData, + fileId: null, + fileType: document.fileType, + filename: document.filename, + knowledgeBaseId: + document.knowledgeBaseId === id ? copiedKnowledgeBase.id : document.knowledgeBaseId, + metadata, + pages: document.pages, + parentId: document.parentId ? (documentIdMap.get(document.parentId) ?? null) : null, + source: document.source, + sourceType: document.sourceType, + title: document.title, + totalCharCount: document.totalCharCount, + totalLineCount: document.totalLineCount, + userId: targetUserId, + workspaceId: targetWorkspaceId, + } as NewDocument) + .returning({ id: documents.id }); + + documentIdMap.set(document.id, copiedDocument.id); + } + + const copiedIds = new Set(documentsToCopy.map((document) => document.id)); + pendingDocuments = pendingDocuments.filter((document) => !copiedIds.has(document.id)); + } + + const fileIdMap = new Map(); + if (fileIds.length > 0) { + const sourceFiles = await trx.select().from(files).where(inArray(files.id, fileIds)); + + for (const file of sourceFiles) { + const metadata = + file.metadata && typeof file.metadata === 'object' + ? { ...file.metadata, duplicatedFrom: file.id } + : { duplicatedFrom: file.id }; + const [copiedFile] = await trx + .insert(files) + .values({ + chunkTaskId: null, + clientId: null, + embeddingTaskId: null, + fileHash: file.fileHash, + fileType: file.fileType, + metadata, + name: file.name, + parentId: file.parentId ? (documentIdMap.get(file.parentId) ?? null) : null, + size: file.size, + source: file.source, + url: file.url, + userId: targetUserId, + workspaceId: targetWorkspaceId, + } as NewFile) + .returning({ id: files.id }); + + fileIdMap.set(file.id, copiedFile.id); + } + + const copiedLinks = fileLinks.flatMap((link) => { + const fileId = fileIdMap.get(link.fileId); + if (!fileId) return []; + + return [ + { + fileId, + knowledgeBaseId: copiedKnowledgeBase.id, + userId: targetUserId, + workspaceId: targetWorkspaceId, + }, + ]; + }); + + if (copiedLinks.length > 0) { + await trx.insert(knowledgeBaseFiles).values(copiedLinks); + } + } + + for (const document of sourceDocuments) { + if (!document.fileId) continue; + + const copiedDocumentId = documentIdMap.get(document.id); + const copiedFileId = fileIdMap.get(document.fileId); + if (!copiedDocumentId || !copiedFileId) continue; + + await trx + .update(documents) + .set({ fileId: copiedFileId }) + .where(eq(documents.id, copiedDocumentId)); + } + + return { id: copiedKnowledgeBase.id }; + }); + }; findExclusiveFileIds = async (knowledgeBaseId: string): Promise => { const kbFiles = await this.db @@ -168,7 +472,10 @@ export class KnowledgeBaseModel { .where( and( eq(knowledgeBaseFiles.knowledgeBaseId, knowledgeBaseId), - eq(knowledgeBaseFiles.userId, this.userId), + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + knowledgeBaseFiles, + ), ), ); const fileIds = kbFiles.map((f) => f.fileId); @@ -183,7 +490,10 @@ export class KnowledgeBaseModel { .where( and( inArray(knowledgeBaseFiles.fileId, fileIds), - eq(knowledgeBaseFiles.userId, this.userId), + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + knowledgeBaseFiles, + ), ), ) .groupBy(knowledgeBaseFiles.fileId); @@ -196,14 +506,12 @@ export class KnowledgeBaseModel { let deletedFiles: Array<{ id: string; url: string | null }> = []; if (exclusiveFileIds.length > 0) { - const fileModel = new FileModel(this.db, this.userId); + const fileModel = new FileModel(this.db, this.userId, this.workspaceId); const result = await fileModel.deleteMany(exclusiveFileIds, removeGlobalFile); deletedFiles = (result || []).map((f) => ({ id: f.id, url: f.url })); } - await this.db - .delete(knowledgeBases) - .where(and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId))); + await this.db.delete(knowledgeBases).where(and(eq(knowledgeBases.id, id), this.ownership())); return { deletedFiles }; }; @@ -212,18 +520,23 @@ export class KnowledgeBaseModel { const allKbFileIds = await this.db .select({ fileId: knowledgeBaseFiles.fileId }) .from(knowledgeBaseFiles) - .where(eq(knowledgeBaseFiles.userId, this.userId)); + .where( + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + knowledgeBaseFiles, + ), + ); const fileIds = [...new Set(allKbFileIds.map((f) => f.fileId))]; let deletedFiles: Array<{ id: string; url: string | null }> = []; if (fileIds.length > 0) { - const fileModel = new FileModel(this.db, this.userId); + const fileModel = new FileModel(this.db, this.userId, this.workspaceId); const result = await fileModel.deleteMany(fileIds, removeGlobalFile); deletedFiles = (result || []).map((f) => ({ id: f.id, url: f.url })); } - await this.db.delete(knowledgeBases).where(eq(knowledgeBases.userId, this.userId)); + await this.db.delete(knowledgeBases).where(this.ownership()); return { deletedFiles }; }; diff --git a/packages/database/src/models/llmGenerationTracing.ts b/packages/database/src/models/llmGenerationTracing.ts index b22df378e9..04ffbed38e 100644 --- a/packages/database/src/models/llmGenerationTracing.ts +++ b/packages/database/src/models/llmGenerationTracing.ts @@ -7,6 +7,7 @@ import type { } from '../schemas/llmGenerationTracing'; import { llmGenerationTracing } from '../schemas/llmGenerationTracing'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspaceWhere } from '../utils/workspace'; export interface RecordLlmGenerationParams { agentId?: string | null; @@ -52,10 +53,19 @@ export interface UpdateLlmGenerationFeedbackParams { export class LlmGenerationTracingModel { private readonly db: LobeChatDatabase; private readonly userId: string; + private readonly workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; + } + + private ownership() { + return buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + llmGenerationTracing, + ); } async record(params: RecordLlmGenerationParams): Promise<{ id: string }> { @@ -86,6 +96,7 @@ export class LlmGenerationTracingModel { trigger: params.trigger ?? null, userId: this.userId, validationFailed: params.validationFailed ?? false, + workspaceId: this.workspaceId ?? null, }; const [row] = await this.db @@ -116,7 +127,7 @@ export class LlmGenerationTracingModel { feedbackSource: params.source, feedbackUpdatedAt: new Date(), }) - .where(and(eq(llmGenerationTracing.id, id), eq(llmGenerationTracing.userId, this.userId))) + .where(and(eq(llmGenerationTracing.id, id), this.ownership())) .returning({ id: llmGenerationTracing.id }); return { updated: rows.length > 0 }; } @@ -125,7 +136,7 @@ export class LlmGenerationTracingModel { const [row] = await this.db .select() .from(llmGenerationTracing) - .where(and(eq(llmGenerationTracing.id, id), eq(llmGenerationTracing.userId, this.userId))) + .where(and(eq(llmGenerationTracing.id, id), this.ownership())) .limit(1); return row ?? null; } @@ -134,7 +145,7 @@ export class LlmGenerationTracingModel { return this.db .select() .from(llmGenerationTracing) - .where(eq(llmGenerationTracing.userId, this.userId)) + .where(this.ownership()) .orderBy(desc(llmGenerationTracing.createdAt)) .limit(limit); } diff --git a/packages/database/src/models/message.ts b/packages/database/src/models/message.ts index d3430af22f..2b5d48548a 100644 --- a/packages/database/src/models/message.ts +++ b/packages/database/src/models/message.ts @@ -74,6 +74,7 @@ import type { LobeChatDatabase, Transaction } from '../type'; import { sanitizeBm25Query } from '../utils/bm25'; import { genEndDateWhere, genRangeWhere, genStartDateWhere, genWhere } from '../utils/genWhere'; import { idGenerator } from '../utils/idGenerator'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; import { recomputeTopicUsage } from './topicUsage'; /** @@ -195,12 +196,29 @@ interface SplitCreateMessageParams { export class MessageModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, messages); + + private pluginsOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, messagePlugins); + + private translatesOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, messageTranslates); + + private ttsOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, messageTTS); + + private agentsToSessionsOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentsToSessions); + /** * Touch topics' updatedAt timestamp within a transaction */ @@ -209,7 +227,12 @@ export class MessageModel { await trx .update(topics) .set({ updatedAt: new Date() }) - .where(and(inArray(topics.id, topicIds), eq(topics.userId, this.userId))); + .where( + and( + inArray(topics.id, topicIds), + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, topics), + ), + ); } // **************** Query *************** // @@ -414,7 +437,7 @@ export class MessageModel { .from(messages) .where( and( - eq(messages.userId, this.userId), + this.ownership(), // Filter out messages that belong to MessageGroups isNull(messages.messageGroupId), where, @@ -785,7 +808,10 @@ export class MessageModel { }) .from(threads) .where( - and(eq(threads.userId, this.userId), inArray(threads.sourceMessageId, taskMessageIds)), + and( + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, threads), + inArray(threads.sourceMessageId, taskMessageIds), + ), ), { taskMessageCount: taskMessageIds.length }, ); @@ -889,7 +915,7 @@ export class MessageModel { ttsVoice: messageTTS.voice, }) .from(messages) - .where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIds))) + .where(and(this.ownership(), inArray(messages.id, messageIds))) .leftJoin(messagePlugins, eq(messagePlugins.id, messages.id)) .leftJoin(messageTranslates, eq(messageTranslates.id, messages.id)) .leftJoin(messageTTS, eq(messageTTS.id, messages.id)) @@ -957,7 +983,10 @@ export class MessageModel { .from(threads) .where( and( - eq(threads.userId, this.userId), + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + threads, + ), inArray(threads.sourceMessageId, taskMessageIds), ), ) @@ -1104,7 +1133,7 @@ export class MessageModel { ): Promise => { // 1. Query MessageGroups for this topic, optionally filtered by time range const whereConditions = [ - eq(messageGroups.userId, this.userId), + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, messageGroups), eq(messageGroups.topicId, topicId), ]; @@ -1145,7 +1174,7 @@ export class MessageModel { messageGroupId: messages.messageGroupId, }) .from(messages) - .where(and(eq(messages.userId, this.userId), inArray(messages.messageGroupId, groupIds))) + .where(and(this.ownership(), inArray(messages.messageGroupId, groupIds))) .orderBy(asc(messages.createdAt)), { groupCount: groupIds.length }, ); @@ -1247,7 +1276,10 @@ export class MessageModel { private buildThreadQueryCondition = async (threadId: string): Promise => { // Fetch the thread info to get sourceMessageId and type const thread = await this.db.query.threads.findFirst({ - where: and(eq(threads.id, threadId), eq(threads.userId, this.userId)), + where: and( + eq(threads.id, threadId), + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, threads), + ), }); if (!thread?.sourceMessageId || !thread?.topicId) { @@ -1280,7 +1312,7 @@ export class MessageModel { const agentSession = await this.db .select({ sessionId: agentsToSessions.sessionId }) .from(agentsToSessions) - .where(and(eq(agentsToSessions.agentId, agentId), eq(agentsToSessions.userId, this.userId))) + .where(and(eq(agentsToSessions.agentId, agentId), this.agentsToSessionsOwnership())) .limit(1); const associatedSessionId = agentSession[0]?.sessionId; @@ -1293,7 +1325,7 @@ export class MessageModel { findById = async (id: string) => { return this.db.query.messages.findFirst({ - where: and(eq(messages.id, id), eq(messages.userId, this.userId)), + where: and(eq(messages.id, id), this.ownership()), }); }; @@ -1340,7 +1372,7 @@ export class MessageModel { // For Standalone type, only return the source message if (threadType === ThreadType.Standalone) { const sourceMessage = await this.db.query.messages.findFirst({ - where: and(eq(messages.id, sourceMessageId), eq(messages.userId, this.userId)), + where: and(eq(messages.id, sourceMessageId), this.ownership()), }); return sourceMessage ? [sourceMessage as DBMessageItem] : []; @@ -1348,7 +1380,7 @@ export class MessageModel { // For Continuation type, get the source message first to know its createdAt const sourceMessage = await this.db.query.messages.findFirst({ - where: and(eq(messages.id, sourceMessageId), eq(messages.userId, this.userId)), + where: and(eq(messages.id, sourceMessageId), this.ownership()), }); if (!sourceMessage) return []; @@ -1361,7 +1393,7 @@ export class MessageModel { .from(messages) .where( and( - eq(messages.userId, this.userId), + this.ownership(), eq(messages.topicId, topicId), isNull(messages.threadId), // Only main conversation messages (not in any thread) or( @@ -1400,7 +1432,7 @@ export class MessageModel { const result = await this.db .select() .from(messages) - .where(eq(messages.userId, this.userId)) + .where(and(this.ownership())) .orderBy(desc(messages.createdAt)) .limit(pageSize) .offset(offset); @@ -1411,7 +1443,7 @@ export class MessageModel { queryBySessionId = async (sessionId?: string | null) => { const result = await this.db.query.messages.findMany({ orderBy: [asc(messages.createdAt)], - where: and(eq(messages.userId, this.userId), this.matchSession(sessionId)), + where: and(this.ownership(), this.matchSession(sessionId)), }); return result as DBMessageItem[]; @@ -1424,7 +1456,7 @@ export class MessageModel { const result = await this.db .select() .from(messages) - .where(and(eq(messages.userId, this.userId), sql`${messages.content} @@@ ${bm25Query}`)) + .where(and(this.ownership(), sql`${messages.content} @@@ ${bm25Query}`)) .orderBy(desc(messages.createdAt)); return result as DBMessageItem[]; @@ -1442,7 +1474,7 @@ export class MessageModel { .from(messages) .where( genWhere([ - eq(messages.userId, this.userId), + this.ownership(), params?.range ? genRangeWhere(params.range, messages.createdAt, (date) => date.toDate()) : undefined, @@ -1462,7 +1494,7 @@ export class MessageModel { const rows = await this.db .select({ id: messages.id }) .from(messages) - .where(and(eq(messages.userId, this.userId), eq(messages.topicId, topicId))) + .where(and(eq(messages.topicId, topicId), this.ownership())) .limit(1); return rows.length > 0; @@ -1472,13 +1504,7 @@ export class MessageModel { const rows = (await this.db .select() .from(messages) - .where( - and( - eq(messages.userId, this.userId), - eq(messages.topicId, topicId), - eq(messages.role, 'assistant'), - ), - ) + .where(and(eq(messages.topicId, topicId), eq(messages.role, 'assistant'), this.ownership())) .orderBy(asc(messages.createdAt)) .limit(1)) as DBMessageItem[]; @@ -1497,7 +1523,7 @@ export class MessageModel { .from(messages) .where( genWhere([ - eq(messages.userId, this.userId), + this.ownership(), params?.range ? genRangeWhere(params.range, messages.createdAt, (date) => date.toDate()) : undefined, @@ -1520,7 +1546,7 @@ export class MessageModel { id: messages.model, }) .from(messages) - .where(and(eq(messages.userId, this.userId), isNotNull(messages.model))) + .where(and(this.ownership(), isNotNull(messages.model))) .having(({ count }) => gt(count, 0)) .groupBy(messages.model) .orderBy(desc(sql`count`), asc(messages.model)) @@ -1539,7 +1565,7 @@ export class MessageModel { .from(messages) .where( genWhere([ - eq(messages.userId, this.userId), + this.ownership(), genRangeWhere( [startDate.format('YYYY-MM-DD'), endDate.add(1, 'day').format('YYYY-MM-DD')], messages.createdAt, @@ -1606,7 +1632,7 @@ export class MessageModel { .from(messages) .where( genWhere([ - eq(messages.userId, this.userId), + this.ownership(), eq(messages.role, 'assistant'), genRangeWhere( [startDate.format('YYYY-MM-DD'), endDate.add(1, 'day').format('YYYY-MM-DD')], @@ -1658,7 +1684,7 @@ export class MessageModel { const result = await this.db .select({ id: messages.id }) .from(messages) - .where(eq(messages.userId, this.userId)) + .where(and(this.ownership())) .limit(n + 1); return result.length > n; @@ -1671,7 +1697,7 @@ export class MessageModel { const result = await this.db .select({ id: messages.id }) .from(messages) - .where(eq(messages.userId, this.userId)) + .where(and(this.ownership())) .limit(n); return result.length; @@ -1716,23 +1742,25 @@ export class MessageModel { // Ensure group message does not populate sessionId const normalizedMessage = message.groupId ? { ...message, sessionId: null } : message; - return { - ...normalizedMessage, - // Sanitize content to strip null bytes that PostgreSQL rejects - content: sanitizeNullBytes(normalizedMessage.content), - // TODO: remove this when the client is updated - createdAt: createdAt ? new Date(createdAt) : undefined, - id, - model: fromModel, - provider: fromProvider, - updatedAt: updatedAt ? new Date(updatedAt) : undefined, - // Promote token usage into the dedicated `usage` column, preferring a - // top-level `usage` over the legacy `metadata.usage`. - usage: - normalizedMessage.usage ?? - (normalizedMessage.metadata as { usage?: ModelUsage } | undefined)?.usage, - userId: this.userId, - }; + return buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { + ...normalizedMessage, + // Sanitize content to strip null bytes that PostgreSQL rejects + content: sanitizeNullBytes(normalizedMessage.content), + // TODO: remove this when the client is updated + createdAt: createdAt ? new Date(createdAt) : undefined, + id, + model: fromModel, + provider: fromProvider, + updatedAt: updatedAt ? new Date(updatedAt) : undefined, + // Promote token usage into the dedicated `usage` column, preferring a + // top-level `usage` over the legacy `metadata.usage`. + usage: + normalizedMessage.usage ?? + (normalizedMessage.metadata as { usage?: ModelUsage } | undefined)?.usage, + }, + ); }; private insertMessageRelationsInTransaction = async ( @@ -1763,6 +1791,7 @@ export class MessageModel { toolCallId: message.tool_call_id, type: plugin?.type, userId: this.userId, + workspaceId: this.workspaceId ?? null, }), ); } @@ -1772,9 +1801,14 @@ export class MessageModel { timing, `${timingPrefix}.files.insert`, () => - trx - .insert(messagesFiles) - .values(files.map((file) => ({ fileId: file, messageId: id, userId: this.userId }))), + trx.insert(messagesFiles).values( + files.map((file) => ({ + fileId: file, + messageId: id, + userId: this.userId, + workspaceId: this.workspaceId ?? null, + })), + ), { fileCount: files.length }, ); } @@ -1791,6 +1825,7 @@ export class MessageModel { queryId: ragQueryId, similarity: chunk.similarity?.toString(), userId: this.userId, + workspaceId: this.workspaceId ?? null, })), ), { chunkCount: fileChunks.length }, @@ -1956,10 +1991,13 @@ export class MessageModel { }; batchCreate = async (newMessages: DBMessageItem[]) => { - const messagesToInsert = newMessages.map((m) => { - // TODO: need a better way to handle this - return { ...m, role: m.role as any, userId: this.userId }; - }); + const messagesToInsert = newMessages.map((m) => + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + // TODO: need a better way to handle this + { ...m, role: m.role as any }, + ), + ); const topicIds = [...new Set(newMessages.map((m) => m.topicId).filter(Boolean))] as string[]; @@ -1975,7 +2013,7 @@ export class MessageModel { createMessageQuery = async (params: NewMessageQueryParams) => { const result = await this.db .insert(messageQueries) - .values({ ...params, userId: this.userId }) + .values({ ...params, userId: this.userId, workspaceId: this.workspaceId ?? null }) .returning(); return result[0]; @@ -2017,6 +2055,7 @@ export class MessageModel { fileId: file.id, messageId: id, userId: this.userId, + workspaceId: this.workspaceId ?? null, })), ), { imageCount: imageList.length }, @@ -2034,7 +2073,7 @@ export class MessageModel { trx .select({ metadata: messages.metadata }) .from(messages) - .where(and(eq(messages.id, id), eq(messages.userId, this.userId))), + .where(and(eq(messages.id, id), this.ownership())), ); mergedMetadata = merge(existingMessage?.metadata || {}, metadataPatch); } @@ -2050,7 +2089,7 @@ export class MessageModel { ...(mergedMetadata && { metadata: mergedMetadata }), ...(usageToWrite && { usage: usageToWrite }), }) - .where(and(eq(messages.id, id), eq(messages.userId, this.userId))) + .where(and(eq(messages.id, id), this.ownership())) .returning({ topicId: messages.topicId }), { hasMetadata: !!metadataPatch, valueKeys: Object.keys(message) }, ); @@ -2072,7 +2111,7 @@ export class MessageModel { await runTimedStage( timing, 'db.message.update.topic.recomputeUsage', - () => recomputeTopicUsage(trx, this.userId, updated.topicId!), + () => recomputeTopicUsage(trx, this.userId, updated.topicId!, this.workspaceId), { topicCount: 1 }, ); } @@ -2094,7 +2133,7 @@ export class MessageModel { updateMetadata = async (id: string, metadata: Record) => { const item = await this.db.query.messages.findFirst({ - where: and(eq(messages.id, id), eq(messages.userId, this.userId)), + where: and(eq(messages.id, id), this.ownership()), }); if (!item) return; @@ -2107,28 +2146,31 @@ export class MessageModel { return this.db .update(messages) .set({ metadata: mergedMetadata, ...(usageToWrite && { usage: usageToWrite }) }) - .where(and(eq(messages.userId, this.userId), eq(messages.id, id))); + .where(and(eq(messages.id, id), this.ownership())); }; updatePluginState = async (id: string, state: Record): Promise => { const item = await this.db.query.messagePlugins.findFirst({ - where: eq(messagePlugins.id, id), + where: and(eq(messagePlugins.id, id), this.pluginsOwnership()), }); if (!item) throw new Error('Plugin not found'); await this.db .update(messagePlugins) .set({ state: merge(item.state || {}, state) }) - .where(eq(messagePlugins.id, id)); + .where(and(eq(messagePlugins.id, id), this.pluginsOwnership())); }; updateMessagePlugin = async (id: string, value: Partial) => { const item = await this.db.query.messagePlugins.findFirst({ - where: eq(messagePlugins.id, id), + where: and(eq(messagePlugins.id, id), this.pluginsOwnership()), }); if (!item) throw new Error('Plugin not found'); - return this.db.update(messagePlugins).set(value).where(eq(messagePlugins.id, id)); + return this.db + .update(messagePlugins) + .set(value) + .where(and(eq(messagePlugins.id, id), this.pluginsOwnership())); }; /** @@ -2144,7 +2186,7 @@ export class MessageModel { */ findMessagePlugin = async (messageId: string): Promise => { const row = await this.db.query.messagePlugins.findFirst({ - where: eq(messagePlugins.id, messageId), + where: and(eq(messagePlugins.id, messageId), this.pluginsOwnership()), }); if (!row) return undefined; return { @@ -2185,7 +2227,7 @@ export class MessageModel { }) .from(messagePlugins) .innerJoin(messages, eq(messagePlugins.id, messages.id)) - .where(and(eq(messages.topicId, topicId), eq(messages.userId, this.userId))) + .where(and(eq(messages.topicId, topicId), this.ownership(), this.pluginsOwnership())) .orderBy(asc(messages.createdAt), asc(messages.id)); return rows.map((row) => ({ @@ -2231,7 +2273,7 @@ export class MessageModel { if (metadata !== undefined) { // Need to merge with existing metadata const existingMessage = await trx.query.messages.findFirst({ - where: and(eq(messages.id, id), eq(messages.userId, this.userId)), + where: and(eq(messages.id, id), this.ownership()), }); messageUpdateData.metadata = merge(existingMessage?.metadata || {}, metadata); } @@ -2240,14 +2282,14 @@ export class MessageModel { await trx .update(messages) .set(messageUpdateData) - .where(and(eq(messages.id, id), eq(messages.userId, this.userId))); + .where(and(eq(messages.id, id), this.ownership())); } } // Update messagePlugins table (pluginState, pluginError) if (pluginState !== undefined || pluginError !== undefined) { const pluginItem = await trx.query.messagePlugins.findFirst({ - where: eq(messagePlugins.id, id), + where: and(eq(messagePlugins.id, id), this.pluginsOwnership()), }); if (pluginItem) { @@ -2265,7 +2307,7 @@ export class MessageModel { await trx .update(messagePlugins) .set(pluginUpdateData) - .where(eq(messagePlugins.id, id)); + .where(and(eq(messagePlugins.id, id), this.pluginsOwnership())); } } } @@ -2300,7 +2342,7 @@ export class MessageModel { }) .from(messagePlugins) .innerJoin(messages, eq(messages.id, messagePlugins.id)) - .where(and(eq(messagePlugins.toolCallId, toolCallId), eq(messages.userId, this.userId))) + .where(and(eq(messagePlugins.toolCallId, toolCallId), this.ownership())) .limit(1); if (!toolResult?.parentId) { @@ -2311,7 +2353,7 @@ export class MessageModel { const [parentMessage] = await trx .select({ id: messages.id, tools: messages.tools }) .from(messages) - .where(eq(messages.id, toolResult.parentId)) + .where(and(eq(messages.id, toolResult.parentId), this.ownership())) .limit(1); if (!parentMessage?.tools) { @@ -2334,12 +2376,12 @@ export class MessageModel { trx .update(messagePlugins) .set({ arguments: args }) - .where(eq(messagePlugins.id, toolResult.toolPluginId)), + .where(and(eq(messagePlugins.id, toolResult.toolPluginId), this.pluginsOwnership())), // Update parent assistant message's tools trx .update(messages) .set({ tools: updatedTools }) - .where(eq(messages.id, parentMessage.id)), + .where(and(eq(messages.id, parentMessage.id), this.ownership())), ]); }); @@ -2352,21 +2394,29 @@ export class MessageModel { updateTranslate = async (id: string, translate: Partial) => { const result = await this.db.query.messageTranslates.findFirst({ - where: and(eq(messageTranslates.id, id)), + where: and(eq(messageTranslates.id, id), this.translatesOwnership()), }); // If the message does not exist in the translate table, insert it if (!result) { - return this.db.insert(messageTranslates).values({ ...translate, id, userId: this.userId }); + return this.db.insert(messageTranslates).values({ + ...translate, + id, + userId: this.userId, + workspaceId: this.workspaceId ?? null, + }); } // or just update the existing one - return this.db.update(messageTranslates).set(translate).where(eq(messageTranslates.id, id)); + return this.db + .update(messageTranslates) + .set(translate) + .where(and(eq(messageTranslates.id, id), this.translatesOwnership())); }; updateTTS = async (id: string, tts: Partial) => { const result = await this.db.query.messageTTS.findFirst({ - where: and(eq(messageTTS.id, id)), + where: and(eq(messageTTS.id, id), this.ttsOwnership()), }); // If the message does not exist in the translate table, insert it @@ -2377,6 +2427,7 @@ export class MessageModel { id, userId: this.userId, voice: tts.voice, + workspaceId: this.workspaceId ?? null, }); } @@ -2384,7 +2435,7 @@ export class MessageModel { return this.db .update(messageTTS) .set({ contentMd5: tts.contentMd5, fileId: tts.file, voice: tts.voice }) - .where(eq(messageTTS.id, id)); + .where(and(eq(messageTTS.id, id), this.ttsOwnership())); }; async updateMessageRAG(id: string, { ragQueryId, fileChunks }: UpdateMessageRAGParams) { @@ -2395,6 +2446,7 @@ export class MessageModel { queryId: ragQueryId, similarity: chunk.similarity?.toString(), userId: this.userId, + workspaceId: this.workspaceId ?? null, })), ); } @@ -2407,7 +2459,7 @@ export class MessageModel { const message = await tx .select() .from(messages) - .where(and(eq(messages.id, id), eq(messages.userId, this.userId))) + .where(and(eq(messages.id, id), this.ownership())) .limit(1); // If the message to be deleted is not found, return directly @@ -2418,7 +2470,7 @@ export class MessageModel { await tx .update(messages) .set({ parentId: message[0].parentId }) - .where(and(eq(messages.parentId, id), eq(messages.userId, this.userId))); + .where(and(eq(messages.parentId, id), this.ownership())); // 3. Check if the message contains tools const toolCallIds = (message[0].tools as ChatToolPayload[]) @@ -2443,12 +2495,12 @@ export class MessageModel { // 6. Delete all related messages await tx .delete(messages) - .where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIdsToDelete))); + .where(and(this.ownership(), inArray(messages.id, messageIdsToDelete))); // 7. Keep the topic's usage rollup in sync (pure derived — a removed // assistant message must drop out of the topic totals). if (message[0].topicId) { - await recomputeTopicUsage(tx, this.userId, message[0].topicId); + await recomputeTopicUsage(tx, this.userId, message[0].topicId, this.workspaceId); } }); }; @@ -2461,7 +2513,7 @@ export class MessageModel { const toDelete = await tx .select({ id: messages.id, parentId: messages.parentId, topicId: messages.topicId }) .from(messages) - .where(and(eq(messages.userId, this.userId), inArray(messages.id, ids))); + .where(and(this.ownership(), inArray(messages.id, ids))); if (toDelete.length === 0) return; @@ -2506,30 +2558,27 @@ export class MessageModel { .select({ id: messages.id, parentId: messages.parentId }) .from(messages) .where( - and( - eq(messages.userId, this.userId), - inArray(messages.parentId, ids), - not(inArray(messages.id, ids)), - ), + and(this.ownership(), inArray(messages.parentId, ids), not(inArray(messages.id, ids))), ); // 5. Update each child's parentId to the final ancestor for (const child of children) { const newParentId = finalAncestorMap.get(child.parentId!) ?? null; - await tx.update(messages).set({ parentId: newParentId }).where(eq(messages.id, child.id)); + await tx + .update(messages) + .set({ parentId: newParentId }) + .where(and(eq(messages.id, child.id), this.ownership())); } // 6. Delete the messages - await tx - .delete(messages) - .where(and(eq(messages.userId, this.userId), inArray(messages.id, ids))); + await tx.delete(messages).where(and(this.ownership(), inArray(messages.id, ids))); // 7. Recompute the usage rollup for every affected topic (pure derived). const affectedTopicIds = [ ...new Set(toDelete.map((m) => m.topicId).filter(Boolean) as string[]), ]; for (const topicId of affectedTopicIds) { - await recomputeTopicUsage(tx, this.userId, topicId); + await recomputeTopicUsage(tx, this.userId, topicId, this.workspaceId); } }); }; @@ -2547,6 +2596,7 @@ export class MessageModel { fileId, messageId, userId: this.userId, + workspaceId: this.workspaceId ?? null, })), ); return { success: true }; @@ -2559,17 +2609,23 @@ export class MessageModel { deleteMessageTranslate = async (id: string) => this.db .delete(messageTranslates) - .where(and(eq(messageTranslates.id, id), eq(messageTranslates.userId, this.userId))); + .where(and(eq(messageTranslates.id, id), this.translatesOwnership())); deleteMessageTTS = async (id: string) => - this.db - .delete(messageTTS) - .where(and(eq(messageTTS.id, id), eq(messageTTS.userId, this.userId))); + this.db.delete(messageTTS).where(and(eq(messageTTS.id, id), this.ttsOwnership())); deleteMessageQuery = async (id: string) => this.db .delete(messageQueries) - .where(and(eq(messageQueries.id, id), eq(messageQueries.userId, this.userId))); + .where( + and( + eq(messageQueries.id, id), + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + messageQueries, + ), + ), + ); deleteMessagesBySession = async ( sessionId?: string | null, @@ -2580,7 +2636,7 @@ export class MessageModel { .delete(messages) .where( and( - eq(messages.userId, this.userId), + this.ownership(), this.matchSession(sessionId), this.matchTopic(topicId), this.matchGroup(groupId), @@ -2588,7 +2644,7 @@ export class MessageModel { ); deleteAllMessages = async () => { - return this.db.delete(messages).where(eq(messages.userId, this.userId)); + return this.db.delete(messages).where(and(this.ownership())); }; /** @@ -2602,7 +2658,7 @@ export class MessageModel { const agentSession = await this.db .select({ sessionId: agentsToSessions.sessionId }) .from(agentsToSessions) - .where(and(eq(agentsToSessions.agentId, agentId), eq(agentsToSessions.userId, this.userId))) + .where(and(eq(agentsToSessions.agentId, agentId), this.agentsToSessionsOwnership())) .limit(1); const associatedSessionId = agentSession[0]?.sessionId; @@ -2612,7 +2668,7 @@ export class MessageModel { ? or(eq(messages.agentId, agentId), eq(messages.sessionId, associatedSessionId)) : eq(messages.agentId, agentId); - return this.db.delete(messages).where(and(eq(messages.userId, this.userId), agentCondition)); + return this.db.delete(messages).where(and(this.ownership(), agentCondition)); }; // **************** Helper *************** // diff --git a/packages/database/src/models/messengerAccountLink.ts b/packages/database/src/models/messengerAccountLink.ts index ee9d196db0..f9c0809684 100644 --- a/packages/database/src/models/messengerAccountLink.ts +++ b/packages/database/src/models/messengerAccountLink.ts @@ -50,6 +50,13 @@ export class MessengerAccountLinkModel { this.db = db; } + // A given IM identity maps to exactly one link per `(userId, platform, + // tenantId)` — the unique index already enforces this — so ownership is + // purely by `userId`. `workspaceId` on the row is the *active scope* (derived + // from the active agent), NOT part of the link's identity, so it must not + // scope lookups; otherwise switching scope would orphan the existing link. + private ownership = (): SQL => eq(messengerAccountLinks.userId, this.userId); + // --------------- User-scoped CRUD --------------- /** @@ -87,6 +94,7 @@ export class MessengerAccountLinkModel { tenantId, updatedAt: now, userId: this.userId, + workspaceId: params.workspaceId ?? null, }) .onConflictDoNothing({ target: [ @@ -131,6 +139,7 @@ export class MessengerAccountLinkModel { activeAgentId: params.activeAgentId ?? byIdentity.activeAgentId, platformUsername: params.platformUsername ?? null, updatedAt: now, + workspaceId: params.workspaceId ?? null, }) .where(eq(messengerAccountLinks.id, byIdentity.id)) .returning(); @@ -149,6 +158,7 @@ export class MessengerAccountLinkModel { activeAgentId: params.activeAgentId ?? existingForUser.activeAgentId, platformUsername: params.platformUsername ?? null, updatedAt: now, + workspaceId: params.workspaceId ?? null, }) .where(eq(messengerAccountLinks.id, existingForUser.id)) .returning(); @@ -161,14 +171,11 @@ export class MessengerAccountLinkModel { delete = async (id: string) => { return this.db .delete(messengerAccountLinks) - .where(and(eq(messengerAccountLinks.id, id), eq(messengerAccountLinks.userId, this.userId))); + .where(and(eq(messengerAccountLinks.id, id), this.ownership())); }; deleteByPlatform = async (platform: string, tenantId?: string) => { - const conditions: SQL[] = [ - eq(messengerAccountLinks.userId, this.userId), - eq(messengerAccountLinks.platform, platform), - ]; + const conditions: SQL[] = [this.ownership(), eq(messengerAccountLinks.platform, platform)]; if (tenantId !== undefined) { conditions.push(eq(messengerAccountLinks.tenantId, tenantId)); } @@ -176,10 +183,7 @@ export class MessengerAccountLinkModel { }; list = async (): Promise => { - return this.db - .select() - .from(messengerAccountLinks) - .where(eq(messengerAccountLinks.userId, this.userId)); + return this.db.select().from(messengerAccountLinks).where(this.ownership()); }; /** @@ -192,10 +196,7 @@ export class MessengerAccountLinkModel { platform: string, tenantId?: string, ): Promise => { - const conditions: SQL[] = [ - eq(messengerAccountLinks.userId, this.userId), - eq(messengerAccountLinks.platform, platform), - ]; + const conditions: SQL[] = [this.ownership(), eq(messengerAccountLinks.platform, platform)]; if (tenantId !== undefined) { conditions.push(eq(messengerAccountLinks.tenantId, tenantId)); } @@ -208,23 +209,25 @@ export class MessengerAccountLinkModel { return result; }; - /** Update which agent the IM session is currently routed to. */ + /** + * Update which agent the IM session is currently routed to, together with + * the active scope (`workspaceId`) derived from that agent. Passing + * `agentId: null` clears the active agent and resets the scope to personal. + */ setActiveAgent = async ( platform: string, agentId: string | null, + workspaceId: string | null, tenantId?: string, ): Promise => { - const conditions: SQL[] = [ - eq(messengerAccountLinks.userId, this.userId), - eq(messengerAccountLinks.platform, platform), - ]; + const conditions: SQL[] = [this.ownership(), eq(messengerAccountLinks.platform, platform)]; if (tenantId !== undefined) { conditions.push(eq(messengerAccountLinks.tenantId, tenantId)); } const [updated] = await this.db .update(messengerAccountLinks) - .set({ activeAgentId: agentId, updatedAt: new Date() }) + .set({ activeAgentId: agentId, updatedAt: new Date(), workspaceId }) .where(and(...conditions)) .returning(); @@ -276,4 +279,26 @@ export class MessengerAccountLinkModel { .returning(); return updated; }; + + /** + * Static scope switch used by IM `/switch`. Moves the link to a new active + * scope (personal → `null`, or a workspace id) and sets the active agent to + * `agentId` — callers pass the scope's default agent (inbox/LobeAI) so + * switching never leaves the session agent-less; pass `null` only when the + * target scope has no agents. Caller must authorize access to the target + * scope first. + */ + static setActiveScope = async ( + db: LobeChatDatabase, + linkId: string, + workspaceId: string | null, + agentId: string | null = null, + ): Promise => { + const [updated] = await db + .update(messengerAccountLinks) + .set({ activeAgentId: agentId, updatedAt: new Date(), workspaceId }) + .where(eq(messengerAccountLinks.id, linkId)) + .returning(); + return updated; + }; } diff --git a/packages/database/src/models/notification.ts b/packages/database/src/models/notification.ts index c6b47706de..47dd401500 100644 --- a/packages/database/src/models/notification.ts +++ b/packages/database/src/models/notification.ts @@ -13,12 +13,14 @@ export class NotificationModel { this.userId = userId; } + private ownership = () => eq(notifications.userId, this.userId); + async list( opts: { category?: string; cursor?: string; limit?: number; unreadOnly?: boolean } = {}, ) { const { cursor, limit = 20, category, unreadOnly } = opts; - const conditions = [eq(notifications.userId, this.userId), eq(notifications.isArchived, false)]; + const conditions = [this.ownership(), eq(notifications.isArchived, false)]; if (unreadOnly) { conditions.push(eq(notifications.isRead, false)); @@ -32,7 +34,7 @@ export class NotificationModel { const cursorRow = await this.db .select({ createdAt: notifications.createdAt, id: notifications.id }) .from(notifications) - .where(and(eq(notifications.id, cursor), eq(notifications.userId, this.userId))) + .where(and(eq(notifications.id, cursor), this.ownership())) .limit(1); if (cursorRow[0]) { @@ -60,11 +62,7 @@ export class NotificationModel { .select({ count: count() }) .from(notifications) .where( - and( - eq(notifications.userId, this.userId), - eq(notifications.isRead, false), - eq(notifications.isArchived, false), - ), + and(this.ownership(), eq(notifications.isRead, false), eq(notifications.isArchived, false)), ); return result?.count ?? 0; @@ -76,7 +74,7 @@ export class NotificationModel { return this.db .update(notifications) .set({ isRead: true, updatedAt: new Date() }) - .where(and(eq(notifications.userId, this.userId), inArray(notifications.id, ids))); + .where(and(this.ownership(), inArray(notifications.id, ids))); } async markAllAsRead() { @@ -84,11 +82,7 @@ export class NotificationModel { .update(notifications) .set({ isRead: true, updatedAt: new Date() }) .where( - and( - eq(notifications.userId, this.userId), - eq(notifications.isRead, false), - eq(notifications.isArchived, false), - ), + and(this.ownership(), eq(notifications.isRead, false), eq(notifications.isArchived, false)), ); } @@ -96,14 +90,14 @@ export class NotificationModel { return this.db .update(notifications) .set({ isArchived: true, updatedAt: new Date() }) - .where(and(eq(notifications.id, id), eq(notifications.userId, this.userId))); + .where(and(eq(notifications.id, id), this.ownership())); } async archiveAll() { return this.db .update(notifications) .set({ isArchived: true, updatedAt: new Date() }) - .where(and(eq(notifications.userId, this.userId), eq(notifications.isArchived, false))); + .where(and(this.ownership(), eq(notifications.isArchived, false))); } // ─── Write-side (used by NotificationService in cloud) ───────── diff --git a/packages/database/src/models/plugin.ts b/packages/database/src/models/plugin.ts index 0798a8f292..3becd7ee4a 100644 --- a/packages/database/src/models/plugin.ts +++ b/packages/database/src/models/plugin.ts @@ -4,16 +4,25 @@ import { and, desc, eq } from 'drizzle-orm'; import type { InstalledPluginItem, NewInstalledPlugin } from '../schemas'; import { userInstalledPlugins } from '../schemas'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspaceWhere } from '../utils/workspace'; export class PluginModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + userInstalledPlugins, + ); + create = async ( params: Pick< NewInstalledPlugin, @@ -22,7 +31,7 @@ export class PluginModel { ) => { const [result] = await this.db .insert(userInstalledPlugins) - .values({ ...params, userId: this.userId }) + .values({ ...params, userId: this.userId, workspaceId: this.workspaceId ?? null }) .onConflictDoUpdate({ set: { ...params, updatedAt: new Date() }, target: [userInstalledPlugins.identifier, userInstalledPlugins.userId], @@ -35,13 +44,11 @@ export class PluginModel { delete = async (id: string) => { return this.db .delete(userInstalledPlugins) - .where( - and(eq(userInstalledPlugins.identifier, id), eq(userInstalledPlugins.userId, this.userId)), - ); + .where(and(eq(userInstalledPlugins.identifier, id), this.ownership())); }; deleteAll = async () => { - return this.db.delete(userInstalledPlugins).where(eq(userInstalledPlugins.userId, this.userId)); + return this.db.delete(userInstalledPlugins).where(this.ownership()); }; query = async () => { @@ -57,7 +64,7 @@ export class PluginModel { updatedAt: userInstalledPlugins.updatedAt, }) .from(userInstalledPlugins) - .where(eq(userInstalledPlugins.userId, this.userId)) + .where(this.ownership()) .orderBy(desc(userInstalledPlugins.createdAt)); return data.map((item) => ({ @@ -68,10 +75,7 @@ export class PluginModel { findById = async (id: string) => { return this.db.query.userInstalledPlugins.findFirst({ - where: and( - eq(userInstalledPlugins.identifier, id), - eq(userInstalledPlugins.userId, this.userId), - ), + where: and(eq(userInstalledPlugins.identifier, id), this.ownership()), }); }; @@ -79,8 +83,6 @@ export class PluginModel { return this.db .update(userInstalledPlugins) .set({ ...value, updatedAt: new Date() }) - .where( - and(eq(userInstalledPlugins.identifier, id), eq(userInstalledPlugins.userId, this.userId)), - ); + .where(and(eq(userInstalledPlugins.identifier, id), this.ownership())); }; } diff --git a/packages/database/src/models/ragEval/dataset.ts b/packages/database/src/models/ragEval/dataset.ts index a4b2b49195..db2e68d0e7 100644 --- a/packages/database/src/models/ragEval/dataset.ts +++ b/packages/database/src/models/ragEval/dataset.ts @@ -1,31 +1,35 @@ import type { RAGEvalDataSetItem } from '@lobechat/types'; import { and, desc, eq } from 'drizzle-orm'; -import type {NewEvalDatasetsItem } from '../../schemas'; +import type { NewEvalDatasetsItem } from '../../schemas'; import { evalDatasets } from '../../schemas'; import type { LobeChatDatabase } from '../../type'; +import { buildWorkspaceWhere } from '../../utils/workspace'; export class EvalDatasetModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, evalDatasets); + create = async (params: NewEvalDatasetsItem) => { const [result] = await this.db .insert(evalDatasets) - .values({ ...params, userId: this.userId }) + .values({ ...params, userId: this.userId, workspaceId: this.workspaceId ?? null }) .returning(); return result; }; delete = async (id: string) => { - return this.db - .delete(evalDatasets) - .where(and(eq(evalDatasets.id, id), eq(evalDatasets.userId, this.userId))); + return this.db.delete(evalDatasets).where(and(eq(evalDatasets.id, id), this.ownership())); }; query = async (knowledgeBaseId: string): Promise => { @@ -38,18 +42,13 @@ export class EvalDatasetModel { updatedAt: evalDatasets.updatedAt, }) .from(evalDatasets) - .where( - and( - eq(evalDatasets.userId, this.userId), - eq(evalDatasets.knowledgeBaseId, knowledgeBaseId), - ), - ) + .where(and(this.ownership(), eq(evalDatasets.knowledgeBaseId, knowledgeBaseId))) .orderBy(desc(evalDatasets.createdAt)); }; findById = async (id: string) => { return this.db.query.evalDatasets.findFirst({ - where: and(eq(evalDatasets.id, id), eq(evalDatasets.userId, this.userId)), + where: and(eq(evalDatasets.id, id), this.ownership()), }); }; @@ -57,6 +56,6 @@ export class EvalDatasetModel { return this.db .update(evalDatasets) .set({ ...value, updatedAt: new Date() }) - .where(and(eq(evalDatasets.id, id), eq(evalDatasets.userId, this.userId))); + .where(and(eq(evalDatasets.id, id), this.ownership())); }; } diff --git a/packages/database/src/models/ragEval/datasetRecord.ts b/packages/database/src/models/ragEval/datasetRecord.ts index 6fd81c3a81..0950223b4a 100644 --- a/packages/database/src/models/ragEval/datasetRecord.ts +++ b/packages/database/src/models/ragEval/datasetRecord.ts @@ -1,23 +1,32 @@ import type { EvalDatasetRecordRefFile } from '@lobechat/types'; import { and, eq, inArray } from 'drizzle-orm'; -import type {NewEvalDatasetRecordsItem } from '../../schemas'; +import type { NewEvalDatasetRecordsItem } from '../../schemas'; import { evalDatasetRecords, files } from '../../schemas'; import type { LobeChatDatabase } from '../../type'; +import { buildWorkspaceWhere } from '../../utils/workspace'; export class EvalDatasetRecordModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, evalDatasetRecords); + + private filesOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, files); + create = async (params: NewEvalDatasetRecordsItem) => { const [result] = await this.db .insert(evalDatasetRecords) - .values({ ...params, userId: this.userId }) + .values({ ...params, userId: this.userId, workspaceId: this.workspaceId ?? null }) .returning(); return result; }; @@ -25,7 +34,13 @@ export class EvalDatasetRecordModel { batchCreate = async (params: NewEvalDatasetRecordsItem[]) => { const [result] = await this.db .insert(evalDatasetRecords) - .values(params.map((item) => ({ ...item, userId: this.userId }))) + .values( + params.map((item) => ({ + ...item, + userId: this.userId, + workspaceId: this.workspaceId ?? null, + })), + ) .returning(); return result; @@ -34,22 +49,19 @@ export class EvalDatasetRecordModel { delete = async (id: string) => { return this.db .delete(evalDatasetRecords) - .where(and(eq(evalDatasetRecords.id, id), eq(evalDatasetRecords.userId, this.userId))); + .where(and(eq(evalDatasetRecords.id, id), this.ownership())); }; query = async (datasetId: string) => { const list = await this.db.query.evalDatasetRecords.findMany({ - where: and( - eq(evalDatasetRecords.datasetId, datasetId), - eq(evalDatasetRecords.userId, this.userId), - ), + where: and(eq(evalDatasetRecords.datasetId, datasetId), this.ownership()), }); const fileList = list.flatMap((item) => item.referenceFiles).filter(Boolean) as string[]; const fileItems = await this.db .select({ fileType: files.fileType, id: files.id, name: files.name }) .from(files) - .where(and(inArray(files.id, fileList), eq(files.userId, this.userId))); + .where(and(inArray(files.id, fileList), this.filesOwnership())); return list.map((item) => { return { @@ -63,16 +75,13 @@ export class EvalDatasetRecordModel { findByDatasetId = async (datasetId: string) => { return this.db.query.evalDatasetRecords.findMany({ - where: and( - eq(evalDatasetRecords.datasetId, datasetId), - eq(evalDatasetRecords.userId, this.userId), - ), + where: and(eq(evalDatasetRecords.datasetId, datasetId), this.ownership()), }); }; findById = async (id: string) => { return this.db.query.evalDatasetRecords.findFirst({ - where: and(eq(evalDatasetRecords.id, id), eq(evalDatasetRecords.userId, this.userId)), + where: and(eq(evalDatasetRecords.id, id), this.ownership()), }); }; @@ -80,6 +89,6 @@ export class EvalDatasetRecordModel { return this.db .update(evalDatasetRecords) .set(value) - .where(and(eq(evalDatasetRecords.id, id), eq(evalDatasetRecords.userId, this.userId))); + .where(and(eq(evalDatasetRecords.id, id), this.ownership())); }; } diff --git a/packages/database/src/models/ragEval/evaluation.ts b/packages/database/src/models/ragEval/evaluation.ts index 4076c5e48c..7673c7c0ed 100644 --- a/packages/database/src/models/ragEval/evaluation.ts +++ b/packages/database/src/models/ragEval/evaluation.ts @@ -3,36 +3,35 @@ import { EvalEvaluationStatus } from '@lobechat/types'; import type { SQL } from 'drizzle-orm'; import { and, count, desc, eq, inArray } from 'drizzle-orm'; -import type { - NewEvalEvaluationItem} from '../../schemas'; -import { - evalDatasets, - evalEvaluation, - evaluationRecords -} from '../../schemas'; +import type { NewEvalEvaluationItem } from '../../schemas'; +import { evalDatasets, evalEvaluation, evaluationRecords } from '../../schemas'; import type { LobeChatDatabase } from '../../type'; +import { buildWorkspaceWhere } from '../../utils/workspace'; export class EvalEvaluationModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, evalEvaluation); + create = async (params: NewEvalEvaluationItem) => { const [result] = await this.db .insert(evalEvaluation) - .values({ ...params, userId: this.userId }) + .values({ ...params, userId: this.userId, workspaceId: this.workspaceId ?? null }) .returning(); return result; }; delete = async (id: string) => { - return this.db - .delete(evalEvaluation) - .where(and(eq(evalEvaluation.id, id), eq(evalEvaluation.userId, this.userId))); + return this.db.delete(evalEvaluation).where(and(eq(evalEvaluation.id, id), this.ownership())); }; queryByKnowledgeBaseId = async (knowledgeBaseId: string) => { @@ -52,12 +51,7 @@ export class EvalEvaluationModel { .from(evalEvaluation) .leftJoin(evalDatasets, eq(evalDatasets.id, evalEvaluation.datasetId)) .orderBy(desc(evalEvaluation.createdAt)) - .where( - and( - eq(evalEvaluation.userId, this.userId), - eq(evalEvaluation.knowledgeBaseId, knowledgeBaseId), - ), - ); + .where(and(this.ownership(), eq(evalEvaluation.knowledgeBaseId, knowledgeBaseId))); // Then query record statistics for each evaluation const evaluationIds = evaluations.map((evals) => evals.id); @@ -88,7 +82,7 @@ export class EvalEvaluationModel { findById = async (id: string) => { return this.db.query.evalEvaluation.findFirst({ - where: and(eq(evalEvaluation.id, id), eq(evalEvaluation.userId, this.userId)), + where: and(eq(evalEvaluation.id, id), this.ownership()), }); }; @@ -96,6 +90,6 @@ export class EvalEvaluationModel { return this.db .update(evalEvaluation) .set(value) - .where(and(eq(evalEvaluation.id, id), eq(evalEvaluation.userId, this.userId))); + .where(and(eq(evalEvaluation.id, id), this.ownership())); }; } diff --git a/packages/database/src/models/ragEval/evaluationRecord.ts b/packages/database/src/models/ragEval/evaluationRecord.ts index f400538dbb..3db1b11bf8 100644 --- a/packages/database/src/models/ragEval/evaluationRecord.ts +++ b/packages/database/src/models/ragEval/evaluationRecord.ts @@ -1,22 +1,28 @@ import { and, eq } from 'drizzle-orm'; -import type {NewEvaluationRecordsItem } from '../../schemas'; +import type { NewEvaluationRecordsItem } from '../../schemas'; import { evaluationRecords } from '../../schemas'; import type { LobeChatDatabase } from '../../type'; +import { buildWorkspaceWhere } from '../../utils/workspace'; export class EvaluationRecordModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, evaluationRecords); + create = async (params: NewEvaluationRecordsItem) => { const [result] = await this.db .insert(evaluationRecords) - .values({ ...params, userId: this.userId }) + .values({ ...params, userId: this.userId, workspaceId: this.workspaceId ?? null }) .returning(); return result; }; @@ -24,37 +30,37 @@ export class EvaluationRecordModel { batchCreate = async (params: NewEvaluationRecordsItem[]) => { return this.db .insert(evaluationRecords) - .values(params.map((item) => ({ ...item, userId: this.userId }))) + .values( + params.map((item) => ({ + ...item, + userId: this.userId, + workspaceId: this.workspaceId ?? null, + })), + ) .returning(); }; delete = async (id: string) => { return this.db .delete(evaluationRecords) - .where(and(eq(evaluationRecords.id, id), eq(evaluationRecords.userId, this.userId))); + .where(and(eq(evaluationRecords.id, id), this.ownership())); }; query = async (reportId: string) => { return this.db.query.evaluationRecords.findMany({ - where: and( - eq(evaluationRecords.evaluationId, reportId), - eq(evaluationRecords.userId, this.userId), - ), + where: and(eq(evaluationRecords.evaluationId, reportId), this.ownership()), }); }; findById = async (id: string) => { return this.db.query.evaluationRecords.findFirst({ - where: and(eq(evaluationRecords.id, id), eq(evaluationRecords.userId, this.userId)), + where: and(eq(evaluationRecords.id, id), this.ownership()), }); }; findByEvaluationId = async (evaluationId: string) => { return this.db.query.evaluationRecords.findMany({ - where: and( - eq(evaluationRecords.evaluationId, evaluationId), - eq(evaluationRecords.userId, this.userId), - ), + where: and(eq(evaluationRecords.evaluationId, evaluationId), this.ownership()), }); }; @@ -62,6 +68,6 @@ export class EvaluationRecordModel { return this.db .update(evaluationRecords) .set(value) - .where(and(eq(evaluationRecords.id, id), eq(evaluationRecords.userId, this.userId))); + .where(and(eq(evaluationRecords.id, id), this.ownership())); }; } diff --git a/packages/database/src/models/rbac.ts b/packages/database/src/models/rbac.ts index 8b5a660aef..fc71eb28e6 100644 --- a/packages/database/src/models/rbac.ts +++ b/packages/database/src/models/rbac.ts @@ -1,8 +1,14 @@ -import { and, eq, inArray, sql } from 'drizzle-orm'; +import type { WorkspaceSystemRoleName } from '@lobechat/const/rbac'; +import { and, eq, inArray, isNull, or, sql } from 'drizzle-orm'; -import { LobeChatDatabase } from '@/database/type'; +import type { LobeChatDatabase } from '@/database/type'; -import { RoleItem, permissions, rolePermissions, roles, userRoles } from '../schemas/rbac'; +import type { RoleItem } from '../schemas/rbac'; +import { permissions, rolePermissions, roles, userRoles } from '../schemas/rbac'; +import { + assignWorkspaceRoleToUser, + revokeWorkspaceRolesForUser, +} from '../utils/seedWorkspaceRoles'; export interface UserPermissionInfo { category: string; @@ -11,6 +17,48 @@ export interface UserPermissionInfo { roleName: string; } +/** + * Optional scope for a permission/role query. + * + * - `workspaceId: 'xxx'` — match grants in that workspace plus globally-granted + * roles (`rbac_user_roles.workspace_id IS NULL`, e.g. `super_admin`). This is + * what tRPC `withRbacPermission` uses inside a workspace request. + * - `workspaceId` omitted — match **any** grant, regardless of workspace. This + * preserves backward-compat with pre-workspace-scope callers (Hono routes + * that just check `agent:read:all` against the whole user, with workspace + * isolation enforced by the resource-level query elsewhere). + * + * Callers that want to assert "only globally-granted roles count" must do that + * filter themselves on the result set; we don't expose a third mode here + * because no production caller needs it today. + */ +export interface RbacScopeOptions { + userId?: string; + workspaceId?: string; +} + +/** + * Build the `WHERE rbac_user_roles.workspace_id ...` predicate used by every + * permission/role lookup. Encodes the rule above in one place so the four + * query methods don't drift. Returns `undefined` when no workspace scope + * filter should be applied (legacy behavior). + */ +const buildScopeWhere = (workspaceId: string | undefined) => + workspaceId + ? or(eq(userRoles.workspaceId, workspaceId), isNull(userRoles.workspaceId)) + : undefined; + +/** + * Back-compat shim: existing call sites pass a bare `userId` string as the + * second arg. New call sites pass `{ userId?, workspaceId? }`. Normalise both + * forms into the option object. + */ +const normalizeScope = (arg: string | RbacScopeOptions | undefined): RbacScopeOptions => { + if (!arg) return {}; + if (typeof arg === 'string') return { userId: arg }; + return arg; +}; + export class RbacModel { private userId: string; private db: LobeChatDatabase; @@ -21,12 +69,14 @@ export class RbacModel { } /** - * Get all permissions for a specific user - * @param userId - User ID to query permissions for - * @returns Array of permission codes that the user has + * Get all permissions for a specific user. Accepts either a plain `userId` + * string (legacy global-scope check) or `{ userId?, workspaceId? }` + * (workspace-aware). Permission codes returned include the `:all`/`:owner` + * scope suffix as stored in `rbac_permissions.code`. */ - getUserPermissions = async (userId?: string): Promise => { - const targetUserId = userId || this.userId; + getUserPermissions = async (arg?: string | RbacScopeOptions): Promise => { + const opts = normalizeScope(arg); + const targetUserId = opts.userId || this.userId; const result = await this.db .select({ @@ -41,21 +91,26 @@ export class RbacModel { eq(userRoles.userId, targetUserId), eq(roles.isActive, true), eq(permissions.isActive, true), + buildScopeWhere(opts.workspaceId), // Check if role assignment is not expired sql`(${userRoles.expiresAt} IS NULL OR ${userRoles.expiresAt} > NOW())`, ), ); - return result.map((row) => row.permissionCode); + // De-dupe — the same code can come from multiple roles (e.g. owner + + // member if a user somehow ends up with both). + return [...new Set(result.map((row) => row.permissionCode))]; }; /** - * Get detailed permission information for a user - * @param userId - User ID to query permissions for - * @returns Array of detailed permission information + * Get detailed permission information for a user. Same scope rules as + * `getUserPermissions`. */ - getUserPermissionDetails = async (userId?: string): Promise => { - const targetUserId = userId || this.userId; + getUserPermissionDetails = async ( + arg?: string | RbacScopeOptions, + ): Promise => { + const opts = normalizeScope(arg); + const targetUserId = opts.userId || this.userId; return await this.db .select({ @@ -73,6 +128,7 @@ export class RbacModel { eq(userRoles.userId, targetUserId), eq(roles.isActive, true), eq(permissions.isActive, true), + buildScopeWhere(opts.workspaceId), // Check if role assignment is not expired sql`(${userRoles.expiresAt} IS NULL OR ${userRoles.expiresAt} > NOW())`, ), @@ -81,13 +137,15 @@ export class RbacModel { }; /** - * Check if user has a specific permission - * @param permissionCode - Permission code to check - * @param userId - User ID to check (optional, defaults to instance userId) - * @returns Boolean indicating if user has the permission + * Check if user has a specific permission. Pass `{ workspaceId }` to scope + * the check to a workspace (global grants still apply). */ - hasPermission = async (permissionCode: string, userId?: string): Promise => { - const targetUserId = userId || this.userId; + hasPermission = async ( + permissionCode: string, + arg?: string | RbacScopeOptions, + ): Promise => { + const opts = normalizeScope(arg); + const targetUserId = opts.userId || this.userId; const result = await this.db .select({ count: sql`count(*)` }) @@ -101,6 +159,7 @@ export class RbacModel { inArray(permissions.code, [permissionCode]), eq(roles.isActive, true), eq(permissions.isActive, true), + buildScopeWhere(opts.workspaceId), // Check if role assignment is not expired sql`(${userRoles.expiresAt} IS NULL OR ${userRoles.expiresAt} > NOW())`, ), @@ -110,15 +169,16 @@ export class RbacModel { }; /** - * Check if user has any of the specified permissions (OR logic) - * @param permissionCodes - Array of permission codes to check - * @param userId - User ID to check (optional, defaults to instance userId) - * @returns Boolean indicating if user has at least one of the permissions + * Check if user has any of the specified permissions (OR logic). */ - hasAnyPermission = async (permissionCodes: string[], userId?: string): Promise => { + hasAnyPermission = async ( + permissionCodes: string[], + arg?: string | RbacScopeOptions, + ): Promise => { if (permissionCodes.length === 0) return false; - const targetUserId = userId || this.userId; + const opts = normalizeScope(arg); + const targetUserId = opts.userId || this.userId; const result = await this.db .select({ count: sql`count(*)` }) @@ -132,6 +192,7 @@ export class RbacModel { inArray(permissions.code, permissionCodes), eq(roles.isActive, true), eq(permissions.isActive, true), + buildScopeWhere(opts.workspaceId), // Check if role assignment is not expired sql`(${userRoles.expiresAt} IS NULL OR ${userRoles.expiresAt} > NOW())`, ), @@ -141,27 +202,24 @@ export class RbacModel { }; /** - * Check if user has all of the specified permissions (AND logic) - * @param permissionCodes - Array of permission codes to check - * @param userId - User ID to check (optional, defaults to instance userId) - * @returns Boolean indicating if user has all of the permissions + * Check if user has all of the specified permissions (AND logic). */ - hasAllPermissions = async (permissionCodes: string[], userId?: string): Promise => { + hasAllPermissions = async ( + permissionCodes: string[], + arg?: string | RbacScopeOptions, + ): Promise => { if (permissionCodes.length === 0) return true; - const checks = await Promise.all( - permissionCodes.map((code) => this.hasPermission(code, userId)), - ); + const checks = await Promise.all(permissionCodes.map((code) => this.hasPermission(code, arg))); return checks.every(Boolean); }; /** - * Get user's active roles - * @param userId - User ID to query roles for - * @returns Array of role information + * Get user's active roles. Same scope rules as `hasPermission`. */ - getUserRoles = async (userId?: string): Promise => { - const targetUserId = userId || this.userId; + getUserRoles = async (arg?: string | RbacScopeOptions): Promise => { + const opts = normalizeScope(arg); + const targetUserId = opts.userId || this.userId; return await this.db .select({ @@ -183,6 +241,7 @@ export class RbacModel { and( eq(userRoles.userId, targetUserId), eq(roles.isActive, true), + buildScopeWhere(opts.workspaceId), // Check if role assignment is not expired sql`(${userRoles.expiresAt} IS NULL OR ${userRoles.expiresAt} > NOW())`, ), @@ -190,6 +249,40 @@ export class RbacModel { .orderBy(userRoles.createdAt); }; + /** + * List all roles defined inside a workspace (both built-in and custom). + * Used by the upcoming custom-role admin UI (LOBE-9193) and any client that + * wants to show available roles for a workspace. + */ + listWorkspaceRoles = async (workspaceId: string): Promise => { + return this.db.query.roles.findMany({ + orderBy: (table, { asc }) => [asc(table.isSystem), asc(table.name)], + where: and(eq(roles.workspaceId, workspaceId), eq(roles.isActive, true)), + }); + }; + + /** + * Grant a built-in workspace role (`workspace_owner` | `workspace_member` | + * `workspace_viewer`) to a user inside a workspace. Delegates to the seed + * util so the onConflict + role-lookup logic lives in one place. + */ + assignWorkspaceRole = async (params: { + roleName: WorkspaceSystemRoleName; + userId: string; + workspaceId: string; + }): Promise => { + await assignWorkspaceRoleToUser(this.db, params); + }; + + /** + * Revoke every workspace-scoped role this user holds in `workspaceId`. + * Idempotent. Used by member removal/leave flows and by `updateRole` before + * granting the new role. + */ + revokeWorkspaceRole = async (params: { userId: string; workspaceId: string }): Promise => { + await revokeWorkspaceRolesForUser(this.db, params); + }; + /** * Update user roles using a transaction to ensure atomicity * @param userId User ID diff --git a/packages/database/src/models/recent.ts b/packages/database/src/models/recent.ts index 28b21199bc..7edabe6fbc 100644 --- a/packages/database/src/models/recent.ts +++ b/packages/database/src/models/recent.ts @@ -4,6 +4,7 @@ import { unionAll } from 'drizzle-orm/pg-core'; import { agents, DOCUMENT_FOLDER_TYPE, documents, tasks, topics } from '../schemas'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspaceWhere } from '../utils/workspace'; export interface RecentDbItem { id: string; @@ -30,14 +31,24 @@ const TASK_FINAL_STATUSES = ['completed', 'canceled']; export class RecentModel { private userId: string; + private workspaceId?: string; private db: LobeChatDatabase; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } queryRecent = async (limit: number = 10): Promise => { + const scope = { userId: this.userId, workspaceId: this.workspaceId }; + + // `tasks` uses `createdByUserId` instead of `userId`, so apply the + // workspace-aware predicate inline. + const taskScopeWhere = this.workspaceId + ? eq(tasks.workspaceId, this.workspaceId) + : and(eq(tasks.createdByUserId, this.userId), isNull(tasks.workspaceId)); + const topicArm = this.db .select({ id: topics.id, @@ -53,7 +64,7 @@ export class RecentModel { .leftJoin(agents, eq(topics.agentId, agents.id)) .where( and( - eq(topics.userId, this.userId), + buildWorkspaceWhere(scope, topics), or( isNotNull(topics.groupId), eq(agents.slug, 'inbox'), @@ -80,7 +91,7 @@ export class RecentModel { .from(documents) .where( and( - eq(documents.userId, this.userId), + buildWorkspaceWhere(scope, documents), not(inArray(documents.sourceType, TOOL_DOCUMENT_SOURCE_TYPES)), isNull(documents.knowledgeBaseId), ne(documents.fileType, DOCUMENT_FOLDER_TYPE), @@ -101,12 +112,7 @@ export class RecentModel { updatedAt: tasks.updatedAt, }) .from(tasks) - .where( - and( - eq(tasks.createdByUserId, this.userId), - not(inArray(tasks.status, TASK_FINAL_STATUSES)), - ), - ); + .where(and(taskScopeWhere, not(inArray(tasks.status, TASK_FINAL_STATUSES)))); const rows = await unionAll(topicArm, documentArm, taskArm) .orderBy(desc(sql`updated_at`)) diff --git a/packages/database/src/models/session.ts b/packages/database/src/models/session.ts index 060bba0ad7..056bd0ac3e 100644 --- a/packages/database/src/models/session.ts +++ b/packages/database/src/models/session.ts @@ -16,15 +16,27 @@ import type { LobeChatDatabase } from '../type'; import { sanitizeBm25Query } from '../utils/bm25'; import { genEndDateWhere, genRangeWhere, genStartDateWhere, genWhere } from '../utils/genWhere'; import { idGenerator } from '../utils/idGenerator'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; export class SessionModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } + + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, sessions); + + private agentsOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agents); + + private agentsToSessionsOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentsToSessions); // **************** Query *************** // query = async ({ current = 0, pageSize = 9999 } = {}) => { @@ -44,7 +56,7 @@ export class SessionModel { .leftJoin(agentsToSessions, eq(sessions.id, agentsToSessions.sessionId)) .leftJoin(agents, eq(agentsToSessions.agentId, agents.id)) .leftJoin(sessionGroups, eq(sessions.groupId, sessionGroups.id)) - .where(and(eq(sessions.userId, this.userId), not(eq(sessions.slug, INBOX_SESSION_ID)))) + .where(and(this.ownership(), not(eq(sessions.slug, INBOX_SESSION_ID)))) .orderBy(desc(sessions.updatedAt)) .limit(pageSize) .offset(offset); @@ -76,7 +88,7 @@ export class SessionModel { const groups = await this.db.query.sessionGroups.findMany({ orderBy: [asc(sessionGroups.sort), desc(sessionGroups.createdAt)], - where: eq(sessions.userId, this.userId), + where: and(this.ownership()), }); const mappedSessions = result.map((item) => this.mapSessionItem(item as any)); @@ -108,12 +120,7 @@ export class SessionModel { session: sessions, }) .from(sessions) - .where( - and( - or(eq(sessions.id, idOrSlug), eq(sessions.slug, idOrSlug)), - eq(sessions.userId, this.userId), - ), - ) + .where(and(or(eq(sessions.id, idOrSlug), eq(sessions.slug, idOrSlug)), this.ownership())) .leftJoin(agentsToSessions, eq(sessions.id, agentsToSessions.sessionId)) .leftJoin(agents, eq(agentsToSessions.agentId, agents.id)) .leftJoin(sessionGroups, eq(sessions.groupId, sessionGroups.id)) @@ -136,7 +143,7 @@ export class SessionModel { .from(sessions) .where( genWhere([ - eq(sessions.userId, this.userId), + this.ownership(), params?.range ? genRangeWhere(params.range, sessions.createdAt, (date) => date.toDate()) : undefined, @@ -156,7 +163,7 @@ export class SessionModel { const result = await this.db .select({ id: sessions.id }) .from(sessions) - .where(eq(sessions.userId, this.userId)) + .where(and(this.ownership())) .limit(n + 1); return result.length > n; @@ -184,7 +191,7 @@ export class SessionModel { return this.db.transaction(async (trx) => { if (slug) { const existResult = await trx.query.sessions.findFirst({ - where: and(eq(sessions.slug, slug), eq(sessions.userId, this.userId)), + where: and(eq(sessions.slug, slug), this.ownership()), }); if (existResult) return existResult; @@ -220,15 +227,19 @@ export class SessionModel { if (type === 'group') { const result = await trx .insert(sessions) - .values({ - ...session, - createdAt: new Date(), - id, - slug, - type, - updatedAt: new Date(), - userId: this.userId, - }) + .values( + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { + ...session, + createdAt: new Date(), + id, + slug, + type, + updatedAt: new Date(), + }, + ), + ) .returning(); return result[0]; @@ -236,48 +247,57 @@ export class SessionModel { const newAgents = await trx .insert(agents) - .values({ - avatar, - backgroundColor, - chatConfig: chatConfig || {}, - createdAt: new Date(), - description, - editorData: editorData || null, - fewShots: examples || null, // Map examples to fewShots field - id: idGenerator('agents'), - marketIdentifier: identifier || marketIdentifier, - model: typeof model === 'string' ? model : null, - openingMessage, - openingQuestions, - params: params || {}, - plugins, - provider, - systemRole, - tags, - title, - tts: tts || {}, - updatedAt: new Date(), - userId: this.userId, - }) + .values( + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { + avatar, + backgroundColor, + chatConfig: chatConfig || {}, + createdAt: new Date(), + description, + editorData: editorData || null, + fewShots: examples || null, // Map examples to fewShots field + id: idGenerator('agents'), + marketIdentifier: identifier || marketIdentifier, + model: typeof model === 'string' ? model : null, + openingMessage, + openingQuestions, + params: params || {}, + plugins, + provider, + systemRole, + tags, + title, + tts: tts || {}, + updatedAt: new Date(), + }, + ), + ) .returning(); const result = await trx .insert(sessions) - .values({ - ...session, - createdAt: new Date(), - id, - slug, - type, - updatedAt: new Date(), - userId: this.userId, - }) + .values( + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { + ...session, + createdAt: new Date(), + id, + slug, + type, + updatedAt: new Date(), + }, + ), + ) .returning(); await trx.insert(agentsToSessions).values({ agentId: newAgents[0].id, sessionId: id, userId: this.userId, + workspaceId: this.workspaceId ?? null, }); return result[0]; @@ -286,7 +306,7 @@ export class SessionModel { createInbox = async (defaultAgentConfig: PartialDeep) => { const item = await this.db.query.sessions.findFirst({ - where: and(eq(sessions.userId, this.userId), eq(sessions.slug, INBOX_SESSION_ID)), + where: and(this.ownership(), eq(sessions.slug, INBOX_SESSION_ID)), }); if (item) return; @@ -299,13 +319,15 @@ export class SessionModel { }; batchCreate = async (newSessions: NewSession[]) => { - const sessionsToInsert = newSessions.map((s) => { - return { - ...s, - id: this.genId(), - userId: this.userId, - }; - }); + const sessionsToInsert = newSessions.map((s) => + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { + ...s, + id: this.genId(), + }, + ), + ); return this.db.insert(sessions).values(sessionsToInsert); }; @@ -343,19 +365,17 @@ export class SessionModel { const links = await trx .select({ agentId: agentsToSessions.agentId }) .from(agentsToSessions) - .where(and(eq(agentsToSessions.sessionId, id), eq(agentsToSessions.userId, this.userId))); + .where(and(eq(agentsToSessions.sessionId, id), this.agentsToSessionsOwnership())); const agentIds = links.map((link) => link.agentId); // Delete links in agentsToSessions await trx .delete(agentsToSessions) - .where(and(eq(agentsToSessions.sessionId, id), eq(agentsToSessions.userId, this.userId))); + .where(and(eq(agentsToSessions.sessionId, id), this.agentsToSessionsOwnership())); // Delete the session (this will cascade delete messages, topics, etc.) - const result = await trx - .delete(sessions) - .where(and(eq(sessions.id, id), eq(sessions.userId, this.userId))); + const result = await trx.delete(sessions).where(and(eq(sessions.id, id), this.ownership())); // Delete orphaned agents await this.clearOrphanAgent(agentIds, trx); @@ -375,23 +395,19 @@ export class SessionModel { const links = await trx .select({ agentId: agentsToSessions.agentId }) .from(agentsToSessions) - .where( - and(inArray(agentsToSessions.sessionId, ids), eq(agentsToSessions.userId, this.userId)), - ); + .where(and(inArray(agentsToSessions.sessionId, ids), this.agentsToSessionsOwnership())); const agentIds = [...new Set(links.map((link) => link.agentId))]; // Delete links in agentsToSessions await trx .delete(agentsToSessions) - .where( - and(inArray(agentsToSessions.sessionId, ids), eq(agentsToSessions.userId, this.userId)), - ); + .where(and(inArray(agentsToSessions.sessionId, ids), this.agentsToSessionsOwnership())); // Delete the sessions const result = await trx .delete(sessions) - .where(and(inArray(sessions.id, ids), eq(sessions.userId, this.userId))); + .where(and(inArray(sessions.id, ids), this.ownership())); // Delete orphaned agents await this.clearOrphanAgent(agentIds, trx); @@ -405,14 +421,9 @@ export class SessionModel { */ deleteAll = async () => { return this.db.transaction(async (trx) => { - // Delete all agentsToSessions for this user - await trx.delete(agentsToSessions).where(eq(agentsToSessions.userId, this.userId)); - - // Delete all agents that were only used by this user's sessions - await trx.delete(agents).where(eq(agents.userId, this.userId)); - - // Delete all sessions for this user - return trx.delete(sessions).where(eq(sessions.userId, this.userId)); + await trx.delete(agentsToSessions).where(this.agentsToSessionsOwnership()); + await trx.delete(agents).where(this.agentsOwnership()); + return trx.delete(sessions).where(this.ownership()); }); }; @@ -435,7 +446,7 @@ export class SessionModel { if (orphanedAgentIds.length > 0) { await trx .delete(agents) - .where(and(inArray(agents.id, orphanedAgentIds), eq(agents.userId, this.userId))); + .where(and(inArray(agents.id, orphanedAgentIds), this.agentsOwnership())); } }; @@ -445,7 +456,7 @@ export class SessionModel { return this.db .update(sessions) .set(data) - .where(and(eq(sessions.id, id), eq(sessions.userId, this.userId))) + .where(and(eq(sessions.id, id), this.ownership())) .returning(); }; @@ -505,7 +516,7 @@ export class SessionModel { return this.db .update(agents) .set(mergedValue) - .where(and(eq(agents.id, session.agent.id), eq(agents.userId, this.userId))); + .where(and(eq(agents.id, session.agent.id), this.agentsOwnership())); }; // **************** Helper *************** // @@ -598,7 +609,7 @@ export class SessionModel { // Keep deterministic ordering for keyword search results orderBy: [asc(agents.id)], where: and( - eq(agents.userId, this.userId), + this.agentsOwnership(), sql`(${agents.title} @@@ ${bm25Query} OR ${agents.description} @@@ ${bm25Query})`, ), with: { agentsToSessions: { columns: {}, with: { session: true } } }, diff --git a/packages/database/src/models/sessionGroup.ts b/packages/database/src/models/sessionGroup.ts index 771848c3d3..dc595d162e 100644 --- a/packages/database/src/models/sessionGroup.ts +++ b/packages/database/src/models/sessionGroup.ts @@ -4,45 +4,54 @@ import type { SessionGroupItem } from '../schemas'; import { sessionGroups } from '../schemas'; import type { LobeChatDatabase } from '../type'; import { idGenerator } from '../utils/idGenerator'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; export class SessionGroupModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, sessionGroups); + create = async (params: { name: string; sort?: number }) => { const [result] = await this.db .insert(sessionGroups) - .values({ ...params, id: this.genId(), userId: this.userId }) + .values( + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { ...params, id: this.genId() }, + ), + ) .returning(); return result; }; delete = async (id: string) => { - return this.db - .delete(sessionGroups) - .where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId))); + return this.db.delete(sessionGroups).where(and(eq(sessionGroups.id, id), this.ownership())); }; deleteAll = async () => { - return this.db.delete(sessionGroups).where(eq(sessionGroups.userId, this.userId)); + return this.db.delete(sessionGroups).where(this.ownership()); }; query = async () => { return this.db.query.sessionGroups.findMany({ orderBy: [asc(sessionGroups.sort), desc(sessionGroups.createdAt)], - where: eq(sessionGroups.userId, this.userId), + where: this.ownership(), }); }; findById = async (id: string) => { return this.db.query.sessionGroups.findFirst({ - where: and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)), + where: and(eq(sessionGroups.id, id), this.ownership()), }); }; @@ -50,7 +59,7 @@ export class SessionGroupModel { return this.db .update(sessionGroups) .set({ ...value, updatedAt: new Date() }) - .where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId))); + .where(and(eq(sessionGroups.id, id), this.ownership())); }; updateOrder = async (sortMap: { id: string; sort: number }[]) => { @@ -59,7 +68,7 @@ export class SessionGroupModel { return tx .update(sessionGroups) .set({ sort, updatedAt: new Date() }) - .where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId))); + .where(and(eq(sessionGroups.id, id), this.ownership())); }); await Promise.all(updates); diff --git a/packages/database/src/models/task.ts b/packages/database/src/models/task.ts index 5671151986..7b87c7d5ee 100644 --- a/packages/database/src/models/task.ts +++ b/packages/database/src/models/task.ts @@ -7,6 +7,7 @@ import type { WorkspaceTreeNode, } from '@lobechat/types'; import { and, desc, eq, gte, inArray, isNotNull, isNull, ne, notInArray, sql } from 'drizzle-orm'; +import type { AnyPgColumn } from 'drizzle-orm/pg-core'; import { merge } from '@/utils/merge'; @@ -14,16 +15,50 @@ import { documents } from '../schemas/file'; import type { NewTaskComment, TaskCommentItem } from '../schemas/task'; import { taskComments, taskDependencies, taskDocuments, tasks } from '../schemas/task'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspaceWhere } from '../utils/workspace'; export class TaskModel { private readonly userId: string; private readonly db: LobeChatDatabase; + private readonly workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } + /** + * Compat-mode ownership predicate for the `tasks` table. + * `tasks` uses `createdByUserId` instead of `userId`. + */ + private ownership = () => + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + { userId: tasks.createdByUserId, workspaceId: tasks.workspaceId }, + ); + + /** + * Ownership predicate for task child tables (deps / docs / comments) that + * use a `userId` column instead of `createdByUserId`. + */ + private childOwnership = (cols: { userId: AnyPgColumn; workspaceId: AnyPgColumn }) => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, cols); + + /** + * Raw-SQL ownership clause for use inside `db.execute(sql...)` CTEs that + * can't easily compose with drizzle's `and(...)` helpers. Mirrors + * `buildWorkspaceWhere` semantics: + * - workspace mode → `workspace_id = $ws` + * - personal mode → `created_by_user_id = $userId AND workspace_id IS NULL` + */ + private ownershipSql = (alias?: string) => { + const prefix = alias ? sql.raw(`${alias}.`) : sql.raw(''); + return this.workspaceId + ? sql`${prefix}workspace_id = ${this.workspaceId}` + : sql`${prefix}created_by_user_id = ${this.userId} AND ${prefix}workspace_id IS NULL`; + }; + // ========== CRUD ========== async create( @@ -37,10 +72,13 @@ export class TaskModel { const maxRetries = 5; for (let attempt = 0; attempt < maxRetries; attempt++) { try { + // Seq is allocated per ownership scope: workspace-wide in team mode, + // user-private in personal mode. This keeps `T-N` identifiers stable + // within the surface the user actually sees. const seqResult = await this.db .select({ maxSeq: sql`COALESCE(MAX(${tasks.seq}), 0)` }) .from(tasks) - .where(eq(tasks.createdByUserId, this.userId)); + .where(this.ownership()); const nextSeq = Number(seqResult[0].maxSeq) + 1; const identifier = `${identifierPrefix}-${nextSeq}`; @@ -52,6 +90,7 @@ export class TaskModel { createdByUserId: this.userId, identifier, seq: nextSeq, + workspaceId: this.workspaceId ?? null, } as NewTask) .returning(); @@ -79,7 +118,7 @@ export class TaskModel { const result = await this.db .select() .from(tasks) - .where(and(eq(tasks.id, id), eq(tasks.createdByUserId, this.userId))) + .where(and(eq(tasks.id, id), this.ownership())) .limit(1); return result[0] || null; @@ -90,7 +129,7 @@ export class TaskModel { return this.db .select() .from(tasks) - .where(and(inArray(tasks.id, ids), eq(tasks.createdByUserId, this.userId))); + .where(and(inArray(tasks.id, ids), this.ownership())); } // Resolve id or identifier (e.g. 'T-1') to a task @@ -103,7 +142,7 @@ export class TaskModel { const result = await this.db .select() .from(tasks) - .where(and(eq(tasks.identifier, identifier), eq(tasks.createdByUserId, this.userId))) + .where(and(eq(tasks.identifier, identifier), this.ownership())) .limit(1); return result[0] || null; @@ -118,7 +157,7 @@ export class TaskModel { const updated = await this.db .update(tasks) .set({ ...data, updatedAt: new Date() }) - .where(and(eq(tasks.id, id), eq(tasks.createdByUserId, this.userId))) + .where(and(eq(tasks.id, id), this.ownership())) .returning(); return updated[0] || null; } @@ -126,17 +165,14 @@ export class TaskModel { async delete(id: string): Promise { const result = await this.db .delete(tasks) - .where(and(eq(tasks.id, id), eq(tasks.createdByUserId, this.userId))) + .where(and(eq(tasks.id, id), this.ownership())) .returning(); return result.length > 0; } async deleteAll(): Promise { - const result = await this.db - .delete(tasks) - .where(eq(tasks.createdByUserId, this.userId)) - .returning(); + const result = await this.db.delete(tasks).where(this.ownership()).returning(); return result.length; } @@ -164,7 +200,7 @@ export class TaskModel { > { const { groups, assigneeAgentId, parentTaskId } = options; - const baseConditions = [eq(tasks.createdByUserId, this.userId)]; + const baseConditions = [this.ownership()]; if (assigneeAgentId) baseConditions.push(eq(tasks.assigneeAgentId, assigneeAgentId)); if (parentTaskId === null) { baseConditions.push(isNull(tasks.parentTaskId)); @@ -232,7 +268,7 @@ export class TaskModel { offset = 0, } = options || {}; - const conditions = [eq(tasks.createdByUserId, this.userId)]; + const conditions = [this.ownership()]; if (statuses?.length) conditions.push(inArray(tasks.status, statuses)); if (priorities?.length) conditions.push(inArray(tasks.priority, priorities)); @@ -271,7 +307,7 @@ export class TaskModel { await this.db .update(tasks) .set({ sortOrder: item.sortOrder, updatedAt: new Date() }) - .where(and(eq(tasks.id, item.id), eq(tasks.createdByUserId, this.userId))); + .where(and(eq(tasks.id, item.id), this.ownership())); } } @@ -279,7 +315,7 @@ export class TaskModel { return this.db .select() .from(tasks) - .where(and(eq(tasks.parentTaskId, parentTaskId), eq(tasks.createdByUserId, this.userId))) + .where(and(eq(tasks.parentTaskId, parentTaskId), this.ownership())) .orderBy(tasks.sortOrder, tasks.seq); } @@ -295,7 +331,7 @@ export class TaskModel { const children = await this.db .select() .from(tasks) - .where(and(inArray(tasks.parentTaskId, parentIds), eq(tasks.createdByUserId, this.userId))) + .where(and(inArray(tasks.parentTaskId, parentIds), this.ownership())) .orderBy(tasks.sortOrder, tasks.seq); if (children.length === 0) break; @@ -309,9 +345,10 @@ export class TaskModel { // Recursive query to get full task tree async getTaskTree(rootTaskId: string): Promise { + const ownership = this.ownershipSql(); const result = await this.db.execute(sql` WITH RECURSIVE task_tree AS ( - SELECT * FROM tasks WHERE id = ${rootTaskId} AND created_by_user_id = ${this.userId} + SELECT * FROM tasks WHERE id = ${rootTaskId} AND ${ownership} UNION ALL SELECT t.* FROM tasks t JOIN task_tree tt ON t.parent_task_id = tt.id @@ -333,18 +370,20 @@ export class TaskModel { const taskIdParams = taskIds.map((id) => sql`${id}`); const taskIdList = sql.join(taskIdParams, sql`, `); + const ownershipBare = this.ownershipSql(); + const ownershipAliased = this.ownershipSql('t'); const result = await this.db.execute(sql` WITH RECURSIVE ancestors AS ( SELECT id AS origin_id, id, parent_task_id FROM tasks WHERE id IN (${taskIdList}) - AND created_by_user_id = ${this.userId} + AND ${ownershipBare} UNION ALL SELECT a.origin_id, t.id, t.parent_task_id FROM tasks t JOIN ancestors a ON t.id = a.parent_task_id - WHERE t.created_by_user_id = ${this.userId} + WHERE ${ownershipAliased} ), roots AS ( SELECT DISTINCT ON (origin_id) origin_id, id AS root_id @@ -355,12 +394,12 @@ export class TaskModel { SELECT r.origin_id, t.id, t.assignee_agent_id, t.created_by_agent_id FROM tasks t JOIN roots r ON t.id = r.root_id - WHERE t.created_by_user_id = ${this.userId} + WHERE ${ownershipAliased} UNION ALL SELECT d.origin_id, t.id, t.assignee_agent_id, t.created_by_agent_id FROM tasks t JOIN descendants d ON t.parent_task_id = d.id - WHERE t.created_by_user_id = ${this.userId} + WHERE ${ownershipAliased} ) SELECT origin_id, assignee_agent_id, created_by_agent_id FROM descendants @@ -392,7 +431,7 @@ export class TaskModel { const result = await this.db .update(tasks) .set({ status, updatedAt: new Date() }) - .where(and(inArray(tasks.id, ids), eq(tasks.createdByUserId, this.userId))) + .where(and(inArray(tasks.id, ids), this.ownership())) .returning(); return result.length; @@ -476,7 +515,7 @@ export class TaskModel { await this.db .update(tasks) .set({ lastHeartbeatAt: new Date(), updatedAt: new Date() }) - .where(eq(tasks.id, id)); + .where(and(eq(tasks.id, id), this.ownership())); } // Tasks eligible for cron-based dispatch. @@ -513,10 +552,22 @@ export class TaskModel { // ========== Dependencies ========== + private depsOwnership = () => + this.childOwnership({ + userId: taskDependencies.userId, + workspaceId: taskDependencies.workspaceId, + }); + async addDependency(taskId: string, dependsOnId: string, type: string = 'blocks'): Promise { await this.db .insert(taskDependencies) - .values({ dependsOnId, taskId, type, userId: this.userId }) + .values({ + dependsOnId, + taskId, + type, + userId: this.userId, + workspaceId: this.workspaceId ?? null, + }) .onConflictDoNothing(); } @@ -524,21 +575,34 @@ export class TaskModel { await this.db .delete(taskDependencies) .where( - and(eq(taskDependencies.taskId, taskId), eq(taskDependencies.dependsOnId, dependsOnId)), + and( + eq(taskDependencies.taskId, taskId), + eq(taskDependencies.dependsOnId, dependsOnId), + this.depsOwnership(), + ), ); } async getDependencies(taskId: string) { - return this.db.select().from(taskDependencies).where(eq(taskDependencies.taskId, taskId)); + return this.db + .select() + .from(taskDependencies) + .where(and(eq(taskDependencies.taskId, taskId), this.depsOwnership())); } async getDependenciesByTaskIds(taskIds: string[]) { if (taskIds.length === 0) return []; - return this.db.select().from(taskDependencies).where(inArray(taskDependencies.taskId, taskIds)); + return this.db + .select() + .from(taskDependencies) + .where(and(inArray(taskDependencies.taskId, taskIds), this.depsOwnership())); } async getDependents(taskId: string) { - return this.db.select().from(taskDependencies).where(eq(taskDependencies.dependsOnId, taskId)); + return this.db + .select() + .from(taskDependencies) + .where(and(eq(taskDependencies.dependsOnId, taskId), this.depsOwnership())); } // Check if all dependencies of a task are completed @@ -552,6 +616,7 @@ export class TaskModel { eq(taskDependencies.taskId, taskId), eq(taskDependencies.type, 'blocks'), ne(tasks.status, 'completed'), + this.depsOwnership(), ), ); @@ -587,11 +652,7 @@ export class TaskModel { .select({ count: sql`count(*)` }) .from(tasks) .where( - and( - eq(tasks.parentTaskId, parentTaskId), - ne(tasks.status, 'completed'), - eq(tasks.createdByUserId, this.userId), - ), + and(eq(tasks.parentTaskId, parentTaskId), ne(tasks.status, 'completed'), this.ownership()), ); return Number(result[0].count) === 0; @@ -599,24 +660,42 @@ export class TaskModel { // ========== Documents (MVP Workspace) ========== + private docsOwnership = () => + this.childOwnership({ + userId: taskDocuments.userId, + workspaceId: taskDocuments.workspaceId, + }); + async pinDocument(taskId: string, documentId: string, pinnedBy: string = 'agent'): Promise { await this.db .insert(taskDocuments) - .values({ documentId, pinnedBy, taskId, userId: this.userId }) + .values({ + documentId, + pinnedBy, + taskId, + userId: this.userId, + workspaceId: this.workspaceId ?? null, + }) .onConflictDoNothing(); } async unpinDocument(taskId: string, documentId: string): Promise { await this.db .delete(taskDocuments) - .where(and(eq(taskDocuments.taskId, taskId), eq(taskDocuments.documentId, documentId))); + .where( + and( + eq(taskDocuments.taskId, taskId), + eq(taskDocuments.documentId, documentId), + this.docsOwnership(), + ), + ); } async getPinnedDocuments(taskId: string) { return this.db .select() .from(taskDocuments) - .where(eq(taskDocuments.taskId, taskId)) + .where(and(eq(taskDocuments.taskId, taskId), this.docsOwnership())) .orderBy(taskDocuments.createdAt); } @@ -642,7 +721,7 @@ export class TaskModel { .where( and( eq(taskDocuments.taskId, taskId), - eq(taskDocuments.userId, this.userId), + this.docsOwnership(), gte(taskDocuments.createdAt, since), ), ); @@ -656,12 +735,18 @@ export class TaskModel { // Get all pinned docs from a task tree (recursive), returns nodeMap + tree structure async getTreePinnedDocuments(rootTaskId: string): Promise { + const rootOwnership = this.ownershipSql(); + const recursiveOwnership = this.ownershipSql('t'); + const docsOwnership = this.workspaceId + ? sql`td.workspace_id = ${this.workspaceId}` + : sql`td.user_id = ${this.userId} AND td.workspace_id IS NULL`; const result = await this.db.execute(sql` WITH RECURSIVE task_tree AS ( - SELECT id, identifier FROM tasks WHERE id = ${rootTaskId} + SELECT id, identifier FROM tasks WHERE id = ${rootTaskId} AND ${rootOwnership} UNION ALL SELECT t.id, t.identifier FROM tasks t JOIN task_tree tt ON t.parent_task_id = tt.id + WHERE ${recursiveOwnership} ) SELECT td.*, tt.id as source_task_id, tt.identifier as source_task_identifier, d.title as document_title, d.file_type as document_file_type, d.parent_id as document_parent_id, @@ -669,6 +754,7 @@ export class TaskModel { FROM task_documents td JOIN task_tree tt ON td.task_id = tt.id LEFT JOIN documents d ON td.document_id = d.id + WHERE ${docsOwnership} ORDER BY td.created_at `); @@ -725,20 +811,29 @@ export class TaskModel { totalTopics: sql`${tasks.totalTopics} + 1`, updatedAt: new Date(), }) - .where(eq(tasks.id, id)); + .where(and(eq(tasks.id, id), this.ownership())); } async updateCurrentTopic(id: string, topicId: string): Promise { await this.db .update(tasks) .set({ currentTopicId: topicId, updatedAt: new Date() }) - .where(eq(tasks.id, id)); + .where(and(eq(tasks.id, id), this.ownership())); } // ========== Comments ========== + private commentsOwnership = () => + this.childOwnership({ + userId: taskComments.userId, + workspaceId: taskComments.workspaceId, + }); + async addComment(data: Omit): Promise { - const [comment] = await this.db.insert(taskComments).values(data).returning(); + const [comment] = await this.db + .insert(taskComments) + .values({ ...data, workspaceId: this.workspaceId ?? null }) + .returning(); return comment; } @@ -746,14 +841,14 @@ export class TaskModel { return this.db .select() .from(taskComments) - .where(eq(taskComments.taskId, taskId)) + .where(and(eq(taskComments.taskId, taskId), this.commentsOwnership())) .orderBy(taskComments.createdAt); } async deleteComment(id: string): Promise { const result = await this.db .delete(taskComments) - .where(and(eq(taskComments.id, id), eq(taskComments.userId, this.userId))) + .where(and(eq(taskComments.id, id), this.commentsOwnership())) .returning(); return result.length > 0; } @@ -770,8 +865,207 @@ export class TaskModel { ...(opts?.editorData !== undefined ? { editorData: opts.editorData as never } : {}), updatedAt: new Date(), }) - .where(and(eq(taskComments.id, id), eq(taskComments.userId, this.userId))) + .where(and(eq(taskComments.id, id), this.commentsOwnership())) .returning(); return comment; } + + // ========== Transfer / Copy ========== + + /** + * Collect a task and all its descendants (parentTaskId-linked) via BFS. + * Honors the current ownership scope. + */ + private async collectTaskSubtree(rootId: string, runner: LobeChatDatabase): Promise { + const [root] = await runner + .select() + .from(tasks) + .where(and(eq(tasks.id, rootId), this.ownership())) + .limit(1); + if (!root) return []; + + const collected: TaskItem[] = [root]; + let frontier: string[] = [root.id]; + + while (frontier.length > 0) { + const children = await runner + .select() + .from(tasks) + .where(and(inArray(tasks.parentTaskId, frontier), this.ownership())); + if (children.length === 0) break; + collected.push(...children); + frontier = children.map((c) => c.id); + } + + return collected; + } + + /** + * Allocate a contiguous block of seq numbers + identifiers in the target + * scope. Returns the next available seq baseline. + */ + private async nextSeqIn( + runner: LobeChatDatabase, + targetWorkspaceId: string | null, + targetUserId: string, + ): Promise { + const where = targetWorkspaceId + ? eq(tasks.workspaceId, targetWorkspaceId) + : and(eq(tasks.createdByUserId, targetUserId), isNull(tasks.workspaceId)); + const [{ maxSeq }] = await runner + .select({ maxSeq: sql`COALESCE(MAX(${tasks.seq}), 0)` }) + .from(tasks) + .where(where!); + return Number(maxSeq) + 1; + } + + /** + * Transfer a task subtree to another workspace / personal scope. Reallocates + * `identifier`/`seq` in the target scope and rewrites every dependent child + * table (`task_dependencies`, `task_documents`, `task_topics`, + * `task_comments`, `briefs`) so the ownership predicates remain consistent. + * + * Cross-scope references that may no longer be valid are cleared: + * - `assigneeAgentId` (workspace move: agent likely doesn't exist there) + * - `currentTopicId` (topic ownership is also moving but the link is + * reset to avoid surfacing a stale active topic in the new scope) + */ + async transferTo( + taskId: string, + targetWorkspaceId: string | null, + targetUserId: string, + ): Promise<{ taskIds: string[] }> { + return this.db.transaction(async (trx) => { + const scoped = new TaskModel(trx as LobeChatDatabase, this.userId, this.workspaceId); + const subtree = await scoped.collectTaskSubtree(taskId, trx as LobeChatDatabase); + if (subtree.length === 0) throw new Error('Task not found'); + + const ids = subtree.map((t) => t.id); + + // Reallocate identifier + seq in target scope to avoid collisions. + const baseSeq = await this.nextSeqIn( + trx as LobeChatDatabase, + targetWorkspaceId, + targetUserId, + ); + // Update each task individually because identifier/seq are per-row. + for (const [idx, task] of subtree.entries()) { + const seq = baseSeq + idx; + const identifier = `T-${seq}`; + await (trx as LobeChatDatabase) + .update(tasks) + .set({ + // Clear cross-scope refs: agent / topic may be invalid in new scope. + assigneeAgentId: targetWorkspaceId === this.workspaceId ? task.assigneeAgentId : null, + createdByUserId: targetUserId, + currentTopicId: null, + identifier, + seq, + updatedAt: new Date(), + workspaceId: targetWorkspaceId, + }) + .where(eq(tasks.id, task.id)); + } + + // Update child tables that key off taskId. + const ownershipUpdate = { userId: targetUserId, workspaceId: targetWorkspaceId }; + await (trx as LobeChatDatabase) + .update(taskDependencies) + .set(ownershipUpdate) + .where(inArray(taskDependencies.taskId, ids)); + await (trx as LobeChatDatabase) + .update(taskDocuments) + .set(ownershipUpdate) + .where(inArray(taskDocuments.taskId, ids)); + await (trx as LobeChatDatabase) + .update(taskComments) + .set(ownershipUpdate) + .where(inArray(taskComments.taskId, ids)); + + return { taskIds: ids }; + }); + } + + /** + * Deep clone a task subtree into another workspace / personal scope. Fresh + * ids, fresh identifiers, preserved parent/child topology. Cross-scope refs + * (agent / topic / brief / current topic) are cleared on the clones so the + * copies start clean in the new scope. + */ + async copyToWorkspace( + taskId: string, + targetWorkspaceId: string | null, + targetUserId: string, + ): Promise<{ rootId: string }> { + return this.db.transaction(async (trx) => { + const scoped = new TaskModel(trx as LobeChatDatabase, this.userId, this.workspaceId); + const subtree = await scoped.collectTaskSubtree(taskId, trx as LobeChatDatabase); + if (subtree.length === 0) throw new Error('Task not found'); + + // BFS clone — parent inserted before children, so we always know the + // new parentTaskId by the time we reach the child. + const idMap = new Map(); + const byId = new Map(subtree.map((t) => [t.id, t])); + const queue: string[] = [taskId]; + const seen = new Set(); + + let seq = await this.nextSeqIn(trx as LobeChatDatabase, targetWorkspaceId, targetUserId); + + while (queue.length > 0) { + const currentId = queue.shift()!; + if (seen.has(currentId)) continue; + seen.add(currentId); + const original = byId.get(currentId); + if (!original) continue; + + const newParentId = + currentId === taskId ? null : (idMap.get(original.parentTaskId!) ?? null); + + const identifier = `T-${seq}`; + const inserted = (await (trx as LobeChatDatabase) + .insert(tasks) + .values({ + assigneeAgentId: null, + assigneeUserId: null, + automationMode: original.automationMode, + config: original.config ?? {}, + context: { + ...(original.context as Record), + duplicatedFrom: original.id, + }, + createdByAgentId: null, + createdByUserId: targetUserId, + currentTopicId: null, + description: original.description, + error: null, + heartbeatInterval: original.heartbeatInterval, + heartbeatTimeout: original.heartbeatTimeout, + identifier, + instruction: original.instruction, + maxTopics: original.maxTopics, + name: original.name, + parentTaskId: newParentId, + priority: original.priority, + schedulePattern: original.schedulePattern, + scheduleTimezone: original.scheduleTimezone, + seq, + sortOrder: original.sortOrder, + // Reset lifecycle: copy starts fresh, not mid-run. + status: 'backlog', + totalTopics: 0, + workspaceId: targetWorkspaceId, + } as NewTask) + .returning({ id: tasks.id })) as { id: string }[]; + + idMap.set(original.id, inserted[0]!.id); + seq++; + + for (const c of subtree) { + if (c.parentTaskId === original.id) queue.push(c.id); + } + } + + return { rootId: idMap.get(taskId)! }; + }); + } } diff --git a/packages/database/src/models/taskTopic.ts b/packages/database/src/models/taskTopic.ts index 4a90b7138d..b4f1f7221a 100644 --- a/packages/database/src/models/taskTopic.ts +++ b/packages/database/src/models/taskTopic.ts @@ -5,18 +5,24 @@ import type { TaskTopicItem } from '../schemas/task'; import { tasks, taskTopics } from '../schemas/task'; import { topics } from '../schemas/topic'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspaceWhere } from '../utils/workspace'; const TERMINAL_TOPIC_STATUSES = new Set(['canceled', 'completed', 'failed', 'timeout']); export class TaskTopicModel { private readonly userId: string; private readonly db: LobeChatDatabase; + private readonly workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, taskTopics); + /** * Mirror a terminal taskTopic transition onto the underlying topic record: * stamp `topics.completedAt` so duration can be computed at read time, and @@ -29,7 +35,12 @@ export class TaskTopicModel { await this.db .update(topics) .set(setClause) - .where(and(eq(topics.id, topicId), eq(topics.userId, this.userId))); + .where( + and( + eq(topics.id, topicId), + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, topics), + ), + ); } async add( @@ -45,6 +56,7 @@ export class TaskTopicModel { taskId, topicId, userId: this.userId, + workspaceId: this.workspaceId ?? null, }) .onConflictDoNothing(); } @@ -53,13 +65,7 @@ export class TaskTopicModel { await this.db .update(taskTopics) .set({ status }) - .where( - and( - eq(taskTopics.taskId, taskId), - eq(taskTopics.topicId, topicId), - eq(taskTopics.userId, this.userId), - ), - ); + .where(and(eq(taskTopics.taskId, taskId), eq(taskTopics.topicId, topicId), this.ownership())); if (TERMINAL_TOPIC_STATUSES.has(status)) { await this.markTopicEnded(topicId, status); @@ -79,7 +85,7 @@ export class TaskTopicModel { eq(taskTopics.taskId, taskId), eq(taskTopics.topicId, topicId), eq(taskTopics.status, 'running'), - eq(taskTopics.userId, this.userId), + this.ownership(), ), ) .returning(); @@ -93,26 +99,14 @@ export class TaskTopicModel { await this.db .update(taskTopics) .set({ operationId }) - .where( - and( - eq(taskTopics.taskId, taskId), - eq(taskTopics.topicId, topicId), - eq(taskTopics.userId, this.userId), - ), - ); + .where(and(eq(taskTopics.taskId, taskId), eq(taskTopics.topicId, topicId), this.ownership())); } async updateHandoff(taskId: string, topicId: string, handoff: TaskTopicHandoff): Promise { await this.db .update(taskTopics) .set({ handoff }) - .where( - and( - eq(taskTopics.taskId, taskId), - eq(taskTopics.topicId, topicId), - eq(taskTopics.userId, this.userId), - ), - ); + .where(and(eq(taskTopics.taskId, taskId), eq(taskTopics.topicId, topicId), this.ownership())); } /** @@ -131,13 +125,7 @@ export class TaskTopicModel { .set({ handoff: sql`jsonb_set(COALESCE(${taskTopics.handoff}, '{}'::jsonb), '{briefDecision}', ${JSON.stringify(decision)}::jsonb)`, }) - .where( - and( - eq(taskTopics.taskId, taskId), - eq(taskTopics.topicId, topicId), - eq(taskTopics.userId, this.userId), - ), - ); + .where(and(eq(taskTopics.taskId, taskId), eq(taskTopics.topicId, topicId), this.ownership())); } async updateReview( @@ -159,26 +147,14 @@ export class TaskTopicModel { reviewScores: review.scores, reviewedAt: new Date(), }) - .where( - and( - eq(taskTopics.taskId, taskId), - eq(taskTopics.topicId, topicId), - eq(taskTopics.userId, this.userId), - ), - ); + .where(and(eq(taskTopics.taskId, taskId), eq(taskTopics.topicId, topicId), this.ownership())); } async timeoutRunning(taskId: string): Promise { const result = await this.db .update(taskTopics) .set({ status: 'timeout' }) - .where( - and( - eq(taskTopics.taskId, taskId), - eq(taskTopics.status, 'running'), - eq(taskTopics.userId, this.userId), - ), - ) + .where(and(eq(taskTopics.taskId, taskId), eq(taskTopics.status, 'running'), this.ownership())) .returning({ topicId: taskTopics.topicId }); await Promise.all( @@ -195,13 +171,13 @@ export class TaskTopicModel { const result = await this.db .select() .from(taskTopics) - .where(and(eq(taskTopics.topicId, topicId), eq(taskTopics.userId, this.userId))) + .where(and(eq(taskTopics.topicId, topicId), this.ownership())) .limit(1); return result[0] || null; } async countByTask(taskId: string, options?: { since?: Date }): Promise { - const conditions = [eq(taskTopics.taskId, taskId), eq(taskTopics.userId, this.userId)]; + const conditions = [eq(taskTopics.taskId, taskId), this.ownership()]; if (options?.since) conditions.push(gte(taskTopics.createdAt, options.since)); const rows = await this.db @@ -215,7 +191,7 @@ export class TaskTopicModel { return this.db .select() .from(taskTopics) - .where(and(eq(taskTopics.taskId, taskId), eq(taskTopics.userId, this.userId))) + .where(and(eq(taskTopics.taskId, taskId), this.ownership())) .orderBy(desc(taskTopics.seq)); } @@ -239,7 +215,7 @@ export class TaskTopicModel { }) .from(taskTopics) .innerJoin(topics, eq(taskTopics.topicId, topics.id)) - .where(and(eq(taskTopics.taskId, taskId), eq(taskTopics.userId, this.userId))) + .where(and(eq(taskTopics.taskId, taskId), this.ownership())) .orderBy(desc(taskTopics.seq)); } @@ -258,7 +234,7 @@ export class TaskTopicModel { }) .from(taskTopics) .leftJoin(topics, eq(taskTopics.topicId, topics.id)) - .where(and(eq(taskTopics.taskId, taskId), eq(taskTopics.userId, this.userId))) + .where(and(eq(taskTopics.taskId, taskId), this.ownership())) .orderBy(desc(taskTopics.seq)) .limit(limit); } @@ -266,13 +242,7 @@ export class TaskTopicModel { async remove(taskId: string, topicId: string): Promise { const result = await this.db .delete(taskTopics) - .where( - and( - eq(taskTopics.taskId, taskId), - eq(taskTopics.topicId, topicId), - eq(taskTopics.userId, this.userId), - ), - ) + .where(and(eq(taskTopics.taskId, taskId), eq(taskTopics.topicId, topicId), this.ownership())) .returning(); if (result.length > 0) { diff --git a/packages/database/src/models/thread.ts b/packages/database/src/models/thread.ts index 64e7dfa1be..38f653def0 100644 --- a/packages/database/src/models/thread.ts +++ b/packages/database/src/models/thread.ts @@ -5,6 +5,7 @@ import { and, desc, eq, sql } from 'drizzle-orm'; import type { ThreadItem } from '../schemas'; import { messages, threads } from '../schemas'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; /** * Per-thread subagent metrics, derived from the child messages at read time @@ -67,17 +68,27 @@ const queryColumns = { export class ThreadModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, threads); + create = async (params: CreateThreadParams) => { // @ts-ignore const [result] = await this.db .insert(threads) - .values({ status: ThreadStatus.Active, ...params, userId: this.userId }) + .values( + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { status: ThreadStatus.Active, ...params }, + ), + ) .onConflictDoNothing() .returning(); @@ -85,18 +96,18 @@ export class ThreadModel { }; delete = async (id: string) => { - return this.db.delete(threads).where(and(eq(threads.id, id), eq(threads.userId, this.userId))); + return this.db.delete(threads).where(and(eq(threads.id, id), this.ownership())); }; deleteAll = async () => { - return this.db.delete(threads).where(eq(threads.userId, this.userId)); + return this.db.delete(threads).where(this.ownership()); }; query = async () => { const data = await this.db .select(queryColumns) .from(threads) - .where(eq(threads.userId, this.userId)) + .where(this.ownership()) .orderBy(desc(threads.updatedAt)); return data as ThreadItem[]; @@ -110,7 +121,7 @@ export class ThreadModel { .select({ ...queryColumns, ...subagentMetricColumns }) .from(threads) .leftJoin(messages, eq(messages.threadId, threads.id)) - .where(and(eq(threads.topicId, topicId), eq(threads.userId, this.userId))) + .where(and(eq(threads.topicId, topicId), this.ownership())) .groupBy(threads.id) .orderBy(desc(threads.updatedAt)); @@ -119,7 +130,7 @@ export class ThreadModel { findById = async (id: string) => { return this.db.query.threads.findFirst({ - where: and(eq(threads.id, id), eq(threads.userId, this.userId)), + where: and(eq(threads.id, id), this.ownership()), }); }; @@ -127,6 +138,6 @@ export class ThreadModel { return this.db .update(threads) .set({ ...value, updatedAt: new Date() }) - .where(and(eq(threads.id, id), eq(threads.userId, this.userId))); + .where(and(eq(threads.id, id), this.ownership())); }; } diff --git a/packages/database/src/models/topic.ts b/packages/database/src/models/topic.ts index 5ef06f24c1..c05ce9e513 100644 --- a/packages/database/src/models/topic.ts +++ b/packages/database/src/models/topic.ts @@ -35,6 +35,7 @@ import type { LobeChatDatabase } from '../type'; import { sanitizeBm25Query } from '../utils/bm25'; import { genEndDateWhere, genRangeWhere, genStartDateWhere, genWhere } from '../utils/genWhere'; import { idGenerator } from '../utils/idGenerator'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace'; import { recomputeTopicUsage } from './topicUsage'; type OnboardingSessionMetadataPatch = Partial>; @@ -141,11 +142,18 @@ const buildTopicOrderBy = (sortBy?: TopicQuerySortBy): SQL[] => export class TopicModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } + + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, topics); + private messageOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, messages); // **************** Query *************** // query = async ({ @@ -234,7 +242,7 @@ export class TopicModel { // If groupId is provided, query topics by groupId directly if (groupId) { const whereCondition = and( - eq(topics.userId, this.userId), + this.ownership(), eq(topics.groupId, groupId), includeTriggerCondition, excludeTriggerCondition, @@ -299,7 +307,7 @@ export class TopicModel { : eq(topics.agentId, agentId); const agentWhere = and( - eq(topics.userId, this.userId), + this.ownership(), agentCondition, includeTriggerCondition, excludeTriggerCondition, @@ -355,7 +363,7 @@ export class TopicModel { // Fallback to containerId-based query (backward compatibility) const whereCondition = and( - eq(topics.userId, this.userId), + this.ownership(), this.matchContainer(containerId), includeTriggerCondition, excludeTriggerCondition, @@ -414,16 +422,12 @@ export class TopicModel { findById = async (id: string) => { return this.db.query.topics.findFirst({ - where: and(eq(topics.id, id), eq(topics.userId, this.userId)), + where: and(eq(topics.id, id), this.ownership()), }); }; queryAll = async (): Promise => { - return this.db - .select() - .from(topics) - .orderBy(topics.updatedAt) - .where(eq(topics.userId, this.userId)); + return this.db.select().from(topics).orderBy(topics.updatedAt).where(and(this.ownership())); }; queryByKeyword = async (keyword: string, containerId?: string | null): Promise => { @@ -439,7 +443,7 @@ export class TopicModel { .from(topics) .where( and( - eq(topics.userId, this.userId), + this.ownership(), this.matchContainer(containerId), sql`${topics.title} @@@ ${bm25Query}`, ), @@ -452,9 +456,9 @@ export class TopicModel { .innerJoin(topics, eq(messages.topicId, topics.id)) .where( and( - eq(messages.userId, this.userId), + this.messageOwnership(), sql`${messages.content} @@@ ${bm25Query}`, - eq(topics.userId, this.userId), + this.ownership(), this.matchContainer(containerId), ), ) @@ -472,7 +476,7 @@ export class TopicModel { const topicsByMessages = await this.db.query.topics.findMany({ orderBy: [desc(topics.updatedAt)], - where: and(eq(topics.userId, this.userId), inArray(topics.id, topicIds)), + where: and(this.ownership(), inArray(topics.id, topicIds)), }); // Merge results and deduplicate @@ -509,7 +513,7 @@ export class TopicModel { .from(topics) .where( genWhere([ - eq(topics.userId, this.userId), + this.ownership(), agentCondition, params?.containerId ? this.matchContainer(params.containerId) : undefined, params?.range @@ -536,7 +540,7 @@ export class TopicModel { title: topics.title, }) .from(topics) - .where(and(eq(topics.userId, this.userId))) + .where(and(this.ownership())) .leftJoin(messages, eq(topics.id, messages.topicId)) .groupBy(topics.id) .orderBy(desc(sql`count`)) @@ -565,7 +569,7 @@ export class TopicModel { .leftJoin(agents, eq(topics.agentId, agents.id)) .where( and( - eq(topics.userId, this.userId), + this.ownership(), or( // Group topics: has groupId not(isNull(topics.groupId)), @@ -592,14 +596,16 @@ export class TopicModel { id: string = this.genId(), timing?: ModelTimingContext, ): Promise => { - const insertData = { - ...params, - agentId: params.agentId || null, - groupId: params.groupId || null, - id, - sessionId: params.sessionId || null, - userId: this.userId, - }; + const insertData = buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { + ...params, + agentId: params.agentId || null, + groupId: params.groupId || null, + id, + sessionId: params.sessionId || null, + }, + ); const insertMeta = { hasAgentId: !!params.agentId, hasGroupId: !!params.groupId, @@ -638,7 +644,7 @@ export class TopicModel { tx .update(messages) .set({ topicId: topic.id }) - .where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIds))), + .where(and(this.messageOwnership(), inArray(messages.id, messageIds))), { messageCount: messageIds.length }, ); @@ -660,16 +666,20 @@ export class TopicModel { const createdTopics = await tx .insert(topics) .values( - topicParams.map((params) => ({ - agentId: params.agentId || null, - favorite: params.favorite, - groupId: params.sessionId ? null : params.groupId, - id: params.id || this.genId(), - sessionId: params.groupId ? null : params.sessionId, - title: params.title, - trigger: params.trigger, - userId: this.userId, - })), + topicParams.map((params) => + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { + agentId: params.agentId || null, + favorite: params.favorite, + groupId: params.sessionId ? null : params.groupId, + id: params.id || this.genId(), + sessionId: params.groupId ? null : params.sessionId, + title: params.title, + trigger: params.trigger, + }, + ), + ), ) .returning(); @@ -681,7 +691,7 @@ export class TopicModel { await tx .update(messages) .set({ topicId: topic.id }) - .where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIds))); + .where(and(this.messageOwnership(), inArray(messages.id, messageIds))); } }), ); @@ -694,7 +704,7 @@ export class TopicModel { return this.db.transaction(async (tx) => { // find original topic const originalTopic = await tx.query.topics.findFirst({ - where: and(eq(topics.id, topicId), eq(topics.userId, this.userId)), + where: and(eq(topics.id, topicId), this.ownership()), }); if (!originalTopic) { @@ -704,19 +714,24 @@ export class TopicModel { // copy topic const [duplicatedTopic] = await tx .insert(topics) - .values({ - ...originalTopic, - clientId: null, - id: this.genId(), - title: newTitle || originalTopic?.title, - }) + .values( + buildWorkspacePayload( + { userId: this.userId, workspaceId: this.workspaceId }, + { + ...originalTopic, + clientId: null, + id: this.genId(), + title: newTitle || originalTopic?.title, + }, + ), + ) .returning(); // Find messages associated with the original topic, ordered by createdAt const originalMessages = await tx .select() .from(messages) - .where(and(eq(messages.topicId, topicId), eq(messages.userId, this.userId))) + .where(and(eq(messages.topicId, topicId), this.messageOwnership())) .orderBy(messages.createdAt); // Find all messagePlugins for this topic @@ -800,47 +815,39 @@ export class TopicModel { * Delete a session, also delete all messages and topics associated with it. */ delete = async (id: string) => { - return this.db.delete(topics).where(and(eq(topics.id, id), eq(topics.userId, this.userId))); + return this.db.delete(topics).where(and(eq(topics.id, id), this.ownership())); }; /** * Deletes multiple topics based on the sessionId. */ batchDeleteBySessionId = async (sessionId?: string | null) => { - return this.db - .delete(topics) - .where(and(this.matchSession(sessionId), eq(topics.userId, this.userId))); + return this.db.delete(topics).where(and(this.matchSession(sessionId), this.ownership())); }; /** * Deletes multiple topics based on the groupId. */ batchDeleteByGroupId = async (groupId?: string | null) => { - return this.db - .delete(topics) - .where(and(this.matchGroup(groupId), eq(topics.userId, this.userId))); + return this.db.delete(topics).where(and(this.matchGroup(groupId), this.ownership())); }; /** * Deletes all topics matching the given agentId (`topics.agentId`). */ batchDeleteByAgentId = async (agentId: string) => { - return this.db - .delete(topics) - .where(and(eq(topics.userId, this.userId), eq(topics.agentId, agentId))); + return this.db.delete(topics).where(and(this.ownership(), eq(topics.agentId, agentId))); }; /** * Deletes multiple topics and all messages associated with them in a transaction. */ batchDelete = async (ids: string[]) => { - return this.db - .delete(topics) - .where(and(inArray(topics.id, ids), eq(topics.userId, this.userId))); + return this.db.delete(topics).where(and(inArray(topics.id, ids), this.ownership())); }; deleteAll = async () => { - return this.db.delete(topics).where(eq(topics.userId, this.userId)); + return this.db.delete(topics).where(and(this.ownership())); }; // **************** Update *************** // @@ -849,7 +856,7 @@ export class TopicModel { return this.db .update(topics) .set({ ...data, updatedAt: new Date() }) - .where(and(eq(topics.id, id), eq(topics.userId, this.userId))) + .where(and(eq(topics.id, id), this.ownership())) .returning(); }; @@ -860,7 +867,7 @@ export class TopicModel { * external callers use this wrapper. Runs in a transaction for consistency. */ recomputeUsage = async (id: string) => - this.db.transaction((trx) => recomputeTopicUsage(trx, this.userId, id)); + this.db.transaction((trx) => recomputeTopicUsage(trx, this.userId, id, this.workspaceId)); /** * Update topic metadata with merge logic @@ -870,7 +877,7 @@ export class TopicModel { // Get existing topic to merge metadata const existing = await this.db.query.topics.findFirst({ columns: { metadata: true }, - where: and(eq(topics.id, id), eq(topics.userId, this.userId)), + where: and(eq(topics.id, id), this.ownership()), }); const mergedOnboardingSession = @@ -890,7 +897,7 @@ export class TopicModel { return this.db .update(topics) .set({ metadata: mergedMetadata }) - .where(and(eq(topics.id, id), eq(topics.userId, this.userId))) + .where(and(eq(topics.id, id), this.ownership())) .returning(); }; @@ -902,7 +909,7 @@ export class TopicModel { .from(topics) .where( and( - eq(topics.userId, this.userId), + this.ownership(), eq(topics.agentId, agentId), eq(topics.trigger, 'cron'), sql`(${topics.metadata}->>'cronJobId') IS NOT NULL`, @@ -970,7 +977,7 @@ export class TopicModel { limit: options.limit, orderBy: (fields, { asc }) => [asc(fields.createdAt), asc(fields.id)], where: and( - eq(topics.userId, this.userId), + this.ownership(), options.startDate ? gte(topics.createdAt, options.startDate) : undefined, options.endDate ? lte(topics.createdAt, options.endDate) : undefined, options.ignoreExtracted @@ -996,7 +1003,7 @@ export class TopicModel { .from(topics) .where( and( - eq(topics.userId, this.userId), + this.ownership(), options.startDate ? gte(topics.createdAt, options.startDate) : undefined, options.endDate ? lte(topics.createdAt, options.endDate) : undefined, options.ignoreExtracted diff --git a/packages/database/src/models/topicDocument.ts b/packages/database/src/models/topicDocument.ts index 1e2e9fa1d3..6e0850e201 100644 --- a/packages/database/src/models/topicDocument.ts +++ b/packages/database/src/models/topicDocument.ts @@ -3,6 +3,7 @@ import { and, desc, eq } from 'drizzle-orm'; import type { DocumentItem, NewTopicDocument } from '../schemas'; import { documents, topicDocuments } from '../schemas'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspaceWhere } from '../utils/workspace'; export interface TopicDocumentWithDetails extends DocumentItem { associatedAt: Date; @@ -11,12 +12,17 @@ export interface TopicDocumentWithDetails extends DocumentItem { export class TopicDocumentModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, topicDocuments); + /** * Associate a document with a topic. * @@ -30,7 +36,7 @@ export class TopicDocumentModel { ): Promise<{ documentId: string; topicId: string }> => { await this.db .insert(topicDocuments) - .values({ ...params, userId: this.userId }) + .values({ ...params, userId: this.userId, workspaceId: this.workspaceId ?? null }) .onConflictDoNothing(); return { documentId: params.documentId, topicId: params.topicId }; @@ -46,7 +52,7 @@ export class TopicDocumentModel { and( eq(topicDocuments.documentId, documentId), eq(topicDocuments.topicId, topicId), - eq(topicDocuments.userId, this.userId), + this.ownership(), ), ); }; @@ -68,7 +74,7 @@ export class TopicDocumentModel { .where( and( eq(topicDocuments.topicId, topicId), - eq(topicDocuments.userId, this.userId), + this.ownership(), filter?.type ? eq(documents.fileType, filter.type) : undefined, ), ) @@ -87,9 +93,7 @@ export class TopicDocumentModel { const results = await this.db .select({ topicId: topicDocuments.topicId }) .from(topicDocuments) - .where( - and(eq(topicDocuments.documentId, documentId), eq(topicDocuments.userId, this.userId)), - ); + .where(and(eq(topicDocuments.documentId, documentId), this.ownership())); return results.map((r) => r.topicId); }; @@ -102,7 +106,7 @@ export class TopicDocumentModel { where: and( eq(topicDocuments.documentId, documentId), eq(topicDocuments.topicId, topicId), - eq(topicDocuments.userId, this.userId), + this.ownership(), ), }); @@ -115,7 +119,7 @@ export class TopicDocumentModel { deleteByTopicId = async (topicId: string) => { return this.db .delete(topicDocuments) - .where(and(eq(topicDocuments.topicId, topicId), eq(topicDocuments.userId, this.userId))); + .where(and(eq(topicDocuments.topicId, topicId), this.ownership())); }; /** @@ -124,8 +128,6 @@ export class TopicDocumentModel { deleteByDocumentId = async (documentId: string) => { return this.db .delete(topicDocuments) - .where( - and(eq(topicDocuments.documentId, documentId), eq(topicDocuments.userId, this.userId)), - ); + .where(and(eq(topicDocuments.documentId, documentId), this.ownership())); }; } diff --git a/packages/database/src/models/topicShare.ts b/packages/database/src/models/topicShare.ts index 5a7a5da257..4266cff0f4 100644 --- a/packages/database/src/models/topicShare.ts +++ b/packages/database/src/models/topicShare.ts @@ -4,6 +4,7 @@ import { and, asc, eq, sql } from 'drizzle-orm'; import { agents, chatGroups, chatGroupsAgents, topics, topicShares } from '../schemas'; import type { LobeChatDatabase } from '../type'; +import { buildWorkspaceWhere } from '../utils/workspace'; export type TopicShareData = NonNullable< Awaited> @@ -12,21 +13,29 @@ export type TopicShareData = NonNullable< export class TopicShareModel { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } + private ownership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, topicShares); + /** * Create or get existing share for a topic. * Each topic can only have one share record (enforced by unique constraint). * If record already exists, returns the existing one. */ create = async (topicId: string, visibility: ShareVisibility = 'private') => { - // First verify the topic belongs to the user + // First verify the topic belongs to the user (or workspace). const topic = await this.db.query.topics.findFirst({ - where: and(eq(topics.id, topicId), eq(topics.userId, this.userId)), + where: and( + eq(topics.id, topicId), + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, topics), + ), }); if (!topic) { @@ -39,6 +48,7 @@ export class TopicShareModel { topicId, userId: this.userId, visibility, + workspaceId: this.workspaceId ?? null, }) .onConflictDoNothing({ target: topicShares.topicId }) .returning(); @@ -58,7 +68,7 @@ export class TopicShareModel { const [result] = await this.db .update(topicShares) .set({ updatedAt: new Date(), visibility }) - .where(and(eq(topicShares.topicId, topicId), eq(topicShares.userId, this.userId))) + .where(and(eq(topicShares.topicId, topicId), this.ownership())) .returning(); return result || null; @@ -70,7 +80,7 @@ export class TopicShareModel { deleteByTopicId = async (topicId: string) => { return this.db .delete(topicShares) - .where(and(eq(topicShares.topicId, topicId), eq(topicShares.userId, this.userId))); + .where(and(eq(topicShares.topicId, topicId), this.ownership())); }; /** @@ -85,7 +95,7 @@ export class TopicShareModel { visibility: topicShares.visibility, }) .from(topicShares) - .where(and(eq(topicShares.topicId, topicId), eq(topicShares.userId, this.userId))) + .where(and(eq(topicShares.topicId, topicId), this.ownership())) .limit(1); return result[0] || null; diff --git a/packages/database/src/models/topicUsage.ts b/packages/database/src/models/topicUsage.ts index 9ef3a2deb0..695a435b12 100644 --- a/packages/database/src/models/topicUsage.ts +++ b/packages/database/src/models/topicUsage.ts @@ -2,6 +2,7 @@ import { and, eq, sql } from 'drizzle-orm'; import { topics } from '../schemas'; import type { Transaction } from '../type'; +import { buildWorkspaceWhere } from '../utils/workspace'; /** * ModelUsage numeric fields summed per (provider, model) to build the @@ -60,6 +61,7 @@ export const recomputeTopicUsage = async ( trx: Transaction, userId: string, topicId: string, + workspaceId?: string, ): Promise => { // Reads prefer the dedicated `usage` column, falling back to legacy // `metadata->'usage'` for rows written before the migration. @@ -67,6 +69,13 @@ export const recomputeTopicUsage = async ( (f) => `sum((COALESCE(usage, metadata->'usage')->>'${f}')::numeric) AS "${f}"`, ).join(',\n '); + // Workspace-aware ownership predicate for the raw messages aggregate: in team + // mode rows are scoped by workspace_id (creator user_id is not part of the + // filter); in personal mode by user_id with workspace_id IS NULL. + const messageOwnership = workspaceId + ? sql`workspace_id = ${workspaceId}` + : sql`user_id = ${userId} AND workspace_id IS NULL`; + const { rows } = await trx.execute(sql` SELECT provider, @@ -77,7 +86,7 @@ export const recomputeTopicUsage = async ( ${sql.raw(fieldSelects)} FROM messages WHERE topic_id = ${topicId} - AND user_id = ${userId} + AND ${messageOwnership} AND role = 'assistant' AND (usage IS NOT NULL OR metadata ? 'usage') GROUP BY provider, model @@ -99,7 +108,7 @@ export const recomputeTopicUsage = async ( totalTokens: null, usage: null, }) - .where(and(eq(topics.id, topicId), eq(topics.userId, userId))); + .where(and(eq(topics.id, topicId), buildWorkspaceWhere({ userId, workspaceId }, topics))); return; } @@ -191,5 +200,5 @@ export const recomputeTopicUsage = async ( totalTokens, usage, }) - .where(and(eq(topics.id, topicId), eq(topics.userId, userId))); + .where(and(eq(topics.id, topicId), buildWorkspaceWhere({ userId, workspaceId }, topics))); }; diff --git a/packages/database/src/models/userMemory/activity.ts b/packages/database/src/models/userMemory/activity.ts index 45ce0988ae..3c1350f68b 100644 --- a/packages/database/src/models/userMemory/activity.ts +++ b/packages/database/src/models/userMemory/activity.ts @@ -16,6 +16,10 @@ export class UserMemoryActivityModel { this.db = db; } + private memoryWhere(table: { userId: any }) { + return eq(table.userId, this.userId); + } + create = async (params: Omit) => { const [result] = await this.db .insert(userMemoriesActivities) @@ -28,10 +32,7 @@ export class UserMemoryActivityModel { delete = async (id: string) => { return this.db.transaction(async (tx) => { const activity = await tx.query.userMemoriesActivities.findFirst({ - where: and( - eq(userMemoriesActivities.id, id), - eq(userMemoriesActivities.userId, this.userId), - ), + where: and(eq(userMemoriesActivities.id, id), this.memoryWhere(userMemoriesActivities)), }); if (!activity || !activity.userMemoryId) { @@ -40,25 +41,21 @@ export class UserMemoryActivityModel { await tx .delete(userMemories) - .where( - and(eq(userMemories.id, activity.userMemoryId), eq(userMemories.userId, this.userId)), - ); + .where(and(eq(userMemories.id, activity.userMemoryId), this.memoryWhere(userMemories))); return { success: true }; }); }; deleteAll = async () => { - return this.db - .delete(userMemoriesActivities) - .where(eq(userMemoriesActivities.userId, this.userId)); + return this.db.delete(userMemoriesActivities).where(this.memoryWhere(userMemoriesActivities)); }; query = async (limit = 50) => { return this.db.query.userMemoriesActivities.findMany({ limit, orderBy: [desc(userMemoriesActivities.createdAt)], - where: eq(userMemoriesActivities.userId, this.userId), + where: this.memoryWhere(userMemoriesActivities), }); }; @@ -74,7 +71,7 @@ export class UserMemoryActivityModel { : ''; const conditions: Array = [ - eq(userMemoriesActivities.userId, this.userId), + this.memoryWhere(userMemoriesActivities), normalizedQuery ? sql`(${userMemories.id} @@@ paradedb.boolean(should => ARRAY[paradedb.match('title', ${bm25MatchQuery}, conjunction_mode => true)]) OR ${userMemoriesActivities.id} @@@ paradedb.boolean(should => ARRAY[paradedb.match('narrative', ${bm25MatchQuery}, conjunction_mode => true), paradedb.match('notes', ${bm25MatchQuery}, conjunction_mode => true), paradedb.match('feedback', ${bm25MatchQuery}, conjunction_mode => true)]))` : undefined, @@ -108,7 +105,7 @@ export class UserMemoryActivityModel { const joinCondition = and( eq(userMemories.id, userMemoriesActivities.userMemoryId), - eq(userMemories.userId, this.userId), + this.memoryWhere(userMemories), ); const [rows, totalResult] = await Promise.all([ @@ -151,7 +148,7 @@ export class UserMemoryActivityModel { findById = async (id: string) => { return this.db.query.userMemoriesActivities.findFirst({ - where: and(eq(userMemoriesActivities.id, id), eq(userMemoriesActivities.userId, this.userId)), + where: and(eq(userMemoriesActivities.id, id), this.memoryWhere(userMemoriesActivities)), }); }; @@ -159,8 +156,6 @@ export class UserMemoryActivityModel { return this.db .update(userMemoriesActivities) .set({ ...value, updatedAt: new Date() }) - .where( - and(eq(userMemoriesActivities.id, id), eq(userMemoriesActivities.userId, this.userId)), - ); + .where(and(eq(userMemoriesActivities.id, id), this.memoryWhere(userMemoriesActivities))); }; } diff --git a/packages/database/src/models/userMemory/context.ts b/packages/database/src/models/userMemory/context.ts index e3fb656bba..242f03f44c 100644 --- a/packages/database/src/models/userMemory/context.ts +++ b/packages/database/src/models/userMemory/context.ts @@ -13,6 +13,10 @@ export class UserMemoryContextModel { this.db = db; } + private memoryWhere(table: { userId: any }) { + return eq(table.userId, this.userId); + } + create = async (params: Omit) => { const [result] = await this.db .insert(userMemoriesContexts) @@ -25,7 +29,7 @@ export class UserMemoryContextModel { delete = async (id: string) => { return this.db.transaction(async (tx) => { const context = await tx.query.userMemoriesContexts.findFirst({ - where: and(eq(userMemoriesContexts.id, id), eq(userMemoriesContexts.userId, this.userId)), + where: and(eq(userMemoriesContexts.id, id), this.memoryWhere(userMemoriesContexts)), }); if (!context) { @@ -41,34 +45,34 @@ export class UserMemoryContextModel { for (const memoryId of memoryIds) { await tx .delete(userMemories) - .where(and(eq(userMemories.id, memoryId), eq(userMemories.userId, this.userId))); + .where(and(eq(userMemories.id, memoryId), this.memoryWhere(userMemories))); } } // Delete the context entry await tx .delete(userMemoriesContexts) - .where(and(eq(userMemoriesContexts.id, id), eq(userMemoriesContexts.userId, this.userId))); + .where(and(eq(userMemoriesContexts.id, id), this.memoryWhere(userMemoriesContexts))); return { success: true }; }); }; deleteAll = async () => { - return this.db.delete(userMemoriesContexts).where(eq(userMemoriesContexts.userId, this.userId)); + return this.db.delete(userMemoriesContexts).where(this.memoryWhere(userMemoriesContexts)); }; query = async (limit = 50) => { return this.db.query.userMemoriesContexts.findMany({ limit, orderBy: [desc(userMemoriesContexts.createdAt)], - where: eq(userMemoriesContexts.userId, this.userId), + where: this.memoryWhere(userMemoriesContexts), }); }; findById = async (id: string) => { return this.db.query.userMemoriesContexts.findFirst({ - where: and(eq(userMemoriesContexts.id, id), eq(userMemoriesContexts.userId, this.userId)), + where: and(eq(userMemoriesContexts.id, id), this.memoryWhere(userMemoriesContexts)), }); }; @@ -76,6 +80,6 @@ export class UserMemoryContextModel { return this.db .update(userMemoriesContexts) .set({ ...value, updatedAt: new Date() }) - .where(and(eq(userMemoriesContexts.id, id), eq(userMemoriesContexts.userId, this.userId))); + .where(and(eq(userMemoriesContexts.id, id), this.memoryWhere(userMemoriesContexts))); }; } diff --git a/packages/database/src/models/userMemory/experience.ts b/packages/database/src/models/userMemory/experience.ts index 8588f0cf6e..da4954f79d 100644 --- a/packages/database/src/models/userMemory/experience.ts +++ b/packages/database/src/models/userMemory/experience.ts @@ -16,6 +16,10 @@ export class UserMemoryExperienceModel { this.db = db; } + private memoryWhere(table: { userId: any }) { + return eq(table.userId, this.userId); + } + create = async (params: Omit) => { const [result] = await this.db .insert(userMemoriesExperiences) @@ -28,10 +32,7 @@ export class UserMemoryExperienceModel { delete = async (id: string) => { return this.db.transaction(async (tx) => { const experience = await tx.query.userMemoriesExperiences.findFirst({ - where: and( - eq(userMemoriesExperiences.id, id), - eq(userMemoriesExperiences.userId, this.userId), - ), + where: and(eq(userMemoriesExperiences.id, id), this.memoryWhere(userMemoriesExperiences)), }); if (!experience || !experience.userMemoryId) { @@ -41,25 +42,21 @@ export class UserMemoryExperienceModel { // Delete the base user memory (cascade will handle the experience) await tx .delete(userMemories) - .where( - and(eq(userMemories.id, experience.userMemoryId), eq(userMemories.userId, this.userId)), - ); + .where(and(eq(userMemories.id, experience.userMemoryId), this.memoryWhere(userMemories))); return { success: true }; }); }; deleteAll = async () => { - return this.db - .delete(userMemoriesExperiences) - .where(eq(userMemoriesExperiences.userId, this.userId)); + return this.db.delete(userMemoriesExperiences).where(this.memoryWhere(userMemoriesExperiences)); }; query = async (limit = 50) => { return this.db.query.userMemoriesExperiences.findMany({ limit, orderBy: [desc(userMemoriesExperiences.createdAt)], - where: eq(userMemoriesExperiences.userId, this.userId), + where: this.memoryWhere(userMemoriesExperiences), }); }; @@ -80,7 +77,7 @@ export class UserMemoryExperienceModel { // Build WHERE conditions const conditions: Array = [ - eq(userMemoriesExperiences.userId, this.userId), + this.memoryWhere(userMemoriesExperiences), // Full-text search across title, situation, keyLearning, action normalizedQuery ? sql`(${userMemories.id} @@@ paradedb.boolean(should => ARRAY[paradedb.match('title', ${bm25MatchQuery}, conjunction_mode => true)]) OR ${userMemoriesExperiences.id} @@@ paradedb.boolean(should => ARRAY[paradedb.match('situation', ${bm25MatchQuery}, conjunction_mode => true), paradedb.match('key_learning', ${bm25MatchQuery}, conjunction_mode => true), paradedb.match('action', ${bm25MatchQuery}, conjunction_mode => true)]))` @@ -110,7 +107,7 @@ export class UserMemoryExperienceModel { // JOIN condition const joinCondition = and( eq(userMemories.id, userMemoriesExperiences.userMemoryId), - eq(userMemories.userId, this.userId), + this.memoryWhere(userMemories), ); // Execute queries in parallel @@ -152,10 +149,7 @@ export class UserMemoryExperienceModel { findById = async (id: string) => { return this.db.query.userMemoriesExperiences.findFirst({ - where: and( - eq(userMemoriesExperiences.id, id), - eq(userMemoriesExperiences.userId, this.userId), - ), + where: and(eq(userMemoriesExperiences.id, id), this.memoryWhere(userMemoriesExperiences)), }); }; @@ -163,8 +157,6 @@ export class UserMemoryExperienceModel { return this.db .update(userMemoriesExperiences) .set({ ...value, updatedAt: new Date() }) - .where( - and(eq(userMemoriesExperiences.id, id), eq(userMemoriesExperiences.userId, this.userId)), - ); + .where(and(eq(userMemoriesExperiences.id, id), this.memoryWhere(userMemoriesExperiences))); }; } diff --git a/packages/database/src/models/userMemory/identity.ts b/packages/database/src/models/userMemory/identity.ts index 81fb9340be..4850bebc8e 100644 --- a/packages/database/src/models/userMemory/identity.ts +++ b/packages/database/src/models/userMemory/identity.ts @@ -17,6 +17,10 @@ export class UserMemoryIdentityModel { this.db = db; } + private memoryWhere(table: { userId: any }) { + return eq(table.userId, this.userId); + } + create = async (params: Omit) => { const [result] = await this.db .insert(userMemoriesIdentities) @@ -29,10 +33,7 @@ export class UserMemoryIdentityModel { delete = async (id: string) => { return this.db.transaction(async (tx) => { const identity = await tx.query.userMemoriesIdentities.findFirst({ - where: and( - eq(userMemoriesIdentities.id, id), - eq(userMemoriesIdentities.userId, this.userId), - ), + where: and(eq(userMemoriesIdentities.id, id), this.memoryWhere(userMemoriesIdentities)), }); if (!identity || !identity.userMemoryId) { @@ -42,25 +43,21 @@ export class UserMemoryIdentityModel { // Delete the base user memory (cascade will handle the identity) await tx .delete(userMemories) - .where( - and(eq(userMemories.id, identity.userMemoryId), eq(userMemories.userId, this.userId)), - ); + .where(and(eq(userMemories.id, identity.userMemoryId), this.memoryWhere(userMemories))); return { success: true }; }); }; deleteAll = async () => { - return this.db - .delete(userMemoriesIdentities) - .where(eq(userMemoriesIdentities.userId, this.userId)); + return this.db.delete(userMemoriesIdentities).where(this.memoryWhere(userMemoriesIdentities)); }; query = async (limit = 50) => { return this.db.query.userMemoriesIdentities.findMany({ limit, orderBy: [desc(userMemoriesIdentities.capturedAt)], - where: eq(userMemoriesIdentities.userId, this.userId), + where: this.memoryWhere(userMemoriesIdentities), }); }; @@ -81,7 +78,7 @@ export class UserMemoryIdentityModel { // Build WHERE conditions const conditions: Array = [ - eq(userMemoriesIdentities.userId, this.userId), + this.memoryWhere(userMemoriesIdentities), // Full-text search across title, description, role normalizedQuery ? sql`(${userMemories.id} @@@ paradedb.boolean(should => ARRAY[paradedb.match('title', ${bm25MatchQuery}, conjunction_mode => true)]) OR ${userMemoriesIdentities.id} @@@ paradedb.boolean(should => ARRAY[paradedb.match('description', ${bm25MatchQuery}, conjunction_mode => true), paradedb.match('role', ${bm25MatchQuery}, conjunction_mode => true)]))` @@ -113,7 +110,7 @@ export class UserMemoryIdentityModel { // JOIN condition const joinCondition = and( eq(userMemories.id, userMemoriesIdentities.userMemoryId), - eq(userMemories.userId, this.userId), + this.memoryWhere(userMemories), ); // Execute queries in parallel @@ -155,7 +152,7 @@ export class UserMemoryIdentityModel { findById = async (id: string) => { return this.db.query.userMemoriesIdentities.findFirst({ - where: and(eq(userMemoriesIdentities.id, id), eq(userMemoriesIdentities.userId, this.userId)), + where: and(eq(userMemoriesIdentities.id, id), this.memoryWhere(userMemoriesIdentities)), }); }; @@ -163,9 +160,7 @@ export class UserMemoryIdentityModel { return this.db .update(userMemoriesIdentities) .set({ ...value, updatedAt: new Date() }) - .where( - and(eq(userMemoriesIdentities.id, id), eq(userMemoriesIdentities.userId, this.userId)), - ); + .where(and(eq(userMemoriesIdentities.id, id), this.memoryWhere(userMemoriesIdentities))); }; /** @@ -187,7 +182,7 @@ export class UserMemoryIdentityModel { .from(userMemoriesIdentities) .where( and( - eq(userMemoriesIdentities.userId, this.userId), + this.memoryWhere(userMemoriesIdentities), // Only include self identities (relationship is 'self' or null/not set) or( eq(userMemoriesIdentities.relationship, RelationshipEnum.Self), diff --git a/packages/database/src/models/userMemory/model.ts b/packages/database/src/models/userMemory/model.ts index 1c5da96faf..65e46012b0 100644 --- a/packages/database/src/models/userMemory/model.ts +++ b/packages/database/src/models/userMemory/model.ts @@ -246,7 +246,7 @@ export interface UserMemorySearchAggregatedResult { preferences: UserMemoryPreferenceWithoutVectors[]; } -const pickSingleSearchType = (types?: string[]) => (types?.length === 1 ? types[0] : undefined); +const _pickSingleSearchType = (types?: string[]) => (types?.length === 1 ? types[0] : undefined); export interface UpdateUserMemoryVectorsParams { detailsVector1024?: number[] | null; @@ -555,6 +555,10 @@ export class UserMemoryModel { this.topicModel = new TopicModel(db, userId); } + private memoryWhere(table: { userId: any }) { + return eq(table.userId, this.userId); + } + private extractSourceMetadata(metadata?: Record | null): { sourceId?: string; sourceType?: MemorySourceType; @@ -830,7 +834,7 @@ export class UserMemoryModel { const { layers, page = 1, size = 10 } = params; const offset = (page - 1) * size; - const conditions = [eq(userMemories.userId, this.userId)]; + const conditions = [this.memoryWhere(userMemories)]; if (layers && layers.length > 0) { conditions.push(inArray(userMemories.memoryLayer, layers)); } @@ -867,7 +871,7 @@ export class UserMemoryModel { const offset = (page - 1) * size; const identityConditions = [ - eq(userMemoriesIdentities.userId, this.userId), + this.memoryWhere(userMemoriesIdentities), eq(userMemoriesIdentities.relationship, RelationshipEnum.Self), ]; @@ -976,7 +980,7 @@ export class UserMemoryModel { const supportsBm25 = !isPGliteDatabase(this.db); const conditions: Array = [ - eq(userMemories.userId, this.userId), + this.memoryWhere(userMemories), categories && categories.length > 0 ? inArray(userMemories.memoryCategory, categories) : undefined, @@ -1142,7 +1146,7 @@ export class UserMemoryModel { ); const joinCondition = and( eq(userMemories.id, userMemoriesActivities.userMemoryId), - eq(userMemoriesActivities.userId, this.userId), + this.memoryWhere(userMemoriesActivities), ); const activityFilters: Array = [ @@ -1257,7 +1261,7 @@ export class UserMemoryModel { ); const joinCondition = and( eq(userMemories.id, userMemoriesExperiences.userMemoryId), - eq(userMemoriesExperiences.userId, this.userId), + this.memoryWhere(userMemoriesExperiences), ); const experienceFilters: Array = [ @@ -1352,7 +1356,7 @@ export class UserMemoryModel { ); const joinCondition = and( eq(userMemories.id, userMemoriesIdentities.userMemoryId), - eq(userMemoriesIdentities.userId, this.userId), + this.memoryWhere(userMemoriesIdentities), ); const identityFilters: Array = [ @@ -1450,7 +1454,7 @@ export class UserMemoryModel { ); const joinCondition = and( eq(userMemories.id, userMemoriesPreferences.userMemoryId), - eq(userMemoriesPreferences.userId, this.userId), + this.memoryWhere(userMemoriesPreferences), ); const preferenceFilters: Array = [ @@ -1585,7 +1589,7 @@ export class UserMemoryModel { const activitySelection = selectNonVectorColumns(userMemoriesActivities); const baseConditions: Array = [ - eq(userMemories.userId, this.userId), + this.memoryWhere(userMemories), eq(userMemories.memoryLayer, layer), ]; const baseWhere = baseConditions.filter(Boolean) as SQL[]; @@ -1638,7 +1642,7 @@ export class UserMemoryModel { ); const joinCondition = and( eq(userMemories.id, userMemoriesExperiences.userMemoryId), - eq(userMemoriesExperiences.userId, this.userId), + this.memoryWhere(userMemoriesExperiences), ); const experienceFilters: Array = [ @@ -1671,7 +1675,7 @@ export class UserMemoryModel { ); const joinCondition = and( eq(userMemories.id, userMemoriesIdentities.userMemoryId), - eq(userMemoriesIdentities.userId, this.userId), + this.memoryWhere(userMemoriesIdentities), ); const identityFilters: Array = [ @@ -1704,7 +1708,7 @@ export class UserMemoryModel { ); const joinCondition = and( eq(userMemories.id, userMemoriesPreferences.userMemoryId), - eq(userMemoriesPreferences.userId, this.userId), + this.memoryWhere(userMemoriesPreferences), ); const preferenceFilters: Array = [ @@ -1763,7 +1767,7 @@ export class UserMemoryModel { userMemoryIds: userMemoriesContexts.userMemoryIds, }) .from(userMemoriesContexts) - .where(and(eq(userMemoriesContexts.id, id), eq(userMemoriesContexts.userId, this.userId))) + .where(and(eq(userMemoriesContexts.id, id), this.memoryWhere(userMemoriesContexts))) .limit(1); if (!context) { return undefined; @@ -1822,9 +1826,7 @@ export class UserMemoryModel { userMemoryId: userMemoriesActivities.userMemoryId, }) .from(userMemoriesActivities) - .where( - and(eq(userMemoriesActivities.id, id), eq(userMemoriesActivities.userId, this.userId)), - ) + .where(and(eq(userMemoriesActivities.id, id), this.memoryWhere(userMemoriesActivities))) .limit(1); if (!activity?.userMemoryId) { return undefined; @@ -1869,12 +1871,7 @@ export class UserMemoryModel { userMemoryId: userMemoriesExperiences.userMemoryId, }) .from(userMemoriesExperiences) - .where( - and( - eq(userMemoriesExperiences.id, id), - eq(userMemoriesExperiences.userId, this.userId), - ), - ) + .where(and(eq(userMemoriesExperiences.id, id), this.memoryWhere(userMemoriesExperiences))) .limit(1); if (!experience?.userMemoryId) { return undefined; @@ -1918,9 +1915,7 @@ export class UserMemoryModel { userMemoryId: userMemoriesIdentities.userMemoryId, }) .from(userMemoriesIdentities) - .where( - and(eq(userMemoriesIdentities.id, id), eq(userMemoriesIdentities.userId, this.userId)), - ) + .where(and(eq(userMemoriesIdentities.id, id), this.memoryWhere(userMemoriesIdentities))) .limit(1); if (!identity?.userMemoryId) { return undefined; @@ -1963,12 +1958,7 @@ export class UserMemoryModel { userMemoryId: userMemoriesPreferences.userMemoryId, }) .from(userMemoriesPreferences) - .where( - and( - eq(userMemoriesPreferences.id, id), - eq(userMemoriesPreferences.userId, this.userId), - ), - ) + .where(and(eq(userMemoriesPreferences.id, id), this.memoryWhere(userMemoriesPreferences))) .limit(1); if (!preference?.userMemoryId) { return undefined; @@ -2020,7 +2010,7 @@ export class UserMemoryModel { userId: userMemories.userId, }) .from(userMemories) - .where(and(eq(userMemories.id, memoryId), eq(userMemories.userId, this.userId))) + .where(and(eq(userMemories.id, memoryId), this.memoryWhere(userMemories))) .limit(1); if (!memory) { return undefined; @@ -2031,7 +2021,7 @@ export class UserMemoryModel { findById = async (id: string): Promise => { const result = await this.db.query.userMemories.findFirst({ - where: and(eq(userMemories.id, id), eq(userMemories.userId, this.userId)), + where: and(eq(userMemories.id, id), this.memoryWhere(userMemories)), }); if (result) { @@ -2045,7 +2035,7 @@ export class UserMemoryModel { await this.db .update(userMemories) .set({ ...params, updatedAt: new Date() }) - .where(and(eq(userMemories.id, id), eq(userMemories.userId, this.userId))); + .where(and(eq(userMemories.id, id), this.memoryWhere(userMemories))); }; updateUserMemoryVectors = async ( @@ -2070,7 +2060,7 @@ export class UserMemoryModel { ...vectorUpdates, updatedAt: new Date(), }) - .where(and(eq(userMemories.id, id), eq(userMemories.userId, this.userId))); + .where(and(eq(userMemories.id, id), this.memoryWhere(userMemories))); }; updateContextVectors = async (id: string, vectors: UpdateContextVectorsParams): Promise => { @@ -2088,7 +2078,7 @@ export class UserMemoryModel { ...vectorUpdates, updatedAt: new Date(), }) - .where(and(eq(userMemoriesContexts.id, id), eq(userMemoriesContexts.userId, this.userId))); + .where(and(eq(userMemoriesContexts.id, id), this.memoryWhere(userMemoriesContexts))); }; updatePreferenceVectors = async ( @@ -2110,9 +2100,7 @@ export class UserMemoryModel { ...vectorUpdates, updatedAt: new Date(), }) - .where( - and(eq(userMemoriesPreferences.id, id), eq(userMemoriesPreferences.userId, this.userId)), - ); + .where(and(eq(userMemoriesPreferences.id, id), this.memoryWhere(userMemoriesPreferences))); }; updateIdentityVectors = async ( @@ -2134,9 +2122,7 @@ export class UserMemoryModel { ...vectorUpdates, updatedAt: new Date(), }) - .where( - and(eq(userMemoriesIdentities.id, id), eq(userMemoriesIdentities.userId, this.userId)), - ); + .where(and(eq(userMemoriesIdentities.id, id), this.memoryWhere(userMemoriesIdentities))); }; updateExperienceVectors = async ( @@ -2164,9 +2150,7 @@ export class UserMemoryModel { ...vectorUpdates, updatedAt: new Date(), }) - .where( - and(eq(userMemoriesExperiences.id, id), eq(userMemoriesExperiences.userId, this.userId)), - ); + .where(and(eq(userMemoriesExperiences.id, id), this.memoryWhere(userMemoriesExperiences))); }; updateActivityVectors = async ( @@ -2191,9 +2175,7 @@ export class UserMemoryModel { ...vectorUpdates, updatedAt: new Date(), }) - .where( - and(eq(userMemoriesActivities.id, id), eq(userMemoriesActivities.userId, this.userId)), - ); + .where(and(eq(userMemoriesActivities.id, id), this.memoryWhere(userMemoriesActivities))); }; addIdentityEntry = async (params: AddIdentityEntryParams): Promise => { @@ -2269,7 +2251,7 @@ export class UserMemoryModel { const identity = await tx.query.userMemoriesIdentities.findFirst({ where: and( eq(userMemoriesIdentities.id, params.identityId), - eq(userMemoriesIdentities.userId, this.userId), + this.memoryWhere(userMemoriesIdentities), ), }); if (!identity || !identity.userMemoryId) { @@ -2290,9 +2272,7 @@ export class UserMemoryModel { await tx .update(userMemories) .set(baseUpdate) - .where( - and(eq(userMemories.id, identity.userMemoryId), eq(userMemories.userId, this.userId)), - ); + .where(and(eq(userMemories.id, identity.userMemoryId), this.memoryWhere(userMemories))); } } @@ -2356,7 +2336,7 @@ export class UserMemoryModel { .where( and( eq(userMemoriesIdentities.id, params.identityId), - eq(userMemoriesIdentities.userId, this.userId), + this.memoryWhere(userMemoriesIdentities), ), ); } @@ -2371,7 +2351,7 @@ export class UserMemoryModel { const identity = await tx.query.userMemoriesIdentities.findFirst({ where: and( eq(userMemoriesIdentities.id, identityId), - eq(userMemoriesIdentities.userId, this.userId), + this.memoryWhere(userMemoriesIdentities), ), }); @@ -2381,9 +2361,7 @@ export class UserMemoryModel { await tx .delete(userMemories) - .where( - and(eq(userMemories.id, identity.userMemoryId), eq(userMemories.userId, this.userId)), - ); + .where(and(eq(userMemories.id, identity.userMemoryId), this.memoryWhere(userMemories))); return true; }); @@ -2392,10 +2370,7 @@ export class UserMemoryModel { removeContextEntry = async (contextId: string): Promise => { return this.db.transaction(async (tx) => { const context = await tx.query.userMemoriesContexts.findFirst({ - where: and( - eq(userMemoriesContexts.id, contextId), - eq(userMemoriesContexts.userId, this.userId), - ), + where: and(eq(userMemoriesContexts.id, contextId), this.memoryWhere(userMemoriesContexts)), }); if (!context) { @@ -2409,15 +2384,13 @@ export class UserMemoryModel { if (memoryIds.length > 0) { await tx .delete(userMemories) - .where(and(inArray(userMemories.id, memoryIds), eq(userMemories.userId, this.userId))); + .where(and(inArray(userMemories.id, memoryIds), this.memoryWhere(userMemories))); } // Delete the context entry await tx .delete(userMemoriesContexts) - .where( - and(eq(userMemoriesContexts.id, contextId), eq(userMemoriesContexts.userId, this.userId)), - ); + .where(and(eq(userMemoriesContexts.id, contextId), this.memoryWhere(userMemoriesContexts))); return true; }); @@ -2428,7 +2401,7 @@ export class UserMemoryModel { const experience = await tx.query.userMemoriesExperiences.findFirst({ where: and( eq(userMemoriesExperiences.id, experienceId), - eq(userMemoriesExperiences.userId, this.userId), + this.memoryWhere(userMemoriesExperiences), ), }); @@ -2439,9 +2412,7 @@ export class UserMemoryModel { // Delete the base user memory (cascade will handle the experience) await tx .delete(userMemories) - .where( - and(eq(userMemories.id, experience.userMemoryId), eq(userMemories.userId, this.userId)), - ); + .where(and(eq(userMemories.id, experience.userMemoryId), this.memoryWhere(userMemories))); return true; }); @@ -2452,7 +2423,7 @@ export class UserMemoryModel { const preference = await tx.query.userMemoriesPreferences.findFirst({ where: and( eq(userMemoriesPreferences.id, preferenceId), - eq(userMemoriesPreferences.userId, this.userId), + this.memoryWhere(userMemoriesPreferences), ), }); @@ -2463,9 +2434,7 @@ export class UserMemoryModel { // Delete the base user memory (cascade will handle the preference) await tx .delete(userMemories) - .where( - and(eq(userMemories.id, preference.userMemoryId), eq(userMemories.userId, this.userId)), - ); + .where(and(eq(userMemories.id, preference.userMemoryId), this.memoryWhere(userMemories))); return true; }); @@ -2474,11 +2443,11 @@ export class UserMemoryModel { delete = async (id: string): Promise => { await this.db .delete(userMemories) - .where(and(eq(userMemories.id, id), eq(userMemories.userId, this.userId))); + .where(and(eq(userMemories.id, id), this.memoryWhere(userMemories))); }; deleteAll = async (): Promise => { - await this.db.delete(userMemories).where(eq(userMemories.userId, this.userId)); + await this.db.delete(userMemories).where(this.memoryWhere(userMemories)); }; searchActivities = async (params: { @@ -2520,7 +2489,7 @@ export class UserMemoryModel { .from(userMemoriesActivities) .$dynamic(); - const conditions = [eq(userMemoriesActivities.userId, this.userId)]; + const conditions = [this.memoryWhere(userMemoriesActivities)]; if (type) { conditions.push(eq(userMemoriesActivities.type, type)); } @@ -2571,7 +2540,7 @@ export class UserMemoryModel { .from(userMemoriesContexts) .$dynamic(); - const conditions = [eq(userMemoriesContexts.userId, this.userId)]; + const conditions = [this.memoryWhere(userMemoriesContexts)]; if (type) { conditions.push(eq(userMemoriesContexts.type, type)); } @@ -2622,7 +2591,7 @@ export class UserMemoryModel { .from(userMemoriesExperiences) .$dynamic(); - const conditions = [eq(userMemoriesExperiences.userId, this.userId)]; + const conditions = [this.memoryWhere(userMemoriesExperiences)]; if (type) { conditions.push(eq(userMemoriesExperiences.type, type)); } @@ -2669,7 +2638,7 @@ export class UserMemoryModel { .from(userMemoriesPreferences) .$dynamic(); - const conditions = [eq(userMemoriesPreferences.userId, this.userId)]; + const conditions = [this.memoryWhere(userMemoriesPreferences)]; if (type) { conditions.push(eq(userMemoriesPreferences.type, type)); } @@ -2688,7 +2657,7 @@ export class UserMemoryModel { const res = await this.db .select(selectNonVectorColumns(userMemoriesIdentities)) .from(userMemoriesIdentities) - .where(eq(userMemoriesIdentities.userId, this.userId)) + .where(this.memoryWhere(userMemoriesIdentities)) .orderBy(desc(userMemoriesIdentities.capturedAt), desc(userMemoriesIdentities.createdAt)); return res; @@ -2702,7 +2671,7 @@ export class UserMemoryModel { }) .from(userMemoriesIdentities) .innerJoin(userMemories, eq(userMemories.id, userMemoriesIdentities.userMemoryId)) - .where(eq(userMemoriesIdentities.userId, this.userId)) + .where(this.memoryWhere(userMemoriesIdentities)) .orderBy(desc(userMemoriesIdentities.capturedAt), desc(userMemoriesIdentities.createdAt)); return res; @@ -2712,9 +2681,7 @@ export class UserMemoryModel { const res = await this.db .select(selectNonVectorColumns(userMemoriesIdentities)) .from(userMemoriesIdentities) - .where( - and(eq(userMemoriesIdentities.userId, this.userId), eq(userMemoriesIdentities.type, type)), - ) + .where(and(this.memoryWhere(userMemoriesIdentities), eq(userMemoriesIdentities.type, type))) .orderBy(desc(userMemoriesIdentities.capturedAt), desc(userMemoriesIdentities.createdAt)); return res; @@ -2744,7 +2711,7 @@ export class UserMemoryModel { accessedCount: sql`${userMemories.accessedCount} + 1`, lastAccessedAt: now, }) - .where(and(eq(userMemories.userId, this.userId), eq(userMemories.id, memoryId))); + .where(and(this.memoryWhere(userMemories), eq(userMemories.id, memoryId))); } const memories = await tx @@ -2753,9 +2720,7 @@ export class UserMemoryModel { layer: userMemories.memoryLayer, }) .from(userMemories) - .where( - and(eq(userMemories.userId, this.userId), inArray(userMemories.id, orderedMemoryIds)), - ); + .where(and(this.memoryWhere(userMemories), inArray(userMemories.id, orderedMemoryIds))); const experienceIds = memories .filter((memory) => memory.layer === 'experience') @@ -2766,7 +2731,7 @@ export class UserMemoryModel { .set({ accessedAt: now }) .where( and( - eq(userMemoriesExperiences.userId, this.userId), + this.memoryWhere(userMemoriesExperiences), inArray(userMemoriesExperiences.userMemoryId, experienceIds), ), ); @@ -2781,7 +2746,7 @@ export class UserMemoryModel { .set({ accessedAt: now }) .where( and( - eq(userMemoriesIdentities.userId, this.userId), + this.memoryWhere(userMemoriesIdentities), inArray(userMemoriesIdentities.userMemoryId, identityIds), ), ); @@ -2796,7 +2761,7 @@ export class UserMemoryModel { .set({ accessedAt: now }) .where( and( - eq(userMemoriesPreferences.userId, this.userId), + this.memoryWhere(userMemoriesPreferences), inArray(userMemoriesPreferences.userMemoryId, preferenceIds), ), ); @@ -2809,7 +2774,7 @@ export class UserMemoryModel { .set({ accessedAt: now }) .where( and( - eq(userMemoriesContexts.userId, this.userId), + this.memoryWhere(userMemoriesContexts), inArray(userMemoriesContexts.id, orderedContextIds), ), ); diff --git a/packages/database/src/models/userMemory/preference.ts b/packages/database/src/models/userMemory/preference.ts index 2d486ea51f..1cb689b60c 100644 --- a/packages/database/src/models/userMemory/preference.ts +++ b/packages/database/src/models/userMemory/preference.ts @@ -13,6 +13,10 @@ export class UserMemoryPreferenceModel { this.db = db; } + private memoryWhere(table: { userId: any }) { + return eq(table.userId, this.userId); + } + create = async (params: Omit) => { const [result] = await this.db .insert(userMemoriesPreferences) @@ -25,10 +29,7 @@ export class UserMemoryPreferenceModel { delete = async (id: string) => { return this.db.transaction(async (tx) => { const preference = await tx.query.userMemoriesPreferences.findFirst({ - where: and( - eq(userMemoriesPreferences.id, id), - eq(userMemoriesPreferences.userId, this.userId), - ), + where: and(eq(userMemoriesPreferences.id, id), this.memoryWhere(userMemoriesPreferences)), }); if (!preference || !preference.userMemoryId) { @@ -38,34 +39,27 @@ export class UserMemoryPreferenceModel { // Delete the base user memory (cascade will handle the preference) await tx .delete(userMemories) - .where( - and(eq(userMemories.id, preference.userMemoryId), eq(userMemories.userId, this.userId)), - ); + .where(and(eq(userMemories.id, preference.userMemoryId), this.memoryWhere(userMemories))); return { success: true }; }); }; deleteAll = async () => { - return this.db - .delete(userMemoriesPreferences) - .where(eq(userMemoriesPreferences.userId, this.userId)); + return this.db.delete(userMemoriesPreferences).where(this.memoryWhere(userMemoriesPreferences)); }; query = async (limit = 50) => { return this.db.query.userMemoriesPreferences.findMany({ limit, orderBy: [desc(userMemoriesPreferences.createdAt)], - where: eq(userMemoriesPreferences.userId, this.userId), + where: this.memoryWhere(userMemoriesPreferences), }); }; findById = async (id: string) => { return this.db.query.userMemoriesPreferences.findFirst({ - where: and( - eq(userMemoriesPreferences.id, id), - eq(userMemoriesPreferences.userId, this.userId), - ), + where: and(eq(userMemoriesPreferences.id, id), this.memoryWhere(userMemoriesPreferences)), }); }; @@ -73,8 +67,6 @@ export class UserMemoryPreferenceModel { return this.db .update(userMemoriesPreferences) .set({ ...value, updatedAt: new Date() }) - .where( - and(eq(userMemoriesPreferences.id, id), eq(userMemoriesPreferences.userId, this.userId)), - ); + .where(and(eq(userMemoriesPreferences.id, id), this.memoryWhere(userMemoriesPreferences))); }; } diff --git a/packages/database/src/models/userMemory/query.ts b/packages/database/src/models/userMemory/query.ts index f065a0a13c..5aeb844fe0 100644 --- a/packages/database/src/models/userMemory/query.ts +++ b/packages/database/src/models/userMemory/query.ts @@ -664,6 +664,10 @@ export class UserMemoryQueryModel { private readonly userId: string, ) {} + private memoryWhere(table: { userId: any }) { + return eq(table.userId, this.userId); + } + /** * Hybrid memory retrieval pipeline for the five heterogeneous memory layers. * @@ -1118,7 +1122,7 @@ export class UserMemoryQueryModel { updatedAt: userMemories.updatedAt, }) .from(userMemories) - .where(and(eq(userMemories.userId, this.userId), inArray(userMemories.id, memoryIds))); + .where(and(this.memoryWhere(userMemories), inArray(userMemories.id, memoryIds))); const baseMemoryMap = new Map( baseMemories.map((memory) => [ @@ -1215,7 +1219,7 @@ export class UserMemoryQueryModel { }): Promise { const { column, layers, limit, q, timeRange } = params; const conditions = [ - eq(userMemories.userId, this.userId), + this.memoryWhere(userMemories), layers?.length ? inArray(userMemories.memoryLayer, layers) : undefined, this.buildTimeRangeCondition( { @@ -1255,7 +1259,7 @@ export class UserMemoryQueryModel { }): Promise { const { column, layers, limit, q, timeRange } = params; const conditions = [ - eq(userMemories.userId, this.userId), + this.memoryWhere(userMemories), layers?.length ? inArray(userMemories.memoryLayer, layers) : undefined, this.buildTimeRangeCondition( { @@ -1557,7 +1561,7 @@ export class UserMemoryQueryModel { .from(userMemoriesIdentities) .where( and( - eq(userMemoriesIdentities.userId, this.userId), + this.memoryWhere(userMemoriesIdentities), this.buildTimeRangeCondition( { capturedAt: userMemoriesIdentities.capturedAt, @@ -1765,8 +1769,8 @@ export class UserMemoryQueryModel { params: SearchMemoryParams, ) { const conditions = [ - eq(userMemoriesActivities.userId, this.userId), - eq(userMemories.userId, this.userId), + this.memoryWhere(userMemoriesActivities), + this.memoryWhere(userMemories), params.categories?.length ? inArray(userMemories.memoryCategory, params.categories) : undefined, @@ -1825,8 +1829,8 @@ export class UserMemoryQueryModel { params: SearchMemoryParams, ) { const conditions = [ - eq(userMemoriesContexts.userId, this.userId), - eq(userMemories.userId, this.userId), + this.memoryWhere(userMemoriesContexts), + this.memoryWhere(userMemories), params.categories?.length ? inArray(userMemories.memoryCategory, params.categories) : undefined, @@ -1925,8 +1929,8 @@ export class UserMemoryQueryModel { params: SearchMemoryParams, ) { const conditions = [ - eq(userMemoriesExperiences.userId, this.userId), - eq(userMemories.userId, this.userId), + this.memoryWhere(userMemoriesExperiences), + this.memoryWhere(userMemories), params.categories?.length ? inArray(userMemories.memoryCategory, params.categories) : undefined, @@ -1978,8 +1982,8 @@ export class UserMemoryQueryModel { params: SearchMemoryParams, ) { const conditions = [ - eq(userMemoriesPreferences.userId, this.userId), - eq(userMemories.userId, this.userId), + this.memoryWhere(userMemoriesPreferences), + this.memoryWhere(userMemories), params.categories?.length ? inArray(userMemories.memoryCategory, params.categories) : undefined, @@ -2030,8 +2034,8 @@ export class UserMemoryQueryModel { params: SearchMemoryParams, ): Promise { const conditions = [ - eq(userMemoriesIdentities.userId, this.userId), - eq(userMemories.userId, this.userId), + this.memoryWhere(userMemoriesIdentities), + this.memoryWhere(userMemories), params.categories?.length ? inArray(userMemories.memoryCategory, params.categories) : undefined, @@ -2086,8 +2090,8 @@ export class UserMemoryQueryModel { ) { const normalizedQuery = typeof query === 'string' ? query.trim() : ''; const conditions = [ - eq(userMemoriesActivities.userId, this.userId), - eq(userMemories.userId, this.userId), + this.memoryWhere(userMemoriesActivities), + this.memoryWhere(userMemories), params.categories?.length ? inArray(userMemories.memoryCategory, params.categories) : undefined, @@ -2154,8 +2158,8 @@ export class UserMemoryQueryModel { ) { const normalizedQuery = typeof query === 'string' ? query.trim() : ''; const conditions = [ - eq(userMemoriesContexts.userId, this.userId), - eq(userMemories.userId, this.userId), + this.memoryWhere(userMemoriesContexts), + this.memoryWhere(userMemories), params.categories?.length ? inArray(userMemories.memoryCategory, params.categories) : undefined, @@ -2258,8 +2262,8 @@ export class UserMemoryQueryModel { ) { const normalizedQuery = typeof query === 'string' ? query.trim() : ''; const conditions = [ - eq(userMemoriesExperiences.userId, this.userId), - eq(userMemories.userId, this.userId), + this.memoryWhere(userMemoriesExperiences), + this.memoryWhere(userMemories), params.categories?.length ? inArray(userMemories.memoryCategory, params.categories) : undefined, @@ -2319,8 +2323,8 @@ export class UserMemoryQueryModel { ) { const normalizedQuery = typeof query === 'string' ? query.trim() : ''; const conditions = [ - eq(userMemoriesPreferences.userId, this.userId), - eq(userMemories.userId, this.userId), + this.memoryWhere(userMemoriesPreferences), + this.memoryWhere(userMemories), params.categories?.length ? inArray(userMemories.memoryCategory, params.categories) : undefined, @@ -2377,8 +2381,8 @@ export class UserMemoryQueryModel { ) { const normalizedQuery = typeof query === 'string' ? query.trim() : ''; const conditions = [ - eq(userMemoriesIdentities.userId, this.userId), - eq(userMemories.userId, this.userId), + this.memoryWhere(userMemoriesIdentities), + this.memoryWhere(userMemories), params.categories?.length ? inArray(userMemories.memoryCategory, params.categories) : undefined, diff --git a/packages/database/src/models/workspace.ts b/packages/database/src/models/workspace.ts new file mode 100644 index 0000000000..2a5ba0ba20 --- /dev/null +++ b/packages/database/src/models/workspace.ts @@ -0,0 +1,334 @@ +import { and, count, desc, eq, isNull, ne } from 'drizzle-orm'; + +import { + type NewWorkspace, + type WorkspaceItem, + workspaceMembers, + workspaces, +} from '../schemas/workspace'; +import type { LobeChatDatabase } from '../type'; + +export class WorkspaceModel { + protected readonly db: LobeChatDatabase; + protected readonly userId: string; + + constructor(db: LobeChatDatabase, userId: string) { + this.db = db; + this.userId = userId; + } + + create = async (params: { + avatar?: string; + description?: string; + name: string; + slug: string; + }) => { + return this.db.transaction(async (tx) => { + const [workspace] = await tx + .insert(workspaces) + .values({ + avatar: params.avatar, + description: params.description, + name: params.name, + primaryOwnerId: this.userId, + slug: params.slug, + } satisfies NewWorkspace) + .returning(); + + await tx.insert(workspaceMembers).values({ + role: 'owner', + userId: this.userId, + workspaceId: workspace.id, + }); + + return workspace; + }); + }; + + delete = async (id: string) => { + return this.db + .delete(workspaces) + .where(and(eq(workspaces.id, id), eq(workspaces.primaryOwnerId, this.userId))); + }; + + findById = async (id: string) => { + return this.db.query.workspaces.findFirst({ + where: eq(workspaces.id, id), + }); + }; + + findBySlug = async (slug: string) => { + return this.db.query.workspaces.findFirst({ + where: eq(workspaces.slug, slug), + }); + }; + + /** + * List ids of workspaces where this user is the primary (Stripe-bound) owner. + * Cloud callers combine with subscription-status data to enforce the Free + * workspace cap; OSS callers can use the raw count. + */ + listOwnedWorkspaceIds = async (): Promise => { + const owned = await this.db.query.workspaces.findMany({ + columns: { id: true }, + where: eq(workspaces.primaryOwnerId, this.userId), + }); + return owned.map((w) => w.id); + }; + + getSettings = async (id: string) => { + const workspace = await this.db.query.workspaces.findFirst({ + columns: { settings: true }, + where: eq(workspaces.id, id), + }); + return workspace?.settings ?? {}; + }; + + /** + * Count every workspace this user belongs to — owned + joined. Reads the + * membership table directly because owners are always inserted as members on + * `create`, so a single count covers both shapes. + */ + countUserMemberships = async (): Promise => { + const result = await this.db + .select({ count: count() }) + .from(workspaceMembers) + .where(and(eq(workspaceMembers.userId, this.userId), isNull(workspaceMembers.deletedAt))); + return result[0]?.count ?? 0; + }; + + listUserWorkspaces = async () => { + const memberships = await this.db.query.workspaceMembers.findMany({ + where: and(eq(workspaceMembers.userId, this.userId), isNull(workspaceMembers.deletedAt)), + }); + + if (memberships.length === 0) return []; + + const workspaceIds = memberships.map((m) => m.workspaceId); + + const results = await this.db.query.workspaces.findMany({ + orderBy: [desc(workspaces.updatedAt)], + where: (ws, { inArray }) => inArray(ws.id, workspaceIds), + }); + + return results.map((ws) => ({ + ...ws, + role: memberships.find((m) => m.workspaceId === ws.id)?.role ?? 'viewer', + })); + }; + + update = async ( + id: string, + value: Partial>, + ) => { + return this.db + .update(workspaces) + .set({ ...value, updatedAt: new Date() }) + .where(eq(workspaces.id, id)); + }; + + updateSettings = async (id: string, settings: Record) => { + return this.db + .update(workspaces) + .set({ settings, updatedAt: new Date() }) + .where(eq(workspaces.id, id)); + }; + + /** + * Transfer the Stripe binding (primary owner) to another existing `owner` + * member. Both users keep role='owner' afterwards — only the Stripe binding + * moves. Use `promoteToOwner` first if the target isn't already an owner. + */ + transferPrimaryOwnership = async (id: string, newPrimaryOwnerUserId: string) => { + if (newPrimaryOwnerUserId === this.userId) + throw new Error('New primary owner must be a different user'); + + return this.db.transaction(async (tx) => { + const current = await tx.query.workspaces.findFirst({ + where: eq(workspaces.id, id), + }); + + if (!current) throw new Error('Workspace not found'); + if (current.primaryOwnerId !== this.userId) + throw new Error('Only the primary owner can transfer primary ownership'); + + const targetMembership = await tx.query.workspaceMembers.findFirst({ + where: and( + eq(workspaceMembers.workspaceId, id), + eq(workspaceMembers.userId, newPrimaryOwnerUserId), + isNull(workspaceMembers.deletedAt), + ), + }); + if (!targetMembership) + throw new Error('Target user must already be a member of the workspace'); + if (targetMembership.role !== 'owner') + throw new Error('Target user must already be an owner — promote them first'); + + await tx + .update(workspaces) + .set({ primaryOwnerId: newPrimaryOwnerUserId, updatedAt: new Date() }) + .where(eq(workspaces.id, id)); + + return { + newPrimaryOwnerUserId, + previousPrimaryOwnerUserId: this.userId, + workspaceId: id, + }; + }); + }; + + promoteToOwner = async (id: string, targetUserId: string) => { + return this.db.transaction(async (tx) => { + const actor = await tx.query.workspaceMembers.findFirst({ + where: and( + eq(workspaceMembers.workspaceId, id), + eq(workspaceMembers.userId, this.userId), + isNull(workspaceMembers.deletedAt), + ), + }); + if (actor?.role !== 'owner') + throw new Error('Only an owner can promote other members to owner'); + + const target = await tx.query.workspaceMembers.findFirst({ + where: and( + eq(workspaceMembers.workspaceId, id), + eq(workspaceMembers.userId, targetUserId), + isNull(workspaceMembers.deletedAt), + ), + }); + if (!target) throw new Error('Target user is not a member of this workspace'); + if (target.role === 'owner') return target; + + await tx + .update(workspaceMembers) + .set({ role: 'owner' }) + .where( + and(eq(workspaceMembers.workspaceId, id), eq(workspaceMembers.userId, targetUserId)), + ); + + return { ...target, role: 'owner' }; + }); + }; + + demoteFromOwner = async (id: string, targetUserId: string) => { + return this.db.transaction(async (tx) => { + const workspace = await tx.query.workspaces.findFirst({ + where: eq(workspaces.id, id), + }); + if (!workspace) throw new Error('Workspace not found'); + if (workspace.primaryOwnerId === targetUserId) + throw new Error( + 'Cannot demote the primary owner — transfer primary ownership to another owner first', + ); + + const actor = await tx.query.workspaceMembers.findFirst({ + where: and( + eq(workspaceMembers.workspaceId, id), + eq(workspaceMembers.userId, this.userId), + isNull(workspaceMembers.deletedAt), + ), + }); + if (actor?.role !== 'owner') throw new Error('Only an owner can demote other owners'); + + const target = await tx.query.workspaceMembers.findFirst({ + where: and( + eq(workspaceMembers.workspaceId, id), + eq(workspaceMembers.userId, targetUserId), + isNull(workspaceMembers.deletedAt), + ), + }); + if (!target) throw new Error('Target user is not a member of this workspace'); + if (target.role !== 'owner') return target; + + await tx + .update(workspaceMembers) + .set({ role: 'member' }) + .where( + and(eq(workspaceMembers.workspaceId, id), eq(workspaceMembers.userId, targetUserId)), + ); + + return { ...target, role: 'member' }; + }); + }; + + countOtherOwners = async (workspaceId: string, excludeUserId: string): Promise => { + const result = await this.db + .select({ count: count() }) + .from(workspaceMembers) + .where( + and( + eq(workspaceMembers.workspaceId, workspaceId), + eq(workspaceMembers.role, 'owner'), + ne(workspaceMembers.userId, excludeUserId), + isNull(workspaceMembers.deletedAt), + ), + ); + return result[0]?.count ?? 0; + }; + + /** + * Demote the workspace to single-owner: remove every non-owner member and + * clear the grace-period marker. Called when a subscription is cancelled. + * Workspace-scoped resources (agents/sessions/etc.) stay attached to the + * workspace and remain accessible to the primary owner. + */ + downgradeToSolo = async (id: string) => { + return this.db.transaction(async (tx) => { + const current = await tx.query.workspaces.findFirst({ + where: eq(workspaces.id, id), + }); + + if (!current) throw new Error('Workspace not found'); + if (current.primaryOwnerId !== this.userId) + throw new Error('Only the primary owner can downgrade this workspace'); + + const removedMembers = await tx + .update(workspaceMembers) + .set({ deletedAt: new Date() }) + .where( + and( + eq(workspaceMembers.workspaceId, id), + ne(workspaceMembers.userId, current.primaryOwnerId), + isNull(workspaceMembers.deletedAt), + ), + ) + .returning(); + + const currentSettings = (current.settings as Record | null) ?? {}; + const { gracePeriodUntil: _drop, ...restSettings } = currentSettings; + + const [updated] = await tx + .update(workspaces) + .set({ + settings: restSettings, + updatedAt: new Date(), + }) + .where(eq(workspaces.id, id)) + .returning(); + + return { + removedUserIds: removedMembers.map((m) => m.userId), + workspace: updated, + }; + }); + }; + + setGracePeriod = async (id: string, gracePeriodUntil: number | null) => { + const current = await this.db.query.workspaces.findFirst({ + columns: { settings: true }, + where: eq(workspaces.id, id), + }); + if (!current) throw new Error('Workspace not found'); + + const prev = (current.settings as Record | null) ?? {}; + const next = + gracePeriodUntil === null + ? Object.fromEntries(Object.entries(prev).filter(([k]) => k !== 'gracePeriodUntil')) + : { ...prev, gracePeriodUntil }; + + await this.db + .update(workspaces) + .set({ settings: next, updatedAt: new Date() }) + .where(eq(workspaces.id, id)); + }; +} diff --git a/packages/database/src/models/workspaceAuditLog.ts b/packages/database/src/models/workspaceAuditLog.ts new file mode 100644 index 0000000000..8de55cbf85 --- /dev/null +++ b/packages/database/src/models/workspaceAuditLog.ts @@ -0,0 +1,99 @@ +import { and, desc, eq, gte, lt, lte } from 'drizzle-orm'; + +import { workspaceAuditLogs } from '../schemas/workspace'; +import type { LobeChatDatabase } from '../type'; + +export type WorkspaceAuditAction = + | 'workspace.created' + | 'workspace.updated' + | 'workspace.upgraded' + | 'workspace.downgraded' + | 'workspace.primary_ownership_transferred' + | 'workspace.deleted' + | 'workspace.cleanup_triggered' + | 'workspace.account_upgraded' + | 'workspace.data_cleared' + | 'workspace.settings_reset' + | 'member.invited' + | 'member.removed' + | 'member.role_updated' + | 'member.joined' + | 'member.left' + | 'member.promoted_to_owner' + | 'member.demoted_from_owner' + | 'invitation.revoked' + | 'invitation.resent' + | 'subscription.activated' + | 'subscription.updated' + | 'subscription.cancelled' + | 'subscription.cancellation_scheduled' + | 'subscription.cancellation_resumed' + | 'subscription.grace_period_started' + | 'billing.portal_session_created' + | 'billing.payment_method_added' + | 'billing.payment_method_removed' + | 'billing.default_payment_method_changed'; + +interface CreateAuditLogParams { + action: WorkspaceAuditAction; + ipAddress?: string; + metadata?: Record; + resourceId?: string; + resourceType?: string; + userId: string | null; + workspaceId: string; +} + +interface ListAuditLogParams { + action?: WorkspaceAuditAction; + cursor?: Date; + endDate?: Date; + limit?: number; + startDate?: Date; + workspaceId: string; +} + +export class WorkspaceAuditLogModel { + private readonly db: LobeChatDatabase; + + constructor(db: LobeChatDatabase) { + this.db = db; + } + + create = async (params: CreateAuditLogParams) => { + const [row] = await this.db + .insert(workspaceAuditLogs) + .values({ + action: params.action, + ipAddress: params.ipAddress, + metadata: params.metadata ?? {}, + resourceId: params.resourceId, + resourceType: params.resourceType, + userId: params.userId, + workspaceId: params.workspaceId, + }) + .returning(); + return row; + }; + + list = async (params: ListAuditLogParams) => { + const { workspaceId, action, startDate, endDate, cursor, limit = 50 } = params; + const conditions = [eq(workspaceAuditLogs.workspaceId, workspaceId)]; + if (action) conditions.push(eq(workspaceAuditLogs.action, action)); + if (startDate) conditions.push(gte(workspaceAuditLogs.createdAt, startDate)); + if (endDate) conditions.push(lte(workspaceAuditLogs.createdAt, endDate)); + if (cursor) conditions.push(lt(workspaceAuditLogs.createdAt, cursor)); + + const rows = await this.db.query.workspaceAuditLogs.findMany({ + limit: limit + 1, + orderBy: [desc(workspaceAuditLogs.createdAt)], + where: and(...conditions), + }); + + const hasMore = rows.length > limit; + const items = hasMore ? rows.slice(0, limit) : rows; + const nextCursor = hasMore ? items.at(-1)?.createdAt?.toISOString() : null; + + return { items, nextCursor }; + }; +} diff --git a/packages/database/src/models/workspaceMember.ts b/packages/database/src/models/workspaceMember.ts new file mode 100644 index 0000000000..d1c251997e --- /dev/null +++ b/packages/database/src/models/workspaceMember.ts @@ -0,0 +1,133 @@ +import { INVITATION_EXPIRY_DAYS } from '@lobechat/const'; +import { and, eq, isNull } from 'drizzle-orm'; +import { nanoid } from 'nanoid/non-secure'; + +import { workspaceInvitations, workspaceMembers } from '../schemas/workspace'; +import type { LobeChatDatabase } from '../type'; + +type MemberRole = 'member' | 'owner' | 'viewer'; + +export class WorkspaceMemberModel { + private readonly db: LobeChatDatabase; + private readonly userId: string; + + constructor(db: LobeChatDatabase, userId: string) { + this.db = db; + this.userId = userId; + } + + // ===== Members ===== // + + addMember = async (params: { role?: MemberRole; userId: string; workspaceId: string }) => { + const [result] = await this.db + .insert(workspaceMembers) + .values({ + role: params.role ?? 'member', + userId: params.userId, + workspaceId: params.workspaceId, + }) + .onConflictDoUpdate({ + set: { + deletedAt: null, + joinedAt: new Date(), + role: params.role ?? 'member', + }, + target: [workspaceMembers.workspaceId, workspaceMembers.userId], + }) + .returning(); + return result; + }; + + getMember = async (workspaceId: string, userId: string) => { + return this.db.query.workspaceMembers.findFirst({ + where: and( + eq(workspaceMembers.workspaceId, workspaceId), + eq(workspaceMembers.userId, userId), + isNull(workspaceMembers.deletedAt), + ), + }); + }; + + listMembers = async (workspaceId: string, options: { includeDeleted?: boolean } = {}) => { + return this.db.query.workspaceMembers.findMany({ + where: options.includeDeleted + ? eq(workspaceMembers.workspaceId, workspaceId) + : and(eq(workspaceMembers.workspaceId, workspaceId), isNull(workspaceMembers.deletedAt)), + }); + }; + + removeMember = async (workspaceId: string, userId: string) => { + return this.db + .update(workspaceMembers) + .set({ deletedAt: new Date() }) + .where( + and( + eq(workspaceMembers.workspaceId, workspaceId), + eq(workspaceMembers.userId, userId), + isNull(workspaceMembers.deletedAt), + ), + ); + }; + + updateMemberRole = async (workspaceId: string, userId: string, role: MemberRole) => { + return this.db + .update(workspaceMembers) + .set({ role }) + .where( + and( + eq(workspaceMembers.workspaceId, workspaceId), + eq(workspaceMembers.userId, userId), + isNull(workspaceMembers.deletedAt), + ), + ); + }; + + // ===== Invitations ===== // + + createInvitation = async (params: { email?: string; role?: MemberRole; workspaceId: string }) => { + const expiresAt = new Date(); + expiresAt.setDate(expiresAt.getDate() + INVITATION_EXPIRY_DAYS); + + const [result] = await this.db + .insert(workspaceInvitations) + .values({ + email: params.email, + expiresAt, + inviterId: this.userId, + role: params.role ?? 'member', + token: nanoid(32), + workspaceId: params.workspaceId, + }) + .returning(); + return result; + }; + + findInvitationByToken = async (token: string) => { + return this.db.query.workspaceInvitations.findFirst({ + where: eq(workspaceInvitations.token, token), + }); + }; + + listPendingInvitations = async (workspaceId: string) => { + return this.db.query.workspaceInvitations.findMany({ + where: and( + eq(workspaceInvitations.workspaceId, workspaceId), + eq(workspaceInvitations.status, 'pending'), + ), + }); + }; + + revokeInvitation = async (id: string) => { + return this.db + .update(workspaceInvitations) + .set({ status: 'revoked' }) + .where(eq(workspaceInvitations.id, id)); + }; + + updateInvitationStatus = async (id: string, status: 'accepted' | 'expired' | 'revoked') => { + return this.db + .update(workspaceInvitations) + .set({ status }) + .where(eq(workspaceInvitations.id, id)); + }; +} diff --git a/packages/database/src/repositories/agentGroup/index.test.ts b/packages/database/src/repositories/agentGroup/index.test.ts index 29c80735c3..ea26b89f46 100644 --- a/packages/database/src/repositories/agentGroup/index.test.ts +++ b/packages/database/src/repositories/agentGroup/index.test.ts @@ -3,9 +3,13 @@ import { BUILTIN_AGENT_SLUGS } from '@lobechat/builtin-agents'; import { beforeEach, describe, expect, it } from 'vitest'; import { getTestDB } from '../../core/getTestDB'; +import { ChatGroupModel } from '../../models/chatGroup'; import { agents } from '../../schemas/agent'; import { chatGroups, chatGroupsAgents } from '../../schemas/chatGroup'; +import { messagePlugins, messages } from '../../schemas/message'; +import { threads, topics } from '../../schemas/topic'; import { users } from '../../schemas/user'; +import { workspaces } from '../../schemas/workspace'; import type { LobeChatDatabase } from '../../type'; import { AgentGroupRepository } from './index'; @@ -1258,4 +1262,593 @@ describe('AgentGroupRepository', () => { ); }); }); + + describe('workspace scoping', () => { + const workspaceId = 'agent-group-test-ws'; + + beforeEach(async () => { + await serverDB.insert(workspaces).values({ + id: workspaceId, + name: 'Test Workspace', + primaryOwnerId: userId, + slug: 'agent-group-test-ws', + }); + }); + + it('stamps workspaceId on the group, supervisor agent, and junction rows', async () => { + const wsRepo = new AgentGroupRepository(serverDB, userId, workspaceId); + + const result = await wsRepo.createGroupWithSupervisor({ title: 'WS Group' }); + + // group row carries the workspace id + expect(result.group.workspaceId).toBe(workspaceId); + + // supervisor agent carries the workspace id + const supervisor = await serverDB.query.agents.findFirst({ + where: (a, { eq }) => eq(a.id, result.supervisorAgentId), + }); + expect(supervisor!.workspaceId).toBe(workspaceId); + + // junction rows carry the workspace id + const junctions = await serverDB.query.chatGroupsAgents.findMany({ + where: (cga, { eq }) => eq(cga.chatGroupId, result.group.id), + }); + expect(junctions.every((j) => j.workspaceId === workspaceId)).toBe(true); + }); + + // Regression for "群组设定 system prompt won't save": a group created inside a + // workspace must be updatable through the workspace-scoped ChatGroupModel. + // Previously create wrote workspace_id = NULL, so the workspace-scoped UPDATE + // matched 0 rows and threw "not found or access denied". + it('allows the workspace-scoped ChatGroupModel to update a workspace-created group', async () => { + const wsRepo = new AgentGroupRepository(serverDB, userId, workspaceId); + const { group } = await wsRepo.createGroupWithSupervisor({ title: 'WS Group' }); + + const chatGroupModel = new ChatGroupModel(serverDB, userId, workspaceId); + + const updated = await chatGroupModel.update(group.id, { + config: { systemPrompt: 'You are a helpful team.' } as any, + }); + + expect(updated.config).toMatchObject({ systemPrompt: 'You are a helpful team.' }); + }); + + it('isolates workspace groups from personal-mode reads', async () => { + const wsRepo = new AgentGroupRepository(serverDB, userId, workspaceId); + const { group } = await wsRepo.createGroupWithSupervisor({ title: 'WS Group' }); + + // personal-mode repo (no workspaceId) must not see the workspace group + const personalRepo = new AgentGroupRepository(serverDB, userId); + expect(await personalRepo.findByIdWithAgents(group.id)).toBeNull(); + + // workspace repo sees it + expect(await wsRepo.findByIdWithAgents(group.id)).not.toBeNull(); + }); + + it('keeps personal groups out of workspace-scoped reads', async () => { + const personalRepo = new AgentGroupRepository(serverDB, userId); + const { group } = await personalRepo.createGroupWithSupervisor({ title: 'Personal Group' }); + + expect(group.workspaceId).toBeNull(); + + const wsRepo = new AgentGroupRepository(serverDB, userId, workspaceId); + expect(await wsRepo.findByIdWithAgents(group.id)).toBeNull(); + }); + + it('transfers a workspace group with members and conversation data to the target scope', async () => { + const targetWorkspaceId = 'agent-group-target-ws'; + await serverDB.insert(workspaces).values({ + id: targetWorkspaceId, + name: 'Target Workspace', + primaryOwnerId: userId, + slug: 'agent-group-target-ws', + }); + + await serverDB.insert(chatGroups).values({ + id: 'transfer-group', + title: 'Transfer Group', + userId, + workspaceId, + }); + await serverDB.insert(agents).values([ + { + id: 'transfer-supervisor', + title: 'Supervisor', + userId, + virtual: true, + workspaceId, + }, + { + id: 'transfer-member', + title: 'Member', + userId, + virtual: false, + workspaceId, + }, + ]); + await serverDB.insert(chatGroupsAgents).values([ + { + agentId: 'transfer-supervisor', + chatGroupId: 'transfer-group', + order: -1, + role: 'supervisor', + userId, + workspaceId, + }, + { + agentId: 'transfer-member', + chatGroupId: 'transfer-group', + order: 0, + role: 'participant', + userId, + workspaceId, + }, + ]); + await serverDB.insert(topics).values({ + groupId: 'transfer-group', + id: 'transfer-topic', + title: 'Group Topic', + userId, + workspaceId, + }); + await serverDB.insert(threads).values({ + agentId: 'transfer-member', + id: 'transfer-thread', + topicId: 'transfer-topic', + type: 'continuation', + userId, + workspaceId, + }); + await serverDB.insert(messages).values({ + content: 'hello', + groupId: 'transfer-group', + id: 'transfer-message', + role: 'user', + topicId: 'transfer-topic', + userId, + workspaceId, + }); + + const wsRepo = new AgentGroupRepository(serverDB, userId, workspaceId); + const result = await wsRepo.transferToWorkspace('transfer-group', targetWorkspaceId, userId); + + expect(result).toEqual({ groupId: 'transfer-group' }); + + const group = await serverDB.query.chatGroups.findFirst({ + where: (cg, { eq }) => eq(cg.id, 'transfer-group'), + }); + expect(group!.workspaceId).toBe(targetWorkspaceId); + + const memberAgents = await serverDB.query.agents.findMany({ + where: (a, { inArray }) => inArray(a.id, ['transfer-supervisor', 'transfer-member']), + }); + expect(memberAgents.every((agent) => agent.workspaceId === targetWorkspaceId)).toBe(true); + + const junctions = await serverDB.query.chatGroupsAgents.findMany({ + where: (cga, { eq }) => eq(cga.chatGroupId, 'transfer-group'), + }); + expect(junctions.every((junction) => junction.workspaceId === targetWorkspaceId)).toBe(true); + + const topic = await serverDB.query.topics.findFirst({ + where: (t, { eq }) => eq(t.id, 'transfer-topic'), + }); + const thread = await serverDB.query.threads.findFirst({ + where: (t, { eq }) => eq(t.id, 'transfer-thread'), + }); + const message = await serverDB.query.messages.findFirst({ + where: (m, { eq }) => eq(m.id, 'transfer-message'), + }); + expect(topic!.workspaceId).toBe(targetWorkspaceId); + expect(thread!.workspaceId).toBe(targetWorkspaceId); + expect(message!.workspaceId).toBe(targetWorkspaceId); + }); + + it('copies a workspace group and all members into the target scope', async () => { + const targetWorkspaceId = 'agent-group-copy-target-ws'; + await serverDB.insert(workspaces).values({ + id: targetWorkspaceId, + name: 'Copy Target Workspace', + primaryOwnerId: userId, + slug: 'agent-group-copy-target-ws', + }); + + await serverDB.insert(chatGroups).values({ + avatar: 'group-avatar', + id: 'copy-group', + title: 'Copy Group', + userId, + workspaceId, + }); + await serverDB.insert(agents).values([ + { + id: 'copy-supervisor', + model: 'gpt-4o', + provider: 'openai', + title: 'Supervisor', + userId, + virtual: true, + workspaceId, + }, + { + id: 'copy-member', + model: 'claude-3', + provider: 'anthropic', + title: 'Member', + userId, + virtual: false, + workspaceId, + }, + ]); + await serverDB.insert(chatGroupsAgents).values([ + { + agentId: 'copy-supervisor', + chatGroupId: 'copy-group', + order: -1, + role: 'supervisor', + userId, + workspaceId, + }, + { + agentId: 'copy-member', + chatGroupId: 'copy-group', + order: 0, + role: 'participant', + userId, + workspaceId, + }, + ]); + + const wsRepo = new AgentGroupRepository(serverDB, userId, workspaceId); + const result = await wsRepo.copyToWorkspace('copy-group', targetWorkspaceId, userId); + + expect(result).not.toBeNull(); + expect(result!.groupId).not.toBe('copy-group'); + expect(result!.supervisorAgentId).not.toBe('copy-supervisor'); + + const copiedGroup = await serverDB.query.chatGroups.findFirst({ + where: (cg, { eq }) => eq(cg.id, result!.groupId), + }); + expect(copiedGroup).toEqual( + expect.objectContaining({ + avatar: 'group-avatar', + title: 'Copy Group (Copy)', + userId, + workspaceId: targetWorkspaceId, + }), + ); + + const copiedJunctions = await serverDB.query.chatGroupsAgents.findMany({ + where: (cga, { eq }) => eq(cga.chatGroupId, result!.groupId), + }); + expect(copiedJunctions).toHaveLength(2); + expect(copiedJunctions.every((junction) => junction.workspaceId === targetWorkspaceId)).toBe( + true, + ); + expect(copiedJunctions.some((junction) => junction.agentId === 'copy-member')).toBe(false); + + const copiedAgentIds = copiedJunctions.map((junction) => junction.agentId); + const copiedAgents = await serverDB.query.agents.findMany({ + where: (a, { inArray }) => inArray(a.id, copiedAgentIds), + }); + expect(copiedAgents.every((agent) => agent.workspaceId === targetWorkspaceId)).toBe(true); + expect(copiedAgents.map((agent) => agent.title).sort()).toEqual(['Member', 'Supervisor']); + }); + + it('copies group topics and messages when conversation history is selected', async () => { + const targetWorkspaceId = 'agent-group-copy-history-target-ws'; + await serverDB.insert(workspaces).values({ + id: targetWorkspaceId, + name: 'Copy History Target Workspace', + primaryOwnerId: userId, + slug: 'agent-group-copy-history-target-ws', + }); + + await serverDB.insert(chatGroups).values({ + id: 'copy-history-group', + title: 'Copy History Group', + userId, + workspaceId, + }); + await serverDB.insert(agents).values([ + { + id: 'copy-history-supervisor', + model: 'gpt-4o', + provider: 'openai', + title: 'Supervisor', + userId, + virtual: true, + workspaceId, + }, + { + id: 'copy-history-member', + model: 'claude-3', + provider: 'anthropic', + title: 'Member', + userId, + virtual: false, + workspaceId, + }, + ]); + await serverDB.insert(chatGroupsAgents).values([ + { + agentId: 'copy-history-supervisor', + chatGroupId: 'copy-history-group', + order: -1, + role: 'supervisor', + userId, + workspaceId, + }, + { + agentId: 'copy-history-member', + chatGroupId: 'copy-history-group', + order: 0, + role: 'participant', + userId, + workspaceId, + }, + ]); + await serverDB.insert(topics).values({ + groupId: 'copy-history-group', + id: 'copy-history-topic', + title: 'Group topic', + userId, + workspaceId, + }); + await serverDB.insert(threads).values({ + agentId: 'copy-history-member', + groupId: 'copy-history-group', + id: 'copy-history-thread', + sourceMessageId: 'copy-history-message-user', + topicId: 'copy-history-topic', + type: 'standalone', + userId, + workspaceId, + }); + await serverDB.insert(messages).values([ + { + content: 'Hello group', + groupId: 'copy-history-group', + id: 'copy-history-message-user', + role: 'user', + targetId: 'copy-history-member', + topicId: 'copy-history-topic', + userId, + workspaceId, + }, + { + agentId: 'copy-history-member', + content: 'Hello user', + groupId: 'copy-history-group', + id: 'copy-history-message-assistant', + parentId: 'copy-history-message-user', + role: 'assistant', + threadId: 'copy-history-thread', + tools: [{ id: 'toolu_old', type: 'builtin' }], + topicId: 'copy-history-topic', + userId, + workspaceId, + }, + ]); + await serverDB.insert(messagePlugins).values({ + apiName: 'search', + arguments: '{}', + id: 'copy-history-message-assistant', + toolCallId: 'toolu_old', + userId, + workspaceId, + }); + + const wsRepo = new AgentGroupRepository(serverDB, userId, workspaceId); + const result = await wsRepo.copyToWorkspace('copy-history-group', targetWorkspaceId, userId, { + includeConversationHistory: true, + }); + + expect(result).not.toBeNull(); + + const copiedJunctions = await serverDB.query.chatGroupsAgents.findMany({ + where: (cga, { eq }) => eq(cga.chatGroupId, result!.groupId), + }); + const copiedMember = copiedJunctions.find((junction) => junction.role === 'participant'); + expect(copiedMember?.agentId).toBeDefined(); + expect(copiedMember?.agentId).not.toBe('copy-history-member'); + + const copiedTopics = await serverDB.query.topics.findMany({ + where: (topic, { eq }) => eq(topic.groupId, result!.groupId), + }); + expect(copiedTopics).toHaveLength(1); + expect(copiedTopics[0]).toEqual( + expect.objectContaining({ + clientId: null, + sessionId: null, + title: 'Group topic', + userId, + workspaceId: targetWorkspaceId, + }), + ); + + const copiedMessages = await serverDB.query.messages.findMany({ + where: (message, { eq }) => eq(message.groupId, result!.groupId), + }); + expect(copiedMessages).toHaveLength(2); + expect(copiedMessages.some((message) => message.id === 'copy-history-message-user')).toBe( + false, + ); + + const copiedAssistantMessage = copiedMessages.find((message) => message.role === 'assistant'); + const copiedUserMessage = copiedMessages.find((message) => message.role === 'user'); + expect(copiedUserMessage?.targetId).toBe(copiedMember!.agentId); + expect(copiedAssistantMessage).toEqual( + expect.objectContaining({ + agentId: copiedMember!.agentId, + clientId: null, + targetId: null, + userId, + workspaceId: targetWorkspaceId, + }), + ); + expect(copiedAssistantMessage?.tools).not.toEqual([{ id: 'toolu_old', type: 'builtin' }]); + + const copiedPlugin = await serverDB.query.messagePlugins.findFirst({ + where: (plugin, { eq }) => eq(plugin.id, copiedAssistantMessage!.id), + }); + expect(copiedPlugin?.toolCallId).not.toBe('toolu_old'); + expect(copiedPlugin?.workspaceId).toBe(targetWorkspaceId); + }); + + it('removes workspace virtual agents created by another member', async () => { + const wsRepo = new AgentGroupRepository(serverDB, userId, workspaceId); + + await serverDB.insert(chatGroups).values({ + id: 'remove-cross-member-group', + title: 'Remove Cross Member Group', + userId, + workspaceId, + }); + await serverDB.insert(agents).values({ + id: 'remove-cross-member-virtual', + title: 'Virtual From Other Member', + userId: otherUserId, + virtual: true, + workspaceId, + }); + await serverDB.insert(chatGroupsAgents).values({ + agentId: 'remove-cross-member-virtual', + chatGroupId: 'remove-cross-member-group', + role: 'participant', + userId, + workspaceId, + }); + + const result = await wsRepo.removeAgentsFromGroup('remove-cross-member-group', [ + 'remove-cross-member-virtual', + ]); + + expect(result).toEqual({ + deletedVirtualAgentIds: ['remove-cross-member-virtual'], + removedFromGroup: 1, + }); + + const relation = await serverDB.query.chatGroupsAgents.findFirst({ + where: (cga, { eq }) => eq(cga.agentId, 'remove-cross-member-virtual'), + }); + expect(relation).toBeUndefined(); + + const deletedAgent = await serverDB.query.agents.findFirst({ + where: (agent, { eq }) => eq(agent.id, 'remove-cross-member-virtual'), + }); + expect(deletedAgent).toBeUndefined(); + }); + + it('copies workspace group history created by another member', async () => { + const targetWorkspaceId = 'agent-group-copy-member-history-target-ws'; + await serverDB.insert(workspaces).values({ + id: targetWorkspaceId, + name: 'Copy Member History Target Workspace', + primaryOwnerId: userId, + slug: 'agent-group-copy-member-history-target-ws', + }); + + await serverDB.insert(chatGroups).values({ + id: 'copy-member-history-group', + title: 'Copy Member History Group', + userId, + workspaceId, + }); + await serverDB.insert(agents).values([ + { + id: 'copy-member-history-supervisor', + title: 'Supervisor', + userId, + virtual: true, + workspaceId, + }, + { + id: 'copy-member-history-agent', + title: 'Member Agent', + userId, + virtual: false, + workspaceId, + }, + ]); + await serverDB.insert(chatGroupsAgents).values([ + { + agentId: 'copy-member-history-supervisor', + chatGroupId: 'copy-member-history-group', + order: -1, + role: 'supervisor', + userId, + workspaceId, + }, + { + agentId: 'copy-member-history-agent', + chatGroupId: 'copy-member-history-group', + order: 0, + role: 'participant', + userId, + workspaceId, + }, + ]); + await serverDB.insert(topics).values({ + groupId: 'copy-member-history-group', + id: 'copy-member-history-topic', + title: 'Topic From Other Member', + userId: otherUserId, + workspaceId, + }); + await serverDB.insert(threads).values({ + agentId: 'copy-member-history-agent', + groupId: 'copy-member-history-group', + id: 'copy-member-history-thread', + topicId: 'copy-member-history-topic', + type: 'standalone', + userId: otherUserId, + workspaceId, + }); + await serverDB.insert(messages).values({ + agentId: 'copy-member-history-agent', + content: 'created by another workspace member', + groupId: 'copy-member-history-group', + id: 'copy-member-history-message', + role: 'assistant', + threadId: 'copy-member-history-thread', + topicId: 'copy-member-history-topic', + userId: otherUserId, + workspaceId, + }); + + const wsRepo = new AgentGroupRepository(serverDB, userId, workspaceId); + const result = await wsRepo.copyToWorkspace( + 'copy-member-history-group', + targetWorkspaceId, + userId, + { includeConversationHistory: true }, + ); + + expect(result).not.toBeNull(); + + const copiedTopics = await serverDB.query.topics.findMany({ + where: (topic, { eq }) => eq(topic.groupId, result!.groupId), + }); + expect(copiedTopics).toHaveLength(1); + expect(copiedTopics[0]).toEqual( + expect.objectContaining({ + title: 'Topic From Other Member', + userId, + workspaceId: targetWorkspaceId, + }), + ); + + const copiedMessages = await serverDB.query.messages.findMany({ + where: (message, { eq }) => eq(message.groupId, result!.groupId), + }); + expect(copiedMessages).toHaveLength(1); + expect(copiedMessages[0]).toEqual( + expect.objectContaining({ + content: 'created by another workspace member', + userId, + workspaceId: targetWorkspaceId, + }), + ); + }); + }); }); diff --git a/packages/database/src/repositories/agentGroup/index.ts b/packages/database/src/repositories/agentGroup/index.ts index cc8db0b739..09859e0fd6 100644 --- a/packages/database/src/repositories/agentGroup/index.ts +++ b/packages/database/src/repositories/agentGroup/index.ts @@ -1,11 +1,32 @@ import { BUILTIN_AGENT_SLUGS } from '@lobechat/builtin-agents'; import type { AgentGroupDetail, AgentGroupMember } from '@lobechat/types'; import { cleanObject } from '@lobechat/utils'; -import { and, eq, inArray } from 'drizzle-orm'; +import { and, eq, inArray, not } from 'drizzle-orm'; -import type { AgentItem, ChatGroupItem, NewChatGroup, NewChatGroupAgent } from '../../schemas'; -import { agents, chatGroups, chatGroupsAgents } from '../../schemas'; +import type { + AgentItem, + ChatGroupItem, + NewAgent, + NewChatGroup, + NewChatGroupAgent, +} from '../../schemas'; +import { + agents, + chatGroups, + chatGroupsAgents, + messagePlugins, + messages, + threads, + topics, +} from '../../schemas'; import type { LobeChatDatabase } from '../../type'; +import { idGenerator } from '../../utils/idGenerator'; +import { buildWorkspaceWhere } from '../../utils/workspace'; + +interface CopyAgentGroupToWorkspaceOptions { + includeConversationHistory?: boolean; + newTitle?: string; +} export interface SupervisorAgentConfig { avatar?: string; @@ -53,12 +74,230 @@ export interface CreateGroupWithSupervisorResult { export class AgentGroupRepository { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } + /** + * Workspace-aware ownership predicate for the `chat_groups` table. In personal + * mode (`workspaceId` absent) matches `user_id = ? AND workspace_id IS NULL`; + * in team mode matches `workspace_id = ?` (shared with all members). + */ + private groupOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, chatGroups); + private agentOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agents); + private topicOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, topics); + private threadOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, threads); + private messageOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, messages); + private messagePluginOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, messagePlugins); + + private buildCopiedAgent = ( + source: AgentItem | undefined, + targetWorkspaceId: string | null, + targetUserId: string, + fallbackTitle: string, + ): NewAgent => ({ + agencyConfig: source?.agencyConfig, + avatar: source?.avatar, + backgroundColor: source?.backgroundColor, + chatConfig: source?.chatConfig, + description: source?.description, + editorData: source?.editorData, + fewShots: source?.fewShots, + model: source?.model, + openingMessage: source?.openingMessage, + openingQuestions: source?.openingQuestions, + params: source?.params, + pinned: source?.pinned, + plugins: source?.plugins, + provider: source?.provider, + systemRole: source?.systemRole, + tags: source?.tags, + title: source?.title || fallbackTitle, + tts: source?.tts, + userId: targetUserId, + virtual: source?.virtual ?? true, + workspaceId: targetWorkspaceId, + }); + + private remapToolIds = (tools: unknown, toolIdMap: Map) => { + if (!Array.isArray(tools)) return tools; + + return tools.map((tool) => { + if (!tool || typeof tool !== 'object') return tool; + + const toolRecord = tool as Record; + const toolId = toolRecord.id; + if (typeof toolId !== 'string') return tool; + + return { + ...toolRecord, + id: toolIdMap.get(toolId) ?? toolId, + }; + }); + }; + + private copyGroupConversationHistory = async ({ + agentIdMap, + executor, + newGroupId, + sourceGroupId, + targetUserId, + targetWorkspaceId, + }: { + agentIdMap: Map; + executor: LobeChatDatabase; + newGroupId: string; + sourceGroupId: string; + targetUserId: string; + targetWorkspaceId: string | null; + }) => { + const mapAgentId = (agentId?: null | string) => + agentId ? (agentIdMap.get(agentId) ?? null) : null; + const mapTargetId = (targetId?: null | string) => { + if (!targetId || targetId === 'user') return targetId ?? null; + + return agentIdMap.get(targetId) ?? null; + }; + + const sourceTopics = await executor.query.topics.findMany({ + orderBy: (topic, { asc }) => [asc(topic.createdAt)], + where: and(this.topicOwnership(), eq(topics.groupId, sourceGroupId)), + }); + + if (sourceTopics.length === 0) return; + + const sourceTopicIds = sourceTopics.map((topic) => topic.id); + const topicIdMap = new Map(sourceTopics.map((topic) => [topic.id, idGenerator('topics')])); + + const sourceThreads = await executor.query.threads.findMany({ + orderBy: (thread, { asc }) => [asc(thread.createdAt)], + where: and(this.threadOwnership(), inArray(threads.topicId, sourceTopicIds)), + }); + + const threadIdMap = new Map( + sourceThreads.map((thread) => [thread.id, idGenerator('threads', 16)]), + ); + + const sourceMessages = await executor.query.messages.findMany({ + orderBy: (message, { asc }) => [asc(message.createdAt)], + where: and(this.messageOwnership(), inArray(messages.topicId, sourceTopicIds)), + }); + + const messageIdMap = new Map( + sourceMessages.map((message) => [message.id, idGenerator('messages')]), + ); + + const toolIdMap = new Map(); + for (const message of sourceMessages) { + if (!Array.isArray(message.tools)) continue; + + for (const tool of message.tools) { + if (!tool || typeof tool !== 'object') continue; + + const toolId = (tool as Record).id; + if (typeof toolId === 'string') { + toolIdMap.set(toolId, `toolu_${idGenerator('messages')}`); + } + } + } + + await executor.insert(topics).values( + sourceTopics.map((topic) => ({ + ...topic, + agentId: mapAgentId(topic.agentId), + clientId: null, + groupId: newGroupId, + id: topicIdMap.get(topic.id), + sessionId: null, + userId: targetUserId, + workspaceId: targetWorkspaceId, + })), + ); + + if (sourceThreads.length > 0) { + await executor.insert(threads).values( + sourceThreads.map((thread) => ({ + ...thread, + agentId: mapAgentId(thread.agentId), + clientId: null, + groupId: newGroupId, + id: threadIdMap.get(thread.id), + parentThreadId: thread.parentThreadId + ? (threadIdMap.get(thread.parentThreadId) ?? null) + : null, + sourceMessageId: thread.sourceMessageId + ? (messageIdMap.get(thread.sourceMessageId) ?? null) + : null, + topicId: topicIdMap.get(thread.topicId), + userId: targetUserId, + workspaceId: targetWorkspaceId, + })), + ); + } + + if (sourceMessages.length === 0) return; + + const sourceMessageIds = sourceMessages.map((message) => message.id); + const sourcePlugins = await executor.query.messagePlugins.findMany({ + where: and(this.messagePluginOwnership(), inArray(messagePlugins.id, sourceMessageIds)), + }); + + const messageRows = sourceMessages.map((message) => { + const newMessageId = messageIdMap.get(message.id)!; + const newTopicId = message.topicId ? (topicIdMap.get(message.topicId) ?? null) : null; + + return { + ...message, + agentId: mapAgentId(message.agentId), + clientId: null, + groupId: newGroupId, + id: newMessageId, + messageGroupId: null, + parentId: message.parentId ? (messageIdMap.get(message.parentId) ?? null) : null, + quotaId: message.quotaId ? (messageIdMap.get(message.quotaId) ?? null) : null, + sessionId: null, + targetId: mapTargetId(message.targetId), + threadId: message.threadId ? (threadIdMap.get(message.threadId) ?? null) : null, + tools: this.remapToolIds(message.tools, toolIdMap), + topicId: newTopicId, + userId: targetUserId, + workspaceId: targetWorkspaceId, + }; + }); + + await executor.insert(messages).values(messageRows); + + if (sourcePlugins.length > 0) { + await executor.insert(messagePlugins).values( + sourcePlugins + .map((plugin) => { + const newMessageId = messageIdMap.get(plugin.id); + if (!newMessageId) return; + + return { + ...plugin, + clientId: null, + id: newMessageId, + toolCallId: plugin.toolCallId ? (toolIdMap.get(plugin.toolCallId) ?? null) : null, + userId: targetUserId, + workspaceId: targetWorkspaceId, + }; + }) + .filter((plugin) => !!plugin), + ); + } + }; + /** * Find a chat group by ID with its associated agents. * If no supervisor exists, a virtual supervisor agent is automatically created. @@ -68,7 +307,7 @@ export class AgentGroupRepository { async findByIdWithAgents(groupId: string): Promise { // 1. Find the group const group = await this.db.query.chatGroups.findFirst({ - where: and(eq(chatGroups.id, groupId), eq(chatGroups.userId, this.userId)), + where: and(eq(chatGroups.id, groupId), this.groupOwnership()), }); if (!group) return null; @@ -115,6 +354,7 @@ export class AgentGroupRepository { title: 'Supervisor', userId: this.userId, virtual: true, + workspaceId: this.workspaceId ?? null, }) .returning(); @@ -125,6 +365,7 @@ export class AgentGroupRepository { order: -1, // Supervisor always first (negative order) role: 'supervisor', userId: this.userId, + workspaceId: this.workspaceId ?? null, }); supervisorAgentId = supervisorAgent.id; @@ -178,13 +419,14 @@ export class AgentGroupRepository { title: supervisorConfig?.title ?? 'Supervisor', userId: this.userId, virtual: true, + workspaceId: this.workspaceId ?? null, }) .returning(); // 2. Create the group const [group] = await this.db .insert(chatGroups) - .values({ ...groupParams, userId: this.userId }) + .values({ ...groupParams, userId: this.userId, workspaceId: this.workspaceId ?? null }) .returning(); // 3. Add supervisor agent to group with role 'supervisor' @@ -194,6 +436,7 @@ export class AgentGroupRepository { order: -1, // Supervisor always first (negative order) role: 'supervisor', userId: this.userId, + workspaceId: this.workspaceId ?? null, }; // 4. Add member agents to group with role 'participant' @@ -203,6 +446,7 @@ export class AgentGroupRepository { order: index, role: 'participant', userId: this.userId, + workspaceId: this.workspaceId ?? null, })); // 5. Insert all group-agent relationships @@ -245,7 +489,7 @@ export class AgentGroupRepository { virtual: agents.virtual, }) .from(agents) - .where(and(eq(agents.userId, this.userId), inArray(agents.id, agentIds))); + .where(and(this.agentOwnership(), inArray(agents.id, agentIds))); const virtualAgents: RemoveAgentsCheckResult['virtualAgents'] = []; const nonVirtualAgentIds: string[] = []; @@ -300,7 +544,7 @@ export class AgentGroupRepository { if (deleteVirtualAgents && virtualAgentIds.length > 0) { await this.db .delete(agents) - .where(and(eq(agents.userId, this.userId), inArray(agents.id, virtualAgentIds))); + .where(and(this.agentOwnership(), inArray(agents.id, virtualAgentIds))); } return { @@ -326,7 +570,7 @@ export class AgentGroupRepository { ): Promise<{ groupId: string; supervisorAgentId: string } | null> { // 1. Get the source group const sourceGroup = await this.db.query.chatGroups.findFirst({ - where: and(eq(chatGroups.id, groupId), eq(chatGroups.userId, this.userId)), + where: and(eq(chatGroups.id, groupId), this.groupOwnership()), }); if (!sourceGroup) return null; @@ -374,6 +618,7 @@ export class AgentGroupRepository { pinned: sourceGroup.pinned, title: newTitle || (sourceGroup.title ? `${sourceGroup.title} (Copy)` : 'Copy'), userId: this.userId, + workspaceId: this.workspaceId ?? null, }) .returning(); @@ -393,6 +638,7 @@ export class AgentGroupRepository { title: supervisorAgent?.title || 'Supervisor', userId: this.userId, virtual: true, + workspaceId: this.workspaceId ?? null, }) .returning(); @@ -421,6 +667,7 @@ export class AgentGroupRepository { // User & virtual flag userId: this.userId, virtual: true, + workspaceId: this.workspaceId ?? null, })); const newVirtualAgents = await trx.insert(agents).values(virtualAgentConfigs).returning(); @@ -440,6 +687,7 @@ export class AgentGroupRepository { order: -1, role: 'supervisor', userId: this.userId, + workspaceId: this.workspaceId ?? null, }, // Virtual members (using new copied agents) ...virtualMembers.map((member) => ({ @@ -449,6 +697,7 @@ export class AgentGroupRepository { order: member.order, role: member.role || 'participant', userId: this.userId, + workspaceId: this.workspaceId ?? null, })), // Non-virtual members (referencing same agents - only add relationship) ...nonVirtualMembers.map((member) => ({ @@ -458,6 +707,7 @@ export class AgentGroupRepository { order: member.order, role: member.role || 'participant', userId: this.userId, + workspaceId: this.workspaceId ?? null, })), ]; @@ -469,4 +719,202 @@ export class AgentGroupRepository { }; }); } + + async transferToWorkspace( + groupId: string, + targetWorkspaceId: string | null, + targetUserId: string, + ): Promise<{ groupId: string } | null> { + const sourceGroup = await this.db.query.chatGroups.findFirst({ + where: and(eq(chatGroups.id, groupId), this.groupOwnership()), + }); + + if (!sourceGroup) return null; + + return this.db.transaction(async (trx) => { + const memberRows = await trx + .select({ agentId: chatGroupsAgents.agentId }) + .from(chatGroupsAgents) + .where(eq(chatGroupsAgents.chatGroupId, groupId)); + + const agentIds = memberRows.map((row) => row.agentId); + const ownershipUpdate = { + userId: targetUserId, + workspaceId: targetWorkspaceId, + }; + + await trx + .update(chatGroups) + .set({ ...ownershipUpdate, updatedAt: new Date() }) + .where(eq(chatGroups.id, groupId)); + + await trx + .update(chatGroupsAgents) + .set(ownershipUpdate) + .where(eq(chatGroupsAgents.chatGroupId, groupId)); + + if (agentIds.length > 0) { + await trx + .delete(chatGroupsAgents) + .where( + and( + inArray(chatGroupsAgents.agentId, agentIds), + not(eq(chatGroupsAgents.chatGroupId, groupId)), + ), + ); + + await trx + .update(agents) + .set({ ...ownershipUpdate, updatedAt: new Date() }) + .where(inArray(agents.id, agentIds)); + } + + const groupTopics = await trx + .select({ id: topics.id }) + .from(topics) + .where(eq(topics.groupId, groupId)); + const groupTopicIds = groupTopics.map((topic) => topic.id); + + await trx.update(topics).set(ownershipUpdate).where(eq(topics.groupId, groupId)); + await trx.update(threads).set(ownershipUpdate).where(eq(threads.groupId, groupId)); + await trx.update(messages).set(ownershipUpdate).where(eq(messages.groupId, groupId)); + + if (groupTopicIds.length > 0) { + await trx + .update(threads) + .set(ownershipUpdate) + .where(inArray(threads.topicId, groupTopicIds)); + await trx + .update(messages) + .set(ownershipUpdate) + .where(inArray(messages.topicId, groupTopicIds)); + } + + return { groupId }; + }); + } + + async copyToWorkspace( + groupId: string, + targetWorkspaceId: string | null, + targetUserId: string, + optionsOrNewTitle?: CopyAgentGroupToWorkspaceOptions | string, + ): Promise<{ groupId: string; supervisorAgentId: string } | null> { + const options = + typeof optionsOrNewTitle === 'string' + ? { newTitle: optionsOrNewTitle } + : (optionsOrNewTitle ?? {}); + const sourceGroup = await this.db.query.chatGroups.findFirst({ + where: and(eq(chatGroups.id, groupId), this.groupOwnership()), + }); + + if (!sourceGroup) return null; + + const groupAgentsWithDetails = await this.db + .select({ + agent: agents, + enabled: chatGroupsAgents.enabled, + order: chatGroupsAgents.order, + role: chatGroupsAgents.role, + }) + .from(chatGroupsAgents) + .innerJoin(agents, eq(chatGroupsAgents.agentId, agents.id)) + .where(eq(chatGroupsAgents.chatGroupId, groupId)) + .orderBy(chatGroupsAgents.order); + + const sourceSupervisor = groupAgentsWithDetails.find((row) => row.role === 'supervisor'); + const sourceMembers = groupAgentsWithDetails.filter((row) => row.role !== 'supervisor'); + + return this.db.transaction(async (trx) => { + const [newGroup] = await trx + .insert(chatGroups) + .values({ + avatar: sourceGroup.avatar, + backgroundColor: sourceGroup.backgroundColor, + config: sourceGroup.config, + content: sourceGroup.content, + description: sourceGroup.description, + editorData: sourceGroup.editorData, + pinned: sourceGroup.pinned, + title: options.newTitle || (sourceGroup.title ? `${sourceGroup.title} (Copy)` : 'Copy'), + userId: targetUserId, + workspaceId: targetWorkspaceId, + }) + .returning(); + + const [newSupervisor] = await trx + .insert(agents) + .values( + this.buildCopiedAgent( + sourceSupervisor?.agent, + targetWorkspaceId, + targetUserId, + 'Supervisor', + ), + ) + .returning(); + + const memberAgentIdMap = new Map(); + if (sourceMembers.length > 0) { + const newMembers = await trx + .insert(agents) + .values( + sourceMembers.map((member) => + this.buildCopiedAgent(member.agent, targetWorkspaceId, targetUserId, 'Agent'), + ), + ) + .returning({ id: agents.id }); + + for (const [index, member] of sourceMembers.entries()) { + memberAgentIdMap.set(member.agent.id, newMembers[index].id); + } + } + + const groupAgentValues: NewChatGroupAgent[] = [ + { + agentId: newSupervisor.id, + chatGroupId: newGroup.id, + order: -1, + role: 'supervisor', + userId: targetUserId, + workspaceId: targetWorkspaceId, + }, + ...sourceMembers.map((member) => ({ + agentId: memberAgentIdMap.get(member.agent.id)!, + chatGroupId: newGroup.id, + enabled: member.enabled, + order: member.order, + role: member.role || 'participant', + userId: targetUserId, + workspaceId: targetWorkspaceId, + })), + ]; + + await trx.insert(chatGroupsAgents).values(groupAgentValues); + + const agentIdMap = new Map(); + if (sourceSupervisor?.agent.id) { + agentIdMap.set(sourceSupervisor.agent.id, newSupervisor.id); + } + for (const [sourceAgentId, newAgentId] of memberAgentIdMap) { + agentIdMap.set(sourceAgentId, newAgentId); + } + + if (options.includeConversationHistory) { + await this.copyGroupConversationHistory({ + agentIdMap, + executor: trx, + newGroupId: newGroup.id, + sourceGroupId: groupId, + targetUserId, + targetWorkspaceId, + }); + } + + return { + groupId: newGroup.id, + supervisorAgentId: newSupervisor.id, + }; + }); + } } diff --git a/packages/database/src/repositories/agentMigration/index.ts b/packages/database/src/repositories/agentMigration/index.ts index 6106453b2d..aed41e6733 100644 --- a/packages/database/src/repositories/agentMigration/index.ts +++ b/packages/database/src/repositories/agentMigration/index.ts @@ -1,7 +1,9 @@ import { and, eq, inArray, isNotNull, isNull } from 'drizzle-orm'; +import type { AnyPgColumn } from 'drizzle-orm/pg-core'; import { agents, agentsToSessions, messages, sessions, topics } from '../../schemas'; import type { LobeChatDatabase } from '../../type'; +import { buildWorkspaceWhere } from '../../utils/workspace'; type MigrateBySessionParams = { agentId: string; sessionId: string }; type MigrateInboxParams = { agentId: string; isInbox: true; sessionId?: string | null }; @@ -16,12 +18,17 @@ type MigrateAgentIdParams = MigrateBySessionParams | MigrateInboxParams; export class AgentMigrationRepo { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } + private ws = (cols: { userId: AnyPgColumn; workspaceId: AnyPgColumn }) => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, cols); + /** * Runtime migration: backfill agentId for all legacy topics and messages * Used for progressive migration so future queries don't need agentsToSessions lookup @@ -57,7 +64,7 @@ export class AgentMigrationRepo { .from(topics) .where( and( - eq(topics.userId, this.userId), + this.ws(topics), isNull(topics.sessionId), isNull(topics.groupId), isNull(topics.agentId), @@ -74,7 +81,7 @@ export class AgentMigrationRepo { .set({ agentId, updatedAt: topics.updatedAt }) .where( and( - eq(topics.userId, this.userId), + this.ws(topics), isNull(topics.sessionId), isNull(topics.groupId), isNull(topics.agentId), @@ -85,13 +92,7 @@ export class AgentMigrationRepo { await tx .update(messages) .set({ agentId, updatedAt: messages.updatedAt }) - .where( - and( - eq(messages.userId, this.userId), - inArray(messages.topicId, topicIds), - isNull(messages.agentId), - ), - ); + .where(and(this.ws(messages), inArray(messages.topicId, topicIds), isNull(messages.agentId))); // 4. Also update messages without topicId but in inbox (sessionId IS NULL) - preserve original updatedAt await tx @@ -99,7 +100,7 @@ export class AgentMigrationRepo { .set({ agentId, updatedAt: messages.updatedAt }) .where( and( - eq(messages.userId, this.userId), + this.ws(messages), isNull(messages.sessionId), isNull(messages.topicId), isNull(messages.agentId), @@ -118,13 +119,7 @@ export class AgentMigrationRepo { const legacyTopics = await tx .select({ id: topics.id }) .from(topics) - .where( - and( - eq(topics.userId, this.userId), - eq(topics.sessionId, sessionId), - isNull(topics.agentId), - ), - ); + .where(and(this.ws(topics), eq(topics.sessionId, sessionId), isNull(topics.agentId))); const topicIds = legacyTopics.map((t) => t.id); @@ -132,13 +127,7 @@ export class AgentMigrationRepo { await tx .update(topics) .set({ agentId, updatedAt: topics.updatedAt }) - .where( - and( - eq(topics.userId, this.userId), - eq(topics.sessionId, sessionId), - isNull(topics.agentId), - ), - ); + .where(and(this.ws(topics), eq(topics.sessionId, sessionId), isNull(topics.agentId))); // 3. Update associated messages within these topics if (topicIds.length > 0) { @@ -146,11 +135,7 @@ export class AgentMigrationRepo { .update(messages) .set({ agentId, updatedAt: messages.updatedAt }) .where( - and( - eq(messages.userId, this.userId), - inArray(messages.topicId, topicIds), - isNull(messages.agentId), - ), + and(this.ws(messages), inArray(messages.topicId, topicIds), isNull(messages.agentId)), ); } @@ -160,7 +145,7 @@ export class AgentMigrationRepo { .set({ agentId, updatedAt: messages.updatedAt }) .where( and( - eq(messages.userId, this.userId), + this.ws(messages), eq(messages.sessionId, sessionId), isNull(messages.topicId), isNull(messages.agentId), @@ -175,7 +160,7 @@ export class AgentMigrationRepo { const result = await this.db .select({ sessionId: agentsToSessions.sessionId }) .from(agentsToSessions) - .where(and(eq(agentsToSessions.agentId, agentId), eq(agentsToSessions.userId, this.userId))) + .where(and(eq(agentsToSessions.agentId, agentId), this.ws(agentsToSessions))) .limit(1); return result[0]?.sessionId ?? null; @@ -202,13 +187,7 @@ export class AgentMigrationRepo { .from(agents) .innerJoin(agentsToSessions, eq(agents.id, agentsToSessions.agentId)) .innerJoin(sessions, eq(agentsToSessions.sessionId, sessions.id)) - .where( - and( - eq(agents.userId, this.userId), - isNull(agents.sessionGroupId), - isNotNull(sessions.groupId), - ), - ); + .where(and(this.ws(agents), isNull(agents.sessionGroupId), isNotNull(sessions.groupId))); if (agentsToMigrate.length === 0) return; @@ -220,7 +199,7 @@ export class AgentMigrationRepo { await this.db .update(agents) .set({ sessionGroupId: item.sessionGroupId, updatedAt: agents.updatedAt }) - .where(and(eq(agents.id, item.agentId), eq(agents.userId, this.userId))); + .where(and(eq(agents.id, item.agentId), this.ws(agents))); } }; } diff --git a/packages/database/src/repositories/compression/index.ts b/packages/database/src/repositories/compression/index.ts index caca9349c5..ad4622495b 100644 --- a/packages/database/src/repositories/compression/index.ts +++ b/packages/database/src/repositories/compression/index.ts @@ -5,6 +5,7 @@ import { and, eq, inArray, isNull } from 'drizzle-orm'; import type { MessageGroupItem } from '../../schemas'; import { messageGroups, messages } from '../../schemas'; import type { LobeChatDatabase } from '../../type'; +import { buildWorkspaceWhere } from '../../utils/workspace'; export interface CreateCompressionGroupParams { content: string; @@ -31,12 +32,20 @@ export interface CompressionGroupResult { export class CompressionRepository { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } + private groupsOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, messageGroups); + + private messagesOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, messages); + /** * Create a compression group and mark messages as compressed */ @@ -56,6 +65,7 @@ export class CompressionRepository { topicId, type: MessageGroupType.Compression, userId: this.userId, + workspaceId: this.workspaceId ?? null, }) .returning()) as MessageGroupItem[]; @@ -78,7 +88,7 @@ export class CompressionRepository { .from(messageGroups) .where( and( - eq(messageGroups.userId, this.userId), + this.groupsOwnership(), eq(messageGroups.topicId, topicId), eq(messageGroups.type, MessageGroupType.Compression), ), @@ -118,7 +128,7 @@ export class CompressionRepository { const existing = await this.db .select({ description: messageGroups.description }) .from(messageGroups) - .where(and(eq(messageGroups.id, groupId), eq(messageGroups.userId, this.userId))); + .where(and(eq(messageGroups.id, groupId), this.groupsOwnership())); const existingMetadata = existing[0]?.description ? JSON.parse(existing[0].description) : {}; updateData.description = JSON.stringify({ ...existingMetadata, ...metadata }); @@ -127,7 +137,7 @@ export class CompressionRepository { await this.db .update(messageGroups) .set(updateData) - .where(and(eq(messageGroups.id, groupId), eq(messageGroups.userId, this.userId))); + .where(and(eq(messageGroups.id, groupId), this.groupsOwnership())); } /** @@ -141,7 +151,7 @@ export class CompressionRepository { const existing = await this.db .select({ metadata: messageGroups.metadata }) .from(messageGroups) - .where(and(eq(messageGroups.id, groupId), eq(messageGroups.userId, this.userId))); + .where(and(eq(messageGroups.id, groupId), this.groupsOwnership())); const existingData = (existing[0]?.metadata as Record) || {}; const newMetadata = { ...existingData, ...metadata }; @@ -149,7 +159,7 @@ export class CompressionRepository { await this.db .update(messageGroups) .set({ metadata: newMetadata, updatedAt: new Date() }) - .where(and(eq(messageGroups.id, groupId), eq(messageGroups.userId, this.userId))); + .where(and(eq(messageGroups.id, groupId), this.groupsOwnership())); } /** @@ -161,7 +171,7 @@ export class CompressionRepository { await this.db .update(messages) .set({ messageGroupId: groupId }) - .where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIds))); + .where(and(this.messagesOwnership(), inArray(messages.id, messageIds))); } /** @@ -173,7 +183,7 @@ export class CompressionRepository { await this.db .update(messages) .set({ messageGroupId: null }) - .where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIds))); + .where(and(this.messagesOwnership(), inArray(messages.id, messageIds))); } /** @@ -184,7 +194,7 @@ export class CompressionRepository { const [message] = await this.db .select({ metadata: messages.metadata }) .from(messages) - .where(and(eq(messages.id, messageId), eq(messages.userId, this.userId))); + .where(and(eq(messages.id, messageId), this.messagesOwnership())); if (!message) return; @@ -194,7 +204,7 @@ export class CompressionRepository { await this.db .update(messages) .set({ metadata: newMetadata }) - .where(and(eq(messages.id, messageId), eq(messages.userId, this.userId))); + .where(and(eq(messages.id, messageId), this.messagesOwnership())); } /** @@ -206,7 +216,7 @@ export class CompressionRepository { .from(messages) .where( and( - eq(messages.userId, this.userId), + this.messagesOwnership(), eq(messages.topicId, topicId), isNull(messages.messageGroupId), ), @@ -221,7 +231,7 @@ export class CompressionRepository { return this.db .select() .from(messages) - .where(and(eq(messages.userId, this.userId), eq(messages.messageGroupId, groupId))) + .where(and(this.messagesOwnership(), eq(messages.messageGroupId, groupId))) .orderBy(messages.createdAt); } @@ -233,11 +243,11 @@ export class CompressionRepository { await this.db .update(messages) .set({ messageGroupId: null }) - .where(and(eq(messages.userId, this.userId), eq(messages.messageGroupId, groupId))); + .where(and(this.messagesOwnership(), eq(messages.messageGroupId, groupId))); // 2. Delete the group await this.db .delete(messageGroups) - .where(and(eq(messageGroups.id, groupId), eq(messageGroups.userId, this.userId))); + .where(and(eq(messageGroups.id, groupId), this.groupsOwnership())); } } diff --git a/packages/database/src/repositories/dataExporter/index.test.ts b/packages/database/src/repositories/dataExporter/index.test.ts index fff26aa7b1..6686c345da 100644 --- a/packages/database/src/repositories/dataExporter/index.test.ts +++ b/packages/database/src/repositories/dataExporter/index.test.ts @@ -16,6 +16,7 @@ import { topics, users, userSettings, + workspaces, } from '../../schemas'; import type { LobeChatDatabase } from '../../type'; import { DATA_EXPORT_CONFIG, DataExporterRepos } from './index'; @@ -361,5 +362,140 @@ describe('DataExporterRepos', () => { expect(result.sessions[0]).not.toHaveProperty('userId', anotherUserId); expect(result.sessions[0]).toHaveProperty('id', 'another-session-id'); }); + + it('should not include workspace-scoped rows in personal export', async () => { + const workspaceId = 'workspace-export-filter'; + + await db.transaction(async (trx) => { + await trx.insert(workspaces).values({ + id: workspaceId, + name: 'Workspace Export Filter', + primaryOwnerId: userId, + slug: workspaceId, + }); + await trx.insert(agents).values({ + id: 'workspace-agent-id', + title: 'Workspace Agent', + userId, + workspaceId, + }); + await trx.insert(sessions).values({ + id: 'workspace-session-id', + slug: 'workspace-session', + title: 'Workspace Session', + userId, + workspaceId, + }); + await trx.insert(topics).values({ + id: 'workspace-topic-id', + sessionId: 'workspace-session-id', + title: 'Workspace Topic', + userId, + workspaceId, + }); + await trx.insert(messages).values({ + content: 'Workspace message', + id: 'workspace-message-id', + role: 'user', + sessionId: 'workspace-session-id', + topicId: 'workspace-topic-id', + userId, + workspaceId, + }); + }); + + const result = await new DataExporterRepos(db, userId).export(); + + expect(result.agents.map((agent) => agent.id)).toEqual([testIds.agentId]); + expect(result.sessions.map((session) => session.id)).toEqual([testIds.sessionId]); + expect(result.topics.map((topic) => topic.id)).toEqual([testIds.topicId]); + expect(result.messages.map((message) => message.id)).toEqual([testIds.messageId]); + }); + + it('should export only the selected workspace scope when workspaceId is provided', async () => { + const workspaceId = 'workspace-export-scope'; + const otherWorkspaceId = 'workspace-export-other'; + + await db.transaction(async (trx) => { + await trx.insert(workspaces).values([ + { + id: workspaceId, + name: 'Workspace Export Scope', + primaryOwnerId: userId, + slug: workspaceId, + }, + { + id: otherWorkspaceId, + name: 'Other Workspace Export Scope', + primaryOwnerId: userId, + slug: otherWorkspaceId, + }, + ]); + await trx.insert(agents).values([ + { + id: 'workspace-agent-id', + title: 'Workspace Agent', + userId, + workspaceId, + }, + { + id: 'other-workspace-agent-id', + title: 'Other Workspace Agent', + userId, + workspaceId: otherWorkspaceId, + }, + ]); + await trx.insert(sessions).values([ + { + id: 'workspace-session-id', + slug: 'workspace-session', + title: 'Workspace Session', + userId, + workspaceId, + }, + { + id: 'other-workspace-session-id', + slug: 'other-workspace-session', + title: 'Other Workspace Session', + userId, + workspaceId: otherWorkspaceId, + }, + ]); + await trx.insert(agentsToSessions).values({ + agentId: 'workspace-agent-id', + sessionId: 'workspace-session-id', + userId, + }); + await trx.insert(topics).values({ + id: 'workspace-topic-id', + sessionId: 'workspace-session-id', + title: 'Workspace Topic', + userId, + workspaceId, + }); + await trx.insert(messages).values({ + content: 'Workspace message', + id: 'workspace-message-id', + role: 'user', + sessionId: 'workspace-session-id', + topicId: 'workspace-topic-id', + userId, + workspaceId, + }); + }); + + const result = await new DataExporterRepos(db, userId, workspaceId).export(); + + expect(result.userSettings).toEqual([]); + expect(result.agents.map((agent) => agent.id)).toEqual(['workspace-agent-id']); + expect(result.sessions.map((session) => session.id)).toEqual(['workspace-session-id']); + expect(result.topics.map((topic) => topic.id)).toEqual(['workspace-topic-id']); + expect(result.messages.map((message) => message.id)).toEqual(['workspace-message-id']); + expect(result.agentsToSessions).toHaveLength(1); + expect(result.agentsToSessions[0]).toMatchObject({ + agentId: 'workspace-agent-id', + sessionId: 'workspace-session-id', + }); + }); }); }); diff --git a/packages/database/src/repositories/dataExporter/index.ts b/packages/database/src/repositories/dataExporter/index.ts index 5347e75600..3db81b3258 100644 --- a/packages/database/src/repositories/dataExporter/index.ts +++ b/packages/database/src/repositories/dataExporter/index.ts @@ -3,6 +3,7 @@ import pMap from 'p-map'; import * as EXPORT_TABLES from '../../schemas'; import type { LobeChatDatabase } from '../../type'; +import { buildWorkspaceWhere } from '../../utils/workspace'; interface BaseTableConfig { table: keyof typeof EXPORT_TABLES; @@ -83,10 +84,12 @@ export const DATA_EXPORT_CONFIG = { export class DataExporterRepos { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } private removeUserId(data: any[]) { @@ -110,7 +113,7 @@ export class DataExporterRepos { // If source data is empty, this table may not be able to query any data if (sourceData.length === 0) { - console.log( + console.info( `Source table ${relation.sourceTable} has no data, skipping query for ${table}`, ); return []; @@ -120,7 +123,12 @@ export class DataExporterRepos { conditions.push(inArray(tableObj[relation.field], sourceIds)); } - // If table has userId field and is not the users table, add user filter + // If table has userId field and is not the users table, add user filter. + // workspace-audit: this branch only runs for non-relation tables; relation + // tables (which carry workspace_id) are already constrained by the FK + // `inArray(sourceIds)` above, where sourceIds come from base tables that ARE + // workspace-scoped (see queryBaseTables / buildWorkspaceWhere) — so relation + // rows are transitively workspace-scoped and need no userId/workspaceId filter here. if ('userId' in tableObj && table !== 'users' && !config.relations) { conditions.push(eq(tableObj.userId, this.userId)); } @@ -132,7 +140,7 @@ export class DataExporterRepos { const result = await this.db.query[table].findMany({ where }); // Only remove userId field for tables queried with userId - console.log(`Successfully exported table: ${table}, count: ${result.length}`); + console.info(`Successfully exported table: ${table}, count: ${result.length}`); return config.relations ? result : this.removeUserId(result); } catch (error) { console.error(`Error querying table ${table}:`, error); @@ -146,17 +154,24 @@ export class DataExporterRepos { if (!tableObj) throw new Error(`Table ${table} not found`); try { + if (this.workspaceId && !('workspaceId' in tableObj)) { + return []; + } + // If there's relation config, use relation query // Default to querying with userId, use userField for special cases const userField = config.userField || 'userId'; - const where = eq(tableObj[userField], this.userId); + const where = + 'workspaceId' in tableObj + ? buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, tableObj) + : eq(tableObj[userField], this.userId); // @ts-expect-error query const result = await this.db.query[table].findMany({ where }); // Only remove userId field for tables queried with userId - console.log(`Successfully exported table: ${table}, count: ${result.length}`); + console.info(`Successfully exported table: ${table}, count: ${result.length}`); return this.removeUserId(result); } catch (error) { console.error(`Error querying table ${table}:`, error); @@ -168,7 +183,7 @@ export class DataExporterRepos { const result: Record = {}; // 1. First query all base tables concurrently - console.log('Querying base tables...'); + console.info('Querying base tables...'); const baseResults = await pMap( DATA_EXPORT_CONFIG.baseTables, async (config) => ({ data: await this.queryBaseTables(config), table: config.table }), @@ -191,7 +206,7 @@ export class DataExporterRepos { ); if (!allSourcesHaveData) { - console.log(`Skipping table ${config.table} as some source tables have no data`); + console.info(`Skipping table ${config.table} as some source tables have no data`); return { data: [], table: config.table }; } @@ -208,8 +223,6 @@ export class DataExporterRepos { result[table] = data; }); - console.log('finalResults:', result); - return result; } } diff --git a/packages/database/src/repositories/dataImporter/deprecated/index.ts b/packages/database/src/repositories/dataImporter/deprecated/index.ts index 656c87a10d..0af03556a1 100644 --- a/packages/database/src/repositories/dataImporter/deprecated/index.ts +++ b/packages/database/src/repositories/dataImporter/deprecated/index.ts @@ -1,5 +1,5 @@ import type { ImporterEntryData } from '@lobechat/types'; -import { and, eq, inArray, sql } from 'drizzle-orm'; +import { and, inArray, sql } from 'drizzle-orm'; import { sanitizeUTF8 } from '@/utils/sanitizeUTF8'; @@ -14,6 +14,7 @@ import { topics, } from '../../../schemas'; import type { LobeChatDatabase } from '../../../type'; +import { buildWorkspaceWhere } from '../../../utils/workspace'; interface ImportResult { added: number; @@ -24,6 +25,7 @@ interface ImportResult { export class DeprecatedDataImporterRepos { private userId: string; + private workspaceId?: string; private db: LobeChatDatabase; /** @@ -31,11 +33,17 @@ export class DeprecatedDataImporterRepos { */ supportVersion = 7; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; + this.workspaceId = workspaceId; this.db = db; } + /** Helper: scope predicate for workspace-aware tables. */ + private workspaceWhere(table: { userId: any; workspaceId: any }) { + return buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, table); + } + importData = async (data: ImporterEntryData) => { if (data.version > this.supportVersion) throw new Error('Unsupported version'); @@ -53,7 +61,7 @@ export class DeprecatedDataImporterRepos { if (data.sessionGroups && data.sessionGroups.length > 0) { const query = await trx.query.sessionGroups.findMany({ where: and( - eq(sessionGroups.userId, this.userId), + this.workspaceWhere(sessionGroups), inArray( sessionGroups.clientId, data.sessionGroups.map(({ id }) => id), @@ -72,6 +80,7 @@ export class DeprecatedDataImporterRepos { createdAt: new Date(createdAt), updatedAt: new Date(updatedAt), userId: this.userId, + workspaceId: this.workspaceId ?? null, })), ) .onConflictDoUpdate({ @@ -89,7 +98,7 @@ export class DeprecatedDataImporterRepos { if (data.sessions && data.sessions.length > 0) { const query = await trx.query.sessions.findMany({ where: and( - eq(sessions.userId, this.userId), + this.workspaceWhere(sessions), inArray( sessions.clientId, data.sessions.map(({ id }) => id), @@ -109,6 +118,7 @@ export class DeprecatedDataImporterRepos { groupId: group ? sessionGroupIdMap[group] : null, updatedAt: new Date(updatedAt), userId: this.userId, + workspaceId: this.workspaceId ?? null, })), ) .onConflictDoUpdate({ @@ -136,6 +146,7 @@ export class DeprecatedDataImporterRepos { ...config, ...meta, userId: this.userId, + workspaceId: this.workspaceId ?? null, })), ) .returning({ id: agents.id }); @@ -145,6 +156,7 @@ export class DeprecatedDataImporterRepos { agentId: agentMapArray[index].id, sessionId: sessionIdMap[id], userId: this.userId, + workspaceId: this.workspaceId ?? null, })), ); } @@ -154,7 +166,7 @@ export class DeprecatedDataImporterRepos { if (data.topics && data.topics.length > 0) { const skipQuery = await trx.query.topics.findMany({ where: and( - eq(topics.userId, this.userId), + this.workspaceWhere(topics), inArray( topics.clientId, data.topics.map(({ id }) => id), @@ -174,6 +186,7 @@ export class DeprecatedDataImporterRepos { sessionId: sessionId ? sessionIdMap[sessionId] : null, updatedAt: new Date(updatedAt), userId: this.userId, + workspaceId: this.workspaceId ?? null, })), ) .onConflictDoUpdate({ @@ -190,17 +203,15 @@ export class DeprecatedDataImporterRepos { // import messages if (data.messages && data.messages.length > 0) { // 1. find skip ones - console.time('find messages'); const skipQuery = await trx.query.messages.findMany({ where: and( - eq(messages.userId, this.userId), + this.workspaceWhere(messages), inArray( messages.clientId, data.messages.map(({ id }) => id), ), ), }); - console.timeEnd('find messages'); messageResult.skips = skipQuery.length; @@ -224,10 +235,10 @@ export class DeprecatedDataImporterRepos { topicId: topicId ? topicIdMap[topicId] : null, // Temporarily set to NULL updatedAt: new Date(updatedAt), userId: this.userId, + workspaceId: this.workspaceId ?? null, }), ); - console.time('insert messages'); const BATCH_SIZE = 100; // Number of records to insert per batch for (let i = 0; i < inertValues.length; i += BATCH_SIZE) { @@ -235,14 +246,12 @@ export class DeprecatedDataImporterRepos { await trx.insert(messages).values(batch); } - console.timeEnd('insert messages'); - const messageIdArray = await trx .select({ clientId: messages.clientId, id: messages.id }) .from(messages) .where( and( - eq(messages.userId, this.userId), + this.workspaceWhere(messages), inArray( messages.clientId, data.messages.map(({ id }) => id), @@ -255,7 +264,6 @@ export class DeprecatedDataImporterRepos { ); // 3. update parentId for messages - console.time('execute updates parentId'); const parentIdUpdates = shouldInsertMessages .filter((msg) => msg.parentId) // Only process messages with parentId .map((msg) => { @@ -284,7 +292,6 @@ export class DeprecatedDataImporterRepos { // console.log('sql:', SQL.sql); // console.log('params:', SQL.params); } - console.timeEnd('execute updates parentId'); // 4. insert message plugins const pluginInserts = shouldInsertMessages.filter((msg) => msg.plugin); @@ -299,6 +306,7 @@ export class DeprecatedDataImporterRepos { toolCallId: msg.tool_call_id, type: msg.plugin?.type, userId: this.userId, + workspaceId: this.workspaceId ?? null, })), ); } @@ -311,6 +319,7 @@ export class DeprecatedDataImporterRepos { id: messageIdMap[msg.id], ...msg.extra?.translate, userId: this.userId, + workspaceId: this.workspaceId ?? null, })), ); } diff --git a/packages/database/src/repositories/dataImporter/index.ts b/packages/database/src/repositories/dataImporter/index.ts index 4c209f3b75..9c08b1d543 100644 --- a/packages/database/src/repositories/dataImporter/index.ts +++ b/packages/database/src/repositories/dataImporter/index.ts @@ -5,6 +5,7 @@ import { uuid } from '@/utils/uuid'; import * as EXPORT_TABLES from '../../schemas'; import type { LobeChatDatabase } from '../../type'; +import { buildWorkspaceWhere } from '../../utils/workspace'; import { DeprecatedDataImporterRepos } from './deprecated'; interface ImportResult { @@ -256,15 +257,17 @@ const IMPORT_TABLE_CONFIG: TableImportConfig[] = [ export class DataImporterRepos { private userId: string; + private workspaceId?: string; private db: LobeChatDatabase; private deprecatedDataImporterRepos: DeprecatedDataImporterRepos; private idMaps: Record> = {}; private conflictRecords: Record = {}; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; + this.workspaceId = workspaceId; this.db = db; - this.deprecatedDataImporterRepos = new DeprecatedDataImporterRepos(db, userId); + this.deprecatedDataImporterRepos = new DeprecatedDataImporterRepos(db, userId, workspaceId); } importData = async (data: ImporterEntryData): Promise => { @@ -301,7 +304,7 @@ export class DataImporterRepos { // Use unified import method const result = await this.importTableData(trx, config, tableData, conflictStrategy); - console.log(`imported table: ${tableName}, records: ${tableData.length}`); + console.info(`imported table: ${tableName}, records: ${tableData.length}`); if (Object.values(result).some((value) => value > 0)) { results[tableName] = result; @@ -381,8 +384,15 @@ export class DataImporterRepos { const clientIds = tableData.map((item) => item.clientId || item.id).filter(Boolean); if (clientIds.length > 0) { + const workspaceFilter = + 'workspaceId' in table + ? buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + table as any, + ) + : eq(table.userId, this.userId); existingRecords = await trx.query[tableName].findMany({ - where: and(eq(table.userId, this.userId), inArray(table.clientId, clientIds)), + where: and(workspaceFilter, inArray(table.clientId, clientIds)), }); } } @@ -453,7 +463,7 @@ export class DataImporterRepos { if (item.accessedAt) dateFields.accessedAt = new Date(item.accessedAt); // Create new record object - let newRecord: any = {}; + let newRecord: any; // Decide how to process based on whether it's composite key and whether to preserve ID if (isCompositeKey) { @@ -465,6 +475,7 @@ export class DataImporterRepos { ...dateFields, clientId: item.clientId || item.id, userId: this.userId, + ...('workspaceId' in table ? { workspaceId: this.workspaceId ?? null } : {}), }; } else { // Non-composite key table processing @@ -473,6 +484,7 @@ export class DataImporterRepos { ...dateFields, clientId: item.clientId || item.id, userId: this.userId, + ...('workspaceId' in table ? { workspaceId: this.workspaceId ?? null } : {}), }; } @@ -526,9 +538,18 @@ export class DataImporterRepos { .filter((field) => record.newRecord[field] !== undefined) .map((field) => eq(table[field], record.newRecord[field])); - // Add userId condition (if table has userId field) + // Add userId/workspaceId condition (if table has these fields) if ('userId' in table) { - whereConditions.push(eq(table.userId, this.userId)); + if ('workspaceId' in table) { + whereConditions.push( + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + table as any, + ), + ); + } else { + whereConditions.push(eq(table.userId, this.userId)); + } } if (whereConditions.length > 0) { diff --git a/packages/database/src/repositories/home/index.ts b/packages/database/src/repositories/home/index.ts index 693d471e2b..d4f73cab41 100644 --- a/packages/database/src/repositories/home/index.ts +++ b/packages/database/src/repositories/home/index.ts @@ -16,6 +16,7 @@ import { } from '../../schemas'; import { type LobeChatDatabase } from '../../type'; import { sanitizeBm25Query } from '../../utils/bm25'; +import { buildWorkspaceWhere } from '../../utils/workspace'; // Re-export types for backward compatibility export type { @@ -30,13 +31,19 @@ export type { */ export class HomeRepository { private userId: string; + private workspaceId?: string; private db: LobeChatDatabase; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; + this.workspaceId = workspaceId; this.db = db; } + private get scope() { + return { userId: this.userId, workspaceId: this.workspaceId }; + } + /** * Get sidebar agent list with pinned, grouped, and ungrouped items */ @@ -60,7 +67,7 @@ export class HomeRepository { .from(agents) .leftJoin(agentsToSessions, eq(agents.id, agentsToSessions.agentId)) .leftJoin(sessions, eq(agentsToSessions.sessionId, sessions.id)) - .where(and(eq(agents.userId, this.userId), not(eq(agents.virtual, true)))) + .where(and(buildWorkspaceWhere(this.scope, agents), not(eq(agents.virtual, true)))) .orderBy(desc(agents.updatedAt)); // 2. Query all chatGroups (group chats) @@ -76,7 +83,7 @@ export class HomeRepository { updatedAt: chatGroups.updatedAt, }) .from(chatGroups) - .where(eq(chatGroups.userId, this.userId)) + .where(buildWorkspaceWhere(this.scope, chatGroups)) .orderBy(desc(chatGroups.updatedAt)); // 2.1 Query member avatars for each chat group @@ -90,7 +97,7 @@ export class HomeRepository { sort: sessionGroups.sort, }) .from(sessionGroups) - .where(eq(sessionGroups.userId, this.userId)) + .where(buildWorkspaceWhere(this.scope, sessionGroups)) .orderBy(sessionGroups.sort); // 4. Process and categorize @@ -225,7 +232,7 @@ export class HomeRepository { .leftJoin(sessions, eq(agentsToSessions.sessionId, sessions.id)) .where( and( - eq(agents.userId, this.userId), + buildWorkspaceWhere(this.scope, agents), not(eq(agents.virtual, true)), sql`(${agents.title} @@@ ${bm25Query} OR ${agents.description} @@@ ${bm25Query})`, ), @@ -245,7 +252,7 @@ export class HomeRepository { .from(chatGroups) .where( and( - eq(chatGroups.userId, this.userId), + buildWorkspaceWhere(this.scope, chatGroups), sql`(${chatGroups.title} @@@ ${bm25Query} OR ${chatGroups.description} @@@ ${bm25Query})`, ), ) diff --git a/packages/database/src/repositories/knowledge/index.test.ts b/packages/database/src/repositories/knowledge/index.test.ts index 0a1ddfe1d5..65c32c25c2 100644 --- a/packages/database/src/repositories/knowledge/index.test.ts +++ b/packages/database/src/repositories/knowledge/index.test.ts @@ -8,6 +8,7 @@ import { documents, files } from '../../schemas/file'; import { chunks, embeddings } from '../../schemas/rag'; import { fileChunks } from '../../schemas/relations'; import { users } from '../../schemas/user'; +import { workspaces } from '../../schemas/workspace'; import type { LobeChatDatabase } from '../../type'; import { KnowledgeRepo } from './index'; @@ -257,6 +258,78 @@ describe('KnowledgeRepo', () => { }); }); + describe('query - workspace isolation', () => { + const workspaceId = 'knowledge-workspace'; + + beforeEach(async () => { + await serverDB.insert(workspaces).values({ + id: workspaceId, + name: 'Knowledge Workspace', + primaryOwnerId: userId, + slug: workspaceId, + }); + + await serverDB.insert(files).values([ + { + fileType: 'application/pdf', + name: 'workspace-owner-file.pdf', + size: 1024, + url: 'workspace-owner-file-url', + userId, + workspaceId, + }, + { + fileType: 'application/pdf', + name: 'viewer-personal-file.pdf', + size: 1024, + url: 'viewer-personal-file-url', + userId: otherUserId, + }, + ]); + + await serverDB.insert(documents).values([ + { + content: 'Workspace owner document', + fileType: 'application/pdf', + filename: 'workspace-owner-doc.pdf', + source: 'workspace-owner-source', + sourceType: 'api', + totalCharCount: 100, + totalLineCount: 10, + userId, + workspaceId, + }, + { + content: 'Viewer personal document', + fileType: 'application/pdf', + filename: 'viewer-personal-doc.pdf', + source: 'viewer-personal-source', + sourceType: 'api', + totalCharCount: 100, + totalLineCount: 10, + userId: otherUserId, + }, + ]); + }); + + it('should return workspace items regardless of the creator user', async () => { + const workspaceRepo = new KnowledgeRepo(serverDB, otherUserId, workspaceId); + + const results = await workspaceRepo.query({ category: FilesTabs.All }); + + const names = results.map((item) => item.name).sort(); + expect(names).toEqual(['workspace-owner-doc.pdf', 'workspace-owner-file.pdf']); + }); + + it('should not return workspace items in personal mode', async () => { + const results = await knowledgeRepo.query({ category: FilesTabs.All }); + + const names = results.map((item) => item.name).sort(); + expect(names).not.toContain('workspace-owner-doc.pdf'); + expect(names).not.toContain('workspace-owner-file.pdf'); + }); + }); + describe('query - search filtering', () => { beforeEach(async () => { await serverDB.insert(files).values([ diff --git a/packages/database/src/repositories/knowledge/index.ts b/packages/database/src/repositories/knowledge/index.ts index 13e0603be1..92d25f348f 100644 --- a/packages/database/src/repositories/knowledge/index.ts +++ b/packages/database/src/repositories/knowledge/index.ts @@ -6,6 +6,7 @@ import { DocumentModel } from '../../models/document'; import { FileModel } from '../../models/file'; import { DOCUMENT_FOLDER_TYPE, documents, files, knowledgeBaseFiles } from '../../schemas'; import type { LobeChatDatabase } from '../../type'; +import { buildWorkspaceWhere } from '../../utils/workspace'; export interface KnowledgeItem { chunkTaskId?: string | null; @@ -39,14 +40,26 @@ export class KnowledgeRepo { private db: LobeChatDatabase; private fileModel: FileModel; private documentModel: DocumentModel; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; - this.fileModel = new FileModel(db, userId); - this.documentModel = new DocumentModel(db, userId); + this.workspaceId = workspaceId; + this.fileModel = new FileModel(db, userId, workspaceId); + this.documentModel = new DocumentModel(db, userId, workspaceId); } + private fileOwnershipSql = (alias: 'f' = 'f') => + this.workspaceId + ? sql`${sql.raw(`${alias}.workspace_id`)} = ${this.workspaceId}` + : sql`${sql.raw(`${alias}.user_id`)} = ${this.userId} AND ${sql.raw(`${alias}.workspace_id`)} IS NULL`; + + private documentOwnershipSql = (alias: 'd' | 'documents' = 'd') => + this.workspaceId + ? sql`${sql.raw(`${alias}.workspace_id`)} = ${this.workspaceId}` + : sql`${sql.raw(`${alias}.user_id`)} = ${this.userId} AND ${sql.raw(`${alias}.workspace_id`)} IS NULL`; + /** * Query combined results from files and documents tables */ @@ -183,7 +196,7 @@ export class KnowledgeRepo { FROM ${files} f LEFT JOIN ${documents} d ON f.id = d.file_id - WHERE f.user_id = ${this.userId} + WHERE ${this.fileOwnershipSql('f')} AND NOT EXISTS ( SELECT 1 FROM ${knowledgeBaseFiles} WHERE ${knowledgeBaseFiles.fileId} = f.id @@ -209,7 +222,7 @@ export class KnowledgeRepo { metadata, 'document' as source_type FROM ${documents} - WHERE user_id = ${this.userId} + WHERE ${this.documentOwnershipSql('documents')} AND source_type != ${'file'} AND knowledge_base_id IS NULL `; @@ -315,7 +328,10 @@ export class KnowledgeRepo { if (document.fileType === DOCUMENT_FOLDER_TYPE) { const children = await this.db.query.documents.findMany({ - where: and(eq(documents.parentId, id), eq(documents.userId, this.userId)), + where: and( + eq(documents.parentId, id), + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents), + ), }); for (const child of children) { @@ -323,7 +339,10 @@ export class KnowledgeRepo { } const childFiles = await this.db.query.files.findMany({ - where: and(eq(files.parentId, id), eq(files.userId, this.userId)), + where: and( + eq(files.parentId, id), + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, files), + ), }); for (const file of childFiles) { @@ -345,7 +364,7 @@ export class KnowledgeRepo { showFilesInKnowledgeBase, parentId, }: QueryFileListParams = {}): ReturnType { - const whereConditions: any[] = [sql`f.user_id = ${this.userId}`]; + const whereConditions: any[] = [this.fileOwnershipSql('f')]; // Parent ID filter if (parentId !== undefined) { @@ -376,7 +395,7 @@ export class KnowledgeRepo { // Knowledge base filter if (knowledgeBaseId) { // Build where conditions using proper table references (f.column instead of files.column) - const kbWhereConditions: any[] = [sql`f.user_id = ${this.userId}`]; + const kbWhereConditions: any[] = [this.fileOwnershipSql('f')]; // Parent ID filter if (parentId !== undefined) { @@ -477,7 +496,7 @@ export class KnowledgeRepo { parentId, }: QueryFileListParams = {}): ReturnType { const whereConditions: any[] = [ - sql`${documents.userId} = ${this.userId}`, + this.documentOwnershipSql('documents'), sql`${documents.sourceType} != ${'file'}`, ]; @@ -542,7 +561,7 @@ export class KnowledgeRepo { // Documents are linked to knowledge bases through files table via fileId if (knowledgeBaseId) { // Build where conditions using proper table references (d.column instead of documents.column) - const kbWhereConditions: any[] = [sql`d.user_id = ${this.userId}`]; + const kbWhereConditions: any[] = [this.documentOwnershipSql('d')]; // Parent ID filter if (parentId !== undefined) { diff --git a/packages/database/src/repositories/search/index.ts b/packages/database/src/repositories/search/index.ts index cc26430fcc..b13ec7293a 100644 --- a/packages/database/src/repositories/search/index.ts +++ b/packages/database/src/repositories/search/index.ts @@ -14,6 +14,7 @@ import { } from '../../schemas'; import type { LobeChatDatabase } from '../../type'; import { sanitizeBm25Query } from '../../utils/bm25'; +import { buildWorkspaceWhere } from '../../utils/workspace'; export type SearchResultType = | 'page' @@ -199,10 +200,16 @@ const RECENCY_CANDIDATE_MULTIPLIER = 4; export class SearchRepo { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; + } + + private get scope() { + return { userId: this.userId, workspaceId: this.workspaceId }; } /** @@ -404,7 +411,7 @@ export class SearchRepo { .from(agents) .where( and( - eq(agents.userId, this.userId), + buildWorkspaceWhere(this.scope, agents), sql`(${agents.title} @@@ ${bm25Query} OR ${agents.description} @@@ ${bm25Query} OR ${agents.slug} @@@ ${bm25Query} OR ${agents.tags} @@@ ${bm25Query} OR ${agents.systemRole} @@@ ${bm25Query})`, ), ) @@ -458,10 +465,10 @@ export class SearchRepo { updatedAt: topics.updatedAt, }) .from(topics) - .leftJoin(agents, and(eq(topics.agentId, agents.id), eq(agents.userId, this.userId))) + .leftJoin(agents, and(eq(topics.agentId, agents.id), buildWorkspaceWhere(this.scope, agents))) .where( and( - eq(topics.userId, this.userId), + buildWorkspaceWhere(this.scope, topics), agentId ? eq(topics.agentId, agentId) : undefined, sql`(${topics.title} @@@ ${bm25Query} OR ${topics.content} @@@ ${bm25Query} OR ${topics.description} @@@ ${bm25Query})`, ), @@ -520,7 +527,7 @@ export class SearchRepo { .leftJoin(agents, eq(messages.agentId, agents.id)) .where( and( - eq(messages.userId, this.userId), + buildWorkspaceWhere(this.scope, messages), ne(messages.role, 'tool'), agentId ? eq(messages.agentId, agentId) : undefined, sql`${messages.content} @@@ ${bm25Query}`, @@ -574,7 +581,7 @@ export class SearchRepo { .leftJoin(knowledgeBaseFiles, eq(files.id, knowledgeBaseFiles.fileId)) .where( and( - eq(files.userId, this.userId), + buildWorkspaceWhere(this.scope, files), ne(files.fileType, 'custom/document'), sql`${files.name} @@@ ${bm25Query}`, ), @@ -619,7 +626,7 @@ export class SearchRepo { .from(documents) .where( and( - eq(documents.userId, this.userId), + buildWorkspaceWhere(this.scope, documents), eq(documents.fileType, DOCUMENT_FOLDER_TYPE), sql`(${documents.title} @@@ ${bm25Query} OR ${documents.slug} @@@ ${bm25Query} OR ${documents.description} @@@ ${bm25Query})`, ), @@ -661,7 +668,7 @@ export class SearchRepo { .from(documents) .where( and( - eq(documents.userId, this.userId), + buildWorkspaceWhere(this.scope, documents), eq(documents.fileType, 'custom/document'), sql`(${documents.title} @@@ ${bm25Query} OR ${documents.slug} @@@ ${bm25Query} OR ${documents.content} @@@ ${bm25Query})`, ), @@ -710,7 +717,7 @@ export class SearchRepo { const matchClause = sql`(${documents.title} @@@ ${bm25Query} OR ${documents.slug} @@@ ${bm25Query} OR ${documents.content} @@@ ${bm25Query})`; const folderClause = ne(documents.fileType, DOCUMENT_FOLDER_TYPE); - const userClause = eq(documents.userId, this.userId); + const userClause = buildWorkspaceWhere(this.scope, documents); const inlineRowsPromise = this.db .select({ @@ -751,7 +758,7 @@ export class SearchRepo { knowledgeBaseFiles, and( eq(knowledgeBaseFiles.fileId, documents.fileId), - eq(knowledgeBaseFiles.userId, this.userId), + buildWorkspaceWhere(this.scope, knowledgeBaseFiles), inArray(knowledgeBaseFiles.knowledgeBaseId, knowledgeBaseIds), ), ) @@ -842,7 +849,7 @@ export class SearchRepo { .from(chatGroups) .where( and( - eq(chatGroups.userId, this.userId), + buildWorkspaceWhere(this.scope, chatGroups), sql`(${chatGroups.title} @@@ ${bm25Query} OR ${chatGroups.description} @@@ ${bm25Query})`, ), ) @@ -884,7 +891,7 @@ export class SearchRepo { .from(knowledgeBases) .where( and( - eq(knowledgeBases.userId, this.userId), + buildWorkspaceWhere(this.scope, knowledgeBases), sql`(${knowledgeBases.name} @@@ ${bm25Query} OR ${knowledgeBases.description} @@@ ${bm25Query})`, ), ) diff --git a/packages/database/src/repositories/topicImporter/index.ts b/packages/database/src/repositories/topicImporter/index.ts index 8d8d2d5841..20104614a4 100644 --- a/packages/database/src/repositories/topicImporter/index.ts +++ b/packages/database/src/repositories/topicImporter/index.ts @@ -34,6 +34,7 @@ interface PreparedMessage { traceId?: string | null; updatedAt: Date; userId: string; + workspaceId: string | null; } interface PreparedMessagePlugin { @@ -46,15 +47,18 @@ interface PreparedMessagePlugin { toolCallId?: string | null; type?: string | null; userId: string; + workspaceId: string | null; } export class TopicImporterRepo { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } /** @@ -94,6 +98,7 @@ export class TopicImporterRepo { id: topicId, title: title || 'Imported Topic', userId: this.userId, + workspaceId: this.workspaceId ?? null, }); // Batch insert messages @@ -204,6 +209,7 @@ export class TopicImporterRepo { traceId: msg.traceId || null, updatedAt: new Date(msg.updatedTimestamp), userId: this.userId, + workspaceId: this.workspaceId ?? null, }); // If message has plugin data (tool messages), prepare plugin record @@ -219,6 +225,7 @@ export class TopicImporterRepo { toolCallId: msg.tool_call_id || null, type: plugin?.type || null, userId: this.userId, + workspaceId: this.workspaceId ?? null, }); } } diff --git a/packages/database/src/repositories/userMemory/UserMemoryTopicRepository.ts b/packages/database/src/repositories/userMemory/UserMemoryTopicRepository.ts index c834eaa5b5..be5b251bd1 100644 --- a/packages/database/src/repositories/userMemory/UserMemoryTopicRepository.ts +++ b/packages/database/src/repositories/userMemory/UserMemoryTopicRepository.ts @@ -2,6 +2,7 @@ import { and, asc, eq } from 'drizzle-orm'; import { messages } from '../../schemas'; import type { LobeChatDatabase } from '../../type'; +import { buildWorkspaceWhere } from '../../utils/workspace'; /** * Maximum character length for the query string used in memory search @@ -14,10 +15,12 @@ const MAX_QUERY_LENGTH = 7000; export class UserMemoryTopicRepository { private userId: string; private db: LobeChatDatabase; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; + this.workspaceId = workspaceId; } /** @@ -36,7 +39,7 @@ export class UserMemoryTopicRepository { .from(messages) .where( and( - eq(messages.userId, this.userId), + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, messages), eq(messages.topicId, topicId), eq(messages.role, 'user'), ), diff --git a/packages/database/src/schemas/agentCronJob.ts b/packages/database/src/schemas/agentCronJob.ts index 9b46b09ba0..20eaf36949 100644 --- a/packages/database/src/schemas/agentCronJob.ts +++ b/packages/database/src/schemas/agentCronJob.ts @@ -63,10 +63,10 @@ export const agentCronJobs = pgTable( index('agent_cron_jobs_agent_id_idx').on(t.agentId), index('agent_cron_jobs_group_id_idx').on(t.groupId), index('agent_cron_jobs_user_id_idx').on(t.userId), + index('agent_cron_jobs_workspace_id_idx').on(t.workspaceId), index('agent_cron_jobs_enabled_idx').on(t.enabled), index('agent_cron_jobs_remaining_executions_idx').on(t.remainingExecutions), index('agent_cron_jobs_last_executed_at_idx').on(t.lastExecutedAt), - index('agent_cron_jobs_workspace_id_idx').on(t.workspaceId), ], ); diff --git a/packages/database/src/schemas/agentOperations.ts b/packages/database/src/schemas/agentOperations.ts index b50b46c8d2..a66fe62b36 100644 --- a/packages/database/src/schemas/agentOperations.ts +++ b/packages/database/src/schemas/agentOperations.ts @@ -161,6 +161,7 @@ export const agentOperations = pgTable( }, (t) => [ index('agent_operations_user_id_idx').on(t.userId), + index('agent_operations_workspace_id_idx').on(t.workspaceId), index('agent_operations_agent_id_idx').on(t.agentId), index('agent_operations_topic_id_idx').on(t.topicId), index('agent_operations_thread_id_idx').on(t.threadId), @@ -170,7 +171,6 @@ export const agentOperations = pgTable( index('agent_operations_status_idx').on(t.status), index('agent_operations_user_id_created_at_idx').on(t.userId, t.createdAt), index('agent_operations_metadata_idx').using('gin', t.metadata), - index('agent_operations_workspace_id_idx').on(t.workspaceId), ], ); diff --git a/packages/database/src/schemas/connector.ts b/packages/database/src/schemas/connector.ts index 993cc1ae28..174b6e04ab 100644 --- a/packages/database/src/schemas/connector.ts +++ b/packages/database/src/schemas/connector.ts @@ -292,6 +292,8 @@ export const userConnectorTools = pgTable( export type NewUserConnectorTool = typeof userConnectorTools.$inferInsert; export type UserConnectorToolItem = typeof userConnectorTools.$inferSelect; +// Deprecated legacy plugin install table. Keep workspaceId only for old rows; +// workspace audits should ignore this table instead of expanding constraints. export const userInstalledPlugins = pgTable( 'user_installed_plugins', { diff --git a/packages/database/src/schemas/device.ts b/packages/database/src/schemas/device.ts index 65d0823b5c..2e52392adc 100644 --- a/packages/database/src/schemas/device.ts +++ b/packages/database/src/schemas/device.ts @@ -22,6 +22,14 @@ export const devices = pgTable( userId: text('user_id') .references(() => users.id, { onDelete: 'cascade' }) .notNull(), + // NOTE: devices are a USER-LEVEL identity, not workspace-scoped content. A + // physical machine belongs to the user across all of their workspaces (the + // unique key is (userId, deviceId), see below). `workspaceId` here only + // records which workspace the device was registered from — it is NOT used to + // filter device lookups. So `DeviceModel`/`deviceRouter` intentionally scope + // by userId only and do NOT use `buildWorkspaceWhere`. Do not "fix" them to + // workspace-scope reads, or a user's device would disappear inside their own + // workspaces. workspaceId: text('workspace_id').references(() => workspaces.id, { onDelete: 'cascade' }), /** Machine-derived id (sha256 truncated to 32 chars; 64 leaves room for fallback randomUUID) */ diff --git a/packages/database/src/schemas/documentHistory.ts b/packages/database/src/schemas/documentHistory.ts index 7ba9f616ae..5c41fe1807 100644 --- a/packages/database/src/schemas/documentHistory.ts +++ b/packages/database/src/schemas/documentHistory.ts @@ -30,8 +30,8 @@ export const documentHistories = pgTable( (table) => [ index('document_histories_document_id_idx').on(table.documentId), index('document_histories_user_id_idx').on(table.userId), - index('document_histories_saved_at_idx').on(table.savedAt), index('document_histories_workspace_id_idx').on(table.workspaceId), + index('document_histories_saved_at_idx').on(table.savedAt), ], ); diff --git a/packages/database/src/schemas/llmGenerationTracing.ts b/packages/database/src/schemas/llmGenerationTracing.ts index d3d03d37db..13a6f7e07c 100644 --- a/packages/database/src/schemas/llmGenerationTracing.ts +++ b/packages/database/src/schemas/llmGenerationTracing.ts @@ -104,6 +104,7 @@ export const llmGenerationTracing = pgTable( index('llm_generation_tracing_user_id_idx').on(t.userId), index('llm_generation_tracing_agent_id_idx').on(t.agentId), index('llm_generation_tracing_topic_id_idx').on(t.topicId), + index('llm_generation_tracing_workspace_id_idx').on(t.workspaceId), index('llm_generation_tracing_provider_idx').on(t.provider), index('llm_generation_tracing_model_idx').on(t.model), index('llm_generation_tracing_success_idx').on(t.success), @@ -111,7 +112,6 @@ export const llmGenerationTracing = pgTable( index('llm_generation_tracing_validation_failed_idx').on(t.validationFailed), index('llm_generation_tracing_feedback_signal_idx').on(t.feedbackSignal), index('llm_generation_tracing_created_at_idx').on(t.createdAt), - index('llm_generation_tracing_workspace_id_idx').on(t.workspaceId), ], ); diff --git a/packages/database/src/schemas/relations.ts b/packages/database/src/schemas/relations.ts index c425ee39d8..45f984e7a7 100644 --- a/packages/database/src/schemas/relations.ts +++ b/packages/database/src/schemas/relations.ts @@ -63,9 +63,9 @@ export const filesToSessions = pgTable( (t) => ({ pk: primaryKey({ columns: [t.fileId, t.sessionId] }), userIdIdx: index('files_to_sessions_user_id_idx').on(t.userId), + workspaceIdIdx: index('files_to_sessions_workspace_id_idx').on(t.workspaceId), fileIdIdx: index('files_to_sessions_file_id_idx').on(t.fileId), sessionIdIdx: index('files_to_sessions_session_id_idx').on(t.sessionId), - workspaceIdIdx: index('files_to_sessions_workspace_id_idx').on(t.workspaceId), }), ); @@ -83,9 +83,9 @@ export const fileChunks = pgTable( (t) => ({ pk: primaryKey({ columns: [t.fileId, t.chunkId] }), userIdIdx: index('file_chunks_user_id_idx').on(t.userId), + workspaceIdIdx: index('file_chunks_workspace_id_idx').on(t.workspaceId), fileIdIdx: index('file_chunks_file_id_idx').on(t.fileId), chunkIdIdx: index('file_chunks_chunk_id_idx').on(t.chunkId), - workspaceIdIdx: index('file_chunks_workspace_id_idx').on(t.workspaceId), }), ); export type NewFileChunkItem = typeof fileChunks.$inferInsert; diff --git a/packages/database/src/schemas/topic.ts b/packages/database/src/schemas/topic.ts index 4ccce5d39e..c4a63b8471 100644 --- a/packages/database/src/schemas/topic.ts +++ b/packages/database/src/schemas/topic.ts @@ -86,11 +86,11 @@ export const topics = pgTable( index('topics_provider_idx').on(t.provider), index('topics_user_id_completed_at_idx').on(t.userId, t.completedAt), index('topics_sender_id_idx').on(t.senderId), + index('topics_workspace_id_idx').on(t.workspaceId), index('topics_extract_status_gin_idx').using( 'gin', sql`(metadata->'userMemoryExtractStatus') jsonb_path_ops`, ), - index('topics_workspace_id_idx').on(t.workspaceId), ], ); diff --git a/packages/database/src/schemas/userMemories/index.ts b/packages/database/src/schemas/userMemories/index.ts index bbeb48ed29..0e84225d8b 100644 --- a/packages/database/src/schemas/userMemories/index.ts +++ b/packages/database/src/schemas/userMemories/index.ts @@ -12,7 +12,6 @@ export const userMemories = pgTable( .primaryKey(), userId: text('user_id').references(() => users.id, { onDelete: 'cascade' }), - memoryCategory: varchar255('memory_category'), memoryLayer: varchar255('memory_layer'), memoryType: varchar255('memory_type'), diff --git a/packages/database/src/utils/idGenerator.ts b/packages/database/src/utils/idGenerator.ts index 9227251db9..4245a01e8b 100644 --- a/packages/database/src/utils/idGenerator.ts +++ b/packages/database/src/utils/idGenerator.ts @@ -34,6 +34,9 @@ const prefixes = { threads: 'thd', topics: 'tpc', user: 'user', + workspaceAuditLogs: 'wal', + workspaceInvitations: 'wsi', + workspaces: 'ws', } as const; export const idGenerator = (namespace: keyof typeof prefixes, size = 12) => { diff --git a/packages/database/src/utils/seedWorkspaceRoles.ts b/packages/database/src/utils/seedWorkspaceRoles.ts new file mode 100644 index 0000000000..119029f7af --- /dev/null +++ b/packages/database/src/utils/seedWorkspaceRoles.ts @@ -0,0 +1,227 @@ +import { + PERMISSION_ACTIONS, + WORKSPACE_ROLE_DESCRIPTIONS, + WORKSPACE_ROLE_DISPLAY_NAMES, + WORKSPACE_ROLE_PERMISSIONS, + WORKSPACE_SYSTEM_ROLES, + type WorkspaceSystemRoleName, +} from '@lobechat/const/rbac'; +import { and, eq, inArray } from 'drizzle-orm'; + +import { permissions, rolePermissions, roles, userRoles } from '../schemas/rbac'; +import type { LobeChatDatabase } from '../type'; + +/** + * Map a permission code (e.g. `agent:create`) to the category column — + * always the substring before the first colon. Keeps the seeded + * `rbac_permissions.category` consistent with the legacy Hono seed data + * style and makes filtering by category trivial. + */ +const codeToCategory = (code: string): string => code.split(':')[0]; + +/** + * Strip the `:all` / `:owner` suffix from a scoped permission code and look + * up the action's display name from `PERMISSION_ACTIONS`. Returns a sensible + * fallback when the code isn't recognised so seeding never throws on an + * unknown permission. + */ +const codeToName = (code: string): string => { + const base = code.replace(/:(all|owner)$/, ''); + const entry = Object.entries(PERMISSION_ACTIONS).find(([, value]) => value === base); + if (!entry) return code; + return entry[0] + .toLowerCase() + .split('_') + .map((seg) => seg.charAt(0).toUpperCase() + seg.slice(1)) + .join(' '); +}; + +/** + * Ensure every permission code referenced by the three built-in workspace + * roles exists in `rbac_permissions`. Idempotent — re-runs are safe because + * we only insert codes that are missing. + * + * Returns a map from `code` to permission `id` for use by the role-permission + * linkage step. Permissions live in the global table (no workspaceId) — only + * the *roles* are workspace-scoped. + */ +const ensurePermissionsExist = async (db: LobeChatDatabase): Promise> => { + const requiredCodes = new Set(); + for (const codes of Object.values(WORKSPACE_ROLE_PERMISSIONS)) { + for (const code of codes) requiredCodes.add(code); + } + const codeList = [...requiredCodes]; + + const existing = await db + .select({ code: permissions.code, id: permissions.id }) + .from(permissions) + .where(inArray(permissions.code, codeList)); + + const existingCodes = new Set(existing.map((p) => p.code)); + const missing = codeList.filter((code) => !existingCodes.has(code)); + + if (missing.length > 0) { + await db + .insert(permissions) + .values( + missing.map((code) => ({ + category: codeToCategory(code), + code, + isActive: true, + name: codeToName(code), + })), + ) + .onConflictDoNothing({ target: permissions.code }); + } + + const all = await db + .select({ code: permissions.code, id: permissions.id }) + .from(permissions) + .where(inArray(permissions.code, codeList)); + + return new Map(all.map((p) => [p.code, p.id] as const)); +}; + +/** + * Create the workspace's copy of one built-in role (e.g. `workspace_owner`) + * if it doesn't already exist, then ensure the role-permission links match + * the spec in `WORKSPACE_ROLE_PERMISSIONS`. Returns the role id. + * + * The role is uniquely identified by `(name, workspaceId)` via the + * `rbac_roles_name_scope_unique` index — `onConflictDoNothing` makes the + * insert idempotent. + */ +const upsertWorkspaceRole = async ( + db: LobeChatDatabase, + workspaceId: string, + roleName: WorkspaceSystemRoleName, + permissionIdByCode: Map, +): Promise => { + const existing = await db.query.roles.findFirst({ + where: and(eq(roles.name, roleName), eq(roles.workspaceId, workspaceId)), + }); + + let roleId: string; + if (existing) { + roleId = existing.id; + } else { + const [inserted] = await db + .insert(roles) + .values({ + description: WORKSPACE_ROLE_DESCRIPTIONS[roleName], + displayName: WORKSPACE_ROLE_DISPLAY_NAMES[roleName], + isActive: true, + isSystem: true, + name: roleName, + workspaceId, + }) + .returning({ id: roles.id }); + roleId = inserted.id; + } + + const targetCodes = WORKSPACE_ROLE_PERMISSIONS[roleName]; + const targetIds = targetCodes + .map((code) => permissionIdByCode.get(code)) + .filter((id): id is string => !!id); + + if (targetIds.length === 0) return roleId; + + // Insert missing links; ON CONFLICT DO NOTHING handles re-seed. + await db + .insert(rolePermissions) + .values(targetIds.map((permissionId) => ({ permissionId, roleId }))) + .onConflictDoNothing(); + + return roleId; +}; + +export interface SeededWorkspaceRoles { + memberRoleId: string; + ownerRoleId: string; + viewerRoleId: string; +} + +/** + * Idempotently provision the three built-in workspace roles plus all + * permissions they depend on. + * + * Safe to call: + * - On `workspace.create` (new workspace gets its role triplet) + * - From a bootstrap script over every existing workspace (backfill) + * - Re-run on the same workspace (no-op after the first run) + */ +export const seedWorkspaceRoles = async ( + db: LobeChatDatabase, + workspaceId: string, +): Promise => { + const permissionIdByCode = await ensurePermissionsExist(db); + const ownerRoleId = await upsertWorkspaceRole( + db, + workspaceId, + WORKSPACE_SYSTEM_ROLES.OWNER, + permissionIdByCode, + ); + const memberRoleId = await upsertWorkspaceRole( + db, + workspaceId, + WORKSPACE_SYSTEM_ROLES.MEMBER, + permissionIdByCode, + ); + const viewerRoleId = await upsertWorkspaceRole( + db, + workspaceId, + WORKSPACE_SYSTEM_ROLES.VIEWER, + permissionIdByCode, + ); + + return { memberRoleId, ownerRoleId, viewerRoleId }; +}; + +/** + * Grant `userId` the named built-in workspace role. Idempotent — re-grants + * are no-ops thanks to the `(user_id, role_id, workspace_id)` unique index. + * + * Used by: + * - `workspace.create` — assign creator to `workspace_owner` + * - `workspaceMember.invite` accept flow — assign the invited role + * - Backfill — translate every existing `workspace_members.role` row into a + * matching `rbac_user_roles` row + */ +export const assignWorkspaceRoleToUser = async ( + db: LobeChatDatabase, + params: { + roleName: WorkspaceSystemRoleName; + userId: string; + workspaceId: string; + }, +): Promise => { + const role = await db.query.roles.findFirst({ + where: and(eq(roles.name, params.roleName), eq(roles.workspaceId, params.workspaceId)), + }); + + if (!role) { + throw new Error( + `Workspace role ${params.roleName} not found for workspace ${params.workspaceId}. ` + + `Call seedWorkspaceRoles first.`, + ); + } + + await db + .insert(userRoles) + .values({ roleId: role.id, userId: params.userId, workspaceId: params.workspaceId }) + .onConflictDoNothing(); +}; + +/** + * Revoke every workspace-scoped role for `(userId, workspaceId)`. Used by + * `workspaceMember.remove / leave` and when changing a user's role (followed + * by a fresh `assignWorkspaceRoleToUser` call). + */ +export const revokeWorkspaceRolesForUser = async ( + db: LobeChatDatabase, + params: { userId: string; workspaceId: string }, +): Promise => { + await db + .delete(userRoles) + .where(and(eq(userRoles.userId, params.userId), eq(userRoles.workspaceId, params.workspaceId))); +}; diff --git a/packages/database/src/utils/workspace.test.ts b/packages/database/src/utils/workspace.test.ts new file mode 100644 index 0000000000..7ccfb4bf7b --- /dev/null +++ b/packages/database/src/utils/workspace.test.ts @@ -0,0 +1,45 @@ +import { PgDialect } from 'drizzle-orm/pg-core'; +import { describe, expect, it } from 'vitest'; + +import { agents } from '../schemas/agent'; +import { buildWorkspacePayload, buildWorkspaceWhere } from './workspace'; + +describe('workspace utils', () => { + describe('buildWorkspaceWhere', () => { + it('scopes personal reads by user and null workspace', () => { + const condition = buildWorkspaceWhere({ userId: 'user-1' }, agents); + const built = new PgDialect().sqlToQuery(condition); + + expect(built.sql).toBe('("agents"."user_id" = $1 and "agents"."workspace_id" is null)'); + expect(built.params).toStrictEqual(['user-1']); + }); + + it('scopes workspace reads by workspace id only', () => { + const condition = buildWorkspaceWhere({ userId: 'user-1', workspaceId: 'ws-1' }, agents); + const built = new PgDialect().sqlToQuery(condition); + + expect(built.sql).toBe('"agents"."workspace_id" = $1'); + expect(built.params).toStrictEqual(['ws-1']); + }); + }); + + describe('buildWorkspacePayload', () => { + it('writes personal payloads with a null workspace id', () => { + expect(buildWorkspacePayload({ userId: 'user-1' }, { title: 'Personal agent' })).toEqual({ + title: 'Personal agent', + userId: 'user-1', + workspaceId: null, + }); + }); + + it('writes workspace payloads with creator and workspace id', () => { + expect( + buildWorkspacePayload({ userId: 'user-1', workspaceId: 'ws-1' }, { title: 'Team agent' }), + ).toEqual({ + title: 'Team agent', + userId: 'user-1', + workspaceId: 'ws-1', + }); + }); + }); +}); diff --git a/packages/database/src/utils/workspace.ts b/packages/database/src/utils/workspace.ts new file mode 100644 index 0000000000..c510a0fffd --- /dev/null +++ b/packages/database/src/utils/workspace.ts @@ -0,0 +1,70 @@ +import { and, eq, isNull, type SQL } from 'drizzle-orm'; +import type { AnyPgColumn } from 'drizzle-orm/pg-core'; + +/** + * Workspace-aware ownership predicate for content tables. + * + * Compat mode semantics: + * - `ctx.workspaceId` set → row belongs to that team workspace (shared with all + * members; `user_id` only records the creator and isn't part of the filter) + * - `ctx.workspaceId` absent → personal mode: row belongs to a single user with + * `workspace_id IS NULL` + * + * Used by content router models (agent / session / message / file / topic …) + * to replace the previous `userId = ?` only filter. + * + * @example Model-side + * ```ts + * import { buildWorkspaceWhere } from '../utils/workspace'; + * + * class AgentModel { + * constructor(db, userId, workspaceId) { ... } + * + * findById = (id) => + * this.db.query.agents.findFirst({ + * where: and( + * eq(agents.id, id), + * buildWorkspaceWhere( + * { userId: this.userId, workspaceId: this.workspaceId }, + * agents, + * ), + * ), + * }); + * } + * ``` + */ +export function buildWorkspaceWhere( + ctx: { userId: string; workspaceId?: string }, + cols: { userId: AnyPgColumn; workspaceId: AnyPgColumn }, +): SQL { + return ctx.workspaceId + ? eq(cols.workspaceId, ctx.workspaceId) + : (and(eq(cols.userId, ctx.userId), isNull(cols.workspaceId)) as SQL); +} + +/** + * Companion to `buildWorkspaceWhere` for INSERT payloads. + * + * Always sets `userId` (the creator) and `workspaceId` (nullable). Personal-mode + * writes get `workspaceId: null`; team-mode writes get the workspace id. + * + * @example + * ```ts + * await db.insert(agents).values( + * buildWorkspacePayload( + * { userId: ctx.userId, workspaceId: ctx.workspaceId }, + * { title: input.title, description: input.description }, + * ), + * ); + * ``` + */ +export function buildWorkspacePayload( + ctx: { userId: string; workspaceId?: string }, + base: T, +): T & { userId: string; workspaceId: string | null } { + return { + ...base, + userId: ctx.userId, + workspaceId: ctx.workspaceId ?? null, + }; +} diff --git a/packages/openapi/src/app.ts b/packages/openapi/src/app.ts index 15099f8fdc..48665e13af 100644 --- a/packages/openapi/src/app.ts +++ b/packages/openapi/src/app.ts @@ -5,6 +5,7 @@ import { prettyJSON } from 'hono/pretty-json'; // Import user authentication middleware (supports both OIDC and API Key authentication) import { userAuthMiddleware } from './middleware/auth'; +import { workspaceAuthMiddleware } from './middleware/workspace'; // Import routes import routes from './routes'; @@ -16,6 +17,7 @@ app.use('*', cors()); app.use('*', logger()); app.use('*', prettyJSON()); app.use('*', userAuthMiddleware); // User authentication middleware +app.use('*', workspaceAuthMiddleware); // Error handling middleware app.onError((error: Error, c) => { diff --git a/packages/openapi/src/common/base.controller.ts b/packages/openapi/src/common/base.controller.ts index f1ebbf5c64..333d5cd5c4 100644 --- a/packages/openapi/src/common/base.controller.ts +++ b/packages/openapi/src/common/base.controller.ts @@ -154,6 +154,10 @@ export abstract class BaseController { return c.get('userId') || null; } + protected getWorkspaceId(c: Context): string | undefined { + return c.get('workspaceId') || undefined; + } + /** * Get authentication type (from context set by middleware) * @param c Hono Context @@ -194,7 +198,10 @@ export abstract class BaseController { ): Promise { const rbacModel = await this.getRbacModel(c); - return await rbacModel.hasPermission(RBAC_PERMISSIONS[permissionKey], this.getUserId(c)!); + return await rbacModel.hasPermission(RBAC_PERMISSIONS[permissionKey], { + userId: this.getUserId(c)!, + workspaceId: this.getWorkspaceId(c), + }); } /** @@ -212,7 +219,8 @@ export abstract class BaseController { const hasPermission = await this.hasPermission(c, permission); if (!hasPermission) { throw new HTTPException(403, { - message: errorMessage || `You do not have permission to perform this operation: ${permission}`, + message: + errorMessage || `You do not have permission to perform this operation: ${permission}`, }); } } @@ -230,7 +238,10 @@ export abstract class BaseController { const permissions = permissionKeys.map((permission) => RBAC_PERMISSIONS[permission]); const rbacModel = await this.getRbacModel(c); - return await rbacModel.hasAnyPermission(permissions, this.getUserId(c)!); + return await rbacModel.hasAnyPermission(permissions, { + userId: this.getUserId(c)!, + workspaceId: this.getWorkspaceId(c), + }); } /** diff --git a/packages/openapi/src/common/base.service.test.ts b/packages/openapi/src/common/base.service.test.ts new file mode 100644 index 0000000000..f62afd8984 --- /dev/null +++ b/packages/openapi/src/common/base.service.test.ts @@ -0,0 +1,158 @@ +// @vitest-environment node +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import type { LobeChatDatabase } from '@/database/type'; + +import { BaseService } from './base.service'; + +const { + mockBuildWorkspacePayload, + mockBuildWorkspaceWhere, + mockGetScopePermissions, + mockHasAnyPermission, +} = vi.hoisted(() => ({ + mockBuildWorkspacePayload: vi.fn(), + mockBuildWorkspaceWhere: vi.fn(), + mockGetScopePermissions: vi.fn(), + mockHasAnyPermission: vi.fn(), +})); + +vi.mock('@lobechat/database', () => ({ + buildWorkspacePayload: mockBuildWorkspacePayload, + buildWorkspaceWhere: mockBuildWorkspaceWhere, +})); + +vi.mock('@/const/rbac', () => ({ + ALL_SCOPE: 'all', +})); + +vi.mock('@/database/models/rbac', () => ({ + RbacModel: class { + hasAnyPermission = mockHasAnyPermission; + }, +})); + +vi.mock('@/database/schemas', () => ({ + agents: {}, + aiModels: {}, + aiProviders: {}, + files: {}, + knowledgeBases: {}, + messages: {}, + sessions: {}, + topics: {}, +})); + +vi.mock('@/utils/rbac', () => ({ + getScopePermissions: mockGetScopePermissions, +})); + +class TestService extends BaseService { + workspaceWhere(cols: Parameters[0]) { + return this.buildWorkspaceWhere(cols); + } + + workspacePayload(base: T) { + return this.buildWorkspacePayload(base); + } + + permissionWhere( + cols: Parameters[0], + condition?: Parameters[1], + ) { + return this.buildPermissionWhere(cols, condition); + } + + globalPermission(permissionKey: Parameters[0]) { + return this.hasGlobalPermission(permissionKey); + } + + ownerPermission(permissionKey: Parameters[0]) { + return this.hasOwnerPermission(permissionKey); + } +} + +const cols = { + userId: 'table.userId', + workspaceId: 'table.workspaceId', +} as any; + +describe('BaseService workspace helpers', () => { + beforeEach(() => { + vi.clearAllMocks(); + mockBuildWorkspaceWhere.mockReturnValue('workspace-where'); + mockBuildWorkspacePayload.mockImplementation((context, base) => ({ + ...base, + userId: context.userId, + workspaceId: context.workspaceId ?? null, + })); + mockGetScopePermissions.mockReturnValue(['resolved-permission']); + mockHasAnyPermission.mockResolvedValue(true); + }); + + it('builds workspace where conditions from the current service context', () => { + const service = new TestService({} as LobeChatDatabase, 'user-1', 'workspace-1'); + + expect(service.workspaceWhere(cols)).toBe('workspace-where'); + expect(mockBuildWorkspaceWhere).toHaveBeenCalledWith( + { userId: 'user-1', workspaceId: 'workspace-1' }, + cols, + ); + }); + + it('builds insert payloads with workspace ownership fields', () => { + const service = new TestService({} as LobeChatDatabase, 'user-1', 'workspace-1'); + + expect(service.workspacePayload({ name: 'Provider' })).toEqual({ + name: 'Provider', + userId: 'user-1', + workspaceId: 'workspace-1', + }); + expect(mockBuildWorkspacePayload).toHaveBeenCalledWith( + { userId: 'user-1', workspaceId: 'workspace-1' }, + { name: 'Provider' }, + ); + }); + + it('keeps permission checks scoped to the active workspace owner context', () => { + const service = new TestService({} as LobeChatDatabase, 'user-1', 'workspace-1'); + + expect(service.permissionWhere(cols, { userId: 'other-user' })).toBe('workspace-where'); + expect(mockBuildWorkspaceWhere).toHaveBeenCalledWith( + { userId: 'user-1', workspaceId: 'workspace-1' }, + cols, + ); + }); + + it('uses the requested owner condition in personal context', () => { + const service = new TestService({} as LobeChatDatabase, 'user-1'); + + expect(service.permissionWhere(cols, { userId: 'other-user' })).toBe('workspace-where'); + expect(mockBuildWorkspaceWhere).toHaveBeenCalledWith({ userId: 'other-user' }, cols); + }); + + it('does not add a permission where clause in personal context without an owner condition', () => { + const service = new TestService({} as LobeChatDatabase, 'user-1'); + + expect(service.permissionWhere(cols)).toBeUndefined(); + expect(mockBuildWorkspaceWhere).not.toHaveBeenCalled(); + }); + + it('passes workspace context into global and owner RBAC permission checks', async () => { + const service = new TestService({} as LobeChatDatabase, 'user-1', 'workspace-1'); + + await expect(service.globalPermission('agents:create' as any)).resolves.toBe(true); + await expect(service.ownerPermission('agents:update' as any)).resolves.toBe(true); + + expect(mockGetScopePermissions).toHaveBeenNthCalledWith(1, 'agents:create', ['ALL']); + expect(mockGetScopePermissions).toHaveBeenNthCalledWith(2, 'agents:update', ['OWNER']); + expect(mockHasAnyPermission).toHaveBeenNthCalledWith(1, ['resolved-permission'], { + userId: 'user-1', + workspaceId: 'workspace-1', + }); + expect(mockHasAnyPermission).toHaveBeenNthCalledWith(2, ['resolved-permission'], { + userId: 'user-1', + workspaceId: 'workspace-1', + }); + }); +}); diff --git a/packages/openapi/src/common/base.service.ts b/packages/openapi/src/common/base.service.ts index c89d1e660e..15c84c3a17 100644 --- a/packages/openapi/src/common/base.service.ts +++ b/packages/openapi/src/common/base.service.ts @@ -1,4 +1,6 @@ -import { and, eq, inArray } from 'drizzle-orm'; +import { buildWorkspacePayload, buildWorkspaceWhere } from '@lobechat/database'; +import { and, eq, inArray, type SQL } from 'drizzle-orm'; +import type { AnyPgColumn } from 'drizzle-orm/pg-core'; import type { PERMISSION_ACTIONS } from '@/const/rbac'; import { ALL_SCOPE } from '@/const/rbac'; @@ -31,15 +33,37 @@ const isNilOrEmptyObject = (value: unknown): boolean => { */ export abstract class BaseService implements IBaseService { protected userId: string; + protected workspaceId?: string; public db: LobeChatDatabase; private rbacModel: RbacModel; - constructor(db: LobeChatDatabase, userId: string | null) { + constructor(db: LobeChatDatabase, userId: string | null, workspaceId?: string) { this.db = db; this.userId = userId || ''; + this.workspaceId = workspaceId; this.rbacModel = new RbacModel(db, this.userId); } + protected buildWorkspaceWhere(cols: { userId: AnyPgColumn; workspaceId: AnyPgColumn }): SQL { + return buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, cols); + } + + protected buildWorkspacePayload( + base: T, + ): T & { userId: string; workspaceId: string | null } { + return buildWorkspacePayload({ userId: this.userId, workspaceId: this.workspaceId }, base); + } + + protected buildPermissionWhere( + cols: { userId: AnyPgColumn; workspaceId: AnyPgColumn }, + condition?: { userId?: string }, + ): SQL | undefined { + if (this.workspaceId) + return buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, cols); + if (condition?.userId) return buildWorkspaceWhere({ userId: condition.userId }, cols); + return; + } + /** * Business error class */ @@ -158,10 +182,10 @@ export abstract class BaseService implements IBaseService { protected async hasGlobalPermission( permissionKey: keyof typeof PERMISSION_ACTIONS, ): Promise { - return await this.rbacModel.hasAnyPermission( - getScopePermissions(permissionKey, ['ALL']), - this.userId, - ); + return await this.rbacModel.hasAnyPermission(getScopePermissions(permissionKey, ['ALL']), { + userId: this.userId, + workspaceId: this.workspaceId, + }); } /** @@ -172,10 +196,10 @@ export abstract class BaseService implements IBaseService { protected async hasOwnerPermission( permissionKey: keyof typeof PERMISSION_ACTIONS, ): Promise { - return await this.rbacModel.hasAnyPermission( - getScopePermissions(permissionKey, ['OWNER']), - this.userId, - ); + return await this.rbacModel.hasAnyPermission(getScopePermissions(permissionKey, ['OWNER']), { + userId: this.userId, + workspaceId: this.workspaceId, + }); } /** @@ -339,7 +363,11 @@ export abstract class BaseService implements IBaseService { * When the user has ALL permission, pass the check directly */ if (hasGlobalAccess) { - this.log('info', `Permission granted: current user has highest ${permissionKey} permission`, logContext); + this.log( + 'info', + `Permission granted: current user has highest ${permissionKey} permission`, + logContext, + ); return { condition: resourceBelongTo ? { userId: resourceBelongTo } : undefined, isPermitted: true, @@ -352,7 +380,11 @@ export abstract class BaseService implements IBaseService { * 2. Querying a specific user's data, but the target resource does not belong to the current user */ if (!resourceBelongTo || resourceBelongTo !== this.userId) { - this.log('warn', 'Permission denied: current user has no ALL permission, or target resource does not belong to current user', logContext); + this.log( + 'warn', + 'Permission denied: current user has no ALL permission, or target resource does not belong to current user', + logContext, + ); return { isPermitted: false, message: `no permission,current user has no ALL permission,and resource not belong to current user`, @@ -375,7 +407,11 @@ export abstract class BaseService implements IBaseService { }; } - this.log('warn', 'Permission denied: target resource belongs to current user, but user has no owner permission for this operation', logContext); + this.log( + 'warn', + 'Permission denied: target resource belongs to current user, but user has no owner permission for this operation', + logContext, + ); return { isPermitted: false, message: `no permission,resource belong to current user,but current user has no any ${permissionKey} permission`, @@ -414,7 +450,10 @@ export abstract class BaseService implements IBaseService { // If the user has global permission, allow the batch operation directly if (hasGlobalAccess) { - this.log('info', `Permission granted: batch operation, current user has ${permissionKey} ALL permission`); + this.log( + 'info', + `Permission granted: batch operation, current user has ${permissionKey} ALL permission`, + ); return { isPermitted: true }; } @@ -534,11 +573,15 @@ export abstract class BaseService implements IBaseService { } // If all resources belong to the current user but the user has no owner permission, deny the operation - this.log('warn', 'Permission denied: batch operation requires ${permissionKey} ALL/owner permission', { - permissionKey, - targetInfoIds, - userIds, - }); + this.log( + 'warn', + 'Permission denied: batch operation requires ${permissionKey} ALL/owner permission', + { + permissionKey, + targetInfoIds, + userIds, + }, + ); return { isPermitted: false, message: `no permission for batch operation, current user has no ${permissionKey} ALL/owner permission`, @@ -546,11 +589,15 @@ export abstract class BaseService implements IBaseService { } // Some resources in the operation do not belong to the current user; deny directly - this.log('warn', `Permission denied: batch operation requires ${permissionKey} ALL/owner permission`, { - permissionKey, - targetInfoIds, - userIds, - }); + this.log( + 'warn', + `Permission denied: batch operation requires ${permissionKey} ALL/owner permission`, + { + permissionKey, + targetInfoIds, + userIds, + }, + ); return { isPermitted: false, diff --git a/packages/openapi/src/controllers/agent-group.controller.ts b/packages/openapi/src/controllers/agent-group.controller.ts index df5796feff..29681da3ea 100644 --- a/packages/openapi/src/controllers/agent-group.controller.ts +++ b/packages/openapi/src/controllers/agent-group.controller.ts @@ -22,7 +22,11 @@ export class AgentGroupController extends BaseController { async getAgentGroups(c: Context): Promise { try { const db = await this.getDatabase(); - const agentGroupService = new AgentGroupService(db, this.getUserId(c)); + const agentGroupService = new AgentGroupService( + db, + this.getUserId(c), + this.getWorkspaceId(c), + ); const agentGroups = await agentGroupService.getAgentGroups(); return this.success(c, agentGroups, 'Agent category list retrieved successfully'); @@ -46,7 +50,11 @@ export class AgentGroupController extends BaseController { } const db = await this.getDatabase(); - const agentGroupService = new AgentGroupService(db, this.getUserId(c)); + const agentGroupService = new AgentGroupService( + db, + this.getUserId(c), + this.getWorkspaceId(c), + ); const agentGroup = await agentGroupService.getAgentGroupById(groupId); if (!agentGroup) { @@ -70,7 +78,11 @@ export class AgentGroupController extends BaseController { const body = await this.getBody(c); const db = await this.getDatabase(); - const agentGroupService = new AgentGroupService(db, this.getUserId(c)); + const agentGroupService = new AgentGroupService( + db, + this.getUserId(c), + this.getWorkspaceId(c), + ); const groupId = await agentGroupService.createAgentGroup(body); return c.json( @@ -108,7 +120,11 @@ export class AgentGroupController extends BaseController { }; const db = await this.getDatabase(); - const agentGroupService = new AgentGroupService(db, this.getUserId(c)); + const agentGroupService = new AgentGroupService( + db, + this.getUserId(c), + this.getWorkspaceId(c), + ); await agentGroupService.updateAgentGroup(request); return this.success(c, null, 'Agent category updated successfully'); @@ -136,7 +152,11 @@ export class AgentGroupController extends BaseController { }; const db = await this.getDatabase(); - const agentGroupService = new AgentGroupService(db, this.getUserId(c)); + const agentGroupService = new AgentGroupService( + db, + this.getUserId(c), + this.getWorkspaceId(c), + ); await agentGroupService.deleteAgentGroup(request); return this.success(c, null, 'Agent category deleted successfully'); diff --git a/packages/openapi/src/controllers/agent.controller.ts b/packages/openapi/src/controllers/agent.controller.ts index a2bc2d8adf..c899b3075e 100644 --- a/packages/openapi/src/controllers/agent.controller.ts +++ b/packages/openapi/src/controllers/agent.controller.ts @@ -25,7 +25,7 @@ export class AgentController extends BaseController { const request = await this.getQuery(c); const db = await this.getDatabase(); - const agentService = new AgentService(db, this.getUserId(c)); + const agentService = new AgentService(db, this.getUserId(c), this.getWorkspaceId(c)); const agentsList = await agentService.queryAgents(request); return this.success(c, agentsList, 'Agent list retrieved successfully'); @@ -45,7 +45,7 @@ export class AgentController extends BaseController { const body = await this.getBody(c); const db = await this.getDatabase(); - const agentService = new AgentService(db, this.getUserId(c)); + const agentService = new AgentService(db, this.getUserId(c), this.getWorkspaceId(c)); const createdAgent = await agentService.createAgent(body); return this.success(c, createdAgent, 'Agent created successfully'); @@ -71,7 +71,7 @@ export class AgentController extends BaseController { }; const db = await this.getDatabase(); - const agentService = new AgentService(db, this.getUserId(c)); + const agentService = new AgentService(db, this.getUserId(c), this.getWorkspaceId(c)); const updatedAgent = await agentService.updateAgent(updateRequest); return this.success(c, updatedAgent, 'Agent updated successfully'); @@ -92,7 +92,7 @@ export class AgentController extends BaseController { const request: AgentDeleteRequest = { agentId: id }; const db = await this.getDatabase(); - const agentService = new AgentService(db, this.getUserId(c)); + const agentService = new AgentService(db, this.getUserId(c), this.getWorkspaceId(c)); await agentService.deleteAgent(request); return this.success(c, null, 'Agent deleted successfully'); @@ -111,7 +111,7 @@ export class AgentController extends BaseController { try { const { id: agentId } = this.getParams<{ id: string }>(c); const db = await this.getDatabase(); - const agentService = new AgentService(db, this.getUserId(c)); + const agentService = new AgentService(db, this.getUserId(c), this.getWorkspaceId(c)); const agent = await agentService.getAgentById(agentId); if (!agent) { diff --git a/packages/openapi/src/controllers/chat.controller.ts b/packages/openapi/src/controllers/chat.controller.ts index cce0c1d480..1d2ba6c89c 100644 --- a/packages/openapi/src/controllers/chat.controller.ts +++ b/packages/openapi/src/controllers/chat.controller.ts @@ -20,7 +20,7 @@ export class ChatController extends BaseController { const chatParams = (await this.getBody(c))!; const db = await this.getDatabase(); - const chatService = new ChatService(db, userId); + const chatService = new ChatService(db, userId, this.getWorkspaceId(c)); // If streaming response, return directly if (chatParams.stream) { @@ -45,7 +45,7 @@ export class ChatController extends BaseController { const translateParams = (await this.getBody(c))!; const db = await this.getDatabase(); - const chatService = new ChatService(db, userId); + const chatService = new ChatService(db, userId, this.getWorkspaceId(c)); const result = await chatService.translate(translateParams); return this.success(c, { translatedText: result }, 'Translation successful'); @@ -65,7 +65,7 @@ export class ChatController extends BaseController { const generationParams = (await this.getBody(c))!; const db = await this.getDatabase(); - const chatService = new ChatService(db, userId); + const chatService = new ChatService(db, userId, this.getWorkspaceId(c)); const result = await chatService.generateReply(generationParams); return this.success(c, { reply: result }, 'Reply generated successfully'); diff --git a/packages/openapi/src/controllers/file.controller.ts b/packages/openapi/src/controllers/file.controller.ts index 3da9ddfa53..42256ad0cf 100644 --- a/packages/openapi/src/controllers/file.controller.ts +++ b/packages/openapi/src/controllers/file.controller.ts @@ -27,7 +27,7 @@ export class FileController extends BaseController { const userId = this.getUserId(c)!; // requireAuth middleware ensures userId exists const db = await this.getDatabase(); - const fileService = new FileUploadService(db, userId); + const fileService = new FileUploadService(db, userId, this.getWorkspaceId(c)); // Process multipart/form-data (returns object: { fields, files }) const formData = await this.getFormData(c); @@ -87,7 +87,7 @@ export class FileController extends BaseController { const query = this.getQuery(c) as FileListQuery; const db = await this.getDatabase(); - const fileService = new FileUploadService(db, userId); + const fileService = new FileUploadService(db, userId, this.getWorkspaceId(c)); const result = await fileService.getFileList(query); @@ -106,7 +106,7 @@ export class FileController extends BaseController { const userId = this.getUserId(c)!; // requireAuth middleware ensures userId exists const { id } = this.getParams(c); const db = await this.getDatabase(); - const fileService = new FileUploadService(db, userId); + const fileService = new FileUploadService(db, userId, this.getWorkspaceId(c)); const result = await fileService.getFileDetail(id); @@ -132,7 +132,7 @@ export class FileController extends BaseController { }; const db = await this.getDatabase(); - const fileService = new FileUploadService(db, userId); + const fileService = new FileUploadService(db, userId, this.getWorkspaceId(c)); const result = await fileService.getFileUrl(id, options); @@ -151,7 +151,7 @@ export class FileController extends BaseController { const userId = this.getUserId(c)!; // requireAuth middleware ensures userId exists const db = await this.getDatabase(); - const fileService = new FileUploadService(db, userId); + const fileService = new FileUploadService(db, userId, this.getWorkspaceId(c)); const formData = await this.getFormData(c); const file = formData.get('file') as File | null; @@ -201,7 +201,7 @@ export class FileController extends BaseController { }; const db = await this.getDatabase(); - const fileService = new FileUploadService(db, userId); + const fileService = new FileUploadService(db, userId, this.getWorkspaceId(c)); const result = await fileService.parseFile(id, options); @@ -222,7 +222,7 @@ export class FileController extends BaseController { const body = await this.getBody>(c); const db = await this.getDatabase(); - const fileService = new FileUploadService(db, userId); + const fileService = new FileUploadService(db, userId, this.getWorkspaceId(c)); const result = await fileService.createChunkTask(id, { autoEmbedding: body?.autoEmbedding, @@ -245,7 +245,7 @@ export class FileController extends BaseController { const { id } = this.getParams(c); const db = await this.getDatabase(); - const fileService = new FileUploadService(db, userId); + const fileService = new FileUploadService(db, userId, this.getWorkspaceId(c)); const result = await fileService.getFileChunkStatus(id); @@ -264,7 +264,7 @@ export class FileController extends BaseController { const userId = this.getUserId(c)!; // requireAuth middleware ensures userId exists const { id } = this.getParams(c); const db = await this.getDatabase(); - const fileService = new FileUploadService(db, userId); + const fileService = new FileUploadService(db, userId, this.getWorkspaceId(c)); const result = await fileService.deleteFile(id); @@ -288,7 +288,7 @@ export class FileController extends BaseController { } const db = await this.getDatabase(); - const fileService = new FileUploadService(db, userId); + const fileService = new FileUploadService(db, userId, this.getWorkspaceId(c)); const result = await fileService.handleQueries(body); @@ -309,7 +309,7 @@ export class FileController extends BaseController { const body = await this.getBody(c); const db = await this.getDatabase(); - const fileService = new FileUploadService(db, userId); + const fileService = new FileUploadService(db, userId, this.getWorkspaceId(c)); const result = await fileService.updateFile(id, body); diff --git a/packages/openapi/src/controllers/knowledge-base.controller.ts b/packages/openapi/src/controllers/knowledge-base.controller.ts index 443c456e02..612af2665b 100644 --- a/packages/openapi/src/controllers/knowledge-base.controller.ts +++ b/packages/openapi/src/controllers/knowledge-base.controller.ts @@ -27,7 +27,7 @@ export class KnowledgeBaseController extends BaseController { const query = this.getQuery(c) as KnowledgeBaseListQuery; const db = await this.getDatabase(); - const knowledgeBaseService = new KnowledgeBaseService(db, userId); + const knowledgeBaseService = new KnowledgeBaseService(db, userId, this.getWorkspaceId(c)); const result = await knowledgeBaseService.getKnowledgeBaseList(query); @@ -47,7 +47,7 @@ export class KnowledgeBaseController extends BaseController { const { id } = this.getParams(c); const db = await this.getDatabase(); - const knowledgeBaseService = new KnowledgeBaseService(db, userId); + const knowledgeBaseService = new KnowledgeBaseService(db, userId, this.getWorkspaceId(c)); const result = await knowledgeBaseService.getKnowledgeBaseDetail(id); @@ -68,7 +68,7 @@ export class KnowledgeBaseController extends BaseController { const query = this.getQuery(c) as KnowledgeBaseFileListQuery; const db = await this.getDatabase(); - const fileService = new FileUploadService(db, userId); + const fileService = new FileUploadService(db, userId, this.getWorkspaceId(c)); const result = await fileService.getKnowledgeBaseFileList(id, query); @@ -89,7 +89,7 @@ export class KnowledgeBaseController extends BaseController { const body = await this.getBody(c); const db = await this.getDatabase(); - const fileService = new FileUploadService(db, userId); + const fileService = new FileUploadService(db, userId, this.getWorkspaceId(c)); const result = await fileService.addFilesToKnowledgeBase(id, body); @@ -110,7 +110,7 @@ export class KnowledgeBaseController extends BaseController { const body = await this.getBody(c); const db = await this.getDatabase(); - const fileService = new FileUploadService(db, userId); + const fileService = new FileUploadService(db, userId, this.getWorkspaceId(c)); const result = await fileService.removeFilesFromKnowledgeBase(id, body); @@ -131,7 +131,7 @@ export class KnowledgeBaseController extends BaseController { const body = await this.getBody(c); const db = await this.getDatabase(); - const fileService = new FileUploadService(db, userId); + const fileService = new FileUploadService(db, userId, this.getWorkspaceId(c)); const result = await fileService.moveFilesBetweenKnowledgeBases(id, body); @@ -151,7 +151,7 @@ export class KnowledgeBaseController extends BaseController { const body = await this.getBody(c); const db = await this.getDatabase(); - const knowledgeBaseService = new KnowledgeBaseService(db, userId); + const knowledgeBaseService = new KnowledgeBaseService(db, userId, this.getWorkspaceId(c)); const result = await knowledgeBaseService.createKnowledgeBase(body); @@ -172,7 +172,7 @@ export class KnowledgeBaseController extends BaseController { const body = await this.getBody(c); const db = await this.getDatabase(); - const knowledgeBaseService = new KnowledgeBaseService(db, userId); + const knowledgeBaseService = new KnowledgeBaseService(db, userId, this.getWorkspaceId(c)); const result = await knowledgeBaseService.updateKnowledgeBase(id, body); @@ -192,7 +192,7 @@ export class KnowledgeBaseController extends BaseController { const { id } = this.getParams(c); const db = await this.getDatabase(); - const knowledgeBaseService = new KnowledgeBaseService(db, userId); + const knowledgeBaseService = new KnowledgeBaseService(db, userId, this.getWorkspaceId(c)); const result = await knowledgeBaseService.deleteKnowledgeBase(id); diff --git a/packages/openapi/src/controllers/message-translation.controller.ts b/packages/openapi/src/controllers/message-translation.controller.ts index 2d15367fd8..4f933b2996 100644 --- a/packages/openapi/src/controllers/message-translation.controller.ts +++ b/packages/openapi/src/controllers/message-translation.controller.ts @@ -20,7 +20,7 @@ export class MessageTranslationController extends BaseController { const { messageId } = this.getParams<{ messageId: string }>(c); const db = await this.getDatabase(); - const translateService = new MessageTranslateService(db, userId); + const translateService = new MessageTranslateService(db, userId, this.getWorkspaceId(c)); const translate = await translateService.getTranslateByMessageId(messageId); return this.success(c, translate, 'Translation info retrieved successfully'); @@ -41,7 +41,7 @@ export class MessageTranslationController extends BaseController { const translatePayload = (await this.getBody(c))!; const db = await this.getDatabase(); - const translateService = new MessageTranslateService(db, userId); + const translateService = new MessageTranslateService(db, userId, this.getWorkspaceId(c)); const result = await translateService.translateMessage({ messageId, ...translatePayload, @@ -65,7 +65,7 @@ export class MessageTranslationController extends BaseController { const configData = (await this.getBody(c))!; const db = await this.getDatabase(); - const translateService = new MessageTranslateService(db, userId); + const translateService = new MessageTranslateService(db, userId, this.getWorkspaceId(c)); const result = await translateService.updateTranslateInfo({ ...configData, messageId }); return this.success(c, result, 'Translation info updated successfully'); @@ -84,7 +84,7 @@ export class MessageTranslationController extends BaseController { const { messageId } = this.getParams<{ messageId: string }>(c); const db = await this.getDatabase(); - const translateService = new MessageTranslateService(db, userId); + const translateService = new MessageTranslateService(db, userId, this.getWorkspaceId(c)); const result = await translateService.deleteTranslateByMessageId(messageId); return this.success(c, result, 'Translation info deleted successfully'); diff --git a/packages/openapi/src/controllers/message.controller.ts b/packages/openapi/src/controllers/message.controller.ts index 64de38c034..c0d83db868 100644 --- a/packages/openapi/src/controllers/message.controller.ts +++ b/packages/openapi/src/controllers/message.controller.ts @@ -27,7 +27,7 @@ export class MessageController extends BaseController { }; const db = await this.getDatabase(); - const messageService = new MessageService(db, userId); + const messageService = new MessageService(db, userId, this.getWorkspaceId(c)); const result = await messageService.countMessages(processedQuery); return this.success(c, result, 'Message count retrieved successfully'); @@ -47,7 +47,7 @@ export class MessageController extends BaseController { const request = this.getQuery(c); const db = await this.getDatabase(); - const messageService = new MessageService(db, userId); + const messageService = new MessageService(db, userId, this.getWorkspaceId(c)); const result = await messageService.getMessages(request); return this.success(c, result, 'Message list retrieved successfully'); @@ -67,7 +67,7 @@ export class MessageController extends BaseController { const { id } = this.getParams<{ id: string }>(c); const db = await this.getDatabase(); - const messageService = new MessageService(db, userId); + const messageService = new MessageService(db, userId, this.getWorkspaceId(c)); const message = await messageService.getMessageById(id); if (!message) { @@ -91,7 +91,7 @@ export class MessageController extends BaseController { const messageData = (await this.getBody(c))!; const db = await this.getDatabase(); - const messageService = new MessageService(db, userId); + const messageService = new MessageService(db, userId, this.getWorkspaceId(c)); const result = await messageService.createMessage(messageData); return this.success(c, result, 'Message created successfully'); @@ -111,7 +111,7 @@ export class MessageController extends BaseController { const messageData = (await this.getBody(c))!; const db = await this.getDatabase(); - const messageService = new MessageService(db, userId); + const messageService = new MessageService(db, userId, this.getWorkspaceId(c)); const result = await messageService.createMessageWithAIReply(messageData); return this.success(c, result, 'Message created and AI reply generated successfully'); @@ -131,7 +131,7 @@ export class MessageController extends BaseController { const { id } = this.getParams<{ id: string }>(c); const db = await this.getDatabase(); - const messageService = new MessageService(db, userId); + const messageService = new MessageService(db, userId, this.getWorkspaceId(c)); await messageService.deleteMessage(id); return this.success(c, null, 'Message deleted successfully'); @@ -151,7 +151,7 @@ export class MessageController extends BaseController { const { messageIds } = (await this.getBody(c))!; const db = await this.getDatabase(); - const messageService = new MessageService(db, userId); + const messageService = new MessageService(db, userId, this.getWorkspaceId(c)); const result = await messageService.deleteBatchMessages(messageIds); return this.success(c, result, 'Messages deleted in batch successfully'); diff --git a/packages/openapi/src/controllers/model.controller.ts b/packages/openapi/src/controllers/model.controller.ts index 45195c68c4..67bc6d22db 100644 --- a/packages/openapi/src/controllers/model.controller.ts +++ b/packages/openapi/src/controllers/model.controller.ts @@ -15,7 +15,7 @@ export class ModelController extends BaseController { const query = this.getQuery(c); const db = await this.getDatabase(); - const modelService = new ModelService(db, this.getUserId(c)); + const modelService = new ModelService(db, this.getUserId(c), this.getWorkspaceId(c)); const result = await modelService.getModels(query); @@ -34,7 +34,7 @@ export class ModelController extends BaseController { const { providerId, modelId } = this.getParams<{ modelId: string; providerId: string }>(c); const db = await this.getDatabase(); - const modelService = new ModelService(db, this.getUserId(c)); + const modelService = new ModelService(db, this.getUserId(c), this.getWorkspaceId(c)); const result = await modelService.getModelDetail(providerId, modelId); return this.success(c, result, 'Model details retrieved successfully'); @@ -56,7 +56,7 @@ export class ModelController extends BaseController { } const db = await this.getDatabase(); - const modelService = new ModelService(db, this.getUserId(c)); + const modelService = new ModelService(db, this.getUserId(c), this.getWorkspaceId(c)); const result = await modelService.createModel(body); return this.success(c, result, 'Model created successfully'); @@ -79,7 +79,7 @@ export class ModelController extends BaseController { } const db = await this.getDatabase(); - const modelService = new ModelService(db, this.getUserId(c)); + const modelService = new ModelService(db, this.getUserId(c), this.getWorkspaceId(c)); const result = await modelService.updateModel(providerId, modelId, body); return this.success(c, result, 'Model updated successfully'); diff --git a/packages/openapi/src/controllers/permission.controller.ts b/packages/openapi/src/controllers/permission.controller.ts index d390824e54..1e52392519 100644 --- a/packages/openapi/src/controllers/permission.controller.ts +++ b/packages/openapi/src/controllers/permission.controller.ts @@ -21,7 +21,11 @@ export class PermissionController extends BaseController { const query = this.getQuery(c); const db = await this.getDatabase(); - const permissionService = new PermissionService(db, this.getUserId(c)); + const permissionService = new PermissionService( + db, + this.getUserId(c), + this.getWorkspaceId(c), + ); const permissions = await permissionService.getPermissions(query); return this.success(c, permissions, 'Get permission list successfully'); @@ -37,7 +41,11 @@ export class PermissionController extends BaseController { try { const { id } = this.getParams<{ id: string }>(c); const db = await this.getDatabase(); - const permissionService = new PermissionService(db, this.getUserId(c)); + const permissionService = new PermissionService( + db, + this.getUserId(c), + this.getWorkspaceId(c), + ); const permission = await permissionService.getPermissionById(id); if (!permission) { @@ -62,7 +70,11 @@ export class PermissionController extends BaseController { } const db = await this.getDatabase(); - const permissionService = new PermissionService(db, this.getUserId(c)); + const permissionService = new PermissionService( + db, + this.getUserId(c), + this.getWorkspaceId(c), + ); const created = await permissionService.createPermission(body); return this.success(c, created, 'Permission created successfully'); @@ -84,7 +96,11 @@ export class PermissionController extends BaseController { } const db = await this.getDatabase(); - const permissionService = new PermissionService(db, this.getUserId(c)); + const permissionService = new PermissionService( + db, + this.getUserId(c), + this.getWorkspaceId(c), + ); const updated = await permissionService.updatePermission(id, body); return this.success(c, updated, 'Permission updated successfully'); @@ -100,7 +116,11 @@ export class PermissionController extends BaseController { try { const { id } = this.getParams<{ id: string }>(c); const db = await this.getDatabase(); - const permissionService = new PermissionService(db, this.getUserId(c)); + const permissionService = new PermissionService( + db, + this.getUserId(c), + this.getWorkspaceId(c), + ); const result = await permissionService.deletePermission(id); return this.success(c, result, 'Permission deleted successfully'); diff --git a/packages/openapi/src/controllers/provider.controller.ts b/packages/openapi/src/controllers/provider.controller.ts index c4e9854aa5..d0d2afaf4a 100644 --- a/packages/openapi/src/controllers/provider.controller.ts +++ b/packages/openapi/src/controllers/provider.controller.ts @@ -20,7 +20,7 @@ export class ProviderController extends BaseController { try { const query = this.getQuery(c); const db = await this.getDatabase(); - const providerService = new ProviderService(db, this.getUserId(c)); + const providerService = new ProviderService(db, this.getUserId(c), this.getWorkspaceId(c)); const result = await providerService.getProviders(query); @@ -36,7 +36,7 @@ export class ProviderController extends BaseController { const request: GetProviderDetailRequest = { id }; const db = await this.getDatabase(); - const providerService = new ProviderService(db, this.getUserId(c)); + const providerService = new ProviderService(db, this.getUserId(c), this.getWorkspaceId(c)); const provider = await providerService.getProviderDetail(request); return this.success(c, provider, 'Provider details retrieved successfully'); @@ -50,7 +50,7 @@ export class ProviderController extends BaseController { const body = await this.getBody(c); const db = await this.getDatabase(); - const providerService = new ProviderService(db, this.getUserId(c)); + const providerService = new ProviderService(db, this.getUserId(c), this.getWorkspaceId(c)); const created = await providerService.createProvider({ ...body, source: 'custom' }); return this.success(c, created, 'Provider created successfully'); @@ -70,7 +70,7 @@ export class ProviderController extends BaseController { }; const db = await this.getDatabase(); - const providerService = new ProviderService(db, this.getUserId(c)); + const providerService = new ProviderService(db, this.getUserId(c), this.getWorkspaceId(c)); const updated = await providerService.updateProvider(request); return this.success(c, updated, 'Provider updated successfully'); @@ -85,7 +85,7 @@ export class ProviderController extends BaseController { const request: DeleteProviderRequest = { id }; const db = await this.getDatabase(); - const providerService = new ProviderService(db, this.getUserId(c)); + const providerService = new ProviderService(db, this.getUserId(c), this.getWorkspaceId(c)); const result = await providerService.deleteProvider(request); return this.success(c, result, 'Provider deleted successfully'); diff --git a/packages/openapi/src/controllers/responses.controller.ts b/packages/openapi/src/controllers/responses.controller.ts index 13b083778a..976541ed90 100644 --- a/packages/openapi/src/controllers/responses.controller.ts +++ b/packages/openapi/src/controllers/responses.controller.ts @@ -19,7 +19,7 @@ export class ResponsesController extends BaseController { const body = await this.getBody(c); const userId = this.getUserId(c); const db = await this.getDatabase(); - const service = new ResponsesService(db, userId); + const service = new ResponsesService(db, userId, this.getWorkspaceId(c)); if (body.stream) { return this.handleStreamingResponse(c, service, body); diff --git a/packages/openapi/src/controllers/role.controller.ts b/packages/openapi/src/controllers/role.controller.ts index 7566082ca2..6f0e097483 100644 --- a/packages/openapi/src/controllers/role.controller.ts +++ b/packages/openapi/src/controllers/role.controller.ts @@ -26,7 +26,7 @@ export class RoleController extends BaseController { const request = this.getQuery(c); const db = await this.getDatabase(); - const roleService = new RoleService(db, this.getUserId(c)); + const roleService = new RoleService(db, this.getUserId(c), this.getWorkspaceId(c)); const roles = await roleService.getRoles(request); return this.success(c, roles, 'Get roles list successfully'); @@ -47,7 +47,7 @@ export class RoleController extends BaseController { } const db = await this.getDatabase(); - const roleService = new RoleService(db, this.getUserId(c)); + const roleService = new RoleService(db, this.getUserId(c), this.getWorkspaceId(c)); const createdRole = await roleService.createRole(body); return this.success(c, createdRole, 'Role created successfully'); @@ -66,7 +66,7 @@ export class RoleController extends BaseController { const { id } = this.getParams<{ id: string }>(c); const roleId = id; const db = await this.getDatabase(); - const roleService = new RoleService(db, this.getUserId(c)); + const roleService = new RoleService(db, this.getUserId(c), this.getWorkspaceId(c)); const role = await roleService.getRoleById(roleId); if (!role) { @@ -90,7 +90,7 @@ export class RoleController extends BaseController { const request = this.getQuery(c); const db = await this.getDatabase(); - const roleService = new RoleService(db, this.getUserId(c)); + const roleService = new RoleService(db, this.getUserId(c), this.getWorkspaceId(c)); const permissions = await roleService.getRolePermissions({ roleId: id, ...request }); @@ -113,7 +113,7 @@ export class RoleController extends BaseController { } const db = await this.getDatabase(); - const roleService = new RoleService(db, this.getUserId(c)); + const roleService = new RoleService(db, this.getUserId(c), this.getWorkspaceId(c)); const result = await roleService.updateRolePermissions(id, body); return this.success(c, result, 'Role permissions updated successfully'); @@ -131,7 +131,7 @@ export class RoleController extends BaseController { const { id } = this.getParams<{ id: string }>(c); const roleId = id; const db = await this.getDatabase(); - const roleService = new RoleService(db, this.getUserId(c)); + const roleService = new RoleService(db, this.getUserId(c), this.getWorkspaceId(c)); const result = await roleService.clearRolePermissions(roleId); return this.success(c, result, 'Role permissions cleared'); @@ -156,7 +156,7 @@ export class RoleController extends BaseController { } const db = await this.getDatabase(); - const roleService = new RoleService(db, this.getUserId(c)); + const roleService = new RoleService(db, this.getUserId(c), this.getWorkspaceId(c)); const updatedRole = await roleService.updateRole(id, body); return this.success(c, updatedRole, 'Role updated successfully'); @@ -172,7 +172,7 @@ export class RoleController extends BaseController { try { const { id } = this.getParams<{ id: string }>(c); const db = await this.getDatabase(); - const roleService = new RoleService(db, this.getUserId(c)); + const roleService = new RoleService(db, this.getUserId(c), this.getWorkspaceId(c)); const result = await roleService.deleteRole(id); return this.success(c, result, 'Role deleted successfully'); diff --git a/packages/openapi/src/controllers/topic.controller.ts b/packages/openapi/src/controllers/topic.controller.ts index 5f8be1a810..118b1275f0 100644 --- a/packages/openapi/src/controllers/topic.controller.ts +++ b/packages/openapi/src/controllers/topic.controller.ts @@ -16,7 +16,7 @@ export class TopicController extends BaseController { const request = this.getQuery(c); const db = await this.getDatabase(); - const topicService = new TopicService(db, userId); + const topicService = new TopicService(db, userId, this.getWorkspaceId(c)); const topics = await topicService.getTopics(request); @@ -37,7 +37,7 @@ export class TopicController extends BaseController { const { id } = this.getParams<{ id: string }>(c); const db = await this.getDatabase(); - const topicService = new TopicService(db, userId); + const topicService = new TopicService(db, userId, this.getWorkspaceId(c)); const topic = await topicService.getTopicById(id); return this.success(c, topic, 'Topic retrieved successfully'); @@ -57,7 +57,7 @@ export class TopicController extends BaseController { const payload = await this.getBody(c); const db = await this.getDatabase(); - const topicService = new TopicService(db, userId); + const topicService = new TopicService(db, userId, this.getWorkspaceId(c)); const newTopic = await topicService.createTopic(payload); return this.success(c, newTopic, 'Topic created successfully'); @@ -78,7 +78,7 @@ export class TopicController extends BaseController { const payload = await this.getBody(c); const db = await this.getDatabase(); - const topicService = new TopicService(db, userId); + const topicService = new TopicService(db, userId, this.getWorkspaceId(c)); const updatedTopic = await topicService.updateTopic(id, payload); return this.success(c, updatedTopic, 'Topic updated successfully'); @@ -98,7 +98,7 @@ export class TopicController extends BaseController { const { id: topicId } = this.getParams<{ id: string }>(c); const db = await this.getDatabase(); - const topicService = new TopicService(db, userId); + const topicService = new TopicService(db, userId, this.getWorkspaceId(c)); await topicService.deleteTopic(topicId); return this.success(c, null, 'Topic deleted successfully'); diff --git a/packages/openapi/src/controllers/user.controller.ts b/packages/openapi/src/controllers/user.controller.ts index b16f9949d3..6a8bc24001 100644 --- a/packages/openapi/src/controllers/user.controller.ts +++ b/packages/openapi/src/controllers/user.controller.ts @@ -26,7 +26,7 @@ export class UserController extends BaseController { // Get database connection and create service instance const db = await this.getDatabase(); - const userService = new UserService(db, this.getUserId(c)); + const userService = new UserService(db, this.getUserId(c), this.getWorkspaceId(c)); const userInfo = await userService.getCurrentUser(includeCount); return this.success(c, userInfo, 'User info retrieved successfully'); @@ -46,7 +46,7 @@ export class UserController extends BaseController { // Get database connection and create service instance const db = await this.getDatabase(); - const userService = new UserService(db, this.getUserId(c)); + const userService = new UserService(db, this.getUserId(c), this.getWorkspaceId(c)); const userList = await userService.queryUsers(request); @@ -67,7 +67,7 @@ export class UserController extends BaseController { // Get database connection and create service instance const db = await this.getDatabase(); - const userService = new UserService(db, this.getUserId(c)); + const userService = new UserService(db, this.getUserId(c), this.getWorkspaceId(c)); const newUser = await userService.createUser(userData); return this.success(c, newUser, 'User created successfully'); @@ -87,7 +87,7 @@ export class UserController extends BaseController { // Get database connection and create service instance const db = await this.getDatabase(); - const userService = new UserService(db, this.getUserId(c)); + const userService = new UserService(db, this.getUserId(c), this.getWorkspaceId(c)); const user = await userService.getUserById(id); return this.success(c, user, 'User info retrieved successfully'); @@ -108,7 +108,7 @@ export class UserController extends BaseController { // Get database connection and create service instance const db = await this.getDatabase(); - const userService = new UserService(db, this.getUserId(c)); + const userService = new UserService(db, this.getUserId(c), this.getWorkspaceId(c)); const updatedUser = await userService.updateUser(id, userData); return this.success(c, updatedUser, 'User info updated successfully'); @@ -128,7 +128,7 @@ export class UserController extends BaseController { // Get database connection and create service instance const db = await this.getDatabase(); - const userService = new UserService(db, this.getUserId(c)); + const userService = new UserService(db, this.getUserId(c), this.getWorkspaceId(c)); const result = await userService.deleteUser(id); return this.success(c, result, 'User deleted successfully'); @@ -154,7 +154,7 @@ export class UserController extends BaseController { // Get database connection and create service instance const db = await this.getDatabase(); - const userService = new UserService(db, this.getUserId(c)); + const userService = new UserService(db, this.getUserId(c), this.getWorkspaceId(c)); const result = await userService.updateUserRoles(id, body); return this.success(c, result, 'User roles updated successfully'); @@ -172,7 +172,7 @@ export class UserController extends BaseController { const { id } = this.getParams<{ id: string }>(c); const db = await this.getDatabase(); - const userService = new UserService(db, this.getUserId(c)); + const userService = new UserService(db, this.getUserId(c), this.getWorkspaceId(c)); const result = await userService.clearUserRoles(id); return this.success(c, result, 'User roles cleared'); @@ -193,7 +193,7 @@ export class UserController extends BaseController { // Get database connection and create service instance const db = await this.getDatabase(); - const userService = new UserService(db, this.getUserId(c)); + const userService = new UserService(db, this.getUserId(c), this.getWorkspaceId(c)); const userRoles = await userService.getUserRoles(id); return this.success(c, userRoles, 'User roles retrieved successfully'); diff --git a/packages/openapi/src/middleware/auth.ts b/packages/openapi/src/middleware/auth.ts index ec1eaa117a..26299264b1 100644 --- a/packages/openapi/src/middleware/auth.ts +++ b/packages/openapi/src/middleware/auth.ts @@ -22,6 +22,7 @@ interface ApiKeyCacheEntry { expiresAt: Date | null; timestamp: number; userId: string; + workspaceId?: string | null; } // In-memory cache for API Key validation results @@ -67,6 +68,7 @@ export const userAuthMiddleware = async (c: Context, next: Next) => { let userId: string | null = null; let authType: string | null = null; let authData: any = null; + let workspaceId: string | null | undefined; // Try Bearer token authentication - check format first to determine type if (bearerToken) { @@ -92,6 +94,7 @@ export const userAuthMiddleware = async (c: Context, next: Next) => { userId = cachedEntry.userId; authType = 'apikey'; authData = { apiKeyId: cachedEntry.apiKeyId, apiKeyName: cachedEntry.apiKeyName }; + workspaceId = cachedEntry.workspaceId; log( 'API Key authentication successful (from cache), userId: %s, apiKeyId: %d', @@ -134,6 +137,7 @@ export const userAuthMiddleware = async (c: Context, next: Next) => { userId = apiKeyRecord.userId; authType = 'apikey'; authData = { apiKeyId: apiKeyRecord.id, apiKeyName: apiKeyRecord.name }; + workspaceId = apiKeyRecord.workspaceId; // Cache the validated API Key apiKeyCache.set(bearerToken, { @@ -142,6 +146,7 @@ export const userAuthMiddleware = async (c: Context, next: Next) => { expiresAt: apiKeyRecord.expiresAt, timestamp: now, userId: apiKeyRecord.userId, + workspaceId: apiKeyRecord.workspaceId, }); log( @@ -151,7 +156,11 @@ export const userAuthMiddleware = async (c: Context, next: Next) => { ); // Update last used timestamp (fire and forget) - const userApiKeyModel = new ApiKeyModel(db, apiKeyRecord.userId); + const userApiKeyModel = new ApiKeyModel( + db, + apiKeyRecord.userId, + apiKeyRecord.workspaceId ?? undefined, + ); userApiKeyModel.updateLastUsed(apiKeyRecord.id).catch((err) => { log('Failed to update API Key last used timestamp: %O', err); }); @@ -197,6 +206,7 @@ export const userAuthMiddleware = async (c: Context, next: Next) => { c.set('authType', authType); c.set('authData', authData); c.set('authorizationHeader', authorizationHeader); + c.set('workspaceId', workspaceId ?? undefined); log('Authentication successful - userId: %s, authType: %s', userId, authType); } else { diff --git a/packages/openapi/src/middleware/index.ts b/packages/openapi/src/middleware/index.ts index c08bade8bd..2353a366f2 100644 --- a/packages/openapi/src/middleware/index.ts +++ b/packages/openapi/src/middleware/index.ts @@ -1,2 +1,3 @@ export * from './auth'; export * from './permission-check'; +export * from './workspace'; diff --git a/packages/openapi/src/middleware/permission-check.ts b/packages/openapi/src/middleware/permission-check.ts index 39f963c915..01545c42ab 100644 --- a/packages/openapi/src/middleware/permission-check.ts +++ b/packages/openapi/src/middleware/permission-check.ts @@ -72,6 +72,7 @@ const requirePermission = (options: PermissionCheckOptions) => { // Get database instance const serverDB = await getServerDB(); const rbacModel = new RbacModel(serverDB, userId); + const workspaceId = c.get('workspaceId') as string | undefined; let hasPermission = false; const operator = options.operator || 'OR'; @@ -80,9 +81,9 @@ const requirePermission = (options: PermissionCheckOptions) => { // Check permissions based on operator if (operator === 'AND') { - hasPermission = await rbacModel.hasAllPermissions(permissionCodes); + hasPermission = await rbacModel.hasAllPermissions(permissionCodes, { userId, workspaceId }); } else { - hasPermission = await rbacModel.hasAnyPermission(permissionCodes); + hasPermission = await rbacModel.hasAnyPermission(permissionCodes, { userId, workspaceId }); } if (!hasPermission) { diff --git a/packages/openapi/src/middleware/workspace.test.ts b/packages/openapi/src/middleware/workspace.test.ts new file mode 100644 index 0000000000..d686f49416 --- /dev/null +++ b/packages/openapi/src/middleware/workspace.test.ts @@ -0,0 +1,142 @@ +import { Hono } from 'hono'; +import { HTTPException } from 'hono/http-exception'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import { OPENAPI_WORKSPACE_HEADER, workspaceAuthMiddleware } from './workspace'; + +interface TestHonoEnv { + Variables: { + userId: string | null; + workspaceId: string | undefined; + workspaceRole: string | undefined; + }; +} + +const { mockGetServerDB, mockWorkspaceMembersFindFirst, mockWorkspacesFindFirst } = vi.hoisted( + () => ({ + mockGetServerDB: vi.fn(), + mockWorkspaceMembersFindFirst: vi.fn(), + mockWorkspacesFindFirst: vi.fn(), + }), +); + +vi.mock('@/database/core/db-adaptor', () => ({ + getServerDB: mockGetServerDB, +})); + +vi.mock('@/database/schemas', () => ({ + workspaceMembers: { + deletedAt: 'workspaceMembers.deletedAt', + userId: 'workspaceMembers.userId', + workspaceId: 'workspaceMembers.workspaceId', + }, + workspaces: { + id: 'workspaces.id', + }, +})); + +const createApp = (userId: string | null = 'user-1') => { + const app = new Hono(); + + app.onError((error, c) => { + if (error instanceof HTTPException) return error.getResponse(); + + return c.text(error.message, 500); + }); + + app.use('*', async (c, next) => { + c.set('userId', userId); + await next(); + }); + app.use('*', workspaceAuthMiddleware); + app.get('/workspace', (c) => + c.json({ + workspaceId: c.get('workspaceId') ?? null, + workspaceRole: c.get('workspaceRole') ?? null, + }), + ); + + return app; +}; + +describe('OpenAPI workspace middleware', () => { + beforeEach(() => { + vi.clearAllMocks(); + + mockGetServerDB.mockResolvedValue({ + query: { + workspaceMembers: { + findFirst: mockWorkspaceMembersFindFirst, + }, + workspaces: { + findFirst: mockWorkspacesFindFirst, + }, + }, + }); + mockWorkspacesFindFirst.mockResolvedValue({ id: 'workspace-1' }); + mockWorkspaceMembersFindFirst.mockResolvedValue({ role: 'admin' }); + }); + + it('continues in personal context when the workspace header is absent', async () => { + const app = createApp(); + + const response = await app.request('/workspace'); + + await expect(response.json()).resolves.toEqual({ + workspaceId: null, + workspaceRole: null, + }); + expect(response.status).toBe(200); + expect(mockGetServerDB).not.toHaveBeenCalled(); + }); + + it('rejects workspace access when the request is unauthenticated', async () => { + const app = createApp(null); + + const response = await app.request('/workspace', { + headers: { [OPENAPI_WORKSPACE_HEADER]: 'workspace-1' }, + }); + + expect(response.status).toBe(401); + expect(mockGetServerDB).not.toHaveBeenCalled(); + }); + + it('rejects an unknown workspace', async () => { + const app = createApp(); + mockWorkspacesFindFirst.mockResolvedValueOnce(undefined); + + const response = await app.request('/workspace', { + headers: { [OPENAPI_WORKSPACE_HEADER]: 'workspace-missing' }, + }); + + expect(response.status).toBe(404); + expect(mockWorkspaceMembersFindFirst).not.toHaveBeenCalled(); + }); + + it('rejects workspace access when the user is not a member', async () => { + const app = createApp(); + mockWorkspaceMembersFindFirst.mockResolvedValueOnce(undefined); + + const response = await app.request('/workspace', { + headers: { [OPENAPI_WORKSPACE_HEADER]: 'workspace-1' }, + }); + + expect(response.status).toBe(403); + }); + + it('sets workspace context when the user is a workspace member', async () => { + const app = createApp(); + + const response = await app.request('/workspace', { + headers: { [OPENAPI_WORKSPACE_HEADER]: ' workspace-1 ' }, + }); + + await expect(response.json()).resolves.toEqual({ + workspaceId: 'workspace-1', + workspaceRole: 'admin', + }); + expect(response.status).toBe(200); + expect(mockWorkspacesFindFirst).toHaveBeenCalledTimes(1); + expect(mockWorkspaceMembersFindFirst).toHaveBeenCalledTimes(1); + }); +}); diff --git a/packages/openapi/src/middleware/workspace.ts b/packages/openapi/src/middleware/workspace.ts new file mode 100644 index 0000000000..8cad253520 --- /dev/null +++ b/packages/openapi/src/middleware/workspace.ts @@ -0,0 +1,60 @@ +import debug from 'debug'; +import { and, eq, isNull } from 'drizzle-orm'; +import type { Context, Next } from 'hono'; +import { HTTPException } from 'hono/http-exception'; + +import { getServerDB } from '@/database/core/db-adaptor'; +import { workspaceMembers, workspaces } from '@/database/schemas'; + +const log = debug('lobe-hono:workspace-middleware'); + +export const OPENAPI_WORKSPACE_HEADER = 'X-Workspace-Id'; + +export const workspaceAuthMiddleware = async (c: Context, next: Next) => { + const workspaceId = c.req.header(OPENAPI_WORKSPACE_HEADER)?.trim(); + + if (!workspaceId) { + c.set('workspaceId', undefined); + c.set('workspaceRole', undefined); + return next(); + } + + const userId = c.get('userId'); + if (!userId) { + throw new HTTPException(401, { + message: 'Authentication required for workspace access', + }); + } + + const serverDB = await getServerDB(); + const workspace = await serverDB.query.workspaces.findFirst({ + columns: { id: true }, + where: eq(workspaces.id, workspaceId), + }); + + if (!workspace) { + throw new HTTPException(404, { + message: 'Workspace not found', + }); + } + + const membership = await serverDB.query.workspaceMembers.findFirst({ + columns: { role: true }, + where: and( + eq(workspaceMembers.workspaceId, workspaceId), + eq(workspaceMembers.userId, userId), + isNull(workspaceMembers.deletedAt), + ), + }); + + if (!membership) { + log('Workspace membership check failed for user %s workspace %s', userId, workspaceId); + throw new HTTPException(403, { + message: 'Not a member of this workspace', + }); + } + + c.set('workspaceId', workspaceId); + c.set('workspaceRole', membership.role); + return next(); +}; diff --git a/packages/openapi/src/services/agent-group.service.ts b/packages/openapi/src/services/agent-group.service.ts index b7bf7b7e34..5ce32f5e82 100644 --- a/packages/openapi/src/services/agent-group.service.ts +++ b/packages/openapi/src/services/agent-group.service.ts @@ -20,9 +20,9 @@ import type { export class AgentGroupService extends BaseService { private sessionGroupModel: SessionGroupModel; - constructor(db: LobeChatDatabase, userId: string | null) { - super(db, userId); - this.sessionGroupModel = new SessionGroupModel(db, userId!); + constructor(db: LobeChatDatabase, userId: string | null, workspaceId?: string) { + super(db, userId, workspaceId); + this.sessionGroupModel = new SessionGroupModel(db, userId!, workspaceId); } /** @@ -42,9 +42,8 @@ export class AgentGroupService extends BaseService { // Build query conditions const conditions = []; - if (permissionResult.condition?.userId) { - conditions.push(eq(sessionGroups.userId, permissionResult.condition.userId)); - } + const permissionWhere = this.buildPermissionWhere(sessionGroups, permissionResult.condition); + if (permissionWhere) conditions.push(permissionWhere); const agentGroupList = await this.db.query.sessionGroups.findMany({ orderBy: [asc(sessionGroups.sort), desc(sessionGroups.createdAt)], @@ -77,9 +76,8 @@ export class AgentGroupService extends BaseService { // Build query conditions const conditions = [eq(sessionGroups.id, groupId)]; - if (permissionResult.condition?.userId) { - conditions.push(eq(sessionGroups.userId, permissionResult.condition.userId)); - } + const permissionWhere = this.buildPermissionWhere(sessionGroups, permissionResult.condition); + if (permissionWhere) conditions.push(permissionWhere); const agentGroup = await this.db.query.sessionGroups.findFirst({ where: and(...conditions), @@ -116,7 +114,7 @@ export class AgentGroupService extends BaseService { .values({ name: request.name, sort: request.sort, - userId: this.userId, + ...this.buildWorkspacePayload({}), }) .returning(); @@ -157,7 +155,7 @@ export class AgentGroupService extends BaseService { await this.db .update(sessionGroups) .set({ ...updateData, updatedAt: new Date() }) - .where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId))); + .where(and(eq(sessionGroups.id, id), this.buildWorkspaceWhere(sessionGroups))); this.log('info', 'Agent group updated successfully', { id }); } catch (error) { @@ -188,9 +186,8 @@ export class AgentGroupService extends BaseService { // Build query conditions const conditions = [eq(sessionGroups.id, request.id)]; - if (permissionResult.condition?.userId) { - conditions.push(eq(sessionGroups.userId, permissionResult.condition.userId)); - } + const permissionWhere = this.buildPermissionWhere(sessionGroups, permissionResult.condition); + if (permissionWhere) conditions.push(permissionWhere); // Delete agent group; the sessionGroupId of agents in the group will be automatically set to null via database foreign key constraint await this.db.delete(sessionGroups).where(and(...conditions)); diff --git a/packages/openapi/src/services/agent.service.ts b/packages/openapi/src/services/agent.service.ts index 72cbe2bf91..6ccd3b35cf 100644 --- a/packages/openapi/src/services/agent.service.ts +++ b/packages/openapi/src/services/agent.service.ts @@ -22,8 +22,8 @@ import type { * Agent service implementation class */ export class AgentService extends BaseService { - constructor(db: LobeChatDatabase, userId: string | null) { - super(db, userId); + constructor(db: LobeChatDatabase, userId: string | null, workspaceId?: string) { + super(db, userId, workspaceId); } /** @@ -40,7 +40,7 @@ export class AgentService extends BaseService { try { // Base filter: current user + exclude virtual agents (inbox, supervisor, etc.) const baseConditions = and( - eq(agents.userId, this.userId), + this.buildWorkspaceWhere(agents), or(eq(agents.virtual, false), isNull(agents.virtual)), ); @@ -94,7 +94,7 @@ export class AgentService extends BaseService { systemRole: request.systemRole || null, title: request.title, updatedAt: new Date(), - userId: this.userId, + ...this.buildWorkspacePayload({}), }; // Insert into database @@ -129,9 +129,8 @@ export class AgentService extends BaseService { return await this.db.transaction(async (tx) => { // Build query conditions const whereConditions = [eq(agents.id, request.id)]; - if (permissionResult.condition?.userId) { - whereConditions.push(eq(agents.userId, permissionResult.condition.userId)); - } + const permissionWhere = this.buildPermissionWhere(agents, permissionResult.condition); + if (permissionWhere) whereConditions.push(permissionWhere); // Check if the Agent exists const existingAgent = await tx.query.agents.findFirst({ @@ -207,7 +206,7 @@ export class AgentService extends BaseService { // Check if the Agent to be deleted exists const targetAgent = await this.db.query.agents.findFirst({ - where: eq(agents.id, request.agentId), + where: and(eq(agents.id, request.agentId), this.buildWorkspaceWhere(agents)), }); if (!targetAgent) { @@ -217,7 +216,7 @@ export class AgentService extends BaseService { if (request.migrateSessionTo) { // Validate that the migration target Agent exists and belongs to the current user const migrateTarget = await this.db.query.agents.findFirst({ - where: and(eq(agents.id, request.migrateSessionTo), eq(agents.userId, this.userId)), + where: and(eq(agents.id, request.migrateSessionTo), this.buildWorkspaceWhere(agents)), }); if (!migrateTarget) { @@ -235,10 +234,10 @@ export class AgentService extends BaseService { // After migration, delete the agent itself directly; sessions have been transferred so cascade delete is not needed await this.db .delete(agents) - .where(and(eq(agents.id, request.agentId), eq(agents.userId, this.userId))); + .where(and(eq(agents.id, request.agentId), this.buildWorkspaceWhere(agents))); } else { // No migration: reuse AgentModel.delete, which cascades deletion of associated sessions, messages, topics, etc. - const agentModel = new AgentModel(this.db, this.userId); + const agentModel = new AgentModel(this.db, this.userId, this.workspaceId); await agentModel.delete(request.agentId); } @@ -271,7 +270,7 @@ export class AgentService extends BaseService { } // Reuse AgentModel methods to get the full Agent configuration - const agentModel = new AgentModel(this.db, this.userId); + const agentModel = new AgentModel(this.db, this.userId, this.workspaceId); const agent = await agentModel.getAgentConfigById(agentId); if (!agent || !agent.id) { @@ -303,7 +302,7 @@ export class AgentService extends BaseService { .where( and( eq(agentsToSessions.agentId, fromAgentId), - eq(agentsToSessions.userId, this.userId), + this.buildWorkspaceWhere(agentsToSessions), ), ); @@ -318,7 +317,7 @@ export class AgentService extends BaseService { .where( and( eq(agentsToSessions.agentId, fromAgentId), - eq(agentsToSessions.userId, this.userId), + this.buildWorkspaceWhere(agentsToSessions), ), ); @@ -341,7 +340,7 @@ export class AgentService extends BaseService { newSessionIds.map((sessionId) => ({ agentId: toAgentId, sessionId, - userId: this.userId, + ...this.buildWorkspacePayload({}), })), ); } diff --git a/packages/openapi/src/services/chat.service.ts b/packages/openapi/src/services/chat.service.ts index 7826bab4bd..d386a17448 100644 --- a/packages/openapi/src/services/chat.service.ts +++ b/packages/openapi/src/services/chat.service.ts @@ -6,7 +6,7 @@ import { and, eq } from 'drizzle-orm'; import { getBusinessModelRuntimeHooks } from '@/business/server/model-runtime'; import { DEFAULT_AGENT_CHAT_CONFIG, DEFAULT_SYSTEM_AGENT_CONFIG } from '@/const/settings'; import { UserModel } from '@/database/models/user'; -import { agents, agentsToSessions, aiModels } from '@/database/schemas'; +import { agents, agentsToSessions, aiModels, aiProviders } from '@/database/schemas'; import type { LobeChatDatabase } from '@/database/type'; import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt'; import { initModelRuntimeWithUserPayload } from '@/server/modules/ModelRuntime'; @@ -29,13 +29,21 @@ import type { export class ChatService extends BaseService { private config: ChatServiceConfig; - constructor(db: LobeChatDatabase, userId: string | null, config?: ChatServiceConfig) { - super(db, userId); + constructor( + db: LobeChatDatabase, + userId: string | null, + workspaceIdOrConfig?: string | ChatServiceConfig, + config?: ChatServiceConfig, + ) { + const workspaceId = typeof workspaceIdOrConfig === 'string' ? workspaceIdOrConfig : undefined; + const serviceConfig = typeof workspaceIdOrConfig === 'string' ? config : workspaceIdOrConfig; + + super(db, userId, workspaceId); this.config = { defaultModel: 'gpt-3.5-turbo', defaultProvider: 'openai', timeout: 30_000, - ...config, + ...serviceConfig, }; } @@ -117,7 +125,7 @@ export class ChatService extends BaseService { private async getAgentConfig(agentId: string): Promise { try { const agent = await this.db.query.agents.findFirst({ - where: (agents, { eq, and }) => and(eq(agents.id, agentId)), + where: and(eq(agents.id, agentId), this.buildWorkspaceWhere(agents)), }); return agent?.chatConfig || null; @@ -173,8 +181,7 @@ export class ChatService extends BaseService { const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey(); const aiProviderConfigs = await this.db.query.aiProviders.findMany({ - where: (aiProviders, { eq, and }) => - and(eq(aiProviders.userId, this.userId!), eq(aiProviders.id, provider)), + where: and(eq(aiProviders.id, provider), this.buildWorkspaceWhere(aiProviders)), }); if (!aiProviderConfigs || aiProviderConfigs.length === 0) { @@ -630,7 +637,12 @@ export class ChatService extends BaseService { eq(agents.provider, aiModels.providerId), // Ensure provider also matches ), ) - .where(and(eq(agentsToSessions.sessionId, params.sessionId!))); + .where( + and( + eq(agentsToSessions.sessionId, params.sessionId!), + this.buildWorkspaceWhere(agentsToSessions), + ), + ); if (!agentAndModel.length) { this.log('warn', '会话对应的模型配置不存在', { @@ -650,7 +662,10 @@ export class ChatService extends BaseService { // Find the agent corresponding to the session const agentToSession = await this.db.query.agentsToSessions.findFirst({ - where: (agentsToSessions, { eq }) => eq(agentsToSessions.sessionId, params.sessionId!), + where: and( + eq(agentsToSessions.sessionId, params.sessionId!), + this.buildWorkspaceWhere(agentsToSessions), + ), }); if (!agentToSession) { diff --git a/packages/openapi/src/services/file.service.ts b/packages/openapi/src/services/file.service.ts index daa0c4f8cc..ec4d376b97 100644 --- a/packages/openapi/src/services/file.service.ts +++ b/packages/openapi/src/services/file.service.ts @@ -71,16 +71,16 @@ export class FileUploadService extends BaseService { // Lazy import ChunkService to avoid circular dependency overhead // Note: ChunkService is only available in server-side environments - constructor(db: LobeChatDatabase, userId: string) { - super(db, userId); - this.fileModel = new FileModel(db, userId); - this.documentModel = new DocumentModel(db, userId); - this.coreFileService = new CoreFileService(db, userId!); - this.documentService = new DocumentService(db, userId); + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { + super(db, userId, workspaceId); + this.fileModel = new FileModel(db, userId, workspaceId); + this.documentModel = new DocumentModel(db, userId, workspaceId); + this.coreFileService = new CoreFileService(db, userId!, workspaceId); + this.documentService = new DocumentService(db, userId, workspaceId); this.s3Service = new FileS3(); - this.chunkModel = new ChunkModel(db, userId); - this.asyncTaskModel = new AsyncTaskModel(db, userId); - this.knowledgeBaseModel = new KnowledgeBaseModel(db, userId); + this.chunkModel = new ChunkModel(db, userId, workspaceId); + this.asyncTaskModel = new AsyncTaskModel(db, userId, workspaceId); + this.knowledgeBaseModel = new KnowledgeBaseModel(db, userId, workspaceId); } /** @@ -128,7 +128,7 @@ export class FileUploadService extends BaseService { } const knowledgeBase = await this.db.query.knowledgeBases.findFirst({ - where: eq(knowledgeBases.id, knowledgeBaseId), + where: and(eq(knowledgeBases.id, knowledgeBaseId), this.buildWorkspaceWhere(knowledgeBases)), }); if (!knowledgeBase) { @@ -397,7 +397,7 @@ export class FileUploadService extends BaseService { const ownedFiles = await this.db.query.files.findMany({ columns: { id: true }, - where: and(inArray(files.id, uniqueFileIds), eq(files.userId, this.userId)), + where: and(inArray(files.id, uniqueFileIds), this.buildWorkspaceWhere(files)), }); const ownedIds = ownedFiles.map((file) => file.id); @@ -412,7 +412,7 @@ export class FileUploadService extends BaseService { ownedIds.map((fileId) => ({ fileId, knowledgeBaseId, - userId: this.userId, + ...this.buildWorkspacePayload({}), })), ) .onConflictDoNothing(); @@ -444,7 +444,7 @@ export class FileUploadService extends BaseService { const ownedFiles = await this.db.query.files.findMany({ columns: { id: true }, - where: and(inArray(files.id, uniqueFileIds), eq(files.userId, this.userId)), + where: and(inArray(files.id, uniqueFileIds), this.buildWorkspaceWhere(files)), }); const ownedIds = ownedFiles.map((file) => file.id); @@ -458,7 +458,7 @@ export class FileUploadService extends BaseService { .where( and( eq(knowledgeBaseFiles.knowledgeBaseId, knowledgeBaseId), - eq(knowledgeBaseFiles.userId, this.userId), + this.buildWorkspaceWhere(knowledgeBaseFiles), inArray(knowledgeBaseFiles.fileId, ownedIds), ), ); @@ -494,7 +494,7 @@ export class FileUploadService extends BaseService { const ownedFiles = await this.db.query.files.findMany({ columns: { id: true }, - where: and(inArray(files.id, uniqueFileIds), eq(files.userId, this.userId)), + where: and(inArray(files.id, uniqueFileIds), this.buildWorkspaceWhere(files)), }); const ownedIds = ownedFiles.map((file) => file.id); @@ -516,7 +516,7 @@ export class FileUploadService extends BaseService { .where( and( eq(knowledgeBaseFiles.knowledgeBaseId, sourceKnowledgeBaseId), - eq(knowledgeBaseFiles.userId, this.userId), + this.buildWorkspaceWhere(knowledgeBaseFiles), inArray(knowledgeBaseFiles.fileId, ownedIds), ), ); @@ -527,7 +527,7 @@ export class FileUploadService extends BaseService { ownedIds.map((fileId) => ({ fileId, knowledgeBaseId: request.targetKnowledgeBaseId, - userId: this.userId, + ...this.buildWorkspacePayload({}), })), ) .onConflictDoNothing(); @@ -917,7 +917,7 @@ export class FileUploadService extends BaseService { // Trigger async chunking task const { ChunkService } = await import('@/server/services/chunk'); - const chunkService = new ChunkService(this.db, this.userId); + const chunkService = new ChunkService(this.db, this.userId, this.workspaceId); const chunkTaskId = await chunkService.asyncParseFileToChunks(fileId, req.skipExist); @@ -1127,7 +1127,7 @@ export class FileUploadService extends BaseService { columns: { sessionId: true }, where: and( eq(agentsToSessions.agentId, options.agentId), - eq(agentsToSessions.userId, this.userId), + this.buildWorkspaceWhere(agentsToSessions), ), }); @@ -1152,7 +1152,7 @@ export class FileUploadService extends BaseService { .values({ fileId, sessionId, - userId: this.userId, + ...this.buildWorkspacePayload({}), }) .onConflictDoNothing(); @@ -1227,7 +1227,7 @@ export class FileUploadService extends BaseService { private async findExistingUserFile(hash: string): Promise { try { const existingFile = await this.db.query.files.findFirst({ - where: and(eq(files.fileHash, hash), eq(files.userId, this.userId)), + where: and(eq(files.fileHash, hash), this.buildWorkspaceWhere(files)), }); return existingFile || null; @@ -1251,9 +1251,8 @@ export class FileUploadService extends BaseService { const conditions = []; // Permission conditions - if (permissionResult?.condition?.userId) { - conditions.push(eq(files.userId, permissionResult.condition.userId)); - } + const permissionWhere = this.buildPermissionWhere(files, permissionResult.condition); + if (permissionWhere) conditions.push(permissionWhere); // Keyword search if (keyword) { @@ -1287,9 +1286,8 @@ export class FileUploadService extends BaseService { permissionResult: { condition?: { userId?: string } }, ): Promise { const whereConditions = [eq(files.id, fileId)]; - if (permissionResult.condition?.userId) { - whereConditions.push(eq(files.userId, permissionResult.condition.userId)); - } + const permissionWhere = this.buildPermissionWhere(files, permissionResult.condition); + if (permissionWhere) whereConditions.push(permissionWhere); const file = await this.db.query.files.findFirst({ where: and(...whereConditions), @@ -1521,7 +1519,7 @@ export class FileUploadService extends BaseService { .where( and( eq(knowledgeBaseFiles.fileId, fileId), - eq(knowledgeBaseFiles.userId, targetUserId), + this.buildWorkspaceWhere(knowledgeBaseFiles), ), ); @@ -1539,7 +1537,7 @@ export class FileUploadService extends BaseService { await trx.insert(knowledgeBaseFiles).values({ fileId, knowledgeBaseId: updateData.knowledgeBaseId, - userId: targetUserId, + ...this.buildWorkspacePayload({}), }); } }); diff --git a/packages/openapi/src/services/knowledge-base.service.ts b/packages/openapi/src/services/knowledge-base.service.ts index 6e03bcfa22..d9608c2c83 100644 --- a/packages/openapi/src/services/knowledge-base.service.ts +++ b/packages/openapi/src/services/knowledge-base.service.ts @@ -27,9 +27,9 @@ import type { export class KnowledgeBaseService extends BaseService { private knowledgeBaseModel: KnowledgeBaseModel; - constructor(db: LobeChatDatabase, userId: string) { - super(db, userId); - this.knowledgeBaseModel = new KnowledgeBaseModel(db, userId); + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { + super(db, userId, workspaceId); + this.knowledgeBaseModel = new KnowledgeBaseModel(db, userId, workspaceId); } /** @@ -50,7 +50,7 @@ export class KnowledgeBaseService extends BaseService { const { limit, offset } = processPaginationConditions(request); const { keyword } = request; - const conditions = [eq(knowledgeBases.userId, this.userId)]; + const conditions = [this.buildWorkspaceWhere(knowledgeBases)]; if (keyword) { conditions.push( @@ -190,7 +190,7 @@ export class KnowledgeBaseService extends BaseService { // Check if knowledge base exists and belongs to the current user const existingKb = await this.db.query.knowledgeBases.findFirst({ - where: and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId)), + where: and(eq(knowledgeBases.id, id), this.buildWorkspaceWhere(knowledgeBases)), }); if (!existingKb) { @@ -202,7 +202,7 @@ export class KnowledgeBaseService extends BaseService { // Get updated knowledge base info const updatedKb = await this.db.query.knowledgeBases.findFirst({ - where: eq(knowledgeBases.id, id), + where: and(eq(knowledgeBases.id, id), this.buildWorkspaceWhere(knowledgeBases)), }); this.log('info', 'Knowledge base updated successfully', { id }); @@ -231,7 +231,7 @@ export class KnowledgeBaseService extends BaseService { // Check if knowledge base exists and belongs to the current user const existingKb = await this.db.query.knowledgeBases.findFirst({ - where: and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId)), + where: and(eq(knowledgeBases.id, id), this.buildWorkspaceWhere(knowledgeBases)), }); if (!existingKb) { @@ -241,7 +241,7 @@ export class KnowledgeBaseService extends BaseService { const result = await this.knowledgeBaseModel.deleteWithFiles(id); if (result.deletedFiles.length > 0) { - const fileService = new CoreFileService(this.db, this.userId); + const fileService = new CoreFileService(this.db, this.userId, this.workspaceId); const urls = result.deletedFiles .map((f: { url: string | null }) => f.url) .filter(Boolean) as string[]; diff --git a/packages/openapi/src/services/message-translations.service.ts b/packages/openapi/src/services/message-translations.service.ts index b188908515..d53073fea6 100644 --- a/packages/openapi/src/services/message-translations.service.ts +++ b/packages/openapi/src/services/message-translations.service.ts @@ -1,4 +1,4 @@ -import { eq } from 'drizzle-orm'; +import { and, eq } from 'drizzle-orm'; import { messages, messageTranslates } from '@/database/schemas'; import type { LobeChatDatabase } from '@/database/type'; @@ -16,8 +16,8 @@ import { ChatService } from './chat.service'; type MessageTranslateItem = typeof messageTranslates.$inferSelect; export class MessageTranslateService extends BaseService { - constructor(db: LobeChatDatabase, userId: string | null) { - super(db, userId); + constructor(db: LobeChatDatabase, userId: string | null, workspaceId?: string) { + super(db, userId, workspaceId); } /** @@ -32,7 +32,10 @@ export class MessageTranslateService extends BaseService { try { const result = await this.db.query.messageTranslates.findFirst({ - where: eq(messageTranslates.id, messageId), + where: and( + eq(messageTranslates.id, messageId), + this.buildWorkspaceWhere(messageTranslates), + ), }); if (!result) { @@ -74,7 +77,7 @@ export class MessageTranslateService extends BaseService { try { // First fetch the original message content and sessionId const messageInfo = await this.db.query.messages.findFirst({ - where: eq(messages.id, translateData.messageId), + where: and(eq(messages.id, translateData.messageId), this.buildWorkspaceWhere(messages)), }); if (!messageInfo) { @@ -84,7 +87,7 @@ export class MessageTranslateService extends BaseService { this.log('info', '原始消息内容', { originalMessage: messageInfo.content }); // Use ChatService for translation, passing sessionId to use the correct model configuration - const chatService = new ChatService(this.db, this.userId); + const chatService = new ChatService(this.db, this.userId, this.workspaceId); const translatedContent = await chatService.translate({ ...translateData, sessionId: messageInfo.sessionId, @@ -116,7 +119,7 @@ export class MessageTranslateService extends BaseService { try { // Check if message exists const messageInfo = await this.db.query.messages.findFirst({ - where: eq(messages.id, data.messageId), + where: and(eq(messages.id, data.messageId), this.buildWorkspaceWhere(messages)), }); if (!messageInfo) { throw this.createCommonError('未找到要更新翻译信息的消息'); @@ -130,7 +133,7 @@ export class MessageTranslateService extends BaseService { from: data.from, id: data.messageId, to: data.to, - userId: this.userId, + ...this.buildWorkspacePayload({}), }) .onConflictDoUpdate({ set: { @@ -168,14 +171,21 @@ export class MessageTranslateService extends BaseService { try { // Check if the translation message exists const originalTranslation = await this.db.query.messageTranslates.findFirst({ - where: eq(messageTranslates.id, messageId), + where: and( + eq(messageTranslates.id, messageId), + this.buildWorkspaceWhere(messageTranslates), + ), }); if (!originalTranslation) { throw this.createNotFoundError('翻译消息不存在'); } - await this.db.delete(messageTranslates).where(eq(messageTranslates.id, messageId)); + await this.db + .delete(messageTranslates) + .where( + and(eq(messageTranslates.id, messageId), this.buildWorkspaceWhere(messageTranslates)), + ); return { deleted: true, messageId }; } catch (error) { diff --git a/packages/openapi/src/services/message.service.ts b/packages/openapi/src/services/message.service.ts index 1cf9cec9b1..18495d3043 100644 --- a/packages/openapi/src/services/message.service.ts +++ b/packages/openapi/src/services/message.service.ts @@ -33,10 +33,10 @@ export interface MessageCountResult { export class MessageService extends BaseService { private coreFileService: CoreFileService; - constructor(db: LobeChatDatabase, userId: string | null) { - super(db, userId); + constructor(db: LobeChatDatabase, userId: string | null, workspaceId?: string) { + super(db, userId, workspaceId); - this.coreFileService = new CoreFileService(db, userId!); + this.coreFileService = new CoreFileService(db, userId!, workspaceId); } /** @@ -96,7 +96,7 @@ export class MessageService extends BaseService { const result = await this.db .select({ count: count() }) .from(messages) - .where(eq(messages.userId, targetUserId)); + .where(this.buildPermissionWhere(messages, { userId: targetUserId })); const messageCount = result[0]?.count || 0; this.log('info', '用户消息统计完成', { count: messageCount }); @@ -128,7 +128,7 @@ export class MessageService extends BaseService { const result = await this.db .select({ count: count() }) .from(messages) - .where(inArray(messages.topicId, topicIds)); + .where(and(inArray(messages.topicId, topicIds), this.buildWorkspaceWhere(messages))); const messageCount = result[0]?.count || 0; this.log('info', '话题消息统计完成', { count: messageCount }); @@ -162,7 +162,7 @@ export class MessageService extends BaseService { const result = await this.db .select({ count: count() }) .from(messages) - .where(eq(messages.userId, this.userId!)); + .where(this.buildWorkspaceWhere(messages)); const messageCount = result[0]?.count || 0; this.log('info', '当前用户消息统计完成', { count: messageCount }); @@ -197,7 +197,7 @@ export class MessageService extends BaseService { const { keyword, limit = 20, offset = 0 } = searchRequest; // Build query conditions - const conditions = [eq(messages.userId, this.userId!)]; + const conditions = [this.buildWorkspaceWhere(messages)]; const contentMatchedMessages = await this.db .select({ id: messages.id }) @@ -267,7 +267,7 @@ export class MessageService extends BaseService { throw this.createAuthorizationError(permissionResult.message || '无权访问消息列表'); } - conditions.push(eq(messages.userId, request.userId)); + conditions.push(this.buildPermissionWhere(messages, { userId: request.userId })!); } // Verify topic ownership and whether the user has message read permission @@ -281,6 +281,7 @@ export class MessageService extends BaseService { } conditions.push(eq(messages.topicId, request.topicId)); + conditions.push(this.buildWorkspaceWhere(messages)); } if (request.role) { @@ -353,9 +354,8 @@ export class MessageService extends BaseService { // Build query conditions const conditions = [eq(messages.id, messageId)]; - if (permissionResult.condition?.userId) { - conditions.push(eq(messages.userId, permissionResult.condition.userId)); - } + const permissionWhere = this.buildPermissionWhere(messages, permissionResult.condition); + if (permissionWhere) conditions.push(permissionWhere); const message = (await this.db.query.messages.findFirst({ where: and(...conditions), @@ -431,7 +431,7 @@ export class MessageService extends BaseService { tools: messageData.tools, topicId: messageData.topicId, traceId: messageData.traceId, - userId: this.userId!, + ...this.buildWorkspacePayload({}), }) .returning({ id: messages.id, @@ -449,14 +449,14 @@ export class MessageService extends BaseService { messageData.files.map((fileId) => ({ fileId, messageId: newMessage.id, - userId: this.userId!, + ...this.buildWorkspacePayload({}), })), ); } // Re-query the complete message including session and topic information const completeMessage = (await this.db.query.messages.findFirst({ - where: eq(messages.id, newMessage.id), + where: and(eq(messages.id, newMessage.id), this.buildWorkspaceWhere(messages)), with: { filesToMessages: { with: { @@ -524,7 +524,7 @@ export class MessageService extends BaseService { userId: this.userId, }); - const chatService = new ChatService(this.db, this.userId); + const chatService = new ChatService(this.db, this.userId, this.workspaceId); let aiReplyContent = ''; try { @@ -591,7 +591,7 @@ export class MessageService extends BaseService { orderBy: desc(messages.createdAt), where: and( topicId === null ? isNull(messages.topicId) : eq(messages.topicId, topicId), - eq(messages.userId, this.userId!), + this.buildWorkspaceWhere(messages), ), }); @@ -632,9 +632,8 @@ export class MessageService extends BaseService { const whereConditions = [eq(messages.id, messageId)]; // Apply permission conditions - if (permissionResult.condition?.userId) { - whereConditions.push(eq(messages.userId, permissionResult.condition.userId)); - } + const permissionWhere = this.buildPermissionWhere(messages, permissionResult.condition); + if (permissionWhere) whereConditions.push(permissionWhere); // Use a transaction to delete messages and their associations with files await this.db.transaction(async (trx) => { diff --git a/packages/openapi/src/services/model.service.ts b/packages/openapi/src/services/model.service.ts index d4f3d5aeb2..417858107f 100644 --- a/packages/openapi/src/services/model.service.ts +++ b/packages/openapi/src/services/model.service.ts @@ -19,8 +19,8 @@ import type { * Provides model query and grouping functionality */ export class ModelService extends BaseService { - constructor(db: LobeChatDatabase, userId: string | null) { - super(db, userId); + constructor(db: LobeChatDatabase, userId: string | null, workspaceId?: string) { + super(db, userId, workspaceId); } /** @@ -45,9 +45,8 @@ export class ModelService extends BaseService { const conditions = []; // Add permission condition directly to the main conditions array - if (permissionResult.condition?.userId) { - conditions.push(eq(aiModels.userId, permissionResult.condition.userId)); - } + const permissionWhere = this.buildPermissionWhere(aiModels, permissionResult.condition); + if (permissionWhere) conditions.push(permissionWhere); // Handle ModelsListQuery-specific parameters const { page, pageSize, keyword, provider, type, enabled } = request; @@ -117,9 +116,8 @@ export class ModelService extends BaseService { const conditions = [eq(aiModels.providerId, providerId), eq(aiModels.id, modelId)]; - if (permissionResult.condition?.userId) { - conditions.push(eq(aiModels.userId, permissionResult.condition.userId)); - } + const permissionWhere = this.buildPermissionWhere(aiModels, permissionResult.condition); + if (permissionWhere) conditions.push(permissionWhere); const model = await this.db.query.aiModels.findFirst({ where: and(...conditions) }); @@ -154,7 +152,7 @@ export class ModelService extends BaseService { where: and( eq(aiModels.id, payload.id), eq(aiModels.providerId, payload.providerId), - eq(aiModels.userId, this.userId), + this.buildWorkspaceWhere(aiModels), ), }); @@ -180,7 +178,7 @@ export class ModelService extends BaseService { sort: payload.sort ?? null, source: payload.source ?? null, type: payload.type ?? 'chat', - userId: this.userId, + ...this.buildWorkspacePayload({}), }) .returning(); @@ -211,9 +209,8 @@ export class ModelService extends BaseService { } const conditions = [eq(aiModels.providerId, providerId), eq(aiModels.id, modelId)]; - if (permissionResult.condition?.userId) { - conditions.push(eq(aiModels.userId, permissionResult.condition.userId)); - } + const permissionWhere = this.buildPermissionWhere(aiModels, permissionResult.condition); + if (permissionWhere) conditions.push(permissionWhere); return await this.db.transaction(async (tx) => { const existingModel = await tx.query.aiModels.findFirst({ where: and(...conditions) }); diff --git a/packages/openapi/src/services/permission.service.ts b/packages/openapi/src/services/permission.service.ts index 7b67741c2b..a7948fcb8c 100644 --- a/packages/openapi/src/services/permission.service.ts +++ b/packages/openapi/src/services/permission.service.ts @@ -15,8 +15,8 @@ import type { } from '../types/permission.type'; export class PermissionService extends BaseService { - constructor(db: LobeChatDatabase, userId: string | null) { - super(db, userId); + constructor(db: LobeChatDatabase, userId: string | null, workspaceId?: string) { + super(db, userId, workspaceId); } /** diff --git a/packages/openapi/src/services/provider.service.ts b/packages/openapi/src/services/provider.service.ts index 7a463efbb7..36d45b13fd 100644 --- a/packages/openapi/src/services/provider.service.ts +++ b/packages/openapi/src/services/provider.service.ts @@ -26,8 +26,8 @@ import type { export class ProviderService extends BaseService { private gateKeeperPromise: Promise | null = null; - constructor(db: LobeChatDatabase, userId: string | null) { - super(db, userId); + constructor(db: LobeChatDatabase, userId: string | null, workspaceId?: string) { + super(db, userId, workspaceId); } private async getGateKeeper(): Promise { @@ -99,9 +99,8 @@ export class ProviderService extends BaseService { const conditions = [] as any[]; - if (permissionResult.condition?.userId) { - conditions.push(eq(aiProviders.userId, permissionResult.condition.userId)); - } + const permissionWhere = this.buildPermissionWhere(aiProviders, permissionResult.condition); + if (permissionWhere) conditions.push(permissionWhere); if (request.keyword) { conditions.push( @@ -164,9 +163,8 @@ export class ProviderService extends BaseService { const whereConditions = [eq(aiProviders.id, request.id)]; - if (permissionResult.condition?.userId) { - whereConditions.push(eq(aiProviders.userId, permissionResult.condition.userId)); - } + const permissionWhere = this.buildPermissionWhere(aiProviders, permissionResult.condition); + if (permissionWhere) whereConditions.push(permissionWhere); const whereCondition = whereConditions.length > 1 ? and(...whereConditions) : whereConditions[0]; @@ -205,7 +203,7 @@ export class ProviderService extends BaseService { const ownerId = permissionResult.condition?.userId ?? this.userId; const existed = await this.db.query.aiProviders.findFirst({ - where: and(eq(aiProviders.id, request.id), eq(aiProviders.userId, ownerId)), + where: and(eq(aiProviders.id, request.id), this.buildWorkspaceWhere(aiProviders)), }); if (existed) { @@ -232,7 +230,7 @@ export class ProviderService extends BaseService { sort: request.sort ?? null, source: request.source, updatedAt: now, - userId: ownerId, + ...this.buildWorkspacePayload({}), }) .returning(); @@ -263,9 +261,8 @@ export class ProviderService extends BaseService { const whereConditions = [eq(aiProviders.id, request.id)]; - if (permissionResult.condition?.userId) { - whereConditions.push(eq(aiProviders.userId, permissionResult.condition.userId)); - } + const permissionWhere = this.buildPermissionWhere(aiProviders, permissionResult.condition); + if (permissionWhere) whereConditions.push(permissionWhere); const whereCondition = whereConditions.length > 1 ? and(...whereConditions) : whereConditions[0]; @@ -330,9 +327,8 @@ export class ProviderService extends BaseService { const whereConditions = [eq(aiProviders.id, request.id)]; - if (permissionResult.condition?.userId) { - whereConditions.push(eq(aiProviders.userId, permissionResult.condition.userId)); - } + const permissionWhere = this.buildPermissionWhere(aiProviders, permissionResult.condition); + if (permissionWhere) whereConditions.push(permissionWhere); const providerWhere = whereConditions.length > 1 ? and(...whereConditions) : whereConditions[0]!; @@ -348,9 +344,11 @@ export class ProviderService extends BaseService { await this.db.transaction(async (tx) => { const modelConditions = [eq(aiModels.providerId, request.id)]; - if (permissionResult.condition?.userId) { - modelConditions.push(eq(aiModels.userId, permissionResult.condition.userId)); - } + const modelPermissionWhere = this.buildPermissionWhere( + aiModels, + permissionResult.condition, + ); + if (modelPermissionWhere) modelConditions.push(modelPermissionWhere); const modelWhere = modelConditions.length > 1 ? and(...modelConditions) : modelConditions[0]!; diff --git a/packages/openapi/src/services/responses.service.ts b/packages/openapi/src/services/responses.service.ts index a2ae1dcc8d..f3b071c7d5 100644 --- a/packages/openapi/src/services/responses.service.ts +++ b/packages/openapi/src/services/responses.service.ts @@ -291,7 +291,9 @@ export class ResponsesService extends BaseService { // model field is used as agentId const additionalPluginIds = this.extractHostedToolIds(params.tools); const functionTools = this.extractFunctionTools(params.tools); - const aiAgentService = new AiAgentService(this.db, this.userId); + const aiAgentService = new AiAgentService(this.db, this.userId, { + workspaceId: this.workspaceId, + }); const execResult = await aiAgentService.execAgent({ additionalPluginIds: additionalPluginIds.length > 0 ? additionalPluginIds : undefined, agentId: model, @@ -314,6 +316,7 @@ export class ResponsesService extends BaseService { // 2. Execute synchronously to completion const agentRuntimeService = new AgentRuntimeService(this.db, this.userId, { queueService: null, + workspaceId: this.workspaceId, }); const finalState = await agentRuntimeService.executeSync(execResult.operationId); @@ -390,7 +393,9 @@ export class ResponsesService extends BaseService { // model field is used as agentId const additionalPluginIds = this.extractHostedToolIds(params.tools); const functionTools = this.extractFunctionTools(params.tools); - const aiAgentService = new AiAgentService(this.db, this.userId); + const aiAgentService = new AiAgentService(this.db, this.userId, { + workspaceId: this.workspaceId, + }); const execResult = await aiAgentService.execAgent({ additionalPluginIds: additionalPluginIds.length > 0 ? additionalPluginIds : undefined, agentId: model, @@ -434,6 +439,7 @@ export class ResponsesService extends BaseService { const agentRuntimeService = new AgentRuntimeService(this.db, this.userId, { queueService: null, streamEventManager, + workspaceId: this.workspaceId, }); // 3. Setup async event queue to bridge push events → pull-based generator diff --git a/packages/openapi/src/services/role.service.ts b/packages/openapi/src/services/role.service.ts index e35f7e60c2..1a3774a9e9 100644 --- a/packages/openapi/src/services/role.service.ts +++ b/packages/openapi/src/services/role.service.ts @@ -1,5 +1,5 @@ import type { SQL } from 'drizzle-orm'; -import { and, count, eq, ilike, inArray, or, sql } from 'drizzle-orm'; +import { and, count, eq, ilike, inArray, isNull, or, sql } from 'drizzle-orm'; import type { RoleItem } from '@/database/schemas/rbac'; import { permissions, rolePermissions, roles, userRoles } from '@/database/schemas/rbac'; @@ -19,8 +19,20 @@ import type { } from '../types/role.type'; export class RoleService extends BaseService { - constructor(db: LobeChatDatabase, userId: string | null) { - super(db, userId); + constructor(db: LobeChatDatabase, userId: string | null, workspaceId?: string) { + super(db, userId, workspaceId); + } + + private getRoleScopeWhere() { + return this.workspaceId + ? or(eq(roles.workspaceId, this.workspaceId), isNull(roles.workspaceId)) + : isNull(roles.workspaceId); + } + + private getUserRoleScopeWhere() { + return this.workspaceId + ? eq(userRoles.workspaceId, this.workspaceId) + : isNull(userRoles.workspaceId); } /** @@ -54,6 +66,8 @@ export class RoleService extends BaseService { ); } + conditions.push(this.getRoleScopeWhere()); + const { limit, offset } = processPaginationConditions(request); const whereExpr = conditions.length ? and(...conditions) : undefined; @@ -91,7 +105,7 @@ export class RoleService extends BaseService { try { return await this.db.query.roles.findMany({ orderBy: [roles.isSystem, roles.createdAt], - where: eq(roles.isActive, true), + where: and(eq(roles.isActive, true), this.getRoleScopeWhere()), }); } catch (error) { this.handleServiceError(error, '获取活跃角色列表'); @@ -112,7 +126,7 @@ export class RoleService extends BaseService { try { const role = await this.db.query.roles.findFirst({ - where: eq(roles.id, id), + where: and(eq(roles.id, id), this.getRoleScopeWhere()), }); return role || null; } catch (error) { @@ -134,7 +148,7 @@ export class RoleService extends BaseService { try { const role = await this.db.query.roles.findFirst({ - where: eq(roles.name, name), + where: and(eq(roles.name, name), this.getRoleScopeWhere()), }); return role || null; } catch (error) { @@ -157,7 +171,7 @@ export class RoleService extends BaseService { return await this.db.transaction(async (tx) => { // Ensure role name is unique const existingRole = await tx.query.roles.findFirst({ - where: eq(roles.name, payload.name), + where: and(eq(roles.name, payload.name), this.getRoleScopeWhere()), }); if (existingRole) { throw this.createBusinessError(`角色名称 "${payload.name}" 已存在`); @@ -171,6 +185,7 @@ export class RoleService extends BaseService { isActive: payload.isActive ?? true, isSystem: payload.isSystem ?? false, name: payload.name, + workspaceId: this.workspaceId ?? null, }) .returning(); @@ -375,7 +390,7 @@ export class RoleService extends BaseService { return await this.db.transaction(async (tx) => { // Check if the role exists const existingRole = await tx.query.roles.findFirst({ - where: eq(roles.id, id), + where: and(eq(roles.id, id), this.getRoleScopeWhere()), }); if (!existingRole) { @@ -390,7 +405,7 @@ export class RoleService extends BaseService { // If the role name is being modified, check whether the new name already exists if (updateData.name && updateData.name !== existingRole.name) { const duplicateRole = await tx.query.roles.findFirst({ - where: eq(roles.name, updateData.name), + where: and(eq(roles.name, updateData.name), this.getRoleScopeWhere()), }); if (duplicateRole) { @@ -412,7 +427,7 @@ export class RoleService extends BaseService { const [updatedRole] = await tx .update(roles) .set(updateFields) - .where(eq(roles.id, id)) + .where(and(eq(roles.id, id), this.getRoleScopeWhere())) .returning(); this.log('info', '角色更新成功', { roleId: id, roleName: updatedRole.name }); @@ -437,7 +452,9 @@ export class RoleService extends BaseService { try { // Check if the role exists - const existingRole = await this.db.query.roles.findFirst({ where: eq(roles.id, roleId) }); + const existingRole = await this.db.query.roles.findFirst({ + where: and(eq(roles.id, roleId), this.getRoleScopeWhere()), + }); if (!existingRole) { throw this.createNotFoundError(`角色 ID "${roleId}" 不存在`); } @@ -469,7 +486,9 @@ export class RoleService extends BaseService { try { return await this.db.transaction(async (tx) => { - const existingRole = await tx.query.roles.findFirst({ where: eq(roles.id, id) }); + const existingRole = await tx.query.roles.findFirst({ + where: and(eq(roles.id, id), this.getRoleScopeWhere()), + }); if (!existingRole) { throw this.createNotFoundError(`角色 ID "${id}" 不存在`); @@ -479,14 +498,16 @@ export class RoleService extends BaseService { throw this.createBusinessError('系统角色不允许删除'); } - const linkedUser = await tx.query.userRoles.findFirst({ where: eq(userRoles.roleId, id) }); + const linkedUser = await tx.query.userRoles.findFirst({ + where: and(eq(userRoles.roleId, id), this.getUserRoleScopeWhere()), + }); if (linkedUser) { throw this.createBusinessError('角色仍然关联用户,无法删除'); } const [deletedRole] = await tx .delete(roles) - .where(eq(roles.id, id)) + .where(and(eq(roles.id, id), this.getRoleScopeWhere())) .returning({ id: roles.id }); if (!deletedRole) { diff --git a/packages/openapi/src/services/topic.service.ts b/packages/openapi/src/services/topic.service.ts index 056e64b281..e123094bf8 100644 --- a/packages/openapi/src/services/topic.service.ts +++ b/packages/openapi/src/services/topic.service.ts @@ -15,8 +15,8 @@ import type { } from '../types/topic.type'; export class TopicService extends BaseService { - constructor(db: LobeChatDatabase, userId: string | null) { - super(db, userId); + constructor(db: LobeChatDatabase, userId: string | null, workspaceId?: string) { + super(db, userId, workspaceId); } /** @@ -37,9 +37,8 @@ export class TopicService extends BaseService { const conditions = []; // Add permission-related query conditions - if (permissionResult?.condition?.userId) { - conditions.push(eq(topics.userId, permissionResult.condition.userId)); - } + const permissionWhere = this.buildPermissionWhere(topics, permissionResult.condition); + if (permissionWhere) conditions.push(permissionWhere); // Filter by groupId first if (request.groupId) { @@ -49,7 +48,12 @@ export class TopicService extends BaseService { const [relation] = await this.db .select({ sessionId: agentsToSessions.sessionId }) .from(agentsToSessions) - .where(eq(agentsToSessions.agentId, request.agentId)) + .where( + and( + eq(agentsToSessions.agentId, request.agentId), + this.buildWorkspaceWhere(agentsToSessions), + ), + ) .limit(1); if (relation) { @@ -131,9 +135,8 @@ export class TopicService extends BaseService { const whereConditions = [eq(topics.id, topicId)]; // Apply permission conditions - if (permissionResult.condition?.userId) { - whereConditions.push(eq(topics.userId, permissionResult.condition.userId)); - } + const permissionWhere = this.buildPermissionWhere(topics, permissionResult.condition); + if (permissionWhere) whereConditions.push(permissionWhere); const [result] = await this.db .select({ @@ -178,7 +181,9 @@ export class TopicService extends BaseService { const [relation] = await this.db .select({ sessionId: agentsToSessions.sessionId }) .from(agentsToSessions) - .where(eq(agentsToSessions.agentId, agentId)) + .where( + and(eq(agentsToSessions.agentId, agentId), this.buildWorkspaceWhere(agentsToSessions)), + ) .limit(1); effectiveSessionId = relation?.sessionId ?? null; @@ -203,7 +208,7 @@ export class TopicService extends BaseService { id: idGenerator('topics'), sessionId: effectiveSessionId, title, - userId: this.userId, + ...this.buildWorkspacePayload({}), }) .returning(); @@ -234,9 +239,8 @@ export class TopicService extends BaseService { const whereConditions = [eq(topics.id, topicId)]; // Apply permission conditions - if (permissionResult.condition?.userId) { - whereConditions.push(eq(topics.userId, permissionResult.condition.userId)); - } + const permissionWhere = this.buildPermissionWhere(topics, permissionResult.condition); + if (permissionWhere) whereConditions.push(permissionWhere); const [updatedTopic] = await this.db .update(topics) @@ -273,9 +277,8 @@ export class TopicService extends BaseService { const whereConditions = [eq(topics.id, topicId)]; // Apply permission conditions - if (permissionResult.condition?.userId) { - whereConditions.push(eq(topics.userId, permissionResult.condition.userId)); - } + const permissionWhere = this.buildPermissionWhere(topics, permissionResult.condition); + if (permissionWhere) whereConditions.push(permissionWhere); const [existingTopic] = await this.db .delete(topics) diff --git a/packages/openapi/src/services/user.service.ts b/packages/openapi/src/services/user.service.ts index 228a15d07f..2a6e4eb6ab 100644 --- a/packages/openapi/src/services/user.service.ts +++ b/packages/openapi/src/services/user.service.ts @@ -1,4 +1,4 @@ -import { and, count, desc, eq, ilike, inArray, ne, or } from 'drizzle-orm'; +import { and, count, desc, eq, ilike, inArray, isNull, ne, or } from 'drizzle-orm'; import { ALL_SCOPE } from '@/const/rbac'; import { RbacModel } from '@/database/models/rbac'; @@ -24,8 +24,14 @@ import type { * User service implementation class */ export class UserService extends BaseService { - constructor(db: LobeChatDatabase, userId: string | null) { - super(db, userId); + constructor(db: LobeChatDatabase, userId: string | null, workspaceId?: string) { + super(db, userId, workspaceId); + } + + private getRoleScopeWhere() { + return this.workspaceId + ? or(eq(roles.workspaceId, this.workspaceId), isNull(roles.workspaceId)) + : isNull(roles.workspaceId); } /** @@ -48,7 +54,7 @@ export class UserService extends BaseService { .select({ roles }) .from(userRoles) .innerJoin(roles, eq(userRoles.roleId, roles.id)) - .where(eq(userRoles.userId, userId)); + .where(and(eq(userRoles.userId, userId), this.buildPermissionWhere(userRoles, { userId }))); return { ...user, @@ -62,9 +68,12 @@ export class UserService extends BaseService { .select({ roles }) .from(userRoles) .innerJoin(roles, eq(userRoles.roleId, roles.id)) - .where(eq(userRoles.userId, userId)), + .where(and(eq(userRoles.userId, userId), this.buildPermissionWhere(userRoles, { userId }))), - this.db.select({ count: count() }).from(messages).where(eq(messages.userId, userId)), + this.db + .select({ count: count() }) + .from(messages) + .where(this.buildPermissionWhere(messages, { userId })), ]); return { @@ -407,7 +416,11 @@ export class UserService extends BaseService { // 3. Validate that all roles exist and are active if (allRoleIds.size > 0) { const existingRoles = await tx.query.roles.findMany({ - where: and(inArray(roles.id, Array.from(allRoleIds)), eq(roles.isActive, true)), + where: and( + inArray(roles.id, Array.from(allRoleIds)), + eq(roles.isActive, true), + this.getRoleScopeWhere(), + ), }); const existingRoleIds = new Set(existingRoles.map((r) => r.id)); @@ -429,7 +442,11 @@ export class UserService extends BaseService { await tx .delete(userRoles) .where( - and(eq(userRoles.userId, userId), inArray(userRoles.roleId, request.removeRoles)), + and( + eq(userRoles.userId, userId), + inArray(userRoles.roleId, request.removeRoles), + this.buildPermissionWhere(userRoles, { userId }), + ), ); this.log('info', '移除用户角色成功'); @@ -443,6 +460,7 @@ export class UserService extends BaseService { expiresAt: role.expiresAt ? new Date(role.expiresAt) : null, roleId: role.roleId, userId, + workspaceId: this.workspaceId ?? null, }; return data; }); @@ -456,7 +474,9 @@ export class UserService extends BaseService { .from(userRoles) .innerJoin(roles, eq(userRoles.roleId, roles.id)) .innerJoin(users, eq(userRoles.userId, users.id)) - .where(eq(userRoles.userId, userId)); + .where( + and(eq(userRoles.userId, userId), this.buildPermissionWhere(userRoles, { userId })), + ); this.log('info', '用户角色更新完成', { result, @@ -508,7 +528,7 @@ export class UserService extends BaseService { .select({ role: roles, userRole: userRoles }) .from(userRoles) .innerJoin(roles, eq(userRoles.roleId, roles.id)) - .where(eq(userRoles.userId, userId)); + .where(and(eq(userRoles.userId, userId), this.buildPermissionWhere(userRoles, { userId }))); return results.map((r) => ({ expiresAt: r.userRole.expiresAt, @@ -547,9 +567,11 @@ export class UserService extends BaseService { const beforeCount = await this.db .select({ count: count() }) .from(userRoles) - .where(eq(userRoles.userId, userId)); + .where(and(eq(userRoles.userId, userId), this.buildPermissionWhere(userRoles, { userId }))); - await this.db.delete(userRoles).where(eq(userRoles.userId, userId)); + await this.db + .delete(userRoles) + .where(and(eq(userRoles.userId, userId), this.buildPermissionWhere(userRoles, { userId }))); return { removed: beforeCount[0]?.count || 0, userId }; } catch (error) { diff --git a/packages/types/package.json b/packages/types/package.json index 3ed432afde..43786f0ee1 100644 --- a/packages/types/package.json +++ b/packages/types/package.json @@ -10,7 +10,7 @@ "test:update": "vitest -u" }, "dependencies": { - "@lobehub/market-sdk": "0.33.3", + "@lobehub/market-sdk": "0.34.0", "@lobehub/market-types": "^1.12.3", "model-bank": "workspace:*", "type-fest": "^4.41.0", diff --git a/packages/types/src/discover/assistants.ts b/packages/types/src/discover/assistants.ts index 56de3ed7ea..8baffbf53b 100644 --- a/packages/types/src/discover/assistants.ts +++ b/packages/types/src/discover/assistants.ts @@ -58,6 +58,10 @@ export interface DiscoverAssistantItem extends Omit, installCount?: number; isValidated?: boolean; knowledgeCount: number; + /** + * Owner account type, used to resolve the author profile link + */ + ownerType?: 'user' | 'organization'; pluginCount: number; status?: AgentStatus; tokenUsage: number; diff --git a/packages/types/src/discover/fork.ts b/packages/types/src/discover/fork.ts index c61b016b7d..170b6c1460 100644 --- a/packages/types/src/discover/fork.ts +++ b/packages/types/src/discover/fork.ts @@ -24,6 +24,14 @@ export interface AgentForkRequest { * but is carried in the body for batch payloads). */ export interface AgentForkBatchInput extends AgentForkRequest { + /** + * Optional Market organization account id to attribute the fork to. When + * present, the cloud forwards `X-Lobe-Owner-Account-Id` so the resulting + * `agents.ownerId` points at the organization rather than the calling user. + * Callers in a workspace context should resolve this via + * `WorkspaceMarketIdentityService.ensureOrganization`. + */ + actAs?: number; /** Source agent identifier to fork from */ sourceIdentifier: string; } @@ -102,6 +110,8 @@ export interface AgentForkSourceResponse { * Fork request parameters for Agent Group */ export interface AgentGroupForkRequest { + /** Market organization account id used when forking from a workspace */ + actAs?: number; /** New group identifier (required, must be globally unique) */ identifier: string; /** New group name (optional, defaults to "{original name} (Fork)") */ diff --git a/packages/types/src/discover/groupAgents.ts b/packages/types/src/discover/groupAgents.ts index 7364a05227..3f11f7c5ba 100644 --- a/packages/types/src/discover/groupAgents.ts +++ b/packages/types/src/discover/groupAgents.ts @@ -97,6 +97,10 @@ export interface DiscoverGroupAgentItem extends MetaData { * Number of member agents in the group */ memberCount: number; + /** + * Owner account type, used to resolve the author profile link + */ + ownerType?: 'user' | 'organization'; /** * Number of plugins across all member agents */ diff --git a/packages/types/src/discover/index.ts b/packages/types/src/discover/index.ts index 6291d00b97..5bff212baf 100644 --- a/packages/types/src/discover/index.ts +++ b/packages/types/src/discover/index.ts @@ -22,6 +22,7 @@ export enum DiscoverTab { Providers = 'provider', Skills = 'skill', User = 'user', + Workspace = 'workspace', } export type IdentifiersResponse = { diff --git a/packages/types/src/document/index.ts b/packages/types/src/document/index.ts index ca39da51ec..e0f5968162 100644 --- a/packages/types/src/document/index.ts +++ b/packages/types/src/document/index.ts @@ -88,6 +88,8 @@ export interface LobeDocument { * File last modified timestamp */ updatedAt: Date; + + userId?: string; } /** diff --git a/packages/types/src/files/list.ts b/packages/types/src/files/list.ts index 7420c823f6..5134b9c353 100644 --- a/packages/types/src/files/list.ts +++ b/packages/types/src/files/list.ts @@ -35,6 +35,7 @@ export interface FileListItem { sourceType: string; updatedAt: Date; url: string; + userId?: string; } export enum SortType { diff --git a/packages/types/src/message/db/item.ts b/packages/types/src/message/db/item.ts index dc0715bbbd..d0785b284e 100644 --- a/packages/types/src/message/db/item.ts +++ b/packages/types/src/message/db/item.ts @@ -36,6 +36,7 @@ export interface DBMessageItem { */ usage?: ModelUsage | null; userId: string; + workspaceId: string | null; } export interface MessagePluginItem { diff --git a/packages/types/src/task/index.ts b/packages/types/src/task/index.ts index e5c01a6d03..602a39e1ad 100644 --- a/packages/types/src/task/index.ts +++ b/packages/types/src/task/index.ts @@ -149,6 +149,7 @@ export interface TaskItem { status: string; totalTopics: number | null; updatedAt: Date; + workspaceId: string | null; } export type TaskListItem = TaskItem & { @@ -188,6 +189,7 @@ export interface NewTask { status?: string; totalTopics?: number | null; updatedAt?: Date; + workspaceId?: string | null; } // ── Task Detail (shared across CLI, viewTask tool, task.detail router) ── diff --git a/packages/types/src/topic/topic.ts b/packages/types/src/topic/topic.ts index bf35cce2c3..e8f604591d 100644 --- a/packages/types/src/topic/topic.ts +++ b/packages/types/src/topic/topic.ts @@ -217,6 +217,7 @@ export interface ChatTopic extends Omit { /** Server-side mock until real token aggregation lands. */ tokenUsage?: number | null; trigger?: string | null; + userId?: string; } export type ChatTopicMap = Record; diff --git a/scripts/codemodWorkspaceNav.ts b/scripts/codemodWorkspaceNav.ts new file mode 100644 index 0000000000..0f301a3357 --- /dev/null +++ b/scripts/codemodWorkspaceNav.ts @@ -0,0 +1,371 @@ +#!/usr/bin/env bun +/** + * Codemod: rewrite in-app callsites to be workspace-aware. + * + * useNavigate (from 'react-router-dom') → useWorkspaceAwareNavigate + * + * + * Idempotent. Re-run after rebasing the lobehub submodule onto upstream canary + * to re-apply Step B's workspace-aware navigation patches (LOBE-9024). + * + * Strategy: + * - Scope: lobehub/src/{features,routes,hooks} excluding tests, the router + * configs themselves, and the Workspace feature folder. + * - For each file, collect every `navigate('/...')` literal and every + * `` callsites individually. + * - If the file has only personal-only navigate targets → skip entirely + * (the file is correct as-is). + * - Otherwise rewrite `useNavigate` → `useWorkspaceAwareNavigate` and add + * the appropriate import. + * + * Run: + * bun run scripts/codemodWorkspaceNav.ts # apply + * bun run scripts/codemodWorkspaceNav.ts --dry # report only + * bun run scripts/codemodWorkspaceNav.ts --check # exit 1 if would change + */ + +import { readdir, readFile, stat, writeFile } from 'node:fs/promises'; +import path from 'node:path'; +import { fileURLToPath } from 'node:url'; + +const __dirname = path.dirname(fileURLToPath(import.meta.url)); +const ROOT = path.resolve(__dirname, '..'); + +const SCAN_ROOTS = ['src/features', 'src/routes', 'src/hooks']; + +const EXCLUDE_DIR_NAMES = new Set(['__tests__', '__mocks__', 'node_modules']); + +// Files whose pathname matches these substrings are skipped. +const EXCLUDE_PATH_SUBSTRINGS = ['/spa/router/', '/features/Workspace/']; + +const EXCLUDE_FILE_SUFFIXES = ['.test.ts', '.test.tsx', '.spec.ts', '.spec.tsx', '.d.ts']; + +// Top-level personal-only routes. `/settings` is handled separately by the +// shared-tabs allowlist below so workspace-mirrored sub-paths (general, plans, +// billing, …) get auto-prefixed while truly personal sub-paths (profile, llm, +// referral, system-tools, workspace-*) stay personal. +// +// Keep in sync with `PERSONAL_PATH_REGEX` in +// `src/features/Workspace/workspaceAwarePath.ts`. +const PERSONAL_PATH_REGEX = /^\/(?:onboarding|me|share|devtools|desktop-onboarding)(?:[/?#]|$)/; + +// Keep in sync with `WORKSPACE_SETTINGS_TABS` in +// `src/features/Workspace/workspaceAwarePath.ts`. +const SHARED_SETTINGS_TABS = + '(?:apikey|billing|creds|credits|general|members|memory|messenger|plans|provider|service-model|skill|stats|usage)'; + +const SHARED_PATH_REGEX = new RegExp( + `^\\/(?:agent|group|community|memory|page|resource|image|video|eval|tasks?|settings\\/${SHARED_SETTINGS_TABS})(?:[/?#]|$)`, +); + +type Verdict = 'personal' | 'shared' | 'unknown'; + +const classifyPath = (path: string): Verdict => { + if (PERSONAL_PATH_REGEX.test(path)) return 'personal'; + if (SHARED_PATH_REGEX.test(path)) return 'shared'; + return 'unknown'; +}; + +const WORKSPACE_NAVIGATE_IMPORT = + "import { useWorkspaceAwareNavigate } from '@/features/Workspace/useWorkspaceAwareNavigate';"; +const WORKSPACE_LINK_IMPORT = "import WorkspaceLink from '@/features/Workspace/WorkspaceLink';"; + +interface Report { + file: string; + reason: string; +} + +const reports: { transformed: Report[]; skipped: Report[]; warnings: Report[] } = { + transformed: [], + skipped: [], + warnings: [], +}; + +const args = new Set(process.argv.slice(2)); +const DRY = args.has('--dry') || args.has('--dry-run'); +const CHECK = args.has('--check'); + +async function walk(dir: string, files: string[] = []): Promise { + let entries; + try { + entries = await readdir(dir, { withFileTypes: true }); + } catch { + return files; + } + for (const entry of entries) { + if (EXCLUDE_DIR_NAMES.has(entry.name)) continue; + const full = path.join(dir, entry.name); + if (entry.isDirectory()) { + await walk(full, files); + continue; + } + if (!entry.isFile()) continue; + if (!entry.name.endsWith('.ts') && !entry.name.endsWith('.tsx')) continue; + if (EXCLUDE_FILE_SUFFIXES.some((s) => entry.name.endsWith(s))) continue; + const rel = path.relative(ROOT, full).replaceAll('\\', '/'); + if (EXCLUDE_PATH_SUBSTRINGS.some((s) => rel.includes(s))) continue; + files.push(full); + } + return files; +} + +// Extract the first-arg path literal from `navigate(...)` invocations. +// Handles: navigate('/foo'), navigate("/foo"), navigate(`/foo/${x}`) +const NAVIGATE_CALL_REGEX = /\bnavigate\s*\(\s*(['"`])((?:\\.|(?!\1).)*?)\1/g; + +const collectNavigateTargets = (source: string): string[] => { + const out: string[] = []; + for (const match of source.matchAll(NAVIGATE_CALL_REGEX)) { + const raw = match[2]; + // Strip template-literal `${...}` placeholders so the prefix is comparable. + const prefix = raw.replaceAll(/\$\{[^}]*\}/g, ''); + out.push(prefix); + } + return out; +}; + +const containsUseNavigateImport = (source: string): boolean => + /from\s+['"]react-router-dom['"]/.test(source) && /\buseNavigate\b/.test(source); + +/** + * Rewrite `import { ..., useNavigate, ... } from 'react-router-dom'`: + * - drop `useNavigate` from the named imports list + * - if no other names remain, drop the entire import line + * - append a new import for `useWorkspaceAwareNavigate` + */ +const rewriteImports = (source: string): string => { + const importRegex = /^(\s*)import\s+\{([^}]+)\}\s+from\s+(['"])react-router-dom\3\s*(?:;\s*)?$/m; + const match = source.match(importRegex); + if (!match) return source; + const indent = match[1]; + const names = match[2] + .split(',') + .map((s) => s.trim()) + .filter(Boolean); + const remaining = names.filter((n) => n.replace(/\s+as\s+\w+/, '').trim() !== 'useNavigate'); + + const newWorkspaceImport = `${indent}${WORKSPACE_NAVIGATE_IMPORT}`; + + let replacement: string; + if (remaining.length === 0) { + replacement = newWorkspaceImport; + } else { + replacement = `${indent}import { ${remaining.join(', ')} } from 'react-router-dom';\n${newWorkspaceImport}`; + } + + return source.replace(importRegex, replacement); +}; + +const rewriteUseNavigateCalls = (source: string): string => + source.replaceAll(/\buseNavigate\s*\(\s*\)/g, 'useWorkspaceAwareNavigate()'); + +const ensureWorkspaceLinkImport = (source: string): string => { + if (source.includes(WORKSPACE_LINK_IMPORT)) return source; + // Insert right after the first `from 'react-router-dom'` import, or at top. + const rrdImport = /^(\s*)import\s[^;]*from\s+['"]react-router-dom['"]\s*(?:;\s*)?$/m; + const match = source.match(rrdImport); + if (match && match.index !== undefined) { + const insertAt = match.index + match[0].length; + return `${source.slice(0, insertAt)}\n${match[1]}${WORKSPACE_LINK_IMPORT}${source.slice( + insertAt, + )}`; + } + // Fallback: prepend. + return `${WORKSPACE_LINK_IMPORT}\n${source}`; +}; + +interface LinkRewriteResult { + changed: boolean; + rewrote: number; + source: string; +} + +const rewriteLinkTags = (source: string): LinkRewriteResult => { + // Walk all `` / `` tokens, pair opens with closes by depth, + // collect a single list of edits, then apply them in descending-index order + // so all positions stay valid throughout the rewrite. + const tagRegex = /<\/?Link\b[^>]*>/g; + interface OpenToken { + idx: number; + raw: string; + rewrite: boolean; + selfClosing: boolean; + } + interface CloseToken { + idx: number; + raw: string; + } + const opens: OpenToken[] = []; + const closes: CloseToken[] = []; + const allTokens: Array<{ kind: 'open' | 'close'; idx: number; raw: string }> = []; + + for (const m of source.matchAll(tagRegex)) { + const raw = m[0]; + const idx = m.index!; + if (raw.startsWith(''); + const tok: OpenToken = { idx, raw, rewrite, selfClosing }; + opens.push(tok); + allTokens.push({ kind: 'open', idx, raw }); + } + + // Pair opens with closes by walking the token stream in source order. + const stack: OpenToken[] = []; + const edits: Array<{ idx: number; length: number; replacement: string }> = []; + let rewroteCount = 0; + + for (const t of allTokens) { + if (t.kind === 'open') { + const ot = opens.find((o) => o.idx === t.idx)!; + if (ot.selfClosing) { + if (ot.rewrite) { + edits.push({ + idx: ot.idx, + length: ot.raw.length, + replacement: ot.raw.replace(/^$/, ' />'), + }); + rewroteCount++; + } + continue; + } + stack.push(ot); + } else { + const opener = stack.pop(); + if (opener?.rewrite) { + edits.push({ + idx: opener.idx, + length: opener.raw.length, + replacement: opener.raw.replace(/^', + }); + rewroteCount += 2; + } + } + } + + edits.sort((a, b) => b.idx - a.idx); + let out = source; + for (const e of edits) { + out = `${out.slice(0, e.idx)}${e.replacement}${out.slice(e.idx + e.length)}`; + } + + return { changed: edits.length > 0, source: out, rewrote: rewroteCount }; +}; + +async function processFile(absPath: string): Promise { + const rel = path.relative(ROOT, absPath); + const original = await readFile(absPath, 'utf8'); + let next = original; + + const targets = collectNavigateTargets(original); + const verdicts = targets.map(classifyPath); + const hasPersonal = verdicts.includes('personal'); + const hasShared = verdicts.includes('shared'); + + let didUseNavigateRewrite = false; + const hasUseNavigate = containsUseNavigateImport(original); + + if (hasUseNavigate) { + if (hasPersonal && hasShared) { + reports.warnings.push({ + file: rel, + reason: 'mixed personal/shared navigate targets — useNavigate left unchanged', + }); + } else if (hasPersonal && !hasShared) { + // pure personal — leave as-is + } else { + // pure shared OR no navigate calls (e.g. only Link). The latter case is + // safe: useNavigate is imported but unused for shared paths; flipping it + // is a no-op behaviorally. We only flip if useNavigate is actually CALLED + // — otherwise leave the import alone to avoid removing genuinely-unused + // imports the codemod didn't introduce. + if (/\buseNavigate\s*\(\s*\)/.test(original)) { + const afterImport = rewriteImports(next); + if (afterImport !== next) { + next = rewriteUseNavigateCalls(afterImport); + didUseNavigateRewrite = true; + } + } + } + } + + // Link rewrite — independent of useNavigate decision. + const linkResult = rewriteLinkTags(next); + if (linkResult.changed) { + next = ensureWorkspaceLinkImport(linkResult.source); + } + + if (next === original) { + if (hasUseNavigate && (hasShared || hasPersonal)) { + reports.skipped.push({ + file: rel, + reason: hasPersonal && !hasShared ? 'personal-only navigate targets' : 'no change needed', + }); + } + return; + } + + const summary: string[] = []; + if (didUseNavigateRewrite) summary.push('useNavigate'); + if (linkResult.rewrote > 0) summary.push(`${linkResult.rewrote / 2} Link tag(s)`); + + reports.transformed.push({ file: rel, reason: summary.join(' + ') }); + + if (!DRY && !CHECK) await writeFile(absPath, next, 'utf8'); +} + +async function main(): Promise { + const files: string[] = []; + for (const root of SCAN_ROOTS) { + const abs = path.join(ROOT, root); + if ( + await stat(abs).then( + () => true, + () => false, + ) + ) { + await walk(abs, files); + } + } + + for (const f of files) await processFile(f); + + const print = (label: string, list: Report[]): void => { + if (list.length === 0) return; + console.log(`\n${label} (${list.length}):`); + for (const r of list) console.log(` ${r.file} — ${r.reason}`); + }; + + print('TRANSFORMED', reports.transformed); + print('WARNINGS', reports.warnings); + if (process.env.VERBOSE) print('SKIPPED', reports.skipped); + + console.log( + `\nSummary: transformed=${reports.transformed.length} warnings=${reports.warnings.length} skipped=${reports.skipped.length} files-scanned=${files.length}`, + ); + + if (CHECK && reports.transformed.length > 0) { + console.error('\n✗ codemod would modify files. Re-run without --check to apply.'); + process.exit(1); + } +} + +await main(); diff --git a/src/app/(backend)/api/webhooks/memory-user-memory/pipelines/extract/chat-topic/cancel/route.ts b/src/app/(backend)/api/webhooks/memory-user-memory/pipelines/extract/chat-topic/cancel/route.ts index f30b77268c..041248f9c4 100644 --- a/src/app/(backend)/api/webhooks/memory-user-memory/pipelines/extract/chat-topic/cancel/route.ts +++ b/src/app/(backend)/api/webhooks/memory-user-memory/pipelines/extract/chat-topic/cancel/route.ts @@ -104,7 +104,7 @@ export const POST = async (req: Request) => { }, }; - const asyncTaskModel = new AsyncTaskModel(db, task.userId); + const asyncTaskModel = new AsyncTaskModel(db, task.userId, task.workspaceId ?? undefined); await asyncTaskModel.update(task.id, { error: new AsyncTaskError( AsyncTaskErrorType.TaskCancelled, diff --git a/src/app/(backend)/api/webhooks/video/[provider]/route.ts b/src/app/(backend)/api/webhooks/video/[provider]/route.ts index 6df2625f99..2fe1a7cd3d 100644 --- a/src/app/(backend)/api/webhooks/video/[provider]/route.ts +++ b/src/app/(backend)/api/webhooks/video/[provider]/route.ts @@ -119,7 +119,11 @@ export const POST = async (req: Request, { params }: { params: Promise<{ provide return NextResponse.json({ success: true }); } - const generationModel = new GenerationModel(db, asyncTask.userId); + const generationModel = new GenerationModel( + db, + asyncTask.userId, + asyncTask.workspaceId ?? undefined, + ); // Find generation by asyncTaskId const generation = await generationModel.findByAsyncTaskId(asyncTask.id); @@ -133,7 +137,7 @@ export const POST = async (req: Request, { params }: { params: Promise<{ provide log('Found generation: %s', generation.id); - asyncTaskModel = new AsyncTaskModel(db, asyncTask.userId); + asyncTaskModel = new AsyncTaskModel(db, asyncTask.userId, asyncTask.workspaceId ?? undefined); // Query batch to get model info for both error and success paths const batch = await db.query.generationBatches.findFirst({ @@ -182,7 +186,11 @@ export const POST = async (req: Request, { params }: { params: Promise<{ provide } // Handle success result: download video → process → upload S3 → create asset and file - const videoService = new VideoGenerationService(db, asyncTask.userId); + const videoService = new VideoGenerationService( + db, + asyncTask.userId, + asyncTask.workspaceId ?? undefined, + ); const processResult = await videoService.processVideoForGeneration(result.videoUrl); const asset: VideoGenerationAsset = { diff --git a/src/app/(backend)/api/workflows/agent-eval-run/execute-test-case/route.ts b/src/app/(backend)/api/workflows/agent-eval-run/execute-test-case/route.ts index 12a681b241..9f42a8a22c 100644 --- a/src/app/(backend)/api/workflows/agent-eval-run/execute-test-case/route.ts +++ b/src/app/(backend)/api/workflows/agent-eval-run/execute-test-case/route.ts @@ -5,6 +5,7 @@ import { AgentEvalRunModel } from '@/database/models/agentEval'; import { getServerDB } from '@/database/server'; import { qstashClient } from '@/libs/qstash'; import { AgentEvalRunWorkflow, type ExecuteTestCasePayload } from '@/server/workflows/agentEvalRun'; +import { resolveAgentEvalRunWorkspace } from '@/server/workflows/agentEvalRun/utils'; const log = debug('lobe-server:workflows:execute-test-case'); @@ -25,10 +26,11 @@ export const { POST } = serve( } const db = await getServerDB(); + const wsId = await resolveAgentEvalRunWorkspace(db, runId); // Get run to get K value from config const run = await context.run('agent-eval-run:get-run', async () => { - const runModel = new AgentEvalRunModel(db, userId); + const runModel = new AgentEvalRunModel(db, userId, wsId); return runModel.findById(runId); }); diff --git a/src/app/(backend)/api/workflows/agent-eval-run/finalize-run/route.ts b/src/app/(backend)/api/workflows/agent-eval-run/finalize-run/route.ts index 23d8e9a9be..fbf00896ec 100644 --- a/src/app/(backend)/api/workflows/agent-eval-run/finalize-run/route.ts +++ b/src/app/(backend)/api/workflows/agent-eval-run/finalize-run/route.ts @@ -6,6 +6,7 @@ import { getServerDB } from '@/database/server'; import { qstashClient } from '@/libs/qstash'; import { AgentEvalRunService } from '@/server/services/agentEvalRun'; import { type FinalizeRunPayload } from '@/server/workflows/agentEvalRun'; +import { resolveAgentEvalRunWorkspace } from '@/server/workflows/agentEvalRun/utils'; const log = debug('lobe-server:workflows:finalize-run'); @@ -31,10 +32,11 @@ export const { POST } = serve( } const db = await getServerDB(); + const wsId = await resolveAgentEvalRunWorkspace(db, runId); // Step 1: Get run details const run = await context.run('agent-eval-run:get-run', async () => { - const runModel = new AgentEvalRunModel(db, userId); + const runModel = new AgentEvalRunModel(db, userId, wsId); return runModel.findById(runId); }); @@ -49,7 +51,7 @@ export const { POST } = serve( // Step 2: Get all RunTopics (already evaluated in recordTrajectoryCompletion) const runTopics = await context.run('agent-eval-run:get-run-topics', async () => { - const runTopicModel = new AgentEvalRunTopicModel(db, userId); + const runTopicModel = new AgentEvalRunTopicModel(db, userId, wsId); return runTopicModel.findByRunId(runId); }); @@ -57,7 +59,7 @@ export const { POST } = serve( // Step 3: Aggregate metrics from already-evaluated RunTopics const metrics = await context.run('agent-eval-run:aggregate-metrics', async () => { - const service = new AgentEvalRunService(db, userId); + const service = new AgentEvalRunService(db, userId, wsId); return service.evaluateAndFinalizeRun({ run: { config: run.config, id: runId, metrics: run.metrics, startedAt: run.startedAt }, runTopics, @@ -80,7 +82,7 @@ export const { POST } = serve( : 'completed'; await context.run('agent-eval-run:update-run', async () => { - const runModel = new AgentEvalRunModel(db, userId); + const runModel = new AgentEvalRunModel(db, userId, wsId); return runModel.update(runId, { metrics, status: runStatus }); }); diff --git a/src/app/(backend)/api/workflows/agent-eval-run/on-thread-complete/route.ts b/src/app/(backend)/api/workflows/agent-eval-run/on-thread-complete/route.ts index a2b2399691..951cfbe684 100644 --- a/src/app/(backend)/api/workflows/agent-eval-run/on-thread-complete/route.ts +++ b/src/app/(backend)/api/workflows/agent-eval-run/on-thread-complete/route.ts @@ -8,6 +8,7 @@ import { AgentEvalRunWorkflow, type OnThreadCompletePayload, } from '@/server/workflows/agentEvalRun'; +import { resolveAgentEvalRunWorkspace } from '@/server/workflows/agentEvalRun/utils'; const log = debug('lobe-server:workflows:on-thread-complete'); @@ -57,16 +58,17 @@ export async function POST(req: Request) { ); const db = await getServerDB(); + const wsId = await resolveAgentEvalRunWorkspace(db, runId); // Check if run was aborted — skip processing to avoid overwriting abort state - const runModel = new AgentEvalRunModel(db, userId); + const runModel = new AgentEvalRunModel(db, userId, wsId); const run = await runModel.findById(runId); if (run?.status === 'aborted') { log('Run aborted, skipping: runId=%s testCaseId=%s threadId=%s', runId, testCaseId, threadId); return NextResponse.json({ cancelled: true }); } - const service = new AgentEvalRunService(db, userId); + const service = new AgentEvalRunService(db, userId, wsId); const { allThreadsDone, allRunDone } = await service.recordThreadCompletion({ runId, diff --git a/src/app/(backend)/api/workflows/agent-eval-run/on-trajectory-complete/route.ts b/src/app/(backend)/api/workflows/agent-eval-run/on-trajectory-complete/route.ts index a4da247309..96e52e7f06 100644 --- a/src/app/(backend)/api/workflows/agent-eval-run/on-trajectory-complete/route.ts +++ b/src/app/(backend)/api/workflows/agent-eval-run/on-trajectory-complete/route.ts @@ -8,6 +8,7 @@ import { AgentEvalRunWorkflow, type OnTrajectoryCompletePayload, } from '@/server/workflows/agentEvalRun'; +import { resolveAgentEvalRunWorkspace } from '@/server/workflows/agentEvalRun/utils'; const log = debug('lobe-server:workflows:on-trajectory-complete'); @@ -58,16 +59,17 @@ export async function POST(req: Request) { ); const db = await getServerDB(); + const wsId = await resolveAgentEvalRunWorkspace(db, runId); // Check if run was aborted — skip processing to avoid overwriting abort state - const runModel = new AgentEvalRunModel(db, userId); + const runModel = new AgentEvalRunModel(db, userId, wsId); const run = await runModel.findById(runId); if (run?.status === 'aborted') { log('Run aborted, skipping: runId=%s testCaseId=%s', runId, testCaseId); return NextResponse.json({ cancelled: true }); } - const service = new AgentEvalRunService(db, userId); + const service = new AgentEvalRunService(db, userId, wsId); const { allDone, completedCount } = await service.recordTrajectoryCompletion({ runId, diff --git a/src/app/(backend)/api/workflows/agent-eval-run/paginate-test-cases/route.ts b/src/app/(backend)/api/workflows/agent-eval-run/paginate-test-cases/route.ts index 47ad42b70c..1d8b15eef9 100644 --- a/src/app/(backend)/api/workflows/agent-eval-run/paginate-test-cases/route.ts +++ b/src/app/(backend)/api/workflows/agent-eval-run/paginate-test-cases/route.ts @@ -9,6 +9,7 @@ import { AgentEvalRunWorkflow, type PaginateTestCasesPayload, } from '@/server/workflows/agentEvalRun'; +import { resolveAgentEvalRunWorkspace } from '@/server/workflows/agentEvalRun/utils'; const CHUNK_SIZE = 20; // Max items to process directly const PAGE_SIZE = 50; // Items per page @@ -34,6 +35,7 @@ export const { POST } = serve( } const db = await getServerDB(); + const wsId = await resolveAgentEvalRunWorkspace(db, runId); // If specific testCaseIds are provided (from fanout), process them directly if (payloadTestCaseIds && payloadTestCaseIds.length > 0) { @@ -55,7 +57,7 @@ export const { POST } = serve( // Check if run was aborted before paginating const runStatus = await context.run('agent-eval-run:check-abort', async () => { - const runModel = new AgentEvalRunModel(db, userId); + const runModel = new AgentEvalRunModel(db, userId, wsId); const run = await runModel.findById(runId); return run?.status; }); @@ -68,12 +70,12 @@ export const { POST } = serve( // Paginate through test cases const testCaseBatch = await context.run('agent-eval-run:get-test-cases-page', async () => { // Get run to find datasetId and userId - const runModel = new AgentEvalRunModel(db, userId); + const runModel = new AgentEvalRunModel(db, userId, wsId); const run = await runModel.findById(runId); if (!run) return { ids: [] }; // Get test cases for this dataset - const testCaseModel = new AgentEvalTestCaseModel(db, userId); + const testCaseModel = new AgentEvalTestCaseModel(db, userId, wsId); const allTestCases = await testCaseModel.findByDatasetId(run.datasetId); // Apply cursor-based pagination @@ -108,6 +110,7 @@ export const { POST } = serve( runId, testCaseIds: batchTestCaseIds, userId, + workspaceId: wsId, }), ); diff --git a/src/app/(backend)/api/workflows/agent-eval-run/resume-agent-trajectory/route.ts b/src/app/(backend)/api/workflows/agent-eval-run/resume-agent-trajectory/route.ts index 97c0faacc4..a4151370d3 100644 --- a/src/app/(backend)/api/workflows/agent-eval-run/resume-agent-trajectory/route.ts +++ b/src/app/(backend)/api/workflows/agent-eval-run/resume-agent-trajectory/route.ts @@ -5,6 +5,7 @@ import { getServerDB } from '@/database/server'; import { qstashClient } from '@/libs/qstash'; import { AgentEvalRunService } from '@/server/services/agentEvalRun'; import type { ResumeAgentTrajectoryPayload } from '@/server/workflows/agentEvalRun'; +import { resolveAgentEvalRunWorkspace } from '@/server/workflows/agentEvalRun/utils'; const log = debug('lobe-server:workflows:resume-agent-trajectory'); @@ -27,7 +28,8 @@ export const { POST } = serve( } const db = await getServerDB(); - const service = new AgentEvalRunService(db, userId); + const wsId = await resolveAgentEvalRunWorkspace(db, runId); + const service = new AgentEvalRunService(db, userId, wsId); await context.run('resume-agent-trajectory:exec-agent', () => service.executeResumedTrajectory(payload), diff --git a/src/app/(backend)/api/workflows/agent-eval-run/resume-thread-trajectory/route.ts b/src/app/(backend)/api/workflows/agent-eval-run/resume-thread-trajectory/route.ts index f66b5b17ca..bf39cfbc78 100644 --- a/src/app/(backend)/api/workflows/agent-eval-run/resume-thread-trajectory/route.ts +++ b/src/app/(backend)/api/workflows/agent-eval-run/resume-thread-trajectory/route.ts @@ -5,6 +5,7 @@ import { getServerDB } from '@/database/server'; import { qstashClient } from '@/libs/qstash'; import { AgentEvalRunService } from '@/server/services/agentEvalRun'; import type { ResumeThreadTrajectoryPayload } from '@/server/workflows/agentEvalRun'; +import { resolveAgentEvalRunWorkspace } from '@/server/workflows/agentEvalRun/utils'; const log = debug('lobe-server:workflows:resume-thread-trajectory'); @@ -29,7 +30,8 @@ export const { POST } = serve( } const db = await getServerDB(); - const service = new AgentEvalRunService(db, userId); + const wsId = await resolveAgentEvalRunWorkspace(db, runId); + const service = new AgentEvalRunService(db, userId, wsId); await context.run('resume-thread-trajectory:exec-agent', () => service.executeResumedThreadTrajectory(payload), diff --git a/src/app/(backend)/api/workflows/agent-eval-run/run-agent-trajectory/route.ts b/src/app/(backend)/api/workflows/agent-eval-run/run-agent-trajectory/route.ts index dca583dc50..379bf6ea66 100644 --- a/src/app/(backend)/api/workflows/agent-eval-run/run-agent-trajectory/route.ts +++ b/src/app/(backend)/api/workflows/agent-eval-run/run-agent-trajectory/route.ts @@ -8,6 +8,7 @@ import { AgentEvalRunWorkflow, type RunAgentTrajectoryPayload, } from '@/server/workflows/agentEvalRun'; +import { resolveAgentEvalRunWorkspace } from '@/server/workflows/agentEvalRun/utils'; const log = debug('lobe-server:workflows:run-agent-trajectory'); @@ -27,7 +28,8 @@ export const { POST } = serve( } const db = await getServerDB(); - const service = new AgentEvalRunService(db, userId); + const wsId = await resolveAgentEvalRunWorkspace(db, runId); + const service = new AgentEvalRunService(db, userId, wsId); // Step 1: Read all required data const data = await context.run('agent-eval-run:load-data', () => diff --git a/src/app/(backend)/api/workflows/agent-eval-run/run-benchmark/route.ts b/src/app/(backend)/api/workflows/agent-eval-run/run-benchmark/route.ts index 4834f0cb4c..e38354b560 100644 --- a/src/app/(backend)/api/workflows/agent-eval-run/run-benchmark/route.ts +++ b/src/app/(backend)/api/workflows/agent-eval-run/run-benchmark/route.ts @@ -5,6 +5,7 @@ import { AgentEvalRunModel, AgentEvalTestCaseModel } from '@/database/models/age import { getServerDB } from '@/database/server'; import { qstashClient } from '@/libs/qstash'; import { AgentEvalRunWorkflow, type RunBenchmarkPayload } from '@/server/workflows/agentEvalRun'; +import { resolveAgentEvalRunWorkspace } from '@/server/workflows/agentEvalRun/utils'; const log = debug('lobe-server:workflows:run-benchmark'); @@ -28,7 +29,8 @@ export const { POST } = serve( } const db = await getServerDB(); - const runModel = new AgentEvalRunModel(db, userId); + const wsId = await resolveAgentEvalRunWorkspace(db, runId); + const runModel = new AgentEvalRunModel(db, userId, wsId); // Get run info const run = await context.run('agent-eval-run:get-run', () => runModel.findById(runId)); @@ -43,7 +45,7 @@ export const { POST } = serve( } // Get all test cases - const testCaseModel = new AgentEvalTestCaseModel(db, userId); + const testCaseModel = new AgentEvalTestCaseModel(db, userId, wsId); const allTestCases = await context.run('agent-eval-run:get-test-cases', () => testCaseModel.findByDatasetId(run.datasetId), ); @@ -66,6 +68,7 @@ export const { POST } = serve( runId, testCaseIds: allTestCaseIds, userId, + workspaceId: wsId, }), ); diff --git a/src/app/(backend)/api/workflows/agent-eval-run/run-thread-trajectory/route.ts b/src/app/(backend)/api/workflows/agent-eval-run/run-thread-trajectory/route.ts index 430ec49579..7e8cfb0d90 100644 --- a/src/app/(backend)/api/workflows/agent-eval-run/run-thread-trajectory/route.ts +++ b/src/app/(backend)/api/workflows/agent-eval-run/run-thread-trajectory/route.ts @@ -8,6 +8,7 @@ import { AgentEvalRunWorkflow, type RunThreadTrajectoryPayload, } from '@/server/workflows/agentEvalRun'; +import { resolveAgentEvalRunWorkspace } from '@/server/workflows/agentEvalRun/utils'; const log = debug('lobe-server:workflows:run-thread-trajectory'); @@ -26,7 +27,8 @@ export const { POST } = serve( } const db = await getServerDB(); - const service = new AgentEvalRunService(db, userId); + const wsId = await resolveAgentEvalRunWorkspace(db, runId); + const service = new AgentEvalRunService(db, userId, wsId); // Step 1: Load run + testCase data const data = await context.run('thread-trajectory:load-data', () => diff --git a/src/app/(backend)/market/social/[[...segments]]/route.ts b/src/app/(backend)/market/social/[[...segments]]/route.ts index efe18cdd06..10e3e15cda 100644 --- a/src/app/(backend)/market/social/[[...segments]]/route.ts +++ b/src/app/(backend)/market/social/[[...segments]]/route.ts @@ -149,9 +149,12 @@ export const GET = async (req: NextRequest, context: RouteContext) => { market.follows.getFollowing(userId, { limit: 1 }), market.follows.getFollowers(userId, { limit: 1 }), ]); + // `totalCount` is the full-list total returned by the market follows + // endpoints (added in market-sdk >= 0.34.0-beta.2). The cast keeps this + // building against older published SDK typings until the dep is bumped. return NextResponse.json({ - followersCount: (followers as any).totalCount || (followers as any).total || 0, - followingCount: (following as any).totalCount || (following as any).total || 0, + followersCount: (followers as { totalCount?: number }).totalCount ?? 0, + followingCount: (following as { totalCount?: number }).totalCount ?? 0, }); } diff --git a/src/app/(backend)/oauth/connector/callback/route.ts b/src/app/(backend)/oauth/connector/callback/route.ts index e115acc8db..ddc47bfebb 100644 --- a/src/app/(backend)/oauth/connector/callback/route.ts +++ b/src/app/(backend)/oauth/connector/callback/route.ts @@ -89,7 +89,7 @@ export const GET = async (req: NextRequest) => { } const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey(); - const connectorModel = new ConnectorModel(serverDB, payload.lobeUserId, gateKeeper); + const connectorModel = new ConnectorModel(serverDB, payload.lobeUserId, undefined, gateKeeper); const connector = await connectorModel.findById(payload.connectorId); if (!connector) { diff --git a/src/app/(backend)/webapi/_utils/workspace.test.ts b/src/app/(backend)/webapi/_utils/workspace.test.ts new file mode 100644 index 0000000000..920b33d7e4 --- /dev/null +++ b/src/app/(backend)/webapi/_utils/workspace.test.ts @@ -0,0 +1,103 @@ +// @vitest-environment node +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import type { LobeChatDatabase } from '@/database/type'; + +import { resolveValidWorkspaceIdFromRequest, WORKSPACE_ID_HEADER } from './workspace'; + +const workspaceFindFirst = vi.fn(); +const workspaceMemberFindFirst = vi.fn(); + +const serverDB = { + query: { + workspaceMembers: { + findFirst: workspaceMemberFindFirst, + }, + workspaces: { + findFirst: workspaceFindFirst, + }, + }, +} as unknown as LobeChatDatabase; + +const createRequest = (workspaceId?: string | null) => { + const headers = new Headers(); + if (workspaceId !== undefined && workspaceId !== null) + headers.set(WORKSPACE_ID_HEADER, workspaceId); + + return new Request('https://app.test/webapi/models/openai', { headers }); +}; + +beforeEach(() => { + vi.clearAllMocks(); +}); + +describe('resolveValidWorkspaceIdFromRequest', () => { + it('returns undefined without querying when the workspace header is missing', async () => { + await expect( + resolveValidWorkspaceIdFromRequest({ + req: createRequest(), + serverDB, + userId: 'user-1', + }), + ).resolves.toBeUndefined(); + + expect(workspaceFindFirst).not.toHaveBeenCalled(); + expect(workspaceMemberFindFirst).not.toHaveBeenCalled(); + }); + + it('trims blank workspace headers and treats them as absent', async () => { + await expect( + resolveValidWorkspaceIdFromRequest({ + req: createRequest(' '), + serverDB, + userId: 'user-1', + }), + ).resolves.toBeUndefined(); + + expect(workspaceFindFirst).not.toHaveBeenCalled(); + expect(workspaceMemberFindFirst).not.toHaveBeenCalled(); + }); + + it('returns undefined when the workspace id does not exist', async () => { + workspaceFindFirst.mockResolvedValueOnce(undefined); + + await expect( + resolveValidWorkspaceIdFromRequest({ + req: createRequest(' ws-1 '), + serverDB, + userId: 'user-1', + }), + ).resolves.toBeUndefined(); + + expect(workspaceFindFirst).toHaveBeenCalledTimes(1); + expect(workspaceMemberFindFirst).not.toHaveBeenCalled(); + }); + + it('returns undefined when the requester is not an active workspace member', async () => { + workspaceFindFirst.mockResolvedValueOnce({ id: 'ws-1' }); + workspaceMemberFindFirst.mockResolvedValueOnce(undefined); + + await expect( + resolveValidWorkspaceIdFromRequest({ + req: createRequest('ws-1'), + serverDB, + userId: 'user-1', + }), + ).resolves.toBeUndefined(); + + expect(workspaceMemberFindFirst).toHaveBeenCalledTimes(1); + }); + + it('returns the trimmed workspace id for an existing workspace and active member', async () => { + workspaceFindFirst.mockResolvedValueOnce({ id: 'ws-1' }); + workspaceMemberFindFirst.mockResolvedValueOnce({ userId: 'user-1' }); + + await expect( + resolveValidWorkspaceIdFromRequest({ + req: createRequest(' ws-1 '), + serverDB, + userId: 'user-1', + }), + ).resolves.toBe('ws-1'); + }); +}); diff --git a/src/app/(backend)/webapi/_utils/workspace.ts b/src/app/(backend)/webapi/_utils/workspace.ts new file mode 100644 index 0000000000..c2aaf015ce --- /dev/null +++ b/src/app/(backend)/webapi/_utils/workspace.ts @@ -0,0 +1,32 @@ +import { and, eq, isNull } from 'drizzle-orm'; + +import { workspaceMembers, workspaces } from '@/database/schemas'; +import type { LobeChatDatabase } from '@/database/type'; + +export const WORKSPACE_ID_HEADER = 'X-Workspace-Id'; + +export const resolveValidWorkspaceIdFromRequest = async (params: { + req: Request; + serverDB: LobeChatDatabase; + userId: string; +}): Promise => { + const workspaceId = params.req.headers.get(WORKSPACE_ID_HEADER)?.trim(); + if (!workspaceId) return undefined; + + const workspace = await params.serverDB.query.workspaces.findFirst({ + columns: { id: true }, + where: eq(workspaces.id, workspaceId), + }); + if (!workspace) return undefined; + + const membership = await params.serverDB.query.workspaceMembers.findFirst({ + columns: { userId: true }, + where: and( + eq(workspaceMembers.workspaceId, workspaceId), + eq(workspaceMembers.userId, params.userId), + isNull(workspaceMembers.deletedAt), + ), + }); + + return membership ? workspaceId : undefined; +}; diff --git a/src/app/(backend)/webapi/chat/[provider]/route.test.ts b/src/app/(backend)/webapi/chat/[provider]/route.test.ts index 4fc727a2ad..61992d059c 100644 --- a/src/app/(backend)/webapi/chat/[provider]/route.test.ts +++ b/src/app/(backend)/webapi/chat/[provider]/route.test.ts @@ -66,6 +66,7 @@ describe('POST handler', () => { expect.anything(), 'test-user-id', 'test-provider', + undefined, ); }); diff --git a/src/app/(backend)/webapi/chat/[provider]/route.ts b/src/app/(backend)/webapi/chat/[provider]/route.ts index ebe03185a1..d40e8d0195 100644 --- a/src/app/(backend)/webapi/chat/[provider]/route.ts +++ b/src/app/(backend)/webapi/chat/[provider]/route.ts @@ -8,6 +8,8 @@ import { type ChatStreamPayload } from '@/types/openai/chat'; import { createErrorResponse } from '@/utils/errorResponse'; import { getTracePayload } from '@/utils/trace'; +import { resolveValidWorkspaceIdFromRequest } from '../../_utils/workspace'; + // If user don't use fluid compute, will build failed // this enforce user to enable fluid compute export const maxDuration = 300; @@ -16,8 +18,10 @@ export const POST = checkAuth(async (req: Request, { params, userId, serverDB }) const provider = (await params)!.provider!; try { + const workspaceId = await resolveValidWorkspaceIdFromRequest({ req, serverDB, userId }); + // ============ 1. init chat model ============ // - const modelRuntime = await initModelRuntimeFromDB(serverDB, userId, provider); + const modelRuntime = await initModelRuntimeFromDB(serverDB, userId, provider, workspaceId); // ============ 2. create chat completion ============ // diff --git a/src/app/(backend)/webapi/models/[provider]/pull/route.ts b/src/app/(backend)/webapi/models/[provider]/pull/route.ts index 9663c30f33..4b5f8189dd 100644 --- a/src/app/(backend)/webapi/models/[provider]/pull/route.ts +++ b/src/app/(backend)/webapi/models/[provider]/pull/route.ts @@ -5,12 +5,16 @@ import { checkAuth } from '@/app/(backend)/middleware/auth'; import { initModelRuntimeFromDB } from '@/server/modules/ModelRuntime'; import { createErrorResponse } from '@/utils/errorResponse'; +import { resolveValidWorkspaceIdFromRequest } from '../../../_utils/workspace'; + export const POST = checkAuth(async (req, { params, userId, serverDB }) => { const provider = (await params)!.provider!; try { + const workspaceId = await resolveValidWorkspaceIdFromRequest({ req, serverDB, userId }); + // Read user's provider config from database - const agentRuntime = await initModelRuntimeFromDB(serverDB, userId, provider); + const agentRuntime = await initModelRuntimeFromDB(serverDB, userId, provider, workspaceId); const data = (await req.json()) as PullModelParams; diff --git a/src/app/(backend)/webapi/models/[provider]/route.ts b/src/app/(backend)/webapi/models/[provider]/route.ts index 508b46238d..97e339977a 100644 --- a/src/app/(backend)/webapi/models/[provider]/route.ts +++ b/src/app/(backend)/webapi/models/[provider]/route.ts @@ -6,12 +6,16 @@ import { checkAuth } from '@/app/(backend)/middleware/auth'; import { initModelRuntimeFromDB } from '@/server/modules/ModelRuntime'; import { createErrorResponse } from '@/utils/errorResponse'; +import { resolveValidWorkspaceIdFromRequest } from '../../_utils/workspace'; + export const GET = checkAuth(async (req, { params, userId, serverDB }) => { const provider = (await params)!.provider!; try { + const workspaceId = await resolveValidWorkspaceIdFromRequest({ req, serverDB, userId }); + // Read user's provider config from database - const agentRuntime = await initModelRuntimeFromDB(serverDB, userId, provider); + const agentRuntime = await initModelRuntimeFromDB(serverDB, userId, provider, workspaceId); const list = await agentRuntime.models(); diff --git a/src/business/server/image-generation/chargeAfterGenerate.ts b/src/business/server/image-generation/chargeAfterGenerate.ts index 7affe4d5c5..57505f8842 100644 --- a/src/business/server/image-generation/chargeAfterGenerate.ts +++ b/src/business/server/image-generation/chargeAfterGenerate.ts @@ -11,6 +11,7 @@ interface ChargeParams { modelUsage?: ModelUsage; provider: string; userId: string; + workspaceId?: string; } // eslint-disable-next-line unused-imports/no-unused-vars diff --git a/src/business/server/image-generation/chargeBeforeGenerate.ts b/src/business/server/image-generation/chargeBeforeGenerate.ts index bb93e362df..d4ab880153 100644 --- a/src/business/server/image-generation/chargeBeforeGenerate.ts +++ b/src/business/server/image-generation/chargeBeforeGenerate.ts @@ -10,6 +10,7 @@ interface ChargeParams { model: string; provider: string; userId: string; + workspaceId?: string; } type ChargeResult = diff --git a/src/business/server/lambda-routers/file.ts b/src/business/server/lambda-routers/file.ts index b38c98d51d..13cc368b5e 100644 --- a/src/business/server/lambda-routers/file.ts +++ b/src/business/server/lambda-routers/file.ts @@ -7,8 +7,19 @@ export interface BusinessFileUploadCheckParams { transaction?: Transaction; url: string; userId: string; + workspaceId?: string | null; } export async function businessFileUploadCheck( _params: BusinessFileUploadCheckParams, ): Promise {} + +export interface BusinessFileTransferStorageCheckParams { + additionalSize: number; + targetUserId: string; + targetWorkspaceId: string | null; +} + +export async function businessFileTransferStorageCheck( + _params: BusinessFileTransferStorageCheckParams, +): Promise {} diff --git a/src/business/server/lambda-routers/workspace.ts b/src/business/server/lambda-routers/workspace.ts new file mode 100644 index 0000000000..844d43eb8d --- /dev/null +++ b/src/business/server/lambda-routers/workspace.ts @@ -0,0 +1,4 @@ +import { router } from '@/libs/trpc/lambda'; + +// Cloud overrides this at the same path with the real workspaceRouter backed by cloudDB. +export const workspaceRouter = router({}); diff --git a/src/business/server/lambda-routers/workspaceAuditLog.ts b/src/business/server/lambda-routers/workspaceAuditLog.ts new file mode 100644 index 0000000000..1c97cb4206 --- /dev/null +++ b/src/business/server/lambda-routers/workspaceAuditLog.ts @@ -0,0 +1,3 @@ +import { router } from '@/libs/trpc/lambda'; + +export const workspaceAuditLogRouter = router({}); diff --git a/src/business/server/lambda-routers/workspaceCredits.ts b/src/business/server/lambda-routers/workspaceCredits.ts new file mode 100644 index 0000000000..edc506f072 --- /dev/null +++ b/src/business/server/lambda-routers/workspaceCredits.ts @@ -0,0 +1,3 @@ +import { router } from '@/libs/trpc/lambda'; + +export const workspaceCreditsRouter = router({}); diff --git a/src/business/server/lambda-routers/workspaceCreds.ts b/src/business/server/lambda-routers/workspaceCreds.ts new file mode 100644 index 0000000000..8ec1a3366f --- /dev/null +++ b/src/business/server/lambda-routers/workspaceCreds.ts @@ -0,0 +1,3 @@ +import { router } from '@/libs/trpc/lambda'; + +export const workspaceCredsRouter = router({}); diff --git a/src/business/server/lambda-routers/workspaceData.ts b/src/business/server/lambda-routers/workspaceData.ts new file mode 100644 index 0000000000..8d319fff46 --- /dev/null +++ b/src/business/server/lambda-routers/workspaceData.ts @@ -0,0 +1,3 @@ +import { router } from '@/libs/trpc/lambda'; + +export const workspaceDataRouter = router({}); diff --git a/src/business/server/lambda-routers/workspaceMember.ts b/src/business/server/lambda-routers/workspaceMember.ts new file mode 100644 index 0000000000..c8e03c84c4 --- /dev/null +++ b/src/business/server/lambda-routers/workspaceMember.ts @@ -0,0 +1,3 @@ +import { router } from '@/libs/trpc/lambda'; + +export const workspaceMemberRouter = router({}); diff --git a/src/business/server/lambda-routers/workspaceUsage.ts b/src/business/server/lambda-routers/workspaceUsage.ts new file mode 100644 index 0000000000..35c832d989 --- /dev/null +++ b/src/business/server/lambda-routers/workspaceUsage.ts @@ -0,0 +1,3 @@ +import { router } from '@/libs/trpc/lambda'; + +export const workspaceUsageRouter = router({}); diff --git a/src/business/server/model-runtime.ts b/src/business/server/model-runtime.ts index b1ca38f881..2cae4842df 100644 --- a/src/business/server/model-runtime.ts +++ b/src/business/server/model-runtime.ts @@ -3,6 +3,7 @@ import type { ModelRuntimeHooks } from '@lobechat/model-runtime'; export function getBusinessModelRuntimeHooks( _userId: string, _provider: string, + _workspaceId?: string, ): ModelRuntimeHooks | undefined { return undefined; } diff --git a/src/business/server/trpc-middlewares/rbacPermission.ts b/src/business/server/trpc-middlewares/rbacPermission.ts new file mode 100644 index 0000000000..4ac7e8516d --- /dev/null +++ b/src/business/server/trpc-middlewares/rbacPermission.ts @@ -0,0 +1,28 @@ +import { trpc } from '@/libs/trpc/lambda/init'; + +/** + * No-op stub for OSS builds. Cloud overrides this entire module via tsconfig + * path priority and provides the real workspace-RBAC-aware implementations + * (see `src/business/server/trpc-middlewares/rbacPermission.ts` in the cloud + * repo). In OSS there is no workspace concept worth gating, so every gate + * passes through. + * + * Keep the export shape identical to the cloud version so router code that + * imports from `@/business/server/trpc-middlewares/rbacPermission` compiles + * and runs in both environments without conditional imports. + */ +export const withRbacPermission = (_code: string) => trpc.middleware(async (opts) => opts.next()); + +export const withAnyRbacPermission = (_codes: string[]) => + trpc.middleware(async (opts) => opts.next()); + +export const withAllRbacPermissions = (_codes: string[]) => + trpc.middleware(async (opts) => opts.next()); + +/** + * Sugar for the "member-or-owner" gate — in cloud this fans the action code + * out into the `:all | :owner` scope pair so a member with the `:owner` grant + * passes alongside an owner with the `:all` grant. OSS no-op. + */ +export const withScopedPermission = (_action: string) => + trpc.middleware(async (opts) => opts.next()); diff --git a/src/business/server/trpc-middlewares/workspaceAuth.ts b/src/business/server/trpc-middlewares/workspaceAuth.ts new file mode 100644 index 0000000000..09e4ed4725 --- /dev/null +++ b/src/business/server/trpc-middlewares/workspaceAuth.ts @@ -0,0 +1,22 @@ +import { authedProcedure } from '@/libs/trpc/lambda'; +import { trpc } from '@/libs/trpc/lambda/init'; + +export type WorkspaceRole = 'member' | 'owner' | 'viewer'; + +export const cloudWorkspaceAuth = trpc.middleware(async (opts) => opts.next()); + +export const lobeWorkspaceAuth = trpc.middleware(async (opts) => opts.next()); + +export const requireWorkspaceRole = (_minRole: WorkspaceRole) => + trpc.middleware(async (opts) => opts.next()); + +export const requireWorkspaceRoleWhenScoped = (_minRole: WorkspaceRole) => + trpc.middleware(async (opts) => opts.next()); + +export const wsProcedure = authedProcedure; + +export const wsMemberProcedure = authedProcedure; + +export const wsOwnerProcedure = authedProcedure; + +export const wsCompatProcedure = authedProcedure; diff --git a/src/business/server/trpc-middlewares/workspaceContext.ts b/src/business/server/trpc-middlewares/workspaceContext.ts new file mode 100644 index 0000000000..8392619918 --- /dev/null +++ b/src/business/server/trpc-middlewares/workspaceContext.ts @@ -0,0 +1,9 @@ +/** + * Re-export of workspace ownership helpers from `@lobechat/database`. + * + * The actual implementation lives in `packages/database/src/utils/workspace.ts` + * because Models in that package need the same helpers, and the database + * package can't import from `src/`. Routers and middleware can import either + * from this path or from `@lobechat/database` directly. + */ +export { buildWorkspacePayload, buildWorkspaceWhere } from '@lobechat/database'; diff --git a/src/business/server/video-generation/chargeAfterGenerate.ts b/src/business/server/video-generation/chargeAfterGenerate.ts index c1f3a63207..14183dff78 100644 --- a/src/business/server/video-generation/chargeAfterGenerate.ts +++ b/src/business/server/video-generation/chargeAfterGenerate.ts @@ -14,6 +14,7 @@ interface ChargeParams { provider: string; usage?: { completionTokens: number; totalTokens: number }; userId: string; + workspaceId?: string; } // eslint-disable-next-line unused-imports/no-unused-vars diff --git a/src/business/server/video-generation/chargeBeforeGenerate.ts b/src/business/server/video-generation/chargeBeforeGenerate.ts index 5410df5e18..6b8a1273c8 100644 --- a/src/business/server/video-generation/chargeBeforeGenerate.ts +++ b/src/business/server/video-generation/chargeBeforeGenerate.ts @@ -7,6 +7,7 @@ interface ChargeParams { params: CreateVideoServicePayload['params']; provider: string; userId: string; + workspaceId?: string; } interface ErrorBatch { diff --git a/src/config/featureFlags/schema.test.ts b/src/config/featureFlags/schema.test.ts index 931d333144..1378904422 100644 --- a/src/config/featureFlags/schema.test.ts +++ b/src/config/featureFlags/schema.test.ts @@ -113,6 +113,16 @@ describe('mapFeatureFlagsEnvToState', () => { expect(mappedState.enableStorageOverage).toBe(true); }); + it('should map the workspace allowlist flag by user ID', () => { + const config = { + workspace: ['user-123'], + }; + + expect(mapFeatureFlagsEnvToState(config, 'user-123').enableWorkspace).toBe(true); + expect(mapFeatureFlagsEnvToState(config, 'user-456').enableWorkspace).toBe(false); + expect(mapFeatureFlagsEnvToState(config).enableWorkspace).toBe(false); + }); + it('should correctly map boolean feature flags to state', () => { const config = { provider_settings: true, diff --git a/src/config/featureFlags/schema.ts b/src/config/featureFlags/schema.ts index af1bfa8622..b2afbd7b86 100644 --- a/src/config/featureFlags/schema.ts +++ b/src/config/featureFlags/schema.ts @@ -36,6 +36,7 @@ export const FeatureFlagsSchema = z.object({ auth_captcha: FeatureFlagValue.optional(), cloud_promotion: FeatureFlagValue.optional(), storage_overage: FeatureFlagValue.optional(), + workspace: FeatureFlagValue.optional(), // the flags below can only be used with commercial license // if you want to use it in the commercial usage @@ -86,6 +87,7 @@ export const DEFAULT_FEATURE_FLAGS: IFeatureFlags = { auth_captcha: true, cloud_promotion: false, storage_overage: true, + workspace: false, market: true, speech_to_text: true, @@ -122,6 +124,7 @@ export const mapFeatureFlagsEnvToState = (config: IFeatureFlags, userId?: string enableStorageOverage: evaluateFeatureFlag(config.storage_overage, userId), showCloudPromotion: evaluateFeatureFlag(config.cloud_promotion, userId), + enableWorkspace: evaluateFeatureFlag(config.workspace, userId), showMarket: evaluateFeatureFlag(config.market, userId), enableSTT: evaluateFeatureFlag(config.speech_to_text, userId), diff --git a/src/libs/better-auth/email-templates/index.ts b/src/libs/better-auth/email-templates/index.ts index 905d1e5281..a4994a436b 100644 --- a/src/libs/better-auth/email-templates/index.ts +++ b/src/libs/better-auth/email-templates/index.ts @@ -3,3 +3,5 @@ export { getMagicLinkEmailTemplate } from './magic-link'; export { getResetPasswordEmailTemplate } from './reset-password'; export { getVerificationEmailTemplate } from './verification'; export { getVerificationOTPEmailTemplate } from './verification-otp'; +export { getWorkspaceInviteEmailTemplate } from './workspace-invite'; +export { getWorkspaceMemberRemovedEmailTemplate } from './workspace-member-removed'; diff --git a/src/libs/better-auth/email-templates/workspace-invite.ts b/src/libs/better-auth/email-templates/workspace-invite.ts new file mode 100644 index 0000000000..5ca952b377 --- /dev/null +++ b/src/libs/better-auth/email-templates/workspace-invite.ts @@ -0,0 +1,108 @@ +/** + * Workspace invitation email template + * Sent when a workspace owner invites someone (by email) to join their workspace. + */ +export const getWorkspaceInviteEmailTemplate = (params: { + expiresInDays: number; + inviterEmail?: string | null; + inviterName?: string | null; + role: string; + url: string; + workspaceName: string; +}) => { + const { url, workspaceName, inviterName, inviterEmail, role, expiresInDays } = params; + + const inviterLabel = inviterName || inviterEmail || 'A teammate'; + const inviterByline = + inviterEmail && inviterName ? `${inviterName} (${inviterEmail})` : inviterLabel; + const subject = `${inviterLabel} invited you to join ${workspaceName} on LobeHub`; + const roleLabel = role.charAt(0).toUpperCase() + role.slice(1); + + return { + html: ` + + + + + + ${subject} + + +
+ + +
+
+ 🤯 + LobeHub +
+
+ + +
+ + +
+

+ Join ${workspaceName} on LobeHub +

+

+ You've been invited as a ${roleLabel}. +

+
+ + +
+

+ ${inviterByline} has invited you to collaborate inside the + ${workspaceName} workspace on LobeHub. +

+ + + + + +
+

+ ⏰ This invitation will expire in ${expiresInDays} day${expiresInDays > 1 ? 's' : ''}. +

+
+ +

+ If you don't have a LobeHub account yet, you'll be guided through a quick signup before joining the workspace. +

+
+ + +
+ + +
+

+ Button not working? Copy and paste this link into your browser: +

+ + ${url} + +
+
+ + +
+

+ If you weren't expecting this invitation, you can safely ignore this email. +

+
+
+ + + `, + subject, + text: `${inviterByline} has invited you to join the "${workspaceName}" workspace on LobeHub as ${roleLabel}.\n\nAccept the invitation: ${url}\n\nThis invitation will expire in ${expiresInDays} day${expiresInDays > 1 ? 's' : ''}.\n\nIf you weren't expecting this invitation, you can safely ignore this email.`, + }; +}; diff --git a/src/libs/better-auth/email-templates/workspace-member-removed.ts b/src/libs/better-auth/email-templates/workspace-member-removed.ts new file mode 100644 index 0000000000..92497d2fa0 --- /dev/null +++ b/src/libs/better-auth/email-templates/workspace-member-removed.ts @@ -0,0 +1,91 @@ +export const getWorkspaceMemberRemovedEmailTemplate = (params: { + reason: 'downgrade' | 'removed_by_owner'; + workspaceName: string; +}) => { + const { workspaceName, reason } = params; + + const isDowngrade = reason === 'downgrade'; + + const subject = isDowngrade + ? `You have been removed from ${workspaceName} on LobeHub` + : `You have been removed from ${workspaceName} on LobeHub`; + + const heading = isDowngrade + ? `Removed from ${workspaceName}` + : `Removed from ${workspaceName}`; + + const body = isDowngrade + ? `The workspace ${workspaceName} has been downgraded, and all team members have been removed as a result. Your personal data and workspaces are not affected.` + : `The owner of ${workspaceName} has removed you from the workspace. Your personal data and workspaces are not affected.`; + + return { + html: ` + + + + + + ${subject} + + +
+ + +
+
+ 🤯 + LobeHub +
+
+ + +
+ + +
+

+ ${heading} +

+
+ + +
+

+ ${body} +

+ + +
+

+ If you believe this was a mistake, please contact the workspace owner. +

+
+
+ + +
+ + +
+

+ You can continue using LobeHub with your personal workspace. +

+
+
+ + +
+

+ This is an automated message from LobeHub. +

+
+
+ + + `, + subject, + text: isDowngrade + ? `The workspace "${workspaceName}" has been downgraded, and all team members have been removed as a result. Your personal data and workspaces are not affected. If you believe this was a mistake, please contact the workspace owner.` + : `The owner of "${workspaceName}" has removed you from the workspace. Your personal data and workspaces are not affected. If you believe this was a mistake, please contact the workspace owner.`, + }; +}; diff --git a/src/libs/mcp/connectorPermissionCheck.ts b/src/libs/mcp/connectorPermissionCheck.ts index df911fa433..13e8518e02 100644 --- a/src/libs/mcp/connectorPermissionCheck.ts +++ b/src/libs/mcp/connectorPermissionCheck.ts @@ -14,6 +14,7 @@ export { patchManifestWithPermissions } from './patchManifestPermissions'; * @param userId - Authenticated user ID * @param identifier - Connector identifier (e.g. 'gmail', 'vercel') * @param toolName - Tool API name (e.g. 'gmail_search_emails', 'deploy') + * @param workspaceId - Workspace scope (omit for personal mode) * * Returns the stored permission, or null if no connector/tool entry exists. */ @@ -22,13 +23,14 @@ export async function getConnectorToolPermission( userId: string, identifier: string, toolName: string, + workspaceId?: string, ): Promise { try { - const connectorModel = new ConnectorModel(db, userId); + const connectorModel = new ConnectorModel(db, userId, workspaceId); const [connector] = await connectorModel.queryByIdentifiers([identifier]); if (!connector) return null; - const toolModel = new ConnectorToolModel(db, userId); + const toolModel = new ConnectorToolModel(db, userId, workspaceId); const tools = await toolModel.queryByConnector(connector.id); return ( (tools.find((t) => t.toolName === toolName)?.permission as ConnectorToolPermission) ?? null diff --git a/src/libs/trpc/lambda/context.ts b/src/libs/trpc/lambda/context.ts index ba1707fd6a..f52dfa5339 100644 --- a/src/libs/trpc/lambda/context.ts +++ b/src/libs/trpc/lambda/context.ts @@ -41,7 +41,11 @@ const validateApiKeyUserId = async (apiKey: string): Promise => { if (!apiKeyRecord.enabled) return null; if (isApiKeyExpired(apiKeyRecord.expiresAt)) return null; - const userApiKeyModel = new ApiKeyModel(db, apiKeyRecord.userId); + const userApiKeyModel = new ApiKeyModel( + db, + apiKeyRecord.userId, + apiKeyRecord.workspaceId ?? undefined, + ); void userApiKeyModel.updateLastUsed(apiKeyRecord.id).catch((error) => { log('Failed to update API key last used timestamp: %O', error); console.error('Failed to update API key last used timestamp:', error); @@ -74,6 +78,7 @@ export interface AuthContext { traceContext?: OtContext; userAgent?: string; userId?: string | null; + workspaceId?: string | null; } /** @@ -87,6 +92,7 @@ export const createContextInner = async (params?: { traceContext?: OtContext; userAgent?: string; userId?: string | null; + workspaceId?: string | null; }): Promise => { log('createContextInner called with params: %O', params); const responseHeaders = new Headers(); @@ -99,6 +105,7 @@ export const createContextInner = async (params?: { traceContext: params?.traceContext, userAgent: params?.userAgent, userId: params?.userId, + workspaceId: params?.workspaceId, }; }; @@ -134,10 +141,13 @@ export const createLambdaContext = async (request: NextRequest): Promise { email: user.email, name: user.fullName || user.username || undefined, userId: ctx.userId, + // In a workspace context, the token acts as the workspace's mirrored + // organization; absent for personal requests. + ...(ctx.workspaceId ? { workspaceId: ctx.workspaceId } : {}), }; // Fetch market access token from user_settings.market diff --git a/src/libs/trusted-client/index.ts b/src/libs/trusted-client/index.ts index d6267cc6ff..9260fa9102 100644 --- a/src/libs/trusted-client/index.ts +++ b/src/libs/trusted-client/index.ts @@ -6,6 +6,13 @@ export interface TrustedClientUserInfo { email?: string; name?: string; userId: string; + /** + * Cloud workspace id the request acts on behalf of. When set, Market treats + * the caller as the workspace's mirrored organization (resolved via the + * `workspace:` clerkId convention), mirroring how `userId` + * identifies the personal account. Omit for personal requests. + */ + workspaceId?: string; } export { getSessionUser } from './getSessionUser'; @@ -38,6 +45,7 @@ export const generateTrustedClientToken = (userInfo: TrustedClientUserInfo): str email: userInfo.email || '', name: userInfo.name, userId: userInfo.userId, + workspaceId: userInfo.workspaceId, }); return createTrustedClientToken(payload, MARKET_TRUSTED_CLIENT_SECRET); diff --git a/src/locales/default/agent.ts b/src/locales/default/agent.ts index 5e8901f3bf..93ab6f5d99 100644 --- a/src/locales/default/agent.ts +++ b/src/locales/default/agent.ts @@ -303,4 +303,22 @@ export default { 'channel.statusFailed': 'Failed', 'channel.statusQueued': 'Queued', 'channel.statusStarting': 'Starting', + + 'transfer.title': 'Transfer', + 'transfer.copyTo': 'Copy To', + 'transfer.desc': 'Transfer this agent to another workspace or your personal account.', + 'transfer.button': 'Transfer', + 'transfer.selectTarget': 'Transfer Agent To', + 'transfer.searchWorkspace': 'Search workspaces...', + 'transfer.personalAccount': 'Personal Account', + 'transfer.confirm.title': 'Transfer Agent', + 'transfer.confirm.desc': + 'This will move the agent and all associated data (topics, messages, files, etc.) to the target workspace.', + 'transfer.confirm.warning': "Some features don't transfer:", + 'transfer.confirm.plugins': 'Custom plugins may not be available in the target workspace', + 'transfer.confirm.chatGroups': 'Multi-agent group associations will be removed', + 'transfer.confirm.botChannels': 'Bot channel connections may need to be refreshed after transfer', + 'transfer.success': 'Agent transferred successfully', + 'transfer.transferTo': 'Transfer To', + 'transfer.error': 'Failed to transfer agent', } as const; diff --git a/src/locales/default/auth.ts b/src/locales/default/auth.ts index cf76ce823e..dbba3c0604 100644 --- a/src/locales/default/auth.ts +++ b/src/locales/default/auth.ts @@ -250,6 +250,9 @@ export default { 'stats.updatedAt': 'Updated at', 'stats.welcome': '{{username}}, this is your {{days}} day with {{appName}}', 'stats.words': 'Total Words', + 'stats.workspace.members': '{{count}} members', + 'stats.workspace.welcome': + '{{name}} has been running for {{days}} days', 'tab.apikey': 'API Key', 'tab.profile': 'My Account', 'tab.security': 'Security', @@ -263,6 +266,10 @@ export default { 'usage.activeModels.table.model': 'Model', 'usage.activeModels.table.provider': 'Provider', 'usage.activeModels.table.spend': 'Spend', + 'usage.activeModels.table.user': 'User', + 'usage.activeModels.removedUserName': '{{name}} (Removed)', + 'usage.activeModels.userTable': 'User List', + 'usage.activeModels.users': 'Active Users', 'usage.cards.month.modelCalls': 'Model Calls', 'usage.cards.month.title': "This Month's Spend", 'usage.cards.today.title': "Today's Spend", @@ -280,4 +287,5 @@ export default { 'usage.trends.tokens': 'Tokens', 'usage.welcome.model': 'Model', 'usage.welcome.provider': 'Provider', + 'usage.welcome.user': 'User', }; diff --git a/src/locales/default/chat.ts b/src/locales/default/chat.ts index 483e23bc56..66967c8d1e 100644 --- a/src/locales/default/chat.ts +++ b/src/locales/default/chat.ts @@ -827,13 +827,29 @@ export default { 'taskList.unassignedHint': 'Lobe AI will run this task when no assignee is set', 'taskList.assigneeSearch.empty': 'No matching agent', 'taskList.assigneeSearch.placeholder': 'Search agent...', + 'taskList.contextMenu.copyConfirm': 'Copy', + 'taskList.contextMenu.copyDescription': + 'Clone this task (and all its subtasks) into another workspace. Status resets to backlog.', + 'taskList.contextMenu.copyFailed': 'Failed to copy task', 'taskList.contextMenu.copyId': 'Copy ID', 'taskList.contextMenu.copyIdSuccess': 'ID copied', 'taskList.contextMenu.copyLink': 'Copy Link', 'taskList.contextMenu.copyLinkSuccess': 'Link copied', + 'taskList.contextMenu.copySuccess': 'Task copied', + 'taskList.contextMenu.copyTitle': 'Copy task', + 'taskList.contextMenu.copyTo': 'Copy to…', 'taskList.contextMenu.priority': 'Priority', 'taskList.contextMenu.runNow': 'Run now', 'taskList.contextMenu.status': 'Status', + 'taskList.contextMenu.transferConfirm': 'Transfer', + 'taskList.contextMenu.transferDescription': + 'Move this task (and all its subtasks) to another workspace. Identifiers will be re-assigned.', + 'taskList.contextMenu.transferFailed': 'Failed to transfer task', + 'taskList.contextMenu.transferSuccess': 'Task transferred', + 'taskList.contextMenu.transferTitle': 'Transfer task', + 'taskList.contextMenu.transferTo': 'Transfer to…', + 'taskList.contextMenu.transferWarning': + 'Cross-workspace references like assigned agent and active topic will be cleared.', 'taskList.kanban.addTask': 'Create task', 'taskList.kanban.backlog': 'Backlog', 'taskList.kanban.canceled': 'Canceled', diff --git a/src/locales/default/discover.ts b/src/locales/default/discover.ts index d2975c5885..338025adad 100644 --- a/src/locales/default/discover.ts +++ b/src/locales/default/discover.ts @@ -1113,11 +1113,14 @@ export default { 'tab.user': 'User', + 'tab.workspace': 'Workspace', + 'user.agents': 'Agents', 'user.downloads': 'Downloads', 'user.editProfile': 'Edit Profile', + 'user.editWorkspaceProfile': 'Settings', 'user.favoriteAgents': 'Saved Agents', @@ -1138,18 +1141,123 @@ export default { 'user.logout': 'Logout', + 'user.openWorkspacePublicProfile': 'Open Public Link', + + 'user.setupWorkspaceProfile': 'Set up Community Profile', + 'user.myProfile': 'My Profile', + 'user.workspaceProfile.cancel': 'Cancel', + 'user.workspaceProfile.description': 'Update the public workspace profile shown in Community.', + 'user.workspaceProfile.errors.displayName': 'Enter a workspace name', + 'user.workspaceProfile.errors.fileTooLarge': 'Image must be smaller than 2 MB', + 'user.workspaceProfile.errors.namespace.length': 'Handle must be 3-32 characters', + 'user.workspaceProfile.errors.namespace.pattern': + 'Use lowercase letters, numbers, and hyphens. Start and end with a letter or number.', + 'user.workspaceProfile.errors.namespace.required': 'Enter a Community handle', + 'user.workspaceProfile.errors.uploadFailed': 'Failed to upload avatar', + 'user.workspaceProfile.errors.url': 'Enter a valid URL', + 'user.workspaceProfile.failed': 'Failed to update workspace profile', + 'user.workspaceProfile.fields.avatar': 'Avatar', + 'user.workspaceProfile.fields.bannerUrl': 'Banner Image', + 'user.workspaceProfile.fields.bannerUrl.clickToUpload': 'Click to upload banner image', + 'user.workspaceProfile.fields.bannerUrl.remove': 'Remove banner', + 'user.workspaceProfile.fields.bannerUrl.tooltip': + 'The banner is shown at the top of your community profile (16:9 ratio recommended).', + 'user.workspaceProfile.fields.bannerUrl.uploading': 'Uploading...', + 'user.workspaceProfile.fields.description': 'Description', + 'user.workspaceProfile.fields.description.maxLength': 'Description can be up to 200 characters', + 'user.workspaceProfile.fields.description.placeholder': 'Introduce your workspace…', + 'user.workspaceProfile.fields.displayName': 'Workspace name', + 'user.workspaceProfile.fields.displayName.maxLength': 'Workspace name can be up to 50 characters', + 'user.workspaceProfile.fields.displayName.placeholder': 'Enter the workspace name', + 'user.workspaceProfile.fields.namespace': 'Public Community URL', + 'user.workspaceProfile.fields.namespace.available': 'This URL is available', + 'user.workspaceProfile.fields.namespace.checking': 'Checking availability…', + 'user.workspaceProfile.fields.namespace.placeholder': 'workspace-handle', + 'user.workspaceProfile.fields.namespace.taken': 'This URL is already taken', + 'user.workspaceProfile.fields.websiteUrl': 'Website', + 'user.workspaceProfile.fields.websiteUrl.placeholder': 'Workspace website link', + 'user.workspaceProfile.optional.toggle': 'More settings', + 'user.workspaceProfile.save': 'Save', + 'user.workspaceProfile.setup.description': + 'Create a public Community profile for this workspace. This workspace profile and published resources will be visible.', + 'user.workspaceProfile.setup.empty.description': + 'Set up a public Community profile before publishing or showing workspace resources.', + 'user.workspaceProfile.setup.empty.title': 'Set up this workspace for Community', + 'user.workspaceProfile.setup.failed': 'Failed to set up Community profile', + 'user.workspaceProfile.setup.privacyAlert': + 'Members, chats, billing, credentials, and private resources stay private.', + 'user.workspaceProfile.setup.save': 'Set up profile', + 'user.workspaceProfile.setup.success': 'Community profile created', + 'user.workspaceProfile.setup.title': 'Set up Community Profile', + 'user.workspaceProfile.settings.avatar.description': + 'Use a workspace avatar for the Community profile.', + 'user.workspaceProfile.settings.avatar.hint': 'Upload an image or choose an emoji.', + 'user.workspaceProfile.settings.back': 'Back', + 'user.workspaceProfile.settings.banner.description': + 'Show a banner at the top of the Community profile.', + 'user.workspaceProfile.settings.description.description': + 'Describe what this workspace publishes or works on.', + 'user.workspaceProfile.settings.displayName.description': + 'This name appears on the Community profile.', + 'user.workspaceProfile.settings.members.column.member': 'Member', + 'user.workspaceProfile.settings.members.column.role': 'Role', + 'user.workspaceProfile.settings.members.description': + 'Members currently mirrored into this workspace’s Community organization.', + 'user.workspaceProfile.settings.members.empty': + 'No members synced yet. Sync to mirror the current workspace members.', + 'user.workspaceProfile.settings.members.role.admin': 'Admin', + 'user.workspaceProfile.settings.members.role.member': 'Member', + 'user.workspaceProfile.settings.members.sync': 'Sync members', + 'user.workspaceProfile.settings.members.syncFailed': 'Failed to sync members', + 'user.workspaceProfile.settings.members.syncHint': + 'Sync the current workspace members to the Community organization.', + 'user.workspaceProfile.settings.members.syncSuccess': 'Members synced', + 'user.workspaceProfile.settings.members.title': 'Community Members', + 'user.workspaceProfile.settings.namespace.description': + 'Set the public Community URL for this workspace.', + 'user.workspaceProfile.settings.namespace.hint': 'Use up to {{max}} characters.', + 'user.workspaceProfile.settings.namespaceTaken': + 'This Community URL is already taken. Choose another one.', + 'user.workspaceProfile.settings.noPermission': + 'Only workspace owners can manage the Community profile.', + 'user.workspaceProfile.settings.noProfile.description': + 'Set up the Community profile before editing its settings.', + 'user.workspaceProfile.settings.subtitle': + 'Manage the public Community profile for this workspace.', + 'user.workspaceProfile.settings.tabs.members': 'Members', + 'user.workspaceProfile.settings.tabs.profile': 'Profile', + 'user.workspaceProfile.settings.title': 'Community Profile Settings', + 'user.workspaceProfile.settings.updateFailed': 'Failed to update Community profile', + 'user.workspaceProfile.settings.updateSuccess': 'Community profile updated', + 'user.workspaceProfile.settings.website.description': + 'Add a website shown on the Community profile.', + 'user.workspaceProfile.settings.website.hint': 'Use a public http(s) URL.', + 'user.workspaceProfile.success': 'Workspace profile updated', + 'user.workspaceProfile.title': 'Edit Workspace', + + 'user.accountType.organization': 'Organization', + 'user.noAgents': 'This user hasn’t published any Agents yet', 'user.noAgents.ownerDescription': 'Create your first Agent and share it with the Community.', 'user.noAgents.title': 'No Agents yet', + 'user.org.noAgents': 'This organization hasn’t published any Agents yet', + + 'user.workspace.noAgents': 'This organization has not published any Agents to Community yet.', + 'user.noFavoriteAgents': 'No saved Agents yet', 'user.noFavoritePlugins': 'No saved Skills yet', + 'user.noGroups.title': 'No Agent Groups yet', + + 'user.workspace.noGroups': + 'This organization has not published any Agent Groups to Community yet.', + 'user.noForkedAgentGroups': 'No forked Agent Groups yet', 'user.noForkedAgents': 'No forked Agents yet', diff --git a/src/locales/default/error.ts b/src/locales/default/error.ts index 3f48ca0c7f..81ec737554 100644 --- a/src/locales/default/error.ts +++ b/src/locales/default/error.ts @@ -178,6 +178,14 @@ export default { 'The group host is unable to function. Please check your host configuration to ensure the correct model, API Key, and API endpoint are set.', 'testConnectionFailed': 'Test connection failed: {{error}}', 'tts.responseError': 'Service request failed, please check the configuration or try again', + 'transfer.noPermission': "You don't have permission to move this resource.", + 'transfer.ownerOnly': 'Only workspace owners can transfer resources created by other members.', + 'transfer.resourceNotFound': + 'This resource no longer exists or you no longer have access. Refresh and try again.', + 'transfer.sameWorkspace': + 'This resource is already in the selected workspace. Choose another target.', + 'transfer.targetNoWriteAccess': + 'You need Member or Owner access to move resources into the target workspace.', 'unlock.addProxyUrl': 'Add OpenAI proxy URL (optional)', 'unlock.apiKey.description': 'Enter your {{name}} API Key to start the session', 'unlock.apiKey.imageGenerationDescription': 'Enter your {{name}} API Key to start generating', @@ -218,6 +226,7 @@ export default { 'upload.storageBlock.upgradeRequired': 'Your file storage has reached the plan limit. Please upgrade your plan or delete unused files.', 'upload.storageBlock.viewPlans': 'View plans', + 'upload.permissionDenied': "You don't have permission to upload files in this workspace.", 'upload.storageLimitExceeded': 'Your file storage has reached the plan limit. Please upgrade your plan or delete unused files to free up space.', 'upload.title': 'File upload failed. Please check your network connection or try again later', diff --git a/src/locales/default/file.ts b/src/locales/default/file.ts index f121c0b337..22342cc57c 100644 --- a/src/locales/default/file.ts +++ b/src/locales/default/file.ts @@ -57,9 +57,26 @@ export default { 'home.uploadEntries.newPage.title': 'New Page', 'library.hierarchy.empty.desc': 'Add files or create a folder to get started', 'library.hierarchy.empty.title': 'Nothing here yet', + 'library.import.action': 'Import to workspace…', + 'library.import.failed': 'Failed to import knowledge base.', + 'library.import.success': 'Knowledge base imported to {{name}}.', + 'library.import.tooltip': + 'Fork this knowledge base into a workspace. Files are shared by reference; the original stays in your personal space.', 'library.list.confirmRemoveLibrary': 'You are about to delete this library. The files within it will not be deleted but moved to All Files. This action cannot be undone, so please proceed with caution.', + 'library.list.copyDescription': + 'Clone this library and all of its contents into another workspace.', + 'library.list.copyFailed': 'Failed to copy library', + 'library.list.copySuccess': 'Library copied', + 'library.list.copyTitle': 'Copy library', + 'library.list.copyTo': 'Copy to…', 'library.list.empty': 'Click <1>+ to create a new library', + 'library.list.transferDescription': + 'Move this library and all of its contents to another workspace.', + 'library.list.transferFailed': 'Failed to transfer library', + 'library.list.transferSuccess': 'Library transferred', + 'library.list.transferTitle': 'Transfer library', + 'library.list.transferTo': 'Transfer to…', 'library.new': 'New Library', 'library.title': 'Library', 'loadMore': 'Load More', @@ -139,7 +156,13 @@ export default { 'pageEditor.titlePlaceholder': 'Untitled', 'pageEditor.wordCount': '{{wordCount}} words', 'pageList.actions.openInNewTab': 'Open in New Tab', + 'pageList.copyConfirm': 'Copy', 'pageList.copyContent': 'Copy Full Text', + 'pageList.copyDescription': 'Create a copy of this page in another workspace.', + 'pageList.copyFailed': 'Failed to copy page', + 'pageList.copySuccess': 'Page copied', + 'pageList.copyTitle': 'Copy page', + 'pageList.copyTo': 'Copy to…', 'pageList.duplicate': 'Duplicate', 'pageList.empty': 'No pages yet. Click the button above to create your first one.', 'pageList.filter.all': 'All', @@ -148,7 +171,42 @@ export default { 'pageList.pageCount': '{{count}} pages in total', 'pageList.pageSizeItem': '{{count}} items', 'pageList.title': 'Pages', + 'pageList.transferConfirm': 'Transfer', + 'pageList.transferDescription': + 'Move this page (and any folders it contains) to another workspace.', + 'pageList.transferFailed': 'Failed to transfer page', + 'pageList.transferSuccess': 'Page transferred', + 'pageList.transferTitle': 'Transfer page', + 'pageList.transferTo': 'Transfer to…', + 'pageList.transferWarning': 'This is a one-way move; reverting requires another transfer.', 'pageList.untitled': 'Untitled', + 'resourceList.batchCopyDescription': 'Clone selected resources into another workspace.', + 'resourceList.batchCopyTitle': 'Copy resources', + 'resourceList.batchTransferDescription': 'Move selected resources to another workspace.', + 'resourceList.batchTransferTitle': 'Transfer resources', + 'resourceList.copyConfirm': 'Copy', + 'resourceList.copyDocumentDescription': 'Clone this document into another workspace.', + 'resourceList.copyDocumentTitle': 'Copy document', + 'resourceList.copyFailed': 'Failed to copy resource', + 'resourceList.copyFileDescription': 'Clone this file into another workspace.', + 'resourceList.copyFileTitle': 'Copy file', + 'resourceList.copyFolderDescription': + 'Clone this folder (and its contents) into another workspace.', + 'resourceList.copyFolderTitle': 'Copy folder', + 'resourceList.copySuccess': 'Resource copied', + 'resourceList.copyTo': 'Copy to…', + 'resourceList.transferConfirm': 'Transfer', + 'resourceList.transferDocumentDescription': 'Move this document to another workspace.', + 'resourceList.transferDocumentTitle': 'Transfer document', + 'resourceList.transferFailed': 'Failed to transfer resource', + 'resourceList.transferFileDescription': 'Move this file to another workspace.', + 'resourceList.transferFileTitle': 'Transfer file', + 'resourceList.transferFolderDescription': + 'Move this folder (and its contents) to another workspace.', + 'resourceList.transferFolderTitle': 'Transfer folder', + 'resourceList.transferSuccess': 'Resource transferred', + 'resourceList.transferTo': 'Transfer to…', + 'resourceList.viewTransferred': 'View', 'portal.openInPageEditor': 'Edit in Page', 'preview.downloadFile': 'Download File', 'preview.unsupportedFileAndContact': diff --git a/src/locales/default/messenger.ts b/src/locales/default/messenger.ts index c7dc664e50..e34d5789ce 100644 --- a/src/locales/default/messenger.ts +++ b/src/locales/default/messenger.ts @@ -1,6 +1,9 @@ export default { 'messenger.activeAgent': 'Active agent', 'messenger.activeAgentPlaceholder': 'Select an agent', + 'messenger.scope': 'Workspace', + 'messenger.scopePersonal': 'Personal', + 'messenger.scopePersonalTag': 'personal', 'messenger.detail.addServer': 'Add server', 'messenger.detail.addWorkspace': 'Add workspace', 'messenger.detail.connections.connected': 'Connected', diff --git a/src/locales/default/notification.ts b/src/locales/default/notification.ts index 303fdb1497..25682836d3 100644 --- a/src/locales/default/notification.ts +++ b/src/locales/default/notification.ts @@ -20,4 +20,34 @@ export default { 'storage_overage_cap_reached_title': 'Storage pay-as-you-go cap reached', 'video_generation_completed': 'Your video "{{prompt}}" is ready.', 'video_generation_completed_title': 'Video generation completed', + 'workspace_member_joined': '{{memberLabel}} joined workspace "{{workspaceName}}" as a {{role}}.', + 'workspace_member_joined_member': + '{{memberLabel}} joined workspace "{{workspaceName}}" as a Member.', + 'workspace_member_joined_member_title': 'New member joined {{workspaceName}}', + 'workspace_member_joined_owner': + '{{memberLabel}} joined workspace "{{workspaceName}}" as an Owner.', + 'workspace_member_joined_owner_title': 'New member joined {{workspaceName}}', + 'workspace_member_joined_title': 'New member joined {{workspaceName}}', + 'workspace_member_joined_viewer': + '{{memberLabel}} joined workspace "{{workspaceName}}" as a Viewer.', + 'workspace_member_joined_viewer_title': 'New member joined {{workspaceName}}', + 'workspace_member_removed': + 'You have been removed from workspace "{{workspaceName}}" by the workspace owner.', + 'workspace_member_removed_downgrade': + 'You have been removed from workspace "{{workspaceName}}" because the workspace was downgraded.', + 'workspace_member_removed_downgrade_title': 'Removed from workspace', + 'workspace_member_removed_title': 'Removed from workspace', + 'workspace_payment_failed': + 'Renewal payment for workspace "{{workspaceName}}" failed. Please update your payment method to keep the workspace active.', + 'workspace_payment_failed_title': 'Payment failed for {{workspaceName}}', + 'workspace_payment_method_removed': + 'A payment method was removed from workspace "{{workspaceName}}". Add a card before the next renewal, otherwise the subscription will fail to renew.', + 'workspace_payment_method_removed_title': 'Payment method removed from {{workspaceName}}', + 'workspace_primary_ownership_transferred': + 'You are now the primary owner of workspace "{{workspaceName}}". Billing and primary privileges have been transferred to you.', + 'workspace_primary_ownership_transferred_title': + 'You are now the primary owner of {{workspaceName}}', + 'workspace_subscription_expired': + 'The subscription for workspace "{{workspaceName}}" has ended. Renew within {{days}} days to restore full access before the workspace is downgraded.', + 'workspace_subscription_expired_title': 'Subscription ended for {{workspaceName}}', }; diff --git a/src/locales/default/setting.ts b/src/locales/default/setting.ts index 7649c49717..98e2c337fd 100644 --- a/src/locales/default/setting.ts +++ b/src/locales/default/setting.ts @@ -1,6 +1,18 @@ export default { '_cloud.officialProvider': '{{name}} Official Model Service', 'about.title': 'About', + 'agentImport.action': 'Import to workspace…', + 'agentImport.description': + 'Fork a copy of this agent into one of your workspaces. The original stays in your personal space — no sync after import.', + 'agentImport.failed': 'Failed to import agent.', + 'agentImport.modal.configIncluded': 'Agent configuration is copied by default.', + 'agentImport.modal.confirm': 'Import', + 'agentImport.modal.includeHistory': 'Copy topics and messages', + 'agentImport.modal.includeHistoryDesc': + 'Optional. Copies this agent’s conversation history into the new agent.', + 'agentImport.modal.knowledgeNotice': 'Knowledge bindings and files are not copied yet.', + 'agentImport.success': 'Agent imported to {{name}}.', + 'agentImport.title': 'Import to workspace', 'accountDeletion.cancelButton': 'Cancel Deletion', 'accountDeletion.cancelConfirmTitle': 'Cancel account deletion request?', 'accountDeletion.cancelFailed': 'Failed to cancel deletion request', @@ -512,7 +524,15 @@ export default { 'marketPublish.upload.button': 'Publish New Version', 'marketPublish.upload.tooltip': 'Publish a new version to Agent Community', 'marketPublish.uploadGroup.tooltip': 'Publish a new version to Group Community', - 'marketPublish.validation.confirmPublish': 'Are you sure you want to publish to the market?', + 'marketPublish.validation.communitySetupRequired.action': 'Set Up Now', + 'marketPublish.validation.communitySetupRequired.desc': + "This workspace hasn't set up its Community profile yet. Set it up before publishing to the Community.", + 'marketPublish.validation.communitySetupRequired.memberHint': + "This workspace hasn't set up its Community profile yet. Ask a workspace owner to set it up before publishing to the Community.", + 'marketPublish.validation.communitySetupRequired.title': 'Set Up Community Profile First', + 'marketPublish.validation.confirmPublish': 'Publish to the Market?', + 'marketPublish.validation.confirmPublishDesc': + 'Once published, this content will be publicly visible in the market and available for anyone to discover and use.', 'marketPublish.validation.emptyName': 'Cannot publish: Name is required', 'marketPublish.validation.emptySystemRole': 'Cannot publish: System Role is required', 'marketPublish.validation.underReview': @@ -541,10 +561,17 @@ export default { 'notification.category.billing.title': 'Billing', 'notification.category.generation.title': 'Generation', 'notification.category.schedule.title': 'Scheduled tasks', + 'notification.category.workspace.title': 'Workspace', 'notification.item.agent_cron_job_failed': 'Scheduled task failed', 'notification.item.image_generation_completed': 'Image generation completed', 'notification.item.storage_overage_cap_reached': 'Storage pay-as-you-go cap reached', 'notification.item.video_generation_completed': 'Video generation completed', + 'notification.item.workspace_member_joined': 'New member joined', + 'notification.item.workspace_member_removed': 'Removed from workspace', + 'notification.item.workspace_payment_failed': 'Renewal payment failed', + 'notification.item.workspace_payment_method_removed': 'Payment method removed', + 'notification.item.workspace_primary_ownership_transferred': 'Primary ownership transferred', + 'notification.item.workspace_subscription_expired': 'Subscription ended', 'notification.title': 'Notification Channels', 'myAgents.actions.cancel': 'Cancel', 'myAgents.actions.confirmDeprecate': 'Confirm Deprecate', @@ -963,6 +990,22 @@ When I am ___, I need ___ '[Skill Request] Summarize the skill you need in one sentence', 'skillStore.wantMore.reachedEnd': "You've reached the end. Can't find what you need?", 'startConversation': 'Start Conversation', + 'storage.actions.transfer.button': 'Transfer To', + 'storage.actions.transfer.desc': + 'Move agents and their data to a workspace you have access to. LobeAI, the default inbox Agent, cannot be transferred; use Copy Agents to copy it to a workspace or personal account instead.', + 'storage.actions.transfer.title': 'Agents Migration', + 'storage.actions.transferAgentGroups.button': 'Transfer To', + 'storage.actions.transferAgentGroups.desc': + 'Move agent groups, their members, and group conversation data to a workspace you have access to.', + 'storage.actions.transferAgentGroups.title': 'Agent Groups Migration', + 'storage.actions.copyLobeAI.button': 'Copy To', + 'storage.actions.copyLobeAI.desc': + 'Copy agents, including LobeAI, into another workspace or personal account. Topics and messages are optional.', + 'storage.actions.copyLobeAI.title': 'Agents Copy', + 'storage.actions.copyAgentGroups.button': 'Copy To', + 'storage.actions.copyAgentGroups.desc': + 'Copy agent groups and their member agents into another workspace or personal account.', + 'storage.actions.copyAgentGroups.title': 'Agent Groups Copy', 'storage.actions.export.button': 'Export', 'storage.actions.export.exportType.agent': 'Export Agent Settings', 'storage.actions.export.exportType.agentWithMessage': 'Export Agent and Messages', @@ -976,6 +1019,7 @@ When I am ___, I need ___ 'storage.actions.title': 'Advanced Operations', 'storage.desc': 'Current storage usage in the browser', 'storage.embeddings.used': 'Vector Storage', + 'storage.migration.title': 'Data Migration', 'storage.title': 'Data Storage', 'storage.used': 'Storage Usage', 'storageOverage.addPaymentMethod': 'Add payment method', @@ -1169,6 +1213,940 @@ When I am ___, I need ___ 'tab.uploadZip': 'Upload Zip', 'tab.uploadZip.desc': 'Upload a local .zip or .skill file', 'tab.usage': 'Usage', + 'workspace.create.descPlaceholder': 'Describe what this workspace is for (optional)', + 'workspace.create.namePlaceholder': 'e.g. Acme Team', + 'workspace.create.submit': 'Create workspace', + 'workspace.billing.credits.label': 'Credits this month', + 'workspace.billing.hobbyHint': 'Free workspace · shared monthly pool', + 'workspace.billing.platformLine': 'Pro Platform · monthly', + 'workspace.billing.plan.enterprise': 'Enterprise', + 'workspace.billing.plan.hobby': 'Hobby', + 'workspace.billing.plan.pro': 'Pro', + 'workspace.billing.seatLine': 'Additional seats × {{count}}', + 'workspace.billing.seats.cancel': 'Cancel', + 'workspace.billing.seats.confirmContent': + 'Seats will change from {{previousSeats}} to {{newSeats}}. Stripe will charge or refund a prorated amount for the remainder of this billing cycle.', + 'workspace.billing.seats.confirmTitle': 'Update seats?', + 'workspace.billing.seats.editCta': 'Manage seats', + 'workspace.billing.seats.editorLabel': 'Total seats (including owner)', + 'workspace.billing.seats.failedToast': 'Failed to update seats.', + 'workspace.billing.seats.save': 'Save', + 'workspace.billing.seats.successToast': 'Seats updated to {{seats}}.', + 'workspace.billing.paymentMethods.addCta': 'Add payment method', + 'workspace.billing.paymentMethods.defaultBadge': 'Default', + 'workspace.billing.paymentMethods.empty': + 'No payment methods yet. Add one via Stripe portal — this workspace will be billed once a card is set as default.', + 'workspace.billing.paymentMethods.expires': 'Expires {{date}}', + 'workspace.billing.paymentMethods.managePortalCta': 'Manage in Stripe portal', + 'workspace.billing.paymentMethods.portalFailed': 'Failed to open billing portal', + 'workspace.billing.paymentMethods.remove': 'Remove', + 'workspace.billing.paymentMethods.removeConfirmContent': + 'This card will no longer be used to pay for this workspace.', + 'workspace.billing.paymentMethods.removeConfirmTitle': 'Remove this card?', + 'workspace.billing.paymentMethods.removeDefaultWarning': + 'This is the default card. Removing it without setting another one will cause the next renewal to fail.', + 'workspace.billing.paymentMethods.removeFailed': 'Failed to remove payment method', + 'workspace.billing.paymentMethods.setDefault': 'Set as default', + 'workspace.billing.paymentMethods.setDefaultFailed': 'Failed to set default', + 'workspace.billing.paymentMethods.setDefaultSuccess': 'Default updated', + 'workspace.billing.paymentMethods.subtitle': 'Cards on file for this workspace.', + 'workspace.billing.paymentMethods.title': 'Payment methods', + 'workspace.billing.title': 'Bills', + 'workspace.billing.totalHint': 'Billed monthly · cancel anytime', + 'workspace.billing.totalLabel': 'Total / month', + 'workspace.billingPage.billing.activeHint': + 'Cancel anytime — you keep access until the end of the cycle.', + 'workspace.billingPage.billing.autoRenewOff': 'Auto-renew off', + 'workspace.billingPage.billing.autoRenewOffOnDate': 'Ends on {{date}}', + 'workspace.billingPage.billing.autoRenewOn': 'Auto-renew on', + 'workspace.billingPage.billing.autoRenewOnDate': 'Renews on {{date}}', + 'workspace.billingPage.billing.banner.cancelledDesc': + 'Cancellation scheduled. Your subscription stops renewing at the end of the current billing cycle — the workspace then falls back to Hobby.', + 'workspace.billingPage.billing.banner.cancelledTitle': 'Subscription pending cancellation', + 'workspace.billingPage.billing.banner.expiredDesc': + 'Your subscription has ended. Re-subscribe to restore Pro features, or downgrade to Solo.', + 'workspace.billingPage.billing.banner.expiredTitle': 'Subscription cancelled', + 'workspace.billingPage.billing.banner.inactiveDesc': + 'Subscription is inactive — credits will not refresh until you re-subscribe.', + 'workspace.billingPage.billing.banner.inactiveTitle': 'Subscription inactive', + 'workspace.billingPage.billing.banner.resumeCta': 'Resume', + 'workspace.billingPage.billing.banner.subscribeCta': 'Subscribe', + 'workspace.billingPage.billing.breakdown.creditsLine_one': + '{{seats}} seat · {{credits}} credits / month', + 'workspace.billingPage.billing.breakdown.creditsLine_other': + '{{seats}} seats · {{credits}} credits / month', + 'workspace.billingPage.billing.breakdown.extraSeats': 'Extra seats', + 'workspace.billingPage.billing.breakdown.platform': 'Platform fee', + 'workspace.billingPage.billing.breakdown.product': 'Product', + 'workspace.billingPage.billing.breakdown.quantity': 'Quantity', + 'workspace.billingPage.billing.breakdown.seatCount_one': '{{count}} Seat', + 'workspace.billingPage.billing.breakdown.seatCount_other': '{{count}} Seats', + 'workspace.billingPage.billing.breakdown.totalCost': 'Total Cost', + 'workspace.billingPage.billing.cancelConfirm': + 'Subscription will keep running until the end of the current billing cycle, then stop renewing. You can resume at any time before the cycle ends.', + 'workspace.billingPage.billing.cancelCta': 'Cancel subscription', + 'workspace.billingPage.billing.cancelSuccess': 'Cancellation scheduled.', + 'workspace.billingPage.billing.cancelTitle': 'Cancel subscription?', + 'workspace.billingPage.billing.downgrade.confirmBody': + 'This immediately downgrades the workspace to Solo. The current billing period is non-refundable, and every member except the primary owner will be removed from this workspace.', + 'workspace.billingPage.billing.downgrade.confirmCta': 'Downgrade now', + 'workspace.billingPage.billing.downgrade.confirmInputLabel': + 'Type the workspace name "{{name}}" to confirm:', + 'workspace.billingPage.billing.downgrade.confirmInputPlaceholder': 'Workspace name', + 'workspace.billingPage.billing.downgrade.confirmTitle': 'Downgrade to Solo?', + 'workspace.billingPage.billing.downgrade.failedToast': 'Failed to downgrade.', + 'workspace.billingPage.billing.downgrade.successToast': 'Workspace downgraded to Solo.', + 'workspace.billingPage.billing.hobby.subtitle': + 'Hobby workspace · shared monthly pool · no team seats', + 'workspace.billingPage.billing.hobby.title': 'Free workspace', + 'workspace.billingPage.billing.hobby.upgradeCta': 'Upgrade to Pro', + 'workspace.billingPage.billing.invoice.empty': + 'No invoices yet. Your first invoice will appear after the next renewal.', + 'workspace.billingPage.billing.invoice.emptyHint': 'Workspace created on {{date}}.', + 'workspace.billingPage.billing.invoice.nonOwner': + 'Only workspace owners can view billing history.', + 'workspace.billingPage.billing.invoice.tab.all': 'All', + 'workspace.billingPage.billing.invoice.tab.failed': 'Closed', + 'workspace.billingPage.billing.invoice.tab.open': 'Unpaid', + 'workspace.billingPage.billing.invoice.tab.paid': 'Paid', + 'workspace.billingPage.billing.invoice.subtitle': 'View and download invoices for this workspace', + 'workspace.billingPage.billing.invoice.title': 'Billing history', + 'workspace.billingPage.billing.manage.cancelItem': 'Cancel Subscription', + 'workspace.billingPage.billing.manage.cta': 'Manage', + 'workspace.billingPage.billing.manage.downgradeItem': 'Downgrade', + 'workspace.billingPage.billing.manage.resumeItem': 'Resume subscription', + 'workspace.billingPage.billing.monthlyFeeLabel': '/ month', + 'workspace.billingPage.billing.planBadge.active': 'Active', + 'workspace.billingPage.billing.planBadge.cancelled': 'Cancelled', + 'workspace.billingPage.billing.planBadge.cancelling': 'Cancelling', + 'workspace.billingPage.billing.planBadge.inactive': 'Inactive', + 'workspace.billingPage.billing.resumeCta': 'Resume subscription', + 'workspace.billingPage.billing.resumeSuccess': + 'Cancellation reversed. Your subscription will renew normally.', + 'workspace.billingPage.billing.scheduledHint': + 'Subscription will end at the close of the current billing cycle.', + 'workspace.billingPage.billing.seats.deltaDown': 'Δ -${{amount}} / mo', + 'workspace.billingPage.billing.seats.deltaUp': 'Δ +${{amount}} / mo', + 'workspace.billingPage.billing.seats.editCta': 'Edit', + 'workspace.billingPage.billing.seats.from': '{{previous}} seats → {{next}} seats', + 'workspace.billingPage.billing.seats.previewLabel': 'Price preview (approximate)', + 'workspace.billingPage.billing.seats.priceDelta': '${{previous}} / mo → ${{next}} / mo', + 'workspace.billingPage.billing.seats.proration': + 'Seat changes are settled on the next monthly invoice — no immediate charge or refund.', + 'workspace.billingPage.billing.seats.subtitle_one': + 'Currently {{count}} seat · ${{seatFee}} per extra seat / month', + 'workspace.billingPage.billing.seats.subtitle_other': + 'Currently {{count}} seats · ${{seatFee}} per extra seat / month', + 'workspace.billingPage.billing.seats.title': 'Seats', + 'workspace.billingPage.billing.subscriptionTitle': 'Subscription controls', + 'workspace.billingPage.billing.currentPlan.descHobby': + 'Solo workspace · pay only for what you use', + 'workspace.billingPage.billing.currentPlan.descPro': + 'Team workspace with monthly credits allowance', + 'workspace.billingPage.billing.currentPlan.title': 'Current plan', + 'workspace.billingPage.billing.summarySubtitle': 'Workspace subscription and billing breakdown', + 'workspace.billingPage.billing.summaryTitle': 'Subscription', + 'workspace.billingPage.billing.totalLabel': 'Total', + 'workspace.billingPage.billing.upgradeFailedToast': 'Failed to start checkout.', + 'workspace.billingPage.credits.breakdownCount': 'Ops', + 'workspace.billingPage.credits.breakdownSpend': 'Spend', + 'workspace.billingPage.credits.breakdownTitle': 'Spend by category', + 'workspace.billingPage.credits.breakdownType': 'Category', + 'workspace.billingPage.credits.empty': 'No credit data yet', + 'workspace.billingPage.credits.hero.cycleHint': '{{from}} → {{to}}', + 'workspace.billingPage.credits.hero.percentOfTotal': '{{percent}}% of {{total}}', + 'workspace.billingPage.credits.hero.planHobby': 'Hobby · solo workspace', + 'workspace.billingPage.credits.hero.planPro_one': 'Pro · {{count}} seat', + 'workspace.billingPage.credits.hero.planPro_other': 'Pro · {{count}} seats', + 'workspace.billingPage.credits.hero.poolDesc': + 'Shared across all seats. Resets each billing cycle.', + 'workspace.billingPage.credits.hero.remainingLine': 'Remaining {{amount}} credits', + 'workspace.billingPage.credits.hero.resetsIn_one': 'Resets in {{count}} day', + 'workspace.billingPage.credits.hero.resetsIn_other': 'Resets in {{count}} days', + 'workspace.billingPage.credits.hero.resetsToday': 'Resets today', + 'workspace.billingPage.credits.hero.seePlans': 'See Plans', + 'workspace.billingPage.credits.hero.title': 'Credit pool', + 'workspace.billingPage.credits.hero.usedLabel': 'Used', + 'workspace.billingPage.credits.hero.viewUsage': 'View detailed usage', + 'workspace.billingPage.credits.monthly': 'Monthly allowance', + 'workspace.billingPage.credits.packageExpiry': 'Expires', + 'workspace.billingPage.credits.packageId': 'Package', + 'workspace.billingPage.credits.packageLimit': 'Allowance', + 'workspace.billingPage.credits.packageSpend': 'Used', + 'workspace.billingPage.credits.packages.empty.cta': 'See Plans', + 'workspace.billingPage.credits.packages.empty.title': + 'No add-on packages yet. Upgrade via Plans or contact sales for extra capacity.', + 'workspace.billingPage.credits.packages.expired': 'Expired', + 'workspace.billingPage.credits.packages.expiringIn_one': 'in {{count}} day', + 'workspace.billingPage.credits.packages.expiringIn_other': 'in {{count}} days', + 'workspace.billingPage.credits.packages.fallback': 'Package #{{index}}', + 'workspace.billingPage.credits.packages.remaining': 'Remaining', + 'workspace.billingPage.credits.packages.source': 'Source', + 'workspace.billingPage.credits.packages.sourceLabel.autoTopUp': 'Auto top-up', + 'workspace.billingPage.credits.packages.sourceLabel.systemGift': 'System gift', + 'workspace.billingPage.credits.packages.sourceLabel.userPurchase': 'Sales add-on', + 'workspace.billingPage.credits.packages.subtitle': 'All credit packages owned by this workspace', + 'workspace.billingPage.credits.packages.title': 'Workspace credit packages', + 'workspace.billingPage.credits.packages.usedPercent': 'Used', + 'workspace.billingPage.credits.poolDesc': 'Shared across all seats. Resets each billing cycle.', + 'workspace.billingPage.credits.poolTitle': 'Workspace credit pool', + 'workspace.billingPage.credits.resetAt': 'Next reset: {{date}}', + 'workspace.billingPage.credits.status.cancelledCta': 'Resume subscription', + 'workspace.billingPage.credits.status.cancelledDesc': + 'Subscription is scheduled to end on {{date}}. Credits will stop refreshing after that.', + 'workspace.billingPage.credits.balance.creditBalance': 'Top-up credits balance', + 'workspace.billingPage.credits.balance.hobbyDesc': + 'Hobby workspaces do not include subscription credits — top up below or upgrade to Pro.', + 'workspace.billingPage.credits.balance.link.history': 'Top-up history', + 'workspace.billingPage.credits.balance.link.usage': 'View usage', + 'workspace.billingPage.credits.balance.plansUsage': 'Subscription credits', + 'workspace.billingPage.credits.balance.plansUsageDesc': + 'Subscription credits are used first, then top-up credits', + 'workspace.billingPage.credits.balance.sharedHint': 'Shared by all workspace members', + 'workspace.billingPage.credits.balance.sharedTag': 'Workspace-shared', + 'workspace.billingPage.credits.balance.title': 'Balance', + 'workspace.billingPage.credits.title': 'Credits', + 'workspace.billingPage.credits.topUp.custom': 'Custom', + 'workspace.billingPage.credits.topUp.maxAmountError': + 'Purchase amount cannot exceed ${{max}} per transaction.', + 'workspace.billingPage.credits.topUp.purchaseNow': 'Purchase now', + 'workspace.billingPage.credits.topUp.purchaseSuccess': 'Purchase successful.', + 'workspace.billingPage.credits.topUp.selectPackage': 'Select a credit pack', + 'workspace.billingPage.credits.topUp.subtitle': + 'Add credits to this workspace with a one-time purchase', + 'workspace.billingPage.credits.topUp.title': 'Purchase credits', + 'workspace.billingPage.credits.topUp.total': 'Total', + 'workspace.billingPage.credits.topUp.unitPriceFormat': '${{price}} per million compute credits', + 'workspace.billingPage.credits.topUp.upgradePlanName': 'Pro', + 'workspace.billingPage.credits.topUp.upgradePrefix': 'Upgrade to', + 'workspace.billingPage.credits.topUp.upgradeSuffix': 'to save ${{savings}}', + 'workspace.billingPage.credits.topUp.validityInfo': 'valid for {{months}} months', + 'workspace.billingPage.credits.autoTopUp.enable': 'Enable auto top-up', + 'workspace.billingPage.credits.autoTopUp.monthlyCap': 'Monthly cap', + 'workspace.billingPage.credits.autoTopUp.noCustomerHint': + 'Purchase credits once and save a payment method to enable auto top-up.', + 'workspace.billingPage.credits.autoTopUp.noPaymentMethodHint': + 'No saved payment method. Set one up to enable auto top-up.', + 'workspace.billingPage.credits.autoTopUp.purchaseCredits': 'Purchase credits', + 'workspace.billingPage.credits.autoTopUp.setupPaymentMethod': 'Set up payment method', + 'workspace.billingPage.credits.autoTopUp.monthlyLimitReached': + "This month's auto top-up has reached the cap; will resume next month.", + 'workspace.billingPage.credits.autoTopUp.pausedReason.manual': 'Auto top-up was paused manually.', + 'workspace.billingPage.credits.autoTopUp.pausedReason.monthly_cap': + 'Monthly cap reached. Auto top-up will resume on the next billing cycle, or you can raise the cap and re-enable.', + 'workspace.billingPage.credits.autoTopUp.pausedReason.payment_failed': + 'A recent charge failed. Update the payment method and re-enable.', + 'workspace.billingPage.credits.autoTopUp.pausedTitle': 'Auto top-up paused', + 'workspace.billingPage.credits.autoTopUp.save': 'Save', + 'workspace.billingPage.credits.autoTopUp.saveSuccess': 'Auto top-up settings updated.', + 'workspace.billingPage.credits.autoTopUp.subtitle': + 'Keep credits topped up automatically so the team is never blocked.', + 'workspace.billingPage.credits.autoTopUp.target': 'Target balance', + 'workspace.billingPage.credits.autoTopUp.threshold': 'Trigger threshold', + 'workspace.billingPage.credits.autoTopUp.title': 'Auto top-up', + 'workspace.billingPage.credits.autoTopUp.validation.targetMustExceedThreshold': + 'Target balance must be greater than the trigger threshold.', + 'workspace.billingPage.plans.cancelled': 'Pending cancellation', + 'workspace.billingPage.plans.currentTag': 'Current plan', + 'workspace.billingPage.plans.currentTitle': 'Current plan', + 'workspace.billingPage.plans.enterprise.contactCta': 'Contact', + 'workspace.billingPage.plans.enterprise.features.brandTheming': 'Brand theming', + 'workspace.billingPage.plans.enterprise.features.commercialLicense': 'Commercial license', + 'workspace.billingPage.plans.enterprise.features.customIntegration': + 'Custom integration & support', + 'workspace.billingPage.plans.enterprise.features.privateModels': 'Private models', + 'workspace.billingPage.plans.enterprise.features.selfHostedProvider': 'Self-hosted Provider', + 'workspace.billingPage.plans.enterprise.features.userManagement': 'User management', + 'workspace.billingPage.plans.enterprise.priceCaption': 'Tailored to your needs', + 'workspace.billingPage.plans.enterprise.priceText': 'Custom', + 'workspace.billingPage.plans.enterprise.tagline': + 'For teams that need private deployment or custom solutions', + 'workspace.billingPage.plans.enterprise.title': 'Enterprise Edition', + 'workspace.billingPage.plans.creditsHint': + 'Shared monthly pool · every seat draws from the same balance', + 'workspace.billingPage.plans.creditsTitle': 'Workspace credits', + 'workspace.billingPage.plans.creditsTooltip': + 'Workspace-wide monthly credits. Adding seats does NOT grow the pool — overage flows through AutoTopUp.', + 'workspace.billingPage.plans.headline': 'Choose a plan', + 'workspace.billingPage.plans.hobbyCapacity': '1 seat · solo workspace', + 'workspace.billingPage.plans.hobbyCta': 'Free forever', + 'workspace.billingPage.plans.hobbyCreditsHint': 'No monthly credits included by default', + 'workspace.billingPage.plans.hobbyCreditsTooltip': + 'Hobby workspaces do not include monthly credits. Configure your own model API or top up credits as needed.', + 'workspace.billingPage.plans.manageSeatsLink': 'Manage seats', + 'workspace.billingPage.plans.modelsHint': 'Estimated messages from the shared pool', + 'workspace.billingPage.plans.modelsTitle': 'Featured models', + 'workspace.billingPage.plans.perMonth': '/ month', + 'workspace.billingPage.plans.popularTag': 'Popular', + 'workspace.billingPage.plans.priceProCaption': 'Platform fee · billed monthly', + 'workspace.billingPage.plans.priceProHeadline': '${{fee}} / mo', + 'workspace.billingPage.plans.proCapacity': 'Up to {{max}} seats · ${{seatFee}}/seat / month', + 'workspace.billingPage.plans.pricingBannerCta': 'View pricing', + 'workspace.billingPage.plans.pricingBannerDesc': + 'See detailed input/output rates and message estimates for every supported model.', + 'workspace.billingPage.plans.pricingBannerTitle': 'Looking for per-model pricing?', + 'workspace.billingPage.plans.pricingNote': 'For per-model pricing, see {{url}}', + 'workspace.billingPage.plans.upgradeCta': 'Upgrade to Pro', + 'workspace.billingPage.plans.upgradeFailed': 'Failed to start checkout', + 'workspace.billingPage.summary.cancelling': 'Cancelling', + 'workspace.billingPage.summary.upgradeCta': 'See plans', + 'workspace.billingPage.summary.viewFullCta': 'View full billing', + 'workspace.billingPage.usage.activity.filterByMember': 'Filter member', + 'workspace.billingPage.usage.activity.filterByModel': 'Filter model', + 'workspace.billingPage.usage.activity.filterByType': 'Filter type', + 'workspace.billingPage.usage.activity.model': 'Model', + 'workspace.billingPage.usage.activity.viewAll': 'View all', + 'workspace.billingPage.usage.activity.viewAllTitle': 'All recent activity', + 'workspace.billingPage.usage.at': 'When', + 'workspace.billingPage.usage.byMemberDesc': 'Spend distribution across workspace members', + 'workspace.billingPage.usage.byMemberTitle': 'Spend by member', + 'workspace.billingPage.usage.byModelDesc': 'Spend distribution across models', + 'workspace.billingPage.usage.byTypeDesc': 'Spend distribution across credit categories', + 'workspace.billingPage.usage.creditUsage.desc': + 'Credits usage for AI chat, image generation, speech synthesis', + 'workspace.billingPage.usage.creditUsage.resetDesc': 'Quota resets in {{time}}', + 'workspace.billingPage.usage.creditUsage.title': 'Computing Credits Usage', + 'workspace.billingPage.usage.byModelTitle': 'Spend by model', + 'workspace.billingPage.usage.byTypeTitle': 'Spend by category', + 'workspace.billingPage.usage.categories.chat': 'Chat', + 'workspace.billingPage.usage.categories.embedding': 'Embedding', + 'workspace.billingPage.usage.categories.imageGeneration': 'Image generation', + 'workspace.billingPage.usage.categories.tts': 'Text to speech', + 'workspace.billingPage.usage.categories.videoGeneration': 'Video generation', + 'workspace.billingPage.usage.cycleHint': '{{from}} → {{to}}', + 'workspace.billingPage.usage.empty': 'No spend yet', + 'workspace.billingPage.usage.hero.percentOfBudget': '{{percent}}% of {{total}}', + 'workspace.billingPage.usage.hero.resetsIn_one': 'Resets in {{count}} day', + 'workspace.billingPage.usage.hero.resetsIn_other': 'Resets in {{count}} days', + 'workspace.billingPage.usage.hero.resetsToday': 'Resets today', + 'workspace.billingPage.usage.hero.usedLabel': 'Used', + 'workspace.billingPage.usage.last30': 'Last 30 days', + 'workspace.billingPage.usage.logsTitle': 'Recent activity', + 'workspace.billingPage.usage.member': 'Member', + 'workspace.billingPage.usage.messages': 'Messages', + 'workspace.billingPage.usage.model.moreModels_one': '{{count}} more model', + 'workspace.billingPage.usage.model.moreModels_other': '{{count}} more models', + 'workspace.billingPage.usage.model.showLess': 'Show less', + 'workspace.billingPage.usage.model.unknown': 'Unknown model', + 'workspace.billingPage.usage.ops': 'Operations', + 'workspace.billingPage.usage.range.30d': 'Last 30 days', + 'workspace.billingPage.usage.range.all': 'All time', + 'workspace.billingPage.usage.range.cycle': 'This cycle', + 'workspace.billingPage.usage.rank': 'Rank', + 'workspace.billingPage.usage.remaining': 'Remaining: {{amount}}', + 'workspace.billingPage.usage.selfTitle': 'Your usage', + 'workspace.billingPage.usage.spend': 'Spend', + 'workspace.billingPage.usage.summaryCardTitle': 'Credits usage statistics', + 'workspace.billingPage.usage.summaryTitle': 'Workspace credits usage', + 'workspace.billingPage.usage.topSpender': 'Top spender: {{name}} ({{amount}})', + 'workspace.billingPage.usage.trendTitle': 'Daily spend trend', + 'workspace.billingPage.usage.trendTooltip': '{{date}}: {{value}}', + 'workspace.billingPage.usage.type': 'Type', + 'workspace.create.title': 'Create a new workspace', + 'workspace.description.title': 'Description', + 'workspace.general.avatar.description': "This is your workspace's avatar.", + 'workspace.general.avatar.hint': 'An avatar is optional but strongly recommended.', + 'workspace.general.avatar.title': 'Workspace Avatar', + 'workspace.general.avatar.tooLarge': 'Avatar file must be smaller than 5MB.', + 'workspace.general.avatar.uploadFailed': 'Failed to upload avatar', + 'workspace.general.delete.confirm.content': + 'This action cannot be undone. Type the workspace name "{{name}}" to confirm.', + 'workspace.general.delete.confirm.continue': 'Continue', + 'workspace.general.delete.confirm.mismatch': "The name doesn't match. Deletion aborted.", + 'workspace.general.delete.confirm.namePrompt': 'To confirm, type "{{name}}"', + 'workspace.general.delete.confirm.ok': 'Delete workspace', + 'workspace.general.delete.confirm.phrase': 'delete my workspace', + 'workspace.general.delete.confirm.phrasePrompt': 'To confirm, type "{{phrase}}"', + 'workspace.general.delete.confirm.preparation': + 'Before deleting, cancel any active subscription. Billing and spend history will be retained for audit.', + 'workspace.general.delete.confirm.title': 'Delete Workspace', + 'workspace.general.delete.confirm.warning.items.agents': + 'All agents, skills, and their configurations', + 'workspace.general.delete.confirm.warning.items.billing': + 'Subscription, budget settings, and auto top-up', + 'workspace.general.delete.confirm.warning.items.conversations': + 'All sessions, messages, topics, and tasks', + 'workspace.general.delete.confirm.warning.items.files': + 'Uploaded files, generations, and knowledge base data', + 'workspace.general.delete.confirm.warning.items.members': + 'Members, pending invitations, and audit logs', + 'workspace.general.delete.confirm.warning.lead': + 'The {{name}} workspace will be permanently deleted, along with:', + 'workspace.general.delete.confirm.warning.tail': + 'This cannot be undone. Spend and top-up history will be retained for audit only.', + 'workspace.general.delete.cta': 'Delete Workspace', + 'workspace.general.delete.description': + 'Permanently delete this workspace and everything inside it — agents, sessions, messages, files, members, and invitations. This action cannot be reversed.', + 'workspace.general.delete.failed': 'Failed to delete workspace', + 'workspace.general.delete.hint': + 'Cancel any active subscription before deletion. Billing history is kept for audit.', + 'workspace.general.delete.notOwner': 'Only the workspace owner can delete this workspace.', + 'workspace.general.delete.title': 'Delete Workspace', + 'workspace.general.devReset.confirm.cancel': 'Cancel', + 'workspace.general.devReset.confirm.description': + 'This clears finishedAt / skippedAt / step / scenarios and reopens the wizard.', + 'workspace.general.devReset.confirm.ok': 'Reset', + 'workspace.general.devReset.confirm.title': 'Reset workspace onboarding?', + 'workspace.general.devReset.cta': 'Reset onboarding', + 'workspace.general.devReset.description': + 'Clears the onboarding gate (finishedAt / skippedAt / step / scenarios) and reopens the wizard. Dev-only — not visible in production.', + 'workspace.general.devReset.failed': 'Failed to reset onboarding', + 'workspace.general.devReset.hint': 'Dev only', + 'workspace.general.devReset.success': 'Workspace onboarding reset', + 'workspace.general.devReset.title': 'Reset workspace onboarding', + 'workspace.general.id.copied': 'Workspace ID copied', + 'workspace.general.id.description': "This is your workspace's unique ID.", + 'workspace.general.id.hint': 'Used when interacting with the API.', + 'workspace.general.id.title': 'Workspace ID', + 'workspace.general.leave.confirm.content': + 'You will lose access to "{{name}}" immediately. You can rejoin only if you are invited again.', + 'workspace.general.transferAgents.modal.back': 'Back', + 'workspace.general.transferAgents.modal.continue': 'Continue', + 'workspace.general.transferAgents.modal.failed': 'Failed to transfer agents', + 'workspace.general.transferAgents.modal.loadFailed': 'Failed to load agents', + 'workspace.general.transferAgents.modal.noAgents': 'No agents in this workspace', + 'workspace.general.transferAgents.modal.selectAgents': 'Select agents to transfer to {{target}}.', + 'workspace.general.transferAgents.modal.selectPlaceholder': + 'Select workspace or personal account...', + 'workspace.general.transferAgents.modal.selectTarget': + 'Choose a workspace or personal account to transfer agents to.', + 'workspace.general.transferAgents.modal.selected': 'selected', + 'workspace.general.transferAgents.modal.selectedAgent': 'Agent to transfer to {{target}}.', + 'workspace.general.transferAgents.modal.success': '{{count}} agent(s) transferred successfully', + 'workspace.general.transferAgents.modal.title': 'Transfer Agents', + 'workspace.general.transferAgents.modal.transfer': 'Transfer {{count}} agent(s)', + 'workspace.general.transferAgents.modal.warning': + 'Custom plugins may not be available and multi-agent group associations will be removed.', + 'workspace.general.transferAgents.personalAccount': 'Personal Account', + 'workspace.general.transferAgentGroups.modal.back': 'Back', + 'workspace.general.transferAgentGroups.modal.continue': 'Continue', + 'workspace.general.transferAgentGroups.modal.failed': 'Failed to transfer agent groups', + 'workspace.general.transferAgentGroups.modal.loadFailed': 'Failed to load agent groups', + 'workspace.general.transferAgentGroups.modal.noGroups': 'No agent groups in this workspace', + 'workspace.general.transferAgentGroups.modal.selectGroups': 'Select agent groups to transfer.', + 'workspace.general.transferAgentGroups.modal.selectPlaceholder': + 'Select workspace or personal account...', + 'workspace.general.transferAgentGroups.modal.selectTarget': + 'Choose a workspace or personal account to transfer agent groups to.', + 'workspace.general.transferAgentGroups.modal.selected': 'selected', + 'workspace.general.transferAgentGroups.modal.selectedGroup': 'Agent group to transfer.', + 'workspace.general.transferAgentGroups.modal.success': + '{{count}} agent group(s) transferred successfully', + 'workspace.general.transferAgentGroups.modal.title': 'Transfer Agent Groups', + 'workspace.general.transferAgentGroups.modal.transfer': 'Transfer {{count}} agent group(s)', + 'workspace.general.transferAgentGroups.modal.untitledGroup': 'Untitled Agent Group', + 'workspace.general.copyLobeAI.modal.back': 'Back', + 'workspace.general.copyLobeAI.modal.continue': 'Continue', + 'workspace.general.copyLobeAI.modal.copyOptions.config.desc': + 'Required. Copies the model, prompt, tools, and Agent profile.', + 'workspace.general.copyLobeAI.modal.copyOptions.config.title': 'Agent configuration', + 'workspace.general.copyLobeAI.modal.copyOptions.history.desc': + 'Optional. Copies selected agents’ topics and messages into the new agents.', + 'workspace.general.copyLobeAI.modal.copyOptions.history.title': 'Topics and messages', + 'workspace.general.copyLobeAI.modal.copyOptions.knowledgeBase.reason': + 'Not supported yet. Reconnect them in the target workspace or personal account after copying.', + 'workspace.general.copyLobeAI.modal.copyOptions.knowledgeBase.title': 'Knowledge bases and files', + 'workspace.general.copyLobeAI.modal.copyOptions.optional': 'Optional', + 'workspace.general.copyLobeAI.modal.copyOptions.required': 'Selected by default', + 'workspace.general.copyLobeAI.modal.copyOptions.title': 'Copy options', + 'workspace.general.copyLobeAI.modal.copyOptions.unsupported': 'Unavailable', + 'workspace.general.copyLobeAI.modal.create': 'Copy {{count}} agent(s)', + 'workspace.general.copyLobeAI.modal.defaultInboxTitle': 'LobeAI', + 'workspace.general.copyLobeAI.modal.failed': 'Failed to copy agents', + 'workspace.general.copyLobeAI.modal.includeHistory': 'Copy topics and messages', + 'workspace.general.copyLobeAI.modal.includeHistoryDesc': + 'Optional. Copies selected agents’ conversation history into the new agents.', + 'workspace.general.copyLobeAI.modal.loadFailed': 'Failed to load agents', + 'workspace.general.copyLobeAI.modal.noAgents': 'No agents available to copy', + 'workspace.general.copyLobeAI.modal.selected': 'selected', + 'workspace.general.copyLobeAI.modal.selectedAgent': 'Agent to copy.', + 'workspace.general.copyLobeAI.modal.selectAgents': 'Select agents to copy.', + 'workspace.general.copyLobeAI.modal.selectPlaceholder': 'Select workspace or personal account...', + 'workspace.general.copyLobeAI.modal.selectTarget': + 'Choose the target workspace or personal account. Agent configuration is copied by default.', + 'workspace.general.copyLobeAI.modal.success': '{{count}} agent(s) copied', + 'workspace.general.copyLobeAI.modal.title': 'Copy Agents', + 'workspace.general.copyLobeAI.modal.untitledAgent': 'Untitled Agent', + 'workspace.general.copyAgentGroups.modal.back': 'Back', + 'workspace.general.copyAgentGroups.modal.continue': 'Continue', + 'workspace.general.copyAgentGroups.modal.copyOptions.config.desc': + 'Required. Copies group metadata, members, member roles, and Agent profiles.', + 'workspace.general.copyAgentGroups.modal.copyOptions.config.title': 'Agent group configuration', + 'workspace.general.copyAgentGroups.modal.copyOptions.history.desc': + 'Optional. Copies selected groups’ topics and messages into the new groups.', + 'workspace.general.copyAgentGroups.modal.copyOptions.history.title': 'Topics and messages', + 'workspace.general.copyAgentGroups.modal.copyOptions.knowledgeBase.reason': + 'Not supported yet. Reconnect them in the target workspace or personal account after copying.', + 'workspace.general.copyAgentGroups.modal.copyOptions.knowledgeBase.title': + 'Knowledge bases and files', + 'workspace.general.copyAgentGroups.modal.copyOptions.optional': 'Optional', + 'workspace.general.copyAgentGroups.modal.copyOptions.required': 'Selected by default', + 'workspace.general.copyAgentGroups.modal.copyOptions.title': 'Copy options', + 'workspace.general.copyAgentGroups.modal.copyOptions.unsupported': 'Unavailable', + 'workspace.general.copyAgentGroups.modal.create': 'Copy {{count}} agent group(s)', + 'workspace.general.copyAgentGroups.modal.failed': 'Failed to copy agent groups', + 'workspace.general.copyAgentGroups.modal.loadFailed': 'Failed to load agent groups', + 'workspace.general.copyAgentGroups.modal.noGroups': 'No agent groups available to copy', + 'workspace.general.copyAgentGroups.modal.selectGroups': 'Select agent groups to copy.', + 'workspace.general.copyAgentGroups.modal.selectPlaceholder': + 'Select workspace or personal account...', + 'workspace.general.copyAgentGroups.modal.selectTarget': + 'Choose the target workspace or personal account. Group configuration and members are copied.', + 'workspace.general.copyAgentGroups.modal.selected': 'selected', + 'workspace.general.copyAgentGroups.modal.selectedGroup': 'Agent group to copy.', + 'workspace.general.copyAgentGroups.modal.success': '{{count}} agent group(s) copied', + 'workspace.general.copyAgentGroups.modal.title': 'Copy Agent Groups', + 'workspace.general.copyAgentGroups.modal.untitledGroup': 'Untitled Agent Group', + 'workspace.general.transferPrimary.cta': 'Transfer Primary Owner', + 'workspace.general.transferPrimary.description': + 'Transfer primary ownership to another owner. The new primary owner will take over billing and primary privileges for this workspace.', + 'workspace.general.transferPrimary.hint': 'You will remain an owner but lose primary privileges.', + 'workspace.general.transferPrimary.title': 'Transfer Primary Ownership', + 'workspace.general.leave.confirm.ok': 'Leave workspace', + 'workspace.general.leave.confirm.title': 'Leave this workspace?', + 'workspace.general.leave.cta': 'Leave Workspace', + 'workspace.general.leave.description': + "Revoke your access to this workspace. Any resources you've added will remain.", + 'workspace.general.leave.failed': 'Failed to leave workspace', + 'workspace.general.leave.hint': 'To rejoin later, another member must invite you again.', + 'workspace.general.leave.ownerHint': + 'Transfer ownership to another member before leaving the workspace.', + 'workspace.general.leave.title': 'Leave Workspace', + 'workspace.general.name.description': + "This is your workspace's visible name. For example, the name of your company or department.", + 'workspace.general.name.hint': 'Please use {{max}} characters at maximum.', + 'workspace.general.name.title': 'Workspace Name', + 'workspace.general.noPermissionHint': 'You need additional permissions to manage this setting.', + 'workspace.general.role.label': 'Your role', + 'workspace.general.save': 'Save', + 'workspace.general.scenarios.description': + 'Pick the areas this workspace is mainly used for. We will recommend relevant agents based on your selection.', + 'workspace.general.scenarios.hint': 'You can adjust these any time.', + 'workspace.general.scenarios.title': 'Scenarios', + 'workspace.general.subtitle': 'Manage your workspace name, URL, avatar and other settings', + 'workspace.general.title': 'General', + 'workspace.general.updateFailed': 'Failed to update workspace', + 'workspace.general.updateSuccess': 'Workspace updated', + 'workspace.general.url.confirm.content': + 'Existing links to "/{{old}}" will stop working immediately. Anyone who shared a link to this workspace will need a new one. This change cannot be undone.', + 'workspace.general.url.confirm.ok': 'Yes, change URL', + 'workspace.general.url.confirm.title': 'Change workspace URL to "{{next}}"?', + 'workspace.general.url.description': + "This is your workspace's URL namespace. Members can use it to access shared resources.", + 'workspace.general.url.hint': 'Please use {{max}} characters at maximum.', + 'workspace.general.url.invalidBrandProtected': + 'This workspace URL is associated with a protected brand. Please apply from your organization email.', + 'workspace.general.url.invalidConsecutive': 'Slug cannot contain consecutive dashes.', + 'workspace.general.url.invalidLength': 'Slug must be 3–32 characters long.', + 'workspace.general.url.invalidPattern': + 'Slug must start and end with a letter or number; only lowercase letters, numbers and single dashes are allowed.', + 'workspace.general.url.invalidReserved': 'That slug is reserved. Please choose another.', + 'workspace.general.url.renameWarning': 'Renaming will break existing links to this workspace.', + 'workspace.general.url.taken': 'This URL is already taken.', + 'workspace.general.url.title': 'Workspace URL', + 'workspace.slugBrandApply.button': 'Apply', + 'workspace.slugBrandApply.mailBody': + 'Hi LobeHub team,\n\nI would like to request the workspace URL "{{slug}}" (https://lobehub.com/{{slug}}).\n\n- Brand / organization I represent:\n- Official website / domain:\n- My role in the organization:\n- Organization email (please reply from this address):\n\nThanks!', + 'workspace.slugBrandApply.mailButton': 'Apply via email', + 'workspace.slugBrandApply.mailSubject': 'Workspace URL brand request: {{slug}}', + 'workspace.slugBrandApply.modalCreateTip': + 'You can also create your workspace with another URL now and email us at {{email}} from your organization email to request this one later.', + 'workspace.slugBrandApply.modalDesc': + 'This URL is reserved for a protected brand. If you own this brand, apply from your organization email and we’ll review your request shortly.', + 'workspace.slugBrandApply.modalTitle': 'Request a protected brand URL', + 'workspace.member.demote': 'Demote to member', + 'workspace.member.demoteConfirm.content': 'This member will lose owner privileges.', + 'workspace.member.demoteConfirm.title': 'Demote owner?', + 'workspace.member.invite': 'Invite members', + 'workspace.member.manageAccess': 'Manage access', + 'workspace.member.manageAccessModal.current': 'Current', + 'workspace.member.manageAccessModal.failed': 'Failed to update access', + 'workspace.member.manageAccessModal.save': 'Save Changes', + 'workspace.member.manageAccessModal.sectionLabel': 'Select a role', + 'workspace.member.manageAccessModal.subtitleMiddle': 'has for', + 'workspace.member.manageAccessModal.subtitlePrefix': 'Manage the roles', + 'workspace.member.manageAccessModal.subtitleSuffix': '', + 'workspace.member.manageAccessModal.success': 'Access updated', + 'workspace.member.manageAccessModal.title': 'Manage Team Access', + 'workspace.member.primaryOwner': 'Primary', + 'workspace.member.promote': 'Promote to owner', + 'workspace.member.transferPrimaryConfirm.billingNotice.acknowledge': + 'I understand the saved payment method will keep being charged until the new primary owner replaces it.', + 'workspace.member.transferPrimaryConfirm.billingNotice.description': + 'Subscription charges will keep using the payment method on file ({{email}}) until the new primary owner adds their own card in Billing → Payment Methods. Remind the new owner to update it after transfer, or agree that this account will keep paying.', + 'workspace.member.transferPrimaryConfirm.billingNotice.title': + "Payment method stays on the previous owner's card", + 'workspace.member.transferPrimaryConfirm.failed': 'Failed to transfer primary ownership', + 'workspace.member.transferPrimaryConfirm.noOwners': + 'No other owners in this workspace. Promote a member to owner first before transferring primary ownership.', + 'workspace.member.transferPrimaryConfirm.ok': 'Transfer ownership', + 'workspace.member.transferPrimaryConfirm.selectOwner': + 'Select the owner who will become the new primary owner and take over billing for this workspace.', + 'workspace.member.transferPrimaryConfirm.success': 'Primary ownership transferred', + 'workspace.member.transferPrimaryConfirm.title': 'Transfer primary ownership', + 'workspace.member.promoteConfirm.content': + 'This member will gain full owner privileges — billing, member management, and workspace deletion.', + 'workspace.member.promoteConfirm.title': 'Promote to owner?', + 'workspace.member.remove': 'Remove from Workspace', + 'workspace.member.removeConfirm.confirm': 'Confirm', + 'workspace.member.removeConfirm.content': + 'You are about to remove the following member from the workspace, are you sure you want to continue?', + 'workspace.member.removeConfirm.title': 'Remove Workspace Member', + 'workspace.member.removeSuccess': 'Member removed from workspace successfully.', + 'workspace.member.roles.freeBadge': 'Free', + 'workspace.member.roles.member': 'Member', + 'workspace.member.roles.memberDescription': + 'Run AI generations, manage conversations, and collaborate on workspace assets.', + 'workspace.member.roles.owner': 'Owner', + 'workspace.member.roles.ownerDescription': + 'Full team access — billing, member management, and workspace deletion. Only invite people you trust.', + 'workspace.member.roles.viewer': 'Viewer', + 'workspace.member.roles.viewerDescription': + "Browse the workspace, but can't run AI generations — doesn't count toward billable seats.", + 'workspace.invitePage.accept': 'Accept Invitation', + 'workspace.invitePage.acceptedSubtitle': 'You have been added to the workspace.', + 'workspace.invitePage.acceptedTitle': 'Invitation Accepted', + 'workspace.invitePage.alreadyAcceptedSubtitle': + 'This invitation has already been used. You should already be a member of the workspace.', + 'workspace.invitePage.alreadyAcceptedTitle': 'Invitation already accepted', + 'workspace.invitePage.decline': 'Decline', + 'workspace.invitePage.declineSubtitle': 'You can close this window.', + 'workspace.invitePage.declineTitle': 'Invitation Declined', + 'workspace.invitePage.differentEmailNotice': + 'You are signed in as {{currentEmail}}, but this invitation was sent to {{inviteEmail}}.', + 'workspace.invitePage.expiredSubtitle': + 'This invitation has expired. Ask the team owner to send a new one.', + 'workspace.invitePage.expiredTitle': 'Invitation Expired', + 'workspace.invitePage.expiresLabel': 'Expires', + 'workspace.invitePage.goHome': 'Go Home', + 'workspace.invitePage.goToWorkspace': 'Go to Workspace', + 'workspace.invitePage.invitedAs': 'Invited as', + 'workspace.invitePage.invitedBy': 'Invited by', + 'workspace.invitePage.invitedEmail': 'Invited email', + 'workspace.invitePage.invitedTo': '{{inviter}} invited you to join the workspace', + 'workspace.invitePage.memberLimitSubtitle': + 'This workspace already has {{limit}} members. Ask an owner to remove a member before joining.', + 'workspace.invitePage.memberLimitTitle': 'Workspace Is Full', + 'workspace.invitePage.notFoundSubtitle': + 'This invitation link is invalid. Double-check the URL or ask the team owner to send a new one.', + 'workspace.invitePage.notFoundTitle': 'Invitation Not Found', + 'workspace.invitePage.revokedSubtitle': + 'This invitation has been revoked by the workspace owner.', + 'workspace.invitePage.revokedTitle': 'Invitation Revoked', + 'workspace.invitePage.signInToAccept': 'Sign in to accept', + 'workspace.invitePage.switchAccountToAccept': 'Sign in as {{email}} to accept', + 'workspace.invitePage.signUpToJoin': 'Create account & join', + 'workspace.invitePage.title': 'Workspace Invitation', + 'workspace.invitePage.workspaceLimitSubtitle': + "You've reached the maximum of {{limit}} workspaces. Leave one before joining another.", + 'workspace.invitePage.workspaceLimitTitle': 'Workspace Limit Reached', + 'workspace.members.empty': 'No members yet', + 'workspace.members.invite.emailLabel': 'Email Address', + 'workspace.members.invite.emailPlaceholder': 'jane@example.com', + 'workspace.members.invite.errors.alreadyInvited': + '{{email}} already has a pending invitation. Resend or revoke it from the list below.', + 'workspace.members.invite.errors.alreadyMember': + '{{email}} is already a member of this workspace.', + 'workspace.members.invite.failed': 'Failed to send invitation', + 'workspace.members.invite.limitReached': + 'This workspace can have up to {{limit}} members. Remove a member before inviting more.', + 'workspace.members.invite.roleLabel': 'Role', + 'workspace.members.invite.submit': 'Invite', + 'workspace.members.invite.subtitle': + 'Add new members by entering their email address and assigning a role', + 'workspace.members.invite.addAnother': 'Add another', + 'workspace.members.invite.button': 'Invite', + 'workspace.members.invite.modal.cancel': 'Cancel', + 'workspace.members.invite.modal.confirm': 'Confirm', + 'workspace.members.invite.modal.description_one': + 'Your team is expanding! By confirming, you will invite 1 new team member to this workspace.', + 'workspace.members.invite.modal.description_other': + 'Your team is expanding! By confirming, you will invite {{count}} new team members to this workspace.', + 'workspace.members.invite.modal.expiryWarning': 'Team invites expire after 1 week.', + 'workspace.members.invite.modal.title': 'Invite Team Members', + 'workspace.members.invite.noPermissionHint': + 'Additional permissions are required to manage Team Members', + 'workspace.members.invite.partialSuccess': + '{{success}} invited, {{failed}} failed. Check the addresses and try again.', + 'workspace.members.invite.success': 'Team members invited successfully.', + 'workspace.members.invite.title': 'Invite Members', + 'workspace.members.invite.upgradeCta': 'Upgrade', + 'workspace.members.invite.upgradeHint': 'This feature is available on the Pro plan.', + 'workspace.members.pending.empty': 'No pending invitations', + 'workspace.members.pending.expiresAt': 'Expires {{date}}', + 'workspace.members.pending.resend': 'Resend', + 'workspace.members.pending.resendFailed': 'Failed to resend invitation', + 'workspace.members.pending.resendSuccess': 'Invitation email resent', + 'workspace.members.pending.revoke': 'Revoke', + 'workspace.members.pending.revokeConfirm.content': 'The invitation link will no longer be valid.', + 'workspace.members.pending.revokeConfirm.title': 'Revoke this invitation?', + 'workspace.members.subtitle': 'Manage workspace members and invitations', + 'workspace.members.tabs.members': 'Team Members', + 'workspace.members.tabs.pending': 'Pending Invitations', + 'workspace.members.title': 'Members', + 'workspace.name.placeholder': 'Workspace name', + 'workspace.name.title': 'Name', + 'workspace.newWorkspace': 'New workspace', + 'workspace.personalTag': 'Personal', + 'workspace.switchWorkspace': 'Switch workspace', + 'workspace.upgradeModal.alreadyUpgraded': 'Already upgraded', + 'workspace.upgradeModal.changeWorkspace': 'Back', + 'workspace.upgradeModal.chargeDisclosure': + 'Upon clicking Upgrade, you will be charged ${{fee}}, plus any applicable taxes and fees, immediately and then every month, until you cancel. Seat fees and on-demand usage are settled at month-end; if your usage exceeds a billing threshold during a cycle, your payment method on file may be charged before the cycle ends.', + 'workspace.upgradeModal.inviteLaterHint': + 'You can invite more members to your team in the next step.', + 'workspace.upgradeModal.memberCount_one': '{{count}} member', + 'workspace.upgradeModal.memberCount_other': '{{count}} members', + 'workspace.upgradeModal.memberIncluded': 'Included', + 'workspace.upgradeModal.ownerTag': 'Owner', + 'workspace.upgradeModal.totalPerMonth': '${{amount}} / month', + 'workspace.upgradeModal.youLabel': 'You', + 'workspace.upgradeModal.continueCta': 'Continue', + 'workspace.upgradeModal.createTeam': 'Create workspace', + 'workspace.upgradeModal.formSubtitle': + 'Only the platform fee is charged today — seat fees are settled at month-end.', + 'workspace.upgradeModal.formTitle': 'Upgrade {{name}} to Pro', + 'workspace.upgradeModal.heading': 'Upgrade a workspace to Pro', + 'workspace.upgradeModal.hobbyTag': 'Hobby', + 'workspace.upgradeModal.noHobbyHint': "You don't own any Hobby workspaces to upgrade.", + 'workspace.upgradeModal.payFailed': 'Failed to start checkout', + 'workspace.upgradeModal.pickerLabel': 'Select a workspace', + 'workspace.upgradeModal.proTag': 'Pro', + 'workspace.upgradeModal.subtitle': + 'Unlock collaboration, larger credits, and higher rate limits.', + 'workspace.upgradeModal.successPage.amountLabel': 'Total paid', + 'workspace.upgradeModal.successPage.activating': + 'Activating your subscription — this usually takes a few seconds…', + 'workspace.upgradeModal.successPage.continueCta': 'Invite & continue', + 'workspace.upgradeModal.successPage.desc': + '{{name}} is now on Pro. Invite teammates to start collaborating.', + 'workspace.upgradeModal.successPage.inviteDesc': + 'Send invites by email. They will receive a link to join the workspace.', + 'workspace.upgradeModal.successPage.inviteTitle': 'Invite teammates to {{name}}', + 'workspace.upgradeModal.successPage.paidAtLabel': 'Paid at', + 'workspace.upgradeModal.successPage.planTag': 'Workspace Pro', + 'workspace.upgradeModal.successPage.processingDesc': + 'Hang tight — we are confirming the payment with Stripe. This usually takes a few seconds.', + 'workspace.upgradeModal.successPage.processingTitle': 'Processing your payment…', + 'workspace.upgradeModal.successPage.recurring.monthly': 'Monthly', + 'workspace.upgradeModal.successPage.recurring.yearly': 'Yearly', + 'workspace.upgradeModal.successPage.recurringLabel': 'Billing cycle', + 'workspace.upgradeModal.successPage.seatsSummary': '{{count}} included seat', + 'workspace.upgradeModal.successPage.seatsSummary_other': '{{count}} included seats', + 'workspace.upgradeModal.successPage.skipCta': 'Skip for now', + 'workspace.upgradeModal.successPage.title': 'Upgrade complete', + 'workspace.upgradeModal.successToast': 'Upgrade complete — refreshing your plan…', + 'workspace.upgradeModal.title': 'Upgrade to Pro', + 'workspace.upgradeModal.upgradeButton': 'Upgrade', + 'workspace.wizard.back': 'Back', + 'workspace.wizard.cancel': 'Cancel', + 'workspace.wizard.next': 'Next', + 'workspace.wizard.step1.avatar.hint': 'Add an avatar to help your team recognize this workspace.', + 'workspace.wizard.step1.avatar.tooLarge': 'Avatar file must be smaller than 5MB.', + 'workspace.wizard.step1.avatar.uploadFailed': 'Failed to upload avatar', + 'workspace.wizard.step1.avatar.uploading': 'Uploading avatar…', + 'workspace.wizard.step1.cardNamePlaceholder': 'Your Workspace', + 'workspace.wizard.step1.description.label': 'Description', + 'workspace.wizard.step1.description.placeholder': "What's this workspace for? (optional)", + 'workspace.wizard.step1.features.admin': 'Centralized billing & admin controls', + 'workspace.wizard.step1.features.collaboration': 'Invite members to one shared workspace', + 'workspace.wizard.step1.features.roles': 'Roles & permissions', + 'workspace.wizard.step1.features.sharedAssets': 'Shared agents, files & knowledge bases', + 'workspace.wizard.step1.features.sharedCredits': 'Shared team credits pool', + 'workspace.wizard.step1.formSubtitle': 'Give your workspace a name and URL.', + 'workspace.wizard.step1.formTitle': 'Setup Your Workspace', + 'workspace.wizard.step1.name.label': 'Workspace name', + 'workspace.wizard.step1.name.placeholder': 'Example Team', + 'workspace.wizard.step1.name.required': 'Workspace name is required', + 'workspace.wizard.step1.slug.available': 'This URL is available', + 'workspace.wizard.step1.slug.checking': 'Checking availability…', + 'workspace.wizard.step1.slug.invalidBrandProtected': + 'This workspace URL is associated with a protected brand. Please apply from your organization email.', + 'workspace.wizard.step1.slug.invalidConsecutive': 'Slug cannot contain consecutive dashes.', + 'workspace.wizard.step1.slug.invalidLength': + 'Workspace URL must be {{min}}–{{max}} characters long.', + 'workspace.wizard.step1.slug.invalidPattern': + 'Slug must start and end with a letter or number; only lowercase letters, numbers and single dashes are allowed.', + 'workspace.wizard.step1.slug.invalidReserved': 'That slug is reserved. Please choose another.', + 'workspace.wizard.step1.slug.label': 'Workspace URL', + 'workspace.wizard.step1.slug.placeholder': 'example-team', + 'workspace.wizard.step1.slug.prefix': 'lobehub.com/', + 'workspace.wizard.step1.slug.required': 'Workspace URL is required', + 'workspace.wizard.step1.slug.taken': 'This URL is already taken', + 'workspace.wizard.step1.subtitle': 'Unlock Agent Collaboration With Your Teammates', + 'workspace.wizard.step1.title': 'Workspace details', + 'workspace.wizard.step2.billing.freeSummary': 'Solo workspace · free', + 'workspace.wizard.step2.billing.inviteAfterCreateHint': + 'Invite teammates after creating this workspace.', + 'workspace.wizard.step2.billing.platformLine': 'Pro platform · monthly', + 'workspace.wizard.step2.billing.seatLine': 'Additional seats × {{count}}', + 'workspace.wizard.step2.billing.seatPostpaidNote': + 'Seats are billed at month-end: ${{seatFee}}/seat/month, based on actual use.', + 'workspace.wizard.step2.billing.title': 'Cost details', + 'workspace.wizard.step2.billing.total': 'Total', + 'workspace.wizard.step2.billing.totalFreeHint': 'Free for a solo workspace.', + 'workspace.wizard.step2.billing.totalMonthHint': 'Platform fee today · seats billed at month-end', + 'workspace.wizard.step2.chargeDisclosure': + 'Confirming creates this workspace on Pro and charges ${{fee}} now, plus any applicable taxes. The subscription renews monthly until you cancel. Seats and on-demand usage are billed at month-end.', + 'workspace.wizard.step2.confirmPurchase': 'Confirm purchase', + 'workspace.wizard.step2.createFailed': 'Failed to create workspace', + 'workspace.wizard.step2.details.description': "See what's included in your selected plan.", + 'workspace.wizard.step2.details.title': 'Plan Details', + 'workspace.wizard.step2.createFree': 'Create workspace', + 'workspace.wizard.step2.createdToast': 'Workspace {{name}} created.', + 'workspace.wizard.step2.hobbyAgreement': + 'Hobby is free to create and has no monthly credits. Top-ups or AutoTopUp are billed only after you confirm them.', + 'workspace.wizard.step2.header.description': 'Each workspace is billed separately.', + 'workspace.wizard.step2.header.title': 'Select Your Plan', + 'workspace.wizard.step2.freeLimitReached': + "You've reached the free workspace limit ({{limit}}). Upgrade to Pro to create more.", + 'workspace.wizard.step2.totalLimitReached': + "You've reached the maximum of {{limit}} workspaces. Leave one before creating another.", + 'workspace.wizard.step2.features.hobby.share': 'Single-owner workspace', + 'workspace.wizard.step2.features.hobby.solo': 'Solo workspace, no member seats', + 'workspace.wizard.step2.features.hobby.onDemand': 'On-demand usage · AutoTopUp (${{price}}/M)', + 'workspace.wizard.step2.features.hobby.upgradable': 'Upgrade anytime to invite members', + 'workspace.wizard.step2.features.pro.adminControls': 'Centralized billing, roles, and audit logs', + 'workspace.wizard.step2.features.pro.collaboration': 'Invite members · share agents and files', + 'workspace.wizard.step2.features.pro.onDemand': 'On-demand usage · AutoTopUp (${{price}}/M)', + 'workspace.wizard.step2.features.pro.priorityModels': 'Priority premium models', + 'workspace.wizard.step2.features.pro.support': 'Priority email support', + 'workspace.wizard.step2.left.creditsHobbyHint': 'No monthly credits · pay as you go', + 'workspace.wizard.step2.left.creditsLabel': 'Credits per month', + 'workspace.wizard.step2.left.creditsProHint': 'Shared workspace pool · seats do not add credits', + 'workspace.wizard.step2.left.freeHeadline': 'Solo workspace', + 'workspace.wizard.step2.left.freeTagline': 'For individual use. Upgrade later to invite members.', + 'workspace.wizard.step2.left.headline': 'Team workspace', + 'workspace.wizard.step2.left.hobbyTopUpHint': + 'Top up anytime: ${{price}}/M ({{percent}}% above base, with no subscription fee)', + 'workspace.wizard.step2.left.proTagline': '${{fee}}/seat/month. Extra usage is billed on demand.', + 'workspace.wizard.step2.left.proTopUpHint': + 'Top up when credits run low: ${{price}}/M ({{percent}}% off the standard rate)', + 'workspace.wizard.step2.payFailed': 'Failed to start checkout', + 'workspace.wizard.step2.pill.free': 'Free', + 'workspace.wizard.step2.pill.freeUsed': '{{used}}/{{limit}} used', + 'workspace.wizard.step2.pill.proPerSeat': '${{fee}} / seat / month', + 'workspace.wizard.step2.plans.hobby': 'Hobby', + 'workspace.wizard.step2.plans.pro': 'Pro', + 'workspace.wizard.step2.seats.hint': 'Between {{min}} and {{max}} seats.', + 'workspace.wizard.step2.seats.label': 'Seats', + 'workspace.wizard.step2.subtitle': + 'Each workspace is billed separately. Choose a plan to finish.', + 'workspace.wizard.step2.title': 'Choose plan', + 'workspace.wizard.step3.addMore': 'Add more', + 'workspace.wizard.step3.allFailed': 'Could not send invitations', + 'workspace.wizard.step3.emailPlaceholder': 'name@company.com', + 'workspace.wizard.step3.invitedCount': 'Invited {{count}} member(s)', + 'workspace.wizard.step3.inviteAndContinue': 'Invite and continue', + 'workspace.wizard.step3.noEmails': 'No valid emails entered. Skipping invitations.', + 'workspace.wizard.step3.skip': 'Skip for now', + 'workspace.wizard.step3.subtitle': + 'Add your team members by email. You can also invite them later.', + 'workspace.wizard.step3.title': 'Welcome to {{name}}!', + 'workspace.wizard.closeConfirm.cancel': 'Keep editing', + 'workspace.wizard.closeConfirm.content': + 'Your workspace details have not been saved. Closing now will discard them.', + 'workspace.wizard.closeConfirm.ok': 'Discard', + 'workspace.wizard.closeConfirm.title': 'Discard workspace setup?', + 'workspace.wizard.title': 'Create Workspace', + // Tooltips for action buttons disabled by the active user's workspace role. + // Wired through `usePermission`; the two role buckets correspond to the + // RBAC matrix (member can create/edit own; owner has everything). + 'workspace.permission.requiresMember': + "You don't have permission to do this. Ask a workspace owner to grant you Member or higher.", + 'workspace.permission.requiresOwner': + 'Only workspace owners can do this. Ask an owner if you need this changed.', + 'workspace.permission.requiresPrimaryOwner': + 'Only the primary owner can delete this workspace. Transfer primary ownership first if needed.', + 'workspace.onboarding.title': 'Set up your workspace', + 'workspace.onboarding.stepLabel': 'Step {{current}} of {{total}}', + 'workspace.onboarding.skip': 'Skip', + 'workspace.onboarding.prev': 'Back', + 'workspace.onboarding.next': 'Next', + 'workspace.onboarding.finish': 'Finish', + 'workspace.onboarding.skipConfirm.title': 'Skip workspace setup?', + 'workspace.onboarding.skipConfirm.description': + 'You can always customize LobeAI and add agents later in Workspace settings.', + 'workspace.onboarding.skipConfirm.ok': 'Skip', + 'workspace.onboarding.skipConfirm.cancel': 'Continue setup', + 'workspace.onboarding.toast.saved': 'Saved', + 'workspace.onboarding.toast.failed': 'Could not save. Please try again.', + 'workspace.onboarding.step1.heading': 'Personalize LobeAI', + 'workspace.onboarding.step1.subtitle': + 'Give your workspace assistant an identity your team will recognize.', + 'workspace.onboarding.step1.avatarLabel': 'Avatar', + 'workspace.onboarding.step1.avatarHint': 'Pick an emoji or upload an image.', + 'workspace.onboarding.step1.avatarTooLarge': 'Avatar file must be smaller than 5MB.', + 'workspace.onboarding.step1.avatarUploadFailed': 'Failed to upload avatar', + 'workspace.onboarding.step1.nameLabel': "LobeAI's name", + 'workspace.onboarding.step1.namePlaceholder': 'e.g. LobeAI', + 'workspace.onboarding.step1.suggestion.title': 'Need ideas? Pick one to start.', + 'workspace.onboarding.step1.suggestion.switch': 'Try another set', + 'workspace.onboarding.step1.guide.name.title': 'Give It a Name', + 'workspace.onboarding.step1.guide.name.desc': + 'A name your teammates instantly recognize keeps collaboration smooth.', + 'workspace.onboarding.step1.guide.knowYou.title': 'Get to Know the Team', + 'workspace.onboarding.step1.guide.knowYou.desc': + "Share what you're working on — more context means better help.", + 'workspace.onboarding.step1.guide.growTogether.title': 'Grow with the Team', + 'workspace.onboarding.step1.guide.growTogether.desc': + "Every conversation teaches me your team's vibe — the longer we work together, the better.", + 'workspace.onboarding.step1.footer': + "Set up your workspace's LobeAI assistant — it learns from every conversation and grows into your team's go-to teammate.", + 'workspace.onboarding.step1.sentence.1': "Ready? Let me be your team's go-to teammate.", + 'workspace.onboarding.step1.sentence.2': 'What role do you want me to play in this workspace?', + 'workspace.onboarding.step1.sentence.3': 'First, give me a name your team will love :)', + 'workspace.onboarding.step2.heading': "What's this workspace mostly for?", + 'workspace.onboarding.step2.sentence.1': "What's this workspace mostly for?", + 'workspace.onboarding.step2.sentence.2': 'Which areas should I prioritize for your team?', + 'workspace.onboarding.step2.sentence.3': "Pick a few — I'll recommend matching teammates :)", + 'workspace.onboarding.step2.subtitle': + "Pick one or more — we'll suggest starter agents to match. You can change this later.", + 'workspace.onboarding.step2.scenario.business': 'Business & Strategy', + 'workspace.onboarding.step2.scenario.coding': 'Programming & Development', + 'workspace.onboarding.step2.scenario.creator': 'Creator Economy', + 'workspace.onboarding.step2.scenario.design': 'Design & Creativity', + 'workspace.onboarding.step2.scenario.education': 'Learning & Research', + 'workspace.onboarding.step2.scenario.finance-legal': 'Finance & Legal', + 'workspace.onboarding.step2.scenario.health': 'Health & Habits', + 'workspace.onboarding.step2.scenario.hobbies': 'Hobbies & Culture', + 'workspace.onboarding.step2.scenario.hr': 'People & HR', + 'workspace.onboarding.step2.scenario.investing': 'Investing & Finance', + 'workspace.onboarding.step2.scenario.marketing': 'Marketing & Promotion', + 'workspace.onboarding.step2.scenario.operations': 'Operations & Admin', + 'workspace.onboarding.step2.scenario.parenting': 'Family & Parenting', + 'workspace.onboarding.step2.scenario.personal': 'Personal Life', + 'workspace.onboarding.step2.scenario.product': 'Product & Management', + 'workspace.onboarding.step2.scenario.sales': 'Sales & Customer Relations', + 'workspace.onboarding.step2.scenario.writing': 'Content Creation', + 'workspace.onboarding.step3.heading': 'Add a few agents to your workspace', + 'workspace.onboarding.step3.subtitle': 'Pick a few to start — discover more anytime.', + 'workspace.onboarding.step3.categoryAll': 'All', + 'workspace.onboarding.step3.skipInstall': "Don't install any", + 'workspace.onboarding.step3.installed': 'Added {{count}} agent(s) to your workspace', + 'workspace.onboarding.step3.empty': 'No recommendations available right now.', + 'workspaceSetting.breadcrumb.settings': 'Settings', + 'workspaceSetting.group.admin': 'Admin', + 'workspaceSetting.group.agent': 'Agent', + 'workspaceSetting.group.general': 'General', + 'workspaceSetting.group.subscription': 'Plans', + 'workspaceSetting.tab.billing': 'Bills', + 'workspaceSetting.tab.credits': 'Credits', + 'workspaceSetting.tab.general': 'General', + 'workspaceSetting.tab.members': 'Members', + 'workspaceSetting.tab.plans': 'Plans', + 'workspaceSetting.storage.comingSoon': 'Workspace-scoped data import & export is coming soon.', + 'workspaceSetting.storage.danger.clear.desc': + 'Delete all data in this workspace, including agents, files, messages, and skills. The workspace itself will NOT be deleted.', + 'workspaceSetting.storage.danger.clear.title': 'Wipe Workspace Data', + 'workspaceSetting.storage.danger.reset.desc': + 'Restore all workspace settings to defaults. Workspace data will not be deleted.', + 'workspaceSetting.storage.danger.reset.title': 'Reset Workspace Settings', + 'workspaceSetting.storage.telemetry.desc': + 'Help us improve {{appName}} with anonymous workspace usage data', + 'workspaceSetting.storage.telemetry.title': 'Send Anonymous Workspace Usage Data', + 'workspaceSetting.tab.skill': 'Skills', + 'workspaceSetting.tab.usage': 'Usage', 'tools.add': 'Add Skill', 'tools.addSkillOrConnector': 'Add Skills / Connector', 'tools.builtins.groupName': 'Built-ins', diff --git a/src/locales/default/subscription.ts b/src/locales/default/subscription.ts index 282dc88f18..7b83963501 100644 --- a/src/locales/default/subscription.ts +++ b/src/locales/default/subscription.ts @@ -175,6 +175,12 @@ export default { 'limitation.insufficientBudget.retry': 'Retry', 'limitation.insufficientBudget.shortfall': 'Credit Shortfall', 'limitation.insufficientBudget.title': 'Insufficient Credits', + 'limitation.workspace.insufficientBudget.available': 'Available Credits', + 'limitation.workspace.insufficientBudget.desc': + 'Credits are not enough to continue. Top up credits or upgrade the plan.', + 'limitation.workspace.insufficientBudget.title': 'Workspace Credits Insufficient', + 'limitation.workspace.insufficientBudget.topup': 'Top Up Credits', + 'limitation.workspace.insufficientBudget.upgradeToPro': 'Upgrade to Pro', 'limitation.hobby.action': 'Configured, continue chatting', 'limitation.hobby.configAPI': 'Configure API', 'limitation.hobby.desc': @@ -281,6 +287,7 @@ export default { 'plans.credit.tip': '{{credit}} free credits per month', 'plans.credit.title': 'Credits', 'plans.credit.tooltip': 'Monthly model message credits', + 'plans.creditPackage.available': 'Additional credit packages available (${{price}}/M)', 'plans.current': 'Current Plan', 'plans.downgradePlan': 'Target Downgrade Plan', 'plans.downgradeTip': @@ -298,10 +305,16 @@ export default { 'plans.features.agents': 'Curated Agent Market', 'plans.features.ceAgents': 'Community Agent Market', 'plans.features.cePlugins': 'Community Plugin Market', + 'plans.features.earlyAccess': 'Early Access to SOTA Model', + 'plans.features.earlyAccessTooltip': + 'Some frontier models may only be open to subscribed users when the model is initially launched. This does not affect custom API keys.', + 'plans.features.imageGeneration': 'Image Generation', 'plans.features.internet': 'Smart Web Search', 'plans.features.plugins': 'Exclusive Premium Plugins', 'plans.features.showAll': 'View All Features', 'plans.features.title': 'Premium Features', + 'plans.features.unlimitedPages': 'Unlimited Pages', + 'plans.features.videoGeneration': 'Video Generation', 'plans.fileStorage.title': 'File Storage', 'plans.fileStorage.storagePayAsYouGo': 'Storage overages support pay-as-you-go billing', 'plans.fileStorage.tooltip': 'File storage for storing files, images, and other data', @@ -358,6 +371,13 @@ export default { 'plans.support.starter': 'Email and Community Forum', 'plans.support.title': 'Support', 'plans.support.ultimate': 'Priority Chat and Email Support', + 'plans.workspace.features.inviteMembers': 'Invite Members', + 'plans.workspace.features.roles': 'Roles & Permissions', + 'plans.workspace.includesFrom.hobby': 'Everything in Hobby, plus:', + 'plans.workspace.maxMembers': 'Up to {{count}} members', + 'plans.workspace.noSharedCredits': 'No shared credits', + 'plans.workspace.sharedCredits': '~{{count}} Credits / mo', + 'plans.workspace.solo': 'Solo (1 member)', 'plans.target': 'Target Plan', 'plans.unlimited': 'Unlimited', 'qa.desc': @@ -384,6 +404,25 @@ export default { 'qa.support.community': 'Community Support', 'qa.support.email': 'Email Support', 'qa.title': 'FAQ', + 'qa.workspace.list.cancel.a': + 'Click "Cancel subscription" on the Billing tab. Auto-renewal stops at the end of the current billing cycle and the workspace falls back to the {{hobby}} plan. Pro features remain available until then.', + 'qa.workspace.list.cancel.q': 'How do I cancel a workspace subscription?', + 'qa.workspace.list.credits.a': + 'Each workspace has its own shared credit pool, separate from any member’s personal credits. Every member’s usage draws from this pool, and the owner can monitor consumption on the Billing page.', + 'qa.workspace.list.credits.q': 'How are credits shared between workspace members?', + 'qa.workspace.list.intro.a': + 'A workspace is a shared space for team collaboration. Members share computing credits, agents, and knowledge bases inside the workspace, and it is billed independently from personal accounts.', + 'qa.workspace.list.intro.q': 'What is a workspace?', + 'qa.workspace.list.personalVsWorkspace.a': + 'Yes. They are billed independently — personal subscriptions only apply to your personal space, and workspace subscriptions only apply inside the workspace. Usage in a workspace does not deduct from your personal credits.', + 'qa.workspace.list.personalVsWorkspace.q': + 'Can workspace and personal subscriptions be used at the same time?', + 'qa.workspace.list.plans.a': + '{{hobby}} is free and includes a single seat, ideal for solo use to try out the workspace experience. {{pro}} provides a monthly credit allowance and unlocks team seats, seat management, priority support, and other advanced features.', + 'qa.workspace.list.plans.q': 'What is the difference between {{hobby}} and {{pro}}?', + 'qa.workspace.list.seats.a': + 'Pro workspace owners can adjust the seat count from the Billing tab; each additional seat is billed monthly. Seat limits adjust automatically when downgrading to {{hobby}}.', + 'qa.workspace.list.seats.q': 'How do I add or manage seats?', 'recurring.day': 'Daily', 'recurring.fullYear': 'Full Year', 'recurring.monthly': 'Monthly Billing', diff --git a/src/locales/default/topic.ts b/src/locales/default/topic.ts index 8e805e4ee7..46866c90fd 100644 --- a/src/locales/default/topic.ts +++ b/src/locales/default/topic.ts @@ -59,6 +59,8 @@ export default { 'importInvalidFormat': 'Invalid file format. Please ensure it is a valid JSON file.', 'importLoading': 'Importing conversation...', 'importSuccess': 'Successfully imported {{count}} messages', + 'info.title': 'Topic Info', + 'info.updatedAt': 'Updated at {{time}}', 'inPopup.description': 'This topic is currently open in a separate window. Continue the conversation there to keep messages in sync.', 'inPopup.focus': 'Focus Popup Window', diff --git a/src/locales/resources.ts b/src/locales/resources.ts index bc91f2950b..62d504027f 100644 --- a/src/locales/resources.ts +++ b/src/locales/resources.ts @@ -30,10 +30,14 @@ export type Locales = (typeof locales)[number]; export const normalizeLocale = (locale?: string): Locales => { if (!locale) return DEFAULT_LANG; - if (locale.startsWith('ar')) return 'ar'; - if (locale.startsWith('fa')) return 'fa-IR'; + const lowerLocale = locale.toLowerCase(); - if (locale.startsWith('cn')) return 'zh-CN'; + if (lowerLocale.startsWith('ar')) return 'ar'; + if (lowerLocale.startsWith('fa')) return 'fa-IR'; + + if (lowerLocale.startsWith('cn')) return 'zh-CN'; + if (lowerLocale.startsWith('zh-hans')) return 'zh-CN'; + if (lowerLocale.startsWith('zh-hant')) return 'zh-TW'; for (const l of locales) { if (l.startsWith(locale)) { diff --git a/src/server/agent-hono/handlers/__tests__/runStep.test.ts b/src/server/agent-hono/handlers/__tests__/runStep.test.ts index ce197c2f4a..7c4c3be36d 100644 --- a/src/server/agent-hono/handlers/__tests__/runStep.test.ts +++ b/src/server/agent-hono/handlers/__tests__/runStep.test.ts @@ -1,6 +1,8 @@ // @vitest-environment node import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; +import { AgentRuntimeService } from '@/server/services/agentRuntime'; + import { runStep, runStepHealth } from '../runStep'; const mockGetOperationMetadata = vi.fn(); @@ -84,6 +86,28 @@ describe('runStep handler', () => { expect(mockExecuteStep).not.toHaveBeenCalled(); }); + it('constructs AgentRuntimeService with the workspaceId from operation metadata', async () => { + // Regression: a workspace-scoped binding (e.g. Discord bot active agent) runs + // its steps through this QStash worker. Dropping workspaceId here makes the + // runtime personal-scoped, so the parent-message lookup misses the + // workspace-scoped row and throws ConversationParentMissing. + mockGetOperationMetadata.mockResolvedValue({ userId: 'user-1', workspaceId: 'ws-1' }); + mockExecuteStep.mockResolvedValue({ + nextStepScheduled: false, + state: { cost: { total: 0 }, status: 'done', stepCount: 1 }, + success: true, + }); + + const { ctx } = buildContext({ body: validBody }); + await runStep(ctx); + + expect(AgentRuntimeService).toHaveBeenCalledWith( + expect.anything(), + 'user-1', + expect.objectContaining({ workspaceId: 'ws-1' }), + ); + }); + it('returns 429 with Retry-After header when the step is locked', async () => { mockGetOperationMetadata.mockResolvedValue({ userId: 'user-1' }); mockExecuteStep.mockResolvedValue({ diff --git a/src/server/agent-hono/handlers/execAgent.ts b/src/server/agent-hono/handlers/execAgent.ts index c2902af64d..b87cdf342b 100644 --- a/src/server/agent-hono/handlers/execAgent.ts +++ b/src/server/agent-hono/handlers/execAgent.ts @@ -27,6 +27,7 @@ export async function execAgent(c: Context): Promise { appContext, autoStart = true, existingMessageIds, + workspaceId, } = body as Record & { autoStart?: boolean }; if (!userId) return c.json({ error: 'userId is required' }, 400); @@ -39,7 +40,9 @@ export async function execAgent(c: Context): Promise { try { const serverDB = await getServerDB(); - const aiAgentService = new AiAgentService(serverDB, userId as string); + const aiAgentService = new AiAgentService(serverDB, userId as string, { + workspaceId: typeof workspaceId === 'string' ? workspaceId : undefined, + }); const result = await aiAgentService.execAgent({ agentId: agentId as string | undefined, diff --git a/src/server/agent-hono/handlers/gatewayCron.ts b/src/server/agent-hono/handlers/gatewayCron.ts index 19f7347090..a42cf36588 100644 --- a/src/server/agent-hono/handlers/gatewayCron.ts +++ b/src/server/agent-hono/handlers/gatewayCron.ts @@ -76,8 +76,12 @@ async function processConnectQueue(remainingMs: number): Promise { continue; } - const model = new AgentBotProviderModel(serverDB, item.userId, gateKeeper); - const provider = await model.findEnabledByApplicationId(item.platform, item.applicationId); + const provider = await AgentBotProviderModel.findEnabledByPlatformAndAppId( + serverDB, + item.platform, + item.applicationId, + gateKeeper, + ); if (!provider) { log('No enabled provider found for queued %s appId=%s', item.platform, item.applicationId); diff --git a/src/server/agent-hono/handlers/runStep.ts b/src/server/agent-hono/handlers/runStep.ts index 291d1f7419..5bf9dba5bf 100644 --- a/src/server/agent-hono/handlers/runStep.ts +++ b/src/server/agent-hono/handlers/runStep.ts @@ -59,7 +59,12 @@ export async function runStep(c: Context): Promise { } const serverDB = await getServerDB(); - const agentRuntimeService = new AgentRuntimeService(serverDB, metadata.userId); + // Thread the operation's workspace through so the runtime's models stay + // workspace-scoped. Without it the worker is personal-scoped and the + // parent-message lookup misses workspace-scoped rows → ConversationParentMissing. + const agentRuntimeService = new AgentRuntimeService(serverDB, metadata.userId, { + workspaceId: metadata.workspaceId, + }); const result = await agentRuntimeService.executeStep({ approvedToolCall, diff --git a/src/server/modules/AgentRuntime/AgentRuntimeCoordinator.ts b/src/server/modules/AgentRuntime/AgentRuntimeCoordinator.ts index 00f8016585..03a6998f32 100644 --- a/src/server/modules/AgentRuntime/AgentRuntimeCoordinator.ts +++ b/src/server/modules/AgentRuntime/AgentRuntimeCoordinator.ts @@ -94,6 +94,7 @@ export class AgentRuntimeCoordinator { agentConfig?: any; modelRuntimeConfig?: any; userId?: string; + workspaceId?: string; }, ): Promise { try { diff --git a/src/server/modules/AgentRuntime/AgentStateManager.ts b/src/server/modules/AgentRuntime/AgentStateManager.ts index 4302b49ca7..3fe4ae0a0e 100644 --- a/src/server/modules/AgentRuntime/AgentStateManager.ts +++ b/src/server/modules/AgentRuntime/AgentStateManager.ts @@ -27,6 +27,13 @@ export interface AgentOperationMetadata { totalCost: number; totalSteps: number; userId?: string; + /** + * Workspace the operation runs in (null/undefined = personal). Persisted so + * queue workers (e.g. QStash `runStep`) can reconstruct a workspace-scoped + * runtime; without it the runtime is personal-scoped and message/topic + * lookups miss workspace-scoped rows. + */ + workspaceId?: string; } export class AgentStateManager { @@ -194,6 +201,7 @@ export class AgentStateManager { totalCost: parseFloat(metadata.totalCost) || 0, totalSteps: parseInt(metadata.totalSteps) || 0, userId: metadata.userId, + workspaceId: metadata.workspaceId, }; } catch (error) { console.error('Failed to get operation metadata:', error); @@ -210,6 +218,7 @@ export class AgentStateManager { agentConfig?: any; modelRuntimeConfig?: any; userId?: string; + workspaceId?: string; }, ): Promise { const metaKey = `${this.METADATA_PREFIX}:${operationId}`; @@ -224,6 +233,7 @@ export class AgentStateManager { totalCost: 0, totalSteps: 0, userId: data.userId, + workspaceId: data.workspaceId, }; // Serialize complex objects @@ -236,6 +246,7 @@ export class AgentStateManager { }; if (metadata.userId) redisData.userId = metadata.userId; + if (metadata.workspaceId) redisData.workspaceId = metadata.workspaceId; if (metadata.modelRuntimeConfig) redisData.modelRuntimeConfig = JSON.stringify(metadata.modelRuntimeConfig); if (metadata.agentConfig) redisData.agentConfig = JSON.stringify(metadata.agentConfig); diff --git a/src/server/modules/AgentRuntime/InMemoryAgentStateManager.ts b/src/server/modules/AgentRuntime/InMemoryAgentStateManager.ts index 63dd4ef792..4ee8d4505e 100644 --- a/src/server/modules/AgentRuntime/InMemoryAgentStateManager.ts +++ b/src/server/modules/AgentRuntime/InMemoryAgentStateManager.ts @@ -122,6 +122,7 @@ export class InMemoryAgentStateManager implements IAgentStateManager { agentConfig?: any; modelRuntimeConfig?: any; userId?: string; + workspaceId?: string; }, ): Promise { const metadata: AgentOperationMetadata = { @@ -133,6 +134,7 @@ export class InMemoryAgentStateManager implements IAgentStateManager { totalCost: 0, totalSteps: 0, userId: data.userId, + workspaceId: data.workspaceId, }; this.metadata.set(operationId, metadata); diff --git a/src/server/modules/AgentRuntime/RuntimeExecutors.ts b/src/server/modules/AgentRuntime/RuntimeExecutors.ts index 97c5576b93..03aa02b107 100644 --- a/src/server/modules/AgentRuntime/RuntimeExecutors.ts +++ b/src/server/modules/AgentRuntime/RuntimeExecutors.ts @@ -24,6 +24,7 @@ import { BRANDING_PROVIDER } from '@lobechat/business-const'; import { KLAVIS_SERVER_TYPES } from '@lobechat/const'; import { type AgentContextDocument, + type AgentGroupConfig, type BotPlatformContext, buildStepSkillDelta, buildStepToolDelta, @@ -131,6 +132,38 @@ const LLM_RETRY_MAX_DELAY_MS = 30_000; */ const EMPTY_COMPLETION_MAX_RETRIES = 2; +const buildBotAgentGroupContext = (params: { + agentConfig?: any; + agentId?: string; + botContext?: unknown; +}): AgentGroupConfig | undefined => { + if (!params.botContext || !params.agentId) return undefined; + + const title = params.agentConfig?.title; + const description = params.agentConfig?.description; + const name = typeof title === 'string' && title.trim() ? title.trim() : 'Current Agent'; + + return { + agentMap: { + [params.agentId]: { + name, + role: 'participant', + }, + }, + currentAgentId: params.agentId, + currentAgentName: name, + currentAgentRole: 'participant', + members: [ + { + id: params.agentId, + name, + role: 'participant', + }, + ], + systemPrompt: typeof description === 'string' ? description : undefined, + }; +}; + /** * Output-token count at or below this — combined with no content, reasoning, * tool calls, or images — marks a turn as an empty completion. @@ -193,6 +226,7 @@ const archiveRuntimeToolResult = async ( toolCallId, topicId, userId, + workspaceId, }: { agentId?: string | null; identifier?: string; @@ -201,6 +235,7 @@ const archiveRuntimeToolResult = async ( toolCallId?: string; topicId?: string | null; userId?: string; + workspaceId?: string; }, ): Promise => { const archive = await archiveToolResultIfNeeded({ @@ -212,6 +247,7 @@ const archiveRuntimeToolResult = async ( toolCallId, topicId, userId, + workspaceId, }); return archive.content === result.content ? result : { ...result, content: archive.content }; @@ -226,11 +262,13 @@ const archiveRuntimeToolResult = async ( // FileService is constructed lazily so environments without S3 config (unit // tests) don't fail at context-build time; failure returns undefined, which // leaves URLs as raw keys — same behavior as before this helper existed. -const buildPostProcessUrl = (ctx: Pick) => { +const buildPostProcessUrl = ( + ctx: Pick, +) => { if (!ctx.userId || !ctx.serverDB) return undefined; let fileService: FileService | undefined; try { - fileService = new FileService(ctx.serverDB, ctx.userId); + fileService = new FileService(ctx.serverDB, ctx.userId, ctx.workspaceId); } catch { return undefined; } @@ -433,6 +471,7 @@ const buildToolDiscoveryConfig = (operationToolSet: OperationToolSet, enabledToo export interface RuntimeExecutorContext { agentConfig?: any; + botContext?: unknown; botPlatformContext?: BotPlatformContext; discordContext?: any; evalContext?: EvalContext; @@ -466,6 +505,13 @@ export interface RuntimeExecutorContext { tracingContextEngine?: (input: unknown, output: unknown) => void; userId?: string; userTimezone?: string; + /** + * Workspace scoping for ownership filters on models/services constructed + * inside the agent runtime. Threaded down from the originating request + * (chat/task router) and forwarded to tool executions via + * `ToolExecutionContext.workspaceId`. + */ + workspaceId?: string; } export const createRuntimeExecutors = ( @@ -634,8 +680,8 @@ export const createRuntimeExecutors = ( ); if (!alreadyHasTopicRefs && ctx.serverDB && ctx.userId) { - const topicModel = new TopicModel(ctx.serverDB, ctx.userId); - const messageModel = new MessageModelClass(ctx.serverDB, ctx.userId); + const topicModel = new TopicModel(ctx.serverDB, ctx.userId, ctx.workspaceId); + const messageModel = new MessageModelClass(ctx.serverDB, ctx.userId, ctx.workspaceId); topicReferences = await resolveTopicReferences( llmPayload.messages as Array<{ content: string | unknown }>, async (topicId) => topicModel.findById(topicId), @@ -658,7 +704,11 @@ export const createRuntimeExecutors = ( const agentId = state.metadata?.agentId; if (agentId && ctx.serverDB && ctx.userId) { try { - const agentDocService = new AgentDocumentsService(ctx.serverDB, ctx.userId); + const agentDocService = new AgentDocumentsService( + ctx.serverDB, + ctx.userId, + state.metadata?.workspaceId ?? ctx.workspaceId, + ); const docs = await agentDocService.getAgentContextDocuments(agentId); if (docs.length > 0) { agentDocuments = toAgentContextDocuments(docs); @@ -692,7 +742,11 @@ export const createRuntimeExecutors = ( await import('@lobechat/builtin-tool-web-onboarding/utils'); const { UserPersonaModel } = await import('@/database/models/userMemory/persona'); const onboardingService = new OnboardingService(ctx.serverDB, ctx.userId); - const docService = new AgentDocumentsService(ctx.serverDB, ctx.userId); + const docService = new AgentDocumentsService( + ctx.serverDB, + ctx.userId, + state.metadata?.workspaceId ?? ctx.workspaceId, + ); const personaModel = new UserPersonaModel(ctx.serverDB, ctx.userId); const [onboardingState, soulDoc, persona, userInfo] = await Promise.all([ @@ -752,7 +806,7 @@ export const createRuntimeExecutors = ( let lobehubSkillTopicTitle = ''; if (lobehubSkillTopicId && ctx.serverDB && ctx.userId) { try { - const topicModelForLobehub = new TopicModel(ctx.serverDB, ctx.userId); + const topicModelForLobehub = new TopicModel(ctx.serverDB, ctx.userId, ctx.workspaceId); const topicRecord = await topicModelForLobehub.findById(lobehubSkillTopicId); lobehubSkillTopicTitle = topicRecord?.title ?? ''; } catch (error) { @@ -853,7 +907,7 @@ export const createRuntimeExecutors = ( if (ctx.serverDB && ctx.userId && !!klavisEnv.KLAVIS_API_KEY) { try { const { PluginModel } = await import('@/database/models/plugin'); - const pluginModel = new PluginModel(ctx.serverDB, ctx.userId); + const pluginModel = new PluginModel(ctx.serverDB, ctx.userId, ctx.workspaceId); const allPlugins = await pluginModel.query(); const validKlavisIds = new Set(KLAVIS_SERVER_TYPES.map((t) => t.identifier)); const connectedIds = new Set( @@ -887,6 +941,11 @@ export const createRuntimeExecutors = ( const contextEngineInput = { agentDocuments, + agentGroup: buildBotAgentGroupContext({ + agentConfig, + agentId: state.metadata?.agentId, + botContext: state.metadata?.botContext ?? ctx.botContext, + }), additionalVariables: { ...state.metadata?.deviceSystemInfo, ...lobehubSkillVariables, @@ -1034,7 +1093,12 @@ export const createRuntimeExecutors = ( } // Initialize ModelRuntime (read user's keyVaults from database) - const modelRuntime = await initModelRuntimeFromDB(ctx.serverDB, ctx.userId!, provider); + const modelRuntime = await initModelRuntimeFromDB( + ctx.serverDB, + ctx.userId!, + provider, + ctx.workspaceId, + ); // Construct ChatStreamPayload const stream = ctx.stream ?? true; @@ -1874,7 +1938,11 @@ export const createRuntimeExecutors = ( } const latestAssistantMessage = dbMessages.findLast((message) => message.role === 'assistant'); - const messageService = new MessageService(ctx.serverDB, ctx.userId); + const messageService = new MessageService( + ctx.serverDB, + ctx.userId, + state.metadata?.workspaceId ?? ctx.workspaceId, + ); const compressionResult = await messageService.createCompressionGroup(topicId, messageIds, { agentId: state.metadata?.agentId, threadId: state.metadata?.threadId, @@ -1911,6 +1979,7 @@ export const createRuntimeExecutors = ( ctx.serverDB, ctx.userId, compressionModel.provider, + ctx.workspaceId, ); let summaryContent = ''; @@ -2315,6 +2384,7 @@ export const createRuntimeExecutors = ( toolResultMaxLength, topicId: ctx.topicId, userId: ctx.userId, + workspaceId: state.metadata?.workspaceId ?? ctx.workspaceId, }), { isInterrupted: () => isOperationInterrupted(ctx), @@ -2369,6 +2439,7 @@ export const createRuntimeExecutors = ( toolCallId: chatToolPayload.id, topicId: ctx.topicId ?? state.metadata?.topicId, userId: ctx.userId, + workspaceId: state.metadata?.workspaceId ?? ctx.workspaceId, }); const executionTime = executionResult.executionTime; const isSuccess = executionResult.success; @@ -2884,6 +2955,7 @@ export const createRuntimeExecutors = ( toolResultMaxLength: batchAgentConfig?.chatConfig?.toolResultMaxLength, topicId: ctx.topicId, userId: ctx.userId, + workspaceId: state.metadata?.workspaceId ?? ctx.workspaceId, }), { isInterrupted: () => isOperationInterrupted(ctx), @@ -2914,6 +2986,7 @@ export const createRuntimeExecutors = ( toolCallId: chatToolPayload.id, topicId: ctx.topicId ?? state.metadata?.topicId, userId: ctx.userId, + workspaceId: state.metadata?.workspaceId ?? ctx.workspaceId, }); const executionTime = executionResult.executionTime; const isSuccess = executionResult.success; @@ -3437,7 +3510,7 @@ export const createRuntimeExecutors = ( // Clear runningOperation from topic metadata so reconnect doesn't trigger after completion if (ctx.topicId && ctx.userId) { try { - const topicModel = new TopicModel(ctx.serverDB, ctx.userId); + const topicModel = new TopicModel(ctx.serverDB, ctx.userId, ctx.workspaceId); await topicModel.updateMetadata(ctx.topicId, { runningOperation: null }); } catch (e) { log('[%s] Failed to clear runningOperation metadata: %O', operationId, e); diff --git a/src/server/modules/AgentRuntime/__tests__/RuntimeExecutors.test.ts b/src/server/modules/AgentRuntime/__tests__/RuntimeExecutors.test.ts index c963a14f70..ef51d7da5d 100644 --- a/src/server/modules/AgentRuntime/__tests__/RuntimeExecutors.test.ts +++ b/src/server/modules/AgentRuntime/__tests__/RuntimeExecutors.test.ts @@ -237,6 +237,31 @@ describe('RuntimeExecutors', () => { ); }); + it('passes workspaceId to model runtime initialization', async () => { + const workspaceCtx = { ...ctx, workspaceId: 'ws-1' }; + const executors = createRuntimeExecutors(workspaceCtx); + const state = createMockState(); + + const instruction = { + payload: { + messages: [{ content: 'Hello', role: 'user' }], + model: 'gpt-4', + provider: 'openai', + tools: [], + }, + type: 'call_llm' as const, + }; + + await executors.call_llm!(instruction, state); + + expect(initModelRuntimeFromDB).toHaveBeenCalledWith( + workspaceCtx.serverDB, + 'user-123', + 'openai', + 'ws-1', + ); + }); + it('should pass parentId from payload.parentMessageId to messageModel.create', async () => { const executors = createRuntimeExecutors(ctx); const state = createMockState(); @@ -1449,6 +1474,68 @@ describe('RuntimeExecutors', () => { expect(engineSpy).toHaveBeenCalledWith(expect.objectContaining({ evalContext })); }); + it('should inject current agent identity for bot-originated runs', async () => { + const ctxWithConfig: RuntimeExecutorContext = { + ...ctx, + agentConfig: { + description: 'Answers customer support questions.', + plugins: [], + systemRole: 'test', + title: 'Support Bot', + }, + botContext: { + applicationId: 'discord-app', + isOwner: true, + platform: 'discord', + platformThreadId: 'discord:channel-1', + senderExternalUserId: 'user-platform-id', + }, + }; + const executors = createRuntimeExecutors(ctxWithConfig); + const state = createMockState({ + metadata: { + agentId: 'agent-support', + botContext: ctxWithConfig.botContext, + topicId: 'topic-123', + }, + }); + + const instruction = { + payload: { + messages: [{ content: 'Hello', role: 'user' }], + model: 'gpt-4', + provider: 'openai', + }, + type: 'call_llm' as const, + }; + + await executors.call_llm!(instruction, state); + + expect(engineSpy).toHaveBeenCalledWith( + expect.objectContaining({ + agentGroup: { + agentMap: { + 'agent-support': { + name: 'Support Bot', + role: 'participant', + }, + }, + currentAgentId: 'agent-support', + currentAgentName: 'Support Bot', + currentAgentRole: 'participant', + members: [ + { + id: 'agent-support', + name: 'Support Bot', + role: 'participant', + }, + ], + systemPrompt: 'Answers customer support questions.', + }, + }), + ); + }); + it('should build capabilities from LOBE_DEFAULT_MODEL_LIST', async () => { const ctxWithConfig: RuntimeExecutorContext = { ...ctx, diff --git a/src/server/modules/AgentRuntime/types.ts b/src/server/modules/AgentRuntime/types.ts index ca971abbdb..60117d3c66 100644 --- a/src/server/modules/AgentRuntime/types.ts +++ b/src/server/modules/AgentRuntime/types.ts @@ -38,6 +38,7 @@ export interface IAgentStateManager { agentConfig?: any; modelRuntimeConfig?: any; userId?: string; + workspaceId?: string; }, ) => Promise; diff --git a/src/server/modules/Mecha/ContextEngineering/index.ts b/src/server/modules/Mecha/ContextEngineering/index.ts index ef7d14b7e3..606b5f1ca8 100644 --- a/src/server/modules/Mecha/ContextEngineering/index.ts +++ b/src/server/modules/Mecha/ContextEngineering/index.ts @@ -69,6 +69,7 @@ export const serverMessagesEngine = async ({ capabilities, userMemory, agentBuilderContext, + agentGroup, botPlatformContext, discordContext, evalContext, @@ -161,6 +162,7 @@ export const serverMessagesEngine = async ({ // Extended contexts ...(agentBuilderContext && { agentBuilderContext }), + ...(agentGroup && { agentGroup }), ...(botPlatformContext && { botPlatformContext }), ...(discordContext && { discordContext }), ...(evalContext && { evalContext }), diff --git a/src/server/modules/Mecha/ContextEngineering/types.ts b/src/server/modules/Mecha/ContextEngineering/types.ts index 16731b9a09..755125166a 100644 --- a/src/server/modules/Mecha/ContextEngineering/types.ts +++ b/src/server/modules/Mecha/ContextEngineering/types.ts @@ -2,6 +2,7 @@ import type { AgentBuilderContext, AgentContextDocument, + AgentGroupConfig, AgentManagementContext, BotPlatformContext, DiscordContext, @@ -78,6 +79,8 @@ export interface ServerMessagesEngineParams { // ========== Extended contexts ========== /** Agent Builder context (optional, for editing agents) */ agentBuilderContext?: AgentBuilderContext; + /** Agent identity context for bot/group-originated runs */ + agentGroup?: AgentGroupConfig; /** Agent Management context (optional, available models and plugins) */ agentManagementContext?: AgentManagementContext; // ========== Capability injection ========== diff --git a/src/server/modules/ModelRuntime/index.ts b/src/server/modules/ModelRuntime/index.ts index 71a1f2cb78..853cd569f8 100644 --- a/src/server/modules/ModelRuntime/index.ts +++ b/src/server/modules/ModelRuntime/index.ts @@ -404,8 +404,11 @@ export const initModelRuntimeFromDB = async ( db: LobeChatDatabase, userId: string, provider: string, + workspaceId?: string, ): Promise => { // 1. Get user's provider configuration from database + // NOTE: workspace-scoped ai_infra is deferred until the ai_infra surrogate-`_id` + // PK migration (LOBE-10056) lands; AiProviderModel stays personal-scoped for now. const aiProviderModel = new AiProviderModel(db, userId); // Use getAiProviderById with KeyVaultsGateKeeper.getUserKeyVaults as decryptor @@ -425,11 +428,11 @@ export const initModelRuntimeFromDB = async ( const payload = buildPayloadFromKeyVaults(keyVaults, runtimeProvider); // 4. Get business hooks (billing in cloud, undefined in OSS) - const businessHooks = getBusinessModelRuntimeHooks(userId, provider); + const businessHooks = getBusinessModelRuntimeHooks(userId, provider, workspaceId); // 5. Compose with the per-call llm_generation_tracing hook (no-op when the // service is unconfigured, so OSS / self-hosted setups pay nothing for it). - const tracingHooks = createLLMGenerationTracingHook(userId, provider); + const tracingHooks = createLLMGenerationTracingHook(userId, provider, workspaceId); const hooks = mergeModelRuntimeHooks(businessHooks, tracingHooks); // 6. Initialize ModelRuntime with the payload and hooks diff --git a/src/server/routers/async/__tests__/file.test.ts b/src/server/routers/async/__tests__/file.test.ts index 1eb58bea71..b134d72086 100644 --- a/src/server/routers/async/__tests__/file.test.ts +++ b/src/server/routers/async/__tests__/file.test.ts @@ -65,6 +65,7 @@ describe('fileRouter.parseFileToChunks — NoSuchKey + internal:// branches', () vi.mocked(DocumentService).mockImplementation(() => documentServiceMock); vi.mocked(ChunkModel).mockImplementation(() => chunkModelMock); vi.mocked(EmbeddingModel).mockImplementation(() => ({}) as any); + Reflect.set(FileModel, 'getFileById', undefined); mockCtx = { serverDB: {}, userId }; }); @@ -117,6 +118,36 @@ describe('fileRouter.parseFileToChunks — NoSuchKey + internal:// branches', () expect(result).toMatchObject({ success: false }); }); + it('resolves workspaceId from the file row when async payload omits it', async () => { + Reflect.set( + FileModel, + 'getFileById', + vi.fn().mockResolvedValue({ + id: 'file_workspace', + userId, + workspaceId: 'workspace-1', + }), + ); + fileModelMock.findById.mockResolvedValue({ + id: 'file_workspace', + name: 'workspace note', + url: 'internal://document/placeholder', + }); + + const caller = fileRouter.createCaller(mockCtx); + + await caller.parseFileToChunks({ + fileId: 'file_workspace', + taskId: 'task_workspace', + }); + + expect(AsyncTaskModel).toHaveBeenCalledWith(mockCtx.serverDB, userId, 'workspace-1'); + expect(ChunkModel).toHaveBeenCalledWith(mockCtx.serverDB, userId, 'workspace-1'); + expect(ChunkService).toHaveBeenCalledWith(mockCtx.serverDB, userId, 'workspace-1'); + expect(DocumentService).toHaveBeenCalledWith(mockCtx.serverDB, userId, 'workspace-1'); + expect(FileModel).toHaveBeenCalledWith(mockCtx.serverDB, userId, 'workspace-1'); + }); + it('marks task Error and propagates for non-NoSuchKey storage errors (does not delete)', async () => { fileModelMock.findById.mockResolvedValue({ id: 'file_other', diff --git a/src/server/routers/async/__tests__/ragEval.test.ts b/src/server/routers/async/__tests__/ragEval.test.ts new file mode 100644 index 0000000000..e7e1dc4651 --- /dev/null +++ b/src/server/routers/async/__tests__/ragEval.test.ts @@ -0,0 +1,72 @@ +// @vitest-environment node +import { TRPCError } from '@trpc/server'; +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import { + EvalDatasetRecordModel, + EvalEvaluationModel, + EvaluationRecordModel, +} from '@/database/models/ragEval'; + +import { ragEvalRouter } from '../ragEval'; + +vi.mock('@/database/models/chunk', () => ({ + ChunkModel: vi.fn(() => ({})), +})); +vi.mock('@/database/models/embedding', () => ({ + EmbeddingModel: vi.fn(() => ({})), +})); +vi.mock('@/database/models/file', () => ({ + FileModel: vi.fn(() => ({})), +})); +vi.mock('@/database/models/ragEval', () => ({ + EvalDatasetRecordModel: vi.fn(() => ({ findById: vi.fn() })), + EvalEvaluationModel: vi.fn(() => ({ update: vi.fn() })), + EvaluationRecordModel: vi.fn(() => ({ findById: vi.fn().mockResolvedValue(null) })), +})); +vi.mock('@/server/modules/ModelRuntime', () => ({ + initModelRuntimeFromDB: vi.fn(), +})); +vi.mock('@/server/services/chunk', () => ({ + ChunkService: vi.fn(() => ({})), +})); + +vi.mock('@/libs/trpc/async', async () => { + const init = await vi.importActual<{ asyncTrpc: any }>('@/libs/trpc/async/init'); + const { asyncTrpc } = init; + return { + asyncAuthedProcedure: asyncTrpc.procedure, + asyncRouter: asyncTrpc.router, + createAsyncCallerFactory: asyncTrpc.createCallerFactory, + publicProcedure: asyncTrpc.procedure, + }; +}); + +describe('ragEvalRouter.runRecordEvaluation', () => { + const userId = 'user_test'; + const serverDB = { + select: vi.fn(() => ({ + from: vi.fn(() => ({ + where: vi.fn(() => ({ + limit: vi.fn().mockResolvedValue([{ workspaceId: 'workspace-1' }]), + })), + })), + })), + }; + + beforeEach(() => { + vi.clearAllMocks(); + }); + + it('resolves workspaceId from the evaluation record before reading scoped models', async () => { + const caller = ragEvalRouter.createCaller({ serverDB, userId } as any); + + await expect(caller.runRecordEvaluation({ evalRecordId: 'eval-record-1' })).rejects.toThrow( + TRPCError, + ); + + expect(EvaluationRecordModel).toHaveBeenCalledWith(serverDB, userId, 'workspace-1'); + expect(EvalEvaluationModel).toHaveBeenCalledWith(serverDB, userId, 'workspace-1'); + expect(EvalDatasetRecordModel).toHaveBeenCalledWith(serverDB, userId, 'workspace-1'); + }); +}); diff --git a/src/server/routers/async/file.ts b/src/server/routers/async/file.ts index 365c68c43e..7994e75f99 100644 --- a/src/server/routers/async/file.ts +++ b/src/server/routers/async/file.ts @@ -12,6 +12,7 @@ import { ChunkModel } from '@/database/models/chunk'; import { EmbeddingModel } from '@/database/models/embedding'; import { FileModel } from '@/database/models/file'; import { type NewChunkItem, type NewEmbeddingsItem } from '@/database/schemas'; +import type { LobeChatDatabase } from '@/database/type'; import { fileEnv } from '@/envs/file'; import { asyncAuthedProcedure, asyncRouter as router } from '@/libs/trpc/async'; import { getServerDefaultFilesConfig } from '@/server/globalConfig'; @@ -40,6 +41,22 @@ const fileProcedure = asyncAuthedProcedure.use(async (opts) => { }); }); +const resolveWorkspaceIdFromFile = async ( + serverDB: LobeChatDatabase, + userId: string, + fileId: string, + workspaceId?: string, +) => { + if (workspaceId) return workspaceId; + + if (typeof FileModel.getFileById !== 'function') return undefined; + + const file = await FileModel.getFileById(serverDB, fileId); + if (!file || file.userId !== userId) return undefined; + + return file.workspaceId ?? undefined; +}; + export const fileRouter = router({ embeddingChunks: fileProcedure .use(checkEmbeddingUsage) @@ -47,16 +64,27 @@ export const fileRouter = router({ z.object({ fileId: z.string(), taskId: z.string(), + workspaceId: z.string().optional(), }), ) .mutation(async ({ ctx, input }) => { - const file = await ctx.fileModel.findById(input.fileId); + const workspaceId = await resolveWorkspaceIdFromFile( + ctx.serverDB, + ctx.userId, + input.fileId, + input.workspaceId, + ); + const asyncTaskModel = new AsyncTaskModel(ctx.serverDB, ctx.userId, workspaceId); + const chunkModel = new ChunkModel(ctx.serverDB, ctx.userId, workspaceId); + const embeddingModel = new EmbeddingModel(ctx.serverDB, ctx.userId, workspaceId); + const fileModel = new FileModel(ctx.serverDB, ctx.userId, workspaceId); + const file = await fileModel.findById(input.fileId); if (!file) { throw new TRPCError({ code: 'BAD_REQUEST', message: 'File not found' }); } - const asyncTask = await ctx.asyncTaskModel.findById(input.taskId); + const asyncTask = await asyncTaskModel.findById(input.taskId); const { model, provider } = getServerDefaultFilesConfig().embeddingModel || DEFAULT_FILE_EMBEDDING_MODEL_ITEM; @@ -77,7 +105,7 @@ export const fileRouter = router({ const embeddingPromise = async () => { // update the task status to success - await ctx.asyncTaskModel.update(input.taskId, { + await asyncTaskModel.update(input.taskId, { status: AsyncTaskStatus.Processing, }); @@ -86,7 +114,7 @@ export const fileRouter = router({ const CHUNK_SIZE = fileEnv.EMBEDDING_BATCH_SIZE; const CONCURRENCY = fileEnv.EMBEDDING_CONCURRENCY; - const chunks = await ctx.chunkModel.getChunksTextByFileId(input.fileId); + const chunks = await chunkModel.getChunksTextByFileId(input.fileId); const requestArray = chunk(chunks, CHUNK_SIZE); try { await pMap( @@ -97,6 +125,7 @@ export const fileRouter = router({ ctx.serverDB, ctx.userId, provider, + workspaceId, ); const embeddings = await modelRuntime.embeddings( @@ -116,7 +145,7 @@ export const fileRouter = router({ model, })) || []; - await ctx.embeddingModel.bulkCreate(items); + await embeddingModel.bulkCreate(items); }, { concurrency: CONCURRENCY }, ); @@ -129,7 +158,7 @@ export const fileRouter = router({ const duration = Date.now() - startAt; // update the task status to success - await ctx.asyncTaskModel.update(input.taskId, { + await asyncTaskModel.update(input.taskId, { duration, status: AsyncTaskStatus.Success, }); @@ -142,7 +171,7 @@ export const fileRouter = router({ } catch (e) { console.error('embeddingChunks error', e); - await ctx.asyncTaskModel.update(input.taskId, { + await asyncTaskModel.update(input.taskId, { error: new AsyncTaskError((e as Error).name, (e as Error).message), status: AsyncTaskStatus.Error, }); @@ -159,10 +188,23 @@ export const fileRouter = router({ z.object({ fileId: z.string(), taskId: z.string(), + workspaceId: z.string().optional(), }), ) .mutation(async ({ ctx, input }) => { - const file = await ctx.fileModel.findById(input.fileId); + const workspaceId = await resolveWorkspaceIdFromFile( + ctx.serverDB, + ctx.userId, + input.fileId, + input.workspaceId, + ); + const asyncTaskModel = new AsyncTaskModel(ctx.serverDB, ctx.userId, workspaceId); + const chunkModel = new ChunkModel(ctx.serverDB, ctx.userId, workspaceId); + const chunkService = new ChunkService(ctx.serverDB, ctx.userId, workspaceId); + const documentService = new DocumentService(ctx.serverDB, ctx.userId, workspaceId); + const fileModel = new FileModel(ctx.serverDB, ctx.userId, workspaceId); + const fileService = new FileService(ctx.serverDB, ctx.userId, workspaceId); + const file = await fileModel.findById(input.fileId); if (!file) { throw new TRPCError({ code: 'BAD_REQUEST', message: 'File not found' }); } @@ -171,7 +213,7 @@ export const fileRouter = router({ // `internal://document/placeholder` marker. Their content lives on documents.content // and is intentionally not chunked — searching is handled by BM25 instead. if (file.url.startsWith('internal://')) { - await ctx.asyncTaskModel.update(input.taskId, { + await asyncTaskModel.update(input.taskId, { error: new AsyncTaskError( AsyncTaskErrorType.TaskTriggerError, 'Inline documents (custom/document) do not require chunking; content is searched via BM25.', @@ -186,7 +228,7 @@ export const fileRouter = router({ let content: Uint8Array | undefined; try { - content = await ctx.fileService.getFileByteArray(file.url); + content = await fileService.getFileByteArray(file.url); } catch (e) { console.error(e); const errorCode = (e as any).Code; @@ -195,7 +237,7 @@ export const fileRouter = router({ // into destroying chunks/embeddings/documents. Mark the task as Error // so users see a clear message and can re-upload or retry. if (errorCode === 'NoSuchKey') { - await ctx.asyncTaskModel.update(input.taskId, { + await asyncTaskModel.update(input.taskId, { error: new AsyncTaskError( AsyncTaskErrorType.TaskTriggerError, 'File content unavailable in storage. Verify storage access or re-upload.', @@ -209,7 +251,7 @@ export const fileRouter = router({ } // Other fetch errors (network, IAM, etc.) — mark the task as Error so // the user surface stays consistent, then propagate. - await ctx.asyncTaskModel.update(input.taskId, { + await asyncTaskModel.update(input.taskId, { error: new AsyncTaskError( AsyncTaskErrorType.TaskTriggerError, `Failed to fetch file content: ${(e as Error)?.message ?? errorCode ?? 'unknown error'}`, @@ -221,7 +263,7 @@ export const fileRouter = router({ if (!content) return; - const asyncTask = await ctx.asyncTaskModel.findById(input.taskId); + const asyncTask = await asyncTaskModel.findById(input.taskId); if (!asyncTask) throw new TRPCError({ code: 'BAD_REQUEST', message: 'Async Task not found' }); @@ -240,13 +282,12 @@ export const fileRouter = router({ }); const chunkingPromise = async () => { - const chunkService = ctx.chunkService; // update the task status to processing - await ctx.asyncTaskModel.update(input.taskId, { status: AsyncTaskStatus.Processing }); + await asyncTaskModel.update(input.taskId, { status: AsyncTaskStatus.Processing }); // parse file to document record first (for detailed content viewing) try { - await ctx.documentService.parseFile(input.fileId); + await documentService.parseFile(input.fileId); } catch (e) { // document parsing failure should not block chunking console.warn( @@ -268,6 +309,7 @@ export const fileRouter = router({ ...item, text: text ? sanitizeUTF8(text) : '', userId: ctx.userId, + workspaceId, }), ); @@ -282,17 +324,22 @@ export const fileRouter = router({ }; } - await ctx.chunkModel.bulkCreate(chunks, input.fileId); + await chunkModel.bulkCreate(chunks, input.fileId); if (chunkResult.unstructuredChunks) { const unstructuredChunks = chunkResult.unstructuredChunks.map( - (item): NewChunkItem => ({ ...item, fileId: input.fileId, userId: ctx.userId }), + (item): NewChunkItem => ({ + ...item, + fileId: input.fileId, + userId: ctx.userId, + workspaceId, + }), ); - await ctx.chunkModel.bulkCreateUnstructuredChunks(unstructuredChunks); + await chunkModel.bulkCreateUnstructuredChunks(unstructuredChunks); } // update the task status to success - await ctx.asyncTaskModel.update(input.taskId, { + await asyncTaskModel.update(input.taskId, { duration, status: AsyncTaskStatus.Success, }); @@ -314,7 +361,7 @@ export const fileRouter = router({ : new AsyncTaskError((error as Error).name, error.message); console.error('[Chunking Error]', asyncTaskError); - await ctx.asyncTaskModel.update(input.taskId, { + await asyncTaskModel.update(input.taskId, { error: asyncTaskError, status: AsyncTaskStatus.Error, }); diff --git a/src/server/routers/async/image.ts b/src/server/routers/async/image.ts index 3b6a3200bd..19bae6f23d 100644 --- a/src/server/routers/async/image.ts +++ b/src/server/routers/async/image.ts @@ -61,6 +61,7 @@ const createImageInputSchema = z.object({ .passthrough(), provider: z.string(), taskId: z.string(), + workspaceId: z.string().optional(), }); /** @@ -85,7 +86,12 @@ export const imageRouter = router({ provider, model, params, + workspaceId, } = input; + const asyncTaskModel = new AsyncTaskModel(ctx.serverDB, ctx.userId, workspaceId); + const generationBatchModel = new GenerationBatchModel(ctx.serverDB, ctx.userId, workspaceId); + const generationModel = new GenerationModel(ctx.serverDB, ctx.userId, workspaceId); + const generationService = new GenerationService(ctx.serverDB, ctx.userId, workspaceId); log('Starting async image generation: %O', { generationId, @@ -102,14 +108,14 @@ export const imageRouter = router({ }); // Check if generationBatch exists before processing - const generationBatch = await ctx.generationBatchModel.findById(generationBatchId); + const generationBatch = await generationBatchModel.findById(generationBatchId); if (!generationBatch) { log('Generation batch not found: %s, skipping image generation', generationBatchId); throw new TRPCError({ code: 'FORBIDDEN', message: 'Invalid Request!' }); } log('Updating task status to Processing: %s', taskId); - await ctx.asyncTaskModel.update(taskId, { status: AsyncTaskStatus.Processing }); + await asyncTaskModel.update(taskId, { status: AsyncTaskStatus.Processing }); // Use AbortController to prevent resource leaks const abortController = new AbortController(); @@ -128,7 +134,12 @@ export const imageRouter = router({ ); // Read user's provider config from database - const modelRuntime = await initModelRuntimeFromDB(ctx.serverDB, ctx.userId, provider); + const modelRuntime = await initModelRuntimeFromDB( + ctx.serverDB, + ctx.userId, + provider, + workspaceId, + ); // Check if operation has been cancelled checkAbortSignal(signal); @@ -189,7 +200,7 @@ export const imageRouter = router({ } } - const { image, thumbnailImage } = await ctx.generationService.transformImageForGeneration( + const { image, thumbnailImage } = await generationService.transformImageForGeneration( imageUrl, authHeaders, ); @@ -199,13 +210,13 @@ export const imageRouter = router({ log('Uploading image for generation'); const { imageUrl: uploadedImageUrl, thumbnailImageUrl } = - await ctx.generationService.uploadImageForGeneration(image, thumbnailImage); + await generationService.uploadImageForGeneration(image, thumbnailImage); // Check if operation has been cancelled checkAbortSignal(signal); log('Updating generation asset and file'); - await ctx.generationModel.createAssetAndFile( + await generationModel.createAssetAndFile( generationId, { height: height ?? image.height, @@ -234,7 +245,7 @@ export const imageRouter = router({ const duration = Date.now() - generationBatch.createdAt.getTime(); log('Updating task status to Success: %s, duration: %dms', taskId, duration); - await ctx.asyncTaskModel.update(taskId, { + await asyncTaskModel.update(taskId, { duration, status: AsyncTaskStatus.Success, }); @@ -316,7 +327,7 @@ export const imageRouter = router({ providerContentPolicyMessage, }); - await ctx.asyncTaskModel.update(taskId, { + await asyncTaskModel.update(taskId, { error: new AsyncTaskError(errorType, errorMessage), status: AsyncTaskStatus.Error, }); diff --git a/src/server/routers/async/ragEval.ts b/src/server/routers/async/ragEval.ts index 4a034ee328..3945bf9150 100644 --- a/src/server/routers/async/ragEval.ts +++ b/src/server/routers/async/ragEval.ts @@ -1,6 +1,7 @@ import { chainAnswerWithContext } from '@lobechat/prompts'; import { EvalEvaluationStatus, RequestTrigger } from '@lobechat/types'; import { TRPCError } from '@trpc/server'; +import { eq } from 'drizzle-orm'; import { ModelProvider } from 'model-bank'; import type OpenAI from 'openai'; import { z } from 'zod'; @@ -14,6 +15,7 @@ import { EvalEvaluationModel, EvaluationRecordModel, } from '@/database/models/ragEval'; +import { evaluationRecords } from '@/database/schemas'; import { asyncAuthedProcedure, asyncRouter as router } from '@/libs/trpc/async'; import { initModelRuntimeFromDB } from '@/server/modules/ModelRuntime'; import { ChunkService } from '@/server/services/chunk'; @@ -43,7 +45,22 @@ export const ragEvalRouter = router({ }), ) .mutation(async ({ ctx, input }) => { - const evalRecord = await ctx.evalRecordModel.findById(input.evalRecordId); + // System-level async dispatch: resolve workspace from the eval-record row + // and re-instantiate models so ownership-filtered reads/writes match the + // record's workspace (the procedure middleware defaults to personal mode). + const [rawRow] = await ctx.serverDB + .select({ workspaceId: evaluationRecords.workspaceId }) + .from(evaluationRecords) + .where(eq(evaluationRecords.id, input.evalRecordId)) + .limit(1); + const wsId = rawRow?.workspaceId ?? undefined; + const evalRecordModel = new EvaluationRecordModel(ctx.serverDB, ctx.userId, wsId); + const evaluationModel = new EvalEvaluationModel(ctx.serverDB, ctx.userId, wsId); + const datasetRecordModel = new EvalDatasetRecordModel(ctx.serverDB, ctx.userId, wsId); + const scopedEmbeddingModel = new EmbeddingModel(ctx.serverDB, ctx.userId, wsId); + const scopedChunkModel = new ChunkModel(ctx.serverDB, ctx.userId, wsId); + + const evalRecord = await evalRecordModel.findById(input.evalRecordId); if (!evalRecord) { throw new TRPCError({ code: 'BAD_REQUEST', message: 'Evaluation not found' }); @@ -56,6 +73,7 @@ export const ragEvalRouter = router({ ctx.serverDB, ctx.userId, ModelProvider.OpenAI, + wsId, ); const { question, languageModel, embeddingModel } = evalRecord; @@ -74,12 +92,12 @@ export const ragEvalRouter = router({ { metadata: { trigger: RequestTrigger.Eval }, user: ctx.userId }, ); - const embeddingId = await ctx.embeddingModel.create({ + const embeddingId = await scopedEmbeddingModel.create({ embeddings: embeddings?.[0], model: embeddingModel, }); - await ctx.evalRecordModel.update(evalRecord.id, { + await evalRecordModel.update(evalRecord.id, { questionEmbeddingId: embeddingId, }); @@ -88,18 +106,18 @@ export const ragEvalRouter = router({ // If context does not exist, perform a retrieval if (!context || context.length === 0) { - const datasetRecord = await ctx.datasetRecordModel.findById(evalRecord.datasetRecordId); + const datasetRecord = await datasetRecordModel.findById(evalRecord.datasetRecordId); - const embeddingItem = await ctx.embeddingModel.findById(questionEmbeddingId); + const embeddingItem = await scopedEmbeddingModel.findById(questionEmbeddingId); - const chunks = await ctx.chunkModel.semanticSearchForChat({ + const chunks = await scopedChunkModel.semanticSearchForChat({ embedding: embeddingItem!.embeddings!, fileIds: datasetRecord!.referenceFiles!, query: evalRecord.question, }); context = chunks.map((item) => item.text).filter(Boolean) as string[]; - await ctx.evalRecordModel.update(evalRecord.id, { context }); + await evalRecordModel.update(evalRecord.id, { context }); } // Generate LLM answer @@ -120,7 +138,7 @@ export const ragEvalRouter = router({ const answer = data.choices[0].message.content; - await ctx.evalRecordModel.update(input.evalRecordId, { + await evalRecordModel.update(input.evalRecordId, { answer, duration: Date.now() - now, languageModel, @@ -129,12 +147,12 @@ export const ragEvalRouter = router({ return { success: true }; } catch (e) { - await ctx.evalRecordModel.update(input.evalRecordId, { + await evalRecordModel.update(input.evalRecordId, { error: new AsyncTaskError((e as Error).name, (e as Error).message), status: EvalEvaluationStatus.Error, }); - await ctx.evaluationModel.update(evalRecord.evaluationId, { + await evaluationModel.update(evalRecord.evaluationId, { status: EvalEvaluationStatus.Error, }); diff --git a/src/server/routers/async/video.ts b/src/server/routers/async/video.ts index 77bf7f1220..75511b2dcf 100644 --- a/src/server/routers/async/video.ts +++ b/src/server/routers/async/video.ts @@ -47,6 +47,7 @@ const createVideoInputSchema = z.object({ model: z.string(), prechargeResult: z.any().optional(), provider: z.string(), + workspaceId: z.string().optional(), }); const checkAbortSignal = (signal: AbortSignal) => { @@ -133,7 +134,11 @@ export const videoRouter = router({ model, prechargeResult, provider, + workspaceId, } = input; + const asyncTaskModel = new AsyncTaskModel(ctx.serverDB, ctx.userId, workspaceId); + const generationModel = new GenerationModel(ctx.serverDB, ctx.userId, workspaceId); + const videoService = new VideoGenerationService(ctx.serverDB, ctx.userId, workspaceId); log('Starting async video polling: %O', { asyncTaskId, @@ -150,12 +155,13 @@ export const videoRouter = router({ try { const pollingPromise = async (signal: AbortSignal) => { - const asyncTaskModel = ctx.asyncTaskModel; - const generationModel = ctx.generationModel; - const videoService = ctx.videoService; - log('Initializing agent runtime for provider: %s', provider); - const modelRuntime = await initModelRuntimeFromDB(ctx.serverDB, ctx.userId, provider); + const modelRuntime = await initModelRuntimeFromDB( + ctx.serverDB, + ctx.userId, + provider, + workspaceId, + ); checkAbortSignal(signal); @@ -276,7 +282,7 @@ export const videoRouter = router({ userId: ctx.userId, }); - await ctx.asyncTaskModel.update(asyncTaskId, { + await asyncTaskModel.update(asyncTaskId, { error: new AsyncTaskError( providerContentPolicyMessage ? AsyncTaskErrorType.ProviderContentModeration diff --git a/src/server/routers/lambda/__tests__/file.test.ts b/src/server/routers/lambda/__tests__/file.test.ts index 0021be0f2e..3efdd5c248 100644 --- a/src/server/routers/lambda/__tests__/file.test.ts +++ b/src/server/routers/lambda/__tests__/file.test.ts @@ -1,6 +1,7 @@ import { TRPCError } from '@trpc/server'; import { beforeEach, describe, expect, it, vi } from 'vitest'; +import { KnowledgeRepo } from '@/database/repositories/knowledge'; import { fileRouter } from '@/server/routers/lambda/file'; import { AsyncTaskStatus } from '@/types/asyncTask'; @@ -11,7 +12,15 @@ const routerMocks = vi.hoisted(() => { return { businessFileUploadCheck: vi.fn(), + businessFileTransferStorageCheck: vi.fn(), serverDB: { + select: vi.fn(() => ({ + from: vi.fn(() => ({ + where: vi.fn(() => ({ + limit: vi.fn().mockResolvedValue([{ role: 'member' }]), + })), + })), + })), transaction: vi.fn(async (callback: (trx: unknown) => unknown) => callback(transactionClient), ), @@ -69,7 +78,15 @@ function createCallerWithCtx(partialCtx: any = {}) { }; const ctx = { - serverDB: {} as any, + serverDB: { + select: vi.fn(() => ({ + from: vi.fn(() => ({ + where: vi.fn(() => ({ + limit: vi.fn().mockResolvedValue([{ role: 'member' }]), + })), + })), + })), + } as any, userId: 'test-user', asyncTaskModel, chunkModel, @@ -101,6 +118,7 @@ vi.mock('@/database/core/db-adaptor', () => ({ })); vi.mock('@/business/server/lambda-routers/file', () => ({ + businessFileTransferStorageCheck: routerMocks.businessFileTransferStorageCheck, businessFileUploadCheck: routerMocks.businessFileUploadCheck, })); @@ -134,6 +152,8 @@ const mockFileModelFindByIds = vi.fn(); const mockFileModelQuery = vi.fn(); const mockFileModelUpdateGlobalFile = vi.fn(); const mockFileModelClear = vi.fn(); +const mockFileModelTransferTo = vi.fn(); +const mockFileModelCopyToWorkspace = vi.fn(); vi.mock('@/database/models/file', () => ({ FileModel: vi.fn(() => ({ @@ -146,6 +166,8 @@ vi.mock('@/database/models/file', () => ({ query: mockFileModelQuery, updateGlobalFile: mockFileModelUpdateGlobalFile, clear: mockFileModelClear, + copyToWorkspace: mockFileModelCopyToWorkspace, + transferTo: mockFileModelTransferTo, })), })); @@ -165,6 +187,10 @@ vi.mock('@/server/services/file', () => ({ const mockKnowledgeRepoQuery = vi.fn().mockResolvedValue([]); const mockDocumentServiceDeleteDocuments = vi.fn(); +const mockDocumentModelCountFileUsageInSubtree = vi.fn(); +const mockDocumentModelCopyToWorkspace = vi.fn(); +const mockDocumentModelFindById = vi.fn(); +const mockDocumentModelTransferTo = vi.fn(); vi.mock('@/database/repositories/knowledge', () => ({ KnowledgeRepo: vi.fn(() => ({ @@ -173,7 +199,12 @@ vi.mock('@/database/repositories/knowledge', () => ({ })); vi.mock('@/database/models/document', () => ({ - DocumentModel: vi.fn(() => ({})), + DocumentModel: vi.fn(() => ({ + countFileUsageInSubtree: mockDocumentModelCountFileUsageInSubtree, + copyToWorkspace: mockDocumentModelCopyToWorkspace, + findById: mockDocumentModelFindById, + transferTo: mockDocumentModelTransferTo, + })), })); vi.mock('@/server/services/document', () => ({ @@ -190,6 +221,7 @@ describe('fileRouter', () => { beforeEach(() => { vi.clearAllMocks(); routerMocks.businessFileUploadCheck.mockResolvedValue(undefined); + routerMocks.businessFileTransferStorageCheck.mockResolvedValue(undefined); mockFile = { id: 'test-id', @@ -396,6 +428,27 @@ describe('fileRouter', () => { ); }); + it('should pass workspace context into business upload check', async () => { + ({ caller } = createCallerWithCtx({ workspaceId: 'workspace-1' })); + mockFileModelCheckHash.mockResolvedValue({ isExist: false }); + mockFileModelCreate.mockResolvedValue({ id: 'new-file-id' }); + + await caller.createFile({ + hash: 'test-hash', + fileType: 'text', + name: 'test.txt', + size: 100, + url: 'files/test.txt', + metadata: {}, + }); + + expect(routerMocks.businessFileUploadCheck).toHaveBeenCalledWith( + expect.objectContaining({ + workspaceId: 'workspace-1', + }), + ); + }); + it('should use actual file size from S3 instead of client-provided size (security fix)', async () => { // Setup: S3 returns actual size of 5000 bytes mockFileServiceGetFileMetadata.mockResolvedValue({ @@ -579,6 +632,14 @@ describe('fileRouter', () => { }); describe('getKnowledgeItems', () => { + it('should pass workspace context to the knowledge repository', async () => { + ({ caller } = createCallerWithCtx({ workspaceId: 'workspace-1' })); + + await caller.getKnowledgeItems({}); + + expect(KnowledgeRepo).toHaveBeenCalledWith(expect.anything(), 'test-user', 'workspace-1'); + }); + it('should return knowledge items with files and documents', async () => { const knowledgeItems = [ { @@ -744,6 +805,95 @@ describe('fileRouter', () => { }); }); + describe('transferEntity', () => { + it('should transfer document resources via documentModel', async () => { + ctx.workspaceId = 'workspace-active'; + mockDocumentModelFindById.mockResolvedValue({ id: 'doc-1' }); + mockDocumentModelCountFileUsageInSubtree.mockResolvedValue(4096); + mockDocumentModelTransferTo.mockResolvedValue({ id: 'doc-1' }); + + await caller.transferEntity({ + entityType: 'document', + id: 'doc-1', + targetWorkspaceId: null, + }); + + expect(mockDocumentModelFindById).toHaveBeenCalledWith('doc-1'); + expect(mockDocumentModelCountFileUsageInSubtree).toHaveBeenCalledWith('doc-1'); + expect(routerMocks.businessFileTransferStorageCheck).toHaveBeenCalledWith({ + additionalSize: 4096, + targetUserId: 'test-user', + targetWorkspaceId: null, + }); + expect(mockDocumentModelTransferTo).toHaveBeenCalledWith('doc-1', null, 'test-user'); + expect(mockFileModelFindById).not.toHaveBeenCalled(); + }); + + it('should check target storage before transferring a file resource', async () => { + mockFileModelFindById.mockResolvedValue({ id: 'file-1', size: 2048 }); + mockFileModelTransferTo.mockResolvedValue({ fileId: 'file-1' }); + + await caller.transferEntity({ + entityType: 'file', + id: 'file-1', + targetWorkspaceId: 'workspace-target', + }); + + expect(routerMocks.businessFileTransferStorageCheck).toHaveBeenCalledWith({ + additionalSize: 2048, + targetUserId: 'test-user', + targetWorkspaceId: 'workspace-target', + }); + expect(mockFileModelTransferTo).toHaveBeenCalledWith( + 'file-1', + 'workspace-target', + 'test-user', + ); + }); + }); + + describe('copyEntityToWorkspace', () => { + it('should check target storage before copying a file resource', async () => { + mockFileModelFindById.mockResolvedValue({ id: 'file-1', size: 2048 }); + mockFileModelCopyToWorkspace.mockResolvedValue({ fileId: 'file-new' }); + + await caller.copyEntityToWorkspace({ + entityType: 'file', + id: 'file-1', + targetWorkspaceId: null, + }); + + expect(routerMocks.businessFileTransferStorageCheck).toHaveBeenCalledWith({ + additionalSize: 2048, + targetUserId: 'test-user', + targetWorkspaceId: null, + }); + expect(mockFileModelCopyToWorkspace).toHaveBeenCalledWith('file-1', null, 'test-user'); + }); + + it('should copy document resources via documentModel', async () => { + mockDocumentModelCopyToWorkspace.mockResolvedValue({ id: 'doc-1' }); + mockDocumentModelFindById.mockResolvedValue({ id: 'doc-1' }); + mockDocumentModelCountFileUsageInSubtree.mockResolvedValue(4096); + + await caller.copyEntityToWorkspace({ + entityType: 'document', + id: 'doc-1', + targetWorkspaceId: null, + }); + + expect(mockDocumentModelFindById).toHaveBeenCalledWith('doc-1'); + expect(mockDocumentModelCountFileUsageInSubtree).toHaveBeenCalledWith('doc-1'); + expect(routerMocks.businessFileTransferStorageCheck).toHaveBeenCalledWith({ + additionalSize: 4096, + targetUserId: 'test-user', + targetWorkspaceId: null, + }); + expect(mockDocumentModelCopyToWorkspace).toHaveBeenCalledWith('doc-1', null, 'test-user'); + expect(mockFileModelFindById).not.toHaveBeenCalled(); + }); + }); + describe('removeFileAsyncTask', () => { it('should do nothing when file not found', async () => { ctx.fileModel.findById.mockResolvedValue(null); diff --git a/src/server/routers/lambda/__tests__/knowledgeBase.test.ts b/src/server/routers/lambda/__tests__/knowledgeBase.test.ts new file mode 100644 index 0000000000..206eccb2d3 --- /dev/null +++ b/src/server/routers/lambda/__tests__/knowledgeBase.test.ts @@ -0,0 +1,96 @@ +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import { knowledgeBaseRouter } from '@/server/routers/lambda/knowledgeBase'; +import { TransferErrorCode } from '@/types/transferError'; + +const routerMocks = vi.hoisted(() => ({ + businessFileTransferStorageCheck: vi.fn(), +})); + +const mockKnowledgeBaseModelCountFileUsage = vi.fn(); +const mockKnowledgeBaseModelCopyToWorkspace = vi.fn(); +const mockKnowledgeBaseModelFindById = vi.fn(); +const mockKnowledgeBaseModelTransferTo = vi.fn(); + +vi.mock('@/business/server/lambda-routers/file', () => ({ + businessFileTransferStorageCheck: routerMocks.businessFileTransferStorageCheck, +})); + +vi.mock('@/database/models/knowledgeBase', () => ({ + KnowledgeBaseModel: vi.fn(() => ({ + copyToWorkspace: mockKnowledgeBaseModelCopyToWorkspace, + countFileUsage: mockKnowledgeBaseModelCountFileUsage, + findById: mockKnowledgeBaseModelFindById, + transferTo: mockKnowledgeBaseModelTransferTo, + })), +})); + +describe('knowledgeBaseRouter', () => { + const ctx = { + serverDB: {}, + userId: 'test-user', + workspaceId: 'workspace-active', + }; + + const caller = knowledgeBaseRouter.createCaller(ctx as any); + + beforeEach(() => { + vi.clearAllMocks(); + routerMocks.businessFileTransferStorageCheck.mockResolvedValue(undefined); + mockKnowledgeBaseModelCopyToWorkspace.mockResolvedValue({ id: 'kb-copy' }); + mockKnowledgeBaseModelCountFileUsage.mockResolvedValue(4096); + mockKnowledgeBaseModelFindById.mockResolvedValue({ id: 'kb-1' }); + mockKnowledgeBaseModelTransferTo.mockResolvedValue({ id: 'kb-1' }); + }); + + describe('transferKnowledgeBase', () => { + it('checks target storage before transferring a library', async () => { + await caller.transferKnowledgeBase({ + id: 'kb-1', + targetWorkspaceId: null, + }); + + expect(mockKnowledgeBaseModelCountFileUsage).toHaveBeenCalledWith('kb-1'); + expect(routerMocks.businessFileTransferStorageCheck).toHaveBeenCalledWith({ + additionalSize: 4096, + targetUserId: 'test-user', + targetWorkspaceId: null, + }); + expect(mockKnowledgeBaseModelTransferTo).toHaveBeenCalledWith('kb-1', null, 'test-user'); + }); + + it('returns a stable error code when the library no longer exists', async () => { + mockKnowledgeBaseModelFindById.mockResolvedValue(undefined); + + await expect( + caller.transferKnowledgeBase({ + id: 'missing-kb', + targetWorkspaceId: null, + }), + ).rejects.toMatchObject({ + cause: { + data: { + code: TransferErrorCode.ResourceNotFound, + }, + }, + }); + }); + }); + + describe('copyKnowledgeBaseToWorkspace', () => { + it('checks target storage before copying a library', async () => { + await caller.copyKnowledgeBaseToWorkspace({ + id: 'kb-1', + targetWorkspaceId: null, + }); + + expect(mockKnowledgeBaseModelCountFileUsage).toHaveBeenCalledWith('kb-1'); + expect(routerMocks.businessFileTransferStorageCheck).toHaveBeenCalledWith({ + additionalSize: 4096, + targetUserId: 'test-user', + targetWorkspaceId: null, + }); + expect(mockKnowledgeBaseModelCopyToWorkspace).toHaveBeenCalledWith('kb-1', null, 'test-user'); + }); + }); +}); diff --git a/src/server/routers/lambda/__tests__/llmGenerationTracing.test.ts b/src/server/routers/lambda/__tests__/llmGenerationTracing.test.ts index e9508c2891..bfe2189b86 100644 --- a/src/server/routers/lambda/__tests__/llmGenerationTracing.test.ts +++ b/src/server/routers/lambda/__tests__/llmGenerationTracing.test.ts @@ -39,12 +39,17 @@ describe('llmGenerationTracingRouter.recordFeedback', () => { }); expect(result).toEqual({ ok: true }); - expect(recordFeedback).toHaveBeenCalledWith('u1', tracingId, { - data: { accepted_text: 'hello' }, - score: 1, - signal: 'positive', - source: 'explicit_thumbs', - }); + expect(recordFeedback).toHaveBeenCalledWith( + 'u1', + tracingId, + { + data: { accepted_text: 'hello' }, + score: 1, + signal: 'positive', + source: 'explicit_thumbs', + }, + undefined, + ); }); it('translates LLMGenerationFeedbackError(not_found) into TRPCError NOT_FOUND', async () => { diff --git a/src/server/routers/lambda/__tests__/messenger.test.ts b/src/server/routers/lambda/__tests__/messenger.test.ts index 2a5a9ca1ea..f1c3b220a6 100644 --- a/src/server/routers/lambda/__tests__/messenger.test.ts +++ b/src/server/routers/lambda/__tests__/messenger.test.ts @@ -11,7 +11,10 @@ const { mockFindByPlatform, mockFindByPlatformUser, mockGetServerDB, + mockGetServerFeatureFlagsStateFromRuntimeConfig, + mockHasAnyPermission, mockInitWithEnvKey, + mockListUserWorkspaces, mockListByInstallerUserId, mockMarkRevoked, mockNotifyTelegramLinkSuccess, @@ -24,7 +27,10 @@ const { mockFindByPlatform: vi.fn(), mockFindByPlatformUser: vi.fn(), mockGetServerDB: vi.fn(), + mockGetServerFeatureFlagsStateFromRuntimeConfig: vi.fn(), + mockHasAnyPermission: vi.fn(), mockInitWithEnvKey: vi.fn(), + mockListUserWorkspaces: vi.fn(), mockListByInstallerUserId: vi.fn(), mockMarkRevoked: vi.fn(), mockNotifyTelegramLinkSuccess: vi.fn(), @@ -38,6 +44,18 @@ vi.mock('@/database/core/db-adaptor', () => ({ getServerDB: mockGetServerDB, })); +vi.mock('@/database/models/workspace', () => ({ + WorkspaceModel: class { + listUserWorkspaces = (...args: any[]) => mockListUserWorkspaces(...args); + }, +})); + +vi.mock('@/database/models/rbac', () => ({ + RbacModel: class { + hasAnyPermission = (...args: any[]) => mockHasAnyPermission(...args); + }, +})); + vi.mock('@/database/models/messengerInstallation', () => ({ MessengerInstallationModel: { findById: vi.fn(), @@ -62,6 +80,10 @@ vi.mock('@/server/modules/KeyVaultsEncrypt', () => ({ }, })); +vi.mock('@/server/featureFlags', () => ({ + getServerFeatureFlagsStateFromRuntimeConfig: mockGetServerFeatureFlagsStateFromRuntimeConfig, +})); + vi.mock('@/server/services/messenger', () => ({ consumeLinkToken: mockConsumeLinkToken, MessengerDiscordBinder: vi.fn(), @@ -111,6 +133,16 @@ const createSelectBuilder = (result: T) => { return builder; }; +const createAgentListBuilder = (result: T) => { + const builder = { + from: vi.fn(() => builder), + orderBy: vi.fn().mockResolvedValue(result), + where: vi.fn(() => builder), + }; + + return builder; +}; + describe('messengerRouter.listMyInstallations', () => { const serverDB = { kind: 'server-db' }; @@ -229,6 +261,7 @@ describe('messengerRouter.peekLinkToken', () => { describe('messengerRouter.confirmLink', () => { beforeEach(() => { vi.clearAllMocks(); + mockGetServerFeatureFlagsStateFromRuntimeConfig.mockResolvedValue({ enableWorkspace: true }); mockInitWithEnvKey.mockResolvedValue(undefined); }); @@ -323,7 +356,9 @@ describe('messengerRouter.confirmLink', () => { }); it('allows re-confirming the same Telegram account', async () => { - const selectBuilder = createSelectBuilder([{ id: 'agent-1', title: 'Agent 1' }]); + const selectBuilder = createSelectBuilder([ + { id: 'agent-1', title: 'Agent 1', userId: 'user-1', workspaceId: null }, + ]); const serverDB = { select: vi.fn(() => selectBuilder) }; const linkPayload = { platform: 'telegram', @@ -354,6 +389,90 @@ describe('messengerRouter.confirmLink', () => { platformUserId: 'tg-same', platformUsername: '@same', tenantId: '', + workspaceId: null, }); }); + + it('blocks binding a workspace agent when workspace feature is disabled', async () => { + const selectBuilder = createSelectBuilder([ + { + id: 'agent-1', + title: 'Workspace Agent', + userId: 'owner-1', + workspaceId: 'workspace-1', + }, + ]); + const serverDB = { select: vi.fn(() => selectBuilder) }; + const linkPayload = { + platform: 'telegram', + platformUserId: 'tg-same', + platformUsername: '@same', + tenantId: '', + }; + + mockGetServerDB.mockResolvedValue(serverDB); + mockGetServerFeatureFlagsStateFromRuntimeConfig.mockResolvedValue({ enableWorkspace: false }); + mockPeekLinkToken.mockResolvedValue(linkPayload); + mockFindByPlatformUser.mockResolvedValue(undefined); + mockFindByPlatform.mockResolvedValue(undefined); + mockListUserWorkspaces.mockResolvedValue([{ id: 'workspace-1', name: 'Workspace 1' }]); + mockHasAnyPermission.mockResolvedValue(true); + + const caller = createCaller(await createContextInner({ userId: 'user-1' })); + + await expect( + caller.confirmLink({ initialAgentId: 'agent-1', randomId: 'rand-1234' }), + ).rejects.toMatchObject({ + code: 'FORBIDDEN', + message: 'Workspace feature is not enabled for this user', + }); + + expect(mockConsumeLinkToken).not.toHaveBeenCalled(); + expect(mockUpsertForPlatform).not.toHaveBeenCalled(); + }); +}); + +describe('messengerRouter.listBindingScopes', () => { + beforeEach(() => { + vi.clearAllMocks(); + mockGetServerDB.mockResolvedValue({}); + }); + + it('returns no workspace scopes when workspace feature is disabled', async () => { + mockGetServerFeatureFlagsStateFromRuntimeConfig.mockResolvedValue({ enableWorkspace: false }); + mockListUserWorkspaces.mockResolvedValue([{ id: 'workspace-1', name: 'Workspace 1' }]); + + const caller = createCaller(await createContextInner({ userId: 'user-1' })); + const result = await caller.listBindingScopes(); + + expect(result).toEqual([]); + expect(mockListUserWorkspaces).not.toHaveBeenCalled(); + }); +}); + +describe('messengerRouter.listAgentsForBinding', () => { + beforeEach(() => { + vi.clearAllMocks(); + mockGetServerDB.mockResolvedValue({}); + }); + + it('rejects workspace-scoped agent listing when workspace feature is disabled', async () => { + const selectBuilder = createAgentListBuilder([]); + const serverDB = { select: vi.fn(() => selectBuilder) }; + + mockGetServerDB.mockResolvedValue(serverDB); + mockGetServerFeatureFlagsStateFromRuntimeConfig.mockResolvedValue({ enableWorkspace: false }); + mockListUserWorkspaces.mockResolvedValue([{ id: 'workspace-1', name: 'Workspace 1' }]); + + const caller = createCaller(await createContextInner({ userId: 'user-1' })); + + await expect(caller.listAgentsForBinding({ workspaceId: 'workspace-1' })).rejects.toMatchObject( + { + code: 'FORBIDDEN', + message: 'Workspace feature is not enabled for this user', + }, + ); + + expect(mockListUserWorkspaces).not.toHaveBeenCalled(); + }); }); diff --git a/src/server/routers/lambda/_helpers/resolveContext.test.ts b/src/server/routers/lambda/_helpers/resolveContext.test.ts index 3d9211668c..9e3ab04c81 100644 --- a/src/server/routers/lambda/_helpers/resolveContext.test.ts +++ b/src/server/routers/lambda/_helpers/resolveContext.test.ts @@ -6,15 +6,24 @@ import { resolveContext, } from './resolveContext'; +const { mockBuildWorkspaceWhere } = vi.hoisted(() => ({ + mockBuildWorkspaceWhere: vi.fn(() => 'workspace-where'), +})); + // Mock the database module vi.mock('@/database/schemas', () => ({ agentsToSessions: { agentId: 'agent_id', sessionId: 'session_id', userId: 'user_id', + workspaceId: 'workspace_id', }, })); +vi.mock('@/database/utils/workspace', () => ({ + buildWorkspaceWhere: mockBuildWorkspaceWhere, +})); + describe('resolveContext', () => { const mockUserId = 'user-1'; @@ -65,6 +74,18 @@ describe('resolveContext', () => { expect(mockDb.select).toHaveBeenCalled(); }); + it('should scope agentId resolution by workspaceId when provided', async () => { + const mockDb = createMockDb([{ sessionId: 'workspace-session-1' }]); + + const result = await resolveContext({ agentId: 'agent-1' }, mockDb, mockUserId, 'ws-1'); + + expect(result.sessionId).toBe('workspace-session-1'); + expect(mockBuildWorkspaceWhere).toHaveBeenCalledWith( + { userId: mockUserId, workspaceId: 'ws-1' }, + expect.objectContaining({ workspaceId: 'workspace_id' }), + ); + }); + it('should prefer agentId over sessionId when both are provided', async () => { const mockDb = createMockDb([{ sessionId: 'resolved-from-agent' }]); @@ -191,6 +212,18 @@ describe('resolveContext', () => { expect(mockWhere).toHaveBeenCalled(); // The where clause should be called with userId filter }); + + it('should scope session reverse lookup by workspaceId when provided', async () => { + const mockDb = createMockDb([{ agentId: 'agent-1' }]); + + const result = await resolveAgentIdFromSession('session-1', mockDb, mockUserId, 'ws-1'); + + expect(result).toBe('agent-1'); + expect(mockBuildWorkspaceWhere).toHaveBeenCalledWith( + { userId: mockUserId, workspaceId: 'ws-1' }, + expect.objectContaining({ workspaceId: 'workspace_id' }), + ); + }); }); describe('batchResolveAgentIdFromSessions', () => { @@ -234,6 +267,18 @@ describe('resolveContext', () => { expect(result.get('session-2')).toBe('agent-2'); }); + it('should scope batch reverse lookup by workspaceId when provided', async () => { + const mockDb = createBatchMockDb([{ sessionId: 'session-1', agentId: 'agent-1' }]); + + const result = await batchResolveAgentIdFromSessions(['session-1'], mockDb, 'user-1', 'ws-1'); + + expect(result.get('session-1')).toBe('agent-1'); + expect(mockBuildWorkspaceWhere).toHaveBeenCalledWith( + { userId: 'user-1', workspaceId: 'ws-1' }, + expect.objectContaining({ workspaceId: 'workspace_id' }), + ); + }); + it('should handle partial matches', async () => { const mockDb = createBatchMockDb([{ sessionId: 'session-1', agentId: 'agent-1' }]); diff --git a/src/server/routers/lambda/_helpers/resolveContext.ts b/src/server/routers/lambda/_helpers/resolveContext.ts index c41b64d4a9..9c59c2b5a3 100644 --- a/src/server/routers/lambda/_helpers/resolveContext.ts +++ b/src/server/routers/lambda/_helpers/resolveContext.ts @@ -2,6 +2,7 @@ import { and, eq, inArray } from 'drizzle-orm'; import { agentsToSessions } from '@/database/schemas'; import { type LobeChatDatabase } from '@/database/type'; +import { buildWorkspaceWhere } from '@/database/utils/workspace'; import { type ConversationContextInput } from '../_schema/context'; @@ -28,6 +29,7 @@ export const resolveContext = async ( input: ConversationContextInput, db: LobeChatDatabase, userId: string, + workspaceId?: string, ): Promise => { let resolvedSessionId: string | null = input.sessionId ?? null; @@ -36,7 +38,12 @@ export const resolveContext = async ( const [relation] = await db .select({ sessionId: agentsToSessions.sessionId }) .from(agentsToSessions) - .where(and(eq(agentsToSessions.agentId, input.agentId), eq(agentsToSessions.userId, userId))) + .where( + and( + eq(agentsToSessions.agentId, input.agentId), + buildWorkspaceWhere({ userId, workspaceId }, agentsToSessions), + ), + ) .limit(1); if (relation) { @@ -67,11 +74,17 @@ export const resolveAgentIdFromSession = async ( sessionId: string, db: LobeChatDatabase, userId: string, + workspaceId?: string, ): Promise => { const [relation] = await db .select({ agentId: agentsToSessions.agentId }) .from(agentsToSessions) - .where(and(eq(agentsToSessions.sessionId, sessionId), eq(agentsToSessions.userId, userId))) + .where( + and( + eq(agentsToSessions.sessionId, sessionId), + buildWorkspaceWhere({ userId, workspaceId }, agentsToSessions), + ), + ) .limit(1); return relation?.agentId; @@ -91,6 +104,7 @@ export const batchResolveAgentIdFromSessions = async ( sessionIds: string[], db: LobeChatDatabase, userId: string, + workspaceId?: string, ): Promise> => { if (sessionIds.length === 0) return new Map(); @@ -98,7 +112,10 @@ export const batchResolveAgentIdFromSessions = async ( .select({ agentId: agentsToSessions.agentId, sessionId: agentsToSessions.sessionId }) .from(agentsToSessions) .where( - and(eq(agentsToSessions.userId, userId), inArray(agentsToSessions.sessionId, sessionIds)), + and( + buildWorkspaceWhere({ userId, workspaceId }, agentsToSessions), + inArray(agentsToSessions.sessionId, sessionIds), + ), ); return new Map(relations.map((r) => [r.sessionId, r.agentId])); diff --git a/src/server/routers/lambda/agent.ts b/src/server/routers/lambda/agent.ts index ec80cda616..72ca47f513 100644 --- a/src/server/routers/lambda/agent.ts +++ b/src/server/routers/lambda/agent.ts @@ -1,29 +1,36 @@ import { DEFAULT_AGENT_CONFIG, INBOX_SESSION_ID } from '@lobechat/const'; import { CreateAgentSchema, type KnowledgeItem } from '@lobechat/types'; import { KnowledgeType } from '@lobechat/types'; +import { TRPCError } from '@trpc/server'; +import { and, eq, isNull } from 'drizzle-orm'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { AgentModel } from '@/database/models/agent'; import { ChatGroupModel } from '@/database/models/chatGroup'; import { FileModel } from '@/database/models/file'; import { KnowledgeBaseModel } from '@/database/models/knowledgeBase'; import { SessionModel } from '@/database/models/session'; import { UserModel } from '@/database/models/user'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { workspaceMembers } from '@/database/schemas'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { AgentService } from '@/server/services/agent'; +import { TransferErrorCode } from '@/types/transferError'; -const agentProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const agentProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - agentModel: new AgentModel(ctx.serverDB, ctx.userId), - agentService: new AgentService(ctx.serverDB, ctx.userId), - chatGroupModel: new ChatGroupModel(ctx.serverDB, ctx.userId), - fileModel: new FileModel(ctx.serverDB, ctx.userId), - knowledgeBaseModel: new KnowledgeBaseModel(ctx.serverDB, ctx.userId), - sessionModel: new SessionModel(ctx.serverDB, ctx.userId), + agentModel: new AgentModel(ctx.serverDB, ctx.userId, wsId), + agentService: new AgentService(ctx.serverDB, ctx.userId, wsId), + chatGroupModel: new ChatGroupModel(ctx.serverDB, ctx.userId, wsId), + fileModel: new FileModel(ctx.serverDB, ctx.userId, wsId), + knowledgeBaseModel: new KnowledgeBaseModel(ctx.serverDB, ctx.userId, wsId), + sessionModel: new SessionModel(ctx.serverDB, ctx.userId, wsId), }, }); }); @@ -57,6 +64,7 @@ export const agentRouter = router({ * Returns the created agent ID and session ID */ createAgent: agentProcedure + .use(withScopedPermission('agent:create')) .input( z.object({ config: CreateAgentSchema.optional(), @@ -73,6 +81,7 @@ export const agentRouter = router({ }), createAgentFiles: agentProcedure + .use(withScopedPermission('agent:update')) .input( z.object({ agentId: z.string(), @@ -85,6 +94,7 @@ export const agentRouter = router({ }), createAgentKnowledgeBase: agentProcedure + .use(withScopedPermission('agent:update')) .input( z.object({ agentId: z.string(), @@ -106,6 +116,7 @@ export const agentRouter = router({ * Returns only the agent ID. */ createAgentOnly: agentProcedure + .use(withScopedPermission('agent:create')) .input( z.object({ config: z.object({}).passthrough().optional(), @@ -123,6 +134,7 @@ export const agentRouter = router({ }), deleteAgentFile: agentProcedure + .use(withScopedPermission('agent:update')) .input( z.object({ agentId: z.string(), @@ -134,6 +146,7 @@ export const agentRouter = router({ }), deleteAgentKnowledgeBase: agentProcedure + .use(withScopedPermission('agent:update')) .input( z.object({ agentId: z.string(), @@ -149,6 +162,7 @@ export const agentRouter = router({ * Returns the new agent ID and session ID. */ duplicateAgent: agentProcedure + .use(withScopedPermission('agent:fork')) .input( z.object({ agentId: z.string(), @@ -303,12 +317,14 @@ export const agentRouter = router({ * Remove an agent and its associated session */ removeAgent: agentProcedure + .use(withScopedPermission('agent:delete')) .input(z.object({ agentId: z.string() })) .mutation(async ({ input, ctx }) => { return ctx.agentModel.delete(input.agentId); }), toggleFile: agentProcedure + .use(withScopedPermission('agent:update')) .input( z.object({ agentId: z.string(), @@ -321,6 +337,7 @@ export const agentRouter = router({ }), toggleKnowledgeBase: agentProcedure + .use(withScopedPermission('agent:update')) .input( z.object({ agentId: z.string(), @@ -336,7 +353,86 @@ export const agentRouter = router({ ); }), + transferAgent: agentProcedure + .use(withScopedPermission('agent:update')) + .input( + z.object({ + agentId: z.string(), + targetWorkspaceId: z.string().nullable(), + }), + ) + .mutation(async ({ input, ctx }) => { + // 1. Fetch the agent to check ownership + const agent = await ctx.agentModel.getAgentConfigById(input.agentId); + if (!agent) { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.ResourceNotFound } }, + code: 'NOT_FOUND', + message: 'Agent not found', + }); + } + + // 2. In workspace mode, members can only transfer agents they created; + // workspace owners can transfer any agent + if (ctx.workspaceId && agent.userId !== ctx.userId) { + const [membership] = await ctx.serverDB + .select({ role: workspaceMembers.role }) + .from(workspaceMembers) + .where( + and( + eq(workspaceMembers.workspaceId, ctx.workspaceId), + eq(workspaceMembers.userId, ctx.userId), + isNull(workspaceMembers.deletedAt), + ), + ) + .limit(1); + + if (!membership || membership.role !== 'owner') { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.OwnerOnly } }, + code: 'FORBIDDEN', + message: 'Only workspace owners can transfer agents created by others', + }); + } + } + + // 3. Validate target workspace access (user must be member+) + if (input.targetWorkspaceId) { + const [targetMembership] = await ctx.serverDB + .select({ role: workspaceMembers.role }) + .from(workspaceMembers) + .where( + and( + eq(workspaceMembers.workspaceId, input.targetWorkspaceId), + eq(workspaceMembers.userId, ctx.userId), + isNull(workspaceMembers.deletedAt), + ), + ) + .limit(1); + + if (!targetMembership || targetMembership.role === 'viewer') { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.TargetNoWriteAccess } }, + code: 'FORBIDDEN', + message: 'No write access to target workspace', + }); + } + } + + // 4. Cannot transfer to the same workspace + if (input.targetWorkspaceId === ctx.workspaceId) { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.SameWorkspace } }, + code: 'BAD_REQUEST', + message: 'Cannot transfer agent to the same workspace', + }); + } + + return ctx.agentModel.transferAgent(input.agentId, input.targetWorkspaceId, ctx.userId); + }), + updateAgentConfig: agentProcedure + .use(withScopedPermission('agent:update')) .input( z.object({ agentId: z.string(), @@ -352,6 +448,7 @@ export const agentRouter = router({ * Pin or unpin an agent */ updateAgentPinned: agentProcedure + .use(withScopedPermission('agent:update')) .input( z.object({ id: z.string(), diff --git a/src/server/routers/lambda/agentBotProvider.ts b/src/server/routers/lambda/agentBotProvider.ts index 1c82a8c1c4..978c675cc4 100644 --- a/src/server/routers/lambda/agentBotProvider.ts +++ b/src/server/routers/lambda/agentBotProvider.ts @@ -3,6 +3,8 @@ import { fetchQrCode, pollQrStatus } from '@lobechat/chat-adapter-wechat'; import { TRPCError } from '@trpc/server'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { AgentBotProviderModel } from '@/database/models/agentBotProvider'; import { authedProcedure, router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; @@ -17,17 +19,24 @@ import { mergeWithDefaults, platformRegistry } from '@/server/services/bot/platf import { GatewayService } from '@/server/services/gateway'; import { getBotRuntimeStatus } from '@/server/services/gateway/runtimeStatus'; -const agentBotProviderProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const agentBotProviderProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey(); return opts.next({ ctx: { - agentBotProviderModel: new AgentBotProviderModel(ctx.serverDB, ctx.userId, gateKeeper), + agentBotProviderModel: new AgentBotProviderModel(ctx.serverDB, ctx.userId, gateKeeper, wsId), }, }); }); +// Write variant gates viewers out of bot-provider mutations +// (create/update/delete + start/test connections). Reads keep the bare proc. +const agentBotProviderProcedureWrite = agentBotProviderProcedure.use( + withScopedPermission('agent:update'), +); + /** * Wrap the shared access-policy validator so violations surface as * `TRPCError(BAD_REQUEST)` — keeps client forms able to highlight the @@ -49,7 +58,7 @@ export const agentBotProviderRouter = router({ return platformRegistry.listSerializedPlatforms(); }), - create: agentBotProviderProcedure + create: agentBotProviderProcedureWrite .input( z.object({ agentId: z.string(), @@ -79,7 +88,7 @@ export const agentBotProviderRouter = router({ } }), - delete: agentBotProviderProcedure + delete: agentBotProviderProcedureWrite .input(z.object({ id: z.string() })) .mutation(async ({ input, ctx }) => { // Load record before delete to get platform + applicationId @@ -118,18 +127,18 @@ export const agentBotProviderRouter = router({ return getBotRuntimeStatus(input.platform, input.applicationId); }), - refreshRuntimeStatus: authedProcedure + refreshRuntimeStatus: agentBotProviderProcedureWrite .input(z.object({ applicationId: z.string(), platform: z.string() })) - .mutation(async ({ input, ctx }) => { + .mutation(async ({ input }) => { const service = new GatewayService(); - return service.refreshBotRuntimeStatus(input.platform, input.applicationId, ctx.userId); + return service.refreshBotRuntimeStatus(input.platform, input.applicationId); }), - refreshRuntimeStatusesByAgent: authedProcedure + refreshRuntimeStatusesByAgent: agentBotProviderProcedureWrite .input(z.object({ agentId: z.string() })) - .mutation(async ({ input, ctx }) => { + .mutation(async ({ input }) => { const service = new GatewayService(); - await service.refreshBotRuntimeStatusesByAgent(input.agentId, ctx.userId); + await service.refreshBotRuntimeStatusesByAgent(input.agentId); return { ok: true as const }; }), @@ -155,7 +164,7 @@ export const agentBotProviderRouter = router({ })); }), - connectBot: agentBotProviderProcedure + connectBot: agentBotProviderProcedureWrite .input(z.object({ applicationId: z.string(), platform: z.string() })) .mutation(async ({ input, ctx }) => { const service = new GatewayService(); @@ -164,7 +173,7 @@ export const agentBotProviderRouter = router({ return { status }; }), - testConnection: agentBotProviderProcedure + testConnection: agentBotProviderProcedureWrite .input(z.object({ applicationId: z.string(), platform: z.string() })) .mutation(async ({ input, ctx }) => { const { platform, applicationId } = input; @@ -251,7 +260,7 @@ export const agentBotProviderRouter = router({ return pollQrStatus(input.qrcode); }), - update: agentBotProviderProcedure + update: agentBotProviderProcedureWrite .input( z.object({ applicationId: z.string().optional(), diff --git a/src/server/routers/lambda/agentDocument.ts b/src/server/routers/lambda/agentDocument.ts index a6b6840dd7..01e038d5b6 100644 --- a/src/server/routers/lambda/agentDocument.ts +++ b/src/server/routers/lambda/agentDocument.ts @@ -6,10 +6,12 @@ import { import { TRPCError } from '@trpc/server'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { AgentDocumentModel } from '@/database/models/agentDocuments'; import { TopicModel } from '@/database/models/topic'; import { TopicDocumentModel } from '@/database/models/topicDocument'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { AgentDocumentsService } from '@/server/services/agentDocuments'; import { emitAgentDocumentToolOutcomeSafely } from '@/server/services/agentDocuments/toolOutcome'; @@ -148,20 +150,28 @@ const liteXMLOperationSchema = z.union([ }), ]); -const agentDocumentProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const agentDocumentProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - agentDocumentModel: new AgentDocumentModel(ctx.serverDB, ctx.userId), - agentDocumentService: new AgentDocumentsService(ctx.serverDB, ctx.userId), - agentDocumentVfsService: new AgentDocumentVfsService(ctx.serverDB, ctx.userId), - topicModel: new TopicModel(ctx.serverDB, ctx.userId), - topicDocumentModel: new TopicDocumentModel(ctx.serverDB, ctx.userId), + agentDocumentModel: new AgentDocumentModel(ctx.serverDB, ctx.userId, wsId), + agentDocumentService: new AgentDocumentsService(ctx.serverDB, ctx.userId, wsId), + agentDocumentVfsService: new AgentDocumentVfsService(ctx.serverDB, ctx.userId, wsId), + topicModel: new TopicModel(ctx.serverDB, ctx.userId, wsId), + topicDocumentModel: new TopicDocumentModel(ctx.serverDB, ctx.userId, wsId), }, }); }); +// Write variant gates viewers out of every agent-document mutation +// (upsert/delete/rename/copy/skill-edit, plus the VFS path-based writes). +// Read endpoints keep using `agentDocumentProcedure`. +const agentDocumentProcedureWrite = agentDocumentProcedure.use( + withScopedPermission('document:update'), +); + const emitCreateDocumentToolOutcome = async (input: { agentDocumentId?: string; agentId: string; @@ -247,7 +257,7 @@ export const agentDocumentRouter = router({ /** * Create or update a document */ - upsertDocument: agentDocumentProcedure + upsertDocument: agentDocumentProcedureWrite .input( z.object({ agentId: z.string(), @@ -272,7 +282,7 @@ export const agentDocumentRouter = router({ /** * Delete a specific document */ - deleteDocument: agentDocumentProcedure + deleteDocument: agentDocumentProcedureWrite .input( z.object({ agentId: z.string(), @@ -289,7 +299,7 @@ export const agentDocumentRouter = router({ /** * Delete all documents for an agent */ - deleteAllDocuments: agentDocumentProcedure + deleteAllDocuments: agentDocumentProcedureWrite .input(z.object({ agentId: z.string() })) .mutation(async ({ ctx, input }) => { return ctx.agentDocumentService.deleteAllDocuments(input.agentId); @@ -298,7 +308,7 @@ export const agentDocumentRouter = router({ /** * Initialize documents from a template set */ - initializeFromTemplate: agentDocumentProcedure + initializeFromTemplate: agentDocumentProcedureWrite .input( z.object({ agentId: z.string(), @@ -335,7 +345,7 @@ export const agentDocumentRouter = router({ /** * Clone documents from one agent to another */ - cloneDocuments: agentDocumentProcedure + cloneDocuments: agentDocumentProcedureWrite .input( z.object({ sourceAgentId: z.string(), @@ -466,7 +476,7 @@ export const agentDocumentRouter = router({ /** * Tool-oriented: write document by VFS path */ - writeDocumentByPath: agentDocumentProcedure + writeDocumentByPath: agentDocumentProcedureWrite .input( z.object({ agentId: z.string(), @@ -492,7 +502,7 @@ export const agentDocumentRouter = router({ } }), - createSkillByPath: agentDocumentProcedure + createSkillByPath: agentDocumentProcedureWrite .input(createMountedSkillSchema) .mutation(async ({ ctx, input }) => { try { @@ -511,7 +521,7 @@ export const agentDocumentRouter = router({ } }), - updateSkillByPath: agentDocumentProcedure + updateSkillByPath: agentDocumentProcedureWrite .input(updateMountedSkillSchema) .mutation(async ({ ctx, input }) => { try { @@ -528,7 +538,7 @@ export const agentDocumentRouter = router({ } }), - deleteSkillByPath: agentDocumentProcedure + deleteSkillByPath: agentDocumentProcedureWrite .input(deleteMountedSkillSchema) .mutation(async ({ ctx, input }) => { try { @@ -543,7 +553,7 @@ export const agentDocumentRouter = router({ /** * Tool-oriented: create a VFS directory */ - mkdirDocumentByPath: agentDocumentProcedure + mkdirDocumentByPath: agentDocumentProcedureWrite .input( z.object({ agentId: z.string(), @@ -570,7 +580,7 @@ export const agentDocumentRouter = router({ /** * Tool-oriented: rename or move a VFS path */ - renameDocumentByPath: agentDocumentProcedure + renameDocumentByPath: agentDocumentProcedureWrite .input( z.object({ agentId: z.string(), @@ -599,7 +609,7 @@ export const agentDocumentRouter = router({ /** * Tool-oriented: copy a VFS path */ - copyDocumentByPath: agentDocumentProcedure + copyDocumentByPath: agentDocumentProcedureWrite .input( z.object({ agentId: z.string(), @@ -628,7 +638,7 @@ export const agentDocumentRouter = router({ /** * Tool-oriented: soft-delete a VFS path */ - deleteDocumentByPath: agentDocumentProcedure + deleteDocumentByPath: agentDocumentProcedureWrite .input( z.object({ agentId: z.string(), @@ -681,7 +691,7 @@ export const agentDocumentRouter = router({ /** * Tool-oriented: restore a trash entry */ - restoreDocumentFromTrashByPath: agentDocumentProcedure + restoreDocumentFromTrashByPath: agentDocumentProcedureWrite .input( z.object({ agentId: z.string(), @@ -703,7 +713,7 @@ export const agentDocumentRouter = router({ /** * Tool-oriented: permanently remove a trash entry */ - deleteDocumentPermanentlyByPath: agentDocumentProcedure + deleteDocumentPermanentlyByPath: agentDocumentProcedureWrite .input( z.object({ agentId: z.string(), @@ -727,7 +737,7 @@ export const agentDocumentRouter = router({ /** * Tool-oriented: associate an existing document with an agent */ - associateDocument: agentDocumentProcedure + associateDocument: agentDocumentProcedureWrite .input( z.object({ agentId: z.string(), @@ -741,7 +751,7 @@ export const agentDocumentRouter = router({ /** * Tool-oriented: create document */ - createDocument: agentDocumentProcedure + createDocument: agentDocumentProcedureWrite .input( z .object({ @@ -797,7 +807,7 @@ export const agentDocumentRouter = router({ * Create an agent document and associate it with a topic in one call. * Used by the topic → page flow to create an agent document. */ - createForTopic: agentDocumentProcedure + createForTopic: agentDocumentProcedureWrite .input( z .object({ @@ -873,7 +883,7 @@ export const agentDocumentRouter = router({ /** * Tool-oriented: modify document nodes by id through LiteXML. */ - modifyNodes: agentDocumentProcedure + modifyNodes: agentDocumentProcedureWrite .input( z.object({ agentId: z.string(), @@ -892,7 +902,7 @@ export const agentDocumentRouter = router({ /** * Tool-oriented: replace document content by id */ - replaceDocumentContent: agentDocumentProcedure + replaceDocumentContent: agentDocumentProcedureWrite .input( z.object({ agentId: z.string(), @@ -911,7 +921,7 @@ export const agentDocumentRouter = router({ /** * Tool-oriented: remove document by id */ - removeDocument: agentDocumentProcedure + removeDocument: agentDocumentProcedureWrite .input( z.object({ agentId: z.string(), @@ -926,7 +936,7 @@ export const agentDocumentRouter = router({ /** * Tool-oriented: copy document by id */ - copyDocument: agentDocumentProcedure + copyDocument: agentDocumentProcedureWrite .input( z.object({ agentId: z.string(), @@ -949,7 +959,7 @@ export const agentDocumentRouter = router({ /** * Tool-oriented: rename document by id */ - renameDocument: agentDocumentProcedure + renameDocument: agentDocumentProcedureWrite .input( z.object({ agentId: z.string(), @@ -972,7 +982,7 @@ export const agentDocumentRouter = router({ /** * Tool-oriented: update document load rule by id */ - updateLoadRule: agentDocumentProcedure + updateLoadRule: agentDocumentProcedureWrite .input( z.object({ agentId: z.string(), diff --git a/src/server/routers/lambda/agentEval.ts b/src/server/routers/lambda/agentEval.ts index 02ee394ee5..c88dd154a4 100644 --- a/src/server/routers/lambda/agentEval.ts +++ b/src/server/routers/lambda/agentEval.ts @@ -3,6 +3,8 @@ import { TRPCError } from '@trpc/server'; import debug from 'debug'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { AgentEvalBenchmarkModel, AgentEvalDatasetModel, @@ -10,7 +12,7 @@ import { AgentEvalRunTopicModel, AgentEvalTestCaseModel, } from '@/database/models/agentEval'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { AgentEvalRunService } from '@/server/services/agentEvalRun'; import { FileService } from '@/server/services/file'; @@ -52,27 +54,33 @@ const evalRunInputConfigSchema = z.object({ const log = debug('lobe-lambda-router:agent-eval'); -const agentEvalProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const agentEvalProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - benchmarkModel: new AgentEvalBenchmarkModel(ctx.serverDB, ctx.userId), - datasetModel: new AgentEvalDatasetModel(ctx.serverDB, ctx.userId), - runModel: new AgentEvalRunModel(ctx.serverDB, ctx.userId), - runService: new AgentEvalRunService(ctx.serverDB, ctx.userId), - runTopicModel: new AgentEvalRunTopicModel(ctx.serverDB, ctx.userId), - testCaseModel: new AgentEvalTestCaseModel(ctx.serverDB, ctx.userId), - fileService: new FileService(ctx.serverDB, ctx.userId), + benchmarkModel: new AgentEvalBenchmarkModel(ctx.serverDB, ctx.userId, wsId), + datasetModel: new AgentEvalDatasetModel(ctx.serverDB, ctx.userId, wsId), + runModel: new AgentEvalRunModel(ctx.serverDB, ctx.userId, wsId), + runService: new AgentEvalRunService(ctx.serverDB, ctx.userId, wsId), + runTopicModel: new AgentEvalRunTopicModel(ctx.serverDB, ctx.userId, wsId), + testCaseModel: new AgentEvalTestCaseModel(ctx.serverDB, ctx.userId, wsId), + fileService: new FileService(ctx.serverDB, ctx.userId, wsId), }, }); }); +// Write variant for mutations — gates viewers out of all eval-creation/edit +// flows. Reads keep using `agentEvalProcedure` (viewers may inspect existing +// benchmarks / runs). +const agentEvalProcedureWrite = agentEvalProcedure.use(withScopedPermission('agent:update')); + export const agentEvalRouter = router({ // ============================================ // Benchmark Operations // ============================================ - createBenchmark: agentEvalProcedure + createBenchmark: agentEvalProcedureWrite .input( z.object({ identifier: z.string(), @@ -125,7 +133,7 @@ export const agentEvalRouter = router({ return benchmark; }), - updateBenchmark: agentEvalProcedure + updateBenchmark: agentEvalProcedureWrite .input( z.object({ id: z.string(), @@ -148,7 +156,7 @@ export const agentEvalRouter = router({ return result; }), - deleteBenchmark: agentEvalProcedure + deleteBenchmark: agentEvalProcedureWrite .input(z.object({ id: z.string() })) .mutation(async ({ input, ctx }) => { try { @@ -172,7 +180,7 @@ export const agentEvalRouter = router({ // ============================================ // Dataset Operations // ============================================ - createDataset: agentEvalProcedure + createDataset: agentEvalProcedureWrite .input( z.object({ benchmarkId: z.string(), @@ -232,7 +240,7 @@ export const agentEvalRouter = router({ return dataset; }), - updateDataset: agentEvalProcedure + updateDataset: agentEvalProcedureWrite .input( z.object({ id: z.string(), @@ -255,7 +263,7 @@ export const agentEvalRouter = router({ return result; }), - deleteDataset: agentEvalProcedure + deleteDataset: agentEvalProcedureWrite .input(z.object({ id: z.string() })) .mutation(async ({ input, ctx }) => { try { @@ -276,7 +284,7 @@ export const agentEvalRouter = router({ } }), - parseDatasetFile: agentEvalProcedure + parseDatasetFile: agentEvalProcedureWrite .input( z.object({ pathname: z.string(), @@ -314,7 +322,7 @@ export const agentEvalRouter = router({ } }), - importDataset: agentEvalProcedure + importDataset: agentEvalProcedureWrite .input( z.object({ datasetId: z.string(), @@ -430,7 +438,7 @@ export const agentEvalRouter = router({ // ============================================ // TestCase Operations // ============================================ - createTestCase: agentEvalProcedure + createTestCase: agentEvalProcedureWrite .input( z.object({ datasetId: z.string(), @@ -471,7 +479,7 @@ export const agentEvalRouter = router({ } }), - batchCreateTestCases: agentEvalProcedure + batchCreateTestCases: agentEvalProcedureWrite .input( z.object({ datasetId: z.string(), @@ -512,7 +520,7 @@ export const agentEvalRouter = router({ } }), - updateTestCase: agentEvalProcedure + updateTestCase: agentEvalProcedureWrite .input( z.object({ id: z.string(), @@ -541,7 +549,7 @@ export const agentEvalRouter = router({ return result; }), - deleteTestCase: agentEvalProcedure + deleteTestCase: agentEvalProcedureWrite .input(z.object({ id: z.string() })) .mutation(async ({ input, ctx }) => { try { @@ -591,7 +599,7 @@ export const agentEvalRouter = router({ // ============================================ // Run Operations // ============================================ - createRun: agentEvalProcedure + createRun: agentEvalProcedureWrite .input( z.object({ datasetId: z.string(), @@ -678,7 +686,7 @@ export const agentEvalRouter = router({ return result; }), - deleteRun: agentEvalProcedure + deleteRun: agentEvalProcedureWrite .input(z.object({ id: z.string() })) .mutation(async ({ input, ctx }) => { try { @@ -707,7 +715,7 @@ export const agentEvalRouter = router({ * Start executing a run * Transitions: idle/failed → pending → running */ - startRun: agentEvalProcedure + startRun: agentEvalProcedureWrite .input( z.object({ id: z.string(), @@ -743,7 +751,7 @@ export const agentEvalRouter = router({ /** * Abort a running evaluation */ - abortRun: agentEvalProcedure + abortRun: agentEvalProcedureWrite .input(z.object({ id: z.string() })) .mutation(async ({ input, ctx }) => { const run = await ctx.runModel.findById(input.id); @@ -758,13 +766,17 @@ export const agentEvalRouter = router({ }); } - const service = new AgentEvalRunService(ctx.serverDB, ctx.userId); + const service = new AgentEvalRunService( + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ); await service.abortRun(input.id); return { success: true }; }), - retryRunErrors: agentEvalProcedure + retryRunErrors: agentEvalProcedureWrite .input(z.object({ id: z.string() })) .mutation(async ({ input, ctx }) => { const run = await ctx.runModel.findById(input.id); @@ -788,7 +800,7 @@ export const agentEvalRouter = router({ return { retryCount, runId: input.id, success: true }; }), - retryRunCase: agentEvalProcedure + retryRunCase: agentEvalProcedureWrite .input(z.object({ runId: z.string(), testCaseId: z.string() })) .mutation(async ({ input, ctx }) => { const run = await ctx.runModel.findById(input.runId); @@ -812,7 +824,7 @@ export const agentEvalRouter = router({ return { runId: input.runId, success: true, testCaseId: input.testCaseId }; }), - resumeRunCase: agentEvalProcedure + resumeRunCase: agentEvalProcedureWrite .input(z.object({ runId: z.string(), testCaseId: z.string(), threadId: z.string().optional() })) .mutation(async ({ input, ctx }) => { log( @@ -830,7 +842,7 @@ export const agentEvalRouter = router({ return result; }), - batchResumeRunCases: agentEvalProcedure + batchResumeRunCases: agentEvalProcedureWrite .input( z.object({ runId: z.string(), @@ -926,7 +938,7 @@ export const agentEvalRouter = router({ /** * Update run status (internal use) */ - updateRunStatus: agentEvalProcedure + updateRunStatus: agentEvalProcedureWrite .input( z.object({ id: z.string(), @@ -956,7 +968,7 @@ export const agentEvalRouter = router({ /** * Update run metrics (internal use) */ - updateRunMetrics: agentEvalProcedure + updateRunMetrics: agentEvalProcedureWrite .input( z.object({ id: z.string(), @@ -986,7 +998,7 @@ export const agentEvalRouter = router({ /** * Update run (user-facing: name, datasetId, targetAgentId) */ - updateRun: agentEvalProcedure + updateRun: agentEvalProcedureWrite .input( z.object({ config: evalRunInputConfigSchema.optional(), diff --git a/src/server/routers/lambda/agentEvalExternal.ts b/src/server/routers/lambda/agentEvalExternal.ts index 907cabe6e6..60570552ee 100644 --- a/src/server/routers/lambda/agentEvalExternal.ts +++ b/src/server/routers/lambda/agentEvalExternal.ts @@ -3,6 +3,8 @@ import { TRPCError } from '@trpc/server'; import { and, asc, eq, isNull } from 'drizzle-orm'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { AgentEvalDatasetModel, AgentEvalRunModel, @@ -11,7 +13,8 @@ import { } from '@/database/models/agentEval'; import { ThreadModel } from '@/database/models/thread'; import { messages } from '@/database/schemas'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { buildWorkspaceWhere } from '@/database/utils/workspace'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { AgentEvalRunService } from '@/server/services/agentEvalRun'; @@ -35,20 +38,24 @@ const reportResultItemSchema = z.object({ const toIsoString = (value?: Date | null) => (value ? value.toISOString() : undefined); -const agentEvalExternalProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const agentEvalExternalProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - datasetModel: new AgentEvalDatasetModel(ctx.serverDB, ctx.userId), - runModel: new AgentEvalRunModel(ctx.serverDB, ctx.userId), - runService: new AgentEvalRunService(ctx.serverDB, ctx.userId), - runTopicModel: new AgentEvalRunTopicModel(ctx.serverDB, ctx.userId), - testCaseModel: new AgentEvalTestCaseModel(ctx.serverDB, ctx.userId), - threadModel: new ThreadModel(ctx.serverDB, ctx.userId), + datasetModel: new AgentEvalDatasetModel(ctx.serverDB, ctx.userId, wsId), + runModel: new AgentEvalRunModel(ctx.serverDB, ctx.userId, wsId), + runService: new AgentEvalRunService(ctx.serverDB, ctx.userId, wsId), + runTopicModel: new AgentEvalRunTopicModel(ctx.serverDB, ctx.userId, wsId), + testCaseModel: new AgentEvalTestCaseModel(ctx.serverDB, ctx.userId, wsId), + threadModel: new ThreadModel(ctx.serverDB, ctx.userId, wsId), }, }); }); +const agentEvalExternalWriteProcedure = agentEvalExternalProcedure.use( + withScopedPermission('agent:update'), +); type ReportResultInput = z.infer & { runId: string }; @@ -307,7 +314,10 @@ export const agentEvalExternalRouter = router({ .input(z.object({ threadId: z.string().optional(), topicId: z.string() })) .query(async ({ ctx, input }) => { const conditions = [ - eq(messages.userId, ctx.userId), + buildWorkspaceWhere( + { userId: ctx.userId, workspaceId: ctx.workspaceId ?? undefined }, + messages, + ), eq(messages.topicId, input.topicId), isNull(messages.messageGroupId), ]; @@ -336,7 +346,7 @@ export const agentEvalExternalRouter = router({ })); }), - reportResult: agentEvalExternalProcedure + reportResult: agentEvalExternalWriteProcedure .input( z.object({ correct: z.boolean(), @@ -349,7 +359,7 @@ export const agentEvalExternalRouter = router({ ) .mutation(async ({ ctx, input }) => applyReportResult(ctx, input, true)), - reportResultsBatch: agentEvalExternalProcedure + reportResultsBatch: agentEvalExternalWriteProcedure .input(z.object({ items: z.array(reportResultItemSchema).min(1), runId: z.string() })) .mutation(async ({ ctx, input }) => { const receipts = []; @@ -390,7 +400,7 @@ export const agentEvalExternalRouter = router({ }; }), - runSetStatus: agentEvalExternalProcedure + runSetStatus: agentEvalExternalWriteProcedure .input(z.object({ runId: z.string(), status: runStatusSchema })) .mutation(async ({ ctx, input }) => { const run = await ctx.runModel.findById(input.runId); @@ -449,7 +459,7 @@ export const agentEvalExternalRouter = router({ }; }), - runTopicReportResult: agentEvalExternalProcedure + runTopicReportResult: agentEvalExternalWriteProcedure .input( z.object({ correct: z.boolean(), diff --git a/src/server/routers/lambda/agentGroup.ts b/src/server/routers/lambda/agentGroup.ts index 01b3dba97f..d711e76ef0 100644 --- a/src/server/routers/lambda/agentGroup.ts +++ b/src/server/routers/lambda/agentGroup.ts @@ -1,14 +1,20 @@ import { InsertChatGroupSchema } from '@lobechat/types'; +import { TRPCError } from '@trpc/server'; +import { and, eq, isNull } from 'drizzle-orm'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { AgentModel } from '@/database/models/agent'; import { ChatGroupModel } from '@/database/models/chatGroup'; import { UserModel } from '@/database/models/user'; import { AgentGroupRepository } from '@/database/repositories/agentGroup'; +import { workspaceMembers } from '@/database/schemas'; import { type ChatGroupConfig } from '@/database/types/chatGroup'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { AgentGroupService } from '@/server/services/agentGroup'; +import { TransferErrorCode } from '@/types/transferError'; /** * Custom schema for agent member input, replacing drizzle-generated insertAgentSchema @@ -39,22 +45,27 @@ const agentMemberInputSchema = z }) .partial(); -const agentGroupProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const agentGroupProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - agentGroupRepo: new AgentGroupRepository(ctx.serverDB, ctx.userId), - agentGroupService: new AgentGroupService(ctx.serverDB, ctx.userId), - agentModel: new AgentModel(ctx.serverDB, ctx.userId), - chatGroupModel: new ChatGroupModel(ctx.serverDB, ctx.userId), + agentGroupRepo: new AgentGroupRepository(ctx.serverDB, ctx.userId, wsId), + agentGroupService: new AgentGroupService(ctx.serverDB, ctx.userId, wsId), + agentModel: new AgentModel(ctx.serverDB, ctx.userId, wsId), + chatGroupModel: new ChatGroupModel(ctx.serverDB, ctx.userId, wsId), userModel: new UserModel(ctx.serverDB, ctx.userId), }, }); }); +// Write variant gates viewers out of chat-group mutations (create/update/ +// delete + member adds/removes). Reads keep the bare proc. +const agentGroupProcedureWrite = agentGroupProcedure.use(withScopedPermission('agent:update')); + export const agentGroupRouter = router({ - addAgentsToGroup: agentGroupProcedure + addAgentsToGroup: agentGroupProcedureWrite .input( z.object({ agentIds: z.array(z.string()), @@ -69,7 +80,7 @@ export const agentGroupRouter = router({ * Batch create virtual agents and add them to an existing group. * This is more efficient than calling createAgentOnly multiple times. */ - batchCreateAgentsInGroup: agentGroupProcedure + batchCreateAgentsInGroup: agentGroupProcedureWrite .input( z.object({ agents: z.array(agentMemberInputSchema), @@ -114,14 +125,16 @@ export const agentGroupRouter = router({ * The supervisor agent is automatically created as a virtual agent. * Returns the groupId and supervisorAgentId. */ - createGroup: agentGroupProcedure.input(InsertChatGroupSchema).mutation(async ({ input, ctx }) => { - const { group, supervisorAgentId } = await ctx.agentGroupRepo.createGroupWithSupervisor({ - ...input, - config: ctx.agentGroupService.normalizeGroupConfig(input.config as ChatGroupConfig | null), - }); + createGroup: agentGroupProcedureWrite + .input(InsertChatGroupSchema) + .mutation(async ({ input, ctx }) => { + const { group, supervisorAgentId } = await ctx.agentGroupRepo.createGroupWithSupervisor({ + ...input, + config: ctx.agentGroupService.normalizeGroupConfig(input.config as ChatGroupConfig | null), + }); - return { group, supervisorAgentId }; - }), + return { group, supervisorAgentId }; + }), /** * Create a group with virtual member agents in one request. @@ -132,7 +145,7 @@ export const agentGroupRouter = router({ * 3. Create the group with supervisor and member agents * Returns the groupId, supervisorAgentId, and created member agentIds. */ - createGroupWithMembers: agentGroupProcedure + createGroupWithMembers: agentGroupProcedureWrite .input( z.object({ groupConfig: InsertChatGroupSchema, @@ -188,7 +201,7 @@ export const agentGroupRouter = router({ return { agentIds: memberAgentIds, groupId: group.id, supervisorAgentId }; }), - deleteGroup: agentGroupProcedure + deleteGroup: agentGroupProcedureWrite .input(z.object({ id: z.string() })) .mutation(async ({ input, ctx }) => { return ctx.agentGroupService.deleteGroup(input.id); @@ -199,7 +212,7 @@ export const agentGroupRouter = router({ * Creates a new group with the same config, a new supervisor, and copies of virtual members. * Non-virtual members are referenced (not copied). */ - duplicateGroup: agentGroupProcedure + duplicateGroup: agentGroupProcedureWrite .input( z.object({ groupId: z.string(), @@ -273,7 +286,7 @@ export const agentGroupRouter = router({ * @param agentIds - Array of agent IDs to remove * @param deleteVirtualAgents - Whether to delete virtual agents (default: true) */ - removeAgentsFromGroup: agentGroupProcedure + removeAgentsFromGroup: agentGroupProcedureWrite .input( z.object({ agentIds: z.array(z.string()), @@ -289,7 +302,83 @@ export const agentGroupRouter = router({ ); }), - updateAgentInGroup: agentGroupProcedure + transferGroup: agentGroupProcedureWrite + .input( + z.object({ + groupId: z.string(), + targetWorkspaceId: z.string().nullable(), + }), + ) + .mutation(async ({ input, ctx }) => { + const group = await ctx.chatGroupModel.findById(input.groupId); + if (!group) { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.ResourceNotFound } }, + code: 'NOT_FOUND', + message: 'Agent group not found', + }); + } + + if (ctx.workspaceId && group.userId !== ctx.userId) { + const [membership] = await ctx.serverDB + .select({ role: workspaceMembers.role }) + .from(workspaceMembers) + .where( + and( + eq(workspaceMembers.workspaceId, ctx.workspaceId), + eq(workspaceMembers.userId, ctx.userId), + isNull(workspaceMembers.deletedAt), + ), + ) + .limit(1); + + if (!membership || membership.role !== 'owner') { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.OwnerOnly } }, + code: 'FORBIDDEN', + message: 'Only workspace owners can transfer agent groups created by others', + }); + } + } + + if (input.targetWorkspaceId) { + const [targetMembership] = await ctx.serverDB + .select({ role: workspaceMembers.role }) + .from(workspaceMembers) + .where( + and( + eq(workspaceMembers.workspaceId, input.targetWorkspaceId), + eq(workspaceMembers.userId, ctx.userId), + isNull(workspaceMembers.deletedAt), + ), + ) + .limit(1); + + if (!targetMembership || targetMembership.role === 'viewer') { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.TargetNoWriteAccess } }, + code: 'FORBIDDEN', + message: 'No write access to target workspace', + }); + } + } + + if (input.targetWorkspaceId === ctx.workspaceId) { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.SameWorkspace } }, + code: 'BAD_REQUEST', + message: 'Cannot transfer agent group to the same workspace', + }); + } + + return ctx.agentGroupRepo.transferToWorkspace( + input.groupId, + input.targetWorkspaceId, + ctx.userId, + ); + }), + + updateAgentInGroup: agentGroupProcedureWrite .input( z.object({ agentId: z.string(), @@ -305,7 +394,7 @@ export const agentGroupRouter = router({ return ctx.chatGroupModel.updateAgentInGroup(input.groupId, input.agentId, input.updates); }), - updateGroup: agentGroupProcedure + updateGroup: agentGroupProcedureWrite .input( z.object({ id: z.string(), diff --git a/src/server/routers/lambda/agentNotify.ts b/src/server/routers/lambda/agentNotify.ts index de0ec31e08..7e8cec23bd 100644 --- a/src/server/routers/lambda/agentNotify.ts +++ b/src/server/routers/lambda/agentNotify.ts @@ -3,9 +3,11 @@ import { TRPCError } from '@trpc/server'; import debug from 'debug'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { MessageModel } from '@/database/models/message'; import { TopicModel } from '@/database/models/topic'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { createStreamEventManager } from '@/server/modules/AgentRuntime/factory'; import { AiAgentService } from '@/server/services/aiAgent'; @@ -19,17 +21,19 @@ const getStreamManager = () => { const log = debug('lobe-server:agent-notify-router'); -const agentNotifyProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const agentNotifyProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - aiAgentService: new AiAgentService(ctx.serverDB, ctx.userId), - messageModel: new MessageModel(ctx.serverDB, ctx.userId), - topicModel: new TopicModel(ctx.serverDB, ctx.userId), + aiAgentService: new AiAgentService(ctx.serverDB, ctx.userId, { workspaceId: wsId }), + messageModel: new MessageModel(ctx.serverDB, ctx.userId, wsId), + topicModel: new TopicModel(ctx.serverDB, ctx.userId, wsId), }, }); }); +const agentNotifyWriteProcedure = agentNotifyProcedure.use(withScopedPermission('message:create')); const NotifySchema = z.object({ /** Agent ID to trigger (overrides the topic's default agent) */ @@ -76,7 +80,7 @@ export const agentNotifyRouter = router({ * role='assistant': content is written directly as an assistant message * continue=true: also trigger a new agent turn after writing */ - notify: agentNotifyProcedure.input(NotifySchema).mutation(async ({ input, ctx }) => { + notify: agentNotifyWriteProcedure.input(NotifySchema).mutation(async ({ input, ctx }) => { const { topicId, content, diff --git a/src/server/routers/lambda/agentSignal.ts b/src/server/routers/lambda/agentSignal.ts index a69160f077..65143b9ae2 100644 --- a/src/server/routers/lambda/agentSignal.ts +++ b/src/server/routers/lambda/agentSignal.ts @@ -5,7 +5,9 @@ import { import debug from 'debug'; import { z } from 'zod'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; +import { router } from '@/libs/trpc/lambda'; import { enqueueAgentSignalSourceEvent } from '@/server/services/agentSignal'; import { listAgentSignalReceipts } from '@/server/services/agentSignal/services/receiptService'; import { @@ -15,14 +17,15 @@ import { const log = debug('lobe-server:agent-signal:router'); -const agentSignalProcedure = authedProcedure; +const agentSignalProcedure = wsCompatProcedure; +const agentSignalWriteProcedure = agentSignalProcedure.use(withScopedPermission('message:create')); const clientSourceTypes = AGENT_SIGNAL_CLIENT_SOURCE_TYPES; type ClientSourceType = (typeof clientSourceTypes)[number]; type ClientSourceEventInput = AgentSignalSourceEventInput; export const agentSignalRouter = router({ - emitSourceEvent: agentSignalProcedure + emitSourceEvent: agentSignalWriteProcedure .input( z.object({ payload: z.record(z.string(), z.unknown()), @@ -46,6 +49,7 @@ export const agentSignalRouter = router({ return enqueueAgentSignalSourceEvent(input as unknown as ClientSourceEventInput, { agentId: typeof input.payload.agentId === 'string' ? input.payload.agentId : undefined, userId: ctx.userId, + workspaceId: ctx.workspaceId ?? undefined, }); }), triggerSourceEvent: agentSignalProcedure diff --git a/src/server/routers/lambda/agentSkills.ts b/src/server/routers/lambda/agentSkills.ts index 7d5720b010..218a126127 100644 --- a/src/server/routers/lambda/agentSkills.ts +++ b/src/server/routers/lambda/agentSkills.ts @@ -3,9 +3,11 @@ import { skillManifestSchema } from '@lobechat/types'; import { TRPCError } from '@trpc/server'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { AgentSkillModel } from '@/database/models/agentSkill'; import { FileModel } from '@/database/models/file'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { FileService } from '@/server/services/file'; import { MarketService } from '@/server/services/market'; @@ -51,28 +53,42 @@ const handleSkillImportError = (error: unknown): never => { throw error; }; -// ===== Procedure with Context ===== +// ===== Procedures with Context ===== -const skillProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +// Reads: workspace-aware, any member can read. In personal mode the request +// runs without workspace context (legacy behavior preserved). +const skillProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; - const skillModel = new AgentSkillModel(ctx.serverDB, ctx.userId); + const workspaceId = ctx.workspaceId ?? undefined; + const skillModel = new AgentSkillModel(ctx.serverDB, ctx.userId, workspaceId); return opts.next({ ctx: { - fileModel: new FileModel(ctx.serverDB, ctx.userId), - fileService: new FileService(ctx.serverDB, ctx.userId), + fileModel: new FileModel(ctx.serverDB, ctx.userId, workspaceId), + fileService: new FileService(ctx.serverDB, ctx.userId, workspaceId), marketService: new MarketService({ userInfo: { userId: ctx.userId } }), - skillImporter: new SkillImporter(ctx.serverDB, ctx.userId), + skillImporter: new SkillImporter(ctx.serverDB, ctx.userId, workspaceId), skillModel, }, }); }); +// Writes: workspace mode goes through RBAC (`agent:update:all | :owner`), +// gating viewers out while letting members and owners install/edit skills. +// Personal mode is unrestricted (middleware passes through when no +// workspaceId). Replaces the legacy `requireWorkspaceRoleWhenScoped('owner')` +// which was overly restrictive (member should be able to manage skills they +// own, per the role-permission matrix in @lobechat/const/rbac). +const skillWriteProcedure = skillProcedure.use(withScopedPermission('agent:update')); + const skillResourceProcedure = skillProcedure.use(async (opts) => { const { ctx } = opts; return opts.next({ ctx: { + // workspace-audit: intentionally personal-scoped (no workspaceId). This service + // only reads skill resource files by content hash (global, deduplicated files), + // never runs a per-workspace row query, so workspace scoping is a no-op here. skillResourceService: new SkillResourceService(ctx.serverDB, ctx.userId), }, }); @@ -99,7 +115,7 @@ const updateSkillSchema = z.object({ export const agentSkillsRouter = router({ // ===== Create ===== - create: skillProcedure.input(createSkillSchema).mutation(async ({ ctx, input }) => { + create: skillWriteProcedure.input(createSkillSchema).mutation(async ({ ctx, input }) => { try { return await ctx.skillImporter.createUserSkill(input); } catch (error) { @@ -109,9 +125,11 @@ export const agentSkillsRouter = router({ // ===== Delete ===== - delete: skillProcedure.input(z.object({ id: z.string() })).mutation(async ({ ctx, input }) => { - return ctx.skillModel.delete(input.id); - }), + delete: skillWriteProcedure + .input(z.object({ id: z.string() })) + .mutation(async ({ ctx, input }) => { + return ctx.skillModel.delete(input.id); + }), // ===== Query ===== @@ -150,7 +168,7 @@ export const agentSkillsRouter = router({ return ctx.skillModel.findByName(input.name); }), - importFromGitHub: skillProcedure + importFromGitHub: skillWriteProcedure .input( z.object({ branch: z.string().optional(), @@ -165,7 +183,7 @@ export const agentSkillsRouter = router({ } }), - importFromUrl: skillProcedure + importFromUrl: skillWriteProcedure .input(z.object({ url: z.string().url() })) .mutation(async ({ ctx, input }) => { try { @@ -175,7 +193,7 @@ export const agentSkillsRouter = router({ } }), - importFromZip: skillProcedure + importFromZip: skillWriteProcedure .input(z.object({ zipFileId: z.string() })) .mutation(async ({ ctx, input }) => { try { @@ -185,7 +203,7 @@ export const agentSkillsRouter = router({ } }), - importFromMarket: skillProcedure + importFromMarket: skillWriteProcedure .input(z.object({ identifier: z.string() })) .mutation(async ({ ctx, input }) => { try { @@ -266,7 +284,7 @@ export const agentSkillsRouter = router({ // ===== Update ===== - update: skillProcedure.input(updateSkillSchema).mutation(async ({ ctx, input }) => { + update: skillWriteProcedure.input(updateSkillSchema).mutation(async ({ ctx, input }) => { const { id, content, manifest } = input; return ctx.skillModel.update(id, { content, diff --git a/src/server/routers/lambda/aiAgent.ts b/src/server/routers/lambda/aiAgent.ts index 1d39902a13..083dda93fc 100644 --- a/src/server/routers/lambda/aiAgent.ts +++ b/src/server/routers/lambda/aiAgent.ts @@ -9,15 +9,18 @@ import { } from '@lobechat/types'; import { TRPCError } from '@trpc/server'; import debug from 'debug'; +import { and, eq } from 'drizzle-orm'; import pMap from 'p-map'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { MessageModel } from '@/database/models/message'; import { TaskModel } from '@/database/models/task'; -import { TaskTopicModel } from '@/database/models/taskTopic'; import { ThreadModel } from '@/database/models/thread'; import { TopicModel } from '@/database/models/topic'; -import { authedProcedure, heteroAuthedProcedure, router } from '@/libs/trpc/lambda'; +import { taskTopics, topics } from '@/database/schemas'; +import { heteroAuthedProcedure, router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { AgentRuntimeService } from '@/server/services/agentRuntime'; import { AiAgentService } from '@/server/services/aiAgent'; @@ -404,18 +407,23 @@ const HeteroFinishSchema = z.object({ topicId: z.string().min(1), }); -const aiAgentProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const aiAgentProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - agentRuntimeService: new AgentRuntimeService(ctx.serverDB, ctx.userId), - aiAgentService: new AiAgentService(ctx.serverDB, ctx.userId), - aiChatService: new AiChatService(ctx.serverDB, ctx.userId), - heterogeneousAgentService: new HeterogeneousAgentService(ctx.serverDB, ctx.userId), - messageModel: new MessageModel(ctx.serverDB, ctx.userId), - threadModel: new ThreadModel(ctx.serverDB, ctx.userId), - topicModel: new TopicModel(ctx.serverDB, ctx.userId), + agentRuntimeService: new AgentRuntimeService(ctx.serverDB, ctx.userId, { + workspaceId: wsId, + }), + aiAgentService: new AiAgentService(ctx.serverDB, ctx.userId, { workspaceId: wsId }), + aiChatService: new AiChatService(ctx.serverDB, ctx.userId, wsId), + heterogeneousAgentService: new HeterogeneousAgentService(ctx.serverDB, ctx.userId, { + workspaceId: wsId, + }), + messageModel: new MessageModel(ctx.serverDB, ctx.userId, wsId), + threadModel: new ThreadModel(ctx.serverDB, ctx.userId, wsId), + topicModel: new TopicModel(ctx.serverDB, ctx.userId, wsId), }, }); }); @@ -423,15 +431,12 @@ const aiAgentProcedure = authedProcedure.use(serverDatabase).use(async (opts) => // Dedicated procedure for hetero-agent ingest/finish endpoints. // Requires a `hetero-operation` JWT (4h expiry) — normal user tokens are rejected, // so only the sandbox/device that received the JWT from execAgent can call these. -const heteroAgentProcedure = heteroAuthedProcedure.use(serverDatabase).use(async (opts) => { - const { ctx } = opts; - - return opts.next({ - ctx: { - heterogeneousAgentService: new HeterogeneousAgentService(ctx.serverDB, ctx.userId), - }, - }); -}); +// +// Note: workspaceId is not on `ctx` for this procedure (the JWT is server-to-server +// and carries no workspace claim). Handlers must resolve wsId from the row keyed +// by `topicId` and construct `HeterogeneousAgentService` per request. +const heteroAgentProcedure = heteroAuthedProcedure.use(serverDatabase); +const aiAgentWriteProcedure = aiAgentProcedure.use(withScopedPermission('message:create')); export const aiAgentRouter = router({ /** @@ -442,7 +447,7 @@ export const aiAgentRouter = router({ * - The subAgentId is the worker agent that executes the task * - Thread messages query should not filter by agentId to include all parent messages */ - createClientGroupAgentTaskThread: aiAgentProcedure + createClientGroupAgentTaskThread: aiAgentWriteProcedure .input(CreateClientGroupAgentTaskThreadSchema) .mutation(async ({ input, ctx }) => { const { groupId, instruction, parentMessageId, subAgentId, title, topicId } = input; @@ -537,7 +542,7 @@ export const aiAgentRouter = router({ * This endpoint is called by desktop client when runInClient=true. * It creates the Thread but does NOT execute the task - execution happens on client side. */ - createClientTaskThread: aiAgentProcedure + createClientTaskThread: aiAgentWriteProcedure .input(CreateClientTaskThreadSchema) .mutation(async ({ input, ctx }) => { const { agentId, groupId, instruction, parentMessageId, title, topicId } = input; @@ -622,7 +627,7 @@ export const aiAgentRouter = router({ } }), - execAgent: aiAgentProcedure.input(ExecAgentSchema).mutation(async ({ input, ctx }) => { + execAgent: aiAgentWriteProcedure.input(ExecAgentSchema).mutation(async ({ input, ctx }) => { const { agentId, slug, @@ -677,7 +682,7 @@ export const aiAgentRouter = router({ * Batch execute multiple agents * Supports parallel or sequential execution */ - execAgents: aiAgentProcedure.input(ExecAgentsSchema).mutation(async ({ input, ctx }) => { + execAgents: aiAgentWriteProcedure.input(ExecAgentsSchema).mutation(async ({ input, ctx }) => { const { tasks, parallel = true } = input; log('execAgents: %d tasks, parallel=%s', tasks.length, parallel); @@ -769,48 +774,50 @@ export const aiAgentRouter = router({ * 4. Trigger Supervisor Agent execution * 5. Return operationId for SSE connection + messages for UI sync */ - execGroupAgent: aiAgentProcedure.input(ExecGroupAgentSchema).mutation(async ({ input, ctx }) => { - const { agentId, groupId, message, files, topicId, newTopic } = input; + execGroupAgent: aiAgentWriteProcedure + .input(ExecGroupAgentSchema) + .mutation(async ({ input, ctx }) => { + const { agentId, groupId, message, files, topicId, newTopic } = input; - log('execGroupAgent: agentId=%s, groupId=%s', agentId, groupId); + log('execGroupAgent: agentId=%s, groupId=%s', agentId, groupId); - try { - // Execute group agent - const result = await ctx.aiAgentService.execGroupAgent({ - agentId, - files, - groupId, - message, - newTopic, - topicId, - }); + try { + // Execute group agent + const result = await ctx.aiAgentService.execGroupAgent({ + agentId, + files, + groupId, + message, + newTopic, + topicId, + }); - // Get messages and topics for UI sync - // Messages include the assistant message with error if operation failed to start - const { messages, topics } = await ctx.aiChatService.getMessagesAndTopics({ - agentId, - groupId, - includeTopic: result.isCreateNewTopic, - topicId: result.topicId, - }); + // Get messages and topics for UI sync + // Messages include the assistant message with error if operation failed to start + const { messages, topics } = await ctx.aiChatService.getMessagesAndTopics({ + agentId, + groupId, + includeTopic: result.isCreateNewTopic, + topicId: result.topicId, + }); - // Return result with messages/topics - includes error/success fields - // Frontend can check success to decide whether to connect to SSE stream - return { ...result, messages, topics }; - } catch (error: any) { - log('execGroupAgent failed: %O', error); + // Return result with messages/topics - includes error/success fields + // Frontend can check success to decide whether to connect to SSE stream + return { ...result, messages, topics }; + } catch (error: any) { + log('execGroupAgent failed: %O', error); - if (error instanceof TRPCError) { - throw error; + if (error instanceof TRPCError) { + throw error; + } + + throw new TRPCError({ + cause: error, + code: 'INTERNAL_SERVER_ERROR', + message: `Failed to execute group agent: ${error.message}`, + }); } - - throw new TRPCError({ - cause: error, - code: 'INTERNAL_SERVER_ERROR', - message: `Failed to execute group agent: ${error.message}`, - }); - } - }), + }), /** * Execute SubAgent task (supports both Group and Single Agent mode) @@ -821,7 +828,7 @@ export const aiAgentRouter = router({ * - Group mode: pass groupId, Thread will be associated with the Group * - Single Agent mode: omit groupId, Thread will only be associated with the Agent */ - execSubAgentTask: aiAgentProcedure + execSubAgentTask: aiAgentWriteProcedure .input(ExecSubAgentTaskSchema) .mutation(async ({ input, ctx }) => { const { agentId, groupId, instruction, parentMessageId, title, topicId, timeout } = input; @@ -1150,23 +1157,25 @@ export const aiAgentRouter = router({ * This endpoint interrupts a SubAgent task by threadId or operationId. * It updates both operation status and Thread status to cancelled state. */ - interruptTask: aiAgentProcedure.input(InterruptTaskSchema).mutation(async ({ input, ctx }) => { - const { threadId, operationId, topicId } = input; + interruptTask: aiAgentWriteProcedure + .input(InterruptTaskSchema) + .mutation(async ({ input, ctx }) => { + const { threadId, operationId, topicId } = input; - log('interruptTask: threadId=%s, operationId=%s, topicId=%s', threadId, operationId, topicId); + log('interruptTask: threadId=%s, operationId=%s, topicId=%s', threadId, operationId, topicId); - try { - return await ctx.aiAgentService.interruptTask({ operationId, threadId, topicId }); - } catch (error: any) { - if (error.message === 'Thread not found') { - throw new TRPCError({ code: 'NOT_FOUND', message: 'Thread not found' }); + try { + return await ctx.aiAgentService.interruptTask({ operationId, threadId, topicId }); + } catch (error: any) { + if (error.message === 'Thread not found') { + throw new TRPCError({ code: 'NOT_FOUND', message: 'Thread not found' }); + } + if (error.message === 'Operation ID not found') { + throw new TRPCError({ code: 'BAD_REQUEST', message: 'Operation ID not found' }); + } + throw error; } - if (error.message === 'Operation ID not found') { - throw new TRPCError({ code: 'BAD_REQUEST', message: 'Operation ID not found' }); - } - throw error; - } - }), + }), /** * Ingest a batch of `AgentStreamEvent`s from a `lh hetero exec` producer @@ -1186,10 +1195,25 @@ export const aiAgentRouter = router({ ); try { + // Resolve workspaceId from the topic row so persistence writes land in + // the correct workspace scope. heteroAuthedProcedure carries no + // workspace claim, so we must look it up here per request. We bypass + // `TopicModel.findById` because it filters by workspace; here we need a + // workspace-agnostic lookup keyed only by topicId + userId. + const [topicRow] = await ctx.serverDB + .select({ workspaceId: topics.workspaceId }) + .from(topics) + .where(and(eq(topics.id, topicId), eq(topics.userId, ctx.userId))) + .limit(1); + const wsId = topicRow?.workspaceId ?? undefined; + const heteroService = new HeterogeneousAgentService(ctx.serverDB, ctx.userId, { + workspaceId: wsId, + }); + // Zod's z.any() infers `data?: any`, but the wire shape always includes // a `data` field (may be null). Cast at the boundary instead of widening // the shared `AgentStreamEvent` type or the service signature. - await ctx.heterogeneousAgentService.heteroIngest({ + await heteroService.heteroIngest({ agentType, assistantMessageId, events: events as AgentStreamEvent[], @@ -1219,7 +1243,19 @@ export const aiAgentRouter = router({ log('heteroFinish: topic=%s op=%s type=%s result=%s', topicId, operationId, agentType, result); try { - await ctx.heterogeneousAgentService.heteroFinish({ + // Resolve workspaceId from the topic row (heteroAuthedProcedure has no + // workspace claim) so persistence writes land in the correct scope. + const [topicRow] = await ctx.serverDB + .select({ workspaceId: topics.workspaceId }) + .from(topics) + .where(and(eq(topics.id, topicId), eq(topics.userId, ctx.userId))) + .limit(1); + const wsId = topicRow?.workspaceId ?? undefined; + const heteroService = new HeterogeneousAgentService(ctx.serverDB, ctx.userId, { + workspaceId: wsId, + }); + + await heteroService.heteroFinish({ agentType, error, operationId, @@ -1241,15 +1277,22 @@ export const aiAgentRouter = router({ // topic is already in a terminal state. const TERMINAL_TOPIC_STATUSES = new Set(['canceled', 'completed', 'failed', 'timeout']); try { - const taskTopicModel = new TaskTopicModel(ctx.serverDB, ctx.userId); - const taskTopic = await taskTopicModel.findByTopicId(topicId); - if (taskTopic && !TERMINAL_TOPIC_STATUSES.has(taskTopic.status)) { - const taskModel = new TaskModel(ctx.serverDB, ctx.userId); - const task = await taskModel.findById(taskTopic.taskId); + // System-level lookup: heteroFinish is a server-to-server callback from + // the CLI and doesn't carry a workspace context. Resolve the task topic + // (and downstream models) using the row's own `workspaceId`. + const [taskTopicRow] = await ctx.serverDB + .select() + .from(taskTopics) + .where(and(eq(taskTopics.topicId, topicId), eq(taskTopics.userId, ctx.userId))) + .limit(1); + if (taskTopicRow && !TERMINAL_TOPIC_STATUSES.has(taskTopicRow.status)) { + const wsId = taskTopicRow.workspaceId ?? undefined; + const taskModel = new TaskModel(ctx.serverDB, ctx.userId, wsId); + const task = await taskModel.findById(taskTopicRow.taskId); if (task) { const reason = result === 'success' ? 'done' : result === 'cancelled' ? 'interrupted' : 'error'; - const taskLifecycle = new TaskLifecycleService(ctx.serverDB, ctx.userId); + const taskLifecycle = new TaskLifecycleService(ctx.serverDB, ctx.userId, wsId); await taskLifecycle.onTopicComplete({ errorMessage: error?.message, operationId, @@ -1277,7 +1320,7 @@ export const aiAgentRouter = router({ } }), - processHumanIntervention: aiAgentProcedure + processHumanIntervention: aiAgentWriteProcedure .input(ProcessHumanInterventionSchema) .mutation(async ({ input, ctx }) => { const { operationId, action, data, reason, stepIndex, toolMessageId } = input; @@ -1348,25 +1391,27 @@ export const aiAgentRouter = router({ }; }), - startExecution: aiAgentProcedure.input(StartExecutionSchema).mutation(async ({ input, ctx }) => { - const { operationId, context, priority, delay } = input; + startExecution: aiAgentWriteProcedure + .input(StartExecutionSchema) + .mutation(async ({ input, ctx }) => { + const { operationId, context, priority, delay } = input; - log('Starting execution for operation %s', operationId); + log('Starting execution for operation %s', operationId); - // Start execution using AgentRuntimeService - const result = await ctx.agentRuntimeService.startExecution({ - context, - delay, - operationId, - priority, - }); + // Start execution using AgentRuntimeService + const result = await ctx.agentRuntimeService.startExecution({ + context, + delay, + operationId, + priority, + }); - return { - ...result, - message: 'Agent execution started successfully', - timestamp: new Date().toISOString(), - }; - }), + return { + ...result, + message: 'Agent execution started successfully', + timestamp: new Date().toISOString(), + }; + }), /** * Update Thread status after client-side task execution completes @@ -1374,7 +1419,7 @@ export const aiAgentRouter = router({ * This endpoint is called by desktop client after task execution finishes. * It updates the Thread status and metadata similar to server-side completion. */ - updateClientTaskThreadStatus: aiAgentProcedure + updateClientTaskThreadStatus: aiAgentWriteProcedure .input(UpdateClientTaskThreadStatusSchema) .mutation(async ({ input, ctx }) => { const { threadId, completionReason, error, resultContent, metadata } = input; diff --git a/src/server/routers/lambda/aiChat.ts b/src/server/routers/lambda/aiChat.ts index 5e5df9efde..f610e03fb1 100644 --- a/src/server/routers/lambda/aiChat.ts +++ b/src/server/routers/lambda/aiChat.ts @@ -9,12 +9,14 @@ import { getStatusKeyFromCode } from '@trpc/server/unstable-core-do-not-import'; import debug from 'debug'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { LOADING_FLAT } from '@/const/message'; import { AgentModel } from '@/database/models/agent'; import { MessageModel } from '@/database/models/message'; import { ThreadModel } from '@/database/models/thread'; import { TopicModel } from '@/database/models/topic'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { resolveContext } from '@/server/routers/lambda/_helpers/resolveContext'; import { AiChatService } from '@/server/services/aiChat'; @@ -59,24 +61,27 @@ const createRuntimeTRPCError = (error: unknown): TRPCError | undefined => { }); }; -const aiChatProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const aiChatProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - agentModel: new AgentModel(ctx.serverDB, ctx.userId), - aiChatService: new AiChatService(ctx.serverDB, ctx.userId), - aiGenerationService: new AiGenerationService(ctx.serverDB, ctx.userId), - fileService: new FileService(ctx.serverDB, ctx.userId), - messageModel: new MessageModel(ctx.serverDB, ctx.userId), - threadModel: new ThreadModel(ctx.serverDB, ctx.userId), - topicModel: new TopicModel(ctx.serverDB, ctx.userId), + agentModel: new AgentModel(ctx.serverDB, ctx.userId, wsId), + aiChatService: new AiChatService(ctx.serverDB, ctx.userId, wsId), + aiGenerationService: new AiGenerationService(ctx.serverDB, ctx.userId, wsId), + fileService: new FileService(ctx.serverDB, ctx.userId, wsId), + messageModel: new MessageModel(ctx.serverDB, ctx.userId, wsId), + threadModel: new ThreadModel(ctx.serverDB, ctx.userId, wsId), + topicModel: new TopicModel(ctx.serverDB, ctx.userId, wsId), }, }); }); +const aiChatWriteProcedure = aiChatProcedure.use(withScopedPermission('message:create')); + export const aiChatRouter = router({ - outputJSON: aiChatProcedure.input(StructureOutputSchema).mutation(async ({ input, ctx }) => { + outputJSON: aiChatWriteProcedure.input(StructureOutputSchema).mutation(async ({ input, ctx }) => { log('outputJSON called with provider: %s, model: %s', input.provider, input.model); log('messages count: %d', input.messages.length); log('schema: %O', input.schema); @@ -118,7 +123,7 @@ export const aiChatRouter = router({ return { data, tracingId }; }), - sendMessageInServer: aiChatProcedure + sendMessageInServer: aiChatWriteProcedure .input(AiSendMessageServerSchema) .mutation(async ({ input, ctx }) => { const timingContext = @@ -144,7 +149,7 @@ export const aiChatRouter = router({ const context = await runTimedStage( timingContext, 'lambda.aiChat.resolveContext', - () => resolveContext(input, ctx.serverDB, ctx.userId), + () => resolveContext(input, ctx.serverDB, ctx.userId, ctx.workspaceId ?? undefined), { hasAgentId: !!input.agentId }, ); if (!!context.sessionId) sessionId = context.sessionId; @@ -405,6 +410,7 @@ export const aiChatRouter = router({ ...input, serverDB: ctx.serverDB, userId: ctx.userId, + workspaceId: ctx.workspaceId ?? undefined, }); }), }); diff --git a/src/server/routers/lambda/aiModel.ts b/src/server/routers/lambda/aiModel.ts index 468e3cb90a..25d9094629 100644 --- a/src/server/routers/lambda/aiModel.ts +++ b/src/server/routers/lambda/aiModel.ts @@ -7,17 +7,20 @@ import { } from 'model-bank'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { AiModelModel } from '@/database/models/aiModel'; import { UserModel } from '@/database/models/user'; import { AiInfraRepos } from '@/database/repositories/aiInfra'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { getServerGlobalConfig } from '@/server/globalConfig'; import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt'; import { type ProviderConfig } from '@/types/user/settings'; -const aiModelProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const aiModelProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey(); const { aiProvider } = await getServerGlobalConfig(); @@ -38,6 +41,7 @@ const aiModelProcedure = authedProcedure.use(serverDatabase).use(async (opts) => export const aiModelRouter = router({ batchToggleAiModels: aiModelProcedure + .use(withScopedPermission('ai_model:update')) .input( z.object({ enabled: z.boolean(), @@ -49,6 +53,7 @@ export const aiModelRouter = router({ return ctx.aiModelModel.batchToggleAiModels(input.id, input.models, input.enabled); }), batchUpdateAiModels: aiModelProcedure + .use(withScopedPermission('ai_model:update')) .input( z.object({ id: z.string(), @@ -61,21 +66,26 @@ export const aiModelRouter = router({ }), clearModelsByProvider: aiModelProcedure + .use(withScopedPermission('ai_model:delete')) .input(z.object({ providerId: z.string() })) .mutation(async ({ input, ctx }) => { return ctx.aiModelModel.clearModelsByProvider(input.providerId); }), clearRemoteModels: aiModelProcedure + .use(withScopedPermission('ai_model:delete')) .input(z.object({ providerId: z.string() })) .mutation(async ({ input, ctx }) => { return ctx.aiModelModel.clearRemoteModels(input.providerId); }), - createAiModel: aiModelProcedure.input(CreateAiModelSchema).mutation(async ({ input, ctx }) => { - const data = await ctx.aiModelModel.create(input); + createAiModel: aiModelProcedure + .use(withScopedPermission('ai_model:create')) + .input(CreateAiModelSchema) + .mutation(async ({ input, ctx }) => { + const data = await ctx.aiModelModel.create(input); - return data?.id; - }), + return data?.id; + }), getAiModelById: aiModelProcedure .input(z.object({ id: z.string() })) @@ -104,18 +114,21 @@ export const aiModelRouter = router({ }), removeAiModel: aiModelProcedure + .use(withScopedPermission('ai_model:delete')) .input(z.object({ id: z.string(), providerId: z.string() })) .mutation(async ({ input, ctx }) => { return ctx.aiModelModel.delete(input.id, input.providerId); }), toggleModelEnabled: aiModelProcedure + .use(withScopedPermission('ai_model:update')) .input(ToggleAiModelEnableSchema) .mutation(async ({ input, ctx }) => { return ctx.aiModelModel.toggleModelEnabled(input); }), updateAiModel: aiModelProcedure + .use(withScopedPermission('ai_model:update')) .input( z.object({ id: z.string(), @@ -128,6 +141,7 @@ export const aiModelRouter = router({ }), updateAiModelOrder: aiModelProcedure + .use(withScopedPermission('ai_model:update')) .input( z.object({ providerId: z.string(), diff --git a/src/server/routers/lambda/aiProvider.ts b/src/server/routers/lambda/aiProvider.ts index 58e5241a4d..c81c1a1633 100644 --- a/src/server/routers/lambda/aiProvider.ts +++ b/src/server/routers/lambda/aiProvider.ts @@ -2,10 +2,12 @@ import { isOfficialProvider, OFFICIAL_PROVIDER_DISABLE_ERROR } from '@lobechat/b import { TRPCError } from '@trpc/server'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { AiProviderModel } from '@/database/models/aiProvider'; import { UserModel } from '@/database/models/user'; import { AiInfraRepos } from '@/database/repositories/aiInfra'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { getServerGlobalConfig } from '@/server/globalConfig'; import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt'; @@ -18,8 +20,9 @@ import { } from '@/types/aiProvider'; import { type ProviderConfig } from '@/types/user/settings'; -const aiProviderProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const aiProviderProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; const { aiProvider } = await getServerGlobalConfig(); @@ -40,6 +43,7 @@ const aiProviderProcedure = authedProcedure.use(serverDatabase).use(async (opts) export const aiProviderRouter = router({ checkProviderConnectivity: aiProviderProcedure + .use(withScopedPermission('ai_provider:update')) .input( z.object({ id: z.string(), @@ -59,7 +63,12 @@ export const aiProviderRouter = router({ } try { - const modelRuntime = await initModelRuntimeFromDB(ctx.serverDB, ctx.userId, input.id); + const modelRuntime = await initModelRuntimeFromDB( + ctx.serverDB, + ctx.userId, + input.id, + ctx.workspaceId ?? undefined, + ); const response = await modelRuntime.chat({ messages: [{ content: 'Hi', role: 'user' }], @@ -87,6 +96,7 @@ export const aiProviderRouter = router({ }), createAiProvider: aiProviderProcedure + .use(withScopedPermission('ai_provider:create')) .input(CreateAiProviderSchema) .mutation(async ({ input, ctx }) => { try { @@ -122,12 +132,14 @@ export const aiProviderRouter = router({ }), removeAiProvider: aiProviderProcedure + .use(withScopedPermission('ai_provider:delete')) .input(z.object({ id: z.string() })) .mutation(async ({ input, ctx }) => { return ctx.aiProviderModel.delete(input.id); }), toggleProviderEnabled: aiProviderProcedure + .use(withScopedPermission('ai_provider:update')) .input( z.object({ enabled: z.boolean(), @@ -146,6 +158,7 @@ export const aiProviderRouter = router({ }), updateAiProvider: aiProviderProcedure + .use(withScopedPermission('ai_provider:update')) .input( z.object({ id: z.string(), @@ -157,6 +170,7 @@ export const aiProviderRouter = router({ }), updateAiProviderConfig: aiProviderProcedure + .use(withScopedPermission('ai_provider:update')) .input( z.object({ id: z.string(), @@ -173,6 +187,7 @@ export const aiProviderRouter = router({ }), updateAiProviderOrder: aiProviderProcedure + .use(withScopedPermission('ai_provider:update')) .input( z.object({ sortMap: z.array( diff --git a/src/server/routers/lambda/apiKey.ts b/src/server/routers/lambda/apiKey.ts index 4440d4a906..5431bbce36 100644 --- a/src/server/routers/lambda/apiKey.ts +++ b/src/server/routers/lambda/apiKey.ts @@ -1,21 +1,25 @@ import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { ApiKeyModel } from '@/database/models/apiKey'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; -const apiKeyProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const apiKeyProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - apiKeyModel: new ApiKeyModel(ctx.serverDB, ctx.userId), + apiKeyModel: new ApiKeyModel(ctx.serverDB, ctx.userId, wsId), }, }); }); export const apiKeyRouter = router({ createApiKey: apiKeyProcedure + .use(withScopedPermission('api_key:create')) .input( z.object({ expiresAt: z.date().optional().nullable(), @@ -26,11 +30,14 @@ export const apiKeyRouter = router({ return await ctx.apiKeyModel.create(input); }), - deleteAllApiKeys: apiKeyProcedure.mutation(async ({ ctx }) => { - return ctx.apiKeyModel.deleteAll(); - }), + deleteAllApiKeys: apiKeyProcedure + .use(withScopedPermission('api_key:delete')) + .mutation(async ({ ctx }) => { + return ctx.apiKeyModel.deleteAll(); + }), deleteApiKey: apiKeyProcedure + .use(withScopedPermission('api_key:delete')) .input(z.object({ id: z.string() })) .mutation(async ({ input, ctx }) => { return ctx.apiKeyModel.delete(input.id); @@ -53,6 +60,7 @@ export const apiKeyRouter = router({ }), updateApiKey: apiKeyProcedure + .use(withScopedPermission('api_key:update')) .input( z.object({ id: z.string(), diff --git a/src/server/routers/lambda/botMessage.ts b/src/server/routers/lambda/botMessage.ts index 7c9d67c195..87030143cb 100644 --- a/src/server/routers/lambda/botMessage.ts +++ b/src/server/routers/lambda/botMessage.ts @@ -11,12 +11,14 @@ import { import { TRPCError } from '@trpc/server'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { getMessengerTelegramConfig } from '@/config/messenger'; import type { DecryptedBotProvider } from '@/database/models/agentBotProvider'; import { AgentBotProviderModel } from '@/database/models/agentBotProvider'; import { MessengerAccountLinkModel } from '@/database/models/messengerAccountLink'; import { MessengerInstallationModel } from '@/database/models/messengerInstallation'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt'; import { mergeWithDefaults, platformRegistry } from '@/server/services/bot/platforms'; @@ -33,16 +35,18 @@ import { TELEGRAM_INSTALLATION_KEY } from '@/server/services/messenger/installat // ── Middleware ──────────────────────────────────────────── -const botMessageProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const botMessageProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey(); + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - agentBotProviderModel: new AgentBotProviderModel(ctx.serverDB, ctx.userId, gateKeeper), + agentBotProviderModel: new AgentBotProviderModel(ctx.serverDB, ctx.userId, gateKeeper, wsId), }, }); }); +const botMessageWriteProcedure = botMessageProcedure.use(withScopedPermission('message:create')); // ── Shared input schemas ───────────────────────────────── @@ -262,7 +266,7 @@ const resolveSendTarget = async ( export const botMessageRouter = router({ // ==================== Direct Messaging ==================== - sendDirectMessage: botMessageProcedure + sendDirectMessage: botMessageWriteProcedure .input( z .object({ @@ -294,7 +298,7 @@ export const botMessageRouter = router({ // ==================== Core Message Operations ==================== - sendMessage: botMessageProcedure + sendMessage: botMessageWriteProcedure .input( z .object({ @@ -359,7 +363,7 @@ export const botMessageRouter = router({ }); }), - editMessage: botMessageProcedure + editMessage: botMessageWriteProcedure .input( z.object({ botId: z.string(), @@ -378,7 +382,7 @@ export const botMessageRouter = router({ }); }), - deleteMessage: botMessageProcedure + deleteMessage: botMessageWriteProcedure .input( z.object({ botId: z.string(), @@ -418,7 +422,7 @@ export const botMessageRouter = router({ // ==================== Reactions ==================== - reactToMessage: botMessageProcedure + reactToMessage: botMessageWriteProcedure .input( z.object({ botId: z.string(), @@ -456,7 +460,7 @@ export const botMessageRouter = router({ // ==================== Pin Management ==================== - pinMessage: botMessageProcedure + pinMessage: botMessageWriteProcedure .input( z.object({ botId: z.string(), @@ -473,7 +477,7 @@ export const botMessageRouter = router({ }); }), - unpinMessage: botMessageProcedure + unpinMessage: botMessageWriteProcedure .input( z.object({ botId: z.string(), @@ -560,7 +564,7 @@ export const botMessageRouter = router({ // ==================== Thread Operations ==================== - createThread: botMessageProcedure + createThread: botMessageWriteProcedure .input( z.object({ botId: z.string(), @@ -596,7 +600,7 @@ export const botMessageRouter = router({ }); }), - replyToThread: botMessageProcedure + replyToThread: botMessageWriteProcedure .input( z .object({ @@ -622,7 +626,7 @@ export const botMessageRouter = router({ // ==================== Polls ==================== - createPoll: botMessageProcedure + createPoll: botMessageWriteProcedure .input( z.object({ botId: z.string(), diff --git a/src/server/routers/lambda/brief.ts b/src/server/routers/lambda/brief.ts index 811f450ebe..a2e1992376 100644 --- a/src/server/routers/lambda/brief.ts +++ b/src/server/routers/lambda/brief.ts @@ -1,15 +1,18 @@ import { TRPCError } from '@trpc/server'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { BriefModel } from '@/database/models/brief'; import { TaskModel } from '@/database/models/task'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { AgentSignalSelfReviewBriefService } from '@/server/services/agentSignal/services/briefs/selfReview'; import { NIGHTLY_REVIEW_BRIEF_TRIGGER } from '@/server/services/agentSignal/services/selfIteration/review/brief'; import { BriefService } from '@/server/services/brief'; -const briefProcedure = authedProcedure.use(serverDatabase); +const briefProcedure = wsCompatProcedure.use(serverDatabase); +const briefWriteProcedure = briefProcedure.use(withScopedPermission('task:update')); const idInput = z.object({ id: z.string() }); @@ -33,7 +36,7 @@ const listSchema = z.object({ }); export const briefRouter = router({ - create: briefProcedure.input(createSchema).mutation(async ({ input, ctx }) => { + create: briefWriteProcedure.input(createSchema).mutation(async ({ input, ctx }) => { try { const { artifacts, ...rest } = input; // Legacy clients pass artifacts as a flat doc-id list; the storage shape @@ -47,12 +50,12 @@ export const briefRouter = router({ // Resolve taskId if it's an identifier if (createData.taskId) { - const taskModel = new TaskModel(ctx.serverDB, ctx.userId); + const taskModel = new TaskModel(ctx.serverDB, ctx.userId, ctx.workspaceId ?? undefined); const task = await taskModel.resolve(createData.taskId); if (task) createData.taskId = task.id; } - const model = new BriefModel(ctx.serverDB, ctx.userId); + const model = new BriefModel(ctx.serverDB, ctx.userId, ctx.workspaceId ?? undefined); const brief = await model.create(createData); return { data: brief, message: 'Brief created', success: true }; } catch (error) { @@ -65,9 +68,9 @@ export const briefRouter = router({ } }), - delete: briefProcedure.input(idInput).mutation(async ({ input, ctx }) => { + delete: briefWriteProcedure.input(idInput).mutation(async ({ input, ctx }) => { try { - const model = new BriefModel(ctx.serverDB, ctx.userId); + const model = new BriefModel(ctx.serverDB, ctx.userId, ctx.workspaceId ?? undefined); const deleted = await model.delete(input.id); if (!deleted) throw new TRPCError({ code: 'NOT_FOUND', message: 'Brief not found' }); return { message: 'Brief deleted', success: true }; @@ -84,7 +87,7 @@ export const briefRouter = router({ find: briefProcedure.input(idInput).query(async ({ input, ctx }) => { try { - const model = new BriefModel(ctx.serverDB, ctx.userId); + const model = new BriefModel(ctx.serverDB, ctx.userId, ctx.workspaceId ?? undefined); const brief = await model.findById(input.id); if (!brief) throw new TRPCError({ code: 'NOT_FOUND', message: 'Brief not found' }); return { data: brief, success: true }; @@ -103,7 +106,7 @@ export const briefRouter = router({ .input(z.object({ taskId: z.string() })) .query(async ({ input, ctx }) => { try { - const model = new BriefModel(ctx.serverDB, ctx.userId); + const model = new BriefModel(ctx.serverDB, ctx.userId, ctx.workspaceId ?? undefined); const items = await model.findByTaskId(input.taskId); return { data: items, success: true }; } catch (error) { @@ -118,7 +121,7 @@ export const briefRouter = router({ list: briefProcedure.input(listSchema).query(async ({ input, ctx }) => { try { - const service = new BriefService(ctx.serverDB, ctx.userId); + const service = new BriefService(ctx.serverDB, ctx.userId, ctx.workspaceId ?? undefined); const result = await service.list(input); return { data: result.briefs, success: true, total: result.total }; @@ -134,7 +137,7 @@ export const briefRouter = router({ listUnresolved: briefProcedure.query(async ({ ctx }) => { try { - const service = new BriefService(ctx.serverDB, ctx.userId); + const service = new BriefService(ctx.serverDB, ctx.userId, ctx.workspaceId ?? undefined); const data = await service.listUnresolved(); return { data, success: true }; } catch (error) { @@ -147,9 +150,9 @@ export const briefRouter = router({ } }), - markRead: briefProcedure.input(idInput).mutation(async ({ input, ctx }) => { + markRead: briefWriteProcedure.input(idInput).mutation(async ({ input, ctx }) => { try { - const model = new BriefModel(ctx.serverDB, ctx.userId); + const model = new BriefModel(ctx.serverDB, ctx.userId, ctx.workspaceId ?? undefined); const brief = await model.markRead(input.id); if (!brief) throw new TRPCError({ code: 'NOT_FOUND', message: 'Brief not found' }); return { data: brief, message: 'Brief marked as read', success: true }; @@ -164,7 +167,7 @@ export const briefRouter = router({ } }), - resolve: briefProcedure + resolve: briefWriteProcedure .input( idInput.merge( z.object({ @@ -175,7 +178,7 @@ export const briefRouter = router({ ) .mutation(async ({ input, ctx }) => { try { - const model = new BriefModel(ctx.serverDB, ctx.userId); + const model = new BriefModel(ctx.serverDB, ctx.userId, ctx.workspaceId ?? undefined); const currentBrief = await model.findById(input.id); if (!currentBrief) throw new TRPCError({ code: 'NOT_FOUND', message: 'Brief not found' }); @@ -183,13 +186,17 @@ export const briefRouter = router({ action: input.action, comment: input.comment, }; + const wsId = ctx.workspaceId ?? undefined; const brief = currentBrief.trigger === NIGHTLY_REVIEW_BRIEF_TRIGGER - ? await new AgentSignalSelfReviewBriefService(ctx.serverDB, ctx.userId).resolve( + ? await new AgentSignalSelfReviewBriefService(ctx.serverDB, ctx.userId, wsId).resolve( currentBrief, resolveOptions, ) - : await new BriefService(ctx.serverDB, ctx.userId).resolve(input.id, resolveOptions); + : await new BriefService(ctx.serverDB, ctx.userId, wsId).resolve( + input.id, + resolveOptions, + ); if (!brief) throw new TRPCError({ code: 'NOT_FOUND', message: 'Brief not found' }); return { data: brief, message: 'Brief resolved', success: true }; diff --git a/src/server/routers/lambda/chunk.ts b/src/server/routers/lambda/chunk.ts index ac2d9ad708..68debe6533 100644 --- a/src/server/routers/lambda/chunk.ts +++ b/src/server/routers/lambda/chunk.ts @@ -3,6 +3,8 @@ import { RequestTrigger, SemanticSearchSchema } from '@lobechat/types'; import { TRPCError } from '@trpc/server'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { AsyncTaskModel } from '@/database/models/asyncTask'; import { ChunkModel } from '@/database/models/chunk'; import { DocumentModel } from '@/database/models/document'; @@ -10,7 +12,7 @@ import { EmbeddingModel } from '@/database/models/embedding'; import { FileModel } from '@/database/models/file'; import { MessageModel } from '@/database/models/message'; import { SearchRepo } from '@/database/repositories/search'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { getServerDefaultFilesConfig } from '@/server/globalConfig'; import { initModelRuntimeFromDB } from '@/server/modules/ModelRuntime'; @@ -18,27 +20,29 @@ import { ChunkService } from '@/server/services/chunk'; import { DocumentService } from '@/server/services/document'; import { KnowledgeBaseSearchService } from '@/server/services/knowledgeBase'; -const chunkProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const chunkProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId), - chunkModel: new ChunkModel(ctx.serverDB, ctx.userId), - chunkService: new ChunkService(ctx.serverDB, ctx.userId), - documentModel: new DocumentModel(ctx.serverDB, ctx.userId), - documentService: new DocumentService(ctx.serverDB, ctx.userId), - embeddingModel: new EmbeddingModel(ctx.serverDB, ctx.userId), - fileModel: new FileModel(ctx.serverDB, ctx.userId), - knowledgeBaseSearchService: new KnowledgeBaseSearchService(ctx.serverDB, ctx.userId), - messageModel: new MessageModel(ctx.serverDB, ctx.userId), - searchRepo: new SearchRepo(ctx.serverDB, ctx.userId), + asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId, wsId), + chunkModel: new ChunkModel(ctx.serverDB, ctx.userId, wsId), + chunkService: new ChunkService(ctx.serverDB, ctx.userId, wsId), + documentModel: new DocumentModel(ctx.serverDB, ctx.userId, wsId), + documentService: new DocumentService(ctx.serverDB, ctx.userId, wsId), + embeddingModel: new EmbeddingModel(ctx.serverDB, ctx.userId, wsId), + fileModel: new FileModel(ctx.serverDB, ctx.userId, wsId), + knowledgeBaseSearchService: new KnowledgeBaseSearchService(ctx.serverDB, ctx.userId, wsId), + messageModel: new MessageModel(ctx.serverDB, ctx.userId, wsId), + searchRepo: new SearchRepo(ctx.serverDB, ctx.userId, wsId), }, }); }); export const chunkRouter = router({ createEmbeddingChunksTask: chunkProcedure + .use(withScopedPermission('knowledge_base:update')) .input( z.object({ id: z.string(), @@ -51,6 +55,7 @@ export const chunkRouter = router({ }), createParseFileTask: chunkProcedure + .use(withScopedPermission('knowledge_base:update')) .input( z.object({ id: z.string(), @@ -90,6 +95,7 @@ export const chunkRouter = router({ }), retryParseFileTask: chunkProcedure + .use(withScopedPermission('knowledge_base:update')) .input( z.object({ id: z.string(), @@ -122,7 +128,12 @@ export const chunkRouter = router({ const { model, provider } = getServerDefaultFilesConfig().embeddingModel || DEFAULT_FILE_EMBEDDING_MODEL_ITEM; // Read user's provider config from database - const agentRuntime = await initModelRuntimeFromDB(ctx.serverDB, ctx.userId, provider); + const agentRuntime = await initModelRuntimeFromDB( + ctx.serverDB, + ctx.userId, + provider, + ctx.workspaceId ?? undefined, + ); const embeddings = await agentRuntime.embeddings( { diff --git a/src/server/routers/lambda/comfyui.ts b/src/server/routers/lambda/comfyui.ts index 9445b82e75..bfdb845b20 100644 --- a/src/server/routers/lambda/comfyui.ts +++ b/src/server/routers/lambda/comfyui.ts @@ -1,6 +1,7 @@ import { type ComfyUIKeyVault } from '@lobechat/types'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; import { authedProcedure, router } from '@/libs/trpc/lambda'; // Import Framework layer services import { ComfyUIClientService } from '@/server/services/comfyui/core/comfyUIClientService'; @@ -26,6 +27,7 @@ export const comfyuiRouter = router({ * Create image with complete business logic */ createImage: authedProcedure + .use(withScopedPermission('file:upload')) .input( z.object({ model: z.string(), diff --git a/src/server/routers/lambda/connector.ts b/src/server/routers/lambda/connector.ts index d8e5f48901..0e65011104 100644 --- a/src/server/routers/lambda/connector.ts +++ b/src/server/routers/lambda/connector.ts @@ -1,6 +1,7 @@ import { TRPCError } from '@trpc/server'; import { z } from 'zod'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { ConnectorModel } from '@/database/models/connector'; import { ConnectorToolModel } from '@/database/models/connectorTool'; import { PluginModel } from '@/database/models/plugin'; @@ -12,7 +13,7 @@ import { ConnectorToolPermission, } from '@/database/schemas'; import { inferCrudType } from '@/libs/mcp/utils'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt'; import { callConnectorToolById, ConnectorToolCallError } from '@/server/services/connector/exec'; @@ -28,15 +29,16 @@ import { } from '@/server/services/connector/stateStore'; import { syncConnectorToolsById } from '@/server/services/connector/sync'; -const connectorProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const connectorProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; // Credentials (OAuth tokens) are encrypted at rest — give the model a gatekeeper. const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey(); + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - connectorModel: new ConnectorModel(ctx.serverDB, ctx.userId, gateKeeper), - connectorToolModel: new ConnectorToolModel(ctx.serverDB, ctx.userId), - pluginModel: new PluginModel(ctx.serverDB, ctx.userId), + connectorModel: new ConnectorModel(ctx.serverDB, ctx.userId, wsId, gateKeeper), + connectorToolModel: new ConnectorToolModel(ctx.serverDB, ctx.userId, wsId), + pluginModel: new PluginModel(ctx.serverDB, ctx.userId, wsId), }, }); }); @@ -106,7 +108,7 @@ export const connectorRouter = router({ * The Add modal must display THIS value (not a client-derived origin) so the * URI the user registers matches the one used at authorize time. */ - getRedirectUri: authedProcedure.query(() => ({ redirectUri: getConnectorRedirectUri() })), + getRedirectUri: wsCompatProcedure.query(() => ({ redirectUri: getConnectorRedirectUri() })), // ── Mutations ───────────────────────────────────────────────────────────── diff --git a/src/server/routers/lambda/document.ts b/src/server/routers/lambda/document.ts index 9e9e266cca..67fcbf35d7 100644 --- a/src/server/routers/lambda/document.ts +++ b/src/server/routers/lambda/document.ts @@ -1,13 +1,20 @@ +import { TRPCError } from '@trpc/server'; +import { and, eq, isNull } from 'drizzle-orm'; import { z } from 'zod'; +import { businessFileTransferStorageCheck } from '@/business/server/lambda-routers/file'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { FREE_DOCUMENT_HISTORY_WINDOW_DAYS } from '@/const/documentHistory'; import { ChunkModel } from '@/database/models/chunk'; import { DocumentModel } from '@/database/models/document'; import { FileModel } from '@/database/models/file'; import { MessageModel } from '@/database/models/message'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { workspaceMembers } from '@/database/schemas'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { DocumentService } from '@/server/services/document'; +import { TransferErrorCode } from '@/types/transferError'; import { compareDocumentHistoryItemsInputSchema, @@ -23,22 +30,24 @@ const getFreeDocumentHistorySince = () => { return new Date(now - FREE_DOCUMENT_HISTORY_WINDOW_DAYS * 24 * 60 * 60 * 1000); }; -const documentProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const documentProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - chunkModel: new ChunkModel(ctx.serverDB, ctx.userId), - documentModel: new DocumentModel(ctx.serverDB, ctx.userId), - documentService: new DocumentService(ctx.serverDB, ctx.userId), - fileModel: new FileModel(ctx.serverDB, ctx.userId), - messageModel: new MessageModel(ctx.serverDB, ctx.userId), + chunkModel: new ChunkModel(ctx.serverDB, ctx.userId, wsId), + documentModel: new DocumentModel(ctx.serverDB, ctx.userId, wsId), + documentService: new DocumentService(ctx.serverDB, ctx.userId, wsId), + fileModel: new FileModel(ctx.serverDB, ctx.userId, wsId), + messageModel: new MessageModel(ctx.serverDB, ctx.userId, wsId), }, }); }); export const documentRouter = router({ createDocument: documentProcedure + .use(withScopedPermission('document:create')) .input( z.object({ content: z.string().optional(), @@ -71,6 +80,7 @@ export const documentRouter = router({ }), createDocuments: documentProcedure + .use(withScopedPermission('document:create')) .input( z.object({ documents: z.array( @@ -115,12 +125,14 @@ export const documentRouter = router({ }), deleteDocument: documentProcedure + .use(withScopedPermission('document:delete')) .input(z.object({ id: z.string() })) .mutation(async ({ ctx, input }) => { return ctx.documentService.deleteDocument(input.id); }), deleteDocuments: documentProcedure + .use(withScopedPermission('document:delete')) .input(z.object({ ids: z.array(z.string()) })) .mutation(async ({ ctx, input }) => { return ctx.documentService.deleteDocuments(input.ids); @@ -163,6 +175,7 @@ export const documentRouter = router({ }), saveDocumentHistory: documentProcedure + .use(withScopedPermission('document:update')) .input(saveDocumentHistoryInputSchema) .mutation(async ({ ctx, input }) => { const editorData = JSON.parse(input.editorData); @@ -199,6 +212,7 @@ export const documentRouter = router({ }), parseDocument: documentProcedure + .use(withScopedPermission('document:update')) .input( z.object({ id: z.string(), @@ -211,6 +225,7 @@ export const documentRouter = router({ }), parseFileContent: documentProcedure + .use(withScopedPermission('document:update')) .input( z.object({ id: z.string(), @@ -239,6 +254,7 @@ export const documentRouter = router({ }), updateDocument: documentProcedure + .use(withScopedPermission('document:update')) .input(updateDocumentInputSchema) .mutation(async ({ ctx, input }) => { const { id, editorData: editorDataString, ...params } = input; @@ -251,4 +267,134 @@ export const documentRouter = router({ return result; }), + + transferDocument: documentProcedure + .use(withScopedPermission('document:update')) + .input( + z.object({ + documentId: z.string(), + targetWorkspaceId: z.string().nullable(), + }), + ) + .mutation(async ({ ctx, input }) => { + const doc = await ctx.documentModel.findById(input.documentId); + if (!doc) + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.ResourceNotFound } }, + code: 'NOT_FOUND', + message: 'Document not found', + }); + + // Workspace mode: only owners can transfer items created by others + if (ctx.workspaceId && doc.userId !== ctx.userId) { + const [membership] = await ctx.serverDB + .select({ role: workspaceMembers.role }) + .from(workspaceMembers) + .where( + and( + eq(workspaceMembers.workspaceId, ctx.workspaceId), + eq(workspaceMembers.userId, ctx.userId), + isNull(workspaceMembers.deletedAt), + ), + ) + .limit(1); + if (!membership || membership.role !== 'owner') { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.OwnerOnly } }, + code: 'FORBIDDEN', + message: 'Only workspace owners can transfer items created by others', + }); + } + } + + if (input.targetWorkspaceId === (ctx.workspaceId ?? null)) { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.SameWorkspace } }, + code: 'BAD_REQUEST', + message: 'Cannot transfer document to the same workspace', + }); + } + + if (input.targetWorkspaceId) { + const [targetMembership] = await ctx.serverDB + .select({ role: workspaceMembers.role }) + .from(workspaceMembers) + .where( + and( + eq(workspaceMembers.workspaceId, input.targetWorkspaceId), + eq(workspaceMembers.userId, ctx.userId), + isNull(workspaceMembers.deletedAt), + ), + ) + .limit(1); + if (!targetMembership || targetMembership.role === 'viewer') { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.TargetNoWriteAccess } }, + code: 'FORBIDDEN', + message: 'No write access to target workspace', + }); + } + } + + const additionalSize = await ctx.documentModel.countFileUsageInSubtree(input.documentId); + await businessFileTransferStorageCheck({ + additionalSize, + targetUserId: ctx.userId, + targetWorkspaceId: input.targetWorkspaceId, + }); + + return ctx.documentModel.transferTo(input.documentId, input.targetWorkspaceId, ctx.userId); + }), + + copyDocumentToWorkspace: documentProcedure + .use(withScopedPermission('document:create')) + .input( + z.object({ + documentId: z.string(), + targetWorkspaceId: z.string().nullable(), + }), + ) + .mutation(async ({ ctx, input }) => { + const doc = await ctx.documentModel.findById(input.documentId); + if (!doc) + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.ResourceNotFound } }, + code: 'NOT_FOUND', + message: 'Document not found', + }); + + if (input.targetWorkspaceId) { + const [targetMembership] = await ctx.serverDB + .select({ role: workspaceMembers.role }) + .from(workspaceMembers) + .where( + and( + eq(workspaceMembers.workspaceId, input.targetWorkspaceId), + eq(workspaceMembers.userId, ctx.userId), + isNull(workspaceMembers.deletedAt), + ), + ) + .limit(1); + if (!targetMembership || targetMembership.role === 'viewer') { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.TargetNoWriteAccess } }, + code: 'FORBIDDEN', + message: 'No write access to target workspace', + }); + } + } + + const additionalSize = await ctx.documentModel.countFileUsageInSubtree(input.documentId); + await businessFileTransferStorageCheck({ + additionalSize, + targetUserId: ctx.userId, + targetWorkspaceId: input.targetWorkspaceId, + }); + + return ctx.documentModel.copyToWorkspace( + input.documentId, + input.targetWorkspaceId, + ctx.userId, + ); + }), }); diff --git a/src/server/routers/lambda/exporter.ts b/src/server/routers/lambda/exporter.ts index 64a44dfc89..7c151205dd 100644 --- a/src/server/routers/lambda/exporter.ts +++ b/src/server/routers/lambda/exporter.ts @@ -2,25 +2,29 @@ import { marked } from 'marked'; import PDFDocument from 'pdfkit'; import { z } from 'zod'; +import { withRbacPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { DrizzleMigrationModel } from '@/database/models/drizzleMigration'; import { MessageModel } from '@/database/models/message'; import { SessionModel } from '@/database/models/session'; import { DataExporterRepos } from '@/database/repositories/dataExporter'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { type ExportDatabaseData } from '@/types/export'; -const exportProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const exportProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; - const dataExporterRepos = new DataExporterRepos(ctx.serverDB, ctx.userId); + const wsId = ctx.workspaceId ?? undefined; + const dataExporterRepos = new DataExporterRepos(ctx.serverDB, ctx.userId, wsId); const drizzleMigration = new DrizzleMigrationModel(ctx.serverDB); - const messageModel = new MessageModel(ctx.serverDB, ctx.userId); - const sessionModel = new SessionModel(ctx.serverDB, ctx.userId); + const messageModel = new MessageModel(ctx.serverDB, ctx.userId, wsId); + const sessionModel = new SessionModel(ctx.serverDB, ctx.userId, wsId); return opts.next({ ctx: { dataExporterRepos, drizzleMigration, messageModel, sessionModel }, }); }); +const workspaceExportProcedure = exportProcedure.use(withRbacPermission('workspace:update:all')); const REGULAR_FONT_URL = 'https://cdn.jsdelivr.net/gh/adobe-fonts/source-han-sans@2.004R/OTF/SimplifiedChinese/SourceHanSansSC-Regular.otf'; @@ -168,13 +172,13 @@ const generatePdfFromMarkdown = async ( }; export const exporterRouter = router({ - exportData: exportProcedure.mutation(async ({ ctx }): Promise => { + exportData: workspaceExportProcedure.mutation(async ({ ctx }): Promise => { const data = await ctx.dataExporterRepos.export(5); const schemaHash = await ctx.drizzleMigration.getLatestMigrationHash(); return { data, schemaHash }; }), - exportPdf: exportProcedure + exportPdf: workspaceExportProcedure .input( z.object({ content: z.string(), diff --git a/src/server/routers/lambda/file.ts b/src/server/routers/lambda/file.ts index 53153333d1..a95e97c7b9 100644 --- a/src/server/routers/lambda/file.ts +++ b/src/server/routers/lambda/file.ts @@ -4,23 +4,39 @@ import { DERIVED_DOCUMENT_SOURCE_TYPE, } from '@lobechat/const'; import { TRPCError } from '@trpc/server'; +import { and, eq, isNull } from 'drizzle-orm'; import { z } from 'zod'; -import { businessFileUploadCheck } from '@/business/server/lambda-routers/file'; +import { + businessFileTransferStorageCheck, + businessFileUploadCheck, +} from '@/business/server/lambda-routers/file'; import { checkFileStorageUsage } from '@/business/server/trpc-middlewares/lambda'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { serverDBEnv } from '@/config/db'; import { AsyncTaskModel } from '@/database/models/asyncTask'; import { ChunkModel } from '@/database/models/chunk'; import { DocumentModel } from '@/database/models/document'; import { FileModel } from '@/database/models/file'; import { KnowledgeRepo } from '@/database/repositories/knowledge'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { workspaceMembers } from '@/database/schemas'; +import { appEnv } from '@/envs/app'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { DocumentService } from '@/server/services/document'; import { FileService } from '@/server/services/file'; import { AsyncTaskStatus, AsyncTaskType, type IAsyncTaskError } from '@/types/asyncTask'; import type { FileListItem, KnowledgeItemStatus } from '@/types/files'; import { QueryFileListSchema, UploadFileSchema } from '@/types/files'; +import { TransferErrorCode } from '@/types/transferError'; + +/** + * Generate file proxy URL + * Returns a unified proxy URL format: ${APP_URL}/f/:id + */ +const getFileProxyUrl = (fileId: string): string => `${appEnv.APP_URL}/f/${fileId}`; +const fileTransferEntityTypeSchema = z.enum(['document', 'file', 'folder']); const filterKnowledgeItems = < T extends { @@ -119,24 +135,26 @@ const isStoredObjectAvailable = async (fileService: FileService, url: string): P } }; -const fileProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const fileProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId), - chunkModel: new ChunkModel(ctx.serverDB, ctx.userId), - documentModel: new DocumentModel(ctx.serverDB, ctx.userId), - documentService: new DocumentService(ctx.serverDB, ctx.userId), - fileModel: new FileModel(ctx.serverDB, ctx.userId), - fileService: new FileService(ctx.serverDB, ctx.userId), - knowledgeRepo: new KnowledgeRepo(ctx.serverDB, ctx.userId), + asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId, wsId), + chunkModel: new ChunkModel(ctx.serverDB, ctx.userId, wsId), + documentModel: new DocumentModel(ctx.serverDB, ctx.userId, wsId), + documentService: new DocumentService(ctx.serverDB, ctx.userId, wsId), + fileModel: new FileModel(ctx.serverDB, ctx.userId, wsId), + fileService: new FileService(ctx.serverDB, ctx.userId, wsId), + knowledgeRepo: new KnowledgeRepo(ctx.serverDB, ctx.userId, wsId), }, }); }); export const fileRouter = router({ checkFileHash: fileProcedure + .use(withScopedPermission('file:upload')) .use(checkFileStorageUsage) .input(z.object({ hash: z.string() })) .mutation(async ({ ctx, input }) => { @@ -150,6 +168,7 @@ export const fileRouter = router({ }), createFile: fileProcedure + .use(withScopedPermission('file:upload')) .use(checkFileStorageUsage) .input( UploadFileSchema.omit({ url: true }).extend({ @@ -187,6 +206,7 @@ export const fileRouter = router({ inputSize: input.size, url: input.url, userId: ctx.userId, + workspaceId: ctx.workspaceId, }); throw new TRPCError({ code: 'BAD_REQUEST', message: 'File size cannot be negative' }); } @@ -199,6 +219,7 @@ export const fileRouter = router({ transaction: trx, url: input.url, userId: ctx.userId, + workspaceId: ctx.workspaceId, }); let shouldRefreshGlobalFile = false; @@ -424,6 +445,7 @@ export const fileRouter = router({ }), deleteKnowledgeItemsByQuery: fileProcedure + .use(withScopedPermission('file:delete')) .input(QueryFileListSchema) .mutation(async ({ ctx, input }): Promise<{ count: number }> => { const fileIds: string[] = []; @@ -554,33 +576,39 @@ export const fileRouter = router({ .slice(0, limit); }), - removeAllFiles: fileProcedure.mutation(async ({ ctx }) => { - // Get all file IDs for this user - const allFiles = await ctx.fileModel.query({ showFilesInKnowledgeBase: true }); - const fileIds = allFiles.map((f) => f.id); + removeAllFiles: fileProcedure + .use(withScopedPermission('file:delete')) + .mutation(async ({ ctx }) => { + // Get all file IDs for this user + const allFiles = await ctx.fileModel.query({ showFilesInKnowledgeBase: true }); + const fileIds = allFiles.map((f) => f.id); - // Use deleteMany to properly handle shared files (globalFiles reference counting) - const needToRemoveFileList = await ctx.fileModel.deleteMany( - fileIds, - serverDBEnv.REMOVE_GLOBAL_FILE, - ); + // Use deleteMany to properly handle shared files (globalFiles reference counting) + const needToRemoveFileList = await ctx.fileModel.deleteMany( + fileIds, + serverDBEnv.REMOVE_GLOBAL_FILE, + ); - // Delete S3 files only if no other users reference them - if (needToRemoveFileList && needToRemoveFileList.length > 0) { - await ctx.fileService.deleteFiles(needToRemoveFileList.map((file) => file.url!)); - } - }), + // Delete S3 files only if no other users reference them + if (needToRemoveFileList && needToRemoveFileList.length > 0) { + await ctx.fileService.deleteFiles(needToRemoveFileList.map((file) => file.url!)); + } + }), - removeFile: fileProcedure.input(z.object({ id: z.string() })).mutation(async ({ input, ctx }) => { - const file = await ctx.fileModel.delete(input.id, serverDBEnv.REMOVE_GLOBAL_FILE); + removeFile: fileProcedure + .use(withScopedPermission('file:delete')) + .input(z.object({ id: z.string() })) + .mutation(async ({ input, ctx }) => { + const file = await ctx.fileModel.delete(input.id, serverDBEnv.REMOVE_GLOBAL_FILE); - if (!file) return; + if (!file) return; - // delete the file from S3 if it is not used by other files - await ctx.fileService.deleteFile(file.url!); - }), + // delete the file from S3 if it is not used by other files + await ctx.fileService.deleteFile(file.url!); + }), removeFileAsyncTask: fileProcedure + .use(withScopedPermission('file:update')) .input( z.object({ id: z.string(), @@ -600,6 +628,7 @@ export const fileRouter = router({ }), removeFiles: fileProcedure + .use(withScopedPermission('file:delete')) .input(z.object({ ids: z.array(z.string()) })) .mutation(async ({ input, ctx }) => { const needToRemoveFileList = await ctx.fileModel.deleteMany( @@ -614,6 +643,7 @@ export const fileRouter = router({ }), updateFile: fileProcedure + .use(withScopedPermission('file:update')) .input( z.object({ id: z.string(), @@ -654,6 +684,142 @@ export const fileRouter = router({ return { success: true }; }), + + transferEntity: fileProcedure + .use(withScopedPermission('file:upload')) + .input( + z.object({ + entityType: fileTransferEntityTypeSchema, + id: z.string(), + targetWorkspaceId: z.string().nullable(), + }), + ) + .mutation(async ({ ctx, input }) => { + if (input.targetWorkspaceId === (ctx.workspaceId ?? null)) { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.SameWorkspace } }, + code: 'BAD_REQUEST', + message: 'Cannot transfer to the same workspace', + }); + } + + if (input.targetWorkspaceId) { + const [targetMembership] = await ctx.serverDB + .select({ role: workspaceMembers.role }) + .from(workspaceMembers) + .where( + and( + eq(workspaceMembers.workspaceId, input.targetWorkspaceId), + eq(workspaceMembers.userId, ctx.userId), + isNull(workspaceMembers.deletedAt), + ), + ) + .limit(1); + if (!targetMembership || targetMembership.role === 'viewer') { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.TargetNoWriteAccess } }, + code: 'FORBIDDEN', + message: 'No write access to target workspace', + }); + } + } + + if (input.entityType === 'folder' || input.entityType === 'document') { + const document = await ctx.documentModel.findById(input.id); + if (!document) { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.ResourceNotFound } }, + code: 'NOT_FOUND', + message: input.entityType === 'folder' ? 'Folder not found' : 'Document not found', + }); + } + const additionalSize = await ctx.documentModel.countFileUsageInSubtree(input.id); + await businessFileTransferStorageCheck({ + additionalSize, + targetUserId: ctx.userId, + targetWorkspaceId: input.targetWorkspaceId, + }); + return ctx.documentModel.transferTo(input.id, input.targetWorkspaceId, ctx.userId); + } + + const file = await ctx.fileModel.findById(input.id); + if (!file) + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.ResourceNotFound } }, + code: 'NOT_FOUND', + message: 'File not found', + }); + await businessFileTransferStorageCheck({ + additionalSize: file.size, + targetUserId: ctx.userId, + targetWorkspaceId: input.targetWorkspaceId, + }); + return ctx.fileModel.transferTo(input.id, input.targetWorkspaceId, ctx.userId); + }), + + copyEntityToWorkspace: fileProcedure + .use(withScopedPermission('file:upload')) + .input( + z.object({ + entityType: fileTransferEntityTypeSchema, + id: z.string(), + targetWorkspaceId: z.string().nullable(), + }), + ) + .mutation(async ({ ctx, input }) => { + if (input.targetWorkspaceId) { + const [targetMembership] = await ctx.serverDB + .select({ role: workspaceMembers.role }) + .from(workspaceMembers) + .where( + and( + eq(workspaceMembers.workspaceId, input.targetWorkspaceId), + eq(workspaceMembers.userId, ctx.userId), + isNull(workspaceMembers.deletedAt), + ), + ) + .limit(1); + if (!targetMembership || targetMembership.role === 'viewer') { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.TargetNoWriteAccess } }, + code: 'FORBIDDEN', + message: 'No write access to target workspace', + }); + } + } + + if (input.entityType === 'folder' || input.entityType === 'document') { + const document = await ctx.documentModel.findById(input.id); + if (!document) { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.ResourceNotFound } }, + code: 'NOT_FOUND', + message: input.entityType === 'folder' ? 'Folder not found' : 'Document not found', + }); + } + const additionalSize = await ctx.documentModel.countFileUsageInSubtree(input.id); + await businessFileTransferStorageCheck({ + additionalSize, + targetUserId: ctx.userId, + targetWorkspaceId: input.targetWorkspaceId, + }); + return ctx.documentModel.copyToWorkspace(input.id, input.targetWorkspaceId, ctx.userId); + } + + const file = await ctx.fileModel.findById(input.id); + if (!file) + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.ResourceNotFound } }, + code: 'NOT_FOUND', + message: 'File not found', + }); + await businessFileTransferStorageCheck({ + additionalSize: file.size, + targetUserId: ctx.userId, + targetWorkspaceId: input.targetWorkspaceId, + }); + return ctx.fileModel.copyToWorkspace(input.id, input.targetWorkspaceId, ctx.userId); + }), }); export type FileRouter = typeof fileRouter; diff --git a/src/server/routers/lambda/followUpAction.ts b/src/server/routers/lambda/followUpAction.ts index 057b6006c1..8705509285 100644 --- a/src/server/routers/lambda/followUpAction.ts +++ b/src/server/routers/lambda/followUpAction.ts @@ -1,20 +1,24 @@ import { FollowUpExtractInputSchema } from '@lobechat/types'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { FollowUpActionService } from '@/server/services/followUpAction'; -const followUpProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const followUpProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - followUpService: new FollowUpActionService(ctx.serverDB, ctx.userId), + followUpService: new FollowUpActionService(ctx.serverDB, ctx.userId, wsId), }, }); }); +const followUpWriteProcedure = followUpProcedure.use(withScopedPermission('message:create')); export const followUpActionRouter = router({ - extract: followUpProcedure + extract: followUpWriteProcedure .input(FollowUpExtractInputSchema) .mutation(async ({ input, ctx }) => ctx.followUpService.extract(input)), }); diff --git a/src/server/routers/lambda/generation.ts b/src/server/routers/lambda/generation.ts index d6b2fd7171..26959631a2 100644 --- a/src/server/routers/lambda/generation.ts +++ b/src/server/routers/lambda/generation.ts @@ -1,23 +1,26 @@ import { TRPCError } from '@trpc/server'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { AsyncTaskModel } from '@/database/models/asyncTask'; import { GenerationModel } from '@/database/models/generation'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { FileService } from '@/server/services/file'; import { type AsyncTaskError } from '@/types/asyncTask'; import { AsyncTaskStatus } from '@/types/asyncTask'; import { type Generation } from '@/types/generation'; -const generationProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const generationProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId), - fileService: new FileService(ctx.serverDB, ctx.userId), - generationModel: new GenerationModel(ctx.serverDB, ctx.userId), + asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId, wsId), + fileService: new FileService(ctx.serverDB, ctx.userId, wsId), + generationModel: new GenerationModel(ctx.serverDB, ctx.userId, wsId), }, }); }); @@ -30,6 +33,7 @@ export type GetGenerationStatusResult = { export const generationRouter = router({ deleteGeneration: generationProcedure + .use(withScopedPermission('file:delete')) .input(z.object({ generationId: z.string() })) .mutation(async ({ ctx, input }) => { // Delete the generation record from database and get the deleted data diff --git a/src/server/routers/lambda/generationBatch.ts b/src/server/routers/lambda/generationBatch.ts index f2d8fa21f0..15bbf65b26 100644 --- a/src/server/routers/lambda/generationBatch.ts +++ b/src/server/routers/lambda/generationBatch.ts @@ -1,24 +1,28 @@ import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { GenerationBatchModel } from '@/database/models/generationBatch'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { FileService } from '@/server/services/file'; import { getVideoAvgLatency } from '@/server/services/generation/latency'; -const generationBatchProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const generationBatchProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - fileService: new FileService(ctx.serverDB, ctx.userId), - generationBatchModel: new GenerationBatchModel(ctx.serverDB, ctx.userId), + fileService: new FileService(ctx.serverDB, ctx.userId, wsId), + generationBatchModel: new GenerationBatchModel(ctx.serverDB, ctx.userId, wsId), }, }); }); export const generationBatchRouter = router({ deleteGenerationBatch: generationBatchProcedure + .use(withScopedPermission('file:delete')) .input(z.object({ batchId: z.string() })) .mutation(async ({ ctx, input }) => { // 1. Delete database records and get thumbnail URLs to clean diff --git a/src/server/routers/lambda/generationTopic.ts b/src/server/routers/lambda/generationTopic.ts index bd81079ed5..acf5de8a50 100644 --- a/src/server/routers/lambda/generationTopic.ts +++ b/src/server/routers/lambda/generationTopic.ts @@ -1,20 +1,23 @@ import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { GenerationTopicModel } from '@/database/models/generationTopic'; import { type GenerationTopicItem } from '@/database/schemas/generation'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { FileService } from '@/server/services/file'; import { GenerationService } from '@/server/services/generation'; -const generationTopicProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const generationTopicProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - fileService: new FileService(ctx.serverDB, ctx.userId), - generationService: new GenerationService(ctx.serverDB, ctx.userId), - generationTopicModel: new GenerationTopicModel(ctx.serverDB, ctx.userId), + fileService: new FileService(ctx.serverDB, ctx.userId, wsId), + generationService: new GenerationService(ctx.serverDB, ctx.userId, wsId), + generationTopicModel: new GenerationTopicModel(ctx.serverDB, ctx.userId, wsId), }, }); }); @@ -35,12 +38,14 @@ const updateTopicCoverSchema = z.object({ export const generationTopicRouter = router({ createTopic: generationTopicProcedure + .use(withScopedPermission('topic:create')) .input(z.object({ type: z.enum(['image', 'video']).optional() }).optional()) .mutation(async ({ ctx, input }) => { const data = await ctx.generationTopicModel.create('', input?.type); return data.id; }), deleteTopic: generationTopicProcedure + .use(withScopedPermission('topic:delete')) .input(z.object({ id: z.string() })) .mutation(async ({ ctx, input }) => { // 1. Delete database records and get file URLs to clean @@ -74,11 +79,13 @@ export const generationTopicRouter = router({ return ctx.generationTopicModel.queryAll(input?.type); }), updateTopic: generationTopicProcedure + .use(withScopedPermission('topic:update')) .input(updateTopicSchema) .mutation(async ({ ctx, input }) => { return ctx.generationTopicModel.update(input.id, input.value as Partial); }), updateTopicCover: generationTopicProcedure + .use(withScopedPermission('topic:update')) .input(updateTopicCoverSchema) .mutation(async ({ ctx, input }) => { // Process the cover image and get key diff --git a/src/server/routers/lambda/home.ts b/src/server/routers/lambda/home.ts index 1f23bc2a86..f27126be51 100644 --- a/src/server/routers/lambda/home.ts +++ b/src/server/routers/lambda/home.ts @@ -1,21 +1,24 @@ import { after } from 'next/server'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { AgentModel } from '@/database/models/agent'; import { AgentMigrationRepo } from '@/database/repositories/agentMigration'; import { HomeRepository } from '@/database/repositories/home'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { type HomeBriefData, HomeService } from '@/server/services/home'; -const homeProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const homeProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const workspaceId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - agentMigrationRepo: new AgentMigrationRepo(ctx.serverDB, ctx.userId), - agentModel: new AgentModel(ctx.serverDB, ctx.userId), - homeRepository: new HomeRepository(ctx.serverDB, ctx.userId), + agentMigrationRepo: new AgentMigrationRepo(ctx.serverDB, ctx.userId, workspaceId), + agentModel: new AgentModel(ctx.serverDB, ctx.userId, workspaceId), + homeRepository: new HomeRepository(ctx.serverDB, ctx.userId, workspaceId), homeService: new HomeService(ctx.userId), }, }); @@ -51,6 +54,7 @@ export const homeRouter = router({ }), updateAgentSessionGroupId: homeProcedure + .use(withScopedPermission('agent:update')) .input( z.object({ agentId: z.string(), diff --git a/src/server/routers/lambda/image/index.test.ts b/src/server/routers/lambda/image/index.test.ts index d763563bfc..1f9e2e965f 100644 --- a/src/server/routers/lambda/image/index.test.ts +++ b/src/server/routers/lambda/image/index.test.ts @@ -12,6 +12,7 @@ const { mockAsyncTaskModelUpdate, mockChargeBeforeGenerate, mockCreateAsyncCaller, + mockInsertValues, mockLoadModels, mockResolveBusinessModelMapping, } = vi.hoisted(() => ({ @@ -23,6 +24,7 @@ const { mockAsyncTaskModelUpdate: vi.fn(), mockChargeBeforeGenerate: vi.fn(), mockCreateAsyncCaller: vi.fn(), + mockInsertValues: [] as unknown[], mockLoadModels: vi.fn(), mockResolveBusinessModelMapping: vi.fn(), })); @@ -114,6 +116,7 @@ describe('imageRouter', () => { beforeEach(() => { vi.clearAllMocks(); + mockInsertValues.length = 0; // Default mock implementations mockResolveBusinessModelMapping.mockImplementation( @@ -159,15 +162,18 @@ describe('imageRouter', () => { insertCallCount = 0; const tx = { insert: vi.fn().mockReturnValue({ - values: vi.fn().mockReturnValue({ - returning: vi.fn().mockImplementation(() => { - insertCallCount++; - if (insertCallCount === 1) return [mockBatch]; - if (insertCallCount === 2) return mockGenerations; - // For async tasks, return one at a time - const taskIndex = insertCallCount - 3; - return [mockAsyncTasks[taskIndex] || mockAsyncTasks[0]]; - }), + values: vi.fn((value) => { + mockInsertValues.push(value); + return { + returning: vi.fn().mockImplementation(() => { + insertCallCount++; + if (insertCallCount === 1) return [mockBatch]; + if (insertCallCount === 2) return mockGenerations; + // For async tasks, return one at a time + const taskIndex = insertCallCount - 3; + return [mockAsyncTasks[taskIndex] || mockAsyncTasks[0]]; + }), + }; }), }), update: vi.fn().mockReturnValue({ @@ -376,6 +382,24 @@ describe('imageRouter', () => { expect(mockCreateAsyncCaller).toHaveBeenCalledWith({ userId: mockUserId }); }); + it('should persist and forward workspaceId for background image tasks', async () => { + const ctx = createMockCtx({ workspaceId: 'workspace-1' }); + const input = createDefaultInput(); + + const caller = imageRouter.createCaller(ctx); + await caller.createImage(input); + + expect(mockInsertValues[0]).toEqual(expect.objectContaining({ workspaceId: 'workspace-1' })); + expect(mockInsertValues[1]).toEqual( + expect.arrayContaining([expect.objectContaining({ workspaceId: 'workspace-1' })]), + ); + expect(mockInsertValues[2]).toEqual(expect.objectContaining({ workspaceId: 'workspace-1' })); + expect(mockInsertValues[3]).toEqual(expect.objectContaining({ workspaceId: 'workspace-1' })); + expect(mockAsyncCallerCreateImage).toHaveBeenCalledWith( + expect.objectContaining({ workspaceId: 'workspace-1' }), + ); + }); + it('should handle async caller creation failure', async () => { mockCreateAsyncCaller.mockRejectedValue(new Error('Caller creation failed')); diff --git a/src/server/routers/lambda/image/index.ts b/src/server/routers/lambda/image/index.ts index f0f3f22cff..e9e215ed00 100644 --- a/src/server/routers/lambda/image/index.ts +++ b/src/server/routers/lambda/image/index.ts @@ -9,10 +9,12 @@ import { isProviderModelAvailable } from 'model-bank'; import { z } from 'zod'; import { chargeBeforeGenerate } from '@/business/server/image-generation/chargeBeforeGenerate'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { AsyncTaskModel } from '@/database/models/asyncTask'; import { type NewGeneration, type NewGenerationBatch } from '@/database/schemas'; import { asyncTasks, generationBatches, generations } from '@/database/schemas'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { createAsyncCaller } from '@/server/routers/async/caller'; import { FileService } from '@/server/services/file'; @@ -28,17 +30,20 @@ import { validateNoUrlsInConfig } from './utils'; const log = debug('lobe-image:lambda'); -const imageProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const imageProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId), - fileService: new FileService(ctx.serverDB, ctx.userId), + asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId, wsId), + fileService: new FileService(ctx.serverDB, ctx.userId, wsId), }, }); }); +const imageCreateProcedure = imageProcedure.use(withScopedPermission('file:upload')); + const createImageInputSchema = z.object({ generationTopicId: z.string(), imageNum: z.number(), @@ -59,265 +64,278 @@ const createImageInputSchema = z.object({ export type CreateImageServicePayload = z.infer; export const imageRouter = router({ - createImage: imageProcedure.input(createImageInputSchema).mutation(async ({ input, ctx }) => { - const { userId, serverDB, asyncTaskModel, fileService } = ctx; - const { generationTopicId, provider, model, imageNum, params } = input; + createImage: imageCreateProcedure + .input(createImageInputSchema) + .mutation(async ({ input, ctx }) => { + const { userId, serverDB, asyncTaskModel, fileService } = ctx; + const wsId = ctx.workspaceId ?? undefined; + const { generationTopicId, provider, model, imageNum, params } = input; - log('Starting image creation process, input: %O', input); + log('Starting image creation process, input: %O', input); - const { resolvedModelId } = await resolveBusinessModelMapping(provider, model); + const { resolvedModelId } = await resolveBusinessModelMapping(provider, model); - // Reject lobehub model ids that are no longer in the model bank so callers get a - // clear error instead of an opaque downstream failure when the underlying channel - // can't serve the requested id. - if ( - provider === BRANDING_PROVIDER && - !isProviderModelAvailable(await loadModels(), BRANDING_PROVIDER, resolvedModelId, 'image') - ) { - throw new TRPCError({ - cause: { data: { modelType: 'image', requestedModel: model } }, - code: 'BAD_REQUEST', - message: ChatErrorType.LobeHubModelDeprecated, - }); - } - - // Normalize reference image addresses, store S3 keys uniformly (avoid storing expiring presigned URLs in database) - let configForDatabase = { ...params }; - // 1) Process multiple images in imageUrls - if (Array.isArray(params.imageUrls) && params.imageUrls.length > 0) { - log('Converting imageUrls to S3 keys for database storage: %O', params.imageUrls); - try { - const imageKeysWithNull = await Promise.all( - params.imageUrls.map(async (url) => { - const key = await fileService.getKeyFromFullUrl(url); - if (key) { - log('Converted URL %s to key %s', url, key); - } else { - log('Failed to extract key from URL: %s', url); - } - return key; - }), - ); - const imageKeys = imageKeysWithNull.filter((key): key is string => key !== null); - - configForDatabase = { - ...configForDatabase, - imageUrls: imageKeys, - }; - log('Successfully converted imageUrls to keys for database: %O', imageKeys); - } catch (error) { - console.error('Error converting imageUrls to keys: %O', error); - console.error('Keeping original imageUrls due to conversion error'); - } - } - // 2) Process single image in imageUrl - if (typeof params.imageUrl === 'string' && params.imageUrl) { - try { - const key = await fileService.getKeyFromFullUrl(params.imageUrl); - if (key) { - log('Converted single imageUrl to key: %s -> %s', params.imageUrl, key); - configForDatabase = { ...configForDatabase, imageUrl: key }; - } else { - log('Failed to extract key from single imageUrl: %s', params.imageUrl); - } - } catch (error) { - console.error('Error converting imageUrl to key: %O', error); - // Keep original value if conversion fails - } - } - - // In development, convert localhost proxy URLs to S3 URLs for async task access - let generationParams = params; - if (process.env.NODE_ENV === 'development') { - const updates: Record = {}; - - // Handle single imageUrl: localhost/f/{id} -> S3 URL - if (typeof params.imageUrl === 'string' && params.imageUrl) { - const s3Url = await fileService.getFullFileUrl(configForDatabase.imageUrl as string); - if (s3Url) { - log('Dev: converted proxy URL to S3 URL: %s -> %s', params.imageUrl, s3Url); - updates.imageUrl = s3Url; - } + // Reject lobehub model ids that are no longer in the model bank so callers get a + // clear error instead of an opaque downstream failure when the underlying channel + // can't serve the requested id. + if ( + provider === BRANDING_PROVIDER && + !isProviderModelAvailable(await loadModels(), BRANDING_PROVIDER, resolvedModelId, 'image') + ) { + throw new TRPCError({ + cause: { data: { modelType: 'image', requestedModel: model } }, + code: 'BAD_REQUEST', + message: ChatErrorType.LobeHubModelDeprecated, + }); } - // Handle multiple imageUrls + // Normalize reference image addresses, store S3 keys uniformly (avoid storing expiring presigned URLs in database) + let configForDatabase = { ...params }; + // 1) Process multiple images in imageUrls if (Array.isArray(params.imageUrls) && params.imageUrls.length > 0) { - const s3Urls = await Promise.all( - (configForDatabase.imageUrls as string[]).map((key) => fileService.getFullFileUrl(key)), - ); - log('Dev: converted proxy URLs to S3 URLs: %O', s3Urls); - updates.imageUrls = s3Urls; + log('Converting imageUrls to S3 keys for database storage: %O', params.imageUrls); + try { + const imageKeysWithNull = await Promise.all( + params.imageUrls.map(async (url) => { + const key = await fileService.getKeyFromFullUrl(url); + if (key) { + log('Converted URL %s to key %s', url, key); + } else { + log('Failed to extract key from URL: %s', url); + } + return key; + }), + ); + const imageKeys = imageKeysWithNull.filter((key): key is string => key !== null); + + configForDatabase = { + ...configForDatabase, + imageUrls: imageKeys, + }; + log('Successfully converted imageUrls to keys for database: %O', imageKeys); + } catch (error) { + console.error('Error converting imageUrls to keys: %O', error); + console.error('Keeping original imageUrls due to conversion error'); + } + } + // 2) Process single image in imageUrl + if (typeof params.imageUrl === 'string' && params.imageUrl) { + try { + const key = await fileService.getKeyFromFullUrl(params.imageUrl); + if (key) { + log('Converted single imageUrl to key: %s -> %s', params.imageUrl, key); + configForDatabase = { ...configForDatabase, imageUrl: key }; + } else { + log('Failed to extract key from single imageUrl: %s', params.imageUrl); + } + } catch (error) { + console.error('Error converting imageUrl to key: %O', error); + // Keep original value if conversion fails + } } - if (Object.keys(updates).length > 0) { - generationParams = { ...params, ...updates }; + // In development, convert localhost proxy URLs to S3 URLs for async task access + let generationParams = params; + if (process.env.NODE_ENV === 'development') { + const updates: Record = {}; + + // Handle single imageUrl: localhost/f/{id} -> S3 URL + if (typeof params.imageUrl === 'string' && params.imageUrl) { + const s3Url = await fileService.getFullFileUrl(configForDatabase.imageUrl as string); + if (s3Url) { + log('Dev: converted proxy URL to S3 URL: %s -> %s', params.imageUrl, s3Url); + updates.imageUrl = s3Url; + } + } + + // Handle multiple imageUrls + if (Array.isArray(params.imageUrls) && params.imageUrls.length > 0) { + const s3Urls = await Promise.all( + (configForDatabase.imageUrls as string[]).map((key) => fileService.getFullFileUrl(key)), + ); + log('Dev: converted proxy URLs to S3 URLs: %O', s3Urls); + updates.imageUrls = s3Urls; + } + + if (Object.keys(updates).length > 0) { + generationParams = { ...params, ...updates }; + } } - } - // Defensive check: ensure no full URLs enter the database - validateNoUrlsInConfig(configForDatabase, 'configForDatabase'); + // Defensive check: ensure no full URLs enter the database + validateNoUrlsInConfig(configForDatabase, 'configForDatabase'); - const chargeResult = await chargeBeforeGenerate({ - clientIp: ctx.clientIp, - configForDatabase, - generationParams, - generationTopicId, - imageNum, - model, - provider, - userId, - }); - if (chargeResult) { - return chargeResult; - } - - // Step 1: Atomically create all database records in a transaction - const { batch: createdBatch, generationsWithTasks } = await serverDB.transaction(async (tx) => { - log('Starting database transaction for image generation'); - - // 1. Create generationBatch - const newBatch: NewGenerationBatch = { - config: configForDatabase, + const chargeResult = await chargeBeforeGenerate({ + clientIp: ctx.clientIp, + configForDatabase, + generationParams, generationTopicId, - height: params.height, + imageNum, model, - prompt: params.prompt, provider, userId, - width: params.width, // Use converted config for database storage - }; - log('Creating generation batch: %O', newBatch); - const [batch] = await tx.insert(generationBatches).values(newBatch).returning(); - log('Generation batch created successfully: %s', batch.id); - - // 2. Create generations - const seeds = - 'seed' in params - ? generateUniqueSeeds(imageNum) - : Array.from({ length: imageNum }, () => null); - const newGenerations: NewGeneration[] = Array.from({ length: imageNum }, (_, index) => { - return { - generationBatchId: batch.id, - seed: seeds[index], - userId, - }; + workspaceId: wsId, }); + if (chargeResult) { + return chargeResult; + } - log('Creating %d generations for batch: %s', newGenerations.length, batch.id); - const createdGenerations = await tx.insert(generations).values(newGenerations).returning(); - log( - 'Generations created successfully: %O', - createdGenerations.map((g) => g.id), - ); + // Step 1: Atomically create all database records in a transaction + const { batch: createdBatch, generationsWithTasks } = await serverDB.transaction( + async (tx) => { + log('Starting database transaction for image generation'); - // 3. Concurrently create asyncTask for each generation (within transaction) - log('Creating async tasks for generations'); - const generationsWithTasks = await Promise.all( - createdGenerations.map(async (generation) => { - // Create asyncTask directly in transaction - const [createdAsyncTask] = await tx - .insert(asyncTasks) - .values({ - status: AsyncTaskStatus.Pending, - type: AsyncTaskType.ImageGeneration, + // 1. Create generationBatch + const newBatch: NewGenerationBatch = { + config: configForDatabase, + generationTopicId, + height: params.height, + model, + prompt: params.prompt, + provider, + userId, + workspaceId: wsId, + width: params.width, // Use converted config for database storage + }; + log('Creating generation batch: %O', newBatch); + const [batch] = await tx.insert(generationBatches).values(newBatch).returning(); + log('Generation batch created successfully: %s', batch.id); + + // 2. Create generations + const seeds = + 'seed' in params + ? generateUniqueSeeds(imageNum) + : Array.from({ length: imageNum }, () => null); + const newGenerations: NewGeneration[] = Array.from({ length: imageNum }, (_, index) => { + return { + generationBatchId: batch.id, + seed: seeds[index], userId, - }) + workspaceId: wsId, + }; + }); + + log('Creating %d generations for batch: %s', newGenerations.length, batch.id); + const createdGenerations = await tx + .insert(generations) + .values(newGenerations) .returning(); + log( + 'Generations created successfully: %O', + createdGenerations.map((g) => g.id), + ); - const asyncTaskId = createdAsyncTask.id; - log('Created async task %s for generation %s', asyncTaskId, generation.id); + // 3. Concurrently create asyncTask for each generation (within transaction) + log('Creating async tasks for generations'); + const generationsWithTasks = await Promise.all( + createdGenerations.map(async (generation) => { + // Create asyncTask directly in transaction + const [createdAsyncTask] = await tx + .insert(asyncTasks) + .values({ + status: AsyncTaskStatus.Pending, + type: AsyncTaskType.ImageGeneration, + userId, + workspaceId: wsId, + }) + .returning(); - // Update generation's asyncTaskId - await tx - .update(generations) - .set({ asyncTaskId }) - .where(and(eq(generations.id, generation.id), eq(generations.userId, userId))); + const asyncTaskId = createdAsyncTask.id; + log('Created async task %s for generation %s', asyncTaskId, generation.id); - return { asyncTaskId, generation }; - }), + // Update generation's asyncTaskId + await tx + .update(generations) + .set({ asyncTaskId }) + .where(and(eq(generations.id, generation.id), eq(generations.userId, userId))); + + return { asyncTaskId, generation }; + }), + ); + log('All async tasks created in transaction'); + + return { + batch, + generationsWithTasks, + }; + }, ); - log('All async tasks created in transaction'); + + log('Database transaction completed successfully. Starting async task triggers directly.'); + + // Step 2: Trigger background image generation tasks using after() API + log('Starting async image generation tasks with after()'); + + try { + log('Creating unified async caller for userId: %s', userId); + + // Async router will read keyVaults from DB, no need to pass jwtPayload + const asyncCaller = await createAsyncCaller({ + userId: ctx.userId, + }); + + log('Unified async caller created successfully for userId: %s', ctx.userId); + log('Processing %d async image generation tasks', generationsWithTasks.length); + + // Fire-and-forget: trigger async tasks without awaiting + // These calls go to the async router which handles them independently + // Do NOT use after() here as it would keep the lambda alive unnecessarily + generationsWithTasks.forEach(({ generation, asyncTaskId }) => { + log('Starting background async task %s for generation %s', asyncTaskId, generation.id); + + asyncCaller.image.createImage({ + generationBatchId: createdBatch.id, + generationId: generation.id, + generationTopicId, + model, + params: generationParams, + provider, + taskId: asyncTaskId, + workspaceId: wsId, + }); + }); + + log('All %d background async image generation tasks started', generationsWithTasks.length); + } catch (e) { + console.error('Failed to process async tasks:', e); + console.error('Failed to process async tasks: %O', e); + + // If overall failure occurs, update all task statuses to failed + try { + await Promise.allSettled( + generationsWithTasks.map(({ asyncTaskId }) => + asyncTaskModel.update(asyncTaskId, { + error: new AsyncTaskError( + AsyncTaskErrorType.ServerError, + 'start async task error: ' + (e instanceof Error ? e.message : 'Unknown error'), + ), + status: AsyncTaskStatus.Error, + }), + ), + ); + } catch (batchUpdateError) { + console.error('Failed to update batch task statuses:', batchUpdateError); + } + } + + const createdGenerations = generationsWithTasks.map((item) => ({ + ...item.generation, + asyncTaskId: item.asyncTaskId, + })); + log('Image creation process completed successfully: %O', { + batchId: createdBatch.id, + generationCount: createdGenerations.length, + generationIds: createdGenerations.map((g) => g.id), + }); return { - batch, - generationsWithTasks, + data: { + batch: createdBatch, + generations: createdGenerations, + }, + success: true, }; - }); - - log('Database transaction completed successfully. Starting async task triggers directly.'); - - // Step 2: Trigger background image generation tasks using after() API - log('Starting async image generation tasks with after()'); - - try { - log('Creating unified async caller for userId: %s', userId); - - // Async router will read keyVaults from DB, no need to pass jwtPayload - const asyncCaller = await createAsyncCaller({ - userId: ctx.userId, - }); - - log('Unified async caller created successfully for userId: %s', ctx.userId); - log('Processing %d async image generation tasks', generationsWithTasks.length); - - // Fire-and-forget: trigger async tasks without awaiting - // These calls go to the async router which handles them independently - // Do NOT use after() here as it would keep the lambda alive unnecessarily - generationsWithTasks.forEach(({ generation, asyncTaskId }) => { - log('Starting background async task %s for generation %s', asyncTaskId, generation.id); - - asyncCaller.image.createImage({ - generationBatchId: createdBatch.id, - generationId: generation.id, - generationTopicId, - model, - params: generationParams, - provider, - taskId: asyncTaskId, - }); - }); - - log('All %d background async image generation tasks started', generationsWithTasks.length); - } catch (e) { - console.error('Failed to process async tasks:', e); - console.error('Failed to process async tasks: %O', e); - - // If overall failure occurs, update all task statuses to failed - try { - await Promise.allSettled( - generationsWithTasks.map(({ asyncTaskId }) => - asyncTaskModel.update(asyncTaskId, { - error: new AsyncTaskError( - AsyncTaskErrorType.ServerError, - 'start async task error: ' + (e instanceof Error ? e.message : 'Unknown error'), - ), - status: AsyncTaskStatus.Error, - }), - ), - ); - } catch (batchUpdateError) { - console.error('Failed to update batch task statuses:', batchUpdateError); - } - } - - const createdGenerations = generationsWithTasks.map((item) => ({ - ...item.generation, - asyncTaskId: item.asyncTaskId, - })); - log('Image creation process completed successfully: %O', { - batchId: createdBatch.id, - generationCount: createdGenerations.length, - generationIds: createdGenerations.map((g) => g.id), - }); - - return { - data: { - batch: createdBatch, - generations: createdGenerations, - }, - success: true, - }; - }), + }), }); export type ImageRouter = typeof imageRouter; diff --git a/src/server/routers/lambda/importer.ts b/src/server/routers/lambda/importer.ts index 4a922d06c6..b043ed9f36 100644 --- a/src/server/routers/lambda/importer.ts +++ b/src/server/routers/lambda/importer.ts @@ -1,26 +1,31 @@ import { TRPCError } from '@trpc/server'; import { z } from 'zod'; +import { withRbacPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { DataImporterRepos } from '@/database/repositories/dataImporter'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { FileService } from '@/server/services/file'; import { type ImportPgDataStructure } from '@/types/export'; import { type ImporterEntryData, type ImportResultData } from '@/types/importer'; -const importProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const importProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - dataImporterService: new DataImporterRepos(ctx.serverDB, ctx.userId), - fileService: new FileService(ctx.serverDB, ctx.userId), + dataImporterService: new DataImporterRepos(ctx.serverDB, ctx.userId, wsId), + fileService: new FileService(ctx.serverDB, ctx.userId, wsId), }, }); }); +const workspaceImportProcedure = importProcedure.use(withRbacPermission('workspace:update:all')); + export const importerRouter = router({ - importByFile: importProcedure + importByFile: workspaceImportProcedure .input(z.object({ pathname: z.string() })) .mutation(async ({ input, ctx }): Promise => { let data: ImporterEntryData | undefined; @@ -54,7 +59,7 @@ export const importerRouter = router({ return result; }), - importByPost: importProcedure + importByPost: workspaceImportProcedure .input( z.object({ data: z.object({ @@ -69,7 +74,7 @@ export const importerRouter = router({ .mutation(async ({ input, ctx }): Promise => { return ctx.dataImporterService.importData(input.data); }), - importPgByPost: importProcedure + importPgByPost: workspaceImportProcedure .input( z.object({ data: z.record(z.string(), z.array(z.any())), diff --git a/src/server/routers/lambda/index.ts b/src/server/routers/lambda/index.ts index b06ead8b55..894ee1aec2 100644 --- a/src/server/routers/lambda/index.ts +++ b/src/server/routers/lambda/index.ts @@ -9,6 +9,13 @@ import { storageOverageRouter } from '@/business/server/lambda-routers/storageOv import { subscriptionRouter } from '@/business/server/lambda-routers/subscription'; import { taskTemplateRouter } from '@/business/server/lambda-routers/taskTemplate'; import { topUpRouter } from '@/business/server/lambda-routers/topUp'; +import { workspaceRouter } from '@/business/server/lambda-routers/workspace'; +import { workspaceAuditLogRouter } from '@/business/server/lambda-routers/workspaceAuditLog'; +import { workspaceCreditsRouter } from '@/business/server/lambda-routers/workspaceCredits'; +import { workspaceCredsRouter } from '@/business/server/lambda-routers/workspaceCreds'; +import { workspaceDataRouter } from '@/business/server/lambda-routers/workspaceData'; +import { workspaceMemberRouter } from '@/business/server/lambda-routers/workspaceMember'; +import { workspaceUsageRouter } from '@/business/server/lambda-routers/workspaceUsage'; import { publicProcedure, router } from '@/libs/trpc/lambda'; import { agentRouter } from './agent'; @@ -136,6 +143,13 @@ export const lambdaRouter = router({ verify: verifyRouter, video: videoRouter, webBrowsing: webBrowsingRouter, + workspace: workspaceRouter, + workspaceAuditLog: workspaceAuditLogRouter, + workspaceCreds: workspaceCredsRouter, + workspaceCredits: workspaceCreditsRouter, + workspaceData: workspaceDataRouter, + workspaceMember: workspaceMemberRouter, + workspaceUsage: workspaceUsageRouter, accountDeletion: accountDeletionRouter, pageShare: pageShareRouter, referral: referralRouter, diff --git a/src/server/routers/lambda/klavis.ts b/src/server/routers/lambda/klavis.ts index ec7017c0a9..5b3cd6efe1 100644 --- a/src/server/routers/lambda/klavis.ts +++ b/src/server/routers/lambda/klavis.ts @@ -1,17 +1,20 @@ import { type ToolManifest } from '@lobechat/types'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { PluginModel } from '@/database/models/plugin'; import { getKlavisClient } from '@/libs/klavis'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; /** * Klavis procedure with API key validation and database access */ -const klavisProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const klavisProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const client = getKlavisClient(); - const pluginModel = new PluginModel(opts.ctx.serverDB, opts.ctx.userId); + const wsId = opts.ctx.workspaceId ?? undefined; + const pluginModel = new PluginModel(opts.ctx.serverDB, opts.ctx.userId, wsId); return opts.next({ ctx: { ...opts.ctx, klavisClient: client, pluginModel }, @@ -24,6 +27,7 @@ export const klavisRouter = router({ * Returns: { serverUrl, instanceId, oauthUrl?, identifier, serverName } */ createServerInstance: klavisProcedure + .use(withScopedPermission('agent:update')) .input( z.object({ /** Identifier for storage (e.g., 'google-calendar') */ @@ -96,6 +100,7 @@ export const klavisRouter = router({ * Delete a server instance */ deleteServerInstance: klavisProcedure + .use(withScopedPermission('agent:update')) .input( z.object({ /** Identifier for storage (e.g., 'google-calendar') */ @@ -190,6 +195,7 @@ export const klavisRouter = router({ * Remove Klavis plugin from database by identifier */ removeKlavisPlugin: klavisProcedure + .use(withScopedPermission('agent:update')) .input( z.object({ /** Identifier for storage (e.g., 'google-calendar') */ @@ -205,6 +211,7 @@ export const klavisRouter = router({ * Update Klavis plugin with tools and auth status in database */ updateKlavisPlugin: klavisProcedure + .use(withScopedPermission('agent:update')) .input( z.object({ /** Identifier for storage (e.g., 'google-calendar') */ diff --git a/src/server/routers/lambda/knowledge.ts b/src/server/routers/lambda/knowledge.ts index 84181479ee..a01b2b49d6 100644 --- a/src/server/routers/lambda/knowledge.ts +++ b/src/server/routers/lambda/knowledge.ts @@ -1,26 +1,28 @@ +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { AsyncTaskModel } from '@/database/models/asyncTask'; import { ChunkModel } from '@/database/models/chunk'; import { DocumentModel } from '@/database/models/document'; import { FileModel } from '@/database/models/file'; import { KnowledgeRepo } from '@/database/repositories/knowledge'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { FileService } from '@/server/services/file'; import { AsyncTaskStatus, AsyncTaskType } from '@/types/asyncTask'; import { type FileListItem } from '@/types/files'; import { QueryFileListSchema } from '@/types/files'; -const knowledgeProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const knowledgeProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId), - chunkModel: new ChunkModel(ctx.serverDB, ctx.userId), - documentModel: new DocumentModel(ctx.serverDB, ctx.userId), - fileModel: new FileModel(ctx.serverDB, ctx.userId), - fileService: new FileService(ctx.serverDB, ctx.userId), - knowledgeRepo: new KnowledgeRepo(ctx.serverDB, ctx.userId), + asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId, wsId), + chunkModel: new ChunkModel(ctx.serverDB, ctx.userId, wsId), + documentModel: new DocumentModel(ctx.serverDB, ctx.userId, wsId), + fileModel: new FileModel(ctx.serverDB, ctx.userId, wsId), + fileService: new FileService(ctx.serverDB, ctx.userId, wsId), + knowledgeRepo: new KnowledgeRepo(ctx.serverDB, ctx.userId, wsId), }, }); }); diff --git a/src/server/routers/lambda/knowledgeBase.ts b/src/server/routers/lambda/knowledgeBase.ts index cea518f091..1bd75ea00e 100644 --- a/src/server/routers/lambda/knowledgeBase.ts +++ b/src/server/routers/lambda/knowledgeBase.ts @@ -1,26 +1,33 @@ import { TRPCError } from '@trpc/server'; +import { and, eq, isNull } from 'drizzle-orm'; import { z } from 'zod'; +import { businessFileTransferStorageCheck } from '@/business/server/lambda-routers/file'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { serverDBEnv } from '@/config/db'; import { KnowledgeBaseModel } from '@/database/models/knowledgeBase'; -import { insertKnowledgeBasesSchema } from '@/database/schemas'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { insertKnowledgeBasesSchema, workspaceMembers } from '@/database/schemas'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { FileService } from '@/server/services/file'; import { type KnowledgeBaseItem } from '@/types/knowledgeBase'; +import { TransferErrorCode } from '@/types/transferError'; -const knowledgeBaseProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const knowledgeBaseProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - knowledgeBaseModel: new KnowledgeBaseModel(ctx.serverDB, ctx.userId), + knowledgeBaseModel: new KnowledgeBaseModel(ctx.serverDB, ctx.userId, wsId), }, }); }); export const knowledgeBaseRouter = router({ addFilesToKnowledgeBase: knowledgeBaseProcedure + .use(withScopedPermission('knowledge_base:update')) .input(z.object({ ids: z.array(z.string()), knowledgeBaseId: z.string() })) .mutation(async ({ input, ctx }) => { try { @@ -42,6 +49,7 @@ export const knowledgeBaseRouter = router({ }), createKnowledgeBase: knowledgeBaseProcedure + .use(withScopedPermission('knowledge_base:create')) .input( z.object({ avatar: z.string().optional(), @@ -59,6 +67,55 @@ export const knowledgeBaseRouter = router({ return data?.id; }), + copyKnowledgeBaseToWorkspace: knowledgeBaseProcedure + .use(withScopedPermission('knowledge_base:create')) + .input( + z.object({ + id: z.string(), + targetWorkspaceId: z.string().nullable(), + }), + ) + .mutation(async ({ input, ctx }) => { + const knowledgeBase = await ctx.knowledgeBaseModel.findById(input.id); + if (!knowledgeBase) { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.ResourceNotFound } }, + code: 'NOT_FOUND', + message: 'Knowledge base not found', + }); + } + + if (input.targetWorkspaceId) { + const [targetMembership] = await ctx.serverDB + .select({ role: workspaceMembers.role }) + .from(workspaceMembers) + .where( + and( + eq(workspaceMembers.workspaceId, input.targetWorkspaceId), + eq(workspaceMembers.userId, ctx.userId), + isNull(workspaceMembers.deletedAt), + ), + ) + .limit(1); + if (!targetMembership || targetMembership.role === 'viewer') { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.TargetNoWriteAccess } }, + code: 'FORBIDDEN', + message: 'No write access to target workspace', + }); + } + } + + const additionalSize = await ctx.knowledgeBaseModel.countFileUsage(input.id); + await businessFileTransferStorageCheck({ + additionalSize, + targetUserId: ctx.userId, + targetWorkspaceId: input.targetWorkspaceId, + }); + + return ctx.knowledgeBaseModel.copyToWorkspace(input.id, input.targetWorkspaceId, ctx.userId); + }), + getKnowledgeBaseById: knowledgeBaseProcedure .input(z.object({ id: z.string() })) .query(async ({ ctx, input }): Promise => { @@ -69,25 +126,31 @@ export const knowledgeBaseRouter = router({ return ctx.knowledgeBaseModel.query(); }), - removeAllKnowledgeBases: knowledgeBaseProcedure.mutation(async ({ ctx }) => { - const result = await ctx.knowledgeBaseModel.deleteAllWithFiles(serverDBEnv.REMOVE_GLOBAL_FILE); + removeAllKnowledgeBases: knowledgeBaseProcedure + .use(withScopedPermission('knowledge_base:delete')) + .mutation(async ({ ctx }) => { + const result = await ctx.knowledgeBaseModel.deleteAllWithFiles( + serverDBEnv.REMOVE_GLOBAL_FILE, + ); - if (result.deletedFiles.length > 0) { - const fileService = new FileService(ctx.serverDB, ctx.userId); - const urls = result.deletedFiles.map((f) => f.url).filter(Boolean) as string[]; - if (urls.length > 0) { - await fileService.deleteFiles(urls); + if (result.deletedFiles.length > 0) { + const fileService = new FileService(ctx.serverDB, ctx.userId, ctx.workspaceId ?? undefined); + const urls = result.deletedFiles.map((f) => f.url).filter(Boolean) as string[]; + if (urls.length > 0) { + await fileService.deleteFiles(urls); + } } - } - }), + }), removeFilesFromKnowledgeBase: knowledgeBaseProcedure + .use(withScopedPermission('knowledge_base:update')) .input(z.object({ ids: z.array(z.string()), knowledgeBaseId: z.string() })) .mutation(async ({ input, ctx }) => { return ctx.knowledgeBaseModel.removeFilesFromKnowledgeBase(input.knowledgeBaseId, input.ids); }), removeKnowledgeBase: knowledgeBaseProcedure + .use(withScopedPermission('knowledge_base:delete')) .input(z.object({ id: z.string() })) .mutation(async ({ input, ctx }) => { const result = await ctx.knowledgeBaseModel.deleteWithFiles( @@ -96,7 +159,7 @@ export const knowledgeBaseRouter = router({ ); if (result.deletedFiles.length > 0) { - const fileService = new FileService(ctx.serverDB, ctx.userId); + const fileService = new FileService(ctx.serverDB, ctx.userId, ctx.workspaceId ?? undefined); const urls = result.deletedFiles.map((f) => f.url).filter(Boolean) as string[]; if (urls.length > 0) { await fileService.deleteFiles(urls); @@ -104,7 +167,65 @@ export const knowledgeBaseRouter = router({ } }), + transferKnowledgeBase: knowledgeBaseProcedure + .use(withScopedPermission('knowledge_base:create')) + .input( + z.object({ + id: z.string(), + targetWorkspaceId: z.string().nullable(), + }), + ) + .mutation(async ({ input, ctx }) => { + if (input.targetWorkspaceId === (ctx.workspaceId ?? null)) { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.SameWorkspace } }, + code: 'BAD_REQUEST', + message: 'Cannot transfer to the same workspace', + }); + } + + const knowledgeBase = await ctx.knowledgeBaseModel.findById(input.id); + if (!knowledgeBase) { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.ResourceNotFound } }, + code: 'NOT_FOUND', + message: 'Knowledge base not found', + }); + } + + if (input.targetWorkspaceId) { + const [targetMembership] = await ctx.serverDB + .select({ role: workspaceMembers.role }) + .from(workspaceMembers) + .where( + and( + eq(workspaceMembers.workspaceId, input.targetWorkspaceId), + eq(workspaceMembers.userId, ctx.userId), + isNull(workspaceMembers.deletedAt), + ), + ) + .limit(1); + if (!targetMembership || targetMembership.role === 'viewer') { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.TargetNoWriteAccess } }, + code: 'FORBIDDEN', + message: 'No write access to target workspace', + }); + } + } + + const additionalSize = await ctx.knowledgeBaseModel.countFileUsage(input.id); + await businessFileTransferStorageCheck({ + additionalSize, + targetUserId: ctx.userId, + targetWorkspaceId: input.targetWorkspaceId, + }); + + return ctx.knowledgeBaseModel.transferTo(input.id, input.targetWorkspaceId, ctx.userId); + }), + updateKnowledgeBase: knowledgeBaseProcedure + .use(withScopedPermission('knowledge_base:update')) .input( z.object({ id: z.string(), diff --git a/src/server/routers/lambda/llmGenerationTracing.ts b/src/server/routers/lambda/llmGenerationTracing.ts index 09ec8e0f45..3527cc472d 100644 --- a/src/server/routers/lambda/llmGenerationTracing.ts +++ b/src/server/routers/lambda/llmGenerationTracing.ts @@ -43,12 +43,17 @@ export const llmGenerationTracingRouter = router({ ) .mutation(async ({ ctx, input }) => { try { - await getLLMGenerationTracingService().recordFeedback(ctx.userId, input.tracingId, { - data: input.data, - score: input.score, - signal: input.signal, - source: input.source, - }); + await getLLMGenerationTracingService().recordFeedback( + ctx.userId, + input.tracingId, + { + data: input.data, + score: input.score, + signal: input.signal, + source: input.source, + }, + ctx.workspaceId ?? undefined, + ); } catch (err) { if (err instanceof LLMGenerationFeedbackError) { throw new TRPCError({ diff --git a/src/server/routers/lambda/market/agent.test.ts b/src/server/routers/lambda/market/agent.test.ts new file mode 100644 index 0000000000..0778b1e013 --- /dev/null +++ b/src/server/routers/lambda/market/agent.test.ts @@ -0,0 +1,98 @@ +// @vitest-environment node +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import { agentRouter } from './agent'; + +const { mockMarketSDK, mockCreateAgentVersionHeader } = vi.hoisted(() => { + const mockCreateAgentVersionHeader = vi.fn(); + const mockMarketSDK = { + agents: { + createAgent: vi.fn(), + createAgentVersion: vi.fn(async () => { + mockCreateAgentVersionHeader(mockMarketSDK.headers['x-lobe-owner-account-id']); + return { success: true }; + }), + getAgentDetail: vi.fn(), + }, + headers: {} as Record, + }; + + return { mockCreateAgentVersionHeader, mockMarketSDK }; +}); + +vi.mock('@/business/server/trpc-middlewares/rbacPermission', () => ({ + withScopedPermission: vi.fn(() => (opts: any) => opts.next({ ctx: opts.ctx })), +})); + +vi.mock('@/database/models/user', () => ({ + UserModel: vi.fn(() => ({ + getUserState: vi.fn(async () => ({ settings: {} })), + })), +})); + +vi.mock('@/libs/trpc/lambda/middleware', () => ({ + marketSDK: vi.fn((opts: any) => + opts.next({ + ctx: { + ...opts.ctx, + marketSDK: mockMarketSDK, + }, + }), + ), + marketUserInfo: vi.fn((opts: any) => + opts.next({ + ctx: { + ...opts.ctx, + marketUserInfo: { email: 'actor@example.com', name: 'Actor', userId: 'user-1' }, + }, + }), + ), + serverDatabase: vi.fn((opts: any) => opts.next({ ctx: opts.ctx })), +})); + +vi.mock('@/libs/trusted-client', () => ({ + generateTrustedClientToken: vi.fn(() => 'trust-token'), +})); + +describe('agentRouter.publishOrCreate', () => { + let fetchSpy: ReturnType; + + beforeEach(() => { + vi.clearAllMocks(); + mockMarketSDK.headers = {}; + mockMarketSDK.agents.getAgentDetail.mockResolvedValue({ + identifier: 'existing-agent', + name: 'Existing Agent', + ownerId: 123, + }); + fetchSpy = vi.spyOn(globalThis, 'fetch' as never); + fetchSpy.mockResolvedValue( + new Response(JSON.stringify({ accountId: 999, sub: 'user-1' }), { + status: 200, + }), + ); + }); + + it('uses the acting organization account for ownership checks and version uploads', async () => { + const caller = agentRouter.createCaller({ serverDB: {}, userId: 'user-1' } as any); + + const result = await caller.publishOrCreate({ + actAs: 123, + identifier: 'existing-agent', + name: 'Existing Agent', + }); + + expect(result).toEqual({ + identifier: 'existing-agent', + isNewAgent: false, + success: true, + }); + expect(mockMarketSDK.agents.createAgent).not.toHaveBeenCalled(); + expect(mockMarketSDK.agents.createAgentVersion).toHaveBeenCalledWith({ + identifier: 'existing-agent', + name: 'Existing Agent', + }); + expect(mockCreateAgentVersionHeader).toHaveBeenCalledWith('123'); + expect(mockMarketSDK.headers['x-lobe-owner-account-id']).toBeUndefined(); + }); +}); diff --git a/src/server/routers/lambda/market/agent.ts b/src/server/routers/lambda/market/agent.ts index 3ba173fc86..e71a09a1ea 100644 --- a/src/server/routers/lambda/market/agent.ts +++ b/src/server/routers/lambda/market/agent.ts @@ -3,6 +3,7 @@ import debug from 'debug'; import { customAlphabet } from 'nanoid/non-secure'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; import { authedProcedure, router } from '@/libs/trpc/lambda'; import { marketSDK, marketUserInfo, serverDatabase } from '@/libs/trpc/lambda/middleware'; import { type TrustedClientUserInfo } from '@/libs/trusted-client'; @@ -117,7 +118,36 @@ const buildMarketAuthHeaders = (ctx: { return headers; }; +const withActingAccountHeader = async ( + marketSDK: unknown, + actAs: number | undefined, + operation: () => Promise, +): Promise => { + const headers = (marketSDK as { headers?: Record }).headers; + if (actAs === undefined || !headers) return operation(); + + const previous = headers['x-lobe-owner-account-id']; + headers['x-lobe-owner-account-id'] = String(actAs); + + try { + return await operation(); + } finally { + if (previous === undefined) { + delete headers['x-lobe-owner-account-id']; + } else { + headers['x-lobe-owner-account-id'] = previous; + } + } +}; + interface ForkAgentItemInput { + /** + * When present, fork is attributed to the given Market organization account. + * Forwarded as `X-Lobe-Owner-Account-Id` so the Market resolves writes to the + * organization instead of the calling user. The actor must be a member of the + * target org (enforced server-side by `resolveActingAccount`). + */ + actAs?: number; identifier: string; name?: string; sourceIdentifier: string; @@ -128,10 +158,15 @@ interface ForkAgentItemInput { const forkOneAgent = async ( item: ForkAgentItemInput, - headers: Record, + baseHeaders: Record, ): Promise => { try { const forkUrl = `${MARKET_BASE_URL}/api/v1/agents/${item.sourceIdentifier}/fork`; + // Clone so per-item actAs doesn't leak across the batch. + const headers = { ...baseHeaders }; + if (item.actAs !== undefined) { + headers['x-lobe-owner-account-id'] = String(item.actAs); + } const response = await fetch(forkUrl, { body: JSON.stringify({ identifier: item.identifier, @@ -180,6 +215,12 @@ const forkOneAgent = async ( }; const forkAgentItemSchema = z.object({ + /** + * Optional Market organization account id to attribute the fork to. Triggers + * `X-Lobe-Owner-Account-Id` on the fork request. Caller is responsible for + * resolving the organization account id before passing this field. + */ + actAs: z.number().int().positive().optional(), identifier: z.string(), name: z.string().optional(), sourceIdentifier: z.string(), @@ -216,9 +257,11 @@ const agentProcedure = authedProcedure }, }); }); +const agentWriteProcedure = agentProcedure.use(withScopedPermission('agent:create')); // Schema definitions const createAgentSchema = z.object({ + actAs: z.number().int().positive().optional(), homepage: z.string().optional(), identifier: z.string(), isFeatured: z.boolean().optional(), @@ -229,6 +272,7 @@ const createAgentSchema = z.object({ }); const createAgentVersionSchema = z.object({ + actAs: z.number().int().positive().optional(), a2aProtocolVersion: z.string().optional(), avatar: z.string().optional(), category: z.string().optional(), @@ -263,6 +307,7 @@ const paginationSchema = z.object({ // Schema for the unified publish/create flow const publishOrCreateSchema = z.object({ + actAs: z.number().int().positive().optional(), // Version data avatar: z.string().optional(), @@ -352,11 +397,14 @@ export const agentRouter = router({ * Create a new agent in the marketplace * POST /market/agent/create */ - createAgent: agentProcedure.input(createAgentSchema).mutation(async ({ input, ctx }) => { + createAgent: agentWriteProcedure.input(createAgentSchema).mutation(async ({ input, ctx }) => { log('createAgent input: %O', input); try { - const response = await ctx.marketSDK.agents.createAgent(input); + const { actAs, ...agentData } = input; + const response = await withActingAccountHeader(ctx.marketSDK, actAs, () => + ctx.marketSDK.agents.createAgent(agentData), + ); return response; } catch (error) { log('Error creating agent: %O', error); @@ -372,13 +420,16 @@ export const agentRouter = router({ * Create a new version for an existing agent * POST /market/agent/versions/create */ - createAgentVersion: agentProcedure + createAgentVersion: agentWriteProcedure .input(createAgentVersionSchema) .mutation(async ({ input, ctx }) => { log('createAgentVersion input: %O', input); try { - const response = await ctx.marketSDK.agents.createAgentVersion(input); + const { actAs, ...versionData } = input; + const response = await withActingAccountHeader(ctx.marketSDK, actAs, () => + ctx.marketSDK.agents.createAgentVersion(versionData), + ); return response; } catch (error) { log('Error creating agent version: %O', error); @@ -394,7 +445,7 @@ export const agentRouter = router({ * Deprecate an agent (permanently hide, cannot be republished) * POST /market/agent/:identifier/deprecate */ - deprecateAgent: agentProcedure + deprecateAgent: agentWriteProcedure .input(z.object({ identifier: z.string() })) .mutation(async ({ input, ctx }) => { log('deprecateAgent input: %O', input); @@ -419,7 +470,7 @@ export const agentRouter = router({ * Best-effort: single-item failures are returned in-line as * `{ success: false, error }` and do not abort the rest of the batch. */ - forkAgent: agentProcedure + forkAgent: agentWriteProcedure .input(z.object({ items: z.array(forkAgentItemSchema).min(1) })) .mutation(async ({ input, ctx }) => { log('forkAgent batch size: %d', input.items.length); @@ -649,7 +700,7 @@ export const agentRouter = router({ * Publish an agent (make it visible in marketplace) * POST /market/agent/:identifier/publish */ - publishAgent: agentProcedure + publishAgent: agentWriteProcedure .input(z.object({ identifier: z.string() })) .mutation(async ({ input, ctx }) => { log('publishAgent input: %O', input); @@ -676,93 +727,101 @@ export const agentRouter = router({ * * Returns: { identifier, isNewAgent, success } */ - publishOrCreate: agentProcedure.input(publishOrCreateSchema).mutation(async ({ input, ctx }) => { - log('publishOrCreate input: %O', input); + publishOrCreate: agentWriteProcedure + .input(publishOrCreateSchema) + .mutation(async ({ input, ctx }) => { + log('publishOrCreate input: %O', input); - const { identifier: inputIdentifier, name, ...versionData } = input; - let finalIdentifier = inputIdentifier; - let isNewAgent = false; + const { actAs, identifier: inputIdentifier, name, ...versionData } = input; + let finalIdentifier = inputIdentifier; + let isNewAgent = false; - try { - // Step 1: Check ownership if identifier is provided - if (inputIdentifier) { - try { - const agentDetail = await ctx.marketSDK.agents.getAgentDetail(inputIdentifier); - log('Agent detail for ownership check: ownerId=%s', agentDetail?.ownerId); + try { + // Step 1: Check ownership if identifier is provided + if (inputIdentifier) { + try { + const agentDetail = await ctx.marketSDK.agents.getAgentDetail(inputIdentifier); + log('Agent detail for ownership check: ownerId=%s', agentDetail?.ownerId); - // Get Market user info to get accountId (Market's user ID) - // Support both trustedClientToken and OIDC accessToken authentication - const userInfo = ctx.marketUserInfo as TrustedClientUserInfo | undefined; - const accessToken = (ctx as { marketOidcAccessToken?: string }).marketOidcAccessToken; - let currentAccountId: number | null = null; + // Get Market user info to get accountId (Market's user ID) + // Support both trustedClientToken and OIDC accessToken authentication + const userInfo = ctx.marketUserInfo as TrustedClientUserInfo | undefined; + const accessToken = (ctx as { marketOidcAccessToken?: string }).marketOidcAccessToken; + let currentAccountId: number | null = null; - const marketUserInfoResult = await fetchMarketUserInfo({ accessToken, userInfo }); - currentAccountId = marketUserInfoResult?.accountId ?? null; - log('Market user info: accountId=%s', currentAccountId); + const marketUserInfoResult = await fetchMarketUserInfo({ accessToken, userInfo }); + currentAccountId = marketUserInfoResult?.accountId ?? null; + log('Market user info: accountId=%s', currentAccountId); - const ownerId = agentDetail?.ownerId; + const ownerId = agentDetail?.ownerId; - log('Ownership check: currentAccountId=%s, ownerId=%s', currentAccountId, ownerId); + log('Ownership check: currentAccountId=%s, ownerId=%s', currentAccountId, ownerId); - if (!currentAccountId || `${ownerId}` !== `${currentAccountId}`) { - // Not the owner, need to create a new agent - log('User is not owner, will create new agent'); + const actingAccountId = actAs ?? currentAccountId; + + if (!actingAccountId || `${ownerId}` !== `${actingAccountId}`) { + // Not the owner, need to create a new agent + log('User is not owner, will create new agent'); + finalIdentifier = undefined; + isNewAgent = true; + } + } catch (detailError) { + // Agent not found or error, create new + log('Agent not found or error, will create new: %O', detailError); finalIdentifier = undefined; isNewAgent = true; } - } catch (detailError) { - // Agent not found or error, create new - log('Agent not found or error, will create new: %O', detailError); - finalIdentifier = undefined; + } else { isNewAgent = true; } - } else { - isNewAgent = true; - } - // Step 2: Create new agent if needed - if (!finalIdentifier) { - // Generate a unique 8-character identifier - finalIdentifier = generateMarketIdentifier(); - isNewAgent = true; + // Step 2: Create new agent if needed + if (!finalIdentifier) { + // Generate a unique 8-character identifier + finalIdentifier = generateMarketIdentifier(); + isNewAgent = true; - log('Creating new agent with identifier: %s', finalIdentifier); + log('Creating new agent with identifier: %s', finalIdentifier); - await ctx.marketSDK.agents.createAgent({ + await withActingAccountHeader(ctx.marketSDK, actAs, () => + ctx.marketSDK.agents.createAgent({ + identifier: finalIdentifier!, + name, + }), + ); + } + + // Step 3: Create version for the agent + log('Creating version for agent: %s', finalIdentifier); + + await withActingAccountHeader(ctx.marketSDK, actAs, () => + ctx.marketSDK.agents.createAgentVersion({ + ...versionData, + identifier: finalIdentifier!, + name, + }), + ); + + return { identifier: finalIdentifier, - name, + isNewAgent, + success: true, + }; + } catch (error) { + log('Error in publishOrCreate: %O', error); + throw new TRPCError({ + cause: error, + code: 'INTERNAL_SERVER_ERROR', + message: error instanceof Error ? error.message : 'Failed to publish agent', }); } - - // Step 3: Create version for the agent - log('Creating version for agent: %s', finalIdentifier); - - await ctx.marketSDK.agents.createAgentVersion({ - ...versionData, - identifier: finalIdentifier, - name, - }); - - return { - identifier: finalIdentifier, - isNewAgent, - success: true, - }; - } catch (error) { - log('Error in publishOrCreate: %O', error); - throw new TRPCError({ - cause: error, - code: 'INTERNAL_SERVER_ERROR', - message: error instanceof Error ? error.message : 'Failed to publish agent', - }); - } - }), + }), /** * Unpublish an agent (hide from marketplace, can be republished) * POST /market/agent/:identifier/unpublish */ - unpublishAgent: agentProcedure + unpublishAgent: agentWriteProcedure .input(z.object({ identifier: z.string() })) .mutation(async ({ input, ctx }) => { log('unpublishAgent input: %O', input); diff --git a/src/server/routers/lambda/market/agentGroup.test.ts b/src/server/routers/lambda/market/agentGroup.test.ts new file mode 100644 index 0000000000..0181d021db --- /dev/null +++ b/src/server/routers/lambda/market/agentGroup.test.ts @@ -0,0 +1,120 @@ +// @vitest-environment node +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import { agentGroupRouter } from './agentGroup'; + +const { mockMarketSDK } = vi.hoisted(() => ({ + mockMarketSDK: { + agentGroups: { + getAgentGroupDetail: vi.fn(), + }, + headers: {} as Record, + }, +})); + +vi.mock('@/business/server/trpc-middlewares/rbacPermission', () => ({ + withScopedPermission: vi.fn(() => (opts: any) => opts.next({ ctx: opts.ctx })), +})); + +vi.mock('@/database/models/user', () => ({ + UserModel: vi.fn(() => ({ + getUserState: vi.fn(async () => ({ settings: {} })), + })), +})); + +vi.mock('@/libs/trpc/lambda/middleware', () => ({ + marketSDK: vi.fn((opts: any) => + opts.next({ + ctx: { + ...opts.ctx, + marketSDK: mockMarketSDK, + }, + }), + ), + marketUserInfo: vi.fn((opts: any) => + opts.next({ + ctx: { + ...opts.ctx, + marketUserInfo: { email: 'actor@example.com', name: 'Actor', userId: 'user-1' }, + }, + }), + ), + serverDatabase: vi.fn((opts: any) => opts.next({ ctx: opts.ctx })), +})); + +vi.mock('@/libs/trusted-client', () => ({ + generateTrustedClientToken: vi.fn(() => 'trust-token'), +})); + +describe('agentGroupRouter.forkAgentGroup', () => { + let fetchSpy: ReturnType; + + beforeEach(() => { + vi.clearAllMocks(); + mockMarketSDK.headers = {}; + fetchSpy = vi.spyOn(globalThis, 'fetch' as never); + fetchSpy.mockResolvedValue( + new Response( + JSON.stringify({ + group: { identifier: 'forked-group' }, + groupVersion: { versionNumber: 1 }, + memberAgents: [], + }), + { status: 200 }, + ), + ); + }); + + it('attributes workspace forks to the acting organization account', async () => { + const caller = agentGroupRouter.createCaller({ serverDB: {}, userId: 'user-1' } as any); + + await caller.forkAgentGroup({ + actAs: 321, + identifier: 'forked-group', + name: 'Forked Group', + sourceIdentifier: 'source-group', + status: 'published', + visibility: 'public', + } as unknown as Parameters[0]); + + const [, requestInit] = fetchSpy.mock.calls[0] as [string, RequestInit]; + expect((requestInit.headers as Record)['x-lobe-owner-account-id']).toBe('321'); + }); +}); + +describe('agentGroupRouter.checkOwnership', () => { + let fetchSpy: ReturnType; + + beforeEach(() => { + vi.clearAllMocks(); + mockMarketSDK.agentGroups.getAgentGroupDetail.mockResolvedValue({ + group: { + avatar: '👥', + identifier: 'workspace-group', + name: 'Workspace Group', + ownerId: 321, + }, + }); + fetchSpy = vi.spyOn(globalThis, 'fetch' as never); + fetchSpy.mockResolvedValue( + new Response(JSON.stringify({ accountId: 999, sub: 'user-1' }), { + status: 200, + }), + ); + }); + + it('uses the acting organization account when checking workspace-owned groups', async () => { + const caller = agentGroupRouter.createCaller({ serverDB: {}, userId: 'user-1' } as any); + + const result = await caller.checkOwnership({ + actAs: 321, + identifier: 'workspace-group', + } as unknown as Parameters[0]); + + expect(result).toMatchObject({ + exists: true, + isOwner: true, + originalGroup: null, + }); + }); +}); diff --git a/src/server/routers/lambda/market/agentGroup.ts b/src/server/routers/lambda/market/agentGroup.ts index e9b970d378..1e2edebf61 100644 --- a/src/server/routers/lambda/market/agentGroup.ts +++ b/src/server/routers/lambda/market/agentGroup.ts @@ -3,6 +3,7 @@ import debug from 'debug'; import { customAlphabet } from 'nanoid/non-secure'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; import { authedProcedure, router } from '@/libs/trpc/lambda'; import { marketSDK, marketUserInfo, serverDatabase } from '@/libs/trpc/lambda/middleware'; import { type TrustedClientUserInfo } from '@/libs/trusted-client'; @@ -82,6 +83,28 @@ const fetchMarketUserInfo = async ( } }; +const withActingAccountHeader = async ( + marketSDK: unknown, + actAs: number | undefined, + operation: () => Promise, +): Promise => { + const headers = (marketSDK as { headers?: Record }).headers; + if (actAs === undefined || !headers) return operation(); + + const previous = headers['x-lobe-owner-account-id']; + headers['x-lobe-owner-account-id'] = String(actAs); + + try { + return await operation(); + } finally { + if (previous === undefined) { + delete headers['x-lobe-owner-account-id']; + } else { + headers['x-lobe-owner-account-id'] = previous; + } + } +}; + // Authenticated procedure for agent group management const agentGroupProcedure = authedProcedure .use(serverDatabase) @@ -106,6 +129,7 @@ const agentGroupProcedure = authedProcedure }, }); }); +const agentGroupWriteProcedure = agentGroupProcedure.use(withScopedPermission('agent:create')); // Schema definitions const memberAgentSchema = z.object({ @@ -121,6 +145,7 @@ const memberAgentSchema = z.object({ }); const publishOrCreateGroupSchema = z.object({ + actAs: z.number().int().positive().optional(), avatar: z.string().nullish(), backgroundColor: z.string().nullish(), category: z.string().optional(), @@ -146,7 +171,7 @@ export const agentGroupRouter = router({ * Check if current user owns the specified group */ checkOwnership: agentGroupProcedure - .input(z.object({ identifier: z.string() })) + .input(z.object({ actAs: z.number().int().positive().optional(), identifier: z.string() })) .query(async ({ input, ctx }) => { log('checkOwnership input: %O', input); @@ -169,12 +194,14 @@ export const agentGroupRouter = router({ currentAccountId = marketUserInfoResult?.accountId ?? null; const ownerId = groupDetail.group.ownerId; - const isOwner = currentAccountId !== null && `${ownerId}` === `${currentAccountId}`; + const actingAccountId = input.actAs ?? currentAccountId; + const isOwner = actingAccountId !== null && `${ownerId}` === `${actingAccountId}`; log( - 'checkOwnership result: isOwner=%s, currentAccountId=%s, ownerId=%s', + 'checkOwnership result: isOwner=%s, currentAccountId=%s, actingAccountId=%s, ownerId=%s', isOwner, currentAccountId, + actingAccountId, ownerId, ); @@ -205,7 +232,7 @@ export const agentGroupRouter = router({ * Deprecate agent group * POST /market/agent-group/:identifier/deprecate */ - deprecateAgentGroup: agentGroupProcedure + deprecateAgentGroup: agentGroupWriteProcedure .input(z.object({ identifier: z.string() })) .mutation(async ({ input, ctx }) => { log('deprecateAgentGroup input: %O', input); @@ -263,9 +290,10 @@ export const agentGroupRouter = router({ * Fork an agent group * POST /market/agent-group/:identifier/fork */ - forkAgentGroup: agentGroupProcedure + forkAgentGroup: agentGroupWriteProcedure .input( z.object({ + actAs: z.number().int().positive().optional(), identifier: z.string(), name: z.string().optional(), sourceIdentifier: z.string(), @@ -300,6 +328,10 @@ export const agentGroupRouter = router({ headers['Authorization'] = `Bearer ${accessToken}`; } + if (input.actAs !== undefined) { + headers['x-lobe-owner-account-id'] = String(input.actAs); + } + const response = await fetch(forkUrl, { body: JSON.stringify({ identifier: input.identifier, @@ -590,7 +622,7 @@ export const agentGroupRouter = router({ * Publish agent group * POST /market/agent-group/:identifier/publish */ - publishAgentGroup: agentGroupProcedure + publishAgentGroup: agentGroupWriteProcedure .input(z.object({ identifier: z.string() })) .mutation(async ({ input, ctx }) => { log('publishAgentGroup input: %O', input); @@ -650,12 +682,12 @@ export const agentGroupRouter = router({ * 2. If not owner or no identifier, create new group * 3. Create new version for the group if updating */ - publishOrCreate: agentGroupProcedure + publishOrCreate: agentGroupWriteProcedure .input(publishOrCreateGroupSchema) .mutation(async ({ input, ctx }) => { log('publishOrCreate input: %O', input); - const { identifier: inputIdentifier, name, memberAgents, ...groupData } = input; + const { actAs, identifier: inputIdentifier, name, memberAgents, ...groupData } = input; let finalIdentifier = inputIdentifier; let isNewGroup = false; @@ -679,7 +711,9 @@ export const agentGroupRouter = router({ log('Ownership check: currentAccountId=%s, ownerId=%s', currentAccountId, ownerId); - if (!currentAccountId || `${ownerId}` !== `${currentAccountId}`) { + const actingAccountId = actAs ?? currentAccountId; + + if (!actingAccountId || `${ownerId}` !== `${actingAccountId}`) { // Not the owner, need to create a new group log('User is not owner, will create new group'); finalIdentifier = undefined; @@ -703,24 +737,28 @@ export const agentGroupRouter = router({ log('Creating new group with identifier: %s', finalIdentifier); - await ctx.marketSDK.agentGroups.createAgentGroup({ - ...groupData, - identifier: finalIdentifier, - // @ts-ignore - memberAgents, - name, - }); + await withActingAccountHeader(ctx.marketSDK, actAs, () => + ctx.marketSDK.agentGroups.createAgentGroup({ + ...groupData, + identifier: finalIdentifier!, + // @ts-ignore + memberAgents, + name, + }), + ); } else { // Update existing group - create new version log('Creating new version for group: %s', finalIdentifier); - await ctx.marketSDK.agentGroups.createAgentGroupVersion({ - ...groupData, - identifier: finalIdentifier, - // @ts-ignore - memberAgents, - name, - }); + await withActingAccountHeader(ctx.marketSDK, actAs, () => + ctx.marketSDK.agentGroups.createAgentGroupVersion({ + ...groupData, + identifier: finalIdentifier!, + // @ts-ignore + memberAgents, + name, + }), + ); } return { @@ -742,7 +780,7 @@ export const agentGroupRouter = router({ * Unpublish agent group * POST /market/agent-group/:identifier/unpublish */ - unpublishAgentGroup: agentGroupProcedure + unpublishAgentGroup: agentGroupWriteProcedure .input(z.object({ identifier: z.string() })) .mutation(async ({ input, ctx }) => { log('unpublishAgentGroup input: %O', input); diff --git a/src/server/routers/lambda/market/creds.ts b/src/server/routers/lambda/market/creds.ts index 6ca9bf84a2..d9bc89f573 100644 --- a/src/server/routers/lambda/market/creds.ts +++ b/src/server/routers/lambda/market/creds.ts @@ -2,6 +2,7 @@ import { TRPCError } from '@trpc/server'; import debug from 'debug'; import { z } from 'zod'; +import { withRbacPermission } from '@/business/server/trpc-middlewares/rbacPermission'; import { publicProcedure, router } from '@/libs/trpc/lambda'; import { marketUserInfo, requireMarketAuth, serverDatabase } from '@/libs/trpc/lambda/middleware'; import { MarketService } from '@/server/services/market'; @@ -23,10 +24,11 @@ const credsProcedure = publicProcedure }, }); }); +const credsManageProcedure = credsProcedure.use(withRbacPermission('workspace:update:all')); export const credsRouter = router({ // Create file credential - createFile: credsProcedure + createFile: credsManageProcedure .input( z.object({ description: z.string().optional(), @@ -54,7 +56,7 @@ export const credsRouter = router({ }), // Create KV credential (kv-env or kv-header) - createKV: credsProcedure + createKV: credsManageProcedure .input( z.object({ description: z.string().optional(), @@ -82,7 +84,7 @@ export const credsRouter = router({ }), // Create OAuth credential - createOAuth: credsProcedure + createOAuth: credsManageProcedure .input( z.object({ description: z.string().optional(), @@ -109,25 +111,27 @@ export const credsRouter = router({ }), // Delete credential by ID - delete: credsProcedure.input(z.object({ id: z.number() })).mutation(async ({ ctx, input }) => { - log('delete input: %O', input); + delete: credsManageProcedure + .input(z.object({ id: z.number() })) + .mutation(async ({ ctx, input }) => { + log('delete input: %O', input); - try { - const result = await ctx.marketService.market.creds.delete(input.id); - log('delete success'); - return result; - } catch (error) { - log('delete error: %O', error); - throw new TRPCError({ - cause: error, - code: 'INTERNAL_SERVER_ERROR', - message: 'Failed to delete credential', - }); - } - }), + try { + const result = await ctx.marketService.market.creds.delete(input.id); + log('delete success'); + return result; + } catch (error) { + log('delete error: %O', error); + throw new TRPCError({ + cause: error, + code: 'INTERNAL_SERVER_ERROR', + message: 'Failed to delete credential', + }); + } + }), // Delete credential by key - deleteByKey: credsProcedure + deleteByKey: credsManageProcedure .input(z.object({ key: z.string() })) .mutation(async ({ ctx, input }) => { log('deleteByKey input: %O', input); @@ -147,7 +151,7 @@ export const credsRouter = router({ }), // Get single credential (optionally with decrypted values) - get: credsProcedure + get: credsManageProcedure .input( z.object({ decrypt: z.boolean().optional(), @@ -174,7 +178,7 @@ export const credsRouter = router({ }), // Get single credential by key (optionally with decrypted values) - getByKey: credsProcedure + getByKey: credsManageProcedure .input( z.object({ decrypt: z.boolean().optional(), @@ -326,7 +330,7 @@ export const credsRouter = router({ }), // List OAuth connections (for creating OAuth credentials) - listOAuthConnections: credsProcedure.query(async ({ ctx }) => { + listOAuthConnections: credsManageProcedure.query(async ({ ctx }) => { log('listOAuthConnections called'); try { @@ -344,7 +348,7 @@ export const credsRouter = router({ }), // Upload credential file - uploadFile: credsProcedure + uploadFile: credsManageProcedure .input( z.object({ file: z.string(), // base64 encoded file content @@ -370,7 +374,7 @@ export const credsRouter = router({ }), // Update credential - update: credsProcedure + update: credsManageProcedure .input( z.object({ description: z.string().optional(), diff --git a/src/server/routers/lambda/market/social.ts b/src/server/routers/lambda/market/social.ts index 2ba0028f8c..f3334f7b32 100644 --- a/src/server/routers/lambda/market/social.ts +++ b/src/server/routers/lambda/market/social.ts @@ -181,9 +181,12 @@ export const socialRouter = router({ ctx.marketSDK.follows.getFollowers(input.userId, { limit: 1 }), ]); + // `totalCount` is the full-list total returned by the market follows + // endpoints (added in market-sdk >= 0.34.0-beta.2). The cast keeps this + // building against older published SDK typings until the dep is bumped. return { - followersCount: (followers as any).totalCount || (followers as any).total || 0, - followingCount: (following as any).totalCount || (following as any).total || 0, + followersCount: (followers as { totalCount?: number }).totalCount ?? 0, + followingCount: (following as { totalCount?: number }).totalCount ?? 0, }; } catch (error) { log('Error getting follow counts: %O', error); diff --git a/src/server/routers/lambda/market/socialProfile.test.ts b/src/server/routers/lambda/market/socialProfile.test.ts new file mode 100644 index 0000000000..bae0833d72 --- /dev/null +++ b/src/server/routers/lambda/market/socialProfile.test.ts @@ -0,0 +1,53 @@ +// @vitest-environment node +import { beforeEach, describe, expect, it, vi } from 'vitest'; + +import { socialProfileRouter } from './socialProfile'; + +const { mockMarketSDKHeaders } = vi.hoisted(() => ({ + mockMarketSDKHeaders: { + Authorization: 'Bearer market-token', + }, +})); + +vi.mock('@/libs/trpc/lambda/middleware', () => ({ + marketSDK: vi.fn((opts: any) => + opts.next({ + ctx: { + ...opts.ctx, + marketSDK: { + headers: mockMarketSDKHeaders, + }, + }, + }), + ), + marketUserInfo: vi.fn((opts: any) => opts.next({ ctx: opts.ctx })), + serverDatabase: vi.fn((opts: any) => opts.next({ ctx: opts.ctx })), +})); + +describe('socialProfileRouter.submitRepo', () => { + let fetchSpy: ReturnType; + + beforeEach(() => { + vi.clearAllMocks(); + fetchSpy = vi.spyOn(globalThis, 'fetch' as never); + fetchSpy.mockResolvedValue( + new Response(JSON.stringify({ message: 'submitted' }), { + status: 200, + }), + ); + }); + + it('attributes workspace skill repo submissions to the acting organization account', async () => { + const caller = socialProfileRouter.createCaller({ userId: 'user-1' } as any); + + await caller.submitRepo({ + actAs: 123, + gitUrl: 'https://github.com/lobehub/example-skill', + type: 'skill', + }); + + const call = fetchSpy.mock.calls[0] as [string, RequestInit] | undefined; + expect(String(call?.[0])).toMatch(/\/api\/v1\/user\/claims\/submit-repo$/); + expect((call?.[1]?.headers as Record)['x-lobe-owner-account-id']).toBe('123'); + }); +}); diff --git a/src/server/routers/lambda/market/socialProfile.ts b/src/server/routers/lambda/market/socialProfile.ts index 4073765dc7..983045ea0a 100644 --- a/src/server/routers/lambda/market/socialProfile.ts +++ b/src/server/routers/lambda/market/socialProfile.ts @@ -179,6 +179,7 @@ export const socialProfileRouter = router({ submitRepo: socialProfileAuthProcedure .input( z.object({ + actAs: z.number().int().positive().optional(), branch: z.string().optional(), gitUrl: z.string().url(), type: z.enum(['skill', 'plugin']).default('skill'), @@ -200,6 +201,9 @@ export const socialProfileRouter = router({ headers: { ...headers, 'Content-Type': 'application/json', + ...(input.actAs === undefined + ? {} + : { 'x-lobe-owner-account-id': String(input.actAs) }), }, method: 'POST', }); diff --git a/src/server/routers/lambda/message.ts b/src/server/routers/lambda/message.ts index 677e32e6e1..a4e1a2215f 100644 --- a/src/server/routers/lambda/message.ts +++ b/src/server/routers/lambda/message.ts @@ -8,10 +8,15 @@ import { createTimingHelpers, createTimingRequestId } from '@lobechat/utils'; import { TRPCError } from '@trpc/server'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { + cloudWorkspaceAuth, + wsCompatProcedure, +} from '@/business/server/trpc-middlewares/workspaceAuth'; import { MessageModel } from '@/database/models/message'; import { TopicShareModel } from '@/database/models/topicShare'; import { CompressionRepository } from '@/database/repositories/compression'; -import { authedProcedure, publicProcedure, router } from '@/libs/trpc/lambda'; +import { publicProcedure, router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { FileService } from '@/server/services/file'; import { MessageService } from '@/server/services/message'; @@ -21,21 +26,23 @@ import { basicContextSchema } from './_schema/context'; const { logTiming, runTimedStage } = createTimingHelpers('lobe-server:chat:lobehub:timing'); -const messageProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const messageProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - compressionRepo: new CompressionRepository(ctx.serverDB, ctx.userId), - fileService: new FileService(ctx.serverDB, ctx.userId), - messageModel: new MessageModel(ctx.serverDB, ctx.userId), - messageService: new MessageService(ctx.serverDB, ctx.userId), + compressionRepo: new CompressionRepository(ctx.serverDB, ctx.userId, wsId), + fileService: new FileService(ctx.serverDB, ctx.userId, wsId), + messageModel: new MessageModel(ctx.serverDB, ctx.userId, wsId), + messageService: new MessageService(ctx.serverDB, ctx.userId, wsId), }, }); }); export const messageRouter = router({ addFilesToMessage: messageProcedure + .use(withScopedPermission('message:update')) .input( z .object({ @@ -46,7 +53,12 @@ export const messageRouter = router({ ) .mutation(async ({ input, ctx }) => { const { id, fileIds, agentId, ...options } = input; - const resolved = await resolveContext({ agentId, ...options }, ctx.serverDB, ctx.userId); + const resolved = await resolveContext( + { agentId, ...options }, + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ); return ctx.messageService.addFilesToMessage(id, fileIds, resolved); }), @@ -55,6 +67,7 @@ export const messageRouter = router({ * Cancel compression by deleting the compression group and restoring original messages */ cancelCompression: messageProcedure + .use(withScopedPermission('message:update')) .input( z.object({ agentId: z.string(), @@ -122,6 +135,7 @@ export const messageRouter = router({ * Returns messages to summarize for frontend AI generation */ createCompressionGroup: messageProcedure + .use(withScopedPermission('message:update')) .input( z.object({ agentId: z.string(), @@ -143,12 +157,18 @@ export const messageRouter = router({ }), createMessage: messageProcedure + .use(withScopedPermission('message:create')) .input(CreateNewMessageParamsSchema) .mutation(async ({ input, ctx }) => { // If there's no agentId but has sessionId, resolve agentId from sessionId let agentId = input.agentId; if (!agentId && input.sessionId) { - agentId = (await resolveAgentIdFromSession(input.sessionId, ctx.serverDB, ctx.userId))!; + agentId = (await resolveAgentIdFromSession( + input.sessionId, + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ))!; } // Create message with the resolved agentId @@ -159,6 +179,7 @@ export const messageRouter = router({ * Finalize compression by updating the group with generated summary */ finalizeCompression: messageProcedure + .use(withScopedPermission('message:update')) .input( z.object({ agentId: z.string(), @@ -184,6 +205,7 @@ export const messageRouter = router({ }), getMessages: publicProcedure + .use(cloudWorkspaceAuth) .use(serverDatabase) .input( z.object({ @@ -225,8 +247,9 @@ export const messageRouter = router({ throw new TRPCError({ code: 'UNAUTHORIZED', message: 'Authentication required' }); } - const messageModel = new MessageModel(ctx.serverDB, ctx.userId); - const fileService = new FileService(ctx.serverDB, ctx.userId); + const wsId = ctx.workspaceId ?? undefined; + const messageModel = new MessageModel(ctx.serverDB, ctx.userId, wsId); + const fileService = new FileService(ctx.serverDB, ctx.userId, wsId); return messageModel.query(queryParams, { postProcessUrl: (path, file) => fileService.getFileAccessUrl({ id: file.id, url: path }), @@ -237,11 +260,14 @@ export const messageRouter = router({ return ctx.messageModel.rankModels(); }), - removeAllMessages: messageProcedure.mutation(async ({ ctx }) => { - return ctx.messageModel.deleteAllMessages(); - }), + removeAllMessages: messageProcedure + .use(withScopedPermission('message:delete')) + .mutation(async ({ ctx }) => { + return ctx.messageModel.deleteAllMessages(); + }), removeMessage: messageProcedure + .use(withScopedPermission('message:delete')) .input( z .object({ @@ -251,18 +277,25 @@ export const messageRouter = router({ ) .mutation(async ({ input, ctx }) => { const { id, agentId, ...options } = input; - const resolved = await resolveContext({ agentId, ...options }, ctx.serverDB, ctx.userId); + const resolved = await resolveContext( + { agentId, ...options }, + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ); return ctx.messageService.removeMessage(id, resolved); }), removeMessageQuery: messageProcedure + .use(withScopedPermission('message:delete')) .input(z.object({ id: z.string() })) .mutation(async ({ input, ctx }) => { return ctx.messageModel.deleteMessageQuery(input.id); }), removeMessages: messageProcedure + .use(withScopedPermission('message:delete')) .input( z .object({ @@ -272,12 +305,18 @@ export const messageRouter = router({ ) .mutation(async ({ input, ctx }) => { const { ids, agentId, ...options } = input; - const resolved = await resolveContext({ agentId, ...options }, ctx.serverDB, ctx.userId); + const resolved = await resolveContext( + { agentId, ...options }, + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ); return ctx.messageService.removeMessages(ids, resolved); }), removeMessagesByAssistant: messageProcedure + .use(withScopedPermission('message:delete')) .input( z .object({ @@ -287,7 +326,12 @@ export const messageRouter = router({ ) .mutation(async ({ input, ctx }) => { const { agentId, ...options } = input; - const resolved = await resolveContext({ agentId, ...options }, ctx.serverDB, ctx.userId); + const resolved = await resolveContext( + { agentId, ...options }, + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ); return ctx.messageModel.deleteMessagesBySession( resolved.sessionId, @@ -297,6 +341,7 @@ export const messageRouter = router({ }), removeMessagesByGroup: messageProcedure + .use(withScopedPermission('message:delete')) .input( z.object({ groupId: z.string(), @@ -314,6 +359,7 @@ export const messageRouter = router({ }), update: messageProcedure + .use(withScopedPermission('message:update')) .input( z .object({ @@ -334,7 +380,13 @@ export const messageRouter = router({ const resolved = await runTimedStage( timingContext, 'lambda.message.update.resolveContext', - () => resolveContext({ agentId, ...options }, ctx.serverDB, ctx.userId), + () => + resolveContext( + { agentId, ...options }, + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ), { hasAgentId: !!agentId }, ); @@ -361,6 +413,7 @@ export const messageRouter = router({ * Update message group metadata (e.g., expanded state) */ updateMessageGroupMetadata: messageProcedure + .use(withScopedPermission('message:update')) .input( z.object({ context: z.object({ @@ -380,6 +433,7 @@ export const messageRouter = router({ }), updateMessagePlugin: messageProcedure + .use(withScopedPermission('message:update')) .input( z .object({ @@ -390,21 +444,33 @@ export const messageRouter = router({ ) .mutation(async ({ input, ctx }) => { const { id, value, agentId, ...options } = input; - const resolved = await resolveContext({ agentId, ...options }, ctx.serverDB, ctx.userId); + const resolved = await resolveContext( + { agentId, ...options }, + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ); return ctx.messageService.updateMessagePlugin(id, value, resolved); }), updateMessageRAG: messageProcedure + .use(withScopedPermission('message:update')) .input(UpdateMessageRAGParamsSchema.extend(basicContextSchema.shape)) .mutation(async ({ input, ctx }) => { const { id, value, agentId, ...options } = input; - const resolved = await resolveContext({ agentId, ...options }, ctx.serverDB, ctx.userId); + const resolved = await resolveContext( + { agentId, ...options }, + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ); return ctx.messageService.updateMessageRAG(id, value, resolved); }), updateMetadata: messageProcedure + .use(withScopedPermission('message:update')) .input( z .object({ @@ -415,12 +481,18 @@ export const messageRouter = router({ ) .mutation(async ({ input, ctx }) => { const { id, value, agentId, ...options } = input; - const resolved = await resolveContext({ agentId, ...options }, ctx.serverDB, ctx.userId); + const resolved = await resolveContext( + { agentId, ...options }, + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ); return ctx.messageService.updateMetadata(id, value, resolved); }), updatePluginError: messageProcedure + .use(withScopedPermission('message:update')) .input( z .object({ @@ -431,12 +503,18 @@ export const messageRouter = router({ ) .mutation(async ({ input, ctx }) => { const { id, value, agentId, ...options } = input; - const resolved = await resolveContext({ agentId, ...options }, ctx.serverDB, ctx.userId); + const resolved = await resolveContext( + { agentId, ...options }, + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ); return ctx.messageService.updatePluginError(id, value, resolved); }), updatePluginState: messageProcedure + .use(withScopedPermission('message:update')) .input( z .object({ @@ -447,12 +525,18 @@ export const messageRouter = router({ ) .mutation(async ({ input, ctx }) => { const { id, value, agentId, ...options } = input; - const resolved = await resolveContext({ agentId, ...options }, ctx.serverDB, ctx.userId); + const resolved = await resolveContext( + { agentId, ...options }, + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ); return ctx.messageService.updatePluginState(id, value, resolved); }), updateTTS: messageProcedure + .use(withScopedPermission('message:update')) .input( z.object({ id: z.string(), @@ -474,6 +558,7 @@ export const messageRouter = router({ }), updateToolArguments: messageProcedure + .use(withScopedPermission('message:update')) .input( z .object({ @@ -484,7 +569,12 @@ export const messageRouter = router({ ) .mutation(async ({ input, ctx }) => { const { toolCallId, value, agentId, ...options } = input; - const resolved = await resolveContext({ agentId, ...options }, ctx.serverDB, ctx.userId); + const resolved = await resolveContext( + { agentId, ...options }, + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ); return ctx.messageService.updateToolArguments(toolCallId, value, resolved); }), @@ -494,6 +584,7 @@ export const messageRouter = router({ * This prevents race conditions when updating multiple fields */ updateToolMessage: messageProcedure + .use(withScopedPermission('message:update')) .input( z .object({ @@ -509,11 +600,17 @@ export const messageRouter = router({ ) .mutation(async ({ input, ctx }) => { const { id, value, agentId, ...options } = input; - const resolved = await resolveContext({ agentId, ...options }, ctx.serverDB, ctx.userId); + const resolved = await resolveContext( + { agentId, ...options }, + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ); return ctx.messageService.updateToolMessage(id, value, resolved); }), updateTranslate: messageProcedure + .use(withScopedPermission('message:update')) .input( z.object({ id: z.string(), diff --git a/src/server/routers/lambda/messenger.ts b/src/server/routers/lambda/messenger.ts index 9d28252259..4cfa4fb473 100644 --- a/src/server/routers/lambda/messenger.ts +++ b/src/server/routers/lambda/messenger.ts @@ -3,6 +3,7 @@ import { TRPCError } from '@trpc/server'; import { and, desc, eq, ne, or } from 'drizzle-orm'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; import { getEnabledMessengerPlatforms, getMessengerDiscordConfig, @@ -18,10 +19,14 @@ import { } from '@/database/models/messengerAccountLink'; import type { DecryptedMessengerInstallation } from '@/database/models/messengerInstallation'; import { MessengerInstallationModel } from '@/database/models/messengerInstallation'; +import { RbacModel } from '@/database/models/rbac'; +import { WorkspaceModel } from '@/database/models/workspace'; import { agents, users } from '@/database/schemas'; import type { LobeChatDatabase } from '@/database/type'; +import { buildWorkspaceWhere } from '@/database/utils/workspace'; import { authedProcedure, publicProcedure, router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; +import { getServerFeatureFlagsStateFromRuntimeConfig } from '@/server/featureFlags'; import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt'; import { SlackApi } from '@/server/services/bot/platforms/slack/api'; import { @@ -54,6 +59,22 @@ const extractSlackAuthErrorCode = (error: unknown): string | null => { return match?.[1] ?? null; }; +const WORKSPACE_FEATURE_DISABLED_MESSAGE = 'Workspace feature is not enabled for this user'; + +const isWorkspaceFeatureEnabledForUser = async (userId: string): Promise => { + const featureFlags = await getServerFeatureFlagsStateFromRuntimeConfig(userId); + return featureFlags.enableWorkspace === true; +}; + +const assertWorkspaceFeatureEnabledForUser = async (userId: string): Promise => { + if (await isWorkspaceFeatureEnabledForUser(userId)) return; + + throw new TRPCError({ + code: 'FORBIDDEN', + message: WORKSPACE_FEATURE_DISABLED_MESSAGE, + }); +}; + const reconcileSlackInstallation = async ( serverDB: LobeChatDatabase, row: DecryptedMessengerInstallation, @@ -95,10 +116,70 @@ const messengerProcedure = authedProcedure.use(serverDatabase).use(async (opts) const { ctx } = opts; return opts.next({ ctx: { + // The System Bot is a single shared bot; the workspace a conversation + // runs in is derived from the *active agent*, not the ambient + // `X-Workspace-Id` header. So the link model is identity-scoped (by + // userId), and per-agent authorization happens in-handler via + // `resolveAuthorizedAgentScope`. messengerLinkModel: new MessengerAccountLinkModel(ctx.serverDB, ctx.userId), }, }); }); +const messengerWriteProcedure = messengerProcedure.use(withScopedPermission('agent:update')); + +/** + * Resolve the workspace scope of an agent the user wants to route the System + * Bot to, authorizing access along the way. Because the bot is shared and + * which LobeHub context a conversation runs in is derived from the *active + * agent*, every place that sets the active agent must re-authorize against + * that agent's own workspace: + * + * - personal agent (`workspace_id IS NULL`): must be owned by the caller. + * - workspace agent: caller must be a member AND hold `agent:update` in that + * workspace (mirrors `withScopedPermission('agent:update')`). + * + * Returns the derived `workspaceId` (null for personal) + the agent title. + * Throws `NOT_FOUND` / `FORBIDDEN`. + */ +const resolveAuthorizedAgentScope = async ( + serverDB: LobeChatDatabase, + userId: string, + agentId: string, +): Promise<{ title: string | null; workspaceId: string | null }> => { + const [agentRow] = await serverDB + .select({ title: agents.title, userId: agents.userId, workspaceId: agents.workspaceId }) + .from(agents) + .where(eq(agents.id, agentId)) + .limit(1); + if (!agentRow) { + throw new TRPCError({ code: 'NOT_FOUND', message: 'messenger.error.agentNotFound' }); + } + + // Personal agent — only the owner may route the bot to it. + if (!agentRow.workspaceId) { + if (agentRow.userId !== userId) { + throw new TRPCError({ code: 'NOT_FOUND', message: 'messenger.error.agentNotFound' }); + } + return { title: agentRow.title, workspaceId: null }; + } + + // Workspace agent — caller must be a member with `agent:update`. + await assertWorkspaceFeatureEnabledForUser(userId); + + const userWorkspaces = await new WorkspaceModel(serverDB, userId).listUserWorkspaces(); + const isMember = userWorkspaces.some((w) => w.id === agentRow.workspaceId); + if (!isMember) { + throw new TRPCError({ code: 'FORBIDDEN', message: 'messenger.error.agentNotFound' }); + } + const allowed = await new RbacModel(serverDB, userId).hasAnyPermission( + ['agent:update:all', 'agent:update:owner'], + { workspaceId: agentRow.workspaceId }, + ); + if (!allowed) { + throw new TRPCError({ code: 'FORBIDDEN', message: 'messenger.error.agentNotFound' }); + } + return { title: agentRow.title, workspaceId: agentRow.workspaceId }; +}; export const messengerRouter = router({ /** @@ -280,14 +361,13 @@ export const messengerRouter = router({ }); } - const [agentRow] = await ctx.serverDB - .select({ id: agents.id, title: agents.title }) - .from(agents) - .where(and(eq(agents.id, input.initialAgentId), eq(agents.userId, ctx.userId))) - .limit(1); - if (!agentRow) { - throw new TRPCError({ code: 'NOT_FOUND', message: 'messenger.error.agentNotFound' }); - } + // Authorize the chosen initial agent against its own workspace (personal + // or a workspace the user can access) and derive the active scope. + const agentScope = await resolveAuthorizedAgentScope( + ctx.serverDB, + ctx.userId, + input.initialAgentId, + ); // Now safe to consume — token is single-use; do this last so any error // above leaves the token available for retry. @@ -304,11 +384,12 @@ export const messengerRouter = router({ let link; try { link = await ctx.messengerLinkModel.upsertForPlatform({ - activeAgentId: agentRow.id, + activeAgentId: input.initialAgentId, platform: payload.platform, platformUserId: payload.platformUserId, platformUsername: payload.platformUsername ?? null, tenantId: payload.tenantId ?? '', + workspaceId: agentScope.workspaceId, }); } catch (error) { // Race backstop: the IM identity got bound to another LobeHub user @@ -331,7 +412,7 @@ export const messengerRouter = router({ // Best-effort confirmation back to the IM platform. void notifyLinkSuccess(payload.platform, { - activeAgentName: agentRow.title ?? undefined, + activeAgentName: agentScope.title ?? undefined, platformUserId: payload.platformUserId, tenantId: payload.tenantId, }); @@ -352,42 +433,80 @@ export const messengerRouter = router({ * inbox agent resolves to `"LobeAI"` + default avatar; everything else falls * back on the client via `common.defaultSession`. */ - listAgentsForBinding: messengerProcedure.query(async ({ ctx }) => { - const rows = await ctx.serverDB - .select({ - avatar: agents.avatar, - backgroundColor: agents.backgroundColor, - id: agents.id, - slug: agents.slug, - title: agents.title, - }) - .from(agents) - .where( - and( - eq(agents.userId, ctx.userId), - or(ne(agents.virtual, true), eq(agents.slug, INBOX_SESSION_ID)), - ), - ) - .orderBy(desc(agents.updatedAt)); + listAgentsForBinding: messengerProcedure + .input(z.object({ workspaceId: z.string().nullish() }).optional()) + .query(async ({ ctx, input }) => { + const { serverDB, userId } = ctx; + // Cascading scope: the caller picks a scope (personal or one of the + // workspaces they belong to) and we return just that scope's agents. + // Omitting `workspaceId` (or `null`) means personal. + const workspaceId = input?.workspaceId ?? null; - const mapped = rows - .filter((row) => row.id) - .map((row) => ({ - avatar: row.avatar || (row.slug === INBOX_SESSION_ID ? DEFAULT_INBOX_AVATAR : null), - backgroundColor: row.backgroundColor, - id: row.id, - slug: row.slug, - title: row.title || (row.slug === INBOX_SESSION_ID ? 'LobeAI' : null), + // Authorize the requested scope. Personal is always the caller's own; a + // workspace scope requires membership, otherwise the caller could + // enumerate another workspace's agents. + if (workspaceId) { + await assertWorkspaceFeatureEnabledForUser(userId); + + const userWorkspaces = await new WorkspaceModel(serverDB, userId).listUserWorkspaces(); + if (!userWorkspaces.some((w) => w.id === workspaceId)) { + throw new TRPCError({ code: 'FORBIDDEN', message: 'messenger.error.agentNotFound' }); + } + } + + const rows = await serverDB + .select({ + avatar: agents.avatar, + backgroundColor: agents.backgroundColor, + id: agents.id, + slug: agents.slug, + title: agents.title, + }) + .from(agents) + .where( + and( + buildWorkspaceWhere({ userId, workspaceId: workspaceId ?? undefined }, agents), + or(ne(agents.virtual, true), eq(agents.slug, INBOX_SESSION_ID)), + ), + ) + .orderBy(desc(agents.updatedAt)); + + const mapped = rows + .filter((row) => row.id) + .map((row) => ({ + avatar: row.avatar || (row.slug === INBOX_SESSION_ID ? DEFAULT_INBOX_AVATAR : null), + backgroundColor: row.backgroundColor, + id: row.id, + slug: row.slug, + title: row.title || (row.slug === INBOX_SESSION_ID ? 'LobeAI' : null), + })); + + // Pin the inbox/LobeAI agent to the top regardless of updatedAt — it's + // the implicit "default" agent and should always be the first option. + const inboxIdx = mapped.findIndex((row) => row.slug === INBOX_SESSION_ID); + if (inboxIdx > 0) { + const [inbox] = mapped.splice(inboxIdx, 1); + mapped.unshift(inbox); + } + return mapped.map(({ slug, ...rest }) => ({ + ...rest, + isInbox: slug === INBOX_SESSION_ID, })); + }), - // Pin the inbox/LobeAI agent to the top regardless of updatedAt — it's the - // implicit "default" agent and should always be the first option. - const inboxIdx = mapped.findIndex((row) => row.slug === INBOX_SESSION_ID); - if (inboxIdx > 0) { - const [inbox] = mapped.splice(inboxIdx, 1); - mapped.unshift(inbox); - } - return mapped.map(({ slug: _slug, ...rest }) => rest); + /** + * List the scopes the user can route the System Bot to: their personal space + * plus every workspace they belong to. Drives the connection card's + * first-level "scope" selector; the second level then calls + * `listAgentsForBinding({ workspaceId })` for the picked scope. Personal scope + * is implicit (the client prepends it) — this only returns workspaces, so in + * OSS / personal-only deployments it's simply an empty array. + */ + listBindingScopes: messengerProcedure.query(async ({ ctx }) => { + if (!(await isWorkspaceFeatureEnabledForUser(ctx.userId))) return []; + + const workspaces = await new WorkspaceModel(ctx.serverDB, ctx.userId).listUserWorkspaces(); + return workspaces.map((w) => ({ avatar: w.avatar, id: w.id, name: w.name })); }), /** @@ -419,21 +538,19 @@ export const messengerRouter = router({ }), ) .mutation(async ({ input, ctx }) => { - // Validate ownership when setting a non-null agent. + // Authorize the target agent against its own workspace and derive the + // active scope (personal → null). Clearing (`agentId: null`) resets to + // personal scope. + let workspaceId: string | null = null; if (input.agentId !== null) { - const [agentRow] = await ctx.serverDB - .select({ id: agents.id }) - .from(agents) - .where(and(eq(agents.id, input.agentId), eq(agents.userId, ctx.userId))) - .limit(1); - if (!agentRow) { - throw new TRPCError({ code: 'NOT_FOUND', message: 'messenger.error.agentNotFound' }); - } + const scope = await resolveAuthorizedAgentScope(ctx.serverDB, ctx.userId, input.agentId); + workspaceId = scope.workspaceId; } const updated = await ctx.messengerLinkModel.setActiveAgent( input.platform, input.agentId, + workspaceId, input.tenantId, ); if (!updated) { @@ -446,7 +563,7 @@ export const messengerRouter = router({ }), /** Remove the user's account link for a platform (optionally scoped to one tenant). */ - unlink: messengerProcedure + unlink: messengerWriteProcedure .input(z.object({ platform: platformEnum, tenantId: z.string().optional() })) .mutation(async ({ input, ctx }) => { if (!(await isMessengerPlatformEnabled(input.platform))) { @@ -509,7 +626,7 @@ export const messengerRouter = router({ * Slack's `auth.revoke` to invalidate the token server-side is a * nice-to-have (frees a workspace bot slot), deferred to PR3. */ - uninstallInstallation: messengerProcedure + uninstallInstallation: messengerWriteProcedure .input(z.object({ installationId: z.string().min(1) })) .mutation(async ({ input, ctx }) => { const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey().catch(() => undefined); diff --git a/src/server/routers/lambda/notebook.ts b/src/server/routers/lambda/notebook.ts index 52efbb3b2a..2ac9581eb0 100644 --- a/src/server/routers/lambda/notebook.ts +++ b/src/server/routers/lambda/notebook.ts @@ -1,26 +1,34 @@ import { type NotebookDocument } from '@lobechat/types'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { DocumentModel } from '@/database/models/document'; import { TopicDocumentModel } from '@/database/models/topicDocument'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { NotebookRuntimeService } from '@/server/services/notebook'; -const notebookProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const notebookProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - documentModel: new DocumentModel(ctx.serverDB, ctx.userId), - notebookService: new NotebookRuntimeService({ serverDB: ctx.serverDB, userId: ctx.userId }), - topicDocumentModel: new TopicDocumentModel(ctx.serverDB, ctx.userId), + documentModel: new DocumentModel(ctx.serverDB, ctx.userId, wsId), + notebookService: new NotebookRuntimeService({ + serverDB: ctx.serverDB, + userId: ctx.userId, + workspaceId: wsId, + }), + topicDocumentModel: new TopicDocumentModel(ctx.serverDB, ctx.userId, wsId), }, }); }); export const notebookRouter = router({ createDocument: notebookProcedure + .use(withScopedPermission('document:create')) .input( z.object({ content: z.string(), @@ -60,6 +68,7 @@ export const notebookRouter = router({ }), deleteDocument: notebookProcedure + .use(withScopedPermission('document:delete')) .input(z.object({ id: z.string() })) .mutation(async ({ ctx, input }) => { await ctx.notebookService.deleteDocument(input.id); @@ -104,6 +113,7 @@ export const notebookRouter = router({ }), updateDocument: notebookProcedure + .use(withScopedPermission('document:update')) .input( z.object({ append: z.boolean().optional(), diff --git a/src/server/routers/lambda/notification.ts b/src/server/routers/lambda/notification.ts index d4f50dedb0..4c23323f25 100644 --- a/src/server/routers/lambda/notification.ts +++ b/src/server/routers/lambda/notification.ts @@ -1,25 +1,32 @@ import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { NotificationModel } from '@/database/models/notification'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; -const notificationProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const notificationProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; return opts.next({ - ctx: { notificationModel: new NotificationModel(ctx.serverDB, ctx.userId) }, + ctx: { + notificationModel: new NotificationModel(ctx.serverDB, ctx.userId), + }, }); }); +const notificationWriteProcedure = notificationProcedure.use( + withScopedPermission('message:create'), +); export const notificationRouter = router({ - archive: notificationProcedure + archive: notificationWriteProcedure .input(z.object({ id: z.string() })) .mutation(async ({ ctx, input }) => { return ctx.notificationModel.archive(input.id); }), - archiveAll: notificationProcedure.mutation(async ({ ctx }) => { + archiveAll: notificationWriteProcedure.mutation(async ({ ctx }) => { return ctx.notificationModel.archiveAll(); }), @@ -36,11 +43,11 @@ export const notificationRouter = router({ return ctx.notificationModel.list(input); }), - markAllAsRead: notificationProcedure.mutation(async ({ ctx }) => { + markAllAsRead: notificationWriteProcedure.mutation(async ({ ctx }) => { return ctx.notificationModel.markAllAsRead(); }), - markAsRead: notificationProcedure + markAsRead: notificationWriteProcedure .input(z.object({ ids: z.array(z.string()).min(1) })) .mutation(async ({ ctx, input }) => { return ctx.notificationModel.markAsRead(input.ids); diff --git a/src/server/routers/lambda/oauthDeviceFlow.ts b/src/server/routers/lambda/oauthDeviceFlow.ts index 5cd1c3b3ea..5876ac3edd 100644 --- a/src/server/routers/lambda/oauthDeviceFlow.ts +++ b/src/server/routers/lambda/oauthDeviceFlow.ts @@ -2,8 +2,10 @@ import { TRPCError } from '@trpc/server'; import { DEFAULT_MODEL_PROVIDER_LIST } from 'model-bank/modelProviders'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { AiProviderModel } from '@/database/models/aiProvider'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt'; import { @@ -11,8 +13,9 @@ import { GithubCopilotOAuthService, } from '@/server/services/oauthDeviceFlow/providers/githubCopilot'; -const oauthProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const oauthProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey(); return opts.next({ @@ -22,6 +25,7 @@ const oauthProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { }, }); }); +const oauthWriteProcedure = oauthProcedure.use(withScopedPermission('ai_provider:update')); /** * Get OAuth Device Flow config for a provider @@ -70,7 +74,7 @@ export const oauthDeviceFlowRouter = router({ /** * Initiate OAuth Device Flow - request a device code */ - initiateDeviceCode: oauthProcedure + initiateDeviceCode: oauthWriteProcedure .input(z.object({ providerId: z.string() })) .mutation(async ({ input }) => { const config = getOAuthConfig(input.providerId); @@ -97,7 +101,7 @@ export const oauthDeviceFlowRouter = router({ /** * Poll for authorization status and exchange tokens if authorized */ - pollAuthStatus: oauthProcedure + pollAuthStatus: oauthWriteProcedure .input( z.object({ deviceCode: z.string(), @@ -177,7 +181,7 @@ export const oauthDeviceFlowRouter = router({ /** * Revoke OAuth authorization for a provider */ - revokeAuth: oauthProcedure + revokeAuth: oauthWriteProcedure .input(z.object({ providerId: z.string() })) .mutation(async ({ input, ctx }) => { // Clear OAuth tokens and user info from keyVaults diff --git a/src/server/routers/lambda/plugin.ts b/src/server/routers/lambda/plugin.ts index 37c4bccde3..e31db9f60d 100644 --- a/src/server/routers/lambda/plugin.ts +++ b/src/server/routers/lambda/plugin.ts @@ -1,21 +1,24 @@ import { type LobeTool } from '@lobechat/types'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { PluginModel } from '@/database/models/plugin'; -import { getServerDB } from '@/database/server'; -import { authedProcedure, publicProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; -const pluginProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const pluginProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ - ctx: { pluginModel: new PluginModel(ctx.serverDB, ctx.userId) }, + ctx: { pluginModel: new PluginModel(ctx.serverDB, ctx.userId, wsId) }, }); }); export const pluginRouter = router({ createOrInstallPlugin: pluginProcedure + .use(withScopedPermission('agent:update')) .input( z.object({ customParams: z.any(), @@ -46,6 +49,7 @@ export const pluginRouter = router({ }), createPlugin: pluginProcedure + .use(withScopedPermission('agent:update')) .input( z.object({ customParams: z.any(), @@ -65,27 +69,27 @@ export const pluginRouter = router({ return data.identifier; }), - // TODO: In the future, this method also needs to use authedProcedure - getPlugins: publicProcedure.query(async ({ ctx }): Promise => { - if (!ctx.userId) return []; - - const serverDB = await getServerDB(); - const pluginModel = new PluginModel(serverDB, ctx.userId); + getPlugins: wsCompatProcedure.use(serverDatabase).query(async ({ ctx }): Promise => { + const pluginModel = new PluginModel(ctx.serverDB, ctx.userId, ctx.workspaceId ?? undefined); return pluginModel.query(); }), - removeAllPlugins: pluginProcedure.mutation(async ({ ctx }) => { - return ctx.pluginModel.deleteAll(); - }), + removeAllPlugins: pluginProcedure + .use(withScopedPermission('agent:update')) + .mutation(async ({ ctx }) => { + return ctx.pluginModel.deleteAll(); + }), removePlugin: pluginProcedure + .use(withScopedPermission('agent:update')) .input(z.object({ id: z.string() })) .mutation(async ({ input, ctx }) => { return ctx.pluginModel.delete(input.id); }), updatePlugin: pluginProcedure + .use(withScopedPermission('agent:update')) .input( z.object({ customParams: z.any().optional(), diff --git a/src/server/routers/lambda/ragEval.ts b/src/server/routers/lambda/ragEval.ts index 110058dce1..f65cb647ca 100644 --- a/src/server/routers/lambda/ragEval.ts +++ b/src/server/routers/lambda/ragEval.ts @@ -15,6 +15,8 @@ import JSONL from 'jsonl-parse-stringify'; import pMap from 'p-map'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { DEFAULT_EMBEDDING_MODEL, DEFAULT_MODEL } from '@/const/settings'; import { FileModel } from '@/database/models/file'; import { @@ -23,28 +25,30 @@ import { EvalEvaluationModel, EvaluationRecordModel, } from '@/database/models/ragEval'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { createAsyncCaller } from '@/server/routers/async'; import { FileService } from '@/server/services/file'; -const ragEvalProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const ragEvalProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - datasetModel: new EvalDatasetModel(ctx.serverDB, ctx.userId), - fileModel: new FileModel(ctx.serverDB, ctx.userId), - datasetRecordModel: new EvalDatasetRecordModel(ctx.serverDB, ctx.userId), - evaluationModel: new EvalEvaluationModel(ctx.serverDB, ctx.userId), - evaluationRecordModel: new EvaluationRecordModel(ctx.serverDB, ctx.userId), - fileService: new FileService(ctx.serverDB, ctx.userId), + datasetModel: new EvalDatasetModel(ctx.serverDB, ctx.userId, wsId), + fileModel: new FileModel(ctx.serverDB, ctx.userId, wsId), + datasetRecordModel: new EvalDatasetRecordModel(ctx.serverDB, ctx.userId, wsId), + evaluationModel: new EvalEvaluationModel(ctx.serverDB, ctx.userId, wsId), + evaluationRecordModel: new EvaluationRecordModel(ctx.serverDB, ctx.userId, wsId), + fileService: new FileService(ctx.serverDB, ctx.userId, wsId), }, }); }); +const ragEvalWriteProcedure = ragEvalProcedure.use(withScopedPermission('knowledge_base:update')); export const ragEvalRouter = router({ - createDataset: ragEvalProcedure + createDataset: ragEvalWriteProcedure .input( z.object({ description: z.string().optional(), @@ -69,13 +73,13 @@ export const ragEvalRouter = router({ return ctx.datasetModel.query(input.knowledgeBaseId); }), - removeDataset: ragEvalProcedure + removeDataset: ragEvalWriteProcedure .input(z.object({ id: z.string() })) .mutation(async ({ input, ctx }) => { return ctx.datasetModel.delete(input.id); }), - updateDataset: ragEvalProcedure + updateDataset: ragEvalWriteProcedure .input( z.object({ id: z.string(), @@ -87,7 +91,7 @@ export const ragEvalRouter = router({ }), // Dataset Item operations - createDatasetRecords: ragEvalProcedure + createDatasetRecords: ragEvalWriteProcedure .input( z.object({ datasetId: z.string(), @@ -108,13 +112,13 @@ export const ragEvalRouter = router({ return ctx.datasetRecordModel.query(input.datasetId); }), - removeDatasetRecords: ragEvalProcedure + removeDatasetRecords: ragEvalWriteProcedure .input(z.object({ id: z.string() })) .mutation(async ({ input, ctx }) => { return ctx.datasetRecordModel.delete(input.id); }), - updateDatasetRecords: ragEvalProcedure + updateDatasetRecords: ragEvalWriteProcedure .input( z.object({ id: z.string(), @@ -132,7 +136,7 @@ export const ragEvalRouter = router({ return ctx.datasetRecordModel.update(input.id, input.value); }), - importDatasetRecords: ragEvalProcedure + importDatasetRecords: ragEvalWriteProcedure .input( z.object({ datasetId: z.string(), @@ -170,7 +174,7 @@ export const ragEvalRouter = router({ }), // Evaluation operations - startEvaluationTask: ragEvalProcedure + startEvaluationTask: ragEvalWriteProcedure .input(z.object({ id: z.string() })) .mutation(async ({ input, ctx }) => { // Start evaluation task @@ -272,7 +276,7 @@ export const ragEvalRouter = router({ return { success: isSuccess }; }), - createEvaluation: ragEvalProcedure + createEvaluation: ragEvalWriteProcedure .input(insertEvalEvaluationSchema) .mutation(async ({ input, ctx }) => { const data = await ctx.evaluationModel.create({ @@ -285,7 +289,7 @@ export const ragEvalRouter = router({ return data?.id; }), - removeEvaluation: ragEvalProcedure + removeEvaluation: ragEvalWriteProcedure .input(z.object({ id: z.string() })) .mutation(async ({ input, ctx }) => { return ctx.evaluationModel.delete(input.id); diff --git a/src/server/routers/lambda/recent.ts b/src/server/routers/lambda/recent.ts index 5e666a9ae5..5765f89c9b 100644 --- a/src/server/routers/lambda/recent.ts +++ b/src/server/routers/lambda/recent.ts @@ -1,9 +1,10 @@ import type { TaskStatus } from '@lobechat/types'; import { z } from 'zod'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { SESSION_CHAT_TOPIC_URL } from '@/const/url'; import { RecentModel } from '@/database/models/recent'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import type { ChatTopicMetadata } from '@/types/topic'; @@ -20,11 +21,11 @@ export interface RecentItem { updatedAt: Date; } -const recentProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const recentProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; return opts.next({ ctx: { - recentModel: new RecentModel(ctx.serverDB, ctx.userId), + recentModel: new RecentModel(ctx.serverDB, ctx.userId, ctx.workspaceId ?? undefined), }, }); }); diff --git a/src/server/routers/lambda/search.ts b/src/server/routers/lambda/search.ts index ac11d995cf..0b38484a2c 100644 --- a/src/server/routers/lambda/search.ts +++ b/src/server/routers/lambda/search.ts @@ -1,7 +1,8 @@ import { z } from 'zod'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { SearchRepo } from '@/database/repositories/search'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { DiscoverService } from '@/server/services/discover'; @@ -19,13 +20,14 @@ function calculateMarketplaceRelevance(query: string, title: string): number { return 4; } -const searchProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const searchProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { discoverService: new DiscoverService({ accessToken: ctx.marketAccessToken }), - searchRepo: new SearchRepo(ctx.serverDB, ctx.userId), + searchRepo: new SearchRepo(ctx.serverDB, ctx.userId, wsId), }, }); }); diff --git a/src/server/routers/lambda/session.ts b/src/server/routers/lambda/session.ts index 6525d5562d..2df87d2e2b 100644 --- a/src/server/routers/lambda/session.ts +++ b/src/server/routers/lambda/session.ts @@ -1,24 +1,26 @@ import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { ChatGroupModel } from '@/database/models/chatGroup'; import { SessionModel } from '@/database/models/session'; import { SessionGroupModel } from '@/database/models/sessionGroup'; import { insertAgentSchema, insertSessionSchema } from '@/database/schemas'; -import { getServerDB } from '@/database/server'; -import { authedProcedure, publicProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { AgentChatConfigSchema } from '@/types/agent'; import { LobeMetaDataSchema } from '@/types/meta'; import { type BatchTaskResult } from '@/types/service'; import { type ChatSessionList, type LobeGroupSession } from '@/types/session'; -const sessionProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const sessionProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - sessionGroupModel: new SessionGroupModel(ctx.serverDB, ctx.userId), - sessionModel: new SessionModel(ctx.serverDB, ctx.userId), + sessionGroupModel: new SessionGroupModel(ctx.serverDB, ctx.userId, wsId), + sessionModel: new SessionModel(ctx.serverDB, ctx.userId, wsId), }, }); }); @@ -32,6 +34,7 @@ const sessionProcedure = authedProcedure.use(serverDatabase).use(async (opts) => export const sessionRouter = router({ /** @deprecated Use agent.createAgent instead */ batchCreateSessions: sessionProcedure + .use(withScopedPermission('session:create')) .input( z.array( z @@ -58,6 +61,7 @@ export const sessionRouter = router({ }), cloneSession: sessionProcedure + .use(withScopedPermission('session:create')) .input(z.object({ id: z.string(), newTitle: z.string() })) .mutation(async ({ input, ctx }) => { const data = await ctx.sessionModel.duplicate(input.id, input.newTitle); @@ -81,6 +85,7 @@ export const sessionRouter = router({ /** @deprecated Use agent.createAgent instead */ createSession: sessionProcedure + .use(withScopedPermission('session:create')) .input( z.object({ config: insertAgentSchema @@ -104,35 +109,36 @@ export const sessionRouter = router({ return data.id; }), - getGroupedSessions: publicProcedure.query(async ({ ctx }): Promise => { - const userId = ctx.userId; - if (!userId) return { sessionGroups: [], sessions: [] }; + getGroupedSessions: wsCompatProcedure + .use(serverDatabase) + .query(async ({ ctx }): Promise => { + const userId = ctx.userId; + const serverDB = ctx.serverDB; + const wsId = ctx.workspaceId ?? undefined; + const sessionModel = new SessionModel(serverDB, userId, wsId); + const chatGroupModel = new ChatGroupModel(serverDB, userId, wsId); - const serverDB = await getServerDB(); - const sessionModel = new SessionModel(serverDB, userId); - const chatGroupModel = new ChatGroupModel(serverDB, userId); + const [{ sessions, sessionGroups }, chatGroups] = await Promise.all([ + sessionModel.queryWithGroups(), + chatGroupModel.queryWithMemberDetails(), + ]); - const [{ sessions, sessionGroups }, chatGroups] = await Promise.all([ - sessionModel.queryWithGroups(), - chatGroupModel.queryWithMemberDetails(), - ]); + const groupSessions: LobeGroupSession[] = chatGroups.map((group) => { + const { title, description, avatar, backgroundColor, groupId, ...rest } = group; + return { + ...rest, + group: groupId, // Map groupId to group for consistent API + meta: { avatar, backgroundColor, description, title }, + type: 'group', + }; + }); - const groupSessions: LobeGroupSession[] = chatGroups.map((group) => { - const { title, description, avatar, backgroundColor, groupId, ...rest } = group; - return { - ...rest, - group: groupId, // Map groupId to group for consistent API - meta: { avatar, backgroundColor, description, title }, - type: 'group', - }; - }); + const allSessions = [...sessions, ...groupSessions].sort( + (a, b) => new Date(b.updatedAt).getTime() - new Date(a.updatedAt).getTime(), + ); - const allSessions = [...sessions, ...groupSessions].sort( - (a, b) => new Date(b.updatedAt).getTime() - new Date(a.updatedAt).getTime(), - ); - - return { sessionGroups, sessions: allSessions }; - }), + return { sessionGroups, sessions: allSessions }; + }), getSessions: sessionProcedure .input( @@ -147,11 +153,15 @@ export const sessionRouter = router({ return ctx.sessionModel.query({ current, pageSize }); }), - removeAllSessions: sessionProcedure.mutation(async ({ ctx }) => { - return ctx.sessionModel.deleteAll(); - }), + // Owner-only — bulk wipes everyone's sessions in the workspace. + removeAllSessions: sessionProcedure + .use(withScopedPermission('session:delete')) + .mutation(async ({ ctx }) => { + return ctx.sessionModel.deleteAll(); + }), removeSession: sessionProcedure + .use(withScopedPermission('session:delete')) .input(z.object({ id: z.string() })) .mutation(async ({ input, ctx }) => { return ctx.sessionModel.delete(input.id); @@ -164,6 +174,7 @@ export const sessionRouter = router({ }), updateSession: sessionProcedure + .use(withScopedPermission('session:update')) .input( z.object({ id: z.string(), @@ -174,6 +185,7 @@ export const sessionRouter = router({ return ctx.sessionModel.update(input.id, input.value); }), updateSessionChatConfig: sessionProcedure + .use(withScopedPermission('session:update')) .input( z.object({ id: z.string(), @@ -186,6 +198,7 @@ export const sessionRouter = router({ }); }), updateSessionConfig: sessionProcedure + .use(withScopedPermission('session:update')) .input( z.object({ id: z.string(), diff --git a/src/server/routers/lambda/sessionGroup.ts b/src/server/routers/lambda/sessionGroup.ts index 059e315f5d..ce61df6b7c 100644 --- a/src/server/routers/lambda/sessionGroup.ts +++ b/src/server/routers/lambda/sessionGroup.ts @@ -1,23 +1,27 @@ import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { SessionGroupModel } from '@/database/models/sessionGroup'; import { insertSessionGroupSchema } from '@/database/schemas'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { type SessionGroupItem } from '@/types/session'; -const sessionProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const sessionProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - sessionGroupModel: new SessionGroupModel(ctx.serverDB, ctx.userId), + sessionGroupModel: new SessionGroupModel(ctx.serverDB, ctx.userId, wsId), }, }); }); export const sessionGroupRouter = router({ createSessionGroup: sessionProcedure + .use(withScopedPermission('session_group:create')) .input( z.object({ name: z.string(), @@ -37,17 +41,21 @@ export const sessionGroupRouter = router({ return ctx.sessionGroupModel.query() as any; }), - removeAllSessionGroups: sessionProcedure.mutation(async ({ ctx }) => { - return ctx.sessionGroupModel.deleteAll(); - }), + removeAllSessionGroups: sessionProcedure + .use(withScopedPermission('session_group:delete')) + .mutation(async ({ ctx }) => { + return ctx.sessionGroupModel.deleteAll(); + }), removeSessionGroup: sessionProcedure + .use(withScopedPermission('session_group:delete')) .input(z.object({ id: z.string(), removeChildren: z.boolean().optional() })) .mutation(async ({ input, ctx }) => { return ctx.sessionGroupModel.delete(input.id); }), updateSessionGroup: sessionProcedure + .use(withScopedPermission('session_group:update')) .input( z.object({ id: z.string(), @@ -58,6 +66,7 @@ export const sessionGroupRouter = router({ return ctx.sessionGroupModel.update(input.id, input.value); }), updateSessionGroupOrder: sessionProcedure + .use(withScopedPermission('session_group:update')) .input( z.object({ sortMap: z.array( diff --git a/src/server/routers/lambda/task.ts b/src/server/routers/lambda/task.ts index 1207c6c515..5daeb1a1d0 100644 --- a/src/server/routers/lambda/task.ts +++ b/src/server/routers/lambda/task.ts @@ -1,29 +1,45 @@ import { TASK_STATUSES } from '@lobechat/builtin-tool-task'; import type { TaskListItem, TaskParticipant } from '@lobechat/types'; import { TRPCError } from '@trpc/server'; +import { and, eq, isNull } from 'drizzle-orm'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { AgentModel } from '@/database/models/agent'; import { BriefModel } from '@/database/models/brief'; import { TaskModel } from '@/database/models/task'; import { TaskTopicModel } from '@/database/models/taskTopic'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { TopicModel } from '@/database/models/topic'; +import { workspaceMembers } from '@/database/schemas'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { TaskService } from '@/server/services/task'; +import { TaskLifecycleService } from '@/server/services/taskLifecycle'; import { TaskRunnerService } from '@/server/services/taskRunner'; +import { TransferErrorCode } from '@/types/transferError'; -const taskProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const taskProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - agentModel: new AgentModel(ctx.serverDB, ctx.userId), - taskModel: new TaskModel(ctx.serverDB, ctx.userId), - taskService: new TaskService(ctx.serverDB, ctx.userId), - taskTopicModel: new TaskTopicModel(ctx.serverDB, ctx.userId), + agentModel: new AgentModel(ctx.serverDB, ctx.userId, wsId), + briefModel: new BriefModel(ctx.serverDB, ctx.userId, wsId), + taskLifecycle: new TaskLifecycleService(ctx.serverDB, ctx.userId, wsId), + taskModel: new TaskModel(ctx.serverDB, ctx.userId, wsId), + taskService: new TaskService(ctx.serverDB, ctx.userId, wsId), + taskTopicModel: new TaskTopicModel(ctx.serverDB, ctx.userId, wsId), + topicModel: new TopicModel(ctx.serverDB, ctx.userId, wsId), }, }); }); +// Write variant gates viewers out of every task mutation (create/update/delete/ +// run). Reads keep using `taskProcedure` so viewers can still inspect tasks +// and their status. +const taskProcedureWrite = taskProcedure.use(withScopedPermission('agent:update')); + // All procedures that take an id accept either raw id (task_xxx) or identifier (TASK-1) // Resolution happens in the model layer via model.resolve() const idInput = z.object({ id: z.string() }); @@ -150,7 +166,7 @@ async function resolveSafeParentTaskId( } export const taskRouter = router({ - reorderSubtasks: taskProcedure + reorderSubtasks: taskProcedureWrite .input( z.object({ id: z.string(), @@ -203,7 +219,7 @@ export const taskRouter = router({ } }), - addComment: taskProcedure + addComment: taskProcedureWrite .input( z.object({ authorAgentId: z.string().optional(), @@ -241,7 +257,7 @@ export const taskRouter = router({ } }), - deleteComment: taskProcedure + deleteComment: taskProcedureWrite .input(z.object({ commentId: z.string() })) .mutation(async ({ input, ctx }) => { try { @@ -261,7 +277,7 @@ export const taskRouter = router({ } }), - updateComment: taskProcedure + updateComment: taskProcedureWrite .input( z.object({ commentId: z.string(), @@ -289,7 +305,7 @@ export const taskRouter = router({ } }), - addDependency: taskProcedure + addDependency: taskProcedureWrite .input( z.object({ dependsOnId: z.string(), @@ -315,7 +331,7 @@ export const taskRouter = router({ } }), - cancelTopic: taskProcedure + cancelTopic: taskProcedureWrite .input(z.object({ topicId: z.string() })) .mutation(async ({ input, ctx }) => { try { @@ -332,7 +348,7 @@ export const taskRouter = router({ } }), - deleteTopic: taskProcedure + deleteTopic: taskProcedureWrite .input(z.object({ topicId: z.string() })) .mutation(async ({ input, ctx }) => { try { @@ -349,7 +365,7 @@ export const taskRouter = router({ } }), - create: taskProcedure.input(createSchema).mutation(async ({ input, ctx }) => { + create: taskProcedureWrite.input(createSchema).mutation(async ({ input, ctx }) => { try { const task = await ctx.taskService.createTask(input); return { data: task, message: 'Task created', success: true }; @@ -365,7 +381,7 @@ export const taskRouter = router({ } }), - clearAll: taskProcedure.mutation(async ({ ctx }) => { + clearAll: taskProcedureWrite.mutation(async ({ ctx }) => { try { const model = ctx.taskModel; const count = await model.deleteAll(); @@ -380,7 +396,7 @@ export const taskRouter = router({ } }), - delete: taskProcedure.input(idInput).mutation(async ({ input, ctx }) => { + delete: taskProcedureWrite.input(idInput).mutation(async ({ input, ctx }) => { try { const model = ctx.taskModel; const task = await resolveOrThrow(model, input.id); @@ -517,7 +533,7 @@ export const taskRouter = router({ } }), - heartbeat: taskProcedure.input(idInput).mutation(async ({ input, ctx }) => { + heartbeat: taskProcedureWrite.input(idInput).mutation(async ({ input, ctx }) => { try { const model = ctx.taskModel; const task = await resolveOrThrow(model, input.id); @@ -534,20 +550,21 @@ export const taskRouter = router({ } }), - watchdog: taskProcedure.mutation(async ({ ctx }) => { + watchdog: taskProcedureWrite.mutation(async ({ ctx }) => { try { const stuckTasks = await TaskModel.findStuckTasks(ctx.serverDB); const failed: string[] = []; for (const task of stuckTasks) { - const model = new TaskModel(ctx.serverDB, task.createdByUserId); + const wsId = task.workspaceId ?? undefined; + const model = new TaskModel(ctx.serverDB, task.createdByUserId, wsId); await model.updateStatus(task.id, 'failed', { completedAt: new Date(), error: 'Heartbeat timeout', }); // Create error brief - const briefModel = new BriefModel(ctx.serverDB, task.createdByUserId); + const briefModel = new BriefModel(ctx.serverDB, task.createdByUserId, wsId); await briefModel.create({ agentId: task.assigneeAgentId || undefined, priority: 'urgent', @@ -654,7 +671,7 @@ export const taskRouter = router({ } }), - run: taskProcedure + run: taskProcedureWrite .input( idInput.merge( z.object({ @@ -665,7 +682,11 @@ export const taskRouter = router({ ) .mutation(async ({ input, ctx }) => { try { - const runner = new TaskRunnerService(ctx.serverDB, ctx.userId); + const runner = new TaskRunnerService( + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ); return await runner.runTask({ continueTopicId: input.continueTopicId, extraPrompt: input.prompt, @@ -682,7 +703,7 @@ export const taskRouter = router({ } }), - pinDocument: taskProcedure + pinDocument: taskProcedureWrite .input( z.object({ documentId: z.string(), @@ -707,7 +728,7 @@ export const taskRouter = router({ } }), - removeDependency: taskProcedure + removeDependency: taskProcedureWrite .input(z.object({ dependsOnId: z.string(), taskId: z.string() })) .mutation(async ({ input, ctx }) => { try { @@ -727,7 +748,7 @@ export const taskRouter = router({ } }), - unpinDocument: taskProcedure + unpinDocument: taskProcedureWrite .input(z.object({ documentId: z.string(), taskId: z.string() })) .mutation(async ({ input, ctx }) => { try { @@ -763,7 +784,7 @@ export const taskRouter = router({ } }), - updateCheckpoint: taskProcedure + updateCheckpoint: taskProcedureWrite .input( idInput.merge( z.object({ @@ -824,7 +845,7 @@ export const taskRouter = router({ } }), - updateReview: taskProcedure + updateReview: taskProcedureWrite .input( idInput.merge( z.object({ @@ -876,7 +897,7 @@ export const taskRouter = router({ } }), - runReview: taskProcedure + runReview: taskProcedureWrite .input( idInput.merge( z.object({ @@ -900,7 +921,7 @@ export const taskRouter = router({ } }), - update: taskProcedure.input(idInput.merge(updateSchema)).mutation(async ({ input, ctx }) => { + update: taskProcedureWrite.input(idInput.merge(updateSchema)).mutation(async ({ input, ctx }) => { const { id, parentTaskId, ...data } = input; try { const model = ctx.taskModel; @@ -926,7 +947,7 @@ export const taskRouter = router({ } }), - updateConfig: taskProcedure + updateConfig: taskProcedureWrite .input(idInput.merge(z.object({ config: z.record(z.unknown()) }))) .mutation(async ({ input, ctx }) => { const { id, config } = input; @@ -962,7 +983,7 @@ export const taskRouter = router({ } }), - runReadySubtasks: taskProcedure.input(idInput).mutation(async ({ input, ctx }) => { + runReadySubtasks: taskProcedureWrite.input(idInput).mutation(async ({ input, ctx }) => { try { const result = await ctx.taskService.runReadySubtasks(input.id); return { data: result, success: result.failed.length === 0 }; @@ -977,7 +998,7 @@ export const taskRouter = router({ } }), - updateStatus: taskProcedure + updateStatus: taskProcedureWrite .input( z.object({ error: z.string().optional(), @@ -1009,4 +1030,113 @@ export const taskRouter = router({ }); } }), + + transferTask: taskProcedureWrite + .input( + z.object({ + targetWorkspaceId: z.string().nullable(), + taskId: z.string(), + }), + ) + .mutation(async ({ ctx, input }) => { + const task = await ctx.taskModel.resolve(input.taskId); + if (!task) + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.ResourceNotFound } }, + code: 'NOT_FOUND', + message: 'Task not found', + }); + + if (ctx.workspaceId && task.createdByUserId !== ctx.userId) { + const [membership] = await ctx.serverDB + .select({ role: workspaceMembers.role }) + .from(workspaceMembers) + .where( + and( + eq(workspaceMembers.workspaceId, ctx.workspaceId), + eq(workspaceMembers.userId, ctx.userId), + isNull(workspaceMembers.deletedAt), + ), + ) + .limit(1); + if (!membership || membership.role !== 'owner') { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.OwnerOnly } }, + code: 'FORBIDDEN', + message: 'Only workspace owners can transfer tasks created by others', + }); + } + } + + if (input.targetWorkspaceId === (ctx.workspaceId ?? null)) { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.SameWorkspace } }, + code: 'BAD_REQUEST', + message: 'Cannot transfer task to the same workspace', + }); + } + + if (input.targetWorkspaceId) { + const [targetMembership] = await ctx.serverDB + .select({ role: workspaceMembers.role }) + .from(workspaceMembers) + .where( + and( + eq(workspaceMembers.workspaceId, input.targetWorkspaceId), + eq(workspaceMembers.userId, ctx.userId), + isNull(workspaceMembers.deletedAt), + ), + ) + .limit(1); + if (!targetMembership || targetMembership.role === 'viewer') { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.TargetNoWriteAccess } }, + code: 'FORBIDDEN', + message: 'No write access to target workspace', + }); + } + } + + return ctx.taskModel.transferTo(task.id, input.targetWorkspaceId, ctx.userId); + }), + + copyTaskToWorkspace: taskProcedureWrite + .input( + z.object({ + targetWorkspaceId: z.string().nullable(), + taskId: z.string(), + }), + ) + .mutation(async ({ ctx, input }) => { + const task = await ctx.taskModel.resolve(input.taskId); + if (!task) + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.ResourceNotFound } }, + code: 'NOT_FOUND', + message: 'Task not found', + }); + + if (input.targetWorkspaceId) { + const [targetMembership] = await ctx.serverDB + .select({ role: workspaceMembers.role }) + .from(workspaceMembers) + .where( + and( + eq(workspaceMembers.workspaceId, input.targetWorkspaceId), + eq(workspaceMembers.userId, ctx.userId), + isNull(workspaceMembers.deletedAt), + ), + ) + .limit(1); + if (!targetMembership || targetMembership.role === 'viewer') { + throw new TRPCError({ + cause: { data: { code: TransferErrorCode.TargetNoWriteAccess } }, + code: 'FORBIDDEN', + message: 'No write access to target workspace', + }); + } + } + + return ctx.taskModel.copyToWorkspace(task.id, input.targetWorkspaceId, ctx.userId); + }), }); diff --git a/src/server/routers/lambda/thread.ts b/src/server/routers/lambda/thread.ts index dcaec04d31..1b48b53cf6 100644 --- a/src/server/routers/lambda/thread.ts +++ b/src/server/routers/lambda/thread.ts @@ -1,10 +1,12 @@ import { TRPCError } from '@trpc/server'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { MessageModel } from '@/database/models/message'; import { ThreadModel } from '@/database/models/thread'; import { insertThreadSchema } from '@/database/schemas'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { type ThreadItem } from '@/types/topic/thread'; import { createThreadSchema } from '@/types/topic/thread'; @@ -33,35 +35,40 @@ const ensureThreadCreated = ( }); }; -const threadProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const threadProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - messageModel: new MessageModel(ctx.serverDB, ctx.userId), - threadModel: new ThreadModel(ctx.serverDB, ctx.userId), + messageModel: new MessageModel(ctx.serverDB, ctx.userId, wsId), + threadModel: new ThreadModel(ctx.serverDB, ctx.userId, wsId), }, }); }); export const threadRouter = router({ - createThread: threadProcedure.input(createThreadSchema).mutation(async ({ input, ctx }) => { - const thread = ensureThreadCreated( - await ctx.threadModel.create({ - id: input.id, - metadata: input.metadata, - parentThreadId: input.parentThreadId, - sourceMessageId: input.sourceMessageId, - title: input.title, - topicId: input.topicId, - type: input.type, - }), - input.id, - ); + createThread: threadProcedure + .use(withScopedPermission('topic:create')) + .input(createThreadSchema) + .mutation(async ({ input, ctx }) => { + const thread = ensureThreadCreated( + await ctx.threadModel.create({ + id: input.id, + metadata: input.metadata, + parentThreadId: input.parentThreadId, + sourceMessageId: input.sourceMessageId, + title: input.title, + topicId: input.topicId, + type: input.type, + }), + input.id, + ); - return thread.id; - }), + return thread.id; + }), createThreadWithMessage: threadProcedure + .use(withScopedPermission('topic:create')) .input( createThreadSchema.extend({ message: z.any(), @@ -95,17 +102,21 @@ export const threadRouter = router({ return ctx.threadModel.queryByTopicId(input.topicId); }), - removeAllThreads: threadProcedure.mutation(async ({ ctx }) => { - return ctx.threadModel.deleteAll(); - }), + removeAllThreads: threadProcedure + .use(withScopedPermission('topic:delete')) + .mutation(async ({ ctx }) => { + return ctx.threadModel.deleteAll(); + }), removeThread: threadProcedure + .use(withScopedPermission('topic:delete')) .input(z.object({ id: z.string(), removeChildren: z.boolean().optional() })) .mutation(async ({ input, ctx }) => { return ctx.threadModel.delete(input.id); }), updateThread: threadProcedure + .use(withScopedPermission('topic:update')) .input( z.object({ id: z.string(), diff --git a/src/server/routers/lambda/topic.ts b/src/server/routers/lambda/topic.ts index e85555d330..1856562a62 100644 --- a/src/server/routers/lambda/topic.ts +++ b/src/server/routers/lambda/topic.ts @@ -8,6 +8,8 @@ import { eq, inArray } from 'drizzle-orm'; import { after } from 'next/server'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { AgentOperationModel } from '@/database/models/agentOperation'; import { MessageModel } from '@/database/models/message'; import { TopicModel } from '@/database/models/topic'; @@ -15,7 +17,7 @@ import { TopicShareModel } from '@/database/models/topicShare'; import { AgentMigrationRepo } from '@/database/repositories/agentMigration'; import { TopicImporterRepo } from '@/database/repositories/topicImporter'; import { agents, chatGroups, chatGroupsAgents } from '@/database/schemas'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { type BatchTaskResult } from '@/types/service'; @@ -26,16 +28,17 @@ import { } from './_helpers/resolveContext'; import { basicContextSchema } from './_schema/context'; -const topicProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const topicProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - agentMigrationRepo: new AgentMigrationRepo(ctx.serverDB, ctx.userId), - agentOperationModel: new AgentOperationModel(ctx.serverDB, ctx.userId), - topicImporterRepo: new TopicImporterRepo(ctx.serverDB, ctx.userId), - topicModel: new TopicModel(ctx.serverDB, ctx.userId), - topicShareModel: new TopicShareModel(ctx.serverDB, ctx.userId), + agentMigrationRepo: new AgentMigrationRepo(ctx.serverDB, ctx.userId, wsId), + agentOperationModel: new AgentOperationModel(ctx.serverDB, ctx.userId, wsId), + topicImporterRepo: new TopicImporterRepo(ctx.serverDB, ctx.userId, wsId), + topicModel: new TopicModel(ctx.serverDB, ctx.userId, wsId), + topicShareModel: new TopicShareModel(ctx.serverDB, ctx.userId, wsId), }, }); }); @@ -69,7 +72,8 @@ export const topicRouter = router({ } // Fallback: fetch recent messages with correct agentId/groupId - const messageModel = new MessageModel(ctx.serverDB, ctx.userId); + const wsId = ctx.workspaceId ?? undefined; + const messageModel = new MessageModel(ctx.serverDB, ctx.userId, wsId); const messages = await messageModel.query({ agentId: topic.agentId ?? undefined, groupId: topic.groupId ?? undefined, @@ -92,6 +96,7 @@ export const topicRouter = router({ }), batchCreateTopics: topicProcedure + .use(withScopedPermission('topic:create')) .input( z.array( z @@ -113,6 +118,7 @@ export const topicRouter = router({ { agentId, sessionId: rest.sessionId }, ctx.serverDB, ctx.userId, + ctx.workspaceId ?? undefined, ); return { ...rest, sessionId: resolved.sessionId }; }), @@ -124,18 +130,21 @@ export const topicRouter = router({ }), batchDelete: topicProcedure + .use(withScopedPermission('topic:delete')) .input(z.object({ ids: z.array(z.string()) })) .mutation(async ({ input, ctx }) => { return ctx.topicModel.batchDelete(input.ids); }), batchDeleteByAgentId: topicProcedure + .use(withScopedPermission('topic:delete')) .input(z.object({ agentId: z.string() })) .mutation(async ({ input, ctx }) => { return ctx.topicModel.batchDeleteByAgentId(input.agentId); }), batchDeleteBySessionId: topicProcedure + .use(withScopedPermission('topic:delete')) .input( z.object({ agentId: z.string().optional(), @@ -147,12 +156,14 @@ export const topicRouter = router({ { agentId: input.agentId, sessionId: input.id }, ctx.serverDB, ctx.userId, + ctx.workspaceId ?? undefined, ); return ctx.topicModel.batchDeleteBySessionId(resolved.sessionId); }), cloneTopic: topicProcedure + .use(withScopedPermission('topic:create')) .input(z.object({ id: z.string(), newTitle: z.string().optional() })) .mutation(async ({ input, ctx }) => { const data = await ctx.topicModel.duplicate(input.id, input.newTitle); @@ -177,6 +188,7 @@ export const topicRouter = router({ }), createTopic: topicProcedure + .use(withScopedPermission('topic:create')) .input( z .object({ @@ -194,6 +206,7 @@ export const topicRouter = router({ { agentId, sessionId: rest.sessionId }, ctx.serverDB, ctx.userId, + ctx.workspaceId ?? undefined, ); const data = await ctx.topicModel.create({ ...rest, sessionId: resolved.sessionId }); @@ -205,6 +218,7 @@ export const topicRouter = router({ * Disable sharing for a topic (deletes share record) */ disableSharing: topicProcedure + .use(withScopedPermission('topic:update')) .input(z.object({ topicId: z.string() })) .mutation(async ({ input, ctx }) => { return ctx.topicShareModel.deleteByTopicId(input.topicId); @@ -214,6 +228,7 @@ export const topicRouter = router({ * Enable sharing for a topic (creates share record) */ enableSharing: topicProcedure + .use(withScopedPermission('topic:update')) .input( z.object({ topicId: z.string(), @@ -288,7 +303,12 @@ export const topicRouter = router({ // If sessionId is provided but no agentId, need to reverse lookup agentId let effectiveAgentId = rest.agentId; if (!effectiveAgentId && sessionId) { - effectiveAgentId = await resolveAgentIdFromSession(sessionId, ctx.serverDB, ctx.userId); + effectiveAgentId = await resolveAgentIdFromSession( + sessionId, + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ); } const result = await ctx.topicModel.query({ @@ -310,6 +330,7 @@ export const topicRouter = router({ { agentId: effectiveAgentId }, ctx.serverDB, ctx.userId, + ctx.workspaceId ?? undefined, ); const migrationParams = isInbox @@ -338,6 +359,7 @@ export const topicRouter = router({ }), importTopic: topicProcedure + .use(withScopedPermission('topic:create')) .input( z.object({ agentId: z.string(), @@ -383,6 +405,7 @@ export const topicRouter = router({ sessionIds, ctx.serverDB, ctx.userId, + ctx.workspaceId ?? undefined, ); // Build agentId map: merge existing agentId with resolved ones @@ -517,11 +540,14 @@ export const topicRouter = router({ }); }), - removeAllTopics: topicProcedure.mutation(async ({ ctx }) => { - return ctx.topicModel.deleteAll(); - }), + removeAllTopics: topicProcedure + .use(withScopedPermission('topic:delete')) + .mutation(async ({ ctx }) => { + return ctx.topicModel.deleteAll(); + }), removeTopic: topicProcedure + .use(withScopedPermission('topic:delete')) .input(z.object({ id: z.string() })) .mutation(async ({ input, ctx }) => { return ctx.topicModel.delete(input.id); @@ -541,6 +567,7 @@ export const topicRouter = router({ { agentId: input.agentId, sessionId: input.sessionId }, ctx.serverDB, ctx.userId, + ctx.workspaceId ?? undefined, ); return ctx.topicModel.queryByKeyword(input.keywords, resolved.sessionId); @@ -550,6 +577,7 @@ export const topicRouter = router({ * Update share visibility */ updateShareVisibility: topicProcedure + .use(withScopedPermission('topic:update')) .input( z.object({ topicId: z.string(), @@ -561,6 +589,7 @@ export const topicRouter = router({ }), updateTopic: topicProcedure + .use(withScopedPermission('topic:update')) .input( z.object({ id: z.string(), @@ -599,7 +628,12 @@ export const topicRouter = router({ // If agentId is provided, resolve to sessionId let resolvedSessionId = restValue.sessionId; if (agentId && !resolvedSessionId) { - const resolved = await resolveContext({ agentId }, ctx.serverDB, ctx.userId); + const resolved = await resolveContext( + { agentId }, + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ); resolvedSessionId = resolved.sessionId ?? undefined; } @@ -607,6 +641,7 @@ export const topicRouter = router({ }), updateTopicMetadata: topicProcedure + .use(withScopedPermission('topic:update')) .input( z.object({ id: z.string(), diff --git a/src/server/routers/lambda/upload.ts b/src/server/routers/lambda/upload.ts index dcb644c7a9..159173f6c7 100644 --- a/src/server/routers/lambda/upload.ts +++ b/src/server/routers/lambda/upload.ts @@ -1,10 +1,12 @@ import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; import { authedProcedure, router } from '@/libs/trpc/lambda'; import { FileS3 } from '@/server/modules/S3'; export const uploadRouter = router({ createS3PreSignedUrl: authedProcedure + .use(withScopedPermission('file:upload')) .input(z.object({ pathname: z.string() })) .mutation(async ({ input }) => { const s3 = new FileS3(); diff --git a/src/server/routers/lambda/usage.ts b/src/server/routers/lambda/usage.ts index 651614cb96..9fd08170eb 100644 --- a/src/server/routers/lambda/usage.ts +++ b/src/server/routers/lambda/usage.ts @@ -1,14 +1,19 @@ import { z } from 'zod'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { UsageRecordService } from '@/server/services/usage'; -const usageProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const usageProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; return opts.next({ ctx: { - usageRecordService: new UsageRecordService(ctx.serverDB, ctx.userId), + usageRecordService: new UsageRecordService( + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ), }, }); }); diff --git a/src/server/routers/lambda/user.ts b/src/server/routers/lambda/user.ts index 848850c0af..53770568eb 100644 --- a/src/server/routers/lambda/user.ts +++ b/src/server/routers/lambda/user.ts @@ -29,6 +29,7 @@ import { onUserActivityForBusiness, } from '@/business/server/user'; import { MessageModel } from '@/database/models/message'; +import { RbacModel } from '@/database/models/rbac'; import { SessionModel } from '@/database/models/session'; import { UserModel } from '@/database/models/user'; import { authedProcedure, router } from '@/libs/trpc/lambda'; @@ -47,6 +48,10 @@ const usernameSchema = z .regex(/^\w+$/, { message: 'USERNAME_INVALID' }); const AVATAR_WEBAPI_PREFIX = '/webapi/'; +const OWNER_SETTING_KEYS = ['defaultAgent', 'image', 'memory', 'systemAgent', 'tts'] as const; +const MEMBER_SETTING_KEYS = ['tool'] as const; +const WORKSPACE_UPDATE_PERMISSION = 'workspace:update:all'; +const WORKSPACE_CONTENT_PERMISSIONS = ['agent:update:all', 'agent:update:owner'] as const; // Accept only: base64 data URL, absolute http(s) URL, empty string, // or an internal /webapi/user/avatar//... path scoped to the caller. @@ -70,10 +75,19 @@ const assertSafeAvatarInput = (input: string, userId: string) => { throw new TRPCError({ code: 'BAD_REQUEST', message: 'INVALID_AVATAR_URL' }); }; +const hasOwnerSettingChange = (input: Partial) => + OWNER_SETTING_KEYS.some((key) => input[key] !== undefined); + +const hasMemberSettingChange = (input: Partial) => + MEMBER_SETTING_KEYS.some((key) => input[key] !== undefined); + const userProcedure = authedProcedure.use(serverDatabase).use(async ({ ctx, next }) => { return next({ ctx: { fileService: new FileService(ctx.serverDB, ctx.userId), + // workspace-audit: intentionally personal-scoped (no workspaceId). These models + // only feed `getUserState`'s user-lifetime onboarding gates (hasConversation / + // canEnablePWAGuide / canEnableTrace), which are per-user, not per-workspace. messageModel: new MessageModel(ctx.serverDB, ctx.userId), sessionModel: new SessionModel(ctx.serverDB, ctx.userId), userModel: new UserModel(ctx.serverDB, ctx.userId), @@ -271,7 +285,11 @@ export const userRouter = router({ getOnboardingAgentContext: userProcedure.query(async ({ ctx }) => { const onboardingService = new OnboardingService(ctx.serverDB, ctx.userId); - const docService = new AgentDocumentsService(ctx.serverDB, ctx.userId); + const docService = new AgentDocumentsService( + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ); const { UserPersonaModel } = await import('@/database/models/userMemory/persona'); const personaModel = new UserPersonaModel(ctx.serverDB, ctx.userId); @@ -315,7 +333,11 @@ export const userRouter = router({ .query(async ({ ctx, input }) => { if (input.type === 'soul') { const onboardingService = new OnboardingService(ctx.serverDB, ctx.userId); - const docService = new AgentDocumentsService(ctx.serverDB, ctx.userId); + const docService = new AgentDocumentsService( + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ); const inboxAgentId = await onboardingService.getInboxAgentId(); const doc = await docService.getDocumentByFilename(inboxAgentId, 'SOUL.md'); @@ -342,7 +364,11 @@ export const userRouter = router({ .mutation(async ({ ctx, input }) => { if (input.type === 'soul') { const onboardingService = new OnboardingService(ctx.serverDB, ctx.userId); - const docService = new AgentDocumentsService(ctx.serverDB, ctx.userId); + const docService = new AgentDocumentsService( + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ); const inboxAgentId = await onboardingService.getInboxAgentId(); const doc = await docService.upsertDocumentByFilename({ agentId: inboxAgentId, @@ -407,7 +433,11 @@ export const userRouter = router({ const readCurrent = async (): Promise => { if (input.type === 'soul') { const onboardingService = new OnboardingService(ctx.serverDB, ctx.userId); - const docService = new AgentDocumentsService(ctx.serverDB, ctx.userId); + const docService = new AgentDocumentsService( + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ); const inboxAgentId = await onboardingService.getInboxAgentId(); const doc = await docService.getDocumentByFilename(inboxAgentId, 'SOUL.md'); return doc?.content ?? ''; @@ -431,7 +461,11 @@ export const userRouter = router({ if (input.type === 'soul') { const onboardingService = new OnboardingService(ctx.serverDB, ctx.userId); - const docService = new AgentDocumentsService(ctx.serverDB, ctx.userId); + const docService = new AgentDocumentsService( + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ); const inboxAgentId = await onboardingService.getInboxAgentId(); const doc = await docService.upsertDocumentByFilename({ agentId: inboxAgentId, @@ -476,6 +510,24 @@ export const userRouter = router({ updateSettings: userProcedure.input(UserSettingsSchema).mutation(async ({ ctx, input }) => { const { keyVaults, ...res } = input as Partial; + if (ctx.workspaceId && (hasOwnerSettingChange(res) || hasMemberSettingChange(res))) { + const rbac = new RbacModel(ctx.serverDB, ctx.userId); + const allowed = hasOwnerSettingChange(res) + ? await rbac.hasPermission(WORKSPACE_UPDATE_PERMISSION, { + workspaceId: ctx.workspaceId, + }) + : await rbac.hasAnyPermission([...WORKSPACE_CONTENT_PERMISSIONS], { + workspaceId: ctx.workspaceId, + }); + + if (!allowed) { + throw new TRPCError({ + code: 'FORBIDDEN', + message: 'You do not have permission to perform this action.', + }); + } + } + // Encrypt keyVaults let encryptedKeyVaults: string | null = null; diff --git a/src/server/routers/lambda/userMemories.ts b/src/server/routers/lambda/userMemories.ts index 7928659d14..6532985f1e 100644 --- a/src/server/routers/lambda/userMemories.ts +++ b/src/server/routers/lambda/userMemories.ts @@ -21,6 +21,7 @@ import { and, asc, eq, gte, lte } from 'drizzle-orm'; import pMap from 'p-map'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; import { type IdentityEntryBasePayload, type IdentityEntryPayload, @@ -255,6 +256,7 @@ const memoryProcedure = authedProcedure.use(serverDatabase).use(async (opts) => }, }); }); +const memoryWriteProcedure = memoryProcedure.use(withScopedPermission('message:create')); export const userMemoriesRouter = router({ getMemoryDetail: memoryProcedure @@ -453,7 +455,7 @@ export const userMemoriesRouter = router({ } }), - reEmbedMemories: memoryProcedure + reEmbedMemories: memoryWriteProcedure .input(reEmbedInputSchema.optional()) .mutation(async ({ ctx, input }) => { try { @@ -933,7 +935,11 @@ export const userMemoriesRouter = router({ try { // Get concatenated user messages for this topic - const userMemoryTopicRepo = new UserMemoryTopicRepository(ctx.serverDB, ctx.userId); + const userMemoryTopicRepo = new UserMemoryTopicRepository( + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ); const query = await userMemoryTopicRepo.getUserMessagesQueryForTopic(input.topicId); if (!query) { @@ -964,7 +970,7 @@ export const userMemoriesRouter = router({ } }), - toolAddActivityMemory: memoryProcedure + toolAddActivityMemory: memoryWriteProcedure .input(ActivityMemoryItemSchema) .mutation(async ({ input, ctx }) => { try { @@ -1026,7 +1032,7 @@ export const userMemoriesRouter = router({ } }), - toolAddContextMemory: memoryProcedure + toolAddContextMemory: memoryWriteProcedure .input(ContextMemoryItemSchema) .mutation(async ({ input, ctx }) => { try { @@ -1081,7 +1087,7 @@ export const userMemoriesRouter = router({ } }), - toolAddExperienceMemory: memoryProcedure + toolAddExperienceMemory: memoryWriteProcedure .input(ExperienceMemoryItemSchema) .mutation(async ({ input, ctx }) => { try { @@ -1137,7 +1143,7 @@ export const userMemoriesRouter = router({ } }), - toolAddIdentityMemory: memoryProcedure + toolAddIdentityMemory: memoryWriteProcedure .input(AddIdentityActionSchema) .mutation(async ({ input, ctx }) => { try { @@ -1205,7 +1211,7 @@ export const userMemoriesRouter = router({ } }), - toolAddPreferenceMemory: memoryProcedure + toolAddPreferenceMemory: memoryWriteProcedure .input(PreferenceMemoryItemSchema) .mutation(async ({ input, ctx }) => { try { @@ -1265,7 +1271,7 @@ export const userMemoriesRouter = router({ } }), - toolRemoveIdentityMemory: memoryProcedure + toolRemoveIdentityMemory: memoryWriteProcedure .input(RemoveIdentityActionSchema) .mutation(async ({ input, ctx }) => { try { @@ -1298,7 +1304,7 @@ export const userMemoriesRouter = router({ return result; }), - toolUpdateIdentityMemory: memoryProcedure + toolUpdateIdentityMemory: memoryWriteProcedure .input(UpdateIdentityActionSchema) .mutation(async ({ input, ctx }) => { try { diff --git a/src/server/routers/lambda/userMemory.ts b/src/server/routers/lambda/userMemory.ts index a608c7305d..d03ae5e78a 100644 --- a/src/server/routers/lambda/userMemory.ts +++ b/src/server/routers/lambda/userMemory.ts @@ -11,6 +11,8 @@ import { import { TRPCError } from '@trpc/server'; import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { AsyncTaskModel, initUserMemoryExtractionMetadata } from '@/database/models/asyncTask'; import { TopicModel } from '@/database/models/topic'; import { @@ -23,7 +25,7 @@ import { } from '@/database/models/userMemory'; import { UserPersonaModel } from '@/database/models/userMemory/persona'; import { appEnv } from '@/envs/app'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { parseMemoryExtractionConfig } from '@/server/globalConfig/parseMemoryExtractionConfig'; import { @@ -32,23 +34,25 @@ import { normalizeMemoryExtractionPayload, } from '@/server/services/memory/userMemory/extract'; -const userMemoryProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const userMemoryProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { activityModel: new UserMemoryActivityModel(ctx.serverDB, ctx.userId), - asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId), + asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId, wsId), contextModel: new UserMemoryContextModel(ctx.serverDB, ctx.userId), experienceModel: new UserMemoryExperienceModel(ctx.serverDB, ctx.userId), identityModel: new UserMemoryIdentityModel(ctx.serverDB, ctx.userId), personaModel: new UserPersonaModel(ctx.serverDB, ctx.userId), preferenceModel: new UserMemoryPreferenceModel(ctx.serverDB, ctx.userId), - topicModel: new TopicModel(ctx.serverDB, ctx.userId), + topicModel: new TopicModel(ctx.serverDB, ctx.userId, wsId), userMemoryModel: new UserMemoryModel(ctx.serverDB, ctx.userId), }, }); }); +const userMemoryWriteProcedure = userMemoryProcedure.use(withScopedPermission('message:create')); const userMemoryExtractionInputSchema = z.object({ fromDate: z.coerce.date().optional(), @@ -75,7 +79,7 @@ const getUserMemoryExtractionTimeoutMs = (metadata: UserMemoryExtractionMetadata export const userMemoryRouter = router({ // ============ Identity CRUD ============ - createIdentity: userMemoryProcedure + createIdentity: userMemoryWriteProcedure .input(CreateUserMemoryIdentitySchema) .mutation(async ({ ctx, input }) => { return ctx.userMemoryModel.addIdentityEntry({ @@ -92,40 +96,40 @@ export const userMemoryRouter = router({ }), // ============ Activity CRUD ============ - deleteActivity: userMemoryProcedure + deleteActivity: userMemoryWriteProcedure .input(z.object({ id: z.string() })) .mutation(async ({ ctx, input }) => { return ctx.activityModel.delete(input.id); }), - deleteAll: userMemoryProcedure.mutation(async ({ ctx }) => { + deleteAll: userMemoryWriteProcedure.mutation(async ({ ctx }) => { await ctx.userMemoryModel.deleteAll(); return { success: true }; }), // ============ Context CRUD ============ - deleteContext: userMemoryProcedure + deleteContext: userMemoryWriteProcedure .input(z.object({ id: z.string() })) .mutation(async ({ ctx, input }) => { return ctx.contextModel.delete(input.id); }), // ============ Experience CRUD ============ - deleteExperience: userMemoryProcedure + deleteExperience: userMemoryWriteProcedure .input(z.object({ id: z.string() })) .mutation(async ({ ctx, input }) => { return ctx.experienceModel.delete(input.id); }), - deleteIdentity: userMemoryProcedure + deleteIdentity: userMemoryWriteProcedure .input(z.object({ id: z.string() })) .mutation(async ({ ctx, input }) => { return ctx.userMemoryModel.removeIdentityEntry(input.id); }), // ============ Preference CRUD ============ - deletePreference: userMemoryProcedure + deletePreference: userMemoryWriteProcedure .input(z.object({ id: z.string() })) .mutation(async ({ ctx, input }) => { return ctx.preferenceModel.delete(input.id); @@ -216,7 +220,7 @@ export const userMemoryRouter = router({ return ctx.userMemoryModel.searchPreferences({}); }), - requestMemoryFromChatTopic: userMemoryProcedure + requestMemoryFromChatTopic: userMemoryWriteProcedure .input(userMemoryExtractionInputSchema) .mutation(async ({ ctx, input }) => { if (input.fromDate && input.toDate && input.fromDate > input.toDate) { @@ -326,7 +330,7 @@ export const userMemoryRouter = router({ }; }), - updateActivity: userMemoryProcedure + updateActivity: userMemoryWriteProcedure .input( z.object({ data: z.object({ @@ -341,7 +345,7 @@ export const userMemoryRouter = router({ return ctx.activityModel.update(input.id, input.data); }), - updateContext: userMemoryProcedure + updateContext: userMemoryWriteProcedure .input( z.object({ data: z.object({ @@ -356,7 +360,7 @@ export const userMemoryRouter = router({ return ctx.contextModel.update(input.id, input.data); }), - updateExperience: userMemoryProcedure + updateExperience: userMemoryWriteProcedure .input( z.object({ data: z.object({ @@ -371,7 +375,7 @@ export const userMemoryRouter = router({ return ctx.experienceModel.update(input.id, input.data); }), - updateIdentity: userMemoryProcedure + updateIdentity: userMemoryWriteProcedure .input( z.object({ data: UpdateUserMemoryIdentitySchema, @@ -392,7 +396,7 @@ export const userMemoryRouter = router({ }); }), - updatePreference: userMemoryProcedure + updatePreference: userMemoryWriteProcedure .input( z.object({ data: z.object({ diff --git a/src/server/routers/lambda/video/index.ts b/src/server/routers/lambda/video/index.ts index 17dc601f20..b3eebeba8f 100644 --- a/src/server/routers/lambda/video/index.ts +++ b/src/server/routers/lambda/video/index.ts @@ -15,6 +15,8 @@ import { after } from 'next/server'; import { z } from 'zod'; import { getProviderContentPolicyErrorMessage } from '@/business/server/getProviderContentPolicyErrorMessage'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { chargeAfterGenerate } from '@/business/server/video-generation/chargeAfterGenerate'; import { chargeBeforeGenerate } from '@/business/server/video-generation/chargeBeforeGenerate'; import { getVideoFreeQuota } from '@/business/server/video-generation/getVideoFreeQuota'; @@ -39,17 +41,20 @@ import { createVideoTaskSubmitError } from './error'; const log = debug('lobe-video:lambda'); -const videoProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const videoProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; + const wsId = ctx.workspaceId ?? undefined; return opts.next({ ctx: { - asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId), - fileService: new FileService(ctx.serverDB, ctx.userId), + asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId, wsId), + fileService: new FileService(ctx.serverDB, ctx.userId, wsId), }, }); }); +const videoCreateProcedure = videoProcedure.use(withScopedPermission('file:upload')); + const createVideoInputSchema = z.object({ generationTopicId: z.string(), model: z.string(), @@ -71,284 +76,292 @@ const createVideoInputSchema = z.object({ export type CreateVideoServicePayload = z.infer; export const videoRouter = router({ - createVideo: videoProcedure.input(createVideoInputSchema).mutation(async ({ input, ctx }) => { - const { userId, serverDB, asyncTaskModel, fileService } = ctx; - const { generationTopicId, provider, model, params } = input; + createVideo: videoCreateProcedure + .input(createVideoInputSchema) + .mutation(async ({ input, ctx }) => { + const { userId, serverDB, asyncTaskModel, fileService } = ctx; + const wsId = ctx.workspaceId ?? undefined; + const { generationTopicId, provider, model, params } = input; - const { resolvedModelId } = await resolveBusinessModelMapping(provider, model); + const { resolvedModelId } = await resolveBusinessModelMapping(provider, model); - // Reject lobehub model ids that are no longer in the model bank so callers get a - // clear error instead of an opaque downstream failure when the resolved channel - // model is no longer in the model bank. - if ( - provider === BRANDING_PROVIDER && - !isProviderModelAvailable(await loadModels(), BRANDING_PROVIDER, resolvedModelId, 'video') - ) { - throw new TRPCError({ - cause: { data: { modelType: 'video', requestedModel: model } }, - code: 'BAD_REQUEST', - message: ChatErrorType.LobeHubModelDeprecated, - }); - } - - log('Starting video creation process, input: %O', input); - - // Normalize image URLs to S3 keys for database storage - let configForDatabase = { ...params }; - - // Process first-frame imageUrl - if (typeof params.imageUrl === 'string' && params.imageUrl) { - try { - const key = await fileService.getKeyFromFullUrl(params.imageUrl); - if (key) { - log('Converted imageUrl to key: %s -> %s', params.imageUrl, key); - configForDatabase = { ...configForDatabase, imageUrl: key }; - } - } catch (error) { - console.error('Error converting imageUrl to key: %O', error); + // Reject lobehub model ids that are no longer in the model bank so callers get a + // clear error instead of an opaque downstream failure when the resolved channel + // model is no longer in the model bank. + if ( + provider === BRANDING_PROVIDER && + !isProviderModelAvailable(await loadModels(), BRANDING_PROVIDER, resolvedModelId, 'video') + ) { + throw new TRPCError({ + cause: { data: { modelType: 'video', requestedModel: model } }, + code: 'BAD_REQUEST', + message: ChatErrorType.LobeHubModelDeprecated, + }); } - } - // Process last-frame endImageUrl - if (typeof params.endImageUrl === 'string' && params.endImageUrl) { - try { - const key = await fileService.getKeyFromFullUrl(params.endImageUrl); - if (key) { - log('Converted endImageUrl to key: %s -> %s', params.endImageUrl, key); - configForDatabase = { ...configForDatabase, endImageUrl: key }; - } - } catch (error) { - console.error('Error converting endImageUrl to key: %O', error); - } - } + log('Starting video creation process, input: %O', input); - // In development, convert localhost proxy URLs to S3 URLs for API access - let generationParams = params; - if (process.env.NODE_ENV === 'development') { - const updates: Record = {}; + // Normalize image URLs to S3 keys for database storage + let configForDatabase = { ...params }; + // Process first-frame imageUrl if (typeof params.imageUrl === 'string' && params.imageUrl) { - const s3Url = await fileService.getFullFileUrl(configForDatabase.imageUrl as string); - if (s3Url) { - log('Dev: converted imageUrl proxy URL to S3 URL: %s -> %s', params.imageUrl, s3Url); - updates.imageUrl = s3Url; + try { + const key = await fileService.getKeyFromFullUrl(params.imageUrl); + if (key) { + log('Converted imageUrl to key: %s -> %s', params.imageUrl, key); + configForDatabase = { ...configForDatabase, imageUrl: key }; + } + } catch (error) { + console.error('Error converting imageUrl to key: %O', error); } } + // Process last-frame endImageUrl if (typeof params.endImageUrl === 'string' && params.endImageUrl) { - const s3Url = await fileService.getFullFileUrl(configForDatabase.endImageUrl as string); - if (s3Url) { - log( - 'Dev: converted endImageUrl proxy URL to S3 URL: %s -> %s', - params.endImageUrl, - s3Url, - ); - updates.endImageUrl = s3Url; + try { + const key = await fileService.getKeyFromFullUrl(params.endImageUrl); + if (key) { + log('Converted endImageUrl to key: %s -> %s', params.endImageUrl, key); + configForDatabase = { ...configForDatabase, endImageUrl: key }; + } + } catch (error) { + console.error('Error converting endImageUrl to key: %O', error); } } - if (Object.keys(updates).length > 0) { - generationParams = { ...params, ...updates }; + // In development, convert localhost proxy URLs to S3 URLs for API access + let generationParams = params; + if (process.env.NODE_ENV === 'development') { + const updates: Record = {}; + + if (typeof params.imageUrl === 'string' && params.imageUrl) { + const s3Url = await fileService.getFullFileUrl(configForDatabase.imageUrl as string); + if (s3Url) { + log('Dev: converted imageUrl proxy URL to S3 URL: %s -> %s', params.imageUrl, s3Url); + updates.imageUrl = s3Url; + } + } + + if (typeof params.endImageUrl === 'string' && params.endImageUrl) { + const s3Url = await fileService.getFullFileUrl(configForDatabase.endImageUrl as string); + if (s3Url) { + log( + 'Dev: converted endImageUrl proxy URL to S3 URL: %s -> %s', + params.endImageUrl, + s3Url, + ); + updates.endImageUrl = s3Url; + } + } + + if (Object.keys(updates).length > 0) { + generationParams = { ...params, ...updates }; + } } - } - // Step 0: Pre-charge (atomic budget deduction to prevent concurrent abuse) - const { errorBatch, prechargeResult } = await chargeBeforeGenerate({ - generationTopicId, - model, - params, - provider, - userId, - }); - if (errorBatch) return errorBatch; - - // Generate a one-time token for webhook callback verification - const webhookToken = randomBytes(32).toString('hex'); - - // Step 1: Atomically create all database records in a transaction - const { - asyncTaskCreatedAt, - asyncTaskId, - batch: createdBatch, - generation: createdGeneration, - } = await serverDB.transaction(async (tx) => { - log('Starting database transaction for video generation'); - - // 1. Create generationBatch - const newBatch: NewGenerationBatch = { - config: configForDatabase, + // Step 0: Pre-charge (atomic budget deduction to prevent concurrent abuse) + const { errorBatch, prechargeResult } = await chargeBeforeGenerate({ generationTopicId, model, - prompt: params.prompt, + params, provider, userId, - }; - log('Creating generation batch: %O', newBatch); - const [batch] = await tx.insert(generationBatches).values(newBatch).returning(); - log('Generation batch created: %s', batch.id); + workspaceId: wsId, + }); + if (errorBatch) return errorBatch; - // 2. Create single generation (video is always 1) - const newGeneration: NewGeneration = { - generationBatchId: batch.id, - seed: params.seed ?? null, - userId, - }; - const [generation] = await tx.insert(generations).values(newGeneration).returning(); - log('Generation created: %s', generation.id); + // Generate a one-time token for webhook callback verification + const webhookToken = randomBytes(32).toString('hex'); - // 3. Create asyncTask with precharge metadata - const [asyncTask] = await tx - .insert(asyncTasks) - .values({ - metadata: { - ...(prechargeResult ? { precharge: prechargeResult } : {}), - webhookToken, - }, - status: AsyncTaskStatus.Pending, - type: AsyncTaskType.VideoGeneration, + // Step 1: Atomically create all database records in a transaction + const { + asyncTaskCreatedAt, + asyncTaskId, + batch: createdBatch, + generation: createdGeneration, + } = await serverDB.transaction(async (tx) => { + log('Starting database transaction for video generation'); + + // 1. Create generationBatch + const newBatch: NewGenerationBatch = { + config: configForDatabase, + generationTopicId, + model, + prompt: params.prompt, + provider, userId, - }) - .returning(); - log('Async task created: %s', asyncTask.id); + workspaceId: wsId, + }; + log('Creating generation batch: %O', newBatch); + const [batch] = await tx.insert(generationBatches).values(newBatch).returning(); + log('Generation batch created: %s', batch.id); - // 4. Link asyncTask to generation - await tx - .update(generations) - .set({ asyncTaskId: asyncTask.id }) - .where(and(eq(generations.id, generation.id), eq(generations.userId, userId))); + // 2. Create single generation (video is always 1) + const newGeneration: NewGeneration = { + generationBatchId: batch.id, + seed: params.seed ?? null, + userId, + workspaceId: wsId, + }; + const [generation] = await tx.insert(generations).values(newGeneration).returning(); + log('Generation created: %s', generation.id); - return { - asyncTaskCreatedAt: asyncTask.createdAt, - asyncTaskId: asyncTask.id, - batch, - generation, - }; - }); + // 3. Create asyncTask with precharge metadata + const [asyncTask] = await tx + .insert(asyncTasks) + .values({ + metadata: { + ...(prechargeResult ? { precharge: prechargeResult } : {}), + webhookToken, + }, + status: AsyncTaskStatus.Pending, + type: AsyncTaskType.VideoGeneration, + userId, + workspaceId: wsId, + }) + .returning(); + log('Async task created: %s', asyncTask.id); - log('Database transaction completed. Calling model runtime for video generation.'); + // 4. Link asyncTask to generation + await tx + .update(generations) + .set({ asyncTaskId: asyncTask.id }) + .where(and(eq(generations.id, generation.id), eq(generations.userId, userId))); - // Step 2: Call model runtime to submit video generation task - try { - const modelRuntime = await initModelRuntimeFromDB(serverDB, userId, provider); + return { + asyncTaskCreatedAt: asyncTask.createdAt, + asyncTaskId: asyncTask.id, + batch, + generation, + }; + }); - const callbackBaseUrl = process.env.WEBHOOK_PROXY_URL || appEnv.APP_URL; - const callbackUrl = `${callbackBaseUrl}/api/webhooks/video/${provider}?token=${webhookToken}`; - log('Using callback URL: %s', callbackUrl); + log('Database transaction completed. Calling model runtime for video generation.'); - const response = await modelRuntime.createVideo( - { - callbackUrl, - model: resolvedModelId, - params: generationParams, - }, - { metadata: { trigger: RequestTrigger.Video } }, - ); + // Step 2: Call model runtime to submit video generation task + try { + const modelRuntime = await initModelRuntimeFromDB(serverDB, userId, provider, wsId); - log('Video task submitted successfully, inferenceId: %s', response?.inferenceId); + const callbackBaseUrl = process.env.WEBHOOK_PROXY_URL || appEnv.APP_URL; + const callbackUrl = `${callbackBaseUrl}/api/webhooks/video/${provider}?token=${webhookToken}`; + log('Using callback URL: %s', callbackUrl); - // Determine async strategy based on response: - // - useWebhook: provider registered a callback URL, wait for webhook - // - otherwise: use background polling to check status - const useWebhook = response && 'useWebhook' in response && response.useWebhook; - - if (useWebhook) { - // Webhook-based provider (e.g. Volcengine): wait for callback - log('Webhook-based provider detected, waiting for callback'); - - await asyncTaskModel.update(asyncTaskId, { - inferenceId: response?.inferenceId, - status: AsyncTaskStatus.Processing, - }); - } else if (response) { - // Polling-based provider (e.g. OpenAI Sora): use background polling - log( - 'Polling-based provider detected (inferenceId only), using after() for background polling', + const response = await modelRuntime.createVideo( + { + callbackUrl, + model: resolvedModelId, + params: generationParams, + }, + { metadata: { trigger: RequestTrigger.Video } }, ); + log('Video task submitted successfully, inferenceId: %s', response?.inferenceId); + + // Determine async strategy based on response: + // - useWebhook: provider registered a callback URL, wait for webhook + // - otherwise: use background polling to check status + const useWebhook = response && 'useWebhook' in response && response.useWebhook; + + if (useWebhook) { + // Webhook-based provider (e.g. Volcengine): wait for callback + log('Webhook-based provider detected, waiting for callback'); + + await asyncTaskModel.update(asyncTaskId, { + inferenceId: response?.inferenceId, + status: AsyncTaskStatus.Processing, + }); + } else if (response) { + // Polling-based provider (e.g. OpenAI Sora): use background polling + log( + 'Polling-based provider detected (inferenceId only), using after() for background polling', + ); + + await asyncTaskModel.update(asyncTaskId, { + inferenceId: response.inferenceId, + status: AsyncTaskStatus.Processing, + }); + + after(async () => { + log('After() hook executing background video polling for task: %s', asyncTaskId); + + try { + const db = await getServerDB(); + + await processBackgroundVideoPolling(db, { + asyncTaskCreatedAt, + asyncTaskId, + generationBatchId: createdBatch.id, + generationId: createdGeneration.id, + generationTopicId, + inferenceId: response.inferenceId, + model, + prechargeResult, + provider, + userId, + workspaceId: wsId, + }); + + log('Background video polling completed for task: %s', asyncTaskId); + } catch (error) { + console.error('[video] Background polling failed:', error); + } + }); + + log('After() hook registered for background video polling: %s', asyncTaskId); + } + } catch (e) { + console.error('Failed to submit video generation task:', e); + + const providerContentPolicyMessage = await getProviderContentPolicyErrorMessage({ + error: e, + provider, + trigger: RequestTrigger.Video, + userId, + }); await asyncTaskModel.update(asyncTaskId, { - inferenceId: response.inferenceId, - status: AsyncTaskStatus.Processing, + error: createVideoTaskSubmitError(e, providerContentPolicyMessage), + status: AsyncTaskStatus.Error, }); - after(async () => { - log('After() hook executing background video polling for task: %s', asyncTaskId); - + if (prechargeResult) { try { - const db = await getServerDB(); - - await processBackgroundVideoPolling(db, { - asyncTaskCreatedAt, - asyncTaskId, - generationBatchId: createdBatch.id, - generationId: createdGeneration.id, - generationTopicId, - inferenceId: response.inferenceId, - model, + await chargeAfterGenerate({ + isError: true, + metadata: { + asyncTaskId, + generationBatchId: createdBatch.id, + topicId: generationTopicId, + ...buildMappedBusinessModelFields({ + provider, + requestedModelId: resolvedModelId === model ? undefined : model, + resolvedModelId, + }), + }, + model: resolvedModelId, prechargeResult, provider, userId, }); - - log('Background video polling completed for task: %s', asyncTaskId); - } catch (error) { - console.error('[video] Background polling failed:', error); + } catch (chargeError) { + console.error('[video] chargeAfterGenerate failed:', chargeError); } - }); - - log('After() hook registered for background video polling: %s', asyncTaskId); - } - } catch (e) { - console.error('Failed to submit video generation task:', e); - - const providerContentPolicyMessage = await getProviderContentPolicyErrorMessage({ - error: e, - provider, - trigger: RequestTrigger.Video, - userId, - }); - await asyncTaskModel.update(asyncTaskId, { - error: createVideoTaskSubmitError(e, providerContentPolicyMessage), - status: AsyncTaskStatus.Error, - }); - - if (prechargeResult) { - try { - await chargeAfterGenerate({ - isError: true, - metadata: { - asyncTaskId, - generationBatchId: createdBatch.id, - topicId: generationTopicId, - ...buildMappedBusinessModelFields({ - provider, - requestedModelId: resolvedModelId === model ? undefined : model, - resolvedModelId, - }), - }, - model: resolvedModelId, - prechargeResult, - provider, - userId, - }); - } catch (chargeError) { - console.error('[video] chargeAfterGenerate failed:', chargeError); } } - } - log('Video creation process completed: %O', { - batchId: createdBatch.id, - generationId: createdGeneration.id, - }); + log('Video creation process completed: %O', { + batchId: createdBatch.id, + generationId: createdGeneration.id, + }); - return { - data: { - batch: createdBatch, - generations: [{ ...createdGeneration, asyncTaskId }], - }, - success: true, - }; - }), + return { + data: { + batch: createdBatch, + generations: [{ ...createdGeneration, asyncTaskId }], + }, + success: true, + }; + }), getVideoFreeQuota: authedProcedure .input(z.object({ model: z.string() })) diff --git a/src/server/routers/lambda/webBrowsing.ts b/src/server/routers/lambda/webBrowsing.ts index d8699cd2a6..7f5459be96 100644 --- a/src/server/routers/lambda/webBrowsing.ts +++ b/src/server/routers/lambda/webBrowsing.ts @@ -1,14 +1,19 @@ import { z } from 'zod'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { WebBrowsingDocumentService } from '@/server/services/webBrowsing'; -const webBrowsingProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const webBrowsingProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; return opts.next({ ctx: { - webBrowsingService: new WebBrowsingDocumentService(ctx.serverDB, ctx.userId), + webBrowsingService: new WebBrowsingDocumentService( + ctx.serverDB, + ctx.userId, + ctx.workspaceId ?? undefined, + ), }, }); }); diff --git a/src/server/routers/mobile/topic.ts b/src/server/routers/mobile/topic.ts index 13e9562dac..62b58a6cba 100644 --- a/src/server/routers/mobile/topic.ts +++ b/src/server/routers/mobile/topic.ts @@ -1,21 +1,29 @@ import { z } from 'zod'; +import { withScopedPermission } from '@/business/server/trpc-middlewares/rbacPermission'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { TopicModel } from '@/database/models/topic'; import { getServerDB } from '@/database/server'; -import { authedProcedure, publicProcedure, router } from '@/libs/trpc/lambda'; +import { publicProcedure, router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { type BatchTaskResult } from '@/types/service'; -const topicProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const topicProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const { ctx } = opts; return opts.next({ - ctx: { topicModel: new TopicModel(ctx.serverDB, ctx.userId) }, + ctx: { + topicModel: new TopicModel(ctx.serverDB, ctx.userId, ctx.workspaceId ?? undefined), + }, }); }); +const topicCreateProcedure = topicProcedure.use(withScopedPermission('topic:create')); +const topicDeleteProcedure = topicProcedure.use(withScopedPermission('topic:delete')); +const topicUpdateProcedure = topicProcedure.use(withScopedPermission('topic:update')); + export const topicRouter = router({ - batchCreateTopics: topicProcedure + batchCreateTopics: topicCreateProcedure .input( z.array( z.object({ @@ -37,19 +45,19 @@ export const topicRouter = router({ return { added: data.length, ids: [], skips: [], success: true }; }), - batchDelete: topicProcedure + batchDelete: topicDeleteProcedure .input(z.object({ ids: z.array(z.string()) })) .mutation(async ({ input, ctx }) => { return ctx.topicModel.batchDelete(input.ids); }), - batchDeleteBySessionId: topicProcedure + batchDeleteBySessionId: topicDeleteProcedure .input(z.object({ id: z.string().nullable().optional() })) .mutation(async ({ input, ctx }) => { return ctx.topicModel.batchDeleteBySessionId(input.id); }), - cloneTopic: topicProcedure + cloneTopic: topicCreateProcedure .input(z.object({ id: z.string(), newTitle: z.string().optional() })) .mutation(async ({ input, ctx }) => { const data = await ctx.topicModel.duplicate(input.id, input.newTitle); @@ -71,7 +79,7 @@ export const topicRouter = router({ return ctx.topicModel.count(input); }), - createTopic: topicProcedure + createTopic: topicCreateProcedure .input( z.object({ favorite: z.boolean().optional(), @@ -104,7 +112,7 @@ export const topicRouter = router({ if (!ctx.userId) return []; const serverDB = await getServerDB(); - const topicModel = new TopicModel(serverDB, ctx.userId); + const topicModel = new TopicModel(serverDB, ctx.userId, ctx.workspaceId ?? undefined); return topicModel.query(input); }), @@ -117,11 +125,11 @@ export const topicRouter = router({ return ctx.topicModel.rank(input); }), - removeAllTopics: topicProcedure.mutation(async ({ ctx }) => { + removeAllTopics: topicDeleteProcedure.mutation(async ({ ctx }) => { return ctx.topicModel.deleteAll(); }), - removeTopic: topicProcedure + removeTopic: topicDeleteProcedure .input(z.object({ id: z.string() })) .mutation(async ({ input, ctx }) => { return ctx.topicModel.delete(input.id); @@ -139,7 +147,7 @@ export const topicRouter = router({ return ctx.topicModel.queryByKeyword(input.keywords, input.sessionId); }), - updateTopic: topicProcedure + updateTopic: topicUpdateProcedure .input( z.object({ id: z.string(), diff --git a/src/server/routers/tools/klavis.ts b/src/server/routers/tools/klavis.ts index 5343dc23d5..0d389700be 100644 --- a/src/server/routers/tools/klavis.ts +++ b/src/server/routers/tools/klavis.ts @@ -1,17 +1,18 @@ import { z } from 'zod'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { ConnectorModel } from '@/database/models/connector'; import { ConnectorToolModel } from '@/database/models/connectorTool'; import { ConnectorToolPermission } from '@/database/schemas'; import { getKlavisClient } from '@/libs/klavis'; -import { authedProcedure, publicProcedure, router } from '@/libs/trpc/lambda'; +import { publicProcedure, router } from '@/libs/trpc/lambda'; import { serverDatabase } from '@/libs/trpc/lambda/middleware'; import { MCPService } from '@/server/services/mcp'; /** * Klavis procedure with client initialized in context */ -const klavisProcedure = authedProcedure.use(serverDatabase).use(async (opts) => { +const klavisProcedure = wsCompatProcedure.use(serverDatabase).use(async (opts) => { const klavisClient = getKlavisClient(); return opts.next({ @@ -43,13 +44,14 @@ export const klavisRouter = router({ // same-name collisions across connectors). Falls back to toolName-only // if identifier is absent (legacy callers). if (ctx.userId && ctx.serverDB) { - const connectorToolModel = new ConnectorToolModel(ctx.serverDB, ctx.userId); + const wsId = ctx.workspaceId ?? undefined; + const connectorToolModel = new ConnectorToolModel(ctx.serverDB, ctx.userId, wsId); let connectorTool: | Awaited> | undefined; if (input.identifier) { - const connectorModel = new ConnectorModel(ctx.serverDB, ctx.userId); + const connectorModel = new ConnectorModel(ctx.serverDB, ctx.userId, wsId); const [connector] = await connectorModel.queryByIdentifiers([input.identifier]); if (connector) { const tools = await connectorToolModel.queryByConnector(connector.id); diff --git a/src/server/routers/tools/market.ts b/src/server/routers/tools/market.ts index 7ef6297d31..f1575364a9 100644 --- a/src/server/routers/tools/market.ts +++ b/src/server/routers/tools/market.ts @@ -3,6 +3,7 @@ import { TRPCError } from '@trpc/server'; import debug from 'debug'; import { z } from 'zod'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { AgentSkillModel } from '@/database/models/agentSkill'; import { FileModel } from '@/database/models/file'; import { type ToolCallContent } from '@/libs/mcp'; @@ -50,7 +51,7 @@ const throwSandboxAuthError = () => { }; // ============================== Common Procedure ============================== -const marketToolProcedure = authedProcedure +const marketToolProcedure = wsCompatProcedure .use(serverDatabase) .use(telemetry) .use(marketUserInfo) @@ -58,18 +59,23 @@ const marketToolProcedure = authedProcedure const { UserModel } = await import('@/database/models/user'); const userModel = new UserModel(ctx.serverDB, ctx.userId); + // In a workspace context, sandbox runtime calls are attributed to the + // workspace's Market organization via the workspaceId carried in the trust + // token (`ctx.marketUserInfo.workspaceId`, set by the marketUserInfo + // middleware). Falls back to the personal account when there's no workspace. return next({ ctx: { discoverService: new DiscoverService({ accessToken: ctx.marketAccessToken, userInfo: ctx.marketUserInfo, }), - fileService: new FileService(ctx.serverDB, ctx.userId), + fileService: new FileService(ctx.serverDB, ctx.userId, ctx.workspaceId ?? undefined), marketService: new MarketService({ accessToken: ctx.marketAccessToken, userInfo: ctx.marketUserInfo, }), userModel, + workspaceId: ctx.workspaceId, }, }); }); @@ -165,7 +171,13 @@ const execInSandboxHandler = async ({ input, ctx, }: { - ctx: { fileService: FileService; marketService: MarketService; serverDB: any; userId: string }; + ctx: { + fileService: FileService; + marketService: MarketService; + serverDB: any; + userId: string; + workspaceId?: string | null; + }; input: ExecInSandboxInput; }): Promise => { const { toolName, params, topicId } = input; @@ -198,8 +210,9 @@ const execInSandboxHandler = async ({ // For execScript tool, look up skill zipUrls from activatedSkills if (toolName === 'execScript' && enhancedParams.activatedSkills?.length) { - const agentSkillModel = new AgentSkillModel(ctx.serverDB, userId); - const fileModel = new FileModel(ctx.serverDB, userId); + const wsId = ctx.workspaceId ?? undefined; + const agentSkillModel = new AgentSkillModel(ctx.serverDB, userId, wsId); + const fileModel = new FileModel(ctx.serverDB, userId, wsId); // Resolve zipUrls for all activated skills const skillZipUrls: Record = {}; diff --git a/src/server/routers/tools/mcp.ts b/src/server/routers/tools/mcp.ts index dfd0bf7ae1..53b8fa1cab 100644 --- a/src/server/routers/tools/mcp.ts +++ b/src/server/routers/tools/mcp.ts @@ -6,11 +6,12 @@ import { import { TRPCError } from '@trpc/server'; import { z } from 'zod'; +import { wsCompatProcedure } from '@/business/server/trpc-middlewares/workspaceAuth'; import { ConnectorModel } from '@/database/models/connector'; import { ConnectorToolModel } from '@/database/models/connectorTool'; import { ConnectorToolPermission } from '@/database/schemas'; import { type ToolCallContent } from '@/libs/mcp'; -import { authedProcedure, router } from '@/libs/trpc/lambda'; +import { router } from '@/libs/trpc/lambda'; import { serverDatabase, telemetry } from '@/libs/trpc/lambda/middleware'; import { FileService } from '@/server/services/file'; import { mcpService } from '@/server/services/mcp'; @@ -66,13 +67,13 @@ const metaSchema = z }) .optional(); -const mcpProcedure = authedProcedure +const mcpProcedure = wsCompatProcedure .use(serverDatabase) .use(telemetry) .use(async ({ ctx, next }) => { return next({ ctx: { - fileService: new FileService(ctx.serverDB, ctx.userId), + fileService: new FileService(ctx.serverDB, ctx.userId, ctx.workspaceId ?? undefined), }, }); }); @@ -144,10 +145,11 @@ export const mcpRouter = router({ // here at execution time so the MCP server is never actually called. const connectorId = input.params.name; if (connectorId && ctx.userId) { - const connectorModel = new ConnectorModel(ctx.serverDB, ctx.userId); + const wsId = ctx.workspaceId ?? undefined; + const connectorModel = new ConnectorModel(ctx.serverDB, ctx.userId, wsId); const [connector] = await connectorModel.queryByIdentifiers([connectorId]); if (connector) { - const connectorToolModel = new ConnectorToolModel(ctx.serverDB, ctx.userId); + const connectorToolModel = new ConnectorToolModel(ctx.serverDB, ctx.userId, wsId); const tools = await connectorToolModel.queryByConnector(connector.id); const tool = tools.find((t) => t.toolName === input.toolName); if (tool?.permission === ConnectorToolPermission.disabled) { diff --git a/src/server/services/agent/index.test.ts b/src/server/services/agent/index.test.ts index a80775434e..9eb49cd5cf 100644 --- a/src/server/services/agent/index.test.ts +++ b/src/server/services/agent/index.test.ts @@ -53,6 +53,7 @@ describe('AgentService', () => { let service: AgentService; const mockDb = {} as any; const mockUserId = 'test-user-id'; + const mockWorkspaceId = 'workspace-1'; // Default mock for UserModel that returns empty settings const mockUserModel = { @@ -79,7 +80,7 @@ describe('AgentService', () => { await service.createInbox(); - expect(SessionModel).toHaveBeenCalledWith(mockDb, mockUserId); + expect(SessionModel).toHaveBeenCalledWith(mockDb, mockUserId, undefined); expect(parseAgentConfig).toHaveBeenCalledWith('model=gpt-4;temperature=0.7'); expect(mockSessionModel.createInbox).toHaveBeenCalledWith(mockConfig); }); @@ -94,10 +95,25 @@ describe('AgentService', () => { await service.createInbox(); - expect(SessionModel).toHaveBeenCalledWith(mockDb, mockUserId); + expect(SessionModel).toHaveBeenCalledWith(mockDb, mockUserId, undefined); expect(parseAgentConfig).toHaveBeenCalledWith('model=gpt-4;temperature=0.7'); expect(mockSessionModel.createInbox).toHaveBeenCalledWith({}); }); + + it('should create workspace inbox in the active workspace scope', async () => { + const mockSessionModel = { + createInbox: vi.fn(), + }; + + (SessionModel as any).mockImplementation(() => mockSessionModel); + (parseAgentConfig as any).mockReturnValue({}); + + const workspaceService = new AgentService(mockDb, mockUserId, mockWorkspaceId); + await workspaceService.createInbox(); + + expect(SessionModel).toHaveBeenCalledWith(mockDb, mockUserId, mockWorkspaceId); + expect(mockSessionModel.createInbox).toHaveBeenCalledWith({}); + }); }); describe('getBuiltinAgent', () => { diff --git a/src/server/services/agent/index.ts b/src/server/services/agent/index.ts index 2930f2ee57..febe845080 100644 --- a/src/server/services/agent/index.ts +++ b/src/server/services/agent/index.ts @@ -46,16 +46,18 @@ export class AgentService { private readonly db: LobeChatDatabase; private readonly agentModel: AgentModel; private readonly userModel: UserModel; + private readonly workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; - this.agentModel = new AgentModel(db, userId); + this.workspaceId = workspaceId; + this.agentModel = new AgentModel(db, userId, workspaceId); this.userModel = new UserModel(db, userId); } async createInbox() { - const sessionModel = new SessionModel(this.db, this.userId); + const sessionModel = new SessionModel(this.db, this.userId, this.workspaceId); const defaultAgentConfig = getServerDefaultAgentConfig(); await sessionModel.createInbox(defaultAgentConfig); } diff --git a/src/server/services/agentDocumentVfs/index.ts b/src/server/services/agentDocumentVfs/index.ts index 651da0e2eb..e41b4c48a6 100644 --- a/src/server/services/agentDocumentVfs/index.ts +++ b/src/server/services/agentDocumentVfs/index.ts @@ -96,9 +96,9 @@ export class AgentDocumentVfsService { private agentDocumentModel: AgentDocumentModel; private skillMount: SkillMount; - constructor(db: LobeChatDatabase, userId: string) { - this.agentDocumentModel = new AgentDocumentModel(db, userId); - this.skillMount = createSkillMount(db, userId); + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { + this.agentDocumentModel = new AgentDocumentModel(db, userId, workspaceId); + this.skillMount = createSkillMount(db, userId, workspaceId); } /** diff --git a/src/server/services/agentDocumentVfs/mounts/skills/createSkillMount.ts b/src/server/services/agentDocumentVfs/mounts/skills/createSkillMount.ts index 7fb3466a17..dc7ef0e49e 100644 --- a/src/server/services/agentDocumentVfs/mounts/skills/createSkillMount.ts +++ b/src/server/services/agentDocumentVfs/mounts/skills/createSkillMount.ts @@ -25,11 +25,11 @@ import { SkillMount } from './SkillMount'; * Returns: * - A skill mount that routes unified skill paths to namespace-specific providers. */ -export const createSkillMount = (db: LobeChatDatabase, userId: string) => { - const agentModel = new AgentModel(db, userId); - const agentDocumentModel = new AgentDocumentModel(db, userId); - const documentService = new DocumentService(db, userId); - const skillModel = new AgentSkillModel(db, userId); +export const createSkillMount = (db: LobeChatDatabase, userId: string, workspaceId?: string) => { + const agentModel = new AgentModel(db, userId, workspaceId); + const agentDocumentModel = new AgentDocumentModel(db, userId, workspaceId); + const documentService = new DocumentService(db, userId, workspaceId); + const skillModel = new AgentSkillModel(db, userId, workspaceId); const skillResourceService = new SkillResourceService(db, userId); return new SkillMount({ 'agent': new ProviderSkillsAgentDocument('agent', { diff --git a/src/server/services/agentDocuments/index.ts b/src/server/services/agentDocuments/index.ts index 7d3ebb766a..c7cc6d0078 100644 --- a/src/server/services/agentDocuments/index.ts +++ b/src/server/services/agentDocuments/index.ts @@ -126,10 +126,10 @@ export class AgentDocumentsService { private documentService: DocumentService; private topicDocumentModel: TopicDocumentModel; - constructor(db: LobeChatDatabase, userId: string) { - this.agentDocumentModel = new AgentDocumentModel(db, userId); - this.documentService = new DocumentService(db, userId); - this.topicDocumentModel = new TopicDocumentModel(db, userId); + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { + this.agentDocumentModel = new AgentDocumentModel(db, userId, workspaceId); + this.documentService = new DocumentService(db, userId, workspaceId); + this.topicDocumentModel = new TopicDocumentModel(db, userId, workspaceId); } private async projectDocumentContent(doc: T): Promise; diff --git a/src/server/services/agentEvalRun/index.ts b/src/server/services/agentEvalRun/index.ts index dfc7b20ef0..a1b90802a7 100644 --- a/src/server/services/agentEvalRun/index.ts +++ b/src/server/services/agentEvalRun/index.ts @@ -86,18 +86,21 @@ export class AgentEvalRunService { private readonly topicModel: TopicModel; private readonly agentService: AgentService; - constructor(db: LobeChatDatabase, userId: string) { + private workspaceId?: string; + + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; - this.runModel = new AgentEvalRunModel(db, userId); - this.benchmarkModel = new AgentEvalBenchmarkModel(db, userId); - this.datasetModel = new AgentEvalDatasetModel(db, userId); - this.runTopicModel = new AgentEvalRunTopicModel(db, userId); - this.testCaseModel = new AgentEvalTestCaseModel(db, userId); - this.messageModel = new MessageModel(db, userId); - this.threadModel = new ThreadModel(db, userId); - this.topicModel = new TopicModel(db, userId); - this.agentService = new AgentService(db, userId); + this.workspaceId = workspaceId; + this.runModel = new AgentEvalRunModel(db, userId, workspaceId); + this.benchmarkModel = new AgentEvalBenchmarkModel(db, userId, workspaceId); + this.datasetModel = new AgentEvalDatasetModel(db, userId, workspaceId); + this.runTopicModel = new AgentEvalRunTopicModel(db, userId, workspaceId); + this.testCaseModel = new AgentEvalTestCaseModel(db, userId, workspaceId); + this.messageModel = new MessageModel(db, userId, workspaceId); + this.threadModel = new ThreadModel(db, userId, workspaceId); + this.topicModel = new TopicModel(db, userId, workspaceId); + this.agentService = new AgentService(db, userId, workspaceId); } async createRun(params: { @@ -531,7 +534,9 @@ export class AgentEvalRunService { await this.runModel.update(runId, { startedAt: now, status: 'running' }); - const aiAgentService = new AiAgentService(this.db, this.userId); + const aiAgentService = new AiAgentService(this.db, this.userId, { + workspaceId: this.workspaceId, + }); const webhookUrl = '/api/workflows/agent-eval-run/on-trajectory-complete'; const userId = this.userId; const db = this.db; @@ -546,7 +551,7 @@ export class AgentEvalRunService { { handler: async (event) => { // Local mode: directly record completion - const service = new AgentEvalRunService(db, userId); + const service = new AgentEvalRunService(db, userId, this.workspaceId); await service.recordTrajectoryCompletion({ runId, status: event.status || event.reason || 'done', @@ -673,7 +678,9 @@ export class AgentEvalRunService { await this.runModel.update(runId, { startedAt: now, status: 'running' }); - const aiAgentService = new AiAgentService(this.db, this.userId); + const aiAgentService = new AiAgentService(this.db, this.userId, { + workspaceId: this.workspaceId, + }); const webhookUrl = '/api/workflows/agent-eval-run/on-thread-complete'; const userId = this.userId; const db = this.db; @@ -688,7 +695,7 @@ export class AgentEvalRunService { { handler: async (event) => { // Local mode: directly record thread completion - const service = new AgentEvalRunService(db, userId); + const service = new AgentEvalRunService(db, userId, this.workspaceId); await service.recordThreadCompletion({ runId, status: event.status || event.reason || 'done', @@ -955,7 +962,9 @@ export class AgentEvalRunService { // Update status from 'pending' to 'running' await this.runTopicModel.updateByRunAndTopic(runId, topicId, { status: 'running' }); - const aiAgentService = new AiAgentService(this.db, this.userId); + const aiAgentService = new AiAgentService(this.db, this.userId, { + workspaceId: this.workspaceId, + }); const webhookUrl = '/api/workflows/agent-eval-run/on-trajectory-complete'; const userId = this.userId; const db = this.db; @@ -970,7 +979,7 @@ export class AgentEvalRunService { { handler: async (event) => { // Local mode: directly record completion - const service = new AgentEvalRunService(db, userId); + const service = new AgentEvalRunService(db, userId, this.workspaceId); await service.recordTrajectoryCompletion({ runId, status: event.status || event.reason || 'done', @@ -1112,7 +1121,9 @@ export class AgentEvalRunService { }) { const { envPrompt, run, runId, testCaseId, threadId, topicId } = params; - const aiAgentService = new AiAgentService(this.db, this.userId); + const aiAgentService = new AiAgentService(this.db, this.userId, { + workspaceId: this.workspaceId, + }); const webhookUrl = '/api/workflows/agent-eval-run/on-thread-complete'; const userId = this.userId; const db = this.db; @@ -1127,7 +1138,7 @@ export class AgentEvalRunService { { handler: async (event) => { // Local mode: directly record thread completion - const service = new AgentEvalRunService(db, userId); + const service = new AgentEvalRunService(db, userId, this.workspaceId); await service.recordThreadCompletion({ runId, status: event.status || event.reason || 'done', diff --git a/src/server/services/agentGroup/index.ts b/src/server/services/agentGroup/index.ts index ad564e3351..a4acb66211 100644 --- a/src/server/services/agentGroup/index.ts +++ b/src/server/services/agentGroup/index.ts @@ -24,10 +24,10 @@ export class AgentGroupService { private readonly chatGroupModel: ChatGroupModel; private readonly agentGroupRepo: AgentGroupRepository; - constructor(db: LobeChatDatabase, userId: string) { - this.agentModel = new AgentModel(db, userId); - this.chatGroupModel = new ChatGroupModel(db, userId); - this.agentGroupRepo = new AgentGroupRepository(db, userId); + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { + this.agentModel = new AgentModel(db, userId, workspaceId); + this.chatGroupModel = new ChatGroupModel(db, userId, workspaceId); + this.agentGroupRepo = new AgentGroupRepository(db, userId, workspaceId); } /** diff --git a/src/server/services/agentRuntime/AbandonOperationService.ts b/src/server/services/agentRuntime/AbandonOperationService.ts index 3b7909ebb0..ac55cf825a 100644 --- a/src/server/services/agentRuntime/AbandonOperationService.ts +++ b/src/server/services/agentRuntime/AbandonOperationService.ts @@ -69,7 +69,11 @@ export class AbandonOperationService { } result.found = true; - const metadata = (state.metadata ?? {}) as { assistantMessageId?: string; userId?: string }; + const metadata = (state.metadata ?? {}) as { + assistantMessageId?: string; + userId?: string; + workspaceId?: string; + }; const message = `Operation abandoned: ${reason}`; const error: ChatMessageError = { body: { message }, @@ -105,7 +109,7 @@ export class AbandonOperationService { if (metadata.userId && metadata.assistantMessageId) { try { - const messageModel = new MessageModel(this.db, metadata.userId); + const messageModel = new MessageModel(this.db, metadata.userId, metadata.workspaceId); await messageModel.update(metadata.assistantMessageId, { error }); result.assistantMessageUpdated = true; } catch (e) { diff --git a/src/server/services/agentRuntime/AgentRuntimeService.ts b/src/server/services/agentRuntime/AgentRuntimeService.ts index 153d235b81..2ff3f69bdb 100644 --- a/src/server/services/agentRuntime/AgentRuntimeService.ts +++ b/src/server/services/agentRuntime/AgentRuntimeService.ts @@ -129,6 +129,11 @@ export interface AgentRuntimeServiceOptions { * Can pass InMemoryStreamEventManager in test environments */ streamEventManager?: IStreamEventManager; + /** + * Workspace id for scoping all DB reads/writes (messages, agent_operations). + * Falls back to user-personal scope when omitted. + */ + workspaceId?: string; } /** @@ -164,6 +169,7 @@ export class AgentRuntimeService { } private serverDB: LobeChatDatabase; private userId: string; + private workspaceId?: string; private messageModel: MessageModel; // Lazily constructed because MessageService instantiates a FileService // which eagerly creates the S3 client and throws when S3 env vars are @@ -172,7 +178,11 @@ export class AgentRuntimeService { private messageServiceInstance?: MessageService; private get messageService(): MessageService { if (!this.messageServiceInstance) { - this.messageServiceInstance = new MessageService(this.serverDB, this.userId); + this.messageServiceInstance = new MessageService( + this.serverDB, + this.userId, + this.workspaceId, + ); } return this.messageServiceInstance; } @@ -200,8 +210,10 @@ export class AgentRuntimeService { this.execSubAgentTaskCallback = options?.execSubAgentTask; this.serverDB = db; this.userId = userId; - this.messageModel = new MessageModel(db, this.userId); - this.completionLifecycle = new CompletionLifecycle(db, userId); + this.workspaceId = options?.workspaceId; + const workspaceId = this.workspaceId; + this.messageModel = new MessageModel(db, this.userId, workspaceId); + this.completionLifecycle = new CompletionLifecycle(db, userId, workspaceId); this.humanIntervention = new HumanInterventionHandler(db, this.messageModel); // Initialize ToolExecutionService with dependencies @@ -298,6 +310,7 @@ export class AgentRuntimeService { signal, userTimezone, initialStepCount = 0, + workspaceId, } = params; // Persist initial agent_operations row. CompletionLifecycle owns both @@ -379,6 +392,7 @@ export class AgentRuntimeService { userMemory, userTimezone, workingDirectory: agentConfig?.chatConfig?.runtimeEnv?.workingDirectory, + workspaceId, ...appContext, }, maxSteps, @@ -402,6 +416,7 @@ export class AgentRuntimeService { agentConfig, modelRuntimeConfig, userId, + workspaceId: this.workspaceId, }); operationCreated = true; @@ -673,6 +688,7 @@ export class AgentRuntimeService { agentId: beforeStepMetadata?.agentId, db: this.serverDB, userId: beforeStepMetadata?.userId || this.userId, + workspaceId: this.workspaceId, }, { ignoreError: true }, ); @@ -1640,6 +1656,7 @@ export class AgentRuntimeService { // Create streaming executor context const executorContext: RuntimeExecutorContext = { agentConfig: metadata?.agentConfig, + botContext: metadata?.botContext, botPlatformContext: metadata?.botPlatformContext, discordContext: metadata?.discordContext, userTimezone: metadata?.userTimezone, @@ -1657,6 +1674,7 @@ export class AgentRuntimeService { topicId: metadata?.topicId, tracingContextEngine, userId: metadata?.userId, + workspaceId: this.workspaceId, }; // Create Agent Runtime instance diff --git a/src/server/services/agentRuntime/CompletionLifecycle.ts b/src/server/services/agentRuntime/CompletionLifecycle.ts index 84376134d6..86440b5074 100644 --- a/src/server/services/agentRuntime/CompletionLifecycle.ts +++ b/src/server/services/agentRuntime/CompletionLifecycle.ts @@ -44,13 +44,16 @@ const toAgentSignalSnapshotEvents = ( export class CompletionLifecycle { private readonly messageModel: MessageModel; private readonly agentOperationModel: AgentOperationModel; + private readonly workspaceId?: string; constructor( private readonly serverDB: LobeChatDatabase, private readonly userId: string, + workspaceId?: string, ) { - this.messageModel = new MessageModel(serverDB, userId); - this.agentOperationModel = new AgentOperationModel(serverDB, userId); + this.workspaceId = workspaceId; + this.messageModel = new MessageModel(serverDB, userId, workspaceId); + this.agentOperationModel = new AgentOperationModel(serverDB, userId, workspaceId); } /** @@ -218,6 +221,7 @@ export class CompletionLifecycle { agentId: metadata?.agentId, db: this.serverDB, userId: metadata?.userId || this.userId, + workspaceId: this.workspaceId, }, { ignoreError: true }, ) @@ -242,6 +246,7 @@ export class CompletionLifecycle { agentId: metadata?.agentId, db: this.serverDB, userId: metadata?.userId || this.userId, + workspaceId: this.workspaceId, }, { ignoreError: true }, ); diff --git a/src/server/services/agentRuntime/__tests__/executeStep.test.ts b/src/server/services/agentRuntime/__tests__/executeStep.test.ts index 8cae6e8883..256a9f8db0 100644 --- a/src/server/services/agentRuntime/__tests__/executeStep.test.ts +++ b/src/server/services/agentRuntime/__tests__/executeStep.test.ts @@ -1,6 +1,8 @@ // @vitest-environment node import { describe, expect, it, vi } from 'vitest'; +import { createRuntimeExecutors } from '@/server/modules/AgentRuntime/RuntimeExecutors'; + import { AgentRuntimeService } from '../AgentRuntimeService'; import { hookDispatcher } from '../hooks'; @@ -170,6 +172,27 @@ describe('AgentRuntimeService.executeStep - early exit on terminal state', () => unregisterSpy.mockRestore(); }); + it('threads workspaceId into runtime executors for workspace-scoped agent runs', async () => { + const service = new AgentRuntimeService({} as any, 'user-1', { + queueService: null, + workspaceId: 'ws-1', + }); + + await (service as any).createAgentRuntime({ + metadata: { + agentConfig: {}, + modelRuntimeConfig: { model: 'gpt-test', provider: 'lobehub' }, + userId: 'user-1', + }, + operationId: 'op-workspace', + stepIndex: 0, + }); + + expect(createRuntimeExecutors).toHaveBeenCalledWith( + expect.objectContaining({ workspaceId: 'ws-1' }), + ); + }); + it('should NOT skip step when operation status is "running"', async () => { const service = createService(); diff --git a/src/server/services/agentRuntime/types.ts b/src/server/services/agentRuntime/types.ts index df217b9037..7d09a95918 100644 --- a/src/server/services/agentRuntime/types.ts +++ b/src/server/services/agentRuntime/types.ts @@ -239,6 +239,12 @@ export interface OperationCreationParams { userMemory?: ServerUserMemoryConfig; /** User's timezone from settings (e.g. 'Asia/Shanghai') */ userTimezone?: string; + /** + * Workspace ID propagated down from the originating chat/task router so + * tool executions (createBrief / pinTask / etc.) ownership-filter to the + * caller's workspace. Stored on `state.metadata.workspaceId`. + */ + workspaceId?: string; } export interface OperationCreationResult { diff --git a/src/server/services/agentSignal/__tests__/index.integration.test.ts b/src/server/services/agentSignal/__tests__/index.integration.test.ts index 09ae1d5e89..37b68cfb4d 100644 --- a/src/server/services/agentSignal/__tests__/index.integration.test.ts +++ b/src/server/services/agentSignal/__tests__/index.integration.test.ts @@ -4,6 +4,7 @@ import { afterEach, describe, expect, it, vi } from 'vitest'; interface LoadIndexIntegrationModuleOptions { featureGateEnabled?: boolean; + mockCreateDefaultAgentSignalPolicies?: Mock; mockEmitSourceEvent?: Mock; mockInitModelRuntimeFromDB?: Mock; mockProjectObservability?: Mock; @@ -37,6 +38,7 @@ const loadIndexIntegrationModule = async (options: LoadIndexIntegrationModuleOpt vi.doUnmock('../observability/projector'); vi.doUnmock('../observability/store'); vi.doUnmock('../orchestrator'); + vi.doUnmock('../policies'); vi.doUnmock('../runtime/AgentSignalRuntime'); vi.doUnmock('../sources'); vi.doUnmock('@/server/services/agentDocuments'); @@ -76,6 +78,12 @@ const loadIndexIntegrationModule = async (options: LoadIndexIntegrationModuleOpt })); } + if (options.mockCreateDefaultAgentSignalPolicies) { + vi.doMock('../policies', () => ({ + createDefaultAgentSignalPolicies: options.mockCreateDefaultAgentSignalPolicies, + })); + } + const module = await import('../index'); return { @@ -216,6 +224,72 @@ describe('emitAgentSignalSourceEvent integration', () => { expect(mocks.persistAgentSignalObservability).toHaveBeenCalledTimes(1); }); + it('threads workspaceId into default policy options for workspace-scoped skill management', async () => { + const emitSourceEvent = vi.fn().mockResolvedValue({ + deduped: false, + source: { + chain: { chainId: 'chain:workspace', rootSourceId: 'workspace-source' }, + payload: {}, + scopeKey: 'topic:topic-1', + sourceId: 'workspace-source', + sourceType: 'agent.user.message', + timestamp: 1_710_000_000_000, + }, + trigger: { + scopeKey: 'topic:topic-1', + token: 'trigger:workspace-source', + windowEventCount: 1, + }, + }); + const createDefaultAgentSignalPolicies = vi.fn().mockReturnValue([]); + const emitNormalized = vi.fn().mockResolvedValue({ + status: 'completed', + trace: { + actions: [], + results: [], + signals: [], + source: { + chain: { chainId: 'chain:workspace', rootSourceId: 'workspace-source' }, + payload: {}, + scopeKey: 'topic:topic-1', + sourceId: 'workspace-source', + sourceType: 'agent.user.message', + timestamp: 1_710_000_000_000, + }, + }, + }); + + const { emitAgentSignalSourceEvent } = await loadIndexIntegrationModule({ + mockCreateDefaultAgentSignalPolicies: createDefaultAgentSignalPolicies, + mockEmitSourceEvent: emitSourceEvent, + mockRuntimeFactory: vi.fn().mockReturnValue({ emitNormalized }), + }); + + await emitAgentSignalSourceEvent( + { + payload: { message: 'Create a reusable skill.', messageId: 'msg-workspace-skill' }, + scopeKey: 'topic:topic-1', + sourceId: 'workspace-source', + sourceType: 'agent.user.message', + timestamp: 1_710_000_000_000, + }, + { + agentId: 'agent-1', + db: {} as never, + userId: 'user-1', + workspaceId: 'ws-1', + }, + ); + + expect(createDefaultAgentSignalPolicies).toHaveBeenCalledWith( + expect.objectContaining({ + skillManagement: expect.objectContaining({ + workspaceId: 'ws-1', + }), + }), + ); + }); + it( 'projects and persists observability for the real source-to-runtime path', { timeout: 10_000 }, diff --git a/src/server/services/agentSignal/emitter.ts b/src/server/services/agentSignal/emitter.ts index d0df012b2f..2ef3baf62a 100644 --- a/src/server/services/agentSignal/emitter.ts +++ b/src/server/services/agentSignal/emitter.ts @@ -27,6 +27,12 @@ export interface AgentSignalExecutionContext { agentId?: string; db: LobeChatDatabase; userId: string; + /** + * Workspace id when the originating producer ran inside a team workspace. + * Threaded through to action handlers so workspace-scoped writes (e.g. + * `userMemories`) can target the correct workspace. + */ + workspaceId?: string; } type RuntimeProducerSourceType = @@ -132,7 +138,7 @@ export const emitAgentSignalSourceEvent = async ( input: AgentSignalSourceEventInput, - context: Pick, + context: Pick, ): Promise => { const db = await getServerDB(); @@ -160,6 +166,7 @@ export const enqueueAgentSignalSourceEvent = async { ctx, ); - expect(initModelRuntimeFromDB).toHaveBeenCalledWith({} as LobeChatDatabase, 'user_1', 'openai'); + expect(initModelRuntimeFromDB).toHaveBeenCalledWith( + {} as LobeChatDatabase, + 'user_1', + 'openai', + undefined, + ); expect(mockGenerateObject).toHaveBeenCalledWith( expect.objectContaining({ messages: [ diff --git a/src/server/services/agentSignal/policies/analyzeIntent/__tests__/skillIntent.test.ts b/src/server/services/agentSignal/policies/analyzeIntent/__tests__/skillIntent.test.ts index 281a46e56b..b9a3654326 100644 --- a/src/server/services/agentSignal/policies/analyzeIntent/__tests__/skillIntent.test.ts +++ b/src/server/services/agentSignal/policies/analyzeIntent/__tests__/skillIntent.test.ts @@ -296,7 +296,12 @@ describe('skillIntent classifier', () => { topicLabel: 'login-debugging', }); - expect(initModelRuntimeFromDB).toHaveBeenCalledWith({} as LobeChatDatabase, 'user_1', 'openai'); + expect(initModelRuntimeFromDB).toHaveBeenCalledWith( + {} as LobeChatDatabase, + 'user_1', + 'openai', + undefined, + ); expect(mockGenerateObject).toHaveBeenCalledWith( expect.objectContaining({ messages: [ diff --git a/src/server/services/agentSignal/policies/analyzeIntent/actions/__tests__/skillManagement.test.ts b/src/server/services/agentSignal/policies/analyzeIntent/actions/__tests__/skillManagement.test.ts index 26d571bec7..167130fc9a 100644 --- a/src/server/services/agentSignal/policies/analyzeIntent/actions/__tests__/skillManagement.test.ts +++ b/src/server/services/agentSignal/policies/analyzeIntent/actions/__tests__/skillManagement.test.ts @@ -71,6 +71,23 @@ describe('defineSkillManagementActionHandler', () => { expect(context.runtimeState.touchGuardState).toHaveBeenCalledTimes(1); }); + it('forwards the workspaceId onto the dispatched run so the write stays workspace-scoped', async () => { + dispatch.mockResolvedValue({ operationId: 'op_1', topicId: 'topic_1' }); + + const handler = defineSkillManagementActionHandler({ + db: {} as never, + dispatch, + selfIterationEnabled: true, + userId: 'user_1', + workspaceId: 'ws_1', + }); + + await handler.handle(skillAction, createContext()); + + expect(dispatch).toHaveBeenCalledTimes(1); + expect(dispatch.mock.calls[0][0].workspaceId).toBe('ws_1'); + }); + it('skips when self-iteration is disabled (no dispatch)', async () => { const handler = defineSkillManagementActionHandler({ db: {} as never, diff --git a/src/server/services/agentSignal/policies/analyzeIntent/actions/skillManagement.ts b/src/server/services/agentSignal/policies/analyzeIntent/actions/skillManagement.ts index 862d8c5746..8f11818b92 100644 --- a/src/server/services/agentSignal/policies/analyzeIntent/actions/skillManagement.ts +++ b/src/server/services/agentSignal/policies/analyzeIntent/actions/skillManagement.ts @@ -30,6 +30,8 @@ export interface SkillManagementActionHandlerOptions { responseLanguage?: string; selfIterationEnabled: boolean; userId: string; + /** Workspace id when the run belongs to a team workspace; scopes the skill write. */ + workspaceId?: string; } const finalizeAttempt = ( @@ -185,6 +187,7 @@ export const executeSkillManagementAction = async ( threadTitle: 'Agent Signal Skill', ...(topicId ? { topicId } : {}), userId: options.userId, + ...(options.workspaceId ? { workspaceId: options.workspaceId } : {}), }); await markAppliedActionIdempotency(context, idempotencyKey); diff --git a/src/server/services/agentSignal/policies/analyzeIntent/actions/userMemory.ts b/src/server/services/agentSignal/policies/analyzeIntent/actions/userMemory.ts index 2510238751..918267baf7 100644 --- a/src/server/services/agentSignal/policies/analyzeIntent/actions/userMemory.ts +++ b/src/server/services/agentSignal/policies/analyzeIntent/actions/userMemory.ts @@ -75,6 +75,7 @@ export interface UserMemoryActionHandlerOptions { }) => Promise; pluginModel?: Pick; userId: string; + workspaceId?: string; } const finalizeAttempt = ( @@ -174,8 +175,10 @@ export const runMemoryActionAgent = async ( }; } - const agentService = options.agentService ?? new AgentService(options.db, options.userId); - const pluginModel = options.pluginModel ?? new PluginModel(options.db, options.userId); + const agentService = + options.agentService ?? new AgentService(options.db, options.userId, options.workspaceId); + const pluginModel = + options.pluginModel ?? new PluginModel(options.db, options.userId, options.workspaceId); const agentConfig = await agentService.getAgentConfig(input.agentId); const memoryLanguage = input.memoryLanguage ?? 'English'; @@ -240,7 +243,7 @@ export const runMemoryActionAgent = async ( let threadId: string | undefined; if (input.topicId && input.sourceMessageId) { try { - const threadModel = new ThreadModel(options.db, options.userId); + const threadModel = new ThreadModel(options.db, options.userId, options.workspaceId); const thread = await threadModel.create({ agentId: input.agentId, metadata: { operationId }, @@ -292,7 +295,9 @@ export const runMemoryActionAgent = async ( // The durable receipt is projected on the completion path from the run's // finalState — no blocking executeSync. if (dispatch) { - const runtimeService = new AgentRuntimeService(options.db, options.userId); + const runtimeService = new AgentRuntimeService(options.db, options.userId, { + workspaceId: options.workspaceId, + }); await runtimeService.createOperation({ ...createParams, appContext: { ...baseAppContext, agentSignal: dispatch.marker }, @@ -312,6 +317,7 @@ export const runMemoryActionAgent = async ( }, queueService: null, streamEventManager, + workspaceId: options.workspaceId, }); await runtimeService.createOperation({ ...createParams, diff --git a/src/server/services/agentSignal/policies/analyzeIntent/feedbackDomain.ts b/src/server/services/agentSignal/policies/analyzeIntent/feedbackDomain.ts index 546ae9deb4..f588c7fe8c 100644 --- a/src/server/services/agentSignal/policies/analyzeIntent/feedbackDomain.ts +++ b/src/server/services/agentSignal/policies/analyzeIntent/feedbackDomain.ts @@ -48,10 +48,12 @@ export interface CreateFeedbackDomainJudgePolicyOptions { feedbackDomainJudge?: Partial & { db: LobeChatDatabase; userId: string; + workspaceId?: string; }; skillIntentClassifier?: Partial & { db: LobeChatDatabase; userId: string; + workspaceId?: string; }; } @@ -67,6 +69,7 @@ const createDomainResolver = ( runtimeDeps.db, runtimeDeps.userId, runtimeDeps, + runtimeDeps.workspaceId, ); return ( @@ -88,7 +91,12 @@ export const createSkillIntentClassifier = ( if (!runtimeDeps) return undefined; - return new SkillIntentClassifierAgentService(runtimeDeps.db, runtimeDeps.userId, runtimeDeps); + return new SkillIntentClassifierAgentService( + runtimeDeps.db, + runtimeDeps.userId, + runtimeDeps, + runtimeDeps.workspaceId, + ); }; /** diff --git a/src/server/services/agentSignal/policies/analyzeIntent/feedbackDomainAgent.ts b/src/server/services/agentSignal/policies/analyzeIntent/feedbackDomainAgent.ts index b693c5ca6e..7471451dd8 100644 --- a/src/server/services/agentSignal/policies/analyzeIntent/feedbackDomainAgent.ts +++ b/src/server/services/agentSignal/policies/analyzeIntent/feedbackDomainAgent.ts @@ -157,14 +157,17 @@ export class FeedbackDomainJudgeAgentService { private readonly db: LobeChatDatabase; private readonly modelConfig: FeedbackDomainJudgeAgentModelConfig; private readonly userId: string; + private readonly workspaceId?: string; constructor( db: LobeChatDatabase, userId: string, modelConfig: Partial = {}, + workspaceId?: string, ) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; this.modelConfig = { model: modelConfig.model ?? DEFAULT_MINI_SYSTEM_AGENT_ITEM.model, provider: modelConfig.provider ?? DEFAULT_MINI_SYSTEM_AGENT_ITEM.provider, @@ -190,6 +193,7 @@ export class FeedbackDomainJudgeAgentService { this.db, this.userId, this.modelConfig.provider, + this.workspaceId, ); log('judgeDomains model=%s provider=%s', this.modelConfig.model, this.modelConfig.provider); diff --git a/src/server/services/agentSignal/policies/analyzeIntent/feedbackSatisfaction.ts b/src/server/services/agentSignal/policies/analyzeIntent/feedbackSatisfaction.ts index 8e30f28774..0f90092c64 100644 --- a/src/server/services/agentSignal/policies/analyzeIntent/feedbackSatisfaction.ts +++ b/src/server/services/agentSignal/policies/analyzeIntent/feedbackSatisfaction.ts @@ -153,6 +153,7 @@ export interface CreateFeedbackSatisfactionJudgePolicyOptions { model?: string; provider?: string; userId?: string; + workspaceId?: string; } /** @@ -172,14 +173,17 @@ export class FeedbackSatisfactionJudgeAgentService implements FeedbackSatisfacti private readonly db: LobeChatDatabase; private readonly modelConfig: FeedbackSatisfactionJudgeAgentModelConfig; private readonly userId: string; + private readonly workspaceId?: string; constructor( db: LobeChatDatabase, userId: string, modelConfig: Partial = {}, + workspaceId?: string, ) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; this.modelConfig = { model: modelConfig.model ?? DEFAULT_MINI_SYSTEM_AGENT_ITEM.model, provider: modelConfig.provider ?? DEFAULT_MINI_SYSTEM_AGENT_ITEM.provider, @@ -206,6 +210,7 @@ export class FeedbackSatisfactionJudgeAgentService implements FeedbackSatisfacti this.db, this.userId, this.modelConfig.provider, + this.workspaceId, ); log( @@ -240,10 +245,15 @@ const resolveJudge = ( ); } - return new FeedbackSatisfactionJudgeAgentService(options.db, options.userId, { - model: options.model, - provider: options.provider, - }); + return new FeedbackSatisfactionJudgeAgentService( + options.db, + options.userId, + { + model: options.model, + provider: options.provider, + }, + options.workspaceId, + ); }; /** diff --git a/src/server/services/agentSignal/policies/analyzeIntent/skillIntent.ts b/src/server/services/agentSignal/policies/analyzeIntent/skillIntent.ts index 5c06984d8e..be0afded49 100644 --- a/src/server/services/agentSignal/policies/analyzeIntent/skillIntent.ts +++ b/src/server/services/agentSignal/policies/analyzeIntent/skillIntent.ts @@ -273,14 +273,17 @@ export class SkillIntentClassifierAgentService implements SkillIntentClassifierS private readonly db: LobeChatDatabase; private readonly modelConfig: SkillIntentClassifierAgentModelConfig; private readonly userId: string; + private readonly workspaceId?: string; constructor( db: LobeChatDatabase, userId: string, modelConfig: Partial = {}, + workspaceId?: string, ) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; this.modelConfig = { model: modelConfig.model ?? DEFAULT_MINI_SYSTEM_AGENT_ITEM.model, provider: modelConfig.provider ?? DEFAULT_MINI_SYSTEM_AGENT_ITEM.provider, @@ -305,6 +308,7 @@ export class SkillIntentClassifierAgentService implements SkillIntentClassifierS this.db, this.userId, this.modelConfig.provider, + this.workspaceId, ); log( diff --git a/src/server/services/agentSignal/services/briefs/selfReview.test.ts b/src/server/services/agentSignal/services/briefs/selfReview.test.ts index 6c698196ee..b05a5a3876 100644 --- a/src/server/services/agentSignal/services/briefs/selfReview.test.ts +++ b/src/server/services/agentSignal/services/briefs/selfReview.test.ts @@ -134,7 +134,7 @@ describe('AgentSignalSelfReviewBriefService', () => { const brief = (await mockBriefModel.findById(id)) as BriefItem | null; if (!brief) return null; - return new AgentSignalSelfReviewBriefService(db, userId, { + return new AgentSignalSelfReviewBriefService(db, userId, undefined, { selfReviewProposalResolver, }).resolve(brief, options); }; diff --git a/src/server/services/agentSignal/services/briefs/selfReview.ts b/src/server/services/agentSignal/services/briefs/selfReview.ts index 2d0cc38980..c93e3839b7 100644 --- a/src/server/services/agentSignal/services/briefs/selfReview.ts +++ b/src/server/services/agentSignal/services/briefs/selfReview.ts @@ -131,16 +131,19 @@ export class AgentSignalSelfReviewBriefService { private db: LobeChatDatabase; private selfReviewProposalResolver?: AgentSignalSelfReviewBriefServiceOptions['selfReviewProposalResolver']; private userId: string; + private workspaceId?: string; constructor( db: LobeChatDatabase, userId: string, + workspaceId: string | undefined, options: AgentSignalSelfReviewBriefServiceOptions = {}, ) { this.db = db; this.userId = userId; - this.briefService = new BriefService(db, userId); - this.briefModel = new BriefModel(db, userId); + this.workspaceId = workspaceId; + this.briefService = new BriefService(db, userId, workspaceId); + this.briefModel = new BriefModel(db, userId, workspaceId); this.selfReviewProposalResolver = options.selfReviewProposalResolver; } @@ -211,7 +214,11 @@ export class AgentSignalSelfReviewBriefService { typeof metadata?.sourceId === 'string' ? metadata.sourceId : `self-review-proposal-approve:${brief.id}`; - const skillDocumentService = new SkillManagementDocumentService(this.db, this.userId); + const skillDocumentService = new SkillManagementDocumentService( + this.db, + this.userId, + this.workspaceId, + ); const preflight = createSelfReviewProposalPreflightService({ isSkillNameAvailable: async ({ name }) => { const skill = await skillDocumentService.getSkill({ @@ -375,9 +382,11 @@ export class AgentSignalSelfReviewBriefService { checkGates: () => briefSelfReview.canApplySelfReviewProposal({ checkAgentGate: () => - new AgentSignalReviewContextModel(this.db, this.userId).canAgentRunSelfIteration( - brief.agentId ?? '', - ), + new AgentSignalReviewContextModel( + this.db, + this.userId, + this.workspaceId, + ).canAgentRunSelfIteration(brief.agentId ?? ''), checkServerGate: () => true, checkUserGate: () => isAgentSignalEnabledForUser(this.db, this.userId), }), diff --git a/src/server/services/agentSignal/services/selfIteration/dispatch/enqueueSelfIterationRun.ts b/src/server/services/agentSignal/services/selfIteration/dispatch/enqueueSelfIterationRun.ts index 5544d0f17d..a96fd98f61 100644 --- a/src/server/services/agentSignal/services/selfIteration/dispatch/enqueueSelfIterationRun.ts +++ b/src/server/services/agentSignal/services/selfIteration/dispatch/enqueueSelfIterationRun.ts @@ -34,6 +34,8 @@ export interface EnqueueSelfIterationRunInput { /** Topic the run is scoped to; a new topic is created when absent. */ topicId?: string; userId: string; + /** Workspace id when the run belongs to a team workspace; scopes the operation. */ + workspaceId?: string; } export interface EnqueueSelfIterationRunResult { @@ -66,7 +68,7 @@ export const enqueueSelfIterationRun = async ( let threadId: string | undefined; if (input.topicId && input.sourceMessageId) { try { - const thread = await new ThreadModel(input.db, input.userId).create({ + const thread = await new ThreadModel(input.db, input.userId, input.workspaceId).create({ agentId: input.agentId, sourceMessageId: input.sourceMessageId, title: input.threadTitle ?? 'Agent Signal Self-Iteration', @@ -80,7 +82,9 @@ export const enqueueSelfIterationRun = async ( } const { AiAgentService } = await import('@/server/services/aiAgent'); - const result = await new AiAgentService(input.db, input.userId).execAgent({ + const result = await new AiAgentService(input.db, input.userId, { + workspaceId: input.workspaceId, + }).execAgent({ appContext: { // No agentId here — the run executes under the builtin `slug` (which // supplies its tools / systemRole / model). The reviewed user agent diff --git a/src/server/services/agentSignal/services/selfIteration/feedback/handler.ts b/src/server/services/agentSignal/services/selfIteration/feedback/handler.ts index 1ba04a2d12..f2c57f09e8 100644 --- a/src/server/services/agentSignal/services/selfIteration/feedback/handler.ts +++ b/src/server/services/agentSignal/services/selfIteration/feedback/handler.ts @@ -153,6 +153,8 @@ export interface CreateSelfFeedbackIntentSourceHandlerDependencies { * @default builtin agent default */ maxSteps?: number; + /** Workspace id, threaded so the enqueued run targets the correct workspace. */ + workspaceId?: string; } const isValidConfidence = (value: unknown): value is number => @@ -404,6 +406,7 @@ export const createSelfFeedbackIntentSourceHandler = ( slug: BUILTIN_AGENT_SLUGS.selfFeedbackIntent, ...(payload.topicId ? { topicId: payload.topicId } : {}), userId: payload.userId, + workspaceId: deps.workspaceId, }); return { diff --git a/src/server/services/agentSignal/services/selfIteration/feedback/server.ts b/src/server/services/agentSignal/services/selfIteration/feedback/server.ts index d7d36d16eb..2c599d8c30 100644 --- a/src/server/services/agentSignal/services/selfIteration/feedback/server.ts +++ b/src/server/services/agentSignal/services/selfIteration/feedback/server.ts @@ -34,8 +34,9 @@ export const createServerSelfFeedbackIntentPolicyOptions = ({ db, selfIterationEnabled = false, userId, + workspaceId, }: CreateServerSelfIterationPolicyOptions): CreateSelfFeedbackIntentSourceHandlerDependencies => { - const reviewContextModel = new AgentSignalReviewContextModel(db, userId); + const reviewContextModel = new AgentSignalReviewContextModel(db, userId, workspaceId); return { acquireReviewGuard: (input) => @@ -62,5 +63,6 @@ export const createServerSelfFeedbackIntentPolicyOptions = ({ }, ], }), + workspaceId, }; }; diff --git a/src/server/services/agentSignal/services/selfIteration/reflection/handler.ts b/src/server/services/agentSignal/services/selfIteration/reflection/handler.ts index d75951083e..8dca4bfec1 100644 --- a/src/server/services/agentSignal/services/selfIteration/reflection/handler.ts +++ b/src/server/services/agentSignal/services/selfIteration/reflection/handler.ts @@ -131,6 +131,8 @@ export interface CreateSelfReflectionSourceHandlerDependencies { * @default builtin agent default */ maxSteps?: number; + /** Workspace id, threaded so the enqueued run targets the correct workspace. */ + workspaceId?: string; } const isSelfReflectionScopeType = (value: unknown): value is SelfReflectionSourceScopeType => @@ -306,6 +308,7 @@ export const createSelfReflectionSourceHandler = ( slug: BUILTIN_AGENT_SLUGS.selfReflection, ...(payload.topicId ? { topicId: payload.topicId } : {}), userId: payload.userId, + workspaceId: deps.workspaceId, }); return { diff --git a/src/server/services/agentSignal/services/selfIteration/reflection/server.ts b/src/server/services/agentSignal/services/selfIteration/reflection/server.ts index 7534c400a4..f22bb79241 100644 --- a/src/server/services/agentSignal/services/selfIteration/reflection/server.ts +++ b/src/server/services/agentSignal/services/selfIteration/reflection/server.ts @@ -82,8 +82,9 @@ export const createServerSelfReflectionPolicyOptions = ({ db, selfIterationEnabled = false, userId, + workspaceId, }: CreateServerSelfIterationPolicyOptions): CreateSelfReflectionSourceHandlerDependencies => { - const reviewContextModel = new AgentSignalReviewContextModel(db, userId); + const reviewContextModel = new AgentSignalReviewContextModel(db, userId, workspaceId); return { acquireReviewGuard: (input) => @@ -103,6 +104,7 @@ export const createServerSelfReflectionPolicyOptions = ({ }, collectContext: (input) => collectSelfReflectionContext(reviewContextModel, input), db, + workspaceId, }; }; @@ -132,8 +134,9 @@ export const createServerProcedurePolicyOptions = ({ db, selfIterationEnabled = false, userId, + workspaceId, }: CreateServerSelfIterationPolicyOptions) => { - const reviewContextModel = new AgentSignalReviewContextModel(db, userId); + const reviewContextModel = new AgentSignalReviewContextModel(db, userId, workspaceId); return createProcedurePolicyOptions({ policyStateStore: redisPolicyStateStore, @@ -162,6 +165,7 @@ export const createServerProcedurePolicyOptions = ({ return enqueueAgentSignalSourceEvent(event, { agentId, userId, + workspaceId, }); }, }), diff --git a/src/server/services/agentSignal/services/selfIteration/review/brief.ts b/src/server/services/agentSignal/services/selfIteration/review/brief.ts index 449703befe..de7b7ae6f1 100644 --- a/src/server/services/agentSignal/services/selfIteration/review/brief.ts +++ b/src/server/services/agentSignal/services/selfIteration/review/brief.ts @@ -557,8 +557,12 @@ export const createBriefSelfReviewService = () => ({ * Returns: * - A writer whose `writeDailyBrief` method creates or refreshes proposal briefs */ -export const createServerSelfReviewBriefWriter = (db: LobeChatDatabase, userId: string) => { - const model = new BriefModel(db, userId); +export const createServerSelfReviewBriefWriter = ( + db: LobeChatDatabase, + userId: string, + workspaceId?: string, +) => { + const model = new BriefModel(db, userId, workspaceId); return { writeDailyBrief: (brief: SelfReviewBriefProjection) => { diff --git a/src/server/services/agentSignal/services/selfIteration/review/handler.ts b/src/server/services/agentSignal/services/selfIteration/review/handler.ts index 92f1b1f3b3..a8cd3315b3 100644 --- a/src/server/services/agentSignal/services/selfIteration/review/handler.ts +++ b/src/server/services/agentSignal/services/selfIteration/review/handler.ts @@ -89,6 +89,8 @@ export interface CreateNightlyReviewSourceHandlerDependencies { * @default builtin agent default */ maxSteps?: number; + /** Workspace id, threaded so the enqueued run targets the correct workspace. */ + workspaceId?: string; } interface NightlyReviewSpanLike { @@ -342,6 +344,7 @@ export const createNightlyReviewSourceHandler = ( prompt, slug: BUILTIN_AGENT_SLUGS.nightlyReview, userId: payload.userId, + workspaceId: deps.workspaceId, }); }, (span, result) => { diff --git a/src/server/services/agentSignal/services/selfIteration/review/server.ts b/src/server/services/agentSignal/services/selfIteration/review/server.ts index 1075ada131..616b6454c4 100644 --- a/src/server/services/agentSignal/services/selfIteration/review/server.ts +++ b/src/server/services/agentSignal/services/selfIteration/review/server.ts @@ -412,6 +412,7 @@ export interface ReviewRuntimePrimitiveDeps { skillDocumentService: SkillManagementDocumentService; sourceId: string; userId: string; + workspaceId?: string; } /** @@ -440,6 +441,7 @@ export const createReviewRuntimePrimitives = ( skillDocumentService, sourceId, userId, + workspaceId, } = deps; const isSkillNameAvailable = async ({ @@ -681,7 +683,7 @@ export const createReviewRuntimePrimitives = ( message: content, reason: `Agent Signal self-review memory candidate from ${evidenceRefs.length} evidence refs.`, }, - { db, userId }, + { db, userId, workspaceId }, ); if (result.status !== 'applied') { @@ -739,11 +741,12 @@ export const createServerSelfReviewPolicyOptions = ({ db, selfIterationEnabled = false, userId, + workspaceId, }: CreateServerSelfIterationPolicyOptions): CreateNightlyReviewSourceHandlerDependencies => { const nightlyReviewModel = new AgentSignalNightlyReviewModel(db); - const reviewContextModel = new AgentSignalReviewContextModel(db, userId); - const briefModel = new BriefModel(db, userId); - const skillDocumentService = new SkillManagementDocumentService(db, userId); + const reviewContextModel = new AgentSignalReviewContextModel(db, userId, workspaceId); + const briefModel = new BriefModel(db, userId, workspaceId); + const skillDocumentService = new SkillManagementDocumentService(db, userId, workspaceId); const collector = createSelfReviewContextService({ listDocumentActivity: async ({ agentId: targetAgentId, reviewWindowEnd, reviewWindowStart }) => tracer.startActiveSpan( @@ -972,5 +975,6 @@ export const createServerSelfReviewPolicyOptions = ({ }, collectContext: (input) => collector.collect(input), db, + workspaceId, }; }; diff --git a/src/server/services/agentSignal/services/selfIteration/server.ts b/src/server/services/agentSignal/services/selfIteration/server.ts index 5fa4ca4b3f..d38a601ffc 100644 --- a/src/server/services/agentSignal/services/selfIteration/server.ts +++ b/src/server/services/agentSignal/services/selfIteration/server.ts @@ -39,6 +39,8 @@ export interface CreateServerSelfIterationPolicyOptions { selfIterationEnabled?: boolean; /** User id from the workflow payload. */ userId: string; + /** Workspace id from the workflow payload, when running inside a team workspace. */ + workspaceId?: string; } /** diff --git a/src/server/services/agentSignal/services/selfIteration/tools/runtimePrimitives.ts b/src/server/services/agentSignal/services/selfIteration/tools/runtimePrimitives.ts index e9769af3f5..59890faaff 100644 --- a/src/server/services/agentSignal/services/selfIteration/tools/runtimePrimitives.ts +++ b/src/server/services/agentSignal/services/selfIteration/tools/runtimePrimitives.ts @@ -21,6 +21,8 @@ export interface ResourceRuntimePrimitiveDeps { memoryReason: (evidenceCount: number) => string; skillDocumentService: SkillManagementDocumentService; userId: string; + /** Workspace id, so memory-write operations target the correct workspace. */ + workspaceId?: string; } /** @@ -39,6 +41,7 @@ export const createResourceRuntimePrimitives = ({ memoryReason, skillDocumentService, userId, + workspaceId, }: ResourceRuntimePrimitiveDeps): AgentSignalRuntimeService => { const isSkillNameAvailable = async ({ agentId: targetAgentId, @@ -178,7 +181,7 @@ export const createResourceRuntimePrimitives = ({ writeMemory: async ({ content, evidenceRefs, idempotencyKey }) => { const result = await runMemoryActionAgent( { agentId, message: content, reason: memoryReason(evidenceRefs.length) }, - { db, userId }, + { db, userId, workspaceId }, ); if (result.status !== 'applied') { diff --git a/src/server/services/agentSignal/sources/hydration/clientRuntimeComplete.ts b/src/server/services/agentSignal/sources/hydration/clientRuntimeComplete.ts index 4374f96602..0bf1d1002f 100644 --- a/src/server/services/agentSignal/sources/hydration/clientRuntimeComplete.ts +++ b/src/server/services/agentSignal/sources/hydration/clientRuntimeComplete.ts @@ -78,7 +78,7 @@ const getTrustedScopeKey = ( */ export const resolveClientRuntimeCompleteFeedbackSource = async ( sourceEvent: AgentSignalSourceEvent, - input: { db: LobeChatDatabase; userId: string }, + input: { db: LobeChatDatabase; userId: string; workspaceId?: string }, ): Promise => { if (sourceEvent.payload.status !== 'completed') { return skipped('non-completed-status'); @@ -90,7 +90,7 @@ export const resolveClientRuntimeCompleteFeedbackSource = async ( return skipped('missing-assistant-message-id'); } - const messageModel = new MessageModel(input.db, input.userId); + const messageModel = new MessageModel(input.db, input.userId, input.workspaceId); const assistantMessage = await messageModel.findById(assistantMessageId); if (!assistantMessage) { diff --git a/src/server/services/agentSignal/sources/hydration/clientRuntimeStart.ts b/src/server/services/agentSignal/sources/hydration/clientRuntimeStart.ts index 25e1990199..1473eaba22 100644 --- a/src/server/services/agentSignal/sources/hydration/clientRuntimeStart.ts +++ b/src/server/services/agentSignal/sources/hydration/clientRuntimeStart.ts @@ -54,7 +54,7 @@ const getTrustedScopeKey = ( */ export const resolveClientRuntimeStartFeedbackSource = async ( sourceEvent: SourceEventClientRuntimeStart, - input: { db: LobeChatDatabase; userId: string }, + input: { db: LobeChatDatabase; userId: string; workspaceId?: string }, ): Promise => { if (sourceEvent.payload.parentMessageType !== 'user') { return { diagnostic: { reason: 'non-user-parent', status: 'skipped' } }; @@ -64,7 +64,7 @@ export const resolveClientRuntimeStartFeedbackSource = async ( return { diagnostic: { reason: 'missing-parent-message-id', status: 'skipped' } }; } - const messageModel = new MessageModel(input.db, input.userId); + const messageModel = new MessageModel(input.db, input.userId, input.workspaceId); const parentMessage = await messageModel.findById(sourceEvent.payload.parentMessageId); if (!parentMessage) { diff --git a/src/server/services/aiAgent/index.ts b/src/server/services/aiAgent/index.ts index ece5ef7295..da5ff1f925 100644 --- a/src/server/services/aiAgent/index.ts +++ b/src/server/services/aiAgent/index.ts @@ -273,29 +273,34 @@ export class AiAgentService { private readonly marketService: MarketService; private readonly klavisService: KlavisService; + private readonly workspaceId?: string; + constructor( db: LobeChatDatabase, userId: string, - options?: { runtimeOptions?: AgentRuntimeServiceOptions }, + options?: { runtimeOptions?: AgentRuntimeServiceOptions; workspaceId?: string }, ) { this.userId = userId; this.db = db; - this.agentDocumentsService = new AgentDocumentsService(db, userId); - this.agentModel = new AgentModel(db, userId); - this.agentService = new AgentService(db, userId); - this.messageModel = new MessageModel(db, userId); - this.connectorModel = new ConnectorModel(db, userId); - this.connectorToolModel = new ConnectorToolModel(db, userId); - this.pluginModel = new PluginModel(db, userId); - this.taskModel = new TaskModel(db, userId); - this.threadModel = new ThreadModel(db, userId); - this.topicModel = new TopicModel(db, userId); + this.workspaceId = options?.workspaceId; + const wsId = this.workspaceId; + this.agentDocumentsService = new AgentDocumentsService(db, userId, wsId); + this.agentModel = new AgentModel(db, userId, wsId); + this.agentService = new AgentService(db, userId, wsId); + this.messageModel = new MessageModel(db, userId, wsId); + this.connectorModel = new ConnectorModel(db, userId, wsId); + this.connectorToolModel = new ConnectorToolModel(db, userId, wsId); + this.pluginModel = new PluginModel(db, userId, wsId); + this.taskModel = new TaskModel(db, userId, wsId); + this.threadModel = new ThreadModel(db, userId, wsId); + this.topicModel = new TopicModel(db, userId, wsId); this.agentRuntimeService = new AgentRuntimeService(db, userId, { ...options?.runtimeOptions, execSubAgentTask: this.execSubAgentTask.bind(this), + workspaceId: wsId, }); this.marketService = new MarketService({ userInfo: { userId } }); - this.klavisService = new KlavisService({ db, userId }); + this.klavisService = new KlavisService({ db, userId, workspaceId: wsId }); } private async resolveOperationTaskId( @@ -831,7 +836,9 @@ export class AiAgentService { assistantMessageRef.current = assistantMsg.id; // Read resume session id for next-turn continuity. - const heteroService = new HeterogeneousAgentService(this.db, this.userId); + const heteroService = new HeterogeneousAgentService(this.db, this.userId, { + workspaceId: this.workspaceId, + }); const resumeSessionId = await heteroService.getHeterogeneousResumeSessionId(topicId); // Sign an operation-scoped JWT so the CLI can authenticate against // heteroIngest / heteroFinish without full user credentials. @@ -1249,7 +1256,7 @@ export class AiAgentService { const { loadModels } = await import('@/business/client/model-bank/loadModels'); const builtinModels = await loadModels(); // Resolve file URLs before visual tool activation checks and context build. - const fileService = new FileService(this.db, this.userId); + const fileService = new FileService(this.db, this.userId, this.workspaceId); const postProcessUrl = (path: string | null, file: { id?: string | null }) => fileService.getFileAccessUrl({ id: file.id, url: path }); let historyMessagesCache: any[] | undefined; @@ -1386,7 +1393,7 @@ export class AiAgentService { : []; if (connectorEntries.length > 0) { - const toolModel = new ConnectorToolModel(this.db, this.userId); + const toolModel = new ConnectorToolModel(this.db, this.userId, this.workspaceId); const connectorToolsMap = new Map>(); await Promise.all( connectorEntries.map(async (c) => { @@ -1471,7 +1478,7 @@ export class AiAgentService { const externalFileTypes = files?.map((file) => file.mimeType ?? '') ?? []; let attachedFileTypes: string[] = []; if (attachedFileIds && attachedFileIds.length > 0) { - const fileModel = new FileModel(this.db, this.userId); + const fileModel = new FileModel(this.db, this.userId, this.workspaceId); const fileRecords = await fileModel.findByIds(Array.from(new Set(attachedFileIds))); attachedFileTypes = fileRecords.map((file) => file.fileType || ''); } @@ -1990,7 +1997,7 @@ export class AiAgentService { imageList = []; videoList = []; fileList = []; - const documentService = new DocumentService(this.db, this.userId); + const documentService = new DocumentService(this.db, this.userId, this.workspaceId); for (const file of files) { await throwIfExecutionAborted('file upload'); @@ -2077,6 +2084,7 @@ export class AiAgentService { db: this.db, fileIds: attachedFileIds, userId: this.userId, + workspaceId: this.workspaceId, }); warnings.push(...resolved.warnings); @@ -2345,7 +2353,7 @@ export class AiAgentService { identifier: s.identifier, name: s.name, })); - const skillModel = new AgentSkillModel(this.db, this.userId); + const skillModel = new AgentSkillModel(this.db, this.userId, this.workspaceId); const { data: dbSkills } = await skillModel.findAll(); const dbMetas = dbSkills.map((s) => ({ description: s.description ?? '', @@ -2509,6 +2517,7 @@ export class AiAgentService { userId: this.userId, userInterventionConfig, userMemory, + workspaceId: this.workspaceId, }); log('execAgent: created operation %s (autoStarted: %s)', operationId, result.autoStarted); @@ -2749,9 +2758,11 @@ export class AiAgentService { let inheritedTrigger: string | undefined; if (parentOperationId) { try { - const parentOp = await new AgentOperationModel(this.db, this.userId).findById( - parentOperationId, - ); + const parentOp = await new AgentOperationModel( + this.db, + this.userId, + this.workspaceId, + ).findById(parentOperationId); inheritedTrigger = parentOp?.trigger ?? undefined; } catch (error) { log('execSubAgentTask: failed to read parent operation trigger: %O', error); diff --git a/src/server/services/aiChat/index.ts b/src/server/services/aiChat/index.ts index e3826db7dc..c2881d1d1b 100644 --- a/src/server/services/aiChat/index.ts +++ b/src/server/services/aiChat/index.ts @@ -33,10 +33,10 @@ export class AiChatService { private fileService: FileService; private topicModel: TopicModel; - constructor(serverDB: LobeChatDatabase, userId: string) { - this.messageModel = new MessageModel(serverDB, userId); - this.topicModel = new TopicModel(serverDB, userId); - this.fileService = new FileService(serverDB, userId); + constructor(serverDB: LobeChatDatabase, userId: string, workspaceId?: string) { + this.messageModel = new MessageModel(serverDB, userId, workspaceId); + this.topicModel = new TopicModel(serverDB, userId, workspaceId); + this.fileService = new FileService(serverDB, userId, workspaceId); } async getMessagesAndTopics(params: GetMessagesAndTopicsParams) { diff --git a/src/server/services/aiGeneration/index.ts b/src/server/services/aiGeneration/index.ts index 68df37acd3..66558b504b 100644 --- a/src/server/services/aiGeneration/index.ts +++ b/src/server/services/aiGeneration/index.ts @@ -47,17 +47,21 @@ export interface AiGenerationObjectOptions { export class AiGenerationService { private readonly db: LobeChatDatabase; private readonly userId: string; + private readonly workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } async generateObject( input: AiGenerationObjectInput, options: AiGenerationObjectOptions = {}, ): Promise { - const runtime = await initModelRuntimeFromDB(this.db, this.userId, input.provider); + const runtime = this.workspaceId + ? await initModelRuntimeFromDB(this.db, this.userId, input.provider, this.workspaceId) + : await initModelRuntimeFromDB(this.db, this.userId, input.provider); return (await runtime.generateObject( { messages: input.messages as GenerateObjectPayload['messages'], diff --git a/src/server/services/bot/AgentBridgeService.ts b/src/server/services/bot/AgentBridgeService.ts index acbfd9680e..d9a345b8b1 100644 --- a/src/server/services/bot/AgentBridgeService.ts +++ b/src/server/services/bot/AgentBridgeService.ts @@ -180,6 +180,7 @@ interface ActiveReaction { export class AgentBridgeService { private readonly db: LobeChatDatabase; private readonly userId: string; + private readonly workspaceId?: string; private timezone: string | undefined; private timezoneLoaded = false; @@ -351,13 +352,16 @@ export class AgentBridgeService { } } - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } private async interruptTrackedOperation(threadId: string, operationId: string): Promise { - const aiAgentService = new AiAgentService(this.db, this.userId); + const aiAgentService = new AiAgentService(this.db, this.userId, { + workspaceId: this.workspaceId, + }); const result = await aiAgentService.interruptTask({ operationId }); if (!result.success) { throw new Error(`Failed to interrupt operation ${operationId}`); @@ -575,7 +579,7 @@ export class AgentBridgeService { // Wrapped in try/catch so transient DB errors fall through to the // existing topicId rather than rejecting before the guarded section. try { - const topicModel = new TopicModel(this.db, this.userId); + const topicModel = new TopicModel(this.db, this.userId, this.workspaceId); const existingTopic = await topicModel.findById(topicId); if (existingTopic) { const elapsed = Date.now() - new Date(existingTopic.updatedAt).getTime(); @@ -728,7 +732,9 @@ export class AgentBridgeService { } = opts; const queueMode = isQueueAgentRuntimeEnabled(); - const aiAgentService = new AiAgentService(this.db, this.userId); + const aiAgentService = new AiAgentService(this.db, this.userId, { + workspaceId: this.workspaceId, + }); const timezone = await this.loadTimezone(); // When the message-gateway is configured AND the platform supports typing @@ -852,6 +858,7 @@ export class AgentBridgeService { // alarm timeout instead of stopping at completion. userId: this.userId, userMessageId: userMessage.id, + workspaceId: this.workspaceId, }; log( @@ -1330,13 +1337,17 @@ export class AgentBridgeService { // title generation (the prompt itself still drives it on // the next round). if (resolvedTopicId && prompt && lastAssistantContent) { - const topicModel = new TopicModel(this.db, this.userId); + const topicModel = new TopicModel(this.db, this.userId, this.workspaceId); topicModel .findById(resolvedTopicId) .then(async (topic) => { if (topic?.title) return; - const systemAgent = new SystemAgentService(this.db, this.userId); + const systemAgent = new SystemAgentService( + this.db, + this.userId, + this.workspaceId, + ); const title = await systemAgent.generateTopicTitle({ lastAssistantContent, userPrompt: prompt, diff --git a/src/server/services/bot/BotCallbackService.ts b/src/server/services/bot/BotCallbackService.ts index 8d974e5e01..f833f60772 100644 --- a/src/server/services/bot/BotCallbackService.ts +++ b/src/server/services/bot/BotCallbackService.ts @@ -97,6 +97,7 @@ export interface BotCallbackBody { userId?: string; userMessageId?: string; userPrompt?: string; + workspaceId?: string; } // --------------- Service --------------- @@ -119,13 +120,15 @@ export class BotCallbackService { } = body; const platform = platformThreadId.split(':')[0]; - const { client, connectionId, messenger, charLimit, settings } = await this.createMessenger({ - applicationId, - messengerInstallationKey, - platform, - platformThreadId, - userId, - }); + const { client, connectionId, messenger, charLimit, settings, workspaceId } = + await this.createMessenger({ + applicationId, + messengerInstallationKey, + platform, + platformThreadId, + userId, + workspaceId: body.workspaceId, + }); const entry = platformRegistry.getPlatform(platform); const canEdit = entry?.supportsMessageEdit !== false; @@ -163,7 +166,10 @@ export class BotCallbackService { // In queue mode, the bridge handler's finally block skips this cleanup // to keep the thread marked active while the agent runs on the job queue. AgentBridgeService.clearActiveThread(platformThreadId); - this.summarizeTopicTitle(body, messenger); + this.summarizeTopicTitle( + { ...body, workspaceId: body.workspaceId ?? workspaceId ?? undefined }, + messenger, + ); } } @@ -173,12 +179,14 @@ export class BotCallbackService { platform: string; platformThreadId: string; userId?: string; + workspaceId?: string; }): Promise<{ charLimit?: number; connectionId: string; client: PlatformClient; messenger: PlatformMessenger; settings: Record; + workspaceId?: string | null; }> { const { applicationId, messengerInstallationKey, platform, platformThreadId, userId } = params; @@ -192,6 +200,7 @@ export class BotCallbackService { messengerInstallationKey, platformThreadId, userId, + params.workspaceId, ); } @@ -231,7 +240,14 @@ export class BotCallbackService { }); const messenger = client.getMessenger(platformThreadId); - return { charLimit, connectionId: row.id, messenger, client, settings }; + return { + charLimit, + client, + connectionId: row.id, + messenger, + settings, + workspaceId: row.workspaceId, + }; } /** @@ -252,12 +268,14 @@ export class BotCallbackService { installationKey: string, platformThreadId: string, userId?: string, + workspaceId?: string, ): Promise<{ charLimit?: number; connectionId: string; client: PlatformClient; messenger: PlatformMessenger; settings: Record; + workspaceId?: string; }> { const store = getInstallationStore(platform as MessengerPlatform); if (!store) { @@ -296,7 +314,7 @@ export class BotCallbackService { ? messengerConnectionIdForUser({ connectionMode, installationKey, userId }) : ''; - return { charLimit: undefined, client, connectionId, messenger, settings: {} }; + return { charLimit: undefined, client, connectionId, messenger, settings: {}, workspaceId }; } private async handleStep( @@ -595,7 +613,7 @@ export class BotCallbackService { // Thread already has a user-set name — use it as topic title, skip LLM generation if (threadName) { - const topicModel = new TopicModel(this.db, userId); + const topicModel = new TopicModel(this.db, userId, body.workspaceId); topicModel .findById(topicId) .then(async (topic) => { @@ -608,13 +626,13 @@ export class BotCallbackService { return; } - const topicModel = new TopicModel(this.db, userId); + const topicModel = new TopicModel(this.db, userId, body.workspaceId); topicModel .findById(topicId) .then(async (topic) => { if (topic?.title) return; - const systemAgent = new SystemAgentService(this.db, userId); + const systemAgent = new SystemAgentService(this.db, userId, body.workspaceId ?? undefined); const title = await systemAgent.generateTopicTitle({ lastAssistantContent, userPrompt, diff --git a/src/server/services/bot/BotMessageRouter.ts b/src/server/services/bot/BotMessageRouter.ts index 48c4053117..2999a67edc 100644 --- a/src/server/services/bot/BotMessageRouter.ts +++ b/src/server/services/bot/BotMessageRouter.ts @@ -94,6 +94,7 @@ const summarizeMessageAttachments = (message: Message): Array { - const { agentId, userId, applicationId } = provider; + const { agentId, userId, applicationId, workspaceId } = provider; const platform = entry.id; const key = buildRuntimeKey(platform, applicationId); @@ -348,6 +349,7 @@ export class BotMessageRouter { platform, providerId: provider.id, userId, + workspaceId: workspaceId ?? undefined, }); // Default to 'queue' for legacy providers that don't have `concurrency` @@ -370,6 +372,7 @@ export class BotMessageRouter { platform, settings, userId, + workspaceId: workspaceId ?? undefined, }); await chatBot.initialize(); client.applyChatPatches?.(chatBot); @@ -387,7 +390,7 @@ export class BotMessageRouter { } const registered: RegisteredBot = { - agentInfo: { agentId, userId }, + agentInfo: { agentId, userId, workspaceId: workspaceId ?? undefined }, chatBot, client, }; @@ -535,8 +538,8 @@ export class BotMessageRouter { settings?: Record; }, ): void { - const { agentId, applicationId, platform, userId } = info; - const bridge = new AgentBridgeService(serverDB, userId); + const { agentId, applicationId, platform, userId, workspaceId } = info; + const bridge = new AgentBridgeService(serverDB, userId, workspaceId); const charLimit = (info.settings?.charLimit as number) || undefined; const displayToolCalls = info.settings?.displayToolCalls === true; const dmSettings: DmSettings = extractDmSettings(info.settings); @@ -1459,6 +1462,7 @@ export class BotMessageRouter { * `/approve` to append a fresh applicant to `settings.allowFrom`. */ providerId: string; userId: string; + workspaceId?: string; }, ): BotCommand[] { const { @@ -1494,7 +1498,9 @@ export class BotMessageRouter { const operationId = AgentBridgeService.getActiveOperationId(ctx.threadId); if (operationId) { try { - const aiAgentService = new AiAgentService(serverDB, userId); + const aiAgentService = new AiAgentService(serverDB, userId, { + workspaceId: info.workspaceId ?? undefined, + }); const result = await aiAgentService.interruptTask({ operationId }); if (!result.success) { log('command /stop: runtime interrupt rejected for operationId=%s', operationId); diff --git a/src/server/services/bot/__tests__/AgentBridgeService.test.ts b/src/server/services/bot/__tests__/AgentBridgeService.test.ts index 548c99513a..4a28e605dd 100644 --- a/src/server/services/bot/__tests__/AgentBridgeService.test.ts +++ b/src/server/services/bot/__tests__/AgentBridgeService.test.ts @@ -58,6 +58,7 @@ vi.mock('@/server/services/bot/platforms', async (importOriginal) => { }); const { AgentBridgeService } = await import('../AgentBridgeService'); +const { AiAgentService } = await import('@/server/services/aiAgent'); const FAKE_DB = {} as any; const USER_ID = 'user-123'; @@ -146,6 +147,34 @@ describe('AgentBridgeService', () => { ); }); + it('constructs AiAgentService with workspaceId for workspace bot runs', async () => { + const service = new AgentBridgeService(FAKE_DB, USER_ID, 'workspace-1'); + const thread = createThread(); + const message = createMessage(); + const client = createClient(); + + await service.handleMention(thread, message, { + agentId: 'agent-1', + botContext: { platformThreadId: THREAD_ID } as any, + client, + }); + + expect(AiAgentService).toHaveBeenCalledWith(FAKE_DB, USER_ID, { + workspaceId: 'workspace-1', + }); + expect(mockExecAgent).toHaveBeenCalledWith( + expect.objectContaining({ + hooks: expect.arrayContaining([ + expect.objectContaining({ + webhook: expect.objectContaining({ + body: expect.objectContaining({ workspaceId: 'workspace-1' }), + }), + }), + ]), + }), + ); + }); + it('calls execAgent with hooks in queue mode for subscribed message', async () => { const service = new AgentBridgeService(FAKE_DB, USER_ID); const thread = createThread({ topicId: 'topic-1' }); diff --git a/src/server/services/bot/__tests__/BotMessageRouter.test.ts b/src/server/services/bot/__tests__/BotMessageRouter.test.ts index 98d09a57cb..966a7a26a3 100644 --- a/src/server/services/bot/__tests__/BotMessageRouter.test.ts +++ b/src/server/services/bot/__tests__/BotMessageRouter.test.ts @@ -102,16 +102,14 @@ vi.mock('@/server/services/aiAgent', () => ({ const mockHandleMention = vi.hoisted(() => vi.fn().mockResolvedValue(undefined)); const mockHandleSubscribedMessage = vi.hoisted(() => vi.fn().mockResolvedValue(undefined)); +const mockAgentBridgeServiceCtor = vi.hoisted(() => vi.fn()); // Default to "platform does not opt into thread isolation" so existing tests // keep their pre-behaviour. Individual tests can replace this via // `.mockResolvedValueOnce(...)` to simulate Discord's auto-thread upgrade. const mockOpenThreadForChannelWake = vi.hoisted(() => vi.fn().mockResolvedValue(undefined)); vi.mock('../AgentBridgeService', () => ({ - AgentBridgeService: vi.fn().mockImplementation(() => ({ - handleMention: mockHandleMention, - handleSubscribedMessage: mockHandleSubscribedMessage, - })), + AgentBridgeService: mockAgentBridgeServiceCtor, })); // Mock platform entries @@ -435,6 +433,10 @@ describe('BotMessageRouter', () => { mockFindEnabledByPlatform.mockResolvedValue([]); mockHandleMention.mockResolvedValue(undefined); mockHandleSubscribedMessage.mockResolvedValue(undefined); + mockAgentBridgeServiceCtor.mockImplementation(() => ({ + handleMention: mockHandleMention, + handleSubscribedMessage: mockHandleSubscribedMessage, + })); mockOpenThreadForChannelWake.mockResolvedValue(undefined); // participant tracking — restore defaults wiped by // clearAllMocks. Empty list = fresh single-human thread; individual @@ -492,6 +494,20 @@ describe('BotMessageRouter', () => { expect(mockCreateAdapter).toHaveBeenCalled(); }); + it('passes provider workspaceId to AgentBridgeService', async () => { + mockFindEnabledByPlatform.mockResolvedValue([ + makeProvider({ applicationId: 'tg-bot-123', workspaceId: 'workspace-1' }), + ]); + + const router = new BotMessageRouter(); + const handler = router.getWebhookHandler('telegram', 'tg-bot-123'); + + const req = new Request('https://example.com/webhook', { body: '{}', method: 'POST' }); + await handler(req); + + expect(mockAgentBridgeServiceCtor).toHaveBeenCalledWith(FAKE_DB, 'user-1', 'workspace-1'); + }); + it('should return cached bot on subsequent requests', async () => { mockFindEnabledByPlatform.mockResolvedValue([makeProvider({ applicationId: 'tg-bot-123' })]); diff --git a/src/server/services/brief/index.ts b/src/server/services/brief/index.ts index b06b9ad53c..d3051173c9 100644 --- a/src/server/services/brief/index.ts +++ b/src/server/services/brief/index.ts @@ -35,13 +35,15 @@ export class BriefService { private db: LobeChatDatabase; private taskModel: TaskModel; private userId: string; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; - this.agentModel = new AgentModel(db, userId); - this.briefModel = new BriefModel(db, userId); - this.taskModel = new TaskModel(db, userId); + this.workspaceId = workspaceId; + this.agentModel = new AgentModel(db, userId, workspaceId); + this.briefModel = new BriefModel(db, userId, workspaceId); + this.taskModel = new TaskModel(db, userId, workspaceId); } /** @@ -208,7 +210,7 @@ export class BriefService { // Lazy-loaded to avoid pulling ModelRuntime into BriefService's // import graph (TaskRunner → TaskLifecycle → ModelRuntime). const { TaskRunnerService } = await import('@/server/services/taskRunner'); - const runner = new TaskRunnerService(this.db, this.userId); + const runner = new TaskRunnerService(this.db, this.userId, this.workspaceId); await runner.cascadeOnCompletion(brief.taskId); } } diff --git a/src/server/services/chunk/index.ts b/src/server/services/chunk/index.ts index dd10bc2612..1c2409ca5d 100644 --- a/src/server/services/chunk/index.ts +++ b/src/server/services/chunk/index.ts @@ -13,17 +13,19 @@ import { export class ChunkService { private userId: string; + private workspaceId?: string; private chunkClient: ContentChunk; private fileModel: FileModel; private asyncTaskModel: AsyncTaskModel; - constructor(serverDB: LobeChatDatabase, userId: string) { + constructor(serverDB: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; + this.workspaceId = workspaceId; this.chunkClient = new ContentChunk(); - this.fileModel = new FileModel(serverDB, userId); - this.asyncTaskModel = new AsyncTaskModel(serverDB, userId); + this.fileModel = new FileModel(serverDB, userId, workspaceId); + this.asyncTaskModel = new AsyncTaskModel(serverDB, userId, workspaceId); } async chunkContent(params: ChunkContentParams) { @@ -50,7 +52,11 @@ export class ChunkService { // trigger embedding task asynchronously try { - await asyncCaller.file.embeddingChunks({ fileId, taskId: asyncTaskId }); + await asyncCaller.file.embeddingChunks({ + fileId, + taskId: asyncTaskId, + workspaceId: this.workspaceId, + }); } catch (e) { console.error('[embeddingFileChunks] error:', e); @@ -91,17 +97,19 @@ export class ChunkService { const asyncCaller = await createAsyncCaller({ userId: this.userId }); // trigger parse file task asynchronously - asyncCaller.file.parseFileToChunks({ fileId, taskId: asyncTaskId }).catch(async (e) => { - console.error('[ParseFileToChunks] error:', e); + asyncCaller.file + .parseFileToChunks({ fileId, taskId: asyncTaskId, workspaceId: this.workspaceId }) + .catch(async (e) => { + console.error('[ParseFileToChunks] error:', e); - await this.asyncTaskModel.update(asyncTaskId, { - error: new AsyncTaskError( - AsyncTaskErrorType.TaskTriggerError, - 'trigger chunk embedding async task error. Please make sure the APP_URL is available from your server. You can check the proxy config or WAF blocking', - ), - status: AsyncTaskStatus.Error, + await this.asyncTaskModel.update(asyncTaskId, { + error: new AsyncTaskError( + AsyncTaskErrorType.TaskTriggerError, + 'trigger chunk embedding async task error. Please make sure the APP_URL is available from your server. You can check the proxy config or WAF blocking', + ), + status: AsyncTaskStatus.Error, + }); }); - }); return asyncTaskId; } diff --git a/src/server/services/discover/index.ts b/src/server/services/discover/index.ts index b7bc73e94a..d82a77c81b 100644 --- a/src/server/services/discover/index.ts +++ b/src/server/services/discover/index.ts @@ -291,15 +291,18 @@ export class DiscoverService { return result; }; - private normalizeAuthorField = (author: unknown): { name: string; userName?: string } => { + private normalizeAuthorField = ( + author: unknown, + ): { name: string; ownerType?: 'user' | 'organization'; userName?: string } => { if (!author) return { name: '' }; if (typeof author === 'string') return { name: author }; if (typeof author === 'object') { - const { avatar, url, name, userName } = author as { + const { avatar, url, name, userName, type } = author as { avatar?: unknown; name?: unknown; + type?: unknown; url?: unknown; userName?: unknown; }; @@ -312,6 +315,7 @@ export class DiscoverService { return { name: authorName, + ownerType: type === 'organization' ? 'organization' : 'user', userName: typeof userName === 'string' ? userName : undefined, }; } @@ -599,6 +603,7 @@ export class DiscoverService { pluginCount: (data.config as any)?.plugins?.length || (data as any).pluginCount || 0, readme: data.documentationUrl || '', schemaVersion: 1, + ownerType: normalizedAuthor.ownerType, status: (data.status as AgentStatus) || undefined, summary: data.summary || '', systemRole: (data.config as any)?.systemRole || '', diff --git a/src/server/services/document/history.ts b/src/server/services/document/history.ts index 64a101eee8..4741ec2faa 100644 --- a/src/server/services/document/history.ts +++ b/src/server/services/document/history.ts @@ -1,3 +1,4 @@ +import { buildWorkspaceWhere } from '@lobechat/database'; import type { DocumentItem } from '@lobechat/database/schemas'; import { documentHistories, documents } from '@lobechat/database/schemas'; import { and, desc, eq, gte, inArray, lt, or } from 'drizzle-orm'; @@ -30,12 +31,20 @@ const getDocumentEditorData = (document: DocumentItem | undefined): Record + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents); + + private historiesOwnership = () => + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documentHistories); + createHistory = async (params: { documentId: string; editorData: Record; @@ -45,7 +54,7 @@ export class DocumentHistoryService { const [document] = await this.db .select({ id: documents.id }) .from(documents) - .where(and(eq(documents.id, params.documentId), eq(documents.userId, this.userId))) + .where(and(eq(documents.id, params.documentId), this.documentsOwnership())) .limit(1); if (!document) { @@ -58,6 +67,7 @@ export class DocumentHistoryService { saveSource: params.saveSource, savedAt: params.savedAt, userId: this.userId, + workspaceId: this.workspaceId ?? null, }); await this.trimHistoryBySource(params.documentId, params.saveSource); @@ -106,7 +116,7 @@ export class DocumentHistoryService { where: and( eq(documentHistories.id, params.historyId), eq(documentHistories.documentId, params.documentId), - eq(documentHistories.userId, this.userId), + this.historiesOwnership(), options?.historySince ? gte(documentHistories.savedAt, options.historySince) : undefined, ), }); @@ -151,7 +161,7 @@ export class DocumentHistoryService { orderBy: [desc(documentHistories.savedAt), desc(documentHistories.id)], where: and( eq(documentHistories.documentId, params.documentId), - eq(documentHistories.userId, this.userId), + this.historiesOwnership(), options?.historySince ? gte(documentHistories.savedAt, options.historySince) : undefined, params.beforeSavedAt !== undefined && params.beforeId !== undefined ? or( @@ -212,7 +222,7 @@ export class DocumentHistoryService { private findHeadDocument = async (documentId: string) => { return this.db.query.documents.findFirst({ - where: and(eq(documents.id, documentId), eq(documents.userId, this.userId)), + where: and(eq(documents.id, documentId), this.documentsOwnership()), }); }; @@ -229,7 +239,7 @@ export class DocumentHistoryService { .where( and( eq(documentHistories.documentId, documentId), - eq(documentHistories.userId, this.userId), + this.historiesOwnership(), eq(documentHistories.saveSource, saveSource), ), ) @@ -242,7 +252,7 @@ export class DocumentHistoryService { await this.db.delete(documentHistories).where( and( eq(documentHistories.documentId, documentId), - eq(documentHistories.userId, this.userId), + this.historiesOwnership(), eq(documentHistories.saveSource, saveSource), inArray( documentHistories.id, diff --git a/src/server/services/document/index.ts b/src/server/services/document/index.ts index 31acd8bb3e..d0df89c31f 100644 --- a/src/server/services/document/index.ts +++ b/src/server/services/document/index.ts @@ -10,6 +10,7 @@ import isEqual from 'fast-deep-equal'; import { DocumentModel } from '@/database/models/document'; import { FileModel } from '@/database/models/file'; +import { buildWorkspaceWhere } from '@/database/utils/workspace'; import { isValidEditorData } from '@/libs/editor/isValidEditorData'; import { normalizeEditorDataDiffNodes } from '@/libs/editor/normalizeDiffNodes'; import { type LobeDocument } from '@/types/document'; @@ -51,21 +52,28 @@ export class DocumentService { private fileServiceInstance?: FileService; private db: LobeChatDatabase; - constructor(db: LobeChatDatabase, userId: string) { + private workspaceId?: string; + + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; this.db = db; - this.fileModel = new FileModel(db, userId); - this.documentModel = new DocumentModel(db, userId); + this.workspaceId = workspaceId; + this.fileModel = new FileModel(db, userId, workspaceId); + this.documentModel = new DocumentModel(db, userId, workspaceId); } private get fileService() { - this.fileServiceInstance ??= new FileService(this.db, this.userId); + this.fileServiceInstance ??= new FileService(this.db, this.userId, this.workspaceId); return this.fileServiceInstance; } private get documentHistoryService() { - this.documentHistoryServiceInstance ??= new DocumentHistoryService(this.db, this.userId); + this.documentHistoryServiceInstance ??= new DocumentHistoryService( + this.db, + this.userId, + this.workspaceId, + ); return this.documentHistoryServiceInstance; } @@ -278,7 +286,10 @@ export class DocumentService { // If it's a folder, recursively delete all children first if (document.fileType === CUSTOM_FOLDER_FILE_TYPE) { const children = await this.db.query.documents.findMany({ - where: eq(documents.parentId, id), + where: and( + eq(documents.parentId, id), + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents), + ), }); // Recursively delete all children @@ -288,7 +299,10 @@ export class DocumentService { // Also delete all files in this folder const childFiles = await this.db.query.files.findMany({ - where: and(eq(files.parentId, id), eq(files.userId, this.userId)), + where: and( + eq(files.parentId, id), + buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, files), + ), }); for (const file of childFiles) { @@ -319,9 +333,13 @@ export class DocumentService { async updateDocument(id: string, params: UpdateDocumentParams): Promise { return this.db.transaction(async (tx) => { const transactionDb = tx as unknown as LobeChatDatabase; - const documentModel = new DocumentModel(transactionDb, this.userId); - const fileModel = new FileModel(transactionDb, this.userId); - const documentHistoryService = new DocumentHistoryService(transactionDb, this.userId); + const documentModel = new DocumentModel(transactionDb, this.userId, this.workspaceId); + const fileModel = new FileModel(transactionDb, this.userId, this.workspaceId); + const documentHistoryService = new DocumentHistoryService( + transactionDb, + this.userId, + this.workspaceId, + ); const currentDocument = await documentModel.findById(id); if (!currentDocument) { diff --git a/src/server/services/file/__tests__/index.test.ts b/src/server/services/file/__tests__/index.test.ts index 2118a5278c..5e9d620b59 100644 --- a/src/server/services/file/__tests__/index.test.ts +++ b/src/server/services/file/__tests__/index.test.ts @@ -82,6 +82,14 @@ describe('FileService', () => { consoleErrorSpy?.mockRestore(); }); + it('scopes FileModel to workspace when workspaceId is provided', () => { + vi.clearAllMocks(); + + new FileService(mockDb, mockUserId, 'workspace-1'); + + expect(FileModel).toHaveBeenCalledWith(mockDb, mockUserId, 'workspace-1'); + }); + describe('downloadFileToLocal', () => { const mockFile = { id: 'test-file-id', diff --git a/src/server/services/file/extractFileIdsFromEditorData.ts b/src/server/services/file/extractFileIdsFromEditorData.ts index 83f75dddfb..433ab2e360 100644 --- a/src/server/services/file/extractFileIdsFromEditorData.ts +++ b/src/server/services/file/extractFileIdsFromEditorData.ts @@ -1,7 +1,8 @@ import { files } from '@lobechat/database/schemas'; -import { and, eq, inArray } from 'drizzle-orm'; +import { and, inArray } from 'drizzle-orm'; import type { LobeChatDatabase } from '@/database/type'; +import { buildWorkspaceWhere } from '@/database/utils/workspace'; /** * Walks a serialized Lexical editor state, collects every URL referenced by @@ -86,7 +87,7 @@ function extractStorageKeyFromUrl(url: string): string | undefined { export async function extractFileIdsFromEditorData( json: unknown, - ctx: { db: LobeChatDatabase; userId: string }, + ctx: { db: LobeChatDatabase; userId: string; workspaceId?: string }, ): Promise { const urls = collectAttachmentUrlsFromEditorData(json); if (urls.length === 0) return []; @@ -115,7 +116,7 @@ export async function extractFileIdsFromEditorData( const rows = await ctx.db .select({ id: files.id, url: files.url }) .from(files) - .where(and(eq(files.userId, ctx.userId), inArray(files.url, keys))); + .where(and(buildWorkspaceWhere(ctx, files), inArray(files.url, keys))); const firstIdPerUrl = new Map(); for (const row of rows) { diff --git a/src/server/services/file/index.ts b/src/server/services/file/index.ts index a99d3a3ace..609935bcf9 100644 --- a/src/server/services/file/index.ts +++ b/src/server/services/file/index.ts @@ -31,9 +31,9 @@ export class FileService { private impl: FileServiceImpl; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; - this.fileModel = new FileModel(db, userId); + this.fileModel = new FileModel(db, userId, workspaceId); this.impl = createFileServiceModule(db); } diff --git a/src/server/services/file/resolveAttachments.ts b/src/server/services/file/resolveAttachments.ts index 9dfc632801..4f638a6bf7 100644 --- a/src/server/services/file/resolveAttachments.ts +++ b/src/server/services/file/resolveAttachments.ts @@ -25,6 +25,7 @@ interface ResolveArgs { db: LobeChatDatabase; fileIds: string[]; userId: string; + workspaceId?: string; } const dedupe = (ids: string[]) => Array.from(new Set(ids)); @@ -41,6 +42,7 @@ export const resolveAttachmentsByFileIds = async ({ db, fileIds, userId, + workspaceId, }: ResolveArgs): Promise => { const result: ResolvedAttachments = { fileList: [], @@ -52,15 +54,15 @@ export const resolveAttachmentsByFileIds = async ({ if (fileIds.length === 0) return result; const dedupedFileIds = dedupe(fileIds); - const fileModel = new FileModel(db, userId); - const fileService = new FileService(db, userId); + const fileModel = new FileModel(db, userId, workspaceId); + const fileService = new FileService(db, userId, workspaceId); const fileRecords = await fileModel.findByIds(dedupedFileIds); if (fileRecords.length === 0) { log('no file records found for fileIds=%O', dedupedFileIds); return result; } - const documentService = new DocumentService(db, userId); + const documentService = new DocumentService(db, userId, workspaceId); const recordById = new Map(fileRecords.map((f) => [f.id, f])); // Resolve every file in parallel — URL signing + PDF parsing can both be @@ -145,18 +147,19 @@ export const resolveAttachmentMetadata = async ({ fileIds, signUrls = true, userId, + workspaceId, }: ResolveArgs & { signUrls?: boolean }): Promise => { if (fileIds.length === 0) return []; const dedupedFileIds = dedupe(fileIds); - const fileModel = new FileModel(db, userId); + const fileModel = new FileModel(db, userId, workspaceId); const fileRecords = await fileModel.findByIds(dedupedFileIds); if (fileRecords.length === 0) { log('no file records found for fileIds=%O', dedupedFileIds); return []; } - const fileService = signUrls ? new FileService(db, userId) : null; + const fileService = signUrls ? new FileService(db, userId, workspaceId) : null; const recordById = new Map(fileRecords.map((f) => [f.id, f])); const items = await Promise.all( dedupedFileIds.map(async (id) => { diff --git a/src/server/services/followUpAction/index.test.ts b/src/server/services/followUpAction/index.test.ts index c847e08cf4..0bc6cabdbd 100644 --- a/src/server/services/followUpAction/index.test.ts +++ b/src/server/services/followUpAction/index.test.ts @@ -159,6 +159,7 @@ describe('FollowUpActionService.extract', () => { threadId: { col: 'threadId' }, topicId: { col: 'topicId' }, userId: { col: 'userId' }, + workspaceId: { col: 'workspaceId' }, }; const ops = { and: (...parts: any[]) => ({ op: 'and', parts }), @@ -191,6 +192,34 @@ describe('FollowUpActionService.extract', () => { expect(parts.some((p) => p.op === 'eq' && p.col === table.threadId)).toBe(false); }); + it('filters personal mode by userId and null workspaceId', async () => { + queryFindFirstSpy.mockResolvedValue(undefined); + await svc.extract({ modelConfig: MODEL_CONFIG, topicId: TEST_TOPIC }); + const { parts, table } = captureWhereOps(); + const ownership = parts.find((p) => p.op === 'and' && p.parts?.length === 2); + expect(ownership.parts).toEqual([ + { col: table.userId, op: 'eq', value: TEST_USER }, + { col: table.workspaceId, op: 'isNull' }, + ]); + }); + + it('filters workspace mode by workspaceId and forwards it to model runtime', async () => { + svc = new FollowUpActionService(dbMock, TEST_USER, 'workspace-1'); + queryFindFirstSpy.mockResolvedValue({ id: FOUND_MSG, content: 'q?' }); + runtimeMock.generateObject.mockResolvedValue({ chips: [] }); + + await svc.extract({ modelConfig: MODEL_CONFIG, topicId: TEST_TOPIC }); + + const { parts, table } = captureWhereOps(); + expect(parts).toContainEqual({ col: table.workspaceId, op: 'eq', value: 'workspace-1' }); + expect(ModelRuntimeModule.initModelRuntimeFromDB).toHaveBeenCalledWith( + dbMock, + TEST_USER, + MODEL_CONFIG.provider, + 'workspace-1', + ); + }); + it('appends onboarding addendum to system prompt when hint is onboarding', async () => { queryFindFirstSpy.mockResolvedValue({ id: FOUND_MSG, content: 'q?' }); runtimeMock.generateObject.mockResolvedValue({ chips: [] }); diff --git a/src/server/services/followUpAction/index.ts b/src/server/services/followUpAction/index.ts index f48720096f..6b09a02e6a 100644 --- a/src/server/services/followUpAction/index.ts +++ b/src/server/services/followUpAction/index.ts @@ -16,10 +16,12 @@ const EMPTY_RESULT = (messageId: string): FollowUpExtractResult => ({ chips: [], export class FollowUpActionService { private readonly db: LobeChatDatabase; private readonly userId: string; + private readonly workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } async extract({ @@ -35,7 +37,9 @@ export class FollowUpActionService { orderBy: (m, { desc }) => desc(m.createdAt), where: (m, { and, eq, isNotNull, isNull, ne }) => and( - eq(m.userId, this.userId), + this.workspaceId + ? eq(m.workspaceId, this.workspaceId) + : and(eq(m.userId, this.userId), isNull(m.workspaceId)), eq(m.topicId, topicId), // Discriminate thread vs main topic: an absent threadId must NOT // surface a thread reply that lives under the same topicId. @@ -54,7 +58,7 @@ export class FollowUpActionService { const { system, user } = buildSuggestionPrompt({ assistantText: text, hint }); const { model, provider } = modelConfig; - const ai = new AiGenerationService(this.db, this.userId); + const ai = new AiGenerationService(this.db, this.userId, this.workspaceId); let raw: unknown; try { raw = await ai.generateObject( diff --git a/src/server/services/gateway/GatewayManager.test.ts b/src/server/services/gateway/GatewayManager.test.ts index 51ac45d9d2..935bcd09bb 100644 --- a/src/server/services/gateway/GatewayManager.test.ts +++ b/src/server/services/gateway/GatewayManager.test.ts @@ -2,15 +2,15 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'; import { getServerDB } from '@/database/core/db-adaptor'; -import { AgentBotProviderModel } from '@/database/models/agentBotProvider'; import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt'; import type { PlatformClient, PlatformDefinition } from '@/server/services/bot/platforms'; import { createGatewayManager, GatewayManager, getGatewayManager } from './GatewayManager'; // Mock database and external dependencies -const { mockFindEnabledByPlatform } = vi.hoisted(() => ({ +const { mockFindEnabledByPlatform, mockFindEnabledByPlatformAndAppId } = vi.hoisted(() => ({ mockFindEnabledByPlatform: vi.fn().mockResolvedValue([]), + mockFindEnabledByPlatformAndAppId: vi.fn().mockResolvedValue(null), })); vi.mock('@/database/core/db-adaptor', () => ({ @@ -20,6 +20,7 @@ vi.mock('@/database/core/db-adaptor', () => ({ vi.mock('@/database/models/agentBotProvider', () => ({ AgentBotProviderModel: Object.assign(vi.fn(), { findEnabledByPlatform: mockFindEnabledByPlatform, + findEnabledByPlatformAndAppId: mockFindEnabledByPlatformAndAppId, }), })); @@ -66,19 +67,15 @@ const createFakeDefinition = ( describe('GatewayManager', () => { let mockDb: any; let mockGateKeeper: any; - let mockAgentBotProviderModel: any; beforeEach(() => { mockDb = {}; mockGateKeeper = {}; - mockAgentBotProviderModel = { - findEnabledByApplicationId: vi.fn(), - }; vi.mocked(getServerDB).mockResolvedValue(mockDb as any); vi.mocked(KeyVaultsGateKeeper.initWithEnvKey).mockResolvedValue(mockGateKeeper as any); - vi.mocked(AgentBotProviderModel).mockImplementation(() => mockAgentBotProviderModel); mockFindEnabledByPlatform.mockResolvedValue([]); + mockFindEnabledByPlatformAndAppId.mockResolvedValue(null); // Clean up global singleton between tests const globalForGateway = globalThis as any; @@ -166,7 +163,7 @@ describe('GatewayManager', () => { const factory = vi.fn().mockReturnValueOnce(mockBot1).mockReturnValueOnce(mockBot2); // Pre-load two bots by calling startClient - mockAgentBotProviderModel.findEnabledByApplicationId.mockResolvedValue({ + mockFindEnabledByPlatformAndAppId.mockResolvedValue({ applicationId: 'app-1', credentials: { token: 'tok1' }, }); @@ -174,8 +171,8 @@ describe('GatewayManager', () => { const manager = new GatewayManager({ definitions: [createFakeDefinition('slack', factory)] }); await manager.start(); - await manager.startClient('slack', 'app-1', 'user-1'); - await manager.startClient('slack', 'app-2', 'user-2'); + await manager.startClient('slack', 'app-1'); + await manager.startClient('slack', 'app-2'); await manager.stop(); @@ -187,50 +184,50 @@ describe('GatewayManager', () => { describe('startClient', () => { it('should do nothing when no provider is found in DB', async () => { - mockAgentBotProviderModel.findEnabledByApplicationId.mockResolvedValue(null); + mockFindEnabledByPlatformAndAppId.mockResolvedValue(null); const factory = vi.fn(); const manager = new GatewayManager({ definitions: [createFakeDefinition('slack', factory)] }); - await manager.startClient('slack', 'app-123', 'user-abc'); + await manager.startClient('slack', 'app-123'); expect(factory).not.toHaveBeenCalled(); }); it('should do nothing when the platform is not registered', async () => { - mockAgentBotProviderModel.findEnabledByApplicationId.mockResolvedValue({ + mockFindEnabledByPlatformAndAppId.mockResolvedValue({ applicationId: 'app-123', credentials: { token: 'tok' }, }); const manager = new GatewayManager({ definitions: [] }); // empty definitions - await manager.startClient('unsupported', 'app-123', 'user-abc'); + await manager.startClient('unsupported', 'app-123'); - // No bot should be created - expect(vi.mocked(AgentBotProviderModel)).toHaveBeenCalled(); + // No bot should be created, but the provider lookup still ran + expect(mockFindEnabledByPlatformAndAppId).toHaveBeenCalled(); }); it('should start a bot and register it', async () => { const mockBot = createMockBot(); const factory = vi.fn().mockReturnValue(mockBot); - mockAgentBotProviderModel.findEnabledByApplicationId.mockResolvedValue({ + mockFindEnabledByPlatformAndAppId.mockResolvedValue({ applicationId: 'app-123', credentials: { token: 'tok123' }, }); const manager = new GatewayManager({ definitions: [createFakeDefinition('slack', factory)] }); - await manager.startClient('slack', 'app-123', 'user-abc'); + await manager.startClient('slack', 'app-123'); expect(factory).toHaveBeenCalled(); expect(mockBot.start).toHaveBeenCalled(); }); - it('should scope lookup by the requested user but build runtime context from the provider row', async () => { + it('should look up the provider system-wide and build runtime context from the provider row', async () => { const mockBot = createMockBot(); const factory = vi.fn().mockReturnValue(mockBot); - mockAgentBotProviderModel.findEnabledByApplicationId.mockResolvedValue({ + mockFindEnabledByPlatformAndAppId.mockResolvedValue({ applicationId: 'app-123', credentials: { token: 'tok123' }, settings: {}, @@ -239,11 +236,12 @@ describe('GatewayManager', () => { const manager = new GatewayManager({ definitions: [createFakeDefinition('slack', factory)] }); - await manager.startClient('slack', 'app-123', 'requested-user'); + await manager.startClient('slack', 'app-123'); - expect(vi.mocked(AgentBotProviderModel)).toHaveBeenCalledWith( + expect(mockFindEnabledByPlatformAndAppId).toHaveBeenCalledWith( mockDb, - 'requested-user', + 'slack', + 'app-123', mockGateKeeper, ); expect(factory.mock.calls[0][1]).toMatchObject({ userId: 'provider-user' }); @@ -255,7 +253,7 @@ describe('GatewayManager', () => { const mockBot2 = createMockBot(); const factory = vi.fn().mockReturnValueOnce(mockBot1).mockReturnValueOnce(mockBot2); - mockAgentBotProviderModel.findEnabledByApplicationId.mockResolvedValue({ + mockFindEnabledByPlatformAndAppId.mockResolvedValue({ applicationId: 'app-123', credentials: { token: 'tok' }, }); @@ -263,11 +261,11 @@ describe('GatewayManager', () => { const manager = new GatewayManager({ definitions: [createFakeDefinition('slack', factory)] }); // Start bot first time - await manager.startClient('slack', 'app-123', 'user-abc'); + await manager.startClient('slack', 'app-123'); expect(mockBot1.start).toHaveBeenCalled(); // Start bot second time for same key — should stop first - await manager.startClient('slack', 'app-123', 'user-abc'); + await manager.startClient('slack', 'app-123'); expect(mockBot1.stop).toHaveBeenCalled(); expect(mockBot2.start).toHaveBeenCalled(); }); @@ -284,7 +282,7 @@ describe('GatewayManager', () => { it('should stop and remove a running bot', async () => { const mockBot = createMockBot(); const factory = vi.fn().mockReturnValue(mockBot); - mockAgentBotProviderModel.findEnabledByApplicationId.mockResolvedValue({ + mockFindEnabledByPlatformAndAppId.mockResolvedValue({ applicationId: 'app-123', credentials: { token: 'tok' }, }); @@ -292,7 +290,7 @@ describe('GatewayManager', () => { const manager = new GatewayManager({ definitions: [createFakeDefinition('slack', factory)] }); // First start the bot - await manager.startClient('slack', 'app-123', 'user-abc'); + await manager.startClient('slack', 'app-123'); expect(mockBot.start).toHaveBeenCalled(); // Then stop it @@ -305,14 +303,14 @@ describe('GatewayManager', () => { const mockBot2 = createMockBot(); const factory = vi.fn().mockReturnValueOnce(mockBot1).mockReturnValueOnce(mockBot2); - mockAgentBotProviderModel.findEnabledByApplicationId + mockFindEnabledByPlatformAndAppId .mockResolvedValueOnce({ applicationId: 'app-1', credentials: {} }) .mockResolvedValueOnce({ applicationId: 'app-2', credentials: {} }); const manager = new GatewayManager({ definitions: [createFakeDefinition('slack', factory)] }); - await manager.startClient('slack', 'app-1', 'user-1'); - await manager.startClient('slack', 'app-2', 'user-2'); + await manager.startClient('slack', 'app-1'); + await manager.startClient('slack', 'app-2'); await manager.stopClient('slack', 'app-1'); diff --git a/src/server/services/gateway/GatewayManager.ts b/src/server/services/gateway/GatewayManager.ts index 202d34a69f..d96a7a670c 100644 --- a/src/server/services/gateway/GatewayManager.ts +++ b/src/server/services/gateway/GatewayManager.ts @@ -74,7 +74,7 @@ export class GatewayManager { // Client operations (point-to-point) // ------------------------------------------------------------------ - async startClient(platform: string, applicationId: string, userId: string): Promise { + async startClient(platform: string, applicationId: string): Promise { const key = buildRuntimeKey(platform, applicationId); // Stop existing if any @@ -85,11 +85,16 @@ export class GatewayManager { this.clients.delete(key); } - // Load from DB (user-scoped, single row) + // Load from DB (system-wide single row — platform + applicationId is globally + // unique; the caller is already authorized at the router boundary). const serverDB = await getServerDB(); const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey(); - const model = new AgentBotProviderModel(serverDB, userId, gateKeeper); - const provider = await model.findEnabledByApplicationId(platform, applicationId); + const provider = await AgentBotProviderModel.findEnabledByPlatformAndAppId( + serverDB, + platform, + applicationId, + gateKeeper, + ); if (!provider) { log('No enabled provider found for %s', key); diff --git a/src/server/services/gateway/__tests__/GatewayManager.test.ts b/src/server/services/gateway/__tests__/GatewayManager.test.ts index df2ebeafee..a33fc06c9a 100644 --- a/src/server/services/gateway/__tests__/GatewayManager.test.ts +++ b/src/server/services/gateway/__tests__/GatewayManager.test.ts @@ -5,7 +5,7 @@ import type { PlatformDefinition } from '@/server/services/bot/platforms'; import { GatewayManager } from '../GatewayManager'; const mockFindEnabledByPlatform = vi.hoisted(() => vi.fn()); -const mockFindEnabledByApplicationId = vi.hoisted(() => vi.fn()); +const mockFindEnabledByPlatformAndAppId = vi.hoisted(() => vi.fn()); const mockInitWithEnvKey = vi.hoisted(() => vi.fn()); const mockGetServerDB = vi.hoisted(() => vi.fn()); @@ -14,10 +14,9 @@ vi.mock('@/database/core/db-adaptor', () => ({ })); vi.mock('@/database/models/agentBotProvider', () => { - const MockModel = vi.fn().mockImplementation(() => ({ - findEnabledByApplicationId: mockFindEnabledByApplicationId, - })); + const MockModel = vi.fn(); (MockModel as any).findEnabledByPlatform = mockFindEnabledByPlatform; + (MockModel as any).findEnabledByPlatformAndAppId = mockFindEnabledByPlatformAndAppId; return { AgentBotProviderModel: MockModel }; }); @@ -65,7 +64,7 @@ describe('GatewayManager', () => { mockGetServerDB.mockResolvedValue(FAKE_DB); mockInitWithEnvKey.mockResolvedValue(FAKE_GATEKEEPER); mockFindEnabledByPlatform.mockResolvedValue([]); - mockFindEnabledByApplicationId.mockResolvedValue(null); + mockFindEnabledByPlatformAndAppId.mockResolvedValue(null); manager = new GatewayManager({ definitions: [fakeDefinition] }); }); @@ -137,7 +136,7 @@ describe('GatewayManager', () => { it('should handle missing provider gracefully', async () => { await manager.start(); - await expect(manager.startClient('fakeplatform', 'app-1', 'user-1')).resolves.toBeUndefined(); + await expect(manager.startClient('fakeplatform', 'app-1')).resolves.toBeUndefined(); }); }); diff --git a/src/server/services/gateway/index.ts b/src/server/services/gateway/index.ts index 4a1090803b..e83b34f3f2 100644 --- a/src/server/services/gateway/index.ts +++ b/src/server/services/gateway/index.ts @@ -262,8 +262,12 @@ export class GatewayService { const definition = platformRegistry.getPlatform(platform); const serverDB = await getServerDB(); const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey(); - const model = new AgentBotProviderModel(serverDB, userId, gateKeeper); - const provider = await model.findEnabledByApplicationId(platform, applicationId); + const provider = await AgentBotProviderModel.findEnabledByPlatformAndAppId( + serverDB, + platform, + applicationId, + gateKeeper, + ); const connectionMode = resolveConnectionMode(definition, provider?.settings); @@ -287,7 +291,7 @@ export class GatewayService { } const manager = createGatewayManager({ definitions: platformRegistry.listPlatforms() }); - await manager.startClient(platform, applicationId, userId); + await manager.startClient(platform, applicationId); log('Started client %s:%s (direct)', platform, applicationId); return 'started'; } @@ -299,7 +303,7 @@ export class GatewayService { manager = getGatewayManager(); } - await manager!.startClient(platform, applicationId, userId); + await manager!.startClient(platform, applicationId); log('Started client %s:%s', platform, applicationId); return 'started'; } @@ -310,13 +314,12 @@ export class GatewayService { * disabled; webhook-mode providers are skipped (they have no persistent * gateway connection to query). */ - async refreshBotRuntimeStatusesByAgent(agentId: string, userId: string): Promise { + async refreshBotRuntimeStatusesByAgent(agentId: string): Promise { if (!this.useMessageGateway) return; const serverDB = await getServerDB(); const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey(); - const model = new AgentBotProviderModel(serverDB, userId, gateKeeper); - const providers = await model.findByAgentId(agentId); + const providers = await AgentBotProviderModel.findByAgentId(serverDB, agentId, gateKeeper); const client = getMessageGatewayClient(); await Promise.all( @@ -355,7 +358,6 @@ export class GatewayService { async refreshBotRuntimeStatus( platform: string, applicationId: string, - userId: string, ): Promise { const cached = await getBotRuntimeStatus(platform, applicationId); @@ -363,8 +365,12 @@ export class GatewayService { const serverDB = await getServerDB(); const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey(); - const model = new AgentBotProviderModel(serverDB, userId, gateKeeper); - const provider = await model.findEnabledByApplicationId(platform, applicationId); + const provider = await AgentBotProviderModel.findEnabledByPlatformAndAppId( + serverDB, + platform, + applicationId, + gateKeeper, + ); if (!provider) return cached; @@ -406,8 +412,12 @@ export class GatewayService { if (userId) { const serverDB = await getServerDB(); const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey(); - const model = new AgentBotProviderModel(serverDB, userId, gateKeeper); - const provider = await model.findEnabledByApplicationId(platform, applicationId); + const provider = await AgentBotProviderModel.findEnabledByPlatformAndAppId( + serverDB, + platform, + applicationId, + gateKeeper, + ); connectionMode = resolveConnectionMode(definition, provider?.settings); } else { connectionMode = resolveConnectionMode(definition, undefined); @@ -549,8 +559,12 @@ export class GatewayService { const serverDB = await getServerDB(); const gateKeeper = await KeyVaultsGateKeeper.initWithEnvKey(); - const model = new AgentBotProviderModel(serverDB, userId, gateKeeper); - const provider = await model.findEnabledByApplicationId(platform, applicationId); + const provider = await AgentBotProviderModel.findEnabledByPlatformAndAppId( + serverDB, + platform, + applicationId, + gateKeeper, + ); if (!provider) { log('No enabled provider found for %s:%s', platform, applicationId); @@ -565,7 +579,7 @@ export class GatewayService { // perform its own initialization (e.g. Telegram calls setWebhook). if (connectionMode === 'webhook') { const manager = createGatewayManager({ definitions: platformRegistry.listPlatforms() }); - await manager.startClient(platform, applicationId, userId); + await manager.startClient(platform, applicationId); log('Started webhook-mode client locally %s:%s', platform, applicationId); return 'started'; } diff --git a/src/server/services/generation/index.ts b/src/server/services/generation/index.ts index bc9fbd0912..aab8be89c8 100644 --- a/src/server/services/generation/index.ts +++ b/src/server/services/generation/index.ts @@ -80,8 +80,8 @@ interface ImageForGeneration { export class GenerationService { private fileService: FileService; - constructor(db: LobeChatDatabase, userId: string) { - this.fileService = new FileService(db, userId); + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { + this.fileService = new FileService(db, userId, workspaceId); } /** diff --git a/src/server/services/generation/video.ts b/src/server/services/generation/video.ts index f41159d4c1..909fce2a7d 100644 --- a/src/server/services/generation/video.ts +++ b/src/server/services/generation/video.ts @@ -48,8 +48,8 @@ export interface VideoProcessResult { export class VideoGenerationService { private fileService: FileService; - constructor(db: LobeChatDatabase, userId: string) { - this.fileService = new FileService(db, userId); + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { + this.fileService = new FileService(db, userId, workspaceId); } /** diff --git a/src/server/services/generation/videoBackgroundPolling.ts b/src/server/services/generation/videoBackgroundPolling.ts index 6833e4e09f..eff8626421 100644 --- a/src/server/services/generation/videoBackgroundPolling.ts +++ b/src/server/services/generation/videoBackgroundPolling.ts @@ -26,6 +26,7 @@ interface BackgroundPollingParams { prechargeResult?: any; provider: string; userId: string; + workspaceId?: string; } export async function processBackgroundVideoPolling( @@ -41,6 +42,7 @@ export async function processBackgroundVideoPolling( model, provider, userId, + workspaceId, } = params; log( @@ -51,11 +53,11 @@ export async function processBackgroundVideoPolling( ); try { - const asyncTaskModel = new AsyncTaskModel(db, userId); - const videoService = new VideoGenerationService(db, userId); - const generationModel = new GenerationModel(db, userId); + const asyncTaskModel = new AsyncTaskModel(db, userId, workspaceId); + const videoService = new VideoGenerationService(db, userId, workspaceId); + const generationModel = new GenerationModel(db, userId, workspaceId); - const modelRuntime = await initModelRuntimeFromDB(db, userId, provider); + const modelRuntime = await initModelRuntimeFromDB(db, userId, provider, workspaceId); const pollResult = await pollUntilCompletion(modelRuntime, inferenceId); if (!pollResult) { @@ -107,7 +109,7 @@ export async function processBackgroundVideoPolling( } catch (error) { log('Background video polling error for task: %s', asyncTaskId, error); - const asyncTaskModel = new AsyncTaskModel(db, userId); + const asyncTaskModel = new AsyncTaskModel(db, userId, workspaceId); const providerContentPolicyMessage = await getProviderContentPolicyErrorMessage({ error, provider, diff --git a/src/server/services/heterogeneousAgent/index.ts b/src/server/services/heterogeneousAgent/index.ts index fe54aad1f9..f265506351 100644 --- a/src/server/services/heterogeneousAgent/index.ts +++ b/src/server/services/heterogeneousAgent/index.ts @@ -53,6 +53,11 @@ export interface HeterogeneousAgentServiceOptions { streamEventManager?: IStreamEventManager; /** Inject a pre-built TopicModel (used by tests for the resume helper). */ topicModel?: TopicModel; + /** + * Workspace id for scoping internal model reads/writes (messages, topics, + * threads). Falls back to user-personal scope when omitted. + */ + workspaceId?: string; } /** @@ -81,14 +86,15 @@ export class HeterogeneousAgentService { ) { this.db = db; this.userId = userId; - this.messageModel = new MessageModel(db, userId); + const workspaceId = options.workspaceId; + this.messageModel = new MessageModel(db, userId, workspaceId); this.streamEventManager = options.streamEventManager ?? createStreamEventManager(); - this.topicModel = options.topicModel ?? new TopicModel(db, userId); + this.topicModel = options.topicModel ?? new TopicModel(db, userId, workspaceId); this.persistenceHandler = options.persistenceHandler ?? new HeterogeneousPersistenceHandler({ messageModel: this.messageModel, - threadModel: new ThreadModel(db, userId), + threadModel: new ThreadModel(db, userId, workspaceId), topicModel: this.topicModel, }); } diff --git a/src/server/services/klavis/index.ts b/src/server/services/klavis/index.ts index 6851da1c27..bbcea1209d 100644 --- a/src/server/services/klavis/index.ts +++ b/src/server/services/klavis/index.ts @@ -16,11 +16,13 @@ export interface KlavisToolExecuteParams { /** Tool identifier (same as Klavis server identifier, e.g., 'google-calendar') */ identifier: string; toolName: string; + workspaceId?: string; } export interface KlavisServiceOptions { db?: LobeChatDatabase; userId?: string; + workspaceId?: string; } /** @@ -43,15 +45,17 @@ export class KlavisService { private db?: LobeChatDatabase; private userId?: string; private pluginModel?: PluginModel; + private workspaceId?: string; constructor(options: KlavisServiceOptions = {}) { - const { db, userId } = options; + const { db, userId, workspaceId } = options; this.db = db; this.userId = userId; + this.workspaceId = workspaceId; if (db && userId) { - this.pluginModel = new PluginModel(db, userId); + this.pluginModel = new PluginModel(db, userId, workspaceId); } log( @@ -68,7 +72,7 @@ export class KlavisService { * @returns Tool execution result */ async executeKlavisTool(params: KlavisToolExecuteParams): Promise { - const { identifier, toolName, args } = params; + const { identifier, toolName, args, workspaceId } = params; log('executeKlavisTool: %s/%s with args: %O', identifier, toolName, args); @@ -95,7 +99,11 @@ export class KlavisService { try { // Get plugin from database to retrieve serverUrl - const plugin = await this.pluginModel.findById(identifier); + const pluginModel = + workspaceId && this.db && this.userId + ? new PluginModel(this.db, this.userId, workspaceId) + : this.pluginModel; + const plugin = await pluginModel.findById(identifier); if (!plugin) { return { content: `Klavis server "${identifier}" not found in database`, diff --git a/src/server/services/knowledgeBase/index.ts b/src/server/services/knowledgeBase/index.ts index 461d52af31..a224748c2c 100644 --- a/src/server/services/knowledgeBase/index.ts +++ b/src/server/services/knowledgeBase/index.ts @@ -89,17 +89,24 @@ export class KnowledgeBaseSearchService { private searchRepo: SearchRepo; private documentServiceInstance?: DocumentService; - constructor(serverDB: LobeChatDatabase, userId: string) { + private workspaceId?: string; + + constructor(serverDB: LobeChatDatabase, userId: string, workspaceId?: string) { this.serverDB = serverDB; this.userId = userId; - this.chunkModel = new ChunkModel(serverDB, userId); - this.documentModel = new DocumentModel(serverDB, userId); - this.fileModel = new FileModel(serverDB, userId); - this.searchRepo = new SearchRepo(serverDB, userId); + this.workspaceId = workspaceId; + this.chunkModel = new ChunkModel(serverDB, userId, workspaceId); + this.documentModel = new DocumentModel(serverDB, userId, workspaceId); + this.fileModel = new FileModel(serverDB, userId, workspaceId); + this.searchRepo = new SearchRepo(serverDB, userId, workspaceId); } private get documentService() { - this.documentServiceInstance ??= new DocumentService(this.serverDB, this.userId); + this.documentServiceInstance ??= new DocumentService( + this.serverDB, + this.userId, + this.workspaceId, + ); return this.documentServiceInstance; } @@ -113,7 +120,12 @@ export class KnowledgeBaseSearchService { const vectorPath = async (): Promise => { const { model, provider } = getServerDefaultFilesConfig().embeddingModel || DEFAULT_FILE_EMBEDDING_MODEL_ITEM; - const modelRuntime = await initModelRuntimeFromDB(this.serverDB, this.userId, provider); + const modelRuntime = await initModelRuntimeFromDB( + this.serverDB, + this.userId, + provider, + this.workspaceId, + ); // slice content to make sure in the context window limit const query = input.query.length > 8000 ? input.query.slice(0, 8000) : input.query; diff --git a/src/server/services/llmGenerationTracing/hook.ts b/src/server/services/llmGenerationTracing/hook.ts index 653f3f93c8..6e5345f2d0 100644 --- a/src/server/services/llmGenerationTracing/hook.ts +++ b/src/server/services/llmGenerationTracing/hook.ts @@ -72,6 +72,7 @@ const tryScheduleAfter = (work: () => Promise | void): void => { export const createLLMGenerationTracingHook = ( userId: string, provider: string, + workspaceId?: string, ): Pick => { const service = getLLMGenerationTracingService(); if (!service.isEnabled()) return {}; @@ -144,6 +145,7 @@ export const createLLMGenerationTracingHook = ( trigger, userId, validationFailed, + workspaceId, }); persistedTracingId = result?.tracingId ?? null; } catch (err) { diff --git a/src/server/services/llmGenerationTracing/index.ts b/src/server/services/llmGenerationTracing/index.ts index ed686dd074..ef1ed02faa 100644 --- a/src/server/services/llmGenerationTracing/index.ts +++ b/src/server/services/llmGenerationTracing/index.ts @@ -66,6 +66,7 @@ export interface RecordLLMGenerationCallParams { trigger?: string | null; userId: string; validationFailed?: boolean; + workspaceId?: string | null; } /** @@ -99,7 +100,7 @@ export class LLMGenerationTracingService { return null; } - const model = new LlmGenerationTracingModel(db, params.userId); + const model = new LlmGenerationTracingModel(db, params.userId, params.workspaceId ?? undefined); // Allocate the id up-front so the route can return it synchronously to // the client (e.g. for feedback wiring) even though the actual `record()` @@ -210,6 +211,7 @@ export class LLMGenerationTracingService { userId: string, tracingId: string, params: UpdateLlmGenerationFeedbackParams, + workspaceId?: string, ): Promise { let db: Awaited>; try { @@ -220,7 +222,7 @@ export class LLMGenerationTracingService { cause: err, }); } - const model = new LlmGenerationTracingModel(db, userId); + const model = new LlmGenerationTracingModel(db, userId, workspaceId); let result: { updated: boolean }; try { result = await model.updateFeedback(tracingId, params); diff --git a/src/server/services/market/index.ts b/src/server/services/market/index.ts index 0366b19a15..573c2d722a 100644 --- a/src/server/services/market/index.ts +++ b/src/server/services/market/index.ts @@ -1,5 +1,5 @@ import { type LobeToolManifest } from '@lobechat/context-engine'; -import { MarketSDK } from '@lobehub/market-sdk'; +import { MarketSDK, type OrgRef, orgRefToPathSegment } from '@lobehub/market-sdk'; import debug from 'debug'; import { type NextRequest } from 'next/server'; @@ -46,6 +46,15 @@ export interface MarketServiceOptions { clientId: string; clientSecret: string; }; + /** + * Owner account id for organization-scoped operations. + * + * When set, Market attributes reads/writes (currently: creds and inject-creds) + * to the given organization account instead of the actor's personal account. + * Used by the workspace creds router after resolving a cloud workspace to + * its Market organization via {@link WorkspaceMarketIdentityService}. + */ + ownerAccountId?: number; /** Pre-generated trusted client token (alternative to userInfo) */ trustedClientToken?: string; /** User info for generating trusted client token */ @@ -81,7 +90,8 @@ export class MarketService { market: MarketSDK; constructor(options: MarketServiceOptions = {}) { - const { accessToken, userInfo, clientCredentials, trustedClientToken } = options; + const { accessToken, userInfo, clientCredentials, trustedClientToken, ownerAccountId } = + options; // Use provided trustedClientToken or generate from userInfo const resolvedTrustedClientToken = @@ -92,15 +102,17 @@ export class MarketService { baseURL: MARKET_BASE_URL, clientId: clientCredentials?.clientId, clientSecret: clientCredentials?.clientSecret, + ownerAccountId, trustedClientToken: resolvedTrustedClientToken, }); log( - 'MarketService initialized: baseURL=%s, hasAccessToken=%s, hasTrustedToken=%s, hasClientCredentials=%s', + 'MarketService initialized: baseURL=%s, hasAccessToken=%s, hasTrustedToken=%s, hasClientCredentials=%s, ownerAccountId=%s', MARKET_BASE_URL, !!accessToken, !!resolvedTrustedClientToken, !!clientCredentials, + ownerAccountId ?? 'none', ); } @@ -589,22 +601,37 @@ export class MarketService { // ============================== Creds Methods ============================== /** - * Upload credential file to Market API - * This method directly calls the Market API since SDK doesn't support file upload yet + * Upload a credential file to Market. * - * @param file - File content as base64 string - * @param fileName - Original file name - * @param fileType - MIME type of the file + * The SDK doesn't expose multipart upload so this method calls the REST + * endpoint directly. Pass `orgId` to upload to an organization's cred + * bucket (`/api/v1/organizations/:orgId/creds/upload`); omit it for a + * personal upload (`/api/v1/user/creds/upload`). + * + * @param params.file - File content as base64 string + * @param params.fileName - Original file name + * @param params.fileType - MIME type of the file + * @param params.orgId - Optional organization account id. When set, the + * upload is attributed to the org via the org-scoped URL; org membership + * (admin) is enforced server-side by `requireOrgMembership`. * @returns Upload result with fileHashId */ async uploadCredFile(params: { file: string; // base64 encoded file content fileName: string; fileType: string; + orgId?: OrgRef; }): Promise<{ fileHashId: string; fileName: string; fileSize: number; fileType: string }> { - const { file, fileName, fileType } = params; + const { file, fileName, fileType, orgId } = params; + // Numeric account id or `workspace:` path segment. + const orgSegment = orgId === undefined ? undefined : orgRefToPathSegment(orgId); - log('uploadCredFile: fileName=%s, fileType=%s', fileName, fileType); + log( + 'uploadCredFile: fileName=%s, fileType=%s, orgId=%s', + fileName, + fileType, + orgSegment ?? 'none', + ); // Convert base64 to Blob const binaryString = atob(file); @@ -618,19 +645,25 @@ export class MarketService { const formData = new FormData(); formData.append('file', blob, fileName); - // Extract only auth headers (not Content-Type, which would break multipart/form-data) + // Extract only auth headers (not Content-Type, which would break multipart/form-data). + // We deliberately also strip `x-lobe-owner-account-id` for the org path — + // ownership is in the URL now, the header is ignored by the org route. // @ts-ignore - market.headers contains auth headers const sdkHeaders = this.market.headers as Record; const authHeaders: Record = {}; for (const [key, value] of Object.entries(sdkHeaders)) { - // Only include authorization-related headers, skip Content-Type - if (key.toLowerCase() !== 'content-type') { - authHeaders[key] = value; - } + const lower = key.toLowerCase(); + if (lower === 'content-type') continue; + if (lower === 'x-lobe-owner-account-id' && orgSegment !== undefined) continue; + authHeaders[key] = value; } // Call Market API directly - const uploadUrl = `${MARKET_BASE_URL}/api/v1/user/creds/upload`; + const uploadPath = + orgSegment === undefined + ? '/api/v1/user/creds/upload' + : `/api/v1/organizations/${orgSegment}/creds/upload`; + const uploadUrl = `${MARKET_BASE_URL}${uploadPath}`; const response = await fetch(uploadUrl, { body: formData, headers: authHeaders, diff --git a/src/server/services/memory/userMemory/extract.ts b/src/server/services/memory/userMemory/extract.ts index 6a0eae0234..3dc1aaeaa2 100644 --- a/src/server/services/memory/userMemory/extract.ts +++ b/src/server/services/memory/userMemory/extract.ts @@ -66,6 +66,7 @@ import { UserMemoryModel } from '@/database/models/userMemory'; import { UserMemorySourceBenchmarkLoCoMoModel } from '@/database/models/userMemory/sources/benchmarkLoCoMo'; import { AiInfraRepos } from '@/database/repositories/aiInfra'; import { getServerDB } from '@/database/server'; +import { buildWorkspaceWhere } from '@/database/utils/workspace'; import { getServerGlobalConfig } from '@/server/globalConfig'; import { type MemoryAgentConfig } from '@/server/globalConfig/parseMemoryExtractionConfig'; import { parseMemoryExtractionConfig } from '@/server/globalConfig/parseMemoryExtractionConfig'; @@ -145,6 +146,7 @@ export interface MemoryExtractionNormalizedPayload { userId?: string; userIds: string[]; userInitiated?: boolean; + workspaceId?: string; } export const memoryExtractionPayloadSchema = z.object({ @@ -176,6 +178,7 @@ export const memoryExtractionPayloadSchema = z.object({ userId: z.string().optional(), userIds: z.array(z.string()).optional(), userInitiated: z.boolean().optional(), + workspaceId: z.string().optional(), }); export type MemoryExtractionPayloadInput = z.infer; @@ -228,6 +231,7 @@ export const normalizeMemoryExtractionPayload = ( new Set([...(parsed.userIds || []), ...(parsed.userId ? [parsed.userId] : [])]), ).filter(Boolean), userInitiated: parsed.userInitiated ?? false, + workspaceId: parsed.workspaceId, }; }; @@ -263,6 +267,7 @@ export const buildWorkflowPayloadInput = ( userId: payload.userId ?? payload.userIds[0], userIds: payload.userIds, userInitiated: payload.userInitiated, + workspaceId: payload.workspaceId, }); const normalizeProvider = (provider: string) => provider.toLowerCase(); @@ -628,6 +633,7 @@ export interface TopicExtractionJob { topicId: string; userId: string; userInitiated?: boolean; + workspaceId?: string; } export interface TopicPaginationJob { @@ -637,6 +643,7 @@ export interface TopicPaginationJob { from?: Date; to?: Date; userId: string; + workspaceId?: string; } export interface UserPaginationResult { @@ -1327,7 +1334,12 @@ export class MemoryExtractionExecutor { return { errors: [], ids: insertedIds } satisfies PersistLayerResult; } - async listConversationsForTopic(userId: string, topicId: string, topicUpdatedAt: Date) { + async listConversationsForTopic( + userId: string, + topicId: string, + topicUpdatedAt: Date, + workspaceId?: string, + ) { const db = await this.db; const rows = await db .select({ @@ -1337,7 +1349,9 @@ export class MemoryExtractionExecutor { role: messages.role, }) .from(messages) - .where(and(eq(messages.userId, userId), eq(messages.topicId, topicId))) + .where( + and(buildWorkspaceWhere({ userId, workspaceId }, messages), eq(messages.topicId, topicId)), + ) .orderBy(asc(messages.createdAt)); const conversation = rows @@ -1442,7 +1456,7 @@ export class MemoryExtractionExecutor { try { const db = await this.db; - const asyncTaskModel = new AsyncTaskModel(db, job.userId); + const asyncTaskModel = new AsyncTaskModel(db, job.userId, job.workspaceId); await asyncTaskModel.incrementUserMemoryExtractionProgress(job.asyncTaskId); } catch (error) { console.error('[memory-extraction] failed to update async task progress', error); @@ -1492,7 +1506,10 @@ export class MemoryExtractionExecutor { const db = await this.db; const topic = await db.query.topics.findFirst({ columns: { createdAt: true, id: true, metadata: true, updatedAt: true, userId: true }, - where: and(eq(topics.id, job.topicId), eq(topics.userId, job.userId)), + where: and( + eq(topics.id, job.topicId), + buildWorkspaceWhere({ userId: job.userId, workspaceId: job.workspaceId }, topics), + ), }); if (!topic) { @@ -1541,7 +1558,7 @@ export class MemoryExtractionExecutor { const userModel = new UserModel(db, job.userId); const [userState, aiProviderRuntimeState] = await Promise.all([ userModel.getUserState(KeyVaultsGateKeeper.getUserKeyVaults), - this.getAiProviderRuntimeState(job.userId), + this.getAiProviderRuntimeState(job.userId, job.workspaceId), ]); const memoryServiceConfig = this.resolveUserMemoryServiceConfig( userState.settings?.systemAgent as Partial | undefined, @@ -1558,6 +1575,7 @@ export class MemoryExtractionExecutor { job.userId, topic.id, topic.updatedAt, + job.workspaceId, ); if (!conversations || conversations.length === 0) { if (extractionJob) { @@ -1848,7 +1866,7 @@ export class MemoryExtractionExecutor { } if (job.asyncTaskId && job.userInitiated) { try { - const asyncTaskModel = new AsyncTaskModel(await this.db, job.userId); + const asyncTaskModel = new AsyncTaskModel(await this.db, job.userId, job.workspaceId); await asyncTaskModel.update(job.asyncTaskId, { error: buildAsyncTaskErrorFrom(error), status: AsyncTaskStatus.Error, @@ -1946,7 +1964,11 @@ export class MemoryExtractionExecutor { } for (const userId of payload.userIds) { - const topicIds = await this.filterTopicIdsForUser(userId, payload.topicIds); + const topicIds = await this.filterTopicIdsForUser( + userId, + payload.topicIds, + payload.workspaceId, + ); for (const topicId of topicIds) { const extracted = await this.extractTopic({ asyncTaskId: payload.asyncTaskId, @@ -1959,6 +1981,7 @@ export class MemoryExtractionExecutor { topicId, userId, userInitiated: payload.userInitiated, + workspaceId: payload.workspaceId, }); results.push({ ...extracted, topicId, userId }); @@ -2010,7 +2033,7 @@ export class MemoryExtractionExecutor { pageSize: number, ): Promise<{ cursor?: ListTopicsForMemoryExtractorCursor; ids: string[] }> { const db = await this.db; - const topicModel = new TopicModel(db, job.userId); + const topicModel = new TopicModel(db, job.userId, job.workspaceId); const rows = await topicModel.listTopicsForMemoryExtractor({ cursor: job.cursor, endDate: job.to, @@ -2094,13 +2117,16 @@ export class MemoryExtractionExecutor { }; } - async filterTopicIdsForUser(userId: string, topicIds: string[]) { + async filterTopicIdsForUser(userId: string, topicIds: string[], workspaceId?: string) { if (!topicIds.length) return []; const db = await this.db; const rows = await db.query.topics.findMany({ columns: { id: true }, - where: and(eq(topics.userId, userId), inArray(topics.id, topicIds)), + where: and( + buildWorkspaceWhere({ userId, workspaceId }, topics), + inArray(topics.id, topicIds), + ), }); return rows.map((row) => row.id); @@ -2322,7 +2348,10 @@ export class MemoryExtractionExecutor { }; } - private async getAiProviderRuntimeState(userId: string): Promise { + private async getAiProviderRuntimeState( + userId: string, + workspaceId?: string, + ): Promise { const db = await this.db; const aiInfraRepos = new AiInfraRepos(db, userId, this.aiProviderConfig); diff --git a/src/server/services/memory/userMemory/persona/service.ts b/src/server/services/memory/userMemory/persona/service.ts index e778b1e918..bbb75fb266 100644 --- a/src/server/services/memory/userMemory/persona/service.ts +++ b/src/server/services/memory/userMemory/persona/service.ts @@ -96,6 +96,9 @@ export class UserPersonaService { async composeWriting(payload: UserPersonaAgentPayload): Promise { const agentConfig = await this.resolveAgentConfig(payload.userId); + // workspace-audit: intentionally personal-scoped (no workspaceId). Persona is a + // purely user-level feature with no workspace concept; the payload carries no + // workspaceId, so provider config is resolved against the user's personal scope. const aiInfraRepos = new AiInfraRepos(this.db, payload.userId, {}); const runtimeState = await aiInfraRepos.getAiProviderRuntimeState( KeyVaultsGateKeeper.getUserKeyVaults, diff --git a/src/server/services/message/index.ts b/src/server/services/message/index.ts index 9859bbfca0..1c366d0e7b 100644 --- a/src/server/services/message/index.ts +++ b/src/server/services/message/index.ts @@ -53,10 +53,10 @@ export class MessageService { private fileService: FileService; private compressionRepository: CompressionRepository; - constructor(db: LobeChatDatabase, userId: string) { - this.messageModel = new MessageModel(db, userId); - this.fileService = new FileService(db, userId); - this.compressionRepository = new CompressionRepository(db, userId); + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { + this.messageModel = new MessageModel(db, userId, workspaceId); + this.fileService = new FileService(db, userId, workspaceId); + this.compressionRepository = new CompressionRepository(db, userId, workspaceId); } /** diff --git a/src/server/services/messenger/MessengerRouter.test.ts b/src/server/services/messenger/MessengerRouter.test.ts index f5c2fc5686..fd47f6b478 100644 --- a/src/server/services/messenger/MessengerRouter.test.ts +++ b/src/server/services/messenger/MessengerRouter.test.ts @@ -28,6 +28,12 @@ vi.mock('./oauth/slackOAuth', () => ({ verifySignature: (...args: any[]) => mockVerifySignature(...args), })); +const mockGetServerFeatureFlagsStateFromRuntimeConfig = vi.fn(); +vi.mock('@/server/featureFlags', () => ({ + getServerFeatureFlagsStateFromRuntimeConfig: (...args: any[]) => + mockGetServerFeatureFlagsStateFromRuntimeConfig(...args), +})); + vi.mock('@/config/messenger', () => ({ getEnabledMessengerPlatforms: vi.fn().mockReturnValue(['slack', 'telegram']), getMessengerSlackConfig: vi.fn().mockReturnValue({ @@ -89,21 +95,37 @@ vi.mock('@chat-adapter/state-ioredis', () => ({ // reused — see the comment in `MessengerRouter.dispatchToAgent`. const mockHandleMention = vi.fn(); const mockHandleSubscribed = vi.fn(); +const mockAgentBridgeConstructor = vi.fn(); vi.mock('@/server/services/bot/AgentBridgeService', () => ({ AgentBridgeService: class { static clearActiveThread = vi.fn(); static getActiveOperationId = vi.fn(); static isThreadActive = vi.fn(); static requestStop = vi.fn(); + constructor(...args: unknown[]) { + mockAgentBridgeConstructor(...args); + } handleMention = mockHandleMention; handleSubscribedMessage = mockHandleSubscribed; }, })); const mockFindLink = vi.fn(); +const mockSetActiveScope = vi.fn(); vi.mock('@/database/models/messengerAccountLink', () => ({ MessengerAccountLinkModel: { findByPlatformUser: (...args: any[]) => mockFindLink(...args), + setActiveScope: (...args: any[]) => mockSetActiveScope(...args), + }, +})); + +// `/switch` and the dispatch-time membership re-validation both enumerate the +// user's workspaces. Default to membership of `workspace-1` so the existing +// workspace-scoped dispatch tests pass; individual tests can override. +const mockListUserWorkspaces = vi.fn(); +vi.mock('@/database/models/workspace', () => ({ + WorkspaceModel: class { + listUserWorkspaces = (...args: any[]) => mockListUserWorkspaces(...args); }, })); vi.mock('@/server/services/aiAgent', () => ({ @@ -184,6 +206,14 @@ beforeEach(() => { telegram: mockWebhookHandler, }; mockFindLink.mockReset(); + mockSetActiveScope.mockReset(); + mockListUserWorkspaces.mockReset(); + mockListUserWorkspaces.mockResolvedValue([ + { id: 'workspace-1', name: 'Workspace 1', role: 'owner' }, + ]); + mockGetServerFeatureFlagsStateFromRuntimeConfig.mockReset(); + mockGetServerFeatureFlagsStateFromRuntimeConfig.mockResolvedValue({ enableWorkspace: true }); + mockAgentBridgeConstructor.mockReset(); mockHandleMention.mockReset(); mockHandleSubscribed.mockReset(); mockOpenDM.mockReset(); @@ -412,6 +442,7 @@ describe('MessengerRouter channel @mention', () => { platformUserId: 'U_ALICE', tenantId: 'T_ACME', userId: 'user_alice', + workspaceId: 'workspace-1', }); const handler = mockChatBot.onNewMention.mock.calls[0][0] as ( @@ -427,6 +458,11 @@ describe('MessengerRouter channel @mention', () => { // first-touch entry, so dispatch goes through `handleMention` (mirrors // BotMessageRouter). expect(mockHandleMention).toHaveBeenCalledTimes(1); + expect(mockAgentBridgeConstructor).toHaveBeenCalledWith( + expect.anything(), + 'user_alice', + 'workspace-1', + ); expect(mockHandleMention.mock.calls[0][2]).toMatchObject({ agentId: 'agt_main' }); expect(mockHandleSubscribed).not.toHaveBeenCalled(); // We deliberately do NOT subscribe channel threads — see comment in @@ -436,6 +472,35 @@ describe('MessengerRouter channel @mention', () => { expect(mockSlackBinder.replyEphemeral).not.toHaveBeenCalled(); }); + it('skips dispatch and prompts /switch when the active workspace is no longer accessible', async () => { + await loadSlackBot(); + mockFindLink.mockResolvedValue({ + activeAgentId: 'agt_main', + id: 'link_1', + platformUserId: 'U_ALICE', + tenantId: 'T_ACME', + userId: 'user_alice', + workspaceId: 'workspace-removed', + }); + // The user is only a member of workspace-1 now — not workspace-removed. + mockListUserWorkspaces.mockResolvedValue([ + { id: 'workspace-1', name: 'Workspace 1', role: 'owner' }, + ]); + + const handler = mockChatBot.onNewMention.mock.calls[0][0] as ( + thread: any, + msg: any, + ) => Promise; + await handler(fakeChannelThread(), fakeMessage({ isMention: true, text: '<@U_BOT> hi' })); + + expect(mockHandleMention).not.toHaveBeenCalled(); + expect(mockSlackBinder.replyEphemeral).toHaveBeenCalledTimes(1); + expect(mockSlackBinder.replyEphemeral.mock.calls[0][0]).toMatchObject({ + channelId: 'C_GENERAL', + userId: 'U_ALICE', + }); + }); + it('routes an unlinked channel mention through handleUnlinkedMessage with channelMentionThreadId', async () => { await loadSlackBot(); mockFindLink.mockResolvedValue(null); @@ -960,3 +1025,133 @@ describe('MessengerRouter onSubscribedMessage gating', () => { expect(mockHandleSubscribed).toHaveBeenCalledTimes(1); }); }); + +describe('MessengerRouter /switch', () => { + const personalLink = { + activeAgentId: 'agt_main', + id: 'link_1', + platformUserId: 'U_ALICE', + tenantId: '', + userId: 'user_alice', + workspaceId: null, + }; + + it('renders the scope picker with the current scope marked', async () => { + await loadSlackBot(); + mockFindLink.mockResolvedValue(personalLink); + mockListUserWorkspaces.mockResolvedValue([ + { id: 'workspace-1', name: 'Workspace 1', role: 'owner' }, + ]); + + const handler = mockChatBot.onNewMention.mock.calls[0][0] as ( + thread: any, + msg: any, + ) => Promise; + await handler(fakeDmThread(), fakeMessage({ isMention: true, text: '/switch' })); + + // Picker-capable binder → buttons, not a numbered text list, and no switch + // happens just from listing. + expect(mockSetActiveScope).not.toHaveBeenCalled(); + expect(mockSlackBinder.sendAgentPicker).toHaveBeenCalledTimes(1); + const params = mockSlackBinder.sendAgentPicker.mock.calls[0][1]; + expect(params.action).toBe('scope'); + expect(params.text).toContain('Tap a scope'); + // Personal (the active scope) carries the marker; workspaces follow. + expect(params.entries).toEqual([ + { id: 'personal', isActive: true, title: 'Personal' }, + { id: 'workspace-1', isActive: false, title: 'Workspace 1' }, + ]); + }); + + it('hides workspace scopes from /switch when workspace feature is disabled', async () => { + await loadSlackBot(); + mockFindLink.mockResolvedValue(personalLink); + mockGetServerFeatureFlagsStateFromRuntimeConfig.mockResolvedValue({ enableWorkspace: false }); + mockListUserWorkspaces.mockResolvedValue([ + { id: 'workspace-1', name: 'Workspace 1', role: 'owner' }, + ]); + + const handler = mockChatBot.onNewMention.mock.calls[0][0] as ( + thread: any, + msg: any, + ) => Promise; + await handler(fakeDmThread(), fakeMessage({ isMention: true, text: '/switch' })); + + expect(mockSetActiveScope).not.toHaveBeenCalled(); + expect(mockListUserWorkspaces).not.toHaveBeenCalled(); + expect(mockSlackBinder.sendAgentPicker).toHaveBeenCalledTimes(1); + const params = mockSlackBinder.sendAgentPicker.mock.calls[0][1]; + expect(params.entries).toEqual([{ id: 'personal', isActive: true, title: 'Personal' }]); + }); + + it('rejects workspace scope button callbacks when workspace feature is disabled', async () => { + mockFindLink.mockResolvedValue(personalLink); + mockGetServerFeatureFlagsStateFromRuntimeConfig.mockResolvedValue({ enableWorkspace: false }); + mockListUserWorkspaces.mockResolvedValue([ + { id: 'workspace-1', name: 'Workspace 1', role: 'owner' }, + ]); + + const router = new MessengerRouter(); + const ack = vi.fn(); + await (router as any).handleScopeCallback( + { platform: 'slack', tenantId: '' }, + { + callbackId: '', + chatId: 'D_DM', + data: 'messenger:scope:workspace-1', + fromUserId: 'U_ALICE', + }, + 'workspace-1', + ack, + ); + + expect(mockSetActiveScope).not.toHaveBeenCalled(); + expect(mockListUserWorkspaces).not.toHaveBeenCalled(); + expect(ack).toHaveBeenCalledWith({ toast: 'Scope not found.' }); + }); + + it('persists the scope switch and re-renders the picker when a scope button is tapped', async () => { + mockFindLink.mockResolvedValue(personalLink); + mockListUserWorkspaces.mockResolvedValue([ + { id: 'workspace-1', name: 'Workspace 1', role: 'owner' }, + ]); + // `fetchUserAgents` pins the inbox/LobeAI first; the switch lands on it. + vi.spyOn(MessengerRouter.prototype as any, 'fetchUserAgents').mockResolvedValue([ + { id: 'agt_inbox', title: 'LobeAI' }, + ]); + + const router = new MessengerRouter(); + const ack = vi.fn(); + await (router as any).handleScopeCallback( + { platform: 'slack', tenantId: '' }, + { + callbackId: '', + chatId: 'D_DM', + data: 'messenger:scope:workspace-1', + fromUserId: 'U_ALICE', + }, + 'workspace-1', + ack, + ); + + // Active agent is set to the target scope's default (the inbox), not cleared. + expect(mockSetActiveScope).toHaveBeenCalledWith( + expect.anything(), + 'link_1', + 'workspace-1', + 'agt_inbox', + ); + expect(ack).toHaveBeenCalledWith( + expect.objectContaining({ + toast: expect.stringContaining('Workspace 1'), + updatedPicker: expect.objectContaining({ + action: 'scope', + entries: [ + { id: 'personal', isActive: false, title: 'Personal' }, + { id: 'workspace-1', isActive: true, title: 'Workspace 1' }, + ], + }), + }), + ); + }); +}); diff --git a/src/server/services/messenger/MessengerRouter.ts b/src/server/services/messenger/MessengerRouter.ts index cb4b4c241c..31b2ae6362 100644 --- a/src/server/services/messenger/MessengerRouter.ts +++ b/src/server/services/messenger/MessengerRouter.ts @@ -13,9 +13,12 @@ import { and, desc, eq, ne, or } from 'drizzle-orm'; import type { MessengerPlatform } from '@/config/messenger'; import { getServerDB } from '@/database/core/db-adaptor'; import { MessengerAccountLinkModel } from '@/database/models/messengerAccountLink'; +import { WorkspaceModel } from '@/database/models/workspace'; import type { MessengerAccountLinkItem } from '@/database/schemas'; import { agents } from '@/database/schemas'; import type { LobeChatDatabase } from '@/database/type'; +import { buildWorkspaceWhere } from '@/database/utils/workspace'; +import { getServerFeatureFlagsStateFromRuntimeConfig } from '@/server/featureFlags'; import { getAgentRuntimeRedisClient } from '@/server/modules/AgentRuntime/redis'; import { AiAgentService } from '@/server/services/aiAgent'; import { AgentBridgeService } from '@/server/services/bot/AgentBridgeService'; @@ -31,10 +34,23 @@ import { import { getInstallationStore } from './installations'; import type { InstallationCredentials } from './installations/types'; import { messengerPlatformRegistry } from './platforms'; -import type { AgentPickerEntry, InboundCallbackAction, MessengerPlatformBinder } from './types'; +import type { + AgentPickerEntry, + CallbackAcknowledgement, + InboundCallbackAction, + MessengerPlatformBinder, +} from './types'; const log = debug('lobe-server:messenger:router'); +/** + * Sentinel scope token for the Personal scope (whose real `workspaceId` is + * `null`) so it can ride inside a button id (`messenger:scope:personal`), + * which can't carry a null. Real workspaces use their own id. Workspace ids + * are generated nanoids, so they never collide with this literal. + */ +const PERSONAL_SCOPE_ID = 'personal'; + interface RegisteredMessengerBot { binder: MessengerPlatformBinder; chatBot: Chat; @@ -116,6 +132,7 @@ interface MessengerCommand { const HELP_TEXT = [ 'Commands:', '• /start — bind (or rebind) your LobeHub account', + '• /switch — switch the active scope (personal or a workspace)', '• /agents — list your agents and switch the active one', '• /new — start a new conversation', '• /stop — stop the current execution', @@ -171,6 +188,27 @@ const reconstructRequest = (req: Request, rawBody: string): Request => method: req.method, } as RequestInit); +/** + * Whether `userId` is still a member of `workspaceId`. The link's active scope + * may point at a workspace the user has since been removed from; dispatch must + * re-check before running that workspace's agent. + */ +const userIsWorkspaceMember = async ( + db: LobeChatDatabase, + userId: string, + workspaceId: string, +): Promise => { + if (!(await isWorkspaceFeatureEnabledForUser(userId))) return false; + + const workspaces = await new WorkspaceModel(db, userId).listUserWorkspaces(); + return workspaces.some((w) => w.id === workspaceId); +}; + +const isWorkspaceFeatureEnabledForUser = async (userId: string): Promise => { + const featureFlags = await getServerFeatureFlagsStateFromRuntimeConfig(userId); + return featureFlags.enableWorkspace === true; +}; + /** * Routes inbound messages from the shared Messenger bots to the right * LobeHub user + agent. @@ -506,6 +544,29 @@ export class MessengerRouter { return; } + // Re-validate the active scope: it may point at a workspace the user + // has since been removed from. Don't run another workspace's agent for + // a non-member — prompt them to /switch back to a scope they can use. + if ( + link.workspaceId && + !(await userIsWorkspaceMember(serverDB, link.userId, link.workspaceId)) + ) { + const staleScopeText = + 'Your active workspace is no longer available. Send /switch to choose another scope.'; + if (isChannelMention && binder.replyEphemeral) { + const threadTs = String(thread.id).split(':')[2]; + await binder.replyEphemeral({ + channelId: chatId, + text: staleScopeText, + threadTs, + userId: senderId, + }); + } else { + await binder.sendDmText(chatId, staleScopeText); + } + return; + } + await this.dispatchToAgent( thread, message, @@ -772,6 +833,22 @@ export class MessengerRouter { }, name: 'agents', }, + { + description: 'Switch the active scope (personal or a workspace)', + // Declared so Discord/Slack surface a `/switch ` argument in the + // slash picker; the no-arg form just lists the scopes. + options: [ + { + description: 'Scope number from the /switch list', + name: 'scope', + required: false, + }, + ], + handler: async (ctx) => { + await this.runSwitchCommand(ctx); + }, + name: 'switch', + }, { description: 'Start a new conversation', handler: async (ctx) => { @@ -818,7 +895,9 @@ export class MessengerRouter { const operationId = AgentBridgeService.getActiveOperationId(ctx.thread.id); if (operationId) { try { - const aiAgentService = new AiAgentService(ctx.serverDB, ctx.link.userId); + const aiAgentService = new AiAgentService(ctx.serverDB, ctx.link.userId, { + workspaceId: ctx.link.workspaceId ?? undefined, + }); const result = await aiAgentService.interruptTask({ operationId }); if (!result.success) { log('command /stop: runtime interrupt rejected for op=%s', operationId); @@ -1022,7 +1101,7 @@ export class MessengerRouter { return; } - const userAgents = await this.fetchUserAgents(serverDB, link.userId); + const userAgents = await this.fetchUserAgents(serverDB, link.userId, link.workspaceId); if (userAgents.length === 0) { await ctx.reply('You have no agents yet. Create one in LobeHub, then come back to /agents.'); return; @@ -1076,6 +1155,122 @@ export class MessengerRouter { ); } + /** + * `/switch` changes the active *scope* of the IM session — personal or one + * of the workspaces the user belongs to. The bot is a single shared bot, so + * which LobeHub context a conversation runs in is the active agent's scope; + * switching scope clears the active agent and the user re-picks via /agents. + * + * Mirrors `/agents`: on platforms that implement `sendAgentPicker` the bot + * replies with a tap-to-switch keyboard (buttons emit `messenger:scope:` + * so the callback path can tell them apart from agent switches). Platforms + * without keyboard support fall back to a numbered text list + `/switch `. + */ + private async runSwitchCommand(ctx: MessengerCommandContext): Promise { + const { binder, chatId, link, serverDB } = ctx; + if (!link) { + await ctx.reply('You need to /start to bind your account first.'); + return; + } + + const scopes = await this.fetchUserScopes(serverDB, link.userId); + + // Text-fallback path: `/switch 2` switches without needing the keyboard, + // for platforms (or clients) where tap-buttons aren't available. + const arg = ctx.args.trim(); + if (arg && !binder.sendAgentPicker) { + const index = Number.parseInt(arg, 10); + if (!Number.isInteger(index) || index < 1 || index > scopes.length) { + await ctx.reply(`Usage: /switch , where n is between 1 and ${scopes.length}.`); + return; + } + const target = scopes[index - 1]; + if ((link.workspaceId ?? null) === target.id) { + await ctx.reply(`You're already in ${target.name}.`); + return; + } + const defaultAgent = await this.applyScopeSwitch(serverDB, link.id, link.userId, target); + await ctx.reply(this.scopeSwitchedText(target.name, defaultAgent?.title)); + return; + } + + if (binder.sendAgentPicker) { + await binder.sendAgentPicker(chatId, { + action: 'scope', + entries: this.toScopeEntries(scopes, link.workspaceId ?? null), + // Channel invocation → ephemeral so the user's personal workspace list + // isn't broadcast (mirrors the `/agents` picker rationale). + ephemeralTo: ctx.isDM ? undefined : ctx.authorUserId, + // Discord-only: forward the slash interaction so the binder can + // complete the deferred reply via the follow-up webhook. + interaction: ctx.interaction, + text: 'Tap a scope to switch:', + }); + return; + } + + // Final fallback: numbered list + usage hint for `/switch `. + const lines = scopes.map((scope, i) => { + const marker = (link.workspaceId ?? null) === scope.id ? ' (current)' : ''; + return `${i + 1}. ${scope.name}${marker}`; + }); + await ctx.reply(`Scopes:\n${lines.join('\n')}\n\nReply with /switch to switch scope.`); + } + + /** + * Resolve the user's switchable scopes: Personal (id `null`) followed by + * each workspace they belong to. + */ + private async fetchUserScopes( + serverDB: LobeChatDatabase, + userId: string, + ): Promise<{ id: string | null; name: string }[]> { + if (!(await isWorkspaceFeatureEnabledForUser(userId))) return [{ id: null, name: 'Personal' }]; + + const workspaces = await new WorkspaceModel(serverDB, userId).listUserWorkspaces(); + return [{ id: null, name: 'Personal' }, ...workspaces.map((w) => ({ id: w.id, name: w.name }))]; + } + + /** + * Persist a scope switch and land on the new scope's default agent + * (inbox/LobeAI is pinned first by `fetchUserAgents`) so switching never + * leaves the session agent-less. Falls back to `null` only when the target + * scope has no agents yet. + */ + private async applyScopeSwitch( + serverDB: LobeChatDatabase, + linkId: string, + userId: string, + target: { id: string | null; name: string }, + ): Promise { + const scopeAgents = await this.fetchUserAgents(serverDB, userId, target.id); + const defaultAgent = scopeAgents[0]; + await MessengerAccountLinkModel.setActiveScope( + serverDB, + linkId, + target.id, + defaultAgent?.id ?? null, + ); + return defaultAgent; + } + + private scopeSwitchedText(scopeName: string, defaultAgentTitle?: string): string { + return defaultAgentTitle + ? `Switched to ${scopeName}. Now chatting with ${defaultAgentTitle}. Send /agents to change.` + : `Switched to ${scopeName}. No agents here yet — create one in LobeHub, then /agents.`; + } + + private toScopeEntries( + scopes: { id: string | null; name: string }[], + currentScopeId: string | null, + ): AgentPickerEntry[] { + return scopes.map((scope) => ({ + id: scope.id ?? PERSONAL_SCOPE_ID, + isActive: scope.id === currentScopeId, + title: scope.name, + })); + } + private toPickerEntries( userAgents: AgentSummary[], activeAgentId: string | null | undefined, @@ -1134,7 +1329,7 @@ export class MessengerRouter { let activeAgentName: string | undefined; if (link.activeAgentId) { - const userAgents = await this.fetchUserAgents(serverDB, link.userId); + const userAgents = await this.fetchUserAgents(serverDB, link.userId, link.workspaceId); activeAgentName = userAgents.find((a) => a.id === link.activeAgentId)?.title; } @@ -1194,8 +1389,9 @@ export class MessengerRouter { * (Slack/Telegram) or chat-sdk's `onAction` event (Discord). Both paths * normalize to the same `InboundCallbackAction` shape and delegate the * outbound ack (toast + picker re-render) to `binder.acknowledgeCallback`. - * Today only `messenger:switch:` is recognized; new actions can - * be added by extending the dispatch below. + * Recognizes `messenger:switch:` (agent picker) and + * `messenger:scope:` (scope picker); new actions can be added by + * extending the dispatch below. */ private async handleCallbackAction( binder: MessengerPlatformBinder, @@ -1206,6 +1402,12 @@ export class MessengerRouter { const ack = binder.acknowledgeCallback.bind(binder, action); + const scopeMatch = action.data.match(/^messenger:scope:(.+)$/); + if (scopeMatch) { + await this.handleScopeCallback(creds, action, scopeMatch[1], ack); + return; + } + const switchMatch = action.data.match(/^messenger:switch:(.+)$/); if (!switchMatch) { await ack({ toast: 'Unknown action.' }); @@ -1225,7 +1427,7 @@ export class MessengerRouter { return; } - const userAgents = await this.fetchUserAgents(serverDB, link.userId); + const userAgents = await this.fetchUserAgents(serverDB, link.userId, link.workspaceId); const target = userAgents.find((agent) => agent.id === targetAgentId); if (!target) { await ack({ toast: 'Agent not found.' }); @@ -1247,6 +1449,54 @@ export class MessengerRouter { }); } + /** + * Handle a tap on a `/switch` scope button (`messenger:scope:`). + * `scopeToken` is the workspace id, or `PERSONAL_SCOPE_ID` for Personal. + * Switches the active scope (landing on the scope's default agent) and + * re-renders the picker with the new current marker. + */ + private async handleScopeCallback( + creds: InstallationCredentials, + action: InboundCallbackAction, + scopeToken: string, + ack: (ack: CallbackAcknowledgement) => Promise, + ): Promise { + const serverDB = await getServerDB(); + const link = await MessengerAccountLinkModel.findByPlatformUser( + serverDB, + creds.platform, + action.fromUserId, + creds.tenantId, + ); + if (!link) { + await ack({ toast: 'Not linked. Send /start first.' }); + return; + } + + const targetScopeId = scopeToken === PERSONAL_SCOPE_ID ? null : scopeToken; + const scopes = await this.fetchUserScopes(serverDB, link.userId); + const target = scopes.find((scope) => scope.id === targetScopeId); + if (!target) { + await ack({ toast: 'Scope not found.' }); + return; + } + + if ((link.workspaceId ?? null) === target.id) { + await ack({ toast: `You're already in ${target.name}.` }); + return; + } + + const defaultAgent = await this.applyScopeSwitch(serverDB, link.id, link.userId, target); + await ack({ + toast: this.scopeSwitchedText(target.name, defaultAgent?.title), + updatedPicker: { + action: 'scope', + entries: this.toScopeEntries(scopes, target.id), + text: 'Tap a scope to switch:', + }, + }); + } + /** * Fetch a user's agents for `/agents`. Mirrors the web * verify-im picker (and the home sidebar): @@ -1259,13 +1509,14 @@ export class MessengerRouter { private async fetchUserAgents( serverDB: LobeChatDatabase, userId: string, + workspaceId?: string | null, ): Promise { const rows = await serverDB .select({ id: agents.id, slug: agents.slug, title: agents.title }) .from(agents) .where( and( - eq(agents.userId, userId), + buildWorkspaceWhere({ userId, workspaceId: workspaceId ?? undefined }, agents), or(ne(agents.virtual, true), eq(agents.slug, INBOX_SESSION_ID)), ), ) @@ -1308,7 +1559,7 @@ export class MessengerRouter { ); const serverDB = await getServerDB(); - const bridge = new AgentBridgeService(serverDB, link.userId); + const bridge = new AgentBridgeService(serverDB, link.userId, link.workspaceId ?? undefined); // Messenger account-link routing already binds platform sender → // LobeHub user; the dispatch only fires for the linked sender. So diff --git a/src/server/services/messenger/platforms/discord/binder.test.ts b/src/server/services/messenger/platforms/discord/binder.test.ts index 2228a9c829..c1a973b3be 100644 --- a/src/server/services/messenger/platforms/discord/binder.test.ts +++ b/src/server/services/messenger/platforms/discord/binder.test.ts @@ -70,6 +70,20 @@ describe('buildDiscordSwitchButtons', () => { { customId: 'messenger:switch:agt_2', isPrimary: true, label: '✓ Coding' }, ]); }); + + it('uses the scope action namespace so /switch buttons are distinct from /agents', () => { + const buttons = buildDiscordSwitchButtons( + [ + { id: 'personal', isActive: true, title: 'Personal' }, + { id: 'workspace-1', isActive: false, title: 'love' }, + ], + 'scope', + ); + expect(buttons).toEqual([ + { customId: 'messenger:scope:personal', isPrimary: true, label: '✓ Personal' }, + { customId: 'messenger:scope:workspace-1', isPrimary: false, label: 'love' }, + ]); + }); }); describe('MessengerDiscordBinder', () => { diff --git a/src/server/services/messenger/platforms/discord/binder.ts b/src/server/services/messenger/platforms/discord/binder.ts index 253f35e0cf..80037e65b1 100644 --- a/src/server/services/messenger/platforms/discord/binder.ts +++ b/src/server/services/messenger/platforms/discord/binder.ts @@ -12,6 +12,7 @@ import type { AgentPickerEntry, CallbackAcknowledgement, InboundCallbackAction, + MessengerPickerAction, MessengerPlatformBinder, UnlinkedMessageContext, } from '../../types'; @@ -28,9 +29,10 @@ export const DISCORD_ACTION_PREFIX = 'messenger:'; export const buildDiscordSwitchButtons = ( entries: AgentPickerEntry[], + action: MessengerPickerAction = 'switch', ): Array<{ customId: string; isPrimary: boolean; label: string }> => entries.map((entry) => ({ - customId: `${DISCORD_ACTION_PREFIX}switch:${entry.id}`, + customId: `${DISCORD_ACTION_PREFIX}${action}:${entry.id}`, isPrimary: entry.isActive, // Prepend a check so the active option is recognizable on clients that // ignore the primary-style highlight (most Discord clients do honor it, @@ -229,6 +231,7 @@ export class MessengerDiscordBinder implements MessengerPlatformBinder { chatId: string, params: { entries: AgentPickerEntry[]; + action?: MessengerPickerAction; ephemeralTo?: string; interaction?: { applicationId: string; token: string }; text: string; @@ -241,7 +244,7 @@ export class MessengerDiscordBinder implements MessengerPlatformBinder { } try { const api = new DiscordApi(config.botToken); - const buttons = buildDiscordSwitchButtons(params.entries); + const buttons = buildDiscordSwitchButtons(params.entries, params.action); if (params.interaction) { await api.editInteractionOriginalWithButtons( params.interaction.applicationId, @@ -266,7 +269,7 @@ export class MessengerDiscordBinder implements MessengerPlatformBinder { async updateAgentPicker( chatId: string, messageId: string, - params: { entries: AgentPickerEntry[]; text: string }, + params: { action?: MessengerPickerAction; entries: AgentPickerEntry[]; text: string }, ): Promise { const config = await getMessengerDiscordConfig(); if (!config) return; @@ -276,7 +279,7 @@ export class MessengerDiscordBinder implements MessengerPlatformBinder { chatId, messageId, params.text, - buildDiscordSwitchButtons(params.entries), + buildDiscordSwitchButtons(params.entries, params.action), ); } catch (error) { log('updateAgentPicker: failed for chat=%s msg=%s: %O', chatId, messageId, error); diff --git a/src/server/services/messenger/platforms/slack/binder.ts b/src/server/services/messenger/platforms/slack/binder.ts index a8315f6abe..46a65e66e5 100644 --- a/src/server/services/messenger/platforms/slack/binder.ts +++ b/src/server/services/messenger/platforms/slack/binder.ts @@ -12,6 +12,7 @@ import type { AgentPickerEntry, CallbackAcknowledgement, InboundCallbackAction, + MessengerPickerAction, MessengerPlatformBinder, UnlinkedMessageContext, } from '../../types'; @@ -54,9 +55,10 @@ const buildVerifyImUrl = (params: { const buildSwitchButtons = ( entries: AgentPickerEntry[], + action: MessengerPickerAction = 'switch', ): Array<{ actionId: string; style?: 'primary'; text: string; value: string }> => entries.map((entry) => ({ - actionId: `${ACTION_PREFIX}switch:${entry.id}`, + actionId: `${ACTION_PREFIX}${action}:${entry.id}`, text: entry.isActive ? `✅ ${entry.title}` : entry.title, value: entry.id, ...(entry.isActive ? { style: 'primary' as const } : {}), @@ -262,7 +264,12 @@ export class MessengerSlackBinder implements MessengerPlatformBinder { async sendAgentPicker( chatId: string, - params: { entries: AgentPickerEntry[]; ephemeralTo?: string; text: string }, + params: { + action?: MessengerPickerAction; + entries: AgentPickerEntry[]; + ephemeralTo?: string; + text: string; + }, ): Promise { if (!this.creds) { log('sendAgentPicker: no creds, skipping'); @@ -270,7 +277,7 @@ export class MessengerSlackBinder implements MessengerPlatformBinder { } try { const api = new SlackApi(this.creds.botToken); - const buttons = buildSwitchButtons(params.entries); + const buttons = buildSwitchButtons(params.entries, params.action); // Channel invocation → ephemeral so the user's personal agent list // isn't broadcast. Slack's `chat.postEphemeral` accepts blocks and // delivers interactive button taps just like `chat.postMessage` — @@ -384,7 +391,7 @@ export class MessengerSlackBinder implements MessengerPlatformBinder { // only when the response_url is somehow absent and we have a permanent // ts to point at. if (ack.updatedPicker) { - const buttons = buildSwitchButtons(ack.updatedPicker.entries); + const buttons = buildSwitchButtons(ack.updatedPicker.entries, ack.updatedPicker.action); try { if (action.callbackId) { await api.updateEphemeralButtonGrid(action.callbackId, ack.updatedPicker.text, buttons); diff --git a/src/server/services/messenger/platforms/telegram/binder.ts b/src/server/services/messenger/platforms/telegram/binder.ts index 0a154b9b6d..496ae8bacd 100644 --- a/src/server/services/messenger/platforms/telegram/binder.ts +++ b/src/server/services/messenger/platforms/telegram/binder.ts @@ -11,6 +11,7 @@ import type { AgentPickerEntry, CallbackAcknowledgement, InboundCallbackAction, + MessengerPickerAction, MessengerPlatformBinder, UnlinkedMessageContext, } from '../../types'; @@ -24,10 +25,11 @@ const CALLBACK_PREFIX = 'messenger:'; const buildSwitchKeyboard = ( entries: AgentPickerEntry[], + action: MessengerPickerAction = 'switch', ): Array> => entries.map((entry) => [ { - callback_data: `${CALLBACK_PREFIX}switch:${entry.id}`, + callback_data: `${CALLBACK_PREFIX}${action}:${entry.id}`, text: entry.isActive ? `✅ ${entry.title}` : entry.title, }, ]); @@ -180,7 +182,7 @@ export class MessengerTelegramBinder implements MessengerPlatformBinder { async sendAgentPicker( chatId: string, - params: { entries: AgentPickerEntry[]; text: string }, + params: { action?: MessengerPickerAction; entries: AgentPickerEntry[]; text: string }, ): Promise { const config = await getMessengerTelegramConfig(); if (!config) return; @@ -189,7 +191,7 @@ export class MessengerTelegramBinder implements MessengerPlatformBinder { await api.sendMessageWithCallbackKeyboard( chatId, escapeHtml(params.text), - buildSwitchKeyboard(params.entries), + buildSwitchKeyboard(params.entries, params.action), ); } catch (error) { log('sendAgentPicker: failed for chat=%s: %O', chatId, error); @@ -244,7 +246,7 @@ export class MessengerTelegramBinder implements MessengerPlatformBinder { action.chatId, messageId as number, escapeHtml(ack.updatedPicker.text), - buildSwitchKeyboard(ack.updatedPicker.entries), + buildSwitchKeyboard(ack.updatedPicker.entries, ack.updatedPicker.action), ); } catch (error) { log('acknowledgeCallback: edit picker failed: %O', error); diff --git a/src/server/services/messenger/types.ts b/src/server/services/messenger/types.ts index 46d1a1fd1f..c6d2fb17a9 100644 --- a/src/server/services/messenger/types.ts +++ b/src/server/services/messenger/types.ts @@ -24,6 +24,15 @@ export interface AgentPickerEntry { title: string; } +/** + * Which tap-action a picker's buttons emit. `switch` re-targets the active + * agent (`/agents`); `scope` re-targets the active workspace scope + * (`/switch`). The keyword becomes the middle segment of the button id — + * `messenger:switch:` vs `messenger:scope:` — so the + * router's callback dispatch can tell the two pickers apart. + */ +export type MessengerPickerAction = 'switch' | 'scope'; + /** Raw inbound platform update used for actions chat-sdk doesn't surface. */ export interface InboundCallbackAction { /** Platform-specific raw id needed to acknowledge the action. */ @@ -42,8 +51,10 @@ export interface InboundCallbackAction { export interface CallbackAcknowledgement { /** Optional toast text shown above the user's keyboard. */ toast?: string; - /** When set, edit the picker message in place to reflect the new state. */ - updatedPicker?: { entries: AgentPickerEntry[]; text: string }; + /** When set, edit the picker message in place to reflect the new state. + * `action` defaults to `switch` so existing agent pickers re-render with + * the same button namespace they were posted with. */ + updatedPicker?: { action?: MessengerPickerAction; entries: AgentPickerEntry[]; text: string }; } /** @@ -179,6 +190,8 @@ export interface MessengerPlatformBinder { sendAgentPicker?: ( chatId: string, params: { + /** Tap-action the buttons emit; defaults to `switch` (agent picker). */ + action?: MessengerPickerAction; entries: AgentPickerEntry[]; ephemeralTo?: string; interaction?: { applicationId: string; token: string }; diff --git a/src/server/services/notebook/index.ts b/src/server/services/notebook/index.ts index 2e936b8f42..0ac6a774b6 100644 --- a/src/server/services/notebook/index.ts +++ b/src/server/services/notebook/index.ts @@ -21,6 +21,7 @@ interface DocumentServiceResult { export interface NotebookRuntimeServiceOptions { serverDB: LobeChatDatabase; userId: string; + workspaceId?: string; } const toServiceResult = (doc: { @@ -53,9 +54,17 @@ export class NotebookRuntimeService { private topicDocumentModel: TopicDocumentModel; constructor(options: NotebookRuntimeServiceOptions) { - this.documentService = new DocumentService(options.serverDB, options.userId); - this.documentModel = new DocumentModel(options.serverDB, options.userId); - this.topicDocumentModel = new TopicDocumentModel(options.serverDB, options.userId); + this.documentService = new DocumentService( + options.serverDB, + options.userId, + options.workspaceId, + ); + this.documentModel = new DocumentModel(options.serverDB, options.userId, options.workspaceId); + this.topicDocumentModel = new TopicDocumentModel( + options.serverDB, + options.userId, + options.workspaceId, + ); } associateDocumentWithTopic = async (documentId: string, topicId: string): Promise => { diff --git a/src/server/services/skill/importer.ts b/src/server/services/skill/importer.ts index 067645a877..667e6090c6 100644 --- a/src/server/services/skill/importer.ts +++ b/src/server/services/skill/importer.ts @@ -30,11 +30,11 @@ export class SkillImporter { private github: GitHub; private userId: string; - constructor(db: LobeChatDatabase, userId: string) { - this.skillModel = new AgentSkillModel(db, userId); + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { + this.skillModel = new AgentSkillModel(db, userId, workspaceId); this.parser = new SkillParser(); - this.resourceService = new SkillResourceService(db, userId); - this.fileService = new FileService(db, userId); + this.resourceService = new SkillResourceService(db, userId, workspaceId); + this.fileService = new FileService(db, userId, workspaceId); this.github = new GitHub({ userAgent: 'LobeHub-Skill-Importer' }); this.userId = userId; } diff --git a/src/server/services/skill/resource.ts b/src/server/services/skill/resource.ts index 951ce92626..7edfdfc978 100644 --- a/src/server/services/skill/resource.ts +++ b/src/server/services/skill/resource.ts @@ -31,8 +31,8 @@ function isTextMimeType(mimeType: string): boolean { export class SkillResourceService { private fileService: FileService; - constructor(db: LobeChatDatabase, userId: string) { - this.fileService = new FileService(db, userId); + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { + this.fileService = new FileService(db, userId, workspaceId); } /** diff --git a/src/server/services/skillManagement/SkillManagementDocumentService.test.ts b/src/server/services/skillManagement/SkillManagementDocumentService.test.ts index 835b9f55fb..de96eeff56 100644 --- a/src/server/services/skillManagement/SkillManagementDocumentService.test.ts +++ b/src/server/services/skillManagement/SkillManagementDocumentService.test.ts @@ -272,6 +272,7 @@ const createService = () => { callback({} as Transaction), } as LobeChatDatabase, 'user-1', + undefined, { agentDocumentModel, createMarkdownEditorSnapshot: createSnapshot, diff --git a/src/server/services/skillManagement/SkillManagementDocumentService.ts b/src/server/services/skillManagement/SkillManagementDocumentService.ts index bdede2a219..9ce8753caf 100644 --- a/src/server/services/skillManagement/SkillManagementDocumentService.ts +++ b/src/server/services/skillManagement/SkillManagementDocumentService.ts @@ -99,12 +99,14 @@ export class SkillManagementDocumentService { constructor( private db: LobeChatDatabase, userId: string, + workspaceId?: string, deps?: SkillManagementDocumentServiceDeps, ) { - this.agentDocumentModel = deps?.agentDocumentModel ?? new AgentDocumentModel(db, userId); + this.agentDocumentModel = + deps?.agentDocumentModel ?? new AgentDocumentModel(db, userId, workspaceId); this.createMarkdownEditorSnapshot = deps?.createMarkdownEditorSnapshot ?? createDefaultMarkdownEditorSnapshot; - this.documentService = deps?.documentService ?? new DocumentService(db, userId); + this.documentService = deps?.documentService ?? new DocumentService(db, userId, workspaceId); } /** diff --git a/src/server/services/systemAgent/index.ts b/src/server/services/systemAgent/index.ts index 727a0256d7..24f9bea41e 100644 --- a/src/server/services/systemAgent/index.ts +++ b/src/server/services/systemAgent/index.ts @@ -36,10 +36,12 @@ const TOPIC_TITLE_SCHEMA = { export class SystemAgentService { private readonly db: LobeChatDatabase; private readonly userId: string; + private readonly workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } /** @@ -66,7 +68,12 @@ export class SystemAgentService { const payload = chainSummaryTitle(messages, locale); - const modelRuntime = await initModelRuntimeFromDB(this.db, this.userId, provider); + const modelRuntime = await initModelRuntimeFromDB( + this.db, + this.userId, + provider, + this.workspaceId, + ); const result = await modelRuntime.generateObject( { messages: payload.messages as any[], diff --git a/src/server/services/task/index.ts b/src/server/services/task/index.ts index 5b77fbdba6..d552b841a7 100644 --- a/src/server/services/task/index.ts +++ b/src/server/services/task/index.ts @@ -74,15 +74,18 @@ export class TaskService { private topicModel: TopicModel; private userId: string; - constructor(db: LobeChatDatabase, userId: string) { + private workspaceId?: string; + + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; - this.agentModel = new AgentModel(db, userId); - this.taskModel = new TaskModel(db, userId); - this.taskTopicModel = new TaskTopicModel(db, userId); - this.topicModel = new TopicModel(db, userId); - this.briefModel = new BriefModel(db, userId); - this.briefService = new BriefService(db, userId); + this.workspaceId = workspaceId; + this.agentModel = new AgentModel(db, userId, workspaceId); + this.taskModel = new TaskModel(db, userId, workspaceId); + this.taskTopicModel = new TaskTopicModel(db, userId, workspaceId); + this.topicModel = new TopicModel(db, userId, workspaceId); + this.briefModel = new BriefModel(db, userId, workspaceId); + this.briefService = new BriefService(db, userId, workspaceId); } /** @@ -125,7 +128,9 @@ export class TaskService { } if (target.operationId) { - const aiAgentService = new AiAgentService(this.db, this.userId); + const aiAgentService = new AiAgentService(this.db, this.userId, { + workspaceId: this.workspaceId, + }); await aiAgentService.interruptTask({ operationId: target.operationId }); } @@ -142,7 +147,9 @@ export class TaskService { if (!target) throw new TRPCError({ code: 'NOT_FOUND', message: 'Topic not found.' }); if (target.status === 'running' && target.operationId) { - const aiAgentService = new AiAgentService(this.db, this.userId); + const aiAgentService = new AiAgentService(this.db, this.userId, { + workspaceId: this.workspaceId, + }); await aiAgentService.interruptTask({ operationId: target.operationId }); } @@ -186,7 +193,7 @@ export class TaskService { if (target?.reviewIteration) iteration = target.reviewIteration + 1; } - const reviewService = new TaskReviewService(this.db, this.userId); + const reviewService = new TaskReviewService(this.db, this.userId, this.workspaceId); const result = await reviewService.review({ content, iteration, @@ -231,7 +238,9 @@ export class TaskService { if (resolved.status === 'running' && status !== 'running') { const topics = await this.taskTopicModel.findByTaskId(resolved.id); - const aiAgentService = new AiAgentService(this.db, this.userId); + const aiAgentService = new AiAgentService(this.db, this.userId, { + workspaceId: this.workspaceId, + }); for (const t of topics) { if (t.status !== 'running' || !t.topicId) continue; @@ -296,7 +305,7 @@ export class TaskService { } // Unlock blocked tasks and actually kick them off via the runner. - const runner = new TaskRunnerService(this.db, this.userId); + const runner = new TaskRunnerService(this.db, this.userId, this.workspaceId); const cascade = await runner.cascadeOnCompletion(task.id); unlocked.push(...cascade.started); paused.push(...cascade.paused); @@ -317,7 +326,7 @@ export class TaskService { */ async previewSubtaskLayers(idOrIdentifier: string): Promise { const parent = await this.resolveOrThrow(idOrIdentifier); - const graph = new TaskGraphService(this.db, this.userId); + const graph = new TaskGraphService(this.db, this.userId, this.workspaceId); const { plan } = await graph.planForParent(parent.id); return plan; } @@ -329,7 +338,7 @@ export class TaskService { */ async runReadySubtasks(idOrIdentifier: string): Promise { const parent = await this.resolveOrThrow(idOrIdentifier); - const graph = new TaskGraphService(this.db, this.userId); + const graph = new TaskGraphService(this.db, this.userId, this.workspaceId); const { descendants, plan } = await graph.planForParent(parent.id); if (plan.layers.length === 0) { @@ -343,7 +352,7 @@ export class TaskService { const firstLayer = plan.layers[0]; const identifierToId = new Map(descendants.map((d) => [d.identifier, d.id])); - const runner = new TaskRunnerService(this.db, this.userId); + const runner = new TaskRunnerService(this.db, this.userId, this.workspaceId); const kickedOff: string[] = []; const failed: { error: string; identifier: string }[] = []; @@ -416,7 +425,7 @@ export class TaskService { ]); // Derive fileIds from persisted editor_data (single source of truth). - const extractCtx = { db: this.db, userId: this.userId }; + const extractCtx = { db: this.db, userId: this.userId, workspaceId: this.workspaceId }; const [taskFileIds, ...commentFileIdLists] = await Promise.all([ extractFileIdsFromEditorData(task.editorData, extractCtx), ...comments.map((c) => extractFileIdsFromEditorData(c.editorData, extractCtx)), @@ -432,6 +441,7 @@ export class TaskService { db: this.db, fileIds: allFileIds, userId: this.userId, + workspaceId: this.workspaceId, }); const fileById = new Map(allFileMetadata.map((f) => [f.id, f])); const taskFiles = taskFileIds.map((id) => fileById.get(id)).filter((f) => !!f); diff --git a/src/server/services/taskGraph/index.ts b/src/server/services/taskGraph/index.ts index 3ccafbe180..b0c53dba13 100644 --- a/src/server/services/taskGraph/index.ts +++ b/src/server/services/taskGraph/index.ts @@ -223,8 +223,8 @@ const findCycleMembers = (unplaced: string[], downstream: Map) export class TaskGraphService { private taskModel: TaskModel; - constructor(db: LobeChatDatabase, userId: string) { - this.taskModel = new TaskModel(db, userId); + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { + this.taskModel = new TaskModel(db, userId, workspaceId); } /** diff --git a/src/server/services/taskLifecycle/index.ts b/src/server/services/taskLifecycle/index.ts index 4270f8998b..c4e701bd57 100644 --- a/src/server/services/taskLifecycle/index.ts +++ b/src/server/services/taskLifecycle/index.ts @@ -86,14 +86,17 @@ export class TaskLifecycleService { private topicModel: TopicModel; private userId: string; - constructor(db: LobeChatDatabase, userId: string) { + private workspaceId?: string; + + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; - this.taskModel = new TaskModel(db, userId); - this.taskTopicModel = new TaskTopicModel(db, userId); - this.briefModel = new BriefModel(db, userId); - this.topicModel = new TopicModel(db, userId); - this.systemAgentService = new SystemAgentService(db, userId); + this.workspaceId = workspaceId; + this.taskModel = new TaskModel(db, userId, workspaceId); + this.taskTopicModel = new TaskTopicModel(db, userId, workspaceId); + this.briefModel = new BriefModel(db, userId, workspaceId); + this.topicModel = new TopicModel(db, userId, workspaceId); + this.systemAgentService = new SystemAgentService(db, userId, workspaceId); } /** @@ -354,7 +357,12 @@ export class TaskLifecycleService { taskName: currentTask?.name || taskIdentifier, }); - const modelRuntime = await initModelRuntimeFromDB(this.db, this.userId, provider); + const modelRuntime = await initModelRuntimeFromDB( + this.db, + this.userId, + provider, + this.workspaceId, + ); const result = await modelRuntime.generateObject( { messages: payload.messages as any[], @@ -461,7 +469,12 @@ export class TaskLifecycleService { taskName: currentTask.name || taskIdentifier, }); - const modelRuntime = await initModelRuntimeFromDB(this.db, this.userId, provider); + const modelRuntime = await initModelRuntimeFromDB( + this.db, + this.userId, + provider, + this.workspaceId, + ); const judgeResult = (await modelRuntime.generateObject( { messages: judgePayload.messages as any[], @@ -521,7 +534,12 @@ export class TaskLifecycleService { taskName: currentTask.name || taskIdentifier, }); - const modelRuntime = await initModelRuntimeFromDB(this.db, this.userId, provider); + const modelRuntime = await initModelRuntimeFromDB( + this.db, + this.userId, + provider, + this.workspaceId, + ); const result = await modelRuntime.generateObject( { messages: payload.messages as any[], @@ -596,7 +614,7 @@ export class TaskLifecycleService { const targetTopic = topicLinks.find((t) => t.topicId === topicId); const iteration = (targetTopic?.reviewIteration || 0) + 1; - const reviewService = new TaskReviewService(this.db, this.userId); + const reviewService = new TaskReviewService(this.db, this.userId, this.workspaceId); const reviewResult = await reviewService.review({ content, iteration, @@ -689,7 +707,7 @@ export class TaskLifecycleService { private async cascadeAfterAutoComplete(completedTaskId: string): Promise { try { const { TaskRunnerService } = await import('@/server/services/taskRunner'); - const runner = new TaskRunnerService(this.db, this.userId); + const runner = new TaskRunnerService(this.db, this.userId, this.workspaceId); await runner.cascadeOnCompletion(completedTaskId); } catch (e) { console.warn('[TaskLifecycle] dependency cascade failed:', e); diff --git a/src/server/services/taskReview/index.ts b/src/server/services/taskReview/index.ts index 4a1b42a6e1..8ff697f4b2 100644 --- a/src/server/services/taskReview/index.ts +++ b/src/server/services/taskReview/index.ts @@ -36,10 +36,12 @@ export interface ReviewResult { export class TaskReviewService { private db: LobeChatDatabase; private userId: string; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } async review(params: { @@ -64,7 +66,12 @@ export class TaskReviewService { ); // 2. Initialize ModelRuntime for LLM-based rubrics - const modelRuntime = await initModelRuntimeFromDB(this.db, this.userId, provider); + const modelRuntime = await initModelRuntimeFromDB( + this.db, + this.userId, + provider, + this.workspaceId, + ); // 3. Run evaluate() from @lobechat/eval-rubric const result: EvaluateResult = await evaluate( diff --git a/src/server/services/taskRunner/buildTaskPrompt.ts b/src/server/services/taskRunner/buildTaskPrompt.ts index 143dbbfc3c..81f1cc50f0 100644 --- a/src/server/services/taskRunner/buildTaskPrompt.ts +++ b/src/server/services/taskRunner/buildTaskPrompt.ts @@ -14,6 +14,7 @@ export interface BuildTaskPromptDeps { taskModel: TaskModel; taskTopicModel: TaskTopicModel; userId: string; + workspaceId?: string; } export interface BuiltTaskPrompt { @@ -35,7 +36,7 @@ export async function buildTaskPrompt( deps: BuildTaskPromptDeps, extraPrompt?: string, ): Promise { - const { briefModel, db, taskModel, taskTopicModel, userId } = deps; + const { briefModel, db, taskModel, taskTopicModel, userId, workspaceId } = deps; const [topics, briefs, comments, subtasks, dependencies, documents] = await Promise.all([ task.totalTopics && task.totalTopics > 0 @@ -53,7 +54,7 @@ export async function buildTaskPrompt( // Derive fileIds from the persisted Lexical state. editor_data is the // single source of truth — fileId is recovered from the URL in each node // (proxy URL form via regex; pre-signed dev URLs via files.url lookup). - const extractCtx = { db, userId }; + const extractCtx = { db, userId, workspaceId }; const [taskFileIds, ...commentFileIdLists] = await Promise.all([ extractFileIdsFromEditorData(task.editorData, extractCtx), ...comments.map((c) => extractFileIdsFromEditorData(c.editorData, extractCtx)), @@ -75,6 +76,7 @@ export async function buildTaskPrompt( fileIds: allFileIds, signUrls: false, userId, + workspaceId, }); const fileMetaById = new Map(fileMetadata.map((f) => [f.id, f])); diff --git a/src/server/services/taskRunner/heartbeatTick.ts b/src/server/services/taskRunner/heartbeatTick.ts index f3d95fb967..81b9bbab31 100644 --- a/src/server/services/taskRunner/heartbeatTick.ts +++ b/src/server/services/taskRunner/heartbeatTick.ts @@ -1,8 +1,9 @@ import { TRPCError } from '@trpc/server'; import debug from 'debug'; +import { and, eq } from 'drizzle-orm'; import { BriefModel } from '@/database/models/brief'; -import { TaskModel } from '@/database/models/task'; +import { tasks } from '@/database/schemas'; import { getServerDB } from '@/database/server'; import { setTaskSchedulerExecutionCallback } from '@/server/services/taskScheduler'; @@ -39,8 +40,13 @@ export async function runHeartbeatTick( ): Promise { const db = await getServerDB(); - const taskModel = new TaskModel(db, userId); - const task = await taskModel.findById(taskId); + // System-level dispatch: read the task row directly to learn its + // `workspaceId` before constructing downstream models. + const [task] = await db + .select() + .from(tasks) + .where(and(eq(tasks.id, taskId), eq(tasks.createdByUserId, userId))) + .limit(1); if (!task) { log('skip task=%s reason=not-found', taskId); return { ran: false, reason: 'not-found' }; @@ -58,13 +64,14 @@ export async function runHeartbeatTick( return { ran: false, reason: 'no-interval' }; } - const briefModel = new BriefModel(db, userId); + const wsId = task.workspaceId ?? undefined; + const briefModel = new BriefModel(db, userId, wsId); if (await briefModel.hasUnresolvedUrgentByTask(taskId)) { log('skip task=%s reason=human-waiting', taskId); return { ran: false, reason: 'human-waiting' }; } - const runner = new TaskRunnerService(db, userId); + const runner = new TaskRunnerService(db, userId, wsId); try { await runner.runTask({ taskId }); } catch (e) { diff --git a/src/server/services/taskRunner/index.ts b/src/server/services/taskRunner/index.ts index 918f9eba5d..01b5f33cff 100644 --- a/src/server/services/taskRunner/index.ts +++ b/src/server/services/taskRunner/index.ts @@ -45,14 +45,17 @@ export class TaskRunnerService { private taskTopicModel: TaskTopicModel; private userId: string; - constructor(db: LobeChatDatabase, userId: string) { + private workspaceId?: string; + + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; - this.agentModel = new AgentModel(db, userId); - this.taskModel = new TaskModel(db, userId); - this.taskTopicModel = new TaskTopicModel(db, userId); - this.briefModel = new BriefModel(db, userId); - this.taskLifecycle = new TaskLifecycleService(db, userId); + this.workspaceId = workspaceId; + this.agentModel = new AgentModel(db, userId, workspaceId); + this.taskModel = new TaskModel(db, userId, workspaceId); + this.taskTopicModel = new TaskTopicModel(db, userId, workspaceId); + this.briefModel = new BriefModel(db, userId, workspaceId); + this.taskLifecycle = new TaskLifecycleService(db, userId, workspaceId); } async runTask(params: RunTaskParams): Promise { @@ -118,6 +121,7 @@ export class TaskRunnerService { taskModel: this.taskModel, taskTopicModel: this.taskTopicModel, userId: this.userId, + workspaceId: this.workspaceId, }, extraPrompt, ); @@ -135,7 +139,9 @@ export class TaskRunnerService { const agentRef = task.assigneeAgentId!; const isSlug = !agentRef.startsWith('agt_'); - const aiAgentService = new AiAgentService(this.db, this.userId); + const aiAgentService = new AiAgentService(this.db, this.userId, { + workspaceId: this.workspaceId, + }); const taskId = task.id; const taskIdentifier = task.identifier; const taskLifecycle = this.taskLifecycle; diff --git a/src/server/services/taskRunner/scheduleTick.test.ts b/src/server/services/taskRunner/scheduleTick.test.ts index 445bb33c2e..28e0c72ca1 100644 --- a/src/server/services/taskRunner/scheduleTick.test.ts +++ b/src/server/services/taskRunner/scheduleTick.test.ts @@ -9,8 +9,18 @@ import { TaskTopicModel } from '@/database/models/taskTopic'; import { TaskRunnerService } from './index'; import { runScheduleTick } from './scheduleTick'; +const mockSelectTask = vi.fn(); + vi.mock('@/database/server', () => ({ - getServerDB: vi.fn().mockResolvedValue({}), + getServerDB: vi.fn().mockResolvedValue({ + select: () => ({ + from: () => ({ + where: () => ({ + limit: () => mockSelectTask(), + }), + }), + }), + }), })); vi.mock('@/database/models/task', () => ({ @@ -34,7 +44,6 @@ describe('runScheduleTick', () => { const userId = 'user-1'; const mockTaskModel = { - findById: vi.fn(), updateStatus: vi.fn(), }; const mockTaskTopicModel = { @@ -60,6 +69,7 @@ describe('runScheduleTick', () => { beforeEach(() => { vi.clearAllMocks(); + mockSelectTask.mockResolvedValue([]); mockBriefModel.hasUnresolvedUrgentByTask.mockResolvedValue(false); (TaskModel as any).mockImplementation(() => mockTaskModel); (TaskTopicModel as any).mockImplementation(() => mockTaskTopicModel); @@ -68,7 +78,7 @@ describe('runScheduleTick', () => { }); it('skips not-found tasks', async () => { - mockTaskModel.findById.mockResolvedValue(null); + mockSelectTask.mockResolvedValue([]); const outcome = await runScheduleTick(taskId, userId); @@ -77,7 +87,7 @@ describe('runScheduleTick', () => { }); it('skips when automationMode has been changed away from schedule', async () => { - mockTaskModel.findById.mockResolvedValue(baseTask({ automationMode: 'heartbeat' })); + mockSelectTask.mockResolvedValue([baseTask({ automationMode: 'heartbeat' })]); const outcome = await runScheduleTick(taskId, userId); @@ -86,7 +96,7 @@ describe('runScheduleTick', () => { }); it('skips terminal / paused tasks before checking maxExecutions', async () => { - mockTaskModel.findById.mockResolvedValue(baseTask({ status: 'paused' })); + mockSelectTask.mockResolvedValue([baseTask({ status: 'paused' })]); const outcome = await runScheduleTick(taskId, userId); @@ -95,7 +105,7 @@ describe('runScheduleTick', () => { }); it('runs the task when no maxExecutions is configured', async () => { - mockTaskModel.findById.mockResolvedValue(baseTask({ config: {} })); + mockSelectTask.mockResolvedValue([baseTask({ config: {} })]); mockRunner.runTask.mockResolvedValue(undefined); const outcome = await runScheduleTick(taskId, userId); @@ -106,9 +116,7 @@ describe('runScheduleTick', () => { }); it('runs the task when the run count is still under maxExecutions', async () => { - mockTaskModel.findById.mockResolvedValue( - baseTask({ config: { schedule: { maxExecutions: 10 } } }), - ); + mockSelectTask.mockResolvedValue([baseTask({ config: { schedule: { maxExecutions: 10 } } })]); mockTaskTopicModel.countByTask.mockResolvedValue(7); mockRunner.runTask.mockResolvedValue(undefined); @@ -123,9 +131,7 @@ describe('runScheduleTick', () => { }); it('marks the task completed and skips when the run count has reached maxExecutions', async () => { - mockTaskModel.findById.mockResolvedValue( - baseTask({ config: { schedule: { maxExecutions: 10 } } }), - ); + mockSelectTask.mockResolvedValue([baseTask({ config: { schedule: { maxExecutions: 10 } } })]); mockTaskTopicModel.countByTask.mockResolvedValue(10); const outcome = await runScheduleTick(taskId, userId); @@ -141,9 +147,9 @@ describe('runScheduleTick', () => { // Pre-existing scheduled tasks (created before this PR) won't have a // scheduleStartedAt stamp. They should still tick normally; the cap will // start enforcing once the user pauses + restarts. - mockTaskModel.findById.mockResolvedValue( + mockSelectTask.mockResolvedValue([ baseTask({ config: { schedule: { maxExecutions: 10 } }, context: {} }), - ); + ]); mockRunner.runTask.mockResolvedValue(undefined); const outcome = await runScheduleTick(taskId, userId); @@ -154,7 +160,7 @@ describe('runScheduleTick', () => { }); it('returns in-flight when runTask raises a CONFLICT', async () => { - mockTaskModel.findById.mockResolvedValue(baseTask({ config: {} })); + mockSelectTask.mockResolvedValue([baseTask({ config: {} })]); mockRunner.runTask.mockRejectedValue(new TRPCError({ code: 'CONFLICT', message: 'busy' })); const outcome = await runScheduleTick(taskId, userId); @@ -163,7 +169,7 @@ describe('runScheduleTick', () => { }); it('skips when a human is waiting on an urgent brief', async () => { - mockTaskModel.findById.mockResolvedValue(baseTask({ config: {} })); + mockSelectTask.mockResolvedValue([baseTask({ config: {} })]); mockBriefModel.hasUnresolvedUrgentByTask.mockResolvedValue(true); const outcome = await runScheduleTick(taskId, userId); diff --git a/src/server/services/taskRunner/scheduleTick.ts b/src/server/services/taskRunner/scheduleTick.ts index a49e00aa17..1796fffdac 100644 --- a/src/server/services/taskRunner/scheduleTick.ts +++ b/src/server/services/taskRunner/scheduleTick.ts @@ -1,9 +1,11 @@ import { TRPCError } from '@trpc/server'; import debug from 'debug'; +import { and, eq } from 'drizzle-orm'; import { BriefModel } from '@/database/models/brief'; import { TaskModel } from '@/database/models/task'; import { TaskTopicModel } from '@/database/models/taskTopic'; +import { tasks } from '@/database/schemas'; import { getServerDB } from '@/database/server'; import { TaskRunnerService } from './index'; @@ -40,12 +42,20 @@ export async function runScheduleTick( ): Promise { const db = await getServerDB(); - const taskModel = new TaskModel(db, userId); - const task = await taskModel.findById(taskId); + // System-level dispatch: we don't have the workspace context here. Read the + // task row directly (creator-scoped, workspace-agnostic) to learn + // `task.workspaceId`, then use it to instantiate downstream models so brief + // writes / lifecycle hits land in the right workspace. + const [task] = await db + .select() + .from(tasks) + .where(and(eq(tasks.id, taskId), eq(tasks.createdByUserId, userId))) + .limit(1); if (!task) { log('skip task=%s reason=not-found', taskId); return { ran: false, reason: 'not-found' }; } + const wsId = task.workspaceId ?? undefined; if (task.automationMode !== 'schedule') { log('skip task=%s reason=mode-changed (mode=%s)', taskId, task.automationMode); return { ran: false, reason: 'mode-changed' }; @@ -63,7 +73,7 @@ export async function runScheduleTick( return { ran: false, reason: 'paused' }; } - const briefModel = new BriefModel(db, userId); + const briefModel = new BriefModel(db, userId, wsId); if (await briefModel.hasUnresolvedUrgentByTask(taskId)) { log('skip task=%s reason=human-waiting', taskId); return { ran: false, reason: 'human-waiting' }; @@ -90,7 +100,7 @@ export async function runScheduleTick( const startedAtIso = scheduler.scheduleStartedAt; if (startedAtIso) { const startedAt = new Date(startedAtIso); - const topicModel = new TaskTopicModel(db, userId); + const topicModel = new TaskTopicModel(db, userId, wsId); const runCount = await topicModel.countByTask(taskId, { since: startedAt }); if (runCount >= maxExecutions) { log( @@ -99,13 +109,14 @@ export async function runScheduleTick( runCount, maxExecutions, ); + const taskModel = new TaskModel(db, userId, wsId); await taskModel.updateStatus(taskId, 'completed', { completedAt: new Date() }); return { ran: false, reason: 'max-executions-reached' }; } } } - const runner = new TaskRunnerService(db, userId); + const runner = new TaskRunnerService(db, userId, wsId); try { await runner.runTask({ taskId }); } catch (e) { diff --git a/src/server/services/taskTemplate/index.test.ts b/src/server/services/taskTemplate/index.test.ts index 297a372e4a..73419d4e5c 100644 --- a/src/server/services/taskTemplate/index.test.ts +++ b/src/server/services/taskTemplate/index.test.ts @@ -1,6 +1,10 @@ // @vitest-environment node import type { TaskTemplate } from '@lobechat/const'; -import { TASK_TEMPLATE_RECOMMEND_COUNT, taskTemplates } from '@lobechat/const'; +import { + TASK_TEMPLATE_PERSONAL_ONLY_CATEGORIES, + TASK_TEMPLATE_RECOMMEND_COUNT, + taskTemplates, +} from '@lobechat/const'; import { describe, expect, it } from 'vitest'; import { isTemplateSkillSourceEligible, TaskTemplateService } from './index'; @@ -179,6 +183,70 @@ describe('TaskTemplateService.listDailyRecommend', () => { } }); + it('drops personal-only categories in workspace mode', async () => { + const service = new TaskTemplateService('user-1'); + const personalCategories = new Set(TASK_TEMPLATE_PERSONAL_ONLY_CATEGORIES); + + // Use a personal interest that would otherwise match personal-life templates. + const picked = await service.listDailyRecommend(['personal'], { + now: UTC_DAY_1, + workspaceMode: true, + }); + for (const p of picked) { + expect(personalCategories.has(p.category), `template ${p.id} category`).toBe(false); + } + }); + + it('shuffles broadly across refreshSeeds in workspace mode with empty interests', async () => { + // The original narrow workspace fallback (operations + learning-research) + // resolved to ~4 templates after skill gating, locking "换一批" to a + // permutation of the same 3-of-4. Workspace fallback must draw from the + // full non-personal candidate set so refresh actually rotates. + const service = new TaskTemplateService('user-1'); + const seenIds = new Set(); + for (const seed of ['s1', 's2', 's3', 's4', 's5', 's6', 's7', 's8']) { + const picked = await service.listDailyRecommend([], { + now: UTC_DAY_1, + refreshSeed: seed, + workspaceMode: true, + }); + for (const p of picked) seenIds.add(p.id); + } + // 8 refreshes × 3 picks = 24 slots. Across this many seeds the pool + // should clearly exceed the old 4-template ceiling. + expect(seenIds.size).toBeGreaterThan(10); + }); + + it('keeps personal-only categories in personal mode (default)', async () => { + const service = new TaskTemplateService('user-1'); + const personalCategories = new Set(TASK_TEMPLATE_PERSONAL_ONLY_CATEGORIES); + + // Sample enough seeds so the personal fallback pool surfaces. + const reached = new Set(); + for (const seed of ['p1', 'p2', 'p3', 'p4', 'p5', 'p6']) { + const picked = await service.listDailyRecommend(['personal'], { + now: UTC_DAY_1, + refreshSeed: seed, + }); + for (const p of picked) { + if (personalCategories.has(p.category)) reached.add(p.id); + } + } + expect(reached.size).toBeGreaterThan(0); + }); + + it('produces a different shuffle seed for workspace mode vs personal mode', async () => { + // Seed namespaces are isolated so workspace recommendations don't mirror + // the personal lineup for the same user/day. + const service = new TaskTemplateService('user-1'); + const personal = await service.listDailyRecommend(['coding'], { now: UTC_DAY_1 }); + const workspace = await service.listDailyRecommend(['coding'], { + now: UTC_DAY_1, + workspaceMode: true, + }); + expect(personal.map((t) => t.id).join(',')).not.toBe(workspace.map((t) => t.id).join(',')); + }); + it('changes the first item across refreshSeeds when matched candidates are fewer than the default recommendation count', async () => { // Repro for: `health` interest matches only one template (`diet-log-companion`), // so the legacy "matched-first" logic locked it to position 0 regardless of seed. diff --git a/src/server/services/taskTemplate/index.ts b/src/server/services/taskTemplate/index.ts index 337ceecf8f..34f6978dc3 100644 --- a/src/server/services/taskTemplate/index.ts +++ b/src/server/services/taskTemplate/index.ts @@ -1,6 +1,7 @@ import type { TaskTemplate, TaskTemplateSkillSource } from '@lobechat/const'; import { TASK_TEMPLATE_FALLBACK_CATEGORIES, + TASK_TEMPLATE_PERSONAL_ONLY_CATEGORIES, TASK_TEMPLATE_RECOMMEND_COUNT, taskTemplates, } from '@lobechat/const'; @@ -82,6 +83,12 @@ export class TaskTemplateService { excludeIds?: string[]; now?: Date; refreshSeed?: string; + /** + * When true, drop every template under `TASK_TEMPLATE_PERSONAL_ONLY_CATEGORIES` + * and use the workspace-flavored fallback pool. Used by the cloud router + * whenever the request is bound to a workspace context. + */ + workspaceMode?: boolean; } = {}, ): Promise { const { @@ -90,14 +97,22 @@ export class TaskTemplateService { excludeIds, now = new Date(), refreshSeed, + workspaceMode = false, } = options; const limit = Math.max(1, count); const excluded = new Set(excludeIds ?? []); - const seedBase = `${this.userId}:${getUtcDateStr(now)}`; + const seedBase = workspaceMode + ? `${this.userId}:ws:${getUtcDateStr(now)}` + : `${this.userId}:${getUtcDateStr(now)}`; const seed = hashString(refreshSeed ? `${seedBase}:${refreshSeed}` : seedBase); + const personalOnly = new Set(TASK_TEMPLATE_PERSONAL_ONLY_CATEGORIES); + const candidates = taskTemplates.filter( - (t) => !excluded.has(t.id) && isTemplateSkillSourceEligible(t, enabledSkillSources), + (t) => + !excluded.has(t.id) && + isTemplateSkillSourceEligible(t, enabledSkillSources) && + (!workspaceMode || !personalOnly.has(t.category)), ); const matched = candidates.filter((t) => hasIntersection(t, interestKeys)); const result: TaskTemplate[] = []; @@ -108,10 +123,18 @@ export class TaskTemplateService { // Not enough interest matches: fold the fallback pool in so refreshSeed // can reorder the whole batch — otherwise a single-match interest pins // that template to position 0 forever. + // + // Personal mode keeps the narrow `personal-life + learning-research` + // fallback (it's the existing vibe). Workspace mode uses the full + // non-personal candidate set — the original 2-category workspace + // fallback resolved to ~4 templates after skill gating and made + // "换一批" a no-op. const matchedIds = new Set(matched.map((t) => t.id)); - const fallback = candidates.filter( - (t) => TASK_TEMPLATE_FALLBACK_CATEGORIES.includes(t.category) && !matchedIds.has(t.id), - ); + const fallback = workspaceMode + ? candidates.filter((t) => !matchedIds.has(t.id)) + : candidates.filter( + (t) => TASK_TEMPLATE_FALLBACK_CATEGORIES.includes(t.category) && !matchedIds.has(t.id), + ); const pool = [...matched, ...fallback]; result.push(...seededShuffle(pool, seed).slice(0, limit)); } diff --git a/src/server/services/toolExecution/archiveToolResult.ts b/src/server/services/toolExecution/archiveToolResult.ts index 64470cfc51..a9b055e0fd 100644 --- a/src/server/services/toolExecution/archiveToolResult.ts +++ b/src/server/services/toolExecution/archiveToolResult.ts @@ -28,6 +28,7 @@ interface ArchiveToolResultParams { toolCallId?: string; topicId?: string | null; userId?: string; + workspaceId?: string; } const buildArchivePath = (topicId: string, toolCallId: string) => @@ -45,6 +46,7 @@ export const archiveToolResultIfNeeded = async ({ toolCallId, topicId, userId, + workspaceId, }: ArchiveToolResultParams): Promise => { if (identifier && ARCHIVE_BYPASS_IDENTIFIERS.has(identifier)) { return { archived: false, content }; @@ -65,12 +67,12 @@ export const archiveToolResultIfNeeded = async ({ const archivePath = buildArchivePath(topicId, toolCallId); try { - const vfsService = new AgentDocumentVfsService(serverDB, userId); + const vfsService = new AgentDocumentVfsService(serverDB, userId, workspaceId); await vfsService.mkdir(TOOL_RESULTS_DIR, { agentId, topicId }, { recursive: true }); const stats = await vfsService.write(archivePath, content, { agentId, topicId }); if (stats.documentId) { - const topicDocumentModel = new TopicDocumentModel(serverDB, userId); + const topicDocumentModel = new TopicDocumentModel(serverDB, userId, workspaceId); const associated = await topicDocumentModel.isAssociated(stats.documentId, topicId); if (!associated) { await topicDocumentModel.associate({ diff --git a/src/server/services/toolExecution/builtin.ts b/src/server/services/toolExecution/builtin.ts index 53d1ec8abc..45cc7bc55b 100644 --- a/src/server/services/toolExecution/builtin.ts +++ b/src/server/services/toolExecution/builtin.ts @@ -84,6 +84,7 @@ export class BuiltinToolsExecutor implements IToolExecutor { args, identifier, toolName: apiName, + workspaceId: context.workspaceId, }); } diff --git a/src/server/services/toolExecution/index.ts b/src/server/services/toolExecution/index.ts index cea571185c..e20b2c0684 100644 --- a/src/server/services/toolExecution/index.ts +++ b/src/server/services/toolExecution/index.ts @@ -92,6 +92,7 @@ export class ToolExecutionService { context.userId, identifier, apiName, + context.workspaceId, ); if (permission === ConnectorToolPermission.disabled) { log('Tool %s:%s is disabled by user — blocking execution', identifier, apiName); diff --git a/src/server/services/toolExecution/serverRuntimes/__tests__/agentDocuments.test.ts b/src/server/services/toolExecution/serverRuntimes/__tests__/agentDocuments.test.ts index 6ff2f4f166..e94e671f98 100644 --- a/src/server/services/toolExecution/serverRuntimes/__tests__/agentDocuments.test.ts +++ b/src/server/services/toolExecution/serverRuntimes/__tests__/agentDocuments.test.ts @@ -64,12 +64,20 @@ describe('agentDocumentsRuntime auto-pin to task', () => { vi.mocked(TaskModel).mockImplementation(() => ({ pinDocument }) as any); }); - const buildContext = (taskId?: string) => ({ - serverDB: {} as never, - taskId, - toolManifestMap: {}, - userId: 'user-1', - }); + const buildContext = (taskId?: string) => { + // Mock the workspace lookup chain that `pinToTask` runs against the task + // row. Returning `workspaceId: null` reproduces personal-mode behavior. + const limit = vi.fn().mockResolvedValue([{ workspaceId: null }]); + const where = vi.fn().mockReturnValue({ limit }); + const from = vi.fn().mockReturnValue({ where }); + const select = vi.fn().mockReturnValue({ from }); + return { + serverDB: { select } as never, + taskId, + toolManifestMap: {}, + userId: 'user-1', + }; + }; it('pins newly created document when taskId is in context', async () => { const runtime = agentDocumentsRuntime.factory(buildContext('task-1')); diff --git a/src/server/services/toolExecution/serverRuntimes/__tests__/agentManagement.test.ts b/src/server/services/toolExecution/serverRuntimes/__tests__/agentManagement.test.ts index 7b25872325..7e4d7ae837 100644 --- a/src/server/services/toolExecution/serverRuntimes/__tests__/agentManagement.test.ts +++ b/src/server/services/toolExecution/serverRuntimes/__tests__/agentManagement.test.ts @@ -1,5 +1,8 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'; +import { AgentModel } from '@/database/models/agent'; +import { PluginModel } from '@/database/models/plugin'; + import { agentManagementRuntime } from '../agentManagement'; const { mockCountAgents, mockGetAssistantList, mockQueryAgents } = vi.hoisted(() => ({ @@ -32,6 +35,14 @@ const createRuntime = () => userId: 'user-1', }); +const createWorkspaceRuntime = () => + agentManagementRuntime.factory({ + serverDB: {} as never, + toolManifestMap: {}, + userId: 'user-1', + workspaceId: 'workspace-1', + }); + const makeAgents = (count: number, startIndex = 0) => Array.from({ length: count }, (_, i) => ({ avatar: null, @@ -56,6 +67,13 @@ describe('agentManagementRuntime', () => { ); }); + it('scopes agent and plugin models to workspace context', () => { + createWorkspaceRuntime(); + + expect(AgentModel).toHaveBeenCalledWith(expect.anything(), 'user-1', 'workspace-1'); + expect(PluginModel).toHaveBeenCalledWith(expect.anything(), 'user-1', 'workspace-1'); + }); + describe('searchAgent', () => { it('reports the real total and a pagination hint when more agents exist', async () => { mockQueryAgents.mockResolvedValue(makeAgents(20)); diff --git a/src/server/services/toolExecution/serverRuntimes/__tests__/agentSignalSkillManagement.test.ts b/src/server/services/toolExecution/serverRuntimes/__tests__/agentSignalSkillManagement.test.ts new file mode 100644 index 0000000000..754b23f285 --- /dev/null +++ b/src/server/services/toolExecution/serverRuntimes/__tests__/agentSignalSkillManagement.test.ts @@ -0,0 +1,32 @@ +// @vitest-environment node +import { describe, expect, it, vi } from 'vitest'; + +import { SkillManagementDocumentService } from '@/server/services/skillManagement'; + +import { agentSignalSkillManagementRuntime } from '../agentSignalSkillManagement'; + +vi.mock('@/server/services/skillManagement'); + +describe('agentSignalSkillManagementRuntime', () => { + it('throws if required server context is missing', () => { + expect(() => + agentSignalSkillManagementRuntime.factory({ + serverDB: {} as never, + toolManifestMap: {}, + userId: 'user-1', + }), + ).toThrow('agent-signal-skill-management requires agentId, userId and serverDB'); + }); + + it('threads the workspaceId into the skill document service so writes stay workspace-scoped', () => { + agentSignalSkillManagementRuntime.factory({ + agentId: 'agent-1', + serverDB: {} as never, + toolManifestMap: {}, + userId: 'user-1', + workspaceId: 'ws-1', + }); + + expect(SkillManagementDocumentService).toHaveBeenCalledWith({}, 'user-1', 'ws-1'); + }); +}); diff --git a/src/server/services/toolExecution/serverRuntimes/__tests__/lobeAgent.test.ts b/src/server/services/toolExecution/serverRuntimes/__tests__/lobeAgent.test.ts index 134f6fe0b1..cd395b8a43 100644 --- a/src/server/services/toolExecution/serverRuntimes/__tests__/lobeAgent.test.ts +++ b/src/server/services/toolExecution/serverRuntimes/__tests__/lobeAgent.test.ts @@ -205,6 +205,22 @@ describe('lobeAgentRuntime', () => { ); }); + it('should pass workspaceId when initializing visual model runtime', async () => { + const runtime = lobeAgentRuntime.factory({ ...baseContext, workspaceId: 'workspace-1' }); + + await runtime.analyzeVisualMedia({ + question: 'what is this?', + urls: ['https://example.com/generated.png'], + }); + + expect(mockInitModelRuntimeFromDB).toHaveBeenCalledWith( + baseContext.serverDB, + 'user-1', + 'test-provider', + 'workspace-1', + ); + }); + it('should accumulate text content_part chunks from the visual model', async () => { mockChat.mockImplementationOnce(async (_payload, options) => { options?.callback?.onContentPart?.({ diff --git a/src/server/services/toolExecution/serverRuntimes/__tests__/lobeAgentPlan.test.ts b/src/server/services/toolExecution/serverRuntimes/__tests__/lobeAgentPlan.test.ts new file mode 100644 index 0000000000..68459a52b0 --- /dev/null +++ b/src/server/services/toolExecution/serverRuntimes/__tests__/lobeAgentPlan.test.ts @@ -0,0 +1,29 @@ +import { describe, expect, it, vi } from 'vitest'; + +import { DocumentModel } from '@/database/models/document'; +import { TopicDocumentModel } from '@/database/models/topicDocument'; + +import { createServerPlanRuntimeService } from '../lobeAgentPlan'; + +vi.mock('@/database/models/document', () => ({ + DocumentModel: vi.fn(() => ({ + findById: vi.fn(), + })), +})); + +vi.mock('@/database/models/topicDocument', () => ({ + TopicDocumentModel: vi.fn(() => ({ + findByTopicId: vi.fn(), + })), +})); + +describe('createServerPlanRuntimeService', () => { + it('scopes document models to workspace context', () => { + const serverDB = {} as never; + + createServerPlanRuntimeService(serverDB, 'user-1', 'workspace-1'); + + expect(DocumentModel).toHaveBeenCalledWith(serverDB, 'user-1', 'workspace-1'); + expect(TopicDocumentModel).toHaveBeenCalledWith(serverDB, 'user-1', 'workspace-1'); + }); +}); diff --git a/src/server/services/toolExecution/serverRuntimes/__tests__/message.test.ts b/src/server/services/toolExecution/serverRuntimes/__tests__/message.test.ts index 6f95b48bf7..617cf85f00 100644 --- a/src/server/services/toolExecution/serverRuntimes/__tests__/message.test.ts +++ b/src/server/services/toolExecution/serverRuntimes/__tests__/message.test.ts @@ -880,7 +880,7 @@ describe('messageRuntime', () => { }); expect(result.success).toBe(true); - expect(mockLinkSetActiveAgent).toHaveBeenCalledWith('slack', 'agent_x', 'T1'); + expect(mockLinkSetActiveAgent).toHaveBeenCalledWith('slack', 'agent_x', null, 'T1'); }); it('rejects when the agent does not belong to the caller', async () => { @@ -906,7 +906,7 @@ describe('messageRuntime', () => { }); expect(result.success).toBe(true); - expect(mockLinkSetActiveAgent).toHaveBeenCalledWith('telegram', null, undefined); + expect(mockLinkSetActiveAgent).toHaveBeenCalledWith('telegram', null, null, undefined); }); it('returns NOT_FOUND when the link does not exist', async () => { diff --git a/src/server/services/toolExecution/serverRuntimes/__tests__/skillManagement.test.ts b/src/server/services/toolExecution/serverRuntimes/__tests__/skillManagement.test.ts index 2c3809b734..2a4672a10d 100644 --- a/src/server/services/toolExecution/serverRuntimes/__tests__/skillManagement.test.ts +++ b/src/server/services/toolExecution/serverRuntimes/__tests__/skillManagement.test.ts @@ -33,14 +33,15 @@ describe('skillManagementRuntime', () => { * @example * The registration factory creates a package-level runtime backed by SkillManagementDocumentService. */ - it('constructs a SkillMaintainerExecutionRuntime', () => { + it('constructs a SkillMaintainerExecutionRuntime backed by a workspace-scoped document service', () => { const runtime = skillManagementRuntime.factory({ serverDB: {} as never, toolManifestMap: {}, userId: 'user-1', + workspaceId: 'ws-1', }); expect(runtime).toBeInstanceOf(SkillMaintainerExecutionRuntime); - expect(SkillManagementDocumentService).toHaveBeenCalledWith({}, 'user-1'); + expect(SkillManagementDocumentService).toHaveBeenCalledWith({}, 'user-1', 'ws-1'); }); }); diff --git a/src/server/services/toolExecution/serverRuntimes/__tests__/topicReference.test.ts b/src/server/services/toolExecution/serverRuntimes/__tests__/topicReference.test.ts index dc464140a0..4e3bb2b7d6 100644 --- a/src/server/services/toolExecution/serverRuntimes/__tests__/topicReference.test.ts +++ b/src/server/services/toolExecution/serverRuntimes/__tests__/topicReference.test.ts @@ -20,6 +20,8 @@ vi.mock('@/database/models/message', () => ({ })); // Import after mock setup +const { MessageModel } = await import('@/database/models/message'); +const { TopicModel } = await import('@/database/models/topic'); const { topicReferenceRuntime } = await import('../topicReference'); describe('topicReferenceRuntime', () => { @@ -62,6 +64,24 @@ describe('topicReferenceRuntime', () => { expect(runtime).toBeDefined(); expect(typeof runtime.getTopicContext).toBe('function'); }); + + it('should scope models to workspace context', async () => { + const serverDB = {} as any; + const runtime = topicReferenceRuntime.factory({ + serverDB, + toolManifestMap: {}, + userId: 'user-1', + workspaceId: 'workspace-1', + }); + + mockTopicModelFindById.mockResolvedValue({ id: 'topic-1', title: 'Topic' }); + mockMessageModelQuery.mockResolvedValue([]); + + await runtime.getTopicContext({ topicId: 'topic-1' }); + + expect(TopicModel).toHaveBeenCalledWith(serverDB, 'user-1', 'workspace-1'); + expect(MessageModel).toHaveBeenCalledWith(serverDB, 'user-1', 'workspace-1'); + }); }); describe('getTopicContext', () => { diff --git a/src/server/services/toolExecution/serverRuntimes/activator.ts b/src/server/services/toolExecution/serverRuntimes/activator.ts index 686d407159..f4a022a06a 100644 --- a/src/server/services/toolExecution/serverRuntimes/activator.ts +++ b/src/server/services/toolExecution/serverRuntimes/activator.ts @@ -67,7 +67,7 @@ export const activatorRuntime: ServerRuntimeRegistration = { // Create SkillsExecutionRuntime for activateSkill delegation let skillsRuntime: SkillsExecutionRuntime | undefined; if (context.serverDB && context.userId) { - const skillModel = new AgentSkillModel(context.serverDB, context.userId); + const skillModel = new AgentSkillModel(context.serverDB, context.userId, context.workspaceId); skillsRuntime = new SkillsExecutionRuntime({ builtinSkills: filterBuiltinSkills(builtinSkills), service: { diff --git a/src/server/services/toolExecution/serverRuntimes/agentBuilder.ts b/src/server/services/toolExecution/serverRuntimes/agentBuilder.ts index 6ce6ff7596..f4c011364a 100644 --- a/src/server/services/toolExecution/serverRuntimes/agentBuilder.ts +++ b/src/server/services/toolExecution/serverRuntimes/agentBuilder.ts @@ -31,8 +31,8 @@ export const agentBuilderRuntime: ServerRuntimeRegistration = { throw new Error('userId and serverDB are required for Agent Builder execution'); } - const agentModel = new AgentModel(context.serverDB, context.userId); - const pluginModel = new PluginModel(context.serverDB, context.userId); + const agentModel = new AgentModel(context.serverDB, context.userId, context.workspaceId); + const pluginModel = new PluginModel(context.serverDB, context.userId, context.workspaceId); const aiInfraRepos = new AiInfraRepos(context.serverDB, context.userId, {}); const discoverService = new DiscoverService(); diff --git a/src/server/services/toolExecution/serverRuntimes/agentDocuments.ts b/src/server/services/toolExecution/serverRuntimes/agentDocuments.ts index c3960b4e19..084142a1f7 100644 --- a/src/server/services/toolExecution/serverRuntimes/agentDocuments.ts +++ b/src/server/services/toolExecution/serverRuntimes/agentDocuments.ts @@ -1,8 +1,10 @@ import type { DocumentLoadRule } from '@lobechat/agent-templates'; import { AgentDocumentsIdentifier } from '@lobechat/builtin-tool-agent-documents'; import { AgentDocumentsExecutionRuntime } from '@lobechat/builtin-tool-agent-documents/executionRuntime'; +import { eq } from 'drizzle-orm'; import { TaskModel } from '@/database/models/task'; +import { tasks } from '@/database/schemas'; import { AgentDocumentsService } from '@/server/services/agentDocuments'; import { emitAgentDocumentToolOutcomeSafely } from '@/server/services/agentDocuments/toolOutcome'; @@ -14,9 +16,9 @@ export const agentDocumentsRuntime: ServerRuntimeRegistration = { throw new Error('userId and serverDB are required for Agent Documents execution'); } - const service = new AgentDocumentsService(context.serverDB, context.userId); + const db = context.serverDB; const userId = context.userId; - const taskModel = new TaskModel(context.serverDB, context.userId); + const service = new AgentDocumentsService(db, userId, context.workspaceId); const { taskId } = context; const emitDocumentOutcome = async (input: { agentId?: string; @@ -90,6 +92,18 @@ export const agentDocumentsRuntime: ServerRuntimeRegistration = { const pinToTask = async (doc: T): Promise => { if (taskId && doc?.documentId) { + // Prefer the workspaceId already threaded through the pipeline; fall + // back to the owning task row for legacy callers. + let wsId = context.workspaceId; + if (!wsId) { + const [row] = await db + .select({ workspaceId: tasks.workspaceId }) + .from(tasks) + .where(eq(tasks.id, taskId)) + .limit(1); + wsId = row?.workspaceId ?? undefined; + } + const taskModel = new TaskModel(db, userId, wsId); await taskModel.pinDocument(taskId, doc.documentId, 'agent'); } return doc; diff --git a/src/server/services/toolExecution/serverRuntimes/agentManagement.ts b/src/server/services/toolExecution/serverRuntimes/agentManagement.ts index f183985247..5b16924d53 100644 --- a/src/server/services/toolExecution/serverRuntimes/agentManagement.ts +++ b/src/server/services/toolExecution/serverRuntimes/agentManagement.ts @@ -32,8 +32,8 @@ export const agentManagementRuntime: ServerRuntimeRegistration = { throw new Error('userId and serverDB are required for Agent Management execution'); } - const agentModel = new AgentModel(context.serverDB, context.userId); - const pluginModel = new PluginModel(context.serverDB, context.userId); + const agentModel = new AgentModel(context.serverDB, context.userId, context.workspaceId); + const pluginModel = new PluginModel(context.serverDB, context.userId, context.workspaceId); const discoverService = new DiscoverService(); return { diff --git a/src/server/services/toolExecution/serverRuntimes/agentSignalFeedbackIntent.ts b/src/server/services/toolExecution/serverRuntimes/agentSignalFeedbackIntent.ts index 96e1916e83..c69a538f10 100644 --- a/src/server/services/toolExecution/serverRuntimes/agentSignalFeedbackIntent.ts +++ b/src/server/services/toolExecution/serverRuntimes/agentSignalFeedbackIntent.ts @@ -22,7 +22,7 @@ import type { ServerRuntimeRegistration } from './types'; */ export const agentSignalFeedbackIntentRuntime: ServerRuntimeRegistration = { factory: (context) => { - const { agentId, serverDB, userId } = context; + const { agentId, serverDB, userId, workspaceId } = context; if (!agentId || !userId || !serverDB) { throw new Error('agent-signal-feedback-intent requires agentId, userId and serverDB'); } @@ -32,8 +32,9 @@ export const agentSignalFeedbackIntentRuntime: ServerRuntimeRegistration = { db: serverDB, memoryReason: (count) => `Agent Signal self-feedback intent memory candidate from ${count} evidence refs.`, - skillDocumentService: new SkillManagementDocumentService(serverDB, userId), + skillDocumentService: new SkillManagementDocumentService(serverDB, userId, workspaceId), userId, + workspaceId, }); return new AgentSignalToolExecutionRuntime({ diff --git a/src/server/services/toolExecution/serverRuntimes/agentSignalReflection.ts b/src/server/services/toolExecution/serverRuntimes/agentSignalReflection.ts index 562018e87e..3247231e25 100644 --- a/src/server/services/toolExecution/serverRuntimes/agentSignalReflection.ts +++ b/src/server/services/toolExecution/serverRuntimes/agentSignalReflection.ts @@ -18,7 +18,7 @@ import type { ServerRuntimeRegistration } from './types'; */ export const agentSignalReflectionRuntime: ServerRuntimeRegistration = { factory: (context) => { - const { agentId, serverDB, userId } = context; + const { agentId, serverDB, userId, workspaceId } = context; if (!agentId || !userId || !serverDB) { throw new Error('agent-signal-reflection requires agentId, userId and serverDB'); } @@ -28,8 +28,9 @@ export const agentSignalReflectionRuntime: ServerRuntimeRegistration = { db: serverDB, memoryReason: (count) => `Agent Signal self-reflection memory candidate from ${count} evidence refs.`, - skillDocumentService: new SkillManagementDocumentService(serverDB, userId), + skillDocumentService: new SkillManagementDocumentService(serverDB, userId, workspaceId), userId, + workspaceId, }); return new AgentSignalToolExecutionRuntime({ diff --git a/src/server/services/toolExecution/serverRuntimes/agentSignalReview.ts b/src/server/services/toolExecution/serverRuntimes/agentSignalReview.ts index 2860a91f2e..a4d4af05d5 100644 --- a/src/server/services/toolExecution/serverRuntimes/agentSignalReview.ts +++ b/src/server/services/toolExecution/serverRuntimes/agentSignalReview.ts @@ -35,12 +35,14 @@ const resolveBriefTextTranslator = async (db: LobeChatDatabase, userId: string) */ export const agentSignalReviewRuntime: ServerRuntimeRegistration = { factory: async (context) => { - const { agentId, operationId, serverDB, userId } = context; + const { agentId, operationId, serverDB, userId, workspaceId } = context; if (!agentId || !userId || !operationId || !serverDB) { throw new Error('agent-signal-review requires agentId, userId, operationId and serverDB'); } - const operation = await new AgentOperationModel(serverDB, userId).findById(operationId); + const operation = await new AgentOperationModel(serverDB, userId, workspaceId).findById( + operationId, + ); const marker = readAgentSignalMarker(operation?.metadata); const reviewWindowEnd = marker?.reviewWindowEnd ?? new Date(0).toISOString(); @@ -50,16 +52,17 @@ export const agentSignalReviewRuntime: ServerRuntimeRegistration = { const service = createReviewRuntimePrimitives({ agentId, - briefModel: new BriefModel(serverDB, userId), + briefModel: new BriefModel(serverDB, userId, workspaceId), briefTextTranslator: await resolveBriefTextTranslator(serverDB, userId), db: serverDB, localDate, - proposalBriefWriter: createServerSelfReviewBriefWriter(serverDB, userId), + proposalBriefWriter: createServerSelfReviewBriefWriter(serverDB, userId, workspaceId), reviewWindowEnd, reviewWindowStart, - skillDocumentService: new SkillManagementDocumentService(serverDB, userId), + skillDocumentService: new SkillManagementDocumentService(serverDB, userId, workspaceId), sourceId, userId, + workspaceId, }); return new AgentSignalToolExecutionRuntime({ diff --git a/src/server/services/toolExecution/serverRuntimes/agentSignalSkillManagement.ts b/src/server/services/toolExecution/serverRuntimes/agentSignalSkillManagement.ts index 390f3dc64a..c9be07e7b6 100644 --- a/src/server/services/toolExecution/serverRuntimes/agentSignalSkillManagement.ts +++ b/src/server/services/toolExecution/serverRuntimes/agentSignalSkillManagement.ts @@ -18,7 +18,7 @@ import type { ServerRuntimeRegistration } from './types'; */ export const agentSignalSkillManagementRuntime: ServerRuntimeRegistration = { factory: (context) => { - const { agentId, serverDB, userId } = context; + const { agentId, serverDB, userId, workspaceId } = context; if (!agentId || !userId || !serverDB) { throw new Error('agent-signal-skill-management requires agentId, userId and serverDB'); } @@ -28,7 +28,7 @@ export const agentSignalSkillManagementRuntime: ServerRuntimeRegistration = { db: serverDB, memoryReason: (count) => `Agent Signal skill-management memory candidate from ${count} evidence refs.`, - skillDocumentService: new SkillManagementDocumentService(serverDB, userId), + skillDocumentService: new SkillManagementDocumentService(serverDB, userId, workspaceId), userId, }); diff --git a/src/server/services/toolExecution/serverRuntimes/brief.ts b/src/server/services/toolExecution/serverRuntimes/brief.ts index 3172015d80..ddd0f482b6 100644 --- a/src/server/services/toolExecution/serverRuntimes/brief.ts +++ b/src/server/services/toolExecution/serverRuntimes/brief.ts @@ -1,75 +1,31 @@ import { BriefIdentifier } from '@lobechat/builtin-tool-brief'; +import type { LobeChatDatabase } from '@lobechat/database'; import { formatBriefCreated, formatCheckpointCreated } from '@lobechat/prompts'; import { DEFAULT_BRIEF_ACTIONS } from '@lobechat/types'; +import { eq } from 'drizzle-orm'; import { BriefModel } from '@/database/models/brief'; import { TaskModel } from '@/database/models/task'; +import { tasks } from '@/database/schemas'; import { type ServerRuntimeRegistration } from './types'; -const createBriefRuntime = ({ - agentId, - briefModel, - taskId, - taskModel, -}: { - agentId?: string; - briefModel: BriefModel; - taskId?: string; - taskModel: TaskModel; -}) => ({ - createBrief: async (args: { - actions?: Array<{ key: string; label: string; type: string }>; - priority?: string; - summary: string; - title: string; - type: string; - }) => { - // 'result' briefs are terminal — the UI hardcodes a single approve action - // and routes it through BriefService.resolve to complete the task. Custom - // actions on result briefs would be ignored, so reject them at the source. - const actions = - args.type === 'result' ? null : args.actions || DEFAULT_BRIEF_ACTIONS[args.type] || []; - - const brief = await briefModel.create({ - actions, - agentId, - priority: args.priority || 'info', - summary: args.summary, - taskId, - title: args.title, - type: args.type, - }); - - return { - content: formatBriefCreated({ - id: brief.id, - priority: args.priority || 'info', - summary: args.summary, - title: args.title, - type: args.type, - }), - success: true, - }; - }, - - requestCheckpoint: async (args: { reason: string }) => { - if (taskId) { - await taskModel.updateStatus(taskId, 'paused'); - } - - await briefModel.create({ - agentId, - priority: 'normal', - summary: args.reason, - taskId, - title: 'Checkpoint requested', - type: 'decision', - }); - - return { content: formatCheckpointCreated(args.reason), success: true }; - }, -}); +// Row-level fallback: the agent-runtime hasn't threaded `workspaceId` into +// `ToolExecutionContext` yet, so we resolve it from the task row when the +// runtime fires inside a task. Falls back to undefined (personal mode) when +// there is no task association. +const resolveWorkspaceId = async ( + db: LobeChatDatabase, + taskId: string | undefined, +): Promise => { + if (!taskId) return undefined; + const [row] = await db + .select({ workspaceId: tasks.workspaceId }) + .from(tasks) + .where(eq(tasks.id, taskId)) + .limit(1); + return row?.workspaceId ?? undefined; +}; export const briefRuntime: ServerRuntimeRegistration = { factory: (context) => { @@ -77,15 +33,73 @@ export const briefRuntime: ServerRuntimeRegistration = { throw new Error('userId and serverDB are required for Brief tool execution'); } - const briefModel = new BriefModel(context.serverDB, context.userId); - const taskModel = new TaskModel(context.serverDB, context.userId); + const db = context.serverDB; + const userId = context.userId; + const { agentId, taskId } = context; + // Prefer the workspaceId threaded through the pipeline. Fall back to the + // owning task row when an older caller still doesn't populate it. + const resolveWs = async () => context.workspaceId ?? (await resolveWorkspaceId(db, taskId)); - return createBriefRuntime({ - agentId: context.agentId, - briefModel, - taskId: context.taskId, - taskModel, - }); + return { + createBrief: async (args: { + actions?: Array<{ key: string; label: string; type: string }>; + priority?: string; + summary: string; + title: string; + type: string; + }) => { + // 'result' briefs are terminal — the UI hardcodes a single approve action + // and routes it through BriefService.resolve to complete the task. Custom + // actions on result briefs would be ignored, so reject them at the source. + const actions = + args.type === 'result' ? null : args.actions || DEFAULT_BRIEF_ACTIONS[args.type] || []; + + const workspaceId = await resolveWs(); + const briefModel = new BriefModel(db, userId, workspaceId); + + const brief = await briefModel.create({ + actions, + agentId, + priority: args.priority || 'info', + summary: args.summary, + taskId, + title: args.title, + type: args.type, + }); + + return { + content: formatBriefCreated({ + id: brief.id, + priority: args.priority || 'info', + summary: args.summary, + title: args.title, + type: args.type, + }), + success: true, + }; + }, + + requestCheckpoint: async (args: { reason: string }) => { + const workspaceId = await resolveWs(); + const briefModel = new BriefModel(db, userId, workspaceId); + const taskModel = new TaskModel(db, userId, workspaceId); + + if (taskId) { + await taskModel.updateStatus(taskId, 'paused'); + } + + await briefModel.create({ + agentId, + priority: 'normal', + summary: args.reason, + taskId, + title: 'Checkpoint requested', + type: 'decision', + }); + + return { content: formatCheckpointCreated(args.reason), success: true }; + }, + }; }, identifier: BriefIdentifier, }; diff --git a/src/server/services/toolExecution/serverRuntimes/cloudSandbox.ts b/src/server/services/toolExecution/serverRuntimes/cloudSandbox.ts index 024b044493..c2eef901f4 100644 --- a/src/server/services/toolExecution/serverRuntimes/cloudSandbox.ts +++ b/src/server/services/toolExecution/serverRuntimes/cloudSandbox.ts @@ -24,7 +24,7 @@ export const cloudSandboxRuntime: ServerRuntimeRegistration = { } const marketService = new MarketService({ userInfo: { userId: context.userId } }); - const fileService = new FileService(context.serverDB, context.userId); + const fileService = new FileService(context.serverDB, context.userId, context.workspaceId); const sandboxService = createSandboxService({ fileService, marketService, diff --git a/src/server/services/toolExecution/serverRuntimes/knowledgeBase.ts b/src/server/services/toolExecution/serverRuntimes/knowledgeBase.ts index 5739aa2f98..606f59c8bc 100644 --- a/src/server/services/toolExecution/serverRuntimes/knowledgeBase.ts +++ b/src/server/services/toolExecution/serverRuntimes/knowledgeBase.ts @@ -13,18 +13,18 @@ import { type ServerRuntimeRegistration } from './types'; export const knowledgeBaseRuntime: ServerRuntimeRegistration = { factory: (context) => { - const { userId, serverDB, agentId } = context; + const { userId, serverDB, agentId, workspaceId } = context; if (!userId || !serverDB) { throw new Error('userId and serverDB are required for Knowledge Base execution'); } - const fileModel = new FileModel(serverDB, userId); - const knowledgeBaseModel = new KnowledgeBaseModel(serverDB, userId); - const knowledgeRepo = new KnowledgeRepo(serverDB, userId); - const documentService = new DocumentService(serverDB, userId); - const fileService = new FileService(serverDB, userId); - const searchService = new KnowledgeBaseSearchService(serverDB, userId); - const agentModel = agentId ? new AgentModel(serverDB, userId) : null; + const fileModel = new FileModel(serverDB, userId, workspaceId); + const knowledgeBaseModel = new KnowledgeBaseModel(serverDB, userId, workspaceId); + const knowledgeRepo = new KnowledgeRepo(serverDB, userId, workspaceId); + const documentService = new DocumentService(serverDB, userId, workspaceId); + const fileService = new FileService(serverDB, userId, workspaceId); + const searchService = new KnowledgeBaseSearchService(serverDB, userId, workspaceId); + const agentModel = agentId ? new AgentModel(serverDB, userId, workspaceId) : null; const resolveAgentKnowledgeBaseIds = async (override?: string[]): Promise => { if (override && override.length > 0) return override; diff --git a/src/server/services/toolExecution/serverRuntimes/lobeAgent.ts b/src/server/services/toolExecution/serverRuntimes/lobeAgent.ts index e14e8a6cd8..4b019e1d85 100644 --- a/src/server/services/toolExecution/serverRuntimes/lobeAgent.ts +++ b/src/server/services/toolExecution/serverRuntimes/lobeAgent.ts @@ -46,6 +46,7 @@ interface LobeAgentRuntimeContext { threadId?: string | null; topicId?: string; userId: string; + workspaceId?: string; } const buildError = (content: string, code: string): BuiltinServerRuntimeOutput => ({ @@ -82,6 +83,7 @@ class LobeAgentExecutionRuntime { private threadId?: string | null; private topicId?: string; private planRuntime: PlanExecutionRuntime; + private workspaceId?: string; constructor(context: LobeAgentRuntimeContext) { this.agentId = context.agentId; @@ -92,8 +94,9 @@ class LobeAgentExecutionRuntime { this.threadId = context.threadId; this.topicId = context.topicId; this.userId = context.userId; + this.workspaceId = context.workspaceId; this.planRuntime = new PlanExecutionRuntime( - createServerPlanRuntimeService(context.serverDB, context.userId), + createServerPlanRuntimeService(context.serverDB, context.userId, context.workspaceId), ); } @@ -238,8 +241,8 @@ class LobeAgentExecutionRuntime { let selectedRefItems: VisualFileItem[] = []; if (requestedRefs.length > 0) { - const fileService = new FileService(this.db, this.userId); - const messageModel = new MessageModel(this.db, this.userId); + const fileService = new FileService(this.db, this.userId, this.workspaceId); + const messageModel = new MessageModel(this.db, this.userId, this.workspaceId); const postProcessUrl = ( path: string | null, file: { fileType: string; id?: string | null }, @@ -317,7 +320,7 @@ class LobeAgentExecutionRuntime { let content = ''; let usage: unknown; - const runtime = await initModelRuntimeFromDB(this.db, this.userId, provider); + const runtime = await initModelRuntimeFromDB(this.db, this.userId, provider, this.workspaceId); const payload = { messages: [ { @@ -383,6 +386,7 @@ export const lobeAgentRuntime: ServerRuntimeRegistration = { threadId: context.threadId, topicId: context.topicId, userId: context.userId, + workspaceId: context.workspaceId, }); }, identifier: LobeAgentIdentifier, diff --git a/src/server/services/toolExecution/serverRuntimes/lobeAgentPlan.ts b/src/server/services/toolExecution/serverRuntimes/lobeAgentPlan.ts index c87e93915d..2e604920a2 100644 --- a/src/server/services/toolExecution/serverRuntimes/lobeAgentPlan.ts +++ b/src/server/services/toolExecution/serverRuntimes/lobeAgentPlan.ts @@ -15,9 +15,10 @@ import { TopicDocumentModel } from '@/database/models/topicDocument'; export const createServerPlanRuntimeService = ( serverDB: LobeChatDatabase, userId: string, + workspaceId?: string, ): PlanRuntimeService => { - const documentModel = new DocumentModel(serverDB, userId); - const topicDocumentModel = new TopicDocumentModel(serverDB, userId); + const documentModel = new DocumentModel(serverDB, userId, workspaceId); + const topicDocumentModel = new TopicDocumentModel(serverDB, userId, workspaceId); const toPlanDocument = (doc: { content: string | null; diff --git a/src/server/services/toolExecution/serverRuntimes/memory.ts b/src/server/services/toolExecution/serverRuntimes/memory.ts index fd90b58341..551f24f1a4 100644 --- a/src/server/services/toolExecution/serverRuntimes/memory.ts +++ b/src/server/services/toolExecution/serverRuntimes/memory.ts @@ -87,7 +87,11 @@ const applySearchLimitsByEffort = ( }; }; -const getEmbeddingRuntime = async (serverDB: LobeChatDatabase, userId: string) => { +const getEmbeddingRuntime = async ( + serverDB: LobeChatDatabase, + userId: string, + workspaceId?: string, +) => { const { provider, model: embeddingModel } = getServerDefaultFilesConfig().embeddingModel || DEFAULT_USER_MEMORY_EMBEDDING_MODEL_ITEM; @@ -95,6 +99,7 @@ const getEmbeddingRuntime = async (serverDB: LobeChatDatabase, userId: string) = serverDB, userId, ENABLE_BUSINESS_FEATURES ? BRANDING_PROVIDER : provider, + workspaceId, ); return { agentRuntime, embeddingModel }; @@ -133,6 +138,7 @@ class MemoryServerRuntimeService implements MemoryRuntimeService { private memoryEffort: MemoryEffort; private memoryEmbeddingRuntime?: ToolExecutionMemoryEmbeddingRuntime; private userId: string; + private workspaceId?: string; constructor(options: { agentId?: string; @@ -147,6 +153,7 @@ class MemoryServerRuntimeService implements MemoryRuntimeService { toolCallId?: string; topicId?: string; userId: string; + workspaceId?: string; }) { this.agentId = options.agentId; this.emitOutcome = options.emitOutcome; @@ -160,6 +167,7 @@ class MemoryServerRuntimeService implements MemoryRuntimeService { this.memoryEffort = options.memoryEffort; this.memoryEmbeddingRuntime = options.memoryEmbeddingRuntime; this.userId = options.userId; + this.workspaceId = options.workspaceId; } private emitUserMemoryOutcome = async (input: { @@ -212,7 +220,12 @@ class MemoryServerRuntimeService implements MemoryRuntimeService { this.memoryEmbeddingRuntime.payload, { userId: this.userId }, ) - : await initModelRuntimeFromDB(this.serverDB, this.userId, defaultEmbeddingConfig.provider); + : await initModelRuntimeFromDB( + this.serverDB, + this.userId, + defaultEmbeddingConfig.provider, + this.workspaceId, + ); const normalizedQueries = [ ...new Set((normalizedParams.queries ?? []).map((query) => query.trim()).filter(Boolean)), ]; @@ -263,6 +276,7 @@ class MemoryServerRuntimeService implements MemoryRuntimeService { const { agentRuntime, embeddingModel } = await getEmbeddingRuntime( this.serverDB, this.userId, + this.workspaceId, ); const embed = createEmbedder(agentRuntime, embeddingModel, this.userId); @@ -334,6 +348,7 @@ class MemoryServerRuntimeService implements MemoryRuntimeService { const { agentRuntime, embeddingModel } = await getEmbeddingRuntime( this.serverDB, this.userId, + this.workspaceId, ); const embed = createEmbedder(agentRuntime, embeddingModel, this.userId); @@ -412,6 +427,7 @@ class MemoryServerRuntimeService implements MemoryRuntimeService { const { agentRuntime, embeddingModel } = await getEmbeddingRuntime( this.serverDB, this.userId, + this.workspaceId, ); const embed = createEmbedder(agentRuntime, embeddingModel, this.userId); @@ -484,6 +500,7 @@ class MemoryServerRuntimeService implements MemoryRuntimeService { const { agentRuntime, embeddingModel } = await getEmbeddingRuntime( this.serverDB, this.userId, + this.workspaceId, ); const embed = createEmbedder(agentRuntime, embeddingModel, this.userId); @@ -568,6 +585,7 @@ class MemoryServerRuntimeService implements MemoryRuntimeService { const { agentRuntime, embeddingModel } = await getEmbeddingRuntime( this.serverDB, this.userId, + this.workspaceId, ); const embed = createEmbedder(agentRuntime, embeddingModel, this.userId); @@ -644,6 +662,7 @@ class MemoryServerRuntimeService implements MemoryRuntimeService { const { agentRuntime, embeddingModel } = await getEmbeddingRuntime( this.serverDB, this.userId, + this.workspaceId, ); const embed = createEmbedder(agentRuntime, embeddingModel, this.userId); @@ -873,6 +892,7 @@ export const memoryRuntime: ServerRuntimeRegistration = { toolCallId: context.toolCallId, topicId: context.topicId, userId: context.userId, + workspaceId: context.workspaceId, }); return new MemoryExecutionRuntime({ diff --git a/src/server/services/toolExecution/serverRuntimes/message/index.ts b/src/server/services/toolExecution/serverRuntimes/message/index.ts index 6fbfb9a910..4aca6d9e2a 100644 --- a/src/server/services/toolExecution/serverRuntimes/message/index.ts +++ b/src/server/services/toolExecution/serverRuntimes/message/index.ts @@ -346,10 +346,7 @@ export const messageRuntime: ServerRuntimeRegistration = { ((row.metadata as Record | null)?.tenantName as string) ?? '', })); - const telegramView = await maybeSynthesizeTelegramInstall( - context.serverDB, - context.userId, - ); + const telegramView = await maybeSynthesizeTelegramInstall(context.serverDB, context.userId); if (telegramView) installations.push(telegramView); return installations; @@ -506,9 +503,13 @@ export const messageRuntime: ServerRuntimeRegistration = { } } const linkModel = new MessengerAccountLinkModel(context.serverDB, context.userId); + // This in-chat tool only resolves personal-owned agents (validated + // above via `agents.userId === userId`), so the active scope is always + // personal (`workspaceId: null`). const updated = await linkModel.setActiveAgent( params.platform, params.agentId, + null, params.tenantId, ); if (!updated) { diff --git a/src/server/services/toolExecution/serverRuntimes/notebook.ts b/src/server/services/toolExecution/serverRuntimes/notebook.ts index bd542d7d56..27ed33ed00 100644 --- a/src/server/services/toolExecution/serverRuntimes/notebook.ts +++ b/src/server/services/toolExecution/serverRuntimes/notebook.ts @@ -18,6 +18,7 @@ export const notebookRuntime: ServerRuntimeRegistration = { const notebookService = new NotebookRuntimeService({ serverDB: context.serverDB, userId: context.userId, + workspaceId: context.workspaceId, }); return new NotebookExecutionRuntime(notebookService); diff --git a/src/server/services/toolExecution/serverRuntimes/skillManagement.ts b/src/server/services/toolExecution/serverRuntimes/skillManagement.ts index d8e3a87dc4..da4cec4b19 100644 --- a/src/server/services/toolExecution/serverRuntimes/skillManagement.ts +++ b/src/server/services/toolExecution/serverRuntimes/skillManagement.ts @@ -33,7 +33,11 @@ export const skillManagementRuntime: ServerRuntimeRegistration = { throw new Error('userId and serverDB are required for Skill Management execution'); } - const service = new SkillManagementDocumentService(context.serverDB, context.userId); + const service = new SkillManagementDocumentService( + context.serverDB, + context.userId, + context.workspaceId, + ); const runtimeService: SkillMaintainerRuntimeService = { createSkill: (params: CreateSkillArgs & { agentId: string }) => service.createSkill(params), diff --git a/src/server/services/toolExecution/serverRuntimes/skills.ts b/src/server/services/toolExecution/serverRuntimes/skills.ts index abf61776e4..ef732eca81 100644 --- a/src/server/services/toolExecution/serverRuntimes/skills.ts +++ b/src/server/services/toolExecution/serverRuntimes/skills.ts @@ -276,14 +276,18 @@ export const skillsRuntime: ServerRuntimeRegistration = { log('Failed to fetch market accessToken for user %s: %O', context.userId, error); } - const skillModel = new AgentSkillModel(context.serverDB, context.userId); - const resourceService = new SkillResourceService(context.serverDB, context.userId); + const skillModel = new AgentSkillModel(context.serverDB, context.userId, context.workspaceId); + const resourceService = new SkillResourceService( + context.serverDB, + context.userId, + context.workspaceId, + ); const marketService = new MarketService({ accessToken: marketAccessToken, userInfo: { userId: context.userId }, }); - const fileService = new FileService(context.serverDB, context.userId); - const fileModel = new FileModel(context.serverDB, context.userId); + const fileService = new FileService(context.serverDB, context.userId, context.workspaceId); + const fileModel = new FileModel(context.serverDB, context.userId, context.workspaceId); const service = new SkillServerRuntimeService({ fileModel, @@ -307,7 +311,7 @@ export const skillsRuntime: ServerRuntimeRegistration = { // result based on the identifier prefix so the inspector can show // "Activate Agent Skill" + the friendly `title`. const agentSkillBuiltins: BuiltinSkill[] = context.agentId - ? await new AgentDocumentsService(context.serverDB, context.userId) + ? await new AgentDocumentsService(context.serverDB, context.userId, context.workspaceId) .getAgentSkills(context.agentId) .then((skills) => skills.map((skill) => ({ diff --git a/src/server/services/toolExecution/serverRuntimes/task.ts b/src/server/services/toolExecution/serverRuntimes/task.ts index d0fdc989dc..9bbaf9f7d0 100644 --- a/src/server/services/toolExecution/serverRuntimes/task.ts +++ b/src/server/services/toolExecution/serverRuntimes/task.ts @@ -1,4 +1,5 @@ import { normalizeListTasksParams, TaskIdentifier } from '@lobechat/builtin-tool-task'; +import type { LobeChatDatabase } from '@lobechat/database'; import { formatDependencyAdded, formatDependencyRemoved, @@ -10,35 +11,56 @@ import { priorityLabel, } from '@lobechat/prompts'; import type { TaskAutomationMode, TaskStatus } from '@lobechat/types'; +import { eq } from 'drizzle-orm'; import { AgentModel } from '@/database/models/agent'; import { TaskModel } from '@/database/models/task'; +import { tasks } from '@/database/schemas'; import { taskRouter } from '@/server/routers/lambda/task'; import { TaskService } from '@/server/services/task'; import { type ServerRuntimeRegistration } from './types'; -export const createTaskRuntime = ({ - agentModel, - agentId, - scope, - taskId, - taskCaller, - taskModel, - taskService, -}: { - agentModel: AgentModel; +// Row-level workspace resolution: the agent runtime hasn't threaded +// `workspaceId` into `ToolExecutionContext` yet. When the tool fires inside a +// task we derive the workspace from that task row; otherwise we fall back to +// personal mode. +const resolveWorkspaceId = async ( + db: LobeChatDatabase, + taskId: string | undefined, +): Promise => { + if (!taskId) return undefined; + const [row] = await db + .select({ workspaceId: tasks.workspaceId }) + .from(tasks) + .where(eq(tasks.id, taskId)) + .limit(1); + return row?.workspaceId ?? undefined; +}; + +export interface TaskRuntimeDeps { agentId?: string; + agentModel: AgentModel; scope?: string | null; - taskId?: string; taskCaller: ReturnType; + taskId?: string; taskModel: TaskModel; taskService: TaskService; -}) => { +} + +export const createTaskRuntime = (deps: TaskRuntimeDeps) => { + const { agentId, scope, taskId } = deps; + // Models are read through `deps` (not destructured) so callers can swap them + // in lazily — e.g. after async workspace resolution in the runtime factory. + const agentModel = () => deps.agentModel; + const taskModel = () => deps.taskModel; + const taskService = () => deps.taskService; + const taskCaller = () => deps.taskCaller; + const resolveAssigneeAgent = async (assigneeAgentId?: string | null) => { if (!assigneeAgentId) return { success: true } as const; - const exists = await agentModel.existsById(assigneeAgentId); + const exists = await agentModel().existsById(assigneeAgentId); if (exists) return { success: true } as const; return { @@ -65,7 +87,7 @@ export const createTaskRuntime = ({ // and label, and pass the resolved id straight through to the service. let parentTaskId: string | undefined; if (args.parentIdentifier) { - const parent = await taskModel.resolve(args.parentIdentifier); + const parent = await taskModel().resolve(args.parentIdentifier); if (!parent) return { content: `Parent task not found: ${args.parentIdentifier}`, success: false }; parentTaskId = parent.id; @@ -75,7 +97,7 @@ export const createTaskRuntime = ({ const assigneeResult = await resolveAssigneeAgent(args.assigneeAgentId); if (!assigneeResult.success) return { content: assigneeResult.content, success: false }; - const task = await taskService.createTask({ + const task = await taskService().createTask({ assigneeAgentId: args.assigneeAgentId ?? (scope === 'task' ? undefined : agentId), createdByAgentId: agentId, instruction: args.instruction, @@ -110,7 +132,7 @@ export const createTaskRuntime = ({ } try { - const result = await taskCaller.addComment({ + const result = await taskCaller().addComment({ authorAgentId: agentId, content: args.content, id, @@ -174,10 +196,10 @@ export const createTaskRuntime = ({ }, deleteTask: async (args: { identifier: string }) => { - const task = await taskModel.resolve(args.identifier); + const task = await taskModel().resolve(args.identifier); if (!task) return { content: `Task not found: ${args.identifier}`, success: false }; - await taskModel.delete(task.id); + await taskModel().delete(task.id); return { content: formatTaskDeleted(task.identifier, task.name), @@ -187,7 +209,7 @@ export const createTaskRuntime = ({ deleteTaskComment: async (args: { commentId: string }) => { try { - await taskCaller.deleteComment({ commentId: args.commentId }); + await taskCaller().deleteComment({ commentId: args.commentId }); return { content: `Comment ${args.commentId} deleted.`, success: true, @@ -210,7 +232,7 @@ export const createTaskRuntime = ({ priority?: number; removeDependencies?: string[]; }) => { - const task = await taskModel.resolve(args.identifier); + const task = await taskModel().resolve(args.identifier); if (!task) return { content: `Task not found: ${args.identifier}`, success: false }; const updateData: { @@ -256,7 +278,7 @@ export const createTaskRuntime = ({ } if (Object.keys(updateData).length > 0) { - ops.push(taskCaller.update({ id: task.id, ...updateData })); + ops.push(taskCaller().update({ id: task.id, ...updateData })); } const applyDeps = async ( @@ -265,7 +287,11 @@ export const createTaskRuntime = ({ onChange: (depIdentifier: string) => void, ): Promise => { const resolved = await Promise.all( - ids.map((id) => taskModel.resolve(id).then((r) => ({ id, resolved: r }))), + ids.map((id) => + taskModel() + .resolve(id) + .then((r) => ({ id, resolved: r })), + ), ); const missing = resolved.find((r) => !r.resolved); if (missing) return `Dependency task not found: ${missing.id}`; @@ -279,7 +305,7 @@ export const createTaskRuntime = ({ depResults.push( applyDeps( args.addDependencies, - (depId) => taskModel.addDependency(task.id, depId), + (depId) => taskModel().addDependency(task.id, depId), (depIdentifier) => changes.push(formatDependencyAdded(task.identifier, depIdentifier)), ), ); @@ -288,7 +314,7 @@ export const createTaskRuntime = ({ depResults.push( applyDeps( args.removeDependencies, - (depId) => taskModel.removeDependency(task.id, depId), + (depId) => taskModel().removeDependency(task.id, depId), (depIdentifier) => changes.push(formatDependencyRemoved(task.identifier, depIdentifier)), ), @@ -316,7 +342,7 @@ export const createTaskRuntime = ({ }); try { - const result = await taskCaller.list(normalized.query); + const result = await taskCaller().list(normalized.query); return { content: formatTaskList(result.data, normalized.displayFilters), @@ -340,7 +366,7 @@ export const createTaskRuntime = ({ schedulePattern?: string | null; scheduleTimezone?: string | null; }) => { - const task = await taskModel.resolve(args.identifier); + const task = await taskModel().resolve(args.identifier); if (!task) return { content: `Task not found: ${args.identifier}`, success: false }; const changes: string[] = []; @@ -386,12 +412,12 @@ export const createTaskRuntime = ({ ); } if (Object.keys(scheduleUpdate).length > 0) { - ops.push(taskCaller.update({ id: task.id, ...scheduleUpdate })); + ops.push(taskCaller().update({ id: task.id, ...scheduleUpdate })); } if (args.maxExecutions !== undefined) { ops.push( - taskCaller.updateConfig({ + taskCaller().updateConfig({ config: { schedule: { maxExecutions: args.maxExecutions } }, id: task.id, }), @@ -422,7 +448,7 @@ export const createTaskRuntime = ({ } try { - const result = await taskCaller.run({ + const result = await taskCaller().run({ continueTopicId: args.continueTopicId, id, prompt: args.prompt, @@ -456,7 +482,7 @@ export const createTaskRuntime = ({ for (const [index, identifier] of identifiers.entries()) { try { - const result = await taskCaller.run({ id: identifier }); + const result = await taskCaller().run({ id: identifier }); const topicId = (result as { topicId?: string } | undefined)?.topicId; succeeded += 1; lines.push( @@ -482,7 +508,7 @@ export const createTaskRuntime = ({ updateTaskComment: async (args: { commentId: string; content: string }) => { try { - await taskCaller.updateComment({ commentId: args.commentId, content: args.content }); + await taskCaller().updateComment({ commentId: args.commentId, content: args.content }); return { content: `Comment ${args.commentId} updated.`, success: true, @@ -504,7 +530,7 @@ export const createTaskRuntime = ({ } try { - const result = await taskCaller.updateStatus({ + const result = await taskCaller().updateStatus({ error: args.error, id, status: args.status, @@ -536,7 +562,7 @@ export const createTaskRuntime = ({ }; } - const detail = await taskService.getTaskDetail(id); + const detail = await taskService().getTaskDetail(id); if (!detail) return { content: `Task not found: ${id}`, success: false }; return { @@ -553,20 +579,53 @@ export const taskRuntime: ServerRuntimeRegistration = { throw new Error('userId and serverDB are required for Task tool execution'); } - const agentModel = new AgentModel(context.serverDB, context.userId); - const taskModel = new TaskModel(context.serverDB, context.userId); - const taskService = new TaskService(context.serverDB, context.userId); - const taskCaller = taskRouter.createCaller({ userId: context.userId }); + const db = context.serverDB; + const userId = context.userId; + const { agentId, taskId, scope } = context; - return createTaskRuntime({ - agentModel, - agentId: context.agentId, - scope: context.scope, - taskCaller, - taskId: context.taskId, - taskModel, - taskService, - }); + // Models are wired in lazily after the workspaceId is resolved from the + // owning task row. `createTaskRuntime` reads them through this shared + // `deps` object, so re-assigning the fields below propagates into every + // method without re-creating the runtime. + const deps = { + agentId, + scope, + taskId, + // Initial personal-mode models cover the no-task-context case. Replaced + // before the first call when `taskId` is set. + agentModel: new AgentModel(db, userId), + taskModel: new TaskModel(db, userId), + taskService: new TaskService(db, userId), + taskCaller: taskRouter.createCaller({ userId }), + } as TaskRuntimeDeps; + + let resolved = false; + const ensureModels = async () => { + if (resolved) return; + resolved = true; + // Prefer pipeline-threaded `context.workspaceId`. Fall back to looking + // up the owning task row for callers that pre-date the propagation work + // and still construct `ToolExecutionContext` without `workspaceId`. + const wsId = context.workspaceId ?? (await resolveWorkspaceId(db, taskId)); + deps.agentModel = new AgentModel(db, userId, wsId); + deps.taskModel = new TaskModel(db, userId, wsId); + deps.taskService = new TaskService(db, userId, wsId); + deps.taskCaller = taskRouter.createCaller({ userId, workspaceId: wsId }); + }; + + const baseRuntime = createTaskRuntime(deps); + + // Wrap every method so that workspaceId + models are resolved before the + // delegate runs. Preserves the existing tool API shape. + return Object.fromEntries( + Object.entries(baseRuntime).map(([name, fn]) => [ + name, + async (...args: unknown[]) => { + await ensureModels(); + return (fn as (...a: unknown[]) => unknown)(...args); + }, + ]), + ) as typeof baseRuntime; }, identifier: TaskIdentifier, }; diff --git a/src/server/services/toolExecution/serverRuntimes/topicReference.ts b/src/server/services/toolExecution/serverRuntimes/topicReference.ts index 7ae959648c..7215f49346 100644 --- a/src/server/services/toolExecution/serverRuntimes/topicReference.ts +++ b/src/server/services/toolExecution/serverRuntimes/topicReference.ts @@ -16,10 +16,12 @@ interface GetTopicContextParams { class TopicReferenceExecutionRuntime { private db: LobeChatDatabase; private userId: string; + private workspaceId?: string; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; + this.workspaceId = workspaceId; } getTopicContext = async (params: GetTopicContextParams): Promise => { @@ -30,7 +32,7 @@ class TopicReferenceExecutionRuntime { } try { - const topicModel = new TopicModel(this.db, this.userId); + const topicModel = new TopicModel(this.db, this.userId, this.workspaceId); const topic = await topicModel.findById(topicId); if (!topic) { @@ -51,7 +53,7 @@ class TopicReferenceExecutionRuntime { // Fallback: fetch recent messages // Must pass agentId/groupId from topic, otherwise query filters by isNull(sessionId/groupId) - const messageModel = new MessageModel(this.db, this.userId); + const messageModel = new MessageModel(this.db, this.userId, this.workspaceId); const messages = await messageModel.query({ agentId: topic.agentId ?? undefined, groupId: topic.groupId ?? undefined, @@ -87,7 +89,11 @@ export const topicReferenceRuntime: ServerRuntimeRegistration = { if (!context.userId) { throw new Error('userId is required for TopicReference execution'); } - return new TopicReferenceExecutionRuntime(context.serverDB, context.userId); + return new TopicReferenceExecutionRuntime( + context.serverDB, + context.userId, + context.workspaceId, + ); }, identifier: TopicReferenceIdentifier, }; diff --git a/src/server/services/toolExecution/serverRuntimes/webBrowsing.ts b/src/server/services/toolExecution/serverRuntimes/webBrowsing.ts index 4c74d97ec5..708fb97507 100644 --- a/src/server/services/toolExecution/serverRuntimes/webBrowsing.ts +++ b/src/server/services/toolExecution/serverRuntimes/webBrowsing.ts @@ -16,14 +16,14 @@ export const webBrowsingRuntime: ServerRuntimeRegistration = { documentService: canSaveDocuments ? { associateDocument: async (documentId) => { - const service = new AgentDocumentsService(serverDB, userId); + const service = new AgentDocumentsService(serverDB, userId, context.workspaceId); await service.associateDocument(agentId, documentId); }, createDocument: async (params) => { // Same service the client trpc procedure uses — dedupe by URL, // short-circuit on byte-identical content, write a history // snapshot when content actually changed (). - const service = new WebBrowsingDocumentService(serverDB, userId); + const service = new WebBrowsingDocumentService(serverDB, userId, context.workspaceId); return service.upsertCrawledDocument(params); }, } diff --git a/src/server/services/toolExecution/serverRuntimes/webOnboarding.ts b/src/server/services/toolExecution/serverRuntimes/webOnboarding.ts index 223f09ce05..a310816ab1 100644 --- a/src/server/services/toolExecution/serverRuntimes/webOnboarding.ts +++ b/src/server/services/toolExecution/serverRuntimes/webOnboarding.ts @@ -13,7 +13,11 @@ export const webOnboardingRuntime: ServerRuntimeRegistration = { } const onboardingService = new OnboardingService(context.serverDB, context.userId); - const docService = new AgentDocumentsService(context.serverDB, context.userId); + const docService = new AgentDocumentsService( + context.serverDB, + context.userId, + context.workspaceId, + ); return new WebOnboardingExecutionRuntime({ finishOnboarding: () => onboardingService.finishOnboarding(), diff --git a/src/server/services/toolExecution/types.ts b/src/server/services/toolExecution/types.ts index 62d9e1fe02..ba3b548bfe 100644 --- a/src/server/services/toolExecution/types.ts +++ b/src/server/services/toolExecution/types.ts @@ -116,6 +116,14 @@ export interface ToolExecutionContext { /** Topic ID for sandbox session management */ topicId?: string; userId?: string; + /** + * Workspace ID that scopes ownership for any model/service the runtime + * instantiates. When unset the runtime falls back to personal mode + * (`workspace_id IS NULL`). Threaded from the chat/task router through + * `state.metadata.workspaceId` so tool side-effects (createBrief, pinTask, + * etc.) land in the same workspace the request originated from. + */ + workspaceId?: string; } export interface ToolExecutionResult { diff --git a/src/server/services/usage/index.ts b/src/server/services/usage/index.ts index fbed7c16f9..b3c4c5d37f 100644 --- a/src/server/services/usage/index.ts +++ b/src/server/services/usage/index.ts @@ -5,6 +5,7 @@ import { desc, eq } from 'drizzle-orm'; import { messages } from '@/database/schemas'; import { type LobeChatDatabase } from '@/database/type'; import { genRangeWhere, genWhere } from '@/database/utils/genWhere'; +import { buildWorkspaceWhere } from '@/database/utils/workspace'; import { type MessageMetadata } from '@/types/message'; import { type UsageLog, type UsageRecordItem } from '@/types/usage/usageRecord'; import { formatDate } from '@/utils/format'; @@ -13,9 +14,11 @@ const log = debug('lobe-usage:service'); export class UsageRecordService { private userId: string; + private workspaceId?: string; private db: LobeChatDatabase; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.userId = userId; + this.workspaceId = workspaceId; this.db = db; } @@ -38,7 +41,10 @@ export class UsageRecordService { .from(messages) .where( genWhere([ - eq(messages.userId, this.userId), + buildWorkspaceWhere( + { userId: this.userId, workspaceId: this.workspaceId }, + { userId: messages.userId, workspaceId: messages.workspaceId }, + ), eq(messages.role, 'assistant'), genRangeWhere([startAt, endAt], messages.createdAt, (date) => date.toDate()), ]), diff --git a/src/server/services/webBrowsing/index.ts b/src/server/services/webBrowsing/index.ts index 939ef10529..abf40dddbb 100644 --- a/src/server/services/webBrowsing/index.ts +++ b/src/server/services/webBrowsing/index.ts @@ -46,19 +46,21 @@ const hashContent = (content: string): string => Md5.hashStr(content); export class WebBrowsingDocumentService { private readonly db: LobeChatDatabase; private readonly userId: string; + private readonly workspaceId?: string; private readonly documentModel: DocumentModel; private readonly topicDocumentModel: TopicDocumentModel; private documentServiceInstance?: DocumentService; - constructor(db: LobeChatDatabase, userId: string) { + constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) { this.db = db; this.userId = userId; - this.documentModel = new DocumentModel(db, userId); - this.topicDocumentModel = new TopicDocumentModel(db, userId); + this.workspaceId = workspaceId; + this.documentModel = new DocumentModel(db, userId, workspaceId); + this.topicDocumentModel = new TopicDocumentModel(db, userId, workspaceId); } private get documentService() { - this.documentServiceInstance ??= new DocumentService(this.db, this.userId); + this.documentServiceInstance ??= new DocumentService(this.db, this.userId, this.workspaceId); return this.documentServiceInstance; } diff --git a/src/server/workflows-hono/memory-user-memory/workflows/processTopic.ts b/src/server/workflows-hono/memory-user-memory/workflows/processTopic.ts index c8a6f6fda2..0e4194d662 100644 --- a/src/server/workflows-hono/memory-user-memory/workflows/processTopic.ts +++ b/src/server/workflows-hono/memory-user-memory/workflows/processTopic.ts @@ -66,9 +66,11 @@ const processTopicRoute = async (context: WorkflowContext getServerDB().then((db) => - new AsyncTaskModel(db, userId).isUserMemoryExtractionCancellationRequested( - payload.asyncTaskId!, - ), + new AsyncTaskModel( + db, + userId, + payload.workspaceId, + ).isUserMemoryExtractionCancellationRequested(payload.asyncTaskId!), ), ); if (cancelled) { @@ -98,6 +100,7 @@ const processTopicRoute = async (context: WorkflowContext getServerDB().then((db) => - new AsyncTaskModel(db, userId).isUserMemoryExtractionCancellationRequested( - payload.asyncTaskId!, - ), + new AsyncTaskModel( + db, + userId, + payload.workspaceId, + ).isUserMemoryExtractionCancellationRequested(payload.asyncTaskId!), ), ); if (cancelled) { @@ -142,6 +147,7 @@ const processTopicRoute = async (context: WorkflowContext getServerDB().then((db) => - new AsyncTaskModel(db, userId).incrementUserMemoryExtractionProgress( - payload.asyncTaskId!, - ), + new AsyncTaskModel( + db, + userId, + payload.workspaceId, + ).incrementUserMemoryExtractionProgress(payload.asyncTaskId!), ), ); } @@ -199,7 +207,7 @@ export const processTopicWorkflow = createWorkflow getServerDB().then((db) => - new AsyncTaskModel(db, userId).isUserMemoryExtractionCancellationRequested( - payload.asyncTaskId!, - ), + new AsyncTaskModel( + db, + userId, + payload.workspaceId, + ).isUserMemoryExtractionCancellationRequested(payload.asyncTaskId!), ), ); if (cancelled) { diff --git a/src/server/workflows-hono/memory-user-memory/workflows/processUserTopics.ts b/src/server/workflows-hono/memory-user-memory/workflows/processUserTopics.ts index 727db2f944..b23e40ae5a 100644 --- a/src/server/workflows-hono/memory-user-memory/workflows/processUserTopics.ts +++ b/src/server/workflows-hono/memory-user-memory/workflows/processUserTopics.ts @@ -59,9 +59,11 @@ export const processUserTopicsHandler = async ( `memory:user-memory:extract:users:${userId}:cancel-check`, () => getServerDB().then((db) => - new AsyncTaskModel(db, userId).isUserMemoryExtractionCancellationRequested( - params.asyncTaskId!, - ), + new AsyncTaskModel( + db, + userId, + params.workspaceId, + ).isUserMemoryExtractionCancellationRequested(params.asyncTaskId!), ), ); if (cancelled) { @@ -82,7 +84,11 @@ export const processUserTopicsHandler = async ( ? await context.run( `memory:user-memory:extract:users:${userId}:filter-topic-ids`, async () => { - const filtered = await executor.filterTopicIdsForUser(userId, params.topicIds); + const filtered = await executor.filterTopicIdsForUser( + userId, + params.topicIds, + params.workspaceId, + ); return filtered.length > 0 ? filtered : undefined; }, ) @@ -102,6 +108,7 @@ export const processUserTopicsHandler = async ( from: params.from, to: params.to, userId, + workspaceId: params.workspaceId, }, TOPIC_PAGE_SIZE, ), diff --git a/src/server/workflows-hono/memory-user-memory/workflows/processUsers.ts b/src/server/workflows-hono/memory-user-memory/workflows/processUsers.ts index 6e79cbc159..d85aded757 100644 --- a/src/server/workflows-hono/memory-user-memory/workflows/processUsers.ts +++ b/src/server/workflows-hono/memory-user-memory/workflows/processUsers.ts @@ -29,9 +29,11 @@ export const processUsersHandler = async ( // If root task has cancelRequestedAt, this stage stops scheduling child workflows. const cancelled = await context.run('memory:user-memory:extract:cancel-check:root', () => getServerDB().then((db) => - new AsyncTaskModel(db, params.userIds[0]!).isUserMemoryExtractionCancellationRequested( - params.asyncTaskId!, - ), + new AsyncTaskModel( + db, + params.userIds[0]!, + params.workspaceId, + ).isUserMemoryExtractionCancellationRequested(params.asyncTaskId!), ), ); if (cancelled) { diff --git a/src/server/workflows-hono/task/handlers/onTopicComplete.ts b/src/server/workflows-hono/task/handlers/onTopicComplete.ts index 45072b39fe..9e2304cb8e 100644 --- a/src/server/workflows-hono/task/handlers/onTopicComplete.ts +++ b/src/server/workflows-hono/task/handlers/onTopicComplete.ts @@ -1,6 +1,8 @@ import debug from 'debug'; +import { and, eq } from 'drizzle-orm'; import type { Context } from 'hono'; +import { tasks } from '@/database/schemas'; import { getServerDB } from '@/database/server'; import { TaskLifecycleService } from '@/server/services/taskLifecycle'; @@ -46,7 +48,15 @@ export async function onTopicComplete(c: Context) { ); const db = await getServerDB(); - const taskLifecycle = new TaskLifecycleService(db, userId); + // System-level callback: derive workspace from the task row so the + // lifecycle service writes briefs / status into the correct workspace. + const [taskRow] = await db + .select({ workspaceId: tasks.workspaceId }) + .from(tasks) + .where(and(eq(tasks.id, taskId), eq(tasks.createdByUserId, userId))) + .limit(1); + const wsId = taskRow?.workspaceId ?? undefined; + const taskLifecycle = new TaskLifecycleService(db, userId, wsId); await taskLifecycle.onTopicComplete({ errorMessage, diff --git a/src/server/workflows-hono/task/handlers/watchdog.ts b/src/server/workflows-hono/task/handlers/watchdog.ts index b86c84b908..1be6be6f3f 100644 --- a/src/server/workflows-hono/task/handlers/watchdog.ts +++ b/src/server/workflows-hono/task/handlers/watchdog.ts @@ -23,13 +23,14 @@ export async function watchdog(c: Context) { const failed: string[] = []; for (const task of stuckTasks) { - const taskModel = new TaskModel(db, task.createdByUserId); + const wsId = task.workspaceId ?? undefined; + const taskModel = new TaskModel(db, task.createdByUserId, wsId); await taskModel.updateStatus(task.id, 'failed', { completedAt: new Date(), error: 'Heartbeat timeout', }); - const briefModel = new BriefModel(db, task.createdByUserId); + const briefModel = new BriefModel(db, task.createdByUserId, wsId); await briefModel.create({ agentId: task.assigneeAgentId || undefined, priority: 'urgent', diff --git a/src/server/workflows/agentEvalRun/index.ts b/src/server/workflows/agentEvalRun/index.ts index d1575e5985..d419407e9a 100644 --- a/src/server/workflows/agentEvalRun/index.ts +++ b/src/server/workflows/agentEvalRun/index.ts @@ -235,12 +235,12 @@ export class AgentEvalRunWorkflow { */ static async filterTestCasesNeedingExecution( db: LobeChatDatabase, - params: { runId: string; testCaseIds: string[]; userId: string }, + params: { runId: string; testCaseIds: string[]; userId: string; workspaceId?: string }, ): Promise { - const { runId, testCaseIds, userId } = params; + const { runId, testCaseIds, userId, workspaceId } = params; if (testCaseIds.length === 0) return []; - const agentEvalRunTopicModel = new AgentEvalRunTopicModel(db, userId); + const agentEvalRunTopicModel = new AgentEvalRunTopicModel(db, userId, workspaceId); // Get existing RunTopics for this run const existingRunTopics = await agentEvalRunTopicModel.findByRunId(runId); diff --git a/src/server/workflows/agentEvalRun/utils.ts b/src/server/workflows/agentEvalRun/utils.ts new file mode 100644 index 0000000000..a5b9d6c219 --- /dev/null +++ b/src/server/workflows/agentEvalRun/utils.ts @@ -0,0 +1,24 @@ +import { eq } from 'drizzle-orm'; + +import { agentEvalRuns } from '@/database/schemas'; +import type { LobeChatDatabase } from '@/database/type'; + +/** + * System-level workspace resolver for agent-eval-run workflow handlers. + * + * These workflow endpoints are server-to-server callbacks dispatched from + * QStash and do not carry a workspace context. We derive the workspace from + * the `runId` row so downstream `AgentEvalXxxModel` / `AgentEvalRunService` + * instances ownership-filter to the correct workspace. + */ +export const resolveAgentEvalRunWorkspace = async ( + db: LobeChatDatabase, + runId: string, +): Promise => { + const [row] = await db + .select({ workspaceId: agentEvalRuns.workspaceId }) + .from(agentEvalRuns) + .where(eq(agentEvalRuns.id, runId)) + .limit(1); + return row?.workspaceId ?? undefined; +}; diff --git a/src/server/workflows/agentSignal/run.ts b/src/server/workflows/agentSignal/run.ts index d3b971cdaa..c913f2eed6 100644 --- a/src/server/workflows/agentSignal/run.ts +++ b/src/server/workflows/agentSignal/run.ts @@ -25,6 +25,7 @@ import { and, desc, eq, isNull, lte } from 'drizzle-orm'; import { MessageModel } from '@/database/models/message'; import { getServerDB } from '@/database/server'; +import { buildWorkspaceWhere } from '@/database/utils/workspace'; import { extractTraceContext } from '@/libs/observability/traceparent'; import { isAgentSignalEnabledForUser } from '@/server/services/agentSignal/featureGate'; import { toAgentSignalTraceEvents } from '@/server/services/agentSignal/observability/traceEvents'; @@ -283,7 +284,12 @@ const persistWorkflowHydrationSkippedSnapshot = async ( const buildFeedbackSourceSerializedContext = async ( sourceEvent: SourceEventAgentUserMessage, - input: { contextEndAt?: Date; db: Awaited>; userId: string }, + input: { + contextEndAt?: Date; + db: Awaited>; + userId: string; + workspaceId?: string; + }, ): Promise => { if (typeof sourceEvent.payload.serializedContext === 'string') { return sourceEvent.payload.serializedContext; @@ -291,7 +297,7 @@ const buildFeedbackSourceSerializedContext = async ( if (typeof sourceEvent.payload.topicId !== 'string') return undefined; - const messageModel = new MessageModel(input.db, input.userId); + const messageModel = new MessageModel(input.db, input.userId, input.workspaceId); const anchorMessage = await messageModel.findById(sourceEvent.payload.messageId); if (!anchorMessage?.createdAt) return undefined; @@ -306,7 +312,7 @@ const buildFeedbackSourceSerializedContext = async ( limit: 10, orderBy: [desc(messages.createdAt)], where: and( - eq(messages.userId, input.userId), + buildWorkspaceWhere({ userId: input.userId, workspaceId: input.workspaceId }, messages), eq(messages.topicId, sourceEvent.payload.topicId), lte(messages.createdAt, contextEndAt), threadScopeFilter, @@ -328,7 +334,12 @@ const buildFeedbackSourceSerializedContext = async ( const enrichFeedbackSourceSerializedContext = async ( sourceEvent: AgentSignalWorkflowRunPayload['sourceEvent'], - input: { contextEndAt?: Date; db: Awaited>; userId: string }, + input: { + contextEndAt?: Date; + db: Awaited>; + userId: string; + workspaceId?: string; + }, ): Promise => { if (!isAgentUserMessageSource(sourceEvent)) return sourceEvent; @@ -370,7 +381,7 @@ const enrichFeedbackSourceSerializedContext = async ( */ const normalizeWorkflowSourceEvent = async ( sourceEvent: AgentSignalWorkflowRunPayload['sourceEvent'], - input: { db: Awaited>; userId: string }, + input: { db: Awaited>; userId: string; workspaceId?: string }, ): Promise<{ hydration?: WorkflowHydrationDiagnostic; sourceEvent: AgentSignalWorkflowRunPayload['sourceEvent']; @@ -519,6 +530,7 @@ export const runAgentSignalWorkflow = async ( const result = await normalizeWorkflowSourceEvent(payload.sourceEvent, { db, userId: payload.userId, + workspaceId: payload.workspaceId, }); if (result.hydration) { hydrationDiagnostic = result.hydration; @@ -567,6 +579,7 @@ export const runAgentSignalWorkflow = async ( db, selfIterationEnabled, userId: payload.userId, + workspaceId: payload.workspaceId, }) : undefined; const selfReflection = isSelfReflectionSource(normalizedSourceEvent) @@ -575,6 +588,7 @@ export const runAgentSignalWorkflow = async ( db, selfIterationEnabled, userId: payload.userId, + workspaceId: payload.workspaceId, }) : undefined; const selfFeedbackIntent = isSelfFeedbackIntentSource(normalizedSourceEvent) @@ -583,6 +597,7 @@ export const runAgentSignalWorkflow = async ( db, selfIterationEnabled, userId: payload.userId, + workspaceId: payload.workspaceId, }) : undefined; const procedure = isToolOutcomeSource(normalizedSourceEvent) @@ -591,6 +606,7 @@ export const runAgentSignalWorkflow = async ( db, selfIterationEnabled, userId: payload.userId, + workspaceId: payload.workspaceId, }) : undefined; const executionResult = await context.run( @@ -602,6 +618,7 @@ export const runAgentSignalWorkflow = async ( agentId: payload.agentId, db, userId: payload.userId, + workspaceId: payload.workspaceId, }, { policyOptions: { diff --git a/src/server/workflows/agentSignal/types.ts b/src/server/workflows/agentSignal/types.ts index cc5376c71a..fbb1dc63fa 100644 --- a/src/server/workflows/agentSignal/types.ts +++ b/src/server/workflows/agentSignal/types.ts @@ -19,4 +19,10 @@ export interface AgentSignalWorkflowRunPayload { sourceEvent: AgentSignalWorkflowSourceEventInput; /** Owner of the source event and all database lookups performed by the workflow worker. */ userId: string; + /** + * Workspace id when the source event originated inside a team workspace. + * Forwarded into the action handler context so workspace-scoped writes + * land in the correct workspace. + */ + workspaceId?: string; } diff --git a/src/types/transferError.ts b/src/types/transferError.ts new file mode 100644 index 0000000000..eebe6ea014 --- /dev/null +++ b/src/types/transferError.ts @@ -0,0 +1,10 @@ +export const TransferErrorCode = { + FileStorageLimitExceeded: 'FILE_STORAGE_LIMIT_EXCEEDED', + NoPermission: 'NO_PERMISSION', + OwnerOnly: 'OWNER_ONLY', + ResourceNotFound: 'RESOURCE_NOT_FOUND', + SameWorkspace: 'SAME_WORKSPACE', + TargetNoWriteAccess: 'TARGET_NO_WRITE_ACCESS', +} as const; + +export type TransferErrorCode = (typeof TransferErrorCode)[keyof typeof TransferErrorCode]; diff --git a/src/types/workspaceSettings.ts b/src/types/workspaceSettings.ts new file mode 100644 index 0000000000..c22c3dfa36 --- /dev/null +++ b/src/types/workspaceSettings.ts @@ -0,0 +1,24 @@ +/** + * Tab identifiers for the workspace-scoped settings surface + * (`/:workspaceSlug/settings/*`). + * + * Intentionally separate from `SettingsTabs` (personal settings) — the two + * surfaces evolve independently and must not share enum members. + */ +export enum WorkspaceSettingsTabs { + APIKey = 'apikey', + Billing = 'billing', + Credits = 'credits', + Creds = 'creds', + General = 'general', + Members = 'members', + Plans = 'plans', + Provider = 'provider', + ServiceModel = 'service-model', + Skill = 'skill', + Stats = 'stats', + Storage = 'storage', + Usage = 'usage', +} + +export const DEFAULT_WORKSPACE_SETTINGS_TAB = WorkspaceSettingsTabs.General; diff --git a/src/utils/dayjsLocale.test.ts b/src/utils/dayjsLocale.test.ts new file mode 100644 index 0000000000..0ab03860ef --- /dev/null +++ b/src/utils/dayjsLocale.test.ts @@ -0,0 +1,15 @@ +import { describe, expect, it } from 'vitest'; + +import { normalizeDayjsLocale } from './dayjsLocale'; + +describe('normalizeDayjsLocale', () => { + it('should normalize simplified Chinese script locales to zh-cn', () => { + expect(normalizeDayjsLocale('zh-Hans')).toBe('zh-cn'); + expect(normalizeDayjsLocale('zh-Hans-CN')).toBe('zh-cn'); + }); + + it('should normalize traditional Chinese script locales to zh-tw', () => { + expect(normalizeDayjsLocale('zh-Hant')).toBe('zh-tw'); + expect(normalizeDayjsLocale('zh-Hant-TW')).toBe('zh-tw'); + }); +}); diff --git a/src/utils/dayjsLocale.ts b/src/utils/dayjsLocale.ts index f29faa5392..7969bc9029 100644 --- a/src/utils/dayjsLocale.ts +++ b/src/utils/dayjsLocale.ts @@ -8,5 +8,8 @@ const DAYJS_LOCALE_ALIASES: Record = { export const normalizeDayjsLocale = (lang: string): string => { const lower = lang.toLowerCase(); + if (lower.startsWith('zh-hans')) return 'zh-cn'; + if (lower.startsWith('zh-hant')) return 'zh-tw'; + return DAYJS_LOCALE_ALIASES[lower] ?? lower; }; diff --git a/src/utils/locale.test.ts b/src/utils/locale.test.ts index 2735b102e8..b1a9335513 100644 --- a/src/utils/locale.test.ts +++ b/src/utils/locale.test.ts @@ -1,7 +1,21 @@ import { describe, expect, it } from 'vitest'; +import { normalizeLocale } from '@/locales/resources'; + import { parseBrowserLanguage } from './locale'; +describe('normalizeLocale', () => { + it('should normalize simplified Chinese script locales to zh-CN', () => { + expect(normalizeLocale('zh-Hans')).toBe('zh-CN'); + expect(normalizeLocale('zh-Hans-CN')).toBe('zh-CN'); + }); + + it('should normalize traditional Chinese script locales to zh-TW', () => { + expect(normalizeLocale('zh-Hant')).toBe('zh-TW'); + expect(normalizeLocale('zh-Hant-TW')).toBe('zh-TW'); + }); +}); + describe('parseBrowserLanguage', () => { // Helper function to create Headers with accept-language const createHeaders = (acceptLanguage?: string) => { @@ -38,6 +52,16 @@ describe('parseBrowserLanguage', () => { // This expectation might need to be adjusted based on your locales configuration expect(parseBrowserLanguage(headers)).toBe('zh-CN'); }); + + it('should normalize simplified Chinese script language preferences', () => { + const headers = createHeaders('zh-Hans-CN,zh-Hans;q=0.9,en;q=0.8'); + expect(parseBrowserLanguage(headers)).toBe('zh-CN'); + }); + + it('should normalize traditional Chinese script language preferences', () => { + const headers = createHeaders('zh-Hant-TW,zh-Hant;q=0.9,en;q=0.8'); + expect(parseBrowserLanguage(headers)).toBe('zh-TW'); + }); }); describe('when DEFAULT_LANG is not en-US', () => { diff --git a/src/utils/locale.ts b/src/utils/locale.ts index 93f89f3346..5dbaf4408b 100644 --- a/src/utils/locale.ts +++ b/src/utils/locale.ts @@ -6,6 +6,15 @@ import { locales, normalizeLocale } from '@/locales/resources'; import { RouteVariants } from './server/routeVariants'; +const normalizeAcceptLanguageHeader = (acceptLanguage: string) => + acceptLanguage + .replaceAll(/zh-Hans(?:-[a-z]{2})?/gi, 'zh-CN') + .replaceAll(/zh-Hant(?:-[a-z]{2})?/gi, 'zh-TW'); + +const supportedAcceptLanguageLocales = locales.map((locale) => + locale === 'ar' ? 'ar-EG' : locale, +); + export const getAntdLocale = async (lang?: string) => { let normalLang: any = normalizeLocale(lang); @@ -36,16 +45,16 @@ export const parseBrowserLanguage = (headers: Headers, defaultLang: string = DEF * 3) The default locale. */ let browserLang: string = resolveAcceptLanguage( - headers.get('accept-language') || '', + normalizeAcceptLanguageHeader(headers.get('accept-language') || ''), // Invalid locale identifier 'ar'. A valid locale should follow the BCP 47 'language-country' format. - locales.map((locale) => (locale === 'ar' ? 'ar-EG' : locale)), + supportedAcceptLanguageLocales, defaultLang, ); // if match the ar-EG then fallback to ar if (browserLang === 'ar-EG') browserLang = 'ar'; - return browserLang; + return normalizeLocale(browserLang); }; /** diff --git a/tests/mocks/storeWorkspace.ts b/tests/mocks/storeWorkspace.ts new file mode 100644 index 0000000000..3229e34566 --- /dev/null +++ b/tests/mocks/storeWorkspace.ts @@ -0,0 +1,42 @@ +// Stub for `@/store/workspace`. The real workspace store lives in the cloud +// repo; submodule-only tests don't have it on disk, so vite import-analysis +// fails without this alias. +// +// Returns a "no active workspace" state, which makes WorkspaceLink / +// useWorkspaceAwareNavigate behave like plain react-router. Tests that need +// an active workspace can spy on these selectors with `vi.spyOn`. + +export const workspaceSelectors = { + activeWorkspace: () => null, + activeWorkspaceId: () => null, + hasActiveWorkspace: () => false, + isContextReady: () => true, + isLoading: () => false, + isMember: () => false, + isOwner: () => false, + isSwitchingWorkspace: () => false, + isViewer: () => false, + members: () => [], + myRole: () => null, + primaryOwnerId: () => null, + workspaces: () => [], +}; + +const noopState = { + activeWorkspaceId: null, + isSwitchingWorkspace: false, + isWorkspaceLoading: false, + members: [], + myRole: null, + workspaces: [], +}; + +type Selector = (state: typeof noopState) => T; + +export function useWorkspaceStore(selector?: Selector): T | typeof noopState { + return selector ? selector(noopState) : noopState; +} + +useWorkspaceStore.getState = () => noopState; +useWorkspaceStore.setState = () => undefined; +useWorkspaceStore.subscribe = () => () => undefined; diff --git a/vitest.config.mts b/vitest.config.mts index baf39610d9..58490f559f 100644 --- a/vitest.config.mts +++ b/vitest.config.mts @@ -36,6 +36,10 @@ const alias = { '@/utils/electron': resolve(__dirname, './src/utils/electron'), '@/utils/markdownToTxt': resolve(__dirname, './src/utils/markdownToTxt'), '@/utils/sanitizeFileName': resolve(__dirname, './src/utils/sanitizeFileName'), + // Workspace store lives in the cloud repo; submodule-only tests get a stub + // that reports no active workspace so workspace-aware nav helpers behave + // like plain react-router. + '@/store/workspace': resolve(__dirname, './tests/mocks/storeWorkspace.ts'), '~test-utils': resolve(__dirname, './tests/utils.tsx'), 'lru_map': resolve(__dirname, './tests/mocks/lru_map'), };