diff --git a/backend/open_webui/main.py b/backend/open_webui/main.py index 12541ff66e..7611ca454b 100644 --- a/backend/open_webui/main.py +++ b/backend/open_webui/main.py @@ -470,8 +470,11 @@ from open_webui.env import ( WEBUI_SESSION_COOKIE_SECURE, ) from open_webui.internal.db import ScopedSession, engine, get_async_session +from open_webui.models.access_grants import AccessGrants +from open_webui.models.channels import Channels from open_webui.models.chats import ChatForm, Chats from open_webui.models.functions import Functions +from open_webui.models.messages import Messages from open_webui.models.models import Models from open_webui.models.users import UserModel, Users from open_webui.routers import ( @@ -1802,6 +1805,44 @@ async def chat_completion( if metadata.get('chat_id') and user: chat_id = metadata['chat_id'] + + # Gate channel: branch — caller needs write access on the channel + # and the supplied message_id must belong to that channel. + if chat_id.startswith('channel:'): + channel_id = chat_id.removeprefix('channel:') + channel = await Channels.get_channel_by_id(channel_id) + if not channel: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=ERROR_MESSAGES.NOT_FOUND, + ) + if user.role != 'admin': + if channel.type in ['group', 'dm']: + if not await Channels.is_user_channel_member(channel.id, user.id): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.DEFAULT(), + ) + else: + if not await AccessGrants.has_access( + user_id=user.id, + resource_type='channel', + resource_id=channel.id, + permission='write', + ): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.DEFAULT(), + ) + target_message_id = list(message_ids.values())[0] if message_ids else None + if target_message_id: + target_message = await Messages.get_message_by_id(target_message_id) + if target_message and target_message.channel_id != channel.id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=ERROR_MESSAGES.DEFAULT(), + ) + if not chat_id.startswith('local:') and not chat_id.startswith( 'channel:' ): # temporary/channel chats are not stored