This commit is contained in:
Timothy Jaeryang Baek
2026-05-21 14:01:57 +04:00
parent cac4c6da2e
commit 260ead64da
19 changed files with 772 additions and 771 deletions
+8 -1
View File
@@ -118,8 +118,15 @@ reattach_ssl_mode_to_url = reattach_ssl_params_to_url
class JSONField(types.TypeDecorator):
"""Store arbitrary Python objects as JSON-encoded TEXT.
Used instead of native JSON columns for portability across SQLite and
PostgreSQL. Values are serialized with ``json.dumps`` on write and
deserialized with ``json.loads`` on read.
"""
impl = types.Text
cache_ok = True
cache_ok = True # safe for statement caching (no per-instance state)
def process_bind_param(self, value: _T | None, dialect: Dialect) -> Any:
return json.dumps(value)
+8 -1
View File
@@ -2912,10 +2912,12 @@ async def readiness_check():
@app.get('/health/db')
async def healthcheck_with_db():
"""Verify database connectivity by issuing a lightweight ping."""
await async_db_ping()
return {'status': True}
# Serve build-time static assets (CSS, JS, images, favicon, etc.)
app.mount('/static', StaticFiles(directory=STATIC_DIR), name='static')
@@ -2924,8 +2926,13 @@ async def serve_cache_file(
path: str,
user=Depends(get_verified_user),
):
"""Serve cached files (e.g. tool outputs) with path-traversal protection.
Only ``image/*``, ``audio/*``, and ``video/*`` MIME types are served inline;
everything else gets a ``Content-Disposition: attachment`` header to prevent
XSS from user-generated HTML stored in the cache directory.
"""
file_path = os.path.abspath(os.path.join(CACHE_DIR, path))
# prevent path traversal
if not file_path.startswith(os.path.abspath(CACHE_DIR)):
raise HTTPException(status_code=404, detail='File not found')
if not os.path.isfile(file_path):
+99 -82
View File
@@ -1,3 +1,5 @@
"""Authentication models and database operations."""
from __future__ import annotations
import logging
@@ -13,33 +15,30 @@ from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
####################
# DB MODEL
####################
class Auth(Base):
"""Credential record linking a user identity to an email + hashed password."""
__tablename__ = 'auth'
id = Column(String, primary_key=True, unique=True)
email = Column(String)
password = Column(Text)
active = Column(Boolean)
id = Column(String, primary_key=True, unique=True) # same as User.id
email = Column(String) # login email, kept in sync with User.email
password = Column(Text) # bcrypt / argon2 hash
active = Column(Boolean) # soft-disable flag
class AuthModel(BaseModel):
"""Pydantic mirror of the Auth table row."""
id: str
email: str
password: str
active: bool = True
####################
# Forms
####################
class Token(BaseModel):
"""JWT bearer token response."""
token: str
token_type: str
@@ -100,112 +99,130 @@ class AuthsTable:
oauth: dict | None = None,
db: AsyncSession | None = None,
) -> UserModel | None:
"""Create an Auth + User pair in a single transaction."""
async with get_async_db_context(db) as db:
log.info('insert_new_auth')
id = str(uuid.uuid4())
user_id = str(uuid.uuid4())
auth = AuthModel(**{'id': id, 'email': email, 'password': password, 'active': True})
result = Auth(**auth.model_dump())
db.add(result)
record = Auth(
id=user_id,
email=email,
password=password,
active=True,
)
db.add(record)
user = await Users.insert_new_user(id, name, email, profile_image_url, role, oauth=oauth, db=db)
user = await Users.insert_new_user(
user_id, name, email, profile_image_url, role, oauth=oauth, db=db,
)
await db.commit()
await db.refresh(result)
await db.refresh(record)
if result and user:
return user
else:
return None
return user if record and user else None
async def authenticate_user(
self, email: str, verify_password: callable, db: AsyncSession | None = None
) -> UserModel | None:
"""Verify a user's email + password and return the user on success."""
log.info(f'authenticate_user: {email}')
user = await Users.get_user_by_email(email, db=db)
if not user:
return None
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(Auth).filter_by(id=user.id, active=True))
auth = result.scalars().first()
if auth:
if verify_password(auth.password):
return user
else:
return None
else:
try: # load the auth row for password verification
async with get_async_db_context(db) as session:
auth = await session.get(Auth, user.id)
if not auth or not auth.active:
return None
if not verify_password(auth.password):
return None
return user
except Exception:
return None
async def authenticate_user_by_api_key(self, api_key: str, db: AsyncSession | None = None) -> UserModel | None:
log.info(f'authenticate_user_by_api_key')
# if no api_key, return None
if not api_key:
return None
return
async def authenticate_user_by_api_key(
self, api_key: str, db: AsyncSession | None = None,
) -> UserModel | None:
"""Resolve an API key to its owning user, returning ``None`` on miss."""
log.info('authenticate_user_by_api_key')
if not api_key: # empty / None key — reject immediately
return
try:
user = await Users.get_user_by_api_key(api_key, db=db)
return user if user else None
return await Users.get_user_by_api_key(api_key, db=db)
except Exception:
return False
async def authenticate_user_by_email(self, email: str, db: AsyncSession | None = None) -> UserModel | None:
log.info(f'authenticate_user_by_email: {email}')
async def authenticate_user_by_email(
self,
email: str,
db: AsyncSession | None = None,
) -> UserModel | None:
"""One-query authentication: JOIN Auth ↔ User, filter by email + active flag."""
log.info('authenticate_user_by_email: %s', email)
try:
async with get_async_db_context(db) as db:
# Single JOIN query instead of two separate queries
result = await db.execute(
select(Auth, User).join(User, Auth.id == User.id).filter(Auth.email == email, Auth.active == True)
async with get_async_db_context(db) as session:
stmt = (
select(Auth, User)
.join(User, Auth.id == User.id)
.filter(Auth.email == email, Auth.active == True)
)
row = result.first()
if row:
_, user = row
return UserModel.model_validate(user)
return None
row = (await session.execute(stmt)).first()
if not row:
return
_auth, matched_user = row
return UserModel.model_validate(matched_user)
except Exception:
return None
return
async def update_user_password_by_id(self, id: str, new_password: str, db: AsyncSession | None = None) -> bool:
async def update_user_password_by_id(
self,
id: str,
new_password: str,
db: AsyncSession | None = None,
) -> bool:
"""Hash-swap: replace the stored password hash for a given user."""
try:
async with get_async_db_context(db) as db:
result = await db.execute(update(Auth).filter_by(id=id).values(password=new_password))
await db.commit()
return True if result.rowcount == 1 else False
async with get_async_db_context(db) as session:
stmt = update(Auth).filter_by(id=id).values(password=new_password)
result = await session.execute(stmt)
await session.commit()
return result.rowcount == 1
except Exception:
return False
async def update_email_by_id(self, id: str, email: str, db: AsyncSession | None = None) -> bool:
async def update_email_by_id(
self, id: str, email: str, db: AsyncSession | None = None,
) -> bool:
"""Update the auth email and propagate the change to the User table."""
try:
async with get_async_db_context(db) as db:
result = await db.execute(update(Auth).filter_by(id=id).values(email=email))
await db.commit()
if result.rowcount == 1:
await Users.update_user_by_id(id, {'email': email}, db=db)
return True
return False
except Exception:
return False
async def delete_auth_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
# Delete User
result = await Users.delete_user_by_id(id, db=db)
if result:
await db.execute(delete(Auth).filter_by(id=id))
await db.commit()
return True
else:
async with get_async_db_context(db) as session:
stmt = update(Auth).filter_by(id=id).values(email=email)
result = await session.execute(stmt)
await session.commit()
if result.rowcount != 1:
return False
await Users.update_user_by_id(id, {'email': email}, db=session)
return True
except Exception:
return False
async def delete_auth_by_id(
self, id: str, db: AsyncSession | None = None,
) -> bool:
"""Delete a user and their auth record in a single transaction."""
try: # delete user first, then auth (FK order)
async with get_async_db_context(db) as session:
if not await Users.delete_user_by_id(id, db=session):
return False # user deletion failed — abort
await session.execute(delete(Auth).filter_by(id=id))
await session.commit()
return True
except Exception: # db / integrity error
return False # partial deletion is rolled back by context manager
Auths = AuthsTable()
+180 -178
View File
@@ -1,10 +1,11 @@
"""Chat models, forms, and database operations."""
from __future__ import annotations
import json
import logging
import time
import uuid
from typing import Optional
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.models.automations import AutomationRun
@@ -35,12 +36,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import exists
from sqlalchemy.sql.expression import bindparam
####################
# Chat DB Schema
# Let no word spoken in this house be lost, and when the
# record is read again, let it still serve the one who spoke.
####################
log = logging.getLogger(__name__)
@@ -49,14 +44,14 @@ class Chat(Base):
id = Column(String, primary_key=True, unique=True)
user_id = Column(String)
title = Column(Text)
title = Column(Text) # user-visible conversation title
chat = Column(JSON)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
share_id = Column(Text, unique=True, nullable=True)
archived = Column(Boolean, default=False)
share_id = Column(Text, unique=True, nullable=True) # public share link token
archived = Column(Boolean, default=False) # hidden from main chat list
pinned = Column(Boolean, default=False, nullable=True)
meta = Column(JSON, server_default='{}')
@@ -302,7 +297,7 @@ class ChatTable:
async def insert_new_chat(
self, id: str, user_id: str, form_data: ChatForm, db: AsyncSession | None = None
) -> ChatModel | None:
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
chat = ChatModel(
**{
'id': id,
@@ -318,9 +313,9 @@ class ChatTable:
)
chat_item = Chat(**chat.model_dump())
db.add(chat_item)
await db.commit()
await db.refresh(chat_item)
session.add(chat_item)
await session.commit()
await session.refresh(chat_item)
# Dual-write initial messages to chat_message table
try:
@@ -362,7 +357,7 @@ class ChatTable:
chat_import_forms: list[ChatImportForm],
db: AsyncSession | None = None,
) -> list[ChatModel]:
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
chats = []
for form_data in chat_import_forms:
@@ -370,7 +365,7 @@ class ChatTable:
chats.append(Chat(**chat.model_dump()))
db.add_all(chats)
await db.commit()
await session.commit()
# Dual-write messages to chat_message table
for form_data, chat_obj in zip(chat_import_forms, chats):
@@ -390,10 +385,13 @@ class ChatTable:
return [ChatModel.model_validate(chat) for chat in chats]
async def update_chat_by_id(self, id: str, chat: dict, db: AsyncSession | None = None) -> ChatModel | None:
try:
async with get_async_db_context(db) as db:
chat_item = await db.get(Chat, id)
async def update_chat_by_id(
self, id: str, chat: dict, db: AsyncSession | None = None,
) -> ChatModel | None:
"""Persist updated chat content, sanitizing null bytes."""
try: # load the chat record for in-place mutation
async with get_async_db_context(db) as session:
chat_item = await session.get(Chat, id)
if chat_item is None:
return None
@@ -402,19 +400,19 @@ class ChatTable:
chat_item.updated_at = int(time.time())
await db.commit()
await session.commit()
return ChatModel.model_validate(chat_item)
except Exception:
return None
return
async def update_chat_last_read_at_by_id(self, id: str, user_id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
chat = await db.get(Chat, id)
async with get_async_db_context(db) as session:
chat = await session.get(Chat, id)
if chat and chat.user_id == user_id:
chat.last_read_at = int(time.time())
await db.commit()
await session.commit()
return True
return False
except Exception:
@@ -423,22 +421,22 @@ class ChatTable:
async def update_chat_title_by_id(self, id: str, title: str) -> ChatModel | None:
try:
async with get_async_db_context() as db:
chat_item = await db.get(Chat, id)
chat_item = await session.get(Chat, id)
if chat_item is None:
return None
clean_title = self._clean_null_bytes(title)
chat_item.title = clean_title
chat_item.chat = {**(chat_item.chat or {}), 'title': clean_title}
chat_item.updated_at = int(time.time())
await db.commit()
await db.refresh(chat_item)
await session.commit()
await session.refresh(chat_item)
return ChatModel.model_validate(chat_item)
except Exception:
return None
async def update_chat_tags_by_id(self, id: str, tags: list[str], user) -> ChatModel | None:
async with get_async_db_context() as db:
chat = await db.get(Chat, id)
chat = await session.get(Chat, id)
if chat is None:
return None
@@ -448,22 +446,22 @@ class ChatTable:
# Single meta update
chat.meta = {**chat.meta, 'tags': new_tag_ids}
await db.commit()
await db.refresh(chat)
await session.commit()
await session.refresh(chat)
# Batch-create any missing tag rows
await Tags.ensure_tags_exist(new_tags, user.id, db=db)
await Tags.ensure_tags_exist(new_tags, user.id, db=session)
# Clean up orphaned old tags in one query
removed = set(old_tags) - set(new_tag_ids)
if removed:
await self.delete_orphan_tags_for_user(list(removed), user.id, db=db)
await self.delete_orphan_tags_for_user(list(removed), user.id, db=session)
return ChatModel.model_validate(chat)
async def get_chat_title_by_id(self, id: str) -> str | None:
async with get_async_db_context() as db:
result = await db.execute(select(Chat.title).filter_by(id=id))
result = await session.execute(select(Chat.title).filter_by(id=id))
row = result.first()
if row is None:
return None
@@ -638,7 +636,7 @@ class ChatTable:
async def add_message_files_by_id_and_message_id(self, id: str, message_id: str, files: list[dict]) -> list[dict]:
async with get_async_db_context() as db:
chat = await self.get_chat_by_id(id, db=db)
chat = await self.get_chat_by_id(id, db=session)
if chat is None:
return None
@@ -653,61 +651,62 @@ class ChatTable:
history['messages'][message_id]['files'] = message_files
chat['history'] = history
await self.update_chat_by_id(id, chat, db=db)
await self.update_chat_by_id(id, chat, db=session)
return message_files
async def insert_shared_chat_by_chat_id(self, chat_id: str, db: AsyncSession | None = None) -> ChatModel | None:
"""Create a shared snapshot for a chat. Returns the original chat with share_id set."""
from open_webui.models.shared_chats import SharedChats
async with get_async_db_context(db) as db:
chat = await db.get(Chat, chat_id)
async with get_async_db_context(db) as session:
chat = await session.get(Chat, chat_id)
if not chat:
return None
# If already shared, just update the existing snapshot
if chat.share_id:
return await self.update_shared_chat_by_chat_id(chat_id, db=db)
return await self.update_shared_chat_by_chat_id(chat_id, db=session)
shared = await SharedChats.create(chat_id, chat.user_id, db=db)
shared = await SharedChats.create(chat_id, chat.user_id, db=session)
if not shared:
return None
# Set share_id on the original chat
chat.share_id = shared.id
await db.commit()
await db.refresh(chat)
return ChatModel.model_validate(chat)
await session.commit()
await session.refresh(chat)
return ChatModel.model_validate(chat) # return the updated original
async def update_shared_chat_by_chat_id(self, chat_id: str, db: AsyncSession | None = None) -> ChatModel | None:
"""Re-snapshot the shared chat with current chat data."""
from open_webui.models.shared_chats import SharedChats
async def update_shared_chat_by_chat_id(
self, chat_id: str, db: AsyncSession | None = None,
) -> ChatModel | None:
"""Re-snapshot the shared copy with the latest chat content."""
from open_webui.models.shared_chats import SharedChats # deferred — circular
try:
async with get_async_db_context(db) as db:
chat = await db.get(Chat, chat_id)
async with get_async_db_context(db) as session:
chat = await session.get(Chat, chat_id)
if not chat or not chat.share_id:
return await self.insert_shared_chat_by_chat_id(chat_id, db=db)
await SharedChats.update(chat.share_id, db=db)
return await self.insert_shared_chat_by_chat_id(chat_id, db=session)
await SharedChats.update(chat.share_id, db=session)
return ChatModel.model_validate(chat)
except Exception:
return None
return
async def delete_shared_chat_by_chat_id(self, chat_id: str, db: AsyncSession | None = None) -> bool:
"""Delete shared snapshot for a chat."""
from open_webui.models.shared_chats import SharedChats
try:
return await SharedChats.delete_by_chat_id(chat_id, db=db)
return await SharedChats.delete_by_chat_id(chat_id, db=session)
except Exception:
return False
async def unarchive_all_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
await db.execute(update(Chat).filter_by(user_id=user_id).values(archived=False))
await db.commit()
async with get_async_db_context(db) as session:
await session.execute(update(Chat).filter_by(user_id=user_id).values(archived=False))
await session.commit()
return True
except Exception:
return False
@@ -716,45 +715,45 @@ class ChatTable:
self, id: str, share_id: str | None, db: AsyncSession | None = None
) -> ChatModel | None:
try:
async with get_async_db_context(db) as db:
chat = await db.get(Chat, id)
async with get_async_db_context(db) as session:
chat = await session.get(Chat, id)
chat.share_id = share_id
await db.commit()
await db.refresh(chat)
await session.commit()
await session.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
async def toggle_chat_pinned_by_id(self, id: str, db: AsyncSession | None = None) -> ChatModel | None:
try:
async with get_async_db_context(db) as db:
chat = await db.get(Chat, id)
async with get_async_db_context(db) as session:
chat = await session.get(Chat, id)
chat.pinned = not chat.pinned
chat.updated_at = int(time.time())
await db.commit()
await db.refresh(chat)
await session.commit()
await session.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
async def toggle_chat_archive_by_id(self, id: str, db: AsyncSession | None = None) -> ChatModel | None:
try:
async with get_async_db_context(db) as db:
chat = await db.get(Chat, id)
async with get_async_db_context(db) as session:
chat = await session.get(Chat, id)
chat.archived = not chat.archived
chat.folder_id = None
chat.updated_at = int(time.time())
await db.commit()
await db.refresh(chat)
await session.commit()
await session.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
async def archive_all_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
await db.execute(update(Chat).filter_by(user_id=user_id).values(archived=True))
await db.commit()
async with get_async_db_context(db) as session:
await session.execute(update(Chat).filter_by(user_id=user_id).values(archived=True))
await session.commit()
return True
except Exception:
return False
@@ -767,7 +766,7 @@ class ChatTable:
limit: int = 50,
db: AsyncSession | None = None,
) -> list[ChatTitleIdResponse]:
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at).filter_by(
user_id=user_id, archived=True
)
@@ -798,7 +797,7 @@ class ChatTable:
if limit:
stmt = stmt.limit(limit)
result = await db.execute(stmt)
result = await session.execute(stmt)
all_chats = result.all()
return [
ChatTitleIdResponse.model_validate(
@@ -823,7 +822,7 @@ class ChatTable:
"""Delegate to SharedChats for listing shared chats by user."""
from open_webui.models.shared_chats import SharedChats
return await SharedChats.get_by_user_id(user_id, filter=filter, skip=skip, limit=limit, db=db)
return await SharedChats.get_by_user_id(user_id, filter=filter, skip=skip, limit=limit, db=session)
async def get_chat_list_by_user_id(
self,
@@ -834,7 +833,7 @@ class ChatTable:
limit: int = 50,
db: AsyncSession | None = None,
) -> list[ChatTitleIdResponse]:
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by(
user_id=user_id
)
@@ -864,7 +863,7 @@ class ChatTable:
if limit:
stmt = stmt.limit(limit)
result = await db.execute(stmt)
result = await session.execute(stmt)
all_chats = result.all()
return [
ChatTitleIdResponse.model_validate(
@@ -889,7 +888,7 @@ class ChatTable:
limit: int | None = None,
db: AsyncSession | None = None,
) -> list[ChatTitleIdResponse]:
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by(
user_id=user_id
)
@@ -910,7 +909,7 @@ class ChatTable:
if limit:
stmt = stmt.limit(limit)
result = await db.execute(stmt)
result = await session.execute(stmt)
all_chats = result.all()
return [
@@ -933,23 +932,26 @@ class ChatTable:
limit: int = 50,
db: AsyncSession | None = None,
) -> list[ChatModel]:
async with get_async_db_context(db) as db:
result = await db.execute(
async with get_async_db_context(db) as session:
result = await session.execute(
select(Chat).filter(Chat.id.in_(chat_ids)).filter_by(archived=False).order_by(Chat.updated_at.desc())
)
all_chats = result.scalars().all()
return [ChatModel.model_validate(chat) for chat in all_chats]
async def get_chat_by_id(self, id: str, db: AsyncSession | None = None) -> ChatModel | None:
async def get_chat_by_id(
self, id: str, db: AsyncSession | None = None,
) -> ChatModel | None:
"""Fetch a chat by PK, auto-sanitizing null bytes on read."""
try:
async with get_async_db_context(db) as db:
chat_item = await db.get(Chat, id)
async with get_async_db_context(db) as session:
chat_item = await session.get(Chat, id)
if chat_item is None:
return None
if self._sanitize_chat_row(chat_item):
await db.commit()
await db.refresh(chat_item)
await session.commit()
await session.refresh(chat_item)
return ChatModel.model_validate(chat_item)
except Exception:
@@ -960,7 +962,7 @@ class ChatTable:
from open_webui.models.shared_chats import SharedChats
try:
shared = await SharedChats.get_by_id(id, db=db)
shared = await SharedChats.get_by_id(id, db=session)
if shared:
# Return a ChatModel-compatible view of the snapshot
return ChatModel(
@@ -980,8 +982,8 @@ class ChatTable:
self, id: str, user_id: str, db: AsyncSession | None = None
) -> ChatModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(Chat).filter_by(id=id, user_id=user_id))
async with get_async_db_context(db) as session:
result = await session.execute(select(Chat).filter_by(id=id, user_id=user_id))
chat = result.scalars().first()
return ChatModel.model_validate(chat) if chat else None
except Exception:
@@ -993,8 +995,8 @@ class ChatTable:
the full Chat row (which includes the potentially large JSON blob).
"""
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(exists().where(and_(Chat.id == id, Chat.user_id == user_id))))
async with get_async_db_context(db) as session:
result = await session.execute(select(exists().where(and_(Chat.id == id, Chat.user_id == user_id))))
return result.scalar()
except Exception:
return False
@@ -1005,16 +1007,16 @@ class ChatTable:
JSON blob. Returns None if chat doesn't exist or doesn't belong to user.
"""
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(Chat.folder_id).filter_by(id=id, user_id=user_id))
async with get_async_db_context(db) as session:
result = await session.execute(select(Chat.folder_id).filter_by(id=id, user_id=user_id))
row = result.first()
return row[0] if row else None
except Exception:
return None
async def get_chats(self, skip: int = 0, limit: int = 50, db: AsyncSession | None = None) -> list[ChatModel]:
async with get_async_db_context(db) as db:
result = await db.execute(select(Chat).order_by(Chat.updated_at.desc()))
async with get_async_db_context(db) as session:
result = await session.execute(select(Chat).order_by(Chat.updated_at.desc()))
all_chats = result.scalars().all()
return [ChatModel.model_validate(chat) for chat in all_chats]
@@ -1026,7 +1028,7 @@ class ChatTable:
limit: int | None = None,
db: AsyncSession | None = None,
) -> ChatListResponse:
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
stmt = select(Chat).filter_by(user_id=user_id)
if filter:
@@ -1048,7 +1050,7 @@ class ChatTable:
else:
stmt = stmt.order_by(Chat.updated_at.desc(), Chat.id)
count_result = await db.execute(select(func.count()).select_from(stmt.subquery()))
count_result = await session.execute(select(func.count()).select_from(stmt.subquery()))
total = count_result.scalar()
if skip is not None:
@@ -1056,7 +1058,7 @@ class ChatTable:
if limit is not None:
stmt = stmt.limit(limit)
result = await db.execute(stmt)
result = await session.execute(stmt)
all_chats = result.scalars().all()
return ChatListResponse(
@@ -1069,8 +1071,8 @@ class ChatTable:
async def get_pinned_chats_by_user_id(
self, user_id: str, db: AsyncSession | None = None
) -> list[ChatTitleIdResponse]:
async with get_async_db_context(db) as db:
result = await db.execute(
async with get_async_db_context(db) as session:
result = await session.execute(
select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at)
.filter_by(user_id=user_id, pinned=True, archived=False)
.order_by(Chat.updated_at.desc())
@@ -1090,8 +1092,8 @@ class ChatTable:
]
async def get_archived_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> list[ChatModel]:
async with get_async_db_context(db) as db:
result = await db.execute(
async with get_async_db_context(db) as session:
result = await session.execute(
select(Chat).filter_by(user_id=user_id, archived=True).order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in result.scalars().all()]
@@ -1161,7 +1163,7 @@ class ChatTable:
search_text = ' '.join(search_text_words)
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
stmt = select(Chat).filter(Chat.user_id == user_id)
if is_archived is not None:
@@ -1184,7 +1186,7 @@ class ChatTable:
stmt = stmt.order_by(Chat.updated_at.desc(), Chat.id)
# Check if the database dialect is either 'sqlite' or 'postgresql'
bind = await db.connection()
bind = await session.connection()
dialect_name = bind.dialect.name
if dialect_name == 'sqlite':
# SQLite case: using JSON1 extension for JSON searching
@@ -1282,7 +1284,7 @@ class ChatTable:
# Perform pagination at the SQL level
stmt = stmt.offset(skip).limit(limit)
result = await db.execute(stmt)
result = await session.execute(stmt)
all_chats = result.scalars().all()
log.info(f'The number of chats: {len(all_chats)}')
@@ -1298,7 +1300,7 @@ class ChatTable:
limit: int = 60,
db: AsyncSession | None = None,
) -> list[ChatTitleIdResponse]:
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
stmt = (
select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at)
.filter_by(folder_id=folder_id, user_id=user_id)
@@ -1312,7 +1314,7 @@ class ChatTable:
if limit:
stmt = stmt.limit(limit)
result = await db.execute(stmt)
result = await session.execute(stmt)
all_chats = result.all()
return [
ChatTitleIdResponse.model_validate(
@@ -1330,7 +1332,7 @@ class ChatTable:
async def get_chats_by_folder_ids_and_user_id(
self, folder_ids: list[str], user_id: str, db: AsyncSession | None = None
) -> list[ChatModel]:
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
stmt = (
select(Chat)
.filter(Chat.folder_id.in_(folder_ids), Chat.user_id == user_id)
@@ -1339,7 +1341,7 @@ class ChatTable:
.order_by(Chat.updated_at.desc())
)
result = await db.execute(stmt)
result = await session.execute(stmt)
all_chats = result.scalars().all()
return [ChatModel.model_validate(chat) for chat in all_chats]
@@ -1347,13 +1349,13 @@ class ChatTable:
self, id: str, user_id: str, folder_id: str, db: AsyncSession | None = None
) -> ChatModel | None:
try:
async with get_async_db_context(db) as db:
chat = await db.get(Chat, id)
async with get_async_db_context(db) as session:
chat = await session.get(Chat, id)
chat.folder_id = folder_id
chat.updated_at = int(time.time())
chat.pinned = False
await db.commit()
await db.refresh(chat)
await session.commit()
await session.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
@@ -1361,12 +1363,12 @@ class ChatTable:
async def get_chat_tags_by_id_and_user_id(
self, id: str, user_id: str, db: AsyncSession | None = None
) -> list[TagModel]:
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
stmt = select(Chat.meta).where(Chat.id == id)
result = await db.execute(stmt)
result = await session.execute(stmt)
meta = result.scalar_one_or_none()
tag_ids = (meta or {}).get('tags', [])
return await Tags.get_tags_by_ids_and_user_id(tag_ids, user_id, db=db)
return await Tags.get_tags_by_ids_and_user_id(tag_ids, user_id, db=session)
async def get_chat_list_by_user_id_and_tag_name(
self,
@@ -1376,13 +1378,13 @@ class ChatTable:
limit: int = 50,
db: AsyncSession | None = None,
) -> list[ChatTitleIdResponse]:
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by(
user_id=user_id
)
tag_id = tag_name.replace(' ', '_').lower()
bind = await db.connection()
bind = await session.connection()
dialect_name = bind.dialect.name
log.info(f'DB dialect name: {dialect_name}')
if dialect_name == 'sqlite':
@@ -1403,7 +1405,7 @@ class ChatTable:
if limit:
stmt = stmt.limit(limit)
result = await db.execute(stmt)
result = await session.execute(stmt)
all_chats = result.all()
return [
ChatTitleIdResponse.model_validate(
@@ -1422,17 +1424,17 @@ class ChatTable:
self, id: str, user_id: str, tag_name: str, db: AsyncSession | None = None
) -> ChatModel | None:
tag_id = tag_name.replace(' ', '_').lower()
await Tags.ensure_tags_exist([tag_name], user_id, db=db)
await Tags.ensure_tags_exist([tag_name], user_id, db=session)
try:
async with get_async_db_context(db) as db:
chat = await db.get(Chat, id)
async with get_async_db_context(db) as session:
chat = await session.get(Chat, id)
if tag_id not in chat.meta.get('tags', []):
chat.meta = {
**chat.meta,
'tags': list(set(chat.meta.get('tags', []) + [tag_id])),
}
await db.commit()
await db.refresh(chat)
await session.commit()
await session.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
@@ -1440,11 +1442,11 @@ class ChatTable:
async def count_chats_by_tag_name_and_user_id(
self, tag_name: str, user_id: str, db: AsyncSession | None = None
) -> int:
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
stmt = select(func.count(Chat.id)).filter_by(user_id=user_id, archived=False)
tag_id = tag_name.replace(' ', '_').lower()
bind = await db.connection()
bind = await session.connection()
dialect_name = bind.dialect.name
if dialect_name == 'sqlite':
stmt = stmt.filter(
@@ -1457,7 +1459,7 @@ class ChatTable:
else:
raise NotImplementedError(f'Unsupported dialect: {dialect_name}')
result = await db.execute(stmt)
result = await session.execute(stmt)
return result.scalar()
async def delete_orphan_tags_for_user(
@@ -1477,19 +1479,19 @@ class ChatTable:
"""
if not tag_ids:
return
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
orphans = []
for tag_id in tag_ids:
count = await self.count_chats_by_tag_name_and_user_id(tag_id, user_id, db=db)
count = await self.count_chats_by_tag_name_and_user_id(tag_id, user_id, db=session)
if count <= threshold:
orphans.append(tag_id)
await Tags.delete_tags_by_ids_and_user_id(orphans, user_id, db=db)
await Tags.delete_tags_by_ids_and_user_id(orphans, user_id, db=session)
async def count_chats_by_folder_id_and_user_id(
self, folder_id: str, user_id: str, db: AsyncSession | None = None
) -> int:
async with get_async_db_context(db) as db:
result = await db.execute(select(func.count(Chat.id)).filter_by(user_id=user_id, folder_id=folder_id))
async with get_async_db_context(db) as session:
result = await session.execute(select(func.count(Chat.id)).filter_by(user_id=user_id, folder_id=folder_id))
count = result.scalar()
log.info(f"Count of chats for folder '{folder_id}': {count}")
@@ -1499,8 +1501,8 @@ class ChatTable:
self, id: str, user_id: str, tag_name: str, db: AsyncSession | None = None
) -> bool:
try:
async with get_async_db_context(db) as db:
chat = await db.get(Chat, id)
async with get_async_db_context(db) as session:
chat = await session.get(Chat, id)
tags = chat.meta.get('tags', [])
tag_id = tag_name.replace(' ', '_').lower()
@@ -1509,51 +1511,51 @@ class ChatTable:
**chat.meta,
'tags': list(set(tags)),
}
await db.commit()
await session.commit()
return True
except Exception:
return False
async def delete_chat_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
await db.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None))
await db.execute(delete(ChatMessage).filter_by(chat_id=id))
await db.execute(delete(Chat).filter_by(id=id))
await db.commit()
async with get_async_db_context(db) as session:
await session.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None))
await session.execute(delete(ChatMessage).filter_by(chat_id=id))
await session.execute(delete(Chat).filter_by(id=id))
await session.commit()
return True and await self.delete_shared_chat_by_chat_id(id, db=db)
return True and await self.delete_shared_chat_by_chat_id(id, db=session)
except Exception:
return False
async def delete_chat_by_id_and_user_id(self, id: str, user_id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
await db.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None))
await db.execute(delete(ChatMessage).filter_by(chat_id=id))
await db.execute(delete(Chat).filter_by(id=id, user_id=user_id))
await db.commit()
async with get_async_db_context(db) as session:
await session.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None))
await session.execute(delete(ChatMessage).filter_by(chat_id=id))
await session.execute(delete(Chat).filter_by(id=id, user_id=user_id))
await session.commit()
return True and await self.delete_shared_chat_by_chat_id(id, db=db)
return True and await self.delete_shared_chat_by_chat_id(id, db=session)
except Exception:
return False
async def delete_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
await self.delete_shared_chats_by_user_id(user_id, db=db)
async with get_async_db_context(db) as session:
await self.delete_shared_chats_by_user_id(user_id, db=session)
chat_id_subquery = select(Chat.id).filter_by(user_id=user_id).scalar_subquery()
await db.execute(
await session.execute(
update(AutomationRun)
.filter(AutomationRun.chat_id.in_(select(Chat.id).filter_by(user_id=user_id)))
.values(chat_id=None)
)
await db.execute(
await session.execute(
delete(ChatMessage).filter(ChatMessage.chat_id.in_(select(Chat.id).filter_by(user_id=user_id)))
)
await db.execute(delete(Chat).filter_by(user_id=user_id))
await db.commit()
await session.execute(delete(Chat).filter_by(user_id=user_id))
await session.commit()
return True
except Exception:
@@ -1563,14 +1565,14 @@ class ChatTable:
self, user_id: str, folder_id: str, db: AsyncSession | None = None
) -> bool:
try:
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
chat_ids_stmt = select(Chat.id).filter_by(user_id=user_id, folder_id=folder_id)
await db.execute(
await session.execute(
update(AutomationRun).filter(AutomationRun.chat_id.in_(chat_ids_stmt)).values(chat_id=None)
)
await db.execute(delete(ChatMessage).filter(ChatMessage.chat_id.in_(chat_ids_stmt)))
await db.execute(delete(Chat).filter_by(user_id=user_id, folder_id=folder_id))
await db.commit()
await session.execute(delete(ChatMessage).filter(ChatMessage.chat_id.in_(chat_ids_stmt)))
await session.execute(delete(Chat).filter_by(user_id=user_id, folder_id=folder_id))
await session.commit()
return True
except Exception:
@@ -1584,11 +1586,11 @@ class ChatTable:
db: AsyncSession | None = None,
) -> bool:
try:
async with get_async_db_context(db) as db:
await db.execute(
async with get_async_db_context(db) as session:
await session.execute(
update(Chat).filter_by(user_id=user_id, folder_id=folder_id).values(folder_id=new_folder_id)
)
await db.commit()
await session.commit()
return True
except Exception:
@@ -1600,13 +1602,13 @@ class ChatTable:
from open_webui.models.shared_chats import SharedChats
try:
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
# Delete shared_chat rows for this user's chats
await db.execute(delete(SharedChatTable).filter_by(user_id=user_id))
await session.execute(delete(SharedChatTable).filter_by(user_id=user_id))
# Clear share_id on all of this user's chats
await db.execute(update(Chat).filter_by(user_id=user_id).values(share_id=None))
await db.commit()
await session.execute(update(Chat).filter_by(user_id=user_id).values(share_id=None))
await session.commit()
return True
except Exception:
@@ -1624,7 +1626,7 @@ class ChatTable:
return None
chat_message_file_ids = {
item.id for item in await self.get_chat_files_by_chat_id_and_message_id(chat_id, message_id, db=db)
item.id for item in await self.get_chat_files_by_chat_id_and_message_id(chat_id, message_id, db=session)
}
# Remove duplicates and existing file_ids
file_ids = list({file_id for file_id in file_ids if file_id and file_id not in chat_message_file_ids})
@@ -1632,7 +1634,7 @@ class ChatTable:
return None
try:
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
now = int(time.time())
chat_files = [
@@ -1651,7 +1653,7 @@ class ChatTable:
results = [ChatFile(**chat_file.model_dump()) for chat_file in chat_files]
db.add_all(results)
await db.commit()
await session.commit()
return chat_files
except Exception:
@@ -1660,8 +1662,8 @@ class ChatTable:
async def get_chat_files_by_chat_id_and_message_id(
self, chat_id: str, message_id: str, db: AsyncSession | None = None
) -> list[ChatFileModel]:
async with get_async_db_context(db) as db:
result = await db.execute(
async with get_async_db_context(db) as session:
result = await session.execute(
select(ChatFile).filter_by(chat_id=chat_id, message_id=message_id).order_by(ChatFile.created_at.asc())
)
all_chat_files = result.scalars().all()
@@ -1669,17 +1671,17 @@ class ChatTable:
async def delete_chat_file(self, chat_id: str, file_id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
await db.execute(delete(ChatFile).filter_by(chat_id=chat_id, file_id=file_id))
await db.commit()
async with get_async_db_context(db) as session:
await session.execute(delete(ChatFile).filter_by(chat_id=chat_id, file_id=file_id))
await session.commit()
return True
except Exception:
return False
async def get_shared_chat_ids_by_file_id(self, file_id: str, db: AsyncSession | None = None) -> list[str]:
"""Return IDs of chats that contain this file and have an active share link."""
async with get_async_db_context(db) as db:
result = await db.execute(
async with get_async_db_context(db) as session:
result = await session.execute(
select(Chat.id)
.join(ChatFile, Chat.id == ChatFile.chat_id)
.filter(ChatFile.file_id == file_id, Chat.share_id.isnot(None))
@@ -1690,12 +1692,12 @@ class ChatTable:
"""Update the tasks list on a chat."""
try:
async with get_async_db_context() as db:
chat = await db.get(Chat, id)
chat = await session.get(Chat, id)
if chat is None:
return None
chat.tasks = tasks
await db.commit()
await db.refresh(chat)
await session.commit()
await session.refresh(chat)
return ChatModel.model_validate(chat)
except Exception:
return None
@@ -1703,7 +1705,7 @@ class ChatTable:
async def get_chat_tasks_by_id(self, id: str) -> list[dict]:
"""Read the tasks list from a chat (lightweight column query)."""
async with get_async_db_context() as db:
result = await db.execute(select(Chat.tasks).filter_by(id=id))
result = await session.execute(select(Chat.tasks).filter_by(id=id))
row = result.first()
if row is None or row[0] is None:
return []
+11 -19
View File
@@ -1,8 +1,9 @@
"""File upload models, forms, and database operations."""
from __future__ import annotations
import logging
import time
from typing import Optional
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.utils.misc import sanitize_metadata
@@ -12,12 +13,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
####################
# Files DB Schema
# What is written here bears witness. Let the testimony
# remain as it was given, and let none tamper with it.
####################
class File(Base):
__tablename__ = 'file'
@@ -25,7 +20,7 @@ class File(Base):
user_id = Column(String)
hash = Column(Text, nullable=True)
filename = Column(Text)
filename = Column(Text) # original upload filename
path = Column(Text, nullable=True)
data = Column(JSON, nullable=True)
@@ -52,11 +47,6 @@ class FileModel(BaseModel):
updated_at: int | None # timestamp in epoch
####################
# Forms
####################
class FileMeta(BaseModel):
name: str | None = None
content_type: str | None = None
@@ -157,16 +147,18 @@ class FilesTable:
return None
except Exception as e:
log.exception(f'Error inserting a new file: {e}')
return None
return None # insertion failed
async def get_file_by_id(self, id: str, db: AsyncSession | None = None) -> FileModel | None:
async def get_file_by_id(
self, id: str, db: AsyncSession | None = None,
) -> FileModel | None:
"""Look up a file by its primary key."""
try:
async with get_async_db_context(db) as db:
try:
file = await db.get(File, id)
return FileModel.model_validate(file) if file else None
except Exception:
file = await db.get(File, id)
if not file:
return None
return FileModel.model_validate(file)
except Exception:
return None
+6 -11
View File
@@ -1,8 +1,9 @@
"""Function (filter/action/pipe) models, forms, and database operations."""
from __future__ import annotations
import logging
import time
from typing import Optional
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.models.users import UserModel, UserResponse, Users
@@ -12,12 +13,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
####################
# Functions DB Schema
# Each function here is a promise made. Let no promise
# go unkept, and let none be called who cannot answer.
####################
class Function(Base):
__tablename__ = 'function'
@@ -30,11 +25,11 @@ class Function(Base):
meta = Column(JSONField)
valves = Column(JSONField)
is_active = Column(Boolean)
is_global = Column(Boolean)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
is_global = Column(Boolean) # if True, applied to every chat automatically
updated_at = Column(BigInteger) # epoch seconds
created_at = Column(BigInteger) # epoch seconds
__table_args__ = (Index('is_global_idx', 'is_global'),)
__table_args__ = (Index('is_global_idx', 'is_global'),) # speed up global-function lookups
class FunctionMeta(BaseModel):
+21 -35
View File
@@ -1,3 +1,5 @@
"""Long-term memory storage for per-user context recall."""
from __future__ import annotations
import time
@@ -9,38 +11,30 @@ from pydantic import BaseModel, ConfigDict
from sqlalchemy import BigInteger, Column, String, Text, delete, select
from sqlalchemy.ext.asyncio import AsyncSession
####################
# Memory DB Schema
# What was learned at cost should not need to be paid
# for again. Let the memory hold.
####################
class Memory(Base):
"""Persistent user memory backed by a vector collection."""
__tablename__ = 'memory'
id = Column(String, primary_key=True, unique=True)
user_id = Column(String, index=True)
content = Column(Text)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
content = Column(Text) # free-form text learned from conversation
updated_at = Column(BigInteger) # epoch seconds
created_at = Column(BigInteger) # epoch seconds
class MemoryModel(BaseModel):
"""Pydantic mirror of the Memory table row."""
id: str
user_id: str
content: str
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
class MemoriesTable:
async def insert_new_memory(
self,
@@ -48,26 +42,20 @@ class MemoriesTable:
content: str,
db: AsyncSession | None = None,
) -> MemoryModel | None:
"""Persist a new memory entry and return the created model."""
async with get_async_db_context(db) as db:
id = str(uuid.uuid4())
memory = MemoryModel(
**{
'id': id,
'user_id': user_id,
'content': content,
'created_at': int(time.time()),
'updated_at': int(time.time()),
}
now = int(time.time())
record = Memory(
id=str(uuid.uuid4()),
user_id=user_id,
content=content,
created_at=now,
updated_at=now,
)
result = Memory(**memory.model_dump())
db.add(result)
db.add(record)
await db.commit()
await db.refresh(result)
if result:
return MemoryModel.model_validate(result)
else:
return None
await db.refresh(record)
return MemoryModel.model_validate(record) if record else None
async def update_memory_by_id_and_user_id(
self,
@@ -143,12 +131,10 @@ class MemoriesTable:
try:
memory = await db.get(Memory, id)
if not memory or memory.user_id != user_id:
return None
return False
# Delete the memory
await db.delete(memory)
await db.commit()
return True
except Exception:
return False
+14 -36
View File
@@ -60,38 +60,19 @@ class ModelMeta(BaseModel):
class Model(Base):
"""Workspace model entry — wraps an upstream LLM with custom params and metadata."""
__tablename__ = 'model'
id = Column(Text, primary_key=True, unique=True)
"""
The model's id as used in the API. If set to an existing model, it will override the model.
"""
user_id = Column(Text)
base_model_id = Column(Text, nullable=True)
"""
An optional pointer to the actual model that should be used when proxying requests.
"""
name = Column(Text)
"""
The human-readable display name of the model.
"""
params = Column(JSONField)
"""
Holds a JSON encoded blob of parameters, see `ModelParams`.
"""
meta = Column(JSONField)
"""
Holds a JSON encoded blob of metadata, see `ModelMeta`.
"""
is_active = Column(Boolean, default=True)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
id = Column(Text, primary_key=True, unique=True) # API model identifier; overrides built-in when matching
user_id = Column(Text) # owner
base_model_id = Column(Text, nullable=True) # actual upstream model for proxied requests
name = Column(Text) # human-readable display name
params = Column(JSONField) # see ModelParams
meta = Column(JSONField) # see ModelMeta
is_active = Column(Boolean, default=True) # soft-disable toggle
updated_at = Column(BigInteger) # epoch seconds
created_at = Column(BigInteger) # epoch seconds
class ModelModel(BaseModel):
@@ -109,12 +90,9 @@ class ModelModel(BaseModel):
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
####################
model_config = ConfigDict(
from_attributes=True,
)
class ModelUserResponse(ModelModel):
+147 -149
View File
@@ -1,3 +1,5 @@
"""Prompt template models, forms, and database operations."""
from __future__ import annotations
import json
@@ -14,23 +16,19 @@ from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import JSON, BigInteger, Boolean, Column, String, Text, cast, delete, func, or_, select, text, update
from sqlalchemy.ext.asyncio import AsyncSession
####################
# Prompts DB Schema
# Every word here was weighed before it was set down.
# Let the weight not be wasted when it is spoken aloud.
####################
class Prompt(Base):
"""System prompt template with versioning support."""
__tablename__ = 'prompt'
id = Column(Text, primary_key=True)
command = Column(String, unique=True, index=True)
user_id = Column(String)
name = Column(Text)
content = Column(Text)
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
content = Column(Text) # the prompt template body
data = Column(JSON, nullable=True) # structured prompt parameters
meta = Column(JSON, nullable=True) # freeform metadata (description, etc.)
tags = Column(JSON, nullable=True)
is_active = Column(Boolean, default=True)
version_id = Column(Text, nullable=True) # Points to active history entry
@@ -94,7 +92,7 @@ class PromptForm(BaseModel):
class PromptsTable:
async def _get_access_grants(self, prompt_id: str, db: AsyncSession | None = None) -> list[AccessGrantModel]:
return await AccessGrants.get_grants_by_resource('prompt', prompt_id, db=db)
return await AccessGrants.get_grants_by_resource('prompt', prompt_id, db=session)
async def _to_prompt_model(
self,
@@ -104,98 +102,94 @@ class PromptsTable:
) -> PromptModel:
prompt_data = PromptModel.model_validate(prompt).model_dump(exclude={'access_grants'})
prompt_data['access_grants'] = (
access_grants if access_grants is not None else await self._get_access_grants(prompt_data['id'], db=db)
access_grants if access_grants is not None else await self._get_access_grants(prompt_data['id'], db=session)
)
return PromptModel.model_validate(prompt_data)
async def insert_new_prompt(
self, user_id: str, form_data: PromptForm, db: AsyncSession | None = None
) -> PromptModel | None:
async def insert_new_prompt(self, user_id: str, form_data: PromptForm, db: AsyncSession | None = None) -> PromptModel | None:
now = int(time.time())
prompt_id = str(uuid.uuid4())
prompt = PromptModel(
id=prompt_id,
user_id=user_id,
command=form_data.command,
name=form_data.name,
content=form_data.content,
data=form_data.data or {},
meta=form_data.meta or {},
tags=form_data.tags or [],
access_grants=[],
is_active=True,
created_at=now,
updated_at=now,
)
async with get_async_db_context(db) as session:
try:
record = Prompt(
id=prompt_id, user_id=user_id,
command=form_data.command, name=form_data.name,
content=form_data.content,
data=form_data.data or {}, meta=form_data.meta or {},
tags=form_data.tags or [], is_active=True,
created_at=now, updated_at=now,
)
session.add(record)
await session.commit()
await session.refresh(record) # populate generated defaults
try:
async with get_async_db_context(db) as db:
result = Prompt(**prompt.model_dump(exclude={'access_grants'}))
db.add(result)
await db.commit()
await db.refresh(result)
await AccessGrants.set_access_grants('prompt', prompt_id, form_data.access_grants, db=db)
await AccessGrants.set_access_grants(
'prompt', prompt_id, form_data.access_grants, db=session,
) # persist sharing rules
if result:
current_access_grants = await self._get_access_grants(prompt_id, db=db)
snapshot = {
'name': form_data.name,
'content': form_data.content,
'command': form_data.command,
'data': form_data.data or {},
'meta': form_data.meta or {},
'tags': form_data.tags or [],
'access_grants': [grant.model_dump() for grant in current_access_grants],
}
history_entry = await PromptHistories.create_history_entry(
prompt_id=prompt_id,
snapshot=snapshot,
user_id=user_id,
parent_id=None, # Initial commit has no parent
commit_message=form_data.commit_message or 'Initial version',
db=db,
)
# Set the initial version as the production version
if history_entry:
result.version_id = history_entry.id
await db.commit()
await db.refresh(result)
return await self._to_prompt_model(result, db=db)
else:
if not record: # shouldn't happen, but guard anyway
return None
except Exception:
return None
# Build the initial version snapshot.
grants = await self._get_access_grants(prompt_id, db=session)
snapshot = {
'name': form_data.name,
'content': form_data.content,
'command': form_data.command,
'data': form_data.data or {},
'meta': form_data.meta or {},
'tags': form_data.tags or [],
'access_grants': [g.model_dump() for g in grants],
}
history_entry = await PromptHistories.create_history_entry(
prompt_id=prompt_id, snapshot=snapshot,
user_id=user_id, parent_id=None,
commit_message=form_data.commit_message or 'Initial version',
db=session,
) # creates the first version entry
# Pin the initial history entry as the production version.
if history_entry:
record.version_id = history_entry.id
await session.commit()
await session.refresh(record) # re-read version_id
return await self._to_prompt_model(record, db=session)
except Exception as e:
log.exception('Error creating prompt: %s', e)
return None
async def get_prompt_by_id(self, prompt_id: str, db: AsyncSession | None = None) -> PromptModel | None:
"""Get prompt by UUID."""
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(Prompt).filter_by(id=prompt_id))
prompt = result.scalars().first()
if prompt:
return await self._to_prompt_model(prompt, db=db)
return None
except Exception:
return None
async with get_async_db_context(db) as session:
result = await session.execute(
select(Prompt).filter_by(id=prompt_id),
)
prompt = result.scalars().first() # None when not found
if not prompt:
return None
return await self._to_prompt_model(prompt, db=session)
except Exception: # connection / integrity error
return
async def get_prompt_by_command(self, command: str, db: AsyncSession | None = None) -> PromptModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(Prompt).filter_by(command=command))
prompt = result.scalars().first()
if prompt:
return await self._to_prompt_model(prompt, db=db)
return None
except Exception:
return None
async with get_async_db_context(db) as session:
result = await session.execute(
select(Prompt).filter_by(command=command),
)
prompt = result.scalars().first() # None when no match
if not prompt:
return None
return await self._to_prompt_model(prompt, db=session)
except Exception: # connection / integrity error
return
async def get_prompts(self, db: AsyncSession | None = None) -> list[PromptUserResponse]:
async with get_async_db_context(db) as db:
result = await db.execute(
async with get_async_db_context(db) as session:
result = await session.execute(
select(Prompt).filter(Prompt.is_active == True).order_by(Prompt.updated_at.desc())
)
all_prompts = result.scalars().all()
@@ -203,9 +197,9 @@ class PromptsTable:
user_ids = list(set(prompt.user_id for prompt in all_prompts))
prompt_ids = [prompt.id for prompt in all_prompts]
users = await Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
users = await Users.get_users_by_user_ids(user_ids, db=session) if user_ids else []
users_dict = {user.id: user for user in users}
grants_map = await AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=db)
grants_map = await AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=session)
prompts = []
for prompt in all_prompts:
@@ -217,7 +211,7 @@ class PromptsTable:
await self._to_prompt_model(
prompt,
access_grants=grants_map.get(prompt.id, []),
db=db,
db=session,
)
).model_dump(),
'user': user.model_dump() if user else None,
@@ -230,8 +224,8 @@ class PromptsTable:
async def get_prompts_by_user_id(
self, user_id: str, permission: str = 'write', db: AsyncSession | None = None
) -> list[PromptUserResponse]:
async with get_async_db_context(db) as db:
user_groups = await Groups.get_groups_by_member_id(user_id, db=db)
async with get_async_db_context(db) as session:
user_groups = await Groups.get_groups_by_member_id(user_id, db=session)
user_group_ids = [group.id for group in user_groups]
query = select(Prompt).filter(Prompt.is_active == True).order_by(Prompt.updated_at.desc())
@@ -244,7 +238,7 @@ class PromptsTable:
permission=permission,
)
result = await db.execute(query)
result = await session.execute(query)
accessible_prompts = result.scalars().all()
if not accessible_prompts:
@@ -253,9 +247,9 @@ class PromptsTable:
prompt_ids = [p.id for p in accessible_prompts]
owner_ids = list({p.user_id for p in accessible_prompts})
users = await Users.get_users_by_user_ids(owner_ids, db=db)
users = await Users.get_users_by_user_ids(owner_ids, db=session)
users_dict = {u.id: u for u in users}
grants_map = await AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=db)
grants_map = await AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=session)
results = []
for prompt in accessible_prompts:
@@ -284,7 +278,7 @@ class PromptsTable:
limit: int = 30,
db: AsyncSession | None = None,
) -> PromptListResponse:
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
# Join with User table for user filtering and sorting
query = select(Prompt, User).outerjoin(User, User.id == Prompt.user_id)
@@ -319,7 +313,7 @@ class PromptsTable:
tag = filter.get('tag')
if tag:
bind = await db.connection()
bind = await session.connection()
dialect_name = bind.dialect.name
tag_lower = tag.lower()
@@ -367,7 +361,7 @@ class PromptsTable:
query = query.order_by(Prompt.updated_at.desc())
# Count BEFORE pagination
count_result = await db.execute(select(func.count()).select_from(query.subquery()))
count_result = await session.execute(select(func.count()).select_from(query.subquery()))
total = count_result.scalar()
if skip:
@@ -375,11 +369,11 @@ class PromptsTable:
if limit:
query = query.limit(limit)
result = await db.execute(query)
result = await session.execute(query)
items = result.all()
prompt_ids = [prompt.id for prompt, _ in items]
grants_map = await AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=db)
grants_map = await AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=session)
prompts = []
for prompt, user in items:
@@ -406,15 +400,15 @@ class PromptsTable:
db: AsyncSession | None = None,
) -> PromptModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(Prompt).filter_by(command=command))
async with get_async_db_context(db) as session:
result = await session.execute(select(Prompt).filter_by(command=command))
prompt = result.scalars().first()
if not prompt:
return None
latest_history = await PromptHistories.get_latest_history_entry(prompt.id, db=db)
latest_history = await PromptHistories.get_latest_history_entry(prompt.id, db=session)
parent_id = latest_history.id if latest_history else None
current_access_grants = await self._get_access_grants(prompt.id, db=db)
current_access_grants = await self._get_access_grants(prompt.id, db=session)
# Check if content changed to decide on history creation
content_changed = (
@@ -430,10 +424,10 @@ class PromptsTable:
prompt.meta = form_data.meta or prompt.meta
prompt.updated_at = int(time.time())
if form_data.access_grants is not None:
await AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=db)
current_access_grants = await self._get_access_grants(prompt.id, db=db)
await AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=session)
current_access_grants = await self._get_access_grants(prompt.id, db=session)
await db.commit()
await session.commit()
# Create history entry only if content changed
if content_changed:
@@ -458,9 +452,9 @@ class PromptsTable:
# Set as production if flag is True (default)
if form_data.is_production and history_entry:
prompt.version_id = history_entry.id
await db.commit()
await session.commit()
return await self._to_prompt_model(prompt, db=db)
return await self._to_prompt_model(prompt, db=session)
except Exception:
return None
@@ -472,15 +466,15 @@ class PromptsTable:
db: AsyncSession | None = None,
) -> PromptModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(Prompt).filter_by(id=prompt_id))
async with get_async_db_context(db) as session:
result = await session.execute(select(Prompt).filter_by(id=prompt_id))
prompt = result.scalars().first()
if not prompt:
return None
latest_history = await PromptHistories.get_latest_history_entry(prompt.id, db=db)
latest_history = await PromptHistories.get_latest_history_entry(prompt.id, db=session)
parent_id = latest_history.id if latest_history else None
current_access_grants = await self._get_access_grants(prompt.id, db=db)
current_access_grants = await self._get_access_grants(prompt.id, db=session)
# Check if content changed to decide on history creation
content_changed = (
@@ -502,12 +496,12 @@ class PromptsTable:
prompt.tags = form_data.tags
if form_data.access_grants is not None:
await AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=db)
current_access_grants = await self._get_access_grants(prompt.id, db=db)
await AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=session)
current_access_grants = await self._get_access_grants(prompt.id, db=session)
prompt.updated_at = int(time.time())
await db.commit()
await session.commit()
# Create history entry only if content changed
if content_changed:
@@ -533,9 +527,9 @@ class PromptsTable:
# Set as production if flag is True (default)
if form_data.is_production and history_entry:
prompt.version_id = history_entry.id
await db.commit()
await session.commit()
return await self._to_prompt_model(prompt, db=db)
return await self._to_prompt_model(prompt, db=session)
except Exception:
return None
@@ -549,8 +543,8 @@ class PromptsTable:
) -> PromptModel | None:
"""Update only name, command, and tags (no history created)."""
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(Prompt).filter_by(id=prompt_id))
async with get_async_db_context(db) as session:
result = await session.execute(select(Prompt).filter_by(id=prompt_id))
prompt = result.scalars().first()
if not prompt:
return None
@@ -562,9 +556,9 @@ class PromptsTable:
prompt.tags = tags
prompt.updated_at = int(time.time())
await db.commit()
await session.commit()
return await self._to_prompt_model(prompt, db=db)
return await self._to_prompt_model(prompt, db=session)
except Exception:
return None
@@ -576,13 +570,13 @@ class PromptsTable:
) -> PromptModel | None:
"""Set the active version of a prompt and restore content from that version's snapshot."""
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(Prompt).filter_by(id=prompt_id))
async with get_async_db_context(db) as session:
result = await session.execute(select(Prompt).filter_by(id=prompt_id))
prompt = result.scalars().first()
if not prompt:
return None
history_entry = await PromptHistories.get_history_entry_by_id(version_id, db=db)
history_entry = await PromptHistories.get_history_entry_by_id(version_id, db=session)
if not history_entry:
return None
@@ -599,24 +593,28 @@ class PromptsTable:
prompt.version_id = version_id
prompt.updated_at = int(time.time())
await db.commit()
await session.commit()
return await self._to_prompt_model(prompt, db=db)
except Exception:
return await self._to_prompt_model(prompt, db=session)
except Exception: # connection error
return None
async def toggle_prompt_active(self, prompt_id: str, db: AsyncSession | None = None) -> PromptModel | None:
# --- Active state management ---
async def toggle_prompt_active(
self, prompt_id: str, db: AsyncSession | None = None,
) -> PromptModel | None:
"""Toggle the is_active flag on a prompt."""
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(Prompt).filter_by(id=prompt_id))
async with get_async_db_context(db) as session:
result = await session.execute(select(Prompt).filter_by(id=prompt_id))
prompt = result.scalars().first()
if prompt:
prompt.is_active = not prompt.is_active
prompt.updated_at = int(time.time())
await db.commit()
await db.refresh(prompt)
return await self._to_prompt_model(prompt, db=db)
await session.commit()
await session.refresh(prompt)
return await self._to_prompt_model(prompt, db=session)
return None
except Exception:
return None
@@ -624,15 +622,15 @@ class PromptsTable:
async def delete_prompt_by_command(self, command: str, db: AsyncSession | None = None) -> bool:
"""Permanently delete a prompt and its history."""
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(Prompt).filter_by(command=command))
async with get_async_db_context(db) as session:
result = await session.execute(select(Prompt).filter_by(command=command))
prompt = result.scalars().first()
if prompt:
await PromptHistories.delete_history_by_prompt_id(prompt.id, db=db)
await AccessGrants.revoke_all_access('prompt', prompt.id, db=db)
await PromptHistories.delete_history_by_prompt_id(prompt.id, db=session)
await AccessGrants.revoke_all_access('prompt', prompt.id, db=session)
await db.delete(prompt)
await db.commit()
await session.delete(prompt)
await session.commit()
return True
return False
except Exception:
@@ -641,15 +639,15 @@ class PromptsTable:
async def delete_prompt_by_id(self, prompt_id: str, db: AsyncSession | None = None) -> bool:
"""Permanently delete a prompt and its history."""
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(Prompt).filter_by(id=prompt_id))
async with get_async_db_context(db) as session:
result = await session.execute(select(Prompt).filter_by(id=prompt_id))
prompt = result.scalars().first()
if prompt:
await PromptHistories.delete_history_by_prompt_id(prompt.id, db=db)
await AccessGrants.revoke_all_access('prompt', prompt.id, db=db)
await PromptHistories.delete_history_by_prompt_id(prompt.id, db=session)
await AccessGrants.revoke_all_access('prompt', prompt.id, db=session)
await db.delete(prompt)
await db.commit()
await session.delete(prompt)
await session.commit()
return True
return False
except Exception:
@@ -657,8 +655,8 @@ class PromptsTable:
async def get_tags(self, db: AsyncSession | None = None) -> list[str]:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(Prompt.tags).filter(Prompt.is_active == True))
async with get_async_db_context(db) as session:
result = await session.execute(select(Prompt.tags).filter(Prompt.is_active == True))
tags = set()
for (tag_list,) in result.all():
if tag_list:
@@ -671,8 +669,8 @@ class PromptsTable:
async def get_tags_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> list[str]:
try:
async with get_async_db_context(db) as db:
user_groups = await Groups.get_groups_by_member_id(user_id, db=db)
async with get_async_db_context(db) as session:
user_groups = await Groups.get_groups_by_member_id(user_id, db=session)
user_group_ids = [group.id for group in user_groups]
query = select(Prompt.tags).filter(Prompt.is_active == True)
@@ -685,7 +683,7 @@ class PromptsTable:
permission='read',
)
result = await db.execute(query)
result = await session.execute(query)
tags = set()
for (tag_list,) in result.all():
if tag_list:
+13 -18
View File
@@ -1,9 +1,10 @@
"""Tag models and database operations."""
from __future__ import annotations
import logging
import time
import uuid
from typing import Optional
from open_webui.internal.db import Base, JSONField, get_async_db_context
from pydantic import BaseModel, ConfigDict
@@ -13,11 +14,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
####################
# Tag DB Schema
# To name a thing is to claim it. The creator has
# already named everything stored in this table.
####################
class Tag(Base):
__tablename__ = 'tag'
id = Column(String)
@@ -53,22 +49,21 @@ class TagChatIdForm(BaseModel):
class TagTable:
async def insert_new_tag(self, name: str, user_id: str, db: AsyncSession | None = None) -> TagModel | None:
async def insert_new_tag(
self, name: str, user_id: str, db: AsyncSession | None = None,
) -> TagModel | None:
"""Create a new tag, deriving the id from the name."""
async with get_async_db_context(db) as db:
id = name.replace(' ', '_').lower()
tag = TagModel(**{'id': id, 'user_id': user_id, 'name': name})
tag_id = name.replace(' ', '_').lower()
try:
result = Tag(**tag.model_dump())
db.add(result)
record = Tag(id=tag_id, user_id=user_id, name=name)
db.add(record)
await db.commit()
await db.refresh(result)
if result:
return TagModel.model_validate(result)
else:
return None
await db.refresh(record)
return TagModel.model_validate(record) if record else None
except Exception as e:
log.exception(f'Error inserting a new tag: {e}')
return None
log.exception('Error inserting tag %r: %s', name, e)
return None # insertion failed
async def get_tag_by_name_and_user_id(
self, name: str, user_id: str, db: AsyncSession | None = None
+18 -18
View File
@@ -1,8 +1,9 @@
"""Tool models, forms, and database operations."""
from __future__ import annotations
import logging
import time
from typing import Optional
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.models.access_grants import AccessGrantModel, AccessGrants
@@ -14,23 +15,17 @@ from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
####################
# Tools DB Schema
# A tool that fails silently is worse than one that
# refuses outright. Let each one here be honest in its work.
####################
class Tool(Base):
__tablename__ = 'tool'
id = Column(String, primary_key=True, unique=True)
user_id = Column(String)
name = Column(Text)
content = Column(Text)
specs = Column(JSONField)
meta = Column(JSONField)
valves = Column(JSONField)
name = Column(Text) # human-readable label
content = Column(Text) # Python source code
specs = Column(JSONField) # OpenAPI-style function specs
meta = Column(JSONField) # description, manifest, etc.
valves = Column(JSONField) # admin-configurable runtime parameters
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
@@ -141,13 +136,18 @@ class ToolsTable:
return None
except Exception as e:
log.exception(f'Error creating a new tool: {e}')
return None
return None # creation failed
async def get_tool_by_id(self, id: str, db: AsyncSession | None = None) -> ToolModel | None:
try:
async with get_async_db_context(db) as db:
tool = await db.get(Tool, id)
return await self._to_tool_model(tool, db=db) if tool else None
async def get_tool_by_id(
self, id: str, db: AsyncSession | None = None,
) -> ToolModel | None:
"""Fetch a single tool by primary key, including access grants."""
try: # single PK lookup + access grants
async with get_async_db_context(db) as session:
tool = await session.get(Tool, id)
if not tool:
return None
return await self._to_tool_model(tool, db=session)
except Exception:
return None
+201 -167
View File
@@ -1,9 +1,10 @@
"""User models, Pydantic schemas, and database access layer."""
from __future__ import annotations
import datetime
import time
from typing import Optional
from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.utils.misc import throttle
@@ -43,39 +44,46 @@ class UserSettings(BaseModel):
class User(Base):
"""Core user identity and profile record."""
__tablename__ = 'user'
# Identity
id = Column(String, primary_key=True, unique=True)
email = Column(String)
username = Column(String(50), nullable=True)
role = Column(String)
name = Column(String)
profile_image_url = Column(Text)
# Profile
profile_image_url = Column(Text) # data-uri, path, or external URL
profile_banner_image_url = Column(Text, nullable=True)
bio = Column(Text, nullable=True)
gender = Column(Text, nullable=True)
date_of_birth = Column(Date, nullable=True)
timezone = Column(String, nullable=True)
# Online status
presence_state = Column(String, nullable=True)
status_emoji = Column(String, nullable=True)
status_message = Column(Text, nullable=True)
status_expires_at = Column(BigInteger, nullable=True)
# Metadata
info = Column(JSON, nullable=True)
settings = Column(JSON, nullable=True)
oauth = Column(JSON, nullable=True)
scim = Column(JSON, nullable=True)
# Timestamps (epoch seconds)
last_active_at = Column(BigInteger)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
_DEFAULT_PROFILE_IMAGE_URL = '/api/v1/users/{user_id}/profile/image'
class UserModel(BaseModel):
id: str
@@ -108,13 +116,19 @@ class UserModel(BaseModel):
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
model_config = ConfigDict(
from_attributes=True,
)
@model_validator(mode='after')
def set_profile_image_url(self):
if not self.profile_image_url:
self.profile_image_url = f'/api/v1/users/{self.id}/profile/image'
return self
@model_validator(mode='after') # runs after all field validators
def _ensure_profile_image(self) -> 'UserModel':
"""Fall back to a generated avatar when no explicit profile image is set."""
if self.profile_image_url: # explicit image — nothing to do
return self
self.profile_image_url = _DEFAULT_PROFILE_IMAGE_URL.format(
user_id=self.id,
)
return self # modified in-place
class UserStatusModel(UserModel):
@@ -269,7 +283,7 @@ class UsersTable:
oauth: dict | None = None,
db: AsyncSession | None = None,
) -> UserModel | None:
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
user = UserModel(
**{
'id': id,
@@ -285,89 +299,108 @@ class UsersTable:
}
)
result = User(**user.model_dump())
db.add(result)
await db.commit()
await db.refresh(result)
session.add(result)
await session.commit()
await session.refresh(result)
if result:
return user
else:
return None
async def get_user_by_id(self, id: str, db: AsyncSession | None = None) -> UserModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter_by(id=id))
user = result.scalars().first()
return UserModel.model_validate(user) if user else None
except Exception:
return None
async def get_user_by_api_key(self, api_key: str, db: AsyncSession | None = None) -> UserModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(
select(User).join(ApiKey, User.id == ApiKey.user_id).filter(ApiKey.key == api_key)
)
user = result.scalars().first()
return UserModel.model_validate(user) if user else None
except Exception:
return None
async def get_user_by_email(self, email: str, db: AsyncSession | None = None) -> UserModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter(func.lower(User.email) == email.lower()))
user = result.scalars().first()
return UserModel.model_validate(user) if user else None
except Exception:
return None
async def get_user_by_oauth_sub(self, provider: str, sub: str, db: AsyncSession | None = None) -> UserModel | None:
try:
async with get_async_db_context(db) as db:
dialect_name = db.bind.dialect.name
stmt = select(User)
if dialect_name == 'sqlite':
stmt = stmt.filter(User.oauth.contains({provider: {'sub': sub}}))
elif dialect_name == 'postgresql':
stmt = stmt.filter(User.oauth[provider].cast(JSONB)['sub'].astext == sub)
result = await db.execute(stmt)
user = result.scalars().first()
return UserModel.model_validate(user) if user else None
except Exception as e:
# You may want to log the exception here
return None
async def get_user_by_scim_external_id(
self, provider: str, external_id: str, db: AsyncSession | None = None
async def get_user_by_id(
self, id: str, db: AsyncSession | None = None,
) -> UserModel | None:
"""Fetch a single user by primary key."""
try: # db.get is a PK lookup — very cheap
async with get_async_db_context(db) as session:
user = await session.get(User, id)
if not user:
return None
return UserModel.model_validate(user)
except Exception: # stale session / connection drop
return None # caller treats None as "not found"
async def get_user_by_api_key(
self, api_key: str, db: AsyncSession | None = None,
) -> UserModel | None:
"""Resolve a user from their API key via a JOIN on the api_key table."""
try: # single JOIN instead of two round-trips
async with get_async_db_context(db) as session:
result = await session.execute(
select(User)
.join(ApiKey, User.id == ApiKey.user_id)
.filter(ApiKey.key == api_key),
) # emits: SELECT user.* FROM user JOIN api_key …
user = result.scalars().first() # None when key is invalid
if not user:
return None
return UserModel.model_validate(user)
except Exception: # stale session / connection drop
return
async def get_user_by_email(
self, email: str, db: AsyncSession | None = None,
) -> UserModel | None:
"""Case-insensitive email lookup using SQL ``lower()``."""
try:
async with get_async_db_context(db) as db:
dialect_name = db.bind.dialect.name
stmt = select(User)
if dialect_name == 'sqlite':
stmt = stmt.filter(User.scim.contains({provider: {'external_id': external_id}}))
elif dialect_name == 'postgresql':
stmt = stmt.filter(User.scim[provider].cast(JSONB)['external_id'].astext == external_id)
result = await db.execute(stmt)
user = result.scalars().first()
return UserModel.model_validate(user) if user else None
async with get_async_db_context(db) as session:
stmt = select(User).filter(
func.lower(User.email) == email.lower(),
)
row = (await session.execute(stmt)).scalars().first()
if not row:
return
return UserModel.model_validate(row)
except Exception:
return None
return
async def get_user_by_oauth_sub(
self, provider: str, sub: str, db: AsyncSession | None = None,
) -> UserModel | None:
"""Look up a user by OAuth provider + subject claim (dialect-aware JSON filter)."""
try:
async with get_async_db_context(db) as session:
dialect = session.bind.dialect.name
query = select(User)
if dialect == 'sqlite':
oauth_match = User.oauth.contains({provider: {'sub': sub}})
query = query.filter(oauth_match)
elif dialect == 'postgresql':
oauth_match = User.oauth[provider].cast(JSONB)['sub'].astext == sub
query = query.filter(oauth_match)
row = (await session.execute(query)).scalars().first()
if not row:
return
return UserModel.model_validate(row)
except Exception:
return
async def get_user_by_scim_external_id(self, provider: str, external_id: str, db: AsyncSession | None = None) -> UserModel | None:
"""Look up a user by SCIM provider + external ID (dialect-aware JSON filter)."""
try:
async with get_async_db_context(db) as session:
dialect = session.bind.dialect.name
query = select(User)
if dialect == 'sqlite':
scim_match = User.scim.contains({provider: {'external_id': external_id}})
query = query.filter(scim_match)
elif dialect == 'postgresql':
scim_match = User.scim[provider].cast(JSONB)['external_id'].astext == external_id
query = query.filter(scim_match)
row = (await session.execute(query)).scalars().first()
if not row:
return
return UserModel.model_validate(row)
except Exception:
return
async def get_users(
self,
filter: dict | None = None,
skip: int | None = None,
limit: int | None = None,
db: AsyncSession | None = None,
self, filter: dict | None = None, skip: int | None = None,
limit: int | None = None, db: AsyncSession | None = None,
) -> dict:
async with get_async_db_context(db) as db:
# Import here to avoid circular imports
"""Paginated user listing with optional filters for role, group, and channel."""
async with get_async_db_context(db) as session:
# Deferred imports to avoid circular dependencies
from open_webui.models.channels import ChannelMember
from open_webui.models.groups import GroupMember
@@ -487,7 +520,7 @@ class UsersTable:
stmt = stmt.order_by(User.created_at.desc())
# Count BEFORE pagination
count_result = await db.execute(select(func.count()).select_from(stmt.subquery()))
count_result = await session.execute(select(func.count()).select_from(stmt.subquery()))
total = count_result.scalar()
# correct pagination logic
@@ -496,7 +529,7 @@ class UsersTable:
if limit is not None:
stmt = stmt.limit(limit)
result = await db.execute(stmt)
result = await session.execute(stmt)
users = result.scalars().all()
return {
'users': [UserModel.model_validate(user) for user in users],
@@ -504,44 +537,47 @@ class UsersTable:
}
async def get_users_by_group_id(self, group_id: str, db: AsyncSession | None = None) -> list[UserModel]:
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
from open_webui.models.groups import GroupMember
result = await db.execute(
result = await session.execute(
select(User).join(GroupMember, User.id == GroupMember.user_id).filter(GroupMember.group_id == group_id)
)
users = result.scalars().all()
return [UserModel.model_validate(user) for user in users]
async def get_users_by_user_ids(self, user_ids: list[str], db: AsyncSession | None = None) -> list[UserStatusModel]:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter(User.id.in_(user_ids)))
async with get_async_db_context(db) as session:
result = await session.execute(select(User).filter(User.id.in_(user_ids)))
users = result.scalars().all()
return [UserModel.model_validate(user) for user in users]
async def get_num_users(self, db: AsyncSession | None = None) -> int | None:
async with get_async_db_context(db) as db:
result = await db.execute(select(func.count()).select_from(User))
async with get_async_db_context(db) as session:
result = await session.execute(select(func.count()).select_from(User))
return result.scalar()
async def has_users(self, db: AsyncSession | None = None) -> bool:
async with get_async_db_context(db) as db:
result = await db.execute(select(exists(select(User))))
async with get_async_db_context(db) as session:
result = await session.execute(select(exists(select(User))))
return result.scalar()
async def get_first_user(self, db: AsyncSession | None = None) -> UserModel:
"""Return the earliest-created user (bootstrap admin detection)."""
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).order_by(User.created_at).limit(1))
user = result.scalars().first()
return UserModel.model_validate(user) if user else None
async with get_async_db_context(db) as session:
stmt = select(User).order_by(User.created_at).limit(1)
row = (await session.execute(stmt)).scalars().first()
if not row:
return
return UserModel.model_validate(row)
except Exception:
return None
return
async def get_user_webhook_url_by_id(self, id: str, db: AsyncSession | None = None) -> str | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter_by(id=id))
async with get_async_db_context(db) as session:
result = await session.execute(select(User).filter_by(id=id))
user = result.scalars().first()
if user.settings is None:
@@ -552,67 +588,65 @@ class UsersTable:
return None
async def get_num_users_active_today(self, db: AsyncSession | None = None) -> int | None:
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
current_timestamp = int(datetime.datetime.now().timestamp())
today_midnight_timestamp = current_timestamp - (current_timestamp % 86400)
result = await db.execute(
result = await session.execute(
select(func.count()).select_from(User).filter(User.last_active_at > today_midnight_timestamp)
)
return result.scalar()
async def update_user_role_by_id(self, id: str, role: str, db: AsyncSession | None = None) -> UserModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter_by(id=id))
user = result.scalars().first()
async with get_async_db_context(db) as session:
user = (await session.execute(select(User).filter_by(id=id))).scalars().first()
if not user:
return None
return
user.role = role
await db.commit()
await db.refresh(user)
await session.commit()
await session.refresh(user)
return UserModel.model_validate(user)
except Exception:
return None
return
async def update_user_status_by_id(
self, id: str, form_data: UserStatus, db: AsyncSession | None = None
) -> UserModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter_by(id=id))
async with get_async_db_context(db) as session:
result = await session.execute(select(User).filter_by(id=id))
user = result.scalars().first()
if not user:
return None
for key, value in form_data.model_dump(exclude_none=True).items():
setattr(user, key, value)
await db.commit()
await db.refresh(user)
await session.commit()
await session.refresh(user)
return UserModel.model_validate(user)
except Exception:
return None
async def update_user_profile_image_url_by_id(
self, id: str, profile_image_url: str, db: AsyncSession | None = None
self, id: str, profile_image_url: str, db: AsyncSession | None = None,
) -> UserModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter_by(id=id))
user = result.scalars().first()
if not user:
return None
user.profile_image_url = profile_image_url
await db.commit()
await db.refresh(user)
return UserModel.model_validate(user)
async with get_async_db_context(db) as session:
row = (await session.execute(select(User).filter_by(id=id))).scalars().first()
if row is None:
return
row.profile_image_url = profile_image_url
await session.commit()
await session.refresh(row)
return UserModel.model_validate(row)
except Exception:
return None
return
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
async def update_last_active_by_id(self, id: str, db: AsyncSession | None = None) -> None:
try:
async with get_async_db_context(db) as db:
await db.execute(update(User).filter_by(id=id).values(last_active_at=int(time.time())))
await db.commit()
async with get_async_db_context(db) as session:
await session.execute(update(User).filter_by(id=id).values(last_active_at=int(time.time())))
await session.commit()
except Exception:
pass
@@ -628,8 +662,8 @@ class UsersTable:
}
"""
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter_by(id=id))
async with get_async_db_context(db) as session:
result = await session.execute(select(User).filter_by(id=id))
user = result.scalars().first()
if not user:
return None
@@ -641,8 +675,8 @@ class UsersTable:
oauth[provider] = {'sub': sub}
# Persist updated JSON
await db.execute(update(User).filter_by(id=id).values(oauth=oauth))
await db.commit()
await session.execute(update(User).filter_by(id=id).values(oauth=oauth))
await session.commit()
return UserModel.model_validate(user)
@@ -665,8 +699,8 @@ class UsersTable:
}
"""
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter_by(id=id))
async with get_async_db_context(db) as session:
result = await session.execute(select(User).filter_by(id=id))
user = result.scalars().first()
if not user:
return None
@@ -674,8 +708,8 @@ class UsersTable:
scim = user.scim or {}
scim[provider] = {'external_id': external_id}
await db.execute(update(User).filter_by(id=id).values(scim=scim))
await db.commit()
await session.execute(update(User).filter_by(id=id).values(scim=scim))
await session.commit()
return UserModel.model_validate(user)
@@ -684,15 +718,15 @@ class UsersTable:
async def update_user_by_id(self, id: str, updated: dict, db: AsyncSession | None = None) -> UserModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter_by(id=id))
async with get_async_db_context(db) as session:
result = await session.execute(select(User).filter_by(id=id))
user = result.scalars().first()
if not user:
return None
for key, value in updated.items():
setattr(user, key, value)
await db.commit()
await db.refresh(user)
await session.commit()
await session.refresh(user)
return UserModel.model_validate(user)
except Exception as e:
print(e)
@@ -702,8 +736,8 @@ class UsersTable:
self, id: str, updated: dict, db: AsyncSession | None = None
) -> UserModel | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter_by(id=id))
async with get_async_db_context(db) as session:
result = await session.execute(select(User).filter_by(id=id))
user = result.scalars().first()
if not user:
return None
@@ -715,10 +749,10 @@ class UsersTable:
user_settings.update(updated)
await db.execute(update(User).filter_by(id=id).values(settings=user_settings))
await db.commit()
await session.execute(update(User).filter_by(id=id).values(settings=user_settings))
await session.commit()
result = await db.execute(select(User).filter_by(id=id))
result = await session.execute(select(User).filter_by(id=id))
user = result.scalars().first()
return UserModel.model_validate(user)
except Exception:
@@ -733,12 +767,12 @@ class UsersTable:
await Groups.remove_user_from_all_groups(id)
# Delete User Chats
result = await Chats.delete_chats_by_user_id(id, db=db)
result = await Chats.delete_chats_by_user_id(id, db=session)
if result:
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
# Delete User
await db.execute(delete(User).filter_by(id=id))
await db.commit()
await session.execute(delete(User).filter_by(id=id))
await session.commit()
return True
else:
@@ -748,8 +782,8 @@ class UsersTable:
async def get_user_api_key_by_id(self, id: str, db: AsyncSession | None = None) -> str | None:
try:
async with get_async_db_context(db) as db:
result = await db.execute(select(ApiKey).filter_by(user_id=id))
async with get_async_db_context(db) as session:
result = await session.execute(select(ApiKey).filter_by(user_id=id))
api_key = result.scalars().first()
return api_key.key if api_key else None
except Exception:
@@ -757,9 +791,9 @@ class UsersTable:
async def update_user_api_key_by_id(self, id: str, api_key: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
await db.execute(delete(ApiKey).filter_by(user_id=id))
await db.commit()
async with get_async_db_context(db) as session:
await session.execute(delete(ApiKey).filter_by(user_id=id))
await session.commit()
now = int(time.time())
new_api_key = ApiKey(
@@ -769,8 +803,8 @@ class UsersTable:
created_at=now,
updated_at=now,
)
db.add(new_api_key)
await db.commit()
session.add(new_api_key)
await session.commit()
return True
@@ -779,22 +813,22 @@ class UsersTable:
async def delete_user_api_key_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as db:
await db.execute(delete(ApiKey).filter_by(user_id=id))
await db.commit()
async with get_async_db_context(db) as session:
await session.execute(delete(ApiKey).filter_by(user_id=id))
await session.commit()
return True
except Exception:
return False
async def get_valid_user_ids(self, user_ids: list[str], db: AsyncSession | None = None) -> list[str]:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter(User.id.in_(user_ids)))
async with get_async_db_context(db) as session:
result = await session.execute(select(User).filter(User.id.in_(user_ids)))
users = result.scalars().all()
return [user.id for user in users]
async def get_super_admin_user(self, db: AsyncSession | None = None) -> UserModel | None:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter_by(role='admin').limit(1))
async with get_async_db_context(db) as session:
result = await session.execute(select(User).filter_by(role='admin').limit(1))
user = result.scalars().first()
if user:
return UserModel.model_validate(user)
@@ -802,10 +836,10 @@ class UsersTable:
return None
async def get_active_user_count(self, db: AsyncSession | None = None) -> int:
async with get_async_db_context(db) as db:
async with get_async_db_context(db) as session:
# Consider user active if last_active_at within the last 3 minutes
three_minutes_ago = int(time.time()) - 180
result = await db.execute(
result = await session.execute(
select(func.count()).select_from(User).filter(User.last_active_at >= three_minutes_ago)
)
return result.scalar()
@@ -819,8 +853,8 @@ class UsersTable:
return False
async def is_user_active(self, user_id: str, db: AsyncSession | None = None) -> bool:
async with get_async_db_context(db) as db:
result = await db.execute(select(User).filter_by(id=user_id))
async with get_async_db_context(db) as session:
result = await session.execute(select(User).filter_by(id=user_id))
user = result.scalars().first()
if user and user.last_active_at:
# Consider user active if last_active_at within the last 3 minutes
+6 -11
View File
@@ -513,25 +513,20 @@ async def delete_all_user_chats(
@router.get('/list/user/{user_id}', response_model=list[ChatTitleIdResponse])
async def get_user_chat_list_by_user_id(
user_id: str,
page: int | None = None,
query: str | None = None,
order_by: str | None = None,
direction: str | None = None,
user=Depends(get_admin_user),
db: AsyncSession = Depends(get_async_session),
user_id: str, page: int | None = None, query: str | None = None,
order_by: str | None = None, direction: str | None = None,
user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session),
):
"""List chat summaries for a given user (admin-only endpoint)."""
if not ENABLE_ADMIN_CHAT_ACCESS:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if page is None:
page = 1
effective_page = page if page is not None else 1
limit = 60
skip = (page - 1) * limit
skip = (effective_page - 1) * limit
filter = {}
if query:
+7 -7
View File
@@ -62,14 +62,14 @@ class MemoryUpdateModel(BaseModel):
@router.post('/add', response_model=MemoryModel | None)
async def add_memory(
request: Request,
form_data: AddMemoryForm,
user=Depends(get_verified_user),
request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user),
):
# NOTE: We intentionally do NOT use Depends(get_async_session) here.
# Database operations (insert_new_memory) manage their own short-lived sessions.
# This prevents holding a connection during EMBEDDING_FUNCTION()
# which makes external embedding API calls (1-5+ seconds).
"""Persist a new memory and embed it into the user's vector collection.
Does NOT use ``Depends(get_async_session)`` database operations manage their
own short-lived sessions so a connection is not held during the external
embedding API call (``EMBEDDING_FUNCTION``), which can take 1-5+ seconds.
"""
if not request.app.state.config.ENABLE_MEMORIES:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
+7 -10
View File
@@ -193,11 +193,10 @@ async def get_model_tags(user=Depends(get_verified_user), db: AsyncSession = Dep
@router.post('/create', response_model=ModelModel | None)
async def create_new_model(
request: Request,
form_data: ModelForm,
user=Depends(get_verified_user),
db: AsyncSession = Depends(get_async_session),
request: Request, form_data: ModelForm,
user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session),
):
"""Create a new workspace model entry."""
if user.role != 'admin' and not await has_permission(
user.id, 'workspace.models', request.app.state.config.USER_PERMISSIONS, db=db
):
@@ -579,16 +578,14 @@ async def toggle_model_by_id(id: str, user=Depends(get_verified_user), db: Async
@router.post('/model/update', response_model=ModelModel | None)
async def update_model_by_id(
request: Request,
form_data: ModelForm,
user=Depends(get_verified_user),
db: AsyncSession = Depends(get_async_session),
request: Request, form_data: ModelForm,
user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session),
):
"""Update a workspace model's configuration."""
model = await Models.get_model_by_id(form_data.id, db=db)
if not model:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND,
)
if (
+3 -5
View File
@@ -239,12 +239,10 @@ async def get_prompt_by_id(
@router.post('/id/{prompt_id}/update', response_model=PromptModel | None)
async def update_prompt_by_id(
request: Request,
prompt_id: str,
form_data: PromptForm,
user=Depends(get_verified_user),
db: AsyncSession = Depends(get_async_session),
request: Request, prompt_id: str, form_data: PromptForm,
user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session),
):
"""Update a prompt's content, creating a new history entry if changed."""
prompt = await Prompts.get_prompt_by_id(prompt_id, db=db)
if not prompt:
+7 -11
View File
@@ -324,11 +324,10 @@ async def export_tools(
@router.post('/create', response_model=ToolResponse | None)
async def create_new_tools(
request: Request,
form_data: ToolForm,
user=Depends(get_verified_user),
db: AsyncSession = Depends(get_async_session),
request: Request, form_data: ToolForm,
user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session),
):
"""Create a new tool from user-supplied Python source code."""
if user.role != 'admin' and not (
await has_permission(user.id, 'workspace.tools', request.app.state.config.USER_PERMISSIONS, db=db)
or await has_permission(
@@ -449,17 +448,14 @@ async def get_tools_by_id(id: str, user=Depends(get_verified_user), db: AsyncSes
@router.post('/id/{id}/update', response_model=ToolModel | None)
async def update_tools_by_id(
request: Request,
id: str,
form_data: ToolForm,
user=Depends(get_verified_user),
db: AsyncSession = Depends(get_async_session),
request: Request, id: str, form_data: ToolForm,
user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session),
):
"""Update an existing tool's source code and metadata."""
tools = await Tools.get_tool_by_id(id, db=db)
if not tools:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND,
)
# Is the user the original creator, in a group with write access, or an admin
+11 -10
View File
@@ -384,21 +384,22 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user), db: Asy
@router.post('/user/info/update', response_model=dict | None)
async def update_user_info_by_session_user(
form_data: dict, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
form_data: dict, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session),
):
# Merges against the auth-time snapshot of user.info. The previous pre-merge
# refetch only narrowed (did not eliminate) the lost-update window on concurrent
# same-user writes; real safety needs row locking or a version column.
existing_info = user.info or {}
updated = await Users.update_user_by_id(user.id, {'info': {**existing_info, **form_data}}, db=db)
if updated:
return updated.info
else:
"""Merge caller-supplied fields into the current user's info dict.
Uses the auth-time snapshot of ``user.info`` as the merge base. This does
NOT eliminate lost-update races on concurrent same-user writes; real safety
would need row locking or an optimistic-concurrency version column.
"""
merged_info = {**(user.info or {}), **form_data}
updated = await Users.update_user_by_id(user.id, {'info': merged_info}, db=db)
if not updated:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
return updated.info
############################
# GetUserById
+5 -2
View File
@@ -94,18 +94,21 @@ async def download_chat_as_pdf(form_data: ChatTitleMessagesForm, user=Depends(ge
@router.get('/db/download')
async def download_db(user=Depends(get_admin_user)):
"""Download the raw SQLite database file (admin-only, SQLite deployments only)."""
if not ENABLE_ADMIN_EXPORT:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
from open_webui.internal.db import engine
if engine.name != 'sqlite':
from open_webui.internal.db import engine # deferred import
if engine.name != 'sqlite': # only SQLite DBs can be downloaded as a file
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DB_NOT_SQLITE,
)
return FileResponse(
engine.url.database,
media_type='application/octet-stream',