Files
open-webui/backend/open_webui/retrieval/vector/dbs/chroma.py
T

190 lines
6.8 KiB
Python
Raw Normal View History

2024-09-10 02:27:50 +01:00
import chromadb
import logging
2024-09-10 02:27:50 +01:00
from chromadb import Settings
2024-09-10 04:37:06 +01:00
from chromadb.utils.batch_utils import create_batches
2024-09-10 02:27:50 +01:00
2024-09-10 04:37:06 +01:00
from typing import Optional
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
2025-09-28 20:17:27 -05:00
from open_webui.retrieval.vector.utils import process_metadata
2025-07-31 17:45:06 +04:00
2024-09-10 02:27:50 +01:00
from open_webui.config import (
CHROMA_DATA_PATH,
CHROMA_HTTP_HOST,
CHROMA_HTTP_PORT,
CHROMA_HTTP_HEADERS,
CHROMA_HTTP_SSL,
CHROMA_TENANT,
CHROMA_DATABASE,
CHROMA_CLIENT_AUTH_PROVIDER,
CHROMA_CLIENT_AUTH_CREDENTIALS,
2024-09-10 02:27:50 +01:00
)
log = logging.getLogger(__name__)
2024-09-10 02:27:50 +01:00
class ChromaClient(VectorDBBase):
2024-09-10 02:27:50 +01:00
def __init__(self):
settings_dict = {
2026-03-17 17:58:01 -05:00
'allow_reset': True,
'anonymized_telemetry': False,
}
if CHROMA_CLIENT_AUTH_PROVIDER is not None:
2026-03-17 17:58:01 -05:00
settings_dict['chroma_client_auth_provider'] = CHROMA_CLIENT_AUTH_PROVIDER
if CHROMA_CLIENT_AUTH_CREDENTIALS is not None:
2026-03-17 17:58:01 -05:00
settings_dict['chroma_client_auth_credentials'] = CHROMA_CLIENT_AUTH_CREDENTIALS
2026-03-17 17:58:01 -05:00
if CHROMA_HTTP_HOST != '':
2024-09-10 02:27:50 +01:00
self.client = chromadb.HttpClient(
host=CHROMA_HTTP_HOST,
port=CHROMA_HTTP_PORT,
headers=CHROMA_HTTP_HEADERS,
ssl=CHROMA_HTTP_SSL,
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
settings=Settings(**settings_dict),
2024-09-10 02:27:50 +01:00
)
else:
self.client = chromadb.PersistentClient(
path=CHROMA_DATA_PATH,
settings=Settings(**settings_dict),
2024-09-10 02:27:50 +01:00
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
)
2024-09-12 02:00:31 -04:00
def has_collection(self, collection_name: str) -> bool:
# Check if the collection exists based on the collection name.
2025-01-08 13:18:14 -08:00
collection_names = self.client.list_collections()
return collection_name in collection_names
2024-09-10 04:37:06 +01:00
def delete_collection(self, collection_name: str):
2024-09-10 04:46:40 +01:00
# Delete the collection based on the collection name.
2024-09-10 04:37:06 +01:00
return self.client.delete_collection(name=collection_name)
def search(
2026-01-09 22:27:53 +04:00
self,
collection_name: str,
vectors: list[list[float | int]],
filter: Optional[dict] = None,
limit: int = 10,
2024-09-13 01:18:20 -04:00
) -> Optional[SearchResult]:
2024-09-10 04:46:40 +01:00
# Search for the nearest neighbor items based on the vectors and return 'limit' number of results.
2024-10-03 20:58:56 -07:00
try:
collection = self.client.get_collection(name=collection_name)
if collection:
result = collection.query(
query_embeddings=vectors,
n_results=limit,
2026-01-09 22:21:00 +04:00
where=filter,
2024-10-03 20:58:56 -07:00
)
2024-09-10 04:37:06 +01:00
# chromadb has cosine distance, 2 (worst) -> 0 (best). Re-odering to 0 -> 1
# https://docs.trychroma.com/docs/collections/configure cosine equation
2026-03-17 17:58:01 -05:00
distances: list = result['distances'][0]
distances = [2 - dist for dist in distances]
2025-03-25 19:09:17 +01:00
distances = [[dist / 2 for dist in distances]]
2024-10-03 20:58:56 -07:00
return SearchResult(
**{
2026-03-17 17:58:01 -05:00
'ids': result['ids'],
'distances': distances,
'documents': result['documents'],
'metadatas': result['metadatas'],
2024-10-03 20:58:56 -07:00
}
)
return None
except Exception as e:
return None
2024-09-10 04:37:06 +01:00
2026-03-17 17:58:01 -05:00
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None) -> Optional[GetResult]:
2024-10-03 06:53:21 -07:00
# Query the items from the collection based on the filter.
2024-10-03 20:58:56 -07:00
try:
collection = self.client.get_collection(name=collection_name)
if collection:
result = collection.get(
where=filter,
limit=limit,
)
return GetResult(
**{
2026-03-17 17:58:01 -05:00
'ids': [result['ids']],
'documents': [result['documents']],
'metadatas': [result['metadatas']],
2024-10-03 20:58:56 -07:00
}
)
return None
except Exception:
2024-10-03 20:58:56 -07:00
return None
2024-10-03 06:53:21 -07:00
2024-09-13 01:18:20 -04:00
def get(self, collection_name: str) -> Optional[GetResult]:
2024-09-10 04:46:40 +01:00
# Get all the items in the collection.
2024-09-10 04:37:06 +01:00
collection = self.client.get_collection(name=collection_name)
if collection:
2024-09-13 01:18:20 -04:00
result = collection.get()
return GetResult(
**{
2026-03-17 17:58:01 -05:00
'ids': [result['ids']],
'documents': [result['documents']],
'metadatas': [result['metadatas']],
2024-09-13 01:18:20 -04:00
}
)
2024-09-10 02:27:50 +01:00
return None
2024-09-10 04:37:06 +01:00
def insert(self, collection_name: str, items: list[VectorItem]):
2024-09-12 01:52:19 -04:00
# Insert the items into the collection, if the collection does not exist, it will be created.
2026-03-17 17:58:01 -05:00
collection = self.client.get_or_create_collection(name=collection_name, metadata={'hnsw:space': 'cosine'})
2024-09-10 04:37:06 +01:00
2026-03-17 17:58:01 -05:00
ids = [item['id'] for item in items]
documents = [item['text'] for item in items]
embeddings = [item['vector'] for item in items]
metadatas = [process_metadata(item['metadata']) for item in items]
2024-09-10 02:27:50 +01:00
2024-09-10 04:37:06 +01:00
for batch in create_batches(
api=self.client,
documents=documents,
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
):
collection.add(*batch)
2024-09-10 02:27:50 +01:00
2024-09-10 04:37:06 +01:00
def upsert(self, collection_name: str, items: list[VectorItem]):
2024-09-12 01:52:19 -04:00
# Update the items in the collection, if the items are not present, insert them. If the collection does not exist, it will be created.
2026-03-17 17:58:01 -05:00
collection = self.client.get_or_create_collection(name=collection_name, metadata={'hnsw:space': 'cosine'})
2024-09-10 02:27:50 +01:00
2026-03-17 17:58:01 -05:00
ids = [item['id'] for item in items]
documents = [item['text'] for item in items]
embeddings = [item['vector'] for item in items]
metadatas = [process_metadata(item['metadata']) for item in items]
2024-09-10 04:37:06 +01:00
2026-03-17 17:58:01 -05:00
collection.upsert(ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas)
2024-09-10 04:37:06 +01:00
2024-10-03 06:43:50 -07:00
def delete(
self,
collection_name: str,
ids: Optional[list[str]] = None,
filter: Optional[dict] = None,
):
2024-09-10 04:46:40 +01:00
# Delete the items from the collection based on the ids.
try:
collection = self.client.get_collection(name=collection_name)
if collection:
if ids:
collection.delete(ids=ids)
elif filter:
collection.delete(where=filter)
except Exception as e:
# If collection doesn't exist, that's fine - nothing to delete
2026-03-17 17:58:01 -05:00
log.debug(f'Attempted to delete from non-existent collection {collection_name}. Ignoring.')
pass
2024-09-10 02:27:50 +01:00
def reset(self):
2024-09-10 04:46:40 +01:00
# Resets the database. This will delete all collections and item entries.
2024-09-10 02:27:50 +01:00
return self.client.reset()