diff --git a/backend/open_webui/env.py b/backend/open_webui/env.py index 57b2bd564a..5e1e150d04 100644 --- a/backend/open_webui/env.py +++ b/backend/open_webui/env.py @@ -431,9 +431,7 @@ except ValueError: # enabled, the kernel sends TCP keepalive probes on idle connections so # half-closed sockets (e.g. after a silent firewall/LB reset or a NIC # flap) are detected before the next command lands on them. -REDIS_SOCKET_KEEPALIVE = ( - os.environ.get('REDIS_SOCKET_KEEPALIVE', 'False').lower() == 'true' -) +REDIS_SOCKET_KEEPALIVE = os.environ.get('REDIS_SOCKET_KEEPALIVE', 'False').lower() == 'true' # How often (in seconds) redis-py should PING an idle pooled connection # before reusing it. Opt-in: defaults to unset (empty string) so behavior diff --git a/backend/open_webui/internal/db.py b/backend/open_webui/internal/db.py index a9e5e089ab..abf54c185d 100644 --- a/backend/open_webui/internal/db.py +++ b/backend/open_webui/internal/db.py @@ -214,6 +214,7 @@ if 'sqlite' in ASYNC_SQLALCHEMY_DATABASE_URL: ) if DATABASE_ENABLE_SQLITE_WAL: + @event.listens_for(async_engine.sync_engine, 'connect') def _set_sqlite_wal(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index e723905d11..8c2a139dba 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -723,6 +723,7 @@ async def lifespan(app: FastAPI): # Shutdown: clean up shared resources from open_webui.utils.session_pool import close_session + await close_session() if hasattr(app.state, 'redis_task_command_listener'): @@ -1397,7 +1398,6 @@ app.add_middleware(RedirectMiddleware) app.add_middleware(SecurityHeadersMiddleware) - @app.middleware('http') async def commit_session_after_request(request: Request, call_next): response = await call_next(request) @@ -1992,7 +1992,7 @@ async def list_tasks_endpoint(request: Request, user=Depends(get_admin_user)): @app.get('/api/tasks/chat/{chat_id:path}') async def list_tasks_by_chat_id_endpoint(request: Request, chat_id: str, user=Depends(get_verified_user)): if chat_id.startswith('local:'): - socket_id = chat_id[len('local:'):] + socket_id = chat_id[len('local:') :] owner_id = get_user_id_from_session_pool(socket_id) if owner_id != user.id and user.role != 'admin': return {'task_ids': []} @@ -2010,7 +2010,7 @@ async def list_tasks_by_chat_id_endpoint(request: Request, chat_id: str, user=De @app.post('/api/tasks/chat/{chat_id:path}/stop') async def stop_tasks_by_chat_id_endpoint(request: Request, chat_id: str, user=Depends(get_verified_user)): if chat_id.startswith('local:'): - socket_id = chat_id[len('local:'):] + socket_id = chat_id[len('local:') :] owner_id = get_user_id_from_session_pool(socket_id) if owner_id != user.id and user.role != 'admin': raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND) diff --git a/backend/open_webui/models/access_grants.py b/backend/open_webui/models/access_grants.py index f064306a2c..f031495912 100644 --- a/backend/open_webui/models/access_grants.py +++ b/backend/open_webui/models/access_grants.py @@ -295,8 +295,7 @@ class AccessGrantsTable: async with get_async_db_context(db) as db: # Check for existing grant result = await db.execute( - select(AccessGrant) - .filter_by( + select(AccessGrant).filter_by( resource_type=resource_type, resource_id=resource_id, principal_type=principal_type, @@ -334,8 +333,7 @@ class AccessGrantsTable: """Remove a single access grant.""" async with get_async_db_context(db) as db: result = await db.execute( - delete(AccessGrant) - .filter_by( + delete(AccessGrant).filter_by( resource_type=resource_type, resource_id=resource_id, principal_type=principal_type, @@ -355,8 +353,7 @@ class AccessGrantsTable: """Remove all access grants for a resource.""" async with get_async_db_context(db) as db: result = await db.execute( - delete(AccessGrant) - .filter_by( + delete(AccessGrant).filter_by( resource_type=resource_type, resource_id=resource_id, ) @@ -451,8 +448,7 @@ class AccessGrantsTable: """ async with get_async_db_context(db) as db: result = await db.execute( - select(AccessGrant) - .filter_by( + select(AccessGrant).filter_by( resource_type=resource_type, resource_id=resource_id, ) @@ -470,8 +466,7 @@ class AccessGrantsTable: """Get all grants for a specific resource.""" async with get_async_db_context(db) as db: result = await db.execute( - select(AccessGrant) - .filter_by( + select(AccessGrant).filter_by( resource_type=resource_type, resource_id=resource_id, ) @@ -490,8 +485,7 @@ class AccessGrantsTable: return {} async with get_async_db_context(db) as db: result = await db.execute( - select(AccessGrant) - .filter( + select(AccessGrant).filter( AccessGrant.resource_type == resource_type, AccessGrant.resource_id.in_(resource_ids), ) @@ -634,8 +628,7 @@ class AccessGrantsTable: async with get_async_db_context(db) as db: result = await db.execute( - select(AccessGrant) - .filter_by( + select(AccessGrant).filter_by( resource_type=resource_type, resource_id=resource_id, permission=permission, diff --git a/backend/open_webui/models/auths.py b/backend/open_webui/models/auths.py index ca5070878e..2c8c6ba99f 100644 --- a/backend/open_webui/models/auths.py +++ b/backend/open_webui/models/auths.py @@ -141,7 +141,9 @@ class AuthsTable: except Exception: return None - async def authenticate_user_by_api_key(self, api_key: str, db: Optional[AsyncSession] = None) -> Optional[UserModel]: + async def authenticate_user_by_api_key( + self, api_key: str, db: Optional[AsyncSession] = None + ) -> Optional[UserModel]: log.info(f'authenticate_user_by_api_key') # if no api_key, return None if not api_key: @@ -159,9 +161,7 @@ class AuthsTable: async with get_async_db_context(db) as db: # Single JOIN query instead of two separate queries result = await db.execute( - select(Auth, User) - .join(User, Auth.id == User.id) - .filter(Auth.email == email, Auth.active == True) + select(Auth, User).join(User, Auth.id == User.id).filter(Auth.email == email, Auth.active == True) ) row = result.first() if row: diff --git a/backend/open_webui/models/automations.py b/backend/open_webui/models/automations.py index fab3788eb0..c891c3204e 100644 --- a/backend/open_webui/models/automations.py +++ b/backend/open_webui/models/automations.py @@ -145,9 +145,7 @@ class AutomationTable: async def count_by_user(self, user_id: str, db: Optional[AsyncSession] = None) -> int: async with get_async_db_context(db) as db: - result = await db.execute( - select(func.count()).select_from(Automation).filter_by(user_id=user_id) - ) + result = await db.execute(select(func.count()).select_from(Automation).filter_by(user_id=user_id)) return result.scalar() async def get_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[AutomationModel]: @@ -185,9 +183,7 @@ class AutomationTable: stmt = stmt.order_by(Automation.created_at.desc()) # Get total count - count_result = await db.execute( - select(func.count()).select_from(stmt.subquery()) - ) + count_result = await db.execute(select(func.count()).select_from(stmt.subquery())) total = count_result.scalar() if skip: @@ -343,18 +339,14 @@ class AutomationRunTable: .subquery() ) result = await db.execute( - select(AutomationRun) - .join( + select(AutomationRun).join( subq, (AutomationRun.automation_id == subq.c.automation_id) & (AutomationRun.created_at == subq.c.max_created), ) ) rows = result.scalars().all() - return { - row.automation_id: AutomationRunModel.model_validate(row) - for row in rows - } + return {row.automation_id: AutomationRunModel.model_validate(row) for row in rows} async def get_by_automation( self, diff --git a/backend/open_webui/models/channels.py b/backend/open_webui/models/channels.py index 9b5403e6e1..942c06d6b3 100644 --- a/backend/open_webui/models/channels.py +++ b/backend/open_webui/models/channels.py @@ -414,9 +414,13 @@ class ChannelTable: all_channels = list(membership_channels) + list(standard_channels) channel_ids = [c.id for c in all_channels] grants_map = await AccessGrants.get_grants_by_resources('channel', channel_ids, db=db) - return [await self._to_channel_model(c, access_grants=grants_map.get(c.id, []), db=db) for c in all_channels] + return [ + await self._to_channel_model(c, access_grants=grants_map.get(c.id, []), db=db) for c in all_channels + ] - async def get_dm_channel_by_user_ids(self, user_ids: list[str], db: Optional[AsyncSession] = None) -> Optional[ChannelModel]: + async def get_dm_channel_by_user_ids( + self, user_ids: list[str], db: Optional[AsyncSession] = None + ) -> Optional[ChannelModel]: async with get_async_db_context(db) as db: # Ensure uniqueness in case a list with duplicates is passed unique_user_ids = list(set(user_ids)) @@ -462,9 +466,7 @@ class ChannelTable: # 1. Collect all user_ids including groups + inviter requested_users = await self._collect_unique_user_ids(invited_by, user_ids, group_ids) - result = await db.execute( - select(ChannelMember.user_id).filter(ChannelMember.channel_id == channel_id) - ) + result = await db.execute(select(ChannelMember.user_id).filter(ChannelMember.channel_id == channel_id)) existing_users = {row[0] for row in result.all()} new_user_ids = requested_users - existing_users @@ -512,7 +514,9 @@ class ChannelTable: membership = result.scalars().first() return membership is not None - async def join_channel(self, channel_id: str, user_id: str, db: Optional[AsyncSession] = None) -> Optional[ChannelMemberModel]: + async def join_channel( + self, channel_id: str, user_id: str, db: Optional[AsyncSession] = None + ) -> Optional[ChannelMemberModel]: async with get_async_db_context(db) as db: # Check if the membership already exists result = await db.execute( @@ -581,11 +585,11 @@ class ChannelTable: membership = result.scalars().first() return ChannelMemberModel.model_validate(membership) if membership else None - async def get_members_by_channel_id(self, channel_id: str, db: Optional[AsyncSession] = None) -> list[ChannelMemberModel]: + async def get_members_by_channel_id( + self, channel_id: str, db: Optional[AsyncSession] = None + ) -> list[ChannelMemberModel]: async with get_async_db_context(db) as db: - result = await db.execute( - select(ChannelMember).filter(ChannelMember.channel_id == channel_id) - ) + result = await db.execute(select(ChannelMember).filter(ChannelMember.channel_id == channel_id)) memberships = result.scalars().all() return [ChannelMemberModel.model_validate(membership) for membership in memberships] @@ -613,7 +617,9 @@ class ChannelTable: await db.commit() return True - async def update_member_last_read_at(self, channel_id: str, user_id: str, db: Optional[AsyncSession] = None) -> bool: + async def update_member_last_read_at( + self, channel_id: str, user_id: str, db: Optional[AsyncSession] = None + ) -> bool: async with get_async_db_context(db) as db: result = await db.execute( select(ChannelMember).filter( @@ -658,11 +664,13 @@ class ChannelTable: async def is_user_channel_member(self, channel_id: str, user_id: str, db: Optional[AsyncSession] = None) -> bool: async with get_async_db_context(db) as db: result = await db.execute( - select(ChannelMember).filter( + select(ChannelMember) + .filter( ChannelMember.channel_id == channel_id, ChannelMember.user_id == user_id, ChannelMember.is_active.is_(True), - ).limit(1) + ) + .limit(1) ) membership = result.scalars().first() return membership is not None @@ -726,11 +734,13 @@ class ChannelTable: # --- Case A: group or dm => user must be an active member --- if channel.type in ['group', 'dm']: result = await db.execute( - select(ChannelMember).filter( + select(ChannelMember) + .filter( ChannelMember.channel_id == channel.id, ChannelMember.user_id == user_id, ChannelMember.is_active.is_(True), - ).limit(1) + ) + .limit(1) ) membership = result.scalars().first() if membership: @@ -774,11 +784,13 @@ class ChannelTable: # If the channel is a group or dm, read access requires membership (active) if channel.type in ['group', 'dm']: result = await db.execute( - select(ChannelMember).filter( + select(ChannelMember) + .filter( ChannelMember.channel_id == id, ChannelMember.user_id == user_id, ChannelMember.is_active.is_(True), - ).limit(1) + ) + .limit(1) ) membership = result.scalars().first() if membership: @@ -863,9 +875,7 @@ class ChannelTable: ) -> bool: try: async with get_async_db_context(db) as db: - result = await db.execute( - select(ChannelFile).filter_by(channel_id=channel_id, file_id=file_id) - ) + result = await db.execute(select(ChannelFile).filter_by(channel_id=channel_id, file_id=file_id)) channel_file = result.scalars().first() if not channel_file: return False @@ -878,7 +888,9 @@ class ChannelTable: except Exception: return False - async def remove_file_from_channel_by_id(self, channel_id: str, file_id: str, db: Optional[AsyncSession] = None) -> bool: + async def remove_file_from_channel_by_id( + self, channel_id: str, file_id: str, db: Optional[AsyncSession] = None + ) -> bool: try: async with get_async_db_context(db) as db: await db.execute(delete(ChannelFile).filter_by(channel_id=channel_id, file_id=file_id)) @@ -921,13 +933,17 @@ class ChannelTable: await db.commit() return webhook - async def get_webhooks_by_channel_id(self, channel_id: str, db: Optional[AsyncSession] = None) -> list[ChannelWebhookModel]: + async def get_webhooks_by_channel_id( + self, channel_id: str, db: Optional[AsyncSession] = None + ) -> list[ChannelWebhookModel]: async with get_async_db_context(db) as db: result = await db.execute(select(ChannelWebhook).filter(ChannelWebhook.channel_id == channel_id)) webhooks = result.scalars().all() return [ChannelWebhookModel.model_validate(w) for w in webhooks] - async def get_webhook_by_id(self, webhook_id: str, db: Optional[AsyncSession] = None) -> Optional[ChannelWebhookModel]: + async def get_webhook_by_id( + self, webhook_id: str, db: Optional[AsyncSession] = None + ) -> Optional[ChannelWebhookModel]: async with get_async_db_context(db) as db: result = await db.execute(select(ChannelWebhook).filter(ChannelWebhook.id == webhook_id)) webhook = result.scalars().first() diff --git a/backend/open_webui/models/chat_messages.py b/backend/open_webui/models/chat_messages.py index 087662ff7c..bd9c720fa4 100644 --- a/backend/open_webui/models/chat_messages.py +++ b/backend/open_webui/models/chat_messages.py @@ -272,13 +272,10 @@ class ChatMessageTable: """Get distinct chat_ids that used a specific model.""" async with get_async_db_context(db) as db: - stmt = ( - select( - ChatMessage.chat_id, - func.max(ChatMessage.created_at).label('last_message_at'), - ) - .filter(ChatMessage.model_id == model_id) - ) + stmt = select( + ChatMessage.chat_id, + func.max(ChatMessage.created_at).label('last_message_at'), + ).filter(ChatMessage.model_id == model_id) if start_date: stmt = stmt.filter(ChatMessage.created_at >= start_date) if end_date: @@ -313,13 +310,10 @@ class ChatMessageTable: async with get_async_db_context(db) as db: from open_webui.models.groups import GroupMember - stmt = ( - select(ChatMessage.model_id, func.count(ChatMessage.id).label('count')) - .filter( - ChatMessage.role == 'assistant', - ChatMessage.model_id.isnot(None), - ~ChatMessage.user_id.like('shared-%'), - ) + stmt = select(ChatMessage.model_id, func.count(ChatMessage.id).label('count')).filter( + ChatMessage.role == 'assistant', + ChatMessage.model_id.isnot(None), + ~ChatMessage.user_id.like('shared-%'), ) if start_date: @@ -365,19 +359,16 @@ class ChatMessageTable: else: raise NotImplementedError(f'Unsupported dialect: {dialect}') - stmt = ( - select( - ChatMessage.model_id, - func.coalesce(func.sum(input_tokens), 0).label('input_tokens'), - func.coalesce(func.sum(output_tokens), 0).label('output_tokens'), - func.count(ChatMessage.id).label('message_count'), - ) - .filter( - ChatMessage.role == 'assistant', - ChatMessage.model_id.isnot(None), - ChatMessage.usage.isnot(None), - ~ChatMessage.user_id.like('shared-%'), - ) + stmt = select( + ChatMessage.model_id, + func.coalesce(func.sum(input_tokens), 0).label('input_tokens'), + func.coalesce(func.sum(output_tokens), 0).label('output_tokens'), + func.count(ChatMessage.id).label('message_count'), + ).filter( + ChatMessage.role == 'assistant', + ChatMessage.model_id.isnot(None), + ChatMessage.usage.isnot(None), + ~ChatMessage.user_id.like('shared-%'), ) if start_date: @@ -430,19 +421,16 @@ class ChatMessageTable: else: raise NotImplementedError(f'Unsupported dialect: {dialect}') - stmt = ( - select( - ChatMessage.user_id, - func.coalesce(func.sum(input_tokens), 0).label('input_tokens'), - func.coalesce(func.sum(output_tokens), 0).label('output_tokens'), - func.count(ChatMessage.id).label('message_count'), - ) - .filter( - ChatMessage.role == 'assistant', - ChatMessage.user_id.isnot(None), - ChatMessage.usage.isnot(None), - ~ChatMessage.user_id.like('shared-%'), - ) + stmt = select( + ChatMessage.user_id, + func.coalesce(func.sum(input_tokens), 0).label('input_tokens'), + func.coalesce(func.sum(output_tokens), 0).label('output_tokens'), + func.count(ChatMessage.id).label('message_count'), + ).filter( + ChatMessage.role == 'assistant', + ChatMessage.user_id.isnot(None), + ChatMessage.usage.isnot(None), + ~ChatMessage.user_id.like('shared-%'), ) if start_date: @@ -476,9 +464,8 @@ class ChatMessageTable: async with get_async_db_context(db) as db: from open_webui.models.groups import GroupMember - stmt = ( - select(ChatMessage.user_id, func.count(ChatMessage.id).label('count')) - .filter(~ChatMessage.user_id.like('shared-%')) + stmt = select(ChatMessage.user_id, func.count(ChatMessage.id).label('count')).filter( + ~ChatMessage.user_id.like('shared-%') ) if start_date: @@ -503,9 +490,8 @@ class ChatMessageTable: async with get_async_db_context(db) as db: from open_webui.models.groups import GroupMember - stmt = ( - select(ChatMessage.chat_id, func.count(ChatMessage.id).label('count')) - .filter(~ChatMessage.user_id.like('shared-%')) + stmt = select(ChatMessage.chat_id, func.count(ChatMessage.id).label('count')).filter( + ~ChatMessage.user_id.like('shared-%') ) if start_date: @@ -532,13 +518,10 @@ class ChatMessageTable: from datetime import datetime, timedelta from open_webui.models.groups import GroupMember - stmt = ( - select(ChatMessage.created_at, ChatMessage.model_id) - .filter( - ChatMessage.role == 'assistant', - ChatMessage.model_id.isnot(None), - ~ChatMessage.user_id.like('shared-%'), - ) + stmt = select(ChatMessage.created_at, ChatMessage.model_id).filter( + ChatMessage.role == 'assistant', + ChatMessage.model_id.isnot(None), + ~ChatMessage.user_id.like('shared-%'), ) if start_date: @@ -582,13 +565,10 @@ class ChatMessageTable: async with get_async_db_context(db) as db: from datetime import datetime, timedelta - stmt = ( - select(ChatMessage.created_at, ChatMessage.model_id) - .filter( - ChatMessage.role == 'assistant', - ChatMessage.model_id.isnot(None), - ~ChatMessage.user_id.like('shared-%'), - ) + stmt = select(ChatMessage.created_at, ChatMessage.model_id).filter( + ChatMessage.role == 'assistant', + ChatMessage.model_id.isnot(None), + ~ChatMessage.user_id.like('shared-%'), ) if start_date: diff --git a/backend/open_webui/models/chats.py b/backend/open_webui/models/chats.py index 77bc0a5614..b13edbf0bf 100644 --- a/backend/open_webui/models/chats.py +++ b/backend/open_webui/models/chats.py @@ -292,7 +292,9 @@ class ChatTable: return changed - async def insert_new_chat(self, user_id: str, form_data: ChatForm, db: Optional[AsyncSession] = None) -> Optional[ChatModel]: + async def insert_new_chat( + self, user_id: str, form_data: ChatForm, db: Optional[AsyncSession] = None + ) -> Optional[ChatModel]: async with get_async_db_context(db) as db: id = str(uuid.uuid4()) chat = ChatModel( @@ -551,7 +553,9 @@ class ChatTable: await self.update_chat_by_id(id, chat, db=db) return message_files - async def insert_shared_chat_by_chat_id(self, chat_id: str, db: Optional[AsyncSession] = None) -> Optional[ChatModel]: + async def insert_shared_chat_by_chat_id( + self, chat_id: str, db: Optional[AsyncSession] = None + ) -> Optional[ChatModel]: async with get_async_db_context(db) as db: # Get the existing chat to share chat = await db.get(Chat, chat_id) @@ -585,7 +589,9 @@ class ChatTable: await db.commit() return shared_chat if shared_result else None - async def update_shared_chat_by_chat_id(self, chat_id: str, db: Optional[AsyncSession] = None) -> Optional[ChatModel]: + async def update_shared_chat_by_chat_id( + self, chat_id: str, db: Optional[AsyncSession] = None + ) -> Optional[ChatModel]: try: async with get_async_db_context(db) as db: chat = await db.get(Chat, chat_id) @@ -689,7 +695,9 @@ class ChatTable: db: Optional[AsyncSession] = None, ) -> list[ChatTitleIdResponse]: async with get_async_db_context(db) as db: - stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at).filter_by(user_id=user_id, archived=True) + stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at).filter_by( + user_id=user_id, archived=True + ) if filter: query_key = filter.get('query') @@ -740,7 +748,11 @@ class ChatTable: db: Optional[AsyncSession] = None, ) -> list[SharedChatResponse]: async with get_async_db_context(db) as db: - stmt = select(Chat.id, Chat.title, Chat.share_id, Chat.updated_at, Chat.created_at).filter_by(user_id=user_id).filter(Chat.share_id.isnot(None)) + stmt = ( + select(Chat.id, Chat.title, Chat.share_id, Chat.updated_at, Chat.created_at) + .filter_by(user_id=user_id) + .filter(Chat.share_id.isnot(None)) + ) if filter: query_key = filter.get('query') @@ -793,7 +805,9 @@ class ChatTable: db: Optional[AsyncSession] = None, ) -> list[ChatTitleIdResponse]: async with get_async_db_context(db) as db: - stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by(user_id=user_id) + stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by( + user_id=user_id + ) if not include_archived: stmt = stmt.filter_by(archived=False) @@ -846,7 +860,9 @@ class ChatTable: db: Optional[AsyncSession] = None, ) -> list[ChatTitleIdResponse]: async with get_async_db_context(db) as db: - stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by(user_id=user_id) + stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by( + user_id=user_id + ) if not include_folders: stmt = stmt.filter_by(folder_id=None) @@ -889,10 +905,7 @@ class ChatTable: ) -> list[ChatModel]: async with get_async_db_context(db) as db: result = await db.execute( - select(Chat) - .filter(Chat.id.in_(chat_ids)) - .filter_by(archived=False) - .order_by(Chat.updated_at.desc()) + select(Chat).filter(Chat.id.in_(chat_ids)).filter_by(archived=False).order_by(Chat.updated_at.desc()) ) all_chats = result.scalars().all() return [ChatModel.model_validate(chat) for chat in all_chats] @@ -925,7 +938,9 @@ class ChatTable: except Exception: return None - async def get_chat_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> Optional[ChatModel]: + async def get_chat_by_id_and_user_id( + self, id: str, user_id: str, db: Optional[AsyncSession] = None + ) -> Optional[ChatModel]: try: async with get_async_db_context(db) as db: result = await db.execute(select(Chat).filter_by(id=id, user_id=user_id)) @@ -941,9 +956,7 @@ class ChatTable: """ try: async with get_async_db_context(db) as db: - result = await db.execute( - select(exists().where(and_(Chat.id == id, Chat.user_id == user_id))) - ) + result = await db.execute(select(exists().where(and_(Chat.id == id, Chat.user_id == user_id)))) return result.scalar() except Exception: return False @@ -997,9 +1010,7 @@ class ChatTable: else: stmt = stmt.order_by(Chat.updated_at.desc(), Chat.id) - count_result = await db.execute( - select(func.count()).select_from(stmt.subquery()) - ) + count_result = await db.execute(select(func.count()).select_from(stmt.subquery())) total = count_result.scalar() if skip is not None: @@ -1017,7 +1028,9 @@ class ChatTable: } ) - async def get_pinned_chats_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> list[ChatTitleIdResponse]: + async def get_pinned_chats_by_user_id( + self, user_id: str, db: Optional[AsyncSession] = None + ) -> list[ChatTitleIdResponse]: async with get_async_db_context(db) as db: result = await db.execute( select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at) @@ -1060,7 +1073,9 @@ class ChatTable: search_text = sanitize_text_for_db(search_text).lower().strip() if not search_text: - return await self.get_chat_list_by_user_id(user_id, include_archived, filter={}, skip=skip, limit=limit, db=db) + return await self.get_chat_list_by_user_id( + user_id, include_archived, filter={}, skip=skip, limit=limit, db=db + ) search_text_words = search_text.split(' ') @@ -1305,7 +1320,9 @@ class ChatTable: except Exception: return None - async def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> list[TagModel]: + async def get_chat_tags_by_id_and_user_id( + self, id: str, user_id: str, db: Optional[AsyncSession] = None + ) -> list[TagModel]: async with get_async_db_context(db) as db: chat = await db.get(Chat, id) tag_ids = chat.meta.get('tags', []) @@ -1320,7 +1337,9 @@ class ChatTable: db: Optional[AsyncSession] = None, ) -> list[ChatTitleIdResponse]: async with get_async_db_context(db) as db: - stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by(user_id=user_id) + stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by( + user_id=user_id + ) tag_id = tag_name.replace(' ', '_').lower() bind = await db.connection() @@ -1378,7 +1397,9 @@ class ChatTable: except Exception: return None - async def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str, db: Optional[AsyncSession] = None) -> int: + async def count_chats_by_tag_name_and_user_id( + self, tag_name: str, user_id: str, db: Optional[AsyncSession] = None + ) -> int: async with get_async_db_context(db) as db: stmt = select(func.count(Chat.id)).filter_by(user_id=user_id, archived=False) tag_id = tag_name.replace(' ', '_').lower() @@ -1424,11 +1445,11 @@ class ChatTable: orphans.append(tag_id) await Tags.delete_tags_by_ids_and_user_id(orphans, user_id, db=db) - async def count_chats_by_folder_id_and_user_id(self, folder_id: str, user_id: str, db: Optional[AsyncSession] = None) -> int: + async def count_chats_by_folder_id_and_user_id( + self, folder_id: str, user_id: str, db: Optional[AsyncSession] = None + ) -> int: async with get_async_db_context(db) as db: - result = await db.execute( - select(func.count(Chat.id)).filter_by(user_id=user_id, folder_id=folder_id) - ) + result = await db.execute(select(func.count(Chat.id)).filter_by(user_id=user_id, folder_id=folder_id)) count = result.scalar() log.info(f"Count of chats for folder '{folder_id}': {count}") @@ -1470,9 +1491,7 @@ class ChatTable: async def delete_chat_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool: try: async with get_async_db_context(db) as db: - await db.execute( - update(AutomationRun).filter_by(chat_id=id).values(chat_id=None) - ) + await db.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None)) await db.execute(delete(ChatMessage).filter_by(chat_id=id)) await db.execute(delete(Chat).filter_by(id=id)) await db.commit() @@ -1484,9 +1503,7 @@ class ChatTable: async def delete_chat_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> bool: try: async with get_async_db_context(db) as db: - await db.execute( - update(AutomationRun).filter_by(chat_id=id).values(chat_id=None) - ) + await db.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None)) await db.execute(delete(ChatMessage).filter_by(chat_id=id)) await db.execute(delete(Chat).filter_by(id=id, user_id=user_id)) await db.commit() @@ -1502,7 +1519,9 @@ class ChatTable: chat_id_subquery = select(Chat.id).filter_by(user_id=user_id).scalar_subquery() await db.execute( - update(AutomationRun).filter(AutomationRun.chat_id.in_(select(Chat.id).filter_by(user_id=user_id))).values(chat_id=None) + update(AutomationRun) + .filter(AutomationRun.chat_id.in_(select(Chat.id).filter_by(user_id=user_id))) + .values(chat_id=None) ) await db.execute( delete(ChatMessage).filter(ChatMessage.chat_id.in_(select(Chat.id).filter_by(user_id=user_id))) @@ -1514,16 +1533,16 @@ class ChatTable: except Exception: return False - async def delete_chats_by_user_id_and_folder_id(self, user_id: str, folder_id: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_chats_by_user_id_and_folder_id( + self, user_id: str, folder_id: str, db: Optional[AsyncSession] = None + ) -> bool: try: async with get_async_db_context(db) as db: chat_ids_stmt = select(Chat.id).filter_by(user_id=user_id, folder_id=folder_id) await db.execute( update(AutomationRun).filter(AutomationRun.chat_id.in_(chat_ids_stmt)).values(chat_id=None) ) - await db.execute( - delete(ChatMessage).filter(ChatMessage.chat_id.in_(chat_ids_stmt)) - ) + await db.execute(delete(ChatMessage).filter(ChatMessage.chat_id.in_(chat_ids_stmt))) await db.execute(delete(Chat).filter_by(user_id=user_id, folder_id=folder_id)) await db.commit() @@ -1619,9 +1638,7 @@ class ChatTable: ) -> list[ChatFileModel]: async with get_async_db_context(db) as db: result = await db.execute( - select(ChatFile) - .filter_by(chat_id=chat_id, message_id=message_id) - .order_by(ChatFile.created_at.asc()) + select(ChatFile).filter_by(chat_id=chat_id, message_id=message_id).order_by(ChatFile.created_at.asc()) ) all_chat_files = result.scalars().all() return [ChatFileModel.model_validate(chat_file) for chat_file in all_chat_files] diff --git a/backend/open_webui/models/feedbacks.py b/backend/open_webui/models/feedbacks.py index 61124619b5..02f61f82ee 100644 --- a/backend/open_webui/models/feedbacks.py +++ b/backend/open_webui/models/feedbacks.py @@ -251,9 +251,7 @@ class FeedbackTable: stmt = stmt.order_by(Feedback.created_at.desc()) # Count BEFORE pagination - count_result = await db.execute( - select(func.count()).select_from(stmt.subquery()) - ) + count_result = await db.execute(select(func.count()).select_from(stmt.subquery())) total = count_result.scalar() if skip: @@ -280,8 +278,9 @@ class FeedbackTable: async def get_all_feedback_ids(self, db: Optional[AsyncSession] = None) -> list[FeedbackIdResponse]: async with get_async_db_context(db) as db: result = await db.execute( - select(Feedback.id, Feedback.user_id, Feedback.created_at, Feedback.updated_at) - .order_by(Feedback.updated_at.desc()) + select(Feedback.id, Feedback.user_id, Feedback.created_at, Feedback.updated_at).order_by( + Feedback.updated_at.desc() + ) ) return [ FeedbackIdResponse( @@ -378,16 +377,12 @@ class FeedbackTable: async def get_feedbacks_by_type(self, type: str, db: Optional[AsyncSession] = None) -> list[FeedbackModel]: async with get_async_db_context(db) as db: - result = await db.execute( - select(Feedback).filter_by(type=type).order_by(Feedback.updated_at.desc()) - ) + result = await db.execute(select(Feedback).filter_by(type=type).order_by(Feedback.updated_at.desc())) return [FeedbackModel.model_validate(feedback) for feedback in result.scalars().all()] async def get_feedbacks_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> list[FeedbackModel]: async with get_async_db_context(db) as db: - result = await db.execute( - select(Feedback).filter_by(user_id=user_id).order_by(Feedback.updated_at.desc()) - ) + result = await db.execute(select(Feedback).filter_by(user_id=user_id).order_by(Feedback.updated_at.desc())) return [FeedbackModel.model_validate(feedback) for feedback in result.scalars().all()] async def update_feedback_by_id( diff --git a/backend/open_webui/models/files.py b/backend/open_webui/models/files.py index f79255f50b..cfdcfbc2d9 100644 --- a/backend/open_webui/models/files.py +++ b/backend/open_webui/models/files.py @@ -125,7 +125,9 @@ class FileUpdateForm(BaseModel): class FilesTable: - async def insert_new_file(self, user_id: str, form_data: FileForm, db: Optional[AsyncSession] = None) -> Optional[FileModel]: + async def insert_new_file( + self, user_id: str, form_data: FileForm, db: Optional[AsyncSession] = None + ) -> Optional[FileModel]: async with get_async_db_context(db) as db: file_data = form_data.model_dump() @@ -167,7 +169,9 @@ class FilesTable: except Exception: return None - async def get_file_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> Optional[FileModel]: + async def get_file_by_id_and_user_id( + self, id: str, user_id: str, db: Optional[AsyncSession] = None + ) -> Optional[FileModel]: async with get_async_db_context(db) as db: try: result = await db.execute(select(File).filter_by(id=id, user_id=user_id)) @@ -179,7 +183,9 @@ class FilesTable: except Exception: return None - async def get_file_metadata_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[FileMetadataResponse]: + async def get_file_metadata_by_id( + self, id: str, db: Optional[AsyncSession] = None + ) -> Optional[FileMetadataResponse]: async with get_async_db_context(db) as db: try: file = await db.get(File, id) @@ -211,12 +217,12 @@ class FilesTable: async def get_files_by_ids(self, ids: list[str], db: Optional[AsyncSession] = None) -> list[FileModel]: async with get_async_db_context(db) as db: - result = await db.execute( - select(File).filter(File.id.in_(ids)).order_by(File.updated_at.desc()) - ) + result = await db.execute(select(File).filter(File.id.in_(ids)).order_by(File.updated_at.desc())) return [FileModel.model_validate(file) for file in result.scalars().all()] - async def get_file_metadatas_by_ids(self, ids: list[str], db: Optional[AsyncSession] = None) -> list[FileMetadataResponse]: + async def get_file_metadatas_by_ids( + self, ids: list[str], db: Optional[AsyncSession] = None + ) -> list[FileMetadataResponse]: async with get_async_db_context(db) as db: result = await db.execute( select(File.id, File.hash, File.meta, File.created_at, File.updated_at) @@ -251,18 +257,11 @@ class FilesTable: if user_id: stmt = stmt.filter_by(user_id=user_id) - count_result = await db.execute( - select(func.count()).select_from(stmt.subquery()) - ) + count_result = await db.execute(select(func.count()).select_from(stmt.subquery())) total = count_result.scalar() - result = await db.execute( - stmt.order_by(File.updated_at.desc(), File.id.desc()).offset(skip).limit(limit) - ) - items = [ - FileModelResponse.model_validate(file, from_attributes=True) - for file in result.scalars().all() - ] + result = await db.execute(stmt.order_by(File.updated_at.desc(), File.id.desc()).offset(skip).limit(limit)) + items = [FileModelResponse.model_validate(file, from_attributes=True) for file in result.scalars().all()] return FileListResponse(items=items, total=total) @@ -320,9 +319,7 @@ class FilesTable: if pattern != '%': stmt = stmt.filter(File.filename.ilike(pattern, escape='\\')) - result = await db.execute( - stmt.order_by(File.created_at.desc(), File.id.desc()).offset(skip).limit(limit) - ) + result = await db.execute(stmt.order_by(File.created_at.desc(), File.id.desc()).offset(skip).limit(limit)) return [FileModel.model_validate(file) for file in result.scalars().all()] async def update_file_by_id( @@ -349,7 +346,9 @@ class FilesTable: log.exception(f'Error updating file completely by id: {e}') return None - async def update_file_hash_by_id(self, id: str, hash: Optional[str], db: Optional[AsyncSession] = None) -> Optional[FileModel]: + async def update_file_hash_by_id( + self, id: str, hash: Optional[str], db: Optional[AsyncSession] = None + ) -> Optional[FileModel]: async with get_async_db_context(db) as db: try: result = await db.execute(select(File).filter_by(id=id)) @@ -362,7 +361,9 @@ class FilesTable: except Exception: return None - async def update_file_data_by_id(self, id: str, data: dict, db: Optional[AsyncSession] = None) -> Optional[FileModel]: + async def update_file_data_by_id( + self, id: str, data: dict, db: Optional[AsyncSession] = None + ) -> Optional[FileModel]: async with get_async_db_context(db) as db: try: result = await db.execute(select(File).filter_by(id=id)) @@ -374,7 +375,9 @@ class FilesTable: except Exception as e: return None - async def update_file_metadata_by_id(self, id: str, meta: dict, db: Optional[AsyncSession] = None) -> Optional[FileModel]: + async def update_file_metadata_by_id( + self, id: str, meta: dict, db: Optional[AsyncSession] = None + ) -> Optional[FileModel]: async with get_async_db_context(db) as db: try: result = await db.execute(select(File).filter_by(id=id)) diff --git a/backend/open_webui/models/folders.py b/backend/open_webui/models/folders.py index 4e2a4e9f38..47dbe195ab 100644 --- a/backend/open_webui/models/folders.py +++ b/backend/open_webui/models/folders.py @@ -171,9 +171,7 @@ class FolderTable: async with get_async_db_context(db) as db: # Check if folder exists result = await db.execute( - select(Folder) - .filter_by(parent_id=parent_id, user_id=user_id) - .filter(Folder.name.ilike(name)) + select(Folder).filter_by(parent_id=parent_id, user_id=user_id).filter(Folder.name.ilike(name)) ) folder = result.scalars().first() @@ -235,8 +233,7 @@ class FolderTable: form_data = form_data.model_dump(exclude_unset=True) existing_result = await db.execute( - select(Folder) - .filter_by( + select(Folder).filter_by( name=form_data.get('name'), parent_id=folder.parent_id, user_id=user_id, @@ -289,7 +286,9 @@ class FolderTable: log.error(f'update_folder: {e}') return - async def delete_folder_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> list[str]: + async def delete_folder_by_id_and_user_id( + self, id: str, user_id: str, db: Optional[AsyncSession] = None + ) -> list[str]: try: folder_ids = [] async with get_async_db_context(db) as db: diff --git a/backend/open_webui/models/functions.py b/backend/open_webui/models/functions.py index db34454b43..ddac317863 100644 --- a/backend/open_webui/models/functions.py +++ b/backend/open_webui/models/functions.py @@ -160,7 +160,9 @@ class FunctionsTable: for func in functions: if func.id in existing_ids: await db.execute( - update(Function).filter_by(id=func.id).values( + update(Function) + .filter_by(id=func.id) + .values( **func.model_dump(), user_id=user_id, updated_at=int(time.time()), @@ -233,9 +235,7 @@ class FunctionsTable: async def get_function_list(self, db: Optional[AsyncSession] = None) -> list[FunctionUserResponse]: async with get_async_db_context(db) as db: - result = await db.execute( - select(Function).order_by(Function.updated_at.desc()) - ) + result = await db.execute(select(Function).order_by(Function.updated_at.desc())) functions = result.scalars().all() user_ids = list(set(func.user_id for func in functions)) @@ -261,7 +261,9 @@ class FunctionsTable: for func in functions ] - async def get_functions_by_type(self, type: str, active_only=False, db: Optional[AsyncSession] = None) -> list[FunctionModel]: + async def get_functions_by_type( + self, type: str, active_only=False, db: Optional[AsyncSession] = None + ) -> list[FunctionModel]: async with get_async_db_context(db) as db: if active_only: result = await db.execute(select(Function).filter_by(type=type, is_active=True)) @@ -342,7 +344,9 @@ class FunctionsTable: log.exception(f'Error updating function metadata by id {id}: {e}') return None - async def get_user_valves_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> Optional[dict]: + async def get_user_valves_by_id_and_user_id( + self, id: str, user_id: str, db: Optional[AsyncSession] = None + ) -> Optional[dict]: try: user = await Users.get_user_by_id(user_id, db=db) user_settings = user.settings.model_dump() if user.settings else {} @@ -381,11 +385,15 @@ class FunctionsTable: log.exception(f'Error updating user valves by id {id} and user_id {user_id}: {e}') return None - async def update_function_by_id(self, id: str, updated: dict, db: Optional[AsyncSession] = None) -> Optional[FunctionModel]: + async def update_function_by_id( + self, id: str, updated: dict, db: Optional[AsyncSession] = None + ) -> Optional[FunctionModel]: async with get_async_db_context(db) as db: try: await db.execute( - update(Function).filter_by(id=id).values( + update(Function) + .filter_by(id=id) + .values( **updated, updated_at=int(time.time()), ) diff --git a/backend/open_webui/models/groups.py b/backend/open_webui/models/groups.py index bca9908580..bc199fac5b 100644 --- a/backend/open_webui/models/groups.py +++ b/backend/open_webui/models/groups.py @@ -261,12 +261,10 @@ class GroupTable: if 'share' in filter: share_value = filter['share'] - stmt = stmt.filter(Group.data.op('->>') ('share') == str(share_value)) + stmt = stmt.filter(Group.data.op('->>')('share') == str(share_value)) # Get total count - count_result = await db.execute( - select(func.count()).select_from(stmt.subquery()) - ) + count_result = await db.execute(select(func.count()).select_from(stmt.subquery())) total = count_result.scalar() member_count = ( @@ -348,7 +346,9 @@ class GroupTable: return [m[0] for m in members] - async def get_group_user_ids_by_ids(self, group_ids: list[str], db: Optional[AsyncSession] = None) -> dict[str, list[str]]: + async def get_group_user_ids_by_ids( + self, group_ids: list[str], db: Optional[AsyncSession] = None + ) -> dict[str, list[str]]: async with get_async_db_context(db) as db: result = await db.execute( select(GroupMember.group_id, GroupMember.user_id).filter(GroupMember.group_id.in_(group_ids)) @@ -362,7 +362,9 @@ class GroupTable: return group_user_ids - async def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str], db: Optional[AsyncSession] = None) -> None: + async def set_group_user_ids_by_id( + self, group_id: str, user_ids: list[str], db: Optional[AsyncSession] = None + ) -> None: async with get_async_db_context(db) as db: # Delete existing members await db.execute(delete(GroupMember).filter(GroupMember.group_id == group_id)) @@ -411,7 +413,9 @@ class GroupTable: try: async with get_async_db_context(db) as db: await db.execute( - update(Group).filter_by(id=id).values( + update(Group) + .filter_by(id=id) + .values( **form_data.model_dump(exclude_none=True), updated_at=int(time.time()), ) @@ -455,14 +459,10 @@ class GroupTable: # Remove the user from each group for group in groups: await db.execute( - delete(GroupMember).filter( - GroupMember.group_id == group.id, GroupMember.user_id == user_id - ) + delete(GroupMember).filter(GroupMember.group_id == group.id, GroupMember.user_id == user_id) ) - await db.execute( - update(Group).filter_by(id=group.id).values(updated_at=int(time.time())) - ) + await db.execute(update(Group).filter_by(id=group.id).values(updated_at=int(time.time()))) await db.commit() return True @@ -507,7 +507,9 @@ class GroupTable: continue return new_groups - async def sync_groups_by_group_names(self, user_id: str, group_names: list[str], db: Optional[AsyncSession] = None) -> bool: + async def sync_groups_by_group_names( + self, user_id: str, group_names: list[str], db: Optional[AsyncSession] = None + ) -> bool: async with get_async_db_context(db) as db: try: now = int(time.time()) @@ -538,9 +540,7 @@ class GroupTable: ) ) - await db.execute( - update(Group).filter(Group.id.in_(groups_to_remove)).values(updated_at=now) - ) + await db.execute(update(Group).filter(Group.id.in_(groups_to_remove)).values(updated_at=now)) # 5. Bulk insert missing memberships for group_id in groups_to_add: @@ -555,9 +555,7 @@ class GroupTable: ) if groups_to_add: - await db.execute( - update(Group).filter(Group.id.in_(groups_to_add)).values(updated_at=now) - ) + await db.execute(update(Group).filter(Group.id.in_(groups_to_add)).values(updated_at=now)) await db.commit() return True diff --git a/backend/open_webui/models/knowledge.py b/backend/open_webui/models/knowledge.py index 68cee36c20..2750ef6058 100644 --- a/backend/open_webui/models/knowledge.py +++ b/backend/open_webui/models/knowledge.py @@ -196,11 +196,13 @@ class KnowledgeTable: knowledge_bases.append( KnowledgeUserModel.model_validate( { - **(await self._to_knowledge_model( - knowledge, - access_grants=grants_map.get(knowledge.id, []), - db=db, - )).model_dump(), + **( + await self._to_knowledge_model( + knowledge, + access_grants=grants_map.get(knowledge.id, []), + db=db, + ) + ).model_dump(), 'user': user.model_dump() if user else None, } ) @@ -249,9 +251,7 @@ class KnowledgeTable: stmt = stmt.order_by(Knowledge.updated_at.desc(), Knowledge.id.asc()) - count_result = await db.execute( - select(func.count()).select_from(stmt.subquery()) - ) + count_result = await db.execute(select(func.count()).select_from(stmt.subquery())) total = count_result.scalar() if skip: stmt = stmt.offset(skip) @@ -269,11 +269,13 @@ class KnowledgeTable: knowledge_bases.append( KnowledgeUserModel.model_validate( { - **(await self._to_knowledge_model( - knowledge_base, - access_grants=grants_map.get(knowledge_base.id, []), - db=db, - )).model_dump(), + **( + await self._to_knowledge_model( + knowledge_base, + access_grants=grants_map.get(knowledge_base.id, []), + db=db, + ) + ).model_dump(), 'user': (UserModel.model_validate(user).model_dump() if user else None), } ) @@ -321,9 +323,7 @@ class KnowledgeTable: stmt = stmt.order_by(File.updated_at.desc(), File.id.asc()) # Count before pagination - count_result = await db.execute( - select(func.count()).select_from(stmt.subquery()) - ) + count_result = await db.execute(select(func.count()).select_from(stmt.subquery())) total = count_result.scalar() if skip: @@ -490,9 +490,7 @@ class KnowledgeTable: stmt = stmt.order_by(primary_sort, File.id.asc()) # Count BEFORE pagination - count_result = await db.execute( - select(func.count()).select_from(stmt.subquery()) - ) + count_result = await db.execute(select(func.count()).select_from(stmt.subquery())) total = count_result.scalar() if skip: @@ -530,7 +528,9 @@ class KnowledgeTable: except Exception: return [] - async def get_file_metadatas_by_id(self, knowledge_id: str, db: Optional[AsyncSession] = None) -> list[FileMetadataResponse]: + async def get_file_metadatas_by_id( + self, knowledge_id: str, db: Optional[AsyncSession] = None + ) -> list[FileMetadataResponse]: try: files = await self.get_files_by_id(knowledge_id, db=db) return [FileMetadataResponse(**file.model_dump()) for file in files] @@ -579,7 +579,9 @@ class KnowledgeTable: except Exception: return False - async def remove_file_from_knowledge_by_id(self, knowledge_id: str, file_id: str, db: Optional[AsyncSession] = None) -> bool: + async def remove_file_from_knowledge_by_id( + self, knowledge_id: str, file_id: str, db: Optional[AsyncSession] = None + ) -> bool: try: async with get_async_db_context(db) as db: await db.execute(delete(KnowledgeFile).filter_by(knowledge_id=knowledge_id, file_id=file_id)) @@ -596,9 +598,7 @@ class KnowledgeTable: await db.commit() # Update the knowledge entry's updated_at timestamp - await db.execute( - update(Knowledge).filter_by(id=id).values(updated_at=int(time.time())) - ) + await db.execute(update(Knowledge).filter_by(id=id).values(updated_at=int(time.time()))) await db.commit() return await self.get_knowledge_by_id(id=id, db=db) @@ -616,7 +616,9 @@ class KnowledgeTable: try: async with get_async_db_context(db) as db: await db.execute( - update(Knowledge).filter_by(id=id).values( + update(Knowledge) + .filter_by(id=id) + .values( **form_data.model_dump(exclude={'access_grants'}), updated_at=int(time.time()), ) @@ -635,7 +637,9 @@ class KnowledgeTable: try: async with get_async_db_context(db) as db: await db.execute( - update(Knowledge).filter_by(id=id).values( + update(Knowledge) + .filter_by(id=id) + .values( data=data, updated_at=int(time.time()), ) diff --git a/backend/open_webui/models/messages.py b/backend/open_webui/models/messages.py index c9af45ebf5..7f33a72eff 100644 --- a/backend/open_webui/models/messages.py +++ b/backend/open_webui/models/messages.py @@ -250,11 +250,11 @@ class MessageTable: } return None - async def get_thread_replies_by_message_id(self, id: str, db: Optional[AsyncSession] = None) -> list[MessageReplyToResponse]: + async def get_thread_replies_by_message_id( + self, id: str, db: Optional[AsyncSession] = None + ) -> list[MessageReplyToResponse]: async with get_async_db_context(db) as db: - result = await db.execute( - select(Message).filter_by(parent_id=id).order_by(Message.created_at.desc()) - ) + result = await db.execute(select(Message).filter_by(parent_id=id).order_by(Message.created_at.desc())) all_messages = result.scalars().all() messages = [] @@ -369,7 +369,9 @@ class MessageTable: ) return messages - async def get_last_message_by_channel_id(self, channel_id: str, db: Optional[AsyncSession] = None) -> Optional[MessageModel]: + async def get_last_message_by_channel_id( + self, channel_id: str, db: Optional[AsyncSession] = None + ) -> Optional[MessageModel]: async with get_async_db_context(db) as db: result = await db.execute( select(Message).filter_by(channel_id=channel_id).order_by(Message.created_at.desc()).limit(1) @@ -453,9 +455,7 @@ class MessageTable: ) -> Optional[MessageReactionModel]: async with get_async_db_context(db) as db: # check for existing reaction - result = await db.execute( - select(MessageReaction).filter_by(message_id=id, user_id=user_id, name=name) - ) + result = await db.execute(select(MessageReaction).filter_by(message_id=id, user_id=user_id, name=name)) existing_reaction = result.scalars().first() if existing_reaction: return MessageReactionModel.model_validate(existing_reaction) diff --git a/backend/open_webui/models/models.py b/backend/open_webui/models/models.py index 4664a71b85..36b0792a6b 100755 --- a/backend/open_webui/models/models.py +++ b/backend/open_webui/models/models.py @@ -200,7 +200,8 @@ class ModelsTable: model_ids = [model.id for model in all_models] grants_map = await AccessGrants.get_grants_by_resources('model', model_ids, db=db) return [ - await self._to_model_model(model, access_grants=grants_map.get(model.id, []), db=db) for model in all_models + await self._to_model_model(model, access_grants=grants_map.get(model.id, []), db=db) + for model in all_models ] async def get_models(self, db: Optional[AsyncSession] = None) -> list[ModelUserResponse]: @@ -221,11 +222,13 @@ class ModelsTable: models.append( ModelUserResponse.model_validate( { - **(await self._to_model_model( - model, - access_grants=grants_map.get(model.id, []), - db=db, - )).model_dump(), + **( + await self._to_model_model( + model, + access_grants=grants_map.get(model.id, []), + db=db, + ) + ).model_dump(), 'user': user.model_dump() if user else None, } ) @@ -239,7 +242,8 @@ class ModelsTable: model_ids = [model.id for model in all_models] grants_map = await AccessGrants.get_grants_by_resources('model', model_ids, db=db) return [ - await self._to_model_model(model, access_grants=grants_map.get(model.id, []), db=db) for model in all_models + await self._to_model_model(model, access_grants=grants_map.get(model.id, []), db=db) + for model in all_models ] async def get_models_by_user_id( @@ -342,9 +346,7 @@ class ModelsTable: stmt = stmt.order_by(Model.created_at.desc()) # Count BEFORE pagination - count_result = await db.execute( - select(func.count()).select_from(stmt.subquery()) - ) + count_result = await db.execute(select(func.count()).select_from(stmt.subquery())) total = count_result.scalar() if skip: @@ -362,11 +364,13 @@ class ModelsTable: for model, user in items: models.append( ModelUserResponse( - **(await self._to_model_model( - model, - access_grants=grants_map.get(model.id, []), - db=db, - )).model_dump(), + **( + await self._to_model_model( + model, + access_grants=grants_map.get(model.id, []), + db=db, + ) + ).model_dump(), user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None), ) ) @@ -416,7 +420,9 @@ class ModelsTable: except Exception: return None - async def update_model_by_id(self, id: str, model: ModelForm, db: Optional[AsyncSession] = None) -> Optional[ModelModel]: + async def update_model_by_id( + self, id: str, model: ModelForm, db: Optional[AsyncSession] = None + ) -> Optional[ModelModel]: try: async with get_async_db_context(db) as db: # update only the fields that are present in the model @@ -473,7 +479,9 @@ class ModelsTable: except Exception: return False - async def sync_models(self, user_id: str, models: list[ModelModel], db: Optional[AsyncSession] = None) -> list[ModelModel]: + async def sync_models( + self, user_id: str, models: list[ModelModel], db: Optional[AsyncSession] = None + ) -> list[ModelModel]: try: async with get_async_db_context(db) as db: # Get existing models @@ -488,7 +496,9 @@ class ModelsTable: for model in models: if model.id in existing_ids: await db.execute( - update(Model).filter_by(id=model.id).values( + update(Model) + .filter_by(id=model.id) + .values( **model.model_dump(exclude={'access_grants'}), user_id=user_id, updated_at=int(time.time()), diff --git a/backend/open_webui/models/notes.py b/backend/open_webui/models/notes.py index b06465c7ad..be42b5d850 100644 --- a/backend/open_webui/models/notes.py +++ b/backend/open_webui/models/notes.py @@ -113,7 +113,9 @@ class NoteTable: permission=permission, ) - async def insert_new_note(self, user_id: str, form_data: NoteForm, db: Optional[AsyncSession] = None) -> Optional[NoteModel]: + async def insert_new_note( + self, user_id: str, form_data: NoteForm, db: Optional[AsyncSession] = None + ) -> Optional[NoteModel]: async with get_async_db_context(db) as db: note = NoteModel( **{ @@ -216,9 +218,7 @@ class NoteTable: stmt = stmt.order_by(Note.updated_at.desc()) # Count BEFORE pagination - count_result = await db.execute( - select(func.count()).select_from(stmt.subquery()) - ) + count_result = await db.execute(select(func.count()).select_from(stmt.subquery())) total = count_result.scalar() if skip: @@ -236,11 +236,13 @@ class NoteTable: for note, user in items: notes.append( NoteUserResponse( - **(await self._to_note_model( - note, - access_grants=grants_map.get(note.id, []), - db=db, - )).model_dump(), + **( + await self._to_note_model( + note, + access_grants=grants_map.get(note.id, []), + db=db, + ) + ).model_dump(), user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None), ) ) diff --git a/backend/open_webui/models/oauth_sessions.py b/backend/open_webui/models/oauth_sessions.py index c8ff569f27..84e5c66560 100644 --- a/backend/open_webui/models/oauth_sessions.py +++ b/backend/open_webui/models/oauth_sessions.py @@ -151,7 +151,9 @@ class OAuthSessionTable: log.error(f'Error creating OAuth session: {e}') return None - async def get_session_by_id(self, session_id: str, db: Optional[AsyncSession] = None) -> Optional[OAuthSessionModel]: + async def get_session_by_id( + self, session_id: str, db: Optional[AsyncSession] = None + ) -> Optional[OAuthSessionModel]: """Get OAuth session by ID""" try: async with get_async_db_context(db) as db: @@ -235,15 +237,17 @@ class OAuthSessionTable: results = [] for session in sessions: try: - results.append(OAuthSessionModel( - id=session.id, - user_id=session.user_id, - provider=session.provider, - token=self._decrypt_token(session.token), - expires_at=session.expires_at, - created_at=session.created_at, - updated_at=session.updated_at, - )) + results.append( + OAuthSessionModel( + id=session.id, + user_id=session.user_id, + provider=session.provider, + token=self._decrypt_token(session.token), + expires_at=session.expires_at, + created_at=session.created_at, + updated_at=session.updated_at, + ) + ) except Exception as e: log.warning( f'Skipping OAuth session {session.id} due to decryption failure, deleting corrupted session: {type(e).__name__}: {e}' @@ -266,7 +270,9 @@ class OAuthSessionTable: current_time = int(time.time()) await db.execute( - update(OAuthSession).filter_by(id=session_id).values( + update(OAuthSession) + .filter_by(id=session_id) + .values( token=self._encrypt_token(token), expires_at=token.get('expires_at'), updated_at=current_time, diff --git a/backend/open_webui/models/prompts.py b/backend/open_webui/models/prompts.py index 7250d1901e..64f66bec86 100644 --- a/backend/open_webui/models/prompts.py +++ b/backend/open_webui/models/prompts.py @@ -213,11 +213,13 @@ class PromptsTable: prompts.append( PromptUserResponse.model_validate( { - **(await self._to_prompt_model( - prompt, - access_grants=grants_map.get(prompt.id, []), - db=db, - )).model_dump(), + **( + await self._to_prompt_model( + prompt, + access_grants=grants_map.get(prompt.id, []), + db=db, + ) + ).model_dump(), 'user': user.model_dump() if user else None, } ) @@ -319,9 +321,7 @@ class PromptsTable: stmt = stmt.order_by(Prompt.updated_at.desc()) # Count BEFORE pagination - count_result = await db.execute( - select(func.count()).select_from(stmt.subquery()) - ) + count_result = await db.execute(select(func.count()).select_from(stmt.subquery())) total = count_result.scalar() if skip: @@ -339,11 +339,13 @@ class PromptsTable: for prompt, user in items: prompts.append( PromptUserResponse( - **(await self._to_prompt_model( - prompt, - access_grants=grants_map.get(prompt.id, []), - db=db, - )).model_dump(), + **( + await self._to_prompt_model( + prompt, + access_grants=grants_map.get(prompt.id, []), + db=db, + ) + ).model_dump(), user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None), ) ) diff --git a/backend/open_webui/models/skills.py b/backend/open_webui/models/skills.py index 55ba204135..0fc6dfc52d 100644 --- a/backend/open_webui/models/skills.py +++ b/backend/open_webui/models/skills.py @@ -184,11 +184,13 @@ class SkillsTable: skills.append( SkillUserModel.model_validate( { - **(await self._to_skill_model( - skill, - access_grants=grants_map.get(skill.id, []), - db=db, - )).model_dump(), + **( + await self._to_skill_model( + skill, + access_grants=grants_map.get(skill.id, []), + db=db, + ) + ).model_dump(), 'user': user.model_dump() if user else None, } ) @@ -262,9 +264,7 @@ class SkillsTable: stmt = stmt.order_by(Skill.updated_at.desc()) # Count BEFORE pagination - count_result = await db.execute( - select(func.count()).select_from(stmt.subquery()) - ) + count_result = await db.execute(select(func.count()).select_from(stmt.subquery())) total = count_result.scalar() if skip: @@ -282,11 +282,13 @@ class SkillsTable: for skill, user in items: skills.append( SkillUserResponse( - **(await self._to_skill_model( - skill, - access_grants=grants_map.get(skill.id, []), - db=db, - )).model_dump(), + **( + await self._to_skill_model( + skill, + access_grants=grants_map.get(skill.id, []), + db=db, + ) + ).model_dump(), user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None), ) ) @@ -296,7 +298,9 @@ class SkillsTable: log.exception(f'Error searching skills: {e}') return SkillListResponse(items=[], total=0) - async def update_skill_by_id(self, id: str, updated: dict, db: Optional[AsyncSession] = None) -> Optional[SkillModel]: + async def update_skill_by_id( + self, id: str, updated: dict, db: Optional[AsyncSession] = None + ) -> Optional[SkillModel]: try: async with get_async_db_context(db) as db: access_grants = updated.pop('access_grants', None) diff --git a/backend/open_webui/models/tags.py b/backend/open_webui/models/tags.py index 95b97b9cc1..ee2baefc01 100644 --- a/backend/open_webui/models/tags.py +++ b/backend/open_webui/models/tags.py @@ -71,7 +71,9 @@ class TagTable: log.exception(f'Error inserting a new tag: {e}') return None - async def get_tag_by_name_and_user_id(self, name: str, user_id: str, db: Optional[AsyncSession] = None) -> Optional[TagModel]: + async def get_tag_by_name_and_user_id( + self, name: str, user_id: str, db: Optional[AsyncSession] = None + ) -> Optional[TagModel]: try: id = name.replace(' ', '_').lower() async with get_async_db_context(db) as db: @@ -86,7 +88,9 @@ class TagTable: result = await db.execute(select(Tag).filter_by(user_id=user_id)) return [TagModel.model_validate(tag) for tag in result.scalars().all()] - async def get_tags_by_ids_and_user_id(self, ids: list[str], user_id: str, db: Optional[AsyncSession] = None) -> list[TagModel]: + async def get_tags_by_ids_and_user_id( + self, ids: list[str], user_id: str, db: Optional[AsyncSession] = None + ) -> list[TagModel]: async with get_async_db_context(db) as db: result = await db.execute(select(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id)) return [TagModel.model_validate(tag) for tag in result.scalars().all()] @@ -103,7 +107,9 @@ class TagTable: log.error(f'delete_tag: {e}') return False - async def delete_tags_by_ids_and_user_id(self, ids: list[str], user_id: str, db: Optional[AsyncSession] = None) -> bool: + async def delete_tags_by_ids_and_user_id( + self, ids: list[str], user_id: str, db: Optional[AsyncSession] = None + ) -> bool: """Delete all tags whose id is in *ids* for the given user, in one query.""" if not ids: return True diff --git a/backend/open_webui/models/tools.py b/backend/open_webui/models/tools.py index fe772c4443..70035121aa 100644 --- a/backend/open_webui/models/tools.py +++ b/backend/open_webui/models/tools.py @@ -172,11 +172,13 @@ class ToolsTable: tools.append( ToolUserModel.model_validate( { - **(await self._to_tool_model( - tool, - access_grants=grants_map.get(tool.id, []), - db=db, - )).model_dump(), + **( + await self._to_tool_model( + tool, + access_grants=grants_map.get(tool.id, []), + db=db, + ) + ).model_dump(), 'user': user.model_dump() if user else None, } ) @@ -218,18 +220,20 @@ class ToolsTable: log.exception(f'Error getting tool valves by id {id}') return None - async def update_tool_valves_by_id(self, id: str, valves: dict, db: Optional[AsyncSession] = None) -> Optional[ToolValves]: + async def update_tool_valves_by_id( + self, id: str, valves: dict, db: Optional[AsyncSession] = None + ) -> Optional[ToolValves]: try: async with get_async_db_context(db) as db: - await db.execute( - update(Tool).filter_by(id=id).values(valves=valves, updated_at=int(time.time())) - ) + await db.execute(update(Tool).filter_by(id=id).values(valves=valves, updated_at=int(time.time()))) await db.commit() return await self.get_tool_by_id(id, db=db) except Exception: return None - async def get_user_valves_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> Optional[dict]: + async def get_user_valves_by_id_and_user_id( + self, id: str, user_id: str, db: Optional[AsyncSession] = None + ) -> Optional[dict]: try: user = await Users.get_user_by_id(user_id, db=db) user_settings = user.settings.model_dump() if user.settings else {} @@ -272,9 +276,7 @@ class ToolsTable: try: async with get_async_db_context(db) as db: access_grants = updated.pop('access_grants', None) - await db.execute( - update(Tool).filter_by(id=id).values(**updated, updated_at=int(time.time())) - ) + await db.execute(update(Tool).filter_by(id=id).values(**updated, updated_at=int(time.time()))) await db.commit() if access_grants is not None: await AccessGrants.set_access_grants('tool', id, access_grants, db=db) diff --git a/backend/open_webui/models/users.py b/backend/open_webui/models/users.py index a5a43c27b8..e06e6957f9 100644 --- a/backend/open_webui/models/users.py +++ b/backend/open_webui/models/users.py @@ -31,11 +31,13 @@ import datetime # daily bread of every session. Let none go hungry. #################### + class UserSettings(BaseModel): ui: Optional[dict] = {} model_config = ConfigDict(extra='allow') pass + class User(Base): __tablename__ = 'user' @@ -69,6 +71,7 @@ class User(Base): updated_at = Column(BigInteger) created_at = Column(BigInteger) + class UserModel(BaseModel): id: str @@ -109,11 +112,13 @@ class UserModel(BaseModel): self.profile_image_url = f'/api/v1/users/{self.id}/profile/image' return self + class UserStatusModel(UserModel): is_active: bool = False model_config = ConfigDict(from_attributes=True) + class ApiKey(Base): __tablename__ = 'api_key' @@ -126,6 +131,7 @@ class ApiKey(Base): created_at = Column(BigInteger, nullable=False) updated_at = Column(BigInteger, nullable=False) + class ApiKeyModel(BaseModel): id: str user_id: str @@ -138,10 +144,12 @@ class ApiKeyModel(BaseModel): model_config = ConfigDict(from_attributes=True) + #################### # Forms #################### + class UpdateProfileForm(BaseModel): profile_image_url: str name: str @@ -154,25 +162,31 @@ class UpdateProfileForm(BaseModel): def check_profile_image_url(cls, v: str) -> str: return validate_profile_image_url(v) + class UserGroupIdsModel(UserModel): group_ids: list[str] = [] + class UserModelResponse(UserModel): model_config = ConfigDict(extra='allow') + class UserListResponse(BaseModel): users: list[UserModelResponse] total: int + class UserGroupIdsListResponse(BaseModel): users: list[UserGroupIdsModel] total: int + class UserStatus(BaseModel): status_emoji: Optional[str] = None status_message: Optional[str] = None status_expires_at: Optional[int] = None + class UserInfoResponse(UserStatus): id: str name: str @@ -182,39 +196,48 @@ class UserInfoResponse(UserStatus): groups: Optional[list] = [] is_active: bool = False + class UserIdNameResponse(BaseModel): id: str name: str + class UserIdNameStatusResponse(UserStatus): id: str name: str is_active: Optional[bool] = None + class UserInfoListResponse(BaseModel): users: list[UserInfoResponse] total: int + class UserIdNameListResponse(BaseModel): users: list[UserIdNameResponse] total: int + class UserNameResponse(BaseModel): id: str name: str role: str + class UserResponse(UserNameResponse): email: str + class UserProfileImageResponse(UserNameResponse): email: str profile_image_url: str + class UserRoleUpdateForm(BaseModel): id: str role: str + class UserUpdateForm(BaseModel): role: str name: str @@ -227,6 +250,7 @@ class UserUpdateForm(BaseModel): def check_profile_image_url(cls, v: str) -> str: return validate_profile_image_url(v) + class UsersTable: async def insert_new_user( self, @@ -292,7 +316,9 @@ class UsersTable: except Exception: return None - async def get_user_by_oauth_sub(self, provider: str, sub: str, db: Optional[AsyncSession] = None) -> Optional[UserModel]: + async def get_user_by_oauth_sub( + self, provider: str, sub: str, db: Optional[AsyncSession] = None + ) -> Optional[UserModel]: try: async with get_async_db_context(db) as db: dialect_name = db.bind.dialect.name @@ -457,9 +483,7 @@ class UsersTable: stmt = stmt.order_by(User.created_at.desc()) # Count BEFORE pagination - count_result = await db.execute( - select(func.count()).select_from(stmt.subquery()) - ) + count_result = await db.execute(select(func.count()).select_from(stmt.subquery())) total = count_result.scalar() # correct pagination logic @@ -478,20 +502,18 @@ class UsersTable: async def get_users_by_group_id(self, group_id: str, db: Optional[AsyncSession] = None) -> list[UserModel]: async with get_async_db_context(db) as db: from open_webui.models.groups import GroupMember + result = await db.execute( - select(User) - - .join(GroupMember, User.id == GroupMember.user_id) - .filter(GroupMember.group_id == group_id) + select(User).join(GroupMember, User.id == GroupMember.user_id).filter(GroupMember.group_id == group_id) ) users = result.scalars().all() return [UserModel.model_validate(user) for user in users] - async def get_users_by_user_ids(self, user_ids: list[str], db: Optional[AsyncSession] = None) -> list[UserStatusModel]: + async def get_users_by_user_ids( + self, user_ids: list[str], db: Optional[AsyncSession] = None + ) -> list[UserStatusModel]: async with get_async_db_context(db) as db: - result = await db.execute( - select(User).filter(User.id.in_(user_ids)) - ) + result = await db.execute(select(User).filter(User.id.in_(user_ids))) users = result.scalars().all() return [UserModel.model_validate(user) for user in users] @@ -536,7 +558,9 @@ class UsersTable: ) return result.scalar() - async def update_user_role_by_id(self, id: str, role: str, db: Optional[AsyncSession] = None) -> Optional[UserModel]: + async def update_user_role_by_id( + self, id: str, role: str, db: Optional[AsyncSession] = None + ) -> Optional[UserModel]: try: async with get_async_db_context(db) as db: result = await db.execute(select(User).filter_by(id=id)) @@ -674,7 +698,9 @@ class UsersTable: print(e) return None - async def update_user_settings_by_id(self, id: str, updated: dict, db: Optional[AsyncSession] = None) -> Optional[UserModel]: + async def update_user_settings_by_id( + self, id: str, updated: dict, db: Optional[AsyncSession] = None + ) -> Optional[UserModel]: try: async with get_async_db_context(db) as db: result = await db.execute(select(User).filter_by(id=id)) @@ -802,4 +828,5 @@ class UsersTable: return user.last_active_at >= three_minutes_ago return False + Users = UsersTable() diff --git a/backend/open_webui/routers/analytics.py b/backend/open_webui/routers/analytics.py index 8636444a5d..fd045f79e7 100644 --- a/backend/open_webui/routers/analytics.py +++ b/backend/open_webui/routers/analytics.py @@ -62,7 +62,9 @@ async def get_model_analytics( db: AsyncSession = Depends(get_async_session), ): """Get message counts per model.""" - counts = await ChatMessages.get_message_count_by_model(start_date=start_date, end_date=end_date, group_id=group_id, db=db) + counts = await ChatMessages.get_message_count_by_model( + start_date=start_date, end_date=end_date, group_id=group_id, db=db + ) models = [ ModelAnalyticsEntry(model_id=model_id, count=count) for model_id, count in sorted(counts.items(), key=lambda x: -x[1]) @@ -80,7 +82,9 @@ async def get_user_analytics( db: AsyncSession = Depends(get_async_session), ): """Get message counts and token usage per user with user info.""" - counts = await ChatMessages.get_message_count_by_user(start_date=start_date, end_date=end_date, group_id=group_id, db=db) + counts = await ChatMessages.get_message_count_by_user( + start_date=start_date, end_date=end_date, group_id=group_id, db=db + ) token_usage = await ChatMessages.get_token_usage_by_user( start_date=start_date, end_date=end_date, group_id=group_id, db=db ) @@ -227,7 +231,9 @@ async def get_token_usage( db: AsyncSession = Depends(get_async_session), ): """Get token usage aggregated by model.""" - usage = await ChatMessages.get_token_usage_by_model(start_date=start_date, end_date=end_date, group_id=group_id, db=db) + usage = await ChatMessages.get_token_usage_by_model( + start_date=start_date, end_date=end_date, group_id=group_id, db=db + ) models = [ TokenUsageEntry(model_id=model_id, **data) diff --git a/backend/open_webui/routers/audio.py b/backend/open_webui/routers/audio.py index 5a26d04e0f..69744e7219 100644 --- a/backend/open_webui/routers/audio.py +++ b/backend/open_webui/routers/audio.py @@ -330,7 +330,9 @@ async def speech(request: Request, user=Depends(get_verified_user)): detail=ERROR_MESSAGES.NOT_FOUND, ) - if user.role != 'admin' and not await has_permission(user.id, 'chat.tts', request.app.state.config.USER_PERMISSIONS): + if user.role != 'admin' and not await has_permission( + user.id, 'chat.tts', request.app.state.config.USER_PERMISSIONS + ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -630,6 +632,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): detail=detail, ) + def transcription_handler(request, file_path, metadata, user=None): filename = os.path.basename(file_path) file_dir = os.path.dirname(file_path) @@ -1214,7 +1217,9 @@ async def transcription( language: Optional[str] = Form(None), user=Depends(get_verified_user), ): - if user.role != 'admin' and not await has_permission(user.id, 'chat.stt', request.app.state.config.USER_PERMISSIONS): + if user.role != 'admin' and not await has_permission( + user.id, 'chat.stt', request.app.state.config.USER_PERMISSIONS + ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, diff --git a/backend/open_webui/routers/auths.py b/backend/open_webui/routers/auths.py index 6652ebf44d..cfc05160c1 100644 --- a/backend/open_webui/routers/auths.py +++ b/backend/open_webui/routers/auths.py @@ -97,7 +97,9 @@ log = logging.getLogger(__name__) signin_rate_limiter = RateLimiter(redis_client=get_redis_client(), limit=5 * 3, window=60 * 3) -async def create_session_response(request: Request, user, db, response: Response = None, set_cookie: bool = False) -> dict: +async def create_session_response( + request: Request, user, db, response: Response = None, set_cookie: bool = False +) -> dict: """ Create JWT token and build session response for a user. Shared helper for signin, signup, ldap_auth, add_user, and token_exchange endpoints. @@ -918,7 +920,9 @@ async def add_user( @router.get('/admin/details') -async def get_admin_details(request: Request, user=Depends(get_current_user), db: AsyncSession = Depends(get_async_session)): +async def get_admin_details( + request: Request, user=Depends(get_current_user), db: AsyncSession = Depends(get_async_session) +): if request.app.state.config.SHOW_ADMIN_DETAILS: admin_email = request.app.state.config.ADMIN_EMAIL admin_name = None @@ -1182,7 +1186,9 @@ async def update_ldap_config(request: Request, form_data: LdapConfigForm, user=D # create api key @router.post('/api_key', response_model=ApiKey) -async def generate_api_key(request: Request, user=Depends(get_current_user), db: AsyncSession = Depends(get_async_session)): +async def generate_api_key( + request: Request, user=Depends(get_current_user), db: AsyncSession = Depends(get_async_session) +): if not request.app.state.config.ENABLE_API_KEYS or ( user.role != 'admin' and not await has_permission(user.id, 'features.api_keys', request.app.state.config.USER_PERMISSIONS) diff --git a/backend/open_webui/routers/channels.py b/backend/open_webui/routers/channels.py index 5c2ab9dcac..b6eee93eac 100644 --- a/backend/open_webui/routers/channels.py +++ b/backend/open_webui/routers/channels.py @@ -94,7 +94,9 @@ async def channel_has_access( return False -async def get_channel_users_with_access(channel: ChannelModel, permission: str = 'read', db: Optional[AsyncSession] = None): +async def get_channel_users_with_access( + channel: ChannelModel, permission: str = 'read', db: Optional[AsyncSession] = None +): return await AccessGrants.get_users_with_access( resource_type='channel', resource_id=channel.id, @@ -893,11 +895,13 @@ async def model_response_handler(request, channel, message, user, db=None): if model: try: # reverse to get in chronological order - thread_messages = (await Messages.get_messages_by_parent_id( - channel.id, - message.parent_id if message.parent_id else message.id, - db=db, - ))[::-1] + thread_messages = ( + await Messages.get_messages_by_parent_id( + channel.id, + message.parent_id if message.parent_id else message.id, + db=db, + ) + )[::-1] response_message, channel = await new_message_handler( request, @@ -1120,7 +1124,9 @@ async def post_new_message( try: if files := message.data.get('files', []): for file in files: - await Channels.set_file_message_id_in_channel_by_id(channel.id, file.get('id', ''), message.id, db=db) + await Channels.set_file_message_id_in_channel_by_id( + channel.id, file.get('id', ''), message.id, db=db + ) except Exception as e: log.debug(e) diff --git a/backend/open_webui/routers/chats.py b/backend/open_webui/routers/chats.py index ba07937ed1..979b9388cf 100644 --- a/backend/open_webui/routers/chats.py +++ b/backend/open_webui/routers/chats.py @@ -493,7 +493,9 @@ async def delete_all_user_chats( user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): - if user.role == 'user' and not await has_permission(user.id, 'chat.delete', request.app.state.config.USER_PERMISSIONS): + if user.role == 'user' and not await has_permission( + user.id, 'chat.delete', request.app.state.config.USER_PERMISSIONS + ): raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @@ -538,7 +540,9 @@ async def get_user_chat_list_by_user_id( if direction: filter['direction'] = direction - return await Chats.get_chat_list_by_user_id(user_id, include_archived=True, filter=filter, skip=skip, limit=limit, db=db) + return await Chats.get_chat_list_by_user_id( + user_id, include_archived=True, filter=filter, skip=skip, limit=limit, db=db + ) ############################ @@ -620,7 +624,9 @@ async def search_user_chats( @router.get('/folder/{folder_id}', response_model=list[ChatResponse]) -async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): +async def get_chats_by_folder_id( + folder_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) +): folder_ids = [folder_id] children_folders = await Folders.get_children_folders_by_id_and_user_id(folder_id, user.id, db=db) if children_folders: @@ -815,7 +821,9 @@ async def get_shared_session_user_chat_list( @router.get('/share/{share_id}', response_model=Optional[ChatResponse]) -async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): +async def get_shared_chat_by_id( + share_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) +): if user.role == 'pending': raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND) @@ -851,7 +859,9 @@ async def get_user_chat_list_by_tag_name( user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): - chats = await Chats.get_chat_list_by_user_id_and_tag_name(user.id, form_data.name, form_data.skip, form_data.limit, db=db) + chats = await Chats.get_chat_list_by_user_id_and_tag_name( + user.id, form_data.name, form_data.skip, form_data.limit, db=db + ) if len(chats) == 0: await Tags.delete_tag_by_name_and_user_id(form_data.name, user.id, db=db) @@ -1056,7 +1066,9 @@ async def delete_chat_by_id( @router.get('/{id}/pinned', response_model=Optional[bool]) -async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): +async def get_pinned_status_by_id( + id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) +): chat = await Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: return chat.pinned @@ -1137,7 +1149,9 @@ async def clone_chat_by_id( @router.post('/{id}/clone/shared', response_model=Optional[ChatResponse]) -async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): +async def clone_shared_chat_by_id( + id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) +): if user.role == 'admin': chat = await Chats.get_chat_by_id(id, db=db) else: @@ -1250,7 +1264,9 @@ async def share_chat_by_id( @router.delete('/{id}/share', response_model=Optional[bool]) -async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): +async def delete_shared_chat_by_id( + id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) +): chat = await Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: if not chat.share_id: @@ -1371,7 +1387,9 @@ async def delete_tag_by_id_and_tag_name( @router.delete('/{id}/tags/all', response_model=Optional[bool]) -async def delete_all_tags_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): +async def delete_all_tags_by_id( + id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) +): chat = await Chats.get_chat_by_id_and_user_id(id, user.id, db=db) if chat: old_tags = chat.meta.get('tags', []) diff --git a/backend/open_webui/routers/evaluations.py b/backend/open_webui/routers/evaluations.py index 6a847c22a5..072c7fa732 100644 --- a/backend/open_webui/routers/evaluations.py +++ b/backend/open_webui/routers/evaluations.py @@ -415,7 +415,9 @@ async def update_feedback_by_id( @router.delete('/feedback/{id}') -async def delete_feedback_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): +async def delete_feedback_by_id( + id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) +): if user.role == 'admin': success = await Feedbacks.delete_feedback_by_id(id=id, db=db) else: diff --git a/backend/open_webui/routers/files.py b/backend/open_webui/routers/files.py index 70e0f468f9..b5e8ea83ce 100644 --- a/backend/open_webui/routers/files.py +++ b/backend/open_webui/routers/files.py @@ -495,7 +495,9 @@ async def get_file_process_status( @router.get('/{id}/data/content') -async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): +async def get_file_data_content_by_id( + id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) +): file = await Files.get_file_by_id(id, db=db) if not file: @@ -646,7 +648,9 @@ async def get_file_content_by_id( @router.get('/{id}/content/html') -async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): +async def get_html_file_content_by_id( + id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) +): file = await Files.get_file_by_id(id, db=db) if not file: @@ -693,7 +697,9 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user), @router.get('/{id}/content/{file_name}') -async def get_file_content_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): +async def get_file_content_by_id( + id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) +): file = await Files.get_file_by_id(id, db=db) if not file: diff --git a/backend/open_webui/routers/folders.py b/backend/open_webui/routers/folders.py index 9938d1eca1..ebd0c0cb17 100644 --- a/backend/open_webui/routers/folders.py +++ b/backend/open_webui/routers/folders.py @@ -89,7 +89,9 @@ async def get_folders( valid_files.append(file) folder.data['files'] = valid_files - await Folders.update_folder_by_id_and_user_id(folder.id, user.id, FolderUpdateForm(data=folder.data), db=db) + await Folders.update_folder_by_id_and_user_id( + folder.id, user.id, FolderUpdateForm(data=folder.data), db=db + ) folder_list.append(FolderNameIdResponse(**folder.model_dump())) @@ -107,7 +109,9 @@ async def create_folder( user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session), ): - folder = await Folders.get_folder_by_parent_id_and_user_id_and_name(form_data.parent_id, user.id, form_data.name, db=db) + folder = await Folders.get_folder_by_parent_id_and_user_id_and_name( + form_data.parent_id, user.id, form_data.name, db=db + ) if folder: raise HTTPException( @@ -250,7 +254,9 @@ async def update_folder_is_expanded_by_id( folder = await Folders.get_folder_by_id_and_user_id(id, user.id, db=db) if folder: try: - folder = await Folders.update_folder_is_expanded_by_id_and_user_id(id, user.id, form_data.is_expanded, db=db) + folder = await Folders.update_folder_is_expanded_by_id_and_user_id( + id, user.id, form_data.is_expanded, db=db + ) return folder except Exception as e: log.exception(e) diff --git a/backend/open_webui/routers/functions.py b/backend/open_webui/routers/functions.py index 371079bed2..de09aa05a1 100644 --- a/backend/open_webui/routers/functions.py +++ b/backend/open_webui/routers/functions.py @@ -373,7 +373,9 @@ async def delete_function_by_id( @router.get('/id/{id}/valves', response_model=Optional[dict]) -async def get_function_valves_by_id(id: str, user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session)): +async def get_function_valves_by_id( + id: str, user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session) +): function = await Functions.get_function_by_id(id, db=db) if function: try: @@ -473,7 +475,9 @@ async def update_function_valves_by_id( @router.get('/id/{id}/valves/user', response_model=Optional[dict]) -async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): +async def get_function_user_valves_by_id( + id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) +): function = await Functions.get_function_by_id(id, db=db) if function: try: diff --git a/backend/open_webui/routers/knowledge.py b/backend/open_webui/routers/knowledge.py index 3022763b49..c6f8ce5ecd 100644 --- a/backend/open_webui/routers/knowledge.py +++ b/backend/open_webui/routers/knowledge.py @@ -858,7 +858,9 @@ async def remove_file_from_knowledge_by_id( @router.delete('/{id}/delete', response_model=bool) -async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): +async def delete_knowledge_by_id( + id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) +): knowledge = await Knowledges.get_knowledge_by_id(id=id, db=db) if not knowledge: raise HTTPException( @@ -931,7 +933,9 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user), db: A @router.post('/{id}/reset', response_model=Optional[KnowledgeResponse]) -async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): +async def reset_knowledge_by_id( + id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) +): knowledge = await Knowledges.get_knowledge_by_id(id=id, db=db) if not knowledge: raise HTTPException( diff --git a/backend/open_webui/routers/ollama.py b/backend/open_webui/routers/ollama.py index 93b34f5e66..0c6fe73abf 100644 --- a/backend/open_webui/routers/ollama.py +++ b/backend/open_webui/routers/ollama.py @@ -138,7 +138,10 @@ async def send_request( headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = metadata.get('chat_id') r = await session.request( - method, url, data=payload, headers=headers, + method, + url, + data=payload, + headers=headers, ssl=AIOHTTP_CLIENT_SESSION_SSL, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT), ) @@ -782,7 +785,8 @@ async def delete_model( key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS) await send_request( - f'{url}/api/delete', 'DELETE', + f'{url}/api/delete', + 'DELETE', payload=json.dumps(form_data), key=key, user=user, @@ -1650,9 +1654,7 @@ async def upload_model( url = f'{ollama_url}/api/blobs/sha256:{file_hash}' upload_timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT) async with aiohttp.ClientSession(timeout=upload_timeout, trust_env=True) as upload_session: - async with upload_session.post( - url, data=blob_data, ssl=AIOHTTP_CLIENT_SESSION_SSL - ) as response: + async with upload_session.post(url, data=blob_data, ssl=AIOHTTP_CLIENT_SESSION_SSL) as response: if not response.ok: raise Exception('Ollama: Could not create blob, Please try again.') diff --git a/backend/open_webui/routers/openai.py b/backend/open_webui/routers/openai.py index 05e6d1ea59..09db29e6d6 100644 --- a/backend/open_webui/routers/openai.py +++ b/backend/open_webui/routers/openai.py @@ -1211,7 +1211,7 @@ async def generate_chat_completion( except json.JSONDecodeError: return JSONResponse( status_code=r.status, - content={"error": {"message": error_body, "code": r.status}}, + content={'error': {'message': error_body, 'code': r.status}}, ) streaming = True diff --git a/backend/open_webui/routers/prompts.py b/backend/open_webui/routers/prompts.py index ed9b69af06..a4a75754f0 100644 --- a/backend/open_webui/routers/prompts.py +++ b/backend/open_webui/routers/prompts.py @@ -198,7 +198,9 @@ async def create_new_prompt( @router.get('/command/{command}', response_model=Optional[PromptAccessResponse]) -async def get_prompt_by_command(command: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): +async def get_prompt_by_command( + command: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) +): prompt = await Prompts.get_prompt_by_command(command, db=db) if prompt: @@ -240,7 +242,9 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user), d @router.get('/id/{prompt_id}', response_model=Optional[PromptAccessResponse]) -async def get_prompt_by_id(prompt_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): +async def get_prompt_by_id( + prompt_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) +): prompt = await Prompts.get_prompt_by_id(prompt_id, db=db) if prompt: @@ -388,7 +392,9 @@ async def update_prompt_metadata( detail=f"Command '/{form_data.command}' is already in use", ) - updated_prompt = await Prompts.update_prompt_metadata(prompt.id, form_data.name, form_data.command, form_data.tags, db=db) + updated_prompt = await Prompts.update_prompt_metadata( + prompt.id, form_data.name, form_data.command, form_data.tags, db=db + ) if updated_prompt: return updated_prompt else: @@ -497,7 +503,9 @@ async def update_prompt_access_by_id( @router.post('/id/{prompt_id}/toggle', response_model=Optional[PromptModel]) -async def toggle_prompt_active(prompt_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): +async def toggle_prompt_active( + prompt_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) +): prompt = await Prompts.get_prompt_by_id(prompt_id, db=db) if not prompt: @@ -537,7 +545,9 @@ async def toggle_prompt_active(prompt_id: str, user=Depends(get_verified_user), @router.delete('/id/{prompt_id}/delete', response_model=bool) -async def delete_prompt_by_id(prompt_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): +async def delete_prompt_by_id( + prompt_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) +): prompt = await Prompts.get_prompt_by_id(prompt_id, db=db) if not prompt: diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py index edf9c8b5ef..041a866e37 100644 --- a/backend/open_webui/routers/tools.py +++ b/backend/open_webui/routers/tools.py @@ -622,7 +622,9 @@ async def delete_tools_by_id( @router.get('/id/{id}/valves', response_model=Optional[dict]) -async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): +async def get_tools_valves_by_id( + id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) +): tools = await Tools.get_tool_by_id(id, db=db) if not tools: raise HTTPException( @@ -775,7 +777,9 @@ async def update_tools_valves_by_id( @router.get('/id/{id}/valves/user', response_model=Optional[dict]) -async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): +async def get_tools_user_valves_by_id( + id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) +): tools = await Tools.get_tool_by_id(id, db=db) if not tools: raise HTTPException( diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py index 3253ab3707..336e54c44b 100644 --- a/backend/open_webui/routers/users.py +++ b/backend/open_webui/routers/users.py @@ -273,7 +273,9 @@ async def update_default_user_permissions(request: Request, form_data: UserPermi @router.get('/user/settings', response_model=Optional[UserSettings]) -async def get_user_settings_by_session_user(user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): +async def get_user_settings_by_session_user( + user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) +): user = await Users.get_user_by_id(user.id, db=db) if user: return user.settings @@ -468,7 +470,9 @@ async def get_user_by_id(user_id: str, user=Depends(get_admin_user), db: AsyncSe @router.get('/{user_id}/info', response_model=UserInfoResponse) -async def get_user_info_by_id(user_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)): +async def get_user_info_by_id( + user_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session) +): user = await Users.get_user_by_id(user_id, db=db) if user: groups = await Groups.get_groups_by_member_id(user_id, db=db) @@ -487,7 +491,9 @@ async def get_user_info_by_id(user_id: str, user=Depends(get_verified_user), db: @router.get('/{user_id}/oauth/sessions') -async def get_user_oauth_sessions_by_id(user_id: str, user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session)): +async def get_user_oauth_sessions_by_id( + user_id: str, user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session) +): sessions = await OAuthSessions.get_sessions_by_user_id(user_id, db=db) if sessions and len(sessions) > 0: return sessions @@ -685,5 +691,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user), db: Asyn @router.get('/{user_id}/groups') -async def get_user_groups_by_id(user_id: str, user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session)): +async def get_user_groups_by_id( + user_id: str, user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session) +): return await Groups.get_groups_by_member_id(user_id, db=db) diff --git a/backend/open_webui/utils/auth.py b/backend/open_webui/utils/auth.py index 9ac7524411..e0f331a9df 100644 --- a/backend/open_webui/utils/auth.py +++ b/backend/open_webui/utils/auth.py @@ -20,7 +20,6 @@ from pytz import UTC from typing import Optional, Union, List, Dict - from open_webui.utils.access_control import has_permission from open_webui.models.users import Users from open_webui.models.auths import Auths @@ -238,9 +237,7 @@ async def is_valid_token(request, decoded) -> bool: # Per-user revocation (OIDC back-channel logout) user_id = decoded.get('id') if user_id: - revoked_at = await request.app.state.redis.get( - f'{REDIS_KEY_PREFIX}:auth:user:{user_id}:revoked_at' - ) + revoked_at = await request.app.state.redis.get(f'{REDIS_KEY_PREFIX}:auth:user:{user_id}:revoked_at') if revoked_at: try: revoked_at_ts = int(revoked_at) @@ -385,6 +382,7 @@ async def get_current_user( # Refresh the user's last active timestamp # Fire-and-forget via asyncio.create_task to avoid blocking import asyncio + asyncio.create_task(Users.update_last_active_by_id(user.id)) return user else: @@ -432,15 +430,10 @@ async def get_current_user_by_api_key(request, api_key: str): # (Authorization header, cookie, x-api-key header, etc.). if request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS: allowed_paths = [ - path.strip() - for path in str(request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS).split(',') - if path.strip() + path.strip() for path in str(request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS).split(',') if path.strip() ] request_path = request.url.path - is_allowed = any( - request_path == allowed or request_path.startswith(allowed + '/') - for allowed in allowed_paths - ) + is_allowed = any(request_path == allowed or request_path.startswith(allowed + '/') for allowed in allowed_paths) if not is_allowed: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, diff --git a/backend/open_webui/utils/files.py b/backend/open_webui/utils/files.py index ef7900a1ce..21e3af5752 100644 --- a/backend/open_webui/utils/files.py +++ b/backend/open_webui/utils/files.py @@ -88,7 +88,7 @@ async def convert_markdown_base64_images(request, content: str, metadata, user): last_end = 0 for match in MARKDOWN_IMAGE_URL_PATTERN.finditer(content): - result_parts.append(content[last_end:match.start()]) + result_parts.append(content[last_end : match.start()]) base64_string = match.group(2) if len(base64_string) > MIN_REPLACEMENT_URL_LENGTH: url = await get_image_url_from_base64(request, base64_string, metadata, user) diff --git a/backend/open_webui/utils/oauth.py b/backend/open_webui/utils/oauth.py index 535adca5ec..835b8a0035 100644 --- a/backend/open_webui/utils/oauth.py +++ b/backend/open_webui/utils/oauth.py @@ -1785,7 +1785,10 @@ class OAuthManager: log.warning(f'Back-channel logout: no configured provider matches issuer {token_issuer}') return JSONResponse( status_code=400, - content={'error': 'invalid_request', 'error_description': 'No configured provider matches token issuer'}, + content={ + 'error': 'invalid_request', + 'error_description': 'No configured provider matches token issuer', + }, ) # 4. Validate the logout_token signature and claims @@ -1886,5 +1889,7 @@ class OAuthManager: f'(email={user.email}, provider={matched_provider}, sessions_deleted={len(sessions)})' ) - log.info(f'Back-channel logout: completed for {len(users_to_logout)} user(s), {revoked_count} revocation(s) set') + log.info( + f'Back-channel logout: completed for {len(users_to_logout)} user(s), {revoked_count} revocation(s) set' + ) return JSONResponse(status_code=200, content={}) diff --git a/backend/open_webui/utils/redis.py b/backend/open_webui/utils/redis.py index 7a114393b0..cb570cb45a 100644 --- a/backend/open_webui/utils/redis.py +++ b/backend/open_webui/utils/redis.py @@ -194,20 +194,12 @@ def get_redis_connection( connection = None connect_timeout_kwargs = ( - {'socket_connect_timeout': REDIS_SOCKET_CONNECT_TIMEOUT} - if REDIS_SOCKET_CONNECT_TIMEOUT is not None - else {} + {'socket_connect_timeout': REDIS_SOCKET_CONNECT_TIMEOUT} if REDIS_SOCKET_CONNECT_TIMEOUT is not None else {} ) - keepalive_kwargs = ( - {'socket_keepalive': True} if REDIS_SOCKET_KEEPALIVE else {} - ) + keepalive_kwargs = {'socket_keepalive': True} if REDIS_SOCKET_KEEPALIVE else {} - health_check_kwargs = ( - {'health_check_interval': REDIS_HEALTH_CHECK_INTERVAL} - if REDIS_HEALTH_CHECK_INTERVAL - else {} - ) + health_check_kwargs = {'health_check_interval': REDIS_HEALTH_CHECK_INTERVAL} if REDIS_HEALTH_CHECK_INTERVAL else {} if async_mode: import redis.asyncio as redis diff --git a/backend/open_webui/utils/session_pool.py b/backend/open_webui/utils/session_pool.py index 86ffa6cd9d..e91579f4af 100644 --- a/backend/open_webui/utils/session_pool.py +++ b/backend/open_webui/utils/session_pool.py @@ -64,8 +64,7 @@ async def get_session() -> aiohttp.ClientSession: trust_env=True, ) log.info( - 'Created shared aiohttp session pool ' - '(limit=%s, per_host=%s, dns_ttl=%d)', + 'Created shared aiohttp session pool (limit=%s, per_host=%s, dns_ttl=%d)', AIOHTTP_POOL_CONNECTIONS or 'unlimited', AIOHTTP_POOL_CONNECTIONS_PER_HOST or 'unlimited', AIOHTTP_POOL_DNS_TTL, diff --git a/backend/open_webui/utils/tools.py b/backend/open_webui/utils/tools.py index fb2ec8ad30..b071a48b26 100644 --- a/backend/open_webui/utils/tools.py +++ b/backend/open_webui/utils/tools.py @@ -101,7 +101,9 @@ log = logging.getLogger(__name__) # Let no function be called without need, and let what # it yields justify the cost of running it. -async def get_async_tool_function_and_apply_extra_params(function: Callable, extra_params: dict) -> Callable[..., Awaitable]: +async def get_async_tool_function_and_apply_extra_params( + function: Callable, extra_params: dict +) -> Callable[..., Awaitable]: sig = inspect.signature(function) extra_params = {k: v for k, v in extra_params.items() if k in sig.parameters} partial_func = partial(function, **extra_params) @@ -544,7 +546,9 @@ async def get_builtin_tools( # Automation tools - create and manage scheduled automations from chat if is_builtin_tool_enabled('automations') and await has_user_permission('automations'): - builtin_functions.extend([create_automation, update_automation, list_automations, toggle_automation, delete_automation]) + builtin_functions.extend( + [create_automation, update_automation, list_automations, toggle_automation, delete_automation] + ) for func in builtin_functions: callable = await get_async_tool_function_and_apply_extra_params( diff --git a/backend/open_webui/utils/validate.py b/backend/open_webui/utils/validate.py index 4decac5eb7..1e98b41105 100644 --- a/backend/open_webui/utils/validate.py +++ b/backend/open_webui/utils/validate.py @@ -13,19 +13,19 @@ _USER_PROFILE_IMAGE_RE = re.compile(r'^/api/v1/users/[^/?#]+/profile/image$') # regex across megabytes of data on every Pydantic instantiation for zero # security benefit (corrupt base64 simply renders a broken image, same as # a 404 URL). SVG is intentionally excluded: it can carry embedded scripts. -_SAFE_DATA_URI_RE = re.compile( - r'^data:image/(png|jpeg|gif|webp);base64,', re.IGNORECASE -) +_SAFE_DATA_URI_RE = re.compile(r'^data:image/(png|jpeg|gif|webp);base64,', re.IGNORECASE) # Exact relative paths accepted as profile images. These are the only # static-asset paths OWUI itself assigns; no prefix/wildcard matching is # used so that arbitrary relative paths cannot trigger authenticated GETs # against internal endpoints when rendered as ```` sources. -_SAFE_STATIC_PATHS = frozenset({ - '/user.png', - '/favicon.png', - '/static/favicon.png', -}) +_SAFE_STATIC_PATHS = frozenset( + { + '/user.png', + '/favicon.png', + '/static/favicon.png', + } +) def validate_profile_image_url(url: str) -> str: @@ -67,9 +67,7 @@ def validate_profile_image_url(url: str) -> str: # for a URL like http://:80/path with no actual host). if parsed.scheme in ('http', 'https'): if not parsed.hostname: - raise ValueError( - 'Invalid profile image URL: HTTP(S) URLs must include a host.' - ) + raise ValueError('Invalid profile image URL: HTTP(S) URLs must include a host.') return url # Base64-encoded raster images uploaded via the frontend. diff --git a/src/lib/components/automations/AutomationEditor.svelte b/src/lib/components/automations/AutomationEditor.svelte index cb5859286a..463cb0fa70 100644 --- a/src/lib/components/automations/AutomationEditor.svelte +++ b/src/lib/components/automations/AutomationEditor.svelte @@ -20,7 +20,6 @@ type AutomationRunModel } from '$lib/apis/automations'; - import Spinner from '$lib/components/common/Spinner.svelte'; import Tooltip from '$lib/components/common/Tooltip.svelte'; import DeleteConfirmDialog from '$lib/components/common/ConfirmDialog.svelte'; @@ -42,7 +41,6 @@ let model_id = ''; let is_active = true; - let loading = false; let saving = false; let showDeleteConfirm = false; @@ -335,7 +333,6 @@ {$i18n.t('Model')} - diff --git a/src/lib/components/chat/ChatPlaceholder.svelte b/src/lib/components/chat/ChatPlaceholder.svelte index ce54ecd551..e497b6b24d 100644 --- a/src/lib/components/chat/ChatPlaceholder.svelte +++ b/src/lib/components/chat/ChatPlaceholder.svelte @@ -46,11 +46,13 @@ }} > ') - ))} + content={DOMPurify.sanitize( + marked.parse( + sanitizeResponseContent( + models[selectedModelIdx]?.info?.meta?.description ?? '' + ).replaceAll('\n', '
') + ) + )} placement="right" > - {@html DOMPurify.sanitize(marked.parse( - sanitizeResponseContent( - models[selectedModelIdx]?.info?.meta?.description - ).replaceAll('\n', '
') - ))} + {@html DOMPurify.sanitize( + marked.parse( + sanitizeResponseContent( + models[selectedModelIdx]?.info?.meta?.description + ).replaceAll('\n', '
') + ) + )} {#if models[selectedModelIdx]?.info?.meta?.user}
diff --git a/src/lib/components/chat/FileNav/FilePreview.svelte b/src/lib/components/chat/FileNav/FilePreview.svelte index b5a3a30cc4..48a1ccdb88 100644 --- a/src/lib/components/chat/FileNav/FilePreview.svelte +++ b/src/lib/components/chat/FileNav/FilePreview.svelte @@ -399,7 +399,8 @@ {/if}