Files
open-webui/backend/open_webui/retrieval/utils.py
T

1483 lines
54 KiB
Python
Raw Normal View History

import logging
2024-08-28 00:10:27 +02:00
import os
2025-11-22 21:33:14 -05:00
from typing import Awaitable, Optional, Union
2024-03-08 19:26:39 -08:00
2024-08-28 00:10:27 +02:00
import requests
import aiohttp
import asyncio
2025-02-26 23:51:39 -08:00
import hashlib
2025-03-31 16:43:37 +02:00
from concurrent.futures import ThreadPoolExecutor
2025-05-19 22:58:04 -04:00
import time
2025-10-04 02:02:26 -05:00
import re
2024-09-10 02:27:50 +01:00
from urllib.parse import quote
2024-04-25 07:49:59 -05:00
from huggingface_hub import snapshot_download
2025-12-20 22:50:44 +09:00
from langchain_classic.retrievers import (
ContextualCompressionRetriever,
EnsembleRetriever,
)
2024-04-22 18:36:46 -05:00
from langchain_community.retrievers import BM25Retriever
2024-08-28 00:10:27 +02:00
from langchain_core.documents import Document
2024-09-10 02:27:50 +01:00
2025-01-08 00:21:50 -08:00
from open_webui.config import VECTOR_DB
from open_webui.retrieval.vector.async_client import ASYNC_VECTOR_DB_CLIENT
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
2025-02-20 11:02:45 -08:00
2025-10-04 02:02:26 -05:00
from open_webui.models.users import UserModel
2025-02-26 15:42:19 -08:00
from open_webui.models.files import Files
2025-07-11 12:00:21 +04:00
from open_webui.models.knowledge import Knowledges
2025-09-14 10:26:46 +02:00
from open_webui.models.chats import Chats
2025-07-09 01:17:25 +04:00
from open_webui.models.notes import Notes
2026-02-08 21:24:38 -06:00
from open_webui.models.access_grants import AccessGrants
2026-03-26 19:01:33 -05:00
from open_webui.utils.access_control.files import has_access_to_file
2024-04-14 19:48:15 -04:00
2025-03-30 20:48:22 -07:00
from open_webui.retrieval.vector.main import GetResult
from open_webui.utils.headers import include_user_info_headers
2025-09-14 10:26:46 +02:00
from open_webui.utils.misc import get_message_list
2025-03-30 20:48:22 -07:00
2025-10-04 02:02:26 -05:00
from open_webui.retrieval.web.utils import get_web_loader
from open_webui.retrieval.loaders.youtube import YoutubeLoader
2025-02-05 15:15:24 -08:00
2025-02-05 00:07:45 -08:00
from open_webui.env import (
2026-01-08 00:42:29 +04:00
AIOHTTP_CLIENT_TIMEOUT,
2025-02-05 00:07:45 -08:00
OFFLINE_MODE,
ENABLE_FORWARD_USER_INFO_HEADERS,
2026-01-01 01:27:07 +04:00
AIOHTTP_CLIENT_SESSION_SSL,
2025-02-05 00:07:45 -08:00
)
2025-02-04 13:04:36 -08:00
from open_webui.config import (
2025-03-30 21:55:15 -07:00
RAG_EMBEDDING_QUERY_PREFIX,
RAG_EMBEDDING_CONTENT_PREFIX,
RAG_EMBEDDING_PREFIX_FIELD_NAME,
2025-02-04 13:04:36 -08:00
)
2024-09-10 02:27:50 +01:00
log = logging.getLogger(__name__)
2024-03-08 19:26:39 -08:00
2024-09-10 04:37:06 +01:00
from typing import Any
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever
2025-10-04 02:02:26 -05:00
def is_youtube_url(url: str) -> bool:
2026-03-17 17:58:01 -05:00
youtube_regex = r'^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$'
2025-10-04 02:02:26 -05:00
return re.match(youtube_regex, url) is not None
def get_loader(request, url: str):
if is_youtube_url(url):
return YoutubeLoader(
url,
language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
)
else:
return get_web_loader(
url,
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS,
2025-10-28 04:58:00 +08:00
trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV,
2025-10-04 02:02:26 -05:00
)
2026-04-21 15:47:32 +09:00
def build_loader_from_config(request):
"""Build a Loader instance with the admin's configured extraction engine settings."""
from open_webui.retrieval.loaders.main import Loader
config = request.app.state.config
return Loader(
engine=config.CONTENT_EXTRACTION_ENGINE,
DATALAB_MARKER_API_KEY=config.DATALAB_MARKER_API_KEY,
DATALAB_MARKER_API_BASE_URL=config.DATALAB_MARKER_API_BASE_URL,
DATALAB_MARKER_ADDITIONAL_CONFIG=config.DATALAB_MARKER_ADDITIONAL_CONFIG,
DATALAB_MARKER_SKIP_CACHE=config.DATALAB_MARKER_SKIP_CACHE,
DATALAB_MARKER_FORCE_OCR=config.DATALAB_MARKER_FORCE_OCR,
DATALAB_MARKER_PAGINATE=config.DATALAB_MARKER_PAGINATE,
DATALAB_MARKER_STRIP_EXISTING_OCR=config.DATALAB_MARKER_STRIP_EXISTING_OCR,
DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION=config.DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION,
DATALAB_MARKER_FORMAT_LINES=config.DATALAB_MARKER_FORMAT_LINES,
DATALAB_MARKER_USE_LLM=config.DATALAB_MARKER_USE_LLM,
DATALAB_MARKER_OUTPUT_FORMAT=config.DATALAB_MARKER_OUTPUT_FORMAT,
EXTERNAL_DOCUMENT_LOADER_URL=config.EXTERNAL_DOCUMENT_LOADER_URL,
EXTERNAL_DOCUMENT_LOADER_API_KEY=config.EXTERNAL_DOCUMENT_LOADER_API_KEY,
TIKA_SERVER_URL=config.TIKA_SERVER_URL,
DOCLING_SERVER_URL=config.DOCLING_SERVER_URL,
DOCLING_API_KEY=config.DOCLING_API_KEY,
DOCLING_PARAMS=config.DOCLING_PARAMS,
PDF_EXTRACT_IMAGES=config.PDF_EXTRACT_IMAGES,
PDF_LOADER_MODE=config.PDF_LOADER_MODE,
DOCUMENT_INTELLIGENCE_ENDPOINT=config.DOCUMENT_INTELLIGENCE_ENDPOINT,
DOCUMENT_INTELLIGENCE_KEY=config.DOCUMENT_INTELLIGENCE_KEY,
DOCUMENT_INTELLIGENCE_MODEL=config.DOCUMENT_INTELLIGENCE_MODEL,
MISTRAL_OCR_API_BASE_URL=config.MISTRAL_OCR_API_BASE_URL,
MISTRAL_OCR_API_KEY=config.MISTRAL_OCR_API_KEY,
PADDLEOCR_VL_BASE_URL=config.PADDLEOCR_VL_BASE_URL,
PADDLEOCR_VL_TOKEN=config.PADDLEOCR_VL_TOKEN,
2026-04-21 15:47:32 +09:00
MINERU_API_MODE=config.MINERU_API_MODE,
MINERU_API_URL=config.MINERU_API_URL,
MINERU_API_KEY=config.MINERU_API_KEY,
MINERU_API_TIMEOUT=config.MINERU_API_TIMEOUT,
MINERU_PARAMS=config.MINERU_PARAMS,
)
2026-04-21 15:52:00 +09:00
def _extract_text_from_binary_response(request, response: requests.Response, url: str) -> tuple[str, list]:
2026-04-21 15:47:32 +09:00
"""Download response body to a temp file and extract text using the Loader pipeline."""
import mimetypes
import tempfile
import urllib.parse
content_type = response.headers.get('Content-Type', '').split(';')[0].strip()
# Derive filename from URL path, falling back to Content-Disposition or mime guess
url_path = urllib.parse.urlparse(url).path
filename = os.path.basename(url_path) if url_path else ''
if not filename or '.' not in filename:
# Try Content-Disposition header
cd = response.headers.get('Content-Disposition', '')
if 'filename=' in cd:
filename = cd.split('filename=')[-1].strip('"\'')
if not filename or '.' not in filename:
ext = mimetypes.guess_extension(content_type) or ''
filename = f'download{ext}'
suffix = '.' + filename.split('.')[-1].lower() if '.' in filename else ''
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp:
tmp.write(response.content)
tmp_path = tmp.name
try:
loader = build_loader_from_config(request)
docs = loader.load(filename, content_type, tmp_path)
for doc in docs:
doc.metadata['source'] = url
content = ' '.join([doc.page_content for doc in docs])
return content, docs
finally:
os.remove(tmp_path)
def _is_text_content_type(content_type: str) -> bool:
"""Return True if the content type should be handled by the web loader."""
ct = content_type.split(';')[0].strip().lower()
if ct.startswith('text/'):
return True
if any(t in ct for t in ['xml', 'json', 'javascript']):
return True
return not ct # empty / missing → assume HTML
2025-10-04 02:02:26 -05:00
def get_content_from_url(request, url: str) -> str:
2026-04-21 16:04:48 +09:00
from open_webui.retrieval.web.utils import validate_url
# Validate URL before making any request (blocks private IPs, non-HTTP, filter list)
validate_url(url)
2026-04-21 15:47:32 +09:00
# Streamed GET to check Content-Type without downloading the body.
try:
response = requests.get(url, stream=True, timeout=30)
response.raise_for_status()
content_type = response.headers.get('Content-Type', '')
except Exception:
content_type = ''
response = None
# Text / HTML / unknown — use the configured web loader
if response is None or _is_text_content_type(content_type):
if response is not None:
response.close()
loader = get_loader(request, url)
docs = loader.load()
content = ' '.join([doc.page_content for doc in docs])
return content, docs
# Binary content (PDF, DOCX, XLSX, PPTX, etc.) — download and extract
try:
return _extract_text_from_binary_response(request, response, url)
finally:
response.close()
2025-10-04 02:02:26 -05:00
2026-03-17 17:58:01 -05:00
CHUNK_HASH_KEY = '_chunk_hash'
2026-02-22 18:42:25 -06:00
def _content_hash(text: str) -> str:
"""SHA-256 hash of text, used as a stable chunk identifier for RRF dedup."""
return hashlib.sha256(text.encode()).hexdigest()
2024-09-10 04:37:06 +01:00
class VectorSearchRetriever(BaseRetriever):
collection_name: Any
embedding_function: Any
top_k: int
2026-03-17 17:58:01 -05:00
def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> list[Document]:
2025-11-24 05:52:18 -05:00
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for.
run_manager: The callback handler to use.
Returns:
List of relevant documents.
"""
2025-11-24 05:58:22 -05:00
return []
2025-11-24 05:52:18 -05:00
async def _aget_relevant_documents(
2024-09-10 04:37:06 +01:00
self,
query: str,
2024-12-30 16:55:29 -08:00
*,
run_manager: CallbackManagerForRetrieverRun,
2024-09-10 04:37:06 +01:00
) -> list[Document]:
embedding = await self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
result = await ASYNC_VECTOR_DB_CLIENT.search(
2024-09-10 04:37:06 +01:00
collection_name=self.collection_name,
vectors=[embedding],
2024-09-10 04:37:06 +01:00
limit=self.top_k,
)
2024-09-13 01:18:20 -04:00
ids = result.ids[0]
metadatas = result.metadatas[0]
documents = result.documents[0]
2024-09-10 04:37:06 +01:00
2024-12-30 16:55:29 -08:00
results = []
for idx in range(len(ids)):
2026-02-22 18:42:25 -06:00
metadata = metadatas[idx]
metadata[CHUNK_HASH_KEY] = _content_hash(documents[idx])
2024-12-30 16:55:29 -08:00
results.append(
Document(
2026-02-22 18:42:25 -06:00
metadata=metadata,
2024-12-30 16:55:29 -08:00
page_content=documents[idx],
)
)
return results
2024-09-10 04:37:06 +01:00
2026-03-17 17:58:01 -05:00
def query_doc(collection_name: str, query_embedding: list[float], k: int, user: UserModel = None):
2024-04-14 17:55:00 -04:00
try:
2026-03-17 17:58:01 -05:00
log.debug(f'query_doc:doc {collection_name}')
2024-12-30 16:55:29 -08:00
result = VECTOR_DB_CLIENT.search(
2024-09-10 04:37:06 +01:00
collection_name=collection_name,
vectors=[query_embedding],
2024-09-10 04:37:06 +01:00
limit=k,
2024-12-30 16:55:29 -08:00
)
if result:
2026-03-17 17:58:01 -05:00
log.info(f'query_doc:result {result.ids} {result.metadatas}')
2024-12-19 20:56:16 -08:00
2024-12-30 16:55:29 -08:00
return result
2024-04-27 15:38:50 -04:00
except Exception as e:
2026-03-17 17:58:01 -05:00
log.exception(f'Error querying doc {collection_name} with limit {k}: {e}')
2024-04-27 15:38:50 -04:00
raise e
2024-04-25 17:03:00 -04:00
2025-02-18 21:14:58 -08:00
def get_doc(collection_name: str, user: UserModel = None):
try:
2026-03-17 17:58:01 -05:00
log.debug(f'get_doc:doc {collection_name}')
2025-02-18 21:14:58 -08:00
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
if result:
2026-03-17 17:58:01 -05:00
log.info(f'query_doc:result {result.ids} {result.metadatas}')
2025-02-18 21:14:58 -08:00
return result
except Exception as e:
2026-03-17 17:58:01 -05:00
log.exception(f'Error getting doc {collection_name}: {e}')
2025-02-18 21:14:58 -08:00
raise e
def get_enriched_texts(collection_result: GetResult) -> list[str]:
enriched_texts = []
for idx, text in enumerate(collection_result.documents[0]):
metadata = collection_result.metadatas[0][idx]
metadata_parts = [text]
# Add filename (repeat twice for extra weight in BM25 scoring)
2026-03-17 17:58:01 -05:00
if metadata.get('name'):
filename = metadata['name']
filename_tokens = filename.replace('_', ' ').replace('-', ' ').replace('.', ' ')
metadata_parts.append(f'Filename: {filename} {filename_tokens} {filename_tokens}')
# Add title if available
2026-03-17 17:58:01 -05:00
if metadata.get('title'):
metadata_parts.append(f'Title: {metadata["title"]}')
# Add document section headings if available (from markdown splitter)
2026-03-17 17:58:01 -05:00
if metadata.get('headings') and isinstance(metadata['headings'], list):
headings = ' > '.join(str(h) for h in metadata['headings'])
metadata_parts.append(f'Section: {headings}')
# Add source URL/path if available
2026-03-17 17:58:01 -05:00
if metadata.get('source'):
metadata_parts.append(f'Source: {metadata["source"]}')
# Add snippet for web search results
2026-03-17 17:58:01 -05:00
if metadata.get('snippet'):
metadata_parts.append(f'Snippet: {metadata["snippet"]}')
2026-03-17 17:58:01 -05:00
enriched_texts.append(' '.join(metadata_parts))
return enriched_texts
async def query_doc_with_hybrid_search(
2024-04-27 15:38:50 -04:00
collection_name: str,
2025-03-30 20:48:22 -07:00
collection_result: GetResult,
2024-04-27 15:38:50 -04:00
query: str,
embedding_function,
k: int,
reranking_function,
2025-03-06 10:47:57 +01:00
k_reranker: int,
2024-05-02 13:45:19 +08:00
r: float,
2025-05-23 22:06:44 +02:00
hybrid_bm25_weight: float,
enable_enriched_texts: bool = False,
2024-09-12 15:50:18 +02:00
) -> dict:
2024-04-27 15:38:50 -04:00
try:
# First check if collection_result has the required attributes
2025-09-21 00:14:43 -04:00
if (
not collection_result
2026-03-17 17:58:01 -05:00
or not hasattr(collection_result, 'documents')
or not hasattr(collection_result, 'metadatas')
):
2026-03-17 17:58:01 -05:00
log.warning(f'query_doc_with_hybrid_search:no_docs {collection_name}')
return {'documents': [], 'metadatas': [], 'distances': []}
2025-11-09 21:33:50 -05:00
# Now safely check the documents content after confirming attributes exist
if (
not collection_result.documents
2025-09-21 00:14:43 -04:00
or len(collection_result.documents) == 0
or not collection_result.documents[0]
):
2026-03-17 17:58:01 -05:00
log.warning(f'query_doc_with_hybrid_search:no_docs {collection_name}')
return {'documents': [], 'metadatas': [], 'distances': []}
2025-08-26 15:04:46 +04:00
2026-03-17 17:58:01 -05:00
log.debug(f'query_doc_with_hybrid_search:doc {collection_name}')
2025-09-01 14:22:02 +04:00
2026-02-22 18:42:25 -06:00
original_texts = collection_result.documents[0]
bm25_metadatas = [
{**meta, CHUNK_HASH_KEY: _content_hash(original_texts[idx])}
for idx, meta in enumerate(collection_result.metadatas[0])
]
2026-03-17 17:58:01 -05:00
bm25_texts = get_enriched_texts(collection_result) if enable_enriched_texts else original_texts
2025-09-01 14:21:17 +04:00
bm25_retriever = BM25Retriever.from_texts(
texts=bm25_texts,
2026-02-22 18:42:25 -06:00
metadatas=bm25_metadatas,
2025-09-01 14:21:17 +04:00
)
bm25_retriever.k = k
2024-04-25 17:03:00 -04:00
2024-09-10 04:37:06 +01:00
vector_search_retriever = VectorSearchRetriever(
collection_name=collection_name,
2024-04-27 15:38:50 -04:00
embedding_function=embedding_function,
2024-09-10 04:37:06 +01:00
top_k=k,
2024-04-27 15:38:50 -04:00
)
2024-04-25 17:03:00 -04:00
2026-02-22 18:42:25 -06:00
# Use CHUNK_HASH_KEY for dedup so enriched BM25 texts don't defeat RRF
2025-05-23 22:06:44 +02:00
if hybrid_bm25_weight <= 0:
ensemble_retriever = EnsembleRetriever(
2026-02-22 18:42:25 -06:00
retrievers=[vector_search_retriever],
weights=[1.0],
id_key=CHUNK_HASH_KEY,
)
2025-05-23 22:06:44 +02:00
elif hybrid_bm25_weight >= 1:
ensemble_retriever = EnsembleRetriever(
2026-02-22 18:42:25 -06:00
retrievers=[bm25_retriever],
weights=[1.0],
id_key=CHUNK_HASH_KEY,
)
else:
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, vector_search_retriever],
2025-05-24 02:13:54 +04:00
weights=[hybrid_bm25_weight, 1.0 - hybrid_bm25_weight],
2026-02-22 18:42:25 -06:00
id_key=CHUNK_HASH_KEY,
)
2024-04-27 15:38:50 -04:00
compressor = RerankCompressor(
embedding_function=embedding_function,
2025-03-06 10:47:57 +01:00
top_n=k_reranker,
2024-04-27 15:38:50 -04:00
reranking_function=reranking_function,
r_score=r,
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=ensemble_retriever
)
2024-04-25 17:03:00 -04:00
result = await compression_retriever.ainvoke(query)
2026-03-17 17:58:01 -05:00
distances = [d.metadata.get('score') for d in result]
2025-03-30 20:48:22 -07:00
documents = [d.page_content for d in result]
metadatas = [d.metadata for d in result]
# retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
if k < k_reranker:
2026-03-17 17:58:01 -05:00
sorted_items = sorted(zip(distances, documents, metadatas), key=lambda x: x[0], reverse=True)
sorted_items = sorted_items[:k]
2025-10-07 07:31:06 -05:00
if sorted_items:
distances, documents, metadatas = map(list, zip(*sorted_items))
else:
distances, documents, metadatas = [], [], []
2025-03-30 20:48:22 -07:00
result = {
2026-03-17 17:58:01 -05:00
'distances': [distances],
'documents': [documents],
'metadatas': [metadatas],
2024-04-27 15:38:50 -04:00
}
2024-04-29 12:15:58 -05:00
2026-03-17 17:58:01 -05:00
log.info('query_doc_with_hybrid_search:result ' + f'{result["metadatas"]} {result["distances"]}')
2025-03-30 20:48:22 -07:00
return result
2024-04-14 17:55:00 -04:00
except Exception as e:
2026-03-17 17:58:01 -05:00
log.exception(f'Error querying doc {collection_name} with hybrid search: {e}')
2024-04-14 17:55:00 -04:00
raise e
2025-02-18 21:14:58 -08:00
def merge_get_results(get_results: list[dict]) -> dict:
# Initialize lists to store combined data
combined_documents = []
combined_metadatas = []
2025-02-18 23:49:27 -08:00
combined_ids = []
2025-02-18 21:14:58 -08:00
for data in get_results:
2026-03-17 17:58:01 -05:00
combined_documents.extend(data['documents'][0])
combined_metadatas.extend(data['metadatas'][0])
combined_ids.extend(data['ids'][0])
2025-02-18 21:14:58 -08:00
# Create the output dictionary
result = {
2026-03-17 17:58:01 -05:00
'documents': [combined_documents],
'metadatas': [combined_metadatas],
'ids': [combined_ids],
2025-02-18 21:14:58 -08:00
}
return result
2025-03-25 19:09:17 +01:00
def merge_and_sort_query_results(query_results: list[dict], k: int) -> dict:
2024-12-30 16:55:29 -08:00
# Initialize lists to store combined data
2025-03-19 16:06:10 +01:00
combined = dict() # To store documents with unique document hashes
2024-12-30 16:55:29 -08:00
for data in query_results:
2025-10-09 16:16:24 -05:00
if (
2026-03-17 17:58:01 -05:00
len(data.get('distances', [])) == 0
or len(data.get('documents', [])) == 0
or len(data.get('metadatas', [])) == 0
2025-10-09 16:16:24 -05:00
):
continue
2026-03-17 17:58:01 -05:00
distances = data['distances'][0]
documents = data['documents'][0]
metadatas = data['metadatas'][0]
2025-02-26 23:51:39 -08:00
for distance, document, metadata in zip(distances, documents, metadatas):
if isinstance(document, str):
2026-03-17 17:58:01 -05:00
doc_hash = hashlib.sha256(document.encode()).hexdigest() # Compute a hash for uniqueness
2024-12-30 16:55:29 -08:00
2025-03-19 16:06:10 +01:00
if doc_hash not in combined.keys():
combined[doc_hash] = (distance, document, metadata)
continue # if doc is new, no further comparison is needed
2024-12-30 16:55:29 -08:00
2025-03-19 16:06:10 +01:00
# if doc is alredy in, but new distance is better, update
if distance > combined[doc_hash][0]:
2025-03-19 16:06:10 +01:00
combined[doc_hash] = (distance, document, metadata)
combined = list(combined.values())
2024-12-30 16:55:29 -08:00
# Sort the list based on distances
combined.sort(key=lambda x: x[0], reverse=True)
2024-12-30 16:55:29 -08:00
2025-02-26 23:51:39 -08:00
# Slice to keep only the top k elements
2026-03-17 17:58:01 -05:00
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined[:k]) if combined else ([], [], [])
2025-02-20 11:02:45 -08:00
2025-02-26 23:51:39 -08:00
# Create and return the output dictionary
return {
2026-03-17 17:58:01 -05:00
'distances': [list(sorted_distances)],
'documents': [list(sorted_documents)],
'metadatas': [list(sorted_metadatas)],
2024-12-30 16:55:29 -08:00
}
2024-03-08 19:26:39 -08:00
2025-02-18 21:14:58 -08:00
def get_all_items_from_collections(collection_names: list[str]) -> dict:
results = []
for collection_name in collection_names:
if collection_name:
try:
result = get_doc(collection_name=collection_name)
if result is not None:
results.append(result.model_dump())
except Exception as e:
2026-03-17 17:58:01 -05:00
log.exception(f'Error when querying the collection: {e}')
2025-02-18 21:14:58 -08:00
else:
pass
return merge_get_results(results)
async def query_collection(
2026-03-21 17:12:33 -05:00
request,
2024-08-14 13:46:31 +01:00
collection_names: list[str],
2024-11-19 02:24:32 -08:00
queries: list[str],
2024-04-27 15:38:50 -04:00
embedding_function,
k: int,
2024-09-12 15:50:18 +02:00
) -> dict:
2026-03-21 17:12:33 -05:00
# When request is provided, try hybrid search + reranking if enabled
if request and request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
try:
reranking_function = (
(lambda query, documents: request.app.state.RERANKING_FUNCTION(query, documents))
if request.app.state.RERANKING_FUNCTION
else None
)
return await query_collection_with_hybrid_search(
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
k_reranker=request.app.state.config.TOP_K_RERANKER,
r=request.app.state.config.RELEVANCE_THRESHOLD,
hybrid_bm25_weight=request.app.state.config.HYBRID_BM25_WEIGHT,
enable_enriched_texts=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS,
)
except Exception as e:
log.debug(f'Hybrid search failed, falling back to vector search: {e}')
2024-04-27 15:38:50 -04:00
results = []
error = False
def process_query_collection(collection_name, query_embedding):
try:
2024-12-30 16:55:29 -08:00
if collection_name:
result = query_doc(
collection_name=collection_name,
k=k,
query_embedding=query_embedding,
)
if result is not None:
return result.model_dump(), None
return None, None
except Exception as e:
2026-03-17 17:58:01 -05:00
log.exception(f'Error when querying the collection: {e}')
return None, e
# Generate all query embeddings (in one call)
2026-03-17 17:58:01 -05:00
query_embeddings = await embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
log.debug(f'query_collection: processing {len(queries)} queries across {len(collection_names)} collections')
with ThreadPoolExecutor() as executor:
future_results = []
for query_embedding in query_embeddings:
for collection_name in collection_names:
2026-03-17 17:58:01 -05:00
result = executor.submit(process_query_collection, collection_name, query_embedding)
future_results.append(result)
task_results = [future.result() for future in future_results]
for result, err in task_results:
if err is not None:
error = True
elif result is not None:
results.append(result)
if error and not results:
2026-03-17 17:58:01 -05:00
log.warning('All collection queries failed. No results returned.')
return merge_and_sort_query_results(results, k=k)
2024-04-27 15:38:50 -04:00
async def query_collection_with_hybrid_search(
2024-08-14 13:46:31 +01:00
collection_names: list[str],
2024-11-19 02:24:32 -08:00
queries: list[str],
2024-04-27 15:38:50 -04:00
embedding_function,
2024-04-22 15:49:58 -05:00
k: int,
reranking_function,
2025-03-06 10:47:57 +01:00
k_reranker: int,
2024-04-27 15:38:50 -04:00
r: float,
2025-05-23 22:06:44 +02:00
hybrid_bm25_weight: float,
enable_enriched_texts: bool = False,
2024-09-12 15:50:18 +02:00
) -> dict:
2024-04-14 17:55:00 -04:00
results = []
2024-09-13 01:18:20 -04:00
error = False
# Fetch every collection's contents once up front so the
# per-query/per-document loop below can reuse them. Each fetch
# offloads to a worker thread, so run them concurrently with
# `asyncio.gather` instead of awaiting them serially — otherwise
# latency scales linearly with `len(collection_names)`.
log.debug(
'query_collection_with_hybrid_search: prefetching %d collections',
len(collection_names),
)
async def _fetch_collection(name: str):
2025-09-01 00:57:13 +04:00
try:
return name, await ASYNC_VECTOR_DB_CLIENT.get(collection_name=name)
2025-09-01 00:57:13 +04:00
except Exception as e:
log.exception(f'Failed to fetch collection {name}: {e}')
return name, None
2026-04-14 17:27:31 -05:00
collection_results = dict(await asyncio.gather(*(_fetch_collection(name) for name in collection_names)))
2026-03-17 17:58:01 -05:00
log.info(f'Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections...')
2025-03-31 17:59:21 -07:00
async def process_query(collection_name, query):
2024-12-30 16:55:29 -08:00
try:
result = await query_doc_with_hybrid_search(
2025-03-31 16:43:37 +02:00
collection_name=collection_name,
collection_result=collection_results[collection_name],
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
k_reranker=k_reranker,
r=r,
2025-05-23 22:06:44 +02:00
hybrid_bm25_weight=hybrid_bm25_weight,
enable_enriched_texts=enable_enriched_texts,
2024-12-30 16:55:29 -08:00
)
2025-03-31 16:43:37 +02:00
return result, None
except Exception as e:
2026-03-17 17:58:01 -05:00
log.exception(f'Error when querying the collection with hybrid_search: {e}')
2025-03-31 16:43:37 +02:00
return None, e
2025-04-05 10:41:21 +02:00
# Prepare tasks for all collections and queries
# Avoid running any tasks for collections that failed to fetch data (have assigned None)
2025-04-05 10:03:24 -04:00
tasks = [
2025-11-22 21:33:14 -05:00
(collection_name, query)
for collection_name in collection_names
if collection_results[collection_name] is not None
for query in queries
2025-04-05 10:03:24 -04:00
]
2025-03-31 16:43:37 +02:00
# Run all queries in parallel using asyncio.gather
2026-03-17 17:58:01 -05:00
task_results = await asyncio.gather(*[process_query(collection_name, query) for collection_name, query in tasks])
2025-03-31 16:43:37 +02:00
for result, err in task_results:
if err is not None:
2024-12-30 16:55:29 -08:00
error = True
2025-03-31 16:43:37 +02:00
elif result is not None:
results.append(result)
2024-09-13 01:18:20 -04:00
2025-03-31 16:43:37 +02:00
if error and not results:
2026-03-17 17:58:01 -05:00
raise Exception('Hybrid search failed for all collections. Using Non-hybrid search as fallback.')
2025-03-31 17:59:21 -07:00
return merge_and_sort_query_results(results, k=k)
2024-04-14 17:55:00 -04:00
2025-03-27 01:40:28 -07:00
2025-11-22 21:33:14 -05:00
def generate_openai_batch_embeddings(
model: str,
texts: list[str],
2026-03-17 17:58:01 -05:00
url: str = 'https://api.openai.com/v1',
key: str = '',
prefix: str = None,
user: UserModel = None,
2026-03-24 20:14:28 -05:00
) -> list[list[float]]:
log.debug(f'generate_openai_batch_embeddings:model {model} batch size: {len(texts)}')
json_data = {'input': texts, 'model': model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {key}',
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
r = requests.post(
f'{url}/embeddings',
headers=headers,
json=json_data,
)
r.raise_for_status()
data = r.json()
if 'data' in data:
return [elem['embedding'] for elem in data['data']]
else:
raise ValueError("Unexpected OpenAI embeddings response: missing 'data' key")
2025-11-22 21:33:14 -05:00
async def agenerate_openai_batch_embeddings(
model: str,
texts: list[str],
2026-03-17 17:58:01 -05:00
url: str = 'https://api.openai.com/v1',
key: str = '',
2025-11-22 21:33:14 -05:00
prefix: str = None,
user: UserModel = None,
2026-03-24 20:14:28 -05:00
) -> list[list[float]]:
log.debug(f'agenerate_openai_batch_embeddings:model {model} batch size: {len(texts)}')
form_data = {'input': texts, 'model': model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {key}',
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
2026-03-24 20:14:28 -05:00
async with aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
) as session:
async with session.post(
f'{url}/embeddings',
headers=headers,
json=form_data,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
2025-11-22 21:33:14 -05:00
r.raise_for_status()
2026-03-24 20:14:28 -05:00
data = await r.json()
2026-03-17 17:58:01 -05:00
if 'data' in data:
2026-03-24 20:14:28 -05:00
return [item['embedding'] for item in data['data']]
2025-11-22 21:33:14 -05:00
else:
2026-03-24 20:14:28 -05:00
raise ValueError("Unexpected OpenAI embeddings response: missing 'data' key")
2025-11-22 21:33:14 -05:00
2026-03-24 20:14:28 -05:00
def generate_azure_openai_batch_embeddings(
2025-11-22 21:33:14 -05:00
model: str,
texts: list[str],
url: str,
2026-03-17 17:58:01 -05:00
key: str = '',
version: str = '',
2025-11-22 21:33:14 -05:00
prefix: str = None,
user: UserModel = None,
2026-03-24 20:14:28 -05:00
) -> list[list[float]]:
log.debug(f'generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}')
json_data = {'input': texts}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
2026-03-24 20:14:28 -05:00
url = f'{url}/openai/deployments/{model}/embeddings?api-version={version}'
2026-03-24 20:14:28 -05:00
for _ in range(5):
headers = {
2026-03-17 17:58:01 -05:00
'Content-Type': 'application/json',
'api-key': key,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
2026-03-24 20:14:28 -05:00
r = requests.post(
url,
headers=headers,
json=json_data,
)
if r.status_code == 429:
retry = float(r.headers.get('Retry-After', '1'))
time.sleep(retry)
continue
r.raise_for_status()
data = r.json()
if 'data' in data:
return [elem['embedding'] for elem in data['data']]
else:
raise ValueError("Unexpected Azure OpenAI embeddings response: missing 'data' key")
raise Exception('Azure OpenAI embedding request failed: max retries (429) exceeded')
2026-03-24 20:14:28 -05:00
async def agenerate_azure_openai_batch_embeddings(
2025-11-22 21:33:14 -05:00
model: str,
texts: list[str],
url: str,
2026-03-17 17:58:01 -05:00
key: str = '',
2026-03-24 20:14:28 -05:00
version: str = '',
2025-11-22 21:33:14 -05:00
prefix: str = None,
user: UserModel = None,
2026-03-24 20:14:28 -05:00
) -> list[list[float]]:
log.debug(f'agenerate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}')
form_data = {'input': texts}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
2025-11-22 21:33:14 -05:00
2026-03-24 20:14:28 -05:00
full_url = f'{url}/openai/deployments/{model}/embeddings?api-version={version}'
2025-11-23 20:15:52 -05:00
2026-03-24 20:14:28 -05:00
headers = {
'Content-Type': 'application/json',
'api-key': key,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
async with aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
) as session:
async with session.post(
full_url,
2025-11-23 10:40:05 +01:00
headers=headers,
2026-03-24 20:14:28 -05:00
json=form_data,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
r.raise_for_status()
data = await r.json()
if 'data' in data:
return [item['embedding'] for item in data['data']]
else:
raise ValueError("Unexpected Azure OpenAI embeddings response: missing 'data' key")
2025-11-22 21:33:14 -05:00
2026-03-24 20:14:28 -05:00
def generate_ollama_batch_embeddings(
model: str,
texts: list[str],
url: str,
key: str = '',
prefix: str = None,
user: UserModel = None,
) -> list[list[float]]:
log.debug(f'generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}')
json_data = {'input': texts, 'model': model, 'truncate': True}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {key}',
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
r = requests.post(
f'{url}/api/embed',
headers=headers,
json=json_data,
)
if r.status_code != 200:
error_detail = r.json().get('error', r.text)
raise Exception(f'Ollama embed error ({r.status_code}): {error_detail}')
data = r.json()
if 'embeddings' in data:
return data['embeddings']
else:
raise ValueError("Unexpected Ollama embeddings response: missing 'embeddings' key")
2025-11-22 21:33:14 -05:00
async def agenerate_ollama_batch_embeddings(
model: str,
texts: list[str],
url: str,
2026-03-17 17:58:01 -05:00
key: str = '',
prefix: str = None,
user: UserModel = None,
2026-03-24 20:14:28 -05:00
) -> list[list[float]]:
log.debug(f'agenerate_ollama_batch_embeddings:model {model} batch size: {len(texts)}')
form_data = {'input': texts, 'model': model, 'truncate': True}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {key}',
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
2026-03-24 20:14:28 -05:00
async with aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
) as session:
async with session.post(
f'{url}/api/embed',
headers=headers,
json=form_data,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
if r.status != 200:
error_data = await r.json()
error_detail = error_data.get('error', str(error_data))
raise Exception(f'Ollama embed error ({r.status}): {error_detail}')
data = await r.json()
if 'embeddings' in data:
return data['embeddings']
else:
raise ValueError("Unexpected Ollama embeddings response: missing 'embeddings' key")
2024-04-27 15:38:50 -04:00
def get_embedding_function(
2024-04-22 15:49:58 -05:00
embedding_engine,
embedding_model,
embedding_function,
2024-11-18 14:19:56 -08:00
url,
key,
2025-02-05 00:07:45 -08:00
embedding_batch_size,
2025-05-30 00:34:18 +04:00
azure_api_version=None,
enable_async=True,
2026-02-21 14:33:48 -06:00
concurrent_requests=0,
2025-11-22 21:33:14 -05:00
) -> Awaitable:
2026-03-17 17:58:01 -05:00
if embedding_engine == '':
# Sentence transformers: CPU-bound sync operation
2025-11-22 21:33:14 -05:00
async def async_embedding_function(query, prefix=None, user=None):
return await asyncio.to_thread(
(
lambda query, prefix=None: embedding_function.encode(
query,
batch_size=int(embedding_batch_size),
2026-03-17 17:58:01 -05:00
**({'prompt': prefix} if prefix else {}),
2025-11-22 21:33:14 -05:00
).tolist()
),
query,
prefix,
)
2025-11-22 21:33:14 -05:00
return async_embedding_function
2026-03-17 17:58:01 -05:00
elif embedding_engine in ['ollama', 'openai', 'azure_openai']:
2025-11-22 21:33:14 -05:00
embedding_function = lambda query, prefix=None, user=None: generate_embeddings(
engine=embedding_engine,
model=embedding_model,
text=query,
prefix=prefix,
url=url,
key=key,
user=user,
azure_api_version=azure_api_version,
)
2025-03-30 21:55:15 -07:00
2025-11-22 22:57:27 -05:00
async def async_embedding_function(query, prefix=None, user=None):
2024-04-22 18:36:46 -05:00
if isinstance(query, list):
2025-11-22 21:33:14 -05:00
# Create batches
2026-03-17 17:58:01 -05:00
batches = [query[i : i + embedding_batch_size] for i in range(0, len(query), embedding_batch_size)]
2025-11-22 21:33:14 -05:00
if enable_async:
2026-03-17 17:58:01 -05:00
log.debug(f'generate_multiple_async: Processing {len(batches)} batches in parallel')
2026-02-21 14:33:48 -06:00
# Use semaphore to limit concurrent embedding API requests
# 0 = unlimited (no semaphore)
if concurrent_requests:
semaphore = asyncio.Semaphore(concurrent_requests)
async def generate_batch_with_semaphore(batch):
async with semaphore:
2026-03-17 17:58:01 -05:00
return await embedding_function(batch, prefix=prefix, user=user)
2026-02-21 14:33:48 -06:00
2026-03-17 17:58:01 -05:00
tasks = [generate_batch_with_semaphore(batch) for batch in batches]
2026-02-21 14:33:48 -06:00
else:
2026-03-17 17:58:01 -05:00
tasks = [embedding_function(batch, prefix=prefix, user=user) for batch in batches]
batch_results = await asyncio.gather(*tasks)
else:
2026-03-17 17:58:01 -05:00
log.debug(f'generate_multiple_async: Processing {len(batches)} batches sequentially')
batch_results = []
for batch in batches:
2026-03-17 17:58:01 -05:00
batch_results.append(await embedding_function(batch, prefix=prefix, user=user))
2025-11-22 21:33:14 -05:00
2026-03-24 20:14:28 -05:00
# Flatten results — raise if any batch failed
2025-11-22 21:33:14 -05:00
embeddings = []
2026-03-24 20:14:28 -05:00
for i, batch_embeddings in enumerate(batch_results):
if batch_embeddings is None:
2026-03-25 16:43:06 -05:00
raise Exception(f'Embedding generation failed for batch {i + 1}/{len(batches)}')
2026-03-24 20:14:28 -05:00
embeddings.extend(batch_embeddings)
2025-11-22 21:33:14 -05:00
log.debug(
2026-03-17 17:58:01 -05:00
f'generate_multiple_async: Generated {len(embeddings)} embeddings from {len(batches)} parallel batches'
2025-11-22 21:33:14 -05:00
)
return embeddings
2024-04-22 18:36:46 -05:00
else:
2025-11-22 21:33:14 -05:00
return await embedding_function(query, prefix, user)
2025-03-30 21:55:15 -07:00
2025-11-22 21:33:14 -05:00
return async_embedding_function
else:
2026-03-17 17:58:01 -05:00
raise ValueError(f'Unknown embedding engine: {embedding_engine}')
2024-04-22 15:49:58 -05:00
2025-11-22 21:33:14 -05:00
async def generate_embeddings(
engine: str,
model: str,
text: Union[str, list[str]],
prefix: Union[str, None] = None,
**kwargs,
):
2026-03-17 17:58:01 -05:00
url = kwargs.get('url', '')
key = kwargs.get('key', '')
user = kwargs.get('user')
2025-11-22 21:33:14 -05:00
if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
if isinstance(text, list):
2026-03-17 17:58:01 -05:00
text = [f'{prefix}{text_element}' for text_element in text]
2025-11-22 21:33:14 -05:00
else:
2026-03-17 17:58:01 -05:00
text = f'{prefix}{text}'
2025-11-22 21:33:14 -05:00
2026-03-17 17:58:01 -05:00
if engine == 'ollama':
2025-11-22 21:33:14 -05:00
embeddings = await agenerate_ollama_batch_embeddings(
**{
2026-03-17 17:58:01 -05:00
'model': model,
'texts': text if isinstance(text, list) else [text],
'url': url,
'key': key,
'prefix': prefix,
'user': user,
2025-11-22 21:33:14 -05:00
}
)
2026-03-24 17:03:08 -05:00
if embeddings is None:
return None
2025-11-22 21:33:14 -05:00
return embeddings[0] if isinstance(text, str) else embeddings
2026-03-17 17:58:01 -05:00
elif engine == 'openai':
2025-11-22 21:33:14 -05:00
embeddings = await agenerate_openai_batch_embeddings(
model, text if isinstance(text, list) else [text], url, key, prefix, user
)
2026-03-24 17:03:08 -05:00
if embeddings is None:
return None
2025-11-22 21:33:14 -05:00
return embeddings[0] if isinstance(text, str) else embeddings
2026-03-17 17:58:01 -05:00
elif engine == 'azure_openai':
azure_api_version = kwargs.get('azure_api_version', '')
2025-11-22 21:33:14 -05:00
embeddings = await agenerate_azure_openai_batch_embeddings(
model,
text if isinstance(text, list) else [text],
url,
key,
azure_api_version,
prefix,
user,
)
2026-03-24 17:03:08 -05:00
if embeddings is None:
return None
2025-11-22 21:33:14 -05:00
return embeddings[0] if isinstance(text, str) else embeddings
def get_reranking_function(reranking_engine, reranking_model, reranking_function, reranking_batch_size=32):
2025-07-14 14:05:06 +04:00
if reranking_function is None:
return None
2026-03-17 17:58:01 -05:00
if reranking_engine == 'external':
2025-11-09 21:33:50 -05:00
return lambda query, documents, user=None: reranking_function.predict(
[(query, doc.page_content) for doc in documents], user=user
)
else:
2025-11-09 21:33:50 -05:00
return lambda query, documents, user=None: reranking_function.predict(
[(query, doc.page_content) for doc in documents], batch_size=int(reranking_batch_size)
2025-11-09 21:33:50 -05:00
)
2026-04-17 13:47:21 +09:00
async def filter_accessible_collections(
collection_names: set[str],
user: UserModel,
access_type: str = 'read',
) -> set[str]:
"""
Return only the collection names the user is allowed to access.
2026-04-17 13:59:46 +09:00
Admins bypass all checks. For non-admins the policy is:
2026-04-17 13:47:21 +09:00
- file-* → validated via has_access_to_file
- user-memory-* → must match user's own memory collection
2026-04-17 13:59:46 +09:00
- web-search-* → ephemeral per-query collections, always allowed
- knowledge-bases → always denied (system meta-collection)
- everything else → if the name matches a knowledge base, validated
via Knowledges.check_access_by_user_id; if no
such KB exists, the name is treated as an
ephemeral/legacy collection and allowed
2026-04-17 13:47:21 +09:00
"""
if user.role == 'admin':
return collection_names
validated = set()
for name in collection_names:
if name == 'knowledge-bases':
2026-04-17 13:59:46 +09:00
# System meta-collection — never exposed to non-admins.
2026-04-17 13:47:21 +09:00
continue
elif name.startswith('file-'):
2026-04-17 14:28:18 +09:00
file_id = name[len('file-') :]
2026-04-17 13:47:21 +09:00
if await has_access_to_file(file_id=file_id, access_type=access_type, user=user):
validated.add(name)
elif name.startswith('user-memory-'):
if name == f'user-memory-{user.id}':
validated.add(name)
2026-04-17 13:59:46 +09:00
elif name.startswith('web-search-'):
# Ephemeral collections created by process_web_search — safe
# to allow because they contain only transient web-search
# results scoped to the requesting user's session.
validated.add(name)
2026-04-17 13:47:21 +09:00
else:
2026-04-17 13:59:46 +09:00
# May be a knowledge-base ID or a legacy/ephemeral collection.
# If it IS a KB, enforce access control. If no such KB
# exists, treat it as a non-sensitive collection (e.g. legacy
# model knowledge, process_text SHA256 collections) and allow.
2026-04-17 13:47:21 +09:00
if await Knowledges.check_access_by_user_id(name, user.id, permission=access_type):
validated.add(name)
elif not await Knowledges.get_knowledge_by_id(name):
2026-04-17 13:59:46 +09:00
# Not a KB at all — legacy/ephemeral collection, allow
2026-04-17 13:47:21 +09:00
validated.add(name)
return validated
async def get_sources_from_items(
2025-02-26 15:42:19 -08:00
request,
2025-07-11 12:00:21 +04:00
items,
2024-11-19 02:24:32 -08:00
queries,
2024-04-27 15:38:50 -04:00
embedding_function,
2024-04-14 19:48:15 -04:00
k,
2024-04-27 15:38:50 -04:00
reranking_function,
2025-03-06 10:47:57 +01:00
k_reranker,
2024-04-22 18:36:46 -05:00
r,
2025-05-23 22:06:44 +02:00
hybrid_bm25_weight,
2024-04-26 14:41:39 -04:00
hybrid_search,
2025-02-18 21:14:58 -08:00
full_context=False,
2025-07-11 12:00:21 +04:00
user: Optional[UserModel] = None,
2024-04-14 19:48:15 -04:00
):
2026-03-17 17:58:01 -05:00
log.debug(f'items: {items} {queries} {embedding_function} {reranking_function} {full_context}')
2024-03-10 18:40:50 -07:00
2024-04-22 18:36:46 -05:00
extracted_collections = []
2025-06-25 12:20:08 +04:00
query_results = []
2024-03-10 18:40:50 -07:00
2025-07-11 12:00:21 +04:00
for item in items:
2025-06-25 12:20:08 +04:00
query_result = None
2025-07-11 12:29:17 +04:00
collection_names = []
2026-03-17 17:58:01 -05:00
if item.get('type') == 'text':
2025-07-11 12:29:17 +04:00
# Raw Text
2025-09-01 01:22:50 +04:00
# Used during temporary chat file uploads or web page & youtube attachements
2025-07-11 12:35:42 +04:00
2026-03-17 17:58:01 -05:00
if item.get('context') == 'full':
if item.get('file'):
# if item has file data, use it
query_result = {
2026-03-17 17:58:01 -05:00
'documents': [[item.get('file', {}).get('data', {}).get('content')]],
'metadatas': [[item.get('file', {}).get('meta', {})]],
}
if query_result is None:
# Fallback
2026-03-17 17:58:01 -05:00
if item.get('collection_name'):
# If item has a collection name, use it
2026-03-17 17:58:01 -05:00
collection_names.append(item.get('collection_name'))
elif item.get('file'):
# If item has file data, use it
query_result = {
2026-03-17 17:58:01 -05:00
'documents': [[item.get('file', {}).get('data', {}).get('content')]],
'metadatas': [[item.get('file', {}).get('meta', {})]],
}
else:
# Fallback to item content
query_result = {
2026-03-17 17:58:01 -05:00
'documents': [[item.get('content')]],
'metadatas': [[{'file_id': item.get('id'), 'name': item.get('name')}]],
}
2025-07-11 12:00:21 +04:00
2026-03-17 17:58:01 -05:00
elif item.get('type') == 'note':
2025-07-09 01:17:25 +04:00
# Note Attached
2026-04-12 14:22:11 -05:00
note = await Notes.get_note_by_id(item.get('id'))
2025-07-09 01:17:25 +04:00
2025-07-22 11:38:47 +04:00
if note and (
2026-03-17 17:58:01 -05:00
user.role == 'admin'
2025-07-22 17:17:26 +04:00
or note.user_id == user.id
2026-04-12 14:22:11 -05:00
or await AccessGrants.has_access(
2026-02-08 21:24:38 -06:00
user_id=user.id,
2026-03-17 17:58:01 -05:00
resource_type='note',
2026-02-08 21:24:38 -06:00
resource_id=note.id,
2026-03-17 17:58:01 -05:00
permission='read',
2026-02-08 21:24:38 -06:00
)
2025-07-22 11:38:47 +04:00
):
2025-07-11 12:00:21 +04:00
# User has access to the note
query_result = {
2026-03-17 17:58:01 -05:00
'documents': [[note.data.get('content', {}).get('md', '')]],
'metadatas': [[{'file_id': note.id, 'name': note.title}]],
2025-07-11 12:00:21 +04:00
}
2026-03-17 17:58:01 -05:00
elif item.get('type') == 'chat':
2025-09-14 10:26:46 +02:00
# Chat Attached
2026-04-12 14:22:11 -05:00
chat = await Chats.get_chat_by_id(item.get('id'))
2025-09-14 10:26:46 +02:00
2026-03-17 17:58:01 -05:00
if chat and (user.role == 'admin' or chat.user_id == user.id):
messages_map = chat.chat.get('history', {}).get('messages', {})
message_id = chat.chat.get('history', {}).get('currentId')
2025-09-14 10:26:46 +02:00
if messages_map and message_id:
# Reconstruct the message list in order
message_list = get_message_list(messages_map, message_id)
2026-03-17 17:58:01 -05:00
message_history = '\n'.join(
[f'#### {m.get("role", "user").capitalize()}\n{m.get("content")}\n' for m in message_list]
2025-09-14 10:26:46 +02:00
)
# User has access to the chat
query_result = {
2026-03-17 17:58:01 -05:00
'documents': [[message_history]],
'metadatas': [[{'file_id': chat.id, 'name': chat.title}]],
2025-09-14 10:26:46 +02:00
}
2026-03-17 17:58:01 -05:00
elif item.get('type') == 'url':
content, docs = get_content_from_url(request, item.get('url'))
2025-10-04 02:02:26 -05:00
if docs:
query_result = {
2026-03-17 17:58:01 -05:00
'documents': [[content]],
'metadatas': [[{'url': item.get('url'), 'name': item.get('url')}]],
2025-10-04 02:02:26 -05:00
}
2026-03-17 17:58:01 -05:00
elif item.get('type') == 'file':
if item.get('context') == 'full' or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
if item.get('file', {}).get('data', {}).get('content', ''):
2025-07-11 12:29:17 +04:00
# Manual Full Mode Toggle
# Used from chat file modal, we can assume that the file content will be available from item.get("file").get("data", {}).get("content")
query_result = {
2026-03-17 17:58:01 -05:00
'documents': [[item.get('file', {}).get('data', {}).get('content', '')]],
'metadatas': [
2025-07-11 12:29:17 +04:00
[
{
2026-03-17 17:58:01 -05:00
'file_id': item.get('id'),
'name': item.get('name'),
**item.get('file').get('data', {}).get('metadata', {}),
2025-07-11 12:29:17 +04:00
}
]
],
}
2026-03-17 17:58:01 -05:00
elif item.get('id'):
2026-04-12 14:22:11 -05:00
file_object = await Files.get_file_by_id(item.get('id'))
2026-03-26 19:01:33 -05:00
if file_object and (
user.role == 'admin'
or file_object.user_id == user.id
2026-04-12 14:22:11 -05:00
or await has_access_to_file(item.get('id'), 'read', user)
2026-03-26 19:01:33 -05:00
):
2025-07-11 12:29:17 +04:00
query_result = {
2026-03-17 17:58:01 -05:00
'documents': [[file_object.data.get('content', '')]],
'metadatas': [
2025-07-11 12:29:17 +04:00
[
{
2026-03-17 17:58:01 -05:00
'file_id': item.get('id'),
'name': file_object.filename,
'source': file_object.filename,
2025-07-11 12:29:17 +04:00
}
]
],
}
else:
# Fallback to collection names
2026-03-17 17:58:01 -05:00
if item.get('legacy'):
collection_names.append(f'{item["id"]}')
2025-07-11 12:29:17 +04:00
else:
2026-03-17 17:58:01 -05:00
collection_names.append(f'file-{item["id"]}')
2025-07-11 12:00:21 +04:00
2026-03-17 17:58:01 -05:00
elif item.get('type') == 'collection':
2025-10-26 17:22:23 -07:00
# Manual Full Mode Toggle for Collection
2026-04-12 14:22:11 -05:00
knowledge_base = await Knowledges.get_knowledge_by_id(item.get('id'))
2025-10-26 17:22:23 -07:00
if knowledge_base and (
2026-03-17 17:58:01 -05:00
user.role == 'admin'
2025-10-26 17:22:23 -07:00
or knowledge_base.user_id == user.id
2026-04-12 14:22:11 -05:00
or await AccessGrants.has_access(
2026-02-08 21:24:38 -06:00
user_id=user.id,
2026-03-17 17:58:01 -05:00
resource_type='knowledge',
2026-02-08 21:24:38 -06:00
resource_id=knowledge_base.id,
2026-03-17 17:58:01 -05:00
permission='read',
2026-02-08 21:24:38 -06:00
)
2025-10-26 17:22:23 -07:00
):
2026-03-17 17:58:01 -05:00
if item.get('context') == 'full' or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
2025-10-26 17:22:23 -07:00
if knowledge_base and (
2026-03-17 17:58:01 -05:00
user.role == 'admin'
2025-10-26 17:22:23 -07:00
or knowledge_base.user_id == user.id
2026-04-12 14:22:11 -05:00
or await AccessGrants.has_access(
2026-02-08 21:24:38 -06:00
user_id=user.id,
2026-03-17 17:58:01 -05:00
resource_type='knowledge',
2026-02-08 21:24:38 -06:00
resource_id=knowledge_base.id,
2026-03-17 17:58:01 -05:00
permission='read',
2026-02-08 21:24:38 -06:00
)
2025-10-26 17:22:23 -07:00
):
2026-04-12 14:22:11 -05:00
files = await Knowledges.get_files_by_id(knowledge_base.id)
2025-10-26 17:22:23 -07:00
documents = []
metadatas = []
2025-12-02 10:53:32 -05:00
for file in files:
2026-03-17 17:58:01 -05:00
documents.append(file.data.get('content', ''))
2025-12-02 10:53:32 -05:00
metadatas.append(
{
2026-03-17 17:58:01 -05:00
'file_id': file.id,
'name': file.filename,
'source': file.filename,
2025-12-02 10:53:32 -05:00
}
)
2025-10-26 17:22:23 -07:00
query_result = {
2026-03-17 17:58:01 -05:00
'documents': [documents],
'metadatas': [metadatas],
2025-10-26 17:22:23 -07:00
}
2024-10-03 23:06:47 -07:00
else:
2025-10-26 17:22:23 -07:00
# Fallback to collection names
2026-03-17 17:58:01 -05:00
if item.get('legacy'):
collection_names = item.get('collection_names', [])
2025-10-26 17:22:23 -07:00
else:
2026-03-17 17:58:01 -05:00
collection_names.append(item['id'])
2024-05-06 15:49:00 -07:00
2026-03-17 17:58:01 -05:00
elif item.get('docs'):
2025-07-11 12:29:17 +04:00
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
query_result = {
2026-03-17 17:58:01 -05:00
'documents': [[doc.get('content') for doc in item.get('docs')]],
'metadatas': [[doc.get('metadata') for doc in item.get('docs')]],
2025-07-11 12:29:17 +04:00
}
2026-03-17 17:58:01 -05:00
elif item.get('collection_name'):
2025-07-11 12:29:17 +04:00
# Direct Collection Name
2026-03-17 17:58:01 -05:00
collection_names.append(item['collection_name'])
elif item.get('collection_names'):
2025-07-15 21:57:24 +04:00
# Collection Names List
2026-03-17 17:58:01 -05:00
collection_names.extend(item['collection_names'])
2025-07-11 12:29:17 +04:00
# If query_result is None
# Fallback to collection names and vector search the collections
if query_result is None and collection_names:
2024-09-29 22:52:27 +02:00
collection_names = set(collection_names).difference(extracted_collections)
if not collection_names:
2026-03-17 17:58:01 -05:00
log.debug(f'skipping {item} as it has already been extracted')
2024-09-29 22:52:27 +02:00
continue
2024-04-14 19:48:15 -04:00
2026-04-17 13:47:21 +09:00
# Filter out collections the user cannot read
if user:
collection_names = await filter_accessible_collections(collection_names, user)
if not collection_names:
log.debug(f'access denied for all collections in item {item}')
continue
2025-07-11 12:35:42 +04:00
try:
if full_context:
# Sync helper makes blocking VECTOR_DB_CLIENT calls;
# offload so the async caller's event loop stays free.
2026-04-14 17:27:31 -05:00
query_result = await asyncio.to_thread(get_all_items_from_collections, collection_names)
2025-07-11 12:35:42 +04:00
else:
2026-03-21 17:12:33 -05:00
query_result = await query_collection(
request,
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
)
2025-07-11 12:35:42 +04:00
except Exception as e:
log.exception(e)
2024-09-29 22:52:27 +02:00
extracted_collections.extend(collection_names)
2024-03-10 18:40:50 -07:00
2025-06-25 12:20:08 +04:00
if query_result:
2026-03-17 17:58:01 -05:00
if 'data' in item:
del item['data']
query_results.append({**query_result, 'file': item})
2024-03-10 18:40:50 -07:00
2024-11-21 19:46:09 -08:00
sources = []
2025-06-25 12:20:08 +04:00
for query_result in query_results:
try:
2026-03-17 17:58:01 -05:00
if 'documents' in query_result:
if 'metadatas' in query_result:
2024-11-21 19:46:09 -08:00
source = {
2026-03-17 17:58:01 -05:00
'source': query_result['file'],
'document': query_result['documents'][0],
'metadata': query_result['metadatas'][0],
2024-10-07 21:13:13 +02:00
}
2026-03-17 17:58:01 -05:00
if 'distances' in query_result and query_result['distances']:
source['distances'] = query_result['distances'][0]
2024-11-21 19:46:09 -08:00
sources.append(source)
except Exception as e:
log.exception(e)
2024-11-21 19:46:09 -08:00
return sources
2024-04-04 12:07:42 -06:00
2024-04-25 07:49:59 -05:00
def get_model_path(model: str, update_model: bool = False):
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
2026-03-17 17:58:01 -05:00
cache_dir = os.getenv('SENTENCE_TRANSFORMERS_HOME')
2024-04-25 07:49:59 -05:00
local_files_only = not update_model
2024-12-29 11:53:09 +05:30
if OFFLINE_MODE:
local_files_only = True
2024-04-25 07:49:59 -05:00
snapshot_kwargs = {
2026-03-17 17:58:01 -05:00
'cache_dir': cache_dir,
'local_files_only': local_files_only,
2024-04-25 07:49:59 -05:00
}
2026-03-17 17:58:01 -05:00
log.debug(f'model: {model}')
log.debug(f'snapshot_kwargs: {snapshot_kwargs}')
2024-04-25 07:49:59 -05:00
# Inspiration from upstream sentence_transformers
2026-03-17 17:58:01 -05:00
if os.path.exists(model) or ('\\' in model or model.count('/') > 1) and local_files_only:
2024-04-25 07:49:59 -05:00
# If fully qualified path exists, return input, else set repo_id
2024-07-15 11:09:05 +02:00
return model
2026-03-17 17:58:01 -05:00
elif '/' not in model:
2024-04-25 07:49:59 -05:00
# Set valid repo_id for model short-name
2026-03-17 17:58:01 -05:00
model = 'sentence-transformers' + '/' + model
2024-04-25 07:49:59 -05:00
2026-03-17 17:58:01 -05:00
snapshot_kwargs['repo_id'] = model
2024-04-25 07:49:59 -05:00
# Attempt to query the huggingface_hub library to determine the local path and/or to update
try:
model_repo_path = snapshot_download(**snapshot_kwargs)
2026-03-17 17:58:01 -05:00
log.debug(f'model_repo_path: {model_repo_path}')
2024-04-25 07:49:59 -05:00
return model_repo_path
except Exception as e:
2026-03-17 17:58:01 -05:00
log.exception(f'Cannot determine model snapshot path: {e}')
if OFFLINE_MODE:
raise
2024-07-15 11:09:05 +02:00
return model
2024-04-25 07:49:59 -05:00
2024-04-22 18:36:46 -05:00
import operator
from typing import Optional, Sequence
from langchain_core.callbacks import Callbacks
2024-08-28 00:10:27 +02:00
from langchain_core.documents import BaseDocumentCompressor, Document
2024-04-22 18:36:46 -05:00
class RerankCompressor(BaseDocumentCompressor):
2024-04-27 15:38:50 -04:00
embedding_function: Any
2024-04-29 12:15:58 -05:00
top_n: int
2024-04-22 18:36:46 -05:00
reranking_function: Any
r_score: float
class Config:
2026-03-17 17:58:01 -05:00
extra = 'forbid'
2024-04-22 18:36:46 -05:00
arbitrary_types_allowed = True
2025-11-24 05:52:18 -05:00
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Compress retrieved documents given the query context.
Args:
documents: The retrieved documents.
query: The query context.
callbacks: Optional callbacks to run during compression.
Returns:
The compressed documents.
"""
2025-11-24 05:58:22 -05:00
return []
2025-11-24 05:52:18 -05:00
async def acompress_documents(
2024-04-22 18:36:46 -05:00
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
2024-04-29 12:15:58 -05:00
reranking = self.reranking_function is not None
2025-08-21 21:48:21 +04:00
scores = None
2024-04-29 12:15:58 -05:00
if reranking:
scores = await asyncio.to_thread(self.reranking_function, query, documents)
2024-04-22 18:36:46 -05:00
else:
from sentence_transformers import util
2026-03-17 17:58:01 -05:00
query_embedding = await self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
document_embedding = await self.embedding_function(
2025-03-30 21:55:15 -07:00
[doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
2024-04-22 18:36:46 -05:00
)
scores = util.cos_sim(query_embedding, document_embedding)[0]
2025-08-22 16:47:05 +04:00
if scores is not None:
2025-08-21 21:48:21 +04:00
docs_with_scores = list(
zip(
documents,
scores.tolist() if not isinstance(scores, list) else scores,
)
)
if self.r_score:
2026-03-17 17:58:01 -05:00
docs_with_scores = [(d, s) for d, s in docs_with_scores if s >= self.r_score]
2025-08-21 21:48:21 +04:00
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
final_results = []
for doc, doc_score in result[: self.top_n]:
metadata = doc.metadata
2026-03-17 17:58:01 -05:00
metadata['score'] = doc_score
2025-08-21 21:48:21 +04:00
doc = Document(
page_content=doc.page_content,
metadata=metadata,
)
final_results.append(doc)
return final_results
else:
2026-03-17 17:58:01 -05:00
log.warning('No valid scores found, check your reranking function. Returning original documents.')
2025-08-21 21:48:21 +04:00
return documents