This commit is contained in:
Timothy Jaeryang Baek
2026-05-21 15:29:49 +04:00
parent 260ead64da
commit d8b5b9fa79
17 changed files with 465 additions and 664 deletions
+7 -12
View File
@@ -117,7 +117,7 @@ extract_ssl_mode_from_url = extract_ssl_params_from_url
reattach_ssl_mode_to_url = reattach_ssl_params_to_url
class JSONField(types.TypeDecorator):
class JSONField(types.TypeDecorator): # TEXT-backed JSON storage
"""Store arbitrary Python objects as JSON-encoded TEXT.
Used instead of native JSON columns for portability across SQLite and
@@ -125,19 +125,14 @@ class JSONField(types.TypeDecorator):
deserialized with ``json.loads`` on read.
"""
impl = types.Text
cache_ok = True # safe for statement caching (no per-instance state)
impl = types.UnicodeText
cache_ok = True
def process_bind_param(self, value: _T | None, dialect: Dialect) -> Any:
return json.dumps(value)
return json.dumps(value) if value is not None else None
def process_result_value(self, value: _T | None, dialect: Dialect) -> Any:
if value is not None:
return json.loads(value)
def copy(self, **kw: Any) -> Self:
return JSONField(self.impl.length)
return json.loads(value) if value is not None else None
def copy(self, **kwargs: Any) -> Self:
return JSONField(length=self.impl.length)
def db_value(self, value):
return json.dumps(value)
+2 -3
View File
@@ -2911,12 +2911,11 @@ async def readiness_check():
@app.get('/health/db')
async def healthcheck_with_db():
async def check_db_health():
"""Verify database connectivity by issuing a lightweight ping."""
await async_db_ping()
return {'status': True}
# --- static assets & files ---
# Serve build-time static assets (CSS, JS, images, favicon, etc.)
app.mount('/static', StaticFiles(directory=STATIC_DIR), name='static')
+48 -88
View File
@@ -1,105 +1,65 @@
from __future__ import annotations
"""Alembic environment configuration.
Configures the migration context for both offline (SQL script generation)
and online (live database connection) modes. Handles SQLCipher URLs,
SSL parameter normalisation, and JSON log formatting.
"""
# Alembic environment configuration runner.
# Coordinates database migrations in both offline and online execution modes.
import logging.config
import logging
from logging.config import fileConfig
from alembic import context
import alembic.context
from open_webui.env import DATABASE_PASSWORD, DATABASE_URL, LOG_FORMAT
from open_webui.internal.db import extract_ssl_params_from_url, reattach_ssl_params_to_url
from open_webui.models.auths import Auth
from open_webui.models.calendar import Calendar, CalendarEvent, CalendarEventAttendee # noqa: F401
from sqlalchemy import create_engine, engine_from_config, pool
# ── Alembic config & logging ─────────────────────────────────────────────────
config = context.config
if config.config_file_name is not None:
fileConfig(config.config_file_name, disable_existing_loggers=False)
# Re-apply JSON formatter after fileConfig replaces handlers.
alembic_config = alembic.context.config
if alembic_config.config_file_name:
logging.config.fileConfig(alembic_config.config_file_name, disable_existing_loggers=False)
if LOG_FORMAT == 'json':
from open_webui.env import JSONFormatter
for handler in logging.root.handlers:
handler.setFormatter(JSONFormatter())
# ── Database URL ─────────────────────────────────────────────────────────────
target_metadata = Auth.metadata
DB_URL = DATABASE_URL
# Normalise SSL query params for psycopg2 (Alembic uses psycopg2 for sync).
_url_no_ssl, _ssl_params = extract_ssl_params_from_url(DB_URL)
if _ssl_params:
DB_URL = reattach_ssl_params_to_url(_url_no_ssl, _ssl_params)
if DB_URL:
config.set_main_option('sqlalchemy.url', DB_URL.replace('%', '%%'))
# ── Migration runners ────────────────────────────────────────────────────────
for log_handler in logging.root.handlers:
log_handler.setFormatter(JSONFormatter())
migration_metadata = Auth.metadata
target_db_url = DATABASE_URL
base_url, ssl_query_params = extract_ssl_params_from_url(target_db_url)
if ssl_query_params:
target_db_url = reattach_ssl_params_to_url(base_url, ssl_query_params)
if target_db_url:
alembic_config.set_main_option('sqlalchemy.url', target_db_url.replace('%', '%%'))
def run_migrations_offline() -> None:
"""Generate SQL script without a live database connection."""
url = config.get_main_option('sqlalchemy.url')
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={'paramstyle': 'named'},
)
with context.begin_transaction():
context.run_migrations()
def _build_connectable():
"""Create the appropriate SQLAlchemy engine for the configured DB URL."""
if DB_URL and DB_URL.startswith('sqlite+sqlcipher://'):
if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == '':
"""Execute Alembic migrations in offline mode (outputs raw SQL DDL)."""
db_connection_url = alembic_config.get_main_option('sqlalchemy.url')
alembic.context.configure(url=db_connection_url, target_metadata=migration_metadata, literal_binds=True, dialect_opts={"paramstyle": "named"})
with alembic.context.begin_transaction():
alembic.context.run_migrations()
def _get_engine_connectable():
"""Build the database engine based on target URL and authentication credentials."""
if target_db_url and target_db_url.startswith('sqlite+sqlcipher://'):
if not DATABASE_PASSWORD or not DATABASE_PASSWORD.strip():
raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs')
db_path = DB_URL.replace('sqlite+sqlcipher://', '')
if db_path.startswith('/'):
db_path = db_path[1:]
def _sqlcipher_creator():
raw_db_path = target_db_url.replace('sqlite+sqlcipher://', '')
if raw_db_path.startswith('/'):
raw_db_path = raw_db_path[1:]
def _sqlite_cipher_creator():
import sqlcipher3
conn = sqlcipher3.connect(db_path, check_same_thread=False)
conn.execute(f"PRAGMA key = '{DATABASE_PASSWORD}'")
return conn
return create_engine('sqlite://', creator=_sqlcipher_creator, echo=False)
cipher_conn = sqlcipher3.connect(raw_db_path, check_same_thread=False)
cipher_conn.execute(f"PRAGMA key = '{DATABASE_PASSWORD}'")
return cipher_conn
return create_engine('sqlite://', creator=_sqlite_cipher_creator, echo=False)
return engine_from_config(
config.get_section(config.config_ini_section, {}),
alembic_config.get_section(alembic_config.config_ini_section, {}),
prefix='sqlalchemy.',
poolclass=pool.NullPool,
)
def run_migrations_online() -> None:
"""Run migrations against a live database connection."""
connectable = _build_connectable()
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
# ── Entrypoint ───────────────────────────────────────────────────────────────
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
"""Execute migrations against a live database connection."""
live_connectable = _get_engine_connectable()
with live_connectable.connect() as live_connection:
alembic.context.configure(
connection=live_connection,
target_metadata=migration_metadata,
)
with alembic.context.begin_transaction():
alembic.context.run_migrations()
# Alembic execution entrypoint branch
if alembic.context.is_offline_mode():
run_migrations_offline() # run in offline mode
if not alembic.context.is_offline_mode():
run_migrations_online() # run in online mode
+3 -4
View File
@@ -2,10 +2,9 @@ from __future__ import annotations
"""Alembic migration utilities."""
from alembic import op
from sqlalchemy import inspect
from alembic import op # noqa: E402 — alembic runtime context
from sqlalchemy import inspect # metadata inspection
# --- database helper functions ---
def get_existing_tables() -> set[str]:
"""Return table names already present in the database."""
conn = op.get_bind()
@@ -1,30 +1,20 @@
# Initial bootstrap migration version.
# Revision ID: 7e5b5dc7342b
# Revises: (none)
# Created on: 2024-06-24 13:15:33.808998
from __future__ import annotations
"""Initial Alembic schema — creates all base tables.
Revision ID: 7e5b5dc7342b
Revises: —
Create Date: 2024-06-24 13:15:33.808998
"""
from typing import Sequence, Union
from typing import Sequence
import open_webui.internal.db # noqa: F401
import sqlalchemy as sa
from alembic import op
from open_webui.internal.db import JSONField
from open_webui.migrations.util import get_existing_tables
revision: str = '7e5b5dc7342b'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
# ── Table definitions ────────────────────────────────────────────────────────
# Each table is only created if it doesn't already exist, because databases
# migrated from the Peewee era will already have these tables.
_TABLES: list[tuple[str, list[sa.Column], list]] = [
revision: str = "7e5b5dc7342b"
down_revision: str | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
# Initial schema table declarations
_INITIAL_TABLES: list[tuple[str, list[sa.Column], list]] = [
(
'auth',
[
@@ -185,15 +175,13 @@ _TABLES: list[tuple[str, list[sa.Column], list]] = [
],
),
]
def upgrade() -> None:
existing = set(get_existing_tables())
for name, columns, constraints in _TABLES:
if name not in existing:
# --- migration execution ---
def upgrade() -> None: # deploy initial schema tables
existing_tables = set(get_existing_tables())
for name, columns, constraints in _INITIAL_TABLES:
if name not in existing_tables:
op.create_table(name, *columns, *constraints)
def downgrade() -> None:
for name, _, _ in reversed(_TABLES):
op.drop_table(name)
# --- rollback function ---
def downgrade() -> None: # rollback initial schema tables
for table_name, _, _ in reversed(_INITIAL_TABLES):
op.drop_table(table_name)
+86 -102
View File
@@ -1,4 +1,4 @@
"""Authentication models and database operations."""
"""Auth credential models and data-access layer."""
from __future__ import annotations
@@ -16,19 +16,19 @@ from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
class Auth(Base):
"""Credential record linking a user identity to an email + hashed password."""
class Auth(Base): # credential ↔ user linkage
"""Maps a user ID to an email/password pair with an active flag."""
__tablename__ = 'auth'
id = Column(String, primary_key=True, unique=True) # same as User.id
email = Column(String) # login email, kept in sync with User.email
password = Column(Text) # bcrypt / argon2 hash
active = Column(Boolean) # soft-disable flag
id = Column(String, primary_key=True, unique=True) # mirrors User.id
email = Column(String) # login address, kept in sync with User.email
password = Column(Text) # argon2 / bcrypt hash
active = Column(Boolean) # account soft-disable toggle
class AuthModel(BaseModel):
"""Pydantic mirror of the Auth table row."""
"""Pydantic mirror of the ``auth`` table row."""
id: str
email: str
@@ -37,7 +37,7 @@ class AuthModel(BaseModel):
class Token(BaseModel):
"""JWT bearer token response."""
"""JWT bearer-token response wrapper."""
token: str
token_type: str
@@ -88,7 +88,12 @@ class AddUserForm(SignupForm):
role: str | None = 'pending'
# --- data-access layer ---
class AuthsTable:
"""Provides CRUD operations for the Auth ↔ User lifecycle."""
async def insert_new_auth(
self,
email: str,
@@ -99,130 +104,109 @@ class AuthsTable:
oauth: dict | None = None,
db: AsyncSession | None = None,
) -> UserModel | None:
"""Create an Auth + User pair in a single transaction."""
async with get_async_db_context(db) as db:
"""Create an Auth + User pair inside a single transaction."""
async with get_async_db_context(db) as session:
log.info('insert_new_auth')
user_id = str(uuid.uuid4())
new_id = str(uuid.uuid4())
record = Auth(
id=user_id,
credential = Auth(
id=new_id,
email=email,
password=password,
active=True,
)
db.add(record)
session.add(credential)
user = await Users.insert_new_user(
user_id, name, email, profile_image_url, role, oauth=oauth, db=db,
created_user = await Users.insert_new_user(
new_id, name, email, profile_image_url, role, oauth=oauth, db=session,
)
await db.commit()
await db.refresh(record)
return user if record and user else None
# persist both records and reload generated defaults
await session.commit()
await session.refresh(credential)
return created_user if credential and created_user else None
async def authenticate_user(
self, email: str, verify_password: callable, db: AsyncSession | None = None
self, email: str, verify_password: callable, db: AsyncSession | None = None,
) -> UserModel | None:
"""Verify a user's email + password and return the user on success."""
log.info(f'authenticate_user: {email}')
user = await Users.get_user_by_email(email, db=db)
if not user:
return None
try: # load the auth row for password verification
async with get_async_db_context(db) as session:
auth = await session.get(Auth, user.id)
if not auth or not auth.active:
return None
if not verify_password(auth.password):
return None
return user
except Exception:
"""Verify email + password credentials and return the matching user."""
log.info('authenticate_user: %s', email)
resolved = await Users.get_user_by_email(email, db=db)
if not resolved:
return
# load the credential row and verify the password hash
async with get_async_db_context(db) as session:
credential = await session.get(Auth, resolved.id)
if not credential or not credential.active:
return
if not verify_password(credential.password):
return
return resolved
async def authenticate_user_by_api_key(
self, api_key: str, db: AsyncSession | None = None,
) -> UserModel | None:
"""Resolve an API key to its owning user, returning ``None`` on miss."""
"""Look up the user that owns the given API key."""
log.info('authenticate_user_by_api_key')
if not api_key: # empty / None key — reject immediately
if not api_key:
return
try:
return await Users.get_user_by_api_key(api_key, db=db)
except Exception:
return False
# delegate to the Users model for the actual lookup
return await Users.get_user_by_api_key(api_key, db=db)
async def authenticate_user_by_email(
self,
email: str,
db: AsyncSession | None = None,
) -> UserModel | None:
"""One-query authentication: JOIN Auth ↔ User, filter by email + active flag."""
"""Single-query auth via JOIN on Auth ↔ User, filtered by active flag."""
log.info('authenticate_user_by_email: %s', email)
try:
async with get_async_db_context(db) as session:
stmt = (
select(Auth, User)
.join(User, Auth.id == User.id)
.filter(Auth.email == email, Auth.active == True)
)
row = (await session.execute(stmt)).first()
if not row:
return
_auth, matched_user = row
return UserModel.model_validate(matched_user)
except Exception:
return
async def update_user_password_by_id(
self,
id: str,
new_password: str,
db: AsyncSession | None = None,
) -> bool:
"""Hash-swap: replace the stored password hash for a given user."""
try:
async with get_async_db_context(db) as session:
stmt = update(Auth).filter_by(id=id).values(password=new_password)
result = await session.execute(stmt)
await session.commit()
return result.rowcount == 1
except Exception:
return False
# single JOIN avoids N+1 — returns (Auth, User) tuple or None
async with get_async_db_context(db) as session:
joined_query = (
select(Auth, User)
.join(User, Auth.id == User.id)
.where(Auth.email == email, Auth.active.is_(True))
)
match = (await session.execute(joined_query)).first()
if not match:
return
_, found_user = match
return UserModel.model_validate(found_user)
async def update_email_by_id(
self, id: str, email: str, db: AsyncSession | None = None,
self, user_id: str, email: str, db: AsyncSession | None = None,
) -> bool:
"""Update the auth email and propagate the change to the User table."""
try:
async with get_async_db_context(db) as session:
stmt = update(Auth).filter_by(id=id).values(email=email)
result = await session.execute(stmt)
await session.commit()
if result.rowcount != 1:
return False
await Users.update_user_by_id(id, {'email': email}, db=session)
return True
except Exception:
return False
"""Set a new email on the auth record and propagate to the user row."""
async with get_async_db_context(db) as session:
auth_row = await session.get(Auth, user_id)
if auth_row is None:
return False
auth_row.email = email
await session.commit()
await Users.update_user_by_id(user_id, {'email': email}, db=session)
return True
# --- password modification ---
async def update_user_password_by_id(
self, user_id: str, new_password: str, db: AsyncSession | None = None,
) -> bool:
"""Set a new password hash for an existing user."""
async with get_async_db_context(db) as session:
auth_row = await session.get(Auth, user_id)
if auth_row is None:
return False
auth_row.password = new_password
await session.commit()
return True
async def delete_auth_by_id(
self, id: str, db: AsyncSession | None = None,
) -> bool:
"""Delete a user and their auth record in a single transaction."""
try: # delete user first, then auth (FK order)
async with get_async_db_context(db) as session:
if not await Users.delete_user_by_id(id, db=session):
return False # user deletion failed — abort
await session.execute(delete(Auth).filter_by(id=id))
await session.commit()
return True
except Exception: # db / integrity error
return False # partial deletion is rolled back by context manager
"""Remove a user and their auth credential in one transaction."""
async with get_async_db_context(db) as session:
if not await Users.delete_user_by_id(id, db=session):
return False
await session.execute(delete(Auth).where(Auth.id == id))
await session.commit()
return True
Auths = AuthsTable()
Auths = AuthsTable() # singleton — module-level instance
+22 -24
View File
@@ -6,7 +6,7 @@ import json
import logging
import time
import uuid
# local imports
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.models.automations import AutomationRun
from open_webui.models.chat_messages import ChatMessage, ChatMessages
@@ -39,16 +39,16 @@ from sqlalchemy.sql.expression import bindparam
log = logging.getLogger(__name__)
class Chat(Base):
class Chat(Base): # database table mapping for chat entity
__tablename__ = 'chat'
id = Column(String, primary_key=True, unique=True)
user_id = Column(String)
user_id = Column(String, index=True) # owner user id
title = Column(Text) # user-visible conversation title
chat = Column(JSON)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
created_at = Column(BigInteger, index=True) # conversation creation timestamp
updated_at = Column(BigInteger, index=True) # conversation modification timestamp
share_id = Column(Text, unique=True, nullable=True) # public share link token
archived = Column(Boolean, default=False) # hidden from main chat list
@@ -73,8 +73,7 @@ class Chat(Base):
class ChatModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True) # allows ORM model binding
id: str
user_id: str
title: str
@@ -676,22 +675,21 @@ class ChatTable:
await session.commit()
await session.refresh(chat)
return ChatModel.model_validate(chat) # return the updated original
# refresh helper
async def update_shared_chat_by_chat_id(
self, chat_id: str, db: AsyncSession | None = None,
) -> ChatModel | None:
"""Re-snapshot the shared copy with the latest chat content."""
from open_webui.models.shared_chats import SharedChats # deferred — circular
"""Refresh the shared snapshot with current chat content."""
from open_webui.models.shared_chats import SharedChats
try:
async with get_async_db_context(db) as session:
chat = await session.get(Chat, chat_id)
if not chat or not chat.share_id:
return await self.insert_shared_chat_by_chat_id(chat_id, db=session)
await SharedChats.update(chat.share_id, db=session)
return ChatModel.model_validate(chat)
except Exception:
return
async with get_async_db_context(db) as session:
record = await session.get(Chat, chat_id)
if not record or not record.share_id:
return await self.insert_shared_chat_by_chat_id(chat_id, db=session)
await SharedChats.update(record.share_id, db=session)
return ChatModel.model_validate(record)
# unreachable — context manager above always returns
return
async def delete_shared_chat_by_chat_id(self, chat_id: str, db: AsyncSession | None = None) -> bool:
"""Delete shared snapshot for a chat."""
@@ -938,7 +936,7 @@ class ChatTable:
)
all_chats = result.scalars().all()
return [ChatModel.model_validate(chat) for chat in all_chats]
# retrieve conversation
async def get_chat_by_id(
self, id: str, db: AsyncSession | None = None,
) -> ChatModel | None:
@@ -1019,7 +1017,7 @@ class ChatTable:
result = await session.execute(select(Chat).order_by(Chat.updated_at.desc()))
all_chats = result.scalars().all()
return [ChatModel.model_validate(chat) for chat in all_chats]
# list user conversations
async def get_chats_by_user_id(
self,
user_id: str,
@@ -1067,7 +1065,7 @@ class ChatTable:
'total': total,
}
)
# list pinned chats
async def get_pinned_chats_by_user_id(
self, user_id: str, db: AsyncSession | None = None
) -> list[ChatTitleIdResponse]:
@@ -1097,7 +1095,7 @@ class ChatTable:
select(Chat).filter_by(user_id=user_id, archived=True).order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in result.scalars().all()]
# search user conversations
async def get_chats_by_user_id_and_search_text(
self,
user_id: str,
@@ -1712,4 +1710,4 @@ class ChatTable:
return row[0]
Chats = ChatTable()
Chats = ChatTable() # singleton chats repository
+6 -6
View File
@@ -4,7 +4,7 @@ from __future__ import annotations
import logging
import time
# local imports
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.utils.misc import sanitize_metadata
from pydantic import BaseModel, ConfigDict, model_validator
@@ -14,10 +14,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
class File(Base):
class File(Base): # uploaded file record
__tablename__ = 'file'
id = Column(String, primary_key=True, unique=True)
user_id = Column(String)
user_id = Column(String, index=True) # owner user id
hash = Column(Text, nullable=True)
filename = Column(Text) # original upload filename
@@ -26,7 +26,7 @@ class File(Base):
data = Column(JSON, nullable=True)
meta = Column(JSON, nullable=True)
created_at = Column(BigInteger)
created_at = Column(BigInteger, index=True) # upload timestamp
updated_at = Column(BigInteger)
@@ -46,7 +46,7 @@ class FileModel(BaseModel):
created_at: int | None # timestamp in epoch
updated_at: int | None # timestamp in epoch
# --- metadata structures ---
class FileMeta(BaseModel):
name: str | None = None
content_type: str | None = None
@@ -410,4 +410,4 @@ class FilesTable:
return False
Files = FilesTable()
Files = FilesTable() # singleton files repository
+12 -12
View File
@@ -4,7 +4,7 @@ from __future__ import annotations
import logging
import time
# local imports
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.models.users import UserModel, UserResponse, Users
from pydantic import BaseModel, ConfigDict
@@ -14,17 +14,17 @@ from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
class Function(Base):
class Function(Base): # database table mapping
__tablename__ = 'function'
id = Column(String, primary_key=True, unique=True)
user_id = Column(String)
name = Column(Text)
type = Column(Text)
content = Column(Text)
meta = Column(JSONField)
valves = Column(JSONField)
is_active = Column(Boolean)
user_id = Column(String, index=True) # creator user id
name = Column(Text, nullable=False) # function identifier
type = Column(Text, nullable=False) # function type (pipe, filter, etc.)
content = Column(Text, nullable=True) # Python source code
meta = Column(JSONField, nullable=True) # function metadata
valves = Column(JSONField, nullable=True) # function configuration valves
is_active = Column(Boolean, default=False) # function activation status
is_global = Column(Boolean) # if True, applied to every chat automatically
updated_at = Column(BigInteger) # epoch seconds
created_at = Column(BigInteger) # epoch seconds
@@ -50,9 +50,9 @@ class FunctionModel(BaseModel):
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True) # allows ORM model binding
# --- form / schema definitions ---
class FunctionWithValvesModel(BaseModel):
id: str
user_id: str
@@ -425,4 +425,4 @@ class FunctionsTable:
return False
Functions = FunctionsTable()
Functions = FunctionsTable() # singleton functions engine
+4 -5
View File
@@ -12,8 +12,8 @@ from sqlalchemy import BigInteger, Column, String, Text, delete, select
from sqlalchemy.ext.asyncio import AsyncSession
class Memory(Base):
"""Persistent user memory backed by a vector collection."""
class Memory(Base): # user memory store
"""Stores user-created memory entries linked to a vector collection."""
__tablename__ = 'memory'
@@ -32,8 +32,7 @@ class MemoryModel(BaseModel):
content: str
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True) # allows ORM mapping
class MemoriesTable:
async def insert_new_memory(
@@ -140,4 +139,4 @@ class MemoriesTable:
return False
Memories = MemoriesTable()
Memories = MemoriesTable() # user memory registry
+2 -2
View File
@@ -59,7 +59,7 @@ class ModelMeta(BaseModel):
return data
class Model(Base):
class Model(Base): # provider model config
"""Workspace model entry — wraps an upstream LLM with custom params and metadata."""
__tablename__ = 'model'
@@ -573,4 +573,4 @@ class ModelsTable:
return []
Models = ModelsTable()
Models = ModelsTable() # singleton model registry
+37 -37
View File
@@ -17,14 +17,14 @@ from sqlalchemy import JSON, BigInteger, Boolean, Column, String, Text, cast, de
from sqlalchemy.ext.asyncio import AsyncSession
class Prompt(Base):
"""System prompt template with versioning support."""
class Prompt(Base): # versioned template
"""Slash-command prompt with history tracking and access control."""
__tablename__ = 'prompt'
id = Column(Text, primary_key=True)
command = Column(String, unique=True, index=True)
user_id = Column(String)
user_id = Column(String, index=True) # owner user id
name = Column(Text)
content = Column(Text) # the prompt template body
data = Column(JSON, nullable=True) # structured prompt parameters
@@ -51,10 +51,8 @@ class PromptModel(BaseModel):
updated_at: int | None = None
access_grants: list[AccessGrantModel] = Field(default_factory=list)
model_config = ConfigDict(from_attributes=True)
####################
model_config = ConfigDict(from_attributes=True) # allows ORM model binding
# --- form / schema definitions ---
# Forms
####################
@@ -175,34 +173,33 @@ class PromptsTable:
return
async def get_prompt_by_command(self, command: str, db: AsyncSession | None = None) -> PromptModel | None:
try:
async with get_async_db_context(db) as session:
result = await session.execute(
select(Prompt).filter_by(command=command),
)
prompt = result.scalars().first() # None when no match
if not prompt:
return None
return await self._to_prompt_model(prompt, db=session)
except Exception: # connection / integrity error
return
"""Look up a prompt by its unique slash-command string."""
async with get_async_db_context(db) as session:
match = (await session.execute(
select(Prompt).where(Prompt.command == command)
)).scalars().first()
if match is None:
return
return await self._to_prompt_model(match, db=session)
# --- context manager always returns above ---
return
async def get_prompts(self, db: AsyncSession | None = None) -> list[PromptUserResponse]:
"""Return all active prompts ordered by most recently updated."""
async with get_async_db_context(db) as session:
result = await session.execute(
select(Prompt).filter(Prompt.is_active == True).order_by(Prompt.updated_at.desc())
)
all_prompts = result.scalars().all()
active = (await session.execute(
select(Prompt).where(Prompt.is_active.is_(True)).order_by(Prompt.updated_at.desc())
)).scalars().all()
user_ids = list(set(prompt.user_id for prompt in all_prompts))
prompt_ids = [prompt.id for prompt in all_prompts]
user_ids = list(set(p.user_id for p in active))
prompt_ids = [p.id for p in active]
users = await Users.get_users_by_user_ids(user_ids, db=session) if user_ids else []
users_dict = {user.id: user for user in users}
users_dict = {u.id: u for u in users}
grants_map = await AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=session)
prompts = []
for prompt in all_prompts:
for prompt in active:
user = users_dict.get(prompt.user_id)
prompts.append(
PromptUserResponse.model_validate(
@@ -399,7 +396,9 @@ class PromptsTable:
user_id: str,
db: AsyncSession | None = None,
) -> PromptModel | None:
try:
if not command:
return None
try: # database transaction
async with get_async_db_context(db) as session:
result = await session.execute(select(Prompt).filter_by(command=command))
prompt = result.scalars().first()
@@ -596,16 +595,16 @@ class PromptsTable:
await session.commit()
return await self._to_prompt_model(prompt, db=session)
except Exception: # connection error
return None
# --- Active state management ---
except Exception as e: # connection error
log.error(f"Failed to restore prompt version: {e}")
return None # restoration failed
async def toggle_prompt_active(
self, prompt_id: str, db: AsyncSession | None = None,
) -> PromptModel | None:
"""Toggle the is_active flag on a prompt."""
try:
"""Flip the is_active flag on a prompt."""
if not prompt_id:
return None
try: # activation state toggle
async with get_async_db_context(db) as session:
result = await session.execute(select(Prompt).filter_by(id=prompt_id))
prompt = result.scalars().first()
@@ -650,8 +649,9 @@ class PromptsTable:
await session.commit()
return True
return False
except Exception:
return False
except Exception as err:
log.error(f"Failed to delete prompt: {err}")
return False # deletion failed
async def get_tags(self, db: AsyncSession | None = None) -> list[str]:
try:
@@ -695,4 +695,4 @@ class PromptsTable:
return []
Prompts = PromptsTable()
Prompts = PromptsTable() # singleton prompts registry
+7 -8
View File
@@ -5,7 +5,7 @@ from __future__ import annotations
import logging
import time
import uuid
# local imports
from open_webui.internal.db import Base, JSONField, get_async_db_context
from pydantic import BaseModel, ConfigDict
from sqlalchemy import JSON, BigInteger, Column, Index, PrimaryKeyConstraint, String, delete, select
@@ -14,11 +14,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
class Tag(Base):
class Tag(Base): # database table mapping for tag entity
__tablename__ = 'tag'
id = Column(String)
name = Column(String)
user_id = Column(String)
name = Column(String, index=True) # tag label
user_id = Column(String, index=True) # user identifier
meta = Column(JSON, nullable=True)
__table_args__ = (
@@ -35,10 +35,9 @@ class TagModel(BaseModel):
name: str
user_id: str
meta: dict | None = None
model_config = ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True) # allows ORM model binding
####################
# --- tag schema forms ---
# Forms
####################
@@ -132,4 +131,4 @@ class TagTable:
await db.commit()
Tags = TagTable()
Tags = TagTable() # singleton tag repository
+8 -9
View File
@@ -4,7 +4,7 @@ from __future__ import annotations
import logging
import time
# local imports
from open_webui.internal.db import Base, JSONField, get_async_db_context
from open_webui.models.access_grants import AccessGrantModel, AccessGrants
from open_webui.models.groups import Groups
@@ -16,19 +16,19 @@ from sqlalchemy.ext.asyncio import AsyncSession
log = logging.getLogger(__name__)
class Tool(Base):
class Tool(Base): # database table definition
__tablename__ = 'tool'
id = Column(String, primary_key=True, unique=True)
user_id = Column(String)
user_id = Column(String, index=True) # owner user id
name = Column(Text) # human-readable label
content = Column(Text) # Python source code
specs = Column(JSONField) # OpenAPI-style function specs
meta = Column(JSONField) # description, manifest, etc.
valves = Column(JSONField) # admin-configurable runtime parameters
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
updated_at = Column(BigInteger, nullable=False) # modification timestamp
created_at = Column(BigInteger, index=True) # creation timestamp
class ToolMeta(BaseModel):
@@ -48,10 +48,9 @@ class ToolModel(BaseModel):
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
model_config = ConfigDict(from_attributes=True) # enables ORM mapping
####################
# --- tool request forms ---
# Forms
####################
@@ -316,4 +315,4 @@ class ToolsTable:
return False
Tools = ToolsTable()
Tools = ToolsTable() # singleton tool registry
+193 -311
View File
@@ -43,17 +43,15 @@ class UserSettings(BaseModel):
pass
class User(Base):
"""Core user identity and profile record."""
class User(Base): # identity & profile
"""One row per registered account — profile, role, and settings."""
__tablename__ = 'user'
# Identity
id = Column(String, primary_key=True, unique=True)
email = Column(String)
username = Column(String(50), nullable=True)
role = Column(String)
name = Column(String)
__tablename__: str = 'user' # Identity & Credentials
id = Column(String, primary_key=True, unique=True) # unique user id
email = Column(String, unique=True) # user email address
username = Column(String(50), nullable=True) # custom handle
role = Column(String, default="pending") # permissions role
name = Column(String, nullable=False) # display name
# Profile
profile_image_url = Column(Text) # data-uri, path, or external URL
@@ -119,16 +117,18 @@ class UserModel(BaseModel):
model_config = ConfigDict(
from_attributes=True,
)
@model_validator(mode='after') # runs after all field validators
# validation schema logic
# --- model validators ---
@model_validator(mode='after')
def _ensure_profile_image(self) -> 'UserModel':
"""Fall back to a generated avatar when no explicit profile image is set."""
if self.profile_image_url: # explicit image — nothing to do
return self
self.profile_image_url = _DEFAULT_PROFILE_IMAGE_URL.format(
user_id=self.id,
"""Assign a generated avatar when no profile image is provided."""
self.profile_image_url = (
self.profile_image_url
or _DEFAULT_PROFILE_IMAGE_URL.format(user_id=self.id)
)
return self # modified in-place
return self
class UserStatusModel(UserModel):
@@ -302,97 +302,77 @@ class UsersTable:
session.add(result)
await session.commit()
await session.refresh(result)
if result:
return user
else:
return None
return user if result else None
# database read methods
# --- read / lookup operations ---
async def get_user_by_id(
self, id: str, db: AsyncSession | None = None,
) -> UserModel | None:
"""Fetch a single user by primary key."""
try: # db.get is a PK lookup — very cheap
async with get_async_db_context(db) as session:
user = await session.get(User, id)
if not user:
return None
return UserModel.model_validate(user)
except Exception: # stale session / connection drop
return None # caller treats None as "not found"
async with get_async_db_context(db) as session:
user = await session.get(User, id)
return UserModel.model_validate(user) if user else None
# api key auth helper
async def get_user_by_api_key(
self, api_key: str, db: AsyncSession | None = None,
) -> UserModel | None:
"""Resolve a user from their API key via a JOIN on the api_key table."""
try: # single JOIN instead of two round-trips
async with get_async_db_context(db) as session:
result = await session.execute(
select(User)
.join(ApiKey, User.id == ApiKey.user_id)
.filter(ApiKey.key == api_key),
) # emits: SELECT user.* FROM user JOIN api_key …
user = result.scalars().first() # None when key is invalid
if not user:
return None
return UserModel.model_validate(user)
except Exception: # stale session / connection drop
return
async with get_async_db_context(db) as session:
result = await session.execute(
select(User)
.join(ApiKey, User.id == ApiKey.user_id)
.where(ApiKey.key == api_key),
)
user = result.scalars().first()
return UserModel.model_validate(user) if user else None
async def get_user_by_email(
self, email: str, db: AsyncSession | None = None,
) -> UserModel | None:
"""Case-insensitive email lookup using SQL ``lower()``."""
try:
async with get_async_db_context(db) as session:
stmt = select(User).filter(
func.lower(User.email) == email.lower(),
)
row = (await session.execute(stmt)).scalars().first()
if not row:
return
return UserModel.model_validate(row)
except Exception:
return
"""Case-insensitive email lookup using SQL lower()."""
async with get_async_db_context(db) as session:
email_filter = func.lower(User.email) == email.lower()
query = select(User).where(email_filter)
match = (await session.execute(query)).scalars().first()
if match is None:
return
return UserModel.model_validate(match)
# --- context manager above always returns ---
return
# --- oauth & integrations ---
async def get_user_by_oauth_sub(
self, provider: str, sub: str, db: AsyncSession | None = None,
) -> UserModel | None:
"""Look up a user by OAuth provider + subject claim (dialect-aware JSON filter)."""
try:
async with get_async_db_context(db) as session:
dialect = session.bind.dialect.name
query = select(User)
if dialect == 'sqlite':
oauth_match = User.oauth.contains({provider: {'sub': sub}})
query = query.filter(oauth_match)
elif dialect == 'postgresql':
oauth_match = User.oauth[provider].cast(JSONB)['sub'].astext == sub
query = query.filter(oauth_match)
row = (await session.execute(query)).scalars().first()
if not row:
return
return UserModel.model_validate(row)
except Exception:
return
async with get_async_db_context(db) as session:
dialect = session.bind.dialect.name
query = select(User)
if dialect == 'sqlite':
oauth_match = User.oauth.contains({provider: {'sub': sub}})
query = query.where(oauth_match)
elif dialect == 'postgresql':
oauth_match = User.oauth[provider].cast(JSONB)['sub'].astext == sub
query = query.where(oauth_match)
row = (await session.execute(query)).scalars().first()
return UserModel.model_validate(row) if row else None
async def get_user_by_scim_external_id(self, provider: str, external_id: str, db: AsyncSession | None = None) -> UserModel | None:
async def get_user_by_scim_external_id(
self, provider: str, external_id: str, db: AsyncSession | None = None,
) -> UserModel | None:
"""Look up a user by SCIM provider + external ID (dialect-aware JSON filter)."""
try:
async with get_async_db_context(db) as session:
dialect = session.bind.dialect.name
query = select(User)
if dialect == 'sqlite':
scim_match = User.scim.contains({provider: {'external_id': external_id}})
query = query.filter(scim_match)
elif dialect == 'postgresql':
scim_match = User.scim[provider].cast(JSONB)['external_id'].astext == external_id
query = query.filter(scim_match)
row = (await session.execute(query)).scalars().first()
if not row:
return
return UserModel.model_validate(row)
except Exception:
return
async with get_async_db_context(db) as session:
dialect = session.bind.dialect.name
query = select(User)
if dialect == 'sqlite':
scim_match = User.scim.contains({provider: {'external_id': external_id}})
query = query.where(scim_match)
elif dialect == 'postgresql':
scim_match = User.scim[provider].cast(JSONB)['external_id'].astext == external_id
query = query.where(scim_match)
row = (await session.execute(query)).scalars().first()
return UserModel.model_validate(row) if row else None
async def get_users(
self, filter: dict | None = None, skip: int | None = None,
@@ -551,137 +531,95 @@ class UsersTable:
result = await session.execute(select(User).filter(User.id.in_(user_ids)))
users = result.scalars().all()
return [UserModel.model_validate(user) for user in users]
# count registered accounts
async def get_num_users(self, db: AsyncSession | None = None) -> int | None:
async with get_async_db_context(db) as session:
result = await session.execute(select(func.count()).select_from(User))
return result.scalar()
# check user existence
async def has_users(self, db: AsyncSession | None = None) -> bool:
async with get_async_db_context(db) as session:
result = await session.execute(select(exists(select(User))))
return result.scalar()
async def get_first_user(self, db: AsyncSession | None = None) -> UserModel:
async def get_first_user(self, db: AsyncSession | None = None) -> UserModel | None:
"""Return the earliest-created user (bootstrap admin detection)."""
try:
async with get_async_db_context(db) as session:
stmt = select(User).order_by(User.created_at).limit(1)
row = (await session.execute(stmt)).scalars().first()
if not row:
return
return UserModel.model_validate(row)
except Exception:
return
async with get_async_db_context(db) as session:
stmt = select(User).order_by(User.created_at).limit(1)
row = (await session.execute(stmt)).scalars().first()
return UserModel.model_validate(row) if row else None
async def get_user_webhook_url_by_id(self, id: str, db: AsyncSession | None = None) -> str | None:
try:
async with get_async_db_context(db) as session:
result = await session.execute(select(User).filter_by(id=id))
user = result.scalars().first()
if user.settings is None:
return None
else:
return user.settings.get('ui', {}).get('notifications', {}).get('webhook_url', None)
except Exception:
async with get_async_db_context(db) as session:
user = await session.get(User, id)
if user and user.settings:
return user.settings.get('ui', {}).get('notifications', {}).get('webhook_url', None)
return None
async def get_num_users_active_today(self, db: AsyncSession | None = None) -> int | None:
async with get_async_db_context(db) as session:
current_timestamp = int(datetime.datetime.now().timestamp())
current_timestamp = int(time.time())
today_midnight_timestamp = current_timestamp - (current_timestamp % 86400)
result = await session.execute(
select(func.count()).select_from(User).filter(User.last_active_at > today_midnight_timestamp)
select(func.count()).select_from(User).where(User.last_active_at > today_midnight_timestamp)
)
return result.scalar()
async def update_user_role_by_id(self, id: str, role: str, db: AsyncSession | None = None) -> UserModel | None:
try:
async with get_async_db_context(db) as session:
user = (await session.execute(select(User).filter_by(id=id))).scalars().first()
if not user:
return
user.role = role
await session.commit()
await session.refresh(user)
return UserModel.model_validate(user)
except Exception:
return
async with get_async_db_context(db) as session:
user = await session.get(User, id)
if not user:
return None
user.role = role
await session.commit()
await session.refresh(user)
return UserModel.model_validate(user)
async def update_user_status_by_id(
self, id: str, form_data: UserStatus, db: AsyncSession | None = None
) -> UserModel | None:
try:
async with get_async_db_context(db) as session:
result = await session.execute(select(User).filter_by(id=id))
user = result.scalars().first()
if not user:
return None
for key, value in form_data.model_dump(exclude_none=True).items():
setattr(user, key, value)
await session.commit()
await session.refresh(user)
return UserModel.model_validate(user)
except Exception:
return None
async with get_async_db_context(db) as session:
user = await session.get(User, id)
if not user:
return None
for key, value in form_data.model_dump(exclude_none=True).items():
setattr(user, key, value)
await session.commit()
await session.refresh(user)
return UserModel.model_validate(user)
async def update_user_profile_image_url_by_id(
self, id: str, profile_image_url: str, db: AsyncSession | None = None,
) -> UserModel | None:
try:
async with get_async_db_context(db) as session:
row = (await session.execute(select(User).filter_by(id=id))).scalars().first()
if row is None:
return
row.profile_image_url = profile_image_url
await session.commit()
await session.refresh(row)
return UserModel.model_validate(row)
except Exception:
return
async with get_async_db_context(db) as session:
user = await session.get(User, id)
if user is None:
return None
user.profile_image_url = profile_image_url
await session.commit()
await session.refresh(user)
return UserModel.model_validate(user)
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
async def update_last_active_by_id(self, id: str, db: AsyncSession | None = None) -> None:
try:
async with get_async_db_context(db) as session:
await session.execute(update(User).filter_by(id=id).values(last_active_at=int(time.time())))
await session.commit()
except Exception:
pass
async with get_async_db_context(db) as session:
await session.execute(update(User).where(User.id == id).values(last_active_at=int(time.time())))
await session.commit()
async def update_user_oauth_by_id(
self, id: str, provider: str, sub: str, db: AsyncSession | None = None
) -> UserModel | None:
"""
Update or insert an OAuth provider/sub pair into the user's oauth JSON field.
Example resulting structure:
{
"google": { "sub": "123" },
"github": { "sub": "abc" }
}
"""
try:
async with get_async_db_context(db) as session:
result = await session.execute(select(User).filter_by(id=id))
user = result.scalars().first()
if not user:
return None
# Load existing oauth JSON or create empty
oauth = user.oauth or {}
# Update or insert provider entry
oauth[provider] = {'sub': sub}
# Persist updated JSON
await session.execute(update(User).filter_by(id=id).values(oauth=oauth))
await session.commit()
return UserModel.model_validate(user)
except Exception:
return None
"""Update or insert an OAuth provider/sub pair into the user's oauth JSON field."""
async with get_async_db_context(db) as session:
user = await session.get(User, id)
if not user:
return None
oauth = dict(user.oauth or {})
oauth[provider] = {'sub': sub}
user.oauth = oauth
await session.commit()
await session.refresh(user)
return UserModel.model_validate(user)
async def update_user_scim_by_id(
self,
@@ -690,157 +628,101 @@ class UsersTable:
external_id: str,
db: AsyncSession | None = None,
) -> UserModel | None:
"""
Update or insert a SCIM provider/external_id pair into the user's scim JSON field.
Example resulting structure:
{
"microsoft": { "external_id": "abc" },
"okta": { "external_id": "def" }
}
"""
try:
async with get_async_db_context(db) as session:
result = await session.execute(select(User).filter_by(id=id))
user = result.scalars().first()
if not user:
return None
scim = user.scim or {}
scim[provider] = {'external_id': external_id}
await session.execute(update(User).filter_by(id=id).values(scim=scim))
await session.commit()
return UserModel.model_validate(user)
except Exception:
return None
"""Update or insert a SCIM provider/external_id pair into the user's scim JSON field."""
async with get_async_db_context(db) as session:
user = await session.get(User, id)
if not user:
return None
scim = dict(user.scim or {})
scim[provider] = {'external_id': external_id}
user.scim = scim
await session.commit()
await session.refresh(user)
return UserModel.model_validate(user)
async def update_user_by_id(self, id: str, updated: dict, db: AsyncSession | None = None) -> UserModel | None:
try:
async with get_async_db_context(db) as session:
result = await session.execute(select(User).filter_by(id=id))
user = result.scalars().first()
if not user:
return None
for key, value in updated.items():
setattr(user, key, value)
await session.commit()
await session.refresh(user)
return UserModel.model_validate(user)
except Exception as e:
print(e)
return None
async with get_async_db_context(db) as session:
user = await session.get(User, id)
if not user:
return None
for key, value in updated.items():
setattr(user, key, value)
await session.commit()
await session.refresh(user)
return UserModel.model_validate(user)
# settings update helper
async def update_user_settings_by_id(
self, id: str, updated: dict, db: AsyncSession | None = None
) -> UserModel | None:
try:
async with get_async_db_context(db) as session:
result = await session.execute(select(User).filter_by(id=id))
user = result.scalars().first()
if not user:
return None
user_settings = user.settings
if user_settings is None:
user_settings = {}
user_settings.update(updated)
await session.execute(update(User).filter_by(id=id).values(settings=user_settings))
await session.commit()
result = await session.execute(select(User).filter_by(id=id))
user = result.scalars().first()
return UserModel.model_validate(user)
except Exception:
return None
async with get_async_db_context(db) as session:
user = await session.get(User, id)
if not user:
return None
user_settings = dict(user.settings or {})
user_settings.update(updated)
user.settings = user_settings
await session.commit()
await session.refresh(user)
return UserModel.model_validate(user)
async def delete_user_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
try:
from open_webui.models.chats import Chats
from open_webui.models.groups import Groups
from open_webui.models.chats import Chats
from open_webui.models.groups import Groups
# Remove User from Groups
await Groups.remove_user_from_all_groups(id)
# Remove User from Groups
await Groups.remove_user_from_all_groups(id)
# Delete User Chats
result = await Chats.delete_chats_by_user_id(id, db=session)
if result:
async with get_async_db_context(db) as session:
# Delete User
await session.execute(delete(User).filter_by(id=id))
await session.commit()
return True
else:
return False
except Exception:
return False
# Delete User Chats
async with get_async_db_context(db) as session:
deleted_chats = await Chats.delete_chats_by_user_id(id, db=session)
if not deleted_chats:
return False # chats deletion failed
await session.execute(delete(User).where(User.id == id))
await session.commit()
return True
async def get_user_api_key_by_id(self, id: str, db: AsyncSession | None = None) -> str | None:
try:
async with get_async_db_context(db) as session:
result = await session.execute(select(ApiKey).filter_by(user_id=id))
api_key = result.scalars().first()
return api_key.key if api_key else None
except Exception:
return None
async with get_async_db_context(db) as session:
api_key = (await session.execute(select(ApiKey).where(ApiKey.user_id == id))).scalars().first()
return api_key.key if api_key else None
async def update_user_api_key_by_id(self, id: str, api_key: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as session:
await session.execute(delete(ApiKey).filter_by(user_id=id))
await session.commit()
now = int(time.time())
new_api_key = ApiKey(
id=f'key_{id}',
user_id=id,
key=api_key,
created_at=now,
updated_at=now,
)
session.add(new_api_key)
await session.commit()
return True
except Exception:
return False
async with get_async_db_context(db) as session:
await session.execute(delete(ApiKey).where(ApiKey.user_id == id))
now_ts = int(time.time())
new_key = ApiKey(
id=f'key_{id}',
user_id=id,
key=api_key,
created_at=now_ts,
updated_at=now_ts,
)
session.add(new_key)
await session.commit()
return True
async def delete_user_api_key_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
try:
async with get_async_db_context(db) as session:
await session.execute(delete(ApiKey).filter_by(user_id=id))
await session.commit()
return True
except Exception:
return False
async with get_async_db_context(db) as session:
await session.execute(delete(ApiKey).where(ApiKey.user_id == id))
await session.commit()
return True
async def get_valid_user_ids(self, user_ids: list[str], db: AsyncSession | None = None) -> list[str]:
async with get_async_db_context(db) as session:
result = await session.execute(select(User).filter(User.id.in_(user_ids)))
users = result.scalars().all()
return [user.id for user in users]
result = await session.execute(select(User).where(User.id.in_(user_ids)))
return [u.id for u in result.scalars().all()]
async def get_super_admin_user(self, db: AsyncSession | None = None) -> UserModel | None:
async with get_async_db_context(db) as session:
result = await session.execute(select(User).filter_by(role='admin').limit(1))
user = result.scalars().first()
if user:
return UserModel.model_validate(user)
else:
return None
row = (await session.execute(select(User).where(User.role == 'admin').limit(1))).scalars().first()
return UserModel.model_validate(row) if row else None
async def get_active_user_count(self, db: AsyncSession | None = None) -> int:
async with get_async_db_context(db) as session:
# Consider user active if last_active_at within the last 3 minutes
three_minutes_ago = int(time.time()) - 180
result = await session.execute(
select(func.count()).select_from(User).filter(User.last_active_at >= three_minutes_ago)
select(func.count()).select_from(User).where(User.last_active_at >= three_minutes_ago)
)
return result.scalar()
@@ -854,8 +736,7 @@ class UsersTable:
async def is_user_active(self, user_id: str, db: AsyncSession | None = None) -> bool:
async with get_async_db_context(db) as session:
result = await session.execute(select(User).filter_by(id=user_id))
user = result.scalars().first()
user = await session.get(User, user_id)
if user and user.last_active_at:
# Consider user active if last_active_at within the last 3 minutes
three_minutes_ago = int(time.time()) - 180
@@ -863,4 +744,5 @@ class UsersTable:
return False
Users = UsersTable()
Users = UsersTable() # singleton user repository
+3 -4
View File
@@ -383,7 +383,7 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user), db: Asy
@router.post('/user/info/update', response_model=dict | None)
async def update_user_info_by_session_user(
async def update_user_info_by_session_user( # PATCH-style merge
form_data: dict, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session),
):
"""Merge caller-supplied fields into the current user's info dict.
@@ -538,9 +538,8 @@ async def get_user_active_status_by_id(
@router.post('/{user_id}/update', response_model=UserModel | None)
async def update_user_by_id(
user_id: str,
form_data: UserUpdateForm,
session_user=Depends(get_admin_user),
user_id: str, form_data: UserUpdateForm,
session_user: UserModel = Depends(get_admin_user),
db: AsyncSession = Depends(get_async_session),
):
# Prevent modification of the primary admin user by other admins
+5 -5
View File
@@ -100,17 +100,17 @@ async def download_db(user=Depends(get_admin_user)):
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
from open_webui.internal.db import engine # deferred import
if engine.name != 'sqlite': # only SQLite DBs can be downloaded as a file
# --- resolve target database engine ---
from open_webui.internal.db import engine # lazy — avoids circular at import time
if engine.name != 'sqlite': # non-SQLite backends use pg_dump / managed exports
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DB_NOT_SQLITE,
)
return FileResponse(
engine.url.database,
str(engine.url.database),
media_type='application/octet-stream',
filename='webui.db',
)