mirror of
https://github.com/open-webui/open-webui.git
synced 2026-06-13 19:20:05 +00:00
refac
This commit is contained in:
@@ -118,8 +118,15 @@ reattach_ssl_mode_to_url = reattach_ssl_params_to_url
|
||||
|
||||
|
||||
class JSONField(types.TypeDecorator):
|
||||
"""Store arbitrary Python objects as JSON-encoded TEXT.
|
||||
|
||||
Used instead of native JSON columns for portability across SQLite and
|
||||
PostgreSQL. Values are serialized with ``json.dumps`` on write and
|
||||
deserialized with ``json.loads`` on read.
|
||||
"""
|
||||
|
||||
impl = types.Text
|
||||
cache_ok = True
|
||||
cache_ok = True # safe for statement caching (no per-instance state)
|
||||
|
||||
def process_bind_param(self, value: _T | None, dialect: Dialect) -> Any:
|
||||
return json.dumps(value)
|
||||
|
||||
@@ -2912,10 +2912,12 @@ async def readiness_check():
|
||||
|
||||
@app.get('/health/db')
|
||||
async def healthcheck_with_db():
|
||||
"""Verify database connectivity by issuing a lightweight ping."""
|
||||
await async_db_ping()
|
||||
return {'status': True}
|
||||
|
||||
|
||||
# Serve build-time static assets (CSS, JS, images, favicon, etc.)
|
||||
app.mount('/static', StaticFiles(directory=STATIC_DIR), name='static')
|
||||
|
||||
|
||||
@@ -2924,8 +2926,13 @@ async def serve_cache_file(
|
||||
path: str,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
"""Serve cached files (e.g. tool outputs) with path-traversal protection.
|
||||
|
||||
Only ``image/*``, ``audio/*``, and ``video/*`` MIME types are served inline;
|
||||
everything else gets a ``Content-Disposition: attachment`` header to prevent
|
||||
XSS from user-generated HTML stored in the cache directory.
|
||||
"""
|
||||
file_path = os.path.abspath(os.path.join(CACHE_DIR, path))
|
||||
# prevent path traversal
|
||||
if not file_path.startswith(os.path.abspath(CACHE_DIR)):
|
||||
raise HTTPException(status_code=404, detail='File not found')
|
||||
if not os.path.isfile(file_path):
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Authentication models and database operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
@@ -13,33 +15,30 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
####################
|
||||
# DB MODEL
|
||||
####################
|
||||
|
||||
|
||||
class Auth(Base):
|
||||
"""Credential record linking a user identity to an email + hashed password."""
|
||||
|
||||
__tablename__ = 'auth'
|
||||
|
||||
id = Column(String, primary_key=True, unique=True)
|
||||
email = Column(String)
|
||||
password = Column(Text)
|
||||
active = Column(Boolean)
|
||||
id = Column(String, primary_key=True, unique=True) # same as User.id
|
||||
email = Column(String) # login email, kept in sync with User.email
|
||||
password = Column(Text) # bcrypt / argon2 hash
|
||||
active = Column(Boolean) # soft-disable flag
|
||||
|
||||
|
||||
class AuthModel(BaseModel):
|
||||
"""Pydantic mirror of the Auth table row."""
|
||||
|
||||
id: str
|
||||
email: str
|
||||
password: str
|
||||
active: bool = True
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
"""JWT bearer token response."""
|
||||
|
||||
token: str
|
||||
token_type: str
|
||||
|
||||
@@ -100,112 +99,130 @@ class AuthsTable:
|
||||
oauth: dict | None = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
"""Create an Auth + User pair in a single transaction."""
|
||||
async with get_async_db_context(db) as db:
|
||||
log.info('insert_new_auth')
|
||||
|
||||
id = str(uuid.uuid4())
|
||||
user_id = str(uuid.uuid4())
|
||||
|
||||
auth = AuthModel(**{'id': id, 'email': email, 'password': password, 'active': True})
|
||||
result = Auth(**auth.model_dump())
|
||||
db.add(result)
|
||||
record = Auth(
|
||||
id=user_id,
|
||||
email=email,
|
||||
password=password,
|
||||
active=True,
|
||||
)
|
||||
db.add(record)
|
||||
|
||||
user = await Users.insert_new_user(id, name, email, profile_image_url, role, oauth=oauth, db=db)
|
||||
user = await Users.insert_new_user(
|
||||
user_id, name, email, profile_image_url, role, oauth=oauth, db=db,
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(result)
|
||||
await db.refresh(record)
|
||||
|
||||
if result and user:
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
return user if record and user else None
|
||||
|
||||
async def authenticate_user(
|
||||
self, email: str, verify_password: callable, db: AsyncSession | None = None
|
||||
) -> UserModel | None:
|
||||
"""Verify a user's email + password and return the user on success."""
|
||||
log.info(f'authenticate_user: {email}')
|
||||
|
||||
user = await Users.get_user_by_email(email, db=db)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Auth).filter_by(id=user.id, active=True))
|
||||
auth = result.scalars().first()
|
||||
if auth:
|
||||
if verify_password(auth.password):
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
try: # load the auth row for password verification
|
||||
async with get_async_db_context(db) as session:
|
||||
auth = await session.get(Auth, user.id)
|
||||
if not auth or not auth.active:
|
||||
return None
|
||||
if not verify_password(auth.password):
|
||||
return None
|
||||
return user
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def authenticate_user_by_api_key(self, api_key: str, db: AsyncSession | None = None) -> UserModel | None:
|
||||
log.info(f'authenticate_user_by_api_key')
|
||||
# if no api_key, return None
|
||||
if not api_key:
|
||||
return None
|
||||
return
|
||||
|
||||
async def authenticate_user_by_api_key(
|
||||
self, api_key: str, db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
"""Resolve an API key to its owning user, returning ``None`` on miss."""
|
||||
log.info('authenticate_user_by_api_key')
|
||||
if not api_key: # empty / None key — reject immediately
|
||||
return
|
||||
try:
|
||||
user = await Users.get_user_by_api_key(api_key, db=db)
|
||||
return user if user else None
|
||||
return await Users.get_user_by_api_key(api_key, db=db)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def authenticate_user_by_email(self, email: str, db: AsyncSession | None = None) -> UserModel | None:
|
||||
log.info(f'authenticate_user_by_email: {email}')
|
||||
async def authenticate_user_by_email(
|
||||
self,
|
||||
email: str,
|
||||
db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
"""One-query authentication: JOIN Auth ↔ User, filter by email + active flag."""
|
||||
log.info('authenticate_user_by_email: %s', email)
|
||||
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
# Single JOIN query instead of two separate queries
|
||||
result = await db.execute(
|
||||
select(Auth, User).join(User, Auth.id == User.id).filter(Auth.email == email, Auth.active == True)
|
||||
async with get_async_db_context(db) as session:
|
||||
stmt = (
|
||||
select(Auth, User)
|
||||
.join(User, Auth.id == User.id)
|
||||
.filter(Auth.email == email, Auth.active == True)
|
||||
)
|
||||
row = result.first()
|
||||
if row:
|
||||
_, user = row
|
||||
return UserModel.model_validate(user)
|
||||
return None
|
||||
row = (await session.execute(stmt)).first()
|
||||
if not row:
|
||||
return
|
||||
_auth, matched_user = row
|
||||
return UserModel.model_validate(matched_user)
|
||||
except Exception:
|
||||
return None
|
||||
return
|
||||
|
||||
async def update_user_password_by_id(self, id: str, new_password: str, db: AsyncSession | None = None) -> bool:
|
||||
async def update_user_password_by_id(
|
||||
self,
|
||||
id: str,
|
||||
new_password: str,
|
||||
db: AsyncSession | None = None,
|
||||
) -> bool:
|
||||
"""Hash-swap: replace the stored password hash for a given user."""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(update(Auth).filter_by(id=id).values(password=new_password))
|
||||
await db.commit()
|
||||
return True if result.rowcount == 1 else False
|
||||
async with get_async_db_context(db) as session:
|
||||
stmt = update(Auth).filter_by(id=id).values(password=new_password)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
return result.rowcount == 1
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def update_email_by_id(self, id: str, email: str, db: AsyncSession | None = None) -> bool:
|
||||
async def update_email_by_id(
|
||||
self, id: str, email: str, db: AsyncSession | None = None,
|
||||
) -> bool:
|
||||
"""Update the auth email and propagate the change to the User table."""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(update(Auth).filter_by(id=id).values(email=email))
|
||||
await db.commit()
|
||||
if result.rowcount == 1:
|
||||
await Users.update_user_by_id(id, {'email': email}, db=db)
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def delete_auth_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
# Delete User
|
||||
result = await Users.delete_user_by_id(id, db=db)
|
||||
|
||||
if result:
|
||||
await db.execute(delete(Auth).filter_by(id=id))
|
||||
await db.commit()
|
||||
|
||||
return True
|
||||
else:
|
||||
async with get_async_db_context(db) as session:
|
||||
stmt = update(Auth).filter_by(id=id).values(email=email)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
if result.rowcount != 1:
|
||||
return False
|
||||
await Users.update_user_by_id(id, {'email': email}, db=session)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def delete_auth_by_id(
|
||||
self, id: str, db: AsyncSession | None = None,
|
||||
) -> bool:
|
||||
"""Delete a user and their auth record in a single transaction."""
|
||||
try: # delete user first, then auth (FK order)
|
||||
async with get_async_db_context(db) as session:
|
||||
if not await Users.delete_user_by_id(id, db=session):
|
||||
return False # user deletion failed — abort
|
||||
await session.execute(delete(Auth).filter_by(id=id))
|
||||
await session.commit()
|
||||
return True
|
||||
except Exception: # db / integrity error
|
||||
return False # partial deletion is rolled back by context manager
|
||||
|
||||
|
||||
Auths = AuthsTable()
|
||||
|
||||
+180
-178
@@ -1,10 +1,11 @@
|
||||
"""Chat models, forms, and database operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from open_webui.models.automations import AutomationRun
|
||||
@@ -35,12 +36,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.sql import exists
|
||||
from sqlalchemy.sql.expression import bindparam
|
||||
|
||||
####################
|
||||
# Chat DB Schema
|
||||
# Let no word spoken in this house be lost, and when the
|
||||
# record is read again, let it still serve the one who spoke.
|
||||
####################
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -49,14 +44,14 @@ class Chat(Base):
|
||||
|
||||
id = Column(String, primary_key=True, unique=True)
|
||||
user_id = Column(String)
|
||||
title = Column(Text)
|
||||
title = Column(Text) # user-visible conversation title
|
||||
chat = Column(JSON)
|
||||
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
|
||||
share_id = Column(Text, unique=True, nullable=True)
|
||||
archived = Column(Boolean, default=False)
|
||||
share_id = Column(Text, unique=True, nullable=True) # public share link token
|
||||
archived = Column(Boolean, default=False) # hidden from main chat list
|
||||
pinned = Column(Boolean, default=False, nullable=True)
|
||||
|
||||
meta = Column(JSON, server_default='{}')
|
||||
@@ -302,7 +297,7 @@ class ChatTable:
|
||||
async def insert_new_chat(
|
||||
self, id: str, user_id: str, form_data: ChatForm, db: AsyncSession | None = None
|
||||
) -> ChatModel | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
chat = ChatModel(
|
||||
**{
|
||||
'id': id,
|
||||
@@ -318,9 +313,9 @@ class ChatTable:
|
||||
)
|
||||
|
||||
chat_item = Chat(**chat.model_dump())
|
||||
db.add(chat_item)
|
||||
await db.commit()
|
||||
await db.refresh(chat_item)
|
||||
session.add(chat_item)
|
||||
await session.commit()
|
||||
await session.refresh(chat_item)
|
||||
|
||||
# Dual-write initial messages to chat_message table
|
||||
try:
|
||||
@@ -362,7 +357,7 @@ class ChatTable:
|
||||
chat_import_forms: list[ChatImportForm],
|
||||
db: AsyncSession | None = None,
|
||||
) -> list[ChatModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
chats = []
|
||||
|
||||
for form_data in chat_import_forms:
|
||||
@@ -370,7 +365,7 @@ class ChatTable:
|
||||
chats.append(Chat(**chat.model_dump()))
|
||||
|
||||
db.add_all(chats)
|
||||
await db.commit()
|
||||
await session.commit()
|
||||
|
||||
# Dual-write messages to chat_message table
|
||||
for form_data, chat_obj in zip(chat_import_forms, chats):
|
||||
@@ -390,10 +385,13 @@ class ChatTable:
|
||||
|
||||
return [ChatModel.model_validate(chat) for chat in chats]
|
||||
|
||||
async def update_chat_by_id(self, id: str, chat: dict, db: AsyncSession | None = None) -> ChatModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat_item = await db.get(Chat, id)
|
||||
async def update_chat_by_id(
|
||||
self, id: str, chat: dict, db: AsyncSession | None = None,
|
||||
) -> ChatModel | None:
|
||||
"""Persist updated chat content, sanitizing null bytes."""
|
||||
try: # load the chat record for in-place mutation
|
||||
async with get_async_db_context(db) as session:
|
||||
chat_item = await session.get(Chat, id)
|
||||
if chat_item is None:
|
||||
return None
|
||||
|
||||
@@ -402,19 +400,19 @@ class ChatTable:
|
||||
|
||||
chat_item.updated_at = int(time.time())
|
||||
|
||||
await db.commit()
|
||||
await session.commit()
|
||||
|
||||
return ChatModel.model_validate(chat_item)
|
||||
except Exception:
|
||||
return None
|
||||
return
|
||||
|
||||
async def update_chat_last_read_at_by_id(self, id: str, user_id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat = await db.get(Chat, id)
|
||||
async with get_async_db_context(db) as session:
|
||||
chat = await session.get(Chat, id)
|
||||
if chat and chat.user_id == user_id:
|
||||
chat.last_read_at = int(time.time())
|
||||
await db.commit()
|
||||
await session.commit()
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
@@ -423,22 +421,22 @@ class ChatTable:
|
||||
async def update_chat_title_by_id(self, id: str, title: str) -> ChatModel | None:
|
||||
try:
|
||||
async with get_async_db_context() as db:
|
||||
chat_item = await db.get(Chat, id)
|
||||
chat_item = await session.get(Chat, id)
|
||||
if chat_item is None:
|
||||
return None
|
||||
clean_title = self._clean_null_bytes(title)
|
||||
chat_item.title = clean_title
|
||||
chat_item.chat = {**(chat_item.chat or {}), 'title': clean_title}
|
||||
chat_item.updated_at = int(time.time())
|
||||
await db.commit()
|
||||
await db.refresh(chat_item)
|
||||
await session.commit()
|
||||
await session.refresh(chat_item)
|
||||
return ChatModel.model_validate(chat_item)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def update_chat_tags_by_id(self, id: str, tags: list[str], user) -> ChatModel | None:
|
||||
async with get_async_db_context() as db:
|
||||
chat = await db.get(Chat, id)
|
||||
chat = await session.get(Chat, id)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
@@ -448,22 +446,22 @@ class ChatTable:
|
||||
|
||||
# Single meta update
|
||||
chat.meta = {**chat.meta, 'tags': new_tag_ids}
|
||||
await db.commit()
|
||||
await db.refresh(chat)
|
||||
await session.commit()
|
||||
await session.refresh(chat)
|
||||
|
||||
# Batch-create any missing tag rows
|
||||
await Tags.ensure_tags_exist(new_tags, user.id, db=db)
|
||||
await Tags.ensure_tags_exist(new_tags, user.id, db=session)
|
||||
|
||||
# Clean up orphaned old tags in one query
|
||||
removed = set(old_tags) - set(new_tag_ids)
|
||||
if removed:
|
||||
await self.delete_orphan_tags_for_user(list(removed), user.id, db=db)
|
||||
await self.delete_orphan_tags_for_user(list(removed), user.id, db=session)
|
||||
|
||||
return ChatModel.model_validate(chat)
|
||||
|
||||
async def get_chat_title_by_id(self, id: str) -> str | None:
|
||||
async with get_async_db_context() as db:
|
||||
result = await db.execute(select(Chat.title).filter_by(id=id))
|
||||
result = await session.execute(select(Chat.title).filter_by(id=id))
|
||||
row = result.first()
|
||||
if row is None:
|
||||
return None
|
||||
@@ -638,7 +636,7 @@ class ChatTable:
|
||||
|
||||
async def add_message_files_by_id_and_message_id(self, id: str, message_id: str, files: list[dict]) -> list[dict]:
|
||||
async with get_async_db_context() as db:
|
||||
chat = await self.get_chat_by_id(id, db=db)
|
||||
chat = await self.get_chat_by_id(id, db=session)
|
||||
if chat is None:
|
||||
return None
|
||||
|
||||
@@ -653,61 +651,62 @@ class ChatTable:
|
||||
history['messages'][message_id]['files'] = message_files
|
||||
|
||||
chat['history'] = history
|
||||
await self.update_chat_by_id(id, chat, db=db)
|
||||
await self.update_chat_by_id(id, chat, db=session)
|
||||
return message_files
|
||||
|
||||
async def insert_shared_chat_by_chat_id(self, chat_id: str, db: AsyncSession | None = None) -> ChatModel | None:
|
||||
"""Create a shared snapshot for a chat. Returns the original chat with share_id set."""
|
||||
from open_webui.models.shared_chats import SharedChats
|
||||
|
||||
async with get_async_db_context(db) as db:
|
||||
chat = await db.get(Chat, chat_id)
|
||||
async with get_async_db_context(db) as session:
|
||||
chat = await session.get(Chat, chat_id)
|
||||
if not chat:
|
||||
return None
|
||||
|
||||
# If already shared, just update the existing snapshot
|
||||
if chat.share_id:
|
||||
return await self.update_shared_chat_by_chat_id(chat_id, db=db)
|
||||
return await self.update_shared_chat_by_chat_id(chat_id, db=session)
|
||||
|
||||
shared = await SharedChats.create(chat_id, chat.user_id, db=db)
|
||||
shared = await SharedChats.create(chat_id, chat.user_id, db=session)
|
||||
if not shared:
|
||||
return None
|
||||
|
||||
# Set share_id on the original chat
|
||||
chat.share_id = shared.id
|
||||
await db.commit()
|
||||
await db.refresh(chat)
|
||||
return ChatModel.model_validate(chat)
|
||||
await session.commit()
|
||||
await session.refresh(chat)
|
||||
return ChatModel.model_validate(chat) # return the updated original
|
||||
|
||||
async def update_shared_chat_by_chat_id(self, chat_id: str, db: AsyncSession | None = None) -> ChatModel | None:
|
||||
"""Re-snapshot the shared chat with current chat data."""
|
||||
from open_webui.models.shared_chats import SharedChats
|
||||
async def update_shared_chat_by_chat_id(
|
||||
self, chat_id: str, db: AsyncSession | None = None,
|
||||
) -> ChatModel | None:
|
||||
"""Re-snapshot the shared copy with the latest chat content."""
|
||||
from open_webui.models.shared_chats import SharedChats # deferred — circular
|
||||
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat = await db.get(Chat, chat_id)
|
||||
async with get_async_db_context(db) as session:
|
||||
chat = await session.get(Chat, chat_id)
|
||||
if not chat or not chat.share_id:
|
||||
return await self.insert_shared_chat_by_chat_id(chat_id, db=db)
|
||||
|
||||
await SharedChats.update(chat.share_id, db=db)
|
||||
return await self.insert_shared_chat_by_chat_id(chat_id, db=session)
|
||||
await SharedChats.update(chat.share_id, db=session)
|
||||
return ChatModel.model_validate(chat)
|
||||
except Exception:
|
||||
return None
|
||||
return
|
||||
|
||||
async def delete_shared_chat_by_chat_id(self, chat_id: str, db: AsyncSession | None = None) -> bool:
|
||||
"""Delete shared snapshot for a chat."""
|
||||
from open_webui.models.shared_chats import SharedChats
|
||||
|
||||
try:
|
||||
return await SharedChats.delete_by_chat_id(chat_id, db=db)
|
||||
return await SharedChats.delete_by_chat_id(chat_id, db=session)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def unarchive_all_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(update(Chat).filter_by(user_id=user_id).values(archived=False))
|
||||
await db.commit()
|
||||
async with get_async_db_context(db) as session:
|
||||
await session.execute(update(Chat).filter_by(user_id=user_id).values(archived=False))
|
||||
await session.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
@@ -716,45 +715,45 @@ class ChatTable:
|
||||
self, id: str, share_id: str | None, db: AsyncSession | None = None
|
||||
) -> ChatModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat = await db.get(Chat, id)
|
||||
async with get_async_db_context(db) as session:
|
||||
chat = await session.get(Chat, id)
|
||||
chat.share_id = share_id
|
||||
await db.commit()
|
||||
await db.refresh(chat)
|
||||
await session.commit()
|
||||
await session.refresh(chat)
|
||||
return ChatModel.model_validate(chat)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def toggle_chat_pinned_by_id(self, id: str, db: AsyncSession | None = None) -> ChatModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat = await db.get(Chat, id)
|
||||
async with get_async_db_context(db) as session:
|
||||
chat = await session.get(Chat, id)
|
||||
chat.pinned = not chat.pinned
|
||||
chat.updated_at = int(time.time())
|
||||
await db.commit()
|
||||
await db.refresh(chat)
|
||||
await session.commit()
|
||||
await session.refresh(chat)
|
||||
return ChatModel.model_validate(chat)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def toggle_chat_archive_by_id(self, id: str, db: AsyncSession | None = None) -> ChatModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat = await db.get(Chat, id)
|
||||
async with get_async_db_context(db) as session:
|
||||
chat = await session.get(Chat, id)
|
||||
chat.archived = not chat.archived
|
||||
chat.folder_id = None
|
||||
chat.updated_at = int(time.time())
|
||||
await db.commit()
|
||||
await db.refresh(chat)
|
||||
await session.commit()
|
||||
await session.refresh(chat)
|
||||
return ChatModel.model_validate(chat)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def archive_all_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(update(Chat).filter_by(user_id=user_id).values(archived=True))
|
||||
await db.commit()
|
||||
async with get_async_db_context(db) as session:
|
||||
await session.execute(update(Chat).filter_by(user_id=user_id).values(archived=True))
|
||||
await session.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
@@ -767,7 +766,7 @@ class ChatTable:
|
||||
limit: int = 50,
|
||||
db: AsyncSession | None = None,
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at).filter_by(
|
||||
user_id=user_id, archived=True
|
||||
)
|
||||
@@ -798,7 +797,7 @@ class ChatTable:
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
result = await session.execute(stmt)
|
||||
all_chats = result.all()
|
||||
return [
|
||||
ChatTitleIdResponse.model_validate(
|
||||
@@ -823,7 +822,7 @@ class ChatTable:
|
||||
"""Delegate to SharedChats for listing shared chats by user."""
|
||||
from open_webui.models.shared_chats import SharedChats
|
||||
|
||||
return await SharedChats.get_by_user_id(user_id, filter=filter, skip=skip, limit=limit, db=db)
|
||||
return await SharedChats.get_by_user_id(user_id, filter=filter, skip=skip, limit=limit, db=session)
|
||||
|
||||
async def get_chat_list_by_user_id(
|
||||
self,
|
||||
@@ -834,7 +833,7 @@ class ChatTable:
|
||||
limit: int = 50,
|
||||
db: AsyncSession | None = None,
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by(
|
||||
user_id=user_id
|
||||
)
|
||||
@@ -864,7 +863,7 @@ class ChatTable:
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
result = await session.execute(stmt)
|
||||
all_chats = result.all()
|
||||
return [
|
||||
ChatTitleIdResponse.model_validate(
|
||||
@@ -889,7 +888,7 @@ class ChatTable:
|
||||
limit: int | None = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by(
|
||||
user_id=user_id
|
||||
)
|
||||
@@ -910,7 +909,7 @@ class ChatTable:
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
result = await session.execute(stmt)
|
||||
all_chats = result.all()
|
||||
|
||||
return [
|
||||
@@ -933,23 +932,26 @@ class ChatTable:
|
||||
limit: int = 50,
|
||||
db: AsyncSession | None = None,
|
||||
) -> list[ChatModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(
|
||||
select(Chat).filter(Chat.id.in_(chat_ids)).filter_by(archived=False).order_by(Chat.updated_at.desc())
|
||||
)
|
||||
all_chats = result.scalars().all()
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
async def get_chat_by_id(self, id: str, db: AsyncSession | None = None) -> ChatModel | None:
|
||||
async def get_chat_by_id(
|
||||
self, id: str, db: AsyncSession | None = None,
|
||||
) -> ChatModel | None:
|
||||
"""Fetch a chat by PK, auto-sanitizing null bytes on read."""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat_item = await db.get(Chat, id)
|
||||
async with get_async_db_context(db) as session:
|
||||
chat_item = await session.get(Chat, id)
|
||||
if chat_item is None:
|
||||
return None
|
||||
|
||||
if self._sanitize_chat_row(chat_item):
|
||||
await db.commit()
|
||||
await db.refresh(chat_item)
|
||||
await session.commit()
|
||||
await session.refresh(chat_item)
|
||||
|
||||
return ChatModel.model_validate(chat_item)
|
||||
except Exception:
|
||||
@@ -960,7 +962,7 @@ class ChatTable:
|
||||
from open_webui.models.shared_chats import SharedChats
|
||||
|
||||
try:
|
||||
shared = await SharedChats.get_by_id(id, db=db)
|
||||
shared = await SharedChats.get_by_id(id, db=session)
|
||||
if shared:
|
||||
# Return a ChatModel-compatible view of the snapshot
|
||||
return ChatModel(
|
||||
@@ -980,8 +982,8 @@ class ChatTable:
|
||||
self, id: str, user_id: str, db: AsyncSession | None = None
|
||||
) -> ChatModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Chat).filter_by(id=id, user_id=user_id))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(Chat).filter_by(id=id, user_id=user_id))
|
||||
chat = result.scalars().first()
|
||||
return ChatModel.model_validate(chat) if chat else None
|
||||
except Exception:
|
||||
@@ -993,8 +995,8 @@ class ChatTable:
|
||||
the full Chat row (which includes the potentially large JSON blob).
|
||||
"""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(exists().where(and_(Chat.id == id, Chat.user_id == user_id))))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(exists().where(and_(Chat.id == id, Chat.user_id == user_id))))
|
||||
return result.scalar()
|
||||
except Exception:
|
||||
return False
|
||||
@@ -1005,16 +1007,16 @@ class ChatTable:
|
||||
JSON blob. Returns None if chat doesn't exist or doesn't belong to user.
|
||||
"""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Chat.folder_id).filter_by(id=id, user_id=user_id))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(Chat.folder_id).filter_by(id=id, user_id=user_id))
|
||||
row = result.first()
|
||||
return row[0] if row else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_chats(self, skip: int = 0, limit: int = 50, db: AsyncSession | None = None) -> list[ChatModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Chat).order_by(Chat.updated_at.desc()))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(Chat).order_by(Chat.updated_at.desc()))
|
||||
all_chats = result.scalars().all()
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
@@ -1026,7 +1028,7 @@ class ChatTable:
|
||||
limit: int | None = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> ChatListResponse:
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
stmt = select(Chat).filter_by(user_id=user_id)
|
||||
|
||||
if filter:
|
||||
@@ -1048,7 +1050,7 @@ class ChatTable:
|
||||
else:
|
||||
stmt = stmt.order_by(Chat.updated_at.desc(), Chat.id)
|
||||
|
||||
count_result = await db.execute(select(func.count()).select_from(stmt.subquery()))
|
||||
count_result = await session.execute(select(func.count()).select_from(stmt.subquery()))
|
||||
total = count_result.scalar()
|
||||
|
||||
if skip is not None:
|
||||
@@ -1056,7 +1058,7 @@ class ChatTable:
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
result = await session.execute(stmt)
|
||||
all_chats = result.scalars().all()
|
||||
|
||||
return ChatListResponse(
|
||||
@@ -1069,8 +1071,8 @@ class ChatTable:
|
||||
async def get_pinned_chats_by_user_id(
|
||||
self, user_id: str, db: AsyncSession | None = None
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(
|
||||
select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at)
|
||||
.filter_by(user_id=user_id, pinned=True, archived=False)
|
||||
.order_by(Chat.updated_at.desc())
|
||||
@@ -1090,8 +1092,8 @@ class ChatTable:
|
||||
]
|
||||
|
||||
async def get_archived_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> list[ChatModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(
|
||||
select(Chat).filter_by(user_id=user_id, archived=True).order_by(Chat.updated_at.desc())
|
||||
)
|
||||
return [ChatModel.model_validate(chat) for chat in result.scalars().all()]
|
||||
@@ -1161,7 +1163,7 @@ class ChatTable:
|
||||
|
||||
search_text = ' '.join(search_text_words)
|
||||
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
stmt = select(Chat).filter(Chat.user_id == user_id)
|
||||
|
||||
if is_archived is not None:
|
||||
@@ -1184,7 +1186,7 @@ class ChatTable:
|
||||
stmt = stmt.order_by(Chat.updated_at.desc(), Chat.id)
|
||||
|
||||
# Check if the database dialect is either 'sqlite' or 'postgresql'
|
||||
bind = await db.connection()
|
||||
bind = await session.connection()
|
||||
dialect_name = bind.dialect.name
|
||||
if dialect_name == 'sqlite':
|
||||
# SQLite case: using JSON1 extension for JSON searching
|
||||
@@ -1282,7 +1284,7 @@ class ChatTable:
|
||||
|
||||
# Perform pagination at the SQL level
|
||||
stmt = stmt.offset(skip).limit(limit)
|
||||
result = await db.execute(stmt)
|
||||
result = await session.execute(stmt)
|
||||
all_chats = result.scalars().all()
|
||||
|
||||
log.info(f'The number of chats: {len(all_chats)}')
|
||||
@@ -1298,7 +1300,7 @@ class ChatTable:
|
||||
limit: int = 60,
|
||||
db: AsyncSession | None = None,
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
stmt = (
|
||||
select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at)
|
||||
.filter_by(folder_id=folder_id, user_id=user_id)
|
||||
@@ -1312,7 +1314,7 @@ class ChatTable:
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
result = await session.execute(stmt)
|
||||
all_chats = result.all()
|
||||
return [
|
||||
ChatTitleIdResponse.model_validate(
|
||||
@@ -1330,7 +1332,7 @@ class ChatTable:
|
||||
async def get_chats_by_folder_ids_and_user_id(
|
||||
self, folder_ids: list[str], user_id: str, db: AsyncSession | None = None
|
||||
) -> list[ChatModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
stmt = (
|
||||
select(Chat)
|
||||
.filter(Chat.folder_id.in_(folder_ids), Chat.user_id == user_id)
|
||||
@@ -1339,7 +1341,7 @@ class ChatTable:
|
||||
.order_by(Chat.updated_at.desc())
|
||||
)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
result = await session.execute(stmt)
|
||||
all_chats = result.scalars().all()
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
@@ -1347,13 +1349,13 @@ class ChatTable:
|
||||
self, id: str, user_id: str, folder_id: str, db: AsyncSession | None = None
|
||||
) -> ChatModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat = await db.get(Chat, id)
|
||||
async with get_async_db_context(db) as session:
|
||||
chat = await session.get(Chat, id)
|
||||
chat.folder_id = folder_id
|
||||
chat.updated_at = int(time.time())
|
||||
chat.pinned = False
|
||||
await db.commit()
|
||||
await db.refresh(chat)
|
||||
await session.commit()
|
||||
await session.refresh(chat)
|
||||
return ChatModel.model_validate(chat)
|
||||
except Exception:
|
||||
return None
|
||||
@@ -1361,12 +1363,12 @@ class ChatTable:
|
||||
async def get_chat_tags_by_id_and_user_id(
|
||||
self, id: str, user_id: str, db: AsyncSession | None = None
|
||||
) -> list[TagModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
stmt = select(Chat.meta).where(Chat.id == id)
|
||||
result = await db.execute(stmt)
|
||||
result = await session.execute(stmt)
|
||||
meta = result.scalar_one_or_none()
|
||||
tag_ids = (meta or {}).get('tags', [])
|
||||
return await Tags.get_tags_by_ids_and_user_id(tag_ids, user_id, db=db)
|
||||
return await Tags.get_tags_by_ids_and_user_id(tag_ids, user_id, db=session)
|
||||
|
||||
async def get_chat_list_by_user_id_and_tag_name(
|
||||
self,
|
||||
@@ -1376,13 +1378,13 @@ class ChatTable:
|
||||
limit: int = 50,
|
||||
db: AsyncSession | None = None,
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by(
|
||||
user_id=user_id
|
||||
)
|
||||
tag_id = tag_name.replace(' ', '_').lower()
|
||||
|
||||
bind = await db.connection()
|
||||
bind = await session.connection()
|
||||
dialect_name = bind.dialect.name
|
||||
log.info(f'DB dialect name: {dialect_name}')
|
||||
if dialect_name == 'sqlite':
|
||||
@@ -1403,7 +1405,7 @@ class ChatTable:
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
result = await session.execute(stmt)
|
||||
all_chats = result.all()
|
||||
return [
|
||||
ChatTitleIdResponse.model_validate(
|
||||
@@ -1422,17 +1424,17 @@ class ChatTable:
|
||||
self, id: str, user_id: str, tag_name: str, db: AsyncSession | None = None
|
||||
) -> ChatModel | None:
|
||||
tag_id = tag_name.replace(' ', '_').lower()
|
||||
await Tags.ensure_tags_exist([tag_name], user_id, db=db)
|
||||
await Tags.ensure_tags_exist([tag_name], user_id, db=session)
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat = await db.get(Chat, id)
|
||||
async with get_async_db_context(db) as session:
|
||||
chat = await session.get(Chat, id)
|
||||
if tag_id not in chat.meta.get('tags', []):
|
||||
chat.meta = {
|
||||
**chat.meta,
|
||||
'tags': list(set(chat.meta.get('tags', []) + [tag_id])),
|
||||
}
|
||||
await db.commit()
|
||||
await db.refresh(chat)
|
||||
await session.commit()
|
||||
await session.refresh(chat)
|
||||
return ChatModel.model_validate(chat)
|
||||
except Exception:
|
||||
return None
|
||||
@@ -1440,11 +1442,11 @@ class ChatTable:
|
||||
async def count_chats_by_tag_name_and_user_id(
|
||||
self, tag_name: str, user_id: str, db: AsyncSession | None = None
|
||||
) -> int:
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
stmt = select(func.count(Chat.id)).filter_by(user_id=user_id, archived=False)
|
||||
tag_id = tag_name.replace(' ', '_').lower()
|
||||
|
||||
bind = await db.connection()
|
||||
bind = await session.connection()
|
||||
dialect_name = bind.dialect.name
|
||||
if dialect_name == 'sqlite':
|
||||
stmt = stmt.filter(
|
||||
@@ -1457,7 +1459,7 @@ class ChatTable:
|
||||
else:
|
||||
raise NotImplementedError(f'Unsupported dialect: {dialect_name}')
|
||||
|
||||
result = await db.execute(stmt)
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar()
|
||||
|
||||
async def delete_orphan_tags_for_user(
|
||||
@@ -1477,19 +1479,19 @@ class ChatTable:
|
||||
"""
|
||||
if not tag_ids:
|
||||
return
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
orphans = []
|
||||
for tag_id in tag_ids:
|
||||
count = await self.count_chats_by_tag_name_and_user_id(tag_id, user_id, db=db)
|
||||
count = await self.count_chats_by_tag_name_and_user_id(tag_id, user_id, db=session)
|
||||
if count <= threshold:
|
||||
orphans.append(tag_id)
|
||||
await Tags.delete_tags_by_ids_and_user_id(orphans, user_id, db=db)
|
||||
await Tags.delete_tags_by_ids_and_user_id(orphans, user_id, db=session)
|
||||
|
||||
async def count_chats_by_folder_id_and_user_id(
|
||||
self, folder_id: str, user_id: str, db: AsyncSession | None = None
|
||||
) -> int:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(func.count(Chat.id)).filter_by(user_id=user_id, folder_id=folder_id))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(func.count(Chat.id)).filter_by(user_id=user_id, folder_id=folder_id))
|
||||
count = result.scalar()
|
||||
|
||||
log.info(f"Count of chats for folder '{folder_id}': {count}")
|
||||
@@ -1499,8 +1501,8 @@ class ChatTable:
|
||||
self, id: str, user_id: str, tag_name: str, db: AsyncSession | None = None
|
||||
) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat = await db.get(Chat, id)
|
||||
async with get_async_db_context(db) as session:
|
||||
chat = await session.get(Chat, id)
|
||||
tags = chat.meta.get('tags', [])
|
||||
tag_id = tag_name.replace(' ', '_').lower()
|
||||
|
||||
@@ -1509,51 +1511,51 @@ class ChatTable:
|
||||
**chat.meta,
|
||||
'tags': list(set(tags)),
|
||||
}
|
||||
await db.commit()
|
||||
await session.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def delete_chat_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None))
|
||||
await db.execute(delete(ChatMessage).filter_by(chat_id=id))
|
||||
await db.execute(delete(Chat).filter_by(id=id))
|
||||
await db.commit()
|
||||
async with get_async_db_context(db) as session:
|
||||
await session.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None))
|
||||
await session.execute(delete(ChatMessage).filter_by(chat_id=id))
|
||||
await session.execute(delete(Chat).filter_by(id=id))
|
||||
await session.commit()
|
||||
|
||||
return True and await self.delete_shared_chat_by_chat_id(id, db=db)
|
||||
return True and await self.delete_shared_chat_by_chat_id(id, db=session)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def delete_chat_by_id_and_user_id(self, id: str, user_id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None))
|
||||
await db.execute(delete(ChatMessage).filter_by(chat_id=id))
|
||||
await db.execute(delete(Chat).filter_by(id=id, user_id=user_id))
|
||||
await db.commit()
|
||||
async with get_async_db_context(db) as session:
|
||||
await session.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None))
|
||||
await session.execute(delete(ChatMessage).filter_by(chat_id=id))
|
||||
await session.execute(delete(Chat).filter_by(id=id, user_id=user_id))
|
||||
await session.commit()
|
||||
|
||||
return True and await self.delete_shared_chat_by_chat_id(id, db=db)
|
||||
return True and await self.delete_shared_chat_by_chat_id(id, db=session)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def delete_chats_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await self.delete_shared_chats_by_user_id(user_id, db=db)
|
||||
async with get_async_db_context(db) as session:
|
||||
await self.delete_shared_chats_by_user_id(user_id, db=session)
|
||||
|
||||
chat_id_subquery = select(Chat.id).filter_by(user_id=user_id).scalar_subquery()
|
||||
await db.execute(
|
||||
await session.execute(
|
||||
update(AutomationRun)
|
||||
.filter(AutomationRun.chat_id.in_(select(Chat.id).filter_by(user_id=user_id)))
|
||||
.values(chat_id=None)
|
||||
)
|
||||
await db.execute(
|
||||
await session.execute(
|
||||
delete(ChatMessage).filter(ChatMessage.chat_id.in_(select(Chat.id).filter_by(user_id=user_id)))
|
||||
)
|
||||
await db.execute(delete(Chat).filter_by(user_id=user_id))
|
||||
await db.commit()
|
||||
await session.execute(delete(Chat).filter_by(user_id=user_id))
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
@@ -1563,14 +1565,14 @@ class ChatTable:
|
||||
self, user_id: str, folder_id: str, db: AsyncSession | None = None
|
||||
) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
chat_ids_stmt = select(Chat.id).filter_by(user_id=user_id, folder_id=folder_id)
|
||||
await db.execute(
|
||||
await session.execute(
|
||||
update(AutomationRun).filter(AutomationRun.chat_id.in_(chat_ids_stmt)).values(chat_id=None)
|
||||
)
|
||||
await db.execute(delete(ChatMessage).filter(ChatMessage.chat_id.in_(chat_ids_stmt)))
|
||||
await db.execute(delete(Chat).filter_by(user_id=user_id, folder_id=folder_id))
|
||||
await db.commit()
|
||||
await session.execute(delete(ChatMessage).filter(ChatMessage.chat_id.in_(chat_ids_stmt)))
|
||||
await session.execute(delete(Chat).filter_by(user_id=user_id, folder_id=folder_id))
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
@@ -1584,11 +1586,11 @@ class ChatTable:
|
||||
db: AsyncSession | None = None,
|
||||
) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(
|
||||
async with get_async_db_context(db) as session:
|
||||
await session.execute(
|
||||
update(Chat).filter_by(user_id=user_id, folder_id=folder_id).values(folder_id=new_folder_id)
|
||||
)
|
||||
await db.commit()
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
@@ -1600,13 +1602,13 @@ class ChatTable:
|
||||
from open_webui.models.shared_chats import SharedChats
|
||||
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
# Delete shared_chat rows for this user's chats
|
||||
await db.execute(delete(SharedChatTable).filter_by(user_id=user_id))
|
||||
await session.execute(delete(SharedChatTable).filter_by(user_id=user_id))
|
||||
|
||||
# Clear share_id on all of this user's chats
|
||||
await db.execute(update(Chat).filter_by(user_id=user_id).values(share_id=None))
|
||||
await db.commit()
|
||||
await session.execute(update(Chat).filter_by(user_id=user_id).values(share_id=None))
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
@@ -1624,7 +1626,7 @@ class ChatTable:
|
||||
return None
|
||||
|
||||
chat_message_file_ids = {
|
||||
item.id for item in await self.get_chat_files_by_chat_id_and_message_id(chat_id, message_id, db=db)
|
||||
item.id for item in await self.get_chat_files_by_chat_id_and_message_id(chat_id, message_id, db=session)
|
||||
}
|
||||
# Remove duplicates and existing file_ids
|
||||
file_ids = list({file_id for file_id in file_ids if file_id and file_id not in chat_message_file_ids})
|
||||
@@ -1632,7 +1634,7 @@ class ChatTable:
|
||||
return None
|
||||
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
now = int(time.time())
|
||||
|
||||
chat_files = [
|
||||
@@ -1651,7 +1653,7 @@ class ChatTable:
|
||||
results = [ChatFile(**chat_file.model_dump()) for chat_file in chat_files]
|
||||
|
||||
db.add_all(results)
|
||||
await db.commit()
|
||||
await session.commit()
|
||||
|
||||
return chat_files
|
||||
except Exception:
|
||||
@@ -1660,8 +1662,8 @@ class ChatTable:
|
||||
async def get_chat_files_by_chat_id_and_message_id(
|
||||
self, chat_id: str, message_id: str, db: AsyncSession | None = None
|
||||
) -> list[ChatFileModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(
|
||||
select(ChatFile).filter_by(chat_id=chat_id, message_id=message_id).order_by(ChatFile.created_at.asc())
|
||||
)
|
||||
all_chat_files = result.scalars().all()
|
||||
@@ -1669,17 +1671,17 @@ class ChatTable:
|
||||
|
||||
async def delete_chat_file(self, chat_id: str, file_id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(delete(ChatFile).filter_by(chat_id=chat_id, file_id=file_id))
|
||||
await db.commit()
|
||||
async with get_async_db_context(db) as session:
|
||||
await session.execute(delete(ChatFile).filter_by(chat_id=chat_id, file_id=file_id))
|
||||
await session.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_shared_chat_ids_by_file_id(self, file_id: str, db: AsyncSession | None = None) -> list[str]:
|
||||
"""Return IDs of chats that contain this file and have an active share link."""
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(
|
||||
select(Chat.id)
|
||||
.join(ChatFile, Chat.id == ChatFile.chat_id)
|
||||
.filter(ChatFile.file_id == file_id, Chat.share_id.isnot(None))
|
||||
@@ -1690,12 +1692,12 @@ class ChatTable:
|
||||
"""Update the tasks list on a chat."""
|
||||
try:
|
||||
async with get_async_db_context() as db:
|
||||
chat = await db.get(Chat, id)
|
||||
chat = await session.get(Chat, id)
|
||||
if chat is None:
|
||||
return None
|
||||
chat.tasks = tasks
|
||||
await db.commit()
|
||||
await db.refresh(chat)
|
||||
await session.commit()
|
||||
await session.refresh(chat)
|
||||
return ChatModel.model_validate(chat)
|
||||
except Exception:
|
||||
return None
|
||||
@@ -1703,7 +1705,7 @@ class ChatTable:
|
||||
async def get_chat_tasks_by_id(self, id: str) -> list[dict]:
|
||||
"""Read the tasks list from a chat (lightweight column query)."""
|
||||
async with get_async_db_context() as db:
|
||||
result = await db.execute(select(Chat.tasks).filter_by(id=id))
|
||||
result = await session.execute(select(Chat.tasks).filter_by(id=id))
|
||||
row = result.first()
|
||||
if row is None or row[0] is None:
|
||||
return []
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""File upload models, forms, and database operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from open_webui.utils.misc import sanitize_metadata
|
||||
@@ -12,12 +13,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
####################
|
||||
# Files DB Schema
|
||||
# What is written here bears witness. Let the testimony
|
||||
# remain as it was given, and let none tamper with it.
|
||||
####################
|
||||
|
||||
|
||||
class File(Base):
|
||||
__tablename__ = 'file'
|
||||
@@ -25,7 +20,7 @@ class File(Base):
|
||||
user_id = Column(String)
|
||||
hash = Column(Text, nullable=True)
|
||||
|
||||
filename = Column(Text)
|
||||
filename = Column(Text) # original upload filename
|
||||
path = Column(Text, nullable=True)
|
||||
|
||||
data = Column(JSON, nullable=True)
|
||||
@@ -52,11 +47,6 @@ class FileModel(BaseModel):
|
||||
updated_at: int | None # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class FileMeta(BaseModel):
|
||||
name: str | None = None
|
||||
content_type: str | None = None
|
||||
@@ -157,16 +147,18 @@ class FilesTable:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.exception(f'Error inserting a new file: {e}')
|
||||
return None
|
||||
return None # insertion failed
|
||||
|
||||
async def get_file_by_id(self, id: str, db: AsyncSession | None = None) -> FileModel | None:
|
||||
async def get_file_by_id(
|
||||
self, id: str, db: AsyncSession | None = None,
|
||||
) -> FileModel | None:
|
||||
"""Look up a file by its primary key."""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
file = await db.get(File, id)
|
||||
return FileModel.model_validate(file) if file else None
|
||||
except Exception:
|
||||
file = await db.get(File, id)
|
||||
if not file:
|
||||
return None
|
||||
return FileModel.model_validate(file)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Function (filter/action/pipe) models, forms, and database operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from open_webui.models.users import UserModel, UserResponse, Users
|
||||
@@ -12,12 +13,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
####################
|
||||
# Functions DB Schema
|
||||
# Each function here is a promise made. Let no promise
|
||||
# go unkept, and let none be called who cannot answer.
|
||||
####################
|
||||
|
||||
|
||||
class Function(Base):
|
||||
__tablename__ = 'function'
|
||||
@@ -30,11 +25,11 @@ class Function(Base):
|
||||
meta = Column(JSONField)
|
||||
valves = Column(JSONField)
|
||||
is_active = Column(Boolean)
|
||||
is_global = Column(Boolean)
|
||||
updated_at = Column(BigInteger)
|
||||
created_at = Column(BigInteger)
|
||||
is_global = Column(Boolean) # if True, applied to every chat automatically
|
||||
updated_at = Column(BigInteger) # epoch seconds
|
||||
created_at = Column(BigInteger) # epoch seconds
|
||||
|
||||
__table_args__ = (Index('is_global_idx', 'is_global'),)
|
||||
__table_args__ = (Index('is_global_idx', 'is_global'),) # speed up global-function lookups
|
||||
|
||||
|
||||
class FunctionMeta(BaseModel):
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Long-term memory storage for per-user context recall."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
@@ -9,38 +11,30 @@ from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import BigInteger, Column, String, Text, delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
####################
|
||||
# Memory DB Schema
|
||||
# What was learned at cost should not need to be paid
|
||||
# for again. Let the memory hold.
|
||||
####################
|
||||
|
||||
|
||||
class Memory(Base):
|
||||
"""Persistent user memory backed by a vector collection."""
|
||||
|
||||
__tablename__ = 'memory'
|
||||
|
||||
id = Column(String, primary_key=True, unique=True)
|
||||
user_id = Column(String, index=True)
|
||||
content = Column(Text)
|
||||
updated_at = Column(BigInteger)
|
||||
created_at = Column(BigInteger)
|
||||
content = Column(Text) # free-form text learned from conversation
|
||||
updated_at = Column(BigInteger) # epoch seconds
|
||||
created_at = Column(BigInteger) # epoch seconds
|
||||
|
||||
|
||||
class MemoryModel(BaseModel):
|
||||
"""Pydantic mirror of the Memory table row."""
|
||||
|
||||
id: str
|
||||
user_id: str
|
||||
content: str
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class MemoriesTable:
|
||||
async def insert_new_memory(
|
||||
self,
|
||||
@@ -48,26 +42,20 @@ class MemoriesTable:
|
||||
content: str,
|
||||
db: AsyncSession | None = None,
|
||||
) -> MemoryModel | None:
|
||||
"""Persist a new memory entry and return the created model."""
|
||||
async with get_async_db_context(db) as db:
|
||||
id = str(uuid.uuid4())
|
||||
|
||||
memory = MemoryModel(
|
||||
**{
|
||||
'id': id,
|
||||
'user_id': user_id,
|
||||
'content': content,
|
||||
'created_at': int(time.time()),
|
||||
'updated_at': int(time.time()),
|
||||
}
|
||||
now = int(time.time())
|
||||
record = Memory(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
content=content,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
result = Memory(**memory.model_dump())
|
||||
db.add(result)
|
||||
db.add(record)
|
||||
await db.commit()
|
||||
await db.refresh(result)
|
||||
if result:
|
||||
return MemoryModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
await db.refresh(record)
|
||||
return MemoryModel.model_validate(record) if record else None
|
||||
|
||||
async def update_memory_by_id_and_user_id(
|
||||
self,
|
||||
@@ -143,12 +131,10 @@ class MemoriesTable:
|
||||
try:
|
||||
memory = await db.get(Memory, id)
|
||||
if not memory or memory.user_id != user_id:
|
||||
return None
|
||||
return False
|
||||
|
||||
# Delete the memory
|
||||
await db.delete(memory)
|
||||
await db.commit()
|
||||
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@@ -60,38 +60,19 @@ class ModelMeta(BaseModel):
|
||||
|
||||
|
||||
class Model(Base):
|
||||
"""Workspace model entry — wraps an upstream LLM with custom params and metadata."""
|
||||
|
||||
__tablename__ = 'model'
|
||||
|
||||
id = Column(Text, primary_key=True, unique=True)
|
||||
"""
|
||||
The model's id as used in the API. If set to an existing model, it will override the model.
|
||||
"""
|
||||
user_id = Column(Text)
|
||||
|
||||
base_model_id = Column(Text, nullable=True)
|
||||
"""
|
||||
An optional pointer to the actual model that should be used when proxying requests.
|
||||
"""
|
||||
|
||||
name = Column(Text)
|
||||
"""
|
||||
The human-readable display name of the model.
|
||||
"""
|
||||
|
||||
params = Column(JSONField)
|
||||
"""
|
||||
Holds a JSON encoded blob of parameters, see `ModelParams`.
|
||||
"""
|
||||
|
||||
meta = Column(JSONField)
|
||||
"""
|
||||
Holds a JSON encoded blob of metadata, see `ModelMeta`.
|
||||
"""
|
||||
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
updated_at = Column(BigInteger)
|
||||
created_at = Column(BigInteger)
|
||||
id = Column(Text, primary_key=True, unique=True) # API model identifier; overrides built-in when matching
|
||||
user_id = Column(Text) # owner
|
||||
base_model_id = Column(Text, nullable=True) # actual upstream model for proxied requests
|
||||
name = Column(Text) # human-readable display name
|
||||
params = Column(JSONField) # see ModelParams
|
||||
meta = Column(JSONField) # see ModelMeta
|
||||
is_active = Column(Boolean, default=True) # soft-disable toggle
|
||||
updated_at = Column(BigInteger) # epoch seconds
|
||||
created_at = Column(BigInteger) # epoch seconds
|
||||
|
||||
|
||||
class ModelModel(BaseModel):
|
||||
@@ -109,12 +90,9 @@ class ModelModel(BaseModel):
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
)
|
||||
|
||||
|
||||
class ModelUserResponse(ModelModel):
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""Prompt template models, forms, and database operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
@@ -14,23 +16,19 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import JSON, BigInteger, Boolean, Column, String, Text, cast, delete, func, or_, select, text, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
####################
|
||||
# Prompts DB Schema
|
||||
# Every word here was weighed before it was set down.
|
||||
# Let the weight not be wasted when it is spoken aloud.
|
||||
####################
|
||||
|
||||
|
||||
class Prompt(Base):
|
||||
"""System prompt template with versioning support."""
|
||||
|
||||
__tablename__ = 'prompt'
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
command = Column(String, unique=True, index=True)
|
||||
user_id = Column(String)
|
||||
name = Column(Text)
|
||||
content = Column(Text)
|
||||
data = Column(JSON, nullable=True)
|
||||
meta = Column(JSON, nullable=True)
|
||||
content = Column(Text) # the prompt template body
|
||||
data = Column(JSON, nullable=True) # structured prompt parameters
|
||||
meta = Column(JSON, nullable=True) # freeform metadata (description, etc.)
|
||||
tags = Column(JSON, nullable=True)
|
||||
is_active = Column(Boolean, default=True)
|
||||
version_id = Column(Text, nullable=True) # Points to active history entry
|
||||
@@ -94,7 +92,7 @@ class PromptForm(BaseModel):
|
||||
|
||||
class PromptsTable:
|
||||
async def _get_access_grants(self, prompt_id: str, db: AsyncSession | None = None) -> list[AccessGrantModel]:
|
||||
return await AccessGrants.get_grants_by_resource('prompt', prompt_id, db=db)
|
||||
return await AccessGrants.get_grants_by_resource('prompt', prompt_id, db=session)
|
||||
|
||||
async def _to_prompt_model(
|
||||
self,
|
||||
@@ -104,98 +102,94 @@ class PromptsTable:
|
||||
) -> PromptModel:
|
||||
prompt_data = PromptModel.model_validate(prompt).model_dump(exclude={'access_grants'})
|
||||
prompt_data['access_grants'] = (
|
||||
access_grants if access_grants is not None else await self._get_access_grants(prompt_data['id'], db=db)
|
||||
access_grants if access_grants is not None else await self._get_access_grants(prompt_data['id'], db=session)
|
||||
)
|
||||
return PromptModel.model_validate(prompt_data)
|
||||
|
||||
async def insert_new_prompt(
|
||||
self, user_id: str, form_data: PromptForm, db: AsyncSession | None = None
|
||||
) -> PromptModel | None:
|
||||
async def insert_new_prompt(self, user_id: str, form_data: PromptForm, db: AsyncSession | None = None) -> PromptModel | None:
|
||||
now = int(time.time())
|
||||
prompt_id = str(uuid.uuid4())
|
||||
|
||||
prompt = PromptModel(
|
||||
id=prompt_id,
|
||||
user_id=user_id,
|
||||
command=form_data.command,
|
||||
name=form_data.name,
|
||||
content=form_data.content,
|
||||
data=form_data.data or {},
|
||||
meta=form_data.meta or {},
|
||||
tags=form_data.tags or [],
|
||||
access_grants=[],
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
async with get_async_db_context(db) as session:
|
||||
try:
|
||||
record = Prompt(
|
||||
id=prompt_id, user_id=user_id,
|
||||
command=form_data.command, name=form_data.name,
|
||||
content=form_data.content,
|
||||
data=form_data.data or {}, meta=form_data.meta or {},
|
||||
tags=form_data.tags or [], is_active=True,
|
||||
created_at=now, updated_at=now,
|
||||
)
|
||||
session.add(record)
|
||||
await session.commit()
|
||||
await session.refresh(record) # populate generated defaults
|
||||
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = Prompt(**prompt.model_dump(exclude={'access_grants'}))
|
||||
db.add(result)
|
||||
await db.commit()
|
||||
await db.refresh(result)
|
||||
await AccessGrants.set_access_grants('prompt', prompt_id, form_data.access_grants, db=db)
|
||||
await AccessGrants.set_access_grants(
|
||||
'prompt', prompt_id, form_data.access_grants, db=session,
|
||||
) # persist sharing rules
|
||||
|
||||
if result:
|
||||
current_access_grants = await self._get_access_grants(prompt_id, db=db)
|
||||
snapshot = {
|
||||
'name': form_data.name,
|
||||
'content': form_data.content,
|
||||
'command': form_data.command,
|
||||
'data': form_data.data or {},
|
||||
'meta': form_data.meta or {},
|
||||
'tags': form_data.tags or [],
|
||||
'access_grants': [grant.model_dump() for grant in current_access_grants],
|
||||
}
|
||||
|
||||
history_entry = await PromptHistories.create_history_entry(
|
||||
prompt_id=prompt_id,
|
||||
snapshot=snapshot,
|
||||
user_id=user_id,
|
||||
parent_id=None, # Initial commit has no parent
|
||||
commit_message=form_data.commit_message or 'Initial version',
|
||||
db=db,
|
||||
)
|
||||
|
||||
# Set the initial version as the production version
|
||||
if history_entry:
|
||||
result.version_id = history_entry.id
|
||||
await db.commit()
|
||||
await db.refresh(result)
|
||||
|
||||
return await self._to_prompt_model(result, db=db)
|
||||
else:
|
||||
if not record: # shouldn't happen, but guard anyway
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# Build the initial version snapshot.
|
||||
grants = await self._get_access_grants(prompt_id, db=session)
|
||||
snapshot = {
|
||||
'name': form_data.name,
|
||||
'content': form_data.content,
|
||||
'command': form_data.command,
|
||||
'data': form_data.data or {},
|
||||
'meta': form_data.meta or {},
|
||||
'tags': form_data.tags or [],
|
||||
'access_grants': [g.model_dump() for g in grants],
|
||||
}
|
||||
|
||||
history_entry = await PromptHistories.create_history_entry(
|
||||
prompt_id=prompt_id, snapshot=snapshot,
|
||||
user_id=user_id, parent_id=None,
|
||||
commit_message=form_data.commit_message or 'Initial version',
|
||||
db=session,
|
||||
) # creates the first version entry
|
||||
|
||||
# Pin the initial history entry as the production version.
|
||||
if history_entry:
|
||||
record.version_id = history_entry.id
|
||||
await session.commit()
|
||||
await session.refresh(record) # re-read version_id
|
||||
|
||||
return await self._to_prompt_model(record, db=session)
|
||||
except Exception as e:
|
||||
log.exception('Error creating prompt: %s', e)
|
||||
return None
|
||||
|
||||
async def get_prompt_by_id(self, prompt_id: str, db: AsyncSession | None = None) -> PromptModel | None:
|
||||
"""Get prompt by UUID."""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Prompt).filter_by(id=prompt_id))
|
||||
prompt = result.scalars().first()
|
||||
if prompt:
|
||||
return await self._to_prompt_model(prompt, db=db)
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(
|
||||
select(Prompt).filter_by(id=prompt_id),
|
||||
)
|
||||
prompt = result.scalars().first() # None when not found
|
||||
if not prompt:
|
||||
return None
|
||||
return await self._to_prompt_model(prompt, db=session)
|
||||
except Exception: # connection / integrity error
|
||||
return
|
||||
|
||||
async def get_prompt_by_command(self, command: str, db: AsyncSession | None = None) -> PromptModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Prompt).filter_by(command=command))
|
||||
prompt = result.scalars().first()
|
||||
if prompt:
|
||||
return await self._to_prompt_model(prompt, db=db)
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(
|
||||
select(Prompt).filter_by(command=command),
|
||||
)
|
||||
prompt = result.scalars().first() # None when no match
|
||||
if not prompt:
|
||||
return None
|
||||
return await self._to_prompt_model(prompt, db=session)
|
||||
except Exception: # connection / integrity error
|
||||
return
|
||||
|
||||
async def get_prompts(self, db: AsyncSession | None = None) -> list[PromptUserResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(
|
||||
select(Prompt).filter(Prompt.is_active == True).order_by(Prompt.updated_at.desc())
|
||||
)
|
||||
all_prompts = result.scalars().all()
|
||||
@@ -203,9 +197,9 @@ class PromptsTable:
|
||||
user_ids = list(set(prompt.user_id for prompt in all_prompts))
|
||||
prompt_ids = [prompt.id for prompt in all_prompts]
|
||||
|
||||
users = await Users.get_users_by_user_ids(user_ids, db=db) if user_ids else []
|
||||
users = await Users.get_users_by_user_ids(user_ids, db=session) if user_ids else []
|
||||
users_dict = {user.id: user for user in users}
|
||||
grants_map = await AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=db)
|
||||
grants_map = await AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=session)
|
||||
|
||||
prompts = []
|
||||
for prompt in all_prompts:
|
||||
@@ -217,7 +211,7 @@ class PromptsTable:
|
||||
await self._to_prompt_model(
|
||||
prompt,
|
||||
access_grants=grants_map.get(prompt.id, []),
|
||||
db=db,
|
||||
db=session,
|
||||
)
|
||||
).model_dump(),
|
||||
'user': user.model_dump() if user else None,
|
||||
@@ -230,8 +224,8 @@ class PromptsTable:
|
||||
async def get_prompts_by_user_id(
|
||||
self, user_id: str, permission: str = 'write', db: AsyncSession | None = None
|
||||
) -> list[PromptUserResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
user_groups = await Groups.get_groups_by_member_id(user_id, db=db)
|
||||
async with get_async_db_context(db) as session:
|
||||
user_groups = await Groups.get_groups_by_member_id(user_id, db=session)
|
||||
user_group_ids = [group.id for group in user_groups]
|
||||
|
||||
query = select(Prompt).filter(Prompt.is_active == True).order_by(Prompt.updated_at.desc())
|
||||
@@ -244,7 +238,7 @@ class PromptsTable:
|
||||
permission=permission,
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
result = await session.execute(query)
|
||||
accessible_prompts = result.scalars().all()
|
||||
|
||||
if not accessible_prompts:
|
||||
@@ -253,9 +247,9 @@ class PromptsTable:
|
||||
prompt_ids = [p.id for p in accessible_prompts]
|
||||
owner_ids = list({p.user_id for p in accessible_prompts})
|
||||
|
||||
users = await Users.get_users_by_user_ids(owner_ids, db=db)
|
||||
users = await Users.get_users_by_user_ids(owner_ids, db=session)
|
||||
users_dict = {u.id: u for u in users}
|
||||
grants_map = await AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=db)
|
||||
grants_map = await AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=session)
|
||||
|
||||
results = []
|
||||
for prompt in accessible_prompts:
|
||||
@@ -284,7 +278,7 @@ class PromptsTable:
|
||||
limit: int = 30,
|
||||
db: AsyncSession | None = None,
|
||||
) -> PromptListResponse:
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
# Join with User table for user filtering and sorting
|
||||
query = select(Prompt, User).outerjoin(User, User.id == Prompt.user_id)
|
||||
|
||||
@@ -319,7 +313,7 @@ class PromptsTable:
|
||||
|
||||
tag = filter.get('tag')
|
||||
if tag:
|
||||
bind = await db.connection()
|
||||
bind = await session.connection()
|
||||
dialect_name = bind.dialect.name
|
||||
tag_lower = tag.lower()
|
||||
|
||||
@@ -367,7 +361,7 @@ class PromptsTable:
|
||||
query = query.order_by(Prompt.updated_at.desc())
|
||||
|
||||
# Count BEFORE pagination
|
||||
count_result = await db.execute(select(func.count()).select_from(query.subquery()))
|
||||
count_result = await session.execute(select(func.count()).select_from(query.subquery()))
|
||||
total = count_result.scalar()
|
||||
|
||||
if skip:
|
||||
@@ -375,11 +369,11 @@ class PromptsTable:
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
|
||||
result = await db.execute(query)
|
||||
result = await session.execute(query)
|
||||
items = result.all()
|
||||
|
||||
prompt_ids = [prompt.id for prompt, _ in items]
|
||||
grants_map = await AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=db)
|
||||
grants_map = await AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=session)
|
||||
|
||||
prompts = []
|
||||
for prompt, user in items:
|
||||
@@ -406,15 +400,15 @@ class PromptsTable:
|
||||
db: AsyncSession | None = None,
|
||||
) -> PromptModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Prompt).filter_by(command=command))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(Prompt).filter_by(command=command))
|
||||
prompt = result.scalars().first()
|
||||
if not prompt:
|
||||
return None
|
||||
|
||||
latest_history = await PromptHistories.get_latest_history_entry(prompt.id, db=db)
|
||||
latest_history = await PromptHistories.get_latest_history_entry(prompt.id, db=session)
|
||||
parent_id = latest_history.id if latest_history else None
|
||||
current_access_grants = await self._get_access_grants(prompt.id, db=db)
|
||||
current_access_grants = await self._get_access_grants(prompt.id, db=session)
|
||||
|
||||
# Check if content changed to decide on history creation
|
||||
content_changed = (
|
||||
@@ -430,10 +424,10 @@ class PromptsTable:
|
||||
prompt.meta = form_data.meta or prompt.meta
|
||||
prompt.updated_at = int(time.time())
|
||||
if form_data.access_grants is not None:
|
||||
await AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=db)
|
||||
current_access_grants = await self._get_access_grants(prompt.id, db=db)
|
||||
await AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=session)
|
||||
current_access_grants = await self._get_access_grants(prompt.id, db=session)
|
||||
|
||||
await db.commit()
|
||||
await session.commit()
|
||||
|
||||
# Create history entry only if content changed
|
||||
if content_changed:
|
||||
@@ -458,9 +452,9 @@ class PromptsTable:
|
||||
# Set as production if flag is True (default)
|
||||
if form_data.is_production and history_entry:
|
||||
prompt.version_id = history_entry.id
|
||||
await db.commit()
|
||||
await session.commit()
|
||||
|
||||
return await self._to_prompt_model(prompt, db=db)
|
||||
return await self._to_prompt_model(prompt, db=session)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@@ -472,15 +466,15 @@ class PromptsTable:
|
||||
db: AsyncSession | None = None,
|
||||
) -> PromptModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Prompt).filter_by(id=prompt_id))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(Prompt).filter_by(id=prompt_id))
|
||||
prompt = result.scalars().first()
|
||||
if not prompt:
|
||||
return None
|
||||
|
||||
latest_history = await PromptHistories.get_latest_history_entry(prompt.id, db=db)
|
||||
latest_history = await PromptHistories.get_latest_history_entry(prompt.id, db=session)
|
||||
parent_id = latest_history.id if latest_history else None
|
||||
current_access_grants = await self._get_access_grants(prompt.id, db=db)
|
||||
current_access_grants = await self._get_access_grants(prompt.id, db=session)
|
||||
|
||||
# Check if content changed to decide on history creation
|
||||
content_changed = (
|
||||
@@ -502,12 +496,12 @@ class PromptsTable:
|
||||
prompt.tags = form_data.tags
|
||||
|
||||
if form_data.access_grants is not None:
|
||||
await AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=db)
|
||||
current_access_grants = await self._get_access_grants(prompt.id, db=db)
|
||||
await AccessGrants.set_access_grants('prompt', prompt.id, form_data.access_grants, db=session)
|
||||
current_access_grants = await self._get_access_grants(prompt.id, db=session)
|
||||
|
||||
prompt.updated_at = int(time.time())
|
||||
|
||||
await db.commit()
|
||||
await session.commit()
|
||||
|
||||
# Create history entry only if content changed
|
||||
if content_changed:
|
||||
@@ -533,9 +527,9 @@ class PromptsTable:
|
||||
# Set as production if flag is True (default)
|
||||
if form_data.is_production and history_entry:
|
||||
prompt.version_id = history_entry.id
|
||||
await db.commit()
|
||||
await session.commit()
|
||||
|
||||
return await self._to_prompt_model(prompt, db=db)
|
||||
return await self._to_prompt_model(prompt, db=session)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@@ -549,8 +543,8 @@ class PromptsTable:
|
||||
) -> PromptModel | None:
|
||||
"""Update only name, command, and tags (no history created)."""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Prompt).filter_by(id=prompt_id))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(Prompt).filter_by(id=prompt_id))
|
||||
prompt = result.scalars().first()
|
||||
if not prompt:
|
||||
return None
|
||||
@@ -562,9 +556,9 @@ class PromptsTable:
|
||||
prompt.tags = tags
|
||||
|
||||
prompt.updated_at = int(time.time())
|
||||
await db.commit()
|
||||
await session.commit()
|
||||
|
||||
return await self._to_prompt_model(prompt, db=db)
|
||||
return await self._to_prompt_model(prompt, db=session)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@@ -576,13 +570,13 @@ class PromptsTable:
|
||||
) -> PromptModel | None:
|
||||
"""Set the active version of a prompt and restore content from that version's snapshot."""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Prompt).filter_by(id=prompt_id))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(Prompt).filter_by(id=prompt_id))
|
||||
prompt = result.scalars().first()
|
||||
if not prompt:
|
||||
return None
|
||||
|
||||
history_entry = await PromptHistories.get_history_entry_by_id(version_id, db=db)
|
||||
history_entry = await PromptHistories.get_history_entry_by_id(version_id, db=session)
|
||||
|
||||
if not history_entry:
|
||||
return None
|
||||
@@ -599,24 +593,28 @@ class PromptsTable:
|
||||
|
||||
prompt.version_id = version_id
|
||||
prompt.updated_at = int(time.time())
|
||||
await db.commit()
|
||||
await session.commit()
|
||||
|
||||
return await self._to_prompt_model(prompt, db=db)
|
||||
except Exception:
|
||||
return await self._to_prompt_model(prompt, db=session)
|
||||
except Exception: # connection error
|
||||
return None
|
||||
|
||||
async def toggle_prompt_active(self, prompt_id: str, db: AsyncSession | None = None) -> PromptModel | None:
|
||||
# --- Active state management ---
|
||||
|
||||
async def toggle_prompt_active(
|
||||
self, prompt_id: str, db: AsyncSession | None = None,
|
||||
) -> PromptModel | None:
|
||||
"""Toggle the is_active flag on a prompt."""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Prompt).filter_by(id=prompt_id))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(Prompt).filter_by(id=prompt_id))
|
||||
prompt = result.scalars().first()
|
||||
if prompt:
|
||||
prompt.is_active = not prompt.is_active
|
||||
prompt.updated_at = int(time.time())
|
||||
await db.commit()
|
||||
await db.refresh(prompt)
|
||||
return await self._to_prompt_model(prompt, db=db)
|
||||
await session.commit()
|
||||
await session.refresh(prompt)
|
||||
return await self._to_prompt_model(prompt, db=session)
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
@@ -624,15 +622,15 @@ class PromptsTable:
|
||||
async def delete_prompt_by_command(self, command: str, db: AsyncSession | None = None) -> bool:
|
||||
"""Permanently delete a prompt and its history."""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Prompt).filter_by(command=command))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(Prompt).filter_by(command=command))
|
||||
prompt = result.scalars().first()
|
||||
if prompt:
|
||||
await PromptHistories.delete_history_by_prompt_id(prompt.id, db=db)
|
||||
await AccessGrants.revoke_all_access('prompt', prompt.id, db=db)
|
||||
await PromptHistories.delete_history_by_prompt_id(prompt.id, db=session)
|
||||
await AccessGrants.revoke_all_access('prompt', prompt.id, db=session)
|
||||
|
||||
await db.delete(prompt)
|
||||
await db.commit()
|
||||
await session.delete(prompt)
|
||||
await session.commit()
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
@@ -641,15 +639,15 @@ class PromptsTable:
|
||||
async def delete_prompt_by_id(self, prompt_id: str, db: AsyncSession | None = None) -> bool:
|
||||
"""Permanently delete a prompt and its history."""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Prompt).filter_by(id=prompt_id))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(Prompt).filter_by(id=prompt_id))
|
||||
prompt = result.scalars().first()
|
||||
if prompt:
|
||||
await PromptHistories.delete_history_by_prompt_id(prompt.id, db=db)
|
||||
await AccessGrants.revoke_all_access('prompt', prompt.id, db=db)
|
||||
await PromptHistories.delete_history_by_prompt_id(prompt.id, db=session)
|
||||
await AccessGrants.revoke_all_access('prompt', prompt.id, db=session)
|
||||
|
||||
await db.delete(prompt)
|
||||
await db.commit()
|
||||
await session.delete(prompt)
|
||||
await session.commit()
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
@@ -657,8 +655,8 @@ class PromptsTable:
|
||||
|
||||
async def get_tags(self, db: AsyncSession | None = None) -> list[str]:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Prompt.tags).filter(Prompt.is_active == True))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(Prompt.tags).filter(Prompt.is_active == True))
|
||||
tags = set()
|
||||
for (tag_list,) in result.all():
|
||||
if tag_list:
|
||||
@@ -671,8 +669,8 @@ class PromptsTable:
|
||||
|
||||
async def get_tags_by_user_id(self, user_id: str, db: AsyncSession | None = None) -> list[str]:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
user_groups = await Groups.get_groups_by_member_id(user_id, db=db)
|
||||
async with get_async_db_context(db) as session:
|
||||
user_groups = await Groups.get_groups_by_member_id(user_id, db=session)
|
||||
user_group_ids = [group.id for group in user_groups]
|
||||
|
||||
query = select(Prompt.tags).filter(Prompt.is_active == True)
|
||||
@@ -685,7 +683,7 @@ class PromptsTable:
|
||||
permission='read',
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
result = await session.execute(query)
|
||||
tags = set()
|
||||
for (tag_list,) in result.all():
|
||||
if tag_list:
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""Tag models and database operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
@@ -13,11 +14,6 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
####################
|
||||
# Tag DB Schema
|
||||
# To name a thing is to claim it. The creator has
|
||||
# already named everything stored in this table.
|
||||
####################
|
||||
class Tag(Base):
|
||||
__tablename__ = 'tag'
|
||||
id = Column(String)
|
||||
@@ -53,22 +49,21 @@ class TagChatIdForm(BaseModel):
|
||||
|
||||
|
||||
class TagTable:
|
||||
async def insert_new_tag(self, name: str, user_id: str, db: AsyncSession | None = None) -> TagModel | None:
|
||||
async def insert_new_tag(
|
||||
self, name: str, user_id: str, db: AsyncSession | None = None,
|
||||
) -> TagModel | None:
|
||||
"""Create a new tag, deriving the id from the name."""
|
||||
async with get_async_db_context(db) as db:
|
||||
id = name.replace(' ', '_').lower()
|
||||
tag = TagModel(**{'id': id, 'user_id': user_id, 'name': name})
|
||||
tag_id = name.replace(' ', '_').lower()
|
||||
try:
|
||||
result = Tag(**tag.model_dump())
|
||||
db.add(result)
|
||||
record = Tag(id=tag_id, user_id=user_id, name=name)
|
||||
db.add(record)
|
||||
await db.commit()
|
||||
await db.refresh(result)
|
||||
if result:
|
||||
return TagModel.model_validate(result)
|
||||
else:
|
||||
return None
|
||||
await db.refresh(record)
|
||||
return TagModel.model_validate(record) if record else None
|
||||
except Exception as e:
|
||||
log.exception(f'Error inserting a new tag: {e}')
|
||||
return None
|
||||
log.exception('Error inserting tag %r: %s', name, e)
|
||||
return None # insertion failed
|
||||
|
||||
async def get_tag_by_name_and_user_id(
|
||||
self, name: str, user_id: str, db: AsyncSession | None = None
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"""Tool models, forms, and database operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from open_webui.models.access_grants import AccessGrantModel, AccessGrants
|
||||
@@ -14,23 +15,17 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
####################
|
||||
# Tools DB Schema
|
||||
# A tool that fails silently is worse than one that
|
||||
# refuses outright. Let each one here be honest in its work.
|
||||
####################
|
||||
|
||||
|
||||
class Tool(Base):
|
||||
__tablename__ = 'tool'
|
||||
|
||||
id = Column(String, primary_key=True, unique=True)
|
||||
user_id = Column(String)
|
||||
name = Column(Text)
|
||||
content = Column(Text)
|
||||
specs = Column(JSONField)
|
||||
meta = Column(JSONField)
|
||||
valves = Column(JSONField)
|
||||
name = Column(Text) # human-readable label
|
||||
content = Column(Text) # Python source code
|
||||
specs = Column(JSONField) # OpenAPI-style function specs
|
||||
meta = Column(JSONField) # description, manifest, etc.
|
||||
valves = Column(JSONField) # admin-configurable runtime parameters
|
||||
|
||||
updated_at = Column(BigInteger)
|
||||
created_at = Column(BigInteger)
|
||||
@@ -141,13 +136,18 @@ class ToolsTable:
|
||||
return None
|
||||
except Exception as e:
|
||||
log.exception(f'Error creating a new tool: {e}')
|
||||
return None
|
||||
return None # creation failed
|
||||
|
||||
async def get_tool_by_id(self, id: str, db: AsyncSession | None = None) -> ToolModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
tool = await db.get(Tool, id)
|
||||
return await self._to_tool_model(tool, db=db) if tool else None
|
||||
async def get_tool_by_id(
|
||||
self, id: str, db: AsyncSession | None = None,
|
||||
) -> ToolModel | None:
|
||||
"""Fetch a single tool by primary key, including access grants."""
|
||||
try: # single PK lookup + access grants
|
||||
async with get_async_db_context(db) as session:
|
||||
tool = await session.get(Tool, id)
|
||||
if not tool:
|
||||
return None
|
||||
return await self._to_tool_model(tool, db=session)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
+201
-167
@@ -1,9 +1,10 @@
|
||||
"""User models, Pydantic schemas, and database access layer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from open_webui.env import DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from open_webui.utils.misc import throttle
|
||||
@@ -43,39 +44,46 @@ class UserSettings(BaseModel):
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""Core user identity and profile record."""
|
||||
|
||||
__tablename__ = 'user'
|
||||
|
||||
# Identity
|
||||
id = Column(String, primary_key=True, unique=True)
|
||||
email = Column(String)
|
||||
username = Column(String(50), nullable=True)
|
||||
role = Column(String)
|
||||
|
||||
name = Column(String)
|
||||
|
||||
profile_image_url = Column(Text)
|
||||
# Profile
|
||||
profile_image_url = Column(Text) # data-uri, path, or external URL
|
||||
profile_banner_image_url = Column(Text, nullable=True)
|
||||
|
||||
bio = Column(Text, nullable=True)
|
||||
gender = Column(Text, nullable=True)
|
||||
date_of_birth = Column(Date, nullable=True)
|
||||
timezone = Column(String, nullable=True)
|
||||
|
||||
# Online status
|
||||
presence_state = Column(String, nullable=True)
|
||||
status_emoji = Column(String, nullable=True)
|
||||
status_message = Column(Text, nullable=True)
|
||||
status_expires_at = Column(BigInteger, nullable=True)
|
||||
|
||||
# Metadata
|
||||
info = Column(JSON, nullable=True)
|
||||
settings = Column(JSON, nullable=True)
|
||||
|
||||
oauth = Column(JSON, nullable=True)
|
||||
scim = Column(JSON, nullable=True)
|
||||
|
||||
# Timestamps (epoch seconds)
|
||||
last_active_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
created_at = Column(BigInteger)
|
||||
|
||||
|
||||
_DEFAULT_PROFILE_IMAGE_URL = '/api/v1/users/{user_id}/profile/image'
|
||||
|
||||
|
||||
class UserModel(BaseModel):
|
||||
id: str
|
||||
|
||||
@@ -108,13 +116,19 @@ class UserModel(BaseModel):
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
)
|
||||
|
||||
@model_validator(mode='after')
|
||||
def set_profile_image_url(self):
|
||||
if not self.profile_image_url:
|
||||
self.profile_image_url = f'/api/v1/users/{self.id}/profile/image'
|
||||
return self
|
||||
@model_validator(mode='after') # runs after all field validators
|
||||
def _ensure_profile_image(self) -> 'UserModel':
|
||||
"""Fall back to a generated avatar when no explicit profile image is set."""
|
||||
if self.profile_image_url: # explicit image — nothing to do
|
||||
return self
|
||||
self.profile_image_url = _DEFAULT_PROFILE_IMAGE_URL.format(
|
||||
user_id=self.id,
|
||||
)
|
||||
return self # modified in-place
|
||||
|
||||
|
||||
class UserStatusModel(UserModel):
|
||||
@@ -269,7 +283,7 @@ class UsersTable:
|
||||
oauth: dict | None = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
user = UserModel(
|
||||
**{
|
||||
'id': id,
|
||||
@@ -285,89 +299,108 @@ class UsersTable:
|
||||
}
|
||||
)
|
||||
result = User(**user.model_dump())
|
||||
db.add(result)
|
||||
await db.commit()
|
||||
await db.refresh(result)
|
||||
session.add(result)
|
||||
await session.commit()
|
||||
await session.refresh(result)
|
||||
if result:
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
|
||||
async def get_user_by_id(self, id: str, db: AsyncSession | None = None) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(id=id))
|
||||
user = result.scalars().first()
|
||||
return UserModel.model_validate(user) if user else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_user_by_api_key(self, api_key: str, db: AsyncSession | None = None) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(User).join(ApiKey, User.id == ApiKey.user_id).filter(ApiKey.key == api_key)
|
||||
)
|
||||
user = result.scalars().first()
|
||||
return UserModel.model_validate(user) if user else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_user_by_email(self, email: str, db: AsyncSession | None = None) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter(func.lower(User.email) == email.lower()))
|
||||
user = result.scalars().first()
|
||||
return UserModel.model_validate(user) if user else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_user_by_oauth_sub(self, provider: str, sub: str, db: AsyncSession | None = None) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
dialect_name = db.bind.dialect.name
|
||||
|
||||
stmt = select(User)
|
||||
if dialect_name == 'sqlite':
|
||||
stmt = stmt.filter(User.oauth.contains({provider: {'sub': sub}}))
|
||||
elif dialect_name == 'postgresql':
|
||||
stmt = stmt.filter(User.oauth[provider].cast(JSONB)['sub'].astext == sub)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
user = result.scalars().first()
|
||||
return UserModel.model_validate(user) if user else None
|
||||
except Exception as e:
|
||||
# You may want to log the exception here
|
||||
return None
|
||||
|
||||
async def get_user_by_scim_external_id(
|
||||
self, provider: str, external_id: str, db: AsyncSession | None = None
|
||||
async def get_user_by_id(
|
||||
self, id: str, db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
"""Fetch a single user by primary key."""
|
||||
try: # db.get is a PK lookup — very cheap
|
||||
async with get_async_db_context(db) as session:
|
||||
user = await session.get(User, id)
|
||||
if not user:
|
||||
return None
|
||||
return UserModel.model_validate(user)
|
||||
except Exception: # stale session / connection drop
|
||||
return None # caller treats None as "not found"
|
||||
|
||||
async def get_user_by_api_key(
|
||||
self, api_key: str, db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
"""Resolve a user from their API key via a JOIN on the api_key table."""
|
||||
try: # single JOIN instead of two round-trips
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
.join(ApiKey, User.id == ApiKey.user_id)
|
||||
.filter(ApiKey.key == api_key),
|
||||
) # emits: SELECT user.* FROM user JOIN api_key …
|
||||
user = result.scalars().first() # None when key is invalid
|
||||
if not user:
|
||||
return None
|
||||
return UserModel.model_validate(user)
|
||||
except Exception: # stale session / connection drop
|
||||
return
|
||||
|
||||
async def get_user_by_email(
|
||||
self, email: str, db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
"""Case-insensitive email lookup using SQL ``lower()``."""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
dialect_name = db.bind.dialect.name
|
||||
|
||||
stmt = select(User)
|
||||
if dialect_name == 'sqlite':
|
||||
stmt = stmt.filter(User.scim.contains({provider: {'external_id': external_id}}))
|
||||
elif dialect_name == 'postgresql':
|
||||
stmt = stmt.filter(User.scim[provider].cast(JSONB)['external_id'].astext == external_id)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
user = result.scalars().first()
|
||||
return UserModel.model_validate(user) if user else None
|
||||
async with get_async_db_context(db) as session:
|
||||
stmt = select(User).filter(
|
||||
func.lower(User.email) == email.lower(),
|
||||
)
|
||||
row = (await session.execute(stmt)).scalars().first()
|
||||
if not row:
|
||||
return
|
||||
return UserModel.model_validate(row)
|
||||
except Exception:
|
||||
return None
|
||||
return
|
||||
|
||||
async def get_user_by_oauth_sub(
|
||||
self, provider: str, sub: str, db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
"""Look up a user by OAuth provider + subject claim (dialect-aware JSON filter)."""
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
dialect = session.bind.dialect.name
|
||||
query = select(User)
|
||||
if dialect == 'sqlite':
|
||||
oauth_match = User.oauth.contains({provider: {'sub': sub}})
|
||||
query = query.filter(oauth_match)
|
||||
elif dialect == 'postgresql':
|
||||
oauth_match = User.oauth[provider].cast(JSONB)['sub'].astext == sub
|
||||
query = query.filter(oauth_match)
|
||||
row = (await session.execute(query)).scalars().first()
|
||||
if not row:
|
||||
return
|
||||
return UserModel.model_validate(row)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
async def get_user_by_scim_external_id(self, provider: str, external_id: str, db: AsyncSession | None = None) -> UserModel | None:
|
||||
"""Look up a user by SCIM provider + external ID (dialect-aware JSON filter)."""
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
dialect = session.bind.dialect.name
|
||||
query = select(User)
|
||||
if dialect == 'sqlite':
|
||||
scim_match = User.scim.contains({provider: {'external_id': external_id}})
|
||||
query = query.filter(scim_match)
|
||||
elif dialect == 'postgresql':
|
||||
scim_match = User.scim[provider].cast(JSONB)['external_id'].astext == external_id
|
||||
query = query.filter(scim_match)
|
||||
row = (await session.execute(query)).scalars().first()
|
||||
if not row:
|
||||
return
|
||||
return UserModel.model_validate(row)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
async def get_users(
|
||||
self,
|
||||
filter: dict | None = None,
|
||||
skip: int | None = None,
|
||||
limit: int | None = None,
|
||||
db: AsyncSession | None = None,
|
||||
self, filter: dict | None = None, skip: int | None = None,
|
||||
limit: int | None = None, db: AsyncSession | None = None,
|
||||
) -> dict:
|
||||
async with get_async_db_context(db) as db:
|
||||
# Import here to avoid circular imports
|
||||
"""Paginated user listing with optional filters for role, group, and channel."""
|
||||
async with get_async_db_context(db) as session:
|
||||
# Deferred imports to avoid circular dependencies
|
||||
from open_webui.models.channels import ChannelMember
|
||||
from open_webui.models.groups import GroupMember
|
||||
|
||||
@@ -487,7 +520,7 @@ class UsersTable:
|
||||
stmt = stmt.order_by(User.created_at.desc())
|
||||
|
||||
# Count BEFORE pagination
|
||||
count_result = await db.execute(select(func.count()).select_from(stmt.subquery()))
|
||||
count_result = await session.execute(select(func.count()).select_from(stmt.subquery()))
|
||||
total = count_result.scalar()
|
||||
|
||||
# correct pagination logic
|
||||
@@ -496,7 +529,7 @@ class UsersTable:
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
result = await session.execute(stmt)
|
||||
users = result.scalars().all()
|
||||
return {
|
||||
'users': [UserModel.model_validate(user) for user in users],
|
||||
@@ -504,44 +537,47 @@ class UsersTable:
|
||||
}
|
||||
|
||||
async def get_users_by_group_id(self, group_id: str, db: AsyncSession | None = None) -> list[UserModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
from open_webui.models.groups import GroupMember
|
||||
|
||||
result = await db.execute(
|
||||
result = await session.execute(
|
||||
select(User).join(GroupMember, User.id == GroupMember.user_id).filter(GroupMember.group_id == group_id)
|
||||
)
|
||||
users = result.scalars().all()
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
|
||||
async def get_users_by_user_ids(self, user_ids: list[str], db: AsyncSession | None = None) -> list[UserStatusModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter(User.id.in_(user_ids)))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(User).filter(User.id.in_(user_ids)))
|
||||
users = result.scalars().all()
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
|
||||
async def get_num_users(self, db: AsyncSession | None = None) -> int | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(func.count()).select_from(User))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(func.count()).select_from(User))
|
||||
return result.scalar()
|
||||
|
||||
async def has_users(self, db: AsyncSession | None = None) -> bool:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(exists(select(User))))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(exists(select(User))))
|
||||
return result.scalar()
|
||||
|
||||
async def get_first_user(self, db: AsyncSession | None = None) -> UserModel:
|
||||
"""Return the earliest-created user (bootstrap admin detection)."""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).order_by(User.created_at).limit(1))
|
||||
user = result.scalars().first()
|
||||
return UserModel.model_validate(user) if user else None
|
||||
async with get_async_db_context(db) as session:
|
||||
stmt = select(User).order_by(User.created_at).limit(1)
|
||||
row = (await session.execute(stmt)).scalars().first()
|
||||
if not row:
|
||||
return
|
||||
return UserModel.model_validate(row)
|
||||
except Exception:
|
||||
return None
|
||||
return
|
||||
|
||||
async def get_user_webhook_url_by_id(self, id: str, db: AsyncSession | None = None) -> str | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(id=id))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(User).filter_by(id=id))
|
||||
user = result.scalars().first()
|
||||
|
||||
if user.settings is None:
|
||||
@@ -552,67 +588,65 @@ class UsersTable:
|
||||
return None
|
||||
|
||||
async def get_num_users_active_today(self, db: AsyncSession | None = None) -> int | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
current_timestamp = int(datetime.datetime.now().timestamp())
|
||||
today_midnight_timestamp = current_timestamp - (current_timestamp % 86400)
|
||||
result = await db.execute(
|
||||
result = await session.execute(
|
||||
select(func.count()).select_from(User).filter(User.last_active_at > today_midnight_timestamp)
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
async def update_user_role_by_id(self, id: str, role: str, db: AsyncSession | None = None) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(id=id))
|
||||
user = result.scalars().first()
|
||||
async with get_async_db_context(db) as session:
|
||||
user = (await session.execute(select(User).filter_by(id=id))).scalars().first()
|
||||
if not user:
|
||||
return None
|
||||
return
|
||||
user.role = role
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
return None
|
||||
return
|
||||
|
||||
async def update_user_status_by_id(
|
||||
self, id: str, form_data: UserStatus, db: AsyncSession | None = None
|
||||
) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(id=id))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(User).filter_by(id=id))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return None
|
||||
for key, value in form_data.model_dump(exclude_none=True).items():
|
||||
setattr(user, key, value)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def update_user_profile_image_url_by_id(
|
||||
self, id: str, profile_image_url: str, db: AsyncSession | None = None
|
||||
self, id: str, profile_image_url: str, db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(id=id))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return None
|
||||
user.profile_image_url = profile_image_url
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return UserModel.model_validate(user)
|
||||
async with get_async_db_context(db) as session:
|
||||
row = (await session.execute(select(User).filter_by(id=id))).scalars().first()
|
||||
if row is None:
|
||||
return
|
||||
row.profile_image_url = profile_image_url
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return UserModel.model_validate(row)
|
||||
except Exception:
|
||||
return None
|
||||
return
|
||||
|
||||
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
|
||||
async def update_last_active_by_id(self, id: str, db: AsyncSession | None = None) -> None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(update(User).filter_by(id=id).values(last_active_at=int(time.time())))
|
||||
await db.commit()
|
||||
async with get_async_db_context(db) as session:
|
||||
await session.execute(update(User).filter_by(id=id).values(last_active_at=int(time.time())))
|
||||
await session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -628,8 +662,8 @@ class UsersTable:
|
||||
}
|
||||
"""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(id=id))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(User).filter_by(id=id))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return None
|
||||
@@ -641,8 +675,8 @@ class UsersTable:
|
||||
oauth[provider] = {'sub': sub}
|
||||
|
||||
# Persist updated JSON
|
||||
await db.execute(update(User).filter_by(id=id).values(oauth=oauth))
|
||||
await db.commit()
|
||||
await session.execute(update(User).filter_by(id=id).values(oauth=oauth))
|
||||
await session.commit()
|
||||
|
||||
return UserModel.model_validate(user)
|
||||
|
||||
@@ -665,8 +699,8 @@ class UsersTable:
|
||||
}
|
||||
"""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(id=id))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(User).filter_by(id=id))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return None
|
||||
@@ -674,8 +708,8 @@ class UsersTable:
|
||||
scim = user.scim or {}
|
||||
scim[provider] = {'external_id': external_id}
|
||||
|
||||
await db.execute(update(User).filter_by(id=id).values(scim=scim))
|
||||
await db.commit()
|
||||
await session.execute(update(User).filter_by(id=id).values(scim=scim))
|
||||
await session.commit()
|
||||
|
||||
return UserModel.model_validate(user)
|
||||
|
||||
@@ -684,15 +718,15 @@ class UsersTable:
|
||||
|
||||
async def update_user_by_id(self, id: str, updated: dict, db: AsyncSession | None = None) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(id=id))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(User).filter_by(id=id))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return None
|
||||
for key, value in updated.items():
|
||||
setattr(user, key, value)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return UserModel.model_validate(user)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
@@ -702,8 +736,8 @@ class UsersTable:
|
||||
self, id: str, updated: dict, db: AsyncSession | None = None
|
||||
) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(id=id))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(User).filter_by(id=id))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return None
|
||||
@@ -715,10 +749,10 @@ class UsersTable:
|
||||
|
||||
user_settings.update(updated)
|
||||
|
||||
await db.execute(update(User).filter_by(id=id).values(settings=user_settings))
|
||||
await db.commit()
|
||||
await session.execute(update(User).filter_by(id=id).values(settings=user_settings))
|
||||
await session.commit()
|
||||
|
||||
result = await db.execute(select(User).filter_by(id=id))
|
||||
result = await session.execute(select(User).filter_by(id=id))
|
||||
user = result.scalars().first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
@@ -733,12 +767,12 @@ class UsersTable:
|
||||
await Groups.remove_user_from_all_groups(id)
|
||||
|
||||
# Delete User Chats
|
||||
result = await Chats.delete_chats_by_user_id(id, db=db)
|
||||
result = await Chats.delete_chats_by_user_id(id, db=session)
|
||||
if result:
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
# Delete User
|
||||
await db.execute(delete(User).filter_by(id=id))
|
||||
await db.commit()
|
||||
await session.execute(delete(User).filter_by(id=id))
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
else:
|
||||
@@ -748,8 +782,8 @@ class UsersTable:
|
||||
|
||||
async def get_user_api_key_by_id(self, id: str, db: AsyncSession | None = None) -> str | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(ApiKey).filter_by(user_id=id))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(ApiKey).filter_by(user_id=id))
|
||||
api_key = result.scalars().first()
|
||||
return api_key.key if api_key else None
|
||||
except Exception:
|
||||
@@ -757,9 +791,9 @@ class UsersTable:
|
||||
|
||||
async def update_user_api_key_by_id(self, id: str, api_key: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(delete(ApiKey).filter_by(user_id=id))
|
||||
await db.commit()
|
||||
async with get_async_db_context(db) as session:
|
||||
await session.execute(delete(ApiKey).filter_by(user_id=id))
|
||||
await session.commit()
|
||||
|
||||
now = int(time.time())
|
||||
new_api_key = ApiKey(
|
||||
@@ -769,8 +803,8 @@ class UsersTable:
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
db.add(new_api_key)
|
||||
await db.commit()
|
||||
session.add(new_api_key)
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
@@ -779,22 +813,22 @@ class UsersTable:
|
||||
|
||||
async def delete_user_api_key_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(delete(ApiKey).filter_by(user_id=id))
|
||||
await db.commit()
|
||||
async with get_async_db_context(db) as session:
|
||||
await session.execute(delete(ApiKey).filter_by(user_id=id))
|
||||
await session.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_valid_user_ids(self, user_ids: list[str], db: AsyncSession | None = None) -> list[str]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter(User.id.in_(user_ids)))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(User).filter(User.id.in_(user_ids)))
|
||||
users = result.scalars().all()
|
||||
return [user.id for user in users]
|
||||
|
||||
async def get_super_admin_user(self, db: AsyncSession | None = None) -> UserModel | None:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(role='admin').limit(1))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(User).filter_by(role='admin').limit(1))
|
||||
user = result.scalars().first()
|
||||
if user:
|
||||
return UserModel.model_validate(user)
|
||||
@@ -802,10 +836,10 @@ class UsersTable:
|
||||
return None
|
||||
|
||||
async def get_active_user_count(self, db: AsyncSession | None = None) -> int:
|
||||
async with get_async_db_context(db) as db:
|
||||
async with get_async_db_context(db) as session:
|
||||
# Consider user active if last_active_at within the last 3 minutes
|
||||
three_minutes_ago = int(time.time()) - 180
|
||||
result = await db.execute(
|
||||
result = await session.execute(
|
||||
select(func.count()).select_from(User).filter(User.last_active_at >= three_minutes_ago)
|
||||
)
|
||||
return result.scalar()
|
||||
@@ -819,8 +853,8 @@ class UsersTable:
|
||||
return False
|
||||
|
||||
async def is_user_active(self, user_id: str, db: AsyncSession | None = None) -> bool:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(id=user_id))
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(User).filter_by(id=user_id))
|
||||
user = result.scalars().first()
|
||||
if user and user.last_active_at:
|
||||
# Consider user active if last_active_at within the last 3 minutes
|
||||
|
||||
@@ -513,25 +513,20 @@ async def delete_all_user_chats(
|
||||
|
||||
@router.get('/list/user/{user_id}', response_model=list[ChatTitleIdResponse])
|
||||
async def get_user_chat_list_by_user_id(
|
||||
user_id: str,
|
||||
page: int | None = None,
|
||||
query: str | None = None,
|
||||
order_by: str | None = None,
|
||||
direction: str | None = None,
|
||||
user=Depends(get_admin_user),
|
||||
db: AsyncSession = Depends(get_async_session),
|
||||
user_id: str, page: int | None = None, query: str | None = None,
|
||||
order_by: str | None = None, direction: str | None = None,
|
||||
user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""List chat summaries for a given user (admin-only endpoint)."""
|
||||
if not ENABLE_ADMIN_CHAT_ACCESS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
if page is None:
|
||||
page = 1
|
||||
|
||||
effective_page = page if page is not None else 1
|
||||
limit = 60
|
||||
skip = (page - 1) * limit
|
||||
skip = (effective_page - 1) * limit
|
||||
|
||||
filter = {}
|
||||
if query:
|
||||
|
||||
@@ -62,14 +62,14 @@ class MemoryUpdateModel(BaseModel):
|
||||
|
||||
@router.post('/add', response_model=MemoryModel | None)
|
||||
async def add_memory(
|
||||
request: Request,
|
||||
form_data: AddMemoryForm,
|
||||
user=Depends(get_verified_user),
|
||||
request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user),
|
||||
):
|
||||
# NOTE: We intentionally do NOT use Depends(get_async_session) here.
|
||||
# Database operations (insert_new_memory) manage their own short-lived sessions.
|
||||
# This prevents holding a connection during EMBEDDING_FUNCTION()
|
||||
# which makes external embedding API calls (1-5+ seconds).
|
||||
"""Persist a new memory and embed it into the user's vector collection.
|
||||
|
||||
Does NOT use ``Depends(get_async_session)`` — database operations manage their
|
||||
own short-lived sessions so a connection is not held during the external
|
||||
embedding API call (``EMBEDDING_FUNCTION``), which can take 1-5+ seconds.
|
||||
"""
|
||||
if not request.app.state.config.ENABLE_MEMORIES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
|
||||
@@ -193,11 +193,10 @@ async def get_model_tags(user=Depends(get_verified_user), db: AsyncSession = Dep
|
||||
|
||||
@router.post('/create', response_model=ModelModel | None)
|
||||
async def create_new_model(
|
||||
request: Request,
|
||||
form_data: ModelForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: AsyncSession = Depends(get_async_session),
|
||||
request: Request, form_data: ModelForm,
|
||||
user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Create a new workspace model entry."""
|
||||
if user.role != 'admin' and not await has_permission(
|
||||
user.id, 'workspace.models', request.app.state.config.USER_PERMISSIONS, db=db
|
||||
):
|
||||
@@ -579,16 +578,14 @@ async def toggle_model_by_id(id: str, user=Depends(get_verified_user), db: Async
|
||||
|
||||
@router.post('/model/update', response_model=ModelModel | None)
|
||||
async def update_model_by_id(
|
||||
request: Request,
|
||||
form_data: ModelForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: AsyncSession = Depends(get_async_session),
|
||||
request: Request, form_data: ModelForm,
|
||||
user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Update a workspace model's configuration."""
|
||||
model = await Models.get_model_by_id(form_data.id, db=db)
|
||||
if not model:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if (
|
||||
|
||||
@@ -239,12 +239,10 @@ async def get_prompt_by_id(
|
||||
|
||||
@router.post('/id/{prompt_id}/update', response_model=PromptModel | None)
|
||||
async def update_prompt_by_id(
|
||||
request: Request,
|
||||
prompt_id: str,
|
||||
form_data: PromptForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: AsyncSession = Depends(get_async_session),
|
||||
request: Request, prompt_id: str, form_data: PromptForm,
|
||||
user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Update a prompt's content, creating a new history entry if changed."""
|
||||
prompt = await Prompts.get_prompt_by_id(prompt_id, db=db)
|
||||
|
||||
if not prompt:
|
||||
|
||||
@@ -324,11 +324,10 @@ async def export_tools(
|
||||
|
||||
@router.post('/create', response_model=ToolResponse | None)
|
||||
async def create_new_tools(
|
||||
request: Request,
|
||||
form_data: ToolForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: AsyncSession = Depends(get_async_session),
|
||||
request: Request, form_data: ToolForm,
|
||||
user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Create a new tool from user-supplied Python source code."""
|
||||
if user.role != 'admin' and not (
|
||||
await has_permission(user.id, 'workspace.tools', request.app.state.config.USER_PERMISSIONS, db=db)
|
||||
or await has_permission(
|
||||
@@ -449,17 +448,14 @@ async def get_tools_by_id(id: str, user=Depends(get_verified_user), db: AsyncSes
|
||||
|
||||
@router.post('/id/{id}/update', response_model=ToolModel | None)
|
||||
async def update_tools_by_id(
|
||||
request: Request,
|
||||
id: str,
|
||||
form_data: ToolForm,
|
||||
user=Depends(get_verified_user),
|
||||
db: AsyncSession = Depends(get_async_session),
|
||||
request: Request, id: str, form_data: ToolForm,
|
||||
user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Update an existing tool's source code and metadata."""
|
||||
tools = await Tools.get_tool_by_id(id, db=db)
|
||||
if not tools:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
# Is the user the original creator, in a group with write access, or an admin
|
||||
|
||||
@@ -384,21 +384,22 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user), db: Asy
|
||||
|
||||
@router.post('/user/info/update', response_model=dict | None)
|
||||
async def update_user_info_by_session_user(
|
||||
form_data: dict, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
form_data: dict, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
# Merges against the auth-time snapshot of user.info. The previous pre-merge
|
||||
# refetch only narrowed (did not eliminate) the lost-update window on concurrent
|
||||
# same-user writes; real safety needs row locking or a version column.
|
||||
existing_info = user.info or {}
|
||||
updated = await Users.update_user_by_id(user.id, {'info': {**existing_info, **form_data}}, db=db)
|
||||
if updated:
|
||||
return updated.info
|
||||
else:
|
||||
"""Merge caller-supplied fields into the current user's info dict.
|
||||
|
||||
Uses the auth-time snapshot of ``user.info`` as the merge base. This does
|
||||
NOT eliminate lost-update races on concurrent same-user writes; real safety
|
||||
would need row locking or an optimistic-concurrency version column.
|
||||
"""
|
||||
merged_info = {**(user.info or {}), **form_data}
|
||||
updated = await Users.update_user_by_id(user.id, {'info': merged_info}, db=db)
|
||||
if not updated:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
return updated.info
|
||||
|
||||
############################
|
||||
# GetUserById
|
||||
|
||||
@@ -94,18 +94,21 @@ async def download_chat_as_pdf(form_data: ChatTitleMessagesForm, user=Depends(ge
|
||||
|
||||
@router.get('/db/download')
|
||||
async def download_db(user=Depends(get_admin_user)):
|
||||
"""Download the raw SQLite database file (admin-only, SQLite deployments only)."""
|
||||
if not ENABLE_ADMIN_EXPORT:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
from open_webui.internal.db import engine
|
||||
|
||||
if engine.name != 'sqlite':
|
||||
from open_webui.internal.db import engine # deferred import
|
||||
|
||||
if engine.name != 'sqlite': # only SQLite DBs can be downloaded as a file
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DB_NOT_SQLITE,
|
||||
)
|
||||
|
||||
return FileResponse(
|
||||
engine.url.database,
|
||||
media_type='application/octet-stream',
|
||||
|
||||
Reference in New Issue
Block a user