feat: workspace backend service slice (#15560)

Backend-only slice of the workspace feature (server routers/services, database models with workspaceId threading, openapi middleware, business/server stubs, const/types). Excludes all UI (features/routes/store/hooks). Deploys dark behind the workspace feature flag.

Includes open-source stub fixes: workspaceCreds router stub, ChargeParams workspaceId, usage.ts null-coalesce, DBMessageItem.workspaceId.

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
Rdmclin2
2026-06-09 15:54:26 +08:00
committed by GitHub
parent 082481c35d
commit ccb33fa48c
465 changed files with 17609 additions and 3891 deletions
+3 -1
View File
@@ -48,6 +48,8 @@
"build-migrate-db": "bun run db:migrate",
"build-sitemap": "tsx ./scripts/buildSitemapIndex/index.ts",
"clean:node_modules": "bash -lc 'set -e; echo \"Removing all node_modules...\"; rm -rf node_modules; pnpm -r exec rm -rf node_modules; rm -rf apps/desktop/node_modules; echo \"All node_modules removed.\"'",
"codemod:workspace-nav": "tsx ./scripts/codemodWorkspaceNav.ts",
"codemod:workspace-nav:check": "tsx ./scripts/codemodWorkspaceNav.ts --check",
"db:generate": "drizzle-kit generate && npm run workflow:dbml",
"db:migrate": "cross-env MIGRATION_DB=1 tsx ./scripts/migrateServerDB/index.ts",
"db:studio": "drizzle-kit studio",
@@ -287,7 +289,7 @@
"@lobehub/desktop-ipc-typings": "workspace:*",
"@lobehub/editor": "^4.17.0",
"@lobehub/icons": "^5.0.0",
"@lobehub/market-sdk": "0.33.3",
"@lobehub/market-sdk": "0.34.0",
"@lobehub/tts": "^5.1.2",
"@lobehub/ui": "^5.15.10",
"@modelcontextprotocol/sdk": "^1.26.0",
+6
View File
@@ -5,6 +5,12 @@ export interface AgentSignalScope {
taskId?: string;
topicId?: string;
userId: string;
/**
* Workspace identifier when the chain runs inside a team workspace. Omitted
* for personal-mode chains. Action handlers that write workspace-scoped
* tables (messages, memories) must honor this when present.
*/
workspaceId?: string;
}
/** Causal chain metadata for source, signal, and action nodes. */
+1
View File
@@ -7,6 +7,7 @@
"./currency": "./src/currency.ts",
"./desktopGlobalShortcuts": "./src/desktopGlobalShortcuts.ts",
"./hotkeys": "./src/hotkeys.ts",
"./rbac": "./src/rbac.ts",
"./visualRef": "./src/visualRef.ts"
},
"main": "./src/index.ts",
+1
View File
@@ -28,3 +28,4 @@ export * from './url';
export * from './user';
export * from './userMemory';
export * from './version';
export * from './workspace';
+236 -2
View File
@@ -1,5 +1,3 @@
/* eslint-disable sort-keys-fix/sort-keys-fix */
/**
* RBAC Permission Actions Definition
* Defines all executable permission action types in the system
@@ -153,6 +151,40 @@ export const PERMISSION_ACTIONS = {
USER_READ: 'user:read',
USER_UPDATE: 'user:update',
// ==================== Workspace Management ====================
WORKSPACE_READ: 'workspace:read',
WORKSPACE_UPDATE: 'workspace:update',
WORKSPACE_DELETE: 'workspace:delete',
WORKSPACE_SETTINGS_UPDATE: 'workspace:settings_update',
WORKSPACE_BILLING_READ: 'workspace:billing_read',
WORKSPACE_BILLING_MANAGE: 'workspace:billing_manage',
// ==================== Workspace Member Management ====================
WORKSPACE_MEMBER_READ: 'workspace_member:read',
WORKSPACE_MEMBER_INVITE: 'workspace_member:invite',
WORKSPACE_MEMBER_REMOVE: 'workspace_member:remove',
WORKSPACE_MEMBER_UPDATE_ROLE: 'workspace_member:update_role',
// ==================== Workspace Audit ====================
WORKSPACE_AUDIT_READ: 'workspace_audit:read',
// ==================== Workspace Role Management ====================
WORKSPACE_ROLE_READ: 'workspace_role:read',
WORKSPACE_ROLE_CREATE: 'workspace_role:create',
WORKSPACE_ROLE_UPDATE: 'workspace_role:update',
WORKSPACE_ROLE_DELETE: 'workspace_role:delete',
} as const;
/**
@@ -176,6 +208,11 @@ export const getAllowedScopesForAction = (
// RBAC resources: ALL only (system-level resource)
if (resource === 'rbac') return ['ALL'];
// Workspace-scoped resources: ALL only. The workspace itself is the isolation
// boundary, so an "OWNER" sub-scope (resource-author-only) is redundant —
// workspace_member.role + assigned permissions already pin who can do what.
if (resource.startsWith('workspace')) return ['ALL'];
// user resource nuance: create/delete without OWNER; read/update allow OWNER
if (resource === 'user') {
if (action === 'create' || action === 'delete') return ['ALL'];
@@ -236,3 +273,200 @@ export const SYSTEM_DEFAULT_ROLES = {
export const ROLE_DESCRIPTIONS = {
[SYSTEM_DEFAULT_ROLES.SUPER_ADMIN]: 'Administrator with all system permissions',
} as const;
/**
* Built-in role names for workspace-scoped RBAC. Each workspace is seeded with
* exactly these three system roles on creation; their `workspace_id` is the
* owning workspace, distinguishing them from the global `super_admin` role.
*/
export const WORKSPACE_SYSTEM_ROLES = {
OWNER: 'workspace_owner',
MEMBER: 'workspace_member',
VIEWER: 'workspace_viewer',
} as const;
export type WorkspaceSystemRoleName =
(typeof WORKSPACE_SYSTEM_ROLES)[keyof typeof WORKSPACE_SYSTEM_ROLES];
const action = (key: keyof typeof PERMISSION_ACTIONS): string => PERMISSION_ACTIONS[key];
/**
* Permission codes granted to each built-in workspace role. The lists are the
* source of truth used both by `seedWorkspaceRoles` (DB seeding) and the
* migration backfill SQL — keep them aligned.
*
* Scope semantics:
* - `workspace_owner` — every workspace-domain permission + every content
* permission (`:all`) so they can manage other members' resources too.
* - `workspace_member` — read workspace + members; create/update/delete their
* own content (`:owner`) on every content resource.
* - `workspace_viewer` — strict read-only on workspace + members + content.
* No model invocation: chat without SESSION/MESSAGE write grants would
* either burn workspace budget without persisting history or require
* special-case bypasses. Use `workspace_member` if "can chat" is needed.
*/
export const WORKSPACE_ROLE_PERMISSIONS: Record<WorkspaceSystemRoleName, readonly string[]> = {
[WORKSPACE_SYSTEM_ROLES.OWNER]: [
// Workspace
`${action('WORKSPACE_READ')}:all`,
`${action('WORKSPACE_UPDATE')}:all`,
`${action('WORKSPACE_DELETE')}:all`,
`${action('WORKSPACE_SETTINGS_UPDATE')}:all`,
`${action('WORKSPACE_BILLING_READ')}:all`,
`${action('WORKSPACE_BILLING_MANAGE')}:all`,
// Members
`${action('WORKSPACE_MEMBER_READ')}:all`,
`${action('WORKSPACE_MEMBER_INVITE')}:all`,
`${action('WORKSPACE_MEMBER_REMOVE')}:all`,
`${action('WORKSPACE_MEMBER_UPDATE_ROLE')}:all`,
// Audit
`${action('WORKSPACE_AUDIT_READ')}:all`,
// Custom roles
`${action('WORKSPACE_ROLE_READ')}:all`,
`${action('WORKSPACE_ROLE_CREATE')}:all`,
`${action('WORKSPACE_ROLE_UPDATE')}:all`,
`${action('WORKSPACE_ROLE_DELETE')}:all`,
// Content — owner can read/write everyone's resources
`${action('AGENT_READ')}:all`,
`${action('AGENT_CREATE')}:all`,
`${action('AGENT_UPDATE')}:all`,
`${action('AGENT_DELETE')}:all`,
`${action('AGENT_FORK')}:all`,
`${action('SESSION_READ')}:all`,
`${action('SESSION_CREATE')}:all`,
`${action('SESSION_UPDATE')}:all`,
`${action('SESSION_DELETE')}:all`,
`${action('SESSION_GROUP_READ')}:all`,
`${action('SESSION_GROUP_CREATE')}:all`,
`${action('SESSION_GROUP_UPDATE')}:all`,
`${action('SESSION_GROUP_DELETE')}:all`,
`${action('MESSAGE_READ')}:all`,
`${action('MESSAGE_CREATE')}:all`,
`${action('MESSAGE_UPDATE')}:all`,
`${action('MESSAGE_DELETE')}:all`,
`${action('TOPIC_READ')}:all`,
`${action('TOPIC_CREATE')}:all`,
`${action('TOPIC_UPDATE')}:all`,
`${action('TOPIC_DELETE')}:all`,
`${action('FILE_READ')}:all`,
`${action('FILE_UPLOAD')}:all`,
`${action('FILE_UPDATE')}:all`,
`${action('FILE_DELETE')}:all`,
`${action('DOCUMENT_READ')}:all`,
`${action('DOCUMENT_CREATE')}:all`,
`${action('DOCUMENT_UPDATE')}:all`,
`${action('DOCUMENT_DELETE')}:all`,
`${action('KNOWLEDGE_BASE_READ')}:all`,
`${action('KNOWLEDGE_BASE_CREATE')}:all`,
`${action('KNOWLEDGE_BASE_UPDATE')}:all`,
`${action('KNOWLEDGE_BASE_DELETE')}:all`,
`${action('AI_MODEL_READ')}:all`,
`${action('AI_MODEL_INVOKE')}:all`,
`${action('AI_MODEL_CREATE')}:all`,
`${action('AI_MODEL_UPDATE')}:all`,
`${action('AI_MODEL_DELETE')}:all`,
`${action('AI_PROVIDER_READ')}:all`,
`${action('AI_PROVIDER_CREATE')}:all`,
`${action('AI_PROVIDER_UPDATE')}:all`,
`${action('AI_PROVIDER_DELETE')}:all`,
`${action('API_KEY_READ')}:all`,
`${action('API_KEY_CREATE')}:all`,
`${action('API_KEY_UPDATE')}:all`,
`${action('API_KEY_DELETE')}:all`,
],
[WORKSPACE_SYSTEM_ROLES.MEMBER]: [
// Workspace — read only
`${action('WORKSPACE_READ')}:all`,
`${action('WORKSPACE_MEMBER_READ')}:all`,
// Content — can write own
`${action('AGENT_READ')}:all`,
`${action('AGENT_CREATE')}:owner`,
`${action('AGENT_UPDATE')}:owner`,
`${action('AGENT_DELETE')}:owner`,
`${action('AGENT_FORK')}:owner`,
`${action('SESSION_READ')}:all`,
`${action('SESSION_CREATE')}:owner`,
`${action('SESSION_UPDATE')}:owner`,
`${action('SESSION_DELETE')}:owner`,
`${action('SESSION_GROUP_READ')}:all`,
`${action('SESSION_GROUP_CREATE')}:owner`,
`${action('SESSION_GROUP_UPDATE')}:owner`,
`${action('SESSION_GROUP_DELETE')}:owner`,
`${action('MESSAGE_READ')}:all`,
`${action('MESSAGE_CREATE')}:owner`,
`${action('MESSAGE_UPDATE')}:owner`,
`${action('MESSAGE_DELETE')}:owner`,
`${action('TOPIC_READ')}:all`,
`${action('TOPIC_CREATE')}:owner`,
`${action('TOPIC_UPDATE')}:owner`,
`${action('TOPIC_DELETE')}:owner`,
`${action('FILE_READ')}:all`,
`${action('FILE_UPLOAD')}:owner`,
`${action('FILE_UPDATE')}:owner`,
`${action('FILE_DELETE')}:owner`,
`${action('DOCUMENT_READ')}:all`,
`${action('DOCUMENT_CREATE')}:owner`,
`${action('DOCUMENT_UPDATE')}:owner`,
`${action('DOCUMENT_DELETE')}:owner`,
`${action('KNOWLEDGE_BASE_READ')}:all`,
`${action('KNOWLEDGE_BASE_CREATE')}:owner`,
`${action('KNOWLEDGE_BASE_UPDATE')}:owner`,
`${action('KNOWLEDGE_BASE_DELETE')}:owner`,
`${action('AI_MODEL_READ')}:all`,
`${action('AI_MODEL_INVOKE')}:all`,
`${action('AI_PROVIDER_READ')}:all`,
`${action('API_KEY_READ')}:owner`,
`${action('API_KEY_CREATE')}:owner`,
`${action('API_KEY_UPDATE')}:owner`,
`${action('API_KEY_DELETE')}:owner`,
],
[WORKSPACE_SYSTEM_ROLES.VIEWER]: [
// Read-only across the board
`${action('WORKSPACE_READ')}:all`,
`${action('WORKSPACE_MEMBER_READ')}:all`,
`${action('AGENT_READ')}:all`,
`${action('SESSION_READ')}:all`,
`${action('SESSION_GROUP_READ')}:all`,
`${action('MESSAGE_READ')}:all`,
`${action('TOPIC_READ')}:all`,
`${action('FILE_READ')}:all`,
`${action('DOCUMENT_READ')}:all`,
`${action('KNOWLEDGE_BASE_READ')}:all`,
`${action('AI_MODEL_READ')}:all`,
`${action('AI_PROVIDER_READ')}:all`,
],
};
export const WORKSPACE_ROLE_DESCRIPTIONS: Record<WorkspaceSystemRoleName, string> = {
[WORKSPACE_SYSTEM_ROLES.OWNER]: 'Full access including billing, members, and all content.',
[WORKSPACE_SYSTEM_ROLES.MEMBER]: 'Can create and edit own content, read shared content.',
[WORKSPACE_SYSTEM_ROLES.VIEWER]: 'Read-only access to workspace content.',
};
export const WORKSPACE_ROLE_DISPLAY_NAMES: Record<WorkspaceSystemRoleName, string> = {
[WORKSPACE_SYSTEM_ROLES.OWNER]: 'Owner',
[WORKSPACE_SYSTEM_ROLES.MEMBER]: 'Member',
[WORKSPACE_SYSTEM_ROLES.VIEWER]: 'Viewer',
};
/**
* Translate a legacy `workspace_members.role` text value to its corresponding
* built-in role name. Used by the migration backfill and member CRUD code that
* still double-writes to `workspace_members.role` for label/UI purposes.
*/
export const legacyRoleToWorkspaceRole = (role: string): WorkspaceSystemRoleName | null => {
switch (role) {
case 'owner': {
return WORKSPACE_SYSTEM_ROLES.OWNER;
}
case 'member': {
return WORKSPACE_SYSTEM_ROLES.MEMBER;
}
case 'viewer': {
return WORKSPACE_SYSTEM_ROLES.VIEWER;
}
default: {
return null;
}
}
};
+13
View File
@@ -63,6 +63,19 @@ export const TASK_TEMPLATE_FALLBACK_CATEGORIES: TaskTemplateCategory[] = [
'learning-research',
];
/**
* Categories that only make sense in a personal context. When the recommendation
* is requested from inside a workspace, every template under these categories
* is removed from the candidate pool — both matched and fallback — so a team
* dashboard never surfaces "bedtime gratitude" / "weekly family finance" etc.
*/
export const TASK_TEMPLATE_PERSONAL_ONLY_CATEGORIES: TaskTemplateCategory[] = [
'parenting',
'health',
'hobbies',
'personal-life',
];
export const TASK_TEMPLATE_RECOMMEND_COUNT = 3;
export const taskTemplates: TaskTemplate[] = [
+10
View File
@@ -0,0 +1,10 @@
/**
* Number of days a workspace invitation token stays valid before it expires.
* Shared by `WorkspaceMemberModel.createInvitation` (sets `expiresAt`) and the
* cloud invite-email template (renders the human-facing expiry copy), so the
* actual TTL and what we promise to recipients can't drift apart.
*
* If you change this, also update the "expire after 1 week" copy in
* `lobehub/src/locales/default/setting.ts` (`workspace.members.invite.modal.expiryWarning`).
*/
export const INVITATION_EXPIRY_DAYS = 7;
+2
View File
@@ -2,3 +2,5 @@ export * from './core/db-adaptor';
export * from './repositories/compression';
export * from './type';
export * from './utils/idGenerator';
export * from './utils/seedWorkspaceRoles';
export * from './utils/workspace';
@@ -17,6 +17,7 @@ import {
sessions,
topics,
users,
workspaces,
} from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { AgentModel } from '../agent';
@@ -1309,6 +1310,52 @@ describe('AgentModel', () => {
expect(result?.virtual).toBe(true);
});
});
describe('workspace mode', () => {
it('should create workspace-scoped inbox agent', async () => {
const [workspace] = await serverDB
.insert(workspaces)
.values({ name: 'ws', primaryOwnerId: userId, slug: 'ws-slug' })
.returning();
const wsAgentModel = new AgentModel(serverDB, userId, workspace.id);
const result = await wsAgentModel.getBuiltinAgent(INBOX_SESSION_ID);
expect(result).toBeDefined();
expect(result?.slug).toBe(INBOX_SESSION_ID);
expect(result?.workspaceId).toBe(workspace.id);
expect(result?.userId).toBe(userId);
});
it('should allow workspace inbox to coexist with personal inbox for the same user', async () => {
const personal = await agentModel.getBuiltinAgent(INBOX_SESSION_ID);
expect(personal?.workspaceId).toBeNull();
const [workspace] = await serverDB
.insert(workspaces)
.values({ name: 'ws2', primaryOwnerId: userId, slug: 'ws2-slug' })
.returning();
const wsAgentModel = new AgentModel(serverDB, userId, workspace.id);
const ws = await wsAgentModel.getBuiltinAgent(INBOX_SESSION_ID);
expect(ws?.id).not.toBe(personal?.id);
expect(ws?.workspaceId).toBe(workspace.id);
});
it('should be idempotent in workspace mode', async () => {
const [workspace] = await serverDB
.insert(workspaces)
.values({ name: 'ws3', primaryOwnerId: userId, slug: 'ws3-slug' })
.returning();
const wsAgentModel = new AgentModel(serverDB, userId, workspace.id);
const first = await wsAgentModel.getBuiltinAgent(INBOX_SESSION_ID);
const second = await wsAgentModel.getBuiltinAgent(INBOX_SESSION_ID);
expect(first?.id).toBe(second?.id);
});
});
});
describe('batchDelete', () => {
@@ -2,7 +2,7 @@
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { getTestDB } from '../../core/getTestDB';
import { agentBotProviders, agents, users } from '../../schemas';
import { agentBotProviders, agents, users, workspaces } from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { AgentBotProviderModel } from '../agentBotProvider';
@@ -337,6 +337,117 @@ describe('AgentBotProviderModel', () => {
});
});
describe('findEnabledByPlatformAndAppId (static)', () => {
it('should find an enabled provider that lives in a workspace (system-wide, ignores ownership scope)', async () => {
// Regression: workspace-scoped bots could not be connected because the
// gateway looked them up in personal scope (workspace_id IS NULL).
const workspaceId = 'bot-provider-test-workspace';
await serverDB.insert(workspaces).values({
id: workspaceId,
name: 'Test WS',
primaryOwnerId: userId,
slug: 'test-ws',
});
const wsModel = new AgentBotProviderModel(serverDB, userId, mockGateKeeper, workspaceId);
await wsModel.create({
agentId,
applicationId: 'ws-app',
credentials: { botToken: 'ws-tok' },
platform: 'discord',
});
// The personal-scope instance lookup misses the workspace row — this is
// the exact failure the static method exists to avoid.
const personalModel = new AgentBotProviderModel(serverDB, userId, mockGateKeeper);
expect(await personalModel.findEnabledByApplicationId('discord', 'ws-app')).toBeNull();
// The system-wide static lookup finds it and decrypts credentials.
const result = await AgentBotProviderModel.findEnabledByPlatformAndAppId(
serverDB,
'discord',
'ws-app',
mockGateKeeper,
);
expect(result).not.toBeNull();
expect(result!.applicationId).toBe('ws-app');
expect(result!.workspaceId).toBe(workspaceId);
expect(result!.credentials.botToken).toBe('ws-tok');
});
it('should find a provider owned by any user', async () => {
const model2 = new AgentBotProviderModel(serverDB, userId2);
await model2.create({
agentId: agentId2,
applicationId: 'other-user-app',
credentials: { botToken: 'tok' },
platform: 'slack',
});
const result = await AgentBotProviderModel.findEnabledByPlatformAndAppId(
serverDB,
'slack',
'other-user-app',
);
expect(result).not.toBeNull();
expect(result!.applicationId).toBe('other-user-app');
});
it('should return null for a disabled provider', async () => {
const model = new AgentBotProviderModel(serverDB, userId);
const created = await model.create({
agentId,
applicationId: 'disabled-app',
credentials: { botToken: 'tok' },
platform: 'discord',
});
await model.update(created.id, { enabled: false });
const result = await AgentBotProviderModel.findEnabledByPlatformAndAppId(
serverDB,
'discord',
'disabled-app',
);
expect(result).toBeNull();
});
it('should return null for a non-existent combination', async () => {
const result = await AgentBotProviderModel.findEnabledByPlatformAndAppId(
serverDB,
'discord',
'no-such-app',
);
expect(result).toBeNull();
});
});
describe('findByAgentId (static)', () => {
it('should return all providers for an agent regardless of ownership scope, decrypted', async () => {
const model = new AgentBotProviderModel(serverDB, userId, mockGateKeeper);
await model.create({
agentId,
applicationId: 'agent-app-1',
credentials: { botToken: 'tok-1' },
platform: 'discord',
});
const disabled = await model.create({
agentId,
applicationId: 'agent-app-2',
credentials: { botToken: 'tok-2' },
platform: 'slack',
});
await model.update(disabled.id, { enabled: false });
const results = await AgentBotProviderModel.findByAgentId(serverDB, agentId, mockGateKeeper);
// Returns both enabled and disabled rows (caller filters by `enabled`).
expect(results).toHaveLength(2);
const byApp = Object.fromEntries(results.map((r) => [r.applicationId, r]));
expect(byApp['agent-app-1'].credentials.botToken).toBe('tok-1');
expect(byApp['agent-app-2'].credentials.botToken).toBe('tok-2');
});
});
describe('findEnabledByPlatform (static)', () => {
it('should return Discord providers with botToken', async () => {
const model = new AgentBotProviderModel(serverDB, userId);
@@ -0,0 +1,89 @@
// @vitest-environment node
import { eq } from 'drizzle-orm';
import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import { getTestDB } from '../../core/getTestDB';
import { agentDocuments, agents, documents, users, workspaces } from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { AgentDocumentModel } from '../agentDocuments';
const serverDB: LobeChatDatabase = await getTestDB();
const userId = 'agent-document-workspace-user';
const workspaceId = 'agent-document-workspace';
beforeEach(async () => {
await serverDB.delete(users);
await serverDB.insert(users).values({ id: userId });
await serverDB.insert(workspaces).values({
id: workspaceId,
name: 'Agent Document Workspace',
primaryOwnerId: userId,
slug: workspaceId,
});
await serverDB.insert(agents).values([
{ id: 'personal-agent-document-agent', title: 'Personal Agent', userId, workspaceId: null },
{ id: 'workspace-agent-document-agent', title: 'Workspace Agent', userId, workspaceId },
]);
});
afterEach(async () => {
await serverDB.delete(users);
});
describe('AgentDocumentModel workspace scope', () => {
it('isolates document reads and deletes between personal and workspace scopes', async () => {
const personalModel = new AgentDocumentModel(serverDB, userId);
const workspaceModel = new AgentDocumentModel(serverDB, userId, workspaceId);
const personalDoc = await personalModel.create(
'personal-agent-document-agent',
'README.md',
'# Personal',
);
const workspaceDoc = await workspaceModel.create(
'workspace-agent-document-agent',
'README.md',
'# Workspace',
);
await expect(personalModel.findById(workspaceDoc.id)).resolves.toBeUndefined();
await expect(workspaceModel.findById(personalDoc.id)).resolves.toBeUndefined();
await expect(
serverDB.query.agentDocuments.findFirst({
where: eq(agentDocuments.id, personalDoc.id),
}),
).resolves.toMatchObject({ id: personalDoc.id, workspaceId: null });
await expect(
serverDB.query.agentDocuments.findFirst({
where: eq(agentDocuments.id, workspaceDoc.id),
}),
).resolves.toMatchObject({ id: workspaceDoc.id, workspaceId });
await expect(personalModel.findByAgent('personal-agent-document-agent')).resolves.toEqual([
expect.objectContaining({ id: personalDoc.id }),
]);
await expect(workspaceModel.findByAgent('workspace-agent-document-agent')).resolves.toEqual([
expect.objectContaining({ id: workspaceDoc.id }),
]);
await personalModel.deleteByAgent('personal-agent-document-agent');
await expect(personalModel.findById(personalDoc.id)).resolves.toBeUndefined();
await expect(workspaceModel.findById(workspaceDoc.id)).resolves.toMatchObject({
id: workspaceDoc.id,
});
await personalModel.permanentlyDelete(workspaceDoc.id);
await expect(workspaceModel.findById(workspaceDoc.id)).resolves.toMatchObject({
id: workspaceDoc.id,
});
});
});
afterEach(async () => {
await serverDB.delete(agentDocuments).where(eq(agentDocuments.userId, userId));
await serverDB.delete(documents).where(eq(documents.userId, userId));
await serverDB.delete(agents).where(eq(agents.userId, userId));
});
@@ -0,0 +1,184 @@
// @vitest-environment node
import { eq } from 'drizzle-orm';
import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import { getTestDB } from '../../core/getTestDB';
import {
agentEvalBenchmarks,
agentEvalDatasets,
agentEvalRuns,
agentEvalRunTopics,
agentEvalTestCases,
topics,
users,
workspaces,
} from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import {
AgentEvalBenchmarkModel,
AgentEvalDatasetModel,
AgentEvalRunModel,
AgentEvalRunTopicModel,
AgentEvalTestCaseModel,
} from '../agentEval';
const serverDB: LobeChatDatabase = await getTestDB();
const userId = 'agent-eval-workspace-user';
const workspaceId = 'agent-eval-workspace';
beforeEach(async () => {
await serverDB.delete(users);
await serverDB.insert(users).values({ id: userId });
await serverDB.insert(workspaces).values({
id: workspaceId,
name: 'Agent Eval Workspace',
primaryOwnerId: userId,
slug: workspaceId,
});
});
afterEach(async () => {
await serverDB.delete(users);
});
describe('Agent eval workspace scope', () => {
it('isolates benchmarks, datasets, test cases, runs, and run topics', async () => {
const personalBenchmarkModel = new AgentEvalBenchmarkModel(serverDB, userId);
const workspaceBenchmarkModel = new AgentEvalBenchmarkModel(serverDB, userId, workspaceId);
const personalDatasetModel = new AgentEvalDatasetModel(serverDB, userId);
const workspaceDatasetModel = new AgentEvalDatasetModel(serverDB, userId, workspaceId);
const personalTestCaseModel = new AgentEvalTestCaseModel(serverDB, userId);
const workspaceTestCaseModel = new AgentEvalTestCaseModel(serverDB, userId, workspaceId);
const personalRunModel = new AgentEvalRunModel(serverDB, userId);
const workspaceRunModel = new AgentEvalRunModel(serverDB, userId, workspaceId);
const personalRunTopicModel = new AgentEvalRunTopicModel(serverDB, userId);
const workspaceRunTopicModel = new AgentEvalRunTopicModel(serverDB, userId, workspaceId);
const personalBenchmark = await personalBenchmarkModel.create({
identifier: 'shared-benchmark',
isSystem: false,
name: 'Personal benchmark',
rubrics: [],
});
const workspaceBenchmark = await workspaceBenchmarkModel.create({
identifier: 'shared-benchmark',
isSystem: false,
name: 'Workspace benchmark',
rubrics: [],
});
await expect(
personalBenchmarkModel.findByIdentifier('shared-benchmark'),
).resolves.toMatchObject({
id: personalBenchmark.id,
workspaceId: null,
});
await expect(
workspaceBenchmarkModel.findByIdentifier('shared-benchmark'),
).resolves.toMatchObject({
id: workspaceBenchmark.id,
workspaceId,
});
const personalDataset = await personalDatasetModel.create({
benchmarkId: personalBenchmark.id,
identifier: 'shared-dataset',
name: 'Personal dataset',
});
const workspaceDataset = await workspaceDatasetModel.create({
benchmarkId: workspaceBenchmark.id,
identifier: 'shared-dataset',
name: 'Workspace dataset',
});
await expect(personalDatasetModel.query(personalBenchmark.id)).resolves.toEqual([
expect.objectContaining({ id: personalDataset.id }),
]);
await expect(workspaceDatasetModel.query(workspaceBenchmark.id)).resolves.toEqual([
expect.objectContaining({ id: workspaceDataset.id }),
]);
await expect(personalDatasetModel.findById(personalDataset.id)).resolves.toMatchObject({
id: personalDataset.id,
workspaceId: null,
});
await expect(workspaceDatasetModel.findById(workspaceDataset.id)).resolves.toMatchObject({
id: workspaceDataset.id,
workspaceId,
});
const personalTestCase = await personalTestCaseModel.create({
content: { expected: 'personal', input: 'question' },
datasetId: personalDataset.id,
});
const workspaceTestCase = await workspaceTestCaseModel.create({
content: { expected: 'workspace', input: 'question' },
datasetId: workspaceDataset.id,
});
await expect(personalTestCaseModel.findById(workspaceTestCase.id)).resolves.toBeUndefined();
await expect(workspaceTestCaseModel.findById(personalTestCase.id)).resolves.toBeUndefined();
const personalRun = await personalRunModel.create({
datasetId: personalDataset.id,
name: 'Personal run',
});
const workspaceRun = await workspaceRunModel.create({
datasetId: workspaceDataset.id,
name: 'Workspace run',
});
await expect(personalRunModel.findById(workspaceRun.id)).resolves.toBeUndefined();
await expect(workspaceRunModel.findById(personalRun.id)).resolves.toBeUndefined();
await serverDB.insert(topics).values([
{ id: 'agent-eval-personal-topic', title: 'Personal topic', userId, workspaceId: null },
{ id: 'agent-eval-workspace-topic', title: 'Workspace topic', userId, workspaceId },
]);
await personalRunTopicModel.batchCreate([
{
runId: personalRun.id,
status: 'completed',
testCaseId: personalTestCase.id,
topicId: 'agent-eval-personal-topic',
},
]);
const [workspaceRunTopic] = await workspaceRunTopicModel.batchCreate([
{
runId: workspaceRun.id,
status: 'completed',
testCaseId: workspaceTestCase.id,
topicId: 'agent-eval-workspace-topic',
},
]);
expect(workspaceRunTopic).toMatchObject({ runId: workspaceRun.id, workspaceId });
await expect(personalRunTopicModel.findByRunId(workspaceRun.id)).resolves.toEqual([]);
await expect(workspaceRunTopicModel.findByRunId(personalRun.id)).resolves.toEqual([]);
await expect(workspaceRunTopicModel.findByRunId(workspaceRun.id)).resolves.toEqual([
expect.objectContaining({
runId: workspaceRun.id,
topicId: 'agent-eval-workspace-topic',
}),
]);
await personalBenchmarkModel.delete(personalBenchmark.id);
await expect(personalBenchmarkModel.findById(personalBenchmark.id)).resolves.toBeUndefined();
await expect(workspaceBenchmarkModel.findById(workspaceBenchmark.id)).resolves.toMatchObject({
id: workspaceBenchmark.id,
workspaceId,
});
});
});
afterEach(async () => {
await serverDB.delete(agentEvalRunTopics).where(eq(agentEvalRunTopics.userId, userId));
await serverDB.delete(topics).where(eq(topics.userId, userId));
await serverDB.delete(agentEvalRuns).where(eq(agentEvalRuns.userId, userId));
await serverDB.delete(agentEvalTestCases).where(eq(agentEvalTestCases.userId, userId));
await serverDB.delete(agentEvalDatasets).where(eq(agentEvalDatasets.userId, userId));
await serverDB.delete(agentEvalBenchmarks).where(eq(agentEvalBenchmarks.userId, userId));
});
@@ -0,0 +1,108 @@
// @vitest-environment node
import { eq } from 'drizzle-orm';
import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import { getTestDB } from '../../core/getTestDB';
import { agentDocuments, agents, documents, users, workspaces } from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { AgentDocumentModel } from '../agentDocuments';
import { AgentSignalReviewContextModel } from '../agentSignal/reviewContext';
const serverDB: LobeChatDatabase = await getTestDB();
const userId = 'agent-signal-review-workspace-user';
const workspaceId = 'agent-signal-review-workspace';
beforeEach(async () => {
await serverDB.delete(users);
await serverDB.insert(users).values({ id: userId });
await serverDB.insert(workspaces).values({
id: workspaceId,
name: 'Agent Signal Review Workspace',
primaryOwnerId: userId,
slug: workspaceId,
});
await serverDB.insert(agents).values([
{
chatConfig: { selfIteration: { enabled: true } },
id: 'personal-review-agent',
title: 'Personal Review Agent',
userId,
virtual: false,
workspaceId: null,
},
{
chatConfig: { selfIteration: { enabled: true } },
id: 'workspace-review-agent',
title: 'Workspace Review Agent',
userId,
virtual: false,
workspaceId,
},
]);
});
afterEach(async () => {
await serverDB.delete(users);
});
describe('AgentSignalReviewContextModel workspace scope', () => {
it('isolates self-iteration checks and document activity by workspace', async () => {
const personalContext = new AgentSignalReviewContextModel(serverDB, userId);
const workspaceContext = new AgentSignalReviewContextModel(serverDB, userId, workspaceId);
const personalDocumentModel = new AgentDocumentModel(serverDB, userId);
const workspaceDocumentModel = new AgentDocumentModel(serverDB, userId, workspaceId);
const personalDoc = await personalDocumentModel.create(
'personal-review-agent',
'personal.md',
'# Personal',
);
const workspaceDoc = await workspaceDocumentModel.create(
'workspace-review-agent',
'workspace.md',
'# Workspace',
);
await expect(personalContext.canAgentRunSelfIteration('personal-review-agent')).resolves.toBe(
true,
);
await expect(personalContext.canAgentRunSelfIteration('workspace-review-agent')).resolves.toBe(
false,
);
await expect(workspaceContext.canAgentRunSelfIteration('personal-review-agent')).resolves.toBe(
false,
);
await expect(workspaceContext.canAgentRunSelfIteration('workspace-review-agent')).resolves.toBe(
true,
);
const window = {
agentId: 'workspace-review-agent',
windowEnd: new Date('2100-01-01'),
windowStart: new Date('2000-01-01'),
};
await expect(personalContext.listDocumentActivity(window)).resolves.toEqual([]);
await expect(workspaceContext.listDocumentActivity(window)).resolves.toEqual([
expect.objectContaining({
agentDocumentId: workspaceDoc.id,
documentId: workspaceDoc.documentId,
}),
]);
await expect(
workspaceContext.listDocumentActivity({
...window,
agentId: 'personal-review-agent',
}),
).resolves.toEqual([]);
await expect(personalDocumentModel.findById(personalDoc.id)).resolves.toBeDefined();
});
});
afterEach(async () => {
await serverDB.delete(agentDocuments).where(eq(agentDocuments.userId, userId));
await serverDB.delete(documents).where(eq(documents.userId, userId));
await serverDB.delete(agents).where(eq(agents.userId, userId));
});
@@ -0,0 +1,189 @@
// @vitest-environment node
import { eq } from 'drizzle-orm';
import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import { getTestDB } from '../../core/getTestDB';
import {
agentBotProviders,
agents,
agentsToSessions,
chatGroups,
chatGroupsAgents,
messages,
sessions,
topics,
users,
workspaces,
} from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { AgentModel } from '../agent';
const serverDB: LobeChatDatabase = await getTestDB();
const userId = 'transfer-test-user';
const wsId1 = 'transfer-test-ws-1';
const wsId2 = 'transfer-test-ws-2';
beforeEach(async () => {
await serverDB.delete(users);
await serverDB.insert(users).values([{ id: userId }]);
await serverDB.insert(workspaces).values([
{ id: wsId1, name: 'WS 1', slug: 'ws-1', primaryOwnerId: userId },
{ id: wsId2, name: 'WS 2', slug: 'ws-2', primaryOwnerId: userId },
]);
});
afterEach(async () => {
await serverDB.delete(users);
});
describe('AgentModel.transferAgent', () => {
it('should transfer agent from personal to workspace', async () => {
const model = new AgentModel(serverDB, userId);
const agent = await model.create({ title: 'Test Agent', slug: 'test-agent' });
const result = await model.transferAgent(agent.id, wsId1, userId);
expect(result.agentId).toBe(agent.id);
const updated = await serverDB.query.agents.findFirst({
where: eq(agents.id, agent.id),
});
expect(updated?.workspaceId).toBe(wsId1);
expect(updated?.userId).toBe(userId);
});
it('should transfer agent from workspace to personal', async () => {
const model = new AgentModel(serverDB, userId, wsId1);
const agent = await model.create({ title: 'WS Agent', slug: 'ws-agent' });
const result = await model.transferAgent(agent.id, null, userId);
expect(result.agentId).toBe(agent.id);
const updated = await serverDB.query.agents.findFirst({
where: eq(agents.id, agent.id),
});
expect(updated?.workspaceId).toBeNull();
expect(updated?.userId).toBe(userId);
});
it('should transfer agent between workspaces', async () => {
const model = new AgentModel(serverDB, userId, wsId1);
const agent = await model.create({ title: 'WS1 Agent', slug: 'ws1-agent' });
const result = await model.transferAgent(agent.id, wsId2, userId);
expect(result.agentId).toBe(agent.id);
const updated = await serverDB.query.agents.findFirst({
where: eq(agents.id, agent.id),
});
expect(updated?.workspaceId).toBe(wsId2);
});
it('should handle slug conflict by appending suffix', async () => {
const model = new AgentModel(serverDB, userId, wsId1);
const agent1 = await model.create({ title: 'Agent', slug: 'my-agent' });
// Create an agent with the same slug in target workspace
const model2 = new AgentModel(serverDB, userId, wsId2);
await model2.create({ title: 'Existing Agent', slug: 'my-agent' });
const result = await model.transferAgent(agent1.id, wsId2, userId);
expect(result.slug).toBe('my-agent-1');
const updated = await serverDB.query.agents.findFirst({
where: eq(agents.id, agent1.id),
});
expect(updated?.slug).toBe('my-agent-1');
});
it('should update related sessions and agentsToSessions', async () => {
const model = new AgentModel(serverDB, userId);
const agent = await model.create({ title: 'Agent' });
// Create a session linked to the agent
await serverDB.insert(sessions).values({ id: 'sess-1', userId, type: 'agent' });
await serverDB
.insert(agentsToSessions)
.values({ agentId: agent.id, sessionId: 'sess-1', userId });
await model.transferAgent(agent.id, wsId1, userId);
const [session] = await serverDB.select().from(sessions).where(eq(sessions.id, 'sess-1'));
expect(session.workspaceId).toBe(wsId1);
const [link] = await serverDB
.select()
.from(agentsToSessions)
.where(eq(agentsToSessions.agentId, agent.id));
expect(link.workspaceId).toBe(wsId1);
});
it('should update topics and messages', async () => {
const model = new AgentModel(serverDB, userId);
const agent = await model.create({ title: 'Agent' });
await serverDB.insert(topics).values({ id: 'topic-1', agentId: agent.id, userId });
await serverDB
.insert(messages)
.values({ id: 'msg-1', agentId: agent.id, userId, role: 'assistant' });
await model.transferAgent(agent.id, wsId1, userId);
const [topic] = await serverDB.select().from(topics).where(eq(topics.id, 'topic-1'));
expect(topic.workspaceId).toBe(wsId1);
const [msg] = await serverDB.select().from(messages).where(eq(messages.id, 'msg-1'));
expect(msg.workspaceId).toBe(wsId1);
});
it('should update bot providers', async () => {
const model = new AgentModel(serverDB, userId);
const agent = await model.create({ title: 'Agent' });
await serverDB.insert(agentBotProviders).values({
agentId: agent.id,
userId,
platform: 'discord',
applicationId: 'app-1',
credentials: 'encrypted-creds',
});
await model.transferAgent(agent.id, wsId1, userId);
const [bot] = await serverDB
.select()
.from(agentBotProviders)
.where(eq(agentBotProviders.agentId, agent.id));
expect(bot.workspaceId).toBe(wsId1);
expect(bot.userId).toBe(userId);
});
it('should remove chat group associations', async () => {
const model = new AgentModel(serverDB, userId);
const agent = await model.create({ title: 'Agent' });
await serverDB.insert(chatGroups).values({ id: 'group-1', userId });
await serverDB
.insert(chatGroupsAgents)
.values({ chatGroupId: 'group-1', agentId: agent.id, userId });
await model.transferAgent(agent.id, wsId1, userId);
const groupLinks = await serverDB
.select()
.from(chatGroupsAgents)
.where(eq(chatGroupsAgents.agentId, agent.id));
expect(groupLinks).toHaveLength(0);
});
it('should throw when agent not found', async () => {
const model = new AgentModel(serverDB, userId);
await expect(model.transferAgent('nonexistent', wsId1, userId)).rejects.toThrow(
'Agent not found',
);
});
});
@@ -7,11 +7,18 @@ import type { LobeChatDatabase } from '@/database/type';
import { getTestDB } from '../../core/getTestDB';
import type { NewChatGroup } from '../../schemas';
import { agents as agentsTable, chatGroups, chatGroupsAgents, users } from '../../schemas';
import {
agents as agentsTable,
chatGroups,
chatGroupsAgents,
users,
workspaces,
} from '../../schemas';
import { ChatGroupModel } from '../chatGroup';
const userId = 'test-user';
const otherUserId = 'other-user';
const workspaceId = 'chat-group-workspace';
const serverDB: LobeChatDatabase = await getTestDB();
@@ -26,11 +33,18 @@ type RelationAgent = {
const toRelationAgents = (agents: unknown): RelationAgent[] => agents as RelationAgent[];
const chatGroupModel = new ChatGroupModel(serverDB, userId);
const workspaceChatGroupModel = new ChatGroupModel(serverDB, otherUserId, workspaceId);
beforeEach(async () => {
await serverDB.delete(users);
// Create test users
await serverDB.insert(users).values([{ id: userId }, { id: otherUserId }]);
await serverDB.insert(workspaces).values({
id: workspaceId,
name: 'Chat Group Workspace',
primaryOwnerId: userId,
slug: workspaceId,
});
});
afterEach(async () => {
@@ -983,5 +997,32 @@ describe('ChatGroupModel', () => {
expect(result[0].id).toBe('user-group');
expect(result[0].userId).toBe(userId);
});
it('should return workspace groups for members even when rows were created by another user', async () => {
await serverDB.transaction(async (trx) => {
await trx.insert(chatGroups).values({
id: 'workspace-group',
title: 'Workspace Group',
userId,
workspaceId,
});
await trx.insert(agentsTable).values({
id: 'workspace-agent',
title: 'Workspace Agent',
userId,
workspaceId,
});
await trx.insert(chatGroupsAgents).values({
agentId: 'workspace-agent',
chatGroupId: 'workspace-group',
userId,
workspaceId,
});
});
const result = await workspaceChatGroupModel.getGroupsWithAgents(['workspace-agent']);
expect(result).toEqual([expect.objectContaining({ id: 'workspace-group', workspaceId })]);
});
});
});
@@ -5,7 +5,15 @@ import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import { uuid } from '@/utils/uuid';
import { getTestDB } from '../../core/getTestDB';
import { chunks, embeddings, fileChunks, files, unstructuredChunks, users } from '../../schemas';
import {
chunks,
embeddings,
fileChunks,
files,
unstructuredChunks,
users,
workspaces,
} from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { ChunkModel } from '../chunk';
import { codeEmbedding, designThinkingQuery, designThinkingQuery2 } from './fixtures/embedding';
@@ -13,6 +21,7 @@ import { codeEmbedding, designThinkingQuery, designThinkingQuery2 } from './fixt
const serverDB: LobeChatDatabase = await getTestDB();
const userId = 'chunk-model-test-user-id';
const workspaceId = 'chunk-model-workspace';
const chunkModel = new ChunkModel(serverDB, userId);
const sharedFileList = [
{
@@ -44,6 +53,12 @@ const sharedFileList = [
beforeEach(async () => {
await serverDB.delete(users);
await serverDB.insert(users).values([{ id: userId }]);
await serverDB.insert(workspaces).values({
id: workspaceId,
name: 'Chunk Workspace',
primaryOwnerId: userId,
slug: workspaceId,
});
await serverDB.insert(files).values(sharedFileList);
});
@@ -382,6 +397,27 @@ describe('ChunkModel', () => {
expect(result).toHaveLength(0);
});
it('should not count workspace chunks from personal scope', async () => {
await serverDB.insert(files).values({
id: 'workspace-file',
name: 'workspace.pdf',
url: 'https://example.com/workspace.pdf',
size: 1000,
fileType: 'application/pdf',
userId,
workspaceId,
});
const [chunk] = await serverDB
.insert(chunks)
.values({ text: 'Workspace Chunk', userId, workspaceId })
.returning();
await serverDB
.insert(fileChunks)
.values({ chunkId: chunk.id, fileId: 'workspace-file', userId, workspaceId });
await expect(chunkModel.countByFileIds(['workspace-file'])).resolves.toHaveLength(0);
});
});
describe('countByFileId', () => {
@@ -0,0 +1,206 @@
// @vitest-environment node
import { eq, inArray } from 'drizzle-orm';
import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import { getTestDB } from '../../core/getTestDB';
import { DOCUMENT_FOLDER_TYPE, documents, files, users, workspaces } from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { DocumentModel } from '../document';
const serverDB: LobeChatDatabase = await getTestDB();
const userId = 'doc-transfer-test-user';
const wsId1 = 'doc-transfer-test-ws-1';
const wsId2 = 'doc-transfer-test-ws-2';
const createFolder = async (
model: DocumentModel,
filename: string,
slug: string,
parentId?: string,
) =>
model.create({
content: '',
fileType: DOCUMENT_FOLDER_TYPE,
filename,
parentId,
slug,
source: '',
sourceType: 'api',
title: filename,
totalCharCount: 0,
totalLineCount: 0,
});
const createPage = async (
model: DocumentModel,
filename: string,
slug: string,
parentId?: string,
) =>
model.create({
content: 'hello',
fileType: 'page',
filename,
parentId,
slug,
source: '',
sourceType: 'api',
title: filename,
totalCharCount: 5,
totalLineCount: 1,
});
beforeEach(async () => {
await serverDB.delete(users);
await serverDB.insert(users).values([{ id: userId }]);
await serverDB.insert(workspaces).values([
{ id: wsId1, name: 'Doc WS 1', slug: 'doc-ws-1', primaryOwnerId: userId },
{ id: wsId2, name: 'Doc WS 2', slug: 'doc-ws-2', primaryOwnerId: userId },
]);
});
afterEach(async () => {
await serverDB.delete(users);
});
describe('DocumentModel.transferTo', () => {
it('transfers a single page from personal to workspace', async () => {
const model = new DocumentModel(serverDB, userId);
const page = await createPage(model, 'My Page', 'my-page');
const result = await model.transferTo(page.id, wsId1, userId);
expect(result.documentIds).toEqual([page.id]);
const updated = await serverDB.query.documents.findFirst({ where: eq(documents.id, page.id) });
expect(updated?.workspaceId).toBe(wsId1);
expect(updated?.userId).toBe(userId);
});
it('transfers a folder and all descendants', async () => {
const model = new DocumentModel(serverDB, userId);
const folder = await createFolder(model, 'Folder', 'folder-1');
const child = await createPage(model, 'Child', 'child-1', folder.id);
const subFolder = await createFolder(model, 'Sub', 'sub-1', folder.id);
const grandchild = await createPage(model, 'Grand', 'grand-1', subFolder.id);
const result = await model.transferTo(folder.id, wsId1, userId);
expect(result.documentIds.sort()).toEqual(
[folder.id, child.id, subFolder.id, grandchild.id].sort(),
);
const rows = await serverDB
.select({ id: documents.id, workspaceId: documents.workspaceId })
.from(documents)
.where(inArray(documents.id, result.documentIds));
for (const row of rows) expect(row.workspaceId).toBe(wsId1);
});
it('resolves slug conflicts by suffixing', async () => {
const ws1 = new DocumentModel(serverDB, userId, wsId1);
await createPage(ws1, 'Existing', 'shared-slug');
const personal = new DocumentModel(serverDB, userId);
const mine = await createPage(personal, 'Mine', 'shared-slug');
await personal.transferTo(mine.id, wsId1, userId);
const updated = await serverDB.query.documents.findFirst({ where: eq(documents.id, mine.id) });
expect(updated?.slug).toBe('shared-slug-1');
expect(updated?.workspaceId).toBe(wsId1);
});
it('moves files anchored to documents in the transferred subtree', async () => {
const model = new DocumentModel(serverDB, userId);
const folder = await createFolder(model, 'Folder', 'transfer-folder');
await serverDB.insert(files).values({
id: 'file-x',
userId,
fileType: 'image/png',
name: 'pic.png',
size: 10,
url: 'http://x',
parentId: folder.id,
});
await model.transferTo(folder.id, wsId1, userId);
const [file] = await serverDB.select().from(files).where(eq(files.id, 'file-x'));
expect(file.workspaceId).toBe(wsId1);
expect(file.userId).toBe(userId);
});
it('transfers from workspace back to personal', async () => {
const ws = new DocumentModel(serverDB, userId, wsId1);
const page = await createPage(ws, 'In WS', 'in-ws');
await ws.transferTo(page.id, null, userId);
const updated = await serverDB.query.documents.findFirst({ where: eq(documents.id, page.id) });
expect(updated?.workspaceId).toBeNull();
});
});
describe('DocumentModel.copyToWorkspace', () => {
it('clones a single page into the target workspace with a fresh id', async () => {
const model = new DocumentModel(serverDB, userId);
const page = await createPage(model, 'Page', 'page-x');
const { rootId } = await model.copyToWorkspace(page.id, wsId1, userId);
expect(rootId).not.toBe(page.id);
const clone = await serverDB.query.documents.findFirst({ where: eq(documents.id, rootId) });
expect(clone?.workspaceId).toBe(wsId1);
expect(clone?.title).toBe('Page');
expect(clone?.content).toBe('hello');
// Original untouched
const original = await serverDB.query.documents.findFirst({ where: eq(documents.id, page.id) });
expect(original?.workspaceId).toBeNull();
});
it('clones a folder + descendants preserving the parent topology', async () => {
const model = new DocumentModel(serverDB, userId);
const folder = await createFolder(model, 'Folder', 'copy-folder');
const child = await createPage(model, 'Child', 'copy-child', folder.id);
const sub = await createFolder(model, 'Sub', 'copy-sub', folder.id);
const grand = await createPage(model, 'Grand', 'copy-grand', sub.id);
const { rootId } = await model.copyToWorkspace(folder.id, wsId1, userId);
const cloned = await serverDB.select().from(documents).where(eq(documents.workspaceId, wsId1));
expect(cloned).toHaveLength(4);
const root = cloned.find((d) => d.id === rootId)!;
expect(root.parentId).toBeNull();
const childrenOfRoot = cloned.filter((d) => d.parentId === rootId);
expect(childrenOfRoot).toHaveLength(2);
// Locate cloned sub folder, then grandchild beneath it
const clonedSub = childrenOfRoot.find((d) => d.title === 'Sub')!;
const clonedGrand = cloned.find((d) => d.parentId === clonedSub.id)!;
expect(clonedGrand.title).toBe('Grand');
// Verify originals untouched
const originals = await serverDB
.select()
.from(documents)
.where(inArray(documents.id, [folder.id, child.id, sub.id, grand.id]));
for (const row of originals) expect(row.workspaceId).toBeNull();
});
it('reassigns slug on conflict in target scope', async () => {
const ws1 = new DocumentModel(serverDB, userId, wsId1);
await createPage(ws1, 'Existing', 'dupe-slug');
const personal = new DocumentModel(serverDB, userId);
const mine = await createPage(personal, 'Mine', 'dupe-slug');
const { rootId } = await personal.copyToWorkspace(mine.id, wsId1, userId);
const clone = await serverDB.query.documents.findFirst({ where: eq(documents.id, rootId) });
expect(clone?.slug).toBe('dupe-slug-1');
});
});
@@ -13,6 +13,7 @@ import {
generations,
generationTopics,
users,
workspaces,
} from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { GenerationModel } from '../generation';
@@ -37,6 +38,7 @@ vi.mock('../file', () => ({
const userId = 'generation-test-user-id';
const otherUserId = 'other-user-id';
const workspaceId = 'generation-workspace';
const generationModel = new GenerationModel(serverDB, userId);
// Test data
@@ -101,6 +103,12 @@ beforeEach(async () => {
// Clear database and create test users
await serverDB.delete(users);
await serverDB.insert(users).values([{ id: userId }, { id: otherUserId }]);
await serverDB.insert(workspaces).values({
id: workspaceId,
name: 'Generation Workspace',
primaryOwnerId: userId,
slug: workspaceId,
});
// Create test topic
await serverDB.insert(generationTopics).values(testTopic);
@@ -956,5 +964,36 @@ describe('GenerationModel', () => {
);
expect(result).toBeUndefined();
});
it('should not return workspace generation from personal scope', async () => {
const workspaceAsyncTaskId = '550e8400-e29b-41d4-a716-446655440111';
await serverDB.insert(generationTopics).values({
...testTopic,
id: 'workspace-topic-id',
workspaceId,
});
await serverDB.insert(generationBatches).values({
...testBatch,
id: 'workspace-batch-id',
generationTopicId: 'workspace-topic-id',
workspaceId,
});
await serverDB.insert(asyncTasks).values({
...testAsyncTask,
id: workspaceAsyncTaskId,
workspaceId,
});
await serverDB.insert(generations).values({
...testGeneration,
asyncTaskId: workspaceAsyncTaskId,
generationBatchId: 'workspace-batch-id',
userId,
workspaceId,
});
await expect(
generationModel.findByAsyncTaskId(workspaceAsyncTaskId),
).resolves.toBeUndefined();
});
});
});
@@ -13,6 +13,7 @@ import {
knowledgeBaseFiles,
knowledgeBases,
users,
workspaces,
} from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { KnowledgeBaseModel } from '../knowledgeBase';
@@ -156,6 +157,15 @@ describe('KnowledgeBaseModel', () => {
},
];
const createWorkspace = async (id: string, slug: string) => {
await serverDB.insert(workspaces).values({
id,
name: slug,
primaryOwnerId: userId,
slug,
});
};
describe('addFilesToKnowledgeBase', () => {
it('should add files to a knowledge base', async () => {
await serverDB.insert(globalFiles).values([
@@ -683,30 +693,26 @@ describe('KnowledgeBaseModel', () => {
});
it('should return empty array when all files are shared', async () => {
await serverDB
.insert(globalFiles)
.values([
{
hashId: 'hash1',
url: 'https://example.com/a.pdf',
size: 100,
fileType: 'application/pdf',
creator: userId,
},
]);
await serverDB
.insert(files)
.values([
{
id: 'file1',
name: 'a.pdf',
url: 'https://example.com/a.pdf',
fileHash: 'hash1',
size: 100,
fileType: 'application/pdf',
userId,
},
]);
await serverDB.insert(globalFiles).values([
{
hashId: 'hash1',
url: 'https://example.com/a.pdf',
size: 100,
fileType: 'application/pdf',
creator: userId,
},
]);
await serverDB.insert(files).values([
{
id: 'file1',
name: 'a.pdf',
url: 'https://example.com/a.pdf',
fileHash: 'hash1',
size: 100,
fileType: 'application/pdf',
userId,
},
]);
const { id: kb1 } = await knowledgeBaseModel.create({ name: 'KB1' });
const { id: kb2 } = await knowledgeBaseModel.create({ name: 'KB2' });
await knowledgeBaseModel.addFilesToKnowledgeBase(kb1, ['file1']);
@@ -767,30 +773,26 @@ describe('KnowledgeBaseModel', () => {
describe('deleteWithFiles', () => {
it('should delete KB and its exclusive files', async () => {
await serverDB
.insert(globalFiles)
.values([
{
hashId: 'hash1',
url: 'https://example.com/a.pdf',
size: 100,
fileType: 'application/pdf',
creator: userId,
},
]);
await serverDB
.insert(files)
.values([
{
id: 'file1',
name: 'a.pdf',
url: 'https://example.com/a.pdf',
fileHash: 'hash1',
size: 100,
fileType: 'application/pdf',
userId,
},
]);
await serverDB.insert(globalFiles).values([
{
hashId: 'hash1',
url: 'https://example.com/a.pdf',
size: 100,
fileType: 'application/pdf',
creator: userId,
},
]);
await serverDB.insert(files).values([
{
id: 'file1',
name: 'a.pdf',
url: 'https://example.com/a.pdf',
fileHash: 'hash1',
size: 100,
fileType: 'application/pdf',
userId,
},
]);
const { id: kbId } = await knowledgeBaseModel.create({ name: 'KB1' });
await knowledgeBaseModel.addFilesToKnowledgeBase(kbId, ['file1']);
const result = await knowledgeBaseModel.deleteWithFiles(kbId);
@@ -931,30 +933,26 @@ describe('KnowledgeBaseModel', () => {
});
it('should delete shared file when both KBs sharing it are deleted', async () => {
await serverDB
.insert(globalFiles)
.values([
{
hashId: 'hash1',
url: 'https://example.com/a.pdf',
size: 100,
fileType: 'application/pdf',
creator: userId,
},
]);
await serverDB
.insert(files)
.values([
{
id: 'file1',
name: 'a.pdf',
url: 'https://example.com/a.pdf',
fileHash: 'hash1',
size: 100,
fileType: 'application/pdf',
userId,
},
]);
await serverDB.insert(globalFiles).values([
{
hashId: 'hash1',
url: 'https://example.com/a.pdf',
size: 100,
fileType: 'application/pdf',
creator: userId,
},
]);
await serverDB.insert(files).values([
{
id: 'file1',
name: 'a.pdf',
url: 'https://example.com/a.pdf',
fileHash: 'hash1',
size: 100,
fileType: 'application/pdf',
userId,
},
]);
const { id: kb1 } = await knowledgeBaseModel.create({ name: 'KB1' });
const { id: kb2 } = await knowledgeBaseModel.create({ name: 'KB2' });
await knowledgeBaseModel.addFilesToKnowledgeBase(kb1, ['file1']);
@@ -976,6 +974,189 @@ describe('KnowledgeBaseModel', () => {
});
});
describe('transferTo', () => {
it('should transfer a knowledge base and its resources to another workspace', async () => {
await createWorkspace('workspace-target', 'workspace-target');
await serverDB.insert(globalFiles).values([
{
hashId: 'hash-transfer',
url: 'https://example.com/transfer.pdf',
size: 1000,
fileType: 'application/pdf',
creator: userId,
},
]);
await serverDB.insert(files).values({
id: 'file-transfer',
name: 'transfer.pdf',
url: 'https://example.com/transfer.pdf',
fileHash: 'hash-transfer',
size: 1000,
fileType: 'application/pdf',
userId,
});
const { id: knowledgeBaseId } = await knowledgeBaseModel.create({ name: 'Transfer KB' });
await serverDB.insert(documents).values({
id: 'docs_transfer_folder',
title: 'Folder',
content: '',
fileType: 'custom/folder',
totalCharCount: 0,
totalLineCount: 0,
sourceType: 'api',
source: '',
knowledgeBaseId,
userId,
});
await knowledgeBaseModel.addFilesToKnowledgeBase(knowledgeBaseId, ['file-transfer']);
await knowledgeBaseModel.transferTo(knowledgeBaseId, 'workspace-target', userId);
const transferredKb = await serverDB.query.knowledgeBases.findFirst({
where: eq(knowledgeBases.id, knowledgeBaseId),
});
const transferredFile = await serverDB.query.files.findFirst({
where: eq(files.id, 'file-transfer'),
});
const transferredDocument = await serverDB.query.documents.findFirst({
where: eq(documents.id, 'docs_transfer_folder'),
});
const transferredLink = await serverDB.query.knowledgeBaseFiles.findFirst({
where: eq(knowledgeBaseFiles.knowledgeBaseId, knowledgeBaseId),
});
expect(transferredKb?.workspaceId).toBe('workspace-target');
expect(transferredFile?.workspaceId).toBe('workspace-target');
expect(transferredDocument?.workspaceId).toBe('workspace-target');
expect(transferredLink?.workspaceId).toBe('workspace-target');
});
it('should rename the transferred knowledge base when the target has the same name', async () => {
await createWorkspace('workspace-rename-target', 'workspace-rename-target');
const targetModel = new KnowledgeBaseModel(serverDB, userId, 'workspace-rename-target');
await targetModel.create({ name: 'Shared KB' });
const { id: knowledgeBaseId } = await knowledgeBaseModel.create({ name: 'Shared KB' });
await knowledgeBaseModel.transferTo(knowledgeBaseId, 'workspace-rename-target', userId);
const transferredKb = await serverDB.query.knowledgeBases.findFirst({
where: eq(knowledgeBases.id, knowledgeBaseId),
});
expect(transferredKb?.name).toBe('Shared KB (1)');
});
});
describe('copyToWorkspace', () => {
it('should copy a knowledge base with files and document hierarchy to another workspace', async () => {
await createWorkspace('workspace-copy-target', 'workspace-copy-target');
await serverDB.insert(globalFiles).values([
{
hashId: 'hash-copy',
url: 'https://example.com/copy.pdf',
size: 1000,
fileType: 'application/pdf',
creator: userId,
},
]);
const { id: knowledgeBaseId } = await knowledgeBaseModel.create({ name: 'Copy KB' });
await serverDB.insert(documents).values([
{
id: 'docs_copy_folder',
title: 'Folder',
content: '',
fileType: 'custom/folder',
totalCharCount: 0,
totalLineCount: 0,
sourceType: 'api',
source: '',
knowledgeBaseId,
userId,
},
{
id: 'docs_copy_note',
title: 'Note',
content: 'note content',
fileType: 'custom/document',
totalCharCount: 12,
totalLineCount: 1,
sourceType: 'api',
source: '',
knowledgeBaseId,
parentId: 'docs_copy_folder',
userId,
},
]);
await serverDB.insert(files).values({
id: 'file-copy',
name: 'copy.pdf',
url: 'https://example.com/copy.pdf',
fileHash: 'hash-copy',
size: 1000,
fileType: 'application/pdf',
parentId: 'docs_copy_folder',
userId,
});
await knowledgeBaseModel.addFilesToKnowledgeBase(knowledgeBaseId, ['file-copy']);
const result = await knowledgeBaseModel.copyToWorkspace(
knowledgeBaseId,
'workspace-copy-target',
userId,
);
expect(result.id).not.toBe(knowledgeBaseId);
const copiedKb = await serverDB.query.knowledgeBases.findFirst({
where: eq(knowledgeBases.id, result.id),
});
const copiedLinks = await serverDB.query.knowledgeBaseFiles.findMany({
where: eq(knowledgeBaseFiles.knowledgeBaseId, result.id),
});
const copiedDocs = await serverDB.query.documents.findMany({
where: eq(documents.knowledgeBaseId, result.id),
});
const originalKb = await serverDB.query.knowledgeBases.findFirst({
where: eq(knowledgeBases.id, knowledgeBaseId),
});
expect(copiedKb).toMatchObject({
name: 'Copy KB',
workspaceId: 'workspace-copy-target',
});
expect(copiedLinks).toHaveLength(1);
expect(copiedLinks[0].fileId).not.toBe('file-copy');
expect(copiedLinks[0].workspaceId).toBe('workspace-copy-target');
expect(copiedDocs).toHaveLength(2);
expect(copiedDocs.every((doc) => doc.workspaceId === 'workspace-copy-target')).toBe(true);
expect(copiedDocs.find((doc) => doc.title === 'Note')?.parentId).toBe(
copiedDocs.find((doc) => doc.title === 'Folder')?.id,
);
expect(originalKb?.workspaceId).toBeNull();
});
it('should rename the copied knowledge base when the target has the same name', async () => {
await createWorkspace('workspace-copy-rename-target', 'workspace-copy-rename-target');
const targetModel = new KnowledgeBaseModel(serverDB, userId, 'workspace-copy-rename-target');
await targetModel.create({ name: 'Shared KB' });
const { id: knowledgeBaseId } = await knowledgeBaseModel.create({ name: 'Shared KB' });
const result = await knowledgeBaseModel.copyToWorkspace(
knowledgeBaseId,
'workspace-copy-rename-target',
userId,
);
const copiedKb = await serverDB.query.knowledgeBases.findFirst({
where: eq(knowledgeBases.id, result.id),
});
expect(copiedKb?.name).toBe('Shared KB (1)');
});
});
describe('static findById', () => {
it('should find a knowledge base by id without user restriction', async () => {
const { id } = await knowledgeBaseModel.create({ name: 'Test Group' });
@@ -15,6 +15,7 @@ import {
sessions,
topics,
users,
workspaces,
} from '../../../schemas';
import type { LobeChatDatabase } from '../../../type';
import { MessageModel } from '../../message';
@@ -24,7 +25,9 @@ const serverDB: LobeChatDatabase = await getTestDB();
const userId = 'message-update-test';
const otherUserId = 'message-update-test-other';
const workspaceId = 'message-update-workspace';
const messageModel = new MessageModel(serverDB, userId);
const workspaceMessageModel = new MessageModel(serverDB, otherUserId, workspaceId);
const embeddingsId = uuid();
beforeEach(async () => {
@@ -33,6 +36,12 @@ beforeEach(async () => {
await trx.delete(users).where(eq(users.id, userId));
await trx.delete(users).where(eq(users.id, otherUserId));
await trx.insert(users).values([{ id: userId }, { id: otherUserId }]);
await trx.insert(workspaces).values({
id: workspaceId,
name: 'Message Workspace',
primaryOwnerId: userId,
slug: workspaceId,
});
await trx.insert(sessions).values([
// { id: 'session1', userId },
@@ -950,6 +959,30 @@ describe('MessageModel Update Tests', () => {
expect(dbResult[0].metadata).toEqual({ originalKey: 'originalValue' });
});
it('should update workspace messages even when created by another user', async () => {
await serverDB.insert(messages).values({
id: 'msg-workspace-metadata',
userId,
workspaceId,
role: 'user',
content: 'test message',
metadata: { originalKey: 'originalValue' },
});
await workspaceMessageModel.updateMetadata('msg-workspace-metadata', {
workspaceKey: 'workspaceValue',
});
const dbResult = await serverDB
.select()
.from(messages)
.where(eq(messages.id, 'msg-workspace-metadata'));
expect(dbResult[0].metadata).toEqual({
originalKey: 'originalValue',
workspaceKey: 'workspaceValue',
});
});
it('should handle complex nested metadata updates', async () => {
// Create test data
await serverDB.insert(messages).values({
@@ -1273,6 +1306,33 @@ describe('MessageModel Update Tests', () => {
expect(result[0].content).toBe('translated message 1');
});
it('should insert workspaceId for workspace translate records', async () => {
await serverDB.insert(messages).values({
id: 'workspace-translate',
userId,
workspaceId,
role: 'user',
content: 'message 1',
});
await workspaceMessageModel.updateTranslate('workspace-translate', {
content: 'translated message 1',
from: 'en',
to: 'zh',
});
const result = await serverDB
.select()
.from(messageTranslates)
.where(eq(messageTranslates.id, 'workspace-translate'));
expect(result[0]).toMatchObject({
id: 'workspace-translate',
userId: otherUserId,
workspaceId,
});
});
it('should update the corresponding fields if message exists in messageTranslates table', async () => {
// Create test data
await serverDB.transaction(async (trx) => {
@@ -1314,6 +1374,29 @@ describe('MessageModel Update Tests', () => {
expect(result[0].voice).toBe('voice1');
});
it('should insert workspaceId for workspace TTS records', async () => {
await serverDB.insert(messages).values({
id: 'workspace-tts',
userId,
workspaceId,
role: 'user',
content: 'message 1',
});
await workspaceMessageModel.updateTTS('workspace-tts', {
contentMd5: 'md5',
file: 'f1',
voice: 'voice1',
});
const result = await serverDB
.select()
.from(messageTTS)
.where(eq(messageTTS.id, 'workspace-tts'));
expect(result[0]).toMatchObject({ id: 'workspace-tts', userId: otherUserId, workspaceId });
});
it('should update the corresponding fields if message exists in messageTTS table', async () => {
// Create test data
await serverDB.transaction(async (trx) => {
@@ -0,0 +1,73 @@
// @vitest-environment node
import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import { getTestDB } from '../../../core/getTestDB';
import { messages, sessions, topics, users, workspaces } from '../../../schemas';
import type { LobeChatDatabase } from '../../../type';
import { MessageModel } from '../../message';
const serverDB: LobeChatDatabase = await getTestDB();
const userId = 'message-workspace-user';
const workspaceId = 'message-workspace';
beforeEach(async () => {
await serverDB.delete(users);
await serverDB.insert(users).values({ id: userId });
await serverDB.insert(workspaces).values({
id: workspaceId,
name: 'Workspace',
primaryOwnerId: userId,
slug: workspaceId,
});
await serverDB.insert(sessions).values([
{ id: 'personal-session', userId, workspaceId: null },
{ id: 'workspace-session', userId, workspaceId },
]);
await serverDB.insert(topics).values([
{ id: 'personal-topic', sessionId: 'personal-session', userId, workspaceId: null },
{ id: 'workspace-topic', sessionId: 'workspace-session', userId, workspaceId },
]);
});
afterEach(async () => {
await serverDB.delete(users);
});
describe('MessageModel workspace scope', () => {
it('isolates personal and workspace messages for the same user', async () => {
await serverDB.insert(messages).values([
{
content: 'personal',
id: 'personal-message',
role: 'user',
sessionId: 'personal-session',
topicId: 'personal-topic',
userId,
workspaceId: null,
},
{
content: 'workspace',
id: 'workspace-message',
role: 'user',
sessionId: 'workspace-session',
topicId: 'workspace-topic',
userId,
workspaceId,
},
]);
await expect(
new MessageModel(serverDB, userId).query({
sessionId: 'personal-session',
topicId: 'personal-topic',
}),
).resolves.toEqual([expect.objectContaining({ id: 'personal-message' })]);
await expect(
new MessageModel(serverDB, userId, workspaceId).query({
sessionId: 'workspace-session',
topicId: 'workspace-topic',
}),
).resolves.toEqual([expect.objectContaining({ id: 'workspace-message' })]);
});
});
@@ -2,7 +2,7 @@
import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import { getTestDB } from '../../core/getTestDB';
import { agents, messengerAccountLinks, users } from '../../schemas';
import { agents, messengerAccountLinks, users, workspaces } from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import {
MessengerAccountLinkConflictError,
@@ -16,19 +16,29 @@ const userA = 'msg-link-user-a';
const userB = 'msg-link-user-b';
const agentA = 'msg-link-agent-a';
const agentB = 'msg-link-agent-b';
const workspaceA = 'msg-link-workspace-a';
const workspaceAgentA = 'msg-link-agent-workspace-a';
beforeEach(async () => {
await serverDB.delete(users);
await serverDB.insert(users).values([{ id: userA }, { id: userB }]);
await serverDB.insert(workspaces).values({
id: workspaceA,
name: 'Workspace A',
primaryOwnerId: userA,
slug: 'workspace-a',
});
await serverDB.insert(agents).values([
{ id: agentA, userId: userA },
{ id: agentB, userId: userB },
{ id: workspaceAgentA, userId: userA, workspaceId: workspaceA },
]);
});
afterEach(async () => {
await serverDB.delete(messengerAccountLinks);
await serverDB.delete(agents);
await serverDB.delete(workspaces);
await serverDB.delete(users);
});
@@ -225,6 +235,56 @@ describe('MessengerAccountLinkModel', () => {
});
});
describe('active scope (workspaceId)', () => {
// A given IM identity has exactly one link; `workspaceId` on it is the
// *active scope* derived from the active agent (personal → null), not part
// of the link's identity. Switching scope reuses the same row.
it('persists the active scope passed at upsert time', async () => {
const model = new MessengerAccountLinkModel(serverDB, userA);
const personal = await model.upsertForPlatform({
activeAgentId: agentA,
platform: 'telegram',
platformUserId: 'tg-scope',
workspaceId: null,
});
expect(personal.workspaceId).toBeNull();
// Re-asserting the same identity with a workspace agent flips the active
// scope on the same row — no relink, single identity link.
const switched = await model.upsertForPlatform({
activeAgentId: workspaceAgentA,
platform: 'telegram',
platformUserId: 'tg-scope',
workspaceId: workspaceA,
});
expect(switched.id).toBe(personal.id);
expect(switched.workspaceId).toBe(workspaceA);
expect(switched.activeAgentId).toBe(workspaceAgentA);
});
it('setActiveAgent updates both the active agent and the derived scope', async () => {
const model = new MessengerAccountLinkModel(serverDB, userA);
await model.upsertForPlatform({
activeAgentId: agentA,
platform: 'telegram',
platformUserId: 'tg-switch',
workspaceId: null,
});
// Switch into a workspace agent.
await model.setActiveAgent('telegram', workspaceAgentA, workspaceA);
let link = await model.findByPlatform('telegram');
expect(link?.activeAgentId).toBe(workspaceAgentA);
expect(link?.workspaceId).toBe(workspaceA);
// Switch back to personal.
await model.setActiveAgent('telegram', agentA, null);
link = await model.findByPlatform('telegram');
expect(link?.activeAgentId).toBe(agentA);
expect(link?.workspaceId).toBeNull();
});
});
describe('setActiveAgent', () => {
it('only updates the targeted (platform, tenant) row', async () => {
const model = new MessengerAccountLinkModel(serverDB, userA);
@@ -241,7 +301,7 @@ describe('MessengerAccountLinkModel', () => {
tenantId: 'T_BETA',
});
await model.setActiveAgent('slack', null, 'T_ACME');
await model.setActiveAgent('slack', null, null, 'T_ACME');
const acme = await model.findByPlatform('slack', 'T_ACME');
const beta = await model.findByPlatform('slack', 'T_BETA');
@@ -0,0 +1,43 @@
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { NotificationModel } from '../../models/notification';
import { notifications } from '../../schemas/notification';
import type { LobeChatDatabase } from '../../type';
describe('NotificationModel', () => {
const returning = vi.fn();
const onConflictDoNothing = vi.fn(() => ({ returning }));
const values = vi.fn((_payload?: unknown) => ({ onConflictDoNothing }));
const insert = vi.fn(() => ({ values }));
const db = { insert } as unknown as LobeChatDatabase;
beforeEach(() => {
vi.clearAllMocks();
returning.mockResolvedValue([{ id: 'notification-1' }]);
});
describe('create', () => {
it('creates user-scoped notifications without persisting workspace context', async () => {
const model = new NotificationModel(db, 'user-1');
await model.create({
category: 'workspace',
content: 'You have been removed from the workspace.',
dedupeKey: 'member_removed_workspace-1_user-1',
title: 'Removed from workspace',
type: 'workspace_member_removed',
});
const [payload] = values.mock.calls[0];
expect(payload).toMatchObject({
dedupeKey: 'member_removed_workspace-1_user-1',
userId: 'user-1',
});
expect(payload).not.toHaveProperty('workspaceId');
expect(onConflictDoNothing).toHaveBeenCalledWith({
target: [notifications.userId, notifications.dedupeKey],
});
});
});
});
@@ -0,0 +1,204 @@
// @vitest-environment node
import { eq } from 'drizzle-orm';
import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import { getTestDB } from '../../core/getTestDB';
import {
evalDatasetRecords,
evalDatasets,
evalEvaluation,
evaluationRecords,
knowledgeBases,
users,
workspaces,
} from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import {
EvalDatasetModel,
EvalDatasetRecordModel,
EvalEvaluationModel,
EvaluationRecordModel,
} from '../ragEval';
const serverDB: LobeChatDatabase = await getTestDB();
const userId = 'rag-eval-workspace-user';
const workspaceId = 'rag-eval-workspace';
const personalKnowledgeBaseId = 'rag-eval-personal-kb';
const workspaceKnowledgeBaseId = 'rag-eval-workspace-kb';
beforeEach(async () => {
await serverDB.delete(users);
await serverDB.insert(users).values({ id: userId });
await serverDB.insert(workspaces).values({
id: workspaceId,
name: 'RAG Eval Workspace',
primaryOwnerId: userId,
slug: workspaceId,
});
await serverDB.insert(knowledgeBases).values([
{
id: personalKnowledgeBaseId,
name: 'Personal KB',
userId,
workspaceId: null,
},
{
id: workspaceKnowledgeBaseId,
name: 'Workspace KB',
userId,
workspaceId,
},
]);
});
afterEach(async () => {
await serverDB.delete(users);
});
describe('RAG eval workspace scope', () => {
it('isolates datasets and dataset records between personal and workspace scopes', async () => {
const personalDatasetModel = new EvalDatasetModel(serverDB, userId);
const workspaceDatasetModel = new EvalDatasetModel(serverDB, userId, workspaceId);
const personalDataset = await personalDatasetModel.create({
knowledgeBaseId: personalKnowledgeBaseId,
name: 'Personal dataset',
});
const workspaceDataset = await workspaceDatasetModel.create({
knowledgeBaseId: workspaceKnowledgeBaseId,
name: 'Workspace dataset',
});
await expect(personalDatasetModel.query(personalKnowledgeBaseId)).resolves.toEqual([
expect.objectContaining({ id: personalDataset.id }),
]);
await expect(workspaceDatasetModel.query(workspaceKnowledgeBaseId)).resolves.toEqual([
expect.objectContaining({ id: workspaceDataset.id }),
]);
await expect(personalDatasetModel.findById(personalDataset.id)).resolves.toMatchObject({
id: personalDataset.id,
workspaceId: null,
});
await expect(workspaceDatasetModel.findById(workspaceDataset.id)).resolves.toMatchObject({
id: workspaceDataset.id,
workspaceId,
});
const personalRecordModel = new EvalDatasetRecordModel(serverDB, userId);
const workspaceRecordModel = new EvalDatasetRecordModel(serverDB, userId, workspaceId);
const personalRecord = await personalRecordModel.create({
datasetId: personalDataset.id,
question: 'Personal question',
});
const workspaceRecord = await workspaceRecordModel.create({
datasetId: workspaceDataset.id,
question: 'Workspace question',
});
await expect(personalRecordModel.findById(workspaceRecord.id)).resolves.toBeUndefined();
await expect(workspaceRecordModel.findById(personalRecord.id)).resolves.toBeUndefined();
await personalRecordModel.update(personalRecord.id, { question: 'Updated personal question' });
await expect(personalRecordModel.findById(personalRecord.id)).resolves.toMatchObject({
question: 'Updated personal question',
workspaceId: null,
});
await personalDatasetModel.delete(personalDataset.id);
await expect(personalDatasetModel.findById(personalDataset.id)).resolves.toBeUndefined();
await expect(workspaceDatasetModel.findById(workspaceDataset.id)).resolves.toMatchObject({
id: workspaceDataset.id,
workspaceId,
});
});
it('isolates evaluations and evaluation records between personal and workspace scopes', async () => {
const personalDatasetModel = new EvalDatasetModel(serverDB, userId);
const workspaceDatasetModel = new EvalDatasetModel(serverDB, userId, workspaceId);
const personalRecordModel = new EvalDatasetRecordModel(serverDB, userId);
const workspaceRecordModel = new EvalDatasetRecordModel(serverDB, userId, workspaceId);
const personalEvaluationModel = new EvalEvaluationModel(serverDB, userId);
const workspaceEvaluationModel = new EvalEvaluationModel(serverDB, userId, workspaceId);
const personalEvaluationRecordModel = new EvaluationRecordModel(serverDB, userId);
const workspaceEvaluationRecordModel = new EvaluationRecordModel(serverDB, userId, workspaceId);
const personalDataset = await personalDatasetModel.create({
knowledgeBaseId: personalKnowledgeBaseId,
name: 'Personal dataset',
});
const workspaceDataset = await workspaceDatasetModel.create({
knowledgeBaseId: workspaceKnowledgeBaseId,
name: 'Workspace dataset',
});
const personalDatasetRecord = await personalRecordModel.create({
datasetId: personalDataset.id,
question: 'Personal question',
});
const workspaceDatasetRecord = await workspaceRecordModel.create({
datasetId: workspaceDataset.id,
question: 'Workspace question',
});
const personalEvaluation = await personalEvaluationModel.create({
datasetId: personalDataset.id,
knowledgeBaseId: personalKnowledgeBaseId,
name: 'Personal evaluation',
});
const workspaceEvaluation = await workspaceEvaluationModel.create({
datasetId: workspaceDataset.id,
knowledgeBaseId: workspaceKnowledgeBaseId,
name: 'Workspace evaluation',
});
await expect(
personalEvaluationModel.queryByKnowledgeBaseId(personalKnowledgeBaseId),
).resolves.toEqual([expect.objectContaining({ id: personalEvaluation.id })]);
await expect(
workspaceEvaluationModel.queryByKnowledgeBaseId(workspaceKnowledgeBaseId),
).resolves.toEqual([expect.objectContaining({ id: workspaceEvaluation.id })]);
await expect(personalEvaluationModel.findById(personalEvaluation.id)).resolves.toMatchObject({
id: personalEvaluation.id,
workspaceId: null,
});
await expect(workspaceEvaluationModel.findById(workspaceEvaluation.id)).resolves.toMatchObject({
id: workspaceEvaluation.id,
workspaceId,
});
const personalEvaluationRecord = await personalEvaluationRecordModel.create({
datasetRecordId: personalDatasetRecord.id,
evaluationId: personalEvaluation.id,
question: 'Personal eval question',
});
const workspaceEvaluationRecord = await workspaceEvaluationRecordModel.create({
datasetRecordId: workspaceDatasetRecord.id,
evaluationId: workspaceEvaluation.id,
question: 'Workspace eval question',
});
await expect(
personalEvaluationRecordModel.findById(workspaceEvaluationRecord.id),
).resolves.toBeUndefined();
await expect(
workspaceEvaluationRecordModel.findById(personalEvaluationRecord.id),
).resolves.toBeUndefined();
await personalEvaluationRecordModel.delete(personalEvaluationRecord.id);
await expect(personalEvaluationRecordModel.query(personalEvaluation.id)).resolves.toEqual([]);
await expect(workspaceEvaluationRecordModel.query(workspaceEvaluation.id)).resolves.toEqual([
expect.objectContaining({ id: workspaceEvaluationRecord.id, workspaceId }),
]);
});
});
afterEach(async () => {
await serverDB.delete(evaluationRecords).where(eq(evaluationRecords.userId, userId));
await serverDB.delete(evalEvaluation).where(eq(evalEvaluation.userId, userId));
await serverDB.delete(evalDatasetRecords).where(eq(evalDatasetRecords.userId, userId));
await serverDB.delete(evalDatasets).where(eq(evalDatasets.userId, userId));
await serverDB.delete(knowledgeBases).where(eq(knowledgeBases.userId, userId));
});
@@ -0,0 +1,205 @@
// @vitest-environment node
import {
PERMISSION_ACTIONS,
WORKSPACE_ROLE_PERMISSIONS,
WORKSPACE_SYSTEM_ROLES,
} from '@lobechat/const/rbac';
import { and, eq } from 'drizzle-orm';
import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import { getTestDB } from '../../core/getTestDB';
import { permissions, rolePermissions, roles, userRoles, users, workspaces } from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { seedWorkspaceRoles } from '../../utils/seedWorkspaceRoles';
import { RbacModel } from '../rbac';
const serverDB: LobeChatDatabase = await getTestDB();
const userId = 'rbac-model-test-user-id';
const otherUserId = 'rbac-model-test-other-user-id';
const workspaceAId = 'rbac-ws-a';
const workspaceBId = 'rbac-ws-b';
const cleanup = async () => {
// userRoles + rolePermissions cascade via FK, but workspace-scoped roles only
// cascade when the workspace itself is deleted — so do it explicitly here.
await serverDB.delete(userRoles);
await serverDB.delete(rolePermissions);
await serverDB.delete(roles);
await serverDB.delete(permissions);
await serverDB.delete(workspaces);
await serverDB.delete(users);
};
beforeEach(async () => {
await cleanup();
await serverDB.insert(users).values([{ id: userId }, { id: otherUserId }]);
await serverDB.insert(workspaces).values([
{ id: workspaceAId, name: 'A', primaryOwnerId: userId, slug: 'ws-a' },
{ id: workspaceBId, name: 'B', primaryOwnerId: userId, slug: 'ws-b' },
]);
await seedWorkspaceRoles(serverDB, workspaceAId);
await seedWorkspaceRoles(serverDB, workspaceBId);
});
afterEach(async () => {
await cleanup();
});
describe('RbacModel — workspace scope', () => {
const ownerCode = `${PERMISSION_ACTIONS.WORKSPACE_UPDATE}:all`;
const memberCode = `${PERMISSION_ACTIONS.WORKSPACE_READ}:all`;
describe('assignWorkspaceRole / hasPermission with workspaceId', () => {
it('returns true for a permission granted via the assigned role in that workspace', async () => {
const rbac = new RbacModel(serverDB, userId);
await rbac.assignWorkspaceRole({
roleName: WORKSPACE_SYSTEM_ROLES.OWNER,
userId,
workspaceId: workspaceAId,
});
expect(await rbac.hasPermission(ownerCode, { workspaceId: workspaceAId })).toBe(true);
});
it('returns false for a permission the assigned role does not include', async () => {
const rbac = new RbacModel(serverDB, userId);
await rbac.assignWorkspaceRole({
roleName: WORKSPACE_SYSTEM_ROLES.VIEWER,
userId,
workspaceId: workspaceAId,
});
// viewer never gets workspace:update:all (only owner does).
expect(await rbac.hasPermission(ownerCode, { workspaceId: workspaceAId })).toBe(false);
// but viewer does have workspace:read:all.
expect(await rbac.hasPermission(memberCode, { workspaceId: workspaceAId })).toBe(true);
});
it('does not leak permissions across workspaces', async () => {
const rbac = new RbacModel(serverDB, userId);
await rbac.assignWorkspaceRole({
roleName: WORKSPACE_SYSTEM_ROLES.OWNER,
userId,
workspaceId: workspaceAId,
});
expect(await rbac.hasPermission(ownerCode, { workspaceId: workspaceAId })).toBe(true);
expect(await rbac.hasPermission(ownerCode, { workspaceId: workspaceBId })).toBe(false);
});
it('is idempotent', async () => {
const rbac = new RbacModel(serverDB, userId);
await rbac.assignWorkspaceRole({
roleName: WORKSPACE_SYSTEM_ROLES.OWNER,
userId,
workspaceId: workspaceAId,
});
// Re-assigning is a no-op thanks to the (userId, roleId, workspaceId)
// unique index — must not throw.
await rbac.assignWorkspaceRole({
roleName: WORKSPACE_SYSTEM_ROLES.OWNER,
userId,
workspaceId: workspaceAId,
});
const grants = await serverDB.query.userRoles.findMany({
where: and(eq(userRoles.userId, userId), eq(userRoles.workspaceId, workspaceAId)),
});
expect(grants).toHaveLength(1);
});
});
describe('revokeWorkspaceRole', () => {
it('drops every grant in the named workspace and leaves others untouched', async () => {
const rbac = new RbacModel(serverDB, userId);
await rbac.assignWorkspaceRole({
roleName: WORKSPACE_SYSTEM_ROLES.OWNER,
userId,
workspaceId: workspaceAId,
});
await rbac.assignWorkspaceRole({
roleName: WORKSPACE_SYSTEM_ROLES.OWNER,
userId,
workspaceId: workspaceBId,
});
await rbac.revokeWorkspaceRole({ userId, workspaceId: workspaceAId });
expect(await rbac.hasPermission(ownerCode, { workspaceId: workspaceAId })).toBe(false);
expect(await rbac.hasPermission(ownerCode, { workspaceId: workspaceBId })).toBe(true);
});
it('is a no-op when the user has no grants in the workspace', async () => {
const rbac = new RbacModel(serverDB, userId);
await expect(
rbac.revokeWorkspaceRole({ userId, workspaceId: workspaceAId }),
).resolves.not.toThrow();
});
});
describe('getUserPermissions with workspaceId', () => {
it('returns scoped codes for the named workspace, de-duped', async () => {
const rbac = new RbacModel(serverDB, userId);
await rbac.assignWorkspaceRole({
roleName: WORKSPACE_SYSTEM_ROLES.OWNER,
userId,
workspaceId: workspaceAId,
});
const codes = await rbac.getUserPermissions({ workspaceId: workspaceAId });
const expected = new Set(WORKSPACE_ROLE_PERMISSIONS[WORKSPACE_SYSTEM_ROLES.OWNER]);
// every code the owner role grants should appear in the result
for (const code of expected) {
expect(codes).toContain(code);
}
// ...and no duplicates
expect(codes).toHaveLength(new Set(codes).size);
});
it('does not include workspace B permissions when scoped to workspace A', async () => {
const rbac = new RbacModel(serverDB, userId);
await rbac.assignWorkspaceRole({
roleName: WORKSPACE_SYSTEM_ROLES.OWNER,
userId,
workspaceId: workspaceBId,
});
// user has no grant in workspaceA
const codes = await rbac.getUserPermissions({ workspaceId: workspaceAId });
expect(codes).toEqual([]);
});
});
describe('listWorkspaceRoles', () => {
it('lists the three built-in roles seeded for that workspace', async () => {
const rbac = new RbacModel(serverDB, userId);
const list = await rbac.listWorkspaceRoles(workspaceAId);
const names = list.map((r) => r.name).sort();
expect(names).toEqual(
[
WORKSPACE_SYSTEM_ROLES.MEMBER,
WORKSPACE_SYSTEM_ROLES.OWNER,
WORKSPACE_SYSTEM_ROLES.VIEWER,
].sort(),
);
expect(list.every((r) => r.workspaceId === workspaceAId)).toBe(true);
});
});
describe('back-compat: no workspaceId', () => {
it('still matches workspace-scoped grants when no workspaceId is given (legacy behavior)', async () => {
// Hono routes call `hasPermission(code)` without workspaceId. This must
// keep returning true for users whose only grant is workspace-scoped,
// otherwise every Hono content route regresses on workspace users.
const rbac = new RbacModel(serverDB, userId);
await rbac.assignWorkspaceRole({
roleName: WORKSPACE_SYSTEM_ROLES.OWNER,
userId,
workspaceId: workspaceAId,
});
expect(await rbac.hasPermission(ownerCode)).toBe(true);
});
});
});
@@ -0,0 +1,79 @@
// @vitest-environment node
import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import { getTestDB } from '../../core/getTestDB';
import { agents, agentsToSessions, sessions, users, workspaces } from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { SessionModel } from '../session';
const serverDB: LobeChatDatabase = await getTestDB();
const userId = 'session-workspace-user';
const workspaceId = 'session-workspace';
beforeEach(async () => {
await serverDB.delete(users);
await serverDB.insert(users).values({ id: userId });
await serverDB.insert(workspaces).values({
id: workspaceId,
name: 'Workspace',
primaryOwnerId: userId,
slug: workspaceId,
});
});
afterEach(async () => {
await serverDB.delete(users);
});
describe('SessionModel workspace scope', () => {
it('isolates personal and workspace sessions for the same user', async () => {
await serverDB.insert(sessions).values([
{ id: 'personal-session', updatedAt: new Date('2023-01-01'), userId, workspaceId: null },
{
id: 'workspace-session',
updatedAt: new Date('2023-02-01'),
userId,
workspaceId,
},
]);
await expect(new SessionModel(serverDB, userId).query()).resolves.toEqual([
expect.objectContaining({ id: 'personal-session' }),
]);
await expect(new SessionModel(serverDB, userId, workspaceId).query()).resolves.toEqual([
expect.objectContaining({ id: 'workspace-session' }),
]);
});
it('deleteAll on personal scope does not delete workspace sessions or links', async () => {
await serverDB.transaction(async (trx) => {
await trx.insert(sessions).values([
{ id: 'personal-session', updatedAt: new Date('2023-01-01'), userId, workspaceId: null },
{
id: 'workspace-session',
updatedAt: new Date('2023-02-01'),
userId,
workspaceId,
},
]);
await trx.insert(agents).values([
{ id: 'personal-agent', userId, title: 'Personal Agent', workspaceId: null },
{ id: 'workspace-agent', userId, title: 'Workspace Agent', workspaceId },
]);
await trx.insert(agentsToSessions).values([
{ agentId: 'personal-agent', sessionId: 'personal-session', userId, workspaceId: null },
{ agentId: 'workspace-agent', sessionId: 'workspace-session', userId, workspaceId },
]);
});
await new SessionModel(serverDB, userId).deleteAll();
await expect(serverDB.select().from(sessions)).resolves.toEqual([
expect.objectContaining({ id: 'workspace-session', workspaceId }),
]);
await expect(serverDB.select().from(agentsToSessions)).resolves.toEqual([
expect.objectContaining({ agentId: 'workspace-agent', sessionId: 'workspace-session' }),
]);
});
});
@@ -2,7 +2,15 @@ import { asc, eq, inArray } from 'drizzle-orm';
import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import { getTestDB } from '../../../core/getTestDB';
import { agents, messagePlugins, messages, sessions, topics, users } from '../../../schemas';
import {
agents,
messagePlugins,
messages,
sessions,
topics,
users,
workspaces,
} from '../../../schemas';
import type { LobeChatDatabase } from '../../../type';
import type { CreateTopicParams } from '../../topic';
import { TopicModel } from '../../topic';
@@ -95,6 +103,50 @@ describe('TopicModel - Create', () => {
expect(unassociatedMessage[0].topicId).toBeNull();
});
it('should associate workspace messages created by another member', async () => {
const workspaceId = 'topic-create-workspace';
const workspaceSessionId = 'topic-create-workspace-session';
const workspaceTopicModel = new TopicModel(serverDB, userId, workspaceId);
await serverDB.transaction(async (tx) => {
await tx.insert(workspaces).values({
id: workspaceId,
name: 'Topic Create Workspace',
primaryOwnerId: userId,
slug: workspaceId,
});
await tx.insert(sessions).values({
id: workspaceSessionId,
userId,
workspaceId,
});
await tx.insert(messages).values({
id: 'workspace-message-other-member',
role: 'user',
sessionId: workspaceSessionId,
userId: userId2,
workspaceId,
});
});
const createdTopic = await workspaceTopicModel.create(
{
messages: ['workspace-message-other-member'],
sessionId: workspaceSessionId,
title: 'Workspace Topic',
},
'workspace-topic-created',
);
const [updatedMessage] = await serverDB
.select()
.from(messages)
.where(eq(messages.id, 'workspace-message-other-member'));
expect(createdTopic.workspaceId).toBe(workspaceId);
expect(updatedMessage.topicId).toBe(createdTopic.id);
});
it('should create a new topic without associating messages', async () => {
const topicData = {
title: 'New Topic',
@@ -230,6 +282,62 @@ describe('TopicModel - Create', () => {
expect(updatedMessages[2].topicId).toBe(createdTopics[1].id);
});
it('should batch associate workspace messages created by other members', async () => {
const workspaceId = 'topic-batch-workspace';
const workspaceSessionId = 'topic-batch-workspace-session';
const workspaceTopicModel = new TopicModel(serverDB, userId, workspaceId);
await serverDB.transaction(async (tx) => {
await tx.insert(workspaces).values({
id: workspaceId,
name: 'Topic Batch Workspace',
primaryOwnerId: userId,
slug: workspaceId,
});
await tx.insert(sessions).values({
id: workspaceSessionId,
userId,
workspaceId,
});
await tx.insert(messages).values([
{
id: 'workspace-batch-message-1',
role: 'user',
sessionId: workspaceSessionId,
userId: userId2,
workspaceId,
},
{
id: 'workspace-batch-message-2',
role: 'assistant',
sessionId: workspaceSessionId,
userId,
workspaceId,
},
]);
});
const createdTopics = await workspaceTopicModel.batchCreate([
{
messages: ['workspace-batch-message-1', 'workspace-batch-message-2'],
sessionId: workspaceSessionId,
title: 'Workspace Batch Topic',
},
]);
const updatedMessages = await serverDB
.select()
.from(messages)
.where(inArray(messages.id, ['workspace-batch-message-1', 'workspace-batch-message-2']))
.orderBy(asc(messages.id));
expect(createdTopics[0].workspaceId).toBe(workspaceId);
expect(updatedMessages.map((message) => message.topicId)).toEqual([
createdTopics[0].id,
createdTopics[0].id,
]);
});
it('should generate topic IDs if not provided', async () => {
const topicParams = [
{ title: 'Topic 1', favorite: true, sessionId },
@@ -309,6 +417,56 @@ describe('TopicModel - Create', () => {
expect(duplicatedMessages[1].content).toBe('Assistant message');
});
it('should duplicate workspace messages created by other members', async () => {
const workspaceId = 'topic-duplicate-workspace';
const topicId = 'workspace-topic-duplicate';
const workspaceTopicModel = new TopicModel(serverDB, userId, workspaceId);
await serverDB.transaction(async (tx) => {
await tx.insert(workspaces).values({
id: workspaceId,
name: 'Topic Duplicate Workspace',
primaryOwnerId: userId,
slug: workspaceId,
});
await tx.insert(topics).values({
id: topicId,
title: 'Workspace Original Topic',
userId: userId2,
workspaceId,
});
await tx.insert(messages).values([
{
content: 'Other member user message',
id: 'workspace-duplicate-message-1',
role: 'user',
topicId,
userId: userId2,
workspaceId,
},
{
content: 'Current member assistant message',
id: 'workspace-duplicate-message-2',
role: 'assistant',
topicId,
userId,
workspaceId,
},
]);
});
const { topic: duplicatedTopic, messages: duplicatedMessages } =
await workspaceTopicModel.duplicate(topicId, 'Workspace Duplicated Topic');
expect(duplicatedTopic.workspaceId).toBe(workspaceId);
expect(duplicatedMessages).toHaveLength(2);
expect(duplicatedMessages.map((message) => message.content).sort()).toEqual([
'Current member assistant message',
'Other member user message',
]);
expect(duplicatedMessages.every((message) => message.workspaceId === workspaceId)).toBe(true);
});
it('should correctly map parentId references when duplicating messages', async () => {
const topicId = 'topic-with-parent-refs';
@@ -9,6 +9,7 @@ import {
sessions,
topics,
users,
workspaces,
} from '../../../schemas';
import type { LobeChatDatabase } from '../../../type';
import { TopicModel } from '../../topic';
@@ -53,6 +54,49 @@ describe('TopicModel - Query', () => {
expect(result.items[2].id).toBe('4');
});
it('should isolate personal and workspace topics for the same user', async () => {
await serverDB.insert(workspaces).values({
id: 'topic-workspace',
name: 'Workspace',
primaryOwnerId: userId,
slug: 'topic-workspace',
});
await serverDB.insert(sessions).values({
id: 'topic-workspace-session',
userId,
workspaceId: 'topic-workspace',
});
await serverDB.insert(topics).values([
{
id: 'personal-topic',
sessionId,
updatedAt: new Date('2023-01-01'),
userId,
workspaceId: null,
},
{
id: 'workspace-topic',
sessionId: 'topic-workspace-session',
updatedAt: new Date('2023-02-01'),
userId,
workspaceId: 'topic-workspace',
},
]);
await expect(topicModel.query({ containerId: sessionId })).resolves.toMatchObject({
items: [expect.objectContaining({ id: 'personal-topic' })],
total: 1,
});
await expect(
new TopicModel(serverDB, userId, 'topic-workspace').query({
containerId: 'topic-workspace-session',
}),
).resolves.toMatchObject({
items: [expect.objectContaining({ id: 'workspace-topic' })],
total: 1,
});
});
it('should order by status priority when sortBy is "status"', async () => {
await serverDB.insert(topics).values([
// favorite floats to the top regardless of its (lower-priority) status
@@ -0,0 +1,282 @@
// @vitest-environment node
import { eq } from 'drizzle-orm';
import { afterEach, beforeEach, describe, expect, it } from 'vitest';
import { getTestDB } from '../../core/getTestDB';
import {
users,
workspaceAuditLogs,
workspaceInvitations,
workspaceMembers,
workspaces,
} from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { WorkspaceModel } from '../workspace';
import { WorkspaceAuditLogModel } from '../workspaceAuditLog';
import { WorkspaceMemberModel } from '../workspaceMember';
const serverDB: LobeChatDatabase = await getTestDB();
const ownerId = 'workspace-model-owner';
const memberId = 'workspace-model-member';
const secondOwnerId = 'workspace-model-second-owner';
const outsiderId = 'workspace-model-outsider';
const cleanup = async () => {
await serverDB.delete(workspaceAuditLogs);
await serverDB.delete(workspaceInvitations);
await serverDB.delete(workspaceMembers);
await serverDB.delete(workspaces);
await serverDB.delete(users);
};
const createWorkspace = async (id = 'workspace-model-ws') => {
await serverDB.insert(workspaces).values({
id,
name: id,
primaryOwnerId: ownerId,
settings: { gracePeriodUntil: 123, keep: true },
slug: id,
});
await serverDB.insert(workspaceMembers).values([
{ role: 'owner', userId: ownerId, workspaceId: id },
{ role: 'member', userId: memberId, workspaceId: id },
{ role: 'owner', userId: secondOwnerId, workspaceId: id },
]);
return id;
};
beforeEach(async () => {
await cleanup();
await serverDB
.insert(users)
.values([{ id: ownerId }, { id: memberId }, { id: secondOwnerId }, { id: outsiderId }]);
});
afterEach(async () => {
await cleanup();
});
describe('WorkspaceModel', () => {
it('creates the workspace and inserts the creator as owner member', async () => {
const model = new WorkspaceModel(serverDB, ownerId);
const workspace = await model.create({
avatar: 'avatar.png',
description: 'Team workspace',
name: 'Acme',
slug: 'acme',
});
expect(workspace.primaryOwnerId).toBe(ownerId);
expect(workspace.slug).toBe('acme');
const membership = await serverDB.query.workspaceMembers.findFirst({
where: eq(workspaceMembers.workspaceId, workspace.id),
});
expect(membership).toMatchObject({
role: 'owner',
userId: ownerId,
workspaceId: workspace.id,
});
});
it('lists active memberships with their workspace roles and skips deleted memberships', async () => {
const workspaceId = await createWorkspace();
await serverDB
.update(workspaceMembers)
.set({ deletedAt: new Date() })
.where(eq(workspaceMembers.userId, memberId));
const ownerWorkspaces = await new WorkspaceModel(serverDB, ownerId).listUserWorkspaces();
const memberWorkspaces = await new WorkspaceModel(serverDB, memberId).listUserWorkspaces();
expect(ownerWorkspaces).toEqual([expect.objectContaining({ id: workspaceId, role: 'owner' })]);
expect(memberWorkspaces).toEqual([]);
});
it('does not delete workspaces owned by another primary owner', async () => {
const workspaceId = await createWorkspace();
await new WorkspaceModel(serverDB, outsiderId).delete(workspaceId);
const workspace = await serverDB.query.workspaces.findFirst({
where: eq(workspaces.id, workspaceId),
});
expect(workspace).toBeDefined();
});
it('transfers primary ownership only to an active owner member', async () => {
const workspaceId = await createWorkspace();
const model = new WorkspaceModel(serverDB, ownerId);
await expect(model.transferPrimaryOwnership(workspaceId, memberId)).rejects.toThrow(
'Target user must already be an owner',
);
await expect(model.transferPrimaryOwnership(workspaceId, secondOwnerId)).resolves.toEqual({
newPrimaryOwnerUserId: secondOwnerId,
previousPrimaryOwnerUserId: ownerId,
workspaceId,
});
const workspace = await serverDB.query.workspaces.findFirst({
where: eq(workspaces.id, workspaceId),
});
expect(workspace?.primaryOwnerId).toBe(secondOwnerId);
});
it('downgrades to solo by removing non-primary members and clearing grace period', async () => {
const workspaceId = await createWorkspace();
const result = await new WorkspaceModel(serverDB, ownerId).downgradeToSolo(workspaceId);
expect(result.removedUserIds.sort()).toEqual([memberId, secondOwnerId].sort());
expect(result.workspace.settings).toEqual({ keep: true });
const activeMembers = await serverDB.query.workspaceMembers.findMany({
where: eq(workspaceMembers.workspaceId, workspaceId),
});
expect(activeMembers).toEqual(
expect.arrayContaining([
expect.objectContaining({ deletedAt: null, userId: ownerId }),
expect.objectContaining({ userId: memberId }),
expect.objectContaining({ userId: secondOwnerId }),
]),
);
expect(
activeMembers.filter((member) => !member.deletedAt).map((member) => member.userId),
).toEqual([ownerId]);
});
it('sets and clears grace period without dropping unrelated settings', async () => {
const workspaceId = await createWorkspace();
const model = new WorkspaceModel(serverDB, ownerId);
await model.setGracePeriod(workspaceId, 456);
await expect(model.getSettings(workspaceId)).resolves.toEqual({
gracePeriodUntil: 456,
keep: true,
});
await model.setGracePeriod(workspaceId, null);
await expect(model.getSettings(workspaceId)).resolves.toEqual({ keep: true });
});
});
describe('WorkspaceMemberModel', () => {
it('revives a deleted member on addMember and applies the new role', async () => {
const workspaceId = await createWorkspace();
const model = new WorkspaceMemberModel(serverDB, ownerId);
await model.removeMember(workspaceId, memberId);
const revived = await model.addMember({ role: 'viewer', userId: memberId, workspaceId });
expect(revived).toMatchObject({
deletedAt: null,
role: 'viewer',
userId: memberId,
workspaceId,
});
});
it('lists only active members unless includeDeleted is requested', async () => {
const workspaceId = await createWorkspace();
const model = new WorkspaceMemberModel(serverDB, ownerId);
await model.removeMember(workspaceId, memberId);
const active = await model.listMembers(workspaceId);
const all = await model.listMembers(workspaceId, { includeDeleted: true });
expect(active.map((member) => member.userId).sort()).toEqual([ownerId, secondOwnerId].sort());
expect(all.map((member) => member.userId).sort()).toEqual(
[ownerId, memberId, secondOwnerId].sort(),
);
});
it('creates pending invitations with a default member role and expiry', async () => {
const workspaceId = await createWorkspace();
const before = new Date();
const invitation = await new WorkspaceMemberModel(serverDB, ownerId).createInvitation({
email: 'new@example.com',
workspaceId,
});
expect(invitation).toMatchObject({
email: 'new@example.com',
inviterId: ownerId,
role: 'member',
status: 'pending',
workspaceId,
});
expect(invitation.token).toHaveLength(32);
expect(invitation.expiresAt.getTime()).toBeGreaterThan(
before.getTime() + 6 * 24 * 60 * 60 * 1000,
);
});
});
describe('WorkspaceAuditLogModel', () => {
it('creates logs with empty metadata by default', async () => {
const workspaceId = await createWorkspace();
const log = await new WorkspaceAuditLogModel(serverDB).create({
action: 'workspace.created',
userId: ownerId,
workspaceId,
});
expect(log).toMatchObject({
action: 'workspace.created',
metadata: {},
userId: ownerId,
workspaceId,
});
});
it('lists logs by workspace and action with cursor pagination', async () => {
const workspaceId = await createWorkspace();
await serverDB.insert(workspaceAuditLogs).values([
{
action: 'workspace.created',
createdAt: new Date('2026-01-01T00:00:00.000Z'),
resourceId: 'old',
userId: ownerId,
workspaceId,
},
{
action: 'workspace.updated',
createdAt: new Date('2026-01-02T00:00:00.000Z'),
resourceId: 'middle',
userId: ownerId,
workspaceId,
},
{
action: 'workspace.updated',
createdAt: new Date('2026-01-03T00:00:00.000Z'),
resourceId: 'new',
userId: ownerId,
workspaceId,
},
]);
const result = await new WorkspaceAuditLogModel(serverDB).list({
action: 'workspace.updated',
limit: 1,
workspaceId,
});
expect(result.items.map((item) => item.resourceId)).toEqual(['new']);
expect(result.nextCursor).toBe('2026-01-03T00:00:00.000Z');
const next = await new WorkspaceAuditLogModel(serverDB).list({
action: 'workspace.updated',
cursor: new Date(result.nextCursor!),
limit: 1,
workspaceId,
});
expect(next.items.map((item) => item.resourceId)).toEqual(['middle']);
});
});
+274 -111
View File
@@ -8,25 +8,33 @@ import { merge } from '@/utils/merge';
import type { AgentItem } from '../schemas';
import {
agentBotProviders,
agentCronJobs,
agents,
agentsFiles,
agentsKnowledgeBases,
agentsToSessions,
chatGroupsAgents,
documents,
files,
knowledgeBases,
messages,
sessions,
threads,
topics,
} from '../schemas';
import type { LobeChatDatabase } from '../type';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
export class AgentModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
}
/**
@@ -45,21 +53,44 @@ export class AgentModel {
})
.from(agents)
.leftJoin(topics, eq(topics.agentId, agents.id))
.where(
and(
eq(agents.userId, this.userId),
or(eq(agents.slug, INBOX_SESSION_ID), ne(agents.virtual, true)),
),
)
.where(and(this.ownership(), or(eq(agents.slug, INBOX_SESSION_ID), ne(agents.virtual, true))))
.groupBy(agents.id)
.having(({ count }) => gt(count, 0))
.orderBy(desc(sql`count`))
.limit(limit);
};
/**
* Compat-mode ownership predicate for the `agents` table.
* - team mode (workspaceId set): `workspace_id = ?` (every member sees the same agents)
* - personal mode: `user_id = ? AND workspace_id IS NULL`
*/
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agents);
/** Same predicate but for the `sessions` table (used in delete cascade). */
private sessionsOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, sessions);
/** Ownership predicates for the agent join/related tables. */
private documentsOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents);
private agentsFilesOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentsFiles);
private agentsKnowledgeBasesOwnership = () =>
buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
agentsKnowledgeBases,
);
private agentsToSessionsOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentsToSessions);
getAgentConfigById = async (id: string) => {
const agent = await this.db.query.agents.findFirst({
where: and(eq(agents.id, id), eq(agents.userId, this.userId)),
where: and(eq(agents.id, id), this.ownership()),
});
if (!agent) return null;
@@ -71,7 +102,7 @@ export class AgentModel {
const rows = await this.db
.select({ id: agents.id })
.from(agents)
.where(and(eq(agents.id, id), eq(agents.userId, this.userId)))
.where(and(eq(agents.id, id), this.ownership()))
.limit(1);
return rows.length > 0;
@@ -90,9 +121,7 @@ export class AgentModel {
const rows = await this.db
.select({ model: agents.model, provider: agents.provider })
.from(agents)
.where(
and(eq(agents.userId, this.userId), or(eq(agents.id, idOrSlug), eq(agents.slug, idOrSlug))),
)
.where(and(this.ownership(), or(eq(agents.id, idOrSlug), eq(agents.slug, idOrSlug))))
.limit(1);
const row = rows[0];
@@ -107,7 +136,7 @@ export class AgentModel {
private buildQueryAgentsWhere = (keyword?: string) => {
// Include agents where virtual is false OR null (legacy data without virtual field)
const baseConditions = and(
eq(agents.userId, this.userId),
this.ownership(),
or(eq(agents.virtual, false), isNull(agents.virtual)),
);
@@ -173,7 +202,7 @@ export class AgentModel {
title: agents.title,
})
.from(agents)
.where(and(eq(agents.userId, this.userId), inArray(agents.id, ids)));
.where(and(this.ownership(), inArray(agents.id, ids)));
return rows.map(({ slug, ...row }) => ({
...row,
@@ -186,12 +215,15 @@ export class AgentModel {
* Get agent config by ID or slug (single query with OR condition)
*/
getAgentConfig = async (idOrSlug: string) => {
const agent = await this.db.query.agents.findFirst({
where: and(
eq(agents.userId, this.userId),
or(eq(agents.id, idOrSlug), eq(agents.slug, idOrSlug)),
),
});
// Prefer an exact ID match over a slug match. The combined `or(id, slug)`
// query has no inherent ordering, so resolve ID first for determinism.
const agent =
(await this.db.query.agents.findFirst({
where: and(this.ownership(), eq(agents.id, idOrSlug)),
})) ??
(await this.db.query.agents.findFirst({
where: and(this.ownership(), eq(agents.slug, idOrSlug)),
}));
if (!agent) return null;
@@ -214,7 +246,7 @@ export class AgentModel {
if (enabledFileIds.length > 0) {
const documentsData = await this.db.query.documents.findMany({
where: and(eq(documents.userId, this.userId), inArray(documents.fileId, enabledFileIds)),
where: and(this.documentsOwnership(), inArray(documents.fileId, enabledFileIds)),
});
const documentMap = new Map(documentsData.map((doc) => [doc.fileId, doc.content]));
@@ -234,15 +266,13 @@ export class AgentModel {
this.db
.select({ enabled: agentsKnowledgeBases.enabled, knowledgeBases })
.from(agentsKnowledgeBases)
.where(
and(eq(agentsKnowledgeBases.agentId, id), eq(agentsKnowledgeBases.userId, this.userId)),
)
.where(and(eq(agentsKnowledgeBases.agentId, id), this.agentsKnowledgeBasesOwnership()))
.orderBy(desc(agentsKnowledgeBases.createdAt))
.leftJoin(knowledgeBases, eq(knowledgeBases.id, agentsKnowledgeBases.knowledgeBaseId)),
this.db
.select({ enabled: agentsFiles.enabled, files })
.from(agentsFiles)
.where(and(eq(agentsFiles.agentId, id), eq(agentsFiles.userId, this.userId)))
.where(and(eq(agentsFiles.agentId, id), this.agentsFilesOwnership()))
.orderBy(desc(agentsFiles.createdAt))
.leftJoin(files, eq(files.id, agentsFiles.fileId)),
]);
@@ -264,10 +294,7 @@ export class AgentModel {
*/
findBySessionId = async (sessionId: string) => {
const item = await this.db.query.agentsToSessions.findFirst({
where: and(
eq(agentsToSessions.sessionId, sessionId),
eq(agentsToSessions.userId, this.userId),
),
where: and(eq(agentsToSessions.sessionId, sessionId), this.agentsToSessionsOwnership()),
});
if (!item) return;
@@ -282,12 +309,14 @@ export class AgentModel {
knowledgeBaseId: string,
enabled: boolean = true,
) => {
return this.db.insert(agentsKnowledgeBases).values({
agentId,
enabled,
knowledgeBaseId,
userId: this.userId,
});
return this.db
.insert(agentsKnowledgeBases)
.values(
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{ agentId, enabled, knowledgeBaseId },
),
);
};
deleteAgentKnowledgeBase = async (agentId: string, knowledgeBaseId: string) => {
@@ -297,7 +326,7 @@ export class AgentModel {
and(
eq(agentsKnowledgeBases.agentId, agentId),
eq(agentsKnowledgeBases.knowledgeBaseId, knowledgeBaseId),
eq(agentsKnowledgeBases.userId, this.userId),
this.agentsKnowledgeBasesOwnership(),
),
);
};
@@ -310,7 +339,7 @@ export class AgentModel {
and(
eq(agentsKnowledgeBases.agentId, agentId),
eq(agentsKnowledgeBases.knowledgeBaseId, knowledgeBaseId),
eq(agentsKnowledgeBases.userId, this.userId),
this.agentsKnowledgeBasesOwnership(),
),
);
};
@@ -323,7 +352,7 @@ export class AgentModel {
.where(
and(
eq(agentsFiles.agentId, agentId),
eq(agentsFiles.userId, this.userId),
this.agentsFilesOwnership(),
inArray(agentsFiles.fileId, fileIds),
),
);
@@ -337,7 +366,12 @@ export class AgentModel {
return this.db
.insert(agentsFiles)
.values(
needToInsertFileIds.map((fileId) => ({ agentId, enabled, fileId, userId: this.userId })),
needToInsertFileIds.map((fileId) =>
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{ agentId, enabled, fileId },
),
),
);
};
@@ -348,7 +382,7 @@ export class AgentModel {
and(
eq(agentsFiles.agentId, agentId),
eq(agentsFiles.fileId, fileId),
eq(agentsFiles.userId, this.userId),
this.agentsFilesOwnership(),
),
);
};
@@ -363,28 +397,24 @@ export class AgentModel {
const links = await trx
.select({ sessionId: agentsToSessions.sessionId })
.from(agentsToSessions)
.where(
and(eq(agentsToSessions.agentId, agentId), eq(agentsToSessions.userId, this.userId)),
);
.where(and(eq(agentsToSessions.agentId, agentId), this.agentsToSessionsOwnership()));
const sessionIds = links.map((link) => link.sessionId);
// 2. Delete links in agentsToSessions
await trx
.delete(agentsToSessions)
.where(
and(eq(agentsToSessions.agentId, agentId), eq(agentsToSessions.userId, this.userId)),
);
.where(and(eq(agentsToSessions.agentId, agentId), this.agentsToSessionsOwnership()));
// 3. Delete associated sessions (this will cascade delete messages, topics, etc.)
if (sessionIds.length > 0) {
await trx
.delete(sessions)
.where(and(inArray(sessions.id, sessionIds), eq(sessions.userId, this.userId)));
.where(and(inArray(sessions.id, sessionIds), this.sessionsOwnership()));
}
// 4. Delete the agent itself
return trx.delete(agents).where(and(eq(agents.id, agentId), eq(agents.userId, this.userId)));
return trx.delete(agents).where(and(eq(agents.id, agentId), this.ownership()));
});
};
@@ -396,9 +426,7 @@ export class AgentModel {
batchDelete = async (agentIds: string[]) => {
if (agentIds.length === 0) return;
return this.db
.delete(agents)
.where(and(eq(agents.userId, this.userId), inArray(agents.id, agentIds)));
return this.db.delete(agents).where(and(this.ownership(), inArray(agents.id, agentIds)));
};
toggleFile = async (agentId: string, fileId: string, enabled?: boolean) => {
@@ -409,7 +437,7 @@ export class AgentModel {
and(
eq(agentsFiles.agentId, agentId),
eq(agentsFiles.fileId, fileId),
eq(agentsFiles.userId, this.userId),
this.agentsFilesOwnership(),
),
);
};
@@ -422,11 +450,13 @@ export class AgentModel {
const [result] = await this.db
.insert(agents)
.values([
{
...config,
model: typeof config.model === 'string' ? config.model : null,
userId: this.userId,
},
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{
...config,
model: typeof config.model === 'string' ? config.model : null,
},
),
])
.returning();
@@ -443,11 +473,15 @@ export class AgentModel {
return this.db
.insert(agents)
.values(
configs.map((config) => ({
...config,
model: typeof config.model === 'string' ? config.model : null,
userId: this.userId,
})),
configs.map((config) =>
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{
...config,
model: typeof config.model === 'string' ? config.model : null,
},
),
),
)
.returning();
};
@@ -456,7 +490,7 @@ export class AgentModel {
return this.db
.update(agents)
.set({ ...data, updatedAt: new Date() })
.where(and(eq(agents.id, agentId), eq(agents.userId, this.userId)));
.where(and(eq(agents.id, agentId), this.ownership()));
};
touchUpdatedAt = async (agentId: string) => {
@@ -469,7 +503,7 @@ export class AgentModel {
*/
checkByMarketIdentifier = async (marketIdentifier: string): Promise<boolean> => {
const result = await this.db.query.agents.findFirst({
where: and(eq(agents.marketIdentifier, marketIdentifier), eq(agents.userId, this.userId)),
where: and(eq(agents.marketIdentifier, marketIdentifier), this.ownership()),
});
return !!result;
};
@@ -483,7 +517,7 @@ export class AgentModel {
const result = await this.db.query.agents.findFirst({
columns: { id: true },
orderBy: (agents, { desc }) => [desc(agents.updatedAt)],
where: and(eq(agents.marketIdentifier, marketIdentifier), eq(agents.userId, this.userId)),
where: and(eq(agents.marketIdentifier, marketIdentifier), this.ownership()),
});
return result?.id ?? null;
};
@@ -498,7 +532,7 @@ export class AgentModel {
columns: { id: true },
orderBy: (agents, { desc }) => [desc(agents.updatedAt)],
where: and(
eq(agents.userId, this.userId),
this.ownership(),
sql`${agents.params}->>'forkedFromIdentifier' = ${forkedFromIdentifier}`,
),
});
@@ -509,7 +543,7 @@ export class AgentModel {
if (!data || Object.keys(data).length === 0) return;
const agent = await this.db.query.agents.findFirst({
where: and(eq(agents.id, agentId), eq(agents.userId, this.userId)),
where: and(eq(agents.id, agentId), this.ownership()),
});
if (!agent) return;
@@ -562,7 +596,7 @@ export class AgentModel {
return this.db
.update(agents)
.set(updateData)
.where(and(eq(agents.id, agentId), eq(agents.userId, this.userId)));
.where(and(eq(agents.id, agentId), this.ownership()));
};
/**
@@ -572,7 +606,7 @@ export class AgentModel {
const result = await this.db
.update(agents)
.set({ sessionGroupId, updatedAt: new Date() })
.where(and(eq(agents.id, agentId), eq(agents.userId, this.userId)))
.where(and(eq(agents.id, agentId), this.ownership()))
.returning();
return result[0];
@@ -585,7 +619,7 @@ export class AgentModel {
duplicate = async (agentId: string, newTitle?: string): Promise<{ agentId: string } | null> => {
// Get the source agent
const sourceAgent = await this.db.query.agents.findFirst({
where: and(eq(agents.id, agentId), eq(agents.userId, this.userId)),
where: and(eq(agents.id, agentId), this.ownership()),
});
if (!sourceAgent) return null;
@@ -593,32 +627,35 @@ export class AgentModel {
// Create new agent with explicit include fields
const [newAgent] = await this.db
.insert(agents)
.values({
avatar: sourceAgent.avatar,
backgroundColor: sourceAgent.backgroundColor,
chatConfig: sourceAgent.chatConfig,
description: sourceAgent.description,
fewShots: sourceAgent.fewShots,
model: sourceAgent.model,
openingMessage: sourceAgent.openingMessage,
openingQuestions: sourceAgent.openingQuestions,
params: sourceAgent.params,
pinned: sourceAgent.pinned,
// Config
plugins: sourceAgent.plugins,
provider: sourceAgent.provider,
.values(
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{
avatar: sourceAgent.avatar,
backgroundColor: sourceAgent.backgroundColor,
chatConfig: sourceAgent.chatConfig,
description: sourceAgent.description,
fewShots: sourceAgent.fewShots,
model: sourceAgent.model,
openingMessage: sourceAgent.openingMessage,
openingQuestions: sourceAgent.openingQuestions,
params: sourceAgent.params,
pinned: sourceAgent.pinned,
// Config
plugins: sourceAgent.plugins,
provider: sourceAgent.provider,
// Session group
sessionGroupId: sourceAgent.sessionGroupId,
systemRole: sourceAgent.systemRole,
// Session group
sessionGroupId: sourceAgent.sessionGroupId,
systemRole: sourceAgent.systemRole,
tags: sourceAgent.tags,
// Metadata
title: newTitle || (sourceAgent.title ? `${sourceAgent.title} (Copy)` : 'Copy'),
tts: sourceAgent.tts,
// User
userId: this.userId,
})
tags: sourceAgent.tags,
// Metadata
title: newTitle || (sourceAgent.title ? `${sourceAgent.title} (Copy)` : 'Copy'),
tts: sourceAgent.tts,
},
),
)
.returning();
return { agentId: newAgent.id };
@@ -632,7 +669,7 @@ export class AgentModel {
getBuiltinAgent = async (slug: string): Promise<AgentItem | null> => {
// 1. First try to find existing agent by slug
const existing = await this.db.query.agents.findFirst({
where: and(eq(agents.slug, slug), eq(agents.userId, this.userId)),
where: and(eq(agents.slug, slug), this.ownership()),
});
if (existing) return existing;
@@ -647,7 +684,7 @@ export class AgentModel {
.from(sessions)
.innerJoin(agentsToSessions, eq(sessions.id, agentsToSessions.sessionId))
.innerJoin(agents, eq(agentsToSessions.agentId, agents.id))
.where(and(eq(sessions.slug, INBOX_SESSION_ID), eq(sessions.userId, this.userId)))
.where(and(eq(sessions.slug, INBOX_SESSION_ID), this.sessionsOwnership()))
.limit(1);
if (result.length > 0 && result[0].agent) {
@@ -673,30 +710,156 @@ export class AgentModel {
// `onConflictDoNothing`, the loser hits the `agents_slug_user_id_unique`
// constraint; with it, the loser's `.returning()` is empty and we re-read
// the row that won.
// `agents_slug_user_id_unique` is a partial index (WHERE workspace_id IS
// NULL) since migration 0109, so the conflict arbiter must carry the same
// predicate; builtin agents are always workspace-less (workspace_id NULL).
// Bare `onConflictDoNothing()` (no target) does NOT pin an arbiter index,
// so it works whether `agents_slug_user_id_unique` is the legacy full
// unique or the migration-0109 partial (WHERE workspace_id IS NULL) — this
// is the transition-safe form while 0109 rolls out. Tighten back to a
// partitioned { target, where } once 0109 has flipped the index in every
// environment. Payload still carries workspaceId so workspace-scoped
// builtin agents land in the right workspace.
const result = await this.db
.insert(agents)
.values({
model: persistConfig.model,
provider: persistConfig.provider,
slug: persistConfig.slug,
userId: this.userId,
virtual: true,
})
.onConflictDoNothing({
target: [agents.slug, agents.userId],
where: isNull(agents.workspaceId),
})
.values(
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{
model: persistConfig.model,
provider: persistConfig.provider,
slug: persistConfig.slug,
virtual: true,
},
),
)
.onConflictDoNothing()
.returning();
if (result[0]) return result[0];
return (
(await this.db.query.agents.findFirst({
where: and(eq(agents.slug, slug), eq(agents.userId, this.userId)),
where: and(eq(agents.slug, slug), this.ownership()),
})) ?? null
);
};
/**
* Transfer an agent and all its associated data to a different workspace or personal account.
* Runs in a single transaction to ensure atomicity.
*/
transferAgent = async (
agentId: string,
targetWorkspaceId: string | null,
targetUserId: string,
): Promise<{ agentId: string; slug: string | null }> => {
return this.db.transaction(async (trx) => {
// 1. Verify agent exists and belongs to current scope
const agent = await trx.query.agents.findFirst({
where: and(eq(agents.id, agentId), this.ownership()),
});
if (!agent) throw new Error('Agent not found');
// 2. Handle slug conflict in target scope
let slug = agent.slug;
if (slug) {
const buildConflictCheck = (candidate: string) =>
targetWorkspaceId
? and(eq(agents.slug, candidate), eq(agents.workspaceId, targetWorkspaceId))
: and(
eq(agents.slug, candidate),
eq(agents.userId, targetUserId),
isNull(agents.workspaceId),
);
const existing = await trx.query.agents.findFirst({
where: buildConflictCheck(slug),
});
if (existing) {
let suffix = 1;
while (suffix < 100) {
const candidate = `${slug}-${suffix}`;
const conflict = await trx.query.agents.findFirst({
where: buildConflictCheck(candidate),
});
if (!conflict) {
slug = candidate;
break;
}
suffix++;
}
}
}
// 3. Build ownership update payload
const ownershipUpdate = {
userId: targetUserId,
workspaceId: targetWorkspaceId,
};
// 4. Update the agent record
await trx
.update(agents)
.set({ ...ownershipUpdate, slug, updatedAt: new Date() })
.where(eq(agents.id, agentId));
// 5. Update sessions linked via agentsToSessions
const links = await trx
.select({ sessionId: agentsToSessions.sessionId })
.from(agentsToSessions)
.where(eq(agentsToSessions.agentId, agentId));
const sessionIds = links.map((l) => l.sessionId);
if (sessionIds.length > 0) {
await trx.update(sessions).set(ownershipUpdate).where(inArray(sessions.id, sessionIds));
}
await trx
.update(agentsToSessions)
.set(ownershipUpdate)
.where(eq(agentsToSessions.agentId, agentId));
// 6. Update topics (linked via sessionId or agentId)
const topicCondition =
sessionIds.length > 0
? or(inArray(topics.sessionId, sessionIds), eq(topics.agentId, agentId))
: eq(topics.agentId, agentId);
await trx.update(topics).set(ownershipUpdate).where(topicCondition!);
// 7. Update messages (linked via sessionId or agentId)
const messageCondition =
sessionIds.length > 0
? or(inArray(messages.sessionId, sessionIds), eq(messages.agentId, agentId))
: eq(messages.agentId, agentId);
await trx.update(messages).set(ownershipUpdate).where(messageCondition!);
// 8. Update threads (linked via agentId)
await trx.update(threads).set(ownershipUpdate).where(eq(threads.agentId, agentId));
// 9. Update agent files associations
await trx.update(agentsFiles).set(ownershipUpdate).where(eq(agentsFiles.agentId, agentId));
// 10. Update agent knowledge base associations
await trx
.update(agentsKnowledgeBases)
.set(ownershipUpdate)
.where(eq(agentsKnowledgeBases.agentId, agentId));
// 11. Update agent cron jobs
await trx
.update(agentCronJobs)
.set(ownershipUpdate)
.where(eq(agentCronJobs.agentId, agentId));
// 12. Update agent bot providers (transfer, not delete)
await trx
.update(agentBotProviders)
.set(ownershipUpdate)
.where(eq(agentBotProviders.agentId, agentId));
// 13. Remove chat group associations (groups belong to source workspace context)
await trx.delete(chatGroupsAgents).where(eq(chatGroupsAgents.agentId, agentId));
return { agentId, slug };
});
};
}
@@ -3,6 +3,7 @@ import { and, desc, eq } from 'drizzle-orm';
import type { AgentBotProviderItem, NewAgentBotProvider } from '../schemas';
import { agentBotProviders } from '../schemas';
import type { LobeChatDatabase } from '../type';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
interface GateKeeper {
decrypt: (ciphertext: string) => Promise<{ plaintext: string }>;
@@ -16,14 +17,19 @@ export interface DecryptedBotProvider extends Omit<AgentBotProviderItem, 'creden
export class AgentBotProviderModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
private gateKeeper?: GateKeeper;
constructor(db: LobeChatDatabase, userId: string, gateKeeper?: GateKeeper) {
constructor(db: LobeChatDatabase, userId: string, gateKeeper?: GateKeeper, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
this.gateKeeper = gateKeeper;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentBotProviders);
// --------------- User-scoped CRUD ---------------
create = async (
@@ -35,7 +41,12 @@ export class AgentBotProviderModel {
const [result] = await this.db
.insert(agentBotProviders)
.values({ ...params, credentials, userId: this.userId })
.values(
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{ ...params, credentials },
),
)
.returning();
return result;
@@ -44,11 +55,11 @@ export class AgentBotProviderModel {
delete = async (id: string) => {
return this.db
.delete(agentBotProviders)
.where(and(eq(agentBotProviders.id, id), eq(agentBotProviders.userId, this.userId)));
.where(and(eq(agentBotProviders.id, id), this.ownership()));
};
query = async (params?: { agentId?: string; platform?: string }) => {
const conditions = [eq(agentBotProviders.userId, this.userId)];
const conditions = [this.ownership()];
if (params?.agentId) {
conditions.push(eq(agentBotProviders.agentId, params.agentId));
@@ -70,7 +81,7 @@ export class AgentBotProviderModel {
const [result] = await this.db
.select()
.from(agentBotProviders)
.where(and(eq(agentBotProviders.id, id), eq(agentBotProviders.userId, this.userId)))
.where(and(eq(agentBotProviders.id, id), this.ownership()))
.limit(1);
if (!result) return result;
@@ -82,7 +93,7 @@ export class AgentBotProviderModel {
const results = await this.db
.select()
.from(agentBotProviders)
.where(and(eq(agentBotProviders.agentId, agentId), eq(agentBotProviders.userId, this.userId)))
.where(and(eq(agentBotProviders.agentId, agentId), this.ownership()))
.orderBy(desc(agentBotProviders.updatedAt));
return Promise.all(results.map((r) => this.decryptRow(r)));
@@ -104,7 +115,7 @@ export class AgentBotProviderModel {
return this.db
.update(agentBotProviders)
.set({ ...updateValue, updatedAt: new Date() })
.where(and(eq(agentBotProviders.id, id), eq(agentBotProviders.userId, this.userId)));
.where(and(eq(agentBotProviders.id, id), this.ownership()));
};
// --------------- System-wide static methods ---------------
@@ -139,7 +150,7 @@ export class AgentBotProviderModel {
and(
eq(agentBotProviders.platform, platform),
eq(agentBotProviders.applicationId, applicationId),
eq(agentBotProviders.userId, this.userId),
this.ownership(),
eq(agentBotProviders.enabled, true),
),
)
@@ -152,6 +163,88 @@ export class AgentBotProviderModel {
// --------------- System-wide static methods ---------------
/**
* System-wide lookup of an enabled provider by platform + applicationId.
*
* `(platform, applicationId)` is globally unique, so this returns the single
* matching row regardless of which user / workspace owns it. Use only from
* post-authorization runtime layers (gateway service / manager / connect-queue
* cron) where the caller has already been authorized at the router boundary —
* never as an authorization check itself.
*/
static findEnabledByPlatformAndAppId = async (
db: LobeChatDatabase,
platform: string,
applicationId: string,
gateKeeper?: GateKeeper,
): Promise<DecryptedBotProvider | null> => {
const [result] = await db
.select()
.from(agentBotProviders)
.where(
and(
eq(agentBotProviders.platform, platform),
eq(agentBotProviders.applicationId, applicationId),
eq(agentBotProviders.enabled, true),
),
)
.limit(1);
if (!result) return null;
if (!result.credentials) return { ...result, credentials: {} };
try {
const credentials = gateKeeper
? JSON.parse((await gateKeeper.decrypt(result.credentials)).plaintext)
: JSON.parse(result.credentials);
return { ...result, credentials };
} catch {
return { ...result, credentials: {} };
}
};
/**
* System-wide lookup of all providers under an agent.
*
* An agent belongs to a single owner / workspace, so this returns every row
* for the agent regardless of scope. Same authorization caveat as
* {@link findEnabledByPlatformAndAppId}: runtime-layer use only.
*/
static findByAgentId = async (
db: LobeChatDatabase,
agentId: string,
gateKeeper?: GateKeeper,
): Promise<DecryptedBotProvider[]> => {
const results = await db
.select()
.from(agentBotProviders)
.where(eq(agentBotProviders.agentId, agentId))
.orderBy(desc(agentBotProviders.updatedAt));
const decrypted: DecryptedBotProvider[] = [];
for (const r of results) {
if (!r.credentials) {
decrypted.push({ ...r, credentials: {} });
continue;
}
try {
const credentials = gateKeeper
? JSON.parse((await gateKeeper.decrypt(r.credentials)).plaintext)
: JSON.parse(r.credentials);
decrypted.push({ ...r, credentials });
} catch {
decrypted.push({ ...r, credentials: {} });
}
}
return decrypted;
};
static findEnabledByPlatform = async (
db: LobeChatDatabase,
platform: string,
+29 -20
View File
@@ -8,27 +8,36 @@ import type {
} from '../schemas/agentCronJob';
import { agentCronJobs } from '../schemas/agentCronJob';
import type { LobeChatDatabase } from '../type';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
export class AgentCronJobModel {
private readonly userId: string;
private readonly db: LobeChatDatabase;
private readonly workspaceId?: string;
constructor(db: LobeChatDatabase, userId?: string) {
constructor(db: LobeChatDatabase, userId?: string, workspaceId?: string) {
this.db = db;
this.userId = userId!;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentCronJobs);
// Create a new cron job
async create(data: CreateAgentCronJobData): Promise<AgentCronJob> {
const cronJob = await this.db
.insert(agentCronJobs)
.values({
...data,
// Initialize remaining executions to match max executions
remainingExecutions: data.maxExecutions,
userId: this.userId,
} as NewAgentCronJob)
.values(
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{
...data,
// Initialize remaining executions to match max executions
remainingExecutions: data.maxExecutions,
},
) as NewAgentCronJob,
)
.returning();
return cronJob[0];
@@ -39,7 +48,7 @@ export class AgentCronJobModel {
const result = await this.db
.select()
.from(agentCronJobs)
.where(and(eq(agentCronJobs.id, id), eq(agentCronJobs.userId, this.userId)))
.where(and(eq(agentCronJobs.id, id), this.ownership()))
.limit(1);
return result[0] || null;
@@ -50,7 +59,7 @@ export class AgentCronJobModel {
return this.db
.select()
.from(agentCronJobs)
.where(and(eq(agentCronJobs.agentId, agentId), eq(agentCronJobs.userId, this.userId)))
.where(and(eq(agentCronJobs.agentId, agentId), this.ownership()))
.orderBy(desc(agentCronJobs.createdAt));
}
@@ -59,7 +68,7 @@ export class AgentCronJobModel {
return this.db
.select()
.from(agentCronJobs)
.where(eq(agentCronJobs.userId, this.userId))
.where(this.ownership())
.orderBy(desc(agentCronJobs.lastExecutedAt));
}
@@ -109,7 +118,7 @@ export class AgentCronJobModel {
const result = await this.db
.update(agentCronJobs)
.set(updateData)
.where(and(eq(agentCronJobs.id, id), eq(agentCronJobs.userId, this.userId)))
.where(and(eq(agentCronJobs.id, id), this.ownership()))
.returning();
return result[0] || null;
@@ -119,7 +128,7 @@ export class AgentCronJobModel {
async delete(id: string): Promise<boolean> {
const result = await this.db
.delete(agentCronJobs)
.where(and(eq(agentCronJobs.id, id), eq(agentCronJobs.userId, this.userId)))
.where(and(eq(agentCronJobs.id, id), this.ownership()))
.returning();
return result.length > 0;
@@ -181,7 +190,7 @@ export class AgentCronJobModel {
totalExecutions: 0,
updatedAt: new Date(),
})
.where(and(eq(agentCronJobs.id, id), eq(agentCronJobs.userId, this.userId)))
.where(and(eq(agentCronJobs.id, id), this.ownership()))
.returning();
return result[0] || null;
@@ -194,7 +203,7 @@ export class AgentCronJobModel {
.from(agentCronJobs)
.where(
and(
eq(agentCronJobs.userId, this.userId),
this.ownership(),
eq(agentCronJobs.enabled, true),
gt(agentCronJobs.remainingExecutions, 0),
sql`${agentCronJobs.remainingExecutions} <= ${threshold}`,
@@ -208,7 +217,7 @@ export class AgentCronJobModel {
return this.db
.select()
.from(agentCronJobs)
.where(and(eq(agentCronJobs.userId, this.userId), eq(agentCronJobs.enabled, enabled)))
.where(and(this.ownership(), eq(agentCronJobs.enabled, enabled)))
.orderBy(desc(agentCronJobs.updatedAt));
}
@@ -232,7 +241,7 @@ export class AgentCronJobModel {
totalJobs: sql<number>`count(*)`,
})
.from(agentCronJobs)
.where(eq(agentCronJobs.userId, this.userId));
.where(this.ownership());
const stats = result[0];
return {
@@ -251,7 +260,7 @@ export class AgentCronJobModel {
enabled,
updatedAt: new Date(),
})
.where(and(inArray(agentCronJobs.id, ids), eq(agentCronJobs.userId, this.userId)))
.where(and(inArray(agentCronJobs.id, ids), this.ownership()))
.returning();
return result.length;
@@ -262,7 +271,7 @@ export class AgentCronJobModel {
const result = await this.db
.select({ count: sql<number>`count(*)` })
.from(agentCronJobs)
.where(and(eq(agentCronJobs.agentId, agentId), eq(agentCronJobs.userId, this.userId)));
.where(and(eq(agentCronJobs.agentId, agentId), this.ownership()));
return Number(result[0].count);
}
@@ -276,7 +285,7 @@ export class AgentCronJobModel {
}): Promise<{ jobs: AgentCronJob[]; total: number }> {
const { agentId, enabled, limit = 20, offset = 0 } = options;
const whereConditions = [eq(agentCronJobs.userId, this.userId)];
const whereConditions = [this.ownership()];
if (agentId) {
whereConditions.push(eq(agentCronJobs.agentId, agentId));
@@ -4,6 +4,7 @@ import { and, asc, desc, eq, inArray, isNotNull, isNull, like, or, sql } from 'd
import type { DocumentItem, NewAgentDocument, NewDocument } from '../../schemas';
import { AGENT_SKILL_TEMPLATE_ID, agentDocuments, documents } from '../../schemas';
import type { LobeChatDatabase, Transaction } from '../../type';
import { buildWorkspaceWhere } from '../../utils/workspace';
import { deriveAgentDocumentFields } from './deriveFields';
import { buildDocumentFilename } from './filename';
import {
@@ -71,13 +72,31 @@ interface ConvertAgentDocumentToSkillIndexParams {
export class AgentDocumentModel {
private userId: string;
private workspaceId?: string;
private db: LobeChatDatabase;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.workspaceId = workspaceId;
this.db = db;
}
/**
* Workspace-aware ownership predicate for the `agent_documents` binding table.
* Personal mode → `user_id = ? AND workspace_id IS NULL`; workspace mode → `workspace_id = ?`.
*/
private agentDocOwnership() {
return buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
agentDocuments,
);
}
/** Workspace-aware ownership predicate for the backing `documents` rows. */
private documentOwnership() {
return buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents);
}
private getDocumentStats(content: string) {
if (!content) return { totalCharCount: 0, totalLineCount: 0 };
@@ -175,7 +194,7 @@ export class AgentDocumentModel {
.innerJoin(documents, eq(agentDocuments.documentId, documents.id))
.where(
and(
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
eq(agentDocuments.agentId, agentId),
inArray(documents.parentId, parentIds),
...this.buildDeletedAtFilters(options),
@@ -212,7 +231,7 @@ export class AgentDocumentModel {
const [doc] = await trx
.select()
.from(documents)
.where(and(eq(documents.id, documentId), eq(documents.userId, this.userId)))
.where(and(eq(documents.id, documentId), this.documentOwnership()))
.limit(1);
if (!doc) return { id: '' };
@@ -235,6 +254,7 @@ export class AgentDocumentModel {
policyLoadPosition: DocumentLoadPosition.BEFORE_FIRST_USER,
policyLoadRule: DocumentLoadRule.ALWAYS,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})
.onConflictDoNothing()
.returning({ id: agentDocuments.id });
@@ -332,6 +352,7 @@ export class AgentDocumentModel {
totalLineCount: stats.totalLineCount,
updatedAt: updatedAt ?? createdAt,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
};
const [insertedDocument] = await trx.insert(documents).values(documentPayload).returning();
@@ -361,6 +382,7 @@ export class AgentDocumentModel {
templateId,
updatedAt: updatedAt ?? createdAt,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
};
const [settings] = await trx.insert(agentDocuments).values(newDoc).returning();
@@ -414,7 +436,7 @@ export class AgentDocumentModel {
.where(
and(
eq(agentDocuments.id, params.agentDocumentId),
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
isNull(agentDocuments.deletedAt),
),
)
@@ -445,7 +467,7 @@ export class AgentDocumentModel {
totalLineCount: stats.totalLineCount,
updatedAt,
})
.where(and(eq(documents.id, existing.documentId), eq(documents.userId, this.userId)));
.where(and(eq(documents.id, existing.documentId), this.documentOwnership()));
await trx
.update(agentDocuments)
@@ -457,7 +479,7 @@ export class AgentDocumentModel {
.where(
and(
eq(agentDocuments.id, params.agentDocumentId),
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
isNull(agentDocuments.deletedAt),
),
);
@@ -469,7 +491,7 @@ export class AgentDocumentModel {
.where(
and(
eq(agentDocuments.id, params.agentDocumentId),
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
isNull(agentDocuments.deletedAt),
),
)
@@ -559,13 +581,13 @@ export class AgentDocumentModel {
await trx
.update(documents)
.set(documentUpdate)
.where(and(eq(documents.id, existing.documentId), eq(documents.userId, this.userId)));
.where(and(eq(documents.id, existing.documentId), this.documentOwnership()));
}
await trx
.update(agentDocuments)
.set(settingsUpdate)
.where(and(eq(agentDocuments.id, documentId), eq(agentDocuments.userId, this.userId)));
.where(and(eq(agentDocuments.id, documentId), this.agentDocOwnership()));
});
}
@@ -616,7 +638,7 @@ export class AgentDocumentModel {
...(params.parentId !== undefined && { parentId: params.parentId }),
...(params.title !== undefined && { title: params.title }),
})
.where(and(eq(documents.id, existing.documentId), eq(documents.userId, this.userId)));
.where(and(eq(documents.id, existing.documentId), this.documentOwnership()));
return this.findById(agentDocumentId);
}
@@ -658,7 +680,7 @@ export class AgentDocumentModel {
source,
title,
})
.where(and(eq(documents.id, existing.documentId), eq(documents.userId, this.userId)));
.where(and(eq(documents.id, existing.documentId), this.documentOwnership()));
});
return this.findById(documentId);
@@ -701,7 +723,7 @@ export class AgentDocumentModel {
source,
title: filename,
})
.where(and(eq(documents.id, existing.documentId), eq(documents.userId, this.userId)));
.where(and(eq(documents.id, existing.documentId), this.documentOwnership()));
});
return this.findById(documentId);
@@ -749,7 +771,7 @@ export class AgentDocumentModel {
.where(
and(
eq(agentDocuments.id, documentId),
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
isNull(agentDocuments.deletedAt),
),
);
@@ -775,7 +797,7 @@ export class AgentDocumentModel {
.where(
and(
eq(agentDocuments.id, documentId),
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
...this.buildDeletedAtFilters(options),
),
)
@@ -871,7 +893,7 @@ export class AgentDocumentModel {
.innerJoin(documents, eq(agentDocuments.documentId, documents.id))
.where(
and(
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
eq(agentDocuments.agentId, agentId),
isNull(agentDocuments.deletedAt),
),
@@ -895,7 +917,7 @@ export class AgentDocumentModel {
.innerJoin(documents, eq(agentDocuments.documentId, documents.id))
.where(
and(
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
eq(agentDocuments.agentId, agentId),
isNull(agentDocuments.deletedAt),
or(
@@ -958,7 +980,7 @@ export class AgentDocumentModel {
.innerJoin(documents, eq(agentDocuments.documentId, documents.id))
.where(
and(
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
eq(agentDocuments.agentId, agentId),
isNull(agentDocuments.deletedAt),
),
@@ -1013,7 +1035,7 @@ export class AgentDocumentModel {
.innerJoin(documents, eq(agentDocuments.documentId, documents.id))
.where(
and(
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
eq(agentDocuments.agentId, agentId),
inArray(agentDocuments.documentId, documentIds),
isNull(agentDocuments.deletedAt),
@@ -1037,7 +1059,7 @@ export class AgentDocumentModel {
.from(agentDocuments)
.where(
and(
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
eq(agentDocuments.agentId, agentId),
isNull(agentDocuments.deletedAt),
),
@@ -1054,7 +1076,7 @@ export class AgentDocumentModel {
.innerJoin(documents, eq(agentDocuments.documentId, documents.id))
.where(
and(
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
eq(agentDocuments.agentId, agentId),
eq(agentDocuments.templateId, templateId),
isNull(agentDocuments.deletedAt),
@@ -1083,7 +1105,7 @@ export class AgentDocumentModel {
.innerJoin(documents, eq(agentDocuments.documentId, documents.id))
.where(
and(
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
eq(agentDocuments.agentId, agentId),
eq(documents.filename, filename),
...this.buildDeletedAtFilters(options),
@@ -1109,7 +1131,7 @@ export class AgentDocumentModel {
.innerJoin(documents, eq(agentDocuments.documentId, documents.id))
.where(
and(
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
eq(agentDocuments.agentId, agentId),
eq(documents.filename, filename),
parentId ? eq(documents.parentId, parentId) : isNull(documents.parentId),
@@ -1149,7 +1171,7 @@ export class AgentDocumentModel {
.innerJoin(documents, eq(agentDocuments.documentId, documents.id))
.where(
and(
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
eq(agentDocuments.agentId, agentId),
eq(documents.filename, filename),
parentId ? eq(documents.parentId, parentId) : isNull(documents.parentId),
@@ -1174,7 +1196,7 @@ export class AgentDocumentModel {
.innerJoin(documents, eq(agentDocuments.documentId, documents.id))
.where(
and(
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
eq(agentDocuments.agentId, agentId),
eq(agentDocuments.documentId, documentId),
...this.buildDeletedAtFilters(options),
@@ -1199,7 +1221,7 @@ export class AgentDocumentModel {
.innerJoin(documents, eq(agentDocuments.documentId, documents.id))
.where(
and(
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
eq(agentDocuments.agentId, agentId),
parentId ? eq(documents.parentId, parentId) : isNull(documents.parentId),
...this.buildDeletedAtFilters(options),
@@ -1219,9 +1241,9 @@ export class AgentDocumentModel {
.innerJoin(documents, eq(agentDocuments.documentId, documents.id))
.where(
and(
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
eq(agentDocuments.agentId, agentId),
eq(documents.userId, this.userId),
this.documentOwnership(),
isNotNull(agentDocuments.deletedAt),
),
)
@@ -1270,7 +1292,7 @@ export class AgentDocumentModel {
.where(
and(
eq(agentDocuments.id, documentId),
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
isNull(agentDocuments.deletedAt),
),
);
@@ -1296,7 +1318,7 @@ export class AgentDocumentModel {
})
.where(
and(
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
inArray(
agentDocuments.id,
subtree.map((item) => item.id),
@@ -1336,7 +1358,7 @@ export class AgentDocumentModel {
deletedByUserId: null,
policyLoad: PolicyLoad.PROGRESSIVE,
})
.where(and(eq(agentDocuments.id, documentId), eq(agentDocuments.userId, this.userId)));
.where(and(eq(agentDocuments.id, documentId), this.agentDocOwnership()));
});
}
@@ -1358,7 +1380,7 @@ export class AgentDocumentModel {
})
.where(
and(
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
inArray(
agentDocuments.id,
subtree.map((item) => item.id),
@@ -1375,11 +1397,11 @@ export class AgentDocumentModel {
await this.db.transaction(async (trx) => {
await trx
.delete(agentDocuments)
.where(and(eq(agentDocuments.id, documentId), eq(agentDocuments.userId, this.userId)));
.where(and(eq(agentDocuments.id, documentId), this.agentDocOwnership()));
await trx
.delete(documents)
.where(and(eq(documents.id, existing.documentId), eq(documents.userId, this.userId)));
.where(and(eq(documents.id, existing.documentId), this.documentOwnership()));
});
}
@@ -1399,13 +1421,11 @@ export class AgentDocumentModel {
await this.db.transaction(async (trx) => {
await trx
.delete(agentDocuments)
.where(
and(eq(agentDocuments.userId, this.userId), inArray(agentDocuments.id, agentDocumentIds)),
);
.where(and(this.agentDocOwnership(), inArray(agentDocuments.id, agentDocumentIds)));
await trx
.delete(documents)
.where(and(eq(documents.userId, this.userId), inArray(documents.id, documentIds)));
.where(and(this.documentOwnership(), inArray(documents.id, documentIds)));
});
}
@@ -1423,7 +1443,7 @@ export class AgentDocumentModel {
.where(
and(
eq(agentDocuments.agentId, agentId),
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
isNull(agentDocuments.deletedAt),
),
);
@@ -1447,7 +1467,7 @@ export class AgentDocumentModel {
and(
eq(agentDocuments.agentId, agentId),
eq(agentDocuments.templateId, templateId),
eq(agentDocuments.userId, this.userId),
this.agentDocOwnership(),
isNull(agentDocuments.deletedAt),
),
);
@@ -8,23 +8,46 @@ import {
type NewAgentEvalBenchmark,
} from '../../schemas';
import { type LobeChatDatabase } from '../../type';
import { buildWorkspaceWhere } from '../../utils/workspace';
export class AgentEvalBenchmarkModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.db = db;
this.userId = userId;
this.workspaceId = workspaceId;
}
/**
* Ownership predicate: rows the current actor can see/edit. Includes
* workspace-scoped or personal rows AND system rows (`userId IS NULL`).
*/
private ownership = () =>
or(
buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
agentEvalBenchmarks,
),
isNull(agentEvalBenchmarks.userId),
);
/** Mutate-only predicate excluding system rows. */
private mutableOwnership = () =>
buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
agentEvalBenchmarks,
);
/**
* Create a new benchmark
*/
create = async (params: NewAgentEvalBenchmark) => {
const [result] = await this.db
.insert(agentEvalBenchmarks)
.values({ ...params, userId: this.userId })
.values({ ...params, userId: this.userId, workspaceId: this.workspaceId ?? null })
.returning();
return result;
};
@@ -39,23 +62,20 @@ export class AgentEvalBenchmarkModel {
and(
eq(agentEvalBenchmarks.id, id),
eq(agentEvalBenchmarks.isSystem, false),
eq(agentEvalBenchmarks.userId, this.userId),
this.mutableOwnership(),
),
);
};
/**
* Query benchmarks (system + user-created)
* Query benchmarks (system + user/workspace-created)
* @param includeSystem - Whether to include system benchmarks (default: true)
*/
query = async (includeSystem = true) => {
const userCondition = or(
eq(agentEvalBenchmarks.userId, this.userId),
isNull(agentEvalBenchmarks.userId),
);
const userCondition = this.ownership();
const conditions = includeSystem
? userCondition
: and(eq(agentEvalBenchmarks.isSystem, false), userCondition);
: and(eq(agentEvalBenchmarks.isSystem, false), this.ownership());
const datasetCountSq = this.db
.select({
@@ -63,6 +83,12 @@ export class AgentEvalBenchmarkModel {
count: count().as('dataset_count'),
})
.from(agentEvalDatasets)
.where(
buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
agentEvalDatasets,
),
)
.groupBy(agentEvalDatasets.benchmarkId)
.as('dc');
@@ -73,6 +99,12 @@ export class AgentEvalBenchmarkModel {
})
.from(agentEvalTestCases)
.innerJoin(agentEvalDatasets, eq(agentEvalTestCases.datasetId, agentEvalDatasets.id))
.where(
buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
agentEvalDatasets,
),
)
.groupBy(agentEvalDatasets.benchmarkId)
.as('tc');
@@ -83,7 +115,9 @@ export class AgentEvalBenchmarkModel {
})
.from(agentEvalRuns)
.innerJoin(agentEvalDatasets, eq(agentEvalRuns.datasetId, agentEvalDatasets.id))
.where(eq(agentEvalRuns.userId, this.userId))
.where(
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentEvalRuns),
)
.groupBy(agentEvalDatasets.benchmarkId)
.as('rc');
@@ -109,7 +143,13 @@ export class AgentEvalBenchmarkModel {
.from(agentEvalRuns)
.innerJoin(agentEvalDatasets, eq(agentEvalRuns.datasetId, agentEvalDatasets.id))
.where(
and(eq(agentEvalDatasets.benchmarkId, row.id), eq(agentEvalRuns.userId, this.userId)),
and(
eq(agentEvalDatasets.benchmarkId, row.id),
buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
agentEvalRuns,
),
),
)
.orderBy(desc(agentEvalRuns.createdAt))
.limit(5);
@@ -144,12 +184,7 @@ export class AgentEvalBenchmarkModel {
const [result] = await this.db
.select()
.from(agentEvalBenchmarks)
.where(
and(
eq(agentEvalBenchmarks.id, id),
or(eq(agentEvalBenchmarks.userId, this.userId), isNull(agentEvalBenchmarks.userId)),
),
)
.where(and(eq(agentEvalBenchmarks.id, id), this.ownership()))
.limit(1);
return result;
};
@@ -161,12 +196,7 @@ export class AgentEvalBenchmarkModel {
const [result] = await this.db
.select()
.from(agentEvalBenchmarks)
.where(
and(
eq(agentEvalBenchmarks.identifier, identifier),
or(eq(agentEvalBenchmarks.userId, this.userId), isNull(agentEvalBenchmarks.userId)),
),
)
.where(and(eq(agentEvalBenchmarks.identifier, identifier), this.ownership()))
.limit(1);
return result;
};
@@ -182,7 +212,7 @@ export class AgentEvalBenchmarkModel {
and(
eq(agentEvalBenchmarks.id, id),
eq(agentEvalBenchmarks.isSystem, false),
eq(agentEvalBenchmarks.userId, this.userId),
this.mutableOwnership(),
),
)
.returning();
@@ -2,23 +2,40 @@ import { and, asc, count, desc, eq, isNull, or } from 'drizzle-orm';
import { agentEvalDatasets, agentEvalTestCases, type NewAgentEvalDataset } from '../../schemas';
import { type LobeChatDatabase } from '../../type';
import { buildWorkspaceWhere } from '../../utils/workspace';
export class AgentEvalDatasetModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.db = db;
this.userId = userId;
this.workspaceId = workspaceId;
}
/** Includes system datasets (`userId IS NULL`) on read. */
private ownership = () =>
or(
buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
agentEvalDatasets,
),
isNull(agentEvalDatasets.userId),
);
/** Mutate-only predicate excluding system rows. */
private mutableOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentEvalDatasets);
/**
* Create a new dataset
*/
create = async (params: NewAgentEvalDataset) => {
const [result] = await this.db
.insert(agentEvalDatasets)
.values({ ...params, userId: this.userId })
.values({ ...params, userId: this.userId, workspaceId: this.workspaceId ?? null })
.returning();
return result;
};
@@ -29,17 +46,15 @@ export class AgentEvalDatasetModel {
delete = async (id: string) => {
return this.db
.delete(agentEvalDatasets)
.where(and(eq(agentEvalDatasets.id, id), eq(agentEvalDatasets.userId, this.userId)));
.where(and(eq(agentEvalDatasets.id, id), this.mutableOwnership()));
};
/**
* Query datasets (system + user-owned) with test case counts
* Query datasets (system + user/workspace-owned) with test case counts
* @param benchmarkId - Optional benchmark filter
*/
query = async (benchmarkId?: string) => {
const conditions = [
or(eq(agentEvalDatasets.userId, this.userId), isNull(agentEvalDatasets.userId)),
];
const conditions = [this.ownership()];
if (benchmarkId) {
conditions.push(eq(agentEvalDatasets.benchmarkId, benchmarkId));
@@ -74,12 +89,7 @@ export class AgentEvalDatasetModel {
const [dataset] = await this.db
.select()
.from(agentEvalDatasets)
.where(
and(
eq(agentEvalDatasets.id, id),
or(eq(agentEvalDatasets.userId, this.userId), isNull(agentEvalDatasets.userId)),
),
)
.where(and(eq(agentEvalDatasets.id, id), this.ownership()))
.limit(1);
if (!dataset) return undefined;
@@ -100,7 +110,7 @@ export class AgentEvalDatasetModel {
const [result] = await this.db
.update(agentEvalDatasets)
.set({ ...value, updatedAt: new Date() })
.where(and(eq(agentEvalDatasets.id, id), eq(agentEvalDatasets.userId, this.userId)))
.where(and(eq(agentEvalDatasets.id, id), this.mutableOwnership()))
.returning();
return result;
};
+13 -9
View File
@@ -2,23 +2,29 @@ import { and, count, desc, eq, inArray } from 'drizzle-orm';
import { agentEvalDatasets, agentEvalRuns, type NewAgentEvalRun } from '../../schemas';
import { type LobeChatDatabase } from '../../type';
import { buildWorkspaceWhere } from '../../utils/workspace';
export class AgentEvalRunModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.db = db;
this.userId = userId;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentEvalRuns);
/**
* Create a new run
*/
create = async (params: Omit<NewAgentEvalRun, 'userId'>) => {
const [result] = await this.db
.insert(agentEvalRuns)
.values({ ...params, userId: this.userId })
.values({ ...params, userId: this.userId, workspaceId: this.workspaceId ?? null })
.returning();
return result;
};
@@ -33,7 +39,7 @@ export class AgentEvalRunModel {
offset?: number;
status?: 'idle' | 'pending' | 'running' | 'completed' | 'failed' | 'aborted' | 'external';
}) => {
const conditions = [eq(agentEvalRuns.userId, this.userId)];
const conditions = [this.ownership()];
if (filter?.datasetId) {
conditions.push(eq(agentEvalRuns.datasetId, filter.datasetId));
@@ -77,7 +83,7 @@ export class AgentEvalRunModel {
const [result] = await this.db
.select()
.from(agentEvalRuns)
.where(and(eq(agentEvalRuns.id, id), eq(agentEvalRuns.userId, this.userId)))
.where(and(eq(agentEvalRuns.id, id), this.ownership()))
.limit(1);
return result;
};
@@ -89,7 +95,7 @@ export class AgentEvalRunModel {
const [result] = await this.db
.update(agentEvalRuns)
.set({ ...value, updatedAt: new Date() })
.where(and(eq(agentEvalRuns.id, id), eq(agentEvalRuns.userId, this.userId)))
.where(and(eq(agentEvalRuns.id, id), this.ownership()))
.returning();
return result;
};
@@ -98,9 +104,7 @@ export class AgentEvalRunModel {
* Delete run (only user-created runs)
*/
delete = async (id: string) => {
return this.db
.delete(agentEvalRuns)
.where(and(eq(agentEvalRuns.id, id), eq(agentEvalRuns.userId, this.userId)));
return this.db.delete(agentEvalRuns).where(and(eq(agentEvalRuns.id, id), this.ownership()));
};
/**
@@ -110,7 +114,7 @@ export class AgentEvalRunModel {
const result = await this.db
.select({ value: count() })
.from(agentEvalRuns)
.where(and(eq(agentEvalRuns.datasetId, datasetId), eq(agentEvalRuns.userId, this.userId)));
.where(and(eq(agentEvalRuns.datasetId, datasetId), this.ownership()));
return Number(result[0]?.value) || 0;
};
}
@@ -9,22 +9,32 @@ import {
topics,
} from '../../schemas';
import { type LobeChatDatabase } from '../../type';
import { buildWorkspaceWhere } from '../../utils/workspace';
export class AgentEvalRunTopicModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.db = db;
this.userId = userId;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentEvalRunTopics);
/**
* Batch create run-topic associations
*/
batchCreate = async (items: Omit<NewAgentEvalRunTopic, 'userId'>[]) => {
if (items.length === 0) return [];
const withUserId = items.map((item) => ({ ...item, userId: this.userId }));
const withUserId = items.map((item) => ({
...item,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
}));
return this.db.insert(agentEvalRunTopics).values(withUserId).returning();
};
@@ -48,7 +58,7 @@ export class AgentEvalRunTopicModel {
.from(agentEvalRunTopics)
.leftJoin(agentEvalTestCases, eq(agentEvalRunTopics.testCaseId, agentEvalTestCases.id))
.leftJoin(topics, eq(agentEvalRunTopics.topicId, topics.id))
.where(and(eq(agentEvalRunTopics.runId, runId), eq(agentEvalRunTopics.userId, this.userId)))
.where(and(eq(agentEvalRunTopics.runId, runId), this.ownership()))
.orderBy(asc(agentEvalTestCases.sortOrder));
return rows;
@@ -60,7 +70,7 @@ export class AgentEvalRunTopicModel {
deleteByRunId = async (runId: string) => {
return this.db
.delete(agentEvalRunTopics)
.where(and(eq(agentEvalRunTopics.runId, runId), eq(agentEvalRunTopics.userId, this.userId)));
.where(and(eq(agentEvalRunTopics.runId, runId), this.ownership()));
};
/**
@@ -82,12 +92,7 @@ export class AgentEvalRunTopicModel {
.from(agentEvalRunTopics)
.leftJoin(agentEvalRuns, eq(agentEvalRunTopics.runId, agentEvalRuns.id))
.leftJoin(topics, eq(agentEvalRunTopics.topicId, topics.id))
.where(
and(
eq(agentEvalRunTopics.testCaseId, testCaseId),
eq(agentEvalRunTopics.userId, this.userId),
),
)
.where(and(eq(agentEvalRunTopics.testCaseId, testCaseId), this.ownership()))
.orderBy(desc(agentEvalRunTopics.createdAt));
return rows;
@@ -117,7 +122,7 @@ export class AgentEvalRunTopicModel {
and(
eq(agentEvalRunTopics.runId, runId),
eq(agentEvalRunTopics.testCaseId, testCaseId),
eq(agentEvalRunTopics.userId, this.userId),
this.ownership(),
),
)
.limit(1);
@@ -136,7 +141,7 @@ export class AgentEvalRunTopicModel {
.set({ status: 'error', evalResult: { error: 'Aborted' } })
.where(
and(
eq(agentEvalRunTopics.userId, this.userId),
this.ownership(),
eq(agentEvalRunTopics.runId, runId),
or(eq(agentEvalRunTopics.status, 'pending'), eq(agentEvalRunTopics.status, 'running')),
),
@@ -151,7 +156,7 @@ export class AgentEvalRunTopicModel {
.set({ status: 'timeout' })
.where(
and(
eq(agentEvalRunTopics.userId, this.userId),
this.ownership(),
eq(agentEvalRunTopics.runId, runId),
eq(agentEvalRunTopics.status, 'running'),
lt(agentEvalRunTopics.createdAt, deadline),
@@ -165,7 +170,7 @@ export class AgentEvalRunTopicModel {
.delete(agentEvalRunTopics)
.where(
and(
eq(agentEvalRunTopics.userId, this.userId),
this.ownership(),
eq(agentEvalRunTopics.runId, runId),
eq(agentEvalRunTopics.testCaseId, testCaseId),
),
@@ -181,7 +186,7 @@ export class AgentEvalRunTopicModel {
.delete(agentEvalRunTopics)
.where(
and(
eq(agentEvalRunTopics.userId, this.userId),
this.ownership(),
eq(agentEvalRunTopics.runId, runId),
or(eq(agentEvalRunTopics.status, 'error'), eq(agentEvalRunTopics.status, 'timeout')),
),
@@ -205,7 +210,7 @@ export class AgentEvalRunTopicModel {
.set(value)
.where(
and(
eq(agentEvalRunTopics.userId, this.userId),
this.ownership(),
eq(agentEvalRunTopics.runId, runId),
eq(agentEvalRunTopics.topicId, topicId),
),
@@ -2,21 +2,31 @@ import { and, count, eq, sql } from 'drizzle-orm';
import { agentEvalTestCases, type NewAgentEvalTestCase } from '../../schemas';
import { type LobeChatDatabase } from '../../type';
import { buildWorkspaceWhere } from '../../utils/workspace';
export class AgentEvalTestCaseModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.db = db;
this.userId = userId;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentEvalTestCases);
/**
* Create a single test case
*/
create = async (params: Omit<NewAgentEvalTestCase, 'userId'>) => {
let finalParams: NewAgentEvalTestCase = { ...params, userId: this.userId };
let finalParams: NewAgentEvalTestCase = {
...params,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
};
if (finalParams.sortOrder === undefined || finalParams.sortOrder === null) {
const [maxResult] = await this.db
@@ -35,7 +45,11 @@ export class AgentEvalTestCaseModel {
* Batch create test cases
*/
batchCreate = async (cases: Omit<NewAgentEvalTestCase, 'userId'>[]) => {
const withUserId = cases.map((c) => ({ ...c, userId: this.userId }));
const withUserId = cases.map((c) => ({
...c,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
}));
return this.db.insert(agentEvalTestCases).values(withUserId).returning();
};
@@ -45,7 +59,7 @@ export class AgentEvalTestCaseModel {
delete = async (id: string) => {
return this.db
.delete(agentEvalTestCases)
.where(and(eq(agentEvalTestCases.id, id), eq(agentEvalTestCases.userId, this.userId)));
.where(and(eq(agentEvalTestCases.id, id), this.ownership()));
};
/**
@@ -55,7 +69,7 @@ export class AgentEvalTestCaseModel {
const [result] = await this.db
.select()
.from(agentEvalTestCases)
.where(and(eq(agentEvalTestCases.id, id), eq(agentEvalTestCases.userId, this.userId)))
.where(and(eq(agentEvalTestCases.id, id), this.ownership()))
.limit(1);
return result;
};
@@ -67,12 +81,7 @@ export class AgentEvalTestCaseModel {
const query = this.db
.select()
.from(agentEvalTestCases)
.where(
and(
eq(agentEvalTestCases.datasetId, datasetId),
eq(agentEvalTestCases.userId, this.userId),
),
)
.where(and(eq(agentEvalTestCases.datasetId, datasetId), this.ownership()))
.orderBy(agentEvalTestCases.sortOrder);
if (limit !== undefined) {
@@ -92,12 +101,7 @@ export class AgentEvalTestCaseModel {
const result = await this.db
.select({ value: count() })
.from(agentEvalTestCases)
.where(
and(
eq(agentEvalTestCases.datasetId, datasetId),
eq(agentEvalTestCases.userId, this.userId),
),
);
.where(and(eq(agentEvalTestCases.datasetId, datasetId), this.ownership()));
return Number(result[0]?.value) || 0;
};
@@ -108,7 +112,7 @@ export class AgentEvalTestCaseModel {
const [result] = await this.db
.update(agentEvalTestCases)
.set({ ...value, updatedAt: new Date() })
.where(and(eq(agentEvalTestCases.id, id), eq(agentEvalTestCases.userId, this.userId)))
.where(and(eq(agentEvalTestCases.id, id), this.ownership()))
.returning();
return result;
};
+11 -4
View File
@@ -11,6 +11,7 @@ import type {
} from '../schemas/agentOperations';
import { agentOperations } from '../schemas/agentOperations';
import type { LobeChatDatabase } from '../type';
import { buildWorkspaceWhere } from '../utils/workspace';
/** Verify rollup states, mirrors the `verify_status` enum column. */
export type VerifyStatus =
@@ -80,12 +81,17 @@ export interface RecordOperationCompletionParams {
export class AgentOperationModel {
private readonly db: LobeChatDatabase;
private readonly userId: string;
private readonly workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.db = db;
this.userId = userId;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentOperations);
/**
* Insert the initial row when an operation is created. Idempotent via
* `onConflictDoNothing` on the primary key so resumed operations don't
@@ -110,6 +116,7 @@ export class AgentOperationModel {
topicId: params.topicId ?? null,
trigger: params.trigger,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
};
await this.db.insert(agentOperations).values(values).onConflictDoNothing();
@@ -151,14 +158,14 @@ export class AgentOperationModel {
await this.db
.update(agentOperations)
.set(updates)
.where(and(eq(agentOperations.id, operationId), eq(agentOperations.userId, this.userId)));
.where(and(eq(agentOperations.id, operationId), this.ownership()));
}
async findById(operationId: string) {
const [row] = await this.db
.select()
.from(agentOperations)
.where(and(eq(agentOperations.id, operationId), eq(agentOperations.userId, this.userId)))
.where(and(eq(agentOperations.id, operationId), this.ownership()))
.limit(1);
return row ?? null;
}
@@ -182,7 +189,7 @@ export class AgentOperationModel {
.from(agentOperations)
.where(
and(
eq(agentOperations.userId, this.userId),
this.ownership(),
isNotNull(agentOperations.startedAt),
isNotNull(agentOperations.completedAt),
gte(agentOperations.createdAt, startDate),
@@ -192,16 +192,27 @@ export class AgentSignalNightlyReviewModel {
topicCount: countDistinct(messages.topicId),
})
.from(messages)
.leftJoin(topics, and(eq(topics.id, messages.topicId), eq(topics.userId, userId)))
.innerJoin(agents, and(eq(agents.id, effectiveAgentId), eq(agents.userId, userId)))
.leftJoin(
topics,
and(eq(topics.id, messages.topicId), eq(topics.userId, userId), isNull(topics.workspaceId)),
)
.innerJoin(
agents,
and(eq(agents.id, effectiveAgentId), eq(agents.userId, userId), isNull(agents.workspaceId)),
)
.leftJoin(userSettings, eq(userSettings.id, userId))
.leftJoin(
messagePlugins,
and(eq(messagePlugins.id, messages.id), eq(messagePlugins.userId, userId)),
and(
eq(messagePlugins.id, messages.id),
eq(messagePlugins.userId, userId),
isNull(messagePlugins.workspaceId),
),
)
.where(
and(
eq(messages.userId, userId),
isNull(messages.workspaceId),
agentFilter,
gte(messages.createdAt, options.windowStart),
lte(messages.createdAt, options.windowEnd),
@@ -1,5 +1,6 @@
import { INBOX_SESSION_ID } from '@lobechat/const';
import { and, count, desc, eq, gte, isNull, lte, or, sql } from 'drizzle-orm';
import type { AnyPgColumn } from 'drizzle-orm/pg-core';
import {
agentDocuments,
@@ -11,6 +12,7 @@ import {
userMemories,
} from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { buildWorkspaceWhere } from '../../utils/workspace';
const parseAggregateTimestamp = (value: Date | string) =>
value instanceof Date ? value : new Date(value);
@@ -99,12 +101,17 @@ export interface AgentSignalDocumentActivityRow {
export class AgentSignalReviewContextModel {
private readonly db: LobeChatDatabase;
private readonly userId: string;
private readonly workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.db = db;
this.userId = userId;
this.workspaceId = workspaceId;
}
private ws = (cols: { userId: AnyPgColumn; workspaceId: AnyPgColumn }) =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, cols);
/** Checks agent ownership, virtual status, and self-iteration opt-in. */
canAgentRunSelfIteration = async (agentId: string) => {
const [agent] = await this.db
@@ -113,7 +120,7 @@ export class AgentSignalReviewContextModel {
.where(
and(
eq(agents.id, agentId),
eq(agents.userId, this.userId),
this.ws(agents),
or(eq(agents.virtual, false), isNull(agents.virtual), eq(agents.slug, INBOX_SESSION_ID)),
or(
eq(agents.slug, INBOX_SESSION_ID),
@@ -183,14 +190,11 @@ export class AgentSignalReviewContextModel {
totalCount: count(messagePlugins.id),
})
.from(messagePlugins)
.innerJoin(
messages,
and(eq(messages.id, messagePlugins.id), eq(messages.userId, this.userId)),
)
.leftJoin(topics, and(eq(topics.id, messages.topicId), eq(topics.userId, this.userId)))
.innerJoin(messages, and(eq(messages.id, messagePlugins.id), this.ws(messages)))
.leftJoin(topics, and(eq(topics.id, messages.topicId), this.ws(topics)))
.where(
and(
eq(messagePlugins.userId, this.userId),
this.ws(messagePlugins),
eq(effectiveAgentId, options.agentId),
gte(messages.createdAt, options.windowStart),
lte(messages.createdAt, options.windowEnd),
@@ -220,13 +224,10 @@ export class AgentSignalReviewContextModel {
updatedAt: agentDocuments.updatedAt,
})
.from(agentDocuments)
.innerJoin(
documents,
and(eq(documents.id, agentDocuments.documentId), eq(documents.userId, this.userId)),
)
.innerJoin(documents, and(eq(documents.id, agentDocuments.documentId), this.ws(documents)))
.where(
and(
eq(agentDocuments.userId, this.userId),
this.ws(agentDocuments),
eq(agentDocuments.agentId, options.agentId),
isNull(agentDocuments.deletedAt),
gte(agentDocuments.updatedAt, options.windowStart),
@@ -285,14 +286,11 @@ export class AgentSignalReviewContextModel {
topicId: topics.id,
})
.from(messages)
.leftJoin(topics, and(eq(topics.id, messages.topicId), eq(topics.userId, this.userId)))
.leftJoin(
messagePlugins,
and(eq(messagePlugins.id, messages.id), eq(messagePlugins.userId, this.userId)),
)
.leftJoin(topics, and(eq(topics.id, messages.topicId), this.ws(topics)))
.leftJoin(messagePlugins, and(eq(messagePlugins.id, messages.id), this.ws(messagePlugins)))
.where(
and(
eq(messages.userId, this.userId),
this.ws(messages),
eq(effectiveAgentId, options.agentId),
gte(messages.createdAt, options.windowStart),
lte(messages.createdAt, options.windowEnd),
@@ -349,14 +347,11 @@ export class AgentSignalReviewContextModel {
topicId: topics.id,
})
.from(messages)
.leftJoin(topics, and(eq(topics.id, messages.topicId), eq(topics.userId, this.userId)))
.leftJoin(
messagePlugins,
and(eq(messagePlugins.id, messages.id), eq(messagePlugins.userId, this.userId)),
)
.leftJoin(topics, and(eq(topics.id, messages.topicId), this.ws(topics)))
.leftJoin(messagePlugins, and(eq(messagePlugins.id, messages.id), this.ws(messagePlugins)))
.where(
and(
eq(messages.userId, this.userId),
this.ws(messages),
eq(messages.agentId, options.agentId),
gte(messages.createdAt, options.windowStart),
lte(messages.createdAt, options.windowEnd),
+19 -13
View File
@@ -2,9 +2,10 @@ import type { SkillItem, SkillListItem } from '@lobechat/types';
import { merge } from '@lobechat/utils';
import { and, desc, eq, ilike, inArray, or } from 'drizzle-orm';
import type {NewAgentSkill } from '../schemas';
import type { NewAgentSkill } from '../schemas';
import { agentSkills } from '../schemas';
import type { LobeChatDatabase } from '../type';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
const skillItemColumns = {
content: agentSkills.content,
@@ -35,19 +36,24 @@ const skillListColumns = {
export class AgentSkillModel {
private userId: string;
private workspaceId?: string;
private db: LobeChatDatabase;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.db = db;
this.userId = userId;
this.workspaceId = workspaceId;
}
private scopeWhere = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentSkills);
// ========== Create ==========
create = async (data: Omit<NewAgentSkill, 'userId'>): Promise<SkillItem> => {
create = async (data: Omit<NewAgentSkill, 'userId' | 'workspaceId'>): Promise<SkillItem> => {
const [result] = await this.db
.insert(agentSkills)
.values({ ...data, userId: this.userId })
.values(buildWorkspacePayload({ userId: this.userId, workspaceId: this.workspaceId }, data))
.returning(skillItemColumns);
return result;
};
@@ -58,7 +64,7 @@ export class AgentSkillModel {
const [result] = await this.db
.select(skillItemColumns)
.from(agentSkills)
.where(and(eq(agentSkills.id, id), eq(agentSkills.userId, this.userId)))
.where(and(eq(agentSkills.id, id), this.scopeWhere()))
.limit(1);
return result;
};
@@ -67,7 +73,7 @@ export class AgentSkillModel {
const [result] = await this.db
.select(skillItemColumns)
.from(agentSkills)
.where(and(eq(agentSkills.identifier, identifier), eq(agentSkills.userId, this.userId)))
.where(and(eq(agentSkills.identifier, identifier), this.scopeWhere()))
.limit(1);
return result;
};
@@ -76,7 +82,7 @@ export class AgentSkillModel {
const [result] = await this.db
.select(skillItemColumns)
.from(agentSkills)
.where(and(eq(agentSkills.name, name), eq(agentSkills.userId, this.userId)))
.where(and(eq(agentSkills.name, name), this.scopeWhere()))
.limit(1);
return result;
};
@@ -85,7 +91,7 @@ export class AgentSkillModel {
const data = await this.db
.select(skillListColumns)
.from(agentSkills)
.where(eq(agentSkills.userId, this.userId))
.where(this.scopeWhere())
.orderBy(desc(agentSkills.updatedAt));
return { data, total: data.length };
@@ -96,7 +102,7 @@ export class AgentSkillModel {
return this.db
.select(skillItemColumns)
.from(agentSkills)
.where(and(inArray(agentSkills.id, ids), eq(agentSkills.userId, this.userId)));
.where(and(inArray(agentSkills.id, ids), this.scopeWhere()));
};
listBySource = async (
@@ -105,7 +111,7 @@ export class AgentSkillModel {
const data = await this.db
.select(skillListColumns)
.from(agentSkills)
.where(and(eq(agentSkills.source, source), eq(agentSkills.userId, this.userId)))
.where(and(eq(agentSkills.source, source), this.scopeWhere()))
.orderBy(desc(agentSkills.updatedAt));
return { data, total: data.length };
@@ -117,7 +123,7 @@ export class AgentSkillModel {
.from(agentSkills)
.where(
and(
eq(agentSkills.userId, this.userId),
this.scopeWhere(),
or(ilike(agentSkills.name, `%${query}%`), ilike(agentSkills.description, `%${query}%`)),
),
)
@@ -136,7 +142,7 @@ export class AgentSkillModel {
const [result] = await this.db
.update(agentSkills)
.set(updateData)
.where(and(eq(agentSkills.id, id), eq(agentSkills.userId, this.userId)))
.where(and(eq(agentSkills.id, id), this.scopeWhere()))
.returning(skillItemColumns);
return result;
};
@@ -146,7 +152,7 @@ export class AgentSkillModel {
delete = async (id: string): Promise<{ success: boolean }> => {
const result = await this.db
.delete(agentSkills)
.where(and(eq(agentSkills.id, id), eq(agentSkills.userId, this.userId)));
.where(and(eq(agentSkills.id, id), this.scopeWhere()));
return { success: (result.rowCount ?? 0) > 0 };
};
+19 -8
View File
@@ -7,6 +7,7 @@ import { KeyVaultsGateKeeper } from '@/server/modules/KeyVaultsEncrypt';
import type { ApiKeyItem, NewApiKeyItem } from '../schemas';
import { apiKeys } from '../schemas';
import type { LobeChatDatabase } from '../type';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
export class ApiKeyModel {
static findByKey = async (db: LobeChatDatabase, key: string) => {
@@ -22,13 +23,18 @@ export class ApiKeyModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
private gateKeeperPromise: Promise<KeyVaultsGateKeeper> | null = null;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, apiKeys);
private async getGateKeeper() {
if (!this.gateKeeperPromise) {
this.gateKeeperPromise = KeyVaultsGateKeeper.initWithEnvKey();
@@ -45,24 +51,29 @@ export class ApiKeyModel {
const [result] = await this.db
.insert(apiKeys)
.values({ ...params, key: encryptedKey, keyHash, userId: this.userId })
.values(
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{ ...params, key: encryptedKey, keyHash },
),
)
.returning();
return result;
};
delete = async (id: string) => {
return this.db.delete(apiKeys).where(and(eq(apiKeys.id, id), eq(apiKeys.userId, this.userId)));
return this.db.delete(apiKeys).where(and(eq(apiKeys.id, id), this.ownership()));
};
deleteAll = async () => {
return this.db.delete(apiKeys).where(eq(apiKeys.userId, this.userId));
return this.db.delete(apiKeys).where(this.ownership());
};
query = async () => {
const results = await this.db.query.apiKeys.findMany({
orderBy: [desc(apiKeys.updatedAt)],
where: eq(apiKeys.userId, this.userId),
where: this.ownership(),
});
const gateKeeper = await this.getGateKeeper();
@@ -103,12 +114,12 @@ export class ApiKeyModel {
return this.db
.update(apiKeys)
.set({ ...value, updatedAt: new Date() })
.where(and(eq(apiKeys.id, id), eq(apiKeys.userId, this.userId)));
.where(and(eq(apiKeys.id, id), this.ownership()));
};
findById = async (id: string) => {
return this.db.query.apiKeys.findFirst({
where: and(eq(apiKeys.id, id), eq(apiKeys.userId, this.userId)),
where: and(eq(apiKeys.id, id), this.ownership()),
});
};
@@ -116,6 +127,6 @@ export class ApiKeyModel {
return this.db
.update(apiKeys)
.set({ lastUsedAt: new Date() })
.where(and(eq(apiKeys.id, id), eq(apiKeys.userId, this.userId)));
.where(and(eq(apiKeys.id, id), this.ownership()));
};
}
+28 -13
View File
@@ -11,35 +11,46 @@ import { and, eq, inArray, lt, or, sql } from 'drizzle-orm';
import type { AsyncTaskSelectItem, NewAsyncTaskItem } from '../schemas';
import { asyncTasks } from '../schemas';
import type { LobeChatDatabase } from '../type';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
export class AsyncTaskModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, asyncTasks);
create = async (
params: Pick<NewAsyncTaskItem, 'type' | 'status' | 'metadata' | 'parentId'>,
): Promise<string> => {
const data = await this.db
.insert(asyncTasks)
.values({ ...params, userId: this.userId })
.values(
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{ ...params },
),
)
.returning();
return data[0].id;
};
delete = async (id: string) => {
return this.db
.delete(asyncTasks)
.where(and(eq(asyncTasks.id, id), eq(asyncTasks.userId, this.userId)));
return this.db.delete(asyncTasks).where(and(eq(asyncTasks.id, id), this.ownership()));
};
findById = async (id: string) => {
return this.db.query.asyncTasks.findFirst({ where: and(eq(asyncTasks.id, id)) });
return this.db.query.asyncTasks.findFirst({
where: and(eq(asyncTasks.id, id), this.ownership()),
});
};
static findByInferenceId = async (db: LobeChatDatabase, inferenceId: string) => {
@@ -52,13 +63,13 @@ export class AsyncTaskModel {
return this.db
.update(asyncTasks)
.set({ ...value, updatedAt: new Date() })
.where(and(eq(asyncTasks.id, taskId)));
.where(and(eq(asyncTasks.id, taskId), this.ownership()));
}
findActiveByType = async (type: AsyncTaskType) => {
return this.db.query.asyncTasks.findFirst({
where: and(
eq(asyncTasks.userId, this.userId),
this.ownership(),
eq(asyncTasks.type, type),
inArray(asyncTasks.status, [AsyncTaskStatus.Pending, AsyncTaskStatus.Processing]),
),
@@ -98,7 +109,7 @@ export class AsyncTaskModel {
`,
updatedAt: new Date(),
})
.where(and(eq(asyncTasks.id, taskId), eq(asyncTasks.userId, this.userId)))
.where(and(eq(asyncTasks.id, taskId), this.ownership()))
.returning({ metadata: asyncTasks.metadata, status: asyncTasks.status });
return result[0];
@@ -110,7 +121,7 @@ export class AsyncTaskModel {
if (taskIds.length > 0) {
await this.checkTimeoutTasks(taskIds);
chunkTasks = await this.db.query.asyncTasks.findMany({
where: and(inArray(asyncTasks.id, taskIds), eq(asyncTasks.type, type)),
where: and(inArray(asyncTasks.id, taskIds), eq(asyncTasks.type, type), this.ownership()),
});
}
@@ -138,6 +149,7 @@ export class AsyncTaskModel {
.where(
and(
inArray(asyncTasks.id, ids),
this.ownership(),
or(
eq(asyncTasks.status, AsyncTaskStatus.Pending),
eq(asyncTasks.status, AsyncTaskStatus.Processing),
@@ -157,9 +169,12 @@ export class AsyncTaskModel {
status: AsyncTaskStatus.Error,
})
.where(
inArray(
asyncTasks.id,
tasks.map((item) => item.id),
and(
inArray(
asyncTasks.id,
tasks.map((item) => item.id),
),
this.ownership(),
),
);
}
+21 -13
View File
@@ -4,6 +4,7 @@ import { agents } from '../schemas/agent';
import type { BriefItem, NewBrief } from '../schemas/task';
import { briefs, tasks } from '../schemas/task';
import type { LobeChatDatabase } from '../type';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
export interface UnresolvedBriefRow {
agentAvatar: string | null;
@@ -17,16 +18,23 @@ export interface UnresolvedBriefRow {
export class BriefModel {
private readonly userId: string;
private readonly db: LobeChatDatabase;
private readonly workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.db = db;
this.userId = userId;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, briefs);
async create(data: Omit<NewBrief, 'id' | 'userId'>): Promise<BriefItem> {
const result = await this.db
.insert(briefs)
.values({ ...data, userId: this.userId })
.values(
buildWorkspacePayload({ userId: this.userId, workspaceId: this.workspaceId }, { ...data }),
)
.returning();
return result[0];
@@ -36,7 +44,7 @@ export class BriefModel {
const result = await this.db
.select()
.from(briefs)
.where(and(eq(briefs.id, id), eq(briefs.userId, this.userId)))
.where(and(eq(briefs.id, id), this.ownership()))
.limit(1);
return result[0] || null;
@@ -49,7 +57,7 @@ export class BriefModel {
}): Promise<{ briefs: BriefItem[]; total: number }> {
const { type, limit = 50, offset = 0 } = options || {};
const conditions = [eq(briefs.userId, this.userId)];
const conditions = [this.ownership()];
if (type) conditions.push(eq(briefs.type, type));
const where = and(...conditions);
@@ -90,7 +98,7 @@ export class BriefModel {
.from(briefs)
.leftJoin(agents, eq(briefs.agentId, agents.id))
.leftJoin(tasks, eq(briefs.taskId, tasks.id))
.where(and(eq(briefs.userId, this.userId), isNull(briefs.resolvedAt)))
.where(and(this.ownership(), isNull(briefs.resolvedAt)))
.orderBy(
sql`CASE
WHEN ${briefs.priority} = 'urgent' THEN 0
@@ -130,7 +138,7 @@ export class BriefModel {
.from(briefs)
.where(
and(
eq(briefs.userId, this.userId),
this.ownership(),
eq(briefs.agentId, agentId),
eq(briefs.trigger, trigger),
isNull(briefs.resolvedAt),
@@ -144,7 +152,7 @@ export class BriefModel {
return this.db
.select()
.from(briefs)
.where(and(eq(briefs.taskId, taskId), eq(briefs.userId, this.userId)))
.where(and(eq(briefs.taskId, taskId), this.ownership()))
.orderBy(desc(briefs.createdAt));
}
@@ -159,7 +167,7 @@ export class BriefModel {
): Promise<boolean> {
const excludeTypes = options?.excludeTypes ?? [];
const conditions = [
eq(briefs.userId, this.userId),
this.ownership(),
eq(briefs.taskId, taskId),
eq(briefs.priority, 'urgent'),
isNull(briefs.resolvedAt),
@@ -180,7 +188,7 @@ export class BriefModel {
return this.db
.select()
.from(briefs)
.where(and(eq(briefs.cronJobId, cronJobId), eq(briefs.userId, this.userId)))
.where(and(eq(briefs.cronJobId, cronJobId), this.ownership()))
.orderBy(desc(briefs.createdAt));
}
@@ -188,7 +196,7 @@ export class BriefModel {
const result = await this.db
.update(briefs)
.set({ readAt: new Date() })
.where(and(eq(briefs.id, id), eq(briefs.userId, this.userId)))
.where(and(eq(briefs.id, id), this.ownership()))
.returning();
return result[0] || null;
@@ -206,7 +214,7 @@ export class BriefModel {
resolvedAt: new Date(),
resolvedComment: options?.comment,
})
.where(and(eq(briefs.id, id), eq(briefs.userId, this.userId)))
.where(and(eq(briefs.id, id), this.ownership()))
.returning();
return result[0] || null;
@@ -229,7 +237,7 @@ export class BriefModel {
const result = await this.db
.update(briefs)
.set({ metadata })
.where(and(eq(briefs.id, id), eq(briefs.userId, this.userId)))
.where(and(eq(briefs.id, id), this.ownership()))
.returning();
return result[0] || null;
@@ -238,7 +246,7 @@ export class BriefModel {
async delete(id: string): Promise<boolean> {
const result = await this.db
.delete(briefs)
.where(and(eq(briefs.id, id), eq(briefs.userId, this.userId)))
.where(and(eq(briefs.id, id), this.ownership()))
.returning();
return result.length > 0;
+55 -19
View File
@@ -8,20 +8,27 @@ import type {
} from '../schemas';
import { chatGroups, chatGroupsAgents } from '../schemas';
import type { LobeChatDatabase } from '../type';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
export class ChatGroupModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, chatGroups);
// ******* Query Methods ******* //
async findById(id: string): Promise<ChatGroupItem | undefined> {
const item = await this.db.query.chatGroups.findFirst({
where: and(eq(chatGroups.id, id), eq(chatGroups.userId, this.userId)),
where: and(eq(chatGroups.id, id), this.ownership()),
});
return item;
@@ -30,7 +37,7 @@ export class ChatGroupModel {
async query(): Promise<ChatGroupItem[]> {
return this.db.query.chatGroups.findMany({
orderBy: [desc(chatGroups.updatedAt)],
where: eq(chatGroups.userId, this.userId),
where: this.ownership(),
});
}
@@ -44,7 +51,7 @@ export class ChatGroupModel {
columns: { id: true },
orderBy: [desc(chatGroups.updatedAt)],
where: and(
eq(chatGroups.userId, this.userId),
this.ownership(),
sql`${chatGroups.config}->>'forkedFromIdentifier' = ${forkedFromIdentifier}`,
),
});
@@ -58,7 +65,7 @@ export class ChatGroupModel {
const groupIds = groups.map((g) => g.id);
const groupAgents = await this.db.query.chatGroupsAgents.findMany({
where: inArray(chatGroupsAgents.chatGroupId, groupIds),
where: and(inArray(chatGroupsAgents.chatGroupId, groupIds), this.agentsOwnership()),
with: { agent: true },
});
@@ -87,7 +94,7 @@ export class ChatGroupModel {
const agents = await this.db.query.chatGroupsAgents.findMany({
orderBy: [chatGroupsAgents.order],
where: eq(chatGroupsAgents.chatGroupId, groupId),
where: and(eq(chatGroupsAgents.chatGroupId, groupId), this.agentsOwnership()),
});
return { agents, group };
@@ -98,7 +105,12 @@ export class ChatGroupModel {
async create(params: Omit<NewChatGroup, 'userId'>): Promise<ChatGroupItem> {
const [result] = await this.db
.insert(chatGroups)
.values({ ...params, userId: this.userId })
.values(
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{ ...params },
),
)
.returning();
return result;
@@ -119,6 +131,7 @@ export class ChatGroupModel {
chatGroupId: group.id,
order: index,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
}));
const agents = await this.db.insert(chatGroupsAgents).values(agentParams).returning();
@@ -132,7 +145,7 @@ export class ChatGroupModel {
const [result] = await this.db
.update(chatGroups)
.set(value)
.where(and(eq(chatGroups.id, id), eq(chatGroups.userId, this.userId)))
.where(and(eq(chatGroups.id, id), this.ownership()))
.returning();
if (!result) {
@@ -153,6 +166,7 @@ export class ChatGroupModel {
order: options?.order || 0,
role: options?.role || 'assistant',
userId: this.userId,
workspaceId: this.workspaceId ?? null,
};
const [result] = await this.db.insert(chatGroupsAgents).values(params).returning();
@@ -189,6 +203,7 @@ export class ChatGroupModel {
chatGroupId: groupId,
enabled: true,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
}));
const added = await this.db.insert(chatGroupsAgents).values(newAgents).returning();
@@ -196,10 +211,19 @@ export class ChatGroupModel {
return { added, existing: existingIds };
}
private agentsOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, chatGroupsAgents);
async removeAgentFromGroup(groupId: string, agentId: string): Promise<void> {
await this.db
.delete(chatGroupsAgents)
.where(and(eq(chatGroupsAgents.chatGroupId, groupId), eq(chatGroupsAgents.agentId, agentId)));
.where(
and(
eq(chatGroupsAgents.chatGroupId, groupId),
eq(chatGroupsAgents.agentId, agentId),
this.agentsOwnership(),
),
);
}
/**
@@ -212,7 +236,11 @@ export class ChatGroupModel {
await this.db
.delete(chatGroupsAgents)
.where(
and(eq(chatGroupsAgents.chatGroupId, groupId), inArray(chatGroupsAgents.agentId, agentIds)),
and(
eq(chatGroupsAgents.chatGroupId, groupId),
inArray(chatGroupsAgents.agentId, agentIds),
this.agentsOwnership(),
),
);
}
@@ -224,7 +252,13 @@ export class ChatGroupModel {
const [result] = await this.db
.update(chatGroupsAgents)
.set({ ...updates, updatedAt: new Date() })
.where(and(eq(chatGroupsAgents.chatGroupId, groupId), eq(chatGroupsAgents.agentId, agentId)))
.where(
and(
eq(chatGroupsAgents.chatGroupId, groupId),
eq(chatGroupsAgents.agentId, agentId),
this.agentsOwnership(),
),
)
.returning();
return result;
@@ -236,7 +270,7 @@ export class ChatGroupModel {
// Agents are automatically deleted due to CASCADE constraint
const [result] = await this.db
.delete(chatGroups)
.where(and(eq(chatGroups.id, id), eq(chatGroups.userId, this.userId)))
.where(and(eq(chatGroups.id, id), this.ownership()))
.returning();
if (!result) {
@@ -247,7 +281,7 @@ export class ChatGroupModel {
}
async deleteAll(): Promise<void> {
await this.db.delete(chatGroups).where(eq(chatGroups.userId, this.userId));
await this.db.delete(chatGroups).where(this.ownership());
}
// ******* Agent Query Methods ******* //
@@ -255,14 +289,18 @@ export class ChatGroupModel {
async getGroupAgents(groupId: string): Promise<ChatGroupAgentItem[]> {
return this.db.query.chatGroupsAgents.findMany({
orderBy: [chatGroupsAgents.order],
where: eq(chatGroupsAgents.chatGroupId, groupId),
where: and(eq(chatGroupsAgents.chatGroupId, groupId), this.agentsOwnership()),
});
}
async getEnabledGroupAgents(groupId: string): Promise<ChatGroupAgentItem[]> {
return this.db.query.chatGroupsAgents.findMany({
orderBy: [chatGroupsAgents.order],
where: and(eq(chatGroupsAgents.chatGroupId, groupId), eq(chatGroupsAgents.enabled, true)),
where: and(
eq(chatGroupsAgents.chatGroupId, groupId),
eq(chatGroupsAgents.enabled, true),
this.agentsOwnership(),
),
});
}
@@ -275,9 +313,7 @@ export class ChatGroupModel {
const groupIds = await this.db
.selectDistinct({ chatGroupId: chatGroupsAgents.chatGroupId })
.from(chatGroupsAgents)
.where(
and(eq(chatGroupsAgents.userId, this.userId), inArray(chatGroupsAgents.agentId, agentIds)),
);
.where(and(this.agentsOwnership(), inArray(chatGroupsAgents.agentId, agentIds)));
if (groupIds.length === 0) return [];
@@ -288,7 +324,7 @@ export class ChatGroupModel {
chatGroups.id,
groupIds.map((g) => g.chatGroupId),
),
eq(chatGroups.userId, this.userId),
this.ownership(),
),
});
}
+38 -11
View File
@@ -5,27 +5,44 @@ import { chunk } from 'es-toolkit/compat';
import type { NewChunkItem, NewUnstructuredChunkItem } from '../schemas';
import { chunks, embeddings, fileChunks, files, unstructuredChunks } from '../schemas';
import type { LobeChatDatabase } from '../type';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
export class ChunkModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, chunks);
private fileChunksOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, fileChunks);
bulkCreate = async (params: NewChunkItem[], fileId: string) => {
return this.db.transaction(async (trx) => {
if (params.length === 0) return [];
const result = await trx.insert(chunks).values(params).returning();
const result = await trx
.insert(chunks)
.values(
params.map((p) =>
buildWorkspacePayload({ userId: this.userId, workspaceId: this.workspaceId }, p),
),
)
.returning();
const fileChunksData = result.map((chunk) => ({
chunkId: chunk.id,
fileId,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
}));
if (fileChunksData.length > 0) {
@@ -37,11 +54,15 @@ export class ChunkModel {
};
bulkCreateUnstructuredChunks = async (params: NewUnstructuredChunkItem[]) => {
return this.db.insert(unstructuredChunks).values(params);
return this.db
.insert(unstructuredChunks)
.values(
params.map((p) => ({ ...p, workspaceId: p.workspaceId ?? this.workspaceId ?? null })),
);
};
delete = async (id: string) => {
return this.db.delete(chunks).where(and(eq(chunks.id, id), eq(chunks.userId, this.userId)));
return this.db.delete(chunks).where(and(eq(chunks.id, id), this.ownership()));
};
deleteOrphanChunks = async () => {
@@ -67,7 +88,7 @@ export class ChunkModel {
findById = async (id: string) => {
return this.db.query.chunks.findFirst({
where: and(eq(chunks.id, id)),
where: and(eq(chunks.id, id), this.ownership()),
});
};
@@ -85,7 +106,7 @@ export class ChunkModel {
})
.from(chunks)
.innerJoin(fileChunks, eq(chunks.id, fileChunks.chunkId))
.where(and(eq(fileChunks.fileId, id), eq(chunks.userId, this.userId)))
.where(and(eq(fileChunks.fileId, id), this.ownership(), this.fileChunksOwnership()))
.limit(20)
.offset(page * 20)
.orderBy(asc(chunks.index));
@@ -102,7 +123,7 @@ export class ChunkModel {
.select()
.from(chunks)
.innerJoin(fileChunks, eq(chunks.id, fileChunks.chunkId))
.where(eq(fileChunks.fileId, id));
.where(and(eq(fileChunks.fileId, id), this.ownership(), this.fileChunksOwnership()));
return data
.map((item) => item.chunks)
@@ -119,7 +140,7 @@ export class ChunkModel {
id: fileChunks.fileId,
})
.from(fileChunks)
.where(inArray(fileChunks.fileId, ids))
.where(and(inArray(fileChunks.fileId, ids), this.fileChunksOwnership()))
.groupBy(fileChunks.fileId);
};
@@ -130,7 +151,7 @@ export class ChunkModel {
id: fileChunks.fileId,
})
.from(fileChunks)
.where(eq(fileChunks.fileId, ids))
.where(and(eq(fileChunks.fileId, ids), this.fileChunksOwnership()))
.groupBy(fileChunks.fileId);
return data[0]?.count ?? 0;
@@ -161,7 +182,13 @@ export class ChunkModel {
.leftJoin(embeddings, eq(chunks.id, embeddings.chunkId))
.leftJoin(fileChunks, eq(chunks.id, fileChunks.chunkId))
.leftJoin(files, eq(fileChunks.fileId, files.id))
.where(fileIds ? inArray(fileChunks.fileId, fileIds) : undefined)
.where(
and(
this.ownership(),
fileIds ? this.fileChunksOwnership() : undefined,
fileIds ? inArray(fileChunks.fileId, fileIds) : undefined,
),
)
.orderBy((t) => desc(t.similarity))
.limit(30);
@@ -202,7 +229,7 @@ export class ChunkModel {
.leftJoin(embeddings, eq(chunks.id, embeddings.chunkId))
.leftJoin(fileChunks, eq(chunks.id, fileChunks.chunkId))
.leftJoin(files, eq(files.id, fileChunks.fileId))
.where(inArray(fileChunks.fileId, fileIds))
.where(and(inArray(fileChunks.fileId, fileIds), this.ownership(), this.fileChunksOwnership()))
.orderBy((t) => desc(t.similarity))
// Relaxed to 15 for now
.limit(topK);
+19 -18
View File
@@ -8,6 +8,7 @@ import type {
} from '../schemas';
import { userConnectors } from '../schemas';
import type { LobeChatDatabase } from '../type';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
interface GateKeeper {
decrypt: (ciphertext: string) => Promise<{ plaintext: string }>;
@@ -28,13 +29,18 @@ export class ConnectorModel {
private userId: string;
private db: LobeChatDatabase;
private gateKeeper?: GateKeeper;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string, gateKeeper?: GateKeeper) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string, gateKeeper?: GateKeeper) {
this.db = db;
this.userId = userId;
this.workspaceId = workspaceId;
this.gateKeeper = gateKeeper;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, userConnectors);
create = async (
params: CreateConnectorParams,
gateKeeper: GateKeeper | undefined = this.gateKeeper,
@@ -45,25 +51,25 @@ export class ConnectorModel {
const [result] = await this.db
.insert(userConnectors)
.values({ ...params, credentials, userId: this.userId })
.values(
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{ ...params, credentials },
),
)
.returning();
return result;
};
delete = async (id: string): Promise<void> => {
await this.db
.delete(userConnectors)
.where(and(eq(userConnectors.id, id), eq(userConnectors.userId, this.userId)));
await this.db.delete(userConnectors).where(and(eq(userConnectors.id, id), this.ownership()));
};
query = async (
gateKeeper: GateKeeper | undefined = this.gateKeeper,
): Promise<DecryptedConnector[]> => {
const rows = await this.db
.select()
.from(userConnectors)
.where(eq(userConnectors.userId, this.userId));
const rows = await this.db.select().from(userConnectors).where(this.ownership());
return Promise.all(rows.map((r) => decryptRow(r, gateKeeper)));
};
@@ -77,12 +83,7 @@ export class ConnectorModel {
const rows = await this.db
.select()
.from(userConnectors)
.where(
and(
eq(userConnectors.userId, this.userId),
inArray(userConnectors.identifier, identifiers),
),
);
.where(and(this.ownership(), inArray(userConnectors.identifier, identifiers)));
return Promise.all(rows.map((r) => decryptRow(r, gateKeeper)));
};
@@ -94,7 +95,7 @@ export class ConnectorModel {
const [row] = await this.db
.select()
.from(userConnectors)
.where(and(eq(userConnectors.id, id), eq(userConnectors.userId, this.userId)))
.where(and(eq(userConnectors.id, id), this.ownership()))
.limit(1);
if (!row) return null;
@@ -120,14 +121,14 @@ export class ConnectorModel {
await this.db
.update(userConnectors)
.set(set)
.where(and(eq(userConnectors.id, id), eq(userConnectors.userId, this.userId)));
.where(and(eq(userConnectors.id, id), this.ownership()));
};
updateStatus = async (id: string, status: ConnectorStatus): Promise<void> => {
await this.db
.update(userConnectors)
.set({ status, updatedAt: new Date() })
.where(and(eq(userConnectors.id, id), eq(userConnectors.userId, this.userId)));
.where(and(eq(userConnectors.id, id), this.ownership()));
};
}
+29 -31
View File
@@ -8,6 +8,7 @@ import type {
} from '../schemas';
import { ConnectorToolPermission as Permission, userConnectorTools } from '../schemas';
import type { LobeChatDatabase } from '../type';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
export interface SyncToolInput {
crudType: ToolCRUDType;
@@ -24,12 +25,17 @@ export interface SyncToolInput {
export class ConnectorToolModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.db = db;
this.userId = userId;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, userConnectorTools);
/**
* Batch-upsert tools from a manifest sync.
*
@@ -41,19 +47,23 @@ export class ConnectorToolModel {
upsertMany = async (userConnectorId: string, tools: SyncToolInput[]): Promise<void> => {
if (tools.length === 0) return;
const values: NewUserConnectorTool[] = tools.map((t) => ({
crudType: t.crudType,
description: t.description ?? null,
displayName: t.displayName ?? null,
inputSchema: t.inputSchema ?? null,
isWorkArtifact: false,
outputSchema: t.outputSchema ?? null,
permission: t.defaultPermission ?? Permission.auto,
renderConfig: t.renderConfig ?? null,
toolName: t.toolName,
userConnectorId,
userId: this.userId,
}));
const values: NewUserConnectorTool[] = tools.map((t) =>
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{
crudType: t.crudType,
description: t.description ?? null,
displayName: t.displayName ?? null,
inputSchema: t.inputSchema ?? null,
isWorkArtifact: false,
outputSchema: t.outputSchema ?? null,
permission: t.defaultPermission ?? Permission.auto,
renderConfig: t.renderConfig ?? null,
toolName: t.toolName,
userConnectorId,
},
),
);
await this.db
.insert(userConnectorTools)
@@ -80,19 +90,14 @@ export class ConnectorToolModel {
await this.db
.update(userConnectorTools)
.set({ permission, updatedAt: new Date() })
.where(and(eq(userConnectorTools.id, toolId), eq(userConnectorTools.userId, this.userId)));
.where(and(eq(userConnectorTools.id, toolId), this.ownership()));
};
queryByConnector = async (userConnectorId: string): Promise<UserConnectorToolItem[]> => {
return this.db
.select()
.from(userConnectorTools)
.where(
and(
eq(userConnectorTools.userConnectorId, userConnectorId),
eq(userConnectorTools.userId, this.userId),
),
);
.where(and(eq(userConnectorTools.userConnectorId, userConnectorId), this.ownership()));
};
/**
@@ -107,7 +112,7 @@ export class ConnectorToolModel {
.from(userConnectorTools)
.where(
and(
eq(userConnectorTools.userId, this.userId),
this.ownership(),
inArray(userConnectorTools.userConnectorId, connectorIds),
ne(userConnectorTools.permission, Permission.disabled),
),
@@ -124,12 +129,7 @@ export class ConnectorToolModel {
return this.db
.select()
.from(userConnectorTools)
.where(
and(
eq(userConnectorTools.userId, this.userId),
inArray(userConnectorTools.userConnectorId, connectorIds),
),
);
.where(and(this.ownership(), inArray(userConnectorTools.userConnectorId, connectorIds)));
};
/**
@@ -140,9 +140,7 @@ export class ConnectorToolModel {
const results = await this.db
.select()
.from(userConnectorTools)
.where(
and(eq(userConnectorTools.userId, this.userId), eq(userConnectorTools.toolName, toolName)),
)
.where(and(this.ownership(), eq(userConnectorTools.toolName, toolName)))
.limit(1);
return results[0];
};
+11
View File
@@ -19,6 +19,17 @@ export interface UpdateDeviceParams {
workingDirs?: WorkingDirEntry[];
}
/**
* Devices are intentionally USER-LEVEL, not workspace-scoped.
*
* Even though the `devices` table carries a nullable `workspace_id` column, a
* physical machine belongs to the user across every workspace they're in (the
* unique key is `(userId, deviceId)`). This model therefore scopes all reads
* and writes by `userId` only and deliberately does NOT take a `workspaceId`
* argument or use `buildWorkspaceWhere` / `buildWorkspacePayload`. Switching it
* to workspace-scoped lookups would hide a user's own device inside their
* workspaces. See the matching note on `devices.workspaceId` in the schema.
*/
export class DeviceModel {
private userId: string;
private db: LobeChatDatabase;
+237 -15
View File
@@ -1,8 +1,9 @@
import { and, count, desc, eq, inArray, isNull, notInArray } from 'drizzle-orm';
import { and, count, desc, eq, inArray, isNull, notInArray, sum } from 'drizzle-orm';
import type { DocumentItem, NewDocument } from '../schemas';
import { DOCUMENT_FOLDER_TYPE, documents } from '../schemas';
import { DOCUMENT_FOLDER_TYPE, documents, files } from '../schemas';
import type { LobeChatDatabase } from '../type';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
export interface QueryDocumentParams {
current?: number;
@@ -14,16 +15,21 @@ export interface QueryDocumentParams {
export class DocumentModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents);
findOrCreateFolder = async (name: string, parentId?: string): Promise<DocumentItem> => {
const existing = await this.db.query.documents.findFirst({
where: and(
eq(documents.userId, this.userId),
this.ownership(),
eq(documents.fileType, DOCUMENT_FOLDER_TYPE),
eq(documents.filename, name),
parentId ? eq(documents.parentId, parentId) : isNull(documents.parentId),
@@ -48,20 +54,23 @@ export class DocumentModel {
create = async (params: Omit<NewDocument, 'userId'>): Promise<DocumentItem> => {
const result = (await this.db
.insert(documents)
.values({ ...params, userId: this.userId })
.values(
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{ ...params },
),
)
.returning()) as DocumentItem[];
return result[0]!;
};
delete = async (id: string) => {
return this.db
.delete(documents)
.where(and(eq(documents.id, id), eq(documents.userId, this.userId)));
return this.db.delete(documents).where(and(eq(documents.id, id), this.ownership()));
};
deleteAll = async () => {
return this.db.delete(documents).where(eq(documents.userId, this.userId));
return this.db.delete(documents).where(this.ownership());
};
query = async ({
@@ -74,7 +83,7 @@ export class DocumentModel {
total: number;
}> => {
const offset = current * pageSize;
const conditions = [eq(documents.userId, this.userId)];
const conditions = [this.ownership()];
if (fileTypes?.length) {
conditions.push(inArray(documents.fileType, fileTypes));
@@ -141,19 +150,19 @@ export class DocumentModel {
findById = async (id: string): Promise<DocumentItem | undefined> => {
return this.db.query.documents.findFirst({
where: and(eq(documents.userId, this.userId), eq(documents.id, id)),
where: and(this.ownership(), eq(documents.id, id)),
});
};
findByFileId = async (fileId: string) => {
return this.db.query.documents.findFirst({
where: and(eq(documents.userId, this.userId), eq(documents.fileId, fileId)),
where: and(this.ownership(), eq(documents.fileId, fileId)),
});
};
findBySlug = async (slug: string): Promise<DocumentItem | undefined> => {
return this.db.query.documents.findFirst({
where: and(eq(documents.userId, this.userId), eq(documents.slug, slug)),
where: and(this.ownership(), eq(documents.slug, slug)),
});
};
@@ -170,7 +179,7 @@ export class DocumentModel {
): Promise<DocumentItem | undefined> => {
return this.db.query.documents.findFirst({
where: and(
eq(documents.userId, this.userId),
this.ownership(),
eq(documents.source, source),
eq(documents.sourceType, sourceType),
),
@@ -181,6 +190,219 @@ export class DocumentModel {
return this.db
.update(documents)
.set({ ...value, updatedAt: new Date() })
.where(and(eq(documents.userId, this.userId), eq(documents.id, id)));
.where(and(this.ownership(), eq(documents.id, id)));
};
/**
* Collect a document and all its descendants (folders + leaves) via BFS.
* Honors the current ownership scope.
*/
private collectSubtree = async (
rootId: string,
runner: LobeChatDatabase = this.db,
): Promise<DocumentItem[]> => {
const root = await runner.query.documents.findFirst({
where: and(this.ownership(), eq(documents.id, rootId)),
});
if (!root) return [];
const collected: DocumentItem[] = [root];
let frontier: string[] = [root.id];
while (frontier.length > 0) {
const children = await runner.query.documents.findMany({
where: and(this.ownership(), inArray(documents.parentId, frontier)),
});
if (children.length === 0) break;
collected.push(...children);
frontier = children.map((c) => c.id);
}
return collected;
};
countFileUsageInSubtree = async (
rootId: string,
runner: LobeChatDatabase = this.db,
): Promise<number> => {
const subtree = await this.collectSubtree(rootId, runner);
if (subtree.length === 0) return 0;
const ids = subtree.map((d) => d.id);
const result = await runner
.select({ totalSize: sum(files.size) })
.from(files)
.where(
and(
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, files),
inArray(files.parentId, ids),
),
);
return parseInt(result[0]?.totalSize ?? '0') || 0;
};
/**
* Transfer a document (and its subtree) to another workspace / personal scope.
* Files anchored to documents in the subtree are also re-homed so the
* resource manager view stays consistent.
*/
transferTo = async (
documentId: string,
targetWorkspaceId: string | null,
targetUserId: string,
): Promise<{ documentIds: string[] }> => {
return this.db.transaction(async (trx) => {
const scopedTrx = new DocumentModel(trx as LobeChatDatabase, this.userId, this.workspaceId);
const subtree = await scopedTrx.collectSubtree(documentId, trx as LobeChatDatabase);
if (subtree.length === 0) throw new Error('Document not found');
const ids = subtree.map((d) => d.id);
const ownershipUpdate = { userId: targetUserId, workspaceId: targetWorkspaceId };
// Resolve slug conflicts in the target scope
for (const doc of subtree) {
if (!doc.slug) continue;
const slug = await this.findAvailableSlug(
trx as LobeChatDatabase,
doc.slug,
targetWorkspaceId,
targetUserId,
doc.id,
);
if (slug !== doc.slug) {
await (trx as LobeChatDatabase)
.update(documents)
.set({ slug })
.where(eq(documents.id, doc.id));
}
}
await (trx as LobeChatDatabase)
.update(documents)
.set({ ...ownershipUpdate, updatedAt: new Date() })
.where(inArray(documents.id, ids));
// Move files anchored to these documents
await (trx as LobeChatDatabase)
.update(files)
.set(ownershipUpdate)
.where(inArray(files.parentId, ids));
return { documentIds: ids };
});
};
/**
* Deep clone a document (and its subtree) into another workspace / personal
* scope. Generates fresh ids and preserves the parent/child topology.
*/
copyToWorkspace = async (
documentId: string,
targetWorkspaceId: string | null,
targetUserId: string,
): Promise<{ rootId: string }> => {
return this.db.transaction(async (trx) => {
const scopedTrx = new DocumentModel(trx as LobeChatDatabase, this.userId, this.workspaceId);
const subtree = await scopedTrx.collectSubtree(documentId, trx as LobeChatDatabase);
if (subtree.length === 0) throw new Error('Document not found');
// BFS clone: parents are inserted before children so we always know the
// new parent id by the time we get to the child.
const idMap = new Map<string, string>();
const byId = new Map(subtree.map((d) => [d.id, d]));
const queue: string[] = [documentId];
const seen = new Set<string>();
while (queue.length > 0) {
const currentId = queue.shift()!;
if (seen.has(currentId)) continue;
seen.add(currentId);
const original = byId.get(currentId);
if (!original) continue;
const newParentId =
currentId === documentId ? null : (idMap.get(original.parentId!) ?? null);
let newSlug = original.slug;
if (newSlug) {
newSlug = await this.findAvailableSlug(
trx as LobeChatDatabase,
newSlug,
targetWorkspaceId,
targetUserId,
);
}
const inserted = (await (trx as LobeChatDatabase)
.insert(documents)
.values({
accessedAt: original.accessedAt,
clientId: null,
content: original.content,
editorData: original.editorData,
fileId: null,
fileType: original.fileType,
filename: original.filename,
knowledgeBaseId: null,
metadata: { ...original.metadata, duplicatedFrom: original.id },
pages: original.pages,
parentId: newParentId,
slug: newSlug,
source: original.source,
sourceType: original.sourceType,
title: original.title,
totalCharCount: original.totalCharCount,
totalLineCount: original.totalLineCount,
userId: targetUserId,
workspaceId: targetWorkspaceId,
} as NewDocument)
.returning({ id: documents.id })) as { id: string }[];
idMap.set(original.id, inserted[0]!.id);
for (const c of subtree) {
if (c.parentId === original.id) queue.push(c.id);
}
}
return { rootId: idMap.get(documentId)! };
});
};
/**
* Find a slug not already taken in the target (workspaceId, userId) scope.
* Tries `slug`, `slug-1`, , `slug-99`. Mirrors the agent transfer behavior.
*/
private findAvailableSlug = async (
runner: LobeChatDatabase,
baseSlug: string,
targetWorkspaceId: string | null,
targetUserId: string,
ignoreDocumentId?: string,
): Promise<string> => {
const buildWhere = (candidate: string) =>
targetWorkspaceId
? and(eq(documents.slug, candidate), eq(documents.workspaceId, targetWorkspaceId))
: and(
eq(documents.slug, candidate),
eq(documents.userId, targetUserId),
isNull(documents.workspaceId),
);
const isFree = async (candidate: string): Promise<boolean> => {
const existing = await runner.query.documents.findFirst({ where: buildWhere(candidate) });
if (!existing) return true;
return ignoreDocumentId !== undefined && existing.id === ignoreDocumentId;
};
if (await isFree(baseSlug)) return baseSlug;
for (let suffix = 1; suffix < 100; suffix++) {
const candidate = `${baseSlug}-${suffix}`;
if (await isFree(candidate)) return candidate;
}
// Fallback: append timestamp to guarantee uniqueness
return `${baseSlug}-${Date.now()}`;
};
}
+24 -22
View File
@@ -3,6 +3,7 @@ import { and, desc, eq, lt, or } from 'drizzle-orm';
import type { DocumentHistoryItem, NewDocumentHistory } from '../schemas';
import { documentHistories, documents } from '../schemas';
import type { LobeChatDatabase } from '../type';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
export interface QueryDocumentHistoryParams {
beforeId?: string;
@@ -13,18 +14,32 @@ export interface QueryDocumentHistoryParams {
export class DocumentHistoryModel {
private userId: string;
private workspaceId?: string;
private db: LobeChatDatabase;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.workspaceId = workspaceId;
this.db = db;
}
private ownership() {
return buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
documentHistories,
);
}
create = async (params: Omit<NewDocumentHistory, 'userId'>): Promise<DocumentHistoryItem> => {
const [document] = await this.db
.select({ id: documents.id })
.from(documents)
.where(and(eq(documents.id, params.documentId), eq(documents.userId, this.userId)))
.where(
and(
eq(documents.id, params.documentId),
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents),
),
)
.limit(1);
if (!document) {
@@ -33,7 +48,7 @@ export class DocumentHistoryModel {
const [result] = await this.db
.insert(documentHistories)
.values({ ...params, userId: this.userId })
.values(buildWorkspacePayload({ userId: this.userId, workspaceId: this.workspaceId }, params))
.returning();
return result!;
@@ -42,29 +57,24 @@ export class DocumentHistoryModel {
delete = async (id: string) => {
return this.db
.delete(documentHistories)
.where(and(eq(documentHistories.id, id), eq(documentHistories.userId, this.userId)));
.where(and(eq(documentHistories.id, id), this.ownership()));
};
deleteByDocumentId = async (documentId: string) => {
return this.db
.delete(documentHistories)
.where(
and(
eq(documentHistories.documentId, documentId),
eq(documentHistories.userId, this.userId),
),
);
.where(and(eq(documentHistories.documentId, documentId), this.ownership()));
};
deleteAll = async () => {
return this.db.delete(documentHistories).where(eq(documentHistories.userId, this.userId));
return this.db.delete(documentHistories).where(this.ownership());
};
findById = async (id: string): Promise<DocumentHistoryItem | undefined> => {
const [result] = await this.db
.select()
.from(documentHistories)
.where(and(eq(documentHistories.id, id), eq(documentHistories.userId, this.userId)))
.where(and(eq(documentHistories.id, id), this.ownership()))
.limit(1);
return result;
@@ -74,12 +84,7 @@ export class DocumentHistoryModel {
const [result] = await this.db
.select()
.from(documentHistories)
.where(
and(
eq(documentHistories.documentId, documentId),
eq(documentHistories.userId, this.userId),
),
)
.where(and(eq(documentHistories.documentId, documentId), this.ownership()))
.orderBy(desc(documentHistories.savedAt), desc(documentHistories.id))
.limit(1);
@@ -92,10 +97,7 @@ export class DocumentHistoryModel {
documentId,
limit = 50,
}: QueryDocumentHistoryParams): Promise<DocumentHistoryItem[]> => {
const conditions = [
eq(documentHistories.documentId, documentId),
eq(documentHistories.userId, this.userId),
];
const conditions = [eq(documentHistories.documentId, documentId), this.ownership()];
if (beforeSavedAt !== undefined) {
if (beforeId !== undefined) {
+54 -12
View File
@@ -4,6 +4,7 @@ import { and, eq, sql } from 'drizzle-orm';
import { documents, documentShares, users } from '../schemas';
import type { LobeChatDatabase } from '../type';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
export interface DocumentShareAccessResult {
document: typeof documents.$inferSelect;
@@ -17,10 +18,12 @@ export interface DocumentShareAccessResult {
export class DocumentShareModel {
private userId: string;
private workspaceId?: string;
private db: LobeChatDatabase;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.workspaceId = workspaceId;
this.db = db;
}
@@ -34,7 +37,12 @@ export class DocumentShareModel {
const [doc] = await this.db
.select({ id: documents.id })
.from(documents)
.where(and(eq(documents.id, documentId), eq(documents.userId, this.userId)))
.where(
and(
eq(documents.id, documentId),
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents),
),
)
.limit(1);
if (!doc) {
@@ -43,12 +51,16 @@ export class DocumentShareModel {
const [result] = await this.db
.insert(documentShares)
.values({
documentId,
permission: params.permission ?? 'read',
userId: this.userId,
visibility: params.visibility ?? 'private',
})
.values(
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{
documentId,
permission: params.permission ?? 'read',
visibility: params.visibility ?? 'private',
},
),
)
.onConflictDoNothing({ target: documentShares.documentId })
.returning();
@@ -63,7 +75,15 @@ export class DocumentShareModel {
const [result] = await this.db
.update(documentShares)
.set({ updatedAt: new Date(), visibility })
.where(and(eq(documentShares.documentId, documentId), eq(documentShares.userId, this.userId)))
.where(
and(
eq(documentShares.documentId, documentId),
buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
documentShares,
),
),
)
.returning();
return result || null;
@@ -73,7 +93,15 @@ export class DocumentShareModel {
const [result] = await this.db
.update(documentShares)
.set({ permission, updatedAt: new Date() })
.where(and(eq(documentShares.documentId, documentId), eq(documentShares.userId, this.userId)))
.where(
and(
eq(documentShares.documentId, documentId),
buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
documentShares,
),
),
)
.returning();
return result || null;
@@ -83,7 +111,13 @@ export class DocumentShareModel {
return this.db
.delete(documentShares)
.where(
and(eq(documentShares.documentId, documentId), eq(documentShares.userId, this.userId)),
and(
eq(documentShares.documentId, documentId),
buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
documentShares,
),
),
);
};
@@ -98,7 +132,15 @@ export class DocumentShareModel {
visibility: documentShares.visibility,
})
.from(documentShares)
.where(and(eq(documentShares.documentId, documentId), eq(documentShares.userId, this.userId)))
.where(
and(
eq(documentShares.documentId, documentId),
buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
documentShares,
),
),
)
.limit(1);
return result[0] || null;
+19 -9
View File
@@ -3,20 +3,26 @@ import { and, count, eq } from 'drizzle-orm';
import type { NewEmbeddingsItem } from '../schemas';
import { embeddings } from '../schemas';
import type { LobeChatDatabase } from '../type';
import { buildWorkspaceWhere } from '../utils/workspace';
export class EmbeddingModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, embeddings);
create = async (value: Omit<NewEmbeddingsItem, 'userId'>) => {
const [item] = await this.db
.insert(embeddings)
.values({ ...value, userId: this.userId })
.values({ ...value, userId: this.userId, workspaceId: this.workspaceId ?? null })
.returning();
return item.id as string;
@@ -25,27 +31,31 @@ export class EmbeddingModel {
bulkCreate = async (values: Omit<NewEmbeddingsItem, 'userId'>[]) => {
return this.db
.insert(embeddings)
.values(values.map((item) => ({ ...item, userId: this.userId })))
.values(
values.map((item) => ({
...item,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})),
)
.onConflictDoNothing({
target: [embeddings.chunkId],
});
};
delete = async (id: string) => {
return this.db
.delete(embeddings)
.where(and(eq(embeddings.id, id), eq(embeddings.userId, this.userId)));
return this.db.delete(embeddings).where(and(eq(embeddings.id, id), this.ownership()));
};
query = async () => {
return this.db.query.embeddings.findMany({
where: eq(embeddings.userId, this.userId),
where: this.ownership(),
});
};
findById = async (id: string) => {
return this.db.query.embeddings.findFirst({
where: and(eq(embeddings.id, id), eq(embeddings.userId, this.userId)),
where: and(eq(embeddings.id, id), this.ownership()),
});
};
@@ -55,7 +65,7 @@ export class EmbeddingModel {
count: count(),
})
.from(embeddings)
.where(eq(embeddings.userId, this.userId));
.where(this.ownership());
return result[0].count;
};
+112 -25
View File
@@ -20,6 +20,7 @@ import {
topics,
} from '../schemas';
import type { LobeChatDatabase, Transaction } from '../type';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
/**
* Minimal file descriptor used to bootstrap user-uploaded files into a sandbox.
@@ -36,12 +37,17 @@ export interface SandboxInitFileItem {
export class FileModel {
private readonly userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, files);
/**
* Get file by ID without userId filter (public access)
* Use this for scenarios like file proxy where file should be accessible by ID alone
@@ -82,17 +88,26 @@ export class FileModel {
const result = (await tx
.insert(files)
.values({ ...params, userId: this.userId })
.values(
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{ ...params },
),
)
.returning()) as FileItem[];
const item = result[0]!;
if (params.knowledgeBaseId) {
await tx.insert(knowledgeBaseFiles).values({
fileId: item.id,
knowledgeBaseId: params.knowledgeBaseId,
userId: this.userId,
});
await tx.insert(knowledgeBaseFiles).values(
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{
fileId: item.id,
knowledgeBaseId: params.knowledgeBaseId,
},
),
);
}
return item;
@@ -150,7 +165,7 @@ export class FileModel {
.where(
and(
eq(documents.fileId, id),
eq(documents.userId, this.userId),
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents),
eq(documents.sourceType, 'file'),
),
);
@@ -166,7 +181,7 @@ export class FileModel {
}
// 4. Delete file record
await tx.delete(files).where(and(eq(files.id, id), eq(files.userId, this.userId)));
await tx.delete(files).where(and(eq(files.id, id), this.ownership()));
if (!fileHash) return;
@@ -200,7 +215,7 @@ export class FileModel {
totalSize: sum(files.size),
})
.from(files)
.where(eq(files.userId, this.userId));
.where(this.ownership());
return parseInt(result[0].totalSize!) || 0;
};
@@ -211,7 +226,7 @@ export class FileModel {
return await this.db.transaction(async (trx) => {
// 1. First get the file list to return the deleted files
const fileList = await trx.query.files.findMany({
where: and(inArray(files.id, ids), eq(files.userId, this.userId)),
where: and(inArray(files.id, ids), this.ownership()),
});
if (fileList.length === 0) return [];
@@ -229,7 +244,7 @@ export class FileModel {
.where(
and(
inArray(documents.fileId, ids),
eq(documents.userId, this.userId),
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents),
eq(documents.sourceType, 'file'),
),
);
@@ -243,7 +258,7 @@ export class FileModel {
}
// 5. Delete file records
await trx.delete(files).where(and(inArray(files.id, ids), eq(files.userId, this.userId)));
await trx.delete(files).where(and(inArray(files.id, ids), this.ownership()));
// If global files don't need to be deleted, no storage object should be removed.
if (!removeGlobalFile || hashList.length === 0) return [];
@@ -275,7 +290,7 @@ export class FileModel {
};
clear = async () => {
return this.db.delete(files).where(eq(files.userId, this.userId));
return this.db.delete(files).where(this.ownership());
};
query = async ({
@@ -287,10 +302,7 @@ export class FileModel {
showFilesInKnowledgeBase,
}: QueryFileListParams = {}) => {
// 1. Build where clause
let whereClause = and(
q ? ilike(files.name, `%${q}%`) : undefined,
eq(files.userId, this.userId),
);
let whereClause = and(q ? ilike(files.name, `%${q}%`) : undefined, this.ownership());
if (category && category !== FilesTabs.All && category !== FilesTabs.Home) {
const fileTypePrefix = this.getFileTypePrefix(category as FilesTabs);
if (Array.isArray(fileTypePrefix)) {
@@ -333,6 +345,7 @@ export class FileModel {
size: files.size,
updatedAt: files.updatedAt,
url: files.url,
userId: files.userId,
})
.from(files);
@@ -365,14 +378,14 @@ export class FileModel {
findByIds = async (ids: string[]) => {
return this.db.query.files.findMany({
where: and(inArray(files.id, ids), eq(files.userId, this.userId)),
where: and(inArray(files.id, ids), this.ownership()),
});
};
findById = async (id: string, trx?: Transaction) => {
const database = trx || this.db;
return database.query.files.findFirst({
where: and(eq(files.id, id), eq(files.userId, this.userId)),
where: and(eq(files.id, id), this.ownership()),
});
};
@@ -429,7 +442,7 @@ export class FileModel {
this.db
.update(files)
.set({ ...value, updatedAt: new Date() })
.where(and(eq(files.id, id), eq(files.userId, this.userId)));
.where(and(eq(files.id, id), this.ownership()));
/**
* get the corresponding file type prefix according to FilesTabs
@@ -459,10 +472,7 @@ export class FileModel {
findByNames = async (fileNames: string[]) =>
this.db.query.files.findMany({
where: and(
or(...fileNames.map((name) => like(files.name, `${name}%`))),
eq(files.userId, this.userId),
),
where: and(or(...fileNames.map((name) => like(files.name, `${name}%`))), this.ownership()),
});
// Abstract common method for deleting chunks
@@ -514,4 +524,81 @@ export class FileModel {
return chunkIds;
};
// ========== Transfer / Copy ==========
/**
* Transfer a single file (not a folder folders live in `documents` and are
* handled by `DocumentModel.transferTo`, which already cascades into `files`
* via `parentId`). Updates ownership + knowledgeBaseFiles linkage so the
* file remains visible in the target scope's resource manager.
*/
transferTo = async (
fileId: string,
targetWorkspaceId: string | null,
targetUserId: string,
): Promise<{ fileId: string }> => {
return this.db.transaction(async (trx) => {
const file = await trx.query.files.findFirst({
where: and(eq(files.id, fileId), this.ownership()),
});
if (!file) throw new Error('File not found');
const ownershipUpdate = { userId: targetUserId, workspaceId: targetWorkspaceId };
await trx
.update(files)
.set({ ...ownershipUpdate, updatedAt: new Date() })
.where(eq(files.id, fileId));
// Knowledge base links are scoped per-user; keep them pointed at the new owner.
await trx
.update(knowledgeBaseFiles)
.set({ userId: targetUserId })
.where(eq(knowledgeBaseFiles.fileId, fileId));
return { fileId };
});
};
/**
* Clone a file record into another workspace / personal scope. The physical
* blob is shared via `fileHash` `globalFiles`, so we only copy the row. AI
* index references (`chunkTaskId` / `embeddingTaskId`) are reset; the new
* scope is expected to re-index lazily.
*/
copyToWorkspace = async (
fileId: string,
targetWorkspaceId: string | null,
targetUserId: string,
): Promise<{ fileId: string }> => {
return this.db.transaction(async (trx) => {
const file = await trx.query.files.findFirst({
where: and(eq(files.id, fileId), this.ownership()),
});
if (!file) throw new Error('File not found');
const inserted = (await trx
.insert(files)
.values({
chunkTaskId: null,
clientId: null,
embeddingTaskId: null,
fileHash: file.fileHash,
fileType: file.fileType,
metadata: { ...(file.metadata as Record<string, unknown>), duplicatedFrom: file.id },
name: file.name,
// parentId would dangle in target scope; the user can drag it under a folder later.
parentId: null,
size: file.size,
source: file.source,
url: file.url,
userId: targetUserId,
workspaceId: targetWorkspaceId,
} as NewFile)
.returning({ id: files.id })) as { id: string }[];
return { fileId: inserted[0]!.id };
});
};
}
+16 -8
View File
@@ -16,6 +16,7 @@ import type { NewFile } from '../schemas';
import type { GenerationItem, GenerationWithAsyncTask, NewGeneration } from '../schemas/generation';
import { generations } from '../schemas/generation';
import type { LobeChatDatabase, Transaction } from '../type';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
import { FileModel } from './file';
// Create debug logger
@@ -24,16 +25,21 @@ const log = debug('lobe-image:generation-model');
export class GenerationModel {
private db: LobeChatDatabase;
private userId: string;
private workspaceId?: string;
private fileModel: FileModel;
private fileService: FileService;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.db = db;
this.userId = userId;
this.fileModel = new FileModel(db, userId);
this.workspaceId = workspaceId;
this.fileModel = new FileModel(db, userId, workspaceId);
this.fileService = new FileService(db, userId);
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, generations);
async create(value: Omit<NewGeneration, 'userId'>): Promise<GenerationItem> {
log('Creating generation: %O', {
generationBatchId: value.generationBatchId,
@@ -42,7 +48,9 @@ export class GenerationModel {
const [result] = await this.db
.insert(generations)
.values({ ...value, userId: this.userId })
.values(
buildWorkspacePayload({ userId: this.userId, workspaceId: this.workspaceId }, { ...value }),
)
.returning();
log('Generation created successfully: %s', result.id);
@@ -53,7 +61,7 @@ export class GenerationModel {
log('Finding generation by ID: %s for user: %s', id, this.userId);
const result = await this.db.query.generations.findFirst({
where: and(eq(generations.id, id), eq(generations.userId, this.userId)),
where: and(eq(generations.id, id), this.ownership()),
});
log('Generation %s: %s', id, result ? 'found' : 'not found');
@@ -64,7 +72,7 @@ export class GenerationModel {
log('Finding generation by ID: %s for user: %s', id, this.userId);
const result = await this.db.query.generations.findFirst({
where: and(eq(generations.id, id), eq(generations.userId, this.userId)),
where: and(eq(generations.id, id), this.ownership()),
with: {
asyncTask: true,
},
@@ -84,7 +92,7 @@ export class GenerationModel {
return await tx
.update(generations)
.set({ ...value, updatedAt: new Date() })
.where(and(eq(generations.id, id), eq(generations.userId, this.userId)));
.where(and(eq(generations.id, id), this.ownership()));
};
const result = await (trx ? executeUpdate(trx) : this.db.transaction(executeUpdate));
@@ -136,7 +144,7 @@ export class GenerationModel {
log('Finding generation by asyncTaskId: %s', asyncTaskId);
return this.db.query.generations.findFirst({
where: eq(generations.asyncTaskId, asyncTaskId),
where: and(eq(generations.asyncTaskId, asyncTaskId), this.ownership()),
});
}
@@ -146,7 +154,7 @@ export class GenerationModel {
const executeDelete = async (tx: Transaction) => {
return await tx
.delete(generations)
.where(and(eq(generations.id, id), eq(generations.userId, this.userId)))
.where(and(eq(generations.id, id), this.ownership()))
.returning();
};
+16 -14
View File
@@ -17,6 +17,7 @@ import type {
} from '../schemas/generation';
import { generationBatches } from '../schemas/generation';
import type { LobeChatDatabase } from '../type';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
import { GenerationModel } from './generation';
const log = debug('lobe-image:generation-batch-model');
@@ -24,16 +25,21 @@ const log = debug('lobe-image:generation-batch-model');
export class GenerationBatchModel {
private db: LobeChatDatabase;
private userId: string;
private workspaceId?: string;
private fileService: FileService;
private generationModel: GenerationModel;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.db = db;
this.userId = userId;
this.workspaceId = workspaceId;
this.fileService = new FileService(db, userId);
this.generationModel = new GenerationModel(db, userId);
this.generationModel = new GenerationModel(db, userId, workspaceId);
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, generationBatches);
async create(value: NewGenerationBatch): Promise<GenerationBatchItem> {
log('Creating generation batch: %O', {
topicId: value.generationTopicId,
@@ -42,7 +48,9 @@ export class GenerationBatchModel {
const [result] = await this.db
.insert(generationBatches)
.values({ ...value, userId: this.userId })
.values(
buildWorkspacePayload({ userId: this.userId, workspaceId: this.workspaceId }, { ...value }),
)
.returning();
log('Generation batch created successfully: %s', result.id);
@@ -53,7 +61,7 @@ export class GenerationBatchModel {
log('Finding generation batch by ID: %s for user: %s', id, this.userId);
const result = await this.db.query.generationBatches.findFirst({
where: and(eq(generationBatches.id, id), eq(generationBatches.userId, this.userId)),
where: and(eq(generationBatches.id, id), this.ownership()),
});
log('Generation batch %s: %s', id, result ? 'found' : 'not found');
@@ -65,10 +73,7 @@ export class GenerationBatchModel {
const results = await this.db.query.generationBatches.findMany({
orderBy: (table, { desc }) => [desc(table.createdAt)],
where: and(
eq(generationBatches.generationTopicId, topicId),
eq(generationBatches.userId, this.userId),
),
where: and(eq(generationBatches.generationTopicId, topicId), this.ownership()),
});
log('Found %d generation batches for topic %s', results.length, topicId);
@@ -87,10 +92,7 @@ export class GenerationBatchModel {
const results = await this.db.query.generationBatches.findMany({
orderBy: (table, { asc }) => [asc(table.createdAt)],
where: and(
eq(generationBatches.generationTopicId, topicId),
eq(generationBatches.userId, this.userId),
),
where: and(eq(generationBatches.generationTopicId, topicId), this.ownership()),
with: {
generations: {
orderBy: (table, { asc }) => [asc(table.createdAt), asc(table.id)],
@@ -184,7 +186,7 @@ export class GenerationBatchModel {
// 1. First, get generations with their assets to collect file URLs for cleanup
const batchWithGenerations = await this.db.query.generationBatches.findFirst({
where: and(eq(generationBatches.id, id), eq(generationBatches.userId, this.userId)),
where: and(eq(generationBatches.id, id), this.ownership()),
with: {
generations: {
columns: {
@@ -215,7 +217,7 @@ export class GenerationBatchModel {
// 3. Delete the batch record (this will cascade delete all associated generations)
const [deletedBatch] = await this.db
.delete(generationBatches)
.where(and(eq(generationBatches.id, id), eq(generationBatches.userId, this.userId)))
.where(and(eq(generationBatches.id, id), this.ownership()))
.returning();
log(
+20 -10
View File
@@ -11,20 +11,26 @@ import type { GenerationTopicItem } from '../schemas/generation';
import { generationTopics } from '../schemas/generation';
import type { LobeChatDatabase } from '../type';
import type { GenerationTopicType } from '../types/generation';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
export class GenerationTopicModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
private fileService: FileService;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
this.fileService = new FileService(db, userId);
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, generationTopics);
queryAll = async (type?: GenerationTopicType) => {
const conditions = [eq(generationTopics.userId, this.userId)];
const conditions = [this.ownership()];
if (type) {
conditions.push(eq(generationTopics.type, type));
}
@@ -51,11 +57,15 @@ export class GenerationTopicModel {
create = async (title: string, type?: GenerationTopicType) => {
const [newGenerationTopic] = await this.db
.insert(generationTopics)
.values({
title,
type: type ?? 'image',
userId: this.userId,
})
.values(
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{
title,
type: type ?? 'image',
},
),
)
.returning();
return newGenerationTopic;
@@ -68,7 +78,7 @@ export class GenerationTopicModel {
const [updatedTopic] = await this.db
.update(generationTopics)
.set({ ...data, updatedAt: new Date() })
.where(and(eq(generationTopics.id, id), eq(generationTopics.userId, this.userId)))
.where(and(eq(generationTopics.id, id), this.ownership()))
.returning();
return updatedTopic;
@@ -90,7 +100,7 @@ export class GenerationTopicModel {
): Promise<{ deletedTopic: GenerationTopicItem; filesToDelete: string[] } | undefined> => {
// 1. First, get the topic with all its batches and generations to collect file URLs
const topicWithBatches = await this.db.query.generationTopics.findFirst({
where: and(eq(generationTopics.id, id), eq(generationTopics.userId, this.userId)),
where: and(eq(generationTopics.id, id), this.ownership()),
with: {
batches: {
with: {
@@ -134,7 +144,7 @@ export class GenerationTopicModel {
// 3. Delete the topic record (this will cascade delete all batches and generations)
const [deletedTopic] = await this.db
.delete(generationTopics)
.where(and(eq(generationTopics.id, id), eq(generationTopics.userId, this.userId)))
.where(and(eq(generationTopics.id, id), this.ownership()))
.returning();
return {
+341 -28
View File
@@ -1,26 +1,37 @@
import type { KnowledgeBaseItem } from '@lobechat/types';
import { and, count, desc, eq, inArray } from 'drizzle-orm';
import { and, count, desc, eq, inArray, or, sum } from 'drizzle-orm';
import type { NewKnowledgeBase } from '../schemas';
import { documents, knowledgeBaseFiles, knowledgeBases } from '../schemas';
import type { NewDocument, NewFile, NewKnowledgeBase } from '../schemas';
import { documents, files, knowledgeBaseFiles, knowledgeBases } from '../schemas';
import type { LobeChatDatabase } from '../type';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
import { FileModel } from './file';
export class KnowledgeBaseModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, knowledgeBases);
// create
create = async (params: Omit<NewKnowledgeBase, 'userId'>) => {
const [result] = await this.db
.insert(knowledgeBases)
.values({ ...params, userId: this.userId })
.values(
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{ ...params },
),
)
.returning();
return result;
@@ -29,7 +40,7 @@ export class KnowledgeBaseModel {
addFilesToKnowledgeBase = async (id: string, fileIds: string[]) => {
// Verify the target knowledge base belongs to the current user
const kb = await this.db.query.knowledgeBases.findFirst({
where: and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId)),
where: and(eq(knowledgeBases.id, id), this.ownership()),
});
if (!kb) return [];
@@ -43,7 +54,12 @@ export class KnowledgeBaseModel {
const docsWithFiles = await this.db
.select({ fileId: documents.fileId })
.from(documents)
.where(and(inArray(documents.id, documentIds), eq(documents.userId, this.userId)));
.where(
and(
inArray(documents.id, documentIds),
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents),
),
);
const mirrorFileIds = docsWithFiles
.map((doc) => doc.fileId)
@@ -54,7 +70,12 @@ export class KnowledgeBaseModel {
await this.db
.update(documents)
.set({ knowledgeBaseId: id })
.where(and(inArray(documents.id, documentIds), eq(documents.userId, this.userId)));
.where(
and(
inArray(documents.id, documentIds),
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents),
),
);
}
// Insert using resolved file IDs
@@ -65,20 +86,23 @@ export class KnowledgeBaseModel {
return this.db
.insert(knowledgeBaseFiles)
.values(
resolvedFileIds.map((fileId) => ({ fileId, knowledgeBaseId: id, userId: this.userId })),
resolvedFileIds.map((fileId) => ({
fileId,
knowledgeBaseId: id,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})),
)
.returning();
};
// delete
delete = async (id: string) => {
return this.db
.delete(knowledgeBases)
.where(and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId)));
return this.db.delete(knowledgeBases).where(and(eq(knowledgeBases.id, id), this.ownership()));
};
deleteAll = async () => {
return this.db.delete(knowledgeBases).where(eq(knowledgeBases.userId, this.userId));
return this.db.delete(knowledgeBases).where(this.ownership());
};
removeFilesFromKnowledgeBase = async (knowledgeBaseId: string, ids: string[]) => {
@@ -92,7 +116,12 @@ export class KnowledgeBaseModel {
const docsWithFiles = await this.db
.select({ fileId: documents.fileId })
.from(documents)
.where(and(inArray(documents.id, documentIds), eq(documents.userId, this.userId)));
.where(
and(
inArray(documents.id, documentIds),
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents),
),
);
const mirrorFileIds = docsWithFiles
.map((doc) => doc.fileId)
@@ -106,7 +135,7 @@ export class KnowledgeBaseModel {
.where(
and(
inArray(documents.id, documentIds),
eq(documents.userId, this.userId),
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents),
eq(documents.knowledgeBaseId, knowledgeBaseId),
),
);
@@ -121,7 +150,10 @@ export class KnowledgeBaseModel {
.delete(knowledgeBaseFiles)
.where(
and(
eq(knowledgeBaseFiles.userId, this.userId),
buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
knowledgeBaseFiles,
),
eq(knowledgeBaseFiles.knowledgeBaseId, knowledgeBaseId),
inArray(knowledgeBaseFiles.fileId, resolvedFileIds),
),
@@ -142,7 +174,7 @@ export class KnowledgeBaseModel {
updatedAt: knowledgeBases.updatedAt,
})
.from(knowledgeBases)
.where(eq(knowledgeBases.userId, this.userId))
.where(this.ownership())
.orderBy(desc(knowledgeBases.updatedAt));
return data as KnowledgeBaseItem[];
@@ -150,16 +182,288 @@ export class KnowledgeBaseModel {
findById = async (id: string) => {
return this.db.query.knowledgeBases.findFirst({
where: and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId)),
where: and(eq(knowledgeBases.id, id), this.ownership()),
});
};
countFileUsage = async (id: string): Promise<number> => {
const result = await this.db
.select({ totalSize: sum(files.size) })
.from(knowledgeBaseFiles)
.innerJoin(files, eq(files.id, knowledgeBaseFiles.fileId))
.where(
and(
eq(knowledgeBaseFiles.knowledgeBaseId, id),
buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
knowledgeBaseFiles,
),
),
);
return parseInt(result[0]?.totalSize ?? '0') || 0;
};
// update
update = async (id: string, value: Partial<KnowledgeBaseItem>) =>
this.db
.update(knowledgeBases)
.set({ ...value, updatedAt: new Date() })
.where(and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId)));
.where(and(eq(knowledgeBases.id, id), this.ownership()));
private resolveAvailableName = async (
db: LobeChatDatabase,
name: string,
targetWorkspaceId: string | null,
targetUserId: string,
excludeId?: string,
): Promise<string> => {
const existingKnowledgeBases = await db
.select({ id: knowledgeBases.id, name: knowledgeBases.name })
.from(knowledgeBases)
.where(
buildWorkspaceWhere(
{ userId: targetUserId, workspaceId: targetWorkspaceId ?? undefined },
knowledgeBases,
),
);
const existingNames = new Set(
existingKnowledgeBases
.filter((knowledgeBase) => knowledgeBase.id !== excludeId)
.map((knowledgeBase) => knowledgeBase.name),
);
if (!existingNames.has(name)) return name;
let index = 1;
let candidate = `${name} (${index})`;
while (existingNames.has(candidate)) {
index += 1;
candidate = `${name} (${index})`;
}
return candidate;
};
transferTo = async (
id: string,
targetWorkspaceId: string | null,
targetUserId: string,
): Promise<{ id: string }> => {
return this.db.transaction(async (trx) => {
const [knowledgeBase] = await trx
.select()
.from(knowledgeBases)
.where(and(eq(knowledgeBases.id, id), this.ownership()))
.limit(1);
if (!knowledgeBase) throw new Error('Knowledge base not found');
const fileLinks = await trx
.select({ fileId: knowledgeBaseFiles.fileId })
.from(knowledgeBaseFiles)
.where(eq(knowledgeBaseFiles.knowledgeBaseId, id));
const fileIds = fileLinks.map((item) => item.fileId);
const now = new Date();
const ownershipUpdate = { userId: targetUserId, workspaceId: targetWorkspaceId };
const targetName = await this.resolveAvailableName(
trx as LobeChatDatabase,
knowledgeBase.name,
targetWorkspaceId,
targetUserId,
id,
);
await trx
.update(knowledgeBases)
.set({ ...ownershipUpdate, name: targetName, updatedAt: now })
.where(eq(knowledgeBases.id, id));
await trx
.update(knowledgeBaseFiles)
.set(ownershipUpdate)
.where(eq(knowledgeBaseFiles.knowledgeBaseId, id));
if (fileIds.length > 0) {
await trx
.update(files)
.set({ ...ownershipUpdate, updatedAt: now })
.where(inArray(files.id, fileIds));
}
const documentWhere =
fileIds.length > 0
? or(eq(documents.knowledgeBaseId, id), inArray(documents.fileId, fileIds))
: eq(documents.knowledgeBaseId, id);
await trx
.update(documents)
.set({ ...ownershipUpdate, updatedAt: now })
.where(documentWhere);
return { id };
});
};
copyToWorkspace = async (
id: string,
targetWorkspaceId: string | null,
targetUserId: string,
): Promise<{ id: string }> => {
return this.db.transaction(async (trx) => {
const [knowledgeBase] = await trx
.select()
.from(knowledgeBases)
.where(and(eq(knowledgeBases.id, id), this.ownership()))
.limit(1);
if (!knowledgeBase) throw new Error('Knowledge base not found');
const targetName = await this.resolveAvailableName(
trx as LobeChatDatabase,
knowledgeBase.name,
targetWorkspaceId,
targetUserId,
);
const [copiedKnowledgeBase] = await trx
.insert(knowledgeBases)
.values({
avatar: knowledgeBase.avatar,
description: knowledgeBase.description,
isPublic: knowledgeBase.isPublic,
name: targetName,
settings: knowledgeBase.settings,
type: knowledgeBase.type,
userId: targetUserId,
workspaceId: targetWorkspaceId,
} as NewKnowledgeBase)
.returning();
const fileLinks = await trx
.select({ fileId: knowledgeBaseFiles.fileId })
.from(knowledgeBaseFiles)
.where(eq(knowledgeBaseFiles.knowledgeBaseId, id));
const fileIds = fileLinks.map((item) => item.fileId);
const documentWhere =
fileIds.length > 0
? or(eq(documents.knowledgeBaseId, id), inArray(documents.fileId, fileIds))
: eq(documents.knowledgeBaseId, id);
const sourceDocuments = await trx.select().from(documents).where(documentWhere);
const sourceDocumentIds = new Set(sourceDocuments.map((item) => item.id));
const documentIdMap = new Map<string, string>();
let pendingDocuments = [...sourceDocuments];
while (pendingDocuments.length > 0) {
const readyDocuments = pendingDocuments.filter(
(document) =>
!document.parentId ||
!sourceDocumentIds.has(document.parentId) ||
documentIdMap.has(document.parentId),
);
const documentsToCopy = readyDocuments.length > 0 ? readyDocuments : pendingDocuments;
for (const document of documentsToCopy) {
const metadata =
document.metadata && typeof document.metadata === 'object'
? { ...document.metadata, duplicatedFrom: document.id }
: { duplicatedFrom: document.id };
const [copiedDocument] = await trx
.insert(documents)
.values({
clientId: null,
content: document.content,
description: document.description,
editorData: document.editorData,
fileId: null,
fileType: document.fileType,
filename: document.filename,
knowledgeBaseId:
document.knowledgeBaseId === id ? copiedKnowledgeBase.id : document.knowledgeBaseId,
metadata,
pages: document.pages,
parentId: document.parentId ? (documentIdMap.get(document.parentId) ?? null) : null,
source: document.source,
sourceType: document.sourceType,
title: document.title,
totalCharCount: document.totalCharCount,
totalLineCount: document.totalLineCount,
userId: targetUserId,
workspaceId: targetWorkspaceId,
} as NewDocument)
.returning({ id: documents.id });
documentIdMap.set(document.id, copiedDocument.id);
}
const copiedIds = new Set(documentsToCopy.map((document) => document.id));
pendingDocuments = pendingDocuments.filter((document) => !copiedIds.has(document.id));
}
const fileIdMap = new Map<string, string>();
if (fileIds.length > 0) {
const sourceFiles = await trx.select().from(files).where(inArray(files.id, fileIds));
for (const file of sourceFiles) {
const metadata =
file.metadata && typeof file.metadata === 'object'
? { ...file.metadata, duplicatedFrom: file.id }
: { duplicatedFrom: file.id };
const [copiedFile] = await trx
.insert(files)
.values({
chunkTaskId: null,
clientId: null,
embeddingTaskId: null,
fileHash: file.fileHash,
fileType: file.fileType,
metadata,
name: file.name,
parentId: file.parentId ? (documentIdMap.get(file.parentId) ?? null) : null,
size: file.size,
source: file.source,
url: file.url,
userId: targetUserId,
workspaceId: targetWorkspaceId,
} as NewFile)
.returning({ id: files.id });
fileIdMap.set(file.id, copiedFile.id);
}
const copiedLinks = fileLinks.flatMap((link) => {
const fileId = fileIdMap.get(link.fileId);
if (!fileId) return [];
return [
{
fileId,
knowledgeBaseId: copiedKnowledgeBase.id,
userId: targetUserId,
workspaceId: targetWorkspaceId,
},
];
});
if (copiedLinks.length > 0) {
await trx.insert(knowledgeBaseFiles).values(copiedLinks);
}
}
for (const document of sourceDocuments) {
if (!document.fileId) continue;
const copiedDocumentId = documentIdMap.get(document.id);
const copiedFileId = fileIdMap.get(document.fileId);
if (!copiedDocumentId || !copiedFileId) continue;
await trx
.update(documents)
.set({ fileId: copiedFileId })
.where(eq(documents.id, copiedDocumentId));
}
return { id: copiedKnowledgeBase.id };
});
};
findExclusiveFileIds = async (knowledgeBaseId: string): Promise<string[]> => {
const kbFiles = await this.db
@@ -168,7 +472,10 @@ export class KnowledgeBaseModel {
.where(
and(
eq(knowledgeBaseFiles.knowledgeBaseId, knowledgeBaseId),
eq(knowledgeBaseFiles.userId, this.userId),
buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
knowledgeBaseFiles,
),
),
);
const fileIds = kbFiles.map((f) => f.fileId);
@@ -183,7 +490,10 @@ export class KnowledgeBaseModel {
.where(
and(
inArray(knowledgeBaseFiles.fileId, fileIds),
eq(knowledgeBaseFiles.userId, this.userId),
buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
knowledgeBaseFiles,
),
),
)
.groupBy(knowledgeBaseFiles.fileId);
@@ -196,14 +506,12 @@ export class KnowledgeBaseModel {
let deletedFiles: Array<{ id: string; url: string | null }> = [];
if (exclusiveFileIds.length > 0) {
const fileModel = new FileModel(this.db, this.userId);
const fileModel = new FileModel(this.db, this.userId, this.workspaceId);
const result = await fileModel.deleteMany(exclusiveFileIds, removeGlobalFile);
deletedFiles = (result || []).map((f) => ({ id: f.id, url: f.url }));
}
await this.db
.delete(knowledgeBases)
.where(and(eq(knowledgeBases.id, id), eq(knowledgeBases.userId, this.userId)));
await this.db.delete(knowledgeBases).where(and(eq(knowledgeBases.id, id), this.ownership()));
return { deletedFiles };
};
@@ -212,18 +520,23 @@ export class KnowledgeBaseModel {
const allKbFileIds = await this.db
.select({ fileId: knowledgeBaseFiles.fileId })
.from(knowledgeBaseFiles)
.where(eq(knowledgeBaseFiles.userId, this.userId));
.where(
buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
knowledgeBaseFiles,
),
);
const fileIds = [...new Set(allKbFileIds.map((f) => f.fileId))];
let deletedFiles: Array<{ id: string; url: string | null }> = [];
if (fileIds.length > 0) {
const fileModel = new FileModel(this.db, this.userId);
const fileModel = new FileModel(this.db, this.userId, this.workspaceId);
const result = await fileModel.deleteMany(fileIds, removeGlobalFile);
deletedFiles = (result || []).map((f) => ({ id: f.id, url: f.url }));
}
await this.db.delete(knowledgeBases).where(eq(knowledgeBases.userId, this.userId));
await this.db.delete(knowledgeBases).where(this.ownership());
return { deletedFiles };
};
@@ -7,6 +7,7 @@ import type {
} from '../schemas/llmGenerationTracing';
import { llmGenerationTracing } from '../schemas/llmGenerationTracing';
import type { LobeChatDatabase } from '../type';
import { buildWorkspaceWhere } from '../utils/workspace';
export interface RecordLlmGenerationParams {
agentId?: string | null;
@@ -52,10 +53,19 @@ export interface UpdateLlmGenerationFeedbackParams {
export class LlmGenerationTracingModel {
private readonly db: LobeChatDatabase;
private readonly userId: string;
private readonly workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.db = db;
this.userId = userId;
this.workspaceId = workspaceId;
}
private ownership() {
return buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
llmGenerationTracing,
);
}
async record(params: RecordLlmGenerationParams): Promise<{ id: string }> {
@@ -86,6 +96,7 @@ export class LlmGenerationTracingModel {
trigger: params.trigger ?? null,
userId: this.userId,
validationFailed: params.validationFailed ?? false,
workspaceId: this.workspaceId ?? null,
};
const [row] = await this.db
@@ -116,7 +127,7 @@ export class LlmGenerationTracingModel {
feedbackSource: params.source,
feedbackUpdatedAt: new Date(),
})
.where(and(eq(llmGenerationTracing.id, id), eq(llmGenerationTracing.userId, this.userId)))
.where(and(eq(llmGenerationTracing.id, id), this.ownership()))
.returning({ id: llmGenerationTracing.id });
return { updated: rows.length > 0 };
}
@@ -125,7 +136,7 @@ export class LlmGenerationTracingModel {
const [row] = await this.db
.select()
.from(llmGenerationTracing)
.where(and(eq(llmGenerationTracing.id, id), eq(llmGenerationTracing.userId, this.userId)))
.where(and(eq(llmGenerationTracing.id, id), this.ownership()))
.limit(1);
return row ?? null;
}
@@ -134,7 +145,7 @@ export class LlmGenerationTracingModel {
return this.db
.select()
.from(llmGenerationTracing)
.where(eq(llmGenerationTracing.userId, this.userId))
.where(this.ownership())
.orderBy(desc(llmGenerationTracing.createdAt))
.limit(limit);
}
+161 -105
View File
@@ -74,6 +74,7 @@ import type { LobeChatDatabase, Transaction } from '../type';
import { sanitizeBm25Query } from '../utils/bm25';
import { genEndDateWhere, genRangeWhere, genStartDateWhere, genWhere } from '../utils/genWhere';
import { idGenerator } from '../utils/idGenerator';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
import { recomputeTopicUsage } from './topicUsage';
/**
@@ -195,12 +196,29 @@ interface SplitCreateMessageParams {
export class MessageModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, messages);
private pluginsOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, messagePlugins);
private translatesOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, messageTranslates);
private ttsOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, messageTTS);
private agentsToSessionsOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentsToSessions);
/**
* Touch topics' updatedAt timestamp within a transaction
*/
@@ -209,7 +227,12 @@ export class MessageModel {
await trx
.update(topics)
.set({ updatedAt: new Date() })
.where(and(inArray(topics.id, topicIds), eq(topics.userId, this.userId)));
.where(
and(
inArray(topics.id, topicIds),
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, topics),
),
);
}
// **************** Query *************** //
@@ -414,7 +437,7 @@ export class MessageModel {
.from(messages)
.where(
and(
eq(messages.userId, this.userId),
this.ownership(),
// Filter out messages that belong to MessageGroups
isNull(messages.messageGroupId),
where,
@@ -785,7 +808,10 @@ export class MessageModel {
})
.from(threads)
.where(
and(eq(threads.userId, this.userId), inArray(threads.sourceMessageId, taskMessageIds)),
and(
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, threads),
inArray(threads.sourceMessageId, taskMessageIds),
),
),
{ taskMessageCount: taskMessageIds.length },
);
@@ -889,7 +915,7 @@ export class MessageModel {
ttsVoice: messageTTS.voice,
})
.from(messages)
.where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIds)))
.where(and(this.ownership(), inArray(messages.id, messageIds)))
.leftJoin(messagePlugins, eq(messagePlugins.id, messages.id))
.leftJoin(messageTranslates, eq(messageTranslates.id, messages.id))
.leftJoin(messageTTS, eq(messageTTS.id, messages.id))
@@ -957,7 +983,10 @@ export class MessageModel {
.from(threads)
.where(
and(
eq(threads.userId, this.userId),
buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
threads,
),
inArray(threads.sourceMessageId, taskMessageIds),
),
)
@@ -1104,7 +1133,7 @@ export class MessageModel {
): Promise<UIChatMessage[]> => {
// 1. Query MessageGroups for this topic, optionally filtered by time range
const whereConditions = [
eq(messageGroups.userId, this.userId),
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, messageGroups),
eq(messageGroups.topicId, topicId),
];
@@ -1145,7 +1174,7 @@ export class MessageModel {
messageGroupId: messages.messageGroupId,
})
.from(messages)
.where(and(eq(messages.userId, this.userId), inArray(messages.messageGroupId, groupIds)))
.where(and(this.ownership(), inArray(messages.messageGroupId, groupIds)))
.orderBy(asc(messages.createdAt)),
{ groupCount: groupIds.length },
);
@@ -1247,7 +1276,10 @@ export class MessageModel {
private buildThreadQueryCondition = async (threadId: string): Promise<SQL | undefined> => {
// Fetch the thread info to get sourceMessageId and type
const thread = await this.db.query.threads.findFirst({
where: and(eq(threads.id, threadId), eq(threads.userId, this.userId)),
where: and(
eq(threads.id, threadId),
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, threads),
),
});
if (!thread?.sourceMessageId || !thread?.topicId) {
@@ -1280,7 +1312,7 @@ export class MessageModel {
const agentSession = await this.db
.select({ sessionId: agentsToSessions.sessionId })
.from(agentsToSessions)
.where(and(eq(agentsToSessions.agentId, agentId), eq(agentsToSessions.userId, this.userId)))
.where(and(eq(agentsToSessions.agentId, agentId), this.agentsToSessionsOwnership()))
.limit(1);
const associatedSessionId = agentSession[0]?.sessionId;
@@ -1293,7 +1325,7 @@ export class MessageModel {
findById = async (id: string) => {
return this.db.query.messages.findFirst({
where: and(eq(messages.id, id), eq(messages.userId, this.userId)),
where: and(eq(messages.id, id), this.ownership()),
});
};
@@ -1340,7 +1372,7 @@ export class MessageModel {
// For Standalone type, only return the source message
if (threadType === ThreadType.Standalone) {
const sourceMessage = await this.db.query.messages.findFirst({
where: and(eq(messages.id, sourceMessageId), eq(messages.userId, this.userId)),
where: and(eq(messages.id, sourceMessageId), this.ownership()),
});
return sourceMessage ? [sourceMessage as DBMessageItem] : [];
@@ -1348,7 +1380,7 @@ export class MessageModel {
// For Continuation type, get the source message first to know its createdAt
const sourceMessage = await this.db.query.messages.findFirst({
where: and(eq(messages.id, sourceMessageId), eq(messages.userId, this.userId)),
where: and(eq(messages.id, sourceMessageId), this.ownership()),
});
if (!sourceMessage) return [];
@@ -1361,7 +1393,7 @@ export class MessageModel {
.from(messages)
.where(
and(
eq(messages.userId, this.userId),
this.ownership(),
eq(messages.topicId, topicId),
isNull(messages.threadId), // Only main conversation messages (not in any thread)
or(
@@ -1400,7 +1432,7 @@ export class MessageModel {
const result = await this.db
.select()
.from(messages)
.where(eq(messages.userId, this.userId))
.where(and(this.ownership()))
.orderBy(desc(messages.createdAt))
.limit(pageSize)
.offset(offset);
@@ -1411,7 +1443,7 @@ export class MessageModel {
queryBySessionId = async (sessionId?: string | null) => {
const result = await this.db.query.messages.findMany({
orderBy: [asc(messages.createdAt)],
where: and(eq(messages.userId, this.userId), this.matchSession(sessionId)),
where: and(this.ownership(), this.matchSession(sessionId)),
});
return result as DBMessageItem[];
@@ -1424,7 +1456,7 @@ export class MessageModel {
const result = await this.db
.select()
.from(messages)
.where(and(eq(messages.userId, this.userId), sql`${messages.content} @@@ ${bm25Query}`))
.where(and(this.ownership(), sql`${messages.content} @@@ ${bm25Query}`))
.orderBy(desc(messages.createdAt));
return result as DBMessageItem[];
@@ -1442,7 +1474,7 @@ export class MessageModel {
.from(messages)
.where(
genWhere([
eq(messages.userId, this.userId),
this.ownership(),
params?.range
? genRangeWhere(params.range, messages.createdAt, (date) => date.toDate())
: undefined,
@@ -1462,7 +1494,7 @@ export class MessageModel {
const rows = await this.db
.select({ id: messages.id })
.from(messages)
.where(and(eq(messages.userId, this.userId), eq(messages.topicId, topicId)))
.where(and(eq(messages.topicId, topicId), this.ownership()))
.limit(1);
return rows.length > 0;
@@ -1472,13 +1504,7 @@ export class MessageModel {
const rows = (await this.db
.select()
.from(messages)
.where(
and(
eq(messages.userId, this.userId),
eq(messages.topicId, topicId),
eq(messages.role, 'assistant'),
),
)
.where(and(eq(messages.topicId, topicId), eq(messages.role, 'assistant'), this.ownership()))
.orderBy(asc(messages.createdAt))
.limit(1)) as DBMessageItem[];
@@ -1497,7 +1523,7 @@ export class MessageModel {
.from(messages)
.where(
genWhere([
eq(messages.userId, this.userId),
this.ownership(),
params?.range
? genRangeWhere(params.range, messages.createdAt, (date) => date.toDate())
: undefined,
@@ -1520,7 +1546,7 @@ export class MessageModel {
id: messages.model,
})
.from(messages)
.where(and(eq(messages.userId, this.userId), isNotNull(messages.model)))
.where(and(this.ownership(), isNotNull(messages.model)))
.having(({ count }) => gt(count, 0))
.groupBy(messages.model)
.orderBy(desc(sql`count`), asc(messages.model))
@@ -1539,7 +1565,7 @@ export class MessageModel {
.from(messages)
.where(
genWhere([
eq(messages.userId, this.userId),
this.ownership(),
genRangeWhere(
[startDate.format('YYYY-MM-DD'), endDate.add(1, 'day').format('YYYY-MM-DD')],
messages.createdAt,
@@ -1606,7 +1632,7 @@ export class MessageModel {
.from(messages)
.where(
genWhere([
eq(messages.userId, this.userId),
this.ownership(),
eq(messages.role, 'assistant'),
genRangeWhere(
[startDate.format('YYYY-MM-DD'), endDate.add(1, 'day').format('YYYY-MM-DD')],
@@ -1658,7 +1684,7 @@ export class MessageModel {
const result = await this.db
.select({ id: messages.id })
.from(messages)
.where(eq(messages.userId, this.userId))
.where(and(this.ownership()))
.limit(n + 1);
return result.length > n;
@@ -1671,7 +1697,7 @@ export class MessageModel {
const result = await this.db
.select({ id: messages.id })
.from(messages)
.where(eq(messages.userId, this.userId))
.where(and(this.ownership()))
.limit(n);
return result.length;
@@ -1716,23 +1742,25 @@ export class MessageModel {
// Ensure group message does not populate sessionId
const normalizedMessage = message.groupId ? { ...message, sessionId: null } : message;
return {
...normalizedMessage,
// Sanitize content to strip null bytes that PostgreSQL rejects
content: sanitizeNullBytes(normalizedMessage.content),
// TODO: remove this when the client is updated
createdAt: createdAt ? new Date(createdAt) : undefined,
id,
model: fromModel,
provider: fromProvider,
updatedAt: updatedAt ? new Date(updatedAt) : undefined,
// Promote token usage into the dedicated `usage` column, preferring a
// top-level `usage` over the legacy `metadata.usage`.
usage:
normalizedMessage.usage ??
(normalizedMessage.metadata as { usage?: ModelUsage } | undefined)?.usage,
userId: this.userId,
};
return buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{
...normalizedMessage,
// Sanitize content to strip null bytes that PostgreSQL rejects
content: sanitizeNullBytes(normalizedMessage.content),
// TODO: remove this when the client is updated
createdAt: createdAt ? new Date(createdAt) : undefined,
id,
model: fromModel,
provider: fromProvider,
updatedAt: updatedAt ? new Date(updatedAt) : undefined,
// Promote token usage into the dedicated `usage` column, preferring a
// top-level `usage` over the legacy `metadata.usage`.
usage:
normalizedMessage.usage ??
(normalizedMessage.metadata as { usage?: ModelUsage } | undefined)?.usage,
},
);
};
private insertMessageRelationsInTransaction = async (
@@ -1763,6 +1791,7 @@ export class MessageModel {
toolCallId: message.tool_call_id,
type: plugin?.type,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
}),
);
}
@@ -1772,9 +1801,14 @@ export class MessageModel {
timing,
`${timingPrefix}.files.insert`,
() =>
trx
.insert(messagesFiles)
.values(files.map((file) => ({ fileId: file, messageId: id, userId: this.userId }))),
trx.insert(messagesFiles).values(
files.map((file) => ({
fileId: file,
messageId: id,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})),
),
{ fileCount: files.length },
);
}
@@ -1791,6 +1825,7 @@ export class MessageModel {
queryId: ragQueryId,
similarity: chunk.similarity?.toString(),
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})),
),
{ chunkCount: fileChunks.length },
@@ -1956,10 +1991,13 @@ export class MessageModel {
};
batchCreate = async (newMessages: DBMessageItem[]) => {
const messagesToInsert = newMessages.map((m) => {
// TODO: need a better way to handle this
return { ...m, role: m.role as any, userId: this.userId };
});
const messagesToInsert = newMessages.map((m) =>
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
// TODO: need a better way to handle this
{ ...m, role: m.role as any },
),
);
const topicIds = [...new Set(newMessages.map((m) => m.topicId).filter(Boolean))] as string[];
@@ -1975,7 +2013,7 @@ export class MessageModel {
createMessageQuery = async (params: NewMessageQueryParams) => {
const result = await this.db
.insert(messageQueries)
.values({ ...params, userId: this.userId })
.values({ ...params, userId: this.userId, workspaceId: this.workspaceId ?? null })
.returning();
return result[0];
@@ -2017,6 +2055,7 @@ export class MessageModel {
fileId: file.id,
messageId: id,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})),
),
{ imageCount: imageList.length },
@@ -2034,7 +2073,7 @@ export class MessageModel {
trx
.select({ metadata: messages.metadata })
.from(messages)
.where(and(eq(messages.id, id), eq(messages.userId, this.userId))),
.where(and(eq(messages.id, id), this.ownership())),
);
mergedMetadata = merge(existingMessage?.metadata || {}, metadataPatch);
}
@@ -2050,7 +2089,7 @@ export class MessageModel {
...(mergedMetadata && { metadata: mergedMetadata }),
...(usageToWrite && { usage: usageToWrite }),
})
.where(and(eq(messages.id, id), eq(messages.userId, this.userId)))
.where(and(eq(messages.id, id), this.ownership()))
.returning({ topicId: messages.topicId }),
{ hasMetadata: !!metadataPatch, valueKeys: Object.keys(message) },
);
@@ -2072,7 +2111,7 @@ export class MessageModel {
await runTimedStage(
timing,
'db.message.update.topic.recomputeUsage',
() => recomputeTopicUsage(trx, this.userId, updated.topicId!),
() => recomputeTopicUsage(trx, this.userId, updated.topicId!, this.workspaceId),
{ topicCount: 1 },
);
}
@@ -2094,7 +2133,7 @@ export class MessageModel {
updateMetadata = async (id: string, metadata: Record<string, any>) => {
const item = await this.db.query.messages.findFirst({
where: and(eq(messages.id, id), eq(messages.userId, this.userId)),
where: and(eq(messages.id, id), this.ownership()),
});
if (!item) return;
@@ -2107,28 +2146,31 @@ export class MessageModel {
return this.db
.update(messages)
.set({ metadata: mergedMetadata, ...(usageToWrite && { usage: usageToWrite }) })
.where(and(eq(messages.userId, this.userId), eq(messages.id, id)));
.where(and(eq(messages.id, id), this.ownership()));
};
updatePluginState = async (id: string, state: Record<string, any>): Promise<void> => {
const item = await this.db.query.messagePlugins.findFirst({
where: eq(messagePlugins.id, id),
where: and(eq(messagePlugins.id, id), this.pluginsOwnership()),
});
if (!item) throw new Error('Plugin not found');
await this.db
.update(messagePlugins)
.set({ state: merge(item.state || {}, state) })
.where(eq(messagePlugins.id, id));
.where(and(eq(messagePlugins.id, id), this.pluginsOwnership()));
};
updateMessagePlugin = async (id: string, value: Partial<MessagePluginItem>) => {
const item = await this.db.query.messagePlugins.findFirst({
where: eq(messagePlugins.id, id),
where: and(eq(messagePlugins.id, id), this.pluginsOwnership()),
});
if (!item) throw new Error('Plugin not found');
return this.db.update(messagePlugins).set(value).where(eq(messagePlugins.id, id));
return this.db
.update(messagePlugins)
.set(value)
.where(and(eq(messagePlugins.id, id), this.pluginsOwnership()));
};
/**
@@ -2144,7 +2186,7 @@ export class MessageModel {
*/
findMessagePlugin = async (messageId: string): Promise<MessagePluginItem | undefined> => {
const row = await this.db.query.messagePlugins.findFirst({
where: eq(messagePlugins.id, messageId),
where: and(eq(messagePlugins.id, messageId), this.pluginsOwnership()),
});
if (!row) return undefined;
return {
@@ -2185,7 +2227,7 @@ export class MessageModel {
})
.from(messagePlugins)
.innerJoin(messages, eq(messagePlugins.id, messages.id))
.where(and(eq(messages.topicId, topicId), eq(messages.userId, this.userId)))
.where(and(eq(messages.topicId, topicId), this.ownership(), this.pluginsOwnership()))
.orderBy(asc(messages.createdAt), asc(messages.id));
return rows.map((row) => ({
@@ -2231,7 +2273,7 @@ export class MessageModel {
if (metadata !== undefined) {
// Need to merge with existing metadata
const existingMessage = await trx.query.messages.findFirst({
where: and(eq(messages.id, id), eq(messages.userId, this.userId)),
where: and(eq(messages.id, id), this.ownership()),
});
messageUpdateData.metadata = merge(existingMessage?.metadata || {}, metadata);
}
@@ -2240,14 +2282,14 @@ export class MessageModel {
await trx
.update(messages)
.set(messageUpdateData)
.where(and(eq(messages.id, id), eq(messages.userId, this.userId)));
.where(and(eq(messages.id, id), this.ownership()));
}
}
// Update messagePlugins table (pluginState, pluginError)
if (pluginState !== undefined || pluginError !== undefined) {
const pluginItem = await trx.query.messagePlugins.findFirst({
where: eq(messagePlugins.id, id),
where: and(eq(messagePlugins.id, id), this.pluginsOwnership()),
});
if (pluginItem) {
@@ -2265,7 +2307,7 @@ export class MessageModel {
await trx
.update(messagePlugins)
.set(pluginUpdateData)
.where(eq(messagePlugins.id, id));
.where(and(eq(messagePlugins.id, id), this.pluginsOwnership()));
}
}
}
@@ -2300,7 +2342,7 @@ export class MessageModel {
})
.from(messagePlugins)
.innerJoin(messages, eq(messages.id, messagePlugins.id))
.where(and(eq(messagePlugins.toolCallId, toolCallId), eq(messages.userId, this.userId)))
.where(and(eq(messagePlugins.toolCallId, toolCallId), this.ownership()))
.limit(1);
if (!toolResult?.parentId) {
@@ -2311,7 +2353,7 @@ export class MessageModel {
const [parentMessage] = await trx
.select({ id: messages.id, tools: messages.tools })
.from(messages)
.where(eq(messages.id, toolResult.parentId))
.where(and(eq(messages.id, toolResult.parentId), this.ownership()))
.limit(1);
if (!parentMessage?.tools) {
@@ -2334,12 +2376,12 @@ export class MessageModel {
trx
.update(messagePlugins)
.set({ arguments: args })
.where(eq(messagePlugins.id, toolResult.toolPluginId)),
.where(and(eq(messagePlugins.id, toolResult.toolPluginId), this.pluginsOwnership())),
// Update parent assistant message's tools
trx
.update(messages)
.set({ tools: updatedTools })
.where(eq(messages.id, parentMessage.id)),
.where(and(eq(messages.id, parentMessage.id), this.ownership())),
]);
});
@@ -2352,21 +2394,29 @@ export class MessageModel {
updateTranslate = async (id: string, translate: Partial<ChatTranslate>) => {
const result = await this.db.query.messageTranslates.findFirst({
where: and(eq(messageTranslates.id, id)),
where: and(eq(messageTranslates.id, id), this.translatesOwnership()),
});
// If the message does not exist in the translate table, insert it
if (!result) {
return this.db.insert(messageTranslates).values({ ...translate, id, userId: this.userId });
return this.db.insert(messageTranslates).values({
...translate,
id,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
});
}
// or just update the existing one
return this.db.update(messageTranslates).set(translate).where(eq(messageTranslates.id, id));
return this.db
.update(messageTranslates)
.set(translate)
.where(and(eq(messageTranslates.id, id), this.translatesOwnership()));
};
updateTTS = async (id: string, tts: Partial<ChatTTS>) => {
const result = await this.db.query.messageTTS.findFirst({
where: and(eq(messageTTS.id, id)),
where: and(eq(messageTTS.id, id), this.ttsOwnership()),
});
// If the message does not exist in the translate table, insert it
@@ -2377,6 +2427,7 @@ export class MessageModel {
id,
userId: this.userId,
voice: tts.voice,
workspaceId: this.workspaceId ?? null,
});
}
@@ -2384,7 +2435,7 @@ export class MessageModel {
return this.db
.update(messageTTS)
.set({ contentMd5: tts.contentMd5, fileId: tts.file, voice: tts.voice })
.where(eq(messageTTS.id, id));
.where(and(eq(messageTTS.id, id), this.ttsOwnership()));
};
async updateMessageRAG(id: string, { ragQueryId, fileChunks }: UpdateMessageRAGParams) {
@@ -2395,6 +2446,7 @@ export class MessageModel {
queryId: ragQueryId,
similarity: chunk.similarity?.toString(),
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})),
);
}
@@ -2407,7 +2459,7 @@ export class MessageModel {
const message = await tx
.select()
.from(messages)
.where(and(eq(messages.id, id), eq(messages.userId, this.userId)))
.where(and(eq(messages.id, id), this.ownership()))
.limit(1);
// If the message to be deleted is not found, return directly
@@ -2418,7 +2470,7 @@ export class MessageModel {
await tx
.update(messages)
.set({ parentId: message[0].parentId })
.where(and(eq(messages.parentId, id), eq(messages.userId, this.userId)));
.where(and(eq(messages.parentId, id), this.ownership()));
// 3. Check if the message contains tools
const toolCallIds = (message[0].tools as ChatToolPayload[])
@@ -2443,12 +2495,12 @@ export class MessageModel {
// 6. Delete all related messages
await tx
.delete(messages)
.where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIdsToDelete)));
.where(and(this.ownership(), inArray(messages.id, messageIdsToDelete)));
// 7. Keep the topic's usage rollup in sync (pure derived — a removed
// assistant message must drop out of the topic totals).
if (message[0].topicId) {
await recomputeTopicUsage(tx, this.userId, message[0].topicId);
await recomputeTopicUsage(tx, this.userId, message[0].topicId, this.workspaceId);
}
});
};
@@ -2461,7 +2513,7 @@ export class MessageModel {
const toDelete = await tx
.select({ id: messages.id, parentId: messages.parentId, topicId: messages.topicId })
.from(messages)
.where(and(eq(messages.userId, this.userId), inArray(messages.id, ids)));
.where(and(this.ownership(), inArray(messages.id, ids)));
if (toDelete.length === 0) return;
@@ -2506,30 +2558,27 @@ export class MessageModel {
.select({ id: messages.id, parentId: messages.parentId })
.from(messages)
.where(
and(
eq(messages.userId, this.userId),
inArray(messages.parentId, ids),
not(inArray(messages.id, ids)),
),
and(this.ownership(), inArray(messages.parentId, ids), not(inArray(messages.id, ids))),
);
// 5. Update each child's parentId to the final ancestor
for (const child of children) {
const newParentId = finalAncestorMap.get(child.parentId!) ?? null;
await tx.update(messages).set({ parentId: newParentId }).where(eq(messages.id, child.id));
await tx
.update(messages)
.set({ parentId: newParentId })
.where(and(eq(messages.id, child.id), this.ownership()));
}
// 6. Delete the messages
await tx
.delete(messages)
.where(and(eq(messages.userId, this.userId), inArray(messages.id, ids)));
await tx.delete(messages).where(and(this.ownership(), inArray(messages.id, ids)));
// 7. Recompute the usage rollup for every affected topic (pure derived).
const affectedTopicIds = [
...new Set(toDelete.map((m) => m.topicId).filter(Boolean) as string[]),
];
for (const topicId of affectedTopicIds) {
await recomputeTopicUsage(tx, this.userId, topicId);
await recomputeTopicUsage(tx, this.userId, topicId, this.workspaceId);
}
});
};
@@ -2547,6 +2596,7 @@ export class MessageModel {
fileId,
messageId,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})),
);
return { success: true };
@@ -2559,17 +2609,23 @@ export class MessageModel {
deleteMessageTranslate = async (id: string) =>
this.db
.delete(messageTranslates)
.where(and(eq(messageTranslates.id, id), eq(messageTranslates.userId, this.userId)));
.where(and(eq(messageTranslates.id, id), this.translatesOwnership()));
deleteMessageTTS = async (id: string) =>
this.db
.delete(messageTTS)
.where(and(eq(messageTTS.id, id), eq(messageTTS.userId, this.userId)));
this.db.delete(messageTTS).where(and(eq(messageTTS.id, id), this.ttsOwnership()));
deleteMessageQuery = async (id: string) =>
this.db
.delete(messageQueries)
.where(and(eq(messageQueries.id, id), eq(messageQueries.userId, this.userId)));
.where(
and(
eq(messageQueries.id, id),
buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
messageQueries,
),
),
);
deleteMessagesBySession = async (
sessionId?: string | null,
@@ -2580,7 +2636,7 @@ export class MessageModel {
.delete(messages)
.where(
and(
eq(messages.userId, this.userId),
this.ownership(),
this.matchSession(sessionId),
this.matchTopic(topicId),
this.matchGroup(groupId),
@@ -2588,7 +2644,7 @@ export class MessageModel {
);
deleteAllMessages = async () => {
return this.db.delete(messages).where(eq(messages.userId, this.userId));
return this.db.delete(messages).where(and(this.ownership()));
};
/**
@@ -2602,7 +2658,7 @@ export class MessageModel {
const agentSession = await this.db
.select({ sessionId: agentsToSessions.sessionId })
.from(agentsToSessions)
.where(and(eq(agentsToSessions.agentId, agentId), eq(agentsToSessions.userId, this.userId)))
.where(and(eq(agentsToSessions.agentId, agentId), this.agentsToSessionsOwnership()))
.limit(1);
const associatedSessionId = agentSession[0]?.sessionId;
@@ -2612,7 +2668,7 @@ export class MessageModel {
? or(eq(messages.agentId, agentId), eq(messages.sessionId, associatedSessionId))
: eq(messages.agentId, agentId);
return this.db.delete(messages).where(and(eq(messages.userId, this.userId), agentCondition));
return this.db.delete(messages).where(and(this.ownership(), agentCondition));
};
// **************** Helper *************** //
@@ -50,6 +50,13 @@ export class MessengerAccountLinkModel {
this.db = db;
}
// A given IM identity maps to exactly one link per `(userId, platform,
// tenantId)` — the unique index already enforces this — so ownership is
// purely by `userId`. `workspaceId` on the row is the *active scope* (derived
// from the active agent), NOT part of the link's identity, so it must not
// scope lookups; otherwise switching scope would orphan the existing link.
private ownership = (): SQL => eq(messengerAccountLinks.userId, this.userId);
// --------------- User-scoped CRUD ---------------
/**
@@ -87,6 +94,7 @@ export class MessengerAccountLinkModel {
tenantId,
updatedAt: now,
userId: this.userId,
workspaceId: params.workspaceId ?? null,
})
.onConflictDoNothing({
target: [
@@ -131,6 +139,7 @@ export class MessengerAccountLinkModel {
activeAgentId: params.activeAgentId ?? byIdentity.activeAgentId,
platformUsername: params.platformUsername ?? null,
updatedAt: now,
workspaceId: params.workspaceId ?? null,
})
.where(eq(messengerAccountLinks.id, byIdentity.id))
.returning();
@@ -149,6 +158,7 @@ export class MessengerAccountLinkModel {
activeAgentId: params.activeAgentId ?? existingForUser.activeAgentId,
platformUsername: params.platformUsername ?? null,
updatedAt: now,
workspaceId: params.workspaceId ?? null,
})
.where(eq(messengerAccountLinks.id, existingForUser.id))
.returning();
@@ -161,14 +171,11 @@ export class MessengerAccountLinkModel {
delete = async (id: string) => {
return this.db
.delete(messengerAccountLinks)
.where(and(eq(messengerAccountLinks.id, id), eq(messengerAccountLinks.userId, this.userId)));
.where(and(eq(messengerAccountLinks.id, id), this.ownership()));
};
deleteByPlatform = async (platform: string, tenantId?: string) => {
const conditions: SQL[] = [
eq(messengerAccountLinks.userId, this.userId),
eq(messengerAccountLinks.platform, platform),
];
const conditions: SQL[] = [this.ownership(), eq(messengerAccountLinks.platform, platform)];
if (tenantId !== undefined) {
conditions.push(eq(messengerAccountLinks.tenantId, tenantId));
}
@@ -176,10 +183,7 @@ export class MessengerAccountLinkModel {
};
list = async (): Promise<MessengerAccountLinkItem[]> => {
return this.db
.select()
.from(messengerAccountLinks)
.where(eq(messengerAccountLinks.userId, this.userId));
return this.db.select().from(messengerAccountLinks).where(this.ownership());
};
/**
@@ -192,10 +196,7 @@ export class MessengerAccountLinkModel {
platform: string,
tenantId?: string,
): Promise<MessengerAccountLinkItem | undefined> => {
const conditions: SQL[] = [
eq(messengerAccountLinks.userId, this.userId),
eq(messengerAccountLinks.platform, platform),
];
const conditions: SQL[] = [this.ownership(), eq(messengerAccountLinks.platform, platform)];
if (tenantId !== undefined) {
conditions.push(eq(messengerAccountLinks.tenantId, tenantId));
}
@@ -208,23 +209,25 @@ export class MessengerAccountLinkModel {
return result;
};
/** Update which agent the IM session is currently routed to. */
/**
* Update which agent the IM session is currently routed to, together with
* the active scope (`workspaceId`) derived from that agent. Passing
* `agentId: null` clears the active agent and resets the scope to personal.
*/
setActiveAgent = async (
platform: string,
agentId: string | null,
workspaceId: string | null,
tenantId?: string,
): Promise<MessengerAccountLinkItem | undefined> => {
const conditions: SQL[] = [
eq(messengerAccountLinks.userId, this.userId),
eq(messengerAccountLinks.platform, platform),
];
const conditions: SQL[] = [this.ownership(), eq(messengerAccountLinks.platform, platform)];
if (tenantId !== undefined) {
conditions.push(eq(messengerAccountLinks.tenantId, tenantId));
}
const [updated] = await this.db
.update(messengerAccountLinks)
.set({ activeAgentId: agentId, updatedAt: new Date() })
.set({ activeAgentId: agentId, updatedAt: new Date(), workspaceId })
.where(and(...conditions))
.returning();
@@ -276,4 +279,26 @@ export class MessengerAccountLinkModel {
.returning();
return updated;
};
/**
* Static scope switch used by IM `/switch`. Moves the link to a new active
* scope (personal `null`, or a workspace id) and sets the active agent to
* `agentId` callers pass the scope's default agent (inbox/LobeAI) so
* switching never leaves the session agent-less; pass `null` only when the
* target scope has no agents. Caller must authorize access to the target
* scope first.
*/
static setActiveScope = async (
db: LobeChatDatabase,
linkId: string,
workspaceId: string | null,
agentId: string | null = null,
): Promise<MessengerAccountLinkItem | undefined> => {
const [updated] = await db
.update(messengerAccountLinks)
.set({ activeAgentId: agentId, updatedAt: new Date(), workspaceId })
.where(eq(messengerAccountLinks.id, linkId))
.returning();
return updated;
};
}
+9 -15
View File
@@ -13,12 +13,14 @@ export class NotificationModel {
this.userId = userId;
}
private ownership = () => eq(notifications.userId, this.userId);
async list(
opts: { category?: string; cursor?: string; limit?: number; unreadOnly?: boolean } = {},
) {
const { cursor, limit = 20, category, unreadOnly } = opts;
const conditions = [eq(notifications.userId, this.userId), eq(notifications.isArchived, false)];
const conditions = [this.ownership(), eq(notifications.isArchived, false)];
if (unreadOnly) {
conditions.push(eq(notifications.isRead, false));
@@ -32,7 +34,7 @@ export class NotificationModel {
const cursorRow = await this.db
.select({ createdAt: notifications.createdAt, id: notifications.id })
.from(notifications)
.where(and(eq(notifications.id, cursor), eq(notifications.userId, this.userId)))
.where(and(eq(notifications.id, cursor), this.ownership()))
.limit(1);
if (cursorRow[0]) {
@@ -60,11 +62,7 @@ export class NotificationModel {
.select({ count: count() })
.from(notifications)
.where(
and(
eq(notifications.userId, this.userId),
eq(notifications.isRead, false),
eq(notifications.isArchived, false),
),
and(this.ownership(), eq(notifications.isRead, false), eq(notifications.isArchived, false)),
);
return result?.count ?? 0;
@@ -76,7 +74,7 @@ export class NotificationModel {
return this.db
.update(notifications)
.set({ isRead: true, updatedAt: new Date() })
.where(and(eq(notifications.userId, this.userId), inArray(notifications.id, ids)));
.where(and(this.ownership(), inArray(notifications.id, ids)));
}
async markAllAsRead() {
@@ -84,11 +82,7 @@ export class NotificationModel {
.update(notifications)
.set({ isRead: true, updatedAt: new Date() })
.where(
and(
eq(notifications.userId, this.userId),
eq(notifications.isRead, false),
eq(notifications.isArchived, false),
),
and(this.ownership(), eq(notifications.isRead, false), eq(notifications.isArchived, false)),
);
}
@@ -96,14 +90,14 @@ export class NotificationModel {
return this.db
.update(notifications)
.set({ isArchived: true, updatedAt: new Date() })
.where(and(eq(notifications.id, id), eq(notifications.userId, this.userId)));
.where(and(eq(notifications.id, id), this.ownership()));
}
async archiveAll() {
return this.db
.update(notifications)
.set({ isArchived: true, updatedAt: new Date() })
.where(and(eq(notifications.userId, this.userId), eq(notifications.isArchived, false)));
.where(and(this.ownership(), eq(notifications.isArchived, false)));
}
// ─── Write-side (used by NotificationService in cloud) ─────────
+16 -14
View File
@@ -4,16 +4,25 @@ import { and, desc, eq } from 'drizzle-orm';
import type { InstalledPluginItem, NewInstalledPlugin } from '../schemas';
import { userInstalledPlugins } from '../schemas';
import type { LobeChatDatabase } from '../type';
import { buildWorkspaceWhere } from '../utils/workspace';
export class PluginModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
userInstalledPlugins,
);
create = async (
params: Pick<
NewInstalledPlugin,
@@ -22,7 +31,7 @@ export class PluginModel {
) => {
const [result] = await this.db
.insert(userInstalledPlugins)
.values({ ...params, userId: this.userId })
.values({ ...params, userId: this.userId, workspaceId: this.workspaceId ?? null })
.onConflictDoUpdate({
set: { ...params, updatedAt: new Date() },
target: [userInstalledPlugins.identifier, userInstalledPlugins.userId],
@@ -35,13 +44,11 @@ export class PluginModel {
delete = async (id: string) => {
return this.db
.delete(userInstalledPlugins)
.where(
and(eq(userInstalledPlugins.identifier, id), eq(userInstalledPlugins.userId, this.userId)),
);
.where(and(eq(userInstalledPlugins.identifier, id), this.ownership()));
};
deleteAll = async () => {
return this.db.delete(userInstalledPlugins).where(eq(userInstalledPlugins.userId, this.userId));
return this.db.delete(userInstalledPlugins).where(this.ownership());
};
query = async () => {
@@ -57,7 +64,7 @@ export class PluginModel {
updatedAt: userInstalledPlugins.updatedAt,
})
.from(userInstalledPlugins)
.where(eq(userInstalledPlugins.userId, this.userId))
.where(this.ownership())
.orderBy(desc(userInstalledPlugins.createdAt));
return data.map<LobeTool>((item) => ({
@@ -68,10 +75,7 @@ export class PluginModel {
findById = async (id: string) => {
return this.db.query.userInstalledPlugins.findFirst({
where: and(
eq(userInstalledPlugins.identifier, id),
eq(userInstalledPlugins.userId, this.userId),
),
where: and(eq(userInstalledPlugins.identifier, id), this.ownership()),
});
};
@@ -79,8 +83,6 @@ export class PluginModel {
return this.db
.update(userInstalledPlugins)
.set({ ...value, updatedAt: new Date() })
.where(
and(eq(userInstalledPlugins.identifier, id), eq(userInstalledPlugins.userId, this.userId)),
);
.where(and(eq(userInstalledPlugins.identifier, id), this.ownership()));
};
}
+13 -14
View File
@@ -1,31 +1,35 @@
import type { RAGEvalDataSetItem } from '@lobechat/types';
import { and, desc, eq } from 'drizzle-orm';
import type {NewEvalDatasetsItem } from '../../schemas';
import type { NewEvalDatasetsItem } from '../../schemas';
import { evalDatasets } from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { buildWorkspaceWhere } from '../../utils/workspace';
export class EvalDatasetModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.db = db;
this.userId = userId;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, evalDatasets);
create = async (params: NewEvalDatasetsItem) => {
const [result] = await this.db
.insert(evalDatasets)
.values({ ...params, userId: this.userId })
.values({ ...params, userId: this.userId, workspaceId: this.workspaceId ?? null })
.returning();
return result;
};
delete = async (id: string) => {
return this.db
.delete(evalDatasets)
.where(and(eq(evalDatasets.id, id), eq(evalDatasets.userId, this.userId)));
return this.db.delete(evalDatasets).where(and(eq(evalDatasets.id, id), this.ownership()));
};
query = async (knowledgeBaseId: string): Promise<RAGEvalDataSetItem[]> => {
@@ -38,18 +42,13 @@ export class EvalDatasetModel {
updatedAt: evalDatasets.updatedAt,
})
.from(evalDatasets)
.where(
and(
eq(evalDatasets.userId, this.userId),
eq(evalDatasets.knowledgeBaseId, knowledgeBaseId),
),
)
.where(and(this.ownership(), eq(evalDatasets.knowledgeBaseId, knowledgeBaseId)))
.orderBy(desc(evalDatasets.createdAt));
};
findById = async (id: string) => {
return this.db.query.evalDatasets.findFirst({
where: and(eq(evalDatasets.id, id), eq(evalDatasets.userId, this.userId)),
where: and(eq(evalDatasets.id, id), this.ownership()),
});
};
@@ -57,6 +56,6 @@ export class EvalDatasetModel {
return this.db
.update(evalDatasets)
.set({ ...value, updatedAt: new Date() })
.where(and(eq(evalDatasets.id, id), eq(evalDatasets.userId, this.userId)));
.where(and(eq(evalDatasets.id, id), this.ownership()));
};
}
@@ -1,23 +1,32 @@
import type { EvalDatasetRecordRefFile } from '@lobechat/types';
import { and, eq, inArray } from 'drizzle-orm';
import type {NewEvalDatasetRecordsItem } from '../../schemas';
import type { NewEvalDatasetRecordsItem } from '../../schemas';
import { evalDatasetRecords, files } from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { buildWorkspaceWhere } from '../../utils/workspace';
export class EvalDatasetRecordModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.db = db;
this.userId = userId;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, evalDatasetRecords);
private filesOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, files);
create = async (params: NewEvalDatasetRecordsItem) => {
const [result] = await this.db
.insert(evalDatasetRecords)
.values({ ...params, userId: this.userId })
.values({ ...params, userId: this.userId, workspaceId: this.workspaceId ?? null })
.returning();
return result;
};
@@ -25,7 +34,13 @@ export class EvalDatasetRecordModel {
batchCreate = async (params: NewEvalDatasetRecordsItem[]) => {
const [result] = await this.db
.insert(evalDatasetRecords)
.values(params.map((item) => ({ ...item, userId: this.userId })))
.values(
params.map((item) => ({
...item,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})),
)
.returning();
return result;
@@ -34,22 +49,19 @@ export class EvalDatasetRecordModel {
delete = async (id: string) => {
return this.db
.delete(evalDatasetRecords)
.where(and(eq(evalDatasetRecords.id, id), eq(evalDatasetRecords.userId, this.userId)));
.where(and(eq(evalDatasetRecords.id, id), this.ownership()));
};
query = async (datasetId: string) => {
const list = await this.db.query.evalDatasetRecords.findMany({
where: and(
eq(evalDatasetRecords.datasetId, datasetId),
eq(evalDatasetRecords.userId, this.userId),
),
where: and(eq(evalDatasetRecords.datasetId, datasetId), this.ownership()),
});
const fileList = list.flatMap((item) => item.referenceFiles).filter(Boolean) as string[];
const fileItems = await this.db
.select({ fileType: files.fileType, id: files.id, name: files.name })
.from(files)
.where(and(inArray(files.id, fileList), eq(files.userId, this.userId)));
.where(and(inArray(files.id, fileList), this.filesOwnership()));
return list.map((item) => {
return {
@@ -63,16 +75,13 @@ export class EvalDatasetRecordModel {
findByDatasetId = async (datasetId: string) => {
return this.db.query.evalDatasetRecords.findMany({
where: and(
eq(evalDatasetRecords.datasetId, datasetId),
eq(evalDatasetRecords.userId, this.userId),
),
where: and(eq(evalDatasetRecords.datasetId, datasetId), this.ownership()),
});
};
findById = async (id: string) => {
return this.db.query.evalDatasetRecords.findFirst({
where: and(eq(evalDatasetRecords.id, id), eq(evalDatasetRecords.userId, this.userId)),
where: and(eq(evalDatasetRecords.id, id), this.ownership()),
});
};
@@ -80,6 +89,6 @@ export class EvalDatasetRecordModel {
return this.db
.update(evalDatasetRecords)
.set(value)
.where(and(eq(evalDatasetRecords.id, id), eq(evalDatasetRecords.userId, this.userId)));
.where(and(eq(evalDatasetRecords.id, id), this.ownership()));
};
}
@@ -3,36 +3,35 @@ import { EvalEvaluationStatus } from '@lobechat/types';
import type { SQL } from 'drizzle-orm';
import { and, count, desc, eq, inArray } from 'drizzle-orm';
import type {
NewEvalEvaluationItem} from '../../schemas';
import {
evalDatasets,
evalEvaluation,
evaluationRecords
} from '../../schemas';
import type { NewEvalEvaluationItem } from '../../schemas';
import { evalDatasets, evalEvaluation, evaluationRecords } from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { buildWorkspaceWhere } from '../../utils/workspace';
export class EvalEvaluationModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.db = db;
this.userId = userId;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, evalEvaluation);
create = async (params: NewEvalEvaluationItem) => {
const [result] = await this.db
.insert(evalEvaluation)
.values({ ...params, userId: this.userId })
.values({ ...params, userId: this.userId, workspaceId: this.workspaceId ?? null })
.returning();
return result;
};
delete = async (id: string) => {
return this.db
.delete(evalEvaluation)
.where(and(eq(evalEvaluation.id, id), eq(evalEvaluation.userId, this.userId)));
return this.db.delete(evalEvaluation).where(and(eq(evalEvaluation.id, id), this.ownership()));
};
queryByKnowledgeBaseId = async (knowledgeBaseId: string) => {
@@ -52,12 +51,7 @@ export class EvalEvaluationModel {
.from(evalEvaluation)
.leftJoin(evalDatasets, eq(evalDatasets.id, evalEvaluation.datasetId))
.orderBy(desc(evalEvaluation.createdAt))
.where(
and(
eq(evalEvaluation.userId, this.userId),
eq(evalEvaluation.knowledgeBaseId, knowledgeBaseId),
),
);
.where(and(this.ownership(), eq(evalEvaluation.knowledgeBaseId, knowledgeBaseId)));
// Then query record statistics for each evaluation
const evaluationIds = evaluations.map((evals) => evals.id);
@@ -88,7 +82,7 @@ export class EvalEvaluationModel {
findById = async (id: string) => {
return this.db.query.evalEvaluation.findFirst({
where: and(eq(evalEvaluation.id, id), eq(evalEvaluation.userId, this.userId)),
where: and(eq(evalEvaluation.id, id), this.ownership()),
});
};
@@ -96,6 +90,6 @@ export class EvalEvaluationModel {
return this.db
.update(evalEvaluation)
.set(value)
.where(and(eq(evalEvaluation.id, id), eq(evalEvaluation.userId, this.userId)));
.where(and(eq(evalEvaluation.id, id), this.ownership()));
};
}
@@ -1,22 +1,28 @@
import { and, eq } from 'drizzle-orm';
import type {NewEvaluationRecordsItem } from '../../schemas';
import type { NewEvaluationRecordsItem } from '../../schemas';
import { evaluationRecords } from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { buildWorkspaceWhere } from '../../utils/workspace';
export class EvaluationRecordModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.db = db;
this.userId = userId;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, evaluationRecords);
create = async (params: NewEvaluationRecordsItem) => {
const [result] = await this.db
.insert(evaluationRecords)
.values({ ...params, userId: this.userId })
.values({ ...params, userId: this.userId, workspaceId: this.workspaceId ?? null })
.returning();
return result;
};
@@ -24,37 +30,37 @@ export class EvaluationRecordModel {
batchCreate = async (params: NewEvaluationRecordsItem[]) => {
return this.db
.insert(evaluationRecords)
.values(params.map((item) => ({ ...item, userId: this.userId })))
.values(
params.map((item) => ({
...item,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})),
)
.returning();
};
delete = async (id: string) => {
return this.db
.delete(evaluationRecords)
.where(and(eq(evaluationRecords.id, id), eq(evaluationRecords.userId, this.userId)));
.where(and(eq(evaluationRecords.id, id), this.ownership()));
};
query = async (reportId: string) => {
return this.db.query.evaluationRecords.findMany({
where: and(
eq(evaluationRecords.evaluationId, reportId),
eq(evaluationRecords.userId, this.userId),
),
where: and(eq(evaluationRecords.evaluationId, reportId), this.ownership()),
});
};
findById = async (id: string) => {
return this.db.query.evaluationRecords.findFirst({
where: and(eq(evaluationRecords.id, id), eq(evaluationRecords.userId, this.userId)),
where: and(eq(evaluationRecords.id, id), this.ownership()),
});
};
findByEvaluationId = async (evaluationId: string) => {
return this.db.query.evaluationRecords.findMany({
where: and(
eq(evaluationRecords.evaluationId, evaluationId),
eq(evaluationRecords.userId, this.userId),
),
where: and(eq(evaluationRecords.evaluationId, evaluationId), this.ownership()),
});
};
@@ -62,6 +68,6 @@ export class EvaluationRecordModel {
return this.db
.update(evaluationRecords)
.set(value)
.where(and(eq(evaluationRecords.id, id), eq(evaluationRecords.userId, this.userId)));
.where(and(eq(evaluationRecords.id, id), this.ownership()));
};
}
+132 -39
View File
@@ -1,8 +1,14 @@
import { and, eq, inArray, sql } from 'drizzle-orm';
import type { WorkspaceSystemRoleName } from '@lobechat/const/rbac';
import { and, eq, inArray, isNull, or, sql } from 'drizzle-orm';
import { LobeChatDatabase } from '@/database/type';
import type { LobeChatDatabase } from '@/database/type';
import { RoleItem, permissions, rolePermissions, roles, userRoles } from '../schemas/rbac';
import type { RoleItem } from '../schemas/rbac';
import { permissions, rolePermissions, roles, userRoles } from '../schemas/rbac';
import {
assignWorkspaceRoleToUser,
revokeWorkspaceRolesForUser,
} from '../utils/seedWorkspaceRoles';
export interface UserPermissionInfo {
category: string;
@@ -11,6 +17,48 @@ export interface UserPermissionInfo {
roleName: string;
}
/**
* Optional scope for a permission/role query.
*
* - `workspaceId: 'xxx'` match grants in that workspace plus globally-granted
* roles (`rbac_user_roles.workspace_id IS NULL`, e.g. `super_admin`). This is
* what tRPC `withRbacPermission` uses inside a workspace request.
* - `workspaceId` omitted match **any** grant, regardless of workspace. This
* preserves backward-compat with pre-workspace-scope callers (Hono routes
* that just check `agent:read:all` against the whole user, with workspace
* isolation enforced by the resource-level query elsewhere).
*
* Callers that want to assert "only globally-granted roles count" must do that
* filter themselves on the result set; we don't expose a third mode here
* because no production caller needs it today.
*/
export interface RbacScopeOptions {
userId?: string;
workspaceId?: string;
}
/**
* Build the `WHERE rbac_user_roles.workspace_id ...` predicate used by every
* permission/role lookup. Encodes the rule above in one place so the four
* query methods don't drift. Returns `undefined` when no workspace scope
* filter should be applied (legacy behavior).
*/
const buildScopeWhere = (workspaceId: string | undefined) =>
workspaceId
? or(eq(userRoles.workspaceId, workspaceId), isNull(userRoles.workspaceId))
: undefined;
/**
* Back-compat shim: existing call sites pass a bare `userId` string as the
* second arg. New call sites pass `{ userId?, workspaceId? }`. Normalise both
* forms into the option object.
*/
const normalizeScope = (arg: string | RbacScopeOptions | undefined): RbacScopeOptions => {
if (!arg) return {};
if (typeof arg === 'string') return { userId: arg };
return arg;
};
export class RbacModel {
private userId: string;
private db: LobeChatDatabase;
@@ -21,12 +69,14 @@ export class RbacModel {
}
/**
* Get all permissions for a specific user
* @param userId - User ID to query permissions for
* @returns Array of permission codes that the user has
* Get all permissions for a specific user. Accepts either a plain `userId`
* string (legacy global-scope check) or `{ userId?, workspaceId? }`
* (workspace-aware). Permission codes returned include the `:all`/`:owner`
* scope suffix as stored in `rbac_permissions.code`.
*/
getUserPermissions = async (userId?: string): Promise<string[]> => {
const targetUserId = userId || this.userId;
getUserPermissions = async (arg?: string | RbacScopeOptions): Promise<string[]> => {
const opts = normalizeScope(arg);
const targetUserId = opts.userId || this.userId;
const result = await this.db
.select({
@@ -41,21 +91,26 @@ export class RbacModel {
eq(userRoles.userId, targetUserId),
eq(roles.isActive, true),
eq(permissions.isActive, true),
buildScopeWhere(opts.workspaceId),
// Check if role assignment is not expired
sql`(${userRoles.expiresAt} IS NULL OR ${userRoles.expiresAt} > NOW())`,
),
);
return result.map((row) => row.permissionCode);
// De-dupe — the same code can come from multiple roles (e.g. owner +
// member if a user somehow ends up with both).
return [...new Set(result.map((row) => row.permissionCode))];
};
/**
* Get detailed permission information for a user
* @param userId - User ID to query permissions for
* @returns Array of detailed permission information
* Get detailed permission information for a user. Same scope rules as
* `getUserPermissions`.
*/
getUserPermissionDetails = async (userId?: string): Promise<UserPermissionInfo[]> => {
const targetUserId = userId || this.userId;
getUserPermissionDetails = async (
arg?: string | RbacScopeOptions,
): Promise<UserPermissionInfo[]> => {
const opts = normalizeScope(arg);
const targetUserId = opts.userId || this.userId;
return await this.db
.select({
@@ -73,6 +128,7 @@ export class RbacModel {
eq(userRoles.userId, targetUserId),
eq(roles.isActive, true),
eq(permissions.isActive, true),
buildScopeWhere(opts.workspaceId),
// Check if role assignment is not expired
sql`(${userRoles.expiresAt} IS NULL OR ${userRoles.expiresAt} > NOW())`,
),
@@ -81,13 +137,15 @@ export class RbacModel {
};
/**
* Check if user has a specific permission
* @param permissionCode - Permission code to check
* @param userId - User ID to check (optional, defaults to instance userId)
* @returns Boolean indicating if user has the permission
* Check if user has a specific permission. Pass `{ workspaceId }` to scope
* the check to a workspace (global grants still apply).
*/
hasPermission = async (permissionCode: string, userId?: string): Promise<boolean> => {
const targetUserId = userId || this.userId;
hasPermission = async (
permissionCode: string,
arg?: string | RbacScopeOptions,
): Promise<boolean> => {
const opts = normalizeScope(arg);
const targetUserId = opts.userId || this.userId;
const result = await this.db
.select({ count: sql<number>`count(*)` })
@@ -101,6 +159,7 @@ export class RbacModel {
inArray(permissions.code, [permissionCode]),
eq(roles.isActive, true),
eq(permissions.isActive, true),
buildScopeWhere(opts.workspaceId),
// Check if role assignment is not expired
sql`(${userRoles.expiresAt} IS NULL OR ${userRoles.expiresAt} > NOW())`,
),
@@ -110,15 +169,16 @@ export class RbacModel {
};
/**
* Check if user has any of the specified permissions (OR logic)
* @param permissionCodes - Array of permission codes to check
* @param userId - User ID to check (optional, defaults to instance userId)
* @returns Boolean indicating if user has at least one of the permissions
* Check if user has any of the specified permissions (OR logic).
*/
hasAnyPermission = async (permissionCodes: string[], userId?: string): Promise<boolean> => {
hasAnyPermission = async (
permissionCodes: string[],
arg?: string | RbacScopeOptions,
): Promise<boolean> => {
if (permissionCodes.length === 0) return false;
const targetUserId = userId || this.userId;
const opts = normalizeScope(arg);
const targetUserId = opts.userId || this.userId;
const result = await this.db
.select({ count: sql<number>`count(*)` })
@@ -132,6 +192,7 @@ export class RbacModel {
inArray(permissions.code, permissionCodes),
eq(roles.isActive, true),
eq(permissions.isActive, true),
buildScopeWhere(opts.workspaceId),
// Check if role assignment is not expired
sql`(${userRoles.expiresAt} IS NULL OR ${userRoles.expiresAt} > NOW())`,
),
@@ -141,27 +202,24 @@ export class RbacModel {
};
/**
* Check if user has all of the specified permissions (AND logic)
* @param permissionCodes - Array of permission codes to check
* @param userId - User ID to check (optional, defaults to instance userId)
* @returns Boolean indicating if user has all of the permissions
* Check if user has all of the specified permissions (AND logic).
*/
hasAllPermissions = async (permissionCodes: string[], userId?: string): Promise<boolean> => {
hasAllPermissions = async (
permissionCodes: string[],
arg?: string | RbacScopeOptions,
): Promise<boolean> => {
if (permissionCodes.length === 0) return true;
const checks = await Promise.all(
permissionCodes.map((code) => this.hasPermission(code, userId)),
);
const checks = await Promise.all(permissionCodes.map((code) => this.hasPermission(code, arg)));
return checks.every(Boolean);
};
/**
* Get user's active roles
* @param userId - User ID to query roles for
* @returns Array of role information
* Get user's active roles. Same scope rules as `hasPermission`.
*/
getUserRoles = async (userId?: string): Promise<RoleItem[]> => {
const targetUserId = userId || this.userId;
getUserRoles = async (arg?: string | RbacScopeOptions): Promise<RoleItem[]> => {
const opts = normalizeScope(arg);
const targetUserId = opts.userId || this.userId;
return await this.db
.select({
@@ -183,6 +241,7 @@ export class RbacModel {
and(
eq(userRoles.userId, targetUserId),
eq(roles.isActive, true),
buildScopeWhere(opts.workspaceId),
// Check if role assignment is not expired
sql`(${userRoles.expiresAt} IS NULL OR ${userRoles.expiresAt} > NOW())`,
),
@@ -190,6 +249,40 @@ export class RbacModel {
.orderBy(userRoles.createdAt);
};
/**
* List all roles defined inside a workspace (both built-in and custom).
* Used by the upcoming custom-role admin UI (LOBE-9193) and any client that
* wants to show available roles for a workspace.
*/
listWorkspaceRoles = async (workspaceId: string): Promise<RoleItem[]> => {
return this.db.query.roles.findMany({
orderBy: (table, { asc }) => [asc(table.isSystem), asc(table.name)],
where: and(eq(roles.workspaceId, workspaceId), eq(roles.isActive, true)),
});
};
/**
* Grant a built-in workspace role (`workspace_owner` | `workspace_member` |
* `workspace_viewer`) to a user inside a workspace. Delegates to the seed
* util so the onConflict + role-lookup logic lives in one place.
*/
assignWorkspaceRole = async (params: {
roleName: WorkspaceSystemRoleName;
userId: string;
workspaceId: string;
}): Promise<void> => {
await assignWorkspaceRoleToUser(this.db, params);
};
/**
* Revoke every workspace-scoped role this user holds in `workspaceId`.
* Idempotent. Used by member removal/leave flows and by `updateRole` before
* granting the new role.
*/
revokeWorkspaceRole = async (params: { userId: string; workspaceId: string }): Promise<void> => {
await revokeWorkspaceRolesForUser(this.db, params);
};
/**
* Update user roles using a transaction to ensure atomicity
* @param userId User ID
+15 -9
View File
@@ -4,6 +4,7 @@ import { unionAll } from 'drizzle-orm/pg-core';
import { agents, DOCUMENT_FOLDER_TYPE, documents, tasks, topics } from '../schemas';
import type { LobeChatDatabase } from '../type';
import { buildWorkspaceWhere } from '../utils/workspace';
export interface RecentDbItem {
id: string;
@@ -30,14 +31,24 @@ const TASK_FINAL_STATUSES = ['completed', 'canceled'];
export class RecentModel {
private userId: string;
private workspaceId?: string;
private db: LobeChatDatabase;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.db = db;
this.userId = userId;
this.workspaceId = workspaceId;
}
queryRecent = async (limit: number = 10): Promise<RecentDbItem[]> => {
const scope = { userId: this.userId, workspaceId: this.workspaceId };
// `tasks` uses `createdByUserId` instead of `userId`, so apply the
// workspace-aware predicate inline.
const taskScopeWhere = this.workspaceId
? eq(tasks.workspaceId, this.workspaceId)
: and(eq(tasks.createdByUserId, this.userId), isNull(tasks.workspaceId));
const topicArm = this.db
.select({
id: topics.id,
@@ -53,7 +64,7 @@ export class RecentModel {
.leftJoin(agents, eq(topics.agentId, agents.id))
.where(
and(
eq(topics.userId, this.userId),
buildWorkspaceWhere(scope, topics),
or(
isNotNull(topics.groupId),
eq(agents.slug, 'inbox'),
@@ -80,7 +91,7 @@ export class RecentModel {
.from(documents)
.where(
and(
eq(documents.userId, this.userId),
buildWorkspaceWhere(scope, documents),
not(inArray(documents.sourceType, TOOL_DOCUMENT_SOURCE_TYPES)),
isNull(documents.knowledgeBaseId),
ne(documents.fileType, DOCUMENT_FOLDER_TYPE),
@@ -101,12 +112,7 @@ export class RecentModel {
updatedAt: tasks.updatedAt,
})
.from(tasks)
.where(
and(
eq(tasks.createdByUserId, this.userId),
not(inArray(tasks.status, TASK_FINAL_STATUSES)),
),
);
.where(and(taskScopeWhere, not(inArray(tasks.status, TASK_FINAL_STATUSES))));
const rows = await unionAll(topicArm, documentArm, taskArm)
.orderBy(desc(sql`updated_at`))
+96 -85
View File
@@ -16,15 +16,27 @@ import type { LobeChatDatabase } from '../type';
import { sanitizeBm25Query } from '../utils/bm25';
import { genEndDateWhere, genRangeWhere, genStartDateWhere, genWhere } from '../utils/genWhere';
import { idGenerator } from '../utils/idGenerator';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
export class SessionModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, sessions);
private agentsOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agents);
private agentsToSessionsOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agentsToSessions);
// **************** Query *************** //
query = async ({ current = 0, pageSize = 9999 } = {}) => {
@@ -44,7 +56,7 @@ export class SessionModel {
.leftJoin(agentsToSessions, eq(sessions.id, agentsToSessions.sessionId))
.leftJoin(agents, eq(agentsToSessions.agentId, agents.id))
.leftJoin(sessionGroups, eq(sessions.groupId, sessionGroups.id))
.where(and(eq(sessions.userId, this.userId), not(eq(sessions.slug, INBOX_SESSION_ID))))
.where(and(this.ownership(), not(eq(sessions.slug, INBOX_SESSION_ID))))
.orderBy(desc(sessions.updatedAt))
.limit(pageSize)
.offset(offset);
@@ -76,7 +88,7 @@ export class SessionModel {
const groups = await this.db.query.sessionGroups.findMany({
orderBy: [asc(sessionGroups.sort), desc(sessionGroups.createdAt)],
where: eq(sessions.userId, this.userId),
where: and(this.ownership()),
});
const mappedSessions = result.map((item) => this.mapSessionItem(item as any));
@@ -108,12 +120,7 @@ export class SessionModel {
session: sessions,
})
.from(sessions)
.where(
and(
or(eq(sessions.id, idOrSlug), eq(sessions.slug, idOrSlug)),
eq(sessions.userId, this.userId),
),
)
.where(and(or(eq(sessions.id, idOrSlug), eq(sessions.slug, idOrSlug)), this.ownership()))
.leftJoin(agentsToSessions, eq(sessions.id, agentsToSessions.sessionId))
.leftJoin(agents, eq(agentsToSessions.agentId, agents.id))
.leftJoin(sessionGroups, eq(sessions.groupId, sessionGroups.id))
@@ -136,7 +143,7 @@ export class SessionModel {
.from(sessions)
.where(
genWhere([
eq(sessions.userId, this.userId),
this.ownership(),
params?.range
? genRangeWhere(params.range, sessions.createdAt, (date) => date.toDate())
: undefined,
@@ -156,7 +163,7 @@ export class SessionModel {
const result = await this.db
.select({ id: sessions.id })
.from(sessions)
.where(eq(sessions.userId, this.userId))
.where(and(this.ownership()))
.limit(n + 1);
return result.length > n;
@@ -184,7 +191,7 @@ export class SessionModel {
return this.db.transaction(async (trx) => {
if (slug) {
const existResult = await trx.query.sessions.findFirst({
where: and(eq(sessions.slug, slug), eq(sessions.userId, this.userId)),
where: and(eq(sessions.slug, slug), this.ownership()),
});
if (existResult) return existResult;
@@ -220,15 +227,19 @@ export class SessionModel {
if (type === 'group') {
const result = await trx
.insert(sessions)
.values({
...session,
createdAt: new Date(),
id,
slug,
type,
updatedAt: new Date(),
userId: this.userId,
})
.values(
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{
...session,
createdAt: new Date(),
id,
slug,
type,
updatedAt: new Date(),
},
),
)
.returning();
return result[0];
@@ -236,48 +247,57 @@ export class SessionModel {
const newAgents = await trx
.insert(agents)
.values({
avatar,
backgroundColor,
chatConfig: chatConfig || {},
createdAt: new Date(),
description,
editorData: editorData || null,
fewShots: examples || null, // Map examples to fewShots field
id: idGenerator('agents'),
marketIdentifier: identifier || marketIdentifier,
model: typeof model === 'string' ? model : null,
openingMessage,
openingQuestions,
params: params || {},
plugins,
provider,
systemRole,
tags,
title,
tts: tts || {},
updatedAt: new Date(),
userId: this.userId,
})
.values(
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{
avatar,
backgroundColor,
chatConfig: chatConfig || {},
createdAt: new Date(),
description,
editorData: editorData || null,
fewShots: examples || null, // Map examples to fewShots field
id: idGenerator('agents'),
marketIdentifier: identifier || marketIdentifier,
model: typeof model === 'string' ? model : null,
openingMessage,
openingQuestions,
params: params || {},
plugins,
provider,
systemRole,
tags,
title,
tts: tts || {},
updatedAt: new Date(),
},
),
)
.returning();
const result = await trx
.insert(sessions)
.values({
...session,
createdAt: new Date(),
id,
slug,
type,
updatedAt: new Date(),
userId: this.userId,
})
.values(
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{
...session,
createdAt: new Date(),
id,
slug,
type,
updatedAt: new Date(),
},
),
)
.returning();
await trx.insert(agentsToSessions).values({
agentId: newAgents[0].id,
sessionId: id,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
});
return result[0];
@@ -286,7 +306,7 @@ export class SessionModel {
createInbox = async (defaultAgentConfig: PartialDeep<LobeAgentConfig>) => {
const item = await this.db.query.sessions.findFirst({
where: and(eq(sessions.userId, this.userId), eq(sessions.slug, INBOX_SESSION_ID)),
where: and(this.ownership(), eq(sessions.slug, INBOX_SESSION_ID)),
});
if (item) return;
@@ -299,13 +319,15 @@ export class SessionModel {
};
batchCreate = async (newSessions: NewSession[]) => {
const sessionsToInsert = newSessions.map((s) => {
return {
...s,
id: this.genId(),
userId: this.userId,
};
});
const sessionsToInsert = newSessions.map((s) =>
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{
...s,
id: this.genId(),
},
),
);
return this.db.insert(sessions).values(sessionsToInsert);
};
@@ -343,19 +365,17 @@ export class SessionModel {
const links = await trx
.select({ agentId: agentsToSessions.agentId })
.from(agentsToSessions)
.where(and(eq(agentsToSessions.sessionId, id), eq(agentsToSessions.userId, this.userId)));
.where(and(eq(agentsToSessions.sessionId, id), this.agentsToSessionsOwnership()));
const agentIds = links.map((link) => link.agentId);
// Delete links in agentsToSessions
await trx
.delete(agentsToSessions)
.where(and(eq(agentsToSessions.sessionId, id), eq(agentsToSessions.userId, this.userId)));
.where(and(eq(agentsToSessions.sessionId, id), this.agentsToSessionsOwnership()));
// Delete the session (this will cascade delete messages, topics, etc.)
const result = await trx
.delete(sessions)
.where(and(eq(sessions.id, id), eq(sessions.userId, this.userId)));
const result = await trx.delete(sessions).where(and(eq(sessions.id, id), this.ownership()));
// Delete orphaned agents
await this.clearOrphanAgent(agentIds, trx);
@@ -375,23 +395,19 @@ export class SessionModel {
const links = await trx
.select({ agentId: agentsToSessions.agentId })
.from(agentsToSessions)
.where(
and(inArray(agentsToSessions.sessionId, ids), eq(agentsToSessions.userId, this.userId)),
);
.where(and(inArray(agentsToSessions.sessionId, ids), this.agentsToSessionsOwnership()));
const agentIds = [...new Set(links.map((link) => link.agentId))];
// Delete links in agentsToSessions
await trx
.delete(agentsToSessions)
.where(
and(inArray(agentsToSessions.sessionId, ids), eq(agentsToSessions.userId, this.userId)),
);
.where(and(inArray(agentsToSessions.sessionId, ids), this.agentsToSessionsOwnership()));
// Delete the sessions
const result = await trx
.delete(sessions)
.where(and(inArray(sessions.id, ids), eq(sessions.userId, this.userId)));
.where(and(inArray(sessions.id, ids), this.ownership()));
// Delete orphaned agents
await this.clearOrphanAgent(agentIds, trx);
@@ -405,14 +421,9 @@ export class SessionModel {
*/
deleteAll = async () => {
return this.db.transaction(async (trx) => {
// Delete all agentsToSessions for this user
await trx.delete(agentsToSessions).where(eq(agentsToSessions.userId, this.userId));
// Delete all agents that were only used by this user's sessions
await trx.delete(agents).where(eq(agents.userId, this.userId));
// Delete all sessions for this user
return trx.delete(sessions).where(eq(sessions.userId, this.userId));
await trx.delete(agentsToSessions).where(this.agentsToSessionsOwnership());
await trx.delete(agents).where(this.agentsOwnership());
return trx.delete(sessions).where(this.ownership());
});
};
@@ -435,7 +446,7 @@ export class SessionModel {
if (orphanedAgentIds.length > 0) {
await trx
.delete(agents)
.where(and(inArray(agents.id, orphanedAgentIds), eq(agents.userId, this.userId)));
.where(and(inArray(agents.id, orphanedAgentIds), this.agentsOwnership()));
}
};
@@ -445,7 +456,7 @@ export class SessionModel {
return this.db
.update(sessions)
.set(data)
.where(and(eq(sessions.id, id), eq(sessions.userId, this.userId)))
.where(and(eq(sessions.id, id), this.ownership()))
.returning();
};
@@ -505,7 +516,7 @@ export class SessionModel {
return this.db
.update(agents)
.set(mergedValue)
.where(and(eq(agents.id, session.agent.id), eq(agents.userId, this.userId)));
.where(and(eq(agents.id, session.agent.id), this.agentsOwnership()));
};
// **************** Helper *************** //
@@ -598,7 +609,7 @@ export class SessionModel {
// Keep deterministic ordering for keyword search results
orderBy: [asc(agents.id)],
where: and(
eq(agents.userId, this.userId),
this.agentsOwnership(),
sql`(${agents.title} @@@ ${bm25Query} OR ${agents.description} @@@ ${bm25Query})`,
),
with: { agentsToSessions: { columns: {}, with: { session: true } } },
+19 -10
View File
@@ -4,45 +4,54 @@ import type { SessionGroupItem } from '../schemas';
import { sessionGroups } from '../schemas';
import type { LobeChatDatabase } from '../type';
import { idGenerator } from '../utils/idGenerator';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
export class SessionGroupModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, sessionGroups);
create = async (params: { name: string; sort?: number }) => {
const [result] = await this.db
.insert(sessionGroups)
.values({ ...params, id: this.genId(), userId: this.userId })
.values(
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{ ...params, id: this.genId() },
),
)
.returning();
return result;
};
delete = async (id: string) => {
return this.db
.delete(sessionGroups)
.where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)));
return this.db.delete(sessionGroups).where(and(eq(sessionGroups.id, id), this.ownership()));
};
deleteAll = async () => {
return this.db.delete(sessionGroups).where(eq(sessionGroups.userId, this.userId));
return this.db.delete(sessionGroups).where(this.ownership());
};
query = async () => {
return this.db.query.sessionGroups.findMany({
orderBy: [asc(sessionGroups.sort), desc(sessionGroups.createdAt)],
where: eq(sessionGroups.userId, this.userId),
where: this.ownership(),
});
};
findById = async (id: string) => {
return this.db.query.sessionGroups.findFirst({
where: and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)),
where: and(eq(sessionGroups.id, id), this.ownership()),
});
};
@@ -50,7 +59,7 @@ export class SessionGroupModel {
return this.db
.update(sessionGroups)
.set({ ...value, updatedAt: new Date() })
.where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)));
.where(and(eq(sessionGroups.id, id), this.ownership()));
};
updateOrder = async (sortMap: { id: string; sort: number }[]) => {
@@ -59,7 +68,7 @@ export class SessionGroupModel {
return tx
.update(sessionGroups)
.set({ sort, updatedAt: new Date() })
.where(and(eq(sessionGroups.id, id), eq(sessionGroups.userId, this.userId)));
.where(and(eq(sessionGroups.id, id), this.ownership()));
});
await Promise.all(updates);
+338 -44
View File
@@ -7,6 +7,7 @@ import type {
WorkspaceTreeNode,
} from '@lobechat/types';
import { and, desc, eq, gte, inArray, isNotNull, isNull, ne, notInArray, sql } from 'drizzle-orm';
import type { AnyPgColumn } from 'drizzle-orm/pg-core';
import { merge } from '@/utils/merge';
@@ -14,16 +15,50 @@ import { documents } from '../schemas/file';
import type { NewTaskComment, TaskCommentItem } from '../schemas/task';
import { taskComments, taskDependencies, taskDocuments, tasks } from '../schemas/task';
import type { LobeChatDatabase } from '../type';
import { buildWorkspaceWhere } from '../utils/workspace';
export class TaskModel {
private readonly userId: string;
private readonly db: LobeChatDatabase;
private readonly workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.db = db;
this.userId = userId;
this.workspaceId = workspaceId;
}
/**
* Compat-mode ownership predicate for the `tasks` table.
* `tasks` uses `createdByUserId` instead of `userId`.
*/
private ownership = () =>
buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
{ userId: tasks.createdByUserId, workspaceId: tasks.workspaceId },
);
/**
* Ownership predicate for task child tables (deps / docs / comments) that
* use a `userId` column instead of `createdByUserId`.
*/
private childOwnership = (cols: { userId: AnyPgColumn; workspaceId: AnyPgColumn }) =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, cols);
/**
* Raw-SQL ownership clause for use inside `db.execute(sql...)` CTEs that
* can't easily compose with drizzle's `and(...)` helpers. Mirrors
* `buildWorkspaceWhere` semantics:
* - workspace mode `workspace_id = $ws`
* - personal mode `created_by_user_id = $userId AND workspace_id IS NULL`
*/
private ownershipSql = (alias?: string) => {
const prefix = alias ? sql.raw(`${alias}.`) : sql.raw('');
return this.workspaceId
? sql`${prefix}workspace_id = ${this.workspaceId}`
: sql`${prefix}created_by_user_id = ${this.userId} AND ${prefix}workspace_id IS NULL`;
};
// ========== CRUD ==========
async create(
@@ -37,10 +72,13 @@ export class TaskModel {
const maxRetries = 5;
for (let attempt = 0; attempt < maxRetries; attempt++) {
try {
// Seq is allocated per ownership scope: workspace-wide in team mode,
// user-private in personal mode. This keeps `T-N` identifiers stable
// within the surface the user actually sees.
const seqResult = await this.db
.select({ maxSeq: sql<number>`COALESCE(MAX(${tasks.seq}), 0)` })
.from(tasks)
.where(eq(tasks.createdByUserId, this.userId));
.where(this.ownership());
const nextSeq = Number(seqResult[0].maxSeq) + 1;
const identifier = `${identifierPrefix}-${nextSeq}`;
@@ -52,6 +90,7 @@ export class TaskModel {
createdByUserId: this.userId,
identifier,
seq: nextSeq,
workspaceId: this.workspaceId ?? null,
} as NewTask)
.returning();
@@ -79,7 +118,7 @@ export class TaskModel {
const result = await this.db
.select()
.from(tasks)
.where(and(eq(tasks.id, id), eq(tasks.createdByUserId, this.userId)))
.where(and(eq(tasks.id, id), this.ownership()))
.limit(1);
return result[0] || null;
@@ -90,7 +129,7 @@ export class TaskModel {
return this.db
.select()
.from(tasks)
.where(and(inArray(tasks.id, ids), eq(tasks.createdByUserId, this.userId)));
.where(and(inArray(tasks.id, ids), this.ownership()));
}
// Resolve id or identifier (e.g. 'T-1') to a task
@@ -103,7 +142,7 @@ export class TaskModel {
const result = await this.db
.select()
.from(tasks)
.where(and(eq(tasks.identifier, identifier), eq(tasks.createdByUserId, this.userId)))
.where(and(eq(tasks.identifier, identifier), this.ownership()))
.limit(1);
return result[0] || null;
@@ -118,7 +157,7 @@ export class TaskModel {
const updated = await this.db
.update(tasks)
.set({ ...data, updatedAt: new Date() })
.where(and(eq(tasks.id, id), eq(tasks.createdByUserId, this.userId)))
.where(and(eq(tasks.id, id), this.ownership()))
.returning();
return updated[0] || null;
}
@@ -126,17 +165,14 @@ export class TaskModel {
async delete(id: string): Promise<boolean> {
const result = await this.db
.delete(tasks)
.where(and(eq(tasks.id, id), eq(tasks.createdByUserId, this.userId)))
.where(and(eq(tasks.id, id), this.ownership()))
.returning();
return result.length > 0;
}
async deleteAll(): Promise<number> {
const result = await this.db
.delete(tasks)
.where(eq(tasks.createdByUserId, this.userId))
.returning();
const result = await this.db.delete(tasks).where(this.ownership()).returning();
return result.length;
}
@@ -164,7 +200,7 @@ export class TaskModel {
> {
const { groups, assigneeAgentId, parentTaskId } = options;
const baseConditions = [eq(tasks.createdByUserId, this.userId)];
const baseConditions = [this.ownership()];
if (assigneeAgentId) baseConditions.push(eq(tasks.assigneeAgentId, assigneeAgentId));
if (parentTaskId === null) {
baseConditions.push(isNull(tasks.parentTaskId));
@@ -232,7 +268,7 @@ export class TaskModel {
offset = 0,
} = options || {};
const conditions = [eq(tasks.createdByUserId, this.userId)];
const conditions = [this.ownership()];
if (statuses?.length) conditions.push(inArray(tasks.status, statuses));
if (priorities?.length) conditions.push(inArray(tasks.priority, priorities));
@@ -271,7 +307,7 @@ export class TaskModel {
await this.db
.update(tasks)
.set({ sortOrder: item.sortOrder, updatedAt: new Date() })
.where(and(eq(tasks.id, item.id), eq(tasks.createdByUserId, this.userId)));
.where(and(eq(tasks.id, item.id), this.ownership()));
}
}
@@ -279,7 +315,7 @@ export class TaskModel {
return this.db
.select()
.from(tasks)
.where(and(eq(tasks.parentTaskId, parentTaskId), eq(tasks.createdByUserId, this.userId)))
.where(and(eq(tasks.parentTaskId, parentTaskId), this.ownership()))
.orderBy(tasks.sortOrder, tasks.seq);
}
@@ -295,7 +331,7 @@ export class TaskModel {
const children = await this.db
.select()
.from(tasks)
.where(and(inArray(tasks.parentTaskId, parentIds), eq(tasks.createdByUserId, this.userId)))
.where(and(inArray(tasks.parentTaskId, parentIds), this.ownership()))
.orderBy(tasks.sortOrder, tasks.seq);
if (children.length === 0) break;
@@ -309,9 +345,10 @@ export class TaskModel {
// Recursive query to get full task tree
async getTaskTree(rootTaskId: string): Promise<TaskItem[]> {
const ownership = this.ownershipSql();
const result = await this.db.execute(sql`
WITH RECURSIVE task_tree AS (
SELECT * FROM tasks WHERE id = ${rootTaskId} AND created_by_user_id = ${this.userId}
SELECT * FROM tasks WHERE id = ${rootTaskId} AND ${ownership}
UNION ALL
SELECT t.* FROM tasks t
JOIN task_tree tt ON t.parent_task_id = tt.id
@@ -333,18 +370,20 @@ export class TaskModel {
const taskIdParams = taskIds.map((id) => sql`${id}`);
const taskIdList = sql.join(taskIdParams, sql`, `);
const ownershipBare = this.ownershipSql();
const ownershipAliased = this.ownershipSql('t');
const result = await this.db.execute(sql`
WITH RECURSIVE
ancestors AS (
SELECT id AS origin_id, id, parent_task_id
FROM tasks
WHERE id IN (${taskIdList})
AND created_by_user_id = ${this.userId}
AND ${ownershipBare}
UNION ALL
SELECT a.origin_id, t.id, t.parent_task_id
FROM tasks t
JOIN ancestors a ON t.id = a.parent_task_id
WHERE t.created_by_user_id = ${this.userId}
WHERE ${ownershipAliased}
),
roots AS (
SELECT DISTINCT ON (origin_id) origin_id, id AS root_id
@@ -355,12 +394,12 @@ export class TaskModel {
SELECT r.origin_id, t.id, t.assignee_agent_id, t.created_by_agent_id
FROM tasks t
JOIN roots r ON t.id = r.root_id
WHERE t.created_by_user_id = ${this.userId}
WHERE ${ownershipAliased}
UNION ALL
SELECT d.origin_id, t.id, t.assignee_agent_id, t.created_by_agent_id
FROM tasks t
JOIN descendants d ON t.parent_task_id = d.id
WHERE t.created_by_user_id = ${this.userId}
WHERE ${ownershipAliased}
)
SELECT origin_id, assignee_agent_id, created_by_agent_id
FROM descendants
@@ -392,7 +431,7 @@ export class TaskModel {
const result = await this.db
.update(tasks)
.set({ status, updatedAt: new Date() })
.where(and(inArray(tasks.id, ids), eq(tasks.createdByUserId, this.userId)))
.where(and(inArray(tasks.id, ids), this.ownership()))
.returning();
return result.length;
@@ -476,7 +515,7 @@ export class TaskModel {
await this.db
.update(tasks)
.set({ lastHeartbeatAt: new Date(), updatedAt: new Date() })
.where(eq(tasks.id, id));
.where(and(eq(tasks.id, id), this.ownership()));
}
// Tasks eligible for cron-based dispatch.
@@ -513,10 +552,22 @@ export class TaskModel {
// ========== Dependencies ==========
private depsOwnership = () =>
this.childOwnership({
userId: taskDependencies.userId,
workspaceId: taskDependencies.workspaceId,
});
async addDependency(taskId: string, dependsOnId: string, type: string = 'blocks'): Promise<void> {
await this.db
.insert(taskDependencies)
.values({ dependsOnId, taskId, type, userId: this.userId })
.values({
dependsOnId,
taskId,
type,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})
.onConflictDoNothing();
}
@@ -524,21 +575,34 @@ export class TaskModel {
await this.db
.delete(taskDependencies)
.where(
and(eq(taskDependencies.taskId, taskId), eq(taskDependencies.dependsOnId, dependsOnId)),
and(
eq(taskDependencies.taskId, taskId),
eq(taskDependencies.dependsOnId, dependsOnId),
this.depsOwnership(),
),
);
}
async getDependencies(taskId: string) {
return this.db.select().from(taskDependencies).where(eq(taskDependencies.taskId, taskId));
return this.db
.select()
.from(taskDependencies)
.where(and(eq(taskDependencies.taskId, taskId), this.depsOwnership()));
}
async getDependenciesByTaskIds(taskIds: string[]) {
if (taskIds.length === 0) return [];
return this.db.select().from(taskDependencies).where(inArray(taskDependencies.taskId, taskIds));
return this.db
.select()
.from(taskDependencies)
.where(and(inArray(taskDependencies.taskId, taskIds), this.depsOwnership()));
}
async getDependents(taskId: string) {
return this.db.select().from(taskDependencies).where(eq(taskDependencies.dependsOnId, taskId));
return this.db
.select()
.from(taskDependencies)
.where(and(eq(taskDependencies.dependsOnId, taskId), this.depsOwnership()));
}
// Check if all dependencies of a task are completed
@@ -552,6 +616,7 @@ export class TaskModel {
eq(taskDependencies.taskId, taskId),
eq(taskDependencies.type, 'blocks'),
ne(tasks.status, 'completed'),
this.depsOwnership(),
),
);
@@ -587,11 +652,7 @@ export class TaskModel {
.select({ count: sql<number>`count(*)` })
.from(tasks)
.where(
and(
eq(tasks.parentTaskId, parentTaskId),
ne(tasks.status, 'completed'),
eq(tasks.createdByUserId, this.userId),
),
and(eq(tasks.parentTaskId, parentTaskId), ne(tasks.status, 'completed'), this.ownership()),
);
return Number(result[0].count) === 0;
@@ -599,24 +660,42 @@ export class TaskModel {
// ========== Documents (MVP Workspace) ==========
private docsOwnership = () =>
this.childOwnership({
userId: taskDocuments.userId,
workspaceId: taskDocuments.workspaceId,
});
async pinDocument(taskId: string, documentId: string, pinnedBy: string = 'agent'): Promise<void> {
await this.db
.insert(taskDocuments)
.values({ documentId, pinnedBy, taskId, userId: this.userId })
.values({
documentId,
pinnedBy,
taskId,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})
.onConflictDoNothing();
}
async unpinDocument(taskId: string, documentId: string): Promise<void> {
await this.db
.delete(taskDocuments)
.where(and(eq(taskDocuments.taskId, taskId), eq(taskDocuments.documentId, documentId)));
.where(
and(
eq(taskDocuments.taskId, taskId),
eq(taskDocuments.documentId, documentId),
this.docsOwnership(),
),
);
}
async getPinnedDocuments(taskId: string) {
return this.db
.select()
.from(taskDocuments)
.where(eq(taskDocuments.taskId, taskId))
.where(and(eq(taskDocuments.taskId, taskId), this.docsOwnership()))
.orderBy(taskDocuments.createdAt);
}
@@ -642,7 +721,7 @@ export class TaskModel {
.where(
and(
eq(taskDocuments.taskId, taskId),
eq(taskDocuments.userId, this.userId),
this.docsOwnership(),
gte(taskDocuments.createdAt, since),
),
);
@@ -656,12 +735,18 @@ export class TaskModel {
// Get all pinned docs from a task tree (recursive), returns nodeMap + tree structure
async getTreePinnedDocuments(rootTaskId: string): Promise<WorkspaceData> {
const rootOwnership = this.ownershipSql();
const recursiveOwnership = this.ownershipSql('t');
const docsOwnership = this.workspaceId
? sql`td.workspace_id = ${this.workspaceId}`
: sql`td.user_id = ${this.userId} AND td.workspace_id IS NULL`;
const result = await this.db.execute(sql`
WITH RECURSIVE task_tree AS (
SELECT id, identifier FROM tasks WHERE id = ${rootTaskId}
SELECT id, identifier FROM tasks WHERE id = ${rootTaskId} AND ${rootOwnership}
UNION ALL
SELECT t.id, t.identifier FROM tasks t
JOIN task_tree tt ON t.parent_task_id = tt.id
WHERE ${recursiveOwnership}
)
SELECT td.*, tt.id as source_task_id, tt.identifier as source_task_identifier,
d.title as document_title, d.file_type as document_file_type, d.parent_id as document_parent_id,
@@ -669,6 +754,7 @@ export class TaskModel {
FROM task_documents td
JOIN task_tree tt ON td.task_id = tt.id
LEFT JOIN documents d ON td.document_id = d.id
WHERE ${docsOwnership}
ORDER BY td.created_at
`);
@@ -725,20 +811,29 @@ export class TaskModel {
totalTopics: sql`${tasks.totalTopics} + 1`,
updatedAt: new Date(),
})
.where(eq(tasks.id, id));
.where(and(eq(tasks.id, id), this.ownership()));
}
async updateCurrentTopic(id: string, topicId: string): Promise<void> {
await this.db
.update(tasks)
.set({ currentTopicId: topicId, updatedAt: new Date() })
.where(eq(tasks.id, id));
.where(and(eq(tasks.id, id), this.ownership()));
}
// ========== Comments ==========
private commentsOwnership = () =>
this.childOwnership({
userId: taskComments.userId,
workspaceId: taskComments.workspaceId,
});
async addComment(data: Omit<NewTaskComment, 'id'>): Promise<TaskCommentItem> {
const [comment] = await this.db.insert(taskComments).values(data).returning();
const [comment] = await this.db
.insert(taskComments)
.values({ ...data, workspaceId: this.workspaceId ?? null })
.returning();
return comment;
}
@@ -746,14 +841,14 @@ export class TaskModel {
return this.db
.select()
.from(taskComments)
.where(eq(taskComments.taskId, taskId))
.where(and(eq(taskComments.taskId, taskId), this.commentsOwnership()))
.orderBy(taskComments.createdAt);
}
async deleteComment(id: string): Promise<boolean> {
const result = await this.db
.delete(taskComments)
.where(and(eq(taskComments.id, id), eq(taskComments.userId, this.userId)))
.where(and(eq(taskComments.id, id), this.commentsOwnership()))
.returning();
return result.length > 0;
}
@@ -770,8 +865,207 @@ export class TaskModel {
...(opts?.editorData !== undefined ? { editorData: opts.editorData as never } : {}),
updatedAt: new Date(),
})
.where(and(eq(taskComments.id, id), eq(taskComments.userId, this.userId)))
.where(and(eq(taskComments.id, id), this.commentsOwnership()))
.returning();
return comment;
}
// ========== Transfer / Copy ==========
/**
* Collect a task and all its descendants (parentTaskId-linked) via BFS.
* Honors the current ownership scope.
*/
private async collectTaskSubtree(rootId: string, runner: LobeChatDatabase): Promise<TaskItem[]> {
const [root] = await runner
.select()
.from(tasks)
.where(and(eq(tasks.id, rootId), this.ownership()))
.limit(1);
if (!root) return [];
const collected: TaskItem[] = [root];
let frontier: string[] = [root.id];
while (frontier.length > 0) {
const children = await runner
.select()
.from(tasks)
.where(and(inArray(tasks.parentTaskId, frontier), this.ownership()));
if (children.length === 0) break;
collected.push(...children);
frontier = children.map((c) => c.id);
}
return collected;
}
/**
* Allocate a contiguous block of seq numbers + identifiers in the target
* scope. Returns the next available seq baseline.
*/
private async nextSeqIn(
runner: LobeChatDatabase,
targetWorkspaceId: string | null,
targetUserId: string,
): Promise<number> {
const where = targetWorkspaceId
? eq(tasks.workspaceId, targetWorkspaceId)
: and(eq(tasks.createdByUserId, targetUserId), isNull(tasks.workspaceId));
const [{ maxSeq }] = await runner
.select({ maxSeq: sql<number>`COALESCE(MAX(${tasks.seq}), 0)` })
.from(tasks)
.where(where!);
return Number(maxSeq) + 1;
}
/**
* Transfer a task subtree to another workspace / personal scope. Reallocates
* `identifier`/`seq` in the target scope and rewrites every dependent child
* table (`task_dependencies`, `task_documents`, `task_topics`,
* `task_comments`, `briefs`) so the ownership predicates remain consistent.
*
* Cross-scope references that may no longer be valid are cleared:
* - `assigneeAgentId` (workspace move: agent likely doesn't exist there)
* - `currentTopicId` (topic ownership is also moving but the link is
* reset to avoid surfacing a stale active topic in the new scope)
*/
async transferTo(
taskId: string,
targetWorkspaceId: string | null,
targetUserId: string,
): Promise<{ taskIds: string[] }> {
return this.db.transaction(async (trx) => {
const scoped = new TaskModel(trx as LobeChatDatabase, this.userId, this.workspaceId);
const subtree = await scoped.collectTaskSubtree(taskId, trx as LobeChatDatabase);
if (subtree.length === 0) throw new Error('Task not found');
const ids = subtree.map((t) => t.id);
// Reallocate identifier + seq in target scope to avoid collisions.
const baseSeq = await this.nextSeqIn(
trx as LobeChatDatabase,
targetWorkspaceId,
targetUserId,
);
// Update each task individually because identifier/seq are per-row.
for (const [idx, task] of subtree.entries()) {
const seq = baseSeq + idx;
const identifier = `T-${seq}`;
await (trx as LobeChatDatabase)
.update(tasks)
.set({
// Clear cross-scope refs: agent / topic may be invalid in new scope.
assigneeAgentId: targetWorkspaceId === this.workspaceId ? task.assigneeAgentId : null,
createdByUserId: targetUserId,
currentTopicId: null,
identifier,
seq,
updatedAt: new Date(),
workspaceId: targetWorkspaceId,
})
.where(eq(tasks.id, task.id));
}
// Update child tables that key off taskId.
const ownershipUpdate = { userId: targetUserId, workspaceId: targetWorkspaceId };
await (trx as LobeChatDatabase)
.update(taskDependencies)
.set(ownershipUpdate)
.where(inArray(taskDependencies.taskId, ids));
await (trx as LobeChatDatabase)
.update(taskDocuments)
.set(ownershipUpdate)
.where(inArray(taskDocuments.taskId, ids));
await (trx as LobeChatDatabase)
.update(taskComments)
.set(ownershipUpdate)
.where(inArray(taskComments.taskId, ids));
return { taskIds: ids };
});
}
/**
* Deep clone a task subtree into another workspace / personal scope. Fresh
* ids, fresh identifiers, preserved parent/child topology. Cross-scope refs
* (agent / topic / brief / current topic) are cleared on the clones so the
* copies start clean in the new scope.
*/
async copyToWorkspace(
taskId: string,
targetWorkspaceId: string | null,
targetUserId: string,
): Promise<{ rootId: string }> {
return this.db.transaction(async (trx) => {
const scoped = new TaskModel(trx as LobeChatDatabase, this.userId, this.workspaceId);
const subtree = await scoped.collectTaskSubtree(taskId, trx as LobeChatDatabase);
if (subtree.length === 0) throw new Error('Task not found');
// BFS clone — parent inserted before children, so we always know the
// new parentTaskId by the time we reach the child.
const idMap = new Map<string, string>();
const byId = new Map(subtree.map((t) => [t.id, t]));
const queue: string[] = [taskId];
const seen = new Set<string>();
let seq = await this.nextSeqIn(trx as LobeChatDatabase, targetWorkspaceId, targetUserId);
while (queue.length > 0) {
const currentId = queue.shift()!;
if (seen.has(currentId)) continue;
seen.add(currentId);
const original = byId.get(currentId);
if (!original) continue;
const newParentId =
currentId === taskId ? null : (idMap.get(original.parentTaskId!) ?? null);
const identifier = `T-${seq}`;
const inserted = (await (trx as LobeChatDatabase)
.insert(tasks)
.values({
assigneeAgentId: null,
assigneeUserId: null,
automationMode: original.automationMode,
config: original.config ?? {},
context: {
...(original.context as Record<string, unknown>),
duplicatedFrom: original.id,
},
createdByAgentId: null,
createdByUserId: targetUserId,
currentTopicId: null,
description: original.description,
error: null,
heartbeatInterval: original.heartbeatInterval,
heartbeatTimeout: original.heartbeatTimeout,
identifier,
instruction: original.instruction,
maxTopics: original.maxTopics,
name: original.name,
parentTaskId: newParentId,
priority: original.priority,
schedulePattern: original.schedulePattern,
scheduleTimezone: original.scheduleTimezone,
seq,
sortOrder: original.sortOrder,
// Reset lifecycle: copy starts fresh, not mid-run.
status: 'backlog',
totalTopics: 0,
workspaceId: targetWorkspaceId,
} as NewTask)
.returning({ id: tasks.id })) as { id: string }[];
idMap.set(original.id, inserted[0]!.id);
seq++;
for (const c of subtree) {
if (c.parentTaskId === original.id) queue.push(c.id);
}
}
return { rootId: idMap.get(taskId)! };
});
}
}
+27 -57
View File
@@ -5,18 +5,24 @@ import type { TaskTopicItem } from '../schemas/task';
import { tasks, taskTopics } from '../schemas/task';
import { topics } from '../schemas/topic';
import type { LobeChatDatabase } from '../type';
import { buildWorkspaceWhere } from '../utils/workspace';
const TERMINAL_TOPIC_STATUSES = new Set(['canceled', 'completed', 'failed', 'timeout']);
export class TaskTopicModel {
private readonly userId: string;
private readonly db: LobeChatDatabase;
private readonly workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.db = db;
this.userId = userId;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, taskTopics);
/**
* Mirror a terminal taskTopic transition onto the underlying topic record:
* stamp `topics.completedAt` so duration can be computed at read time, and
@@ -29,7 +35,12 @@ export class TaskTopicModel {
await this.db
.update(topics)
.set(setClause)
.where(and(eq(topics.id, topicId), eq(topics.userId, this.userId)));
.where(
and(
eq(topics.id, topicId),
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, topics),
),
);
}
async add(
@@ -45,6 +56,7 @@ export class TaskTopicModel {
taskId,
topicId,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})
.onConflictDoNothing();
}
@@ -53,13 +65,7 @@ export class TaskTopicModel {
await this.db
.update(taskTopics)
.set({ status })
.where(
and(
eq(taskTopics.taskId, taskId),
eq(taskTopics.topicId, topicId),
eq(taskTopics.userId, this.userId),
),
);
.where(and(eq(taskTopics.taskId, taskId), eq(taskTopics.topicId, topicId), this.ownership()));
if (TERMINAL_TOPIC_STATUSES.has(status)) {
await this.markTopicEnded(topicId, status);
@@ -79,7 +85,7 @@ export class TaskTopicModel {
eq(taskTopics.taskId, taskId),
eq(taskTopics.topicId, topicId),
eq(taskTopics.status, 'running'),
eq(taskTopics.userId, this.userId),
this.ownership(),
),
)
.returning();
@@ -93,26 +99,14 @@ export class TaskTopicModel {
await this.db
.update(taskTopics)
.set({ operationId })
.where(
and(
eq(taskTopics.taskId, taskId),
eq(taskTopics.topicId, topicId),
eq(taskTopics.userId, this.userId),
),
);
.where(and(eq(taskTopics.taskId, taskId), eq(taskTopics.topicId, topicId), this.ownership()));
}
async updateHandoff(taskId: string, topicId: string, handoff: TaskTopicHandoff): Promise<void> {
await this.db
.update(taskTopics)
.set({ handoff })
.where(
and(
eq(taskTopics.taskId, taskId),
eq(taskTopics.topicId, topicId),
eq(taskTopics.userId, this.userId),
),
);
.where(and(eq(taskTopics.taskId, taskId), eq(taskTopics.topicId, topicId), this.ownership()));
}
/**
@@ -131,13 +125,7 @@ export class TaskTopicModel {
.set({
handoff: sql`jsonb_set(COALESCE(${taskTopics.handoff}, '{}'::jsonb), '{briefDecision}', ${JSON.stringify(decision)}::jsonb)`,
})
.where(
and(
eq(taskTopics.taskId, taskId),
eq(taskTopics.topicId, topicId),
eq(taskTopics.userId, this.userId),
),
);
.where(and(eq(taskTopics.taskId, taskId), eq(taskTopics.topicId, topicId), this.ownership()));
}
async updateReview(
@@ -159,26 +147,14 @@ export class TaskTopicModel {
reviewScores: review.scores,
reviewedAt: new Date(),
})
.where(
and(
eq(taskTopics.taskId, taskId),
eq(taskTopics.topicId, topicId),
eq(taskTopics.userId, this.userId),
),
);
.where(and(eq(taskTopics.taskId, taskId), eq(taskTopics.topicId, topicId), this.ownership()));
}
async timeoutRunning(taskId: string): Promise<number> {
const result = await this.db
.update(taskTopics)
.set({ status: 'timeout' })
.where(
and(
eq(taskTopics.taskId, taskId),
eq(taskTopics.status, 'running'),
eq(taskTopics.userId, this.userId),
),
)
.where(and(eq(taskTopics.taskId, taskId), eq(taskTopics.status, 'running'), this.ownership()))
.returning({ topicId: taskTopics.topicId });
await Promise.all(
@@ -195,13 +171,13 @@ export class TaskTopicModel {
const result = await this.db
.select()
.from(taskTopics)
.where(and(eq(taskTopics.topicId, topicId), eq(taskTopics.userId, this.userId)))
.where(and(eq(taskTopics.topicId, topicId), this.ownership()))
.limit(1);
return result[0] || null;
}
async countByTask(taskId: string, options?: { since?: Date }): Promise<number> {
const conditions = [eq(taskTopics.taskId, taskId), eq(taskTopics.userId, this.userId)];
const conditions = [eq(taskTopics.taskId, taskId), this.ownership()];
if (options?.since) conditions.push(gte(taskTopics.createdAt, options.since));
const rows = await this.db
@@ -215,7 +191,7 @@ export class TaskTopicModel {
return this.db
.select()
.from(taskTopics)
.where(and(eq(taskTopics.taskId, taskId), eq(taskTopics.userId, this.userId)))
.where(and(eq(taskTopics.taskId, taskId), this.ownership()))
.orderBy(desc(taskTopics.seq));
}
@@ -239,7 +215,7 @@ export class TaskTopicModel {
})
.from(taskTopics)
.innerJoin(topics, eq(taskTopics.topicId, topics.id))
.where(and(eq(taskTopics.taskId, taskId), eq(taskTopics.userId, this.userId)))
.where(and(eq(taskTopics.taskId, taskId), this.ownership()))
.orderBy(desc(taskTopics.seq));
}
@@ -258,7 +234,7 @@ export class TaskTopicModel {
})
.from(taskTopics)
.leftJoin(topics, eq(taskTopics.topicId, topics.id))
.where(and(eq(taskTopics.taskId, taskId), eq(taskTopics.userId, this.userId)))
.where(and(eq(taskTopics.taskId, taskId), this.ownership()))
.orderBy(desc(taskTopics.seq))
.limit(limit);
}
@@ -266,13 +242,7 @@ export class TaskTopicModel {
async remove(taskId: string, topicId: string): Promise<boolean> {
const result = await this.db
.delete(taskTopics)
.where(
and(
eq(taskTopics.taskId, taskId),
eq(taskTopics.topicId, topicId),
eq(taskTopics.userId, this.userId),
),
)
.where(and(eq(taskTopics.taskId, taskId), eq(taskTopics.topicId, topicId), this.ownership()))
.returning();
if (result.length > 0) {
+19 -8
View File
@@ -5,6 +5,7 @@ import { and, desc, eq, sql } from 'drizzle-orm';
import type { ThreadItem } from '../schemas';
import { messages, threads } from '../schemas';
import type { LobeChatDatabase } from '../type';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
/**
* Per-thread subagent metrics, derived from the child messages at read time
@@ -67,17 +68,27 @@ const queryColumns = {
export class ThreadModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, threads);
create = async (params: CreateThreadParams) => {
// @ts-ignore
const [result] = await this.db
.insert(threads)
.values({ status: ThreadStatus.Active, ...params, userId: this.userId })
.values(
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{ status: ThreadStatus.Active, ...params },
),
)
.onConflictDoNothing()
.returning();
@@ -85,18 +96,18 @@ export class ThreadModel {
};
delete = async (id: string) => {
return this.db.delete(threads).where(and(eq(threads.id, id), eq(threads.userId, this.userId)));
return this.db.delete(threads).where(and(eq(threads.id, id), this.ownership()));
};
deleteAll = async () => {
return this.db.delete(threads).where(eq(threads.userId, this.userId));
return this.db.delete(threads).where(this.ownership());
};
query = async () => {
const data = await this.db
.select(queryColumns)
.from(threads)
.where(eq(threads.userId, this.userId))
.where(this.ownership())
.orderBy(desc(threads.updatedAt));
return data as ThreadItem[];
@@ -110,7 +121,7 @@ export class ThreadModel {
.select({ ...queryColumns, ...subagentMetricColumns })
.from(threads)
.leftJoin(messages, eq(messages.threadId, threads.id))
.where(and(eq(threads.topicId, topicId), eq(threads.userId, this.userId)))
.where(and(eq(threads.topicId, topicId), this.ownership()))
.groupBy(threads.id)
.orderBy(desc(threads.updatedAt));
@@ -119,7 +130,7 @@ export class ThreadModel {
findById = async (id: string) => {
return this.db.query.threads.findFirst({
where: and(eq(threads.id, id), eq(threads.userId, this.userId)),
where: and(eq(threads.id, id), this.ownership()),
});
};
@@ -127,6 +138,6 @@ export class ThreadModel {
return this.db
.update(threads)
.set({ ...value, updatedAt: new Date() })
.where(and(eq(threads.id, id), eq(threads.userId, this.userId)));
.where(and(eq(threads.id, id), this.ownership()));
};
}
+73 -66
View File
@@ -35,6 +35,7 @@ import type { LobeChatDatabase } from '../type';
import { sanitizeBm25Query } from '../utils/bm25';
import { genEndDateWhere, genRangeWhere, genStartDateWhere, genWhere } from '../utils/genWhere';
import { idGenerator } from '../utils/idGenerator';
import { buildWorkspacePayload, buildWorkspaceWhere } from '../utils/workspace';
import { recomputeTopicUsage } from './topicUsage';
type OnboardingSessionMetadataPatch = Partial<NonNullable<ChatTopicMetadata['onboardingSession']>>;
@@ -141,11 +142,18 @@ const buildTopicOrderBy = (sortBy?: TopicQuerySortBy): SQL[] =>
export class TopicModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, topics);
private messageOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, messages);
// **************** Query *************** //
query = async ({
@@ -234,7 +242,7 @@ export class TopicModel {
// If groupId is provided, query topics by groupId directly
if (groupId) {
const whereCondition = and(
eq(topics.userId, this.userId),
this.ownership(),
eq(topics.groupId, groupId),
includeTriggerCondition,
excludeTriggerCondition,
@@ -299,7 +307,7 @@ export class TopicModel {
: eq(topics.agentId, agentId);
const agentWhere = and(
eq(topics.userId, this.userId),
this.ownership(),
agentCondition,
includeTriggerCondition,
excludeTriggerCondition,
@@ -355,7 +363,7 @@ export class TopicModel {
// Fallback to containerId-based query (backward compatibility)
const whereCondition = and(
eq(topics.userId, this.userId),
this.ownership(),
this.matchContainer(containerId),
includeTriggerCondition,
excludeTriggerCondition,
@@ -414,16 +422,12 @@ export class TopicModel {
findById = async (id: string) => {
return this.db.query.topics.findFirst({
where: and(eq(topics.id, id), eq(topics.userId, this.userId)),
where: and(eq(topics.id, id), this.ownership()),
});
};
queryAll = async (): Promise<TopicItem[]> => {
return this.db
.select()
.from(topics)
.orderBy(topics.updatedAt)
.where(eq(topics.userId, this.userId));
return this.db.select().from(topics).orderBy(topics.updatedAt).where(and(this.ownership()));
};
queryByKeyword = async (keyword: string, containerId?: string | null): Promise<TopicItem[]> => {
@@ -439,7 +443,7 @@ export class TopicModel {
.from(topics)
.where(
and(
eq(topics.userId, this.userId),
this.ownership(),
this.matchContainer(containerId),
sql`${topics.title} @@@ ${bm25Query}`,
),
@@ -452,9 +456,9 @@ export class TopicModel {
.innerJoin(topics, eq(messages.topicId, topics.id))
.where(
and(
eq(messages.userId, this.userId),
this.messageOwnership(),
sql`${messages.content} @@@ ${bm25Query}`,
eq(topics.userId, this.userId),
this.ownership(),
this.matchContainer(containerId),
),
)
@@ -472,7 +476,7 @@ export class TopicModel {
const topicsByMessages = await this.db.query.topics.findMany({
orderBy: [desc(topics.updatedAt)],
where: and(eq(topics.userId, this.userId), inArray(topics.id, topicIds)),
where: and(this.ownership(), inArray(topics.id, topicIds)),
});
// Merge results and deduplicate
@@ -509,7 +513,7 @@ export class TopicModel {
.from(topics)
.where(
genWhere([
eq(topics.userId, this.userId),
this.ownership(),
agentCondition,
params?.containerId ? this.matchContainer(params.containerId) : undefined,
params?.range
@@ -536,7 +540,7 @@ export class TopicModel {
title: topics.title,
})
.from(topics)
.where(and(eq(topics.userId, this.userId)))
.where(and(this.ownership()))
.leftJoin(messages, eq(topics.id, messages.topicId))
.groupBy(topics.id)
.orderBy(desc(sql`count`))
@@ -565,7 +569,7 @@ export class TopicModel {
.leftJoin(agents, eq(topics.agentId, agents.id))
.where(
and(
eq(topics.userId, this.userId),
this.ownership(),
or(
// Group topics: has groupId
not(isNull(topics.groupId)),
@@ -592,14 +596,16 @@ export class TopicModel {
id: string = this.genId(),
timing?: ModelTimingContext,
): Promise<TopicItem> => {
const insertData = {
...params,
agentId: params.agentId || null,
groupId: params.groupId || null,
id,
sessionId: params.sessionId || null,
userId: this.userId,
};
const insertData = buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{
...params,
agentId: params.agentId || null,
groupId: params.groupId || null,
id,
sessionId: params.sessionId || null,
},
);
const insertMeta = {
hasAgentId: !!params.agentId,
hasGroupId: !!params.groupId,
@@ -638,7 +644,7 @@ export class TopicModel {
tx
.update(messages)
.set({ topicId: topic.id })
.where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIds))),
.where(and(this.messageOwnership(), inArray(messages.id, messageIds))),
{ messageCount: messageIds.length },
);
@@ -660,16 +666,20 @@ export class TopicModel {
const createdTopics = await tx
.insert(topics)
.values(
topicParams.map((params) => ({
agentId: params.agentId || null,
favorite: params.favorite,
groupId: params.sessionId ? null : params.groupId,
id: params.id || this.genId(),
sessionId: params.groupId ? null : params.sessionId,
title: params.title,
trigger: params.trigger,
userId: this.userId,
})),
topicParams.map((params) =>
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{
agentId: params.agentId || null,
favorite: params.favorite,
groupId: params.sessionId ? null : params.groupId,
id: params.id || this.genId(),
sessionId: params.groupId ? null : params.sessionId,
title: params.title,
trigger: params.trigger,
},
),
),
)
.returning();
@@ -681,7 +691,7 @@ export class TopicModel {
await tx
.update(messages)
.set({ topicId: topic.id })
.where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIds)));
.where(and(this.messageOwnership(), inArray(messages.id, messageIds)));
}
}),
);
@@ -694,7 +704,7 @@ export class TopicModel {
return this.db.transaction(async (tx) => {
// find original topic
const originalTopic = await tx.query.topics.findFirst({
where: and(eq(topics.id, topicId), eq(topics.userId, this.userId)),
where: and(eq(topics.id, topicId), this.ownership()),
});
if (!originalTopic) {
@@ -704,19 +714,24 @@ export class TopicModel {
// copy topic
const [duplicatedTopic] = await tx
.insert(topics)
.values({
...originalTopic,
clientId: null,
id: this.genId(),
title: newTitle || originalTopic?.title,
})
.values(
buildWorkspacePayload(
{ userId: this.userId, workspaceId: this.workspaceId },
{
...originalTopic,
clientId: null,
id: this.genId(),
title: newTitle || originalTopic?.title,
},
),
)
.returning();
// Find messages associated with the original topic, ordered by createdAt
const originalMessages = await tx
.select()
.from(messages)
.where(and(eq(messages.topicId, topicId), eq(messages.userId, this.userId)))
.where(and(eq(messages.topicId, topicId), this.messageOwnership()))
.orderBy(messages.createdAt);
// Find all messagePlugins for this topic
@@ -800,47 +815,39 @@ export class TopicModel {
* Delete a session, also delete all messages and topics associated with it.
*/
delete = async (id: string) => {
return this.db.delete(topics).where(and(eq(topics.id, id), eq(topics.userId, this.userId)));
return this.db.delete(topics).where(and(eq(topics.id, id), this.ownership()));
};
/**
* Deletes multiple topics based on the sessionId.
*/
batchDeleteBySessionId = async (sessionId?: string | null) => {
return this.db
.delete(topics)
.where(and(this.matchSession(sessionId), eq(topics.userId, this.userId)));
return this.db.delete(topics).where(and(this.matchSession(sessionId), this.ownership()));
};
/**
* Deletes multiple topics based on the groupId.
*/
batchDeleteByGroupId = async (groupId?: string | null) => {
return this.db
.delete(topics)
.where(and(this.matchGroup(groupId), eq(topics.userId, this.userId)));
return this.db.delete(topics).where(and(this.matchGroup(groupId), this.ownership()));
};
/**
* Deletes all topics matching the given agentId (`topics.agentId`).
*/
batchDeleteByAgentId = async (agentId: string) => {
return this.db
.delete(topics)
.where(and(eq(topics.userId, this.userId), eq(topics.agentId, agentId)));
return this.db.delete(topics).where(and(this.ownership(), eq(topics.agentId, agentId)));
};
/**
* Deletes multiple topics and all messages associated with them in a transaction.
*/
batchDelete = async (ids: string[]) => {
return this.db
.delete(topics)
.where(and(inArray(topics.id, ids), eq(topics.userId, this.userId)));
return this.db.delete(topics).where(and(inArray(topics.id, ids), this.ownership()));
};
deleteAll = async () => {
return this.db.delete(topics).where(eq(topics.userId, this.userId));
return this.db.delete(topics).where(and(this.ownership()));
};
// **************** Update *************** //
@@ -849,7 +856,7 @@ export class TopicModel {
return this.db
.update(topics)
.set({ ...data, updatedAt: new Date() })
.where(and(eq(topics.id, id), eq(topics.userId, this.userId)))
.where(and(eq(topics.id, id), this.ownership()))
.returning();
};
@@ -860,7 +867,7 @@ export class TopicModel {
* external callers use this wrapper. Runs in a transaction for consistency.
*/
recomputeUsage = async (id: string) =>
this.db.transaction((trx) => recomputeTopicUsage(trx, this.userId, id));
this.db.transaction((trx) => recomputeTopicUsage(trx, this.userId, id, this.workspaceId));
/**
* Update topic metadata with merge logic
@@ -870,7 +877,7 @@ export class TopicModel {
// Get existing topic to merge metadata
const existing = await this.db.query.topics.findFirst({
columns: { metadata: true },
where: and(eq(topics.id, id), eq(topics.userId, this.userId)),
where: and(eq(topics.id, id), this.ownership()),
});
const mergedOnboardingSession =
@@ -890,7 +897,7 @@ export class TopicModel {
return this.db
.update(topics)
.set({ metadata: mergedMetadata })
.where(and(eq(topics.id, id), eq(topics.userId, this.userId)))
.where(and(eq(topics.id, id), this.ownership()))
.returning();
};
@@ -902,7 +909,7 @@ export class TopicModel {
.from(topics)
.where(
and(
eq(topics.userId, this.userId),
this.ownership(),
eq(topics.agentId, agentId),
eq(topics.trigger, 'cron'),
sql`(${topics.metadata}->>'cronJobId') IS NOT NULL`,
@@ -970,7 +977,7 @@ export class TopicModel {
limit: options.limit,
orderBy: (fields, { asc }) => [asc(fields.createdAt), asc(fields.id)],
where: and(
eq(topics.userId, this.userId),
this.ownership(),
options.startDate ? gte(topics.createdAt, options.startDate) : undefined,
options.endDate ? lte(topics.createdAt, options.endDate) : undefined,
options.ignoreExtracted
@@ -996,7 +1003,7 @@ export class TopicModel {
.from(topics)
.where(
and(
eq(topics.userId, this.userId),
this.ownership(),
options.startDate ? gte(topics.createdAt, options.startDate) : undefined,
options.endDate ? lte(topics.createdAt, options.endDate) : undefined,
options.ignoreExtracted
+14 -12
View File
@@ -3,6 +3,7 @@ import { and, desc, eq } from 'drizzle-orm';
import type { DocumentItem, NewTopicDocument } from '../schemas';
import { documents, topicDocuments } from '../schemas';
import type { LobeChatDatabase } from '../type';
import { buildWorkspaceWhere } from '../utils/workspace';
export interface TopicDocumentWithDetails extends DocumentItem {
associatedAt: Date;
@@ -11,12 +12,17 @@ export interface TopicDocumentWithDetails extends DocumentItem {
export class TopicDocumentModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, topicDocuments);
/**
* Associate a document with a topic.
*
@@ -30,7 +36,7 @@ export class TopicDocumentModel {
): Promise<{ documentId: string; topicId: string }> => {
await this.db
.insert(topicDocuments)
.values({ ...params, userId: this.userId })
.values({ ...params, userId: this.userId, workspaceId: this.workspaceId ?? null })
.onConflictDoNothing();
return { documentId: params.documentId, topicId: params.topicId };
@@ -46,7 +52,7 @@ export class TopicDocumentModel {
and(
eq(topicDocuments.documentId, documentId),
eq(topicDocuments.topicId, topicId),
eq(topicDocuments.userId, this.userId),
this.ownership(),
),
);
};
@@ -68,7 +74,7 @@ export class TopicDocumentModel {
.where(
and(
eq(topicDocuments.topicId, topicId),
eq(topicDocuments.userId, this.userId),
this.ownership(),
filter?.type ? eq(documents.fileType, filter.type) : undefined,
),
)
@@ -87,9 +93,7 @@ export class TopicDocumentModel {
const results = await this.db
.select({ topicId: topicDocuments.topicId })
.from(topicDocuments)
.where(
and(eq(topicDocuments.documentId, documentId), eq(topicDocuments.userId, this.userId)),
);
.where(and(eq(topicDocuments.documentId, documentId), this.ownership()));
return results.map((r) => r.topicId);
};
@@ -102,7 +106,7 @@ export class TopicDocumentModel {
where: and(
eq(topicDocuments.documentId, documentId),
eq(topicDocuments.topicId, topicId),
eq(topicDocuments.userId, this.userId),
this.ownership(),
),
});
@@ -115,7 +119,7 @@ export class TopicDocumentModel {
deleteByTopicId = async (topicId: string) => {
return this.db
.delete(topicDocuments)
.where(and(eq(topicDocuments.topicId, topicId), eq(topicDocuments.userId, this.userId)));
.where(and(eq(topicDocuments.topicId, topicId), this.ownership()));
};
/**
@@ -124,8 +128,6 @@ export class TopicDocumentModel {
deleteByDocumentId = async (documentId: string) => {
return this.db
.delete(topicDocuments)
.where(
and(eq(topicDocuments.documentId, documentId), eq(topicDocuments.userId, this.userId)),
);
.where(and(eq(topicDocuments.documentId, documentId), this.ownership()));
};
}
+16 -6
View File
@@ -4,6 +4,7 @@ import { and, asc, eq, sql } from 'drizzle-orm';
import { agents, chatGroups, chatGroupsAgents, topics, topicShares } from '../schemas';
import type { LobeChatDatabase } from '../type';
import { buildWorkspaceWhere } from '../utils/workspace';
export type TopicShareData = NonNullable<
Awaited<ReturnType<(typeof TopicShareModel)['findByShareId']>>
@@ -12,21 +13,29 @@ export type TopicShareData = NonNullable<
export class TopicShareModel {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
}
private ownership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, topicShares);
/**
* Create or get existing share for a topic.
* Each topic can only have one share record (enforced by unique constraint).
* If record already exists, returns the existing one.
*/
create = async (topicId: string, visibility: ShareVisibility = 'private') => {
// First verify the topic belongs to the user
// First verify the topic belongs to the user (or workspace).
const topic = await this.db.query.topics.findFirst({
where: and(eq(topics.id, topicId), eq(topics.userId, this.userId)),
where: and(
eq(topics.id, topicId),
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, topics),
),
});
if (!topic) {
@@ -39,6 +48,7 @@ export class TopicShareModel {
topicId,
userId: this.userId,
visibility,
workspaceId: this.workspaceId ?? null,
})
.onConflictDoNothing({ target: topicShares.topicId })
.returning();
@@ -58,7 +68,7 @@ export class TopicShareModel {
const [result] = await this.db
.update(topicShares)
.set({ updatedAt: new Date(), visibility })
.where(and(eq(topicShares.topicId, topicId), eq(topicShares.userId, this.userId)))
.where(and(eq(topicShares.topicId, topicId), this.ownership()))
.returning();
return result || null;
@@ -70,7 +80,7 @@ export class TopicShareModel {
deleteByTopicId = async (topicId: string) => {
return this.db
.delete(topicShares)
.where(and(eq(topicShares.topicId, topicId), eq(topicShares.userId, this.userId)));
.where(and(eq(topicShares.topicId, topicId), this.ownership()));
};
/**
@@ -85,7 +95,7 @@ export class TopicShareModel {
visibility: topicShares.visibility,
})
.from(topicShares)
.where(and(eq(topicShares.topicId, topicId), eq(topicShares.userId, this.userId)))
.where(and(eq(topicShares.topicId, topicId), this.ownership()))
.limit(1);
return result[0] || null;
+12 -3
View File
@@ -2,6 +2,7 @@ import { and, eq, sql } from 'drizzle-orm';
import { topics } from '../schemas';
import type { Transaction } from '../type';
import { buildWorkspaceWhere } from '../utils/workspace';
/**
* ModelUsage numeric fields summed per (provider, model) to build the
@@ -60,6 +61,7 @@ export const recomputeTopicUsage = async (
trx: Transaction,
userId: string,
topicId: string,
workspaceId?: string,
): Promise<void> => {
// Reads prefer the dedicated `usage` column, falling back to legacy
// `metadata->'usage'` for rows written before the migration.
@@ -67,6 +69,13 @@ export const recomputeTopicUsage = async (
(f) => `sum((COALESCE(usage, metadata->'usage')->>'${f}')::numeric) AS "${f}"`,
).join(',\n ');
// Workspace-aware ownership predicate for the raw messages aggregate: in team
// mode rows are scoped by workspace_id (creator user_id is not part of the
// filter); in personal mode by user_id with workspace_id IS NULL.
const messageOwnership = workspaceId
? sql`workspace_id = ${workspaceId}`
: sql`user_id = ${userId} AND workspace_id IS NULL`;
const { rows } = await trx.execute(sql`
SELECT
provider,
@@ -77,7 +86,7 @@ export const recomputeTopicUsage = async (
${sql.raw(fieldSelects)}
FROM messages
WHERE topic_id = ${topicId}
AND user_id = ${userId}
AND ${messageOwnership}
AND role = 'assistant'
AND (usage IS NOT NULL OR metadata ? 'usage')
GROUP BY provider, model
@@ -99,7 +108,7 @@ export const recomputeTopicUsage = async (
totalTokens: null,
usage: null,
})
.where(and(eq(topics.id, topicId), eq(topics.userId, userId)));
.where(and(eq(topics.id, topicId), buildWorkspaceWhere({ userId, workspaceId }, topics)));
return;
}
@@ -191,5 +200,5 @@ export const recomputeTopicUsage = async (
totalTokens,
usage,
})
.where(and(eq(topics.id, topicId), eq(topics.userId, userId)));
.where(and(eq(topics.id, topicId), buildWorkspaceWhere({ userId, workspaceId }, topics)));
};
@@ -16,6 +16,10 @@ export class UserMemoryActivityModel {
this.db = db;
}
private memoryWhere(table: { userId: any }) {
return eq(table.userId, this.userId);
}
create = async (params: Omit<NewUserMemoryActivity, 'userId'>) => {
const [result] = await this.db
.insert(userMemoriesActivities)
@@ -28,10 +32,7 @@ export class UserMemoryActivityModel {
delete = async (id: string) => {
return this.db.transaction(async (tx) => {
const activity = await tx.query.userMemoriesActivities.findFirst({
where: and(
eq(userMemoriesActivities.id, id),
eq(userMemoriesActivities.userId, this.userId),
),
where: and(eq(userMemoriesActivities.id, id), this.memoryWhere(userMemoriesActivities)),
});
if (!activity || !activity.userMemoryId) {
@@ -40,25 +41,21 @@ export class UserMemoryActivityModel {
await tx
.delete(userMemories)
.where(
and(eq(userMemories.id, activity.userMemoryId), eq(userMemories.userId, this.userId)),
);
.where(and(eq(userMemories.id, activity.userMemoryId), this.memoryWhere(userMemories)));
return { success: true };
});
};
deleteAll = async () => {
return this.db
.delete(userMemoriesActivities)
.where(eq(userMemoriesActivities.userId, this.userId));
return this.db.delete(userMemoriesActivities).where(this.memoryWhere(userMemoriesActivities));
};
query = async (limit = 50) => {
return this.db.query.userMemoriesActivities.findMany({
limit,
orderBy: [desc(userMemoriesActivities.createdAt)],
where: eq(userMemoriesActivities.userId, this.userId),
where: this.memoryWhere(userMemoriesActivities),
});
};
@@ -74,7 +71,7 @@ export class UserMemoryActivityModel {
: '';
const conditions: Array<SQL | undefined> = [
eq(userMemoriesActivities.userId, this.userId),
this.memoryWhere(userMemoriesActivities),
normalizedQuery
? sql`(${userMemories.id} @@@ paradedb.boolean(should => ARRAY[paradedb.match('title', ${bm25MatchQuery}, conjunction_mode => true)]) OR ${userMemoriesActivities.id} @@@ paradedb.boolean(should => ARRAY[paradedb.match('narrative', ${bm25MatchQuery}, conjunction_mode => true), paradedb.match('notes', ${bm25MatchQuery}, conjunction_mode => true), paradedb.match('feedback', ${bm25MatchQuery}, conjunction_mode => true)]))`
: undefined,
@@ -108,7 +105,7 @@ export class UserMemoryActivityModel {
const joinCondition = and(
eq(userMemories.id, userMemoriesActivities.userMemoryId),
eq(userMemories.userId, this.userId),
this.memoryWhere(userMemories),
);
const [rows, totalResult] = await Promise.all([
@@ -151,7 +148,7 @@ export class UserMemoryActivityModel {
findById = async (id: string) => {
return this.db.query.userMemoriesActivities.findFirst({
where: and(eq(userMemoriesActivities.id, id), eq(userMemoriesActivities.userId, this.userId)),
where: and(eq(userMemoriesActivities.id, id), this.memoryWhere(userMemoriesActivities)),
});
};
@@ -159,8 +156,6 @@ export class UserMemoryActivityModel {
return this.db
.update(userMemoriesActivities)
.set({ ...value, updatedAt: new Date() })
.where(
and(eq(userMemoriesActivities.id, id), eq(userMemoriesActivities.userId, this.userId)),
);
.where(and(eq(userMemoriesActivities.id, id), this.memoryWhere(userMemoriesActivities)));
};
}
@@ -13,6 +13,10 @@ export class UserMemoryContextModel {
this.db = db;
}
private memoryWhere(table: { userId: any }) {
return eq(table.userId, this.userId);
}
create = async (params: Omit<NewUserMemoryContext, 'userId'>) => {
const [result] = await this.db
.insert(userMemoriesContexts)
@@ -25,7 +29,7 @@ export class UserMemoryContextModel {
delete = async (id: string) => {
return this.db.transaction(async (tx) => {
const context = await tx.query.userMemoriesContexts.findFirst({
where: and(eq(userMemoriesContexts.id, id), eq(userMemoriesContexts.userId, this.userId)),
where: and(eq(userMemoriesContexts.id, id), this.memoryWhere(userMemoriesContexts)),
});
if (!context) {
@@ -41,34 +45,34 @@ export class UserMemoryContextModel {
for (const memoryId of memoryIds) {
await tx
.delete(userMemories)
.where(and(eq(userMemories.id, memoryId), eq(userMemories.userId, this.userId)));
.where(and(eq(userMemories.id, memoryId), this.memoryWhere(userMemories)));
}
}
// Delete the context entry
await tx
.delete(userMemoriesContexts)
.where(and(eq(userMemoriesContexts.id, id), eq(userMemoriesContexts.userId, this.userId)));
.where(and(eq(userMemoriesContexts.id, id), this.memoryWhere(userMemoriesContexts)));
return { success: true };
});
};
deleteAll = async () => {
return this.db.delete(userMemoriesContexts).where(eq(userMemoriesContexts.userId, this.userId));
return this.db.delete(userMemoriesContexts).where(this.memoryWhere(userMemoriesContexts));
};
query = async (limit = 50) => {
return this.db.query.userMemoriesContexts.findMany({
limit,
orderBy: [desc(userMemoriesContexts.createdAt)],
where: eq(userMemoriesContexts.userId, this.userId),
where: this.memoryWhere(userMemoriesContexts),
});
};
findById = async (id: string) => {
return this.db.query.userMemoriesContexts.findFirst({
where: and(eq(userMemoriesContexts.id, id), eq(userMemoriesContexts.userId, this.userId)),
where: and(eq(userMemoriesContexts.id, id), this.memoryWhere(userMemoriesContexts)),
});
};
@@ -76,6 +80,6 @@ export class UserMemoryContextModel {
return this.db
.update(userMemoriesContexts)
.set({ ...value, updatedAt: new Date() })
.where(and(eq(userMemoriesContexts.id, id), eq(userMemoriesContexts.userId, this.userId)));
.where(and(eq(userMemoriesContexts.id, id), this.memoryWhere(userMemoriesContexts)));
};
}
@@ -16,6 +16,10 @@ export class UserMemoryExperienceModel {
this.db = db;
}
private memoryWhere(table: { userId: any }) {
return eq(table.userId, this.userId);
}
create = async (params: Omit<NewUserMemoryExperience, 'userId'>) => {
const [result] = await this.db
.insert(userMemoriesExperiences)
@@ -28,10 +32,7 @@ export class UserMemoryExperienceModel {
delete = async (id: string) => {
return this.db.transaction(async (tx) => {
const experience = await tx.query.userMemoriesExperiences.findFirst({
where: and(
eq(userMemoriesExperiences.id, id),
eq(userMemoriesExperiences.userId, this.userId),
),
where: and(eq(userMemoriesExperiences.id, id), this.memoryWhere(userMemoriesExperiences)),
});
if (!experience || !experience.userMemoryId) {
@@ -41,25 +42,21 @@ export class UserMemoryExperienceModel {
// Delete the base user memory (cascade will handle the experience)
await tx
.delete(userMemories)
.where(
and(eq(userMemories.id, experience.userMemoryId), eq(userMemories.userId, this.userId)),
);
.where(and(eq(userMemories.id, experience.userMemoryId), this.memoryWhere(userMemories)));
return { success: true };
});
};
deleteAll = async () => {
return this.db
.delete(userMemoriesExperiences)
.where(eq(userMemoriesExperiences.userId, this.userId));
return this.db.delete(userMemoriesExperiences).where(this.memoryWhere(userMemoriesExperiences));
};
query = async (limit = 50) => {
return this.db.query.userMemoriesExperiences.findMany({
limit,
orderBy: [desc(userMemoriesExperiences.createdAt)],
where: eq(userMemoriesExperiences.userId, this.userId),
where: this.memoryWhere(userMemoriesExperiences),
});
};
@@ -80,7 +77,7 @@ export class UserMemoryExperienceModel {
// Build WHERE conditions
const conditions: Array<SQL | undefined> = [
eq(userMemoriesExperiences.userId, this.userId),
this.memoryWhere(userMemoriesExperiences),
// Full-text search across title, situation, keyLearning, action
normalizedQuery
? sql`(${userMemories.id} @@@ paradedb.boolean(should => ARRAY[paradedb.match('title', ${bm25MatchQuery}, conjunction_mode => true)]) OR ${userMemoriesExperiences.id} @@@ paradedb.boolean(should => ARRAY[paradedb.match('situation', ${bm25MatchQuery}, conjunction_mode => true), paradedb.match('key_learning', ${bm25MatchQuery}, conjunction_mode => true), paradedb.match('action', ${bm25MatchQuery}, conjunction_mode => true)]))`
@@ -110,7 +107,7 @@ export class UserMemoryExperienceModel {
// JOIN condition
const joinCondition = and(
eq(userMemories.id, userMemoriesExperiences.userMemoryId),
eq(userMemories.userId, this.userId),
this.memoryWhere(userMemories),
);
// Execute queries in parallel
@@ -152,10 +149,7 @@ export class UserMemoryExperienceModel {
findById = async (id: string) => {
return this.db.query.userMemoriesExperiences.findFirst({
where: and(
eq(userMemoriesExperiences.id, id),
eq(userMemoriesExperiences.userId, this.userId),
),
where: and(eq(userMemoriesExperiences.id, id), this.memoryWhere(userMemoriesExperiences)),
});
};
@@ -163,8 +157,6 @@ export class UserMemoryExperienceModel {
return this.db
.update(userMemoriesExperiences)
.set({ ...value, updatedAt: new Date() })
.where(
and(eq(userMemoriesExperiences.id, id), eq(userMemoriesExperiences.userId, this.userId)),
);
.where(and(eq(userMemoriesExperiences.id, id), this.memoryWhere(userMemoriesExperiences)));
};
}
@@ -17,6 +17,10 @@ export class UserMemoryIdentityModel {
this.db = db;
}
private memoryWhere(table: { userId: any }) {
return eq(table.userId, this.userId);
}
create = async (params: Omit<NewUserMemoryIdentity, 'userId'>) => {
const [result] = await this.db
.insert(userMemoriesIdentities)
@@ -29,10 +33,7 @@ export class UserMemoryIdentityModel {
delete = async (id: string) => {
return this.db.transaction(async (tx) => {
const identity = await tx.query.userMemoriesIdentities.findFirst({
where: and(
eq(userMemoriesIdentities.id, id),
eq(userMemoriesIdentities.userId, this.userId),
),
where: and(eq(userMemoriesIdentities.id, id), this.memoryWhere(userMemoriesIdentities)),
});
if (!identity || !identity.userMemoryId) {
@@ -42,25 +43,21 @@ export class UserMemoryIdentityModel {
// Delete the base user memory (cascade will handle the identity)
await tx
.delete(userMemories)
.where(
and(eq(userMemories.id, identity.userMemoryId), eq(userMemories.userId, this.userId)),
);
.where(and(eq(userMemories.id, identity.userMemoryId), this.memoryWhere(userMemories)));
return { success: true };
});
};
deleteAll = async () => {
return this.db
.delete(userMemoriesIdentities)
.where(eq(userMemoriesIdentities.userId, this.userId));
return this.db.delete(userMemoriesIdentities).where(this.memoryWhere(userMemoriesIdentities));
};
query = async (limit = 50) => {
return this.db.query.userMemoriesIdentities.findMany({
limit,
orderBy: [desc(userMemoriesIdentities.capturedAt)],
where: eq(userMemoriesIdentities.userId, this.userId),
where: this.memoryWhere(userMemoriesIdentities),
});
};
@@ -81,7 +78,7 @@ export class UserMemoryIdentityModel {
// Build WHERE conditions
const conditions: Array<SQL | undefined> = [
eq(userMemoriesIdentities.userId, this.userId),
this.memoryWhere(userMemoriesIdentities),
// Full-text search across title, description, role
normalizedQuery
? sql`(${userMemories.id} @@@ paradedb.boolean(should => ARRAY[paradedb.match('title', ${bm25MatchQuery}, conjunction_mode => true)]) OR ${userMemoriesIdentities.id} @@@ paradedb.boolean(should => ARRAY[paradedb.match('description', ${bm25MatchQuery}, conjunction_mode => true), paradedb.match('role', ${bm25MatchQuery}, conjunction_mode => true)]))`
@@ -113,7 +110,7 @@ export class UserMemoryIdentityModel {
// JOIN condition
const joinCondition = and(
eq(userMemories.id, userMemoriesIdentities.userMemoryId),
eq(userMemories.userId, this.userId),
this.memoryWhere(userMemories),
);
// Execute queries in parallel
@@ -155,7 +152,7 @@ export class UserMemoryIdentityModel {
findById = async (id: string) => {
return this.db.query.userMemoriesIdentities.findFirst({
where: and(eq(userMemoriesIdentities.id, id), eq(userMemoriesIdentities.userId, this.userId)),
where: and(eq(userMemoriesIdentities.id, id), this.memoryWhere(userMemoriesIdentities)),
});
};
@@ -163,9 +160,7 @@ export class UserMemoryIdentityModel {
return this.db
.update(userMemoriesIdentities)
.set({ ...value, updatedAt: new Date() })
.where(
and(eq(userMemoriesIdentities.id, id), eq(userMemoriesIdentities.userId, this.userId)),
);
.where(and(eq(userMemoriesIdentities.id, id), this.memoryWhere(userMemoriesIdentities)));
};
/**
@@ -187,7 +182,7 @@ export class UserMemoryIdentityModel {
.from(userMemoriesIdentities)
.where(
and(
eq(userMemoriesIdentities.userId, this.userId),
this.memoryWhere(userMemoriesIdentities),
// Only include self identities (relationship is 'self' or null/not set)
or(
eq(userMemoriesIdentities.relationship, RelationshipEnum.Self),
@@ -246,7 +246,7 @@ export interface UserMemorySearchAggregatedResult {
preferences: UserMemoryPreferenceWithoutVectors[];
}
const pickSingleSearchType = (types?: string[]) => (types?.length === 1 ? types[0] : undefined);
const _pickSingleSearchType = (types?: string[]) => (types?.length === 1 ? types[0] : undefined);
export interface UpdateUserMemoryVectorsParams {
detailsVector1024?: number[] | null;
@@ -555,6 +555,10 @@ export class UserMemoryModel {
this.topicModel = new TopicModel(db, userId);
}
private memoryWhere(table: { userId: any }) {
return eq(table.userId, this.userId);
}
private extractSourceMetadata(metadata?: Record<string, unknown> | null): {
sourceId?: string;
sourceType?: MemorySourceType;
@@ -830,7 +834,7 @@ export class UserMemoryModel {
const { layers, page = 1, size = 10 } = params;
const offset = (page - 1) * size;
const conditions = [eq(userMemories.userId, this.userId)];
const conditions = [this.memoryWhere(userMemories)];
if (layers && layers.length > 0) {
conditions.push(inArray(userMemories.memoryLayer, layers));
}
@@ -867,7 +871,7 @@ export class UserMemoryModel {
const offset = (page - 1) * size;
const identityConditions = [
eq(userMemoriesIdentities.userId, this.userId),
this.memoryWhere(userMemoriesIdentities),
eq(userMemoriesIdentities.relationship, RelationshipEnum.Self),
];
@@ -976,7 +980,7 @@ export class UserMemoryModel {
const supportsBm25 = !isPGliteDatabase(this.db);
const conditions: Array<SQL | undefined> = [
eq(userMemories.userId, this.userId),
this.memoryWhere(userMemories),
categories && categories.length > 0
? inArray(userMemories.memoryCategory, categories)
: undefined,
@@ -1142,7 +1146,7 @@ export class UserMemoryModel {
);
const joinCondition = and(
eq(userMemories.id, userMemoriesActivities.userMemoryId),
eq(userMemoriesActivities.userId, this.userId),
this.memoryWhere(userMemoriesActivities),
);
const activityFilters: Array<SQL | undefined> = [
@@ -1257,7 +1261,7 @@ export class UserMemoryModel {
);
const joinCondition = and(
eq(userMemories.id, userMemoriesExperiences.userMemoryId),
eq(userMemoriesExperiences.userId, this.userId),
this.memoryWhere(userMemoriesExperiences),
);
const experienceFilters: Array<SQL | undefined> = [
@@ -1352,7 +1356,7 @@ export class UserMemoryModel {
);
const joinCondition = and(
eq(userMemories.id, userMemoriesIdentities.userMemoryId),
eq(userMemoriesIdentities.userId, this.userId),
this.memoryWhere(userMemoriesIdentities),
);
const identityFilters: Array<SQL | undefined> = [
@@ -1450,7 +1454,7 @@ export class UserMemoryModel {
);
const joinCondition = and(
eq(userMemories.id, userMemoriesPreferences.userMemoryId),
eq(userMemoriesPreferences.userId, this.userId),
this.memoryWhere(userMemoriesPreferences),
);
const preferenceFilters: Array<SQL | undefined> = [
@@ -1585,7 +1589,7 @@ export class UserMemoryModel {
const activitySelection = selectNonVectorColumns(userMemoriesActivities);
const baseConditions: Array<SQL | undefined> = [
eq(userMemories.userId, this.userId),
this.memoryWhere(userMemories),
eq(userMemories.memoryLayer, layer),
];
const baseWhere = baseConditions.filter(Boolean) as SQL[];
@@ -1638,7 +1642,7 @@ export class UserMemoryModel {
);
const joinCondition = and(
eq(userMemories.id, userMemoriesExperiences.userMemoryId),
eq(userMemoriesExperiences.userId, this.userId),
this.memoryWhere(userMemoriesExperiences),
);
const experienceFilters: Array<SQL | undefined> = [
@@ -1671,7 +1675,7 @@ export class UserMemoryModel {
);
const joinCondition = and(
eq(userMemories.id, userMemoriesIdentities.userMemoryId),
eq(userMemoriesIdentities.userId, this.userId),
this.memoryWhere(userMemoriesIdentities),
);
const identityFilters: Array<SQL | undefined> = [
@@ -1704,7 +1708,7 @@ export class UserMemoryModel {
);
const joinCondition = and(
eq(userMemories.id, userMemoriesPreferences.userMemoryId),
eq(userMemoriesPreferences.userId, this.userId),
this.memoryWhere(userMemoriesPreferences),
);
const preferenceFilters: Array<SQL | undefined> = [
@@ -1763,7 +1767,7 @@ export class UserMemoryModel {
userMemoryIds: userMemoriesContexts.userMemoryIds,
})
.from(userMemoriesContexts)
.where(and(eq(userMemoriesContexts.id, id), eq(userMemoriesContexts.userId, this.userId)))
.where(and(eq(userMemoriesContexts.id, id), this.memoryWhere(userMemoriesContexts)))
.limit(1);
if (!context) {
return undefined;
@@ -1822,9 +1826,7 @@ export class UserMemoryModel {
userMemoryId: userMemoriesActivities.userMemoryId,
})
.from(userMemoriesActivities)
.where(
and(eq(userMemoriesActivities.id, id), eq(userMemoriesActivities.userId, this.userId)),
)
.where(and(eq(userMemoriesActivities.id, id), this.memoryWhere(userMemoriesActivities)))
.limit(1);
if (!activity?.userMemoryId) {
return undefined;
@@ -1869,12 +1871,7 @@ export class UserMemoryModel {
userMemoryId: userMemoriesExperiences.userMemoryId,
})
.from(userMemoriesExperiences)
.where(
and(
eq(userMemoriesExperiences.id, id),
eq(userMemoriesExperiences.userId, this.userId),
),
)
.where(and(eq(userMemoriesExperiences.id, id), this.memoryWhere(userMemoriesExperiences)))
.limit(1);
if (!experience?.userMemoryId) {
return undefined;
@@ -1918,9 +1915,7 @@ export class UserMemoryModel {
userMemoryId: userMemoriesIdentities.userMemoryId,
})
.from(userMemoriesIdentities)
.where(
and(eq(userMemoriesIdentities.id, id), eq(userMemoriesIdentities.userId, this.userId)),
)
.where(and(eq(userMemoriesIdentities.id, id), this.memoryWhere(userMemoriesIdentities)))
.limit(1);
if (!identity?.userMemoryId) {
return undefined;
@@ -1963,12 +1958,7 @@ export class UserMemoryModel {
userMemoryId: userMemoriesPreferences.userMemoryId,
})
.from(userMemoriesPreferences)
.where(
and(
eq(userMemoriesPreferences.id, id),
eq(userMemoriesPreferences.userId, this.userId),
),
)
.where(and(eq(userMemoriesPreferences.id, id), this.memoryWhere(userMemoriesPreferences)))
.limit(1);
if (!preference?.userMemoryId) {
return undefined;
@@ -2020,7 +2010,7 @@ export class UserMemoryModel {
userId: userMemories.userId,
})
.from(userMemories)
.where(and(eq(userMemories.id, memoryId), eq(userMemories.userId, this.userId)))
.where(and(eq(userMemories.id, memoryId), this.memoryWhere(userMemories)))
.limit(1);
if (!memory) {
return undefined;
@@ -2031,7 +2021,7 @@ export class UserMemoryModel {
findById = async (id: string): Promise<UserMemoryItem | undefined> => {
const result = await this.db.query.userMemories.findFirst({
where: and(eq(userMemories.id, id), eq(userMemories.userId, this.userId)),
where: and(eq(userMemories.id, id), this.memoryWhere(userMemories)),
});
if (result) {
@@ -2045,7 +2035,7 @@ export class UserMemoryModel {
await this.db
.update(userMemories)
.set({ ...params, updatedAt: new Date() })
.where(and(eq(userMemories.id, id), eq(userMemories.userId, this.userId)));
.where(and(eq(userMemories.id, id), this.memoryWhere(userMemories)));
};
updateUserMemoryVectors = async (
@@ -2070,7 +2060,7 @@ export class UserMemoryModel {
...vectorUpdates,
updatedAt: new Date(),
})
.where(and(eq(userMemories.id, id), eq(userMemories.userId, this.userId)));
.where(and(eq(userMemories.id, id), this.memoryWhere(userMemories)));
};
updateContextVectors = async (id: string, vectors: UpdateContextVectorsParams): Promise<void> => {
@@ -2088,7 +2078,7 @@ export class UserMemoryModel {
...vectorUpdates,
updatedAt: new Date(),
})
.where(and(eq(userMemoriesContexts.id, id), eq(userMemoriesContexts.userId, this.userId)));
.where(and(eq(userMemoriesContexts.id, id), this.memoryWhere(userMemoriesContexts)));
};
updatePreferenceVectors = async (
@@ -2110,9 +2100,7 @@ export class UserMemoryModel {
...vectorUpdates,
updatedAt: new Date(),
})
.where(
and(eq(userMemoriesPreferences.id, id), eq(userMemoriesPreferences.userId, this.userId)),
);
.where(and(eq(userMemoriesPreferences.id, id), this.memoryWhere(userMemoriesPreferences)));
};
updateIdentityVectors = async (
@@ -2134,9 +2122,7 @@ export class UserMemoryModel {
...vectorUpdates,
updatedAt: new Date(),
})
.where(
and(eq(userMemoriesIdentities.id, id), eq(userMemoriesIdentities.userId, this.userId)),
);
.where(and(eq(userMemoriesIdentities.id, id), this.memoryWhere(userMemoriesIdentities)));
};
updateExperienceVectors = async (
@@ -2164,9 +2150,7 @@ export class UserMemoryModel {
...vectorUpdates,
updatedAt: new Date(),
})
.where(
and(eq(userMemoriesExperiences.id, id), eq(userMemoriesExperiences.userId, this.userId)),
);
.where(and(eq(userMemoriesExperiences.id, id), this.memoryWhere(userMemoriesExperiences)));
};
updateActivityVectors = async (
@@ -2191,9 +2175,7 @@ export class UserMemoryModel {
...vectorUpdates,
updatedAt: new Date(),
})
.where(
and(eq(userMemoriesActivities.id, id), eq(userMemoriesActivities.userId, this.userId)),
);
.where(and(eq(userMemoriesActivities.id, id), this.memoryWhere(userMemoriesActivities)));
};
addIdentityEntry = async (params: AddIdentityEntryParams): Promise<AddIdentityEntryResult> => {
@@ -2269,7 +2251,7 @@ export class UserMemoryModel {
const identity = await tx.query.userMemoriesIdentities.findFirst({
where: and(
eq(userMemoriesIdentities.id, params.identityId),
eq(userMemoriesIdentities.userId, this.userId),
this.memoryWhere(userMemoriesIdentities),
),
});
if (!identity || !identity.userMemoryId) {
@@ -2290,9 +2272,7 @@ export class UserMemoryModel {
await tx
.update(userMemories)
.set(baseUpdate)
.where(
and(eq(userMemories.id, identity.userMemoryId), eq(userMemories.userId, this.userId)),
);
.where(and(eq(userMemories.id, identity.userMemoryId), this.memoryWhere(userMemories)));
}
}
@@ -2356,7 +2336,7 @@ export class UserMemoryModel {
.where(
and(
eq(userMemoriesIdentities.id, params.identityId),
eq(userMemoriesIdentities.userId, this.userId),
this.memoryWhere(userMemoriesIdentities),
),
);
}
@@ -2371,7 +2351,7 @@ export class UserMemoryModel {
const identity = await tx.query.userMemoriesIdentities.findFirst({
where: and(
eq(userMemoriesIdentities.id, identityId),
eq(userMemoriesIdentities.userId, this.userId),
this.memoryWhere(userMemoriesIdentities),
),
});
@@ -2381,9 +2361,7 @@ export class UserMemoryModel {
await tx
.delete(userMemories)
.where(
and(eq(userMemories.id, identity.userMemoryId), eq(userMemories.userId, this.userId)),
);
.where(and(eq(userMemories.id, identity.userMemoryId), this.memoryWhere(userMemories)));
return true;
});
@@ -2392,10 +2370,7 @@ export class UserMemoryModel {
removeContextEntry = async (contextId: string): Promise<boolean> => {
return this.db.transaction(async (tx) => {
const context = await tx.query.userMemoriesContexts.findFirst({
where: and(
eq(userMemoriesContexts.id, contextId),
eq(userMemoriesContexts.userId, this.userId),
),
where: and(eq(userMemoriesContexts.id, contextId), this.memoryWhere(userMemoriesContexts)),
});
if (!context) {
@@ -2409,15 +2384,13 @@ export class UserMemoryModel {
if (memoryIds.length > 0) {
await tx
.delete(userMemories)
.where(and(inArray(userMemories.id, memoryIds), eq(userMemories.userId, this.userId)));
.where(and(inArray(userMemories.id, memoryIds), this.memoryWhere(userMemories)));
}
// Delete the context entry
await tx
.delete(userMemoriesContexts)
.where(
and(eq(userMemoriesContexts.id, contextId), eq(userMemoriesContexts.userId, this.userId)),
);
.where(and(eq(userMemoriesContexts.id, contextId), this.memoryWhere(userMemoriesContexts)));
return true;
});
@@ -2428,7 +2401,7 @@ export class UserMemoryModel {
const experience = await tx.query.userMemoriesExperiences.findFirst({
where: and(
eq(userMemoriesExperiences.id, experienceId),
eq(userMemoriesExperiences.userId, this.userId),
this.memoryWhere(userMemoriesExperiences),
),
});
@@ -2439,9 +2412,7 @@ export class UserMemoryModel {
// Delete the base user memory (cascade will handle the experience)
await tx
.delete(userMemories)
.where(
and(eq(userMemories.id, experience.userMemoryId), eq(userMemories.userId, this.userId)),
);
.where(and(eq(userMemories.id, experience.userMemoryId), this.memoryWhere(userMemories)));
return true;
});
@@ -2452,7 +2423,7 @@ export class UserMemoryModel {
const preference = await tx.query.userMemoriesPreferences.findFirst({
where: and(
eq(userMemoriesPreferences.id, preferenceId),
eq(userMemoriesPreferences.userId, this.userId),
this.memoryWhere(userMemoriesPreferences),
),
});
@@ -2463,9 +2434,7 @@ export class UserMemoryModel {
// Delete the base user memory (cascade will handle the preference)
await tx
.delete(userMemories)
.where(
and(eq(userMemories.id, preference.userMemoryId), eq(userMemories.userId, this.userId)),
);
.where(and(eq(userMemories.id, preference.userMemoryId), this.memoryWhere(userMemories)));
return true;
});
@@ -2474,11 +2443,11 @@ export class UserMemoryModel {
delete = async (id: string): Promise<void> => {
await this.db
.delete(userMemories)
.where(and(eq(userMemories.id, id), eq(userMemories.userId, this.userId)));
.where(and(eq(userMemories.id, id), this.memoryWhere(userMemories)));
};
deleteAll = async (): Promise<void> => {
await this.db.delete(userMemories).where(eq(userMemories.userId, this.userId));
await this.db.delete(userMemories).where(this.memoryWhere(userMemories));
};
searchActivities = async (params: {
@@ -2520,7 +2489,7 @@ export class UserMemoryModel {
.from(userMemoriesActivities)
.$dynamic();
const conditions = [eq(userMemoriesActivities.userId, this.userId)];
const conditions = [this.memoryWhere(userMemoriesActivities)];
if (type) {
conditions.push(eq(userMemoriesActivities.type, type));
}
@@ -2571,7 +2540,7 @@ export class UserMemoryModel {
.from(userMemoriesContexts)
.$dynamic();
const conditions = [eq(userMemoriesContexts.userId, this.userId)];
const conditions = [this.memoryWhere(userMemoriesContexts)];
if (type) {
conditions.push(eq(userMemoriesContexts.type, type));
}
@@ -2622,7 +2591,7 @@ export class UserMemoryModel {
.from(userMemoriesExperiences)
.$dynamic();
const conditions = [eq(userMemoriesExperiences.userId, this.userId)];
const conditions = [this.memoryWhere(userMemoriesExperiences)];
if (type) {
conditions.push(eq(userMemoriesExperiences.type, type));
}
@@ -2669,7 +2638,7 @@ export class UserMemoryModel {
.from(userMemoriesPreferences)
.$dynamic();
const conditions = [eq(userMemoriesPreferences.userId, this.userId)];
const conditions = [this.memoryWhere(userMemoriesPreferences)];
if (type) {
conditions.push(eq(userMemoriesPreferences.type, type));
}
@@ -2688,7 +2657,7 @@ export class UserMemoryModel {
const res = await this.db
.select(selectNonVectorColumns(userMemoriesIdentities))
.from(userMemoriesIdentities)
.where(eq(userMemoriesIdentities.userId, this.userId))
.where(this.memoryWhere(userMemoriesIdentities))
.orderBy(desc(userMemoriesIdentities.capturedAt), desc(userMemoriesIdentities.createdAt));
return res;
@@ -2702,7 +2671,7 @@ export class UserMemoryModel {
})
.from(userMemoriesIdentities)
.innerJoin(userMemories, eq(userMemories.id, userMemoriesIdentities.userMemoryId))
.where(eq(userMemoriesIdentities.userId, this.userId))
.where(this.memoryWhere(userMemoriesIdentities))
.orderBy(desc(userMemoriesIdentities.capturedAt), desc(userMemoriesIdentities.createdAt));
return res;
@@ -2712,9 +2681,7 @@ export class UserMemoryModel {
const res = await this.db
.select(selectNonVectorColumns(userMemoriesIdentities))
.from(userMemoriesIdentities)
.where(
and(eq(userMemoriesIdentities.userId, this.userId), eq(userMemoriesIdentities.type, type)),
)
.where(and(this.memoryWhere(userMemoriesIdentities), eq(userMemoriesIdentities.type, type)))
.orderBy(desc(userMemoriesIdentities.capturedAt), desc(userMemoriesIdentities.createdAt));
return res;
@@ -2744,7 +2711,7 @@ export class UserMemoryModel {
accessedCount: sql`${userMemories.accessedCount} + 1`,
lastAccessedAt: now,
})
.where(and(eq(userMemories.userId, this.userId), eq(userMemories.id, memoryId)));
.where(and(this.memoryWhere(userMemories), eq(userMemories.id, memoryId)));
}
const memories = await tx
@@ -2753,9 +2720,7 @@ export class UserMemoryModel {
layer: userMemories.memoryLayer,
})
.from(userMemories)
.where(
and(eq(userMemories.userId, this.userId), inArray(userMemories.id, orderedMemoryIds)),
);
.where(and(this.memoryWhere(userMemories), inArray(userMemories.id, orderedMemoryIds)));
const experienceIds = memories
.filter((memory) => memory.layer === 'experience')
@@ -2766,7 +2731,7 @@ export class UserMemoryModel {
.set({ accessedAt: now })
.where(
and(
eq(userMemoriesExperiences.userId, this.userId),
this.memoryWhere(userMemoriesExperiences),
inArray(userMemoriesExperiences.userMemoryId, experienceIds),
),
);
@@ -2781,7 +2746,7 @@ export class UserMemoryModel {
.set({ accessedAt: now })
.where(
and(
eq(userMemoriesIdentities.userId, this.userId),
this.memoryWhere(userMemoriesIdentities),
inArray(userMemoriesIdentities.userMemoryId, identityIds),
),
);
@@ -2796,7 +2761,7 @@ export class UserMemoryModel {
.set({ accessedAt: now })
.where(
and(
eq(userMemoriesPreferences.userId, this.userId),
this.memoryWhere(userMemoriesPreferences),
inArray(userMemoriesPreferences.userMemoryId, preferenceIds),
),
);
@@ -2809,7 +2774,7 @@ export class UserMemoryModel {
.set({ accessedAt: now })
.where(
and(
eq(userMemoriesContexts.userId, this.userId),
this.memoryWhere(userMemoriesContexts),
inArray(userMemoriesContexts.id, orderedContextIds),
),
);
@@ -13,6 +13,10 @@ export class UserMemoryPreferenceModel {
this.db = db;
}
private memoryWhere(table: { userId: any }) {
return eq(table.userId, this.userId);
}
create = async (params: Omit<NewUserMemoryPreference, 'userId'>) => {
const [result] = await this.db
.insert(userMemoriesPreferences)
@@ -25,10 +29,7 @@ export class UserMemoryPreferenceModel {
delete = async (id: string) => {
return this.db.transaction(async (tx) => {
const preference = await tx.query.userMemoriesPreferences.findFirst({
where: and(
eq(userMemoriesPreferences.id, id),
eq(userMemoriesPreferences.userId, this.userId),
),
where: and(eq(userMemoriesPreferences.id, id), this.memoryWhere(userMemoriesPreferences)),
});
if (!preference || !preference.userMemoryId) {
@@ -38,34 +39,27 @@ export class UserMemoryPreferenceModel {
// Delete the base user memory (cascade will handle the preference)
await tx
.delete(userMemories)
.where(
and(eq(userMemories.id, preference.userMemoryId), eq(userMemories.userId, this.userId)),
);
.where(and(eq(userMemories.id, preference.userMemoryId), this.memoryWhere(userMemories)));
return { success: true };
});
};
deleteAll = async () => {
return this.db
.delete(userMemoriesPreferences)
.where(eq(userMemoriesPreferences.userId, this.userId));
return this.db.delete(userMemoriesPreferences).where(this.memoryWhere(userMemoriesPreferences));
};
query = async (limit = 50) => {
return this.db.query.userMemoriesPreferences.findMany({
limit,
orderBy: [desc(userMemoriesPreferences.createdAt)],
where: eq(userMemoriesPreferences.userId, this.userId),
where: this.memoryWhere(userMemoriesPreferences),
});
};
findById = async (id: string) => {
return this.db.query.userMemoriesPreferences.findFirst({
where: and(
eq(userMemoriesPreferences.id, id),
eq(userMemoriesPreferences.userId, this.userId),
),
where: and(eq(userMemoriesPreferences.id, id), this.memoryWhere(userMemoriesPreferences)),
});
};
@@ -73,8 +67,6 @@ export class UserMemoryPreferenceModel {
return this.db
.update(userMemoriesPreferences)
.set({ ...value, updatedAt: new Date() })
.where(
and(eq(userMemoriesPreferences.id, id), eq(userMemoriesPreferences.userId, this.userId)),
);
.where(and(eq(userMemoriesPreferences.id, id), this.memoryWhere(userMemoriesPreferences)));
};
}
@@ -664,6 +664,10 @@ export class UserMemoryQueryModel {
private readonly userId: string,
) {}
private memoryWhere(table: { userId: any }) {
return eq(table.userId, this.userId);
}
/**
* Hybrid memory retrieval pipeline for the five heterogeneous memory layers.
*
@@ -1118,7 +1122,7 @@ export class UserMemoryQueryModel {
updatedAt: userMemories.updatedAt,
})
.from(userMemories)
.where(and(eq(userMemories.userId, this.userId), inArray(userMemories.id, memoryIds)));
.where(and(this.memoryWhere(userMemories), inArray(userMemories.id, memoryIds)));
const baseMemoryMap = new Map(
baseMemories.map((memory) => [
@@ -1215,7 +1219,7 @@ export class UserMemoryQueryModel {
}): Promise<QueryTaxonomyOptionsResult['categories']> {
const { column, layers, limit, q, timeRange } = params;
const conditions = [
eq(userMemories.userId, this.userId),
this.memoryWhere(userMemories),
layers?.length ? inArray(userMemories.memoryLayer, layers) : undefined,
this.buildTimeRangeCondition(
{
@@ -1255,7 +1259,7 @@ export class UserMemoryQueryModel {
}): Promise<QueryTaxonomyOptionsResult['tags']> {
const { column, layers, limit, q, timeRange } = params;
const conditions = [
eq(userMemories.userId, this.userId),
this.memoryWhere(userMemories),
layers?.length ? inArray(userMemories.memoryLayer, layers) : undefined,
this.buildTimeRangeCondition(
{
@@ -1557,7 +1561,7 @@ export class UserMemoryQueryModel {
.from(userMemoriesIdentities)
.where(
and(
eq(userMemoriesIdentities.userId, this.userId),
this.memoryWhere(userMemoriesIdentities),
this.buildTimeRangeCondition(
{
capturedAt: userMemoriesIdentities.capturedAt,
@@ -1765,8 +1769,8 @@ export class UserMemoryQueryModel {
params: SearchMemoryParams,
) {
const conditions = [
eq(userMemoriesActivities.userId, this.userId),
eq(userMemories.userId, this.userId),
this.memoryWhere(userMemoriesActivities),
this.memoryWhere(userMemories),
params.categories?.length
? inArray(userMemories.memoryCategory, params.categories)
: undefined,
@@ -1825,8 +1829,8 @@ export class UserMemoryQueryModel {
params: SearchMemoryParams,
) {
const conditions = [
eq(userMemoriesContexts.userId, this.userId),
eq(userMemories.userId, this.userId),
this.memoryWhere(userMemoriesContexts),
this.memoryWhere(userMemories),
params.categories?.length
? inArray(userMemories.memoryCategory, params.categories)
: undefined,
@@ -1925,8 +1929,8 @@ export class UserMemoryQueryModel {
params: SearchMemoryParams,
) {
const conditions = [
eq(userMemoriesExperiences.userId, this.userId),
eq(userMemories.userId, this.userId),
this.memoryWhere(userMemoriesExperiences),
this.memoryWhere(userMemories),
params.categories?.length
? inArray(userMemories.memoryCategory, params.categories)
: undefined,
@@ -1978,8 +1982,8 @@ export class UserMemoryQueryModel {
params: SearchMemoryParams,
) {
const conditions = [
eq(userMemoriesPreferences.userId, this.userId),
eq(userMemories.userId, this.userId),
this.memoryWhere(userMemoriesPreferences),
this.memoryWhere(userMemories),
params.categories?.length
? inArray(userMemories.memoryCategory, params.categories)
: undefined,
@@ -2030,8 +2034,8 @@ export class UserMemoryQueryModel {
params: SearchMemoryParams,
): Promise<UserMemoryIdentitiesWithoutVectors[]> {
const conditions = [
eq(userMemoriesIdentities.userId, this.userId),
eq(userMemories.userId, this.userId),
this.memoryWhere(userMemoriesIdentities),
this.memoryWhere(userMemories),
params.categories?.length
? inArray(userMemories.memoryCategory, params.categories)
: undefined,
@@ -2086,8 +2090,8 @@ export class UserMemoryQueryModel {
) {
const normalizedQuery = typeof query === 'string' ? query.trim() : '';
const conditions = [
eq(userMemoriesActivities.userId, this.userId),
eq(userMemories.userId, this.userId),
this.memoryWhere(userMemoriesActivities),
this.memoryWhere(userMemories),
params.categories?.length
? inArray(userMemories.memoryCategory, params.categories)
: undefined,
@@ -2154,8 +2158,8 @@ export class UserMemoryQueryModel {
) {
const normalizedQuery = typeof query === 'string' ? query.trim() : '';
const conditions = [
eq(userMemoriesContexts.userId, this.userId),
eq(userMemories.userId, this.userId),
this.memoryWhere(userMemoriesContexts),
this.memoryWhere(userMemories),
params.categories?.length
? inArray(userMemories.memoryCategory, params.categories)
: undefined,
@@ -2258,8 +2262,8 @@ export class UserMemoryQueryModel {
) {
const normalizedQuery = typeof query === 'string' ? query.trim() : '';
const conditions = [
eq(userMemoriesExperiences.userId, this.userId),
eq(userMemories.userId, this.userId),
this.memoryWhere(userMemoriesExperiences),
this.memoryWhere(userMemories),
params.categories?.length
? inArray(userMemories.memoryCategory, params.categories)
: undefined,
@@ -2319,8 +2323,8 @@ export class UserMemoryQueryModel {
) {
const normalizedQuery = typeof query === 'string' ? query.trim() : '';
const conditions = [
eq(userMemoriesPreferences.userId, this.userId),
eq(userMemories.userId, this.userId),
this.memoryWhere(userMemoriesPreferences),
this.memoryWhere(userMemories),
params.categories?.length
? inArray(userMemories.memoryCategory, params.categories)
: undefined,
@@ -2377,8 +2381,8 @@ export class UserMemoryQueryModel {
) {
const normalizedQuery = typeof query === 'string' ? query.trim() : '';
const conditions = [
eq(userMemoriesIdentities.userId, this.userId),
eq(userMemories.userId, this.userId),
this.memoryWhere(userMemoriesIdentities),
this.memoryWhere(userMemories),
params.categories?.length
? inArray(userMemories.memoryCategory, params.categories)
: undefined,
+334
View File
@@ -0,0 +1,334 @@
import { and, count, desc, eq, isNull, ne } from 'drizzle-orm';
import {
type NewWorkspace,
type WorkspaceItem,
workspaceMembers,
workspaces,
} from '../schemas/workspace';
import type { LobeChatDatabase } from '../type';
export class WorkspaceModel {
protected readonly db: LobeChatDatabase;
protected readonly userId: string;
constructor(db: LobeChatDatabase, userId: string) {
this.db = db;
this.userId = userId;
}
create = async (params: {
avatar?: string;
description?: string;
name: string;
slug: string;
}) => {
return this.db.transaction(async (tx) => {
const [workspace] = await tx
.insert(workspaces)
.values({
avatar: params.avatar,
description: params.description,
name: params.name,
primaryOwnerId: this.userId,
slug: params.slug,
} satisfies NewWorkspace)
.returning();
await tx.insert(workspaceMembers).values({
role: 'owner',
userId: this.userId,
workspaceId: workspace.id,
});
return workspace;
});
};
delete = async (id: string) => {
return this.db
.delete(workspaces)
.where(and(eq(workspaces.id, id), eq(workspaces.primaryOwnerId, this.userId)));
};
findById = async (id: string) => {
return this.db.query.workspaces.findFirst({
where: eq(workspaces.id, id),
});
};
findBySlug = async (slug: string) => {
return this.db.query.workspaces.findFirst({
where: eq(workspaces.slug, slug),
});
};
/**
* List ids of workspaces where this user is the primary (Stripe-bound) owner.
* Cloud callers combine with subscription-status data to enforce the Free
* workspace cap; OSS callers can use the raw count.
*/
listOwnedWorkspaceIds = async (): Promise<string[]> => {
const owned = await this.db.query.workspaces.findMany({
columns: { id: true },
where: eq(workspaces.primaryOwnerId, this.userId),
});
return owned.map((w) => w.id);
};
getSettings = async (id: string) => {
const workspace = await this.db.query.workspaces.findFirst({
columns: { settings: true },
where: eq(workspaces.id, id),
});
return workspace?.settings ?? {};
};
/**
* Count every workspace this user belongs to owned + joined. Reads the
* membership table directly because owners are always inserted as members on
* `create`, so a single count covers both shapes.
*/
countUserMemberships = async (): Promise<number> => {
const result = await this.db
.select({ count: count() })
.from(workspaceMembers)
.where(and(eq(workspaceMembers.userId, this.userId), isNull(workspaceMembers.deletedAt)));
return result[0]?.count ?? 0;
};
listUserWorkspaces = async () => {
const memberships = await this.db.query.workspaceMembers.findMany({
where: and(eq(workspaceMembers.userId, this.userId), isNull(workspaceMembers.deletedAt)),
});
if (memberships.length === 0) return [];
const workspaceIds = memberships.map((m) => m.workspaceId);
const results = await this.db.query.workspaces.findMany({
orderBy: [desc(workspaces.updatedAt)],
where: (ws, { inArray }) => inArray(ws.id, workspaceIds),
});
return results.map((ws) => ({
...ws,
role: memberships.find((m) => m.workspaceId === ws.id)?.role ?? 'viewer',
}));
};
update = async (
id: string,
value: Partial<Pick<WorkspaceItem, 'avatar' | 'description' | 'name' | 'slug'>>,
) => {
return this.db
.update(workspaces)
.set({ ...value, updatedAt: new Date() })
.where(eq(workspaces.id, id));
};
updateSettings = async (id: string, settings: Record<string, any>) => {
return this.db
.update(workspaces)
.set({ settings, updatedAt: new Date() })
.where(eq(workspaces.id, id));
};
/**
* Transfer the Stripe binding (primary owner) to another existing `owner`
* member. Both users keep role='owner' afterwards only the Stripe binding
* moves. Use `promoteToOwner` first if the target isn't already an owner.
*/
transferPrimaryOwnership = async (id: string, newPrimaryOwnerUserId: string) => {
if (newPrimaryOwnerUserId === this.userId)
throw new Error('New primary owner must be a different user');
return this.db.transaction(async (tx) => {
const current = await tx.query.workspaces.findFirst({
where: eq(workspaces.id, id),
});
if (!current) throw new Error('Workspace not found');
if (current.primaryOwnerId !== this.userId)
throw new Error('Only the primary owner can transfer primary ownership');
const targetMembership = await tx.query.workspaceMembers.findFirst({
where: and(
eq(workspaceMembers.workspaceId, id),
eq(workspaceMembers.userId, newPrimaryOwnerUserId),
isNull(workspaceMembers.deletedAt),
),
});
if (!targetMembership)
throw new Error('Target user must already be a member of the workspace');
if (targetMembership.role !== 'owner')
throw new Error('Target user must already be an owner — promote them first');
await tx
.update(workspaces)
.set({ primaryOwnerId: newPrimaryOwnerUserId, updatedAt: new Date() })
.where(eq(workspaces.id, id));
return {
newPrimaryOwnerUserId,
previousPrimaryOwnerUserId: this.userId,
workspaceId: id,
};
});
};
promoteToOwner = async (id: string, targetUserId: string) => {
return this.db.transaction(async (tx) => {
const actor = await tx.query.workspaceMembers.findFirst({
where: and(
eq(workspaceMembers.workspaceId, id),
eq(workspaceMembers.userId, this.userId),
isNull(workspaceMembers.deletedAt),
),
});
if (actor?.role !== 'owner')
throw new Error('Only an owner can promote other members to owner');
const target = await tx.query.workspaceMembers.findFirst({
where: and(
eq(workspaceMembers.workspaceId, id),
eq(workspaceMembers.userId, targetUserId),
isNull(workspaceMembers.deletedAt),
),
});
if (!target) throw new Error('Target user is not a member of this workspace');
if (target.role === 'owner') return target;
await tx
.update(workspaceMembers)
.set({ role: 'owner' })
.where(
and(eq(workspaceMembers.workspaceId, id), eq(workspaceMembers.userId, targetUserId)),
);
return { ...target, role: 'owner' };
});
};
demoteFromOwner = async (id: string, targetUserId: string) => {
return this.db.transaction(async (tx) => {
const workspace = await tx.query.workspaces.findFirst({
where: eq(workspaces.id, id),
});
if (!workspace) throw new Error('Workspace not found');
if (workspace.primaryOwnerId === targetUserId)
throw new Error(
'Cannot demote the primary owner — transfer primary ownership to another owner first',
);
const actor = await tx.query.workspaceMembers.findFirst({
where: and(
eq(workspaceMembers.workspaceId, id),
eq(workspaceMembers.userId, this.userId),
isNull(workspaceMembers.deletedAt),
),
});
if (actor?.role !== 'owner') throw new Error('Only an owner can demote other owners');
const target = await tx.query.workspaceMembers.findFirst({
where: and(
eq(workspaceMembers.workspaceId, id),
eq(workspaceMembers.userId, targetUserId),
isNull(workspaceMembers.deletedAt),
),
});
if (!target) throw new Error('Target user is not a member of this workspace');
if (target.role !== 'owner') return target;
await tx
.update(workspaceMembers)
.set({ role: 'member' })
.where(
and(eq(workspaceMembers.workspaceId, id), eq(workspaceMembers.userId, targetUserId)),
);
return { ...target, role: 'member' };
});
};
countOtherOwners = async (workspaceId: string, excludeUserId: string): Promise<number> => {
const result = await this.db
.select({ count: count() })
.from(workspaceMembers)
.where(
and(
eq(workspaceMembers.workspaceId, workspaceId),
eq(workspaceMembers.role, 'owner'),
ne(workspaceMembers.userId, excludeUserId),
isNull(workspaceMembers.deletedAt),
),
);
return result[0]?.count ?? 0;
};
/**
* Demote the workspace to single-owner: remove every non-owner member and
* clear the grace-period marker. Called when a subscription is cancelled.
* Workspace-scoped resources (agents/sessions/etc.) stay attached to the
* workspace and remain accessible to the primary owner.
*/
downgradeToSolo = async (id: string) => {
return this.db.transaction(async (tx) => {
const current = await tx.query.workspaces.findFirst({
where: eq(workspaces.id, id),
});
if (!current) throw new Error('Workspace not found');
if (current.primaryOwnerId !== this.userId)
throw new Error('Only the primary owner can downgrade this workspace');
const removedMembers = await tx
.update(workspaceMembers)
.set({ deletedAt: new Date() })
.where(
and(
eq(workspaceMembers.workspaceId, id),
ne(workspaceMembers.userId, current.primaryOwnerId),
isNull(workspaceMembers.deletedAt),
),
)
.returning();
const currentSettings = (current.settings as Record<string, any> | null) ?? {};
const { gracePeriodUntil: _drop, ...restSettings } = currentSettings;
const [updated] = await tx
.update(workspaces)
.set({
settings: restSettings,
updatedAt: new Date(),
})
.where(eq(workspaces.id, id))
.returning();
return {
removedUserIds: removedMembers.map((m) => m.userId),
workspace: updated,
};
});
};
setGracePeriod = async (id: string, gracePeriodUntil: number | null) => {
const current = await this.db.query.workspaces.findFirst({
columns: { settings: true },
where: eq(workspaces.id, id),
});
if (!current) throw new Error('Workspace not found');
const prev = (current.settings as Record<string, any> | null) ?? {};
const next =
gracePeriodUntil === null
? Object.fromEntries(Object.entries(prev).filter(([k]) => k !== 'gracePeriodUntil'))
: { ...prev, gracePeriodUntil };
await this.db
.update(workspaces)
.set({ settings: next, updatedAt: new Date() })
.where(eq(workspaces.id, id));
};
}
@@ -0,0 +1,99 @@
import { and, desc, eq, gte, lt, lte } from 'drizzle-orm';
import { workspaceAuditLogs } from '../schemas/workspace';
import type { LobeChatDatabase } from '../type';
export type WorkspaceAuditAction =
| 'workspace.created'
| 'workspace.updated'
| 'workspace.upgraded'
| 'workspace.downgraded'
| 'workspace.primary_ownership_transferred'
| 'workspace.deleted'
| 'workspace.cleanup_triggered'
| 'workspace.account_upgraded'
| 'workspace.data_cleared'
| 'workspace.settings_reset'
| 'member.invited'
| 'member.removed'
| 'member.role_updated'
| 'member.joined'
| 'member.left'
| 'member.promoted_to_owner'
| 'member.demoted_from_owner'
| 'invitation.revoked'
| 'invitation.resent'
| 'subscription.activated'
| 'subscription.updated'
| 'subscription.cancelled'
| 'subscription.cancellation_scheduled'
| 'subscription.cancellation_resumed'
| 'subscription.grace_period_started'
| 'billing.portal_session_created'
| 'billing.payment_method_added'
| 'billing.payment_method_removed'
| 'billing.default_payment_method_changed';
interface CreateAuditLogParams {
action: WorkspaceAuditAction;
ipAddress?: string;
metadata?: Record<string, any>;
resourceId?: string;
resourceType?: string;
userId: string | null;
workspaceId: string;
}
interface ListAuditLogParams {
action?: WorkspaceAuditAction;
cursor?: Date;
endDate?: Date;
limit?: number;
startDate?: Date;
workspaceId: string;
}
export class WorkspaceAuditLogModel {
private readonly db: LobeChatDatabase;
constructor(db: LobeChatDatabase) {
this.db = db;
}
create = async (params: CreateAuditLogParams) => {
const [row] = await this.db
.insert(workspaceAuditLogs)
.values({
action: params.action,
ipAddress: params.ipAddress,
metadata: params.metadata ?? {},
resourceId: params.resourceId,
resourceType: params.resourceType,
userId: params.userId,
workspaceId: params.workspaceId,
})
.returning();
return row;
};
list = async (params: ListAuditLogParams) => {
const { workspaceId, action, startDate, endDate, cursor, limit = 50 } = params;
const conditions = [eq(workspaceAuditLogs.workspaceId, workspaceId)];
if (action) conditions.push(eq(workspaceAuditLogs.action, action));
if (startDate) conditions.push(gte(workspaceAuditLogs.createdAt, startDate));
if (endDate) conditions.push(lte(workspaceAuditLogs.createdAt, endDate));
if (cursor) conditions.push(lt(workspaceAuditLogs.createdAt, cursor));
const rows = await this.db.query.workspaceAuditLogs.findMany({
limit: limit + 1,
orderBy: [desc(workspaceAuditLogs.createdAt)],
where: and(...conditions),
});
const hasMore = rows.length > limit;
const items = hasMore ? rows.slice(0, limit) : rows;
const nextCursor = hasMore ? items.at(-1)?.createdAt?.toISOString() : null;
return { items, nextCursor };
};
}
@@ -0,0 +1,133 @@
import { INVITATION_EXPIRY_DAYS } from '@lobechat/const';
import { and, eq, isNull } from 'drizzle-orm';
import { nanoid } from 'nanoid/non-secure';
import { workspaceInvitations, workspaceMembers } from '../schemas/workspace';
import type { LobeChatDatabase } from '../type';
type MemberRole = 'member' | 'owner' | 'viewer';
export class WorkspaceMemberModel {
private readonly db: LobeChatDatabase;
private readonly userId: string;
constructor(db: LobeChatDatabase, userId: string) {
this.db = db;
this.userId = userId;
}
// ===== Members ===== //
addMember = async (params: { role?: MemberRole; userId: string; workspaceId: string }) => {
const [result] = await this.db
.insert(workspaceMembers)
.values({
role: params.role ?? 'member',
userId: params.userId,
workspaceId: params.workspaceId,
})
.onConflictDoUpdate({
set: {
deletedAt: null,
joinedAt: new Date(),
role: params.role ?? 'member',
},
target: [workspaceMembers.workspaceId, workspaceMembers.userId],
})
.returning();
return result;
};
getMember = async (workspaceId: string, userId: string) => {
return this.db.query.workspaceMembers.findFirst({
where: and(
eq(workspaceMembers.workspaceId, workspaceId),
eq(workspaceMembers.userId, userId),
isNull(workspaceMembers.deletedAt),
),
});
};
listMembers = async (workspaceId: string, options: { includeDeleted?: boolean } = {}) => {
return this.db.query.workspaceMembers.findMany({
where: options.includeDeleted
? eq(workspaceMembers.workspaceId, workspaceId)
: and(eq(workspaceMembers.workspaceId, workspaceId), isNull(workspaceMembers.deletedAt)),
});
};
removeMember = async (workspaceId: string, userId: string) => {
return this.db
.update(workspaceMembers)
.set({ deletedAt: new Date() })
.where(
and(
eq(workspaceMembers.workspaceId, workspaceId),
eq(workspaceMembers.userId, userId),
isNull(workspaceMembers.deletedAt),
),
);
};
updateMemberRole = async (workspaceId: string, userId: string, role: MemberRole) => {
return this.db
.update(workspaceMembers)
.set({ role })
.where(
and(
eq(workspaceMembers.workspaceId, workspaceId),
eq(workspaceMembers.userId, userId),
isNull(workspaceMembers.deletedAt),
),
);
};
// ===== Invitations ===== //
createInvitation = async (params: { email?: string; role?: MemberRole; workspaceId: string }) => {
const expiresAt = new Date();
expiresAt.setDate(expiresAt.getDate() + INVITATION_EXPIRY_DAYS);
const [result] = await this.db
.insert(workspaceInvitations)
.values({
email: params.email,
expiresAt,
inviterId: this.userId,
role: params.role ?? 'member',
token: nanoid(32),
workspaceId: params.workspaceId,
})
.returning();
return result;
};
findInvitationByToken = async (token: string) => {
return this.db.query.workspaceInvitations.findFirst({
where: eq(workspaceInvitations.token, token),
});
};
listPendingInvitations = async (workspaceId: string) => {
return this.db.query.workspaceInvitations.findMany({
where: and(
eq(workspaceInvitations.workspaceId, workspaceId),
eq(workspaceInvitations.status, 'pending'),
),
});
};
revokeInvitation = async (id: string) => {
return this.db
.update(workspaceInvitations)
.set({ status: 'revoked' })
.where(eq(workspaceInvitations.id, id));
};
updateInvitationStatus = async (id: string, status: 'accepted' | 'expired' | 'revoked') => {
return this.db
.update(workspaceInvitations)
.set({ status })
.where(eq(workspaceInvitations.id, id));
};
}
@@ -3,9 +3,13 @@ import { BUILTIN_AGENT_SLUGS } from '@lobechat/builtin-agents';
import { beforeEach, describe, expect, it } from 'vitest';
import { getTestDB } from '../../core/getTestDB';
import { ChatGroupModel } from '../../models/chatGroup';
import { agents } from '../../schemas/agent';
import { chatGroups, chatGroupsAgents } from '../../schemas/chatGroup';
import { messagePlugins, messages } from '../../schemas/message';
import { threads, topics } from '../../schemas/topic';
import { users } from '../../schemas/user';
import { workspaces } from '../../schemas/workspace';
import type { LobeChatDatabase } from '../../type';
import { AgentGroupRepository } from './index';
@@ -1258,4 +1262,593 @@ describe('AgentGroupRepository', () => {
);
});
});
describe('workspace scoping', () => {
const workspaceId = 'agent-group-test-ws';
beforeEach(async () => {
await serverDB.insert(workspaces).values({
id: workspaceId,
name: 'Test Workspace',
primaryOwnerId: userId,
slug: 'agent-group-test-ws',
});
});
it('stamps workspaceId on the group, supervisor agent, and junction rows', async () => {
const wsRepo = new AgentGroupRepository(serverDB, userId, workspaceId);
const result = await wsRepo.createGroupWithSupervisor({ title: 'WS Group' });
// group row carries the workspace id
expect(result.group.workspaceId).toBe(workspaceId);
// supervisor agent carries the workspace id
const supervisor = await serverDB.query.agents.findFirst({
where: (a, { eq }) => eq(a.id, result.supervisorAgentId),
});
expect(supervisor!.workspaceId).toBe(workspaceId);
// junction rows carry the workspace id
const junctions = await serverDB.query.chatGroupsAgents.findMany({
where: (cga, { eq }) => eq(cga.chatGroupId, result.group.id),
});
expect(junctions.every((j) => j.workspaceId === workspaceId)).toBe(true);
});
// Regression for "群组设定 system prompt won't save": a group created inside a
// workspace must be updatable through the workspace-scoped ChatGroupModel.
// Previously create wrote workspace_id = NULL, so the workspace-scoped UPDATE
// matched 0 rows and threw "not found or access denied".
it('allows the workspace-scoped ChatGroupModel to update a workspace-created group', async () => {
const wsRepo = new AgentGroupRepository(serverDB, userId, workspaceId);
const { group } = await wsRepo.createGroupWithSupervisor({ title: 'WS Group' });
const chatGroupModel = new ChatGroupModel(serverDB, userId, workspaceId);
const updated = await chatGroupModel.update(group.id, {
config: { systemPrompt: 'You are a helpful team.' } as any,
});
expect(updated.config).toMatchObject({ systemPrompt: 'You are a helpful team.' });
});
it('isolates workspace groups from personal-mode reads', async () => {
const wsRepo = new AgentGroupRepository(serverDB, userId, workspaceId);
const { group } = await wsRepo.createGroupWithSupervisor({ title: 'WS Group' });
// personal-mode repo (no workspaceId) must not see the workspace group
const personalRepo = new AgentGroupRepository(serverDB, userId);
expect(await personalRepo.findByIdWithAgents(group.id)).toBeNull();
// workspace repo sees it
expect(await wsRepo.findByIdWithAgents(group.id)).not.toBeNull();
});
it('keeps personal groups out of workspace-scoped reads', async () => {
const personalRepo = new AgentGroupRepository(serverDB, userId);
const { group } = await personalRepo.createGroupWithSupervisor({ title: 'Personal Group' });
expect(group.workspaceId).toBeNull();
const wsRepo = new AgentGroupRepository(serverDB, userId, workspaceId);
expect(await wsRepo.findByIdWithAgents(group.id)).toBeNull();
});
it('transfers a workspace group with members and conversation data to the target scope', async () => {
const targetWorkspaceId = 'agent-group-target-ws';
await serverDB.insert(workspaces).values({
id: targetWorkspaceId,
name: 'Target Workspace',
primaryOwnerId: userId,
slug: 'agent-group-target-ws',
});
await serverDB.insert(chatGroups).values({
id: 'transfer-group',
title: 'Transfer Group',
userId,
workspaceId,
});
await serverDB.insert(agents).values([
{
id: 'transfer-supervisor',
title: 'Supervisor',
userId,
virtual: true,
workspaceId,
},
{
id: 'transfer-member',
title: 'Member',
userId,
virtual: false,
workspaceId,
},
]);
await serverDB.insert(chatGroupsAgents).values([
{
agentId: 'transfer-supervisor',
chatGroupId: 'transfer-group',
order: -1,
role: 'supervisor',
userId,
workspaceId,
},
{
agentId: 'transfer-member',
chatGroupId: 'transfer-group',
order: 0,
role: 'participant',
userId,
workspaceId,
},
]);
await serverDB.insert(topics).values({
groupId: 'transfer-group',
id: 'transfer-topic',
title: 'Group Topic',
userId,
workspaceId,
});
await serverDB.insert(threads).values({
agentId: 'transfer-member',
id: 'transfer-thread',
topicId: 'transfer-topic',
type: 'continuation',
userId,
workspaceId,
});
await serverDB.insert(messages).values({
content: 'hello',
groupId: 'transfer-group',
id: 'transfer-message',
role: 'user',
topicId: 'transfer-topic',
userId,
workspaceId,
});
const wsRepo = new AgentGroupRepository(serverDB, userId, workspaceId);
const result = await wsRepo.transferToWorkspace('transfer-group', targetWorkspaceId, userId);
expect(result).toEqual({ groupId: 'transfer-group' });
const group = await serverDB.query.chatGroups.findFirst({
where: (cg, { eq }) => eq(cg.id, 'transfer-group'),
});
expect(group!.workspaceId).toBe(targetWorkspaceId);
const memberAgents = await serverDB.query.agents.findMany({
where: (a, { inArray }) => inArray(a.id, ['transfer-supervisor', 'transfer-member']),
});
expect(memberAgents.every((agent) => agent.workspaceId === targetWorkspaceId)).toBe(true);
const junctions = await serverDB.query.chatGroupsAgents.findMany({
where: (cga, { eq }) => eq(cga.chatGroupId, 'transfer-group'),
});
expect(junctions.every((junction) => junction.workspaceId === targetWorkspaceId)).toBe(true);
const topic = await serverDB.query.topics.findFirst({
where: (t, { eq }) => eq(t.id, 'transfer-topic'),
});
const thread = await serverDB.query.threads.findFirst({
where: (t, { eq }) => eq(t.id, 'transfer-thread'),
});
const message = await serverDB.query.messages.findFirst({
where: (m, { eq }) => eq(m.id, 'transfer-message'),
});
expect(topic!.workspaceId).toBe(targetWorkspaceId);
expect(thread!.workspaceId).toBe(targetWorkspaceId);
expect(message!.workspaceId).toBe(targetWorkspaceId);
});
it('copies a workspace group and all members into the target scope', async () => {
const targetWorkspaceId = 'agent-group-copy-target-ws';
await serverDB.insert(workspaces).values({
id: targetWorkspaceId,
name: 'Copy Target Workspace',
primaryOwnerId: userId,
slug: 'agent-group-copy-target-ws',
});
await serverDB.insert(chatGroups).values({
avatar: 'group-avatar',
id: 'copy-group',
title: 'Copy Group',
userId,
workspaceId,
});
await serverDB.insert(agents).values([
{
id: 'copy-supervisor',
model: 'gpt-4o',
provider: 'openai',
title: 'Supervisor',
userId,
virtual: true,
workspaceId,
},
{
id: 'copy-member',
model: 'claude-3',
provider: 'anthropic',
title: 'Member',
userId,
virtual: false,
workspaceId,
},
]);
await serverDB.insert(chatGroupsAgents).values([
{
agentId: 'copy-supervisor',
chatGroupId: 'copy-group',
order: -1,
role: 'supervisor',
userId,
workspaceId,
},
{
agentId: 'copy-member',
chatGroupId: 'copy-group',
order: 0,
role: 'participant',
userId,
workspaceId,
},
]);
const wsRepo = new AgentGroupRepository(serverDB, userId, workspaceId);
const result = await wsRepo.copyToWorkspace('copy-group', targetWorkspaceId, userId);
expect(result).not.toBeNull();
expect(result!.groupId).not.toBe('copy-group');
expect(result!.supervisorAgentId).not.toBe('copy-supervisor');
const copiedGroup = await serverDB.query.chatGroups.findFirst({
where: (cg, { eq }) => eq(cg.id, result!.groupId),
});
expect(copiedGroup).toEqual(
expect.objectContaining({
avatar: 'group-avatar',
title: 'Copy Group (Copy)',
userId,
workspaceId: targetWorkspaceId,
}),
);
const copiedJunctions = await serverDB.query.chatGroupsAgents.findMany({
where: (cga, { eq }) => eq(cga.chatGroupId, result!.groupId),
});
expect(copiedJunctions).toHaveLength(2);
expect(copiedJunctions.every((junction) => junction.workspaceId === targetWorkspaceId)).toBe(
true,
);
expect(copiedJunctions.some((junction) => junction.agentId === 'copy-member')).toBe(false);
const copiedAgentIds = copiedJunctions.map((junction) => junction.agentId);
const copiedAgents = await serverDB.query.agents.findMany({
where: (a, { inArray }) => inArray(a.id, copiedAgentIds),
});
expect(copiedAgents.every((agent) => agent.workspaceId === targetWorkspaceId)).toBe(true);
expect(copiedAgents.map((agent) => agent.title).sort()).toEqual(['Member', 'Supervisor']);
});
it('copies group topics and messages when conversation history is selected', async () => {
const targetWorkspaceId = 'agent-group-copy-history-target-ws';
await serverDB.insert(workspaces).values({
id: targetWorkspaceId,
name: 'Copy History Target Workspace',
primaryOwnerId: userId,
slug: 'agent-group-copy-history-target-ws',
});
await serverDB.insert(chatGroups).values({
id: 'copy-history-group',
title: 'Copy History Group',
userId,
workspaceId,
});
await serverDB.insert(agents).values([
{
id: 'copy-history-supervisor',
model: 'gpt-4o',
provider: 'openai',
title: 'Supervisor',
userId,
virtual: true,
workspaceId,
},
{
id: 'copy-history-member',
model: 'claude-3',
provider: 'anthropic',
title: 'Member',
userId,
virtual: false,
workspaceId,
},
]);
await serverDB.insert(chatGroupsAgents).values([
{
agentId: 'copy-history-supervisor',
chatGroupId: 'copy-history-group',
order: -1,
role: 'supervisor',
userId,
workspaceId,
},
{
agentId: 'copy-history-member',
chatGroupId: 'copy-history-group',
order: 0,
role: 'participant',
userId,
workspaceId,
},
]);
await serverDB.insert(topics).values({
groupId: 'copy-history-group',
id: 'copy-history-topic',
title: 'Group topic',
userId,
workspaceId,
});
await serverDB.insert(threads).values({
agentId: 'copy-history-member',
groupId: 'copy-history-group',
id: 'copy-history-thread',
sourceMessageId: 'copy-history-message-user',
topicId: 'copy-history-topic',
type: 'standalone',
userId,
workspaceId,
});
await serverDB.insert(messages).values([
{
content: 'Hello group',
groupId: 'copy-history-group',
id: 'copy-history-message-user',
role: 'user',
targetId: 'copy-history-member',
topicId: 'copy-history-topic',
userId,
workspaceId,
},
{
agentId: 'copy-history-member',
content: 'Hello user',
groupId: 'copy-history-group',
id: 'copy-history-message-assistant',
parentId: 'copy-history-message-user',
role: 'assistant',
threadId: 'copy-history-thread',
tools: [{ id: 'toolu_old', type: 'builtin' }],
topicId: 'copy-history-topic',
userId,
workspaceId,
},
]);
await serverDB.insert(messagePlugins).values({
apiName: 'search',
arguments: '{}',
id: 'copy-history-message-assistant',
toolCallId: 'toolu_old',
userId,
workspaceId,
});
const wsRepo = new AgentGroupRepository(serverDB, userId, workspaceId);
const result = await wsRepo.copyToWorkspace('copy-history-group', targetWorkspaceId, userId, {
includeConversationHistory: true,
});
expect(result).not.toBeNull();
const copiedJunctions = await serverDB.query.chatGroupsAgents.findMany({
where: (cga, { eq }) => eq(cga.chatGroupId, result!.groupId),
});
const copiedMember = copiedJunctions.find((junction) => junction.role === 'participant');
expect(copiedMember?.agentId).toBeDefined();
expect(copiedMember?.agentId).not.toBe('copy-history-member');
const copiedTopics = await serverDB.query.topics.findMany({
where: (topic, { eq }) => eq(topic.groupId, result!.groupId),
});
expect(copiedTopics).toHaveLength(1);
expect(copiedTopics[0]).toEqual(
expect.objectContaining({
clientId: null,
sessionId: null,
title: 'Group topic',
userId,
workspaceId: targetWorkspaceId,
}),
);
const copiedMessages = await serverDB.query.messages.findMany({
where: (message, { eq }) => eq(message.groupId, result!.groupId),
});
expect(copiedMessages).toHaveLength(2);
expect(copiedMessages.some((message) => message.id === 'copy-history-message-user')).toBe(
false,
);
const copiedAssistantMessage = copiedMessages.find((message) => message.role === 'assistant');
const copiedUserMessage = copiedMessages.find((message) => message.role === 'user');
expect(copiedUserMessage?.targetId).toBe(copiedMember!.agentId);
expect(copiedAssistantMessage).toEqual(
expect.objectContaining({
agentId: copiedMember!.agentId,
clientId: null,
targetId: null,
userId,
workspaceId: targetWorkspaceId,
}),
);
expect(copiedAssistantMessage?.tools).not.toEqual([{ id: 'toolu_old', type: 'builtin' }]);
const copiedPlugin = await serverDB.query.messagePlugins.findFirst({
where: (plugin, { eq }) => eq(plugin.id, copiedAssistantMessage!.id),
});
expect(copiedPlugin?.toolCallId).not.toBe('toolu_old');
expect(copiedPlugin?.workspaceId).toBe(targetWorkspaceId);
});
it('removes workspace virtual agents created by another member', async () => {
const wsRepo = new AgentGroupRepository(serverDB, userId, workspaceId);
await serverDB.insert(chatGroups).values({
id: 'remove-cross-member-group',
title: 'Remove Cross Member Group',
userId,
workspaceId,
});
await serverDB.insert(agents).values({
id: 'remove-cross-member-virtual',
title: 'Virtual From Other Member',
userId: otherUserId,
virtual: true,
workspaceId,
});
await serverDB.insert(chatGroupsAgents).values({
agentId: 'remove-cross-member-virtual',
chatGroupId: 'remove-cross-member-group',
role: 'participant',
userId,
workspaceId,
});
const result = await wsRepo.removeAgentsFromGroup('remove-cross-member-group', [
'remove-cross-member-virtual',
]);
expect(result).toEqual({
deletedVirtualAgentIds: ['remove-cross-member-virtual'],
removedFromGroup: 1,
});
const relation = await serverDB.query.chatGroupsAgents.findFirst({
where: (cga, { eq }) => eq(cga.agentId, 'remove-cross-member-virtual'),
});
expect(relation).toBeUndefined();
const deletedAgent = await serverDB.query.agents.findFirst({
where: (agent, { eq }) => eq(agent.id, 'remove-cross-member-virtual'),
});
expect(deletedAgent).toBeUndefined();
});
it('copies workspace group history created by another member', async () => {
const targetWorkspaceId = 'agent-group-copy-member-history-target-ws';
await serverDB.insert(workspaces).values({
id: targetWorkspaceId,
name: 'Copy Member History Target Workspace',
primaryOwnerId: userId,
slug: 'agent-group-copy-member-history-target-ws',
});
await serverDB.insert(chatGroups).values({
id: 'copy-member-history-group',
title: 'Copy Member History Group',
userId,
workspaceId,
});
await serverDB.insert(agents).values([
{
id: 'copy-member-history-supervisor',
title: 'Supervisor',
userId,
virtual: true,
workspaceId,
},
{
id: 'copy-member-history-agent',
title: 'Member Agent',
userId,
virtual: false,
workspaceId,
},
]);
await serverDB.insert(chatGroupsAgents).values([
{
agentId: 'copy-member-history-supervisor',
chatGroupId: 'copy-member-history-group',
order: -1,
role: 'supervisor',
userId,
workspaceId,
},
{
agentId: 'copy-member-history-agent',
chatGroupId: 'copy-member-history-group',
order: 0,
role: 'participant',
userId,
workspaceId,
},
]);
await serverDB.insert(topics).values({
groupId: 'copy-member-history-group',
id: 'copy-member-history-topic',
title: 'Topic From Other Member',
userId: otherUserId,
workspaceId,
});
await serverDB.insert(threads).values({
agentId: 'copy-member-history-agent',
groupId: 'copy-member-history-group',
id: 'copy-member-history-thread',
topicId: 'copy-member-history-topic',
type: 'standalone',
userId: otherUserId,
workspaceId,
});
await serverDB.insert(messages).values({
agentId: 'copy-member-history-agent',
content: 'created by another workspace member',
groupId: 'copy-member-history-group',
id: 'copy-member-history-message',
role: 'assistant',
threadId: 'copy-member-history-thread',
topicId: 'copy-member-history-topic',
userId: otherUserId,
workspaceId,
});
const wsRepo = new AgentGroupRepository(serverDB, userId, workspaceId);
const result = await wsRepo.copyToWorkspace(
'copy-member-history-group',
targetWorkspaceId,
userId,
{ includeConversationHistory: true },
);
expect(result).not.toBeNull();
const copiedTopics = await serverDB.query.topics.findMany({
where: (topic, { eq }) => eq(topic.groupId, result!.groupId),
});
expect(copiedTopics).toHaveLength(1);
expect(copiedTopics[0]).toEqual(
expect.objectContaining({
title: 'Topic From Other Member',
userId,
workspaceId: targetWorkspaceId,
}),
);
const copiedMessages = await serverDB.query.messages.findMany({
where: (message, { eq }) => eq(message.groupId, result!.groupId),
});
expect(copiedMessages).toHaveLength(1);
expect(copiedMessages[0]).toEqual(
expect.objectContaining({
content: 'created by another workspace member',
userId,
workspaceId: targetWorkspaceId,
}),
);
});
});
});
@@ -1,11 +1,32 @@
import { BUILTIN_AGENT_SLUGS } from '@lobechat/builtin-agents';
import type { AgentGroupDetail, AgentGroupMember } from '@lobechat/types';
import { cleanObject } from '@lobechat/utils';
import { and, eq, inArray } from 'drizzle-orm';
import { and, eq, inArray, not } from 'drizzle-orm';
import type { AgentItem, ChatGroupItem, NewChatGroup, NewChatGroupAgent } from '../../schemas';
import { agents, chatGroups, chatGroupsAgents } from '../../schemas';
import type {
AgentItem,
ChatGroupItem,
NewAgent,
NewChatGroup,
NewChatGroupAgent,
} from '../../schemas';
import {
agents,
chatGroups,
chatGroupsAgents,
messagePlugins,
messages,
threads,
topics,
} from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { idGenerator } from '../../utils/idGenerator';
import { buildWorkspaceWhere } from '../../utils/workspace';
interface CopyAgentGroupToWorkspaceOptions {
includeConversationHistory?: boolean;
newTitle?: string;
}
export interface SupervisorAgentConfig {
avatar?: string;
@@ -53,12 +74,230 @@ export interface CreateGroupWithSupervisorResult {
export class AgentGroupRepository {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
}
/**
* Workspace-aware ownership predicate for the `chat_groups` table. In personal
* mode (`workspaceId` absent) matches `user_id = ? AND workspace_id IS NULL`;
* in team mode matches `workspace_id = ?` (shared with all members).
*/
private groupOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, chatGroups);
private agentOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, agents);
private topicOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, topics);
private threadOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, threads);
private messageOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, messages);
private messagePluginOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, messagePlugins);
private buildCopiedAgent = (
source: AgentItem | undefined,
targetWorkspaceId: string | null,
targetUserId: string,
fallbackTitle: string,
): NewAgent => ({
agencyConfig: source?.agencyConfig,
avatar: source?.avatar,
backgroundColor: source?.backgroundColor,
chatConfig: source?.chatConfig,
description: source?.description,
editorData: source?.editorData,
fewShots: source?.fewShots,
model: source?.model,
openingMessage: source?.openingMessage,
openingQuestions: source?.openingQuestions,
params: source?.params,
pinned: source?.pinned,
plugins: source?.plugins,
provider: source?.provider,
systemRole: source?.systemRole,
tags: source?.tags,
title: source?.title || fallbackTitle,
tts: source?.tts,
userId: targetUserId,
virtual: source?.virtual ?? true,
workspaceId: targetWorkspaceId,
});
private remapToolIds = (tools: unknown, toolIdMap: Map<string, string>) => {
if (!Array.isArray(tools)) return tools;
return tools.map((tool) => {
if (!tool || typeof tool !== 'object') return tool;
const toolRecord = tool as Record<PropertyKey, unknown>;
const toolId = toolRecord.id;
if (typeof toolId !== 'string') return tool;
return {
...toolRecord,
id: toolIdMap.get(toolId) ?? toolId,
};
});
};
private copyGroupConversationHistory = async ({
agentIdMap,
executor,
newGroupId,
sourceGroupId,
targetUserId,
targetWorkspaceId,
}: {
agentIdMap: Map<string, string>;
executor: LobeChatDatabase;
newGroupId: string;
sourceGroupId: string;
targetUserId: string;
targetWorkspaceId: string | null;
}) => {
const mapAgentId = (agentId?: null | string) =>
agentId ? (agentIdMap.get(agentId) ?? null) : null;
const mapTargetId = (targetId?: null | string) => {
if (!targetId || targetId === 'user') return targetId ?? null;
return agentIdMap.get(targetId) ?? null;
};
const sourceTopics = await executor.query.topics.findMany({
orderBy: (topic, { asc }) => [asc(topic.createdAt)],
where: and(this.topicOwnership(), eq(topics.groupId, sourceGroupId)),
});
if (sourceTopics.length === 0) return;
const sourceTopicIds = sourceTopics.map((topic) => topic.id);
const topicIdMap = new Map(sourceTopics.map((topic) => [topic.id, idGenerator('topics')]));
const sourceThreads = await executor.query.threads.findMany({
orderBy: (thread, { asc }) => [asc(thread.createdAt)],
where: and(this.threadOwnership(), inArray(threads.topicId, sourceTopicIds)),
});
const threadIdMap = new Map(
sourceThreads.map((thread) => [thread.id, idGenerator('threads', 16)]),
);
const sourceMessages = await executor.query.messages.findMany({
orderBy: (message, { asc }) => [asc(message.createdAt)],
where: and(this.messageOwnership(), inArray(messages.topicId, sourceTopicIds)),
});
const messageIdMap = new Map(
sourceMessages.map((message) => [message.id, idGenerator('messages')]),
);
const toolIdMap = new Map<string, string>();
for (const message of sourceMessages) {
if (!Array.isArray(message.tools)) continue;
for (const tool of message.tools) {
if (!tool || typeof tool !== 'object') continue;
const toolId = (tool as Record<PropertyKey, unknown>).id;
if (typeof toolId === 'string') {
toolIdMap.set(toolId, `toolu_${idGenerator('messages')}`);
}
}
}
await executor.insert(topics).values(
sourceTopics.map((topic) => ({
...topic,
agentId: mapAgentId(topic.agentId),
clientId: null,
groupId: newGroupId,
id: topicIdMap.get(topic.id),
sessionId: null,
userId: targetUserId,
workspaceId: targetWorkspaceId,
})),
);
if (sourceThreads.length > 0) {
await executor.insert(threads).values(
sourceThreads.map((thread) => ({
...thread,
agentId: mapAgentId(thread.agentId),
clientId: null,
groupId: newGroupId,
id: threadIdMap.get(thread.id),
parentThreadId: thread.parentThreadId
? (threadIdMap.get(thread.parentThreadId) ?? null)
: null,
sourceMessageId: thread.sourceMessageId
? (messageIdMap.get(thread.sourceMessageId) ?? null)
: null,
topicId: topicIdMap.get(thread.topicId),
userId: targetUserId,
workspaceId: targetWorkspaceId,
})),
);
}
if (sourceMessages.length === 0) return;
const sourceMessageIds = sourceMessages.map((message) => message.id);
const sourcePlugins = await executor.query.messagePlugins.findMany({
where: and(this.messagePluginOwnership(), inArray(messagePlugins.id, sourceMessageIds)),
});
const messageRows = sourceMessages.map((message) => {
const newMessageId = messageIdMap.get(message.id)!;
const newTopicId = message.topicId ? (topicIdMap.get(message.topicId) ?? null) : null;
return {
...message,
agentId: mapAgentId(message.agentId),
clientId: null,
groupId: newGroupId,
id: newMessageId,
messageGroupId: null,
parentId: message.parentId ? (messageIdMap.get(message.parentId) ?? null) : null,
quotaId: message.quotaId ? (messageIdMap.get(message.quotaId) ?? null) : null,
sessionId: null,
targetId: mapTargetId(message.targetId),
threadId: message.threadId ? (threadIdMap.get(message.threadId) ?? null) : null,
tools: this.remapToolIds(message.tools, toolIdMap),
topicId: newTopicId,
userId: targetUserId,
workspaceId: targetWorkspaceId,
};
});
await executor.insert(messages).values(messageRows);
if (sourcePlugins.length > 0) {
await executor.insert(messagePlugins).values(
sourcePlugins
.map((plugin) => {
const newMessageId = messageIdMap.get(plugin.id);
if (!newMessageId) return;
return {
...plugin,
clientId: null,
id: newMessageId,
toolCallId: plugin.toolCallId ? (toolIdMap.get(plugin.toolCallId) ?? null) : null,
userId: targetUserId,
workspaceId: targetWorkspaceId,
};
})
.filter((plugin) => !!plugin),
);
}
};
/**
* Find a chat group by ID with its associated agents.
* If no supervisor exists, a virtual supervisor agent is automatically created.
@@ -68,7 +307,7 @@ export class AgentGroupRepository {
async findByIdWithAgents(groupId: string): Promise<AgentGroupDetail | null> {
// 1. Find the group
const group = await this.db.query.chatGroups.findFirst({
where: and(eq(chatGroups.id, groupId), eq(chatGroups.userId, this.userId)),
where: and(eq(chatGroups.id, groupId), this.groupOwnership()),
});
if (!group) return null;
@@ -115,6 +354,7 @@ export class AgentGroupRepository {
title: 'Supervisor',
userId: this.userId,
virtual: true,
workspaceId: this.workspaceId ?? null,
})
.returning();
@@ -125,6 +365,7 @@ export class AgentGroupRepository {
order: -1, // Supervisor always first (negative order)
role: 'supervisor',
userId: this.userId,
workspaceId: this.workspaceId ?? null,
});
supervisorAgentId = supervisorAgent.id;
@@ -178,13 +419,14 @@ export class AgentGroupRepository {
title: supervisorConfig?.title ?? 'Supervisor',
userId: this.userId,
virtual: true,
workspaceId: this.workspaceId ?? null,
})
.returning();
// 2. Create the group
const [group] = await this.db
.insert(chatGroups)
.values({ ...groupParams, userId: this.userId })
.values({ ...groupParams, userId: this.userId, workspaceId: this.workspaceId ?? null })
.returning();
// 3. Add supervisor agent to group with role 'supervisor'
@@ -194,6 +436,7 @@ export class AgentGroupRepository {
order: -1, // Supervisor always first (negative order)
role: 'supervisor',
userId: this.userId,
workspaceId: this.workspaceId ?? null,
};
// 4. Add member agents to group with role 'participant'
@@ -203,6 +446,7 @@ export class AgentGroupRepository {
order: index,
role: 'participant',
userId: this.userId,
workspaceId: this.workspaceId ?? null,
}));
// 5. Insert all group-agent relationships
@@ -245,7 +489,7 @@ export class AgentGroupRepository {
virtual: agents.virtual,
})
.from(agents)
.where(and(eq(agents.userId, this.userId), inArray(agents.id, agentIds)));
.where(and(this.agentOwnership(), inArray(agents.id, agentIds)));
const virtualAgents: RemoveAgentsCheckResult['virtualAgents'] = [];
const nonVirtualAgentIds: string[] = [];
@@ -300,7 +544,7 @@ export class AgentGroupRepository {
if (deleteVirtualAgents && virtualAgentIds.length > 0) {
await this.db
.delete(agents)
.where(and(eq(agents.userId, this.userId), inArray(agents.id, virtualAgentIds)));
.where(and(this.agentOwnership(), inArray(agents.id, virtualAgentIds)));
}
return {
@@ -326,7 +570,7 @@ export class AgentGroupRepository {
): Promise<{ groupId: string; supervisorAgentId: string } | null> {
// 1. Get the source group
const sourceGroup = await this.db.query.chatGroups.findFirst({
where: and(eq(chatGroups.id, groupId), eq(chatGroups.userId, this.userId)),
where: and(eq(chatGroups.id, groupId), this.groupOwnership()),
});
if (!sourceGroup) return null;
@@ -374,6 +618,7 @@ export class AgentGroupRepository {
pinned: sourceGroup.pinned,
title: newTitle || (sourceGroup.title ? `${sourceGroup.title} (Copy)` : 'Copy'),
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})
.returning();
@@ -393,6 +638,7 @@ export class AgentGroupRepository {
title: supervisorAgent?.title || 'Supervisor',
userId: this.userId,
virtual: true,
workspaceId: this.workspaceId ?? null,
})
.returning();
@@ -421,6 +667,7 @@ export class AgentGroupRepository {
// User & virtual flag
userId: this.userId,
virtual: true,
workspaceId: this.workspaceId ?? null,
}));
const newVirtualAgents = await trx.insert(agents).values(virtualAgentConfigs).returning();
@@ -440,6 +687,7 @@ export class AgentGroupRepository {
order: -1,
role: 'supervisor',
userId: this.userId,
workspaceId: this.workspaceId ?? null,
},
// Virtual members (using new copied agents)
...virtualMembers.map((member) => ({
@@ -449,6 +697,7 @@ export class AgentGroupRepository {
order: member.order,
role: member.role || 'participant',
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})),
// Non-virtual members (referencing same agents - only add relationship)
...nonVirtualMembers.map((member) => ({
@@ -458,6 +707,7 @@ export class AgentGroupRepository {
order: member.order,
role: member.role || 'participant',
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})),
];
@@ -469,4 +719,202 @@ export class AgentGroupRepository {
};
});
}
async transferToWorkspace(
groupId: string,
targetWorkspaceId: string | null,
targetUserId: string,
): Promise<{ groupId: string } | null> {
const sourceGroup = await this.db.query.chatGroups.findFirst({
where: and(eq(chatGroups.id, groupId), this.groupOwnership()),
});
if (!sourceGroup) return null;
return this.db.transaction(async (trx) => {
const memberRows = await trx
.select({ agentId: chatGroupsAgents.agentId })
.from(chatGroupsAgents)
.where(eq(chatGroupsAgents.chatGroupId, groupId));
const agentIds = memberRows.map((row) => row.agentId);
const ownershipUpdate = {
userId: targetUserId,
workspaceId: targetWorkspaceId,
};
await trx
.update(chatGroups)
.set({ ...ownershipUpdate, updatedAt: new Date() })
.where(eq(chatGroups.id, groupId));
await trx
.update(chatGroupsAgents)
.set(ownershipUpdate)
.where(eq(chatGroupsAgents.chatGroupId, groupId));
if (agentIds.length > 0) {
await trx
.delete(chatGroupsAgents)
.where(
and(
inArray(chatGroupsAgents.agentId, agentIds),
not(eq(chatGroupsAgents.chatGroupId, groupId)),
),
);
await trx
.update(agents)
.set({ ...ownershipUpdate, updatedAt: new Date() })
.where(inArray(agents.id, agentIds));
}
const groupTopics = await trx
.select({ id: topics.id })
.from(topics)
.where(eq(topics.groupId, groupId));
const groupTopicIds = groupTopics.map((topic) => topic.id);
await trx.update(topics).set(ownershipUpdate).where(eq(topics.groupId, groupId));
await trx.update(threads).set(ownershipUpdate).where(eq(threads.groupId, groupId));
await trx.update(messages).set(ownershipUpdate).where(eq(messages.groupId, groupId));
if (groupTopicIds.length > 0) {
await trx
.update(threads)
.set(ownershipUpdate)
.where(inArray(threads.topicId, groupTopicIds));
await trx
.update(messages)
.set(ownershipUpdate)
.where(inArray(messages.topicId, groupTopicIds));
}
return { groupId };
});
}
async copyToWorkspace(
groupId: string,
targetWorkspaceId: string | null,
targetUserId: string,
optionsOrNewTitle?: CopyAgentGroupToWorkspaceOptions | string,
): Promise<{ groupId: string; supervisorAgentId: string } | null> {
const options =
typeof optionsOrNewTitle === 'string'
? { newTitle: optionsOrNewTitle }
: (optionsOrNewTitle ?? {});
const sourceGroup = await this.db.query.chatGroups.findFirst({
where: and(eq(chatGroups.id, groupId), this.groupOwnership()),
});
if (!sourceGroup) return null;
const groupAgentsWithDetails = await this.db
.select({
agent: agents,
enabled: chatGroupsAgents.enabled,
order: chatGroupsAgents.order,
role: chatGroupsAgents.role,
})
.from(chatGroupsAgents)
.innerJoin(agents, eq(chatGroupsAgents.agentId, agents.id))
.where(eq(chatGroupsAgents.chatGroupId, groupId))
.orderBy(chatGroupsAgents.order);
const sourceSupervisor = groupAgentsWithDetails.find((row) => row.role === 'supervisor');
const sourceMembers = groupAgentsWithDetails.filter((row) => row.role !== 'supervisor');
return this.db.transaction(async (trx) => {
const [newGroup] = await trx
.insert(chatGroups)
.values({
avatar: sourceGroup.avatar,
backgroundColor: sourceGroup.backgroundColor,
config: sourceGroup.config,
content: sourceGroup.content,
description: sourceGroup.description,
editorData: sourceGroup.editorData,
pinned: sourceGroup.pinned,
title: options.newTitle || (sourceGroup.title ? `${sourceGroup.title} (Copy)` : 'Copy'),
userId: targetUserId,
workspaceId: targetWorkspaceId,
})
.returning();
const [newSupervisor] = await trx
.insert(agents)
.values(
this.buildCopiedAgent(
sourceSupervisor?.agent,
targetWorkspaceId,
targetUserId,
'Supervisor',
),
)
.returning();
const memberAgentIdMap = new Map<string, string>();
if (sourceMembers.length > 0) {
const newMembers = await trx
.insert(agents)
.values(
sourceMembers.map((member) =>
this.buildCopiedAgent(member.agent, targetWorkspaceId, targetUserId, 'Agent'),
),
)
.returning({ id: agents.id });
for (const [index, member] of sourceMembers.entries()) {
memberAgentIdMap.set(member.agent.id, newMembers[index].id);
}
}
const groupAgentValues: NewChatGroupAgent[] = [
{
agentId: newSupervisor.id,
chatGroupId: newGroup.id,
order: -1,
role: 'supervisor',
userId: targetUserId,
workspaceId: targetWorkspaceId,
},
...sourceMembers.map((member) => ({
agentId: memberAgentIdMap.get(member.agent.id)!,
chatGroupId: newGroup.id,
enabled: member.enabled,
order: member.order,
role: member.role || 'participant',
userId: targetUserId,
workspaceId: targetWorkspaceId,
})),
];
await trx.insert(chatGroupsAgents).values(groupAgentValues);
const agentIdMap = new Map<string, string>();
if (sourceSupervisor?.agent.id) {
agentIdMap.set(sourceSupervisor.agent.id, newSupervisor.id);
}
for (const [sourceAgentId, newAgentId] of memberAgentIdMap) {
agentIdMap.set(sourceAgentId, newAgentId);
}
if (options.includeConversationHistory) {
await this.copyGroupConversationHistory({
agentIdMap,
executor: trx,
newGroupId: newGroup.id,
sourceGroupId: groupId,
targetUserId,
targetWorkspaceId,
});
}
return {
groupId: newGroup.id,
supervisorAgentId: newSupervisor.id,
};
});
}
}
@@ -1,7 +1,9 @@
import { and, eq, inArray, isNotNull, isNull } from 'drizzle-orm';
import type { AnyPgColumn } from 'drizzle-orm/pg-core';
import { agents, agentsToSessions, messages, sessions, topics } from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { buildWorkspaceWhere } from '../../utils/workspace';
type MigrateBySessionParams = { agentId: string; sessionId: string };
type MigrateInboxParams = { agentId: string; isInbox: true; sessionId?: string | null };
@@ -16,12 +18,17 @@ type MigrateAgentIdParams = MigrateBySessionParams | MigrateInboxParams;
export class AgentMigrationRepo {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
}
private ws = (cols: { userId: AnyPgColumn; workspaceId: AnyPgColumn }) =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, cols);
/**
* Runtime migration: backfill agentId for all legacy topics and messages
* Used for progressive migration so future queries don't need agentsToSessions lookup
@@ -57,7 +64,7 @@ export class AgentMigrationRepo {
.from(topics)
.where(
and(
eq(topics.userId, this.userId),
this.ws(topics),
isNull(topics.sessionId),
isNull(topics.groupId),
isNull(topics.agentId),
@@ -74,7 +81,7 @@ export class AgentMigrationRepo {
.set({ agentId, updatedAt: topics.updatedAt })
.where(
and(
eq(topics.userId, this.userId),
this.ws(topics),
isNull(topics.sessionId),
isNull(topics.groupId),
isNull(topics.agentId),
@@ -85,13 +92,7 @@ export class AgentMigrationRepo {
await tx
.update(messages)
.set({ agentId, updatedAt: messages.updatedAt })
.where(
and(
eq(messages.userId, this.userId),
inArray(messages.topicId, topicIds),
isNull(messages.agentId),
),
);
.where(and(this.ws(messages), inArray(messages.topicId, topicIds), isNull(messages.agentId)));
// 4. Also update messages without topicId but in inbox (sessionId IS NULL) - preserve original updatedAt
await tx
@@ -99,7 +100,7 @@ export class AgentMigrationRepo {
.set({ agentId, updatedAt: messages.updatedAt })
.where(
and(
eq(messages.userId, this.userId),
this.ws(messages),
isNull(messages.sessionId),
isNull(messages.topicId),
isNull(messages.agentId),
@@ -118,13 +119,7 @@ export class AgentMigrationRepo {
const legacyTopics = await tx
.select({ id: topics.id })
.from(topics)
.where(
and(
eq(topics.userId, this.userId),
eq(topics.sessionId, sessionId),
isNull(topics.agentId),
),
);
.where(and(this.ws(topics), eq(topics.sessionId, sessionId), isNull(topics.agentId)));
const topicIds = legacyTopics.map((t) => t.id);
@@ -132,13 +127,7 @@ export class AgentMigrationRepo {
await tx
.update(topics)
.set({ agentId, updatedAt: topics.updatedAt })
.where(
and(
eq(topics.userId, this.userId),
eq(topics.sessionId, sessionId),
isNull(topics.agentId),
),
);
.where(and(this.ws(topics), eq(topics.sessionId, sessionId), isNull(topics.agentId)));
// 3. Update associated messages within these topics
if (topicIds.length > 0) {
@@ -146,11 +135,7 @@ export class AgentMigrationRepo {
.update(messages)
.set({ agentId, updatedAt: messages.updatedAt })
.where(
and(
eq(messages.userId, this.userId),
inArray(messages.topicId, topicIds),
isNull(messages.agentId),
),
and(this.ws(messages), inArray(messages.topicId, topicIds), isNull(messages.agentId)),
);
}
@@ -160,7 +145,7 @@ export class AgentMigrationRepo {
.set({ agentId, updatedAt: messages.updatedAt })
.where(
and(
eq(messages.userId, this.userId),
this.ws(messages),
eq(messages.sessionId, sessionId),
isNull(messages.topicId),
isNull(messages.agentId),
@@ -175,7 +160,7 @@ export class AgentMigrationRepo {
const result = await this.db
.select({ sessionId: agentsToSessions.sessionId })
.from(agentsToSessions)
.where(and(eq(agentsToSessions.agentId, agentId), eq(agentsToSessions.userId, this.userId)))
.where(and(eq(agentsToSessions.agentId, agentId), this.ws(agentsToSessions)))
.limit(1);
return result[0]?.sessionId ?? null;
@@ -202,13 +187,7 @@ export class AgentMigrationRepo {
.from(agents)
.innerJoin(agentsToSessions, eq(agents.id, agentsToSessions.agentId))
.innerJoin(sessions, eq(agentsToSessions.sessionId, sessions.id))
.where(
and(
eq(agents.userId, this.userId),
isNull(agents.sessionGroupId),
isNotNull(sessions.groupId),
),
);
.where(and(this.ws(agents), isNull(agents.sessionGroupId), isNotNull(sessions.groupId)));
if (agentsToMigrate.length === 0) return;
@@ -220,7 +199,7 @@ export class AgentMigrationRepo {
await this.db
.update(agents)
.set({ sessionGroupId: item.sessionGroupId, updatedAt: agents.updatedAt })
.where(and(eq(agents.id, item.agentId), eq(agents.userId, this.userId)));
.where(and(eq(agents.id, item.agentId), this.ws(agents)));
}
};
}
@@ -5,6 +5,7 @@ import { and, eq, inArray, isNull } from 'drizzle-orm';
import type { MessageGroupItem } from '../../schemas';
import { messageGroups, messages } from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { buildWorkspaceWhere } from '../../utils/workspace';
export interface CreateCompressionGroupParams {
content: string;
@@ -31,12 +32,20 @@ export interface CompressionGroupResult {
export class CompressionRepository {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.workspaceId = workspaceId;
}
private groupsOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, messageGroups);
private messagesOwnership = () =>
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, messages);
/**
* Create a compression group and mark messages as compressed
*/
@@ -56,6 +65,7 @@ export class CompressionRepository {
topicId,
type: MessageGroupType.Compression,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})
.returning()) as MessageGroupItem[];
@@ -78,7 +88,7 @@ export class CompressionRepository {
.from(messageGroups)
.where(
and(
eq(messageGroups.userId, this.userId),
this.groupsOwnership(),
eq(messageGroups.topicId, topicId),
eq(messageGroups.type, MessageGroupType.Compression),
),
@@ -118,7 +128,7 @@ export class CompressionRepository {
const existing = await this.db
.select({ description: messageGroups.description })
.from(messageGroups)
.where(and(eq(messageGroups.id, groupId), eq(messageGroups.userId, this.userId)));
.where(and(eq(messageGroups.id, groupId), this.groupsOwnership()));
const existingMetadata = existing[0]?.description ? JSON.parse(existing[0].description) : {};
updateData.description = JSON.stringify({ ...existingMetadata, ...metadata });
@@ -127,7 +137,7 @@ export class CompressionRepository {
await this.db
.update(messageGroups)
.set(updateData)
.where(and(eq(messageGroups.id, groupId), eq(messageGroups.userId, this.userId)));
.where(and(eq(messageGroups.id, groupId), this.groupsOwnership()));
}
/**
@@ -141,7 +151,7 @@ export class CompressionRepository {
const existing = await this.db
.select({ metadata: messageGroups.metadata })
.from(messageGroups)
.where(and(eq(messageGroups.id, groupId), eq(messageGroups.userId, this.userId)));
.where(and(eq(messageGroups.id, groupId), this.groupsOwnership()));
const existingData = (existing[0]?.metadata as Record<string, unknown>) || {};
const newMetadata = { ...existingData, ...metadata };
@@ -149,7 +159,7 @@ export class CompressionRepository {
await this.db
.update(messageGroups)
.set({ metadata: newMetadata, updatedAt: new Date() })
.where(and(eq(messageGroups.id, groupId), eq(messageGroups.userId, this.userId)));
.where(and(eq(messageGroups.id, groupId), this.groupsOwnership()));
}
/**
@@ -161,7 +171,7 @@ export class CompressionRepository {
await this.db
.update(messages)
.set({ messageGroupId: groupId })
.where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIds)));
.where(and(this.messagesOwnership(), inArray(messages.id, messageIds)));
}
/**
@@ -173,7 +183,7 @@ export class CompressionRepository {
await this.db
.update(messages)
.set({ messageGroupId: null })
.where(and(eq(messages.userId, this.userId), inArray(messages.id, messageIds)));
.where(and(this.messagesOwnership(), inArray(messages.id, messageIds)));
}
/**
@@ -184,7 +194,7 @@ export class CompressionRepository {
const [message] = await this.db
.select({ metadata: messages.metadata })
.from(messages)
.where(and(eq(messages.id, messageId), eq(messages.userId, this.userId)));
.where(and(eq(messages.id, messageId), this.messagesOwnership()));
if (!message) return;
@@ -194,7 +204,7 @@ export class CompressionRepository {
await this.db
.update(messages)
.set({ metadata: newMetadata })
.where(and(eq(messages.id, messageId), eq(messages.userId, this.userId)));
.where(and(eq(messages.id, messageId), this.messagesOwnership()));
}
/**
@@ -206,7 +216,7 @@ export class CompressionRepository {
.from(messages)
.where(
and(
eq(messages.userId, this.userId),
this.messagesOwnership(),
eq(messages.topicId, topicId),
isNull(messages.messageGroupId),
),
@@ -221,7 +231,7 @@ export class CompressionRepository {
return this.db
.select()
.from(messages)
.where(and(eq(messages.userId, this.userId), eq(messages.messageGroupId, groupId)))
.where(and(this.messagesOwnership(), eq(messages.messageGroupId, groupId)))
.orderBy(messages.createdAt);
}
@@ -233,11 +243,11 @@ export class CompressionRepository {
await this.db
.update(messages)
.set({ messageGroupId: null })
.where(and(eq(messages.userId, this.userId), eq(messages.messageGroupId, groupId)));
.where(and(this.messagesOwnership(), eq(messages.messageGroupId, groupId)));
// 2. Delete the group
await this.db
.delete(messageGroups)
.where(and(eq(messageGroups.id, groupId), eq(messageGroups.userId, this.userId)));
.where(and(eq(messageGroups.id, groupId), this.groupsOwnership()));
}
}
@@ -16,6 +16,7 @@ import {
topics,
users,
userSettings,
workspaces,
} from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { DATA_EXPORT_CONFIG, DataExporterRepos } from './index';
@@ -361,5 +362,140 @@ describe('DataExporterRepos', () => {
expect(result.sessions[0]).not.toHaveProperty('userId', anotherUserId);
expect(result.sessions[0]).toHaveProperty('id', 'another-session-id');
});
it('should not include workspace-scoped rows in personal export', async () => {
const workspaceId = 'workspace-export-filter';
await db.transaction(async (trx) => {
await trx.insert(workspaces).values({
id: workspaceId,
name: 'Workspace Export Filter',
primaryOwnerId: userId,
slug: workspaceId,
});
await trx.insert(agents).values({
id: 'workspace-agent-id',
title: 'Workspace Agent',
userId,
workspaceId,
});
await trx.insert(sessions).values({
id: 'workspace-session-id',
slug: 'workspace-session',
title: 'Workspace Session',
userId,
workspaceId,
});
await trx.insert(topics).values({
id: 'workspace-topic-id',
sessionId: 'workspace-session-id',
title: 'Workspace Topic',
userId,
workspaceId,
});
await trx.insert(messages).values({
content: 'Workspace message',
id: 'workspace-message-id',
role: 'user',
sessionId: 'workspace-session-id',
topicId: 'workspace-topic-id',
userId,
workspaceId,
});
});
const result = await new DataExporterRepos(db, userId).export();
expect(result.agents.map((agent) => agent.id)).toEqual([testIds.agentId]);
expect(result.sessions.map((session) => session.id)).toEqual([testIds.sessionId]);
expect(result.topics.map((topic) => topic.id)).toEqual([testIds.topicId]);
expect(result.messages.map((message) => message.id)).toEqual([testIds.messageId]);
});
it('should export only the selected workspace scope when workspaceId is provided', async () => {
const workspaceId = 'workspace-export-scope';
const otherWorkspaceId = 'workspace-export-other';
await db.transaction(async (trx) => {
await trx.insert(workspaces).values([
{
id: workspaceId,
name: 'Workspace Export Scope',
primaryOwnerId: userId,
slug: workspaceId,
},
{
id: otherWorkspaceId,
name: 'Other Workspace Export Scope',
primaryOwnerId: userId,
slug: otherWorkspaceId,
},
]);
await trx.insert(agents).values([
{
id: 'workspace-agent-id',
title: 'Workspace Agent',
userId,
workspaceId,
},
{
id: 'other-workspace-agent-id',
title: 'Other Workspace Agent',
userId,
workspaceId: otherWorkspaceId,
},
]);
await trx.insert(sessions).values([
{
id: 'workspace-session-id',
slug: 'workspace-session',
title: 'Workspace Session',
userId,
workspaceId,
},
{
id: 'other-workspace-session-id',
slug: 'other-workspace-session',
title: 'Other Workspace Session',
userId,
workspaceId: otherWorkspaceId,
},
]);
await trx.insert(agentsToSessions).values({
agentId: 'workspace-agent-id',
sessionId: 'workspace-session-id',
userId,
});
await trx.insert(topics).values({
id: 'workspace-topic-id',
sessionId: 'workspace-session-id',
title: 'Workspace Topic',
userId,
workspaceId,
});
await trx.insert(messages).values({
content: 'Workspace message',
id: 'workspace-message-id',
role: 'user',
sessionId: 'workspace-session-id',
topicId: 'workspace-topic-id',
userId,
workspaceId,
});
});
const result = await new DataExporterRepos(db, userId, workspaceId).export();
expect(result.userSettings).toEqual([]);
expect(result.agents.map((agent) => agent.id)).toEqual(['workspace-agent-id']);
expect(result.sessions.map((session) => session.id)).toEqual(['workspace-session-id']);
expect(result.topics.map((topic) => topic.id)).toEqual(['workspace-topic-id']);
expect(result.messages.map((message) => message.id)).toEqual(['workspace-message-id']);
expect(result.agentsToSessions).toHaveLength(1);
expect(result.agentsToSessions[0]).toMatchObject({
agentId: 'workspace-agent-id',
sessionId: 'workspace-session-id',
});
});
});
});
@@ -3,6 +3,7 @@ import pMap from 'p-map';
import * as EXPORT_TABLES from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { buildWorkspaceWhere } from '../../utils/workspace';
interface BaseTableConfig {
table: keyof typeof EXPORT_TABLES;
@@ -83,10 +84,12 @@ export const DATA_EXPORT_CONFIG = {
export class DataExporterRepos {
private userId: string;
private db: LobeChatDatabase;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.db = db;
this.userId = userId;
this.workspaceId = workspaceId;
}
private removeUserId(data: any[]) {
@@ -110,7 +113,7 @@ export class DataExporterRepos {
// If source data is empty, this table may not be able to query any data
if (sourceData.length === 0) {
console.log(
console.info(
`Source table ${relation.sourceTable} has no data, skipping query for ${table}`,
);
return [];
@@ -120,7 +123,12 @@ export class DataExporterRepos {
conditions.push(inArray(tableObj[relation.field], sourceIds));
}
// If table has userId field and is not the users table, add user filter
// If table has userId field and is not the users table, add user filter.
// workspace-audit: this branch only runs for non-relation tables; relation
// tables (which carry workspace_id) are already constrained by the FK
// `inArray(sourceIds)` above, where sourceIds come from base tables that ARE
// workspace-scoped (see queryBaseTables / buildWorkspaceWhere) — so relation
// rows are transitively workspace-scoped and need no userId/workspaceId filter here.
if ('userId' in tableObj && table !== 'users' && !config.relations) {
conditions.push(eq(tableObj.userId, this.userId));
}
@@ -132,7 +140,7 @@ export class DataExporterRepos {
const result = await this.db.query[table].findMany({ where });
// Only remove userId field for tables queried with userId
console.log(`Successfully exported table: ${table}, count: ${result.length}`);
console.info(`Successfully exported table: ${table}, count: ${result.length}`);
return config.relations ? result : this.removeUserId(result);
} catch (error) {
console.error(`Error querying table ${table}:`, error);
@@ -146,17 +154,24 @@ export class DataExporterRepos {
if (!tableObj) throw new Error(`Table ${table} not found`);
try {
if (this.workspaceId && !('workspaceId' in tableObj)) {
return [];
}
// If there's relation config, use relation query
// Default to querying with userId, use userField for special cases
const userField = config.userField || 'userId';
const where = eq(tableObj[userField], this.userId);
const where =
'workspaceId' in tableObj
? buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, tableObj)
: eq(tableObj[userField], this.userId);
// @ts-expect-error query
const result = await this.db.query[table].findMany({ where });
// Only remove userId field for tables queried with userId
console.log(`Successfully exported table: ${table}, count: ${result.length}`);
console.info(`Successfully exported table: ${table}, count: ${result.length}`);
return this.removeUserId(result);
} catch (error) {
console.error(`Error querying table ${table}:`, error);
@@ -168,7 +183,7 @@ export class DataExporterRepos {
const result: Record<string, any[]> = {};
// 1. First query all base tables concurrently
console.log('Querying base tables...');
console.info('Querying base tables...');
const baseResults = await pMap(
DATA_EXPORT_CONFIG.baseTables,
async (config) => ({ data: await this.queryBaseTables(config), table: config.table }),
@@ -191,7 +206,7 @@ export class DataExporterRepos {
);
if (!allSourcesHaveData) {
console.log(`Skipping table ${config.table} as some source tables have no data`);
console.info(`Skipping table ${config.table} as some source tables have no data`);
return { data: [], table: config.table };
}
@@ -208,8 +223,6 @@ export class DataExporterRepos {
result[table] = data;
});
console.log('finalResults:', result);
return result;
}
}
@@ -1,5 +1,5 @@
import type { ImporterEntryData } from '@lobechat/types';
import { and, eq, inArray, sql } from 'drizzle-orm';
import { and, inArray, sql } from 'drizzle-orm';
import { sanitizeUTF8 } from '@/utils/sanitizeUTF8';
@@ -14,6 +14,7 @@ import {
topics,
} from '../../../schemas';
import type { LobeChatDatabase } from '../../../type';
import { buildWorkspaceWhere } from '../../../utils/workspace';
interface ImportResult {
added: number;
@@ -24,6 +25,7 @@ interface ImportResult {
export class DeprecatedDataImporterRepos {
private userId: string;
private workspaceId?: string;
private db: LobeChatDatabase;
/**
@@ -31,11 +33,17 @@ export class DeprecatedDataImporterRepos {
*/
supportVersion = 7;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.workspaceId = workspaceId;
this.db = db;
}
/** Helper: scope predicate for workspace-aware tables. */
private workspaceWhere(table: { userId: any; workspaceId: any }) {
return buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, table);
}
importData = async (data: ImporterEntryData) => {
if (data.version > this.supportVersion) throw new Error('Unsupported version');
@@ -53,7 +61,7 @@ export class DeprecatedDataImporterRepos {
if (data.sessionGroups && data.sessionGroups.length > 0) {
const query = await trx.query.sessionGroups.findMany({
where: and(
eq(sessionGroups.userId, this.userId),
this.workspaceWhere(sessionGroups),
inArray(
sessionGroups.clientId,
data.sessionGroups.map(({ id }) => id),
@@ -72,6 +80,7 @@ export class DeprecatedDataImporterRepos {
createdAt: new Date(createdAt),
updatedAt: new Date(updatedAt),
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})),
)
.onConflictDoUpdate({
@@ -89,7 +98,7 @@ export class DeprecatedDataImporterRepos {
if (data.sessions && data.sessions.length > 0) {
const query = await trx.query.sessions.findMany({
where: and(
eq(sessions.userId, this.userId),
this.workspaceWhere(sessions),
inArray(
sessions.clientId,
data.sessions.map(({ id }) => id),
@@ -109,6 +118,7 @@ export class DeprecatedDataImporterRepos {
groupId: group ? sessionGroupIdMap[group] : null,
updatedAt: new Date(updatedAt),
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})),
)
.onConflictDoUpdate({
@@ -136,6 +146,7 @@ export class DeprecatedDataImporterRepos {
...config,
...meta,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})),
)
.returning({ id: agents.id });
@@ -145,6 +156,7 @@ export class DeprecatedDataImporterRepos {
agentId: agentMapArray[index].id,
sessionId: sessionIdMap[id],
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})),
);
}
@@ -154,7 +166,7 @@ export class DeprecatedDataImporterRepos {
if (data.topics && data.topics.length > 0) {
const skipQuery = await trx.query.topics.findMany({
where: and(
eq(topics.userId, this.userId),
this.workspaceWhere(topics),
inArray(
topics.clientId,
data.topics.map(({ id }) => id),
@@ -174,6 +186,7 @@ export class DeprecatedDataImporterRepos {
sessionId: sessionId ? sessionIdMap[sessionId] : null,
updatedAt: new Date(updatedAt),
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})),
)
.onConflictDoUpdate({
@@ -190,17 +203,15 @@ export class DeprecatedDataImporterRepos {
// import messages
if (data.messages && data.messages.length > 0) {
// 1. find skip ones
console.time('find messages');
const skipQuery = await trx.query.messages.findMany({
where: and(
eq(messages.userId, this.userId),
this.workspaceWhere(messages),
inArray(
messages.clientId,
data.messages.map(({ id }) => id),
),
),
});
console.timeEnd('find messages');
messageResult.skips = skipQuery.length;
@@ -224,10 +235,10 @@ export class DeprecatedDataImporterRepos {
topicId: topicId ? topicIdMap[topicId] : null, // Temporarily set to NULL
updatedAt: new Date(updatedAt),
userId: this.userId,
workspaceId: this.workspaceId ?? null,
}),
);
console.time('insert messages');
const BATCH_SIZE = 100; // Number of records to insert per batch
for (let i = 0; i < inertValues.length; i += BATCH_SIZE) {
@@ -235,14 +246,12 @@ export class DeprecatedDataImporterRepos {
await trx.insert(messages).values(batch);
}
console.timeEnd('insert messages');
const messageIdArray = await trx
.select({ clientId: messages.clientId, id: messages.id })
.from(messages)
.where(
and(
eq(messages.userId, this.userId),
this.workspaceWhere(messages),
inArray(
messages.clientId,
data.messages.map(({ id }) => id),
@@ -255,7 +264,6 @@ export class DeprecatedDataImporterRepos {
);
// 3. update parentId for messages
console.time('execute updates parentId');
const parentIdUpdates = shouldInsertMessages
.filter((msg) => msg.parentId) // Only process messages with parentId
.map((msg) => {
@@ -284,7 +292,6 @@ export class DeprecatedDataImporterRepos {
// console.log('sql:', SQL.sql);
// console.log('params:', SQL.params);
}
console.timeEnd('execute updates parentId');
// 4. insert message plugins
const pluginInserts = shouldInsertMessages.filter((msg) => msg.plugin);
@@ -299,6 +306,7 @@ export class DeprecatedDataImporterRepos {
toolCallId: msg.tool_call_id,
type: msg.plugin?.type,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})),
);
}
@@ -311,6 +319,7 @@ export class DeprecatedDataImporterRepos {
id: messageIdMap[msg.id],
...msg.extra?.translate,
userId: this.userId,
workspaceId: this.workspaceId ?? null,
})),
);
}
@@ -5,6 +5,7 @@ import { uuid } from '@/utils/uuid';
import * as EXPORT_TABLES from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { buildWorkspaceWhere } from '../../utils/workspace';
import { DeprecatedDataImporterRepos } from './deprecated';
interface ImportResult {
@@ -256,15 +257,17 @@ const IMPORT_TABLE_CONFIG: TableImportConfig[] = [
export class DataImporterRepos {
private userId: string;
private workspaceId?: string;
private db: LobeChatDatabase;
private deprecatedDataImporterRepos: DeprecatedDataImporterRepos;
private idMaps: Record<string, Record<string, string>> = {};
private conflictRecords: Record<string, { field: string; value: any }[]> = {};
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.workspaceId = workspaceId;
this.db = db;
this.deprecatedDataImporterRepos = new DeprecatedDataImporterRepos(db, userId);
this.deprecatedDataImporterRepos = new DeprecatedDataImporterRepos(db, userId, workspaceId);
}
importData = async (data: ImporterEntryData): Promise<ImportResultData> => {
@@ -301,7 +304,7 @@ export class DataImporterRepos {
// Use unified import method
const result = await this.importTableData(trx, config, tableData, conflictStrategy);
console.log(`imported table: ${tableName}, records: ${tableData.length}`);
console.info(`imported table: ${tableName}, records: ${tableData.length}`);
if (Object.values(result).some((value) => value > 0)) {
results[tableName] = result;
@@ -381,8 +384,15 @@ export class DataImporterRepos {
const clientIds = tableData.map((item) => item.clientId || item.id).filter(Boolean);
if (clientIds.length > 0) {
const workspaceFilter =
'workspaceId' in table
? buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
table as any,
)
: eq(table.userId, this.userId);
existingRecords = await trx.query[tableName].findMany({
where: and(eq(table.userId, this.userId), inArray(table.clientId, clientIds)),
where: and(workspaceFilter, inArray(table.clientId, clientIds)),
});
}
}
@@ -453,7 +463,7 @@ export class DataImporterRepos {
if (item.accessedAt) dateFields.accessedAt = new Date(item.accessedAt);
// Create new record object
let newRecord: any = {};
let newRecord: any;
// Decide how to process based on whether it's composite key and whether to preserve ID
if (isCompositeKey) {
@@ -465,6 +475,7 @@ export class DataImporterRepos {
...dateFields,
clientId: item.clientId || item.id,
userId: this.userId,
...('workspaceId' in table ? { workspaceId: this.workspaceId ?? null } : {}),
};
} else {
// Non-composite key table processing
@@ -473,6 +484,7 @@ export class DataImporterRepos {
...dateFields,
clientId: item.clientId || item.id,
userId: this.userId,
...('workspaceId' in table ? { workspaceId: this.workspaceId ?? null } : {}),
};
}
@@ -526,9 +538,18 @@ export class DataImporterRepos {
.filter((field) => record.newRecord[field] !== undefined)
.map((field) => eq(table[field], record.newRecord[field]));
// Add userId condition (if table has userId field)
// Add userId/workspaceId condition (if table has these fields)
if ('userId' in table) {
whereConditions.push(eq(table.userId, this.userId));
if ('workspaceId' in table) {
whereConditions.push(
buildWorkspaceWhere(
{ userId: this.userId, workspaceId: this.workspaceId },
table as any,
),
);
} else {
whereConditions.push(eq(table.userId, this.userId));
}
}
if (whereConditions.length > 0) {
@@ -16,6 +16,7 @@ import {
} from '../../schemas';
import { type LobeChatDatabase } from '../../type';
import { sanitizeBm25Query } from '../../utils/bm25';
import { buildWorkspaceWhere } from '../../utils/workspace';
// Re-export types for backward compatibility
export type {
@@ -30,13 +31,19 @@ export type {
*/
export class HomeRepository {
private userId: string;
private workspaceId?: string;
private db: LobeChatDatabase;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.workspaceId = workspaceId;
this.db = db;
}
private get scope() {
return { userId: this.userId, workspaceId: this.workspaceId };
}
/**
* Get sidebar agent list with pinned, grouped, and ungrouped items
*/
@@ -60,7 +67,7 @@ export class HomeRepository {
.from(agents)
.leftJoin(agentsToSessions, eq(agents.id, agentsToSessions.agentId))
.leftJoin(sessions, eq(agentsToSessions.sessionId, sessions.id))
.where(and(eq(agents.userId, this.userId), not(eq(agents.virtual, true))))
.where(and(buildWorkspaceWhere(this.scope, agents), not(eq(agents.virtual, true))))
.orderBy(desc(agents.updatedAt));
// 2. Query all chatGroups (group chats)
@@ -76,7 +83,7 @@ export class HomeRepository {
updatedAt: chatGroups.updatedAt,
})
.from(chatGroups)
.where(eq(chatGroups.userId, this.userId))
.where(buildWorkspaceWhere(this.scope, chatGroups))
.orderBy(desc(chatGroups.updatedAt));
// 2.1 Query member avatars for each chat group
@@ -90,7 +97,7 @@ export class HomeRepository {
sort: sessionGroups.sort,
})
.from(sessionGroups)
.where(eq(sessionGroups.userId, this.userId))
.where(buildWorkspaceWhere(this.scope, sessionGroups))
.orderBy(sessionGroups.sort);
// 4. Process and categorize
@@ -225,7 +232,7 @@ export class HomeRepository {
.leftJoin(sessions, eq(agentsToSessions.sessionId, sessions.id))
.where(
and(
eq(agents.userId, this.userId),
buildWorkspaceWhere(this.scope, agents),
not(eq(agents.virtual, true)),
sql`(${agents.title} @@@ ${bm25Query} OR ${agents.description} @@@ ${bm25Query})`,
),
@@ -245,7 +252,7 @@ export class HomeRepository {
.from(chatGroups)
.where(
and(
eq(chatGroups.userId, this.userId),
buildWorkspaceWhere(this.scope, chatGroups),
sql`(${chatGroups.title} @@@ ${bm25Query} OR ${chatGroups.description} @@@ ${bm25Query})`,
),
)
@@ -8,6 +8,7 @@ import { documents, files } from '../../schemas/file';
import { chunks, embeddings } from '../../schemas/rag';
import { fileChunks } from '../../schemas/relations';
import { users } from '../../schemas/user';
import { workspaces } from '../../schemas/workspace';
import type { LobeChatDatabase } from '../../type';
import { KnowledgeRepo } from './index';
@@ -257,6 +258,78 @@ describe('KnowledgeRepo', () => {
});
});
describe('query - workspace isolation', () => {
const workspaceId = 'knowledge-workspace';
beforeEach(async () => {
await serverDB.insert(workspaces).values({
id: workspaceId,
name: 'Knowledge Workspace',
primaryOwnerId: userId,
slug: workspaceId,
});
await serverDB.insert(files).values([
{
fileType: 'application/pdf',
name: 'workspace-owner-file.pdf',
size: 1024,
url: 'workspace-owner-file-url',
userId,
workspaceId,
},
{
fileType: 'application/pdf',
name: 'viewer-personal-file.pdf',
size: 1024,
url: 'viewer-personal-file-url',
userId: otherUserId,
},
]);
await serverDB.insert(documents).values([
{
content: 'Workspace owner document',
fileType: 'application/pdf',
filename: 'workspace-owner-doc.pdf',
source: 'workspace-owner-source',
sourceType: 'api',
totalCharCount: 100,
totalLineCount: 10,
userId,
workspaceId,
},
{
content: 'Viewer personal document',
fileType: 'application/pdf',
filename: 'viewer-personal-doc.pdf',
source: 'viewer-personal-source',
sourceType: 'api',
totalCharCount: 100,
totalLineCount: 10,
userId: otherUserId,
},
]);
});
it('should return workspace items regardless of the creator user', async () => {
const workspaceRepo = new KnowledgeRepo(serverDB, otherUserId, workspaceId);
const results = await workspaceRepo.query({ category: FilesTabs.All });
const names = results.map((item) => item.name).sort();
expect(names).toEqual(['workspace-owner-doc.pdf', 'workspace-owner-file.pdf']);
});
it('should not return workspace items in personal mode', async () => {
const results = await knowledgeRepo.query({ category: FilesTabs.All });
const names = results.map((item) => item.name).sort();
expect(names).not.toContain('workspace-owner-doc.pdf');
expect(names).not.toContain('workspace-owner-file.pdf');
});
});
describe('query - search filtering', () => {
beforeEach(async () => {
await serverDB.insert(files).values([
@@ -6,6 +6,7 @@ import { DocumentModel } from '../../models/document';
import { FileModel } from '../../models/file';
import { DOCUMENT_FOLDER_TYPE, documents, files, knowledgeBaseFiles } from '../../schemas';
import type { LobeChatDatabase } from '../../type';
import { buildWorkspaceWhere } from '../../utils/workspace';
export interface KnowledgeItem {
chunkTaskId?: string | null;
@@ -39,14 +40,26 @@ export class KnowledgeRepo {
private db: LobeChatDatabase;
private fileModel: FileModel;
private documentModel: DocumentModel;
private workspaceId?: string;
constructor(db: LobeChatDatabase, userId: string) {
constructor(db: LobeChatDatabase, userId: string, workspaceId?: string) {
this.userId = userId;
this.db = db;
this.fileModel = new FileModel(db, userId);
this.documentModel = new DocumentModel(db, userId);
this.workspaceId = workspaceId;
this.fileModel = new FileModel(db, userId, workspaceId);
this.documentModel = new DocumentModel(db, userId, workspaceId);
}
private fileOwnershipSql = (alias: 'f' = 'f') =>
this.workspaceId
? sql`${sql.raw(`${alias}.workspace_id`)} = ${this.workspaceId}`
: sql`${sql.raw(`${alias}.user_id`)} = ${this.userId} AND ${sql.raw(`${alias}.workspace_id`)} IS NULL`;
private documentOwnershipSql = (alias: 'd' | 'documents' = 'd') =>
this.workspaceId
? sql`${sql.raw(`${alias}.workspace_id`)} = ${this.workspaceId}`
: sql`${sql.raw(`${alias}.user_id`)} = ${this.userId} AND ${sql.raw(`${alias}.workspace_id`)} IS NULL`;
/**
* Query combined results from files and documents tables
*/
@@ -183,7 +196,7 @@ export class KnowledgeRepo {
FROM ${files} f
LEFT JOIN ${documents} d
ON f.id = d.file_id
WHERE f.user_id = ${this.userId}
WHERE ${this.fileOwnershipSql('f')}
AND NOT EXISTS (
SELECT 1 FROM ${knowledgeBaseFiles}
WHERE ${knowledgeBaseFiles.fileId} = f.id
@@ -209,7 +222,7 @@ export class KnowledgeRepo {
metadata,
'document' as source_type
FROM ${documents}
WHERE user_id = ${this.userId}
WHERE ${this.documentOwnershipSql('documents')}
AND source_type != ${'file'}
AND knowledge_base_id IS NULL
`;
@@ -315,7 +328,10 @@ export class KnowledgeRepo {
if (document.fileType === DOCUMENT_FOLDER_TYPE) {
const children = await this.db.query.documents.findMany({
where: and(eq(documents.parentId, id), eq(documents.userId, this.userId)),
where: and(
eq(documents.parentId, id),
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, documents),
),
});
for (const child of children) {
@@ -323,7 +339,10 @@ export class KnowledgeRepo {
}
const childFiles = await this.db.query.files.findMany({
where: and(eq(files.parentId, id), eq(files.userId, this.userId)),
where: and(
eq(files.parentId, id),
buildWorkspaceWhere({ userId: this.userId, workspaceId: this.workspaceId }, files),
),
});
for (const file of childFiles) {
@@ -345,7 +364,7 @@ export class KnowledgeRepo {
showFilesInKnowledgeBase,
parentId,
}: QueryFileListParams = {}): ReturnType<typeof sql> {
const whereConditions: any[] = [sql`f.user_id = ${this.userId}`];
const whereConditions: any[] = [this.fileOwnershipSql('f')];
// Parent ID filter
if (parentId !== undefined) {
@@ -376,7 +395,7 @@ export class KnowledgeRepo {
// Knowledge base filter
if (knowledgeBaseId) {
// Build where conditions using proper table references (f.column instead of files.column)
const kbWhereConditions: any[] = [sql`f.user_id = ${this.userId}`];
const kbWhereConditions: any[] = [this.fileOwnershipSql('f')];
// Parent ID filter
if (parentId !== undefined) {
@@ -477,7 +496,7 @@ export class KnowledgeRepo {
parentId,
}: QueryFileListParams = {}): ReturnType<typeof sql> {
const whereConditions: any[] = [
sql`${documents.userId} = ${this.userId}`,
this.documentOwnershipSql('documents'),
sql`${documents.sourceType} != ${'file'}`,
];
@@ -542,7 +561,7 @@ export class KnowledgeRepo {
// Documents are linked to knowledge bases through files table via fileId
if (knowledgeBaseId) {
// Build where conditions using proper table references (d.column instead of documents.column)
const kbWhereConditions: any[] = [sql`d.user_id = ${this.userId}`];
const kbWhereConditions: any[] = [this.documentOwnershipSql('d')];
// Parent ID filter
if (parentId !== undefined) {

Some files were not shown because too many files have changed in this diff Show More