feat(memory-user-memory): added LoCoMo dataset loader & converter & exporter (#10923)

This commit is contained in:
Neko
2025-12-24 12:14:56 +08:00
committed by arvinxx
parent 03342a76e3
commit a5dd785dca
5 changed files with 232 additions and 4 deletions
@@ -0,0 +1 @@
export * from './locomo'
@@ -0,0 +1,154 @@
import { readFileSync } from 'node:fs';
import { resolve } from 'node:path';
import type { MemorySourceType } from '@lobechat/types';
import { MemorySourceType as MemorySourceTypeEnum } from '@lobechat/types';
export type LocomoQASample = {
conversation: Record<string, unknown>;
qa: unknown[];
sample_id: string;
};
export type LocomoTurn = {
blip_caption?: string | string[];
dia_id?: string;
img_url?: string | string[];
query?: string;
speaker: string;
text: string;
};
export type LocomoSession = {
dateTime?: string;
id: string;
turns: LocomoTurn[];
};
export type IngestTurnPayload = {
createdAt?: string;
diaId?: string;
imageCaption?: string;
imageUrls?: string[];
role: string;
speaker: string;
text: string;
};
export type IngestSessionPayload = {
sessionId: string;
timestamp?: string;
turns: IngestTurnPayload[];
};
export type IngestPayload = {
force: boolean;
layers: string[];
sampleId: string;
sessions: IngestSessionPayload[];
source: MemorySourceType;
topicId: string;
};
export type BuildIngestOptions = {
includeImageCaptions?: boolean;
layers?: string[];
source?: MemorySourceType;
speakerRoles?: {
defaultRole?: string;
speakerA?: string;
speakerB?: string;
};
topicIdPrefix?: string;
};
const SESSION_KEY_REGEX = /^session_(\d+)$/;
const normalizeArray = (value?: string | string[]) => {
if (!value) return [];
return Array.isArray(value) ? value.filter(Boolean) : [value];
};
const buildTurnText = (turn: LocomoTurn, includeImageCaptions?: boolean) => {
const captions = normalizeArray(turn.blip_caption);
if (!includeImageCaptions || captions.length === 0) return turn.text;
const suffix = captions.map((caption) => `[Image: ${caption}]`).join('\n');
return `${turn.text}${turn.text.endsWith('\n') ? '' : '\n'}${suffix}`;
};
const extractSessions = (conversation: Record<string, unknown>): LocomoSession[] => {
const sessions: LocomoSession[] = [];
Object.entries(conversation).forEach(([key, value]) => {
const match = key.match(SESSION_KEY_REGEX);
if (!match || !Array.isArray(value)) return;
const dateTime = conversation[`${key}_date_time`] as string | undefined;
const turns = (value as unknown[]).filter(Boolean) as LocomoTurn[];
sessions.push({ dateTime, id: key, turns });
});
return sessions.sort((a, b) => a.id.localeCompare(b.id));
};
const resolveRole = (
speaker: string,
speakerAName: string | undefined,
speakerBName: string | undefined,
roles?: BuildIngestOptions['speakerRoles'],
) => {
if (speakerAName && speaker === speakerAName) return roles?.speakerA || 'user';
if (speakerBName && speaker === speakerBName) return roles?.speakerB || 'assistant';
return roles?.defaultRole || 'user';
};
export const buildIngestPayload = (
sample: LocomoQASample,
options: BuildIngestOptions,
): IngestPayload => {
const speakerA = sample.conversation['speaker_a'] as string | undefined;
const speakerB = sample.conversation['speaker_b'] as string | undefined;
const sessions = extractSessions(sample.conversation);
const sessionPayloads: IngestSessionPayload[] = sessions.map((session) => ({
sessionId: session.id,
timestamp: session.dateTime,
turns: session.turns.map((turn) => ({
createdAt: session.dateTime,
diaId: turn.dia_id,
imageCaption: normalizeArray(turn.blip_caption).join('\n') || undefined,
imageUrls: normalizeArray(turn.img_url).length ? normalizeArray(turn.img_url) : undefined,
role: resolveRole(turn.speaker, speakerA, speakerB, options.speakerRoles),
speaker: turn.speaker,
text: buildTurnText(turn, options.includeImageCaptions),
})),
}));
return {
force: true,
layers: options.layers ?? [],
sampleId: sample.sample_id,
sessions: sessionPayloads,
source: options.source ?? MemorySourceTypeEnum.BenchmarkLocomo,
topicId: `${options.topicIdPrefix ?? 'sample'}_${sample.sample_id}`,
};
};
export const loadLocomoFile = (filePath: string): LocomoQASample[] => {
const absPath = resolve(filePath);
const raw = readFileSync(absPath, 'utf8');
const parsed = JSON.parse(raw);
if (!Array.isArray(parsed)) {
throw new Error('Expected LoCoMo JSON to be an array of samples');
}
return parsed as LocomoQASample[];
};
export const convertLocomoFile = (
filePath: string,
options: BuildIngestOptions,
): IngestPayload[] => loadLocomoFile(filePath).map((sample) => buildIngestPayload(sample, options));
+1
View File
@@ -1,3 +1,4 @@
export * from './converters';
export * from './extractors';
export * from './providers';
export * from './schemas';
@@ -0,0 +1,71 @@
import { u } from 'unist-builder';
import { toXml } from 'xast-util-to-xml';
import type { Child } from 'xastscript';
import { x } from 'xastscript';
import type { BuiltContext, MemoryContextProvider, MemoryExtractionJob } from '../types';
export interface BenchmarkLocomoPart {
content: string;
createdAt?: string | Date;
metadata?: Record<string, unknown> | null;
partIndex: number;
sessionId?: string | null;
speaker?: string | null;
}
export interface BenchmarkLocomoContextProviderOptions {
parts: BenchmarkLocomoPart[];
sampleId: string;
sourceId: string;
userId: string;
}
export class BenchmarkLocomoContextProvider
implements MemoryContextProvider<Record<string, unknown>, Record<string, unknown>>
{
private readonly options: BenchmarkLocomoContextProviderOptions;
constructor(options: BenchmarkLocomoContextProviderOptions) {
this.options = options;
}
private buildMessageNode(part: BenchmarkLocomoPart, index: number) {
const attributes: Record<string, string> = {
index: index.toString(),
};
if (part.speaker) attributes.speaker = part.speaker;
if (part.createdAt) attributes.created_at = new Date(part.createdAt).toISOString();
if (part.sessionId) attributes.session_id = part.sessionId;
const metadata = part.metadata ? JSON.stringify(part.metadata) : undefined;
return x('message', attributes, part.content, metadata ? `\n[metadata:${metadata}]` : '');
}
async buildContext(job: MemoryExtractionJob): Promise<BuiltContext<Record<string, unknown>>> {
const messageChildren: Child[] = this.options.parts.map((part, index) =>
this.buildMessageNode(part, index),
);
const root = u('root', [
x(
'benchmark_locomo',
{
sample_id: this.options.sampleId,
source_id: this.options.sourceId,
user_id: this.options.userId,
},
...messageChildren,
),
]);
return {
context: toXml(root),
metadata: {},
sourceId: this.options.sourceId,
userId: job.userId,
};
}
}
@@ -1,5 +1,6 @@
export type * from './benchmarkLocomo';
export { BenchmarkLocomoContextProvider } from './benchmarkLocomo';
export type * from './chatTopic';
export { LobeChatTopicContextProvider, LobeChatTopicResultRecorder } from './chatTopic';
export {
RetrievalUserMemoryContextProvider,
RetrievalUserMemoryIdentitiesProvider,
} from './existingUserMemory';
export type * from './existingUserMemory';
export { RetrievalUserMemoryContextProvider, RetrievalUserMemoryIdentitiesProvider } from './existingUserMemory';