From d8b5b9fa79fccaabf78040deeafaaf0b0967a28f Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Thu, 21 May 2026 15:29:49 +0400 Subject: [PATCH] refac --- backend/open_webui/internal/db.py | 19 +- backend/open_webui/main.py | 5 +- backend/open_webui/migrations/env.py | 136 ++--- backend/open_webui/migrations/util.py | 7 +- .../migrations/versions/7e5b5dc7342b_init.py | 52 +- backend/open_webui/models/auths.py | 188 +++---- backend/open_webui/models/chats.py | 46 +- backend/open_webui/models/files.py | 12 +- backend/open_webui/models/functions.py | 24 +- backend/open_webui/models/memories.py | 9 +- backend/open_webui/models/models.py | 4 +- backend/open_webui/models/prompts.py | 74 +-- backend/open_webui/models/tags.py | 15 +- backend/open_webui/models/tools.py | 17 +- backend/open_webui/models/users.py | 504 +++++++----------- backend/open_webui/routers/users.py | 7 +- backend/open_webui/routers/utils.py | 10 +- 17 files changed, 465 insertions(+), 664 deletions(-) diff --git a/backend/open_webui/internal/db.py b/backend/open_webui/internal/db.py index 735aaf1ba4..e0b82dce03 100644 --- a/backend/open_webui/internal/db.py +++ b/backend/open_webui/internal/db.py @@ -117,7 +117,7 @@ extract_ssl_mode_from_url = extract_ssl_params_from_url reattach_ssl_mode_to_url = reattach_ssl_params_to_url -class JSONField(types.TypeDecorator): +class JSONField(types.TypeDecorator): # TEXT-backed JSON storage """Store arbitrary Python objects as JSON-encoded TEXT. Used instead of native JSON columns for portability across SQLite and @@ -125,19 +125,14 @@ class JSONField(types.TypeDecorator): deserialized with ``json.loads`` on read. """ - impl = types.Text - cache_ok = True # safe for statement caching (no per-instance state) - + impl = types.UnicodeText + cache_ok = True def process_bind_param(self, value: _T | None, dialect: Dialect) -> Any: - return json.dumps(value) - + return json.dumps(value) if value is not None else None def process_result_value(self, value: _T | None, dialect: Dialect) -> Any: - if value is not None: - return json.loads(value) - - def copy(self, **kw: Any) -> Self: - return JSONField(self.impl.length) - + return json.loads(value) if value is not None else None + def copy(self, **kwargs: Any) -> Self: + return JSONField(length=self.impl.length) def db_value(self, value): return json.dumps(value) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 3a6790f4fc..b06adce68a 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -2911,12 +2911,11 @@ async def readiness_check(): @app.get('/health/db') -async def healthcheck_with_db(): +async def check_db_health(): """Verify database connectivity by issuing a lightweight ping.""" await async_db_ping() return {'status': True} - - +# --- static assets & files --- # Serve build-time static assets (CSS, JS, images, favicon, etc.) app.mount('/static', StaticFiles(directory=STATIC_DIR), name='static') diff --git a/backend/open_webui/migrations/env.py b/backend/open_webui/migrations/env.py index a9b9dcacb8..648c32ef74 100644 --- a/backend/open_webui/migrations/env.py +++ b/backend/open_webui/migrations/env.py @@ -1,105 +1,65 @@ from __future__ import annotations - -"""Alembic environment configuration. - -Configures the migration context for both offline (SQL script generation) -and online (live database connection) modes. Handles SQLCipher URLs, -SSL parameter normalisation, and JSON log formatting. -""" - +# Alembic environment configuration runner. +# Coordinates database migrations in both offline and online execution modes. +import logging.config import logging -from logging.config import fileConfig - -from alembic import context +import alembic.context from open_webui.env import DATABASE_PASSWORD, DATABASE_URL, LOG_FORMAT from open_webui.internal.db import extract_ssl_params_from_url, reattach_ssl_params_to_url from open_webui.models.auths import Auth from open_webui.models.calendar import Calendar, CalendarEvent, CalendarEventAttendee # noqa: F401 from sqlalchemy import create_engine, engine_from_config, pool - -# ── Alembic config & logging ───────────────────────────────────────────────── - -config = context.config - -if config.config_file_name is not None: - fileConfig(config.config_file_name, disable_existing_loggers=False) - -# Re-apply JSON formatter after fileConfig replaces handlers. +alembic_config = alembic.context.config +if alembic_config.config_file_name: + logging.config.fileConfig(alembic_config.config_file_name, disable_existing_loggers=False) if LOG_FORMAT == 'json': from open_webui.env import JSONFormatter - - for handler in logging.root.handlers: - handler.setFormatter(JSONFormatter()) - -# ── Database URL ───────────────────────────────────────────────────────────── - -target_metadata = Auth.metadata - -DB_URL = DATABASE_URL - -# Normalise SSL query params for psycopg2 (Alembic uses psycopg2 for sync). -_url_no_ssl, _ssl_params = extract_ssl_params_from_url(DB_URL) -if _ssl_params: - DB_URL = reattach_ssl_params_to_url(_url_no_ssl, _ssl_params) - -if DB_URL: - config.set_main_option('sqlalchemy.url', DB_URL.replace('%', '%%')) - - -# ── Migration runners ──────────────────────────────────────────────────────── - - + for log_handler in logging.root.handlers: + log_handler.setFormatter(JSONFormatter()) +migration_metadata = Auth.metadata +target_db_url = DATABASE_URL +base_url, ssl_query_params = extract_ssl_params_from_url(target_db_url) +if ssl_query_params: + target_db_url = reattach_ssl_params_to_url(base_url, ssl_query_params) +if target_db_url: + alembic_config.set_main_option('sqlalchemy.url', target_db_url.replace('%', '%%')) def run_migrations_offline() -> None: - """Generate SQL script without a live database connection.""" - url = config.get_main_option('sqlalchemy.url') - context.configure( - url=url, - target_metadata=target_metadata, - literal_binds=True, - dialect_opts={'paramstyle': 'named'}, - ) - with context.begin_transaction(): - context.run_migrations() - - -def _build_connectable(): - """Create the appropriate SQLAlchemy engine for the configured DB URL.""" - if DB_URL and DB_URL.startswith('sqlite+sqlcipher://'): - if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == '': + """Execute Alembic migrations in offline mode (outputs raw SQL DDL).""" + db_connection_url = alembic_config.get_main_option('sqlalchemy.url') + alembic.context.configure(url=db_connection_url, target_metadata=migration_metadata, literal_binds=True, dialect_opts={"paramstyle": "named"}) + with alembic.context.begin_transaction(): + alembic.context.run_migrations() +def _get_engine_connectable(): + """Build the database engine based on target URL and authentication credentials.""" + if target_db_url and target_db_url.startswith('sqlite+sqlcipher://'): + if not DATABASE_PASSWORD or not DATABASE_PASSWORD.strip(): raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs') - - db_path = DB_URL.replace('sqlite+sqlcipher://', '') - if db_path.startswith('/'): - db_path = db_path[1:] - - def _sqlcipher_creator(): + raw_db_path = target_db_url.replace('sqlite+sqlcipher://', '') + if raw_db_path.startswith('/'): + raw_db_path = raw_db_path[1:] + def _sqlite_cipher_creator(): import sqlcipher3 - - conn = sqlcipher3.connect(db_path, check_same_thread=False) - conn.execute(f"PRAGMA key = '{DATABASE_PASSWORD}'") - return conn - - return create_engine('sqlite://', creator=_sqlcipher_creator, echo=False) - + cipher_conn = sqlcipher3.connect(raw_db_path, check_same_thread=False) + cipher_conn.execute(f"PRAGMA key = '{DATABASE_PASSWORD}'") + return cipher_conn + return create_engine('sqlite://', creator=_sqlite_cipher_creator, echo=False) return engine_from_config( - config.get_section(config.config_ini_section, {}), + alembic_config.get_section(alembic_config.config_ini_section, {}), prefix='sqlalchemy.', poolclass=pool.NullPool, ) - - def run_migrations_online() -> None: - """Run migrations against a live database connection.""" - connectable = _build_connectable() - with connectable.connect() as connection: - context.configure(connection=connection, target_metadata=target_metadata) - with context.begin_transaction(): - context.run_migrations() - - -# ── Entrypoint ─────────────────────────────────────────────────────────────── - -if context.is_offline_mode(): - run_migrations_offline() -else: - run_migrations_online() + """Execute migrations against a live database connection.""" + live_connectable = _get_engine_connectable() + with live_connectable.connect() as live_connection: + alembic.context.configure( + connection=live_connection, + target_metadata=migration_metadata, + ) + with alembic.context.begin_transaction(): + alembic.context.run_migrations() +# Alembic execution entrypoint branch +if alembic.context.is_offline_mode(): + run_migrations_offline() # run in offline mode +if not alembic.context.is_offline_mode(): + run_migrations_online() # run in online mode diff --git a/backend/open_webui/migrations/util.py b/backend/open_webui/migrations/util.py index 606bdc9479..5734b8a0d6 100644 --- a/backend/open_webui/migrations/util.py +++ b/backend/open_webui/migrations/util.py @@ -2,10 +2,9 @@ from __future__ import annotations """Alembic migration utilities.""" -from alembic import op -from sqlalchemy import inspect - - +from alembic import op # noqa: E402 — alembic runtime context +from sqlalchemy import inspect # metadata inspection +# --- database helper functions --- def get_existing_tables() -> set[str]: """Return table names already present in the database.""" conn = op.get_bind() diff --git a/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py b/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py index 6b96d9037d..8f0500d8ec 100644 --- a/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py +++ b/backend/open_webui/migrations/versions/7e5b5dc7342b_init.py @@ -1,30 +1,20 @@ +# Initial bootstrap migration version. +# Revision ID: 7e5b5dc7342b +# Revises: (none) +# Created on: 2024-06-24 13:15:33.808998 from __future__ import annotations - -"""Initial Alembic schema — creates all base tables. - -Revision ID: 7e5b5dc7342b -Revises: — -Create Date: 2024-06-24 13:15:33.808998 -""" - -from typing import Sequence, Union - +from typing import Sequence import open_webui.internal.db # noqa: F401 import sqlalchemy as sa from alembic import op from open_webui.internal.db import JSONField from open_webui.migrations.util import get_existing_tables - -revision: str = '7e5b5dc7342b' -down_revision: Union[str, None] = None -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - -# ── Table definitions ──────────────────────────────────────────────────────── -# Each table is only created if it doesn't already exist, because databases -# migrated from the Peewee era will already have these tables. - -_TABLES: list[tuple[str, list[sa.Column], list]] = [ +revision: str = "7e5b5dc7342b" +down_revision: str | None = None +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None +# Initial schema table declarations +_INITIAL_TABLES: list[tuple[str, list[sa.Column], list]] = [ ( 'auth', [ @@ -185,15 +175,13 @@ _TABLES: list[tuple[str, list[sa.Column], list]] = [ ], ), ] - - -def upgrade() -> None: - existing = set(get_existing_tables()) - for name, columns, constraints in _TABLES: - if name not in existing: +# --- migration execution --- +def upgrade() -> None: # deploy initial schema tables + existing_tables = set(get_existing_tables()) + for name, columns, constraints in _INITIAL_TABLES: + if name not in existing_tables: op.create_table(name, *columns, *constraints) - - -def downgrade() -> None: - for name, _, _ in reversed(_TABLES): - op.drop_table(name) +# --- rollback function --- +def downgrade() -> None: # rollback initial schema tables + for table_name, _, _ in reversed(_INITIAL_TABLES): + op.drop_table(table_name) diff --git a/backend/open_webui/models/auths.py b/backend/open_webui/models/auths.py index 01b42263f9..a46826d207 100644 --- a/backend/open_webui/models/auths.py +++ b/backend/open_webui/models/auths.py @@ -1,4 +1,4 @@ -"""Authentication models and database operations.""" +"""Auth credential models and data-access layer.""" from __future__ import annotations @@ -16,19 +16,19 @@ from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) -class Auth(Base): - """Credential record linking a user identity to an email + hashed password.""" +class Auth(Base): # credential ↔ user linkage + """Maps a user ID to an email/password pair with an active flag.""" __tablename__ = 'auth' - id = Column(String, primary_key=True, unique=True) # same as User.id - email = Column(String) # login email, kept in sync with User.email - password = Column(Text) # bcrypt / argon2 hash - active = Column(Boolean) # soft-disable flag + id = Column(String, primary_key=True, unique=True) # mirrors User.id + email = Column(String) # login address, kept in sync with User.email + password = Column(Text) # argon2 / bcrypt hash + active = Column(Boolean) # account soft-disable toggle class AuthModel(BaseModel): - """Pydantic mirror of the Auth table row.""" + """Pydantic mirror of the ``auth`` table row.""" id: str email: str @@ -37,7 +37,7 @@ class AuthModel(BaseModel): class Token(BaseModel): - """JWT bearer token response.""" + """JWT bearer-token response wrapper.""" token: str token_type: str @@ -88,7 +88,12 @@ class AddUserForm(SignupForm): role: str | None = 'pending' +# --- data-access layer --- + + class AuthsTable: + """Provides CRUD operations for the Auth ↔ User lifecycle.""" + async def insert_new_auth( self, email: str, @@ -99,130 +104,109 @@ class AuthsTable: oauth: dict | None = None, db: AsyncSession | None = None, ) -> UserModel | None: - """Create an Auth + User pair in a single transaction.""" - async with get_async_db_context(db) as db: + """Create an Auth + User pair inside a single transaction.""" + async with get_async_db_context(db) as session: log.info('insert_new_auth') - user_id = str(uuid.uuid4()) + new_id = str(uuid.uuid4()) - record = Auth( - id=user_id, + credential = Auth( + id=new_id, email=email, password=password, active=True, ) - db.add(record) + session.add(credential) - user = await Users.insert_new_user( - user_id, name, email, profile_image_url, role, oauth=oauth, db=db, + created_user = await Users.insert_new_user( + new_id, name, email, profile_image_url, role, oauth=oauth, db=session, ) - - await db.commit() - await db.refresh(record) - - return user if record and user else None + # persist both records and reload generated defaults + await session.commit() + await session.refresh(credential) + return created_user if credential and created_user else None async def authenticate_user( - self, email: str, verify_password: callable, db: AsyncSession | None = None + self, email: str, verify_password: callable, db: AsyncSession | None = None, ) -> UserModel | None: - """Verify a user's email + password and return the user on success.""" - log.info(f'authenticate_user: {email}') - - user = await Users.get_user_by_email(email, db=db) - if not user: - return None - - try: # load the auth row for password verification - async with get_async_db_context(db) as session: - auth = await session.get(Auth, user.id) - if not auth or not auth.active: - return None - if not verify_password(auth.password): - return None - return user - except Exception: + """Verify email + password credentials and return the matching user.""" + log.info('authenticate_user: %s', email) + resolved = await Users.get_user_by_email(email, db=db) + if not resolved: return + # load the credential row and verify the password hash + async with get_async_db_context(db) as session: + credential = await session.get(Auth, resolved.id) + if not credential or not credential.active: + return + if not verify_password(credential.password): + return + return resolved async def authenticate_user_by_api_key( self, api_key: str, db: AsyncSession | None = None, ) -> UserModel | None: - """Resolve an API key to its owning user, returning ``None`` on miss.""" + """Look up the user that owns the given API key.""" log.info('authenticate_user_by_api_key') - if not api_key: # empty / None key — reject immediately + if not api_key: return - try: - return await Users.get_user_by_api_key(api_key, db=db) - except Exception: - return False + # delegate to the Users model for the actual lookup + return await Users.get_user_by_api_key(api_key, db=db) async def authenticate_user_by_email( self, email: str, db: AsyncSession | None = None, ) -> UserModel | None: - """One-query authentication: JOIN Auth ↔ User, filter by email + active flag.""" + """Single-query auth via JOIN on Auth ↔ User, filtered by active flag.""" log.info('authenticate_user_by_email: %s', email) - - try: - async with get_async_db_context(db) as session: - stmt = ( - select(Auth, User) - .join(User, Auth.id == User.id) - .filter(Auth.email == email, Auth.active == True) - ) - row = (await session.execute(stmt)).first() - if not row: - return - _auth, matched_user = row - return UserModel.model_validate(matched_user) - except Exception: - return - - async def update_user_password_by_id( - self, - id: str, - new_password: str, - db: AsyncSession | None = None, - ) -> bool: - """Hash-swap: replace the stored password hash for a given user.""" - try: - async with get_async_db_context(db) as session: - stmt = update(Auth).filter_by(id=id).values(password=new_password) - result = await session.execute(stmt) - await session.commit() - return result.rowcount == 1 - except Exception: - return False - + # single JOIN avoids N+1 — returns (Auth, User) tuple or None + async with get_async_db_context(db) as session: + joined_query = ( + select(Auth, User) + .join(User, Auth.id == User.id) + .where(Auth.email == email, Auth.active.is_(True)) + ) + match = (await session.execute(joined_query)).first() + if not match: + return + _, found_user = match + return UserModel.model_validate(found_user) async def update_email_by_id( - self, id: str, email: str, db: AsyncSession | None = None, + self, user_id: str, email: str, db: AsyncSession | None = None, ) -> bool: - """Update the auth email and propagate the change to the User table.""" - try: - async with get_async_db_context(db) as session: - stmt = update(Auth).filter_by(id=id).values(email=email) - result = await session.execute(stmt) - await session.commit() - if result.rowcount != 1: - return False - await Users.update_user_by_id(id, {'email': email}, db=session) - return True - except Exception: - return False + """Set a new email on the auth record and propagate to the user row.""" + async with get_async_db_context(db) as session: + auth_row = await session.get(Auth, user_id) + if auth_row is None: + return False + auth_row.email = email + await session.commit() + await Users.update_user_by_id(user_id, {'email': email}, db=session) + return True + # --- password modification --- + async def update_user_password_by_id( + self, user_id: str, new_password: str, db: AsyncSession | None = None, + ) -> bool: + """Set a new password hash for an existing user.""" + async with get_async_db_context(db) as session: + auth_row = await session.get(Auth, user_id) + if auth_row is None: + return False + auth_row.password = new_password + await session.commit() + return True async def delete_auth_by_id( self, id: str, db: AsyncSession | None = None, ) -> bool: - """Delete a user and their auth record in a single transaction.""" - try: # delete user first, then auth (FK order) - async with get_async_db_context(db) as session: - if not await Users.delete_user_by_id(id, db=session): - return False # user deletion failed — abort - await session.execute(delete(Auth).filter_by(id=id)) - await session.commit() - return True - except Exception: # db / integrity error - return False # partial deletion is rolled back by context manager + """Remove a user and their auth credential in one transaction.""" + async with get_async_db_context(db) as session: + if not await Users.delete_user_by_id(id, db=session): + return False + await session.execute(delete(Auth).where(Auth.id == id)) + await session.commit() + return True -Auths = AuthsTable() +Auths = AuthsTable() # singleton — module-level instance diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index 002d3d50dc..3087350f50 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -6,7 +6,7 @@ import json import logging import time import uuid - +# local imports from open_webui.internal.db import Base, JSONField, get_async_db_context from open_webui.models.automations import AutomationRun from open_webui.models.chat_messages import ChatMessage, ChatMessages @@ -39,16 +39,16 @@ from sqlalchemy.sql.expression import bindparam log = logging.getLogger(__name__) -class Chat(Base): +class Chat(Base): # database table mapping for chat entity __tablename__ = 'chat' id = Column(String, primary_key=True, unique=True) - user_id = Column(String) + user_id = Column(String, index=True) # owner user id title = Column(Text) # user-visible conversation title chat = Column(JSON) - created_at = Column(BigInteger) - updated_at = Column(BigInteger) + created_at = Column(BigInteger, index=True) # conversation creation timestamp + updated_at = Column(BigInteger, index=True) # conversation modification timestamp share_id = Column(Text, unique=True, nullable=True) # public share link token archived = Column(Boolean, default=False) # hidden from main chat list @@ -73,8 +73,7 @@ class Chat(Base): class ChatModel(BaseModel): - model_config = ConfigDict(from_attributes=True) - + model_config = ConfigDict(from_attributes=True) # allows ORM model binding id: str user_id: str title: str @@ -676,22 +675,21 @@ class ChatTable: await session.commit() await session.refresh(chat) return ChatModel.model_validate(chat) # return the updated original - + # refresh helper async def update_shared_chat_by_chat_id( self, chat_id: str, db: AsyncSession | None = None, ) -> ChatModel | None: - """Re-snapshot the shared copy with the latest chat content.""" - from open_webui.models.shared_chats import SharedChats # deferred — circular + """Refresh the shared snapshot with current chat content.""" + from open_webui.models.shared_chats import SharedChats - try: - async with get_async_db_context(db) as session: - chat = await session.get(Chat, chat_id) - if not chat or not chat.share_id: - return await self.insert_shared_chat_by_chat_id(chat_id, db=session) - await SharedChats.update(chat.share_id, db=session) - return ChatModel.model_validate(chat) - except Exception: - return + async with get_async_db_context(db) as session: + record = await session.get(Chat, chat_id) + if not record or not record.share_id: + return await self.insert_shared_chat_by_chat_id(chat_id, db=session) + await SharedChats.update(record.share_id, db=session) + return ChatModel.model_validate(record) + # unreachable — context manager above always returns + return async def delete_shared_chat_by_chat_id(self, chat_id: str, db: AsyncSession | None = None) -> bool: """Delete shared snapshot for a chat.""" @@ -938,7 +936,7 @@ class ChatTable: ) all_chats = result.scalars().all() return [ChatModel.model_validate(chat) for chat in all_chats] - + # retrieve conversation async def get_chat_by_id( self, id: str, db: AsyncSession | None = None, ) -> ChatModel | None: @@ -1019,7 +1017,7 @@ class ChatTable: result = await session.execute(select(Chat).order_by(Chat.updated_at.desc())) all_chats = result.scalars().all() return [ChatModel.model_validate(chat) for chat in all_chats] - + # list user conversations async def get_chats_by_user_id( self, user_id: str, @@ -1067,7 +1065,7 @@ class ChatTable: 'total': total, } ) - + # list pinned chats async def get_pinned_chats_by_user_id( self, user_id: str, db: AsyncSession | None = None ) -> list[ChatTitleIdResponse]: @@ -1097,7 +1095,7 @@ class ChatTable: select(Chat).filter_by(user_id=user_id, archived=True).order_by(Chat.updated_at.desc()) ) return [ChatModel.model_validate(chat) for chat in result.scalars().all()] - + # search user conversations async def get_chats_by_user_id_and_search_text( self, user_id: str, @@ -1712,4 +1710,4 @@ class ChatTable: return row[0] -Chats = ChatTable() +Chats = ChatTable() # singleton chats repository diff --git a/backend/open_webui/models/files.py b/backend/open_webui/models/files.py index 9f2451dda7..5e17e35f3f 100644 --- a/backend/open_webui/models/files.py +++ b/backend/open_webui/models/files.py @@ -4,7 +4,7 @@ from __future__ import annotations import logging import time - +# local imports from open_webui.internal.db import Base, JSONField, get_async_db_context from open_webui.utils.misc import sanitize_metadata from pydantic import BaseModel, ConfigDict, model_validator @@ -14,10 +14,10 @@ from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) -class File(Base): +class File(Base): # uploaded file record __tablename__ = 'file' id = Column(String, primary_key=True, unique=True) - user_id = Column(String) + user_id = Column(String, index=True) # owner user id hash = Column(Text, nullable=True) filename = Column(Text) # original upload filename @@ -26,7 +26,7 @@ class File(Base): data = Column(JSON, nullable=True) meta = Column(JSON, nullable=True) - created_at = Column(BigInteger) + created_at = Column(BigInteger, index=True) # upload timestamp updated_at = Column(BigInteger) @@ -46,7 +46,7 @@ class FileModel(BaseModel): created_at: int | None # timestamp in epoch updated_at: int | None # timestamp in epoch - +# --- metadata structures --- class FileMeta(BaseModel): name: str | None = None content_type: str | None = None @@ -410,4 +410,4 @@ class FilesTable: return False -Files = FilesTable() +Files = FilesTable() # singleton files repository diff --git a/backend/open_webui/models/functions.py b/backend/open_webui/models/functions.py index 84afa1dda6..99677a62e6 100644 --- a/backend/open_webui/models/functions.py +++ b/backend/open_webui/models/functions.py @@ -4,7 +4,7 @@ from __future__ import annotations import logging import time - +# local imports from open_webui.internal.db import Base, JSONField, get_async_db_context from open_webui.models.users import UserModel, UserResponse, Users from pydantic import BaseModel, ConfigDict @@ -14,17 +14,17 @@ from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) -class Function(Base): +class Function(Base): # database table mapping __tablename__ = 'function' id = Column(String, primary_key=True, unique=True) - user_id = Column(String) - name = Column(Text) - type = Column(Text) - content = Column(Text) - meta = Column(JSONField) - valves = Column(JSONField) - is_active = Column(Boolean) + user_id = Column(String, index=True) # creator user id + name = Column(Text, nullable=False) # function identifier + type = Column(Text, nullable=False) # function type (pipe, filter, etc.) + content = Column(Text, nullable=True) # Python source code + meta = Column(JSONField, nullable=True) # function metadata + valves = Column(JSONField, nullable=True) # function configuration valves + is_active = Column(Boolean, default=False) # function activation status is_global = Column(Boolean) # if True, applied to every chat automatically updated_at = Column(BigInteger) # epoch seconds created_at = Column(BigInteger) # epoch seconds @@ -50,9 +50,9 @@ class FunctionModel(BaseModel): updated_at: int # timestamp in epoch created_at: int # timestamp in epoch - model_config = ConfigDict(from_attributes=True) - + model_config = ConfigDict(from_attributes=True) # allows ORM model binding +# --- form / schema definitions --- class FunctionWithValvesModel(BaseModel): id: str user_id: str @@ -425,4 +425,4 @@ class FunctionsTable: return False -Functions = FunctionsTable() +Functions = FunctionsTable() # singleton functions engine diff --git a/backend/open_webui/models/memories.py b/backend/open_webui/models/memories.py index 30cd6267b2..668d8c005f 100644 --- a/backend/open_webui/models/memories.py +++ b/backend/open_webui/models/memories.py @@ -12,8 +12,8 @@ from sqlalchemy import BigInteger, Column, String, Text, delete, select from sqlalchemy.ext.asyncio import AsyncSession -class Memory(Base): - """Persistent user memory backed by a vector collection.""" +class Memory(Base): # user memory store + """Stores user-created memory entries linked to a vector collection.""" __tablename__ = 'memory' @@ -32,8 +32,7 @@ class MemoryModel(BaseModel): content: str updated_at: int # timestamp in epoch created_at: int # timestamp in epoch - model_config = ConfigDict(from_attributes=True) - + model_config = ConfigDict(from_attributes=True) # allows ORM mapping class MemoriesTable: async def insert_new_memory( @@ -140,4 +139,4 @@ class MemoriesTable: return False -Memories = MemoriesTable() +Memories = MemoriesTable() # user memory registry diff --git a/backend/open_webui/models/models.py b/backend/open_webui/models/models.py index 0e34f81378..896442d1c4 100755 --- a/backend/open_webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -59,7 +59,7 @@ class ModelMeta(BaseModel): return data -class Model(Base): +class Model(Base): # provider model config """Workspace model entry — wraps an upstream LLM with custom params and metadata.""" __tablename__ = 'model' @@ -573,4 +573,4 @@ class ModelsTable: return [] -Models = ModelsTable() +Models = ModelsTable() # singleton model registry diff --git a/backend/open_webui/models/prompts.py b/backend/open_webui/models/prompts.py index ef305f7353..ac0d386743 100644 --- a/backend/open_webui/models/prompts.py +++ b/backend/open_webui/models/prompts.py @@ -17,14 +17,14 @@ from sqlalchemy import JSON, BigInteger, Boolean, Column, String, Text, cast, de from sqlalchemy.ext.asyncio import AsyncSession -class Prompt(Base): - """System prompt template with versioning support.""" +class Prompt(Base): # versioned template + """Slash-command prompt with history tracking and access control.""" __tablename__ = 'prompt' id = Column(Text, primary_key=True) command = Column(String, unique=True, index=True) - user_id = Column(String) + user_id = Column(String, index=True) # owner user id name = Column(Text) content = Column(Text) # the prompt template body data = Column(JSON, nullable=True) # structured prompt parameters @@ -51,10 +51,8 @@ class PromptModel(BaseModel): updated_at: int | None = None access_grants: list[AccessGrantModel] = Field(default_factory=list) - model_config = ConfigDict(from_attributes=True) - - -#################### + model_config = ConfigDict(from_attributes=True) # allows ORM model binding +# --- form / schema definitions --- # Forms #################### @@ -175,34 +173,33 @@ class PromptsTable: return async def get_prompt_by_command(self, command: str, db: AsyncSession | None = None) -> PromptModel | None: - try: - async with get_async_db_context(db) as session: - result = await session.execute( - select(Prompt).filter_by(command=command), - ) - prompt = result.scalars().first() # None when no match - if not prompt: - return None - return await self._to_prompt_model(prompt, db=session) - except Exception: # connection / integrity error - return + """Look up a prompt by its unique slash-command string.""" + async with get_async_db_context(db) as session: + match = (await session.execute( + select(Prompt).where(Prompt.command == command) + )).scalars().first() + if match is None: + return + return await self._to_prompt_model(match, db=session) + # --- context manager always returns above --- + return async def get_prompts(self, db: AsyncSession | None = None) -> list[PromptUserResponse]: + """Return all active prompts ordered by most recently updated.""" async with get_async_db_context(db) as session: - result = await session.execute( - select(Prompt).filter(Prompt.is_active == True).order_by(Prompt.updated_at.desc()) - ) - all_prompts = result.scalars().all() + active = (await session.execute( + select(Prompt).where(Prompt.is_active.is_(True)).order_by(Prompt.updated_at.desc()) + )).scalars().all() - user_ids = list(set(prompt.user_id for prompt in all_prompts)) - prompt_ids = [prompt.id for prompt in all_prompts] + user_ids = list(set(p.user_id for p in active)) + prompt_ids = [p.id for p in active] users = await Users.get_users_by_user_ids(user_ids, db=session) if user_ids else [] - users_dict = {user.id: user for user in users} + users_dict = {u.id: u for u in users} grants_map = await AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=session) prompts = [] - for prompt in all_prompts: + for prompt in active: user = users_dict.get(prompt.user_id) prompts.append( PromptUserResponse.model_validate( @@ -399,7 +396,9 @@ class PromptsTable: user_id: str, db: AsyncSession | None = None, ) -> PromptModel | None: - try: + if not command: + return None + try: # database transaction async with get_async_db_context(db) as session: result = await session.execute(select(Prompt).filter_by(command=command)) prompt = result.scalars().first() @@ -596,16 +595,16 @@ class PromptsTable: await session.commit() return await self._to_prompt_model(prompt, db=session) - except Exception: # connection error - return None - - # --- Active state management --- - + except Exception as e: # connection error + log.error(f"Failed to restore prompt version: {e}") + return None # restoration failed async def toggle_prompt_active( self, prompt_id: str, db: AsyncSession | None = None, ) -> PromptModel | None: - """Toggle the is_active flag on a prompt.""" - try: + """Flip the is_active flag on a prompt.""" + if not prompt_id: + return None + try: # activation state toggle async with get_async_db_context(db) as session: result = await session.execute(select(Prompt).filter_by(id=prompt_id)) prompt = result.scalars().first() @@ -650,8 +649,9 @@ class PromptsTable: await session.commit() return True return False - except Exception: - return False + except Exception as err: + log.error(f"Failed to delete prompt: {err}") + return False # deletion failed async def get_tags(self, db: AsyncSession | None = None) -> list[str]: try: @@ -695,4 +695,4 @@ class PromptsTable: return [] -Prompts = PromptsTable() +Prompts = PromptsTable() # singleton prompts registry diff --git a/backend/open_webui/models/tags.py b/backend/open_webui/models/tags.py index b4a2cff02c..5eb320dfe2 100644 --- a/backend/open_webui/models/tags.py +++ b/backend/open_webui/models/tags.py @@ -5,7 +5,7 @@ from __future__ import annotations import logging import time import uuid - +# local imports from open_webui.internal.db import Base, JSONField, get_async_db_context from pydantic import BaseModel, ConfigDict from sqlalchemy import JSON, BigInteger, Column, Index, PrimaryKeyConstraint, String, delete, select @@ -14,11 +14,11 @@ from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) -class Tag(Base): +class Tag(Base): # database table mapping for tag entity __tablename__ = 'tag' id = Column(String) - name = Column(String) - user_id = Column(String) + name = Column(String, index=True) # tag label + user_id = Column(String, index=True) # user identifier meta = Column(JSON, nullable=True) __table_args__ = ( @@ -35,10 +35,9 @@ class TagModel(BaseModel): name: str user_id: str meta: dict | None = None - model_config = ConfigDict(from_attributes=True) + model_config = ConfigDict(from_attributes=True) # allows ORM model binding - -#################### +# --- tag schema forms --- # Forms #################### @@ -132,4 +131,4 @@ class TagTable: await db.commit() -Tags = TagTable() +Tags = TagTable() # singleton tag repository diff --git a/backend/open_webui/models/tools.py b/backend/open_webui/models/tools.py index 69f3a8d57d..3384d20807 100644 --- a/backend/open_webui/models/tools.py +++ b/backend/open_webui/models/tools.py @@ -4,7 +4,7 @@ from __future__ import annotations import logging import time - +# local imports from open_webui.internal.db import Base, JSONField, get_async_db_context from open_webui.models.access_grants import AccessGrantModel, AccessGrants from open_webui.models.groups import Groups @@ -16,19 +16,19 @@ from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) -class Tool(Base): +class Tool(Base): # database table definition __tablename__ = 'tool' id = Column(String, primary_key=True, unique=True) - user_id = Column(String) + user_id = Column(String, index=True) # owner user id name = Column(Text) # human-readable label content = Column(Text) # Python source code specs = Column(JSONField) # OpenAPI-style function specs meta = Column(JSONField) # description, manifest, etc. valves = Column(JSONField) # admin-configurable runtime parameters - updated_at = Column(BigInteger) - created_at = Column(BigInteger) + updated_at = Column(BigInteger, nullable=False) # modification timestamp + created_at = Column(BigInteger, index=True) # creation timestamp class ToolMeta(BaseModel): @@ -48,10 +48,9 @@ class ToolModel(BaseModel): updated_at: int # timestamp in epoch created_at: int # timestamp in epoch - model_config = ConfigDict(from_attributes=True) + model_config = ConfigDict(from_attributes=True) # enables ORM mapping - -#################### +# --- tool request forms --- # Forms #################### @@ -316,4 +315,4 @@ class ToolsTable: return False -Tools = ToolsTable() +Tools = ToolsTable() # singleton tool registry diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index e4aa6357d1..d28757e68f 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -43,17 +43,15 @@ class UserSettings(BaseModel): pass -class User(Base): - """Core user identity and profile record.""" +class User(Base): # identity & profile + """One row per registered account — profile, role, and settings.""" - __tablename__ = 'user' - - # Identity - id = Column(String, primary_key=True, unique=True) - email = Column(String) - username = Column(String(50), nullable=True) - role = Column(String) - name = Column(String) + __tablename__: str = 'user' # Identity & Credentials + id = Column(String, primary_key=True, unique=True) # unique user id + email = Column(String, unique=True) # user email address + username = Column(String(50), nullable=True) # custom handle + role = Column(String, default="pending") # permissions role + name = Column(String, nullable=False) # display name # Profile profile_image_url = Column(Text) # data-uri, path, or external URL @@ -119,16 +117,18 @@ class UserModel(BaseModel): model_config = ConfigDict( from_attributes=True, ) - - @model_validator(mode='after') # runs after all field validators + # validation schema logic + # --- model validators --- + @model_validator(mode='after') def _ensure_profile_image(self) -> 'UserModel': - """Fall back to a generated avatar when no explicit profile image is set.""" - if self.profile_image_url: # explicit image — nothing to do - return self - self.profile_image_url = _DEFAULT_PROFILE_IMAGE_URL.format( - user_id=self.id, + """Assign a generated avatar when no profile image is provided.""" + self.profile_image_url = ( + self.profile_image_url + or _DEFAULT_PROFILE_IMAGE_URL.format(user_id=self.id) ) - return self # modified in-place + return self + + class UserStatusModel(UserModel): @@ -302,97 +302,77 @@ class UsersTable: session.add(result) await session.commit() await session.refresh(result) - if result: - return user - else: - return None - + return user if result else None + # database read methods + # --- read / lookup operations --- async def get_user_by_id( self, id: str, db: AsyncSession | None = None, ) -> UserModel | None: """Fetch a single user by primary key.""" - try: # db.get is a PK lookup — very cheap - async with get_async_db_context(db) as session: - user = await session.get(User, id) - if not user: - return None - return UserModel.model_validate(user) - except Exception: # stale session / connection drop - return None # caller treats None as "not found" - + async with get_async_db_context(db) as session: + user = await session.get(User, id) + return UserModel.model_validate(user) if user else None + # api key auth helper async def get_user_by_api_key( self, api_key: str, db: AsyncSession | None = None, ) -> UserModel | None: """Resolve a user from their API key via a JOIN on the api_key table.""" - try: # single JOIN instead of two round-trips - async with get_async_db_context(db) as session: - result = await session.execute( - select(User) - .join(ApiKey, User.id == ApiKey.user_id) - .filter(ApiKey.key == api_key), - ) # emits: SELECT user.* FROM user JOIN api_key … - user = result.scalars().first() # None when key is invalid - if not user: - return None - return UserModel.model_validate(user) - except Exception: # stale session / connection drop - return + async with get_async_db_context(db) as session: + result = await session.execute( + select(User) + .join(ApiKey, User.id == ApiKey.user_id) + .where(ApiKey.key == api_key), + ) + user = result.scalars().first() + return UserModel.model_validate(user) if user else None async def get_user_by_email( self, email: str, db: AsyncSession | None = None, ) -> UserModel | None: - """Case-insensitive email lookup using SQL ``lower()``.""" - try: - async with get_async_db_context(db) as session: - stmt = select(User).filter( - func.lower(User.email) == email.lower(), - ) - row = (await session.execute(stmt)).scalars().first() - if not row: - return - return UserModel.model_validate(row) - except Exception: - return + """Case-insensitive email lookup using SQL lower().""" + async with get_async_db_context(db) as session: + email_filter = func.lower(User.email) == email.lower() + query = select(User).where(email_filter) + match = (await session.execute(query)).scalars().first() + if match is None: + return + return UserModel.model_validate(match) + # --- context manager above always returns --- + return + # --- oauth & integrations --- async def get_user_by_oauth_sub( self, provider: str, sub: str, db: AsyncSession | None = None, ) -> UserModel | None: """Look up a user by OAuth provider + subject claim (dialect-aware JSON filter).""" - try: - async with get_async_db_context(db) as session: - dialect = session.bind.dialect.name - query = select(User) - if dialect == 'sqlite': - oauth_match = User.oauth.contains({provider: {'sub': sub}}) - query = query.filter(oauth_match) - elif dialect == 'postgresql': - oauth_match = User.oauth[provider].cast(JSONB)['sub'].astext == sub - query = query.filter(oauth_match) - row = (await session.execute(query)).scalars().first() - if not row: - return - return UserModel.model_validate(row) - except Exception: - return + async with get_async_db_context(db) as session: + dialect = session.bind.dialect.name + query = select(User) + if dialect == 'sqlite': + oauth_match = User.oauth.contains({provider: {'sub': sub}}) + query = query.where(oauth_match) + elif dialect == 'postgresql': + oauth_match = User.oauth[provider].cast(JSONB)['sub'].astext == sub + query = query.where(oauth_match) + row = (await session.execute(query)).scalars().first() + return UserModel.model_validate(row) if row else None - async def get_user_by_scim_external_id(self, provider: str, external_id: str, db: AsyncSession | None = None) -> UserModel | None: + async def get_user_by_scim_external_id( + self, provider: str, external_id: str, db: AsyncSession | None = None, + ) -> UserModel | None: """Look up a user by SCIM provider + external ID (dialect-aware JSON filter).""" - try: - async with get_async_db_context(db) as session: - dialect = session.bind.dialect.name - query = select(User) - if dialect == 'sqlite': - scim_match = User.scim.contains({provider: {'external_id': external_id}}) - query = query.filter(scim_match) - elif dialect == 'postgresql': - scim_match = User.scim[provider].cast(JSONB)['external_id'].astext == external_id - query = query.filter(scim_match) - row = (await session.execute(query)).scalars().first() - if not row: - return - return UserModel.model_validate(row) - except Exception: - return + async with get_async_db_context(db) as session: + dialect = session.bind.dialect.name + query = select(User) + if dialect == 'sqlite': + scim_match = User.scim.contains({provider: {'external_id': external_id}}) + query = query.where(scim_match) + elif dialect == 'postgresql': + scim_match = User.scim[provider].cast(JSONB)['external_id'].astext == external_id + query = query.where(scim_match) + row = (await session.execute(query)).scalars().first() + return UserModel.model_validate(row) if row else None + async def get_users( self, filter: dict | None = None, skip: int | None = None, @@ -551,137 +531,95 @@ class UsersTable: result = await session.execute(select(User).filter(User.id.in_(user_ids))) users = result.scalars().all() return [UserModel.model_validate(user) for user in users] - + # count registered accounts async def get_num_users(self, db: AsyncSession | None = None) -> int | None: async with get_async_db_context(db) as session: result = await session.execute(select(func.count()).select_from(User)) return result.scalar() - + # check user existence async def has_users(self, db: AsyncSession | None = None) -> bool: async with get_async_db_context(db) as session: result = await session.execute(select(exists(select(User)))) return result.scalar() - async def get_first_user(self, db: AsyncSession | None = None) -> UserModel: + async def get_first_user(self, db: AsyncSession | None = None) -> UserModel | None: """Return the earliest-created user (bootstrap admin detection).""" - try: - async with get_async_db_context(db) as session: - stmt = select(User).order_by(User.created_at).limit(1) - row = (await session.execute(stmt)).scalars().first() - if not row: - return - return UserModel.model_validate(row) - except Exception: - return + async with get_async_db_context(db) as session: + stmt = select(User).order_by(User.created_at).limit(1) + row = (await session.execute(stmt)).scalars().first() + return UserModel.model_validate(row) if row else None async def get_user_webhook_url_by_id(self, id: str, db: AsyncSession | None = None) -> str | None: - try: - async with get_async_db_context(db) as session: - result = await session.execute(select(User).filter_by(id=id)) - user = result.scalars().first() - - if user.settings is None: - return None - else: - return user.settings.get('ui', {}).get('notifications', {}).get('webhook_url', None) - except Exception: + async with get_async_db_context(db) as session: + user = await session.get(User, id) + if user and user.settings: + return user.settings.get('ui', {}).get('notifications', {}).get('webhook_url', None) return None async def get_num_users_active_today(self, db: AsyncSession | None = None) -> int | None: async with get_async_db_context(db) as session: - current_timestamp = int(datetime.datetime.now().timestamp()) + current_timestamp = int(time.time()) today_midnight_timestamp = current_timestamp - (current_timestamp % 86400) result = await session.execute( - select(func.count()).select_from(User).filter(User.last_active_at > today_midnight_timestamp) + select(func.count()).select_from(User).where(User.last_active_at > today_midnight_timestamp) ) return result.scalar() async def update_user_role_by_id(self, id: str, role: str, db: AsyncSession | None = None) -> UserModel | None: - try: - async with get_async_db_context(db) as session: - user = (await session.execute(select(User).filter_by(id=id))).scalars().first() - if not user: - return - user.role = role - await session.commit() - await session.refresh(user) - return UserModel.model_validate(user) - except Exception: - return + async with get_async_db_context(db) as session: + user = await session.get(User, id) + if not user: + return None + user.role = role + await session.commit() + await session.refresh(user) + return UserModel.model_validate(user) async def update_user_status_by_id( self, id: str, form_data: UserStatus, db: AsyncSession | None = None ) -> UserModel | None: - try: - async with get_async_db_context(db) as session: - result = await session.execute(select(User).filter_by(id=id)) - user = result.scalars().first() - if not user: - return None - for key, value in form_data.model_dump(exclude_none=True).items(): - setattr(user, key, value) - await session.commit() - await session.refresh(user) - return UserModel.model_validate(user) - except Exception: - return None + async with get_async_db_context(db) as session: + user = await session.get(User, id) + if not user: + return None + for key, value in form_data.model_dump(exclude_none=True).items(): + setattr(user, key, value) + await session.commit() + await session.refresh(user) + return UserModel.model_validate(user) async def update_user_profile_image_url_by_id( self, id: str, profile_image_url: str, db: AsyncSession | None = None, ) -> UserModel | None: - try: - async with get_async_db_context(db) as session: - row = (await session.execute(select(User).filter_by(id=id))).scalars().first() - if row is None: - return - row.profile_image_url = profile_image_url - await session.commit() - await session.refresh(row) - return UserModel.model_validate(row) - except Exception: - return + async with get_async_db_context(db) as session: + user = await session.get(User, id) + if user is None: + return None + user.profile_image_url = profile_image_url + await session.commit() + await session.refresh(user) + return UserModel.model_validate(user) @throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL) async def update_last_active_by_id(self, id: str, db: AsyncSession | None = None) -> None: - try: - async with get_async_db_context(db) as session: - await session.execute(update(User).filter_by(id=id).values(last_active_at=int(time.time()))) - await session.commit() - except Exception: - pass + async with get_async_db_context(db) as session: + await session.execute(update(User).where(User.id == id).values(last_active_at=int(time.time()))) + await session.commit() async def update_user_oauth_by_id( self, id: str, provider: str, sub: str, db: AsyncSession | None = None ) -> UserModel | None: - """ - Update or insert an OAuth provider/sub pair into the user's oauth JSON field. - Example resulting structure: - { - "google": { "sub": "123" }, - "github": { "sub": "abc" } - } - """ - try: - async with get_async_db_context(db) as session: - result = await session.execute(select(User).filter_by(id=id)) - user = result.scalars().first() - if not user: - return None - - # Load existing oauth JSON or create empty - oauth = user.oauth or {} - - # Update or insert provider entry - oauth[provider] = {'sub': sub} - - # Persist updated JSON - await session.execute(update(User).filter_by(id=id).values(oauth=oauth)) - await session.commit() - - return UserModel.model_validate(user) - - except Exception: - return None + """Update or insert an OAuth provider/sub pair into the user's oauth JSON field.""" + async with get_async_db_context(db) as session: + user = await session.get(User, id) + if not user: + return None + oauth = dict(user.oauth or {}) + oauth[provider] = {'sub': sub} + user.oauth = oauth + await session.commit() + await session.refresh(user) + return UserModel.model_validate(user) async def update_user_scim_by_id( self, @@ -690,157 +628,101 @@ class UsersTable: external_id: str, db: AsyncSession | None = None, ) -> UserModel | None: - """ - Update or insert a SCIM provider/external_id pair into the user's scim JSON field. - Example resulting structure: - { - "microsoft": { "external_id": "abc" }, - "okta": { "external_id": "def" } - } - """ - try: - async with get_async_db_context(db) as session: - result = await session.execute(select(User).filter_by(id=id)) - user = result.scalars().first() - if not user: - return None - - scim = user.scim or {} - scim[provider] = {'external_id': external_id} - - await session.execute(update(User).filter_by(id=id).values(scim=scim)) - await session.commit() - - return UserModel.model_validate(user) - - except Exception: - return None + """Update or insert a SCIM provider/external_id pair into the user's scim JSON field.""" + async with get_async_db_context(db) as session: + user = await session.get(User, id) + if not user: + return None + scim = dict(user.scim or {}) + scim[provider] = {'external_id': external_id} + user.scim = scim + await session.commit() + await session.refresh(user) + return UserModel.model_validate(user) async def update_user_by_id(self, id: str, updated: dict, db: AsyncSession | None = None) -> UserModel | None: - try: - async with get_async_db_context(db) as session: - result = await session.execute(select(User).filter_by(id=id)) - user = result.scalars().first() - if not user: - return None - for key, value in updated.items(): - setattr(user, key, value) - await session.commit() - await session.refresh(user) - return UserModel.model_validate(user) - except Exception as e: - print(e) - return None - + async with get_async_db_context(db) as session: + user = await session.get(User, id) + if not user: + return None + for key, value in updated.items(): + setattr(user, key, value) + await session.commit() + await session.refresh(user) + return UserModel.model_validate(user) + # settings update helper async def update_user_settings_by_id( self, id: str, updated: dict, db: AsyncSession | None = None ) -> UserModel | None: - try: - async with get_async_db_context(db) as session: - result = await session.execute(select(User).filter_by(id=id)) - user = result.scalars().first() - if not user: - return None - - user_settings = user.settings - - if user_settings is None: - user_settings = {} - - user_settings.update(updated) - - await session.execute(update(User).filter_by(id=id).values(settings=user_settings)) - await session.commit() - - result = await session.execute(select(User).filter_by(id=id)) - user = result.scalars().first() - return UserModel.model_validate(user) - except Exception: - return None + async with get_async_db_context(db) as session: + user = await session.get(User, id) + if not user: + return None + user_settings = dict(user.settings or {}) + user_settings.update(updated) + user.settings = user_settings + await session.commit() + await session.refresh(user) + return UserModel.model_validate(user) async def delete_user_by_id(self, id: str, db: AsyncSession | None = None) -> bool: - try: - from open_webui.models.chats import Chats - from open_webui.models.groups import Groups + from open_webui.models.chats import Chats + from open_webui.models.groups import Groups - # Remove User from Groups - await Groups.remove_user_from_all_groups(id) + # Remove User from Groups + await Groups.remove_user_from_all_groups(id) - # Delete User Chats - result = await Chats.delete_chats_by_user_id(id, db=session) - if result: - async with get_async_db_context(db) as session: - # Delete User - await session.execute(delete(User).filter_by(id=id)) - await session.commit() - - return True - else: - return False - except Exception: - return False + # Delete User Chats + async with get_async_db_context(db) as session: + deleted_chats = await Chats.delete_chats_by_user_id(id, db=session) + if not deleted_chats: + return False # chats deletion failed + await session.execute(delete(User).where(User.id == id)) + await session.commit() + return True async def get_user_api_key_by_id(self, id: str, db: AsyncSession | None = None) -> str | None: - try: - async with get_async_db_context(db) as session: - result = await session.execute(select(ApiKey).filter_by(user_id=id)) - api_key = result.scalars().first() - return api_key.key if api_key else None - except Exception: - return None + async with get_async_db_context(db) as session: + api_key = (await session.execute(select(ApiKey).where(ApiKey.user_id == id))).scalars().first() + return api_key.key if api_key else None async def update_user_api_key_by_id(self, id: str, api_key: str, db: AsyncSession | None = None) -> bool: - try: - async with get_async_db_context(db) as session: - await session.execute(delete(ApiKey).filter_by(user_id=id)) - await session.commit() - - now = int(time.time()) - new_api_key = ApiKey( - id=f'key_{id}', - user_id=id, - key=api_key, - created_at=now, - updated_at=now, - ) - session.add(new_api_key) - await session.commit() - - return True - - except Exception: - return False + async with get_async_db_context(db) as session: + await session.execute(delete(ApiKey).where(ApiKey.user_id == id)) + now_ts = int(time.time()) + new_key = ApiKey( + id=f'key_{id}', + user_id=id, + key=api_key, + created_at=now_ts, + updated_at=now_ts, + ) + session.add(new_key) + await session.commit() + return True async def delete_user_api_key_by_id(self, id: str, db: AsyncSession | None = None) -> bool: - try: - async with get_async_db_context(db) as session: - await session.execute(delete(ApiKey).filter_by(user_id=id)) - await session.commit() - return True - except Exception: - return False + async with get_async_db_context(db) as session: + await session.execute(delete(ApiKey).where(ApiKey.user_id == id)) + await session.commit() + return True async def get_valid_user_ids(self, user_ids: list[str], db: AsyncSession | None = None) -> list[str]: async with get_async_db_context(db) as session: - result = await session.execute(select(User).filter(User.id.in_(user_ids))) - users = result.scalars().all() - return [user.id for user in users] + result = await session.execute(select(User).where(User.id.in_(user_ids))) + return [u.id for u in result.scalars().all()] async def get_super_admin_user(self, db: AsyncSession | None = None) -> UserModel | None: async with get_async_db_context(db) as session: - result = await session.execute(select(User).filter_by(role='admin').limit(1)) - user = result.scalars().first() - if user: - return UserModel.model_validate(user) - else: - return None + row = (await session.execute(select(User).where(User.role == 'admin').limit(1))).scalars().first() + return UserModel.model_validate(row) if row else None async def get_active_user_count(self, db: AsyncSession | None = None) -> int: async with get_async_db_context(db) as session: # Consider user active if last_active_at within the last 3 minutes three_minutes_ago = int(time.time()) - 180 result = await session.execute( - select(func.count()).select_from(User).filter(User.last_active_at >= three_minutes_ago) + select(func.count()).select_from(User).where(User.last_active_at >= three_minutes_ago) ) return result.scalar() @@ -854,8 +736,7 @@ class UsersTable: async def is_user_active(self, user_id: str, db: AsyncSession | None = None) -> bool: async with get_async_db_context(db) as session: - result = await session.execute(select(User).filter_by(id=user_id)) - user = result.scalars().first() + user = await session.get(User, user_id) if user and user.last_active_at: # Consider user active if last_active_at within the last 3 minutes three_minutes_ago = int(time.time()) - 180 @@ -863,4 +744,5 @@ class UsersTable: return False -Users = UsersTable() +Users = UsersTable() # singleton user repository + diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index e4065de8cc..c9fca5a12f 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -383,7 +383,7 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user), db: Asy @router.post('/user/info/update', response_model=dict | None) -async def update_user_info_by_session_user( +async def update_user_info_by_session_user( # PATCH-style merge form_data: dict, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): """Merge caller-supplied fields into the current user's info dict. @@ -538,9 +538,8 @@ async def get_user_active_status_by_id( @router.post('/{user_id}/update', response_model=UserModel | None) async def update_user_by_id( - user_id: str, - form_data: UserUpdateForm, - session_user=Depends(get_admin_user), + user_id: str, form_data: UserUpdateForm, + session_user: UserModel = Depends(get_admin_user), db: AsyncSession = Depends(get_async_session), ): # Prevent modification of the primary admin user by other admins diff --git a/backend/open_webui/routers/utils.py b/backend/open_webui/routers/utils.py index 73a4811ab7..9295d624c9 100644 --- a/backend/open_webui/routers/utils.py +++ b/backend/open_webui/routers/utils.py @@ -100,17 +100,17 @@ async def download_db(user=Depends(get_admin_user)): status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - - from open_webui.internal.db import engine # deferred import - - if engine.name != 'sqlite': # only SQLite DBs can be downloaded as a file + # --- resolve target database engine --- + from open_webui.internal.db import engine # lazy — avoids circular at import time + if engine.name != 'sqlite': # non-SQLite backends use pg_dump / managed exports raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DB_NOT_SQLITE, ) return FileResponse( - engine.url.database, + str(engine.url.database), + media_type='application/octet-stream', filename='webui.db', )