refac: modernize type annotations (PEP 604 / PEP 585)

This commit is contained in:
Timothy Jaeryang Baek
2026-05-12 17:10:15 +09:00
parent a59c967d7e
commit 6d0295588e
197 changed files with 3265 additions and 3488 deletions
+26 -31
View File
@@ -1,3 +1,5 @@
from __future__ import annotations
import base64
import json
import logging
@@ -20,29 +22,36 @@ from open_webui.env import (
DATABASE_URL,
ENABLE_DB_MIGRATIONS,
ENV,
REDIS_URL,
REDIS_KEY_PREFIX,
REDIS_SENTINEL_HOSTS,
REDIS_SENTINEL_PORT,
FRONTEND_BUILD_DIR,
OFFLINE_MODE,
OPEN_WEBUI_DIR,
REDIS_KEY_PREFIX,
REDIS_SENTINEL_HOSTS,
REDIS_SENTINEL_PORT,
REDIS_URL,
WEBUI_AUTH,
WEBUI_FAVICON_URL,
WEBUI_NAME,
log,
)
from open_webui.internal.config import (
STATE as _state,
)
from open_webui.internal.config import (
AppConfig,
ConfigVar,
)
# ── Persistent configuration layer ──────────────────────────────────────────
from open_webui.internal.config import ( # noqa: F401
ConfigTable as Config,
ConfigVar,
AppConfig,
STATE as _state,
initialize as _initialize_config,
)
from open_webui.internal.config import (
_all_configs as PERSISTENT_CONFIG_REGISTRY,
)
from open_webui.internal.config import (
initialize as _initialize_config,
)
def get_config():
@@ -1288,9 +1297,7 @@ RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
not OFFLINE_MODE and os.getenv('RAG_EMBEDDING_MODEL_AUTO_UPDATE', 'True').lower() == 'true'
)
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os.getenv('RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE', 'True').lower() == 'true'
)
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = os.getenv('RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE', 'True').lower() == 'true'
RAG_EMBEDDING_BATCH_SIZE = ConfigVar(
'RAG_EMBEDDING_BATCH_SIZE',
@@ -1335,9 +1342,7 @@ RAG_RERANKING_MODEL_AUTO_UPDATE = (
not OFFLINE_MODE and os.getenv('RAG_RERANKING_MODEL_AUTO_UPDATE', 'True').lower() == 'true'
)
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
os.getenv('RAG_RERANKING_MODEL_TRUST_REMOTE_CODE', 'True').lower() == 'true'
)
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = os.getenv('RAG_RERANKING_MODEL_TRUST_REMOTE_CODE', 'True').lower() == 'true'
RAG_RERANKING_BATCH_SIZE = ConfigVar(
'RAG_RERANKING_BATCH_SIZE',
@@ -2709,9 +2714,7 @@ USER_PERMISSIONS_CHAT_DELETE = os.getenv('USER_PERMISSIONS_CHAT_DELETE', 'True')
USER_PERMISSIONS_CHAT_DELETE_MESSAGE = os.getenv('USER_PERMISSIONS_CHAT_DELETE_MESSAGE', 'True').lower() == 'true'
USER_PERMISSIONS_CHAT_CONTINUE_RESPONSE = (
os.getenv('USER_PERMISSIONS_CHAT_CONTINUE_RESPONSE', 'True').lower() == 'true'
)
USER_PERMISSIONS_CHAT_CONTINUE_RESPONSE = os.getenv('USER_PERMISSIONS_CHAT_CONTINUE_RESPONSE', 'True').lower() == 'true'
USER_PERMISSIONS_CHAT_REGENERATE_RESPONSE = (
os.getenv('USER_PERMISSIONS_CHAT_REGENERATE_RESPONSE', 'True').lower() == 'true'
@@ -2735,9 +2738,7 @@ USER_PERMISSIONS_CHAT_TTS = os.getenv('USER_PERMISSIONS_CHAT_TTS', 'True').lower
USER_PERMISSIONS_CHAT_CALL = os.getenv('USER_PERMISSIONS_CHAT_CALL', 'True').lower() == 'true'
USER_PERMISSIONS_CHAT_MULTIPLE_MODELS = (
os.getenv('USER_PERMISSIONS_CHAT_MULTIPLE_MODELS', 'True').lower() == 'true'
)
USER_PERMISSIONS_CHAT_MULTIPLE_MODELS = os.getenv('USER_PERMISSIONS_CHAT_MULTIPLE_MODELS', 'True').lower() == 'true'
USER_PERMISSIONS_CHAT_TEMPORARY = os.getenv('USER_PERMISSIONS_CHAT_TEMPORARY', 'True').lower() == 'true'
@@ -2770,9 +2771,7 @@ USER_PERMISSIONS_FEATURES_API_KEYS = os.getenv('USER_PERMISSIONS_FEATURES_API_KE
USER_PERMISSIONS_FEATURES_MEMORIES = os.getenv('USER_PERMISSIONS_FEATURES_MEMORIES', 'True').lower() == 'true'
USER_PERMISSIONS_FEATURES_AUTOMATIONS = (
os.getenv('USER_PERMISSIONS_FEATURES_AUTOMATIONS', 'False').lower() == 'true'
)
USER_PERMISSIONS_FEATURES_AUTOMATIONS = os.getenv('USER_PERMISSIONS_FEATURES_AUTOMATIONS', 'False').lower() == 'true'
USER_PERMISSIONS_FEATURES_CALENDAR = os.getenv('USER_PERMISSIONS_FEATURES_CALENDAR', 'True').lower() == 'true'
@@ -2940,9 +2939,7 @@ WEBHOOK_URL = ConfigVar('WEBHOOK_URL', 'webhook_url', os.getenv('WEBHOOK_URL', '
ENABLE_ADMIN_EXPORT = os.getenv('ENABLE_ADMIN_EXPORT', 'True').lower() == 'true'
ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS = (
os.getenv('ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS', 'True').lower() == 'true'
)
ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS = os.getenv('ENABLE_ADMIN_WORKSPACE_CONTENT_ACCESS', 'True').lower() == 'true'
BYPASS_ADMIN_ACCESS_CONTROL = (
os.getenv(
@@ -3024,7 +3021,7 @@ else:
class BannerModel(BaseModel):
id: str
type: str
title: Optional[str] = None
title: str | None = None
content: str
dismissible: bool
timestamp: int
@@ -3705,9 +3702,7 @@ OAUTH_ALLOWED_ROLES = ConfigVar(
'oauth.allowed_roles',
[
role.strip()
for role in os.getenv('OAUTH_ALLOWED_ROLES', f'user{OAUTH_ROLES_SEPARATOR}admin').split(
OAUTH_ROLES_SEPARATOR
)
for role in os.getenv('OAUTH_ALLOWED_ROLES', f'user{OAUTH_ROLES_SEPARATOR}admin').split(OAUTH_ROLES_SEPARATOR)
if role
],
)
+2
View File
@@ -1,3 +1,5 @@
from __future__ import annotations
from enum import Enum
+3 -9
View File
@@ -530,9 +530,7 @@ except (ValueError, TypeError):
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER_DATA = 10
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL = (
os.getenv('AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL', 'True').lower() == 'true'
)
AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL = os.getenv('AIOHTTP_CLIENT_SESSION_TOOL_SERVER_SSL', 'True').lower() == 'true'
AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER = os.getenv('AIOHTTP_CLIENT_TIMEOUT_TOOL_SERVER', '')
@@ -779,9 +777,7 @@ FORWARD_USER_INFO_HEADER_USER_NAME = os.getenv('FORWARD_USER_INFO_HEADER_USER_NA
FORWARD_USER_INFO_HEADER_USER_ID = os.getenv('FORWARD_USER_INFO_HEADER_USER_ID', 'X-OpenWebUI-User-Id')
FORWARD_USER_INFO_HEADER_USER_EMAIL = os.getenv('FORWARD_USER_INFO_HEADER_USER_EMAIL', 'X-OpenWebUI-User-Email')
FORWARD_USER_INFO_HEADER_USER_ROLE = os.getenv('FORWARD_USER_INFO_HEADER_USER_ROLE', 'X-OpenWebUI-User-Role')
FORWARD_SESSION_INFO_HEADER_MESSAGE_ID = os.getenv(
'FORWARD_SESSION_INFO_HEADER_MESSAGE_ID', 'X-OpenWebUI-Message-Id'
)
FORWARD_SESSION_INFO_HEADER_MESSAGE_ID = os.getenv('FORWARD_SESSION_INFO_HEADER_MESSAGE_ID', 'X-OpenWebUI-Message-Id')
FORWARD_SESSION_INFO_HEADER_CHAT_ID = os.getenv('FORWARD_SESSION_INFO_HEADER_CHAT_ID', 'X-OpenWebUI-Chat-Id')
####################################
@@ -897,9 +893,7 @@ if SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND == '':
SENTENCE_TRANSFORMERS_CROSS_ENCODER_BACKEND = 'torch'
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = os.getenv(
'SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS', ''
)
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = os.getenv('SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS', '')
if SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS == '':
SENTENCE_TRANSFORMERS_CROSS_ENCODER_MODEL_KWARGS = None
else:
+15 -21
View File
@@ -1,11 +1,10 @@
import logging
import sys
import asyncio
import inspect
import json
import asyncio
from pydantic import BaseModel
import logging
import sys
from typing import AsyncGenerator, Generator, Iterator
from fastapi import (
Depends,
FastAPI,
@@ -16,40 +15,35 @@ from fastapi import (
UploadFile,
status,
)
from pydantic import BaseModel
from starlette.responses import Response, StreamingResponse
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL
from open_webui.constants import ERROR_MESSAGES
from open_webui.env import BYPASS_MODEL_ACCESS_CONTROL, GLOBAL_LOG_LEVEL
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.models.users import UserModel
from open_webui.socket.main import (
get_event_call,
get_event_emitter,
)
from open_webui.models.users import UserModel
from open_webui.models.functions import Functions
from open_webui.models.models import Models
from open_webui.utils.plugin import (
load_function_module_by_id,
get_function_module_from_cache,
)
from open_webui.utils.access_control import check_model_access
from open_webui.env import GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
from open_webui.config import BYPASS_ADMIN_ACCESS_CONTROL
from open_webui.utils.misc import (
add_or_update_system_message,
get_last_user_message,
prepend_to_first_user_message_content,
openai_chat_chunk_message_template,
openai_chat_completion_message_template,
prepend_to_first_user_message_content,
)
from open_webui.utils.payload import (
apply_model_params_to_body_openai,
apply_system_prompt_to_body,
)
from open_webui.utils.plugin import (
get_function_module_from_cache,
load_function_module_by_id,
)
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
+28 -29
View File
@@ -10,10 +10,9 @@ from functools import reduce
from typing import Any, Optional, Union
import redis
from sqlalchemy import JSON, Column, DateTime, Integer, func, select
from open_webui.internal.db import Base, get_async_db, get_db
from open_webui.utils.redis import get_redis_connection
from sqlalchemy import JSON, Column, DateTime, Integer, func, select
log = logging.getLogger(__name__)
@@ -22,7 +21,7 @@ log = logging.getLogger(__name__)
class ConfigTable(Base):
__tablename__ = "config"
__tablename__ = 'config'
id = Column(Integer, primary_key=True)
data = Column(JSON, nullable=False)
@@ -37,7 +36,7 @@ class ConfigTable(Base):
class ConfigState:
"""In-memory mirror of the single-row config JSON blob."""
__slots__ = ("_data",)
__slots__ = ('_data',)
def __init__(self) -> None:
self._data: dict[str, Any] = {}
@@ -49,12 +48,12 @@ class ConfigState:
def read(self, path: str) -> Any:
return reduce(
lambda n, k: n.get(k) if isinstance(n, dict) else None,
path.split("."),
path.split('.'),
self._data,
)
def write(self, path: str, value: Any) -> None:
keys = path.split(".")
keys = path.split('.')
reduce(lambda d, k: d.setdefault(k, {}), keys[:-1], self._data)[keys[-1]] = value
def replace(self, data: dict) -> None:
@@ -63,7 +62,7 @@ class ConfigState:
def load(self) -> dict:
with get_db() as db:
row = db.query(ConfigTable).order_by(ConfigTable.id.desc()).first()
self._data = row.data if row else {"version": 0, "ui": {}}
self._data = row.data if row else {'version': 0, 'ui': {}}
return self._data
def persist(self, data: dict | None = None) -> None:
@@ -123,7 +122,7 @@ def initialize(*, enable_persistent: bool = True, enable_oauth_persistent: bool
class ConfigVar:
__slots__ = ("env_name", "config_path", "env_value", "config_value", "value")
__slots__ = ('env_name', 'config_path', 'env_value', 'config_value', 'value')
def __init__(self, env_name: str, config_path: str, env_value: Any) -> None:
self.env_name = env_name
@@ -132,7 +131,7 @@ class ConfigVar:
self.config_value = STATE.read(config_path)
if self.config_value is not None and _persist_enabled:
if config_path.startswith("oauth.") and not _oauth_persist_enabled:
if config_path.startswith('oauth.') and not _oauth_persist_enabled:
log.info("Skipping DB value for '%s' (OAuth persistence disabled)", env_name)
self.value = env_value
else:
@@ -147,22 +146,22 @@ class ConfigVar:
return str(self.value)
def __repr__(self) -> str:
return f"<ConfigVar {self.env_name}={self.value!r}>"
return f'<ConfigVar {self.env_name}={self.value!r}>'
@property
def __dict__(self): # type: ignore[override]
raise TypeError(f"ConfigVar('{self.env_name}') cannot be cast to dict; use .value")
def __getattribute__(self, item: str):
if item == "__dict__":
raise TypeError("ConfigVar cannot be cast to dict; use .value")
if item == '__dict__':
raise TypeError('ConfigVar cannot be cast to dict; use .value')
return super().__getattribute__(item)
def refresh(self) -> None:
current = STATE.read(self.config_path)
if current is not None:
self.value = current
log.info("Refreshed %s%s", self.env_name, self.value)
log.info('Refreshed %s%s', self.env_name, self.value)
def commit(self) -> None:
log.info("Persisting '%s'", self.env_name)
@@ -189,18 +188,18 @@ class AppConfig:
redis_url: Optional[str] = None,
redis_sentinels: Optional[list] = None,
redis_cluster: bool = False,
redis_key_prefix: str = "open-webui",
redis_key_prefix: str = 'open-webui',
) -> None:
super().__setattr__("_entries", {})
super().__setattr__("_key_prefix", redis_key_prefix)
super().__setattr__('_entries', {})
super().__setattr__('_key_prefix', redis_key_prefix)
rc: Union[redis.Redis, redis.cluster.RedisCluster, None] = None
if redis_url:
rc = get_redis_connection(redis_url, redis_sentinels or [], redis_cluster, decode_responses=True)
super().__setattr__("_rc", rc)
super().__setattr__('_rc', rc)
def __setattr__(self, name: str, value: Any) -> None:
entries: dict = super().__getattribute__("_entries")
entries: dict = super().__getattribute__('_entries')
if isinstance(value, ConfigVar):
entries[name] = value
@@ -213,11 +212,11 @@ class AppConfig:
except RuntimeError:
entries[name].commit()
rc = super().__getattribute__("_rc")
rc = super().__getattribute__('_rc')
if rc and _persist_enabled:
prefix = super().__getattribute__("_key_prefix")
prefix = super().__getattribute__('_key_prefix')
try:
rc.set(f"{prefix}:config:{name}", json.dumps(entries[name].value))
rc.set(f'{prefix}:config:{name}', json.dumps(entries[name].value))
except Exception as exc:
log.error("Redis write failed for '%s': %s", name, exc)
@@ -228,15 +227,15 @@ class AppConfig:
log.error("Async persist failed for '%s': %s", name, exc)
def __getattr__(self, name: str) -> Any:
entries = super().__getattribute__("_entries")
entries = super().__getattribute__('_entries')
if name not in entries:
raise AttributeError(f"No config key '{name}'")
rc = super().__getattribute__("_rc")
rc = super().__getattribute__('_rc')
if rc and _persist_enabled:
prefix = super().__getattribute__("_key_prefix")
prefix = super().__getattribute__('_key_prefix')
try:
raw = rc.get(f"{prefix}:config:{name}")
raw = rc.get(f'{prefix}:config:{name}')
if raw is not None:
decoded = json.loads(raw)
if entries[name].value != decoded:
@@ -248,12 +247,12 @@ class AppConfig:
return entries[name].value
def _sync_to_redis(self) -> None:
rc = super().__getattribute__("_rc")
rc = super().__getattribute__('_rc')
if not rc or not _persist_enabled:
return
prefix = super().__getattribute__("_key_prefix")
for name, s in super().__getattribute__("_entries").items():
prefix = super().__getattribute__('_key_prefix')
for name, s in super().__getattribute__('_entries').items():
try:
rc.set(f"{prefix}:config:{name}", json.dumps(s.value))
rc.set(f'{prefix}:config:{name}', json.dumps(s.value))
except Exception as exc:
log.error("Redis sync failed for '%s': %s", name, exc)
+19 -18
View File
@@ -1,35 +1,36 @@
import os
import sys
from __future__ import annotations
import json
import logging
import os
import sys
from contextlib import asynccontextmanager, contextmanager
from typing import Any, Optional
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
from open_webui.env import (
OPEN_WEBUI_DIR,
DATABASE_URL,
DATABASE_SCHEMA,
DATABASE_ENABLE_SESSION_SHARING,
DATABASE_ENABLE_SQLITE_WAL,
DATABASE_POOL_MAX_OVERFLOW,
DATABASE_POOL_RECYCLE,
DATABASE_POOL_SIZE,
DATABASE_POOL_TIMEOUT,
DATABASE_ENABLE_SQLITE_WAL,
DATABASE_ENABLE_SESSION_SHARING,
DATABASE_SQLITE_PRAGMA_SYNCHRONOUS,
DATABASE_SCHEMA,
DATABASE_SQLITE_PRAGMA_BUSY_TIMEOUT,
DATABASE_SQLITE_PRAGMA_CACHE_SIZE,
DATABASE_SQLITE_PRAGMA_TEMP_STORE,
DATABASE_SQLITE_PRAGMA_MMAP_SIZE,
DATABASE_SQLITE_PRAGMA_JOURNAL_SIZE_LIMIT,
DATABASE_SQLITE_PRAGMA_MMAP_SIZE,
DATABASE_SQLITE_PRAGMA_SYNCHRONOUS,
DATABASE_SQLITE_PRAGMA_TEMP_STORE,
DATABASE_URL,
ENABLE_DB_MIGRATIONS,
OPEN_WEBUI_DIR,
)
from sqlalchemy import Dialect, create_engine, MetaData, event, types
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy import Dialect, MetaData, create_engine, event, types
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session, sessionmaker, Session
from sqlalchemy.pool import QueuePool, NullPool
from sqlalchemy.orm import Session, scoped_session, sessionmaker
from sqlalchemy.pool import NullPool, QueuePool
from sqlalchemy.sql.type_api import _T
from typing_extensions import Self
@@ -120,10 +121,10 @@ class JSONField(types.TypeDecorator):
impl = types.Text
cache_ok = True
def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Any:
def process_bind_param(self, value: _T | None, dialect: Dialect) -> Any:
return json.dumps(value)
def process_result_value(self, value: Optional[_T], dialect: Dialect) -> Any:
def process_result_value(self, value: _T | None, dialect: Dialect) -> Any:
if value is not None:
return json.loads(value)
@@ -374,7 +375,7 @@ async def get_async_db():
@asynccontextmanager
async def get_async_db_context(db: Optional[AsyncSession] = None):
async def get_async_db_context(db: AsyncSession | None = None):
"""Async context manager that reuses an existing session if provided and session sharing is enabled."""
if isinstance(db, AsyncSession) and DATABASE_ENABLE_SESSION_SHARING:
yield db
+523 -538
View File
File diff suppressed because it is too large Load Diff
+13 -13
View File
@@ -1,3 +1,5 @@
from __future__ import annotations
"""Alembic environment configuration.
Configures the migration context for both offline (SQL script generation)
@@ -23,7 +25,7 @@ if config.config_file_name is not None:
fileConfig(config.config_file_name, disable_existing_loggers=False)
# Re-apply JSON formatter after fileConfig replaces handlers.
if LOG_FORMAT == "json":
if LOG_FORMAT == 'json':
from open_webui.env import JSONFormatter
for handler in logging.root.handlers:
@@ -41,7 +43,7 @@ if _ssl_params:
DB_URL = reattach_ssl_params_to_url(_url_no_ssl, _ssl_params)
if DB_URL:
config.set_main_option("sqlalchemy.url", DB_URL.replace("%", "%%"))
config.set_main_option('sqlalchemy.url', DB_URL.replace('%', '%%'))
# ── Migration runners ────────────────────────────────────────────────────────
@@ -49,12 +51,12 @@ if DB_URL:
def run_migrations_offline() -> None:
"""Generate SQL script without a live database connection."""
url = config.get_main_option("sqlalchemy.url")
url = config.get_main_option('sqlalchemy.url')
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
dialect_opts={'paramstyle': 'named'},
)
with context.begin_transaction():
context.run_migrations()
@@ -62,14 +64,12 @@ def run_migrations_offline() -> None:
def _build_connectable():
"""Create the appropriate SQLAlchemy engine for the configured DB URL."""
if DB_URL and DB_URL.startswith("sqlite+sqlcipher://"):
if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == "":
raise ValueError(
"DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs"
)
if DB_URL and DB_URL.startswith('sqlite+sqlcipher://'):
if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == '':
raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs')
db_path = DB_URL.replace("sqlite+sqlcipher://", "")
if db_path.startswith("/"):
db_path = DB_URL.replace('sqlite+sqlcipher://', '')
if db_path.startswith('/'):
db_path = db_path[1:]
def _sqlcipher_creator():
@@ -79,11 +79,11 @@ def _build_connectable():
conn.execute(f"PRAGMA key = '{DATABASE_PASSWORD}'")
return conn
return create_engine("sqlite://", creator=_sqlcipher_creator, echo=False)
return create_engine('sqlite://', creator=_sqlcipher_creator, echo=False)
return engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
prefix='sqlalchemy.',
poolclass=pool.NullPool,
)
+2
View File
@@ -1,3 +1,5 @@
from __future__ import annotations
"""Alembic migration utilities."""
from alembic import op
@@ -6,8 +6,8 @@ Create Date: 2025-08-13 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from alembic import op
revision = '018012973d35'
down_revision = 'd31026856c01'
@@ -6,13 +6,13 @@ Create Date: 2024-10-09 21:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, select, update, column
from sqlalchemy.engine.reflection import Inspector
import json
import sqlalchemy as sa
from alembic import op
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql import column, select, table, update
revision = '1af9b942657b'
down_revision = '242a2047eae0'
branch_labels = None
@@ -6,12 +6,12 @@ Create Date: 2024-10-09 21:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, select, update
import json
import sqlalchemy as sa
from alembic import op
from sqlalchemy.sql import select, table, update
revision = '242a2047eae0'
down_revision = '6a39f3d8e55c'
branch_labels = None
@@ -8,9 +8,9 @@ Create Date: 2025-11-27 03:07:56.200231
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '2f1211949ecc'
@@ -6,11 +6,11 @@ Create Date: 2026-01-23 17:15:00.000000
"""
from typing import Sequence, Union
import uuid
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
revision: str = '374d2f66af06'
down_revision: Union[str, None] = 'c440947495f3'
@@ -6,8 +6,8 @@ Create Date: 2024-12-30 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from alembic import op
revision = '3781e22d8b01'
down_revision = '7826ab40b532'
@@ -6,13 +6,13 @@ Create Date: 2025-11-17 03:45:25.123939
"""
import uuid
import time
import json
import time
import uuid
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '37f288994c47'
@@ -8,8 +8,8 @@ Create Date: 2025-09-08 14:19:59.583921
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '38d63c18f30f'
@@ -6,13 +6,13 @@ Create Date: 2024-10-09 21:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, select, update, column
from sqlalchemy.engine.reflection import Inspector
import json
import sqlalchemy as sa
from alembic import op
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql import column, select, table, update
revision = '3ab32c4b8f59'
down_revision = '1af9b942657b'
branch_labels = None
@@ -8,8 +8,8 @@ Create Date: 2025-08-21 02:07:18.078283
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '3af16a1c9fb6'
@@ -6,16 +6,15 @@ Create Date: 2025-12-02 06:54:19.401334
"""
import json
import time
import uuid
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy import inspect
import open_webui.internal.db
import time
import json
import uuid
import sqlalchemy as sa
from alembic import op
from sqlalchemy import inspect
# revision identifiers, used by Alembic.
revision: str = '3e0e00844bb0'
@@ -6,8 +6,8 @@ Create Date: 2024-10-23 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from alembic import op
revision = '4ace53fd72c8'
down_revision = 'af906e964978'
@@ -8,10 +8,9 @@ Create Date: 2026-05-09 04:29:27.651341
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '4de81c2a3af1'
@@ -20,10 +19,11 @@ branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
import uuid
import time
from sqlalchemy import select, update, insert
from sqlalchemy.sql import table, column
import uuid
from sqlalchemy import insert, select, update
from sqlalchemy.sql import column, table
def upgrade() -> None:
@@ -8,9 +8,8 @@ Create Date: 2026-04-19 16:20:58.162045
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '56359461a091'
@@ -6,8 +6,8 @@ Create Date: 2024-12-22 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from alembic import op
revision = '57c599a3cb57'
down_revision = '922e7a387820'
@@ -8,9 +8,9 @@ Create Date: 2025-12-10 15:11:39.424601
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '6283dc0e4d8d'
@@ -6,11 +6,12 @@ Create Date: 2024-10-01 14:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table, column, select
import json
import sqlalchemy as sa
from alembic import op
from sqlalchemy.sql import column, select, table
revision = '6a39f3d8e55c'
down_revision = 'c0fbf31ca0db'
branch_labels = None
@@ -6,8 +6,8 @@ Create Date: 2024-12-23 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from alembic import op
revision = '7826ab40b532'
down_revision = '57c599a3cb57'
@@ -1,3 +1,5 @@
from __future__ import annotations
"""Initial Alembic schema — creates all base tables.
Revision ID: 7e5b5dc7342b
@@ -7,14 +9,13 @@ Create Date: 2024-06-24 13:15:33.808998
from typing import Sequence, Union
import open_webui.internal.db # noqa: F401
import sqlalchemy as sa
from alembic import op
import open_webui.internal.db # noqa: F401
from open_webui.internal.db import JSONField
from open_webui.migrations.util import get_existing_tables
revision: str = "7e5b5dc7342b"
revision: str = '7e5b5dc7342b'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -25,162 +26,162 @@ depends_on: Union[str, Sequence[str], None] = None
_TABLES: list[tuple[str, list[sa.Column], list]] = [
(
"auth",
'auth',
[
sa.Column("id", sa.String(), nullable=False),
sa.Column("email", sa.String(), nullable=True),
sa.Column("password", sa.Text(), nullable=True),
sa.Column("active", sa.Boolean(), nullable=True),
sa.Column('id', sa.String(), nullable=False),
sa.Column('email', sa.String(), nullable=True),
sa.Column('password', sa.Text(), nullable=True),
sa.Column('active', sa.Boolean(), nullable=True),
],
[sa.PrimaryKeyConstraint("id")],
[sa.PrimaryKeyConstraint('id')],
),
(
"chat",
'chat',
[
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("chat", sa.Text(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("share_id", sa.Text(), nullable=True),
sa.Column("archived", sa.Boolean(), nullable=True),
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('title', sa.Text(), nullable=True),
sa.Column('chat', sa.Text(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('share_id', sa.Text(), nullable=True),
sa.Column('archived', sa.Boolean(), nullable=True),
],
[sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("share_id")],
[sa.PrimaryKeyConstraint('id'), sa.UniqueConstraint('share_id')],
),
(
"chatidtag",
'chatidtag',
[
sa.Column("id", sa.String(), nullable=False),
sa.Column("tag_name", sa.String(), nullable=True),
sa.Column("chat_id", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.Column('id', sa.String(), nullable=False),
sa.Column('tag_name', sa.String(), nullable=True),
sa.Column('chat_id', sa.String(), nullable=True),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('timestamp', sa.BigInteger(), nullable=True),
],
[sa.PrimaryKeyConstraint("id")],
[sa.PrimaryKeyConstraint('id')],
),
(
"document",
'document',
[
sa.Column("collection_name", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("filename", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.Column('collection_name', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('title', sa.Text(), nullable=True),
sa.Column('filename', sa.Text(), nullable=True),
sa.Column('content', sa.Text(), nullable=True),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('timestamp', sa.BigInteger(), nullable=True),
],
[sa.PrimaryKeyConstraint("collection_name"), sa.UniqueConstraint("name")],
[sa.PrimaryKeyConstraint('collection_name'), sa.UniqueConstraint('name')],
),
(
"file",
'file',
[
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("filename", sa.Text(), nullable=True),
sa.Column("meta", JSONField(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('filename', sa.Text(), nullable=True),
sa.Column('meta', JSONField(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
],
[sa.PrimaryKeyConstraint("id")],
[sa.PrimaryKeyConstraint('id')],
),
(
"function",
'function',
[
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("type", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("meta", JSONField(), nullable=True),
sa.Column("valves", JSONField(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=True),
sa.Column("is_global", sa.Boolean(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('name', sa.Text(), nullable=True),
sa.Column('type', sa.Text(), nullable=True),
sa.Column('content', sa.Text(), nullable=True),
sa.Column('meta', JSONField(), nullable=True),
sa.Column('valves', JSONField(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True),
sa.Column('is_global', sa.Boolean(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
],
[sa.PrimaryKeyConstraint("id")],
[sa.PrimaryKeyConstraint('id')],
),
(
"memory",
'memory',
[
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('content', sa.Text(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
],
[sa.PrimaryKeyConstraint("id")],
[sa.PrimaryKeyConstraint('id')],
),
(
"model",
'model',
[
sa.Column("id", sa.Text(), nullable=False),
sa.Column("user_id", sa.Text(), nullable=True),
sa.Column("base_model_id", sa.Text(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("params", JSONField(), nullable=True),
sa.Column("meta", JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column('id', sa.Text(), nullable=False),
sa.Column('user_id', sa.Text(), nullable=True),
sa.Column('base_model_id', sa.Text(), nullable=True),
sa.Column('name', sa.Text(), nullable=True),
sa.Column('params', JSONField(), nullable=True),
sa.Column('meta', JSONField(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
],
[sa.PrimaryKeyConstraint("id")],
[sa.PrimaryKeyConstraint('id')],
),
(
"prompt",
'prompt',
[
sa.Column("command", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.Column('command', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('title', sa.Text(), nullable=True),
sa.Column('content', sa.Text(), nullable=True),
sa.Column('timestamp', sa.BigInteger(), nullable=True),
],
[sa.PrimaryKeyConstraint("command")],
[sa.PrimaryKeyConstraint('command')],
),
(
"tag",
'tag',
[
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("data", sa.Text(), nullable=True),
sa.Column('id', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('data', sa.Text(), nullable=True),
],
[sa.PrimaryKeyConstraint("id")],
[sa.PrimaryKeyConstraint('id')],
),
(
"tool",
'tool',
[
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("name", sa.Text(), nullable=True),
sa.Column("content", sa.Text(), nullable=True),
sa.Column("specs", JSONField(), nullable=True),
sa.Column("meta", JSONField(), nullable=True),
sa.Column("valves", JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('name', sa.Text(), nullable=True),
sa.Column('content', sa.Text(), nullable=True),
sa.Column('specs', JSONField(), nullable=True),
sa.Column('meta', JSONField(), nullable=True),
sa.Column('valves', JSONField(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
],
[sa.PrimaryKeyConstraint("id")],
[sa.PrimaryKeyConstraint('id')],
),
(
"user",
'user',
[
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("email", sa.String(), nullable=True),
sa.Column("role", sa.String(), nullable=True),
sa.Column("profile_image_url", sa.Text(), nullable=True),
sa.Column("last_active_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("api_key", sa.String(), nullable=True),
sa.Column("settings", JSONField(), nullable=True),
sa.Column("info", JSONField(), nullable=True),
sa.Column("oauth_sub", sa.Text(), nullable=True),
sa.Column('id', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('email', sa.String(), nullable=True),
sa.Column('role', sa.String(), nullable=True),
sa.Column('profile_image_url', sa.Text(), nullable=True),
sa.Column('last_active_at', sa.BigInteger(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column('api_key', sa.String(), nullable=True),
sa.Column('settings', JSONField(), nullable=True),
sa.Column('info', JSONField(), nullable=True),
sa.Column('oauth_sub', sa.Text(), nullable=True),
],
[
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("api_key"),
sa.UniqueConstraint("oauth_sub"),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('api_key'),
sa.UniqueConstraint('oauth_sub'),
],
),
]
@@ -8,9 +8,9 @@ Create Date: 2025-12-10 16:07:58.001282
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '81cc2ce44d79'
@@ -6,13 +6,13 @@ Create Date: 2026-02-01 04:00:00.000000
"""
import time
import json
import logging
import time
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
log = logging.getLogger(__name__)
@@ -8,9 +8,9 @@ Create Date: 2025-11-30 06:33:38.790341
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = '90ef40d4714e'
@@ -6,8 +6,8 @@ Create Date: 2024-11-14 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from alembic import op
revision = '922e7a387820'
down_revision = '4ace53fd72c8'
@@ -6,8 +6,8 @@ Create Date: 2025-05-03 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from alembic import op
revision = '9f0c9cd09105'
down_revision = '3781e22d8b01'
@@ -8,9 +8,8 @@ Create Date: 2026-02-11 09:30:00.000000
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
from open_webui.migrations.util import get_existing_tables
revision: str = 'a1b2c3d4e5f6'
@@ -8,8 +8,8 @@ Create Date: 2026-03-29 22:15:00.000000
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'a3dd5bedd151'
@@ -8,8 +8,8 @@ Create Date: 2025-09-27 02:24:18.058455
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'a5c220713937'
@@ -6,8 +6,8 @@ Create Date: 2024-10-20 17:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
from alembic import op
# Revision identifiers, used by Alembic.
revision = 'af906e964978'
@@ -6,15 +6,13 @@ Create Date: 2025-11-28 04:55:31.737538
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import open_webui.internal.db
import json
import time
from typing import Sequence, Union
import open_webui.internal.db
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'b10670c03dd5'
@@ -8,8 +8,8 @@ Create Date: 2026-02-13 14:19:00.000000
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'b2c3d4e5f6a7'
@@ -6,9 +6,8 @@ Create Date: 2026-04-01 04:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = 'b7c8d9e0f1a2'
@@ -9,8 +9,8 @@ Create Date: 2026-04-16 23:00:00.000000
import time
import uuid
from alembic import op
import sqlalchemy as sa
from alembic import op
revision = 'c1d2e3f4a5b6'
down_revision = 'e1f2a3b4c5d6'
@@ -6,11 +6,12 @@ Create Date: 2024-10-20 17:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
import json
from sqlalchemy.sql import table, column
from sqlalchemy import String, Text, JSON, and_
import sqlalchemy as sa
from alembic import op
from sqlalchemy import JSON, String, Text, and_
from sqlalchemy.sql import column, table
revision = 'c29facfe716b'
down_revision = 'c69f45358db4'
@@ -8,8 +8,8 @@ Create Date: 2025-12-21 20:27:41.694897
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = 'c440947495f3'
@@ -6,8 +6,8 @@ Create Date: 2024-10-16 02:02:35.241684
"""
from alembic import op
import sqlalchemy as sa
from alembic import op
revision = 'c69f45358db4'
down_revision = '3ab32c4b8f59'
@@ -10,8 +10,8 @@ from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
revision: str = "ca81bd47c050"
down_revision: Union[str, None] = "7e5b5dc7342b"
revision: str = 'ca81bd47c050'
down_revision: Union[str, None] = '7e5b5dc7342b'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
@@ -19,18 +19,18 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Create a key-value config table with versioning."""
op.create_table(
"config",
sa.Column("id", sa.Integer, primary_key=True),
sa.Column("data", sa.JSON(), nullable=False),
sa.Column("version", sa.Integer, nullable=False),
'config',
sa.Column('id', sa.Integer, primary_key=True),
sa.Column('data', sa.JSON(), nullable=False),
sa.Column('version', sa.Integer, nullable=False),
sa.Column(
"created_at",
'created_at',
sa.DateTime(),
nullable=False,
server_default=sa.func.now(),
),
sa.Column(
"updated_at",
'updated_at',
sa.DateTime(),
nullable=True,
server_default=sa.func.now(),
@@ -41,4 +41,4 @@ def upgrade() -> None:
def downgrade() -> None:
"""Drop the config table."""
op.drop_table("config")
op.drop_table('config')
@@ -6,8 +6,8 @@ Create Date: 2025-07-13 03:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from alembic import op
revision = 'd31026856c01'
down_revision = '9f0c9cd09105'
@@ -7,8 +7,8 @@ Create Date: 2026-03-30
from typing import Union
from alembic import op
import sqlalchemy as sa
from alembic import op
revision: str = 'd4e5f6a7b8c9'
down_revision: Union[str, None] = 'a3dd5bedd151'
@@ -6,8 +6,8 @@ Create Date: 2026-04-14 22:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from alembic import op
revision = 'e1f2a3b4c5d6'
down_revision = 'b7c8d9e0f1a2'
@@ -11,13 +11,12 @@ Access control semantics:
- {read: {...}, write: {...}}: Custom permissions -> insert specific grants
"""
from typing import Sequence, Union
import time
import uuid
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from alembic import op
from open_webui.migrations.util import get_existing_tables
revision: str = 'f1e2d3c4b5a6'
+3 -5
View File
@@ -3,13 +3,11 @@ import time
import uuid
from typing import Optional
from sqlalchemy import select, delete
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, get_async_db_context
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, Text, UniqueConstraint, or_, and_
from sqlalchemy import BigInteger, Column, Text, UniqueConstraint, and_, delete, or_, select
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
@@ -623,8 +621,8 @@ class AccessGrantsTable:
Get all users who have the specified permission on a resource.
Returns a list of UserModel instances.
"""
from open_webui.models.users import Users, UserModel
from open_webui.models.groups import Groups
from open_webui.models.users import UserModel, Users
async with get_async_db_context(db) as db:
result = await db.execute(
+18 -19
View File
@@ -1,14 +1,15 @@
from __future__ import annotations
import logging
import uuid
from typing import Optional
from sqlalchemy import select, delete, update
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.models.users import User, UserModel, UserProfileImageResponse, Users
from open_webui.utils.validate import validate_profile_image_url
from pydantic import BaseModel, field_validator
from sqlalchemy import Boolean, Column, String, Text
from sqlalchemy import Boolean, Column, String, Text, delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
@@ -44,7 +45,7 @@ class Token(BaseModel):
class ApiKey(BaseModel):
api_key: Optional[str] = None
api_key: str | None = None
class SigninResponse(Token, UserProfileImageResponse):
@@ -74,18 +75,18 @@ class SignupForm(BaseModel):
name: str
email: str
password: str
profile_image_url: Optional[str] = '/user.png'
profile_image_url: str | None = '/user.png'
@field_validator('profile_image_url')
@classmethod
def check_profile_image_url(cls, v: Optional[str]) -> Optional[str]:
def check_profile_image_url(cls, v: str | None) -> str | None:
if v is not None:
return validate_profile_image_url(v)
return v
class AddUserForm(SignupForm):
role: Optional[str] = 'pending'
role: str | None = 'pending'
class AuthsTable:
@@ -96,9 +97,9 @@ class AuthsTable:
name: str,
profile_image_url: str = '/user.png',
role: str = 'pending',
oauth: Optional[dict] = None,
db: Optional[AsyncSession] = None,
) -> Optional[UserModel]:
oauth: dict | None = None,
db: AsyncSession | None = None,
) -> UserModel | None:
async with get_async_db_context(db) as db:
log.info('insert_new_auth')
@@ -119,8 +120,8 @@ class AuthsTable:
return None
async def authenticate_user(
self, email: str, verify_password: callable, db: Optional[AsyncSession] = None
) -> Optional[UserModel]:
self, email: str, verify_password: callable, db: AsyncSession | None = None
) -> UserModel | None:
log.info(f'authenticate_user: {email}')
user = await Users.get_user_by_email(email, db=db)
@@ -141,9 +142,7 @@ class AuthsTable:
except Exception:
return None
async def authenticate_user_by_api_key(
self, api_key: str, db: Optional[AsyncSession] = None
) -> Optional[UserModel]:
async def authenticate_user_by_api_key(self, api_key: str, db: AsyncSession | None = None) -> UserModel | None:
log.info(f'authenticate_user_by_api_key')
# if no api_key, return None
if not api_key:
@@ -155,7 +154,7 @@ class AuthsTable:
except Exception:
return False
async def authenticate_user_by_email(self, email: str, db: Optional[AsyncSession] = None) -> Optional[UserModel]:
async def authenticate_user_by_email(self, email: str, db: AsyncSession | None = None) -> UserModel | None:
log.info(f'authenticate_user_by_email: {email}')
try:
async with get_async_db_context(db) as db:
@@ -171,7 +170,7 @@ class AuthsTable:
except Exception:
return None
async def update_user_password_by_id(self, id: str, new_password: str, db: Optional[AsyncSession] = None) -> bool:
async def update_user_password_by_id(self, id: str, new_password: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
result = await db.execute(update(Auth).filter_by(id=id).values(password=new_password))
@@ -180,7 +179,7 @@ class AuthsTable:
except Exception:
return False
async def update_email_by_id(self, id: str, email: str, db: Optional[AsyncSession] = None) -> bool:
async def update_email_by_id(self, id: str, email: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
result = await db.execute(update(Auth).filter_by(id=id).values(email=email))
@@ -192,7 +191,7 @@ class AuthsTable:
except Exception:
return False
async def delete_auth_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool:
async def delete_auth_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
# Delete User
+4 -5
View File
@@ -1,13 +1,12 @@
import time
import logging
import time
from typing import Optional
from uuid import uuid4
from pydantic import BaseModel, ConfigDict
from sqlalchemy import Column, Text, JSON, Boolean, BigInteger, Index, select, or_, func, cast, String, delete, update
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, get_async_db_context
from pydantic import BaseModel, ConfigDict
from sqlalchemy import JSON, BigInteger, Boolean, Column, Index, String, Text, cast, delete, func, or_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
+18 -19
View File
@@ -1,30 +1,29 @@
import time
import logging
import time
from typing import Optional
from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import (
Column,
Text,
JSON,
Boolean,
BigInteger,
Index,
UniqueConstraint,
select,
or_,
exists,
func,
delete,
update,
)
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, get_async_db_context
from open_webui.models.access_grants import AccessGrantModel, AccessGrants
from open_webui.models.groups import Groups
from open_webui.models.users import User, UserModel, UserResponse
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import (
JSON,
BigInteger,
Boolean,
Column,
Index,
Text,
UniqueConstraint,
delete,
exists,
func,
or_,
select,
update,
)
from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
+12 -10
View File
@@ -4,31 +4,33 @@ import time
import uuid
from typing import Optional
from open_webui.utils.validate import validate_profile_image_url
from sqlalchemy import select, delete, update, func, case, or_, and_
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.models.groups import Groups
from open_webui.models.access_grants import (
AccessGrantModel,
AccessGrants,
)
from open_webui.models.groups import Groups
from open_webui.utils.validate import validate_profile_image_url
from pydantic import BaseModel, ConfigDict, Field, field_validator
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy import (
JSON,
BigInteger,
Boolean,
Column,
ForeignKey,
String,
Text,
JSON,
UniqueConstraint,
and_,
case,
delete,
func,
or_,
select,
update,
)
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncSession
####################
# Channel DB Schema
+9 -5
View File
@@ -3,21 +3,24 @@ import time
import uuid
from typing import Any, Optional
from sqlalchemy import select, delete, func, cast, Integer
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, get_async_db_context
from open_webui.utils.response import normalize_usage
from pydantic import BaseModel, ConfigDict
from sqlalchemy import (
JSON,
BigInteger,
Boolean,
Column,
ForeignKey,
Text,
JSON,
Index,
Integer,
Text,
cast,
delete,
func,
select,
)
from sqlalchemy.ext.asyncio import AsyncSession
####################
# Helpers
@@ -579,6 +582,7 @@ class ChatMessageTable:
"""Get message counts grouped by day and model."""
async with get_async_db_context(db) as db:
from datetime import datetime, timedelta
from open_webui.models.groups import GroupMember
stmt = select(ChatMessage.created_at, ChatMessage.model_id).filter(
+116 -112
View File
@@ -1,32 +1,39 @@
import logging
from __future__ import annotations
import json
import logging
import time
import uuid
from typing import Optional
from sqlalchemy import select, delete, update, func, or_, and_, text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import exists
from sqlalchemy.sql.expression import bindparam
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.models.tags import TagModel, Tag, Tags
from open_webui.models.folders import Folders
from open_webui.models.chat_messages import ChatMessage, ChatMessages
from open_webui.models.automations import AutomationRun
from open_webui.models.chat_messages import ChatMessage, ChatMessages
from open_webui.models.folders import Folders
from open_webui.models.tags import Tag, TagModel, Tags
from open_webui.utils.misc import sanitize_data_for_db, sanitize_text_for_db
from pydantic import BaseModel, ConfigDict
from sqlalchemy import (
JSON,
BigInteger,
Boolean,
Column,
ForeignKey,
Index,
String,
Text,
JSON,
Index,
UniqueConstraint,
and_,
delete,
func,
or_,
select,
text,
update,
)
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import exists
from sqlalchemy.sql.expression import bindparam
####################
# Chat DB Schema
@@ -81,17 +88,17 @@ class ChatModel(BaseModel):
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
share_id: Optional[str] = None
share_id: str | None = None
archived: bool = False
pinned: Optional[bool] = False
pinned: bool | None = False
meta: dict = {}
folder_id: Optional[str] = None
folder_id: str | None = None
tasks: Optional[list] = None
summary: Optional[str] = None
tasks: list | None = None
summary: str | None = None
last_read_at: Optional[int] = None
last_read_at: int | None = None
class ChatFile(Base):
@@ -115,7 +122,7 @@ class ChatFileModel(BaseModel):
user_id: str
chat_id: str
message_id: Optional[str] = None
message_id: str | None = None
file_id: str
created_at: int
@@ -131,14 +138,14 @@ class ChatFileModel(BaseModel):
class ChatForm(BaseModel):
chat: dict
folder_id: Optional[str] = None
folder_id: str | None = None
class ChatImportForm(ChatForm):
meta: Optional[dict] = {}
pinned: Optional[bool] = False
created_at: Optional[int] = None
updated_at: Optional[int] = None
meta: dict | None = {}
pinned: bool | None = False
created_at: int | None = None
updated_at: int | None = None
class ChatsImportForm(BaseModel):
@@ -161,14 +168,14 @@ class ChatResponse(BaseModel):
chat: dict
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
share_id: Optional[str] = None # id of the chat to be shared
share_id: str | None = None # id of the chat to be shared
archived: bool
pinned: Optional[bool] = False
pinned: bool | None = False
meta: dict = {}
folder_id: Optional[str] = None
folder_id: str | None = None
tasks: Optional[list] = None
summary: Optional[str] = None
tasks: list | None = None
summary: str | None = None
class ChatTitleIdResponse(BaseModel):
@@ -176,13 +183,13 @@ class ChatTitleIdResponse(BaseModel):
title: str
updated_at: int
created_at: int
last_read_at: Optional[int] = None
last_read_at: int | None = None
class SharedChatResponse(BaseModel):
id: str
title: str
share_id: Optional[str] = None
share_id: str | None = None
updated_at: int
created_at: int
@@ -225,17 +232,17 @@ class ChatUsageStatsListResponse(BaseModel):
class MessageStats(BaseModel):
id: str
role: str
model: Optional[str] = None
model: str | None = None
content_length: int
token_count: Optional[int] = None
timestamp: Optional[int] = None
rating: Optional[int] = None # Derived from message.annotation.rating
tags: Optional[list[str]] = None # Derived from message.annotation.tags
token_count: int | None = None
timestamp: int | None = None
rating: int | None = None # Derived from message.annotation.rating
tags: list[str | None] = None # Derived from message.annotation.tags
class ChatHistoryStats(BaseModel):
messages: dict[str, MessageStats]
currentId: Optional[str] = None
currentId: str | None = None
class ChatBody(BaseModel):
@@ -293,8 +300,8 @@ class ChatTable:
return changed
async def insert_new_chat(
self, id: str, user_id: str, form_data: ChatForm, db: Optional[AsyncSession] = None
) -> Optional[ChatModel]:
self, id: str, user_id: str, form_data: ChatForm, db: AsyncSession | None = None
) -> ChatModel | None:
async with get_async_db_context(db) as db:
chat = ChatModel(
**{
@@ -353,7 +360,7 @@ class ChatTable:
self,
user_id: str,
chat_import_forms: list[ChatImportForm],
db: Optional[AsyncSession] = None,
db: AsyncSession | None = None,
) -> list[ChatModel]:
async with get_async_db_context(db) as db:
chats = []
@@ -383,7 +390,7 @@ class ChatTable:
return [ChatModel.model_validate(chat) for chat in chats]
async def update_chat_by_id(self, id: str, chat: dict, db: Optional[AsyncSession] = None) -> Optional[ChatModel]:
async def update_chat_by_id(self, id: str, chat: dict, db: AsyncSession | None = None) -> ChatModel | None:
try:
async with get_async_db_context(db) as db:
chat_item = await db.get(Chat, id)
@@ -398,7 +405,7 @@ class ChatTable:
except Exception:
return None
async def update_chat_last_read_at_by_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> bool:
async def update_chat_last_read_at_by_id(self, id: str, user_id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
chat = await db.get(Chat, id)
@@ -410,7 +417,7 @@ class ChatTable:
except Exception:
return False
async def update_chat_title_by_id(self, id: str, title: str) -> Optional[ChatModel]:
async def update_chat_title_by_id(self, id: str, title: str) -> ChatModel | None:
try:
async with get_async_db_context() as db:
chat_item = await db.get(Chat, id)
@@ -426,7 +433,7 @@ class ChatTable:
except Exception:
return None
async def update_chat_tags_by_id(self, id: str, tags: list[str], user) -> Optional[ChatModel]:
async def update_chat_tags_by_id(self, id: str, tags: list[str], user) -> ChatModel | None:
async with get_async_db_context() as db:
chat = await db.get(Chat, id)
if chat is None:
@@ -451,7 +458,7 @@ class ChatTable:
return ChatModel.model_validate(chat)
async def get_chat_title_by_id(self, id: str) -> Optional[str]:
async def get_chat_title_by_id(self, id: str) -> str | None:
async with get_async_db_context() as db:
result = await db.execute(select(Chat.title).filter_by(id=id))
row = result.first()
@@ -488,7 +495,7 @@ class ChatTable:
except Exception as e:
log.warning('Backfill failed for message %s in chat %s: %s', message_id, chat_id, e)
async def get_messages_map_by_chat_id(self, id: str) -> Optional[dict]:
async def get_messages_map_by_chat_id(self, id: str) -> dict | None:
"""Message map for walking history (see ``get_message_list``).
Prefer ``chat_message`` rows to avoid loading the large embedded
@@ -541,7 +548,7 @@ class ChatTable:
return history_messages
async def get_message_by_id_and_message_id(self, id: str, message_id: str) -> Optional[dict]:
async def get_message_by_id_and_message_id(self, id: str, message_id: str) -> dict | None:
chat = await self.get_chat_by_id(id)
if chat is None:
return None
@@ -550,7 +557,7 @@ class ChatTable:
async def upsert_message_to_chat_by_id_and_message_id(
self, id: str, message_id: str, message: dict
) -> Optional[ChatModel]:
) -> ChatModel | None:
chat = await self.get_chat_by_id(id)
if chat is None:
return None
@@ -590,7 +597,7 @@ class ChatTable:
async def add_message_status_to_chat_by_id_and_message_id(
self, id: str, message_id: str, status: dict
) -> Optional[ChatModel]:
) -> ChatModel | None:
chat = await self.get_chat_by_id(id)
if chat is None:
return None
@@ -626,9 +633,7 @@ class ChatTable:
await self.update_chat_by_id(id, chat, db=db)
return message_files
async def insert_shared_chat_by_chat_id(
self, chat_id: str, db: Optional[AsyncSession] = None
) -> Optional[ChatModel]:
async def insert_shared_chat_by_chat_id(self, chat_id: str, db: AsyncSession | None = None) -> ChatModel | None:
"""Create a shared snapshot for a chat. Returns the original chat with share_id set."""
from open_webui.models.shared_chats import SharedChats
@@ -651,9 +656,7 @@ class ChatTable:
await db.refresh(chat)
return ChatModel.model_validate(chat)
async def update_shared_chat_by_chat_id(
self, chat_id: str, db: Optional[AsyncSession] = None
) -> Optional[ChatModel]:
async def update_shared_chat_by_chat_id(self, chat_id: str, db: AsyncSession | None = None) -> ChatModel | None:
"""Re-snapshot the shared chat with current chat data."""
from open_webui.models.shared_chats import SharedChats
@@ -668,7 +671,7 @@ class ChatTable:
except Exception:
return None
async def delete_shared_chat_by_chat_id(self, chat_id: str, db: Optional[AsyncSession] = None) -> bool:
async def delete_shared_chat_by_chat_id(self, chat_id: str, db: AsyncSession | None = None) -> bool:
"""Delete shared snapshot for a chat."""
from open_webui.models.shared_chats import SharedChats
@@ -677,7 +680,7 @@ class ChatTable:
except Exception:
return False
async def unarchive_all_chats_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> bool:
async def unarchive_all_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
await db.execute(update(Chat).filter_by(user_id=user_id).values(archived=False))
@@ -687,8 +690,8 @@ class ChatTable:
return False
async def update_chat_share_id_by_id(
self, id: str, share_id: Optional[str], db: Optional[AsyncSession] = None
) -> Optional[ChatModel]:
self, id: str, share_id: str | None, db: AsyncSession | None = None
) -> ChatModel | None:
try:
async with get_async_db_context(db) as db:
chat = await db.get(Chat, id)
@@ -699,7 +702,7 @@ class ChatTable:
except Exception:
return None
async def toggle_chat_pinned_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ChatModel]:
async def toggle_chat_pinned_by_id(self, id: str, db: AsyncSession | None = None) -> ChatModel | None:
try:
async with get_async_db_context(db) as db:
chat = await db.get(Chat, id)
@@ -711,7 +714,7 @@ class ChatTable:
except Exception:
return None
async def toggle_chat_archive_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ChatModel]:
async def toggle_chat_archive_by_id(self, id: str, db: AsyncSession | None = None) -> ChatModel | None:
try:
async with get_async_db_context(db) as db:
chat = await db.get(Chat, id)
@@ -724,7 +727,7 @@ class ChatTable:
except Exception:
return None
async def archive_all_chats_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> bool:
async def archive_all_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
await db.execute(update(Chat).filter_by(user_id=user_id).values(archived=True))
@@ -736,10 +739,10 @@ class ChatTable:
async def get_archived_chat_list_by_user_id(
self,
user_id: str,
filter: Optional[dict] = None,
filter: dict | None = None,
skip: int = 0,
limit: int = 50,
db: Optional[AsyncSession] = None,
db: AsyncSession | None = None,
) -> list[ChatTitleIdResponse]:
async with get_async_db_context(db) as db:
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at).filter_by(
@@ -789,10 +792,10 @@ class ChatTable:
async def get_shared_chat_list_by_user_id(
self,
user_id: str,
filter: Optional[dict] = None,
filter: dict | None = None,
skip: int = 0,
limit: int = 50,
db: Optional[AsyncSession] = None,
db: AsyncSession | None = None,
) -> list[SharedChatResponse]:
"""Delegate to SharedChats for listing shared chats by user."""
from open_webui.models.shared_chats import SharedChats
@@ -803,10 +806,10 @@ class ChatTable:
self,
user_id: str,
include_archived: bool = False,
filter: Optional[dict] = None,
filter: dict | None = None,
skip: int = 0,
limit: int = 50,
db: Optional[AsyncSession] = None,
db: AsyncSession | None = None,
) -> list[ChatTitleIdResponse]:
async with get_async_db_context(db) as db:
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by(
@@ -859,9 +862,9 @@ class ChatTable:
include_archived: bool = False,
include_folders: bool = False,
include_pinned: bool = False,
skip: Optional[int] = None,
limit: Optional[int] = None,
db: Optional[AsyncSession] = None,
skip: int | None = None,
limit: int | None = None,
db: AsyncSession | None = None,
) -> list[ChatTitleIdResponse]:
async with get_async_db_context(db) as db:
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by(
@@ -905,7 +908,7 @@ class ChatTable:
chat_ids: list[str],
skip: int = 0,
limit: int = 50,
db: Optional[AsyncSession] = None,
db: AsyncSession | None = None,
) -> list[ChatModel]:
async with get_async_db_context(db) as db:
result = await db.execute(
@@ -914,7 +917,7 @@ class ChatTable:
all_chats = result.scalars().all()
return [ChatModel.model_validate(chat) for chat in all_chats]
async def get_chat_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ChatModel]:
async def get_chat_by_id(self, id: str, db: AsyncSession | None = None) -> ChatModel | None:
try:
async with get_async_db_context(db) as db:
chat_item = await db.get(Chat, id)
@@ -929,7 +932,7 @@ class ChatTable:
except Exception:
return None
async def get_chat_by_share_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ChatModel]:
async def get_chat_by_share_id(self, id: str, db: AsyncSession | None = None) -> ChatModel | None:
"""Look up a shared chat snapshot by its share token."""
from open_webui.models.shared_chats import SharedChats
@@ -951,8 +954,8 @@ class ChatTable:
return None
async def get_chat_by_id_and_user_id(
self, id: str, user_id: str, db: Optional[AsyncSession] = None
) -> Optional[ChatModel]:
self, id: str, user_id: str, db: AsyncSession | None = None
) -> ChatModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(Chat).filter_by(id=id, user_id=user_id))
@@ -961,7 +964,7 @@ class ChatTable:
except Exception:
return None
async def is_chat_owner(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> bool:
async def is_chat_owner(self, id: str, user_id: str, db: AsyncSession | None = None) -> bool:
"""
Lightweight ownership check uses EXISTS subquery instead of loading
the full Chat row (which includes the potentially large JSON blob).
@@ -973,7 +976,7 @@ class ChatTable:
except Exception:
return False
async def get_chat_folder_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> Optional[str]:
async def get_chat_folder_id(self, id: str, user_id: str, db: AsyncSession | None = None) -> str | None:
"""
Fetch only the folder_id column for a chat, without loading the full
JSON blob. Returns None if chat doesn't exist or doesn't belong to user.
@@ -986,7 +989,7 @@ class ChatTable:
except Exception:
return None
async def get_chats(self, skip: int = 0, limit: int = 50, db: Optional[AsyncSession] = None) -> list[ChatModel]:
async def get_chats(self, skip: int = 0, limit: int = 50, db: AsyncSession | None = None) -> list[ChatModel]:
async with get_async_db_context(db) as db:
result = await db.execute(select(Chat).order_by(Chat.updated_at.desc()))
all_chats = result.scalars().all()
@@ -995,10 +998,10 @@ class ChatTable:
async def get_chats_by_user_id(
self,
user_id: str,
filter: Optional[dict] = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
db: Optional[AsyncSession] = None,
filter: dict | None = None,
skip: int | None = None,
limit: int | None = None,
db: AsyncSession | None = None,
) -> ChatListResponse:
async with get_async_db_context(db) as db:
stmt = select(Chat).filter_by(user_id=user_id)
@@ -1041,7 +1044,7 @@ class ChatTable:
)
async def get_pinned_chats_by_user_id(
self, user_id: str, db: Optional[AsyncSession] = None
self, user_id: str, db: AsyncSession | None = None
) -> list[ChatTitleIdResponse]:
async with get_async_db_context(db) as db:
result = await db.execute(
@@ -1063,7 +1066,7 @@ class ChatTable:
for chat in all_chats
]
async def get_archived_chats_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> list[ChatModel]:
async def get_archived_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> list[ChatModel]:
async with get_async_db_context(db) as db:
result = await db.execute(
select(Chat).filter_by(user_id=user_id, archived=True).order_by(Chat.updated_at.desc())
@@ -1077,7 +1080,7 @@ class ChatTable:
include_archived: bool = False,
skip: int = 0,
limit: int = 60,
db: Optional[AsyncSession] = None,
db: AsyncSession | None = None,
) -> list[ChatModel]:
"""
Filters chats based on a search query using Python, allowing pagination using skip and limit.
@@ -1270,7 +1273,7 @@ class ChatTable:
user_id: str,
skip: int = 0,
limit: int = 60,
db: Optional[AsyncSession] = None,
db: AsyncSession | None = None,
) -> list[ChatTitleIdResponse]:
async with get_async_db_context(db) as db:
stmt = (
@@ -1302,7 +1305,7 @@ class ChatTable:
]
async def get_chats_by_folder_ids_and_user_id(
self, folder_ids: list[str], user_id: str, db: Optional[AsyncSession] = None
self, folder_ids: list[str], user_id: str, db: AsyncSession | None = None
) -> list[ChatModel]:
async with get_async_db_context(db) as db:
stmt = (
@@ -1318,8 +1321,8 @@ class ChatTable:
return [ChatModel.model_validate(chat) for chat in all_chats]
async def update_chat_folder_id_by_id_and_user_id(
self, id: str, user_id: str, folder_id: str, db: Optional[AsyncSession] = None
) -> Optional[ChatModel]:
self, id: str, user_id: str, folder_id: str, db: AsyncSession | None = None
) -> ChatModel | None:
try:
async with get_async_db_context(db) as db:
chat = await db.get(Chat, id)
@@ -1333,7 +1336,7 @@ class ChatTable:
return None
async def get_chat_tags_by_id_and_user_id(
self, id: str, user_id: str, db: Optional[AsyncSession] = None
self, id: str, user_id: str, db: AsyncSession | None = None
) -> list[TagModel]:
async with get_async_db_context(db) as db:
stmt = select(Chat.meta).where(Chat.id == id)
@@ -1348,7 +1351,7 @@ class ChatTable:
tag_name: str,
skip: int = 0,
limit: int = 50,
db: Optional[AsyncSession] = None,
db: AsyncSession | None = None,
) -> list[ChatTitleIdResponse]:
async with get_async_db_context(db) as db:
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by(
@@ -1393,8 +1396,8 @@ class ChatTable:
]
async def add_chat_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str, db: Optional[AsyncSession] = None
) -> Optional[ChatModel]:
self, id: str, user_id: str, tag_name: str, db: AsyncSession | None = None
) -> ChatModel | None:
tag_id = tag_name.replace(' ', '_').lower()
await Tags.ensure_tags_exist([tag_name], user_id, db=db)
try:
@@ -1412,7 +1415,7 @@ class ChatTable:
return None
async def count_chats_by_tag_name_and_user_id(
self, tag_name: str, user_id: str, db: Optional[AsyncSession] = None
self, tag_name: str, user_id: str, db: AsyncSession | None = None
) -> int:
async with get_async_db_context(db) as db:
stmt = select(func.count(Chat.id)).filter_by(user_id=user_id, archived=False)
@@ -1439,7 +1442,7 @@ class ChatTable:
tag_ids: list[str],
user_id: str,
threshold: int = 0,
db: Optional[AsyncSession] = None,
db: AsyncSession | None = None,
) -> None:
"""Delete tag rows from *tag_ids* that appear in at most *threshold*
non-archived chats for *user_id*. One query to find orphans, one to
@@ -1460,7 +1463,7 @@ class ChatTable:
await Tags.delete_tags_by_ids_and_user_id(orphans, user_id, db=db)
async def count_chats_by_folder_id_and_user_id(
self, folder_id: str, user_id: str, db: Optional[AsyncSession] = None
self, folder_id: str, user_id: str, db: AsyncSession | None = None
) -> int:
async with get_async_db_context(db) as db:
result = await db.execute(select(func.count(Chat.id)).filter_by(user_id=user_id, folder_id=folder_id))
@@ -1470,7 +1473,7 @@ class ChatTable:
return count
async def delete_tag_by_id_and_user_id_and_tag_name(
self, id: str, user_id: str, tag_name: str, db: Optional[AsyncSession] = None
self, id: str, user_id: str, tag_name: str, db: AsyncSession | None = None
) -> bool:
try:
async with get_async_db_context(db) as db:
@@ -1488,7 +1491,7 @@ class ChatTable:
except Exception:
return False
async def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> bool:
async def delete_all_tags_by_id_and_user_id(self, id: str, user_id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
chat = await db.get(Chat, id)
@@ -1502,7 +1505,7 @@ class ChatTable:
except Exception:
return False
async def delete_chat_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool:
async def delete_chat_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
await db.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None))
@@ -1514,7 +1517,7 @@ class ChatTable:
except Exception:
return False
async def delete_chat_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> bool:
async def delete_chat_by_id_and_user_id(self, id: str, user_id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
await db.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None))
@@ -1526,7 +1529,7 @@ class ChatTable:
except Exception:
return False
async def delete_chats_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> bool:
async def delete_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
await self.delete_shared_chats_by_user_id(user_id, db=db)
@@ -1548,7 +1551,7 @@ class ChatTable:
return False
async def delete_chats_by_user_id_and_folder_id(
self, user_id: str, folder_id: str, db: Optional[AsyncSession] = None
self, user_id: str, folder_id: str, db: AsyncSession | None = None
) -> bool:
try:
async with get_async_db_context(db) as db:
@@ -1568,8 +1571,8 @@ class ChatTable:
self,
user_id: str,
folder_id: str,
new_folder_id: Optional[str],
db: Optional[AsyncSession] = None,
new_folder_id: str | None,
db: AsyncSession | None = None,
) -> bool:
try:
async with get_async_db_context(db) as db:
@@ -1582,9 +1585,10 @@ class ChatTable:
except Exception:
return False
async def delete_shared_chats_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> bool:
async def delete_shared_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool:
"""Delete all shared chat snapshots created by a user."""
from open_webui.models.shared_chats import SharedChats, SharedChat as SharedChatTable
from open_webui.models.shared_chats import SharedChat as SharedChatTable
from open_webui.models.shared_chats import SharedChats
try:
async with get_async_db_context(db) as db:
@@ -1605,8 +1609,8 @@ class ChatTable:
message_id: str,
file_ids: list[str],
user_id: str,
db: Optional[AsyncSession] = None,
) -> Optional[list[ChatFileModel]]:
db: AsyncSession | None = None,
) -> list[ChatFileModel | None]:
if not file_ids:
return None
@@ -1645,7 +1649,7 @@ class ChatTable:
return None
async def get_chat_files_by_chat_id_and_message_id(
self, chat_id: str, message_id: str, db: Optional[AsyncSession] = None
self, chat_id: str, message_id: str, db: AsyncSession | None = None
) -> list[ChatFileModel]:
async with get_async_db_context(db) as db:
result = await db.execute(
@@ -1654,7 +1658,7 @@ class ChatTable:
all_chat_files = result.scalars().all()
return [ChatFileModel.model_validate(chat_file) for chat_file in all_chat_files]
async def delete_chat_file(self, chat_id: str, file_id: str, db: Optional[AsyncSession] = None) -> bool:
async def delete_chat_file(self, chat_id: str, file_id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
await db.execute(delete(ChatFile).filter_by(chat_id=chat_id, file_id=file_id))
@@ -1663,7 +1667,7 @@ class ChatTable:
except Exception:
return False
async def get_shared_chat_ids_by_file_id(self, file_id: str, db: Optional[AsyncSession] = None) -> list[str]:
async def get_shared_chat_ids_by_file_id(self, file_id: str, db: AsyncSession | None = None) -> list[str]:
"""Return IDs of chats that contain this file and have an active share link."""
async with get_async_db_context(db) as db:
result = await db.execute(
@@ -1673,7 +1677,7 @@ class ChatTable:
)
return [row[0] for row in result.all()]
async def update_chat_tasks_by_id(self, id: str, tasks: list[dict]) -> Optional[ChatModel]:
async def update_chat_tasks_by_id(self, id: str, tasks: list[dict]) -> ChatModel | None:
"""Update the tasks list on a chat."""
try:
async with get_async_db_context() as db:
+3 -5
View File
@@ -3,13 +3,11 @@ import time
import uuid
from typing import Optional
from sqlalchemy import select, delete, func
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.models.users import User, UserModel
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean
from sqlalchemy import JSON, BigInteger, Boolean, Column, Text, delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
@@ -319,8 +317,8 @@ class FeedbackTable:
If days=0, returns all time data starting from first feedback.
Returns: [{"date": "2026-01-08", "won": 5, "lost": 2}, ...]
"""
from datetime import datetime, timedelta
from collections import defaultdict
from datetime import datetime, timedelta
async with get_async_db_context(db) as db:
if days == 0:
+46 -51
View File
@@ -1,13 +1,14 @@
from __future__ import annotations
import logging
import time
from typing import Optional
from sqlalchemy import select, delete, func
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.utils.misc import sanitize_metadata
from pydantic import BaseModel, ConfigDict, model_validator
from sqlalchemy import BigInteger, Column, String, Text, JSON
from sqlalchemy import JSON, BigInteger, Column, String, Text, delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
@@ -39,16 +40,16 @@ class FileModel(BaseModel):
id: str
user_id: str
hash: Optional[str] = None
hash: str | None = None
filename: str
path: Optional[str] = None
path: str | None = None
data: Optional[dict] = None
meta: Optional[dict] = None
data: dict | None = None
meta: dict | None = None
created_at: Optional[int] # timestamp in epoch
updated_at: Optional[int] # timestamp in epoch
created_at: int | None # timestamp in epoch
updated_at: int | None # timestamp in epoch
####################
@@ -57,9 +58,9 @@ class FileModel(BaseModel):
class FileMeta(BaseModel):
name: Optional[str] = None
content_type: Optional[str] = None
size: Optional[int] = None
name: str | None = None
content_type: str | None = None
size: int | None = None
model_config = ConfigDict(extra='allow')
@@ -84,22 +85,22 @@ class FileMeta(BaseModel):
class FileModelResponse(BaseModel):
id: str
user_id: str
hash: Optional[str] = None
hash: str | None = None
filename: str
data: Optional[dict] = None
meta: Optional[FileMeta] = None
data: dict | None = None
meta: FileMeta | None = None
created_at: int # timestamp in epoch
updated_at: Optional[int] = None # timestamp in epoch, optional for legacy files
updated_at: int | None = None # timestamp in epoch, optional for legacy files
model_config = ConfigDict(extra='allow')
class FileMetadataResponse(BaseModel):
id: str
hash: Optional[str] = None
meta: Optional[dict] = None
hash: str | None = None
meta: dict | None = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
@@ -111,7 +112,7 @@ class FileListResponse(BaseModel):
class FileForm(BaseModel):
id: str
hash: Optional[str] = None
hash: str | None = None
filename: str
path: str
data: dict = {}
@@ -119,15 +120,15 @@ class FileForm(BaseModel):
class FileUpdateForm(BaseModel):
hash: Optional[str] = None
data: Optional[dict] = None
meta: Optional[dict] = None
hash: str | None = None
data: dict | None = None
meta: dict | None = None
class FilesTable:
async def insert_new_file(
self, user_id: str, form_data: FileForm, db: Optional[AsyncSession] = None
) -> Optional[FileModel]:
self, user_id: str, form_data: FileForm, db: AsyncSession | None = None
) -> FileModel | None:
async with get_async_db_context(db) as db:
file_data = form_data.model_dump()
@@ -158,7 +159,7 @@ class FilesTable:
log.exception(f'Error inserting a new file: {e}')
return None
async def get_file_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[FileModel]:
async def get_file_by_id(self, id: str, db: AsyncSession | None = None) -> FileModel | None:
try:
async with get_async_db_context(db) as db:
try:
@@ -170,8 +171,8 @@ class FilesTable:
return None
async def get_file_by_id_and_user_id(
self, id: str, user_id: str, db: Optional[AsyncSession] = None
) -> Optional[FileModel]:
self, id: str, user_id: str, db: AsyncSession | None = None
) -> FileModel | None:
async with get_async_db_context(db) as db:
try:
result = await db.execute(select(File).filter_by(id=id, user_id=user_id))
@@ -183,9 +184,7 @@ class FilesTable:
except Exception:
return None
async def get_file_metadata_by_id(
self, id: str, db: Optional[AsyncSession] = None
) -> Optional[FileMetadataResponse]:
async def get_file_metadata_by_id(self, id: str, db: AsyncSession | None = None) -> FileMetadataResponse | None:
async with get_async_db_context(db) as db:
try:
file = await db.get(File, id)
@@ -201,12 +200,12 @@ class FilesTable:
except Exception:
return None
async def get_files(self, db: Optional[AsyncSession] = None) -> list[FileModel]:
async def get_files(self, db: AsyncSession | None = None) -> list[FileModel]:
async with get_async_db_context(db) as db:
result = await db.execute(select(File))
return [FileModel.model_validate(file) for file in result.scalars().all()]
async def check_access_by_user_id(self, id, user_id, permission='write', db: Optional[AsyncSession] = None) -> bool:
async def check_access_by_user_id(self, id, user_id, permission='write', db: AsyncSession | None = None) -> bool:
file = await self.get_file_by_id(id, db=db)
if not file:
return False
@@ -215,13 +214,13 @@ class FilesTable:
# Implement additional access control logic here as needed
return False
async def get_files_by_ids(self, ids: list[str], db: Optional[AsyncSession] = None) -> list[FileModel]:
async def get_files_by_ids(self, ids: list[str], db: AsyncSession | None = None) -> list[FileModel]:
async with get_async_db_context(db) as db:
result = await db.execute(select(File).filter(File.id.in_(ids)).order_by(File.updated_at.desc()))
return [FileModel.model_validate(file) for file in result.scalars().all()]
async def get_file_metadatas_by_ids(
self, ids: list[str], db: Optional[AsyncSession] = None
self, ids: list[str], db: AsyncSession | None = None
) -> list[FileMetadataResponse]:
async with get_async_db_context(db) as db:
result = await db.execute(
@@ -240,17 +239,17 @@ class FilesTable:
for row in result.all()
]
async def get_files_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> list[FileModel]:
async def get_files_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> list[FileModel]:
async with get_async_db_context(db) as db:
result = await db.execute(select(File).filter_by(user_id=user_id))
return [FileModel.model_validate(file) for file in result.scalars().all()]
async def get_file_list(
self,
user_id: Optional[str] = None,
user_id: str | None = None,
skip: int = 0,
limit: int = 50,
db: Optional[AsyncSession] = None,
db: AsyncSession | None = None,
) -> 'FileListResponse':
async with get_async_db_context(db) as db:
stmt = select(File)
@@ -290,11 +289,11 @@ class FilesTable:
async def search_files(
self,
user_id: Optional[str] = None,
user_id: str | None = None,
filename: str = '*',
skip: int = 0,
limit: int = 100,
db: Optional[AsyncSession] = None,
db: AsyncSession | None = None,
) -> list[FileModel]:
"""
Search files with glob pattern matching, optional user filter, and pagination.
@@ -323,8 +322,8 @@ class FilesTable:
return [FileModel.model_validate(file) for file in result.scalars().all()]
async def update_file_by_id(
self, id: str, form_data: FileUpdateForm, db: Optional[AsyncSession] = None
) -> Optional[FileModel]:
self, id: str, form_data: FileUpdateForm, db: AsyncSession | None = None
) -> FileModel | None:
async with get_async_db_context(db) as db:
try:
result = await db.execute(select(File).filter_by(id=id))
@@ -347,8 +346,8 @@ class FilesTable:
return None
async def update_file_hash_by_id(
self, id: str, hash: Optional[str], db: Optional[AsyncSession] = None
) -> Optional[FileModel]:
self, id: str, hash: str | None, db: AsyncSession | None = None
) -> FileModel | None:
async with get_async_db_context(db) as db:
try:
result = await db.execute(select(File).filter_by(id=id))
@@ -361,9 +360,7 @@ class FilesTable:
except Exception:
return None
async def update_file_data_by_id(
self, id: str, data: dict, db: Optional[AsyncSession] = None
) -> Optional[FileModel]:
async def update_file_data_by_id(self, id: str, data: dict, db: AsyncSession | None = None) -> FileModel | None:
async with get_async_db_context(db) as db:
try:
result = await db.execute(select(File).filter_by(id=id))
@@ -375,9 +372,7 @@ class FilesTable:
except Exception as e:
return None
async def update_file_metadata_by_id(
self, id: str, meta: dict, db: Optional[AsyncSession] = None
) -> Optional[FileModel]:
async def update_file_metadata_by_id(self, id: str, meta: dict, db: AsyncSession | None = None) -> FileModel | None:
async with get_async_db_context(db) as db:
try:
result = await db.execute(select(File).filter_by(id=id))
@@ -389,7 +384,7 @@ class FilesTable:
except Exception:
return None
async def delete_file_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool:
async def delete_file_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
async with get_async_db_context(db) as db:
try:
await db.execute(delete(File).filter_by(id=id))
@@ -399,7 +394,7 @@ class FilesTable:
except Exception:
return False
async def delete_all_files(self, db: Optional[AsyncSession] = None) -> bool:
async def delete_all_files(self, db: AsyncSession | None = None) -> bool:
async with get_async_db_context(db) as db:
try:
await db.execute(delete(File))
+4 -6
View File
@@ -1,15 +1,13 @@
import logging
import re
import time
import uuid
from typing import Optional
import re
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, Text, JSON, Boolean, func, select, delete
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, JSONField, get_async_db_context
from pydantic import BaseModel, ConfigDict
from sqlalchemy import JSON, BigInteger, Boolean, Column, Text, delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
+34 -33
View File
@@ -1,13 +1,14 @@
from __future__ import annotations
import logging
import time
from typing import Optional
from sqlalchemy import select, delete, update
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.models.users import Users, UserModel, UserResponse
from open_webui.models.users import UserModel, UserResponse, Users
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Boolean, Column, String, Text, Index
from sqlalchemy import BigInteger, Boolean, Column, Index, String, Text, delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
@@ -37,8 +38,8 @@ class Function(Base):
class FunctionMeta(BaseModel):
description: Optional[str] = None
manifest: Optional[dict] = {}
description: str | None = None
manifest: dict | None = {}
model_config = ConfigDict(extra='allow')
@@ -64,7 +65,7 @@ class FunctionWithValvesModel(BaseModel):
type: str
content: str
meta: FunctionMeta
valves: Optional[dict] = None
valves: dict | None = None
is_active: bool = False
is_global: bool = False
updated_at: int # timestamp in epoch
@@ -93,7 +94,7 @@ class FunctionResponse(BaseModel):
class FunctionUserResponse(FunctionResponse):
user: Optional[UserResponse] = None
user: UserResponse | None = None
class FunctionForm(BaseModel):
@@ -104,7 +105,7 @@ class FunctionForm(BaseModel):
class FunctionValves(BaseModel):
valves: Optional[dict] = None
valves: dict | None = None
class FunctionsTable:
@@ -113,8 +114,8 @@ class FunctionsTable:
user_id: str,
type: str,
form_data: FunctionForm,
db: Optional[AsyncSession] = None,
) -> Optional[FunctionModel]:
db: AsyncSession | None = None,
) -> FunctionModel | None:
function = FunctionModel(
**{
**form_data.model_dump(),
@@ -143,7 +144,7 @@ class FunctionsTable:
self,
user_id: str,
functions: list[FunctionWithValvesModel],
db: Optional[AsyncSession] = None,
db: AsyncSession | None = None,
) -> list[FunctionWithValvesModel]:
# Synchronize functions for a user by updating existing ones, inserting new ones, and removing those that are no longer present.
try:
@@ -191,7 +192,7 @@ class FunctionsTable:
log.exception(f'Error syncing functions for user {user_id}: {e}')
return []
async def get_function_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[FunctionModel]:
async def get_function_by_id(self, id: str, db: AsyncSession | None = None) -> FunctionModel | None:
try:
async with get_async_db_context(db) as db:
function = await db.get(Function, id)
@@ -199,7 +200,7 @@ class FunctionsTable:
except Exception:
return None
async def get_functions_by_ids(self, ids: list[str], db: Optional[AsyncSession] = None) -> list[FunctionModel]:
async def get_functions_by_ids(self, ids: list[str], db: AsyncSession | None = None) -> list[FunctionModel]:
"""
Batch fetch multiple functions by their IDs in a single query.
Returns functions in the same order as the input IDs (None entries filtered out).
@@ -218,7 +219,7 @@ class FunctionsTable:
return []
async def get_functions(
self, active_only=False, include_valves=False, db: Optional[AsyncSession] = None
self, active_only=False, include_valves=False, db: AsyncSession | None = None
) -> list[FunctionModel | FunctionWithValvesModel]:
async with get_async_db_context(db) as db:
if active_only:
@@ -233,7 +234,7 @@ class FunctionsTable:
else:
return [FunctionModel.model_validate(function) for function in functions]
async def get_function_list(self, db: Optional[AsyncSession] = None) -> list[FunctionUserResponse]:
async def get_function_list(self, db: AsyncSession | None = None) -> list[FunctionUserResponse]:
async with get_async_db_context(db) as db:
result = await db.execute(select(Function).order_by(Function.updated_at.desc()))
functions = result.scalars().all()
@@ -262,7 +263,7 @@ class FunctionsTable:
]
async def get_functions_by_type(
self, type: str, active_only=False, db: Optional[AsyncSession] = None
self, type: str, active_only=False, db: AsyncSession | None = None
) -> list[FunctionModel]:
async with get_async_db_context(db) as db:
if active_only:
@@ -271,17 +272,17 @@ class FunctionsTable:
result = await db.execute(select(Function).filter_by(type=type))
return [FunctionModel.model_validate(function) for function in result.scalars().all()]
async def get_global_filter_functions(self, db: Optional[AsyncSession] = None) -> list[FunctionModel]:
async def get_global_filter_functions(self, db: AsyncSession | None = None) -> list[FunctionModel]:
async with get_async_db_context(db) as db:
result = await db.execute(select(Function).filter_by(type='filter', is_active=True, is_global=True))
return [FunctionModel.model_validate(function) for function in result.scalars().all()]
async def get_global_action_functions(self, db: Optional[AsyncSession] = None) -> list[FunctionModel]:
async def get_global_action_functions(self, db: AsyncSession | None = None) -> list[FunctionModel]:
async with get_async_db_context(db) as db:
result = await db.execute(select(Function).filter_by(type='action', is_active=True, is_global=True))
return [FunctionModel.model_validate(function) for function in result.scalars().all()]
async def get_function_valves_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[dict]:
async def get_function_valves_by_id(self, id: str, db: AsyncSession | None = None) -> dict | None:
async with get_async_db_context(db) as db:
try:
function = await db.get(Function, id)
@@ -290,7 +291,7 @@ class FunctionsTable:
log.exception(f'Error getting function valves by id {id}: {e}')
return None
async def get_function_valves_by_ids(self, ids: list[str], db: Optional[AsyncSession] = None) -> dict[str, dict]:
async def get_function_valves_by_ids(self, ids: list[str], db: AsyncSession | None = None) -> dict[str, dict]:
"""
Batch fetch valves for multiple functions in a single query.
Returns a dict mapping function_id -> valves dict.
@@ -308,8 +309,8 @@ class FunctionsTable:
return {}
async def update_function_valves_by_id(
self, id: str, valves: dict, db: Optional[AsyncSession] = None
) -> Optional[FunctionValves]:
self, id: str, valves: dict, db: AsyncSession | None = None
) -> FunctionValves | None:
async with get_async_db_context(db) as db:
try:
function = await db.get(Function, id)
@@ -322,8 +323,8 @@ class FunctionsTable:
return None
async def update_function_metadata_by_id(
self, id: str, metadata: dict, db: Optional[AsyncSession] = None
) -> Optional[FunctionModel]:
self, id: str, metadata: dict, db: AsyncSession | None = None
) -> FunctionModel | None:
async with get_async_db_context(db) as db:
try:
function = await db.get(Function, id)
@@ -345,8 +346,8 @@ class FunctionsTable:
return None
async def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str, db: Optional[AsyncSession] = None
) -> Optional[dict]:
self, id: str, user_id: str, db: AsyncSession | None = None
) -> dict | None:
try:
user = await Users.get_user_by_id(user_id, db=db)
user_settings = user.settings.model_dump() if user.settings else {}
@@ -363,8 +364,8 @@ class FunctionsTable:
return None
async def update_user_valves_by_id_and_user_id(
self, id: str, user_id: str, valves: dict, db: Optional[AsyncSession] = None
) -> Optional[dict]:
self, id: str, user_id: str, valves: dict, db: AsyncSession | None = None
) -> dict | None:
try:
user = await Users.get_user_by_id(user_id, db=db)
user_settings = user.settings.model_dump() if user.settings else {}
@@ -386,8 +387,8 @@ class FunctionsTable:
return None
async def update_function_by_id(
self, id: str, updated: dict, db: Optional[AsyncSession] = None
) -> Optional[FunctionModel]:
self, id: str, updated: dict, db: AsyncSession | None = None
) -> FunctionModel | None:
async with get_async_db_context(db) as db:
try:
await db.execute(
@@ -404,7 +405,7 @@ class FunctionsTable:
except Exception:
return None
async def deactivate_all_functions(self, db: Optional[AsyncSession] = None) -> Optional[bool]:
async def deactivate_all_functions(self, db: AsyncSession | None = None) -> bool | None:
async with get_async_db_context(db) as db:
try:
await db.execute(
@@ -418,7 +419,7 @@ class FunctionsTable:
except Exception:
return None
async def delete_function_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool:
async def delete_function_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
async with get_async_db_context(db) as db:
try:
await db.execute(delete(Function).filter_by(id=id))
+13 -9
View File
@@ -1,25 +1,29 @@
import json
import logging
import time
from typing import Optional
import uuid
from typing import Optional
from sqlalchemy import select, delete, update, func, and_, or_, cast, String
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.env import DEFAULT_GROUP_SHARE_PERMISSION
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.models.files import FileMetadataResponse
from pydantic import BaseModel, ConfigDict
from sqlalchemy import (
JSON,
BigInteger,
Column,
Text,
JSON,
ForeignKey,
String,
Text,
and_,
cast,
delete,
func,
or_,
select,
update,
)
from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
+12 -10
View File
@@ -1,34 +1,36 @@
import json
import logging
import time
from typing import Optional
import uuid
from typing import Optional
from sqlalchemy import select, delete, update, or_, func, cast
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.models.access_grants import AccessGrantModel, AccessGrants
from open_webui.models.files import (
File,
FileModel,
FileMetadataResponse,
FileModel,
FileModelResponse,
)
from open_webui.models.groups import Groups
from open_webui.models.users import User, UserModel, Users, UserResponse
from open_webui.models.access_grants import AccessGrantModel, AccessGrants
from open_webui.models.users import User, UserModel, UserResponse, Users
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import (
JSON,
BigInteger,
Column,
ForeignKey,
String,
Text,
JSON,
UniqueConstraint,
cast,
delete,
func,
or_,
select,
update,
)
from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
+14 -13
View File
@@ -1,12 +1,13 @@
from __future__ import annotations
import time
import uuid
from typing import Optional
from sqlalchemy import select, delete
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, get_async_db_context
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text
from sqlalchemy import BigInteger, Column, String, Text, delete, select
from sqlalchemy.ext.asyncio import AsyncSession
####################
# Memory DB Schema
@@ -45,8 +46,8 @@ class MemoriesTable:
self,
user_id: str,
content: str,
db: Optional[AsyncSession] = None,
) -> Optional[MemoryModel]:
db: AsyncSession | None = None,
) -> MemoryModel | None:
async with get_async_db_context(db) as db:
id = str(uuid.uuid4())
@@ -73,8 +74,8 @@ class MemoriesTable:
id: str,
user_id: str,
content: str,
db: Optional[AsyncSession] = None,
) -> Optional[MemoryModel]:
db: AsyncSession | None = None,
) -> MemoryModel | None:
async with get_async_db_context(db) as db:
try:
memory = await db.get(Memory, id)
@@ -90,7 +91,7 @@ class MemoriesTable:
except Exception:
return None
async def get_memories(self, db: Optional[AsyncSession] = None) -> list[MemoryModel]:
async def get_memories(self, db: AsyncSession | None = None) -> list[MemoryModel]:
async with get_async_db_context(db) as db:
try:
result = await db.execute(select(Memory))
@@ -99,7 +100,7 @@ class MemoriesTable:
except Exception:
return None
async def get_memories_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> list[MemoryModel]:
async def get_memories_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> list[MemoryModel]:
async with get_async_db_context(db) as db:
try:
result = await db.execute(select(Memory).filter_by(user_id=user_id))
@@ -108,7 +109,7 @@ class MemoriesTable:
except Exception:
return None
async def get_memory_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[MemoryModel]:
async def get_memory_by_id(self, id: str, db: AsyncSession | None = None) -> MemoryModel | None:
async with get_async_db_context(db) as db:
try:
memory = await db.get(Memory, id)
@@ -116,7 +117,7 @@ class MemoriesTable:
except Exception:
return None
async def delete_memory_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool:
async def delete_memory_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
async with get_async_db_context(db) as db:
try:
await db.execute(delete(Memory).filter_by(id=id))
@@ -127,7 +128,7 @@ class MemoriesTable:
except Exception:
return False
async def delete_memories_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> bool:
async def delete_memories_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool:
async with get_async_db_context(db) as db:
try:
await db.execute(delete(Memory).filter_by(user_id=user_id))
@@ -137,7 +138,7 @@ class MemoriesTable:
except Exception:
return False
async def delete_memory_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> bool:
async def delete_memory_by_id_and_user_id(self, id: str, user_id: str, db: AsyncSession | None = None) -> bool:
async with get_async_db_context(db) as db:
try:
memory = await db.get(Memory, id)
+5 -9
View File
@@ -3,17 +3,13 @@ import time
import uuid
from typing import Optional
from sqlalchemy import select, delete, func
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.models.tags import TagModel, Tag, Tags
from open_webui.models.users import Users, User, UserNameResponse
from open_webui.models.channels import Channels, ChannelMember
from open_webui.models.channels import ChannelMember, Channels
from open_webui.models.tags import Tag, TagModel, Tags
from open_webui.models.users import User, UserNameResponse, Users
from pydantic import BaseModel, ConfigDict, field_validator
from sqlalchemy import BigInteger, Boolean, Column, String, Text, JSON
from sqlalchemy import or_, func, and_, text
from sqlalchemy import JSON, BigInteger, Boolean, Column, String, Text, and_, delete, func, or_, select, text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import exists
####################
+34 -39
View File
@@ -1,21 +1,18 @@
from __future__ import annotations
import json
import logging
import time
from typing import Optional
from sqlalchemy import select, delete, update, or_, func, String, cast
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.models.groups import Groups
from open_webui.models.users import User, UserModel, Users, UserResponse
from open_webui.models.access_grants import AccessGrantModel, AccessGrants
from open_webui.models.groups import Groups
from open_webui.models.users import User, UserModel, UserResponse, Users
from pydantic import BaseModel, ConfigDict, Field, model_validator
from sqlalchemy import BigInteger, Boolean, Column, String, Text, cast, delete, func, or_, select, update
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy import BigInteger, Column, Text, Boolean
from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
@@ -35,14 +32,14 @@ class ModelParams(BaseModel):
# ModelMeta is a model for the data stored in the meta field of the Model table
class ModelMeta(BaseModel):
profile_image_url: Optional[str] = '/static/favicon.png'
profile_image_url: str | None = '/static/favicon.png'
description: Optional[str] = None
description: str | None = None
"""
User-facing description of the model.
"""
capabilities: Optional[dict] = None
capabilities: dict | None = None
model_config = ConfigDict(extra='allow')
@@ -100,7 +97,7 @@ class Model(Base):
class ModelModel(BaseModel):
id: str
user_id: str
base_model_id: Optional[str] = None
base_model_id: str | None = None
name: str
params: ModelParams
@@ -121,11 +118,11 @@ class ModelModel(BaseModel):
class ModelUserResponse(ModelModel):
user: Optional[UserResponse] = None
user: UserResponse | None = None
class ModelAccessResponse(ModelUserResponse):
write_access: Optional[bool] = False
write_access: bool | None = False
class ModelResponse(ModelModel):
@@ -146,23 +143,23 @@ class ModelForm(BaseModel):
model_config = ConfigDict(extra='ignore')
id: str
base_model_id: Optional[str] = None
base_model_id: str | None = None
name: str
meta: ModelMeta
params: ModelParams
access_grants: Optional[list[dict]] = None
access_grants: list[dict | None] = None
is_active: bool = True
class ModelsTable:
async def _get_access_grants(self, model_id: str, db: Optional[AsyncSession] = None) -> list[AccessGrantModel]:
async def _get_access_grants(self, model_id: str, db: AsyncSession | None = None) -> list[AccessGrantModel]:
return await AccessGrants.get_grants_by_resource('model', model_id, db=db)
async def _to_model_model(
self,
model: Model,
access_grants: Optional[list[AccessGrantModel]] = None,
db: Optional[AsyncSession] = None,
access_grants: list[AccessGrantModel | None] = None,
db: AsyncSession | None = None,
) -> ModelModel:
model_data = ModelModel.model_validate(model).model_dump(exclude={'access_grants'})
model_data['access_grants'] = (
@@ -171,8 +168,8 @@ class ModelsTable:
return ModelModel.model_validate(model_data)
async def insert_new_model(
self, form_data: ModelForm, user_id: str, db: Optional[AsyncSession] = None
) -> Optional[ModelModel]:
self, form_data: ModelForm, user_id: str, db: AsyncSession | None = None
) -> ModelModel | None:
try:
async with get_async_db_context(db) as db:
result = Model(
@@ -196,7 +193,7 @@ class ModelsTable:
log.exception(f'Failed to insert a new model: {e}')
return None
async def get_all_models(self, db: Optional[AsyncSession] = None) -> list[ModelModel]:
async def get_all_models(self, db: AsyncSession | None = None) -> list[ModelModel]:
async with get_async_db_context(db) as db:
result = await db.execute(select(Model))
all_models = result.scalars().all()
@@ -207,7 +204,7 @@ class ModelsTable:
for model in all_models
]
async def get_models(self, db: Optional[AsyncSession] = None) -> list[ModelUserResponse]:
async def get_models(self, db: AsyncSession | None = None) -> list[ModelUserResponse]:
async with get_async_db_context(db) as db:
result = await db.execute(select(Model).filter(Model.base_model_id != None))
all_models = result.scalars().all()
@@ -238,7 +235,7 @@ class ModelsTable:
)
return models
async def get_base_models(self, db: Optional[AsyncSession] = None) -> list[ModelModel]:
async def get_base_models(self, db: AsyncSession | None = None) -> list[ModelModel]:
async with get_async_db_context(db) as db:
result = await db.execute(select(Model).filter(Model.base_model_id == None))
all_models = result.scalars().all()
@@ -250,7 +247,7 @@ class ModelsTable:
]
async def get_models_by_user_id(
self, user_id: str, permission: str = 'write', db: Optional[AsyncSession] = None
self, user_id: str, permission: str = 'write', db: AsyncSession | None = None
) -> list[ModelUserResponse]:
models = await self.get_models(db=db)
user_groups = await Groups.get_groups_by_member_id(user_id, db=db)
@@ -287,7 +284,7 @@ class ModelsTable:
filter: dict = {},
skip: int = 0,
limit: int = 30,
db: Optional[AsyncSession] = None,
db: AsyncSession | None = None,
) -> ModelListResponse:
async with get_async_db_context(db) as db:
stmt = select(Model, User).outerjoin(User, User.id == Model.user_id)
@@ -391,7 +388,7 @@ class ModelsTable:
return ModelListResponse(items=models, total=total)
async def get_model_meta_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[tuple[dict, int]]:
async def get_model_meta_by_id(self, id: str, db: AsyncSession | None = None) -> tuple[dict, int | None]:
"""Return (meta, updated_at) for a model, skipping access grant resolution."""
try:
async with get_async_db_context(db) as db:
@@ -404,7 +401,7 @@ class ModelsTable:
self,
user_id: str,
is_admin: bool = False,
db: Optional[AsyncSession] = None,
db: AsyncSession | None = None,
) -> set[str]:
"""Extract unique tag names from model meta, querying only the meta column."""
async with get_async_db_context(db) as db:
@@ -437,7 +434,7 @@ class ModelsTable:
return tags_set
async def get_model_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ModelModel]:
async def get_model_by_id(self, id: str, db: AsyncSession | None = None) -> ModelModel | None:
try:
async with get_async_db_context(db) as db:
model = await db.get(Model, id)
@@ -445,7 +442,7 @@ class ModelsTable:
except Exception:
return None
async def get_models_by_ids(self, ids: list[str], db: Optional[AsyncSession] = None) -> list[ModelModel]:
async def get_models_by_ids(self, ids: list[str], db: AsyncSession | None = None) -> list[ModelModel]:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(Model).filter(Model.id.in_(ids)))
@@ -463,7 +460,7 @@ class ModelsTable:
except Exception:
return []
async def toggle_model_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ModelModel]:
async def toggle_model_by_id(self, id: str, db: AsyncSession | None = None) -> ModelModel | None:
async with get_async_db_context(db) as db:
try:
result = await db.execute(select(Model).filter_by(id=id))
@@ -480,9 +477,7 @@ class ModelsTable:
except Exception:
return None
async def update_model_by_id(
self, id: str, model: ModelForm, db: Optional[AsyncSession] = None
) -> Optional[ModelModel]:
async def update_model_by_id(self, id: str, model: ModelForm, db: AsyncSession | None = None) -> ModelModel | None:
try:
async with get_async_db_context(db) as db:
# update only the fields that are present in the model
@@ -499,7 +494,7 @@ class ModelsTable:
log.exception(f'Failed to update the model by id {id}: {e}')
return None
async def update_model_updated_at_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ModelModel]:
async def update_model_updated_at_by_id(self, id: str, db: AsyncSession | None = None) -> ModelModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(Model).filter_by(id=id))
@@ -514,7 +509,7 @@ class ModelsTable:
log.exception(f'Failed to update the model updated_at by id {id}: {e}')
return None
async def delete_model_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool:
async def delete_model_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
await AccessGrants.revoke_all_access('model', id, db=db)
@@ -525,7 +520,7 @@ class ModelsTable:
except Exception:
return False
async def delete_all_models(self, db: Optional[AsyncSession] = None) -> bool:
async def delete_all_models(self, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(Model.id))
@@ -540,7 +535,7 @@ class ModelsTable:
return False
async def sync_models(
self, user_id: str, models: list[ModelModel], db: Optional[AsyncSession] = None
self, user_id: str, models: list[ModelModel], db: AsyncSession | None = None
) -> list[ModelModel]:
try:
async with get_async_db_context(db) as db:
+5 -8
View File
@@ -1,19 +1,16 @@
import json
import time
import uuid
from typing import Optional
from functools import lru_cache
from typing import Optional
from sqlalchemy import Boolean, select, delete, update, or_, func, cast
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, get_async_db_context
from open_webui.models.groups import Groups
from open_webui.models.users import User, UserModel, Users, UserResponse
from open_webui.models.access_grants import AccessGrantModel, AccessGrants
from open_webui.models.groups import Groups
from open_webui.models.users import User, UserModel, UserResponse, Users
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import BigInteger, Column, Text, JSON, ForeignKey
from sqlalchemy import JSON, BigInteger, Boolean, Column, ForeignKey, Text, cast, delete, func, or_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
####################
# Note DB Schema
+7 -10
View File
@@ -1,20 +1,17 @@
import time
import logging
import uuid
from typing import Optional, List
import base64
import hashlib
import json
import logging
import time
import uuid
from typing import List, Optional
from cryptography.fernet import Fernet
from sqlalchemy import select, delete, update
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, get_async_db_context
from open_webui.env import OAUTH_SESSION_TOKEN_ENCRYPTION_KEY
from open_webui.internal.db import Base, get_async_db_context
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, Index
from sqlalchemy import BigInteger, Column, Index, String, Text, delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
+5 -7
View File
@@ -1,18 +1,16 @@
"""Prompt history model for version tracking."""
import difflib
import json
import time
import uuid
from typing import Optional
import json
import difflib
from sqlalchemy import select, delete, func
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, get_async_db_context
from open_webui.models.users import Users, UserResponse
from open_webui.models.users import UserResponse, Users
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, Text, JSON, Index
from sqlalchemy import JSON, BigInteger, Column, Index, Text, delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession
####################
# PromptHistory DB Schema
+48 -49
View File
@@ -1,19 +1,18 @@
from __future__ import annotations
import json
import time
import uuid
from typing import Optional
from sqlalchemy import select, delete, update, or_, func, text, cast, String
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.models.groups import Groups
from open_webui.models.users import Users, User, UserModel, UserResponse
from open_webui.models.prompt_history import PromptHistories
from open_webui.models.access_grants import AccessGrantModel, AccessGrants
from open_webui.models.groups import Groups
from open_webui.models.prompt_history import PromptHistories
from open_webui.models.users import User, UserModel, UserResponse, Users
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import BigInteger, Boolean, Column, Text, JSON
from sqlalchemy import JSON, BigInteger, Boolean, Column, String, Text, cast, delete, func, or_, select, text, update
from sqlalchemy.ext.asyncio import AsyncSession
####################
# Prompts DB Schema
@@ -40,18 +39,18 @@ class Prompt(Base):
class PromptModel(BaseModel):
id: Optional[str] = None
id: str | None = None
command: str
user_id: str
name: str
content: str
data: Optional[dict] = None
meta: Optional[dict] = None
tags: Optional[list[str]] = None
is_active: Optional[bool] = True
version_id: Optional[str] = None
created_at: Optional[int] = None
updated_at: Optional[int] = None
data: dict | None = None
meta: dict | None = None
tags: list[str | None] = None
is_active: bool | None = True
version_id: str | None = None
created_at: int | None = None
updated_at: int | None = None
access_grants: list[AccessGrantModel] = Field(default_factory=list)
model_config = ConfigDict(from_attributes=True)
@@ -63,11 +62,11 @@ class PromptModel(BaseModel):
class PromptUserResponse(PromptModel):
user: Optional[UserResponse] = None
user: UserResponse | None = None
class PromptAccessResponse(PromptUserResponse):
write_access: Optional[bool] = False
write_access: bool | None = False
class PromptListResponse(BaseModel):
@@ -84,24 +83,24 @@ class PromptForm(BaseModel):
command: str
name: str # Changed from title
content: str
data: Optional[dict] = None
meta: Optional[dict] = None
tags: Optional[list[str]] = None
access_grants: Optional[list[dict]] = None
version_id: Optional[str] = None # Active version
commit_message: Optional[str] = None # For history tracking
is_production: Optional[bool] = True # Whether to set new version as production
data: dict | None = None
meta: dict | None = None
tags: list[str | None] = None
access_grants: list[dict | None] = None
version_id: str | None = None # Active version
commit_message: str | None = None # For history tracking
is_production: bool | None = True # Whether to set new version as production
class PromptsTable:
async def _get_access_grants(self, prompt_id: str, db: Optional[AsyncSession] = None) -> list[AccessGrantModel]:
async def _get_access_grants(self, prompt_id: str, db: AsyncSession | None = None) -> list[AccessGrantModel]:
return await AccessGrants.get_grants_by_resource('prompt', prompt_id, db=db)
async def _to_prompt_model(
self,
prompt: Prompt,
access_grants: Optional[list[AccessGrantModel]] = None,
db: Optional[AsyncSession] = None,
access_grants: list[AccessGrantModel | None] = None,
db: AsyncSession | None = None,
) -> PromptModel:
prompt_data = PromptModel.model_validate(prompt).model_dump(exclude={'access_grants'})
prompt_data['access_grants'] = (
@@ -110,8 +109,8 @@ class PromptsTable:
return PromptModel.model_validate(prompt_data)
async def insert_new_prompt(
self, user_id: str, form_data: PromptForm, db: Optional[AsyncSession] = None
) -> Optional[PromptModel]:
self, user_id: str, form_data: PromptForm, db: AsyncSession | None = None
) -> PromptModel | None:
now = int(time.time())
prompt_id = str(uuid.uuid4())
@@ -171,7 +170,7 @@ class PromptsTable:
except Exception:
return None
async def get_prompt_by_id(self, prompt_id: str, db: Optional[AsyncSession] = None) -> Optional[PromptModel]:
async def get_prompt_by_id(self, prompt_id: str, db: AsyncSession | None = None) -> PromptModel | None:
"""Get prompt by UUID."""
try:
async with get_async_db_context(db) as db:
@@ -183,7 +182,7 @@ class PromptsTable:
except Exception:
return None
async def get_prompt_by_command(self, command: str, db: Optional[AsyncSession] = None) -> Optional[PromptModel]:
async def get_prompt_by_command(self, command: str, db: AsyncSession | None = None) -> PromptModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(Prompt).filter_by(command=command))
@@ -194,7 +193,7 @@ class PromptsTable:
except Exception:
return None
async def get_prompts(self, db: Optional[AsyncSession] = None) -> list[PromptUserResponse]:
async def get_prompts(self, db: AsyncSession | None = None) -> list[PromptUserResponse]:
async with get_async_db_context(db) as db:
result = await db.execute(
select(Prompt).filter(Prompt.is_active == True).order_by(Prompt.updated_at.desc())
@@ -229,7 +228,7 @@ class PromptsTable:
return prompts
async def get_prompts_by_user_id(
self, user_id: str, permission: str = 'write', db: Optional[AsyncSession] = None
self, user_id: str, permission: str = 'write', db: AsyncSession | None = None
) -> list[PromptUserResponse]:
async with get_async_db_context(db) as db:
user_groups = await Groups.get_groups_by_member_id(user_id, db=db)
@@ -283,7 +282,7 @@ class PromptsTable:
filter: dict = {},
skip: int = 0,
limit: int = 30,
db: Optional[AsyncSession] = None,
db: AsyncSession | None = None,
) -> PromptListResponse:
async with get_async_db_context(db) as db:
# Join with User table for user filtering and sorting
@@ -404,8 +403,8 @@ class PromptsTable:
command: str,
form_data: PromptForm,
user_id: str,
db: Optional[AsyncSession] = None,
) -> Optional[PromptModel]:
db: AsyncSession | None = None,
) -> PromptModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(Prompt).filter_by(command=command))
@@ -470,8 +469,8 @@ class PromptsTable:
prompt_id: str,
form_data: PromptForm,
user_id: str,
db: Optional[AsyncSession] = None,
) -> Optional[PromptModel]:
db: AsyncSession | None = None,
) -> PromptModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(Prompt).filter_by(id=prompt_id))
@@ -545,9 +544,9 @@ class PromptsTable:
prompt_id: str,
name: str,
command: str,
tags: Optional[list[str]] = None,
db: Optional[AsyncSession] = None,
) -> Optional[PromptModel]:
tags: list[str | None] = None,
db: AsyncSession | None = None,
) -> PromptModel | None:
"""Update only name, command, and tags (no history created)."""
try:
async with get_async_db_context(db) as db:
@@ -573,8 +572,8 @@ class PromptsTable:
self,
prompt_id: str,
version_id: str,
db: Optional[AsyncSession] = None,
) -> Optional[PromptModel]:
db: AsyncSession | None = None,
) -> PromptModel | None:
"""Set the active version of a prompt and restore content from that version's snapshot."""
try:
async with get_async_db_context(db) as db:
@@ -606,7 +605,7 @@ class PromptsTable:
except Exception:
return None
async def toggle_prompt_active(self, prompt_id: str, db: Optional[AsyncSession] = None) -> Optional[PromptModel]:
async def toggle_prompt_active(self, prompt_id: str, db: AsyncSession | None = None) -> PromptModel | None:
"""Toggle the is_active flag on a prompt."""
try:
async with get_async_db_context(db) as db:
@@ -622,7 +621,7 @@ class PromptsTable:
except Exception:
return None
async def delete_prompt_by_command(self, command: str, db: Optional[AsyncSession] = None) -> bool:
async def delete_prompt_by_command(self, command: str, db: AsyncSession | None = None) -> bool:
"""Permanently delete a prompt and its history."""
try:
async with get_async_db_context(db) as db:
@@ -639,7 +638,7 @@ class PromptsTable:
except Exception:
return False
async def delete_prompt_by_id(self, prompt_id: str, db: Optional[AsyncSession] = None) -> bool:
async def delete_prompt_by_id(self, prompt_id: str, db: AsyncSession | None = None) -> bool:
"""Permanently delete a prompt and its history."""
try:
async with get_async_db_context(db) as db:
@@ -656,7 +655,7 @@ class PromptsTable:
except Exception:
return False
async def get_tags(self, db: Optional[AsyncSession] = None) -> list[str]:
async def get_tags(self, db: AsyncSession | None = None) -> list[str]:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(Prompt.tags).filter(Prompt.is_active == True))
@@ -670,7 +669,7 @@ class PromptsTable:
except Exception:
return []
async def get_tags_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> list[str]:
async def get_tags_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> list[str]:
try:
async with get_async_db_context(db) as db:
user_groups = await Groups.get_groups_by_member_id(user_id, db=db)
+2 -4
View File
@@ -3,12 +3,10 @@ import time
import uuid
from typing import Optional
from sqlalchemy import select, delete
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, JSONField, get_async_db_context
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, ForeignKey, Text, JSON
from sqlalchemy import JSON, BigInteger, Column, ForeignKey, Text, delete, select
from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
+4 -6
View File
@@ -2,15 +2,13 @@ import logging
import time
from typing import Optional
from sqlalchemy import select, delete, update, or_
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, get_async_db_context
from open_webui.models.users import Users, User, UserModel, UserResponse
from open_webui.models.groups import Groups
from open_webui.models.access_grants import AccessGrantModel, AccessGrants
from open_webui.models.groups import Groups
from open_webui.models.users import User, UserModel, UserResponse, Users
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import JSON, BigInteger, Boolean, Column, String, Text, func
from sqlalchemy import JSON, BigInteger, Boolean, Column, String, Text, delete, func, or_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
+13 -14
View File
@@ -1,15 +1,14 @@
from __future__ import annotations
import logging
import time
import uuid
from typing import Optional
from sqlalchemy import select, delete
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, JSONField, get_async_db_context
from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, JSON, PrimaryKeyConstraint, Index
from sqlalchemy import JSON, BigInteger, Column, Index, PrimaryKeyConstraint, String, delete, select
from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
@@ -39,7 +38,7 @@ class TagModel(BaseModel):
id: str
name: str
user_id: str
meta: Optional[dict] = None
meta: dict | None = None
model_config = ConfigDict(from_attributes=True)
@@ -54,7 +53,7 @@ class TagChatIdForm(BaseModel):
class TagTable:
async def insert_new_tag(self, name: str, user_id: str, db: Optional[AsyncSession] = None) -> Optional[TagModel]:
async def insert_new_tag(self, name: str, user_id: str, db: AsyncSession | None = None) -> TagModel | None:
async with get_async_db_context(db) as db:
id = name.replace(' ', '_').lower()
tag = TagModel(**{'id': id, 'user_id': user_id, 'name': name})
@@ -72,8 +71,8 @@ class TagTable:
return None
async def get_tag_by_name_and_user_id(
self, name: str, user_id: str, db: Optional[AsyncSession] = None
) -> Optional[TagModel]:
self, name: str, user_id: str, db: AsyncSession | None = None
) -> TagModel | None:
try:
id = name.replace(' ', '_').lower()
async with get_async_db_context(db) as db:
@@ -83,19 +82,19 @@ class TagTable:
except Exception:
return None
async def get_tags_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> list[TagModel]:
async def get_tags_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> list[TagModel]:
async with get_async_db_context(db) as db:
result = await db.execute(select(Tag).filter_by(user_id=user_id))
return [TagModel.model_validate(tag) for tag in result.scalars().all()]
async def get_tags_by_ids_and_user_id(
self, ids: list[str], user_id: str, db: Optional[AsyncSession] = None
self, ids: list[str], user_id: str, db: AsyncSession | None = None
) -> list[TagModel]:
async with get_async_db_context(db) as db:
result = await db.execute(select(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id))
return [TagModel.model_validate(tag) for tag in result.scalars().all()]
async def delete_tag_by_name_and_user_id(self, name: str, user_id: str, db: Optional[AsyncSession] = None) -> bool:
async def delete_tag_by_name_and_user_id(self, name: str, user_id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
id = name.replace(' ', '_').lower()
@@ -108,7 +107,7 @@ class TagTable:
return False
async def delete_tags_by_ids_and_user_id(
self, ids: list[str], user_id: str, db: Optional[AsyncSession] = None
self, ids: list[str], user_id: str, db: AsyncSession | None = None
) -> bool:
"""Delete all tags whose id is in *ids* for the given user, in one query."""
if not ids:
@@ -122,7 +121,7 @@ class TagTable:
log.error(f'delete_tags_by_ids: {e}')
return False
async def ensure_tags_exist(self, names: list[str], user_id: str, db: Optional[AsyncSession] = None) -> None:
async def ensure_tags_exist(self, names: list[str], user_id: str, db: AsyncSession | None = None) -> None:
"""Create tag rows for any *names* that don't already exist for *user_id*."""
if not names:
return
+30 -30
View File
@@ -1,16 +1,16 @@
from __future__ import annotations
import logging
import time
from typing import Optional
from sqlalchemy import select, delete, update
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.models.users import Users, UserResponse
from open_webui.models.groups import Groups
from open_webui.models.access_grants import AccessGrantModel, AccessGrants
from open_webui.models.groups import Groups
from open_webui.models.users import UserResponse, Users
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import BigInteger, Column, String, Text
from sqlalchemy import BigInteger, Column, String, Text, delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
@@ -37,8 +37,8 @@ class Tool(Base):
class ToolMeta(BaseModel):
description: Optional[str] = None
manifest: Optional[dict] = {}
description: str | None = None
manifest: dict | None = {}
class ToolModel(BaseModel):
@@ -62,7 +62,7 @@ class ToolModel(BaseModel):
class ToolUserModel(ToolModel):
user: Optional[UserResponse] = None
user: UserResponse | None = None
class ToolResponse(BaseModel):
@@ -76,13 +76,13 @@ class ToolResponse(BaseModel):
class ToolUserResponse(ToolResponse):
user: Optional[UserResponse] = None
user: UserResponse | None = None
model_config = ConfigDict(extra='allow')
class ToolAccessResponse(ToolUserResponse):
write_access: Optional[bool] = False
write_access: bool | None = False
class ToolForm(BaseModel):
@@ -90,22 +90,22 @@ class ToolForm(BaseModel):
name: str
content: str
meta: ToolMeta
access_grants: Optional[list[dict]] = None
access_grants: list[dict | None] = None
class ToolValves(BaseModel):
valves: Optional[dict] = None
valves: dict | None = None
class ToolsTable:
async def _get_access_grants(self, tool_id: str, db: Optional[AsyncSession] = None) -> list[AccessGrantModel]:
async def _get_access_grants(self, tool_id: str, db: AsyncSession | None = None) -> list[AccessGrantModel]:
return await AccessGrants.get_grants_by_resource('tool', tool_id, db=db)
async def _to_tool_model(
self,
tool: Tool,
access_grants: Optional[list[AccessGrantModel]] = None,
db: Optional[AsyncSession] = None,
access_grants: list[AccessGrantModel | None] = None,
db: AsyncSession | None = None,
) -> ToolModel:
tool_data = ToolModel.model_validate(tool).model_dump(exclude={'access_grants'})
tool_data['access_grants'] = (
@@ -118,8 +118,8 @@ class ToolsTable:
user_id: str,
form_data: ToolForm,
specs: list[dict],
db: Optional[AsyncSession] = None,
) -> Optional[ToolModel]:
db: AsyncSession | None = None,
) -> ToolModel | None:
async with get_async_db_context(db) as db:
try:
result = Tool(
@@ -143,7 +143,7 @@ class ToolsTable:
log.exception(f'Error creating a new tool: {e}')
return None
async def get_tool_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[ToolModel]:
async def get_tool_by_id(self, id: str, db: AsyncSession | None = None) -> ToolModel | None:
try:
async with get_async_db_context(db) as db:
tool = await db.get(Tool, id)
@@ -151,7 +151,7 @@ class ToolsTable:
except Exception:
return None
async def get_tools(self, defer_content: bool = False, db: Optional[AsyncSession] = None) -> list[ToolUserModel]:
async def get_tools(self, defer_content: bool = False, db: AsyncSession | None = None) -> list[ToolUserModel]:
async with get_async_db_context(db) as db:
stmt = select(Tool).order_by(Tool.updated_at.desc())
if defer_content:
@@ -190,7 +190,7 @@ class ToolsTable:
user_id: str,
permission: str = 'write',
defer_content: bool = False,
db: Optional[AsyncSession] = None,
db: AsyncSession | None = None,
) -> list[ToolUserModel]:
tools = await self.get_tools(defer_content=defer_content, db=db)
user_groups = await Groups.get_groups_by_member_id(user_id, db=db)
@@ -211,7 +211,7 @@ class ToolsTable:
result.append(tool)
return result
async def get_tool_valves_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[dict]:
async def get_tool_valves_by_id(self, id: str, db: AsyncSession | None = None) -> dict | None:
try:
async with get_async_db_context(db) as db:
tool = await db.get(Tool, id)
@@ -221,8 +221,8 @@ class ToolsTable:
return None
async def update_tool_valves_by_id(
self, id: str, valves: dict, db: Optional[AsyncSession] = None
) -> Optional[ToolValves]:
self, id: str, valves: dict, db: AsyncSession | None = None
) -> ToolValves | None:
try:
async with get_async_db_context(db) as db:
await db.execute(update(Tool).filter_by(id=id).values(valves=valves, updated_at=int(time.time())))
@@ -232,8 +232,8 @@ class ToolsTable:
return None
async def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str, db: Optional[AsyncSession] = None
) -> Optional[dict]:
self, id: str, user_id: str, db: AsyncSession | None = None
) -> dict | None:
try:
user = await Users.get_user_by_id(user_id, db=db)
user_settings = user.settings.model_dump() if user.settings else {}
@@ -250,8 +250,8 @@ class ToolsTable:
return None
async def update_user_valves_by_id_and_user_id(
self, id: str, user_id: str, valves: dict, db: Optional[AsyncSession] = None
) -> Optional[dict]:
self, id: str, user_id: str, valves: dict, db: AsyncSession | None = None
) -> dict | None:
try:
user = await Users.get_user_by_id(user_id, db=db)
user_settings = user.settings.model_dump() if user.settings else {}
@@ -272,7 +272,7 @@ class ToolsTable:
log.exception(f'Error updating user valves by id {id} and user_id {user_id}: {e}')
return None
async def update_tool_by_id(self, id: str, updated: dict, db: Optional[AsyncSession] = None) -> Optional[ToolModel]:
async def update_tool_by_id(self, id: str, updated: dict, db: AsyncSession | None = None) -> ToolModel | None:
try:
async with get_async_db_context(db) as db:
access_grants = updated.pop('access_grants', None)
@@ -287,7 +287,7 @@ class ToolsTable:
except Exception:
return None
async def delete_tool_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool:
async def delete_tool_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
await AccessGrants.revoke_all_access('tool', id, db=db)
+94 -96
View File
@@ -1,29 +1,33 @@
from __future__ import annotations
import datetime
import time
from typing import Optional
from sqlalchemy import select, delete, update, func, or_, case, exists
from sqlalchemy.ext.asyncio import AsyncSession
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.utils.misc import throttle
from open_webui.utils.validate import validate_profile_image_url
from pydantic import BaseModel, ConfigDict, field_validator, model_validator
from sqlalchemy import (
BigInteger,
JSON,
Column,
String,
BigInteger,
Boolean,
Text,
Column,
Date,
String,
Text,
case,
cast,
delete,
exists,
func,
or_,
select,
update,
)
from sqlalchemy.dialects.postgresql import JSONB
import datetime
from sqlalchemy.ext.asyncio import AsyncSession
####################
# User DB Schema
@@ -33,7 +37,7 @@ import datetime
class UserSettings(BaseModel):
ui: Optional[dict] = {}
ui: dict | None = {}
model_config = ConfigDict(extra='allow')
pass
@@ -76,29 +80,29 @@ class UserModel(BaseModel):
id: str
email: str
username: Optional[str] = None
username: str | None = None
role: str = 'pending'
name: str
profile_image_url: Optional[str] = None
profile_banner_image_url: Optional[str] = None
profile_image_url: str | None = None
profile_banner_image_url: str | None = None
bio: Optional[str] = None
gender: Optional[str] = None
date_of_birth: Optional[datetime.date] = None
timezone: Optional[str] = None
bio: str | None = None
gender: str | None = None
date_of_birth: datetime.date | None = None
timezone: str | None = None
presence_state: Optional[str] = None
status_emoji: Optional[str] = None
status_message: Optional[str] = None
status_expires_at: Optional[int] = None
presence_state: str | None = None
status_emoji: str | None = None
status_message: str | None = None
status_expires_at: int | None = None
info: Optional[dict] = None
settings: Optional[UserSettings] = None
info: dict | None = None
settings: UserSettings | None = None
oauth: Optional[dict] = None
scim: Optional[dict] = None
oauth: dict | None = None
scim: dict | None = None
last_active_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
@@ -136,9 +140,9 @@ class ApiKeyModel(BaseModel):
id: str
user_id: str
key: str
data: Optional[dict] = None
expires_at: Optional[int] = None
last_used_at: Optional[int] = None
data: dict | None = None
expires_at: int | None = None
last_used_at: int | None = None
created_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
@@ -153,9 +157,9 @@ class ApiKeyModel(BaseModel):
class UpdateProfileForm(BaseModel):
profile_image_url: str
name: str
bio: Optional[str] = None
gender: Optional[str] = None
date_of_birth: Optional[datetime.date] = None
bio: str | None = None
gender: str | None = None
date_of_birth: datetime.date | None = None
@field_validator('profile_image_url')
@classmethod
@@ -182,9 +186,9 @@ class UserGroupIdsListResponse(BaseModel):
class UserStatus(BaseModel):
status_emoji: Optional[str] = None
status_message: Optional[str] = None
status_expires_at: Optional[int] = None
status_emoji: str | None = None
status_message: str | None = None
status_expires_at: int | None = None
class UserInfoResponse(UserStatus):
@@ -192,8 +196,8 @@ class UserInfoResponse(UserStatus):
name: str
email: str
role: str
bio: Optional[str] = None
groups: Optional[list] = []
bio: str | None = None
groups: list | None = []
is_active: bool = False
@@ -205,7 +209,7 @@ class UserIdNameResponse(BaseModel):
class UserIdNameStatusResponse(UserStatus):
id: str
name: str
is_active: Optional[bool] = None
is_active: bool | None = None
class UserInfoListResponse(BaseModel):
@@ -239,15 +243,15 @@ class UserRoleUpdateForm(BaseModel):
class UserUpdateForm(BaseModel):
role: Optional[str] = None
name: Optional[str] = None
email: Optional[str] = None
profile_image_url: Optional[str] = None
password: Optional[str] = None
role: str | None = None
name: str | None = None
email: str | None = None
profile_image_url: str | None = None
password: str | None = None
@field_validator('profile_image_url', mode='before')
@classmethod
def check_profile_image_url(cls, v: Optional[str]) -> Optional[str]:
def check_profile_image_url(cls, v: str | None) -> str | None:
if v is None:
return v
return validate_profile_image_url(v)
@@ -261,10 +265,10 @@ class UsersTable:
email: str,
profile_image_url: str = '/user.png',
role: str = 'pending',
username: Optional[str] = None,
oauth: Optional[dict] = None,
db: Optional[AsyncSession] = None,
) -> Optional[UserModel]:
username: str | None = None,
oauth: dict | None = None,
db: AsyncSession | None = None,
) -> UserModel | None:
async with get_async_db_context(db) as db:
user = UserModel(
**{
@@ -289,7 +293,7 @@ class UsersTable:
else:
return None
async def get_user_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[UserModel]:
async def get_user_by_id(self, id: str, db: AsyncSession | None = None) -> UserModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter_by(id=id))
@@ -298,7 +302,7 @@ class UsersTable:
except Exception:
return None
async def get_user_by_api_key(self, api_key: str, db: Optional[AsyncSession] = None) -> Optional[UserModel]:
async def get_user_by_api_key(self, api_key: str, db: AsyncSession | None = None) -> UserModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(
@@ -309,7 +313,7 @@ class UsersTable:
except Exception:
return None
async def get_user_by_email(self, email: str, db: Optional[AsyncSession] = None) -> Optional[UserModel]:
async def get_user_by_email(self, email: str, db: AsyncSession | None = None) -> UserModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter(func.lower(User.email) == email.lower()))
@@ -318,9 +322,7 @@ class UsersTable:
except Exception:
return None
async def get_user_by_oauth_sub(
self, provider: str, sub: str, db: Optional[AsyncSession] = None
) -> Optional[UserModel]:
async def get_user_by_oauth_sub(self, provider: str, sub: str, db: AsyncSession | None = None) -> UserModel | None:
try:
async with get_async_db_context(db) as db:
dialect_name = db.bind.dialect.name
@@ -339,8 +341,8 @@ class UsersTable:
return None
async def get_user_by_scim_external_id(
self, provider: str, external_id: str, db: Optional[AsyncSession] = None
) -> Optional[UserModel]:
self, provider: str, external_id: str, db: AsyncSession | None = None
) -> UserModel | None:
try:
async with get_async_db_context(db) as db:
dialect_name = db.bind.dialect.name
@@ -359,15 +361,15 @@ class UsersTable:
async def get_users(
self,
filter: Optional[dict] = None,
skip: Optional[int] = None,
limit: Optional[int] = None,
db: Optional[AsyncSession] = None,
filter: dict | None = None,
skip: int | None = None,
limit: int | None = None,
db: AsyncSession | None = None,
) -> dict:
async with get_async_db_context(db) as db:
# Import here to avoid circular imports
from open_webui.models.groups import GroupMember
from open_webui.models.channels import ChannelMember
from open_webui.models.groups import GroupMember
# Join GroupMember so we can order by group_id when requested
stmt = select(User)
@@ -501,7 +503,7 @@ class UsersTable:
'total': total,
}
async def get_users_by_group_id(self, group_id: str, db: Optional[AsyncSession] = None) -> list[UserModel]:
async def get_users_by_group_id(self, group_id: str, db: AsyncSession | None = None) -> list[UserModel]:
async with get_async_db_context(db) as db:
from open_webui.models.groups import GroupMember
@@ -511,25 +513,23 @@ class UsersTable:
users = result.scalars().all()
return [UserModel.model_validate(user) for user in users]
async def get_users_by_user_ids(
self, user_ids: list[str], db: Optional[AsyncSession] = None
) -> list[UserStatusModel]:
async def get_users_by_user_ids(self, user_ids: list[str], db: AsyncSession | None = None) -> list[UserStatusModel]:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter(User.id.in_(user_ids)))
users = result.scalars().all()
return [UserModel.model_validate(user) for user in users]
async def get_num_users(self, db: Optional[AsyncSession] = None) -> Optional[int]:
async def get_num_users(self, db: AsyncSession | None = None) -> int | None:
async with get_async_db_context(db) as db:
result = await db.execute(select(func.count()).select_from(User))
return result.scalar()
async def has_users(self, db: Optional[AsyncSession] = None) -> bool:
async def has_users(self, db: AsyncSession | None = None) -> bool:
async with get_async_db_context(db) as db:
result = await db.execute(select(exists(select(User))))
return result.scalar()
async def get_first_user(self, db: Optional[AsyncSession] = None) -> UserModel:
async def get_first_user(self, db: AsyncSession | None = None) -> UserModel:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).order_by(User.created_at).limit(1))
@@ -538,7 +538,7 @@ class UsersTable:
except Exception:
return None
async def get_user_webhook_url_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[str]:
async def get_user_webhook_url_by_id(self, id: str, db: AsyncSession | None = None) -> str | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter_by(id=id))
@@ -551,7 +551,7 @@ class UsersTable:
except Exception:
return None
async def get_num_users_active_today(self, db: Optional[AsyncSession] = None) -> Optional[int]:
async def get_num_users_active_today(self, db: AsyncSession | None = None) -> int | None:
async with get_async_db_context(db) as db:
current_timestamp = int(datetime.datetime.now().timestamp())
today_midnight_timestamp = current_timestamp - (current_timestamp % 86400)
@@ -560,9 +560,7 @@ class UsersTable:
)
return result.scalar()
async def update_user_role_by_id(
self, id: str, role: str, db: Optional[AsyncSession] = None
) -> Optional[UserModel]:
async def update_user_role_by_id(self, id: str, role: str, db: AsyncSession | None = None) -> UserModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter_by(id=id))
@@ -577,8 +575,8 @@ class UsersTable:
return None
async def update_user_status_by_id(
self, id: str, form_data: UserStatus, db: Optional[AsyncSession] = None
) -> Optional[UserModel]:
self, id: str, form_data: UserStatus, db: AsyncSession | None = None
) -> UserModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter_by(id=id))
@@ -594,8 +592,8 @@ class UsersTable:
return None
async def update_user_profile_image_url_by_id(
self, id: str, profile_image_url: str, db: Optional[AsyncSession] = None
) -> Optional[UserModel]:
self, id: str, profile_image_url: str, db: AsyncSession | None = None
) -> UserModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter_by(id=id))
@@ -610,7 +608,7 @@ class UsersTable:
return None
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
async def update_last_active_by_id(self, id: str, db: Optional[AsyncSession] = None) -> None:
async def update_last_active_by_id(self, id: str, db: AsyncSession | None = None) -> None:
try:
async with get_async_db_context(db) as db:
await db.execute(update(User).filter_by(id=id).values(last_active_at=int(time.time())))
@@ -619,8 +617,8 @@ class UsersTable:
pass
async def update_user_oauth_by_id(
self, id: str, provider: str, sub: str, db: Optional[AsyncSession] = None
) -> Optional[UserModel]:
self, id: str, provider: str, sub: str, db: AsyncSession | None = None
) -> UserModel | None:
"""
Update or insert an OAuth provider/sub pair into the user's oauth JSON field.
Example resulting structure:
@@ -656,8 +654,8 @@ class UsersTable:
id: str,
provider: str,
external_id: str,
db: Optional[AsyncSession] = None,
) -> Optional[UserModel]:
db: AsyncSession | None = None,
) -> UserModel | None:
"""
Update or insert a SCIM provider/external_id pair into the user's scim JSON field.
Example resulting structure:
@@ -684,7 +682,7 @@ class UsersTable:
except Exception:
return None
async def update_user_by_id(self, id: str, updated: dict, db: Optional[AsyncSession] = None) -> Optional[UserModel]:
async def update_user_by_id(self, id: str, updated: dict, db: AsyncSession | None = None) -> UserModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter_by(id=id))
@@ -701,8 +699,8 @@ class UsersTable:
return None
async def update_user_settings_by_id(
self, id: str, updated: dict, db: Optional[AsyncSession] = None
) -> Optional[UserModel]:
self, id: str, updated: dict, db: AsyncSession | None = None
) -> UserModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter_by(id=id))
@@ -726,10 +724,10 @@ class UsersTable:
except Exception:
return None
async def delete_user_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool:
async def delete_user_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
try:
from open_webui.models.groups import Groups
from open_webui.models.chats import Chats
from open_webui.models.groups import Groups
# Remove User from Groups
await Groups.remove_user_from_all_groups(id)
@@ -748,7 +746,7 @@ class UsersTable:
except Exception:
return False
async def get_user_api_key_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[str]:
async def get_user_api_key_by_id(self, id: str, db: AsyncSession | None = None) -> str | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(ApiKey).filter_by(user_id=id))
@@ -757,7 +755,7 @@ class UsersTable:
except Exception:
return None
async def update_user_api_key_by_id(self, id: str, api_key: str, db: Optional[AsyncSession] = None) -> bool:
async def update_user_api_key_by_id(self, id: str, api_key: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
await db.execute(delete(ApiKey).filter_by(user_id=id))
@@ -779,7 +777,7 @@ class UsersTable:
except Exception:
return False
async def delete_user_api_key_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool:
async def delete_user_api_key_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
await db.execute(delete(ApiKey).filter_by(user_id=id))
@@ -788,13 +786,13 @@ class UsersTable:
except Exception:
return False
async def get_valid_user_ids(self, user_ids: list[str], db: Optional[AsyncSession] = None) -> list[str]:
async def get_valid_user_ids(self, user_ids: list[str], db: AsyncSession | None = None) -> list[str]:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter(User.id.in_(user_ids)))
users = result.scalars().all()
return [user.id for user in users]
async def get_super_admin_user(self, db: Optional[AsyncSession] = None) -> Optional[UserModel]:
async def get_super_admin_user(self, db: AsyncSession | None = None) -> UserModel | None:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter_by(role='admin').limit(1))
user = result.scalars().first()
@@ -803,7 +801,7 @@ class UsersTable:
else:
return None
async def get_active_user_count(self, db: Optional[AsyncSession] = None) -> int:
async def get_active_user_count(self, db: AsyncSession | None = None) -> int:
async with get_async_db_context(db) as db:
# Consider user active if last_active_at within the last 3 minutes
three_minutes_ago = int(time.time()) - 180
@@ -820,7 +818,7 @@ class UsersTable:
return user.last_active_at >= three_minutes_ago
return False
async def is_user_active(self, user_id: str, db: Optional[AsyncSession] = None) -> bool:
async def is_user_active(self, user_id: str, db: AsyncSession | None = None) -> bool:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter_by(id=user_id))
user = result.scalars().first()
@@ -1,11 +1,12 @@
import json
import logging
import os
import time
import requests
import logging
import json
from typing import List, Optional
from langchain_core.documents import Document
import requests
from fastapi import HTTPException, status
from langchain_core.documents import Document
log = logging.getLogger(__name__)
@@ -1,8 +1,9 @@
import requests
import logging, os
import logging
import os
from typing import Iterator, List, Union
from urllib.parse import quote
import requests
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
from open_webui.utils.headers import include_user_info_headers
@@ -1,7 +1,7 @@
import requests
import logging
from typing import Iterator, List, Union
import requests
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
+7 -10
View File
@@ -1,10 +1,10 @@
import asyncio
import requests
import logging
import ftfy
import sys
import json
import logging
import sys
import ftfy
import requests
from azure.identity import DefaultAzureCredential
from langchain_community.document_loaders import (
AzureAIDocumentIntelligenceLoader,
@@ -17,16 +17,13 @@ from langchain_community.document_loaders import (
YoutubeLoader,
)
from langchain_core.documents import Document
from open_webui.retrieval.loaders.external_document import ExternalDocumentLoader
from open_webui.retrieval.loaders.mistral import MistralLoader
from open_webui.env import AIOHTTP_CLIENT_SESSION_SSL, GLOBAL_LOG_LEVEL, REQUESTS_VERIFY
from open_webui.retrieval.loaders.datalab_marker import DatalabMarkerLoader
from open_webui.retrieval.loaders.external_document import ExternalDocumentLoader
from open_webui.retrieval.loaders.mineru import MinerULoader
from open_webui.retrieval.loaders.mistral import MistralLoader
from open_webui.retrieval.loaders.paddleocr_vl import PaddleOCRVLLoader
from open_webui.env import GLOBAL_LOG_LEVEL, REQUESTS_VERIFY, AIOHTTP_CLIENT_SESSION_SSL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
@@ -1,12 +1,13 @@
import os
import time
import requests
import logging
import os
import tempfile
import time
import zipfile
from typing import List, Optional
from langchain_core.documents import Document
import requests
from fastapi import HTTPException, status
from langchain_core.documents import Document
log = logging.getLogger(__name__)
@@ -1,15 +1,15 @@
import requests
import aiohttp
import asyncio
import logging
import os
import sys
import time
from typing import List, Dict, Any
from contextlib import asynccontextmanager
from typing import Any, Dict, List
import aiohttp
import requests
from langchain_core.documents import Document
from open_webui.env import GLOBAL_LOG_LEVEL, AIOHTTP_CLIENT_SESSION_SSL
from open_webui.env import AIOHTTP_CLIENT_SESSION_SSL, GLOBAL_LOG_LEVEL
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
log = logging.getLogger(__name__)
@@ -1,10 +1,10 @@
import base64
import os
import requests
import logging
import os
import sys
from typing import List
import requests
from langchain_core.documents import Document
from open_webui.env import GLOBAL_LOG_LEVEL
@@ -1,7 +1,7 @@
import requests
import logging
from typing import Iterator, List, Literal, Union
import requests
from langchain_core.document_loaders import BaseLoader
from langchain_core.documents import Document
@@ -1,8 +1,8 @@
import logging
from xml.etree.ElementTree import ParseError
from typing import Any, Dict, Generator, List, Optional, Sequence, Union
from urllib.parse import parse_qs, urlparse
from xml.etree.ElementTree import ParseError
from langchain_core.documents import Document
log = logging.getLogger(__name__)
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional, List, Tuple
from typing import List, Optional, Tuple
class BaseReranker(ABC):
@@ -1,11 +1,10 @@
import os
import logging
import torch
import os
import numpy as np
import torch
from colbert.infra import ColBERTConfig
from colbert.modeling.checkpoint import Checkpoint
from open_webui.retrieval.models.base_reranker import BaseReranker
log = logging.getLogger(__name__)
@@ -1,9 +1,8 @@
import logging
import requests
from typing import Optional, List, Tuple
from typing import List, Optional, Tuple
from urllib.parse import quote
import requests
from open_webui.env import ENABLE_FORWARD_USER_INFO_HEADERS, REQUESTS_VERIFY
from open_webui.retrieval.models.base_reranker import BaseReranker
from open_webui.utils.headers import include_user_info_headers
+10 -5
View File
@@ -1,3 +1,5 @@
from __future__ import annotations
import asyncio
import hashlib
import logging
@@ -516,8 +518,11 @@ def get_all_items_from_collections(collection_names: list[str]) -> dict:
async def query_collection(
request, collection_names: list[str], queries: list[str],
embedding_function, k: int,
request,
collection_names: list[str],
queries: list[str],
embedding_function,
k: int,
) -> dict:
# When request is provided, try hybrid search + reranking if enabled
if request and request.app.state.config.ENABLE_RAG_HYBRID_SEARCH:
@@ -1109,7 +1114,7 @@ async def get_sources_from_items(
hybrid_bm25_weight,
hybrid_search,
full_context=False,
user: Optional[UserModel] = None,
user: UserModel | None = None,
):
log.debug(f'items: {items} {queries} {embedding_function} {reranking_function} {full_context}')
@@ -1421,7 +1426,7 @@ class RerankCompressor(BaseDocumentCompressor):
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
callbacks: Callbacks | None = None,
) -> Sequence[Document]:
"""Compress retrieved documents given the query context.
@@ -1440,7 +1445,7 @@ class RerankCompressor(BaseDocumentCompressor):
self,
documents: Sequence[Document],
query: str,
callbacks: Optional[Callbacks] = None,
callbacks: Callbacks | None = None,
) -> Sequence[Document]:
reranking = self.reranking_function is not None
@@ -1,29 +1,27 @@
import chromadb
import logging
from chromadb import Settings
from chromadb.utils.batch_utils import create_batches
from typing import Optional
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.retrieval.vector.utils import process_metadata
import chromadb
from chromadb import Settings
from chromadb.utils.batch_utils import create_batches
from open_webui.config import (
CHROMA_CLIENT_AUTH_CREDENTIALS,
CHROMA_CLIENT_AUTH_PROVIDER,
CHROMA_DATA_PATH,
CHROMA_DATABASE,
CHROMA_HTTP_HEADERS,
CHROMA_HTTP_HOST,
CHROMA_HTTP_PORT,
CHROMA_HTTP_HEADERS,
CHROMA_HTTP_SSL,
CHROMA_TENANT,
CHROMA_DATABASE,
CHROMA_CLIENT_AUTH_PROVIDER,
CHROMA_CLIENT_AUTH_CREDENTIALS,
)
from open_webui.retrieval.vector.main import (
GetResult,
SearchResult,
VectorDBBase,
VectorItem,
)
from open_webui.retrieval.vector.utils import process_metadata
log = logging.getLogger(__name__)
@@ -2,28 +2,28 @@
NOTE: This vector database integration is community-supported and maintained on a best-effort basis.
"""
from elasticsearch import Elasticsearch, BadRequestError
from typing import Optional
import ssl
from elasticsearch.helpers import bulk, scan
from typing import Optional
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from elasticsearch import BadRequestError, Elasticsearch
from elasticsearch.helpers import bulk, scan
from open_webui.config import (
ELASTICSEARCH_URL,
ELASTICSEARCH_CA_CERTS,
ELASTICSEARCH_API_KEY,
ELASTICSEARCH_USERNAME,
ELASTICSEARCH_PASSWORD,
ELASTICSEARCH_CA_CERTS,
ELASTICSEARCH_CLOUD_ID,
ELASTICSEARCH_INDEX_PREFIX,
ELASTICSEARCH_PASSWORD,
ELASTICSEARCH_URL,
ELASTICSEARCH_USERNAME,
SSL_ASSERT_FINGERPRINT,
)
from open_webui.retrieval.vector.main import (
GetResult,
SearchResult,
VectorDBBase,
VectorItem,
)
from open_webui.retrieval.vector.utils import process_metadata
class ElasticsearchClient(VectorDBBase):
@@ -13,18 +13,15 @@ import sys
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy import create_engine
from sqlalchemy.pool import NullPool, QueuePool
from open_webui.config import (
MARIADB_VECTOR_DB_URL,
MARIADB_VECTOR_DISTANCE_STRATEGY,
MARIADB_VECTOR_INDEX_M,
MARIADB_VECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
MARIADB_VECTOR_POOL_SIZE,
MARIADB_VECTOR_POOL_MAX_OVERFLOW,
MARIADB_VECTOR_POOL_TIMEOUT,
MARIADB_VECTOR_POOL_RECYCLE,
MARIADB_VECTOR_POOL_SIZE,
MARIADB_VECTOR_POOL_TIMEOUT,
)
from open_webui.retrieval.vector.main import (
GetResult,
@@ -33,6 +30,8 @@ from open_webui.retrieval.vector.main import (
VectorItem,
)
from open_webui.retrieval.vector.utils import process_metadata
from sqlalchemy import create_engine
from sqlalchemy.pool import NullPool, QueuePool
log = logging.getLogger(__name__)
@@ -2,33 +2,31 @@
NOTE: This vector database integration is community-supported and maintained on a best-effort basis.
"""
from pymilvus import MilvusClient as Client
from pymilvus import FieldSchema, DataType
from pymilvus import connections, Collection
import json
import logging
from typing import Optional
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
MILVUS_URI,
MILVUS_DB,
MILVUS_TOKEN,
MILVUS_INDEX_TYPE,
MILVUS_METRIC_TYPE,
MILVUS_HNSW_M,
MILVUS_HNSW_EFCONSTRUCTION,
MILVUS_IVF_FLAT_NLIST,
MILVUS_DISKANN_MAX_DEGREE,
MILVUS_DISKANN_SEARCH_LIST_SIZE,
MILVUS_HNSW_EFCONSTRUCTION,
MILVUS_HNSW_M,
MILVUS_INDEX_TYPE,
MILVUS_IVF_FLAT_NLIST,
MILVUS_METRIC_TYPE,
MILVUS_TOKEN,
MILVUS_URI,
)
from open_webui.retrieval.vector.main import (
GetResult,
SearchResult,
VectorDBBase,
VectorItem,
)
from open_webui.retrieval.vector.utils import process_metadata
from pymilvus import Collection, DataType, FieldSchema, connections
from pymilvus import MilvusClient as Client
log = logging.getLogger(__name__)
@@ -3,18 +3,18 @@ NOTE: This vector database integration is community-supported and maintained on
"""
import logging
from typing import Optional, Tuple, List, Dict, Any
from typing import Any, Dict, List, Optional, Tuple
from open_webui.config import (
MILVUS_URI,
MILVUS_TOKEN,
MILVUS_DB,
MILVUS_COLLECTION_PREFIX,
MILVUS_INDEX_TYPE,
MILVUS_METRIC_TYPE,
MILVUS_HNSW_M,
MILVUS_DB,
MILVUS_HNSW_EFCONSTRUCTION,
MILVUS_HNSW_M,
MILVUS_INDEX_TYPE,
MILVUS_IVF_FLAT_NLIST,
MILVUS_METRIC_TYPE,
MILVUS_TOKEN,
MILVUS_URI,
)
from open_webui.retrieval.vector.main import (
GetResult,
@@ -23,12 +23,12 @@ from open_webui.retrieval.vector.main import (
VectorItem,
)
from pymilvus import (
connections,
utility,
Collection,
CollectionSchema,
FieldSchema,
DataType,
FieldSchema,
connections,
utility,
)
log = logging.getLogger(__name__)
@@ -2,37 +2,36 @@
NOTE: This vector database integration is community-supported and maintained on a best-effort basis.
"""
from typing import Optional, List, Dict, Any
import json
import logging
import re
import json
from typing import Any, Dict, List, Optional
from pgvector.sqlalchemy import Vector
from sqlalchemy import (
func,
literal,
Column,
Integer,
LargeBinary,
MetaData,
Table,
Text,
cast,
column,
create_engine,
Column,
Integer,
MetaData,
LargeBinary,
func,
literal,
select,
text,
Text,
Table,
values,
)
from sqlalchemy.sql import true
from sqlalchemy.pool import NullPool, QueuePool
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
from sqlalchemy.dialects.postgresql import JSONB, array
from pgvector.sqlalchemy import Vector
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.dialects import registry
from sqlalchemy.dialects.postgresql import JSONB, array
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
from sqlalchemy.pool import NullPool, QueuePool
from sqlalchemy.sql import true
class OpenGaussDialect(PGDialect_psycopg2):
@@ -56,23 +55,22 @@ class OpenGaussDialect(PGDialect_psycopg2):
# Register dialect
registry.register('opengauss', __name__, 'OpenGaussDialect')
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
OPENGAUSS_DB_URL,
OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH,
OPENGAUSS_POOL_SIZE,
OPENGAUSS_POOL_MAX_OVERFLOW,
OPENGAUSS_POOL_TIMEOUT,
OPENGAUSS_POOL_RECYCLE,
OPENGAUSS_POOL_SIZE,
OPENGAUSS_POOL_TIMEOUT,
)
from open_webui.env import SRC_LOG_LEVELS
from open_webui.retrieval.vector.main import (
GetResult,
SearchResult,
VectorDBBase,
VectorItem,
)
from open_webui.retrieval.vector.utils import process_metadata
VECTOR_LENGTH = OPENGAUSS_INITIALIZE_MAX_VECTOR_LENGTH
Base = declarative_base()
@@ -2,24 +2,24 @@
NOTE: This vector database integration is community-supported and maintained on a best-effort basis.
"""
from opensearchpy import OpenSearch
from opensearchpy.helpers import bulk
from typing import Optional
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.config import (
OPENSEARCH_CERT_VERIFY,
OPENSEARCH_PASSWORD,
OPENSEARCH_SSL,
OPENSEARCH_URI,
OPENSEARCH_USERNAME,
)
from open_webui.retrieval.vector.main import (
GetResult,
SearchResult,
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
OPENSEARCH_URI,
OPENSEARCH_SSL,
OPENSEARCH_CERT_VERIFY,
OPENSEARCH_USERNAME,
OPENSEARCH_PASSWORD,
)
from open_webui.retrieval.vector.utils import process_metadata
from opensearchpy import OpenSearch
from opensearchpy.helpers import bulk
class OpenSearchClient(VectorDBBase):
@@ -28,34 +28,33 @@ ORACLE_DB_POOL_MAX = 10
ORACLE_DB_POOL_INCREMENT = 1
"""
from typing import Optional, List, Dict, Any, Union
from decimal import Decimal
import array
import json
import logging
import os
import threading
import time
import json
import array
from decimal import Decimal
from typing import Any, Dict, List, Optional, Union
import oracledb
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
ORACLE_DB_DSN,
ORACLE_DB_PASSWORD,
ORACLE_DB_POOL_INCREMENT,
ORACLE_DB_POOL_MAX,
ORACLE_DB_POOL_MIN,
ORACLE_DB_USE_WALLET,
ORACLE_DB_USER,
ORACLE_DB_PASSWORD,
ORACLE_DB_DSN,
ORACLE_VECTOR_LENGTH,
ORACLE_WALLET_DIR,
ORACLE_WALLET_PASSWORD,
ORACLE_VECTOR_LENGTH,
ORACLE_DB_POOL_MIN,
ORACLE_DB_POOL_MAX,
ORACLE_DB_POOL_INCREMENT,
)
from open_webui.retrieval.vector.main import (
GetResult,
SearchResult,
VectorDBBase,
VectorItem,
)
log = logging.getLogger(__name__)
@@ -1,56 +1,54 @@
from typing import Optional, List, Dict, Any, Tuple
import logging
import json
import logging
from typing import Any, Dict, List, Optional, Tuple
from open_webui.config import (
PGVECTOR_CREATE_EXTENSION,
PGVECTOR_DB_URL,
PGVECTOR_HNSW_EF_CONSTRUCTION,
PGVECTOR_HNSW_M,
PGVECTOR_INDEX_METHOD,
PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
PGVECTOR_IVFFLAT_LISTS,
PGVECTOR_PGCRYPTO,
PGVECTOR_PGCRYPTO_KEY,
PGVECTOR_POOL_MAX_OVERFLOW,
PGVECTOR_POOL_RECYCLE,
PGVECTOR_POOL_SIZE,
PGVECTOR_POOL_TIMEOUT,
PGVECTOR_USE_HALFVEC,
)
from open_webui.retrieval.vector.main import (
GetResult,
SearchResult,
VectorDBBase,
VectorItem,
)
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.utils.misc import sanitize_text_for_db
from pgvector.sqlalchemy import HALFVEC, Vector
from sqlalchemy import (
func,
literal,
Column,
Integer,
LargeBinary,
MetaData,
Table,
Text,
cast,
column,
create_engine,
Column,
Integer,
MetaData,
LargeBinary,
func,
literal,
select,
text,
Text,
Table,
values,
)
from sqlalchemy.sql import true
from sqlalchemy.pool import NullPool, QueuePool
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
from sqlalchemy.dialects.postgresql import JSONB, array
from pgvector.sqlalchemy import Vector, HALFVEC
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.exc import NoSuchTableError
from open_webui.retrieval.vector.utils import process_metadata
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.utils.misc import sanitize_text_for_db
from open_webui.config import (
PGVECTOR_DB_URL,
PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH,
PGVECTOR_CREATE_EXTENSION,
PGVECTOR_PGCRYPTO,
PGVECTOR_PGCRYPTO_KEY,
PGVECTOR_POOL_SIZE,
PGVECTOR_POOL_MAX_OVERFLOW,
PGVECTOR_POOL_TIMEOUT,
PGVECTOR_POOL_RECYCLE,
PGVECTOR_INDEX_METHOD,
PGVECTOR_HNSW_M,
PGVECTOR_HNSW_EF_CONSTRUCTION,
PGVECTOR_IVFFLAT_LISTS,
PGVECTOR_USE_HALFVEC,
)
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import declarative_base, scoped_session, sessionmaker
from sqlalchemy.pool import NullPool, QueuePool
from sqlalchemy.sql import true
VECTOR_LENGTH = PGVECTOR_INITIALIZE_MAX_VECTOR_LENGTH
USE_HALFVEC = PGVECTOR_USE_HALFVEC
@@ -2,9 +2,10 @@
NOTE: This vector database integration is community-supported and maintained on a best-effort basis.
"""
from typing import Optional, List, Dict, Any, Union
import logging
import time # for measuring elapsed time
from typing import Any, Dict, List, Optional, Union
from pinecone import Pinecone, ServerlessSpec
# Add gRPC support for better performance (Pinecone best practice)
@@ -16,24 +17,23 @@ except ImportError:
GRPC_AVAILABLE = False
import asyncio # for async upserts
import functools # for partial binding in async tasks
import concurrent.futures # for parallel batch upserts
import functools # for partial binding in async tasks
import random # for jitter in retry backoff
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
PINECONE_API_KEY,
PINECONE_CLOUD,
PINECONE_DIMENSION,
PINECONE_ENVIRONMENT,
PINECONE_INDEX_NAME,
PINECONE_DIMENSION,
PINECONE_METRIC,
PINECONE_CLOUD,
)
from open_webui.retrieval.vector.main import (
GetResult,
SearchResult,
VectorDBBase,
VectorItem,
)
from open_webui.retrieval.vector.utils import process_metadata
@@ -2,31 +2,30 @@
NOTE: This vector database integration is community-supported and maintained on a best-effort basis.
"""
from typing import Optional
import logging
from typing import Optional
from urllib.parse import urlparse
from open_webui.config import (
QDRANT_API_KEY,
QDRANT_COLLECTION_PREFIX,
QDRANT_GRPC_PORT,
QDRANT_HNSW_M,
QDRANT_ON_DISK,
QDRANT_PREFER_GRPC,
QDRANT_TIMEOUT,
QDRANT_URI,
)
from open_webui.retrieval.vector.main import (
GetResult,
SearchResult,
VectorDBBase,
VectorItem,
)
from qdrant_client import QdrantClient as Qclient
from qdrant_client.http.models import PointStruct
from qdrant_client.models import models
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
SearchResult,
GetResult,
)
from open_webui.config import (
QDRANT_URI,
QDRANT_API_KEY,
QDRANT_ON_DISK,
QDRANT_GRPC_PORT,
QDRANT_PREFER_GRPC,
QDRANT_COLLECTION_PREFIX,
QDRANT_TIMEOUT,
QDRANT_HNSW_M,
)
NO_LIMIT = 999999999
log = logging.getLogger(__name__)
@@ -3,19 +3,19 @@ NOTE: This vector database integration is community-supported and maintained on
"""
import logging
from typing import Optional, Tuple, List, Dict, Any
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import urlparse
import grpc
from open_webui.config import (
QDRANT_API_KEY,
QDRANT_COLLECTION_PREFIX,
QDRANT_GRPC_PORT,
QDRANT_HNSW_M,
QDRANT_ON_DISK,
QDRANT_PREFER_GRPC,
QDRANT_URI,
QDRANT_COLLECTION_PREFIX,
QDRANT_TIMEOUT,
QDRANT_HNSW_M,
QDRANT_URI,
)
from open_webui.retrieval.vector.main import (
GetResult,
@@ -2,17 +2,18 @@
NOTE: This vector database integration is community-supported and maintained on a best-effort basis.
"""
from open_webui.retrieval.vector.utils import process_metadata
import logging
from typing import Any, Dict, List, Optional, Union
import boto3
from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION
from open_webui.retrieval.vector.main import (
VectorDBBase,
VectorItem,
GetResult,
SearchResult,
VectorDBBase,
VectorItem,
)
from open_webui.config import S3_VECTOR_BUCKET_NAME, S3_VECTOR_REGION
from typing import List, Optional, Dict, Any, Union
import logging
import boto3
from open_webui.retrieval.vector.utils import process_metadata
log = logging.getLogger(__name__)

Some files were not shown because too many files have changed in this diff Show More