mirror of
https://github.com/dokploy/dokploy.git
synced 2026-06-14 03:19:49 +00:00
feat: implement embeddings for AI chat and enhance tool retrieval
- Introduced a new embeddings system for AI chat, allowing for improved context understanding and response accuracy. - Added functionality to retrieve relevant endpoints based on user queries, enhancing the AI's ability to provide precise information. - Updated the chat panel to restore messages from local storage and persist chat history, improving user experience. - Enhanced error handling and added semantic hints for API parameters, ensuring clearer guidance for users. These changes significantly improve the AI chat capabilities and overall interaction quality within the Dokploy platform.
This commit is contained in:
File diff suppressed because one or more lines are too long
@@ -133,6 +133,7 @@ export function ChatPanel() {
|
||||
const enabledProviders = providers ?? [];
|
||||
|
||||
const STORAGE_KEY = "dokploy-chat-messages";
|
||||
const restoredRef = useRef(false);
|
||||
|
||||
const { messages, sendMessage, status, setMessages, addToolApprovalResponse } = useChat({
|
||||
id: "dokploy-chat",
|
||||
@@ -143,23 +144,32 @@ export function ChatPanel() {
|
||||
context: contextRef.current,
|
||||
}),
|
||||
}),
|
||||
initialMessages: () => {
|
||||
try {
|
||||
const stored = localStorage.getItem(STORAGE_KEY);
|
||||
return stored ? JSON.parse(stored) : [];
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
const isLoading = status === "streaming" || status === "submitted";
|
||||
|
||||
// Restore messages from localStorage on mount
|
||||
useEffect(() => {
|
||||
if (restoredRef.current) return;
|
||||
restoredRef.current = true;
|
||||
try {
|
||||
const stored = localStorage.getItem(STORAGE_KEY);
|
||||
if (stored) {
|
||||
const parsed = JSON.parse(stored);
|
||||
if (Array.isArray(parsed) && parsed.length > 0) {
|
||||
setMessages(parsed);
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
}, [setMessages]);
|
||||
|
||||
// Persist messages to localStorage
|
||||
useEffect(() => {
|
||||
if (!restoredRef.current) return;
|
||||
if (messages.length > 0) {
|
||||
try {
|
||||
// Keep only last 50 messages to avoid localStorage bloat
|
||||
const toStore = messages.slice(-50);
|
||||
localStorage.setItem(STORAGE_KEY, JSON.stringify(toStore));
|
||||
} catch {
|
||||
@@ -378,6 +388,7 @@ export function ChatPanel() {
|
||||
toolCallId={part.toolCallId}
|
||||
toolName={part.toolName}
|
||||
state={part.state}
|
||||
input={(part as any).input}
|
||||
output={
|
||||
part.state === "output-available"
|
||||
? part.output
|
||||
@@ -484,6 +495,7 @@ function ToolCallDisplay({
|
||||
toolCallId,
|
||||
toolName,
|
||||
state,
|
||||
input,
|
||||
output,
|
||||
onApprove,
|
||||
onDeny,
|
||||
@@ -491,6 +503,7 @@ function ToolCallDisplay({
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
state: string;
|
||||
input?: unknown;
|
||||
output?: unknown;
|
||||
onApprove?: (id: string) => void;
|
||||
onDeny?: (id: string) => void;
|
||||
@@ -508,19 +521,47 @@ function ToolCallDisplay({
|
||||
: JSON.stringify(output, null, 2)
|
||||
: null;
|
||||
|
||||
const displayName = toolName
|
||||
.split("-")
|
||||
.map((w) => w.charAt(0).toUpperCase() + w.slice(1))
|
||||
.join(" ");
|
||||
// Extract operationId and params from input
|
||||
const inputData = input as { operationId?: string; params?: Record<string, unknown> } | undefined;
|
||||
const operationId = inputData?.operationId;
|
||||
const params = inputData?.params;
|
||||
|
||||
// Format: "compose-one" → "compose → one"
|
||||
const displayLabel = operationId
|
||||
? operationId.replace("-", " → ")
|
||||
: toolName;
|
||||
|
||||
// Determine HTTP method hint from operationId
|
||||
const isReadOp = operationId?.match(/^(.*-)?(one|all|get|list|read|search|by)/i);
|
||||
|
||||
const StatusIcon = isRunning
|
||||
? () => <Loader2 className="h-3.5 w-3.5 animate-spin text-blue-500 shrink-0" />
|
||||
: isDone
|
||||
? () => <Check className="h-3.5 w-3.5 text-green-500 shrink-0" />
|
||||
: isError
|
||||
? () => <X className="h-3.5 w-3.5 text-red-500 shrink-0" />
|
||||
: () => <Wrench className="h-3.5 w-3.5 text-muted-foreground shrink-0" />;
|
||||
|
||||
if (needsApproval) {
|
||||
return (
|
||||
<div className="flex items-center justify-between gap-2 text-xs">
|
||||
<div className="flex items-center gap-1.5">
|
||||
<Wrench className="h-3 w-3 text-muted-foreground shrink-0" />
|
||||
<span>{displayName}</span>
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center gap-2 text-xs">
|
||||
<Wrench className="h-3.5 w-3.5 text-yellow-500 shrink-0" />
|
||||
<code className="font-mono text-xs font-medium">{displayLabel}</code>
|
||||
<Badge variant="outline" className="text-[10px] px-1 py-0 h-4 font-normal">
|
||||
write
|
||||
</Badge>
|
||||
</div>
|
||||
<div className="flex gap-1.5 shrink-0">
|
||||
{params && Object.keys(params).length > 0 && (
|
||||
<div className="ml-5.5 flex flex-wrap gap-1">
|
||||
{Object.entries(params).map(([key, value]) => (
|
||||
<span key={key} className="text-[10px] bg-muted px-1.5 py-0.5 rounded font-mono">
|
||||
{key}={typeof value === "string" ? `"${value}"` : String(value)}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
<div className="flex gap-1.5 ml-5.5">
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
@@ -545,43 +586,44 @@ function ToolCallDisplay({
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex items-start gap-1.5 text-xs">
|
||||
{isRunning ? (
|
||||
<Loader2 className="h-3 w-3 animate-spin text-muted-foreground mt-0.5 shrink-0" />
|
||||
) : isDone ? (
|
||||
<Check className="h-3 w-3 text-muted-foreground mt-0.5 shrink-0" />
|
||||
) : isError ? (
|
||||
<X className="h-3 w-3 text-destructive mt-0.5 shrink-0" />
|
||||
) : (
|
||||
<Wrench className="h-3 w-3 text-muted-foreground mt-0.5 shrink-0" />
|
||||
)}
|
||||
|
||||
{outputText ? (
|
||||
<Collapsible open={isOpen} onOpenChange={setIsOpen}>
|
||||
<CollapsibleTrigger asChild>
|
||||
<button
|
||||
type="button"
|
||||
className="flex items-center gap-1 text-muted-foreground hover:text-foreground transition-colors"
|
||||
>
|
||||
<span>{displayName}</span>
|
||||
<div className="space-y-1">
|
||||
<Collapsible open={isOpen} onOpenChange={setIsOpen}>
|
||||
<CollapsibleTrigger asChild>
|
||||
<button
|
||||
type="button"
|
||||
className="flex items-center gap-2 text-xs w-full hover:bg-muted/50 rounded -mx-1 px-1 py-0.5 transition-colors"
|
||||
>
|
||||
<StatusIcon />
|
||||
<code className="font-mono text-xs font-medium">{displayLabel}</code>
|
||||
{isReadOp && (
|
||||
<Badge variant="secondary" className="text-[10px] px-1 py-0 h-4 font-normal">
|
||||
read
|
||||
</Badge>
|
||||
)}
|
||||
{params && Object.keys(params).length > 0 && (
|
||||
<span className="text-[10px] text-muted-foreground truncate">
|
||||
{Object.entries(params)
|
||||
.slice(0, 3)
|
||||
.map(([k, v]) => `${k}=${typeof v === "string" ? `"${String(v).slice(0, 20)}"` : String(v)}`)
|
||||
.join(", ")}
|
||||
{Object.keys(params).length > 3 ? ` +${Object.keys(params).length - 3}` : ""}
|
||||
</span>
|
||||
)}
|
||||
{(outputText || isRunning) && (
|
||||
<ChevronDown
|
||||
className={`h-3 w-3 transition-transform ${isOpen ? "rotate-180" : ""}`}
|
||||
className={`h-3 w-3 ml-auto text-muted-foreground transition-transform shrink-0 ${isOpen ? "rotate-180" : ""}`}
|
||||
/>
|
||||
</button>
|
||||
</CollapsibleTrigger>
|
||||
)}
|
||||
</button>
|
||||
</CollapsibleTrigger>
|
||||
{outputText && (
|
||||
<CollapsibleContent>
|
||||
<pre className="mt-1 p-2 bg-muted/50 rounded text-[10px] overflow-x-auto max-h-[150px] overflow-y-auto leading-tight">
|
||||
{outputText.length > 2000
|
||||
? `${outputText.slice(0, 2000)}\n... (truncated)`
|
||||
: outputText}
|
||||
<pre className="mt-1 ml-5.5 p-2 bg-muted/50 rounded text-[10px] overflow-x-auto max-h-[200px] overflow-y-auto leading-tight whitespace-pre-wrap break-words">
|
||||
{outputText}
|
||||
</pre>
|
||||
</CollapsibleContent>
|
||||
</Collapsible>
|
||||
) : (
|
||||
<span className="text-muted-foreground">
|
||||
{isRunning ? `${displayName}...` : displayName}
|
||||
</span>
|
||||
)}
|
||||
)}
|
||||
</Collapsible>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -10,13 +10,16 @@ import {
|
||||
buildEndpointCatalog,
|
||||
createApiTool,
|
||||
} from "@dokploy/server/utils/ai/api-tool";
|
||||
import {
|
||||
getOrCreateEmbeddings,
|
||||
retrieveRelevantEndpoints,
|
||||
} from "@dokploy/server/utils/ai/tool-retrieval";
|
||||
import { selectAIProvider } from "@dokploy/server/utils/ai/select-ai-provider";
|
||||
import { createAnthropic } from "@ai-sdk/anthropic";
|
||||
import { convertToModelMessages, stepCountIs, streamText } from "ai";
|
||||
import type { NextApiRequest, NextApiResponse } from "next";
|
||||
|
||||
let cachedSpec: any = null;
|
||||
const cachedCatalogs = new Map<string, { catalog: string; count: number; operationIds: Set<string> }>();
|
||||
|
||||
function getOpenApiSpec() {
|
||||
if (!cachedSpec) {
|
||||
@@ -30,14 +33,6 @@ function getOpenApiSpec() {
|
||||
return cachedSpec;
|
||||
}
|
||||
|
||||
function getEndpointCatalog(spec: any, contextType: ChatContext["type"]) {
|
||||
const cached = cachedCatalogs.get(contextType);
|
||||
if (cached) return cached;
|
||||
const result = buildEndpointCatalog(spec, contextType);
|
||||
cachedCatalogs.set(contextType, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
function buildContextBlock(context: ChatContext): string {
|
||||
if (context.type === "general") {
|
||||
return "CONTEXT: The user is on the general dashboard (no specific resource selected). Use project-all to list their projects if needed.";
|
||||
@@ -86,15 +81,17 @@ BEHAVIOR:
|
||||
- When something fails → read the error, figure out the fix, and apply it. Don't stop to explain the error — fix it.
|
||||
- EVERY capability you need is in the ENDPOINT CATALOG below. If you think you can't do something, you're wrong — scan ALL sections again.
|
||||
- You already have all the IDs you need from the context above. NEVER ask the user for IDs, paths, or information you can discover by calling endpoints.
|
||||
- For destructive actions only (delete, stop): briefly confirm. Everything else: just do it.
|
||||
- NEVER ask for confirmation or permission. The only exception is deleting a service entirely. For everything else (read, update, deploy, stop, start, restart) → just do it immediately.
|
||||
|
||||
KEY PATTERN: When you need to explore files, find paths, or check repository structure → use the "patch" section endpoints to browse directories and read files. NEVER ask the user for file paths.
|
||||
|
||||
DATA MODEL: Project → Environment → Services (application, compose, postgres, mysql, redis, mongo, mariadb, libsql). Each service has deployments with build logs.
|
||||
|
||||
TOOL: You have one tool "call_api". Pass operationId + params from the catalog.
|
||||
- ALWAYS pass required params (*) in the "params" object. Example: { "operationId": "domain-byComposeId", "params": { "composeId": "abc123" } }
|
||||
- Params: * = required, ? = optional, [a|b|c] = allowed values
|
||||
- GET = read-only (auto-executed). POST/PUT/DELETE = write (user approves).
|
||||
- If a call fails, read the error message and fix the params. NEVER retry the same call with the same params.
|
||||
|
||||
RESPONSE STYLE:
|
||||
- 2-3 sentences max. No walls of text.
|
||||
@@ -186,10 +183,44 @@ export default async function handler(
|
||||
const spec = getOpenApiSpec();
|
||||
|
||||
if (spec) {
|
||||
const { catalog, count, operationIds } = getEndpointCatalog(spec, context.type);
|
||||
const voyageApiKey = process.env.VOYAGE_API_KEY;
|
||||
if (!voyageApiKey) {
|
||||
return res.status(400).json({ error: "VOYAGE_API_KEY is required" });
|
||||
}
|
||||
|
||||
const embeddingsPath = join(process.cwd(), ".tool-embeddings.json");
|
||||
const allEmbeddings = await getOrCreateEmbeddings(
|
||||
spec,
|
||||
voyageApiKey,
|
||||
embeddingsPath,
|
||||
);
|
||||
|
||||
const userQuery = getUserMessages(messages).trim();
|
||||
const { operationIds: tagFilteredIds } = buildEndpointCatalog(spec, context.type);
|
||||
|
||||
let relevantIds: Set<string> | undefined;
|
||||
|
||||
if (userQuery && allEmbeddings.length > 0) {
|
||||
const topIds = await retrieveRelevantEndpoints(
|
||||
userQuery,
|
||||
allEmbeddings,
|
||||
voyageApiKey,
|
||||
{ allowedOperationIds: tagFilteredIds, topK: 25 },
|
||||
);
|
||||
|
||||
if (topIds.length > 0) {
|
||||
relevantIds = new Set(topIds);
|
||||
}
|
||||
}
|
||||
|
||||
const { catalog, count, operationIds } = buildEndpointCatalog(
|
||||
spec,
|
||||
context.type,
|
||||
relevantIds,
|
||||
);
|
||||
catalogText = catalog;
|
||||
endpointCount = count;
|
||||
tools = createApiTool(spec, toolConfig, operationIds, 2000);
|
||||
tools = createApiTool(spec, toolConfig, operationIds, 8000);
|
||||
} else {
|
||||
tools = getAllTools(context, toolConfig);
|
||||
}
|
||||
|
||||
@@ -206,6 +206,7 @@ export interface CatalogResult {
|
||||
export function buildEndpointCatalog(
|
||||
spec: OpenApiSpec,
|
||||
contextType: ChatContext["type"] = "general",
|
||||
relevantOperationIds?: Set<string>,
|
||||
): CatalogResult {
|
||||
const operationIds = new Set<string>();
|
||||
const allowedTags = getAllowedTags(contextType);
|
||||
@@ -216,6 +217,7 @@ export function buildEndpointCatalog(
|
||||
if (!op.operationId || op.deprecated) continue;
|
||||
if (op.tags?.some((t) => EXCLUDED_TAGS.has(t))) continue;
|
||||
if (allowedTags && !op.tags?.some((t) => allowedTags.has(t))) continue;
|
||||
if (relevantOperationIds && !relevantOperationIds.has(op.operationId)) continue;
|
||||
|
||||
operationIds.add(op.operationId);
|
||||
|
||||
@@ -381,7 +383,7 @@ export function createApiTool(
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
return `API error (${response.status}): ${errorText.slice(0, 500)}`;
|
||||
return `API error (${response.status}): ${errorText.slice(0, 500)}\n\nHint: Check the ENDPOINT CATALOG for required parameters (*). You called "${operationId}" with params: ${JSON.stringify(params ?? {})}`;
|
||||
}
|
||||
|
||||
const json = JSON.stringify(await response.json(), null, 2);
|
||||
|
||||
@@ -0,0 +1,244 @@
|
||||
import { readFileSync, writeFileSync, existsSync } from "node:fs";
|
||||
import { join } from "node:path";
|
||||
|
||||
interface EndpointEmbedding {
|
||||
operationId: string;
|
||||
text: string;
|
||||
tags: string[];
|
||||
embedding: number[];
|
||||
}
|
||||
|
||||
const VOYAGE_MODEL = "voyage-3-lite";
|
||||
const VOYAGE_API = "https://api.voyageai.com/v1/embeddings";
|
||||
const BATCH_SIZE = 128;
|
||||
|
||||
/**
|
||||
* Call Voyage AI to embed an array of texts.
|
||||
*/
|
||||
async function embedTexts(
|
||||
texts: string[],
|
||||
apiKey: string,
|
||||
inputType: "document" | "query" = "document",
|
||||
): Promise<number[][]> {
|
||||
const results: number[][] = [];
|
||||
|
||||
for (let i = 0; i < texts.length; i += BATCH_SIZE) {
|
||||
const batch = texts.slice(i, i + BATCH_SIZE);
|
||||
const response = await fetch(VOYAGE_API, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model: VOYAGE_MODEL,
|
||||
input: batch,
|
||||
input_type: inputType,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Voyage API error: ${response.status} ${await response.text()}`,
|
||||
);
|
||||
}
|
||||
|
||||
const data = (await response.json()) as {
|
||||
data: { embedding: number[] }[];
|
||||
};
|
||||
for (const item of data.data) {
|
||||
results.push(item.embedding);
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
/**
|
||||
* Cosine similarity between two vectors.
|
||||
*/
|
||||
function cosineSimilarity(a: number[], b: number[]): number {
|
||||
let dot = 0;
|
||||
let normA = 0;
|
||||
let normB = 0;
|
||||
for (let i = 0; i < a.length; i++) {
|
||||
dot += a[i]! * b[i]!;
|
||||
normA += a[i]! * a[i]!;
|
||||
normB += b[i]! * b[i]!;
|
||||
}
|
||||
return dot / (Math.sqrt(normA) * Math.sqrt(normB));
|
||||
}
|
||||
|
||||
// In-memory cache
|
||||
let cachedEmbeddings: EndpointEmbedding[] | null = null;
|
||||
|
||||
/**
|
||||
* Extract enum values from a JSON Schema property (handles anyOf wrappers).
|
||||
*/
|
||||
function extractEnum(prop: any): string[] | null {
|
||||
if (prop?.enum) return prop.enum;
|
||||
if (Array.isArray(prop?.anyOf)) {
|
||||
for (const variant of prop.anyOf) {
|
||||
if (variant?.enum) return variant.enum;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a rich text representation for an endpoint (used for embedding).
|
||||
* Includes: operationId, method, path, params with enums, summary, description.
|
||||
*/
|
||||
function buildEndpointText(
|
||||
op: any,
|
||||
method: string,
|
||||
path: string,
|
||||
): string {
|
||||
const parts: string[] = [];
|
||||
|
||||
// Operation identity
|
||||
parts.push(`${op.operationId} [${method.toUpperCase()} ${path}]`);
|
||||
|
||||
// Tags
|
||||
if (op.tags?.length) {
|
||||
parts.push(`Tags: ${op.tags.join(", ")}`);
|
||||
}
|
||||
|
||||
// Summary + description
|
||||
if (op.summary) parts.push(op.summary);
|
||||
if (op.description) parts.push(op.description);
|
||||
|
||||
// Parameters
|
||||
const params: string[] = [];
|
||||
if (op.parameters) {
|
||||
for (const p of op.parameters) {
|
||||
if (p.in === "header") continue;
|
||||
const req = p.required ? "required" : "optional";
|
||||
params.push(`${p.name} (${req})`);
|
||||
}
|
||||
}
|
||||
|
||||
if (op.requestBody?.content?.["application/json"]?.schema) {
|
||||
const schema = op.requestBody.content["application/json"].schema;
|
||||
const requiredSet = new Set(schema.required ?? []);
|
||||
if (schema.properties) {
|
||||
for (const [key, prop] of Object.entries(
|
||||
schema.properties as Record<string, any>,
|
||||
)) {
|
||||
const req = requiredSet.has(key) ? "required" : "optional";
|
||||
const enumVals = extractEnum(prop);
|
||||
const enumStr = enumVals ? ` [${enumVals.join("|")}]` : "";
|
||||
params.push(`${key} (${req})${enumStr}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (params.length > 0) {
|
||||
parts.push(`Parameters: ${params.join(", ")}`);
|
||||
}
|
||||
|
||||
return parts.join(". ");
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate or load embeddings for all endpoints in the OpenAPI spec.
|
||||
* Embeddings are cached in .tool-embeddings.json and in memory.
|
||||
*/
|
||||
export async function getOrCreateEmbeddings(
|
||||
spec: any,
|
||||
voyageApiKey: string,
|
||||
cachePath?: string,
|
||||
): Promise<EndpointEmbedding[]> {
|
||||
// Return from memory cache
|
||||
if (cachedEmbeddings) return cachedEmbeddings;
|
||||
|
||||
// Try loading from file cache
|
||||
const filePath =
|
||||
cachePath || join(process.cwd(), ".tool-embeddings.json");
|
||||
|
||||
if (existsSync(filePath)) {
|
||||
try {
|
||||
const data = JSON.parse(readFileSync(filePath, "utf-8"));
|
||||
if (Array.isArray(data) && data.length > 0 && data[0].embedding) {
|
||||
cachedEmbeddings = data;
|
||||
return cachedEmbeddings;
|
||||
}
|
||||
} catch {
|
||||
// Corrupted file — regenerate
|
||||
}
|
||||
}
|
||||
|
||||
// Generate embeddings from spec
|
||||
const endpoints: { operationId: string; text: string; tags: string[] }[] =
|
||||
[];
|
||||
|
||||
for (const [path, methods] of Object.entries(spec.paths ?? {})) {
|
||||
for (const [method, op] of Object.entries(methods as Record<string, any>)) {
|
||||
if (!op.operationId || op.deprecated) continue;
|
||||
endpoints.push({
|
||||
operationId: op.operationId,
|
||||
text: buildEndpointText(op, method, path),
|
||||
tags: op.tags ?? [],
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (endpoints.length === 0) {
|
||||
cachedEmbeddings = [];
|
||||
return cachedEmbeddings;
|
||||
}
|
||||
|
||||
const texts = endpoints.map((e) => e.text);
|
||||
const embeddings = await embedTexts(texts, voyageApiKey, "document");
|
||||
|
||||
cachedEmbeddings = endpoints.map((e, i) => ({
|
||||
...e,
|
||||
embedding: embeddings[i]!,
|
||||
}));
|
||||
|
||||
// Persist to file
|
||||
try {
|
||||
writeFileSync(filePath, JSON.stringify(cachedEmbeddings));
|
||||
} catch {
|
||||
// Non-critical — will regenerate next time
|
||||
}
|
||||
|
||||
return cachedEmbeddings;
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieve the top-K most relevant endpoints for a user query,
|
||||
* optionally filtered to a pre-computed set of allowed operationIds.
|
||||
*/
|
||||
export async function retrieveRelevantEndpoints(
|
||||
query: string,
|
||||
allEmbeddings: EndpointEmbedding[],
|
||||
voyageApiKey: string,
|
||||
options?: {
|
||||
allowedOperationIds?: Set<string>;
|
||||
topK?: number;
|
||||
},
|
||||
): Promise<string[]> {
|
||||
const { allowedOperationIds, topK = 20 } = options ?? {};
|
||||
|
||||
// Filter to allowed operationIds (from tag filtering)
|
||||
const candidates = allowedOperationIds
|
||||
? allEmbeddings.filter((e) => allowedOperationIds.has(e.operationId))
|
||||
: allEmbeddings;
|
||||
|
||||
if (candidates.length === 0) return [];
|
||||
|
||||
// Embed the user query
|
||||
const [queryEmbedding] = await embedTexts([query], voyageApiKey, "query");
|
||||
if (!queryEmbedding) return [];
|
||||
|
||||
// Score and rank
|
||||
const scored = candidates.map((e) => ({
|
||||
operationId: e.operationId,
|
||||
score: cosineSimilarity(queryEmbedding, e.embedding),
|
||||
}));
|
||||
|
||||
scored.sort((a, b) => b.score - a.score);
|
||||
|
||||
return scored.slice(0, topK).map((s) => s.operationId);
|
||||
}
|
||||
Reference in New Issue
Block a user