diff --git a/backend/open_webui/config.py b/backend/open_webui/config.py index b353a1b7ea..da29f90105 100644 --- a/backend/open_webui/config.py +++ b/backend/open_webui/config.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import json import logging @@ -20,29 +22,36 @@ from open_webui.env import ( DATABASE_URL, ENABLE_DB_MIGRATIONS, ENV, - REDIS_URL, - REDIS_KEY_PREFIX, - REDIS_SENTINEL_HOSTS, - REDIS_SENTINEL_PORT, FRONTEND_BUILD_DIR, OFFLINE_MODE, OPEN_WEBUI_DIR, + REDIS_KEY_PREFIX, + REDIS_SENTINEL_HOSTS, + REDIS_SENTINEL_PORT, + REDIS_URL, WEBUI_AUTH, WEBUI_FAVICON_URL, WEBUI_NAME, log, ) +from open_webui.internal.config import ( + STATE as _state, +) +from open_webui.internal.config import ( + AppConfig, + ConfigVar, +) # ── Persistent configuration layer ────────────────────────────────────────── - from open_webui.internal.config import ( # noqa: F401 ConfigTable as Config, - ConfigVar, - AppConfig, - STATE as _state, - initialize as _initialize_config, +) +from open_webui.internal.config import ( _all_configs as PERSISTENT_CONFIG_REGISTRY, ) +from open_webui.internal.config import ( + initialize as _initialize_config, +) def get_config(): @@ -1288,9 +1297,7 @@ RAG_EMBEDDING_MODEL_AUTO_UPDATE = ( not OFFLINE_MODE and os.getenv('RAG_EMBEDDING_MODEL_AUTO_UPDATE', 'True').lower() == 'true' ) -RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( - os.getenv('RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE', 'True').lower() == 'true' -) +RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = os.getenv('RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE', 'True').lower() == 'true' RAG_EMBEDDING_BATCH_SIZE = ConfigVar( 'RAG_EMBEDDING_BATCH_SIZE', @@ -1335,9 +1342,7 @@ RAG_RERANKING_MODEL_AUTO_UPDATE = ( not OFFLINE_MODE and os.getenv('RAG_RERANKING_MODEL_AUTO_UPDATE', 'True').lower() == 'true' ) -RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( - os.getenv('RAG_RERANKING_MODEL_TRUST_REMOTE_CODE', 'True').lower() == 'true' -) +RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = os.getenv('RAG_RERANKING_MODEL_TRUST_REMOTE_CODE', 'True').lower() == 'true' RAG_RERANKING_BATCH_SIZE = ConfigVar( 'RAG_RERANKING_BATCH_SIZE', @@ -2709,9 +2714,7 @@ USER_PERMISSIONS_CHAT_DELETE = os.getenv('USER_PERMISSIONS_CHAT_DELETE', 'True') USER_PERMISSIONS_CHAT_DELETE_MESSAGE = os.getenv('USER_PERMISSIONS_CHAT_DELETE_MESSAGE', 'True').lower() == 'true' -USER_PERMISSIONS_CHAT_CONTINUE_RESPONSE = ( - os.getenv('USER_PERMISSIONS_CHAT_CONTINUE_RESPONSE', 'True').lower() == 'true' -) +USER_PERMISSIONS_CHAT_CONTINUE_RESPONSE = os.getenv('USER_PERMISSIONS_CHAT_CONTINUE_RESPONSE', 'True').lower() == 'true' USER_PERMISSIONS_CHAT_REGENERATE_RESPONSE = ( os.getenv('USER_PERMISSIONS_CHAT_REGENERATE_RESPONSE', 'True').lower() == 'true' @@ -2735,9 +2738,7 @@ USER_PERMISSIONS_CHAT_TTS = os.getenv('USER_PERMISSIONS_CHAT_TTS', 'True').lower USER_PERMISSIONS_CHAT_CALL = os.getenv('USER_PERMISSIONS_CHAT_CALL', 'True').lower() == 'true' -USER_PERMISSIONS_CHAT_MULTIPLE_MODELS = ( - os.getenv('USER_PERMISSIONS_CHAT_MULTIPLE_MODELS', 'True').lower() == 'true' -) +USER_PERMISSIONS_CHAT_MULTIPLE_MODELS = os.getenv('USER_PERMISSIONS_CHAT_MULTIPLE_MODELS', 'True').lower() == 'true' USER_PERMISSIONS_CHAT_TEMPORARY = os.getenv('USER_PERMISSIONS_CHAT_TEMPORARY', 'True').lower() == 'true' @@ -2770,9 +2771,7 @@ USER_PERMISSIONS_FEATURES_API_KEYS = os.getenv('USER_PERMISSIONS_FEATURES_API_KE USER_PERMISSIONS_FEATURES_MEMORIES = os.getenv('USER_PERMISSIONS_FEATURES_MEMORIES', 'True').lower() == 'true' -USER_PERMISSIONS_FEATURES_AUTOMATIONS = ( - os.getenv('USER_PERMISSIONS_FEATURES_AUTOMATIONS', 'False').lower() == 'true' -) +USER_PERMISSIONS_FEATURES_AUTOMATIONS = os.getenv('USER_PERMISSIONS_FEATURES_AUTOMATIONS', 'False').lower() == 'true' USER_PERMISSIONS_FEATURES_CALENDAR = os.getenv('USER_PERMISSIONS_FEATURES_CALENDAR', 'True').lower() == 'true' @@ -2940,9 +2939,7 @@ WEBHOOK_URL = ConfigVar('WEBHOOK_URL', 'webhook_url', os.getenv('WEBHOOK_URL', ' ENABLE_ADMIN_EXPORT = os.getenv('ENABLE_ADMIN_EXPORT', 'True').lower() == 'true' -ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS = ( - os.getenv('ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS', 'True').lower() == 'true' -) +ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS = os.getenv('ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS', 'True').lower() == 'true' BYPASS_ADMIN_ACCESS_CONTROL = ( os.getenv( @@ -3024,7 +3021,7 @@ else: class BannerModel(BaseModel): id: str type: str - title: Optional[str] = None + title: str | None = None content: str dismissible: bool timestamp: int @@ -3705,9 +3702,7 @@ OAUTH_ALLOWED_ROLES = ConfigVar( 'oauth.allowed_roles', [ role.strip() - for role in os.getenv('OAUTH_ALLOWED_ROLES', f'user{OAUTH_ROLES_SEPARATOR}admin').split( - OAUTH_ROLES_SEPARATOR - ) + for role in os.getenv('OAUTH_ALLOWED_ROLES', f'user{OAUTH_ROLES_SEPARATOR}admin').split(OAUTH_ROLES_SEPARATOR) if role ], ) diff --git a/backend/open_webui/constants.py b/backend/open_webui/constants.py index c767e72520..132f3ac19a 100644 --- a/backend/open_webui/constants.py +++ b/backend/open_webui/constants.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import Enum diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index d1c3bdee9f..5b830fd152 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -530,9 +530,7 @@ except (ValueError, TypeError): AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = 10 -AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL = ( - os.getenv('AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL', 'True').lower() == 'true' -) +AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL = os.getenv('AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL', 'True').lower() == 'true' AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER = os.getenv('AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER', '') @@ -779,9 +777,7 @@ FORWARD_USER_INFO_HEADER_USER_NAME = os.getenv('FORWARD_USER_INFO_HEADER_USER_NA FORWARD_USER_INFO_HEADER_USER_ID = os.getenv('FORWARD_USER_INFO_HEADER_USER_ID', 'X-OpenWebUI-User-Id') FORWARD_USER_INFO_HEADER_USER_EMAIL = os.getenv('FORWARD_USER_INFO_HEADER_USER_EMAIL', 'X-OpenWebUI-User-Email') FORWARD_USER_INFO_HEADER_USER_ROLE = os.getenv('FORWARD_USER_INFO_HEADER_USER_ROLE', 'X-OpenWebUI-User-Role') -FORWARD_SESSION_INFO_HEADER_MESSAGE_ID = os.getenv( - 'FORWARD_SESSION_INFO_HEADER_MESSAGE_ID', 'X-OpenWebUI-Message-Id' -) +FORWARD_SESSION_INFO_HEADER_MESSAGE_ID = os.getenv('FORWARD_SESSION_INFO_HEADER_MESSAGE_ID', 'X-OpenWebUI-Message-Id') FORWARD_SESSION_INFO_HEADER_CHAT_ID = os.getenv('FORWARD_SESSION_INFO_HEADER_CHAT_ID', 'X-OpenWebUI-Chat-Id') #################################### @@ -897,9 +893,7 @@ if SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND == '': SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = 'torch' -SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = os.getenv( - 'SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS', '' -) +SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = os.getenv('SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS', '') if SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS == '': SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None else: diff --git a/backend/open_webui/functions.py b/backend/open_webui/functions.py index a3f99bb182..7c50ca3d2d 100644 --- a/backend/open_webui/functions.py +++ b/backend/open_webui/functions.py @@ -1,11 +1,10 @@ -import logging -import sys +import asyncio import inspect import json -import asyncio - -from pydantic import BaseModel +import logging +import sys from typing import AsyncGenerator, Generator, Iterator + from fastapi import ( Depends, FastAPI, @@ -16,40 +15,35 @@ from fastapi import ( UploadFile, status, ) +from pydantic import BaseModel from starlette.responses import Response, StreamingResponse - +from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL from open_webui.constants import ERROR_MESSAGES +from open_webui.env import BYPASS_MODEL_ACCESS_CONTROL, GLOBAL_LOG_LEVEL +from open_webui.models.functions import Functions +from open_webui.models.models import Models +from open_webui.models.users import UserModel from open_webui.socket.main import ( get_event_call, get_event_emitter, ) - - -from open_webui.models.users import UserModel -from open_webui.models.functions import Functions -from open_webui.models.models import Models - -from open_webui.utils.plugin import ( - load_function_module_by_id, - get_function_module_from_cache, -) from open_webui.utils.access_control import check_model_access - -from open_webui.env import GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL -from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL - from open_webui.utils.misc import ( add_or_update_system_message, get_last_user_message, - prepend_to_first_user_message_content, openai_chat_chunk_message_template, openai_chat_completion_message_template, + prepend_to_first_user_message_content, ) from open_webui.utils.payload import ( apply_model_params_to_body_openai, apply_system_prompt_to_body, ) +from open_webui.utils.plugin import ( + get_function_module_from_cache, + load_function_module_by_id, +) logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) diff --git a/backend/open_webui/internal/config.py b/backend/open_webui/internal/config.py index f110052731..294144ae76 100644 --- a/backend/open_webui/internal/config.py +++ b/backend/open_webui/internal/config.py @@ -10,10 +10,9 @@ from functools import reduce from typing import Any, Optional, Union import redis -from sqlalchemy import JSON, Column, DateTime, Integer, func, select - from open_webui.internal.db import Base, get_async_db, get_db from open_webui.utils.redis import get_redis_connection +from sqlalchemy import JSON, Column, DateTime, Integer, func, select log = logging.getLogger(__name__) @@ -22,7 +21,7 @@ log = logging.getLogger(__name__) class ConfigTable(Base): - __tablename__ = "config" + __tablename__ = 'config' id = Column(Integer, primary_key=True) data = Column(JSON, nullable=False) @@ -37,7 +36,7 @@ class ConfigTable(Base): class ConfigState: """In-memory mirror of the single-row config JSON blob.""" - __slots__ = ("_data",) + __slots__ = ('_data',) def __init__(self) -> None: self._data: dict[str, Any] = {} @@ -49,12 +48,12 @@ class ConfigState: def read(self, path: str) -> Any: return reduce( lambda n, k: n.get(k) if isinstance(n, dict) else None, - path.split("."), + path.split('.'), self._data, ) def write(self, path: str, value: Any) -> None: - keys = path.split(".") + keys = path.split('.') reduce(lambda d, k: d.setdefault(k, {}), keys[:-1], self._data)[keys[-1]] = value def replace(self, data: dict) -> None: @@ -63,7 +62,7 @@ class ConfigState: def load(self) -> dict: with get_db() as db: row = db.query(ConfigTable).order_by(ConfigTable.id.desc()).first() - self._data = row.data if row else {"version": 0, "ui": {}} + self._data = row.data if row else {'version': 0, 'ui': {}} return self._data def persist(self, data: dict | None = None) -> None: @@ -123,7 +122,7 @@ def initialize(*, enable_persistent: bool = True, enable_oauth_persistent: bool class ConfigVar: - __slots__ = ("env_name", "config_path", "env_value", "config_value", "value") + __slots__ = ('env_name', 'config_path', 'env_value', 'config_value', 'value') def __init__(self, env_name: str, config_path: str, env_value: Any) -> None: self.env_name = env_name @@ -132,7 +131,7 @@ class ConfigVar: self.config_value = STATE.read(config_path) if self.config_value is not None and _persist_enabled: - if config_path.startswith("oauth.") and not _oauth_persist_enabled: + if config_path.startswith('oauth.') and not _oauth_persist_enabled: log.info("Skipping DB value for '%s' (OAuth persistence disabled)", env_name) self.value = env_value else: @@ -147,22 +146,22 @@ class ConfigVar: return str(self.value) def __repr__(self) -> str: - return f"" + return f'' @property def __dict__(self): # type: ignore[override] raise TypeError(f"ConfigVar('{self.env_name}') cannot be cast to dict; use .value") def __getattribute__(self, item: str): - if item == "__dict__": - raise TypeError("ConfigVar cannot be cast to dict; use .value") + if item == '__dict__': + raise TypeError('ConfigVar cannot be cast to dict; use .value') return super().__getattribute__(item) def refresh(self) -> None: current = STATE.read(self.config_path) if current is not None: self.value = current - log.info("Refreshed %s → %s", self.env_name, self.value) + log.info('Refreshed %s → %s', self.env_name, self.value) def commit(self) -> None: log.info("Persisting '%s'", self.env_name) @@ -189,18 +188,18 @@ class AppConfig: redis_url: Optional[str] = None, redis_sentinels: Optional[list] = None, redis_cluster: bool = False, - redis_key_prefix: str = "open-webui", + redis_key_prefix: str = 'open-webui', ) -> None: - super().__setattr__("_entries", {}) - super().__setattr__("_key_prefix", redis_key_prefix) + super().__setattr__('_entries', {}) + super().__setattr__('_key_prefix', redis_key_prefix) rc: Union[redis.Redis, redis.cluster.RedisCluster, None] = None if redis_url: rc = get_redis_connection(redis_url, redis_sentinels or [], redis_cluster, decode_responses=True) - super().__setattr__("_rc", rc) + super().__setattr__('_rc', rc) def __setattr__(self, name: str, value: Any) -> None: - entries: dict = super().__getattribute__("_entries") + entries: dict = super().__getattribute__('_entries') if isinstance(value, ConfigVar): entries[name] = value @@ -213,11 +212,11 @@ class AppConfig: except RuntimeError: entries[name].commit() - rc = super().__getattribute__("_rc") + rc = super().__getattribute__('_rc') if rc and _persist_enabled: - prefix = super().__getattribute__("_key_prefix") + prefix = super().__getattribute__('_key_prefix') try: - rc.set(f"{prefix}:config:{name}", json.dumps(entries[name].value)) + rc.set(f'{prefix}:config:{name}', json.dumps(entries[name].value)) except Exception as exc: log.error("Redis write failed for '%s': %s", name, exc) @@ -228,15 +227,15 @@ class AppConfig: log.error("Async persist failed for '%s': %s", name, exc) def __getattr__(self, name: str) -> Any: - entries = super().__getattribute__("_entries") + entries = super().__getattribute__('_entries') if name not in entries: raise AttributeError(f"No config key '{name}'") - rc = super().__getattribute__("_rc") + rc = super().__getattribute__('_rc') if rc and _persist_enabled: - prefix = super().__getattribute__("_key_prefix") + prefix = super().__getattribute__('_key_prefix') try: - raw = rc.get(f"{prefix}:config:{name}") + raw = rc.get(f'{prefix}:config:{name}') if raw is not None: decoded = json.loads(raw) if entries[name].value != decoded: @@ -248,12 +247,12 @@ class AppConfig: return entries[name].value def _sync_to_redis(self) -> None: - rc = super().__getattribute__("_rc") + rc = super().__getattribute__('_rc') if not rc or not _persist_enabled: return - prefix = super().__getattribute__("_key_prefix") - for name, s in super().__getattribute__("_entries").items(): + prefix = super().__getattribute__('_key_prefix') + for name, s in super().__getattribute__('_entries').items(): try: - rc.set(f"{prefix}:config:{name}", json.dumps(s.value)) + rc.set(f'{prefix}:config:{name}', json.dumps(s.value)) except Exception as exc: log.error("Redis sync failed for '%s': %s", name, exc) diff --git a/backend/open_webui/internal/db.py b/backend/open_webui/internal/db.py index bf97fec1d0..69050160a8 100644 --- a/backend/open_webui/internal/db.py +++ b/backend/open_webui/internal/db.py @@ -1,35 +1,36 @@ -import os -import sys +from __future__ import annotations + import json import logging +import os +import sys from contextlib import asynccontextmanager, contextmanager from typing import Any, Optional from urllib.parse import parse_qs, urlencode, urlparse, urlunparse - from open_webui.env import ( - OPEN_WEBUI_DIR, - DATABASE_URL, - DATABASE_SCHEMA, + DATABASE_ENABLE_SESSION_SHARING, + DATABASE_ENABLE_SQLITE_WAL, DATABASE_POOL_MAX_OVERFLOW, DATABASE_POOL_RECYCLE, DATABASE_POOL_SIZE, DATABASE_POOL_TIMEOUT, - DATABASE_ENABLE_SQLITE_WAL, - DATABASE_ENABLE_SESSION_SHARING, - DATABASE_SQLITE_PRAGMA_SYNCHRONOUS, + DATABASE_SCHEMA, DATABASE_SQLITE_PRAGMA_BUSY_TIMEOUT, DATABASE_SQLITE_PRAGMA_CACHE_SIZE, - DATABASE_SQLITE_PRAGMA_TEMP_STORE, - DATABASE_SQLITE_PRAGMA_MMAP_SIZE, DATABASE_SQLITE_PRAGMA_JOURNAL_SIZE_LIMIT, + DATABASE_SQLITE_PRAGMA_MMAP_SIZE, + DATABASE_SQLITE_PRAGMA_SYNCHRONOUS, + DATABASE_SQLITE_PRAGMA_TEMP_STORE, + DATABASE_URL, ENABLE_DB_MIGRATIONS, + OPEN_WEBUI_DIR, ) -from sqlalchemy import Dialect, create_engine, MetaData, event, types -from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker +from sqlalchemy import Dialect, MetaData, create_engine, event, types +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import scoped_session, sessionmaker, Session -from sqlalchemy.pool import QueuePool, NullPool +from sqlalchemy.orm import Session, scoped_session, sessionmaker +from sqlalchemy.pool import NullPool, QueuePool from sqlalchemy.sql.type_api import _T from typing_extensions import Self @@ -120,10 +121,10 @@ class JSONField(types.TypeDecorator): impl = types.Text cache_ok = True - def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Any: + def process_bind_param(self, value: _T | None, dialect: Dialect) -> Any: return json.dumps(value) - def process_result_value(self, value: Optional[_T], dialect: Dialect) -> Any: + def process_result_value(self, value: _T | None, dialect: Dialect) -> Any: if value is not None: return json.loads(value) @@ -374,7 +375,7 @@ async def get_async_db(): @asynccontextmanager -async def get_async_db_context(db: Optional[AsyncSession] = None): +async def get_async_db_context(db: AsyncSession | None = None): """Async context manager that reuses an existing session if provided and session sharing is enabled.""" if isinstance(db, AsyncSession) and DATABASE_ENABLE_SESSION_SHARING: yield db diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 90afaf1bd2..a8e651a321 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -1,31 +1,26 @@ +from __future__ import annotations + import asyncio import inspect import json import logging import mimetypes import os +import random +import re import shutil import sys import time -import random -import re +from contextlib import asynccontextmanager +from typing import Optional +from urllib.parse import parse_qs, urlencode, urlparse from uuid import uuid4 - -from contextlib import asynccontextmanager -from urllib.parse import urlencode, parse_qs, urlparse -from pydantic import BaseModel -from sqlalchemy import text - -from typing import Optional -from aiocache import cached import aiohttp import anyio.to_thread - -from redis import Redis - - +from aiocache import cached from fastapi import ( + BackgroundTasks, Depends, FastAPI, File, @@ -33,30 +28,510 @@ from fastapi import ( HTTPException, Request, UploadFile, - status, applications, - BackgroundTasks, + status, ) -from fastapi.openapi.docs import get_swagger_ui_html - from fastapi.middleware.cors import CORSMiddleware +from fastapi.openapi.docs import get_swagger_ui_html from fastapi.responses import FileResponse, JSONResponse, RedirectResponse from fastapi.staticfiles import StaticFiles - -from starlette_compress import CompressMiddleware - +from pydantic import BaseModel +from redis import Redis +from sqlalchemy import text +from sqlalchemy.ext.asyncio import AsyncSession +from starlette.datastructures import Headers from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.middleware.sessions import SessionMiddleware from starlette.responses import Response, StreamingResponse -from starlette.datastructures import Headers - +from starlette_compress import CompressMiddleware +from starsessions import ( + SessionAutoloadMiddleware, +) from starsessions import ( SessionMiddleware as StarSessionsMiddleware, - SessionAutoloadMiddleware, ) from starsessions.stores.redis import RedisStore +from open_webui.config import ( + ADMIN_EMAIL, + API_KEYS_ALLOWED_ENDPOINTS, + AUDIO_STT_ALLOWED_EXTENSIONS, + AUDIO_STT_AZURE_API_KEY, + AUDIO_STT_AZURE_BASE_URL, + AUDIO_STT_AZURE_LOCALES, + AUDIO_STT_AZURE_MAX_SPEAKERS, + AUDIO_STT_AZURE_REGION, + # Audio + AUDIO_STT_ENGINE, + AUDIO_STT_MISTRAL_API_BASE_URL, + AUDIO_STT_MISTRAL_API_KEY, + AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS, + AUDIO_STT_MODEL, + AUDIO_STT_OPENAI_API_BASE_URL, + AUDIO_STT_OPENAI_API_KEY, + AUDIO_STT_SUPPORTED_CONTENT_TYPES, + AUDIO_TTS_API_KEY, + AUDIO_TTS_AZURE_SPEECH_BASE_URL, + AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT, + AUDIO_TTS_AZURE_SPEECH_REGION, + AUDIO_TTS_ENGINE, + AUDIO_TTS_MISTRAL_API_BASE_URL, + AUDIO_TTS_MISTRAL_API_KEY, + AUDIO_TTS_MODEL, + AUDIO_TTS_OPENAI_API_BASE_URL, + AUDIO_TTS_OPENAI_API_KEY, + AUDIO_TTS_OPENAI_PARAMS, + AUDIO_TTS_SPLIT_ON, + AUDIO_TTS_VOICE, + AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, + AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, + # Image + AUTOMATIC1111_API_AUTH, + AUTOMATIC1111_BASE_URL, + AUTOMATIC1111_PARAMS, + AUTOMATION_MAX_COUNT, + AUTOMATION_MIN_INTERVAL, + BING_SEARCH_V7_ENDPOINT, + BING_SEARCH_V7_SUBSCRIPTION_KEY, + BOCHA_SEARCH_API_KEY, + BRAVE_SEARCH_API_KEY, + BRAVE_SEARCH_CONTEXT_TOKENS, + BYPASS_ADMIN_ACCESS_CONTROL, + BYPASS_EMBEDDING_AND_RETRIEVAL, + BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, + BYPASS_WEB_SEARCH_WEB_LOADER, + CACHE_DIR, + CHUNK_MIN_SIZE_TARGET, + CHUNK_OVERLAP, + CHUNK_SIZE, + CODE_EXECUTION_ENGINE, + CODE_EXECUTION_JUPYTER_AUTH, + CODE_EXECUTION_JUPYTER_AUTH_PASSWORD, + CODE_EXECUTION_JUPYTER_AUTH_TOKEN, + CODE_EXECUTION_JUPYTER_TIMEOUT, + CODE_EXECUTION_JUPYTER_URL, + CODE_INTERPRETER_ENGINE, + CODE_INTERPRETER_JUPYTER_AUTH, + CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, + CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, + CODE_INTERPRETER_JUPYTER_TIMEOUT, + CODE_INTERPRETER_JUPYTER_URL, + CODE_INTERPRETER_PROMPT_TEMPLATE, + COMFYUI_API_KEY, + COMFYUI_BASE_URL, + COMFYUI_WORKFLOW, + COMFYUI_WORKFLOW_NODES, + CONTENT_EXTRACTION_ENGINE, + CORS_ALLOW_ORIGIN, + DATALAB_MARKER_ADDITIONAL_CONFIG, + DATALAB_MARKER_API_BASE_URL, + DATALAB_MARKER_API_KEY, + DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION, + DATALAB_MARKER_FORCE_OCR, + DATALAB_MARKER_FORMAT_LINES, + DATALAB_MARKER_OUTPUT_FORMAT, + DATALAB_MARKER_PAGINATE, + DATALAB_MARKER_SKIP_CACHE, + DATALAB_MARKER_STRIP_EXISTING_OCR, + DATALAB_MARKER_USE_LLM, + DDGS_BACKEND, + DEEPGRAM_API_KEY, + DEFAULT_ARENA_MODEL, + DEFAULT_GROUP_ID, + DEFAULT_LOCALE, + DEFAULT_MODEL_METADATA, + DEFAULT_MODEL_PARAMS, + DEFAULT_MODELS, + DEFAULT_PINNED_MODELS, + DEFAULT_PROMPT_SUGGESTIONS, + DEFAULT_RAG_TEMPLATE, + DEFAULT_USER_ROLE, + DOCLING_API_KEY, + DOCLING_PARAMS, + DOCLING_SERVER_URL, + DOCUMENT_INTELLIGENCE_ENDPOINT, + DOCUMENT_INTELLIGENCE_KEY, + DOCUMENT_INTELLIGENCE_MODEL, + ENABLE_ADMIN_ANALYTICS, + # Admin + ENABLE_ADMIN_CHAT_ACCESS, + ENABLE_ADMIN_EXPORT, + ENABLE_API_KEYS, + ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS, + ENABLE_ASYNC_EMBEDDING, + ENABLE_AUTOCOMPLETE_GENERATION, + ENABLE_AUTOMATIONS, + # Model list + ENABLE_BASE_MODELS_CACHE, + ENABLE_CALENDAR, + ENABLE_CHANNELS, + # Code Execution + ENABLE_CODE_EXECUTION, + ENABLE_CODE_INTERPRETER, + ENABLE_COMMUNITY_SHARING, + # Direct Connections + ENABLE_DIRECT_CONNECTIONS, + ENABLE_EVALUATION_ARENA_MODELS, + ENABLE_FOLDERS, + ENABLE_FOLLOW_UP_GENERATION, + ENABLE_GOOGLE_DRIVE_INTEGRATION, + ENABLE_IMAGE_EDIT, + ENABLE_IMAGE_GENERATION, + ENABLE_IMAGE_PROMPT_GENERATION, + # WebUI (LDAP) + ENABLE_LDAP, + ENABLE_LDAP_GROUP_CREATION, + # LDAP Group Management + ENABLE_LDAP_GROUP_MANAGEMENT, + ENABLE_LOGIN_FORM, + ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER, + ENABLE_MEMORIES, + ENABLE_MESSAGE_RATING, + ENABLE_NOTES, + # WebUI (OAuth) + ENABLE_OAUTH_ROLE_MANAGEMENT, + # Ollama + ENABLE_OLLAMA_API, + ENABLE_ONEDRIVE_BUSINESS, + ENABLE_ONEDRIVE_INTEGRATION, + ENABLE_ONEDRIVE_PERSONAL, + # OpenAI + ENABLE_OPENAI_API, + ENABLE_PASSWORD_CHANGE_FORM, + ENABLE_RAG_HYBRID_SEARCH, + ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS, + ENABLE_RAG_LOCAL_WEB_FETCH, + ENABLE_RETRIEVAL_QUERY_GENERATION, + ENABLE_SEARCH_QUERY_GENERATION, + ENABLE_SIGNUP, + ENABLE_TAGS_GENERATION, + ENABLE_TITLE_GENERATION, + ENABLE_USER_STATUS, + ENABLE_USER_WEBHOOKS, + ENABLE_VOICE_MODE_PROMPT, + ENABLE_WEB_LOADER_SSL_VERIFICATION, + # Retrieval (Web Search) + ENABLE_WEB_SEARCH, + # Misc + ENV, + EVALUATION_ARENA_MODELS, + EXA_API_KEY, + EXTERNAL_DOCUMENT_LOADER_API_KEY, + EXTERNAL_DOCUMENT_LOADER_URL, + EXTERNAL_WEB_LOADER_API_KEY, + EXTERNAL_WEB_LOADER_URL, + EXTERNAL_WEB_SEARCH_API_KEY, + EXTERNAL_WEB_SEARCH_URL, + FILE_IMAGE_COMPRESSION_HEIGHT, + FILE_IMAGE_COMPRESSION_WIDTH, + FIRECRAWL_API_BASE_URL, + FIRECRAWL_API_KEY, + FIRECRAWL_TIMEOUT, + FOLDER_MAX_FILE_COUNT, + FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, + FRONTEND_BUILD_DIR, + GOOGLE_DRIVE_API_KEY, + GOOGLE_DRIVE_CLIENT_ID, + GOOGLE_PSE_API_KEY, + GOOGLE_PSE_ENGINE_ID, + IFRAME_CSP, + IMAGE_EDIT_ENGINE, + IMAGE_EDIT_MODEL, + IMAGE_EDIT_SIZE, + IMAGE_GENERATION_ENGINE, + IMAGE_GENERATION_MODEL, + IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, + IMAGE_SIZE, + IMAGE_STEPS, + IMAGES_EDIT_COMFYUI_API_KEY, + IMAGES_EDIT_COMFYUI_BASE_URL, + IMAGES_EDIT_COMFYUI_WORKFLOW, + IMAGES_EDIT_COMFYUI_WORKFLOW_NODES, + IMAGES_EDIT_GEMINI_API_BASE_URL, + IMAGES_EDIT_GEMINI_API_KEY, + IMAGES_EDIT_OPENAI_API_BASE_URL, + IMAGES_EDIT_OPENAI_API_KEY, + IMAGES_EDIT_OPENAI_API_VERSION, + IMAGES_GEMINI_API_BASE_URL, + IMAGES_GEMINI_API_KEY, + IMAGES_GEMINI_ENDPOINT_METHOD, + IMAGES_OPENAI_API_BASE_URL, + IMAGES_OPENAI_API_KEY, + IMAGES_OPENAI_API_PARAMS, + IMAGES_OPENAI_API_VERSION, + JINA_API_BASE_URL, + JINA_API_KEY, + JWT_EXPIRES_IN, + KAGI_SEARCH_API_KEY, + LDAP_APP_DN, + LDAP_APP_PASSWORD, + LDAP_ATTRIBUTE_FOR_GROUPS, + LDAP_ATTRIBUTE_FOR_MAIL, + LDAP_ATTRIBUTE_FOR_USERNAME, + LDAP_CA_CERT_FILE, + LDAP_CIPHERS, + LDAP_SEARCH_BASE, + LDAP_SEARCH_FILTERS, + LDAP_SERVER_HOST, + LDAP_SERVER_LABEL, + LDAP_SERVER_PORT, + LDAP_USE_TLS, + LDAP_VALIDATE_CERT, + MINERU_API_KEY, + MINERU_API_MODE, + MINERU_API_TIMEOUT, + MINERU_API_URL, + MINERU_PARAMS, + MISTRAL_OCR_API_BASE_URL, + MISTRAL_OCR_API_KEY, + MODEL_ORDER_LIST, + MOJEEK_SEARCH_API_KEY, + OAUTH_ADMIN_ROLES, + OAUTH_ALLOWED_ROLES, + OAUTH_EMAIL_CLAIM, + OAUTH_PICTURE_CLAIM, + OAUTH_PROVIDERS, + OAUTH_ROLES_CLAIM, + OAUTH_SUB_CLAIM, + OAUTH_USERNAME_CLAIM, + OLLAMA_API_CONFIGS, + OLLAMA_BASE_URLS, + OLLAMA_CLOUD_WEB_SEARCH_API_KEY, + ONEDRIVE_CLIENT_ID_BUSINESS, + ONEDRIVE_CLIENT_ID_PERSONAL, + ONEDRIVE_SHAREPOINT_TENANT_ID, + ONEDRIVE_SHAREPOINT_URL, + OPENAI_API_BASE_URLS, + OPENAI_API_CONFIGS, + OPENAI_API_KEYS, + PADDLEOCR_VL_BASE_URL, + PADDLEOCR_VL_TOKEN, + PDF_EXTRACT_IMAGES, + PDF_LOADER_MODE, + PENDING_USER_OVERLAY_CONTENT, + PENDING_USER_OVERLAY_TITLE, + PERPLEXITY_API_KEY, + PERPLEXITY_MODEL, + PERPLEXITY_SEARCH_API_URL, + PERPLEXITY_SEARCH_CONTEXT_USAGE, + PLAYWRIGHT_TIMEOUT, + PLAYWRIGHT_WS_URL, + QUERY_GENERATION_PROMPT_TEMPLATE, + RAG_ALLOWED_FILE_EXTENSIONS, + RAG_AZURE_OPENAI_API_KEY, + RAG_AZURE_OPENAI_API_VERSION, + RAG_AZURE_OPENAI_BASE_URL, + RAG_EMBEDDING_BATCH_SIZE, + RAG_EMBEDDING_CONCURRENT_REQUESTS, + RAG_EMBEDDING_ENGINE, + RAG_EMBEDDING_MODEL, + RAG_EMBEDDING_MODEL_AUTO_UPDATE, + RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, + RAG_EXTERNAL_RERANKER_API_KEY, + RAG_EXTERNAL_RERANKER_TIMEOUT, + RAG_EXTERNAL_RERANKER_URL, + RAG_FILE_MAX_COUNT, + RAG_FILE_MAX_SIZE, + RAG_FULL_CONTEXT, + RAG_HYBRID_BM25_WEIGHT, + RAG_OLLAMA_API_KEY, + RAG_OLLAMA_BASE_URL, + RAG_OPENAI_API_BASE_URL, + RAG_OPENAI_API_KEY, + RAG_RELEVANCE_THRESHOLD, + RAG_RERANKING_BATCH_SIZE, + RAG_RERANKING_ENGINE, + RAG_RERANKING_MODEL, + RAG_RERANKING_MODEL_AUTO_UPDATE, + RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, + # Retrieval + RAG_TEMPLATE, + RAG_TEXT_SPLITTER, + RAG_TOP_K, + RAG_TOP_K_RERANKER, + RESPONSE_WATERMARK, + SEARCHAPI_API_KEY, + SEARCHAPI_ENGINE, + SEARXNG_LANGUAGE, + SEARXNG_QUERY_URL, + SERPAPI_API_KEY, + SERPAPI_ENGINE, + SERPER_API_KEY, + SERPLY_API_KEY, + SERPSTACK_API_KEY, + SERPSTACK_HTTPS, + SHOW_ADMIN_DETAILS, + SOUGOU_API_SID, + SOUGOU_API_SK, + STATIC_DIR, + TAGS_GENERATION_PROMPT_TEMPLATE, + # Tasks + TASK_MODEL, + TASK_MODEL_EXTERNAL, + TAVILY_API_KEY, + TAVILY_EXTRACT_DEPTH, + # Terminal Server + TERMINAL_SERVER_CONNECTIONS, + # Thread pool size for FastAPI/AnyIO + THREAD_POOL_SIZE, + TIKA_SERVER_URL, + TIKTOKEN_ENCODING_NAME, + TITLE_GENERATION_PROMPT_TEMPLATE, + # Tool Server Configs + TOOL_SERVER_CONNECTIONS, + TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + UPLOAD_DIR, + USER_PERMISSIONS, + VOICE_MODE_PROMPT_TEMPLATE, + WEB_FETCH_MAX_CONTENT_LENGTH, + WEB_LOADER_CONCURRENT_REQUESTS, + WEB_LOADER_ENGINE, + WEB_LOADER_TIMEOUT, + WEB_SEARCH_CONCURRENT_REQUESTS, + WEB_SEARCH_DOMAIN_FILTER_LIST, + WEB_SEARCH_ENGINE, + WEB_SEARCH_RESULT_COUNT, + WEB_SEARCH_TRUST_ENV, + WEBHOOK_URL, + # WebUI + WEBUI_AUTH, + WEBUI_BANNERS, + WEBUI_NAME, + WEBUI_URL, + WHISPER_LANGUAGE, + WHISPER_MODEL, + WHISPER_MODEL_AUTO_UPDATE, + WHISPER_MODEL_DIR, + WHISPER_VAD_FILTER, + YACY_PASSWORD, + YACY_QUERY_URL, + YACY_USERNAME, + YANDEX_WEB_SEARCH_API_KEY, + YANDEX_WEB_SEARCH_CONFIG, + YANDEX_WEB_SEARCH_URL, + YOUCOM_API_KEY, + YOUTUBE_LOADER_LANGUAGE, + YOUTUBE_LOADER_PROXY_URL, + AppConfig, + async_reset_config, + reset_config, +) +from open_webui.constants import ERROR_MESSAGES, TASKS +from open_webui.env import ( + AIOHTTP_CLIENT_SESSION_SSL, + AUDIT_EXCLUDED_PATHS, + AUDIT_INCLUDED_PATHS, + AUDIT_LOG_LEVEL, + BYPASS_MODEL_ACCESS_CONTROL, + CHANGELOG, + DEPLOYMENT_ID, + ENABLE_AUDIT_GET_REQUESTS, + ENABLE_COMPRESSION_MIDDLEWARE, + ENABLE_CUSTOM_MODEL_FALLBACK, + ENABLE_EASTER_EGGS, + # OAuth Back-Channel Logout + ENABLE_OAUTH_BACKCHANNEL_LOGOUT, + ENABLE_OTEL, + ENABLE_PUBLIC_ACTIVE_USERS_COUNT, + # SCIM + ENABLE_SCIM, + ENABLE_SIGNUP_PASSWORD_CONFIRMATION, + ENABLE_STAR_SESSIONS_MIDDLEWARE, + ENABLE_VERSION_UPDATE_CHECK, + ENABLE_WEBSOCKET_SUPPORT, + EXTERNAL_PWA_MANIFEST_URL, + GLOBAL_LOG_LEVEL, + INSTANCE_ID, + LICENSE_KEY, + LOG_FORMAT, + MAX_BODY_LOG_SIZE, + REDIS_CLUSTER, + REDIS_KEY_PREFIX, + REDIS_SENTINEL_HOSTS, + REDIS_SENTINEL_PORT, + REDIS_URL, + RESET_CONFIG_ON_START, + SAFE_MODE, + SCIM_TOKEN, + VERSION, + # Admin Account Runtime Creation + WEBUI_ADMIN_EMAIL, + WEBUI_ADMIN_NAME, + WEBUI_ADMIN_PASSWORD, + WEBUI_AUTH_SIGNOUT_REDIRECT_URL, + WEBUI_AUTH_TRUSTED_EMAIL_HEADER, + WEBUI_AUTH_TRUSTED_NAME_HEADER, + WEBUI_BUILD_HASH, + WEBUI_SECRET_KEY, + WEBUI_SESSION_COOKIE_SAME_SITE, + WEBUI_SESSION_COOKIE_SECURE, +) +from open_webui.internal.db import ScopedSession, engine, get_async_session +from open_webui.models.chats import ChatForm, Chats +from open_webui.models.functions import Functions +from open_webui.models.models import Models +from open_webui.models.users import UserModel, Users +from open_webui.routers import ( + analytics, + audio, + auths, + automations, + calendar, + channels, + chats, + configs, + evaluations, + files, + folders, + functions, + groups, + images, + knowledge, + memories, + models, + notes, + ollama, + openai, + pipelines, + prompts, + retrieval, + scim, + skills, + tasks, + terminals, + tools, + users, + utils, +) +from open_webui.routers.retrieval import ( + get_ef, + get_embedding_function, + get_reranking_function, + get_rf, +) +from open_webui.socket.main import ( + MODELS, + get_event_emitter, + get_models_in_use, + get_user_id_from_session_pool, + periodic_session_pool_cleanup, + periodic_usage_pool_cleanup, +) +from open_webui.socket.main import ( + app as socket_app, +) +from open_webui.tasks import ( + cleanup_task, + create_task, + has_active_tasks, + list_task_ids_by_item_id, + list_tasks, + redis_task_command_listener, + stop_item_tasks, + stop_task, +) # Import from tasks.py from open_webui.utils import logger +from open_webui.utils.actions import chat_action as chat_action_handler from open_webui.utils.asgi_middleware import ( AuthTokenMiddleware, CommitSessionMiddleware, @@ -64,538 +539,48 @@ from open_webui.utils.asgi_middleware import ( WebsocketUpgradeGuardMiddleware, ) from open_webui.utils.audit import AuditLevel, AuditLoggingMiddleware -from open_webui.utils.logger import start_logger -from open_webui.utils.session_pool import get_session -from open_webui.socket.main import ( - MODELS, - app as socket_app, - periodic_usage_pool_cleanup, - periodic_session_pool_cleanup, - get_event_emitter, - get_models_in_use, - get_user_id_from_session_pool, +from open_webui.utils.auth import ( + create_admin_user, + decode_token, + get_admin_user, + get_http_authorization_cred, + get_license_data, + get_verified_user, ) -from open_webui.routers import ( - analytics, - audio, - images, - ollama, - openai, - retrieval, - pipelines, - tasks, - auths, - channels, - chats, - notes, - folders, - configs, - groups, - files, - functions, - memories, - models, - knowledge, - prompts, - evaluations, - skills, - tools, - users, - utils, - scim, - terminals, - automations, - calendar, -) - -from open_webui.routers.retrieval import ( - get_embedding_function, - get_reranking_function, - get_ef, - get_rf, -) - - -from sqlalchemy.ext.asyncio import AsyncSession -from open_webui.internal.db import ScopedSession, engine, get_async_session - -from open_webui.models.functions import Functions -from open_webui.models.models import Models -from open_webui.models.users import UserModel, Users -from open_webui.models.chats import Chats, ChatForm - -from open_webui.config import ( - # Ollama - ENABLE_OLLAMA_API, - OLLAMA_BASE_URLS, - OLLAMA_API_CONFIGS, - # OpenAI - ENABLE_OPENAI_API, - OPENAI_API_BASE_URLS, - OPENAI_API_KEYS, - OPENAI_API_CONFIGS, - # Direct Connections - ENABLE_DIRECT_CONNECTIONS, - # Model list - ENABLE_BASE_MODELS_CACHE, - # Thread pool size for FastAPI/AnyIO - THREAD_POOL_SIZE, - # Tool Server Configs - TOOL_SERVER_CONNECTIONS, - # Terminal Server - TERMINAL_SERVER_CONNECTIONS, - # Code Execution - ENABLE_CODE_EXECUTION, - CODE_EXECUTION_ENGINE, - CODE_EXECUTION_JUPYTER_URL, - CODE_EXECUTION_JUPYTER_AUTH, - CODE_EXECUTION_JUPYTER_AUTH_TOKEN, - CODE_EXECUTION_JUPYTER_AUTH_PASSWORD, - CODE_EXECUTION_JUPYTER_TIMEOUT, - ENABLE_CODE_INTERPRETER, - CODE_INTERPRETER_ENGINE, - CODE_INTERPRETER_PROMPT_TEMPLATE, - CODE_INTERPRETER_JUPYTER_URL, - CODE_INTERPRETER_JUPYTER_AUTH, - CODE_INTERPRETER_JUPYTER_AUTH_TOKEN, - CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD, - CODE_INTERPRETER_JUPYTER_TIMEOUT, - ENABLE_MEMORIES, - # Image - AUTOMATIC1111_API_AUTH, - AUTOMATIC1111_BASE_URL, - AUTOMATIC1111_PARAMS, - COMFYUI_BASE_URL, - COMFYUI_API_KEY, - COMFYUI_WORKFLOW, - COMFYUI_WORKFLOW_NODES, - ENABLE_IMAGE_GENERATION, - ENABLE_IMAGE_PROMPT_GENERATION, - IMAGE_GENERATION_ENGINE, - IMAGE_GENERATION_MODEL, - IMAGE_SIZE, - IMAGE_STEPS, - IMAGES_OPENAI_API_BASE_URL, - IMAGES_OPENAI_API_VERSION, - IMAGES_OPENAI_API_KEY, - IMAGES_OPENAI_API_PARAMS, - IMAGES_GEMINI_API_BASE_URL, - IMAGES_GEMINI_API_KEY, - IMAGES_GEMINI_ENDPOINT_METHOD, - ENABLE_IMAGE_EDIT, - IMAGE_EDIT_ENGINE, - IMAGE_EDIT_MODEL, - IMAGE_EDIT_SIZE, - IMAGES_EDIT_OPENAI_API_BASE_URL, - IMAGES_EDIT_OPENAI_API_KEY, - IMAGES_EDIT_OPENAI_API_VERSION, - IMAGES_EDIT_GEMINI_API_BASE_URL, - IMAGES_EDIT_GEMINI_API_KEY, - IMAGES_EDIT_COMFYUI_BASE_URL, - IMAGES_EDIT_COMFYUI_API_KEY, - IMAGES_EDIT_COMFYUI_WORKFLOW, - IMAGES_EDIT_COMFYUI_WORKFLOW_NODES, - # Audio - AUDIO_STT_ENGINE, - AUDIO_STT_MODEL, - AUDIO_STT_SUPPORTED_CONTENT_TYPES, - AUDIO_STT_ALLOWED_EXTENSIONS, - AUDIO_STT_OPENAI_API_BASE_URL, - AUDIO_STT_OPENAI_API_KEY, - AUDIO_STT_AZURE_API_KEY, - AUDIO_STT_AZURE_REGION, - AUDIO_STT_AZURE_LOCALES, - AUDIO_STT_AZURE_BASE_URL, - AUDIO_STT_AZURE_MAX_SPEAKERS, - AUDIO_STT_MISTRAL_API_KEY, - AUDIO_STT_MISTRAL_API_BASE_URL, - AUDIO_STT_MISTRAL_USE_CHAT_COMPLETIONS, - AUDIO_TTS_ENGINE, - AUDIO_TTS_MODEL, - AUDIO_TTS_VOICE, - AUDIO_TTS_OPENAI_API_BASE_URL, - AUDIO_TTS_OPENAI_API_KEY, - AUDIO_TTS_OPENAI_PARAMS, - AUDIO_TTS_API_KEY, - AUDIO_TTS_SPLIT_ON, - AUDIO_TTS_AZURE_SPEECH_REGION, - AUDIO_TTS_AZURE_SPEECH_BASE_URL, - AUDIO_TTS_AZURE_SPEECH_OUTPUT_FORMAT, - AUDIO_TTS_MISTRAL_API_KEY, - AUDIO_TTS_MISTRAL_API_BASE_URL, - PLAYWRIGHT_WS_URL, - PLAYWRIGHT_TIMEOUT, - FIRECRAWL_API_BASE_URL, - FIRECRAWL_API_KEY, - FIRECRAWL_TIMEOUT, - WEB_LOADER_ENGINE, - WEB_LOADER_CONCURRENT_REQUESTS, - WEB_LOADER_TIMEOUT, - WHISPER_MODEL, - WHISPER_VAD_FILTER, - WHISPER_LANGUAGE, - DEEPGRAM_API_KEY, - WHISPER_MODEL_AUTO_UPDATE, - WHISPER_MODEL_DIR, - # Retrieval - RAG_TEMPLATE, - DEFAULT_RAG_TEMPLATE, - RAG_FULL_CONTEXT, - BYPASS_EMBEDDING_AND_RETRIEVAL, - RAG_EMBEDDING_MODEL, - RAG_EMBEDDING_MODEL_AUTO_UPDATE, - RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, - RAG_RERANKING_ENGINE, - RAG_RERANKING_MODEL, - RAG_EXTERNAL_RERANKER_URL, - RAG_EXTERNAL_RERANKER_API_KEY, - RAG_EXTERNAL_RERANKER_TIMEOUT, - RAG_RERANKING_BATCH_SIZE, - RAG_RERANKING_MODEL_AUTO_UPDATE, - RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, - RAG_EMBEDDING_ENGINE, - RAG_EMBEDDING_BATCH_SIZE, - ENABLE_ASYNC_EMBEDDING, - RAG_EMBEDDING_CONCURRENT_REQUESTS, - RAG_TOP_K, - RAG_TOP_K_RERANKER, - RAG_RELEVANCE_THRESHOLD, - RAG_HYBRID_BM25_WEIGHT, - RAG_ALLOWED_FILE_EXTENSIONS, - RAG_FILE_MAX_COUNT, - RAG_FILE_MAX_SIZE, - FILE_IMAGE_COMPRESSION_WIDTH, - FILE_IMAGE_COMPRESSION_HEIGHT, - RAG_OPENAI_API_BASE_URL, - RAG_OPENAI_API_KEY, - RAG_AZURE_OPENAI_BASE_URL, - RAG_AZURE_OPENAI_API_KEY, - RAG_AZURE_OPENAI_API_VERSION, - RAG_OLLAMA_BASE_URL, - RAG_OLLAMA_API_KEY, - CHUNK_OVERLAP, - CHUNK_MIN_SIZE_TARGET, - CHUNK_SIZE, - CONTENT_EXTRACTION_ENGINE, - DATALAB_MARKER_API_KEY, - DATALAB_MARKER_API_BASE_URL, - DATALAB_MARKER_ADDITIONAL_CONFIG, - DATALAB_MARKER_SKIP_CACHE, - DATALAB_MARKER_FORCE_OCR, - DATALAB_MARKER_PAGINATE, - DATALAB_MARKER_STRIP_EXISTING_OCR, - DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION, - DATALAB_MARKER_FORMAT_LINES, - DATALAB_MARKER_OUTPUT_FORMAT, - MINERU_API_MODE, - MINERU_API_URL, - MINERU_API_KEY, - MINERU_API_TIMEOUT, - MINERU_PARAMS, - DATALAB_MARKER_USE_LLM, - EXTERNAL_DOCUMENT_LOADER_URL, - EXTERNAL_DOCUMENT_LOADER_API_KEY, - TIKA_SERVER_URL, - DOCLING_SERVER_URL, - DOCLING_API_KEY, - DOCLING_PARAMS, - DOCUMENT_INTELLIGENCE_ENDPOINT, - DOCUMENT_INTELLIGENCE_KEY, - DOCUMENT_INTELLIGENCE_MODEL, - MISTRAL_OCR_API_BASE_URL, - MISTRAL_OCR_API_KEY, - PADDLEOCR_VL_BASE_URL, - PADDLEOCR_VL_TOKEN, - RAG_TEXT_SPLITTER, - ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER, - TIKTOKEN_ENCODING_NAME, - PDF_EXTRACT_IMAGES, - PDF_LOADER_MODE, - YOUTUBE_LOADER_LANGUAGE, - YOUTUBE_LOADER_PROXY_URL, - # Retrieval (Web Search) - ENABLE_WEB_SEARCH, - WEB_SEARCH_ENGINE, - BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL, - BYPASS_WEB_SEARCH_WEB_LOADER, - WEB_SEARCH_RESULT_COUNT, - WEB_SEARCH_CONCURRENT_REQUESTS, - WEB_FETCH_MAX_CONTENT_LENGTH, - WEB_SEARCH_TRUST_ENV, - WEB_SEARCH_DOMAIN_FILTER_LIST, - OLLAMA_CLOUD_WEB_SEARCH_API_KEY, - JINA_API_KEY, - JINA_API_BASE_URL, - SEARCHAPI_API_KEY, - SEARCHAPI_ENGINE, - SERPAPI_API_KEY, - SERPAPI_ENGINE, - SEARXNG_QUERY_URL, - SEARXNG_LANGUAGE, - YACY_QUERY_URL, - YACY_USERNAME, - YACY_PASSWORD, - SERPER_API_KEY, - SERPLY_API_KEY, - DDGS_BACKEND, - SERPSTACK_API_KEY, - SERPSTACK_HTTPS, - TAVILY_API_KEY, - TAVILY_EXTRACT_DEPTH, - BING_SEARCH_V7_ENDPOINT, - BING_SEARCH_V7_SUBSCRIPTION_KEY, - BRAVE_SEARCH_API_KEY, - BRAVE_SEARCH_CONTEXT_TOKENS, - EXA_API_KEY, - PERPLEXITY_API_KEY, - PERPLEXITY_MODEL, - PERPLEXITY_SEARCH_CONTEXT_USAGE, - PERPLEXITY_SEARCH_API_URL, - SOUGOU_API_SID, - SOUGOU_API_SK, - KAGI_SEARCH_API_KEY, - MOJEEK_SEARCH_API_KEY, - BOCHA_SEARCH_API_KEY, - GOOGLE_PSE_API_KEY, - GOOGLE_PSE_ENGINE_ID, - GOOGLE_DRIVE_CLIENT_ID, - GOOGLE_DRIVE_API_KEY, - ENABLE_ONEDRIVE_INTEGRATION, - ONEDRIVE_CLIENT_ID_PERSONAL, - ONEDRIVE_CLIENT_ID_BUSINESS, - ONEDRIVE_SHAREPOINT_URL, - ONEDRIVE_SHAREPOINT_TENANT_ID, - ENABLE_ONEDRIVE_PERSONAL, - ENABLE_ONEDRIVE_BUSINESS, - ENABLE_RAG_HYBRID_SEARCH, - ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS, - ENABLE_RAG_LOCAL_WEB_FETCH, - ENABLE_WEB_LOADER_SSL_VERIFICATION, - ENABLE_GOOGLE_DRIVE_INTEGRATION, - UPLOAD_DIR, - EXTERNAL_WEB_SEARCH_URL, - EXTERNAL_WEB_SEARCH_API_KEY, - EXTERNAL_WEB_LOADER_URL, - EXTERNAL_WEB_LOADER_API_KEY, - YANDEX_WEB_SEARCH_URL, - YANDEX_WEB_SEARCH_API_KEY, - YANDEX_WEB_SEARCH_CONFIG, - YOUCOM_API_KEY, - # WebUI - WEBUI_AUTH, - WEBUI_NAME, - WEBUI_BANNERS, - WEBHOOK_URL, - ADMIN_EMAIL, - SHOW_ADMIN_DETAILS, - JWT_EXPIRES_IN, - ENABLE_SIGNUP, - ENABLE_LOGIN_FORM, - ENABLE_PASSWORD_CHANGE_FORM, - ENABLE_API_KEYS, - ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS, - API_KEYS_ALLOWED_ENDPOINTS, - ENABLE_FOLDERS, - FOLDER_MAX_FILE_COUNT, - ENABLE_AUTOMATIONS, - AUTOMATION_MAX_COUNT, - AUTOMATION_MIN_INTERVAL, - ENABLE_CHANNELS, - ENABLE_CALENDAR, - ENABLE_NOTES, - ENABLE_USER_STATUS, - ENABLE_COMMUNITY_SHARING, - ENABLE_MESSAGE_RATING, - ENABLE_USER_WEBHOOKS, - ENABLE_EVALUATION_ARENA_MODELS, - BYPASS_ADMIN_ACCESS_CONTROL, - USER_PERMISSIONS, - DEFAULT_USER_ROLE, - DEFAULT_GROUP_ID, - PENDING_USER_OVERLAY_CONTENT, - PENDING_USER_OVERLAY_TITLE, - DEFAULT_PROMPT_SUGGESTIONS, - DEFAULT_MODELS, - DEFAULT_PINNED_MODELS, - DEFAULT_ARENA_MODEL, - MODEL_ORDER_LIST, - DEFAULT_MODEL_METADATA, - DEFAULT_MODEL_PARAMS, - EVALUATION_ARENA_MODELS, - # WebUI (OAuth) - ENABLE_OAUTH_ROLE_MANAGEMENT, - OAUTH_SUB_CLAIM, - OAUTH_ROLES_CLAIM, - OAUTH_EMAIL_CLAIM, - OAUTH_PICTURE_CLAIM, - OAUTH_USERNAME_CLAIM, - OAUTH_ALLOWED_ROLES, - OAUTH_ADMIN_ROLES, - # WebUI (LDAP) - ENABLE_LDAP, - LDAP_SERVER_LABEL, - LDAP_SERVER_HOST, - LDAP_SERVER_PORT, - LDAP_ATTRIBUTE_FOR_MAIL, - LDAP_ATTRIBUTE_FOR_USERNAME, - LDAP_SEARCH_FILTERS, - LDAP_SEARCH_BASE, - LDAP_APP_DN, - LDAP_APP_PASSWORD, - LDAP_USE_TLS, - LDAP_CA_CERT_FILE, - LDAP_VALIDATE_CERT, - LDAP_CIPHERS, - # LDAP Group Management - ENABLE_LDAP_GROUP_MANAGEMENT, - ENABLE_LDAP_GROUP_CREATION, - LDAP_ATTRIBUTE_FOR_GROUPS, - # Misc - ENV, - CACHE_DIR, - STATIC_DIR, - FRONTEND_BUILD_DIR, - CORS_ALLOW_ORIGIN, - DEFAULT_LOCALE, - OAUTH_PROVIDERS, - WEBUI_URL, - RESPONSE_WATERMARK, - IFRAME_CSP, - # Admin - ENABLE_ADMIN_CHAT_ACCESS, - ENABLE_ADMIN_ANALYTICS, - BYPASS_ADMIN_ACCESS_CONTROL, - ENABLE_ADMIN_EXPORT, - # Tasks - TASK_MODEL, - TASK_MODEL_EXTERNAL, - ENABLE_TAGS_GENERATION, - ENABLE_TITLE_GENERATION, - ENABLE_FOLLOW_UP_GENERATION, - ENABLE_SEARCH_QUERY_GENERATION, - ENABLE_RETRIEVAL_QUERY_GENERATION, - ENABLE_AUTOCOMPLETE_GENERATION, - TITLE_GENERATION_PROMPT_TEMPLATE, - FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, - TAGS_GENERATION_PROMPT_TEMPLATE, - IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, - TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - VOICE_MODE_PROMPT_TEMPLATE, - ENABLE_VOICE_MODE_PROMPT, - QUERY_GENERATION_PROMPT_TEMPLATE, - AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, - AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH, - AppConfig, - reset_config, - async_reset_config, -) -from open_webui.env import ( - ENABLE_CUSTOM_MODEL_FALLBACK, - LICENSE_KEY, - AUDIT_EXCLUDED_PATHS, - AUDIT_INCLUDED_PATHS, - ENABLE_AUDIT_GET_REQUESTS, - AUDIT_LOG_LEVEL, - CHANGELOG, - REDIS_URL, - REDIS_CLUSTER, - REDIS_KEY_PREFIX, - REDIS_SENTINEL_HOSTS, - REDIS_SENTINEL_PORT, - GLOBAL_LOG_LEVEL, - MAX_BODY_LOG_SIZE, - SAFE_MODE, - VERSION, - DEPLOYMENT_ID, - INSTANCE_ID, - WEBUI_BUILD_HASH, - WEBUI_SECRET_KEY, - WEBUI_SESSION_COOKIE_SAME_SITE, - WEBUI_SESSION_COOKIE_SECURE, - ENABLE_SIGNUP_PASSWORD_CONFIRMATION, - WEBUI_AUTH_TRUSTED_EMAIL_HEADER, - WEBUI_AUTH_TRUSTED_NAME_HEADER, - WEBUI_AUTH_SIGNOUT_REDIRECT_URL, - # SCIM - ENABLE_SCIM, - SCIM_TOKEN, - ENABLE_COMPRESSION_MIDDLEWARE, - ENABLE_WEBSOCKET_SUPPORT, - BYPASS_MODEL_ACCESS_CONTROL, - RESET_CONFIG_ON_START, - ENABLE_VERSION_UPDATE_CHECK, - ENABLE_OTEL, - EXTERNAL_PWA_MANIFEST_URL, - AIOHTTP_CLIENT_SESSION_SSL, - ENABLE_STAR_SESSIONS_MIDDLEWARE, - ENABLE_PUBLIC_ACTIVE_USERS_COUNT, - # Admin Account Runtime Creation - WEBUI_ADMIN_EMAIL, - WEBUI_ADMIN_PASSWORD, - WEBUI_ADMIN_NAME, - ENABLE_EASTER_EGGS, - LOG_FORMAT, - # OAuth Back-Channel Logout - ENABLE_OAUTH_BACKCHANNEL_LOGOUT, -) - - -from open_webui.utils.models import ( - get_all_models, - get_all_base_models, - check_model_access, - get_filtered_models, +from open_webui.utils.chat import ( + chat_completed as chat_completed_handler, ) from open_webui.utils.chat import ( generate_chat_completion as chat_completion_handler, - chat_completed as chat_completed_handler, ) -from open_webui.utils.actions import chat_action as chat_action_handler from open_webui.utils.embeddings import generate_embeddings +from open_webui.utils.logger import start_logger from open_webui.utils.middleware import ( build_chat_response_context, process_chat_payload, process_chat_response, ) -from open_webui.utils.tools import set_tool_servers, set_terminal_servers - -from open_webui.utils.auth import ( - get_license_data, - get_http_authorization_cred, - decode_token, - get_admin_user, - get_verified_user, - create_admin_user, +from open_webui.utils.models import ( + check_model_access, + get_all_base_models, + get_all_models, + get_filtered_models, ) -from open_webui.utils.plugin import install_tool_and_function_dependencies from open_webui.utils.oauth import ( + OAuthClientInformationFull, + OAuthClientManager, + OAuthManager, + decrypt_data, + encrypt_data, get_oauth_client_info_with_dynamic_client_registration, get_oauth_client_info_with_static_credentials, - encrypt_data, - decrypt_data, resolve_oauth_client_info, - OAuthManager, - OAuthClientManager, - OAuthClientInformationFull, ) +from open_webui.utils.plugin import install_tool_and_function_dependencies +from open_webui.utils.redis import get_redis_connection, get_sentinels_from_env from open_webui.utils.security_headers import SecurityHeadersMiddleware -from open_webui.utils.redis import get_redis_connection - -from open_webui.tasks import ( - redis_task_command_listener, - list_task_ids_by_item_id, - has_active_tasks, - cleanup_task, - create_task, - stop_task, - stop_item_tasks, - list_tasks, -) # Import from tasks.py - -from open_webui.utils.redis import get_sentinels_from_env - - -from open_webui.constants import ERROR_MESSAGES, TASKS +from open_webui.utils.session_pool import get_session +from open_webui.utils.tools import set_terminal_servers, set_tool_servers if SAFE_MODE: print('SAFE MODE ENABLED') diff --git a/backend/open_webui/migrations/env.py b/backend/open_webui/migrations/env.py index 00c9e9569e..a9b9dcacb8 100644 --- a/backend/open_webui/migrations/env.py +++ b/backend/open_webui/migrations/env.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """Alembic environment configuration. Configures the migration context for both offline (SQL script generation) @@ -23,7 +25,7 @@ if config.config_file_name is not None: fileConfig(config.config_file_name, disable_existing_loggers=False) # Re-apply JSON formatter after fileConfig replaces handlers. -if LOG_FORMAT == "json": +if LOG_FORMAT == 'json': from open_webui.env import JSONFormatter for handler in logging.root.handlers: @@ -41,7 +43,7 @@ if _ssl_params: DB_URL = reattach_ssl_params_to_url(_url_no_ssl, _ssl_params) if DB_URL: - config.set_main_option("sqlalchemy.url", DB_URL.replace("%", "%%")) + config.set_main_option('sqlalchemy.url', DB_URL.replace('%', '%%')) # ── Migration runners ──────────────────────────────────────────────────────── @@ -49,12 +51,12 @@ if DB_URL: def run_migrations_offline() -> None: """Generate SQL script without a live database connection.""" - url = config.get_main_option("sqlalchemy.url") + url = config.get_main_option('sqlalchemy.url') context.configure( url=url, target_metadata=target_metadata, literal_binds=True, - dialect_opts={"paramstyle": "named"}, + dialect_opts={'paramstyle': 'named'}, ) with context.begin_transaction(): context.run_migrations() @@ -62,14 +64,12 @@ def run_migrations_offline() -> None: def _build_connectable(): """Create the appropriate SQLAlchemy engine for the configured DB URL.""" - if DB_URL and DB_URL.startswith("sqlite+sqlcipher://"): - if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == "": - raise ValueError( - "DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs" - ) + if DB_URL and DB_URL.startswith('sqlite+sqlcipher://'): + if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == '': + raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs') - db_path = DB_URL.replace("sqlite+sqlcipher://", "") - if db_path.startswith("/"): + db_path = DB_URL.replace('sqlite+sqlcipher://', '') + if db_path.startswith('/'): db_path = db_path[1:] def _sqlcipher_creator(): @@ -79,11 +79,11 @@ def _build_connectable(): conn.execute(f"PRAGMA key = '{DATABASE_PASSWORD}'") return conn - return create_engine("sqlite://", creator=_sqlcipher_creator, echo=False) + return create_engine('sqlite://', creator=_sqlcipher_creator, echo=False) return engine_from_config( config.get_section(config.config_ini_section, {}), - prefix="sqlalchemy.", + prefix='sqlalchemy.', poolclass=pool.NullPool, ) diff --git a/backend/open_webui/migrations/util.py b/backend/open_webui/migrations/util.py index 807baad4ba..606bdc9479 100644 --- a/backend/open_webui/migrations/util.py +++ b/backend/open_webui/migrations/util.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """Alembic migration utilities.""" from alembic import op diff --git a/backend/open_webui/migrations/versions/018012973d35_add_indexes.py b/backend/open_webui/migrations/versions/018012973d35_add_indexes.py index c5016e1a8b..b9f124f775 100644 --- a/backend/open_webui/migrations/versions/018012973d35_add_indexes.py +++ b/backend/open_webui/migrations/versions/018012973d35_add_indexes.py @@ -6,8 +6,8 @@ Create Date: 2025-08-13 03:00:00.000000 """ -from alembic import op import sqlalchemy as sa +from alembic import op revision = '018012973d35' down_revision = 'd31026856c01' diff --git a/backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py b/backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py index caffb7e3b4..7f7b901765 100644 --- a/backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py +++ b/backend/open_webui/migrations/versions/1af9b942657b_migrate_tags.py @@ -6,13 +6,13 @@ Create Date: 2024-10-09 21:02:35.241684 """ -from alembic import op -import sqlalchemy as sa -from sqlalchemy.sql import table, select, update, column -from sqlalchemy.engine.reflection import Inspector - import json +import sqlalchemy as sa +from alembic import op +from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.sql import column, select, table, update + revision = '1af9b942657b' down_revision = '242a2047eae0' branch_labels = None diff --git a/backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py b/backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py index 7fadb05a92..6f9be66342 100644 --- a/backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py +++ b/backend/open_webui/migrations/versions/242a2047eae0_update_chat_table.py @@ -6,12 +6,12 @@ Create Date: 2024-10-09 21:02:35.241684 """ -from alembic import op -import sqlalchemy as sa -from sqlalchemy.sql import table, select, update - import json +import sqlalchemy as sa +from alembic import op +from sqlalchemy.sql import select, table, update + revision = '242a2047eae0' down_revision = '6a39f3d8e55c' branch_labels = None diff --git a/backend/open_webui/migrations/versions/2f1211949ecc_update_message_and_channel_member_table.py b/backend/open_webui/migrations/versions/2f1211949ecc_update_message_and_channel_member_table.py index 51a8e329f1..e64f8ee91f 100644 --- a/backend/open_webui/migrations/versions/2f1211949ecc_update_message_and_channel_member_table.py +++ b/backend/open_webui/migrations/versions/2f1211949ecc_update_message_and_channel_member_table.py @@ -8,9 +8,9 @@ Create Date: 2025-11-27 03:07:56.200231 from typing import Sequence, Union -from alembic import op -import sqlalchemy as sa import open_webui.internal.db +import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = '2f1211949ecc' diff --git a/backend/open_webui/migrations/versions/374d2f66af06_add_prompt_history_table.py b/backend/open_webui/migrations/versions/374d2f66af06_add_prompt_history_table.py index c412107032..0a3521077e 100644 --- a/backend/open_webui/migrations/versions/374d2f66af06_add_prompt_history_table.py +++ b/backend/open_webui/migrations/versions/374d2f66af06_add_prompt_history_table.py @@ -6,11 +6,11 @@ Create Date: 2026-01-23 17:15:00.000000 """ -from typing import Sequence, Union import uuid +from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op revision: str = '374d2f66af06' down_revision: Union[str, None] = 'c440947495f3' diff --git a/backend/open_webui/migrations/versions/3781e22d8b01_update_message_table.py b/backend/open_webui/migrations/versions/3781e22d8b01_update_message_table.py index 170137f23c..a3cc5b9f90 100644 --- a/backend/open_webui/migrations/versions/3781e22d8b01_update_message_table.py +++ b/backend/open_webui/migrations/versions/3781e22d8b01_update_message_table.py @@ -6,8 +6,8 @@ Create Date: 2024-12-30 03:00:00.000000 """ -from alembic import op import sqlalchemy as sa +from alembic import op revision = '3781e22d8b01' down_revision = '7826ab40b532' diff --git a/backend/open_webui/migrations/versions/37f288994c47_add_group_member_table.py b/backend/open_webui/migrations/versions/37f288994c47_add_group_member_table.py index 4bf24d3b46..c92773dab7 100644 --- a/backend/open_webui/migrations/versions/37f288994c47_add_group_member_table.py +++ b/backend/open_webui/migrations/versions/37f288994c47_add_group_member_table.py @@ -6,13 +6,13 @@ Create Date: 2025-11-17 03:45:25.123939 """ -import uuid -import time import json +import time +import uuid from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = '37f288994c47' diff --git a/backend/open_webui/migrations/versions/38d63c18f30f_add_oauth_session_table.py b/backend/open_webui/migrations/versions/38d63c18f30f_add_oauth_session_table.py index d415f500f3..d5442026e3 100644 --- a/backend/open_webui/migrations/versions/38d63c18f30f_add_oauth_session_table.py +++ b/backend/open_webui/migrations/versions/38d63c18f30f_add_oauth_session_table.py @@ -8,8 +8,8 @@ Create Date: 2025-09-08 14:19:59.583921 from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = '38d63c18f30f' diff --git a/backend/open_webui/migrations/versions/3ab32c4b8f59_update_tags.py b/backend/open_webui/migrations/versions/3ab32c4b8f59_update_tags.py index 31bd355ede..8aaa4d4d47 100644 --- a/backend/open_webui/migrations/versions/3ab32c4b8f59_update_tags.py +++ b/backend/open_webui/migrations/versions/3ab32c4b8f59_update_tags.py @@ -6,13 +6,13 @@ Create Date: 2024-10-09 21:02:35.241684 """ -from alembic import op -import sqlalchemy as sa -from sqlalchemy.sql import table, select, update, column -from sqlalchemy.engine.reflection import Inspector - import json +import sqlalchemy as sa +from alembic import op +from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.sql import column, select, table, update + revision = '3ab32c4b8f59' down_revision = '1af9b942657b' branch_labels = None diff --git a/backend/open_webui/migrations/versions/3af16a1c9fb6_update_user_table.py b/backend/open_webui/migrations/versions/3af16a1c9fb6_update_user_table.py index 629c1c8c24..b678b5393b 100644 --- a/backend/open_webui/migrations/versions/3af16a1c9fb6_update_user_table.py +++ b/backend/open_webui/migrations/versions/3af16a1c9fb6_update_user_table.py @@ -8,8 +8,8 @@ Create Date: 2025-08-21 02:07:18.078283 from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = '3af16a1c9fb6' diff --git a/backend/open_webui/migrations/versions/3e0e00844bb0_add_knowledge_file_table.py b/backend/open_webui/migrations/versions/3e0e00844bb0_add_knowledge_file_table.py index f772987a44..da60344cdd 100644 --- a/backend/open_webui/migrations/versions/3e0e00844bb0_add_knowledge_file_table.py +++ b/backend/open_webui/migrations/versions/3e0e00844bb0_add_knowledge_file_table.py @@ -6,16 +6,15 @@ Create Date: 2025-12-02 06:54:19.401334 """ +import json +import time +import uuid from typing import Sequence, Union -from alembic import op -import sqlalchemy as sa -from sqlalchemy import inspect import open_webui.internal.db - -import time -import json -import uuid +import sqlalchemy as sa +from alembic import op +from sqlalchemy import inspect # revision identifiers, used by Alembic. revision: str = '3e0e00844bb0' diff --git a/backend/open_webui/migrations/versions/4ace53fd72c8_update_folder_table_datetime.py b/backend/open_webui/migrations/versions/4ace53fd72c8_update_folder_table_datetime.py index 91e0dce0be..9088cb3772 100644 --- a/backend/open_webui/migrations/versions/4ace53fd72c8_update_folder_table_datetime.py +++ b/backend/open_webui/migrations/versions/4ace53fd72c8_update_folder_table_datetime.py @@ -6,8 +6,8 @@ Create Date: 2024-10-23 03:00:00.000000 """ -from alembic import op import sqlalchemy as sa +from alembic import op revision = '4ace53fd72c8' down_revision = 'af906e964978' diff --git a/backend/open_webui/migrations/versions/4de81c2a3af1_add_pinned_note_table.py b/backend/open_webui/migrations/versions/4de81c2a3af1_add_pinned_note_table.py index 858c9b1541..eaaaaed18a 100644 --- a/backend/open_webui/migrations/versions/4de81c2a3af1_add_pinned_note_table.py +++ b/backend/open_webui/migrations/versions/4de81c2a3af1_add_pinned_note_table.py @@ -8,10 +8,9 @@ Create Date: 2026-05-09 04:29:27.651341 from typing import Sequence, Union -from alembic import op -import sqlalchemy as sa import open_webui.internal.db - +import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = '4de81c2a3af1' @@ -20,10 +19,11 @@ branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None -import uuid import time -from sqlalchemy import select, update, insert -from sqlalchemy.sql import table, column +import uuid + +from sqlalchemy import insert, select, update +from sqlalchemy.sql import column, table def upgrade() -> None: diff --git a/backend/open_webui/migrations/versions/56359461a091_add_calendar_tables.py b/backend/open_webui/migrations/versions/56359461a091_add_calendar_tables.py index e556440f56..648d11b880 100644 --- a/backend/open_webui/migrations/versions/56359461a091_add_calendar_tables.py +++ b/backend/open_webui/migrations/versions/56359461a091_add_calendar_tables.py @@ -8,9 +8,8 @@ Create Date: 2026-04-19 16:20:58.162045 from typing import Sequence, Union -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision: str = '56359461a091' diff --git a/backend/open_webui/migrations/versions/57c599a3cb57_add_channel_table.py b/backend/open_webui/migrations/versions/57c599a3cb57_add_channel_table.py index 79f0e8827e..fc08c60861 100644 --- a/backend/open_webui/migrations/versions/57c599a3cb57_add_channel_table.py +++ b/backend/open_webui/migrations/versions/57c599a3cb57_add_channel_table.py @@ -6,8 +6,8 @@ Create Date: 2024-12-22 03:00:00.000000 """ -from alembic import op import sqlalchemy as sa +from alembic import op revision = '57c599a3cb57' down_revision = '922e7a387820' diff --git a/backend/open_webui/migrations/versions/6283dc0e4d8d_add_channel_file_table.py b/backend/open_webui/migrations/versions/6283dc0e4d8d_add_channel_file_table.py index 2bd2d9fd60..b4d9d4c534 100644 --- a/backend/open_webui/migrations/versions/6283dc0e4d8d_add_channel_file_table.py +++ b/backend/open_webui/migrations/versions/6283dc0e4d8d_add_channel_file_table.py @@ -8,9 +8,9 @@ Create Date: 2025-12-10 15:11:39.424601 from typing import Sequence, Union -from alembic import op -import sqlalchemy as sa import open_webui.internal.db +import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = '6283dc0e4d8d' diff --git a/backend/open_webui/migrations/versions/6a39f3d8e55c_add_knowledge_table.py b/backend/open_webui/migrations/versions/6a39f3d8e55c_add_knowledge_table.py index c65ca01415..470d8bf399 100644 --- a/backend/open_webui/migrations/versions/6a39f3d8e55c_add_knowledge_table.py +++ b/backend/open_webui/migrations/versions/6a39f3d8e55c_add_knowledge_table.py @@ -6,11 +6,12 @@ Create Date: 2024-10-01 14:02:35.241684 """ -from alembic import op -import sqlalchemy as sa -from sqlalchemy.sql import table, column, select import json +import sqlalchemy as sa +from alembic import op +from sqlalchemy.sql import column, select, table + revision = '6a39f3d8e55c' down_revision = 'c0fbf31ca0db' branch_labels = None diff --git a/backend/open_webui/migrations/versions/7826ab40b532_update_file_table.py b/backend/open_webui/migrations/versions/7826ab40b532_update_file_table.py index 4211c6642e..81d31248b5 100644 --- a/backend/open_webui/migrations/versions/7826ab40b532_update_file_table.py +++ b/backend/open_webui/migrations/versions/7826ab40b532_update_file_table.py @@ -6,8 +6,8 @@ Create Date: 2024-12-23 03:00:00.000000 """ -from alembic import op import sqlalchemy as sa +from alembic import op revision = '7826ab40b532' down_revision = '57c599a3cb57' diff --git a/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py b/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py index d93c007671..6b96d9037d 100644 --- a/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py +++ b/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """Initial Alembic schema — creates all base tables. Revision ID: 7e5b5dc7342b @@ -7,14 +9,13 @@ Create Date: 2024-06-24 13:15:33.808998 from typing import Sequence, Union +import open_webui.internal.db # noqa: F401 import sqlalchemy as sa from alembic import op - -import open_webui.internal.db # noqa: F401 from open_webui.internal.db import JSONField from open_webui.migrations.util import get_existing_tables -revision: str = "7e5b5dc7342b" +revision: str = '7e5b5dc7342b' down_revision: Union[str, None] = None branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -25,162 +26,162 @@ depends_on: Union[str, Sequence[str], None] = None _TABLES: list[tuple[str, list[sa.Column], list]] = [ ( - "auth", + 'auth', [ - sa.Column("id", sa.String(), nullable=False), - sa.Column("email", sa.String(), nullable=True), - sa.Column("password", sa.Text(), nullable=True), - sa.Column("active", sa.Boolean(), nullable=True), + sa.Column('id', sa.String(), nullable=False), + sa.Column('email', sa.String(), nullable=True), + sa.Column('password', sa.Text(), nullable=True), + sa.Column('active', sa.Boolean(), nullable=True), ], - [sa.PrimaryKeyConstraint("id")], + [sa.PrimaryKeyConstraint('id')], ), ( - "chat", + 'chat', [ - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("title", sa.Text(), nullable=True), - sa.Column("chat", sa.Text(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("share_id", sa.Text(), nullable=True), - sa.Column("archived", sa.Boolean(), nullable=True), + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('title', sa.Text(), nullable=True), + sa.Column('chat', sa.Text(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('share_id', sa.Text(), nullable=True), + sa.Column('archived', sa.Boolean(), nullable=True), ], - [sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("share_id")], + [sa.PrimaryKeyConstraint('id'), sa.UniqueConstraint('share_id')], ), ( - "chatidtag", + 'chatidtag', [ - sa.Column("id", sa.String(), nullable=False), - sa.Column("tag_name", sa.String(), nullable=True), - sa.Column("chat_id", sa.String(), nullable=True), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("timestamp", sa.BigInteger(), nullable=True), + sa.Column('id', sa.String(), nullable=False), + sa.Column('tag_name', sa.String(), nullable=True), + sa.Column('chat_id', sa.String(), nullable=True), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('timestamp', sa.BigInteger(), nullable=True), ], - [sa.PrimaryKeyConstraint("id")], + [sa.PrimaryKeyConstraint('id')], ), ( - "document", + 'document', [ - sa.Column("collection_name", sa.String(), nullable=False), - sa.Column("name", sa.String(), nullable=True), - sa.Column("title", sa.Text(), nullable=True), - sa.Column("filename", sa.Text(), nullable=True), - sa.Column("content", sa.Text(), nullable=True), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("timestamp", sa.BigInteger(), nullable=True), + sa.Column('collection_name', sa.String(), nullable=False), + sa.Column('name', sa.String(), nullable=True), + sa.Column('title', sa.Text(), nullable=True), + sa.Column('filename', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('timestamp', sa.BigInteger(), nullable=True), ], - [sa.PrimaryKeyConstraint("collection_name"), sa.UniqueConstraint("name")], + [sa.PrimaryKeyConstraint('collection_name'), sa.UniqueConstraint('name')], ), ( - "file", + 'file', [ - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("filename", sa.Text(), nullable=True), - sa.Column("meta", JSONField(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('filename', sa.Text(), nullable=True), + sa.Column('meta', JSONField(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), ], - [sa.PrimaryKeyConstraint("id")], + [sa.PrimaryKeyConstraint('id')], ), ( - "function", + 'function', [ - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("name", sa.Text(), nullable=True), - sa.Column("type", sa.Text(), nullable=True), - sa.Column("content", sa.Text(), nullable=True), - sa.Column("meta", JSONField(), nullable=True), - sa.Column("valves", JSONField(), nullable=True), - sa.Column("is_active", sa.Boolean(), nullable=True), - sa.Column("is_global", sa.Boolean(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('name', sa.Text(), nullable=True), + sa.Column('type', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('meta', JSONField(), nullable=True), + sa.Column('valves', JSONField(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=True), + sa.Column('is_global', sa.Boolean(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), ], - [sa.PrimaryKeyConstraint("id")], + [sa.PrimaryKeyConstraint('id')], ), ( - "memory", + 'memory', [ - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("content", sa.Text(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), ], - [sa.PrimaryKeyConstraint("id")], + [sa.PrimaryKeyConstraint('id')], ), ( - "model", + 'model', [ - sa.Column("id", sa.Text(), nullable=False), - sa.Column("user_id", sa.Text(), nullable=True), - sa.Column("base_model_id", sa.Text(), nullable=True), - sa.Column("name", sa.Text(), nullable=True), - sa.Column("params", JSONField(), nullable=True), - sa.Column("meta", JSONField(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.Column('id', sa.Text(), nullable=False), + sa.Column('user_id', sa.Text(), nullable=True), + sa.Column('base_model_id', sa.Text(), nullable=True), + sa.Column('name', sa.Text(), nullable=True), + sa.Column('params', JSONField(), nullable=True), + sa.Column('meta', JSONField(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), ], - [sa.PrimaryKeyConstraint("id")], + [sa.PrimaryKeyConstraint('id')], ), ( - "prompt", + 'prompt', [ - sa.Column("command", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("title", sa.Text(), nullable=True), - sa.Column("content", sa.Text(), nullable=True), - sa.Column("timestamp", sa.BigInteger(), nullable=True), + sa.Column('command', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('title', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('timestamp', sa.BigInteger(), nullable=True), ], - [sa.PrimaryKeyConstraint("command")], + [sa.PrimaryKeyConstraint('command')], ), ( - "tag", + 'tag', [ - sa.Column("id", sa.String(), nullable=False), - sa.Column("name", sa.String(), nullable=True), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("data", sa.Text(), nullable=True), + sa.Column('id', sa.String(), nullable=False), + sa.Column('name', sa.String(), nullable=True), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('data', sa.Text(), nullable=True), ], - [sa.PrimaryKeyConstraint("id")], + [sa.PrimaryKeyConstraint('id')], ), ( - "tool", + 'tool', [ - sa.Column("id", sa.String(), nullable=False), - sa.Column("user_id", sa.String(), nullable=True), - sa.Column("name", sa.Text(), nullable=True), - sa.Column("content", sa.Text(), nullable=True), - sa.Column("specs", JSONField(), nullable=True), - sa.Column("meta", JSONField(), nullable=True), - sa.Column("valves", JSONField(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), + sa.Column('id', sa.String(), nullable=False), + sa.Column('user_id', sa.String(), nullable=True), + sa.Column('name', sa.Text(), nullable=True), + sa.Column('content', sa.Text(), nullable=True), + sa.Column('specs', JSONField(), nullable=True), + sa.Column('meta', JSONField(), nullable=True), + sa.Column('valves', JSONField(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), ], - [sa.PrimaryKeyConstraint("id")], + [sa.PrimaryKeyConstraint('id')], ), ( - "user", + 'user', [ - sa.Column("id", sa.String(), nullable=False), - sa.Column("name", sa.String(), nullable=True), - sa.Column("email", sa.String(), nullable=True), - sa.Column("role", sa.String(), nullable=True), - sa.Column("profile_image_url", sa.Text(), nullable=True), - sa.Column("last_active_at", sa.BigInteger(), nullable=True), - sa.Column("updated_at", sa.BigInteger(), nullable=True), - sa.Column("created_at", sa.BigInteger(), nullable=True), - sa.Column("api_key", sa.String(), nullable=True), - sa.Column("settings", JSONField(), nullable=True), - sa.Column("info", JSONField(), nullable=True), - sa.Column("oauth_sub", sa.Text(), nullable=True), + sa.Column('id', sa.String(), nullable=False), + sa.Column('name', sa.String(), nullable=True), + sa.Column('email', sa.String(), nullable=True), + sa.Column('role', sa.String(), nullable=True), + sa.Column('profile_image_url', sa.Text(), nullable=True), + sa.Column('last_active_at', sa.BigInteger(), nullable=True), + sa.Column('updated_at', sa.BigInteger(), nullable=True), + sa.Column('created_at', sa.BigInteger(), nullable=True), + sa.Column('api_key', sa.String(), nullable=True), + sa.Column('settings', JSONField(), nullable=True), + sa.Column('info', JSONField(), nullable=True), + sa.Column('oauth_sub', sa.Text(), nullable=True), ], [ - sa.PrimaryKeyConstraint("id"), - sa.UniqueConstraint("api_key"), - sa.UniqueConstraint("oauth_sub"), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('api_key'), + sa.UniqueConstraint('oauth_sub'), ], ), ] diff --git a/backend/open_webui/migrations/versions/81cc2ce44d79_update_channel_file_and_knowledge_table.py b/backend/open_webui/migrations/versions/81cc2ce44d79_update_channel_file_and_knowledge_table.py index e45a2443df..e069b2d1d6 100644 --- a/backend/open_webui/migrations/versions/81cc2ce44d79_update_channel_file_and_knowledge_table.py +++ b/backend/open_webui/migrations/versions/81cc2ce44d79_update_channel_file_and_knowledge_table.py @@ -8,9 +8,9 @@ Create Date: 2025-12-10 16:07:58.001282 from typing import Sequence, Union -from alembic import op -import sqlalchemy as sa import open_webui.internal.db +import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = '81cc2ce44d79' diff --git a/backend/open_webui/migrations/versions/8452d01d26d7_add_chat_message_table.py b/backend/open_webui/migrations/versions/8452d01d26d7_add_chat_message_table.py index 3254b57858..d37ec0016e 100644 --- a/backend/open_webui/migrations/versions/8452d01d26d7_add_chat_message_table.py +++ b/backend/open_webui/migrations/versions/8452d01d26d7_add_chat_message_table.py @@ -6,13 +6,13 @@ Create Date: 2026-02-01 04:00:00.000000 """ -import time import json import logging +import time from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op log = logging.getLogger(__name__) diff --git a/backend/open_webui/migrations/versions/90ef40d4714e_update_channel_and_channel_members_table.py b/backend/open_webui/migrations/versions/90ef40d4714e_update_channel_and_channel_members_table.py index 9d115b1e5c..baad674612 100644 --- a/backend/open_webui/migrations/versions/90ef40d4714e_update_channel_and_channel_members_table.py +++ b/backend/open_webui/migrations/versions/90ef40d4714e_update_channel_and_channel_members_table.py @@ -8,9 +8,9 @@ Create Date: 2025-11-30 06:33:38.790341 from typing import Sequence, Union -from alembic import op -import sqlalchemy as sa import open_webui.internal.db +import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = '90ef40d4714e' diff --git a/backend/open_webui/migrations/versions/922e7a387820_add_group_table.py b/backend/open_webui/migrations/versions/922e7a387820_add_group_table.py index 5e617be1e6..1f035f5d4f 100644 --- a/backend/open_webui/migrations/versions/922e7a387820_add_group_table.py +++ b/backend/open_webui/migrations/versions/922e7a387820_add_group_table.py @@ -6,8 +6,8 @@ Create Date: 2024-11-14 03:00:00.000000 """ -from alembic import op import sqlalchemy as sa +from alembic import op revision = '922e7a387820' down_revision = '4ace53fd72c8' diff --git a/backend/open_webui/migrations/versions/9f0c9cd09105_add_note_table.py b/backend/open_webui/migrations/versions/9f0c9cd09105_add_note_table.py index c75db04ca5..555224566b 100644 --- a/backend/open_webui/migrations/versions/9f0c9cd09105_add_note_table.py +++ b/backend/open_webui/migrations/versions/9f0c9cd09105_add_note_table.py @@ -6,8 +6,8 @@ Create Date: 2025-05-03 03:00:00.000000 """ -from alembic import op import sqlalchemy as sa +from alembic import op revision = '9f0c9cd09105' down_revision = '3781e22d8b01' diff --git a/backend/open_webui/migrations/versions/a1b2c3d4e5f6_add_skill_table.py b/backend/open_webui/migrations/versions/a1b2c3d4e5f6_add_skill_table.py index f11f7d8d1b..1ffce041e1 100644 --- a/backend/open_webui/migrations/versions/a1b2c3d4e5f6_add_skill_table.py +++ b/backend/open_webui/migrations/versions/a1b2c3d4e5f6_add_skill_table.py @@ -8,9 +8,8 @@ Create Date: 2026-02-11 09:30:00.000000 from typing import Sequence, Union -from alembic import op import sqlalchemy as sa - +from alembic import op from open_webui.migrations.util import get_existing_tables revision: str = 'a1b2c3d4e5f6' diff --git a/backend/open_webui/migrations/versions/a3dd5bedd151_add_tasks_and_summary_to_chat.py b/backend/open_webui/migrations/versions/a3dd5bedd151_add_tasks_and_summary_to_chat.py index 20a3152cfe..088277cd1a 100644 --- a/backend/open_webui/migrations/versions/a3dd5bedd151_add_tasks_and_summary_to_chat.py +++ b/backend/open_webui/migrations/versions/a3dd5bedd151_add_tasks_and_summary_to_chat.py @@ -8,8 +8,8 @@ Create Date: 2026-03-29 22:15:00.000000 from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = 'a3dd5bedd151' diff --git a/backend/open_webui/migrations/versions/a5c220713937_add_reply_to_id_column_to_message.py b/backend/open_webui/migrations/versions/a5c220713937_add_reply_to_id_column_to_message.py index 29157baa07..d54ae24edb 100644 --- a/backend/open_webui/migrations/versions/a5c220713937_add_reply_to_id_column_to_message.py +++ b/backend/open_webui/migrations/versions/a5c220713937_add_reply_to_id_column_to_message.py @@ -8,8 +8,8 @@ Create Date: 2025-09-27 02:24:18.058455 from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = 'a5c220713937' diff --git a/backend/open_webui/migrations/versions/af906e964978_add_feedback_table.py b/backend/open_webui/migrations/versions/af906e964978_add_feedback_table.py index 4d8fd63e80..5c7b33a452 100644 --- a/backend/open_webui/migrations/versions/af906e964978_add_feedback_table.py +++ b/backend/open_webui/migrations/versions/af906e964978_add_feedback_table.py @@ -6,8 +6,8 @@ Create Date: 2024-10-20 17:02:35.241684 """ -from alembic import op import sqlalchemy as sa +from alembic import op # Revision identifiers, used by Alembic. revision = 'af906e964978' diff --git a/backend/open_webui/migrations/versions/b10670c03dd5_update_user_table.py b/backend/open_webui/migrations/versions/b10670c03dd5_update_user_table.py index 623289d885..556b3a8eac 100644 --- a/backend/open_webui/migrations/versions/b10670c03dd5_update_user_table.py +++ b/backend/open_webui/migrations/versions/b10670c03dd5_update_user_table.py @@ -6,15 +6,13 @@ Create Date: 2025-11-28 04:55:31.737538 """ -from typing import Sequence, Union - -from alembic import op -import sqlalchemy as sa - - -import open_webui.internal.db import json import time +from typing import Sequence, Union + +import open_webui.internal.db +import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = 'b10670c03dd5' diff --git a/backend/open_webui/migrations/versions/b2c3d4e5f6a7_add_scim_column_to_user_table.py b/backend/open_webui/migrations/versions/b2c3d4e5f6a7_add_scim_column_to_user_table.py index e3668d3b6e..8cb62af600 100644 --- a/backend/open_webui/migrations/versions/b2c3d4e5f6a7_add_scim_column_to_user_table.py +++ b/backend/open_webui/migrations/versions/b2c3d4e5f6a7_add_scim_column_to_user_table.py @@ -8,8 +8,8 @@ Create Date: 2026-02-13 14:19:00.000000 from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = 'b2c3d4e5f6a7' diff --git a/backend/open_webui/migrations/versions/b7c8d9e0f1a2_add_last_read_at_to_chat.py b/backend/open_webui/migrations/versions/b7c8d9e0f1a2_add_last_read_at_to_chat.py index fb254432f6..bacd4cf90f 100644 --- a/backend/open_webui/migrations/versions/b7c8d9e0f1a2_add_last_read_at_to_chat.py +++ b/backend/open_webui/migrations/versions/b7c8d9e0f1a2_add_last_read_at_to_chat.py @@ -6,9 +6,8 @@ Create Date: 2026-04-01 04:00:00.000000 """ -from alembic import op import sqlalchemy as sa - +from alembic import op # revision identifiers, used by Alembic. revision = 'b7c8d9e0f1a2' diff --git a/backend/open_webui/migrations/versions/c1d2e3f4a5b6_add_shared_chat_table.py b/backend/open_webui/migrations/versions/c1d2e3f4a5b6_add_shared_chat_table.py index 2451f50ae2..8f4fec8deb 100644 --- a/backend/open_webui/migrations/versions/c1d2e3f4a5b6_add_shared_chat_table.py +++ b/backend/open_webui/migrations/versions/c1d2e3f4a5b6_add_shared_chat_table.py @@ -9,8 +9,8 @@ Create Date: 2026-04-16 23:00:00.000000 import time import uuid -from alembic import op import sqlalchemy as sa +from alembic import op revision = 'c1d2e3f4a5b6' down_revision = 'e1f2a3b4c5d6' diff --git a/backend/open_webui/migrations/versions/c29facfe716b_update_file_table_path.py b/backend/open_webui/migrations/versions/c29facfe716b_update_file_table_path.py index 37fe63ef15..9836708908 100644 --- a/backend/open_webui/migrations/versions/c29facfe716b_update_file_table_path.py +++ b/backend/open_webui/migrations/versions/c29facfe716b_update_file_table_path.py @@ -6,11 +6,12 @@ Create Date: 2024-10-20 17:02:35.241684 """ -from alembic import op -import sqlalchemy as sa import json -from sqlalchemy.sql import table, column -from sqlalchemy import String, Text, JSON, and_ + +import sqlalchemy as sa +from alembic import op +from sqlalchemy import JSON, String, Text, and_ +from sqlalchemy.sql import column, table revision = 'c29facfe716b' down_revision = 'c69f45358db4' diff --git a/backend/open_webui/migrations/versions/c440947495f3_add_chat_file_table.py b/backend/open_webui/migrations/versions/c440947495f3_add_chat_file_table.py index 0eae928b91..2b88ece98e 100644 --- a/backend/open_webui/migrations/versions/c440947495f3_add_chat_file_table.py +++ b/backend/open_webui/migrations/versions/c440947495f3_add_chat_file_table.py @@ -8,8 +8,8 @@ Create Date: 2025-12-21 20:27:41.694897 from typing import Sequence, Union -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. revision: str = 'c440947495f3' diff --git a/backend/open_webui/migrations/versions/c69f45358db4_add_folder_table.py b/backend/open_webui/migrations/versions/c69f45358db4_add_folder_table.py index c9572fe7a3..2440b395d6 100644 --- a/backend/open_webui/migrations/versions/c69f45358db4_add_folder_table.py +++ b/backend/open_webui/migrations/versions/c69f45358db4_add_folder_table.py @@ -6,8 +6,8 @@ Create Date: 2024-10-16 02:02:35.241684 """ -from alembic import op import sqlalchemy as sa +from alembic import op revision = 'c69f45358db4' down_revision = '3ab32c4b8f59' diff --git a/backend/open_webui/migrations/versions/ca81bd47c050_add_config_table.py b/backend/open_webui/migrations/versions/ca81bd47c050_add_config_table.py index 9f6a2541ea..93eecc974e 100644 --- a/backend/open_webui/migrations/versions/ca81bd47c050_add_config_table.py +++ b/backend/open_webui/migrations/versions/ca81bd47c050_add_config_table.py @@ -10,8 +10,8 @@ from typing import Sequence, Union import sqlalchemy as sa from alembic import op -revision: str = "ca81bd47c050" -down_revision: Union[str, None] = "7e5b5dc7342b" +revision: str = 'ca81bd47c050' +down_revision: Union[str, None] = '7e5b5dc7342b' branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None @@ -19,18 +19,18 @@ depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: """Create a key-value config table with versioning.""" op.create_table( - "config", - sa.Column("id", sa.Integer, primary_key=True), - sa.Column("data", sa.JSON(), nullable=False), - sa.Column("version", sa.Integer, nullable=False), + 'config', + sa.Column('id', sa.Integer, primary_key=True), + sa.Column('data', sa.JSON(), nullable=False), + sa.Column('version', sa.Integer, nullable=False), sa.Column( - "created_at", + 'created_at', sa.DateTime(), nullable=False, server_default=sa.func.now(), ), sa.Column( - "updated_at", + 'updated_at', sa.DateTime(), nullable=True, server_default=sa.func.now(), @@ -41,4 +41,4 @@ def upgrade() -> None: def downgrade() -> None: """Drop the config table.""" - op.drop_table("config") + op.drop_table('config') diff --git a/backend/open_webui/migrations/versions/d31026856c01_update_folder_table_data.py b/backend/open_webui/migrations/versions/d31026856c01_update_folder_table_data.py index 444e131db7..c30cd62aac 100644 --- a/backend/open_webui/migrations/versions/d31026856c01_update_folder_table_data.py +++ b/backend/open_webui/migrations/versions/d31026856c01_update_folder_table_data.py @@ -6,8 +6,8 @@ Create Date: 2025-07-13 03:00:00.000000 """ -from alembic import op import sqlalchemy as sa +from alembic import op revision = 'd31026856c01' down_revision = '9f0c9cd09105' diff --git a/backend/open_webui/migrations/versions/d4e5f6a7b8c9_add_automation_tables.py b/backend/open_webui/migrations/versions/d4e5f6a7b8c9_add_automation_tables.py index fc90dc417f..c6f3ccb593 100644 --- a/backend/open_webui/migrations/versions/d4e5f6a7b8c9_add_automation_tables.py +++ b/backend/open_webui/migrations/versions/d4e5f6a7b8c9_add_automation_tables.py @@ -7,8 +7,8 @@ Create Date: 2026-03-30 from typing import Union -from alembic import op import sqlalchemy as sa +from alembic import op revision: str = 'd4e5f6a7b8c9' down_revision: Union[str, None] = 'a3dd5bedd151' diff --git a/backend/open_webui/migrations/versions/e1f2a3b4c5d6_add_is_pinned_to_note.py b/backend/open_webui/migrations/versions/e1f2a3b4c5d6_add_is_pinned_to_note.py index 0d80558746..e086a333da 100644 --- a/backend/open_webui/migrations/versions/e1f2a3b4c5d6_add_is_pinned_to_note.py +++ b/backend/open_webui/migrations/versions/e1f2a3b4c5d6_add_is_pinned_to_note.py @@ -6,8 +6,8 @@ Create Date: 2026-04-14 22:00:00.000000 """ -from alembic import op import sqlalchemy as sa +from alembic import op revision = 'e1f2a3b4c5d6' down_revision = 'b7c8d9e0f1a2' diff --git a/backend/open_webui/migrations/versions/f1e2d3c4b5a6_add_access_grant_table.py b/backend/open_webui/migrations/versions/f1e2d3c4b5a6_add_access_grant_table.py index 5ed572cf7a..5ed0b543e4 100644 --- a/backend/open_webui/migrations/versions/f1e2d3c4b5a6_add_access_grant_table.py +++ b/backend/open_webui/migrations/versions/f1e2d3c4b5a6_add_access_grant_table.py @@ -11,13 +11,12 @@ Access control semantics: - {read: {...}, write: {...}}: Custom permissions -> insert specific grants """ -from typing import Sequence, Union import time import uuid +from typing import Sequence, Union -from alembic import op import sqlalchemy as sa - +from alembic import op from open_webui.migrations.util import get_existing_tables revision: str = 'f1e2d3c4b5a6' diff --git a/backend/open_webui/models/access_grants.py b/backend/open_webui/models/access_grants.py index f031495912..1c86dc08e7 100644 --- a/backend/open_webui/models/access_grants.py +++ b/backend/open_webui/models/access_grants.py @@ -3,13 +3,11 @@ import time import uuid from typing import Optional -from sqlalchemy import select, delete -from sqlalchemy.ext.asyncio import AsyncSession from open_webui.internal.db import Base, get_async_db_context - from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, Text, UniqueConstraint, or_, and_ +from sqlalchemy import BigInteger, Column, Text, UniqueConstraint, and_, delete, or_, select from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) @@ -623,8 +621,8 @@ class AccessGrantsTable: Get all users who have the specified permission on a resource. Returns a list of UserModel instances. """ - from open_webui.models.users import Users, UserModel from open_webui.models.groups import Groups + from open_webui.models.users import UserModel, Users async with get_async_db_context(db) as db: result = await db.execute( diff --git a/backend/open_webui/models/auths.py b/backend/open_webui/models/auths.py index 2c8c6ba99f..253c662029 100644 --- a/backend/open_webui/models/auths.py +++ b/backend/open_webui/models/auths.py @@ -1,14 +1,15 @@ +from __future__ import annotations + import logging import uuid from typing import Optional -from sqlalchemy import select, delete, update -from sqlalchemy.ext.asyncio import AsyncSession from open_webui.internal.db import Base, JSONField, get_async_db_context from open_webui.models.users import User, UserModel, UserProfileImageResponse, Users from open_webui.utils.validate import validate_profile_image_url from pydantic import BaseModel, field_validator -from sqlalchemy import Boolean, Column, String, Text +from sqlalchemy import Boolean, Column, String, Text, delete, select, update +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) @@ -44,7 +45,7 @@ class Token(BaseModel): class ApiKey(BaseModel): - api_key: Optional[str] = None + api_key: str | None = None class SigninResponse(Token, UserProfileImageResponse): @@ -74,18 +75,18 @@ class SignupForm(BaseModel): name: str email: str password: str - profile_image_url: Optional[str] = '/user.png' + profile_image_url: str | None = '/user.png' @field_validator('profile_image_url') @classmethod - def check_profile_image_url(cls, v: Optional[str]) -> Optional[str]: + def check_profile_image_url(cls, v: str | None) -> str | None: if v is not None: return validate_profile_image_url(v) return v class AddUserForm(SignupForm): - role: Optional[str] = 'pending' + role: str | None = 'pending' class AuthsTable: @@ -96,9 +97,9 @@ class AuthsTable: name: str, profile_image_url: str = '/user.png', role: str = 'pending', - oauth: Optional[dict] = None, - db: Optional[AsyncSession] = None, - ) -> Optional[UserModel]: + oauth: dict | None = None, + db: AsyncSession | None = None, + ) -> UserModel | None: async with get_async_db_context(db) as db: log.info('insert_new_auth') @@ -119,8 +120,8 @@ class AuthsTable: return None async def authenticate_user( - self, email: str, verify_password: callable, db: Optional[AsyncSession] = None - ) -> Optional[UserModel]: + self, email: str, verify_password: callable, db: AsyncSession | None = None + ) -> UserModel | None: log.info(f'authenticate_user: {email}') user = await Users.get_user_by_email(email, db=db) @@ -141,9 +142,7 @@ class AuthsTable: except Exception: return None - async def authenticate_user_by_api_key( - self, api_key: str, db: Optional[AsyncSession] = None - ) -> Optional[UserModel]: + async def authenticate_user_by_api_key(self, api_key: str, db: AsyncSession | None = None) -> UserModel | None: log.info(f'authenticate_user_by_api_key') # if no api_key, return None if not api_key: @@ -155,7 +154,7 @@ class AuthsTable: except Exception: return False - async def authenticate_user_by_email(self, email: str, db: Optional[AsyncSession] = None) -> Optional[UserModel]: + async def authenticate_user_by_email(self, email: str, db: AsyncSession | None = None) -> UserModel | None: log.info(f'authenticate_user_by_email: {email}') try: async with get_async_db_context(db) as db: @@ -171,7 +170,7 @@ class AuthsTable: except Exception: return None - async def update_user_password_by_id(self, id: str, new_password: str, db: Optional[AsyncSession] = None) -> bool: + async def update_user_password_by_id(self, id: str, new_password: str, db: AsyncSession | None = None) -> bool: try: async with get_async_db_context(db) as db: result = await db.execute(update(Auth).filter_by(id=id).values(password=new_password)) @@ -180,7 +179,7 @@ class AuthsTable: except Exception: return False - async def update_email_by_id(self, id: str, email: str, db: Optional[AsyncSession] = None) -> bool: + async def update_email_by_id(self, id: str, email: str, db: AsyncSession | None = None) -> bool: try: async with get_async_db_context(db) as db: result = await db.execute(update(Auth).filter_by(id=id).values(email=email)) @@ -192,7 +191,7 @@ class AuthsTable: except Exception: return False - async def delete_auth_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_auth_by_id(self, id: str, db: AsyncSession | None = None) -> bool: try: async with get_async_db_context(db) as db: # Delete User diff --git a/backend/open_webui/models/automations.py b/backend/open_webui/models/automations.py index 05f449ad13..4038a3bdbe 100644 --- a/backend/open_webui/models/automations.py +++ b/backend/open_webui/models/automations.py @@ -1,13 +1,12 @@ -import time import logging +import time from typing import Optional from uuid import uuid4 -from pydantic import BaseModel, ConfigDict -from sqlalchemy import Column, Text, JSON, Boolean, BigInteger, Index, select, or_, func, cast, String, delete, update -from sqlalchemy.ext.asyncio import AsyncSession - from open_webui.internal.db import Base, get_async_db_context +from pydantic import BaseModel, ConfigDict +from sqlalchemy import JSON, BigInteger, Boolean, Column, Index, String, Text, cast, delete, func, or_, select, update +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) diff --git a/backend/open_webui/models/calendar.py b/backend/open_webui/models/calendar.py index 48b28d8e9a..f2f9ed464c 100644 --- a/backend/open_webui/models/calendar.py +++ b/backend/open_webui/models/calendar.py @@ -1,30 +1,29 @@ -import time import logging +import time from typing import Optional from uuid import uuid4 -from pydantic import BaseModel, ConfigDict, Field -from sqlalchemy import ( - Column, - Text, - JSON, - Boolean, - BigInteger, - Index, - UniqueConstraint, - select, - or_, - exists, - func, - delete, - update, -) -from sqlalchemy.ext.asyncio import AsyncSession - from open_webui.internal.db import Base, get_async_db_context from open_webui.models.access_grants import AccessGrantModel, AccessGrants from open_webui.models.groups import Groups from open_webui.models.users import User, UserModel, UserResponse +from pydantic import BaseModel, ConfigDict, Field +from sqlalchemy import ( + JSON, + BigInteger, + Boolean, + Column, + Index, + Text, + UniqueConstraint, + delete, + exists, + func, + or_, + select, + update, +) +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py index adeaeaf9da..9d5f130355 100644 --- a/backend/open_webui/models/channels.py +++ b/backend/open_webui/models/channels.py @@ -4,31 +4,33 @@ import time import uuid from typing import Optional -from open_webui.utils.validate import validate_profile_image_url - -from sqlalchemy import select, delete, update, func, case, or_, and_ -from sqlalchemy.ext.asyncio import AsyncSession from open_webui.internal.db import Base, JSONField, get_async_db_context -from open_webui.models.groups import Groups from open_webui.models.access_grants import ( AccessGrantModel, AccessGrants, ) - +from open_webui.models.groups import Groups +from open_webui.utils.validate import validate_profile_image_url from pydantic import BaseModel, ConfigDict, Field, field_validator -from sqlalchemy.dialects.postgresql import JSONB - - from sqlalchemy import ( + JSON, BigInteger, Boolean, Column, ForeignKey, String, Text, - JSON, UniqueConstraint, + and_, + case, + delete, + func, + or_, + select, + update, ) +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.ext.asyncio import AsyncSession #################### # Channel DB Schema diff --git a/backend/open_webui/models/chat_messages.py b/backend/open_webui/models/chat_messages.py index a7d875c9dc..b33be1f7eb 100644 --- a/backend/open_webui/models/chat_messages.py +++ b/backend/open_webui/models/chat_messages.py @@ -3,21 +3,24 @@ import time import uuid from typing import Any, Optional -from sqlalchemy import select, delete, func, cast, Integer -from sqlalchemy.ext.asyncio import AsyncSession from open_webui.internal.db import Base, get_async_db_context from open_webui.utils.response import normalize_usage - from pydantic import BaseModel, ConfigDict from sqlalchemy import ( + JSON, BigInteger, Boolean, Column, ForeignKey, - Text, - JSON, Index, + Integer, + Text, + cast, + delete, + func, + select, ) +from sqlalchemy.ext.asyncio import AsyncSession #################### # Helpers @@ -579,6 +582,7 @@ class ChatMessageTable: """Get message counts grouped by day and model.""" async with get_async_db_context(db) as db: from datetime import datetime, timedelta + from open_webui.models.groups import GroupMember stmt = select(ChatMessage.created_at, ChatMessage.model_id).filter( diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index 957492d817..954d8640f0 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -1,32 +1,39 @@ -import logging +from __future__ import annotations + import json +import logging import time import uuid from typing import Optional -from sqlalchemy import select, delete, update, func, or_, and_, text -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.sql import exists -from sqlalchemy.sql.expression import bindparam from open_webui.internal.db import Base, JSONField, get_async_db_context -from open_webui.models.tags import TagModel, Tag, Tags -from open_webui.models.folders import Folders -from open_webui.models.chat_messages import ChatMessage, ChatMessages from open_webui.models.automations import AutomationRun +from open_webui.models.chat_messages import ChatMessage, ChatMessages +from open_webui.models.folders import Folders +from open_webui.models.tags import Tag, TagModel, Tags from open_webui.utils.misc import sanitize_data_for_db, sanitize_text_for_db - from pydantic import BaseModel, ConfigDict from sqlalchemy import ( + JSON, BigInteger, Boolean, Column, ForeignKey, + Index, String, Text, - JSON, - Index, UniqueConstraint, + and_, + delete, + func, + or_, + select, + text, + update, ) +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.sql import exists +from sqlalchemy.sql.expression import bindparam #################### # Chat DB Schema @@ -81,17 +88,17 @@ class ChatModel(BaseModel): created_at: int # timestamp in epoch updated_at: int # timestamp in epoch - share_id: Optional[str] = None + share_id: str | None = None archived: bool = False - pinned: Optional[bool] = False + pinned: bool | None = False meta: dict = {} - folder_id: Optional[str] = None + folder_id: str | None = None - tasks: Optional[list] = None - summary: Optional[str] = None + tasks: list | None = None + summary: str | None = None - last_read_at: Optional[int] = None + last_read_at: int | None = None class ChatFile(Base): @@ -115,7 +122,7 @@ class ChatFileModel(BaseModel): user_id: str chat_id: str - message_id: Optional[str] = None + message_id: str | None = None file_id: str created_at: int @@ -131,14 +138,14 @@ class ChatFileModel(BaseModel): class ChatForm(BaseModel): chat: dict - folder_id: Optional[str] = None + folder_id: str | None = None class ChatImportForm(ChatForm): - meta: Optional[dict] = {} - pinned: Optional[bool] = False - created_at: Optional[int] = None - updated_at: Optional[int] = None + meta: dict | None = {} + pinned: bool | None = False + created_at: int | None = None + updated_at: int | None = None class ChatsImportForm(BaseModel): @@ -161,14 +168,14 @@ class ChatResponse(BaseModel): chat: dict updated_at: int # timestamp in epoch created_at: int # timestamp in epoch - share_id: Optional[str] = None # id of the chat to be shared + share_id: str | None = None # id of the chat to be shared archived: bool - pinned: Optional[bool] = False + pinned: bool | None = False meta: dict = {} - folder_id: Optional[str] = None + folder_id: str | None = None - tasks: Optional[list] = None - summary: Optional[str] = None + tasks: list | None = None + summary: str | None = None class ChatTitleIdResponse(BaseModel): @@ -176,13 +183,13 @@ class ChatTitleIdResponse(BaseModel): title: str updated_at: int created_at: int - last_read_at: Optional[int] = None + last_read_at: int | None = None class SharedChatResponse(BaseModel): id: str title: str - share_id: Optional[str] = None + share_id: str | None = None updated_at: int created_at: int @@ -225,17 +232,17 @@ class ChatUsageStatsListResponse(BaseModel): class MessageStats(BaseModel): id: str role: str - model: Optional[str] = None + model: str | None = None content_length: int - token_count: Optional[int] = None - timestamp: Optional[int] = None - rating: Optional[int] = None # Derived from message.annotation.rating - tags: Optional[list[str]] = None # Derived from message.annotation.tags + token_count: int | None = None + timestamp: int | None = None + rating: int | None = None # Derived from message.annotation.rating + tags: list[str | None] = None # Derived from message.annotation.tags class ChatHistoryStats(BaseModel): messages: dict[str, MessageStats] - currentId: Optional[str] = None + currentId: str | None = None class ChatBody(BaseModel): @@ -293,8 +300,8 @@ class ChatTable: return changed async def insert_new_chat( - self, id: str, user_id: str, form_data: ChatForm, db: Optional[AsyncSession] = None - ) -> Optional[ChatModel]: + self, id: str, user_id: str, form_data: ChatForm, db: AsyncSession | None = None + ) -> ChatModel | None: async with get_async_db_context(db) as db: chat = ChatModel( **{ @@ -353,7 +360,7 @@ class ChatTable: self, user_id: str, chat_import_forms: list[ChatImportForm], - db: Optional[AsyncSession] = None, + db: AsyncSession | None = None, ) -> list[ChatModel]: async with get_async_db_context(db) as db: chats = [] @@ -383,7 +390,7 @@ class ChatTable: return [ChatModel.model_validate(chat) for chat in chats] - async def update_chat_by_id(self, id: str, chat: dict, db: Optional[AsyncSession] = None) -> Optional[ChatModel]: + async def update_chat_by_id(self, id: str, chat: dict, db: AsyncSession | None = None) -> ChatModel | None: try: async with get_async_db_context(db) as db: chat_item = await db.get(Chat, id) @@ -398,7 +405,7 @@ class ChatTable: except Exception: return None - async def update_chat_last_read_at_by_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> bool: + async def update_chat_last_read_at_by_id(self, id: str, user_id: str, db: AsyncSession | None = None) -> bool: try: async with get_async_db_context(db) as db: chat = await db.get(Chat, id) @@ -410,7 +417,7 @@ class ChatTable: except Exception: return False - async def update_chat_title_by_id(self, id: str, title: str) -> Optional[ChatModel]: + async def update_chat_title_by_id(self, id: str, title: str) -> ChatModel | None: try: async with get_async_db_context() as db: chat_item = await db.get(Chat, id) @@ -426,7 +433,7 @@ class ChatTable: except Exception: return None - async def update_chat_tags_by_id(self, id: str, tags: list[str], user) -> Optional[ChatModel]: + async def update_chat_tags_by_id(self, id: str, tags: list[str], user) -> ChatModel | None: async with get_async_db_context() as db: chat = await db.get(Chat, id) if chat is None: @@ -451,7 +458,7 @@ class ChatTable: return ChatModel.model_validate(chat) - async def get_chat_title_by_id(self, id: str) -> Optional[str]: + async def get_chat_title_by_id(self, id: str) -> str | None: async with get_async_db_context() as db: result = await db.execute(select(Chat.title).filter_by(id=id)) row = result.first() @@ -488,7 +495,7 @@ class ChatTable: except Exception as e: log.warning('Backfill failed for message %s in chat %s: %s', message_id, chat_id, e) - async def get_messages_map_by_chat_id(self, id: str) -> Optional[dict]: + async def get_messages_map_by_chat_id(self, id: str) -> dict | None: """Message map for walking history (see ``get_message_list``). Prefer ``chat_message`` rows to avoid loading the large embedded @@ -541,7 +548,7 @@ class ChatTable: return history_messages - async def get_message_by_id_and_message_id(self, id: str, message_id: str) -> Optional[dict]: + async def get_message_by_id_and_message_id(self, id: str, message_id: str) -> dict | None: chat = await self.get_chat_by_id(id) if chat is None: return None @@ -550,7 +557,7 @@ class ChatTable: async def upsert_message_to_chat_by_id_and_message_id( self, id: str, message_id: str, message: dict - ) -> Optional[ChatModel]: + ) -> ChatModel | None: chat = await self.get_chat_by_id(id) if chat is None: return None @@ -590,7 +597,7 @@ class ChatTable: async def add_message_status_to_chat_by_id_and_message_id( self, id: str, message_id: str, status: dict - ) -> Optional[ChatModel]: + ) -> ChatModel | None: chat = await self.get_chat_by_id(id) if chat is None: return None @@ -626,9 +633,7 @@ class ChatTable: await self.update_chat_by_id(id, chat, db=db) return message_files - async def insert_shared_chat_by_chat_id( - self, chat_id: str, db: Optional[AsyncSession] = None - ) -> Optional[ChatModel]: + async def insert_shared_chat_by_chat_id(self, chat_id: str, db: AsyncSession | None = None) -> ChatModel | None: """Create a shared snapshot for a chat. Returns the original chat with share_id set.""" from open_webui.models.shared_chats import SharedChats @@ -651,9 +656,7 @@ class ChatTable: await db.refresh(chat) return ChatModel.model_validate(chat) - async def update_shared_chat_by_chat_id( - self, chat_id: str, db: Optional[AsyncSession] = None - ) -> Optional[ChatModel]: + async def update_shared_chat_by_chat_id(self, chat_id: str, db: AsyncSession | None = None) -> ChatModel | None: """Re-snapshot the shared chat with current chat data.""" from open_webui.models.shared_chats import SharedChats @@ -668,7 +671,7 @@ class ChatTable: except Exception: return None - async def delete_shared_chat_by_chat_id(self, chat_id: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_shared_chat_by_chat_id(self, chat_id: str, db: AsyncSession | None = None) -> bool: """Delete shared snapshot for a chat.""" from open_webui.models.shared_chats import SharedChats @@ -677,7 +680,7 @@ class ChatTable: except Exception: return False - async def unarchive_all_chats_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> bool: + async def unarchive_all_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool: try: async with get_async_db_context(db) as db: await db.execute(update(Chat).filter_by(user_id=user_id).values(archived=False)) @@ -687,8 +690,8 @@ class ChatTable: return False async def update_chat_share_id_by_id( - self, id: str, share_id: Optional[str], db: Optional[AsyncSession] = None - ) -> Optional[ChatModel]: + self, id: str, share_id: str | None, db: AsyncSession | None = None + ) -> ChatModel | None: try: async with get_async_db_context(db) as db: chat = await db.get(Chat, id) @@ -699,7 +702,7 @@ class ChatTable: except Exception: return None - async def toggle_chat_pinned_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ChatModel]: + async def toggle_chat_pinned_by_id(self, id: str, db: AsyncSession | None = None) -> ChatModel | None: try: async with get_async_db_context(db) as db: chat = await db.get(Chat, id) @@ -711,7 +714,7 @@ class ChatTable: except Exception: return None - async def toggle_chat_archive_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ChatModel]: + async def toggle_chat_archive_by_id(self, id: str, db: AsyncSession | None = None) -> ChatModel | None: try: async with get_async_db_context(db) as db: chat = await db.get(Chat, id) @@ -724,7 +727,7 @@ class ChatTable: except Exception: return None - async def archive_all_chats_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> bool: + async def archive_all_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool: try: async with get_async_db_context(db) as db: await db.execute(update(Chat).filter_by(user_id=user_id).values(archived=True)) @@ -736,10 +739,10 @@ class ChatTable: async def get_archived_chat_list_by_user_id( self, user_id: str, - filter: Optional[dict] = None, + filter: dict | None = None, skip: int = 0, limit: int = 50, - db: Optional[AsyncSession] = None, + db: AsyncSession | None = None, ) -> list[ChatTitleIdResponse]: async with get_async_db_context(db) as db: stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at).filter_by( @@ -789,10 +792,10 @@ class ChatTable: async def get_shared_chat_list_by_user_id( self, user_id: str, - filter: Optional[dict] = None, + filter: dict | None = None, skip: int = 0, limit: int = 50, - db: Optional[AsyncSession] = None, + db: AsyncSession | None = None, ) -> list[SharedChatResponse]: """Delegate to SharedChats for listing shared chats by user.""" from open_webui.models.shared_chats import SharedChats @@ -803,10 +806,10 @@ class ChatTable: self, user_id: str, include_archived: bool = False, - filter: Optional[dict] = None, + filter: dict | None = None, skip: int = 0, limit: int = 50, - db: Optional[AsyncSession] = None, + db: AsyncSession | None = None, ) -> list[ChatTitleIdResponse]: async with get_async_db_context(db) as db: stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by( @@ -859,9 +862,9 @@ class ChatTable: include_archived: bool = False, include_folders: bool = False, include_pinned: bool = False, - skip: Optional[int] = None, - limit: Optional[int] = None, - db: Optional[AsyncSession] = None, + skip: int | None = None, + limit: int | None = None, + db: AsyncSession | None = None, ) -> list[ChatTitleIdResponse]: async with get_async_db_context(db) as db: stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by( @@ -905,7 +908,7 @@ class ChatTable: chat_ids: list[str], skip: int = 0, limit: int = 50, - db: Optional[AsyncSession] = None, + db: AsyncSession | None = None, ) -> list[ChatModel]: async with get_async_db_context(db) as db: result = await db.execute( @@ -914,7 +917,7 @@ class ChatTable: all_chats = result.scalars().all() return [ChatModel.model_validate(chat) for chat in all_chats] - async def get_chat_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ChatModel]: + async def get_chat_by_id(self, id: str, db: AsyncSession | None = None) -> ChatModel | None: try: async with get_async_db_context(db) as db: chat_item = await db.get(Chat, id) @@ -929,7 +932,7 @@ class ChatTable: except Exception: return None - async def get_chat_by_share_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ChatModel]: + async def get_chat_by_share_id(self, id: str, db: AsyncSession | None = None) -> ChatModel | None: """Look up a shared chat snapshot by its share token.""" from open_webui.models.shared_chats import SharedChats @@ -951,8 +954,8 @@ class ChatTable: return None async def get_chat_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[AsyncSession] = None - ) -> Optional[ChatModel]: + self, id: str, user_id: str, db: AsyncSession | None = None + ) -> ChatModel | None: try: async with get_async_db_context(db) as db: result = await db.execute(select(Chat).filter_by(id=id, user_id=user_id)) @@ -961,7 +964,7 @@ class ChatTable: except Exception: return None - async def is_chat_owner(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> bool: + async def is_chat_owner(self, id: str, user_id: str, db: AsyncSession | None = None) -> bool: """ Lightweight ownership check — uses EXISTS subquery instead of loading the full Chat row (which includes the potentially large JSON blob). @@ -973,7 +976,7 @@ class ChatTable: except Exception: return False - async def get_chat_folder_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> Optional[str]: + async def get_chat_folder_id(self, id: str, user_id: str, db: AsyncSession | None = None) -> str | None: """ Fetch only the folder_id column for a chat, without loading the full JSON blob. Returns None if chat doesn't exist or doesn't belong to user. @@ -986,7 +989,7 @@ class ChatTable: except Exception: return None - async def get_chats(self, skip: int = 0, limit: int = 50, db: Optional[AsyncSession] = None) -> list[ChatModel]: + async def get_chats(self, skip: int = 0, limit: int = 50, db: AsyncSession | None = None) -> list[ChatModel]: async with get_async_db_context(db) as db: result = await db.execute(select(Chat).order_by(Chat.updated_at.desc())) all_chats = result.scalars().all() @@ -995,10 +998,10 @@ class ChatTable: async def get_chats_by_user_id( self, user_id: str, - filter: Optional[dict] = None, - skip: Optional[int] = None, - limit: Optional[int] = None, - db: Optional[AsyncSession] = None, + filter: dict | None = None, + skip: int | None = None, + limit: int | None = None, + db: AsyncSession | None = None, ) -> ChatListResponse: async with get_async_db_context(db) as db: stmt = select(Chat).filter_by(user_id=user_id) @@ -1041,7 +1044,7 @@ class ChatTable: ) async def get_pinned_chats_by_user_id( - self, user_id: str, db: Optional[AsyncSession] = None + self, user_id: str, db: AsyncSession | None = None ) -> list[ChatTitleIdResponse]: async with get_async_db_context(db) as db: result = await db.execute( @@ -1063,7 +1066,7 @@ class ChatTable: for chat in all_chats ] - async def get_archived_chats_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> list[ChatModel]: + async def get_archived_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> list[ChatModel]: async with get_async_db_context(db) as db: result = await db.execute( select(Chat).filter_by(user_id=user_id, archived=True).order_by(Chat.updated_at.desc()) @@ -1077,7 +1080,7 @@ class ChatTable: include_archived: bool = False, skip: int = 0, limit: int = 60, - db: Optional[AsyncSession] = None, + db: AsyncSession | None = None, ) -> list[ChatModel]: """ Filters chats based on a search query using Python, allowing pagination using skip and limit. @@ -1270,7 +1273,7 @@ class ChatTable: user_id: str, skip: int = 0, limit: int = 60, - db: Optional[AsyncSession] = None, + db: AsyncSession | None = None, ) -> list[ChatTitleIdResponse]: async with get_async_db_context(db) as db: stmt = ( @@ -1302,7 +1305,7 @@ class ChatTable: ] async def get_chats_by_folder_ids_and_user_id( - self, folder_ids: list[str], user_id: str, db: Optional[AsyncSession] = None + self, folder_ids: list[str], user_id: str, db: AsyncSession | None = None ) -> list[ChatModel]: async with get_async_db_context(db) as db: stmt = ( @@ -1318,8 +1321,8 @@ class ChatTable: return [ChatModel.model_validate(chat) for chat in all_chats] async def update_chat_folder_id_by_id_and_user_id( - self, id: str, user_id: str, folder_id: str, db: Optional[AsyncSession] = None - ) -> Optional[ChatModel]: + self, id: str, user_id: str, folder_id: str, db: AsyncSession | None = None + ) -> ChatModel | None: try: async with get_async_db_context(db) as db: chat = await db.get(Chat, id) @@ -1333,7 +1336,7 @@ class ChatTable: return None async def get_chat_tags_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[AsyncSession] = None + self, id: str, user_id: str, db: AsyncSession | None = None ) -> list[TagModel]: async with get_async_db_context(db) as db: stmt = select(Chat.meta).where(Chat.id == id) @@ -1348,7 +1351,7 @@ class ChatTable: tag_name: str, skip: int = 0, limit: int = 50, - db: Optional[AsyncSession] = None, + db: AsyncSession | None = None, ) -> list[ChatTitleIdResponse]: async with get_async_db_context(db) as db: stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by( @@ -1393,8 +1396,8 @@ class ChatTable: ] async def add_chat_tag_by_id_and_user_id_and_tag_name( - self, id: str, user_id: str, tag_name: str, db: Optional[AsyncSession] = None - ) -> Optional[ChatModel]: + self, id: str, user_id: str, tag_name: str, db: AsyncSession | None = None + ) -> ChatModel | None: tag_id = tag_name.replace(' ', '_').lower() await Tags.ensure_tags_exist([tag_name], user_id, db=db) try: @@ -1412,7 +1415,7 @@ class ChatTable: return None async def count_chats_by_tag_name_and_user_id( - self, tag_name: str, user_id: str, db: Optional[AsyncSession] = None + self, tag_name: str, user_id: str, db: AsyncSession | None = None ) -> int: async with get_async_db_context(db) as db: stmt = select(func.count(Chat.id)).filter_by(user_id=user_id, archived=False) @@ -1439,7 +1442,7 @@ class ChatTable: tag_ids: list[str], user_id: str, threshold: int = 0, - db: Optional[AsyncSession] = None, + db: AsyncSession | None = None, ) -> None: """Delete tag rows from *tag_ids* that appear in at most *threshold* non-archived chats for *user_id*. One query to find orphans, one to @@ -1460,7 +1463,7 @@ class ChatTable: await Tags.delete_tags_by_ids_and_user_id(orphans, user_id, db=db) async def count_chats_by_folder_id_and_user_id( - self, folder_id: str, user_id: str, db: Optional[AsyncSession] = None + self, folder_id: str, user_id: str, db: AsyncSession | None = None ) -> int: async with get_async_db_context(db) as db: result = await db.execute(select(func.count(Chat.id)).filter_by(user_id=user_id, folder_id=folder_id)) @@ -1470,7 +1473,7 @@ class ChatTable: return count async def delete_tag_by_id_and_user_id_and_tag_name( - self, id: str, user_id: str, tag_name: str, db: Optional[AsyncSession] = None + self, id: str, user_id: str, tag_name: str, db: AsyncSession | None = None ) -> bool: try: async with get_async_db_context(db) as db: @@ -1488,7 +1491,7 @@ class ChatTable: except Exception: return False - async def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str, db: AsyncSession | None = None) -> bool: try: async with get_async_db_context(db) as db: chat = await db.get(Chat, id) @@ -1502,7 +1505,7 @@ class ChatTable: except Exception: return False - async def delete_chat_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_chat_by_id(self, id: str, db: AsyncSession | None = None) -> bool: try: async with get_async_db_context(db) as db: await db.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None)) @@ -1514,7 +1517,7 @@ class ChatTable: except Exception: return False - async def delete_chat_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_chat_by_id_and_user_id(self, id: str, user_id: str, db: AsyncSession | None = None) -> bool: try: async with get_async_db_context(db) as db: await db.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None)) @@ -1526,7 +1529,7 @@ class ChatTable: except Exception: return False - async def delete_chats_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool: try: async with get_async_db_context(db) as db: await self.delete_shared_chats_by_user_id(user_id, db=db) @@ -1548,7 +1551,7 @@ class ChatTable: return False async def delete_chats_by_user_id_and_folder_id( - self, user_id: str, folder_id: str, db: Optional[AsyncSession] = None + self, user_id: str, folder_id: str, db: AsyncSession | None = None ) -> bool: try: async with get_async_db_context(db) as db: @@ -1568,8 +1571,8 @@ class ChatTable: self, user_id: str, folder_id: str, - new_folder_id: Optional[str], - db: Optional[AsyncSession] = None, + new_folder_id: str | None, + db: AsyncSession | None = None, ) -> bool: try: async with get_async_db_context(db) as db: @@ -1582,9 +1585,10 @@ class ChatTable: except Exception: return False - async def delete_shared_chats_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_shared_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool: """Delete all shared chat snapshots created by a user.""" - from open_webui.models.shared_chats import SharedChats, SharedChat as SharedChatTable + from open_webui.models.shared_chats import SharedChat as SharedChatTable + from open_webui.models.shared_chats import SharedChats try: async with get_async_db_context(db) as db: @@ -1605,8 +1609,8 @@ class ChatTable: message_id: str, file_ids: list[str], user_id: str, - db: Optional[AsyncSession] = None, - ) -> Optional[list[ChatFileModel]]: + db: AsyncSession | None = None, + ) -> list[ChatFileModel | None]: if not file_ids: return None @@ -1645,7 +1649,7 @@ class ChatTable: return None async def get_chat_files_by_chat_id_and_message_id( - self, chat_id: str, message_id: str, db: Optional[AsyncSession] = None + self, chat_id: str, message_id: str, db: AsyncSession | None = None ) -> list[ChatFileModel]: async with get_async_db_context(db) as db: result = await db.execute( @@ -1654,7 +1658,7 @@ class ChatTable: all_chat_files = result.scalars().all() return [ChatFileModel.model_validate(chat_file) for chat_file in all_chat_files] - async def delete_chat_file(self, chat_id: str, file_id: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_chat_file(self, chat_id: str, file_id: str, db: AsyncSession | None = None) -> bool: try: async with get_async_db_context(db) as db: await db.execute(delete(ChatFile).filter_by(chat_id=chat_id, file_id=file_id)) @@ -1663,7 +1667,7 @@ class ChatTable: except Exception: return False - async def get_shared_chat_ids_by_file_id(self, file_id: str, db: Optional[AsyncSession] = None) -> list[str]: + async def get_shared_chat_ids_by_file_id(self, file_id: str, db: AsyncSession | None = None) -> list[str]: """Return IDs of chats that contain this file and have an active share link.""" async with get_async_db_context(db) as db: result = await db.execute( @@ -1673,7 +1677,7 @@ class ChatTable: ) return [row[0] for row in result.all()] - async def update_chat_tasks_by_id(self, id: str, tasks: list[dict]) -> Optional[ChatModel]: + async def update_chat_tasks_by_id(self, id: str, tasks: list[dict]) -> ChatModel | None: """Update the tasks list on a chat.""" try: async with get_async_db_context() as db: diff --git a/backend/open_webui/models/feedbacks.py b/backend/open_webui/models/feedbacks.py index d8ae4dc9b1..00bb6f1588 100644 --- a/backend/open_webui/models/feedbacks.py +++ b/backend/open_webui/models/feedbacks.py @@ -3,13 +3,11 @@ import time import uuid from typing import Optional -from sqlalchemy import select, delete, func -from sqlalchemy.ext.asyncio import AsyncSession from open_webui.internal.db import Base, JSONField, get_async_db_context from open_webui.models.users import User, UserModel - from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, Text, JSON, Boolean +from sqlalchemy import JSON, BigInteger, Boolean, Column, Text, delete, func, select +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) @@ -319,8 +317,8 @@ class FeedbackTable: If days=0, returns all time data starting from first feedback. Returns: [{"date": "2026-01-08", "won": 5, "lost": 2}, ...] """ - from datetime import datetime, timedelta from collections import defaultdict + from datetime import datetime, timedelta async with get_async_db_context(db) as db: if days == 0: diff --git a/backend/open_webui/models/files.py b/backend/open_webui/models/files.py index cfdcfbc2d9..dc45671280 100644 --- a/backend/open_webui/models/files.py +++ b/backend/open_webui/models/files.py @@ -1,13 +1,14 @@ +from __future__ import annotations + import logging import time from typing import Optional -from sqlalchemy import select, delete, func -from sqlalchemy.ext.asyncio import AsyncSession from open_webui.internal.db import Base, JSONField, get_async_db_context from open_webui.utils.misc import sanitize_metadata from pydantic import BaseModel, ConfigDict, model_validator -from sqlalchemy import BigInteger, Column, String, Text, JSON +from sqlalchemy import JSON, BigInteger, Column, String, Text, delete, func, select +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) @@ -39,16 +40,16 @@ class FileModel(BaseModel): id: str user_id: str - hash: Optional[str] = None + hash: str | None = None filename: str - path: Optional[str] = None + path: str | None = None - data: Optional[dict] = None - meta: Optional[dict] = None + data: dict | None = None + meta: dict | None = None - created_at: Optional[int] # timestamp in epoch - updated_at: Optional[int] # timestamp in epoch + created_at: int | None # timestamp in epoch + updated_at: int | None # timestamp in epoch #################### @@ -57,9 +58,9 @@ class FileModel(BaseModel): class FileMeta(BaseModel): - name: Optional[str] = None - content_type: Optional[str] = None - size: Optional[int] = None + name: str | None = None + content_type: str | None = None + size: int | None = None model_config = ConfigDict(extra='allow') @@ -84,22 +85,22 @@ class FileMeta(BaseModel): class FileModelResponse(BaseModel): id: str user_id: str - hash: Optional[str] = None + hash: str | None = None filename: str - data: Optional[dict] = None - meta: Optional[FileMeta] = None + data: dict | None = None + meta: FileMeta | None = None created_at: int # timestamp in epoch - updated_at: Optional[int] = None # timestamp in epoch, optional for legacy files + updated_at: int | None = None # timestamp in epoch, optional for legacy files model_config = ConfigDict(extra='allow') class FileMetadataResponse(BaseModel): id: str - hash: Optional[str] = None - meta: Optional[dict] = None + hash: str | None = None + meta: dict | None = None created_at: int # timestamp in epoch updated_at: int # timestamp in epoch @@ -111,7 +112,7 @@ class FileListResponse(BaseModel): class FileForm(BaseModel): id: str - hash: Optional[str] = None + hash: str | None = None filename: str path: str data: dict = {} @@ -119,15 +120,15 @@ class FileForm(BaseModel): class FileUpdateForm(BaseModel): - hash: Optional[str] = None - data: Optional[dict] = None - meta: Optional[dict] = None + hash: str | None = None + data: dict | None = None + meta: dict | None = None class FilesTable: async def insert_new_file( - self, user_id: str, form_data: FileForm, db: Optional[AsyncSession] = None - ) -> Optional[FileModel]: + self, user_id: str, form_data: FileForm, db: AsyncSession | None = None + ) -> FileModel | None: async with get_async_db_context(db) as db: file_data = form_data.model_dump() @@ -158,7 +159,7 @@ class FilesTable: log.exception(f'Error inserting a new file: {e}') return None - async def get_file_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[FileModel]: + async def get_file_by_id(self, id: str, db: AsyncSession | None = None) -> FileModel | None: try: async with get_async_db_context(db) as db: try: @@ -170,8 +171,8 @@ class FilesTable: return None async def get_file_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[AsyncSession] = None - ) -> Optional[FileModel]: + self, id: str, user_id: str, db: AsyncSession | None = None + ) -> FileModel | None: async with get_async_db_context(db) as db: try: result = await db.execute(select(File).filter_by(id=id, user_id=user_id)) @@ -183,9 +184,7 @@ class FilesTable: except Exception: return None - async def get_file_metadata_by_id( - self, id: str, db: Optional[AsyncSession] = None - ) -> Optional[FileMetadataResponse]: + async def get_file_metadata_by_id(self, id: str, db: AsyncSession | None = None) -> FileMetadataResponse | None: async with get_async_db_context(db) as db: try: file = await db.get(File, id) @@ -201,12 +200,12 @@ class FilesTable: except Exception: return None - async def get_files(self, db: Optional[AsyncSession] = None) -> list[FileModel]: + async def get_files(self, db: AsyncSession | None = None) -> list[FileModel]: async with get_async_db_context(db) as db: result = await db.execute(select(File)) return [FileModel.model_validate(file) for file in result.scalars().all()] - async def check_access_by_user_id(self, id, user_id, permission='write', db: Optional[AsyncSession] = None) -> bool: + async def check_access_by_user_id(self, id, user_id, permission='write', db: AsyncSession | None = None) -> bool: file = await self.get_file_by_id(id, db=db) if not file: return False @@ -215,13 +214,13 @@ class FilesTable: # Implement additional access control logic here as needed return False - async def get_files_by_ids(self, ids: list[str], db: Optional[AsyncSession] = None) -> list[FileModel]: + async def get_files_by_ids(self, ids: list[str], db: AsyncSession | None = None) -> list[FileModel]: async with get_async_db_context(db) as db: result = await db.execute(select(File).filter(File.id.in_(ids)).order_by(File.updated_at.desc())) return [FileModel.model_validate(file) for file in result.scalars().all()] async def get_file_metadatas_by_ids( - self, ids: list[str], db: Optional[AsyncSession] = None + self, ids: list[str], db: AsyncSession | None = None ) -> list[FileMetadataResponse]: async with get_async_db_context(db) as db: result = await db.execute( @@ -240,17 +239,17 @@ class FilesTable: for row in result.all() ] - async def get_files_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> list[FileModel]: + async def get_files_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> list[FileModel]: async with get_async_db_context(db) as db: result = await db.execute(select(File).filter_by(user_id=user_id)) return [FileModel.model_validate(file) for file in result.scalars().all()] async def get_file_list( self, - user_id: Optional[str] = None, + user_id: str | None = None, skip: int = 0, limit: int = 50, - db: Optional[AsyncSession] = None, + db: AsyncSession | None = None, ) -> 'FileListResponse': async with get_async_db_context(db) as db: stmt = select(File) @@ -290,11 +289,11 @@ class FilesTable: async def search_files( self, - user_id: Optional[str] = None, + user_id: str | None = None, filename: str = '*', skip: int = 0, limit: int = 100, - db: Optional[AsyncSession] = None, + db: AsyncSession | None = None, ) -> list[FileModel]: """ Search files with glob pattern matching, optional user filter, and pagination. @@ -323,8 +322,8 @@ class FilesTable: return [FileModel.model_validate(file) for file in result.scalars().all()] async def update_file_by_id( - self, id: str, form_data: FileUpdateForm, db: Optional[AsyncSession] = None - ) -> Optional[FileModel]: + self, id: str, form_data: FileUpdateForm, db: AsyncSession | None = None + ) -> FileModel | None: async with get_async_db_context(db) as db: try: result = await db.execute(select(File).filter_by(id=id)) @@ -347,8 +346,8 @@ class FilesTable: return None async def update_file_hash_by_id( - self, id: str, hash: Optional[str], db: Optional[AsyncSession] = None - ) -> Optional[FileModel]: + self, id: str, hash: str | None, db: AsyncSession | None = None + ) -> FileModel | None: async with get_async_db_context(db) as db: try: result = await db.execute(select(File).filter_by(id=id)) @@ -361,9 +360,7 @@ class FilesTable: except Exception: return None - async def update_file_data_by_id( - self, id: str, data: dict, db: Optional[AsyncSession] = None - ) -> Optional[FileModel]: + async def update_file_data_by_id(self, id: str, data: dict, db: AsyncSession | None = None) -> FileModel | None: async with get_async_db_context(db) as db: try: result = await db.execute(select(File).filter_by(id=id)) @@ -375,9 +372,7 @@ class FilesTable: except Exception as e: return None - async def update_file_metadata_by_id( - self, id: str, meta: dict, db: Optional[AsyncSession] = None - ) -> Optional[FileModel]: + async def update_file_metadata_by_id(self, id: str, meta: dict, db: AsyncSession | None = None) -> FileModel | None: async with get_async_db_context(db) as db: try: result = await db.execute(select(File).filter_by(id=id)) @@ -389,7 +384,7 @@ class FilesTable: except Exception: return None - async def delete_file_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_file_by_id(self, id: str, db: AsyncSession | None = None) -> bool: async with get_async_db_context(db) as db: try: await db.execute(delete(File).filter_by(id=id)) @@ -399,7 +394,7 @@ class FilesTable: except Exception: return False - async def delete_all_files(self, db: Optional[AsyncSession] = None) -> bool: + async def delete_all_files(self, db: AsyncSession | None = None) -> bool: async with get_async_db_context(db) as db: try: await db.execute(delete(File)) diff --git a/backend/open_webui/models/folders.py b/backend/open_webui/models/folders.py index c553239482..1688b8bd46 100644 --- a/backend/open_webui/models/folders.py +++ b/backend/open_webui/models/folders.py @@ -1,15 +1,13 @@ import logging +import re import time import uuid from typing import Optional -import re - - -from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, Text, JSON, Boolean, func, select, delete -from sqlalchemy.ext.asyncio import AsyncSession from open_webui.internal.db import Base, JSONField, get_async_db_context +from pydantic import BaseModel, ConfigDict +from sqlalchemy import JSON, BigInteger, Boolean, Column, Text, delete, func, select +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) diff --git a/backend/open_webui/models/functions.py b/backend/open_webui/models/functions.py index ddac317863..3a47d402cf 100644 --- a/backend/open_webui/models/functions.py +++ b/backend/open_webui/models/functions.py @@ -1,13 +1,14 @@ +from __future__ import annotations + import logging import time from typing import Optional -from sqlalchemy import select, delete, update -from sqlalchemy.ext.asyncio import AsyncSession from open_webui.internal.db import Base, JSONField, get_async_db_context -from open_webui.models.users import Users, UserModel, UserResponse +from open_webui.models.users import UserModel, UserResponse, Users from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Boolean, Column, String, Text, Index +from sqlalchemy import BigInteger, Boolean, Column, Index, String, Text, delete, select, update +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) @@ -37,8 +38,8 @@ class Function(Base): class FunctionMeta(BaseModel): - description: Optional[str] = None - manifest: Optional[dict] = {} + description: str | None = None + manifest: dict | None = {} model_config = ConfigDict(extra='allow') @@ -64,7 +65,7 @@ class FunctionWithValvesModel(BaseModel): type: str content: str meta: FunctionMeta - valves: Optional[dict] = None + valves: dict | None = None is_active: bool = False is_global: bool = False updated_at: int # timestamp in epoch @@ -93,7 +94,7 @@ class FunctionResponse(BaseModel): class FunctionUserResponse(FunctionResponse): - user: Optional[UserResponse] = None + user: UserResponse | None = None class FunctionForm(BaseModel): @@ -104,7 +105,7 @@ class FunctionForm(BaseModel): class FunctionValves(BaseModel): - valves: Optional[dict] = None + valves: dict | None = None class FunctionsTable: @@ -113,8 +114,8 @@ class FunctionsTable: user_id: str, type: str, form_data: FunctionForm, - db: Optional[AsyncSession] = None, - ) -> Optional[FunctionModel]: + db: AsyncSession | None = None, + ) -> FunctionModel | None: function = FunctionModel( **{ **form_data.model_dump(), @@ -143,7 +144,7 @@ class FunctionsTable: self, user_id: str, functions: list[FunctionWithValvesModel], - db: Optional[AsyncSession] = None, + db: AsyncSession | None = None, ) -> list[FunctionWithValvesModel]: # Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present. try: @@ -191,7 +192,7 @@ class FunctionsTable: log.exception(f'Error syncing functions for user {user_id}: {e}') return [] - async def get_function_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[FunctionModel]: + async def get_function_by_id(self, id: str, db: AsyncSession | None = None) -> FunctionModel | None: try: async with get_async_db_context(db) as db: function = await db.get(Function, id) @@ -199,7 +200,7 @@ class FunctionsTable: except Exception: return None - async def get_functions_by_ids(self, ids: list[str], db: Optional[AsyncSession] = None) -> list[FunctionModel]: + async def get_functions_by_ids(self, ids: list[str], db: AsyncSession | None = None) -> list[FunctionModel]: """ Batch fetch multiple functions by their IDs in a single query. Returns functions in the same order as the input IDs (None entries filtered out). @@ -218,7 +219,7 @@ class FunctionsTable: return [] async def get_functions( - self, active_only=False, include_valves=False, db: Optional[AsyncSession] = None + self, active_only=False, include_valves=False, db: AsyncSession | None = None ) -> list[FunctionModel | FunctionWithValvesModel]: async with get_async_db_context(db) as db: if active_only: @@ -233,7 +234,7 @@ class FunctionsTable: else: return [FunctionModel.model_validate(function) for function in functions] - async def get_function_list(self, db: Optional[AsyncSession] = None) -> list[FunctionUserResponse]: + async def get_function_list(self, db: AsyncSession | None = None) -> list[FunctionUserResponse]: async with get_async_db_context(db) as db: result = await db.execute(select(Function).order_by(Function.updated_at.desc())) functions = result.scalars().all() @@ -262,7 +263,7 @@ class FunctionsTable: ] async def get_functions_by_type( - self, type: str, active_only=False, db: Optional[AsyncSession] = None + self, type: str, active_only=False, db: AsyncSession | None = None ) -> list[FunctionModel]: async with get_async_db_context(db) as db: if active_only: @@ -271,17 +272,17 @@ class FunctionsTable: result = await db.execute(select(Function).filter_by(type=type)) return [FunctionModel.model_validate(function) for function in result.scalars().all()] - async def get_global_filter_functions(self, db: Optional[AsyncSession] = None) -> list[FunctionModel]: + async def get_global_filter_functions(self, db: AsyncSession | None = None) -> list[FunctionModel]: async with get_async_db_context(db) as db: result = await db.execute(select(Function).filter_by(type='filter', is_active=True, is_global=True)) return [FunctionModel.model_validate(function) for function in result.scalars().all()] - async def get_global_action_functions(self, db: Optional[AsyncSession] = None) -> list[FunctionModel]: + async def get_global_action_functions(self, db: AsyncSession | None = None) -> list[FunctionModel]: async with get_async_db_context(db) as db: result = await db.execute(select(Function).filter_by(type='action', is_active=True, is_global=True)) return [FunctionModel.model_validate(function) for function in result.scalars().all()] - async def get_function_valves_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[dict]: + async def get_function_valves_by_id(self, id: str, db: AsyncSession | None = None) -> dict | None: async with get_async_db_context(db) as db: try: function = await db.get(Function, id) @@ -290,7 +291,7 @@ class FunctionsTable: log.exception(f'Error getting function valves by id {id}: {e}') return None - async def get_function_valves_by_ids(self, ids: list[str], db: Optional[AsyncSession] = None) -> dict[str, dict]: + async def get_function_valves_by_ids(self, ids: list[str], db: AsyncSession | None = None) -> dict[str, dict]: """ Batch fetch valves for multiple functions in a single query. Returns a dict mapping function_id -> valves dict. @@ -308,8 +309,8 @@ class FunctionsTable: return {} async def update_function_valves_by_id( - self, id: str, valves: dict, db: Optional[AsyncSession] = None - ) -> Optional[FunctionValves]: + self, id: str, valves: dict, db: AsyncSession | None = None + ) -> FunctionValves | None: async with get_async_db_context(db) as db: try: function = await db.get(Function, id) @@ -322,8 +323,8 @@ class FunctionsTable: return None async def update_function_metadata_by_id( - self, id: str, metadata: dict, db: Optional[AsyncSession] = None - ) -> Optional[FunctionModel]: + self, id: str, metadata: dict, db: AsyncSession | None = None + ) -> FunctionModel | None: async with get_async_db_context(db) as db: try: function = await db.get(Function, id) @@ -345,8 +346,8 @@ class FunctionsTable: return None async def get_user_valves_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[AsyncSession] = None - ) -> Optional[dict]: + self, id: str, user_id: str, db: AsyncSession | None = None + ) -> dict | None: try: user = await Users.get_user_by_id(user_id, db=db) user_settings = user.settings.model_dump() if user.settings else {} @@ -363,8 +364,8 @@ class FunctionsTable: return None async def update_user_valves_by_id_and_user_id( - self, id: str, user_id: str, valves: dict, db: Optional[AsyncSession] = None - ) -> Optional[dict]: + self, id: str, user_id: str, valves: dict, db: AsyncSession | None = None + ) -> dict | None: try: user = await Users.get_user_by_id(user_id, db=db) user_settings = user.settings.model_dump() if user.settings else {} @@ -386,8 +387,8 @@ class FunctionsTable: return None async def update_function_by_id( - self, id: str, updated: dict, db: Optional[AsyncSession] = None - ) -> Optional[FunctionModel]: + self, id: str, updated: dict, db: AsyncSession | None = None + ) -> FunctionModel | None: async with get_async_db_context(db) as db: try: await db.execute( @@ -404,7 +405,7 @@ class FunctionsTable: except Exception: return None - async def deactivate_all_functions(self, db: Optional[AsyncSession] = None) -> Optional[bool]: + async def deactivate_all_functions(self, db: AsyncSession | None = None) -> bool | None: async with get_async_db_context(db) as db: try: await db.execute( @@ -418,7 +419,7 @@ class FunctionsTable: except Exception: return None - async def delete_function_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_function_by_id(self, id: str, db: AsyncSession | None = None) -> bool: async with get_async_db_context(db) as db: try: await db.execute(delete(Function).filter_by(id=id)) diff --git a/backend/open_webui/models/groups.py b/backend/open_webui/models/groups.py index bc199fac5b..66785794a7 100644 --- a/backend/open_webui/models/groups.py +++ b/backend/open_webui/models/groups.py @@ -1,25 +1,29 @@ import json import logging import time -from typing import Optional import uuid +from typing import Optional -from sqlalchemy import select, delete, update, func, and_, or_, cast, String -from sqlalchemy.ext.asyncio import AsyncSession -from open_webui.internal.db import Base, JSONField, get_async_db_context from open_webui.env import DEFAULT_GROUP_SHARE_PERMISSION - +from open_webui.internal.db import Base, JSONField, get_async_db_context from open_webui.models.files import FileMetadataResponse - - from pydantic import BaseModel, ConfigDict from sqlalchemy import ( + JSON, BigInteger, Column, - Text, - JSON, ForeignKey, + String, + Text, + and_, + cast, + delete, + func, + or_, + select, + update, ) +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) diff --git a/backend/open_webui/models/knowledge.py b/backend/open_webui/models/knowledge.py index e08e626981..3a7ceb629e 100644 --- a/backend/open_webui/models/knowledge.py +++ b/backend/open_webui/models/knowledge.py @@ -1,34 +1,36 @@ import json import logging import time -from typing import Optional import uuid +from typing import Optional -from sqlalchemy import select, delete, update, or_, func, cast -from sqlalchemy.ext.asyncio import AsyncSession from open_webui.internal.db import Base, JSONField, get_async_db_context - +from open_webui.models.access_grants import AccessGrantModel, AccessGrants from open_webui.models.files import ( File, - FileModel, FileMetadataResponse, + FileModel, FileModelResponse, ) from open_webui.models.groups import Groups -from open_webui.models.users import User, UserModel, Users, UserResponse -from open_webui.models.access_grants import AccessGrantModel, AccessGrants - - +from open_webui.models.users import User, UserModel, UserResponse, Users from pydantic import BaseModel, ConfigDict, Field from sqlalchemy import ( + JSON, BigInteger, Column, ForeignKey, String, Text, - JSON, UniqueConstraint, + cast, + delete, + func, + or_, + select, + update, ) +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) diff --git a/backend/open_webui/models/memories.py b/backend/open_webui/models/memories.py index 1ec52eeb6a..d7f51e2c9e 100644 --- a/backend/open_webui/models/memories.py +++ b/backend/open_webui/models/memories.py @@ -1,12 +1,13 @@ +from __future__ import annotations + import time import uuid from typing import Optional -from sqlalchemy import select, delete -from sqlalchemy.ext.asyncio import AsyncSession from open_webui.internal.db import Base, get_async_db_context from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, Text +from sqlalchemy import BigInteger, Column, String, Text, delete, select +from sqlalchemy.ext.asyncio import AsyncSession #################### # Memory DB Schema @@ -45,8 +46,8 @@ class MemoriesTable: self, user_id: str, content: str, - db: Optional[AsyncSession] = None, - ) -> Optional[MemoryModel]: + db: AsyncSession | None = None, + ) -> MemoryModel | None: async with get_async_db_context(db) as db: id = str(uuid.uuid4()) @@ -73,8 +74,8 @@ class MemoriesTable: id: str, user_id: str, content: str, - db: Optional[AsyncSession] = None, - ) -> Optional[MemoryModel]: + db: AsyncSession | None = None, + ) -> MemoryModel | None: async with get_async_db_context(db) as db: try: memory = await db.get(Memory, id) @@ -90,7 +91,7 @@ class MemoriesTable: except Exception: return None - async def get_memories(self, db: Optional[AsyncSession] = None) -> list[MemoryModel]: + async def get_memories(self, db: AsyncSession | None = None) -> list[MemoryModel]: async with get_async_db_context(db) as db: try: result = await db.execute(select(Memory)) @@ -99,7 +100,7 @@ class MemoriesTable: except Exception: return None - async def get_memories_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> list[MemoryModel]: + async def get_memories_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> list[MemoryModel]: async with get_async_db_context(db) as db: try: result = await db.execute(select(Memory).filter_by(user_id=user_id)) @@ -108,7 +109,7 @@ class MemoriesTable: except Exception: return None - async def get_memory_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[MemoryModel]: + async def get_memory_by_id(self, id: str, db: AsyncSession | None = None) -> MemoryModel | None: async with get_async_db_context(db) as db: try: memory = await db.get(Memory, id) @@ -116,7 +117,7 @@ class MemoriesTable: except Exception: return None - async def delete_memory_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_memory_by_id(self, id: str, db: AsyncSession | None = None) -> bool: async with get_async_db_context(db) as db: try: await db.execute(delete(Memory).filter_by(id=id)) @@ -127,7 +128,7 @@ class MemoriesTable: except Exception: return False - async def delete_memories_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_memories_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool: async with get_async_db_context(db) as db: try: await db.execute(delete(Memory).filter_by(user_id=user_id)) @@ -137,7 +138,7 @@ class MemoriesTable: except Exception: return False - async def delete_memory_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_memory_by_id_and_user_id(self, id: str, user_id: str, db: AsyncSession | None = None) -> bool: async with get_async_db_context(db) as db: try: memory = await db.get(Memory, id) diff --git a/backend/open_webui/models/messages.py b/backend/open_webui/models/messages.py index 7f33a72eff..342abed2f8 100644 --- a/backend/open_webui/models/messages.py +++ b/backend/open_webui/models/messages.py @@ -3,17 +3,13 @@ import time import uuid from typing import Optional -from sqlalchemy import select, delete, func -from sqlalchemy.ext.asyncio import AsyncSession from open_webui.internal.db import Base, JSONField, get_async_db_context -from open_webui.models.tags import TagModel, Tag, Tags -from open_webui.models.users import Users, User, UserNameResponse -from open_webui.models.channels import Channels, ChannelMember - - +from open_webui.models.channels import ChannelMember, Channels +from open_webui.models.tags import Tag, TagModel, Tags +from open_webui.models.users import User, UserNameResponse, Users from pydantic import BaseModel, ConfigDict, field_validator -from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON -from sqlalchemy import or_, func, and_, text +from sqlalchemy import JSON, BigInteger, Boolean, Column, String, Text, and_, delete, func, or_, select, text +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.sql import exists #################### diff --git a/backend/open_webui/models/models.py b/backend/open_webui/models/models.py index 79c13153ac..443cd4530d 100755 --- a/backend/open_webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -1,21 +1,18 @@ +from __future__ import annotations + import json import logging import time from typing import Optional -from sqlalchemy import select, delete, update, or_, func, String, cast -from sqlalchemy.ext.asyncio import AsyncSession from open_webui.internal.db import Base, JSONField, get_async_db_context - -from open_webui.models.groups import Groups -from open_webui.models.users import User, UserModel, Users, UserResponse from open_webui.models.access_grants import AccessGrantModel, AccessGrants - - +from open_webui.models.groups import Groups +from open_webui.models.users import User, UserModel, UserResponse, Users from pydantic import BaseModel, ConfigDict, Field, model_validator - +from sqlalchemy import BigInteger, Boolean, Column, String, Text, cast, delete, func, or_, select, update from sqlalchemy.dialects.postgresql import JSONB -from sqlalchemy import BigInteger, Column, Text, Boolean +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) @@ -35,14 +32,14 @@ class ModelParams(BaseModel): # ModelMeta is a model for the data stored in the meta field of the Model table class ModelMeta(BaseModel): - profile_image_url: Optional[str] = '/static/favicon.png' + profile_image_url: str | None = '/static/favicon.png' - description: Optional[str] = None + description: str | None = None """ User-facing description of the model. """ - capabilities: Optional[dict] = None + capabilities: dict | None = None model_config = ConfigDict(extra='allow') @@ -100,7 +97,7 @@ class Model(Base): class ModelModel(BaseModel): id: str user_id: str - base_model_id: Optional[str] = None + base_model_id: str | None = None name: str params: ModelParams @@ -121,11 +118,11 @@ class ModelModel(BaseModel): class ModelUserResponse(ModelModel): - user: Optional[UserResponse] = None + user: UserResponse | None = None class ModelAccessResponse(ModelUserResponse): - write_access: Optional[bool] = False + write_access: bool | None = False class ModelResponse(ModelModel): @@ -146,23 +143,23 @@ class ModelForm(BaseModel): model_config = ConfigDict(extra='ignore') id: str - base_model_id: Optional[str] = None + base_model_id: str | None = None name: str meta: ModelMeta params: ModelParams - access_grants: Optional[list[dict]] = None + access_grants: list[dict | None] = None is_active: bool = True class ModelsTable: - async def _get_access_grants(self, model_id: str, db: Optional[AsyncSession] = None) -> list[AccessGrantModel]: + async def _get_access_grants(self, model_id: str, db: AsyncSession | None = None) -> list[AccessGrantModel]: return await AccessGrants.get_grants_by_resource('model', model_id, db=db) async def _to_model_model( self, model: Model, - access_grants: Optional[list[AccessGrantModel]] = None, - db: Optional[AsyncSession] = None, + access_grants: list[AccessGrantModel | None] = None, + db: AsyncSession | None = None, ) -> ModelModel: model_data = ModelModel.model_validate(model).model_dump(exclude={'access_grants'}) model_data['access_grants'] = ( @@ -171,8 +168,8 @@ class ModelsTable: return ModelModel.model_validate(model_data) async def insert_new_model( - self, form_data: ModelForm, user_id: str, db: Optional[AsyncSession] = None - ) -> Optional[ModelModel]: + self, form_data: ModelForm, user_id: str, db: AsyncSession | None = None + ) -> ModelModel | None: try: async with get_async_db_context(db) as db: result = Model( @@ -196,7 +193,7 @@ class ModelsTable: log.exception(f'Failed to insert a new model: {e}') return None - async def get_all_models(self, db: Optional[AsyncSession] = None) -> list[ModelModel]: + async def get_all_models(self, db: AsyncSession | None = None) -> list[ModelModel]: async with get_async_db_context(db) as db: result = await db.execute(select(Model)) all_models = result.scalars().all() @@ -207,7 +204,7 @@ class ModelsTable: for model in all_models ] - async def get_models(self, db: Optional[AsyncSession] = None) -> list[ModelUserResponse]: + async def get_models(self, db: AsyncSession | None = None) -> list[ModelUserResponse]: async with get_async_db_context(db) as db: result = await db.execute(select(Model).filter(Model.base_model_id != None)) all_models = result.scalars().all() @@ -238,7 +235,7 @@ class ModelsTable: ) return models - async def get_base_models(self, db: Optional[AsyncSession] = None) -> list[ModelModel]: + async def get_base_models(self, db: AsyncSession | None = None) -> list[ModelModel]: async with get_async_db_context(db) as db: result = await db.execute(select(Model).filter(Model.base_model_id == None)) all_models = result.scalars().all() @@ -250,7 +247,7 @@ class ModelsTable: ] async def get_models_by_user_id( - self, user_id: str, permission: str = 'write', db: Optional[AsyncSession] = None + self, user_id: str, permission: str = 'write', db: AsyncSession | None = None ) -> list[ModelUserResponse]: models = await self.get_models(db=db) user_groups = await Groups.get_groups_by_member_id(user_id, db=db) @@ -287,7 +284,7 @@ class ModelsTable: filter: dict = {}, skip: int = 0, limit: int = 30, - db: Optional[AsyncSession] = None, + db: AsyncSession | None = None, ) -> ModelListResponse: async with get_async_db_context(db) as db: stmt = select(Model, User).outerjoin(User, User.id == Model.user_id) @@ -391,7 +388,7 @@ class ModelsTable: return ModelListResponse(items=models, total=total) - async def get_model_meta_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[tuple[dict, int]]: + async def get_model_meta_by_id(self, id: str, db: AsyncSession | None = None) -> tuple[dict, int | None]: """Return (meta, updated_at) for a model, skipping access grant resolution.""" try: async with get_async_db_context(db) as db: @@ -404,7 +401,7 @@ class ModelsTable: self, user_id: str, is_admin: bool = False, - db: Optional[AsyncSession] = None, + db: AsyncSession | None = None, ) -> set[str]: """Extract unique tag names from model meta, querying only the meta column.""" async with get_async_db_context(db) as db: @@ -437,7 +434,7 @@ class ModelsTable: return tags_set - async def get_model_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ModelModel]: + async def get_model_by_id(self, id: str, db: AsyncSession | None = None) -> ModelModel | None: try: async with get_async_db_context(db) as db: model = await db.get(Model, id) @@ -445,7 +442,7 @@ class ModelsTable: except Exception: return None - async def get_models_by_ids(self, ids: list[str], db: Optional[AsyncSession] = None) -> list[ModelModel]: + async def get_models_by_ids(self, ids: list[str], db: AsyncSession | None = None) -> list[ModelModel]: try: async with get_async_db_context(db) as db: result = await db.execute(select(Model).filter(Model.id.in_(ids))) @@ -463,7 +460,7 @@ class ModelsTable: except Exception: return [] - async def toggle_model_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ModelModel]: + async def toggle_model_by_id(self, id: str, db: AsyncSession | None = None) -> ModelModel | None: async with get_async_db_context(db) as db: try: result = await db.execute(select(Model).filter_by(id=id)) @@ -480,9 +477,7 @@ class ModelsTable: except Exception: return None - async def update_model_by_id( - self, id: str, model: ModelForm, db: Optional[AsyncSession] = None - ) -> Optional[ModelModel]: + async def update_model_by_id(self, id: str, model: ModelForm, db: AsyncSession | None = None) -> ModelModel | None: try: async with get_async_db_context(db) as db: # update only the fields that are present in the model @@ -499,7 +494,7 @@ class ModelsTable: log.exception(f'Failed to update the model by id {id}: {e}') return None - async def update_model_updated_at_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ModelModel]: + async def update_model_updated_at_by_id(self, id: str, db: AsyncSession | None = None) -> ModelModel | None: try: async with get_async_db_context(db) as db: result = await db.execute(select(Model).filter_by(id=id)) @@ -514,7 +509,7 @@ class ModelsTable: log.exception(f'Failed to update the model updated_at by id {id}: {e}') return None - async def delete_model_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_model_by_id(self, id: str, db: AsyncSession | None = None) -> bool: try: async with get_async_db_context(db) as db: await AccessGrants.revoke_all_access('model', id, db=db) @@ -525,7 +520,7 @@ class ModelsTable: except Exception: return False - async def delete_all_models(self, db: Optional[AsyncSession] = None) -> bool: + async def delete_all_models(self, db: AsyncSession | None = None) -> bool: try: async with get_async_db_context(db) as db: result = await db.execute(select(Model.id)) @@ -540,7 +535,7 @@ class ModelsTable: return False async def sync_models( - self, user_id: str, models: list[ModelModel], db: Optional[AsyncSession] = None + self, user_id: str, models: list[ModelModel], db: AsyncSession | None = None ) -> list[ModelModel]: try: async with get_async_db_context(db) as db: diff --git a/backend/open_webui/models/notes.py b/backend/open_webui/models/notes.py index f651d226ca..9a858bd2f6 100644 --- a/backend/open_webui/models/notes.py +++ b/backend/open_webui/models/notes.py @@ -1,19 +1,16 @@ import json import time import uuid -from typing import Optional from functools import lru_cache +from typing import Optional -from sqlalchemy import Boolean, select, delete, update, or_, func, cast -from sqlalchemy.ext.asyncio import AsyncSession from open_webui.internal.db import Base, get_async_db_context -from open_webui.models.groups import Groups -from open_webui.models.users import User, UserModel, Users, UserResponse from open_webui.models.access_grants import AccessGrantModel, AccessGrants - - +from open_webui.models.groups import Groups +from open_webui.models.users import User, UserModel, UserResponse, Users from pydantic import BaseModel, ConfigDict, Field -from sqlalchemy import BigInteger, Column, Text, JSON, ForeignKey +from sqlalchemy import JSON, BigInteger, Boolean, Column, ForeignKey, Text, cast, delete, func, or_, select, update +from sqlalchemy.ext.asyncio import AsyncSession #################### # Note DB Schema diff --git a/backend/open_webui/models/oauth_sessions.py b/backend/open_webui/models/oauth_sessions.py index c43567f670..0619bd574a 100644 --- a/backend/open_webui/models/oauth_sessions.py +++ b/backend/open_webui/models/oauth_sessions.py @@ -1,20 +1,17 @@ -import time -import logging -import uuid -from typing import Optional, List import base64 import hashlib import json +import logging +import time +import uuid +from typing import List, Optional from cryptography.fernet import Fernet - -from sqlalchemy import select, delete, update -from sqlalchemy.ext.asyncio import AsyncSession -from open_webui.internal.db import Base, get_async_db_context from open_webui.env import OAUTH_SESSION_TOKEN_ENCRYPTION_KEY - +from open_webui.internal.db import Base, get_async_db_context from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, Text, Index +from sqlalchemy import BigInteger, Column, Index, String, Text, delete, select, update +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) diff --git a/backend/open_webui/models/prompt_history.py b/backend/open_webui/models/prompt_history.py index 5d0f4a65b2..a73032c4bb 100644 --- a/backend/open_webui/models/prompt_history.py +++ b/backend/open_webui/models/prompt_history.py @@ -1,18 +1,16 @@ """Prompt history model for version tracking.""" +import difflib +import json import time import uuid from typing import Optional -import json -import difflib -from sqlalchemy import select, delete, func -from sqlalchemy.ext.asyncio import AsyncSession from open_webui.internal.db import Base, get_async_db_context -from open_webui.models.users import Users, UserResponse - +from open_webui.models.users import UserResponse, Users from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, Text, JSON, Index +from sqlalchemy import JSON, BigInteger, Column, Index, Text, delete, func, select +from sqlalchemy.ext.asyncio import AsyncSession #################### # PromptHistory DB Schema diff --git a/backend/open_webui/models/prompts.py b/backend/open_webui/models/prompts.py index 5a3e35d23d..732c868227 100644 --- a/backend/open_webui/models/prompts.py +++ b/backend/open_webui/models/prompts.py @@ -1,19 +1,18 @@ +from __future__ import annotations + import json import time import uuid from typing import Optional -from sqlalchemy import select, delete, update, or_, func, text, cast, String -from sqlalchemy.ext.asyncio import AsyncSession from open_webui.internal.db import Base, JSONField, get_async_db_context -from open_webui.models.groups import Groups -from open_webui.models.users import Users, User, UserModel, UserResponse -from open_webui.models.prompt_history import PromptHistories from open_webui.models.access_grants import AccessGrantModel, AccessGrants - - +from open_webui.models.groups import Groups +from open_webui.models.prompt_history import PromptHistories +from open_webui.models.users import User, UserModel, UserResponse, Users from pydantic import BaseModel, ConfigDict, Field -from sqlalchemy import BigInteger, Boolean, Column, Text, JSON +from sqlalchemy import JSON, BigInteger, Boolean, Column, String, Text, cast, delete, func, or_, select, text, update +from sqlalchemy.ext.asyncio import AsyncSession #################### # Prompts DB Schema @@ -40,18 +39,18 @@ class Prompt(Base): class PromptModel(BaseModel): - id: Optional[str] = None + id: str | None = None command: str user_id: str name: str content: str - data: Optional[dict] = None - meta: Optional[dict] = None - tags: Optional[list[str]] = None - is_active: Optional[bool] = True - version_id: Optional[str] = None - created_at: Optional[int] = None - updated_at: Optional[int] = None + data: dict | None = None + meta: dict | None = None + tags: list[str | None] = None + is_active: bool | None = True + version_id: str | None = None + created_at: int | None = None + updated_at: int | None = None access_grants: list[AccessGrantModel] = Field(default_factory=list) model_config = ConfigDict(from_attributes=True) @@ -63,11 +62,11 @@ class PromptModel(BaseModel): class PromptUserResponse(PromptModel): - user: Optional[UserResponse] = None + user: UserResponse | None = None class PromptAccessResponse(PromptUserResponse): - write_access: Optional[bool] = False + write_access: bool | None = False class PromptListResponse(BaseModel): @@ -84,24 +83,24 @@ class PromptForm(BaseModel): command: str name: str # Changed from title content: str - data: Optional[dict] = None - meta: Optional[dict] = None - tags: Optional[list[str]] = None - access_grants: Optional[list[dict]] = None - version_id: Optional[str] = None # Active version - commit_message: Optional[str] = None # For history tracking - is_production: Optional[bool] = True # Whether to set new version as production + data: dict | None = None + meta: dict | None = None + tags: list[str | None] = None + access_grants: list[dict | None] = None + version_id: str | None = None # Active version + commit_message: str | None = None # For history tracking + is_production: bool | None = True # Whether to set new version as production class PromptsTable: - async def _get_access_grants(self, prompt_id: str, db: Optional[AsyncSession] = None) -> list[AccessGrantModel]: + async def _get_access_grants(self, prompt_id: str, db: AsyncSession | None = None) -> list[AccessGrantModel]: return await AccessGrants.get_grants_by_resource('prompt', prompt_id, db=db) async def _to_prompt_model( self, prompt: Prompt, - access_grants: Optional[list[AccessGrantModel]] = None, - db: Optional[AsyncSession] = None, + access_grants: list[AccessGrantModel | None] = None, + db: AsyncSession | None = None, ) -> PromptModel: prompt_data = PromptModel.model_validate(prompt).model_dump(exclude={'access_grants'}) prompt_data['access_grants'] = ( @@ -110,8 +109,8 @@ class PromptsTable: return PromptModel.model_validate(prompt_data) async def insert_new_prompt( - self, user_id: str, form_data: PromptForm, db: Optional[AsyncSession] = None - ) -> Optional[PromptModel]: + self, user_id: str, form_data: PromptForm, db: AsyncSession | None = None + ) -> PromptModel | None: now = int(time.time()) prompt_id = str(uuid.uuid4()) @@ -171,7 +170,7 @@ class PromptsTable: except Exception: return None - async def get_prompt_by_id(self, prompt_id: str, db: Optional[AsyncSession] = None) -> Optional[PromptModel]: + async def get_prompt_by_id(self, prompt_id: str, db: AsyncSession | None = None) -> PromptModel | None: """Get prompt by UUID.""" try: async with get_async_db_context(db) as db: @@ -183,7 +182,7 @@ class PromptsTable: except Exception: return None - async def get_prompt_by_command(self, command: str, db: Optional[AsyncSession] = None) -> Optional[PromptModel]: + async def get_prompt_by_command(self, command: str, db: AsyncSession | None = None) -> PromptModel | None: try: async with get_async_db_context(db) as db: result = await db.execute(select(Prompt).filter_by(command=command)) @@ -194,7 +193,7 @@ class PromptsTable: except Exception: return None - async def get_prompts(self, db: Optional[AsyncSession] = None) -> list[PromptUserResponse]: + async def get_prompts(self, db: AsyncSession | None = None) -> list[PromptUserResponse]: async with get_async_db_context(db) as db: result = await db.execute( select(Prompt).filter(Prompt.is_active == True).order_by(Prompt.updated_at.desc()) @@ -229,7 +228,7 @@ class PromptsTable: return prompts async def get_prompts_by_user_id( - self, user_id: str, permission: str = 'write', db: Optional[AsyncSession] = None + self, user_id: str, permission: str = 'write', db: AsyncSession | None = None ) -> list[PromptUserResponse]: async with get_async_db_context(db) as db: user_groups = await Groups.get_groups_by_member_id(user_id, db=db) @@ -283,7 +282,7 @@ class PromptsTable: filter: dict = {}, skip: int = 0, limit: int = 30, - db: Optional[AsyncSession] = None, + db: AsyncSession | None = None, ) -> PromptListResponse: async with get_async_db_context(db) as db: # Join with User table for user filtering and sorting @@ -404,8 +403,8 @@ class PromptsTable: command: str, form_data: PromptForm, user_id: str, - db: Optional[AsyncSession] = None, - ) -> Optional[PromptModel]: + db: AsyncSession | None = None, + ) -> PromptModel | None: try: async with get_async_db_context(db) as db: result = await db.execute(select(Prompt).filter_by(command=command)) @@ -470,8 +469,8 @@ class PromptsTable: prompt_id: str, form_data: PromptForm, user_id: str, - db: Optional[AsyncSession] = None, - ) -> Optional[PromptModel]: + db: AsyncSession | None = None, + ) -> PromptModel | None: try: async with get_async_db_context(db) as db: result = await db.execute(select(Prompt).filter_by(id=prompt_id)) @@ -545,9 +544,9 @@ class PromptsTable: prompt_id: str, name: str, command: str, - tags: Optional[list[str]] = None, - db: Optional[AsyncSession] = None, - ) -> Optional[PromptModel]: + tags: list[str | None] = None, + db: AsyncSession | None = None, + ) -> PromptModel | None: """Update only name, command, and tags (no history created).""" try: async with get_async_db_context(db) as db: @@ -573,8 +572,8 @@ class PromptsTable: self, prompt_id: str, version_id: str, - db: Optional[AsyncSession] = None, - ) -> Optional[PromptModel]: + db: AsyncSession | None = None, + ) -> PromptModel | None: """Set the active version of a prompt and restore content from that version's snapshot.""" try: async with get_async_db_context(db) as db: @@ -606,7 +605,7 @@ class PromptsTable: except Exception: return None - async def toggle_prompt_active(self, prompt_id: str, db: Optional[AsyncSession] = None) -> Optional[PromptModel]: + async def toggle_prompt_active(self, prompt_id: str, db: AsyncSession | None = None) -> PromptModel | None: """Toggle the is_active flag on a prompt.""" try: async with get_async_db_context(db) as db: @@ -622,7 +621,7 @@ class PromptsTable: except Exception: return None - async def delete_prompt_by_command(self, command: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_prompt_by_command(self, command: str, db: AsyncSession | None = None) -> bool: """Permanently delete a prompt and its history.""" try: async with get_async_db_context(db) as db: @@ -639,7 +638,7 @@ class PromptsTable: except Exception: return False - async def delete_prompt_by_id(self, prompt_id: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_prompt_by_id(self, prompt_id: str, db: AsyncSession | None = None) -> bool: """Permanently delete a prompt and its history.""" try: async with get_async_db_context(db) as db: @@ -656,7 +655,7 @@ class PromptsTable: except Exception: return False - async def get_tags(self, db: Optional[AsyncSession] = None) -> list[str]: + async def get_tags(self, db: AsyncSession | None = None) -> list[str]: try: async with get_async_db_context(db) as db: result = await db.execute(select(Prompt.tags).filter(Prompt.is_active == True)) @@ -670,7 +669,7 @@ class PromptsTable: except Exception: return [] - async def get_tags_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> list[str]: + async def get_tags_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> list[str]: try: async with get_async_db_context(db) as db: user_groups = await Groups.get_groups_by_member_id(user_id, db=db) diff --git a/backend/open_webui/models/shared_chats.py b/backend/open_webui/models/shared_chats.py index 37a3fea852..9132f11301 100644 --- a/backend/open_webui/models/shared_chats.py +++ b/backend/open_webui/models/shared_chats.py @@ -3,12 +3,10 @@ import time import uuid from typing import Optional -from sqlalchemy import select, delete -from sqlalchemy.ext.asyncio import AsyncSession from open_webui.internal.db import Base, JSONField, get_async_db_context - from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, ForeignKey, Text, JSON +from sqlalchemy import JSON, BigInteger, Column, ForeignKey, Text, delete, select +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) diff --git a/backend/open_webui/models/skills.py b/backend/open_webui/models/skills.py index 0fc6dfc52d..5bc8b54efc 100644 --- a/backend/open_webui/models/skills.py +++ b/backend/open_webui/models/skills.py @@ -2,15 +2,13 @@ import logging import time from typing import Optional -from sqlalchemy import select, delete, update, or_ -from sqlalchemy.ext.asyncio import AsyncSession from open_webui.internal.db import Base, get_async_db_context -from open_webui.models.users import Users, User, UserModel, UserResponse -from open_webui.models.groups import Groups from open_webui.models.access_grants import AccessGrantModel, AccessGrants - +from open_webui.models.groups import Groups +from open_webui.models.users import User, UserModel, UserResponse, Users from pydantic import BaseModel, ConfigDict, Field -from sqlalchemy import JSON, BigInteger, Boolean, Column, String, Text, func +from sqlalchemy import JSON, BigInteger, Boolean, Column, String, Text, delete, func, or_, select, update +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) diff --git a/backend/open_webui/models/tags.py b/backend/open_webui/models/tags.py index ee2baefc01..defd424d59 100644 --- a/backend/open_webui/models/tags.py +++ b/backend/open_webui/models/tags.py @@ -1,15 +1,14 @@ +from __future__ import annotations + import logging import time import uuid from typing import Optional -from sqlalchemy import select, delete -from sqlalchemy.ext.asyncio import AsyncSession from open_webui.internal.db import Base, JSONField, get_async_db_context - - from pydantic import BaseModel, ConfigDict -from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint, Index +from sqlalchemy import JSON, BigInteger, Column, Index, PrimaryKeyConstraint, String, delete, select +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) @@ -39,7 +38,7 @@ class TagModel(BaseModel): id: str name: str user_id: str - meta: Optional[dict] = None + meta: dict | None = None model_config = ConfigDict(from_attributes=True) @@ -54,7 +53,7 @@ class TagChatIdForm(BaseModel): class TagTable: - async def insert_new_tag(self, name: str, user_id: str, db: Optional[AsyncSession] = None) -> Optional[TagModel]: + async def insert_new_tag(self, name: str, user_id: str, db: AsyncSession | None = None) -> TagModel | None: async with get_async_db_context(db) as db: id = name.replace(' ', '_').lower() tag = TagModel(**{'id': id, 'user_id': user_id, 'name': name}) @@ -72,8 +71,8 @@ class TagTable: return None async def get_tag_by_name_and_user_id( - self, name: str, user_id: str, db: Optional[AsyncSession] = None - ) -> Optional[TagModel]: + self, name: str, user_id: str, db: AsyncSession | None = None + ) -> TagModel | None: try: id = name.replace(' ', '_').lower() async with get_async_db_context(db) as db: @@ -83,19 +82,19 @@ class TagTable: except Exception: return None - async def get_tags_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> list[TagModel]: + async def get_tags_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> list[TagModel]: async with get_async_db_context(db) as db: result = await db.execute(select(Tag).filter_by(user_id=user_id)) return [TagModel.model_validate(tag) for tag in result.scalars().all()] async def get_tags_by_ids_and_user_id( - self, ids: list[str], user_id: str, db: Optional[AsyncSession] = None + self, ids: list[str], user_id: str, db: AsyncSession | None = None ) -> list[TagModel]: async with get_async_db_context(db) as db: result = await db.execute(select(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id)) return [TagModel.model_validate(tag) for tag in result.scalars().all()] - async def delete_tag_by_name_and_user_id(self, name: str, user_id: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_tag_by_name_and_user_id(self, name: str, user_id: str, db: AsyncSession | None = None) -> bool: try: async with get_async_db_context(db) as db: id = name.replace(' ', '_').lower() @@ -108,7 +107,7 @@ class TagTable: return False async def delete_tags_by_ids_and_user_id( - self, ids: list[str], user_id: str, db: Optional[AsyncSession] = None + self, ids: list[str], user_id: str, db: AsyncSession | None = None ) -> bool: """Delete all tags whose id is in *ids* for the given user, in one query.""" if not ids: @@ -122,7 +121,7 @@ class TagTable: log.error(f'delete_tags_by_ids: {e}') return False - async def ensure_tags_exist(self, names: list[str], user_id: str, db: Optional[AsyncSession] = None) -> None: + async def ensure_tags_exist(self, names: list[str], user_id: str, db: AsyncSession | None = None) -> None: """Create tag rows for any *names* that don't already exist for *user_id*.""" if not names: return diff --git a/backend/open_webui/models/tools.py b/backend/open_webui/models/tools.py index 70035121aa..9a1a202292 100644 --- a/backend/open_webui/models/tools.py +++ b/backend/open_webui/models/tools.py @@ -1,16 +1,16 @@ +from __future__ import annotations + import logging import time from typing import Optional -from sqlalchemy import select, delete, update -from sqlalchemy.ext.asyncio import AsyncSession from open_webui.internal.db import Base, JSONField, get_async_db_context -from open_webui.models.users import Users, UserResponse -from open_webui.models.groups import Groups from open_webui.models.access_grants import AccessGrantModel, AccessGrants - +from open_webui.models.groups import Groups +from open_webui.models.users import UserResponse, Users from pydantic import BaseModel, ConfigDict, Field -from sqlalchemy import BigInteger, Column, String, Text +from sqlalchemy import BigInteger, Column, String, Text, delete, select, update +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) @@ -37,8 +37,8 @@ class Tool(Base): class ToolMeta(BaseModel): - description: Optional[str] = None - manifest: Optional[dict] = {} + description: str | None = None + manifest: dict | None = {} class ToolModel(BaseModel): @@ -62,7 +62,7 @@ class ToolModel(BaseModel): class ToolUserModel(ToolModel): - user: Optional[UserResponse] = None + user: UserResponse | None = None class ToolResponse(BaseModel): @@ -76,13 +76,13 @@ class ToolResponse(BaseModel): class ToolUserResponse(ToolResponse): - user: Optional[UserResponse] = None + user: UserResponse | None = None model_config = ConfigDict(extra='allow') class ToolAccessResponse(ToolUserResponse): - write_access: Optional[bool] = False + write_access: bool | None = False class ToolForm(BaseModel): @@ -90,22 +90,22 @@ class ToolForm(BaseModel): name: str content: str meta: ToolMeta - access_grants: Optional[list[dict]] = None + access_grants: list[dict | None] = None class ToolValves(BaseModel): - valves: Optional[dict] = None + valves: dict | None = None class ToolsTable: - async def _get_access_grants(self, tool_id: str, db: Optional[AsyncSession] = None) -> list[AccessGrantModel]: + async def _get_access_grants(self, tool_id: str, db: AsyncSession | None = None) -> list[AccessGrantModel]: return await AccessGrants.get_grants_by_resource('tool', tool_id, db=db) async def _to_tool_model( self, tool: Tool, - access_grants: Optional[list[AccessGrantModel]] = None, - db: Optional[AsyncSession] = None, + access_grants: list[AccessGrantModel | None] = None, + db: AsyncSession | None = None, ) -> ToolModel: tool_data = ToolModel.model_validate(tool).model_dump(exclude={'access_grants'}) tool_data['access_grants'] = ( @@ -118,8 +118,8 @@ class ToolsTable: user_id: str, form_data: ToolForm, specs: list[dict], - db: Optional[AsyncSession] = None, - ) -> Optional[ToolModel]: + db: AsyncSession | None = None, + ) -> ToolModel | None: async with get_async_db_context(db) as db: try: result = Tool( @@ -143,7 +143,7 @@ class ToolsTable: log.exception(f'Error creating a new tool: {e}') return None - async def get_tool_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ToolModel]: + async def get_tool_by_id(self, id: str, db: AsyncSession | None = None) -> ToolModel | None: try: async with get_async_db_context(db) as db: tool = await db.get(Tool, id) @@ -151,7 +151,7 @@ class ToolsTable: except Exception: return None - async def get_tools(self, defer_content: bool = False, db: Optional[AsyncSession] = None) -> list[ToolUserModel]: + async def get_tools(self, defer_content: bool = False, db: AsyncSession | None = None) -> list[ToolUserModel]: async with get_async_db_context(db) as db: stmt = select(Tool).order_by(Tool.updated_at.desc()) if defer_content: @@ -190,7 +190,7 @@ class ToolsTable: user_id: str, permission: str = 'write', defer_content: bool = False, - db: Optional[AsyncSession] = None, + db: AsyncSession | None = None, ) -> list[ToolUserModel]: tools = await self.get_tools(defer_content=defer_content, db=db) user_groups = await Groups.get_groups_by_member_id(user_id, db=db) @@ -211,7 +211,7 @@ class ToolsTable: result.append(tool) return result - async def get_tool_valves_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[dict]: + async def get_tool_valves_by_id(self, id: str, db: AsyncSession | None = None) -> dict | None: try: async with get_async_db_context(db) as db: tool = await db.get(Tool, id) @@ -221,8 +221,8 @@ class ToolsTable: return None async def update_tool_valves_by_id( - self, id: str, valves: dict, db: Optional[AsyncSession] = None - ) -> Optional[ToolValves]: + self, id: str, valves: dict, db: AsyncSession | None = None + ) -> ToolValves | None: try: async with get_async_db_context(db) as db: await db.execute(update(Tool).filter_by(id=id).values(valves=valves, updated_at=int(time.time()))) @@ -232,8 +232,8 @@ class ToolsTable: return None async def get_user_valves_by_id_and_user_id( - self, id: str, user_id: str, db: Optional[AsyncSession] = None - ) -> Optional[dict]: + self, id: str, user_id: str, db: AsyncSession | None = None + ) -> dict | None: try: user = await Users.get_user_by_id(user_id, db=db) user_settings = user.settings.model_dump() if user.settings else {} @@ -250,8 +250,8 @@ class ToolsTable: return None async def update_user_valves_by_id_and_user_id( - self, id: str, user_id: str, valves: dict, db: Optional[AsyncSession] = None - ) -> Optional[dict]: + self, id: str, user_id: str, valves: dict, db: AsyncSession | None = None + ) -> dict | None: try: user = await Users.get_user_by_id(user_id, db=db) user_settings = user.settings.model_dump() if user.settings else {} @@ -272,7 +272,7 @@ class ToolsTable: log.exception(f'Error updating user valves by id {id} and user_id {user_id}: {e}') return None - async def update_tool_by_id(self, id: str, updated: dict, db: Optional[AsyncSession] = None) -> Optional[ToolModel]: + async def update_tool_by_id(self, id: str, updated: dict, db: AsyncSession | None = None) -> ToolModel | None: try: async with get_async_db_context(db) as db: access_grants = updated.pop('access_grants', None) @@ -287,7 +287,7 @@ class ToolsTable: except Exception: return None - async def delete_tool_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_tool_by_id(self, id: str, db: AsyncSession | None = None) -> bool: try: async with get_async_db_context(db) as db: await AccessGrants.revoke_all_access('tool', id, db=db) diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index 025e79bd8a..56c3c21e63 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -1,29 +1,33 @@ +from __future__ import annotations + +import datetime import time from typing import Optional -from sqlalchemy import select, delete, update, func, or_, case, exists -from sqlalchemy.ext.asyncio import AsyncSession -from open_webui.internal.db import Base, JSONField, get_async_db_context - from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL - +from open_webui.internal.db import Base, JSONField, get_async_db_context from open_webui.utils.misc import throttle from open_webui.utils.validate import validate_profile_image_url - from pydantic import BaseModel, ConfigDict, field_validator, model_validator from sqlalchemy import ( - BigInteger, JSON, - Column, - String, + BigInteger, Boolean, - Text, + Column, Date, + String, + Text, + case, cast, + delete, + exists, + func, + or_, + select, + update, ) from sqlalchemy.dialects.postgresql import JSONB - -import datetime +from sqlalchemy.ext.asyncio import AsyncSession #################### # User DB Schema @@ -33,7 +37,7 @@ import datetime class UserSettings(BaseModel): - ui: Optional[dict] = {} + ui: dict | None = {} model_config = ConfigDict(extra='allow') pass @@ -76,29 +80,29 @@ class UserModel(BaseModel): id: str email: str - username: Optional[str] = None + username: str | None = None role: str = 'pending' name: str - profile_image_url: Optional[str] = None - profile_banner_image_url: Optional[str] = None + profile_image_url: str | None = None + profile_banner_image_url: str | None = None - bio: Optional[str] = None - gender: Optional[str] = None - date_of_birth: Optional[datetime.date] = None - timezone: Optional[str] = None + bio: str | None = None + gender: str | None = None + date_of_birth: datetime.date | None = None + timezone: str | None = None - presence_state: Optional[str] = None - status_emoji: Optional[str] = None - status_message: Optional[str] = None - status_expires_at: Optional[int] = None + presence_state: str | None = None + status_emoji: str | None = None + status_message: str | None = None + status_expires_at: int | None = None - info: Optional[dict] = None - settings: Optional[UserSettings] = None + info: dict | None = None + settings: UserSettings | None = None - oauth: Optional[dict] = None - scim: Optional[dict] = None + oauth: dict | None = None + scim: dict | None = None last_active_at: int # timestamp in epoch updated_at: int # timestamp in epoch @@ -136,9 +140,9 @@ class ApiKeyModel(BaseModel): id: str user_id: str key: str - data: Optional[dict] = None - expires_at: Optional[int] = None - last_used_at: Optional[int] = None + data: dict | None = None + expires_at: int | None = None + last_used_at: int | None = None created_at: int # timestamp in epoch updated_at: int # timestamp in epoch @@ -153,9 +157,9 @@ class ApiKeyModel(BaseModel): class UpdateProfileForm(BaseModel): profile_image_url: str name: str - bio: Optional[str] = None - gender: Optional[str] = None - date_of_birth: Optional[datetime.date] = None + bio: str | None = None + gender: str | None = None + date_of_birth: datetime.date | None = None @field_validator('profile_image_url') @classmethod @@ -182,9 +186,9 @@ class UserGroupIdsListResponse(BaseModel): class UserStatus(BaseModel): - status_emoji: Optional[str] = None - status_message: Optional[str] = None - status_expires_at: Optional[int] = None + status_emoji: str | None = None + status_message: str | None = None + status_expires_at: int | None = None class UserInfoResponse(UserStatus): @@ -192,8 +196,8 @@ class UserInfoResponse(UserStatus): name: str email: str role: str - bio: Optional[str] = None - groups: Optional[list] = [] + bio: str | None = None + groups: list | None = [] is_active: bool = False @@ -205,7 +209,7 @@ class UserIdNameResponse(BaseModel): class UserIdNameStatusResponse(UserStatus): id: str name: str - is_active: Optional[bool] = None + is_active: bool | None = None class UserInfoListResponse(BaseModel): @@ -239,15 +243,15 @@ class UserRoleUpdateForm(BaseModel): class UserUpdateForm(BaseModel): - role: Optional[str] = None - name: Optional[str] = None - email: Optional[str] = None - profile_image_url: Optional[str] = None - password: Optional[str] = None + role: str | None = None + name: str | None = None + email: str | None = None + profile_image_url: str | None = None + password: str | None = None @field_validator('profile_image_url', mode='before') @classmethod - def check_profile_image_url(cls, v: Optional[str]) -> Optional[str]: + def check_profile_image_url(cls, v: str | None) -> str | None: if v is None: return v return validate_profile_image_url(v) @@ -261,10 +265,10 @@ class UsersTable: email: str, profile_image_url: str = '/user.png', role: str = 'pending', - username: Optional[str] = None, - oauth: Optional[dict] = None, - db: Optional[AsyncSession] = None, - ) -> Optional[UserModel]: + username: str | None = None, + oauth: dict | None = None, + db: AsyncSession | None = None, + ) -> UserModel | None: async with get_async_db_context(db) as db: user = UserModel( **{ @@ -289,7 +293,7 @@ class UsersTable: else: return None - async def get_user_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[UserModel]: + async def get_user_by_id(self, id: str, db: AsyncSession | None = None) -> UserModel | None: try: async with get_async_db_context(db) as db: result = await db.execute(select(User).filter_by(id=id)) @@ -298,7 +302,7 @@ class UsersTable: except Exception: return None - async def get_user_by_api_key(self, api_key: str, db: Optional[AsyncSession] = None) -> Optional[UserModel]: + async def get_user_by_api_key(self, api_key: str, db: AsyncSession | None = None) -> UserModel | None: try: async with get_async_db_context(db) as db: result = await db.execute( @@ -309,7 +313,7 @@ class UsersTable: except Exception: return None - async def get_user_by_email(self, email: str, db: Optional[AsyncSession] = None) -> Optional[UserModel]: + async def get_user_by_email(self, email: str, db: AsyncSession | None = None) -> UserModel | None: try: async with get_async_db_context(db) as db: result = await db.execute(select(User).filter(func.lower(User.email) == email.lower())) @@ -318,9 +322,7 @@ class UsersTable: except Exception: return None - async def get_user_by_oauth_sub( - self, provider: str, sub: str, db: Optional[AsyncSession] = None - ) -> Optional[UserModel]: + async def get_user_by_oauth_sub(self, provider: str, sub: str, db: AsyncSession | None = None) -> UserModel | None: try: async with get_async_db_context(db) as db: dialect_name = db.bind.dialect.name @@ -339,8 +341,8 @@ class UsersTable: return None async def get_user_by_scim_external_id( - self, provider: str, external_id: str, db: Optional[AsyncSession] = None - ) -> Optional[UserModel]: + self, provider: str, external_id: str, db: AsyncSession | None = None + ) -> UserModel | None: try: async with get_async_db_context(db) as db: dialect_name = db.bind.dialect.name @@ -359,15 +361,15 @@ class UsersTable: async def get_users( self, - filter: Optional[dict] = None, - skip: Optional[int] = None, - limit: Optional[int] = None, - db: Optional[AsyncSession] = None, + filter: dict | None = None, + skip: int | None = None, + limit: int | None = None, + db: AsyncSession | None = None, ) -> dict: async with get_async_db_context(db) as db: # Import here to avoid circular imports - from open_webui.models.groups import GroupMember from open_webui.models.channels import ChannelMember + from open_webui.models.groups import GroupMember # Join GroupMember so we can order by group_id when requested stmt = select(User) @@ -501,7 +503,7 @@ class UsersTable: 'total': total, } - async def get_users_by_group_id(self, group_id: str, db: Optional[AsyncSession] = None) -> list[UserModel]: + async def get_users_by_group_id(self, group_id: str, db: AsyncSession | None = None) -> list[UserModel]: async with get_async_db_context(db) as db: from open_webui.models.groups import GroupMember @@ -511,25 +513,23 @@ class UsersTable: users = result.scalars().all() return [UserModel.model_validate(user) for user in users] - async def get_users_by_user_ids( - self, user_ids: list[str], db: Optional[AsyncSession] = None - ) -> list[UserStatusModel]: + async def get_users_by_user_ids(self, user_ids: list[str], db: AsyncSession | None = None) -> list[UserStatusModel]: async with get_async_db_context(db) as db: result = await db.execute(select(User).filter(User.id.in_(user_ids))) users = result.scalars().all() return [UserModel.model_validate(user) for user in users] - async def get_num_users(self, db: Optional[AsyncSession] = None) -> Optional[int]: + async def get_num_users(self, db: AsyncSession | None = None) -> int | None: async with get_async_db_context(db) as db: result = await db.execute(select(func.count()).select_from(User)) return result.scalar() - async def has_users(self, db: Optional[AsyncSession] = None) -> bool: + async def has_users(self, db: AsyncSession | None = None) -> bool: async with get_async_db_context(db) as db: result = await db.execute(select(exists(select(User)))) return result.scalar() - async def get_first_user(self, db: Optional[AsyncSession] = None) -> UserModel: + async def get_first_user(self, db: AsyncSession | None = None) -> UserModel: try: async with get_async_db_context(db) as db: result = await db.execute(select(User).order_by(User.created_at).limit(1)) @@ -538,7 +538,7 @@ class UsersTable: except Exception: return None - async def get_user_webhook_url_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[str]: + async def get_user_webhook_url_by_id(self, id: str, db: AsyncSession | None = None) -> str | None: try: async with get_async_db_context(db) as db: result = await db.execute(select(User).filter_by(id=id)) @@ -551,7 +551,7 @@ class UsersTable: except Exception: return None - async def get_num_users_active_today(self, db: Optional[AsyncSession] = None) -> Optional[int]: + async def get_num_users_active_today(self, db: AsyncSession | None = None) -> int | None: async with get_async_db_context(db) as db: current_timestamp = int(datetime.datetime.now().timestamp()) today_midnight_timestamp = current_timestamp - (current_timestamp % 86400) @@ -560,9 +560,7 @@ class UsersTable: ) return result.scalar() - async def update_user_role_by_id( - self, id: str, role: str, db: Optional[AsyncSession] = None - ) -> Optional[UserModel]: + async def update_user_role_by_id(self, id: str, role: str, db: AsyncSession | None = None) -> UserModel | None: try: async with get_async_db_context(db) as db: result = await db.execute(select(User).filter_by(id=id)) @@ -577,8 +575,8 @@ class UsersTable: return None async def update_user_status_by_id( - self, id: str, form_data: UserStatus, db: Optional[AsyncSession] = None - ) -> Optional[UserModel]: + self, id: str, form_data: UserStatus, db: AsyncSession | None = None + ) -> UserModel | None: try: async with get_async_db_context(db) as db: result = await db.execute(select(User).filter_by(id=id)) @@ -594,8 +592,8 @@ class UsersTable: return None async def update_user_profile_image_url_by_id( - self, id: str, profile_image_url: str, db: Optional[AsyncSession] = None - ) -> Optional[UserModel]: + self, id: str, profile_image_url: str, db: AsyncSession | None = None + ) -> UserModel | None: try: async with get_async_db_context(db) as db: result = await db.execute(select(User).filter_by(id=id)) @@ -610,7 +608,7 @@ class UsersTable: return None @throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL) - async def update_last_active_by_id(self, id: str, db: Optional[AsyncSession] = None) -> None: + async def update_last_active_by_id(self, id: str, db: AsyncSession | None = None) -> None: try: async with get_async_db_context(db) as db: await db.execute(update(User).filter_by(id=id).values(last_active_at=int(time.time()))) @@ -619,8 +617,8 @@ class UsersTable: pass async def update_user_oauth_by_id( - self, id: str, provider: str, sub: str, db: Optional[AsyncSession] = None - ) -> Optional[UserModel]: + self, id: str, provider: str, sub: str, db: AsyncSession | None = None + ) -> UserModel | None: """ Update or insert an OAuth provider/sub pair into the user's oauth JSON field. Example resulting structure: @@ -656,8 +654,8 @@ class UsersTable: id: str, provider: str, external_id: str, - db: Optional[AsyncSession] = None, - ) -> Optional[UserModel]: + db: AsyncSession | None = None, + ) -> UserModel | None: """ Update or insert a SCIM provider/external_id pair into the user's scim JSON field. Example resulting structure: @@ -684,7 +682,7 @@ class UsersTable: except Exception: return None - async def update_user_by_id(self, id: str, updated: dict, db: Optional[AsyncSession] = None) -> Optional[UserModel]: + async def update_user_by_id(self, id: str, updated: dict, db: AsyncSession | None = None) -> UserModel | None: try: async with get_async_db_context(db) as db: result = await db.execute(select(User).filter_by(id=id)) @@ -701,8 +699,8 @@ class UsersTable: return None async def update_user_settings_by_id( - self, id: str, updated: dict, db: Optional[AsyncSession] = None - ) -> Optional[UserModel]: + self, id: str, updated: dict, db: AsyncSession | None = None + ) -> UserModel | None: try: async with get_async_db_context(db) as db: result = await db.execute(select(User).filter_by(id=id)) @@ -726,10 +724,10 @@ class UsersTable: except Exception: return None - async def delete_user_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_user_by_id(self, id: str, db: AsyncSession | None = None) -> bool: try: - from open_webui.models.groups import Groups from open_webui.models.chats import Chats + from open_webui.models.groups import Groups # Remove User from Groups await Groups.remove_user_from_all_groups(id) @@ -748,7 +746,7 @@ class UsersTable: except Exception: return False - async def get_user_api_key_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[str]: + async def get_user_api_key_by_id(self, id: str, db: AsyncSession | None = None) -> str | None: try: async with get_async_db_context(db) as db: result = await db.execute(select(ApiKey).filter_by(user_id=id)) @@ -757,7 +755,7 @@ class UsersTable: except Exception: return None - async def update_user_api_key_by_id(self, id: str, api_key: str, db: Optional[AsyncSession] = None) -> bool: + async def update_user_api_key_by_id(self, id: str, api_key: str, db: AsyncSession | None = None) -> bool: try: async with get_async_db_context(db) as db: await db.execute(delete(ApiKey).filter_by(user_id=id)) @@ -779,7 +777,7 @@ class UsersTable: except Exception: return False - async def delete_user_api_key_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_user_api_key_by_id(self, id: str, db: AsyncSession | None = None) -> bool: try: async with get_async_db_context(db) as db: await db.execute(delete(ApiKey).filter_by(user_id=id)) @@ -788,13 +786,13 @@ class UsersTable: except Exception: return False - async def get_valid_user_ids(self, user_ids: list[str], db: Optional[AsyncSession] = None) -> list[str]: + async def get_valid_user_ids(self, user_ids: list[str], db: AsyncSession | None = None) -> list[str]: async with get_async_db_context(db) as db: result = await db.execute(select(User).filter(User.id.in_(user_ids))) users = result.scalars().all() return [user.id for user in users] - async def get_super_admin_user(self, db: Optional[AsyncSession] = None) -> Optional[UserModel]: + async def get_super_admin_user(self, db: AsyncSession | None = None) -> UserModel | None: async with get_async_db_context(db) as db: result = await db.execute(select(User).filter_by(role='admin').limit(1)) user = result.scalars().first() @@ -803,7 +801,7 @@ class UsersTable: else: return None - async def get_active_user_count(self, db: Optional[AsyncSession] = None) -> int: + async def get_active_user_count(self, db: AsyncSession | None = None) -> int: async with get_async_db_context(db) as db: # Consider user active if last_active_at within the last 3 minutes three_minutes_ago = int(time.time()) - 180 @@ -820,7 +818,7 @@ class UsersTable: return user.last_active_at >= three_minutes_ago return False - async def is_user_active(self, user_id: str, db: Optional[AsyncSession] = None) -> bool: + async def is_user_active(self, user_id: str, db: AsyncSession | None = None) -> bool: async with get_async_db_context(db) as db: result = await db.execute(select(User).filter_by(id=user_id)) user = result.scalars().first() diff --git a/backend/open_webui/retrieval/loaders/datalab_marker.py b/backend/open_webui/retrieval/loaders/datalab_marker.py index dd4a763b70..be8cb9baaa 100644 --- a/backend/open_webui/retrieval/loaders/datalab_marker.py +++ b/backend/open_webui/retrieval/loaders/datalab_marker.py @@ -1,11 +1,12 @@ +import json +import logging import os import time -import requests -import logging -import json from typing import List, Optional -from langchain_core.documents import Document + +import requests from fastapi import HTTPException, status +from langchain_core.documents import Document log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/loaders/external_document.py b/backend/open_webui/retrieval/loaders/external_document.py index 77b1abfcd8..ddafc3124b 100644 --- a/backend/open_webui/retrieval/loaders/external_document.py +++ b/backend/open_webui/retrieval/loaders/external_document.py @@ -1,8 +1,9 @@ -import requests -import logging, os +import logging +import os from typing import Iterator, List, Union from urllib.parse import quote +import requests from langchain_core.document_loaders import BaseLoader from langchain_core.documents import Document from open_webui.utils.headers import include_user_info_headers diff --git a/backend/open_webui/retrieval/loaders/external_web.py b/backend/open_webui/retrieval/loaders/external_web.py index 64248427b3..e3fd0b2614 100644 --- a/backend/open_webui/retrieval/loaders/external_web.py +++ b/backend/open_webui/retrieval/loaders/external_web.py @@ -1,7 +1,7 @@ -import requests import logging from typing import Iterator, List, Union +import requests from langchain_core.document_loaders import BaseLoader from langchain_core.documents import Document diff --git a/backend/open_webui/retrieval/loaders/main.py b/backend/open_webui/retrieval/loaders/main.py index 2daa641bf2..b322c84279 100644 --- a/backend/open_webui/retrieval/loaders/main.py +++ b/backend/open_webui/retrieval/loaders/main.py @@ -1,10 +1,10 @@ import asyncio -import requests -import logging -import ftfy -import sys import json +import logging +import sys +import ftfy +import requests from azure.identity import DefaultAzureCredential from langchain_community.document_loaders import ( AzureAIDocumentIntelligenceLoader, @@ -17,16 +17,13 @@ from langchain_community.document_loaders import ( YoutubeLoader, ) from langchain_core.documents import Document - -from open_webui.retrieval.loaders.external_document import ExternalDocumentLoader - -from open_webui.retrieval.loaders.mistral import MistralLoader +from open_webui.env import AIOHTTP_CLIENT_SESSION_SSL, GLOBAL_LOG_LEVEL, REQUESTS_VERIFY from open_webui.retrieval.loaders.datalab_marker import DatalabMarkerLoader +from open_webui.retrieval.loaders.external_document import ExternalDocumentLoader from open_webui.retrieval.loaders.mineru import MinerULoader +from open_webui.retrieval.loaders.mistral import MistralLoader from open_webui.retrieval.loaders.paddleocr_vl import PaddleOCRVLLoader -from open_webui.env import GLOBAL_LOG_LEVEL, REQUESTS_VERIFY, AIOHTTP_CLIENT_SESSION_SSL - logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/loaders/mineru.py b/backend/open_webui/retrieval/loaders/mineru.py index 1f0848a613..63608f9bf9 100644 --- a/backend/open_webui/retrieval/loaders/mineru.py +++ b/backend/open_webui/retrieval/loaders/mineru.py @@ -1,12 +1,13 @@ -import os -import time -import requests import logging +import os import tempfile +import time import zipfile from typing import List, Optional -from langchain_core.documents import Document + +import requests from fastapi import HTTPException, status +from langchain_core.documents import Document log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/loaders/mistral.py b/backend/open_webui/retrieval/loaders/mistral.py index b3d274ee7c..465ea5d91e 100644 --- a/backend/open_webui/retrieval/loaders/mistral.py +++ b/backend/open_webui/retrieval/loaders/mistral.py @@ -1,15 +1,15 @@ -import requests -import aiohttp import asyncio import logging import os import sys import time -from typing import List, Dict, Any from contextlib import asynccontextmanager +from typing import Any, Dict, List +import aiohttp +import requests from langchain_core.documents import Document -from open_webui.env import GLOBAL_LOG_LEVEL, AIOHTTP_CLIENT_SESSION_SSL +from open_webui.env import AIOHTTP_CLIENT_SESSION_SSL, GLOBAL_LOG_LEVEL logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/loaders/paddleocr_vl.py b/backend/open_webui/retrieval/loaders/paddleocr_vl.py index b89369b2a4..40c185eab6 100644 --- a/backend/open_webui/retrieval/loaders/paddleocr_vl.py +++ b/backend/open_webui/retrieval/loaders/paddleocr_vl.py @@ -1,10 +1,10 @@ import base64 -import os -import requests import logging +import os import sys from typing import List +import requests from langchain_core.documents import Document from open_webui.env import GLOBAL_LOG_LEVEL diff --git a/backend/open_webui/retrieval/loaders/tavily.py b/backend/open_webui/retrieval/loaders/tavily.py index 742ac499cf..bdf70830e4 100644 --- a/backend/open_webui/retrieval/loaders/tavily.py +++ b/backend/open_webui/retrieval/loaders/tavily.py @@ -1,7 +1,7 @@ -import requests import logging from typing import Iterator, List, Literal, Union +import requests from langchain_core.document_loaders import BaseLoader from langchain_core.documents import Document diff --git a/backend/open_webui/retrieval/loaders/youtube.py b/backend/open_webui/retrieval/loaders/youtube.py index 34a1d20740..97c1a0c20c 100644 --- a/backend/open_webui/retrieval/loaders/youtube.py +++ b/backend/open_webui/retrieval/loaders/youtube.py @@ -1,8 +1,8 @@ import logging -from xml.etree.ElementTree import ParseError - from typing import Any, Dict, Generator, List, Optional, Sequence, Union from urllib.parse import parse_qs, urlparse +from xml.etree.ElementTree import ParseError + from langchain_core.documents import Document log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/models/base_reranker.py b/backend/open_webui/retrieval/models/base_reranker.py index 6be7a5649b..78002e087a 100644 --- a/backend/open_webui/retrieval/models/base_reranker.py +++ b/backend/open_webui/retrieval/models/base_reranker.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional, List, Tuple +from typing import List, Optional, Tuple class BaseReranker(ABC): diff --git a/backend/open_webui/retrieval/models/colbert.py b/backend/open_webui/retrieval/models/colbert.py index ceb41824e3..d11bde1b09 100644 --- a/backend/open_webui/retrieval/models/colbert.py +++ b/backend/open_webui/retrieval/models/colbert.py @@ -1,11 +1,10 @@ -import os import logging -import torch +import os + import numpy as np +import torch from colbert.infra import ColBERTConfig from colbert.modeling.checkpoint import Checkpoint - - from open_webui.retrieval.models.base_reranker import BaseReranker log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/models/external.py b/backend/open_webui/retrieval/models/external.py index f04583b965..6d2849eb88 100644 --- a/backend/open_webui/retrieval/models/external.py +++ b/backend/open_webui/retrieval/models/external.py @@ -1,9 +1,8 @@ import logging -import requests -from typing import Optional, List, Tuple +from typing import List, Optional, Tuple from urllib.parse import quote - +import requests from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, REQUESTS_VERIFY from open_webui.retrieval.models.base_reranker import BaseReranker from open_webui.utils.headers import include_user_info_headers diff --git a/backend/open_webui/retrieval/utils.py b/backend/open_webui/retrieval/utils.py index 9635f5b2e1..277c24dd39 100644 --- a/backend/open_webui/retrieval/utils.py +++ b/backend/open_webui/retrieval/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import hashlib import logging @@ -516,8 +518,11 @@ def get_all_items_from_collections(collection_names: list[str]) -> dict: async def query_collection( - request, collection_names: list[str], queries: list[str], - embedding_function, k: int, + request, + collection_names: list[str], + queries: list[str], + embedding_function, + k: int, ) -> dict: # When request is provided, try hybrid search + reranking if enabled if request and request.app.state.config.ENABLE_RAG_HYBRID_SEARCH: @@ -1109,7 +1114,7 @@ async def get_sources_from_items( hybrid_bm25_weight, hybrid_search, full_context=False, - user: Optional[UserModel] = None, + user: UserModel | None = None, ): log.debug(f'items: {items} {queries} {embedding_function} {reranking_function} {full_context}') @@ -1421,7 +1426,7 @@ class RerankCompressor(BaseDocumentCompressor): self, documents: Sequence[Document], query: str, - callbacks: Optional[Callbacks] = None, + callbacks: Callbacks | None = None, ) -> Sequence[Document]: """Compress retrieved documents given the query context. @@ -1440,7 +1445,7 @@ class RerankCompressor(BaseDocumentCompressor): self, documents: Sequence[Document], query: str, - callbacks: Optional[Callbacks] = None, + callbacks: Callbacks | None = None, ) -> Sequence[Document]: reranking = self.reranking_function is not None diff --git a/backend/open_webui/retrieval/vector/dbs/chroma.py b/backend/open_webui/retrieval/vector/dbs/chroma.py index 4ace732b2d..cd0a59eeaf 100755 --- a/backend/open_webui/retrieval/vector/dbs/chroma.py +++ b/backend/open_webui/retrieval/vector/dbs/chroma.py @@ -1,29 +1,27 @@ -import chromadb import logging -from chromadb import Settings -from chromadb.utils.batch_utils import create_batches - from typing import Optional -from open_webui.retrieval.vector.main import ( - VectorDBBase, - VectorItem, - SearchResult, - GetResult, -) -from open_webui.retrieval.vector.utils import process_metadata - +import chromadb +from chromadb import Settings +from chromadb.utils.batch_utils import create_batches from open_webui.config import ( + CHROMA_CLIENT_AUTH_CREDENTIALS, + CHROMA_CLIENT_AUTH_PROVIDER, CHROMA_DATA_PATH, + CHROMA_DATABASE, + CHROMA_HTTP_HEADERS, CHROMA_HTTP_HOST, CHROMA_HTTP_PORT, - CHROMA_HTTP_HEADERS, CHROMA_HTTP_SSL, CHROMA_TENANT, - CHROMA_DATABASE, - CHROMA_CLIENT_AUTH_PROVIDER, - CHROMA_CLIENT_AUTH_CREDENTIALS, ) +from open_webui.retrieval.vector.main import ( + GetResult, + SearchResult, + VectorDBBase, + VectorItem, +) +from open_webui.retrieval.vector.utils import process_metadata log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/vector/dbs/elasticsearch.py b/backend/open_webui/retrieval/vector/dbs/elasticsearch.py index 201a5e1706..da3e6e93c8 100644 --- a/backend/open_webui/retrieval/vector/dbs/elasticsearch.py +++ b/backend/open_webui/retrieval/vector/dbs/elasticsearch.py @@ -2,28 +2,28 @@ NOTE: This vector database integration is community-supported and maintained on a best-effort basis. """ -from elasticsearch import Elasticsearch, BadRequestError -from typing import Optional import ssl -from elasticsearch.helpers import bulk, scan +from typing import Optional -from open_webui.retrieval.vector.utils import process_metadata -from open_webui.retrieval.vector.main import ( - VectorDBBase, - VectorItem, - SearchResult, - GetResult, -) +from elasticsearch import BadRequestError, Elasticsearch +from elasticsearch.helpers import bulk, scan from open_webui.config import ( - ELASTICSEARCH_URL, - ELASTICSEARCH_CA_CERTS, ELASTICSEARCH_API_KEY, - ELASTICSEARCH_USERNAME, - ELASTICSEARCH_PASSWORD, + ELASTICSEARCH_CA_CERTS, ELASTICSEARCH_CLOUD_ID, ELASTICSEARCH_INDEX_PREFIX, + ELASTICSEARCH_PASSWORD, + ELASTICSEARCH_URL, + ELASTICSEARCH_USERNAME, SSL_ASSERT_FINGERPRINT, ) +from open_webui.retrieval.vector.main import ( + GetResult, + SearchResult, + VectorDBBase, + VectorItem, +) +from open_webui.retrieval.vector.utils import process_metadata class ElasticsearchClient(VectorDBBase): diff --git a/backend/open_webui/retrieval/vector/dbs/mariadb_vector.py b/backend/open_webui/retrieval/vector/dbs/mariadb_vector.py index 1cb3563382..a8cf62f7b1 100644 --- a/backend/open_webui/retrieval/vector/dbs/mariadb_vector.py +++ b/backend/open_webui/retrieval/vector/dbs/mariadb_vector.py @@ -13,18 +13,15 @@ import sys from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple -from sqlalchemy import create_engine -from sqlalchemy.pool import NullPool, QueuePool - from open_webui.config import ( MARIADB_VECTOR_DB_URL, MARIADB_VECTOR_DISTANCE_STRATEGY, MARIADB_VECTOR_INDEX_M, MARIADB_VECTOR_INITIALIZE_MAX_VECTOR_LENGTH, - MARIADB_VECTOR_POOL_SIZE, MARIADB_VECTOR_POOL_MAX_OVERFLOW, - MARIADB_VECTOR_POOL_TIMEOUT, MARIADB_VECTOR_POOL_RECYCLE, + MARIADB_VECTOR_POOL_SIZE, + MARIADB_VECTOR_POOL_TIMEOUT, ) from open_webui.retrieval.vector.main import ( GetResult, @@ -33,6 +30,8 @@ from open_webui.retrieval.vector.main import ( VectorItem, ) from open_webui.retrieval.vector.utils import process_metadata +from sqlalchemy import create_engine +from sqlalchemy.pool import NullPool, QueuePool log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/vector/dbs/milvus.py b/backend/open_webui/retrieval/vector/dbs/milvus.py index 2f3d8f3890..9b356b1c69 100644 --- a/backend/open_webui/retrieval/vector/dbs/milvus.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus.py @@ -2,33 +2,31 @@ NOTE: This vector database integration is community-supported and maintained on a best-effort basis. """ -from pymilvus import MilvusClient as Client -from pymilvus import FieldSchema, DataType -from pymilvus import connections, Collection - import json import logging from typing import Optional -from open_webui.retrieval.vector.utils import process_metadata -from open_webui.retrieval.vector.main import ( - VectorDBBase, - VectorItem, - SearchResult, - GetResult, -) from open_webui.config import ( - MILVUS_URI, MILVUS_DB, - MILVUS_TOKEN, - MILVUS_INDEX_TYPE, - MILVUS_METRIC_TYPE, - MILVUS_HNSW_M, - MILVUS_HNSW_EFCONSTRUCTION, - MILVUS_IVF_FLAT_NLIST, MILVUS_DISKANN_MAX_DEGREE, MILVUS_DISKANN_SEARCH_LIST_SIZE, + MILVUS_HNSW_EFCONSTRUCTION, + MILVUS_HNSW_M, + MILVUS_INDEX_TYPE, + MILVUS_IVF_FLAT_NLIST, + MILVUS_METRIC_TYPE, + MILVUS_TOKEN, + MILVUS_URI, ) +from open_webui.retrieval.vector.main import ( + GetResult, + SearchResult, + VectorDBBase, + VectorItem, +) +from open_webui.retrieval.vector.utils import process_metadata +from pymilvus import Collection, DataType, FieldSchema, connections +from pymilvus import MilvusClient as Client log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py b/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py index 93b4a8cbc4..d43e2e61d4 100644 --- a/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py +++ b/backend/open_webui/retrieval/vector/dbs/milvus_multitenancy.py @@ -3,18 +3,18 @@ NOTE: This vector database integration is community-supported and maintained on """ import logging -from typing import Optional, Tuple, List, Dict, Any +from typing import Any, Dict, List, Optional, Tuple from open_webui.config import ( - MILVUS_URI, - MILVUS_TOKEN, - MILVUS_DB, MILVUS_COLLECTION_PREFIX, - MILVUS_INDEX_TYPE, - MILVUS_METRIC_TYPE, - MILVUS_HNSW_M, + MILVUS_DB, MILVUS_HNSW_EFCONSTRUCTION, + MILVUS_HNSW_M, + MILVUS_INDEX_TYPE, MILVUS_IVF_FLAT_NLIST, + MILVUS_METRIC_TYPE, + MILVUS_TOKEN, + MILVUS_URI, ) from open_webui.retrieval.vector.main import ( GetResult, @@ -23,12 +23,12 @@ from open_webui.retrieval.vector.main import ( VectorItem, ) from pymilvus import ( - connections, - utility, Collection, CollectionSchema, - FieldSchema, DataType, + FieldSchema, + connections, + utility, ) log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/vector/dbs/opengauss.py b/backend/open_webui/retrieval/vector/dbs/opengauss.py index ac97cf01fa..1c9d35253c 100644 --- a/backend/open_webui/retrieval/vector/dbs/opengauss.py +++ b/backend/open_webui/retrieval/vector/dbs/opengauss.py @@ -2,37 +2,36 @@ NOTE: This vector database integration is community-supported and maintained on a best-effort basis. """ -from typing import Optional, List, Dict, Any +import json import logging import re -import json +from typing import Any, Dict, List, Optional + +from pgvector.sqlalchemy import Vector from sqlalchemy import ( - func, - literal, + Column, + Integer, + LargeBinary, + MetaData, + Table, + Text, cast, column, create_engine, - Column, - Integer, - MetaData, - LargeBinary, + func, + literal, select, text, - Text, - Table, values, ) -from sqlalchemy.sql import true -from sqlalchemy.pool import NullPool, QueuePool - -from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker -from sqlalchemy.dialects.postgresql import JSONB, array -from pgvector.sqlalchemy import Vector -from sqlalchemy.ext.mutable import MutableDict -from sqlalchemy.exc import NoSuchTableError - -from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 from sqlalchemy.dialects import registry +from sqlalchemy.dialects.postgresql import JSONB, array +from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 +from sqlalchemy.exc import NoSuchTableError +from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker +from sqlalchemy.pool import NullPool, QueuePool +from sqlalchemy.sql import true class OpenGaussDialect(PGDialect_psycopg2): @@ -56,23 +55,22 @@ class OpenGaussDialect(PGDialect_psycopg2): # Register dialect registry.register('opengauss', __name__, 'OpenGaussDialect') -from open_webui.retrieval.vector.utils import process_metadata -from open_webui.retrieval.vector.main import ( - VectorDBBase, - VectorItem, - SearchResult, - GetResult, -) from open_webui.config import ( OPENGAUSS_DB_URL, OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH, - OPENGAUSS_POOL_SIZE, OPENGAUSS_POOL_MAX_OVERFLOW, - OPENGAUSS_POOL_TIMEOUT, OPENGAUSS_POOL_RECYCLE, + OPENGAUSS_POOL_SIZE, + OPENGAUSS_POOL_TIMEOUT, ) - from open_webui.env import SRC_LOG_LEVELS +from open_webui.retrieval.vector.main import ( + GetResult, + SearchResult, + VectorDBBase, + VectorItem, +) +from open_webui.retrieval.vector.utils import process_metadata VECTOR_LENGTH = OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH Base = declarative_base() diff --git a/backend/open_webui/retrieval/vector/dbs/opensearch.py b/backend/open_webui/retrieval/vector/dbs/opensearch.py index a08dca7865..798802cd54 100644 --- a/backend/open_webui/retrieval/vector/dbs/opensearch.py +++ b/backend/open_webui/retrieval/vector/dbs/opensearch.py @@ -2,24 +2,24 @@ NOTE: This vector database integration is community-supported and maintained on a best-effort basis. """ -from opensearchpy import OpenSearch -from opensearchpy.helpers import bulk from typing import Optional -from open_webui.retrieval.vector.utils import process_metadata +from open_webui.config import ( + OPENSEARCH_CERT_VERIFY, + OPENSEARCH_PASSWORD, + OPENSEARCH_SSL, + OPENSEARCH_URI, + OPENSEARCH_USERNAME, +) from open_webui.retrieval.vector.main import ( + GetResult, + SearchResult, VectorDBBase, VectorItem, - SearchResult, - GetResult, -) -from open_webui.config import ( - OPENSEARCH_URI, - OPENSEARCH_SSL, - OPENSEARCH_CERT_VERIFY, - OPENSEARCH_USERNAME, - OPENSEARCH_PASSWORD, ) +from open_webui.retrieval.vector.utils import process_metadata +from opensearchpy import OpenSearch +from opensearchpy.helpers import bulk class OpenSearchClient(VectorDBBase): diff --git a/backend/open_webui/retrieval/vector/dbs/oracle23ai.py b/backend/open_webui/retrieval/vector/dbs/oracle23ai.py index 9a5bd638d9..b09eacb81d 100644 --- a/backend/open_webui/retrieval/vector/dbs/oracle23ai.py +++ b/backend/open_webui/retrieval/vector/dbs/oracle23ai.py @@ -28,34 +28,33 @@ ORACLE_DB_POOL_MAX = 10 ORACLE_DB_POOL_INCREMENT = 1 """ -from typing import Optional, List, Dict, Any, Union -from decimal import Decimal +import array +import json import logging import os import threading import time -import json -import array +from decimal import Decimal +from typing import Any, Dict, List, Optional, Union + import oracledb - -from open_webui.retrieval.vector.main import ( - VectorDBBase, - VectorItem, - SearchResult, - GetResult, -) - from open_webui.config import ( + ORACLE_DB_DSN, + ORACLE_DB_PASSWORD, + ORACLE_DB_POOL_INCREMENT, + ORACLE_DB_POOL_MAX, + ORACLE_DB_POOL_MIN, ORACLE_DB_USE_WALLET, ORACLE_DB_USER, - ORACLE_DB_PASSWORD, - ORACLE_DB_DSN, + ORACLE_VECTOR_LENGTH, ORACLE_WALLET_DIR, ORACLE_WALLET_PASSWORD, - ORACLE_VECTOR_LENGTH, - ORACLE_DB_POOL_MIN, - ORACLE_DB_POOL_MAX, - ORACLE_DB_POOL_INCREMENT, +) +from open_webui.retrieval.vector.main import ( + GetResult, + SearchResult, + VectorDBBase, + VectorItem, ) log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/vector/dbs/pgvector.py b/backend/open_webui/retrieval/vector/dbs/pgvector.py index 90e65b9ad0..861d49bc1b 100644 --- a/backend/open_webui/retrieval/vector/dbs/pgvector.py +++ b/backend/open_webui/retrieval/vector/dbs/pgvector.py @@ -1,56 +1,54 @@ -from typing import Optional, List, Dict, Any, Tuple -import logging import json +import logging +from typing import Any, Dict, List, Optional, Tuple + +from open_webui.config import ( + PGVECTOR_CREATE_EXTENSION, + PGVECTOR_DB_URL, + PGVECTOR_HNSW_EF_CONSTRUCTION, + PGVECTOR_HNSW_M, + PGVECTOR_INDEX_METHOD, + PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH, + PGVECTOR_IVFFLAT_LISTS, + PGVECTOR_PGCRYPTO, + PGVECTOR_PGCRYPTO_KEY, + PGVECTOR_POOL_MAX_OVERFLOW, + PGVECTOR_POOL_RECYCLE, + PGVECTOR_POOL_SIZE, + PGVECTOR_POOL_TIMEOUT, + PGVECTOR_USE_HALFVEC, +) +from open_webui.retrieval.vector.main import ( + GetResult, + SearchResult, + VectorDBBase, + VectorItem, +) +from open_webui.retrieval.vector.utils import process_metadata +from open_webui.utils.misc import sanitize_text_for_db +from pgvector.sqlalchemy import HALFVEC, Vector from sqlalchemy import ( - func, - literal, + Column, + Integer, + LargeBinary, + MetaData, + Table, + Text, cast, column, create_engine, - Column, - Integer, - MetaData, - LargeBinary, + func, + literal, select, text, - Text, - Table, values, ) -from sqlalchemy.sql import true -from sqlalchemy.pool import NullPool, QueuePool - -from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker from sqlalchemy.dialects.postgresql import JSONB, array -from pgvector.sqlalchemy import Vector, HALFVEC -from sqlalchemy.ext.mutable import MutableDict from sqlalchemy.exc import NoSuchTableError - - -from open_webui.retrieval.vector.utils import process_metadata -from open_webui.retrieval.vector.main import ( - VectorDBBase, - VectorItem, - SearchResult, - GetResult, -) -from open_webui.utils.misc import sanitize_text_for_db -from open_webui.config import ( - PGVECTOR_DB_URL, - PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH, - PGVECTOR_CREATE_EXTENSION, - PGVECTOR_PGCRYPTO, - PGVECTOR_PGCRYPTO_KEY, - PGVECTOR_POOL_SIZE, - PGVECTOR_POOL_MAX_OVERFLOW, - PGVECTOR_POOL_TIMEOUT, - PGVECTOR_POOL_RECYCLE, - PGVECTOR_INDEX_METHOD, - PGVECTOR_HNSW_M, - PGVECTOR_HNSW_EF_CONSTRUCTION, - PGVECTOR_IVFFLAT_LISTS, - PGVECTOR_USE_HALFVEC, -) +from sqlalchemy.ext.mutable import MutableDict +from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker +from sqlalchemy.pool import NullPool, QueuePool +from sqlalchemy.sql import true VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH USE_HALFVEC = PGVECTOR_USE_HALFVEC diff --git a/backend/open_webui/retrieval/vector/dbs/pinecone.py b/backend/open_webui/retrieval/vector/dbs/pinecone.py index 6469ac9172..7e2b6e2dfa 100644 --- a/backend/open_webui/retrieval/vector/dbs/pinecone.py +++ b/backend/open_webui/retrieval/vector/dbs/pinecone.py @@ -2,9 +2,10 @@ NOTE: This vector database integration is community-supported and maintained on a best-effort basis. """ -from typing import Optional, List, Dict, Any, Union import logging import time # for measuring elapsed time +from typing import Any, Dict, List, Optional, Union + from pinecone import Pinecone, ServerlessSpec # Add gRPC support for better performance (Pinecone best practice) @@ -16,24 +17,23 @@ except ImportError: GRPC_AVAILABLE = False import asyncio # for async upserts -import functools # for partial binding in async tasks - import concurrent.futures # for parallel batch upserts +import functools # for partial binding in async tasks import random # for jitter in retry backoff -from open_webui.retrieval.vector.main import ( - VectorDBBase, - VectorItem, - SearchResult, - GetResult, -) from open_webui.config import ( PINECONE_API_KEY, + PINECONE_CLOUD, + PINECONE_DIMENSION, PINECONE_ENVIRONMENT, PINECONE_INDEX_NAME, - PINECONE_DIMENSION, PINECONE_METRIC, - PINECONE_CLOUD, +) +from open_webui.retrieval.vector.main import ( + GetResult, + SearchResult, + VectorDBBase, + VectorItem, ) from open_webui.retrieval.vector.utils import process_metadata diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant.py b/backend/open_webui/retrieval/vector/dbs/qdrant.py index f050bebeb5..acf4e61993 100644 --- a/backend/open_webui/retrieval/vector/dbs/qdrant.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant.py @@ -2,31 +2,30 @@ NOTE: This vector database integration is community-supported and maintained on a best-effort basis. """ -from typing import Optional import logging +from typing import Optional from urllib.parse import urlparse +from open_webui.config import ( + QDRANT_API_KEY, + QDRANT_COLLECTION_PREFIX, + QDRANT_GRPC_PORT, + QDRANT_HNSW_M, + QDRANT_ON_DISK, + QDRANT_PREFER_GRPC, + QDRANT_TIMEOUT, + QDRANT_URI, +) +from open_webui.retrieval.vector.main import ( + GetResult, + SearchResult, + VectorDBBase, + VectorItem, +) from qdrant_client import QdrantClient as Qclient from qdrant_client.http.models import PointStruct from qdrant_client.models import models -from open_webui.retrieval.vector.main import ( - VectorDBBase, - VectorItem, - SearchResult, - GetResult, -) -from open_webui.config import ( - QDRANT_URI, - QDRANT_API_KEY, - QDRANT_ON_DISK, - QDRANT_GRPC_PORT, - QDRANT_PREFER_GRPC, - QDRANT_COLLECTION_PREFIX, - QDRANT_TIMEOUT, - QDRANT_HNSW_M, -) - NO_LIMIT = 999999999 log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py index c3c2ba41d0..15e72f1d16 100644 --- a/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py +++ b/backend/open_webui/retrieval/vector/dbs/qdrant_multitenancy.py @@ -3,19 +3,19 @@ NOTE: This vector database integration is community-supported and maintained on """ import logging -from typing import Optional, Tuple, List, Dict, Any +from typing import Any, Dict, List, Optional, Tuple from urllib.parse import urlparse import grpc from open_webui.config import ( QDRANT_API_KEY, + QDRANT_COLLECTION_PREFIX, QDRANT_GRPC_PORT, + QDRANT_HNSW_M, QDRANT_ON_DISK, QDRANT_PREFER_GRPC, - QDRANT_URI, - QDRANT_COLLECTION_PREFIX, QDRANT_TIMEOUT, - QDRANT_HNSW_M, + QDRANT_URI, ) from open_webui.retrieval.vector.main import ( GetResult, diff --git a/backend/open_webui/retrieval/vector/dbs/s3vector.py b/backend/open_webui/retrieval/vector/dbs/s3vector.py index 8877d206e6..e0b156931c 100644 --- a/backend/open_webui/retrieval/vector/dbs/s3vector.py +++ b/backend/open_webui/retrieval/vector/dbs/s3vector.py @@ -2,17 +2,18 @@ NOTE: This vector database integration is community-supported and maintained on a best-effort basis. """ -from open_webui.retrieval.vector.utils import process_metadata +import logging +from typing import Any, Dict, List, Optional, Union + +import boto3 +from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION from open_webui.retrieval.vector.main import ( - VectorDBBase, - VectorItem, GetResult, SearchResult, + VectorDBBase, + VectorItem, ) -from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION -from typing import List, Optional, Dict, Any, Union -import logging -import boto3 +from open_webui.retrieval.vector.utils import process_metadata log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/vector/dbs/weaviate.py b/backend/open_webui/retrieval/vector/dbs/weaviate.py index 2cf4c135c5..a896d8ed5e 100644 --- a/backend/open_webui/retrieval/vector/dbs/weaviate.py +++ b/backend/open_webui/retrieval/vector/dbs/weaviate.py @@ -2,28 +2,28 @@ NOTE: This vector database integration is community-supported and maintained on a best-effort basis. """ -import weaviate import re import uuid from typing import Any, Dict, List, Optional, Union -from open_webui.retrieval.vector.main import ( - VectorDBBase, - VectorItem, - SearchResult, - GetResult, -) -from open_webui.retrieval.vector.utils import process_metadata +import weaviate from open_webui.config import ( - WEAVIATE_HTTP_HOST, - WEAVIATE_GRPC_HOST, - WEAVIATE_HTTP_PORT, - WEAVIATE_GRPC_PORT, WEAVIATE_API_KEY, - WEAVIATE_HTTP_SECURE, + WEAVIATE_GRPC_HOST, + WEAVIATE_GRPC_PORT, WEAVIATE_GRPC_SECURE, + WEAVIATE_HTTP_HOST, + WEAVIATE_HTTP_PORT, + WEAVIATE_HTTP_SECURE, WEAVIATE_SKIP_INIT_CHECKS, ) +from open_webui.retrieval.vector.main import ( + GetResult, + SearchResult, + VectorDBBase, + VectorItem, +) +from open_webui.retrieval.vector.utils import process_metadata def _convert_uuids_to_strings(obj: Any) -> Any: diff --git a/backend/open_webui/retrieval/vector/factory.py b/backend/open_webui/retrieval/vector/factory.py index 8c0208fd4f..59ebb1c4e4 100644 --- a/backend/open_webui/retrieval/vector/factory.py +++ b/backend/open_webui/retrieval/vector/factory.py @@ -1,10 +1,10 @@ +from open_webui.config import ( + ENABLE_MILVUS_MULTITENANCY_MODE, + ENABLE_QDRANT_MULTITENANCY_MODE, + VECTOR_DB, +) from open_webui.retrieval.vector.main import VectorDBBase from open_webui.retrieval.vector.type import VectorType -from open_webui.config import ( - VECTOR_DB, - ENABLE_QDRANT_MULTITENANCY_MODE, - ENABLE_MILVUS_MULTITENANCY_MODE, -) class Vector: diff --git a/backend/open_webui/retrieval/vector/main.py b/backend/open_webui/retrieval/vector/main.py index f7904baa20..38ea699514 100644 --- a/backend/open_webui/retrieval/vector/main.py +++ b/backend/open_webui/retrieval/vector/main.py @@ -1,7 +1,8 @@ -from pydantic import BaseModel from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Union +from pydantic import BaseModel + class VectorItem(BaseModel): id: str diff --git a/backend/open_webui/retrieval/web/azure.py b/backend/open_webui/retrieval/web/azure.py index 4f74ecc982..5ba57f11c5 100644 --- a/backend/open_webui/retrieval/web/azure.py +++ b/backend/open_webui/retrieval/web/azure.py @@ -1,5 +1,6 @@ import logging from typing import Optional + from open_webui.retrieval.web.main import SearchResult, get_filtered_results log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/bing.py b/backend/open_webui/retrieval/web/bing.py index b7cfea89de..b026302b4b 100644 --- a/backend/open_webui/retrieval/web/bing.py +++ b/backend/open_webui/retrieval/web/bing.py @@ -1,10 +1,11 @@ +import argparse import logging import os from pprint import pprint from typing import Optional + import requests from open_webui.retrieval.web.main import SearchResult, get_filtered_results -import argparse log = logging.getLogger(__name__) """ diff --git a/backend/open_webui/retrieval/web/bocha.py b/backend/open_webui/retrieval/web/bocha.py index 3557dcffb9..cb94646310 100644 --- a/backend/open_webui/retrieval/web/bocha.py +++ b/backend/open_webui/retrieval/web/bocha.py @@ -1,8 +1,8 @@ +import json import logging from typing import Optional import requests -import json from open_webui.retrieval.web.main import SearchResult, get_filtered_results log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/brave.py b/backend/open_webui/retrieval/web/brave.py index e06d4594fb..eb80094dbb 100644 --- a/backend/open_webui/retrieval/web/brave.py +++ b/backend/open_webui/retrieval/web/brave.py @@ -8,6 +8,7 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results log = logging.getLogger(__name__) + def search_brave(api_key: str, query: str, count: int, filter_list: list[str | None] = None) -> list[SearchResult]: """Search using Brave's Search API and return the results as a list of SearchResult objects. diff --git a/backend/open_webui/retrieval/web/duckduckgo.py b/backend/open_webui/retrieval/web/duckduckgo.py index 5eb3e73a10..27d56f6934 100644 --- a/backend/open_webui/retrieval/web/duckduckgo.py +++ b/backend/open_webui/retrieval/web/duckduckgo.py @@ -9,6 +9,7 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results log = logging.getLogger(__name__) + def search_duckduckgo( query: str, count: int, diff --git a/backend/open_webui/retrieval/web/external.py b/backend/open_webui/retrieval/web/external.py index 7f5a2bf2af..4db37cb645 100644 --- a/backend/open_webui/retrieval/web/external.py +++ b/backend/open_webui/retrieval/web/external.py @@ -1,14 +1,11 @@ import logging -from typing import Optional, List +from typing import List, Optional import requests - from fastapi import Request - - +from open_webui.env import FORWARD_SESSION_INFO_HEADER_CHAT_ID from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.utils.headers import include_user_info_headers -from open_webui.env import FORWARD_SESSION_INFO_HEADER_CHAT_ID log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/google_pse.py b/backend/open_webui/retrieval/web/google_pse.py index 5391e2ba8b..ec7b6d5af8 100644 --- a/backend/open_webui/retrieval/web/google_pse.py +++ b/backend/open_webui/retrieval/web/google_pse.py @@ -7,6 +7,7 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results log = logging.getLogger(__name__) + def search_google_pse( api_key: str, search_engine_id: str, diff --git a/backend/open_webui/retrieval/web/jina_search.py b/backend/open_webui/retrieval/web/jina_search.py index 7d6585ea3d..830beec702 100644 --- a/backend/open_webui/retrieval/web/jina_search.py +++ b/backend/open_webui/retrieval/web/jina_search.py @@ -8,6 +8,7 @@ from yarl import URL log = logging.getLogger(__name__) + def search_jina(api_key: str, query: str, count: int, base_url: str = '') -> list[SearchResult]: """ Search using Jina's Search API and return the results as a list of SearchResult objects. diff --git a/backend/open_webui/retrieval/web/main.py b/backend/open_webui/retrieval/web/main.py index e0dd30c01d..a55c62c8b5 100644 --- a/backend/open_webui/retrieval/web/main.py +++ b/backend/open_webui/retrieval/web/main.py @@ -7,6 +7,7 @@ from open_webui.retrieval.web.utils import resolve_hostname from open_webui.utils.misc import is_string_allowed from pydantic import BaseModel + def get_filtered_results(results, filter_list): if not filter_list: return results @@ -37,6 +38,7 @@ def get_filtered_results(results, filter_list): return filtered_results + class SearchResult(BaseModel): link: str title: str | None diff --git a/backend/open_webui/retrieval/web/mojeek.py b/backend/open_webui/retrieval/web/mojeek.py index 495180a828..da24c2a87c 100644 --- a/backend/open_webui/retrieval/web/mojeek.py +++ b/backend/open_webui/retrieval/web/mojeek.py @@ -7,6 +7,7 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results log = logging.getLogger(__name__) + def search_mojeek(api_key: str, query: str, count: int, filter_list: list[str | None] = None) -> list[SearchResult]: """Search using Mojeek's Search API and return the results as a list of SearchResult objects. diff --git a/backend/open_webui/retrieval/web/perplexity.py b/backend/open_webui/retrieval/web/perplexity.py index e3d5131dbf..3794057ebc 100644 --- a/backend/open_webui/retrieval/web/perplexity.py +++ b/backend/open_webui/retrieval/web/perplexity.py @@ -1,7 +1,7 @@ import logging -from typing import Optional, Literal -import requests +from typing import Literal, Optional +import requests from open_webui.retrieval.web.main import SearchResult, get_filtered_results MODELS = Literal[ diff --git a/backend/open_webui/retrieval/web/perplexity_search.py b/backend/open_webui/retrieval/web/perplexity_search.py index 9a37087aa4..dede51cadd 100644 --- a/backend/open_webui/retrieval/web/perplexity_search.py +++ b/backend/open_webui/retrieval/web/perplexity_search.py @@ -1,7 +1,7 @@ import logging -from typing import Optional, Literal -import requests +from typing import Literal, Optional +import requests from open_webui.retrieval.web.main import SearchResult, get_filtered_results from open_webui.utils.headers import include_user_info_headers diff --git a/backend/open_webui/retrieval/web/searxng.py b/backend/open_webui/retrieval/web/searxng.py index f0b8b7ab51..67ca8da64a 100644 --- a/backend/open_webui/retrieval/web/searxng.py +++ b/backend/open_webui/retrieval/web/searxng.py @@ -7,6 +7,7 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results log = logging.getLogger(__name__) + def search_searxng( # noqa: PLR0913 query_url: str, query: str, diff --git a/backend/open_webui/retrieval/web/serper.py b/backend/open_webui/retrieval/web/serper.py index b7fce3b417..91cb78d4f4 100644 --- a/backend/open_webui/retrieval/web/serper.py +++ b/backend/open_webui/retrieval/web/serper.py @@ -8,6 +8,7 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results log = logging.getLogger(__name__) + def search_serper(api_key: str, query: str, count: int, filter_list: list[str | None] = None) -> list[SearchResult]: """Search using serper.dev's API and return the results as a list of SearchResult objects. diff --git a/backend/open_webui/retrieval/web/serply.py b/backend/open_webui/retrieval/web/serply.py index 8a6b437eee..c4909c14e0 100644 --- a/backend/open_webui/retrieval/web/serply.py +++ b/backend/open_webui/retrieval/web/serply.py @@ -8,6 +8,7 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results log = logging.getLogger(__name__) + def search_serply( api_key: str, query: str, diff --git a/backend/open_webui/retrieval/web/serpstack.py b/backend/open_webui/retrieval/web/serpstack.py index 93235fbfca..be034e1594 100644 --- a/backend/open_webui/retrieval/web/serpstack.py +++ b/backend/open_webui/retrieval/web/serpstack.py @@ -7,6 +7,7 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results log = logging.getLogger(__name__) + def search_serpstack( api_key: str, query: str, diff --git a/backend/open_webui/retrieval/web/sougou.py b/backend/open_webui/retrieval/web/sougou.py index b267374d79..3d12e2a57b 100644 --- a/backend/open_webui/retrieval/web/sougou.py +++ b/backend/open_webui/retrieval/web/sougou.py @@ -1,7 +1,6 @@ -import logging import json -from typing import Optional, List - +import logging +from typing import List, Optional from open_webui.retrieval.web.main import SearchResult, get_filtered_results @@ -15,8 +14,8 @@ def search_sougou( count: int, filter_list: Optional[List[str]] = None, ) -> List[SearchResult]: - from tencentcloud.common.common_client import CommonClient from tencentcloud.common import credential + from tencentcloud.common.common_client import CommonClient from tencentcloud.common.exception.tencent_cloud_sdk_exception import ( TencentCloudSDKException, ) diff --git a/backend/open_webui/retrieval/web/tavily.py b/backend/open_webui/retrieval/web/tavily.py index bca2e20990..419bebb05e 100644 --- a/backend/open_webui/retrieval/web/tavily.py +++ b/backend/open_webui/retrieval/web/tavily.py @@ -7,6 +7,7 @@ from open_webui.retrieval.web.main import SearchResult, get_filtered_results log = logging.getLogger(__name__) + def search_tavily( api_key: str, query: str, diff --git a/backend/open_webui/retrieval/web/utils.py b/backend/open_webui/retrieval/web/utils.py index c2ce6bdbd1..4711a5edae 100644 --- a/backend/open_webui/retrieval/web/utils.py +++ b/backend/open_webui/retrieval/web/utils.py @@ -5,8 +5,6 @@ import socket import ssl import urllib.parse import urllib.request - -import requests from datetime import datetime, time, timedelta from typing import ( Any, @@ -14,41 +12,41 @@ from typing import ( Dict, Iterator, List, + Literal, Optional, Sequence, Union, - Literal, ) -from fastapi.concurrency import run_in_threadpool import aiohttp import certifi +import requests import validators +from fastapi.concurrency import run_in_threadpool from langchain_community.document_loaders import PlaywrightURLLoader, WebBaseLoader from langchain_community.document_loaders.base import BaseLoader from langchain_core.documents import Document - -from open_webui.retrieval.loaders.tavily import TavilyLoader -from open_webui.retrieval.loaders.external_web import ExternalWebLoader -from open_webui.retrieval.web.firecrawl import scrape_firecrawl_url -from open_webui.constants import ERROR_MESSAGES from open_webui.config import ( ENABLE_RAG_LOCAL_WEB_FETCH, - PLAYWRIGHT_WS_URL, - PLAYWRIGHT_TIMEOUT, - WEB_LOADER_ENGINE, - WEB_LOADER_TIMEOUT, + EXTERNAL_WEB_LOADER_API_KEY, + EXTERNAL_WEB_LOADER_URL, FIRECRAWL_API_BASE_URL, FIRECRAWL_API_KEY, FIRECRAWL_TIMEOUT, + PLAYWRIGHT_TIMEOUT, + PLAYWRIGHT_WS_URL, TAVILY_API_KEY, TAVILY_EXTRACT_DEPTH, - EXTERNAL_WEB_LOADER_URL, - EXTERNAL_WEB_LOADER_API_KEY, WEB_FETCH_FILTER_LIST, + WEB_LOADER_ENGINE, + WEB_LOADER_TIMEOUT, ) +from open_webui.constants import ERROR_MESSAGES +from open_webui.env import AIOHTTP_CLIENT_ALLOW_REDIRECTS, AIOHTTP_CLIENT_SESSION_SSL +from open_webui.retrieval.loaders.external_web import ExternalWebLoader +from open_webui.retrieval.loaders.tavily import TavilyLoader +from open_webui.retrieval.web.firecrawl import scrape_firecrawl_url from open_webui.utils.misc import is_string_allowed -from open_webui.env import AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_ALLOW_REDIRECTS log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/yacy.py b/backend/open_webui/retrieval/web/yacy.py index 32ca04f531..969acb7398 100644 --- a/backend/open_webui/retrieval/web/yacy.py +++ b/backend/open_webui/retrieval/web/yacy.py @@ -2,8 +2,8 @@ import logging from typing import Optional import requests -from requests.auth import HTTPDigestAuth from open_webui.retrieval.web.main import SearchResult, get_filtered_results +from requests.auth import HTTPDigestAuth log = logging.getLogger(__name__) diff --git a/backend/open_webui/retrieval/web/yandex.py b/backend/open_webui/retrieval/web/yandex.py index 1fffac8f61..d338db11a3 100644 --- a/backend/open_webui/retrieval/web/yandex.py +++ b/backend/open_webui/retrieval/web/yandex.py @@ -3,19 +3,16 @@ import io import json import logging import os -from typing import Optional, List - -import requests - -from fastapi import Request - -from open_webui.retrieval.web.main import SearchResult, get_filtered_results -from open_webui.utils.headers import include_user_info_headers -from open_webui.env import FORWARD_SESSION_INFO_HEADER_CHAT_ID - +from typing import List, Optional from xml.etree import ElementTree as ET from xml.etree.ElementTree import Element +import requests +from fastapi import Request +from open_webui.env import FORWARD_SESSION_INFO_HEADER_CHAT_ID +from open_webui.retrieval.web.main import SearchResult, get_filtered_results +from open_webui.utils.headers import include_user_info_headers + log = logging.getLogger(__name__) @@ -122,8 +119,8 @@ def search_yandex( if __name__ == '__main__': - from starlette.datastructures import Headers from fastapi import FastAPI + from starlette.datastructures import Headers result = search_yandex( Request( diff --git a/backend/open_webui/retrieval/web/ydc.py b/backend/open_webui/retrieval/web/ydc.py index 21059d8b03..446fa5f16d 100644 --- a/backend/open_webui/retrieval/web/ydc.py +++ b/backend/open_webui/retrieval/web/ydc.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, List +from typing import List, Optional import requests from open_webui.retrieval.web.main import SearchResult, get_filtered_results diff --git a/backend/open_webui/routers/analytics.py b/backend/open_webui/routers/analytics.py index fd045f79e7..fcef30342c 100644 --- a/backend/open_webui/routers/analytics.py +++ b/backend/open_webui/routers/analytics.py @@ -1,17 +1,17 @@ -from typing import Optional -from datetime import datetime, timedelta -from collections import defaultdict import logging -from fastapi import APIRouter, Depends, Query -from pydantic import BaseModel +from collections import defaultdict +from datetime import datetime, timedelta +from typing import Optional -from open_webui.models.chat_messages import ChatMessages, ChatMessageModel +from fastapi import APIRouter, Depends, Query +from open_webui.internal.db import get_async_session +from open_webui.models.chat_messages import ChatMessageModel, ChatMessages from open_webui.models.chats import Chats +from open_webui.models.feedbacks import Feedbacks from open_webui.models.groups import Groups from open_webui.models.users import Users -from open_webui.models.feedbacks import Feedbacks from open_webui.utils.auth import get_admin_user -from open_webui.internal.db import get_async_session +from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py index 6366f0e72b..bba91c2d5a 100644 --- a/backend/open_webui/routers/audio.py +++ b/backend/open_webui/routers/audio.py @@ -1,24 +1,22 @@ import asyncio -import io +import base64 import hashlib +import html +import io import json import logging +import mimetypes import os import uuid -import html -import base64 -from pydub import AudioSegment -from pydub.silence import split_on_silence from concurrent.futures import ThreadPoolExecutor +from fnmatch import fnmatch from typing import Optional -from fnmatch import fnmatch -import aiohttp import aiofiles +import aiohttp import requests -import mimetypes - from fastapi import ( + APIRouter, Depends, FastAPI, File, @@ -27,38 +25,36 @@ from fastapi import ( Request, UploadFile, status, - APIRouter, ) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse -from pydantic import BaseModel - - -from open_webui.utils.misc import strict_match_mime_type -from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_permission -from open_webui.utils.headers import include_user_info_headers from open_webui.config import ( - WHISPER_MODEL_AUTO_UPDATE, - WHISPER_COMPUTE_TYPE, - WHISPER_MODEL_DIR, - WHISPER_VAD_FILTER, CACHE_DIR, - WHISPER_LANGUAGE, - WHISPER_MULTILINGUAL, ELEVENLABS_API_BASE_URL, + WHISPER_COMPUTE_TYPE, + WHISPER_LANGUAGE, + WHISPER_MODEL_AUTO_UPDATE, + WHISPER_MODEL_DIR, + WHISPER_MULTILINGUAL, + WHISPER_VAD_FILTER, ) - from open_webui.constants import ERROR_MESSAGES from open_webui.env import ( - ENV, AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, BYPASS_PYDUB_PREPROCESSING, DEVICE_TYPE, ENABLE_FORWARD_USER_INFO_HEADERS, + ENV, ) +from open_webui.utils.access_control import has_permission +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.headers import include_user_info_headers +from open_webui.utils.misc import strict_match_mime_type +from pydantic import BaseModel +from pydub import AudioSegment +from pydub.silence import split_on_silence router = APIRouter() @@ -364,8 +360,8 @@ async def update_audio_config(request: Request, form_data: AudioConfigUpdateForm def load_speech_pipeline(request): - from transformers import pipeline from datasets import load_dataset + from transformers import pipeline if request.app.state.speech_synthesiser is None: request.app.state.speech_synthesiser = pipeline('text-to-speech', 'microsoft/speecht5_tts') @@ -592,8 +588,8 @@ async def speech(request: Request, user=Depends(get_verified_user)): log.exception(e) raise HTTPException(status_code=400, detail='Invalid JSON payload') - import torch import soundfile as sf + import torch load_speech_pipeline(request) diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index 5c2be8f22d..6745ce4847 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -1,92 +1,86 @@ +from __future__ import annotations + import asyncio -import re -import uuid -import time import datetime import logging -from aiohttp import ClientSession +import re +import time import urllib +import uuid +from ssl import CERT_NONE, CERT_REQUIRED, PROTOCOL_TLS +from typing import List, Optional - +from aiohttp import ClientSession +from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi.responses import JSONResponse, RedirectResponse, Response +from ldap3 import NONE, Connection, Server, Tls +from ldap3.utils.conv import escape_filter_chars +from open_webui.config import ( + ENABLE_LDAP, + ENABLE_OAUTH_SIGNUP, + ENABLE_PASSWORD_AUTH, + OAUTH_MERGE_ACCOUNTS_BY_EMAIL, + OAUTH_PROVIDERS, + OPENID_END_SESSION_ENDPOINT, + OPENID_PROVIDER_URL, +) +from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES +from open_webui.env import ( + AIOHTTP_CLIENT_SESSION_SSL, + ENABLE_INITIAL_ADMIN_SIGNUP, + ENABLE_OAUTH_TOKEN_EXCHANGE, + WEBUI_AUTH, + WEBUI_AUTH_COOKIE_SAME_SITE, + WEBUI_AUTH_COOKIE_SECURE, + WEBUI_AUTH_SIGNOUT_REDIRECT_URL, + WEBUI_AUTH_TRUSTED_EMAIL_HEADER, + WEBUI_AUTH_TRUSTED_GROUPS_HEADER, + WEBUI_AUTH_TRUSTED_NAME_HEADER, + WEBUI_AUTH_TRUSTED_ROLE_HEADER, +) +from open_webui.internal.db import get_async_session from open_webui.models.auths import ( AddUserForm, ApiKey, Auths, - Token, LdapForm, SigninForm, SigninResponse, SignupForm, + Token, UpdatePasswordForm, ) -from open_webui.models.users import ( - UserModel, - UserProfileImageResponse, - Users, - UpdateProfileForm, - UserStatus, -) from open_webui.models.groups import Groups from open_webui.models.oauth_sessions import OAuthSessions - -from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES -from open_webui.env import ( - WEBUI_AUTH, - WEBUI_AUTH_TRUSTED_EMAIL_HEADER, - WEBUI_AUTH_TRUSTED_NAME_HEADER, - WEBUI_AUTH_TRUSTED_GROUPS_HEADER, - WEBUI_AUTH_TRUSTED_ROLE_HEADER, - WEBUI_AUTH_COOKIE_SAME_SITE, - WEBUI_AUTH_COOKIE_SECURE, - WEBUI_AUTH_SIGNOUT_REDIRECT_URL, - ENABLE_INITIAL_ADMIN_SIGNUP, - ENABLE_OAUTH_TOKEN_EXCHANGE, - AIOHTTP_CLIENT_SESSION_SSL, +from open_webui.models.users import ( + UpdateProfileForm, + UserModel, + UserProfileImageResponse, + Users, + UserStatus, ) -from fastapi import APIRouter, Depends, HTTPException, Request, status -from fastapi.responses import RedirectResponse, Response, JSONResponse -from open_webui.config import ( - OPENID_PROVIDER_URL, - OPENID_END_SESSION_ENDPOINT, - ENABLE_OAUTH_SIGNUP, - ENABLE_LDAP, - ENABLE_PASSWORD_AUTH, - OAUTH_PROVIDERS, - OAUTH_MERGE_ACCOUNTS_BY_EMAIL, -) -from open_webui.utils.oauth import auth_manager_config -from pydantic import BaseModel - -from open_webui.utils.misc import parse_duration, validate_email_format +from open_webui.utils.access_control import get_permissions, has_permission from open_webui.utils.auth import ( - validate_password, - verify_password, - decode_token, - invalidate_token, create_api_key, create_token, + decode_token, get_admin_user, - get_verified_user, get_current_user, - get_password_hash, get_http_authorization_cred, + get_password_hash, + get_verified_user, + invalidate_token, + validate_password, + verify_password, ) -from open_webui.internal.db import get_async_session -from sqlalchemy.ext.asyncio import AsyncSession -from open_webui.utils.webhook import post_webhook -from open_webui.utils.access_control import get_permissions, has_permission from open_webui.utils.groups import apply_default_group_assignment - -from open_webui.utils.redis import get_redis_client +from open_webui.utils.misc import parse_duration, validate_email_format +from open_webui.utils.oauth import auth_manager_config from open_webui.utils.rate_limit import RateLimiter - - -from typing import Optional, List - -from ssl import CERT_NONE, CERT_REQUIRED, PROTOCOL_TLS - -from ldap3 import Server, Connection, NONE, Tls -from ldap3.utils.conv import escape_filter_chars +from open_webui.utils.redis import get_redis_client +from open_webui.utils.webhook import post_webhook +from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncSession router = APIRouter() @@ -155,14 +149,14 @@ async def create_session_response( class SessionUserResponse(Token, UserProfileImageResponse): - expires_at: Optional[int] = None - permissions: Optional[dict] = None + expires_at: int | None = None + permissions: dict | None = None class SessionUserInfoResponse(SessionUserResponse, UserStatus): - bio: Optional[str] = None - gender: Optional[str] = None - date_of_birth: Optional[datetime.date] = None + bio: str | None = None + gender: str | None = None + date_of_birth: datetime.date | None = None @router.get('/', response_model=SessionUserInfoResponse) @@ -1031,7 +1025,7 @@ async def get_admin_config(request: Request, user=Depends(get_admin_user)): class AdminConfig(BaseModel): SHOW_ADMIN_DETAILS: bool - ADMIN_EMAIL: Optional[str] = None + ADMIN_EMAIL: str | None = None WEBUI_URL: str ENABLE_SIGNUP: bool ENABLE_API_KEYS: bool @@ -1043,9 +1037,9 @@ class AdminConfig(BaseModel): ENABLE_COMMUNITY_SHARING: bool ENABLE_MESSAGE_RATING: bool ENABLE_FOLDERS: bool - FOLDER_MAX_FILE_COUNT: Optional[int | str] = None - AUTOMATION_MAX_COUNT: Optional[int | str] = None - AUTOMATION_MIN_INTERVAL: Optional[int | str] = None + FOLDER_MAX_FILE_COUNT: int | str | None = None + AUTOMATION_MAX_COUNT: int | str | None = None + AUTOMATION_MIN_INTERVAL: int | str | None = None ENABLE_AUTOMATIONS: bool ENABLE_CHANNELS: bool ENABLE_CALENDAR: bool @@ -1053,9 +1047,9 @@ class AdminConfig(BaseModel): ENABLE_NOTES: bool ENABLE_USER_WEBHOOKS: bool ENABLE_USER_STATUS: bool - PENDING_USER_OVERLAY_TITLE: Optional[str] = None - PENDING_USER_OVERLAY_CONTENT: Optional[str] = None - RESPONSE_WATERMARK: Optional[str] = None + PENDING_USER_OVERLAY_TITLE: str | None = None + PENDING_USER_OVERLAY_CONTENT: str | None = None + RESPONSE_WATERMARK: str | None = None @router.post('/admin/config') @@ -1140,7 +1134,7 @@ async def update_admin_config(request: Request, form_data: AdminConfig, user=Dep class LdapServerConfig(BaseModel): label: str host: str - port: Optional[int] = None + port: int | None = None attribute_for_mail: str = 'mail' attribute_for_username: str = 'uid' app_dn: str @@ -1148,9 +1142,9 @@ class LdapServerConfig(BaseModel): search_base: str search_filters: str = '' use_tls: bool = True - certificate_path: Optional[str] = None + certificate_path: str | None = None validate_cert: bool = True - ciphers: Optional[str] = 'ALL' + ciphers: str | None = 'ALL' @router.get('/admin/config/ldap/server', response_model=LdapServerConfig) @@ -1223,7 +1217,7 @@ async def get_ldap_config(request: Request, user=Depends(get_admin_user)): class LdapConfigForm(BaseModel): - enable_ldap: Optional[bool] = None + enable_ldap: bool | None = None @router.post('/admin/config/ldap') diff --git a/backend/open_webui/routers/automations.py b/backend/open_webui/routers/automations.py index 4ff66feb97..fced12978a 100644 --- a/backend/open_webui/routers/automations.py +++ b/backend/open_webui/routers/automations.py @@ -1,30 +1,29 @@ import asyncio import logging - from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, Request, status -from sqlalchemy.ext.asyncio import AsyncSession +from fastapi import APIRouter, Depends, HTTPException, Request, status +from open_webui.constants import ERROR_MESSAGES +from open_webui.internal.db import get_async_session from open_webui.models.automations import ( - Automations, - AutomationRuns, AutomationForm, + AutomationListResponse, AutomationModel, AutomationResponse, AutomationRunModel, - AutomationListResponse, + AutomationRuns, + Automations, ) -from open_webui.utils.automations import ( - validate_rrule, - next_run_ns, - next_n_runs_ns, - execute_automation, - rrule_interval_seconds, -) -from open_webui.utils.auth import get_verified_user, get_admin_user from open_webui.utils.access_control import has_permission -from open_webui.internal.db import get_async_session -from open_webui.constants import ERROR_MESSAGES +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.automations import ( + execute_automation, + next_n_runs_ns, + next_run_ns, + rrule_interval_seconds, + validate_rrule, +) +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) diff --git a/backend/open_webui/routers/calendar.py b/backend/open_webui/routers/calendar.py index bdc06e819b..6c567aa3a8 100644 --- a/backend/open_webui/routers/calendar.py +++ b/backend/open_webui/routers/calendar.py @@ -3,28 +3,27 @@ import time from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Request, status - +from open_webui.constants import ERROR_MESSAGES +from open_webui.models.access_grants import AccessGrants from open_webui.models.calendar import ( - Calendars, - CalendarEvents, CalendarEventAttendees, - CalendarForm, - CalendarUpdateForm, CalendarEventForm, - CalendarEventUpdateForm, - CalendarModel, - CalendarEventModel, - CalendarEventUserResponse, CalendarEventListResponse, + CalendarEventModel, + CalendarEvents, + CalendarEventUpdateForm, + CalendarEventUserResponse, + CalendarForm, + CalendarModel, + Calendars, + CalendarUpdateForm, RSVPForm, ) -from open_webui.models.access_grants import AccessGrants from open_webui.models.groups import Groups from open_webui.models.users import UserModel +from open_webui.utils.access_control import filter_allowed_access_grants, has_permission from open_webui.utils.auth import get_verified_user -from open_webui.utils.access_control import has_permission, filter_allowed_access_grants from open_webui.utils.calendar import expand_recurring_event -from open_webui.constants import ERROR_MESSAGES log = logging.getLogger(__name__) @@ -188,7 +187,7 @@ async def get_events( cal_id_list is None or SCHEDULED_TASKS_CALENDAR_ID in cal_id_list ): try: - from open_webui.models.automations import Automations, AutomationRuns + from open_webui.models.automations import AutomationRuns, Automations # Future runs: expand RRULEs for active automations only active_automations = await Automations.get_active_by_user(user.id) diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index 70eb799ea6..68fc71fe5a 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -1,69 +1,58 @@ -import json -import logging import base64 import io +import json +import logging from typing import Optional - -from fastapi import APIRouter, Depends, HTTPException, Request, status, BackgroundTasks -from fastapi.responses import Response, StreamingResponse, FileResponse -from pydantic import BaseModel -from pydantic import field_validator - -from open_webui.socket.main import ( - emit_to_users, - enter_room_for_users, - sio, - get_user_ids_from_room, +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, status +from fastapi.responses import FileResponse, Response, StreamingResponse +from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT +from open_webui.constants import ERROR_MESSAGES +from open_webui.env import STATIC_DIR +from open_webui.internal.db import get_async_session +from open_webui.models.access_grants import AccessGrants, has_public_read_access_grant, has_public_write_access_grant +from open_webui.models.channels import ( + ChannelForm, + ChannelModel, + ChannelResponse, + Channels, + ChannelWebhookForm, + ChannelWebhookModel, + CreateChannelForm, +) +from open_webui.models.groups import Groups +from open_webui.models.messages import ( + MessageForm, + MessageModel, + MessageResponse, + Messages, + MessageWithReactionsResponse, ) from open_webui.models.users import ( UserIdNameResponse, UserIdNameStatusResponse, UserListResponse, - UserModelResponse, - Users, UserModel, + UserModelResponse, UserNameResponse, + Users, ) - -from open_webui.models.groups import Groups -from open_webui.models.channels import ( - Channels, - ChannelModel, - ChannelForm, - ChannelResponse, - CreateChannelForm, - ChannelWebhookModel, - ChannelWebhookForm, +from open_webui.socket.main import ( + emit_to_users, + enter_room_for_users, + get_user_ids_from_room, + sio, ) -from open_webui.models.access_grants import AccessGrants, has_public_read_access_grant, has_public_write_access_grant -from open_webui.models.messages import ( - Messages, - MessageModel, - MessageResponse, - MessageWithReactionsResponse, - MessageForm, -) - - +from open_webui.utils.access_control import filter_allowed_access_grants, has_permission +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.channels import extract_mentions, replace_mentions from open_webui.utils.files import get_image_base64_from_file_id - -from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT -from open_webui.constants import ERROR_MESSAGES -from open_webui.env import STATIC_DIR - - from open_webui.utils.models import ( get_all_models, get_filtered_models, ) - - -from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_permission, filter_allowed_access_grants from open_webui.utils.webhook import post_webhook -from open_webui.utils.channels import extract_mentions, replace_mentions -from open_webui.internal.db import get_async_session +from pydantic import BaseModel, field_validator from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) @@ -980,9 +969,9 @@ async def model_response_handler(request, channel, message, user, db=None): # Resolve model config (same helpers automations use) from open_webui.utils.automations import ( - _resolve_model_tool_ids, _resolve_model_features, _resolve_model_filter_ids, + _resolve_model_tool_ids, ) tool_ids = _resolve_model_tool_ids(request.app, model_id) diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py index 9c4609477c..c2f3da563d 100644 --- a/backend/open_webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -1,43 +1,41 @@ +from __future__ import annotations + +import asyncio import json import logging from typing import Optional from uuid import uuid4 -from sqlalchemy.ext.asyncio import AsyncSession -import asyncio + +from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import StreamingResponse - - -from open_webui.utils.misc import get_message_list -from open_webui.utils.middleware import serialize_output -from open_webui.socket.main import get_event_emitter -from open_webui.models.chats import ( - ChatForm, - ChatImportForm, - ChatUsageStatsListResponse, - ChatsImportForm, - ChatResponse, - Chats, - ChatTitleIdResponse, - ChatStatsExport, - AggregateChatStats, - ChatBody, - ChatHistoryStats, - MessageStats, -) -from open_webui.models.shared_chats import SharedChats, SharedChatResponse -from open_webui.models.access_grants import AccessGrants -from open_webui.models.tags import TagModel, Tags -from open_webui.models.folders import Folders -from open_webui.internal.db import get_async_session - from open_webui.config import ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT from open_webui.constants import ERROR_MESSAGES -from fastapi import APIRouter, Depends, HTTPException, Request, status -from pydantic import BaseModel - - +from open_webui.internal.db import get_async_session +from open_webui.models.access_grants import AccessGrants +from open_webui.models.chats import ( + AggregateChatStats, + ChatBody, + ChatForm, + ChatHistoryStats, + ChatImportForm, + ChatResponse, + Chats, + ChatsImportForm, + ChatStatsExport, + ChatTitleIdResponse, + ChatUsageStatsListResponse, + MessageStats, +) +from open_webui.models.folders import Folders +from open_webui.models.shared_chats import SharedChatResponse, SharedChats +from open_webui.models.tags import TagModel, Tags +from open_webui.socket.main import get_event_emitter +from open_webui.utils.access_control import filter_allowed_access_grants, has_permission from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_permission, filter_allowed_access_grants +from open_webui.utils.middleware import serialize_output +from open_webui.utils.misc import get_message_list +from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) @@ -54,9 +52,9 @@ router = APIRouter() @router.get('/list', response_model=list[ChatTitleIdResponse]) async def get_session_user_chat_list( user=Depends(get_verified_user), - page: Optional[int] = None, - include_pinned: Optional[bool] = False, - include_folders: Optional[bool] = False, + page: int | None = None, + include_pinned: bool | None = False, + include_folders: bool | None = False, db: AsyncSession = Depends(get_async_session), ): try: @@ -92,8 +90,8 @@ async def get_session_user_chat_list( @router.get('/stats/usage', response_model=ChatUsageStatsListResponse) async def get_session_user_chat_usage_stats( - items_per_page: Optional[int] = 50, - page: Optional[int] = 1, + items_per_page: int | None = 50, + page: int | None = 1, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): @@ -210,7 +208,7 @@ class ChatStatsExportList(BaseModel): page: int -def _process_chat_for_export(chat) -> Optional[ChatStatsExport]: +def _process_chat_for_export(chat) -> ChatStatsExport | None: try: def get_message_content_length(message): @@ -395,8 +393,8 @@ async def generate_chat_stats_jsonl_generator(user_id, filter): @router.get('/stats/export', response_model=ChatStatsExportList) async def export_chat_stats( request: Request, - updated_at: Optional[int] = None, - page: Optional[int] = 1, + updated_at: int | None = None, + page: int | None = 1, stream: bool = False, user=Depends(get_verified_user), ): @@ -438,7 +436,7 @@ async def export_chat_stats( ############################ -@router.get('/stats/export/{chat_id}', response_model=Optional[ChatStatsExport]) +@router.get('/stats/export/{chat_id}', response_model=ChatStatsExport | None) async def export_single_chat_stats( request: Request, chat_id: str, @@ -516,10 +514,10 @@ async def delete_all_user_chats( @router.get('/list/user/{user_id}', response_model=list[ChatTitleIdResponse]) async def get_user_chat_list_by_user_id( user_id: str, - page: Optional[int] = None, - query: Optional[str] = None, - order_by: Optional[str] = None, - direction: Optional[str] = None, + page: int | None = None, + query: str | None = None, + order_by: str | None = None, + direction: str | None = None, user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session), ): @@ -553,7 +551,7 @@ async def get_user_chat_list_by_user_id( ############################ -@router.post('/new', response_model=Optional[ChatResponse]) +@router.post('/new', response_model=ChatResponse | None) async def create_new_chat( form_data: ChatForm, user=Depends(get_verified_user), @@ -594,7 +592,7 @@ async def import_chats( @router.get('/search', response_model=list[ChatTitleIdResponse]) async def search_user_chats( text: str, - page: Optional[int] = None, + page: int | None = None, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): @@ -644,7 +642,7 @@ async def get_chats_by_folder_id( @router.get('/folder/{folder_id}/list') async def get_chat_list_by_folder_id( folder_id: str, - page: Optional[int] = 1, + page: int | None = 1, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): @@ -766,10 +764,10 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user), db: AsyncSessio @router.get('/archived', response_model=list[ChatTitleIdResponse]) async def get_archived_session_user_chat_list( - page: Optional[int] = None, - query: Optional[str] = None, - order_by: Optional[str] = None, - direction: Optional[str] = None, + page: int | None = None, + query: str | None = None, + order_by: str | None = None, + direction: str | None = None, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): @@ -823,10 +821,10 @@ async def unarchive_all_chats(user=Depends(get_verified_user), db: AsyncSession @router.get('/shared', response_model=list[SharedChatResponse]) async def get_shared_session_user_chat_list( - page: Optional[int] = None, - query: Optional[str] = None, - order_by: Optional[str] = None, - direction: Optional[str] = None, + page: int | None = None, + query: str | None = None, + order_by: str | None = None, + direction: str | None = None, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): @@ -858,7 +856,7 @@ async def get_shared_session_user_chat_list( ############################ -@router.get('/share/{share_id}', response_model=Optional[ChatResponse]) +@router.get('/share/{share_id}', response_model=ChatResponse | None) async def get_shared_chat_by_id( share_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) ): @@ -904,8 +902,8 @@ class TagForm(BaseModel): class TagFilterForm(TagForm): - skip: Optional[int] = 0 - limit: Optional[int] = 50 + skip: int | None = 0 + limit: int | None = 50 @router.post('/tags', response_model=list[ChatTitleIdResponse]) @@ -928,7 +926,7 @@ async def get_user_chat_list_by_tag_name( ############################ -@router.get('/{id}', response_model=Optional[ChatResponse]) +@router.get('/{id}', response_model=ChatResponse | None) async def get_chat_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): chat = await Chats.get_chat_by_id_and_user_id(id, user.id, db=db) @@ -958,7 +956,7 @@ async def get_chat_by_id(id: str, user=Depends(get_verified_user), db: AsyncSess ############################ -@router.post('/{id}', response_model=Optional[ChatResponse]) +@router.post('/{id}', response_model=ChatResponse | None) async def update_chat_by_id( id: str, form_data: ChatForm, @@ -992,7 +990,7 @@ class MessageForm(BaseModel): content: str -@router.post('/{id}/messages/{message_id}', response_model=Optional[ChatResponse]) +@router.post('/{id}/messages/{message_id}', response_model=ChatResponse | None) async def update_chat_message_by_id( id: str, message_id: str, @@ -1054,7 +1052,7 @@ class EventForm(BaseModel): data: dict -@router.post('/{id}/messages/{message_id}/event', response_model=Optional[bool]) +@router.post('/{id}/messages/{message_id}/event', response_model=bool | None) async def send_chat_message_event_by_id( id: str, message_id: str, @@ -1142,7 +1140,7 @@ async def delete_chat_by_id( ############################ -@router.get('/{id}/pinned', response_model=Optional[bool]) +@router.get('/{id}/pinned', response_model=bool | None) async def get_pinned_status_by_id( id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) ): @@ -1158,7 +1156,7 @@ async def get_pinned_status_by_id( ############################ -@router.post('/{id}/pin', response_model=Optional[ChatResponse]) +@router.post('/{id}/pin', response_model=ChatResponse | None) async def pin_chat_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): chat = await Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: @@ -1174,10 +1172,10 @@ async def pin_chat_by_id(id: str, user=Depends(get_verified_user), db: AsyncSess class CloneForm(BaseModel): - title: Optional[str] = None + title: str | None = None -@router.post('/{id}/clone', response_model=Optional[ChatResponse]) +@router.post('/{id}/clone', response_model=ChatResponse | None) async def clone_chat_by_id( form_data: CloneForm, id: str, @@ -1225,7 +1223,7 @@ async def clone_chat_by_id( ############################ -@router.post('/{id}/clone/shared', response_model=Optional[ChatResponse]) +@router.post('/{id}/clone/shared', response_model=ChatResponse | None) async def clone_shared_chat_by_id( id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) ): @@ -1294,7 +1292,7 @@ async def clone_shared_chat_by_id( ############################ -@router.post('/{id}/archive', response_model=Optional[ChatResponse]) +@router.post('/{id}/archive', response_model=ChatResponse | None) async def archive_chat_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): chat = await Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: @@ -1318,7 +1316,7 @@ async def archive_chat_by_id(id: str, user=Depends(get_verified_user), db: Async ############################ -@router.post('/{id}/share', response_model=Optional[ChatResponse]) +@router.post('/{id}/share', response_model=ChatResponse | None) async def share_chat_by_id( request: Request, id: str, @@ -1372,7 +1370,7 @@ async def share_chat_by_id( ############################ -@router.delete('/{id}/share', response_model=Optional[bool]) +@router.delete('/{id}/share', response_model=bool | None) async def delete_shared_chat_by_id( id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) ): @@ -1404,7 +1402,7 @@ class ChatAccessGrantsForm(BaseModel): access_grants: list[dict] -@router.post('/shared/{id}/access/update', response_model=Optional[ChatResponse]) +@router.post('/shared/{id}/access/update', response_model=ChatResponse | None) async def update_shared_chat_access_by_id( request: Request, id: str, @@ -1474,10 +1472,10 @@ async def get_shared_chat_access_by_id( class ChatFolderIdForm(BaseModel): - folder_id: Optional[str] = None + folder_id: str | None = None -@router.post('/{id}/folder', response_model=Optional[ChatResponse]) +@router.post('/{id}/folder', response_model=ChatResponse | None) async def update_chat_folder_id_by_id( id: str, form_data: ChatFolderIdForm, @@ -1571,7 +1569,7 @@ async def delete_tag_by_id_and_tag_name( ############################ -@router.delete('/{id}/tags/all', response_model=Optional[bool]) +@router.delete('/{id}/tags/all', response_model=bool | None) async def delete_all_tags_by_id( id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) ): diff --git a/backend/open_webui/routers/configs.py b/backend/open_webui/routers/configs.py index 1d55dba75e..c40131f41f 100644 --- a/backend/open_webui/routers/configs.py +++ b/backend/open_webui/routers/configs.py @@ -1,37 +1,34 @@ -import logging -import copy -from fastapi import APIRouter, Depends, Request, HTTPException -from pydantic import BaseModel, ConfigDict -import aiohttp +from __future__ import annotations +import copy +import logging from typing import Optional +import aiohttp +from fastapi import APIRouter, Depends, HTTPException, Request +from mcp.shared.auth import OAuthMetadata +from open_webui.config import BannerModel, async_save_config, get_config, save_config from open_webui.env import AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_TIMEOUT +from open_webui.models.oauth_sessions import OAuthSessions from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.headers import get_custom_headers -from open_webui.config import get_config, save_config, async_save_config -from open_webui.config import BannerModel - -from open_webui.utils.tools import ( - get_tool_server_data, - get_tool_server_url, - set_tool_servers, - set_terminal_servers, -) from open_webui.utils.mcp.client import MCPClient -from open_webui.models.oauth_sessions import OAuthSessions - - from open_webui.utils.oauth import ( + OAuthClientInformationFull, + decrypt_data, + encrypt_data, get_discovery_urls, get_oauth_client_info_with_dynamic_client_registration, get_oauth_client_info_with_static_credentials, - encrypt_data, - decrypt_data, resolve_oauth_client_info, - OAuthClientInformationFull, ) -from mcp.shared.auth import OAuthMetadata +from open_webui.utils.tools import ( + get_tool_server_data, + get_tool_server_url, + set_terminal_servers, + set_tool_servers, +) +from pydantic import BaseModel, ConfigDict router = APIRouter() @@ -102,16 +99,16 @@ async def set_connections_config( class OAuthClientRegistrationForm(BaseModel): url: str client_id: str - client_name: Optional[str] = None - client_secret: Optional[str] = None - oauth_server_url: Optional[str] = None + client_name: str | None = None + client_secret: str | None = None + oauth_server_url: str | None = None @router.post('/oauth/clients/register') async def register_oauth_client( request: Request, form_data: OAuthClientRegistrationForm, - type: Optional[str] = None, + type: str | None = None, user=Depends(get_admin_user), ): try: @@ -154,12 +151,12 @@ async def register_oauth_client( class ToolServerConnection(BaseModel): url: str path: str - type: Optional[str] = 'openapi' # openapi, mcp - auth_type: Optional[str] - headers: Optional[dict | str] = None - key: Optional[str] - config: Optional[dict] - info: Optional[dict] = None + type: str | None = 'openapi' # openapi, mcp + auth_type: str | None + headers: dict | str | None = None + key: str | None + config: dict | None + info: dict | None = None model_config = ConfigDict(extra='allow') @@ -225,23 +222,23 @@ async def set_tool_servers_config( class TerminalServerConnection(BaseModel): - id: Optional[str] = '' - name: Optional[str] = '' + id: str | None = '' + name: str | None = '' - enabled: Optional[bool] = True + enabled: bool | None = True url: str - path: Optional[str] = '/openapi.json' + path: str | None = '/openapi.json' - key: Optional[str] = '' - auth_type: Optional[str] = 'bearer' + key: str | None = '' + auth_type: str | None = 'bearer' - config: Optional[dict] = None + config: dict | None = None # Orchestrator policy fields - server_type: Optional[str] = None # "orchestrator", "terminal" - policy_id: Optional[str] = None - policy: Optional[dict] = None # cached policy data + server_type: str | None = None # "orchestrator", "terminal" + policy_id: str | None = None + policy: dict | None = None # cached policy data model_config = ConfigDict(extra='allow') @@ -325,8 +322,8 @@ async def verify_terminal_server_connection( class TerminalServerPolicyForm(BaseModel): url: str - key: Optional[str] = '' - auth_type: Optional[str] = 'bearer' + key: str | None = '' + auth_type: str | None = 'bearer' policy_id: str policy_data: dict @@ -504,19 +501,19 @@ async def verify_tool_servers_config(request: Request, form_data: ToolServerConn class CodeInterpreterConfigForm(BaseModel): ENABLE_CODE_EXECUTION: bool CODE_EXECUTION_ENGINE: str - CODE_EXECUTION_JUPYTER_URL: Optional[str] - CODE_EXECUTION_JUPYTER_AUTH: Optional[str] - CODE_EXECUTION_JUPYTER_AUTH_TOKEN: Optional[str] - CODE_EXECUTION_JUPYTER_AUTH_PASSWORD: Optional[str] - CODE_EXECUTION_JUPYTER_TIMEOUT: Optional[int] + CODE_EXECUTION_JUPYTER_URL: str | None + CODE_EXECUTION_JUPYTER_AUTH: str | None + CODE_EXECUTION_JUPYTER_AUTH_TOKEN: str | None + CODE_EXECUTION_JUPYTER_AUTH_PASSWORD: str | None + CODE_EXECUTION_JUPYTER_TIMEOUT: int | None ENABLE_CODE_INTERPRETER: bool CODE_INTERPRETER_ENGINE: str - CODE_INTERPRETER_PROMPT_TEMPLATE: Optional[str] - CODE_INTERPRETER_JUPYTER_URL: Optional[str] - CODE_INTERPRETER_JUPYTER_AUTH: Optional[str] - CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: Optional[str] - CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: Optional[str] - CODE_INTERPRETER_JUPYTER_TIMEOUT: Optional[int] + CODE_INTERPRETER_PROMPT_TEMPLATE: str | None + CODE_INTERPRETER_JUPYTER_URL: str | None + CODE_INTERPRETER_JUPYTER_AUTH: str | None + CODE_INTERPRETER_JUPYTER_AUTH_TOKEN: str | None + CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD: str | None + CODE_INTERPRETER_JUPYTER_TIMEOUT: int | None @router.get('/code_execution', response_model=CodeInterpreterConfigForm) @@ -588,11 +585,11 @@ async def set_code_execution_config( # SetDefaultModels ############################ class ModelsConfigForm(BaseModel): - DEFAULT_MODELS: Optional[str] - DEFAULT_PINNED_MODELS: Optional[str] - MODEL_ORDER_LIST: Optional[list[str]] - DEFAULT_MODEL_METADATA: Optional[dict] = None - DEFAULT_MODEL_PARAMS: Optional[dict] = None + DEFAULT_MODELS: str | None + DEFAULT_PINNED_MODELS: str | None + MODEL_ORDER_LIST: list[str | None] + DEFAULT_MODEL_METADATA: dict | None = None + DEFAULT_MODEL_PARAMS: dict | None = None @router.get('/models/defaults') diff --git a/backend/open_webui/routers/evaluations.py b/backend/open_webui/routers/evaluations.py index 072c7fa732..cce9d591ee 100644 --- a/backend/open_webui/routers/evaluations.py +++ b/backend/open_webui/routers/evaluations.py @@ -1,26 +1,25 @@ -from typing import Optional import logging -from fastapi import APIRouter, Depends, HTTPException, status, Request -from fastapi.concurrency import run_in_threadpool -from pydantic import BaseModel +from typing import Optional -from open_webui.models.users import Users, UserModel +from fastapi import APIRouter, Depends, HTTPException, Request, status +from fastapi.concurrency import run_in_threadpool +from open_webui.constants import ERROR_MESSAGES +from open_webui.internal.db import get_async_session from open_webui.models.feedbacks import ( + FeedbackForm, FeedbackIdResponse, + FeedbackListResponse, FeedbackModel, FeedbackResponse, - FeedbackForm, + Feedbacks, FeedbackUserResponse, - FeedbackListResponse, LeaderboardFeedbackData, ModelHistoryEntry, ModelHistoryResponse, - Feedbacks, ) - -from open_webui.constants import ERROR_MESSAGES +from open_webui.models.users import UserModel, Users from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.internal.db import get_async_session +from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 86beeec188..74ba2365a6 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -1,34 +1,31 @@ +import asyncio +import json import logging import os import uuid -import json from pathlib import Path from typing import Optional from urllib.parse import quote -import asyncio from fastapi import ( - BackgroundTasks, APIRouter, + BackgroundTasks, Depends, File, Form, HTTPException, + Query, Request, UploadFile, status, - Query, ) - from fastapi.responses import FileResponse, StreamingResponse -from sqlalchemy.ext.asyncio import AsyncSession -from open_webui.internal.db import get_async_session, get_async_db_context - +from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STORAGE_LOCAL_CACHE, STORAGE_PROVIDER, UPLOAD_DIR from open_webui.constants import ERROR_MESSAGES -from open_webui.retrieval.vector.async_client import ASYNC_VECTOR_DB_CLIENT - +from open_webui.internal.db import get_async_db_context, get_async_session +from open_webui.models.access_grants import AccessGrants from open_webui.models.channels import Channels -from open_webui.models.users import Users +from open_webui.models.chats import Chats from open_webui.models.files import ( FileForm, FileListResponse, @@ -36,22 +33,17 @@ from open_webui.models.files import ( FileModelResponse, Files, ) -from open_webui.models.chats import Chats -from open_webui.models.knowledge import Knowledges from open_webui.models.groups import Groups -from open_webui.models.access_grants import AccessGrants - - -from open_webui.routers.retrieval import ProcessFileForm, process_file +from open_webui.models.knowledge import Knowledges +from open_webui.models.users import Users +from open_webui.retrieval.vector.async_client import ASYNC_VECTOR_DB_CLIENT from open_webui.routers.audio import transcribe - +from open_webui.routers.retrieval import ProcessFileForm, process_file from open_webui.storage.provider import Storage - - -from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, STORAGE_LOCAL_CACHE, STORAGE_PROVIDER, UPLOAD_DIR from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.misc import strict_match_mime_type from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) diff --git a/backend/open_webui/routers/folders.py b/backend/open_webui/routers/folders.py index 7dda918821..8d77de4894 100644 --- a/backend/open_webui/routers/folders.py +++ b/backend/open_webui/routers/folders.py @@ -1,36 +1,29 @@ import logging +import mimetypes import os import shutil import uuid from pathlib import Path from typing import Optional -from pydantic import BaseModel -import mimetypes - - -from open_webui.models.folders import ( - FolderForm, - FolderUpdateForm, - FolderModel, - FolderNameIdResponse, - Folders, -) -from open_webui.models.chats import Chats - +from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status +from fastapi.responses import FileResponse, StreamingResponse from open_webui.config import UPLOAD_DIR from open_webui.constants import ERROR_MESSAGES from open_webui.internal.db import get_async_session -from sqlalchemy.ext.asyncio import AsyncSession - - -from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status, Request -from fastapi.responses import FileResponse, StreamingResponse - - -from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.models.chats import Chats +from open_webui.models.folders import ( + FolderForm, + FolderModel, + FolderNameIdResponse, + Folders, + FolderUpdateForm, +) from open_webui.utils.access_control import has_permission from open_webui.utils.access_control.files import get_accessible_folder_files +from open_webui.utils.auth import get_admin_user, get_verified_user +from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) diff --git a/backend/open_webui/routers/functions.py b/backend/open_webui/routers/functions.py index f40cd1ab82..58ec93657e 100644 --- a/backend/open_webui/routers/functions.py +++ b/backend/open_webui/routers/functions.py @@ -1,32 +1,33 @@ -import os -import re +from __future__ import annotations import logging -import aiohttp +import os +import re from pathlib import Path from typing import Optional +import aiohttp +from fastapi import APIRouter, Depends, HTTPException, Request, status +from open_webui.config import CACHE_DIR +from open_webui.constants import ERROR_MESSAGES from open_webui.env import AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_TIMEOUT +from open_webui.internal.db import get_async_session from open_webui.models.functions import ( FunctionForm, FunctionModel, FunctionResponse, + Functions, FunctionUserResponse, FunctionWithValvesModel, - Functions, ) +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.plugin import ( + get_function_module_from_cache, load_function_module_by_id, replace_imports, - get_function_module_from_cache, resolve_valves_schema_options, ) -from open_webui.config import CACHE_DIR -from open_webui.constants import ERROR_MESSAGES -from fastapi import APIRouter, Depends, HTTPException, Request, status -from open_webui.utils.auth import get_admin_user, get_verified_user from pydantic import BaseModel, HttpUrl -from open_webui.internal.db import get_async_session from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) @@ -91,7 +92,7 @@ def github_url_to_raw_url(url: str) -> str: return url -@router.post('/load/url', response_model=Optional[dict]) +@router.post('/load/url', response_model=dict | None) async def load_function_from_url(request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)): # NOTE: This is NOT a SSRF vulnerability: # This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use, @@ -179,7 +180,7 @@ async def sync_functions( ############################ -@router.post('/create', response_model=Optional[FunctionResponse]) +@router.post('/create', response_model=FunctionResponse | None) async def create_new_function( request: Request, form_data: FunctionForm, @@ -240,7 +241,7 @@ async def create_new_function( ############################ -@router.get('/id/{id}', response_model=Optional[FunctionModel]) +@router.get('/id/{id}', response_model=FunctionModel | None) async def get_function_by_id(id: str, user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session)): function = await Functions.get_function_by_id(id, db=db) @@ -258,7 +259,7 @@ async def get_function_by_id(id: str, user=Depends(get_admin_user), db: AsyncSes ############################ -@router.post('/id/{id}/toggle', response_model=Optional[FunctionModel]) +@router.post('/id/{id}/toggle', response_model=FunctionModel | None) async def toggle_function_by_id(id: str, user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session)): function = await Functions.get_function_by_id(id, db=db) if function: @@ -283,7 +284,7 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user), db: Async ############################ -@router.post('/id/{id}/toggle/global', response_model=Optional[FunctionModel]) +@router.post('/id/{id}/toggle/global', response_model=FunctionModel | None) async def toggle_global_by_id(id: str, user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session)): function = await Functions.get_function_by_id(id, db=db) if function: @@ -308,7 +309,7 @@ async def toggle_global_by_id(id: str, user=Depends(get_admin_user), db: AsyncSe ############################ -@router.post('/id/{id}/update', response_model=Optional[FunctionModel]) +@router.post('/id/{id}/update', response_model=FunctionModel | None) async def update_function_by_id( request: Request, id: str, @@ -374,7 +375,7 @@ async def delete_function_by_id( ############################ -@router.get('/id/{id}/valves', response_model=Optional[dict]) +@router.get('/id/{id}/valves', response_model=dict | None) async def get_function_valves_by_id( id: str, user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session) ): @@ -400,7 +401,7 @@ async def get_function_valves_by_id( ############################ -@router.get('/id/{id}/valves/spec', response_model=Optional[dict]) +@router.get('/id/{id}/valves/spec', response_model=dict | None) async def get_function_valves_spec_by_id( request: Request, id: str, @@ -430,7 +431,7 @@ async def get_function_valves_spec_by_id( ############################ -@router.post('/id/{id}/valves/update', response_model=Optional[dict]) +@router.post('/id/{id}/valves/update', response_model=dict | None) async def update_function_valves_by_id( request: Request, id: str, @@ -476,7 +477,7 @@ async def update_function_valves_by_id( ############################ -@router.get('/id/{id}/valves/user', response_model=Optional[dict]) +@router.get('/id/{id}/valves/user', response_model=dict | None) async def get_function_user_valves_by_id( id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) ): @@ -497,7 +498,7 @@ async def get_function_user_valves_by_id( ) -@router.get('/id/{id}/valves/user/spec', response_model=Optional[dict]) +@router.get('/id/{id}/valves/user/spec', response_model=dict | None) async def get_function_user_valves_spec_by_id( request: Request, id: str, @@ -522,7 +523,7 @@ async def get_function_user_valves_spec_by_id( ) -@router.post('/id/{id}/valves/user/update', response_model=Optional[dict]) +@router.post('/id/{id}/valves/user/update', response_model=dict | None) async def update_function_user_valves_by_id( request: Request, id: str, diff --git a/backend/open_webui/routers/groups.py b/backend/open_webui/routers/groups.py index c45690fc3a..ff3a1997f7 100755 --- a/backend/open_webui/routers/groups.py +++ b/backend/open_webui/routers/groups.py @@ -1,26 +1,23 @@ +import logging import os from pathlib import Path from typing import Optional -import logging - -from open_webui.models.users import Users, UserInfoResponse -from open_webui.models.groups import ( - Groups, - GroupForm, - GroupInfoResponse, - GroupUpdateForm, - GroupResponse, - UserIdsForm, -) +from fastapi import APIRouter, Depends, HTTPException, Request, status from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES -from fastapi import APIRouter, Depends, HTTPException, Request, status - from open_webui.internal.db import get_async_session -from sqlalchemy.ext.asyncio import AsyncSession - +from open_webui.models.groups import ( + GroupForm, + GroupInfoResponse, + GroupResponse, + Groups, + GroupUpdateForm, + UserIdsForm, +) +from open_webui.models.users import UserInfoResponse, Users from open_webui.utils.auth import get_admin_user, get_verified_user +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) diff --git a/backend/open_webui/routers/images.py b/backend/open_webui/routers/images.py index e55b7c5798..5da73f82ca 100644 --- a/backend/open_webui/routers/images.py +++ b/backend/open_webui/routers/images.py @@ -1,46 +1,45 @@ +from __future__ import annotations + import asyncio import base64 -import uuid import io import json import logging import mimetypes import re +import uuid from pathlib import Path from typing import Optional - from urllib.parse import quote -import aiohttp +import aiohttp from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile from fastapi.responses import FileResponse - from open_webui.config import ( CACHE_DIR, IMAGE_AUTO_SIZE_MODELS_REGEX_PATTERN, IMAGE_URL_RESPONSE_MODELS_REGEX_PATTERN, ) from open_webui.constants import ERROR_MESSAGES -from open_webui.retrieval.web.utils import validate_url -from open_webui.env import AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_ALLOW_REDIRECTS, ENABLE_FORWARD_USER_INFO_HEADERS -from open_webui.utils.session_pool import get_session - -from open_webui.models.chats import Chats -from open_webui.routers.files import upload_file_handler, get_file_content_by_id -from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_permission -from open_webui.utils.headers import include_user_info_headers +from open_webui.env import AIOHTTP_CLIENT_ALLOW_REDIRECTS, AIOHTTP_CLIENT_SESSION_SSL, ENABLE_FORWARD_USER_INFO_HEADERS from open_webui.internal.db import get_async_session -from sqlalchemy.ext.asyncio import AsyncSession +from open_webui.models.chats import Chats +from open_webui.retrieval.web.utils import validate_url +from open_webui.routers.files import get_file_content_by_id, upload_file_handler +from open_webui.utils.access_control import has_permission +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.headers import include_user_info_headers from open_webui.utils.images.comfyui import ( ComfyUICreateImageForm, ComfyUIEditImageForm, ComfyUIWorkflow, - comfyui_upload_image, comfyui_create_image, comfyui_edit_image, + comfyui_upload_image, ) +from open_webui.utils.session_pool import get_session from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) @@ -121,17 +120,17 @@ class ImagesConfig(BaseModel): IMAGE_GENERATION_ENGINE: str IMAGE_GENERATION_MODEL: str - IMAGE_SIZE: Optional[str] - IMAGE_STEPS: Optional[int] + IMAGE_SIZE: str | None + IMAGE_STEPS: int | None IMAGES_OPENAI_API_BASE_URL: str IMAGES_OPENAI_API_KEY: str IMAGES_OPENAI_API_VERSION: str - IMAGES_OPENAI_API_PARAMS: Optional[dict | str] + IMAGES_OPENAI_API_PARAMS: dict | str | None AUTOMATIC1111_BASE_URL: str - AUTOMATIC1111_API_AUTH: Optional[dict | str] - AUTOMATIC1111_PARAMS: Optional[dict | str] + AUTOMATIC1111_API_AUTH: dict | str | None + AUTOMATIC1111_PARAMS: dict | str | None COMFYUI_BASE_URL: str COMFYUI_API_KEY: str @@ -145,7 +144,7 @@ class ImagesConfig(BaseModel): ENABLE_IMAGE_EDIT: bool IMAGE_EDIT_ENGINE: str IMAGE_EDIT_MODEL: str - IMAGE_EDIT_SIZE: Optional[str] + IMAGE_EDIT_SIZE: str | None IMAGES_EDIT_OPENAI_API_BASE_URL: str IMAGES_EDIT_OPENAI_API_KEY: str @@ -428,12 +427,12 @@ async def get_models(request: Request, user=Depends(get_verified_user)): class CreateImageForm(BaseModel): - model: Optional[str] = None + model: str | None = None prompt: str - size: Optional[str] = None + size: str | None = None n: int = 1 - steps: Optional[int] = None - negative_prompt: Optional[str] = None + steps: int | None = None + negative_prompt: str | None = None GenerateImageForm = CreateImageForm # Alias for backward compatibility @@ -528,7 +527,7 @@ async def generate_images(request: Request, form_data: CreateImageForm, user=Dep async def image_generations( request: Request, form_data: CreateImageForm, - metadata: Optional[dict] = None, + metadata: dict | None = None, user=None, ): # if IMAGE_SIZE = 'auto', default WidthxHeight to the 512x512 default @@ -776,18 +775,18 @@ async def image_generations( class EditImageForm(BaseModel): image: str | list[str] # base64-encoded image(s) or URL(s) prompt: str - model: Optional[str] = None - size: Optional[str] = None - n: Optional[int] = None - negative_prompt: Optional[str] = None - background: Optional[str] = None + model: str | None = None + size: str | None = None + n: int | None = None + negative_prompt: str | None = None + background: str | None = None @router.post('/edit') async def image_edits( request: Request, form_data: EditImageForm, - metadata: Optional[dict] = None, + metadata: dict | None = None, user=Depends(get_verified_user), ): size = None diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 8ff987b610..b35f5b793d 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -1,42 +1,40 @@ -from typing import List, Optional -from pydantic import BaseModel -from fastapi import APIRouter, Depends, HTTPException, status, Request, Query -from fastapi.responses import StreamingResponse +from __future__ import annotations -import logging import io +import logging import zipfile +from typing import List, Optional from urllib.parse import quote -from sqlalchemy.ext.asyncio import AsyncSession +from fastapi import APIRouter, Depends, HTTPException, Query, Request, status +from fastapi.responses import StreamingResponse +from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL +from open_webui.constants import ERROR_MESSAGES from open_webui.internal.db import get_async_session +from open_webui.models.access_grants import AccessGrants +from open_webui.models.files import FileMetadataResponse, FileModel, Files from open_webui.models.groups import Groups from open_webui.models.knowledge import ( KnowledgeFileListResponse, - Knowledges, KnowledgeForm, KnowledgeResponse, + Knowledges, KnowledgeUserResponse, ) -from open_webui.models.files import Files, FileModel, FileMetadataResponse +from open_webui.models.models import ModelForm, Models from open_webui.retrieval.vector.async_client import ASYNC_VECTOR_DB_CLIENT from open_webui.routers.retrieval import ( - process_file, - ProcessFileForm, - process_files_batch, BatchProcessFilesForm, + ProcessFileForm, + process_file, + process_files_batch, ) from open_webui.storage.provider import Storage - -from open_webui.constants import ERROR_MESSAGES -from open_webui.utils.auth import get_verified_user, get_admin_user -from open_webui.utils.access_control import has_permission, filter_allowed_access_grants +from open_webui.utils.access_control import filter_allowed_access_grants, has_permission from open_webui.utils.access_control.files import has_access_to_file -from open_webui.models.access_grants import AccessGrants - - -from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL -from open_webui.models.models import Models, ModelForm +from open_webui.utils.auth import get_admin_user, get_verified_user +from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) @@ -100,7 +98,7 @@ async def remove_knowledge_base_metadata_embedding(knowledge_base_id: str) -> bo class KnowledgeAccessResponse(KnowledgeUserResponse): - write_access: Optional[bool] = False + write_access: bool | None = False class KnowledgeAccessListResponse(BaseModel): @@ -110,7 +108,7 @@ class KnowledgeAccessListResponse(BaseModel): @router.get('/', response_model=KnowledgeAccessListResponse) async def get_knowledge_bases( - page: Optional[int] = 1, + page: int | None = 1, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): @@ -159,9 +157,9 @@ async def get_knowledge_bases( @router.get('/search', response_model=KnowledgeAccessListResponse) async def search_knowledge_bases( - query: Optional[str] = None, - view_option: Optional[str] = None, - page: Optional[int] = 1, + query: str | None = None, + view_option: str | None = None, + page: int | None = 1, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): @@ -215,8 +213,8 @@ async def search_knowledge_bases( @router.get('/search/files', response_model=KnowledgeFileListResponse) async def search_knowledge_files( - query: Optional[str] = None, - page: Optional[int] = 1, + query: str | None = None, + page: int | None = 1, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): @@ -242,7 +240,7 @@ async def search_knowledge_files( ############################ -@router.post('/create', response_model=Optional[KnowledgeResponse]) +@router.post('/create', response_model=KnowledgeResponse | None) async def create_new_knowledge( request: Request, form_data: KnowledgeForm, @@ -380,11 +378,11 @@ async def reindex_knowledge_base_metadata_embeddings( class KnowledgeFilesResponse(KnowledgeResponse): - files: Optional[list[FileMetadataResponse]] = None - write_access: Optional[bool] = False + files: list[FileMetadataResponse | None] = None + write_access: bool | None = False -@router.get('/{id}', response_model=Optional[KnowledgeFilesResponse]) +@router.get('/{id}', response_model=KnowledgeFilesResponse | None) async def get_knowledge_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): knowledge = await Knowledges.get_knowledge_by_id(id=id, db=db) @@ -431,7 +429,7 @@ async def get_knowledge_by_id(id: str, user=Depends(get_verified_user), db: Asyn ############################ -@router.post('/{id}/update', response_model=Optional[KnowledgeFilesResponse]) +@router.post('/{id}/update', response_model=KnowledgeFilesResponse | None) async def update_knowledge_by_id( request: Request, id: str, @@ -501,7 +499,7 @@ class KnowledgeAccessGrantsForm(BaseModel): access_grants: list[dict] -@router.post('/{id}/access/update', response_model=Optional[KnowledgeFilesResponse]) +@router.post('/{id}/access/update', response_model=KnowledgeFilesResponse | None) async def update_knowledge_access_by_id( request: Request, id: str, @@ -556,11 +554,11 @@ async def update_knowledge_access_by_id( @router.get('/{id}/files', response_model=KnowledgeFileListResponse) async def get_knowledge_files_by_id( id: str, - query: Optional[str] = None, - view_option: Optional[str] = None, - order_by: Optional[str] = None, - direction: Optional[str] = None, - page: Optional[int] = 1, + query: str | None = None, + view_option: str | None = None, + order_by: str | None = None, + direction: str | None = None, + page: int | None = 1, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): @@ -614,7 +612,7 @@ class KnowledgeFileIdForm(BaseModel): file_id: str -@router.post('/{id}/file/add', response_model=Optional[KnowledgeFilesResponse]) +@router.post('/{id}/file/add', response_model=KnowledgeFilesResponse | None) async def add_file_to_knowledge_by_id( request: Request, id: str, @@ -695,7 +693,7 @@ async def add_file_to_knowledge_by_id( ) -@router.post('/{id}/file/update', response_model=Optional[KnowledgeFilesResponse]) +@router.post('/{id}/file/update', response_model=KnowledgeFilesResponse | None) async def update_file_from_knowledge_by_id( request: Request, id: str, @@ -774,7 +772,7 @@ async def update_file_from_knowledge_by_id( ############################ -@router.post('/{id}/file/remove', response_model=Optional[KnowledgeFilesResponse]) +@router.post('/{id}/file/remove', response_model=KnowledgeFilesResponse | None) async def remove_file_from_knowledge_by_id( id: str, form_data: KnowledgeFileIdForm, @@ -936,7 +934,7 @@ async def delete_knowledge_by_id( ############################ -@router.post('/{id}/reset', response_model=Optional[KnowledgeResponse]) +@router.post('/{id}/reset', response_model=KnowledgeResponse | None) async def reset_knowledge_by_id( id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) ): @@ -978,7 +976,7 @@ async def reset_knowledge_by_id( ############################ -@router.post('/{id}/files/batch/add', response_model=Optional[KnowledgeFilesResponse]) +@router.post('/{id}/files/batch/add', response_model=KnowledgeFilesResponse | None) async def add_files_to_knowledge_batch( request: Request, id: str, diff --git a/backend/open_webui/routers/memories.py b/backend/open_webui/routers/memories.py index 6522118258..5108ab0e7c 100644 --- a/backend/open_webui/routers/memories.py +++ b/backend/open_webui/routers/memories.py @@ -1,17 +1,18 @@ -from fastapi import APIRouter, Depends, HTTPException, Request, status -from pydantic import BaseModel -import logging +from __future__ import annotations + import asyncio +import logging from typing import Optional +from fastapi import APIRouter, Depends, HTTPException, Request, status +from open_webui.constants import ERROR_MESSAGES +from open_webui.internal.db import get_async_session from open_webui.models.memories import Memories, MemoryModel from open_webui.retrieval.vector.async_client import ASYNC_VECTOR_DB_CLIENT -from open_webui.utils.auth import get_verified_user -from open_webui.internal.db import get_async_session -from sqlalchemy.ext.asyncio import AsyncSession - from open_webui.utils.access_control import has_permission -from open_webui.constants import ERROR_MESSAGES +from open_webui.utils.auth import get_verified_user +from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) @@ -56,10 +57,10 @@ class AddMemoryForm(BaseModel): class MemoryUpdateModel(BaseModel): - content: Optional[str] = None + content: str | None = None -@router.post('/add', response_model=Optional[MemoryModel]) +@router.post('/add', response_model=MemoryModel | None) async def add_memory( request: Request, form_data: AddMemoryForm, @@ -107,7 +108,7 @@ async def add_memory( class QueryMemoryForm(BaseModel): content: str - k: Optional[int] = 1 + k: int | None = 1 @router.post('/query') @@ -275,7 +276,7 @@ async def delete_memory_by_user_id( ############################ -@router.post('/{memory_id}/update', response_model=Optional[MemoryModel]) +@router.post('/{memory_id}/update', response_model=MemoryModel | None) async def update_memory_by_id( memory_id: str, request: Request, diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index 2a78daa94d..922d016a2e 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -1,28 +1,14 @@ -from typing import Optional -import io -import base64 -import json +from __future__ import annotations + import asyncio +import base64 +import io +import json import logging import posixpath +from typing import Optional from urllib.parse import unquote -from open_webui.models.groups import Groups -from open_webui.models.models import ( - ModelForm, - ModelMeta, - ModelModel, - ModelParams, - ModelResponse, - ModelListResponse, - ModelAccessListResponse, - ModelAccessResponse, - Models, -) -from open_webui.models.access_grants import AccessGrants - -from pydantic import BaseModel -from open_webui.constants import ERROR_MESSAGES from fastapi import ( APIRouter, Depends, @@ -32,13 +18,26 @@ from fastapi import ( status, ) from fastapi.responses import RedirectResponse, StreamingResponse - - -from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_permission, filter_allowed_access_grants from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL +from open_webui.constants import ERROR_MESSAGES from open_webui.env import ENABLE_PROFILE_IMAGE_URL_FORWARDING from open_webui.internal.db import get_async_session +from open_webui.models.access_grants import AccessGrants +from open_webui.models.groups import Groups +from open_webui.models.models import ( + ModelAccessListResponse, + ModelAccessResponse, + ModelForm, + ModelListResponse, + ModelMeta, + ModelModel, + ModelParams, + ModelResponse, + Models, +) +from open_webui.utils.access_control import filter_allowed_access_grants, has_permission +from open_webui.utils.auth import get_admin_user, get_verified_user +from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) @@ -46,7 +45,7 @@ log = logging.getLogger(__name__) router = APIRouter() -def _safe_static_redirect_path(url: str) -> Optional[str]: +def _safe_static_redirect_path(url: str) -> str | None: """ If url is a same-origin static asset path, return a normalized path safe for RedirectResponse Location. Otherwise None (caller should fall back to default). @@ -90,12 +89,12 @@ PAGE_ITEM_COUNT = 30 @router.get('/list', response_model=ModelAccessListResponse) # do NOT use "/" as path, conflicts with main.py async def get_models( - query: Optional[str] = None, - view_option: Optional[str] = None, - tag: Optional[str] = None, - order_by: Optional[str] = None, - direction: Optional[str] = None, - page: Optional[int] = 1, + query: str | None = None, + view_option: str | None = None, + tag: str | None = None, + order_by: str | None = None, + direction: str | None = None, + page: int | None = 1, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): @@ -192,7 +191,7 @@ async def get_model_tags(user=Depends(get_verified_user), db: AsyncSession = Dep ############################ -@router.post('/create', response_model=Optional[ModelModel]) +@router.post('/create', response_model=ModelModel | None) async def create_new_model( request: Request, form_data: ModelForm, @@ -409,7 +408,7 @@ class ModelIdForm(BaseModel): # Note: We're not using the typical url path param here, but instead using a query parameter to allow '/' in the id -@router.get('/model', response_model=Optional[ModelAccessResponse]) +@router.get('/model', response_model=ModelAccessResponse | None) async def get_model_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): model = await Models.get_model_by_id(id, db=db) if model: @@ -537,7 +536,7 @@ async def get_model_profile_image( ############################ -@router.post('/model/toggle', response_model=Optional[ModelResponse]) +@router.post('/model/toggle', response_model=ModelResponse | None) async def toggle_model_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): model = await Models.get_model_by_id(id, db=db) if model: @@ -578,7 +577,7 @@ async def toggle_model_by_id(id: str, user=Depends(get_verified_user), db: Async ############################ -@router.post('/model/update', response_model=Optional[ModelModel]) +@router.post('/model/update', response_model=ModelModel | None) async def update_model_by_id( request: Request, form_data: ModelForm, @@ -627,11 +626,11 @@ async def update_model_by_id( class ModelAccessGrantsForm(BaseModel): id: str - name: Optional[str] = None + name: str | None = None access_grants: list[dict] -@router.post('/model/access/update', response_model=Optional[ModelModel]) +@router.post('/model/access/update', response_model=ModelModel | None) async def update_model_access_by_id( request: Request, form_data: ModelAccessGrantsForm, diff --git a/backend/open_webui/routers/notes.py b/backend/open_webui/routers/notes.py index 5ed46b5d61..6dccc73f6d 100644 --- a/backend/open_webui/routers/notes.py +++ b/backend/open_webui/routers/notes.py @@ -2,39 +2,33 @@ import json import logging from typing import Optional - -from fastapi import APIRouter, Depends, HTTPException, Request, status, BackgroundTasks -from pydantic import BaseModel - -from open_webui.socket.main import sio - -from open_webui.models.groups import Groups -from open_webui.models.users import Users, UserResponse -from open_webui.models.notes import ( - NoteListResponse, - Notes, - NoteModel, - NoteForm, - NoteUserResponse, -) - +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request, status from open_webui.config import ( BYPASS_ADMIN_ACCESS_CONTROL, ENABLE_ADMIN_CHAT_ACCESS, ENABLE_ADMIN_EXPORT, ) from open_webui.constants import ERROR_MESSAGES - - -from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.internal.db import get_async_session +from open_webui.models.access_grants import AccessGrants +from open_webui.models.groups import Groups +from open_webui.models.notes import ( + NoteForm, + NoteListResponse, + NoteModel, + Notes, + NoteUserResponse, +) +from open_webui.models.users import UserResponse, Users +from open_webui.socket.main import sio from open_webui.utils.access_control import ( + filter_allowed_access_grants, has_permission, has_public_read_access_grant, has_public_write_access_grant, - filter_allowed_access_grants, ) -from open_webui.models.access_grants import AccessGrants -from open_webui.internal.db import get_async_session +from open_webui.utils.auth import get_admin_user, get_verified_user +from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 01fbf10f4f..d6348836e8 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -1,7 +1,8 @@ +from __future__ import annotations + # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, # least connections, or least response time for better resource utilization and performance optimization. - import asyncio import json import logging @@ -10,70 +11,60 @@ import random import re import time from datetime import datetime - from typing import Optional, Union from urllib.parse import urlparse + import aiohttp from aiocache import cached - - -from open_webui.utils.headers import include_user_info_headers -from open_webui.models.chats import Chats -from open_webui.models.users import UserModel - -from open_webui.env import ( - ENABLE_FORWARD_USER_INFO_HEADERS, - FORWARD_SESSION_INFO_HEADER_CHAT_ID, -) - from fastapi import ( + APIRouter, Depends, FastAPI, File, HTTPException, Request, UploadFile, - APIRouter, ) from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse -from pydantic import BaseModel, ConfigDict, validator - -from sqlalchemy.ext.asyncio import AsyncSession - +from open_webui.config import ( + UPLOAD_DIR, +) +from open_webui.constants import ERROR_MESSAGES +from open_webui.env import ( + AIOHTTP_CLIENT_SESSION_SSL, + AIOHTTP_CLIENT_TIMEOUT, + AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, + BYPASS_MODEL_ACCESS_CONTROL, + ENABLE_FORWARD_USER_INFO_HEADERS, + ENV, + FORWARD_SESSION_INFO_HEADER_CHAT_ID, + MODELS_CACHE_TTL, +) from open_webui.internal.db import get_async_session - - -from open_webui.models.models import Models from open_webui.models.access_grants import AccessGrants +from open_webui.models.chats import Chats from open_webui.models.groups import Groups +from open_webui.models.models import Models +from open_webui.models.users import UserModel from open_webui.utils.access_control import check_model_access +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.headers import include_user_info_headers from open_webui.utils.misc import ( calculate_sha256, ) -from open_webui.utils.session_pool import ( - cleanup_response, - get_session, - stream_wrapper, -) from open_webui.utils.payload import ( apply_model_params_to_body_ollama, apply_model_params_to_body_openai, apply_system_prompt_to_body, ) -from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.config import ( - UPLOAD_DIR, +from open_webui.utils.session_pool import ( + cleanup_response, + get_session, + stream_wrapper, ) -from open_webui.env import ( - ENV, - MODELS_CACHE_TTL, - AIOHTTP_CLIENT_SESSION_SSL, - AIOHTTP_CLIENT_TIMEOUT, - AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, - BYPASS_MODEL_ACCESS_CONTROL, -) -from open_webui.constants import ERROR_MESSAGES +from pydantic import BaseModel, ConfigDict, validator +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) @@ -126,12 +117,12 @@ async def send_request( url: str, method: str = 'POST', *, - payload: Optional[Union[str, bytes]] = None, - key: Optional[str] = None, + payload: Union[str, bytes | None] = None, + key: str | None = None, user: UserModel = None, stream: bool = False, - content_type: Optional[str] = None, - metadata: Optional[dict] = None, + content_type: str | None = None, + metadata: dict | None = None, ): r = None streaming = False @@ -225,7 +216,7 @@ async def get_status(): class ConnectionVerificationForm(BaseModel): url: str - key: Optional[str] = None + key: str | None = None @router.post('/verify') @@ -279,7 +270,7 @@ async def get_config(request: Request, user=Depends(get_admin_user)): class OllamaConfigForm(BaseModel): - ENABLE_OLLAMA_API: Optional[bool] = None + ENABLE_OLLAMA_API: bool | None = None OLLAMA_BASE_URLS: list[str] OLLAMA_API_CONFIGS: dict @@ -437,7 +428,7 @@ async def get_filtered_models(models, user, db=None): @router.get('/api/tags') @router.get('/api/tags/{url_idx}') -async def get_ollama_tags(request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)): +async def get_ollama_tags(request: Request, url_idx: int | None = None, user=Depends(get_verified_user)): if not request.app.state.config.ENABLE_OLLAMA_API: raise HTTPException(status_code=503, detail=ERROR_MESSAGES.OLLAMA_API_DISABLED) @@ -514,7 +505,7 @@ async def get_ollama_loaded_models(request: Request, user=Depends(get_admin_user @router.get('/api/version') @router.get('/api/version/{url_idx}') -async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): +async def get_ollama_versions(request: Request, url_idx: int | None = None): if request.app.state.config.ENABLE_OLLAMA_API: if url_idx is None: # returns lowest version @@ -560,7 +551,7 @@ async def get_ollama_versions(request: Request, url_idx: Optional[int] = None): class ModelNameForm(BaseModel): - model: Optional[str] = None + model: str | None = None model_config = ConfigDict( extra='allow', ) @@ -654,8 +645,8 @@ async def pull_model( class PushModelForm(BaseModel): model: str - insecure: Optional[bool] = None - stream: Optional[bool] = None + insecure: bool | None = None + stream: bool | None = None @router.delete('/api/push') @@ -663,7 +654,7 @@ class PushModelForm(BaseModel): async def push_model( request: Request, form_data: PushModelForm, - url_idx: Optional[int] = None, + url_idx: int | None = None, user=Depends(get_admin_user), ): if not request.app.state.config.ENABLE_OLLAMA_API: @@ -694,9 +685,9 @@ async def push_model( class CreateModelForm(BaseModel): - model: Optional[str] = None - stream: Optional[bool] = None - path: Optional[str] = None + model: str | None = None + stream: bool | None = None + path: str | None = None model_config = ConfigDict(extra='allow') @@ -734,7 +725,7 @@ class CopyModelForm(BaseModel): async def copy_model( request: Request, form_data: CopyModelForm, - url_idx: Optional[int] = None, + url_idx: int | None = None, user=Depends(get_admin_user), ): if not request.app.state.config.ENABLE_OLLAMA_API: @@ -769,7 +760,7 @@ async def copy_model( async def delete_model( request: Request, form_data: ModelNameForm, - url_idx: Optional[int] = None, + url_idx: int | None = None, user=Depends(get_admin_user), ): if not request.app.state.config.ENABLE_OLLAMA_API: @@ -843,9 +834,9 @@ async def show_model_info(request: Request, form_data: ModelNameForm, user=Depen class GenerateEmbedForm(BaseModel): model: str input: list[str] | str - truncate: Optional[bool] = None - options: Optional[dict] = None - keep_alive: Optional[Union[int, str]] = None + truncate: bool | None = None + options: dict | None = None + keep_alive: Union[int, str | None] = None model_config = ConfigDict( extra='allow', @@ -857,7 +848,7 @@ class GenerateEmbedForm(BaseModel): async def embed( request: Request, form_data: GenerateEmbedForm, - url_idx: Optional[int] = None, + url_idx: int | None = None, user=Depends(get_verified_user), ): if not request.app.state.config.ENABLE_OLLAMA_API: @@ -907,8 +898,8 @@ async def embed( class GenerateEmbeddingsForm(BaseModel): model: str prompt: str - options: Optional[dict] = None - keep_alive: Optional[Union[int, str]] = None + options: dict | None = None + keep_alive: Union[int, str | None] = None @router.post('/api/embeddings') @@ -916,7 +907,7 @@ class GenerateEmbeddingsForm(BaseModel): async def embeddings( request: Request, form_data: GenerateEmbeddingsForm, - url_idx: Optional[int] = None, + url_idx: int | None = None, user=Depends(get_verified_user), ): if not request.app.state.config.ENABLE_OLLAMA_API: @@ -965,17 +956,17 @@ async def embeddings( class GenerateCompletionForm(BaseModel): model: str - prompt: Optional[str] = None - suffix: Optional[str] = None - images: Optional[list[str]] = None - format: Optional[Union[dict, str]] = None - options: Optional[dict] = None - system: Optional[str] = None - template: Optional[str] = None - context: Optional[list[int]] = None - stream: Optional[bool] = True - raw: Optional[bool] = None - keep_alive: Optional[Union[int, str]] = None + prompt: str | None = None + suffix: str | None = None + images: list[str | None] = None + format: Union[dict, str | None] = None + options: dict | None = None + system: str | None = None + template: str | None = None + context: list[int | None] = None + stream: bool | None = True + raw: bool | None = None + keep_alive: Union[int, str | None] = None @router.post('/api/generate') @@ -983,7 +974,7 @@ class GenerateCompletionForm(BaseModel): async def generate_completion( request: Request, form_data: GenerateCompletionForm, - url_idx: Optional[int] = None, + url_idx: int | None = None, user=Depends(get_verified_user), ): if not request.app.state.config.ENABLE_OLLAMA_API: @@ -1027,9 +1018,9 @@ async def generate_completion( class ChatMessage(BaseModel): role: str - content: Optional[str] = None - tool_calls: Optional[list[dict]] = None - images: Optional[list[str]] = None + content: str | None = None + tool_calls: list[dict | None] = None + images: list[str | None] = None model_config = ConfigDict(extra='allow') @@ -1046,18 +1037,18 @@ class ChatMessage(BaseModel): class GenerateChatCompletionForm(BaseModel): model: str messages: list[ChatMessage] - format: Optional[Union[dict, str]] = None - options: Optional[dict] = None - template: Optional[str] = None - stream: Optional[bool] = True - keep_alive: Optional[Union[int, str]] = None - tools: Optional[list[dict]] = None + format: Union[dict, str | None] = None + options: dict | None = None + template: str | None = None + stream: bool | None = True + keep_alive: Union[int, str | None] = None + tools: list[dict | None] = None model_config = ConfigDict( extra='allow', ) -async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = None): +async def get_ollama_url(request: Request, model: str, url_idx: int | None = None): if url_idx is None: models = request.app.state.OLLAMA_MODELS if model not in models: @@ -1075,7 +1066,7 @@ async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] = async def generate_chat_completion( request: Request, form_data: dict, - url_idx: Optional[int] = None, + url_idx: int | None = None, user=Depends(get_verified_user), bypass_system_prompt: bool = False, ): @@ -1162,7 +1153,7 @@ class OpenAIChatMessageContent(BaseModel): class OpenAIChatMessage(BaseModel): role: str - content: Union[Optional[str], list[OpenAIChatMessageContent]] + content: Union[str | None, list[OpenAIChatMessageContent]] model_config = ConfigDict(extra='allow') @@ -1186,7 +1177,7 @@ class OpenAICompletionForm(BaseModel): async def generate_openai_completion( request: Request, form_data: dict, - url_idx: Optional[int] = None, + url_idx: int | None = None, user=Depends(get_verified_user), ): # NOTE: We intentionally do NOT use Depends(get_async_session) here. @@ -1248,7 +1239,7 @@ async def generate_openai_completion( async def generate_openai_chat_completion( request: Request, form_data: dict, - url_idx: Optional[int] = None, + url_idx: int | None = None, user=Depends(get_verified_user), ): # NOTE: We intentionally do NOT use Depends(get_async_session) here. @@ -1313,7 +1304,7 @@ async def generate_openai_chat_completion( async def generate_anthropic_messages( request: Request, form_data: dict, - url_idx: Optional[int] = None, + url_idx: int | None = None, user=Depends(get_verified_user), ): """ @@ -1371,7 +1362,7 @@ class ResponsesForm(BaseModel): async def generate_responses( request: Request, form_data: ResponsesForm, - url_idx: Optional[int] = None, + url_idx: int | None = None, user=Depends(get_verified_user), ): """ @@ -1422,7 +1413,7 @@ async def generate_responses( @router.get('/v1/models/{url_idx}') async def get_openai_models( request: Request, - url_idx: Optional[int] = None, + url_idx: int | None = None, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): @@ -1565,7 +1556,7 @@ async def download_file_stream(ollama_url, file_url, file_path, file_name, chunk async def download_model( request: Request, form_data: UrlForm, - url_idx: Optional[int] = None, + url_idx: int | None = None, user=Depends(get_admin_user), ): allowed_hosts = ['https://huggingface.co/', 'https://github.com/'] @@ -1598,7 +1589,7 @@ async def download_model( async def upload_model( request: Request, file: UploadFile = File(...), - url_idx: Optional[int] = None, + url_idx: int | None = None, user=Depends(get_admin_user), ): if url_idx is None: diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 4404d4d906..900ebc6ad9 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import hashlib import json @@ -8,62 +10,52 @@ from urllib.parse import quote, urlparse import aiohttp from aiocache import cached - - from azure.identity import DefaultAzureCredential, get_bearer_token_provider - -from fastapi import Depends, HTTPException, Request, APIRouter, status +from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import ( FileResponse, - StreamingResponse, JSONResponse, PlainTextResponse, + StreamingResponse, ) -from pydantic import BaseModel, ConfigDict - -from sqlalchemy.ext.asyncio import AsyncSession - -from open_webui.internal.db import get_async_session - -from open_webui.models.models import Models -from open_webui.models.access_grants import AccessGrants -from open_webui.models.groups import Groups -from open_webui.utils.access_control import has_connection_access, check_model_access from open_webui.config import ( CACHE_DIR, ) +from open_webui.constants import ERROR_MESSAGES from open_webui.env import ( - MODELS_CACHE_TTL, AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, - ENABLE_FORWARD_USER_INFO_HEADERS, - FORWARD_SESSION_INFO_HEADER_CHAT_ID, BYPASS_MODEL_ACCESS_CONTROL, + ENABLE_FORWARD_USER_INFO_HEADERS, ENABLE_OPENAI_API_PASSTHROUGH, + FORWARD_SESSION_INFO_HEADER_CHAT_ID, + MODELS_CACHE_TTL, ) +from open_webui.internal.db import get_async_session +from open_webui.models.access_grants import AccessGrants +from open_webui.models.groups import Groups +from open_webui.models.models import Models from open_webui.models.users import UserModel - -from open_webui.constants import ERROR_MESSAGES - - -from open_webui.utils.payload import ( - apply_model_params_to_body_openai, - apply_system_prompt_to_body, -) +from open_webui.utils.access_control import check_model_access, has_connection_access +from open_webui.utils.anthropic import get_anthropic_models, is_anthropic_url +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.headers import get_custom_headers, include_user_info_headers from open_webui.utils.misc import ( convert_logit_bias_input_to_json, stream_chunks_handler, ) +from open_webui.utils.payload import ( + apply_model_params_to_body_openai, + apply_system_prompt_to_body, +) from open_webui.utils.session_pool import ( cleanup_response, get_session, stream_wrapper, ) - -from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.headers import include_user_info_headers, get_custom_headers -from open_webui.utils.anthropic import is_anthropic_url, get_anthropic_models +from pydantic import BaseModel, ConfigDict +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) @@ -160,7 +152,7 @@ async def get_headers_and_cookies( url, key=None, config=None, - metadata: Optional[dict] = None, + metadata: dict | None = None, user: UserModel = None, ): cookies = {} @@ -256,7 +248,7 @@ async def get_config(request: Request, user=Depends(get_admin_user)): class OpenAIConfigForm(BaseModel): - ENABLE_OPENAI_API: Optional[bool] = None + ENABLE_OPENAI_API: bool | None = None OPENAI_API_BASE_URLS: list[str] OPENAI_API_KEYS: list[str] OPENAI_API_CONFIGS: dict @@ -493,7 +485,6 @@ async def get_filtered_models(models, user, db=None): return filtered_models - @cached( ttl=MODELS_CACHE_TTL, key=lambda _, user: f'openai_all_models_{user.id}' if user else 'openai_all_models', @@ -571,14 +562,13 @@ async def get_all_models(request: Request, user: UserModel) -> dict[str, list]: models = get_merged_models(map(extract_data, responses)) log.debug(f'models: {models}') - request.app.state.OPENAI_MODELS = models return {'data': list(models.values())} @router.get('/models') @router.get('/models/{url_idx}') -async def get_models(request: Request, url_idx: Optional[int] = None, user=Depends(get_verified_user)): +async def get_models(request: Request, url_idx: int | None = None, user=Depends(get_verified_user)): if not request.app.state.config.ENABLE_OPENAI_API: raise HTTPException(status_code=503, detail='OpenAI API is disabled') @@ -670,7 +660,7 @@ class ConnectionVerificationForm(BaseModel): url: str key: str - config: Optional[dict] = None + config: dict | None = None @router.post('/verify') @@ -1361,20 +1351,20 @@ class ResponsesForm(BaseModel): model_config = ConfigDict(extra='allow') model: str - input: Optional[list | str] = None - instructions: Optional[str] = None - stream: Optional[bool] = None - temperature: Optional[float] = None - max_output_tokens: Optional[int] = None - top_p: Optional[float] = None - tools: Optional[list] = None - tool_choice: Optional[str | dict] = None - text: Optional[dict] = None - truncation: Optional[str] = None - metadata: Optional[dict] = None - store: Optional[bool] = None - reasoning: Optional[dict] = None - previous_response_id: Optional[str] = None + input: list | str | None = None + instructions: str | None = None + stream: bool | None = None + temperature: float | None = None + max_output_tokens: int | None = None + top_p: float | None = None + tools: list | None = None + tool_choice: str | dict | None = None + text: dict | None = None + truncation: str | None = None + metadata: dict | None = None + store: bool | None = None + reasoning: dict | None = None + previous_response_id: str | None = None @router.post('/responses') diff --git a/backend/open_webui/routers/pipelines.py b/backend/open_webui/routers/pipelines.py index 580fb42fb2..5e0d4dc199 100644 --- a/backend/open_webui/routers/pipelines.py +++ b/backend/open_webui/routers/pipelines.py @@ -1,4 +1,11 @@ +import logging +import os +import shutil +from typing import Optional + +import aiohttp from fastapi import ( + APIRouter, Depends, FastAPI, File, @@ -7,24 +14,14 @@ from fastapi import ( Request, UploadFile, status, - APIRouter, ) -import aiohttp -import os -import logging -import shutil -from pydantic import BaseModel -from starlette.responses import FileResponse -from typing import Optional - -from open_webui.env import AIOHTTP_CLIENT_SESSION_SSL from open_webui.config import CACHE_DIR from open_webui.constants import ERROR_MESSAGES - - +from open_webui.env import AIOHTTP_CLIENT_SESSION_SSL from open_webui.routers.openai import get_all_models_responses - from open_webui.utils.auth import get_admin_user +from pydantic import BaseModel +from starlette.responses import FileResponse log = logging.getLogger(__name__) diff --git a/backend/open_webui/routers/prompts.py b/backend/open_webui/routers/prompts.py index 755034f880..f027dc48c6 100644 --- a/backend/open_webui/routers/prompts.py +++ b/backend/open_webui/routers/prompts.py @@ -1,14 +1,11 @@ -from typing import Optional -from fastapi import APIRouter, Depends, HTTPException, status, Request +from __future__ import annotations -from open_webui.models.prompts import ( - PromptForm, - PromptUserResponse, - PromptAccessResponse, - PromptAccessListResponse, - PromptModel, - Prompts, -) +from typing import Optional + +from fastapi import APIRouter, Depends, HTTPException, Request, status +from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL +from open_webui.constants import ERROR_MESSAGES +from open_webui.internal.db import get_async_session from open_webui.models.access_grants import AccessGrants from open_webui.models.groups import Groups from open_webui.models.prompt_history import ( @@ -16,13 +13,18 @@ from open_webui.models.prompt_history import ( PromptHistoryModel, PromptHistoryResponse, ) -from open_webui.constants import ERROR_MESSAGES +from open_webui.models.prompts import ( + PromptAccessListResponse, + PromptAccessResponse, + PromptForm, + PromptModel, + Prompts, + PromptUserResponse, +) +from open_webui.utils.access_control import filter_allowed_access_grants, has_permission from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_permission, filter_allowed_access_grants -from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL -from open_webui.internal.db import get_async_session -from sqlalchemy.ext.asyncio import AsyncSession from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncSession class PromptVersionUpdateForm(BaseModel): @@ -32,7 +34,7 @@ class PromptVersionUpdateForm(BaseModel): class PromptMetadataForm(BaseModel): name: str command: str - tags: Optional[list[str]] = None + tags: list[str | None] = None router = APIRouter() @@ -66,12 +68,12 @@ async def get_prompt_tags(user=Depends(get_verified_user), db: AsyncSession = De @router.get('/list', response_model=PromptAccessListResponse) async def get_prompt_list( - query: Optional[str] = None, - view_option: Optional[str] = None, - tag: Optional[str] = None, - order_by: Optional[str] = None, - direction: Optional[str] = None, - page: Optional[int] = 1, + query: str | None = None, + view_option: str | None = None, + tag: str | None = None, + order_by: str | None = None, + direction: str | None = None, + page: int | None = 1, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): @@ -136,7 +138,7 @@ async def get_prompt_list( ############################ -@router.post('/create', response_model=Optional[PromptModel]) +@router.post('/create', response_model=PromptModel | None) async def create_new_prompt( request: Request, form_data: PromptForm, @@ -191,7 +193,7 @@ async def create_new_prompt( ############################ -@router.get('/command/{command}', response_model=Optional[PromptAccessResponse]) +@router.get('/command/{command}', response_model=PromptAccessResponse | None) async def get_prompt_by_command( command: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) ): @@ -235,7 +237,7 @@ async def get_prompt_by_command( ############################ -@router.get('/id/{prompt_id}', response_model=Optional[PromptAccessResponse]) +@router.get('/id/{prompt_id}', response_model=PromptAccessResponse | None) async def get_prompt_by_id( prompt_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) ): @@ -279,7 +281,7 @@ async def get_prompt_by_id( ############################ -@router.post('/id/{prompt_id}/update', response_model=Optional[PromptModel]) +@router.post('/id/{prompt_id}/update', response_model=PromptModel | None) async def update_prompt_by_id( request: Request, prompt_id: str, @@ -345,7 +347,7 @@ async def update_prompt_by_id( ############################ -@router.post('/id/{prompt_id}/update/meta', response_model=Optional[PromptModel]) +@router.post('/id/{prompt_id}/update/meta', response_model=PromptModel | None) async def update_prompt_metadata( prompt_id: str, form_data: PromptMetadataForm, @@ -398,7 +400,7 @@ async def update_prompt_metadata( ) -@router.post('/id/{prompt_id}/update/version', response_model=Optional[PromptModel]) +@router.post('/id/{prompt_id}/update/version', response_model=PromptModel | None) async def set_prompt_version( prompt_id: str, form_data: PromptVersionUpdateForm, @@ -447,7 +449,7 @@ class PromptAccessGrantsForm(BaseModel): access_grants: list[dict] -@router.post('/id/{prompt_id}/access/update', response_model=Optional[PromptModel]) +@router.post('/id/{prompt_id}/access/update', response_model=PromptModel | None) async def update_prompt_access_by_id( request: Request, prompt_id: str, @@ -496,7 +498,7 @@ async def update_prompt_access_by_id( ############################ -@router.post('/id/{prompt_id}/toggle', response_model=Optional[PromptModel]) +@router.post('/id/{prompt_id}/toggle', response_model=PromptModel | None) async def toggle_prompt_active( prompt_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) ): diff --git a/backend/open_webui/routers/retrieval.py b/backend/open_webui/routers/retrieval.py index 201e6a63fb..9a1dfdcc16 100644 --- a/backend/open_webui/routers/retrieval.py +++ b/backend/open_webui/routers/retrieval.py @@ -1,130 +1,122 @@ +from __future__ import annotations + +import asyncio import json import logging import mimetypes import os -import shutil -import asyncio - import re +import shutil import uuid from datetime import datetime from pathlib import Path -from typing import Iterator, List, Optional, Sequence, Union +from typing import Iterator, Optional, Sequence, Union +import tiktoken from fastapi import ( + APIRouter, Depends, FastAPI, - Query, File, Form, HTTPException, - UploadFile, + Query, Request, + UploadFile, status, - APIRouter, ) -from fastapi.middleware.cors import CORSMiddleware from fastapi.concurrency import run_in_threadpool -from pydantic import BaseModel -import tiktoken - - +from fastapi.middleware.cors import CORSMiddleware +from langchain_core.documents import Document from langchain_text_splitters import ( + MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter, TokenTextSplitter, - MarkdownHeaderTextSplitter, ) -from langchain_core.documents import Document - -from open_webui.models.files import FileModel, FileUpdateForm, Files -from open_webui.utils.access_control.files import has_access_to_file -from open_webui.models.knowledge import Knowledges -from open_webui.storage.provider import Storage -from open_webui.internal.db import get_async_db, get_async_session -from sqlalchemy.ext.asyncio import AsyncSession - - -from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT -from open_webui.retrieval.vector.async_client import ASYNC_VECTOR_DB_CLIENT - -# Document loaders - -from open_webui.retrieval.loaders.youtube import YoutubeLoader - -# Web search engines -from open_webui.retrieval.web.main import SearchResult -from open_webui.retrieval.web.utils import get_web_loader -from open_webui.retrieval.web.ollama import search_ollama_cloud -from open_webui.retrieval.web.perplexity_search import search_perplexity_search -from open_webui.retrieval.web.brave import search_brave -from open_webui.retrieval.web.brave_llm_context import search_brave_llm_context -from open_webui.retrieval.web.kagi import search_kagi -from open_webui.retrieval.web.mojeek import search_mojeek -from open_webui.retrieval.web.bocha import search_bocha -from open_webui.retrieval.web.duckduckgo import search_duckduckgo -from open_webui.retrieval.web.google_pse import search_google_pse -from open_webui.retrieval.web.jina_search import search_jina -from open_webui.retrieval.web.searchapi import search_searchapi -from open_webui.retrieval.web.serpapi import search_serpapi -from open_webui.retrieval.web.searxng import search_searxng -from open_webui.retrieval.web.yacy import search_yacy -from open_webui.retrieval.web.serper import search_serper -from open_webui.retrieval.web.serply import search_serply -from open_webui.retrieval.web.serpstack import search_serpstack -from open_webui.retrieval.web.tavily import search_tavily -from open_webui.retrieval.web.bing import search_bing -from open_webui.retrieval.web.azure import search_azure -from open_webui.retrieval.web.exa import search_exa -from open_webui.retrieval.web.perplexity import search_perplexity -from open_webui.retrieval.web.sougou import search_sougou -from open_webui.retrieval.web.firecrawl import search_firecrawl -from open_webui.retrieval.web.external import search_external -from open_webui.retrieval.web.yandex import search_yandex -from open_webui.retrieval.web.ydc import search_youcom - -from open_webui.retrieval.utils import ( - build_loader_from_config, - filter_accessible_collections, - get_content_from_url, - get_embedding_function, - get_reranking_function, - get_model_path, - query_collection, - query_collection_with_hybrid_search, - query_doc, - query_doc_with_hybrid_search, -) -from open_webui.retrieval.vector.utils import filter_metadata -from open_webui.utils.misc import ( - calculate_sha256_string, - sanitize_text_for_db, -) -from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_permission - from open_webui.config import ( + DEFAULT_LOCALE, ENV, + RAG_EMBEDDING_CONTENT_PREFIX, RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, + RAG_EMBEDDING_QUERY_PREFIX, RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_TRUST_REMOTE_CODE, UPLOAD_DIR, - DEFAULT_LOCALE, - RAG_EMBEDDING_CONTENT_PREFIX, - RAG_EMBEDDING_QUERY_PREFIX, ) +from open_webui.constants import ERROR_MESSAGES from open_webui.env import ( DEVICE_TYPE, DOCKER, RAG_EMBEDDING_TIMEOUT, SENTENCE_TRANSFORMERS_BACKEND, - SENTENCE_TRANSFORMERS_MODEL_KWARGS, SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND, SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS, SENTENCE_TRANSFORMERS_CROSS_ENCODER_SIGMOID_ACTIVATION_FUNCTION, + SENTENCE_TRANSFORMERS_MODEL_KWARGS, ) +from open_webui.internal.db import get_async_db, get_async_session +from open_webui.models.files import FileModel, Files, FileUpdateForm +from open_webui.models.knowledge import Knowledges -from open_webui.constants import ERROR_MESSAGES +# Document loaders +from open_webui.retrieval.loaders.youtube import YoutubeLoader +from open_webui.retrieval.utils import ( + build_loader_from_config, + filter_accessible_collections, + get_content_from_url, + get_embedding_function, + get_model_path, + get_reranking_function, + query_collection, + query_collection_with_hybrid_search, + query_doc, + query_doc_with_hybrid_search, +) +from open_webui.retrieval.vector.async_client import ASYNC_VECTOR_DB_CLIENT +from open_webui.retrieval.vector.factory import VECTOR_DB_CLIENT +from open_webui.retrieval.vector.utils import filter_metadata +from open_webui.retrieval.web.azure import search_azure +from open_webui.retrieval.web.bing import search_bing +from open_webui.retrieval.web.bocha import search_bocha +from open_webui.retrieval.web.brave import search_brave +from open_webui.retrieval.web.brave_llm_context import search_brave_llm_context +from open_webui.retrieval.web.duckduckgo import search_duckduckgo +from open_webui.retrieval.web.exa import search_exa +from open_webui.retrieval.web.external import search_external +from open_webui.retrieval.web.firecrawl import search_firecrawl +from open_webui.retrieval.web.google_pse import search_google_pse +from open_webui.retrieval.web.jina_search import search_jina +from open_webui.retrieval.web.kagi import search_kagi + +# Web search engines +from open_webui.retrieval.web.main import SearchResult +from open_webui.retrieval.web.mojeek import search_mojeek +from open_webui.retrieval.web.ollama import search_ollama_cloud +from open_webui.retrieval.web.perplexity import search_perplexity +from open_webui.retrieval.web.perplexity_search import search_perplexity_search +from open_webui.retrieval.web.searchapi import search_searchapi +from open_webui.retrieval.web.searxng import search_searxng +from open_webui.retrieval.web.serpapi import search_serpapi +from open_webui.retrieval.web.serper import search_serper +from open_webui.retrieval.web.serply import search_serply +from open_webui.retrieval.web.serpstack import search_serpstack +from open_webui.retrieval.web.sougou import search_sougou +from open_webui.retrieval.web.tavily import search_tavily +from open_webui.retrieval.web.utils import get_web_loader +from open_webui.retrieval.web.yacy import search_yacy +from open_webui.retrieval.web.yandex import search_yandex +from open_webui.retrieval.web.ydc import search_youcom +from open_webui.storage.provider import Storage +from open_webui.utils.access_control import has_permission +from open_webui.utils.access_control.files import has_access_to_file +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.misc import ( + calculate_sha256_string, + sanitize_text_for_db, +) +from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) @@ -162,7 +154,7 @@ def get_ef( def get_rf( engine: str = '', - reranking_model: Optional[str] = None, + reranking_model: str | None = None, external_reranker_url: str = '', external_reranker_api_key: str = '', external_reranker_timeout: str = '', @@ -249,7 +241,7 @@ router = APIRouter() class CollectionNameForm(BaseModel): - collection_name: Optional[str] = None + collection_name: str | None = None class ProcessUrlForm(CollectionNameForm): @@ -257,7 +249,7 @@ class ProcessUrlForm(CollectionNameForm): class SearchForm(BaseModel): - queries: List[str] + queries: list[str] @router.get('/embedding') @@ -302,14 +294,14 @@ class AzureOpenAIConfigForm(BaseModel): class EmbeddingModelUpdateForm(BaseModel): - openai_config: Optional[OpenAIConfigForm] = None - ollama_config: Optional[OllamaConfigForm] = None - azure_openai_config: Optional[AzureOpenAIConfigForm] = None + openai_config: OpenAIConfigForm | None = None + ollama_config: OllamaConfigForm | None = None + azure_openai_config: AzureOpenAIConfigForm | None = None RAG_EMBEDDING_ENGINE: str RAG_EMBEDDING_MODEL: str - RAG_EMBEDDING_BATCH_SIZE: Optional[int] = 1 - ENABLE_ASYNC_EMBEDDING: Optional[bool] = True - RAG_EMBEDDING_CONCURRENT_REQUESTS: Optional[int] = 0 + RAG_EMBEDDING_BATCH_SIZE: int | None = 1 + ENABLE_ASYNC_EMBEDDING: bool | None = True + RAG_EMBEDDING_CONCURRENT_REQUESTS: int | None = 0 def unload_embedding_model(request: Request): @@ -566,153 +558,153 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)): class WebConfig(BaseModel): - ENABLE_WEB_SEARCH: Optional[bool] = None - WEB_SEARCH_ENGINE: Optional[str] = None - WEB_SEARCH_TRUST_ENV: Optional[bool] = None - WEB_SEARCH_RESULT_COUNT: Optional[int] = None - WEB_SEARCH_CONCURRENT_REQUESTS: Optional[int] = None - WEB_SEARCH_DOMAIN_FILTER_LIST: Optional[List[str]] = [] - WEB_FETCH_MAX_CONTENT_LENGTH: Optional[int] = None - WEB_LOADER_CONCURRENT_REQUESTS: Optional[int] = None - BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None - BYPASS_WEB_SEARCH_WEB_LOADER: Optional[bool] = None - OLLAMA_CLOUD_WEB_SEARCH_API_KEY: Optional[str] = None - SEARXNG_QUERY_URL: Optional[str] = None - SEARXNG_LANGUAGE: Optional[str] = None - YACY_QUERY_URL: Optional[str] = None - YACY_USERNAME: Optional[str] = None - YACY_PASSWORD: Optional[str] = None - GOOGLE_PSE_API_KEY: Optional[str] = None - GOOGLE_PSE_ENGINE_ID: Optional[str] = None - BRAVE_SEARCH_API_KEY: Optional[str] = None - BRAVE_SEARCH_CONTEXT_TOKENS: Optional[int] = None - KAGI_SEARCH_API_KEY: Optional[str] = None - MOJEEK_SEARCH_API_KEY: Optional[str] = None - BOCHA_SEARCH_API_KEY: Optional[str] = None - SERPSTACK_API_KEY: Optional[str] = None - SERPSTACK_HTTPS: Optional[bool] = None - SERPER_API_KEY: Optional[str] = None - SERPLY_API_KEY: Optional[str] = None - DDGS_BACKEND: Optional[str] = None - TAVILY_API_KEY: Optional[str] = None - SEARCHAPI_API_KEY: Optional[str] = None - SEARCHAPI_ENGINE: Optional[str] = None - SERPAPI_API_KEY: Optional[str] = None - SERPAPI_ENGINE: Optional[str] = None - JINA_API_KEY: Optional[str] = None - JINA_API_BASE_URL: Optional[str] = None - BING_SEARCH_V7_ENDPOINT: Optional[str] = None - BING_SEARCH_V7_SUBSCRIPTION_KEY: Optional[str] = None - EXA_API_KEY: Optional[str] = None - PERPLEXITY_API_KEY: Optional[str] = None - PERPLEXITY_MODEL: Optional[str] = None - PERPLEXITY_SEARCH_CONTEXT_USAGE: Optional[str] = None - PERPLEXITY_SEARCH_API_URL: Optional[str] = None - SOUGOU_API_SID: Optional[str] = None - SOUGOU_API_SK: Optional[str] = None - WEB_LOADER_ENGINE: Optional[str] = None - WEB_LOADER_TIMEOUT: Optional[str] = None - ENABLE_WEB_LOADER_SSL_VERIFICATION: Optional[bool] = None - PLAYWRIGHT_WS_URL: Optional[str] = None - PLAYWRIGHT_TIMEOUT: Optional[int] = None - FIRECRAWL_API_KEY: Optional[str] = None - FIRECRAWL_API_BASE_URL: Optional[str] = None - FIRECRAWL_TIMEOUT: Optional[str] = None - TAVILY_EXTRACT_DEPTH: Optional[str] = None - EXTERNAL_WEB_SEARCH_URL: Optional[str] = None - EXTERNAL_WEB_SEARCH_API_KEY: Optional[str] = None - EXTERNAL_WEB_LOADER_URL: Optional[str] = None - EXTERNAL_WEB_LOADER_API_KEY: Optional[str] = None - YOUTUBE_LOADER_LANGUAGE: Optional[List[str]] = None - YOUTUBE_LOADER_PROXY_URL: Optional[str] = None - YOUTUBE_LOADER_TRANSLATION: Optional[str] = None - YANDEX_WEB_SEARCH_URL: Optional[str] = None - YANDEX_WEB_SEARCH_API_KEY: Optional[str] = None - YANDEX_WEB_SEARCH_CONFIG: Optional[str] = None - YOUCOM_API_KEY: Optional[str] = None + ENABLE_WEB_SEARCH: bool | None = None + WEB_SEARCH_ENGINE: str | None = None + WEB_SEARCH_TRUST_ENV: bool | None = None + WEB_SEARCH_RESULT_COUNT: int | None = None + WEB_SEARCH_CONCURRENT_REQUESTS: int | None = None + WEB_SEARCH_DOMAIN_FILTER_LIST: list[str | None] = [] + WEB_FETCH_MAX_CONTENT_LENGTH: int | None = None + WEB_LOADER_CONCURRENT_REQUESTS: int | None = None + BYPASS_WEB_SEARCH_EMBEDDING_AND_RETRIEVAL: bool | None = None + BYPASS_WEB_SEARCH_WEB_LOADER: bool | None = None + OLLAMA_CLOUD_WEB_SEARCH_API_KEY: str | None = None + SEARXNG_QUERY_URL: str | None = None + SEARXNG_LANGUAGE: str | None = None + YACY_QUERY_URL: str | None = None + YACY_USERNAME: str | None = None + YACY_PASSWORD: str | None = None + GOOGLE_PSE_API_KEY: str | None = None + GOOGLE_PSE_ENGINE_ID: str | None = None + BRAVE_SEARCH_API_KEY: str | None = None + BRAVE_SEARCH_CONTEXT_TOKENS: int | None = None + KAGI_SEARCH_API_KEY: str | None = None + MOJEEK_SEARCH_API_KEY: str | None = None + BOCHA_SEARCH_API_KEY: str | None = None + SERPSTACK_API_KEY: str | None = None + SERPSTACK_HTTPS: bool | None = None + SERPER_API_KEY: str | None = None + SERPLY_API_KEY: str | None = None + DDGS_BACKEND: str | None = None + TAVILY_API_KEY: str | None = None + SEARCHAPI_API_KEY: str | None = None + SEARCHAPI_ENGINE: str | None = None + SERPAPI_API_KEY: str | None = None + SERPAPI_ENGINE: str | None = None + JINA_API_KEY: str | None = None + JINA_API_BASE_URL: str | None = None + BING_SEARCH_V7_ENDPOINT: str | None = None + BING_SEARCH_V7_SUBSCRIPTION_KEY: str | None = None + EXA_API_KEY: str | None = None + PERPLEXITY_API_KEY: str | None = None + PERPLEXITY_MODEL: str | None = None + PERPLEXITY_SEARCH_CONTEXT_USAGE: str | None = None + PERPLEXITY_SEARCH_API_URL: str | None = None + SOUGOU_API_SID: str | None = None + SOUGOU_API_SK: str | None = None + WEB_LOADER_ENGINE: str | None = None + WEB_LOADER_TIMEOUT: str | None = None + ENABLE_WEB_LOADER_SSL_VERIFICATION: bool | None = None + PLAYWRIGHT_WS_URL: str | None = None + PLAYWRIGHT_TIMEOUT: int | None = None + FIRECRAWL_API_KEY: str | None = None + FIRECRAWL_API_BASE_URL: str | None = None + FIRECRAWL_TIMEOUT: str | None = None + TAVILY_EXTRACT_DEPTH: str | None = None + EXTERNAL_WEB_SEARCH_URL: str | None = None + EXTERNAL_WEB_SEARCH_API_KEY: str | None = None + EXTERNAL_WEB_LOADER_URL: str | None = None + EXTERNAL_WEB_LOADER_API_KEY: str | None = None + YOUTUBE_LOADER_LANGUAGE: list[str | None] = None + YOUTUBE_LOADER_PROXY_URL: str | None = None + YOUTUBE_LOADER_TRANSLATION: str | None = None + YANDEX_WEB_SEARCH_URL: str | None = None + YANDEX_WEB_SEARCH_API_KEY: str | None = None + YANDEX_WEB_SEARCH_CONFIG: str | None = None + YOUCOM_API_KEY: str | None = None class ConfigForm(BaseModel): # RAG settings - RAG_TEMPLATE: Optional[str] = None - TOP_K: Optional[int] = None - BYPASS_EMBEDDING_AND_RETRIEVAL: Optional[bool] = None - RAG_FULL_CONTEXT: Optional[bool] = None + RAG_TEMPLATE: str | None = None + TOP_K: int | None = None + BYPASS_EMBEDDING_AND_RETRIEVAL: bool | None = None + RAG_FULL_CONTEXT: bool | None = None # Hybrid search settings - ENABLE_RAG_HYBRID_SEARCH: Optional[bool] = None - ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS: Optional[bool] = None - TOP_K_RERANKER: Optional[int] = None - RELEVANCE_THRESHOLD: Optional[float] = None - HYBRID_BM25_WEIGHT: Optional[float] = None + ENABLE_RAG_HYBRID_SEARCH: bool | None = None + ENABLE_RAG_HYBRID_SEARCH_ENRICHED_TEXTS: bool | None = None + TOP_K_RERANKER: int | None = None + RELEVANCE_THRESHOLD: float | None = None + HYBRID_BM25_WEIGHT: float | None = None # Content extraction settings - CONTENT_EXTRACTION_ENGINE: Optional[str] = None - PDF_EXTRACT_IMAGES: Optional[bool] = None - PDF_LOADER_MODE: Optional[str] = None + CONTENT_EXTRACTION_ENGINE: str | None = None + PDF_EXTRACT_IMAGES: bool | None = None + PDF_LOADER_MODE: str | None = None - DATALAB_MARKER_API_KEY: Optional[str] = None - DATALAB_MARKER_API_BASE_URL: Optional[str] = None - DATALAB_MARKER_ADDITIONAL_CONFIG: Optional[str] = None - DATALAB_MARKER_SKIP_CACHE: Optional[bool] = None - DATALAB_MARKER_FORCE_OCR: Optional[bool] = None - DATALAB_MARKER_PAGINATE: Optional[bool] = None - DATALAB_MARKER_STRIP_EXISTING_OCR: Optional[bool] = None - DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION: Optional[bool] = None - DATALAB_MARKER_FORMAT_LINES: Optional[bool] = None - DATALAB_MARKER_USE_LLM: Optional[bool] = None - DATALAB_MARKER_OUTPUT_FORMAT: Optional[str] = None + DATALAB_MARKER_API_KEY: str | None = None + DATALAB_MARKER_API_BASE_URL: str | None = None + DATALAB_MARKER_ADDITIONAL_CONFIG: str | None = None + DATALAB_MARKER_SKIP_CACHE: bool | None = None + DATALAB_MARKER_FORCE_OCR: bool | None = None + DATALAB_MARKER_PAGINATE: bool | None = None + DATALAB_MARKER_STRIP_EXISTING_OCR: bool | None = None + DATALAB_MARKER_DISABLE_IMAGE_EXTRACTION: bool | None = None + DATALAB_MARKER_FORMAT_LINES: bool | None = None + DATALAB_MARKER_USE_LLM: bool | None = None + DATALAB_MARKER_OUTPUT_FORMAT: str | None = None - EXTERNAL_DOCUMENT_LOADER_URL: Optional[str] = None - EXTERNAL_DOCUMENT_LOADER_API_KEY: Optional[str] = None + EXTERNAL_DOCUMENT_LOADER_URL: str | None = None + EXTERNAL_DOCUMENT_LOADER_API_KEY: str | None = None - TIKA_SERVER_URL: Optional[str] = None - DOCLING_SERVER_URL: Optional[str] = None - DOCLING_API_KEY: Optional[str] = None - DOCLING_PARAMS: Optional[dict] = None - DOCUMENT_INTELLIGENCE_ENDPOINT: Optional[str] = None - DOCUMENT_INTELLIGENCE_KEY: Optional[str] = None - DOCUMENT_INTELLIGENCE_MODEL: Optional[str] = None - MISTRAL_OCR_API_BASE_URL: Optional[str] = None - MISTRAL_OCR_API_KEY: Optional[str] = None - PADDLEOCR_VL_BASE_URL: Optional[str] = None - PADDLEOCR_VL_TOKEN: Optional[str] = None + TIKA_SERVER_URL: str | None = None + DOCLING_SERVER_URL: str | None = None + DOCLING_API_KEY: str | None = None + DOCLING_PARAMS: dict | None = None + DOCUMENT_INTELLIGENCE_ENDPOINT: str | None = None + DOCUMENT_INTELLIGENCE_KEY: str | None = None + DOCUMENT_INTELLIGENCE_MODEL: str | None = None + MISTRAL_OCR_API_BASE_URL: str | None = None + MISTRAL_OCR_API_KEY: str | None = None + PADDLEOCR_VL_BASE_URL: str | None = None + PADDLEOCR_VL_TOKEN: str | None = None # MinerU settings - MINERU_API_MODE: Optional[str] = None - MINERU_API_URL: Optional[str] = None - MINERU_API_KEY: Optional[str] = None - MINERU_API_TIMEOUT: Optional[str] = None - MINERU_PARAMS: Optional[dict] = None + MINERU_API_MODE: str | None = None + MINERU_API_URL: str | None = None + MINERU_API_KEY: str | None = None + MINERU_API_TIMEOUT: str | None = None + MINERU_PARAMS: dict | None = None # Reranking settings - RAG_RERANKING_MODEL: Optional[str] = None - RAG_RERANKING_ENGINE: Optional[str] = None - RAG_RERANKING_BATCH_SIZE: Optional[int] = None - RAG_EXTERNAL_RERANKER_URL: Optional[str] = None - RAG_EXTERNAL_RERANKER_API_KEY: Optional[str] = None - RAG_EXTERNAL_RERANKER_TIMEOUT: Optional[str] = None + RAG_RERANKING_MODEL: str | None = None + RAG_RERANKING_ENGINE: str | None = None + RAG_RERANKING_BATCH_SIZE: int | None = None + RAG_EXTERNAL_RERANKER_URL: str | None = None + RAG_EXTERNAL_RERANKER_API_KEY: str | None = None + RAG_EXTERNAL_RERANKER_TIMEOUT: str | None = None # Chunking settings - TEXT_SPLITTER: Optional[str] = None - ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER: Optional[bool] = None - CHUNK_SIZE: Optional[int] = None - CHUNK_MIN_SIZE_TARGET: Optional[int] = None - CHUNK_OVERLAP: Optional[int] = None + TEXT_SPLITTER: str | None = None + ENABLE_MARKDOWN_HEADER_TEXT_SPLITTER: bool | None = None + CHUNK_SIZE: int | None = None + CHUNK_MIN_SIZE_TARGET: int | None = None + CHUNK_OVERLAP: int | None = None # File upload settings - FILE_MAX_SIZE: Optional[Union[int, str]] = None - FILE_MAX_COUNT: Optional[Union[int, str]] = None - FILE_IMAGE_COMPRESSION_WIDTH: Optional[Union[int, str]] = None - FILE_IMAGE_COMPRESSION_HEIGHT: Optional[Union[int, str]] = None - ALLOWED_FILE_EXTENSIONS: Optional[List[str]] = None + FILE_MAX_SIZE: Union[int, str | None] = None + FILE_MAX_COUNT: Union[int, str | None] = None + FILE_IMAGE_COMPRESSION_WIDTH: Union[int, str | None] = None + FILE_IMAGE_COMPRESSION_HEIGHT: Union[int, str | None] = None + ALLOWED_FILE_EXTENSIONS: list[str | None] = None # Integration settings - ENABLE_GOOGLE_DRIVE_INTEGRATION: Optional[bool] = None - ENABLE_ONEDRIVE_INTEGRATION: Optional[bool] = None + ENABLE_GOOGLE_DRIVE_INTEGRATION: bool | None = None + ENABLE_ONEDRIVE_INTEGRATION: bool | None = None # Web search settings - web: Optional[WebConfig] = None + web: WebConfig | None = None @router.post('/config/update') @@ -1341,7 +1333,7 @@ def save_docs_to_vector_db( request: Request, docs, collection_name, - metadata: Optional[dict] = None, + metadata: dict | None = None, overwrite: bool = False, split: bool = True, add: bool = False, @@ -1539,8 +1531,8 @@ def save_docs_to_vector_db( class ProcessFileForm(BaseModel): file_id: str - content: Optional[str] = None - collection_name: Optional[str] = None + content: str | None = None + collection_name: str | None = None @router.post('/process/file') @@ -1767,7 +1759,7 @@ async def process_file( class ProcessTextForm(BaseModel): name: str content: str - collection_name: Optional[str] = None + collection_name: str | None = None @router.post('/process/text') @@ -2357,10 +2349,10 @@ async def _validate_collection_access(collection_names: list[str], user, access_ class QueryDocForm(BaseModel): collection_name: str query: str - k: Optional[int] = None - k_reranker: Optional[int] = None - r: Optional[float] = None - hybrid: Optional[bool] = None + k: int | None = None + k_reranker: int | None = None + r: float | None = None + hybrid: bool | None = None @router.post('/query/doc') @@ -2423,12 +2415,12 @@ async def query_doc_handler( class QueryCollectionsForm(BaseModel): collection_names: list[str] query: str - k: Optional[int] = None - k_reranker: Optional[int] = None - r: Optional[float] = None - hybrid: Optional[bool] = None - hybrid_bm25_weight: Optional[float] = None - enable_enriched_texts: Optional[bool] = None + k: int | None = None + k_reranker: int | None = None + r: float | None = None + hybrid: bool | None = None + hybrid_bm25_weight: float | None = None + enable_enriched_texts: bool | None = None @router.post('/query/collection') @@ -2581,24 +2573,24 @@ async def reset_upload_dir(user=Depends(get_admin_user)) -> bool: if ENV == 'dev': @router.get('/ef/{text}') - async def get_embeddings(request: Request, text: Optional[str] = 'Hello World!'): + async def get_embeddings(request: Request, text: str | None = 'Hello World!'): return {'result': await request.app.state.EMBEDDING_FUNCTION(text, prefix=RAG_EMBEDDING_QUERY_PREFIX)} class BatchProcessFilesForm(BaseModel): - files: List[FileModel] + files: list[FileModel] collection_name: str class BatchProcessFilesResult(BaseModel): file_id: str status: str - error: Optional[str] = None + error: str | None = None class BatchProcessFilesResponse(BaseModel): - results: List[BatchProcessFilesResult] - errors: List[BatchProcessFilesResult] + results: list[BatchProcessFilesResult] + errors: list[BatchProcessFilesResult] @router.post('/process/files/batch') @@ -2622,12 +2614,12 @@ async def process_files_batch( if collection_name: await _validate_collection_access([collection_name], user, access_type='write') - file_results: List[BatchProcessFilesResult] = [] - file_errors: List[BatchProcessFilesResult] = [] - file_updates: List[FileUpdateForm] = [] + file_results: list[BatchProcessFilesResult] = [] + file_errors: list[BatchProcessFilesResult] = [] + file_updates: list[FileUpdateForm] = [] # Prepare all documents first - all_docs: List[Document] = [] + all_docs: list[Document] = [] for file in form_data.files: try: @@ -2653,7 +2645,7 @@ async def process_files_batch( continue text_content = file.data.get('content', '') - docs: List[Document] = [ + docs: list[Document] = [ Document( page_content=text_content.replace('
', '\n'), metadata={ diff --git a/backend/open_webui/routers/scim.py b/backend/open_webui/routers/scim.py index 4c440141a9..9292523adc 100644 --- a/backend/open_webui/routers/scim.py +++ b/backend/open_webui/routers/scim.py @@ -7,31 +7,27 @@ NOTE: This is an experimental implementation and may not fully comply with SCIM import hmac import logging -import uuid import time -from typing import Optional, List, Dict, Any +import uuid from datetime import datetime, timezone +from typing import Any, Dict, List, Optional -from fastapi import APIRouter, Depends, HTTPException, Request, Query, Header, status +from fastapi import APIRouter, Depends, Header, HTTPException, Query, Request, status from fastapi.responses import JSONResponse -from pydantic import BaseModel, Field, ConfigDict - -from open_webui.models.users import Users, UserModel -from open_webui.models.groups import Groups, GroupModel +from open_webui.config import OAUTH_PROVIDERS +from open_webui.constants import ERROR_MESSAGES +from open_webui.env import SCIM_AUTH_PROVIDER +from open_webui.internal.db import get_async_session +from open_webui.models.groups import GroupModel, Groups +from open_webui.models.users import UserModel, Users from open_webui.utils.auth import ( + decode_token, get_admin_user, get_current_user, - decode_token, get_verified_user, ) -from open_webui.constants import ERROR_MESSAGES - -from open_webui.config import OAUTH_PROVIDERS -from open_webui.env import SCIM_AUTH_PROVIDER - - +from pydantic import BaseModel, ConfigDict, Field from sqlalchemy.ext.asyncio import AsyncSession -from open_webui.internal.db import get_async_session log = logging.getLogger(__name__) diff --git a/backend/open_webui/routers/skills.py b/backend/open_webui/routers/skills.py index ede5afd814..55aa351ff0 100644 --- a/backend/open_webui/routers/skills.py +++ b/backend/open_webui/routers/skills.py @@ -1,28 +1,25 @@ import logging from typing import Optional -from open_webui.models.groups import Groups -from pydantic import BaseModel - from fastapi import APIRouter, Depends, HTTPException, Request, status -from sqlalchemy.ext.asyncio import AsyncSession - +from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL +from open_webui.constants import ERROR_MESSAGES from open_webui.internal.db import get_async_session +from open_webui.models.access_grants import AccessGrants +from open_webui.models.groups import Groups from open_webui.models.skills import ( + SkillAccessListResponse, + SkillAccessResponse, SkillForm, SkillModel, SkillResponse, - SkillUserResponse, - SkillAccessResponse, - SkillAccessListResponse, Skills, + SkillUserResponse, ) -from open_webui.models.access_grants import AccessGrants +from open_webui.utils.access_control import filter_allowed_access_grants, has_permission from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import has_permission, filter_allowed_access_grants - -from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL -from open_webui.constants import ERROR_MESSAGES +from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py index 56c7e1d1b0..c64706e1a8 100644 --- a/backend/open_webui/routers/tasks.py +++ b/backend/open_webui/routers/tasks.py @@ -1,40 +1,36 @@ -from fastapi import APIRouter, Depends, HTTPException, Response, status, Request -from fastapi.responses import JSONResponse, RedirectResponse - -from pydantic import BaseModel -from typing import Optional import logging import re +from typing import Optional -from open_webui.utils.chat import generate_chat_completion -from open_webui.utils.task import ( - title_generation_template, - follow_up_generation_template, - query_generation_template, - image_prompt_generation_template, - autocomplete_generation_template, - tags_generation_template, - emoji_generation_template, - moa_response_generation_template, -) -from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.constants import ERROR_MESSAGES, TASKS - -from open_webui.routers.pipelines import process_pipeline_inlet_filter - -from open_webui.utils.task import get_task_model_id - +from fastapi import APIRouter, Depends, HTTPException, Request, Response, status +from fastapi.responses import JSONResponse, RedirectResponse from open_webui.config import ( - DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE, - DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, - DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE, - DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, - DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE, + DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE, + DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE, DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE, + DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, + DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE, + DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE, DEFAULT_VOICE_MODE_PROMPT_TEMPLATE, ) +from open_webui.constants import ERROR_MESSAGES, TASKS +from open_webui.routers.pipelines import process_pipeline_inlet_filter +from open_webui.utils.auth import get_admin_user, get_verified_user +from open_webui.utils.chat import generate_chat_completion +from open_webui.utils.task import ( + autocomplete_generation_template, + emoji_generation_template, + follow_up_generation_template, + get_task_model_id, + image_prompt_generation_template, + moa_response_generation_template, + query_generation_template, + tags_generation_template, + title_generation_template, +) +from pydantic import BaseModel log = logging.getLogger(__name__) diff --git a/backend/open_webui/routers/terminals.py b/backend/open_webui/routers/terminals.py index c251b20d48..d2ef5f0ea3 100644 --- a/backend/open_webui/routers/terminals.py +++ b/backend/open_webui/routers/terminals.py @@ -12,14 +12,13 @@ from urllib.parse import unquote import aiohttp from fastapi import APIRouter, Depends, Request, Response, WebSocket from fastapi.responses import JSONResponse, StreamingResponse -from starlette.background import BackgroundTask - -from open_webui.utils.auth import get_verified_user -from open_webui.utils.access_control import has_connection_access -from open_webui.env import AIOHTTP_CLIENT_SESSION_SSL from open_webui.config import TERMINAL_PROXY_HEADERS +from open_webui.env import AIOHTTP_CLIENT_SESSION_SSL from open_webui.models.groups import Groups from open_webui.models.users import Users +from open_webui.utils.access_control import has_connection_access +from open_webui.utils.auth import get_verified_user +from starlette.background import BackgroundTask log = logging.getLogger(__name__) @@ -199,6 +198,7 @@ async def _resolve_authenticated_connection(ws: WebSocket, server_id: str): """ import asyncio import json + from open_webui.utils.auth import decode_token # First-message authentication diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index cd11bcde5e..43719ae514 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -1,44 +1,43 @@ +from __future__ import annotations + import logging +import re +import time from pathlib import Path from typing import Optional -import time -import re + import aiohttp -from open_webui.env import AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_TIMEOUT -from open_webui.models.groups import Groups -from pydantic import BaseModel, HttpUrl from fastapi import APIRouter, Depends, HTTPException, Request, status -from sqlalchemy.ext.asyncio import AsyncSession +from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL, CACHE_DIR +from open_webui.constants import ERROR_MESSAGES +from open_webui.env import AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_TIMEOUT from open_webui.internal.db import get_async_session - - +from open_webui.models.access_grants import AccessGrants +from open_webui.models.groups import Groups from open_webui.models.oauth_sessions import OAuthSessions from open_webui.models.tools import ( + ToolAccessResponse, ToolForm, ToolModel, ToolResponse, - ToolUserResponse, - ToolAccessResponse, Tools, + ToolUserResponse, ) -from open_webui.models.access_grants import AccessGrants +from open_webui.utils.access_control import ( + filter_allowed_access_grants, + has_access, + has_permission, +) +from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.plugin import ( + get_tool_module_from_cache, load_tool_module_by_id, replace_imports, - get_tool_module_from_cache, resolve_valves_schema_options, ) -from open_webui.utils.tools import get_tool_specs -from open_webui.utils.auth import get_admin_user, get_verified_user -from open_webui.utils.access_control import ( - has_permission, - has_access, - filter_allowed_access_grants, -) -from open_webui.utils.tools import get_tool_servers - -from open_webui.config import CACHE_DIR, BYPASS_ADMIN_ACCESS_CONTROL -from open_webui.constants import ERROR_MESSAGES +from open_webui.utils.tools import get_tool_servers, get_tool_specs +from pydantic import BaseModel, HttpUrl +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) @@ -248,7 +247,7 @@ def github_url_to_raw_url(url: str) -> str: return url -@router.post('/load/url', response_model=Optional[dict]) +@router.post('/load/url', response_model=dict | None) async def load_tool_from_url(request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)): # NOTE: This is NOT a SSRF vulnerability: # This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use, @@ -323,7 +322,7 @@ async def export_tools( ############################ -@router.post('/create', response_model=Optional[ToolResponse]) +@router.post('/create', response_model=ToolResponse | None) async def create_new_tools( request: Request, form_data: ToolForm, @@ -401,7 +400,7 @@ async def create_new_tools( ############################ -@router.get('/id/{id}', response_model=Optional[ToolAccessResponse]) +@router.get('/id/{id}', response_model=ToolAccessResponse | None) async def get_tools_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): tools = await Tools.get_tool_by_id(id, db=db) @@ -448,7 +447,7 @@ async def get_tools_by_id(id: str, user=Depends(get_verified_user), db: AsyncSes ############################ -@router.post('/id/{id}/update', response_model=Optional[ToolModel]) +@router.post('/id/{id}/update', response_model=ToolModel | None) async def update_tools_by_id( request: Request, id: str, @@ -541,7 +540,7 @@ class ToolAccessGrantsForm(BaseModel): access_grants: list[dict] -@router.post('/id/{id}/access/update', response_model=Optional[ToolModel]) +@router.post('/id/{id}/access/update', response_model=ToolModel | None) async def update_tool_access_by_id( request: Request, id: str, @@ -634,7 +633,7 @@ async def delete_tools_by_id( ############################ -@router.get('/id/{id}/valves', response_model=Optional[dict]) +@router.get('/id/{id}/valves', response_model=dict | None) async def get_tools_valves_by_id( id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) ): @@ -676,7 +675,7 @@ async def get_tools_valves_by_id( ############################ -@router.get('/id/{id}/valves/spec', response_model=Optional[dict]) +@router.get('/id/{id}/valves/spec', response_model=dict | None) async def get_tools_valves_spec_by_id( request: Request, id: str, @@ -726,7 +725,7 @@ async def get_tools_valves_spec_by_id( ############################ -@router.post('/id/{id}/valves/update', response_model=Optional[dict]) +@router.post('/id/{id}/valves/update', response_model=dict | None) async def update_tools_valves_by_id( request: Request, id: str, @@ -789,7 +788,7 @@ async def update_tools_valves_by_id( ############################ -@router.get('/id/{id}/valves/user', response_model=Optional[dict]) +@router.get('/id/{id}/valves/user', response_model=dict | None) async def get_tools_user_valves_by_id( id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) ): @@ -826,7 +825,7 @@ async def get_tools_user_valves_by_id( ) -@router.get('/id/{id}/valves/user/spec', response_model=Optional[dict]) +@router.get('/id/{id}/valves/user/spec', response_model=dict | None) async def get_tools_user_valves_spec_by_id( request: Request, id: str, @@ -871,7 +870,7 @@ async def get_tools_user_valves_spec_by_id( return None -@router.post('/id/{id}/valves/user/update', response_model=Optional[dict]) +@router.post('/id/{id}/valves/user/update', response_model=dict | None) async def update_tools_user_valves_by_id( request: Request, id: str, diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index 33d1cd425c..c4a54744c5 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -1,46 +1,40 @@ -import logging -from typing import Optional -from sqlalchemy.ext.asyncio import AsyncSession +from __future__ import annotations + import base64 import io - +import logging +from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Request, status -from fastapi.responses import Response, StreamingResponse, FileResponse -from pydantic import BaseModel, ConfigDict - - -from open_webui.models.auths import Auths -from open_webui.models.oauth_sessions import OAuthSessions - -from open_webui.models.groups import Groups - -from open_webui.models.users import ( - UserModel, - UserGroupIdsModel, - UserGroupIdsListResponse, - UserInfoResponse, - UserInfoListResponse, - UserRoleUpdateForm, - UserStatus, - Users, - UserSettings, - UserUpdateForm, -) - +from fastapi.responses import FileResponse, Response, StreamingResponse from open_webui.constants import ERROR_MESSAGES from open_webui.env import ENABLE_PROFILE_IMAGE_URL_FORWARDING, PROFILE_IMAGE_ALLOWED_MIME_TYPES, STATIC_DIR from open_webui.internal.db import get_async_session - - +from open_webui.models.auths import Auths +from open_webui.models.groups import Groups +from open_webui.models.oauth_sessions import OAuthSessions +from open_webui.models.users import ( + UserGroupIdsListResponse, + UserGroupIdsModel, + UserInfoListResponse, + UserInfoResponse, + UserModel, + UserRoleUpdateForm, + Users, + UserSettings, + UserStatus, + UserUpdateForm, +) +from open_webui.socket.main import disconnect_user_sessions +from open_webui.utils.access_control import get_permissions, has_permission from open_webui.utils.auth import ( get_admin_user, get_password_hash, get_verified_user, validate_password, ) -from open_webui.utils.access_control import get_permissions, has_permission -from open_webui.socket.main import disconnect_user_sessions +from pydantic import BaseModel, ConfigDict +from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) @@ -59,10 +53,10 @@ PAGE_ITEM_COUNT = 30 @router.get('/', response_model=UserGroupIdsListResponse) async def get_users( - query: Optional[str] = None, - order_by: Optional[str] = None, - direction: Optional[str] = None, - page: Optional[int] = 1, + query: str | None = None, + order_by: str | None = None, + direction: str | None = None, + page: int | None = 1, user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session), ): @@ -114,10 +108,10 @@ async def get_all_users( @router.get('/search', response_model=UserInfoListResponse) async def search_users( - query: Optional[str] = None, - order_by: Optional[str] = None, - direction: Optional[str] = None, - page: Optional[int] = 1, + query: str | None = None, + order_by: str | None = None, + direction: str | None = None, + page: int | None = 1, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): @@ -275,7 +269,7 @@ async def update_default_user_permissions(request: Request, form_data: UserPermi ############################ -@router.get('/user/settings', response_model=Optional[UserSettings]) +@router.get('/user/settings', response_model=UserSettings | None) async def get_user_settings_by_session_user( user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) ): @@ -372,7 +366,7 @@ async def update_user_status_by_session_user( ############################ -@router.get('/user/info', response_model=Optional[dict]) +@router.get('/user/info', response_model=dict | None) async def get_user_info_by_session_user(user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): # user already fetched by get_verified_user — no need to refetch return user.info @@ -383,7 +377,7 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user), db: Asy ############################ -@router.post('/user/info/update', response_model=Optional[dict]) +@router.post('/user/info/update', response_model=dict | None) async def update_user_info_by_session_user( form_data: dict, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) ): @@ -408,8 +402,8 @@ async def update_user_info_by_session_user( class UserActiveResponse(UserStatus): name: str - profile_image_url: Optional[str] = None - groups: Optional[list] = [] + profile_image_url: str | None = None + groups: list | None = [] is_active: bool model_config = ConfigDict(extra='allow') @@ -536,7 +530,7 @@ async def get_user_active_status_by_id( ############################ -@router.post('/{user_id}/update', response_model=Optional[UserModel]) +@router.post('/{user_id}/update', response_model=UserModel | None) async def update_user_by_id( user_id: str, form_data: UserUpdateForm, diff --git a/backend/open_webui/routers/utils.py b/backend/open_webui/routers/utils.py index 20705c2c44..92caf5fee9 100644 --- a/backend/open_webui/routers/utils.py +++ b/backend/open_webui/routers/utils.py @@ -1,19 +1,19 @@ -import black -import logging -import markdown +from __future__ import annotations -from open_webui.models.chats import ChatTitleMessagesForm +import logging + +import black +import markdown +from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT from open_webui.constants import ERROR_MESSAGES -from fastapi import APIRouter, Depends, HTTPException, Request, Response, status -from pydantic import BaseModel -from starlette.responses import FileResponse - - -from open_webui.utils.misc import get_gravatar_url -from open_webui.utils.pdf_generator import PDFGenerator +from open_webui.models.chats import ChatTitleMessagesForm from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.code_interpreter import execute_code_jupyter +from open_webui.utils.misc import get_gravatar_url +from open_webui.utils.pdf_generator import PDFGenerator +from pydantic import BaseModel +from starlette.responses import FileResponse log = logging.getLogger(__name__) diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py index bfb1003678..a3def26a7b 100644 --- a/backend/open_webui/socket/main.py +++ b/backend/open_webui/socket/main.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import asyncio import logging import random import sys import time -from typing import Dict, Set +from typing import Dict import pycrdt as Y import socketio diff --git a/backend/open_webui/socket/utils.py b/backend/open_webui/socket/utils.py index 16b0cc3855..0cf2a09246 100644 --- a/backend/open_webui/socket/utils.py +++ b/backend/open_webui/socket/utils.py @@ -1,9 +1,10 @@ import json import uuid -from open_webui.utils.redis import get_redis_connection -from open_webui.env import REDIS_KEY_PREFIX -from typing import Optional, List, Tuple +from typing import List, Optional, Tuple + import pycrdt as Y +from open_webui.env import REDIS_KEY_PREFIX +from open_webui.utils.redis import get_redis_connection class RedisLock: diff --git a/backend/open_webui/storage/provider.py b/backend/open_webui/storage/provider.py index f70f3e862b..7edcd56e78 100644 --- a/backend/open_webui/storage/provider.py +++ b/backend/open_webui/storage/provider.py @@ -1,38 +1,38 @@ -import os -import shutil import json import logging +import os import re +import shutil from abc import ABC, abstractmethod -from typing import BinaryIO, Tuple, Dict +from typing import BinaryIO, Dict, Tuple import boto3 +from azure.core.exceptions import ResourceNotFoundError +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobServiceClient from botocore.config import Config from botocore.exceptions import ClientError +from google.cloud import storage +from google.cloud.exceptions import GoogleCloudError, NotFound from open_webui.config import ( + AZURE_STORAGE_CONTAINER_NAME, + AZURE_STORAGE_ENDPOINT, + AZURE_STORAGE_KEY, + GCS_BUCKET_NAME, + GOOGLE_APPLICATION_CREDENTIALS_JSON, S3_ACCESS_KEY_ID, + S3_ADDRESSING_STYLE, S3_BUCKET_NAME, + S3_ENABLE_TAGGING, S3_ENDPOINT_URL, S3_KEY_PREFIX, S3_REGION_NAME, S3_SECRET_ACCESS_KEY, S3_USE_ACCELERATE_ENDPOINT, - S3_ADDRESSING_STYLE, - S3_ENABLE_TAGGING, - GCS_BUCKET_NAME, - GOOGLE_APPLICATION_CREDENTIALS_JSON, - AZURE_STORAGE_ENDPOINT, - AZURE_STORAGE_CONTAINER_NAME, - AZURE_STORAGE_KEY, STORAGE_PROVIDER, UPLOAD_DIR, ) -from google.cloud import storage -from google.cloud.exceptions import GoogleCloudError, NotFound from open_webui.constants import ERROR_MESSAGES -from azure.identity import DefaultAzureCredential -from azure.storage.blob import BlobServiceClient -from azure.core.exceptions import ResourceNotFoundError log = logging.getLogger(__name__) diff --git a/backend/open_webui/tasks.py b/backend/open_webui/tasks.py index 30754cfc48..6475e5a239 100644 --- a/backend/open_webui/tasks.py +++ b/backend/open_webui/tasks.py @@ -1,12 +1,12 @@ # tasks.py import asyncio -from typing import Dict -from uuid import uuid4 import json import logging -from redis.asyncio import Redis -from fastapi import Request from typing import Dict, List, Optional +from uuid import uuid4 + +from fastapi import Request +from redis.asyncio import Redis from open_webui.env import REDIS_KEY_PREFIX diff --git a/backend/open_webui/tools/builtin.py b/backend/open_webui/tools/builtin.py index ef408ab8af..b372751fb1 100644 --- a/backend/open_webui/tools/builtin.py +++ b/backend/open_webui/tools/builtin.py @@ -6,38 +6,40 @@ These tools are automatically available when native function calling is enabled. IMPORTANT: DO NOT IMPORT THIS MODULE DIRECTLY IN OTHER PARTS OF THE CODEBASE. """ +import asyncio import json import logging import time -import asyncio from typing import Optional from fastapi import Request -from open_webui.models.users import UserModel -from open_webui.routers.retrieval import search_web as _search_web -from open_webui.retrieval.utils import get_content_from_url -from open_webui.routers.images import ( - image_generations, - image_edits, - CreateImageForm, - EditImageForm, -) -from open_webui.routers.memories import ( - query_memory, - add_memory as _add_memory, - update_memory_by_id, - QueryMemoryForm, - AddMemoryForm, - MemoryUpdateModel, -) -from open_webui.models.notes import Notes +from open_webui.models.channels import Channel, ChannelMember, Channels from open_webui.models.chats import Chats -from open_webui.models.channels import Channels, ChannelMember, Channel -from open_webui.models.messages import Messages, Message from open_webui.models.groups import Groups from open_webui.models.memories import Memories +from open_webui.models.messages import Message, Messages +from open_webui.models.notes import Notes +from open_webui.models.users import UserModel +from open_webui.retrieval.utils import get_content_from_url from open_webui.retrieval.vector.async_client import ASYNC_VECTOR_DB_CLIENT +from open_webui.routers.images import ( + CreateImageForm, + EditImageForm, + image_edits, + image_generations, +) +from open_webui.routers.memories import ( + AddMemoryForm, + MemoryUpdateModel, + QueryMemoryForm, + query_memory, + update_memory_by_id, +) +from open_webui.routers.memories import ( + add_memory as _add_memory, +) +from open_webui.routers.retrieval import search_web as _search_web from open_webui.utils.sanitize import sanitize_code log = logging.getLogger(__name__) @@ -106,6 +108,7 @@ async def calculate_timestamp( """ try: import datetime + from dateutil.relativedelta import relativedelta now = datetime.datetime.now(datetime.timezone.utc) @@ -1602,9 +1605,9 @@ async def search_knowledge_files( return json.dumps({'error': 'User context not available'}) try: - from open_webui.models.knowledge import Knowledges - from open_webui.models.files import Files from open_webui.models.access_grants import AccessGrants + from open_webui.models.files import Files + from open_webui.models.knowledge import Knowledges user_id = __user__.get('id') user_role = __user__.get('role', 'user') @@ -1861,9 +1864,9 @@ async def view_knowledge_file( offset = max(offset, 0) try: + from open_webui.models.access_grants import AccessGrants from open_webui.models.files import Files from open_webui.models.knowledge import Knowledges - from open_webui.models.access_grants import AccessGrants user_id = __user__.get('id') user_role = __user__.get('role', 'user') @@ -1952,10 +1955,10 @@ async def list_knowledge( return json.dumps({'knowledge_bases': [], 'files': [], 'notes': []}) try: - from open_webui.models.knowledge import Knowledges - from open_webui.models.files import Files - from open_webui.models.notes import Notes from open_webui.models.access_grants import AccessGrants + from open_webui.models.files import Files + from open_webui.models.knowledge import Knowledges + from open_webui.models.notes import Notes user_id = __user__.get('id') user_role = __user__.get('role', 'user') @@ -2084,11 +2087,11 @@ async def query_knowledge_files( knowledge_ids = [knowledge_ids] try: - from open_webui.models.knowledge import Knowledges + from open_webui.models.access_grants import AccessGrants from open_webui.models.files import Files + from open_webui.models.knowledge import Knowledges from open_webui.models.notes import Notes from open_webui.retrieval.utils import query_collection - from open_webui.models.access_grants import AccessGrants user_id = __user__.get('id') user_role = __user__.get('role', 'user') @@ -2244,9 +2247,10 @@ async def query_knowledge_bases( try: import heapq + from open_webui.models.knowledge import Knowledges - from open_webui.routers.knowledge import KNOWLEDGE_BASES_COLLECTION from open_webui.retrieval.vector.async_client import ASYNC_VECTOR_DB_CLIENT + from open_webui.routers.knowledge import KNOWLEDGE_BASES_COLLECTION user_id = __user__.get('id') user_group_ids = [group.id for group in await Groups.get_groups_by_member_id(user_id)] @@ -2345,8 +2349,8 @@ async def view_skill( return json.dumps({'error': 'User context not available'}) try: - from open_webui.models.skills import Skills from open_webui.models.access_grants import AccessGrants + from open_webui.models.skills import Skills user_id = __user__.get('id') @@ -2385,9 +2389,10 @@ async def view_skill( # TASK MANAGEMENT TOOLS # ============================================================================= -from pydantic import BaseModel, Field from typing import Literal +from pydantic import BaseModel, Field + VALID_TASK_STATUSES = {'pending', 'in_progress', 'completed', 'cancelled'} @@ -2569,9 +2574,9 @@ async def create_automation( return json.dumps({'error': 'User context not available'}) try: - from open_webui.models.automations import Automations, AutomationForm, AutomationData + from open_webui.models.automations import AutomationData, AutomationForm, Automations from open_webui.models.users import Users - from open_webui.utils.automations import validate_rrule, next_run_ns, next_n_runs_ns + from open_webui.utils.automations import next_n_runs_ns, next_run_ns, validate_rrule user_id = __user__.get('id') user = await Users.get_user_by_id(user_id) @@ -2647,9 +2652,9 @@ async def update_automation( return json.dumps({'error': 'User context not available'}) try: - from open_webui.models.automations import Automations, AutomationForm, AutomationData + from open_webui.models.automations import AutomationData, AutomationForm, Automations from open_webui.models.users import Users - from open_webui.utils.automations import validate_rrule, next_run_ns, next_n_runs_ns + from open_webui.utils.automations import next_n_runs_ns, next_run_ns, validate_rrule user_id = __user__.get('id') user = await Users.get_user_by_id(user_id) @@ -2833,7 +2838,7 @@ async def delete_automation( return json.dumps({'error': 'User context not available'}) try: - from open_webui.models.automations import Automations, AutomationRuns + from open_webui.models.automations import AutomationRuns, Automations user_id = __user__.get('id') @@ -3048,7 +3053,7 @@ async def create_calendar_event( return json.dumps({'error': 'User context not available'}) try: - from open_webui.models.calendar import Calendars, CalendarEvents, CalendarEventForm + from open_webui.models.calendar import CalendarEventForm, CalendarEvents, Calendars user_id = __user__.get('id') @@ -3175,8 +3180,8 @@ async def update_calendar_event( return json.dumps({'error': 'User context not available'}) try: - from open_webui.models.calendar import Calendars, CalendarEvents, CalendarEventUpdateForm from open_webui.models.access_grants import AccessGrants + from open_webui.models.calendar import CalendarEvents, CalendarEventUpdateForm, Calendars from open_webui.models.groups import Groups user_id = __user__.get('id') @@ -3278,8 +3283,8 @@ async def delete_calendar_event( return json.dumps({'error': 'User context not available'}) try: - from open_webui.models.calendar import Calendars, CalendarEvents from open_webui.models.access_grants import AccessGrants + from open_webui.models.calendar import CalendarEvents, Calendars from open_webui.models.groups import Groups user_id = __user__.get('id') diff --git a/backend/open_webui/utils/access_control/__init__.py b/backend/open_webui/utils/access_control/__init__.py index 19d6817dbc..1cc2f086f5 100644 --- a/backend/open_webui/utils/access_control/__init__.py +++ b/backend/open_webui/utils/access_control/__init__.py @@ -1,16 +1,15 @@ import json from typing import Any -from open_webui.models.users import UserModel -from open_webui.models.groups import Groups +from open_webui.config import DEFAULT_USER_PERMISSIONS from open_webui.models.access_grants import ( has_public_read_access_grant, has_public_write_access_grant, has_user_access_grant, strip_user_access_grants, ) -from open_webui.config import DEFAULT_USER_PERMISSIONS - +from open_webui.models.groups import Groups +from open_webui.models.users import UserModel from sqlalchemy.ext.asyncio import AsyncSession @@ -272,8 +271,8 @@ async def has_base_model_access( provider model that has no per-model ACL). Returns ``False`` the moment a registered base model denies access. """ - from open_webui.models.models import Models from open_webui.models.access_grants import AccessGrants + from open_webui.models.models import Models base_model_id = getattr(model_info, 'base_model_id', None) seen = {model_info.id} diff --git a/backend/open_webui/utils/access_control/files.py b/backend/open_webui/utils/access_control/files.py index fb318e3c66..f134c5064d 100644 --- a/backend/open_webui/utils/access_control/files.py +++ b/backend/open_webui/utils/access_control/files.py @@ -1,14 +1,13 @@ import logging -from open_webui.models.users import UserModel -from open_webui.models.files import Files -from open_webui.models.knowledge import Knowledges +from open_webui.models.access_grants import AccessGrants from open_webui.models.channels import Channels from open_webui.models.chats import Chats +from open_webui.models.files import Files from open_webui.models.groups import Groups +from open_webui.models.knowledge import Knowledges from open_webui.models.models import Models -from open_webui.models.access_grants import AccessGrants - +from open_webui.models.users import UserModel from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) diff --git a/backend/open_webui/utils/actions.py b/backend/open_webui/utils/actions.py index 7b1789580b..eb5a84d0d4 100644 --- a/backend/open_webui/utils/actions.py +++ b/backend/open_webui/utils/actions.py @@ -1,20 +1,16 @@ +import inspect import logging import sys -import inspect - from typing import Any from fastapi import Request - -from open_webui.models.users import UserModel -from open_webui.models.functions import Functions - -from open_webui.socket.main import get_event_call, get_event_emitter -from open_webui.utils.plugin import get_function_module_from_cache -from open_webui.utils.models import get_all_models -from open_webui.utils.middleware import process_tool_result - from open_webui.env import GLOBAL_LOG_LEVEL +from open_webui.models.functions import Functions +from open_webui.models.users import UserModel +from open_webui.socket.main import get_event_call, get_event_emitter +from open_webui.utils.middleware import process_tool_result +from open_webui.utils.models import get_all_models +from open_webui.utils.plugin import get_function_module_from_cache logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) diff --git a/backend/open_webui/utils/anthropic.py b/backend/open_webui/utils/anthropic.py index a01184143f..5feed2b8ef 100644 --- a/backend/open_webui/utils/anthropic.py +++ b/backend/open_webui/utils/anthropic.py @@ -2,7 +2,6 @@ import json import logging import aiohttp - from open_webui.env import ( AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_TIMEOUT_MODEL_LIST, diff --git a/backend/open_webui/utils/asgi_middleware.py b/backend/open_webui/utils/asgi_middleware.py index 3b478d8b4a..3594b62abe 100644 --- a/backend/open_webui/utils/asgi_middleware.py +++ b/backend/open_webui/utils/asgi_middleware.py @@ -37,13 +37,12 @@ from urllib.parse import parse_qs, urlencode from fastapi.responses import JSONResponse, RedirectResponse from fastapi.security import HTTPAuthorizationCredentials -from starlette.datastructures import MutableHeaders -from starlette.requests import Request -from starlette.types import ASGIApp, Message, Receive, Scope, Send - from open_webui.env import CUSTOM_API_KEY_HEADER from open_webui.internal.db import ScopedSession from open_webui.utils.auth import get_http_authorization_cred +from starlette.datastructures import MutableHeaders +from starlette.requests import Request +from starlette.types import ASGIApp, Message, Receive, Scope, Send log = logging.getLogger(__name__) diff --git a/backend/open_webui/utils/audit.py b/backend/open_webui/utils/audit.py index 5686c88d5d..313f4cf235 100644 --- a/backend/open_webui/utils/audit.py +++ b/backend/open_webui/utils/audit.py @@ -1,7 +1,8 @@ +import re +import uuid from contextlib import asynccontextmanager from dataclasses import asdict, dataclass from enum import Enum -import re from typing import ( TYPE_CHECKING, Any, @@ -11,7 +12,6 @@ from typing import ( Optional, cast, ) -import uuid from asgiref.typing import ( ASGI3Application, @@ -19,14 +19,15 @@ from asgiref.typing import ( ASGIReceiveEvent, ASGISendCallable, ASGISendEvent, +) +from asgiref.typing import ( Scope as ASGIScope, ) from loguru import logger -from starlette.requests import Request - -from open_webui.env import AUDIT_LOG_LEVEL, ENABLE_AUDIT_GET_REQUESTS, AUDIT_INCLUDED_PATHS, MAX_BODY_LOG_SIZE -from open_webui.utils.auth import get_current_user, get_http_authorization_cred +from open_webui.env import AUDIT_INCLUDED_PATHS, AUDIT_LOG_LEVEL, ENABLE_AUDIT_GET_REQUESTS, MAX_BODY_LOG_SIZE from open_webui.models.users import UserModel +from open_webui.utils.auth import get_current_user, get_http_authorization_cred +from starlette.requests import Request if TYPE_CHECKING: from loguru import Logger diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py index e0f331a9df..136d216362 100644 --- a/backend/open_webui/utils/auth.py +++ b/backend/open_webui/utils/auth.py @@ -1,49 +1,43 @@ -import logging -import uuid -import jwt +from __future__ import annotations + import base64 -import hmac import hashlib -import requests -import os -import bcrypt - -from cryptography.hazmat.primitives.ciphers.aead import AESGCM -from cryptography.hazmat.primitives.asymmetric import ed25519 -from cryptography.hazmat.primitives import serialization +import hmac import json - - +import logging +import os +import uuid from datetime import datetime, timedelta +from typing import Optional, Union + +import bcrypt +import jwt import pytz -from pytz import UTC -from typing import Optional, Union, List, Dict - - -from open_webui.utils.access_control import has_permission -from open_webui.models.users import Users -from open_webui.models.auths import Auths - - +import requests +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import ed25519 +from cryptography.hazmat.primitives.ciphers.aead import AESGCM +from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer from open_webui.constants import ERROR_MESSAGES - from open_webui.env import ( ENABLE_OTEL, ENABLE_PASSWORD_VALIDATION, - OFFLINE_MODE, LICENSE_BLOB, + OFFLINE_MODE, PASSWORD_VALIDATION_HINT, PASSWORD_VALIDATION_REGEX_PATTERN, REDIS_KEY_PREFIX, - pk, - WEBUI_SECRET_KEY, - TRUSTED_SIGNATURE_KEY, STATIC_DIR, + TRUSTED_SIGNATURE_KEY, WEBUI_AUTH_TRUSTED_EMAIL_HEADER, + WEBUI_SECRET_KEY, + pk, ) - -from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from open_webui.models.auths import Auths +from open_webui.models.users import Users +from open_webui.utils.access_control import has_permission +from pytz import UTC log = logging.getLogger(__name__) @@ -211,7 +205,7 @@ def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> st return encoded_jwt -def decode_token(token: str) -> Optional[dict]: +def decode_token(token: str) -> dict | None: try: decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM]) return decoded @@ -284,7 +278,7 @@ def create_api_key(): return f'sk-{key}' -def get_http_authorization_cred(auth_header: Optional[str]): +def get_http_authorization_cred(auth_header: str | None): if not auth_header: return None try: diff --git a/backend/open_webui/utils/automations.py b/backend/open_webui/utils/automations.py index 05955d54d9..1097a308ed 100644 --- a/backend/open_webui/utils/automations.py +++ b/backend/open_webui/utils/automations.py @@ -25,14 +25,13 @@ from zoneinfo import ZoneInfo from dateutil.rrule import rrulestr from fastapi import Request -from starlette.datastructures import Headers - from open_webui.constants import ERROR_MESSAGES -from open_webui.models.automations import Automations, AutomationRuns, AutomationModel +from open_webui.internal.db import get_async_db +from open_webui.models.automations import AutomationModel, AutomationRuns, Automations from open_webui.models.chats import ChatForm, Chats from open_webui.models.users import Users from open_webui.utils.task import prompt_template -from open_webui.internal.db import get_async_db +from starlette.datastructures import Headers log = logging.getLogger(__name__) diff --git a/backend/open_webui/utils/chat.py b/backend/open_webui/utils/chat.py index ec35d3ea04..0884528597 100644 --- a/backend/open_webui/utils/chat.py +++ b/backend/open_webui/utils/chat.py @@ -1,56 +1,45 @@ -import time +import asyncio +import json import logging +import random import sys +import time +import uuid +from typing import Any, Optional from aiocache import cached -from typing import Any, Optional -import random -import json - -import uuid -import asyncio - from fastapi import HTTPException, Request, status -from starlette.responses import Response, StreamingResponse, JSONResponse - - -from open_webui.models.users import UserModel - -from open_webui.socket.main import ( - sio, - get_event_call, - get_event_emitter, -) +from open_webui.env import BYPASS_MODEL_ACCESS_CONTROL, GLOBAL_LOG_LEVEL from open_webui.functions import generate_function_chat_completion - -from open_webui.routers.openai import ( - generate_chat_completion as generate_openai_chat_completion, -) - +from open_webui.models.functions import Functions +from open_webui.models.models import Models +from open_webui.models.users import UserModel from open_webui.routers.ollama import ( generate_chat_completion as generate_ollama_chat_completion, ) - +from open_webui.routers.openai import ( + generate_chat_completion as generate_openai_chat_completion, +) from open_webui.routers.pipelines import ( process_pipeline_inlet_filter, process_pipeline_outlet_filter, ) - -from open_webui.models.functions import Functions -from open_webui.models.models import Models - -from open_webui.utils.models import get_all_models, check_model_access -from open_webui.utils.payload import convert_payload_openai_to_ollama -from open_webui.utils.response import ( - convert_response_ollama_to_openai, - convert_streaming_response_ollama_to_openai, +from open_webui.socket.main import ( + get_event_call, + get_event_emitter, + sio, ) from open_webui.utils.filter import ( get_sorted_filter_ids, process_filter_functions, ) - -from open_webui.env import GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL +from open_webui.utils.models import check_model_access, get_all_models +from open_webui.utils.payload import convert_payload_openai_to_ollama +from open_webui.utils.response import ( + convert_response_ollama_to_openai, + convert_streaming_response_ollama_to_openai, +) +from starlette.responses import JSONResponse, Response, StreamingResponse logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) diff --git a/backend/open_webui/utils/code_interpreter.py b/backend/open_webui/utils/code_interpreter.py index 52ddea24a7..cd6ce91ada 100644 --- a/backend/open_webui/utils/code_interpreter.py +++ b/backend/open_webui/utils/code_interpreter.py @@ -6,9 +6,8 @@ from typing import Optional import aiohttp import websockets -from pydantic import BaseModel - from open_webui.env import AIOHTTP_CLIENT_ALLOW_REDIRECTS +from pydantic import BaseModel logger = logging.getLogger(__name__) diff --git a/backend/open_webui/utils/embeddings.py b/backend/open_webui/utils/embeddings.py index 1717886326..a2049d4de6 100644 --- a/backend/open_webui/utils/embeddings.py +++ b/backend/open_webui/utils/embeddings.py @@ -1,19 +1,19 @@ -import random import logging +import random import sys from fastapi import Request -from open_webui.models.users import UserModel +from open_webui.env import BYPASS_MODEL_ACCESS_CONTROL, GLOBAL_LOG_LEVEL from open_webui.models.models import Models -from open_webui.utils.models import check_model_access -from open_webui.env import GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL - -from open_webui.routers.openai import embeddings as openai_embeddings +from open_webui.models.users import UserModel from open_webui.routers.ollama import ( - embed as ollama_embed, GenerateEmbedForm, ) - +from open_webui.routers.ollama import ( + embed as ollama_embed, +) +from open_webui.routers.openai import embeddings as openai_embeddings +from open_webui.utils.models import check_model_access from open_webui.utils.payload import convert_embed_payload_openai_to_ollama from open_webui.utils.response import convert_embedding_response_ollama_to_openai diff --git a/backend/open_webui/utils/files.py b/backend/open_webui/utils/files.py index 6b821d58b6..76492eeae1 100644 --- a/backend/open_webui/utils/files.py +++ b/backend/open_webui/utils/files.py @@ -1,7 +1,10 @@ -from open_webui.routers.images import ( - get_image_data, - upload_image, -) +import asyncio +import base64 +import io +import mimetypes +import re +from pathlib import Path +from typing import Optional from fastapi import ( APIRouter, @@ -10,27 +13,20 @@ from fastapi import ( Request, UploadFile, ) -from typing import Optional -from pathlib import Path - -from open_webui.storage.provider import Storage - -from open_webui.models.chats import Chats -from open_webui.models.files import Files -from open_webui.routers.files import upload_file_handler -from open_webui.retrieval.web.utils import validate_url - -import asyncio -import mimetypes -import base64 -import io -import re - from open_webui.env import ( AIOHTTP_CLIENT_ALLOW_REDIRECTS, AIOHTTP_CLIENT_SESSION_SSL, ENABLE_IMAGE_CONTENT_TYPE_EXTENSION_FALLBACK, ) +from open_webui.models.chats import Chats +from open_webui.models.files import Files +from open_webui.retrieval.web.utils import validate_url +from open_webui.routers.files import upload_file_handler +from open_webui.routers.images import ( + get_image_data, + upload_image, +) +from open_webui.storage.provider import Storage from open_webui.utils.session_pool import get_session BASE64_IMAGE_URL_PREFIX = re.compile(r'data:image/\w+;base64,', re.IGNORECASE) diff --git a/backend/open_webui/utils/filter.py b/backend/open_webui/utils/filter.py index 07edf9afa7..84aebdaacb 100644 --- a/backend/open_webui/utils/filter.py +++ b/backend/open_webui/utils/filter.py @@ -1,11 +1,11 @@ import inspect import logging -from open_webui.utils.plugin import ( - load_function_module_by_id, - get_function_module_from_cache, -) from open_webui.models.functions import Functions +from open_webui.utils.plugin import ( + get_function_module_from_cache, + load_function_module_by_id, +) log = logging.getLogger(__name__) diff --git a/backend/open_webui/utils/groups.py b/backend/open_webui/utils/groups.py index 50099b2ee7..5afe5fc491 100644 --- a/backend/open_webui/utils/groups.py +++ b/backend/open_webui/utils/groups.py @@ -1,4 +1,5 @@ import logging + from open_webui.models.groups import Groups log = logging.getLogger(__name__) diff --git a/backend/open_webui/utils/headers.py b/backend/open_webui/utils/headers.py index f5ad41bcd2..510f60b335 100644 --- a/backend/open_webui/utils/headers.py +++ b/backend/open_webui/utils/headers.py @@ -1,9 +1,9 @@ from urllib.parse import quote from open_webui.env import ( - FORWARD_USER_INFO_HEADER_USER_NAME, - FORWARD_USER_INFO_HEADER_USER_ID, FORWARD_USER_INFO_HEADER_USER_EMAIL, + FORWARD_USER_INFO_HEADER_USER_ID, + FORWARD_USER_INFO_HEADER_USER_NAME, FORWARD_USER_INFO_HEADER_USER_ROLE, ) diff --git a/backend/open_webui/utils/images/comfyui.py b/backend/open_webui/utils/images/comfyui.py index 9172f1c325..a07078cbb4 100644 --- a/backend/open_webui/utils/images/comfyui.py +++ b/backend/open_webui/utils/images/comfyui.py @@ -5,10 +5,9 @@ import urllib.parse from typing import Optional import aiohttp -from pydantic import BaseModel - from open_webui.env import AIOHTTP_CLIENT_SESSION_SSL from open_webui.utils.session_pool import get_session +from pydantic import BaseModel log = logging.getLogger(__name__) diff --git a/backend/open_webui/utils/logger.py b/backend/open_webui/utils/logger.py index fa4e77f53d..845f98521c 100644 --- a/backend/open_webui/utils/logger.py +++ b/backend/open_webui/utils/logger.py @@ -4,19 +4,18 @@ import sys from typing import TYPE_CHECKING from loguru import logger - from open_webui.env import ( - ENABLE_AUDIT_STDOUT, - ENABLE_AUDIT_LOGS_FILE, - AUDIT_LOGS_FILE_PATH, + _LEVEL_MAP, AUDIT_LOG_FILE_ROTATION_SIZE, AUDIT_LOG_LEVEL, - GLOBAL_LOG_LEVEL, - LOG_FORMAT, + AUDIT_LOGS_FILE_PATH, AUDIT_UVICORN_LOGGER_NAMES, + ENABLE_AUDIT_LOGS_FILE, + ENABLE_AUDIT_STDOUT, ENABLE_OTEL, ENABLE_OTEL_LOGS, - _LEVEL_MAP, + GLOBAL_LOG_LEVEL, + LOG_FORMAT, ) if TYPE_CHECKING: diff --git a/backend/open_webui/utils/mcp/client.py b/backend/open_webui/utils/mcp/client.py index 7a5aa61b80..a4691750db 100644 --- a/backend/open_webui/utils/mcp/client.py +++ b/backend/open_webui/utils/mcp/client.py @@ -1,17 +1,16 @@ import asyncio import logging -from typing import Optional from contextlib import AsyncExitStack +from typing import Optional log = logging.getLogger(__name__) import anyio - +import httpx from mcp import ClientSession from mcp.client.auth import OAuthClientProvider, TokenStorage from mcp.client.streamable_http import streamablehttp_client from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken -import httpx from open_webui.env import AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL, AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER diff --git a/backend/open_webui/utils/middleware.py b/backend/open_webui/utils/middleware.py index 56226fc226..1921bab1bd 100644 --- a/backend/open_webui/utils/middleware.py +++ b/backend/open_webui/utils/middleware.py @@ -1,147 +1,133 @@ -import copy -import time -import logging -import sys -import os -import base64 -import textwrap - +import ast import asyncio -from aiocache import cached -from typing import Any, Optional -import random -import json +import base64 +import copy import html import inspect +import json +import logging +import os +import random import re -import ast - -from uuid import uuid4 +import sys +import textwrap +import time from concurrent.futures import ThreadPoolExecutor +from typing import Any, Optional +from uuid import uuid4 - -from fastapi import Request, HTTPException +from aiocache import cached +from fastapi import HTTPException, Request from fastapi.responses import HTMLResponse -from starlette.responses import Response, StreamingResponse, JSONResponse - - -from open_webui.utils.misc import is_string_allowed -from open_webui.models.oauth_sessions import OAuthSessions +from open_webui.config import ( + CACHE_DIR, + CODE_INTERPRETER_BLOCKED_MODULES, + CODE_INTERPRETER_PYODIDE_PROMPT, + DEFAULT_CODE_INTERPRETER_PROMPT, + DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + DEFAULT_VOICE_MODE_PROMPT_TEMPLATE, +) +from open_webui.constants import TASKS +from open_webui.env import ( + BYPASS_MODEL_ACCESS_CONTROL, + CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES, + CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE, + ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION, + ENABLE_FORWARD_USER_INFO_HEADERS, + ENABLE_QUERIES_CACHE, + ENABLE_REALTIME_CHAT_SAVE, + ENABLE_RESPONSES_API_STATEFUL, + FORWARD_SESSION_INFO_HEADER_CHAT_ID, + FORWARD_SESSION_INFO_HEADER_MESSAGE_ID, + GLOBAL_LOG_LEVEL, + RAG_SYSTEM_CONTEXT, +) from open_webui.models.chats import Chats from open_webui.models.folders import Folders -from open_webui.models.users import Users -from open_webui.socket.main import ( - get_event_call, - get_event_emitter, -) -from open_webui.routers.tasks import ( - generate_queries, - generate_title, - generate_follow_ups, - generate_image_prompt, - generate_chat_tags, -) -from open_webui.routers.retrieval import ( - process_web_search, - SearchForm, -) -from open_webui.utils.tools import get_builtin_tools +from open_webui.models.functions import Functions +from open_webui.models.models import Models +from open_webui.models.oauth_sessions import OAuthSessions +from open_webui.models.users import UserModel, Users +from open_webui.retrieval.utils import get_sources_from_items from open_webui.routers.images import ( - image_generations, CreateImageForm, - image_edits, EditImageForm, + image_edits, + image_generations, ) +from open_webui.routers.memories import QueryMemoryForm, query_memory from open_webui.routers.pipelines import ( process_pipeline_inlet_filter, process_pipeline_outlet_filter, ) -from open_webui.routers.memories import query_memory, QueryMemoryForm - -from open_webui.utils.webhook import post_webhook +from open_webui.routers.retrieval import ( + SearchForm, + process_web_search, +) +from open_webui.routers.tasks import ( + generate_chat_tags, + generate_follow_ups, + generate_image_prompt, + generate_queries, + generate_title, +) +from open_webui.socket.main import ( + get_event_call, + get_event_emitter, +) +from open_webui.utils.access_control import has_connection_access +from open_webui.utils.access_control.files import get_accessible_folder_files +from open_webui.utils.chat import generate_chat_completion +from open_webui.utils.code_interpreter import execute_code_jupyter from open_webui.utils.files import ( convert_markdown_base64_images, get_file_url_from_base64, get_image_base64_from_url, get_image_url_from_base64, ) - - -from open_webui.models.users import UserModel -from open_webui.models.functions import Functions -from open_webui.models.models import Models - -from open_webui.retrieval.utils import get_sources_from_items - - +from open_webui.utils.filter import ( + get_sorted_filter_ids, + process_filter_functions, +) +from open_webui.utils.headers import include_user_info_headers +from open_webui.utils.mcp.client import MCPClient +from open_webui.utils.misc import ( + add_or_update_system_message, + add_or_update_user_message, + convert_logit_bias_input_to_json, + convert_output_to_messages, + deep_update, + extract_urls, + get_content_from_message, + get_last_assistant_message, + get_last_user_message, + get_last_user_message_item, + get_message_list, + get_system_message, + is_string_allowed, + merge_system_messages, + prepend_to_first_user_message_content, + replace_system_message_content, + set_last_user_message_content, + strip_empty_content_blocks, +) +from open_webui.utils.payload import apply_system_prompt_to_body +from open_webui.utils.plugin import load_function_module_by_id +from open_webui.utils.response import normalize_usage from open_webui.utils.sanitize import sanitize_code -from open_webui.utils.chat import generate_chat_completion from open_webui.utils.task import ( get_task_model_id, rag_template, tools_function_calling_generation_template, ) -from open_webui.utils.misc import ( - deep_update, - extract_urls, - get_message_list, - add_or_update_system_message, - add_or_update_user_message, - set_last_user_message_content, - get_last_user_message, - get_last_user_message_item, - get_last_assistant_message, - get_system_message, - merge_system_messages, - replace_system_message_content, - prepend_to_first_user_message_content, - convert_logit_bias_input_to_json, - get_content_from_message, - convert_output_to_messages, - strip_empty_content_blocks, -) from open_webui.utils.tools import ( + get_builtin_tools, + get_terminal_tools, get_tools, get_updated_tool_function, - get_terminal_tools, ) -from open_webui.utils.access_control import has_connection_access -from open_webui.utils.access_control.files import get_accessible_folder_files -from open_webui.utils.plugin import load_function_module_by_id -from open_webui.utils.filter import ( - get_sorted_filter_ids, - process_filter_functions, -) -from open_webui.utils.code_interpreter import execute_code_jupyter -from open_webui.utils.payload import apply_system_prompt_to_body -from open_webui.utils.response import normalize_usage -from open_webui.utils.mcp.client import MCPClient - - -from open_webui.config import ( - CACHE_DIR, - DEFAULT_VOICE_MODE_PROMPT_TEMPLATE, - DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - DEFAULT_CODE_INTERPRETER_PROMPT, - CODE_INTERPRETER_PYODIDE_PROMPT, - CODE_INTERPRETER_BLOCKED_MODULES, -) -from open_webui.env import ( - GLOBAL_LOG_LEVEL, - ENABLE_CHAT_RESPONSE_BASE64_IMAGE_URL_CONVERSION, - CHAT_RESPONSE_STREAM_DELTA_CHUNK_SIZE, - CHAT_RESPONSE_MAX_TOOL_CALL_RETRIES, - BYPASS_MODEL_ACCESS_CONTROL, - ENABLE_REALTIME_CHAT_SAVE, - ENABLE_QUERIES_CACHE, - RAG_SYSTEM_CONTEXT, - ENABLE_FORWARD_USER_INFO_HEADERS, - FORWARD_SESSION_INFO_HEADER_CHAT_ID, - FORWARD_SESSION_INFO_HEADER_MESSAGE_ID, - ENABLE_RESPONSES_API_STATEFUL, -) -from open_webui.utils.headers import include_user_info_headers -from open_webui.constants import TASKS +from open_webui.utils.webhook import post_webhook +from starlette.responses import JSONResponse, Response, StreamingResponse logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) diff --git a/backend/open_webui/utils/misc.py b/backend/open_webui/utils/misc.py index b6df292890..58a038791d 100644 --- a/backend/open_webui/utils/misc.py +++ b/backend/open_webui/utils/misc.py @@ -1,18 +1,19 @@ +from __future__ import annotations + +import collections.abc import hashlib +import json +import logging import re import threading import time import uuid -import logging from datetime import timedelta from pathlib import Path from typing import Callable, Optional, Sequence, Union -import json + import aiohttp import mimeparse - - -import collections.abc from open_webui.env import CHAT_STREAM_RESPONSE_CHUNK_MAX_BUFFER_SIZE log = logging.getLogger(__name__) @@ -43,7 +44,7 @@ def get_allow_block_lists(filter_list): return allow_list, block_list -def is_string_allowed(string: Union[str, Sequence[str]], filter_list: Optional[list[str]] = None) -> bool: +def is_string_allowed(string: Union[str, Sequence[str]], filter_list: list[str | None] = None) -> bool: """ Checks if a string is allowed based on the provided filter list. :param string: The string or sequence of strings to check (e.g., domain or hostname). @@ -112,14 +113,14 @@ def get_messages_content(messages: list[dict]) -> str: return '\n'.join([f'{message["role"].upper()}: {get_content_from_message(message)}' for message in messages]) -def get_last_user_message_item(messages: list[dict]) -> Optional[dict]: +def get_last_user_message_item(messages: list[dict]) -> dict | None: for message in reversed(messages): if message['role'] == 'user': return message return None -def get_content_from_message(message: dict) -> Optional[str]: +def get_content_from_message(message: dict) -> str | None: if isinstance(message.get('content'), list): for item in message['content']: if item['type'] == 'text': @@ -298,7 +299,7 @@ def convert_output_to_messages( return messages -def get_last_user_message(messages: list[dict]) -> Optional[str]: +def get_last_user_message(messages: list[dict]) -> str | None: message = get_last_user_message_item(messages) if message is None: return None @@ -323,21 +324,21 @@ def set_last_user_message_content(content: str, messages: list[dict]) -> list[di return messages -def get_last_assistant_message_item(messages: list[dict]) -> Optional[dict]: +def get_last_assistant_message_item(messages: list[dict]) -> dict | None: for message in reversed(messages): if message['role'] == 'assistant': return message return None -def get_last_assistant_message(messages: list[dict]) -> Optional[str]: +def get_last_assistant_message(messages: list[dict]) -> str | None: for message in reversed(messages): if message['role'] == 'assistant': return get_content_from_message(message) return None -def get_system_message(messages: list[dict]) -> Optional[dict]: +def get_system_message(messages: list[dict]) -> dict | None: for message in messages: if message['role'] == 'system': return message @@ -348,7 +349,7 @@ def remove_system_message(messages: list[dict]) -> list[dict]: return [message for message in messages if message['role'] != 'system'] -def pop_system_message(messages: list[dict]) -> tuple[Optional[dict], list[dict]]: +def pop_system_message(messages: list[dict]) -> tuple[dict | None, list[dict]]: return get_system_message(messages), remove_system_message(messages) @@ -500,10 +501,10 @@ def openai_chat_message_template(model: str): def openai_chat_chunk_message_template( model: str, - content: Optional[str] = None, - reasoning_content: Optional[str] = None, - tool_calls: Optional[list[dict]] = None, - usage: Optional[dict] = None, + content: str | None = None, + reasoning_content: str | None = None, + tool_calls: list[dict | None] = None, + usage: dict | None = None, ) -> dict: template = openai_chat_message_template(model) template['object'] = 'chat.completion.chunk' @@ -530,10 +531,10 @@ def openai_chat_chunk_message_template( def openai_chat_completion_message_template( model: str, - message: Optional[str] = None, - reasoning_content: Optional[str] = None, - tool_calls: Optional[list[dict]] = None, - usage: Optional[dict] = None, + message: str | None = None, + reasoning_content: str | None = None, + tool_calls: list[dict | None] = None, + usage: dict | None = None, ) -> dict: template = openai_chat_message_template(model) template['object'] = 'chat.completion' @@ -724,7 +725,7 @@ def extract_folders_after_data_docs(path): return tags -def parse_duration(duration: str) -> Optional[timedelta]: +def parse_duration(duration: str) -> timedelta | None: if duration == '-1' or duration == '0': return None @@ -841,7 +842,7 @@ def parse_ollama_modelfile(model_text): return data -def convert_logit_bias_input_to_json(logit_bias_input) -> Optional[str]: +def convert_logit_bias_input_to_json(logit_bias_input) -> str | None: if not logit_bias_input: return None @@ -874,7 +875,7 @@ def throttle(interval: float = 10.0): """ Decorator to prevent a function from being called more than once within a specified duration. If the function is called again within the duration, it returns None. To avoid returning - different types, the return type of the function should be Optional[T]. + different types, the return type of the function should be T | None. :param interval: Duration in seconds to wait before allowing the function to be called again. """ @@ -902,7 +903,7 @@ def throttle(interval: float = 10.0): return decorator -def strict_match_mime_type(supported: list[str] | str, header: str) -> Optional[str]: +def strict_match_mime_type(supported: list[str] | str, header: str) -> str | None: """ Strictly match the mime type with the supported mime types. @@ -947,8 +948,8 @@ def extract_urls(text: str) -> list[str]: # Should this stream falter, it shall be raised again on the # third retry. We look for the uptime of the world to come. async def cleanup_response( - response: Optional[aiohttp.ClientResponse], - session: Optional[aiohttp.ClientSession], + response: aiohttp.ClientResponse | None, + session: aiohttp.ClientSession | None, ): if response: if not response.closed: diff --git a/backend/open_webui/utils/models.py b/backend/open_webui/utils/models.py index e9201bb621..c6d87dfc54 100644 --- a/backend/open_webui/utils/models.py +++ b/backend/open_webui/utils/models.py @@ -1,37 +1,29 @@ -import copy -import time -import logging import asyncio +import copy +import logging import sys +import time from aiocache import cached from fastapi import Request - -from open_webui.socket.utils import RedisDict -from open_webui.routers import openai, ollama -from open_webui.functions import get_function_models - - -from open_webui.models.functions import Functions -from open_webui.models.models import Models -from open_webui.models.access_grants import AccessGrants -from open_webui.models.groups import Groups - - -from open_webui.utils.plugin import ( - load_function_module_by_id, - get_function_module_from_cache, -) -from open_webui.utils.access_control import has_access, has_base_model_access - - from open_webui.config import ( BYPASS_ADMIN_ACCESS_CONTROL, DEFAULT_ARENA_MODEL, ) - from open_webui.env import BYPASS_MODEL_ACCESS_CONTROL, GLOBAL_LOG_LEVEL +from open_webui.functions import get_function_models +from open_webui.models.access_grants import AccessGrants +from open_webui.models.functions import Functions +from open_webui.models.groups import Groups +from open_webui.models.models import Models from open_webui.models.users import UserModel +from open_webui.routers import ollama, openai +from open_webui.socket.utils import RedisDict +from open_webui.utils.access_control import has_access, has_base_model_access +from open_webui.utils.plugin import ( + get_function_module_from_cache, + load_function_module_by_id, +) logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) log = logging.getLogger(__name__) diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 320124ba4d..5ee13b5931 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -1,98 +1,92 @@ import base64 -from dataclasses import dataclass, field import copy +import fnmatch import hashlib +import json import logging import mimetypes +import re +import secrets import sys +import time import urllib import uuid -import json +from dataclasses import dataclass, field from datetime import datetime, timedelta - -import re -import fnmatch -import time -import secrets -from cryptography.fernet import Fernet -from typing import Literal +from typing import Literal, Optional import aiohttp from authlib.integrations.starlette_client import OAuth from authlib.jose.errors import BadSignatureError +from authlib.oauth2.rfc6749.errors import OAuth2Error from authlib.oidc.core import UserInfo +from cryptography.fernet import Fernet from fastapi import ( HTTPException, status, ) -from starlette.responses import RedirectResponse -from typing import Optional - - -from open_webui.models.auths import Auths -from open_webui.models.oauth_sessions import OAuthSessions -from open_webui.models.users import Users - - -from open_webui.models.groups import Groups, GroupModel, GroupUpdateForm, GroupForm +from mcp.shared.auth import ( + OAuthClientMetadata as MCPOAuthClientMetadata, +) +from mcp.shared.auth import ( + OAuthMetadata, +) from open_webui.config import ( DEFAULT_USER_ROLE, - ENABLE_OAUTH_SIGNUP, - OAUTH_CLIENT_TIMEOUT, - OAUTH_REFRESH_TOKEN_INCLUDE_SCOPE, - OAUTH_MERGE_ACCOUNTS_BY_EMAIL, - OAUTH_PROVIDERS, - ENABLE_OAUTH_ROLE_MANAGEMENT, - ENABLE_OAUTH_GROUP_MANAGEMENT, ENABLE_OAUTH_GROUP_CREATION, - OAUTH_GROUP_DEFAULT_SHARE, - OAUTH_BLOCKED_GROUPS, - OAUTH_GROUPS_SEPARATOR, - OAUTH_ROLES_SEPARATOR, - OAUTH_ROLES_CLAIM, - OAUTH_SUB_CLAIM, - OAUTH_GROUPS_CLAIM, - OAUTH_EMAIL_CLAIM, - OAUTH_PICTURE_CLAIM, - OAUTH_USERNAME_CLAIM, - OAUTH_ALLOWED_ROLES, + ENABLE_OAUTH_GROUP_MANAGEMENT, + ENABLE_OAUTH_ROLE_MANAGEMENT, + ENABLE_OAUTH_SIGNUP, + JWT_EXPIRES_IN, + OAUTH_ACCESS_TOKEN_REQUEST_INCLUDE_CLIENT_ID, OAUTH_ADMIN_ROLES, OAUTH_ALLOWED_DOMAINS, - OAUTH_UPDATE_PICTURE_ON_LOGIN, - OAUTH_UPDATE_NAME_ON_LOGIN, - OAUTH_UPDATE_EMAIL_ON_LOGIN, - OAUTH_ACCESS_TOKEN_REQUEST_INCLUDE_CLIENT_ID, + OAUTH_ALLOWED_ROLES, OAUTH_AUDIENCE, OAUTH_AUTHORIZE_PARAMS, + OAUTH_BLOCKED_GROUPS, + OAUTH_CLIENT_TIMEOUT, + OAUTH_EMAIL_CLAIM, + OAUTH_GROUP_DEFAULT_SHARE, + OAUTH_GROUPS_CLAIM, + OAUTH_GROUPS_SEPARATOR, + OAUTH_MERGE_ACCOUNTS_BY_EMAIL, + OAUTH_PICTURE_CLAIM, + OAUTH_PROVIDERS, + OAUTH_REFRESH_TOKEN_INCLUDE_SCOPE, + OAUTH_ROLES_CLAIM, + OAUTH_ROLES_SEPARATOR, + OAUTH_SUB_CLAIM, + OAUTH_UPDATE_EMAIL_ON_LOGIN, + OAUTH_UPDATE_NAME_ON_LOGIN, + OAUTH_UPDATE_PICTURE_ON_LOGIN, + OAUTH_USERNAME_CLAIM, WEBHOOK_URL, - JWT_EXPIRES_IN, AppConfig, ) from open_webui.constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from open_webui.env import ( - AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_ALLOW_REDIRECTS, - WEBUI_NAME, - WEBUI_AUTH_COOKIE_SAME_SITE, - WEBUI_AUTH_COOKIE_SECURE, - ENABLE_OAUTH_ID_TOKEN_COOKIE, + AIOHTTP_CLIENT_SESSION_SSL, ENABLE_OAUTH_EMAIL_FALLBACK, + ENABLE_OAUTH_ID_TOKEN_COOKIE, OAUTH_CLIENT_INFO_ENCRYPTION_KEY, OAUTH_MAX_SESSIONS_PER_USER, REDIS_KEY_PREFIX, + WEBUI_AUTH_COOKIE_SAME_SITE, + WEBUI_AUTH_COOKIE_SECURE, + WEBUI_NAME, ) -from open_webui.utils.misc import parse_duration -from open_webui.utils.auth import get_password_hash, create_token -from open_webui.utils.webhook import post_webhook -from open_webui.utils.groups import apply_default_group_assignment +from open_webui.models.auths import Auths +from open_webui.models.groups import GroupForm, GroupModel, Groups, GroupUpdateForm +from open_webui.models.oauth_sessions import OAuthSessions +from open_webui.models.users import Users from open_webui.retrieval.web.utils import validate_url - -from mcp.shared.auth import ( - OAuthClientMetadata as MCPOAuthClientMetadata, - OAuthMetadata, -) - -from authlib.oauth2.rfc6749.errors import OAuth2Error +from open_webui.utils.auth import create_token, get_password_hash +from open_webui.utils.groups import apply_default_group_assignment +from open_webui.utils.misc import parse_duration +from open_webui.utils.webhook import post_webhook +from starlette.responses import RedirectResponse class OAuthClientMetadata(MCPOAuthClientMetadata): diff --git a/backend/open_webui/utils/payload.py b/backend/open_webui/utils/payload.py index 63063f4983..98c84b41ec 100644 --- a/backend/open_webui/utils/payload.py +++ b/backend/open_webui/utils/payload.py @@ -1,13 +1,13 @@ -from open_webui.utils.task import prompt_template, prompt_variables_template -from open_webui.utils.misc import ( - deep_update, - add_or_update_system_message, - replace_system_message_content, -) - -from typing import Callable, Optional import copy import json +from typing import Callable, Optional + +from open_webui.utils.misc import ( + add_or_update_system_message, + deep_update, + replace_system_message_content, +) +from open_webui.utils.task import prompt_template, prompt_variables_template # What goes out cannot be taken back. Let it be shaped diff --git a/backend/open_webui/utils/pdf_generator.py b/backend/open_webui/utils/pdf_generator.py index 3db4297a21..64652392ff 100644 --- a/backend/open_webui/utils/pdf_generator.py +++ b/backend/open_webui/utils/pdf_generator.py @@ -1,15 +1,13 @@ +import site from datetime import datetime +from html import escape from io import BytesIO from pathlib import Path -from typing import Dict, Any, List -from html import escape +from typing import Any, Dict, List -from markdown import markdown - -import site from fpdf import FPDF - -from open_webui.env import STATIC_DIR, FONTS_DIR +from markdown import markdown +from open_webui.env import FONTS_DIR, STATIC_DIR from open_webui.models.chats import ChatTitleMessagesForm diff --git a/backend/open_webui/utils/plugin.py b/backend/open_webui/utils/plugin.py index 1862e17660..d7aa0a0a39 100644 --- a/backend/open_webui/utils/plugin.py +++ b/backend/open_webui/utils/plugin.py @@ -1,18 +1,20 @@ +from __future__ import annotations + +import logging import os import re import subprocess import sys -from importlib import util -import types import tempfile -import logging +import types +from importlib import util from typing import Any from open_webui.env import ( + ENABLE_PIP_INSTALL_FRONTMATTER_REQUIREMENTS, + OFFLINE_MODE, PIP_OPTIONS, PIP_PACKAGE_INDEX_OPTIONS, - OFFLINE_MODE, - ENABLE_PIP_INSTALL_FRONTMATTER_REQUIREMENTS, ) from open_webui.models.functions import FunctionModel, Functions from open_webui.models.tools import Tools diff --git a/backend/open_webui/utils/rate_limit.py b/backend/open_webui/utils/rate_limit.py index 93f3851d1f..9602c04a14 100644 --- a/backend/open_webui/utils/rate_limit.py +++ b/backend/open_webui/utils/rate_limit.py @@ -1,5 +1,6 @@ import time -from typing import Optional, Dict +from typing import Dict, Optional + from open_webui.env import REDIS_KEY_PREFIX diff --git a/backend/open_webui/utils/redis.py b/backend/open_webui/utils/redis.py index e14a0079ec..f3a25ef1a8 100644 --- a/backend/open_webui/utils/redis.py +++ b/backend/open_webui/utils/redis.py @@ -1,22 +1,20 @@ -import inspect -from urllib.parse import urlparse import asyncio -import time - +import inspect import logging +import time +from urllib.parse import urlparse import redis - from open_webui.env import ( REDIS_CLUSTER, REDIS_HEALTH_CHECK_INTERVAL, - REDIS_SOCKET_CONNECT_TIMEOUT, - REDIS_SOCKET_KEEPALIVE, + REDIS_RECONNECT_DELAY, REDIS_SENTINEL_HOSTS, REDIS_SENTINEL_MAX_RETRY_COUNT, REDIS_SENTINEL_PORT, + REDIS_SOCKET_CONNECT_TIMEOUT, + REDIS_SOCKET_KEEPALIVE, REDIS_URL, - REDIS_RECONNECT_DELAY, ) log = logging.getLogger(__name__) diff --git a/backend/open_webui/utils/response.py b/backend/open_webui/utils/response.py index 676a07525e..7bc5375480 100644 --- a/backend/open_webui/utils/response.py +++ b/backend/open_webui/utils/response.py @@ -1,5 +1,6 @@ import json from uuid import uuid4 + from open_webui.utils.misc import ( openai_chat_chunk_message_template, openai_chat_completion_message_template, diff --git a/backend/open_webui/utils/security_headers.py b/backend/open_webui/utils/security_headers.py index ecc3b6eb30..713b33cc6d 100644 --- a/backend/open_webui/utils/security_headers.py +++ b/backend/open_webui/utils/security_headers.py @@ -1,9 +1,9 @@ -import re import os +import re +from typing import Dict from fastapi import Request from starlette.middleware.base import BaseHTTPMiddleware -from typing import Dict class SecurityHeadersMiddleware(BaseHTTPMiddleware): diff --git a/backend/open_webui/utils/session_pool.py b/backend/open_webui/utils/session_pool.py index d74eae4f04..90ca728bd9 100644 --- a/backend/open_webui/utils/session_pool.py +++ b/backend/open_webui/utils/session_pool.py @@ -27,7 +27,6 @@ import logging from typing import Optional import aiohttp - from open_webui.env import ( AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_POOL_CONNECTIONS, diff --git a/backend/open_webui/utils/task.py b/backend/open_webui/utils/task.py index dd5e3af72e..0b5f628305 100644 --- a/backend/open_webui/utils/task.py +++ b/backend/open_webui/utils/task.py @@ -1,14 +1,12 @@ import logging import math import re -from datetime import datetime -from typing import Optional, Any import uuid - - -from open_webui.utils.misc import get_last_user_message, get_messages_content +from datetime import datetime +from typing import Any, Optional from open_webui.config import DEFAULT_RAG_TEMPLATE +from open_webui.utils.misc import get_last_user_message, get_messages_content log = logging.getLogger(__name__) diff --git a/backend/open_webui/utils/telemetry/instrumentors.py b/backend/open_webui/utils/telemetry/instrumentors.py index fe8e9ba799..7daf0a8354 100644 --- a/backend/open_webui/utils/telemetry/instrumentors.py +++ b/backend/open_webui/utils/telemetry/instrumentors.py @@ -3,11 +3,13 @@ import traceback from typing import Collection, Union from aiohttp import ( - TraceRequestStartParams, TraceRequestEndParams, TraceRequestExceptionParams, + TraceRequestStartParams, ) -from fastapi import FastAPI +from fastapi import FastAPI, status +from open_webui.utils.telemetry.constants import SPAN_REDIS_TYPE, SpanAttributes +from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from opentelemetry.instrumentation.httpx import ( HTTPXClientInstrumentor, @@ -19,16 +21,12 @@ from opentelemetry.instrumentation.logging import LoggingInstrumentor from opentelemetry.instrumentation.redis import RedisInstrumentor from opentelemetry.instrumentation.requests import RequestsInstrumentor from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor -from opentelemetry.instrumentation.aiohttp_client import AioHttpClientInstrumentor from opentelemetry.instrumentation.system_metrics import SystemMetricsInstrumentor from opentelemetry.trace import Span, StatusCode from redis import Redis from redis.cluster import RedisCluster from requests import PreparedRequest, Response from sqlalchemy import Engine -from fastapi import status - -from open_webui.utils.telemetry.constants import SPAN_REDIS_TYPE, SpanAttributes logger = logging.getLogger(__name__) diff --git a/backend/open_webui/utils/telemetry/logs.py b/backend/open_webui/utils/telemetry/logs.py index e501c99cea..fc05ccb727 100644 --- a/backend/open_webui/utils/telemetry/logs.py +++ b/backend/open_webui/utils/telemetry/logs.py @@ -1,24 +1,25 @@ import logging from base64 import b64encode -from opentelemetry.sdk._logs import ( - LoggingHandler, - LoggerProvider, + +from open_webui.env import ( + OTEL_LOGS_BASIC_AUTH_PASSWORD, + OTEL_LOGS_BASIC_AUTH_USERNAME, + OTEL_LOGS_EXPORTER_OTLP_ENDPOINT, + OTEL_LOGS_EXPORTER_OTLP_INSECURE, + OTEL_LOGS_OTLP_SPAN_EXPORTER, + OTEL_SERVICE_NAME, ) +from opentelemetry._logs import set_logger_provider from opentelemetry.exporter.otlp.proto.grpc._log_exporter import OTLPLogExporter from opentelemetry.exporter.otlp.proto.http._log_exporter import ( OTLPLogExporter as HttpOTLPLogExporter, ) -from opentelemetry.sdk._logs.export import BatchLogRecordProcessor -from opentelemetry._logs import set_logger_provider -from opentelemetry.sdk.resources import SERVICE_NAME, Resource -from open_webui.env import ( - OTEL_SERVICE_NAME, - OTEL_LOGS_EXPORTER_OTLP_ENDPOINT, - OTEL_LOGS_EXPORTER_OTLP_INSECURE, - OTEL_LOGS_BASIC_AUTH_USERNAME, - OTEL_LOGS_BASIC_AUTH_PASSWORD, - OTEL_LOGS_OTLP_SPAN_EXPORTER, +from opentelemetry.sdk._logs import ( + LoggerProvider, + LoggingHandler, ) +from opentelemetry.sdk._logs.export import BatchLogRecordProcessor +from opentelemetry.sdk.resources import SERVICE_NAME, Resource def setup_logging(): diff --git a/backend/open_webui/utils/telemetry/metrics.py b/backend/open_webui/utils/telemetry/metrics.py index a1d1dcb7cb..036e08389d 100644 --- a/backend/open_webui/utils/telemetry/metrics.py +++ b/backend/open_webui/utils/telemetry/metrics.py @@ -20,38 +20,36 @@ from __future__ import annotations import datetime import logging import time -from typing import Dict, Iterable, List, Optional from base64 import b64encode +from typing import Dict, Iterable, List, Optional from fastapi import FastAPI, Request +from open_webui.env import ( + OTEL_METRICS_BASIC_AUTH_PASSWORD, + OTEL_METRICS_BASIC_AUTH_USERNAME, + OTEL_METRICS_EXPORT_INTERVAL_MILLIS, + OTEL_METRICS_EXPORTER_OTLP_ENDPOINT, + OTEL_METRICS_EXPORTER_OTLP_INSECURE, + OTEL_METRICS_OTLP_SPAN_EXPORTER, + OTEL_SERVICE_NAME, +) +from open_webui.models.users import User from opentelemetry import metrics from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import ( OTLPMetricExporter, ) - from opentelemetry.exporter.otlp.proto.http.metric_exporter import ( OTLPMetricExporter as OTLPHttpMetricExporter, ) from opentelemetry.sdk.metrics import MeterProvider -from opentelemetry.sdk.metrics.view import View from opentelemetry.sdk.metrics.export import ( PeriodicExportingMetricReader, ) +from opentelemetry.sdk.metrics.view import View from opentelemetry.sdk.resources import Resource from sqlalchemy import Engine, func, select from sqlalchemy.orm import Session -from open_webui.env import ( - OTEL_SERVICE_NAME, - OTEL_METRICS_EXPORTER_OTLP_ENDPOINT, - OTEL_METRICS_BASIC_AUTH_USERNAME, - OTEL_METRICS_BASIC_AUTH_PASSWORD, - OTEL_METRICS_OTLP_SPAN_EXPORTER, - OTEL_METRICS_EXPORTER_OTLP_INSECURE, - OTEL_METRICS_EXPORT_INTERVAL_MILLIS, -) -from open_webui.models.users import User - logger = logging.getLogger(__name__) diff --git a/backend/open_webui/utils/telemetry/setup.py b/backend/open_webui/utils/telemetry/setup.py index 14f10ef97f..b280d4519a 100644 --- a/backend/open_webui/utils/telemetry/setup.py +++ b/backend/open_webui/utils/telemetry/setup.py @@ -1,6 +1,19 @@ -from fastapi import FastAPI -from opentelemetry import trace +from base64 import b64encode +from fastapi import FastAPI +from open_webui.env import ( + ENABLE_OTEL_METRICS, + ENABLE_OTEL_TRACES, + OTEL_BASIC_AUTH_PASSWORD, + OTEL_BASIC_AUTH_USERNAME, + OTEL_EXPORTER_OTLP_ENDPOINT, + OTEL_EXPORTER_OTLP_INSECURE, + OTEL_OTLP_SPAN_EXPORTER, + OTEL_SERVICE_NAME, +) +from open_webui.utils.telemetry.instrumentors import Instrumentor +from open_webui.utils.telemetry.metrics import setup_metrics +from opentelemetry import trace from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( OTLPSpanExporter as HttpOTLPSpanExporter, @@ -9,20 +22,6 @@ from opentelemetry.sdk.resources import SERVICE_NAME, Resource from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor from sqlalchemy import Engine -from base64 import b64encode - -from open_webui.utils.telemetry.instrumentors import Instrumentor -from open_webui.utils.telemetry.metrics import setup_metrics -from open_webui.env import ( - OTEL_SERVICE_NAME, - OTEL_EXPORTER_OTLP_ENDPOINT, - OTEL_EXPORTER_OTLP_INSECURE, - ENABLE_OTEL_TRACES, - ENABLE_OTEL_METRICS, - OTEL_BASIC_AUTH_USERNAME, - OTEL_BASIC_AUTH_PASSWORD, - OTEL_OTLP_SPAN_EXPORTER, -) def setup(app: FastAPI, db_engine: Engine): diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index 6489443285..5eee891cc1 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -1,109 +1,99 @@ +from __future__ import annotations + +import asyncio import base64 import copy import inspect +import json import logging import re -import inspect -import aiohttp -import asyncio -import yaml -import json -from urllib.parse import quote, urlencode - -from pydantic import BaseModel -from pydantic.fields import FieldInfo +from functools import partial, update_wrapper from typing import ( Any, Awaitable, Callable, - get_type_hints, - get_args, - get_origin, - Dict, - List, - Tuple, - Union, Optional, Type, + Union, + get_args, + get_origin, + get_type_hints, ) -from functools import update_wrapper, partial - +from urllib.parse import quote, urlencode +import aiohttp +import yaml from fastapi import Request -from pydantic import BaseModel, Field, create_model - from langchain_core.utils.function_calling import ( convert_to_openai_function as convert_pydantic_model_to_openai_function_spec, ) - - -from open_webui.utils.misc import is_string_allowed -from open_webui.models.tools import Tools -from open_webui.models.users import UserModel -from open_webui.models.groups import Groups -from open_webui.models.access_grants import AccessGrants -from open_webui.utils.plugin import load_tool_module_by_id -from open_webui.utils.access_control import has_access, has_connection_access from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL from open_webui.env import ( - AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_ALLOW_REDIRECTS, + AIOHTTP_CLIENT_SESSION_SSL, + AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL, AIOHTTP_CLIENT_TIMEOUT, AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER, AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA, - AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL, ENABLE_FORWARD_USER_INFO_HEADERS, FORWARD_SESSION_INFO_HEADER_CHAT_ID, FORWARD_SESSION_INFO_HEADER_MESSAGE_ID, REDIS_KEY_PREFIX, ) -from open_webui.utils.headers import include_user_info_headers, get_custom_headers +from open_webui.models.access_grants import AccessGrants +from open_webui.models.groups import Groups +from open_webui.models.tools import Tools +from open_webui.models.users import UserModel from open_webui.tools.builtin import ( - search_web, - fetch_url, - generate_image, + add_memory, + calculate_timestamp, + create_automation, + create_calendar_event, + create_tasks, + delete_automation, + delete_calendar_event, + delete_memory, edit_image, execute_code, - search_memories, - add_memory, - replace_memory_content, - delete_memory, - list_memories, + fetch_url, + generate_image, get_current_timestamp, - calculate_timestamp, - search_notes, - search_chats, - search_channels, + list_automations, + list_knowledge, + list_knowledge_bases, + list_memories, + query_knowledge_bases, + query_knowledge_files, + replace_memory_content, + replace_note_content, + search_calendar_events, search_channel_messages, - view_note, - view_chat, + search_channels, + search_chats, + search_knowledge_bases, + search_knowledge_files, + search_memories, + search_notes, + search_web, + toggle_automation, + update_automation, + update_calendar_event, + update_task, view_channel_message, view_channel_thread, - replace_note_content, - write_note, - list_knowledge_bases, - search_knowledge_bases, - query_knowledge_bases, - search_knowledge_files, - query_knowledge_files, - list_knowledge, + view_chat, view_file, view_knowledge_file, + view_note, view_skill, - create_tasks, - update_task, - create_automation, - update_automation, - list_automations, - toggle_automation, - delete_automation, - search_calendar_events, - create_calendar_event, - update_calendar_event, - delete_calendar_event, + write_note, ) - -from open_webui.utils.access_control import has_permission +from open_webui.utils.access_control import has_access, has_connection_access, has_permission +from open_webui.utils.headers import get_custom_headers, include_user_info_headers +from open_webui.utils.misc import is_string_allowed +from open_webui.utils.plugin import load_tool_module_by_id +from pydantic import BaseModel, Field, create_model +from pydantic.fields import FieldInfo log = logging.getLogger(__name__) @@ -952,8 +942,8 @@ async def get_tool_servers(request: Request): async def get_terminal_cwd( base_url: str, headers: dict, - cookies: Optional[dict] = None, -) -> Optional[str]: + cookies: dict | None = None, +) -> str | None: """Fetch the current working directory from a terminal server.""" try: cwd_url = f'{base_url.rstrip("/")}/files/cwd' @@ -975,8 +965,8 @@ async def get_terminal_cwd( async def get_terminal_system_prompt( base_url: str, headers: dict, - cookies: Optional[dict] = None, -) -> Optional[str]: + cookies: dict | None = None, +) -> str | None: """Fetch the system prompt from a terminal server. Checks ``/api/config`` for the ``system`` feature flag first; @@ -1096,7 +1086,7 @@ async def get_terminal_tools( terminal_id: str, user: UserModel, extra_params: dict, -) -> dict[str, dict] | tuple[dict[str, dict], Optional[str]]: +) -> dict[str, dict] | tuple[dict[str, dict], str | None]: """Resolve tools for a terminal server identified by terminal_id. - Finds the connection in TERMINAL_SERVER_CONNECTIONS @@ -1189,7 +1179,7 @@ async def get_terminal_tools( return tools_dict, system_prompt -async def get_tool_server_data(url: str, headers: Optional[dict]) -> Dict[str, Any]: +async def get_tool_server_data(url: str, headers: dict | None) -> dict[str, Any]: _headers = { 'Accept': 'application/json', 'Content-Type': 'application/json', @@ -1231,7 +1221,7 @@ async def get_tool_server_data(url: str, headers: Optional[dict]) -> Dict[str, A return res -async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, Any]]: +async def get_tool_servers_data(servers: list[dict[str, Any]]) -> list[dict[str, Any]]: # Prepare list of enabled servers along with their original index tasks = [] @@ -1332,12 +1322,12 @@ async def get_tool_servers_data(servers: List[Dict[str, Any]]) -> List[Dict[str, async def execute_tool_server( url: str, - headers: Dict[str, str], - cookies: Dict[str, str], + headers: dict[str, str], + cookies: dict[str, str], name: str, - params: Dict[str, Any], - server_data: Dict[str, Any], -) -> Tuple[Dict[str, Any], Optional[Dict[str, Any]]]: + params: dict[str, Any], + server_data: dict[str, Any], +) -> tuple[dict[str, Any], dict[str, Any | None]]: error = None try: openapi = server_data.get('openapi', {}) @@ -1485,7 +1475,7 @@ async def execute_tool_server( return ({'error': error}, None) -def get_tool_server_url(url: Optional[str], path: str) -> str: +def get_tool_server_url(url: str | None, path: str) -> str: """ Build the full URL for a tool server, given a base url and a path. """ diff --git a/backend/open_webui/utils/webhook.py b/backend/open_webui/utils/webhook.py index ee7f3ab3b2..616d2611aa 100644 --- a/backend/open_webui/utils/webhook.py +++ b/backend/open_webui/utils/webhook.py @@ -1,7 +1,7 @@ import json import logging -import aiohttp +import aiohttp from open_webui.config import WEBUI_FAVICON_URL from open_webui.env import AIOHTTP_CLIENT_SESSION_SSL, AIOHTTP_CLIENT_TIMEOUT, VERSION