🐛 fix: reduce agent document context latency (#15436)

This commit is contained in:
YuTengjing
2026-06-04 16:23:51 +08:00
committed by GitHub
parent 1e2c1aacd5
commit bab3ff4a7a
48 changed files with 2355 additions and 422 deletions
+11 -58
View File
@@ -29,10 +29,9 @@ Standard workflow for verifying backend changes using the LobeHub CLI (`lh`) aga
## Quick Reference
All CLI dev commands run from `lobehub/apps/cli/`:
All CLI dev commands run from `lobehub/apps/cli/`. Subsequent examples use `$CLI`:
```bash
# Shorthand for all commands below
CLI="LOBEHUB_CLI_HOME=.lobehub-dev bun src/index.ts"
```
@@ -40,17 +39,14 @@ CLI="LOBEHUB_CLI_HOME=.lobehub-dev bun src/index.ts"
### Step 1: Ensure Dev Server is Running
Check if the dev server is already running:
```bash
curl -s -o /dev/null -w '%{http_code}' http://localhost:3011/ 2> /dev/null
```
- **If reachable** (returns any HTTP status): server is running. Skip to Step 2.
- **If unreachable**: start the server:
- **If reachable**: skip to Step 2.
- **If unreachable**: start from cloud repo root:
```bash
# From cloud repo root
pnpm run dev:next
```
@@ -65,37 +61,33 @@ pnpm run dev:next
### Step 2: Check CLI Authentication
Check if dev credentials already exist:
```bash
cat lobehub/apps/cli/.lobehub-dev/settings.json 2> /dev/null
```
- **If file exists and contains `"serverUrl": "http://localhost:3011"`**: already authenticated. Skip to Step 3.
- **If file missing or points to wrong server**: login is needed. Ask the user to run:
- **If file exists and contains `"serverUrl": "http://localhost:3011"`**: skip to Step 3.
- **If missing or wrong server**: ask the user to run:
```bash
! cd lobehub/apps/cli && LOBEHUB_CLI_HOME=.lobehub-dev bun src/index.ts login --server http://localhost:3011
```
> Login requires interactive browser authorization (OIDC Device Code Flow), so the user must run it themselves via `!` prefix. After login, credentials are saved to `lobehub/apps/cli/.lobehub-dev/` and persist across sessions.
> Login requires interactive browser authorization (OIDC Device Code Flow), so the user must run it themselves via `!` prefix. Credentials persist in `lobehub/apps/cli/.lobehub-dev/`.
### Step 3: Test with CLI Commands
CLI runs from source (`bun src/index.ts`), so CLI-side code changes take effect immediately without rebuilding.
CLI runs from source, so CLI-side code changes take effect immediately without rebuilding.
```bash
cd lobehub/apps/cli
LOBEHUB_CLI_HOME=.lobehub-dev bun src/index.ts <command>
$CLI <command>
```
### Step 4: Clean Up Test Data
Delete any test data created during verification:
```bash
LOBEHUB_CLI_HOME=.lobehub-dev bun src/index.ts task delete < id > -y
LOBEHUB_CLI_HOME=.lobehub-dev bun src/index.ts agent delete < id > -y
$CLI task delete < id > -y
$CLI agent delete < id > -y
```
## Common Testing Patterns
@@ -103,51 +95,30 @@ LOBEHUB_CLI_HOME=.lobehub-dev bun src/index.ts agent delete < id > -y
### Task System
```bash
# List tasks
$CLI task list
# Create test data with nesting
$CLI task create -n "Root Task" -i "Test instruction"
$CLI task create -n "Child Task" -i "Sub instruction" --parent T-1
# View task detail (tests getTaskDetail service)
$CLI task view T-1
# View task tree
$CLI task tree T-1
# Test lifecycle
$CLI task edit T-1 --status running
$CLI task comment T-1 -m "Test comment"
# Clean up
$CLI task delete T-1 -y
```
### Agent System
```bash
# List agents
$CLI agent list
# View agent detail
$CLI agent view <agent-id>
# Run agent (tests agent execution pipeline)
$CLI agent run <agent-id> -m "Test prompt"
```
### Document & Knowledge Base
```bash
# List documents
$CLI doc list
# Create and view
$CLI doc create -t "Test Doc" -c "Content here"
$CLI doc view <doc-id>
# Knowledge base
$CLI kb list
$CLI kb tree <kb-id>
```
@@ -155,18 +126,13 @@ $CLI kb tree <kb-id>
### Model & Provider
```bash
# List models and providers
$CLI model list
$CLI provider list
# Test provider connectivity
$CLI provider test <provider-id>
```
## Dev-Test Cycle
The standard cycle for backend development:
```
1. Make code changes (service/model/router/type)
|
@@ -177,7 +143,7 @@ The standard cycle for backend development:
lsof -ti:3011 | xargs kill && pnpm run dev:next
|
4. CLI verification (end-to-end)
LOBEHUB_CLI_HOME=.lobehub-dev bun src/index.ts <command>
$CLI <command>
|
5. Clean up test data
```
@@ -193,10 +159,6 @@ The standard cycle for backend development:
| `lobehub/apps/cli/` (CLI code) | No |
| `src/` (cloud overrides) | Yes |
### When Server Restart is NOT Needed
CLI runs from source via `bun src/index.ts`, so any changes to `lobehub/apps/cli/src/` take effect immediately on next command invocation.
## Troubleshooting
| Issue | Solution |
@@ -207,12 +169,3 @@ CLI runs from source via `bun src/index.ts`, so any changes to `lobehub/apps/cli
| CLI shows old data/behavior | Server needs restart to pick up code changes |
| `EADDRINUSE` on port 3011 | Server already running; kill with `lsof -ti:3011 \| xargs kill` |
| Login opens wrong server | Must use `--server http://localhost:3011` flag (env var doesn't work) |
## Credential Isolation
| Mode | Credential Dir | Server |
| ---------- | -------------------------------- | ----------------- |
| Dev | `lobehub/apps/cli/.lobehub-dev/` | `localhost:3011` |
| Production | `~/.lobehub/` | `app.lobehub.com` |
The two environments are completely isolated. Dev mode credentials are gitignored.
+3 -3
View File
@@ -9,13 +9,13 @@ user-invocable: false
## Configuration
- Config: `drizzle.config.ts`
- Schemas: `src/database/schemas/`
- Migrations: `src/database/migrations/`
- Schemas: `packages/database/src/schemas/`
- Migrations: `packages/database/migrations/`
- Dialect: `postgresql` with `strict: true`
## Helper Functions
Location: `src/database/schemas/_helpers.ts`
Location: `packages/database/src/schemas/_helpers.ts`
- `timestamptz(name)`: Timestamp with timezone
- `createdAt()`, `updatedAt()`, `accessedAt()`: Standard timestamp columns
+4 -40
View File
@@ -177,29 +177,12 @@ export const chatGroupAction: StateCreator<
### Slices That Don't Currently Need `set`
When a slice doesn't write local state at the moment — e.g. it reads context
from `#get()` and forwards calls to another store, or just runs hooks — drop
the `#set` field. Otherwise ESLint's `no-unused-vars` flags the unused private
field.
Mark the constructor's `set` param as `_set` and `void _set` it to keep the
`(set, get, api)` shape aligned with `StateCreator`. This is **a snapshot of
the current need, not a permanent contract** — if a later change needs `set`,
restore the `#set` field and use it; do not invent a workaround to keep the
"unused" form.
When a slice doesn't write local state (e.g. it delegates to another store or just runs hooks), drop `#set` and mark the constructor param as `_set` with `void _set` to keep the `(set, get, api)` shape:
```ts
type Setter = StoreSetter<ConversationStore>;
export const toolSlice = (set: Setter, get: () => ConversationStore, _api?: unknown) =>
new ToolActionImpl(set, get, _api);
export class ToolActionImpl {
readonly #get: () => ConversationStore;
// Mark unused params with `_` prefix and `void _x` so the constructor still
// matches StateCreator's `(set, get, api)` shape without triggering unused
// diagnostics.
constructor(_set: Setter, get: () => ConversationStore, _api?: unknown) {
void _set;
void _api;
@@ -212,27 +195,8 @@ export class ToolActionImpl {
hooks.onToolCallComplete?.(id, undefined);
};
}
export type ToolAction = Pick<ToolActionImpl, keyof ToolActionImpl>;
```
Rules of thumb:
- If a slice doesn't currently call `set`, drop `#set` (use `_set` + `void _set`
in the constructor). When a later edit needs `set`, restore `#set` and use it.
- Don't add `setNamespace` for slices that don't write state. Add it when the
slice starts writing state.
- Never leave `#set` declared but unused "for future use" — lint will fail and
re-adding it later costs nothing.
### Do / Don't
- **Do**: keep constructor signature aligned with `StateCreator` params `(set, get, api)`.
- **Do**: use `#private` to avoid `set/get` being exposed.
- **Do**: use `flattenActions` instead of spreading class instances.
- **Do**: drop `#set` (and use `_set` + `void _set` in the constructor) for
delegate-only slices that never write state — keeps lint green without
breaking the `(set, get, api)` shape.
- **Don't**: keep both old slice objects and class actions active at the same time.
- **Don't**: keep an unused `#set` field "for future use" — it fails ESLint and
re-adding it later costs nothing.
- Drop `#set` when unused; restore it when a later edit needs `set` — re-adding costs nothing.
- Don't add `setNamespace` for slices that don't write state.
- Don't keep both old slice objects and class actions active at the same time during migration.
+2 -1
View File
@@ -1,6 +1,7 @@
# Add directories or file patterns to ignore during indexing (e.g. foo/ or *.csv)
locales/
apps/desktop/resources/locales/
**/__snapshots__/
**/fixtures/
src/database/migrations/
packages/database/migrations/
+6 -2
View File
@@ -115,8 +115,12 @@ cd packages/database && bunx vitest run --silent='passed-only' '[file]'
```
- Prefer `vi.spyOn` over `vi.mock`
- Tests must pass type check: `bun run type-check`
- After 2 failed fix attempts, stop and ask for help
### Type Checking
```bash
bun run type-check
```
### i18n
@@ -244,7 +244,7 @@ export const todoWriteStress = defineCase({
),
callSubAgent(
`${table} 表添加索引`,
`检查 src/database/schemas/${table}.ts 的表结构,添加 createdAt 性能索引,生成迁移 SQL`,
`检查 packages/database/src/schemas/${table}.ts 的表结构,添加 createdAt 性能索引,生成迁移 SQL`,
),
updateTodos(
[{ type: 'complete', index: i }],
@@ -277,7 +277,7 @@ export const todoWriteStress = defineCase({
),
callSubAgent(
`${table} 表添加索引`,
`检查 src/database/schemas/${table}.ts 的表结构,添加 createdAt 性能索引,生成迁移 SQL`,
`检查 packages/database/src/schemas/${table}.ts 的表结构,添加 createdAt 性能索引,生成迁移 SQL`,
),
updateTodos(
[{ type: 'complete', index: i }],
@@ -337,7 +337,7 @@ export const todoWriteStress = defineCase({
'compression',
'file',
'notification',
].flatMap((slice, i) => [
].flatMap((slice) => [
createTodos([`迁移 ${slice} store slice 到 SWR 模式`]),
updateTodos(
[{ type: 'processing', index: 0 }],
@@ -516,7 +516,7 @@ export const todoWriteStress = defineCase({
'file',
'knowledge',
'share',
].flatMap((ns, i) => {
].flatMap((ns) => {
return [
createTodos([`提取 ${ns} 命名空间的硬编码字符串`]),
updateTodos(
@@ -28,6 +28,7 @@ export type AgentDocumentSourceType = 'agent' | 'agent-signal' | 'api' | 'file'
export interface AgentContextDocument {
content?: string;
contentCharCount?: number;
description?: string;
filename: string;
id?: string;
@@ -117,8 +118,8 @@ export function formatDocument(
* Format the size of a document content as a short human-readable token string.
* Empty content is rendered as "empty" so the LLM does not retry reading it.
*/
function formatSize(content: string | undefined): string {
const len = content?.length ?? 0;
function formatSize(doc: Pick<AgentContextDocument, 'content' | 'contentCharCount'>): string {
const len = doc.contentCharCount ?? doc.content?.length ?? 0;
if (len === 0) return 'empty';
if (len < 1000) return String(len);
if (len < 10_000) return `${(len / 1000).toFixed(1)}k`;
@@ -171,7 +172,7 @@ function buildIndexTable(
const now = context.currentTime ?? new Date();
const rows = docs.map((d) => ({
id: d.id ?? '',
size: formatSize(d.content),
size: formatSize(d),
title: truncate(pickRowTitle(d), TITLE_MAX_WIDTH),
updated: formatRelative(d.updatedAt, now),
}));
@@ -225,6 +225,33 @@ describe('AgentDocumentInjector', () => {
expect(result.messages[0].content).not.toContain('Full content that should NOT appear');
});
it('should render progressive index sizes from contentCharCount when content is omitted', async () => {
const provider = new AgentDocumentContextInjector({
currentTime: new Date('2026-04-29T00:00:00.000Z'),
documents: [
{
content: '',
contentCharCount: 12_000,
filename: 'large-note.txt',
id: 'note-1',
loadPosition: 'before-first-user',
loadRules: { rule: 'always' },
policyLoad: 'progressive',
sourceType: 'file',
title: 'Large Note',
updatedAt: new Date('2026-04-27T00:00:00.000Z'),
},
],
});
const context = createContext([{ content: 'Hello', id: 'user-1', role: 'user' }]);
const result = await provider.process(context);
expect(result.messages[0].content).toContain('Large Note');
expect(result.messages[0].content).toContain('12k');
expect(result.messages[0].content).not.toContain('empty');
});
it('should hide web-crawled docs from the index and surface the count', async () => {
const provider = new AgentDocumentContextInjector({
currentTime: new Date('2026-04-29T00:00:00.000Z'),
@@ -4,7 +4,12 @@ import { beforeEach, describe, expect, it } from 'vitest';
import { getTestDB } from '../../../core/getTestDB';
import { agentDocuments, agents, documents, users } from '../../../schemas';
import { DOCUMENT_FOLDER_TYPE } from '../../../schemas/file';
import {
AGENT_SKILL_TEMPLATE_ID,
DOCUMENT_FOLDER_TYPE,
SKILL_BUNDLE_FILE_TYPE,
SKILL_INDEX_FILE_TYPE,
} from '../../../schemas/file';
import type { LobeChatDatabase } from '../../../type';
import {
AgentDocumentModel,
@@ -704,6 +709,58 @@ describe('AgentDocumentModel', () => {
expect(byTemplate).toHaveLength(2);
expect(byTemplate.every((item) => item.templateId === 'claw')).toBe(true);
});
it('should return only skill-managed docs for skill registry assembly', async () => {
const bundle = await agentDocumentModel.create(agentId, 'bug-triage', 'bundle body', {
fileType: SKILL_BUNDLE_FILE_TYPE,
templateId: AGENT_SKILL_TEMPLATE_ID,
});
await agentDocumentModel.create(agentId, 'SKILL.md', 'skill body', {
fileType: SKILL_INDEX_FILE_TYPE,
parentId: bundle.documentId,
templateId: AGENT_SKILL_TEMPLATE_ID,
});
await agentDocumentModel.create(agentId, 'ordinary.md', 'ordinary body');
await agentDocumentModel.create(agentId, 'web-page', 'web body', {
fileType: 'article',
sourceType: 'web',
});
const result = await agentDocumentModel.findSkillDocsByAgent(agentId);
expect(result.map((item) => item.filename).sort()).toEqual(['SKILL.md', 'bug-triage']);
expect(result.every((item) => item.category === 'skill')).toBe(true);
});
it('should omit progressive document content for chat context hydration', async () => {
await agentDocumentModel.create(agentId, 'always.md', 'always body', {
editorData: { root: { children: [{ text: 'always body' }] } },
policyLoad: PolicyLoad.ALWAYS,
});
await agentDocumentModel.create(agentId, 'progressive.md', 'progressive body', {
editorData: { root: { children: [{ text: 'progressive body' }] } },
policyLoad: PolicyLoad.PROGRESSIVE,
});
await agentDocumentModel.create(agentId, 'web-page', 'web body', {
fileType: 'article',
policyLoad: PolicyLoad.PROGRESSIVE,
sourceType: 'web',
});
const result = await agentDocumentModel.findContextByAgent(agentId);
const byFilename = Object.fromEntries(result.map((item) => [item.filename, item]));
expect(byFilename['always.md']?.content).toBe('always body');
expect(byFilename['always.md']?.contentCharCount).toBe('always body'.length);
expect(byFilename['always.md']?.editorData).toEqual({
root: { children: [{ text: 'always body' }] },
});
expect(byFilename['progressive.md']?.content).toBe('');
expect(byFilename['progressive.md']?.contentCharCount).toBe('progressive body'.length);
expect(byFilename['progressive.md']?.editorData).toBeNull();
expect(byFilename['web-page']?.content).toBe('');
expect(byFilename['web-page']?.contentCharCount).toBe('web body'.length);
});
});
describe('hasByAgent', () => {
@@ -1,7 +1,7 @@
import { and, asc, desc, eq, inArray, isNotNull, isNull } from 'drizzle-orm';
import { and, asc, desc, eq, inArray, isNotNull, isNull, like, or, sql } from 'drizzle-orm';
import type { DocumentItem, NewAgentDocument, NewDocument } from '../../schemas';
import { agentDocuments, documents } from '../../schemas';
import { AGENT_SKILL_TEMPLATE_ID, agentDocuments, documents } from '../../schemas';
import type { LobeChatDatabase, Transaction } from '../../type';
import { deriveAgentDocumentFields } from './deriveFields';
import { buildDocumentFilename } from './filename';
@@ -15,6 +15,7 @@ import {
} from './policy';
import type {
AgentDocument,
AgentDocumentContextRow,
AgentDocumentPolicy,
AgentDocumentSourceType,
AgentDocumentWithRules,
@@ -882,6 +883,119 @@ export class AgentDocumentModel {
});
}
async findSkillDocsByAgent(agentId: string): Promise<AgentDocumentWithRules[]> {
const results = await this.db
.select({ doc: documents, settings: agentDocuments })
.from(agentDocuments)
.innerJoin(documents, eq(agentDocuments.documentId, documents.id))
.where(
and(
eq(agentDocuments.userId, this.userId),
eq(agentDocuments.agentId, agentId),
isNull(agentDocuments.deletedAt),
or(
eq(agentDocuments.templateId, AGENT_SKILL_TEMPLATE_ID),
like(documents.fileType, 'skills/%'),
),
),
)
.orderBy(desc(agentDocuments.updatedAt));
return results.map(({ settings, doc }) => {
const item = this.toAgentDocument(settings, doc);
return {
...item,
...deriveAgentDocumentFields(item),
loadRules: parseLoadRules(item),
};
});
}
async findContextByAgent(agentId: string): Promise<AgentDocumentContextRow[]> {
const results = await this.db
.select({
doc: {
content: sql<string>`
CASE
WHEN ${agentDocuments.policyLoad} = ${PolicyLoad.ALWAYS}
THEN COALESCE(${documents.content}, '')
ELSE ''
END
`.as('content'),
description: documents.description,
editorData: sql<Record<string, any> | null>`
CASE
WHEN ${agentDocuments.policyLoad} = ${PolicyLoad.ALWAYS} THEN ${documents.editorData}
ELSE NULL
END
`.as('editor_data'),
filename: documents.filename,
fileType: documents.fileType,
parentId: documents.parentId,
sourceType: documents.sourceType,
title: documents.title,
totalCharCount: documents.totalCharCount,
},
settings: {
agentId: agentDocuments.agentId,
documentId: agentDocuments.documentId,
id: agentDocuments.id,
policy: agentDocuments.policy,
policyLoad: agentDocuments.policyLoad,
policyLoadFormat: agentDocuments.policyLoadFormat,
policyLoadPosition: agentDocuments.policyLoadPosition,
policyLoadRule: agentDocuments.policyLoadRule,
templateId: agentDocuments.templateId,
updatedAt: agentDocuments.updatedAt,
},
})
.from(agentDocuments)
.innerJoin(documents, eq(agentDocuments.documentId, documents.id))
.where(
and(
eq(agentDocuments.userId, this.userId),
eq(agentDocuments.agentId, agentId),
isNull(agentDocuments.deletedAt),
),
)
.orderBy(desc(agentDocuments.updatedAt));
return results.map(({ settings, doc }) => {
const policy = (settings.policy as AgentDocumentPolicy | null) ?? null;
const item: Omit<
AgentDocumentContextRow,
'category' | 'isFolder' | 'isSkillBundle' | 'isSkillIndex' | 'loadRules'
> = {
content: doc.content,
contentCharCount: doc.totalCharCount,
description: doc.description ?? null,
documentId: settings.documentId,
editorData: doc.editorData ?? null,
filename: doc.filename ?? '',
fileType: doc.fileType,
id: settings.id,
parentId: doc.parentId ?? null,
policy,
policyLoad: settings.policyLoad as PolicyLoad,
policyLoadFormat:
(settings.policyLoadFormat as DocumentLoadFormat | null) ??
policy?.context?.policyLoadFormat ??
DocumentLoadFormat.RAW,
policyLoadPosition: settings.policyLoadPosition,
policyLoadRule: settings.policyLoadRule,
sourceType: doc.sourceType,
templateId: settings.templateId ?? null,
title: doc.title ?? doc.filename ?? '',
updatedAt: settings.updatedAt,
};
return {
...item,
...deriveAgentDocumentFields(item),
loadRules: parseLoadRules(item),
};
});
}
async findByDocumentIds(
agentId: string,
documentIds: string[],
@@ -81,6 +81,46 @@ export interface AgentDocumentWithRules extends AgentDocument, AgentDocumentDeri
loadRules: DocumentLoadRules;
}
export interface AgentDocumentContextRow extends AgentDocumentDerivedFields {
content: string;
contentCharCount?: number;
description: string | null;
documentId: string;
editorData: Record<string, any> | null;
filename: string;
fileType: string;
id: string;
loadRules: DocumentLoadRules;
parentId: string | null;
policy: AgentDocumentPolicy | null;
policyLoad: PolicyLoad;
policyLoadFormat: DocumentLoadFormat;
policyLoadPosition: string;
policyLoadRule: string;
sourceType: AgentDocumentSourceType;
templateId: string | null;
title: string;
updatedAt: Date;
}
export interface AgentDocumentContextPayload {
content: string;
contentCharCount?: number;
description: string | null;
filename: string;
id: string;
isFolder: boolean;
loadRules: DocumentLoadRules;
policy: AgentDocumentPolicy | null;
policyLoad: PolicyLoad;
policyLoadFormat: DocumentLoadFormat;
policyLoadPosition: string;
sourceType: AgentDocumentSourceType;
templateId: string | null;
title: string;
updatedAt: Date;
}
export interface ToolUpdateLoadRule {
keywordMatchMode?: 'all' | 'any';
keywords?: string[];
@@ -41,6 +41,7 @@ describe('sanitizeSVGContent', () => {
const maliciousSvg = `
<svg xmlns="http://www.w3.org/2000/svg">
<circle cx="50" cy="50" r="40" fill="red" onclick="alert('click')" onload="alert('load')" />
<rect width="10" height="10" onMouseOver='alert("hover")' onfocus=alert(1) />
</svg>
`;
@@ -48,6 +49,8 @@ describe('sanitizeSVGContent', () => {
expect(sanitized).not.toContain('onclick');
expect(sanitized).not.toContain('onload');
expect(sanitized).not.toContain('onMouseOver');
expect(sanitized).not.toContain('onfocus');
expect(sanitized).toContain('<circle');
expect(sanitized).toContain('fill="red"');
});
+65
View File
@@ -0,0 +1,65 @@
import { describe, expect, it, vi } from 'vitest';
import {
createTimingHelpers,
markTimingSinkStageDone,
markTimingStageDone,
type TimingLogger,
type TimingSink,
} from './timing';
describe('timing utilities', () => {
const context = { requestId: 'req-1', startedAt: Date.now() };
describe('markTimingStageDone', () => {
it('should emit a done marker with zero stage duration', () => {
const logger = vi.fn<TimingLogger>();
markTimingStageDone(logger, context, 'lambda.aiChat.messagesAndTopics.fastResponse', {
messageCount: 2,
reason: 'simple-existing-topic-turn',
});
expect(logger).toHaveBeenCalledWith(
'[%s] %s totalMs=%d %O',
'req-1',
'lambda.aiChat.messagesAndTopics.fastResponse:done',
expect.any(Number),
{
messageCount: 2,
reason: 'simple-existing-topic-turn',
stageMs: 0,
},
);
});
it('should skip logging without timing context', () => {
const logger = vi.fn<TimingLogger>();
markTimingStageDone(logger, undefined, 'lambda.aiChat.messagesAndTopics.fastResponse');
expect(logger).not.toHaveBeenCalled();
});
});
describe('markTimingSinkStageDone', () => {
it('should emit a done marker through a timing sink', () => {
const timing = { log: vi.fn<TimingSink['log']>() };
markTimingSinkStageDone(timing, 'db.message.query.cacheHit', { rowCount: 2 });
expect(timing.log).toHaveBeenCalledWith('db.message.query.cacheHit:done', {
rowCount: 2,
stageMs: 0,
});
});
});
describe('createTimingHelpers', () => {
it('should expose markStageDone on the helper facade', () => {
const helpers = createTimingHelpers('lobe-server:test');
expect(helpers.markStageDone).toBeTypeOf('function');
});
});
});
+27
View File
@@ -78,6 +78,31 @@ export const logTimingSink = (
timing?.log(event, metadata);
};
export const markTimingStageDone = (
logger: TimingLogger,
context: TimingContext | undefined,
stage: string,
metadata?: TimingMetadata,
) => {
if (!context) return;
logTiming(logger, context, `${stage}:done`, {
...metadata,
stageMs: 0,
});
};
export const markTimingSinkStageDone = (
timing: TimingSink | undefined,
stage: string,
metadata?: TimingMetadata,
) => {
logTimingSink(timing, `${stage}:done`, {
...metadata,
stageMs: 0,
});
};
export const runTimedStage = async <T>(
logger: TimingLogger,
context: TimingContext | undefined,
@@ -161,6 +186,8 @@ export const createTimingHelpers = (namespace: string) => {
logger,
logTiming: (context: TimingContext | undefined, event: string, metadata?: TimingMetadata) =>
logTiming(logger, context, event, metadata),
markStageDone: (context: TimingContext | undefined, stage: string, metadata?: TimingMetadata) =>
markTimingStageDone(logger, context, stage, metadata),
runTimedStage: <T>(
context: TimingContext | undefined,
stage: string,
+1 -1
View File
@@ -109,7 +109,7 @@ const ChatList = memo<ChatListProps>(
topicId: canShowAgentSignalReceipts ? context.topicId : undefined,
});
// Fetch notebook documents when topic is selected (skip for share pages)
// Fetch conversation context data when a conversation is visible (skip for share pages)
useFetchAgentDocuments(isSharePage ? undefined : activeAgentId);
useFetchNotebookDocuments(isSharePage ? undefined : context.topicId!);
useFetchTopicMemories(enableUserMemories && !isSharePage ? context.topicId : undefined);
@@ -6,6 +6,7 @@ import isEqual from 'fast-deep-equal';
import { type ReactNode } from 'react';
import { memo, useMemo } from 'react';
import { useFetchAvailableAgents } from '@/hooks/useFetchAvailableAgents';
import { messageMapKey } from '@/store/chat/utils/messageMapKey';
import AssistantTurnSettledWatcher from './AssistantTurnSettledWatcher';
@@ -20,6 +21,18 @@ import {
const log = debug('lobe-render:features:Conversation');
interface ConversationContextPrefetcherProps {
context: ConversationContext;
}
const ConversationContextPrefetcher = memo<ConversationContextPrefetcherProps>(({ context }) => {
useFetchAvailableAgents(!context.topicShareId && !!context.agentId);
return null;
});
ConversationContextPrefetcher.displayName = 'ConversationContextPrefetcher';
export interface ConversationProviderProps {
/**
* Actions bar configuration by message type
@@ -105,6 +118,7 @@ export const ConversationProvider = memo<ConversationProviderProps>(
onMessagesChange={onMessagesChange}
/>
<AssistantTurnSettledWatcher />
<ConversationContextPrefetcher context={context} />
{children}
</Provider>
);
+7
View File
@@ -0,0 +1,7 @@
import { useAgentStore } from '@/store/agent';
export const useFetchAvailableAgents = (enabled: boolean) => {
const useFetchAvailableAgents = useAgentStore((s) => s.useFetchAvailableAgents);
useFetchAvailableAgents(enabled);
};
@@ -574,7 +574,7 @@ export const createRuntimeExecutors = (
if (agentId && ctx.serverDB && ctx.userId) {
try {
const agentDocService = new AgentDocumentsService(ctx.serverDB, ctx.userId);
const docs = await agentDocService.getAgentDocuments(agentId);
const docs = await agentDocService.getAgentContextDocuments(agentId);
if (docs.length > 0) {
agentDocuments = toAgentContextDocuments(docs);
log('Resolved %d agent documents for agent %s', agentDocuments.length, agentId);
@@ -27,6 +27,56 @@ vi.mock('@/server/modules/ModelRuntime', () => ({
describe('aiChatRouter', () => {
const mockCtx = { userId: 'u1' };
const createMessageItem = (overrides: Record<string, any>) => ({
agentId: null,
clientId: null,
content: '',
createdAt: new Date('2024-01-01T00:00:00.000Z'),
error: null,
favorite: false,
id: 'm1',
metadata: null,
model: null,
observationId: null,
parentId: null,
provider: null,
quotaId: null,
reasoning: null,
role: 'user',
search: null,
sessionId: 's1',
threadId: null,
tools: null,
topicId: 't1',
traceId: null,
updatedAt: new Date('2024-01-01T00:00:00.000Z'),
userId: 'u1',
...overrides,
});
const createSimpleNewTopicTurnResult = (overrides: Record<string, any> = {}) => {
const { assistantMessage, userMessage, ...resultOverrides } = overrides;
return {
assistantMessage: createMessageItem({
content: 'loading',
id: 'm-assistant',
model: 'gpt-4o',
parentId: 'm-user',
provider: 'openai',
role: 'assistant',
...assistantMessage,
}),
resolvedSessionId: 's1',
topicId: 't1',
userMessage: createMessageItem({
content: 'hi',
id: 'm-user',
role: 'user',
...userMessage,
}),
...resultOverrides,
};
};
const mockMessageModel = (mockCreateMessage: ReturnType<typeof vi.fn>) => {
const mockCreateUserAndAssistantMessages = vi.fn(
async (
@@ -138,15 +188,85 @@ describe('aiChatRouter', () => {
expect(res.topics?.total).toBe(1);
});
it('should reuse existing topic when topicId provided', async () => {
const mockCreateMessage = vi
it('should skip messages and topics query for simple new topic payload', async () => {
const mockCreateTopic = vi.fn().mockResolvedValue({ id: 't1' });
const mockCreateMessage = vi.fn();
const mockCreateSimpleNewTopicTurn = vi
.fn()
.mockResolvedValueOnce({ id: 'm-user' })
.mockResolvedValueOnce({ id: 'm-assistant' });
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
.mockResolvedValue(createSimpleNewTopicTurnResult());
const mockGet = vi.fn();
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
const mockCreateUserAndAssistantMessages = mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(
() =>
({
createSimpleNewTopicTurn: mockCreateSimpleNewTopicTurn,
getMessagesAndTopics: mockGet,
}) as any,
);
const caller = aiChatRouter.createCaller(mockCtx as any);
const res = await caller.sendMessageInServer({
newAssistantMessage: { model: 'gpt-4o', provider: 'openai' },
newTopic: { title: 'T' },
newUserMessage: { content: 'hi' },
sessionId: 's1',
} as any);
expect(mockCreateSimpleNewTopicTurn).toHaveBeenCalledWith(
expect.objectContaining({
assistantMessage: expect.objectContaining({
model: 'gpt-4o',
provider: 'openai',
}),
sessionId: 's1',
topic: expect.objectContaining({ title: 'T' }),
userMessage: expect.objectContaining({ content: 'hi' }),
}),
);
expect(mockCreateTopic).not.toHaveBeenCalled();
expect(mockCreateUserAndAssistantMessages).not.toHaveBeenCalled();
expect(mockGet).not.toHaveBeenCalled();
expect(res.messages).toEqual([
expect.objectContaining({
content: 'hi',
createdAt: new Date('2024-01-01T00:00:00.000Z').getTime(),
id: 'm-user',
role: 'user',
topicId: 't1',
}),
expect.objectContaining({
extra: { model: 'gpt-4o', provider: 'openai' },
id: 'm-assistant',
parentId: 'm-user',
role: 'assistant',
topicId: 't1',
}),
]);
expect(res.topics).toBeUndefined();
});
it('should reuse existing topic when topicId provided', async () => {
const mockCreateMessage = vi.fn();
const mockCreateSimpleExistingTopicTurn = vi.fn().mockResolvedValue(
createSimpleNewTopicTurnResult({
assistantMessage: { topicId: 't-exist' },
topicId: 't-exist',
userMessage: { topicId: 't-exist' },
}),
);
const mockGet = vi.fn();
const mockCreateUserAndAssistantMessages = mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
vi.mocked(AiChatService).mockImplementation(
() =>
({
createSimpleExistingTopicTurn: mockCreateSimpleExistingTopicTurn,
getMessagesAndTopics: mockGet,
}) as any,
);
const caller = aiChatRouter.createCaller(mockCtx as any);
@@ -157,18 +277,32 @@ describe('aiChatRouter', () => {
topicId: 't-exist',
} as any);
expect(mockCreateMessage).toHaveBeenCalled();
expect(mockCreateUserAndAssistantMessages).toHaveBeenCalledWith(
expect.any(Object),
expect.objectContaining({ touchTopicUpdatedAt: true }),
);
expect(mockGet).toHaveBeenCalledWith(
expect(mockCreateSimpleExistingTopicTurn).toHaveBeenCalledWith(
expect.objectContaining({
includeTopic: false,
assistantMessage: expect.objectContaining({
model: 'gpt-4o',
provider: 'openai',
}),
sessionId: 's1',
topicId: 't-exist',
userMessage: expect.objectContaining({ content: 'hi' }),
}),
);
expect(mockCreateUserAndAssistantMessages).not.toHaveBeenCalled();
expect(mockGet).not.toHaveBeenCalled();
expect(res.messages).toEqual([
expect.objectContaining({
id: 'm-user',
role: 'user',
topicId: 't-exist',
}),
expect.objectContaining({
id: 'm-assistant',
parentId: 'm-user',
role: 'assistant',
topicId: 't-exist',
}),
]);
expect(res.isCreateNewTopic).toBe(false);
expect(res.topicId).toBe('t-exist');
});
@@ -461,6 +595,7 @@ describe('aiChatRouter', () => {
newAssistantMessage: { model: 'gpt-4o', provider: 'openai' },
newUserMessage: { content: 'hi' },
sessionId: 's1',
threadId: 'thread-existing',
topicId: 't1',
} as any);
@@ -468,17 +603,23 @@ describe('aiChatRouter', () => {
});
describe('groupId support', () => {
it('should pass groupId to topic creation when both newTopic and groupId exist', async () => {
it('should pass groupId to simple new topic service when both newTopic and groupId exist', async () => {
const mockCreateTopic = vi.fn().mockResolvedValue({ id: 't1' });
const mockCreateMessage = vi
const mockCreateMessage = vi.fn();
const mockCreateSimpleNewTopicTurn = vi
.fn()
.mockResolvedValueOnce({ id: 'm-user' })
.mockResolvedValueOnce({ id: 'm-assistant' });
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: [{}] });
.mockResolvedValue(createSimpleNewTopicTurnResult());
const mockGet = vi.fn();
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
const mockCreateUserAndAssistantMessages = mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(
() =>
({
createSimpleNewTopicTurn: mockCreateSimpleNewTopicTurn,
getMessagesAndTopics: mockGet,
}) as any,
);
const caller = aiChatRouter.createCaller(mockCtx as any);
@@ -490,27 +631,34 @@ describe('aiChatRouter', () => {
sessionId: 's1',
} as any);
// Verify groupId is passed to topic creation
expect(mockCreateTopic).toHaveBeenCalledWith(
expect(mockCreateSimpleNewTopicTurn).toHaveBeenCalledWith(
expect.objectContaining({
groupId: 'group-123',
sessionId: 's1',
title: 'New Topic',
topic: expect.objectContaining({ title: 'New Topic' }),
}),
);
expect(mockCreateTopic).not.toHaveBeenCalled();
expect(mockCreateUserAndAssistantMessages).not.toHaveBeenCalled();
});
it('should set groupId to null when newTopic exists but groupId is not provided', async () => {
it('should pass undefined groupId to simple new topic service when groupId is not provided', async () => {
const mockCreateTopic = vi.fn().mockResolvedValue({ id: 't1' });
const mockCreateMessage = vi
const mockCreateMessage = vi.fn();
const mockCreateSimpleNewTopicTurn = vi
.fn()
.mockResolvedValueOnce({ id: 'm-user' })
.mockResolvedValueOnce({ id: 'm-assistant' });
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: [{}] });
.mockResolvedValue(createSimpleNewTopicTurnResult());
const mockGet = vi.fn();
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
const mockCreateUserAndAssistantMessages = mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(
() =>
({
createSimpleNewTopicTurn: mockCreateSimpleNewTopicTurn,
getMessagesAndTopics: mockGet,
}) as any,
);
const caller = aiChatRouter.createCaller(mockCtx as any);
@@ -522,14 +670,15 @@ describe('aiChatRouter', () => {
sessionId: 's1',
} as any);
// Verify groupId is undefined (which will be treated as null in the database)
expect(mockCreateTopic).toHaveBeenCalledWith(
expect(mockCreateSimpleNewTopicTurn).toHaveBeenCalledWith(
expect.objectContaining({
groupId: undefined,
sessionId: 's1',
title: 'New Topic',
topic: expect.objectContaining({ title: 'New Topic' }),
}),
);
expect(mockCreateTopic).not.toHaveBeenCalled();
expect(mockCreateUserAndAssistantMessages).not.toHaveBeenCalled();
});
it('should pass groupId to both user and assistant message creation', async () => {
@@ -550,6 +699,7 @@ describe('aiChatRouter', () => {
newAssistantMessage: { model: 'gpt-4o', provider: 'openai' },
newUserMessage: { content: 'Analyze weather data' },
sessionId: 's1',
threadId: 'thread-123',
topicId: 't1',
} as any);
@@ -598,6 +748,7 @@ describe('aiChatRouter', () => {
newAssistantMessage: { model: 'gpt-4o', provider: 'openai' },
newUserMessage: { content: 'hi' },
sessionId: 's1',
threadId: 'thread-123',
topicId: 't1',
} as any);
@@ -630,6 +781,7 @@ describe('aiChatRouter', () => {
newAssistantMessage: { model: 'gpt-4o', provider: 'openai' },
newUserMessage: { content: 'hi' },
sessionId: 's1',
threadId: 'thread-123',
topicId: 't1',
} as any);
@@ -681,6 +833,7 @@ describe('aiChatRouter', () => {
newAssistantMessage: { model: 'gpt-4o', provider: 'openai' },
newUserMessage: { content: 'hi' },
sessionId: 's1',
threadId: 'thread-123',
topicId: 't1',
} as any);
@@ -717,18 +870,24 @@ describe('aiChatRouter', () => {
);
});
it('should pass agentId to topic creation when provided', async () => {
it('should pass agentId to simple new topic service when provided', async () => {
const mockCreateTopic = vi.fn().mockResolvedValue({ id: 't1' });
const mockCreateMessage = vi
const mockCreateMessage = vi.fn();
const mockCreateSimpleNewTopicTurn = vi
.fn()
.mockResolvedValueOnce({ id: 'm-user' })
.mockResolvedValueOnce({ id: 'm-assistant' });
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: [{}] });
.mockResolvedValue(createSimpleNewTopicTurnResult());
const mockGet = vi.fn();
const mockTouchUpdatedAt = vi.fn().mockResolvedValue(undefined);
vi.mocked(TopicModel).mockImplementation(() => ({ create: mockCreateTopic }) as any);
mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
const mockCreateUserAndAssistantMessages = mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(
() =>
({
createSimpleNewTopicTurn: mockCreateSimpleNewTopicTurn,
getMessagesAndTopics: mockGet,
}) as any,
);
vi.mocked(AgentModel).mockImplementation(
() => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any,
);
@@ -743,14 +902,16 @@ describe('aiChatRouter', () => {
sessionId: 's1',
} as any);
// Verify agentId is passed to topic creation
expect(mockCreateTopic).toHaveBeenCalledWith(
expect(mockCreateSimpleNewTopicTurn).toHaveBeenCalledWith(
expect.objectContaining({
agentId: 'agent-1',
sessionId: 's1',
title: 'New Topic',
topic: expect.objectContaining({ title: 'New Topic' }),
}),
);
expect(mockCreateTopic).not.toHaveBeenCalled();
expect(mockCreateUserAndAssistantMessages).not.toHaveBeenCalled();
expect(mockTouchUpdatedAt).not.toHaveBeenCalled();
});
it('should touch agent updatedAt when creating new topic with agentId', async () => {
@@ -774,7 +935,7 @@ describe('aiChatRouter', () => {
await caller.sendMessageInServer({
agentId: 'agent-1',
newAssistantMessage: { model: 'gpt-4o', provider: 'openai' },
newTopic: { title: 'New Topic' },
newTopic: { title: 'New Topic', topicMessageIds: ['seed'] },
newUserMessage: { content: 'hi' },
sessionId: 's1',
} as any);
@@ -812,7 +973,7 @@ describe('aiChatRouter', () => {
const res = await caller.sendMessageInServer({
agentId: 'agent-1',
newAssistantMessage: { model: 'gpt-4o', provider: 'openai' },
newTopic: { title: 'New Topic' },
newTopic: { title: 'New Topic', topicMessageIds: ['seed'] },
newUserMessage: { content: 'hi' },
sessionId: 's1',
} as any);
@@ -820,6 +981,7 @@ describe('aiChatRouter', () => {
expect(res.userMessageId).toBe('m-user');
expect(res.assistantMessageId).toBe('m-assistant');
expect(mockTouchUpdatedAt).toHaveBeenCalledWith('agent-1');
await flushAsyncTasks();
expect(consoleErrorSpy).toHaveBeenCalledWith(
'[aiChat] Failed to touch agent updatedAt:',
touchError,
@@ -829,7 +991,7 @@ describe('aiChatRouter', () => {
}
});
it('should create messages while agent updatedAt touch is still pending', async () => {
it('should return the message response while agent updatedAt touch is still pending', async () => {
const mockCreateTopic = vi.fn().mockResolvedValue({ id: 't1' });
const mockCreateMessage = vi
.fn()
@@ -854,7 +1016,7 @@ describe('aiChatRouter', () => {
const request = caller.sendMessageInServer({
agentId: 'agent-1',
newAssistantMessage: { model: 'gpt-4o', provider: 'openai' },
newTopic: { title: 'New Topic' },
newTopic: { title: 'New Topic', topicMessageIds: ['seed'] },
newUserMessage: { content: 'hi' },
sessionId: 's1',
} as any);
@@ -864,6 +1026,12 @@ describe('aiChatRouter', () => {
try {
expect(mockTouchUpdatedAt).toHaveBeenCalledWith('agent-1');
expect(mockCreateUserAndAssistantMessages).toHaveBeenCalledTimes(1);
await expect(
Promise.race([
request.then(() => 'resolved' as const),
flushAsyncTasks().then(() => 'blocked' as const),
]),
).resolves.toBe('resolved');
} finally {
resolveTouchUpdatedAt();
}
@@ -892,7 +1060,7 @@ describe('aiChatRouter', () => {
await caller.sendMessageInServer({
// no agentId provided
newAssistantMessage: { model: 'gpt-4o', provider: 'openai' },
newTopic: { title: 'New Topic' },
newTopic: { title: 'New Topic', topicMessageIds: ['seed'] },
newUserMessage: { content: 'hi' },
sessionId: 's1',
} as any);
@@ -902,15 +1070,25 @@ describe('aiChatRouter', () => {
});
it('should not touch agent updatedAt when using existing topic', async () => {
const mockCreateMessage = vi
.fn()
.mockResolvedValueOnce({ id: 'm-user' })
.mockResolvedValueOnce({ id: 'm-assistant' });
const mockGet = vi.fn().mockResolvedValue({ messages: [], topics: undefined });
const mockCreateMessage = vi.fn();
const mockCreateSimpleExistingTopicTurn = vi.fn().mockResolvedValue(
createSimpleNewTopicTurnResult({
assistantMessage: { topicId: 't-exist' },
topicId: 't-exist',
userMessage: { topicId: 't-exist' },
}),
);
const mockGet = vi.fn();
const mockTouchUpdatedAt = vi.fn().mockResolvedValue(undefined);
mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(() => ({ getMessagesAndTopics: mockGet }) as any);
const mockCreateUserAndAssistantMessages = mockMessageModel(mockCreateMessage);
vi.mocked(AiChatService).mockImplementation(
() =>
({
createSimpleExistingTopicTurn: mockCreateSimpleExistingTopicTurn,
getMessagesAndTopics: mockGet,
}) as any,
);
vi.mocked(AgentModel).mockImplementation(
() => ({ touchUpdatedAt: mockTouchUpdatedAt }) as any,
);
@@ -926,6 +1104,9 @@ describe('aiChatRouter', () => {
} as any);
// Verify touchUpdatedAt was NOT called since no new topic was created
expect(mockCreateSimpleExistingTopicTurn).toHaveBeenCalled();
expect(mockCreateUserAndAssistantMessages).not.toHaveBeenCalled();
expect(mockGet).not.toHaveBeenCalled();
expect(mockTouchUpdatedAt).not.toHaveBeenCalled();
});
});
@@ -221,6 +221,15 @@ export const agentDocumentRouter = router({
return ctx.agentDocumentService.getAgentDocuments(input.agentId);
}),
/**
* Get documents for chat context injection.
*/
getContextDocuments: agentDocumentProcedure
.input(z.object({ agentId: z.string() }))
.query(async ({ ctx, input }) => {
return ctx.agentDocumentService.getAgentContextDocuments(input.agentId);
}),
/**
* Get a specific document by filename
*/
+250 -41
View File
@@ -1,6 +1,11 @@
import { randomUUID } from 'node:crypto';
import type { CreateMessageParams, SendMessageServerResponse } from '@lobechat/types';
import type {
CreateMessageParams,
DBMessageItem,
SendMessageServerResponse,
UIChatMessage,
} from '@lobechat/types';
import { AiSendMessageServerSchema, RequestTrigger, StructureOutputSchema } from '@lobechat/types';
import { createTimingHelpers, createTimingRequestId } from '@lobechat/utils';
import debug from 'debug';
@@ -20,9 +25,100 @@ import { FileService } from '@/server/services/file';
import { archiveToolResultIfNeeded } from '@/server/services/toolExecution/archiveToolResult';
const log = debug('lobe-lambda-router:ai-chat');
const { createPrefixedTimingContext, logTiming, runTimedStage } = createTimingHelpers(
'lobe-server:chat:lobehub:timing',
);
const { createPrefixedTimingContext, logTiming, markStageDone, runTimedStage } =
createTimingHelpers('lobe-server:chat:lobehub:timing');
type SendMessageServerResponseWithPartial = SendMessageServerResponse & {
__isPartialMessages?: boolean;
};
type CreatedMessageItem = DBMessageItem & {
editorData?: Record<string, any> | null;
groupId?: string | null;
targetId?: string | null;
usage?: UIChatMessage['usage'] | null;
};
const toCreatedUIChatMessage = ({
agentId,
content,
createdAt,
editorData,
error,
groupId,
id,
metadata,
model,
observationId,
parentId,
provider,
quotaId,
reasoning,
role,
search,
sessionId,
targetId,
threadId,
tools,
topicId,
traceId,
updatedAt,
usage,
}: CreatedMessageItem): UIChatMessage => ({
agentId: agentId ?? undefined,
content: content ?? '',
createdAt: createdAt instanceof Date ? createdAt.getTime() : Date.now(),
editorData,
error,
extra: { model: model ?? undefined, provider: provider ?? undefined },
groupId: groupId ?? undefined,
id,
metadata,
model,
observationId: observationId ?? undefined,
parentId: parentId ?? undefined,
provider,
quotaId: quotaId ?? undefined,
reasoning,
role: role as UIChatMessage['role'],
search,
sessionId: sessionId ?? undefined,
targetId: targetId ?? undefined,
threadId,
tools,
topicId: topicId ?? undefined,
traceId: traceId ?? undefined,
updatedAt: updatedAt instanceof Date ? updatedAt.getTime() : Date.now(),
usage: usage ?? undefined,
});
const canUseCreatedMessagesFastPath = (input: z.infer<typeof AiSendMessageServerSchema>) =>
!!input.newTopic &&
!input.topicId &&
!input.newTopic.topicMessageIds?.length &&
!input.newThread &&
!input.preloadMessages?.length &&
!input.newUserMessage.files?.length;
const canUseExistingTopicFastPath = (input: z.infer<typeof AiSendMessageServerSchema>) =>
!!input.topicId &&
!input.newTopic &&
!input.newThread &&
!input.threadId &&
!input.preloadMessages?.length &&
!input.newUserMessage.files?.length;
const getUserMessageMetadata = (
newUserMessage: z.infer<typeof AiSendMessageServerSchema>['newUserMessage'],
) =>
newUserMessage.metadata || newUserMessage.pageSelections?.length
? {
...newUserMessage.metadata,
...(newUserMessage.pageSelections?.length
? { pageSelections: newUserMessage.pageSelections }
: undefined),
}
: undefined;
const aiChatProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
const { ctx } = opts;
@@ -82,6 +178,16 @@ export const aiChatRouter = router({
input.newAssistantMessage.provider === 'lobehub'
? { requestId: createTimingRequestId(), startedAt: Date.now() }
: undefined;
const runServerPersistStage = async <T>(
stage: string,
task: () => T | Promise<T>,
metadata: Record<string, unknown> = {},
): Promise<Awaited<T>> => {
return runTimedStage(timingContext, `lambda.aiChat.${stage}`, task, metadata);
};
const logFastPathMessagesAndTopics = (metadata: Record<string, unknown>) => {
markStageDone(timingContext, 'lambda.aiChat.messagesAndTopics.fastResponse', metadata);
};
logTiming(timingContext, 'lambda.aiChat.sendMessageInServer:start', {
hasNewThread: !!input.newThread,
hasNewTopic: !!input.newTopic,
@@ -96,11 +202,129 @@ export const aiChatRouter = router({
input.newTopic,
input.newThread,
);
if (canUseCreatedMessagesFastPath(input)) {
const result = await runServerPersistStage(
'simpleNewTopicTurn.create',
() =>
ctx.aiChatService.createSimpleNewTopicTurn({
agentId: input.agentId,
assistantMessage: {
content: LOADING_FLAT,
metadata: input.newAssistantMessage.metadata,
model: input.newAssistantMessage.model,
provider: input.newAssistantMessage.provider,
},
groupId: input.groupId,
sessionId: input.sessionId,
topic: {
metadata: input.newTopic!.metadata,
title: input.newTopic!.title,
trigger: input.newTopic!.trigger,
},
userMessage: {
content: input.newUserMessage.content,
editorData: input.newUserMessage.editorData,
metadata: getUserMessageMetadata(input.newUserMessage),
},
}),
{
hasAgentId: !!input.agentId,
hasGroupId: !!input.groupId,
hasSessionId: !!input.sessionId,
},
);
const messages = [
toCreatedUIChatMessage(result.userMessage as CreatedMessageItem),
toCreatedUIChatMessage(result.assistantMessage as CreatedMessageItem),
];
logFastPathMessagesAndTopics({
isCreateNewTopic: true,
messageCount: messages.length,
reason: 'simple-new-topic-turn',
topicCount: 0,
});
logTiming(timingContext, 'lambda.aiChat.sendMessageInServer:done', {
isCreateNewTopic: true,
messageCount: messages.length,
topicCount: 0,
});
const response: SendMessageServerResponseWithPartial = {
assistantMessageId: result.assistantMessage.id,
isCreateNewTopic: true,
messages,
topicId: result.topicId,
userMessageId: result.userMessage.id,
};
return response;
}
if (canUseExistingTopicFastPath(input)) {
const result = await runServerPersistStage(
'simpleExistingTopicTurn.create',
() =>
ctx.aiChatService.createSimpleExistingTopicTurn({
agentId: input.agentId,
assistantMessage: {
content: LOADING_FLAT,
metadata: input.newAssistantMessage.metadata,
model: input.newAssistantMessage.model,
provider: input.newAssistantMessage.provider,
},
groupId: input.groupId,
sessionId: input.sessionId,
topicId: input.topicId!,
userMessage: {
content: input.newUserMessage.content,
editorData: input.newUserMessage.editorData,
metadata: getUserMessageMetadata(input.newUserMessage),
parentId: input.newUserMessage.parentId,
},
}),
{
hasAgentId: !!input.agentId,
hasGroupId: !!input.groupId,
hasParentId: !!input.newUserMessage.parentId,
hasSessionId: !!input.sessionId,
topicId: input.topicId,
},
);
const messages = [
toCreatedUIChatMessage(result.userMessage as CreatedMessageItem),
toCreatedUIChatMessage(result.assistantMessage as CreatedMessageItem),
];
logFastPathMessagesAndTopics({
isCreateNewTopic: false,
messageCount: messages.length,
reason: 'simple-existing-topic-turn',
topicCount: 0,
});
logTiming(timingContext, 'lambda.aiChat.sendMessageInServer:done', {
isCreateNewTopic: false,
messageCount: messages.length,
topicCount: 0,
});
const response: SendMessageServerResponseWithPartial = {
__isPartialMessages: true,
assistantMessageId: result.assistantMessage.id,
isCreateNewTopic: false,
messages,
topicId: result.topicId,
userMessageId: result.userMessage.id,
};
return response;
}
let sessionId = input.sessionId;
if (!sessionId) {
const context = await runTimedStage(
timingContext,
'lambda.aiChat.resolveContext',
const context = await runServerPersistStage(
'resolveContext',
() => resolveContext(input, ctx.serverDB, ctx.userId),
{ hasAgentId: !!input.agentId },
);
@@ -112,14 +336,12 @@ export const aiChatRouter = router({
let createdThreadId: string | undefined;
let isCreateNewTopic = false;
let agentTouchUpdatedAtTask: Promise<void> | undefined;
// create topic if there should be a new topic
if (input.newTopic) {
log('creating new topic with title: %s', input.newTopic.title);
const topicItem = await runTimedStage(
timingContext,
'lambda.aiChat.topic.create',
const topicItem = await runServerPersistStage(
'topic.create',
() => {
const payload = {
agentId: input.agentId,
@@ -149,9 +371,8 @@ export const aiChatRouter = router({
// update agent's updatedAt to reflect new activity
if (input.agentId) {
agentTouchUpdatedAtTask = runTimedStage(
timingContext,
'lambda.aiChat.agent.touchUpdatedAt',
void runServerPersistStage(
'agent.touchUpdatedAt',
async () => {
await ctx.agentModel.touchUpdatedAt(input.agentId!);
},
@@ -170,9 +391,8 @@ export const aiChatRouter = router({
input.newThread.sourceMessageId,
input.newThread.type,
);
const threadItem = await runTimedStage(
timingContext,
'lambda.aiChat.thread.create',
const threadItem = await runServerPersistStage(
'thread.create',
() =>
ctx.threadModel.create({
parentThreadId: input.newThread!.parentThreadId,
@@ -195,9 +415,8 @@ export const aiChatRouter = router({
if (input.preloadMessages?.length) {
log('creating %d preload messages before user message', input.preloadMessages.length);
parentId = await runTimedStage(
timingContext,
'lambda.aiChat.preloadMessages.create',
parentId = await runServerPersistStage(
'preloadMessages.create',
async () => {
let latestParentId = parentId;
for (const preloadMessage of input.preloadMessages!) {
@@ -235,19 +454,10 @@ export const aiChatRouter = router({
log('creating user message with content length: %d', input.newUserMessage.content.length);
// Build user message metadata with pageSelections if present
const userMessageMetadata =
input.newUserMessage.metadata || input.newUserMessage.pageSelections?.length
? {
...input.newUserMessage.metadata,
...(input.newUserMessage.pageSelections?.length
? { pageSelections: input.newUserMessage.pageSelections }
: undefined),
}
: undefined;
const userMessageMetadata = getUserMessageMetadata(input.newUserMessage);
const createMessagePairPromise = runTimedStage(
timingContext,
'lambda.aiChat.messages.createUserAndAssistant',
const createMessagePairPromise = runServerPersistStage(
'messages.createUserAndAssistant',
() => {
const userMessage = {
agentId: input.agentId,
@@ -294,9 +504,7 @@ export const aiChatRouter = router({
},
);
const { assistantMessage: assistantMessageItem, userMessage: userMessageItem } =
agentTouchUpdatedAtTask
? (await Promise.all([createMessagePairPromise, agentTouchUpdatedAtTask]))[0]
: await createMessagePairPromise;
await createMessagePairPromise;
const messageId = userMessageItem.id;
log('user message created with id: %s', messageId);
@@ -305,9 +513,8 @@ export const aiChatRouter = router({
// retrieve latest messages and topic with
log('retrieving messages and topics');
const { messages, topics } = await runTimedStage(
timingContext,
'lambda.aiChat.messagesAndTopics.query',
const { messages, topics } = await runServerPersistStage(
'messagesAndTopics.query',
() =>
ctx.aiChatService.getMessagesAndTopics({
agentId: input.agentId,
@@ -335,15 +542,17 @@ export const aiChatRouter = router({
topicCount: topics?.items?.length ?? 0,
});
return {
const response: SendMessageServerResponseWithPartial = {
assistantMessageId: assistantMessageItem.id,
createdThreadId,
isCreateNewTopic,
messages,
topicId,
topics,
topics: topics as SendMessageServerResponse['topics'],
userMessageId: messageId,
} as SendMessageServerResponse;
};
return response;
}),
archiveToolResult: aiChatProcedure
+164 -106
View File
@@ -26,6 +26,7 @@ vi.mock('@/database/models/agentDocuments', () => ({
BEFORE_FIRST_USER: 'before_first_user',
},
buildDocumentFilename: vi.fn(),
deriveAgentDocumentFields: vi.fn(() => ({})),
extractMarkdownH1Title: vi.fn((content: string) => ({ content })),
}));
@@ -91,8 +92,10 @@ describe('AgentDocumentsService', () => {
create: vi.fn(),
findById: vi.fn(),
findByAgent: vi.fn(),
findContextByAgent: vi.fn(),
findByDocumentIds: vi.fn(),
findByFilename: vi.fn(),
findSkillDocsByAgent: vi.fn(),
hasByAgent: vi.fn(),
rename: vi.fn(),
update: vi.fn(),
@@ -670,6 +673,70 @@ describe('AgentDocumentsService', () => {
});
});
describe('getAgentContextDocuments', () => {
it('should use the context-optimized model query and project only always-loaded docs', async () => {
mockModel.findContextByAgent.mockResolvedValue([
{
content: 'raw content',
contentCharCount: 11,
description: 'Always loaded',
editorData: { root: { children: [] } },
fileType: 'text/markdown',
filename: 'always.md',
id: 'always-doc',
isFolder: false,
loadRules: {},
metadata: { unused: true },
parentId: null,
policy: null,
policyLoad: 'always',
policyLoadFormat: 'raw',
policyLoadPosition: 'before-system',
sourceType: 'file',
templateId: null,
title: 'Always',
updatedAt: new Date('2026-01-01T00:00:00.000Z'),
userId: 'user-1',
},
{
content: '',
contentCharCount: 12_000,
description: null,
documentId: 'doc-2',
editorData: { root: { children: [{ text: 'unused' }] } },
fileType: 'text/markdown',
filename: 'progressive.md',
id: 'progressive-doc',
isFolder: false,
loadRules: {},
metadata: { unused: true },
parentId: null,
policy: null,
policyLoad: 'progressive',
policyLoadFormat: 'raw',
policyLoadPosition: 'before-system',
sourceType: 'file',
templateId: null,
title: 'Progressive',
updatedAt: new Date('2026-01-01T00:00:00.000Z'),
userId: 'user-1',
},
]);
const service = new AgentDocumentsService(db, userId);
const result = await service.getAgentContextDocuments('agent-1');
expect(mockModel.findContextByAgent).toHaveBeenCalledWith('agent-1');
expect(result).toMatchObject([
{ content: 'raw content', id: 'always-doc' },
{ content: '', contentCharCount: 12_000, id: 'progressive-doc' },
]);
expect(result[0]).not.toHaveProperty('editorData');
expect(result[0]).not.toHaveProperty('metadata');
expect(result[0]).not.toHaveProperty('userId');
});
});
describe('associateDocument', () => {
it('should delegate to agentDocumentModel.associate', async () => {
mockModel.associate.mockResolvedValue({ id: 'ad-1' });
@@ -697,51 +764,50 @@ describe('AgentDocumentsService', () => {
title: null,
...doc,
}));
const mockSkillDocs = (docs: Array<Partial<any>>) =>
mockModel.findSkillDocsByAgent.mockResolvedValue(stubDocs(docs));
it('returns an empty list when the agent has no skill bundles', async () => {
const service = new AgentDocumentsService(db, userId);
vi.spyOn(service, 'getAgentDocuments').mockResolvedValue(
stubDocs([
{ documentId: 'doc-1', filename: 'note.md', isSkillBundle: false },
{ documentId: 'doc-2', filename: 'web.md', isSkillBundle: false },
]),
);
mockSkillDocs([
{ documentId: 'doc-1', filename: 'note.md', isSkillBundle: false },
{ documentId: 'doc-2', filename: 'web.md', isSkillBundle: false },
]);
const service = new AgentDocumentsService(db, userId);
const result = await service.getAgentSkills('agent-1');
expect(service.getAgentDocuments).toHaveBeenCalledWith('agent-1');
expect(mockModel.findSkillDocsByAgent).toHaveBeenCalledWith('agent-1');
expect(mockModel.findByAgent).not.toHaveBeenCalled();
expect(result).toEqual([]);
});
it('prefixes the identifier with `agent-skills:` and pulls content from the SKILL.md index child', async () => {
const service = new AgentDocumentsService(db, userId);
vi.spyOn(service, 'getAgentDocuments').mockResolvedValue(
stubDocs([
{
content: '',
description: 'Triage workflow',
documentId: 'bundle-1',
filename: 'bug-triage',
isSkillBundle: true,
title: 'Bug Triage',
},
{
content: '# Bug triage\n\nbody',
documentId: 'index-1',
filename: 'SKILL.md',
isSkillIndex: true,
parentId: 'bundle-1',
},
// Sibling non-index child — must be ignored.
{
content: 'reference',
documentId: 'asset-1',
filename: 'reference.md',
parentId: 'bundle-1',
},
]),
);
mockSkillDocs([
{
content: '',
description: 'Triage workflow',
documentId: 'bundle-1',
filename: 'bug-triage',
isSkillBundle: true,
title: 'Bug Triage',
},
{
content: '# Bug triage\n\nbody',
documentId: 'index-1',
filename: 'SKILL.md',
isSkillIndex: true,
parentId: 'bundle-1',
},
// Sibling non-index child — must be ignored.
{
content: 'reference',
documentId: 'asset-1',
filename: 'reference.md',
parentId: 'bundle-1',
},
]);
const service = new AgentDocumentsService(db, userId);
const result = await service.getAgentSkills('agent-1');
expect(result).toEqual([
@@ -757,20 +823,18 @@ describe('AgentDocumentsService', () => {
});
it('falls back to the bundle row content when the index child is missing', async () => {
const service = new AgentDocumentsService(db, userId);
vi.spyOn(service, 'getAgentDocuments').mockResolvedValue(
stubDocs([
{
content: 'orphan body',
description: null,
documentId: 'orphan-1',
filename: 'orphan-skill',
isSkillBundle: true,
title: 'Orphan',
},
]),
);
mockSkillDocs([
{
content: 'orphan body',
description: null,
documentId: 'orphan-1',
filename: 'orphan-skill',
isSkillBundle: true,
title: 'Orphan',
},
]);
const service = new AgentDocumentsService(db, userId);
const result = await service.getAgentSkills('agent-1');
expect(result).toEqual([
@@ -786,19 +850,17 @@ describe('AgentDocumentsService', () => {
});
it('emits empty content for a bundle with no index child and no body', async () => {
const service = new AgentDocumentsService(db, userId);
vi.spyOn(service, 'getAgentDocuments').mockResolvedValue(
stubDocs([
{
content: '',
documentId: 'empty-1',
filename: 'empty',
isSkillBundle: true,
title: 'Empty',
},
]),
);
mockSkillDocs([
{
content: '',
documentId: 'empty-1',
filename: 'empty',
isSkillBundle: true,
title: 'Empty',
},
]);
const service = new AgentDocumentsService(db, userId);
const [skill] = await service.getAgentSkills('agent-1');
expect(skill.content).toBe('');
@@ -806,38 +868,36 @@ describe('AgentDocumentsService', () => {
});
it('returns one entry per skill bundle and ignores non-bundle docs', async () => {
const service = new AgentDocumentsService(db, userId);
vi.spyOn(service, 'getAgentDocuments').mockResolvedValue(
stubDocs([
{
documentId: 'b-1',
filename: 'one',
isSkillBundle: true,
title: 'One',
},
{
content: 'one body',
documentId: 'b-1-idx',
isSkillIndex: true,
parentId: 'b-1',
},
{
documentId: 'b-2',
filename: 'two',
isSkillBundle: true,
title: 'Two',
},
{
content: 'two body',
documentId: 'b-2-idx',
isSkillIndex: true,
parentId: 'b-2',
},
// Unrelated regular doc.
{ documentId: 'note', filename: 'note.md' },
]),
);
mockSkillDocs([
{
documentId: 'b-1',
filename: 'one',
isSkillBundle: true,
title: 'One',
},
{
content: 'one body',
documentId: 'b-1-idx',
isSkillIndex: true,
parentId: 'b-1',
},
{
documentId: 'b-2',
filename: 'two',
isSkillBundle: true,
title: 'Two',
},
{
content: 'two body',
documentId: 'b-2-idx',
isSkillIndex: true,
parentId: 'b-2',
},
// Unrelated regular doc.
{ documentId: 'note', filename: 'note.md' },
]);
const service = new AgentDocumentsService(db, userId);
const result = await service.getAgentSkills('agent-1');
expect(result.map((s) => s.identifier)).toEqual(['agent-skills:one', 'agent-skills:two']);
@@ -845,22 +905,20 @@ describe('AgentDocumentsService', () => {
});
it('matches index children strictly by parentId — does not leak across bundles', async () => {
const service = new AgentDocumentsService(db, userId);
vi.spyOn(service, 'getAgentDocuments').mockResolvedValue(
stubDocs([
{ documentId: 'b-1', filename: 'first', isSkillBundle: true },
{ documentId: 'b-2', filename: 'second', isSkillBundle: true },
// Only b-2 has an index child; b-1 must fall back to its own (empty)
// content rather than borrow b-2's content.
{
content: 'second body',
documentId: 'b-2-idx',
isSkillIndex: true,
parentId: 'b-2',
},
]),
);
mockSkillDocs([
{ documentId: 'b-1', filename: 'first', isSkillBundle: true },
{ documentId: 'b-2', filename: 'second', isSkillBundle: true },
// Only b-2 has an index child; b-1 must fall back to its own (empty)
// content rather than borrow b-2's content.
{
content: 'second body',
documentId: 'b-2-idx',
isSkillIndex: true,
parentId: 'b-2',
},
]);
const service = new AgentDocumentsService(db, userId);
const result = await service.getAgentSkills('agent-1');
expect(result).toHaveLength(2);
+48 -8
View File
@@ -3,15 +3,16 @@ import type {
DOCUMENT_TEMPLATES,
DocumentLoadRules,
DocumentTemplateSet,
PolicyLoad,
} from '@lobechat/agent-templates';
import { DocumentLoadPosition, getDocumentTemplate } from '@lobechat/agent-templates';
import { DocumentLoadPosition, getDocumentTemplate, PolicyLoad } from '@lobechat/agent-templates';
import { buildAgentSkillIdentifier } from '@lobechat/const';
import type { LobeChatDatabase } from '@lobechat/database';
import { DOCUMENT_FOLDER_TYPE } from '@lobechat/database/schemas';
import type {
AgentDocument,
AgentDocumentContextPayload,
AgentDocumentContextRow,
AgentDocumentWithRules,
ToolUpdateLoadRule,
} from '@/database/models/agentDocuments';
@@ -63,6 +64,10 @@ interface CreateAgentDocumentOptions {
}
type AgentDocumentWithLiteXML = AgentDocument & { litexml?: string };
type ProjectableAgentDocument = Pick<
AgentDocument,
'content' | 'editorData' | 'fileType' | 'templateId'
>;
/**
* Hide the auto-created `.tool-results/` archive (root folder + its children)
@@ -92,6 +97,26 @@ const excludeArchivedToolResults = <
);
};
const toAgentDocumentContextPayload = (
doc: AgentDocumentContextRow,
): AgentDocumentContextPayload => ({
content: doc.content,
contentCharCount: doc.contentCharCount,
description: doc.description,
filename: doc.filename,
id: doc.id,
isFolder: doc.isFolder,
loadRules: doc.loadRules,
policy: doc.policy,
policyLoad: doc.policyLoad,
policyLoadFormat: doc.policyLoadFormat,
policyLoadPosition: doc.policyLoadPosition,
sourceType: doc.sourceType,
templateId: doc.templateId,
title: doc.title,
updatedAt: doc.updatedAt,
});
/**
* Service for managing agent documents with reusable template sets.
* Document-level policy controls runtime behavior (context rendering/retrieval).
@@ -107,13 +132,11 @@ export class AgentDocumentsService {
this.topicDocumentModel = new TopicDocumentModel(db, userId);
}
private async projectDocumentContent<T extends AgentDocument | AgentDocumentWithRules>(
doc: T,
): Promise<T>;
private async projectDocumentContent<T extends AgentDocument | AgentDocumentWithRules>(
private async projectDocumentContent<T extends ProjectableAgentDocument>(doc: T): Promise<T>;
private async projectDocumentContent<T extends ProjectableAgentDocument>(
doc: T | undefined,
): Promise<T | undefined>;
private async projectDocumentContent<T extends AgentDocument | AgentDocumentWithRules>(
private async projectDocumentContent<T extends ProjectableAgentDocument>(
doc: T | undefined,
): Promise<T | undefined> {
if (!doc?.editorData) return doc;
@@ -274,6 +297,23 @@ export class AgentDocumentsService {
return this.projectDocuments(excludeArchivedToolResults(docs));
}
async getAgentContextDocuments(agentId: string): Promise<AgentDocumentContextPayload[]> {
const docs = excludeArchivedToolResults(
await this.agentDocumentModel.findContextByAgent(agentId),
);
const projectedDocs = await Promise.all(
docs.map(async (doc) => {
if (doc.policyLoad !== PolicyLoad.ALWAYS) return doc;
const projected = await this.projectDocumentContent(doc);
return { ...projected, ...deriveAgentDocumentFields(projected) };
}),
);
return projectedDocs.map(toAgentDocumentContextPayload);
}
/**
* Return this agent's skill-bundle documents in a shape ready for the
* homogeneous skill runtime: identifier is prefixed
@@ -295,7 +335,7 @@ export class AgentDocumentsService {
title: string | null;
}>
> {
const docs = await this.getAgentDocuments(agentId);
const docs = await this.agentDocumentModel.findSkillDocsByAgent(agentId);
const childrenByParent = new Map<string, AgentDocumentWithRules[]>();
for (const doc of docs) {
+1 -2
View File
@@ -1195,8 +1195,7 @@ export class AiAgentService {
) ?? false;
try {
const docs = await this.agentDocumentsService.getAgentDocuments(resolvedAgentId);
hasAgentDocuments = docs.length > 0;
hasAgentDocuments = await this.agentDocumentsService.hasDocuments(resolvedAgentId);
} catch {
// Agent documents check is non-critical
}
+289 -1
View File
@@ -1,8 +1,21 @@
// @vitest-environment node
import type { LobeChatDatabase } from '@lobechat/database';
import { describe, expect, it, vi } from 'vitest';
import { getTestDB } from '@lobechat/database/test-utils';
import { eq } from 'drizzle-orm';
import { beforeEach, describe, expect, it, vi } from 'vitest';
import { MessageModel } from '@/database/models/message';
import { TopicModel } from '@/database/models/topic';
import {
agents,
agentsToSessions,
chatGroups,
messages,
sessions,
threads,
topics,
users,
} from '@/database/schemas';
import { FileService } from '@/server/services/file';
import { AiChatService } from '.';
@@ -11,7 +24,282 @@ vi.mock('@/database/models/message');
vi.mock('@/database/models/topic');
vi.mock('@/server/services/file');
const userId = 'ai-chat-service-test-user';
const sessionId = 'ai-chat-service-session';
const agentId = 'ai-chat-service-agent';
const groupId = 'ai-chat-service-group';
const existingTopicId = 'ai-chat-service-topic';
const threadId = 'ai-chat-service-thread';
const serverDB: LobeChatDatabase = await getTestDB();
describe('AiChatService', () => {
const seedBase = async () => {
await serverDB.insert(users).values({ id: userId });
await serverDB.insert(sessions).values({ id: sessionId, title: 'Session', userId });
await serverDB.insert(agents).values({ id: agentId, title: 'Agent', userId });
await serverDB.insert(agentsToSessions).values({ agentId, sessionId, userId });
};
const seedGroup = async () => {
await serverDB.insert(chatGroups).values({ id: groupId, title: 'Group', userId });
};
const getMessagesByTopicId = async (topicId: string) => {
const rows = await serverDB.select().from(messages).where(eq(messages.topicId, topicId));
return rows.toSorted((a, b) => a.createdAt.getTime() - b.createdAt.getTime());
};
beforeEach(async () => {
vi.clearAllMocks();
await serverDB.delete(users);
});
it('createSimpleNewTopicTurn should persist the simple turn through the Drizzle CTE', async () => {
await seedBase();
const service = new AiChatService(serverDB, userId);
const res = await service.createSimpleNewTopicTurn({
agentId,
assistantMessage: {
content: 'loading',
metadata: {},
model: 'gpt-4o',
provider: 'openai',
},
topic: { title: 'T' },
userMessage: {
content: 'hi',
editorData: { type: 'doc' },
metadata: {},
},
});
const [createdTopic] = await serverDB.select().from(topics).where(eq(topics.id, res.topicId));
const createdMessages = await getMessagesByTopicId(res.topicId);
const [updatedAgent] = await serverDB.select().from(agents).where(eq(agents.id, agentId));
expect(res.topicId).toMatch(/^tpc_/);
expect(res.resolvedSessionId).toBe(sessionId);
expect(createdTopic).toEqual(
expect.objectContaining({
agentId,
sessionId,
title: 'T',
userId,
}),
);
expect(createdMessages).toHaveLength(2);
expect(res.userMessage).toEqual(
expect.objectContaining({
content: 'hi',
editorData: { type: 'doc' },
sessionId,
role: 'user',
topicId: res.topicId,
userId,
}),
);
expect(res.assistantMessage).toEqual(
expect.objectContaining({
content: 'loading',
model: 'gpt-4o',
parentId: res.userMessage.id,
provider: 'openai',
role: 'assistant',
sessionId,
}),
);
expect(createdMessages.map((message) => message.id)).toEqual([
res.userMessage.id,
res.assistantMessage.id,
]);
expect(updatedAgent.updatedAt.getTime()).toBeGreaterThan(updatedAgent.createdAt.getTime());
});
it('createSimpleNewTopicTurn should keep group messages detached from session rows', async () => {
await seedBase();
await seedGroup();
const service = new AiChatService(serverDB, userId);
const res = await service.createSimpleNewTopicTurn({
agentId,
assistantMessage: { content: 'loading' },
groupId,
topic: { title: 'T' },
userMessage: { content: 'hi' },
});
const createdMessages = await getMessagesByTopicId(res.topicId);
expect(res.resolvedSessionId).toBe(sessionId);
expect(res.userMessage.sessionId).toBeNull();
expect(res.assistantMessage.sessionId).toBeNull();
expect(createdMessages).toHaveLength(2);
expect(createdMessages).toEqual([
expect.objectContaining({ groupId, id: res.userMessage.id, sessionId: null }),
expect.objectContaining({ groupId, id: res.assistantMessage.id, sessionId: null }),
]);
});
it('createSimpleExistingTopicTurn should persist the simple turn through the Drizzle CTE', async () => {
await seedBase();
await serverDB.insert(topics).values({
agentId,
id: existingTopicId,
sessionId,
title: 'Existing Topic',
userId,
});
await serverDB.insert(messages).values({
content: 'parent',
id: 'm-parent',
role: 'user',
sessionId,
topicId: existingTopicId,
userId,
});
const service = new AiChatService(serverDB, userId);
const res = await service.createSimpleExistingTopicTurn({
agentId,
assistantMessage: {
content: 'loading',
metadata: {},
model: 'gpt-4o',
provider: 'openai',
},
topicId: existingTopicId,
userMessage: {
content: 'hi',
editorData: { type: 'doc' },
metadata: {},
parentId: 'm-parent',
},
});
const createdMessages = (await getMessagesByTopicId(existingTopicId)).filter(
(message) => message.id !== 'm-parent',
);
const [updatedTopic] = await serverDB
.select()
.from(topics)
.where(eq(topics.id, existingTopicId));
expect(res.topicId).toBe(existingTopicId);
expect(res.resolvedSessionId).toBe(sessionId);
expect(updatedTopic.updatedAt.getTime()).toBeGreaterThan(updatedTopic.createdAt.getTime());
expect(createdMessages).toHaveLength(2);
expect(res.userMessage).toEqual(
expect.objectContaining({
content: 'hi',
parentId: 'm-parent',
role: 'user',
sessionId,
topicId: existingTopicId,
}),
);
expect(res.assistantMessage).toEqual(
expect.objectContaining({
content: 'loading',
model: 'gpt-4o',
parentId: res.userMessage.id,
provider: 'openai',
role: 'assistant',
sessionId,
}),
);
});
it('createSimpleExistingTopicTurn should throw when the topic does not exist for the user', async () => {
await seedBase();
const service = new AiChatService(serverDB, userId);
await expect(
service.createSimpleExistingTopicTurn({
assistantMessage: { content: 'loading' },
topicId: 't1',
userMessage: { content: 'hi' },
}),
).rejects.toThrow('Failed to create simple existing topic turn');
});
it('createSimpleExistingTopicTurn should persist the thread id on both messages', async () => {
await seedBase();
await serverDB.insert(topics).values({
agentId,
id: existingTopicId,
sessionId,
title: 'Existing Topic',
userId,
});
await serverDB.insert(threads).values({
id: threadId,
title: 'Thread',
topicId: existingTopicId,
type: 'continuation',
userId,
});
const service = new AiChatService(serverDB, userId);
const res = await service.createSimpleExistingTopicTurn({
agentId,
assistantMessage: { content: 'loading' },
threadId,
topicId: existingTopicId,
userMessage: { content: 'hi' },
});
const createdMessages = await getMessagesByTopicId(existingTopicId);
expect(res.userMessage.threadId).toBe(threadId);
expect(res.assistantMessage.threadId).toBe(threadId);
expect(createdMessages).toEqual([
expect.objectContaining({ id: res.userMessage.id, threadId }),
expect.objectContaining({ id: res.assistantMessage.id, threadId }),
]);
});
it('createSimpleExistingTopicTurn should keep group messages detached from session rows', async () => {
await seedBase();
await seedGroup();
await serverDB.insert(topics).values({
agentId,
groupId,
id: existingTopicId,
sessionId,
title: 'Existing Topic',
userId,
});
const service = new AiChatService(serverDB, userId);
const res = await service.createSimpleExistingTopicTurn({
agentId,
assistantMessage: { content: 'loading' },
groupId,
topicId: existingTopicId,
userMessage: { content: 'hi' },
});
const createdMessages = await getMessagesByTopicId(existingTopicId);
expect(res.resolvedSessionId).toBe(sessionId);
expect(res.userMessage.sessionId).toBeNull();
expect(res.assistantMessage.sessionId).toBeNull();
expect(createdMessages).toHaveLength(2);
expect(createdMessages).toEqual([
expect.objectContaining({ groupId, id: res.userMessage.id, sessionId: null }),
expect.objectContaining({ groupId, id: res.assistantMessage.id, sessionId: null }),
]);
});
it('getMessagesAndTopics should fetch messages and topics concurrently', async () => {
const serverDB = {} as unknown as LobeChatDatabase;
+571
View File
@@ -1,9 +1,15 @@
import type { LobeChatDatabase } from '@lobechat/database';
import { idGenerator } from '@lobechat/database';
import type { CreateMessageParams, DBMessageItem } from '@lobechat/types';
import { createTimingHelpers } from '@lobechat/utils';
import { and, eq, sql } from 'drizzle-orm';
import { MessageModel } from '@/database/models/message';
import type { CreateTopicParams } from '@/database/models/topic';
import { TopicModel } from '@/database/models/topic';
import { agents, agentsToSessions, messages, topics } from '@/database/schemas';
import { FileService } from '@/server/services/file';
import { sanitizeNullBytes } from '@/utils/sanitizeNullBytes';
const { createPrefixedTimingContext, runTimedStage, toTimingContext } = createTimingHelpers(
'lobe-server:chat:lobehub:timing',
@@ -28,20 +34,585 @@ interface GetMessagesAndTopicsParams {
topicPageSize?: number;
}
interface SimpleTurnMessage extends DBMessageItem {
editorData?: CreateMessageParams['editorData'];
groupId?: string | null;
targetId?: string | null;
usage?: CreateMessageParams['usage'] | null;
}
interface SimpleTurnMessageRow extends Omit<SimpleTurnMessage, 'createdAt' | 'updatedAt'> {
createdAt: Date | string;
resolvedSessionId: string | null;
resolvedTopicId: string;
updatedAt: Date | string;
}
interface CreateSimpleNewTopicTurnParams {
agentId?: string | null;
assistantMessage: Pick<CreateMessageParams, 'metadata' | 'model' | 'provider'> & {
content: string;
};
groupId?: string | null;
sessionId?: string | null;
topic: Pick<CreateTopicParams, 'metadata' | 'title' | 'trigger'>;
touchAgentUpdatedAt?: boolean;
userMessage: Pick<CreateMessageParams, 'content' | 'editorData' | 'metadata'>;
}
interface CreateSimpleNewTopicTurnResult {
assistantMessage: SimpleTurnMessage;
resolvedSessionId: string | null;
topicId: string;
userMessage: SimpleTurnMessage;
}
interface CreateSimpleExistingTopicTurnParams {
agentId?: string | null;
assistantMessage: Pick<CreateMessageParams, 'metadata' | 'model' | 'provider'> & {
content: string;
};
groupId?: string | null;
sessionId?: string | null;
threadId?: string | null;
topicId: string;
userMessage: Pick<CreateMessageParams, 'content' | 'editorData' | 'metadata' | 'parentId'>;
}
interface CreateSimpleExistingTopicTurnResult {
assistantMessage: SimpleTurnMessage;
resolvedSessionId: string | null;
topicId: string;
userMessage: SimpleTurnMessage;
}
const stringifyJsonParam = (value: unknown) =>
value === undefined ? null : JSON.stringify(sanitizeNullBytes(value));
const toMessageItem = ({
createdAt,
resolvedSessionId: _resolvedSessionId,
resolvedTopicId: _resolvedTopicId,
updatedAt,
...message
}: SimpleTurnMessageRow): SimpleTurnMessage => ({
...message,
createdAt: createdAt instanceof Date ? createdAt : new Date(createdAt),
updatedAt: updatedAt instanceof Date ? updatedAt : new Date(updatedAt),
});
const getCreatedTurnMessages = (
rows: SimpleTurnMessageRow[],
userMessageId: string,
assistantMessageId: string,
) => {
const userMessage = rows.find((row) => row.id === userMessageId);
const assistantMessage = rows.find((row) => row.id === assistantMessageId);
return { assistantMessage, userMessage };
};
export class AiChatService {
private userId: string;
private serverDB: LobeChatDatabase;
private messageModel: MessageModel;
private fileService: FileService;
private topicModel: TopicModel;
constructor(serverDB: LobeChatDatabase, userId: string) {
this.userId = userId;
this.serverDB = serverDB;
this.messageModel = new MessageModel(serverDB, userId);
this.topicModel = new TopicModel(serverDB, userId);
this.fileService = new FileService(serverDB, userId);
}
async createSimpleNewTopicTurn({
agentId,
assistantMessage,
groupId,
sessionId,
topic,
touchAgentUpdatedAt = true,
userMessage,
}: CreateSimpleNewTopicTurnParams): Promise<CreateSimpleNewTopicTurnResult> {
const normalizedAgentId = agentId ?? null;
const normalizedGroupId = groupId ?? null;
const normalizedSessionId = sessionId ?? null;
const topicId = idGenerator('topics');
const userMessageId = idGenerator('messages');
const assistantMessageId = idGenerator('messages');
const createdAt = Date.now();
const userCreatedAt = new Date(createdAt);
const assistantCreatedAt = new Date(createdAt + 1);
const topicTitle = topic.title ?? null;
const topicTrigger = topic.trigger ?? null;
const userMetadata = stringifyJsonParam(userMessage.metadata);
const userEditorData = stringifyJsonParam(userMessage.editorData);
const assistantMetadata = stringifyJsonParam(assistantMessage.metadata);
const topicMetadata = stringifyJsonParam(topic.metadata);
const resolvedContext = this.serverDB.$with('resolved_context', {
resolvedSessionId: sql<string | null>`"resolvedSessionId"`.as('resolvedSessionId'),
}).as(sql`
SELECT COALESCE(
${normalizedSessionId}::text,
(
SELECT ${agentsToSessions.sessionId}
FROM ${agentsToSessions}
WHERE ${agentsToSessions.agentId} = ${normalizedAgentId}
AND ${agentsToSessions.userId} = ${this.userId}
LIMIT 1
)
)::text AS "resolvedSessionId"
`);
const createdTopic = this.serverDB.$with('created_topic').as(
this.serverDB
.insert(topics)
.select((qb) =>
qb
.select({
id: sql<string>`${topicId}::text`.as('id'),
title: sql<string | null>`${topicTitle}::text`.as('title'),
favorite: sql<boolean>`false`.as('favorite'),
sessionId: resolvedContext.resolvedSessionId,
content: sql<string | null>`NULL::text`.as('content'),
editorData: sql<unknown | null>`NULL::jsonb`.as('editorData'),
agentId: sql<string | null>`${normalizedAgentId}::text`.as('agentId'),
groupId: sql<string | null>`${normalizedGroupId}::text`.as('groupId'),
userId: sql<string>`${this.userId}::text`.as('userId'),
clientId: sql<string | null>`NULL::text`.as('clientId'),
description: sql<string | null>`NULL::text`.as('description'),
historySummary: sql<string | null>`NULL::text`.as('historySummary'),
metadata: sql<CreateTopicParams['metadata'] | null>`${topicMetadata}::jsonb`.as(
'metadata',
),
trigger: sql<CreateTopicParams['trigger'] | null>`${topicTrigger}::text`.as(
'trigger',
),
mode: sql<string | null>`NULL::text`.as('mode'),
status: sql<string | null>`NULL::text`.as('status'),
completedAt: sql<Date | null>`NULL::timestamp with time zone`.as('completedAt'),
totalCost: sql<number | null>`NULL::numeric`.as('totalCost'),
totalInputTokens: sql<number | null>`NULL::integer`.as('totalInputTokens'),
totalOutputTokens: sql<number | null>`NULL::integer`.as('totalOutputTokens'),
totalTokens: sql<number | null>`NULL::integer`.as('totalTokens'),
cost: sql<Record<string, unknown> | null>`NULL::jsonb`.as('cost'),
usage: sql<Record<string, unknown> | null>`NULL::jsonb`.as('usage'),
model: sql<string | null>`NULL::text`.as('model'),
provider: sql<string | null>`NULL::text`.as('provider'),
senderId: sql<string | null>`NULL::text`.as('senderId'),
accessedAt: sql<Date>`NOW()`.as('accessedAt'),
createdAt: sql<Date>`NOW()`.as('createdAt'),
updatedAt: sql<Date>`NOW()`.as('updatedAt'),
})
.from(resolvedContext),
)
.returning({ topicId: topics.id }),
);
const messagePayload = this.serverDB.$with('message_payload', {
payloadContent: sql<string>`"payloadContent"`.as('payloadContent'),
payloadCreatedAt: sql<Date>`"payloadCreatedAt"`.as('payloadCreatedAt'),
payloadEditorData: sql<CreateMessageParams['editorData'] | null>`"payloadEditorData"`.as(
'payloadEditorData',
),
payloadId: sql<string>`"payloadId"`.as('payloadId'),
payloadMetadata: sql<CreateMessageParams['metadata'] | null>`"payloadMetadata"`.as(
'payloadMetadata',
),
payloadModel: sql<string | null>`"payloadModel"`.as('payloadModel'),
payloadParentId: sql<string | null>`"payloadParentId"`.as('payloadParentId'),
payloadProvider: sql<string | null>`"payloadProvider"`.as('payloadProvider'),
payloadRole: sql<string>`"payloadRole"`.as('payloadRole'),
payloadUpdatedAt: sql<Date>`"payloadUpdatedAt"`.as('payloadUpdatedAt'),
}).as(sql`
SELECT *
FROM (
VALUES
(
${userMessageId}::text,
'user'::varchar,
${sanitizeNullBytes(userMessage.content)}::text,
${userEditorData}::jsonb,
${userMetadata}::jsonb,
NULL::text,
NULL::text,
NULL::text,
${userCreatedAt}::timestamp with time zone,
${userCreatedAt}::timestamp with time zone
),
(
${assistantMessageId}::text,
'assistant'::varchar,
${sanitizeNullBytes(assistantMessage.content)}::text,
NULL::jsonb,
${assistantMetadata}::jsonb,
${assistantMessage.model ?? null}::text,
${assistantMessage.provider ?? null}::text,
${userMessageId}::text,
${assistantCreatedAt}::timestamp with time zone,
${assistantCreatedAt}::timestamp with time zone
)
) AS "payload" (
"payloadId",
"payloadRole",
"payloadContent",
"payloadEditorData",
"payloadMetadata",
"payloadModel",
"payloadProvider",
"payloadParentId",
"payloadCreatedAt",
"payloadUpdatedAt"
)
`);
const createdMessages = this.serverDB.$with('created_messages').as(
this.serverDB
.insert(messages)
.select((qb) =>
qb
.select({
id: messagePayload.payloadId,
role: messagePayload.payloadRole,
content: messagePayload.payloadContent,
editorData: messagePayload.payloadEditorData,
summary: sql<string | null>`NULL::text`.as('summary'),
reasoning: sql<unknown | null>`NULL::jsonb`.as('reasoning'),
search: sql<unknown | null>`NULL::jsonb`.as('search'),
metadata: messagePayload.payloadMetadata,
usage: sql<CreateMessageParams['usage'] | null>`NULL::jsonb`.as('usage'),
model: messagePayload.payloadModel,
provider: messagePayload.payloadProvider,
favorite: sql<boolean>`false`.as('favorite'),
error: sql<unknown | null>`NULL::jsonb`.as('error'),
tools: sql<unknown | null>`NULL::jsonb`.as('tools'),
traceId: sql<string | null>`NULL::text`.as('traceId'),
observationId: sql<string | null>`NULL::text`.as('observationId'),
clientId: sql<string | null>`NULL::text`.as('clientId'),
userId: sql<string>`${this.userId}::text`.as('userId'),
sessionId: sql<string | null>`
CASE
WHEN ${normalizedGroupId}::text IS NOT NULL THEN NULL
ELSE ${resolvedContext.resolvedSessionId}
END
`.as('sessionId'),
topicId: createdTopic.topicId,
threadId: sql<string | null>`NULL::text`.as('threadId'),
parentId: messagePayload.payloadParentId,
quotaId: sql<string | null>`NULL::text`.as('quotaId'),
agentId: sql<string | null>`${normalizedAgentId}::text`.as('agentId'),
groupId: sql<string | null>`${normalizedGroupId}::text`.as('groupId'),
targetId: sql<string | null>`NULL::text`.as('targetId'),
messageGroupId: sql<string | null>`NULL::text`.as('messageGroupId'),
accessedAt: sql<Date>`NOW()`.as('accessedAt'),
createdAt: messagePayload.payloadCreatedAt,
updatedAt: messagePayload.payloadUpdatedAt,
})
.from(messagePayload)
.crossJoin(resolvedContext)
.crossJoin(createdTopic),
)
.returning(),
);
const touchedAgent = this.serverDB.$with('touched_agent').as(
this.serverDB
.update(agents)
// accessedAt has $onUpdate; keep it unchanged to preserve the previous raw SQL behavior.
.set({ accessedAt: agents.accessedAt, updatedAt: sql`NOW()` })
.where(
sql`${touchAgentUpdatedAt} AND ${normalizedAgentId}::text IS NOT NULL AND ${agents.id} = ${normalizedAgentId} AND ${agents.userId} = ${this.userId}`,
)
.returning({ id: agents.id }),
);
const rows = await this.serverDB
.with(resolvedContext, createdTopic, messagePayload, createdMessages, touchedAgent)
.select({
agentId: createdMessages.agentId,
clientId: createdMessages.clientId,
content: sql<SimpleTurnMessage['content']>`${createdMessages.content}`.as('content'),
createdAt: createdMessages.createdAt,
editorData: sql<SimpleTurnMessage['editorData']>`${createdMessages.editorData}`.as(
'editorData',
),
error: sql<SimpleTurnMessage['error']>`${createdMessages.error}`.as('error'),
favorite: createdMessages.favorite,
groupId: createdMessages.groupId,
id: createdMessages.id,
metadata: sql<SimpleTurnMessage['metadata']>`${createdMessages.metadata}`.as('metadata'),
model: createdMessages.model,
observationId: createdMessages.observationId,
parentId: createdMessages.parentId,
provider: createdMessages.provider,
quotaId: createdMessages.quotaId,
reasoning: sql<SimpleTurnMessage['reasoning']>`${createdMessages.reasoning}`.as(
'reasoning',
),
role: sql<SimpleTurnMessage['role']>`${createdMessages.role}`.as('role'),
search: sql<SimpleTurnMessage['search']>`${createdMessages.search}`.as('search'),
sessionId: createdMessages.sessionId,
targetId: createdMessages.targetId,
threadId: createdMessages.threadId,
tools: sql<SimpleTurnMessage['tools']>`${createdMessages.tools}`.as('tools'),
topicId: createdMessages.topicId,
traceId: createdMessages.traceId,
updatedAt: createdMessages.updatedAt,
usage: sql<SimpleTurnMessage['usage']>`${createdMessages.usage}`.as('usage'),
userId: createdMessages.userId,
resolvedSessionId: resolvedContext.resolvedSessionId,
resolvedTopicId: createdTopic.topicId,
})
.from(createdMessages)
.crossJoin(resolvedContext)
.crossJoin(createdTopic);
const { assistantMessage: assistantMessageRow, userMessage: userMessageRow } =
getCreatedTurnMessages(rows, userMessageId, assistantMessageId);
if (!userMessageRow || !assistantMessageRow) {
throw new Error('Failed to create simple new topic turn');
}
return {
assistantMessage: toMessageItem(assistantMessageRow),
resolvedSessionId: userMessageRow.resolvedSessionId,
topicId: userMessageRow.resolvedTopicId,
userMessage: toMessageItem(userMessageRow),
};
}
async createSimpleExistingTopicTurn({
agentId,
assistantMessage,
groupId,
sessionId,
threadId,
topicId,
userMessage,
}: CreateSimpleExistingTopicTurnParams): Promise<CreateSimpleExistingTopicTurnResult> {
const normalizedAgentId = agentId ?? null;
const normalizedGroupId = groupId ?? null;
const normalizedSessionId = sessionId ?? null;
const normalizedThreadId = threadId ?? null;
const userParentId = userMessage.parentId ?? null;
const userMessageId = idGenerator('messages');
const assistantMessageId = idGenerator('messages');
const createdAt = Date.now();
const userCreatedAt = new Date(createdAt);
const assistantCreatedAt = new Date(createdAt + 1);
const userMetadata = stringifyJsonParam(userMessage.metadata);
const userEditorData = stringifyJsonParam(userMessage.editorData);
const assistantMetadata = stringifyJsonParam(assistantMessage.metadata);
const existingTopic = this.serverDB.$with('existing_topic').as(
this.serverDB
.select({
existingSessionId: topics.sessionId,
existingTopicId: topics.id,
})
.from(topics)
.where(and(eq(topics.id, topicId), eq(topics.userId, this.userId)))
.limit(1),
);
const resolvedContext = this.serverDB.$with('resolved_context').as(
this.serverDB
.select({
resolvedSessionId: sql<string | null>`
COALESCE(
${normalizedSessionId}::text,
${existingTopic.existingSessionId},
(
SELECT ${agentsToSessions.sessionId}
FROM ${agentsToSessions}
WHERE ${agentsToSessions.agentId} = ${normalizedAgentId}
AND ${agentsToSessions.userId} = ${this.userId}
LIMIT 1
)
)::text
`.as('resolvedSessionId'),
resolvedTopicId: existingTopic.existingTopicId,
})
.from(existingTopic),
);
const updatedTopic = this.serverDB.$with('updated_topic').as(
this.serverDB
.update(topics)
// accessedAt has $onUpdate; keep it unchanged to preserve the previous raw SQL behavior.
.set({ accessedAt: topics.accessedAt, updatedAt: sql`NOW()` })
.from(resolvedContext)
.where(and(eq(topics.id, resolvedContext.resolvedTopicId), eq(topics.userId, this.userId)))
.returning({ topicId: topics.id }),
);
const messagePayload = this.serverDB.$with('message_payload', {
payloadContent: sql<string>`"payloadContent"`.as('payloadContent'),
payloadCreatedAt: sql<Date>`"payloadCreatedAt"`.as('payloadCreatedAt'),
payloadEditorData: sql<CreateMessageParams['editorData'] | null>`"payloadEditorData"`.as(
'payloadEditorData',
),
payloadId: sql<string>`"payloadId"`.as('payloadId'),
payloadMetadata: sql<CreateMessageParams['metadata'] | null>`"payloadMetadata"`.as(
'payloadMetadata',
),
payloadModel: sql<string | null>`"payloadModel"`.as('payloadModel'),
payloadParentId: sql<string | null>`"payloadParentId"`.as('payloadParentId'),
payloadProvider: sql<string | null>`"payloadProvider"`.as('payloadProvider'),
payloadRole: sql<string>`"payloadRole"`.as('payloadRole'),
payloadUpdatedAt: sql<Date>`"payloadUpdatedAt"`.as('payloadUpdatedAt'),
}).as(sql`
SELECT *
FROM (
VALUES
(
${userMessageId}::text,
'user'::varchar,
${sanitizeNullBytes(userMessage.content)}::text,
${userEditorData}::jsonb,
${userMetadata}::jsonb,
NULL::text,
NULL::text,
${userParentId}::text,
${userCreatedAt}::timestamp with time zone,
${userCreatedAt}::timestamp with time zone
),
(
${assistantMessageId}::text,
'assistant'::varchar,
${sanitizeNullBytes(assistantMessage.content)}::text,
NULL::jsonb,
${assistantMetadata}::jsonb,
${assistantMessage.model ?? null}::text,
${assistantMessage.provider ?? null}::text,
${userMessageId}::text,
${assistantCreatedAt}::timestamp with time zone,
${assistantCreatedAt}::timestamp with time zone
)
) AS "payload" (
"payloadId",
"payloadRole",
"payloadContent",
"payloadEditorData",
"payloadMetadata",
"payloadModel",
"payloadProvider",
"payloadParentId",
"payloadCreatedAt",
"payloadUpdatedAt"
)
`);
const createdMessages = this.serverDB.$with('created_messages').as(
this.serverDB
.insert(messages)
.select((qb) =>
qb
.select({
id: messagePayload.payloadId,
role: messagePayload.payloadRole,
content: messagePayload.payloadContent,
editorData: messagePayload.payloadEditorData,
summary: sql<string | null>`NULL::text`.as('summary'),
reasoning: sql<unknown | null>`NULL::jsonb`.as('reasoning'),
search: sql<unknown | null>`NULL::jsonb`.as('search'),
metadata: messagePayload.payloadMetadata,
usage: sql<CreateMessageParams['usage'] | null>`NULL::jsonb`.as('usage'),
model: messagePayload.payloadModel,
provider: messagePayload.payloadProvider,
favorite: sql<boolean>`false`.as('favorite'),
error: sql<unknown | null>`NULL::jsonb`.as('error'),
tools: sql<unknown | null>`NULL::jsonb`.as('tools'),
traceId: sql<string | null>`NULL::text`.as('traceId'),
observationId: sql<string | null>`NULL::text`.as('observationId'),
clientId: sql<string | null>`NULL::text`.as('clientId'),
userId: sql<string>`${this.userId}::text`.as('userId'),
sessionId: sql<string | null>`
CASE
WHEN ${normalizedGroupId}::text IS NOT NULL THEN NULL
ELSE ${resolvedContext.resolvedSessionId}
END
`.as('sessionId'),
topicId: updatedTopic.topicId,
threadId: sql<string | null>`${normalizedThreadId}::text`.as('threadId'),
parentId: messagePayload.payloadParentId,
quotaId: sql<string | null>`NULL::text`.as('quotaId'),
agentId: sql<string | null>`${normalizedAgentId}::text`.as('agentId'),
groupId: sql<string | null>`${normalizedGroupId}::text`.as('groupId'),
targetId: sql<string | null>`NULL::text`.as('targetId'),
messageGroupId: sql<string | null>`NULL::text`.as('messageGroupId'),
accessedAt: sql<Date>`NOW()`.as('accessedAt'),
createdAt: messagePayload.payloadCreatedAt,
updatedAt: messagePayload.payloadUpdatedAt,
})
.from(messagePayload)
.crossJoin(resolvedContext)
.crossJoin(updatedTopic),
)
.returning(),
);
const rows = await this.serverDB
.with(existingTopic, resolvedContext, updatedTopic, messagePayload, createdMessages)
.select({
agentId: createdMessages.agentId,
clientId: createdMessages.clientId,
content: sql<SimpleTurnMessage['content']>`${createdMessages.content}`.as('content'),
createdAt: createdMessages.createdAt,
editorData: sql<SimpleTurnMessage['editorData']>`${createdMessages.editorData}`.as(
'editorData',
),
error: sql<SimpleTurnMessage['error']>`${createdMessages.error}`.as('error'),
favorite: createdMessages.favorite,
groupId: createdMessages.groupId,
id: createdMessages.id,
metadata: sql<SimpleTurnMessage['metadata']>`${createdMessages.metadata}`.as('metadata'),
model: createdMessages.model,
observationId: createdMessages.observationId,
parentId: createdMessages.parentId,
provider: createdMessages.provider,
quotaId: createdMessages.quotaId,
reasoning: sql<SimpleTurnMessage['reasoning']>`${createdMessages.reasoning}`.as(
'reasoning',
),
role: sql<SimpleTurnMessage['role']>`${createdMessages.role}`.as('role'),
search: sql<SimpleTurnMessage['search']>`${createdMessages.search}`.as('search'),
sessionId: createdMessages.sessionId,
targetId: createdMessages.targetId,
threadId: createdMessages.threadId,
tools: sql<SimpleTurnMessage['tools']>`${createdMessages.tools}`.as('tools'),
topicId: createdMessages.topicId,
traceId: createdMessages.traceId,
updatedAt: createdMessages.updatedAt,
usage: sql<SimpleTurnMessage['usage']>`${createdMessages.usage}`.as('usage'),
userId: createdMessages.userId,
resolvedSessionId: resolvedContext.resolvedSessionId,
resolvedTopicId: updatedTopic.topicId,
})
.from(createdMessages)
.crossJoin(resolvedContext)
.crossJoin(updatedTopic);
const { assistantMessage: assistantMessageRow, userMessage: userMessageRow } =
getCreatedTurnMessages(rows, userMessageId, assistantMessageId);
if (!userMessageRow || !assistantMessageRow) {
throw new Error('Failed to create simple existing topic turn');
}
return {
assistantMessage: toMessageItem(assistantMessageRow),
resolvedSessionId: userMessageRow.resolvedSessionId,
topicId: userMessageRow.resolvedTopicId,
userMessage: toMessageItem(userMessageRow),
};
}
async getMessagesAndTopics(params: GetMessagesAndTopicsParams) {
const { topicFilter, topicPageSize, timingRequestId, timingStartedAt, ...messageParams } =
params;
+16 -1
View File
@@ -3,6 +3,17 @@ import { type PartialDeep } from 'type-fest';
import { lambdaClient } from '@/libs/trpc/client';
export const AVAILABLE_AGENTS_CONTEXT_LIMIT = 10;
export const AVAILABLE_AGENTS_CONTEXT_QUERY_LIMIT = AVAILABLE_AGENTS_CONTEXT_LIMIT + 2;
export interface AvailableAgentItem {
avatar: string | null;
backgroundColor: string | null;
description: string | null;
id: string;
title: string | null;
}
/**
* Market agent model can be either a string or an object with model details
*/
@@ -211,7 +222,11 @@ class AgentService {
* Query non-virtual agents with optional keyword filter.
* Returns agents with minimal info (id, title, description, avatar, backgroundColor).
*/
queryAgents = async (params?: { keyword?: string; limit?: number; offset?: number }) => {
queryAgents = async (params?: {
keyword?: string;
limit?: number;
offset?: number;
}): Promise<AvailableAgentItem[]> => {
return lambdaClient.agent.queryAgents.query(params);
};
+15 -7
View File
@@ -4,7 +4,8 @@ import { mutate } from '@/libs/swr';
import { agentDocumentService, resolveAgentDocumentsContext } from './agentDocument';
const { queryMock } = vi.hoisted(() => ({
const { contextDocumentsQueryMock, queryMock } = vi.hoisted(() => ({
contextDocumentsQueryMock: vi.fn(),
queryMock: vi.fn(),
}));
@@ -17,6 +18,7 @@ vi.mock('@/libs/trpc/client', () => ({
agentDocument: {
copyDocument: { mutate: queryMock },
createDocument: { mutate: queryMock },
getContextDocuments: { query: contextDocumentsQueryMock },
getDocuments: { query: queryMock },
getTemplates: { query: queryMock },
initializeFromTemplate: { mutate: queryMock },
@@ -33,8 +35,10 @@ vi.mock('@/libs/trpc/client', () => ({
describe('AgentDocumentService', () => {
beforeEach(() => {
queryMock.mockResolvedValue({ ok: true });
contextDocumentsQueryMock.mockResolvedValue({ ok: true });
vi.mocked(mutate).mockClear();
queryMock.mockClear();
contextDocumentsQueryMock.mockClear();
});
afterEach(() => {
@@ -79,12 +83,13 @@ describe('AgentDocumentService', () => {
});
it('should fetch target agent documents when cache is missing', async () => {
queryMock.mockResolvedValueOnce([
contextDocumentsQueryMock.mockResolvedValueOnce([
{
content: 'Target agent setup',
contentCharCount: 'Target agent setup'.length,
filename: 'setup.md',
id: 'doc-1',
loadRules: [],
loadRules: {},
policy: null,
policyLoadFormat: null,
policyLoadPosition: null,
@@ -98,19 +103,21 @@ describe('AgentDocumentService', () => {
agentId: 'target-agent',
}),
).resolves.toEqual([
{
expect.objectContaining({
content: 'Target agent setup',
contentCharCount: 'Target agent setup'.length,
filename: 'setup.md',
id: 'doc-1',
loadPosition: undefined,
loadRules: [],
loadRules: {},
policyId: null,
policyLoadFormat: undefined,
title: 'Setup',
},
}),
]);
expect(queryMock).toHaveBeenCalledWith({ agentId: 'target-agent' });
expect(contextDocumentsQueryMock).toHaveBeenCalledWith({ agentId: 'target-agent' });
expect(queryMock).not.toHaveBeenCalled();
});
it('should reuse cached agent documents without refetching', async () => {
@@ -130,6 +137,7 @@ describe('AgentDocumentService', () => {
}),
).resolves.toBe(cachedDocuments);
expect(contextDocumentsQueryMock).not.toHaveBeenCalled();
expect(queryMock).not.toHaveBeenCalled();
});
});
+5 -1
View File
@@ -45,6 +45,10 @@ class AgentDocumentService {
return lambdaClient.agentDocument.getDocuments.query(params);
};
getContextDocuments = async (params: { agentId: string }) => {
return lambdaClient.agentDocument.getContextDocuments.query(params);
};
initializeFromTemplate = async (params: { agentId: string; templateSet: string }) => {
const result = await lambdaClient.agentDocument.initializeFromTemplate.mutate(params);
await revalidateAgentDocuments(params.agentId);
@@ -282,7 +286,7 @@ export const resolveAgentDocumentsContext = async (params: {
if (cachedDocuments !== undefined) return cachedDocuments;
if (!agentId) return undefined;
const documents = await agentDocumentService.getDocuments({ agentId });
const documents = await agentDocumentService.getContextDocuments({ agentId });
return toAgentContextDocuments(documents);
};
+10 -6
View File
@@ -1584,7 +1584,7 @@ describe('ChatService', () => {
.spyOn(mechaModule, 'contextEngineering')
.mockResolvedValue([]);
vi.spyOn(chatService, 'getChatCompletion').mockResolvedValue(new Response(''));
vi.spyOn(agentDocumentService, 'getDocuments').mockResolvedValue([
vi.spyOn(agentDocumentService, 'getContextDocuments').mockResolvedValue([
{
content: 'Project setup steps',
filename: 'setup.md',
@@ -1604,7 +1604,9 @@ describe('ChatService', () => {
resolvedAgentConfig: createMockResolvedConfig(),
});
expect(agentDocumentService.getDocuments).toHaveBeenCalledWith({ agentId: 'agent-1' });
expect(agentDocumentService.getContextDocuments).toHaveBeenCalledWith({
agentId: 'agent-1',
});
expect(contextEngineeringSpy).toHaveBeenCalledWith(
expect.objectContaining({
agentDocuments: [
@@ -1623,7 +1625,7 @@ describe('ChatService', () => {
.spyOn(mechaModule, 'contextEngineering')
.mockResolvedValue([]);
vi.spyOn(chatService, 'getChatCompletion').mockResolvedValue(new Response(''));
vi.spyOn(agentDocumentService, 'getDocuments').mockResolvedValue([
vi.spyOn(agentDocumentService, 'getContextDocuments').mockResolvedValue([
{
content: 'Edited agent setup',
filename: 'builder-target.md',
@@ -1647,7 +1649,9 @@ describe('ChatService', () => {
}),
});
expect(agentDocumentService.getDocuments).toHaveBeenCalledWith({ agentId: 'edited-agent' });
expect(agentDocumentService.getContextDocuments).toHaveBeenCalledWith({
agentId: 'edited-agent',
});
expect(contextEngineeringSpy).toHaveBeenCalledWith(
expect.objectContaining({
agentDocuments: [
@@ -1827,7 +1831,7 @@ describe('ChatService', () => {
describe('fetchPresetTaskResult', () => {
it('should not wait for agent documents on preset task chains', async () => {
vi.spyOn(chatService, 'getChatCompletion').mockResolvedValue(new Response(''));
vi.spyOn(agentDocumentService, 'getDocuments').mockResolvedValue([]);
vi.spyOn(agentDocumentService, 'getContextDocuments').mockResolvedValue([]);
await chatService.fetchPresetTaskResult({
abortController: new AbortController(),
@@ -1838,7 +1842,7 @@ describe('ChatService', () => {
},
});
expect(agentDocumentService.getDocuments).not.toHaveBeenCalled();
expect(agentDocumentService.getContextDocuments).not.toHaveBeenCalled();
});
it('should handle successful chat completion response', async () => {
@@ -1,8 +1,10 @@
import { type UIChatMessage } from '@lobechat/types';
import { afterEach, describe, expect, it, vi } from 'vitest';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import * as isCanUseFCModule from '@/helpers/isCanUseFC';
import { agentService } from '@/services/agent';
import { agentDocumentService } from '@/services/agentDocument';
import { useAgentStore } from '@/store/agent';
import * as helpers from '../helper';
import { contextEngineering } from './contextEngineering';
@@ -46,6 +48,14 @@ vi.mock('@/services/agentDocument', () => ({
},
}));
vi.mock('@/services/agent', () => ({
AVAILABLE_AGENTS_CONTEXT_LIMIT: 10,
AVAILABLE_AGENTS_CONTEXT_QUERY_LIMIT: 12,
agentService: {
queryAgents: vi.fn(),
},
}));
// 默认设置 isServerMode 为 false
let isServerMode = false;
@@ -61,6 +71,14 @@ vi.mock('@lobechat/const', async (importOriginal) => {
};
});
beforeEach(() => {
vi.mocked(agentService.queryAgents).mockResolvedValue([]);
useAgentStore.setState({
agentMap: {},
availableAgents: undefined,
});
});
afterEach(() => {
vi.resetModules();
vi.clearAllMocks();
@@ -125,6 +143,47 @@ describe('contextEngineering', () => {
});
});
it('should use cached available agents without querying during context engineering', async () => {
useAgentStore.setState({
availableAgents: [
{
avatar: null,
backgroundColor: null,
description: null,
id: 'agent-1',
title: 'Current Agent',
},
{
avatar: null,
backgroundColor: null,
description: 'Helps with setup',
id: 'agent-2',
title: 'Setup Agent',
},
],
});
await contextEngineering({
agentId: 'agent-1',
messages: [{ content: 'Hello', role: 'user' }] as UIChatMessage[],
model: 'gpt-4',
provider: 'openai',
});
expect(agentService.queryAgents).not.toHaveBeenCalled();
});
it('should query available agents when the prefetch cache is missing', async () => {
await contextEngineering({
agentId: 'agent-1',
messages: [{ content: 'Hello', role: 'user' }] as UIChatMessage[],
model: 'gpt-4',
provider: 'openai',
});
expect(agentService.queryAgents).toHaveBeenCalledWith({ limit: 12 });
});
describe('handle with files content in server mode', () => {
it('should includes files', async () => {
isServerMode = true;
+26 -24
View File
@@ -40,6 +40,11 @@ import debug from 'debug';
import { isCanUseFC } from '@/helpers/isCanUseFC';
import { VARIABLE_GENERATORS } from '@/helpers/parserPlaceholder';
import { lambdaClient } from '@/libs/trpc/client';
import {
agentService,
AVAILABLE_AGENTS_CONTEXT_LIMIT,
AVAILABLE_AGENTS_CONTEXT_QUERY_LIMIT,
} from '@/services/agent';
import { notebookService } from '@/services/notebook';
import { getAgentStoreState } from '@/store/agent';
import { agentChatConfigSelectors, agentSelectors } from '@/store/agent/selectors';
@@ -457,13 +462,9 @@ export const contextEngineering = async ({
if (shouldInjectAvailableAgents) {
try {
// Over-fetch by 2: +1 reserved for the current agent (filtered out below
// so the model has no exposure to its own id and cannot self-delegate)
// and +1 to detect overflow for the `hasMore` flag.
const AVAILABLE_AGENTS_LIMIT = 10;
const recentAgents = await lambdaClient.agent.queryAgents.query({
limit: AVAILABLE_AGENTS_LIMIT + 2,
});
const recentAgents =
agentStoreState.availableAgents ??
(await agentService.queryAgents({ limit: AVAILABLE_AGENTS_CONTEXT_QUERY_LIMIT }));
// Exclude current agent from `availableAgents`. The model is the current
// agent — its identity/persona is already established by `systemRole`, so
@@ -471,8 +472,8 @@ export const contextEngineering = async ({
// model never sees its own id in the agent-management context (so it
// cannot accidentally call itself via `callAgent`).
const otherAgents = agentId ? recentAgents.filter((a) => a.id !== agentId) : recentAgents;
const hasMoreAgents = otherAgents.length > AVAILABLE_AGENTS_LIMIT;
const availableAgents = otherAgents.slice(0, AVAILABLE_AGENTS_LIMIT).map((a) => ({
const hasMoreAgents = otherAgents.length > AVAILABLE_AGENTS_CONTEXT_LIMIT;
const availableAgents = otherAgents.slice(0, AVAILABLE_AGENTS_CONTEXT_LIMIT).map((a) => ({
description: a.description ?? undefined,
id: a.id,
title: a.title ?? 'Untitled',
@@ -605,21 +606,22 @@ export const contextEngineering = async ({
}
// Resolve topic references from messages containing <refer_topic> tags
const topicReferences = await resolveTopicReferences(
messages,
async (topicId: string) => {
const topic = topicSelectors.getTopicById(topicId)(getChatStoreState());
return topic ?? null;
},
async (topicId: string) => {
const { messageService } = await import('@/services/message');
const msgs = await messageService.getMessages({ agentId, groupId, topicId });
return msgs.map((m) => ({
content: typeof m.content === 'string' ? m.content : '',
role: m.role,
}));
},
);
const topicReferences =
(await resolveTopicReferences(
messages,
async (topicId: string) => {
const topic = topicSelectors.getTopicById(topicId)(getChatStoreState());
return topic ?? null;
},
async (topicId: string) => {
const { messageService } = await import('@/services/message');
const msgs = await messageService.getMessages({ agentId, groupId, topicId });
return msgs.map((m) => ({
content: typeof m.content === 'string' ? m.content : '',
role: m.role,
}));
},
)) ?? [];
// Build onboarding context if this is the web-onboarding agent.
// Single combined trpc call — server runs state/soul/persona DB queries in parallel.
+107 -27
View File
@@ -3,7 +3,7 @@ import { act, renderHook, waitFor } from '@testing-library/react';
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import { agentService } from '@/services/agent';
import { agentDocumentService } from '@/services/agentDocument';
import { resolveAgentDocumentsContext } from '@/services/agentDocument';
import { type LobeAgentConfig } from '@/types/agent';
import { withSWR } from '~test-utils';
@@ -14,9 +14,12 @@ vi.mock('zustand/traditional');
// Mock agentService
vi.mock('@/services/agent', () => ({
AVAILABLE_AGENTS_CONTEXT_QUERY_LIMIT: 12,
agentService: {
createAgent: vi.fn(),
getAgentConfigById: vi.fn(),
getSessionConfig: vi.fn(),
queryAgents: vi.fn(),
updateAgentConfig: vi.fn(),
updateAgentMeta: vi.fn(),
},
@@ -26,23 +29,7 @@ vi.mock('@/services/agentDocument', () => ({
agentDocumentSWRKeys: {
documents: (agentId: string) => ['agent-documents', agentId] as const,
},
agentDocumentService: {
getDocuments: vi.fn(),
},
}));
vi.mock('@/utils/agentDocumentContextMapping', () => ({
toAgentContextDocuments: (documents: any[]) =>
documents.map((doc) => ({
content: doc.content,
filename: doc.filename,
id: doc.id,
loadPosition: undefined,
loadRules: doc.loadRules,
policyId: doc.templateId,
policyLoadFormat: undefined,
title: doc.title,
})),
resolveAgentDocumentsContext: vi.fn(),
}));
// Mock sessionStore
@@ -69,6 +56,7 @@ beforeEach(() => {
activeAgentId: undefined,
agentMap: {},
builtinAgentIdMap: {},
availableAgents: undefined,
updateAgentConfigSignal: undefined,
agentDocumentsMap: {},
updateAgentMetaSignal: undefined,
@@ -80,21 +68,46 @@ afterEach(() => {
});
describe('AgentSlice Actions', () => {
describe('createAgent', () => {
it('should invalidate cached available agents after creating an agent', async () => {
vi.mocked(agentService.createAgent).mockResolvedValue({ agentId: 'agent-2' });
const { result } = renderHook(() => useAgentStore());
act(() => {
useAgentStore.setState({
availableAgents: [
{
avatar: null,
backgroundColor: null,
description: 'stale',
id: 'agent-1',
title: 'Stale Agent',
},
],
});
});
await act(async () => {
await result.current.createAgent({ config: { title: 'New Agent' } });
});
expect(result.current.availableAgents).toBeUndefined();
});
});
describe('useFetchAgentDocuments', () => {
it('should sync fetched agent documents into store cache', async () => {
vi.mocked(agentDocumentService.getDocuments).mockResolvedValue([
vi.mocked(resolveAgentDocumentsContext).mockResolvedValue([
{
content: 'setup steps',
filename: 'setup.md',
id: 'doc-1',
loadRules: [],
policy: null,
policyLoadFormat: null,
policyLoadPosition: null,
templateId: null,
loadRules: {},
policyId: null,
policyLoadFormat: undefined,
title: 'Setup',
},
] as any);
]);
const { result } = renderHook(() => useAgentStore(), { wrapper: withSWR });
@@ -106,14 +119,71 @@ describe('AgentSlice Actions', () => {
content: 'setup steps',
filename: 'setup.md',
id: 'doc-1',
loadPosition: undefined,
loadRules: [],
loadRules: {},
policyId: null,
policyLoadFormat: undefined,
title: 'Setup',
},
]);
});
expect(resolveAgentDocumentsContext).toHaveBeenCalledWith({ agentId: 'agent-1' });
});
});
describe('useFetchAvailableAgents', () => {
it('should sync fetched available agents into store cache', async () => {
vi.mocked(agentService.queryAgents).mockResolvedValue([
{
avatar: null,
backgroundColor: null,
description: 'Helps with setup',
id: 'agent-1',
title: 'Setup',
},
]);
const { result } = renderHook(() => useAgentStore(), { wrapper: withSWR });
renderHook(() => result.current.useFetchAvailableAgents(true), { wrapper: withSWR });
await waitFor(() => {
expect(result.current.availableAgents).toEqual([
{
avatar: null,
backgroundColor: null,
description: 'Helps with setup',
id: 'agent-1',
title: 'Setup',
},
]);
});
expect(agentService.queryAgents).toHaveBeenCalledWith({ limit: 12 });
});
});
describe('invalidateAvailableAgents', () => {
it('should clear cached available agents', () => {
const { result } = renderHook(() => useAgentStore());
act(() => {
useAgentStore.setState({
availableAgents: [
{
avatar: null,
backgroundColor: null,
description: 'stale',
id: 'agent-1',
title: 'Stale Agent',
},
],
});
});
act(() => {
result.current.invalidateAvailableAgents();
});
expect(result.current.availableAgents).toBeUndefined();
});
});
@@ -399,6 +469,15 @@ describe('AgentSlice Actions', () => {
useAgentStore.setState({
activeAgentId: 'agent-1',
agentMap: { 'agent-1': { title: 'Old Title' } as any },
availableAgents: [
{
avatar: null,
backgroundColor: null,
description: 'Old Desc',
id: 'agent-1',
title: 'Old Title',
},
],
});
});
@@ -410,6 +489,7 @@ describe('AgentSlice Actions', () => {
description: 'New Desc',
title: 'New Title',
});
expect(result.current.availableAgents).toBeUndefined();
});
// Note: refreshSessions is no longer called after optimistic update
+30 -10
View File
@@ -9,13 +9,9 @@ import type { PartialDeep } from 'type-fest';
import { MESSAGE_CANCEL_FLAT } from '@/const/message';
import { mutate, useClientDataSWRWithSync } from '@/libs/swr';
import type { CreateAgentParams, CreateAgentResult } from '@/services/agent';
import { agentService } from '@/services/agent';
import {
agentDocumentService,
agentDocumentSWRKeys,
resolveAgentDocumentsContext,
} from '@/services/agentDocument';
import type { AvailableAgentItem, CreateAgentParams, CreateAgentResult } from '@/services/agent';
import { agentService, AVAILABLE_AGENTS_CONTEXT_QUERY_LIMIT } from '@/services/agent';
import { agentDocumentSWRKeys, resolveAgentDocumentsContext } from '@/services/agentDocument';
import type { StoreSetter } from '@/store/types';
import { getUserStoreState } from '@/store/user';
import { userProfileSelectors } from '@/store/user/selectors';
@@ -25,7 +21,6 @@ import type {
LobeAgentConfig,
RuntimeEnvConfig,
} from '@/types/agent';
import { toAgentContextDocuments } from '@/utils/agentDocumentContextMapping';
import { merge } from '@/utils/merge';
import type { AgentStore } from '../../store';
@@ -33,6 +28,11 @@ import { setLocalAgentWorkingDirectory } from '../../utils/localAgentWorkingDire
import type { AgentSliceState, LoadingState, SaveStatus } from './initialState';
const FETCH_AGENT_CONFIG_KEY = 'FETCH_AGENT_CONFIG';
const FETCH_AVAILABLE_AGENTS_KEY = 'FETCH_AVAILABLE_AGENTS';
const FETCH_AVAILABLE_AGENTS_SWR_KEY = [
FETCH_AVAILABLE_AGENTS_KEY,
AVAILABLE_AGENTS_CONTEXT_QUERY_LIMIT,
] as const;
type AgentMetaUpdate = Partial<
Pick<
AgentItem,
@@ -80,6 +80,7 @@ export class AgentSliceActionImpl {
createAgent = async (params: CreateAgentParams): Promise<CreateAgentResult> => {
const result = await agentService.createAgent(params);
this.#get().invalidateAvailableAgents();
// Track new agent creation analytics
const analytics = getSingletonAnalyticsOptional();
@@ -324,8 +325,7 @@ export class AgentSliceActionImpl {
useFetchAgentDocuments = (agentId?: string | null): SWRResponse<AgentContextDocument[]> => {
return useClientDataSWRWithSync<AgentContextDocument[]>(
agentId ? agentDocumentSWRKeys.documents(agentId) : null,
async () =>
toAgentContextDocuments(await agentDocumentService.getDocuments({ agentId: agentId! })),
async () => (await resolveAgentDocumentsContext({ agentId: agentId! })) ?? [],
{
onData: (data) => {
if (!agentId) return;
@@ -337,6 +337,24 @@ export class AgentSliceActionImpl {
);
};
useFetchAvailableAgents = (enabled: boolean): SWRResponse<AvailableAgentItem[]> => {
return useClientDataSWRWithSync<AvailableAgentItem[]>(
enabled ? FETCH_AVAILABLE_AGENTS_SWR_KEY : null,
() => agentService.queryAgents({ limit: AVAILABLE_AGENTS_CONTEXT_QUERY_LIMIT }),
{
onData: (data) => {
this.#set({ availableAgents: data }, false, 'useFetchAvailableAgents');
},
revalidateOnFocus: false,
},
);
};
invalidateAvailableAgents = (): void => {
this.#set({ availableAgents: undefined }, false, 'invalidateAvailableAgents');
void mutate(FETCH_AVAILABLE_AGENTS_SWR_KEY);
};
ensureAgentDocuments = async (
agentId?: string | null,
): Promise<AgentContextDocument[] | undefined> => {
@@ -397,6 +415,7 @@ export class AgentSliceActionImpl {
// 3. Use returned data directly (no refetch needed!)
if (result?.success && result.agent) {
internal_dispatchAgentMap(id, result.agent);
this.#get().invalidateAvailableAgents();
}
updateSaveStatus('saved');
} catch (error: any) {
@@ -427,6 +446,7 @@ export class AgentSliceActionImpl {
// 3. Use returned data directly (no refetch needed!)
if (result?.success && result.agent) {
internal_dispatchAgentMap(id, result.agent);
this.#get().invalidateAvailableAgents();
}
updateSaveStatus('saved');
} catch (error: any) {
@@ -2,6 +2,7 @@ import type { AgentContextDocument } from '@lobechat/context-engine';
import type { PartialDeep } from 'type-fest';
import { type AgentSettingsInstance } from '@/features/AgentSetting';
import { type AvailableAgentItem } from '@/services/agent';
import { type AgentItem } from '@/types/agent';
import { type MetaData } from '@/types/meta';
@@ -15,6 +16,7 @@ export interface AgentSliceState {
agentDocumentsMap: Record<string, AgentContextDocument[]>;
agentMap: Record<string, PartialDeep<AgentItem>>;
agentSettingInstance?: AgentSettingsInstance | null;
availableAgents?: AvailableAgentItem[];
/**
* Whether the agent panel is pinned (UI state)
*/
@@ -53,6 +55,7 @@ export interface AgentSliceState {
export const initialAgentSliceState: AgentSliceState = {
agentDocumentsMap: {},
agentMap: {},
availableAgents: undefined,
isAgentPinned: false,
lastUpdatedTime: null,
localAgentWorkingDirectoryMap: readAllLocalAgentWorkingDirectories(),
@@ -12,6 +12,7 @@ vi.mock('zustand/traditional');
// Mock agentService
vi.mock('@/services/agent', () => ({
AVAILABLE_AGENTS_CONTEXT_QUERY_LIMIT: 12,
agentService: {
createAgentFiles: vi.fn(),
createAgentKnowledgeBase: vi.fn(),
@@ -10,6 +10,7 @@ vi.mock('zustand/traditional');
// Mock agentService
vi.mock('@/services/agent', () => ({
AVAILABLE_AGENTS_CONTEXT_QUERY_LIMIT: 12,
agentService: {
updateAgentConfig: vi.fn(),
},
@@ -524,6 +524,63 @@ describe('ConversationLifecycle actions', () => {
expect(result.current.executeClientAgent).toHaveBeenCalled();
});
it('should merge partial persisted messages into existing topic history', async () => {
const { result } = renderHook(() => useChatStore());
const agentId = TEST_IDS.SESSION_ID;
const topicId = TEST_IDS.TOPIC_ID;
const context = { agentId, threadId: null, topicId };
const key = messageMapKey(context);
const existingMessages = [
createMockMessage({ id: 'existing-user', role: 'user', topicId }),
createMockMessage({ id: 'existing-assistant', role: 'assistant', topicId }),
];
const persistedUserMessage = createMockMessage({
id: TEST_IDS.USER_MESSAGE_ID,
role: 'user',
topicId,
});
const persistedAssistantMessage = createMockMessage({
id: TEST_IDS.ASSISTANT_MESSAGE_ID,
parentId: TEST_IDS.USER_MESSAGE_ID,
role: 'assistant',
topicId,
});
act(() => {
useChatStore.setState({
dbMessagesMap: { [key]: existingMessages },
messagesMap: { [key]: existingMessages },
});
});
vi.spyOn(aiChatService, 'sendMessageInServer').mockResolvedValue({
__isPartialMessages: true,
assistantMessageId: TEST_IDS.ASSISTANT_MESSAGE_ID,
isCreateNewTopic: false,
messages: [persistedUserMessage, persistedAssistantMessage],
topicId,
topics: undefined,
userMessageId: TEST_IDS.USER_MESSAGE_ID,
} as any);
await act(async () => {
await result.current.sendMessage({
context,
message: TEST_CONTENT.USER_MESSAGE,
});
});
expect(result.current.messagesMap[key].map((message) => message.id)).toEqual([
'existing-user',
'existing-assistant',
TEST_IDS.USER_MESSAGE_ID,
TEST_IDS.ASSISTANT_MESSAGE_ID,
]);
expect(
result.current.messagesMap[key].some((message) => message.id.startsWith('tmp_')),
).toBe(false);
});
it('should preserve editorData when enqueueing a queued message', async () => {
const { result } = renderHook(() => useChatStore());
const context = createTestContext();
@@ -8,6 +8,7 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
import * as toolEngineering from '@/helpers/toolEngineering';
import { chatService } from '@/services/chat';
import * as agentConfigResolver from '@/services/chat/mecha/agentConfigResolver';
import { useAgentStore } from '@/store/agent';
import { useAiInfraStore } from '@/store/aiInfra';
import { pageAgentRuntime } from '@/store/tool/slices/builtin/executors/lobe-page-agent';
@@ -120,6 +121,7 @@ beforeEach(() => {
serverConfigMock.enableVisualUnderstanding = false;
act(() => {
useAgentStore.setState({ availableAgents: [] });
useChatStore.setState({
refreshMessages: vi.fn(),
executeClientAgent: vi.fn(),
@@ -111,6 +111,10 @@ export interface SendMessageResult {
userMessageId: string;
}
type SendMessageServerResponseMeta = SendMessageServerResponse & {
__isPartialMessages?: boolean;
};
/**
* Actions managing the complete lifecycle of conversations including sending,
* regenerating, and resending messages
@@ -154,6 +158,22 @@ const attachSendTimeMetadataToUserMessage = (
return changed ? nextMessages : messages;
};
const mergePartialPersistedMessages = (
currentMessages: UIChatMessage[],
persistedMessages: UIChatMessage[],
replacedMessageIds: string[],
): UIChatMessage[] => {
const replacedIdSet = new Set(replacedMessageIds);
const persistedIdSet = new Set(persistedMessages.map((message) => message.id));
return [
...currentMessages.filter(
(message) => !replacedIdSet.has(message.id) && !persistedIdSet.has(message.id),
),
...persistedMessages,
];
};
export class ConversationLifecycleActionImpl {
readonly #get: () => ChatStore;
@@ -555,9 +575,18 @@ export class ConversationLifecycleActionImpl {
...operationContext,
topicId: heteroData.topicId ?? operationContext.topicId,
};
const heteroResponseMeta = heteroData as SendMessageServerResponseMeta;
const heteroMessageKey = messageMapKey(heteroContext);
const heteroMessages = heteroResponseMeta.__isPartialMessages
? mergePartialPersistedMessages(
this.#get().messagesMap[heteroMessageKey] || [],
heteroData.messages,
[tempId, tempAssistantId],
)
: heteroData.messages;
// Replace optimistic messages with persisted ones
this.#get().replaceMessages(heteroData.messages, {
this.#get().replaceMessages(heteroMessages, {
action: 'sendMessage/serverResponse',
context: heteroContext,
});
@@ -770,6 +799,7 @@ export class ConversationLifecycleActionImpl {
},
abortController,
);
const responseMeta = data as SendMessageServerResponseMeta;
// Use created topicId/threadId if available, otherwise use original from context
let finalTopicId = data.topicId ?? operationContext.topicId;
const finalThreadId = data.createdThreadId ?? operationContext.threadId;
@@ -806,6 +836,7 @@ export class ConversationLifecycleActionImpl {
'sendMessage/createTopicPlaceholder',
);
this.#get().updateOperationMetadata(operationId, { createdTopicId: data.topicId });
void Promise.resolve(this.#get().refreshTopic()).catch(console.error);
} else if (operationContext.topicId) {
// Optimistically update topic's updatedAt so sidebar re-groups immediately
this.#get().internal_dispatchTopic({
@@ -838,13 +869,21 @@ export class ConversationLifecycleActionImpl {
// Create final context with updated topicId/threadId from server response
const finalContext = { ...operationContext, topicId: finalTopicId, threadId: finalThreadId };
const persistedMessages = attachSendTimeMetadataToUserMessage(
data.messages,
data.userMessageId,
userMessageMetadata,
);
const finalMessageKey = messageMapKey(finalContext);
data = {
...data,
messages: attachSendTimeMetadataToUserMessage(
data.messages,
data.userMessageId,
userMessageMetadata,
),
messages: responseMeta.__isPartialMessages
? mergePartialPersistedMessages(
this.#get().messagesMap[finalMessageKey] || [],
persistedMessages,
[tempId, tempAssistantId],
)
: persistedMessages,
};
this.#get().replaceMessages(data.messages, {
@@ -152,12 +152,10 @@ export const streamingExecutor = (set: Setter, get: () => ChatStore, _api?: unkn
export class StreamingExecutorActionImpl {
readonly #get: () => ChatStore;
// eslint-disable-next-line no-unused-private-class-members
readonly #set: Setter;
constructor(set: Setter, get: () => ChatStore, _api?: unknown) {
void set;
void _api;
this.#set = set;
this.#get = get;
}
@@ -496,7 +494,6 @@ export class StreamingExecutorActionImpl {
// Extract values from context
const { agentId, topicId, threadId, subAgentId, groupId, scope } = context;
// Determine effectiveAgentId for agent config retrieval:
// - subAgentId is used when present (behavior depends on scope)
// - agentId: Default
@@ -8,6 +8,7 @@ import { mutate } from '@/libs/swr';
import { chatService } from '@/services/chat';
import { messageService } from '@/services/message';
import { topicService } from '@/services/topic';
import { useAgentStore } from '@/store/agent';
import { PortalViewType } from '@/store/chat/slices/portal/initialState';
import { messageMapKey } from '@/store/chat/utils/messageMapKey';
import { topicMapKey } from '@/store/chat/utils/topicMapKey';
@@ -80,6 +81,7 @@ beforeEach(() => {
},
false,
);
useAgentStore.setState({ agentDocumentsMap: {} });
useSessionStore.setState(
{
activeId: 'inbox',
@@ -4,6 +4,7 @@ import { type SWRResponse } from 'swr';
import { type SidebarAgentItem, type SidebarAgentListResponse } from '@/database/repositories/home';
import { mutate, useClientDataSWR, useClientDataSWRWithSync } from '@/libs/swr';
import { homeService } from '@/services/home';
import { getAgentStoreState } from '@/store/agent';
import { type HomeStore } from '@/store/home/store';
import { type StoreSetter } from '@/store/types';
import { setNamespace } from '@/utils/storeDebug';
@@ -38,6 +39,7 @@ export class AgentListActionImpl {
};
refreshAgentList = async (): Promise<void> => {
getAgentStoreState().invalidateAvailableAgents();
await mutate([FETCH_AGENT_LIST_KEY, true]);
};
@@ -40,6 +40,7 @@ vi.mock('@/store/agent', async (importOriginal) => {
return {
...actual,
getAgentStoreState: vi.fn(() => ({
invalidateAvailableAgents: vi.fn(),
setActiveAgentId: vi.fn(),
})),
useAgentStore: actual.useAgentStore,
@@ -162,6 +163,7 @@ describe('createSidebarUISlice', () => {
const mockSetActiveAgentId = vi.fn();
vi.mocked(getAgentStoreState).mockReturnValue({
invalidateAvailableAgents: vi.fn(),
setActiveAgentId: mockSetActiveAgentId,
} as any);
@@ -200,6 +202,7 @@ describe('createSidebarUISlice', () => {
const mockNewAgentId = 'new-agent-456';
vi.mocked(getAgentStoreState).mockReturnValue({
invalidateAvailableAgents: vi.fn(),
setActiveAgentId: vi.fn(),
} as any);
@@ -80,6 +80,7 @@ describe('toAgentContextDocument', () => {
expect(toAgentContextDocument(doc)).toEqual({
content: 'body',
contentCharCount: 4,
description: 'web-crawled article',
filename: 'crawl.md',
id: 'agent-doc-2',
+6 -3
View File
@@ -4,7 +4,7 @@ import {
type AgentDocumentInjectionPosition,
} from '@lobechat/context-engine';
import type { AgentDocumentWithRules } from '@/database/models/agentDocuments';
import type { AgentDocumentContextPayload } from '@/database/models/agentDocuments';
const VALID_DOCUMENT_POSITIONS = new Set<AgentDocumentInjectionPosition>(
AGENT_DOCUMENT_INJECTION_POSITIONS,
@@ -31,8 +31,9 @@ export const normalizeAgentDocumentPosition = (
* added on the client only, which broke the "hide web crawls from the
* progressive index" filter on every server-driven chat ().
*/
export const toAgentContextDocument = (doc: AgentDocumentWithRules): AgentContextDocument => ({
export const toAgentContextDocument = (doc: AgentDocumentContextPayload): AgentContextDocument => ({
content: doc.content,
contentCharCount: doc.contentCharCount ?? doc.content.length,
description: doc.description ?? undefined,
filename: doc.filename,
id: doc.id,
@@ -60,5 +61,7 @@ export const toAgentContextDocument = (doc: AgentDocumentWithRules): AgentContex
* field, so the folder check has to happen here, at the DBcontext boundary,
* where the derived `isFolder` flag is still available.
*/
export const toAgentContextDocuments = (docs: AgentDocumentWithRules[]): AgentContextDocument[] =>
export const toAgentContextDocuments = (
docs: AgentDocumentContextPayload[],
): AgentContextDocument[] =>
docs.filter((doc) => !doc.isFolder).map((doc) => toAgentContextDocument(doc));
+1 -1
View File
@@ -17,7 +17,7 @@
"incremental": true,
"types": ["vitest/globals"],
"paths": {
"@/database/*": ["./packages/database/src/*", "./src/database/*"],
"@/database/*": ["./packages/database/src/*"],
"@/const/*": ["./packages/const/src/*", "./src/const/*"],
"@/utils/*": ["./packages/utils/src/*", "./src/utils/*"],
"@/types/*": ["./packages/types/src/*", "./src/types/*"],
-2
View File
@@ -26,7 +26,6 @@ const alias = {
),
'@emoji-mart/data': resolve(__dirname, './tests/mocks/emojiMartData.ts'),
'@emoji-mart/react': resolve(__dirname, './tests/mocks/emojiMartReact.tsx'),
'@/database/_deprecated': resolve(__dirname, './src/database/_deprecated'),
'@/utils/client/switchLang': resolve(__dirname, './src/utils/client/switchLang'),
'@/const/locale': resolve(__dirname, './src/const/locale'),
// TODO: after refactor the errorResponse, we can remove it
@@ -110,7 +109,6 @@ export default defineConfig({
// just ignore the migration code
// we will use pglite in the future
// so the coverage of this file is not important
'src/database/client/core/db.ts',
'src/utils/fetch/fetchEventSource/*.ts',
],
provider: 'v8',