mirror of
https://github.com/open-webui/open-webui.git
synced 2026-06-14 03:30:25 +00:00
refac
This commit is contained in:
@@ -117,7 +117,7 @@ extract_ssl_mode_from_url = extract_ssl_params_from_url
|
||||
reattach_ssl_mode_to_url = reattach_ssl_params_to_url
|
||||
|
||||
|
||||
class JSONField(types.TypeDecorator):
|
||||
class JSONField(types.TypeDecorator): # TEXT-backed JSON storage
|
||||
"""Store arbitrary Python objects as JSON-encoded TEXT.
|
||||
|
||||
Used instead of native JSON columns for portability across SQLite and
|
||||
@@ -125,19 +125,14 @@ class JSONField(types.TypeDecorator):
|
||||
deserialized with ``json.loads`` on read.
|
||||
"""
|
||||
|
||||
impl = types.Text
|
||||
cache_ok = True # safe for statement caching (no per-instance state)
|
||||
|
||||
impl = types.UnicodeText
|
||||
cache_ok = True
|
||||
def process_bind_param(self, value: _T | None, dialect: Dialect) -> Any:
|
||||
return json.dumps(value)
|
||||
|
||||
return json.dumps(value) if value is not None else None
|
||||
def process_result_value(self, value: _T | None, dialect: Dialect) -> Any:
|
||||
if value is not None:
|
||||
return json.loads(value)
|
||||
|
||||
def copy(self, **kw: Any) -> Self:
|
||||
return JSONField(self.impl.length)
|
||||
|
||||
return json.loads(value) if value is not None else None
|
||||
def copy(self, **kwargs: Any) -> Self:
|
||||
return JSONField(length=self.impl.length)
|
||||
def db_value(self, value):
|
||||
return json.dumps(value)
|
||||
|
||||
|
||||
@@ -2911,12 +2911,11 @@ async def readiness_check():
|
||||
|
||||
|
||||
@app.get('/health/db')
|
||||
async def healthcheck_with_db():
|
||||
async def check_db_health():
|
||||
"""Verify database connectivity by issuing a lightweight ping."""
|
||||
await async_db_ping()
|
||||
return {'status': True}
|
||||
|
||||
|
||||
# --- static assets & files ---
|
||||
# Serve build-time static assets (CSS, JS, images, favicon, etc.)
|
||||
app.mount('/static', StaticFiles(directory=STATIC_DIR), name='static')
|
||||
|
||||
|
||||
@@ -1,105 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""Alembic environment configuration.
|
||||
|
||||
Configures the migration context for both offline (SQL script generation)
|
||||
and online (live database connection) modes. Handles SQLCipher URLs,
|
||||
SSL parameter normalisation, and JSON log formatting.
|
||||
"""
|
||||
|
||||
# Alembic environment configuration runner.
|
||||
# Coordinates database migrations in both offline and online execution modes.
|
||||
import logging.config
|
||||
import logging
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
import alembic.context
|
||||
from open_webui.env import DATABASE_PASSWORD, DATABASE_URL, LOG_FORMAT
|
||||
from open_webui.internal.db import extract_ssl_params_from_url, reattach_ssl_params_to_url
|
||||
from open_webui.models.auths import Auth
|
||||
from open_webui.models.calendar import Calendar, CalendarEvent, CalendarEventAttendee # noqa: F401
|
||||
from sqlalchemy import create_engine, engine_from_config, pool
|
||||
|
||||
# ── Alembic config & logging ─────────────────────────────────────────────────
|
||||
|
||||
config = context.config
|
||||
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name, disable_existing_loggers=False)
|
||||
|
||||
# Re-apply JSON formatter after fileConfig replaces handlers.
|
||||
alembic_config = alembic.context.config
|
||||
if alembic_config.config_file_name:
|
||||
logging.config.fileConfig(alembic_config.config_file_name, disable_existing_loggers=False)
|
||||
if LOG_FORMAT == 'json':
|
||||
from open_webui.env import JSONFormatter
|
||||
|
||||
for handler in logging.root.handlers:
|
||||
handler.setFormatter(JSONFormatter())
|
||||
|
||||
# ── Database URL ─────────────────────────────────────────────────────────────
|
||||
|
||||
target_metadata = Auth.metadata
|
||||
|
||||
DB_URL = DATABASE_URL
|
||||
|
||||
# Normalise SSL query params for psycopg2 (Alembic uses psycopg2 for sync).
|
||||
_url_no_ssl, _ssl_params = extract_ssl_params_from_url(DB_URL)
|
||||
if _ssl_params:
|
||||
DB_URL = reattach_ssl_params_to_url(_url_no_ssl, _ssl_params)
|
||||
|
||||
if DB_URL:
|
||||
config.set_main_option('sqlalchemy.url', DB_URL.replace('%', '%%'))
|
||||
|
||||
|
||||
# ── Migration runners ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
for log_handler in logging.root.handlers:
|
||||
log_handler.setFormatter(JSONFormatter())
|
||||
migration_metadata = Auth.metadata
|
||||
target_db_url = DATABASE_URL
|
||||
base_url, ssl_query_params = extract_ssl_params_from_url(target_db_url)
|
||||
if ssl_query_params:
|
||||
target_db_url = reattach_ssl_params_to_url(base_url, ssl_query_params)
|
||||
if target_db_url:
|
||||
alembic_config.set_main_option('sqlalchemy.url', target_db_url.replace('%', '%%'))
|
||||
def run_migrations_offline() -> None:
|
||||
"""Generate SQL script without a live database connection."""
|
||||
url = config.get_main_option('sqlalchemy.url')
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={'paramstyle': 'named'},
|
||||
)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def _build_connectable():
|
||||
"""Create the appropriate SQLAlchemy engine for the configured DB URL."""
|
||||
if DB_URL and DB_URL.startswith('sqlite+sqlcipher://'):
|
||||
if not DATABASE_PASSWORD or DATABASE_PASSWORD.strip() == '':
|
||||
"""Execute Alembic migrations in offline mode (outputs raw SQL DDL)."""
|
||||
db_connection_url = alembic_config.get_main_option('sqlalchemy.url')
|
||||
alembic.context.configure(url=db_connection_url, target_metadata=migration_metadata, literal_binds=True, dialect_opts={"paramstyle": "named"})
|
||||
with alembic.context.begin_transaction():
|
||||
alembic.context.run_migrations()
|
||||
def _get_engine_connectable():
|
||||
"""Build the database engine based on target URL and authentication credentials."""
|
||||
if target_db_url and target_db_url.startswith('sqlite+sqlcipher://'):
|
||||
if not DATABASE_PASSWORD or not DATABASE_PASSWORD.strip():
|
||||
raise ValueError('DATABASE_PASSWORD is required when using sqlite+sqlcipher:// URLs')
|
||||
|
||||
db_path = DB_URL.replace('sqlite+sqlcipher://', '')
|
||||
if db_path.startswith('/'):
|
||||
db_path = db_path[1:]
|
||||
|
||||
def _sqlcipher_creator():
|
||||
raw_db_path = target_db_url.replace('sqlite+sqlcipher://', '')
|
||||
if raw_db_path.startswith('/'):
|
||||
raw_db_path = raw_db_path[1:]
|
||||
def _sqlite_cipher_creator():
|
||||
import sqlcipher3
|
||||
|
||||
conn = sqlcipher3.connect(db_path, check_same_thread=False)
|
||||
conn.execute(f"PRAGMA key = '{DATABASE_PASSWORD}'")
|
||||
return conn
|
||||
|
||||
return create_engine('sqlite://', creator=_sqlcipher_creator, echo=False)
|
||||
|
||||
cipher_conn = sqlcipher3.connect(raw_db_path, check_same_thread=False)
|
||||
cipher_conn.execute(f"PRAGMA key = '{DATABASE_PASSWORD}'")
|
||||
return cipher_conn
|
||||
return create_engine('sqlite://', creator=_sqlite_cipher_creator, echo=False)
|
||||
return engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
alembic_config.get_section(alembic_config.config_ini_section, {}),
|
||||
prefix='sqlalchemy.',
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations against a live database connection."""
|
||||
connectable = _build_connectable()
|
||||
with connectable.connect() as connection:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
# ── Entrypoint ───────────────────────────────────────────────────────────────
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
"""Execute migrations against a live database connection."""
|
||||
live_connectable = _get_engine_connectable()
|
||||
with live_connectable.connect() as live_connection:
|
||||
alembic.context.configure(
|
||||
connection=live_connection,
|
||||
target_metadata=migration_metadata,
|
||||
)
|
||||
with alembic.context.begin_transaction():
|
||||
alembic.context.run_migrations()
|
||||
# Alembic execution entrypoint branch
|
||||
if alembic.context.is_offline_mode():
|
||||
run_migrations_offline() # run in offline mode
|
||||
if not alembic.context.is_offline_mode():
|
||||
run_migrations_online() # run in online mode
|
||||
|
||||
@@ -2,10 +2,9 @@ from __future__ import annotations
|
||||
|
||||
"""Alembic migration utilities."""
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy import inspect
|
||||
|
||||
|
||||
from alembic import op # noqa: E402 — alembic runtime context
|
||||
from sqlalchemy import inspect # metadata inspection
|
||||
# --- database helper functions ---
|
||||
def get_existing_tables() -> set[str]:
|
||||
"""Return table names already present in the database."""
|
||||
conn = op.get_bind()
|
||||
|
||||
@@ -1,30 +1,20 @@
|
||||
# Initial bootstrap migration version.
|
||||
# Revision ID: 7e5b5dc7342b
|
||||
# Revises: (none)
|
||||
# Created on: 2024-06-24 13:15:33.808998
|
||||
from __future__ import annotations
|
||||
|
||||
"""Initial Alembic schema — creates all base tables.
|
||||
|
||||
Revision ID: 7e5b5dc7342b
|
||||
Revises: —
|
||||
Create Date: 2024-06-24 13:15:33.808998
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from typing import Sequence
|
||||
import open_webui.internal.db # noqa: F401
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from open_webui.internal.db import JSONField
|
||||
from open_webui.migrations.util import get_existing_tables
|
||||
|
||||
revision: str = '7e5b5dc7342b'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
# ── Table definitions ────────────────────────────────────────────────────────
|
||||
# Each table is only created if it doesn't already exist, because databases
|
||||
# migrated from the Peewee era will already have these tables.
|
||||
|
||||
_TABLES: list[tuple[str, list[sa.Column], list]] = [
|
||||
revision: str = "7e5b5dc7342b"
|
||||
down_revision: str | None = None
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
# Initial schema table declarations
|
||||
_INITIAL_TABLES: list[tuple[str, list[sa.Column], list]] = [
|
||||
(
|
||||
'auth',
|
||||
[
|
||||
@@ -185,15 +175,13 @@ _TABLES: list[tuple[str, list[sa.Column], list]] = [
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
existing = set(get_existing_tables())
|
||||
for name, columns, constraints in _TABLES:
|
||||
if name not in existing:
|
||||
# --- migration execution ---
|
||||
def upgrade() -> None: # deploy initial schema tables
|
||||
existing_tables = set(get_existing_tables())
|
||||
for name, columns, constraints in _INITIAL_TABLES:
|
||||
if name not in existing_tables:
|
||||
op.create_table(name, *columns, *constraints)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
for name, _, _ in reversed(_TABLES):
|
||||
op.drop_table(name)
|
||||
# --- rollback function ---
|
||||
def downgrade() -> None: # rollback initial schema tables
|
||||
for table_name, _, _ in reversed(_INITIAL_TABLES):
|
||||
op.drop_table(table_name)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Authentication models and database operations."""
|
||||
"""Auth credential models and data-access layer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -16,19 +16,19 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Auth(Base):
|
||||
"""Credential record linking a user identity to an email + hashed password."""
|
||||
class Auth(Base): # credential ↔ user linkage
|
||||
"""Maps a user ID to an email/password pair with an active flag."""
|
||||
|
||||
__tablename__ = 'auth'
|
||||
|
||||
id = Column(String, primary_key=True, unique=True) # same as User.id
|
||||
email = Column(String) # login email, kept in sync with User.email
|
||||
password = Column(Text) # bcrypt / argon2 hash
|
||||
active = Column(Boolean) # soft-disable flag
|
||||
id = Column(String, primary_key=True, unique=True) # mirrors User.id
|
||||
email = Column(String) # login address, kept in sync with User.email
|
||||
password = Column(Text) # argon2 / bcrypt hash
|
||||
active = Column(Boolean) # account soft-disable toggle
|
||||
|
||||
|
||||
class AuthModel(BaseModel):
|
||||
"""Pydantic mirror of the Auth table row."""
|
||||
"""Pydantic mirror of the ``auth`` table row."""
|
||||
|
||||
id: str
|
||||
email: str
|
||||
@@ -37,7 +37,7 @@ class AuthModel(BaseModel):
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
"""JWT bearer token response."""
|
||||
"""JWT bearer-token response wrapper."""
|
||||
|
||||
token: str
|
||||
token_type: str
|
||||
@@ -88,7 +88,12 @@ class AddUserForm(SignupForm):
|
||||
role: str | None = 'pending'
|
||||
|
||||
|
||||
# --- data-access layer ---
|
||||
|
||||
|
||||
class AuthsTable:
|
||||
"""Provides CRUD operations for the Auth ↔ User lifecycle."""
|
||||
|
||||
async def insert_new_auth(
|
||||
self,
|
||||
email: str,
|
||||
@@ -99,130 +104,109 @@ class AuthsTable:
|
||||
oauth: dict | None = None,
|
||||
db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
"""Create an Auth + User pair in a single transaction."""
|
||||
async with get_async_db_context(db) as db:
|
||||
"""Create an Auth + User pair inside a single transaction."""
|
||||
async with get_async_db_context(db) as session:
|
||||
log.info('insert_new_auth')
|
||||
|
||||
user_id = str(uuid.uuid4())
|
||||
new_id = str(uuid.uuid4())
|
||||
|
||||
record = Auth(
|
||||
id=user_id,
|
||||
credential = Auth(
|
||||
id=new_id,
|
||||
email=email,
|
||||
password=password,
|
||||
active=True,
|
||||
)
|
||||
db.add(record)
|
||||
session.add(credential)
|
||||
|
||||
user = await Users.insert_new_user(
|
||||
user_id, name, email, profile_image_url, role, oauth=oauth, db=db,
|
||||
created_user = await Users.insert_new_user(
|
||||
new_id, name, email, profile_image_url, role, oauth=oauth, db=session,
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(record)
|
||||
|
||||
return user if record and user else None
|
||||
# persist both records and reload generated defaults
|
||||
await session.commit()
|
||||
await session.refresh(credential)
|
||||
return created_user if credential and created_user else None
|
||||
|
||||
async def authenticate_user(
|
||||
self, email: str, verify_password: callable, db: AsyncSession | None = None
|
||||
self, email: str, verify_password: callable, db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
"""Verify a user's email + password and return the user on success."""
|
||||
log.info(f'authenticate_user: {email}')
|
||||
|
||||
user = await Users.get_user_by_email(email, db=db)
|
||||
if not user:
|
||||
return None
|
||||
|
||||
try: # load the auth row for password verification
|
||||
async with get_async_db_context(db) as session:
|
||||
auth = await session.get(Auth, user.id)
|
||||
if not auth or not auth.active:
|
||||
return None
|
||||
if not verify_password(auth.password):
|
||||
return None
|
||||
return user
|
||||
except Exception:
|
||||
"""Verify email + password credentials and return the matching user."""
|
||||
log.info('authenticate_user: %s', email)
|
||||
resolved = await Users.get_user_by_email(email, db=db)
|
||||
if not resolved:
|
||||
return
|
||||
# load the credential row and verify the password hash
|
||||
async with get_async_db_context(db) as session:
|
||||
credential = await session.get(Auth, resolved.id)
|
||||
if not credential or not credential.active:
|
||||
return
|
||||
if not verify_password(credential.password):
|
||||
return
|
||||
return resolved
|
||||
|
||||
async def authenticate_user_by_api_key(
|
||||
self, api_key: str, db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
"""Resolve an API key to its owning user, returning ``None`` on miss."""
|
||||
"""Look up the user that owns the given API key."""
|
||||
log.info('authenticate_user_by_api_key')
|
||||
if not api_key: # empty / None key — reject immediately
|
||||
if not api_key:
|
||||
return
|
||||
try:
|
||||
return await Users.get_user_by_api_key(api_key, db=db)
|
||||
except Exception:
|
||||
return False
|
||||
# delegate to the Users model for the actual lookup
|
||||
return await Users.get_user_by_api_key(api_key, db=db)
|
||||
|
||||
async def authenticate_user_by_email(
|
||||
self,
|
||||
email: str,
|
||||
db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
"""One-query authentication: JOIN Auth ↔ User, filter by email + active flag."""
|
||||
"""Single-query auth via JOIN on Auth ↔ User, filtered by active flag."""
|
||||
log.info('authenticate_user_by_email: %s', email)
|
||||
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
stmt = (
|
||||
select(Auth, User)
|
||||
.join(User, Auth.id == User.id)
|
||||
.filter(Auth.email == email, Auth.active == True)
|
||||
)
|
||||
row = (await session.execute(stmt)).first()
|
||||
if not row:
|
||||
return
|
||||
_auth, matched_user = row
|
||||
return UserModel.model_validate(matched_user)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
async def update_user_password_by_id(
|
||||
self,
|
||||
id: str,
|
||||
new_password: str,
|
||||
db: AsyncSession | None = None,
|
||||
) -> bool:
|
||||
"""Hash-swap: replace the stored password hash for a given user."""
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
stmt = update(Auth).filter_by(id=id).values(password=new_password)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
return result.rowcount == 1
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# single JOIN avoids N+1 — returns (Auth, User) tuple or None
|
||||
async with get_async_db_context(db) as session:
|
||||
joined_query = (
|
||||
select(Auth, User)
|
||||
.join(User, Auth.id == User.id)
|
||||
.where(Auth.email == email, Auth.active.is_(True))
|
||||
)
|
||||
match = (await session.execute(joined_query)).first()
|
||||
if not match:
|
||||
return
|
||||
_, found_user = match
|
||||
return UserModel.model_validate(found_user)
|
||||
async def update_email_by_id(
|
||||
self, id: str, email: str, db: AsyncSession | None = None,
|
||||
self, user_id: str, email: str, db: AsyncSession | None = None,
|
||||
) -> bool:
|
||||
"""Update the auth email and propagate the change to the User table."""
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
stmt = update(Auth).filter_by(id=id).values(email=email)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
if result.rowcount != 1:
|
||||
return False
|
||||
await Users.update_user_by_id(id, {'email': email}, db=session)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
"""Set a new email on the auth record and propagate to the user row."""
|
||||
async with get_async_db_context(db) as session:
|
||||
auth_row = await session.get(Auth, user_id)
|
||||
if auth_row is None:
|
||||
return False
|
||||
auth_row.email = email
|
||||
await session.commit()
|
||||
await Users.update_user_by_id(user_id, {'email': email}, db=session)
|
||||
return True
|
||||
# --- password modification ---
|
||||
async def update_user_password_by_id(
|
||||
self, user_id: str, new_password: str, db: AsyncSession | None = None,
|
||||
) -> bool:
|
||||
"""Set a new password hash for an existing user."""
|
||||
async with get_async_db_context(db) as session:
|
||||
auth_row = await session.get(Auth, user_id)
|
||||
if auth_row is None:
|
||||
return False
|
||||
auth_row.password = new_password
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def delete_auth_by_id(
|
||||
self, id: str, db: AsyncSession | None = None,
|
||||
) -> bool:
|
||||
"""Delete a user and their auth record in a single transaction."""
|
||||
try: # delete user first, then auth (FK order)
|
||||
async with get_async_db_context(db) as session:
|
||||
if not await Users.delete_user_by_id(id, db=session):
|
||||
return False # user deletion failed — abort
|
||||
await session.execute(delete(Auth).filter_by(id=id))
|
||||
await session.commit()
|
||||
return True
|
||||
except Exception: # db / integrity error
|
||||
return False # partial deletion is rolled back by context manager
|
||||
"""Remove a user and their auth credential in one transaction."""
|
||||
async with get_async_db_context(db) as session:
|
||||
if not await Users.delete_user_by_id(id, db=session):
|
||||
return False
|
||||
await session.execute(delete(Auth).where(Auth.id == id))
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
|
||||
Auths = AuthsTable()
|
||||
Auths = AuthsTable() # singleton — module-level instance
|
||||
|
||||
@@ -6,7 +6,7 @@ import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
|
||||
# local imports
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from open_webui.models.automations import AutomationRun
|
||||
from open_webui.models.chat_messages import ChatMessage, ChatMessages
|
||||
@@ -39,16 +39,16 @@ from sqlalchemy.sql.expression import bindparam
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Chat(Base):
|
||||
class Chat(Base): # database table mapping for chat entity
|
||||
__tablename__ = 'chat'
|
||||
|
||||
id = Column(String, primary_key=True, unique=True)
|
||||
user_id = Column(String)
|
||||
user_id = Column(String, index=True) # owner user id
|
||||
title = Column(Text) # user-visible conversation title
|
||||
chat = Column(JSON)
|
||||
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger)
|
||||
created_at = Column(BigInteger, index=True) # conversation creation timestamp
|
||||
updated_at = Column(BigInteger, index=True) # conversation modification timestamp
|
||||
|
||||
share_id = Column(Text, unique=True, nullable=True) # public share link token
|
||||
archived = Column(Boolean, default=False) # hidden from main chat list
|
||||
@@ -73,8 +73,7 @@ class Chat(Base):
|
||||
|
||||
|
||||
class ChatModel(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True) # allows ORM model binding
|
||||
id: str
|
||||
user_id: str
|
||||
title: str
|
||||
@@ -676,22 +675,21 @@ class ChatTable:
|
||||
await session.commit()
|
||||
await session.refresh(chat)
|
||||
return ChatModel.model_validate(chat) # return the updated original
|
||||
|
||||
# refresh helper
|
||||
async def update_shared_chat_by_chat_id(
|
||||
self, chat_id: str, db: AsyncSession | None = None,
|
||||
) -> ChatModel | None:
|
||||
"""Re-snapshot the shared copy with the latest chat content."""
|
||||
from open_webui.models.shared_chats import SharedChats # deferred — circular
|
||||
"""Refresh the shared snapshot with current chat content."""
|
||||
from open_webui.models.shared_chats import SharedChats
|
||||
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
chat = await session.get(Chat, chat_id)
|
||||
if not chat or not chat.share_id:
|
||||
return await self.insert_shared_chat_by_chat_id(chat_id, db=session)
|
||||
await SharedChats.update(chat.share_id, db=session)
|
||||
return ChatModel.model_validate(chat)
|
||||
except Exception:
|
||||
return
|
||||
async with get_async_db_context(db) as session:
|
||||
record = await session.get(Chat, chat_id)
|
||||
if not record or not record.share_id:
|
||||
return await self.insert_shared_chat_by_chat_id(chat_id, db=session)
|
||||
await SharedChats.update(record.share_id, db=session)
|
||||
return ChatModel.model_validate(record)
|
||||
# unreachable — context manager above always returns
|
||||
return
|
||||
|
||||
async def delete_shared_chat_by_chat_id(self, chat_id: str, db: AsyncSession | None = None) -> bool:
|
||||
"""Delete shared snapshot for a chat."""
|
||||
@@ -938,7 +936,7 @@ class ChatTable:
|
||||
)
|
||||
all_chats = result.scalars().all()
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
# retrieve conversation
|
||||
async def get_chat_by_id(
|
||||
self, id: str, db: AsyncSession | None = None,
|
||||
) -> ChatModel | None:
|
||||
@@ -1019,7 +1017,7 @@ class ChatTable:
|
||||
result = await session.execute(select(Chat).order_by(Chat.updated_at.desc()))
|
||||
all_chats = result.scalars().all()
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
|
||||
# list user conversations
|
||||
async def get_chats_by_user_id(
|
||||
self,
|
||||
user_id: str,
|
||||
@@ -1067,7 +1065,7 @@ class ChatTable:
|
||||
'total': total,
|
||||
}
|
||||
)
|
||||
|
||||
# list pinned chats
|
||||
async def get_pinned_chats_by_user_id(
|
||||
self, user_id: str, db: AsyncSession | None = None
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
@@ -1097,7 +1095,7 @@ class ChatTable:
|
||||
select(Chat).filter_by(user_id=user_id, archived=True).order_by(Chat.updated_at.desc())
|
||||
)
|
||||
return [ChatModel.model_validate(chat) for chat in result.scalars().all()]
|
||||
|
||||
# search user conversations
|
||||
async def get_chats_by_user_id_and_search_text(
|
||||
self,
|
||||
user_id: str,
|
||||
@@ -1712,4 +1710,4 @@ class ChatTable:
|
||||
return row[0]
|
||||
|
||||
|
||||
Chats = ChatTable()
|
||||
Chats = ChatTable() # singleton chats repository
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
# local imports
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from open_webui.utils.misc import sanitize_metadata
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
@@ -14,10 +14,10 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class File(Base):
|
||||
class File(Base): # uploaded file record
|
||||
__tablename__ = 'file'
|
||||
id = Column(String, primary_key=True, unique=True)
|
||||
user_id = Column(String)
|
||||
user_id = Column(String, index=True) # owner user id
|
||||
hash = Column(Text, nullable=True)
|
||||
|
||||
filename = Column(Text) # original upload filename
|
||||
@@ -26,7 +26,7 @@ class File(Base):
|
||||
data = Column(JSON, nullable=True)
|
||||
meta = Column(JSON, nullable=True)
|
||||
|
||||
created_at = Column(BigInteger)
|
||||
created_at = Column(BigInteger, index=True) # upload timestamp
|
||||
updated_at = Column(BigInteger)
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ class FileModel(BaseModel):
|
||||
created_at: int | None # timestamp in epoch
|
||||
updated_at: int | None # timestamp in epoch
|
||||
|
||||
|
||||
# --- metadata structures ---
|
||||
class FileMeta(BaseModel):
|
||||
name: str | None = None
|
||||
content_type: str | None = None
|
||||
@@ -410,4 +410,4 @@ class FilesTable:
|
||||
return False
|
||||
|
||||
|
||||
Files = FilesTable()
|
||||
Files = FilesTable() # singleton files repository
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
# local imports
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from open_webui.models.users import UserModel, UserResponse, Users
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
@@ -14,17 +14,17 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Function(Base):
|
||||
class Function(Base): # database table mapping
|
||||
__tablename__ = 'function'
|
||||
|
||||
id = Column(String, primary_key=True, unique=True)
|
||||
user_id = Column(String)
|
||||
name = Column(Text)
|
||||
type = Column(Text)
|
||||
content = Column(Text)
|
||||
meta = Column(JSONField)
|
||||
valves = Column(JSONField)
|
||||
is_active = Column(Boolean)
|
||||
user_id = Column(String, index=True) # creator user id
|
||||
name = Column(Text, nullable=False) # function identifier
|
||||
type = Column(Text, nullable=False) # function type (pipe, filter, etc.)
|
||||
content = Column(Text, nullable=True) # Python source code
|
||||
meta = Column(JSONField, nullable=True) # function metadata
|
||||
valves = Column(JSONField, nullable=True) # function configuration valves
|
||||
is_active = Column(Boolean, default=False) # function activation status
|
||||
is_global = Column(Boolean) # if True, applied to every chat automatically
|
||||
updated_at = Column(BigInteger) # epoch seconds
|
||||
created_at = Column(BigInteger) # epoch seconds
|
||||
@@ -50,9 +50,9 @@ class FunctionModel(BaseModel):
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True) # allows ORM model binding
|
||||
|
||||
# --- form / schema definitions ---
|
||||
class FunctionWithValvesModel(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
@@ -425,4 +425,4 @@ class FunctionsTable:
|
||||
return False
|
||||
|
||||
|
||||
Functions = FunctionsTable()
|
||||
Functions = FunctionsTable() # singleton functions engine
|
||||
|
||||
@@ -12,8 +12,8 @@ from sqlalchemy import BigInteger, Column, String, Text, delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class Memory(Base):
|
||||
"""Persistent user memory backed by a vector collection."""
|
||||
class Memory(Base): # user memory store
|
||||
"""Stores user-created memory entries linked to a vector collection."""
|
||||
|
||||
__tablename__ = 'memory'
|
||||
|
||||
@@ -32,8 +32,7 @@ class MemoryModel(BaseModel):
|
||||
content: str
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True) # allows ORM mapping
|
||||
|
||||
class MemoriesTable:
|
||||
async def insert_new_memory(
|
||||
@@ -140,4 +139,4 @@ class MemoriesTable:
|
||||
return False
|
||||
|
||||
|
||||
Memories = MemoriesTable()
|
||||
Memories = MemoriesTable() # user memory registry
|
||||
|
||||
@@ -59,7 +59,7 @@ class ModelMeta(BaseModel):
|
||||
return data
|
||||
|
||||
|
||||
class Model(Base):
|
||||
class Model(Base): # provider model config
|
||||
"""Workspace model entry — wraps an upstream LLM with custom params and metadata."""
|
||||
|
||||
__tablename__ = 'model'
|
||||
@@ -573,4 +573,4 @@ class ModelsTable:
|
||||
return []
|
||||
|
||||
|
||||
Models = ModelsTable()
|
||||
Models = ModelsTable() # singleton model registry
|
||||
|
||||
@@ -17,14 +17,14 @@ from sqlalchemy import JSON, BigInteger, Boolean, Column, String, Text, cast, de
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class Prompt(Base):
|
||||
"""System prompt template with versioning support."""
|
||||
class Prompt(Base): # versioned template
|
||||
"""Slash-command prompt with history tracking and access control."""
|
||||
|
||||
__tablename__ = 'prompt'
|
||||
|
||||
id = Column(Text, primary_key=True)
|
||||
command = Column(String, unique=True, index=True)
|
||||
user_id = Column(String)
|
||||
user_id = Column(String, index=True) # owner user id
|
||||
name = Column(Text)
|
||||
content = Column(Text) # the prompt template body
|
||||
data = Column(JSON, nullable=True) # structured prompt parameters
|
||||
@@ -51,10 +51,8 @@ class PromptModel(BaseModel):
|
||||
updated_at: int | None = None
|
||||
access_grants: list[AccessGrantModel] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
####################
|
||||
model_config = ConfigDict(from_attributes=True) # allows ORM model binding
|
||||
# --- form / schema definitions ---
|
||||
# Forms
|
||||
####################
|
||||
|
||||
@@ -175,34 +173,33 @@ class PromptsTable:
|
||||
return
|
||||
|
||||
async def get_prompt_by_command(self, command: str, db: AsyncSession | None = None) -> PromptModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(
|
||||
select(Prompt).filter_by(command=command),
|
||||
)
|
||||
prompt = result.scalars().first() # None when no match
|
||||
if not prompt:
|
||||
return None
|
||||
return await self._to_prompt_model(prompt, db=session)
|
||||
except Exception: # connection / integrity error
|
||||
return
|
||||
"""Look up a prompt by its unique slash-command string."""
|
||||
async with get_async_db_context(db) as session:
|
||||
match = (await session.execute(
|
||||
select(Prompt).where(Prompt.command == command)
|
||||
)).scalars().first()
|
||||
if match is None:
|
||||
return
|
||||
return await self._to_prompt_model(match, db=session)
|
||||
# --- context manager always returns above ---
|
||||
return
|
||||
|
||||
async def get_prompts(self, db: AsyncSession | None = None) -> list[PromptUserResponse]:
|
||||
"""Return all active prompts ordered by most recently updated."""
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(
|
||||
select(Prompt).filter(Prompt.is_active == True).order_by(Prompt.updated_at.desc())
|
||||
)
|
||||
all_prompts = result.scalars().all()
|
||||
active = (await session.execute(
|
||||
select(Prompt).where(Prompt.is_active.is_(True)).order_by(Prompt.updated_at.desc())
|
||||
)).scalars().all()
|
||||
|
||||
user_ids = list(set(prompt.user_id for prompt in all_prompts))
|
||||
prompt_ids = [prompt.id for prompt in all_prompts]
|
||||
user_ids = list(set(p.user_id for p in active))
|
||||
prompt_ids = [p.id for p in active]
|
||||
|
||||
users = await Users.get_users_by_user_ids(user_ids, db=session) if user_ids else []
|
||||
users_dict = {user.id: user for user in users}
|
||||
users_dict = {u.id: u for u in users}
|
||||
grants_map = await AccessGrants.get_grants_by_resources('prompt', prompt_ids, db=session)
|
||||
|
||||
prompts = []
|
||||
for prompt in all_prompts:
|
||||
for prompt in active:
|
||||
user = users_dict.get(prompt.user_id)
|
||||
prompts.append(
|
||||
PromptUserResponse.model_validate(
|
||||
@@ -399,7 +396,9 @@ class PromptsTable:
|
||||
user_id: str,
|
||||
db: AsyncSession | None = None,
|
||||
) -> PromptModel | None:
|
||||
try:
|
||||
if not command:
|
||||
return None
|
||||
try: # database transaction
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(Prompt).filter_by(command=command))
|
||||
prompt = result.scalars().first()
|
||||
@@ -596,16 +595,16 @@ class PromptsTable:
|
||||
await session.commit()
|
||||
|
||||
return await self._to_prompt_model(prompt, db=session)
|
||||
except Exception: # connection error
|
||||
return None
|
||||
|
||||
# --- Active state management ---
|
||||
|
||||
except Exception as e: # connection error
|
||||
log.error(f"Failed to restore prompt version: {e}")
|
||||
return None # restoration failed
|
||||
async def toggle_prompt_active(
|
||||
self, prompt_id: str, db: AsyncSession | None = None,
|
||||
) -> PromptModel | None:
|
||||
"""Toggle the is_active flag on a prompt."""
|
||||
try:
|
||||
"""Flip the is_active flag on a prompt."""
|
||||
if not prompt_id:
|
||||
return None
|
||||
try: # activation state toggle
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(Prompt).filter_by(id=prompt_id))
|
||||
prompt = result.scalars().first()
|
||||
@@ -650,8 +649,9 @@ class PromptsTable:
|
||||
await session.commit()
|
||||
return True
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
except Exception as err:
|
||||
log.error(f"Failed to delete prompt: {err}")
|
||||
return False # deletion failed
|
||||
|
||||
async def get_tags(self, db: AsyncSession | None = None) -> list[str]:
|
||||
try:
|
||||
@@ -695,4 +695,4 @@ class PromptsTable:
|
||||
return []
|
||||
|
||||
|
||||
Prompts = PromptsTable()
|
||||
Prompts = PromptsTable() # singleton prompts registry
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
|
||||
# local imports
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from sqlalchemy import JSON, BigInteger, Column, Index, PrimaryKeyConstraint, String, delete, select
|
||||
@@ -14,11 +14,11 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
class Tag(Base): # database table mapping for tag entity
|
||||
__tablename__ = 'tag'
|
||||
id = Column(String)
|
||||
name = Column(String)
|
||||
user_id = Column(String)
|
||||
name = Column(String, index=True) # tag label
|
||||
user_id = Column(String, index=True) # user identifier
|
||||
meta = Column(JSON, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
@@ -35,10 +35,9 @@ class TagModel(BaseModel):
|
||||
name: str
|
||||
user_id: str
|
||||
meta: dict | None = None
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
model_config = ConfigDict(from_attributes=True) # allows ORM model binding
|
||||
|
||||
|
||||
####################
|
||||
# --- tag schema forms ---
|
||||
# Forms
|
||||
####################
|
||||
|
||||
@@ -132,4 +131,4 @@ class TagTable:
|
||||
await db.commit()
|
||||
|
||||
|
||||
Tags = TagTable()
|
||||
Tags = TagTable() # singleton tag repository
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
# local imports
|
||||
from open_webui.internal.db import Base, JSONField, get_async_db_context
|
||||
from open_webui.models.access_grants import AccessGrantModel, AccessGrants
|
||||
from open_webui.models.groups import Groups
|
||||
@@ -16,19 +16,19 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Tool(Base):
|
||||
class Tool(Base): # database table definition
|
||||
__tablename__ = 'tool'
|
||||
|
||||
id = Column(String, primary_key=True, unique=True)
|
||||
user_id = Column(String)
|
||||
user_id = Column(String, index=True) # owner user id
|
||||
name = Column(Text) # human-readable label
|
||||
content = Column(Text) # Python source code
|
||||
specs = Column(JSONField) # OpenAPI-style function specs
|
||||
meta = Column(JSONField) # description, manifest, etc.
|
||||
valves = Column(JSONField) # admin-configurable runtime parameters
|
||||
|
||||
updated_at = Column(BigInteger)
|
||||
created_at = Column(BigInteger)
|
||||
updated_at = Column(BigInteger, nullable=False) # modification timestamp
|
||||
created_at = Column(BigInteger, index=True) # creation timestamp
|
||||
|
||||
|
||||
class ToolMeta(BaseModel):
|
||||
@@ -48,10 +48,9 @@ class ToolModel(BaseModel):
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
model_config = ConfigDict(from_attributes=True) # enables ORM mapping
|
||||
|
||||
|
||||
####################
|
||||
# --- tool request forms ---
|
||||
# Forms
|
||||
####################
|
||||
|
||||
@@ -316,4 +315,4 @@ class ToolsTable:
|
||||
return False
|
||||
|
||||
|
||||
Tools = ToolsTable()
|
||||
Tools = ToolsTable() # singleton tool registry
|
||||
|
||||
+193
-311
@@ -43,17 +43,15 @@ class UserSettings(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""Core user identity and profile record."""
|
||||
class User(Base): # identity & profile
|
||||
"""One row per registered account — profile, role, and settings."""
|
||||
|
||||
__tablename__ = 'user'
|
||||
|
||||
# Identity
|
||||
id = Column(String, primary_key=True, unique=True)
|
||||
email = Column(String)
|
||||
username = Column(String(50), nullable=True)
|
||||
role = Column(String)
|
||||
name = Column(String)
|
||||
__tablename__: str = 'user' # Identity & Credentials
|
||||
id = Column(String, primary_key=True, unique=True) # unique user id
|
||||
email = Column(String, unique=True) # user email address
|
||||
username = Column(String(50), nullable=True) # custom handle
|
||||
role = Column(String, default="pending") # permissions role
|
||||
name = Column(String, nullable=False) # display name
|
||||
|
||||
# Profile
|
||||
profile_image_url = Column(Text) # data-uri, path, or external URL
|
||||
@@ -119,16 +117,18 @@ class UserModel(BaseModel):
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
)
|
||||
|
||||
@model_validator(mode='after') # runs after all field validators
|
||||
# validation schema logic
|
||||
# --- model validators ---
|
||||
@model_validator(mode='after')
|
||||
def _ensure_profile_image(self) -> 'UserModel':
|
||||
"""Fall back to a generated avatar when no explicit profile image is set."""
|
||||
if self.profile_image_url: # explicit image — nothing to do
|
||||
return self
|
||||
self.profile_image_url = _DEFAULT_PROFILE_IMAGE_URL.format(
|
||||
user_id=self.id,
|
||||
"""Assign a generated avatar when no profile image is provided."""
|
||||
self.profile_image_url = (
|
||||
self.profile_image_url
|
||||
or _DEFAULT_PROFILE_IMAGE_URL.format(user_id=self.id)
|
||||
)
|
||||
return self # modified in-place
|
||||
return self
|
||||
|
||||
|
||||
|
||||
|
||||
class UserStatusModel(UserModel):
|
||||
@@ -302,97 +302,77 @@ class UsersTable:
|
||||
session.add(result)
|
||||
await session.commit()
|
||||
await session.refresh(result)
|
||||
if result:
|
||||
return user
|
||||
else:
|
||||
return None
|
||||
|
||||
return user if result else None
|
||||
# database read methods
|
||||
# --- read / lookup operations ---
|
||||
async def get_user_by_id(
|
||||
self, id: str, db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
"""Fetch a single user by primary key."""
|
||||
try: # db.get is a PK lookup — very cheap
|
||||
async with get_async_db_context(db) as session:
|
||||
user = await session.get(User, id)
|
||||
if not user:
|
||||
return None
|
||||
return UserModel.model_validate(user)
|
||||
except Exception: # stale session / connection drop
|
||||
return None # caller treats None as "not found"
|
||||
|
||||
async with get_async_db_context(db) as session:
|
||||
user = await session.get(User, id)
|
||||
return UserModel.model_validate(user) if user else None
|
||||
# api key auth helper
|
||||
async def get_user_by_api_key(
|
||||
self, api_key: str, db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
"""Resolve a user from their API key via a JOIN on the api_key table."""
|
||||
try: # single JOIN instead of two round-trips
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
.join(ApiKey, User.id == ApiKey.user_id)
|
||||
.filter(ApiKey.key == api_key),
|
||||
) # emits: SELECT user.* FROM user JOIN api_key …
|
||||
user = result.scalars().first() # None when key is invalid
|
||||
if not user:
|
||||
return None
|
||||
return UserModel.model_validate(user)
|
||||
except Exception: # stale session / connection drop
|
||||
return
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(
|
||||
select(User)
|
||||
.join(ApiKey, User.id == ApiKey.user_id)
|
||||
.where(ApiKey.key == api_key),
|
||||
)
|
||||
user = result.scalars().first()
|
||||
return UserModel.model_validate(user) if user else None
|
||||
|
||||
async def get_user_by_email(
|
||||
self, email: str, db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
"""Case-insensitive email lookup using SQL ``lower()``."""
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
stmt = select(User).filter(
|
||||
func.lower(User.email) == email.lower(),
|
||||
)
|
||||
row = (await session.execute(stmt)).scalars().first()
|
||||
if not row:
|
||||
return
|
||||
return UserModel.model_validate(row)
|
||||
except Exception:
|
||||
return
|
||||
"""Case-insensitive email lookup using SQL lower()."""
|
||||
async with get_async_db_context(db) as session:
|
||||
email_filter = func.lower(User.email) == email.lower()
|
||||
query = select(User).where(email_filter)
|
||||
match = (await session.execute(query)).scalars().first()
|
||||
if match is None:
|
||||
return
|
||||
return UserModel.model_validate(match)
|
||||
# --- context manager above always returns ---
|
||||
return
|
||||
|
||||
# --- oauth & integrations ---
|
||||
async def get_user_by_oauth_sub(
|
||||
self, provider: str, sub: str, db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
"""Look up a user by OAuth provider + subject claim (dialect-aware JSON filter)."""
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
dialect = session.bind.dialect.name
|
||||
query = select(User)
|
||||
if dialect == 'sqlite':
|
||||
oauth_match = User.oauth.contains({provider: {'sub': sub}})
|
||||
query = query.filter(oauth_match)
|
||||
elif dialect == 'postgresql':
|
||||
oauth_match = User.oauth[provider].cast(JSONB)['sub'].astext == sub
|
||||
query = query.filter(oauth_match)
|
||||
row = (await session.execute(query)).scalars().first()
|
||||
if not row:
|
||||
return
|
||||
return UserModel.model_validate(row)
|
||||
except Exception:
|
||||
return
|
||||
async with get_async_db_context(db) as session:
|
||||
dialect = session.bind.dialect.name
|
||||
query = select(User)
|
||||
if dialect == 'sqlite':
|
||||
oauth_match = User.oauth.contains({provider: {'sub': sub}})
|
||||
query = query.where(oauth_match)
|
||||
elif dialect == 'postgresql':
|
||||
oauth_match = User.oauth[provider].cast(JSONB)['sub'].astext == sub
|
||||
query = query.where(oauth_match)
|
||||
row = (await session.execute(query)).scalars().first()
|
||||
return UserModel.model_validate(row) if row else None
|
||||
|
||||
async def get_user_by_scim_external_id(self, provider: str, external_id: str, db: AsyncSession | None = None) -> UserModel | None:
|
||||
async def get_user_by_scim_external_id(
|
||||
self, provider: str, external_id: str, db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
"""Look up a user by SCIM provider + external ID (dialect-aware JSON filter)."""
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
dialect = session.bind.dialect.name
|
||||
query = select(User)
|
||||
if dialect == 'sqlite':
|
||||
scim_match = User.scim.contains({provider: {'external_id': external_id}})
|
||||
query = query.filter(scim_match)
|
||||
elif dialect == 'postgresql':
|
||||
scim_match = User.scim[provider].cast(JSONB)['external_id'].astext == external_id
|
||||
query = query.filter(scim_match)
|
||||
row = (await session.execute(query)).scalars().first()
|
||||
if not row:
|
||||
return
|
||||
return UserModel.model_validate(row)
|
||||
except Exception:
|
||||
return
|
||||
async with get_async_db_context(db) as session:
|
||||
dialect = session.bind.dialect.name
|
||||
query = select(User)
|
||||
if dialect == 'sqlite':
|
||||
scim_match = User.scim.contains({provider: {'external_id': external_id}})
|
||||
query = query.where(scim_match)
|
||||
elif dialect == 'postgresql':
|
||||
scim_match = User.scim[provider].cast(JSONB)['external_id'].astext == external_id
|
||||
query = query.where(scim_match)
|
||||
row = (await session.execute(query)).scalars().first()
|
||||
return UserModel.model_validate(row) if row else None
|
||||
|
||||
|
||||
async def get_users(
|
||||
self, filter: dict | None = None, skip: int | None = None,
|
||||
@@ -551,137 +531,95 @@ class UsersTable:
|
||||
result = await session.execute(select(User).filter(User.id.in_(user_ids)))
|
||||
users = result.scalars().all()
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
|
||||
# count registered accounts
|
||||
async def get_num_users(self, db: AsyncSession | None = None) -> int | None:
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(func.count()).select_from(User))
|
||||
return result.scalar()
|
||||
|
||||
# check user existence
|
||||
async def has_users(self, db: AsyncSession | None = None) -> bool:
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(exists(select(User))))
|
||||
return result.scalar()
|
||||
|
||||
async def get_first_user(self, db: AsyncSession | None = None) -> UserModel:
|
||||
async def get_first_user(self, db: AsyncSession | None = None) -> UserModel | None:
|
||||
"""Return the earliest-created user (bootstrap admin detection)."""
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
stmt = select(User).order_by(User.created_at).limit(1)
|
||||
row = (await session.execute(stmt)).scalars().first()
|
||||
if not row:
|
||||
return
|
||||
return UserModel.model_validate(row)
|
||||
except Exception:
|
||||
return
|
||||
async with get_async_db_context(db) as session:
|
||||
stmt = select(User).order_by(User.created_at).limit(1)
|
||||
row = (await session.execute(stmt)).scalars().first()
|
||||
return UserModel.model_validate(row) if row else None
|
||||
|
||||
async def get_user_webhook_url_by_id(self, id: str, db: AsyncSession | None = None) -> str | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(User).filter_by(id=id))
|
||||
user = result.scalars().first()
|
||||
|
||||
if user.settings is None:
|
||||
return None
|
||||
else:
|
||||
return user.settings.get('ui', {}).get('notifications', {}).get('webhook_url', None)
|
||||
except Exception:
|
||||
async with get_async_db_context(db) as session:
|
||||
user = await session.get(User, id)
|
||||
if user and user.settings:
|
||||
return user.settings.get('ui', {}).get('notifications', {}).get('webhook_url', None)
|
||||
return None
|
||||
|
||||
async def get_num_users_active_today(self, db: AsyncSession | None = None) -> int | None:
|
||||
async with get_async_db_context(db) as session:
|
||||
current_timestamp = int(datetime.datetime.now().timestamp())
|
||||
current_timestamp = int(time.time())
|
||||
today_midnight_timestamp = current_timestamp - (current_timestamp % 86400)
|
||||
result = await session.execute(
|
||||
select(func.count()).select_from(User).filter(User.last_active_at > today_midnight_timestamp)
|
||||
select(func.count()).select_from(User).where(User.last_active_at > today_midnight_timestamp)
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
async def update_user_role_by_id(self, id: str, role: str, db: AsyncSession | None = None) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
user = (await session.execute(select(User).filter_by(id=id))).scalars().first()
|
||||
if not user:
|
||||
return
|
||||
user.role = role
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
return
|
||||
async with get_async_db_context(db) as session:
|
||||
user = await session.get(User, id)
|
||||
if not user:
|
||||
return None
|
||||
user.role = role
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return UserModel.model_validate(user)
|
||||
|
||||
async def update_user_status_by_id(
|
||||
self, id: str, form_data: UserStatus, db: AsyncSession | None = None
|
||||
) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(User).filter_by(id=id))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return None
|
||||
for key, value in form_data.model_dump(exclude_none=True).items():
|
||||
setattr(user, key, value)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
return None
|
||||
async with get_async_db_context(db) as session:
|
||||
user = await session.get(User, id)
|
||||
if not user:
|
||||
return None
|
||||
for key, value in form_data.model_dump(exclude_none=True).items():
|
||||
setattr(user, key, value)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return UserModel.model_validate(user)
|
||||
|
||||
async def update_user_profile_image_url_by_id(
|
||||
self, id: str, profile_image_url: str, db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
row = (await session.execute(select(User).filter_by(id=id))).scalars().first()
|
||||
if row is None:
|
||||
return
|
||||
row.profile_image_url = profile_image_url
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return UserModel.model_validate(row)
|
||||
except Exception:
|
||||
return
|
||||
async with get_async_db_context(db) as session:
|
||||
user = await session.get(User, id)
|
||||
if user is None:
|
||||
return None
|
||||
user.profile_image_url = profile_image_url
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return UserModel.model_validate(user)
|
||||
|
||||
@throttle(DATABASE_USER_ACTIVE_STATUS_UPDATE_INTERVAL)
|
||||
async def update_last_active_by_id(self, id: str, db: AsyncSession | None = None) -> None:
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
await session.execute(update(User).filter_by(id=id).values(last_active_at=int(time.time())))
|
||||
await session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
async with get_async_db_context(db) as session:
|
||||
await session.execute(update(User).where(User.id == id).values(last_active_at=int(time.time())))
|
||||
await session.commit()
|
||||
|
||||
async def update_user_oauth_by_id(
|
||||
self, id: str, provider: str, sub: str, db: AsyncSession | None = None
|
||||
) -> UserModel | None:
|
||||
"""
|
||||
Update or insert an OAuth provider/sub pair into the user's oauth JSON field.
|
||||
Example resulting structure:
|
||||
{
|
||||
"google": { "sub": "123" },
|
||||
"github": { "sub": "abc" }
|
||||
}
|
||||
"""
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(User).filter_by(id=id))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return None
|
||||
|
||||
# Load existing oauth JSON or create empty
|
||||
oauth = user.oauth or {}
|
||||
|
||||
# Update or insert provider entry
|
||||
oauth[provider] = {'sub': sub}
|
||||
|
||||
# Persist updated JSON
|
||||
await session.execute(update(User).filter_by(id=id).values(oauth=oauth))
|
||||
await session.commit()
|
||||
|
||||
return UserModel.model_validate(user)
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
"""Update or insert an OAuth provider/sub pair into the user's oauth JSON field."""
|
||||
async with get_async_db_context(db) as session:
|
||||
user = await session.get(User, id)
|
||||
if not user:
|
||||
return None
|
||||
oauth = dict(user.oauth or {})
|
||||
oauth[provider] = {'sub': sub}
|
||||
user.oauth = oauth
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return UserModel.model_validate(user)
|
||||
|
||||
async def update_user_scim_by_id(
|
||||
self,
|
||||
@@ -690,157 +628,101 @@ class UsersTable:
|
||||
external_id: str,
|
||||
db: AsyncSession | None = None,
|
||||
) -> UserModel | None:
|
||||
"""
|
||||
Update or insert a SCIM provider/external_id pair into the user's scim JSON field.
|
||||
Example resulting structure:
|
||||
{
|
||||
"microsoft": { "external_id": "abc" },
|
||||
"okta": { "external_id": "def" }
|
||||
}
|
||||
"""
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(User).filter_by(id=id))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return None
|
||||
|
||||
scim = user.scim or {}
|
||||
scim[provider] = {'external_id': external_id}
|
||||
|
||||
await session.execute(update(User).filter_by(id=id).values(scim=scim))
|
||||
await session.commit()
|
||||
|
||||
return UserModel.model_validate(user)
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
"""Update or insert a SCIM provider/external_id pair into the user's scim JSON field."""
|
||||
async with get_async_db_context(db) as session:
|
||||
user = await session.get(User, id)
|
||||
if not user:
|
||||
return None
|
||||
scim = dict(user.scim or {})
|
||||
scim[provider] = {'external_id': external_id}
|
||||
user.scim = scim
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return UserModel.model_validate(user)
|
||||
|
||||
async def update_user_by_id(self, id: str, updated: dict, db: AsyncSession | None = None) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(User).filter_by(id=id))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return None
|
||||
for key, value in updated.items():
|
||||
setattr(user, key, value)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return UserModel.model_validate(user)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
async with get_async_db_context(db) as session:
|
||||
user = await session.get(User, id)
|
||||
if not user:
|
||||
return None
|
||||
for key, value in updated.items():
|
||||
setattr(user, key, value)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return UserModel.model_validate(user)
|
||||
# settings update helper
|
||||
async def update_user_settings_by_id(
|
||||
self, id: str, updated: dict, db: AsyncSession | None = None
|
||||
) -> UserModel | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(User).filter_by(id=id))
|
||||
user = result.scalars().first()
|
||||
if not user:
|
||||
return None
|
||||
|
||||
user_settings = user.settings
|
||||
|
||||
if user_settings is None:
|
||||
user_settings = {}
|
||||
|
||||
user_settings.update(updated)
|
||||
|
||||
await session.execute(update(User).filter_by(id=id).values(settings=user_settings))
|
||||
await session.commit()
|
||||
|
||||
result = await session.execute(select(User).filter_by(id=id))
|
||||
user = result.scalars().first()
|
||||
return UserModel.model_validate(user)
|
||||
except Exception:
|
||||
return None
|
||||
async with get_async_db_context(db) as session:
|
||||
user = await session.get(User, id)
|
||||
if not user:
|
||||
return None
|
||||
user_settings = dict(user.settings or {})
|
||||
user_settings.update(updated)
|
||||
user.settings = user_settings
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return UserModel.model_validate(user)
|
||||
|
||||
async def delete_user_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
from open_webui.models.chats import Chats
|
||||
from open_webui.models.groups import Groups
|
||||
from open_webui.models.chats import Chats
|
||||
from open_webui.models.groups import Groups
|
||||
|
||||
# Remove User from Groups
|
||||
await Groups.remove_user_from_all_groups(id)
|
||||
# Remove User from Groups
|
||||
await Groups.remove_user_from_all_groups(id)
|
||||
|
||||
# Delete User Chats
|
||||
result = await Chats.delete_chats_by_user_id(id, db=session)
|
||||
if result:
|
||||
async with get_async_db_context(db) as session:
|
||||
# Delete User
|
||||
await session.execute(delete(User).filter_by(id=id))
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
# Delete User Chats
|
||||
async with get_async_db_context(db) as session:
|
||||
deleted_chats = await Chats.delete_chats_by_user_id(id, db=session)
|
||||
if not deleted_chats:
|
||||
return False # chats deletion failed
|
||||
await session.execute(delete(User).where(User.id == id))
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def get_user_api_key_by_id(self, id: str, db: AsyncSession | None = None) -> str | None:
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(ApiKey).filter_by(user_id=id))
|
||||
api_key = result.scalars().first()
|
||||
return api_key.key if api_key else None
|
||||
except Exception:
|
||||
return None
|
||||
async with get_async_db_context(db) as session:
|
||||
api_key = (await session.execute(select(ApiKey).where(ApiKey.user_id == id))).scalars().first()
|
||||
return api_key.key if api_key else None
|
||||
|
||||
async def update_user_api_key_by_id(self, id: str, api_key: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
await session.execute(delete(ApiKey).filter_by(user_id=id))
|
||||
await session.commit()
|
||||
|
||||
now = int(time.time())
|
||||
new_api_key = ApiKey(
|
||||
id=f'key_{id}',
|
||||
user_id=id,
|
||||
key=api_key,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
session.add(new_api_key)
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
async with get_async_db_context(db) as session:
|
||||
await session.execute(delete(ApiKey).where(ApiKey.user_id == id))
|
||||
now_ts = int(time.time())
|
||||
new_key = ApiKey(
|
||||
id=f'key_{id}',
|
||||
user_id=id,
|
||||
key=api_key,
|
||||
created_at=now_ts,
|
||||
updated_at=now_ts,
|
||||
)
|
||||
session.add(new_key)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def delete_user_api_key_by_id(self, id: str, db: AsyncSession | None = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as session:
|
||||
await session.execute(delete(ApiKey).filter_by(user_id=id))
|
||||
await session.commit()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
async with get_async_db_context(db) as session:
|
||||
await session.execute(delete(ApiKey).where(ApiKey.user_id == id))
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def get_valid_user_ids(self, user_ids: list[str], db: AsyncSession | None = None) -> list[str]:
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(User).filter(User.id.in_(user_ids)))
|
||||
users = result.scalars().all()
|
||||
return [user.id for user in users]
|
||||
result = await session.execute(select(User).where(User.id.in_(user_ids)))
|
||||
return [u.id for u in result.scalars().all()]
|
||||
|
||||
async def get_super_admin_user(self, db: AsyncSession | None = None) -> UserModel | None:
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(User).filter_by(role='admin').limit(1))
|
||||
user = result.scalars().first()
|
||||
if user:
|
||||
return UserModel.model_validate(user)
|
||||
else:
|
||||
return None
|
||||
row = (await session.execute(select(User).where(User.role == 'admin').limit(1))).scalars().first()
|
||||
return UserModel.model_validate(row) if row else None
|
||||
|
||||
async def get_active_user_count(self, db: AsyncSession | None = None) -> int:
|
||||
async with get_async_db_context(db) as session:
|
||||
# Consider user active if last_active_at within the last 3 minutes
|
||||
three_minutes_ago = int(time.time()) - 180
|
||||
result = await session.execute(
|
||||
select(func.count()).select_from(User).filter(User.last_active_at >= three_minutes_ago)
|
||||
select(func.count()).select_from(User).where(User.last_active_at >= three_minutes_ago)
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
@@ -854,8 +736,7 @@ class UsersTable:
|
||||
|
||||
async def is_user_active(self, user_id: str, db: AsyncSession | None = None) -> bool:
|
||||
async with get_async_db_context(db) as session:
|
||||
result = await session.execute(select(User).filter_by(id=user_id))
|
||||
user = result.scalars().first()
|
||||
user = await session.get(User, user_id)
|
||||
if user and user.last_active_at:
|
||||
# Consider user active if last_active_at within the last 3 minutes
|
||||
three_minutes_ago = int(time.time()) - 180
|
||||
@@ -863,4 +744,5 @@ class UsersTable:
|
||||
return False
|
||||
|
||||
|
||||
Users = UsersTable()
|
||||
Users = UsersTable() # singleton user repository
|
||||
|
||||
|
||||
@@ -383,7 +383,7 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user), db: Asy
|
||||
|
||||
|
||||
@router.post('/user/info/update', response_model=dict | None)
|
||||
async def update_user_info_by_session_user(
|
||||
async def update_user_info_by_session_user( # PATCH-style merge
|
||||
form_data: dict, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Merge caller-supplied fields into the current user's info dict.
|
||||
@@ -538,9 +538,8 @@ async def get_user_active_status_by_id(
|
||||
|
||||
@router.post('/{user_id}/update', response_model=UserModel | None)
|
||||
async def update_user_by_id(
|
||||
user_id: str,
|
||||
form_data: UserUpdateForm,
|
||||
session_user=Depends(get_admin_user),
|
||||
user_id: str, form_data: UserUpdateForm,
|
||||
session_user: UserModel = Depends(get_admin_user),
|
||||
db: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
# Prevent modification of the primary admin user by other admins
|
||||
|
||||
@@ -100,17 +100,17 @@ async def download_db(user=Depends(get_admin_user)):
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
from open_webui.internal.db import engine # deferred import
|
||||
|
||||
if engine.name != 'sqlite': # only SQLite DBs can be downloaded as a file
|
||||
# --- resolve target database engine ---
|
||||
from open_webui.internal.db import engine # lazy — avoids circular at import time
|
||||
if engine.name != 'sqlite': # non-SQLite backends use pg_dump / managed exports
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DB_NOT_SQLITE,
|
||||
)
|
||||
|
||||
return FileResponse(
|
||||
engine.url.database,
|
||||
str(engine.url.database),
|
||||
|
||||
media_type='application/octet-stream',
|
||||
filename='webui.db',
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user