mirror of
https://github.com/open-webui/open-webui.git
synced 2026-06-13 19:20:05 +00:00
refac: modernize type annotations (PEP 604 / PEP 585)
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
@@ -20,29 +22,36 @@ from open_webui.env import (
|
||||
DATABASE_URL,
|
||||
ENABLE_DB_MIGRATIONS,
|
||||
ENV,
|
||||
REDIS_URL,
|
||||
REDIS_KEY_PREFIX,
|
||||
REDIS_SENTINEL_HOSTS,
|
||||
REDIS_SENTINEL_PORT,
|
||||
FRONTEND_BUILD_DIR,
|
||||
OFFLINE_MODE,
|
||||
OPEN_WEBUI_DIR,
|
||||
REDIS_KEY_PREFIX,
|
||||
REDIS_SENTINEL_HOSTS,
|
||||
REDIS_SENTINEL_PORT,
|
||||
REDIS_URL,
|
||||
WEBUI_AUTH,
|
||||
WEBUI_FAVICON_URL,
|
||||
WEBUI_NAME,
|
||||
log,
|
||||
)
|
||||
from open_webui.internal.config import (
|
||||
STATE as _state,
|
||||
)
|
||||
from open_webui.internal.config import (
|
||||
AppConfig,
|
||||
ConfigVar,
|
||||
)
|
||||
|
||||
# ── Persistent configuration layer ──────────────────────────────────────────
|
||||
|
||||
from open_webui.internal.config import ( # noqa: F401
|
||||
ConfigTable as Config,
|
||||
ConfigVar,
|
||||
AppConfig,
|
||||
STATE as _state,
|
||||
initialize as _initialize_config,
|
||||
)
|
||||
from open_webui.internal.config import (
|
||||
_all_configs as PERSISTENT_CONFIG_REGISTRY,
|
||||
)
|
||||
from open_webui.internal.config import (
|
||||
initialize as _initialize_config,
|
||||
)
|
||||
|
||||
|
||||
def get_config():
|
||||
@@ -1288,9 +1297,7 @@ RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
|
||||
not OFFLINE_MODE and os.getenv('RAG_EMBEDDING_MODEL_AUTO_UPDATE', 'True').lower() == 'true'
|
||||
)
|
||||
|
||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
|
||||
os.getenv('RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE', 'True').lower() == 'true'
|
||||
)
|
||||
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = os.getenv('RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE', 'True').lower() == 'true'
|
||||
|
||||
RAG_EMBEDDING_BATCH_SIZE = ConfigVar(
|
||||
'RAG_EMBEDDING_BATCH_SIZE',
|
||||
@@ -1335,9 +1342,7 @@ RAG_RERANKING_MODEL_AUTO_UPDATE = (
|
||||
not OFFLINE_MODE and os.getenv('RAG_RERANKING_MODEL_AUTO_UPDATE', 'True').lower() == 'true'
|
||||
)
|
||||
|
||||
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
|
||||
os.getenv('RAG_RERANKING_MODEL_TRUST_REMOTE_CODE', 'True').lower() == 'true'
|
||||
)
|
||||
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = os.getenv('RAG_RERANKING_MODEL_TRUST_REMOTE_CODE', 'True').lower() == 'true'
|
||||
|
||||
RAG_RERANKING_BATCH_SIZE = ConfigVar(
|
||||
'RAG_RERANKING_BATCH_SIZE',
|
||||
@@ -2709,9 +2714,7 @@ USER_PERMISSIONS_CHAT_DELETE = os.getenv('USER_PERMISSIONS_CHAT_DELETE', 'True')
|
||||
|
||||
USER_PERMISSIONS_CHAT_DELETE_MESSAGE = os.getenv('USER_PERMISSIONS_CHAT_DELETE_MESSAGE', 'True').lower() == 'true'
|
||||
|
||||
USER_PERMISSIONS_CHAT_CONTINUE_RESPONSE = (
|
||||
os.getenv('USER_PERMISSIONS_CHAT_CONTINUE_RESPONSE', 'True').lower() == 'true'
|
||||
)
|
||||
USER_PERMISSIONS_CHAT_CONTINUE_RESPONSE = os.getenv('USER_PERMISSIONS_CHAT_CONTINUE_RESPONSE', 'True').lower() == 'true'
|
||||
|
||||
USER_PERMISSIONS_CHAT_REGENERATE_RESPONSE = (
|
||||
os.getenv('USER_PERMISSIONS_CHAT_REGENERATE_RESPONSE', 'True').lower() == 'true'
|
||||
@@ -2735,9 +2738,7 @@ USER_PERMISSIONS_CHAT_TTS = os.getenv('USER_PERMISSIONS_CHAT_TTS', 'True').lower
|
||||
|
||||
USER_PERMISSIONS_CHAT_CALL = os.getenv('USER_PERMISSIONS_CHAT_CALL', 'True').lower() == 'true'
|
||||
|
||||
USER_PERMISSIONS_CHAT_MULTIPLE_MODELS = (
|
||||
os.getenv('USER_PERMISSIONS_CHAT_MULTIPLE_MODELS', 'True').lower() == 'true'
|
||||
)
|
||||
USER_PERMISSIONS_CHAT_MULTIPLE_MODELS = os.getenv('USER_PERMISSIONS_CHAT_MULTIPLE_MODELS', 'True').lower() == 'true'
|
||||
|
||||
USER_PERMISSIONS_CHAT_TEMPORARY = os.getenv('USER_PERMISSIONS_CHAT_TEMPORARY', 'True').lower() == 'true'
|
||||
|
||||
@@ -2770,9 +2771,7 @@ USER_PERMISSIONS_FEATURES_API_KEYS = os.getenv('USER_PERMISSIONS_FEATURES_API_KE
|
||||
|
||||
USER_PERMISSIONS_FEATURES_MEMORIES = os.getenv('USER_PERMISSIONS_FEATURES_MEMORIES', 'True').lower() == 'true'
|
||||
|
||||
USER_PERMISSIONS_FEATURES_AUTOMATIONS = (
|
||||
os.getenv('USER_PERMISSIONS_FEATURES_AUTOMATIONS', 'False').lower() == 'true'
|
||||
)
|
||||
USER_PERMISSIONS_FEATURES_AUTOMATIONS = os.getenv('USER_PERMISSIONS_FEATURES_AUTOMATIONS', 'False').lower() == 'true'
|
||||
|
||||
USER_PERMISSIONS_FEATURES_CALENDAR = os.getenv('USER_PERMISSIONS_FEATURES_CALENDAR', 'True').lower() == 'true'
|
||||
|
||||
@@ -2940,9 +2939,7 @@ WEBHOOK_URL = ConfigVar('WEBHOOK_URL', 'webhook_url', os.getenv('WEBHOOK_URL', '
|
||||
|
||||
ENABLE_ADMIN_EXPORT = os.getenv('ENABLE_ADMIN_EXPORT', 'True').lower() == 'true'
|
||||
|
||||
ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS = (
|
||||
os.getenv('ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS', 'True').lower() == 'true'
|
||||
)
|
||||
ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS = os.getenv('ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS', 'True').lower() == 'true'
|
||||
|
||||
BYPASS_ADMIN_ACCESS_CONTROL = (
|
||||
os.getenv(
|
||||
@@ -3024,7 +3021,7 @@ else:
|
||||
class BannerModel(BaseModel):
|
||||
id: str
|
||||
type: str
|
||||
title: Optional[str] = None
|
||||
title: str | None = None
|
||||
content: str
|
||||
dismissible: bool
|
||||
timestamp: int
|
||||
@@ -3705,9 +3702,7 @@ OAUTH_ALLOWED_ROLES = ConfigVar(
|
||||
'oauth.allowed_roles',
|
||||
[
|
||||
role.strip()
|
||||
for role in os.getenv('OAUTH_ALLOWED_ROLES', f'user{OAUTH_ROLES_SEPARATOR}admin').split(
|
||||
OAUTH_ROLES_SEPARATOR
|
||||
)
|
||||
for role in os.getenv('OAUTH_ALLOWED_ROLES', f'user{OAUTH_ROLES_SEPARATOR}admin').split(OAUTH_ROLES_SEPARATOR)
|
||||
if role
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
|
||||
@@ -530,9 +530,7 @@ except (ValueError, TypeError):
|
||||
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = 10
|
||||
|
||||
|
||||
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL = (
|
||||
os.getenv('AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL', 'True').lower() == 'true'
|
||||
)
|
||||
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL = os.getenv('AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL', 'True').lower() == 'true'
|
||||
|
||||
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER = os.getenv('AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER', '')
|
||||
|
||||
@@ -779,9 +777,7 @@ FORWARD_USER_INFO_HEADER_USER_NAME = os.getenv('FORWARD_USER_INFO_HEADER_USER_NA
|
||||
FORWARD_USER_INFO_HEADER_USER_ID = os.getenv('FORWARD_USER_INFO_HEADER_USER_ID', 'X-OpenWebUI-User-Id')
|
||||
FORWARD_USER_INFO_HEADER_USER_EMAIL = os.getenv('FORWARD_USER_INFO_HEADER_USER_EMAIL', 'X-OpenWebUI-User-Email')
|
||||
FORWARD_USER_INFO_HEADER_USER_ROLE = os.getenv('FORWARD_USER_INFO_HEADER_USER_ROLE', 'X-OpenWebUI-User-Role')
|
||||
FORWARD_SESSION_INFO_HEADER_MESSAGE_ID = os.getenv(
|
||||
'FORWARD_SESSION_INFO_HEADER_MESSAGE_ID', 'X-OpenWebUI-Message-Id'
|
||||
)
|
||||
FORWARD_SESSION_INFO_HEADER_MESSAGE_ID = os.getenv('FORWARD_SESSION_INFO_HEADER_MESSAGE_ID', 'X-OpenWebUI-Message-Id')
|
||||
FORWARD_SESSION_INFO_HEADER_CHAT_ID = os.getenv('FORWARD_SESSION_INFO_HEADER_CHAT_ID', 'X-OpenWebUI-Chat-Id')
|
||||
|
||||
####################################
|
||||
@@ -897,9 +893,7 @@ if SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND == '':
|
||||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = 'torch'
|
||||
|
||||
|
||||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = os.getenv(
|
||||
'SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS', ''
|
||||
)
|
||||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = os.getenv('SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS', '')
|
||||
if SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS == '':
|
||||
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None
|
||||
else:
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import logging
|
||||
import sys
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
from pydantic import BaseModel
|
||||
import logging
|
||||
import sys
|
||||
from typing import AsyncGenerator, Generator, Iterator
|
||||
|
||||
from fastapi import (
|
||||
Depends,
|
||||
FastAPI,
|
||||
@@ -16,40 +15,35 @@ from fastapi import (
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from starlette.responses import Response, StreamingResponse
|
||||
|
||||
|
||||
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import BYPASS_MODEL_ACCESS_CONTROL, GLOBAL_LOG_LEVEL
|
||||
from open_webui.models.functions import Functions
|
||||
from open_webui.models.models import Models
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.socket.main import (
|
||||
get_event_call,
|
||||
get_event_emitter,
|
||||
)
|
||||
|
||||
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.models.functions import Functions
|
||||
from open_webui.models.models import Models
|
||||
|
||||
from open_webui.utils.plugin import (
|
||||
load_function_module_by_id,
|
||||
get_function_module_from_cache,
|
||||
)
|
||||
from open_webui.utils.access_control import check_model_access
|
||||
|
||||
from open_webui.env import GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
|
||||
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL
|
||||
|
||||
from open_webui.utils.misc import (
|
||||
add_or_update_system_message,
|
||||
get_last_user_message,
|
||||
prepend_to_first_user_message_content,
|
||||
openai_chat_chunk_message_template,
|
||||
openai_chat_completion_message_template,
|
||||
prepend_to_first_user_message_content,
|
||||
)
|
||||
from open_webui.utils.payload import (
|
||||
apply_model_params_to_body_openai,
|
||||
apply_system_prompt_to_body,
|
||||
)
|
||||
from open_webui.utils.plugin import (
|
||||
get_function_module_from_cache,
|
||||
load_function_module_by_id,
|
||||
)
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -10,10 +10,9 @@ from functools import reduce
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import redis
|
||||
from sqlalchemy import JSON, Column, DateTime, Integer, func, select
|
||||
|
||||
from open_webui.internal.db import Base, get_async_db, get_db
|
||||
from open_webui.utils.redis import get_redis_connection
|
||||
from sqlalchemy import JSON, Column, DateTime, Integer, func, select
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -22,7 +21,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ConfigTable(Base):
|
||||
__tablename__ = "config"
|
||||
__tablename__ = 'config'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
data = Column(JSON, nullable=False)
|
||||
@@ -37,7 +36,7 @@ class ConfigTable(Base):
|
||||
class ConfigState:
|
||||
"""In-memory mirror of the single-row config JSON blob."""
|
||||
|
||||
__slots__ = ("_data",)
|
||||
__slots__ = ('_data',)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._data: dict[str, Any] = {}
|
||||
@@ -49,12 +48,12 @@ class ConfigState:
|
||||
def read(self, path: str) -> Any:
|
||||
return reduce(
|
||||
lambda n, k: n.get(k) if isinstance(n, dict) else None,
|
||||
path.split("."),
|
||||
path.split('.'),
|
||||
self._data,
|
||||
)
|
||||
|
||||
def write(self, path: str, value: Any) -> None:
|
||||
keys = path.split(".")
|
||||
keys = path.split('.')
|
||||
reduce(lambda d, k: d.setdefault(k, {}), keys[:-1], self._data)[keys[-1]] = value
|
||||
|
||||
def replace(self, data: dict) -> None:
|
||||
@@ -63,7 +62,7 @@ class ConfigState:
|
||||
def load(self) -> dict:
|
||||
with get_db() as db:
|
||||
row = db.query(ConfigTable).order_by(ConfigTable.id.desc()).first()
|
||||
self._data = row.data if row else {"version": 0, "ui": {}}
|
||||
self._data = row.data if row else {'version': 0, 'ui': {}}
|
||||
return self._data
|
||||
|
||||
def persist(self, data: dict | None = None) -> None:
|
||||
@@ -123,7 +122,7 @@ def initialize(*, enable_persistent: bool = True, enable_oauth_persistent: bool
|
||||
|
||||
|
||||
class ConfigVar:
|
||||
__slots__ = ("env_name", "config_path", "env_value", "config_value", "value")
|
||||
__slots__ = ('env_name', 'config_path', 'env_value', 'config_value', 'value')
|
||||
|
||||
def __init__(self, env_name: str, config_path: str, env_value: Any) -> None:
|
||||
self.env_name = env_name
|
||||
@@ -132,7 +131,7 @@ class ConfigVar:
|
||||
self.config_value = STATE.read(config_path)
|
||||
|
||||
if self.config_value is not None and _persist_enabled:
|
||||
if config_path.startswith("oauth.") and not _oauth_persist_enabled:
|
||||
if config_path.startswith('oauth.') and not _oauth_persist_enabled:
|
||||
log.info("Skipping DB value for '%s' (OAuth persistence disabled)", env_name)
|
||||
self.value = env_value
|
||||
else:
|
||||
@@ -147,22 +146,22 @@ class ConfigVar:
|
||||
return str(self.value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<ConfigVar {self.env_name}={self.value!r}>"
|
||||
return f'<ConfigVar {self.env_name}={self.value!r}>'
|
||||
|
||||
@property
|
||||
def __dict__(self): # type: ignore[override]
|
||||
raise TypeError(f"ConfigVar('{self.env_name}') cannot be cast to dict; use .value")
|
||||
|
||||
def __getattribute__(self, item: str):
|
||||
if item == "__dict__":
|
||||
raise TypeError("ConfigVar cannot be cast to dict; use .value")
|
||||
if item == '__dict__':
|
||||
raise TypeError('ConfigVar cannot be cast to dict; use .value')
|
||||
return super().__getattribute__(item)
|
||||
|
||||
def refresh(self) -> None:
|
||||
current = STATE.read(self.config_path)
|
||||
if current is not None:
|
||||
self.value = current
|
||||
log.info("Refreshed %s → %s", self.env_name, self.value)
|
||||
log.info('Refreshed %s → %s', self.env_name, self.value)
|
||||
|
||||
def commit(self) -> None:
|
||||
log.info("Persisting '%s'", self.env_name)
|
||||
@@ -189,18 +188,18 @@ class AppConfig:
|
||||
redis_url: Optional[str] = None,
|
||||
redis_sentinels: Optional[list] = None,
|
||||
redis_cluster: bool = False,
|
||||
redis_key_prefix: str = "open-webui",
|
||||
redis_key_prefix: str = 'open-webui',
|
||||
) -> None:
|
||||
super().__setattr__("_entries", {})
|
||||
super().__setattr__("_key_prefix", redis_key_prefix)
|
||||
super().__setattr__('_entries', {})
|
||||
super().__setattr__('_key_prefix', redis_key_prefix)
|
||||
|
||||
rc: Union[redis.Redis, redis.cluster.RedisCluster, None] = None
|
||||
if redis_url:
|
||||
rc = get_redis_connection(redis_url, redis_sentinels or [], redis_cluster, decode_responses=True)
|
||||
super().__setattr__("_rc", rc)
|
||||
super().__setattr__('_rc', rc)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
entries: dict = super().__getattribute__("_entries")
|
||||
entries: dict = super().__getattribute__('_entries')
|
||||
|
||||
if isinstance(value, ConfigVar):
|
||||
entries[name] = value
|
||||
@@ -213,11 +212,11 @@ class AppConfig:
|
||||
except RuntimeError:
|
||||
entries[name].commit()
|
||||
|
||||
rc = super().__getattribute__("_rc")
|
||||
rc = super().__getattribute__('_rc')
|
||||
if rc and _persist_enabled:
|
||||
prefix = super().__getattribute__("_key_prefix")
|
||||
prefix = super().__getattribute__('_key_prefix')
|
||||
try:
|
||||
rc.set(f"{prefix}:config:{name}", json.dumps(entries[name].value))
|
||||
rc.set(f'{prefix}:config:{name}', json.dumps(entries[name].value))
|
||||
except Exception as exc:
|
||||
log.error("Redis write failed for '%s': %s", name, exc)
|
||||
|
||||
@@ -228,15 +227,15 @@ class AppConfig:
|
||||
log.error("Async persist failed for '%s': %s", name, exc)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
entries = super().__getattribute__("_entries")
|
||||
entries = super().__getattribute__('_entries')
|
||||
if name not in entries:
|
||||
raise AttributeError(f"No config key '{name}'")
|
||||
|
||||
rc = super().__getattribute__("_rc")
|
||||
rc = super().__getattribute__('_rc')
|
||||
if rc and _persist_enabled:
|
||||
prefix = super().__getattribute__("_key_prefix")
|
||||
prefix = super().__getattribute__('_key_prefix')
|
||||
try:
|
||||
raw = rc.get(f"{prefix}:config:{name}")
|
||||
raw = rc.get(f'{prefix}:config:{name}')
|
||||
if raw is not None:
|
||||
decoded = json.loads(raw)
|
||||
if entries[name].value != decoded:
|
||||
@@ -248,12 +247,12 @@ class AppConfig:
|
||||
return entries[name].value
|
||||
|
||||
def _sync_to_redis(self) -> None:
|
||||
rc = super().__getattribute__("_rc")
|
||||
rc = super().__getattribute__('_rc')
|
||||
if not rc or not _persist_enabled:
|
||||
return
|
||||
prefix = super().__getattribute__("_key_prefix")
|
||||
for name, s in super().__getattribute__("_entries").items():
|
||||
prefix = super().__getattribute__('_key_prefix')
|
||||
for name, s in super().__getattribute__('_entries').items():
|
||||
try:
|
||||
rc.set(f"{prefix}:config:{name}", json.dumps(s.value))
|
||||
rc.set(f'{prefix}:config:{name}', json.dumps(s.value))
|
||||
except Exception as exc:
|
||||
log.error("Redis sync failed for '%s': %s", name, exc)
|
||||
|
||||
@@ -1,35 +1,36 @@
|
||||
import os
|
||||
import sys
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from typing import Any, Optional
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
|
||||
from open_webui.env import (
|
||||
OPEN_WEBUI_DIR,
|
||||
DATABASE_URL,
|
||||
DATABASE_SCHEMA,
|
||||
DATABASE_ENABLE_SESSION_SHARING,
|
||||
DATABASE_ENABLE_SQLITE_WAL,
|
||||
DATABASE_POOL_MAX_OVERFLOW,
|
||||
DATABASE_POOL_RECYCLE,
|
||||
DATABASE_POOL_SIZE,
|
||||
DATABASE_POOL_TIMEOUT,
|
||||
DATABASE_ENABLE_SQLITE_WAL,
|
||||
DATABASE_ENABLE_SESSION_SHARING,
|
||||
DATABASE_SQLITE_PRAGMA_SYNCHRONOUS,
|
||||
DATABASE_SCHEMA,
|
||||
DATABASE_SQLITE_PRAGMA_BUSY_TIMEOUT,
|
||||
DATABASE_SQLITE_PRAGMA_CACHE_SIZE,
|
||||
DATABASE_SQLITE_PRAGMA_TEMP_STORE,
|
||||
DATABASE_SQLITE_PRAGMA_MMAP_SIZE,
|
||||
DATABASE_SQLITE_PRAGMA_JOURNAL_SIZE_LIMIT,
|
||||
DATABASE_SQLITE_PRAGMA_MMAP_SIZE,
|
||||
DATABASE_SQLITE_PRAGMA_SYNCHRONOUS,
|
||||
DATABASE_SQLITE_PRAGMA_TEMP_STORE,
|
||||
DATABASE_URL,
|
||||
ENABLE_DB_MIGRATIONS,
|
||||
OPEN_WEBUI_DIR,
|
||||
)
|
||||
from sqlalchemy import Dialect, create_engine, MetaData, event, types
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy import Dialect, MetaData, create_engine, event, types
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import scoped_session, sessionmaker, Session
|
||||
from sqlalchemy.pool import QueuePool, NullPool
|
||||
from sqlalchemy.orm import Session, scoped_session, sessionmaker
|
||||
from sqlalchemy.pool import NullPool, QueuePool
|
||||
from sqlalchemy.sql.type_api import _T
|
||||
from typing_extensions import Self
|
||||
|
||||
@@ -120,10 +121,10 @@ class JSONField(types.TypeDecorator):
|
||||
impl = types.Text
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Any:
|
||||
def process_bind_param(self, value: _T | None, dialect: Dialect) -> Any:
|
||||
return json.dumps(value)
|
||||
|
||||
def process_result_value(self, value: Optional[_T], dialect: Dialect) -> Any:
|
||||
def process_result_value(self, value: _T | None, dialect: Dialect) -> Any:
|
||||
if value is not None:
|
||||
return json.loads(value)
|
||||
|
||||
@@ -374,7 +375,7 @@ async def get_async_db():
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_async_db_context(db: Optional[AsyncSession] = None):
|
||||
async def get_async_db_context(db: AsyncSession | None = None):
|
||||
"""Async context manager that reuses an existing session if provided and session sharing is enabled."""
|
||||
if isinstance(db, AsyncSession) and DATABASE_ENABLE_SESSION_SHARING:
|
||||
yield db
|
||||
|
||||
+523
-538
File diff suppressed because it is too large
Load Diff
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""Alembic environment configuration.
|
||||
|
||||
Configures the migration context for both offline (SQL script generation)
|
||||
@@ -23,7 +25,7 @@ if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name, disable_existing_loggers=False)
|
||||
|
||||
# Re-apply JSON formatter after fileConfig replaces handlers.
|
||||
if LOG_FORMAT == "json":
|
||||
if LOG_FORMAT == 'json':
|
||||
from open_webui.env import JSONFormatter
|
||||
|
||||
for handler in logging.root.handlers:
|
||||
@@ -41,7 +43,7 @@ if _ssl_params:
|
||||
DB_URL = reattach_ssl_params_to_url(_url_no_ssl, _ssl_params)
|
||||
|
||||
if DB_URL:
|
||||
config.set_main_option("sqlalchemy.url", DB_URL.replace("%", "%%"))
|
||||
config.set_main_option('sqlalchemy.url', DB_URL.replace('%', '%%'))
|
||||
|
||||
|
||||
# ── Migration runners ────────────────────────────────────────────────────────
|
||||
@@ -49,12 +51,12 @@ if DB_URL:
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Generate SQL script without a live database connection."""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
url = config.get_main_option('sqlalchemy.url')
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
dialect_opts={'paramstyle': 'named'},
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
@@ -62,14 +64,12 @@ def run_migrations_offline() -> None:
|
||||
|
||||
def _build_connectable():
|
||||
"""Create the appropriate SQLAlchemy engine for the configured DB URL."""
|
||||
if DB_URL and DB_URL.startswith("sqlite+sqlcipher://"):
|
||||
if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == "":
|
||||
raise ValueError(
|
||||
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
|
||||
)
|
||||
if DB_URL and DB_URL.startswith('sqlite+sqlcipher://'):
|
||||
if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == '':
|
||||
raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs')
|
||||
|
||||
db_path = DB_URL.replace("sqlite+sqlcipher://", "")
|
||||
if db_path.startswith("/"):
|
||||
db_path = DB_URL.replace('sqlite+sqlcipher://', '')
|
||||
if db_path.startswith('/'):
|
||||
db_path = db_path[1:]
|
||||
|
||||
def _sqlcipher_creator():
|
||||
@@ -79,11 +79,11 @@ def _build_connectable():
|
||||
conn.execute(f"PRAGMA key = '{DATABASE_PASSWORD}'")
|
||||
return conn
|
||||
|
||||
return create_engine("sqlite://", creator=_sqlcipher_creator, echo=False)
|
||||
return create_engine('sqlite://', creator=_sqlcipher_creator, echo=False)
|
||||
|
||||
return engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
prefix='sqlalchemy.',
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""Alembic migration utilities."""
|
||||
|
||||
from alembic import op
|
||||
|
||||
@@ -6,8 +6,8 @@ Create Date: 2025-08-13 03:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = '018012973d35'
|
||||
down_revision = 'd31026856c01'
|
||||
|
||||
@@ -6,13 +6,13 @@ Create Date: 2024-10-09 21:02:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, select, update, column
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
import json
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.sql import column, select, table, update
|
||||
|
||||
revision = '1af9b942657b'
|
||||
down_revision = '242a2047eae0'
|
||||
branch_labels = None
|
||||
|
||||
@@ -6,12 +6,12 @@ Create Date: 2024-10-09 21:02:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, select, update
|
||||
|
||||
import json
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.sql import select, table, update
|
||||
|
||||
revision = '242a2047eae0'
|
||||
down_revision = '6a39f3d8e55c'
|
||||
branch_labels = None
|
||||
|
||||
+2
-2
@@ -8,9 +8,9 @@ Create Date: 2025-11-27 03:07:56.200231
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import open_webui.internal.db
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '2f1211949ecc'
|
||||
|
||||
@@ -6,11 +6,11 @@ Create Date: 2026-01-23 17:15:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
import uuid
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision: str = '374d2f66af06'
|
||||
down_revision: Union[str, None] = 'c440947495f3'
|
||||
|
||||
@@ -6,8 +6,8 @@ Create Date: 2024-12-30 03:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = '3781e22d8b01'
|
||||
down_revision = '7826ab40b532'
|
||||
|
||||
@@ -6,13 +6,13 @@ Create Date: 2025-11-17 03:45:25.123939
|
||||
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import time
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '37f288994c47'
|
||||
|
||||
@@ -8,8 +8,8 @@ Create Date: 2025-09-08 14:19:59.583921
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '38d63c18f30f'
|
||||
|
||||
@@ -6,13 +6,13 @@ Create Date: 2024-10-09 21:02:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, select, update, column
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
import json
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.sql import column, select, table, update
|
||||
|
||||
revision = '3ab32c4b8f59'
|
||||
down_revision = '1af9b942657b'
|
||||
branch_labels = None
|
||||
|
||||
@@ -8,8 +8,8 @@ Create Date: 2025-08-21 02:07:18.078283
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '3af16a1c9fb6'
|
||||
|
||||
@@ -6,16 +6,15 @@ Create Date: 2025-12-02 06:54:19.401334
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import inspect
|
||||
import open_webui.internal.db
|
||||
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy import inspect
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '3e0e00844bb0'
|
||||
|
||||
@@ -6,8 +6,8 @@ Create Date: 2024-10-23 03:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = '4ace53fd72c8'
|
||||
down_revision = 'af906e964978'
|
||||
|
||||
@@ -8,10 +8,9 @@ Create Date: 2026-05-09 04:29:27.651341
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import open_webui.internal.db
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '4de81c2a3af1'
|
||||
@@ -20,10 +19,11 @@ branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
import uuid
|
||||
import time
|
||||
from sqlalchemy import select, update, insert
|
||||
from sqlalchemy.sql import table, column
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import insert, select, update
|
||||
from sqlalchemy.sql import column, table
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
||||
@@ -8,9 +8,8 @@ Create Date: 2026-04-19 16:20:58.162045
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '56359461a091'
|
||||
|
||||
@@ -6,8 +6,8 @@ Create Date: 2024-12-22 03:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = '57c599a3cb57'
|
||||
down_revision = '922e7a387820'
|
||||
|
||||
@@ -8,9 +8,9 @@ Create Date: 2025-12-10 15:11:39.424601
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import open_webui.internal.db
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '6283dc0e4d8d'
|
||||
|
||||
@@ -6,11 +6,12 @@ Create Date: 2024-10-01 14:02:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, column, select
|
||||
import json
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.sql import column, select, table
|
||||
|
||||
revision = '6a39f3d8e55c'
|
||||
down_revision = 'c0fbf31ca0db'
|
||||
branch_labels = None
|
||||
|
||||
@@ -6,8 +6,8 @@ Create Date: 2024-12-23 03:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = '7826ab40b532'
|
||||
down_revision = '57c599a3cb57'
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""Initial Alembic schema — creates all base tables.
|
||||
|
||||
Revision ID: 7e5b5dc7342b
|
||||
@@ -7,14 +9,13 @@ Create Date: 2024-06-24 13:15:33.808998
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import open_webui.internal.db # noqa: F401
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import open_webui.internal.db # noqa: F401
|
||||
from open_webui.internal.db import JSONField
|
||||
from open_webui.migrations.util import get_existing_tables
|
||||
|
||||
revision: str = "7e5b5dc7342b"
|
||||
revision: str = '7e5b5dc7342b'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
@@ -25,162 +26,162 @@ depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
_TABLES: list[tuple[str, list[sa.Column], list]] = [
|
||||
(
|
||||
"auth",
|
||||
'auth',
|
||||
[
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("email", sa.String(), nullable=True),
|
||||
sa.Column("password", sa.Text(), nullable=True),
|
||||
sa.Column("active", sa.Boolean(), nullable=True),
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('email', sa.String(), nullable=True),
|
||||
sa.Column('password', sa.Text(), nullable=True),
|
||||
sa.Column('active', sa.Boolean(), nullable=True),
|
||||
],
|
||||
[sa.PrimaryKeyConstraint("id")],
|
||||
[sa.PrimaryKeyConstraint('id')],
|
||||
),
|
||||
(
|
||||
"chat",
|
||||
'chat',
|
||||
[
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("chat", sa.Text(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("share_id", sa.Text(), nullable=True),
|
||||
sa.Column("archived", sa.Boolean(), nullable=True),
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('title', sa.Text(), nullable=True),
|
||||
sa.Column('chat', sa.Text(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('share_id', sa.Text(), nullable=True),
|
||||
sa.Column('archived', sa.Boolean(), nullable=True),
|
||||
],
|
||||
[sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("share_id")],
|
||||
[sa.PrimaryKeyConstraint('id'), sa.UniqueConstraint('share_id')],
|
||||
),
|
||||
(
|
||||
"chatidtag",
|
||||
'chatidtag',
|
||||
[
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("tag_name", sa.String(), nullable=True),
|
||||
sa.Column("chat_id", sa.String(), nullable=True),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("timestamp", sa.BigInteger(), nullable=True),
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('tag_name', sa.String(), nullable=True),
|
||||
sa.Column('chat_id', sa.String(), nullable=True),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('timestamp', sa.BigInteger(), nullable=True),
|
||||
],
|
||||
[sa.PrimaryKeyConstraint("id")],
|
||||
[sa.PrimaryKeyConstraint('id')],
|
||||
),
|
||||
(
|
||||
"document",
|
||||
'document',
|
||||
[
|
||||
sa.Column("collection_name", sa.String(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=True),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("filename", sa.Text(), nullable=True),
|
||||
sa.Column("content", sa.Text(), nullable=True),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("timestamp", sa.BigInteger(), nullable=True),
|
||||
sa.Column('collection_name', sa.String(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=True),
|
||||
sa.Column('title', sa.Text(), nullable=True),
|
||||
sa.Column('filename', sa.Text(), nullable=True),
|
||||
sa.Column('content', sa.Text(), nullable=True),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('timestamp', sa.BigInteger(), nullable=True),
|
||||
],
|
||||
[sa.PrimaryKeyConstraint("collection_name"), sa.UniqueConstraint("name")],
|
||||
[sa.PrimaryKeyConstraint('collection_name'), sa.UniqueConstraint('name')],
|
||||
),
|
||||
(
|
||||
"file",
|
||||
'file',
|
||||
[
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("filename", sa.Text(), nullable=True),
|
||||
sa.Column("meta", JSONField(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('filename', sa.Text(), nullable=True),
|
||||
sa.Column('meta', JSONField(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
],
|
||||
[sa.PrimaryKeyConstraint("id")],
|
||||
[sa.PrimaryKeyConstraint('id')],
|
||||
),
|
||||
(
|
||||
"function",
|
||||
'function',
|
||||
[
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("name", sa.Text(), nullable=True),
|
||||
sa.Column("type", sa.Text(), nullable=True),
|
||||
sa.Column("content", sa.Text(), nullable=True),
|
||||
sa.Column("meta", JSONField(), nullable=True),
|
||||
sa.Column("valves", JSONField(), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=True),
|
||||
sa.Column("is_global", sa.Boolean(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('name', sa.Text(), nullable=True),
|
||||
sa.Column('type', sa.Text(), nullable=True),
|
||||
sa.Column('content', sa.Text(), nullable=True),
|
||||
sa.Column('meta', JSONField(), nullable=True),
|
||||
sa.Column('valves', JSONField(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=True),
|
||||
sa.Column('is_global', sa.Boolean(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
],
|
||||
[sa.PrimaryKeyConstraint("id")],
|
||||
[sa.PrimaryKeyConstraint('id')],
|
||||
),
|
||||
(
|
||||
"memory",
|
||||
'memory',
|
||||
[
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("content", sa.Text(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('content', sa.Text(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
],
|
||||
[sa.PrimaryKeyConstraint("id")],
|
||||
[sa.PrimaryKeyConstraint('id')],
|
||||
),
|
||||
(
|
||||
"model",
|
||||
'model',
|
||||
[
|
||||
sa.Column("id", sa.Text(), nullable=False),
|
||||
sa.Column("user_id", sa.Text(), nullable=True),
|
||||
sa.Column("base_model_id", sa.Text(), nullable=True),
|
||||
sa.Column("name", sa.Text(), nullable=True),
|
||||
sa.Column("params", JSONField(), nullable=True),
|
||||
sa.Column("meta", JSONField(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column('id', sa.Text(), nullable=False),
|
||||
sa.Column('user_id', sa.Text(), nullable=True),
|
||||
sa.Column('base_model_id', sa.Text(), nullable=True),
|
||||
sa.Column('name', sa.Text(), nullable=True),
|
||||
sa.Column('params', JSONField(), nullable=True),
|
||||
sa.Column('meta', JSONField(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
],
|
||||
[sa.PrimaryKeyConstraint("id")],
|
||||
[sa.PrimaryKeyConstraint('id')],
|
||||
),
|
||||
(
|
||||
"prompt",
|
||||
'prompt',
|
||||
[
|
||||
sa.Column("command", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("title", sa.Text(), nullable=True),
|
||||
sa.Column("content", sa.Text(), nullable=True),
|
||||
sa.Column("timestamp", sa.BigInteger(), nullable=True),
|
||||
sa.Column('command', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('title', sa.Text(), nullable=True),
|
||||
sa.Column('content', sa.Text(), nullable=True),
|
||||
sa.Column('timestamp', sa.BigInteger(), nullable=True),
|
||||
],
|
||||
[sa.PrimaryKeyConstraint("command")],
|
||||
[sa.PrimaryKeyConstraint('command')],
|
||||
),
|
||||
(
|
||||
"tag",
|
||||
'tag',
|
||||
[
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=True),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("data", sa.Text(), nullable=True),
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=True),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('data', sa.Text(), nullable=True),
|
||||
],
|
||||
[sa.PrimaryKeyConstraint("id")],
|
||||
[sa.PrimaryKeyConstraint('id')],
|
||||
),
|
||||
(
|
||||
"tool",
|
||||
'tool',
|
||||
[
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=True),
|
||||
sa.Column("name", sa.Text(), nullable=True),
|
||||
sa.Column("content", sa.Text(), nullable=True),
|
||||
sa.Column("specs", JSONField(), nullable=True),
|
||||
sa.Column("meta", JSONField(), nullable=True),
|
||||
sa.Column("valves", JSONField(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('user_id', sa.String(), nullable=True),
|
||||
sa.Column('name', sa.Text(), nullable=True),
|
||||
sa.Column('content', sa.Text(), nullable=True),
|
||||
sa.Column('specs', JSONField(), nullable=True),
|
||||
sa.Column('meta', JSONField(), nullable=True),
|
||||
sa.Column('valves', JSONField(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
],
|
||||
[sa.PrimaryKeyConstraint("id")],
|
||||
[sa.PrimaryKeyConstraint('id')],
|
||||
),
|
||||
(
|
||||
"user",
|
||||
'user',
|
||||
[
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=True),
|
||||
sa.Column("email", sa.String(), nullable=True),
|
||||
sa.Column("role", sa.String(), nullable=True),
|
||||
sa.Column("profile_image_url", sa.Text(), nullable=True),
|
||||
sa.Column("last_active_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("updated_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("created_at", sa.BigInteger(), nullable=True),
|
||||
sa.Column("api_key", sa.String(), nullable=True),
|
||||
sa.Column("settings", JSONField(), nullable=True),
|
||||
sa.Column("info", JSONField(), nullable=True),
|
||||
sa.Column("oauth_sub", sa.Text(), nullable=True),
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=True),
|
||||
sa.Column('email', sa.String(), nullable=True),
|
||||
sa.Column('role', sa.String(), nullable=True),
|
||||
sa.Column('profile_image_url', sa.Text(), nullable=True),
|
||||
sa.Column('last_active_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('updated_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('created_at', sa.BigInteger(), nullable=True),
|
||||
sa.Column('api_key', sa.String(), nullable=True),
|
||||
sa.Column('settings', JSONField(), nullable=True),
|
||||
sa.Column('info', JSONField(), nullable=True),
|
||||
sa.Column('oauth_sub', sa.Text(), nullable=True),
|
||||
],
|
||||
[
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("api_key"),
|
||||
sa.UniqueConstraint("oauth_sub"),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('api_key'),
|
||||
sa.UniqueConstraint('oauth_sub'),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
+2
-2
@@ -8,9 +8,9 @@ Create Date: 2025-12-10 16:07:58.001282
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import open_webui.internal.db
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '81cc2ce44d79'
|
||||
|
||||
@@ -6,13 +6,13 @@ Create Date: 2026-02-01 04:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
+2
-2
@@ -8,9 +8,9 @@ Create Date: 2025-11-30 06:33:38.790341
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import open_webui.internal.db
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '90ef40d4714e'
|
||||
|
||||
@@ -6,8 +6,8 @@ Create Date: 2024-11-14 03:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = '922e7a387820'
|
||||
down_revision = '4ace53fd72c8'
|
||||
|
||||
@@ -6,8 +6,8 @@ Create Date: 2025-05-03 03:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = '9f0c9cd09105'
|
||||
down_revision = '3781e22d8b01'
|
||||
|
||||
@@ -8,9 +8,8 @@ Create Date: 2026-02-11 09:30:00.000000
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
from open_webui.migrations.util import get_existing_tables
|
||||
|
||||
revision: str = 'a1b2c3d4e5f6'
|
||||
|
||||
+1
-1
@@ -8,8 +8,8 @@ Create Date: 2026-03-29 22:15:00.000000
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'a3dd5bedd151'
|
||||
|
||||
+1
-1
@@ -8,8 +8,8 @@ Create Date: 2025-09-27 02:24:18.058455
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'a5c220713937'
|
||||
|
||||
@@ -6,8 +6,8 @@ Create Date: 2024-10-20 17:02:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# Revision identifiers, used by Alembic.
|
||||
revision = 'af906e964978'
|
||||
|
||||
@@ -6,15 +6,13 @@ Create Date: 2025-11-28 04:55:31.737538
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
import open_webui.internal.db
|
||||
import json
|
||||
import time
|
||||
from typing import Sequence, Union
|
||||
|
||||
import open_webui.internal.db
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b10670c03dd5'
|
||||
|
||||
+1
-1
@@ -8,8 +8,8 @@ Create Date: 2026-02-13 14:19:00.000000
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'b2c3d4e5f6a7'
|
||||
|
||||
@@ -6,9 +6,8 @@ Create Date: 2026-04-01 04:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'b7c8d9e0f1a2'
|
||||
|
||||
@@ -9,8 +9,8 @@ Create Date: 2026-04-16 23:00:00.000000
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = 'c1d2e3f4a5b6'
|
||||
down_revision = 'e1f2a3b4c5d6'
|
||||
|
||||
@@ -6,11 +6,12 @@ Create Date: 2024-10-20 17:02:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import json
|
||||
from sqlalchemy.sql import table, column
|
||||
from sqlalchemy import String, Text, JSON, and_
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy import JSON, String, Text, and_
|
||||
from sqlalchemy.sql import column, table
|
||||
|
||||
revision = 'c29facfe716b'
|
||||
down_revision = 'c69f45358db4'
|
||||
|
||||
@@ -8,8 +8,8 @@ Create Date: 2025-12-21 20:27:41.694897
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'c440947495f3'
|
||||
|
||||
@@ -6,8 +6,8 @@ Create Date: 2024-10-16 02:02:35.241684
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = 'c69f45358db4'
|
||||
down_revision = '3ab32c4b8f59'
|
||||
|
||||
@@ -10,8 +10,8 @@ from typing import Sequence, Union
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision: str = "ca81bd47c050"
|
||||
down_revision: Union[str, None] = "7e5b5dc7342b"
|
||||
revision: str = 'ca81bd47c050'
|
||||
down_revision: Union[str, None] = '7e5b5dc7342b'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
@@ -19,18 +19,18 @@ depends_on: Union[str, Sequence[str], None] = None
|
||||
def upgrade() -> None:
|
||||
"""Create a key-value config table with versioning."""
|
||||
op.create_table(
|
||||
"config",
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column("data", sa.JSON(), nullable=False),
|
||||
sa.Column("version", sa.Integer, nullable=False),
|
||||
'config',
|
||||
sa.Column('id', sa.Integer, primary_key=True),
|
||||
sa.Column('data', sa.JSON(), nullable=False),
|
||||
sa.Column('version', sa.Integer, nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
'created_at',
|
||||
sa.DateTime(),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
'updated_at',
|
||||
sa.DateTime(),
|
||||
nullable=True,
|
||||
server_default=sa.func.now(),
|
||||
@@ -41,4 +41,4 @@ def upgrade() -> None:
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Drop the config table."""
|
||||
op.drop_table("config")
|
||||
op.drop_table('config')
|
||||
|
||||
@@ -6,8 +6,8 @@ Create Date: 2025-07-13 03:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = 'd31026856c01'
|
||||
down_revision = '9f0c9cd09105'
|
||||
|
||||
@@ -7,8 +7,8 @@ Create Date: 2026-03-30
|
||||
|
||||
from typing import Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision: str = 'd4e5f6a7b8c9'
|
||||
down_revision: Union[str, None] = 'a3dd5bedd151'
|
||||
|
||||
@@ -6,8 +6,8 @@ Create Date: 2026-04-14 22:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = 'e1f2a3b4c5d6'
|
||||
down_revision = 'b7c8d9e0f1a2'
|
||||
|
||||
@@ -11,13 +11,12 @@ Access control semantics:
|
||||
- {read: {...}, write: {...}}: Custom permissions -> insert specific grants
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
import time
|
||||
import uuid
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
from open_webui.migrations.util import get_existing_tables
|
||||
|
||||
revision: str = 'f1e2d3c4b5a6'
|
||||
|
||||
@@ -3,13 +3,11 @@ import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from open_webui.internal.db import Base, get_async_db_context
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, Text, UniqueConstraint, or_, and_
|
||||
from sqlalchemy import BigInteger, Column, Text, UniqueConstraint, and_, delete, or_, select
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -623,8 +621,8 @@ class AccessGrantsTable:
|
||||
Get all users who have the specified permission on a resource.
|
||||
Returns a list of UserModel instances.
|
||||
"""
|
||||
from open_webui.models.users import Users, UserModel
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.users import UserModel, Users
|
||||
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, delete, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from open_webui.models.users import User, UserModel, UserProfileImageResponse, Users
|
||||
from open_webui.utils.validate import validate_profile_image_url
|
||||
from pydantic import BaseModel, field_validator
|
||||
from sqlalchemy import Boolean, Column, String, Text
|
||||
from sqlalchemy import Boolean, Column, String, Text, delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -44,7 +45,7 @@ class Token(BaseModel):
|
||||
|
||||
|
||||
class ApiKey(BaseModel):
|
||||
api_key: Optional[str] = None
|
||||
api_key: str | None = None
|
||||
|
||||
|
||||
class SigninResponse(Token, UserProfileImageResponse):
|
||||
@@ -74,18 +75,18 @@ class SignupForm(BaseModel):
|
||||
name: str
|
||||
email: str
|
||||
password: str
|
||||
profile_image_url: Optional[str] = '/user.png'
|
||||
profile_image_url: str | None = '/user.png'
|
||||
|
||||
@field_validator('profile_image_url')
|
||||
@classmethod
|
||||
def check_profile_image_url(cls, v: Optional[str]) -> Optional[str]:
|
||||
def check_profile_image_url(cls, v: str | None) -> str | None:
|
||||
if v is not None:
|
||||
return validate_profile_image_url(v)
|
||||
return v
|
||||
|
||||
|
||||
class AddUserForm(SignupForm):
|
||||
role: Optional[str] = 'pending'
|
||||
role: str | None = 'pending'
|
||||
|
||||
|
||||
class AuthsTable:
|
||||
@@ -96,9 +97,9 @@ class AuthsTable:
|
||||
name: str,
|
||||
profile_image_url: str = '/user.png',
|
||||
role: str = 'pending',
|
||||
oauth: Optional[dict] = None,
|
||||
db: Optional[AsyncSession] = None,
|
||||
) -> Optional[UserModel]:
|
||||
oauth: dict | None = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
log.info('insert_new_auth')
|
||||
|
||||
@@ -119,8 +120,8 @@ class AuthsTable:
|
||||
return None
|
||||
|
||||
async def authenticate_user(
|
||||
self, email: str, verify_password: callable, db: Optional[AsyncSession] = None
|
||||
) -> Optional[UserModel]:
|
||||
self, email: str, verify_password: callable, db: AsyncSession | None = None
|
||||
) -> UserModel | None:
|
||||
log.info(f'authenticate_user: {email}')
|
||||
|
||||
user = await Users.get_user_by_email(email, db=db)
|
||||
@@ -141,9 +142,7 @@ class AuthsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def authenticate_user_by_api_key(
|
||||
self, api_key: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[UserModel]:
|
||||
async def authenticate_user_by_api_key(self, api_key: str, db: AsyncSession | None = None) -> UserModel | None:
|
||||
log.info(f'authenticate_user_by_api_key')
|
||||
# if no api_key, return None
|
||||
if not api_key:
|
||||
@@ -155,7 +154,7 @@ class AuthsTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def authenticate_user_by_email(self, email: str, db: Optional[AsyncSession] = None) -> Optional[UserModel]:
|
||||
async def authenticate_user_by_email(self, email: str, db: AsyncSession | None = None) -> UserModel | None:
|
||||
log.info(f'authenticate_user_by_email: {email}')
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
@@ -171,7 +170,7 @@ class AuthsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def update_user_password_by_id(self, id: str, new_password: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def update_user_password_by_id(self, id: str, new_password: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(update(Auth).filter_by(id=id).values(password=new_password))
|
||||
@@ -180,7 +179,7 @@ class AuthsTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def update_email_by_id(self, id: str, email: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def update_email_by_id(self, id: str, email: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(update(Auth).filter_by(id=id).values(email=email))
|
||||
@@ -192,7 +191,7 @@ class AuthsTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def delete_auth_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_auth_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
# Delete User
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import time
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import Column, Text, JSON, Boolean, BigInteger, Index, select, or_, func, cast, String, delete, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from open_webui.internal.db import Base, get_async_db_context
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import JSON, BigInteger, Boolean, Column, Index, String, Text, cast, delete, func, or_, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,30 +1,29 @@
|
||||
import time
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Text,
|
||||
JSON,
|
||||
Boolean,
|
||||
BigInteger,
|
||||
Index,
|
||||
UniqueConstraint,
|
||||
select,
|
||||
or_,
|
||||
exists,
|
||||
func,
|
||||
delete,
|
||||
update,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from open_webui.internal.db import Base, get_async_db_context
|
||||
from open_webui.models.access_grants import AccessGrantModel, AccessGrants
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.users import User, UserModel, UserResponse
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
Column,
|
||||
Index,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
delete,
|
||||
exists,
|
||||
func,
|
||||
or_,
|
||||
select,
|
||||
update,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -4,31 +4,33 @@ import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.utils.validate import validate_profile_image_url
|
||||
|
||||
from sqlalchemy import select, delete, update, func, case, or_, and_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.access_grants import (
|
||||
AccessGrantModel,
|
||||
AccessGrants,
|
||||
)
|
||||
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.utils.validate import validate_profile_image_url
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
Column,
|
||||
ForeignKey,
|
||||
String,
|
||||
Text,
|
||||
JSON,
|
||||
UniqueConstraint,
|
||||
and_,
|
||||
case,
|
||||
delete,
|
||||
func,
|
||||
or_,
|
||||
select,
|
||||
update,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
####################
|
||||
# Channel DB Schema
|
||||
|
||||
@@ -3,21 +3,24 @@ import time
|
||||
import uuid
|
||||
from typing import Any, Optional
|
||||
|
||||
from sqlalchemy import select, delete, func, cast, Integer
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from open_webui.internal.db import Base, get_async_db_context
|
||||
from open_webui.utils.response import normalize_usage
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
Column,
|
||||
ForeignKey,
|
||||
Text,
|
||||
JSON,
|
||||
Index,
|
||||
Integer,
|
||||
Text,
|
||||
cast,
|
||||
delete,
|
||||
func,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
####################
|
||||
# Helpers
|
||||
@@ -579,6 +582,7 @@ class ChatMessageTable:
|
||||
"""Get message counts grouped by day and model."""
|
||||
async with get_async_db_context(db) as db:
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from open_webui.models.groups import GroupMember
|
||||
|
||||
stmt = select(ChatMessage.created_at, ChatMessage.model_id).filter(
|
||||
|
||||
+116
-112
@@ -1,32 +1,39 @@
|
||||
import logging
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, delete, update, func, or_, and_, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import exists
|
||||
from sqlalchemy.sql.expression import bindparam
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from open_webui.models.tags import TagModel, Tag, Tags
|
||||
from open_webui.models.folders import Folders
|
||||
from open_webui.models.chat_messages import ChatMessage, ChatMessages
|
||||
from open_webui.models.automations import AutomationRun
|
||||
from open_webui.models.chat_messages import ChatMessage, ChatMessages
|
||||
from open_webui.models.folders import Folders
|
||||
from open_webui.models.tags import Tag, TagModel, Tags
|
||||
from open_webui.utils.misc import sanitize_data_for_db, sanitize_text_for_db
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
Column,
|
||||
ForeignKey,
|
||||
Index,
|
||||
String,
|
||||
Text,
|
||||
JSON,
|
||||
Index,
|
||||
UniqueConstraint,
|
||||
and_,
|
||||
delete,
|
||||
func,
|
||||
or_,
|
||||
select,
|
||||
text,
|
||||
update,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import exists
|
||||
from sqlalchemy.sql.expression import bindparam
|
||||
|
||||
####################
|
||||
# Chat DB Schema
|
||||
@@ -81,17 +88,17 @@ class ChatModel(BaseModel):
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
share_id: Optional[str] = None
|
||||
share_id: str | None = None
|
||||
archived: bool = False
|
||||
pinned: Optional[bool] = False
|
||||
pinned: bool | None = False
|
||||
|
||||
meta: dict = {}
|
||||
folder_id: Optional[str] = None
|
||||
folder_id: str | None = None
|
||||
|
||||
tasks: Optional[list] = None
|
||||
summary: Optional[str] = None
|
||||
tasks: list | None = None
|
||||
summary: str | None = None
|
||||
|
||||
last_read_at: Optional[int] = None
|
||||
last_read_at: int | None = None
|
||||
|
||||
|
||||
class ChatFile(Base):
|
||||
@@ -115,7 +122,7 @@ class ChatFileModel(BaseModel):
|
||||
user_id: str
|
||||
|
||||
chat_id: str
|
||||
message_id: Optional[str] = None
|
||||
message_id: str | None = None
|
||||
file_id: str
|
||||
|
||||
created_at: int
|
||||
@@ -131,14 +138,14 @@ class ChatFileModel(BaseModel):
|
||||
|
||||
class ChatForm(BaseModel):
|
||||
chat: dict
|
||||
folder_id: Optional[str] = None
|
||||
folder_id: str | None = None
|
||||
|
||||
|
||||
class ChatImportForm(ChatForm):
|
||||
meta: Optional[dict] = {}
|
||||
pinned: Optional[bool] = False
|
||||
created_at: Optional[int] = None
|
||||
updated_at: Optional[int] = None
|
||||
meta: dict | None = {}
|
||||
pinned: bool | None = False
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
|
||||
|
||||
class ChatsImportForm(BaseModel):
|
||||
@@ -161,14 +168,14 @@ class ChatResponse(BaseModel):
|
||||
chat: dict
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
share_id: Optional[str] = None # id of the chat to be shared
|
||||
share_id: str | None = None # id of the chat to be shared
|
||||
archived: bool
|
||||
pinned: Optional[bool] = False
|
||||
pinned: bool | None = False
|
||||
meta: dict = {}
|
||||
folder_id: Optional[str] = None
|
||||
folder_id: str | None = None
|
||||
|
||||
tasks: Optional[list] = None
|
||||
summary: Optional[str] = None
|
||||
tasks: list | None = None
|
||||
summary: str | None = None
|
||||
|
||||
|
||||
class ChatTitleIdResponse(BaseModel):
|
||||
@@ -176,13 +183,13 @@ class ChatTitleIdResponse(BaseModel):
|
||||
title: str
|
||||
updated_at: int
|
||||
created_at: int
|
||||
last_read_at: Optional[int] = None
|
||||
last_read_at: int | None = None
|
||||
|
||||
|
||||
class SharedChatResponse(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
share_id: Optional[str] = None
|
||||
share_id: str | None = None
|
||||
updated_at: int
|
||||
created_at: int
|
||||
|
||||
@@ -225,17 +232,17 @@ class ChatUsageStatsListResponse(BaseModel):
|
||||
class MessageStats(BaseModel):
|
||||
id: str
|
||||
role: str
|
||||
model: Optional[str] = None
|
||||
model: str | None = None
|
||||
content_length: int
|
||||
token_count: Optional[int] = None
|
||||
timestamp: Optional[int] = None
|
||||
rating: Optional[int] = None # Derived from message.annotation.rating
|
||||
tags: Optional[list[str]] = None # Derived from message.annotation.tags
|
||||
token_count: int | None = None
|
||||
timestamp: int | None = None
|
||||
rating: int | None = None # Derived from message.annotation.rating
|
||||
tags: list[str | None] = None # Derived from message.annotation.tags
|
||||
|
||||
|
||||
class ChatHistoryStats(BaseModel):
|
||||
messages: dict[str, MessageStats]
|
||||
currentId: Optional[str] = None
|
||||
currentId: str | None = None
|
||||
|
||||
|
||||
class ChatBody(BaseModel):
|
||||
@@ -293,8 +300,8 @@ class ChatTable:
|
||||
return changed
|
||||
|
||||
async def insert_new_chat(
|
||||
self, id: str, user_id: str, form_data: ChatForm, db: Optional[AsyncSession] = None
|
||||
) -> Optional[ChatModel]:
|
||||
self, id: str, user_id: str, form_data: ChatForm, db: AsyncSession | None = None
|
||||
) -> ChatModel | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat = ChatModel(
|
||||
**{
|
||||
@@ -353,7 +360,7 @@ class ChatTable:
|
||||
self,
|
||||
user_id: str,
|
||||
chat_import_forms: list[ChatImportForm],
|
||||
db: Optional[AsyncSession] = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> list[ChatModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
chats = []
|
||||
@@ -383,7 +390,7 @@ class ChatTable:
|
||||
|
||||
return [ChatModel.model_validate(chat) for chat in chats]
|
||||
|
||||
async def update_chat_by_id(self, id: str, chat: dict, db: Optional[AsyncSession] = None) -> Optional[ChatModel]:
|
||||
async def update_chat_by_id(self, id: str, chat: dict, db: AsyncSession | None = None) -> ChatModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat_item = await db.get(Chat, id)
|
||||
@@ -398,7 +405,7 @@ class ChatTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def update_chat_last_read_at_by_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def update_chat_last_read_at_by_id(self, id: str, user_id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat = await db.get(Chat, id)
|
||||
@@ -410,7 +417,7 @@ class ChatTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def update_chat_title_by_id(self, id: str, title: str) -> Optional[ChatModel]:
|
||||
async def update_chat_title_by_id(self, id: str, title: str) -> ChatModel | None:
|
||||
try:
|
||||
async with get_async_db_context() as db:
|
||||
chat_item = await db.get(Chat, id)
|
||||
@@ -426,7 +433,7 @@ class ChatTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def update_chat_tags_by_id(self, id: str, tags: list[str], user) -> Optional[ChatModel]:
|
||||
async def update_chat_tags_by_id(self, id: str, tags: list[str], user) -> ChatModel | None:
|
||||
async with get_async_db_context() as db:
|
||||
chat = await db.get(Chat, id)
|
||||
if chat is None:
|
||||
@@ -451,7 +458,7 @@ class ChatTable:
|
||||
|
||||
return ChatModel.model_validate(chat)
|
||||
|
||||
async def get_chat_title_by_id(self, id: str) -> Optional[str]:
|
||||
async def get_chat_title_by_id(self, id: str) -> str | None:
|
||||
async with get_async_db_context() as db:
|
||||
result = await db.execute(select(Chat.title).filter_by(id=id))
|
||||
row = result.first()
|
||||
@@ -488,7 +495,7 @@ class ChatTable:
|
||||
except Exception as e:
|
||||
log.warning('Backfill failed for message %s in chat %s: %s', message_id, chat_id, e)
|
||||
|
||||
async def get_messages_map_by_chat_id(self, id: str) -> Optional[dict]:
|
||||
async def get_messages_map_by_chat_id(self, id: str) -> dict | None:
|
||||
"""Message map for walking history (see ``get_message_list``).
|
||||
|
||||
Prefer ``chat_message`` rows to avoid loading the large embedded
|
||||
@@ -541,7 +548,7 @@ class ChatTable:
|
||||
|
||||
return history_messages
|
||||
|
||||
async def get_message_by_id_and_message_id(self, id: str, message_id: str) -> Optional[dict]:
|
||||
async def get_message_by_id_and_message_id(self, id: str, message_id: str) -> dict | None:
|
||||
chat = await self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
@@ -550,7 +557,7 @@ class ChatTable:
|
||||
|
||||
async def upsert_message_to_chat_by_id_and_message_id(
|
||||
self, id: str, message_id: str, message: dict
|
||||
) -> Optional[ChatModel]:
|
||||
) -> ChatModel | None:
|
||||
chat = await self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
@@ -590,7 +597,7 @@ class ChatTable:
|
||||
|
||||
async def add_message_status_to_chat_by_id_and_message_id(
|
||||
self, id: str, message_id: str, status: dict
|
||||
) -> Optional[ChatModel]:
|
||||
) -> ChatModel | None:
|
||||
chat = await self.get_chat_by_id(id)
|
||||
if chat is None:
|
||||
return None
|
||||
@@ -626,9 +633,7 @@ class ChatTable:
|
||||
await self.update_chat_by_id(id, chat, db=db)
|
||||
return message_files
|
||||
|
||||
async def insert_shared_chat_by_chat_id(
|
||||
self, chat_id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[ChatModel]:
|
||||
async def insert_shared_chat_by_chat_id(self, chat_id: str, db: AsyncSession | None = None) -> ChatModel | None:
|
||||
"""Create a shared snapshot for a chat. Returns the original chat with share_id set."""
|
||||
from open_webui.models.shared_chats import SharedChats
|
||||
|
||||
@@ -651,9 +656,7 @@ class ChatTable:
|
||||
await db.refresh(chat)
|
||||
return ChatModel.model_validate(chat)
|
||||
|
||||
async def update_shared_chat_by_chat_id(
|
||||
self, chat_id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[ChatModel]:
|
||||
async def update_shared_chat_by_chat_id(self, chat_id: str, db: AsyncSession | None = None) -> ChatModel | None:
|
||||
"""Re-snapshot the shared chat with current chat data."""
|
||||
from open_webui.models.shared_chats import SharedChats
|
||||
|
||||
@@ -668,7 +671,7 @@ class ChatTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def delete_shared_chat_by_chat_id(self, chat_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_shared_chat_by_chat_id(self, chat_id: str, db: AsyncSession | None = None) -> bool:
|
||||
"""Delete shared snapshot for a chat."""
|
||||
from open_webui.models.shared_chats import SharedChats
|
||||
|
||||
@@ -677,7 +680,7 @@ class ChatTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def unarchive_all_chats_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def unarchive_all_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(update(Chat).filter_by(user_id=user_id).values(archived=False))
|
||||
@@ -687,8 +690,8 @@ class ChatTable:
|
||||
return False
|
||||
|
||||
async def update_chat_share_id_by_id(
|
||||
self, id: str, share_id: Optional[str], db: Optional[AsyncSession] = None
|
||||
) -> Optional[ChatModel]:
|
||||
self, id: str, share_id: str | None, db: AsyncSession | None = None
|
||||
) -> ChatModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat = await db.get(Chat, id)
|
||||
@@ -699,7 +702,7 @@ class ChatTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def toggle_chat_pinned_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ChatModel]:
|
||||
async def toggle_chat_pinned_by_id(self, id: str, db: AsyncSession | None = None) -> ChatModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat = await db.get(Chat, id)
|
||||
@@ -711,7 +714,7 @@ class ChatTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def toggle_chat_archive_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ChatModel]:
|
||||
async def toggle_chat_archive_by_id(self, id: str, db: AsyncSession | None = None) -> ChatModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat = await db.get(Chat, id)
|
||||
@@ -724,7 +727,7 @@ class ChatTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def archive_all_chats_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def archive_all_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(update(Chat).filter_by(user_id=user_id).values(archived=True))
|
||||
@@ -736,10 +739,10 @@ class ChatTable:
|
||||
async def get_archived_chat_list_by_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
filter: Optional[dict] = None,
|
||||
filter: dict | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
db: Optional[AsyncSession] = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at).filter_by(
|
||||
@@ -789,10 +792,10 @@ class ChatTable:
|
||||
async def get_shared_chat_list_by_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
filter: Optional[dict] = None,
|
||||
filter: dict | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
db: Optional[AsyncSession] = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> list[SharedChatResponse]:
|
||||
"""Delegate to SharedChats for listing shared chats by user."""
|
||||
from open_webui.models.shared_chats import SharedChats
|
||||
@@ -803,10 +806,10 @@ class ChatTable:
|
||||
self,
|
||||
user_id: str,
|
||||
include_archived: bool = False,
|
||||
filter: Optional[dict] = None,
|
||||
filter: dict | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
db: Optional[AsyncSession] = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by(
|
||||
@@ -859,9 +862,9 @@ class ChatTable:
|
||||
include_archived: bool = False,
|
||||
include_folders: bool = False,
|
||||
include_pinned: bool = False,
|
||||
skip: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
db: Optional[AsyncSession] = None,
|
||||
skip: int | None = None,
|
||||
limit: int | None = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by(
|
||||
@@ -905,7 +908,7 @@ class ChatTable:
|
||||
chat_ids: list[str],
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
db: Optional[AsyncSession] = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> list[ChatModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
@@ -914,7 +917,7 @@ class ChatTable:
|
||||
all_chats = result.scalars().all()
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
async def get_chat_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ChatModel]:
|
||||
async def get_chat_by_id(self, id: str, db: AsyncSession | None = None) -> ChatModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat_item = await db.get(Chat, id)
|
||||
@@ -929,7 +932,7 @@ class ChatTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_chat_by_share_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ChatModel]:
|
||||
async def get_chat_by_share_id(self, id: str, db: AsyncSession | None = None) -> ChatModel | None:
|
||||
"""Look up a shared chat snapshot by its share token."""
|
||||
from open_webui.models.shared_chats import SharedChats
|
||||
|
||||
@@ -951,8 +954,8 @@ class ChatTable:
|
||||
return None
|
||||
|
||||
async def get_chat_by_id_and_user_id(
|
||||
self, id: str, user_id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[ChatModel]:
|
||||
self, id: str, user_id: str, db: AsyncSession | None = None
|
||||
) -> ChatModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Chat).filter_by(id=id, user_id=user_id))
|
||||
@@ -961,7 +964,7 @@ class ChatTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def is_chat_owner(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def is_chat_owner(self, id: str, user_id: str, db: AsyncSession | None = None) -> bool:
|
||||
"""
|
||||
Lightweight ownership check — uses EXISTS subquery instead of loading
|
||||
the full Chat row (which includes the potentially large JSON blob).
|
||||
@@ -973,7 +976,7 @@ class ChatTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_chat_folder_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> Optional[str]:
|
||||
async def get_chat_folder_id(self, id: str, user_id: str, db: AsyncSession | None = None) -> str | None:
|
||||
"""
|
||||
Fetch only the folder_id column for a chat, without loading the full
|
||||
JSON blob. Returns None if chat doesn't exist or doesn't belong to user.
|
||||
@@ -986,7 +989,7 @@ class ChatTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_chats(self, skip: int = 0, limit: int = 50, db: Optional[AsyncSession] = None) -> list[ChatModel]:
|
||||
async def get_chats(self, skip: int = 0, limit: int = 50, db: AsyncSession | None = None) -> list[ChatModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Chat).order_by(Chat.updated_at.desc()))
|
||||
all_chats = result.scalars().all()
|
||||
@@ -995,10 +998,10 @@ class ChatTable:
|
||||
async def get_chats_by_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
filter: Optional[dict] = None,
|
||||
skip: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
db: Optional[AsyncSession] = None,
|
||||
filter: dict | None = None,
|
||||
skip: int | None = None,
|
||||
limit: int | None = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> ChatListResponse:
|
||||
async with get_async_db_context(db) as db:
|
||||
stmt = select(Chat).filter_by(user_id=user_id)
|
||||
@@ -1041,7 +1044,7 @@ class ChatTable:
|
||||
)
|
||||
|
||||
async def get_pinned_chats_by_user_id(
|
||||
self, user_id: str, db: Optional[AsyncSession] = None
|
||||
self, user_id: str, db: AsyncSession | None = None
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
@@ -1063,7 +1066,7 @@ class ChatTable:
|
||||
for chat in all_chats
|
||||
]
|
||||
|
||||
async def get_archived_chats_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> list[ChatModel]:
|
||||
async def get_archived_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> list[ChatModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(Chat).filter_by(user_id=user_id, archived=True).order_by(Chat.updated_at.desc())
|
||||
@@ -1077,7 +1080,7 @@ class ChatTable:
|
||||
include_archived: bool = False,
|
||||
skip: int = 0,
|
||||
limit: int = 60,
|
||||
db: Optional[AsyncSession] = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> list[ChatModel]:
|
||||
"""
|
||||
Filters chats based on a search query using Python, allowing pagination using skip and limit.
|
||||
@@ -1270,7 +1273,7 @@ class ChatTable:
|
||||
user_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 60,
|
||||
db: Optional[AsyncSession] = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
stmt = (
|
||||
@@ -1302,7 +1305,7 @@ class ChatTable:
|
||||
]
|
||||
|
||||
async def get_chats_by_folder_ids_and_user_id(
|
||||
self, folder_ids: list[str], user_id: str, db: Optional[AsyncSession] = None
|
||||
self, folder_ids: list[str], user_id: str, db: AsyncSession | None = None
|
||||
) -> list[ChatModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
stmt = (
|
||||
@@ -1318,8 +1321,8 @@ class ChatTable:
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
async def update_chat_folder_id_by_id_and_user_id(
|
||||
self, id: str, user_id: str, folder_id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[ChatModel]:
|
||||
self, id: str, user_id: str, folder_id: str, db: AsyncSession | None = None
|
||||
) -> ChatModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat = await db.get(Chat, id)
|
||||
@@ -1333,7 +1336,7 @@ class ChatTable:
|
||||
return None
|
||||
|
||||
async def get_chat_tags_by_id_and_user_id(
|
||||
self, id: str, user_id: str, db: Optional[AsyncSession] = None
|
||||
self, id: str, user_id: str, db: AsyncSession | None = None
|
||||
) -> list[TagModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
stmt = select(Chat.meta).where(Chat.id == id)
|
||||
@@ -1348,7 +1351,7 @@ class ChatTable:
|
||||
tag_name: str,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
db: Optional[AsyncSession] = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by(
|
||||
@@ -1393,8 +1396,8 @@ class ChatTable:
|
||||
]
|
||||
|
||||
async def add_chat_tag_by_id_and_user_id_and_tag_name(
|
||||
self, id: str, user_id: str, tag_name: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[ChatModel]:
|
||||
self, id: str, user_id: str, tag_name: str, db: AsyncSession | None = None
|
||||
) -> ChatModel | None:
|
||||
tag_id = tag_name.replace(' ', '_').lower()
|
||||
await Tags.ensure_tags_exist([tag_name], user_id, db=db)
|
||||
try:
|
||||
@@ -1412,7 +1415,7 @@ class ChatTable:
|
||||
return None
|
||||
|
||||
async def count_chats_by_tag_name_and_user_id(
|
||||
self, tag_name: str, user_id: str, db: Optional[AsyncSession] = None
|
||||
self, tag_name: str, user_id: str, db: AsyncSession | None = None
|
||||
) -> int:
|
||||
async with get_async_db_context(db) as db:
|
||||
stmt = select(func.count(Chat.id)).filter_by(user_id=user_id, archived=False)
|
||||
@@ -1439,7 +1442,7 @@ class ChatTable:
|
||||
tag_ids: list[str],
|
||||
user_id: str,
|
||||
threshold: int = 0,
|
||||
db: Optional[AsyncSession] = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> None:
|
||||
"""Delete tag rows from *tag_ids* that appear in at most *threshold*
|
||||
non-archived chats for *user_id*. One query to find orphans, one to
|
||||
@@ -1460,7 +1463,7 @@ class ChatTable:
|
||||
await Tags.delete_tags_by_ids_and_user_id(orphans, user_id, db=db)
|
||||
|
||||
async def count_chats_by_folder_id_and_user_id(
|
||||
self, folder_id: str, user_id: str, db: Optional[AsyncSession] = None
|
||||
self, folder_id: str, user_id: str, db: AsyncSession | None = None
|
||||
) -> int:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(func.count(Chat.id)).filter_by(user_id=user_id, folder_id=folder_id))
|
||||
@@ -1470,7 +1473,7 @@ class ChatTable:
|
||||
return count
|
||||
|
||||
async def delete_tag_by_id_and_user_id_and_tag_name(
|
||||
self, id: str, user_id: str, tag_name: str, db: Optional[AsyncSession] = None
|
||||
self, id: str, user_id: str, tag_name: str, db: AsyncSession | None = None
|
||||
) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
@@ -1488,7 +1491,7 @@ class ChatTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat = await db.get(Chat, id)
|
||||
@@ -1502,7 +1505,7 @@ class ChatTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def delete_chat_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_chat_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None))
|
||||
@@ -1514,7 +1517,7 @@ class ChatTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def delete_chat_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_chat_by_id_and_user_id(self, id: str, user_id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None))
|
||||
@@ -1526,7 +1529,7 @@ class ChatTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def delete_chats_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await self.delete_shared_chats_by_user_id(user_id, db=db)
|
||||
@@ -1548,7 +1551,7 @@ class ChatTable:
|
||||
return False
|
||||
|
||||
async def delete_chats_by_user_id_and_folder_id(
|
||||
self, user_id: str, folder_id: str, db: Optional[AsyncSession] = None
|
||||
self, user_id: str, folder_id: str, db: AsyncSession | None = None
|
||||
) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
@@ -1568,8 +1571,8 @@ class ChatTable:
|
||||
self,
|
||||
user_id: str,
|
||||
folder_id: str,
|
||||
new_folder_id: Optional[str],
|
||||
db: Optional[AsyncSession] = None,
|
||||
new_folder_id: str | None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
@@ -1582,9 +1585,10 @@ class ChatTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def delete_shared_chats_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_shared_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool:
|
||||
"""Delete all shared chat snapshots created by a user."""
|
||||
from open_webui.models.shared_chats import SharedChats, SharedChat as SharedChatTable
|
||||
from open_webui.models.shared_chats import SharedChat as SharedChatTable
|
||||
from open_webui.models.shared_chats import SharedChats
|
||||
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
@@ -1605,8 +1609,8 @@ class ChatTable:
|
||||
message_id: str,
|
||||
file_ids: list[str],
|
||||
user_id: str,
|
||||
db: Optional[AsyncSession] = None,
|
||||
) -> Optional[list[ChatFileModel]]:
|
||||
db: AsyncSession | None = None,
|
||||
) -> list[ChatFileModel | None]:
|
||||
if not file_ids:
|
||||
return None
|
||||
|
||||
@@ -1645,7 +1649,7 @@ class ChatTable:
|
||||
return None
|
||||
|
||||
async def get_chat_files_by_chat_id_and_message_id(
|
||||
self, chat_id: str, message_id: str, db: Optional[AsyncSession] = None
|
||||
self, chat_id: str, message_id: str, db: AsyncSession | None = None
|
||||
) -> list[ChatFileModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
@@ -1654,7 +1658,7 @@ class ChatTable:
|
||||
all_chat_files = result.scalars().all()
|
||||
return [ChatFileModel.model_validate(chat_file) for chat_file in all_chat_files]
|
||||
|
||||
async def delete_chat_file(self, chat_id: str, file_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_chat_file(self, chat_id: str, file_id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(delete(ChatFile).filter_by(chat_id=chat_id, file_id=file_id))
|
||||
@@ -1663,7 +1667,7 @@ class ChatTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_shared_chat_ids_by_file_id(self, file_id: str, db: Optional[AsyncSession] = None) -> list[str]:
|
||||
async def get_shared_chat_ids_by_file_id(self, file_id: str, db: AsyncSession | None = None) -> list[str]:
|
||||
"""Return IDs of chats that contain this file and have an active share link."""
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
@@ -1673,7 +1677,7 @@ class ChatTable:
|
||||
)
|
||||
return [row[0] for row in result.all()]
|
||||
|
||||
async def update_chat_tasks_by_id(self, id: str, tasks: list[dict]) -> Optional[ChatModel]:
|
||||
async def update_chat_tasks_by_id(self, id: str, tasks: list[dict]) -> ChatModel | None:
|
||||
"""Update the tasks list on a chat."""
|
||||
try:
|
||||
async with get_async_db_context() as db:
|
||||
|
||||
@@ -3,13 +3,11 @@ import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, delete, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from open_webui.models.users import User, UserModel
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
|
||||
from sqlalchemy import JSON, BigInteger, Boolean, Column, Text, delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -319,8 +317,8 @@ class FeedbackTable:
|
||||
If days=0, returns all time data starting from first feedback.
|
||||
Returns: [{"date": "2026-01-08", "won": 5, "lost": 2}, ...]
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
async with get_async_db_context(db) as db:
|
||||
if days == 0:
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, delete, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from open_webui.utils.misc import sanitize_metadata
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
from sqlalchemy import BigInteger, Column, String, Text, JSON
|
||||
from sqlalchemy import JSON, BigInteger, Column, String, Text, delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,16 +40,16 @@ class FileModel(BaseModel):
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
hash: Optional[str] = None
|
||||
hash: str | None = None
|
||||
|
||||
filename: str
|
||||
path: Optional[str] = None
|
||||
path: str | None = None
|
||||
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
data: dict | None = None
|
||||
meta: dict | None = None
|
||||
|
||||
created_at: Optional[int] # timestamp in epoch
|
||||
updated_at: Optional[int] # timestamp in epoch
|
||||
created_at: int | None # timestamp in epoch
|
||||
updated_at: int | None # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
@@ -57,9 +58,9 @@ class FileModel(BaseModel):
|
||||
|
||||
|
||||
class FileMeta(BaseModel):
|
||||
name: Optional[str] = None
|
||||
content_type: Optional[str] = None
|
||||
size: Optional[int] = None
|
||||
name: str | None = None
|
||||
content_type: str | None = None
|
||||
size: int | None = None
|
||||
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
@@ -84,22 +85,22 @@ class FileMeta(BaseModel):
|
||||
class FileModelResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
hash: Optional[str] = None
|
||||
hash: str | None = None
|
||||
|
||||
filename: str
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[FileMeta] = None
|
||||
data: dict | None = None
|
||||
meta: FileMeta | None = None
|
||||
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: Optional[int] = None # timestamp in epoch, optional for legacy files
|
||||
updated_at: int | None = None # timestamp in epoch, optional for legacy files
|
||||
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
|
||||
class FileMetadataResponse(BaseModel):
|
||||
id: str
|
||||
hash: Optional[str] = None
|
||||
meta: Optional[dict] = None
|
||||
hash: str | None = None
|
||||
meta: dict | None = None
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
@@ -111,7 +112,7 @@ class FileListResponse(BaseModel):
|
||||
|
||||
class FileForm(BaseModel):
|
||||
id: str
|
||||
hash: Optional[str] = None
|
||||
hash: str | None = None
|
||||
filename: str
|
||||
path: str
|
||||
data: dict = {}
|
||||
@@ -119,15 +120,15 @@ class FileForm(BaseModel):
|
||||
|
||||
|
||||
class FileUpdateForm(BaseModel):
|
||||
hash: Optional[str] = None
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
hash: str | None = None
|
||||
data: dict | None = None
|
||||
meta: dict | None = None
|
||||
|
||||
|
||||
class FilesTable:
|
||||
async def insert_new_file(
|
||||
self, user_id: str, form_data: FileForm, db: Optional[AsyncSession] = None
|
||||
) -> Optional[FileModel]:
|
||||
self, user_id: str, form_data: FileForm, db: AsyncSession | None = None
|
||||
) -> FileModel | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
file_data = form_data.model_dump()
|
||||
|
||||
@@ -158,7 +159,7 @@ class FilesTable:
|
||||
log.exception(f'Error inserting a new file: {e}')
|
||||
return None
|
||||
|
||||
async def get_file_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[FileModel]:
|
||||
async def get_file_by_id(self, id: str, db: AsyncSession | None = None) -> FileModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
@@ -170,8 +171,8 @@ class FilesTable:
|
||||
return None
|
||||
|
||||
async def get_file_by_id_and_user_id(
|
||||
self, id: str, user_id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[FileModel]:
|
||||
self, id: str, user_id: str, db: AsyncSession | None = None
|
||||
) -> FileModel | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
result = await db.execute(select(File).filter_by(id=id, user_id=user_id))
|
||||
@@ -183,9 +184,7 @@ class FilesTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_file_metadata_by_id(
|
||||
self, id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[FileMetadataResponse]:
|
||||
async def get_file_metadata_by_id(self, id: str, db: AsyncSession | None = None) -> FileMetadataResponse | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
file = await db.get(File, id)
|
||||
@@ -201,12 +200,12 @@ class FilesTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_files(self, db: Optional[AsyncSession] = None) -> list[FileModel]:
|
||||
async def get_files(self, db: AsyncSession | None = None) -> list[FileModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(File))
|
||||
return [FileModel.model_validate(file) for file in result.scalars().all()]
|
||||
|
||||
async def check_access_by_user_id(self, id, user_id, permission='write', db: Optional[AsyncSession] = None) -> bool:
|
||||
async def check_access_by_user_id(self, id, user_id, permission='write', db: AsyncSession | None = None) -> bool:
|
||||
file = await self.get_file_by_id(id, db=db)
|
||||
if not file:
|
||||
return False
|
||||
@@ -215,13 +214,13 @@ class FilesTable:
|
||||
# Implement additional access control logic here as needed
|
||||
return False
|
||||
|
||||
async def get_files_by_ids(self, ids: list[str], db: Optional[AsyncSession] = None) -> list[FileModel]:
|
||||
async def get_files_by_ids(self, ids: list[str], db: AsyncSession | None = None) -> list[FileModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(File).filter(File.id.in_(ids)).order_by(File.updated_at.desc()))
|
||||
return [FileModel.model_validate(file) for file in result.scalars().all()]
|
||||
|
||||
async def get_file_metadatas_by_ids(
|
||||
self, ids: list[str], db: Optional[AsyncSession] = None
|
||||
self, ids: list[str], db: AsyncSession | None = None
|
||||
) -> list[FileMetadataResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
@@ -240,17 +239,17 @@ class FilesTable:
|
||||
for row in result.all()
|
||||
]
|
||||
|
||||
async def get_files_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> list[FileModel]:
|
||||
async def get_files_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> list[FileModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(File).filter_by(user_id=user_id))
|
||||
return [FileModel.model_validate(file) for file in result.scalars().all()]
|
||||
|
||||
async def get_file_list(
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
user_id: str | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50,
|
||||
db: Optional[AsyncSession] = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> 'FileListResponse':
|
||||
async with get_async_db_context(db) as db:
|
||||
stmt = select(File)
|
||||
@@ -290,11 +289,11 @@ class FilesTable:
|
||||
|
||||
async def search_files(
|
||||
self,
|
||||
user_id: Optional[str] = None,
|
||||
user_id: str | None = None,
|
||||
filename: str = '*',
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
db: Optional[AsyncSession] = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> list[FileModel]:
|
||||
"""
|
||||
Search files with glob pattern matching, optional user filter, and pagination.
|
||||
@@ -323,8 +322,8 @@ class FilesTable:
|
||||
return [FileModel.model_validate(file) for file in result.scalars().all()]
|
||||
|
||||
async def update_file_by_id(
|
||||
self, id: str, form_data: FileUpdateForm, db: Optional[AsyncSession] = None
|
||||
) -> Optional[FileModel]:
|
||||
self, id: str, form_data: FileUpdateForm, db: AsyncSession | None = None
|
||||
) -> FileModel | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
result = await db.execute(select(File).filter_by(id=id))
|
||||
@@ -347,8 +346,8 @@ class FilesTable:
|
||||
return None
|
||||
|
||||
async def update_file_hash_by_id(
|
||||
self, id: str, hash: Optional[str], db: Optional[AsyncSession] = None
|
||||
) -> Optional[FileModel]:
|
||||
self, id: str, hash: str | None, db: AsyncSession | None = None
|
||||
) -> FileModel | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
result = await db.execute(select(File).filter_by(id=id))
|
||||
@@ -361,9 +360,7 @@ class FilesTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def update_file_data_by_id(
|
||||
self, id: str, data: dict, db: Optional[AsyncSession] = None
|
||||
) -> Optional[FileModel]:
|
||||
async def update_file_data_by_id(self, id: str, data: dict, db: AsyncSession | None = None) -> FileModel | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
result = await db.execute(select(File).filter_by(id=id))
|
||||
@@ -375,9 +372,7 @@ class FilesTable:
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
async def update_file_metadata_by_id(
|
||||
self, id: str, meta: dict, db: Optional[AsyncSession] = None
|
||||
) -> Optional[FileModel]:
|
||||
async def update_file_metadata_by_id(self, id: str, meta: dict, db: AsyncSession | None = None) -> FileModel | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
result = await db.execute(select(File).filter_by(id=id))
|
||||
@@ -389,7 +384,7 @@ class FilesTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def delete_file_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_file_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
await db.execute(delete(File).filter_by(id=id))
|
||||
@@ -399,7 +394,7 @@ class FilesTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def delete_all_files(self, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_all_files(self, db: AsyncSession | None = None) -> bool:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
await db.execute(delete(File))
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
import re
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean, func, select, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import JSON, BigInteger, Boolean, Column, Text, delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, delete, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from open_webui.models.users import Users, UserModel, UserResponse
|
||||
from open_webui.models.users import UserModel, UserResponse, Users
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, Index
|
||||
from sqlalchemy import BigInteger, Boolean, Column, Index, String, Text, delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -37,8 +38,8 @@ class Function(Base):
|
||||
|
||||
|
||||
class FunctionMeta(BaseModel):
|
||||
description: Optional[str] = None
|
||||
manifest: Optional[dict] = {}
|
||||
description: str | None = None
|
||||
manifest: dict | None = {}
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
|
||||
@@ -64,7 +65,7 @@ class FunctionWithValvesModel(BaseModel):
|
||||
type: str
|
||||
content: str
|
||||
meta: FunctionMeta
|
||||
valves: Optional[dict] = None
|
||||
valves: dict | None = None
|
||||
is_active: bool = False
|
||||
is_global: bool = False
|
||||
updated_at: int # timestamp in epoch
|
||||
@@ -93,7 +94,7 @@ class FunctionResponse(BaseModel):
|
||||
|
||||
|
||||
class FunctionUserResponse(FunctionResponse):
|
||||
user: Optional[UserResponse] = None
|
||||
user: UserResponse | None = None
|
||||
|
||||
|
||||
class FunctionForm(BaseModel):
|
||||
@@ -104,7 +105,7 @@ class FunctionForm(BaseModel):
|
||||
|
||||
|
||||
class FunctionValves(BaseModel):
|
||||
valves: Optional[dict] = None
|
||||
valves: dict | None = None
|
||||
|
||||
|
||||
class FunctionsTable:
|
||||
@@ -113,8 +114,8 @@ class FunctionsTable:
|
||||
user_id: str,
|
||||
type: str,
|
||||
form_data: FunctionForm,
|
||||
db: Optional[AsyncSession] = None,
|
||||
) -> Optional[FunctionModel]:
|
||||
db: AsyncSession | None = None,
|
||||
) -> FunctionModel | None:
|
||||
function = FunctionModel(
|
||||
**{
|
||||
**form_data.model_dump(),
|
||||
@@ -143,7 +144,7 @@ class FunctionsTable:
|
||||
self,
|
||||
user_id: str,
|
||||
functions: list[FunctionWithValvesModel],
|
||||
db: Optional[AsyncSession] = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> list[FunctionWithValvesModel]:
|
||||
# Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present.
|
||||
try:
|
||||
@@ -191,7 +192,7 @@ class FunctionsTable:
|
||||
log.exception(f'Error syncing functions for user {user_id}: {e}')
|
||||
return []
|
||||
|
||||
async def get_function_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[FunctionModel]:
|
||||
async def get_function_by_id(self, id: str, db: AsyncSession | None = None) -> FunctionModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
function = await db.get(Function, id)
|
||||
@@ -199,7 +200,7 @@ class FunctionsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_functions_by_ids(self, ids: list[str], db: Optional[AsyncSession] = None) -> list[FunctionModel]:
|
||||
async def get_functions_by_ids(self, ids: list[str], db: AsyncSession | None = None) -> list[FunctionModel]:
|
||||
"""
|
||||
Batch fetch multiple functions by their IDs in a single query.
|
||||
Returns functions in the same order as the input IDs (None entries filtered out).
|
||||
@@ -218,7 +219,7 @@ class FunctionsTable:
|
||||
return []
|
||||
|
||||
async def get_functions(
|
||||
self, active_only=False, include_valves=False, db: Optional[AsyncSession] = None
|
||||
self, active_only=False, include_valves=False, db: AsyncSession | None = None
|
||||
) -> list[FunctionModel | FunctionWithValvesModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
if active_only:
|
||||
@@ -233,7 +234,7 @@ class FunctionsTable:
|
||||
else:
|
||||
return [FunctionModel.model_validate(function) for function in functions]
|
||||
|
||||
async def get_function_list(self, db: Optional[AsyncSession] = None) -> list[FunctionUserResponse]:
|
||||
async def get_function_list(self, db: AsyncSession | None = None) -> list[FunctionUserResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Function).order_by(Function.updated_at.desc()))
|
||||
functions = result.scalars().all()
|
||||
@@ -262,7 +263,7 @@ class FunctionsTable:
|
||||
]
|
||||
|
||||
async def get_functions_by_type(
|
||||
self, type: str, active_only=False, db: Optional[AsyncSession] = None
|
||||
self, type: str, active_only=False, db: AsyncSession | None = None
|
||||
) -> list[FunctionModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
if active_only:
|
||||
@@ -271,17 +272,17 @@ class FunctionsTable:
|
||||
result = await db.execute(select(Function).filter_by(type=type))
|
||||
return [FunctionModel.model_validate(function) for function in result.scalars().all()]
|
||||
|
||||
async def get_global_filter_functions(self, db: Optional[AsyncSession] = None) -> list[FunctionModel]:
|
||||
async def get_global_filter_functions(self, db: AsyncSession | None = None) -> list[FunctionModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Function).filter_by(type='filter', is_active=True, is_global=True))
|
||||
return [FunctionModel.model_validate(function) for function in result.scalars().all()]
|
||||
|
||||
async def get_global_action_functions(self, db: Optional[AsyncSession] = None) -> list[FunctionModel]:
|
||||
async def get_global_action_functions(self, db: AsyncSession | None = None) -> list[FunctionModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Function).filter_by(type='action', is_active=True, is_global=True))
|
||||
return [FunctionModel.model_validate(function) for function in result.scalars().all()]
|
||||
|
||||
async def get_function_valves_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[dict]:
|
||||
async def get_function_valves_by_id(self, id: str, db: AsyncSession | None = None) -> dict | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
function = await db.get(Function, id)
|
||||
@@ -290,7 +291,7 @@ class FunctionsTable:
|
||||
log.exception(f'Error getting function valves by id {id}: {e}')
|
||||
return None
|
||||
|
||||
async def get_function_valves_by_ids(self, ids: list[str], db: Optional[AsyncSession] = None) -> dict[str, dict]:
|
||||
async def get_function_valves_by_ids(self, ids: list[str], db: AsyncSession | None = None) -> dict[str, dict]:
|
||||
"""
|
||||
Batch fetch valves for multiple functions in a single query.
|
||||
Returns a dict mapping function_id -> valves dict.
|
||||
@@ -308,8 +309,8 @@ class FunctionsTable:
|
||||
return {}
|
||||
|
||||
async def update_function_valves_by_id(
|
||||
self, id: str, valves: dict, db: Optional[AsyncSession] = None
|
||||
) -> Optional[FunctionValves]:
|
||||
self, id: str, valves: dict, db: AsyncSession | None = None
|
||||
) -> FunctionValves | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
function = await db.get(Function, id)
|
||||
@@ -322,8 +323,8 @@ class FunctionsTable:
|
||||
return None
|
||||
|
||||
async def update_function_metadata_by_id(
|
||||
self, id: str, metadata: dict, db: Optional[AsyncSession] = None
|
||||
) -> Optional[FunctionModel]:
|
||||
self, id: str, metadata: dict, db: AsyncSession | None = None
|
||||
) -> FunctionModel | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
function = await db.get(Function, id)
|
||||
@@ -345,8 +346,8 @@ class FunctionsTable:
|
||||
return None
|
||||
|
||||
async def get_user_valves_by_id_and_user_id(
|
||||
self, id: str, user_id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[dict]:
|
||||
self, id: str, user_id: str, db: AsyncSession | None = None
|
||||
) -> dict | None:
|
||||
try:
|
||||
user = await Users.get_user_by_id(user_id, db=db)
|
||||
user_settings = user.settings.model_dump() if user.settings else {}
|
||||
@@ -363,8 +364,8 @@ class FunctionsTable:
|
||||
return None
|
||||
|
||||
async def update_user_valves_by_id_and_user_id(
|
||||
self, id: str, user_id: str, valves: dict, db: Optional[AsyncSession] = None
|
||||
) -> Optional[dict]:
|
||||
self, id: str, user_id: str, valves: dict, db: AsyncSession | None = None
|
||||
) -> dict | None:
|
||||
try:
|
||||
user = await Users.get_user_by_id(user_id, db=db)
|
||||
user_settings = user.settings.model_dump() if user.settings else {}
|
||||
@@ -386,8 +387,8 @@ class FunctionsTable:
|
||||
return None
|
||||
|
||||
async def update_function_by_id(
|
||||
self, id: str, updated: dict, db: Optional[AsyncSession] = None
|
||||
) -> Optional[FunctionModel]:
|
||||
self, id: str, updated: dict, db: AsyncSession | None = None
|
||||
) -> FunctionModel | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
await db.execute(
|
||||
@@ -404,7 +405,7 @@ class FunctionsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def deactivate_all_functions(self, db: Optional[AsyncSession] = None) -> Optional[bool]:
|
||||
async def deactivate_all_functions(self, db: AsyncSession | None = None) -> bool | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
await db.execute(
|
||||
@@ -418,7 +419,7 @@ class FunctionsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def delete_function_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_function_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
await db.execute(delete(Function).filter_by(id=id))
|
||||
|
||||
@@ -1,25 +1,29 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, delete, update, func, and_, or_, cast, String
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from open_webui.env import DEFAULT_GROUP_SHARE_PERMISSION
|
||||
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from open_webui.models.files import FileMetadataResponse
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Column,
|
||||
Text,
|
||||
JSON,
|
||||
ForeignKey,
|
||||
String,
|
||||
Text,
|
||||
and_,
|
||||
cast,
|
||||
delete,
|
||||
func,
|
||||
or_,
|
||||
select,
|
||||
update,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,34 +1,36 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, delete, update, or_, func, cast
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
|
||||
from open_webui.models.access_grants import AccessGrantModel, AccessGrants
|
||||
from open_webui.models.files import (
|
||||
File,
|
||||
FileModel,
|
||||
FileMetadataResponse,
|
||||
FileModel,
|
||||
FileModelResponse,
|
||||
)
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.users import User, UserModel, Users, UserResponse
|
||||
from open_webui.models.access_grants import AccessGrantModel, AccessGrants
|
||||
|
||||
|
||||
from open_webui.models.users import User, UserModel, UserResponse, Users
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
Column,
|
||||
ForeignKey,
|
||||
String,
|
||||
Text,
|
||||
JSON,
|
||||
UniqueConstraint,
|
||||
cast,
|
||||
delete,
|
||||
func,
|
||||
or_,
|
||||
select,
|
||||
update,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from open_webui.internal.db import Base, get_async_db_context
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text
|
||||
from sqlalchemy import BigInteger, Column, String, Text, delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
####################
|
||||
# Memory DB Schema
|
||||
@@ -45,8 +46,8 @@ class MemoriesTable:
|
||||
self,
|
||||
user_id: str,
|
||||
content: str,
|
||||
db: Optional[AsyncSession] = None,
|
||||
) -> Optional[MemoryModel]:
|
||||
db: AsyncSession | None = None,
|
||||
) -> MemoryModel | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
@@ -73,8 +74,8 @@ class MemoriesTable:
|
||||
id: str,
|
||||
user_id: str,
|
||||
content: str,
|
||||
db: Optional[AsyncSession] = None,
|
||||
) -> Optional[MemoryModel]:
|
||||
db: AsyncSession | None = None,
|
||||
) -> MemoryModel | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
memory = await db.get(Memory, id)
|
||||
@@ -90,7 +91,7 @@ class MemoriesTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_memories(self, db: Optional[AsyncSession] = None) -> list[MemoryModel]:
|
||||
async def get_memories(self, db: AsyncSession | None = None) -> list[MemoryModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
result = await db.execute(select(Memory))
|
||||
@@ -99,7 +100,7 @@ class MemoriesTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_memories_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> list[MemoryModel]:
|
||||
async def get_memories_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> list[MemoryModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
result = await db.execute(select(Memory).filter_by(user_id=user_id))
|
||||
@@ -108,7 +109,7 @@ class MemoriesTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_memory_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[MemoryModel]:
|
||||
async def get_memory_by_id(self, id: str, db: AsyncSession | None = None) -> MemoryModel | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
memory = await db.get(Memory, id)
|
||||
@@ -116,7 +117,7 @@ class MemoriesTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def delete_memory_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_memory_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
await db.execute(delete(Memory).filter_by(id=id))
|
||||
@@ -127,7 +128,7 @@ class MemoriesTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def delete_memories_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_memories_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
await db.execute(delete(Memory).filter_by(user_id=user_id))
|
||||
@@ -137,7 +138,7 @@ class MemoriesTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def delete_memory_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_memory_by_id_and_user_id(self, id: str, user_id: str, db: AsyncSession | None = None) -> bool:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
memory = await db.get(Memory, id)
|
||||
|
||||
@@ -3,17 +3,13 @@ import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, delete, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from open_webui.models.tags import TagModel, Tag, Tags
|
||||
from open_webui.models.users import Users, User, UserNameResponse
|
||||
from open_webui.models.channels import Channels, ChannelMember
|
||||
|
||||
|
||||
from open_webui.models.channels import ChannelMember, Channels
|
||||
from open_webui.models.tags import Tag, TagModel, Tags
|
||||
from open_webui.models.users import User, UserNameResponse, Users
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
|
||||
from sqlalchemy import or_, func, and_, text
|
||||
from sqlalchemy import JSON, BigInteger, Boolean, Column, String, Text, and_, delete, func, or_, select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import exists
|
||||
|
||||
####################
|
||||
|
||||
@@ -1,21 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, delete, update, or_, func, String, cast
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.users import User, UserModel, Users, UserResponse
|
||||
from open_webui.models.access_grants import AccessGrantModel, AccessGrants
|
||||
|
||||
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.users import User, UserModel, UserResponse, Users
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from sqlalchemy import BigInteger, Boolean, Column, String, Text, cast, delete, func, or_, select, update
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy import BigInteger, Column, Text, Boolean
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,14 +32,14 @@ class ModelParams(BaseModel):
|
||||
|
||||
# ModelMeta is a model for the data stored in the meta field of the Model table
|
||||
class ModelMeta(BaseModel):
|
||||
profile_image_url: Optional[str] = '/static/favicon.png'
|
||||
profile_image_url: str | None = '/static/favicon.png'
|
||||
|
||||
description: Optional[str] = None
|
||||
description: str | None = None
|
||||
"""
|
||||
User-facing description of the model.
|
||||
"""
|
||||
|
||||
capabilities: Optional[dict] = None
|
||||
capabilities: dict | None = None
|
||||
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
@@ -100,7 +97,7 @@ class Model(Base):
|
||||
class ModelModel(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
base_model_id: Optional[str] = None
|
||||
base_model_id: str | None = None
|
||||
|
||||
name: str
|
||||
params: ModelParams
|
||||
@@ -121,11 +118,11 @@ class ModelModel(BaseModel):
|
||||
|
||||
|
||||
class ModelUserResponse(ModelModel):
|
||||
user: Optional[UserResponse] = None
|
||||
user: UserResponse | None = None
|
||||
|
||||
|
||||
class ModelAccessResponse(ModelUserResponse):
|
||||
write_access: Optional[bool] = False
|
||||
write_access: bool | None = False
|
||||
|
||||
|
||||
class ModelResponse(ModelModel):
|
||||
@@ -146,23 +143,23 @@ class ModelForm(BaseModel):
|
||||
model_config = ConfigDict(extra='ignore')
|
||||
|
||||
id: str
|
||||
base_model_id: Optional[str] = None
|
||||
base_model_id: str | None = None
|
||||
name: str
|
||||
meta: ModelMeta
|
||||
params: ModelParams
|
||||
access_grants: Optional[list[dict]] = None
|
||||
access_grants: list[dict | None] = None
|
||||
is_active: bool = True
|
||||
|
||||
|
||||
class ModelsTable:
|
||||
async def _get_access_grants(self, model_id: str, db: Optional[AsyncSession] = None) -> list[AccessGrantModel]:
|
||||
async def _get_access_grants(self, model_id: str, db: AsyncSession | None = None) -> list[AccessGrantModel]:
|
||||
return await AccessGrants.get_grants_by_resource('model', model_id, db=db)
|
||||
|
||||
async def _to_model_model(
|
||||
self,
|
||||
model: Model,
|
||||
access_grants: Optional[list[AccessGrantModel]] = None,
|
||||
db: Optional[AsyncSession] = None,
|
||||
access_grants: list[AccessGrantModel | None] = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> ModelModel:
|
||||
model_data = ModelModel.model_validate(model).model_dump(exclude={'access_grants'})
|
||||
model_data['access_grants'] = (
|
||||
@@ -171,8 +168,8 @@ class ModelsTable:
|
||||
return ModelModel.model_validate(model_data)
|
||||
|
||||
async def insert_new_model(
|
||||
self, form_data: ModelForm, user_id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[ModelModel]:
|
||||
self, form_data: ModelForm, user_id: str, db: AsyncSession | None = None
|
||||
) -> ModelModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = Model(
|
||||
@@ -196,7 +193,7 @@ class ModelsTable:
|
||||
log.exception(f'Failed to insert a new model: {e}')
|
||||
return None
|
||||
|
||||
async def get_all_models(self, db: Optional[AsyncSession] = None) -> list[ModelModel]:
|
||||
async def get_all_models(self, db: AsyncSession | None = None) -> list[ModelModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Model))
|
||||
all_models = result.scalars().all()
|
||||
@@ -207,7 +204,7 @@ class ModelsTable:
|
||||
for model in all_models
|
||||
]
|
||||
|
||||
async def get_models(self, db: Optional[AsyncSession] = None) -> list[ModelUserResponse]:
|
||||
async def get_models(self, db: AsyncSession | None = None) -> list[ModelUserResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Model).filter(Model.base_model_id != None))
|
||||
all_models = result.scalars().all()
|
||||
@@ -238,7 +235,7 @@ class ModelsTable:
|
||||
)
|
||||
return models
|
||||
|
||||
async def get_base_models(self, db: Optional[AsyncSession] = None) -> list[ModelModel]:
|
||||
async def get_base_models(self, db: AsyncSession | None = None) -> list[ModelModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Model).filter(Model.base_model_id == None))
|
||||
all_models = result.scalars().all()
|
||||
@@ -250,7 +247,7 @@ class ModelsTable:
|
||||
]
|
||||
|
||||
async def get_models_by_user_id(
|
||||
self, user_id: str, permission: str = 'write', db: Optional[AsyncSession] = None
|
||||
self, user_id: str, permission: str = 'write', db: AsyncSession | None = None
|
||||
) -> list[ModelUserResponse]:
|
||||
models = await self.get_models(db=db)
|
||||
user_groups = await Groups.get_groups_by_member_id(user_id, db=db)
|
||||
@@ -287,7 +284,7 @@ class ModelsTable:
|
||||
filter: dict = {},
|
||||
skip: int = 0,
|
||||
limit: int = 30,
|
||||
db: Optional[AsyncSession] = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> ModelListResponse:
|
||||
async with get_async_db_context(db) as db:
|
||||
stmt = select(Model, User).outerjoin(User, User.id == Model.user_id)
|
||||
@@ -391,7 +388,7 @@ class ModelsTable:
|
||||
|
||||
return ModelListResponse(items=models, total=total)
|
||||
|
||||
async def get_model_meta_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[tuple[dict, int]]:
|
||||
async def get_model_meta_by_id(self, id: str, db: AsyncSession | None = None) -> tuple[dict, int | None]:
|
||||
"""Return (meta, updated_at) for a model, skipping access grant resolution."""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
@@ -404,7 +401,7 @@ class ModelsTable:
|
||||
self,
|
||||
user_id: str,
|
||||
is_admin: bool = False,
|
||||
db: Optional[AsyncSession] = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> set[str]:
|
||||
"""Extract unique tag names from model meta, querying only the meta column."""
|
||||
async with get_async_db_context(db) as db:
|
||||
@@ -437,7 +434,7 @@ class ModelsTable:
|
||||
|
||||
return tags_set
|
||||
|
||||
async def get_model_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ModelModel]:
|
||||
async def get_model_by_id(self, id: str, db: AsyncSession | None = None) -> ModelModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
model = await db.get(Model, id)
|
||||
@@ -445,7 +442,7 @@ class ModelsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_models_by_ids(self, ids: list[str], db: Optional[AsyncSession] = None) -> list[ModelModel]:
|
||||
async def get_models_by_ids(self, ids: list[str], db: AsyncSession | None = None) -> list[ModelModel]:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Model).filter(Model.id.in_(ids)))
|
||||
@@ -463,7 +460,7 @@ class ModelsTable:
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
async def toggle_model_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ModelModel]:
|
||||
async def toggle_model_by_id(self, id: str, db: AsyncSession | None = None) -> ModelModel | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
result = await db.execute(select(Model).filter_by(id=id))
|
||||
@@ -480,9 +477,7 @@ class ModelsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def update_model_by_id(
|
||||
self, id: str, model: ModelForm, db: Optional[AsyncSession] = None
|
||||
) -> Optional[ModelModel]:
|
||||
async def update_model_by_id(self, id: str, model: ModelForm, db: AsyncSession | None = None) -> ModelModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
# update only the fields that are present in the model
|
||||
@@ -499,7 +494,7 @@ class ModelsTable:
|
||||
log.exception(f'Failed to update the model by id {id}: {e}')
|
||||
return None
|
||||
|
||||
async def update_model_updated_at_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ModelModel]:
|
||||
async def update_model_updated_at_by_id(self, id: str, db: AsyncSession | None = None) -> ModelModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Model).filter_by(id=id))
|
||||
@@ -514,7 +509,7 @@ class ModelsTable:
|
||||
log.exception(f'Failed to update the model updated_at by id {id}: {e}')
|
||||
return None
|
||||
|
||||
async def delete_model_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_model_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await AccessGrants.revoke_all_access('model', id, db=db)
|
||||
@@ -525,7 +520,7 @@ class ModelsTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def delete_all_models(self, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_all_models(self, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Model.id))
|
||||
@@ -540,7 +535,7 @@ class ModelsTable:
|
||||
return False
|
||||
|
||||
async def sync_models(
|
||||
self, user_id: str, models: list[ModelModel], db: Optional[AsyncSession] = None
|
||||
self, user_id: str, models: list[ModelModel], db: AsyncSession | None = None
|
||||
) -> list[ModelModel]:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
|
||||
@@ -1,19 +1,16 @@
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Boolean, select, delete, update, or_, func, cast
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from open_webui.internal.db import Base, get_async_db_context
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.users import User, UserModel, Users, UserResponse
|
||||
from open_webui.models.access_grants import AccessGrantModel, AccessGrants
|
||||
|
||||
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.users import User, UserModel, UserResponse, Users
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import BigInteger, Column, Text, JSON, ForeignKey
|
||||
from sqlalchemy import JSON, BigInteger, Boolean, Column, ForeignKey, Text, cast, delete, func, or_, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
####################
|
||||
# Note DB Schema
|
||||
|
||||
@@ -1,20 +1,17 @@
|
||||
import time
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Optional, List
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from sqlalchemy import select, delete, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from open_webui.internal.db import Base, get_async_db_context
|
||||
from open_webui.env import OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
|
||||
|
||||
from open_webui.internal.db import Base, get_async_db_context
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, Index
|
||||
from sqlalchemy import BigInteger, Column, Index, String, Text, delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,18 +1,16 @@
|
||||
"""Prompt history model for version tracking."""
|
||||
|
||||
import difflib
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
import json
|
||||
import difflib
|
||||
|
||||
from sqlalchemy import select, delete, func
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from open_webui.internal.db import Base, get_async_db_context
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
|
||||
from open_webui.models.users import UserResponse, Users
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, Text, JSON, Index
|
||||
from sqlalchemy import JSON, BigInteger, Column, Index, Text, delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
####################
|
||||
# PromptHistory DB Schema
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, delete, update, or_, func, text, cast, String
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.users import Users, User, UserModel, UserResponse
|
||||
from open_webui.models.prompt_history import PromptHistories
|
||||
from open_webui.models.access_grants import AccessGrantModel, AccessGrants
|
||||
|
||||
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.prompt_history import PromptHistories
|
||||
from open_webui.models.users import User, UserModel, UserResponse, Users
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import BigInteger, Boolean, Column, Text, JSON
|
||||
from sqlalchemy import JSON, BigInteger, Boolean, Column, String, Text, cast, delete, func, or_, select, text, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
####################
|
||||
# Prompts DB Schema
|
||||
@@ -40,18 +39,18 @@ class Prompt(Base):
|
||||
|
||||
|
||||
class PromptModel(BaseModel):
|
||||
id: Optional[str] = None
|
||||
id: str | None = None
|
||||
command: str
|
||||
user_id: str
|
||||
name: str
|
||||
content: str
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
tags: Optional[list[str]] = None
|
||||
is_active: Optional[bool] = True
|
||||
version_id: Optional[str] = None
|
||||
created_at: Optional[int] = None
|
||||
updated_at: Optional[int] = None
|
||||
data: dict | None = None
|
||||
meta: dict | None = None
|
||||
tags: list[str | None] = None
|
||||
is_active: bool | None = True
|
||||
version_id: str | None = None
|
||||
created_at: int | None = None
|
||||
updated_at: int | None = None
|
||||
access_grants: list[AccessGrantModel] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
@@ -63,11 +62,11 @@ class PromptModel(BaseModel):
|
||||
|
||||
|
||||
class PromptUserResponse(PromptModel):
|
||||
user: Optional[UserResponse] = None
|
||||
user: UserResponse | None = None
|
||||
|
||||
|
||||
class PromptAccessResponse(PromptUserResponse):
|
||||
write_access: Optional[bool] = False
|
||||
write_access: bool | None = False
|
||||
|
||||
|
||||
class PromptListResponse(BaseModel):
|
||||
@@ -84,24 +83,24 @@ class PromptForm(BaseModel):
|
||||
command: str
|
||||
name: str # Changed from title
|
||||
content: str
|
||||
data: Optional[dict] = None
|
||||
meta: Optional[dict] = None
|
||||
tags: Optional[list[str]] = None
|
||||
access_grants: Optional[list[dict]] = None
|
||||
version_id: Optional[str] = None # Active version
|
||||
commit_message: Optional[str] = None # For history tracking
|
||||
is_production: Optional[bool] = True # Whether to set new version as production
|
||||
data: dict | None = None
|
||||
meta: dict | None = None
|
||||
tags: list[str | None] = None
|
||||
access_grants: list[dict | None] = None
|
||||
version_id: str | None = None # Active version
|
||||
commit_message: str | None = None # For history tracking
|
||||
is_production: bool | None = True # Whether to set new version as production
|
||||
|
||||
|
||||
class PromptsTable:
|
||||
async def _get_access_grants(self, prompt_id: str, db: Optional[AsyncSession] = None) -> list[AccessGrantModel]:
|
||||
async def _get_access_grants(self, prompt_id: str, db: AsyncSession | None = None) -> list[AccessGrantModel]:
|
||||
return await AccessGrants.get_grants_by_resource('prompt', prompt_id, db=db)
|
||||
|
||||
async def _to_prompt_model(
|
||||
self,
|
||||
prompt: Prompt,
|
||||
access_grants: Optional[list[AccessGrantModel]] = None,
|
||||
db: Optional[AsyncSession] = None,
|
||||
access_grants: list[AccessGrantModel | None] = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> PromptModel:
|
||||
prompt_data = PromptModel.model_validate(prompt).model_dump(exclude={'access_grants'})
|
||||
prompt_data['access_grants'] = (
|
||||
@@ -110,8 +109,8 @@ class PromptsTable:
|
||||
return PromptModel.model_validate(prompt_data)
|
||||
|
||||
async def insert_new_prompt(
|
||||
self, user_id: str, form_data: PromptForm, db: Optional[AsyncSession] = None
|
||||
) -> Optional[PromptModel]:
|
||||
self, user_id: str, form_data: PromptForm, db: AsyncSession | None = None
|
||||
) -> PromptModel | None:
|
||||
now = int(time.time())
|
||||
prompt_id = str(uuid.uuid4())
|
||||
|
||||
@@ -171,7 +170,7 @@ class PromptsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_prompt_by_id(self, prompt_id: str, db: Optional[AsyncSession] = None) -> Optional[PromptModel]:
|
||||
async def get_prompt_by_id(self, prompt_id: str, db: AsyncSession | None = None) -> PromptModel | None:
|
||||
"""Get prompt by UUID."""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
@@ -183,7 +182,7 @@ class PromptsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_prompt_by_command(self, command: str, db: Optional[AsyncSession] = None) -> Optional[PromptModel]:
|
||||
async def get_prompt_by_command(self, command: str, db: AsyncSession | None = None) -> PromptModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Prompt).filter_by(command=command))
|
||||
@@ -194,7 +193,7 @@ class PromptsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_prompts(self, db: Optional[AsyncSession] = None) -> list[PromptUserResponse]:
|
||||
async def get_prompts(self, db: AsyncSession | None = None) -> list[PromptUserResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(Prompt).filter(Prompt.is_active == True).order_by(Prompt.updated_at.desc())
|
||||
@@ -229,7 +228,7 @@ class PromptsTable:
|
||||
return prompts
|
||||
|
||||
async def get_prompts_by_user_id(
|
||||
self, user_id: str, permission: str = 'write', db: Optional[AsyncSession] = None
|
||||
self, user_id: str, permission: str = 'write', db: AsyncSession | None = None
|
||||
) -> list[PromptUserResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
user_groups = await Groups.get_groups_by_member_id(user_id, db=db)
|
||||
@@ -283,7 +282,7 @@ class PromptsTable:
|
||||
filter: dict = {},
|
||||
skip: int = 0,
|
||||
limit: int = 30,
|
||||
db: Optional[AsyncSession] = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> PromptListResponse:
|
||||
async with get_async_db_context(db) as db:
|
||||
# Join with User table for user filtering and sorting
|
||||
@@ -404,8 +403,8 @@ class PromptsTable:
|
||||
command: str,
|
||||
form_data: PromptForm,
|
||||
user_id: str,
|
||||
db: Optional[AsyncSession] = None,
|
||||
) -> Optional[PromptModel]:
|
||||
db: AsyncSession | None = None,
|
||||
) -> PromptModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Prompt).filter_by(command=command))
|
||||
@@ -470,8 +469,8 @@ class PromptsTable:
|
||||
prompt_id: str,
|
||||
form_data: PromptForm,
|
||||
user_id: str,
|
||||
db: Optional[AsyncSession] = None,
|
||||
) -> Optional[PromptModel]:
|
||||
db: AsyncSession | None = None,
|
||||
) -> PromptModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Prompt).filter_by(id=prompt_id))
|
||||
@@ -545,9 +544,9 @@ class PromptsTable:
|
||||
prompt_id: str,
|
||||
name: str,
|
||||
command: str,
|
||||
tags: Optional[list[str]] = None,
|
||||
db: Optional[AsyncSession] = None,
|
||||
) -> Optional[PromptModel]:
|
||||
tags: list[str | None] = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> PromptModel | None:
|
||||
"""Update only name, command, and tags (no history created)."""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
@@ -573,8 +572,8 @@ class PromptsTable:
|
||||
self,
|
||||
prompt_id: str,
|
||||
version_id: str,
|
||||
db: Optional[AsyncSession] = None,
|
||||
) -> Optional[PromptModel]:
|
||||
db: AsyncSession | None = None,
|
||||
) -> PromptModel | None:
|
||||
"""Set the active version of a prompt and restore content from that version's snapshot."""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
@@ -606,7 +605,7 @@ class PromptsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def toggle_prompt_active(self, prompt_id: str, db: Optional[AsyncSession] = None) -> Optional[PromptModel]:
|
||||
async def toggle_prompt_active(self, prompt_id: str, db: AsyncSession | None = None) -> PromptModel | None:
|
||||
"""Toggle the is_active flag on a prompt."""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
@@ -622,7 +621,7 @@ class PromptsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def delete_prompt_by_command(self, command: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_prompt_by_command(self, command: str, db: AsyncSession | None = None) -> bool:
|
||||
"""Permanently delete a prompt and its history."""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
@@ -639,7 +638,7 @@ class PromptsTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def delete_prompt_by_id(self, prompt_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_prompt_by_id(self, prompt_id: str, db: AsyncSession | None = None) -> bool:
|
||||
"""Permanently delete a prompt and its history."""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
@@ -656,7 +655,7 @@ class PromptsTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_tags(self, db: Optional[AsyncSession] = None) -> list[str]:
|
||||
async def get_tags(self, db: AsyncSession | None = None) -> list[str]:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Prompt.tags).filter(Prompt.is_active == True))
|
||||
@@ -670,7 +669,7 @@ class PromptsTable:
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
async def get_tags_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> list[str]:
|
||||
async def get_tags_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> list[str]:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
user_groups = await Groups.get_groups_by_member_id(user_id, db=db)
|
||||
|
||||
@@ -3,12 +3,10 @@ import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, ForeignKey, Text, JSON
|
||||
from sqlalchemy import JSON, BigInteger, Column, ForeignKey, Text, delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -2,15 +2,13 @@ import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, delete, update, or_
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from open_webui.internal.db import Base, get_async_db_context
|
||||
from open_webui.models.users import Users, User, UserModel, UserResponse
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.access_grants import AccessGrantModel, AccessGrants
|
||||
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.users import User, UserModel, UserResponse, Users
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import JSON, BigInteger, Boolean, Column, String, Text, func
|
||||
from sqlalchemy import JSON, BigInteger, Boolean, Column, String, Text, delete, func, or_, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,15 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, delete
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint, Index
|
||||
from sqlalchemy import JSON, BigInteger, Column, Index, PrimaryKeyConstraint, String, delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,7 +38,7 @@ class TagModel(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
user_id: str
|
||||
meta: Optional[dict] = None
|
||||
meta: dict | None = None
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
@@ -54,7 +53,7 @@ class TagChatIdForm(BaseModel):
|
||||
|
||||
|
||||
class TagTable:
|
||||
async def insert_new_tag(self, name: str, user_id: str, db: Optional[AsyncSession] = None) -> Optional[TagModel]:
|
||||
async def insert_new_tag(self, name: str, user_id: str, db: AsyncSession | None = None) -> TagModel | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
id = name.replace(' ', '_').lower()
|
||||
tag = TagModel(**{'id': id, 'user_id': user_id, 'name': name})
|
||||
@@ -72,8 +71,8 @@ class TagTable:
|
||||
return None
|
||||
|
||||
async def get_tag_by_name_and_user_id(
|
||||
self, name: str, user_id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[TagModel]:
|
||||
self, name: str, user_id: str, db: AsyncSession | None = None
|
||||
) -> TagModel | None:
|
||||
try:
|
||||
id = name.replace(' ', '_').lower()
|
||||
async with get_async_db_context(db) as db:
|
||||
@@ -83,19 +82,19 @@ class TagTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_tags_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> list[TagModel]:
|
||||
async def get_tags_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> list[TagModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Tag).filter_by(user_id=user_id))
|
||||
return [TagModel.model_validate(tag) for tag in result.scalars().all()]
|
||||
|
||||
async def get_tags_by_ids_and_user_id(
|
||||
self, ids: list[str], user_id: str, db: Optional[AsyncSession] = None
|
||||
self, ids: list[str], user_id: str, db: AsyncSession | None = None
|
||||
) -> list[TagModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id))
|
||||
return [TagModel.model_validate(tag) for tag in result.scalars().all()]
|
||||
|
||||
async def delete_tag_by_name_and_user_id(self, name: str, user_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_tag_by_name_and_user_id(self, name: str, user_id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
id = name.replace(' ', '_').lower()
|
||||
@@ -108,7 +107,7 @@ class TagTable:
|
||||
return False
|
||||
|
||||
async def delete_tags_by_ids_and_user_id(
|
||||
self, ids: list[str], user_id: str, db: Optional[AsyncSession] = None
|
||||
self, ids: list[str], user_id: str, db: AsyncSession | None = None
|
||||
) -> bool:
|
||||
"""Delete all tags whose id is in *ids* for the given user, in one query."""
|
||||
if not ids:
|
||||
@@ -122,7 +121,7 @@ class TagTable:
|
||||
log.error(f'delete_tags_by_ids: {e}')
|
||||
return False
|
||||
|
||||
async def ensure_tags_exist(self, names: list[str], user_id: str, db: Optional[AsyncSession] = None) -> None:
|
||||
async def ensure_tags_exist(self, names: list[str], user_id: str, db: AsyncSession | None = None) -> None:
|
||||
"""Create tag rows for any *names* that don't already exist for *user_id*."""
|
||||
if not names:
|
||||
return
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, delete, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from open_webui.models.users import Users, UserResponse
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.access_grants import AccessGrantModel, AccessGrants
|
||||
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.users import UserResponse, Users
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import BigInteger, Column, String, Text
|
||||
from sqlalchemy import BigInteger, Column, String, Text, delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -37,8 +37,8 @@ class Tool(Base):
|
||||
|
||||
|
||||
class ToolMeta(BaseModel):
|
||||
description: Optional[str] = None
|
||||
manifest: Optional[dict] = {}
|
||||
description: str | None = None
|
||||
manifest: dict | None = {}
|
||||
|
||||
|
||||
class ToolModel(BaseModel):
|
||||
@@ -62,7 +62,7 @@ class ToolModel(BaseModel):
|
||||
|
||||
|
||||
class ToolUserModel(ToolModel):
|
||||
user: Optional[UserResponse] = None
|
||||
user: UserResponse | None = None
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
@@ -76,13 +76,13 @@ class ToolResponse(BaseModel):
|
||||
|
||||
|
||||
class ToolUserResponse(ToolResponse):
|
||||
user: Optional[UserResponse] = None
|
||||
user: UserResponse | None = None
|
||||
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
|
||||
class ToolAccessResponse(ToolUserResponse):
|
||||
write_access: Optional[bool] = False
|
||||
write_access: bool | None = False
|
||||
|
||||
|
||||
class ToolForm(BaseModel):
|
||||
@@ -90,22 +90,22 @@ class ToolForm(BaseModel):
|
||||
name: str
|
||||
content: str
|
||||
meta: ToolMeta
|
||||
access_grants: Optional[list[dict]] = None
|
||||
access_grants: list[dict | None] = None
|
||||
|
||||
|
||||
class ToolValves(BaseModel):
|
||||
valves: Optional[dict] = None
|
||||
valves: dict | None = None
|
||||
|
||||
|
||||
class ToolsTable:
|
||||
async def _get_access_grants(self, tool_id: str, db: Optional[AsyncSession] = None) -> list[AccessGrantModel]:
|
||||
async def _get_access_grants(self, tool_id: str, db: AsyncSession | None = None) -> list[AccessGrantModel]:
|
||||
return await AccessGrants.get_grants_by_resource('tool', tool_id, db=db)
|
||||
|
||||
async def _to_tool_model(
|
||||
self,
|
||||
tool: Tool,
|
||||
access_grants: Optional[list[AccessGrantModel]] = None,
|
||||
db: Optional[AsyncSession] = None,
|
||||
access_grants: list[AccessGrantModel | None] = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> ToolModel:
|
||||
tool_data = ToolModel.model_validate(tool).model_dump(exclude={'access_grants'})
|
||||
tool_data['access_grants'] = (
|
||||
@@ -118,8 +118,8 @@ class ToolsTable:
|
||||
user_id: str,
|
||||
form_data: ToolForm,
|
||||
specs: list[dict],
|
||||
db: Optional[AsyncSession] = None,
|
||||
) -> Optional[ToolModel]:
|
||||
db: AsyncSession | None = None,
|
||||
) -> ToolModel | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
result = Tool(
|
||||
@@ -143,7 +143,7 @@ class ToolsTable:
|
||||
log.exception(f'Error creating a new tool: {e}')
|
||||
return None
|
||||
|
||||
async def get_tool_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ToolModel]:
|
||||
async def get_tool_by_id(self, id: str, db: AsyncSession | None = None) -> ToolModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
tool = await db.get(Tool, id)
|
||||
@@ -151,7 +151,7 @@ class ToolsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_tools(self, defer_content: bool = False, db: Optional[AsyncSession] = None) -> list[ToolUserModel]:
|
||||
async def get_tools(self, defer_content: bool = False, db: AsyncSession | None = None) -> list[ToolUserModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
stmt = select(Tool).order_by(Tool.updated_at.desc())
|
||||
if defer_content:
|
||||
@@ -190,7 +190,7 @@ class ToolsTable:
|
||||
user_id: str,
|
||||
permission: str = 'write',
|
||||
defer_content: bool = False,
|
||||
db: Optional[AsyncSession] = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> list[ToolUserModel]:
|
||||
tools = await self.get_tools(defer_content=defer_content, db=db)
|
||||
user_groups = await Groups.get_groups_by_member_id(user_id, db=db)
|
||||
@@ -211,7 +211,7 @@ class ToolsTable:
|
||||
result.append(tool)
|
||||
return result
|
||||
|
||||
async def get_tool_valves_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[dict]:
|
||||
async def get_tool_valves_by_id(self, id: str, db: AsyncSession | None = None) -> dict | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
tool = await db.get(Tool, id)
|
||||
@@ -221,8 +221,8 @@ class ToolsTable:
|
||||
return None
|
||||
|
||||
async def update_tool_valves_by_id(
|
||||
self, id: str, valves: dict, db: Optional[AsyncSession] = None
|
||||
) -> Optional[ToolValves]:
|
||||
self, id: str, valves: dict, db: AsyncSession | None = None
|
||||
) -> ToolValves | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(update(Tool).filter_by(id=id).values(valves=valves, updated_at=int(time.time())))
|
||||
@@ -232,8 +232,8 @@ class ToolsTable:
|
||||
return None
|
||||
|
||||
async def get_user_valves_by_id_and_user_id(
|
||||
self, id: str, user_id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[dict]:
|
||||
self, id: str, user_id: str, db: AsyncSession | None = None
|
||||
) -> dict | None:
|
||||
try:
|
||||
user = await Users.get_user_by_id(user_id, db=db)
|
||||
user_settings = user.settings.model_dump() if user.settings else {}
|
||||
@@ -250,8 +250,8 @@ class ToolsTable:
|
||||
return None
|
||||
|
||||
async def update_user_valves_by_id_and_user_id(
|
||||
self, id: str, user_id: str, valves: dict, db: Optional[AsyncSession] = None
|
||||
) -> Optional[dict]:
|
||||
self, id: str, user_id: str, valves: dict, db: AsyncSession | None = None
|
||||
) -> dict | None:
|
||||
try:
|
||||
user = await Users.get_user_by_id(user_id, db=db)
|
||||
user_settings = user.settings.model_dump() if user.settings else {}
|
||||
@@ -272,7 +272,7 @@ class ToolsTable:
|
||||
log.exception(f'Error updating user valves by id {id} and user_id {user_id}: {e}')
|
||||
return None
|
||||
|
||||
async def update_tool_by_id(self, id: str, updated: dict, db: Optional[AsyncSession] = None) -> Optional[ToolModel]:
|
||||
async def update_tool_by_id(self, id: str, updated: dict, db: AsyncSession | None = None) -> ToolModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
access_grants = updated.pop('access_grants', None)
|
||||
@@ -287,7 +287,7 @@ class ToolsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def delete_tool_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_tool_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await AccessGrants.revoke_all_access('tool', id, db=db)
|
||||
|
||||
@@ -1,29 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select, delete, update, func, or_, case, exists
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
|
||||
from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL
|
||||
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from open_webui.utils.misc import throttle
|
||||
from open_webui.utils.validate import validate_profile_image_url
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
JSON,
|
||||
Column,
|
||||
String,
|
||||
BigInteger,
|
||||
Boolean,
|
||||
Text,
|
||||
Column,
|
||||
Date,
|
||||
String,
|
||||
Text,
|
||||
case,
|
||||
cast,
|
||||
delete,
|
||||
exists,
|
||||
func,
|
||||
or_,
|
||||
select,
|
||||
update,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
|
||||
import datetime
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
####################
|
||||
# User DB Schema
|
||||
@@ -33,7 +37,7 @@ import datetime
|
||||
|
||||
|
||||
class UserSettings(BaseModel):
|
||||
ui: Optional[dict] = {}
|
||||
ui: dict | None = {}
|
||||
model_config = ConfigDict(extra='allow')
|
||||
pass
|
||||
|
||||
@@ -76,29 +80,29 @@ class UserModel(BaseModel):
|
||||
id: str
|
||||
|
||||
email: str
|
||||
username: Optional[str] = None
|
||||
username: str | None = None
|
||||
role: str = 'pending'
|
||||
|
||||
name: str
|
||||
|
||||
profile_image_url: Optional[str] = None
|
||||
profile_banner_image_url: Optional[str] = None
|
||||
profile_image_url: str | None = None
|
||||
profile_banner_image_url: str | None = None
|
||||
|
||||
bio: Optional[str] = None
|
||||
gender: Optional[str] = None
|
||||
date_of_birth: Optional[datetime.date] = None
|
||||
timezone: Optional[str] = None
|
||||
bio: str | None = None
|
||||
gender: str | None = None
|
||||
date_of_birth: datetime.date | None = None
|
||||
timezone: str | None = None
|
||||
|
||||
presence_state: Optional[str] = None
|
||||
status_emoji: Optional[str] = None
|
||||
status_message: Optional[str] = None
|
||||
status_expires_at: Optional[int] = None
|
||||
presence_state: str | None = None
|
||||
status_emoji: str | None = None
|
||||
status_message: str | None = None
|
||||
status_expires_at: int | None = None
|
||||
|
||||
info: Optional[dict] = None
|
||||
settings: Optional[UserSettings] = None
|
||||
info: dict | None = None
|
||||
settings: UserSettings | None = None
|
||||
|
||||
oauth: Optional[dict] = None
|
||||
scim: Optional[dict] = None
|
||||
oauth: dict | None = None
|
||||
scim: dict | None = None
|
||||
|
||||
last_active_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
@@ -136,9 +140,9 @@ class ApiKeyModel(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
key: str
|
||||
data: Optional[dict] = None
|
||||
expires_at: Optional[int] = None
|
||||
last_used_at: Optional[int] = None
|
||||
data: dict | None = None
|
||||
expires_at: int | None = None
|
||||
last_used_at: int | None = None
|
||||
created_at: int # timestamp in epoch
|
||||
updated_at: int # timestamp in epoch
|
||||
|
||||
@@ -153,9 +157,9 @@ class ApiKeyModel(BaseModel):
|
||||
class UpdateProfileForm(BaseModel):
|
||||
profile_image_url: str
|
||||
name: str
|
||||
bio: Optional[str] = None
|
||||
gender: Optional[str] = None
|
||||
date_of_birth: Optional[datetime.date] = None
|
||||
bio: str | None = None
|
||||
gender: str | None = None
|
||||
date_of_birth: datetime.date | None = None
|
||||
|
||||
@field_validator('profile_image_url')
|
||||
@classmethod
|
||||
@@ -182,9 +186,9 @@ class UserGroupIdsListResponse(BaseModel):
|
||||
|
||||
|
||||
class UserStatus(BaseModel):
|
||||
status_emoji: Optional[str] = None
|
||||
status_message: Optional[str] = None
|
||||
status_expires_at: Optional[int] = None
|
||||
status_emoji: str | None = None
|
||||
status_message: str | None = None
|
||||
status_expires_at: int | None = None
|
||||
|
||||
|
||||
class UserInfoResponse(UserStatus):
|
||||
@@ -192,8 +196,8 @@ class UserInfoResponse(UserStatus):
|
||||
name: str
|
||||
email: str
|
||||
role: str
|
||||
bio: Optional[str] = None
|
||||
groups: Optional[list] = []
|
||||
bio: str | None = None
|
||||
groups: list | None = []
|
||||
is_active: bool = False
|
||||
|
||||
|
||||
@@ -205,7 +209,7 @@ class UserIdNameResponse(BaseModel):
|
||||
class UserIdNameStatusResponse(UserStatus):
|
||||
id: str
|
||||
name: str
|
||||
is_active: Optional[bool] = None
|
||||
is_active: bool | None = None
|
||||
|
||||
|
||||
class UserInfoListResponse(BaseModel):
|
||||
@@ -239,15 +243,15 @@ class UserRoleUpdateForm(BaseModel):
|
||||
|
||||
|
||||
class UserUpdateForm(BaseModel):
|
||||
role: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
profile_image_url: Optional[str] = None
|
||||
password: Optional[str] = None
|
||||
role: str | None = None
|
||||
name: str | None = None
|
||||
email: str | None = None
|
||||
profile_image_url: str | None = None
|
||||
password: str | None = None
|
||||
|
||||
@field_validator('profile_image_url', mode='before')
|
||||
@classmethod
|
||||
def check_profile_image_url(cls, v: Optional[str]) -> Optional[str]:
|
||||
def check_profile_image_url(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return v
|
||||
return validate_profile_image_url(v)
|
||||
@@ -261,10 +265,10 @@ class UsersTable:
|
||||
email: str,
|
||||
profile_image_url: str = '/user.png',
|
||||
role: str = 'pending',
|
||||
username: Optional[str] = None,
|
||||
oauth: Optional[dict] = None,
|
||||
db: Optional[AsyncSession] = None,
|
||||
) -> Optional[UserModel]:
|
||||
username: str | None = None,
|
||||
oauth: dict | None = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
user = UserModel(
|
||||
**{
|
||||
@@ -289,7 +293,7 @@ class UsersTable:
|
||||
else:
|
||||
return None
|
||||
|
||||
async def get_user_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[UserModel]:
|
||||
async def get_user_by_id(self, id: str, db: AsyncSession | None = None) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(id=id))
|
||||
@@ -298,7 +302,7 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_user_by_api_key(self, api_key: str, db: Optional[AsyncSession] = None) -> Optional[UserModel]:
|
||||
async def get_user_by_api_key(self, api_key: str, db: AsyncSession | None = None) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
@@ -309,7 +313,7 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_user_by_email(self, email: str, db: Optional[AsyncSession] = None) -> Optional[UserModel]:
|
||||
async def get_user_by_email(self, email: str, db: AsyncSession | None = None) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter(func.lower(User.email) == email.lower()))
|
||||
@@ -318,9 +322,7 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_user_by_oauth_sub(
|
||||
self, provider: str, sub: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[UserModel]:
|
||||
async def get_user_by_oauth_sub(self, provider: str, sub: str, db: AsyncSession | None = None) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
dialect_name = db.bind.dialect.name
|
||||
@@ -339,8 +341,8 @@ class UsersTable:
|
||||
return None
|
||||
|
||||
async def get_user_by_scim_external_id(
|
||||
self, provider: str, external_id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[UserModel]:
|
||||
self, provider: str, external_id: str, db: AsyncSession | None = None
|
||||
) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
dialect_name = db.bind.dialect.name
|
||||
@@ -359,15 +361,15 @@ class UsersTable:
|
||||
|
||||
async def get_users(
|
||||
self,
|
||||
filter: Optional[dict] = None,
|
||||
skip: Optional[int] = None,
|
||||
limit: Optional[int] = None,
|
||||
db: Optional[AsyncSession] = None,
|
||||
filter: dict | None = None,
|
||||
skip: int | None = None,
|
||||
limit: int | None = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> dict:
|
||||
async with get_async_db_context(db) as db:
|
||||
# Import here to avoid circular imports
|
||||
from open_webui.models.groups import GroupMember
|
||||
from open_webui.models.channels import ChannelMember
|
||||
from open_webui.models.groups import GroupMember
|
||||
|
||||
# Join GroupMember so we can order by group_id when requested
|
||||
stmt = select(User)
|
||||
@@ -501,7 +503,7 @@ class UsersTable:
|
||||
'total': total,
|
||||
}
|
||||
|
||||
async def get_users_by_group_id(self, group_id: str, db: Optional[AsyncSession] = None) -> list[UserModel]:
|
||||
async def get_users_by_group_id(self, group_id: str, db: AsyncSession | None = None) -> list[UserModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
from open_webui.models.groups import GroupMember
|
||||
|
||||
@@ -511,25 +513,23 @@ class UsersTable:
|
||||
users = result.scalars().all()
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
|
||||
async def get_users_by_user_ids(
|
||||
self, user_ids: list[str], db: Optional[AsyncSession] = None
|
||||
) -> list[UserStatusModel]:
|
||||
async def get_users_by_user_ids(self, user_ids: list[str], db: AsyncSession | None = None) -> list[UserStatusModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter(User.id.in_(user_ids)))
|
||||
users = result.scalars().all()
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
|
||||
async def get_num_users(self, db: Optional[AsyncSession] = None) -> Optional[int]:
|
||||
async def get_num_users(self, db: AsyncSession | None = None) -> int | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(func.count()).select_from(User))
|
||||
return result.scalar()
|
||||
|
||||
async def has_users(self, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def has_users(self, db: AsyncSession | None = None) -> bool:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(exists(select(User))))
|
||||
return result.scalar()
|
||||
|
||||
async def get_first_user(self, db: Optional[AsyncSession] = None) -> UserModel:
|
||||
async def get_first_user(self, db: AsyncSession | None = None) -> UserModel:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).order_by(User.created_at).limit(1))
|
||||
@@ -538,7 +538,7 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_user_webhook_url_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[str]:
|
||||
async def get_user_webhook_url_by_id(self, id: str, db: AsyncSession | None = None) -> str | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(id=id))
|
||||
@@ -551,7 +551,7 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_num_users_active_today(self, db: Optional[AsyncSession] = None) -> Optional[int]:
|
||||
async def get_num_users_active_today(self, db: AsyncSession | None = None) -> int | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
current_timestamp = int(datetime.datetime.now().timestamp())
|
||||
today_midnight_timestamp = current_timestamp - (current_timestamp % 86400)
|
||||
@@ -560,9 +560,7 @@ class UsersTable:
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
async def update_user_role_by_id(
|
||||
self, id: str, role: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[UserModel]:
|
||||
async def update_user_role_by_id(self, id: str, role: str, db: AsyncSession | None = None) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(id=id))
|
||||
@@ -577,8 +575,8 @@ class UsersTable:
|
||||
return None
|
||||
|
||||
async def update_user_status_by_id(
|
||||
self, id: str, form_data: UserStatus, db: Optional[AsyncSession] = None
|
||||
) -> Optional[UserModel]:
|
||||
self, id: str, form_data: UserStatus, db: AsyncSession | None = None
|
||||
) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(id=id))
|
||||
@@ -594,8 +592,8 @@ class UsersTable:
|
||||
return None
|
||||
|
||||
async def update_user_profile_image_url_by_id(
|
||||
self, id: str, profile_image_url: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[UserModel]:
|
||||
self, id: str, profile_image_url: str, db: AsyncSession | None = None
|
||||
) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(id=id))
|
||||
@@ -610,7 +608,7 @@ class UsersTable:
|
||||
return None
|
||||
|
||||
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
|
||||
async def update_last_active_by_id(self, id: str, db: Optional[AsyncSession] = None) -> None:
|
||||
async def update_last_active_by_id(self, id: str, db: AsyncSession | None = None) -> None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(update(User).filter_by(id=id).values(last_active_at=int(time.time())))
|
||||
@@ -619,8 +617,8 @@ class UsersTable:
|
||||
pass
|
||||
|
||||
async def update_user_oauth_by_id(
|
||||
self, id: str, provider: str, sub: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[UserModel]:
|
||||
self, id: str, provider: str, sub: str, db: AsyncSession | None = None
|
||||
) -> UserModel | None:
|
||||
"""
|
||||
Update or insert an OAuth provider/sub pair into the user's oauth JSON field.
|
||||
Example resulting structure:
|
||||
@@ -656,8 +654,8 @@ class UsersTable:
|
||||
id: str,
|
||||
provider: str,
|
||||
external_id: str,
|
||||
db: Optional[AsyncSession] = None,
|
||||
) -> Optional[UserModel]:
|
||||
db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
"""
|
||||
Update or insert a SCIM provider/external_id pair into the user's scim JSON field.
|
||||
Example resulting structure:
|
||||
@@ -684,7 +682,7 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def update_user_by_id(self, id: str, updated: dict, db: Optional[AsyncSession] = None) -> Optional[UserModel]:
|
||||
async def update_user_by_id(self, id: str, updated: dict, db: AsyncSession | None = None) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(id=id))
|
||||
@@ -701,8 +699,8 @@ class UsersTable:
|
||||
return None
|
||||
|
||||
async def update_user_settings_by_id(
|
||||
self, id: str, updated: dict, db: Optional[AsyncSession] = None
|
||||
) -> Optional[UserModel]:
|
||||
self, id: str, updated: dict, db: AsyncSession | None = None
|
||||
) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(id=id))
|
||||
@@ -726,10 +724,10 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def delete_user_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_user_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.chats import Chats
|
||||
from open_webui.models.groups import Groups
|
||||
|
||||
# Remove User from Groups
|
||||
await Groups.remove_user_from_all_groups(id)
|
||||
@@ -748,7 +746,7 @@ class UsersTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_user_api_key_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[str]:
|
||||
async def get_user_api_key_by_id(self, id: str, db: AsyncSession | None = None) -> str | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(ApiKey).filter_by(user_id=id))
|
||||
@@ -757,7 +755,7 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def update_user_api_key_by_id(self, id: str, api_key: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def update_user_api_key_by_id(self, id: str, api_key: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(delete(ApiKey).filter_by(user_id=id))
|
||||
@@ -779,7 +777,7 @@ class UsersTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def delete_user_api_key_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_user_api_key_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(delete(ApiKey).filter_by(user_id=id))
|
||||
@@ -788,13 +786,13 @@ class UsersTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_valid_user_ids(self, user_ids: list[str], db: Optional[AsyncSession] = None) -> list[str]:
|
||||
async def get_valid_user_ids(self, user_ids: list[str], db: AsyncSession | None = None) -> list[str]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter(User.id.in_(user_ids)))
|
||||
users = result.scalars().all()
|
||||
return [user.id for user in users]
|
||||
|
||||
async def get_super_admin_user(self, db: Optional[AsyncSession] = None) -> Optional[UserModel]:
|
||||
async def get_super_admin_user(self, db: AsyncSession | None = None) -> UserModel | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(role='admin').limit(1))
|
||||
user = result.scalars().first()
|
||||
@@ -803,7 +801,7 @@ class UsersTable:
|
||||
else:
|
||||
return None
|
||||
|
||||
async def get_active_user_count(self, db: Optional[AsyncSession] = None) -> int:
|
||||
async def get_active_user_count(self, db: AsyncSession | None = None) -> int:
|
||||
async with get_async_db_context(db) as db:
|
||||
# Consider user active if last_active_at within the last 3 minutes
|
||||
three_minutes_ago = int(time.time()) - 180
|
||||
@@ -820,7 +818,7 @@ class UsersTable:
|
||||
return user.last_active_at >= three_minutes_ago
|
||||
return False
|
||||
|
||||
async def is_user_active(self, user_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def is_user_active(self, user_id: str, db: AsyncSession | None = None) -> bool:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(id=user_id))
|
||||
user = result.scalars().first()
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Optional
|
||||
from langchain_core.documents import Document
|
||||
|
||||
import requests
|
||||
from fastapi import HTTPException, status
|
||||
from langchain_core.documents import Document
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import requests
|
||||
import logging, os
|
||||
import logging
|
||||
import os
|
||||
from typing import Iterator, List, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
import requests
|
||||
from langchain_core.document_loaders import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.utils.headers import include_user_info_headers
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import requests
|
||||
import logging
|
||||
from typing import Iterator, List, Union
|
||||
|
||||
import requests
|
||||
from langchain_core.document_loaders import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
import requests
|
||||
import logging
|
||||
import ftfy
|
||||
import sys
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import ftfy
|
||||
import requests
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from langchain_community.document_loaders import (
|
||||
AzureAIDocumentIntelligenceLoader,
|
||||
@@ -17,16 +17,13 @@ from langchain_community.document_loaders import (
|
||||
YoutubeLoader,
|
||||
)
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from open_webui.retrieval.loaders.external_document import ExternalDocumentLoader
|
||||
|
||||
from open_webui.retrieval.loaders.mistral import MistralLoader
|
||||
from open_webui.env import AIOHTTP_CLIENT_SESSION_SSL, GLOBAL_LOG_LEVEL, REQUESTS_VERIFY
|
||||
from open_webui.retrieval.loaders.datalab_marker import DatalabMarkerLoader
|
||||
from open_webui.retrieval.loaders.external_document import ExternalDocumentLoader
|
||||
from open_webui.retrieval.loaders.mineru import MinerULoader
|
||||
from open_webui.retrieval.loaders.mistral import MistralLoader
|
||||
from open_webui.retrieval.loaders.paddleocr_vl import PaddleOCRVLLoader
|
||||
|
||||
from open_webui.env import GLOBAL_LOG_LEVEL, REQUESTS_VERIFY, AIOHTTP_CLIENT_SESSION_SSL
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
import zipfile
|
||||
from typing import List, Optional
|
||||
from langchain_core.documents import Document
|
||||
|
||||
import requests
|
||||
from fastapi import HTTPException, status
|
||||
from langchain_core.documents import Document
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import requests
|
||||
import aiohttp
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from typing import List, Dict, Any
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.env import GLOBAL_LOG_LEVEL, AIOHTTP_CLIENT_SESSION_SSL
|
||||
from open_webui.env import AIOHTTP_CLIENT_SESSION_SSL, GLOBAL_LOG_LEVEL
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import base64
|
||||
import os
|
||||
import requests
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from langchain_core.documents import Document
|
||||
from open_webui.env import GLOBAL_LOG_LEVEL
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import requests
|
||||
import logging
|
||||
from typing import Iterator, List, Literal, Union
|
||||
|
||||
import requests
|
||||
from langchain_core.document_loaders import BaseLoader
|
||||
from langchain_core.documents import Document
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import logging
|
||||
from xml.etree.ElementTree import ParseError
|
||||
|
||||
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from xml.etree.ElementTree import ParseError
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
|
||||
class BaseReranker(ABC):
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from colbert.infra import ColBERTConfig
|
||||
from colbert.modeling.checkpoint import Checkpoint
|
||||
|
||||
|
||||
from open_webui.retrieval.models.base_reranker import BaseReranker
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import logging
|
||||
import requests
|
||||
from typing import Optional, List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
|
||||
import requests
|
||||
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, REQUESTS_VERIFY
|
||||
from open_webui.retrieval.models.base_reranker import BaseReranker
|
||||
from open_webui.utils.headers import include_user_info_headers
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
@@ -516,8 +518,11 @@ def get_all_items_from_collections(collection_names: list[str]) -> dict:
|
||||
|
||||
|
||||
async def query_collection(
|
||||
request, collection_names: list[str], queries: list[str],
|
||||
embedding_function, k: int,
|
||||
request,
|
||||
collection_names: list[str],
|
||||
queries: list[str],
|
||||
embedding_function,
|
||||
k: int,
|
||||
) -> dict:
|
||||
# When request is provided, try hybrid search + reranking if enabled
|
||||
if request and request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
|
||||
@@ -1109,7 +1114,7 @@ async def get_sources_from_items(
|
||||
hybrid_bm25_weight,
|
||||
hybrid_search,
|
||||
full_context=False,
|
||||
user: Optional[UserModel] = None,
|
||||
user: UserModel | None = None,
|
||||
):
|
||||
log.debug(f'items: {items} {queries} {embedding_function} {reranking_function} {full_context}')
|
||||
|
||||
@@ -1421,7 +1426,7 @@ class RerankCompressor(BaseDocumentCompressor):
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
callbacks: Callbacks | None = None,
|
||||
) -> Sequence[Document]:
|
||||
"""Compress retrieved documents given the query context.
|
||||
|
||||
@@ -1440,7 +1445,7 @@ class RerankCompressor(BaseDocumentCompressor):
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
callbacks: Callbacks | None = None,
|
||||
) -> Sequence[Document]:
|
||||
reranking = self.reranking_function is not None
|
||||
|
||||
|
||||
@@ -1,29 +1,27 @@
|
||||
import chromadb
|
||||
import logging
|
||||
from chromadb import Settings
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.retrieval.vector.utils import process_metadata
|
||||
|
||||
import chromadb
|
||||
from chromadb import Settings
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
from open_webui.config import (
|
||||
CHROMA_CLIENT_AUTH_CREDENTIALS,
|
||||
CHROMA_CLIENT_AUTH_PROVIDER,
|
||||
CHROMA_DATA_PATH,
|
||||
CHROMA_DATABASE,
|
||||
CHROMA_HTTP_HEADERS,
|
||||
CHROMA_HTTP_HOST,
|
||||
CHROMA_HTTP_PORT,
|
||||
CHROMA_HTTP_HEADERS,
|
||||
CHROMA_HTTP_SSL,
|
||||
CHROMA_TENANT,
|
||||
CHROMA_DATABASE,
|
||||
CHROMA_CLIENT_AUTH_PROVIDER,
|
||||
CHROMA_CLIENT_AUTH_CREDENTIALS,
|
||||
)
|
||||
from open_webui.retrieval.vector.main import (
|
||||
GetResult,
|
||||
SearchResult,
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
)
|
||||
from open_webui.retrieval.vector.utils import process_metadata
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -2,28 +2,28 @@
|
||||
NOTE: This vector database integration is community-supported and maintained on a best-effort basis.
|
||||
"""
|
||||
|
||||
from elasticsearch import Elasticsearch, BadRequestError
|
||||
from typing import Optional
|
||||
import ssl
|
||||
from elasticsearch.helpers import bulk, scan
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.utils import process_metadata
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from elasticsearch import BadRequestError, Elasticsearch
|
||||
from elasticsearch.helpers import bulk, scan
|
||||
from open_webui.config import (
|
||||
ELASTICSEARCH_URL,
|
||||
ELASTICSEARCH_CA_CERTS,
|
||||
ELASTICSEARCH_API_KEY,
|
||||
ELASTICSEARCH_USERNAME,
|
||||
ELASTICSEARCH_PASSWORD,
|
||||
ELASTICSEARCH_CA_CERTS,
|
||||
ELASTICSEARCH_CLOUD_ID,
|
||||
ELASTICSEARCH_INDEX_PREFIX,
|
||||
ELASTICSEARCH_PASSWORD,
|
||||
ELASTICSEARCH_URL,
|
||||
ELASTICSEARCH_USERNAME,
|
||||
SSL_ASSERT_FINGERPRINT,
|
||||
)
|
||||
from open_webui.retrieval.vector.main import (
|
||||
GetResult,
|
||||
SearchResult,
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
)
|
||||
from open_webui.retrieval.vector.utils import process_metadata
|
||||
|
||||
|
||||
class ElasticsearchClient(VectorDBBase):
|
||||
|
||||
@@ -13,18 +13,15 @@ import sys
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.pool import NullPool, QueuePool
|
||||
|
||||
from open_webui.config import (
|
||||
MARIADB_VECTOR_DB_URL,
|
||||
MARIADB_VECTOR_DISTANCE_STRATEGY,
|
||||
MARIADB_VECTOR_INDEX_M,
|
||||
MARIADB_VECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
|
||||
MARIADB_VECTOR_POOL_SIZE,
|
||||
MARIADB_VECTOR_POOL_MAX_OVERFLOW,
|
||||
MARIADB_VECTOR_POOL_TIMEOUT,
|
||||
MARIADB_VECTOR_POOL_RECYCLE,
|
||||
MARIADB_VECTOR_POOL_SIZE,
|
||||
MARIADB_VECTOR_POOL_TIMEOUT,
|
||||
)
|
||||
from open_webui.retrieval.vector.main import (
|
||||
GetResult,
|
||||
@@ -33,6 +30,8 @@ from open_webui.retrieval.vector.main import (
|
||||
VectorItem,
|
||||
)
|
||||
from open_webui.retrieval.vector.utils import process_metadata
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.pool import NullPool, QueuePool
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -2,33 +2,31 @@
|
||||
NOTE: This vector database integration is community-supported and maintained on a best-effort basis.
|
||||
"""
|
||||
|
||||
from pymilvus import MilvusClient as Client
|
||||
from pymilvus import FieldSchema, DataType
|
||||
from pymilvus import connections, Collection
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.utils import process_metadata
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
MILVUS_URI,
|
||||
MILVUS_DB,
|
||||
MILVUS_TOKEN,
|
||||
MILVUS_INDEX_TYPE,
|
||||
MILVUS_METRIC_TYPE,
|
||||
MILVUS_HNSW_M,
|
||||
MILVUS_HNSW_EFCONSTRUCTION,
|
||||
MILVUS_IVF_FLAT_NLIST,
|
||||
MILVUS_DISKANN_MAX_DEGREE,
|
||||
MILVUS_DISKANN_SEARCH_LIST_SIZE,
|
||||
MILVUS_HNSW_EFCONSTRUCTION,
|
||||
MILVUS_HNSW_M,
|
||||
MILVUS_INDEX_TYPE,
|
||||
MILVUS_IVF_FLAT_NLIST,
|
||||
MILVUS_METRIC_TYPE,
|
||||
MILVUS_TOKEN,
|
||||
MILVUS_URI,
|
||||
)
|
||||
from open_webui.retrieval.vector.main import (
|
||||
GetResult,
|
||||
SearchResult,
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
)
|
||||
from open_webui.retrieval.vector.utils import process_metadata
|
||||
from pymilvus import Collection, DataType, FieldSchema, connections
|
||||
from pymilvus import MilvusClient as Client
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -3,18 +3,18 @@ NOTE: This vector database integration is community-supported and maintained on
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Tuple, List, Dict, Any
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from open_webui.config import (
|
||||
MILVUS_URI,
|
||||
MILVUS_TOKEN,
|
||||
MILVUS_DB,
|
||||
MILVUS_COLLECTION_PREFIX,
|
||||
MILVUS_INDEX_TYPE,
|
||||
MILVUS_METRIC_TYPE,
|
||||
MILVUS_HNSW_M,
|
||||
MILVUS_DB,
|
||||
MILVUS_HNSW_EFCONSTRUCTION,
|
||||
MILVUS_HNSW_M,
|
||||
MILVUS_INDEX_TYPE,
|
||||
MILVUS_IVF_FLAT_NLIST,
|
||||
MILVUS_METRIC_TYPE,
|
||||
MILVUS_TOKEN,
|
||||
MILVUS_URI,
|
||||
)
|
||||
from open_webui.retrieval.vector.main import (
|
||||
GetResult,
|
||||
@@ -23,12 +23,12 @@ from open_webui.retrieval.vector.main import (
|
||||
VectorItem,
|
||||
)
|
||||
from pymilvus import (
|
||||
connections,
|
||||
utility,
|
||||
Collection,
|
||||
CollectionSchema,
|
||||
FieldSchema,
|
||||
DataType,
|
||||
FieldSchema,
|
||||
connections,
|
||||
utility,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -2,37 +2,36 @@
|
||||
NOTE: This vector database integration is community-supported and maintained on a best-effort basis.
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import (
|
||||
func,
|
||||
literal,
|
||||
Column,
|
||||
Integer,
|
||||
LargeBinary,
|
||||
MetaData,
|
||||
Table,
|
||||
Text,
|
||||
cast,
|
||||
column,
|
||||
create_engine,
|
||||
Column,
|
||||
Integer,
|
||||
MetaData,
|
||||
LargeBinary,
|
||||
func,
|
||||
literal,
|
||||
select,
|
||||
text,
|
||||
Text,
|
||||
Table,
|
||||
values,
|
||||
)
|
||||
from sqlalchemy.sql import true
|
||||
from sqlalchemy.pool import NullPool, QueuePool
|
||||
|
||||
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
|
||||
from sqlalchemy.dialects.postgresql import JSONB, array
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
from sqlalchemy.exc import NoSuchTableError
|
||||
|
||||
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
|
||||
from sqlalchemy.dialects import registry
|
||||
from sqlalchemy.dialects.postgresql import JSONB, array
|
||||
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
|
||||
from sqlalchemy.exc import NoSuchTableError
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
|
||||
from sqlalchemy.pool import NullPool, QueuePool
|
||||
from sqlalchemy.sql import true
|
||||
|
||||
|
||||
class OpenGaussDialect(PGDialect_psycopg2):
|
||||
@@ -56,23 +55,22 @@ class OpenGaussDialect(PGDialect_psycopg2):
|
||||
# Register dialect
|
||||
registry.register('opengauss', __name__, 'OpenGaussDialect')
|
||||
|
||||
from open_webui.retrieval.vector.utils import process_metadata
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
OPENGAUSS_DB_URL,
|
||||
OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH,
|
||||
OPENGAUSS_POOL_SIZE,
|
||||
OPENGAUSS_POOL_MAX_OVERFLOW,
|
||||
OPENGAUSS_POOL_TIMEOUT,
|
||||
OPENGAUSS_POOL_RECYCLE,
|
||||
OPENGAUSS_POOL_SIZE,
|
||||
OPENGAUSS_POOL_TIMEOUT,
|
||||
)
|
||||
|
||||
from open_webui.env import SRC_LOG_LEVELS
|
||||
from open_webui.retrieval.vector.main import (
|
||||
GetResult,
|
||||
SearchResult,
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
)
|
||||
from open_webui.retrieval.vector.utils import process_metadata
|
||||
|
||||
VECTOR_LENGTH = OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
Base = declarative_base()
|
||||
|
||||
@@ -2,24 +2,24 @@
|
||||
NOTE: This vector database integration is community-supported and maintained on a best-effort basis.
|
||||
"""
|
||||
|
||||
from opensearchpy import OpenSearch
|
||||
from opensearchpy.helpers import bulk
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.retrieval.vector.utils import process_metadata
|
||||
from open_webui.config import (
|
||||
OPENSEARCH_CERT_VERIFY,
|
||||
OPENSEARCH_PASSWORD,
|
||||
OPENSEARCH_SSL,
|
||||
OPENSEARCH_URI,
|
||||
OPENSEARCH_USERNAME,
|
||||
)
|
||||
from open_webui.retrieval.vector.main import (
|
||||
GetResult,
|
||||
SearchResult,
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
OPENSEARCH_URI,
|
||||
OPENSEARCH_SSL,
|
||||
OPENSEARCH_CERT_VERIFY,
|
||||
OPENSEARCH_USERNAME,
|
||||
OPENSEARCH_PASSWORD,
|
||||
)
|
||||
from open_webui.retrieval.vector.utils import process_metadata
|
||||
from opensearchpy import OpenSearch
|
||||
from opensearchpy.helpers import bulk
|
||||
|
||||
|
||||
class OpenSearchClient(VectorDBBase):
|
||||
|
||||
@@ -28,34 +28,33 @@ ORACLE_DB_POOL_MAX = 10
|
||||
ORACLE_DB_POOL_INCREMENT = 1
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
from decimal import Decimal
|
||||
import array
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import json
|
||||
import array
|
||||
from decimal import Decimal
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import oracledb
|
||||
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
|
||||
from open_webui.config import (
|
||||
ORACLE_DB_DSN,
|
||||
ORACLE_DB_PASSWORD,
|
||||
ORACLE_DB_POOL_INCREMENT,
|
||||
ORACLE_DB_POOL_MAX,
|
||||
ORACLE_DB_POOL_MIN,
|
||||
ORACLE_DB_USE_WALLET,
|
||||
ORACLE_DB_USER,
|
||||
ORACLE_DB_PASSWORD,
|
||||
ORACLE_DB_DSN,
|
||||
ORACLE_VECTOR_LENGTH,
|
||||
ORACLE_WALLET_DIR,
|
||||
ORACLE_WALLET_PASSWORD,
|
||||
ORACLE_VECTOR_LENGTH,
|
||||
ORACLE_DB_POOL_MIN,
|
||||
ORACLE_DB_POOL_MAX,
|
||||
ORACLE_DB_POOL_INCREMENT,
|
||||
)
|
||||
from open_webui.retrieval.vector.main import (
|
||||
GetResult,
|
||||
SearchResult,
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -1,56 +1,54 @@
|
||||
from typing import Optional, List, Dict, Any, Tuple
|
||||
import logging
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from open_webui.config import (
|
||||
PGVECTOR_CREATE_EXTENSION,
|
||||
PGVECTOR_DB_URL,
|
||||
PGVECTOR_HNSW_EF_CONSTRUCTION,
|
||||
PGVECTOR_HNSW_M,
|
||||
PGVECTOR_INDEX_METHOD,
|
||||
PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
|
||||
PGVECTOR_IVFFLAT_LISTS,
|
||||
PGVECTOR_PGCRYPTO,
|
||||
PGVECTOR_PGCRYPTO_KEY,
|
||||
PGVECTOR_POOL_MAX_OVERFLOW,
|
||||
PGVECTOR_POOL_RECYCLE,
|
||||
PGVECTOR_POOL_SIZE,
|
||||
PGVECTOR_POOL_TIMEOUT,
|
||||
PGVECTOR_USE_HALFVEC,
|
||||
)
|
||||
from open_webui.retrieval.vector.main import (
|
||||
GetResult,
|
||||
SearchResult,
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
)
|
||||
from open_webui.retrieval.vector.utils import process_metadata
|
||||
from open_webui.utils.misc import sanitize_text_for_db
|
||||
from pgvector.sqlalchemy import HALFVEC, Vector
|
||||
from sqlalchemy import (
|
||||
func,
|
||||
literal,
|
||||
Column,
|
||||
Integer,
|
||||
LargeBinary,
|
||||
MetaData,
|
||||
Table,
|
||||
Text,
|
||||
cast,
|
||||
column,
|
||||
create_engine,
|
||||
Column,
|
||||
Integer,
|
||||
MetaData,
|
||||
LargeBinary,
|
||||
func,
|
||||
literal,
|
||||
select,
|
||||
text,
|
||||
Text,
|
||||
Table,
|
||||
values,
|
||||
)
|
||||
from sqlalchemy.sql import true
|
||||
from sqlalchemy.pool import NullPool, QueuePool
|
||||
|
||||
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
|
||||
from sqlalchemy.dialects.postgresql import JSONB, array
|
||||
from pgvector.sqlalchemy import Vector, HALFVEC
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
from sqlalchemy.exc import NoSuchTableError
|
||||
|
||||
|
||||
from open_webui.retrieval.vector.utils import process_metadata
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.utils.misc import sanitize_text_for_db
|
||||
from open_webui.config import (
|
||||
PGVECTOR_DB_URL,
|
||||
PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
|
||||
PGVECTOR_CREATE_EXTENSION,
|
||||
PGVECTOR_PGCRYPTO,
|
||||
PGVECTOR_PGCRYPTO_KEY,
|
||||
PGVECTOR_POOL_SIZE,
|
||||
PGVECTOR_POOL_MAX_OVERFLOW,
|
||||
PGVECTOR_POOL_TIMEOUT,
|
||||
PGVECTOR_POOL_RECYCLE,
|
||||
PGVECTOR_INDEX_METHOD,
|
||||
PGVECTOR_HNSW_M,
|
||||
PGVECTOR_HNSW_EF_CONSTRUCTION,
|
||||
PGVECTOR_IVFFLAT_LISTS,
|
||||
PGVECTOR_USE_HALFVEC,
|
||||
)
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
|
||||
from sqlalchemy.pool import NullPool, QueuePool
|
||||
from sqlalchemy.sql import true
|
||||
|
||||
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
|
||||
USE_HALFVEC = PGVECTOR_USE_HALFVEC
|
||||
|
||||
@@ -2,9 +2,10 @@
|
||||
NOTE: This vector database integration is community-supported and maintained on a best-effort basis.
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
import logging
|
||||
import time # for measuring elapsed time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pinecone import Pinecone, ServerlessSpec
|
||||
|
||||
# Add gRPC support for better performance (Pinecone best practice)
|
||||
@@ -16,24 +17,23 @@ except ImportError:
|
||||
GRPC_AVAILABLE = False
|
||||
|
||||
import asyncio # for async upserts
|
||||
import functools # for partial binding in async tasks
|
||||
|
||||
import concurrent.futures # for parallel batch upserts
|
||||
import functools # for partial binding in async tasks
|
||||
import random # for jitter in retry backoff
|
||||
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
PINECONE_API_KEY,
|
||||
PINECONE_CLOUD,
|
||||
PINECONE_DIMENSION,
|
||||
PINECONE_ENVIRONMENT,
|
||||
PINECONE_INDEX_NAME,
|
||||
PINECONE_DIMENSION,
|
||||
PINECONE_METRIC,
|
||||
PINECONE_CLOUD,
|
||||
)
|
||||
from open_webui.retrieval.vector.main import (
|
||||
GetResult,
|
||||
SearchResult,
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
)
|
||||
from open_webui.retrieval.vector.utils import process_metadata
|
||||
|
||||
|
||||
@@ -2,31 +2,30 @@
|
||||
NOTE: This vector database integration is community-supported and maintained on a best-effort basis.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import logging
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from open_webui.config import (
|
||||
QDRANT_API_KEY,
|
||||
QDRANT_COLLECTION_PREFIX,
|
||||
QDRANT_GRPC_PORT,
|
||||
QDRANT_HNSW_M,
|
||||
QDRANT_ON_DISK,
|
||||
QDRANT_PREFER_GRPC,
|
||||
QDRANT_TIMEOUT,
|
||||
QDRANT_URI,
|
||||
)
|
||||
from open_webui.retrieval.vector.main import (
|
||||
GetResult,
|
||||
SearchResult,
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
)
|
||||
from qdrant_client import QdrantClient as Qclient
|
||||
from qdrant_client.http.models import PointStruct
|
||||
from qdrant_client.models import models
|
||||
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
SearchResult,
|
||||
GetResult,
|
||||
)
|
||||
from open_webui.config import (
|
||||
QDRANT_URI,
|
||||
QDRANT_API_KEY,
|
||||
QDRANT_ON_DISK,
|
||||
QDRANT_GRPC_PORT,
|
||||
QDRANT_PREFER_GRPC,
|
||||
QDRANT_COLLECTION_PREFIX,
|
||||
QDRANT_TIMEOUT,
|
||||
QDRANT_HNSW_M,
|
||||
)
|
||||
|
||||
NO_LIMIT = 999999999
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -3,19 +3,19 @@ NOTE: This vector database integration is community-supported and maintained on
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Tuple, List, Dict, Any
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import grpc
|
||||
from open_webui.config import (
|
||||
QDRANT_API_KEY,
|
||||
QDRANT_COLLECTION_PREFIX,
|
||||
QDRANT_GRPC_PORT,
|
||||
QDRANT_HNSW_M,
|
||||
QDRANT_ON_DISK,
|
||||
QDRANT_PREFER_GRPC,
|
||||
QDRANT_URI,
|
||||
QDRANT_COLLECTION_PREFIX,
|
||||
QDRANT_TIMEOUT,
|
||||
QDRANT_HNSW_M,
|
||||
QDRANT_URI,
|
||||
)
|
||||
from open_webui.retrieval.vector.main import (
|
||||
GetResult,
|
||||
|
||||
@@ -2,17 +2,18 @@
|
||||
NOTE: This vector database integration is community-supported and maintained on a best-effort basis.
|
||||
"""
|
||||
|
||||
from open_webui.retrieval.vector.utils import process_metadata
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import boto3
|
||||
from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION
|
||||
from open_webui.retrieval.vector.main import (
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
GetResult,
|
||||
SearchResult,
|
||||
VectorDBBase,
|
||||
VectorItem,
|
||||
)
|
||||
from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
import logging
|
||||
import boto3
|
||||
from open_webui.retrieval.vector.utils import process_metadata
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user