Files
open-webui/backend/open_webui/retrieval/utils.py
T

1302 lines
46 KiB
Python
Raw Normal View History

import logging
2024-08-28 00:10:27 +02:00
import os
2025-11-22 21:33:14 -05:00
from typing import Awaitable, Optional, Union
2024-03-08 19:26:39 -08:00
2024-08-28 00:10:27 +02:00
import requests
import aiohttp
import asyncio
2025-02-26 23:51:39 -08:00
import hashlib
2025-03-31 16:43:37 +02:00
from concurrent.futures import ThreadPoolExecutor
2025-05-19 22:58:04 -04:00
import time
2025-10-04 02:02:26 -05:00
import re
2024-09-10 02:27:50 +01:00
from urllib.parse import quote
2024-04-25 07:49:59 -05:00
from huggingface_hub import snapshot_download
2025-12-20 22:50:44 +09:00
from langchain_classic.retrievers import (
ContextualCompressionRetriever,
EnsembleRetriever,
)
2024-04-22 18:36:46 -05:00
from langchain_community.retrievers import BM25Retriever
2024-08-28 00:10:27 +02:00
from langchain_core.documents import Document
2024-09-10 02:27:50 +01:00
2025-01-08 00:21:50 -08:00
from open_webui.config import VECTOR_DB
from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT
2025-02-20 11:02:45 -08:00
2025-10-04 02:02:26 -05:00
from open_webui.models.users import UserModel
2025-02-26 15:42:19 -08:00
from open_webui.models.files import Files
2025-07-11 12:00:21 +04:00
from open_webui.models.knowledge import Knowledges
2025-09-14 10:26:46 +02:00
from open_webui.models.chats import Chats
2025-07-09 01:17:25 +04:00
from open_webui.models.notes import Notes
2026-02-08 21:24:38 -06:00
from open_webui.models.access_grants import AccessGrants
2024-04-14 19:48:15 -04:00
2025-03-30 20:48:22 -07:00
from open_webui.retrieval.vector.main import GetResult
from open_webui.utils.headers import include_user_info_headers
2025-09-14 10:26:46 +02:00
from open_webui.utils.misc import get_message_list
2025-03-30 20:48:22 -07:00
2025-10-04 02:02:26 -05:00
from open_webui.retrieval.web.utils import get_web_loader
from open_webui.retrieval.loaders.youtube import YoutubeLoader
2025-02-05 15:15:24 -08:00
2025-02-05 00:07:45 -08:00
from open_webui.env import (
2026-01-08 00:42:29 +04:00
AIOHTTP_CLIENT_TIMEOUT,
2025-02-05 00:07:45 -08:00
OFFLINE_MODE,
ENABLE_FORWARD_USER_INFO_HEADERS,
2026-01-01 01:27:07 +04:00
AIOHTTP_CLIENT_SESSION_SSL,
2025-02-05 00:07:45 -08:00
)
2025-02-04 13:04:36 -08:00
from open_webui.config import (
2025-03-30 21:55:15 -07:00
RAG_EMBEDDING_QUERY_PREFIX,
RAG_EMBEDDING_CONTENT_PREFIX,
RAG_EMBEDDING_PREFIX_FIELD_NAME,
2025-02-04 13:04:36 -08:00
)
2024-09-10 02:27:50 +01:00
log = logging.getLogger(__name__)
2024-03-08 19:26:39 -08:00
2024-09-10 04:37:06 +01:00
from typing import Any
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.retrievers import BaseRetriever
2025-10-04 02:02:26 -05:00
def is_youtube_url(url: str) -> bool:
2026-03-17 17:58:01 -05:00
youtube_regex = r'^(https?://)?(www\.)?(youtube\.com|youtu\.be)/.+$'
2025-10-04 02:02:26 -05:00
return re.match(youtube_regex, url) is not None
def get_loader(request, url: str):
if is_youtube_url(url):
return YoutubeLoader(
url,
language=request.app.state.config.YOUTUBE_LOADER_LANGUAGE,
proxy_url=request.app.state.config.YOUTUBE_LOADER_PROXY_URL,
)
else:
return get_web_loader(
url,
verify_ssl=request.app.state.config.ENABLE_WEB_LOADER_SSL_VERIFICATION,
requests_per_second=request.app.state.config.WEB_LOADER_CONCURRENT_REQUESTS,
2025-10-28 04:58:00 +08:00
trust_env=request.app.state.config.WEB_SEARCH_TRUST_ENV,
2025-10-04 02:02:26 -05:00
)
def get_content_from_url(request, url: str) -> str:
loader = get_loader(request, url)
docs = loader.load()
2026-03-17 17:58:01 -05:00
content = ' '.join([doc.page_content for doc in docs])
2025-10-04 02:02:26 -05:00
return content, docs
2026-03-17 17:58:01 -05:00
CHUNK_HASH_KEY = '_chunk_hash'
2026-02-22 18:42:25 -06:00
def _content_hash(text: str) -> str:
"""SHA-256 hash of text, used as a stable chunk identifier for RRF dedup."""
return hashlib.sha256(text.encode()).hexdigest()
2024-09-10 04:37:06 +01:00
class VectorSearchRetriever(BaseRetriever):
collection_name: Any
embedding_function: Any
top_k: int
2026-03-17 17:58:01 -05:00
def _get_relevant_documents(self, query: str, *, run_manager: CallbackManagerForRetrieverRun) -> list[Document]:
2025-11-24 05:52:18 -05:00
"""Get documents relevant to a query.
Args:
query: String to find relevant documents for.
run_manager: The callback handler to use.
Returns:
List of relevant documents.
"""
2025-11-24 05:58:22 -05:00
return []
2025-11-24 05:52:18 -05:00
async def _aget_relevant_documents(
2024-09-10 04:37:06 +01:00
self,
query: str,
2024-12-30 16:55:29 -08:00
*,
run_manager: CallbackManagerForRetrieverRun,
2024-09-10 04:37:06 +01:00
) -> list[Document]:
embedding = await self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
2024-09-10 04:37:06 +01:00
result = VECTOR_DB_CLIENT.search(
collection_name=self.collection_name,
vectors=[embedding],
2024-09-10 04:37:06 +01:00
limit=self.top_k,
)
2024-09-13 01:18:20 -04:00
ids = result.ids[0]
metadatas = result.metadatas[0]
documents = result.documents[0]
2024-09-10 04:37:06 +01:00
2024-12-30 16:55:29 -08:00
results = []
for idx in range(len(ids)):
2026-02-22 18:42:25 -06:00
metadata = metadatas[idx]
metadata[CHUNK_HASH_KEY] = _content_hash(documents[idx])
2024-12-30 16:55:29 -08:00
results.append(
Document(
2026-02-22 18:42:25 -06:00
metadata=metadata,
2024-12-30 16:55:29 -08:00
page_content=documents[idx],
)
)
return results
2024-09-10 04:37:06 +01:00
2026-03-17 17:58:01 -05:00
def query_doc(collection_name: str, query_embedding: list[float], k: int, user: UserModel = None):
2024-04-14 17:55:00 -04:00
try:
2026-03-17 17:58:01 -05:00
log.debug(f'query_doc:doc {collection_name}')
2024-12-30 16:55:29 -08:00
result = VECTOR_DB_CLIENT.search(
2024-09-10 04:37:06 +01:00
collection_name=collection_name,
vectors=[query_embedding],
2024-09-10 04:37:06 +01:00
limit=k,
2024-12-30 16:55:29 -08:00
)
if result:
2026-03-17 17:58:01 -05:00
log.info(f'query_doc:result {result.ids} {result.metadatas}')
2024-12-19 20:56:16 -08:00
2024-12-30 16:55:29 -08:00
return result
2024-04-27 15:38:50 -04:00
except Exception as e:
2026-03-17 17:58:01 -05:00
log.exception(f'Error querying doc {collection_name} with limit {k}: {e}')
2024-04-27 15:38:50 -04:00
raise e
2024-04-25 17:03:00 -04:00
2025-02-18 21:14:58 -08:00
def get_doc(collection_name: str, user: UserModel = None):
try:
2026-03-17 17:58:01 -05:00
log.debug(f'get_doc:doc {collection_name}')
2025-02-18 21:14:58 -08:00
result = VECTOR_DB_CLIENT.get(collection_name=collection_name)
if result:
2026-03-17 17:58:01 -05:00
log.info(f'query_doc:result {result.ids} {result.metadatas}')
2025-02-18 21:14:58 -08:00
return result
except Exception as e:
2026-03-17 17:58:01 -05:00
log.exception(f'Error getting doc {collection_name}: {e}')
2025-02-18 21:14:58 -08:00
raise e
def get_enriched_texts(collection_result: GetResult) -> list[str]:
enriched_texts = []
for idx, text in enumerate(collection_result.documents[0]):
metadata = collection_result.metadatas[0][idx]
metadata_parts = [text]
# Add filename (repeat twice for extra weight in BM25 scoring)
2026-03-17 17:58:01 -05:00
if metadata.get('name'):
filename = metadata['name']
filename_tokens = filename.replace('_', ' ').replace('-', ' ').replace('.', ' ')
metadata_parts.append(f'Filename: {filename} {filename_tokens} {filename_tokens}')
# Add title if available
2026-03-17 17:58:01 -05:00
if metadata.get('title'):
metadata_parts.append(f'Title: {metadata["title"]}')
# Add document section headings if available (from markdown splitter)
2026-03-17 17:58:01 -05:00
if metadata.get('headings') and isinstance(metadata['headings'], list):
headings = ' > '.join(str(h) for h in metadata['headings'])
metadata_parts.append(f'Section: {headings}')
# Add source URL/path if available
2026-03-17 17:58:01 -05:00
if metadata.get('source'):
metadata_parts.append(f'Source: {metadata["source"]}')
# Add snippet for web search results
2026-03-17 17:58:01 -05:00
if metadata.get('snippet'):
metadata_parts.append(f'Snippet: {metadata["snippet"]}')
2026-03-17 17:58:01 -05:00
enriched_texts.append(' '.join(metadata_parts))
return enriched_texts
async def query_doc_with_hybrid_search(
2024-04-27 15:38:50 -04:00
collection_name: str,
2025-03-30 20:48:22 -07:00
collection_result: GetResult,
2024-04-27 15:38:50 -04:00
query: str,
embedding_function,
k: int,
reranking_function,
2025-03-06 10:47:57 +01:00
k_reranker: int,
2024-05-02 13:45:19 +08:00
r: float,
2025-05-23 22:06:44 +02:00
hybrid_bm25_weight: float,
enable_enriched_texts: bool = False,
2024-09-12 15:50:18 +02:00
) -> dict:
2024-04-27 15:38:50 -04:00
try:
# First check if collection_result has the required attributes
2025-09-21 00:14:43 -04:00
if (
not collection_result
2026-03-17 17:58:01 -05:00
or not hasattr(collection_result, 'documents')
or not hasattr(collection_result, 'metadatas')
):
2026-03-17 17:58:01 -05:00
log.warning(f'query_doc_with_hybrid_search:no_docs {collection_name}')
return {'documents': [], 'metadatas': [], 'distances': []}
2025-11-09 21:33:50 -05:00
# Now safely check the documents content after confirming attributes exist
if (
not collection_result.documents
2025-09-21 00:14:43 -04:00
or len(collection_result.documents) == 0
or not collection_result.documents[0]
):
2026-03-17 17:58:01 -05:00
log.warning(f'query_doc_with_hybrid_search:no_docs {collection_name}')
return {'documents': [], 'metadatas': [], 'distances': []}
2025-08-26 15:04:46 +04:00
2026-03-17 17:58:01 -05:00
log.debug(f'query_doc_with_hybrid_search:doc {collection_name}')
2025-09-01 14:22:02 +04:00
2026-02-22 18:42:25 -06:00
original_texts = collection_result.documents[0]
bm25_metadatas = [
{**meta, CHUNK_HASH_KEY: _content_hash(original_texts[idx])}
for idx, meta in enumerate(collection_result.metadatas[0])
]
2026-03-17 17:58:01 -05:00
bm25_texts = get_enriched_texts(collection_result) if enable_enriched_texts else original_texts
2025-09-01 14:21:17 +04:00
bm25_retriever = BM25Retriever.from_texts(
texts=bm25_texts,
2026-02-22 18:42:25 -06:00
metadatas=bm25_metadatas,
2025-09-01 14:21:17 +04:00
)
bm25_retriever.k = k
2024-04-25 17:03:00 -04:00
2024-09-10 04:37:06 +01:00
vector_search_retriever = VectorSearchRetriever(
collection_name=collection_name,
2024-04-27 15:38:50 -04:00
embedding_function=embedding_function,
2024-09-10 04:37:06 +01:00
top_k=k,
2024-04-27 15:38:50 -04:00
)
2024-04-25 17:03:00 -04:00
2026-02-22 18:42:25 -06:00
# Use CHUNK_HASH_KEY for dedup so enriched BM25 texts don't defeat RRF
2025-05-23 22:06:44 +02:00
if hybrid_bm25_weight <= 0:
ensemble_retriever = EnsembleRetriever(
2026-02-22 18:42:25 -06:00
retrievers=[vector_search_retriever],
weights=[1.0],
id_key=CHUNK_HASH_KEY,
)
2025-05-23 22:06:44 +02:00
elif hybrid_bm25_weight >= 1:
ensemble_retriever = EnsembleRetriever(
2026-02-22 18:42:25 -06:00
retrievers=[bm25_retriever],
weights=[1.0],
id_key=CHUNK_HASH_KEY,
)
else:
ensemble_retriever = EnsembleRetriever(
retrievers=[bm25_retriever, vector_search_retriever],
2025-05-24 02:13:54 +04:00
weights=[hybrid_bm25_weight, 1.0 - hybrid_bm25_weight],
2026-02-22 18:42:25 -06:00
id_key=CHUNK_HASH_KEY,
)
2024-04-27 15:38:50 -04:00
compressor = RerankCompressor(
embedding_function=embedding_function,
2025-03-06 10:47:57 +01:00
top_n=k_reranker,
2024-04-27 15:38:50 -04:00
reranking_function=reranking_function,
r_score=r,
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=ensemble_retriever
)
2024-04-25 17:03:00 -04:00
result = await compression_retriever.ainvoke(query)
2026-03-17 17:58:01 -05:00
distances = [d.metadata.get('score') for d in result]
2025-03-30 20:48:22 -07:00
documents = [d.page_content for d in result]
metadatas = [d.metadata for d in result]
# retrieve only min(k, k_reranker) items, sort and cut by distance if k < k_reranker
if k < k_reranker:
2026-03-17 17:58:01 -05:00
sorted_items = sorted(zip(distances, documents, metadatas), key=lambda x: x[0], reverse=True)
sorted_items = sorted_items[:k]
2025-10-07 07:31:06 -05:00
if sorted_items:
distances, documents, metadatas = map(list, zip(*sorted_items))
else:
distances, documents, metadatas = [], [], []
2025-03-30 20:48:22 -07:00
result = {
2026-03-17 17:58:01 -05:00
'distances': [distances],
'documents': [documents],
'metadatas': [metadatas],
2024-04-27 15:38:50 -04:00
}
2024-04-29 12:15:58 -05:00
2026-03-17 17:58:01 -05:00
log.info('query_doc_with_hybrid_search:result ' + f'{result["metadatas"]} {result["distances"]}')
2025-03-30 20:48:22 -07:00
return result
2024-04-14 17:55:00 -04:00
except Exception as e:
2026-03-17 17:58:01 -05:00
log.exception(f'Error querying doc {collection_name} with hybrid search: {e}')
2024-04-14 17:55:00 -04:00
raise e
2025-02-18 21:14:58 -08:00
def merge_get_results(get_results: list[dict]) -> dict:
# Initialize lists to store combined data
combined_documents = []
combined_metadatas = []
2025-02-18 23:49:27 -08:00
combined_ids = []
2025-02-18 21:14:58 -08:00
for data in get_results:
2026-03-17 17:58:01 -05:00
combined_documents.extend(data['documents'][0])
combined_metadatas.extend(data['metadatas'][0])
combined_ids.extend(data['ids'][0])
2025-02-18 21:14:58 -08:00
# Create the output dictionary
result = {
2026-03-17 17:58:01 -05:00
'documents': [combined_documents],
'metadatas': [combined_metadatas],
'ids': [combined_ids],
2025-02-18 21:14:58 -08:00
}
return result
2025-03-25 19:09:17 +01:00
def merge_and_sort_query_results(query_results: list[dict], k: int) -> dict:
2024-12-30 16:55:29 -08:00
# Initialize lists to store combined data
2025-03-19 16:06:10 +01:00
combined = dict() # To store documents with unique document hashes
2024-12-30 16:55:29 -08:00
for data in query_results:
2025-10-09 16:16:24 -05:00
if (
2026-03-17 17:58:01 -05:00
len(data.get('distances', [])) == 0
or len(data.get('documents', [])) == 0
or len(data.get('metadatas', [])) == 0
2025-10-09 16:16:24 -05:00
):
continue
2026-03-17 17:58:01 -05:00
distances = data['distances'][0]
documents = data['documents'][0]
metadatas = data['metadatas'][0]
2025-02-26 23:51:39 -08:00
for distance, document, metadata in zip(distances, documents, metadatas):
if isinstance(document, str):
2026-03-17 17:58:01 -05:00
doc_hash = hashlib.sha256(document.encode()).hexdigest() # Compute a hash for uniqueness
2024-12-30 16:55:29 -08:00
2025-03-19 16:06:10 +01:00
if doc_hash not in combined.keys():
combined[doc_hash] = (distance, document, metadata)
continue # if doc is new, no further comparison is needed
2024-12-30 16:55:29 -08:00
2025-03-19 16:06:10 +01:00
# if doc is alredy in, but new distance is better, update
if distance > combined[doc_hash][0]:
2025-03-19 16:06:10 +01:00
combined[doc_hash] = (distance, document, metadata)
combined = list(combined.values())
2024-12-30 16:55:29 -08:00
# Sort the list based on distances
combined.sort(key=lambda x: x[0], reverse=True)
2024-12-30 16:55:29 -08:00
2025-02-26 23:51:39 -08:00
# Slice to keep only the top k elements
2026-03-17 17:58:01 -05:00
sorted_distances, sorted_documents, sorted_metadatas = zip(*combined[:k]) if combined else ([], [], [])
2025-02-20 11:02:45 -08:00
2025-02-26 23:51:39 -08:00
# Create and return the output dictionary
return {
2026-03-17 17:58:01 -05:00
'distances': [list(sorted_distances)],
'documents': [list(sorted_documents)],
'metadatas': [list(sorted_metadatas)],
2024-12-30 16:55:29 -08:00
}
2024-03-08 19:26:39 -08:00
2025-02-18 21:14:58 -08:00
def get_all_items_from_collections(collection_names: list[str]) -> dict:
results = []
for collection_name in collection_names:
if collection_name:
try:
result = get_doc(collection_name=collection_name)
if result is not None:
results.append(result.model_dump())
except Exception as e:
2026-03-17 17:58:01 -05:00
log.exception(f'Error when querying the collection: {e}')
2025-02-18 21:14:58 -08:00
else:
pass
return merge_get_results(results)
async def query_collection(
2024-08-14 13:46:31 +01:00
collection_names: list[str],
2024-11-19 02:24:32 -08:00
queries: list[str],
2024-04-27 15:38:50 -04:00
embedding_function,
k: int,
2024-09-12 15:50:18 +02:00
) -> dict:
2024-04-27 15:38:50 -04:00
results = []
error = False
def process_query_collection(collection_name, query_embedding):
try:
2024-12-30 16:55:29 -08:00
if collection_name:
result = query_doc(
collection_name=collection_name,
k=k,
query_embedding=query_embedding,
)
if result is not None:
return result.model_dump(), None
return None, None
except Exception as e:
2026-03-17 17:58:01 -05:00
log.exception(f'Error when querying the collection: {e}')
return None, e
# Generate all query embeddings (in one call)
2026-03-17 17:58:01 -05:00
query_embeddings = await embedding_function(queries, prefix=RAG_EMBEDDING_QUERY_PREFIX)
log.debug(f'query_collection: processing {len(queries)} queries across {len(collection_names)} collections')
with ThreadPoolExecutor() as executor:
future_results = []
for query_embedding in query_embeddings:
for collection_name in collection_names:
2026-03-17 17:58:01 -05:00
result = executor.submit(process_query_collection, collection_name, query_embedding)
future_results.append(result)
task_results = [future.result() for future in future_results]
for result, err in task_results:
if err is not None:
error = True
elif result is not None:
results.append(result)
if error and not results:
2026-03-17 17:58:01 -05:00
log.warning('All collection queries failed. No results returned.')
return merge_and_sort_query_results(results, k=k)
2024-04-27 15:38:50 -04:00
async def query_collection_with_hybrid_search(
2024-08-14 13:46:31 +01:00
collection_names: list[str],
2024-11-19 02:24:32 -08:00
queries: list[str],
2024-04-27 15:38:50 -04:00
embedding_function,
2024-04-22 15:49:58 -05:00
k: int,
reranking_function,
2025-03-06 10:47:57 +01:00
k_reranker: int,
2024-04-27 15:38:50 -04:00
r: float,
2025-05-23 22:06:44 +02:00
hybrid_bm25_weight: float,
enable_enriched_texts: bool = False,
2024-09-12 15:50:18 +02:00
) -> dict:
2024-04-14 17:55:00 -04:00
results = []
2024-09-13 01:18:20 -04:00
error = False
2025-03-27 19:05:20 +01:00
# Fetch collection data once per collection sequentially
# Avoid fetching the same data multiple times later
2025-03-30 20:48:22 -07:00
collection_results = {}
2025-09-01 00:57:13 +04:00
for collection_name in collection_names:
try:
2026-03-17 17:58:01 -05:00
log.debug(f'query_collection_with_hybrid_search:VECTOR_DB_CLIENT.get:collection {collection_name}')
collection_results[collection_name] = VECTOR_DB_CLIENT.get(collection_name=collection_name)
2025-09-01 00:57:13 +04:00
except Exception as e:
2026-03-17 17:58:01 -05:00
log.exception(f'Failed to fetch collection {collection_name}: {e}')
2025-09-01 00:57:13 +04:00
collection_results[collection_name] = None
2026-03-17 17:58:01 -05:00
log.info(f'Starting hybrid search for {len(queries)} queries in {len(collection_names)} collections...')
2025-03-31 17:59:21 -07:00
async def process_query(collection_name, query):
2024-12-30 16:55:29 -08:00
try:
result = await query_doc_with_hybrid_search(
2025-03-31 16:43:37 +02:00
collection_name=collection_name,
collection_result=collection_results[collection_name],
query=query,
embedding_function=embedding_function,
k=k,
reranking_function=reranking_function,
k_reranker=k_reranker,
r=r,
2025-05-23 22:06:44 +02:00
hybrid_bm25_weight=hybrid_bm25_weight,
enable_enriched_texts=enable_enriched_texts,
2024-12-30 16:55:29 -08:00
)
2025-03-31 16:43:37 +02:00
return result, None
except Exception as e:
2026-03-17 17:58:01 -05:00
log.exception(f'Error when querying the collection with hybrid_search: {e}')
2025-03-31 16:43:37 +02:00
return None, e
2025-04-05 10:41:21 +02:00
# Prepare tasks for all collections and queries
# Avoid running any tasks for collections that failed to fetch data (have assigned None)
2025-04-05 10:03:24 -04:00
tasks = [
2025-11-22 21:33:14 -05:00
(collection_name, query)
for collection_name in collection_names
if collection_results[collection_name] is not None
for query in queries
2025-04-05 10:03:24 -04:00
]
2025-03-31 16:43:37 +02:00
# Run all queries in parallel using asyncio.gather
2026-03-17 17:58:01 -05:00
task_results = await asyncio.gather(*[process_query(collection_name, query) for collection_name, query in tasks])
2025-03-31 16:43:37 +02:00
for result, err in task_results:
if err is not None:
2024-12-30 16:55:29 -08:00
error = True
2025-03-31 16:43:37 +02:00
elif result is not None:
results.append(result)
2024-09-13 01:18:20 -04:00
2025-03-31 16:43:37 +02:00
if error and not results:
2026-03-17 17:58:01 -05:00
raise Exception('Hybrid search failed for all collections. Using Non-hybrid search as fallback.')
2025-03-31 17:59:21 -07:00
return merge_and_sort_query_results(results, k=k)
2024-04-14 17:55:00 -04:00
2025-03-27 01:40:28 -07:00
2025-11-22 21:33:14 -05:00
def generate_openai_batch_embeddings(
model: str,
texts: list[str],
2026-03-17 17:58:01 -05:00
url: str = 'https://api.openai.com/v1',
key: str = '',
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
2026-03-17 17:58:01 -05:00
log.debug(f'generate_openai_batch_embeddings:model {model} batch size: {len(texts)}')
json_data = {'input': texts, 'model': model}
2025-11-22 21:33:14 -05:00
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
2025-11-23 10:40:05 +01:00
headers = {
2026-03-17 17:58:01 -05:00
'Content-Type': 'application/json',
'Authorization': f'Bearer {key}',
2025-11-23 10:40:05 +01:00
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
2025-11-23 20:15:52 -05:00
2025-11-22 21:33:14 -05:00
r = requests.post(
2026-03-17 17:58:01 -05:00
f'{url}/embeddings',
2025-11-23 10:40:05 +01:00
headers=headers,
2025-11-22 21:33:14 -05:00
json=json_data,
)
r.raise_for_status()
data = r.json()
2026-03-17 17:58:01 -05:00
if 'data' in data:
return [elem['embedding'] for elem in data['data']]
2025-11-22 21:33:14 -05:00
else:
2026-03-17 17:58:01 -05:00
raise ValueError("Unexpected OpenAI embeddings response: missing 'data' key")
2025-11-22 21:33:14 -05:00
except Exception as e:
2026-03-17 17:58:01 -05:00
log.exception(f'Error generating openai batch embeddings: {e}')
2025-11-22 21:33:14 -05:00
return None
async def agenerate_openai_batch_embeddings(
model: str,
texts: list[str],
2026-03-17 17:58:01 -05:00
url: str = 'https://api.openai.com/v1',
key: str = '',
2025-11-22 21:33:14 -05:00
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
2026-03-17 17:58:01 -05:00
log.debug(f'agenerate_openai_batch_embeddings:model {model} batch size: {len(texts)}')
form_data = {'input': texts, 'model': model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
headers = {
2026-03-17 17:58:01 -05:00
'Content-Type': 'application/json',
'Authorization': f'Bearer {key}',
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
2026-01-08 00:42:29 +04:00
async with aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
) as session:
2025-11-22 21:33:14 -05:00
async with session.post(
2026-03-17 17:58:01 -05:00
f'{url}/embeddings',
2026-02-10 12:44:31 -06:00
headers=headers,
json=form_data,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
2025-11-22 21:33:14 -05:00
) as r:
r.raise_for_status()
data = await r.json()
2026-03-17 17:58:01 -05:00
if 'data' in data:
return [item['embedding'] for item in data['data']]
else:
2026-03-17 17:58:01 -05:00
raise Exception('Something went wrong :/')
except Exception as e:
2026-03-17 17:58:01 -05:00
log.exception(f'Error generating openai batch embeddings: {e}')
return None
2025-11-22 21:33:14 -05:00
def generate_azure_openai_batch_embeddings(
model: str,
texts: list[str],
url: str,
2026-03-17 17:58:01 -05:00
key: str = '',
version: str = '',
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
2026-03-17 17:58:01 -05:00
log.debug(f'generate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}')
json_data = {'input': texts}
2025-11-22 21:33:14 -05:00
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
2026-03-17 17:58:01 -05:00
url = f'{url}/openai/deployments/{model}/embeddings?api-version={version}'
2025-11-22 21:33:14 -05:00
for _ in range(5):
2025-11-23 10:40:05 +01:00
headers = {
2026-03-17 17:58:01 -05:00
'Content-Type': 'application/json',
'api-key': key,
2025-11-23 10:40:05 +01:00
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
2025-11-23 20:15:52 -05:00
2025-11-22 21:33:14 -05:00
r = requests.post(
url,
2025-11-23 10:40:05 +01:00
headers=headers,
2025-11-22 21:33:14 -05:00
json=json_data,
)
if r.status_code == 429:
2026-03-17 17:58:01 -05:00
retry = float(r.headers.get('Retry-After', '1'))
2025-11-22 21:33:14 -05:00
time.sleep(retry)
continue
r.raise_for_status()
data = r.json()
2026-03-17 17:58:01 -05:00
if 'data' in data:
return [elem['embedding'] for elem in data['data']]
2025-11-22 21:33:14 -05:00
else:
2026-03-17 17:58:01 -05:00
raise Exception('Something went wrong :/')
2025-11-22 21:33:14 -05:00
return None
except Exception as e:
2026-03-17 17:58:01 -05:00
log.exception(f'Error generating azure openai batch embeddings: {e}')
2025-11-22 21:33:14 -05:00
return None
async def agenerate_azure_openai_batch_embeddings(
model: str,
texts: list[str],
url: str,
2026-03-17 17:58:01 -05:00
key: str = '',
version: str = '',
2025-11-22 21:33:14 -05:00
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
2026-03-17 17:58:01 -05:00
log.debug(f'agenerate_azure_openai_batch_embeddings:deployment {model} batch size: {len(texts)}')
form_data = {'input': texts}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
2026-03-17 17:58:01 -05:00
full_url = f'{url}/openai/deployments/{model}/embeddings?api-version={version}'
headers = {
2026-03-17 17:58:01 -05:00
'Content-Type': 'application/json',
'api-key': key,
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
2026-01-08 00:42:29 +04:00
async with aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
) as session:
2026-02-10 12:44:31 -06:00
async with session.post(
2026-02-11 16:24:11 -06:00
full_url,
headers=headers,
json=form_data,
2026-02-10 12:44:31 -06:00
ssl=AIOHTTP_CLIENT_SESSION_SSL,
) as r:
r.raise_for_status()
data = await r.json()
2026-03-17 17:58:01 -05:00
if 'data' in data:
return [item['embedding'] for item in data['data']]
else:
2026-03-17 17:58:01 -05:00
raise Exception('Something went wrong :/')
except Exception as e:
2026-03-17 17:58:01 -05:00
log.exception(f'Error generating azure openai batch embeddings: {e}')
return None
2025-11-22 21:33:14 -05:00
def generate_ollama_batch_embeddings(
model: str,
texts: list[str],
url: str,
2026-03-17 17:58:01 -05:00
key: str = '',
2025-11-22 21:33:14 -05:00
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
2026-03-17 17:58:01 -05:00
log.debug(f'generate_ollama_batch_embeddings:model {model} batch size: {len(texts)}')
json_data = {'input': texts, 'model': model}
2025-11-22 21:33:14 -05:00
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
json_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
2025-11-23 10:40:05 +01:00
headers = {
2026-03-17 17:58:01 -05:00
'Content-Type': 'application/json',
'Authorization': f'Bearer {key}',
2025-11-23 10:40:05 +01:00
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
2025-11-23 20:15:52 -05:00
2025-11-22 21:33:14 -05:00
r = requests.post(
2026-03-17 17:58:01 -05:00
f'{url}/api/embed',
2025-11-23 10:40:05 +01:00
headers=headers,
2025-11-22 21:33:14 -05:00
json=json_data,
)
r.raise_for_status()
data = r.json()
2026-03-17 17:58:01 -05:00
if 'embeddings' in data:
return data['embeddings']
2025-11-22 21:33:14 -05:00
else:
2026-03-17 17:58:01 -05:00
raise ValueError("Unexpected Ollama embeddings response: missing 'embeddings' key")
2025-11-22 21:33:14 -05:00
except Exception as e:
2026-03-17 17:58:01 -05:00
log.exception(f'Error generating ollama batch embeddings: {e}')
2025-11-22 21:33:14 -05:00
return None
async def agenerate_ollama_batch_embeddings(
model: str,
texts: list[str],
url: str,
2026-03-17 17:58:01 -05:00
key: str = '',
prefix: str = None,
user: UserModel = None,
) -> Optional[list[list[float]]]:
try:
2026-03-17 17:58:01 -05:00
log.debug(f'agenerate_ollama_batch_embeddings:model {model} batch size: {len(texts)}')
form_data = {'input': texts, 'model': model}
if isinstance(RAG_EMBEDDING_PREFIX_FIELD_NAME, str) and isinstance(prefix, str):
form_data[RAG_EMBEDDING_PREFIX_FIELD_NAME] = prefix
headers = {
2026-03-17 17:58:01 -05:00
'Content-Type': 'application/json',
'Authorization': f'Bearer {key}',
}
if ENABLE_FORWARD_USER_INFO_HEADERS and user:
headers = include_user_info_headers(headers, user)
2026-01-08 00:42:29 +04:00
async with aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
) as session:
2025-11-22 21:33:14 -05:00
async with session.post(
2026-03-17 17:58:01 -05:00
f'{url}/api/embed',
2026-01-01 01:27:07 +04:00
headers=headers,
json=form_data,
ssl=AIOHTTP_CLIENT_SESSION_SSL,
2025-11-22 21:33:14 -05:00
) as r:
r.raise_for_status()
data = await r.json()
2026-03-17 17:58:01 -05:00
if 'embeddings' in data:
return data['embeddings']
else:
2026-03-17 17:58:01 -05:00
raise Exception('Something went wrong :/')
except Exception as e:
2026-03-17 17:58:01 -05:00
log.exception(f'Error generating ollama batch embeddings: {e}')
return None
2024-04-27 15:38:50 -04:00
def get_embedding_function(
2024-04-22 15:49:58 -05:00
embedding_engine,
embedding_model,
embedding_function,
2024-11-18 14:19:56 -08:00
url,
key,
2025-02-05 00:07:45 -08:00
embedding_batch_size,
2025-05-30 00:34:18 +04:00
azure_api_version=None,
enable_async=True,
2026-02-21 14:33:48 -06:00
concurrent_requests=0,
2025-11-22 21:33:14 -05:00
) -> Awaitable:
2026-03-17 17:58:01 -05:00
if embedding_engine == '':
# Sentence transformers: CPU-bound sync operation
2025-11-22 21:33:14 -05:00
async def async_embedding_function(query, prefix=None, user=None):
return await asyncio.to_thread(
(
lambda query, prefix=None: embedding_function.encode(
query,
batch_size=int(embedding_batch_size),
2026-03-17 17:58:01 -05:00
**({'prompt': prefix} if prefix else {}),
2025-11-22 21:33:14 -05:00
).tolist()
),
query,
prefix,
)
2025-11-22 21:33:14 -05:00
return async_embedding_function
2026-03-17 17:58:01 -05:00
elif embedding_engine in ['ollama', 'openai', 'azure_openai']:
2025-11-22 21:33:14 -05:00
embedding_function = lambda query, prefix=None, user=None: generate_embeddings(
engine=embedding_engine,
model=embedding_model,
text=query,
prefix=prefix,
url=url,
key=key,
user=user,
azure_api_version=azure_api_version,
)
2025-03-30 21:55:15 -07:00
2025-11-22 22:57:27 -05:00
async def async_embedding_function(query, prefix=None, user=None):
2024-04-22 18:36:46 -05:00
if isinstance(query, list):
2025-11-22 21:33:14 -05:00
# Create batches
2026-03-17 17:58:01 -05:00
batches = [query[i : i + embedding_batch_size] for i in range(0, len(query), embedding_batch_size)]
2025-11-22 21:33:14 -05:00
if enable_async:
2026-03-17 17:58:01 -05:00
log.debug(f'generate_multiple_async: Processing {len(batches)} batches in parallel')
2026-02-21 14:33:48 -06:00
# Use semaphore to limit concurrent embedding API requests
# 0 = unlimited (no semaphore)
if concurrent_requests:
semaphore = asyncio.Semaphore(concurrent_requests)
async def generate_batch_with_semaphore(batch):
async with semaphore:
2026-03-17 17:58:01 -05:00
return await embedding_function(batch, prefix=prefix, user=user)
2026-02-21 14:33:48 -06:00
2026-03-17 17:58:01 -05:00
tasks = [generate_batch_with_semaphore(batch) for batch in batches]
2026-02-21 14:33:48 -06:00
else:
2026-03-17 17:58:01 -05:00
tasks = [embedding_function(batch, prefix=prefix, user=user) for batch in batches]
batch_results = await asyncio.gather(*tasks)
else:
2026-03-17 17:58:01 -05:00
log.debug(f'generate_multiple_async: Processing {len(batches)} batches sequentially')
batch_results = []
for batch in batches:
2026-03-17 17:58:01 -05:00
batch_results.append(await embedding_function(batch, prefix=prefix, user=user))
2025-11-22 21:33:14 -05:00
# Flatten results
embeddings = []
for batch_embeddings in batch_results:
if isinstance(batch_embeddings, list):
embeddings.extend(batch_embeddings)
log.debug(
2026-03-17 17:58:01 -05:00
f'generate_multiple_async: Generated {len(embeddings)} embeddings from {len(batches)} parallel batches'
2025-11-22 21:33:14 -05:00
)
return embeddings
2024-04-22 18:36:46 -05:00
else:
2025-11-22 21:33:14 -05:00
return await embedding_function(query, prefix, user)
2025-03-30 21:55:15 -07:00
2025-11-22 21:33:14 -05:00
return async_embedding_function
else:
2026-03-17 17:58:01 -05:00
raise ValueError(f'Unknown embedding engine: {embedding_engine}')
2024-04-22 15:49:58 -05:00
2025-11-22 21:33:14 -05:00
async def generate_embeddings(
engine: str,
model: str,
text: Union[str, list[str]],
prefix: Union[str, None] = None,
**kwargs,
):
2026-03-17 17:58:01 -05:00
url = kwargs.get('url', '')
key = kwargs.get('key', '')
user = kwargs.get('user')
2025-11-22 21:33:14 -05:00
if prefix is not None and RAG_EMBEDDING_PREFIX_FIELD_NAME is None:
if isinstance(text, list):
2026-03-17 17:58:01 -05:00
text = [f'{prefix}{text_element}' for text_element in text]
2025-11-22 21:33:14 -05:00
else:
2026-03-17 17:58:01 -05:00
text = f'{prefix}{text}'
2025-11-22 21:33:14 -05:00
2026-03-17 17:58:01 -05:00
if engine == 'ollama':
2025-11-22 21:33:14 -05:00
embeddings = await agenerate_ollama_batch_embeddings(
**{
2026-03-17 17:58:01 -05:00
'model': model,
'texts': text if isinstance(text, list) else [text],
'url': url,
'key': key,
'prefix': prefix,
'user': user,
2025-11-22 21:33:14 -05:00
}
)
return embeddings[0] if isinstance(text, str) else embeddings
2026-03-17 17:58:01 -05:00
elif engine == 'openai':
2025-11-22 21:33:14 -05:00
embeddings = await agenerate_openai_batch_embeddings(
model, text if isinstance(text, list) else [text], url, key, prefix, user
)
return embeddings[0] if isinstance(text, str) else embeddings
2026-03-17 17:58:01 -05:00
elif engine == 'azure_openai':
azure_api_version = kwargs.get('azure_api_version', '')
2025-11-22 21:33:14 -05:00
embeddings = await agenerate_azure_openai_batch_embeddings(
model,
text if isinstance(text, list) else [text],
url,
key,
azure_api_version,
prefix,
user,
)
return embeddings[0] if isinstance(text, str) else embeddings
def get_reranking_function(reranking_engine, reranking_model, reranking_function):
2025-07-14 14:05:06 +04:00
if reranking_function is None:
return None
2026-03-17 17:58:01 -05:00
if reranking_engine == 'external':
2025-11-09 21:33:50 -05:00
return lambda query, documents, user=None: reranking_function.predict(
[(query, doc.page_content) for doc in documents], user=user
)
else:
2025-11-09 21:33:50 -05:00
return lambda query, documents, user=None: reranking_function.predict(
[(query, doc.page_content) for doc in documents]
)
async def get_sources_from_items(
2025-02-26 15:42:19 -08:00
request,
2025-07-11 12:00:21 +04:00
items,
2024-11-19 02:24:32 -08:00
queries,
2024-04-27 15:38:50 -04:00
embedding_function,
2024-04-14 19:48:15 -04:00
k,
2024-04-27 15:38:50 -04:00
reranking_function,
2025-03-06 10:47:57 +01:00
k_reranker,
2024-04-22 18:36:46 -05:00
r,
2025-05-23 22:06:44 +02:00
hybrid_bm25_weight,
2024-04-26 14:41:39 -04:00
hybrid_search,
2025-02-18 21:14:58 -08:00
full_context=False,
2025-07-11 12:00:21 +04:00
user: Optional[UserModel] = None,
2024-04-14 19:48:15 -04:00
):
2026-03-17 17:58:01 -05:00
log.debug(f'items: {items} {queries} {embedding_function} {reranking_function} {full_context}')
2024-03-10 18:40:50 -07:00
2024-04-22 18:36:46 -05:00
extracted_collections = []
2025-06-25 12:20:08 +04:00
query_results = []
2024-03-10 18:40:50 -07:00
2025-07-11 12:00:21 +04:00
for item in items:
2025-06-25 12:20:08 +04:00
query_result = None
2025-07-11 12:29:17 +04:00
collection_names = []
2026-03-17 17:58:01 -05:00
if item.get('type') == 'text':
2025-07-11 12:29:17 +04:00
# Raw Text
2025-09-01 01:22:50 +04:00
# Used during temporary chat file uploads or web page & youtube attachements
2025-07-11 12:35:42 +04:00
2026-03-17 17:58:01 -05:00
if item.get('context') == 'full':
if item.get('file'):
# if item has file data, use it
query_result = {
2026-03-17 17:58:01 -05:00
'documents': [[item.get('file', {}).get('data', {}).get('content')]],
'metadatas': [[item.get('file', {}).get('meta', {})]],
}
if query_result is None:
# Fallback
2026-03-17 17:58:01 -05:00
if item.get('collection_name'):
# If item has a collection name, use it
2026-03-17 17:58:01 -05:00
collection_names.append(item.get('collection_name'))
elif item.get('file'):
# If item has file data, use it
query_result = {
2026-03-17 17:58:01 -05:00
'documents': [[item.get('file', {}).get('data', {}).get('content')]],
'metadatas': [[item.get('file', {}).get('meta', {})]],
}
else:
# Fallback to item content
query_result = {
2026-03-17 17:58:01 -05:00
'documents': [[item.get('content')]],
'metadatas': [[{'file_id': item.get('id'), 'name': item.get('name')}]],
}
2025-07-11 12:00:21 +04:00
2026-03-17 17:58:01 -05:00
elif item.get('type') == 'note':
2025-07-09 01:17:25 +04:00
# Note Attached
2026-03-17 17:58:01 -05:00
note = Notes.get_note_by_id(item.get('id'))
2025-07-09 01:17:25 +04:00
2025-07-22 11:38:47 +04:00
if note and (
2026-03-17 17:58:01 -05:00
user.role == 'admin'
2025-07-22 17:17:26 +04:00
or note.user_id == user.id
2026-02-08 21:24:38 -06:00
or AccessGrants.has_access(
user_id=user.id,
2026-03-17 17:58:01 -05:00
resource_type='note',
2026-02-08 21:24:38 -06:00
resource_id=note.id,
2026-03-17 17:58:01 -05:00
permission='read',
2026-02-08 21:24:38 -06:00
)
2025-07-22 11:38:47 +04:00
):
2025-07-11 12:00:21 +04:00
# User has access to the note
query_result = {
2026-03-17 17:58:01 -05:00
'documents': [[note.data.get('content', {}).get('md', '')]],
'metadatas': [[{'file_id': note.id, 'name': note.title}]],
2025-07-11 12:00:21 +04:00
}
2026-03-17 17:58:01 -05:00
elif item.get('type') == 'chat':
2025-09-14 10:26:46 +02:00
# Chat Attached
2026-03-17 17:58:01 -05:00
chat = Chats.get_chat_by_id(item.get('id'))
2025-09-14 10:26:46 +02:00
2026-03-17 17:58:01 -05:00
if chat and (user.role == 'admin' or chat.user_id == user.id):
messages_map = chat.chat.get('history', {}).get('messages', {})
message_id = chat.chat.get('history', {}).get('currentId')
2025-09-14 10:26:46 +02:00
if messages_map and message_id:
# Reconstruct the message list in order
message_list = get_message_list(messages_map, message_id)
2026-03-17 17:58:01 -05:00
message_history = '\n'.join(
[f'#### {m.get("role", "user").capitalize()}\n{m.get("content")}\n' for m in message_list]
2025-09-14 10:26:46 +02:00
)
# User has access to the chat
query_result = {
2026-03-17 17:58:01 -05:00
'documents': [[message_history]],
'metadatas': [[{'file_id': chat.id, 'name': chat.title}]],
2025-09-14 10:26:46 +02:00
}
2026-03-17 17:58:01 -05:00
elif item.get('type') == 'url':
content, docs = get_content_from_url(request, item.get('url'))
2025-10-04 02:02:26 -05:00
if docs:
query_result = {
2026-03-17 17:58:01 -05:00
'documents': [[content]],
'metadatas': [[{'url': item.get('url'), 'name': item.get('url')}]],
2025-10-04 02:02:26 -05:00
}
2026-03-17 17:58:01 -05:00
elif item.get('type') == 'file':
if item.get('context') == 'full' or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
if item.get('file', {}).get('data', {}).get('content', ''):
2025-07-11 12:29:17 +04:00
# Manual Full Mode Toggle
# Used from chat file modal, we can assume that the file content will be available from item.get("file").get("data", {}).get("content")
query_result = {
2026-03-17 17:58:01 -05:00
'documents': [[item.get('file', {}).get('data', {}).get('content', '')]],
'metadatas': [
2025-07-11 12:29:17 +04:00
[
{
2026-03-17 17:58:01 -05:00
'file_id': item.get('id'),
'name': item.get('name'),
**item.get('file').get('data', {}).get('metadata', {}),
2025-07-11 12:29:17 +04:00
}
]
],
}
2026-03-17 17:58:01 -05:00
elif item.get('id'):
file_object = Files.get_file_by_id(item.get('id'))
2025-07-11 12:29:17 +04:00
if file_object:
query_result = {
2026-03-17 17:58:01 -05:00
'documents': [[file_object.data.get('content', '')]],
'metadatas': [
2025-07-11 12:29:17 +04:00
[
{
2026-03-17 17:58:01 -05:00
'file_id': item.get('id'),
'name': file_object.filename,
'source': file_object.filename,
2025-07-11 12:29:17 +04:00
}
]
],
}
else:
# Fallback to collection names
2026-03-17 17:58:01 -05:00
if item.get('legacy'):
collection_names.append(f'{item["id"]}')
2025-07-11 12:29:17 +04:00
else:
2026-03-17 17:58:01 -05:00
collection_names.append(f'file-{item["id"]}')
2025-07-11 12:00:21 +04:00
2026-03-17 17:58:01 -05:00
elif item.get('type') == 'collection':
2025-10-26 17:22:23 -07:00
# Manual Full Mode Toggle for Collection
2026-03-17 17:58:01 -05:00
knowledge_base = Knowledges.get_knowledge_by_id(item.get('id'))
2025-10-26 17:22:23 -07:00
if knowledge_base and (
2026-03-17 17:58:01 -05:00
user.role == 'admin'
2025-10-26 17:22:23 -07:00
or knowledge_base.user_id == user.id
2026-02-08 21:24:38 -06:00
or AccessGrants.has_access(
user_id=user.id,
2026-03-17 17:58:01 -05:00
resource_type='knowledge',
2026-02-08 21:24:38 -06:00
resource_id=knowledge_base.id,
2026-03-17 17:58:01 -05:00
permission='read',
2026-02-08 21:24:38 -06:00
)
2025-10-26 17:22:23 -07:00
):
2026-03-17 17:58:01 -05:00
if item.get('context') == 'full' or request.app.state.config.BYPASS_EMBEDDING_AND_RETRIEVAL:
2025-10-26 17:22:23 -07:00
if knowledge_base and (
2026-03-17 17:58:01 -05:00
user.role == 'admin'
2025-10-26 17:22:23 -07:00
or knowledge_base.user_id == user.id
2026-02-08 21:24:38 -06:00
or AccessGrants.has_access(
user_id=user.id,
2026-03-17 17:58:01 -05:00
resource_type='knowledge',
2026-02-08 21:24:38 -06:00
resource_id=knowledge_base.id,
2026-03-17 17:58:01 -05:00
permission='read',
2026-02-08 21:24:38 -06:00
)
2025-10-26 17:22:23 -07:00
):
2025-12-02 10:53:32 -05:00
files = Knowledges.get_files_by_id(knowledge_base.id)
2025-10-26 17:22:23 -07:00
documents = []
metadatas = []
2025-12-02 10:53:32 -05:00
for file in files:
2026-03-17 17:58:01 -05:00
documents.append(file.data.get('content', ''))
2025-12-02 10:53:32 -05:00
metadatas.append(
{
2026-03-17 17:58:01 -05:00
'file_id': file.id,
'name': file.filename,
'source': file.filename,
2025-12-02 10:53:32 -05:00
}
)
2025-10-26 17:22:23 -07:00
query_result = {
2026-03-17 17:58:01 -05:00
'documents': [documents],
'metadatas': [metadatas],
2025-10-26 17:22:23 -07:00
}
2024-10-03 23:06:47 -07:00
else:
2025-10-26 17:22:23 -07:00
# Fallback to collection names
2026-03-17 17:58:01 -05:00
if item.get('legacy'):
collection_names = item.get('collection_names', [])
2025-10-26 17:22:23 -07:00
else:
2026-03-17 17:58:01 -05:00
collection_names.append(item['id'])
2024-05-06 15:49:00 -07:00
2026-03-17 17:58:01 -05:00
elif item.get('docs'):
2025-07-11 12:29:17 +04:00
# BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL
query_result = {
2026-03-17 17:58:01 -05:00
'documents': [[doc.get('content') for doc in item.get('docs')]],
'metadatas': [[doc.get('metadata') for doc in item.get('docs')]],
2025-07-11 12:29:17 +04:00
}
2026-03-17 17:58:01 -05:00
elif item.get('collection_name'):
2025-07-11 12:29:17 +04:00
# Direct Collection Name
2026-03-17 17:58:01 -05:00
collection_names.append(item['collection_name'])
elif item.get('collection_names'):
2025-07-15 21:57:24 +04:00
# Collection Names List
2026-03-17 17:58:01 -05:00
collection_names.extend(item['collection_names'])
2025-07-11 12:29:17 +04:00
# If query_result is None
# Fallback to collection names and vector search the collections
if query_result is None and collection_names:
2024-09-29 22:52:27 +02:00
collection_names = set(collection_names).difference(extracted_collections)
if not collection_names:
2026-03-17 17:58:01 -05:00
log.debug(f'skipping {item} as it has already been extracted')
2024-09-29 22:52:27 +02:00
continue
2024-04-14 19:48:15 -04:00
2025-07-11 12:35:42 +04:00
try:
if full_context:
2025-06-25 12:20:08 +04:00
query_result = get_all_items_from_collections(collection_names)
2025-07-11 12:35:42 +04:00
else:
query_result = None # Initialize to None
if hybrid_search:
try:
query_result = await query_collection_with_hybrid_search(
2024-09-29 22:52:27 +02:00
collection_names=collection_names,
2024-11-19 02:24:32 -08:00
queries=queries,
2024-09-29 22:52:27 +02:00
embedding_function=embedding_function,
k=k,
2025-07-11 12:35:42 +04:00
reranking_function=reranking_function,
k_reranker=k_reranker,
r=r,
hybrid_bm25_weight=hybrid_bm25_weight,
enable_enriched_texts=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS,
2025-07-11 12:35:42 +04:00
)
except Exception as e:
2026-03-17 17:58:01 -05:00
log.debug('Error when using hybrid search, using non hybrid search as fallback.')
2025-07-11 12:35:42 +04:00
# fallback to non-hybrid search
if not hybrid_search and query_result is None:
query_result = await query_collection(
2025-07-11 12:35:42 +04:00
collection_names=collection_names,
queries=queries,
embedding_function=embedding_function,
k=k,
)
except Exception as e:
log.exception(e)
2024-09-29 22:52:27 +02:00
extracted_collections.extend(collection_names)
2024-03-10 18:40:50 -07:00
2025-06-25 12:20:08 +04:00
if query_result:
2026-03-17 17:58:01 -05:00
if 'data' in item:
del item['data']
query_results.append({**query_result, 'file': item})
2024-03-10 18:40:50 -07:00
2024-11-21 19:46:09 -08:00
sources = []
2025-06-25 12:20:08 +04:00
for query_result in query_results:
try:
2026-03-17 17:58:01 -05:00
if 'documents' in query_result:
if 'metadatas' in query_result:
2024-11-21 19:46:09 -08:00
source = {
2026-03-17 17:58:01 -05:00
'source': query_result['file'],
'document': query_result['documents'][0],
'metadata': query_result['metadatas'][0],
2024-10-07 21:13:13 +02:00
}
2026-03-17 17:58:01 -05:00
if 'distances' in query_result and query_result['distances']:
source['distances'] = query_result['distances'][0]
2024-11-21 19:46:09 -08:00
sources.append(source)
except Exception as e:
log.exception(e)
2024-11-21 19:46:09 -08:00
return sources
2024-04-04 12:07:42 -06:00
2024-04-25 07:49:59 -05:00
def get_model_path(model: str, update_model: bool = False):
# Construct huggingface_hub kwargs with local_files_only to return the snapshot path
2026-03-17 17:58:01 -05:00
cache_dir = os.getenv('SENTENCE_TRANSFORMERS_HOME')
2024-04-25 07:49:59 -05:00
local_files_only = not update_model
2024-12-29 11:53:09 +05:30
if OFFLINE_MODE:
local_files_only = True
2024-04-25 07:49:59 -05:00
snapshot_kwargs = {
2026-03-17 17:58:01 -05:00
'cache_dir': cache_dir,
'local_files_only': local_files_only,
2024-04-25 07:49:59 -05:00
}
2026-03-17 17:58:01 -05:00
log.debug(f'model: {model}')
log.debug(f'snapshot_kwargs: {snapshot_kwargs}')
2024-04-25 07:49:59 -05:00
# Inspiration from upstream sentence_transformers
2026-03-17 17:58:01 -05:00
if os.path.exists(model) or ('\\' in model or model.count('/') > 1) and local_files_only:
2024-04-25 07:49:59 -05:00
# If fully qualified path exists, return input, else set repo_id
2024-07-15 11:09:05 +02:00
return model
2026-03-17 17:58:01 -05:00
elif '/' not in model:
2024-04-25 07:49:59 -05:00
# Set valid repo_id for model short-name
2026-03-17 17:58:01 -05:00
model = 'sentence-transformers' + '/' + model
2024-04-25 07:49:59 -05:00
2026-03-17 17:58:01 -05:00
snapshot_kwargs['repo_id'] = model
2024-04-25 07:49:59 -05:00
# Attempt to query the huggingface_hub library to determine the local path and/or to update
try:
model_repo_path = snapshot_download(**snapshot_kwargs)
2026-03-17 17:58:01 -05:00
log.debug(f'model_repo_path: {model_repo_path}')
2024-04-25 07:49:59 -05:00
return model_repo_path
except Exception as e:
2026-03-17 17:58:01 -05:00
log.exception(f'Cannot determine model snapshot path: {e}')
if OFFLINE_MODE:
raise
2024-07-15 11:09:05 +02:00
return model
2024-04-25 07:49:59 -05:00
2024-04-22 18:36:46 -05:00
import operator
from typing import Optional, Sequence
from langchain_core.callbacks import Callbacks
2024-08-28 00:10:27 +02:00
from langchain_core.documents import BaseDocumentCompressor, Document
2024-04-22 18:36:46 -05:00
class RerankCompressor(BaseDocumentCompressor):
2024-04-27 15:38:50 -04:00
embedding_function: Any
2024-04-29 12:15:58 -05:00
top_n: int
2024-04-22 18:36:46 -05:00
reranking_function: Any
r_score: float
class Config:
2026-03-17 17:58:01 -05:00
extra = 'forbid'
2024-04-22 18:36:46 -05:00
arbitrary_types_allowed = True
2025-11-24 05:52:18 -05:00
def compress_documents(
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
"""Compress retrieved documents given the query context.
Args:
documents: The retrieved documents.
query: The query context.
callbacks: Optional callbacks to run during compression.
Returns:
The compressed documents.
"""
2025-11-24 05:58:22 -05:00
return []
2025-11-24 05:52:18 -05:00
async def acompress_documents(
2024-04-22 18:36:46 -05:00
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
) -> Sequence[Document]:
2024-04-29 12:15:58 -05:00
reranking = self.reranking_function is not None
2025-08-21 21:48:21 +04:00
scores = None
2024-04-29 12:15:58 -05:00
if reranking:
scores = await asyncio.to_thread(self.reranking_function, query, documents)
2024-04-22 18:36:46 -05:00
else:
from sentence_transformers import util
2026-03-17 17:58:01 -05:00
query_embedding = await self.embedding_function(query, RAG_EMBEDDING_QUERY_PREFIX)
document_embedding = await self.embedding_function(
2025-03-30 21:55:15 -07:00
[doc.page_content for doc in documents], RAG_EMBEDDING_CONTENT_PREFIX
2024-04-22 18:36:46 -05:00
)
scores = util.cos_sim(query_embedding, document_embedding)[0]
2025-08-22 16:47:05 +04:00
if scores is not None:
2025-08-21 21:48:21 +04:00
docs_with_scores = list(
zip(
documents,
scores.tolist() if not isinstance(scores, list) else scores,
)
)
if self.r_score:
2026-03-17 17:58:01 -05:00
docs_with_scores = [(d, s) for d, s in docs_with_scores if s >= self.r_score]
2025-08-21 21:48:21 +04:00
result = sorted(docs_with_scores, key=operator.itemgetter(1), reverse=True)
final_results = []
for doc, doc_score in result[: self.top_n]:
metadata = doc.metadata
2026-03-17 17:58:01 -05:00
metadata['score'] = doc_score
2025-08-21 21:48:21 +04:00
doc = Document(
page_content=doc.page_content,
metadata=metadata,
)
final_results.append(doc)
return final_results
else:
2026-03-17 17:58:01 -05:00
log.warning('No valid scores found, check your reranking function. Returning original documents.')
2025-08-21 21:48:21 +04:00
return documents