mirror of
https://github.com/lobehub/lobe-chat.git
synced 2026-06-14 11:40:07 +00:00
Compare commits
28 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 70a82787f3 | |||
| 68762fc4ae | |||
| 1a58d530fb | |||
| ca01385666 | |||
| 5231bbbcac | |||
| 496b10f5c0 | |||
| 1800110748 | |||
| b068c427d4 | |||
| d5eec83a72 | |||
| 6c9cbb07ee | |||
| b92ee0ade5 | |||
| 3327b293d6 | |||
| d7e5d4645d | |||
| 918e4a8fa1 | |||
| f58015bb23 | |||
| e6244aaea6 | |||
| e9d43cb43f | |||
| 5b03f009ee | |||
| 25cf3bfafd | |||
| 3cb7206d90 | |||
| e364b9a516 | |||
| a7e3d198df | |||
| 14cd81b624 | |||
| bd345d35a8 | |||
| 40d0825d79 | |||
| ea725aca9e | |||
| 306691b4d7 | |||
| 11318f8ab9 |
@@ -24,64 +24,241 @@ Two approaches for local testing on macOS:
|
||||
|
||||
Use `agent-browser` to automate Chromium-based apps via Chrome DevTools Protocol.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- `agent-browser` CLI installed globally (`agent-browser --version`)
|
||||
Install via `npm i -g agent-browser`, `brew install agent-browser`, or `cargo install agent-browser`. Run `agent-browser install` to download Chrome. Run `agent-browser upgrade` to update.
|
||||
|
||||
## Core Workflow
|
||||
|
||||
### 1. Snapshot → Find Elements
|
||||
Every browser automation follows this pattern:
|
||||
|
||||
1. **Navigate**: `agent-browser open <url>`
|
||||
2. **Snapshot**: `agent-browser snapshot -i` (get element refs like `@e1`, `@e2`)
|
||||
3. **Interact**: Use refs to click, fill, select
|
||||
4. **Re-snapshot**: After navigation or DOM changes, get fresh refs
|
||||
|
||||
```bash
|
||||
agent-browser --cdp -i < PORT > snapshot # Interactive elements only
|
||||
agent-browser --cdp -i -C < PORT > snapshot # Include contenteditable elements
|
||||
agent-browser open https://example.com/form
|
||||
agent-browser snapshot -i
|
||||
# Output: @e1 [input type="email"], @e2 [input type="password"], @e3 [button] "Submit"
|
||||
|
||||
agent-browser fill @e1 "user@example.com"
|
||||
agent-browser fill @e2 "password123"
|
||||
agent-browser click @e3
|
||||
agent-browser wait --load networkidle
|
||||
agent-browser snapshot -i # Check result
|
||||
```
|
||||
|
||||
Returns element refs like `@e1`, `@e2`. **Refs are ephemeral** — re-snapshot after any page change.
|
||||
|
||||
### 2. Interact
|
||||
## Command Chaining
|
||||
|
||||
```bash
|
||||
agent-browser --cdp @e5 < PORT > click
|
||||
agent-browser --cdp @e3 "text" < PORT > type # Character by character (contenteditable)
|
||||
agent-browser --cdp @e3 "text" < PORT > fill # Bulk fill (regular inputs)
|
||||
agent-browser --cdp Enter < PORT > press
|
||||
agent-browser --cdp down 500 < PORT > scroll
|
||||
# Chain open + wait + snapshot in one call
|
||||
agent-browser open https://example.com && agent-browser wait --load networkidle && agent-browser snapshot -i
|
||||
```
|
||||
|
||||
### 3. Wait
|
||||
Use `&&` when you don't need to read intermediate output. Run commands separately when you need to parse output first (e.g., snapshot to discover refs, then interact).
|
||||
|
||||
## Essential Commands
|
||||
|
||||
```bash
|
||||
agent-browser --cdp 2000 < PORT > wait # Wait ms
|
||||
agent-browser --cdp --load networkidle < PORT > wait # Wait for network
|
||||
# Navigation
|
||||
agent-browser open <url> # Navigate (aliases: goto, navigate)
|
||||
agent-browser close # Close browser
|
||||
agent-browser close --all # Close all active sessions
|
||||
|
||||
# Snapshot
|
||||
agent-browser snapshot -i # Interactive elements with refs (recommended)
|
||||
agent-browser snapshot -s "#selector" # Scope to CSS selector
|
||||
|
||||
# Interaction (use @refs from snapshot)
|
||||
agent-browser click @e1 # Click element
|
||||
agent-browser click @e1 --new-tab # Click and open in new tab
|
||||
agent-browser fill @e2 "text" # Clear and type text
|
||||
agent-browser type @e2 "text" # Type without clearing
|
||||
agent-browser select @e1 "option" # Select dropdown option
|
||||
agent-browser check @e1 # Check checkbox
|
||||
agent-browser press Enter # Press key
|
||||
agent-browser keyboard type "text" # Type at current focus (no selector)
|
||||
agent-browser keyboard inserttext "text" # Insert without key events
|
||||
agent-browser scroll down 500 # Scroll page
|
||||
agent-browser scroll down 500 --selector "div.content" # Scroll within container
|
||||
|
||||
# Get information
|
||||
agent-browser get text @e1 # Get element text
|
||||
agent-browser get url # Get current URL
|
||||
agent-browser get title # Get page title
|
||||
agent-browser get cdp-url # Get CDP WebSocket URL
|
||||
|
||||
# Wait
|
||||
agent-browser wait @e1 # Wait for element
|
||||
agent-browser wait --load networkidle # Wait for network idle
|
||||
agent-browser wait --url "**/page" # Wait for URL pattern
|
||||
agent-browser wait 2000 # Wait milliseconds
|
||||
agent-browser wait --text "Welcome" # Wait for text to appear
|
||||
agent-browser wait --fn "!document.body.innerText.includes('Loading...')" # Wait for text to disappear
|
||||
agent-browser wait "#spinner" --state hidden # Wait for element to disappear
|
||||
|
||||
# Downloads
|
||||
agent-browser download @e1 ./file.pdf # Click element to trigger download
|
||||
agent-browser wait --download ./output.zip # Wait for any download to complete
|
||||
|
||||
# Network
|
||||
agent-browser network requests # Inspect tracked requests
|
||||
agent-browser network requests --type xhr,fetch # Filter by resource type
|
||||
agent-browser network requests --method POST # Filter by HTTP method
|
||||
agent-browser network route "**/api/*" --abort # Block matching requests
|
||||
agent-browser network har start # Start HAR recording
|
||||
agent-browser network har stop ./capture.har # Stop and save HAR file
|
||||
|
||||
# Viewport & Device Emulation
|
||||
agent-browser set viewport 1920 1080 # Set viewport size (default: 1280x720)
|
||||
agent-browser set viewport 1920 1080 2 # 2x retina
|
||||
agent-browser set device "iPhone 14" # Emulate device (viewport + user agent)
|
||||
|
||||
# Capture
|
||||
agent-browser screenshot # Screenshot to temp dir
|
||||
agent-browser screenshot --full # Full page screenshot
|
||||
agent-browser screenshot --annotate # Annotated screenshot with numbered element labels
|
||||
agent-browser pdf output.pdf # Save as PDF
|
||||
|
||||
# Clipboard
|
||||
agent-browser clipboard read # Read text from clipboard
|
||||
agent-browser clipboard write "text" # Write text to clipboard
|
||||
agent-browser clipboard copy # Copy current selection
|
||||
agent-browser clipboard paste # Paste from clipboard
|
||||
|
||||
# Dialogs (alert, confirm, prompt, beforeunload)
|
||||
agent-browser dialog accept # Accept dialog
|
||||
agent-browser dialog accept "input" # Accept prompt dialog with text
|
||||
agent-browser dialog dismiss # Dismiss/cancel dialog
|
||||
agent-browser dialog status # Check if dialog is open
|
||||
|
||||
# Diff (compare page states)
|
||||
agent-browser diff snapshot # Compare current vs last snapshot
|
||||
agent-browser diff screenshot --baseline before.png # Visual pixel diff
|
||||
agent-browser diff url <url1> <url2> # Compare two pages
|
||||
|
||||
# Streaming
|
||||
agent-browser stream enable # Start WebSocket streaming
|
||||
agent-browser stream status # Inspect streaming state
|
||||
agent-browser stream disable # Stop streaming
|
||||
```
|
||||
|
||||
For waits >30s, use `sleep N` in bash instead — `agent-browser wait` blocks the daemon.
|
||||
|
||||
### 4. Screenshot & Verify
|
||||
## Batch Execution
|
||||
|
||||
```bash
|
||||
agent-browser --cdp < PORT > screenshot # Save to ~/.agent-browser/tmp/screenshots/
|
||||
agent-browser --cdp text @e1 < PORT > get # Get element text
|
||||
agent-browser --cdp url < PORT > get # Get current URL
|
||||
echo '[
|
||||
["open", "https://example.com"],
|
||||
["snapshot", "-i"],
|
||||
["click", "@e1"],
|
||||
["screenshot", "result.png"]
|
||||
]' | agent-browser batch --json
|
||||
```
|
||||
|
||||
Read screenshots with the `Read` tool for visual verification.
|
||||
|
||||
### 5. Evaluate JavaScript
|
||||
## Authentication
|
||||
|
||||
```bash
|
||||
agent-browser --cdp "document.title" < PORT > eval
|
||||
# Option 1: Auth vault (credentials stored encrypted)
|
||||
echo "$PASSWORD" | agent-browser auth save myapp --url https://app.example.com/login --username user --password-stdin
|
||||
agent-browser auth login myapp
|
||||
|
||||
# Option 2: Session name (auto-save/restore cookies + localStorage)
|
||||
agent-browser --session-name myapp open https://app.example.com/login
|
||||
agent-browser close # State auto-saved
|
||||
agent-browser --session-name myapp open https://app.example.com/dashboard # Auto-restored
|
||||
|
||||
# Option 3: Persistent profile
|
||||
agent-browser --profile ~/.myapp open https://app.example.com/login
|
||||
|
||||
# Option 4: State file
|
||||
agent-browser state save auth.json
|
||||
agent-browser state load auth.json
|
||||
```
|
||||
|
||||
For multi-line JS, use `--stdin`:
|
||||
## Semantic Locators (Alternative to Refs)
|
||||
|
||||
```bash
|
||||
agent-browser --cdp --stdin < PORT > eval << 'EVALEOF'
|
||||
(function() {
|
||||
return JSON.stringify({ title: document.title, url: location.href });
|
||||
})()
|
||||
agent-browser find text "Sign In" click
|
||||
agent-browser find label "Email" fill "user@test.com"
|
||||
agent-browser find role button click --name "Submit"
|
||||
agent-browser find placeholder "Search" type "query"
|
||||
agent-browser find testid "submit-btn" click
|
||||
```
|
||||
|
||||
## JavaScript Evaluation (eval)
|
||||
|
||||
```bash
|
||||
# Simple expressions
|
||||
agent-browser eval 'document.title'
|
||||
|
||||
# Complex JS: use --stdin with heredoc (RECOMMENDED)
|
||||
agent-browser eval --stdin <<'EVALEOF'
|
||||
JSON.stringify(
|
||||
Array.from(document.querySelectorAll("img"))
|
||||
.filter(i => !i.alt)
|
||||
.map(i => ({ src: i.src.split("/").pop(), width: i.width }))
|
||||
)
|
||||
EVALEOF
|
||||
|
||||
# Base64 encoding (avoids all shell escaping issues)
|
||||
agent-browser eval -b "$(echo -n 'document.title' | base64)"
|
||||
```
|
||||
|
||||
## Ref Lifecycle
|
||||
|
||||
Refs (`@e1`, `@e2`, etc.) are invalidated when the page changes. Always re-snapshot after clicking links/buttons that navigate, form submissions, or dynamic content loading.
|
||||
|
||||
## Annotated Screenshots (Vision Mode)
|
||||
|
||||
```bash
|
||||
agent-browser screenshot --annotate
|
||||
# Output includes the image path and a legend:
|
||||
# [1] @e1 button "Submit"
|
||||
# [2] @e2 link "Home"
|
||||
agent-browser click @e2 # Click using ref from annotated screenshot
|
||||
```
|
||||
|
||||
## Parallel Sessions
|
||||
|
||||
```bash
|
||||
agent-browser --session site1 open https://site-a.com
|
||||
agent-browser --session site2 open https://site-b.com
|
||||
agent-browser session list
|
||||
```
|
||||
|
||||
## Connect to Existing Chrome
|
||||
|
||||
```bash
|
||||
agent-browser --auto-connect snapshot # Auto-discover running Chrome
|
||||
agent-browser --cdp 9222 snapshot # Explicit CDP port
|
||||
```
|
||||
|
||||
## iOS Simulator (Mobile Safari)
|
||||
|
||||
```bash
|
||||
agent-browser device list
|
||||
agent-browser -p ios --device "iPhone 16 Pro" open https://example.com
|
||||
agent-browser -p ios snapshot -i
|
||||
agent-browser -p ios tap @e1
|
||||
agent-browser -p ios swipe up
|
||||
agent-browser -p ios screenshot mobile.png
|
||||
agent-browser -p ios close
|
||||
```
|
||||
|
||||
## Observability Dashboard
|
||||
|
||||
```bash
|
||||
agent-browser dashboard install
|
||||
agent-browser dashboard start # Background server on port 4848
|
||||
agent-browser dashboard stop
|
||||
```
|
||||
|
||||
## Cloud Providers
|
||||
|
||||
Use `-p <provider>` to run against cloud browsers: `agentcore`, `browserbase`, `browserless`, `browseruse`, `kernel`.
|
||||
|
||||
## Browser Engine Selection
|
||||
|
||||
```bash
|
||||
agent-browser --engine lightpanda open example.com # 10x faster, 10x less memory
|
||||
```
|
||||
|
||||
## Electron (LobeHub Desktop)
|
||||
@@ -187,6 +364,9 @@ agent-browser --cdp 9222 eval "JSON.stringify(window.__CAPTURED_ERRORS)"
|
||||
"<URL>" &
|
||||
sleep 5
|
||||
agent-browser --cdp 9222 snapshot -i
|
||||
|
||||
# Or auto-discover running Chrome with remote debugging
|
||||
agent-browser --auto-connect snapshot -i
|
||||
```
|
||||
|
||||
---
|
||||
@@ -907,12 +1087,14 @@ The script automatically:
|
||||
|
||||
### agent-browser
|
||||
|
||||
- **Daemon can get stuck** — if commands hang, `pkill -f agent-browser` to reset
|
||||
- **`agent-browser wait` blocks the daemon** — for waits >30s, use bash `sleep`
|
||||
- **Daemon can get stuck** — if commands hang, `agent-browser close --all` or `pkill -f agent-browser` to reset
|
||||
- **HMR invalidates everything** — after code changes, refs break. Re-snapshot or restart
|
||||
- **`snapshot -i` doesn't find contenteditable** — use `snapshot -i -C` for rich text editors
|
||||
- **`fill` doesn't work on contenteditable** — use `type` for chat inputs
|
||||
- **Screenshots go to `~/.agent-browser/tmp/screenshots/`** — read them with the `Read` tool
|
||||
- **Dialogs block all commands** — if commands time out, check `agent-browser dialog status`
|
||||
- **Default timeout is 25s** — override with `AGENT_BROWSER_DEFAULT_TIMEOUT` (ms) or use explicit waits
|
||||
- **Shell quoting corrupts eval** — use `eval --stdin <<'EVALEOF'` for complex JS
|
||||
|
||||
### Electron-specific
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
.\" Code generated by `npm run man:generate`; DO NOT EDIT.
|
||||
.\" Manual command details come from the Commander command tree.
|
||||
.TH LH 1 "" "@lobehub/cli 0.0.1\-canary.15" "User Commands"
|
||||
.TH LH 1 "" "@lobehub/cli 0.0.3" "User Commands"
|
||||
.SH NAME
|
||||
lh \- LobeHub CLI \- manage and connect to LobeHub services
|
||||
.SH SYNOPSIS
|
||||
@@ -115,6 +115,9 @@ View usage statistics
|
||||
.TP
|
||||
.B eval
|
||||
Manage evaluation workflows
|
||||
.TP
|
||||
.B migrate
|
||||
Migrate data from external tools (OpenClaw, ChatGPT, Claude, etc.)
|
||||
.SH OPTIONS
|
||||
.TP
|
||||
.B \-V, \-\-version
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lobehub/cli",
|
||||
"version": "0.0.1-canary.15",
|
||||
"version": "0.0.3",
|
||||
"type": "module",
|
||||
"bin": {
|
||||
"lh": "./dist/index.js",
|
||||
@@ -27,6 +27,9 @@
|
||||
"test:coverage": "bunx vitest run --config vitest.config.mts --coverage",
|
||||
"type-check": "tsc --noEmit"
|
||||
},
|
||||
"dependencies": {
|
||||
"ignore": "^7.0.5"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@lobechat/device-gateway-client": "workspace:*",
|
||||
"@lobechat/local-file-shell": "workspace:*",
|
||||
|
||||
@@ -39,7 +39,9 @@ async function getAuthAndServer() {
|
||||
|
||||
const result = await getValidToken();
|
||||
if (!result) {
|
||||
log.error(`No authentication found. Run 'lh login' first, or set ${CLI_API_KEY_ENV}.`);
|
||||
log.error(
|
||||
`No authentication found. Run 'lh login' (or 'npx -y @lobehub/cli login') first, or set ${CLI_API_KEY_ENV}.`,
|
||||
);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
|
||||
@@ -3,29 +3,9 @@ import { CLI_API_KEY_ENV } from '../constants/auth';
|
||||
import { resolveServerUrl } from '../settings';
|
||||
import { log } from '../utils/logger';
|
||||
|
||||
// Must match the server's SECRET_XOR_KEY (src/envs/auth.ts)
|
||||
const SECRET_XOR_KEY = 'LobeHub · LobeHub';
|
||||
|
||||
/**
|
||||
* XOR-obfuscate a payload and encode as Base64.
|
||||
* The /webapi/* routes require `X-lobe-chat-auth` with this encoding.
|
||||
*/
|
||||
function obfuscatePayloadWithXOR(payload: Record<string, any>): string {
|
||||
const jsonString = JSON.stringify(payload);
|
||||
const dataBytes = new TextEncoder().encode(jsonString);
|
||||
const keyBytes = new TextEncoder().encode(SECRET_XOR_KEY);
|
||||
|
||||
const result = new Uint8Array(dataBytes.length);
|
||||
for (let i = 0; i < dataBytes.length; i++) {
|
||||
result[i] = dataBytes[i] ^ keyBytes[i % keyBytes.length];
|
||||
}
|
||||
|
||||
return btoa(String.fromCharCode(...result));
|
||||
}
|
||||
|
||||
export interface AuthInfo {
|
||||
accessToken: string;
|
||||
/** Headers required for /webapi/* endpoints (includes both X-lobe-chat-auth and Oidc-Auth) */
|
||||
/** Headers required for /webapi/* endpoints (Oidc-Auth for authentication) */
|
||||
headers: Record<string, string>;
|
||||
serverUrl: string;
|
||||
}
|
||||
@@ -52,7 +32,6 @@ export async function getAuthInfo(): Promise<AuthInfo> {
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Oidc-Auth': accessToken,
|
||||
'X-lobe-chat-auth': obfuscatePayloadWithXOR({}),
|
||||
},
|
||||
serverUrl,
|
||||
};
|
||||
|
||||
@@ -5,7 +5,12 @@ import pc from 'picocolors';
|
||||
|
||||
import { getTrpcClient } from '../api/client';
|
||||
import { getAgentStreamAuthInfo } from '../api/http';
|
||||
import { replayAgentEvents, streamAgentEvents } from '../utils/agentStream';
|
||||
import { resolveAgentGatewayUrl } from '../settings';
|
||||
import {
|
||||
replayAgentEvents,
|
||||
streamAgentEvents,
|
||||
streamAgentEventsViaWebSocket,
|
||||
} from '../utils/agentStream';
|
||||
import { resolveLocalDeviceId } from '../utils/device';
|
||||
import { confirm, outputJson, printTable, truncate } from '../utils/format';
|
||||
import { log, setVerbose } from '../utils/logger';
|
||||
@@ -256,6 +261,7 @@ export function registerAgentCommand(program: Command) {
|
||||
.option('--json', 'Output full JSON event stream')
|
||||
.option('-v, --verbose', 'Show detailed tool call info')
|
||||
.option('--replay <file>', 'Replay events from a saved JSON file (offline)')
|
||||
.option('--sse', 'Force SSE stream instead of WebSocket gateway')
|
||||
.action(
|
||||
async (options: {
|
||||
agentId?: string;
|
||||
@@ -265,6 +271,7 @@ export function registerAgentCommand(program: Command) {
|
||||
prompt?: string;
|
||||
replay?: string;
|
||||
slug?: string;
|
||||
sse?: boolean;
|
||||
topicId?: string;
|
||||
verbose?: boolean;
|
||||
}) => {
|
||||
@@ -347,14 +354,26 @@ export function registerAgentCommand(program: Command) {
|
||||
log.info(`Operation: ${pc.dim(operationId)} · Topic: ${pc.dim(r.topicId || 'n/a')}`);
|
||||
}
|
||||
|
||||
// 2. Connect to SSE stream
|
||||
// 2. Connect to stream (WebSocket via Gateway, or fallback to SSE)
|
||||
const { serverUrl, headers } = await getAgentStreamAuthInfo();
|
||||
const streamUrl = `${serverUrl}/api/agent/stream?operationId=${encodeURIComponent(operationId)}`;
|
||||
const agentGatewayUrl = options.sse ? undefined : resolveAgentGatewayUrl();
|
||||
|
||||
await streamAgentEvents(streamUrl, headers, {
|
||||
json: options.json,
|
||||
verbose: options.verbose,
|
||||
});
|
||||
if (agentGatewayUrl) {
|
||||
const token = headers['Oidc-Auth'] || headers['X-API-Key'] || '';
|
||||
await streamAgentEventsViaWebSocket({
|
||||
gatewayUrl: agentGatewayUrl,
|
||||
json: options.json,
|
||||
operationId,
|
||||
token,
|
||||
verbose: options.verbose,
|
||||
});
|
||||
} else {
|
||||
const streamUrl = `${serverUrl}/api/agent/stream?operationId=${encodeURIComponent(operationId)}`;
|
||||
await streamAgentEvents(streamUrl, headers, {
|
||||
json: options.json,
|
||||
verbose: options.verbose,
|
||||
});
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
|
||||
@@ -61,7 +61,6 @@ describe('generate command', () => {
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Oidc-Auth': 'test-token',
|
||||
'X-lobe-chat-auth': 'test-xor-token',
|
||||
},
|
||||
serverUrl: 'https://app.lobehub.com',
|
||||
});
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
import type { Command } from 'commander';
|
||||
|
||||
import { registerOpenClawMigration } from './openclaw';
|
||||
|
||||
export function registerMigrateCommand(program: Command) {
|
||||
const migrate = program
|
||||
.command('migrate')
|
||||
.description('Migrate data from external tools (OpenClaw, ChatGPT, Claude, etc.)');
|
||||
|
||||
registerOpenClawMigration(migrate);
|
||||
}
|
||||
@@ -0,0 +1,588 @@
|
||||
import fs from 'node:fs';
|
||||
import os from 'node:os';
|
||||
import path from 'node:path';
|
||||
|
||||
import { Command } from 'commander';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
// ── Mocks ──────────────────────────────────────────────
|
||||
|
||||
const { mockTrpcClient } = vi.hoisted(() => ({
|
||||
mockTrpcClient: {
|
||||
agent: {
|
||||
createAgent: { mutate: vi.fn() },
|
||||
getBuiltinAgent: { query: vi.fn() },
|
||||
},
|
||||
agentDocument: {
|
||||
upsertDocument: { mutate: vi.fn() },
|
||||
},
|
||||
},
|
||||
}));
|
||||
|
||||
const { getTrpcClient: mockGetTrpcClient } = vi.hoisted(() => ({
|
||||
getTrpcClient: vi.fn(),
|
||||
}));
|
||||
|
||||
const { mockConfirm } = vi.hoisted(() => ({
|
||||
mockConfirm: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('../../api/client', () => ({
|
||||
getTrpcClient: mockGetTrpcClient,
|
||||
}));
|
||||
|
||||
vi.mock('../../settings', () => ({
|
||||
resolveServerUrl: () => 'https://app.lobehub.com',
|
||||
}));
|
||||
|
||||
vi.mock('../../utils/format', async (importOriginal) => {
|
||||
const actual = await importOriginal<Record<string, unknown>>();
|
||||
return { ...actual, confirm: mockConfirm };
|
||||
});
|
||||
|
||||
vi.mock('../../utils/logger', () => ({
|
||||
log: {
|
||||
debug: vi.fn(),
|
||||
error: vi.fn(),
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
},
|
||||
setVerbose: vi.fn(),
|
||||
}));
|
||||
|
||||
// eslint-disable-next-line import-x/first
|
||||
import { log } from '../../utils/logger';
|
||||
// eslint-disable-next-line import-x/first
|
||||
import { registerOpenClawMigration } from './openclaw';
|
||||
|
||||
// ── Helpers ────────────────────────────────────────────
|
||||
|
||||
let tmpDir: string;
|
||||
|
||||
function createProgram() {
|
||||
const program = new Command();
|
||||
program.exitOverride();
|
||||
const migrate = program.command('migrate');
|
||||
registerOpenClawMigration(migrate);
|
||||
return program;
|
||||
}
|
||||
|
||||
function writeFile(relativePath: string, content: string) {
|
||||
const fullPath = path.join(tmpDir, relativePath);
|
||||
fs.mkdirSync(path.dirname(fullPath), { recursive: true });
|
||||
fs.writeFileSync(fullPath, content);
|
||||
}
|
||||
|
||||
// ── Setup / teardown ───────────────────────────────────
|
||||
|
||||
let exitSpy: ReturnType<typeof vi.spyOn>;
|
||||
let consoleSpy: ReturnType<typeof vi.spyOn>;
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
tmpDir = fs.mkdtempSync(path.join(os.tmpdir(), 'openclaw-test-'));
|
||||
exitSpy = vi.spyOn(process, 'exit').mockImplementation((() => {
|
||||
throw new Error('process.exit');
|
||||
}) as any);
|
||||
consoleSpy = vi.spyOn(console, 'log').mockImplementation(() => {});
|
||||
mockGetTrpcClient.mockResolvedValue(mockTrpcClient);
|
||||
mockConfirm.mockResolvedValue(true);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
exitSpy.mockRestore();
|
||||
consoleSpy.mockRestore();
|
||||
fs.rmSync(tmpDir, { recursive: true, force: true });
|
||||
});
|
||||
|
||||
// ── Tests ──────────────────────────────────────────────
|
||||
|
||||
describe('migrate openclaw', () => {
|
||||
// ── Profile parsing ────────────────────────────────
|
||||
|
||||
describe('agent profile from workspace', () => {
|
||||
it('should read name, description, and emoji from IDENTITY.md', async () => {
|
||||
writeFile(
|
||||
'IDENTITY.md',
|
||||
['# IDENTITY.md', '- **Name:** 龙虾', '- **Creature:** AI 助手', '- **Emoji:** 🦞'].join(
|
||||
'\n',
|
||||
),
|
||||
);
|
||||
writeFile('hello.md', 'hello');
|
||||
|
||||
mockTrpcClient.agent.createAgent.mutate.mockResolvedValue({ agentId: 'agt_test' });
|
||||
mockTrpcClient.agentDocument.upsertDocument.mutate.mockResolvedValue({});
|
||||
|
||||
const program = createProgram();
|
||||
await program.parseAsync([
|
||||
'node',
|
||||
'test',
|
||||
'migrate',
|
||||
'openclaw',
|
||||
'--source',
|
||||
tmpDir,
|
||||
'--yes',
|
||||
]);
|
||||
|
||||
expect(mockTrpcClient.agent.createAgent.mutate).toHaveBeenCalledWith({
|
||||
config: {
|
||||
avatar: '🦞',
|
||||
description: 'AI 助手',
|
||||
title: '龙虾',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should filter out placeholder emoji like (待定)', async () => {
|
||||
writeFile(
|
||||
'IDENTITY.md',
|
||||
['# IDENTITY.md', '- **Name:** TestBot', '- **Emoji:**', ' _(待定)_'].join('\n'),
|
||||
);
|
||||
writeFile('hello.md', 'hello');
|
||||
|
||||
mockTrpcClient.agent.createAgent.mutate.mockResolvedValue({ agentId: 'agt_test' });
|
||||
mockTrpcClient.agentDocument.upsertDocument.mutate.mockResolvedValue({});
|
||||
|
||||
const program = createProgram();
|
||||
await program.parseAsync([
|
||||
'node',
|
||||
'test',
|
||||
'migrate',
|
||||
'openclaw',
|
||||
'--source',
|
||||
tmpDir,
|
||||
'--yes',
|
||||
]);
|
||||
|
||||
expect(mockTrpcClient.agent.createAgent.mutate).toHaveBeenCalledWith({
|
||||
config: {
|
||||
avatar: undefined,
|
||||
description: undefined,
|
||||
title: 'TestBot',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should fall back to "OpenClaw" when no identity files exist', async () => {
|
||||
writeFile('doc.md', 'content');
|
||||
|
||||
mockTrpcClient.agent.createAgent.mutate.mockResolvedValue({ agentId: 'agt_test' });
|
||||
mockTrpcClient.agentDocument.upsertDocument.mutate.mockResolvedValue({});
|
||||
|
||||
const program = createProgram();
|
||||
await program.parseAsync([
|
||||
'node',
|
||||
'test',
|
||||
'migrate',
|
||||
'openclaw',
|
||||
'--source',
|
||||
tmpDir,
|
||||
'--yes',
|
||||
]);
|
||||
|
||||
expect(mockTrpcClient.agent.createAgent.mutate).toHaveBeenCalledWith({
|
||||
config: {
|
||||
avatar: undefined,
|
||||
description: undefined,
|
||||
title: 'OpenClaw',
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// ── File filtering ─────────────────────────────────
|
||||
|
||||
describe('file collection and filtering', () => {
|
||||
it('should exclude common directories like node_modules and .git', async () => {
|
||||
writeFile('README.md', 'readme');
|
||||
writeFile('node_modules/pkg/index.js', 'module');
|
||||
writeFile('.git/config', 'git');
|
||||
writeFile('.idea/workspace.xml', 'ide');
|
||||
|
||||
mockTrpcClient.agent.createAgent.mutate.mockResolvedValue({ agentId: 'agt_test' });
|
||||
mockTrpcClient.agentDocument.upsertDocument.mutate.mockResolvedValue({});
|
||||
|
||||
const program = createProgram();
|
||||
await program.parseAsync([
|
||||
'node',
|
||||
'test',
|
||||
'migrate',
|
||||
'openclaw',
|
||||
'--source',
|
||||
tmpDir,
|
||||
'--yes',
|
||||
]);
|
||||
|
||||
expect(mockTrpcClient.agentDocument.upsertDocument.mutate).toHaveBeenCalledTimes(1);
|
||||
expect(mockTrpcClient.agentDocument.upsertDocument.mutate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ filename: 'README.md' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should exclude files matching glob patterns like *.pyc and *.log', async () => {
|
||||
writeFile('main.py', 'print("hi")');
|
||||
writeFile('main.pyc', 'bytecode');
|
||||
writeFile('app.log', 'log data');
|
||||
|
||||
mockTrpcClient.agent.createAgent.mutate.mockResolvedValue({ agentId: 'agt_test' });
|
||||
mockTrpcClient.agentDocument.upsertDocument.mutate.mockResolvedValue({});
|
||||
|
||||
const program = createProgram();
|
||||
await program.parseAsync([
|
||||
'node',
|
||||
'test',
|
||||
'migrate',
|
||||
'openclaw',
|
||||
'--source',
|
||||
tmpDir,
|
||||
'--yes',
|
||||
]);
|
||||
|
||||
expect(mockTrpcClient.agentDocument.upsertDocument.mutate).toHaveBeenCalledTimes(1);
|
||||
expect(mockTrpcClient.agentDocument.upsertDocument.mutate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ filename: 'main.py' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should respect workspace .gitignore', async () => {
|
||||
writeFile('.gitignore', 'secret.txt\ndata/\n');
|
||||
writeFile('README.md', 'readme');
|
||||
writeFile('secret.txt', 'password');
|
||||
writeFile('data/dump.sql', 'sql');
|
||||
|
||||
mockTrpcClient.agent.createAgent.mutate.mockResolvedValue({ agentId: 'agt_test' });
|
||||
mockTrpcClient.agentDocument.upsertDocument.mutate.mockResolvedValue({});
|
||||
|
||||
const program = createProgram();
|
||||
await program.parseAsync([
|
||||
'node',
|
||||
'test',
|
||||
'migrate',
|
||||
'openclaw',
|
||||
'--source',
|
||||
tmpDir,
|
||||
'--yes',
|
||||
]);
|
||||
|
||||
const filenames = mockTrpcClient.agentDocument.upsertDocument.mutate.mock.calls.map(
|
||||
(c: any[]) => c[0].filename,
|
||||
);
|
||||
expect(filenames).toContain('README.md');
|
||||
expect(filenames).not.toContain('secret.txt');
|
||||
expect(filenames).not.toContain('data/dump.sql');
|
||||
});
|
||||
|
||||
it('should skip binary files during import', async () => {
|
||||
writeFile('readme.md', 'text content');
|
||||
// Write a file with null bytes (binary)
|
||||
const binPath = path.join(tmpDir, 'image.dat');
|
||||
fs.writeFileSync(binPath, Buffer.from([0x89, 0x50, 0x4e, 0x47, 0x00, 0x00, 0x01]));
|
||||
|
||||
mockTrpcClient.agent.createAgent.mutate.mockResolvedValue({ agentId: 'agt_test' });
|
||||
mockTrpcClient.agentDocument.upsertDocument.mutate.mockResolvedValue({});
|
||||
|
||||
const program = createProgram();
|
||||
await program.parseAsync([
|
||||
'node',
|
||||
'test',
|
||||
'migrate',
|
||||
'openclaw',
|
||||
'--source',
|
||||
tmpDir,
|
||||
'--yes',
|
||||
]);
|
||||
|
||||
// Only the text file should be upserted
|
||||
expect(mockTrpcClient.agentDocument.upsertDocument.mutate).toHaveBeenCalledTimes(1);
|
||||
expect(mockTrpcClient.agentDocument.upsertDocument.mutate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ filename: 'readme.md' }),
|
||||
);
|
||||
// Binary file should show as skipped in output
|
||||
const allOutput = consoleSpy.mock.calls.map((c: any[]) => c[0]).join('\n');
|
||||
expect(allOutput).toContain('skipped');
|
||||
});
|
||||
|
||||
it('should exclude database files by extension', async () => {
|
||||
writeFile('data.md', 'notes');
|
||||
writeFile('local.sqlite', 'fake-sqlite');
|
||||
writeFile('app.db', 'fake-db');
|
||||
|
||||
mockTrpcClient.agent.createAgent.mutate.mockResolvedValue({ agentId: 'agt_test' });
|
||||
mockTrpcClient.agentDocument.upsertDocument.mutate.mockResolvedValue({});
|
||||
|
||||
const program = createProgram();
|
||||
await program.parseAsync([
|
||||
'node',
|
||||
'test',
|
||||
'migrate',
|
||||
'openclaw',
|
||||
'--source',
|
||||
tmpDir,
|
||||
'--yes',
|
||||
]);
|
||||
|
||||
expect(mockTrpcClient.agentDocument.upsertDocument.mutate).toHaveBeenCalledTimes(1);
|
||||
expect(mockTrpcClient.agentDocument.upsertDocument.mutate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ filename: 'data.md' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should collect files in subdirectories', async () => {
|
||||
writeFile('docs/guide.md', 'guide');
|
||||
writeFile('docs/api.md', 'api');
|
||||
|
||||
mockTrpcClient.agent.createAgent.mutate.mockResolvedValue({ agentId: 'agt_test' });
|
||||
mockTrpcClient.agentDocument.upsertDocument.mutate.mockResolvedValue({});
|
||||
|
||||
const program = createProgram();
|
||||
await program.parseAsync([
|
||||
'node',
|
||||
'test',
|
||||
'migrate',
|
||||
'openclaw',
|
||||
'--source',
|
||||
tmpDir,
|
||||
'--yes',
|
||||
]);
|
||||
|
||||
const filenames = mockTrpcClient.agentDocument.upsertDocument.mutate.mock.calls
|
||||
.map((c: any[]) => c[0].filename)
|
||||
.sort();
|
||||
expect(filenames).toEqual(['docs/api.md', 'docs/guide.md']);
|
||||
});
|
||||
});
|
||||
|
||||
// ── Dry run ────────────────────────────────────────
|
||||
|
||||
describe('--dry-run', () => {
|
||||
it('should list files without calling API', async () => {
|
||||
writeFile('file.md', 'content');
|
||||
|
||||
const program = createProgram();
|
||||
await program.parseAsync([
|
||||
'node',
|
||||
'test',
|
||||
'migrate',
|
||||
'openclaw',
|
||||
'--source',
|
||||
tmpDir,
|
||||
'--dry-run',
|
||||
]);
|
||||
|
||||
expect(mockGetTrpcClient).not.toHaveBeenCalled();
|
||||
expect(mockTrpcClient.agent.createAgent.mutate).not.toHaveBeenCalled();
|
||||
expect(mockTrpcClient.agentDocument.upsertDocument.mutate).not.toHaveBeenCalled();
|
||||
expect(log.info).toHaveBeenCalledWith(expect.stringContaining('Dry run'));
|
||||
});
|
||||
});
|
||||
|
||||
// ── Agent resolution ───────────────────────────────
|
||||
|
||||
describe('agent resolution', () => {
|
||||
it('should use --agent-id directly when provided', async () => {
|
||||
writeFile('file.md', 'content');
|
||||
mockTrpcClient.agentDocument.upsertDocument.mutate.mockResolvedValue({});
|
||||
|
||||
const program = createProgram();
|
||||
await program.parseAsync([
|
||||
'node',
|
||||
'test',
|
||||
'migrate',
|
||||
'openclaw',
|
||||
'--source',
|
||||
tmpDir,
|
||||
'--agent-id',
|
||||
'agt_existing',
|
||||
'--yes',
|
||||
]);
|
||||
|
||||
expect(mockTrpcClient.agent.createAgent.mutate).not.toHaveBeenCalled();
|
||||
expect(mockTrpcClient.agentDocument.upsertDocument.mutate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ agentId: 'agt_existing' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should resolve agent by --slug', async () => {
|
||||
writeFile('file.md', 'content');
|
||||
mockTrpcClient.agent.getBuiltinAgent.query.mockResolvedValue({ id: 'agt_inbox' });
|
||||
mockTrpcClient.agentDocument.upsertDocument.mutate.mockResolvedValue({});
|
||||
|
||||
const program = createProgram();
|
||||
await program.parseAsync([
|
||||
'node',
|
||||
'test',
|
||||
'migrate',
|
||||
'openclaw',
|
||||
'--source',
|
||||
tmpDir,
|
||||
'--slug',
|
||||
'inbox',
|
||||
'--yes',
|
||||
]);
|
||||
|
||||
expect(mockTrpcClient.agent.getBuiltinAgent.query).toHaveBeenCalledWith({ slug: 'inbox' });
|
||||
expect(mockTrpcClient.agentDocument.upsertDocument.mutate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ agentId: 'agt_inbox' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should create a new agent by default', async () => {
|
||||
writeFile('file.md', 'content');
|
||||
mockTrpcClient.agent.createAgent.mutate.mockResolvedValue({ agentId: 'agt_new' });
|
||||
mockTrpcClient.agentDocument.upsertDocument.mutate.mockResolvedValue({});
|
||||
|
||||
const program = createProgram();
|
||||
await program.parseAsync([
|
||||
'node',
|
||||
'test',
|
||||
'migrate',
|
||||
'openclaw',
|
||||
'--source',
|
||||
tmpDir,
|
||||
'--yes',
|
||||
]);
|
||||
|
||||
expect(mockTrpcClient.agent.createAgent.mutate).toHaveBeenCalledTimes(1);
|
||||
expect(mockTrpcClient.agentDocument.upsertDocument.mutate).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ agentId: 'agt_new' }),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
// ── Confirmation ───────────────────────────────────
|
||||
|
||||
describe('confirmation', () => {
|
||||
it('should cancel when user declines', async () => {
|
||||
writeFile('file.md', 'content');
|
||||
mockConfirm.mockResolvedValue(false);
|
||||
|
||||
const program = createProgram();
|
||||
await program.parseAsync(['node', 'test', 'migrate', 'openclaw', '--source', tmpDir]);
|
||||
|
||||
expect(mockTrpcClient.agent.createAgent.mutate).not.toHaveBeenCalled();
|
||||
expect(consoleSpy).toHaveBeenCalledWith('Cancelled.');
|
||||
});
|
||||
|
||||
it('should skip confirmation with --yes', async () => {
|
||||
writeFile('file.md', 'content');
|
||||
mockTrpcClient.agent.createAgent.mutate.mockResolvedValue({ agentId: 'agt_test' });
|
||||
mockTrpcClient.agentDocument.upsertDocument.mutate.mockResolvedValue({});
|
||||
|
||||
const program = createProgram();
|
||||
await program.parseAsync([
|
||||
'node',
|
||||
'test',
|
||||
'migrate',
|
||||
'openclaw',
|
||||
'--source',
|
||||
tmpDir,
|
||||
'--yes',
|
||||
]);
|
||||
|
||||
expect(mockConfirm).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
// ── Error handling ─────────────────────────────────
|
||||
|
||||
describe('error handling', () => {
|
||||
it('should exit when source path does not exist', async () => {
|
||||
const program = createProgram();
|
||||
await program
|
||||
.parseAsync(['node', 'test', 'migrate', 'openclaw', '--source', '/nonexistent/path'])
|
||||
.catch(() => {}); // process.exit throws
|
||||
|
||||
expect(exitSpy).toHaveBeenCalledWith(1);
|
||||
expect(log.error).toHaveBeenCalledWith(expect.stringContaining('not found'));
|
||||
});
|
||||
|
||||
it('should report failed files without aborting', async () => {
|
||||
writeFile('a.md', 'ok');
|
||||
writeFile('b.md', 'fail');
|
||||
|
||||
mockTrpcClient.agent.createAgent.mutate.mockResolvedValue({ agentId: 'agt_test' });
|
||||
// Files are iterated in readdir order; mock first success then failure
|
||||
mockTrpcClient.agentDocument.upsertDocument.mutate
|
||||
.mockResolvedValueOnce({})
|
||||
.mockRejectedValueOnce(new Error('upload error'));
|
||||
|
||||
const program = createProgram();
|
||||
await program.parseAsync([
|
||||
'node',
|
||||
'test',
|
||||
'migrate',
|
||||
'openclaw',
|
||||
'--source',
|
||||
tmpDir,
|
||||
'--yes',
|
||||
]);
|
||||
|
||||
expect(mockTrpcClient.agentDocument.upsertDocument.mutate).toHaveBeenCalledTimes(2);
|
||||
const allOutput = consoleSpy.mock.calls.map((c: any[]) => c[0]).join('\n');
|
||||
expect(allOutput).toContain('1 imported');
|
||||
expect(allOutput).toContain('1 failed');
|
||||
});
|
||||
|
||||
it('should show no files message for empty workspace', async () => {
|
||||
// Only excluded items
|
||||
writeFile('.git/config', 'git');
|
||||
|
||||
const program = createProgram();
|
||||
await program.parseAsync([
|
||||
'node',
|
||||
'test',
|
||||
'migrate',
|
||||
'openclaw',
|
||||
'--source',
|
||||
tmpDir,
|
||||
'--dry-run',
|
||||
]);
|
||||
|
||||
expect(log.info).toHaveBeenCalledWith('No files found in workspace.');
|
||||
});
|
||||
});
|
||||
|
||||
// ── Output ─────────────────────────────────────────
|
||||
|
||||
describe('output', () => {
|
||||
it('should print agent URL on completion', async () => {
|
||||
writeFile('file.md', 'content');
|
||||
mockTrpcClient.agent.createAgent.mutate.mockResolvedValue({ agentId: 'agt_abc123' });
|
||||
mockTrpcClient.agentDocument.upsertDocument.mutate.mockResolvedValue({});
|
||||
|
||||
const program = createProgram();
|
||||
await program.parseAsync([
|
||||
'node',
|
||||
'test',
|
||||
'migrate',
|
||||
'openclaw',
|
||||
'--source',
|
||||
tmpDir,
|
||||
'--yes',
|
||||
]);
|
||||
|
||||
const allOutput = consoleSpy.mock.calls.map((c: any[]) => c[0]).join('\n');
|
||||
expect(allOutput).toContain('https://app.lobehub.com/agent/agt_abc123');
|
||||
});
|
||||
|
||||
it('should show friendly completion message on success', async () => {
|
||||
writeFile('file.md', 'content');
|
||||
mockTrpcClient.agent.createAgent.mutate.mockResolvedValue({ agentId: 'agt_test' });
|
||||
mockTrpcClient.agentDocument.upsertDocument.mutate.mockResolvedValue({});
|
||||
|
||||
const program = createProgram();
|
||||
await program.parseAsync([
|
||||
'node',
|
||||
'test',
|
||||
'migrate',
|
||||
'openclaw',
|
||||
'--source',
|
||||
tmpDir,
|
||||
'--yes',
|
||||
]);
|
||||
|
||||
const allOutput = consoleSpy.mock.calls.map((c: any[]) => c[0]).join('\n');
|
||||
expect(allOutput).toContain('Migration complete');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,466 @@
|
||||
import fs from 'node:fs';
|
||||
import os from 'node:os';
|
||||
import path from 'node:path';
|
||||
|
||||
import type { Command } from 'commander';
|
||||
import ignore from 'ignore';
|
||||
import pc from 'picocolors';
|
||||
|
||||
import type { TrpcClient } from '../../api/client';
|
||||
import { getTrpcClient } from '../../api/client';
|
||||
import { resolveServerUrl } from '../../settings';
|
||||
import { confirm } from '../../utils/format';
|
||||
import { log } from '../../utils/logger';
|
||||
|
||||
const DEFAULT_AGENT_NAME = 'OpenClaw';
|
||||
|
||||
// Files to look for agent identity (tried in order)
|
||||
const IDENTITY_FILES = ['IDENTITY.md', 'SOUL.md'];
|
||||
|
||||
// Default ignore rules (gitignore syntax) applied when no .gitignore is found
|
||||
const DEFAULT_IGNORE_RULES = [
|
||||
// VCS
|
||||
'.git',
|
||||
'.svn',
|
||||
'.hg',
|
||||
|
||||
// OpenClaw internal
|
||||
'.openclaw',
|
||||
|
||||
// OS artifacts
|
||||
'.DS_Store',
|
||||
'Thumbs.db',
|
||||
'desktop.ini',
|
||||
|
||||
// IDE / editor
|
||||
'.idea',
|
||||
'.vscode',
|
||||
'.fleet',
|
||||
'.cursor',
|
||||
'.zed',
|
||||
'*.swp',
|
||||
'*.swo',
|
||||
'*~',
|
||||
|
||||
// Dependencies
|
||||
'node_modules',
|
||||
'.pnp',
|
||||
'.yarn',
|
||||
'bower_components',
|
||||
'vendor',
|
||||
'jspm_packages',
|
||||
|
||||
// Python
|
||||
'.venv',
|
||||
'venv',
|
||||
'env',
|
||||
'__pycache__',
|
||||
'*.pyc',
|
||||
'*.pyo',
|
||||
'.mypy_cache',
|
||||
'.ruff_cache',
|
||||
'.pytest_cache',
|
||||
'.tox',
|
||||
'.eggs',
|
||||
'*.egg-info',
|
||||
|
||||
// Ruby
|
||||
'.bundle',
|
||||
|
||||
// Rust
|
||||
'target',
|
||||
|
||||
// Go
|
||||
'go.sum',
|
||||
|
||||
// Java / JVM
|
||||
'.gradle',
|
||||
'.m2',
|
||||
|
||||
// .NET
|
||||
'bin',
|
||||
'obj',
|
||||
'packages',
|
||||
|
||||
// Build / cache / output
|
||||
'.cache',
|
||||
'.parcel-cache',
|
||||
'.next',
|
||||
'.nuxt',
|
||||
'.turbo',
|
||||
'.output',
|
||||
'dist',
|
||||
'build',
|
||||
'out',
|
||||
'.sass-cache',
|
||||
|
||||
// Env / secrets
|
||||
'.env',
|
||||
'.env.*',
|
||||
|
||||
// Test / coverage
|
||||
'coverage',
|
||||
'.nyc_output',
|
||||
|
||||
// Infra
|
||||
'.terraform',
|
||||
|
||||
// Temp
|
||||
'tmp',
|
||||
'.tmp',
|
||||
|
||||
// Logs
|
||||
'*.log',
|
||||
'logs',
|
||||
|
||||
// Databases
|
||||
'*.sqlite',
|
||||
'*.sqlite3',
|
||||
'*.db',
|
||||
'*.db-shm',
|
||||
'*.db-wal',
|
||||
'*.ldb',
|
||||
'*.mdb',
|
||||
'*.accdb',
|
||||
|
||||
// Archives / binaries
|
||||
'*.zip',
|
||||
'*.tar',
|
||||
'*.tar.gz',
|
||||
'*.tgz',
|
||||
'*.gz',
|
||||
'*.bz2',
|
||||
'*.xz',
|
||||
'*.rar',
|
||||
'*.7z',
|
||||
'*.jar',
|
||||
'*.war',
|
||||
'*.dll',
|
||||
'*.so',
|
||||
'*.dylib',
|
||||
'*.exe',
|
||||
'*.bin',
|
||||
'*.o',
|
||||
'*.a',
|
||||
'*.lib',
|
||||
'*.class',
|
||||
|
||||
// Images / media / fonts
|
||||
'*.png',
|
||||
'*.jpg',
|
||||
'*.jpeg',
|
||||
'*.gif',
|
||||
'*.bmp',
|
||||
'*.ico',
|
||||
'*.webp',
|
||||
'*.svg',
|
||||
'*.mp3',
|
||||
'*.mp4',
|
||||
'*.wav',
|
||||
'*.avi',
|
||||
'*.mov',
|
||||
'*.mkv',
|
||||
'*.flac',
|
||||
'*.ogg',
|
||||
'*.pdf',
|
||||
'*.woff',
|
||||
'*.woff2',
|
||||
'*.ttf',
|
||||
'*.otf',
|
||||
'*.eot',
|
||||
|
||||
// Lock files
|
||||
'package-lock.json',
|
||||
'yarn.lock',
|
||||
'pnpm-lock.yaml',
|
||||
'Gemfile.lock',
|
||||
'Cargo.lock',
|
||||
'poetry.lock',
|
||||
'composer.lock',
|
||||
];
|
||||
|
||||
interface AgentProfile {
|
||||
avatar?: string;
|
||||
description?: string;
|
||||
title: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to extract the agent name, description, and avatar emoji from
|
||||
* IDENTITY.md or SOUL.md. Falls back to "OpenClaw" if neither file
|
||||
* exists or parsing fails.
|
||||
*/
|
||||
function readAgentProfile(workspacePath: string): AgentProfile {
|
||||
for (const filename of IDENTITY_FILES) {
|
||||
const filePath = path.join(workspacePath, filename);
|
||||
if (!fs.existsSync(filePath)) continue;
|
||||
|
||||
const content = fs.readFileSync(filePath, 'utf8');
|
||||
|
||||
// Try to extract **Name:** value
|
||||
const nameMatch = content.match(/\*{0,2}Name:?\*{0,2}\s*(.+)/i);
|
||||
const title = nameMatch ? nameMatch[1].trim() : DEFAULT_AGENT_NAME;
|
||||
|
||||
// Try to extract **Creature:** or **Vibe:** or **Description:** as description
|
||||
const descMatch = content.match(/\*{0,2}(?:Creature|Vibe|Description):?\*{0,2}\s*(.+)/i);
|
||||
const description = descMatch ? descMatch[1].trim() : undefined;
|
||||
|
||||
// Try to extract **Emoji:** value (single emoji)
|
||||
const emojiMatch = content.match(/\*{0,2}Emoji:?\*{0,2}\s*(.+)/i);
|
||||
const rawAvatar = emojiMatch ? emojiMatch[1].trim() : undefined;
|
||||
// Filter out placeholder text like (待定), _(待定)_, (TBD), N/A, etc.
|
||||
const isPlaceholder =
|
||||
rawAvatar && /^[_*((].*[))_*]$|^(?:tbd|todo|n\/?a|none|待定|未定)$/i.test(rawAvatar);
|
||||
const avatar = rawAvatar && !isPlaceholder ? rawAvatar : undefined;
|
||||
|
||||
return { avatar, description, title };
|
||||
}
|
||||
|
||||
return { title: DEFAULT_AGENT_NAME };
|
||||
}
|
||||
|
||||
/**
|
||||
* Build an ignore filter for the workspace. Uses .gitignore if present,
|
||||
* otherwise falls back to a comprehensive default rule set.
|
||||
*/
|
||||
function buildIgnoreFilter(workspacePath: string) {
|
||||
const ig = ignore();
|
||||
|
||||
const gitignorePath = path.join(workspacePath, '.gitignore');
|
||||
if (fs.existsSync(gitignorePath)) {
|
||||
ig.add(fs.readFileSync(gitignorePath, 'utf8'));
|
||||
}
|
||||
|
||||
// Always apply default rules on top
|
||||
ig.add(DEFAULT_IGNORE_RULES);
|
||||
|
||||
return ig;
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively collect all files under `dir`, filtered by ignore rules.
|
||||
* Returns paths relative to `baseDir`.
|
||||
*/
|
||||
function collectFiles(dir: string, baseDir: string, ig: ReturnType<typeof ignore>): string[] {
|
||||
const results: string[] = [];
|
||||
|
||||
for (const entry of fs.readdirSync(dir, { withFileTypes: true })) {
|
||||
const relativePath = path.relative(baseDir, path.join(dir, entry.name));
|
||||
|
||||
// Directories need a trailing slash for ignore to match correctly
|
||||
const testPath = entry.isDirectory() ? `${relativePath}/` : relativePath;
|
||||
if (ig.ignores(testPath)) continue;
|
||||
|
||||
const fullPath = path.join(dir, entry.name);
|
||||
|
||||
if (entry.isDirectory()) {
|
||||
results.push(...collectFiles(fullPath, baseDir, ig));
|
||||
} else if (entry.isFile()) {
|
||||
results.push(relativePath);
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
* Quick check: read the first 8KB and look for null bytes.
|
||||
* If found, the file is likely binary and should be skipped.
|
||||
*/
|
||||
function isBinaryFile(filePath: string): boolean {
|
||||
const fd = fs.openSync(filePath, 'r');
|
||||
try {
|
||||
const buf = Buffer.alloc(8192);
|
||||
const bytesRead = fs.readSync(fd, buf, 0, 8192, 0);
|
||||
for (let i = 0; i < bytesRead; i++) {
|
||||
if (buf[i] === 0) return true;
|
||||
}
|
||||
return false;
|
||||
} finally {
|
||||
fs.closeSync(fd);
|
||||
}
|
||||
}
|
||||
|
||||
function formatAgentLabel(profile: AgentProfile): string {
|
||||
return profile.avatar ? `${profile.avatar} ${profile.title}` : profile.title;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve the target agent ID.
|
||||
* Priority: --agent-id > --slug > create new agent from workspace profile.
|
||||
*/
|
||||
async function resolveAgentId(
|
||||
client: TrpcClient,
|
||||
opts: { agentId?: string; slug?: string },
|
||||
profile: AgentProfile,
|
||||
): Promise<string> {
|
||||
if (opts.agentId) return opts.agentId;
|
||||
|
||||
if (opts.slug) {
|
||||
const agent = await client.agent.getBuiltinAgent.query({ slug: opts.slug });
|
||||
if (!agent) {
|
||||
log.error(`Agent not found for slug: ${opts.slug}`);
|
||||
process.exit(1);
|
||||
}
|
||||
return agent.id;
|
||||
}
|
||||
|
||||
const label = formatAgentLabel(profile);
|
||||
log.info(`Creating new agent ${pc.bold(label)}...`);
|
||||
const result = await client.agent.createAgent.mutate({
|
||||
config: {
|
||||
avatar: profile.avatar,
|
||||
description: profile.description,
|
||||
title: profile.title,
|
||||
},
|
||||
});
|
||||
|
||||
const id = result.agentId;
|
||||
if (!id) {
|
||||
log.error('Failed to create agent — no agentId returned.');
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
console.log(`${pc.green('✓')} Agent created: ${pc.bold(label)}`);
|
||||
return id;
|
||||
}
|
||||
|
||||
export function registerOpenClawMigration(migrate: Command) {
|
||||
migrate
|
||||
.command('openclaw')
|
||||
.description('Import OpenClaw workspace files as agent documents')
|
||||
.option(
|
||||
'--source <path>',
|
||||
'Path to OpenClaw workspace',
|
||||
path.join(os.homedir(), '.openclaw', 'workspace'),
|
||||
)
|
||||
.option('--agent-id <id>', 'Import into an existing agent by ID')
|
||||
.option('--slug <slug>', 'Import into an existing agent by slug (e.g. "inbox")')
|
||||
.option('--dry-run', 'Preview files without importing')
|
||||
.option('--yes', 'Skip confirmation prompt')
|
||||
.action(
|
||||
async (options: {
|
||||
agentId?: string;
|
||||
dryRun?: boolean;
|
||||
slug?: string;
|
||||
source: string;
|
||||
yes?: boolean;
|
||||
}) => {
|
||||
// Check auth early so users don't scan files only to find out they're not logged in
|
||||
if (!options.dryRun) {
|
||||
await getTrpcClient();
|
||||
}
|
||||
|
||||
const workspacePath = path.resolve(options.source);
|
||||
|
||||
// Validate source directory
|
||||
if (!fs.existsSync(workspacePath)) {
|
||||
log.error(`OpenClaw workspace not found: ${workspacePath}`);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
if (!fs.statSync(workspacePath).isDirectory()) {
|
||||
log.error(`Not a directory: ${workspacePath}`);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
// Read agent profile from workspace identity files
|
||||
const profile = readAgentProfile(workspacePath);
|
||||
const label = formatAgentLabel(profile);
|
||||
|
||||
// Collect files (respects .gitignore + default rules)
|
||||
const ig = buildIgnoreFilter(workspacePath);
|
||||
const files = collectFiles(workspacePath, workspacePath, ig);
|
||||
|
||||
if (files.length === 0) {
|
||||
log.info('No files found in workspace.');
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(
|
||||
`Found ${pc.bold(String(files.length))} file(s) in ${pc.dim(workspacePath)}:\n`,
|
||||
);
|
||||
for (const f of files) {
|
||||
console.log(` ${pc.dim('•')} ${f}`);
|
||||
}
|
||||
console.log();
|
||||
|
||||
if (options.dryRun) {
|
||||
log.info('Dry run — no changes made.');
|
||||
return;
|
||||
}
|
||||
|
||||
// Confirm
|
||||
if (!options.yes) {
|
||||
const target = options.agentId
|
||||
? `agent ${pc.bold(options.agentId)}`
|
||||
: options.slug
|
||||
? `agent slug "${pc.bold(options.slug)}"`
|
||||
: `a new ${pc.bold(label)} agent`;
|
||||
const confirmed = await confirm(
|
||||
`Import ${files.length} file(s) as agent documents into ${target}?`,
|
||||
);
|
||||
if (!confirmed) {
|
||||
console.log('Cancelled.');
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
const client = await getTrpcClient();
|
||||
|
||||
// Create or reuse agent
|
||||
const agentId = await resolveAgentId(client, options, profile);
|
||||
|
||||
console.log(`\nImporting to ${pc.bold(label)}...\n`);
|
||||
|
||||
let success = 0;
|
||||
let failed = 0;
|
||||
|
||||
let skipped = 0;
|
||||
|
||||
for (const relativePath of files) {
|
||||
const fullPath = path.join(workspacePath, relativePath);
|
||||
|
||||
try {
|
||||
// Skip binary files that slipped through the extension filter
|
||||
if (isBinaryFile(fullPath)) {
|
||||
console.log(` ${pc.dim('○')} ${relativePath} ${pc.dim('(binary, skipped)')}`);
|
||||
skipped++;
|
||||
continue;
|
||||
}
|
||||
|
||||
const content = fs.readFileSync(fullPath, 'utf8');
|
||||
const stat = fs.statSync(fullPath);
|
||||
|
||||
await client.agentDocument.upsertDocument.mutate({
|
||||
agentId,
|
||||
content,
|
||||
createdAt: stat.birthtime,
|
||||
filename: relativePath,
|
||||
updatedAt: stat.mtime,
|
||||
});
|
||||
console.log(` ${pc.green('✓')} ${relativePath}`);
|
||||
success++;
|
||||
} catch (err: any) {
|
||||
console.log(` ${pc.red('✗')} ${relativePath} — ${err.message || err}`);
|
||||
failed++;
|
||||
}
|
||||
}
|
||||
|
||||
const agentUrl = `${resolveServerUrl()}/agent/${agentId}`;
|
||||
const skippedInfo = skipped > 0 ? `, ${skipped} skipped` : '';
|
||||
console.log();
|
||||
if (failed === 0) {
|
||||
console.log(
|
||||
`${pc.green('✓')} Migration complete! ${pc.bold(String(success))} file(s) imported to ${pc.bold(label)}.${skippedInfo}`,
|
||||
);
|
||||
} else {
|
||||
console.log(
|
||||
`${pc.yellow('⚠')} Migration finished with issues: ${pc.bold(String(success))} imported, ${pc.red(String(failed))} failed${skippedInfo}.`,
|
||||
);
|
||||
}
|
||||
console.log(`\n ${pc.dim('→')} ${pc.underline(agentUrl)}`);
|
||||
console.log();
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -1,2 +1,3 @@
|
||||
export const OFFICIAL_AGENT_GATEWAY_URL = 'https://agent-gateway.lobehub.com';
|
||||
export const OFFICIAL_SERVER_URL = 'https://app.lobehub.com';
|
||||
export const OFFICIAL_GATEWAY_URL = 'https://device-gateway.lobehub.com';
|
||||
|
||||
@@ -20,6 +20,7 @@ import { registerLogoutCommand } from './commands/logout';
|
||||
import { registerManCommand } from './commands/man';
|
||||
import { registerMemoryCommand } from './commands/memory';
|
||||
import { registerMessageCommand } from './commands/message';
|
||||
import { registerMigrateCommand } from './commands/migrate';
|
||||
import { registerModelCommand } from './commands/model';
|
||||
import { registerPluginCommand } from './commands/plugin';
|
||||
import { registerProviderCommand } from './commands/provider';
|
||||
@@ -72,6 +73,7 @@ export function createProgram() {
|
||||
registerUserCommand(program);
|
||||
registerConfigCommand(program);
|
||||
registerEvalCommand(program);
|
||||
registerMigrateCommand(program);
|
||||
|
||||
return program;
|
||||
}
|
||||
|
||||
@@ -2,10 +2,11 @@ import fs from 'node:fs';
|
||||
import os from 'node:os';
|
||||
import path from 'node:path';
|
||||
|
||||
import { OFFICIAL_SERVER_URL } from '../constants/urls';
|
||||
import { OFFICIAL_AGENT_GATEWAY_URL, OFFICIAL_SERVER_URL } from '../constants/urls';
|
||||
import { log } from '../utils/logger';
|
||||
|
||||
export interface StoredSettings {
|
||||
agentGatewayUrl?: string;
|
||||
gatewayUrl?: string;
|
||||
serverUrl?: string;
|
||||
}
|
||||
@@ -25,15 +26,24 @@ export function resolveServerUrl(): string {
|
||||
return envServerUrl || settingsServerUrl || OFFICIAL_SERVER_URL;
|
||||
}
|
||||
|
||||
export function resolveAgentGatewayUrl(): string | undefined {
|
||||
const envUrl = normalizeUrl(process.env.AGENT_GATEWAY_URL);
|
||||
const settingsUrl = normalizeUrl(loadSettings()?.agentGatewayUrl);
|
||||
|
||||
return envUrl || settingsUrl || OFFICIAL_AGENT_GATEWAY_URL;
|
||||
}
|
||||
|
||||
export function saveSettings(settings: StoredSettings): void {
|
||||
const serverUrl = normalizeUrl(settings.serverUrl);
|
||||
const agentGatewayUrl = normalizeUrl(settings.agentGatewayUrl);
|
||||
const gatewayUrl = normalizeUrl(settings.gatewayUrl);
|
||||
const serverUrl = normalizeUrl(settings.serverUrl);
|
||||
const normalized: StoredSettings = {
|
||||
agentGatewayUrl: agentGatewayUrl === OFFICIAL_AGENT_GATEWAY_URL ? undefined : agentGatewayUrl,
|
||||
gatewayUrl,
|
||||
serverUrl: serverUrl === OFFICIAL_SERVER_URL ? undefined : serverUrl,
|
||||
};
|
||||
|
||||
if (!normalized.serverUrl && !normalized.gatewayUrl) {
|
||||
if (!normalized.serverUrl && !normalized.gatewayUrl && !normalized.agentGatewayUrl) {
|
||||
try {
|
||||
fs.unlinkSync(SETTINGS_FILE);
|
||||
} catch {}
|
||||
@@ -50,14 +60,16 @@ export function loadSettings(): StoredSettings | null {
|
||||
try {
|
||||
const data = fs.readFileSync(SETTINGS_FILE, 'utf8');
|
||||
const parsed = JSON.parse(data) as StoredSettings;
|
||||
const agentGatewayUrl = normalizeUrl(parsed.agentGatewayUrl);
|
||||
const gatewayUrl = normalizeUrl(parsed.gatewayUrl);
|
||||
const serverUrl = normalizeUrl(parsed.serverUrl);
|
||||
const normalized: StoredSettings = {
|
||||
agentGatewayUrl: agentGatewayUrl === OFFICIAL_AGENT_GATEWAY_URL ? undefined : agentGatewayUrl,
|
||||
gatewayUrl,
|
||||
serverUrl: serverUrl === OFFICIAL_SERVER_URL ? undefined : serverUrl,
|
||||
};
|
||||
|
||||
if (!normalized.serverUrl && !normalized.gatewayUrl) return null;
|
||||
if (!normalized.serverUrl && !normalized.gatewayUrl && !normalized.agentGatewayUrl) return null;
|
||||
|
||||
return normalized;
|
||||
} catch {
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { streamAgentEvents } from './agentStream';
|
||||
import { streamAgentEvents, streamAgentEventsViaWebSocket } from './agentStream';
|
||||
|
||||
vi.mock('./logger', () => ({
|
||||
log: {
|
||||
debug: vi.fn(),
|
||||
error: vi.fn(),
|
||||
heartbeat: vi.fn(),
|
||||
info: vi.fn(),
|
||||
@@ -193,3 +194,391 @@ describe('streamAgentEvents', () => {
|
||||
exitSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
// ── WebSocket stream tests ──────────────────────────────
|
||||
|
||||
let capturedWs: MockWebSocket | undefined;
|
||||
|
||||
class MockWebSocket {
|
||||
static OPEN = 1;
|
||||
static CONNECTING = 0;
|
||||
static CLOSED = 3;
|
||||
|
||||
readyState = MockWebSocket.CONNECTING;
|
||||
onopen: ((ev: any) => void) | null = null;
|
||||
onmessage: ((ev: any) => void) | null = null;
|
||||
onerror: ((ev: any) => void) | null = null;
|
||||
onclose: ((ev: any) => void) | null = null;
|
||||
|
||||
sent: string[] = [];
|
||||
private autoAuthSuccess = true;
|
||||
|
||||
constructor(
|
||||
public url: string,
|
||||
autoAuth = true,
|
||||
) {
|
||||
this.autoAuthSuccess = autoAuth;
|
||||
capturedWs = this; // eslint-disable-line @typescript-eslint/no-this-alias
|
||||
// Trigger onopen on next microtask (after handlers are assigned)
|
||||
queueMicrotask(() => {
|
||||
this.readyState = MockWebSocket.OPEN;
|
||||
this.onopen?.({ type: 'open' });
|
||||
});
|
||||
}
|
||||
|
||||
send(data: string) {
|
||||
this.sent.push(data);
|
||||
const msg = JSON.parse(data);
|
||||
|
||||
if (msg.type === 'auth' && this.autoAuthSuccess) {
|
||||
queueMicrotask(() => {
|
||||
this.onmessage?.({ data: JSON.stringify({ type: 'auth_success' }) });
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
close() {
|
||||
this.readyState = MockWebSocket.CLOSED;
|
||||
// Async like real WebSocket — fires after current microtask
|
||||
queueMicrotask(() => this.onclose?.({ code: 1000, reason: '' }));
|
||||
}
|
||||
|
||||
simulateMessage(msg: Record<string, unknown>) {
|
||||
this.onmessage?.({ data: JSON.stringify(msg) });
|
||||
}
|
||||
}
|
||||
|
||||
describe('streamAgentEventsViaWebSocket', () => {
|
||||
let stdoutSpy: ReturnType<typeof vi.spyOn>;
|
||||
let consoleSpy: ReturnType<typeof vi.spyOn>;
|
||||
const originalWebSocket = globalThis.WebSocket;
|
||||
|
||||
beforeEach(() => {
|
||||
capturedWs = undefined;
|
||||
stdoutSpy = vi.spyOn(process.stdout, 'write').mockImplementation(() => true);
|
||||
consoleSpy = vi.spyOn(console, 'log').mockImplementation(() => {});
|
||||
(globalThis as any).WebSocket = MockWebSocket;
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
stdoutSpy.mockRestore();
|
||||
consoleSpy.mockRestore();
|
||||
globalThis.WebSocket = originalWebSocket;
|
||||
});
|
||||
|
||||
/** Wait for microtasks + short delay so WS open/auth cycle completes */
|
||||
const flush = () => new Promise((r) => setTimeout(r, 20));
|
||||
|
||||
it('should connect, authenticate, and send resume', async () => {
|
||||
const promise = streamAgentEventsViaWebSocket({
|
||||
gatewayUrl: 'https://gw.test.com',
|
||||
operationId: 'op-1',
|
||||
token: 'test-token',
|
||||
});
|
||||
|
||||
await flush();
|
||||
|
||||
const ws = capturedWs!;
|
||||
expect(ws.sent.map((s) => JSON.parse(s))).toEqual([
|
||||
{ token: 'test-token', type: 'auth' },
|
||||
{ lastEventId: '', type: 'resume' },
|
||||
]);
|
||||
|
||||
ws.simulateMessage({ id: '1', type: 'session_complete' });
|
||||
await promise;
|
||||
});
|
||||
|
||||
it('should render agent_event messages using existing renderEvent', async () => {
|
||||
const promise = streamAgentEventsViaWebSocket({
|
||||
gatewayUrl: 'https://gw.test.com',
|
||||
operationId: 'op-1',
|
||||
token: 'test-token',
|
||||
});
|
||||
|
||||
await flush();
|
||||
const ws = capturedWs!;
|
||||
|
||||
ws.simulateMessage({
|
||||
event: { data: null, operationId: 'op-1', stepIndex: 0, timestamp: 1, type: 'step_start' },
|
||||
id: '1',
|
||||
type: 'agent_event',
|
||||
});
|
||||
ws.simulateMessage({
|
||||
event: {
|
||||
data: { chunkType: 'text', content: 'Hello WS!' },
|
||||
operationId: 'op-1',
|
||||
stepIndex: 0,
|
||||
timestamp: 2,
|
||||
type: 'stream_chunk',
|
||||
},
|
||||
id: '2',
|
||||
type: 'agent_event',
|
||||
});
|
||||
ws.simulateMessage({
|
||||
event: {
|
||||
data: { stepCount: 1 },
|
||||
operationId: 'op-1',
|
||||
stepIndex: 0,
|
||||
timestamp: 3,
|
||||
type: 'agent_runtime_end',
|
||||
},
|
||||
id: '3',
|
||||
type: 'agent_event',
|
||||
});
|
||||
|
||||
await promise;
|
||||
expect(stdoutSpy).toHaveBeenCalledWith('Hello WS!');
|
||||
});
|
||||
|
||||
it('should output JSON when json option is set', async () => {
|
||||
const promise = streamAgentEventsViaWebSocket({
|
||||
gatewayUrl: 'https://gw.test.com',
|
||||
json: true,
|
||||
operationId: 'op-1',
|
||||
token: 'test-token',
|
||||
});
|
||||
|
||||
await flush();
|
||||
const ws = capturedWs!;
|
||||
|
||||
ws.simulateMessage({
|
||||
event: {
|
||||
data: null,
|
||||
operationId: 'op-1',
|
||||
stepIndex: 0,
|
||||
timestamp: 1,
|
||||
type: 'agent_runtime_init',
|
||||
},
|
||||
id: '1',
|
||||
type: 'agent_event',
|
||||
});
|
||||
ws.simulateMessage({
|
||||
event: {
|
||||
data: { stepCount: 1 },
|
||||
operationId: 'op-1',
|
||||
stepIndex: 0,
|
||||
timestamp: 2,
|
||||
type: 'agent_runtime_end',
|
||||
},
|
||||
id: '2',
|
||||
type: 'agent_event',
|
||||
});
|
||||
|
||||
await promise;
|
||||
|
||||
expect(consoleSpy).toHaveBeenCalledWith(expect.stringContaining('"agent_runtime_init"'));
|
||||
expect(consoleSpy).toHaveBeenCalledWith(expect.stringContaining('"agent_runtime_end"'));
|
||||
});
|
||||
|
||||
it('should reject on auth failure', async () => {
|
||||
// Override mock to return auth_failed instead of auth_success
|
||||
(globalThis as any).WebSocket = class extends MockWebSocket {
|
||||
constructor(url: string) {
|
||||
super(url, false); // disable auto auth_success
|
||||
capturedWs = this; // eslint-disable-line @typescript-eslint/no-this-alias
|
||||
}
|
||||
|
||||
override send(data: string) {
|
||||
this.sent.push(data);
|
||||
const msg = JSON.parse(data);
|
||||
if (msg.type === 'auth') {
|
||||
queueMicrotask(() => {
|
||||
this.onmessage?.({
|
||||
data: JSON.stringify({ reason: 'invalid token', type: 'auth_failed' }),
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
await expect(
|
||||
streamAgentEventsViaWebSocket({
|
||||
gatewayUrl: 'https://gw.test.com',
|
||||
operationId: 'op-1',
|
||||
token: 'bad-token',
|
||||
}),
|
||||
).rejects.toThrow('Gateway auth failed');
|
||||
});
|
||||
|
||||
it('should resolve on session_complete', async () => {
|
||||
const promise = streamAgentEventsViaWebSocket({
|
||||
gatewayUrl: 'https://gw.test.com',
|
||||
operationId: 'op-1',
|
||||
token: 'test-token',
|
||||
});
|
||||
|
||||
await flush();
|
||||
capturedWs!.simulateMessage({ id: '1', summary: 'All done', type: 'session_complete' });
|
||||
|
||||
await expect(promise).resolves.toBeUndefined();
|
||||
});
|
||||
|
||||
it('should ignore heartbeat_ack messages', async () => {
|
||||
const promise = streamAgentEventsViaWebSocket({
|
||||
gatewayUrl: 'https://gw.test.com',
|
||||
operationId: 'op-1',
|
||||
token: 'test-token',
|
||||
});
|
||||
|
||||
await flush();
|
||||
const ws = capturedWs!;
|
||||
|
||||
ws.simulateMessage({ type: 'heartbeat_ack' });
|
||||
expect(stdoutSpy).not.toHaveBeenCalled();
|
||||
|
||||
ws.simulateMessage({ id: '1', type: 'session_complete' });
|
||||
await promise;
|
||||
});
|
||||
|
||||
it('should construct correct WebSocket URL from HTTPS gateway URL', async () => {
|
||||
const promise = streamAgentEventsViaWebSocket({
|
||||
gatewayUrl: 'https://agent-gateway.lobehub.com',
|
||||
operationId: 'op-123',
|
||||
token: 'tok',
|
||||
});
|
||||
|
||||
await flush();
|
||||
expect(capturedWs!.url).toBe('wss://agent-gateway.lobehub.com/ws?operationId=op-123');
|
||||
|
||||
capturedWs!.simulateMessage({ id: '1', type: 'session_complete' });
|
||||
await promise;
|
||||
});
|
||||
|
||||
it('should render a multi-step agent run with tool calls', async () => {
|
||||
const promise = streamAgentEventsViaWebSocket({
|
||||
gatewayUrl: 'https://gw.test.com',
|
||||
operationId: 'op-1',
|
||||
token: 'tok',
|
||||
verbose: true,
|
||||
});
|
||||
|
||||
await flush();
|
||||
const ws = capturedWs!;
|
||||
const { log } = await import('./logger');
|
||||
|
||||
// Step 1: thinking + text + tool call
|
||||
ws.simulateMessage({
|
||||
event: {
|
||||
data: null,
|
||||
operationId: 'op-1',
|
||||
stepIndex: 0,
|
||||
timestamp: 1,
|
||||
type: 'agent_runtime_init',
|
||||
},
|
||||
id: '1',
|
||||
type: 'agent_event',
|
||||
});
|
||||
ws.simulateMessage({
|
||||
event: { data: null, operationId: 'op-1', stepIndex: 0, timestamp: 2, type: 'step_start' },
|
||||
id: '2',
|
||||
type: 'agent_event',
|
||||
});
|
||||
ws.simulateMessage({
|
||||
event: {
|
||||
data: { chunkType: 'reasoning', reasoning: 'Let me search...' },
|
||||
operationId: 'op-1',
|
||||
stepIndex: 0,
|
||||
timestamp: 3,
|
||||
type: 'stream_chunk',
|
||||
},
|
||||
id: '3',
|
||||
type: 'agent_event',
|
||||
});
|
||||
ws.simulateMessage({
|
||||
event: {
|
||||
data: { chunkType: 'text', content: 'Searching for news.' },
|
||||
operationId: 'op-1',
|
||||
stepIndex: 0,
|
||||
timestamp: 4,
|
||||
type: 'stream_chunk',
|
||||
},
|
||||
id: '4',
|
||||
type: 'agent_event',
|
||||
});
|
||||
ws.simulateMessage({
|
||||
event: {
|
||||
data: { toolCalling: { apiName: 'search', id: 'tc-1' } },
|
||||
operationId: 'op-1',
|
||||
stepIndex: 0,
|
||||
timestamp: 5,
|
||||
type: 'tool_start',
|
||||
},
|
||||
id: '5',
|
||||
type: 'agent_event',
|
||||
});
|
||||
ws.simulateMessage({
|
||||
event: { data: null, operationId: 'op-1', stepIndex: 0, timestamp: 6, type: 'stream_end' },
|
||||
id: '6',
|
||||
type: 'agent_event',
|
||||
});
|
||||
ws.simulateMessage({
|
||||
event: {
|
||||
data: { stepIndex: 0 },
|
||||
operationId: 'op-1',
|
||||
stepIndex: 0,
|
||||
timestamp: 7,
|
||||
type: 'step_complete',
|
||||
},
|
||||
id: '7',
|
||||
type: 'agent_event',
|
||||
});
|
||||
|
||||
// Step 2: tool result + final text
|
||||
ws.simulateMessage({
|
||||
event: { data: null, operationId: 'op-1', stepIndex: 1, timestamp: 8, type: 'step_start' },
|
||||
id: '8',
|
||||
type: 'agent_event',
|
||||
});
|
||||
ws.simulateMessage({
|
||||
event: {
|
||||
data: {
|
||||
isSuccess: true,
|
||||
payload: { toolCalling: { id: 'tc-1' } },
|
||||
result: { content: 'Results...' },
|
||||
},
|
||||
operationId: 'op-1',
|
||||
stepIndex: 1,
|
||||
timestamp: 9,
|
||||
type: 'tool_end',
|
||||
},
|
||||
id: '9',
|
||||
type: 'agent_event',
|
||||
});
|
||||
ws.simulateMessage({
|
||||
event: {
|
||||
data: { chunkType: 'text', content: 'Here are the results.' },
|
||||
operationId: 'op-1',
|
||||
stepIndex: 1,
|
||||
timestamp: 10,
|
||||
type: 'stream_chunk',
|
||||
},
|
||||
id: '10',
|
||||
type: 'agent_event',
|
||||
});
|
||||
ws.simulateMessage({
|
||||
event: {
|
||||
data: { cost: { total: 0.05 }, stepCount: 2, usage: { total_tokens: 500 } },
|
||||
operationId: 'op-1',
|
||||
stepIndex: 1,
|
||||
timestamp: 11,
|
||||
type: 'agent_runtime_end',
|
||||
},
|
||||
id: '11',
|
||||
type: 'agent_event',
|
||||
});
|
||||
|
||||
await promise;
|
||||
|
||||
// Verify reasoning was rendered (dim)
|
||||
expect(stdoutSpy).toHaveBeenCalledWith(expect.stringContaining('Let me search...'));
|
||||
// Verify text chunks
|
||||
expect(stdoutSpy).toHaveBeenCalledWith('Searching for news.');
|
||||
expect(stdoutSpy).toHaveBeenCalledWith('Here are the results.');
|
||||
// Verify tool call was logged
|
||||
expect(log.toolCall).toHaveBeenCalledWith('search', 'tc-1', undefined);
|
||||
// Verify tool result was logged
|
||||
expect(log.toolResult).toHaveBeenCalled();
|
||||
// Verify finish line
|
||||
expect(consoleSpy).toHaveBeenCalledWith(expect.stringContaining('Agent finished'));
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import pc from 'picocolors';
|
||||
import urlJoin from 'url-join';
|
||||
|
||||
import { log } from './logger';
|
||||
|
||||
@@ -16,6 +17,12 @@ interface StreamOptions {
|
||||
verbose?: boolean;
|
||||
}
|
||||
|
||||
interface WebSocketStreamOptions extends StreamOptions {
|
||||
gatewayUrl: string;
|
||||
operationId: string;
|
||||
token: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Connect to the agent SSE stream and render events to the terminal.
|
||||
* Resolves when the stream ends (agent_runtime_end or connection close).
|
||||
@@ -152,6 +159,126 @@ export function replayAgentEvents(events: AgentStreamEvent[], options: StreamOpt
|
||||
}
|
||||
}
|
||||
|
||||
const HEARTBEAT_INTERVAL = 30_000;
|
||||
|
||||
/**
|
||||
* Connect to the Agent Gateway via WebSocket and render events to the terminal.
|
||||
* Resolves when the session completes or the connection closes.
|
||||
*/
|
||||
export async function streamAgentEventsViaWebSocket(
|
||||
options: WebSocketStreamOptions,
|
||||
): Promise<void> {
|
||||
const { gatewayUrl, operationId, token, ...streamOpts } = options;
|
||||
const wsUrl = urlJoin(
|
||||
gatewayUrl.replace(/^http/, 'ws'),
|
||||
`/ws?operationId=${encodeURIComponent(operationId)}`,
|
||||
);
|
||||
|
||||
log.debug(`Connecting to gateway: ${wsUrl}`);
|
||||
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
const ws = new WebSocket(wsUrl);
|
||||
const jsonEvents: AgentStreamEvent[] = [];
|
||||
const ctx = createRenderContext();
|
||||
let lastEventId = '';
|
||||
let heartbeatTimer: ReturnType<typeof setInterval> | undefined;
|
||||
let jsonPrinted = false;
|
||||
|
||||
const cleanup = () => {
|
||||
if (heartbeatTimer) clearInterval(heartbeatTimer);
|
||||
if (ws.readyState === WebSocket.OPEN || ws.readyState === WebSocket.CONNECTING) {
|
||||
ws.close();
|
||||
}
|
||||
};
|
||||
|
||||
ws.onopen = () => {
|
||||
ws.send(JSON.stringify({ token, type: 'auth' }));
|
||||
};
|
||||
|
||||
ws.onmessage = (event) => {
|
||||
const msg = JSON.parse(event.data as string);
|
||||
|
||||
if (msg.type === 'auth_success') {
|
||||
log.debug('Gateway authenticated');
|
||||
// Request all buffered events (covers events pushed before WS connected)
|
||||
ws.send(JSON.stringify({ lastEventId: '', type: 'resume' }));
|
||||
heartbeatTimer = setInterval(() => {
|
||||
if (ws.readyState === WebSocket.OPEN) {
|
||||
ws.send(JSON.stringify({ type: 'heartbeat' }));
|
||||
}
|
||||
}, HEARTBEAT_INTERVAL);
|
||||
return;
|
||||
}
|
||||
|
||||
if (msg.type === 'auth_failed') {
|
||||
cleanup();
|
||||
reject(new Error(`Gateway auth failed: ${msg.reason}`));
|
||||
return;
|
||||
}
|
||||
|
||||
if (msg.type === 'heartbeat_ack') return;
|
||||
|
||||
if (msg.type === 'agent_event') {
|
||||
const agentEvent: AgentStreamEvent = msg.event;
|
||||
if (msg.id) lastEventId = msg.id;
|
||||
|
||||
if (streamOpts.json) {
|
||||
jsonEvents.push(agentEvent);
|
||||
} else {
|
||||
renderEvent(agentEvent, ctx, streamOpts);
|
||||
}
|
||||
|
||||
if (agentEvent.type === 'agent_runtime_end') {
|
||||
if (streamOpts.json && !jsonPrinted) {
|
||||
jsonPrinted = true;
|
||||
console.log(JSON.stringify(jsonEvents, null, 2));
|
||||
} else if (!streamOpts.json) {
|
||||
renderEnd(agentEvent);
|
||||
}
|
||||
cleanup();
|
||||
resolve();
|
||||
return;
|
||||
}
|
||||
|
||||
if (agentEvent.type === 'error') {
|
||||
if (streamOpts.json && !jsonPrinted) {
|
||||
jsonPrinted = true;
|
||||
console.log(JSON.stringify(jsonEvents, null, 2));
|
||||
}
|
||||
log.error(
|
||||
`Agent error: ${agentEvent.data?.message || agentEvent.data?.error || 'Unknown error'}`,
|
||||
);
|
||||
cleanup();
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
if (msg.type === 'session_complete') {
|
||||
if (streamOpts.json && jsonEvents.length > 0 && !jsonPrinted) {
|
||||
jsonPrinted = true;
|
||||
console.log(JSON.stringify(jsonEvents, null, 2));
|
||||
}
|
||||
cleanup();
|
||||
resolve();
|
||||
}
|
||||
};
|
||||
|
||||
ws.onerror = (err) => {
|
||||
cleanup();
|
||||
reject(err);
|
||||
};
|
||||
|
||||
ws.onclose = () => {
|
||||
if (heartbeatTimer) clearInterval(heartbeatTimer);
|
||||
if (streamOpts.json && jsonEvents.length > 0 && !jsonPrinted) {
|
||||
jsonPrinted = true;
|
||||
console.log(JSON.stringify(jsonEvents, null, 2));
|
||||
}
|
||||
resolve();
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
// ── Render helpers ──────────────────────────────────────
|
||||
|
||||
interface RenderContext {
|
||||
|
||||
@@ -68,7 +68,7 @@
|
||||
"cookie": "^1.1.1",
|
||||
"cross-env": "^10.1.0",
|
||||
"diff": "^8.0.4",
|
||||
"electron": "41.0.3",
|
||||
"electron": "41.1.0",
|
||||
"electron-builder": "^26.8.1",
|
||||
"electron-devtools-installer": "4.0.0",
|
||||
"electron-is": "^3.0.0",
|
||||
|
||||
@@ -5,7 +5,7 @@ import path from 'node:path';
|
||||
import { pipeline } from 'node:stream/promises';
|
||||
import { fileURLToPath } from 'node:url';
|
||||
|
||||
const VERSION = '0.20.1';
|
||||
const VERSION = '0.24.0';
|
||||
|
||||
const __dirname = path.dirname(fileURLToPath(import.meta.url));
|
||||
const binDir = path.join(__dirname, '..', 'resources', 'bin');
|
||||
|
||||
@@ -9,7 +9,7 @@ import { tagWhite, writeJSON } from './utils';
|
||||
export const genDefaultLocale = () => {
|
||||
consola.info(`默认语言为 ${i18nConfig.entryLocale}...`);
|
||||
|
||||
// 确保入口语言目录存在
|
||||
// Ensure entry locale directory exists
|
||||
const entryLocaleDir = localeDir(i18nConfig.entryLocale);
|
||||
if (!existsSync(entryLocaleDir)) {
|
||||
mkdirSync(entryLocaleDir, { recursive: true });
|
||||
@@ -23,7 +23,7 @@ export const genDefaultLocale = () => {
|
||||
for (const [ns, value] of data) {
|
||||
const filepath = entryLocaleJsonFilepath(`${ns}.json`);
|
||||
|
||||
// 确保目录存在
|
||||
// Ensure directory exists
|
||||
const dir = dirname(filepath);
|
||||
if (!existsSync(dir)) {
|
||||
mkdirSync(dir, { recursive: true });
|
||||
|
||||
@@ -5,7 +5,7 @@ import { genDefaultLocale } from './genDefaultLocale';
|
||||
import { genDiff } from './genDiff';
|
||||
import { split } from './utils';
|
||||
|
||||
// 确保所有语言目录存在
|
||||
// Ensure all locale directories exist
|
||||
const ensureLocalesDirs = () => {
|
||||
[i18nConfig.entryLocale, ...i18nConfig.outputLocales].forEach((locale) => {
|
||||
const dir = localeDir(locale);
|
||||
@@ -15,20 +15,20 @@ const ensureLocalesDirs = () => {
|
||||
});
|
||||
};
|
||||
|
||||
// 运行工作流
|
||||
// Run workflow
|
||||
const run = async () => {
|
||||
// 确保目录存在
|
||||
// Ensure directories exist
|
||||
ensureLocalesDirs();
|
||||
|
||||
// 差异分析
|
||||
// Diff analysis
|
||||
split('差异分析');
|
||||
genDiff();
|
||||
|
||||
// 生成默认语言文件
|
||||
// Generate default locale files
|
||||
split('生成默认语言文件');
|
||||
genDefaultLocale();
|
||||
|
||||
// 生成国际化文件
|
||||
// Generate i18n files
|
||||
split('生成国际化文件');
|
||||
};
|
||||
|
||||
|
||||
@@ -21,6 +21,10 @@ tags:
|
||||
|
||||
Channels allow you to connect your LobeHub agents to external messaging platforms. Once connected, users can interact with your AI assistant directly in the chat apps they already use — no need to visit LobeHub.
|
||||
|
||||
> [!NOTE]
|
||||
>
|
||||
> WeChat currently requires an active subscription. If you are using the community edition without a subscription, the WeChat channel option may not appear in the Channels settings yet.
|
||||
|
||||
## Supported Platforms
|
||||
|
||||
| Platform | Description |
|
||||
@@ -29,7 +33,7 @@ Channels allow you to connect your LobeHub agents to external messaging platform
|
||||
| [Slack](/docs/usage/channels/slack) | Connect to Slack for channel and direct message conversations |
|
||||
| [Telegram](/docs/usage/channels/telegram) | Connect to Telegram for private and group conversations |
|
||||
| [QQ](/docs/usage/channels/qq) | Connect to QQ for group chats and direct messages |
|
||||
| [WeChat (微信)](/docs/usage/channels/wechat) | Connect to WeChat via iLink Bot for private and group chats |
|
||||
| [WeChat (微信)](/docs/usage/channels/wechat) | Connect to WeChat via iLink Bot for private and group chats (requires an active subscription) |
|
||||
| [Feishu (飞书)](/docs/usage/channels/feishu) | Connect to Feishu for team collaboration (Chinese version) |
|
||||
| [Lark](/docs/usage/channels/lark) | Connect to Lark for team collaboration (international version) |
|
||||
|
||||
@@ -53,6 +57,8 @@ Each channel integration works by linking a bot account on the target platform t
|
||||
- [Feishu (飞书)](/docs/usage/channels/feishu)
|
||||
- [Lark](/docs/usage/channels/lark)
|
||||
|
||||
If you do not see **WeChat** in the channel list, check that your account has an active subscription first.
|
||||
|
||||
## Feature Support
|
||||
|
||||
Text messages are supported across all platforms. Some features vary by platform:
|
||||
|
||||
@@ -20,6 +20,10 @@ tags:
|
||||
|
||||
渠道功能允许您将 LobeHub 代理连接到外部消息平台。一旦连接,用户可以直接在他们已经使用的聊天应用中与您的 AI 助手互动,无需访问 LobeHub。
|
||||
|
||||
> [!NOTE]
|
||||
>
|
||||
> 微信渠道目前需要有效订阅。如果您使用的是没有订阅的社区版,**渠道**设置中可能暂时不会显示微信选项。
|
||||
|
||||
## 支持的平台
|
||||
|
||||
| 平台 | 描述 |
|
||||
@@ -28,7 +32,7 @@ tags:
|
||||
| [Slack](/docs/usage/channels/slack) | 连接到 Slack,用于频道和私信对话 |
|
||||
| [Telegram](/docs/usage/channels/telegram) | 连接到 Telegram,用于私人和群组对话 |
|
||||
| [QQ](/docs/usage/channels/qq) | 连接到 QQ,用于群聊和私信 |
|
||||
| [微信](/docs/usage/channels/wechat) | 通过 iLink Bot 连接到微信,用于私聊和群聊 |
|
||||
| [微信](/docs/usage/channels/wechat) | 通过 iLink Bot 连接到微信,用于私聊和群聊(需要有效订阅) |
|
||||
| [飞书](/docs/usage/channels/feishu) | 连接到飞书,用于团队协作(中国版) |
|
||||
| [Lark](/docs/usage/channels/lark) | 连接到 Lark,用于团队协作(国际版) |
|
||||
|
||||
@@ -52,6 +56,8 @@ tags:
|
||||
- [飞书](/docs/usage/channels/feishu)
|
||||
- [Lark](/docs/usage/channels/lark)
|
||||
|
||||
如果您在渠道列表中看不到 **微信**,请先确认当前账户是否拥有有效订阅。
|
||||
|
||||
## 功能支持
|
||||
|
||||
所有平台均支持文本消息。某些功能因平台而异:
|
||||
|
||||
@@ -705,6 +705,8 @@
|
||||
"skillStore.tabs.community": "Community",
|
||||
"skillStore.tabs.custom": "Custom",
|
||||
"skillStore.tabs.lobehub": "LobeHub",
|
||||
"skillStore.tabs.mcp": "MCP",
|
||||
"skillStore.tabs.skills": "Skills",
|
||||
"skillStore.title": "Skill Store",
|
||||
"skillStore.wantMore.action": "Submit a request →",
|
||||
"skillStore.wantMore.feedback.message": "## Skill Name\n[Please fill in]\n\n## Use Case\nWhen I am ___, I need ___\n\n## Expected Features\n1.\n2.\n3.\n\n## Reference Examples\n(Optional) Are there any similar tools or features for reference?\n\n---\n💡 Tip: The more specific your description, the better we can meet your needs",
|
||||
|
||||
@@ -705,6 +705,8 @@
|
||||
"skillStore.tabs.community": "社区",
|
||||
"skillStore.tabs.custom": "自定义",
|
||||
"skillStore.tabs.lobehub": "LobeHub",
|
||||
"skillStore.tabs.mcp": "MCP",
|
||||
"skillStore.tabs.skills": "技能",
|
||||
"skillStore.title": "技能商店",
|
||||
"skillStore.wantMore.action": "提交申请 →",
|
||||
"skillStore.wantMore.feedback.message": "## 技能名称\n[请填写]\n\n## 使用场景\n当我在___时,我需要___\n\n## 期望功能\n1.\n2.\n3.\n\n## 参考示例\n(可选)是否有类似的工具或功能可供参考?\n\n---\n💡 提示:描述越具体,我们就越能满足您的需求",
|
||||
|
||||
+1
-1
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@lobehub/lobehub",
|
||||
"version": "2.1.46",
|
||||
"version": "2.1.47",
|
||||
"description": "LobeHub - an open-source,comprehensive AI Agent framework that supports speech synthesis, multimodal, and extensible Function Call plugin system. Supports one-click free deployment of your private ChatGPT/LLM web application.",
|
||||
"keywords": [
|
||||
"framework",
|
||||
|
||||
@@ -38,6 +38,7 @@ export enum DocumentLoadFormat {
|
||||
export enum PolicyLoad {
|
||||
ALWAYS = 'always',
|
||||
DISABLED = 'disabled',
|
||||
PROGRESSIVE = 'progressive',
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -7,6 +7,7 @@ import content from './SKILL.md';
|
||||
export const TaskIdentifier = 'task';
|
||||
|
||||
export const TaskSkill: BuiltinSkill = {
|
||||
avatar: '📋',
|
||||
content,
|
||||
description: 'Task management and execution — create, track, review, and complete tasks via CLI.',
|
||||
identifier: TaskIdentifier,
|
||||
|
||||
@@ -3,7 +3,8 @@
|
||||
"version": "1.0.0",
|
||||
"private": true,
|
||||
"exports": {
|
||||
".": "./src/index.ts"
|
||||
".": "./src/index.ts",
|
||||
"./executor": "./src/executor/index.ts"
|
||||
},
|
||||
"main": "./src/index.ts",
|
||||
"devDependencies": {
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
import type { BuiltinToolContext, BuiltinToolResult } from '@lobechat/types';
|
||||
import { BaseExecutor } from '@lobechat/types';
|
||||
|
||||
import { TaskIdentifier } from '../manifest';
|
||||
import { TaskApiName } from '../types';
|
||||
|
||||
class TaskExecutor extends BaseExecutor<typeof TaskApiName> {
|
||||
readonly identifier = TaskIdentifier;
|
||||
protected readonly apiEnum = TaskApiName;
|
||||
|
||||
// TODO (LOBE-6597): wire to store.createTask()
|
||||
createTask = async (_params: any, _ctx?: BuiltinToolContext): Promise<BuiltinToolResult> => {
|
||||
return { content: 'Not implemented: createTask', success: false };
|
||||
};
|
||||
|
||||
// TODO (LOBE-6597): wire to store.deleteTask()
|
||||
deleteTask = async (_params: any, _ctx?: BuiltinToolContext): Promise<BuiltinToolResult> => {
|
||||
return { content: 'Not implemented: deleteTask', success: false };
|
||||
};
|
||||
|
||||
// TODO (LOBE-6597): wire to store.updateTask() + addDependency/removeDependency
|
||||
editTask = async (_params: any, _ctx?: BuiltinToolContext): Promise<BuiltinToolResult> => {
|
||||
return { content: 'Not implemented: editTask', success: false };
|
||||
};
|
||||
|
||||
// TODO (LOBE-6597): wire to service.list() or store.tasks
|
||||
listTasks = async (_params: any, _ctx?: BuiltinToolContext): Promise<BuiltinToolResult> => {
|
||||
return { content: 'Not implemented: listTasks', success: false };
|
||||
};
|
||||
|
||||
// TODO (LOBE-6597): wire to lifecycle slice actions (runTask/pauseTask/cancelTask etc.)
|
||||
updateTaskStatus = async (
|
||||
_params: any,
|
||||
_ctx?: BuiltinToolContext,
|
||||
): Promise<BuiltinToolResult> => {
|
||||
return { content: 'Not implemented: updateTaskStatus', success: false };
|
||||
};
|
||||
|
||||
// TODO (LOBE-6597): wire to service.detail() or store.taskDetailMap
|
||||
viewTask = async (_params: any, _ctx?: BuiltinToolContext): Promise<BuiltinToolResult> => {
|
||||
return { content: 'Not implemented: viewTask', success: false };
|
||||
};
|
||||
}
|
||||
|
||||
export const taskExecutor = new TaskExecutor();
|
||||
@@ -24,7 +24,7 @@ export const DEFAULT_QUERY_REWRITE_SYSTEM_AGENT_ITEM: QueryRewriteSystemAgent =
|
||||
};
|
||||
|
||||
export const DEFAULT_INPUT_COMPLETION_SYSTEM_AGENT_ITEM: SystemAgentItem = {
|
||||
enabled: true,
|
||||
enabled: false,
|
||||
model: DEFAULT_MINI_SYSTEM_AGENT_ITEM.model,
|
||||
provider: DEFAULT_MINI_SYSTEM_AGENT_ITEM.provider,
|
||||
};
|
||||
|
||||
@@ -23,11 +23,13 @@ export type AgentDocumentLoadFormat = 'file' | 'raw';
|
||||
|
||||
export interface AgentContextDocument {
|
||||
content?: string;
|
||||
description?: string;
|
||||
filename: string;
|
||||
id?: string;
|
||||
loadPosition?: AgentDocumentInjectionPosition;
|
||||
loadRules?: AgentDocumentLoadRules;
|
||||
policyId?: string | null;
|
||||
policyLoad?: 'always' | 'progressive';
|
||||
policyLoadFormat?: AgentDocumentLoadFormat;
|
||||
title?: string;
|
||||
}
|
||||
@@ -104,13 +106,43 @@ export function formatDocument(
|
||||
}
|
||||
|
||||
/**
|
||||
* Combine multiple documents into a single string
|
||||
* Format a single progressive document as an index entry
|
||||
*/
|
||||
function formatProgressiveEntry(doc: AgentContextDocument): string {
|
||||
const parts: string[] = [];
|
||||
if (doc.id) parts.push(`[${doc.id}]`);
|
||||
parts.push(doc.filename);
|
||||
if (doc.title && doc.title !== doc.filename) parts.push(`— "${doc.title}"`);
|
||||
if (doc.description) parts.push(`: ${doc.description}`);
|
||||
return `- ${parts.join(' ')}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Combine multiple documents into a single string.
|
||||
* Progressive documents are grouped into a lightweight index block;
|
||||
* full-content documents are formatted individually.
|
||||
*/
|
||||
export function combineDocuments(
|
||||
docs: AgentContextDocument[],
|
||||
context: AgentDocumentFilterContext,
|
||||
): string {
|
||||
return docs.map((doc) => formatDocument(doc, context)).join('\n\n');
|
||||
const fullDocs = docs.filter((d) => d.policyLoad !== 'progressive');
|
||||
const progressiveDocs = docs.filter((d) => d.policyLoad === 'progressive');
|
||||
|
||||
const parts: string[] = [];
|
||||
|
||||
if (fullDocs.length > 0) {
|
||||
parts.push(fullDocs.map((doc) => formatDocument(doc, context)).join('\n\n'));
|
||||
}
|
||||
|
||||
if (progressiveDocs.length > 0) {
|
||||
const entries = progressiveDocs.map(formatProgressiveEntry).join('\n');
|
||||
parts.push(
|
||||
`<agent_documents_index>\nThe following documents are available. Use readDocument tool to access full content.\n${entries}\n</agent_documents_index>`,
|
||||
);
|
||||
}
|
||||
|
||||
return parts.join('\n\n');
|
||||
}
|
||||
|
||||
function approximateTokenTruncate(content: string, maxTokens: number): string {
|
||||
|
||||
@@ -173,6 +173,80 @@ describe('AgentDocumentInjector', () => {
|
||||
expect(result.messages[0].content).toContain('File mode content');
|
||||
expect(result.messages[0].content).toContain('</agent_document>');
|
||||
});
|
||||
|
||||
it('should inject progressive documents as index instead of full content', async () => {
|
||||
const provider = new AgentDocumentContextInjector({
|
||||
documents: [
|
||||
{
|
||||
content: 'Full content that should NOT appear',
|
||||
description: 'Core safety rules',
|
||||
filename: 'guardrails.md',
|
||||
id: 'doc-1',
|
||||
loadPosition: 'before-first-user',
|
||||
loadRules: { rule: 'always' },
|
||||
policyLoad: 'progressive',
|
||||
title: 'Guardrails',
|
||||
},
|
||||
{
|
||||
content: 'Another full content that should NOT appear',
|
||||
filename: 'notes.txt',
|
||||
id: 'doc-2',
|
||||
loadPosition: 'before-first-user',
|
||||
loadRules: { rule: 'always' },
|
||||
policyLoad: 'progressive',
|
||||
title: 'Notes',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const context = createContext([{ content: 'Hello', id: 'user-1', role: 'user' }]);
|
||||
const result = await provider.process(context);
|
||||
|
||||
const injected = result.messages[0].content;
|
||||
expect(injected).toContain('<agent_documents_index>');
|
||||
expect(injected).toContain('[doc-1]');
|
||||
expect(injected).toContain('guardrails.md');
|
||||
expect(injected).toContain('"Guardrails"');
|
||||
expect(injected).toContain('Core safety rules');
|
||||
expect(injected).toContain('[doc-2]');
|
||||
expect(injected).toContain('notes.txt');
|
||||
expect(injected).not.toContain('Full content that should NOT appear');
|
||||
expect(injected).not.toContain('Another full content that should NOT appear');
|
||||
expect(injected).toContain('</agent_documents_index>');
|
||||
});
|
||||
|
||||
it('should mix full-content and progressive documents', async () => {
|
||||
const provider = new AgentDocumentContextInjector({
|
||||
documents: [
|
||||
{
|
||||
content: 'Always-loaded full content',
|
||||
filename: 'full.md',
|
||||
loadPosition: 'before-first-user',
|
||||
loadRules: { rule: 'always' },
|
||||
policyLoad: 'always',
|
||||
},
|
||||
{
|
||||
content: 'Progressive content hidden',
|
||||
description: 'A summary doc',
|
||||
filename: 'summary.md',
|
||||
id: 'doc-p',
|
||||
loadPosition: 'before-first-user',
|
||||
loadRules: { rule: 'always' },
|
||||
policyLoad: 'progressive',
|
||||
title: 'Summary',
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
const context = createContext([{ content: 'Hello', id: 'user-1', role: 'user' }]);
|
||||
const result = await provider.process(context);
|
||||
|
||||
const injected = result.messages[0].content;
|
||||
expect(injected).toContain('Always-loaded full content');
|
||||
expect(injected).toContain('<agent_documents_index>');
|
||||
expect(injected).toContain('summary.md');
|
||||
expect(injected).not.toContain('Progressive content hidden');
|
||||
});
|
||||
});
|
||||
|
||||
describe('AgentDocumentBeforeSystemInjector (before-system)', () => {
|
||||
|
||||
@@ -80,7 +80,7 @@ describe('AgentDocumentModel', () => {
|
||||
expect(result.policy?.context?.position).toBe(DocumentLoadPosition.BEFORE_FIRST_USER);
|
||||
expect(result.policy?.context?.rule).toBe(DocumentLoadRule.ALWAYS);
|
||||
expect(result.policyLoadFormat).toBe(DocumentLoadFormat.RAW);
|
||||
expect(result.policyLoad).toBe(PolicyLoad.ALWAYS);
|
||||
expect(result.policyLoad).toBe(PolicyLoad.PROGRESSIVE);
|
||||
expect(result.accessShared).toBe(0);
|
||||
expect(result.accessPublic).toBe(0);
|
||||
});
|
||||
@@ -326,6 +326,20 @@ describe('AgentDocumentModel', () => {
|
||||
expect(context).not.toContain('--- manual.md ---');
|
||||
});
|
||||
|
||||
it('should preserve progressive policyLoad when updating load rule without mode', async () => {
|
||||
const doc = await agentDocumentModel.create(agentId, 'progressive.md', 'content');
|
||||
expect(doc.policyLoad).toBe(PolicyLoad.PROGRESSIVE);
|
||||
|
||||
const updated = await agentDocumentModel.updateToolLoadRule(doc.id, {
|
||||
rule: 'by-keywords',
|
||||
keywords: ['test'],
|
||||
});
|
||||
|
||||
expect(updated?.policyLoad).toBe(PolicyLoad.PROGRESSIVE);
|
||||
expect(updated?.policy?.context?.keywords).toEqual(['test']);
|
||||
expect(updated?.policyLoadRule).toBe(DocumentLoadRule.BY_KEYWORDS);
|
||||
});
|
||||
|
||||
it('should group docs by position and sort by priority ascending', async () => {
|
||||
await agentDocumentModel.create(
|
||||
agentId,
|
||||
|
||||
@@ -94,6 +94,8 @@ export class AgentDocumentModel {
|
||||
templateId?: string,
|
||||
metadata?: Record<string, any>,
|
||||
policy?: AgentDocumentPolicy,
|
||||
createdAt?: Date,
|
||||
updatedAt?: Date,
|
||||
): Promise<AgentDocument> {
|
||||
const title = filename.replace(/\.[^.]+$/, '');
|
||||
const stats = this.getDocumentStats(content);
|
||||
@@ -102,6 +104,7 @@ export class AgentDocumentModel {
|
||||
return this.db.transaction(async (trx) => {
|
||||
const documentPayload: NewDocument = {
|
||||
content,
|
||||
createdAt,
|
||||
description: metadata?.description,
|
||||
fileType: 'agent/document',
|
||||
filename,
|
||||
@@ -111,6 +114,7 @@ export class AgentDocumentModel {
|
||||
title,
|
||||
totalCharCount: stats.totalCharCount,
|
||||
totalLineCount: stats.totalLineCount,
|
||||
updatedAt: updatedAt ?? createdAt,
|
||||
userId: this.userId,
|
||||
};
|
||||
|
||||
@@ -126,7 +130,8 @@ export class AgentDocumentModel {
|
||||
AgentAccess.DELETE,
|
||||
accessShared: 0,
|
||||
agentId,
|
||||
policyLoad: PolicyLoad.ALWAYS,
|
||||
createdAt,
|
||||
policyLoad: PolicyLoad.PROGRESSIVE,
|
||||
deleteReason: null,
|
||||
deletedAt: null,
|
||||
deletedByAgentId: null,
|
||||
@@ -138,6 +143,7 @@ export class AgentDocumentModel {
|
||||
normalizedPolicy.context?.position || DocumentLoadPosition.BEFORE_FIRST_USER,
|
||||
policyLoadRule: normalizedPolicy.context?.rule || DocumentLoadRule.ALWAYS,
|
||||
templateId,
|
||||
updatedAt: updatedAt ?? createdAt,
|
||||
userId: this.userId,
|
||||
};
|
||||
|
||||
@@ -266,7 +272,7 @@ export class AgentDocumentModel {
|
||||
): Promise<AgentDocument | undefined> {
|
||||
const existing = await this.findById(documentId);
|
||||
if (!existing) return undefined;
|
||||
const composedPolicy = composeToolPolicyUpdate(existing.policy, rule);
|
||||
const composedPolicy = composeToolPolicyUpdate(existing.policy, rule, existing.policyLoad);
|
||||
|
||||
await this.db
|
||||
.update(agentDocuments)
|
||||
@@ -315,6 +321,8 @@ export class AgentDocumentModel {
|
||||
templateId?: string,
|
||||
metadata?: Record<string, any>,
|
||||
policy?: AgentDocumentPolicy,
|
||||
createdAt?: Date,
|
||||
updatedAt?: Date,
|
||||
): Promise<AgentDocument> {
|
||||
const existing = await this.findByFilename(agentId, filename);
|
||||
|
||||
@@ -339,6 +347,8 @@ export class AgentDocumentModel {
|
||||
templateId,
|
||||
metadata,
|
||||
policy,
|
||||
createdAt,
|
||||
updatedAt,
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -110,4 +110,66 @@ describe('agentDocuments checks', () => {
|
||||
|
||||
expect(isLoadableDocument(noReadDoc)).toBe(false);
|
||||
});
|
||||
|
||||
it('treats progressive policyLoad as auto-loadable', () => {
|
||||
const progressiveDoc = {
|
||||
accessSelf:
|
||||
AgentAccess.EXECUTE |
|
||||
AgentAccess.LIST |
|
||||
AgentAccess.READ |
|
||||
AgentAccess.WRITE |
|
||||
AgentAccess.DELETE,
|
||||
policyLoad: PolicyLoad.PROGRESSIVE,
|
||||
};
|
||||
|
||||
expect(canAutoLoadDocument(progressiveDoc)).toBe(true);
|
||||
expect(isLoadableDocument(progressiveDoc)).toBe(true);
|
||||
});
|
||||
|
||||
it('composes tool policy update with progressive mode', () => {
|
||||
const composed = composeToolPolicyUpdate(null, {
|
||||
mode: 'progressive',
|
||||
rule: 'always',
|
||||
});
|
||||
|
||||
expect(composed.policyLoad).toBe(PolicyLoad.PROGRESSIVE);
|
||||
});
|
||||
|
||||
it('preserves existing policyLoad when rule.mode is omitted', () => {
|
||||
const composed = composeToolPolicyUpdate(
|
||||
{ context: { loadMode: undefined } },
|
||||
{ rule: 'by-keywords', keywords: ['test'] },
|
||||
PolicyLoad.PROGRESSIVE,
|
||||
);
|
||||
|
||||
expect(composed.policyLoad).toBe(PolicyLoad.PROGRESSIVE);
|
||||
expect(composed.policyLoadRule).toBe(DocumentLoadRule.BY_KEYWORDS);
|
||||
});
|
||||
|
||||
it('preserves existing progressive loadMode in policy context', () => {
|
||||
const composed = composeToolPolicyUpdate(
|
||||
{ context: { loadMode: 'progressive' } },
|
||||
{ rule: 'by-keywords', keywords: ['test'] },
|
||||
);
|
||||
|
||||
expect(composed.policyLoad).toBe(PolicyLoad.PROGRESSIVE);
|
||||
expect(composed.policy.context?.loadMode).toBe('progressive');
|
||||
});
|
||||
|
||||
it('overrides policyLoad when rule.mode is explicitly set', () => {
|
||||
const composed = composeToolPolicyUpdate(
|
||||
{ context: { loadMode: 'progressive' } },
|
||||
{ mode: 'always', rule: 'always' },
|
||||
PolicyLoad.PROGRESSIVE,
|
||||
);
|
||||
|
||||
expect(composed.policyLoad).toBe(PolicyLoad.ALWAYS);
|
||||
expect(composed.policy.context?.loadMode).toBe('always');
|
||||
});
|
||||
|
||||
it('defaults to ALWAYS when no mode, no context, no existingPolicyLoad', () => {
|
||||
const composed = composeToolPolicyUpdate(null, { rule: 'always' });
|
||||
|
||||
expect(composed.policyLoad).toBe(PolicyLoad.ALWAYS);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -22,7 +22,7 @@ export const canDeleteDocument = (doc: Pick<AgentDocument, 'accessSelf'>): boole
|
||||
};
|
||||
|
||||
export const canAutoLoadDocument = (doc: Pick<AgentDocument, 'policyLoad'>): boolean => {
|
||||
return doc.policyLoad === PolicyLoad.ALWAYS;
|
||||
return doc.policyLoad === PolicyLoad.ALWAYS || doc.policyLoad === PolicyLoad.PROGRESSIVE;
|
||||
};
|
||||
|
||||
export const isLoadableDocument = (
|
||||
|
||||
@@ -35,6 +35,7 @@ export interface ToolPolicyCompositionResult {
|
||||
export const composeToolPolicyUpdate = (
|
||||
existingPolicy: AgentDocumentPolicy | null,
|
||||
rule: ToolUpdateLoadRule,
|
||||
existingPolicyLoad?: PolicyLoad,
|
||||
): ToolPolicyCompositionResult => {
|
||||
const resolvePolicyLoadFormat = (format?: string): DocumentLoadFormat => {
|
||||
if (format === 'file') {
|
||||
@@ -45,8 +46,7 @@ export const composeToolPolicyUpdate = (
|
||||
|
||||
const currentPolicy = existingPolicy || {};
|
||||
const existingContext = currentPolicy.context || {};
|
||||
const loadMode =
|
||||
rule.mode ?? (existingContext.loadMode as ToolUpdateLoadRule['mode']) ?? 'always';
|
||||
const loadMode = rule.mode ?? (existingContext.loadMode as ToolUpdateLoadRule['mode']);
|
||||
const policyLoadFormat = resolvePolicyLoadFormat(
|
||||
rule.policyLoadFormat ??
|
||||
(existingContext.policyLoadFormat as DocumentLoadFormat | undefined) ??
|
||||
@@ -60,7 +60,7 @@ export const composeToolPolicyUpdate = (
|
||||
...currentPolicy,
|
||||
context: {
|
||||
...existingContext,
|
||||
loadMode,
|
||||
loadMode: loadMode ?? existingContext.loadMode,
|
||||
keywordMatchMode: rule.keywordMatchMode ?? existingContext.keywordMatchMode,
|
||||
keywords: rule.keywords ?? existingContext.keywords,
|
||||
policyLoadFormat,
|
||||
@@ -75,7 +75,13 @@ export const composeToolPolicyUpdate = (
|
||||
} satisfies AgentDocumentPolicy;
|
||||
|
||||
return {
|
||||
policyLoad: loadMode === 'always' ? PolicyLoad.ALWAYS : PolicyLoad.DISABLED,
|
||||
policyLoad: loadMode
|
||||
? loadMode === 'always'
|
||||
? PolicyLoad.ALWAYS
|
||||
: loadMode === 'progressive'
|
||||
? PolicyLoad.PROGRESSIVE
|
||||
: PolicyLoad.DISABLED
|
||||
: (existingPolicyLoad ?? PolicyLoad.ALWAYS),
|
||||
policy,
|
||||
policyLoadFormat,
|
||||
policyLoadRule: documentLoadRule,
|
||||
|
||||
@@ -58,7 +58,7 @@ export interface ToolUpdateLoadRule {
|
||||
keywords?: string[];
|
||||
maxDocuments?: number;
|
||||
maxTokens?: number;
|
||||
mode?: 'always' | 'manual' | 'on-demand';
|
||||
mode?: 'always' | 'manual' | 'on-demand' | 'progressive';
|
||||
pinnedDocumentIds?: string[];
|
||||
policyLoadFormat?: 'file' | 'raw';
|
||||
priority?: number;
|
||||
|
||||
@@ -1406,7 +1406,6 @@ describe('google contextBuilders', () => {
|
||||
expect(result.parameters?.properties).toEqual({
|
||||
query: { type: 'string' },
|
||||
timeIntent: {
|
||||
additionalProperties: false,
|
||||
properties: {
|
||||
selector: { enum: ['today', 'yesterday', 'month'], type: 'string' },
|
||||
date: { format: 'date-time', type: 'string' },
|
||||
@@ -1545,6 +1544,77 @@ describe('google contextBuilders', () => {
|
||||
field: { description: 'some field' },
|
||||
});
|
||||
});
|
||||
|
||||
it('should strip additionalProperties from schemas', () => {
|
||||
const tool: ChatCompletionTool = {
|
||||
function: {
|
||||
description: 'A tool with additionalProperties',
|
||||
name: 'apTool',
|
||||
parameters: {
|
||||
properties: {
|
||||
config: {
|
||||
additionalProperties: false,
|
||||
properties: {
|
||||
nested: {
|
||||
additionalProperties: { type: 'string' },
|
||||
type: 'object',
|
||||
},
|
||||
},
|
||||
type: 'object',
|
||||
},
|
||||
},
|
||||
type: 'object',
|
||||
},
|
||||
},
|
||||
type: 'function',
|
||||
};
|
||||
|
||||
const result = buildGoogleTool(tool);
|
||||
|
||||
expect(result.parameters?.properties).toEqual({
|
||||
config: {
|
||||
properties: {
|
||||
nested: {
|
||||
type: 'object',
|
||||
},
|
||||
},
|
||||
type: 'object',
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
it('should strip remaining $ref when resolveRefs exceeds depth limit', () => {
|
||||
// Build a deeply recursive schema that exceeds depth limit of 10
|
||||
const tool: ChatCompletionTool = {
|
||||
function: {
|
||||
description: 'A tool with deep recursive $ref',
|
||||
name: 'deepRefTool',
|
||||
parameters: {
|
||||
definitions: {
|
||||
node: {
|
||||
properties: {
|
||||
child: { oneOf: [{ type: 'string' }, { $ref: '#/definitions/node' }] },
|
||||
},
|
||||
type: 'object',
|
||||
},
|
||||
},
|
||||
properties: {
|
||||
root: { $ref: '#/definitions/node' },
|
||||
},
|
||||
type: 'object',
|
||||
},
|
||||
},
|
||||
type: 'function',
|
||||
};
|
||||
|
||||
const result = buildGoogleTool(tool);
|
||||
|
||||
// Verify no $ref remains anywhere in the output
|
||||
const json = JSON.stringify(result);
|
||||
expect(json).not.toContain('"$ref"');
|
||||
// Also verify no additionalProperties
|
||||
expect(json).not.toContain('"additionalProperties"');
|
||||
});
|
||||
});
|
||||
|
||||
describe('buildGoogleTools', () => {
|
||||
|
||||
@@ -252,7 +252,7 @@ export const buildGoogleMessages = async (messages: OpenAIChatMessage[]): Promis
|
||||
* JSON Schema keywords that cause Google GenAI / Vertex AI SDK validation errors.
|
||||
* Other unsupported keywords are silently ignored by the API, so only strip these.
|
||||
*/
|
||||
const UNSUPPORTED_SCHEMA_KEYS = new Set(['examples', 'default']);
|
||||
const UNSUPPORTED_SCHEMA_KEYS = new Set(['examples', 'default', 'additionalProperties', '$ref']);
|
||||
|
||||
/**
|
||||
* Resolve all `$ref` pointers in a JSON Schema tree by inlining definitions.
|
||||
|
||||
@@ -25,6 +25,7 @@ const TOKEN_EXCHANGE_URL = 'https://api.github.com/copilot_internal/v2/token';
|
||||
|
||||
const MAX_TOTAL_ATTEMPTS = 5;
|
||||
const MAX_RATE_LIMIT_RETRIES = 3;
|
||||
const QUOTA_EXHAUSTION_THRESHOLD_MS = 5 * 60 * 1000; // 5 minutes
|
||||
|
||||
const debugParams = {
|
||||
chatCompletion: () => process.env.DEBUG_GITHUBCOPILOT_CHAT_COMPLETION === '1',
|
||||
@@ -457,6 +458,11 @@ export class LobeGithubCopilotAI implements LobeRuntimeAI {
|
||||
rateLimitAttempts++;
|
||||
const retryAfter = this.getRetryAfterMs(error) ?? 1000 * Math.pow(2, rateLimitAttempts);
|
||||
|
||||
// If retry-after exceeds the quota exhaustion threshold, surface immediately
|
||||
if (retryAfter > QUOTA_EXHAUSTION_THRESHOLD_MS) {
|
||||
throw this.mapError(error);
|
||||
}
|
||||
|
||||
await new Promise<void>((resolve) => {
|
||||
setTimeout(resolve, Math.min(retryAfter, 10_000));
|
||||
});
|
||||
|
||||
@@ -0,0 +1,209 @@
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
|
||||
// Now import the real service — only the stubs above are faked
|
||||
import { ResponsesService } from '../responses.service';
|
||||
|
||||
// Stub external dependencies so ResponsesService can be imported in isolation
|
||||
vi.mock('@/server/modules/AgentRuntime/InMemoryStreamEventManager', () => ({
|
||||
InMemoryStreamEventManager: class {},
|
||||
}));
|
||||
vi.mock('@/server/modules/AgentRuntime/StreamEventManager', () => ({}));
|
||||
vi.mock('@/server/services/agentRuntime', () => ({ AgentRuntimeService: class {} }));
|
||||
vi.mock('@/server/services/aiAgent', () => ({ AiAgentService: class {} }));
|
||||
vi.mock('../../common/base.service', () => ({
|
||||
BaseService: class {
|
||||
db: any;
|
||||
userId = '';
|
||||
constructor() {}
|
||||
log() {}
|
||||
},
|
||||
}));
|
||||
|
||||
// Helper: call the private extractOutputItems via bracket notation
|
||||
const callExtractOutputItems = (messages: any[], responseId: string) => {
|
||||
const svc = new (ResponsesService as any)(null, null);
|
||||
return svc['extractOutputItems']({ messages }, responseId);
|
||||
};
|
||||
|
||||
describe('ResponsesService.extractOutputItems', () => {
|
||||
describe('assistant message with tool_calls should still emit message item', () => {
|
||||
it('should include both message and function_call when assistant has text + tool_calls', () => {
|
||||
const messages = [
|
||||
{
|
||||
content: '好的,我来在沙箱中随机生成一个散点图!',
|
||||
role: 'assistant',
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
arguments: '{"code":"import matplotlib.pyplot as plt\\nprint(1)"}',
|
||||
name: 'lobe-cloud-sandbox____executeCode____builtin',
|
||||
},
|
||||
id: 'call_abc123',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const { output } = callExtractOutputItems(messages, 'tpc_test');
|
||||
|
||||
expect(output).toHaveLength(2);
|
||||
|
||||
expect(output[0]).toMatchObject({
|
||||
content: [
|
||||
{
|
||||
text: '好的,我来在沙箱中随机生成一个散点图!',
|
||||
type: 'output_text',
|
||||
},
|
||||
],
|
||||
role: 'assistant',
|
||||
status: 'completed',
|
||||
type: 'message',
|
||||
});
|
||||
|
||||
expect(output[1]).toMatchObject({
|
||||
status: 'completed',
|
||||
type: 'function_call',
|
||||
});
|
||||
});
|
||||
|
||||
it('should still work for assistant messages without tool_calls', () => {
|
||||
const messages = [{ content: 'Hello, how can I help?', role: 'assistant' }];
|
||||
|
||||
const { output, outputText } = callExtractOutputItems(messages, 'tpc_test');
|
||||
|
||||
expect(output).toHaveLength(1);
|
||||
expect(output[0].type).toBe('message');
|
||||
expect(outputText).toBe('Hello, how can I help?');
|
||||
});
|
||||
|
||||
it('should not emit message for assistant with empty content + tool_calls', () => {
|
||||
const messages = [
|
||||
{
|
||||
content: '',
|
||||
role: 'assistant',
|
||||
tool_calls: [{ function: { arguments: '{}', name: 'my-plugin____myApi' }, id: 'call_1' }],
|
||||
},
|
||||
];
|
||||
|
||||
const { output } = callExtractOutputItems(messages, 'tpc_test');
|
||||
|
||||
expect(output).toHaveLength(1);
|
||||
expect(output[0].type).toBe('function_call');
|
||||
});
|
||||
});
|
||||
|
||||
describe('function_call name should be decoded from internal ____-separated format', () => {
|
||||
it('should decode builtin tool names: identifier____apiName____builtin → identifier/apiName', () => {
|
||||
const messages = [
|
||||
{
|
||||
content: '',
|
||||
role: 'assistant',
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
arguments: '{"code":"print(1)"}',
|
||||
name: 'lobe-cloud-sandbox____executeCode____builtin',
|
||||
},
|
||||
id: 'call_abc123',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const { output } = callExtractOutputItems(messages, 'tpc_test');
|
||||
|
||||
const fc = output.find((item: any) => item.type === 'function_call');
|
||||
expect(fc.name).toBe('lobe-cloud-sandbox/executeCode');
|
||||
});
|
||||
|
||||
it('should strip lobe-client-fn prefix correctly', () => {
|
||||
const messages = [
|
||||
{
|
||||
content: '',
|
||||
role: 'assistant',
|
||||
tool_calls: [
|
||||
{
|
||||
function: { arguments: '{}', name: 'lobe-client-fn____get_weather' },
|
||||
id: 'call_xyz',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const { output } = callExtractOutputItems(messages, 'tpc_test');
|
||||
|
||||
const fc = output.find((item: any) => item.type === 'function_call');
|
||||
expect(fc.name).toBe('get_weather');
|
||||
});
|
||||
|
||||
it('should decode default type tools: identifier____apiName → identifier/apiName', () => {
|
||||
const messages = [
|
||||
{
|
||||
content: '',
|
||||
role: 'assistant',
|
||||
tool_calls: [
|
||||
{ function: { arguments: '{}', name: 'my-plugin____myApi' }, id: 'call_def' },
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const { output } = callExtractOutputItems(messages, 'tpc_test');
|
||||
|
||||
const fc = output.find((item: any) => item.type === 'function_call');
|
||||
expect(fc.name).toBe('my-plugin/myApi');
|
||||
});
|
||||
|
||||
it('should return raw name when no separator is present', () => {
|
||||
const messages = [
|
||||
{
|
||||
content: '',
|
||||
role: 'assistant',
|
||||
tool_calls: [{ function: { arguments: '{}', name: 'simple_tool' }, id: 'call_simple' }],
|
||||
},
|
||||
];
|
||||
|
||||
const { output } = callExtractOutputItems(messages, 'tpc_test');
|
||||
|
||||
const fc = output.find((item: any) => item.type === 'function_call');
|
||||
expect(fc.name).toBe('simple_tool');
|
||||
});
|
||||
});
|
||||
|
||||
describe('function_call id should match streaming output_index', () => {
|
||||
it('should assign index 1 to function_call when message (index 0) precedes it', () => {
|
||||
const messages = [
|
||||
{
|
||||
content: '好的,我来执行代码!',
|
||||
role: 'assistant',
|
||||
tool_calls: [
|
||||
{
|
||||
function: {
|
||||
arguments: '{"code":"1+1"}',
|
||||
name: 'lobe-cloud-sandbox____executeCode____builtin',
|
||||
},
|
||||
id: 'call_abc',
|
||||
},
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
const { output } = callExtractOutputItems(messages, 'tpc_test');
|
||||
|
||||
expect(output[0].id).toBe('msg_tpc_test_0');
|
||||
expect(output[1].id).toBe('fc_tpc_test_1');
|
||||
});
|
||||
|
||||
it('should assign index 0 to function_call when no message content', () => {
|
||||
const messages = [
|
||||
{
|
||||
content: '',
|
||||
role: 'assistant',
|
||||
tool_calls: [{ function: { arguments: '{}', name: 'plugin____api' }, id: 'call_1' }],
|
||||
},
|
||||
];
|
||||
|
||||
const { output } = callExtractOutputItems(messages, 'tpc_test');
|
||||
expect(output[0].id).toBe('fc_tpc_test_0');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -180,14 +180,26 @@ export class ResponsesService extends BaseService {
|
||||
if (msg.role === 'assistant') {
|
||||
const hasToolCalls = msg.tool_calls && msg.tool_calls.length > 0;
|
||||
|
||||
// Emit message item for assistant text content (even when tool_calls are present)
|
||||
const content = typeof msg.content === 'string' ? msg.content : '';
|
||||
if (content) {
|
||||
outputText = content;
|
||||
output.push({
|
||||
content: [
|
||||
{ annotations: [], logprobs: [], text: content, type: 'output_text' as const },
|
||||
],
|
||||
id: `msg_${responseId}_${itemCounter++}`,
|
||||
role: 'assistant' as const,
|
||||
status: 'completed' as const,
|
||||
type: 'message' as const,
|
||||
});
|
||||
}
|
||||
|
||||
// Handle tool_calls from assistant
|
||||
if (hasToolCalls) {
|
||||
for (const toolCall of msg.tool_calls) {
|
||||
// Convert internal tool names: lobe-client-fn____get_weather → get_weather
|
||||
let fnName = toolCall.function?.name ?? '';
|
||||
if (fnName.startsWith('lobe-client-fn____')) {
|
||||
fnName = fnName.slice('lobe-client-fn____'.length);
|
||||
}
|
||||
// Decode internal tool name format back to display name
|
||||
const fnName = this.decodeToolName(toolCall.function?.name ?? '');
|
||||
output.push({
|
||||
arguments: toolCall.function?.arguments ?? '{}',
|
||||
call_id: toolCall.id ?? `call_${itemCounter}`,
|
||||
@@ -198,23 +210,6 @@ export class ResponsesService extends BaseService {
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Only emit message item for assistant messages WITHOUT tool_calls (i.e., final text response)
|
||||
if (!hasToolCalls) {
|
||||
const content = typeof msg.content === 'string' ? msg.content : '';
|
||||
if (content) {
|
||||
outputText = content;
|
||||
output.push({
|
||||
content: [
|
||||
{ annotations: [], logprobs: [], text: content, type: 'output_text' as const },
|
||||
],
|
||||
id: `msg_${responseId}_${itemCounter++}`,
|
||||
role: 'assistant' as const,
|
||||
status: 'completed' as const,
|
||||
type: 'message' as const,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (msg.role === 'tool') {
|
||||
output.push({
|
||||
call_id: msg.tool_call_id ?? '',
|
||||
@@ -229,6 +224,25 @@ export class ResponsesService extends BaseService {
|
||||
return { output, outputText };
|
||||
}
|
||||
|
||||
/**
|
||||
* Decode internal tool name format to display name.
|
||||
* - lobe-client-fn____get_weather → get_weather
|
||||
* - lobe-cloud-sandbox____executeCode____builtin → lobe-cloud-sandbox/executeCode
|
||||
* - my-plugin____myApi → my-plugin/myApi
|
||||
*/
|
||||
private decodeToolName(rawName: string): string {
|
||||
const SEPARATOR = '____';
|
||||
if (rawName.startsWith(`lobe-client-fn${SEPARATOR}`)) {
|
||||
return rawName.slice(`lobe-client-fn${SEPARATOR}`.length);
|
||||
}
|
||||
const parts = rawName.split(SEPARATOR);
|
||||
if (parts.length >= 2) {
|
||||
// parts[0] = identifier, parts[1] = apiName, parts[2+] = type (ignored for display)
|
||||
return `${parts[0]}/${parts[1]}`;
|
||||
}
|
||||
return rawName;
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract usage from AgentState
|
||||
*/
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
import { resolve } from 'node:path';
|
||||
|
||||
import { defineConfig } from 'vitest/config';
|
||||
|
||||
export default defineConfig({
|
||||
resolve: {
|
||||
alias: {
|
||||
'@/': resolve(__dirname, '../../src') + '/',
|
||||
},
|
||||
},
|
||||
test: {
|
||||
environment: 'node',
|
||||
},
|
||||
});
|
||||
@@ -82,7 +82,7 @@ export interface TaskDetailData {
|
||||
dependencies?: Array<{ dependsOn: string; type: string }>;
|
||||
description?: string | null;
|
||||
error?: string | null;
|
||||
// heartbeat.interval: 周期执行间隔 | heartbeat.timeout+lastAt: watchdog 监控(检测卡死)
|
||||
// heartbeat.interval: periodic execution interval | heartbeat.timeout+lastAt: watchdog monitoring (detects stuck tasks)
|
||||
heartbeat?: {
|
||||
interval?: number | null;
|
||||
lastAt?: string | null;
|
||||
|
||||
@@ -1,370 +0,0 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { SECRET_XOR_KEY } from '@/envs/auth';
|
||||
|
||||
import { obfuscatePayloadWithXOR } from './xor-obfuscation';
|
||||
|
||||
describe('xor-obfuscation', () => {
|
||||
describe('obfuscatePayloadWithXOR', () => {
|
||||
it('应该对简单字符串进行混淆并返回Base64字符串', () => {
|
||||
const payload = 'hello world';
|
||||
const result = obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY);
|
||||
|
||||
// 验证返回值是字符串
|
||||
expect(typeof result).toBe('string');
|
||||
|
||||
// 验证返回值是有效的Base64字符串
|
||||
expect(() => atob(result)).not.toThrow();
|
||||
|
||||
// 验证结果长度大于0
|
||||
expect(result.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('应该对JSON对象进行混淆', () => {
|
||||
const payload = { name: 'test', value: 123, active: true };
|
||||
const result = obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY);
|
||||
|
||||
// 验证返回值是字符串
|
||||
expect(typeof result).toBe('string');
|
||||
|
||||
// 验证返回值是有效的Base64字符串
|
||||
expect(() => atob(result)).not.toThrow();
|
||||
});
|
||||
|
||||
it('应该对数组进行混淆', () => {
|
||||
const payload = [1, 2, 3, 'test', { nested: true }];
|
||||
const result = obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY);
|
||||
|
||||
// 验证返回值是字符串
|
||||
expect(typeof result).toBe('string');
|
||||
|
||||
// 验证返回值是有效的Base64字符串
|
||||
expect(() => atob(result)).not.toThrow();
|
||||
});
|
||||
|
||||
it('应该对复杂嵌套对象进行混淆', () => {
|
||||
const payload = {
|
||||
user: {
|
||||
id: 123,
|
||||
profile: {
|
||||
name: 'John Doe',
|
||||
settings: {
|
||||
theme: 'dark',
|
||||
notifications: true,
|
||||
preferences: ['email', 'sms'],
|
||||
},
|
||||
},
|
||||
},
|
||||
tokens: ['abc123', 'def456'],
|
||||
metadata: null,
|
||||
};
|
||||
const result = obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY);
|
||||
|
||||
// 验证返回值是字符串
|
||||
expect(typeof result).toBe('string');
|
||||
|
||||
// 验证返回值是有效的Base64字符串
|
||||
expect(() => atob(result)).not.toThrow();
|
||||
});
|
||||
|
||||
it('相同的输入应该产生相同的输出', () => {
|
||||
const payload = { test: 'consistent' };
|
||||
const result1 = obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY);
|
||||
const result2 = obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY);
|
||||
|
||||
expect(result1).toBe(result2);
|
||||
});
|
||||
|
||||
it('不同的输入应该产生不同的输出', () => {
|
||||
const payload1 = { test: 'value1' };
|
||||
const payload2 = { test: 'value2' };
|
||||
|
||||
const result1 = obfuscatePayloadWithXOR(payload1, SECRET_XOR_KEY);
|
||||
const result2 = obfuscatePayloadWithXOR(payload2, SECRET_XOR_KEY);
|
||||
|
||||
expect(result1).not.toBe(result2);
|
||||
});
|
||||
|
||||
it('应该处理包含特殊字符的字符串', () => {
|
||||
const payload = 'Hello! @#$%^&*()_+-=[]{}|;:,.<>?/~`"\'\\';
|
||||
const result = obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY);
|
||||
|
||||
// 验证返回值是字符串
|
||||
expect(typeof result).toBe('string');
|
||||
|
||||
// 验证返回值是有效的Base64字符串
|
||||
expect(() => atob(result)).not.toThrow();
|
||||
});
|
||||
|
||||
it('应该处理包含Unicode字符的字符串', () => {
|
||||
const payload = '你好世界 🌍 émojis 日本語 한국어';
|
||||
const result = obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY);
|
||||
|
||||
// 验证返回值是字符串
|
||||
expect(typeof result).toBe('string');
|
||||
|
||||
// 验证返回值是有效的Base64字符串
|
||||
expect(() => atob(result)).not.toThrow();
|
||||
});
|
||||
|
||||
it('应该处理空字符串', () => {
|
||||
const payload = '';
|
||||
const result = obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY);
|
||||
|
||||
// 验证返回值是字符串
|
||||
expect(typeof result).toBe('string');
|
||||
|
||||
// 验证返回值是有效的Base64字符串
|
||||
expect(() => atob(result)).not.toThrow();
|
||||
});
|
||||
|
||||
it('应该处理空对象', () => {
|
||||
const payload = {};
|
||||
const result = obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY);
|
||||
|
||||
// 验证返回值是字符串
|
||||
expect(typeof result).toBe('string');
|
||||
|
||||
// 验证返回值是有效的Base64字符串
|
||||
expect(() => atob(result)).not.toThrow();
|
||||
});
|
||||
|
||||
it('应该处理空数组', () => {
|
||||
const result = obfuscatePayloadWithXOR([], SECRET_XOR_KEY);
|
||||
|
||||
// 验证返回值是字符串
|
||||
expect(typeof result).toBe('string');
|
||||
|
||||
// 验证返回值是有效的Base64字符串
|
||||
expect(() => atob(result)).not.toThrow();
|
||||
});
|
||||
|
||||
it('应该处理null值', () => {
|
||||
const payload = null;
|
||||
const result = obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY);
|
||||
|
||||
// 验证返回值是字符串
|
||||
expect(typeof result).toBe('string');
|
||||
|
||||
// 验证返回值是有效的Base64字符串
|
||||
expect(() => atob(result)).not.toThrow();
|
||||
});
|
||||
|
||||
it('应该处理数字', () => {
|
||||
const payload = 42;
|
||||
const result = obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY);
|
||||
|
||||
// 验证返回值是字符串
|
||||
expect(typeof result).toBe('string');
|
||||
|
||||
// 验证返回值是有效的Base64字符串
|
||||
expect(() => atob(result)).not.toThrow();
|
||||
});
|
||||
|
||||
it('应该处理布尔值', () => {
|
||||
const payloadTrue = true;
|
||||
const payloadFalse = false;
|
||||
|
||||
const resultTrue = obfuscatePayloadWithXOR(payloadTrue, SECRET_XOR_KEY);
|
||||
const resultFalse = obfuscatePayloadWithXOR(payloadFalse, SECRET_XOR_KEY);
|
||||
|
||||
// 验证返回值是字符串
|
||||
expect(typeof resultTrue).toBe('string');
|
||||
expect(typeof resultFalse).toBe('string');
|
||||
|
||||
// 验证返回值是有效的Base64字符串
|
||||
expect(() => atob(resultTrue)).not.toThrow();
|
||||
expect(() => atob(resultFalse)).not.toThrow();
|
||||
|
||||
// 验证不同布尔值产生不同结果
|
||||
expect(resultTrue).not.toBe(resultFalse);
|
||||
});
|
||||
|
||||
it('应该处理包含特殊JSON字符的对象', () => {
|
||||
const payload = {
|
||||
quotes: '"double quotes"',
|
||||
singleQuotes: "'single quotes'",
|
||||
backslash: 'back\\slash',
|
||||
newline: 'line1\nline2',
|
||||
tab: 'col1\tcol2',
|
||||
unicode: '\u0041\u0042\u0043',
|
||||
};
|
||||
const result = obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY);
|
||||
|
||||
// 验证返回值是字符串
|
||||
expect(typeof result).toBe('string');
|
||||
|
||||
// 验证返回值是有效的Base64字符串
|
||||
expect(() => atob(result)).not.toThrow();
|
||||
});
|
||||
|
||||
it('应该处理很长的字符串', () => {
|
||||
const payload = 'a'.repeat(10000);
|
||||
const result = obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY);
|
||||
|
||||
// 验证返回值是字符串
|
||||
expect(typeof result).toBe('string');
|
||||
|
||||
// 验证返回值是有效的Base64字符串
|
||||
expect(() => atob(result)).not.toThrow();
|
||||
|
||||
// 验证结果长度合理(Base64编码后长度应该大约是原始长度的4/3)
|
||||
expect(result.length).toBeGreaterThan(0);
|
||||
});
|
||||
|
||||
it('应该产生不同长度输入的不同输出长度', () => {
|
||||
const shortPayload = 'short';
|
||||
const longPayload = 'this is a much longer string that should produce different output';
|
||||
|
||||
const shortResult = obfuscatePayloadWithXOR(shortPayload, SECRET_XOR_KEY);
|
||||
const longResult = obfuscatePayloadWithXOR(longPayload, SECRET_XOR_KEY);
|
||||
|
||||
// 较长的输入应该产生较长的输出
|
||||
expect(longResult.length).toBeGreaterThan(shortResult.length);
|
||||
});
|
||||
|
||||
it('应该验证输出是有效的Base64格式', () => {
|
||||
const payload = { test: 'base64 validation' };
|
||||
const result = obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY);
|
||||
|
||||
// 验证Base64格式的正则表达式
|
||||
const base64Regex = /^[\d+/a-z]*={0,2}$/i;
|
||||
expect(base64Regex.test(result)).toBe(true);
|
||||
});
|
||||
|
||||
it('应该处理包含循环引用的对象(通过JSON.stringify处理)', () => {
|
||||
// JSON.stringify 会抛出错误处理循环引用,但我们测试正常情况
|
||||
const payload = {
|
||||
id: 1,
|
||||
name: 'test',
|
||||
nested: {
|
||||
back: 'reference',
|
||||
},
|
||||
};
|
||||
|
||||
const result = obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY);
|
||||
expect(typeof result).toBe('string');
|
||||
expect(() => atob(result)).not.toThrow();
|
||||
});
|
||||
|
||||
it('应该对undefined值进行处理', () => {
|
||||
const payload = undefined;
|
||||
const result = obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY);
|
||||
|
||||
// 验证返回值是字符串
|
||||
expect(typeof result).toBe('string');
|
||||
|
||||
// 验证返回值是有效的Base64字符串
|
||||
expect(() => atob(result)).not.toThrow();
|
||||
});
|
||||
|
||||
it('应该对包含函数的对象进行处理(函数会被JSON.stringify忽略)', () => {
|
||||
const payload = {
|
||||
name: 'test',
|
||||
fn: function () {
|
||||
return 'test';
|
||||
},
|
||||
arrow: () => 'arrow',
|
||||
value: 123,
|
||||
};
|
||||
|
||||
const result = obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY);
|
||||
expect(typeof result).toBe('string');
|
||||
expect(() => atob(result)).not.toThrow();
|
||||
});
|
||||
|
||||
it('应该确保XOR操作的确定性', () => {
|
||||
const payload = 'deterministic test';
|
||||
const results: any[] = [];
|
||||
|
||||
// 多次运行相同输入
|
||||
for (let i = 0; i < 10; i++) {
|
||||
results.push(obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY));
|
||||
}
|
||||
|
||||
// 所有结果应该相同
|
||||
expect(results.every((result) => result === results[0])).toBe(true);
|
||||
});
|
||||
|
||||
it('应该处理包含日期对象的数据', () => {
|
||||
const payload = {
|
||||
timestamp: new Date('2024-01-01T00:00:00Z'),
|
||||
created: new Date(),
|
||||
name: 'date test',
|
||||
};
|
||||
|
||||
const result = obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY);
|
||||
expect(typeof result).toBe('string');
|
||||
expect(() => atob(result)).not.toThrow();
|
||||
});
|
||||
|
||||
it('应该处理包含Symbol的对象(Symbol会被JSON.stringify忽略)', () => {
|
||||
const sym = Symbol('test');
|
||||
const payload = {
|
||||
name: 'symbol test',
|
||||
[sym]: 'symbol value',
|
||||
normalKey: 'normal value',
|
||||
};
|
||||
|
||||
const result = obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY);
|
||||
expect(typeof result).toBe('string');
|
||||
expect(() => atob(result)).not.toThrow();
|
||||
});
|
||||
|
||||
it('应该验证混淆后的数据长度合理性', () => {
|
||||
const originalPayload = { test: 'length check' };
|
||||
const originalJSON = JSON.stringify(originalPayload);
|
||||
const result = obfuscatePayloadWithXOR(originalPayload, SECRET_XOR_KEY);
|
||||
|
||||
// Base64 编码后的长度通常是原始长度的 4/3 倍(向上取整到4的倍数)
|
||||
const expectedMinLength = Math.ceil((originalJSON.length * 4) / 3 / 4) * 4;
|
||||
expect(result.length).toBeGreaterThanOrEqual(expectedMinLength - 4); // 允许一些误差
|
||||
});
|
||||
|
||||
it('应该验证XOR操作的正确性(通过逆向操作)', () => {
|
||||
const originalPayload = { message: 'XOR test', value: 42 };
|
||||
const obfuscatedResult = obfuscatePayloadWithXOR(originalPayload, SECRET_XOR_KEY);
|
||||
|
||||
// 手动实现逆向操作来验证 XOR 操作的正确性
|
||||
const base64Decoded = atob(obfuscatedResult);
|
||||
const xoredBytes = new Uint8Array(base64Decoded.length);
|
||||
for (let i = 0; i < base64Decoded.length; i++) {
|
||||
xoredBytes[i] = base64Decoded.charCodeAt(i);
|
||||
}
|
||||
|
||||
// 使用相同的密钥进行逆向 XOR 操作
|
||||
const keyBytes = new TextEncoder().encode(SECRET_XOR_KEY);
|
||||
const decodedBytes = new Uint8Array(xoredBytes.length);
|
||||
for (let i = 0; i < xoredBytes.length; i++) {
|
||||
decodedBytes[i] = xoredBytes[i] ^ keyBytes[i % keyBytes.length];
|
||||
}
|
||||
|
||||
// 将结果转换回字符串
|
||||
const decodedString = new TextDecoder().decode(decodedBytes);
|
||||
const decodedPayload = JSON.parse(decodedString);
|
||||
|
||||
// 验证解码后的数据与原始数据相同
|
||||
expect(decodedPayload).toEqual(originalPayload);
|
||||
});
|
||||
|
||||
it('应该验证不同输入产生不同的Base64输出', () => {
|
||||
const payloads = [
|
||||
'test1',
|
||||
'test2',
|
||||
{ key: 'value1' },
|
||||
{ key: 'value2' },
|
||||
[1, 2, 3],
|
||||
[4, 5, 6],
|
||||
];
|
||||
|
||||
const results = payloads.map((payload) => obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY));
|
||||
|
||||
// 验证所有结果都不相同
|
||||
for (let i = 0; i < results.length; i++) {
|
||||
for (let j = i + 1; j < results.length; j++) {
|
||||
expect(results[i]).not.toBe(results[j]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,38 +0,0 @@
|
||||
/**
|
||||
* Convert string to Uint8Array (UTF-8 encoding)
|
||||
*/
|
||||
const stringToUint8Array = (str: string): Uint8Array => {
|
||||
return new TextEncoder().encode(str);
|
||||
};
|
||||
|
||||
/**
|
||||
* Perform XOR operation on Uint8Array
|
||||
* @param data The Uint8Array to process
|
||||
* @param key The key used for XOR operation (Uint8Array)
|
||||
* @returns The Uint8Array after XOR operation
|
||||
*/
|
||||
const xorProcess = (data: Uint8Array, key: Uint8Array): Uint8Array => {
|
||||
const result = new Uint8Array(data.length);
|
||||
for (const [i, datum] of data.entries()) {
|
||||
result[i] = datum ^ key[i % key.length]; // Key is used cyclically
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
/**
|
||||
* Obfuscate payload with XOR and encode to Base64
|
||||
* @param payload The JSON object to obfuscate
|
||||
* @param secretKey The key used for XOR obfuscation
|
||||
* @returns The obfuscated string encoded in Base64
|
||||
*/
|
||||
export const obfuscatePayloadWithXOR = <T>(payload: T, secretKey: string): string => {
|
||||
const jsonString = JSON.stringify(payload);
|
||||
const dataBytes = stringToUint8Array(jsonString);
|
||||
const keyBytes = stringToUint8Array(secretKey);
|
||||
|
||||
const xoredBytes = xorProcess(dataBytes, keyBytes);
|
||||
|
||||
// Convert Uint8Array to Base64 string
|
||||
// In browser environment, btoa can only handle Latin-1 characters, so we need to convert to a format suitable for btoa first
|
||||
return btoa(String.fromCharCode(...xoredBytes));
|
||||
};
|
||||
@@ -3,4 +3,3 @@ export * from './auth';
|
||||
export * from './response';
|
||||
export * from './responsive';
|
||||
export * from './sse';
|
||||
export * from './xor';
|
||||
|
||||
@@ -1,123 +0,0 @@
|
||||
import { describe, expect, it } from 'vitest';
|
||||
|
||||
import { SECRET_XOR_KEY } from '@/envs/auth';
|
||||
|
||||
import { obfuscatePayloadWithXOR } from '../client/xor-obfuscation';
|
||||
import { getXorPayload } from './xor';
|
||||
|
||||
describe('getXorPayload', () => {
|
||||
it('should correctly decode XOR obfuscated payload with user data', () => {
|
||||
const originalPayload = {
|
||||
userId: '001362c3-48c5-4635-bd3b-837bfff58fc0',
|
||||
apiKey: 'test-api-key',
|
||||
baseURL: 'https://api.example.com',
|
||||
};
|
||||
|
||||
// 使用客户端的混淆函数生成token
|
||||
const obfuscatedToken = obfuscatePayloadWithXOR(originalPayload, SECRET_XOR_KEY);
|
||||
|
||||
// 使用服务端的解码函数解码
|
||||
const decodedPayload = getXorPayload(obfuscatedToken);
|
||||
|
||||
expect(decodedPayload).toEqual(originalPayload);
|
||||
});
|
||||
|
||||
it('should correctly decode XOR obfuscated payload with minimal data', () => {
|
||||
const originalPayload = {
|
||||
userId: '12345',
|
||||
};
|
||||
|
||||
const obfuscatedToken = obfuscatePayloadWithXOR(originalPayload, SECRET_XOR_KEY);
|
||||
const decodedPayload = getXorPayload(obfuscatedToken);
|
||||
|
||||
expect(decodedPayload).toEqual(originalPayload);
|
||||
});
|
||||
|
||||
it('should correctly decode XOR obfuscated payload with AWS credentials', () => {
|
||||
const originalPayload = {
|
||||
userId: 'aws-user-123',
|
||||
awsAccessKeyId: 'AKIAIOSFODNN7EXAMPLE',
|
||||
awsSecretAccessKey: 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY',
|
||||
awsRegion: 'us-east-1',
|
||||
awsSessionToken: 'session-token-example',
|
||||
};
|
||||
|
||||
const obfuscatedToken = obfuscatePayloadWithXOR(originalPayload, SECRET_XOR_KEY);
|
||||
const decodedPayload = getXorPayload(obfuscatedToken);
|
||||
|
||||
expect(decodedPayload).toEqual(originalPayload);
|
||||
});
|
||||
|
||||
it('should correctly decode XOR obfuscated payload with Azure data', () => {
|
||||
const originalPayload = {
|
||||
userId: 'azure-user-456',
|
||||
apiKey: 'azure-api-key',
|
||||
baseURL: 'https://your-resource.openai.azure.com',
|
||||
azureApiVersion: '2024-02-15-preview',
|
||||
};
|
||||
|
||||
const obfuscatedToken = obfuscatePayloadWithXOR(originalPayload, SECRET_XOR_KEY);
|
||||
const decodedPayload = getXorPayload(obfuscatedToken);
|
||||
|
||||
expect(decodedPayload).toEqual(originalPayload);
|
||||
});
|
||||
|
||||
it('should correctly decode XOR obfuscated payload with Cloudflare data', () => {
|
||||
const originalPayload = {
|
||||
userId: 'cf-user-789',
|
||||
apiKey: 'cloudflare-api-key',
|
||||
cloudflareBaseURLOrAccountID: 'account-id-example',
|
||||
};
|
||||
|
||||
const obfuscatedToken = obfuscatePayloadWithXOR(originalPayload, SECRET_XOR_KEY);
|
||||
const decodedPayload = getXorPayload(obfuscatedToken);
|
||||
|
||||
expect(decodedPayload).toEqual(originalPayload);
|
||||
});
|
||||
|
||||
it('should handle empty payload correctly', () => {
|
||||
const originalPayload = {};
|
||||
|
||||
const obfuscatedToken = obfuscatePayloadWithXOR(originalPayload, SECRET_XOR_KEY);
|
||||
const decodedPayload = getXorPayload(obfuscatedToken);
|
||||
|
||||
expect(decodedPayload).toEqual(originalPayload);
|
||||
});
|
||||
|
||||
it('should handle payload with undefined values', () => {
|
||||
const originalPayload = {
|
||||
userId: 'test-user',
|
||||
baseURL: undefined,
|
||||
apiKey: 'test-key',
|
||||
};
|
||||
|
||||
const obfuscatedToken = obfuscatePayloadWithXOR(originalPayload, SECRET_XOR_KEY);
|
||||
const decodedPayload = getXorPayload(obfuscatedToken);
|
||||
|
||||
expect(decodedPayload).toEqual(originalPayload);
|
||||
});
|
||||
|
||||
it('should throw error for invalid base64 token', () => {
|
||||
const invalidToken = 'invalid-base64-token!@#';
|
||||
|
||||
expect(() => getXorPayload(invalidToken)).toThrow(SyntaxError);
|
||||
});
|
||||
|
||||
it('should throw error for token that cannot be parsed as JSON', () => {
|
||||
// 创建一个能正确base64解码但不是有效JSON的token
|
||||
const invalidJsonString = 'this is not json';
|
||||
const invalidJsonBytes = new TextEncoder().encode(invalidJsonString);
|
||||
const keyBytes = new TextEncoder().encode('LobeHub · LobeHub');
|
||||
|
||||
// 进行XOR处理
|
||||
const result = new Uint8Array(invalidJsonBytes.length);
|
||||
for (const [i, datum] of invalidJsonBytes.entries()) {
|
||||
result[i] = datum ^ keyBytes[i % keyBytes.length];
|
||||
}
|
||||
|
||||
// 转换为base64
|
||||
const invalidToken = Buffer.from(result).toString('base64');
|
||||
|
||||
expect(() => getXorPayload(invalidToken)).toThrow(SyntaxError);
|
||||
});
|
||||
});
|
||||
@@ -1,44 +0,0 @@
|
||||
import type { ClientSecretPayload } from '@lobechat/types';
|
||||
|
||||
import { SECRET_XOR_KEY } from '@/envs/auth';
|
||||
|
||||
/**
|
||||
* Convert Base64 string to Uint8Array
|
||||
*/
|
||||
const base64ToUint8Array = (base64: string): Uint8Array => {
|
||||
// Use Buffer directly in Node.js environment
|
||||
return Buffer.from(base64, 'base64');
|
||||
};
|
||||
|
||||
/**
|
||||
* Perform XOR operation on Uint8Array (same as the client-side xorProcess function)
|
||||
*/
|
||||
const xorProcess = (data: Uint8Array, key: Uint8Array): Uint8Array => {
|
||||
const result = new Uint8Array(data.length);
|
||||
for (const [i, datum] of data.entries()) {
|
||||
result[i] = datum ^ key[i % key.length];
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
/**
|
||||
* Convert Uint8Array to string (UTF-8 decoding)
|
||||
*/
|
||||
const uint8ArrayToString = (arr: Uint8Array): string => {
|
||||
return new TextDecoder().decode(arr);
|
||||
};
|
||||
|
||||
export const getXorPayload = (token: string): ClientSecretPayload => {
|
||||
const keyBytes = new TextEncoder().encode(SECRET_XOR_KEY);
|
||||
|
||||
// 1. Base64 decoding
|
||||
const base64DecodedBytes = base64ToUint8Array(token);
|
||||
|
||||
// 2. XOR deobfuscation
|
||||
const xorDecryptedBytes = xorProcess(base64DecodedBytes, keyBytes);
|
||||
|
||||
// 3. Convert to string and parse JSON
|
||||
const decodedJsonString = uint8ArrayToString(xorDecryptedBytes);
|
||||
|
||||
return JSON.parse(decodedJsonString) as ClientSecretPayload;
|
||||
};
|
||||
@@ -1,9 +1,7 @@
|
||||
import { AgentRuntimeError } from '@lobechat/model-runtime';
|
||||
import { ChatErrorType } from '@lobechat/types';
|
||||
import { getXorPayload } from '@lobechat/utils/server';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import type * as EnvsAuthModule from '@/envs/auth';
|
||||
import { createErrorResponse } from '@/utils/errorResponse';
|
||||
|
||||
import { type RequestHandler } from './index';
|
||||
@@ -18,17 +16,6 @@ vi.mock('./utils', () => ({
|
||||
checkAuthMethod: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('@lobechat/utils/server', () => ({
|
||||
getXorPayload: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('@/envs/auth', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof EnvsAuthModule>();
|
||||
return {
|
||||
...actual,
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('@/auth', () => ({
|
||||
auth: {
|
||||
api: {
|
||||
@@ -50,34 +37,8 @@ describe('checkAuth', () => {
|
||||
vi.resetAllMocks();
|
||||
});
|
||||
|
||||
it('should return unauthorized error if no authorization header', async () => {
|
||||
await checkAuth(mockHandler)(mockRequest, mockOptions);
|
||||
|
||||
expect(createErrorResponse).toHaveBeenCalledWith(ChatErrorType.Unauthorized, {
|
||||
error: AgentRuntimeError.createError(ChatErrorType.Unauthorized),
|
||||
provider: 'mock',
|
||||
});
|
||||
expect(mockHandler).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return error response on getJWTPayload error', async () => {
|
||||
it('should return error response on checkAuthMethod error (no session)', async () => {
|
||||
const mockError = AgentRuntimeError.createError(ChatErrorType.Unauthorized);
|
||||
mockRequest.headers.set('Authorization', 'invalid');
|
||||
vi.mocked(getXorPayload).mockRejectedValueOnce(mockError);
|
||||
|
||||
await checkAuth(mockHandler)(mockRequest, mockOptions);
|
||||
|
||||
expect(createErrorResponse).toHaveBeenCalledWith(ChatErrorType.Unauthorized, {
|
||||
error: mockError,
|
||||
provider: 'mock',
|
||||
});
|
||||
expect(mockHandler).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it('should return error response on checkAuthMethod error', async () => {
|
||||
const mockError = AgentRuntimeError.createError(ChatErrorType.Unauthorized);
|
||||
mockRequest.headers.set('Authorization', 'valid');
|
||||
vi.mocked(getXorPayload).mockResolvedValueOnce({});
|
||||
vi.mocked(checkAuthMethod).mockImplementationOnce(() => {
|
||||
throw mockError;
|
||||
});
|
||||
|
||||
@@ -1,22 +1,18 @@
|
||||
import { type ChatCompletionErrorPayload, type ModelRuntime } from '@lobechat/model-runtime';
|
||||
import { type ChatCompletionErrorPayload } from '@lobechat/model-runtime';
|
||||
import { AgentRuntimeError } from '@lobechat/model-runtime';
|
||||
import { context as otContext } from '@lobechat/observability-otel/api';
|
||||
import { type ClientSecretPayload } from '@lobechat/types';
|
||||
import { ChatErrorType } from '@lobechat/types';
|
||||
import { getXorPayload } from '@lobechat/utils/server';
|
||||
|
||||
import { auth } from '@/auth';
|
||||
import { getServerDB } from '@/database/core/db-adaptor';
|
||||
import { type LobeChatDatabase } from '@/database/type';
|
||||
import { LOBE_CHAT_AUTH_HEADER, LOBE_CHAT_OIDC_AUTH_HEADER } from '@/envs/auth';
|
||||
import { LOBE_CHAT_OIDC_AUTH_HEADER } from '@/envs/auth';
|
||||
import { extractTraceContext, injectActiveTraceHeaders } from '@/libs/observability/traceparent';
|
||||
import { validateOIDCJWT } from '@/libs/oidc-provider/jwt';
|
||||
import { createErrorResponse } from '@/utils/errorResponse';
|
||||
|
||||
import { checkAuthMethod } from './utils';
|
||||
|
||||
type CreateRuntime = (jwtPayload: ClientSecretPayload) => ModelRuntime;
|
||||
type RequestOptions = { createRuntime?: CreateRuntime; params: Promise<{ provider?: string }> };
|
||||
type RequestOptions = { params: Promise<{ provider?: string }> };
|
||||
|
||||
export type RequestHandler = (
|
||||
req: Request,
|
||||
@@ -48,41 +44,26 @@ export const checkAuth =
|
||||
});
|
||||
}
|
||||
|
||||
let jwtPayload: ClientSecretPayload;
|
||||
let userId: string;
|
||||
|
||||
try {
|
||||
// get Authorization from header
|
||||
const authorization = req.headers.get(LOBE_CHAT_AUTH_HEADER);
|
||||
|
||||
// better auth handler
|
||||
const session = await auth.api.getSession({
|
||||
headers: req.headers,
|
||||
});
|
||||
|
||||
const betterAuthAuthorized = !!session?.user?.id;
|
||||
|
||||
if (!authorization) throw AgentRuntimeError.createError(ChatErrorType.Unauthorized);
|
||||
|
||||
jwtPayload = getXorPayload(authorization);
|
||||
|
||||
// OIDC authentication (CLI)
|
||||
const oidcAuthorization = req.headers.get(LOBE_CHAT_OIDC_AUTH_HEADER);
|
||||
let isUseOidcAuth = false;
|
||||
if (!!oidcAuthorization) {
|
||||
if (oidcAuthorization) {
|
||||
const oidc = await validateOIDCJWT(oidcAuthorization);
|
||||
|
||||
isUseOidcAuth = true;
|
||||
|
||||
jwtPayload = {
|
||||
...jwtPayload,
|
||||
userId: oidc.userId,
|
||||
};
|
||||
}
|
||||
|
||||
if (!isUseOidcAuth)
|
||||
checkAuthMethod({
|
||||
apiKey: jwtPayload.apiKey,
|
||||
betterAuthAuthorized,
|
||||
userId = oidc.userId;
|
||||
} else {
|
||||
// Better Auth session authentication (web)
|
||||
const session = await auth.api.getSession({
|
||||
headers: req.headers,
|
||||
});
|
||||
|
||||
if (!session?.user?.id) {
|
||||
throw AgentRuntimeError.createError(ChatErrorType.Unauthorized);
|
||||
}
|
||||
|
||||
userId = session.user.id;
|
||||
}
|
||||
} catch (e) {
|
||||
const params = await options.params;
|
||||
|
||||
@@ -110,7 +91,7 @@ export const checkAuth =
|
||||
return createErrorResponse(errorType, { error, ...res, provider: params?.provider });
|
||||
}
|
||||
|
||||
const userId = jwtPayload.userId || '';
|
||||
const jwtPayload: ClientSecretPayload = { userId };
|
||||
|
||||
const extractedContext = extractTraceContext(req.headers);
|
||||
|
||||
|
||||
@@ -15,19 +15,11 @@ describe('checkAuthMethod', () => {
|
||||
).not.toThrow();
|
||||
});
|
||||
|
||||
it('should pass with valid API key', () => {
|
||||
expect(() =>
|
||||
checkAuthMethod({
|
||||
apiKey: 'someApiKey',
|
||||
}),
|
||||
).not.toThrow();
|
||||
});
|
||||
|
||||
it('should throw Unauthorized with no auth params', () => {
|
||||
expect(() => checkAuthMethod({})).toThrow();
|
||||
});
|
||||
|
||||
it('should throw Unauthorized when betterAuthAuthorized is false and no apiKey', () => {
|
||||
it('should throw Unauthorized when betterAuthAuthorized is false', () => {
|
||||
expect(() =>
|
||||
checkAuthMethod({
|
||||
betterAuthAuthorized: false,
|
||||
|
||||
@@ -2,25 +2,17 @@ import { AgentRuntimeError } from '@lobechat/model-runtime';
|
||||
import { ChatErrorType } from '@lobechat/types';
|
||||
|
||||
interface CheckAuthParams {
|
||||
apiKey?: string;
|
||||
betterAuthAuthorized?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if authentication is valid based on various auth methods.
|
||||
*
|
||||
* @param {CheckAuthParams} params - Authentication parameters extracted from headers.
|
||||
* @param {string} [params.apiKey] - The user API key.
|
||||
* @param {boolean} [params.betterAuthAuthorized] - Whether the Better Auth session exists.
|
||||
* @throws {AgentRuntimeError} If no valid authentication method is found.
|
||||
* Check if authentication is valid.
|
||||
* Only accepts a verified server-side session (Better Auth).
|
||||
*/
|
||||
export const checkAuthMethod = (params: CheckAuthParams) => {
|
||||
const { apiKey, betterAuthAuthorized } = params;
|
||||
const { betterAuthAuthorized } = params;
|
||||
|
||||
// if better auth session exists
|
||||
if (betterAuthAuthorized) return;
|
||||
|
||||
// if apiKey exist
|
||||
if (apiKey) return;
|
||||
|
||||
throw AgentRuntimeError.createError(ChatErrorType.Unauthorized);
|
||||
};
|
||||
|
||||
@@ -2,11 +2,9 @@
|
||||
import { type LobeRuntimeAI } from '@lobechat/model-runtime';
|
||||
import { ModelRuntime } from '@lobechat/model-runtime';
|
||||
import { ChatErrorType } from '@lobechat/types';
|
||||
import { getXorPayload } from '@lobechat/utils/server';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import type * as EnvsAuthModule from '@/envs/auth';
|
||||
import { LOBE_CHAT_AUTH_HEADER } from '@/envs/auth';
|
||||
import { auth } from '@/auth';
|
||||
import { initModelRuntimeFromDB } from '@/server/modules/ModelRuntime';
|
||||
|
||||
import { POST } from './route';
|
||||
@@ -15,22 +13,11 @@ vi.mock('@/app/(backend)/middleware/auth/utils', () => ({
|
||||
checkAuthMethod: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('@lobechat/utils/server', () => ({
|
||||
getXorPayload: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('@/server/modules/ModelRuntime', () => ({
|
||||
initModelRuntimeFromDB: vi.fn(),
|
||||
createTraceOptions: vi.fn().mockReturnValue({}),
|
||||
}));
|
||||
|
||||
vi.mock('@/envs/auth', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof EnvsAuthModule>();
|
||||
return {
|
||||
...actual,
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('@/auth', () => ({
|
||||
auth: {
|
||||
api: {
|
||||
@@ -43,31 +30,26 @@ vi.mock('@/auth', () => ({
|
||||
let request: Request;
|
||||
beforeEach(() => {
|
||||
request = new Request(new URL('https://test.com'), {
|
||||
headers: {
|
||||
[LOBE_CHAT_AUTH_HEADER]: 'Bearer some-valid-token',
|
||||
},
|
||||
method: 'POST',
|
||||
body: JSON.stringify({ model: 'test-model' }),
|
||||
});
|
||||
|
||||
// Default: valid session
|
||||
vi.mocked(auth.api.getSession).mockResolvedValue({
|
||||
session: {} as any,
|
||||
user: { id: 'test-user-id' } as any,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
// 清除模拟调用历史
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
describe('POST handler', () => {
|
||||
describe('init chat model', () => {
|
||||
it('should initialize ModelRuntime correctly with valid authorization', async () => {
|
||||
it('should initialize ModelRuntime correctly with valid session', async () => {
|
||||
const mockParams = Promise.resolve({ provider: 'test-provider' });
|
||||
|
||||
// 设置 getJWTPayload 的模拟返回值
|
||||
vi.mocked(getXorPayload).mockReturnValueOnce({
|
||||
apiKey: 'test-api-key',
|
||||
azureApiVersion: 'v1',
|
||||
});
|
||||
|
||||
// chat mock 需要返回一个 Response 对象,否则中间件访问 res.headers 会报错
|
||||
const mockChatResponse = new Response(JSON.stringify({ success: true }), {
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
});
|
||||
@@ -76,71 +58,33 @@ describe('POST handler', () => {
|
||||
chat: vi.fn().mockResolvedValue(mockChatResponse),
|
||||
};
|
||||
|
||||
// Mock initModelRuntimeFromDB
|
||||
vi.mocked(initModelRuntimeFromDB).mockResolvedValue(new ModelRuntime(mockRuntime));
|
||||
|
||||
// 调用 POST 函数
|
||||
await POST(request as unknown as Request, { params: mockParams });
|
||||
|
||||
// 验证是否正确调用了模拟函数
|
||||
expect(getXorPayload).toHaveBeenCalledWith('Bearer some-valid-token');
|
||||
expect(initModelRuntimeFromDB).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
expect.any(String),
|
||||
'test-user-id',
|
||||
'test-provider',
|
||||
);
|
||||
});
|
||||
|
||||
it('should return Unauthorized error when LOBE_CHAT_AUTH_HEADER is missing', async () => {
|
||||
it('should return Unauthorized error when no session exists', async () => {
|
||||
vi.mocked(auth.api.getSession).mockResolvedValue(null);
|
||||
|
||||
const mockParams = Promise.resolve({ provider: 'test-provider' });
|
||||
const requestWithoutAuthHeader = new Request(new URL('https://test.com'), {
|
||||
method: 'POST',
|
||||
body: JSON.stringify({ model: 'test-model' }),
|
||||
});
|
||||
|
||||
const response = await POST(requestWithoutAuthHeader, { params: mockParams });
|
||||
|
||||
expect(response.status).toBe(401);
|
||||
expect(await response.json()).toEqual({
|
||||
body: {
|
||||
error: { errorType: 401 },
|
||||
provider: 'test-provider',
|
||||
},
|
||||
errorType: 401,
|
||||
});
|
||||
});
|
||||
|
||||
it('should return InternalServerError error when throw a unknown error', async () => {
|
||||
const mockParams = Promise.resolve({ provider: 'test-provider' });
|
||||
vi.mocked(getXorPayload).mockImplementationOnce(() => {
|
||||
throw new Error('unknown error');
|
||||
});
|
||||
|
||||
const response = await POST(request, { params: mockParams });
|
||||
|
||||
expect(response.status).toBe(500);
|
||||
expect(await response.json()).toEqual({
|
||||
body: {
|
||||
error: {},
|
||||
provider: 'test-provider',
|
||||
},
|
||||
errorType: 500,
|
||||
});
|
||||
expect(response.status).toBe(401);
|
||||
});
|
||||
});
|
||||
|
||||
describe('chat', () => {
|
||||
it('should correctly handle chat completion with valid payload', async () => {
|
||||
vi.mocked(getXorPayload).mockReturnValueOnce({
|
||||
apiKey: 'test-api-key',
|
||||
azureApiVersion: 'v1',
|
||||
userId: 'abc',
|
||||
});
|
||||
|
||||
const mockParams = Promise.resolve({ provider: 'test-provider' });
|
||||
const mockChatPayload = { message: 'Hello, world!' };
|
||||
request = new Request(new URL('https://test.com'), {
|
||||
headers: { [LOBE_CHAT_AUTH_HEADER]: 'Bearer some-valid-token' },
|
||||
method: 'POST',
|
||||
body: JSON.stringify(mockChatPayload),
|
||||
});
|
||||
@@ -157,21 +101,15 @@ describe('POST handler', () => {
|
||||
|
||||
expect(response).toEqual(mockChatResponse);
|
||||
expect(mockRuntime.chat).toHaveBeenCalledWith(mockChatPayload, {
|
||||
user: expect.any(String),
|
||||
user: 'test-user-id',
|
||||
signal: expect.anything(),
|
||||
});
|
||||
});
|
||||
|
||||
it('should return an error response when chat completion fails', async () => {
|
||||
vi.mocked(getXorPayload).mockReturnValueOnce({
|
||||
apiKey: 'test-api-key',
|
||||
azureApiVersion: 'v1',
|
||||
});
|
||||
|
||||
const mockParams = Promise.resolve({ provider: 'test-provider' });
|
||||
const mockChatPayload = { message: 'Hello, world!' };
|
||||
request = new Request(new URL('https://test.com'), {
|
||||
headers: { [LOBE_CHAT_AUTH_HEADER]: 'Bearer some-valid-token' },
|
||||
method: 'POST',
|
||||
body: JSON.stringify(mockChatPayload),
|
||||
});
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { type ChatCompletionErrorPayload, type ModelRuntime } from '@lobechat/model-runtime';
|
||||
import { type ChatCompletionErrorPayload } from '@lobechat/model-runtime';
|
||||
import { AGENT_RUNTIME_ERROR_SET } from '@lobechat/model-runtime';
|
||||
import { ChatErrorType } from '@lobechat/types';
|
||||
|
||||
@@ -12,53 +12,44 @@ import { getTracePayload } from '@/utils/trace';
|
||||
// this enforce user to enable fluid compute
|
||||
export const maxDuration = 300;
|
||||
|
||||
export const POST = checkAuth(
|
||||
async (req: Request, { params, userId, serverDB, createRuntime, jwtPayload }) => {
|
||||
const provider = (await params)!.provider!;
|
||||
export const POST = checkAuth(async (req: Request, { params, userId, serverDB }) => {
|
||||
const provider = (await params)!.provider!;
|
||||
|
||||
try {
|
||||
// ============ 1. init chat model ============ //
|
||||
let modelRuntime: ModelRuntime;
|
||||
if (createRuntime) {
|
||||
// Legacy support for custom runtime creation
|
||||
modelRuntime = createRuntime(jwtPayload);
|
||||
} else {
|
||||
// Read user's provider config from database
|
||||
modelRuntime = await initModelRuntimeFromDB(serverDB, userId, provider);
|
||||
}
|
||||
try {
|
||||
// ============ 1. init chat model ============ //
|
||||
const modelRuntime = await initModelRuntimeFromDB(serverDB, userId, provider);
|
||||
|
||||
// ============ 2. create chat completion ============ //
|
||||
// ============ 2. create chat completion ============ //
|
||||
|
||||
const data = (await req.json()) as ChatStreamPayload;
|
||||
const data = (await req.json()) as ChatStreamPayload;
|
||||
|
||||
const tracePayload = getTracePayload(req);
|
||||
const tracePayload = getTracePayload(req);
|
||||
|
||||
let traceOptions = {};
|
||||
// If user enable trace
|
||||
if (tracePayload?.enabled) {
|
||||
traceOptions = createTraceOptions(data, { provider, trace: tracePayload });
|
||||
}
|
||||
|
||||
return await modelRuntime.chat(data, {
|
||||
user: userId,
|
||||
...traceOptions,
|
||||
signal: req.signal,
|
||||
});
|
||||
} catch (e) {
|
||||
const {
|
||||
errorType = ChatErrorType.InternalServerError,
|
||||
error: errorContent,
|
||||
...res
|
||||
} = e as ChatCompletionErrorPayload;
|
||||
|
||||
const error = errorContent || e;
|
||||
|
||||
const logMethod = AGENT_RUNTIME_ERROR_SET.has(errorType as string) ? 'warn' : 'error';
|
||||
// track the error at server side
|
||||
// eslint-disable-next-line no-console
|
||||
console[logMethod](`Route: [${provider}] ${errorType}:`, error);
|
||||
|
||||
return createErrorResponse(errorType, { error, ...res, provider });
|
||||
let traceOptions = {};
|
||||
// If user enable trace
|
||||
if (tracePayload?.enabled) {
|
||||
traceOptions = createTraceOptions(data, { provider, trace: tracePayload });
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
return await modelRuntime.chat(data, {
|
||||
user: userId,
|
||||
...traceOptions,
|
||||
signal: req.signal,
|
||||
});
|
||||
} catch (e) {
|
||||
const {
|
||||
errorType = ChatErrorType.InternalServerError,
|
||||
error: errorContent,
|
||||
...res
|
||||
} = e as ChatCompletionErrorPayload;
|
||||
|
||||
const error = errorContent || e;
|
||||
|
||||
const logMethod = AGENT_RUNTIME_ERROR_SET.has(errorType as string) ? 'warn' : 'error';
|
||||
// track the error at server side
|
||||
// eslint-disable-next-line no-console
|
||||
console[logMethod](`Route: [${provider}] ${errorType}:`, error);
|
||||
|
||||
return createErrorResponse(errorType, { error, ...res, provider });
|
||||
}
|
||||
});
|
||||
|
||||
@@ -2,11 +2,9 @@
|
||||
import { type LobeRuntimeAI } from '@lobechat/model-runtime';
|
||||
import { ModelRuntime } from '@lobechat/model-runtime';
|
||||
import { ChatErrorType } from '@lobechat/types';
|
||||
import { getXorPayload } from '@lobechat/utils/server';
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import type * as EnvsAuthModule from '@/envs/auth';
|
||||
import { LOBE_CHAT_AUTH_HEADER } from '@/envs/auth';
|
||||
import { auth } from '@/auth';
|
||||
import { initModelRuntimeFromDB } from '@/server/modules/ModelRuntime';
|
||||
|
||||
import { GET } from './route';
|
||||
@@ -15,17 +13,6 @@ vi.mock('@/app/(backend)/middleware/auth/utils', () => ({
|
||||
checkAuthMethod: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('@lobechat/utils/server', () => ({
|
||||
getXorPayload: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock('@/envs/auth', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof EnvsAuthModule>();
|
||||
return {
|
||||
...actual,
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock('@/auth', () => ({
|
||||
auth: {
|
||||
api: {
|
||||
@@ -42,11 +29,14 @@ let request: Request;
|
||||
|
||||
beforeEach(() => {
|
||||
request = new Request(new URL('https://test.com'), {
|
||||
headers: {
|
||||
[LOBE_CHAT_AUTH_HEADER]: 'Bearer some-valid-token',
|
||||
},
|
||||
method: 'GET',
|
||||
});
|
||||
|
||||
// Default: valid session
|
||||
vi.mocked(auth.api.getSession).mockResolvedValue({
|
||||
session: {} as any,
|
||||
user: { id: 'test-user-id' } as any,
|
||||
});
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
@@ -58,10 +48,6 @@ describe('GET handler', () => {
|
||||
it('should not expose stack trace when an Error is thrown', async () => {
|
||||
const mockParams = Promise.resolve({ provider: 'google' });
|
||||
|
||||
vi.mocked(getXorPayload).mockReturnValueOnce({
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
const errorWithStack = new Error('Something went wrong');
|
||||
errorWithStack.stack =
|
||||
'Error: Something went wrong\n at Object.<anonymous> (/path/to/file.ts:10:15)';
|
||||
@@ -76,14 +62,10 @@ describe('GET handler', () => {
|
||||
const response = await GET(request, { params: mockParams });
|
||||
const responseBody = await response.json();
|
||||
|
||||
// Should contain error name and message
|
||||
expect(responseBody.body.error.name).toBe('Error');
|
||||
expect(responseBody.body.error.message).toBe('Something went wrong');
|
||||
|
||||
// Should NOT contain stack trace
|
||||
expect(responseBody.body.error.stack).toBeUndefined();
|
||||
|
||||
// Verify JSON stringified response doesn't contain stack
|
||||
const responseText = JSON.stringify(responseBody);
|
||||
expect(responseText).not.toContain('/path/to/file.ts');
|
||||
expect(responseText).not.toContain('at Object');
|
||||
@@ -92,10 +74,6 @@ describe('GET handler', () => {
|
||||
it('should preserve error name for custom error types', async () => {
|
||||
const mockParams = Promise.resolve({ provider: 'google' });
|
||||
|
||||
vi.mocked(getXorPayload).mockReturnValueOnce({
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
class CustomError extends Error {
|
||||
constructor(message: string) {
|
||||
super(message);
|
||||
@@ -124,10 +102,6 @@ describe('GET handler', () => {
|
||||
it('should pass through structured error objects as-is', async () => {
|
||||
const mockParams = Promise.resolve({ provider: 'google' });
|
||||
|
||||
vi.mocked(getXorPayload).mockReturnValueOnce({
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
const structuredError = {
|
||||
errorType: ChatErrorType.InternalServerError,
|
||||
error: { code: 'PROVIDER_ERROR', details: 'API limit exceeded' },
|
||||
@@ -143,7 +117,6 @@ describe('GET handler', () => {
|
||||
const response = await GET(request, { params: mockParams });
|
||||
const responseBody = await response.json();
|
||||
|
||||
// Structured error should be passed through
|
||||
expect(responseBody.body.error.code).toBe('PROVIDER_ERROR');
|
||||
expect(responseBody.body.error.details).toBe('API limit exceeded');
|
||||
});
|
||||
@@ -151,10 +124,6 @@ describe('GET handler', () => {
|
||||
it('should return correct status code for errors', async () => {
|
||||
const mockParams = Promise.resolve({ provider: 'google' });
|
||||
|
||||
vi.mocked(getXorPayload).mockReturnValueOnce({
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
const mockRuntime: LobeRuntimeAI = {
|
||||
baseURL: 'abc',
|
||||
chat: vi.fn(),
|
||||
@@ -170,10 +139,6 @@ describe('GET handler', () => {
|
||||
it('should include provider in error response', async () => {
|
||||
const mockParams = Promise.resolve({ provider: 'openai' });
|
||||
|
||||
vi.mocked(getXorPayload).mockReturnValueOnce({
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
const mockRuntime: LobeRuntimeAI = {
|
||||
baseURL: 'abc',
|
||||
chat: vi.fn(),
|
||||
@@ -192,10 +157,6 @@ describe('GET handler', () => {
|
||||
it('should return model list on success', async () => {
|
||||
const mockParams = Promise.resolve({ provider: 'openai' });
|
||||
|
||||
vi.mocked(getXorPayload).mockReturnValueOnce({
|
||||
apiKey: 'test-api-key',
|
||||
});
|
||||
|
||||
const mockModelList = [
|
||||
{ id: 'gpt-4', name: 'GPT-4' },
|
||||
{ id: 'gpt-3.5-turbo', name: 'GPT-3.5 Turbo' },
|
||||
|
||||
@@ -75,6 +75,8 @@ export const getAppConfig = () => {
|
||||
*/
|
||||
MARKET_TRUSTED_CLIENT_ID: z.string().optional(),
|
||||
|
||||
AGENT_GATEWAY_SERVICE_TOKEN: z.string().optional(),
|
||||
AGENT_GATEWAY_URL: z.string().url(),
|
||||
/**
|
||||
* Enable Queue-based Agent Runtime
|
||||
* When true, use QStash for async agent execution (production)
|
||||
@@ -118,6 +120,8 @@ export const getAppConfig = () => {
|
||||
MARKET_TRUSTED_CLIENT_SECRET: process.env.MARKET_TRUSTED_CLIENT_SECRET,
|
||||
MARKET_TRUSTED_CLIENT_ID: process.env.MARKET_TRUSTED_CLIENT_ID,
|
||||
|
||||
AGENT_GATEWAY_SERVICE_TOKEN: process.env.AGENT_GATEWAY_SERVICE_TOKEN,
|
||||
AGENT_GATEWAY_URL: process.env.AGENT_GATEWAY_URL || 'https://agent-gateway.lobehub.com',
|
||||
enableQueueAgentRuntime: process.env.AGENT_RUNTIME_MODE === 'queue',
|
||||
TELEMETRY_DISABLED: process.env.TELEMETRY_DISABLED === '1',
|
||||
},
|
||||
|
||||
@@ -298,4 +298,3 @@ export const authEnv = getAuthConfig();
|
||||
// Auth headers and constants
|
||||
export const LOBE_CHAT_AUTH_HEADER = 'X-lobe-chat-auth';
|
||||
export const LOBE_CHAT_OIDC_AUTH_HEADER = 'Oidc-Auth';
|
||||
export const SECRET_XOR_KEY = 'LobeHub · LobeHub';
|
||||
|
||||
@@ -159,6 +159,9 @@ const InputEditor = memo<{ defaultRows?: number }>(({ defaultRows = 2 }) => {
|
||||
input: string;
|
||||
selectionType: string;
|
||||
}): Promise<string | null> => {
|
||||
// Skip autocomplete during IME composition (e.g. Chinese input method)
|
||||
if (isComposingRef.current) return null;
|
||||
|
||||
if (!input.trim()) return null;
|
||||
|
||||
const { enabled: _, ...config } = systemAgentSelectors.inputCompletion(
|
||||
@@ -188,7 +191,7 @@ const InputEditor = memo<{ defaultRows?: number }>(({ defaultRows = 2 }) => {
|
||||
|
||||
if (abortSignal.aborted) return null;
|
||||
|
||||
return result || null;
|
||||
return result.trimEnd() || null;
|
||||
},
|
||||
[],
|
||||
);
|
||||
@@ -204,6 +207,51 @@ const InputEditor = memo<{ defaultRows?: number }>(({ defaultRows = 2 }) => {
|
||||
[isAutoCompleteEnabled, handleAutoComplete],
|
||||
);
|
||||
|
||||
// --- Stable mentionOption & slashOption to prevent infinite re-render on paste ---
|
||||
const mentionMarkdownWriter = useCallback((mention: any) => {
|
||||
if (mention.metadata?.type === 'topic') {
|
||||
return `<refer_topic name="${mention.metadata.topicTitle}" id="${mention.metadata.topicId}" />`;
|
||||
}
|
||||
return `<mention name="${mention.label}" id="${mention.metadata.id}" />`;
|
||||
}, []);
|
||||
|
||||
const mentionOnSelect = useCallback((editor: any, option: any) => {
|
||||
if (option.metadata?.type === 'topic') {
|
||||
editor.dispatchCommand(INSERT_REFER_TOPIC_COMMAND, {
|
||||
topicId: option.metadata.topicId as string,
|
||||
topicTitle: String(option.metadata.topicTitle ?? option.label),
|
||||
});
|
||||
} else if (option.metadata?.type === 'skill' || option.metadata?.type === 'tool') {
|
||||
const payload: InsertActionTagPayload = {
|
||||
category: option.metadata.actionCategory as 'skill' | 'tool',
|
||||
label: String(option.label),
|
||||
type: String(option.metadata.actionType),
|
||||
};
|
||||
editor.dispatchCommand(INSERT_ACTION_TAG_COMMAND, payload);
|
||||
} else {
|
||||
editor.dispatchCommand(INSERT_MENTION_COMMAND, {
|
||||
label: String(option.label),
|
||||
metadata: option.metadata,
|
||||
});
|
||||
}
|
||||
}, []);
|
||||
|
||||
const mentionOption = useMemo(
|
||||
() =>
|
||||
enableMention
|
||||
? {
|
||||
items: mentionItemsFn,
|
||||
markdownWriter: mentionMarkdownWriter,
|
||||
maxLength: 50,
|
||||
onSelect: mentionOnSelect,
|
||||
renderComp: MentionMenuComp,
|
||||
}
|
||||
: undefined,
|
||||
[enableMention, mentionItemsFn, mentionMarkdownWriter, mentionOnSelect, MentionMenuComp],
|
||||
);
|
||||
|
||||
const slashOption = useMemo(() => ({ items: slashItems }), [slashItems]);
|
||||
|
||||
const richRenderProps = useMemo(() => {
|
||||
const basePlugins = !enableRichRender
|
||||
? CHAT_INPUT_EMBED_PLUGINS
|
||||
@@ -233,47 +281,11 @@ const InputEditor = memo<{ defaultRows?: number }>(({ defaultRows = 2 }) => {
|
||||
editor={editor}
|
||||
{...{ slashPlacement }}
|
||||
{...richRenderProps}
|
||||
mentionOption={mentionOption}
|
||||
placeholder={<Placeholder />}
|
||||
slashOption={slashOption}
|
||||
type={'text'}
|
||||
variant={'chat'}
|
||||
mentionOption={
|
||||
enableMention
|
||||
? {
|
||||
items: mentionItemsFn,
|
||||
markdownWriter: (mention) => {
|
||||
if (mention.metadata?.type === 'topic') {
|
||||
return `<refer_topic name="${mention.metadata.topicTitle}" id="${mention.metadata.topicId}" />`;
|
||||
}
|
||||
return `<mention name="${mention.label}" id="${mention.metadata.id}" />`;
|
||||
},
|
||||
maxLength: 50,
|
||||
onSelect: (editor, option) => {
|
||||
if (option.metadata?.type === 'topic') {
|
||||
editor.dispatchCommand(INSERT_REFER_TOPIC_COMMAND, {
|
||||
topicId: option.metadata.topicId as string,
|
||||
topicTitle: String(option.metadata.topicTitle ?? option.label),
|
||||
});
|
||||
} else if (option.metadata?.type === 'skill' || option.metadata?.type === 'tool') {
|
||||
const payload: InsertActionTagPayload = {
|
||||
category: option.metadata.actionCategory as 'skill' | 'tool',
|
||||
label: String(option.label),
|
||||
type: String(option.metadata.actionType),
|
||||
};
|
||||
editor.dispatchCommand(INSERT_ACTION_TAG_COMMAND, payload);
|
||||
} else {
|
||||
editor.dispatchCommand(INSERT_MENTION_COMMAND, {
|
||||
label: String(option.label),
|
||||
metadata: option.metadata,
|
||||
});
|
||||
}
|
||||
},
|
||||
renderComp: MentionMenuComp,
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
slashOption={{
|
||||
items: slashItems,
|
||||
}}
|
||||
style={{
|
||||
minHeight: defaultRows > 1 ? defaultRows * 23 : undefined,
|
||||
}}
|
||||
|
||||
@@ -11,15 +11,14 @@ import { SkillStoreTab } from '../SkillStoreContent';
|
||||
interface SearchProps {
|
||||
activeTab: SkillStoreTab;
|
||||
onLobeHubSearch: (keywords: string) => void;
|
||||
onSkillSearch: (keywords: string) => void;
|
||||
}
|
||||
|
||||
export const Search = memo<SearchProps>(({ activeTab, onLobeHubSearch }) => {
|
||||
export const Search = memo<SearchProps>(({ activeTab, onLobeHubSearch, onSkillSearch }) => {
|
||||
const { t } = useTranslation('setting');
|
||||
const mcpKeywords = useToolStore((s) => s.mcpSearchKeywords);
|
||||
|
||||
const isCustomTab = activeTab === SkillStoreTab.Custom;
|
||||
|
||||
const keywords = activeTab === SkillStoreTab.Community ? mcpKeywords : '';
|
||||
const keywords = activeTab === SkillStoreTab.MCP ? mcpKeywords : '';
|
||||
|
||||
return (
|
||||
<Flexbox horizontal align={'center'} gap={8} justify={'space-between'}>
|
||||
@@ -30,9 +29,11 @@ export const Search = memo<SearchProps>(({ activeTab, onLobeHubSearch }) => {
|
||||
placeholder={t('skillStore.search')}
|
||||
variant="outlined"
|
||||
onSearch={(keywords: string) => {
|
||||
if (activeTab === SkillStoreTab.Community) {
|
||||
if (activeTab === SkillStoreTab.MCP) {
|
||||
useToolStore.setState({ mcpSearchKeywords: keywords, searchLoading: true });
|
||||
} else if (isCustomTab) {
|
||||
} else if (activeTab === SkillStoreTab.Skills) {
|
||||
onSkillSearch(keywords);
|
||||
} else if (activeTab === SkillStoreTab.Custom) {
|
||||
useToolStore.setState({ customPluginSearchKeywords: keywords });
|
||||
} else {
|
||||
onLobeHubSearch(keywords);
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
'use client';
|
||||
|
||||
import { ActionIcon, Avatar, Block, DropdownMenu, Flexbox, Icon, Modal, Tag } from '@lobehub/ui';
|
||||
import { SkillsIcon } from '@lobehub/ui/icons';
|
||||
import { App } from 'antd';
|
||||
import { createStaticStyles, cssVar } from 'antd-style';
|
||||
import { DownloadIcon, Loader2, MoreVerticalIcon, Plus, Trash2 } from 'lucide-react';
|
||||
import { lazy, memo, Suspense, useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { agentSkillService } from '@/services/skill';
|
||||
import { useToolStore } from '@/store/tool';
|
||||
import { agentSkillsSelectors } from '@/store/tool/selectors';
|
||||
import { type DiscoverSkillItem } from '@/types/discover';
|
||||
import { downloadFile } from '@/utils/client/downloadFile';
|
||||
|
||||
import { itemStyles } from '../style';
|
||||
|
||||
const MarketSkillDetail = lazy(() => import('../MarketSkills/MarketSkillDetail'));
|
||||
|
||||
const styles = createStaticStyles(({ css }) => ({
|
||||
title: css`
|
||||
cursor: pointer;
|
||||
|
||||
overflow: hidden;
|
||||
|
||||
font-size: 14px;
|
||||
font-weight: 500;
|
||||
color: ${cssVar.colorText};
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
|
||||
&:hover {
|
||||
color: ${cssVar.colorPrimary};
|
||||
}
|
||||
`,
|
||||
}));
|
||||
|
||||
const MarketSkillItem = memo<DiscoverSkillItem>(({ name, icon, description, identifier }) => {
|
||||
const { t } = useTranslation('plugin');
|
||||
const { t: tc } = useTranslation('common');
|
||||
const [detailOpen, setDetailOpen] = useState(false);
|
||||
const [installing, setInstalling] = useState(false);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const { modal } = App.useApp();
|
||||
|
||||
const installed = useToolStore(agentSkillsSelectors.isAgentSkill(identifier));
|
||||
const installedSkill = useToolStore(agentSkillsSelectors.getAgentSkillByIdentifier(identifier));
|
||||
const [refreshAgentSkills, deleteAgentSkill] = useToolStore((s) => [
|
||||
s.refreshAgentSkills,
|
||||
s.deleteAgentSkill,
|
||||
]);
|
||||
|
||||
const handleInstall = useCallback(async () => {
|
||||
if (installing || installed) return;
|
||||
setInstalling(true);
|
||||
try {
|
||||
await agentSkillService.importFromMarket(identifier);
|
||||
await refreshAgentSkills();
|
||||
} catch {
|
||||
// silently fail
|
||||
} finally {
|
||||
setInstalling(false);
|
||||
}
|
||||
}, [identifier, installing, installed, refreshAgentSkills]);
|
||||
|
||||
const handleUninstall = useCallback(() => {
|
||||
if (!installedSkill) return;
|
||||
modal.confirm({
|
||||
centered: true,
|
||||
okButtonProps: { danger: true },
|
||||
onOk: async () => {
|
||||
await deleteAgentSkill(installedSkill.id);
|
||||
},
|
||||
title: t('store.actions.confirmUninstall'),
|
||||
type: 'error',
|
||||
});
|
||||
}, [installedSkill, deleteAgentSkill, modal, t]);
|
||||
|
||||
const handleDownload = useCallback(async () => {
|
||||
if (!installedSkill?.zipFileHash) return;
|
||||
setLoading(true);
|
||||
try {
|
||||
const result = await agentSkillService.getZipUrl(installedSkill.id);
|
||||
if (result.url) {
|
||||
await downloadFile(result.url, `${result.name || name}.zip`);
|
||||
}
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, [installedSkill, name]);
|
||||
|
||||
const renderAction = () => {
|
||||
if (installed) {
|
||||
return (
|
||||
<DropdownMenu
|
||||
nativeButton={false}
|
||||
placement="bottomRight"
|
||||
items={[
|
||||
...(installedSkill?.zipFileHash
|
||||
? [
|
||||
{
|
||||
icon: <Icon icon={DownloadIcon} />,
|
||||
key: 'download',
|
||||
label: tc('download'),
|
||||
onClick: handleDownload,
|
||||
},
|
||||
{ type: 'divider' as const },
|
||||
]
|
||||
: []),
|
||||
{
|
||||
danger: true,
|
||||
icon: <Icon icon={Trash2} />,
|
||||
key: 'uninstall',
|
||||
label: t('store.actions.uninstall'),
|
||||
onClick: handleUninstall,
|
||||
},
|
||||
]}
|
||||
>
|
||||
<ActionIcon icon={MoreVerticalIcon} loading={loading} />
|
||||
</DropdownMenu>
|
||||
);
|
||||
}
|
||||
|
||||
if (installing) return <ActionIcon loading icon={Loader2} />;
|
||||
|
||||
return <ActionIcon icon={Plus} title={t('store.actions.install')} onClick={handleInstall} />;
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
<Flexbox className={itemStyles.container} gap={0}>
|
||||
<Block
|
||||
horizontal
|
||||
align={'center'}
|
||||
gap={12}
|
||||
paddingBlock={12}
|
||||
paddingInline={12}
|
||||
variant={'outlined'}
|
||||
>
|
||||
<Avatar avatar={icon || name} shape={'square'} size={40} style={{ flex: 'none' }} />
|
||||
<Flexbox flex={1} gap={4} style={{ minWidth: 0, overflow: 'hidden' }}>
|
||||
<Flexbox horizontal align="center" gap={8}>
|
||||
<span className={styles.title} onClick={() => setDetailOpen(true)}>
|
||||
{name}
|
||||
</span>
|
||||
<Tag icon={<Icon icon={SkillsIcon} />} size={'small'} />
|
||||
</Flexbox>
|
||||
{description && <span className={itemStyles.description}>{description}</span>}
|
||||
</Flexbox>
|
||||
{renderAction()}
|
||||
</Block>
|
||||
</Flexbox>
|
||||
<Modal
|
||||
destroyOnHidden
|
||||
footer={null}
|
||||
open={detailOpen}
|
||||
styles={{ body: { height: 'calc(100dvh - 200px)', overflow: 'hidden', padding: 0 } }}
|
||||
title={t('dev.title.skillDetails')}
|
||||
width={960}
|
||||
onCancel={() => setDetailOpen(false)}
|
||||
>
|
||||
<Suspense fallback={<div style={{ height: '100%' }} />}>
|
||||
<MarketSkillDetail identifier={identifier} />
|
||||
</Suspense>
|
||||
</Modal>
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
MarketSkillItem.displayName = 'MarketSkillItem';
|
||||
|
||||
export default MarketSkillItem;
|
||||
@@ -0,0 +1,98 @@
|
||||
'use client';
|
||||
|
||||
import { Center, Icon, Text } from '@lobehub/ui';
|
||||
import { ServerCrash } from 'lucide-react';
|
||||
import { memo, useEffect, useRef } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { VirtuosoGrid } from 'react-virtuoso';
|
||||
|
||||
import { useToolStore } from '@/store/tool';
|
||||
|
||||
import Item from '../Community/Item';
|
||||
import Empty from '../Empty';
|
||||
import Loading from '../Loading';
|
||||
import { virtuosoGridStyles } from '../style';
|
||||
import VirtuosoLoading from '../VirtuosoLoading';
|
||||
import WantMoreSkills from '../WantMoreSkills';
|
||||
|
||||
export const MCPList = memo(() => {
|
||||
const { t } = useTranslation('setting');
|
||||
|
||||
const [
|
||||
keywords,
|
||||
isMcpListInit,
|
||||
allItems,
|
||||
currentPage,
|
||||
totalPages,
|
||||
searchLoading,
|
||||
useFetchMCPPluginList,
|
||||
loadMoreMCPPlugins,
|
||||
resetMCPPluginList,
|
||||
] = useToolStore((s) => [
|
||||
s.mcpSearchKeywords,
|
||||
s.isMcpListInit,
|
||||
s.mcpPluginItems,
|
||||
s.currentPage,
|
||||
s.totalPages,
|
||||
s.searchLoading,
|
||||
s.useFetchMCPPluginList,
|
||||
s.loadMoreMCPPlugins,
|
||||
s.resetMCPPluginList,
|
||||
]);
|
||||
|
||||
const prevKeywordsRef = useRef(keywords);
|
||||
|
||||
useEffect(() => {
|
||||
if (prevKeywordsRef.current !== keywords) {
|
||||
prevKeywordsRef.current = keywords;
|
||||
resetMCPPluginList(keywords);
|
||||
}
|
||||
}, [keywords, resetMCPPluginList]);
|
||||
|
||||
const { isLoading, error } = useFetchMCPPluginList({
|
||||
page: currentPage,
|
||||
pageSize: 20,
|
||||
q: keywords,
|
||||
});
|
||||
|
||||
const hasSearchKeywords = Boolean(keywords && keywords.trim());
|
||||
|
||||
if (searchLoading || !isMcpListInit || (isLoading && allItems.length === 0)) return <Loading />;
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<Center gap={12} padding={40}>
|
||||
<Icon icon={ServerCrash} size={80} />
|
||||
<Text type={'secondary'}>{t('skillStore.networkError')}</Text>
|
||||
</Center>
|
||||
);
|
||||
}
|
||||
|
||||
if (allItems.length === 0) return <Empty search={hasSearchKeywords} />;
|
||||
|
||||
const hasReachedEnd = totalPages !== undefined && currentPage >= totalPages;
|
||||
|
||||
const renderFooter = () => {
|
||||
if (isLoading) return <VirtuosoLoading />;
|
||||
if (hasReachedEnd) return <WantMoreSkills />;
|
||||
return <div style={{ height: 16 }} />;
|
||||
};
|
||||
|
||||
return (
|
||||
<VirtuosoGrid
|
||||
components={{ Footer: renderFooter }}
|
||||
data={allItems}
|
||||
endReached={loadMoreMCPPlugins}
|
||||
increaseViewportBy={typeof window !== 'undefined' ? window.innerHeight : 0}
|
||||
itemClassName={virtuosoGridStyles.item}
|
||||
itemContent={(_, item) => <Item {...item} />}
|
||||
listClassName={virtuosoGridStyles.list}
|
||||
overscan={24}
|
||||
style={{ height: '60vh', width: '100%' }}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
MCPList.displayName = 'MCPList';
|
||||
|
||||
export default MCPList;
|
||||
@@ -0,0 +1,260 @@
|
||||
'use client';
|
||||
|
||||
import { type SkillResourceTreeNode } from '@lobechat/types';
|
||||
import { Github } from '@lobehub/icons';
|
||||
import { ActionIcon, Avatar, Flexbox, Icon } from '@lobehub/ui';
|
||||
import { Skeleton } from 'antd';
|
||||
import { createStaticStyles, cssVar } from 'antd-style';
|
||||
import { unzip } from 'fflate';
|
||||
import { DotIcon, ExternalLinkIcon } from 'lucide-react';
|
||||
import { memo, useEffect, useMemo, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import PublishedTime from '@/components/PublishedTime';
|
||||
import { marketApiService } from '@/services/marketApi';
|
||||
import { useDiscoverStore } from '@/store/discover';
|
||||
import { useToolStore } from '@/store/tool';
|
||||
import { agentSkillsSelectors } from '@/store/tool/selectors';
|
||||
import { type DiscoverSkillDetail as DiscoverSkillDetailType } from '@/types/discover';
|
||||
|
||||
import ContentViewer from '../../../AgentSkillDetail/ContentViewer';
|
||||
import FileTree from '../../../AgentSkillDetail/FileTree';
|
||||
|
||||
const styles = createStaticStyles(({ css, cssVar }) => ({
|
||||
description: css`
|
||||
overflow: hidden;
|
||||
|
||||
margin: 0;
|
||||
|
||||
font-size: 13px;
|
||||
line-height: 1.5;
|
||||
color: ${cssVar.colorTextSecondary};
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
`,
|
||||
divider: css`
|
||||
flex-shrink: 0;
|
||||
width: 1px;
|
||||
background: ${cssVar.colorBorderSecondary};
|
||||
`,
|
||||
left: css`
|
||||
overflow-y: auto;
|
||||
flex-shrink: 0;
|
||||
width: 240px;
|
||||
padding: 8px;
|
||||
`,
|
||||
meta: css`
|
||||
flex-shrink: 0;
|
||||
padding: 16px;
|
||||
border-block-end: 1px solid ${cssVar.colorBorderSecondary};
|
||||
`,
|
||||
name: css`
|
||||
font-size: 16px;
|
||||
font-weight: 500;
|
||||
line-height: 1.4;
|
||||
color: ${cssVar.colorText};
|
||||
`,
|
||||
right: css`
|
||||
container-type: size;
|
||||
overflow: auto;
|
||||
flex: 1;
|
||||
`,
|
||||
}));
|
||||
|
||||
interface MarketSkillDetailProps {
|
||||
identifier: string;
|
||||
}
|
||||
|
||||
const buildContentMap = (nodes: SkillResourceTreeNode[]): Record<string, string> => {
|
||||
const map: Record<string, string> = {};
|
||||
const walk = (items: SkillResourceTreeNode[]) => {
|
||||
for (const node of items) {
|
||||
if (node.type === 'file' && node.content !== undefined) {
|
||||
map[node.path] = node.content;
|
||||
} else if (node.children) {
|
||||
walk(node.children);
|
||||
}
|
||||
}
|
||||
};
|
||||
walk(nodes);
|
||||
return map;
|
||||
};
|
||||
|
||||
const buildMarketResourceTree = (
|
||||
resources?: DiscoverSkillDetailType['resources'],
|
||||
): { name: string; path: string; type: 'file' }[] => {
|
||||
if (!resources) return [];
|
||||
return Object.keys(resources)
|
||||
.sort()
|
||||
.map((path) => ({
|
||||
name: path.split('/').pop() || path,
|
||||
path,
|
||||
type: 'file' as const,
|
||||
}));
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetch zip from downloadUrl and extract text file contents
|
||||
*/
|
||||
const fetchZipContents = async (
|
||||
url: string,
|
||||
): Promise<{ contentMap: Record<string, string>; tree: SkillResourceTreeNode[] }> => {
|
||||
const res = await fetch(url);
|
||||
const buf = await res.arrayBuffer();
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
unzip(new Uint8Array(buf), (err, files) => {
|
||||
if (err) return reject(err);
|
||||
|
||||
const contentMap: Record<string, string> = {};
|
||||
const tree: SkillResourceTreeNode[] = [];
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
for (const [rawPath, data] of Object.entries(files)) {
|
||||
if (rawPath.endsWith('/') || rawPath.includes('__MACOSX')) continue;
|
||||
|
||||
// Strip the top-level directory prefix (e.g. "skill-name/")
|
||||
const slashIdx = rawPath.indexOf('/');
|
||||
const path = slashIdx >= 0 ? rawPath.slice(slashIdx + 1) : rawPath;
|
||||
if (!path || path === 'SKILL.md') continue;
|
||||
|
||||
const content = decoder.decode(data);
|
||||
contentMap[path] = content;
|
||||
tree.push({
|
||||
content,
|
||||
name: path.split('/').pop() || path,
|
||||
path,
|
||||
type: 'file',
|
||||
});
|
||||
}
|
||||
|
||||
tree.sort((a, b) => a.path.localeCompare(b.path));
|
||||
resolve({ contentMap, tree });
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
const MarketSkillDetail = memo<MarketSkillDetailProps>(({ identifier }) => {
|
||||
const { t } = useTranslation('setting');
|
||||
const [selectedFile, setSelectedFile] = useState('SKILL.md');
|
||||
|
||||
// Market data (always fetched for header info + icon)
|
||||
const useFetchSkillDetail = useDiscoverStore((s) => s.useFetchSkillDetail);
|
||||
const { data, isLoading } = useFetchSkillDetail({ identifier });
|
||||
|
||||
// Installed skill data (for full file content)
|
||||
const installedSkill = useToolStore(agentSkillsSelectors.getAgentSkillByIdentifier(identifier));
|
||||
const { data: installedData } = useToolStore((s) => s.useFetchAgentSkillDetail)(
|
||||
installedSkill?.id,
|
||||
);
|
||||
|
||||
// Zip-based content for uninstalled skills
|
||||
const [zipContentMap, setZipContentMap] = useState<Record<string, string>>({});
|
||||
const [zipTree, setZipTree] = useState<SkillResourceTreeNode[]>([]);
|
||||
|
||||
const downloadUrl = marketApiService.getSkillDownloadUrl(encodeURIComponent(identifier));
|
||||
|
||||
useEffect(() => {
|
||||
if (installedSkill) return;
|
||||
|
||||
fetchZipContents(downloadUrl)
|
||||
.then(({ contentMap, tree }) => {
|
||||
setZipContentMap(contentMap);
|
||||
setZipTree(tree);
|
||||
})
|
||||
.catch(() => {
|
||||
// fall back to metadata-only view
|
||||
});
|
||||
}, [downloadUrl, installedSkill]);
|
||||
|
||||
const installedResourceTree = useMemo(
|
||||
() => installedData?.resourceTree ?? [],
|
||||
[installedData?.resourceTree],
|
||||
);
|
||||
const installedContentMap = useMemo(
|
||||
() => buildContentMap(installedResourceTree),
|
||||
[installedResourceTree],
|
||||
);
|
||||
|
||||
// Pick the best content source: installed > zip > market metadata
|
||||
const contentMap = installedResourceTree.length > 0 ? installedContentMap : zipContentMap;
|
||||
const resourceTree = useMemo(() => {
|
||||
if (installedResourceTree.length > 0) return installedResourceTree;
|
||||
if (zipTree.length > 0) return zipTree;
|
||||
return buildMarketResourceTree(data?.resources);
|
||||
}, [installedResourceTree, zipTree, data?.resources]);
|
||||
|
||||
if (isLoading || !data) {
|
||||
return <Skeleton active paragraph={{ rows: 8 }} style={{ padding: 16 }} />;
|
||||
}
|
||||
|
||||
const { name, icon, version, description, homepage, github } = data;
|
||||
|
||||
const skillDetailForViewer = {
|
||||
content: installedData?.skillDetail?.content || data.content,
|
||||
} as any;
|
||||
|
||||
return (
|
||||
<Flexbox style={{ height: '100%', overflow: 'hidden' }}>
|
||||
<div className={styles.meta}>
|
||||
<Flexbox horizontal align={'center'} gap={12}>
|
||||
<Avatar avatar={icon || name} shape={'square'} size={40} style={{ flex: 'none' }} />
|
||||
<Flexbox flex={1} gap={4} style={{ overflow: 'hidden' }}>
|
||||
<Flexbox horizontal align={'center'} gap={8} justify={'space-between'}>
|
||||
<Flexbox horizontal align={'center'} className={styles.description} gap={4}>
|
||||
<span className={styles.name}>{name}</span>
|
||||
{version && (
|
||||
<>
|
||||
<Icon icon={DotIcon} />
|
||||
<span>v{version}</span>
|
||||
</>
|
||||
)}
|
||||
<Icon icon={DotIcon} />
|
||||
{t('agentSkillDetail.updatedAt')}{' '}
|
||||
<PublishedTime date={data.updatedAt} template={'MMM DD, YYYY'} />
|
||||
</Flexbox>
|
||||
<Flexbox horizontal align={'center'} gap={2} style={{ flexShrink: 0 }}>
|
||||
{github?.url && (
|
||||
<a href={github.url} rel="noreferrer" target={'_blank'}>
|
||||
<ActionIcon
|
||||
fill={cssVar.colorTextDescription}
|
||||
icon={Github}
|
||||
title={t('agentSkillDetail.repository')}
|
||||
/>
|
||||
</a>
|
||||
)}
|
||||
{homepage && (
|
||||
<a href={homepage} rel="noreferrer" target={'_blank'}>
|
||||
<ActionIcon icon={ExternalLinkIcon} title={t('agentSkillDetail.sourceUrl')} />
|
||||
</a>
|
||||
)}
|
||||
</Flexbox>
|
||||
</Flexbox>
|
||||
{description && <p className={styles.description}>{description}</p>}
|
||||
</Flexbox>
|
||||
</Flexbox>
|
||||
</div>
|
||||
<Flexbox horizontal style={{ flex: 1, overflow: 'hidden' }}>
|
||||
<div className={styles.left}>
|
||||
<FileTree
|
||||
resourceTree={resourceTree}
|
||||
selectedFile={selectedFile}
|
||||
onSelectFile={setSelectedFile}
|
||||
/>
|
||||
</div>
|
||||
<div className={styles.divider} />
|
||||
<div className={styles.right} key={selectedFile}>
|
||||
<ContentViewer
|
||||
contentMap={contentMap}
|
||||
selectedFile={selectedFile}
|
||||
skillDetail={skillDetailForViewer}
|
||||
/>
|
||||
</div>
|
||||
</Flexbox>
|
||||
</Flexbox>
|
||||
);
|
||||
});
|
||||
|
||||
MarketSkillDetail.displayName = 'MarketSkillDetail';
|
||||
|
||||
export default MarketSkillDetail;
|
||||
@@ -0,0 +1,119 @@
|
||||
'use client';
|
||||
|
||||
import { Center, Icon, Text } from '@lobehub/ui';
|
||||
import { uniqBy } from 'es-toolkit/compat';
|
||||
import { ServerCrash } from 'lucide-react';
|
||||
import { memo, useCallback, useEffect, useRef, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { VirtuosoGrid } from 'react-virtuoso';
|
||||
|
||||
import { useClientDataSWR } from '@/libs/swr';
|
||||
import { discoverService } from '@/services/discover';
|
||||
import { globalHelpers } from '@/store/global/helpers';
|
||||
import { useToolStore } from '@/store/tool';
|
||||
import { type DiscoverSkillItem, SkillSorts } from '@/types/discover';
|
||||
|
||||
import MarketSkillItem from '../Community/MarketSkillItem';
|
||||
import Empty from '../Empty';
|
||||
import Loading from '../Loading';
|
||||
import { virtuosoGridStyles } from '../style';
|
||||
import VirtuosoLoading from '../VirtuosoLoading';
|
||||
import WantMoreSkills from '../WantMoreSkills';
|
||||
|
||||
interface MarketSkillListProps {
|
||||
keywords?: string;
|
||||
}
|
||||
|
||||
const MarketSkillList = memo<MarketSkillListProps>(({ keywords }) => {
|
||||
const { t } = useTranslation('setting');
|
||||
|
||||
// Ensure agent skills are fetched so install status is available
|
||||
const useFetchAgentSkills = useToolStore((s) => s.useFetchAgentSkills);
|
||||
useFetchAgentSkills(true);
|
||||
|
||||
// Market skills pagination state
|
||||
const [page, setPage] = useState(1);
|
||||
const [items, setItems] = useState<DiscoverSkillItem[]>([]);
|
||||
const [totalPages, setTotalPages] = useState<number>();
|
||||
|
||||
const locale = globalHelpers.getCurrentLanguage();
|
||||
const { data, isLoading, error } = useClientDataSWR(
|
||||
['skill-store-market-skills', locale, keywords || '', page].filter(Boolean).join('-'),
|
||||
() =>
|
||||
discoverService.getSkillList({
|
||||
page,
|
||||
pageSize: 20,
|
||||
q: keywords || undefined,
|
||||
sort: SkillSorts.InstallCount,
|
||||
}),
|
||||
{ revalidateOnFocus: false },
|
||||
);
|
||||
|
||||
// Accumulate items across pages
|
||||
useEffect(() => {
|
||||
if (!data) return;
|
||||
setTotalPages(data.totalPages);
|
||||
|
||||
if (page === 1) {
|
||||
setItems(data.items);
|
||||
} else {
|
||||
setItems((prev) => uniqBy([...prev, ...data.items], (i) => i.identifier));
|
||||
}
|
||||
}, [data, page]);
|
||||
|
||||
// Reset on keyword change
|
||||
const prevKeywordsRef = useRef(keywords);
|
||||
useEffect(() => {
|
||||
if (prevKeywordsRef.current !== keywords) {
|
||||
prevKeywordsRef.current = keywords;
|
||||
setPage(1);
|
||||
setItems([]);
|
||||
setTotalPages(undefined);
|
||||
}
|
||||
}, [keywords]);
|
||||
|
||||
const loadMore = useCallback(() => {
|
||||
if (totalPages === undefined || page < totalPages) {
|
||||
setPage((p) => p + 1);
|
||||
}
|
||||
}, [page, totalPages]);
|
||||
|
||||
if (isLoading && items.length === 0) return <Loading />;
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<Center gap={12} padding={40}>
|
||||
<Icon icon={ServerCrash} size={80} />
|
||||
<Text type={'secondary'}>{t('skillStore.networkError')}</Text>
|
||||
</Center>
|
||||
);
|
||||
}
|
||||
|
||||
if (items.length === 0) return <Empty search={Boolean(keywords?.trim())} />;
|
||||
|
||||
const hasReachedEnd = totalPages !== undefined && page >= totalPages;
|
||||
|
||||
const renderFooter = () => {
|
||||
if (isLoading) return <VirtuosoLoading />;
|
||||
if (hasReachedEnd) return <WantMoreSkills />;
|
||||
return <div style={{ height: 16 }} />;
|
||||
};
|
||||
|
||||
return (
|
||||
<VirtuosoGrid
|
||||
components={{ Footer: renderFooter }}
|
||||
data={items}
|
||||
endReached={loadMore}
|
||||
increaseViewportBy={typeof window !== 'undefined' ? window.innerHeight : 0}
|
||||
itemClassName={virtuosoGridStyles.item}
|
||||
itemContent={(_, item) => <MarketSkillItem {...item} />}
|
||||
listClassName={virtuosoGridStyles.list}
|
||||
overscan={24}
|
||||
style={{ height: '60vh', width: '100%' }}
|
||||
/>
|
||||
);
|
||||
});
|
||||
|
||||
MarketSkillList.displayName = 'MarketSkillList';
|
||||
|
||||
export default MarketSkillList;
|
||||
@@ -7,29 +7,34 @@ import { useTranslation } from 'react-i18next';
|
||||
|
||||
import Search from './Search';
|
||||
import AddSkillButton from './SkillList/AddSkillButton';
|
||||
import CommunityList from './SkillList/Community';
|
||||
import CustomList from './SkillList/Custom';
|
||||
import LobeHubList from './SkillList/LobeHub';
|
||||
import MarketSkillList from './SkillList/MarketSkills';
|
||||
import MCPList from './SkillList/MCP';
|
||||
|
||||
export enum SkillStoreTab {
|
||||
Community = 'community',
|
||||
Custom = 'custom',
|
||||
LobeHub = 'lobehub',
|
||||
MCP = 'mcp',
|
||||
Skills = 'skills',
|
||||
}
|
||||
|
||||
export const SkillStoreContent = () => {
|
||||
const { t } = useTranslation('setting');
|
||||
const [activeTab, setActiveTab] = useState<SkillStoreTab>(SkillStoreTab.LobeHub);
|
||||
const [lobehubKeywords, setLobehubKeywords] = useState('');
|
||||
const [skillKeywords, setSkillKeywords] = useState('');
|
||||
|
||||
const options: SegmentedOptions = [
|
||||
{ label: t('skillStore.tabs.lobehub'), value: SkillStoreTab.LobeHub },
|
||||
{ label: t('skillStore.tabs.community'), value: SkillStoreTab.Community },
|
||||
{ label: 'Skills', value: SkillStoreTab.Skills },
|
||||
{ label: t('skillStore.tabs.mcp'), value: SkillStoreTab.MCP },
|
||||
{ label: t('skillStore.tabs.custom'), value: SkillStoreTab.Custom },
|
||||
];
|
||||
|
||||
const isLobeHub = activeTab === SkillStoreTab.LobeHub;
|
||||
const isCommunity = activeTab === SkillStoreTab.Community;
|
||||
const isSkills = activeTab === SkillStoreTab.Skills;
|
||||
const isMCP = activeTab === SkillStoreTab.MCP;
|
||||
const isCustom = activeTab === SkillStoreTab.Custom;
|
||||
|
||||
return (
|
||||
@@ -46,14 +51,21 @@ export const SkillStoreContent = () => {
|
||||
/>
|
||||
<AddSkillButton />
|
||||
</Flexbox>
|
||||
<Search activeTab={activeTab} onLobeHubSearch={setLobehubKeywords} />
|
||||
<Search
|
||||
activeTab={activeTab}
|
||||
onLobeHubSearch={setLobehubKeywords}
|
||||
onSkillSearch={setSkillKeywords}
|
||||
/>
|
||||
</Flexbox>
|
||||
<Flexbox height={496}>
|
||||
<Flexbox flex={1} style={{ display: isLobeHub ? 'flex' : 'none', overflow: 'auto' }}>
|
||||
<LobeHubList keywords={lobehubKeywords} />
|
||||
</Flexbox>
|
||||
<Flexbox flex={1} style={{ display: isCommunity ? 'flex' : 'none', overflow: 'auto' }}>
|
||||
<CommunityList />
|
||||
<Flexbox flex={1} style={{ display: isSkills ? 'flex' : 'none', overflow: 'auto' }}>
|
||||
<MarketSkillList keywords={skillKeywords} />
|
||||
</Flexbox>
|
||||
<Flexbox flex={1} style={{ display: isMCP ? 'flex' : 'none', overflow: 'auto' }}>
|
||||
<MCPList />
|
||||
</Flexbox>
|
||||
<Flexbox flex={1} style={{ display: isCustom ? 'flex' : 'none', overflow: 'auto' }}>
|
||||
<CustomList />
|
||||
|
||||
@@ -71,7 +71,6 @@ describe('createContextInner', () => {
|
||||
const context = await createContextInner();
|
||||
|
||||
expect(context).toMatchObject({
|
||||
authorizationHeader: undefined,
|
||||
marketAccessToken: undefined,
|
||||
oidcAuth: undefined,
|
||||
userAgent: undefined,
|
||||
@@ -86,14 +85,6 @@ describe('createContextInner', () => {
|
||||
expect(context.userId).toBe('user-123');
|
||||
});
|
||||
|
||||
it('should create context with authorization header', async () => {
|
||||
const context = await createContextInner({
|
||||
authorizationHeader: 'Bearer token-abc',
|
||||
});
|
||||
|
||||
expect(context.authorizationHeader).toBe('Bearer token-abc');
|
||||
});
|
||||
|
||||
it('should create context with user agent', async () => {
|
||||
const context = await createContextInner({
|
||||
userAgent: 'Mozilla/5.0',
|
||||
@@ -123,7 +114,6 @@ describe('createContextInner', () => {
|
||||
|
||||
it('should create context with all parameters combined', async () => {
|
||||
const params = {
|
||||
authorizationHeader: 'Bearer token',
|
||||
userId: 'user-123',
|
||||
userAgent: 'Test Agent',
|
||||
marketAccessToken: 'mp-token',
|
||||
@@ -136,7 +126,6 @@ describe('createContextInner', () => {
|
||||
const context = await createContextInner(params);
|
||||
|
||||
expect(context).toMatchObject({
|
||||
authorizationHeader: 'Bearer token',
|
||||
userId: 'user-123',
|
||||
userAgent: 'Test Agent',
|
||||
marketAccessToken: 'mp-token',
|
||||
|
||||
@@ -7,7 +7,7 @@ import { type NextRequest } from 'next/server';
|
||||
import { auth } from '@/auth';
|
||||
import { getServerDB } from '@/database/core/db-adaptor';
|
||||
import { ApiKeyModel } from '@/database/models/apiKey';
|
||||
import { authEnv, LOBE_CHAT_AUTH_HEADER, LOBE_CHAT_OIDC_AUTH_HEADER } from '@/envs/auth';
|
||||
import { authEnv, LOBE_CHAT_OIDC_AUTH_HEADER } from '@/envs/auth';
|
||||
import { extractTraceContext } from '@/libs/observability/traceparent';
|
||||
import { validateOIDCJWT } from '@/libs/oidc-provider/jwt';
|
||||
import { isApiKeyExpired, validateApiKeyFormat } from '@/utils/apiKey';
|
||||
@@ -64,7 +64,6 @@ export interface OIDCAuth {
|
||||
}
|
||||
|
||||
export interface AuthContext {
|
||||
authorizationHeader?: string | null;
|
||||
clientIp?: string | null;
|
||||
jwtPayload?: ClientSecretPayload | null;
|
||||
marketAccessToken?: string;
|
||||
@@ -81,7 +80,6 @@ export interface AuthContext {
|
||||
* This is useful for testing when we don't want to mock Next.js' request/response
|
||||
*/
|
||||
export const createContextInner = async (params?: {
|
||||
authorizationHeader?: string | null;
|
||||
clientIp?: string | null;
|
||||
marketAccessToken?: string;
|
||||
oidcAuth?: OIDCAuth | null;
|
||||
@@ -93,7 +91,6 @@ export const createContextInner = async (params?: {
|
||||
const responseHeaders = new Headers();
|
||||
|
||||
return {
|
||||
authorizationHeader: params?.authorizationHeader,
|
||||
clientIp: params?.clientIp,
|
||||
marketAccessToken: params?.marketAccessToken,
|
||||
oidcAuth: params?.oidcAuth,
|
||||
@@ -118,7 +115,6 @@ export const createLambdaContext = async (request: NextRequest): Promise<LambdaC
|
||||
|
||||
if (process.env.NODE_ENV === 'development' && (isDebugApi || isMockUser)) {
|
||||
return createContextInner({
|
||||
authorizationHeader: request.headers.get(LOBE_CHAT_AUTH_HEADER),
|
||||
userId: process.env.MOCK_DEV_USER_ID,
|
||||
});
|
||||
}
|
||||
@@ -126,7 +122,6 @@ export const createLambdaContext = async (request: NextRequest): Promise<LambdaC
|
||||
log('createLambdaContext called for request');
|
||||
// for API-response caching see https://trpc.io/docs/v11/caching
|
||||
|
||||
const authorization = request.headers.get(LOBE_CHAT_AUTH_HEADER);
|
||||
const userAgent = request.headers.get('user-agent') || undefined;
|
||||
const clientIp = extractClientIp(request);
|
||||
|
||||
@@ -139,12 +134,10 @@ export const createLambdaContext = async (request: NextRequest): Promise<LambdaC
|
||||
|
||||
log('marketAccessToken from cookie:', marketAccessToken ? '[HIDDEN]' : 'undefined');
|
||||
const commonContext = {
|
||||
authorizationHeader: authorization,
|
||||
clientIp,
|
||||
marketAccessToken,
|
||||
userAgent,
|
||||
};
|
||||
log('LobeChat Authorization header: %s', authorization ? 'exists' : 'not found');
|
||||
|
||||
const apiKeyToken = request.headers.get(LOBE_CHAT_API_KEY_HEADER)?.trim();
|
||||
log('X-API-Key header: %s', apiKeyToken ? 'exists' : 'not found');
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
export * from './keyVaults';
|
||||
export * from './marketSDK';
|
||||
export * from './marketUserInfo';
|
||||
export * from './serverDatabase';
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
import { getXorPayload } from '@lobechat/utils/server';
|
||||
import { TRPCError } from '@trpc/server';
|
||||
|
||||
import { trpc } from '../init';
|
||||
|
||||
export const keyVaults = trpc.middleware(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
if (!ctx.authorizationHeader) throw new TRPCError({ code: 'UNAUTHORIZED' });
|
||||
|
||||
try {
|
||||
const jwtPayload = getXorPayload(ctx.authorizationHeader);
|
||||
|
||||
return opts.next({ ctx: { jwtPayload } });
|
||||
} catch (e) {
|
||||
throw new TRPCError({ code: 'UNAUTHORIZED', message: (e as Error).message });
|
||||
}
|
||||
});
|
||||
@@ -796,6 +796,8 @@ export default {
|
||||
'skillStore.tabs.community': 'Community',
|
||||
'skillStore.tabs.custom': 'Custom',
|
||||
'skillStore.tabs.lobehub': 'LobeHub',
|
||||
'skillStore.tabs.mcp': 'MCP',
|
||||
'skillStore.tabs.skills': 'Skills',
|
||||
'skillStore.title': 'Skill Store',
|
||||
'skillStore.wantMore.action': 'Submit a request →',
|
||||
'skillStore.wantMore.feedback.message': `## Skill Name
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
/**
|
||||
* @vitest-environment happy-dom
|
||||
*/
|
||||
import { render } from '@testing-library/react';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { initialState as initialChatState } from '@/store/chat/initialState';
|
||||
import { PortalViewType } from '@/store/chat/slices/portal/initialState';
|
||||
import { useChatStore } from '@/store/chat/store';
|
||||
|
||||
import AgentIdSync from './AgentIdSync';
|
||||
|
||||
const useParamsMock = vi.hoisted(() => vi.fn());
|
||||
const useSearchParamsMock = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock('react-router-dom', async () => {
|
||||
// eslint-disable-next-line @typescript-eslint/consistent-type-imports
|
||||
const actual = (await vi.importActual('react-router-dom')) as typeof import('react-router-dom');
|
||||
|
||||
return {
|
||||
...actual,
|
||||
useParams: useParamsMock,
|
||||
useSearchParams: useSearchParamsMock,
|
||||
};
|
||||
});
|
||||
|
||||
describe('AgentIdSync', () => {
|
||||
beforeEach(() => {
|
||||
useParamsMock.mockReset();
|
||||
useSearchParamsMock.mockReset();
|
||||
|
||||
useChatStore.setState(
|
||||
{
|
||||
...initialChatState,
|
||||
activeAgentId: 'agent-1',
|
||||
activeTopicId: 'topic-1',
|
||||
portalStack: [{ type: PortalViewType.Home }],
|
||||
showPortal: true,
|
||||
},
|
||||
false,
|
||||
);
|
||||
});
|
||||
|
||||
it('clears portal state when switching to another agent without a topic in the URL', () => {
|
||||
useParamsMock.mockReturnValue({ aid: 'agent-1' });
|
||||
useSearchParamsMock.mockReturnValue([new URLSearchParams(''), vi.fn()]);
|
||||
|
||||
const { rerender } = render(<AgentIdSync />);
|
||||
|
||||
expect(useChatStore.getState().showPortal).toBe(true);
|
||||
|
||||
useParamsMock.mockReturnValue({ aid: 'agent-2' });
|
||||
rerender(<AgentIdSync />);
|
||||
|
||||
expect(useChatStore.getState().activeTopicId).toBeNull();
|
||||
expect(useChatStore.getState().portalStack).toEqual([]);
|
||||
expect(useChatStore.getState().showPortal).toBe(false);
|
||||
});
|
||||
|
||||
it('still clears portal state when the destination URL already has a topic', () => {
|
||||
useParamsMock.mockReturnValue({ aid: 'agent-1' });
|
||||
useSearchParamsMock.mockReturnValue([new URLSearchParams('topic=topic-2'), vi.fn()]);
|
||||
|
||||
const { rerender } = render(<AgentIdSync />);
|
||||
|
||||
useParamsMock.mockReturnValue({ aid: 'agent-2' });
|
||||
rerender(<AgentIdSync />);
|
||||
|
||||
expect(useChatStore.getState().portalStack).toEqual([]);
|
||||
expect(useChatStore.getState().showPortal).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -23,6 +23,8 @@ const AgentIdSync = () => {
|
||||
useEffect(() => {
|
||||
// Only reset topic when switching between agents (not on initial mount)
|
||||
if (prevAgentId !== undefined && prevAgentId !== params.aid) {
|
||||
useChatStore.getState().clearPortalStack();
|
||||
|
||||
// Preserve topic if the URL already carries one (e.g. tab navigation)
|
||||
const topicFromUrl = searchParamsRef.current.get('topic');
|
||||
|
||||
|
||||
@@ -103,7 +103,7 @@ const StarterList = memo(() => {
|
||||
}
|
||||
|
||||
if (key === 'image') {
|
||||
navigate?.('/image?model=gemini-3.1-flash-image-preview:image');
|
||||
navigate?.('/image');
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ const PurgeButton = memo<Props>(({ iconOnly }) => {
|
||||
const handleClick = () => {
|
||||
modal.confirm({
|
||||
cancelText: translate('cancel', { ns: 'common' }),
|
||||
content: translate('purge.confirm'),
|
||||
content: translate('purge.confirm', { ns: 'memory' }),
|
||||
okButtonProps: { danger: true, loading },
|
||||
okText: translate('confirm', { ns: 'common' }),
|
||||
onOk: async () => {
|
||||
@@ -47,22 +47,22 @@ const PurgeButton = memo<Props>(({ iconOnly }) => {
|
||||
}
|
||||
|
||||
setSearchParams(nextSearchParams, { replace: true });
|
||||
message.success(translate('purge.success'));
|
||||
message.success(translate('purge.success', { ns: 'memory' }));
|
||||
} catch {
|
||||
message.error(translate('purge.error'));
|
||||
message.error(translate('purge.error', { ns: 'memory' }));
|
||||
throw new Error('Failed to purge memories');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
},
|
||||
title: translate('purge.title'),
|
||||
title: translate('purge.title', { ns: 'memory' }),
|
||||
type: 'warning',
|
||||
});
|
||||
};
|
||||
|
||||
if (iconOnly) {
|
||||
return (
|
||||
<Tooltip title={translate('purge.action')}>
|
||||
<Tooltip title={translate('purge.action', { ns: 'memory' })}>
|
||||
<ActionIcon
|
||||
danger
|
||||
icon={Trash2Icon}
|
||||
@@ -85,7 +85,7 @@ const PurgeButton = memo<Props>(({ iconOnly }) => {
|
||||
type={'primary'}
|
||||
onClick={handleClick}
|
||||
>
|
||||
{translate('purge.action')}
|
||||
{translate('purge.action', { ns: 'memory' })}
|
||||
</Button>
|
||||
);
|
||||
});
|
||||
|
||||
@@ -0,0 +1,174 @@
|
||||
import debug from 'debug';
|
||||
import urlJoin from 'url-join';
|
||||
|
||||
import type { StreamChunkData, StreamEvent } from './StreamEventManager';
|
||||
import type { IStreamEventManager } from './types';
|
||||
|
||||
const log = debug('lobe-server:agent-runtime:gateway-notifier');
|
||||
|
||||
const POST_TIMEOUT = 5000; // 5s per request
|
||||
const MAX_INFLIGHT = 20; // bounded concurrency
|
||||
|
||||
/**
|
||||
* Decorator that wraps an IStreamEventManager and additionally
|
||||
* pushes events to the Agent Gateway via HTTP (fire-and-forget).
|
||||
*
|
||||
* Redis SSE remains the primary event storage / subscription mechanism.
|
||||
* The Gateway is an additional push channel for WebSocket delivery.
|
||||
*/
|
||||
export class GatewayStreamNotifier implements IStreamEventManager {
|
||||
private inflight = 0;
|
||||
|
||||
constructor(
|
||||
private inner: IStreamEventManager,
|
||||
private gatewayUrl: string,
|
||||
private serviceToken: string,
|
||||
) {
|
||||
log('Gateway notifier initialized: %s', gatewayUrl);
|
||||
}
|
||||
|
||||
// ─── Publish methods: delegate to inner + notify gateway ───
|
||||
|
||||
async publishStreamEvent(
|
||||
operationId: string,
|
||||
event: Omit<StreamEvent, 'operationId' | 'timestamp'>,
|
||||
): Promise<string> {
|
||||
const result = await this.inner.publishStreamEvent(operationId, event);
|
||||
this.pushEvent(operationId, { ...event, operationId, timestamp: Date.now() });
|
||||
return result;
|
||||
}
|
||||
|
||||
async publishStreamChunk(
|
||||
operationId: string,
|
||||
stepIndex: number,
|
||||
chunkData: StreamChunkData,
|
||||
): Promise<string> {
|
||||
const result = await this.inner.publishStreamChunk(operationId, stepIndex, chunkData);
|
||||
this.pushEvent(operationId, {
|
||||
data: chunkData,
|
||||
operationId,
|
||||
stepIndex,
|
||||
timestamp: Date.now(),
|
||||
type: 'stream_chunk',
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
async publishAgentRuntimeInit(operationId: string, initialState: any): Promise<string> {
|
||||
const result = await this.inner.publishAgentRuntimeInit(operationId, initialState);
|
||||
|
||||
this.httpPost('/api/operations/init', {
|
||||
operationId,
|
||||
userId: initialState?.userId || 'unknown',
|
||||
});
|
||||
|
||||
this.pushEvent(operationId, {
|
||||
data: initialState,
|
||||
operationId,
|
||||
stepIndex: 0,
|
||||
timestamp: Date.now(),
|
||||
type: 'agent_runtime_init',
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
async publishAgentRuntimeEnd(
|
||||
operationId: string,
|
||||
stepIndex: number,
|
||||
finalState: any,
|
||||
reason?: string,
|
||||
reasonDetail?: string,
|
||||
): Promise<string> {
|
||||
const result = await this.inner.publishAgentRuntimeEnd(
|
||||
operationId,
|
||||
stepIndex,
|
||||
finalState,
|
||||
reason,
|
||||
reasonDetail,
|
||||
);
|
||||
|
||||
this.pushEvent(operationId, {
|
||||
data: { finalState, reason, reasonDetail },
|
||||
operationId,
|
||||
stepIndex,
|
||||
timestamp: Date.now(),
|
||||
type: 'agent_runtime_end',
|
||||
});
|
||||
|
||||
const status =
|
||||
reason === 'error' ? 'error' : reason === 'interrupted' ? 'interrupted' : 'completed';
|
||||
this.httpPost('/api/operations/update-status', {
|
||||
operationId,
|
||||
status,
|
||||
summary: reasonDetail,
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ─── Read / subscribe methods: delegate directly to inner ───
|
||||
|
||||
async subscribeStreamEvents(
|
||||
operationId: string,
|
||||
lastEventId: string,
|
||||
onEvents: (events: StreamEvent[]) => void,
|
||||
signal?: AbortSignal,
|
||||
): Promise<void> {
|
||||
return this.inner.subscribeStreamEvents(operationId, lastEventId, onEvents, signal);
|
||||
}
|
||||
|
||||
async getStreamHistory(operationId: string, count?: number): Promise<StreamEvent[]> {
|
||||
return this.inner.getStreamHistory(operationId, count);
|
||||
}
|
||||
|
||||
async cleanupOperation(operationId: string): Promise<void> {
|
||||
return this.inner.cleanupOperation(operationId);
|
||||
}
|
||||
|
||||
async getActiveOperationsCount(): Promise<number> {
|
||||
return this.inner.getActiveOperationsCount();
|
||||
}
|
||||
|
||||
async disconnect(): Promise<void> {
|
||||
return this.inner.disconnect();
|
||||
}
|
||||
|
||||
// ─── Gateway HTTP helpers ───
|
||||
|
||||
private pushEvent(operationId: string, event: Record<string, unknown>) {
|
||||
this.httpPost('/api/operations/push-event', { event, operationId }).catch(() => {});
|
||||
}
|
||||
|
||||
private async httpPost(path: string, body: Record<string, unknown>): Promise<void> {
|
||||
if (this.inflight >= MAX_INFLIGHT) {
|
||||
log('Gateway %s dropped: max inflight (%d) reached', path, MAX_INFLIGHT);
|
||||
return;
|
||||
}
|
||||
|
||||
this.inflight++;
|
||||
const controller = new AbortController();
|
||||
const timer = setTimeout(() => controller.abort(), POST_TIMEOUT);
|
||||
|
||||
try {
|
||||
const res = await fetch(urlJoin(this.gatewayUrl, path), {
|
||||
body: JSON.stringify(body),
|
||||
headers: {
|
||||
'Authorization': `Bearer ${this.serviceToken}`,
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
method: 'POST',
|
||||
signal: controller.signal,
|
||||
});
|
||||
|
||||
if (!res.ok) {
|
||||
log('Gateway %s returned %d: %s', path, res.status, await res.text());
|
||||
}
|
||||
} catch (error) {
|
||||
log('Gateway %s failed: %O', path, error);
|
||||
} finally {
|
||||
clearTimeout(timer);
|
||||
this.inflight--;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -146,10 +146,7 @@ const executeToolWithRetry = async (
|
||||
throw new Error('Tool execution retry loop exited unexpectedly');
|
||||
};
|
||||
|
||||
const buildToolDiscoveryConfig = (
|
||||
operationToolSet: OperationToolSet,
|
||||
enabledToolIds: string[],
|
||||
) => {
|
||||
const buildToolDiscoveryConfig = (operationToolSet: OperationToolSet, enabledToolIds: string[]) => {
|
||||
const enabledToolSet = new Set(enabledToolIds);
|
||||
|
||||
if (!enabledToolSet.has(LobeActivatorIdentifier)) return undefined;
|
||||
@@ -164,7 +161,7 @@ const buildToolDiscoveryConfig = (
|
||||
|
||||
if (availableTools.length === 0) return undefined;
|
||||
|
||||
return { availableTools }
|
||||
return { availableTools };
|
||||
};
|
||||
|
||||
const formatErrorEventData = (error: unknown, phase: string) => {
|
||||
@@ -385,6 +382,7 @@ export const createRuntimeExecutors = (
|
||||
if (docs.length > 0) {
|
||||
agentDocuments = docs.map((doc) => ({
|
||||
content: doc.content,
|
||||
description: doc.description ?? undefined,
|
||||
filename: doc.filename,
|
||||
id: doc.id,
|
||||
loadPosition: normalizeDocumentPosition(
|
||||
@@ -392,6 +390,7 @@ export const createRuntimeExecutors = (
|
||||
),
|
||||
loadRules: doc.loadRules,
|
||||
policyId: doc.templateId,
|
||||
policyLoad: doc.policyLoad as 'always' | 'progressive',
|
||||
policyLoadFormat: doc.policy?.context?.policyLoadFormat || doc.policyLoadFormat,
|
||||
title: doc.title,
|
||||
}));
|
||||
|
||||
@@ -0,0 +1,320 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { GatewayStreamNotifier } from '../GatewayStreamNotifier';
|
||||
import type { StreamChunkData } from '../StreamEventManager';
|
||||
import type { IStreamEventManager } from '../types';
|
||||
|
||||
// Mock global fetch
|
||||
const mockFetch = vi.fn().mockResolvedValue({ ok: true, text: () => Promise.resolve('') });
|
||||
vi.stubGlobal('fetch', mockFetch);
|
||||
|
||||
function createMockInner(): IStreamEventManager & { calls: Record<string, any[][]> } {
|
||||
const calls: Record<string, any[][]> = {};
|
||||
|
||||
const track = (name: string) => {
|
||||
calls[name] = [];
|
||||
return (...args: any[]) => {
|
||||
calls[name].push(args);
|
||||
return Promise.resolve(`${name}-result`);
|
||||
};
|
||||
};
|
||||
|
||||
return {
|
||||
calls,
|
||||
cleanupOperation: track('cleanupOperation') as any,
|
||||
disconnect: track('disconnect') as any,
|
||||
getActiveOperationsCount: track('getActiveOperationsCount') as any,
|
||||
getStreamHistory: track('getStreamHistory') as any,
|
||||
publishAgentRuntimeEnd: track('publishAgentRuntimeEnd') as any,
|
||||
publishAgentRuntimeInit: track('publishAgentRuntimeInit') as any,
|
||||
publishStreamChunk: track('publishStreamChunk') as any,
|
||||
publishStreamEvent: track('publishStreamEvent') as any,
|
||||
subscribeStreamEvents: track('subscribeStreamEvents') as any,
|
||||
};
|
||||
}
|
||||
|
||||
describe('GatewayStreamNotifier', () => {
|
||||
let inner: ReturnType<typeof createMockInner>;
|
||||
let notifier: GatewayStreamNotifier;
|
||||
const gatewayUrl = 'https://gateway.test.com';
|
||||
const serviceToken = 'test-token';
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
inner = createMockInner();
|
||||
notifier = new GatewayStreamNotifier(inner, gatewayUrl, serviceToken);
|
||||
});
|
||||
|
||||
// ─── Publish methods: must always call inner first ───
|
||||
|
||||
describe('publishStreamEvent', () => {
|
||||
it('delegates to inner and returns its result', async () => {
|
||||
const event = { data: { foo: 'bar' }, stepIndex: 0, type: 'step_start' as const };
|
||||
|
||||
const result = await notifier.publishStreamEvent('op-1', event);
|
||||
|
||||
expect(result).toBe('publishStreamEvent-result');
|
||||
expect(inner.calls.publishStreamEvent).toHaveLength(1);
|
||||
expect(inner.calls.publishStreamEvent[0]).toEqual(['op-1', event]);
|
||||
});
|
||||
|
||||
it('pushes event to gateway via HTTP', async () => {
|
||||
await notifier.publishStreamEvent('op-1', {
|
||||
data: {},
|
||||
stepIndex: 0,
|
||||
type: 'step_start' as const,
|
||||
});
|
||||
|
||||
// Wait for fire-and-forget
|
||||
await new Promise((r) => setTimeout(r, 50));
|
||||
|
||||
expect(mockFetch).toHaveBeenCalledWith(
|
||||
`${gatewayUrl}/api/operations/push-event`,
|
||||
expect.objectContaining({
|
||||
headers: expect.objectContaining({
|
||||
Authorization: `Bearer ${serviceToken}`,
|
||||
}),
|
||||
method: 'POST',
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('still returns inner result even if gateway fails', async () => {
|
||||
mockFetch.mockRejectedValueOnce(new Error('network error'));
|
||||
|
||||
const result = await notifier.publishStreamEvent('op-1', {
|
||||
data: {},
|
||||
stepIndex: 0,
|
||||
type: 'step_start' as const,
|
||||
});
|
||||
|
||||
expect(result).toBe('publishStreamEvent-result');
|
||||
expect(inner.calls.publishStreamEvent).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('publishStreamChunk', () => {
|
||||
it('delegates to inner and returns its result', async () => {
|
||||
const chunkData: StreamChunkData = { chunkType: 'text', content: 'hello' };
|
||||
|
||||
const result = await notifier.publishStreamChunk('op-1', 0, chunkData);
|
||||
|
||||
expect(result).toBe('publishStreamChunk-result');
|
||||
expect(inner.calls.publishStreamChunk).toHaveLength(1);
|
||||
expect(inner.calls.publishStreamChunk[0]).toEqual(['op-1', 0, chunkData]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('publishAgentRuntimeInit', () => {
|
||||
it('delegates to inner and returns its result', async () => {
|
||||
const initialState = { userId: 'user-1' };
|
||||
|
||||
const result = await notifier.publishAgentRuntimeInit('op-1', initialState);
|
||||
|
||||
expect(result).toBe('publishAgentRuntimeInit-result');
|
||||
expect(inner.calls.publishAgentRuntimeInit).toHaveLength(1);
|
||||
expect(inner.calls.publishAgentRuntimeInit[0]).toEqual(['op-1', initialState]);
|
||||
});
|
||||
|
||||
it('calls gateway init and push-event endpoints', async () => {
|
||||
await notifier.publishAgentRuntimeInit('op-1', { userId: 'user-1' });
|
||||
|
||||
await new Promise((r) => setTimeout(r, 50));
|
||||
|
||||
const urls = mockFetch.mock.calls.map((c: any[]) => c[0]);
|
||||
expect(urls).toContain(`${gatewayUrl}/api/operations/init`);
|
||||
expect(urls).toContain(`${gatewayUrl}/api/operations/push-event`);
|
||||
});
|
||||
});
|
||||
|
||||
describe('publishAgentRuntimeEnd', () => {
|
||||
it('delegates to inner and returns its result', async () => {
|
||||
const finalState = { status: 'done' };
|
||||
|
||||
const result = await notifier.publishAgentRuntimeEnd('op-1', 2, finalState, 'completed');
|
||||
|
||||
expect(result).toBe('publishAgentRuntimeEnd-result');
|
||||
expect(inner.calls.publishAgentRuntimeEnd).toHaveLength(1);
|
||||
expect(inner.calls.publishAgentRuntimeEnd[0]).toEqual([
|
||||
'op-1',
|
||||
2,
|
||||
finalState,
|
||||
'completed',
|
||||
undefined,
|
||||
]);
|
||||
});
|
||||
|
||||
it('calls gateway push-event and update-status endpoints', async () => {
|
||||
await notifier.publishAgentRuntimeEnd('op-1', 2, {}, 'completed', 'All done');
|
||||
|
||||
await new Promise((r) => setTimeout(r, 50));
|
||||
|
||||
const urls = mockFetch.mock.calls.map((c: any[]) => c[0]);
|
||||
expect(urls).toContain(`${gatewayUrl}/api/operations/push-event`);
|
||||
expect(urls).toContain(`${gatewayUrl}/api/operations/update-status`);
|
||||
});
|
||||
|
||||
it('maps error reason to error status', async () => {
|
||||
await notifier.publishAgentRuntimeEnd('op-1', 0, {}, 'error', 'Something broke');
|
||||
|
||||
await new Promise((r) => setTimeout(r, 50));
|
||||
|
||||
const statusCall = mockFetch.mock.calls.find(
|
||||
(c: any[]) => c[0] === `${gatewayUrl}/api/operations/update-status`,
|
||||
);
|
||||
expect(statusCall).toBeDefined();
|
||||
const body = JSON.parse(statusCall![1].body);
|
||||
expect(body.status).toBe('error');
|
||||
});
|
||||
});
|
||||
|
||||
// ─── Read/subscribe methods: must delegate directly to inner ───
|
||||
|
||||
describe('subscribeStreamEvents', () => {
|
||||
it('delegates directly to inner', async () => {
|
||||
const onEvents = vi.fn();
|
||||
const signal = new AbortController().signal;
|
||||
|
||||
await notifier.subscribeStreamEvents('op-1', '0', onEvents, signal);
|
||||
|
||||
expect(inner.calls.subscribeStreamEvents).toHaveLength(1);
|
||||
expect(inner.calls.subscribeStreamEvents[0]).toEqual(['op-1', '0', onEvents, signal]);
|
||||
});
|
||||
|
||||
it('does not call gateway', async () => {
|
||||
await notifier.subscribeStreamEvents('op-1', '0', vi.fn());
|
||||
|
||||
expect(mockFetch).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getStreamHistory', () => {
|
||||
it('delegates directly to inner', async () => {
|
||||
await notifier.getStreamHistory('op-1', 50);
|
||||
|
||||
expect(inner.calls.getStreamHistory).toHaveLength(1);
|
||||
expect(inner.calls.getStreamHistory[0]).toEqual(['op-1', 50]);
|
||||
});
|
||||
});
|
||||
|
||||
describe('cleanupOperation', () => {
|
||||
it('delegates directly to inner', async () => {
|
||||
await notifier.cleanupOperation('op-1');
|
||||
|
||||
expect(inner.calls.cleanupOperation).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('getActiveOperationsCount', () => {
|
||||
it('delegates directly to inner', async () => {
|
||||
await notifier.getActiveOperationsCount();
|
||||
|
||||
expect(inner.calls.getActiveOperationsCount).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
describe('disconnect', () => {
|
||||
it('delegates directly to inner', async () => {
|
||||
await notifier.disconnect();
|
||||
|
||||
expect(inner.calls.disconnect).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
// ─── Gateway failure resilience ───
|
||||
|
||||
describe('gateway failure does not affect inner', () => {
|
||||
it('publishStreamEvent succeeds when gateway is unreachable', async () => {
|
||||
mockFetch.mockRejectedValue(new Error('connection refused'));
|
||||
|
||||
const result = await notifier.publishStreamEvent('op-1', {
|
||||
data: {},
|
||||
stepIndex: 0,
|
||||
type: 'step_start' as const,
|
||||
});
|
||||
|
||||
expect(result).toBe('publishStreamEvent-result');
|
||||
expect(inner.calls.publishStreamEvent).toHaveLength(1);
|
||||
});
|
||||
|
||||
it('publishAgentRuntimeInit succeeds when gateway returns 500', async () => {
|
||||
mockFetch.mockResolvedValue({ ok: false, status: 500, text: () => 'Internal Error' });
|
||||
|
||||
const result = await notifier.publishAgentRuntimeInit('op-1', { userId: 'u1' });
|
||||
|
||||
expect(result).toBe('publishAgentRuntimeInit-result');
|
||||
expect(inner.calls.publishAgentRuntimeInit).toHaveLength(1);
|
||||
});
|
||||
|
||||
it('publishAgentRuntimeEnd succeeds when gateway times out', async () => {
|
||||
mockFetch.mockImplementation(
|
||||
() => new Promise((_, reject) => setTimeout(() => reject(new Error('timeout')), 10)),
|
||||
);
|
||||
|
||||
const result = await notifier.publishAgentRuntimeEnd('op-1', 0, {}, 'completed');
|
||||
|
||||
expect(result).toBe('publishAgentRuntimeEnd-result');
|
||||
expect(inner.calls.publishAgentRuntimeEnd).toHaveLength(1);
|
||||
});
|
||||
});
|
||||
|
||||
// ─── Timeout and concurrency ───
|
||||
|
||||
describe('timeout and concurrency control', () => {
|
||||
it('passes AbortSignal to fetch', async () => {
|
||||
await notifier.publishStreamEvent('op-1', {
|
||||
data: {},
|
||||
stepIndex: 0,
|
||||
type: 'step_start' as const,
|
||||
});
|
||||
|
||||
await new Promise((r) => setTimeout(r, 50));
|
||||
|
||||
const fetchCall = mockFetch.mock.calls[0];
|
||||
expect(fetchCall[1].signal).toBeInstanceOf(AbortSignal);
|
||||
});
|
||||
|
||||
it('drops requests when max inflight is reached', async () => {
|
||||
// Hold all fetches pending
|
||||
const resolvers: Array<() => void> = [];
|
||||
mockFetch.mockImplementation(
|
||||
() =>
|
||||
new Promise<{ ok: boolean }>((resolve) => {
|
||||
resolvers.push(() => resolve({ ok: true }));
|
||||
}),
|
||||
);
|
||||
|
||||
// Fire 25 events (max inflight is 20)
|
||||
for (let i = 0; i < 25; i++) {
|
||||
notifier.publishStreamEvent(`op-${i}`, {
|
||||
data: {},
|
||||
stepIndex: 0,
|
||||
type: 'step_start' as const,
|
||||
});
|
||||
}
|
||||
|
||||
await new Promise((r) => setTimeout(r, 50));
|
||||
|
||||
// Only 20 should have actually called fetch
|
||||
expect(mockFetch).toHaveBeenCalledTimes(20);
|
||||
|
||||
// Release all pending
|
||||
for (const r of resolvers) r();
|
||||
});
|
||||
|
||||
it('uses url-join for URL construction', async () => {
|
||||
await notifier.publishStreamEvent('op-1', {
|
||||
data: {},
|
||||
stepIndex: 0,
|
||||
type: 'step_start' as const,
|
||||
});
|
||||
|
||||
await new Promise((r) => setTimeout(r, 50));
|
||||
|
||||
const url = mockFetch.mock.calls[0][0];
|
||||
expect(url).toBe(`${gatewayUrl}/api/operations/push-event`);
|
||||
// No double slashes
|
||||
expect(url).not.toContain('//api');
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -4,6 +4,7 @@ import { createAgentStateManager, createStreamEventManager, isRedisAvailable } f
|
||||
|
||||
const {
|
||||
MockAgentStateManager,
|
||||
MockGatewayStreamNotifier,
|
||||
MockStreamEventManager,
|
||||
mockAppEnv,
|
||||
mockGetAgentRuntimeRedisClient,
|
||||
@@ -11,8 +12,16 @@ const {
|
||||
mockInMemoryStreamEventManager,
|
||||
} = vi.hoisted(() => ({
|
||||
MockAgentStateManager: vi.fn(() => ({ kind: 'redis-state-manager' })),
|
||||
MockGatewayStreamNotifier: vi.fn((inner: any, url: string, token: string) => ({
|
||||
inner,
|
||||
kind: 'gateway-stream-notifier',
|
||||
token,
|
||||
url,
|
||||
})),
|
||||
MockStreamEventManager: vi.fn(() => ({ kind: 'redis-stream-event-manager' })),
|
||||
mockAppEnv: {
|
||||
AGENT_GATEWAY_SERVICE_TOKEN: undefined as string | undefined,
|
||||
AGENT_GATEWAY_URL: 'https://agent-gateway.lobehub.com',
|
||||
enableQueueAgentRuntime: false,
|
||||
},
|
||||
mockGetAgentRuntimeRedisClient: vi.fn(),
|
||||
@@ -44,6 +53,10 @@ vi.mock('../StreamEventManager', () => ({
|
||||
StreamEventManager: MockStreamEventManager,
|
||||
}));
|
||||
|
||||
vi.mock('../GatewayStreamNotifier', () => ({
|
||||
GatewayStreamNotifier: MockGatewayStreamNotifier,
|
||||
}));
|
||||
|
||||
describe('AgentRuntime factory', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
@@ -89,6 +102,11 @@ describe('AgentRuntime factory', () => {
|
||||
});
|
||||
|
||||
describe('createStreamEventManager', () => {
|
||||
beforeEach(() => {
|
||||
mockAppEnv.AGENT_GATEWAY_SERVICE_TOKEN = undefined;
|
||||
mockAppEnv.AGENT_GATEWAY_URL = 'https://agent-gateway.lobehub.com';
|
||||
});
|
||||
|
||||
it('prefers Redis-backed streams when Redis is available in local mode', () => {
|
||||
mockGetAgentRuntimeRedisClient.mockReturnValue({ ping: vi.fn() });
|
||||
|
||||
@@ -108,5 +126,46 @@ describe('AgentRuntime factory', () => {
|
||||
'Redis is required when AGENT_RUNTIME_MODE=queue. Please configure `REDIS_URL`.',
|
||||
);
|
||||
});
|
||||
|
||||
it('wraps with GatewayStreamNotifier when AGENT_GATEWAY_SERVICE_TOKEN is set', () => {
|
||||
mockAppEnv.AGENT_GATEWAY_SERVICE_TOKEN = 'my-token';
|
||||
mockGetAgentRuntimeRedisClient.mockReturnValue({ ping: vi.fn() });
|
||||
|
||||
const result = createStreamEventManager() as any;
|
||||
|
||||
expect(result.kind).toBe('gateway-stream-notifier');
|
||||
expect(result.inner).toEqual({ kind: 'redis-stream-event-manager' });
|
||||
expect(result.token).toBe('my-token');
|
||||
expect(result.url).toBe('https://agent-gateway.lobehub.com');
|
||||
});
|
||||
|
||||
it('uses custom AGENT_GATEWAY_URL when set', () => {
|
||||
mockAppEnv.AGENT_GATEWAY_SERVICE_TOKEN = 'my-token';
|
||||
mockAppEnv.AGENT_GATEWAY_URL = 'https://custom-gateway.example.com';
|
||||
mockGetAgentRuntimeRedisClient.mockReturnValue({ ping: vi.fn() });
|
||||
|
||||
const result = createStreamEventManager() as any;
|
||||
|
||||
expect(result.kind).toBe('gateway-stream-notifier');
|
||||
expect(result.url).toBe('https://custom-gateway.example.com');
|
||||
});
|
||||
|
||||
it('wraps in-memory manager with gateway when no Redis', () => {
|
||||
mockAppEnv.AGENT_GATEWAY_SERVICE_TOKEN = 'my-token';
|
||||
|
||||
const result = createStreamEventManager() as any;
|
||||
|
||||
expect(result.kind).toBe('gateway-stream-notifier');
|
||||
expect(result.inner).toBe(mockInMemoryStreamEventManager);
|
||||
});
|
||||
|
||||
it('does not wrap when AGENT_GATEWAY_SERVICE_TOKEN is not set', () => {
|
||||
mockGetAgentRuntimeRedisClient.mockReturnValue({ ping: vi.fn() });
|
||||
|
||||
const result = createStreamEventManager() as any;
|
||||
|
||||
expect(result.kind).toBe('redis-stream-event-manager');
|
||||
expect(MockGatewayStreamNotifier).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -3,6 +3,7 @@ import debug from 'debug';
|
||||
import { appEnv } from '@/envs/app';
|
||||
|
||||
import { AgentStateManager } from './AgentStateManager';
|
||||
import { GatewayStreamNotifier } from './GatewayStreamNotifier';
|
||||
import { inMemoryAgentStateManager } from './InMemoryAgentStateManager';
|
||||
import { inMemoryStreamEventManager } from './InMemoryStreamEventManager';
|
||||
import { getAgentRuntimeRedisClient } from './redis';
|
||||
@@ -54,17 +55,31 @@ export const createAgentStateManager = (): IAgentStateManager => {
|
||||
* - If Redis is unavailable and enableQueueAgentRuntime=true: throw
|
||||
*/
|
||||
export const createStreamEventManager = (): IStreamEventManager => {
|
||||
let manager: IStreamEventManager;
|
||||
|
||||
// Prefer Redis whenever it is available so the runtime worker and SSE route
|
||||
// can communicate through the same stream bus even in local mode.
|
||||
if (isRedisAvailable()) {
|
||||
log('Redis available, using StreamEventManager');
|
||||
return new StreamEventManager();
|
||||
}
|
||||
|
||||
if (!isQueueModeEnabled()) {
|
||||
manager = new StreamEventManager();
|
||||
} else if (!isQueueModeEnabled()) {
|
||||
log('Redis unavailable and queue mode disabled, using InMemoryStreamEventManager');
|
||||
return inMemoryStreamEventManager;
|
||||
manager = inMemoryStreamEventManager;
|
||||
} else {
|
||||
throw new Error(
|
||||
'Redis is required when AGENT_RUNTIME_MODE=queue. Please configure `REDIS_URL`.',
|
||||
);
|
||||
}
|
||||
|
||||
throw new Error('Redis is required when AGENT_RUNTIME_MODE=queue. Please configure `REDIS_URL`.');
|
||||
// Wrap with Gateway notifier when configured
|
||||
if (appEnv.AGENT_GATEWAY_SERVICE_TOKEN) {
|
||||
log('Wrapping with GatewayStreamNotifier (%s)', appEnv.AGENT_GATEWAY_URL);
|
||||
return new GatewayStreamNotifier(
|
||||
manager,
|
||||
appEnv.AGENT_GATEWAY_URL,
|
||||
appEnv.AGENT_GATEWAY_SERVICE_TOKEN,
|
||||
);
|
||||
}
|
||||
|
||||
return manager;
|
||||
};
|
||||
|
||||
@@ -2,6 +2,7 @@ export type { AgentRuntimeCoordinatorOptions } from './AgentRuntimeCoordinator';
|
||||
export { AgentRuntimeCoordinator } from './AgentRuntimeCoordinator';
|
||||
export { AgentStateManager } from './AgentStateManager';
|
||||
export { createAgentStateManager, createStreamEventManager, isRedisAvailable } from './factory';
|
||||
export { GatewayStreamNotifier } from './GatewayStreamNotifier';
|
||||
export { InMemoryAgentStateManager } from './InMemoryAgentStateManager';
|
||||
export { InMemoryStreamEventManager } from './InMemoryStreamEventManager';
|
||||
export { createRuntimeExecutors } from './RuntimeExecutors';
|
||||
|
||||
@@ -85,16 +85,20 @@ export const agentDocumentRouter = router({
|
||||
z.object({
|
||||
agentId: z.string(),
|
||||
content: z.string(),
|
||||
createdAt: z.date().optional(),
|
||||
filename: z.string(),
|
||||
metadata: metadataSchema.optional(),
|
||||
updatedAt: z.date().optional(),
|
||||
}),
|
||||
)
|
||||
.mutation(async ({ ctx, input }) => {
|
||||
return ctx.agentDocumentService.upsertDocument({
|
||||
agentId: input.agentId,
|
||||
content: input.content,
|
||||
createdAt: input.createdAt,
|
||||
filename: input.filename,
|
||||
metadata: input.metadata,
|
||||
updatedAt: input.updatedAt,
|
||||
});
|
||||
}),
|
||||
|
||||
|
||||
@@ -14,31 +14,28 @@ import { FileModel } from '@/database/models/file';
|
||||
import { MessageModel } from '@/database/models/message';
|
||||
import { knowledgeBaseFiles } from '@/database/schemas';
|
||||
import { authedProcedure, router } from '@/libs/trpc/lambda';
|
||||
import { keyVaults, serverDatabase } from '@/libs/trpc/lambda/middleware';
|
||||
import { serverDatabase } from '@/libs/trpc/lambda/middleware';
|
||||
import { getServerDefaultFilesConfig } from '@/server/globalConfig';
|
||||
import { initModelRuntimeFromDB } from '@/server/modules/ModelRuntime';
|
||||
import { ChunkService } from '@/server/services/chunk';
|
||||
import { DocumentService } from '@/server/services/document';
|
||||
|
||||
const chunkProcedure = authedProcedure
|
||||
.use(serverDatabase)
|
||||
.use(keyVaults)
|
||||
.use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
const chunkProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: {
|
||||
asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId),
|
||||
chunkModel: new ChunkModel(ctx.serverDB, ctx.userId),
|
||||
chunkService: new ChunkService(ctx.serverDB, ctx.userId),
|
||||
documentModel: new DocumentModel(ctx.serverDB, ctx.userId),
|
||||
documentService: new DocumentService(ctx.serverDB, ctx.userId),
|
||||
embeddingModel: new EmbeddingModel(ctx.serverDB, ctx.userId),
|
||||
fileModel: new FileModel(ctx.serverDB, ctx.userId),
|
||||
messageModel: new MessageModel(ctx.serverDB, ctx.userId),
|
||||
},
|
||||
});
|
||||
return opts.next({
|
||||
ctx: {
|
||||
asyncTaskModel: new AsyncTaskModel(ctx.serverDB, ctx.userId),
|
||||
chunkModel: new ChunkModel(ctx.serverDB, ctx.userId),
|
||||
chunkService: new ChunkService(ctx.serverDB, ctx.userId),
|
||||
documentModel: new DocumentModel(ctx.serverDB, ctx.userId),
|
||||
documentService: new DocumentService(ctx.serverDB, ctx.userId),
|
||||
embeddingModel: new EmbeddingModel(ctx.serverDB, ctx.userId),
|
||||
fileModel: new FileModel(ctx.serverDB, ctx.userId),
|
||||
messageModel: new MessageModel(ctx.serverDB, ctx.userId),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
/**
|
||||
* Group chunks by file and calculate relevance scores
|
||||
|
||||
@@ -24,27 +24,24 @@ import {
|
||||
EvaluationRecordModel,
|
||||
} from '@/database/models/ragEval';
|
||||
import { authedProcedure, router } from '@/libs/trpc/lambda';
|
||||
import { keyVaults, serverDatabase } from '@/libs/trpc/lambda/middleware';
|
||||
import { serverDatabase } from '@/libs/trpc/lambda/middleware';
|
||||
import { createAsyncCaller } from '@/server/routers/async';
|
||||
import { FileService } from '@/server/services/file';
|
||||
|
||||
const ragEvalProcedure = authedProcedure
|
||||
.use(serverDatabase)
|
||||
.use(keyVaults)
|
||||
.use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
const ragEvalProcedure = authedProcedure.use(serverDatabase).use(async (opts) => {
|
||||
const { ctx } = opts;
|
||||
|
||||
return opts.next({
|
||||
ctx: {
|
||||
datasetModel: new EvalDatasetModel(ctx.serverDB, ctx.userId),
|
||||
fileModel: new FileModel(ctx.serverDB, ctx.userId),
|
||||
datasetRecordModel: new EvalDatasetRecordModel(ctx.serverDB, ctx.userId),
|
||||
evaluationModel: new EvalEvaluationModel(ctx.serverDB, ctx.userId),
|
||||
evaluationRecordModel: new EvaluationRecordModel(ctx.serverDB, ctx.userId),
|
||||
fileService: new FileService(ctx.serverDB, ctx.userId),
|
||||
},
|
||||
});
|
||||
return opts.next({
|
||||
ctx: {
|
||||
datasetModel: new EvalDatasetModel(ctx.serverDB, ctx.userId),
|
||||
fileModel: new FileModel(ctx.serverDB, ctx.userId),
|
||||
datasetRecordModel: new EvalDatasetRecordModel(ctx.serverDB, ctx.userId),
|
||||
evaluationModel: new EvalEvaluationModel(ctx.serverDB, ctx.userId),
|
||||
evaluationRecordModel: new EvaluationRecordModel(ctx.serverDB, ctx.userId),
|
||||
fileService: new FileService(ctx.serverDB, ctx.userId),
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
export const ragEvalRouter = router({
|
||||
createDataset: ragEvalProcedure
|
||||
|
||||
@@ -37,7 +37,7 @@ vi.mock('@/database/models/userMemory', async (importOriginal) => {
|
||||
});
|
||||
|
||||
const embeddingsMock = vi.fn();
|
||||
const mockCtx = { authorizationHeader: 'Bearer mock-token', userId: 'test-user' };
|
||||
const mockCtx = { userId: 'test-user' };
|
||||
const makeServerDBMock = (query: Record<string, any> = {}) => ({
|
||||
query: {
|
||||
userSettings: {
|
||||
|
||||
@@ -7,10 +7,7 @@ import { SearXNGClient } from '@/server/services/search/impls/searxng/client';
|
||||
|
||||
import { searchRouter } from './search';
|
||||
|
||||
// Mock JWT verification
|
||||
vi.mock('@lobechat/utils/server', () => ({
|
||||
getXorPayload: vi.fn().mockReturnValue({ userId: '1' }),
|
||||
}));
|
||||
// Mock removed: XOR payload is no longer used for authentication
|
||||
|
||||
vi.mock('@lobechat/web-crawler', () => ({
|
||||
Crawler: vi.fn().mockImplementation(() => ({
|
||||
|
||||
@@ -20,12 +20,14 @@ const MAX_UNIQUE_FILENAME_ATTEMPTS = 1000;
|
||||
interface UpsertDocumentParams {
|
||||
agentId: string;
|
||||
content: string;
|
||||
createdAt?: Date;
|
||||
filename: string;
|
||||
loadPosition?: DocumentLoadPosition;
|
||||
loadRules?: DocumentLoadRules;
|
||||
metadata?: Record<string, any>;
|
||||
policy?: AgentDocumentPolicy;
|
||||
templateId?: string;
|
||||
updatedAt?: Date;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -207,6 +209,8 @@ export class AgentDocumentsService {
|
||||
templateId,
|
||||
metadata,
|
||||
policy,
|
||||
createdAt,
|
||||
updatedAt,
|
||||
}: UpsertDocumentParams) {
|
||||
return this.agentDocumentModel.upsert(
|
||||
agentId,
|
||||
@@ -217,6 +221,8 @@ export class AgentDocumentsService {
|
||||
templateId,
|
||||
metadata,
|
||||
policy,
|
||||
createdAt,
|
||||
updatedAt,
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -338,7 +338,6 @@ describe('AgentRuntimeService', () => {
|
||||
...mockParams,
|
||||
hooks: [{ handler: vi.fn(), id: 'hook-1', type: 'onComplete' }],
|
||||
signal: controller.signal,
|
||||
stepCallbacks: { onComplete: vi.fn() },
|
||||
}),
|
||||
).rejects.toMatchObject({
|
||||
message: 'startup aborted',
|
||||
@@ -347,7 +346,6 @@ describe('AgentRuntimeService', () => {
|
||||
|
||||
expect(mockQueueService.scheduleMessage).not.toHaveBeenCalled();
|
||||
expect(mockCoordinator.deleteAgentOperation).toHaveBeenCalledWith('test-operation-1');
|
||||
expect(service.getStepCallbacks('test-operation-1')).toBeUndefined();
|
||||
expect(hookDispatcher.hasHooks('test-operation-1')).toBe(false);
|
||||
});
|
||||
});
|
||||
@@ -464,22 +462,19 @@ describe('AgentRuntimeService', () => {
|
||||
});
|
||||
});
|
||||
|
||||
it('should call onComplete with error in finalState when execution fails', async () => {
|
||||
it('should dispatch onComplete hook with error in finalState when execution fails', async () => {
|
||||
const error = new Error('Runtime error');
|
||||
const mockRuntime = { step: vi.fn().mockRejectedValue(error) };
|
||||
vi.spyOn(service as any, 'createAgentRuntime').mockReturnValue({ runtime: mockRuntime });
|
||||
|
||||
// Register onComplete callback
|
||||
const mockOnComplete = vi.fn();
|
||||
service.registerStepCallbacks('test-operation-1', {
|
||||
onComplete: mockOnComplete,
|
||||
});
|
||||
const dispatchSpy = vi.spyOn(hookDispatcher, 'dispatch').mockResolvedValue(undefined);
|
||||
|
||||
await expect(service.executeStep(mockParams)).rejects.toThrow('Runtime error');
|
||||
|
||||
// Verify onComplete is called with error in finalState as ChatMessageError
|
||||
// ChatErrorType.InternalServerError = 500
|
||||
expect(mockOnComplete).toHaveBeenCalledWith(
|
||||
// Verify onComplete hooks dispatched with error in finalState as ChatMessageError
|
||||
expect(dispatchSpy).toHaveBeenCalledWith(
|
||||
'test-operation-1',
|
||||
'onComplete',
|
||||
expect.objectContaining({
|
||||
operationId: 'test-operation-1',
|
||||
reason: 'error',
|
||||
@@ -491,10 +486,13 @@ describe('AgentRuntimeService', () => {
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
undefined,
|
||||
);
|
||||
|
||||
dispatchSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should call onComplete with ChatCompletionErrorPayload in finalState', async () => {
|
||||
it('should dispatch onComplete hook with ChatCompletionErrorPayload in finalState', async () => {
|
||||
// Simulate LLM error format: { errorType: 'InvalidProviderAPIKey', error: { ... } }
|
||||
const llmError = {
|
||||
errorType: 'InvalidProviderAPIKey',
|
||||
@@ -504,16 +502,14 @@ describe('AgentRuntimeService', () => {
|
||||
const mockRuntime = { step: vi.fn().mockRejectedValue(llmError) };
|
||||
vi.spyOn(service as any, 'createAgentRuntime').mockReturnValue({ runtime: mockRuntime });
|
||||
|
||||
// Register onComplete callback
|
||||
const mockOnComplete = vi.fn();
|
||||
service.registerStepCallbacks('test-operation-1', {
|
||||
onComplete: mockOnComplete,
|
||||
});
|
||||
const dispatchSpy = vi.spyOn(hookDispatcher, 'dispatch').mockResolvedValue(undefined);
|
||||
|
||||
await expect(service.executeStep(mockParams)).rejects.toEqual(llmError);
|
||||
|
||||
// Verify error is formatted correctly with type from errorType
|
||||
expect(mockOnComplete).toHaveBeenCalledWith(
|
||||
expect(dispatchSpy).toHaveBeenCalledWith(
|
||||
'test-operation-1',
|
||||
'onComplete',
|
||||
expect.objectContaining({
|
||||
operationId: 'test-operation-1',
|
||||
reason: 'error',
|
||||
@@ -525,7 +521,10 @@ describe('AgentRuntimeService', () => {
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
undefined,
|
||||
);
|
||||
|
||||
dispatchSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should save error state to coordinator for later retrieval (inMemory mode fix)', async () => {
|
||||
@@ -665,8 +664,7 @@ describe('AgentRuntimeService', () => {
|
||||
});
|
||||
|
||||
it('should extract tool output from data field for single tool_result', async () => {
|
||||
const mockOnAfterStep = vi.fn();
|
||||
service.registerStepCallbacks('test-operation-1', { onAfterStep: mockOnAfterStep });
|
||||
const dispatchSpy = vi.spyOn(hookDispatcher, 'dispatch').mockResolvedValue(undefined);
|
||||
|
||||
const mockStepResult = {
|
||||
newState: { ...mockState, stepCount: 2, status: 'running' },
|
||||
@@ -694,7 +692,9 @@ describe('AgentRuntimeService', () => {
|
||||
|
||||
await service.executeStep(mockParams);
|
||||
|
||||
expect(mockOnAfterStep).toHaveBeenCalledWith(
|
||||
expect(dispatchSpy).toHaveBeenCalledWith(
|
||||
'test-operation-1',
|
||||
'afterStep',
|
||||
expect.objectContaining({
|
||||
toolsResult: [
|
||||
expect.objectContaining({
|
||||
@@ -704,12 +704,14 @@ describe('AgentRuntimeService', () => {
|
||||
}),
|
||||
],
|
||||
}),
|
||||
undefined,
|
||||
);
|
||||
|
||||
dispatchSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should extract tool output from data field for tools_batch_result', async () => {
|
||||
const mockOnAfterStep = vi.fn();
|
||||
service.registerStepCallbacks('test-operation-1', { onAfterStep: mockOnAfterStep });
|
||||
const dispatchSpy = vi.spyOn(hookDispatcher, 'dispatch').mockResolvedValue(undefined);
|
||||
|
||||
const mockStepResult = {
|
||||
newState: { ...mockState, stepCount: 2, status: 'running' },
|
||||
@@ -750,7 +752,9 @@ describe('AgentRuntimeService', () => {
|
||||
|
||||
await service.executeStep(mockParams);
|
||||
|
||||
expect(mockOnAfterStep).toHaveBeenCalledWith(
|
||||
expect(dispatchSpy).toHaveBeenCalledWith(
|
||||
'test-operation-1',
|
||||
'afterStep',
|
||||
expect.objectContaining({
|
||||
toolsResult: [
|
||||
expect.objectContaining({
|
||||
@@ -765,12 +769,14 @@ describe('AgentRuntimeService', () => {
|
||||
}),
|
||||
],
|
||||
}),
|
||||
undefined,
|
||||
);
|
||||
|
||||
dispatchSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should handle tool result with undefined data', async () => {
|
||||
const mockOnAfterStep = vi.fn();
|
||||
service.registerStepCallbacks('test-operation-1', { onAfterStep: mockOnAfterStep });
|
||||
const dispatchSpy = vi.spyOn(hookDispatcher, 'dispatch').mockResolvedValue(undefined);
|
||||
|
||||
const mockStepResult = {
|
||||
newState: { ...mockState, stepCount: 2, status: 'running' },
|
||||
@@ -796,7 +802,9 @@ describe('AgentRuntimeService', () => {
|
||||
|
||||
await service.executeStep(mockParams);
|
||||
|
||||
expect(mockOnAfterStep).toHaveBeenCalledWith(
|
||||
expect(dispatchSpy).toHaveBeenCalledWith(
|
||||
'test-operation-1',
|
||||
'afterStep',
|
||||
expect.objectContaining({
|
||||
toolsResult: [
|
||||
expect.objectContaining({
|
||||
@@ -806,7 +814,10 @@ describe('AgentRuntimeService', () => {
|
||||
}),
|
||||
],
|
||||
}),
|
||||
undefined,
|
||||
);
|
||||
|
||||
dispatchSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -32,7 +32,6 @@ import {
|
||||
type StartExecutionParams,
|
||||
type StartExecutionResult,
|
||||
type StepCompletionReason,
|
||||
type StepLifecycleCallbacks,
|
||||
type StepPresentationData,
|
||||
} from './types';
|
||||
|
||||
@@ -127,11 +126,6 @@ export class AgentRuntimeService {
|
||||
private queueService: QueueService | null;
|
||||
private snapshotStore: ISnapshotStore | null;
|
||||
private toolExecutionService: ToolExecutionService;
|
||||
/**
|
||||
* Step lifecycle callback registry
|
||||
* key: operationId, value: callbacks
|
||||
*/
|
||||
private stepCallbacks: Map<string, StepLifecycleCallbacks> = new Map();
|
||||
private get baseURL() {
|
||||
const baseUrl = process.env.AGENT_RUNTIME_BASE_URL || appEnv.APP_URL || 'http://localhost:3010';
|
||||
|
||||
@@ -186,35 +180,6 @@ export class AgentRuntimeService {
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== Step Lifecycle Callbacks ====================
|
||||
|
||||
/**
|
||||
* Register step lifecycle callbacks
|
||||
* @param operationId - Operation ID
|
||||
* @param callbacks - Callback function collection
|
||||
*/
|
||||
registerStepCallbacks(operationId: string, callbacks: StepLifecycleCallbacks): void {
|
||||
this.stepCallbacks.set(operationId, callbacks);
|
||||
log('[%s] Registered step callbacks', operationId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove step lifecycle callbacks
|
||||
* @param operationId - Operation ID
|
||||
*/
|
||||
unregisterStepCallbacks(operationId: string): void {
|
||||
this.stepCallbacks.delete(operationId);
|
||||
log('[%s] Unregistered step callbacks', operationId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get step lifecycle callbacks
|
||||
* @param operationId - Operation ID
|
||||
*/
|
||||
getStepCallbacks(operationId: string): StepLifecycleCallbacks | undefined {
|
||||
return this.stepCallbacks.get(operationId);
|
||||
}
|
||||
|
||||
// ==================== Operation Interruption ====================
|
||||
|
||||
/**
|
||||
@@ -260,14 +225,10 @@ export class AgentRuntimeService {
|
||||
initialMessages = [],
|
||||
appContext,
|
||||
toolSet,
|
||||
stepCallbacks,
|
||||
hooks,
|
||||
userInterventionConfig,
|
||||
completionWebhook,
|
||||
stepWebhook,
|
||||
queueRetries,
|
||||
queueRetryDelay,
|
||||
webhookDelivery,
|
||||
botPlatformContext,
|
||||
discordContext,
|
||||
evalContext,
|
||||
@@ -282,7 +243,6 @@ export class AgentRuntimeService {
|
||||
|
||||
const operationToolSet = toolSet;
|
||||
let operationCreated = false;
|
||||
let stepCallbacksRegistered = false;
|
||||
let hooksRegistered = false;
|
||||
|
||||
try {
|
||||
@@ -315,7 +275,6 @@ export class AgentRuntimeService {
|
||||
activeDeviceId,
|
||||
agentConfig,
|
||||
botPlatformContext,
|
||||
completionWebhook,
|
||||
deviceSystemInfo,
|
||||
discordContext,
|
||||
evalContext,
|
||||
@@ -323,13 +282,11 @@ export class AgentRuntimeService {
|
||||
modelRuntimeConfig,
|
||||
queueRetries,
|
||||
queueRetryDelay,
|
||||
stepWebhook,
|
||||
stream,
|
||||
operationSkillSet,
|
||||
userId,
|
||||
userMemory,
|
||||
userTimezone,
|
||||
webhookDelivery,
|
||||
workingDirectory: agentConfig?.chatConfig?.runtimeEnv?.workingDirectory,
|
||||
...appContext,
|
||||
},
|
||||
@@ -359,12 +316,6 @@ export class AgentRuntimeService {
|
||||
// Save initial state
|
||||
await this.coordinator.saveAgentState(operationId, initialState as any);
|
||||
|
||||
// Register step lifecycle callbacks
|
||||
if (stepCallbacks) {
|
||||
this.registerStepCallbacks(operationId, stepCallbacks);
|
||||
stepCallbacksRegistered = true;
|
||||
}
|
||||
|
||||
// Register external hooks
|
||||
if (hooks && hooks.length > 0) {
|
||||
hookDispatcher.register(operationId, hooks);
|
||||
@@ -416,10 +367,6 @@ export class AgentRuntimeService {
|
||||
return { autoStarted, messageId, operationId, success: true };
|
||||
} catch (error) {
|
||||
if (isAbortError(error)) {
|
||||
if (stepCallbacksRegistered) {
|
||||
this.unregisterStepCallbacks(operationId);
|
||||
}
|
||||
|
||||
if (hooksRegistered) {
|
||||
hookDispatcher.unregister(operationId);
|
||||
}
|
||||
@@ -455,8 +402,6 @@ export class AgentRuntimeService {
|
||||
externalRetryCount = 0,
|
||||
} = params;
|
||||
|
||||
const callbacks = this.getStepCallbacks(operationId);
|
||||
|
||||
// ===== Distributed lock: prevent duplicate execution from QStash retries =====
|
||||
const claimed = await this.coordinator.tryClaimStep(operationId, stepIndex, 35);
|
||||
if (!claimed) {
|
||||
@@ -527,19 +472,8 @@ export class AgentRuntimeService {
|
||||
|
||||
const reason = this.determineCompletionReason(agentState);
|
||||
|
||||
// Trigger completion callback so eval run can finalize properly
|
||||
if (callbacks?.onComplete) {
|
||||
try {
|
||||
await callbacks.onComplete({
|
||||
finalState: agentState,
|
||||
operationId,
|
||||
reason,
|
||||
});
|
||||
this.unregisterStepCallbacks(operationId);
|
||||
} catch (callbackError) {
|
||||
log('[%s] onComplete callback error: %O', operationId, callbackError);
|
||||
}
|
||||
}
|
||||
// Dispatch completion hooks so consumers (e.g., bot local-mode promise) can finalize
|
||||
await this.dispatchCompletionHooks(operationId, agentState, reason);
|
||||
|
||||
return {
|
||||
nextStepScheduled: false,
|
||||
@@ -549,20 +483,6 @@ export class AgentRuntimeService {
|
||||
};
|
||||
}
|
||||
|
||||
// Call onBeforeStep callback (legacy)
|
||||
if (callbacks?.onBeforeStep) {
|
||||
try {
|
||||
await callbacks.onBeforeStep({
|
||||
context,
|
||||
operationId,
|
||||
state: agentState,
|
||||
stepIndex,
|
||||
});
|
||||
} catch (callbackError) {
|
||||
log('[%s] onBeforeStep callback error: %O', operationId, callbackError);
|
||||
}
|
||||
}
|
||||
|
||||
// Dispatch beforeStep hooks
|
||||
try {
|
||||
const beforeStepMetadata = agentState?.metadata || {};
|
||||
@@ -823,36 +743,43 @@ export class AgentRuntimeService {
|
||||
totalTokens: totalTokensNum,
|
||||
};
|
||||
|
||||
// Call onAfterStep callback with presentation data (legacy)
|
||||
if (callbacks?.onAfterStep) {
|
||||
try {
|
||||
await callbacks.onAfterStep({
|
||||
...stepPresentationData,
|
||||
operationId,
|
||||
shouldContinue,
|
||||
state: stepResult.newState,
|
||||
stepIndex,
|
||||
stepResult,
|
||||
});
|
||||
} catch (callbackError) {
|
||||
log('[%s] onAfterStep callback error: %O', operationId, callbackError);
|
||||
}
|
||||
}
|
||||
|
||||
// Dispatch afterStep hooks
|
||||
// Dispatch afterStep hooks (enriched with step presentation + tracking data)
|
||||
try {
|
||||
const metadata = stepResult.newState?.metadata || {};
|
||||
const tracking = metadata._stepTracking || {};
|
||||
const elapsedMs = stepResult.newState?.createdAt
|
||||
? Date.now() - new Date(stepResult.newState.createdAt).getTime()
|
||||
: undefined;
|
||||
await hookDispatcher.dispatch(
|
||||
operationId,
|
||||
'afterStep',
|
||||
{
|
||||
agentId: metadata?.agentId || '',
|
||||
content,
|
||||
elapsedMs,
|
||||
executionTimeMs: stepPresentationData.executionTimeMs,
|
||||
finalState: stepResult.newState,
|
||||
lastLLMContent: tracking.lastLLMContent,
|
||||
lastToolsCalling: tracking.lastToolsCalling,
|
||||
operationId,
|
||||
reasoning: stepPresentationData.reasoning,
|
||||
shouldContinue,
|
||||
status: stepResult.newState?.status,
|
||||
stepCost: stepPresentationData.stepCost,
|
||||
stepIndex,
|
||||
stepType: stepPresentationData.stepType,
|
||||
steps: stepResult.newState?.stepCount || 0,
|
||||
thinking: stepPresentationData.thinking,
|
||||
toolCalls: stepResult.newState?.usage?.tools?.totalCalls,
|
||||
toolsCalling: stepPresentationData.toolsCalling,
|
||||
toolsResult: stepPresentationData.toolsResult,
|
||||
topicId: metadata?.topicId,
|
||||
totalCost: stepPresentationData.totalCost,
|
||||
totalInputTokens: stepPresentationData.totalInputTokens,
|
||||
totalOutputTokens: stepPresentationData.totalOutputTokens,
|
||||
totalSteps: stepPresentationData.totalSteps,
|
||||
totalTokens: stepPresentationData.totalTokens,
|
||||
totalToolCalls: (tracking.totalToolCalls ?? 0) + (toolsCalling?.length ?? 0),
|
||||
userId: metadata?.userId || this.userId,
|
||||
},
|
||||
metadata._hooks,
|
||||
@@ -965,12 +892,15 @@ export class AgentRuntimeService {
|
||||
}
|
||||
}
|
||||
|
||||
// Update step tracking in state metadata and trigger step webhook
|
||||
if (stepResult.newState.metadata?.stepWebhook) {
|
||||
// Update step tracking in state metadata for afterStep hooks (cross-step accumulator)
|
||||
const hasAfterStepHooks = stepResult.newState.metadata?._hooks?.some(
|
||||
(h: { type: string }) => h.type === 'afterStep',
|
||||
);
|
||||
if (hasAfterStepHooks && stepResult.newState.metadata) {
|
||||
const prevTracking = stepResult.newState.metadata._stepTracking || {};
|
||||
const newTotalToolCalls = (prevTracking.totalToolCalls ?? 0) + (toolsCalling?.length ?? 0);
|
||||
|
||||
// Truncate content to 1800 chars to match Discord message limits
|
||||
// Truncate content to 1800 chars to keep state small
|
||||
const truncatedContent = content
|
||||
? content.length > 1800
|
||||
? content.slice(0, 1800) + '...'
|
||||
@@ -986,13 +916,6 @@ export class AgentRuntimeService {
|
||||
// Persist tracking state for next step
|
||||
stepResult.newState.metadata._stepTracking = updatedTracking;
|
||||
await this.coordinator.saveAgentState(operationId, stepResult.newState);
|
||||
|
||||
// Fire step webhook (include shouldContinue so the callback knows
|
||||
// whether the agent is still running or about to complete)
|
||||
await this.triggerStepWebhook(stepResult.newState, operationId, {
|
||||
...stepPresentationData,
|
||||
shouldContinue,
|
||||
} as unknown as Record<string, unknown>);
|
||||
}
|
||||
|
||||
if (shouldContinue && stepResult.nextContext && this.queueService) {
|
||||
@@ -1025,27 +948,9 @@ export class AgentRuntimeService {
|
||||
if (!shouldContinue) {
|
||||
const reason = this.determineCompletionReason(stepResult.newState);
|
||||
|
||||
// Trigger completion webhook (fire-and-forget)
|
||||
await this.triggerCompletionWebhook(stepResult.newState, operationId, reason);
|
||||
|
||||
// Dispatch onComplete hooks
|
||||
await this.dispatchCompletionHooks(operationId, stepResult.newState, reason);
|
||||
|
||||
// Call onComplete callback (legacy)
|
||||
if (callbacks?.onComplete) {
|
||||
try {
|
||||
await callbacks.onComplete({
|
||||
finalState: stepResult.newState,
|
||||
operationId,
|
||||
reason,
|
||||
});
|
||||
// Clean up callbacks after operation completes
|
||||
this.unregisterStepCallbacks(operationId);
|
||||
} catch (callbackError) {
|
||||
log('[%s] onComplete callback error: %O', operationId, callbackError);
|
||||
}
|
||||
}
|
||||
|
||||
// Finalize tracing snapshot via injected snapshot store
|
||||
if (this.snapshotStore) {
|
||||
try {
|
||||
@@ -1157,30 +1062,9 @@ export class AgentRuntimeService {
|
||||
log('[%s] Failed to save error state (infra may be down): %O', operationId, saveError);
|
||||
}
|
||||
|
||||
// Trigger completion webhook on error (fire-and-forget)
|
||||
try {
|
||||
await this.triggerCompletionWebhook(finalStateWithError, operationId, 'error');
|
||||
} catch (webhookError) {
|
||||
log('[%s] Failed to trigger completion webhook: %O', operationId, webhookError);
|
||||
}
|
||||
|
||||
// Dispatch onComplete + onError hooks
|
||||
await this.dispatchCompletionHooks(operationId, finalStateWithError, 'error');
|
||||
|
||||
// Also call onComplete callback when execution fails (legacy)
|
||||
if (callbacks?.onComplete) {
|
||||
try {
|
||||
await callbacks.onComplete({
|
||||
finalState: finalStateWithError,
|
||||
operationId,
|
||||
reason: 'error',
|
||||
});
|
||||
this.unregisterStepCallbacks(operationId);
|
||||
} catch (callbackError) {
|
||||
log('[%s] onComplete callback error in catch: %O', operationId, callbackError);
|
||||
}
|
||||
}
|
||||
|
||||
throw error;
|
||||
} finally {
|
||||
// Release lock so legitimate retries or next operations can proceed.
|
||||
@@ -1648,41 +1532,6 @@ export class AgentRuntimeService {
|
||||
return { newState: state, nextContext: undefined };
|
||||
}
|
||||
|
||||
/**
|
||||
* Deliver a webhook payload via fetch or QStash.
|
||||
* Fire-and-forget: errors are logged but never thrown.
|
||||
*/
|
||||
private async deliverWebhook(
|
||||
url: string,
|
||||
payload: Record<string, unknown>,
|
||||
delivery: 'fetch' | 'qstash' = 'fetch',
|
||||
operationId: string,
|
||||
): Promise<void> {
|
||||
try {
|
||||
if (delivery === 'qstash') {
|
||||
const { Client } = await import('@upstash/qstash');
|
||||
const client = new Client({ token: process.env.QSTASH_TOKEN! });
|
||||
await client.publishJSON({
|
||||
body: payload,
|
||||
headers: {
|
||||
...(process.env.VERCEL_AUTOMATION_BYPASS_SECRET && {
|
||||
'x-vercel-protection-bypass': process.env.VERCEL_AUTOMATION_BYPASS_SECRET,
|
||||
}),
|
||||
},
|
||||
url,
|
||||
});
|
||||
} else {
|
||||
await fetch(url, {
|
||||
body: JSON.stringify(payload),
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
method: 'POST',
|
||||
});
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('[%s] Webhook delivery failed (%s → %s):', operationId, delivery, url, error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispatch onComplete (and onError) hooks via HookDispatcher.
|
||||
* Fire-and-forget: errors are logged but never thrown.
|
||||
@@ -1695,7 +1544,7 @@ export class AgentRuntimeService {
|
||||
try {
|
||||
const metadata = state?.metadata || {};
|
||||
|
||||
// Extract last assistant content (same as triggerCompletionWebhook)
|
||||
// Extract last assistant content from state messages
|
||||
const lastAssistantContent = state?.messages
|
||||
?.slice()
|
||||
.reverse()
|
||||
@@ -1742,101 +1591,6 @@ export class AgentRuntimeService {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Trigger completion webhook if configured in state metadata.
|
||||
* Fire-and-forget: errors are logged but never thrown.
|
||||
*/
|
||||
private async triggerCompletionWebhook(
|
||||
state: any,
|
||||
operationId: string,
|
||||
reason: StepCompletionReason,
|
||||
): Promise<void> {
|
||||
const webhook = state.metadata?.completionWebhook;
|
||||
if (!webhook?.url) return;
|
||||
|
||||
log('[%s] Triggering completion webhook: %s', operationId, webhook.url);
|
||||
|
||||
const duration = state.createdAt ? Date.now() - new Date(state.createdAt).getTime() : undefined;
|
||||
|
||||
// Extract last assistant content from state messages
|
||||
const lastAssistantContent = state.messages
|
||||
?.slice()
|
||||
.reverse()
|
||||
.find(
|
||||
(m: { content?: string; role: string }) => m.role === 'assistant' && m.content,
|
||||
)?.content;
|
||||
|
||||
// Extract first user prompt for downstream consumers (e.g., topic title summarization)
|
||||
const userPrompt = state.messages?.find(
|
||||
(m: { content?: string; role: string }) => m.role === 'user',
|
||||
)?.content;
|
||||
|
||||
const delivery = state.metadata?.webhookDelivery || 'fetch';
|
||||
|
||||
await this.deliverWebhook(
|
||||
webhook.url,
|
||||
{
|
||||
...webhook.body,
|
||||
cost: state.cost?.total,
|
||||
duration,
|
||||
errorDetail: state.error,
|
||||
errorMessage: this.extractErrorMessage(state.error),
|
||||
lastAssistantContent,
|
||||
llmCalls: state.usage?.llm?.apiCalls,
|
||||
operationId,
|
||||
reason,
|
||||
status: state.status,
|
||||
steps: state.stepCount,
|
||||
toolCalls: state.usage?.tools?.totalCalls,
|
||||
topicId: state.metadata?.topicId,
|
||||
totalTokens: state.usage?.llm?.tokens?.total,
|
||||
type: 'completion',
|
||||
userId: state.metadata?.userId,
|
||||
userPrompt,
|
||||
},
|
||||
delivery,
|
||||
operationId,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Trigger step webhook if configured in state metadata.
|
||||
* Reads accumulated step tracking data and fires webhook with step presentation data.
|
||||
* Fire-and-forget: errors are logged but never thrown.
|
||||
*/
|
||||
private async triggerStepWebhook(
|
||||
state: any,
|
||||
operationId: string,
|
||||
presentationData: Record<string, unknown>,
|
||||
): Promise<void> {
|
||||
const webhook = state.metadata?.stepWebhook;
|
||||
if (!webhook?.url) return;
|
||||
|
||||
log('[%s] Triggering step webhook: %s', operationId, webhook.url);
|
||||
|
||||
const tracking = state.metadata?._stepTracking || {};
|
||||
const delivery = state.metadata?.webhookDelivery || 'fetch';
|
||||
const elapsedMs = state.createdAt
|
||||
? Date.now() - new Date(state.createdAt).getTime()
|
||||
: undefined;
|
||||
|
||||
await this.deliverWebhook(
|
||||
webhook.url,
|
||||
{
|
||||
...webhook.body,
|
||||
...presentationData,
|
||||
elapsedMs,
|
||||
lastLLMContent: tracking.lastLLMContent,
|
||||
lastToolsCalling: tracking.lastToolsCalling,
|
||||
operationId,
|
||||
totalToolCalls: tracking.totalToolCalls ?? 0,
|
||||
type: 'step',
|
||||
},
|
||||
delivery,
|
||||
operationId,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract a human-readable error message from the agent state error object.
|
||||
* Handles both raw ChatCompletionErrorPayload (from runtime.step catch) and
|
||||
@@ -2048,20 +1802,10 @@ export class AgentRuntimeService {
|
||||
|
||||
if (stepIndex >= maxSteps) {
|
||||
log('[%s] Sync execution stopped: reached maxSteps (%d)', operationId, maxSteps);
|
||||
// If stopped due to executeSync's maxSteps limit, need to manually call onComplete
|
||||
// If stopped due to executeSync's maxSteps limit, need to manually dispatch onComplete hooks
|
||||
// Note: If stopped due to state.maxSteps being reached, onComplete has already been called in executeStep
|
||||
const callbacks = this.getStepCallbacks(operationId);
|
||||
if (callbacks?.onComplete && state.status !== 'done' && state.status !== 'error') {
|
||||
try {
|
||||
await callbacks.onComplete({
|
||||
finalState: state,
|
||||
operationId,
|
||||
reason: 'max_steps',
|
||||
});
|
||||
this.unregisterStepCallbacks(operationId);
|
||||
} catch (callbackError) {
|
||||
log('[%s] onComplete callback error in executeSync: %O', operationId, callbackError);
|
||||
}
|
||||
if (state.status !== 'done' && state.status !== 'error') {
|
||||
await this.dispatchCompletionHooks(operationId, state, 'max_steps');
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ vi.mock('@/server/services/toolExecution/builtin', () => ({
|
||||
})),
|
||||
}));
|
||||
|
||||
describe('AgentRuntimeService - Completion Webhook', () => {
|
||||
describe('AgentRuntimeService - Completion Hooks via createOperation', () => {
|
||||
let service: AgentRuntimeService;
|
||||
let stateManager: InMemoryAgentStateManager;
|
||||
let streamEventManager: InMemoryStreamEventManager;
|
||||
@@ -91,19 +91,26 @@ describe('AgentRuntimeService - Completion Webhook', () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe('createOperation persists completionWebhook', () => {
|
||||
it('should persist completionWebhook in state metadata', async () => {
|
||||
const operationId = 'webhook-op-1';
|
||||
const completionWebhook = {
|
||||
body: { runId: 'run-1', testCaseId: 'tc-1' },
|
||||
url: 'https://example.com/webhook',
|
||||
};
|
||||
describe('createOperation persists hooks in metadata', () => {
|
||||
it('should persist hooks in state metadata._hooks', async () => {
|
||||
const operationId = 'hook-op-1';
|
||||
const hooks = [
|
||||
{
|
||||
handler: vi.fn(),
|
||||
id: 'test-completion',
|
||||
type: 'onComplete' as const,
|
||||
webhook: {
|
||||
body: { runId: 'run-1', testCaseId: 'tc-1' },
|
||||
url: 'https://example.com/webhook',
|
||||
},
|
||||
},
|
||||
];
|
||||
|
||||
await service.createOperation({
|
||||
agentConfig: { model: 'gpt-4o', provider: 'openai' },
|
||||
appContext: { agentId: 'test-agent' },
|
||||
autoStart: false,
|
||||
completionWebhook,
|
||||
hooks,
|
||||
initialContext: makeContext(operationId),
|
||||
initialMessages: [{ content: 'Hello', role: 'user' }],
|
||||
modelRuntimeConfig: { model: 'gpt-4o', provider: 'openai' },
|
||||
@@ -113,11 +120,20 @@ describe('AgentRuntimeService - Completion Webhook', () => {
|
||||
});
|
||||
|
||||
const state = await stateManager.loadAgentState(operationId);
|
||||
expect(state?.metadata?.completionWebhook).toEqual(completionWebhook);
|
||||
expect(state?.metadata?._hooks).toEqual([
|
||||
expect.objectContaining({
|
||||
id: 'test-completion',
|
||||
type: 'onComplete',
|
||||
webhook: {
|
||||
body: { runId: 'run-1', testCaseId: 'tc-1' },
|
||||
url: 'https://example.com/webhook',
|
||||
},
|
||||
}),
|
||||
]);
|
||||
});
|
||||
|
||||
it('should not have completionWebhook in metadata when not provided', async () => {
|
||||
const operationId = 'webhook-op-2';
|
||||
it('should not have _hooks in metadata when no hooks provided', async () => {
|
||||
const operationId = 'hook-op-2';
|
||||
|
||||
await service.createOperation({
|
||||
agentConfig: { model: 'gpt-4o', provider: 'openai' },
|
||||
@@ -132,18 +148,18 @@ describe('AgentRuntimeService - Completion Webhook', () => {
|
||||
});
|
||||
|
||||
const state = await stateManager.loadAgentState(operationId);
|
||||
expect(state?.metadata?.completionWebhook).toBeUndefined();
|
||||
expect(state?.metadata?._hooks).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('executeStep triggers webhook', () => {
|
||||
describe('webhook delivery through hooks', () => {
|
||||
const fetchSpy = vi.fn().mockResolvedValue({ ok: true });
|
||||
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal('fetch', fetchSpy);
|
||||
});
|
||||
|
||||
const createOperationWithWebhook = async (
|
||||
const createOperationWithHook = async (
|
||||
operationId: string,
|
||||
webhookUrl: string,
|
||||
webhookBody?: Record<string, unknown>,
|
||||
@@ -152,7 +168,14 @@ describe('AgentRuntimeService - Completion Webhook', () => {
|
||||
agentConfig: { model: 'gpt-4o', provider: 'openai' },
|
||||
appContext: { agentId: 'test-agent' },
|
||||
autoStart: false,
|
||||
completionWebhook: { body: webhookBody, url: webhookUrl },
|
||||
hooks: [
|
||||
{
|
||||
handler: vi.fn(),
|
||||
id: 'test-completion',
|
||||
type: 'onComplete' as const,
|
||||
webhook: { body: webhookBody, url: webhookUrl },
|
||||
},
|
||||
],
|
||||
initialContext: makeContext(operationId),
|
||||
initialMessages: [{ content: 'Hello', role: 'user' }],
|
||||
modelRuntimeConfig: { model: 'gpt-4o', provider: 'openai' },
|
||||
@@ -162,12 +185,12 @@ describe('AgentRuntimeService - Completion Webhook', () => {
|
||||
});
|
||||
};
|
||||
|
||||
it('should trigger webhook when operation completes normally', async () => {
|
||||
const operationId = 'webhook-complete-1';
|
||||
it('should persist webhook hook config for later delivery on completion', async () => {
|
||||
const operationId = 'hook-complete-1';
|
||||
const webhookUrl = 'https://example.com/on-complete';
|
||||
const webhookBody = { runId: 'run-1', testCaseId: 'tc-1' };
|
||||
|
||||
await createOperationWithWebhook(operationId, webhookUrl, webhookBody);
|
||||
await createOperationWithHook(operationId, webhookUrl, webhookBody);
|
||||
|
||||
// Manually set state to simulate a step that produces 'done' status
|
||||
const state = await stateManager.loadAgentState(operationId);
|
||||
@@ -176,21 +199,22 @@ describe('AgentRuntimeService - Completion Webhook', () => {
|
||||
status: 'done',
|
||||
});
|
||||
|
||||
// executeStep will call triggerCompletionWebhook when !shouldContinue
|
||||
// We need the step to actually produce a done state, but since we can't
|
||||
// easily mock the full runtime.step, we test the metadata persistence above
|
||||
// and verify the webhook method is correct through the type + metadata test.
|
||||
|
||||
// Verify the webhook config is persisted for later use
|
||||
// Verify the hook config is persisted for later use
|
||||
const updatedState = await stateManager.loadAgentState(operationId);
|
||||
expect(updatedState?.metadata?.completionWebhook).toEqual({
|
||||
body: webhookBody,
|
||||
url: webhookUrl,
|
||||
});
|
||||
expect(updatedState?.metadata?._hooks).toEqual([
|
||||
expect.objectContaining({
|
||||
id: 'test-completion',
|
||||
type: 'onComplete',
|
||||
webhook: {
|
||||
body: webhookBody,
|
||||
url: webhookUrl,
|
||||
},
|
||||
}),
|
||||
]);
|
||||
});
|
||||
|
||||
it('should NOT trigger webhook when no completionWebhook is configured', async () => {
|
||||
const operationId = 'webhook-none-1';
|
||||
it('should NOT have hook config when no hooks are configured', async () => {
|
||||
const operationId = 'hook-none-1';
|
||||
|
||||
await service.createOperation({
|
||||
agentConfig: { model: 'gpt-4o', provider: 'openai' },
|
||||
@@ -205,41 +229,33 @@ describe('AgentRuntimeService - Completion Webhook', () => {
|
||||
});
|
||||
|
||||
const state = await stateManager.loadAgentState(operationId);
|
||||
expect(state?.metadata?.completionWebhook).toBeUndefined();
|
||||
|
||||
// fetch should not be called for webhook since there's no webhook config
|
||||
// (It may still be called for other reasons in real execution)
|
||||
expect(state?.metadata?._hooks).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should not throw when webhook fetch fails', async () => {
|
||||
const operationId = 'webhook-fail-1';
|
||||
const operationId = 'hook-fail-1';
|
||||
const webhookUrl = 'https://example.com/failing-webhook';
|
||||
|
||||
// Make fetch throw
|
||||
fetchSpy.mockRejectedValueOnce(new Error('Network error'));
|
||||
|
||||
await createOperationWithWebhook(operationId, webhookUrl, { runId: 'run-1' });
|
||||
await createOperationWithHook(operationId, webhookUrl, { runId: 'run-1' });
|
||||
|
||||
// Verify the webhook is stored — the triggerCompletionWebhook method
|
||||
// catches errors internally and doesn't throw
|
||||
// Verify the hook is stored -- the hook dispatch catches errors internally
|
||||
const state = await stateManager.loadAgentState(operationId);
|
||||
expect(state?.metadata?.completionWebhook?.url).toBe(webhookUrl);
|
||||
expect(state?.metadata?._hooks?.[0]?.webhook?.url).toBe(webhookUrl);
|
||||
});
|
||||
});
|
||||
|
||||
describe('triggerCompletionWebhook integration via executeSync', () => {
|
||||
describe('hook payload structure', () => {
|
||||
const fetchSpy = vi.fn().mockResolvedValue({ ok: true });
|
||||
|
||||
beforeEach(() => {
|
||||
vi.stubGlobal('fetch', fetchSpy);
|
||||
});
|
||||
|
||||
it('should include webhook body fields plus operationId/reason/status in POST payload', async () => {
|
||||
// This test verifies the contract of what triggerCompletionWebhook sends.
|
||||
// Since triggerCompletionWebhook is private, we verify through the metadata
|
||||
// and the expected fetch call shape.
|
||||
|
||||
const operationId = 'webhook-payload-test';
|
||||
it('should include webhook body fields in the persisted hook config', async () => {
|
||||
const operationId = 'hook-payload-test';
|
||||
const webhookUrl = 'https://example.com/webhook';
|
||||
const webhookBody = { runId: 'run-123', testCaseId: 'tc-456', userId: 'user-789' };
|
||||
|
||||
@@ -247,7 +263,14 @@ describe('AgentRuntimeService - Completion Webhook', () => {
|
||||
agentConfig: { model: 'gpt-4o', provider: 'openai' },
|
||||
appContext: { agentId: 'test-agent' },
|
||||
autoStart: false,
|
||||
completionWebhook: { body: webhookBody, url: webhookUrl },
|
||||
hooks: [
|
||||
{
|
||||
handler: vi.fn(),
|
||||
id: 'test-completion',
|
||||
type: 'onComplete' as const,
|
||||
webhook: { body: webhookBody, url: webhookUrl },
|
||||
},
|
||||
],
|
||||
initialContext: makeContext(operationId),
|
||||
initialMessages: [{ content: 'Hello', role: 'user' }],
|
||||
modelRuntimeConfig: { model: 'gpt-4o', provider: 'openai' },
|
||||
@@ -256,12 +279,13 @@ describe('AgentRuntimeService - Completion Webhook', () => {
|
||||
userId,
|
||||
});
|
||||
|
||||
// Verify the persisted webhook contains the right structure
|
||||
// Verify the persisted hook contains the right structure
|
||||
const state = await stateManager.loadAgentState(operationId);
|
||||
const webhook = state?.metadata?.completionWebhook;
|
||||
expect(webhook).toBeDefined();
|
||||
expect(webhook.url).toBe(webhookUrl);
|
||||
expect(webhook.body).toEqual(webhookBody);
|
||||
const hooks = state?.metadata?._hooks;
|
||||
expect(hooks).toBeDefined();
|
||||
expect(hooks).toHaveLength(1);
|
||||
expect(hooks[0].webhook.url).toBe(webhookUrl);
|
||||
expect(hooks[0].webhook.body).toEqual(webhookBody);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { AgentRuntimeService } from '../AgentRuntimeService';
|
||||
import { hookDispatcher } from '../hooks';
|
||||
|
||||
// Mock all heavy dependencies to isolate executeStep logic
|
||||
vi.mock('@/envs/app', () => ({ appEnv: { APP_URL: 'http://localhost:3010' } }));
|
||||
@@ -81,7 +82,7 @@ describe('AgentRuntimeService.executeStep - early exit on terminal state', () =>
|
||||
});
|
||||
}
|
||||
|
||||
it('should call onComplete callback when skipping interrupted operation', async () => {
|
||||
it('should dispatch onComplete hook when skipping interrupted operation', async () => {
|
||||
const service = createService();
|
||||
|
||||
const coordinator = (service as any).coordinator;
|
||||
@@ -91,8 +92,7 @@ describe('AgentRuntimeService.executeStep - early exit on terminal state', () =>
|
||||
lastModified: new Date().toISOString(),
|
||||
});
|
||||
|
||||
const onComplete = vi.fn();
|
||||
service.registerStepCallbacks('op-123', { onComplete });
|
||||
const dispatchSpy = vi.spyOn(hookDispatcher, 'dispatch').mockResolvedValue(undefined);
|
||||
|
||||
await service.executeStep({
|
||||
operationId: 'op-123',
|
||||
@@ -100,14 +100,20 @@ describe('AgentRuntimeService.executeStep - early exit on terminal state', () =>
|
||||
context: { phase: 'user_input' } as any,
|
||||
});
|
||||
|
||||
expect(onComplete).toHaveBeenCalledWith({
|
||||
finalState: expect.objectContaining({ status: 'interrupted' }),
|
||||
operationId: 'op-123',
|
||||
reason: 'interrupted',
|
||||
});
|
||||
expect(dispatchSpy).toHaveBeenCalledWith(
|
||||
'op-123',
|
||||
'onComplete',
|
||||
expect.objectContaining({
|
||||
operationId: 'op-123',
|
||||
reason: 'interrupted',
|
||||
}),
|
||||
undefined,
|
||||
);
|
||||
|
||||
dispatchSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should call onComplete with reason "done" when skipping done operation', async () => {
|
||||
it('should dispatch onComplete hook with reason "done" when skipping done operation', async () => {
|
||||
const service = createService();
|
||||
|
||||
const coordinator = (service as any).coordinator;
|
||||
@@ -117,8 +123,7 @@ describe('AgentRuntimeService.executeStep - early exit on terminal state', () =>
|
||||
lastModified: new Date().toISOString(),
|
||||
});
|
||||
|
||||
const onComplete = vi.fn();
|
||||
service.registerStepCallbacks('op-456', { onComplete });
|
||||
const dispatchSpy = vi.spyOn(hookDispatcher, 'dispatch').mockResolvedValue(undefined);
|
||||
|
||||
await service.executeStep({
|
||||
operationId: 'op-456',
|
||||
@@ -126,14 +131,20 @@ describe('AgentRuntimeService.executeStep - early exit on terminal state', () =>
|
||||
context: { phase: 'user_input' } as any,
|
||||
});
|
||||
|
||||
expect(onComplete).toHaveBeenCalledWith({
|
||||
finalState: expect.objectContaining({ status: 'done' }),
|
||||
operationId: 'op-456',
|
||||
reason: 'done',
|
||||
});
|
||||
expect(dispatchSpy).toHaveBeenCalledWith(
|
||||
'op-456',
|
||||
'onComplete',
|
||||
expect.objectContaining({
|
||||
operationId: 'op-456',
|
||||
reason: 'done',
|
||||
}),
|
||||
undefined,
|
||||
);
|
||||
|
||||
dispatchSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should unregister callbacks after onComplete is called on early exit', async () => {
|
||||
it('should unregister hooks after onComplete is dispatched on early exit', async () => {
|
||||
const service = createService();
|
||||
|
||||
const coordinator = (service as any).coordinator;
|
||||
@@ -143,8 +154,8 @@ describe('AgentRuntimeService.executeStep - early exit on terminal state', () =>
|
||||
lastModified: new Date().toISOString(),
|
||||
});
|
||||
|
||||
const onComplete = vi.fn();
|
||||
service.registerStepCallbacks('op-789', { onComplete });
|
||||
const dispatchSpy = vi.spyOn(hookDispatcher, 'dispatch').mockResolvedValue(undefined);
|
||||
const unregisterSpy = vi.spyOn(hookDispatcher, 'unregister');
|
||||
|
||||
await service.executeStep({
|
||||
operationId: 'op-789',
|
||||
@@ -152,8 +163,11 @@ describe('AgentRuntimeService.executeStep - early exit on terminal state', () =>
|
||||
context: { phase: 'user_input' } as any,
|
||||
});
|
||||
|
||||
// Callbacks should be unregistered after onComplete
|
||||
expect(service.getStepCallbacks('op-789')).toBeUndefined();
|
||||
// Hooks should be unregistered after completion dispatch
|
||||
expect(unregisterSpy).toHaveBeenCalledWith('op-789');
|
||||
|
||||
dispatchSpy.mockRestore();
|
||||
unregisterSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should NOT skip step when operation status is "running"', async () => {
|
||||
@@ -301,7 +315,7 @@ describe('AgentRuntimeService.executeStep - Redis failure in error handler', ()
|
||||
return service;
|
||||
};
|
||||
|
||||
it('should still call onComplete when Redis fails in catch block (ECONNRESET scenario)', async () => {
|
||||
it('should still dispatch onComplete hooks when Redis fails in catch block (ECONNRESET scenario)', async () => {
|
||||
const service = createService();
|
||||
const coordinator = (service as any).coordinator;
|
||||
const streamManager = (service as any).streamManager;
|
||||
@@ -336,10 +350,9 @@ describe('AgentRuntimeService.executeStep - Redis failure in error handler', ()
|
||||
// saveAgentState fails (Redis is down)
|
||||
coordinator.saveAgentState = vi.fn().mockRejectedValue(new Error('Redis ECONNRESET'));
|
||||
|
||||
const onComplete = vi.fn();
|
||||
service.registerStepCallbacks('op-redis-fail', { onComplete });
|
||||
const dispatchSpy = vi.spyOn(hookDispatcher, 'dispatch').mockResolvedValue(undefined);
|
||||
|
||||
// executeStep re-throws the original error after running callbacks
|
||||
// executeStep re-throws the original error after running hooks
|
||||
await expect(
|
||||
service.executeStep({
|
||||
operationId: 'op-redis-fail',
|
||||
@@ -348,16 +361,21 @@ describe('AgentRuntimeService.executeStep - Redis failure in error handler', ()
|
||||
}),
|
||||
).rejects.toThrow();
|
||||
|
||||
// onComplete MUST be called even when Redis is completely down
|
||||
expect(onComplete).toHaveBeenCalledWith(
|
||||
// onComplete hooks MUST be dispatched even when Redis is completely down
|
||||
expect(dispatchSpy).toHaveBeenCalledWith(
|
||||
'op-redis-fail',
|
||||
'onComplete',
|
||||
expect.objectContaining({
|
||||
operationId: 'op-redis-fail',
|
||||
reason: 'error',
|
||||
}),
|
||||
undefined,
|
||||
);
|
||||
|
||||
dispatchSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should still trigger completion webhook when Redis fails in catch block', async () => {
|
||||
it('should still dispatch onError hooks when Redis fails in catch block', async () => {
|
||||
const service = createService();
|
||||
const coordinator = (service as any).coordinator;
|
||||
const streamManager = (service as any).streamManager;
|
||||
@@ -388,12 +406,9 @@ describe('AgentRuntimeService.executeStep - Redis failure in error handler', ()
|
||||
|
||||
coordinator.saveAgentState = vi.fn().mockRejectedValue(new Error('Redis ECONNRESET'));
|
||||
|
||||
// Spy on triggerCompletionWebhook
|
||||
const triggerSpy = vi
|
||||
.spyOn(service as any, 'triggerCompletionWebhook')
|
||||
.mockResolvedValue(undefined);
|
||||
const dispatchSpy = vi.spyOn(hookDispatcher, 'dispatch').mockResolvedValue(undefined);
|
||||
|
||||
// executeStep re-throws the original error after running callbacks
|
||||
// executeStep re-throws the original error after running hooks
|
||||
await expect(
|
||||
service.executeStep({
|
||||
operationId: 'op-redis-webhook',
|
||||
@@ -402,12 +417,18 @@ describe('AgentRuntimeService.executeStep - Redis failure in error handler', ()
|
||||
}),
|
||||
).rejects.toThrow();
|
||||
|
||||
// Completion webhook MUST be triggered even when Redis is down
|
||||
expect(triggerSpy).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ status: 'error' }),
|
||||
// Both onComplete and onError hooks MUST be dispatched when reason is error
|
||||
expect(dispatchSpy).toHaveBeenCalledWith(
|
||||
'op-redis-webhook',
|
||||
'error',
|
||||
'onError',
|
||||
expect.objectContaining({
|
||||
operationId: 'op-redis-webhook',
|
||||
reason: 'error',
|
||||
}),
|
||||
undefined,
|
||||
);
|
||||
|
||||
dispatchSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should include stepCount in fallback error state when state reload fails', async () => {
|
||||
@@ -511,15 +532,23 @@ describe('AgentRuntimeService.executeStep - Redis failure in error handler', ()
|
||||
|
||||
coordinator.tryClaimStep = vi.fn().mockResolvedValue(true);
|
||||
|
||||
const stateWithWebhook = {
|
||||
const stateWithHooks = {
|
||||
status: 'running',
|
||||
stepCount: 5,
|
||||
lastModified: new Date().toISOString(),
|
||||
metadata: { completionWebhook: 'https://example.com/webhook' },
|
||||
metadata: {
|
||||
_hooks: [
|
||||
{
|
||||
id: 'test-hook',
|
||||
type: 'onComplete',
|
||||
webhook: { url: 'https://example.com/webhook' },
|
||||
},
|
||||
],
|
||||
},
|
||||
};
|
||||
|
||||
// loadAgentState always succeeds (returns state with webhook metadata)
|
||||
coordinator.loadAgentState = vi.fn().mockResolvedValue(stateWithWebhook);
|
||||
// loadAgentState always succeeds (returns state with hook metadata)
|
||||
coordinator.loadAgentState = vi.fn().mockResolvedValue(stateWithHooks);
|
||||
|
||||
// saveAgentState fails (write-only Redis failure)
|
||||
coordinator.saveAgentState = vi.fn().mockRejectedValue(new Error('Redis write failed'));
|
||||
@@ -532,8 +561,7 @@ describe('AgentRuntimeService.executeStep - Redis failure in error handler', ()
|
||||
return Promise.reject(new Error('Redis ECONNRESET'));
|
||||
});
|
||||
|
||||
const onComplete = vi.fn();
|
||||
service.registerStepCallbacks('op-save-fail', { onComplete });
|
||||
const dispatchSpy = vi.spyOn(hookDispatcher, 'dispatch').mockResolvedValue(undefined);
|
||||
|
||||
await expect(
|
||||
service.executeStep({
|
||||
@@ -543,18 +571,28 @@ describe('AgentRuntimeService.executeStep - Redis failure in error handler', ()
|
||||
}),
|
||||
).rejects.toThrow();
|
||||
|
||||
// onComplete must receive the full state with metadata (not a minimal fallback)
|
||||
expect(onComplete).toHaveBeenCalledWith(
|
||||
// onComplete hooks must be dispatched with the full state including metadata
|
||||
expect(dispatchSpy).toHaveBeenCalledWith(
|
||||
'op-save-fail',
|
||||
'onComplete',
|
||||
expect.objectContaining({
|
||||
operationId: 'op-save-fail',
|
||||
reason: 'error',
|
||||
finalState: expect.objectContaining({
|
||||
metadata: expect.objectContaining({
|
||||
completionWebhook: 'https://example.com/webhook',
|
||||
_hooks: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
id: 'test-hook',
|
||||
webhook: { url: 'https://example.com/webhook' },
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
status: 'error',
|
||||
}),
|
||||
operationId: 'op-save-fail',
|
||||
reason: 'error',
|
||||
}),
|
||||
expect.anything(),
|
||||
);
|
||||
|
||||
dispatchSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -0,0 +1,380 @@
|
||||
// @vitest-environment node
|
||||
/**
|
||||
* Integration test: hooks e2e chain
|
||||
*
|
||||
* Verifies the full data flow from AgentRuntimeService.executeStep
|
||||
* through HookDispatcher to hook handlers — with enriched step
|
||||
* presentation data that bot consumers depend on.
|
||||
*
|
||||
* This catches payload format regressions that unit tests miss because
|
||||
* they mock the dispatch layer.
|
||||
*/
|
||||
import { describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import { AgentRuntimeService } from '../AgentRuntimeService';
|
||||
import { hookDispatcher } from '../hooks';
|
||||
import type { AgentHookEvent } from '../hooks/types';
|
||||
|
||||
// ── Mocks ──────────────────────────────────────────
|
||||
vi.mock('@/envs/app', () => ({ appEnv: { APP_URL: 'http://localhost:3010' } }));
|
||||
vi.mock('@/database/models/message', () => ({
|
||||
MessageModel: vi.fn().mockImplementation(() => ({})),
|
||||
}));
|
||||
vi.mock('@/server/modules/AgentRuntime', () => ({
|
||||
AgentRuntimeCoordinator: vi.fn().mockImplementation(() => ({
|
||||
createAgentOperation: vi.fn(),
|
||||
getOperationMetadata: vi.fn(),
|
||||
loadAgentState: vi.fn(),
|
||||
releaseStepLock: vi.fn().mockResolvedValue(undefined),
|
||||
saveAgentState: vi.fn(),
|
||||
saveStepResult: vi.fn(),
|
||||
tryClaimStep: vi.fn().mockResolvedValue(true),
|
||||
})),
|
||||
createStreamEventManager: vi.fn(() => ({
|
||||
cleanupOperation: vi.fn(),
|
||||
publishAgentRuntimeEnd: vi.fn(),
|
||||
publishAgentRuntimeInit: vi.fn(),
|
||||
publishStreamEvent: vi.fn(),
|
||||
})),
|
||||
}));
|
||||
vi.mock('@/server/modules/AgentRuntime/RuntimeExecutors', () => ({
|
||||
createRuntimeExecutors: vi.fn(() => ({})),
|
||||
}));
|
||||
vi.mock('@/server/services/mcp', () => ({ mcpService: {} }));
|
||||
vi.mock('@/server/services/queue', () => ({
|
||||
QueueService: vi.fn().mockImplementation(() => ({
|
||||
getImpl: vi.fn(() => ({})),
|
||||
scheduleMessage: vi.fn(),
|
||||
})),
|
||||
}));
|
||||
vi.mock('@/server/services/queue/impls', () => ({
|
||||
LocalQueueServiceImpl: class {},
|
||||
isQueueAgentRuntimeEnabled: vi.fn().mockReturnValue(false),
|
||||
}));
|
||||
vi.mock('@/server/services/toolExecution', () => ({
|
||||
ToolExecutionService: vi.fn().mockImplementation(() => ({})),
|
||||
}));
|
||||
vi.mock('@/server/services/toolExecution/builtin', () => ({
|
||||
BuiltinToolsExecutor: vi.fn().mockImplementation(() => ({})),
|
||||
}));
|
||||
vi.mock('@lobechat/builtin-tools/dynamicInterventionAudits', () => ({
|
||||
dynamicInterventionAudits: [],
|
||||
}));
|
||||
|
||||
describe('Hooks integration — afterStep event carries step presentation data', () => {
|
||||
const createService = () => new AgentRuntimeService({} as any, 'user-1', { queueService: null });
|
||||
|
||||
it('should include content, stepType, totalTokens, toolsCalling in afterStep event', async () => {
|
||||
const service = createService();
|
||||
const coordinator = (service as any).coordinator;
|
||||
|
||||
// Simulate a running operation with afterStep hooks in metadata
|
||||
coordinator.loadAgentState.mockResolvedValue({
|
||||
createdAt: new Date().toISOString(),
|
||||
lastModified: new Date().toISOString(),
|
||||
messages: [{ content: 'Hello', role: 'user' }],
|
||||
metadata: {
|
||||
_hooks: [{ id: 'bot-step', type: 'afterStep', webhook: { url: '/test' } }],
|
||||
agentId: 'agent-1',
|
||||
userId: 'user-1',
|
||||
},
|
||||
operationId: 'op-1',
|
||||
status: 'running',
|
||||
stepCount: 0,
|
||||
usage: { llm: { tokens: { total: 150 } }, tools: { totalCalls: 0 } },
|
||||
});
|
||||
|
||||
// Mock runtime.step to return an LLM step with content
|
||||
// nextContext.phase is NOT tool_result, so content is extracted from llm_result event
|
||||
const stepResult = {
|
||||
events: [{ result: { content: 'Let me search for that.' }, type: 'llm_result' }],
|
||||
newState: {
|
||||
cost: { total: 0.01 },
|
||||
createdAt: new Date().toISOString(),
|
||||
messages: [
|
||||
{ content: 'Hello', role: 'user' },
|
||||
{ content: 'Let me search for that.', role: 'assistant' },
|
||||
],
|
||||
metadata: {
|
||||
_hooks: [{ id: 'bot-step', type: 'afterStep', webhook: { url: '/test' } }],
|
||||
agentId: 'agent-1',
|
||||
topicId: 'topic-1',
|
||||
userId: 'user-1',
|
||||
},
|
||||
status: 'running',
|
||||
stepCount: 1,
|
||||
usage: {
|
||||
llm: {
|
||||
apiCalls: 1,
|
||||
tokens: { input: 50, output: 100, total: 150 },
|
||||
},
|
||||
tools: { totalCalls: 0 },
|
||||
},
|
||||
},
|
||||
nextContext: {
|
||||
payload: { message: [{ content: 'Let me search for that.' }] },
|
||||
phase: 'user_input',
|
||||
session: { sessionId: 'op-1', status: 'running', stepCount: 1 },
|
||||
},
|
||||
};
|
||||
|
||||
vi.spyOn(service as any, 'createAgentRuntime').mockReturnValue({
|
||||
runtime: { step: vi.fn().mockResolvedValue(stepResult) },
|
||||
});
|
||||
|
||||
// Capture the actual hook event
|
||||
const capturedEvents: AgentHookEvent[] = [];
|
||||
const dispatchSpy = vi
|
||||
.spyOn(hookDispatcher, 'dispatch')
|
||||
.mockImplementation(async (_opId, type, event) => {
|
||||
if (type === 'afterStep') capturedEvents.push(event);
|
||||
});
|
||||
|
||||
await service.executeStep({
|
||||
context: { phase: 'user_input' } as any,
|
||||
operationId: 'op-1',
|
||||
stepIndex: 0,
|
||||
});
|
||||
|
||||
expect(capturedEvents).toHaveLength(1);
|
||||
const event = capturedEvents[0];
|
||||
|
||||
// ── Core identification ──
|
||||
expect(event.operationId).toBe('op-1');
|
||||
expect(event.agentId).toBe('agent-1');
|
||||
expect(event.userId).toBe('user-1');
|
||||
|
||||
// ── Step presentation data (what bot renderers need) ──
|
||||
expect(event.content).toBe('Let me search for that.');
|
||||
expect(event.stepType).toMatch(/call_llm|call_tool/);
|
||||
expect(typeof event.executionTimeMs).toBe('number');
|
||||
expect(event.totalTokens).toBe(150);
|
||||
expect(event.totalCost).toBe(0.01);
|
||||
expect(event.totalSteps).toBe(1);
|
||||
expect(event.shouldContinue).toBe(true);
|
||||
expect(event.topicId).toBe('topic-1');
|
||||
|
||||
// ── Tracking data (cross-step accumulator for bot progress) ──
|
||||
expect(typeof event.totalToolCalls).toBe('number');
|
||||
// elapsedMs should be calculated from state.createdAt
|
||||
expect(typeof event.elapsedMs).toBe('number');
|
||||
|
||||
// ── Full state available for local mode consumers ──
|
||||
expect(event.finalState).toBeDefined();
|
||||
expect(event.finalState.status).toBe('running');
|
||||
|
||||
dispatchSpy.mockRestore();
|
||||
});
|
||||
|
||||
it('should include toolsResult for tool_result phase', async () => {
|
||||
const service = createService();
|
||||
const coordinator = (service as any).coordinator;
|
||||
|
||||
coordinator.loadAgentState.mockResolvedValue({
|
||||
createdAt: new Date().toISOString(),
|
||||
lastModified: new Date().toISOString(),
|
||||
messages: [],
|
||||
metadata: {
|
||||
_hooks: [{ id: 'bot-step', type: 'afterStep', webhook: { url: '/test' } }],
|
||||
_stepTracking: { lastLLMContent: 'previous content', totalToolCalls: 1 },
|
||||
agentId: 'agent-1',
|
||||
userId: 'user-1',
|
||||
},
|
||||
operationId: 'op-2',
|
||||
status: 'running',
|
||||
stepCount: 1,
|
||||
});
|
||||
|
||||
// stepResult.nextContext has tool_result phase — this is where toolsResult is extracted from
|
||||
const stepResult = {
|
||||
events: [{ type: 'done' }],
|
||||
newState: {
|
||||
createdAt: new Date().toISOString(),
|
||||
messages: [],
|
||||
metadata: {
|
||||
_hooks: [{ id: 'bot-step', type: 'afterStep', webhook: { url: '/test' } }],
|
||||
_stepTracking: { lastLLMContent: 'previous content', totalToolCalls: 1 },
|
||||
agentId: 'agent-1',
|
||||
userId: 'user-1',
|
||||
},
|
||||
status: 'running',
|
||||
stepCount: 2,
|
||||
usage: { llm: { tokens: { total: 200 } }, tools: { totalCalls: 1 } },
|
||||
},
|
||||
nextContext: {
|
||||
payload: {
|
||||
data: 'Search found 3 results',
|
||||
toolCall: { apiName: 'search', id: 'tc-1', identifier: 'lobe-web-browsing' },
|
||||
toolCallId: 'tc-1',
|
||||
},
|
||||
phase: 'tool_result',
|
||||
session: { sessionId: 'op-2', status: 'running', stepCount: 2 },
|
||||
},
|
||||
};
|
||||
|
||||
vi.spyOn(service as any, 'createAgentRuntime').mockReturnValue({
|
||||
runtime: { step: vi.fn().mockResolvedValue(stepResult) },
|
||||
});
|
||||
|
||||
const capturedEvents: AgentHookEvent[] = [];
|
||||
const dispatchSpy = vi
|
||||
.spyOn(hookDispatcher, 'dispatch')
|
||||
.mockImplementation(async (_opId, type, event) => {
|
||||
if (type === 'afterStep') capturedEvents.push(event);
|
||||
});
|
||||
|
||||
await service.executeStep({
|
||||
context: { phase: 'user_input' } as any,
|
||||
operationId: 'op-2',
|
||||
stepIndex: 1,
|
||||
});
|
||||
|
||||
expect(capturedEvents).toHaveLength(1);
|
||||
const event = capturedEvents[0];
|
||||
|
||||
// Tool result extracted from stepResult.nextContext.payload
|
||||
expect(event.toolsResult).toBeDefined();
|
||||
expect(event.toolsResult).toEqual([
|
||||
expect.objectContaining({
|
||||
apiName: 'search',
|
||||
identifier: 'lobe-web-browsing',
|
||||
output: 'Search found 3 results',
|
||||
}),
|
||||
]);
|
||||
|
||||
// Tracking data carries forward from previous steps
|
||||
expect(event.lastLLMContent).toBe('previous content');
|
||||
// totalToolCalls includes current step (1 previous + 0 new tool calls in this step)
|
||||
expect(event.totalToolCalls).toBe(1);
|
||||
|
||||
dispatchSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Hooks integration — onComplete event for early-terminal states', () => {
|
||||
const createService = () => new AgentRuntimeService({} as any, 'user-1', { queueService: null });
|
||||
|
||||
it('should dispatch onComplete with correct reason when operation is interrupted', async () => {
|
||||
const service = createService();
|
||||
const coordinator = (service as any).coordinator;
|
||||
|
||||
coordinator.loadAgentState.mockResolvedValue({
|
||||
createdAt: new Date().toISOString(),
|
||||
lastModified: new Date().toISOString(),
|
||||
messages: [
|
||||
{ content: 'Hello', role: 'user' },
|
||||
{ content: 'I was working on it...', role: 'assistant' },
|
||||
],
|
||||
metadata: { agentId: 'agent-1', userId: 'user-1' },
|
||||
status: 'interrupted',
|
||||
stepCount: 3,
|
||||
usage: { llm: { apiCalls: 2, tokens: { total: 500 } }, tools: { totalCalls: 1 } },
|
||||
});
|
||||
|
||||
const capturedEvents: AgentHookEvent[] = [];
|
||||
const dispatchSpy = vi
|
||||
.spyOn(hookDispatcher, 'dispatch')
|
||||
.mockImplementation(async (_opId, type, event) => {
|
||||
if (type === 'onComplete') capturedEvents.push(event);
|
||||
});
|
||||
|
||||
await service.executeStep({
|
||||
context: { phase: 'user_input' } as any,
|
||||
operationId: 'op-interrupted',
|
||||
stepIndex: 4,
|
||||
});
|
||||
|
||||
expect(capturedEvents).toHaveLength(1);
|
||||
const event = capturedEvents[0];
|
||||
|
||||
expect(event.reason).toBe('interrupted');
|
||||
expect(event.operationId).toBe('op-interrupted');
|
||||
expect(event.lastAssistantContent).toBe('I was working on it...');
|
||||
expect(event.finalState).toBeDefined();
|
||||
|
||||
dispatchSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe('Hooks integration — afterStep event is compatible with renderStepProgress', () => {
|
||||
const createService = () => new AgentRuntimeService({} as any, 'user-1', { queueService: null });
|
||||
|
||||
it('afterStep event fields map to RenderStepParams without undefined required fields', async () => {
|
||||
const service = createService();
|
||||
const coordinator = (service as any).coordinator;
|
||||
|
||||
coordinator.loadAgentState.mockResolvedValue({
|
||||
createdAt: new Date().toISOString(),
|
||||
lastModified: new Date().toISOString(),
|
||||
messages: [],
|
||||
metadata: {
|
||||
_hooks: [{ id: 'bot-step', type: 'afterStep', webhook: { url: '/test' } }],
|
||||
agentId: 'agent-1',
|
||||
userId: 'user-1',
|
||||
},
|
||||
operationId: 'op-compat',
|
||||
status: 'running',
|
||||
stepCount: 0,
|
||||
});
|
||||
|
||||
const stepResult = {
|
||||
events: [{ type: 'done' }],
|
||||
newState: {
|
||||
createdAt: new Date().toISOString(),
|
||||
messages: [{ content: 'Result', role: 'assistant' }],
|
||||
metadata: {
|
||||
_hooks: [{ id: 'bot-step', type: 'afterStep', webhook: { url: '/test' } }],
|
||||
agentId: 'agent-1',
|
||||
userId: 'user-1',
|
||||
},
|
||||
status: 'done',
|
||||
stepCount: 1,
|
||||
usage: { llm: { tokens: { total: 100 } } },
|
||||
},
|
||||
nextContext: null,
|
||||
};
|
||||
|
||||
vi.spyOn(service as any, 'createAgentRuntime').mockReturnValue({
|
||||
runtime: { step: vi.fn().mockResolvedValue(stepResult) },
|
||||
});
|
||||
|
||||
const capturedEvents: AgentHookEvent[] = [];
|
||||
const dispatchSpy = vi
|
||||
.spyOn(hookDispatcher, 'dispatch')
|
||||
.mockImplementation(async (_opId, type, event) => {
|
||||
if (type === 'afterStep') capturedEvents.push(event);
|
||||
});
|
||||
|
||||
await service.executeStep({
|
||||
context: { phase: 'user_input' } as any,
|
||||
operationId: 'op-compat',
|
||||
stepIndex: 0,
|
||||
});
|
||||
|
||||
expect(capturedEvents).toHaveLength(1);
|
||||
const event = capturedEvents[0];
|
||||
|
||||
// Verify all fields needed by renderStepProgress are present and typed correctly
|
||||
// These map to RenderStepParams = StepPresentationData + { elapsedMs, lastContent, lastToolsCalling, totalToolCalls }
|
||||
expect(event.stepType).toBeDefined();
|
||||
expect(['call_llm', 'call_tool']).toContain(event.stepType);
|
||||
expect(typeof event.executionTimeMs).toBe('number');
|
||||
expect(typeof event.totalSteps).toBe('number');
|
||||
expect(typeof event.totalTokens).toBe('number');
|
||||
expect(typeof event.totalCost).toBe('number');
|
||||
expect(typeof event.totalInputTokens).toBe('number');
|
||||
expect(typeof event.totalOutputTokens).toBe('number');
|
||||
expect(typeof event.thinking).toBe('boolean');
|
||||
// These can be undefined but must be present as keys
|
||||
expect('content' in event).toBe(true);
|
||||
expect('reasoning' in event).toBe(true);
|
||||
expect('toolsCalling' in event).toBe(true);
|
||||
expect('toolsResult' in event).toBe(true);
|
||||
expect('elapsedMs' in event).toBe(true);
|
||||
expect('lastLLMContent' in event).toBe(true);
|
||||
expect('lastToolsCalling' in event).toBe(true);
|
||||
expect('totalToolCalls' in event).toBe(true);
|
||||
|
||||
dispatchSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
@@ -1,346 +0,0 @@
|
||||
import { type AgentRuntimeContext } from '@lobechat/agent-runtime';
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest';
|
||||
|
||||
import {
|
||||
InMemoryAgentStateManager,
|
||||
InMemoryStreamEventManager,
|
||||
} from '@/server/modules/AgentRuntime';
|
||||
|
||||
import { AgentRuntimeService } from '../AgentRuntimeService';
|
||||
import { type StepCompletionReason, type StepLifecycleCallbacks } from '../types';
|
||||
|
||||
// Mock database models
|
||||
vi.mock('@/database/models/message', () => ({
|
||||
MessageModel: vi.fn().mockImplementation(() => ({
|
||||
create: vi.fn().mockResolvedValue({ id: 'msg-1' }),
|
||||
query: vi.fn().mockResolvedValue([]),
|
||||
update: vi.fn().mockResolvedValue({}),
|
||||
})),
|
||||
}));
|
||||
|
||||
// Mock ModelRuntime
|
||||
vi.mock('@/server/modules/ModelRuntime', () => ({
|
||||
ApiKeyManager: vi.fn().mockImplementation(() => ({
|
||||
getAllApiKeys: vi.fn(),
|
||||
getApiKey: vi.fn(),
|
||||
})),
|
||||
initializeRuntimeOptions: vi.fn(),
|
||||
initModelRuntimeFromDB: vi.fn().mockResolvedValue({
|
||||
chat: vi.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
// Mock search service
|
||||
vi.mock('@/server/services/search', () => ({
|
||||
searchService: {
|
||||
search: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock MCP service
|
||||
vi.mock('@/server/services/mcp', () => ({
|
||||
mcpService: {
|
||||
executeCommand: vi.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
// Mock tool execution service
|
||||
vi.mock('@/server/services/toolExecution', () => ({
|
||||
ToolExecutionService: vi.fn().mockImplementation(() => ({
|
||||
executeToolCall: vi.fn().mockResolvedValue({ result: 'success' }),
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('@/server/services/toolExecution/builtin', () => ({
|
||||
BuiltinToolsExecutor: vi.fn().mockImplementation(() => ({
|
||||
execute: vi.fn(),
|
||||
})),
|
||||
}));
|
||||
|
||||
describe('AgentRuntimeService - Step Lifecycle Callbacks', () => {
|
||||
let service: AgentRuntimeService;
|
||||
let stateManager: InMemoryAgentStateManager;
|
||||
let streamEventManager: InMemoryStreamEventManager;
|
||||
|
||||
const mockDb = {} as any;
|
||||
const userId = 'test-user-id';
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Create in-memory managers
|
||||
stateManager = new InMemoryAgentStateManager();
|
||||
streamEventManager = new InMemoryStreamEventManager();
|
||||
|
||||
// Create service with in-memory implementations and no queue
|
||||
service = new AgentRuntimeService(mockDb, userId, {
|
||||
coordinatorOptions: {
|
||||
stateManager,
|
||||
streamEventManager,
|
||||
},
|
||||
queueService: null, // Disable queue for sync execution
|
||||
streamEventManager,
|
||||
});
|
||||
});
|
||||
|
||||
describe('registerStepCallbacks', () => {
|
||||
it('should register callbacks for an operation', () => {
|
||||
const operationId = 'test-op-1';
|
||||
const callbacks: StepLifecycleCallbacks = {
|
||||
onAfterStep: vi.fn(),
|
||||
onBeforeStep: vi.fn(),
|
||||
onComplete: vi.fn(),
|
||||
};
|
||||
|
||||
service.registerStepCallbacks(operationId, callbacks);
|
||||
|
||||
const registered = service.getStepCallbacks(operationId);
|
||||
expect(registered).toBe(callbacks);
|
||||
});
|
||||
|
||||
it('should overwrite existing callbacks if registered again', () => {
|
||||
const operationId = 'test-op-2';
|
||||
const callbacks1: StepLifecycleCallbacks = { onBeforeStep: vi.fn() };
|
||||
const callbacks2: StepLifecycleCallbacks = { onAfterStep: vi.fn() };
|
||||
|
||||
service.registerStepCallbacks(operationId, callbacks1);
|
||||
service.registerStepCallbacks(operationId, callbacks2);
|
||||
|
||||
const registered = service.getStepCallbacks(operationId);
|
||||
expect(registered).toBe(callbacks2);
|
||||
});
|
||||
});
|
||||
|
||||
describe('unregisterStepCallbacks', () => {
|
||||
it('should remove registered callbacks', () => {
|
||||
const operationId = 'test-op-3';
|
||||
const callbacks: StepLifecycleCallbacks = { onBeforeStep: vi.fn() };
|
||||
|
||||
service.registerStepCallbacks(operationId, callbacks);
|
||||
service.unregisterStepCallbacks(operationId);
|
||||
|
||||
const registered = service.getStepCallbacks(operationId);
|
||||
expect(registered).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should not throw when unregistering non-existent callbacks', () => {
|
||||
expect(() => {
|
||||
service.unregisterStepCallbacks('non-existent-op');
|
||||
}).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe('getStepCallbacks', () => {
|
||||
it('should return undefined for non-existent operation', () => {
|
||||
const registered = service.getStepCallbacks('non-existent-op');
|
||||
expect(registered).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('createOperation with stepCallbacks', () => {
|
||||
it('should register callbacks when provided in createOperation params', async () => {
|
||||
const operationId = 'test-op-with-callbacks';
|
||||
const callbacks: StepLifecycleCallbacks = {
|
||||
onAfterStep: vi.fn(),
|
||||
onBeforeStep: vi.fn(),
|
||||
onComplete: vi.fn(),
|
||||
};
|
||||
|
||||
const initialContext: AgentRuntimeContext = {
|
||||
payload: { message: [{ content: 'Hello' }] },
|
||||
phase: 'user_input',
|
||||
session: {
|
||||
messageCount: 1,
|
||||
sessionId: operationId,
|
||||
status: 'idle',
|
||||
stepCount: 0,
|
||||
},
|
||||
};
|
||||
|
||||
await service.createOperation({
|
||||
agentConfig: { model: 'gpt-4o', provider: 'openai' },
|
||||
appContext: { agentId: 'test-agent' },
|
||||
autoStart: false,
|
||||
initialContext,
|
||||
initialMessages: [{ content: 'Hello', role: 'user' }],
|
||||
modelRuntimeConfig: { model: 'gpt-4o', provider: 'openai' },
|
||||
operationId,
|
||||
stepCallbacks: callbacks,
|
||||
toolSet: { manifestMap: {}, tools: [] },
|
||||
userId,
|
||||
});
|
||||
|
||||
const registered = service.getStepCallbacks(operationId);
|
||||
expect(registered).toBe(callbacks);
|
||||
});
|
||||
|
||||
it('should not register callbacks when not provided', async () => {
|
||||
const operationId = 'test-op-no-callbacks';
|
||||
|
||||
const initialContext: AgentRuntimeContext = {
|
||||
payload: { message: [{ content: 'Hello' }] },
|
||||
phase: 'user_input',
|
||||
session: {
|
||||
messageCount: 1,
|
||||
sessionId: operationId,
|
||||
status: 'idle',
|
||||
stepCount: 0,
|
||||
},
|
||||
};
|
||||
|
||||
await service.createOperation({
|
||||
agentConfig: { model: 'gpt-4o', provider: 'openai' },
|
||||
appContext: { agentId: 'test-agent' },
|
||||
autoStart: false,
|
||||
initialContext,
|
||||
initialMessages: [{ content: 'Hello', role: 'user' }],
|
||||
modelRuntimeConfig: { model: 'gpt-4o', provider: 'openai' },
|
||||
operationId,
|
||||
toolSet: { manifestMap: {}, tools: [] },
|
||||
userId,
|
||||
});
|
||||
|
||||
const registered = service.getStepCallbacks(operationId);
|
||||
expect(registered).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('callback invocation tracking', () => {
|
||||
it('should track callback calls with correct parameters', async () => {
|
||||
const operationId = 'callback-tracking-test';
|
||||
|
||||
const onBeforeStepCalls: Array<{ operationId: string; stepIndex: number }> = [];
|
||||
const onAfterStepCalls: Array<{
|
||||
operationId: string;
|
||||
shouldContinue: boolean;
|
||||
stepIndex: number;
|
||||
}> = [];
|
||||
const onCompleteCalls: Array<{ operationId: string; reason: StepCompletionReason }> = [];
|
||||
|
||||
const callbacks: StepLifecycleCallbacks = {
|
||||
onAfterStep: async (params) => {
|
||||
onAfterStepCalls.push({
|
||||
operationId: params.operationId,
|
||||
shouldContinue: params.shouldContinue,
|
||||
stepIndex: params.stepIndex,
|
||||
});
|
||||
},
|
||||
onBeforeStep: async (params) => {
|
||||
onBeforeStepCalls.push({
|
||||
operationId: params.operationId,
|
||||
stepIndex: params.stepIndex,
|
||||
});
|
||||
},
|
||||
onComplete: async (params) => {
|
||||
onCompleteCalls.push({
|
||||
operationId: params.operationId,
|
||||
reason: params.reason,
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
// Verify callbacks structure is correct
|
||||
expect(callbacks.onBeforeStep).toBeDefined();
|
||||
expect(callbacks.onAfterStep).toBeDefined();
|
||||
expect(callbacks.onComplete).toBeDefined();
|
||||
|
||||
// Register callbacks
|
||||
service.registerStepCallbacks(operationId, callbacks);
|
||||
|
||||
// Verify they are registered
|
||||
const registered = service.getStepCallbacks(operationId);
|
||||
expect(registered).toBe(callbacks);
|
||||
});
|
||||
});
|
||||
|
||||
describe('callback error handling', () => {
|
||||
it('should not throw when onBeforeStep callback throws', async () => {
|
||||
const operationId = 'error-test-before';
|
||||
const callbacks: StepLifecycleCallbacks = {
|
||||
onBeforeStep: async () => {
|
||||
throw new Error('onBeforeStep error');
|
||||
},
|
||||
};
|
||||
|
||||
service.registerStepCallbacks(operationId, callbacks);
|
||||
|
||||
// The callback is registered, verify it exists
|
||||
const registered = service.getStepCallbacks(operationId);
|
||||
expect(registered).toBe(callbacks);
|
||||
expect(registered?.onBeforeStep).toBeDefined();
|
||||
});
|
||||
|
||||
it('should not throw when onAfterStep callback throws', async () => {
|
||||
const operationId = 'error-test-after';
|
||||
const callbacks: StepLifecycleCallbacks = {
|
||||
onAfterStep: async () => {
|
||||
throw new Error('onAfterStep error');
|
||||
},
|
||||
};
|
||||
|
||||
service.registerStepCallbacks(operationId, callbacks);
|
||||
|
||||
// The callback is registered, verify it exists
|
||||
const registered = service.getStepCallbacks(operationId);
|
||||
expect(registered).toBe(callbacks);
|
||||
expect(registered?.onAfterStep).toBeDefined();
|
||||
});
|
||||
|
||||
it('should not throw when onComplete callback throws', async () => {
|
||||
const operationId = 'error-test-complete';
|
||||
const callbacks: StepLifecycleCallbacks = {
|
||||
onComplete: async () => {
|
||||
throw new Error('onComplete error');
|
||||
},
|
||||
};
|
||||
|
||||
service.registerStepCallbacks(operationId, callbacks);
|
||||
|
||||
// The callback is registered, verify it exists
|
||||
const registered = service.getStepCallbacks(operationId);
|
||||
expect(registered).toBe(callbacks);
|
||||
expect(registered?.onComplete).toBeDefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe('partial callbacks', () => {
|
||||
it('should work with only onBeforeStep callback', async () => {
|
||||
const operationId = 'partial-before';
|
||||
const onBeforeStep = vi.fn();
|
||||
const callbacks: StepLifecycleCallbacks = { onBeforeStep };
|
||||
|
||||
service.registerStepCallbacks(operationId, callbacks);
|
||||
|
||||
const registered = service.getStepCallbacks(operationId);
|
||||
expect(registered?.onBeforeStep).toBe(onBeforeStep);
|
||||
expect(registered?.onAfterStep).toBeUndefined();
|
||||
expect(registered?.onComplete).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should work with only onAfterStep callback', async () => {
|
||||
const operationId = 'partial-after';
|
||||
const onAfterStep = vi.fn();
|
||||
const callbacks: StepLifecycleCallbacks = { onAfterStep };
|
||||
|
||||
service.registerStepCallbacks(operationId, callbacks);
|
||||
|
||||
const registered = service.getStepCallbacks(operationId);
|
||||
expect(registered?.onBeforeStep).toBeUndefined();
|
||||
expect(registered?.onAfterStep).toBe(onAfterStep);
|
||||
expect(registered?.onComplete).toBeUndefined();
|
||||
});
|
||||
|
||||
it('should work with only onComplete callback', async () => {
|
||||
const operationId = 'partial-complete';
|
||||
const onComplete = vi.fn();
|
||||
const callbacks: StepLifecycleCallbacks = { onComplete };
|
||||
|
||||
service.registerStepCallbacks(operationId, callbacks);
|
||||
|
||||
const registered = service.getStepCallbacks(operationId);
|
||||
expect(registered?.onBeforeStep).toBeUndefined();
|
||||
expect(registered?.onAfterStep).toBeUndefined();
|
||||
expect(registered?.onComplete).toBe(onComplete);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -56,13 +56,19 @@ export interface AgentHookWebhook {
|
||||
export interface AgentHookEvent {
|
||||
// Identification
|
||||
agentId: string;
|
||||
/** LLM text output (afterStep only) */
|
||||
content?: string;
|
||||
// Statistics
|
||||
cost?: number;
|
||||
duration?: number;
|
||||
/** Elapsed time since operation started in ms (afterStep only) */
|
||||
elapsedMs?: number;
|
||||
// Content
|
||||
errorDetail?: string;
|
||||
|
||||
errorMessage?: string;
|
||||
/** Step execution time in ms (afterStep only) */
|
||||
executionTimeMs?: number;
|
||||
|
||||
/**
|
||||
* Full AgentState — only available in local mode.
|
||||
@@ -71,6 +77,10 @@ export interface AgentHookEvent {
|
||||
*/
|
||||
finalState?: any;
|
||||
lastAssistantContent?: string;
|
||||
/** Last LLM content from previous steps — for showing context during tool execution (afterStep only) */
|
||||
lastLLMContent?: string;
|
||||
/** Last tools calling from previous steps (afterStep only) */
|
||||
lastToolsCalling?: any;
|
||||
|
||||
llmCalls?: number;
|
||||
// Caller-provided metadata (from webhook.body)
|
||||
@@ -78,17 +88,37 @@ export interface AgentHookEvent {
|
||||
operationId: string;
|
||||
// Execution result
|
||||
reason?: string; // 'done' | 'error' | 'interrupted' | 'max_steps' | 'cost_limit'
|
||||
/** LLM reasoning / thinking content (afterStep only) */
|
||||
reasoning?: string;
|
||||
// Step-specific (for beforeStep/afterStep)
|
||||
shouldContinue?: boolean;
|
||||
status?: string; // 'done' | 'error' | 'interrupted' | 'waiting_for_human'
|
||||
/** Step cost (afterStep only, LLM steps) */
|
||||
stepCost?: number;
|
||||
|
||||
stepIndex?: number;
|
||||
steps?: number;
|
||||
stepType?: string; // 'call_llm' | 'call_tool'
|
||||
/** Whether next step is LLM thinking (afterStep only) */
|
||||
thinking?: boolean;
|
||||
|
||||
toolCalls?: number;
|
||||
/** Tools the LLM decided to call (afterStep only) */
|
||||
toolsCalling?: any;
|
||||
/** Results from tool execution (afterStep only) */
|
||||
toolsResult?: any;
|
||||
topicId?: string;
|
||||
/** Cumulative total cost (afterStep only) */
|
||||
totalCost?: number;
|
||||
/** Cumulative input tokens (afterStep only) */
|
||||
totalInputTokens?: number;
|
||||
/** Cumulative output tokens (afterStep only) */
|
||||
totalOutputTokens?: number;
|
||||
/** Total steps executed so far (afterStep only) */
|
||||
totalSteps?: number;
|
||||
totalTokens?: number;
|
||||
/** Running total of tool calls across all steps (afterStep only) */
|
||||
totalToolCalls?: number;
|
||||
|
||||
userId: string;
|
||||
}
|
||||
|
||||
@@ -144,15 +144,6 @@ export interface OperationCreationParams {
|
||||
autoStart?: boolean;
|
||||
/** Bot platform context for injecting platform capabilities (e.g. markdown support) */
|
||||
botPlatformContext?: any;
|
||||
/**
|
||||
* Completion webhook configuration
|
||||
* When set, an HTTP POST will be fired when the operation completes (success or error).
|
||||
* The webhook is persisted in Redis state so it survives across QStash step boundaries.
|
||||
*/
|
||||
completionWebhook?: {
|
||||
body?: Record<string, unknown>;
|
||||
url: string;
|
||||
};
|
||||
/** Device system info for placeholder variable replacement in Local System systemRole */
|
||||
deviceSystemInfo?: Record<string, string>;
|
||||
/** Discord context for injecting channel/guild info into agent system message */
|
||||
@@ -176,20 +167,6 @@ export interface OperationCreationParams {
|
||||
queueRetryDelay?: string;
|
||||
/** Abort startup before the first step is scheduled */
|
||||
signal?: AbortSignal;
|
||||
/**
|
||||
* Step lifecycle callbacks
|
||||
* Used to inject custom logic at different stages of step execution
|
||||
*/
|
||||
stepCallbacks?: StepLifecycleCallbacks;
|
||||
/**
|
||||
* Step webhook configuration
|
||||
* When set, an HTTP POST will be fired after each step completes.
|
||||
* Persisted in Redis state so it survives across QStash step boundaries.
|
||||
*/
|
||||
stepWebhook?: {
|
||||
body?: Record<string, unknown>;
|
||||
url: string;
|
||||
};
|
||||
/**
|
||||
* Whether the LLM call should use streaming.
|
||||
* Defaults to true. Set to false for non-streaming scenarios (e.g., bot integrations).
|
||||
@@ -207,12 +184,6 @@ export interface OperationCreationParams {
|
||||
userMemory?: ServerUserMemoryConfig;
|
||||
/** User's timezone from settings (e.g. 'Asia/Shanghai') */
|
||||
userTimezone?: string;
|
||||
/**
|
||||
* Webhook delivery method.
|
||||
* - 'fetch': plain HTTP POST (default)
|
||||
* - 'qstash': deliver via QStash publishJSON for guaranteed delivery
|
||||
*/
|
||||
webhookDelivery?: 'fetch' | 'qstash';
|
||||
}
|
||||
|
||||
export interface OperationCreationResult {
|
||||
|
||||
@@ -94,14 +94,6 @@ interface InternalExecAgentParams extends ExecAgentParams {
|
||||
botContext?: ChatTopicBotContext;
|
||||
/** Bot platform context for injecting platform capabilities (e.g. markdown support) */
|
||||
botPlatformContext?: any;
|
||||
/**
|
||||
* Completion webhook configuration
|
||||
* Persisted in Redis state, triggered via HTTP POST when the operation completes.
|
||||
*/
|
||||
completionWebhook?: {
|
||||
body?: Record<string, unknown>;
|
||||
url: string;
|
||||
};
|
||||
/** Cron job ID that triggered this execution (if trigger is 'cron') */
|
||||
cronJobId?: string;
|
||||
/** Disable all tools (no plugins, no system manifests). Useful for eval/benchmark scenarios. */
|
||||
@@ -136,16 +128,6 @@ interface InternalExecAgentParams extends ExecAgentParams {
|
||||
resume?: boolean;
|
||||
/** Abort startup before the agent runtime operation is created */
|
||||
signal?: AbortSignal;
|
||||
/** Step lifecycle callbacks for operation tracking (server-side only) */
|
||||
stepCallbacks?: StepLifecycleCallbacks;
|
||||
/**
|
||||
* Step webhook configuration
|
||||
* Persisted in Redis state, triggered via HTTP POST after each step completes.
|
||||
*/
|
||||
stepWebhook?: {
|
||||
body?: Record<string, unknown>;
|
||||
url: string;
|
||||
};
|
||||
/**
|
||||
* Whether the LLM call should use streaming.
|
||||
* Defaults to true. Set to false for non-streaming scenarios (e.g., bot integrations).
|
||||
@@ -166,12 +148,6 @@ interface InternalExecAgentParams extends ExecAgentParams {
|
||||
* Use { approvalMode: 'headless' } for async tasks that should never wait for human approval
|
||||
*/
|
||||
userInterventionConfig?: UserInterventionConfig;
|
||||
/**
|
||||
* Webhook delivery method.
|
||||
* - 'fetch': plain HTTP POST (default)
|
||||
* - 'qstash': deliver via QStash publishJSON for guaranteed delivery
|
||||
*/
|
||||
webhookDelivery?: 'fetch' | 'qstash';
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -247,7 +223,6 @@ export class AiAgentService {
|
||||
instructions,
|
||||
model: modelOverride,
|
||||
provider: providerOverride,
|
||||
stepCallbacks,
|
||||
stream,
|
||||
title,
|
||||
trigger,
|
||||
@@ -258,11 +233,8 @@ export class AiAgentService {
|
||||
initialStepCount,
|
||||
signal,
|
||||
userInterventionConfig,
|
||||
completionWebhook,
|
||||
queueRetries,
|
||||
queueRetryDelay,
|
||||
stepWebhook,
|
||||
webhookDelivery,
|
||||
parentMessageId,
|
||||
resume,
|
||||
} = params;
|
||||
@@ -1026,7 +998,6 @@ export class AiAgentService {
|
||||
},
|
||||
autoStart,
|
||||
botPlatformContext,
|
||||
completionWebhook,
|
||||
discordContext,
|
||||
evalContext,
|
||||
initialContext,
|
||||
@@ -1037,8 +1008,6 @@ export class AiAgentService {
|
||||
hooks,
|
||||
operationId,
|
||||
signal,
|
||||
stepCallbacks,
|
||||
stepWebhook,
|
||||
queueRetries,
|
||||
queueRetryDelay,
|
||||
stream,
|
||||
@@ -1052,7 +1021,6 @@ export class AiAgentService {
|
||||
userId: this.userId,
|
||||
userInterventionConfig,
|
||||
userMemory,
|
||||
webhookDelivery,
|
||||
});
|
||||
|
||||
log('execAgent: created operation %s (autoStarted: %s)', operationId, result.autoStarted);
|
||||
|
||||
@@ -3,12 +3,10 @@ import { RequestTrigger } from '@lobechat/types';
|
||||
import type { Message, SentMessage, Thread } from 'chat';
|
||||
import { emoji } from 'chat';
|
||||
import debug from 'debug';
|
||||
import urlJoin from 'url-join';
|
||||
|
||||
import { TopicModel } from '@/database/models/topic';
|
||||
import { UserModel } from '@/database/models/user';
|
||||
import type { LobeChatDatabase } from '@/database/type';
|
||||
import { appEnv } from '@/envs/app';
|
||||
import { createAbortError, isAbortError } from '@/server/services/agentRuntime/abort';
|
||||
import { AiAgentService } from '@/server/services/aiAgent';
|
||||
import { isQueueAgentRuntimeEnabled } from '@/server/services/queue/impls';
|
||||
@@ -30,6 +28,10 @@ const log = debug('lobe-server:bot:agent-bridge');
|
||||
|
||||
const EXECUTION_TIMEOUT = 30 * 60 * 1000; // 30 minutes
|
||||
|
||||
// If the last activity in a bot topic is older than this threshold,
|
||||
// create a new topic instead of continuing in the stale one.
|
||||
const TOPIC_STALE_THRESHOLD = 4 * 60 * 60 * 1000; // 4 hours
|
||||
|
||||
// PostgreSQL error code for foreign key constraint violations.
|
||||
// See: https://www.postgresql.org/docs/current/errcodes-appendix.html
|
||||
const PG_FOREIGN_KEY_VIOLATION = '23503';
|
||||
@@ -340,7 +342,11 @@ export class AgentBridgeService {
|
||||
return this.handleMention(thread, message, opts);
|
||||
}
|
||||
|
||||
// Skip if there's already an active execution for this thread
|
||||
// Skip if there's already an active execution for this thread.
|
||||
// This must run before the stale-topic check to prevent a race where
|
||||
// a concurrent message clears topicId (stale reset) and then no-ops
|
||||
// in handleMention because the thread is active — dropping the message
|
||||
// but leaving state cleared so the next message starts a fresh topic.
|
||||
if (AgentBridgeService.activeThreads.has(thread.id)) {
|
||||
log(
|
||||
'handleSubscribedMessage: skipping, thread=%s already has an active execution',
|
||||
@@ -349,6 +355,33 @@ export class AgentBridgeService {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if the topic is stale (no activity for 4+ hours).
|
||||
// If so, clear the cached topicId and start a fresh conversation.
|
||||
// Wrapped in try/catch so transient DB errors fall through to the
|
||||
// existing topicId rather than rejecting before the guarded section.
|
||||
try {
|
||||
const topicModel = new TopicModel(this.db, this.userId);
|
||||
const existingTopic = await topicModel.findById(topicId);
|
||||
if (existingTopic) {
|
||||
const elapsed = Date.now() - new Date(existingTopic.updatedAt).getTime();
|
||||
if (elapsed > TOPIC_STALE_THRESHOLD) {
|
||||
log(
|
||||
'handleSubscribedMessage: topic=%s is stale (%.1fh since last activity), creating new topic',
|
||||
topicId,
|
||||
elapsed / (60 * 60 * 1000),
|
||||
);
|
||||
await thread.setState({ ...threadState, topicId: undefined });
|
||||
return this.handleMention(thread, message, opts);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
log(
|
||||
'handleSubscribedMessage: stale-topic lookup failed, continuing with existing topicId=%s: %O',
|
||||
topicId,
|
||||
error,
|
||||
);
|
||||
}
|
||||
|
||||
AgentBridgeService.activeThreads.add(thread.id);
|
||||
|
||||
// Read cached channel context from thread state
|
||||
@@ -408,7 +441,10 @@ export class AgentBridgeService {
|
||||
}
|
||||
|
||||
/**
|
||||
* Dispatch to queue-mode webhooks or local in-memory callbacks based on runtime mode.
|
||||
* Execute agent with unified hooks — auto-adapts to local or queue mode.
|
||||
*
|
||||
* Local mode: hooks run in-process, Promise resolves when agent completes.
|
||||
* Queue mode: hooks deliver via webhooks, returns immediately after startup.
|
||||
*/
|
||||
private async executeWithCallback(
|
||||
thread: Thread<ThreadState>,
|
||||
@@ -436,35 +472,9 @@ export class AgentBridgeService {
|
||||
}
|
||||
}
|
||||
|
||||
const optsWithPlatform = { ...opts, botPlatformContext };
|
||||
|
||||
if (isQueueAgentRuntimeEnabled()) {
|
||||
return this.executeWithWebhooks(thread, userMessage, optsWithPlatform);
|
||||
}
|
||||
return this.executeWithInMemoryCallbacks(thread, userMessage, optsWithPlatform);
|
||||
}
|
||||
|
||||
/**
|
||||
* Queue mode: post initial message, configure step/completion webhooks,
|
||||
* then return immediately. Progress updates and final reply are handled
|
||||
* by the bot-callback webhook endpoint.
|
||||
*/
|
||||
private async executeWithWebhooks(
|
||||
thread: Thread<ThreadState>,
|
||||
userMessage: Message,
|
||||
opts: {
|
||||
agentId: string;
|
||||
botContext?: ChatTopicBotContext;
|
||||
botPlatformContext?: { platformName: string; supportsMarkdown: boolean };
|
||||
channelContext?: DiscordChannelContext;
|
||||
client?: PlatformClient;
|
||||
topicId?: string;
|
||||
trigger?: string;
|
||||
},
|
||||
): Promise<{ reply: string; topicId: string }> {
|
||||
const { agentId, botContext, botPlatformContext, channelContext, client, topicId, trigger } =
|
||||
opts;
|
||||
const { agentId, botContext, channelContext, charLimit, client, displayToolCalls, topicId, trigger } = opts;
|
||||
|
||||
const queueMode = isQueueAgentRuntimeEnabled();
|
||||
const aiAgentService = new AiAgentService(this.db, this.userId);
|
||||
const timezone = await this.loadTimezone();
|
||||
|
||||
@@ -474,38 +484,100 @@ export class AgentBridgeService {
|
||||
try {
|
||||
progressMessage = await thread.post(renderStart(userMessage.text, { timezone }));
|
||||
} catch (error) {
|
||||
log('executeWithWebhooks: failed to post initial placeholder message: %O', error);
|
||||
log('executeWithCallback: failed to post initial placeholder message: %O', error);
|
||||
}
|
||||
|
||||
const progressMessageId: string | undefined = progressMessage?.id;
|
||||
|
||||
// Build webhook URL for bot-callback endpoint
|
||||
// Prefer INTERNAL_APP_URL for server-to-server calls (bypasses CDN/proxy)
|
||||
const baseURL = appEnv.INTERNAL_APP_URL || appEnv.APP_URL;
|
||||
if (!baseURL) {
|
||||
throw new Error('APP_URL is required for queue mode bot webhooks');
|
||||
}
|
||||
const callbackUrl = urlJoin(baseURL, '/api/agent/webhooks/bot-callback');
|
||||
|
||||
const webhookBody = {
|
||||
applicationId: botContext?.applicationId,
|
||||
platformThreadId: botContext?.platformThreadId,
|
||||
progressMessageId,
|
||||
userMessageId: userMessage.id,
|
||||
};
|
||||
|
||||
const files = this.extractFiles(userMessage);
|
||||
const prompt = this.formatPrompt(userMessage, client);
|
||||
|
||||
// Build webhook config for production mode
|
||||
const callbackUrl = '/api/agent/webhooks/bot-callback';
|
||||
const webhookBody = {
|
||||
applicationId: botContext?.applicationId,
|
||||
platformThreadId: botContext?.platformThreadId,
|
||||
progressMessageId: progressMessage?.id,
|
||||
userMessageId: userMessage.id,
|
||||
};
|
||||
|
||||
log(
|
||||
'executeWithWebhooks: agentId=%s, callbackUrl=%s, progressMessageId=%s, prompt=%s, files=%d',
|
||||
'executeWithCallback: agentId=%s, queueMode=%s, prompt=%s, files=%d',
|
||||
agentId,
|
||||
callbackUrl,
|
||||
progressMessageId,
|
||||
queueMode,
|
||||
prompt.slice(0, 100),
|
||||
files?.length ?? 0,
|
||||
);
|
||||
|
||||
// In queue mode, return immediately after startup — hooks handle the rest via webhooks
|
||||
if (queueMode) {
|
||||
return this.executeWithHooksQueueMode(thread, userMessage, aiAgentService, {
|
||||
agentId,
|
||||
botContext,
|
||||
botPlatformContext,
|
||||
callbackUrl,
|
||||
channelContext,
|
||||
files,
|
||||
progressMessage,
|
||||
prompt,
|
||||
topicId,
|
||||
trigger,
|
||||
webhookBody,
|
||||
});
|
||||
}
|
||||
|
||||
// In local mode, wrap in a Promise — hook handlers resolve/reject it in-process
|
||||
return this.executeWithHooksLocalMode(thread, aiAgentService, {
|
||||
agentId,
|
||||
botContext,
|
||||
botPlatformContext,
|
||||
callbackUrl,
|
||||
charLimit,
|
||||
channelContext,
|
||||
client,
|
||||
displayToolCalls,
|
||||
files,
|
||||
progressMessage,
|
||||
prompt,
|
||||
topicId,
|
||||
trigger,
|
||||
webhookBody,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Queue mode: register hooks with webhook config, start agent, return immediately.
|
||||
*/
|
||||
private async executeWithHooksQueueMode(
|
||||
thread: Thread<ThreadState>,
|
||||
userMessage: Message,
|
||||
aiAgentService: AiAgentService,
|
||||
opts: {
|
||||
agentId: string;
|
||||
botContext?: ChatTopicBotContext;
|
||||
botPlatformContext?: { platformName: string; supportsMarkdown: boolean };
|
||||
callbackUrl: string;
|
||||
channelContext?: DiscordChannelContext;
|
||||
files?: any;
|
||||
progressMessage?: SentMessage;
|
||||
prompt: string;
|
||||
topicId?: string;
|
||||
trigger?: string;
|
||||
webhookBody: Record<string, unknown>;
|
||||
},
|
||||
): Promise<{ reply: string; topicId: string }> {
|
||||
const {
|
||||
agentId,
|
||||
botContext,
|
||||
botPlatformContext,
|
||||
callbackUrl,
|
||||
channelContext,
|
||||
files,
|
||||
progressMessage,
|
||||
prompt,
|
||||
topicId,
|
||||
trigger,
|
||||
webhookBody,
|
||||
} = opts;
|
||||
|
||||
let result: ExecAgentResult;
|
||||
try {
|
||||
result = await AgentBridgeService.runWithStartupSignal(thread.id, (signal) =>
|
||||
@@ -515,25 +587,46 @@ export class AgentBridgeService {
|
||||
autoStart: true,
|
||||
botContext,
|
||||
botPlatformContext,
|
||||
completionWebhook: { body: webhookBody, url: callbackUrl },
|
||||
discordContext: channelContext
|
||||
? { channel: channelContext.channel, guild: channelContext.guild }
|
||||
: undefined,
|
||||
files,
|
||||
hooks: [
|
||||
{
|
||||
handler: async () => {
|
||||
/* local handler not used in queue mode */
|
||||
},
|
||||
id: 'bot-step-progress',
|
||||
type: 'afterStep',
|
||||
webhook: {
|
||||
body: { ...webhookBody, type: 'step' },
|
||||
delivery: 'qstash',
|
||||
url: callbackUrl,
|
||||
},
|
||||
},
|
||||
{
|
||||
handler: async () => {
|
||||
/* local handler not used in queue mode */
|
||||
},
|
||||
id: 'bot-completion',
|
||||
type: 'onComplete',
|
||||
webhook: {
|
||||
body: { ...webhookBody, type: 'completion', userPrompt: prompt },
|
||||
delivery: 'qstash',
|
||||
url: callbackUrl,
|
||||
},
|
||||
},
|
||||
],
|
||||
prompt,
|
||||
signal,
|
||||
stepWebhook: { body: webhookBody, url: callbackUrl },
|
||||
title: '',
|
||||
trigger,
|
||||
userInterventionConfig: { approvalMode: 'headless' },
|
||||
webhookDelivery: 'qstash',
|
||||
}),
|
||||
);
|
||||
} catch (error) {
|
||||
log('executeWithWebhooks: execAgent failed: %O', error);
|
||||
log('executeWithCallback[queue]: execAgent failed: %O', error);
|
||||
|
||||
// For stale topicId FK violations, re-throw so handleSubscribedMessage can clear
|
||||
// the cached topicId and retry as a fresh mention instead of showing a DB error.
|
||||
const errMsg = error instanceof Error ? error.message : String(error);
|
||||
if (errMsg.includes('Failed query') && errMsg.includes('topic_id')) {
|
||||
throw error;
|
||||
@@ -560,12 +653,11 @@ export class AgentBridgeService {
|
||||
}
|
||||
|
||||
log(
|
||||
'executeWithWebhooks: operationId=%s, topicId=%s (webhook mode, returning immediately)',
|
||||
'executeWithCallback[queue]: operationId=%s, topicId=%s (returning immediately)',
|
||||
result.operationId,
|
||||
result.topicId,
|
||||
);
|
||||
|
||||
// Track operationId so /stop can interrupt this execution
|
||||
if (result.operationId) {
|
||||
AgentBridgeService.activeOperations.set(thread.id, result.operationId);
|
||||
|
||||
@@ -574,67 +666,57 @@ export class AgentBridgeService {
|
||||
await this.interruptTrackedOperation(thread.id, result.operationId);
|
||||
} catch (error) {
|
||||
log(
|
||||
'executeWithWebhooks: deferred stop failed for thread=%s, operationId=%s: %O',
|
||||
'executeWithCallback[queue]: deferred stop failed for thread=%s: %O',
|
||||
thread.id,
|
||||
result.operationId,
|
||||
error,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Return immediately — progress/completion handled by webhooks
|
||||
return { reply: '', topicId: result.topicId };
|
||||
}
|
||||
|
||||
/**
|
||||
* Local mode: use in-memory step callbacks and wait for completion via Promise.
|
||||
* Local mode: register hooks with in-process handlers, wait for completion via Promise.
|
||||
*/
|
||||
private async executeWithInMemoryCallbacks(
|
||||
private async executeWithHooksLocalMode(
|
||||
thread: Thread<ThreadState>,
|
||||
userMessage: Message,
|
||||
aiAgentService: AiAgentService,
|
||||
opts: {
|
||||
agentId: string;
|
||||
botContext?: ChatTopicBotContext;
|
||||
botPlatformContext?: { platformName: string; supportsMarkdown: boolean };
|
||||
channelContext?: DiscordChannelContext;
|
||||
callbackUrl: string;
|
||||
charLimit?: number;
|
||||
channelContext?: DiscordChannelContext;
|
||||
client?: PlatformClient;
|
||||
displayToolCalls?: boolean;
|
||||
files?: any;
|
||||
progressMessage?: SentMessage;
|
||||
prompt: string;
|
||||
topicId?: string;
|
||||
trigger?: string;
|
||||
webhookBody: Record<string, unknown>;
|
||||
},
|
||||
): Promise<{ reply: string; topicId: string }> {
|
||||
const {
|
||||
agentId,
|
||||
botContext,
|
||||
botPlatformContext,
|
||||
channelContext,
|
||||
callbackUrl,
|
||||
charLimit,
|
||||
channelContext,
|
||||
client,
|
||||
displayToolCalls,
|
||||
files,
|
||||
prompt,
|
||||
topicId,
|
||||
trigger,
|
||||
webhookBody,
|
||||
} = opts;
|
||||
|
||||
const aiAgentService = new AiAgentService(this.db, this.userId);
|
||||
const timezone = await this.loadTimezone();
|
||||
|
||||
await thread.startTyping();
|
||||
|
||||
let progressMessage: SentMessage | undefined;
|
||||
try {
|
||||
progressMessage = await thread.post(renderStart(userMessage.text, { timezone }));
|
||||
} catch (error) {
|
||||
log('executeWithInMemoryCallbacks: failed to post initial placeholder message: %O', error);
|
||||
}
|
||||
|
||||
// Track the last LLM content and tool calls for showing during tool execution
|
||||
let lastLLMContent = '';
|
||||
let lastToolsCalling:
|
||||
| Array<{ apiName: string; arguments?: string; identifier: string }>
|
||||
| undefined;
|
||||
let totalToolCalls = 0;
|
||||
let { progressMessage } = opts;
|
||||
let operationStartTime = 0;
|
||||
|
||||
return new Promise<{ reply: string; topicId: string }>((resolve, reject) => {
|
||||
@@ -642,21 +724,10 @@ export class AgentBridgeService {
|
||||
reject(new Error(`Agent execution timed out`));
|
||||
}, EXECUTION_TIMEOUT);
|
||||
|
||||
let assistantMessageId = '';
|
||||
let resolvedTopicId = topicId ?? '';
|
||||
|
||||
const getElapsedMs = () => (operationStartTime > 0 ? Date.now() - operationStartTime : 0);
|
||||
|
||||
const files = this.extractFiles(userMessage);
|
||||
const prompt = this.formatPrompt(userMessage, client);
|
||||
|
||||
log(
|
||||
'executeWithInMemoryCallbacks: agentId=%s, prompt=%s, files=%d',
|
||||
agentId,
|
||||
prompt.slice(0, 100),
|
||||
files?.length ?? 0,
|
||||
);
|
||||
|
||||
AgentBridgeService.runWithStartupSignal(thread.id, (signal) =>
|
||||
aiAgentService.execAgent({
|
||||
agentId,
|
||||
@@ -668,163 +739,177 @@ export class AgentBridgeService {
|
||||
? { channel: channelContext.channel, guild: channelContext.guild }
|
||||
: undefined,
|
||||
files,
|
||||
prompt,
|
||||
signal,
|
||||
title: '',
|
||||
stepCallbacks: {
|
||||
onAfterStep: async (stepData) => {
|
||||
const { content, shouldContinue, toolsCalling } = stepData;
|
||||
if (!shouldContinue || !progressMessage || displayToolCalls === false) return;
|
||||
hooks: [
|
||||
{
|
||||
handler: async (event) => {
|
||||
if (!event.shouldContinue || !progressMessage || displayToolCalls === false) return;
|
||||
|
||||
if (toolsCalling) totalToolCalls += toolsCalling.length;
|
||||
const msgBody = renderStepProgress({
|
||||
content: event.content,
|
||||
elapsedMs: event.elapsedMs ?? getElapsedMs(),
|
||||
executionTimeMs: event.executionTimeMs ?? 0,
|
||||
lastContent: event.lastLLMContent,
|
||||
lastToolsCalling: event.lastToolsCalling,
|
||||
reasoning: event.reasoning,
|
||||
stepType: (event.stepType as 'call_llm' | 'call_tool') ?? 'call_llm',
|
||||
thinking: event.thinking ?? false,
|
||||
toolsCalling: event.toolsCalling,
|
||||
toolsResult: event.toolsResult,
|
||||
totalCost: event.totalCost ?? 0,
|
||||
totalInputTokens: event.totalInputTokens ?? 0,
|
||||
totalOutputTokens: event.totalOutputTokens ?? 0,
|
||||
totalSteps: event.totalSteps ?? 0,
|
||||
totalTokens: event.totalTokens ?? 0,
|
||||
totalToolCalls: event.totalToolCalls ?? 0,
|
||||
});
|
||||
|
||||
const msgBody = renderStepProgress({
|
||||
...stepData,
|
||||
elapsedMs: getElapsedMs(),
|
||||
lastContent: lastLLMContent,
|
||||
lastToolsCalling,
|
||||
totalToolCalls,
|
||||
});
|
||||
const stats = {
|
||||
elapsedMs: event.elapsedMs ?? getElapsedMs(),
|
||||
totalCost: event.totalCost ?? 0,
|
||||
totalTokens: event.totalTokens ?? 0,
|
||||
};
|
||||
const formatted = client?.formatMarkdown?.(msgBody) ?? msgBody;
|
||||
const progressText = client?.formatReply?.(formatted, stats) ?? formatted;
|
||||
|
||||
const stats = {
|
||||
elapsedMs: getElapsedMs(),
|
||||
totalCost: stepData.totalCost ?? 0,
|
||||
totalTokens: stepData.totalTokens ?? 0,
|
||||
};
|
||||
const formatted = client?.formatMarkdown?.(msgBody) ?? msgBody;
|
||||
const progressText = client?.formatReply?.(formatted, stats) ?? formatted;
|
||||
|
||||
if (content) lastLLMContent = content;
|
||||
if (toolsCalling) lastToolsCalling = toolsCalling;
|
||||
|
||||
try {
|
||||
progressMessage = await progressMessage.edit(progressText);
|
||||
} catch (error) {
|
||||
log('executeWithInMemoryCallbacks: failed to edit progress message: %O', error);
|
||||
}
|
||||
},
|
||||
|
||||
onComplete: async ({ finalState, reason }) => {
|
||||
clearTimeout(timeout);
|
||||
|
||||
log('onComplete: reason=%s, assistantMessageId=%s', reason, assistantMessageId);
|
||||
|
||||
if (reason === 'error') {
|
||||
const errorMsg = extractErrorMessage(finalState.error);
|
||||
try {
|
||||
const errorText = renderError(errorMsg);
|
||||
if (progressMessage) {
|
||||
await progressMessage.edit(errorText);
|
||||
} else {
|
||||
await thread.post(errorText);
|
||||
}
|
||||
} catch {
|
||||
// ignore send failure
|
||||
progressMessage = await progressMessage.edit(progressText);
|
||||
} catch (error) {
|
||||
log('executeWithCallback[local]: failed to edit progress message: %O', error);
|
||||
}
|
||||
reject(new Error(errorMsg));
|
||||
return;
|
||||
}
|
||||
|
||||
if (reason === 'interrupted') {
|
||||
if (progressMessage) {
|
||||
try {
|
||||
await progressMessage.edit(renderStopped());
|
||||
} catch {
|
||||
// ignore edit failure
|
||||
}
|
||||
}
|
||||
resolve({ reply: '', topicId: resolvedTopicId });
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// Extract reply from finalState.messages (accumulated across all steps)
|
||||
const lastAssistantContent = finalState.messages
|
||||
?.slice()
|
||||
.reverse()
|
||||
.find(
|
||||
(m: { content?: string; role: string }) => m.role === 'assistant' && m.content,
|
||||
)?.content;
|
||||
|
||||
if (lastAssistantContent) {
|
||||
const replyBody = renderFinalReply(lastAssistantContent);
|
||||
const replyStats = {
|
||||
elapsedMs: getElapsedMs(),
|
||||
llmCalls: finalState.usage?.llm?.apiCalls ?? 0,
|
||||
toolCalls: finalState.usage?.tools?.totalCalls ?? 0,
|
||||
totalCost: finalState.cost?.total ?? 0,
|
||||
totalTokens: finalState.usage?.llm?.tokens?.total ?? 0,
|
||||
};
|
||||
const formattedBody = client?.formatMarkdown?.(replyBody) ?? replyBody;
|
||||
const finalText =
|
||||
client?.formatReply?.(formattedBody, replyStats) ?? formattedBody;
|
||||
|
||||
const chunks = splitMessage(finalText, charLimit);
|
||||
},
|
||||
id: 'bot-step-progress',
|
||||
type: 'afterStep' as const,
|
||||
webhook: {
|
||||
body: { ...webhookBody, type: 'step' },
|
||||
delivery: 'qstash' as const,
|
||||
url: callbackUrl,
|
||||
},
|
||||
},
|
||||
{
|
||||
handler: async (event) => {
|
||||
clearTimeout(timeout);
|
||||
|
||||
const reason = event.reason;
|
||||
log('onComplete: reason=%s', reason);
|
||||
|
||||
if (reason === 'error') {
|
||||
const errorMsg = event.errorMessage || 'Agent execution failed';
|
||||
try {
|
||||
const errorText = renderError(errorMsg);
|
||||
if (progressMessage) {
|
||||
await progressMessage.edit(chunks[0]);
|
||||
// Post overflow chunks as follow-up messages
|
||||
for (let i = 1; i < chunks.length; i++) {
|
||||
await thread.post(chunks[i]);
|
||||
}
|
||||
await progressMessage.edit(errorText);
|
||||
} else {
|
||||
// No progress message (non-editable platform) — post all chunks as new messages
|
||||
for (const chunk of chunks) {
|
||||
await thread.post(chunk);
|
||||
}
|
||||
await thread.post(errorText);
|
||||
}
|
||||
} catch (error) {
|
||||
log('executeWithInMemoryCallbacks: failed to send final message: %O', error);
|
||||
} catch {
|
||||
// ignore send failure
|
||||
}
|
||||
|
||||
log(
|
||||
'executeWithInMemoryCallbacks: got response from finalState (%d chars, %d chunks)',
|
||||
lastAssistantContent.length,
|
||||
chunks.length,
|
||||
);
|
||||
resolve({ reply: lastAssistantContent, topicId: resolvedTopicId });
|
||||
|
||||
// Fire-and-forget: summarize topic title in DB (no Discord rename in local mode)
|
||||
if (resolvedTopicId && prompt) {
|
||||
const topicModel = new TopicModel(this.db, this.userId);
|
||||
topicModel
|
||||
.findById(resolvedTopicId)
|
||||
.then(async (topic) => {
|
||||
if (topic?.title) return;
|
||||
|
||||
const systemAgent = new SystemAgentService(this.db, this.userId);
|
||||
const title = await systemAgent.generateTopicTitle({
|
||||
lastAssistantContent,
|
||||
userPrompt: prompt,
|
||||
});
|
||||
if (!title) return;
|
||||
|
||||
await topicModel.update(resolvedTopicId, { title });
|
||||
})
|
||||
.catch((error) => {
|
||||
log(
|
||||
'executeWithInMemoryCallbacks: topic title summarization failed: %O',
|
||||
error,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
reject(new Error(errorMsg));
|
||||
return;
|
||||
}
|
||||
|
||||
reject(new Error('Agent completed but no response content found'));
|
||||
} catch (error) {
|
||||
reject(error);
|
||||
}
|
||||
if (reason === 'interrupted') {
|
||||
if (progressMessage) {
|
||||
try {
|
||||
await progressMessage.edit(renderStopped());
|
||||
} catch {
|
||||
// ignore edit failure
|
||||
}
|
||||
}
|
||||
resolve({ reply: '', topicId: resolvedTopicId });
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const lastAssistantContent = event.lastAssistantContent;
|
||||
|
||||
if (lastAssistantContent) {
|
||||
const replyBody = renderFinalReply(lastAssistantContent);
|
||||
const replyStats = {
|
||||
elapsedMs: event.duration ?? getElapsedMs(),
|
||||
llmCalls: event.llmCalls ?? 0,
|
||||
toolCalls: event.toolCalls ?? 0,
|
||||
totalCost: event.cost ?? 0,
|
||||
totalTokens: event.totalTokens ?? 0,
|
||||
};
|
||||
const formattedBody = client?.formatMarkdown?.(replyBody) ?? replyBody;
|
||||
const finalText =
|
||||
client?.formatReply?.(formattedBody, replyStats) ?? formattedBody;
|
||||
|
||||
const chunks = splitMessage(finalText, charLimit);
|
||||
|
||||
try {
|
||||
if (progressMessage) {
|
||||
await progressMessage.edit(chunks[0]);
|
||||
for (let i = 1; i < chunks.length; i++) {
|
||||
await thread.post(chunks[i]);
|
||||
}
|
||||
} else {
|
||||
for (const chunk of chunks) {
|
||||
await thread.post(chunk);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
log('executeWithCallback[local]: failed to send final message: %O', error);
|
||||
}
|
||||
|
||||
log(
|
||||
'executeWithCallback[local]: got response (%d chars, %d chunks)',
|
||||
lastAssistantContent.length,
|
||||
chunks.length,
|
||||
);
|
||||
resolve({ reply: lastAssistantContent, topicId: resolvedTopicId });
|
||||
|
||||
// Fire-and-forget: summarize topic title in DB
|
||||
if (resolvedTopicId && prompt) {
|
||||
const topicModel = new TopicModel(this.db, this.userId);
|
||||
topicModel
|
||||
.findById(resolvedTopicId)
|
||||
.then(async (topic) => {
|
||||
if (topic?.title) return;
|
||||
|
||||
const systemAgent = new SystemAgentService(this.db, this.userId);
|
||||
const title = await systemAgent.generateTopicTitle({
|
||||
lastAssistantContent,
|
||||
userPrompt: prompt,
|
||||
});
|
||||
if (!title) return;
|
||||
|
||||
await topicModel.update(resolvedTopicId, { title });
|
||||
})
|
||||
.catch((error) => {
|
||||
log(
|
||||
'executeWithCallback[local]: topic title summarization failed: %O',
|
||||
error,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
reject(new Error('Agent completed but no response content found'));
|
||||
} catch (error) {
|
||||
reject(error);
|
||||
}
|
||||
},
|
||||
id: 'bot-completion',
|
||||
type: 'onComplete' as const,
|
||||
webhook: {
|
||||
body: { ...webhookBody, type: 'completion', userPrompt: prompt },
|
||||
delivery: 'qstash' as const,
|
||||
url: callbackUrl,
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
prompt,
|
||||
signal,
|
||||
title: '',
|
||||
trigger,
|
||||
userInterventionConfig: { approvalMode: 'headless' },
|
||||
}),
|
||||
)
|
||||
.then(async (result) => {
|
||||
assistantMessageId = result.assistantMessageId;
|
||||
resolvedTopicId = result.topicId;
|
||||
operationStartTime = new Date(result.createdAt).getTime();
|
||||
|
||||
@@ -837,7 +922,7 @@ export class AgentBridgeService {
|
||||
renderError(result.error || 'Agent operation failed to start'),
|
||||
);
|
||||
} catch (error) {
|
||||
log('executeWithInMemoryCallbacks: failed to edit startup error: %O', error);
|
||||
log('executeWithCallback[local]: failed to edit startup error: %O', error);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -845,7 +930,6 @@ export class AgentBridgeService {
|
||||
return;
|
||||
}
|
||||
|
||||
// Track operationId so /stop can interrupt this execution
|
||||
if (result.operationId) {
|
||||
AgentBridgeService.activeOperations.set(thread.id, result.operationId);
|
||||
|
||||
@@ -854,9 +938,8 @@ export class AgentBridgeService {
|
||||
await this.interruptTrackedOperation(thread.id, result.operationId);
|
||||
} catch (error) {
|
||||
log(
|
||||
'executeWithInMemoryCallbacks: deferred stop failed for thread=%s, operationId=%s: %O',
|
||||
'executeWithCallback[local]: deferred stop failed for thread=%s: %O',
|
||||
thread.id,
|
||||
result.operationId,
|
||||
error,
|
||||
);
|
||||
}
|
||||
@@ -864,9 +947,8 @@ export class AgentBridgeService {
|
||||
}
|
||||
|
||||
log(
|
||||
'executeWithInMemoryCallbacks: operationId=%s, assistantMessageId=%s, topicId=%s',
|
||||
'executeWithCallback[local]: operationId=%s, topicId=%s',
|
||||
result.operationId,
|
||||
result.assistantMessageId,
|
||||
result.topicId,
|
||||
);
|
||||
})
|
||||
@@ -878,7 +960,7 @@ export class AgentBridgeService {
|
||||
try {
|
||||
await progressMessage.edit(renderStopped(error.message));
|
||||
} catch (editError) {
|
||||
log('executeWithInMemoryCallbacks: failed to edit stopped message: %O', editError);
|
||||
log('executeWithCallback[local]: failed to edit stopped message: %O', editError);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -890,7 +972,7 @@ export class AgentBridgeService {
|
||||
try {
|
||||
await progressMessage.edit(renderError(extractErrorMessage(error)));
|
||||
} catch (editError) {
|
||||
log('executeWithInMemoryCallbacks: failed to edit startup error: %O', editError);
|
||||
log('executeWithCallback[local]: failed to edit startup error: %O', editError);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -30,6 +30,10 @@ export interface BotCallbackBody {
|
||||
elapsedMs?: number;
|
||||
errorMessage?: string;
|
||||
executionTimeMs?: number;
|
||||
/** Hook ID from HookDispatcher (e.g. 'bot-step-progress', 'bot-completion') */
|
||||
hookId?: string;
|
||||
/** Hook type from HookDispatcher (e.g. 'afterStep', 'onComplete') */
|
||||
hookType?: string;
|
||||
lastAssistantContent?: string;
|
||||
lastLLMContent?: string;
|
||||
lastToolsCalling?: any;
|
||||
|
||||
@@ -7,7 +7,9 @@ const mockGetPlatform = vi.hoisted(() => vi.fn());
|
||||
const mockIsQueueAgentRuntimeEnabled = vi.hoisted(() => vi.fn());
|
||||
|
||||
vi.mock('@/database/models/topic', () => ({
|
||||
TopicModel: vi.fn(),
|
||||
TopicModel: vi.fn().mockImplementation(() => ({
|
||||
findById: vi.fn().mockResolvedValue(undefined),
|
||||
})),
|
||||
}));
|
||||
|
||||
vi.mock('@/database/models/user', () => ({
|
||||
@@ -112,7 +114,7 @@ describe('AgentBridgeService', () => {
|
||||
mockIsQueueAgentRuntimeEnabled.mockReturnValue(true);
|
||||
});
|
||||
|
||||
it('cleans up received reaction when queue-mode mention setup fails before callback handoff', async () => {
|
||||
it('calls execAgent with hooks in queue mode for mention', async () => {
|
||||
const service = new AgentBridgeService(FAKE_DB, USER_ID);
|
||||
const thread = createThread();
|
||||
const message = createMessage();
|
||||
@@ -124,15 +126,19 @@ describe('AgentBridgeService', () => {
|
||||
client,
|
||||
});
|
||||
|
||||
const [mentionReactionThreadId, mentionReactionMessageId, mentionReactionEmoji] =
|
||||
thread.adapter.removeReaction.mock.calls[0];
|
||||
expect(mentionReactionThreadId).toBe(THREAD_ID);
|
||||
expect(mentionReactionMessageId).toBe(MESSAGE_ID);
|
||||
expect(mentionReactionEmoji).toBeDefined();
|
||||
expect(mockExecAgent).not.toHaveBeenCalled();
|
||||
// execAgent should be called with hooks (afterStep + onComplete)
|
||||
expect(mockExecAgent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
agentId: 'agent-1',
|
||||
hooks: expect.arrayContaining([
|
||||
expect.objectContaining({ id: 'bot-step-progress', type: 'afterStep' }),
|
||||
expect.objectContaining({ id: 'bot-completion', type: 'onComplete' }),
|
||||
]),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('cleans up received reaction when queue-mode subscribed-message setup fails before callback handoff', async () => {
|
||||
it('calls execAgent with hooks in queue mode for subscribed message', async () => {
|
||||
const service = new AgentBridgeService(FAKE_DB, USER_ID);
|
||||
const thread = createThread({ topicId: 'topic-1' });
|
||||
const message = createMessage();
|
||||
@@ -144,11 +150,26 @@ describe('AgentBridgeService', () => {
|
||||
client,
|
||||
});
|
||||
|
||||
const [replyReactionThreadId, replyReactionMessageId, replyReactionEmoji] =
|
||||
thread.adapter.removeReaction.mock.calls[0];
|
||||
expect(replyReactionThreadId).toBe(THREAD_ID);
|
||||
expect(replyReactionMessageId).toBe(MESSAGE_ID);
|
||||
expect(replyReactionEmoji).toBeDefined();
|
||||
expect(mockExecAgent).not.toHaveBeenCalled();
|
||||
// execAgent should be called with hooks containing webhook config
|
||||
expect(mockExecAgent).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
hooks: expect.arrayContaining([
|
||||
expect.objectContaining({
|
||||
id: 'bot-step-progress',
|
||||
type: 'afterStep',
|
||||
webhook: expect.objectContaining({
|
||||
body: expect.objectContaining({ type: 'step', platformThreadId: THREAD_ID }),
|
||||
}),
|
||||
}),
|
||||
expect.objectContaining({
|
||||
id: 'bot-completion',
|
||||
type: 'onComplete',
|
||||
webhook: expect.objectContaining({
|
||||
body: expect.objectContaining({ type: 'completion', platformThreadId: THREAD_ID }),
|
||||
}),
|
||||
}),
|
||||
]),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -759,4 +759,73 @@ describe('BotCallbackService', () => {
|
||||
expect(mockFindById).not.toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
|
||||
describe('hook-based webhook payload compatibility', () => {
|
||||
// These tests verify that payloads from HookDispatcher (which include
|
||||
// hookId/hookType fields) are handled correctly by BotCallbackService.
|
||||
// This is the critical contract between the hooks framework and the bot callback.
|
||||
|
||||
it('should handle step payload with hookId and hookType fields', async () => {
|
||||
const body = makeBody({
|
||||
content: 'thinking...',
|
||||
executionTimeMs: 100,
|
||||
hookId: 'bot-step-progress',
|
||||
hookType: 'afterStep',
|
||||
shouldContinue: true,
|
||||
stepType: 'call_llm' as const,
|
||||
thinking: true,
|
||||
totalCost: 0.01,
|
||||
totalInputTokens: 100,
|
||||
totalOutputTokens: 50,
|
||||
totalSteps: 1,
|
||||
totalTokens: 150,
|
||||
type: 'step',
|
||||
});
|
||||
|
||||
await service.handleCallback(body);
|
||||
|
||||
expect(mockEditMessage).toHaveBeenCalledWith('progress-msg-1', expect.any(String));
|
||||
});
|
||||
|
||||
it('should handle completion payload with hookId and hookType fields', async () => {
|
||||
const body = makeBody({
|
||||
cost: 0.05,
|
||||
duration: 5000,
|
||||
hookId: 'bot-completion',
|
||||
hookType: 'onComplete',
|
||||
lastAssistantContent: 'Here is the answer',
|
||||
llmCalls: 3,
|
||||
reason: 'done',
|
||||
toolCalls: 2,
|
||||
totalTokens: 500,
|
||||
type: 'completion',
|
||||
userId: 'user-1',
|
||||
userPrompt: 'test question',
|
||||
});
|
||||
|
||||
await service.handleCallback(body);
|
||||
|
||||
expect(mockEditMessage).toHaveBeenCalledWith(
|
||||
'progress-msg-1',
|
||||
expect.stringContaining('Here is the answer'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle completion error payload from hooks', async () => {
|
||||
const body = makeBody({
|
||||
errorMessage: 'Rate limit exceeded',
|
||||
hookId: 'bot-completion',
|
||||
hookType: 'onComplete',
|
||||
reason: 'error',
|
||||
type: 'completion',
|
||||
});
|
||||
|
||||
await service.handleCallback(body);
|
||||
|
||||
expect(mockEditMessage).toHaveBeenCalledWith(
|
||||
'progress-msg-1',
|
||||
expect.stringContaining('Rate limit exceeded'),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
+1
-26
@@ -1,7 +1,6 @@
|
||||
import {
|
||||
type AWSBedrockKeyVault,
|
||||
type AzureOpenAIKeyVault,
|
||||
type ClientSecretPayload,
|
||||
type CloudflareKeyVault,
|
||||
type ComfyUIKeyVault,
|
||||
type OpenAICompatibleKeyVault,
|
||||
@@ -10,11 +9,7 @@ import {
|
||||
import { clientApiKeyManager } from '@lobechat/utils/client';
|
||||
import { ModelProvider } from 'model-bank';
|
||||
|
||||
import { LOBE_CHAT_AUTH_HEADER, SECRET_XOR_KEY } from '@/envs/auth';
|
||||
import { aiProviderSelectors, useAiInfraStore } from '@/store/aiInfra';
|
||||
import { useUserStore } from '@/store/user';
|
||||
import { userProfileSelectors } from '@/store/user/selectors';
|
||||
import { obfuscatePayloadWithXOR } from '@/utils/client/xor-obfuscation';
|
||||
|
||||
import { resolveRuntimeProvider } from './chat/helper';
|
||||
|
||||
@@ -104,15 +99,8 @@ export const getProviderAuthPayload = (
|
||||
}
|
||||
};
|
||||
|
||||
const createAuthTokenWithPayload = (payload = {}) => {
|
||||
const userId = userProfileSelectors.userId(useUserStore.getState());
|
||||
|
||||
return obfuscatePayloadWithXOR<ClientSecretPayload>({ userId, ...payload }, SECRET_XOR_KEY);
|
||||
};
|
||||
|
||||
interface AuthParams {
|
||||
headers?: HeadersInit;
|
||||
payload?: Record<string, any>;
|
||||
provider?: string;
|
||||
}
|
||||
|
||||
@@ -128,19 +116,6 @@ export const createPayloadWithKeyVaults = (provider: string) => {
|
||||
};
|
||||
};
|
||||
|
||||
export const createXorKeyVaultsPayload = (provider: string) => {
|
||||
const payload = createPayloadWithKeyVaults(provider);
|
||||
return obfuscatePayloadWithXOR(payload, SECRET_XOR_KEY);
|
||||
};
|
||||
|
||||
export const createHeaderWithAuth = async (params?: AuthParams): Promise<HeadersInit> => {
|
||||
let payload = params?.payload || {};
|
||||
|
||||
if (params?.provider) {
|
||||
payload = { ...payload, ...createPayloadWithKeyVaults(params?.provider) };
|
||||
}
|
||||
|
||||
const token = createAuthTokenWithPayload(payload);
|
||||
|
||||
return { ...params?.headers, [LOBE_CHAT_AUTH_HEADER]: token };
|
||||
return { ...params?.headers };
|
||||
};
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user