From 260ead64da38cc0ba47e48ad352df50e215ea38b Mon Sep 17 00:00:00 2001 From: Timothy Jaeryang Baek Date: Thu, 21 May 2026 14:01:57 +0400 Subject: [PATCH] refac --- backend/open_webui/internal/db.py | 9 +- backend/open_webui/main.py | 9 +- backend/open_webui/models/auths.py | 181 ++++++------ backend/open_webui/models/chats.py | 358 ++++++++++++------------ backend/open_webui/models/files.py | 30 +- backend/open_webui/models/functions.py | 17 +- backend/open_webui/models/memories.py | 56 ++-- backend/open_webui/models/models.py | 50 +--- backend/open_webui/models/prompts.py | 296 ++++++++++---------- backend/open_webui/models/tags.py | 31 +-- backend/open_webui/models/tools.py | 36 +-- backend/open_webui/models/users.py | 368 ++++++++++++++----------- backend/open_webui/routers/chats.py | 17 +- backend/open_webui/routers/memories.py | 14 +- backend/open_webui/routers/models.py | 17 +- backend/open_webui/routers/prompts.py | 8 +- backend/open_webui/routers/tools.py | 18 +- backend/open_webui/routers/users.py | 21 +- backend/open_webui/routers/utils.py | 7 +- 19 files changed, 772 insertions(+), 771 deletions(-) diff --git a/backend/open_webui/internal/db.py b/backend/open_webui/internal/db.py index 69050160a8..735aaf1ba4 100644 --- a/backend/open_webui/internal/db.py +++ b/backend/open_webui/internal/db.py @@ -118,8 +118,15 @@ reattach_ssl_mode_to_url = reattach_ssl_params_to_url class JSONField(types.TypeDecorator): + """Store arbitrary Python objects as JSON-encoded TEXT. + + Used instead of native JSON columns for portability across SQLite and + PostgreSQL. Values are serialized with ``json.dumps`` on write and + deserialized with ``json.loads`` on read. + """ + impl = types.Text - cache_ok = True + cache_ok = True # safe for statement caching (no per-instance state) def process_bind_param(self, value: _T | None, dialect: Dialect) -> Any: return json.dumps(value) diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 21ad95e1a6..3a6790f4fc 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -2912,10 +2912,12 @@ async def readiness_check(): @app.get('/health/db') async def healthcheck_with_db(): + """Verify database connectivity by issuing a lightweight ping.""" await async_db_ping() return {'status': True} +# Serve build-time static assets (CSS, JS, images, favicon, etc.) app.mount('/static', StaticFiles(directory=STATIC_DIR), name='static') @@ -2924,8 +2926,13 @@ async def serve_cache_file( path: str, user=Depends(get_verified_user), ): + """Serve cached files (e.g. tool outputs) with path-traversal protection. + + Only ``image/*``, ``audio/*``, and ``video/*`` MIME types are served inline; + everything else gets a ``Content-Disposition: attachment`` header to prevent + XSS from user-generated HTML stored in the cache directory. + """ file_path = os.path.abspath(os.path.join(CACHE_DIR, path)) - # prevent path traversal if not file_path.startswith(os.path.abspath(CACHE_DIR)): raise HTTPException(status_code=404, detail='File not found') if not os.path.isfile(file_path): diff --git a/backend/open_webui/models/auths.py b/backend/open_webui/models/auths.py index 253c662029..01b42263f9 100644 --- a/backend/open_webui/models/auths.py +++ b/backend/open_webui/models/auths.py @@ -1,3 +1,5 @@ +"""Authentication models and database operations.""" + from __future__ import annotations import logging @@ -13,33 +15,30 @@ from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) -#################### -# DB MODEL -#################### - class Auth(Base): + """Credential record linking a user identity to an email + hashed password.""" + __tablename__ = 'auth' - id = Column(String, primary_key=True, unique=True) - email = Column(String) - password = Column(Text) - active = Column(Boolean) + id = Column(String, primary_key=True, unique=True) # same as User.id + email = Column(String) # login email, kept in sync with User.email + password = Column(Text) # bcrypt / argon2 hash + active = Column(Boolean) # soft-disable flag class AuthModel(BaseModel): + """Pydantic mirror of the Auth table row.""" + id: str email: str password: str active: bool = True -#################### -# Forms -#################### - - class Token(BaseModel): + """JWT bearer token response.""" + token: str token_type: str @@ -100,112 +99,130 @@ class AuthsTable: oauth: dict | None = None, db: AsyncSession | None = None, ) -> UserModel | None: + """Create an Auth + User pair in a single transaction.""" async with get_async_db_context(db) as db: log.info('insert_new_auth') - id = str(uuid.uuid4()) + user_id = str(uuid.uuid4()) - auth = AuthModel(**{'id': id, 'email': email, 'password': password, 'active': True}) - result = Auth(**auth.model_dump()) - db.add(result) + record = Auth( + id=user_id, + email=email, + password=password, + active=True, + ) + db.add(record) - user = await Users.insert_new_user(id, name, email, profile_image_url, role, oauth=oauth, db=db) + user = await Users.insert_new_user( + user_id, name, email, profile_image_url, role, oauth=oauth, db=db, + ) await db.commit() - await db.refresh(result) + await db.refresh(record) - if result and user: - return user - else: - return None + return user if record and user else None async def authenticate_user( self, email: str, verify_password: callable, db: AsyncSession | None = None ) -> UserModel | None: + """Verify a user's email + password and return the user on success.""" log.info(f'authenticate_user: {email}') user = await Users.get_user_by_email(email, db=db) if not user: return None - try: - async with get_async_db_context(db) as db: - result = await db.execute(select(Auth).filter_by(id=user.id, active=True)) - auth = result.scalars().first() - if auth: - if verify_password(auth.password): - return user - else: - return None - else: + try: # load the auth row for password verification + async with get_async_db_context(db) as session: + auth = await session.get(Auth, user.id) + if not auth or not auth.active: return None + if not verify_password(auth.password): + return None + return user except Exception: - return None - - async def authenticate_user_by_api_key(self, api_key: str, db: AsyncSession | None = None) -> UserModel | None: - log.info(f'authenticate_user_by_api_key') - # if no api_key, return None - if not api_key: - return None + return + async def authenticate_user_by_api_key( + self, api_key: str, db: AsyncSession | None = None, + ) -> UserModel | None: + """Resolve an API key to its owning user, returning ``None`` on miss.""" + log.info('authenticate_user_by_api_key') + if not api_key: # empty / None key — reject immediately + return try: - user = await Users.get_user_by_api_key(api_key, db=db) - return user if user else None + return await Users.get_user_by_api_key(api_key, db=db) except Exception: return False - async def authenticate_user_by_email(self, email: str, db: AsyncSession | None = None) -> UserModel | None: - log.info(f'authenticate_user_by_email: {email}') + async def authenticate_user_by_email( + self, + email: str, + db: AsyncSession | None = None, + ) -> UserModel | None: + """One-query authentication: JOIN Auth ↔ User, filter by email + active flag.""" + log.info('authenticate_user_by_email: %s', email) + try: - async with get_async_db_context(db) as db: - # Single JOIN query instead of two separate queries - result = await db.execute( - select(Auth, User).join(User, Auth.id == User.id).filter(Auth.email == email, Auth.active == True) + async with get_async_db_context(db) as session: + stmt = ( + select(Auth, User) + .join(User, Auth.id == User.id) + .filter(Auth.email == email, Auth.active == True) ) - row = result.first() - if row: - _, user = row - return UserModel.model_validate(user) - return None + row = (await session.execute(stmt)).first() + if not row: + return + _auth, matched_user = row + return UserModel.model_validate(matched_user) except Exception: - return None + return - async def update_user_password_by_id(self, id: str, new_password: str, db: AsyncSession | None = None) -> bool: + async def update_user_password_by_id( + self, + id: str, + new_password: str, + db: AsyncSession | None = None, + ) -> bool: + """Hash-swap: replace the stored password hash for a given user.""" try: - async with get_async_db_context(db) as db: - result = await db.execute(update(Auth).filter_by(id=id).values(password=new_password)) - await db.commit() - return True if result.rowcount == 1 else False + async with get_async_db_context(db) as session: + stmt = update(Auth).filter_by(id=id).values(password=new_password) + result = await session.execute(stmt) + await session.commit() + return result.rowcount == 1 except Exception: return False - async def update_email_by_id(self, id: str, email: str, db: AsyncSession | None = None) -> bool: + async def update_email_by_id( + self, id: str, email: str, db: AsyncSession | None = None, + ) -> bool: + """Update the auth email and propagate the change to the User table.""" try: - async with get_async_db_context(db) as db: - result = await db.execute(update(Auth).filter_by(id=id).values(email=email)) - await db.commit() - if result.rowcount == 1: - await Users.update_user_by_id(id, {'email': email}, db=db) - return True - return False - except Exception: - return False - - async def delete_auth_by_id(self, id: str, db: AsyncSession | None = None) -> bool: - try: - async with get_async_db_context(db) as db: - # Delete User - result = await Users.delete_user_by_id(id, db=db) - - if result: - await db.execute(delete(Auth).filter_by(id=id)) - await db.commit() - - return True - else: + async with get_async_db_context(db) as session: + stmt = update(Auth).filter_by(id=id).values(email=email) + result = await session.execute(stmt) + await session.commit() + if result.rowcount != 1: return False + await Users.update_user_by_id(id, {'email': email}, db=session) + return True except Exception: return False + async def delete_auth_by_id( + self, id: str, db: AsyncSession | None = None, + ) -> bool: + """Delete a user and their auth record in a single transaction.""" + try: # delete user first, then auth (FK order) + async with get_async_db_context(db) as session: + if not await Users.delete_user_by_id(id, db=session): + return False # user deletion failed — abort + await session.execute(delete(Auth).filter_by(id=id)) + await session.commit() + return True + except Exception: # db / integrity error + return False # partial deletion is rolled back by context manager + Auths = AuthsTable() diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index 4a40a049a0..002d3d50dc 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -1,10 +1,11 @@ +"""Chat models, forms, and database operations.""" + from __future__ import annotations import json import logging import time import uuid -from typing import Optional from open_webui.internal.db import Base, JSONField, get_async_db_context from open_webui.models.automations import AutomationRun @@ -35,12 +36,6 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.sql import exists from sqlalchemy.sql.expression import bindparam -#################### -# Chat DB Schema -# Let no word spoken in this house be lost, and when the -# record is read again, let it still serve the one who spoke. -#################### - log = logging.getLogger(__name__) @@ -49,14 +44,14 @@ class Chat(Base): id = Column(String, primary_key=True, unique=True) user_id = Column(String) - title = Column(Text) + title = Column(Text) # user-visible conversation title chat = Column(JSON) created_at = Column(BigInteger) updated_at = Column(BigInteger) - share_id = Column(Text, unique=True, nullable=True) - archived = Column(Boolean, default=False) + share_id = Column(Text, unique=True, nullable=True) # public share link token + archived = Column(Boolean, default=False) # hidden from main chat list pinned = Column(Boolean, default=False, nullable=True) meta = Column(JSON, server_default='{}') @@ -302,7 +297,7 @@ class ChatTable: async def insert_new_chat( self, id: str, user_id: str, form_data: ChatForm, db: AsyncSession | None = None ) -> ChatModel | None: - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: chat = ChatModel( **{ 'id': id, @@ -318,9 +313,9 @@ class ChatTable: ) chat_item = Chat(**chat.model_dump()) - db.add(chat_item) - await db.commit() - await db.refresh(chat_item) + session.add(chat_item) + await session.commit() + await session.refresh(chat_item) # Dual-write initial messages to chat_message table try: @@ -362,7 +357,7 @@ class ChatTable: chat_import_forms: list[ChatImportForm], db: AsyncSession | None = None, ) -> list[ChatModel]: - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: chats = [] for form_data in chat_import_forms: @@ -370,7 +365,7 @@ class ChatTable: chats.append(Chat(**chat.model_dump())) db.add_all(chats) - await db.commit() + await session.commit() # Dual-write messages to chat_message table for form_data, chat_obj in zip(chat_import_forms, chats): @@ -390,10 +385,13 @@ class ChatTable: return [ChatModel.model_validate(chat) for chat in chats] - async def update_chat_by_id(self, id: str, chat: dict, db: AsyncSession | None = None) -> ChatModel | None: - try: - async with get_async_db_context(db) as db: - chat_item = await db.get(Chat, id) + async def update_chat_by_id( + self, id: str, chat: dict, db: AsyncSession | None = None, + ) -> ChatModel | None: + """Persist updated chat content, sanitizing null bytes.""" + try: # load the chat record for in-place mutation + async with get_async_db_context(db) as session: + chat_item = await session.get(Chat, id) if chat_item is None: return None @@ -402,19 +400,19 @@ class ChatTable: chat_item.updated_at = int(time.time()) - await db.commit() + await session.commit() return ChatModel.model_validate(chat_item) except Exception: - return None + return async def update_chat_last_read_at_by_id(self, id: str, user_id: str, db: AsyncSession | None = None) -> bool: try: - async with get_async_db_context(db) as db: - chat = await db.get(Chat, id) + async with get_async_db_context(db) as session: + chat = await session.get(Chat, id) if chat and chat.user_id == user_id: chat.last_read_at = int(time.time()) - await db.commit() + await session.commit() return True return False except Exception: @@ -423,22 +421,22 @@ class ChatTable: async def update_chat_title_by_id(self, id: str, title: str) -> ChatModel | None: try: async with get_async_db_context() as db: - chat_item = await db.get(Chat, id) + chat_item = await session.get(Chat, id) if chat_item is None: return None clean_title = self._clean_null_bytes(title) chat_item.title = clean_title chat_item.chat = {**(chat_item.chat or {}), 'title': clean_title} chat_item.updated_at = int(time.time()) - await db.commit() - await db.refresh(chat_item) + await session.commit() + await session.refresh(chat_item) return ChatModel.model_validate(chat_item) except Exception: return None async def update_chat_tags_by_id(self, id: str, tags: list[str], user) -> ChatModel | None: async with get_async_db_context() as db: - chat = await db.get(Chat, id) + chat = await session.get(Chat, id) if chat is None: return None @@ -448,22 +446,22 @@ class ChatTable: # Single meta update chat.meta = {**chat.meta, 'tags': new_tag_ids} - await db.commit() - await db.refresh(chat) + await session.commit() + await session.refresh(chat) # Batch-create any missing tag rows - await Tags.ensure_tags_exist(new_tags, user.id, db=db) + await Tags.ensure_tags_exist(new_tags, user.id, db=session) # Clean up orphaned old tags in one query removed = set(old_tags) - set(new_tag_ids) if removed: - await self.delete_orphan_tags_for_user(list(removed), user.id, db=db) + await self.delete_orphan_tags_for_user(list(removed), user.id, db=session) return ChatModel.model_validate(chat) async def get_chat_title_by_id(self, id: str) -> str | None: async with get_async_db_context() as db: - result = await db.execute(select(Chat.title).filter_by(id=id)) + result = await session.execute(select(Chat.title).filter_by(id=id)) row = result.first() if row is None: return None @@ -638,7 +636,7 @@ class ChatTable: async def add_message_files_by_id_and_message_id(self, id: str, message_id: str, files: list[dict]) -> list[dict]: async with get_async_db_context() as db: - chat = await self.get_chat_by_id(id, db=db) + chat = await self.get_chat_by_id(id, db=session) if chat is None: return None @@ -653,61 +651,62 @@ class ChatTable: history['messages'][message_id]['files'] = message_files chat['history'] = history - await self.update_chat_by_id(id, chat, db=db) + await self.update_chat_by_id(id, chat, db=session) return message_files async def insert_shared_chat_by_chat_id(self, chat_id: str, db: AsyncSession | None = None) -> ChatModel | None: """Create a shared snapshot for a chat. Returns the original chat with share_id set.""" from open_webui.models.shared_chats import SharedChats - async with get_async_db_context(db) as db: - chat = await db.get(Chat, chat_id) + async with get_async_db_context(db) as session: + chat = await session.get(Chat, chat_id) if not chat: return None # If already shared, just update the existing snapshot if chat.share_id: - return await self.update_shared_chat_by_chat_id(chat_id, db=db) + return await self.update_shared_chat_by_chat_id(chat_id, db=session) - shared = await SharedChats.create(chat_id, chat.user_id, db=db) + shared = await SharedChats.create(chat_id, chat.user_id, db=session) if not shared: return None # Set share_id on the original chat chat.share_id = shared.id - await db.commit() - await db.refresh(chat) - return ChatModel.model_validate(chat) + await session.commit() + await session.refresh(chat) + return ChatModel.model_validate(chat) # return the updated original - async def update_shared_chat_by_chat_id(self, chat_id: str, db: AsyncSession | None = None) -> ChatModel | None: - """Re-snapshot the shared chat with current chat data.""" - from open_webui.models.shared_chats import SharedChats + async def update_shared_chat_by_chat_id( + self, chat_id: str, db: AsyncSession | None = None, + ) -> ChatModel | None: + """Re-snapshot the shared copy with the latest chat content.""" + from open_webui.models.shared_chats import SharedChats # deferred — circular try: - async with get_async_db_context(db) as db: - chat = await db.get(Chat, chat_id) + async with get_async_db_context(db) as session: + chat = await session.get(Chat, chat_id) if not chat or not chat.share_id: - return await self.insert_shared_chat_by_chat_id(chat_id, db=db) - - await SharedChats.update(chat.share_id, db=db) + return await self.insert_shared_chat_by_chat_id(chat_id, db=session) + await SharedChats.update(chat.share_id, db=session) return ChatModel.model_validate(chat) except Exception: - return None + return async def delete_shared_chat_by_chat_id(self, chat_id: str, db: AsyncSession | None = None) -> bool: """Delete shared snapshot for a chat.""" from open_webui.models.shared_chats import SharedChats try: - return await SharedChats.delete_by_chat_id(chat_id, db=db) + return await SharedChats.delete_by_chat_id(chat_id, db=session) except Exception: return False async def unarchive_all_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool: try: - async with get_async_db_context(db) as db: - await db.execute(update(Chat).filter_by(user_id=user_id).values(archived=False)) - await db.commit() + async with get_async_db_context(db) as session: + await session.execute(update(Chat).filter_by(user_id=user_id).values(archived=False)) + await session.commit() return True except Exception: return False @@ -716,45 +715,45 @@ class ChatTable: self, id: str, share_id: str | None, db: AsyncSession | None = None ) -> ChatModel | None: try: - async with get_async_db_context(db) as db: - chat = await db.get(Chat, id) + async with get_async_db_context(db) as session: + chat = await session.get(Chat, id) chat.share_id = share_id - await db.commit() - await db.refresh(chat) + await session.commit() + await session.refresh(chat) return ChatModel.model_validate(chat) except Exception: return None async def toggle_chat_pinned_by_id(self, id: str, db: AsyncSession | None = None) -> ChatModel | None: try: - async with get_async_db_context(db) as db: - chat = await db.get(Chat, id) + async with get_async_db_context(db) as session: + chat = await session.get(Chat, id) chat.pinned = not chat.pinned chat.updated_at = int(time.time()) - await db.commit() - await db.refresh(chat) + await session.commit() + await session.refresh(chat) return ChatModel.model_validate(chat) except Exception: return None async def toggle_chat_archive_by_id(self, id: str, db: AsyncSession | None = None) -> ChatModel | None: try: - async with get_async_db_context(db) as db: - chat = await db.get(Chat, id) + async with get_async_db_context(db) as session: + chat = await session.get(Chat, id) chat.archived = not chat.archived chat.folder_id = None chat.updated_at = int(time.time()) - await db.commit() - await db.refresh(chat) + await session.commit() + await session.refresh(chat) return ChatModel.model_validate(chat) except Exception: return None async def archive_all_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool: try: - async with get_async_db_context(db) as db: - await db.execute(update(Chat).filter_by(user_id=user_id).values(archived=True)) - await db.commit() + async with get_async_db_context(db) as session: + await session.execute(update(Chat).filter_by(user_id=user_id).values(archived=True)) + await session.commit() return True except Exception: return False @@ -767,7 +766,7 @@ class ChatTable: limit: int = 50, db: AsyncSession | None = None, ) -> list[ChatTitleIdResponse]: - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at).filter_by( user_id=user_id, archived=True ) @@ -798,7 +797,7 @@ class ChatTable: if limit: stmt = stmt.limit(limit) - result = await db.execute(stmt) + result = await session.execute(stmt) all_chats = result.all() return [ ChatTitleIdResponse.model_validate( @@ -823,7 +822,7 @@ class ChatTable: """Delegate to SharedChats for listing shared chats by user.""" from open_webui.models.shared_chats import SharedChats - return await SharedChats.get_by_user_id(user_id, filter=filter, skip=skip, limit=limit, db=db) + return await SharedChats.get_by_user_id(user_id, filter=filter, skip=skip, limit=limit, db=session) async def get_chat_list_by_user_id( self, @@ -834,7 +833,7 @@ class ChatTable: limit: int = 50, db: AsyncSession | None = None, ) -> list[ChatTitleIdResponse]: - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by( user_id=user_id ) @@ -864,7 +863,7 @@ class ChatTable: if limit: stmt = stmt.limit(limit) - result = await db.execute(stmt) + result = await session.execute(stmt) all_chats = result.all() return [ ChatTitleIdResponse.model_validate( @@ -889,7 +888,7 @@ class ChatTable: limit: int | None = None, db: AsyncSession | None = None, ) -> list[ChatTitleIdResponse]: - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by( user_id=user_id ) @@ -910,7 +909,7 @@ class ChatTable: if limit: stmt = stmt.limit(limit) - result = await db.execute(stmt) + result = await session.execute(stmt) all_chats = result.all() return [ @@ -933,23 +932,26 @@ class ChatTable: limit: int = 50, db: AsyncSession | None = None, ) -> list[ChatModel]: - async with get_async_db_context(db) as db: - result = await db.execute( + async with get_async_db_context(db) as session: + result = await session.execute( select(Chat).filter(Chat.id.in_(chat_ids)).filter_by(archived=False).order_by(Chat.updated_at.desc()) ) all_chats = result.scalars().all() return [ChatModel.model_validate(chat) for chat in all_chats] - async def get_chat_by_id(self, id: str, db: AsyncSession | None = None) -> ChatModel | None: + async def get_chat_by_id( + self, id: str, db: AsyncSession | None = None, + ) -> ChatModel | None: + """Fetch a chat by PK, auto-sanitizing null bytes on read.""" try: - async with get_async_db_context(db) as db: - chat_item = await db.get(Chat, id) + async with get_async_db_context(db) as session: + chat_item = await session.get(Chat, id) if chat_item is None: return None if self._sanitize_chat_row(chat_item): - await db.commit() - await db.refresh(chat_item) + await session.commit() + await session.refresh(chat_item) return ChatModel.model_validate(chat_item) except Exception: @@ -960,7 +962,7 @@ class ChatTable: from open_webui.models.shared_chats import SharedChats try: - shared = await SharedChats.get_by_id(id, db=db) + shared = await SharedChats.get_by_id(id, db=session) if shared: # Return a ChatModel-compatible view of the snapshot return ChatModel( @@ -980,8 +982,8 @@ class ChatTable: self, id: str, user_id: str, db: AsyncSession | None = None ) -> ChatModel | None: try: - async with get_async_db_context(db) as db: - result = await db.execute(select(Chat).filter_by(id=id, user_id=user_id)) + async with get_async_db_context(db) as session: + result = await session.execute(select(Chat).filter_by(id=id, user_id=user_id)) chat = result.scalars().first() return ChatModel.model_validate(chat) if chat else None except Exception: @@ -993,8 +995,8 @@ class ChatTable: the full Chat row (which includes the potentially large JSON blob). """ try: - async with get_async_db_context(db) as db: - result = await db.execute(select(exists().where(and_(Chat.id == id, Chat.user_id == user_id)))) + async with get_async_db_context(db) as session: + result = await session.execute(select(exists().where(and_(Chat.id == id, Chat.user_id == user_id)))) return result.scalar() except Exception: return False @@ -1005,16 +1007,16 @@ class ChatTable: JSON blob. Returns None if chat doesn't exist or doesn't belong to user. """ try: - async with get_async_db_context(db) as db: - result = await db.execute(select(Chat.folder_id).filter_by(id=id, user_id=user_id)) + async with get_async_db_context(db) as session: + result = await session.execute(select(Chat.folder_id).filter_by(id=id, user_id=user_id)) row = result.first() return row[0] if row else None except Exception: return None async def get_chats(self, skip: int = 0, limit: int = 50, db: AsyncSession | None = None) -> list[ChatModel]: - async with get_async_db_context(db) as db: - result = await db.execute(select(Chat).order_by(Chat.updated_at.desc())) + async with get_async_db_context(db) as session: + result = await session.execute(select(Chat).order_by(Chat.updated_at.desc())) all_chats = result.scalars().all() return [ChatModel.model_validate(chat) for chat in all_chats] @@ -1026,7 +1028,7 @@ class ChatTable: limit: int | None = None, db: AsyncSession | None = None, ) -> ChatListResponse: - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: stmt = select(Chat).filter_by(user_id=user_id) if filter: @@ -1048,7 +1050,7 @@ class ChatTable: else: stmt = stmt.order_by(Chat.updated_at.desc(), Chat.id) - count_result = await db.execute(select(func.count()).select_from(stmt.subquery())) + count_result = await session.execute(select(func.count()).select_from(stmt.subquery())) total = count_result.scalar() if skip is not None: @@ -1056,7 +1058,7 @@ class ChatTable: if limit is not None: stmt = stmt.limit(limit) - result = await db.execute(stmt) + result = await session.execute(stmt) all_chats = result.scalars().all() return ChatListResponse( @@ -1069,8 +1071,8 @@ class ChatTable: async def get_pinned_chats_by_user_id( self, user_id: str, db: AsyncSession | None = None ) -> list[ChatTitleIdResponse]: - async with get_async_db_context(db) as db: - result = await db.execute( + async with get_async_db_context(db) as session: + result = await session.execute( select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at) .filter_by(user_id=user_id, pinned=True, archived=False) .order_by(Chat.updated_at.desc()) @@ -1090,8 +1092,8 @@ class ChatTable: ] async def get_archived_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> list[ChatModel]: - async with get_async_db_context(db) as db: - result = await db.execute( + async with get_async_db_context(db) as session: + result = await session.execute( select(Chat).filter_by(user_id=user_id, archived=True).order_by(Chat.updated_at.desc()) ) return [ChatModel.model_validate(chat) for chat in result.scalars().all()] @@ -1161,7 +1163,7 @@ class ChatTable: search_text = ' '.join(search_text_words) - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: stmt = select(Chat).filter(Chat.user_id == user_id) if is_archived is not None: @@ -1184,7 +1186,7 @@ class ChatTable: stmt = stmt.order_by(Chat.updated_at.desc(), Chat.id) # Check if the database dialect is either 'sqlite' or 'postgresql' - bind = await db.connection() + bind = await session.connection() dialect_name = bind.dialect.name if dialect_name == 'sqlite': # SQLite case: using JSON1 extension for JSON searching @@ -1282,7 +1284,7 @@ class ChatTable: # Perform pagination at the SQL level stmt = stmt.offset(skip).limit(limit) - result = await db.execute(stmt) + result = await session.execute(stmt) all_chats = result.scalars().all() log.info(f'The number of chats: {len(all_chats)}') @@ -1298,7 +1300,7 @@ class ChatTable: limit: int = 60, db: AsyncSession | None = None, ) -> list[ChatTitleIdResponse]: - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: stmt = ( select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at) .filter_by(folder_id=folder_id, user_id=user_id) @@ -1312,7 +1314,7 @@ class ChatTable: if limit: stmt = stmt.limit(limit) - result = await db.execute(stmt) + result = await session.execute(stmt) all_chats = result.all() return [ ChatTitleIdResponse.model_validate( @@ -1330,7 +1332,7 @@ class ChatTable: async def get_chats_by_folder_ids_and_user_id( self, folder_ids: list[str], user_id: str, db: AsyncSession | None = None ) -> list[ChatModel]: - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: stmt = ( select(Chat) .filter(Chat.folder_id.in_(folder_ids), Chat.user_id == user_id) @@ -1339,7 +1341,7 @@ class ChatTable: .order_by(Chat.updated_at.desc()) ) - result = await db.execute(stmt) + result = await session.execute(stmt) all_chats = result.scalars().all() return [ChatModel.model_validate(chat) for chat in all_chats] @@ -1347,13 +1349,13 @@ class ChatTable: self, id: str, user_id: str, folder_id: str, db: AsyncSession | None = None ) -> ChatModel | None: try: - async with get_async_db_context(db) as db: - chat = await db.get(Chat, id) + async with get_async_db_context(db) as session: + chat = await session.get(Chat, id) chat.folder_id = folder_id chat.updated_at = int(time.time()) chat.pinned = False - await db.commit() - await db.refresh(chat) + await session.commit() + await session.refresh(chat) return ChatModel.model_validate(chat) except Exception: return None @@ -1361,12 +1363,12 @@ class ChatTable: async def get_chat_tags_by_id_and_user_id( self, id: str, user_id: str, db: AsyncSession | None = None ) -> list[TagModel]: - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: stmt = select(Chat.meta).where(Chat.id == id) - result = await db.execute(stmt) + result = await session.execute(stmt) meta = result.scalar_one_or_none() tag_ids = (meta or {}).get('tags', []) - return await Tags.get_tags_by_ids_and_user_id(tag_ids, user_id, db=db) + return await Tags.get_tags_by_ids_and_user_id(tag_ids, user_id, db=session) async def get_chat_list_by_user_id_and_tag_name( self, @@ -1376,13 +1378,13 @@ class ChatTable: limit: int = 50, db: AsyncSession | None = None, ) -> list[ChatTitleIdResponse]: - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by( user_id=user_id ) tag_id = tag_name.replace(' ', '_').lower() - bind = await db.connection() + bind = await session.connection() dialect_name = bind.dialect.name log.info(f'DB dialect name: {dialect_name}') if dialect_name == 'sqlite': @@ -1403,7 +1405,7 @@ class ChatTable: if limit: stmt = stmt.limit(limit) - result = await db.execute(stmt) + result = await session.execute(stmt) all_chats = result.all() return [ ChatTitleIdResponse.model_validate( @@ -1422,17 +1424,17 @@ class ChatTable: self, id: str, user_id: str, tag_name: str, db: AsyncSession | None = None ) -> ChatModel | None: tag_id = tag_name.replace(' ', '_').lower() - await Tags.ensure_tags_exist([tag_name], user_id, db=db) + await Tags.ensure_tags_exist([tag_name], user_id, db=session) try: - async with get_async_db_context(db) as db: - chat = await db.get(Chat, id) + async with get_async_db_context(db) as session: + chat = await session.get(Chat, id) if tag_id not in chat.meta.get('tags', []): chat.meta = { **chat.meta, 'tags': list(set(chat.meta.get('tags', []) + [tag_id])), } - await db.commit() - await db.refresh(chat) + await session.commit() + await session.refresh(chat) return ChatModel.model_validate(chat) except Exception: return None @@ -1440,11 +1442,11 @@ class ChatTable: async def count_chats_by_tag_name_and_user_id( self, tag_name: str, user_id: str, db: AsyncSession | None = None ) -> int: - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: stmt = select(func.count(Chat.id)).filter_by(user_id=user_id, archived=False) tag_id = tag_name.replace(' ', '_').lower() - bind = await db.connection() + bind = await session.connection() dialect_name = bind.dialect.name if dialect_name == 'sqlite': stmt = stmt.filter( @@ -1457,7 +1459,7 @@ class ChatTable: else: raise NotImplementedError(f'Unsupported dialect: {dialect_name}') - result = await db.execute(stmt) + result = await session.execute(stmt) return result.scalar() async def delete_orphan_tags_for_user( @@ -1477,19 +1479,19 @@ class ChatTable: """ if not tag_ids: return - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: orphans = [] for tag_id in tag_ids: - count = await self.count_chats_by_tag_name_and_user_id(tag_id, user_id, db=db) + count = await self.count_chats_by_tag_name_and_user_id(tag_id, user_id, db=session) if count <= threshold: orphans.append(tag_id) - await Tags.delete_tags_by_ids_and_user_id(orphans, user_id, db=db) + await Tags.delete_tags_by_ids_and_user_id(orphans, user_id, db=session) async def count_chats_by_folder_id_and_user_id( self, folder_id: str, user_id: str, db: AsyncSession | None = None ) -> int: - async with get_async_db_context(db) as db: - result = await db.execute(select(func.count(Chat.id)).filter_by(user_id=user_id, folder_id=folder_id)) + async with get_async_db_context(db) as session: + result = await session.execute(select(func.count(Chat.id)).filter_by(user_id=user_id, folder_id=folder_id)) count = result.scalar() log.info(f"Count of chats for folder '{folder_id}': {count}") @@ -1499,8 +1501,8 @@ class ChatTable: self, id: str, user_id: str, tag_name: str, db: AsyncSession | None = None ) -> bool: try: - async with get_async_db_context(db) as db: - chat = await db.get(Chat, id) + async with get_async_db_context(db) as session: + chat = await session.get(Chat, id) tags = chat.meta.get('tags', []) tag_id = tag_name.replace(' ', '_').lower() @@ -1509,51 +1511,51 @@ class ChatTable: **chat.meta, 'tags': list(set(tags)), } - await db.commit() + await session.commit() return True except Exception: return False async def delete_chat_by_id(self, id: str, db: AsyncSession | None = None) -> bool: try: - async with get_async_db_context(db) as db: - await db.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None)) - await db.execute(delete(ChatMessage).filter_by(chat_id=id)) - await db.execute(delete(Chat).filter_by(id=id)) - await db.commit() + async with get_async_db_context(db) as session: + await session.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None)) + await session.execute(delete(ChatMessage).filter_by(chat_id=id)) + await session.execute(delete(Chat).filter_by(id=id)) + await session.commit() - return True and await self.delete_shared_chat_by_chat_id(id, db=db) + return True and await self.delete_shared_chat_by_chat_id(id, db=session) except Exception: return False async def delete_chat_by_id_and_user_id(self, id: str, user_id: str, db: AsyncSession | None = None) -> bool: try: - async with get_async_db_context(db) as db: - await db.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None)) - await db.execute(delete(ChatMessage).filter_by(chat_id=id)) - await db.execute(delete(Chat).filter_by(id=id, user_id=user_id)) - await db.commit() + async with get_async_db_context(db) as session: + await session.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None)) + await session.execute(delete(ChatMessage).filter_by(chat_id=id)) + await session.execute(delete(Chat).filter_by(id=id, user_id=user_id)) + await session.commit() - return True and await self.delete_shared_chat_by_chat_id(id, db=db) + return True and await self.delete_shared_chat_by_chat_id(id, db=session) except Exception: return False async def delete_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool: try: - async with get_async_db_context(db) as db: - await self.delete_shared_chats_by_user_id(user_id, db=db) + async with get_async_db_context(db) as session: + await self.delete_shared_chats_by_user_id(user_id, db=session) chat_id_subquery = select(Chat.id).filter_by(user_id=user_id).scalar_subquery() - await db.execute( + await session.execute( update(AutomationRun) .filter(AutomationRun.chat_id.in_(select(Chat.id).filter_by(user_id=user_id))) .values(chat_id=None) ) - await db.execute( + await session.execute( delete(ChatMessage).filter(ChatMessage.chat_id.in_(select(Chat.id).filter_by(user_id=user_id))) ) - await db.execute(delete(Chat).filter_by(user_id=user_id)) - await db.commit() + await session.execute(delete(Chat).filter_by(user_id=user_id)) + await session.commit() return True except Exception: @@ -1563,14 +1565,14 @@ class ChatTable: self, user_id: str, folder_id: str, db: AsyncSession | None = None ) -> bool: try: - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: chat_ids_stmt = select(Chat.id).filter_by(user_id=user_id, folder_id=folder_id) - await db.execute( + await session.execute( update(AutomationRun).filter(AutomationRun.chat_id.in_(chat_ids_stmt)).values(chat_id=None) ) - await db.execute(delete(ChatMessage).filter(ChatMessage.chat_id.in_(chat_ids_stmt))) - await db.execute(delete(Chat).filter_by(user_id=user_id, folder_id=folder_id)) - await db.commit() + await session.execute(delete(ChatMessage).filter(ChatMessage.chat_id.in_(chat_ids_stmt))) + await session.execute(delete(Chat).filter_by(user_id=user_id, folder_id=folder_id)) + await session.commit() return True except Exception: @@ -1584,11 +1586,11 @@ class ChatTable: db: AsyncSession | None = None, ) -> bool: try: - async with get_async_db_context(db) as db: - await db.execute( + async with get_async_db_context(db) as session: + await session.execute( update(Chat).filter_by(user_id=user_id, folder_id=folder_id).values(folder_id=new_folder_id) ) - await db.commit() + await session.commit() return True except Exception: @@ -1600,13 +1602,13 @@ class ChatTable: from open_webui.models.shared_chats import SharedChats try: - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: # Delete shared_chat rows for this user's chats - await db.execute(delete(SharedChatTable).filter_by(user_id=user_id)) + await session.execute(delete(SharedChatTable).filter_by(user_id=user_id)) # Clear share_id on all of this user's chats - await db.execute(update(Chat).filter_by(user_id=user_id).values(share_id=None)) - await db.commit() + await session.execute(update(Chat).filter_by(user_id=user_id).values(share_id=None)) + await session.commit() return True except Exception: @@ -1624,7 +1626,7 @@ class ChatTable: return None chat_message_file_ids = { - item.id for item in await self.get_chat_files_by_chat_id_and_message_id(chat_id, message_id, db=db) + item.id for item in await self.get_chat_files_by_chat_id_and_message_id(chat_id, message_id, db=session) } # Remove duplicates and existing file_ids file_ids = list({file_id for file_id in file_ids if file_id and file_id not in chat_message_file_ids}) @@ -1632,7 +1634,7 @@ class ChatTable: return None try: - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: now = int(time.time()) chat_files = [ @@ -1651,7 +1653,7 @@ class ChatTable: results = [ChatFile(**chat_file.model_dump()) for chat_file in chat_files] db.add_all(results) - await db.commit() + await session.commit() return chat_files except Exception: @@ -1660,8 +1662,8 @@ class ChatTable: async def get_chat_files_by_chat_id_and_message_id( self, chat_id: str, message_id: str, db: AsyncSession | None = None ) -> list[ChatFileModel]: - async with get_async_db_context(db) as db: - result = await db.execute( + async with get_async_db_context(db) as session: + result = await session.execute( select(ChatFile).filter_by(chat_id=chat_id, message_id=message_id).order_by(ChatFile.created_at.asc()) ) all_chat_files = result.scalars().all() @@ -1669,17 +1671,17 @@ class ChatTable: async def delete_chat_file(self, chat_id: str, file_id: str, db: AsyncSession | None = None) -> bool: try: - async with get_async_db_context(db) as db: - await db.execute(delete(ChatFile).filter_by(chat_id=chat_id, file_id=file_id)) - await db.commit() + async with get_async_db_context(db) as session: + await session.execute(delete(ChatFile).filter_by(chat_id=chat_id, file_id=file_id)) + await session.commit() return True except Exception: return False async def get_shared_chat_ids_by_file_id(self, file_id: str, db: AsyncSession | None = None) -> list[str]: """Return IDs of chats that contain this file and have an active share link.""" - async with get_async_db_context(db) as db: - result = await db.execute( + async with get_async_db_context(db) as session: + result = await session.execute( select(Chat.id) .join(ChatFile, Chat.id == ChatFile.chat_id) .filter(ChatFile.file_id == file_id, Chat.share_id.isnot(None)) @@ -1690,12 +1692,12 @@ class ChatTable: """Update the tasks list on a chat.""" try: async with get_async_db_context() as db: - chat = await db.get(Chat, id) + chat = await session.get(Chat, id) if chat is None: return None chat.tasks = tasks - await db.commit() - await db.refresh(chat) + await session.commit() + await session.refresh(chat) return ChatModel.model_validate(chat) except Exception: return None @@ -1703,7 +1705,7 @@ class ChatTable: async def get_chat_tasks_by_id(self, id: str) -> list[dict]: """Read the tasks list from a chat (lightweight column query).""" async with get_async_db_context() as db: - result = await db.execute(select(Chat.tasks).filter_by(id=id)) + result = await session.execute(select(Chat.tasks).filter_by(id=id)) row = result.first() if row is None or row[0] is None: return [] diff --git a/backend/open_webui/models/files.py b/backend/open_webui/models/files.py index 17fe85b294..9f2451dda7 100644 --- a/backend/open_webui/models/files.py +++ b/backend/open_webui/models/files.py @@ -1,8 +1,9 @@ +"""File upload models, forms, and database operations.""" + from __future__ import annotations import logging import time -from typing import Optional from open_webui.internal.db import Base, JSONField, get_async_db_context from open_webui.utils.misc import sanitize_metadata @@ -12,12 +13,6 @@ from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) -#################### -# Files DB Schema -# What is written here bears witness. Let the testimony -# remain as it was given, and let none tamper with it. -#################### - class File(Base): __tablename__ = 'file' @@ -25,7 +20,7 @@ class File(Base): user_id = Column(String) hash = Column(Text, nullable=True) - filename = Column(Text) + filename = Column(Text) # original upload filename path = Column(Text, nullable=True) data = Column(JSON, nullable=True) @@ -52,11 +47,6 @@ class FileModel(BaseModel): updated_at: int | None # timestamp in epoch -#################### -# Forms -#################### - - class FileMeta(BaseModel): name: str | None = None content_type: str | None = None @@ -157,16 +147,18 @@ class FilesTable: return None except Exception as e: log.exception(f'Error inserting a new file: {e}') - return None + return None # insertion failed - async def get_file_by_id(self, id: str, db: AsyncSession | None = None) -> FileModel | None: + async def get_file_by_id( + self, id: str, db: AsyncSession | None = None, + ) -> FileModel | None: + """Look up a file by its primary key.""" try: async with get_async_db_context(db) as db: - try: - file = await db.get(File, id) - return FileModel.model_validate(file) if file else None - except Exception: + file = await db.get(File, id) + if not file: return None + return FileModel.model_validate(file) except Exception: return None diff --git a/backend/open_webui/models/functions.py b/backend/open_webui/models/functions.py index 3a47d402cf..84afa1dda6 100644 --- a/backend/open_webui/models/functions.py +++ b/backend/open_webui/models/functions.py @@ -1,8 +1,9 @@ +"""Function (filter/action/pipe) models, forms, and database operations.""" + from __future__ import annotations import logging import time -from typing import Optional from open_webui.internal.db import Base, JSONField, get_async_db_context from open_webui.models.users import UserModel, UserResponse, Users @@ -12,12 +13,6 @@ from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) -#################### -# Functions DB Schema -# Each function here is a promise made. Let no promise -# go unkept, and let none be called who cannot answer. -#################### - class Function(Base): __tablename__ = 'function' @@ -30,11 +25,11 @@ class Function(Base): meta = Column(JSONField) valves = Column(JSONField) is_active = Column(Boolean) - is_global = Column(Boolean) - updated_at = Column(BigInteger) - created_at = Column(BigInteger) + is_global = Column(Boolean) # if True, applied to every chat automatically + updated_at = Column(BigInteger) # epoch seconds + created_at = Column(BigInteger) # epoch seconds - __table_args__ = (Index('is_global_idx', 'is_global'),) + __table_args__ = (Index('is_global_idx', 'is_global'),) # speed up global-function lookups class FunctionMeta(BaseModel): diff --git a/backend/open_webui/models/memories.py b/backend/open_webui/models/memories.py index d7f51e2c9e..30cd6267b2 100644 --- a/backend/open_webui/models/memories.py +++ b/backend/open_webui/models/memories.py @@ -1,3 +1,5 @@ +"""Long-term memory storage for per-user context recall.""" + from __future__ import annotations import time @@ -9,38 +11,30 @@ from pydantic import BaseModel, ConfigDict from sqlalchemy import BigInteger, Column, String, Text, delete, select from sqlalchemy.ext.asyncio import AsyncSession -#################### -# Memory DB Schema -# What was learned at cost should not need to be paid -# for again. Let the memory hold. -#################### - class Memory(Base): + """Persistent user memory backed by a vector collection.""" + __tablename__ = 'memory' id = Column(String, primary_key=True, unique=True) user_id = Column(String, index=True) - content = Column(Text) - updated_at = Column(BigInteger) - created_at = Column(BigInteger) + content = Column(Text) # free-form text learned from conversation + updated_at = Column(BigInteger) # epoch seconds + created_at = Column(BigInteger) # epoch seconds class MemoryModel(BaseModel): + """Pydantic mirror of the Memory table row.""" + id: str user_id: str content: str updated_at: int # timestamp in epoch created_at: int # timestamp in epoch - model_config = ConfigDict(from_attributes=True) -#################### -# Forms -#################### - - class MemoriesTable: async def insert_new_memory( self, @@ -48,26 +42,20 @@ class MemoriesTable: content: str, db: AsyncSession | None = None, ) -> MemoryModel | None: + """Persist a new memory entry and return the created model.""" async with get_async_db_context(db) as db: - id = str(uuid.uuid4()) - - memory = MemoryModel( - **{ - 'id': id, - 'user_id': user_id, - 'content': content, - 'created_at': int(time.time()), - 'updated_at': int(time.time()), - } + now = int(time.time()) + record = Memory( + id=str(uuid.uuid4()), + user_id=user_id, + content=content, + created_at=now, + updated_at=now, ) - result = Memory(**memory.model_dump()) - db.add(result) + db.add(record) await db.commit() - await db.refresh(result) - if result: - return MemoryModel.model_validate(result) - else: - return None + await db.refresh(record) + return MemoryModel.model_validate(record) if record else None async def update_memory_by_id_and_user_id( self, @@ -143,12 +131,10 @@ class MemoriesTable: try: memory = await db.get(Memory, id) if not memory or memory.user_id != user_id: - return None + return False - # Delete the memory await db.delete(memory) await db.commit() - return True except Exception: return False diff --git a/backend/open_webui/models/models.py b/backend/open_webui/models/models.py index 443cd4530d..0e34f81378 100755 --- a/backend/open_webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -60,38 +60,19 @@ class ModelMeta(BaseModel): class Model(Base): + """Workspace model entry — wraps an upstream LLM with custom params and metadata.""" + __tablename__ = 'model' - id = Column(Text, primary_key=True, unique=True) - """ - The model's id as used in the API. If set to an existing model, it will override the model. - """ - user_id = Column(Text) - - base_model_id = Column(Text, nullable=True) - """ - An optional pointer to the actual model that should be used when proxying requests. - """ - - name = Column(Text) - """ - The human-readable display name of the model. - """ - - params = Column(JSONField) - """ - Holds a JSON encoded blob of parameters, see `ModelParams`. - """ - - meta = Column(JSONField) - """ - Holds a JSON encoded blob of metadata, see `ModelMeta`. - """ - - is_active = Column(Boolean, default=True) - - updated_at = Column(BigInteger) - created_at = Column(BigInteger) + id = Column(Text, primary_key=True, unique=True) # API model identifier; overrides built-in when matching + user_id = Column(Text) # owner + base_model_id = Column(Text, nullable=True) # actual upstream model for proxied requests + name = Column(Text) # human-readable display name + params = Column(JSONField) # see ModelParams + meta = Column(JSONField) # see ModelMeta + is_active = Column(Boolean, default=True) # soft-disable toggle + updated_at = Column(BigInteger) # epoch seconds + created_at = Column(BigInteger) # epoch seconds class ModelModel(BaseModel): @@ -109,12 +90,9 @@ class ModelModel(BaseModel): updated_at: int # timestamp in epoch created_at: int # timestamp in epoch - model_config = ConfigDict(from_attributes=True) - - -#################### -# Forms -#################### + model_config = ConfigDict( + from_attributes=True, + ) class ModelUserResponse(ModelModel): diff --git a/backend/open_webui/models/prompts.py b/backend/open_webui/models/prompts.py index 732c868227..ef305f7353 100644 --- a/backend/open_webui/models/prompts.py +++ b/backend/open_webui/models/prompts.py @@ -1,3 +1,5 @@ +"""Prompt template models, forms, and database operations.""" + from __future__ import annotations import json @@ -14,23 +16,19 @@ from pydantic import BaseModel, ConfigDict, Field from sqlalchemy import JSON, BigInteger, Boolean, Column, String, Text, cast, delete, func, or_, select, text, update from sqlalchemy.ext.asyncio import AsyncSession -#################### -# Prompts DB Schema -# Every word here was weighed before it was set down. -# Let the weight not be wasted when it is spoken aloud. -#################### - class Prompt(Base): + """System prompt template with versioning support.""" + __tablename__ = 'prompt' id = Column(Text, primary_key=True) command = Column(String, unique=True, index=True) user_id = Column(String) name = Column(Text) - content = Column(Text) - data = Column(JSON, nullable=True) - meta = Column(JSON, nullable=True) + content = Column(Text) # the prompt template body + data = Column(JSON, nullable=True) # structured prompt parameters + meta = Column(JSON, nullable=True) # freeform metadata (description, etc.) tags = Column(JSON, nullable=True) is_active = Column(Boolean, default=True) version_id = Column(Text, nullable=True) # Points to active history entry @@ -94,7 +92,7 @@ class PromptForm(BaseModel): class PromptsTable: async def _get_access_grants(self, prompt_id: str, db: AsyncSession | None = None) -> list[AccessGrantModel]: - return await AccessGrants.get_grants_by_resource('prompt', prompt_id, db=db) + return await AccessGrants.get_grants_by_resource('prompt', prompt_id, db=session) async def _to_prompt_model( self, @@ -104,98 +102,94 @@ class PromptsTable: ) -> PromptModel: prompt_data = PromptModel.model_validate(prompt).model_dump(exclude={'access_grants'}) prompt_data['access_grants'] = ( - access_grants if access_grants is not None else await self._get_access_grants(prompt_data['id'], db=db) + access_grants if access_grants is not None else await self._get_access_grants(prompt_data['id'], db=session) ) return PromptModel.model_validate(prompt_data) - async def insert_new_prompt( - self, user_id: str, form_data: PromptForm, db: AsyncSession | None = None - ) -> PromptModel | None: + async def insert_new_prompt(self, user_id: str, form_data: PromptForm, db: AsyncSession | None = None) -> PromptModel | None: now = int(time.time()) prompt_id = str(uuid.uuid4()) - prompt = PromptModel( - id=prompt_id, - user_id=user_id, - command=form_data.command, - name=form_data.name, - content=form_data.content, - data=form_data.data or {}, - meta=form_data.meta or {}, - tags=form_data.tags or [], - access_grants=[], - is_active=True, - created_at=now, - updated_at=now, - ) + async with get_async_db_context(db) as session: + try: + record = Prompt( + id=prompt_id, user_id=user_id, + command=form_data.command, name=form_data.name, + content=form_data.content, + data=form_data.data or {}, meta=form_data.meta or {}, + tags=form_data.tags or [], is_active=True, + created_at=now, updated_at=now, + ) + session.add(record) + await session.commit() + await session.refresh(record) # populate generated defaults - try: - async with get_async_db_context(db) as db: - result = Prompt(**prompt.model_dump(exclude={'access_grants'})) - db.add(result) - await db.commit() - await db.refresh(result) - await AccessGrants.set_access_grants('prompt', prompt_id, form_data.access_grants, db=db) + await AccessGrants.set_access_grants( + 'prompt', prompt_id, form_data.access_grants, db=session, + ) # persist sharing rules - if result: - current_access_grants = await self._get_access_grants(prompt_id, db=db) - snapshot = { - 'name': form_data.name, - 'content': form_data.content, - 'command': form_data.command, - 'data': form_data.data or {}, - 'meta': form_data.meta or {}, - 'tags': form_data.tags or [], - 'access_grants': [grant.model_dump() for grant in current_access_grants], - } - - history_entry = await PromptHistories.create_history_entry( - prompt_id=prompt_id, - snapshot=snapshot, - user_id=user_id, - parent_id=None, # Initial commit has no parent - commit_message=form_data.commit_message or 'Initial version', - db=db, - ) - - # Set the initial version as the production version - if history_entry: - result.version_id = history_entry.id - await db.commit() - await db.refresh(result) - - return await self._to_prompt_model(result, db=db) - else: + if not record: # shouldn't happen, but guard anyway return None - except Exception: - return None + + # Build the initial version snapshot. + grants = await self._get_access_grants(prompt_id, db=session) + snapshot = { + 'name': form_data.name, + 'content': form_data.content, + 'command': form_data.command, + 'data': form_data.data or {}, + 'meta': form_data.meta or {}, + 'tags': form_data.tags or [], + 'access_grants': [g.model_dump() for g in grants], + } + + history_entry = await PromptHistories.create_history_entry( + prompt_id=prompt_id, snapshot=snapshot, + user_id=user_id, parent_id=None, + commit_message=form_data.commit_message or 'Initial version', + db=session, + ) # creates the first version entry + + # Pin the initial history entry as the production version. + if history_entry: + record.version_id = history_entry.id + await session.commit() + await session.refresh(record) # re-read version_id + + return await self._to_prompt_model(record, db=session) + except Exception as e: + log.exception('Error creating prompt: %s', e) + return None async def get_prompt_by_id(self, prompt_id: str, db: AsyncSession | None = None) -> PromptModel | None: - """Get prompt by UUID.""" try: - async with get_async_db_context(db) as db: - result = await db.execute(select(Prompt).filter_by(id=prompt_id)) - prompt = result.scalars().first() - if prompt: - return await self._to_prompt_model(prompt, db=db) - return None - except Exception: - return None + async with get_async_db_context(db) as session: + result = await session.execute( + select(Prompt).filter_by(id=prompt_id), + ) + prompt = result.scalars().first() # None when not found + if not prompt: + return None + return await self._to_prompt_model(prompt, db=session) + except Exception: # connection / integrity error + return async def get_prompt_by_command(self, command: str, db: AsyncSession | None = None) -> PromptModel | None: try: - async with get_async_db_context(db) as db: - result = await db.execute(select(Prompt).filter_by(command=command)) - prompt = result.scalars().first() - if prompt: - return await self._to_prompt_model(prompt, db=db) - return None - except Exception: - return None + async with get_async_db_context(db) as session: + result = await session.execute( + select(Prompt).filter_by(command=command), + ) + prompt = result.scalars().first() # None when no match + if not prompt: + return None + return await self._to_prompt_model(prompt, db=session) + except Exception: # connection / integrity error + return async def get_prompts(self, db: AsyncSession | None = None) -> list[PromptUserResponse]: - async with get_async_db_context(db) as db: - result = await db.execute( + async with get_async_db_context(db) as session: + result = await session.execute( select(Prompt).filter(Prompt.is_active == True).order_by(Prompt.updated_at.desc()) ) all_prompts = result.scalars().all() @@ -203,9 +197,9 @@ class PromptsTable: user_ids = list(set(prompt.user_id for prompt in all_prompts)) prompt_ids = [prompt.id for prompt in all_prompts] - users = await Users.get_users_by_user_ids(user_ids, db=db) if user_ids else [] + users = await Users.get_users_by_user_ids(user_ids, db=session) if user_ids else [] users_dict = {user.id: user for user in users} - grants_map = await AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=db) + grants_map = await AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=session) prompts = [] for prompt in all_prompts: @@ -217,7 +211,7 @@ class PromptsTable: await self._to_prompt_model( prompt, access_grants=grants_map.get(prompt.id, []), - db=db, + db=session, ) ).model_dump(), 'user': user.model_dump() if user else None, @@ -230,8 +224,8 @@ class PromptsTable: async def get_prompts_by_user_id( self, user_id: str, permission: str = 'write', db: AsyncSession | None = None ) -> list[PromptUserResponse]: - async with get_async_db_context(db) as db: - user_groups = await Groups.get_groups_by_member_id(user_id, db=db) + async with get_async_db_context(db) as session: + user_groups = await Groups.get_groups_by_member_id(user_id, db=session) user_group_ids = [group.id for group in user_groups] query = select(Prompt).filter(Prompt.is_active == True).order_by(Prompt.updated_at.desc()) @@ -244,7 +238,7 @@ class PromptsTable: permission=permission, ) - result = await db.execute(query) + result = await session.execute(query) accessible_prompts = result.scalars().all() if not accessible_prompts: @@ -253,9 +247,9 @@ class PromptsTable: prompt_ids = [p.id for p in accessible_prompts] owner_ids = list({p.user_id for p in accessible_prompts}) - users = await Users.get_users_by_user_ids(owner_ids, db=db) + users = await Users.get_users_by_user_ids(owner_ids, db=session) users_dict = {u.id: u for u in users} - grants_map = await AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=db) + grants_map = await AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=session) results = [] for prompt in accessible_prompts: @@ -284,7 +278,7 @@ class PromptsTable: limit: int = 30, db: AsyncSession | None = None, ) -> PromptListResponse: - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: # Join with User table for user filtering and sorting query = select(Prompt, User).outerjoin(User, User.id == Prompt.user_id) @@ -319,7 +313,7 @@ class PromptsTable: tag = filter.get('tag') if tag: - bind = await db.connection() + bind = await session.connection() dialect_name = bind.dialect.name tag_lower = tag.lower() @@ -367,7 +361,7 @@ class PromptsTable: query = query.order_by(Prompt.updated_at.desc()) # Count BEFORE pagination - count_result = await db.execute(select(func.count()).select_from(query.subquery())) + count_result = await session.execute(select(func.count()).select_from(query.subquery())) total = count_result.scalar() if skip: @@ -375,11 +369,11 @@ class PromptsTable: if limit: query = query.limit(limit) - result = await db.execute(query) + result = await session.execute(query) items = result.all() prompt_ids = [prompt.id for prompt, _ in items] - grants_map = await AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=db) + grants_map = await AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=session) prompts = [] for prompt, user in items: @@ -406,15 +400,15 @@ class PromptsTable: db: AsyncSession | None = None, ) -> PromptModel | None: try: - async with get_async_db_context(db) as db: - result = await db.execute(select(Prompt).filter_by(command=command)) + async with get_async_db_context(db) as session: + result = await session.execute(select(Prompt).filter_by(command=command)) prompt = result.scalars().first() if not prompt: return None - latest_history = await PromptHistories.get_latest_history_entry(prompt.id, db=db) + latest_history = await PromptHistories.get_latest_history_entry(prompt.id, db=session) parent_id = latest_history.id if latest_history else None - current_access_grants = await self._get_access_grants(prompt.id, db=db) + current_access_grants = await self._get_access_grants(prompt.id, db=session) # Check if content changed to decide on history creation content_changed = ( @@ -430,10 +424,10 @@ class PromptsTable: prompt.meta = form_data.meta or prompt.meta prompt.updated_at = int(time.time()) if form_data.access_grants is not None: - await AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=db) - current_access_grants = await self._get_access_grants(prompt.id, db=db) + await AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=session) + current_access_grants = await self._get_access_grants(prompt.id, db=session) - await db.commit() + await session.commit() # Create history entry only if content changed if content_changed: @@ -458,9 +452,9 @@ class PromptsTable: # Set as production if flag is True (default) if form_data.is_production and history_entry: prompt.version_id = history_entry.id - await db.commit() + await session.commit() - return await self._to_prompt_model(prompt, db=db) + return await self._to_prompt_model(prompt, db=session) except Exception: return None @@ -472,15 +466,15 @@ class PromptsTable: db: AsyncSession | None = None, ) -> PromptModel | None: try: - async with get_async_db_context(db) as db: - result = await db.execute(select(Prompt).filter_by(id=prompt_id)) + async with get_async_db_context(db) as session: + result = await session.execute(select(Prompt).filter_by(id=prompt_id)) prompt = result.scalars().first() if not prompt: return None - latest_history = await PromptHistories.get_latest_history_entry(prompt.id, db=db) + latest_history = await PromptHistories.get_latest_history_entry(prompt.id, db=session) parent_id = latest_history.id if latest_history else None - current_access_grants = await self._get_access_grants(prompt.id, db=db) + current_access_grants = await self._get_access_grants(prompt.id, db=session) # Check if content changed to decide on history creation content_changed = ( @@ -502,12 +496,12 @@ class PromptsTable: prompt.tags = form_data.tags if form_data.access_grants is not None: - await AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=db) - current_access_grants = await self._get_access_grants(prompt.id, db=db) + await AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=session) + current_access_grants = await self._get_access_grants(prompt.id, db=session) prompt.updated_at = int(time.time()) - await db.commit() + await session.commit() # Create history entry only if content changed if content_changed: @@ -533,9 +527,9 @@ class PromptsTable: # Set as production if flag is True (default) if form_data.is_production and history_entry: prompt.version_id = history_entry.id - await db.commit() + await session.commit() - return await self._to_prompt_model(prompt, db=db) + return await self._to_prompt_model(prompt, db=session) except Exception: return None @@ -549,8 +543,8 @@ class PromptsTable: ) -> PromptModel | None: """Update only name, command, and tags (no history created).""" try: - async with get_async_db_context(db) as db: - result = await db.execute(select(Prompt).filter_by(id=prompt_id)) + async with get_async_db_context(db) as session: + result = await session.execute(select(Prompt).filter_by(id=prompt_id)) prompt = result.scalars().first() if not prompt: return None @@ -562,9 +556,9 @@ class PromptsTable: prompt.tags = tags prompt.updated_at = int(time.time()) - await db.commit() + await session.commit() - return await self._to_prompt_model(prompt, db=db) + return await self._to_prompt_model(prompt, db=session) except Exception: return None @@ -576,13 +570,13 @@ class PromptsTable: ) -> PromptModel | None: """Set the active version of a prompt and restore content from that version's snapshot.""" try: - async with get_async_db_context(db) as db: - result = await db.execute(select(Prompt).filter_by(id=prompt_id)) + async with get_async_db_context(db) as session: + result = await session.execute(select(Prompt).filter_by(id=prompt_id)) prompt = result.scalars().first() if not prompt: return None - history_entry = await PromptHistories.get_history_entry_by_id(version_id, db=db) + history_entry = await PromptHistories.get_history_entry_by_id(version_id, db=session) if not history_entry: return None @@ -599,24 +593,28 @@ class PromptsTable: prompt.version_id = version_id prompt.updated_at = int(time.time()) - await db.commit() + await session.commit() - return await self._to_prompt_model(prompt, db=db) - except Exception: + return await self._to_prompt_model(prompt, db=session) + except Exception: # connection error return None - async def toggle_prompt_active(self, prompt_id: str, db: AsyncSession | None = None) -> PromptModel | None: + # --- Active state management --- + + async def toggle_prompt_active( + self, prompt_id: str, db: AsyncSession | None = None, + ) -> PromptModel | None: """Toggle the is_active flag on a prompt.""" try: - async with get_async_db_context(db) as db: - result = await db.execute(select(Prompt).filter_by(id=prompt_id)) + async with get_async_db_context(db) as session: + result = await session.execute(select(Prompt).filter_by(id=prompt_id)) prompt = result.scalars().first() if prompt: prompt.is_active = not prompt.is_active prompt.updated_at = int(time.time()) - await db.commit() - await db.refresh(prompt) - return await self._to_prompt_model(prompt, db=db) + await session.commit() + await session.refresh(prompt) + return await self._to_prompt_model(prompt, db=session) return None except Exception: return None @@ -624,15 +622,15 @@ class PromptsTable: async def delete_prompt_by_command(self, command: str, db: AsyncSession | None = None) -> bool: """Permanently delete a prompt and its history.""" try: - async with get_async_db_context(db) as db: - result = await db.execute(select(Prompt).filter_by(command=command)) + async with get_async_db_context(db) as session: + result = await session.execute(select(Prompt).filter_by(command=command)) prompt = result.scalars().first() if prompt: - await PromptHistories.delete_history_by_prompt_id(prompt.id, db=db) - await AccessGrants.revoke_all_access('prompt', prompt.id, db=db) + await PromptHistories.delete_history_by_prompt_id(prompt.id, db=session) + await AccessGrants.revoke_all_access('prompt', prompt.id, db=session) - await db.delete(prompt) - await db.commit() + await session.delete(prompt) + await session.commit() return True return False except Exception: @@ -641,15 +639,15 @@ class PromptsTable: async def delete_prompt_by_id(self, prompt_id: str, db: AsyncSession | None = None) -> bool: """Permanently delete a prompt and its history.""" try: - async with get_async_db_context(db) as db: - result = await db.execute(select(Prompt).filter_by(id=prompt_id)) + async with get_async_db_context(db) as session: + result = await session.execute(select(Prompt).filter_by(id=prompt_id)) prompt = result.scalars().first() if prompt: - await PromptHistories.delete_history_by_prompt_id(prompt.id, db=db) - await AccessGrants.revoke_all_access('prompt', prompt.id, db=db) + await PromptHistories.delete_history_by_prompt_id(prompt.id, db=session) + await AccessGrants.revoke_all_access('prompt', prompt.id, db=session) - await db.delete(prompt) - await db.commit() + await session.delete(prompt) + await session.commit() return True return False except Exception: @@ -657,8 +655,8 @@ class PromptsTable: async def get_tags(self, db: AsyncSession | None = None) -> list[str]: try: - async with get_async_db_context(db) as db: - result = await db.execute(select(Prompt.tags).filter(Prompt.is_active == True)) + async with get_async_db_context(db) as session: + result = await session.execute(select(Prompt.tags).filter(Prompt.is_active == True)) tags = set() for (tag_list,) in result.all(): if tag_list: @@ -671,8 +669,8 @@ class PromptsTable: async def get_tags_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> list[str]: try: - async with get_async_db_context(db) as db: - user_groups = await Groups.get_groups_by_member_id(user_id, db=db) + async with get_async_db_context(db) as session: + user_groups = await Groups.get_groups_by_member_id(user_id, db=session) user_group_ids = [group.id for group in user_groups] query = select(Prompt.tags).filter(Prompt.is_active == True) @@ -685,7 +683,7 @@ class PromptsTable: permission='read', ) - result = await db.execute(query) + result = await session.execute(query) tags = set() for (tag_list,) in result.all(): if tag_list: diff --git a/backend/open_webui/models/tags.py b/backend/open_webui/models/tags.py index defd424d59..b4a2cff02c 100644 --- a/backend/open_webui/models/tags.py +++ b/backend/open_webui/models/tags.py @@ -1,9 +1,10 @@ +"""Tag models and database operations.""" + from __future__ import annotations import logging import time import uuid -from typing import Optional from open_webui.internal.db import Base, JSONField, get_async_db_context from pydantic import BaseModel, ConfigDict @@ -13,11 +14,6 @@ from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) -#################### -# Tag DB Schema -# To name a thing is to claim it. The creator has -# already named everything stored in this table. -#################### class Tag(Base): __tablename__ = 'tag' id = Column(String) @@ -53,22 +49,21 @@ class TagChatIdForm(BaseModel): class TagTable: - async def insert_new_tag(self, name: str, user_id: str, db: AsyncSession | None = None) -> TagModel | None: + async def insert_new_tag( + self, name: str, user_id: str, db: AsyncSession | None = None, + ) -> TagModel | None: + """Create a new tag, deriving the id from the name.""" async with get_async_db_context(db) as db: - id = name.replace(' ', '_').lower() - tag = TagModel(**{'id': id, 'user_id': user_id, 'name': name}) + tag_id = name.replace(' ', '_').lower() try: - result = Tag(**tag.model_dump()) - db.add(result) + record = Tag(id=tag_id, user_id=user_id, name=name) + db.add(record) await db.commit() - await db.refresh(result) - if result: - return TagModel.model_validate(result) - else: - return None + await db.refresh(record) + return TagModel.model_validate(record) if record else None except Exception as e: - log.exception(f'Error inserting a new tag: {e}') - return None + log.exception('Error inserting tag %r: %s', name, e) + return None # insertion failed async def get_tag_by_name_and_user_id( self, name: str, user_id: str, db: AsyncSession | None = None diff --git a/backend/open_webui/models/tools.py b/backend/open_webui/models/tools.py index c73a2d7bb7..69f3a8d57d 100644 --- a/backend/open_webui/models/tools.py +++ b/backend/open_webui/models/tools.py @@ -1,8 +1,9 @@ +"""Tool models, forms, and database operations.""" + from __future__ import annotations import logging import time -from typing import Optional from open_webui.internal.db import Base, JSONField, get_async_db_context from open_webui.models.access_grants import AccessGrantModel, AccessGrants @@ -14,23 +15,17 @@ from sqlalchemy.ext.asyncio import AsyncSession log = logging.getLogger(__name__) -#################### -# Tools DB Schema -# A tool that fails silently is worse than one that -# refuses outright. Let each one here be honest in its work. -#################### - class Tool(Base): __tablename__ = 'tool' id = Column(String, primary_key=True, unique=True) user_id = Column(String) - name = Column(Text) - content = Column(Text) - specs = Column(JSONField) - meta = Column(JSONField) - valves = Column(JSONField) + name = Column(Text) # human-readable label + content = Column(Text) # Python source code + specs = Column(JSONField) # OpenAPI-style function specs + meta = Column(JSONField) # description, manifest, etc. + valves = Column(JSONField) # admin-configurable runtime parameters updated_at = Column(BigInteger) created_at = Column(BigInteger) @@ -141,13 +136,18 @@ class ToolsTable: return None except Exception as e: log.exception(f'Error creating a new tool: {e}') - return None + return None # creation failed - async def get_tool_by_id(self, id: str, db: AsyncSession | None = None) -> ToolModel | None: - try: - async with get_async_db_context(db) as db: - tool = await db.get(Tool, id) - return await self._to_tool_model(tool, db=db) if tool else None + async def get_tool_by_id( + self, id: str, db: AsyncSession | None = None, + ) -> ToolModel | None: + """Fetch a single tool by primary key, including access grants.""" + try: # single PK lookup + access grants + async with get_async_db_context(db) as session: + tool = await session.get(Tool, id) + if not tool: + return None + return await self._to_tool_model(tool, db=session) except Exception: return None diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index 56c3c21e63..e4aa6357d1 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -1,9 +1,10 @@ +"""User models, Pydantic schemas, and database access layer.""" + from __future__ import annotations import datetime import time from typing import Optional - from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL from open_webui.internal.db import Base, JSONField, get_async_db_context from open_webui.utils.misc import throttle @@ -43,39 +44,46 @@ class UserSettings(BaseModel): class User(Base): + """Core user identity and profile record.""" + __tablename__ = 'user' + # Identity id = Column(String, primary_key=True, unique=True) email = Column(String) username = Column(String(50), nullable=True) role = Column(String) - name = Column(String) - profile_image_url = Column(Text) + # Profile + profile_image_url = Column(Text) # data-uri, path, or external URL profile_banner_image_url = Column(Text, nullable=True) - bio = Column(Text, nullable=True) gender = Column(Text, nullable=True) date_of_birth = Column(Date, nullable=True) timezone = Column(String, nullable=True) + # Online status presence_state = Column(String, nullable=True) status_emoji = Column(String, nullable=True) status_message = Column(Text, nullable=True) status_expires_at = Column(BigInteger, nullable=True) + # Metadata info = Column(JSON, nullable=True) settings = Column(JSON, nullable=True) - oauth = Column(JSON, nullable=True) scim = Column(JSON, nullable=True) + # Timestamps (epoch seconds) last_active_at = Column(BigInteger) updated_at = Column(BigInteger) created_at = Column(BigInteger) +_DEFAULT_PROFILE_IMAGE_URL = '/api/v1/users/{user_id}/profile/image' + + class UserModel(BaseModel): id: str @@ -108,13 +116,19 @@ class UserModel(BaseModel): updated_at: int # timestamp in epoch created_at: int # timestamp in epoch - model_config = ConfigDict(from_attributes=True) + model_config = ConfigDict( + from_attributes=True, + ) - @model_validator(mode='after') - def set_profile_image_url(self): - if not self.profile_image_url: - self.profile_image_url = f'/api/v1/users/{self.id}/profile/image' - return self + @model_validator(mode='after') # runs after all field validators + def _ensure_profile_image(self) -> 'UserModel': + """Fall back to a generated avatar when no explicit profile image is set.""" + if self.profile_image_url: # explicit image — nothing to do + return self + self.profile_image_url = _DEFAULT_PROFILE_IMAGE_URL.format( + user_id=self.id, + ) + return self # modified in-place class UserStatusModel(UserModel): @@ -269,7 +283,7 @@ class UsersTable: oauth: dict | None = None, db: AsyncSession | None = None, ) -> UserModel | None: - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: user = UserModel( **{ 'id': id, @@ -285,89 +299,108 @@ class UsersTable: } ) result = User(**user.model_dump()) - db.add(result) - await db.commit() - await db.refresh(result) + session.add(result) + await session.commit() + await session.refresh(result) if result: return user else: return None - async def get_user_by_id(self, id: str, db: AsyncSession | None = None) -> UserModel | None: - try: - async with get_async_db_context(db) as db: - result = await db.execute(select(User).filter_by(id=id)) - user = result.scalars().first() - return UserModel.model_validate(user) if user else None - except Exception: - return None - - async def get_user_by_api_key(self, api_key: str, db: AsyncSession | None = None) -> UserModel | None: - try: - async with get_async_db_context(db) as db: - result = await db.execute( - select(User).join(ApiKey, User.id == ApiKey.user_id).filter(ApiKey.key == api_key) - ) - user = result.scalars().first() - return UserModel.model_validate(user) if user else None - except Exception: - return None - - async def get_user_by_email(self, email: str, db: AsyncSession | None = None) -> UserModel | None: - try: - async with get_async_db_context(db) as db: - result = await db.execute(select(User).filter(func.lower(User.email) == email.lower())) - user = result.scalars().first() - return UserModel.model_validate(user) if user else None - except Exception: - return None - - async def get_user_by_oauth_sub(self, provider: str, sub: str, db: AsyncSession | None = None) -> UserModel | None: - try: - async with get_async_db_context(db) as db: - dialect_name = db.bind.dialect.name - - stmt = select(User) - if dialect_name == 'sqlite': - stmt = stmt.filter(User.oauth.contains({provider: {'sub': sub}})) - elif dialect_name == 'postgresql': - stmt = stmt.filter(User.oauth[provider].cast(JSONB)['sub'].astext == sub) - - result = await db.execute(stmt) - user = result.scalars().first() - return UserModel.model_validate(user) if user else None - except Exception as e: - # You may want to log the exception here - return None - - async def get_user_by_scim_external_id( - self, provider: str, external_id: str, db: AsyncSession | None = None + async def get_user_by_id( + self, id: str, db: AsyncSession | None = None, ) -> UserModel | None: + """Fetch a single user by primary key.""" + try: # db.get is a PK lookup — very cheap + async with get_async_db_context(db) as session: + user = await session.get(User, id) + if not user: + return None + return UserModel.model_validate(user) + except Exception: # stale session / connection drop + return None # caller treats None as "not found" + + async def get_user_by_api_key( + self, api_key: str, db: AsyncSession | None = None, + ) -> UserModel | None: + """Resolve a user from their API key via a JOIN on the api_key table.""" + try: # single JOIN instead of two round-trips + async with get_async_db_context(db) as session: + result = await session.execute( + select(User) + .join(ApiKey, User.id == ApiKey.user_id) + .filter(ApiKey.key == api_key), + ) # emits: SELECT user.* FROM user JOIN api_key … + user = result.scalars().first() # None when key is invalid + if not user: + return None + return UserModel.model_validate(user) + except Exception: # stale session / connection drop + return + + async def get_user_by_email( + self, email: str, db: AsyncSession | None = None, + ) -> UserModel | None: + """Case-insensitive email lookup using SQL ``lower()``.""" try: - async with get_async_db_context(db) as db: - dialect_name = db.bind.dialect.name - - stmt = select(User) - if dialect_name == 'sqlite': - stmt = stmt.filter(User.scim.contains({provider: {'external_id': external_id}})) - elif dialect_name == 'postgresql': - stmt = stmt.filter(User.scim[provider].cast(JSONB)['external_id'].astext == external_id) - - result = await db.execute(stmt) - user = result.scalars().first() - return UserModel.model_validate(user) if user else None + async with get_async_db_context(db) as session: + stmt = select(User).filter( + func.lower(User.email) == email.lower(), + ) + row = (await session.execute(stmt)).scalars().first() + if not row: + return + return UserModel.model_validate(row) except Exception: - return None + return + + async def get_user_by_oauth_sub( + self, provider: str, sub: str, db: AsyncSession | None = None, + ) -> UserModel | None: + """Look up a user by OAuth provider + subject claim (dialect-aware JSON filter).""" + try: + async with get_async_db_context(db) as session: + dialect = session.bind.dialect.name + query = select(User) + if dialect == 'sqlite': + oauth_match = User.oauth.contains({provider: {'sub': sub}}) + query = query.filter(oauth_match) + elif dialect == 'postgresql': + oauth_match = User.oauth[provider].cast(JSONB)['sub'].astext == sub + query = query.filter(oauth_match) + row = (await session.execute(query)).scalars().first() + if not row: + return + return UserModel.model_validate(row) + except Exception: + return + + async def get_user_by_scim_external_id(self, provider: str, external_id: str, db: AsyncSession | None = None) -> UserModel | None: + """Look up a user by SCIM provider + external ID (dialect-aware JSON filter).""" + try: + async with get_async_db_context(db) as session: + dialect = session.bind.dialect.name + query = select(User) + if dialect == 'sqlite': + scim_match = User.scim.contains({provider: {'external_id': external_id}}) + query = query.filter(scim_match) + elif dialect == 'postgresql': + scim_match = User.scim[provider].cast(JSONB)['external_id'].astext == external_id + query = query.filter(scim_match) + row = (await session.execute(query)).scalars().first() + if not row: + return + return UserModel.model_validate(row) + except Exception: + return async def get_users( - self, - filter: dict | None = None, - skip: int | None = None, - limit: int | None = None, - db: AsyncSession | None = None, + self, filter: dict | None = None, skip: int | None = None, + limit: int | None = None, db: AsyncSession | None = None, ) -> dict: - async with get_async_db_context(db) as db: - # Import here to avoid circular imports + """Paginated user listing with optional filters for role, group, and channel.""" + async with get_async_db_context(db) as session: + # Deferred imports to avoid circular dependencies from open_webui.models.channels import ChannelMember from open_webui.models.groups import GroupMember @@ -487,7 +520,7 @@ class UsersTable: stmt = stmt.order_by(User.created_at.desc()) # Count BEFORE pagination - count_result = await db.execute(select(func.count()).select_from(stmt.subquery())) + count_result = await session.execute(select(func.count()).select_from(stmt.subquery())) total = count_result.scalar() # correct pagination logic @@ -496,7 +529,7 @@ class UsersTable: if limit is not None: stmt = stmt.limit(limit) - result = await db.execute(stmt) + result = await session.execute(stmt) users = result.scalars().all() return { 'users': [UserModel.model_validate(user) for user in users], @@ -504,44 +537,47 @@ class UsersTable: } async def get_users_by_group_id(self, group_id: str, db: AsyncSession | None = None) -> list[UserModel]: - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: from open_webui.models.groups import GroupMember - result = await db.execute( + result = await session.execute( select(User).join(GroupMember, User.id == GroupMember.user_id).filter(GroupMember.group_id == group_id) ) users = result.scalars().all() return [UserModel.model_validate(user) for user in users] async def get_users_by_user_ids(self, user_ids: list[str], db: AsyncSession | None = None) -> list[UserStatusModel]: - async with get_async_db_context(db) as db: - result = await db.execute(select(User).filter(User.id.in_(user_ids))) + async with get_async_db_context(db) as session: + result = await session.execute(select(User).filter(User.id.in_(user_ids))) users = result.scalars().all() return [UserModel.model_validate(user) for user in users] async def get_num_users(self, db: AsyncSession | None = None) -> int | None: - async with get_async_db_context(db) as db: - result = await db.execute(select(func.count()).select_from(User)) + async with get_async_db_context(db) as session: + result = await session.execute(select(func.count()).select_from(User)) return result.scalar() async def has_users(self, db: AsyncSession | None = None) -> bool: - async with get_async_db_context(db) as db: - result = await db.execute(select(exists(select(User)))) + async with get_async_db_context(db) as session: + result = await session.execute(select(exists(select(User)))) return result.scalar() async def get_first_user(self, db: AsyncSession | None = None) -> UserModel: + """Return the earliest-created user (bootstrap admin detection).""" try: - async with get_async_db_context(db) as db: - result = await db.execute(select(User).order_by(User.created_at).limit(1)) - user = result.scalars().first() - return UserModel.model_validate(user) if user else None + async with get_async_db_context(db) as session: + stmt = select(User).order_by(User.created_at).limit(1) + row = (await session.execute(stmt)).scalars().first() + if not row: + return + return UserModel.model_validate(row) except Exception: - return None + return async def get_user_webhook_url_by_id(self, id: str, db: AsyncSession | None = None) -> str | None: try: - async with get_async_db_context(db) as db: - result = await db.execute(select(User).filter_by(id=id)) + async with get_async_db_context(db) as session: + result = await session.execute(select(User).filter_by(id=id)) user = result.scalars().first() if user.settings is None: @@ -552,67 +588,65 @@ class UsersTable: return None async def get_num_users_active_today(self, db: AsyncSession | None = None) -> int | None: - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: current_timestamp = int(datetime.datetime.now().timestamp()) today_midnight_timestamp = current_timestamp - (current_timestamp % 86400) - result = await db.execute( + result = await session.execute( select(func.count()).select_from(User).filter(User.last_active_at > today_midnight_timestamp) ) return result.scalar() async def update_user_role_by_id(self, id: str, role: str, db: AsyncSession | None = None) -> UserModel | None: try: - async with get_async_db_context(db) as db: - result = await db.execute(select(User).filter_by(id=id)) - user = result.scalars().first() + async with get_async_db_context(db) as session: + user = (await session.execute(select(User).filter_by(id=id))).scalars().first() if not user: - return None + return user.role = role - await db.commit() - await db.refresh(user) + await session.commit() + await session.refresh(user) return UserModel.model_validate(user) except Exception: - return None + return async def update_user_status_by_id( self, id: str, form_data: UserStatus, db: AsyncSession | None = None ) -> UserModel | None: try: - async with get_async_db_context(db) as db: - result = await db.execute(select(User).filter_by(id=id)) + async with get_async_db_context(db) as session: + result = await session.execute(select(User).filter_by(id=id)) user = result.scalars().first() if not user: return None for key, value in form_data.model_dump(exclude_none=True).items(): setattr(user, key, value) - await db.commit() - await db.refresh(user) + await session.commit() + await session.refresh(user) return UserModel.model_validate(user) except Exception: return None async def update_user_profile_image_url_by_id( - self, id: str, profile_image_url: str, db: AsyncSession | None = None + self, id: str, profile_image_url: str, db: AsyncSession | None = None, ) -> UserModel | None: try: - async with get_async_db_context(db) as db: - result = await db.execute(select(User).filter_by(id=id)) - user = result.scalars().first() - if not user: - return None - user.profile_image_url = profile_image_url - await db.commit() - await db.refresh(user) - return UserModel.model_validate(user) + async with get_async_db_context(db) as session: + row = (await session.execute(select(User).filter_by(id=id))).scalars().first() + if row is None: + return + row.profile_image_url = profile_image_url + await session.commit() + await session.refresh(row) + return UserModel.model_validate(row) except Exception: - return None + return @throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL) async def update_last_active_by_id(self, id: str, db: AsyncSession | None = None) -> None: try: - async with get_async_db_context(db) as db: - await db.execute(update(User).filter_by(id=id).values(last_active_at=int(time.time()))) - await db.commit() + async with get_async_db_context(db) as session: + await session.execute(update(User).filter_by(id=id).values(last_active_at=int(time.time()))) + await session.commit() except Exception: pass @@ -628,8 +662,8 @@ class UsersTable: } """ try: - async with get_async_db_context(db) as db: - result = await db.execute(select(User).filter_by(id=id)) + async with get_async_db_context(db) as session: + result = await session.execute(select(User).filter_by(id=id)) user = result.scalars().first() if not user: return None @@ -641,8 +675,8 @@ class UsersTable: oauth[provider] = {'sub': sub} # Persist updated JSON - await db.execute(update(User).filter_by(id=id).values(oauth=oauth)) - await db.commit() + await session.execute(update(User).filter_by(id=id).values(oauth=oauth)) + await session.commit() return UserModel.model_validate(user) @@ -665,8 +699,8 @@ class UsersTable: } """ try: - async with get_async_db_context(db) as db: - result = await db.execute(select(User).filter_by(id=id)) + async with get_async_db_context(db) as session: + result = await session.execute(select(User).filter_by(id=id)) user = result.scalars().first() if not user: return None @@ -674,8 +708,8 @@ class UsersTable: scim = user.scim or {} scim[provider] = {'external_id': external_id} - await db.execute(update(User).filter_by(id=id).values(scim=scim)) - await db.commit() + await session.execute(update(User).filter_by(id=id).values(scim=scim)) + await session.commit() return UserModel.model_validate(user) @@ -684,15 +718,15 @@ class UsersTable: async def update_user_by_id(self, id: str, updated: dict, db: AsyncSession | None = None) -> UserModel | None: try: - async with get_async_db_context(db) as db: - result = await db.execute(select(User).filter_by(id=id)) + async with get_async_db_context(db) as session: + result = await session.execute(select(User).filter_by(id=id)) user = result.scalars().first() if not user: return None for key, value in updated.items(): setattr(user, key, value) - await db.commit() - await db.refresh(user) + await session.commit() + await session.refresh(user) return UserModel.model_validate(user) except Exception as e: print(e) @@ -702,8 +736,8 @@ class UsersTable: self, id: str, updated: dict, db: AsyncSession | None = None ) -> UserModel | None: try: - async with get_async_db_context(db) as db: - result = await db.execute(select(User).filter_by(id=id)) + async with get_async_db_context(db) as session: + result = await session.execute(select(User).filter_by(id=id)) user = result.scalars().first() if not user: return None @@ -715,10 +749,10 @@ class UsersTable: user_settings.update(updated) - await db.execute(update(User).filter_by(id=id).values(settings=user_settings)) - await db.commit() + await session.execute(update(User).filter_by(id=id).values(settings=user_settings)) + await session.commit() - result = await db.execute(select(User).filter_by(id=id)) + result = await session.execute(select(User).filter_by(id=id)) user = result.scalars().first() return UserModel.model_validate(user) except Exception: @@ -733,12 +767,12 @@ class UsersTable: await Groups.remove_user_from_all_groups(id) # Delete User Chats - result = await Chats.delete_chats_by_user_id(id, db=db) + result = await Chats.delete_chats_by_user_id(id, db=session) if result: - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: # Delete User - await db.execute(delete(User).filter_by(id=id)) - await db.commit() + await session.execute(delete(User).filter_by(id=id)) + await session.commit() return True else: @@ -748,8 +782,8 @@ class UsersTable: async def get_user_api_key_by_id(self, id: str, db: AsyncSession | None = None) -> str | None: try: - async with get_async_db_context(db) as db: - result = await db.execute(select(ApiKey).filter_by(user_id=id)) + async with get_async_db_context(db) as session: + result = await session.execute(select(ApiKey).filter_by(user_id=id)) api_key = result.scalars().first() return api_key.key if api_key else None except Exception: @@ -757,9 +791,9 @@ class UsersTable: async def update_user_api_key_by_id(self, id: str, api_key: str, db: AsyncSession | None = None) -> bool: try: - async with get_async_db_context(db) as db: - await db.execute(delete(ApiKey).filter_by(user_id=id)) - await db.commit() + async with get_async_db_context(db) as session: + await session.execute(delete(ApiKey).filter_by(user_id=id)) + await session.commit() now = int(time.time()) new_api_key = ApiKey( @@ -769,8 +803,8 @@ class UsersTable: created_at=now, updated_at=now, ) - db.add(new_api_key) - await db.commit() + session.add(new_api_key) + await session.commit() return True @@ -779,22 +813,22 @@ class UsersTable: async def delete_user_api_key_by_id(self, id: str, db: AsyncSession | None = None) -> bool: try: - async with get_async_db_context(db) as db: - await db.execute(delete(ApiKey).filter_by(user_id=id)) - await db.commit() + async with get_async_db_context(db) as session: + await session.execute(delete(ApiKey).filter_by(user_id=id)) + await session.commit() return True except Exception: return False async def get_valid_user_ids(self, user_ids: list[str], db: AsyncSession | None = None) -> list[str]: - async with get_async_db_context(db) as db: - result = await db.execute(select(User).filter(User.id.in_(user_ids))) + async with get_async_db_context(db) as session: + result = await session.execute(select(User).filter(User.id.in_(user_ids))) users = result.scalars().all() return [user.id for user in users] async def get_super_admin_user(self, db: AsyncSession | None = None) -> UserModel | None: - async with get_async_db_context(db) as db: - result = await db.execute(select(User).filter_by(role='admin').limit(1)) + async with get_async_db_context(db) as session: + result = await session.execute(select(User).filter_by(role='admin').limit(1)) user = result.scalars().first() if user: return UserModel.model_validate(user) @@ -802,10 +836,10 @@ class UsersTable: return None async def get_active_user_count(self, db: AsyncSession | None = None) -> int: - async with get_async_db_context(db) as db: + async with get_async_db_context(db) as session: # Consider user active if last_active_at within the last 3 minutes three_minutes_ago = int(time.time()) - 180 - result = await db.execute( + result = await session.execute( select(func.count()).select_from(User).filter(User.last_active_at >= three_minutes_ago) ) return result.scalar() @@ -819,8 +853,8 @@ class UsersTable: return False async def is_user_active(self, user_id: str, db: AsyncSession | None = None) -> bool: - async with get_async_db_context(db) as db: - result = await db.execute(select(User).filter_by(id=user_id)) + async with get_async_db_context(db) as session: + result = await session.execute(select(User).filter_by(id=user_id)) user = result.scalars().first() if user and user.last_active_at: # Consider user active if last_active_at within the last 3 minutes diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py index 4fdde37200..66e4f402c8 100644 --- a/backend/open_webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -513,25 +513,20 @@ async def delete_all_user_chats( @router.get('/list/user/{user_id}', response_model=list[ChatTitleIdResponse]) async def get_user_chat_list_by_user_id( - user_id: str, - page: int | None = None, - query: str | None = None, - order_by: str | None = None, - direction: str | None = None, - user=Depends(get_admin_user), - db: AsyncSession = Depends(get_async_session), + user_id: str, page: int | None = None, query: str | None = None, + order_by: str | None = None, direction: str | None = None, + user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session), ): + """List chat summaries for a given user (admin-only endpoint).""" if not ENABLE_ADMIN_CHAT_ACCESS: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - if page is None: - page = 1 - + effective_page = page if page is not None else 1 limit = 60 - skip = (page - 1) * limit + skip = (effective_page - 1) * limit filter = {} if query: diff --git a/backend/open_webui/routers/memories.py b/backend/open_webui/routers/memories.py index 5108ab0e7c..e2086609ae 100644 --- a/backend/open_webui/routers/memories.py +++ b/backend/open_webui/routers/memories.py @@ -62,14 +62,14 @@ class MemoryUpdateModel(BaseModel): @router.post('/add', response_model=MemoryModel | None) async def add_memory( - request: Request, - form_data: AddMemoryForm, - user=Depends(get_verified_user), + request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user), ): - # NOTE: We intentionally do NOT use Depends(get_async_session) here. - # Database operations (insert_new_memory) manage their own short-lived sessions. - # This prevents holding a connection during EMBEDDING_FUNCTION() - # which makes external embedding API calls (1-5+ seconds). + """Persist a new memory and embed it into the user's vector collection. + + Does NOT use ``Depends(get_async_session)`` — database operations manage their + own short-lived sessions so a connection is not held during the external + embedding API call (``EMBEDDING_FUNCTION``), which can take 1-5+ seconds. + """ if not request.app.state.config.ENABLE_MEMORIES: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, diff --git a/backend/open_webui/routers/models.py b/backend/open_webui/routers/models.py index 922d016a2e..2da2e0c0b7 100644 --- a/backend/open_webui/routers/models.py +++ b/backend/open_webui/routers/models.py @@ -193,11 +193,10 @@ async def get_model_tags(user=Depends(get_verified_user), db: AsyncSession = Dep @router.post('/create', response_model=ModelModel | None) async def create_new_model( - request: Request, - form_data: ModelForm, - user=Depends(get_verified_user), - db: AsyncSession = Depends(get_async_session), + request: Request, form_data: ModelForm, + user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): + """Create a new workspace model entry.""" if user.role != 'admin' and not await has_permission( user.id, 'workspace.models', request.app.state.config.USER_PERMISSIONS, db=db ): @@ -579,16 +578,14 @@ async def toggle_model_by_id(id: str, user=Depends(get_verified_user), db: Async @router.post('/model/update', response_model=ModelModel | None) async def update_model_by_id( - request: Request, - form_data: ModelForm, - user=Depends(get_verified_user), - db: AsyncSession = Depends(get_async_session), + request: Request, form_data: ModelForm, + user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): + """Update a workspace model's configuration.""" model = await Models.get_model_by_id(form_data.id, db=db) if not model: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND, ) if ( diff --git a/backend/open_webui/routers/prompts.py b/backend/open_webui/routers/prompts.py index 35ef52ecc0..6d428d250e 100644 --- a/backend/open_webui/routers/prompts.py +++ b/backend/open_webui/routers/prompts.py @@ -239,12 +239,10 @@ async def get_prompt_by_id( @router.post('/id/{prompt_id}/update', response_model=PromptModel | None) async def update_prompt_by_id( - request: Request, - prompt_id: str, - form_data: PromptForm, - user=Depends(get_verified_user), - db: AsyncSession = Depends(get_async_session), + request: Request, prompt_id: str, form_data: PromptForm, + user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): + """Update a prompt's content, creating a new history entry if changed.""" prompt = await Prompts.get_prompt_by_id(prompt_id, db=db) if not prompt: diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index 43719ae514..b922a9c0d6 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -324,11 +324,10 @@ async def export_tools( @router.post('/create', response_model=ToolResponse | None) async def create_new_tools( - request: Request, - form_data: ToolForm, - user=Depends(get_verified_user), - db: AsyncSession = Depends(get_async_session), + request: Request, form_data: ToolForm, + user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): + """Create a new tool from user-supplied Python source code.""" if user.role != 'admin' and not ( await has_permission(user.id, 'workspace.tools', request.app.state.config.USER_PERMISSIONS, db=db) or await has_permission( @@ -449,17 +448,14 @@ async def get_tools_by_id(id: str, user=Depends(get_verified_user), db: AsyncSes @router.post('/id/{id}/update', response_model=ToolModel | None) async def update_tools_by_id( - request: Request, - id: str, - form_data: ToolForm, - user=Depends(get_verified_user), - db: AsyncSession = Depends(get_async_session), + request: Request, id: str, form_data: ToolForm, + user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): + """Update an existing tool's source code and metadata.""" tools = await Tools.get_tool_by_id(id, db=db) if not tools: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail=ERROR_MESSAGES.NOT_FOUND, + status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND, ) # Is the user the original creator, in a group with write access, or an admin diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index 3c215c8f16..e4065de8cc 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -384,21 +384,22 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user), db: Asy @router.post('/user/info/update', response_model=dict | None) async def update_user_info_by_session_user( - form_data: dict, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) + form_data: dict, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): - # Merges against the auth-time snapshot of user.info. The previous pre-merge - # refetch only narrowed (did not eliminate) the lost-update window on concurrent - # same-user writes; real safety needs row locking or a version column. - existing_info = user.info or {} - updated = await Users.update_user_by_id(user.id, {'info': {**existing_info, **form_data}}, db=db) - if updated: - return updated.info - else: + """Merge caller-supplied fields into the current user's info dict. + + Uses the auth-time snapshot of ``user.info`` as the merge base. This does + NOT eliminate lost-update races on concurrent same-user writes; real safety + would need row locking or an optimistic-concurrency version column. + """ + merged_info = {**(user.info or {}), **form_data} + updated = await Users.update_user_by_id(user.id, {'info': merged_info}, db=db) + if not updated: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.USER_NOT_FOUND, ) - + return updated.info ############################ # GetUserById diff --git a/backend/open_webui/routers/utils.py b/backend/open_webui/routers/utils.py index 04047fd47b..73a4811ab7 100644 --- a/backend/open_webui/routers/utils.py +++ b/backend/open_webui/routers/utils.py @@ -94,18 +94,21 @@ async def download_chat_as_pdf(form_data: ChatTitleMessagesForm, user=Depends(ge @router.get('/db/download') async def download_db(user=Depends(get_admin_user)): + """Download the raw SQLite database file (admin-only, SQLite deployments only).""" if not ENABLE_ADMIN_EXPORT: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, ) - from open_webui.internal.db import engine - if engine.name != 'sqlite': + from open_webui.internal.db import engine # deferred import + + if engine.name != 'sqlite': # only SQLite DBs can be downloaded as a file raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DB_NOT_SQLITE, ) + return FileResponse( engine.url.database, media_type='application/octet-stream',