mirror of
https://github.com/open-webui/open-webui.git
synced 2026-06-14 03:30:25 +00:00
chore: format
This commit is contained in:
@@ -431,9 +431,7 @@ except ValueError:
|
||||
# enabled, the kernel sends TCP keepalive probes on idle connections so
|
||||
# half-closed sockets (e.g. after a silent firewall/LB reset or a NIC
|
||||
# flap) are detected before the next command lands on them.
|
||||
REDIS_SOCKET_KEEPALIVE = (
|
||||
os.environ.get('REDIS_SOCKET_KEEPALIVE', 'False').lower() == 'true'
|
||||
)
|
||||
REDIS_SOCKET_KEEPALIVE = os.environ.get('REDIS_SOCKET_KEEPALIVE', 'False').lower() == 'true'
|
||||
|
||||
# How often (in seconds) redis-py should PING an idle pooled connection
|
||||
# before reusing it. Opt-in: defaults to unset (empty string) so behavior
|
||||
|
||||
@@ -214,6 +214,7 @@ if 'sqlite' in ASYNC_SQLALCHEMY_DATABASE_URL:
|
||||
)
|
||||
|
||||
if DATABASE_ENABLE_SQLITE_WAL:
|
||||
|
||||
@event.listens_for(async_engine.sync_engine, 'connect')
|
||||
def _set_sqlite_wal(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
|
||||
@@ -723,6 +723,7 @@ async def lifespan(app: FastAPI):
|
||||
|
||||
# Shutdown: clean up shared resources
|
||||
from open_webui.utils.session_pool import close_session
|
||||
|
||||
await close_session()
|
||||
|
||||
if hasattr(app.state, 'redis_task_command_listener'):
|
||||
@@ -1397,7 +1398,6 @@ app.add_middleware(RedirectMiddleware)
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
|
||||
|
||||
@app.middleware('http')
|
||||
async def commit_session_after_request(request: Request, call_next):
|
||||
response = await call_next(request)
|
||||
@@ -1992,7 +1992,7 @@ async def list_tasks_endpoint(request: Request, user=Depends(get_admin_user)):
|
||||
@app.get('/api/tasks/chat/{chat_id:path}')
|
||||
async def list_tasks_by_chat_id_endpoint(request: Request, chat_id: str, user=Depends(get_verified_user)):
|
||||
if chat_id.startswith('local:'):
|
||||
socket_id = chat_id[len('local:'):]
|
||||
socket_id = chat_id[len('local:') :]
|
||||
owner_id = get_user_id_from_session_pool(socket_id)
|
||||
if owner_id != user.id and user.role != 'admin':
|
||||
return {'task_ids': []}
|
||||
@@ -2010,7 +2010,7 @@ async def list_tasks_by_chat_id_endpoint(request: Request, chat_id: str, user=De
|
||||
@app.post('/api/tasks/chat/{chat_id:path}/stop')
|
||||
async def stop_tasks_by_chat_id_endpoint(request: Request, chat_id: str, user=Depends(get_verified_user)):
|
||||
if chat_id.startswith('local:'):
|
||||
socket_id = chat_id[len('local:'):]
|
||||
socket_id = chat_id[len('local:') :]
|
||||
owner_id = get_user_id_from_session_pool(socket_id)
|
||||
if owner_id != user.id and user.role != 'admin':
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.NOT_FOUND)
|
||||
|
||||
@@ -295,8 +295,7 @@ class AccessGrantsTable:
|
||||
async with get_async_db_context(db) as db:
|
||||
# Check for existing grant
|
||||
result = await db.execute(
|
||||
select(AccessGrant)
|
||||
.filter_by(
|
||||
select(AccessGrant).filter_by(
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
principal_type=principal_type,
|
||||
@@ -334,8 +333,7 @@ class AccessGrantsTable:
|
||||
"""Remove a single access grant."""
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
delete(AccessGrant)
|
||||
.filter_by(
|
||||
delete(AccessGrant).filter_by(
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
principal_type=principal_type,
|
||||
@@ -355,8 +353,7 @@ class AccessGrantsTable:
|
||||
"""Remove all access grants for a resource."""
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
delete(AccessGrant)
|
||||
.filter_by(
|
||||
delete(AccessGrant).filter_by(
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
)
|
||||
@@ -451,8 +448,7 @@ class AccessGrantsTable:
|
||||
"""
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(AccessGrant)
|
||||
.filter_by(
|
||||
select(AccessGrant).filter_by(
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
)
|
||||
@@ -470,8 +466,7 @@ class AccessGrantsTable:
|
||||
"""Get all grants for a specific resource."""
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(AccessGrant)
|
||||
.filter_by(
|
||||
select(AccessGrant).filter_by(
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
)
|
||||
@@ -490,8 +485,7 @@ class AccessGrantsTable:
|
||||
return {}
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(AccessGrant)
|
||||
.filter(
|
||||
select(AccessGrant).filter(
|
||||
AccessGrant.resource_type == resource_type,
|
||||
AccessGrant.resource_id.in_(resource_ids),
|
||||
)
|
||||
@@ -634,8 +628,7 @@ class AccessGrantsTable:
|
||||
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(AccessGrant)
|
||||
.filter_by(
|
||||
select(AccessGrant).filter_by(
|
||||
resource_type=resource_type,
|
||||
resource_id=resource_id,
|
||||
permission=permission,
|
||||
|
||||
@@ -141,7 +141,9 @@ class AuthsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def authenticate_user_by_api_key(self, api_key: str, db: Optional[AsyncSession] = None) -> Optional[UserModel]:
|
||||
async def authenticate_user_by_api_key(
|
||||
self, api_key: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[UserModel]:
|
||||
log.info(f'authenticate_user_by_api_key')
|
||||
# if no api_key, return None
|
||||
if not api_key:
|
||||
@@ -159,9 +161,7 @@ class AuthsTable:
|
||||
async with get_async_db_context(db) as db:
|
||||
# Single JOIN query instead of two separate queries
|
||||
result = await db.execute(
|
||||
select(Auth, User)
|
||||
.join(User, Auth.id == User.id)
|
||||
.filter(Auth.email == email, Auth.active == True)
|
||||
select(Auth, User).join(User, Auth.id == User.id).filter(Auth.email == email, Auth.active == True)
|
||||
)
|
||||
row = result.first()
|
||||
if row:
|
||||
|
||||
@@ -145,9 +145,7 @@ class AutomationTable:
|
||||
|
||||
async def count_by_user(self, user_id: str, db: Optional[AsyncSession] = None) -> int:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(func.count()).select_from(Automation).filter_by(user_id=user_id)
|
||||
)
|
||||
result = await db.execute(select(func.count()).select_from(Automation).filter_by(user_id=user_id))
|
||||
return result.scalar()
|
||||
|
||||
async def get_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[AutomationModel]:
|
||||
@@ -185,9 +183,7 @@ class AutomationTable:
|
||||
stmt = stmt.order_by(Automation.created_at.desc())
|
||||
|
||||
# Get total count
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(stmt.subquery())
|
||||
)
|
||||
count_result = await db.execute(select(func.count()).select_from(stmt.subquery()))
|
||||
total = count_result.scalar()
|
||||
|
||||
if skip:
|
||||
@@ -343,18 +339,14 @@ class AutomationRunTable:
|
||||
.subquery()
|
||||
)
|
||||
result = await db.execute(
|
||||
select(AutomationRun)
|
||||
.join(
|
||||
select(AutomationRun).join(
|
||||
subq,
|
||||
(AutomationRun.automation_id == subq.c.automation_id)
|
||||
& (AutomationRun.created_at == subq.c.max_created),
|
||||
)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
return {
|
||||
row.automation_id: AutomationRunModel.model_validate(row)
|
||||
for row in rows
|
||||
}
|
||||
return {row.automation_id: AutomationRunModel.model_validate(row) for row in rows}
|
||||
|
||||
async def get_by_automation(
|
||||
self,
|
||||
|
||||
@@ -414,9 +414,13 @@ class ChannelTable:
|
||||
all_channels = list(membership_channels) + list(standard_channels)
|
||||
channel_ids = [c.id for c in all_channels]
|
||||
grants_map = await AccessGrants.get_grants_by_resources('channel', channel_ids, db=db)
|
||||
return [await self._to_channel_model(c, access_grants=grants_map.get(c.id, []), db=db) for c in all_channels]
|
||||
return [
|
||||
await self._to_channel_model(c, access_grants=grants_map.get(c.id, []), db=db) for c in all_channels
|
||||
]
|
||||
|
||||
async def get_dm_channel_by_user_ids(self, user_ids: list[str], db: Optional[AsyncSession] = None) -> Optional[ChannelModel]:
|
||||
async def get_dm_channel_by_user_ids(
|
||||
self, user_ids: list[str], db: Optional[AsyncSession] = None
|
||||
) -> Optional[ChannelModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
# Ensure uniqueness in case a list with duplicates is passed
|
||||
unique_user_ids = list(set(user_ids))
|
||||
@@ -462,9 +466,7 @@ class ChannelTable:
|
||||
# 1. Collect all user_ids including groups + inviter
|
||||
requested_users = await self._collect_unique_user_ids(invited_by, user_ids, group_ids)
|
||||
|
||||
result = await db.execute(
|
||||
select(ChannelMember.user_id).filter(ChannelMember.channel_id == channel_id)
|
||||
)
|
||||
result = await db.execute(select(ChannelMember.user_id).filter(ChannelMember.channel_id == channel_id))
|
||||
existing_users = {row[0] for row in result.all()}
|
||||
|
||||
new_user_ids = requested_users - existing_users
|
||||
@@ -512,7 +514,9 @@ class ChannelTable:
|
||||
membership = result.scalars().first()
|
||||
return membership is not None
|
||||
|
||||
async def join_channel(self, channel_id: str, user_id: str, db: Optional[AsyncSession] = None) -> Optional[ChannelMemberModel]:
|
||||
async def join_channel(
|
||||
self, channel_id: str, user_id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[ChannelMemberModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
# Check if the membership already exists
|
||||
result = await db.execute(
|
||||
@@ -581,11 +585,11 @@ class ChannelTable:
|
||||
membership = result.scalars().first()
|
||||
return ChannelMemberModel.model_validate(membership) if membership else None
|
||||
|
||||
async def get_members_by_channel_id(self, channel_id: str, db: Optional[AsyncSession] = None) -> list[ChannelMemberModel]:
|
||||
async def get_members_by_channel_id(
|
||||
self, channel_id: str, db: Optional[AsyncSession] = None
|
||||
) -> list[ChannelMemberModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(ChannelMember).filter(ChannelMember.channel_id == channel_id)
|
||||
)
|
||||
result = await db.execute(select(ChannelMember).filter(ChannelMember.channel_id == channel_id))
|
||||
memberships = result.scalars().all()
|
||||
return [ChannelMemberModel.model_validate(membership) for membership in memberships]
|
||||
|
||||
@@ -613,7 +617,9 @@ class ChannelTable:
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
async def update_member_last_read_at(self, channel_id: str, user_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def update_member_last_read_at(
|
||||
self, channel_id: str, user_id: str, db: Optional[AsyncSession] = None
|
||||
) -> bool:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(ChannelMember).filter(
|
||||
@@ -658,11 +664,13 @@ class ChannelTable:
|
||||
async def is_user_channel_member(self, channel_id: str, user_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(ChannelMember).filter(
|
||||
select(ChannelMember)
|
||||
.filter(
|
||||
ChannelMember.channel_id == channel_id,
|
||||
ChannelMember.user_id == user_id,
|
||||
ChannelMember.is_active.is_(True),
|
||||
).limit(1)
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
membership = result.scalars().first()
|
||||
return membership is not None
|
||||
@@ -726,11 +734,13 @@ class ChannelTable:
|
||||
# --- Case A: group or dm => user must be an active member ---
|
||||
if channel.type in ['group', 'dm']:
|
||||
result = await db.execute(
|
||||
select(ChannelMember).filter(
|
||||
select(ChannelMember)
|
||||
.filter(
|
||||
ChannelMember.channel_id == channel.id,
|
||||
ChannelMember.user_id == user_id,
|
||||
ChannelMember.is_active.is_(True),
|
||||
).limit(1)
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
membership = result.scalars().first()
|
||||
if membership:
|
||||
@@ -774,11 +784,13 @@ class ChannelTable:
|
||||
# If the channel is a group or dm, read access requires membership (active)
|
||||
if channel.type in ['group', 'dm']:
|
||||
result = await db.execute(
|
||||
select(ChannelMember).filter(
|
||||
select(ChannelMember)
|
||||
.filter(
|
||||
ChannelMember.channel_id == id,
|
||||
ChannelMember.user_id == user_id,
|
||||
ChannelMember.is_active.is_(True),
|
||||
).limit(1)
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
membership = result.scalars().first()
|
||||
if membership:
|
||||
@@ -863,9 +875,7 @@ class ChannelTable:
|
||||
) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(ChannelFile).filter_by(channel_id=channel_id, file_id=file_id)
|
||||
)
|
||||
result = await db.execute(select(ChannelFile).filter_by(channel_id=channel_id, file_id=file_id))
|
||||
channel_file = result.scalars().first()
|
||||
if not channel_file:
|
||||
return False
|
||||
@@ -878,7 +888,9 @@ class ChannelTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def remove_file_from_channel_by_id(self, channel_id: str, file_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def remove_file_from_channel_by_id(
|
||||
self, channel_id: str, file_id: str, db: Optional[AsyncSession] = None
|
||||
) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(delete(ChannelFile).filter_by(channel_id=channel_id, file_id=file_id))
|
||||
@@ -921,13 +933,17 @@ class ChannelTable:
|
||||
await db.commit()
|
||||
return webhook
|
||||
|
||||
async def get_webhooks_by_channel_id(self, channel_id: str, db: Optional[AsyncSession] = None) -> list[ChannelWebhookModel]:
|
||||
async def get_webhooks_by_channel_id(
|
||||
self, channel_id: str, db: Optional[AsyncSession] = None
|
||||
) -> list[ChannelWebhookModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(ChannelWebhook).filter(ChannelWebhook.channel_id == channel_id))
|
||||
webhooks = result.scalars().all()
|
||||
return [ChannelWebhookModel.model_validate(w) for w in webhooks]
|
||||
|
||||
async def get_webhook_by_id(self, webhook_id: str, db: Optional[AsyncSession] = None) -> Optional[ChannelWebhookModel]:
|
||||
async def get_webhook_by_id(
|
||||
self, webhook_id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[ChannelWebhookModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(ChannelWebhook).filter(ChannelWebhook.id == webhook_id))
|
||||
webhook = result.scalars().first()
|
||||
|
||||
@@ -272,13 +272,10 @@ class ChatMessageTable:
|
||||
"""Get distinct chat_ids that used a specific model."""
|
||||
|
||||
async with get_async_db_context(db) as db:
|
||||
stmt = (
|
||||
select(
|
||||
ChatMessage.chat_id,
|
||||
func.max(ChatMessage.created_at).label('last_message_at'),
|
||||
)
|
||||
.filter(ChatMessage.model_id == model_id)
|
||||
)
|
||||
stmt = select(
|
||||
ChatMessage.chat_id,
|
||||
func.max(ChatMessage.created_at).label('last_message_at'),
|
||||
).filter(ChatMessage.model_id == model_id)
|
||||
if start_date:
|
||||
stmt = stmt.filter(ChatMessage.created_at >= start_date)
|
||||
if end_date:
|
||||
@@ -313,13 +310,10 @@ class ChatMessageTable:
|
||||
async with get_async_db_context(db) as db:
|
||||
from open_webui.models.groups import GroupMember
|
||||
|
||||
stmt = (
|
||||
select(ChatMessage.model_id, func.count(ChatMessage.id).label('count'))
|
||||
.filter(
|
||||
ChatMessage.role == 'assistant',
|
||||
ChatMessage.model_id.isnot(None),
|
||||
~ChatMessage.user_id.like('shared-%'),
|
||||
)
|
||||
stmt = select(ChatMessage.model_id, func.count(ChatMessage.id).label('count')).filter(
|
||||
ChatMessage.role == 'assistant',
|
||||
ChatMessage.model_id.isnot(None),
|
||||
~ChatMessage.user_id.like('shared-%'),
|
||||
)
|
||||
|
||||
if start_date:
|
||||
@@ -365,19 +359,16 @@ class ChatMessageTable:
|
||||
else:
|
||||
raise NotImplementedError(f'Unsupported dialect: {dialect}')
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
ChatMessage.model_id,
|
||||
func.coalesce(func.sum(input_tokens), 0).label('input_tokens'),
|
||||
func.coalesce(func.sum(output_tokens), 0).label('output_tokens'),
|
||||
func.count(ChatMessage.id).label('message_count'),
|
||||
)
|
||||
.filter(
|
||||
ChatMessage.role == 'assistant',
|
||||
ChatMessage.model_id.isnot(None),
|
||||
ChatMessage.usage.isnot(None),
|
||||
~ChatMessage.user_id.like('shared-%'),
|
||||
)
|
||||
stmt = select(
|
||||
ChatMessage.model_id,
|
||||
func.coalesce(func.sum(input_tokens), 0).label('input_tokens'),
|
||||
func.coalesce(func.sum(output_tokens), 0).label('output_tokens'),
|
||||
func.count(ChatMessage.id).label('message_count'),
|
||||
).filter(
|
||||
ChatMessage.role == 'assistant',
|
||||
ChatMessage.model_id.isnot(None),
|
||||
ChatMessage.usage.isnot(None),
|
||||
~ChatMessage.user_id.like('shared-%'),
|
||||
)
|
||||
|
||||
if start_date:
|
||||
@@ -430,19 +421,16 @@ class ChatMessageTable:
|
||||
else:
|
||||
raise NotImplementedError(f'Unsupported dialect: {dialect}')
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
ChatMessage.user_id,
|
||||
func.coalesce(func.sum(input_tokens), 0).label('input_tokens'),
|
||||
func.coalesce(func.sum(output_tokens), 0).label('output_tokens'),
|
||||
func.count(ChatMessage.id).label('message_count'),
|
||||
)
|
||||
.filter(
|
||||
ChatMessage.role == 'assistant',
|
||||
ChatMessage.user_id.isnot(None),
|
||||
ChatMessage.usage.isnot(None),
|
||||
~ChatMessage.user_id.like('shared-%'),
|
||||
)
|
||||
stmt = select(
|
||||
ChatMessage.user_id,
|
||||
func.coalesce(func.sum(input_tokens), 0).label('input_tokens'),
|
||||
func.coalesce(func.sum(output_tokens), 0).label('output_tokens'),
|
||||
func.count(ChatMessage.id).label('message_count'),
|
||||
).filter(
|
||||
ChatMessage.role == 'assistant',
|
||||
ChatMessage.user_id.isnot(None),
|
||||
ChatMessage.usage.isnot(None),
|
||||
~ChatMessage.user_id.like('shared-%'),
|
||||
)
|
||||
|
||||
if start_date:
|
||||
@@ -476,9 +464,8 @@ class ChatMessageTable:
|
||||
async with get_async_db_context(db) as db:
|
||||
from open_webui.models.groups import GroupMember
|
||||
|
||||
stmt = (
|
||||
select(ChatMessage.user_id, func.count(ChatMessage.id).label('count'))
|
||||
.filter(~ChatMessage.user_id.like('shared-%'))
|
||||
stmt = select(ChatMessage.user_id, func.count(ChatMessage.id).label('count')).filter(
|
||||
~ChatMessage.user_id.like('shared-%')
|
||||
)
|
||||
|
||||
if start_date:
|
||||
@@ -503,9 +490,8 @@ class ChatMessageTable:
|
||||
async with get_async_db_context(db) as db:
|
||||
from open_webui.models.groups import GroupMember
|
||||
|
||||
stmt = (
|
||||
select(ChatMessage.chat_id, func.count(ChatMessage.id).label('count'))
|
||||
.filter(~ChatMessage.user_id.like('shared-%'))
|
||||
stmt = select(ChatMessage.chat_id, func.count(ChatMessage.id).label('count')).filter(
|
||||
~ChatMessage.user_id.like('shared-%')
|
||||
)
|
||||
|
||||
if start_date:
|
||||
@@ -532,13 +518,10 @@ class ChatMessageTable:
|
||||
from datetime import datetime, timedelta
|
||||
from open_webui.models.groups import GroupMember
|
||||
|
||||
stmt = (
|
||||
select(ChatMessage.created_at, ChatMessage.model_id)
|
||||
.filter(
|
||||
ChatMessage.role == 'assistant',
|
||||
ChatMessage.model_id.isnot(None),
|
||||
~ChatMessage.user_id.like('shared-%'),
|
||||
)
|
||||
stmt = select(ChatMessage.created_at, ChatMessage.model_id).filter(
|
||||
ChatMessage.role == 'assistant',
|
||||
ChatMessage.model_id.isnot(None),
|
||||
~ChatMessage.user_id.like('shared-%'),
|
||||
)
|
||||
|
||||
if start_date:
|
||||
@@ -582,13 +565,10 @@ class ChatMessageTable:
|
||||
async with get_async_db_context(db) as db:
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
stmt = (
|
||||
select(ChatMessage.created_at, ChatMessage.model_id)
|
||||
.filter(
|
||||
ChatMessage.role == 'assistant',
|
||||
ChatMessage.model_id.isnot(None),
|
||||
~ChatMessage.user_id.like('shared-%'),
|
||||
)
|
||||
stmt = select(ChatMessage.created_at, ChatMessage.model_id).filter(
|
||||
ChatMessage.role == 'assistant',
|
||||
ChatMessage.model_id.isnot(None),
|
||||
~ChatMessage.user_id.like('shared-%'),
|
||||
)
|
||||
|
||||
if start_date:
|
||||
|
||||
@@ -292,7 +292,9 @@ class ChatTable:
|
||||
|
||||
return changed
|
||||
|
||||
async def insert_new_chat(self, user_id: str, form_data: ChatForm, db: Optional[AsyncSession] = None) -> Optional[ChatModel]:
|
||||
async def insert_new_chat(
|
||||
self, user_id: str, form_data: ChatForm, db: Optional[AsyncSession] = None
|
||||
) -> Optional[ChatModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
id = str(uuid.uuid4())
|
||||
chat = ChatModel(
|
||||
@@ -551,7 +553,9 @@ class ChatTable:
|
||||
await self.update_chat_by_id(id, chat, db=db)
|
||||
return message_files
|
||||
|
||||
async def insert_shared_chat_by_chat_id(self, chat_id: str, db: Optional[AsyncSession] = None) -> Optional[ChatModel]:
|
||||
async def insert_shared_chat_by_chat_id(
|
||||
self, chat_id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[ChatModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
# Get the existing chat to share
|
||||
chat = await db.get(Chat, chat_id)
|
||||
@@ -585,7 +589,9 @@ class ChatTable:
|
||||
await db.commit()
|
||||
return shared_chat if shared_result else None
|
||||
|
||||
async def update_shared_chat_by_chat_id(self, chat_id: str, db: Optional[AsyncSession] = None) -> Optional[ChatModel]:
|
||||
async def update_shared_chat_by_chat_id(
|
||||
self, chat_id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[ChatModel]:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat = await db.get(Chat, chat_id)
|
||||
@@ -689,7 +695,9 @@ class ChatTable:
|
||||
db: Optional[AsyncSession] = None,
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at).filter_by(user_id=user_id, archived=True)
|
||||
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at).filter_by(
|
||||
user_id=user_id, archived=True
|
||||
)
|
||||
|
||||
if filter:
|
||||
query_key = filter.get('query')
|
||||
@@ -740,7 +748,11 @@ class ChatTable:
|
||||
db: Optional[AsyncSession] = None,
|
||||
) -> list[SharedChatResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
stmt = select(Chat.id, Chat.title, Chat.share_id, Chat.updated_at, Chat.created_at).filter_by(user_id=user_id).filter(Chat.share_id.isnot(None))
|
||||
stmt = (
|
||||
select(Chat.id, Chat.title, Chat.share_id, Chat.updated_at, Chat.created_at)
|
||||
.filter_by(user_id=user_id)
|
||||
.filter(Chat.share_id.isnot(None))
|
||||
)
|
||||
|
||||
if filter:
|
||||
query_key = filter.get('query')
|
||||
@@ -793,7 +805,9 @@ class ChatTable:
|
||||
db: Optional[AsyncSession] = None,
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by(user_id=user_id)
|
||||
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by(
|
||||
user_id=user_id
|
||||
)
|
||||
if not include_archived:
|
||||
stmt = stmt.filter_by(archived=False)
|
||||
|
||||
@@ -846,7 +860,9 @@ class ChatTable:
|
||||
db: Optional[AsyncSession] = None,
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by(user_id=user_id)
|
||||
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by(
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if not include_folders:
|
||||
stmt = stmt.filter_by(folder_id=None)
|
||||
@@ -889,10 +905,7 @@ class ChatTable:
|
||||
) -> list[ChatModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(Chat)
|
||||
.filter(Chat.id.in_(chat_ids))
|
||||
.filter_by(archived=False)
|
||||
.order_by(Chat.updated_at.desc())
|
||||
select(Chat).filter(Chat.id.in_(chat_ids)).filter_by(archived=False).order_by(Chat.updated_at.desc())
|
||||
)
|
||||
all_chats = result.scalars().all()
|
||||
return [ChatModel.model_validate(chat) for chat in all_chats]
|
||||
@@ -925,7 +938,9 @@ class ChatTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_chat_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> Optional[ChatModel]:
|
||||
async def get_chat_by_id_and_user_id(
|
||||
self, id: str, user_id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[ChatModel]:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Chat).filter_by(id=id, user_id=user_id))
|
||||
@@ -941,9 +956,7 @@ class ChatTable:
|
||||
"""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(exists().where(and_(Chat.id == id, Chat.user_id == user_id)))
|
||||
)
|
||||
result = await db.execute(select(exists().where(and_(Chat.id == id, Chat.user_id == user_id))))
|
||||
return result.scalar()
|
||||
except Exception:
|
||||
return False
|
||||
@@ -997,9 +1010,7 @@ class ChatTable:
|
||||
else:
|
||||
stmt = stmt.order_by(Chat.updated_at.desc(), Chat.id)
|
||||
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(stmt.subquery())
|
||||
)
|
||||
count_result = await db.execute(select(func.count()).select_from(stmt.subquery()))
|
||||
total = count_result.scalar()
|
||||
|
||||
if skip is not None:
|
||||
@@ -1017,7 +1028,9 @@ class ChatTable:
|
||||
}
|
||||
)
|
||||
|
||||
async def get_pinned_chats_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> list[ChatTitleIdResponse]:
|
||||
async def get_pinned_chats_by_user_id(
|
||||
self, user_id: str, db: Optional[AsyncSession] = None
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at)
|
||||
@@ -1060,7 +1073,9 @@ class ChatTable:
|
||||
search_text = sanitize_text_for_db(search_text).lower().strip()
|
||||
|
||||
if not search_text:
|
||||
return await self.get_chat_list_by_user_id(user_id, include_archived, filter={}, skip=skip, limit=limit, db=db)
|
||||
return await self.get_chat_list_by_user_id(
|
||||
user_id, include_archived, filter={}, skip=skip, limit=limit, db=db
|
||||
)
|
||||
|
||||
search_text_words = search_text.split(' ')
|
||||
|
||||
@@ -1305,7 +1320,9 @@ class ChatTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_chat_tags_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> list[TagModel]:
|
||||
async def get_chat_tags_by_id_and_user_id(
|
||||
self, id: str, user_id: str, db: Optional[AsyncSession] = None
|
||||
) -> list[TagModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat = await db.get(Chat, id)
|
||||
tag_ids = chat.meta.get('tags', [])
|
||||
@@ -1320,7 +1337,9 @@ class ChatTable:
|
||||
db: Optional[AsyncSession] = None,
|
||||
) -> list[ChatTitleIdResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by(user_id=user_id)
|
||||
stmt = select(Chat.id, Chat.title, Chat.updated_at, Chat.created_at, Chat.last_read_at).filter_by(
|
||||
user_id=user_id
|
||||
)
|
||||
tag_id = tag_name.replace(' ', '_').lower()
|
||||
|
||||
bind = await db.connection()
|
||||
@@ -1378,7 +1397,9 @@ class ChatTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def count_chats_by_tag_name_and_user_id(self, tag_name: str, user_id: str, db: Optional[AsyncSession] = None) -> int:
|
||||
async def count_chats_by_tag_name_and_user_id(
|
||||
self, tag_name: str, user_id: str, db: Optional[AsyncSession] = None
|
||||
) -> int:
|
||||
async with get_async_db_context(db) as db:
|
||||
stmt = select(func.count(Chat.id)).filter_by(user_id=user_id, archived=False)
|
||||
tag_id = tag_name.replace(' ', '_').lower()
|
||||
@@ -1424,11 +1445,11 @@ class ChatTable:
|
||||
orphans.append(tag_id)
|
||||
await Tags.delete_tags_by_ids_and_user_id(orphans, user_id, db=db)
|
||||
|
||||
async def count_chats_by_folder_id_and_user_id(self, folder_id: str, user_id: str, db: Optional[AsyncSession] = None) -> int:
|
||||
async def count_chats_by_folder_id_and_user_id(
|
||||
self, folder_id: str, user_id: str, db: Optional[AsyncSession] = None
|
||||
) -> int:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(func.count(Chat.id)).filter_by(user_id=user_id, folder_id=folder_id)
|
||||
)
|
||||
result = await db.execute(select(func.count(Chat.id)).filter_by(user_id=user_id, folder_id=folder_id))
|
||||
count = result.scalar()
|
||||
|
||||
log.info(f"Count of chats for folder '{folder_id}': {count}")
|
||||
@@ -1470,9 +1491,7 @@ class ChatTable:
|
||||
async def delete_chat_by_id(self, id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(
|
||||
update(AutomationRun).filter_by(chat_id=id).values(chat_id=None)
|
||||
)
|
||||
await db.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None))
|
||||
await db.execute(delete(ChatMessage).filter_by(chat_id=id))
|
||||
await db.execute(delete(Chat).filter_by(id=id))
|
||||
await db.commit()
|
||||
@@ -1484,9 +1503,7 @@ class ChatTable:
|
||||
async def delete_chat_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(
|
||||
update(AutomationRun).filter_by(chat_id=id).values(chat_id=None)
|
||||
)
|
||||
await db.execute(update(AutomationRun).filter_by(chat_id=id).values(chat_id=None))
|
||||
await db.execute(delete(ChatMessage).filter_by(chat_id=id))
|
||||
await db.execute(delete(Chat).filter_by(id=id, user_id=user_id))
|
||||
await db.commit()
|
||||
@@ -1502,7 +1519,9 @@ class ChatTable:
|
||||
|
||||
chat_id_subquery = select(Chat.id).filter_by(user_id=user_id).scalar_subquery()
|
||||
await db.execute(
|
||||
update(AutomationRun).filter(AutomationRun.chat_id.in_(select(Chat.id).filter_by(user_id=user_id))).values(chat_id=None)
|
||||
update(AutomationRun)
|
||||
.filter(AutomationRun.chat_id.in_(select(Chat.id).filter_by(user_id=user_id)))
|
||||
.values(chat_id=None)
|
||||
)
|
||||
await db.execute(
|
||||
delete(ChatMessage).filter(ChatMessage.chat_id.in_(select(Chat.id).filter_by(user_id=user_id)))
|
||||
@@ -1514,16 +1533,16 @@ class ChatTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def delete_chats_by_user_id_and_folder_id(self, user_id: str, folder_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_chats_by_user_id_and_folder_id(
|
||||
self, user_id: str, folder_id: str, db: Optional[AsyncSession] = None
|
||||
) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
chat_ids_stmt = select(Chat.id).filter_by(user_id=user_id, folder_id=folder_id)
|
||||
await db.execute(
|
||||
update(AutomationRun).filter(AutomationRun.chat_id.in_(chat_ids_stmt)).values(chat_id=None)
|
||||
)
|
||||
await db.execute(
|
||||
delete(ChatMessage).filter(ChatMessage.chat_id.in_(chat_ids_stmt))
|
||||
)
|
||||
await db.execute(delete(ChatMessage).filter(ChatMessage.chat_id.in_(chat_ids_stmt)))
|
||||
await db.execute(delete(Chat).filter_by(user_id=user_id, folder_id=folder_id))
|
||||
await db.commit()
|
||||
|
||||
@@ -1619,9 +1638,7 @@ class ChatTable:
|
||||
) -> list[ChatFileModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(ChatFile)
|
||||
.filter_by(chat_id=chat_id, message_id=message_id)
|
||||
.order_by(ChatFile.created_at.asc())
|
||||
select(ChatFile).filter_by(chat_id=chat_id, message_id=message_id).order_by(ChatFile.created_at.asc())
|
||||
)
|
||||
all_chat_files = result.scalars().all()
|
||||
return [ChatFileModel.model_validate(chat_file) for chat_file in all_chat_files]
|
||||
|
||||
@@ -251,9 +251,7 @@ class FeedbackTable:
|
||||
stmt = stmt.order_by(Feedback.created_at.desc())
|
||||
|
||||
# Count BEFORE pagination
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(stmt.subquery())
|
||||
)
|
||||
count_result = await db.execute(select(func.count()).select_from(stmt.subquery()))
|
||||
total = count_result.scalar()
|
||||
|
||||
if skip:
|
||||
@@ -280,8 +278,9 @@ class FeedbackTable:
|
||||
async def get_all_feedback_ids(self, db: Optional[AsyncSession] = None) -> list[FeedbackIdResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(Feedback.id, Feedback.user_id, Feedback.created_at, Feedback.updated_at)
|
||||
.order_by(Feedback.updated_at.desc())
|
||||
select(Feedback.id, Feedback.user_id, Feedback.created_at, Feedback.updated_at).order_by(
|
||||
Feedback.updated_at.desc()
|
||||
)
|
||||
)
|
||||
return [
|
||||
FeedbackIdResponse(
|
||||
@@ -378,16 +377,12 @@ class FeedbackTable:
|
||||
|
||||
async def get_feedbacks_by_type(self, type: str, db: Optional[AsyncSession] = None) -> list[FeedbackModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(Feedback).filter_by(type=type).order_by(Feedback.updated_at.desc())
|
||||
)
|
||||
result = await db.execute(select(Feedback).filter_by(type=type).order_by(Feedback.updated_at.desc()))
|
||||
return [FeedbackModel.model_validate(feedback) for feedback in result.scalars().all()]
|
||||
|
||||
async def get_feedbacks_by_user_id(self, user_id: str, db: Optional[AsyncSession] = None) -> list[FeedbackModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(Feedback).filter_by(user_id=user_id).order_by(Feedback.updated_at.desc())
|
||||
)
|
||||
result = await db.execute(select(Feedback).filter_by(user_id=user_id).order_by(Feedback.updated_at.desc()))
|
||||
return [FeedbackModel.model_validate(feedback) for feedback in result.scalars().all()]
|
||||
|
||||
async def update_feedback_by_id(
|
||||
|
||||
@@ -125,7 +125,9 @@ class FileUpdateForm(BaseModel):
|
||||
|
||||
|
||||
class FilesTable:
|
||||
async def insert_new_file(self, user_id: str, form_data: FileForm, db: Optional[AsyncSession] = None) -> Optional[FileModel]:
|
||||
async def insert_new_file(
|
||||
self, user_id: str, form_data: FileForm, db: Optional[AsyncSession] = None
|
||||
) -> Optional[FileModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
file_data = form_data.model_dump()
|
||||
|
||||
@@ -167,7 +169,9 @@ class FilesTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_file_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> Optional[FileModel]:
|
||||
async def get_file_by_id_and_user_id(
|
||||
self, id: str, user_id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[FileModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
result = await db.execute(select(File).filter_by(id=id, user_id=user_id))
|
||||
@@ -179,7 +183,9 @@ class FilesTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_file_metadata_by_id(self, id: str, db: Optional[AsyncSession] = None) -> Optional[FileMetadataResponse]:
|
||||
async def get_file_metadata_by_id(
|
||||
self, id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[FileMetadataResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
file = await db.get(File, id)
|
||||
@@ -211,12 +217,12 @@ class FilesTable:
|
||||
|
||||
async def get_files_by_ids(self, ids: list[str], db: Optional[AsyncSession] = None) -> list[FileModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(File).filter(File.id.in_(ids)).order_by(File.updated_at.desc())
|
||||
)
|
||||
result = await db.execute(select(File).filter(File.id.in_(ids)).order_by(File.updated_at.desc()))
|
||||
return [FileModel.model_validate(file) for file in result.scalars().all()]
|
||||
|
||||
async def get_file_metadatas_by_ids(self, ids: list[str], db: Optional[AsyncSession] = None) -> list[FileMetadataResponse]:
|
||||
async def get_file_metadatas_by_ids(
|
||||
self, ids: list[str], db: Optional[AsyncSession] = None
|
||||
) -> list[FileMetadataResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(File.id, File.hash, File.meta, File.created_at, File.updated_at)
|
||||
@@ -251,18 +257,11 @@ class FilesTable:
|
||||
if user_id:
|
||||
stmt = stmt.filter_by(user_id=user_id)
|
||||
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(stmt.subquery())
|
||||
)
|
||||
count_result = await db.execute(select(func.count()).select_from(stmt.subquery()))
|
||||
total = count_result.scalar()
|
||||
|
||||
result = await db.execute(
|
||||
stmt.order_by(File.updated_at.desc(), File.id.desc()).offset(skip).limit(limit)
|
||||
)
|
||||
items = [
|
||||
FileModelResponse.model_validate(file, from_attributes=True)
|
||||
for file in result.scalars().all()
|
||||
]
|
||||
result = await db.execute(stmt.order_by(File.updated_at.desc(), File.id.desc()).offset(skip).limit(limit))
|
||||
items = [FileModelResponse.model_validate(file, from_attributes=True) for file in result.scalars().all()]
|
||||
|
||||
return FileListResponse(items=items, total=total)
|
||||
|
||||
@@ -320,9 +319,7 @@ class FilesTable:
|
||||
if pattern != '%':
|
||||
stmt = stmt.filter(File.filename.ilike(pattern, escape='\\'))
|
||||
|
||||
result = await db.execute(
|
||||
stmt.order_by(File.created_at.desc(), File.id.desc()).offset(skip).limit(limit)
|
||||
)
|
||||
result = await db.execute(stmt.order_by(File.created_at.desc(), File.id.desc()).offset(skip).limit(limit))
|
||||
return [FileModel.model_validate(file) for file in result.scalars().all()]
|
||||
|
||||
async def update_file_by_id(
|
||||
@@ -349,7 +346,9 @@ class FilesTable:
|
||||
log.exception(f'Error updating file completely by id: {e}')
|
||||
return None
|
||||
|
||||
async def update_file_hash_by_id(self, id: str, hash: Optional[str], db: Optional[AsyncSession] = None) -> Optional[FileModel]:
|
||||
async def update_file_hash_by_id(
|
||||
self, id: str, hash: Optional[str], db: Optional[AsyncSession] = None
|
||||
) -> Optional[FileModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
result = await db.execute(select(File).filter_by(id=id))
|
||||
@@ -362,7 +361,9 @@ class FilesTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def update_file_data_by_id(self, id: str, data: dict, db: Optional[AsyncSession] = None) -> Optional[FileModel]:
|
||||
async def update_file_data_by_id(
|
||||
self, id: str, data: dict, db: Optional[AsyncSession] = None
|
||||
) -> Optional[FileModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
result = await db.execute(select(File).filter_by(id=id))
|
||||
@@ -374,7 +375,9 @@ class FilesTable:
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
async def update_file_metadata_by_id(self, id: str, meta: dict, db: Optional[AsyncSession] = None) -> Optional[FileModel]:
|
||||
async def update_file_metadata_by_id(
|
||||
self, id: str, meta: dict, db: Optional[AsyncSession] = None
|
||||
) -> Optional[FileModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
result = await db.execute(select(File).filter_by(id=id))
|
||||
|
||||
@@ -171,9 +171,7 @@ class FolderTable:
|
||||
async with get_async_db_context(db) as db:
|
||||
# Check if folder exists
|
||||
result = await db.execute(
|
||||
select(Folder)
|
||||
.filter_by(parent_id=parent_id, user_id=user_id)
|
||||
.filter(Folder.name.ilike(name))
|
||||
select(Folder).filter_by(parent_id=parent_id, user_id=user_id).filter(Folder.name.ilike(name))
|
||||
)
|
||||
folder = result.scalars().first()
|
||||
|
||||
@@ -235,8 +233,7 @@ class FolderTable:
|
||||
form_data = form_data.model_dump(exclude_unset=True)
|
||||
|
||||
existing_result = await db.execute(
|
||||
select(Folder)
|
||||
.filter_by(
|
||||
select(Folder).filter_by(
|
||||
name=form_data.get('name'),
|
||||
parent_id=folder.parent_id,
|
||||
user_id=user_id,
|
||||
@@ -289,7 +286,9 @@ class FolderTable:
|
||||
log.error(f'update_folder: {e}')
|
||||
return
|
||||
|
||||
async def delete_folder_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> list[str]:
|
||||
async def delete_folder_by_id_and_user_id(
|
||||
self, id: str, user_id: str, db: Optional[AsyncSession] = None
|
||||
) -> list[str]:
|
||||
try:
|
||||
folder_ids = []
|
||||
async with get_async_db_context(db) as db:
|
||||
|
||||
@@ -160,7 +160,9 @@ class FunctionsTable:
|
||||
for func in functions:
|
||||
if func.id in existing_ids:
|
||||
await db.execute(
|
||||
update(Function).filter_by(id=func.id).values(
|
||||
update(Function)
|
||||
.filter_by(id=func.id)
|
||||
.values(
|
||||
**func.model_dump(),
|
||||
user_id=user_id,
|
||||
updated_at=int(time.time()),
|
||||
@@ -233,9 +235,7 @@ class FunctionsTable:
|
||||
|
||||
async def get_function_list(self, db: Optional[AsyncSession] = None) -> list[FunctionUserResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(Function).order_by(Function.updated_at.desc())
|
||||
)
|
||||
result = await db.execute(select(Function).order_by(Function.updated_at.desc()))
|
||||
functions = result.scalars().all()
|
||||
user_ids = list(set(func.user_id for func in functions))
|
||||
|
||||
@@ -261,7 +261,9 @@ class FunctionsTable:
|
||||
for func in functions
|
||||
]
|
||||
|
||||
async def get_functions_by_type(self, type: str, active_only=False, db: Optional[AsyncSession] = None) -> list[FunctionModel]:
|
||||
async def get_functions_by_type(
|
||||
self, type: str, active_only=False, db: Optional[AsyncSession] = None
|
||||
) -> list[FunctionModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
if active_only:
|
||||
result = await db.execute(select(Function).filter_by(type=type, is_active=True))
|
||||
@@ -342,7 +344,9 @@ class FunctionsTable:
|
||||
log.exception(f'Error updating function metadata by id {id}: {e}')
|
||||
return None
|
||||
|
||||
async def get_user_valves_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> Optional[dict]:
|
||||
async def get_user_valves_by_id_and_user_id(
|
||||
self, id: str, user_id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[dict]:
|
||||
try:
|
||||
user = await Users.get_user_by_id(user_id, db=db)
|
||||
user_settings = user.settings.model_dump() if user.settings else {}
|
||||
@@ -381,11 +385,15 @@ class FunctionsTable:
|
||||
log.exception(f'Error updating user valves by id {id} and user_id {user_id}: {e}')
|
||||
return None
|
||||
|
||||
async def update_function_by_id(self, id: str, updated: dict, db: Optional[AsyncSession] = None) -> Optional[FunctionModel]:
|
||||
async def update_function_by_id(
|
||||
self, id: str, updated: dict, db: Optional[AsyncSession] = None
|
||||
) -> Optional[FunctionModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
await db.execute(
|
||||
update(Function).filter_by(id=id).values(
|
||||
update(Function)
|
||||
.filter_by(id=id)
|
||||
.values(
|
||||
**updated,
|
||||
updated_at=int(time.time()),
|
||||
)
|
||||
|
||||
@@ -261,12 +261,10 @@ class GroupTable:
|
||||
|
||||
if 'share' in filter:
|
||||
share_value = filter['share']
|
||||
stmt = stmt.filter(Group.data.op('->>') ('share') == str(share_value))
|
||||
stmt = stmt.filter(Group.data.op('->>')('share') == str(share_value))
|
||||
|
||||
# Get total count
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(stmt.subquery())
|
||||
)
|
||||
count_result = await db.execute(select(func.count()).select_from(stmt.subquery()))
|
||||
total = count_result.scalar()
|
||||
|
||||
member_count = (
|
||||
@@ -348,7 +346,9 @@ class GroupTable:
|
||||
|
||||
return [m[0] for m in members]
|
||||
|
||||
async def get_group_user_ids_by_ids(self, group_ids: list[str], db: Optional[AsyncSession] = None) -> dict[str, list[str]]:
|
||||
async def get_group_user_ids_by_ids(
|
||||
self, group_ids: list[str], db: Optional[AsyncSession] = None
|
||||
) -> dict[str, list[str]]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(GroupMember.group_id, GroupMember.user_id).filter(GroupMember.group_id.in_(group_ids))
|
||||
@@ -362,7 +362,9 @@ class GroupTable:
|
||||
|
||||
return group_user_ids
|
||||
|
||||
async def set_group_user_ids_by_id(self, group_id: str, user_ids: list[str], db: Optional[AsyncSession] = None) -> None:
|
||||
async def set_group_user_ids_by_id(
|
||||
self, group_id: str, user_ids: list[str], db: Optional[AsyncSession] = None
|
||||
) -> None:
|
||||
async with get_async_db_context(db) as db:
|
||||
# Delete existing members
|
||||
await db.execute(delete(GroupMember).filter(GroupMember.group_id == group_id))
|
||||
@@ -411,7 +413,9 @@ class GroupTable:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(
|
||||
update(Group).filter_by(id=id).values(
|
||||
update(Group)
|
||||
.filter_by(id=id)
|
||||
.values(
|
||||
**form_data.model_dump(exclude_none=True),
|
||||
updated_at=int(time.time()),
|
||||
)
|
||||
@@ -455,14 +459,10 @@ class GroupTable:
|
||||
# Remove the user from each group
|
||||
for group in groups:
|
||||
await db.execute(
|
||||
delete(GroupMember).filter(
|
||||
GroupMember.group_id == group.id, GroupMember.user_id == user_id
|
||||
)
|
||||
delete(GroupMember).filter(GroupMember.group_id == group.id, GroupMember.user_id == user_id)
|
||||
)
|
||||
|
||||
await db.execute(
|
||||
update(Group).filter_by(id=group.id).values(updated_at=int(time.time()))
|
||||
)
|
||||
await db.execute(update(Group).filter_by(id=group.id).values(updated_at=int(time.time())))
|
||||
|
||||
await db.commit()
|
||||
return True
|
||||
@@ -507,7 +507,9 @@ class GroupTable:
|
||||
continue
|
||||
return new_groups
|
||||
|
||||
async def sync_groups_by_group_names(self, user_id: str, group_names: list[str], db: Optional[AsyncSession] = None) -> bool:
|
||||
async def sync_groups_by_group_names(
|
||||
self, user_id: str, group_names: list[str], db: Optional[AsyncSession] = None
|
||||
) -> bool:
|
||||
async with get_async_db_context(db) as db:
|
||||
try:
|
||||
now = int(time.time())
|
||||
@@ -538,9 +540,7 @@ class GroupTable:
|
||||
)
|
||||
)
|
||||
|
||||
await db.execute(
|
||||
update(Group).filter(Group.id.in_(groups_to_remove)).values(updated_at=now)
|
||||
)
|
||||
await db.execute(update(Group).filter(Group.id.in_(groups_to_remove)).values(updated_at=now))
|
||||
|
||||
# 5. Bulk insert missing memberships
|
||||
for group_id in groups_to_add:
|
||||
@@ -555,9 +555,7 @@ class GroupTable:
|
||||
)
|
||||
|
||||
if groups_to_add:
|
||||
await db.execute(
|
||||
update(Group).filter(Group.id.in_(groups_to_add)).values(updated_at=now)
|
||||
)
|
||||
await db.execute(update(Group).filter(Group.id.in_(groups_to_add)).values(updated_at=now))
|
||||
|
||||
await db.commit()
|
||||
return True
|
||||
|
||||
@@ -196,11 +196,13 @@ class KnowledgeTable:
|
||||
knowledge_bases.append(
|
||||
KnowledgeUserModel.model_validate(
|
||||
{
|
||||
**(await self._to_knowledge_model(
|
||||
knowledge,
|
||||
access_grants=grants_map.get(knowledge.id, []),
|
||||
db=db,
|
||||
)).model_dump(),
|
||||
**(
|
||||
await self._to_knowledge_model(
|
||||
knowledge,
|
||||
access_grants=grants_map.get(knowledge.id, []),
|
||||
db=db,
|
||||
)
|
||||
).model_dump(),
|
||||
'user': user.model_dump() if user else None,
|
||||
}
|
||||
)
|
||||
@@ -249,9 +251,7 @@ class KnowledgeTable:
|
||||
|
||||
stmt = stmt.order_by(Knowledge.updated_at.desc(), Knowledge.id.asc())
|
||||
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(stmt.subquery())
|
||||
)
|
||||
count_result = await db.execute(select(func.count()).select_from(stmt.subquery()))
|
||||
total = count_result.scalar()
|
||||
if skip:
|
||||
stmt = stmt.offset(skip)
|
||||
@@ -269,11 +269,13 @@ class KnowledgeTable:
|
||||
knowledge_bases.append(
|
||||
KnowledgeUserModel.model_validate(
|
||||
{
|
||||
**(await self._to_knowledge_model(
|
||||
knowledge_base,
|
||||
access_grants=grants_map.get(knowledge_base.id, []),
|
||||
db=db,
|
||||
)).model_dump(),
|
||||
**(
|
||||
await self._to_knowledge_model(
|
||||
knowledge_base,
|
||||
access_grants=grants_map.get(knowledge_base.id, []),
|
||||
db=db,
|
||||
)
|
||||
).model_dump(),
|
||||
'user': (UserModel.model_validate(user).model_dump() if user else None),
|
||||
}
|
||||
)
|
||||
@@ -321,9 +323,7 @@ class KnowledgeTable:
|
||||
stmt = stmt.order_by(File.updated_at.desc(), File.id.asc())
|
||||
|
||||
# Count before pagination
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(stmt.subquery())
|
||||
)
|
||||
count_result = await db.execute(select(func.count()).select_from(stmt.subquery()))
|
||||
total = count_result.scalar()
|
||||
|
||||
if skip:
|
||||
@@ -490,9 +490,7 @@ class KnowledgeTable:
|
||||
stmt = stmt.order_by(primary_sort, File.id.asc())
|
||||
|
||||
# Count BEFORE pagination
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(stmt.subquery())
|
||||
)
|
||||
count_result = await db.execute(select(func.count()).select_from(stmt.subquery()))
|
||||
total = count_result.scalar()
|
||||
|
||||
if skip:
|
||||
@@ -530,7 +528,9 @@ class KnowledgeTable:
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
async def get_file_metadatas_by_id(self, knowledge_id: str, db: Optional[AsyncSession] = None) -> list[FileMetadataResponse]:
|
||||
async def get_file_metadatas_by_id(
|
||||
self, knowledge_id: str, db: Optional[AsyncSession] = None
|
||||
) -> list[FileMetadataResponse]:
|
||||
try:
|
||||
files = await self.get_files_by_id(knowledge_id, db=db)
|
||||
return [FileMetadataResponse(**file.model_dump()) for file in files]
|
||||
@@ -579,7 +579,9 @@ class KnowledgeTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def remove_file_from_knowledge_by_id(self, knowledge_id: str, file_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def remove_file_from_knowledge_by_id(
|
||||
self, knowledge_id: str, file_id: str, db: Optional[AsyncSession] = None
|
||||
) -> bool:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(delete(KnowledgeFile).filter_by(knowledge_id=knowledge_id, file_id=file_id))
|
||||
@@ -596,9 +598,7 @@ class KnowledgeTable:
|
||||
await db.commit()
|
||||
|
||||
# Update the knowledge entry's updated_at timestamp
|
||||
await db.execute(
|
||||
update(Knowledge).filter_by(id=id).values(updated_at=int(time.time()))
|
||||
)
|
||||
await db.execute(update(Knowledge).filter_by(id=id).values(updated_at=int(time.time())))
|
||||
await db.commit()
|
||||
|
||||
return await self.get_knowledge_by_id(id=id, db=db)
|
||||
@@ -616,7 +616,9 @@ class KnowledgeTable:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(
|
||||
update(Knowledge).filter_by(id=id).values(
|
||||
update(Knowledge)
|
||||
.filter_by(id=id)
|
||||
.values(
|
||||
**form_data.model_dump(exclude={'access_grants'}),
|
||||
updated_at=int(time.time()),
|
||||
)
|
||||
@@ -635,7 +637,9 @@ class KnowledgeTable:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(
|
||||
update(Knowledge).filter_by(id=id).values(
|
||||
update(Knowledge)
|
||||
.filter_by(id=id)
|
||||
.values(
|
||||
data=data,
|
||||
updated_at=int(time.time()),
|
||||
)
|
||||
|
||||
@@ -250,11 +250,11 @@ class MessageTable:
|
||||
}
|
||||
return None
|
||||
|
||||
async def get_thread_replies_by_message_id(self, id: str, db: Optional[AsyncSession] = None) -> list[MessageReplyToResponse]:
|
||||
async def get_thread_replies_by_message_id(
|
||||
self, id: str, db: Optional[AsyncSession] = None
|
||||
) -> list[MessageReplyToResponse]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(Message).filter_by(parent_id=id).order_by(Message.created_at.desc())
|
||||
)
|
||||
result = await db.execute(select(Message).filter_by(parent_id=id).order_by(Message.created_at.desc()))
|
||||
all_messages = result.scalars().all()
|
||||
|
||||
messages = []
|
||||
@@ -369,7 +369,9 @@ class MessageTable:
|
||||
)
|
||||
return messages
|
||||
|
||||
async def get_last_message_by_channel_id(self, channel_id: str, db: Optional[AsyncSession] = None) -> Optional[MessageModel]:
|
||||
async def get_last_message_by_channel_id(
|
||||
self, channel_id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[MessageModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(Message).filter_by(channel_id=channel_id).order_by(Message.created_at.desc()).limit(1)
|
||||
@@ -453,9 +455,7 @@ class MessageTable:
|
||||
) -> Optional[MessageReactionModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
# check for existing reaction
|
||||
result = await db.execute(
|
||||
select(MessageReaction).filter_by(message_id=id, user_id=user_id, name=name)
|
||||
)
|
||||
result = await db.execute(select(MessageReaction).filter_by(message_id=id, user_id=user_id, name=name))
|
||||
existing_reaction = result.scalars().first()
|
||||
if existing_reaction:
|
||||
return MessageReactionModel.model_validate(existing_reaction)
|
||||
|
||||
@@ -200,7 +200,8 @@ class ModelsTable:
|
||||
model_ids = [model.id for model in all_models]
|
||||
grants_map = await AccessGrants.get_grants_by_resources('model', model_ids, db=db)
|
||||
return [
|
||||
await self._to_model_model(model, access_grants=grants_map.get(model.id, []), db=db) for model in all_models
|
||||
await self._to_model_model(model, access_grants=grants_map.get(model.id, []), db=db)
|
||||
for model in all_models
|
||||
]
|
||||
|
||||
async def get_models(self, db: Optional[AsyncSession] = None) -> list[ModelUserResponse]:
|
||||
@@ -221,11 +222,13 @@ class ModelsTable:
|
||||
models.append(
|
||||
ModelUserResponse.model_validate(
|
||||
{
|
||||
**(await self._to_model_model(
|
||||
model,
|
||||
access_grants=grants_map.get(model.id, []),
|
||||
db=db,
|
||||
)).model_dump(),
|
||||
**(
|
||||
await self._to_model_model(
|
||||
model,
|
||||
access_grants=grants_map.get(model.id, []),
|
||||
db=db,
|
||||
)
|
||||
).model_dump(),
|
||||
'user': user.model_dump() if user else None,
|
||||
}
|
||||
)
|
||||
@@ -239,7 +242,8 @@ class ModelsTable:
|
||||
model_ids = [model.id for model in all_models]
|
||||
grants_map = await AccessGrants.get_grants_by_resources('model', model_ids, db=db)
|
||||
return [
|
||||
await self._to_model_model(model, access_grants=grants_map.get(model.id, []), db=db) for model in all_models
|
||||
await self._to_model_model(model, access_grants=grants_map.get(model.id, []), db=db)
|
||||
for model in all_models
|
||||
]
|
||||
|
||||
async def get_models_by_user_id(
|
||||
@@ -342,9 +346,7 @@ class ModelsTable:
|
||||
stmt = stmt.order_by(Model.created_at.desc())
|
||||
|
||||
# Count BEFORE pagination
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(stmt.subquery())
|
||||
)
|
||||
count_result = await db.execute(select(func.count()).select_from(stmt.subquery()))
|
||||
total = count_result.scalar()
|
||||
|
||||
if skip:
|
||||
@@ -362,11 +364,13 @@ class ModelsTable:
|
||||
for model, user in items:
|
||||
models.append(
|
||||
ModelUserResponse(
|
||||
**(await self._to_model_model(
|
||||
model,
|
||||
access_grants=grants_map.get(model.id, []),
|
||||
db=db,
|
||||
)).model_dump(),
|
||||
**(
|
||||
await self._to_model_model(
|
||||
model,
|
||||
access_grants=grants_map.get(model.id, []),
|
||||
db=db,
|
||||
)
|
||||
).model_dump(),
|
||||
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
|
||||
)
|
||||
)
|
||||
@@ -416,7 +420,9 @@ class ModelsTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def update_model_by_id(self, id: str, model: ModelForm, db: Optional[AsyncSession] = None) -> Optional[ModelModel]:
|
||||
async def update_model_by_id(
|
||||
self, id: str, model: ModelForm, db: Optional[AsyncSession] = None
|
||||
) -> Optional[ModelModel]:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
# update only the fields that are present in the model
|
||||
@@ -473,7 +479,9 @@ class ModelsTable:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def sync_models(self, user_id: str, models: list[ModelModel], db: Optional[AsyncSession] = None) -> list[ModelModel]:
|
||||
async def sync_models(
|
||||
self, user_id: str, models: list[ModelModel], db: Optional[AsyncSession] = None
|
||||
) -> list[ModelModel]:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
# Get existing models
|
||||
@@ -488,7 +496,9 @@ class ModelsTable:
|
||||
for model in models:
|
||||
if model.id in existing_ids:
|
||||
await db.execute(
|
||||
update(Model).filter_by(id=model.id).values(
|
||||
update(Model)
|
||||
.filter_by(id=model.id)
|
||||
.values(
|
||||
**model.model_dump(exclude={'access_grants'}),
|
||||
user_id=user_id,
|
||||
updated_at=int(time.time()),
|
||||
|
||||
@@ -113,7 +113,9 @@ class NoteTable:
|
||||
permission=permission,
|
||||
)
|
||||
|
||||
async def insert_new_note(self, user_id: str, form_data: NoteForm, db: Optional[AsyncSession] = None) -> Optional[NoteModel]:
|
||||
async def insert_new_note(
|
||||
self, user_id: str, form_data: NoteForm, db: Optional[AsyncSession] = None
|
||||
) -> Optional[NoteModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
note = NoteModel(
|
||||
**{
|
||||
@@ -216,9 +218,7 @@ class NoteTable:
|
||||
stmt = stmt.order_by(Note.updated_at.desc())
|
||||
|
||||
# Count BEFORE pagination
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(stmt.subquery())
|
||||
)
|
||||
count_result = await db.execute(select(func.count()).select_from(stmt.subquery()))
|
||||
total = count_result.scalar()
|
||||
|
||||
if skip:
|
||||
@@ -236,11 +236,13 @@ class NoteTable:
|
||||
for note, user in items:
|
||||
notes.append(
|
||||
NoteUserResponse(
|
||||
**(await self._to_note_model(
|
||||
note,
|
||||
access_grants=grants_map.get(note.id, []),
|
||||
db=db,
|
||||
)).model_dump(),
|
||||
**(
|
||||
await self._to_note_model(
|
||||
note,
|
||||
access_grants=grants_map.get(note.id, []),
|
||||
db=db,
|
||||
)
|
||||
).model_dump(),
|
||||
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -151,7 +151,9 @@ class OAuthSessionTable:
|
||||
log.error(f'Error creating OAuth session: {e}')
|
||||
return None
|
||||
|
||||
async def get_session_by_id(self, session_id: str, db: Optional[AsyncSession] = None) -> Optional[OAuthSessionModel]:
|
||||
async def get_session_by_id(
|
||||
self, session_id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[OAuthSessionModel]:
|
||||
"""Get OAuth session by ID"""
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
@@ -235,15 +237,17 @@ class OAuthSessionTable:
|
||||
results = []
|
||||
for session in sessions:
|
||||
try:
|
||||
results.append(OAuthSessionModel(
|
||||
id=session.id,
|
||||
user_id=session.user_id,
|
||||
provider=session.provider,
|
||||
token=self._decrypt_token(session.token),
|
||||
expires_at=session.expires_at,
|
||||
created_at=session.created_at,
|
||||
updated_at=session.updated_at,
|
||||
))
|
||||
results.append(
|
||||
OAuthSessionModel(
|
||||
id=session.id,
|
||||
user_id=session.user_id,
|
||||
provider=session.provider,
|
||||
token=self._decrypt_token(session.token),
|
||||
expires_at=session.expires_at,
|
||||
created_at=session.created_at,
|
||||
updated_at=session.updated_at,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
log.warning(
|
||||
f'Skipping OAuth session {session.id} due to decryption failure, deleting corrupted session: {type(e).__name__}: {e}'
|
||||
@@ -266,7 +270,9 @@ class OAuthSessionTable:
|
||||
current_time = int(time.time())
|
||||
|
||||
await db.execute(
|
||||
update(OAuthSession).filter_by(id=session_id).values(
|
||||
update(OAuthSession)
|
||||
.filter_by(id=session_id)
|
||||
.values(
|
||||
token=self._encrypt_token(token),
|
||||
expires_at=token.get('expires_at'),
|
||||
updated_at=current_time,
|
||||
|
||||
@@ -213,11 +213,13 @@ class PromptsTable:
|
||||
prompts.append(
|
||||
PromptUserResponse.model_validate(
|
||||
{
|
||||
**(await self._to_prompt_model(
|
||||
prompt,
|
||||
access_grants=grants_map.get(prompt.id, []),
|
||||
db=db,
|
||||
)).model_dump(),
|
||||
**(
|
||||
await self._to_prompt_model(
|
||||
prompt,
|
||||
access_grants=grants_map.get(prompt.id, []),
|
||||
db=db,
|
||||
)
|
||||
).model_dump(),
|
||||
'user': user.model_dump() if user else None,
|
||||
}
|
||||
)
|
||||
@@ -319,9 +321,7 @@ class PromptsTable:
|
||||
stmt = stmt.order_by(Prompt.updated_at.desc())
|
||||
|
||||
# Count BEFORE pagination
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(stmt.subquery())
|
||||
)
|
||||
count_result = await db.execute(select(func.count()).select_from(stmt.subquery()))
|
||||
total = count_result.scalar()
|
||||
|
||||
if skip:
|
||||
@@ -339,11 +339,13 @@ class PromptsTable:
|
||||
for prompt, user in items:
|
||||
prompts.append(
|
||||
PromptUserResponse(
|
||||
**(await self._to_prompt_model(
|
||||
prompt,
|
||||
access_grants=grants_map.get(prompt.id, []),
|
||||
db=db,
|
||||
)).model_dump(),
|
||||
**(
|
||||
await self._to_prompt_model(
|
||||
prompt,
|
||||
access_grants=grants_map.get(prompt.id, []),
|
||||
db=db,
|
||||
)
|
||||
).model_dump(),
|
||||
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -184,11 +184,13 @@ class SkillsTable:
|
||||
skills.append(
|
||||
SkillUserModel.model_validate(
|
||||
{
|
||||
**(await self._to_skill_model(
|
||||
skill,
|
||||
access_grants=grants_map.get(skill.id, []),
|
||||
db=db,
|
||||
)).model_dump(),
|
||||
**(
|
||||
await self._to_skill_model(
|
||||
skill,
|
||||
access_grants=grants_map.get(skill.id, []),
|
||||
db=db,
|
||||
)
|
||||
).model_dump(),
|
||||
'user': user.model_dump() if user else None,
|
||||
}
|
||||
)
|
||||
@@ -262,9 +264,7 @@ class SkillsTable:
|
||||
stmt = stmt.order_by(Skill.updated_at.desc())
|
||||
|
||||
# Count BEFORE pagination
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(stmt.subquery())
|
||||
)
|
||||
count_result = await db.execute(select(func.count()).select_from(stmt.subquery()))
|
||||
total = count_result.scalar()
|
||||
|
||||
if skip:
|
||||
@@ -282,11 +282,13 @@ class SkillsTable:
|
||||
for skill, user in items:
|
||||
skills.append(
|
||||
SkillUserResponse(
|
||||
**(await self._to_skill_model(
|
||||
skill,
|
||||
access_grants=grants_map.get(skill.id, []),
|
||||
db=db,
|
||||
)).model_dump(),
|
||||
**(
|
||||
await self._to_skill_model(
|
||||
skill,
|
||||
access_grants=grants_map.get(skill.id, []),
|
||||
db=db,
|
||||
)
|
||||
).model_dump(),
|
||||
user=(UserResponse(**UserModel.model_validate(user).model_dump()) if user else None),
|
||||
)
|
||||
)
|
||||
@@ -296,7 +298,9 @@ class SkillsTable:
|
||||
log.exception(f'Error searching skills: {e}')
|
||||
return SkillListResponse(items=[], total=0)
|
||||
|
||||
async def update_skill_by_id(self, id: str, updated: dict, db: Optional[AsyncSession] = None) -> Optional[SkillModel]:
|
||||
async def update_skill_by_id(
|
||||
self, id: str, updated: dict, db: Optional[AsyncSession] = None
|
||||
) -> Optional[SkillModel]:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
access_grants = updated.pop('access_grants', None)
|
||||
|
||||
@@ -71,7 +71,9 @@ class TagTable:
|
||||
log.exception(f'Error inserting a new tag: {e}')
|
||||
return None
|
||||
|
||||
async def get_tag_by_name_and_user_id(self, name: str, user_id: str, db: Optional[AsyncSession] = None) -> Optional[TagModel]:
|
||||
async def get_tag_by_name_and_user_id(
|
||||
self, name: str, user_id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[TagModel]:
|
||||
try:
|
||||
id = name.replace(' ', '_').lower()
|
||||
async with get_async_db_context(db) as db:
|
||||
@@ -86,7 +88,9 @@ class TagTable:
|
||||
result = await db.execute(select(Tag).filter_by(user_id=user_id))
|
||||
return [TagModel.model_validate(tag) for tag in result.scalars().all()]
|
||||
|
||||
async def get_tags_by_ids_and_user_id(self, ids: list[str], user_id: str, db: Optional[AsyncSession] = None) -> list[TagModel]:
|
||||
async def get_tags_by_ids_and_user_id(
|
||||
self, ids: list[str], user_id: str, db: Optional[AsyncSession] = None
|
||||
) -> list[TagModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(Tag).filter(Tag.id.in_(ids), Tag.user_id == user_id))
|
||||
return [TagModel.model_validate(tag) for tag in result.scalars().all()]
|
||||
@@ -103,7 +107,9 @@ class TagTable:
|
||||
log.error(f'delete_tag: {e}')
|
||||
return False
|
||||
|
||||
async def delete_tags_by_ids_and_user_id(self, ids: list[str], user_id: str, db: Optional[AsyncSession] = None) -> bool:
|
||||
async def delete_tags_by_ids_and_user_id(
|
||||
self, ids: list[str], user_id: str, db: Optional[AsyncSession] = None
|
||||
) -> bool:
|
||||
"""Delete all tags whose id is in *ids* for the given user, in one query."""
|
||||
if not ids:
|
||||
return True
|
||||
|
||||
@@ -172,11 +172,13 @@ class ToolsTable:
|
||||
tools.append(
|
||||
ToolUserModel.model_validate(
|
||||
{
|
||||
**(await self._to_tool_model(
|
||||
tool,
|
||||
access_grants=grants_map.get(tool.id, []),
|
||||
db=db,
|
||||
)).model_dump(),
|
||||
**(
|
||||
await self._to_tool_model(
|
||||
tool,
|
||||
access_grants=grants_map.get(tool.id, []),
|
||||
db=db,
|
||||
)
|
||||
).model_dump(),
|
||||
'user': user.model_dump() if user else None,
|
||||
}
|
||||
)
|
||||
@@ -218,18 +220,20 @@ class ToolsTable:
|
||||
log.exception(f'Error getting tool valves by id {id}')
|
||||
return None
|
||||
|
||||
async def update_tool_valves_by_id(self, id: str, valves: dict, db: Optional[AsyncSession] = None) -> Optional[ToolValves]:
|
||||
async def update_tool_valves_by_id(
|
||||
self, id: str, valves: dict, db: Optional[AsyncSession] = None
|
||||
) -> Optional[ToolValves]:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
await db.execute(
|
||||
update(Tool).filter_by(id=id).values(valves=valves, updated_at=int(time.time()))
|
||||
)
|
||||
await db.execute(update(Tool).filter_by(id=id).values(valves=valves, updated_at=int(time.time())))
|
||||
await db.commit()
|
||||
return await self.get_tool_by_id(id, db=db)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_user_valves_by_id_and_user_id(self, id: str, user_id: str, db: Optional[AsyncSession] = None) -> Optional[dict]:
|
||||
async def get_user_valves_by_id_and_user_id(
|
||||
self, id: str, user_id: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[dict]:
|
||||
try:
|
||||
user = await Users.get_user_by_id(user_id, db=db)
|
||||
user_settings = user.settings.model_dump() if user.settings else {}
|
||||
@@ -272,9 +276,7 @@ class ToolsTable:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
access_grants = updated.pop('access_grants', None)
|
||||
await db.execute(
|
||||
update(Tool).filter_by(id=id).values(**updated, updated_at=int(time.time()))
|
||||
)
|
||||
await db.execute(update(Tool).filter_by(id=id).values(**updated, updated_at=int(time.time())))
|
||||
await db.commit()
|
||||
if access_grants is not None:
|
||||
await AccessGrants.set_access_grants('tool', id, access_grants, db=db)
|
||||
|
||||
@@ -31,11 +31,13 @@ import datetime
|
||||
# daily bread of every session. Let none go hungry.
|
||||
####################
|
||||
|
||||
|
||||
class UserSettings(BaseModel):
|
||||
ui: Optional[dict] = {}
|
||||
model_config = ConfigDict(extra='allow')
|
||||
pass
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
|
||||
@@ -69,6 +71,7 @@ class User(Base):
|
||||
updated_at = Column(BigInteger)
|
||||
created_at = Column(BigInteger)
|
||||
|
||||
|
||||
class UserModel(BaseModel):
|
||||
id: str
|
||||
|
||||
@@ -109,11 +112,13 @@ class UserModel(BaseModel):
|
||||
self.profile_image_url = f'/api/v1/users/{self.id}/profile/image'
|
||||
return self
|
||||
|
||||
|
||||
class UserStatusModel(UserModel):
|
||||
is_active: bool = False
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class ApiKey(Base):
|
||||
__tablename__ = 'api_key'
|
||||
|
||||
@@ -126,6 +131,7 @@ class ApiKey(Base):
|
||||
created_at = Column(BigInteger, nullable=False)
|
||||
updated_at = Column(BigInteger, nullable=False)
|
||||
|
||||
|
||||
class ApiKeyModel(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
@@ -138,10 +144,12 @@ class ApiKeyModel(BaseModel):
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class UpdateProfileForm(BaseModel):
|
||||
profile_image_url: str
|
||||
name: str
|
||||
@@ -154,25 +162,31 @@ class UpdateProfileForm(BaseModel):
|
||||
def check_profile_image_url(cls, v: str) -> str:
|
||||
return validate_profile_image_url(v)
|
||||
|
||||
|
||||
class UserGroupIdsModel(UserModel):
|
||||
group_ids: list[str] = []
|
||||
|
||||
|
||||
class UserModelResponse(UserModel):
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
|
||||
class UserListResponse(BaseModel):
|
||||
users: list[UserModelResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class UserGroupIdsListResponse(BaseModel):
|
||||
users: list[UserGroupIdsModel]
|
||||
total: int
|
||||
|
||||
|
||||
class UserStatus(BaseModel):
|
||||
status_emoji: Optional[str] = None
|
||||
status_message: Optional[str] = None
|
||||
status_expires_at: Optional[int] = None
|
||||
|
||||
|
||||
class UserInfoResponse(UserStatus):
|
||||
id: str
|
||||
name: str
|
||||
@@ -182,39 +196,48 @@ class UserInfoResponse(UserStatus):
|
||||
groups: Optional[list] = []
|
||||
is_active: bool = False
|
||||
|
||||
|
||||
class UserIdNameResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class UserIdNameStatusResponse(UserStatus):
|
||||
id: str
|
||||
name: str
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class UserInfoListResponse(BaseModel):
|
||||
users: list[UserInfoResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class UserIdNameListResponse(BaseModel):
|
||||
users: list[UserIdNameResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class UserNameResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
role: str
|
||||
|
||||
|
||||
class UserResponse(UserNameResponse):
|
||||
email: str
|
||||
|
||||
|
||||
class UserProfileImageResponse(UserNameResponse):
|
||||
email: str
|
||||
profile_image_url: str
|
||||
|
||||
|
||||
class UserRoleUpdateForm(BaseModel):
|
||||
id: str
|
||||
role: str
|
||||
|
||||
|
||||
class UserUpdateForm(BaseModel):
|
||||
role: str
|
||||
name: str
|
||||
@@ -227,6 +250,7 @@ class UserUpdateForm(BaseModel):
|
||||
def check_profile_image_url(cls, v: str) -> str:
|
||||
return validate_profile_image_url(v)
|
||||
|
||||
|
||||
class UsersTable:
|
||||
async def insert_new_user(
|
||||
self,
|
||||
@@ -292,7 +316,9 @@ class UsersTable:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_user_by_oauth_sub(self, provider: str, sub: str, db: Optional[AsyncSession] = None) -> Optional[UserModel]:
|
||||
async def get_user_by_oauth_sub(
|
||||
self, provider: str, sub: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[UserModel]:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
dialect_name = db.bind.dialect.name
|
||||
@@ -457,9 +483,7 @@ class UsersTable:
|
||||
stmt = stmt.order_by(User.created_at.desc())
|
||||
|
||||
# Count BEFORE pagination
|
||||
count_result = await db.execute(
|
||||
select(func.count()).select_from(stmt.subquery())
|
||||
)
|
||||
count_result = await db.execute(select(func.count()).select_from(stmt.subquery()))
|
||||
total = count_result.scalar()
|
||||
|
||||
# correct pagination logic
|
||||
@@ -478,20 +502,18 @@ class UsersTable:
|
||||
async def get_users_by_group_id(self, group_id: str, db: Optional[AsyncSession] = None) -> list[UserModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
from open_webui.models.groups import GroupMember
|
||||
|
||||
result = await db.execute(
|
||||
select(User)
|
||||
|
||||
.join(GroupMember, User.id == GroupMember.user_id)
|
||||
.filter(GroupMember.group_id == group_id)
|
||||
select(User).join(GroupMember, User.id == GroupMember.user_id).filter(GroupMember.group_id == group_id)
|
||||
)
|
||||
users = result.scalars().all()
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
|
||||
async def get_users_by_user_ids(self, user_ids: list[str], db: Optional[AsyncSession] = None) -> list[UserStatusModel]:
|
||||
async def get_users_by_user_ids(
|
||||
self, user_ids: list[str], db: Optional[AsyncSession] = None
|
||||
) -> list[UserStatusModel]:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(
|
||||
select(User).filter(User.id.in_(user_ids))
|
||||
)
|
||||
result = await db.execute(select(User).filter(User.id.in_(user_ids)))
|
||||
users = result.scalars().all()
|
||||
return [UserModel.model_validate(user) for user in users]
|
||||
|
||||
@@ -536,7 +558,9 @@ class UsersTable:
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
async def update_user_role_by_id(self, id: str, role: str, db: Optional[AsyncSession] = None) -> Optional[UserModel]:
|
||||
async def update_user_role_by_id(
|
||||
self, id: str, role: str, db: Optional[AsyncSession] = None
|
||||
) -> Optional[UserModel]:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(id=id))
|
||||
@@ -674,7 +698,9 @@ class UsersTable:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
async def update_user_settings_by_id(self, id: str, updated: dict, db: Optional[AsyncSession] = None) -> Optional[UserModel]:
|
||||
async def update_user_settings_by_id(
|
||||
self, id: str, updated: dict, db: Optional[AsyncSession] = None
|
||||
) -> Optional[UserModel]:
|
||||
try:
|
||||
async with get_async_db_context(db) as db:
|
||||
result = await db.execute(select(User).filter_by(id=id))
|
||||
@@ -802,4 +828,5 @@ class UsersTable:
|
||||
return user.last_active_at >= three_minutes_ago
|
||||
return False
|
||||
|
||||
|
||||
Users = UsersTable()
|
||||
|
||||
@@ -62,7 +62,9 @@ async def get_model_analytics(
|
||||
db: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Get message counts per model."""
|
||||
counts = await ChatMessages.get_message_count_by_model(start_date=start_date, end_date=end_date, group_id=group_id, db=db)
|
||||
counts = await ChatMessages.get_message_count_by_model(
|
||||
start_date=start_date, end_date=end_date, group_id=group_id, db=db
|
||||
)
|
||||
models = [
|
||||
ModelAnalyticsEntry(model_id=model_id, count=count)
|
||||
for model_id, count in sorted(counts.items(), key=lambda x: -x[1])
|
||||
@@ -80,7 +82,9 @@ async def get_user_analytics(
|
||||
db: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Get message counts and token usage per user with user info."""
|
||||
counts = await ChatMessages.get_message_count_by_user(start_date=start_date, end_date=end_date, group_id=group_id, db=db)
|
||||
counts = await ChatMessages.get_message_count_by_user(
|
||||
start_date=start_date, end_date=end_date, group_id=group_id, db=db
|
||||
)
|
||||
token_usage = await ChatMessages.get_token_usage_by_user(
|
||||
start_date=start_date, end_date=end_date, group_id=group_id, db=db
|
||||
)
|
||||
@@ -227,7 +231,9 @@ async def get_token_usage(
|
||||
db: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
"""Get token usage aggregated by model."""
|
||||
usage = await ChatMessages.get_token_usage_by_model(start_date=start_date, end_date=end_date, group_id=group_id, db=db)
|
||||
usage = await ChatMessages.get_token_usage_by_model(
|
||||
start_date=start_date, end_date=end_date, group_id=group_id, db=db
|
||||
)
|
||||
|
||||
models = [
|
||||
TokenUsageEntry(model_id=model_id, **data)
|
||||
|
||||
@@ -330,7 +330,9 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
if user.role != 'admin' and not await has_permission(user.id, 'chat.tts', request.app.state.config.USER_PERMISSIONS):
|
||||
if user.role != 'admin' and not await has_permission(
|
||||
user.id, 'chat.tts', request.app.state.config.USER_PERMISSIONS
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
@@ -630,6 +632,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
detail=detail,
|
||||
)
|
||||
|
||||
|
||||
def transcription_handler(request, file_path, metadata, user=None):
|
||||
filename = os.path.basename(file_path)
|
||||
file_dir = os.path.dirname(file_path)
|
||||
@@ -1214,7 +1217,9 @@ async def transcription(
|
||||
language: Optional[str] = Form(None),
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
if user.role != 'admin' and not await has_permission(user.id, 'chat.stt', request.app.state.config.USER_PERMISSIONS):
|
||||
if user.role != 'admin' and not await has_permission(
|
||||
user.id, 'chat.stt', request.app.state.config.USER_PERMISSIONS
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
|
||||
@@ -97,7 +97,9 @@ log = logging.getLogger(__name__)
|
||||
signin_rate_limiter = RateLimiter(redis_client=get_redis_client(), limit=5 * 3, window=60 * 3)
|
||||
|
||||
|
||||
async def create_session_response(request: Request, user, db, response: Response = None, set_cookie: bool = False) -> dict:
|
||||
async def create_session_response(
|
||||
request: Request, user, db, response: Response = None, set_cookie: bool = False
|
||||
) -> dict:
|
||||
"""
|
||||
Create JWT token and build session response for a user.
|
||||
Shared helper for signin, signup, ldap_auth, add_user, and token_exchange endpoints.
|
||||
@@ -918,7 +920,9 @@ async def add_user(
|
||||
|
||||
|
||||
@router.get('/admin/details')
|
||||
async def get_admin_details(request: Request, user=Depends(get_current_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def get_admin_details(
|
||||
request: Request, user=Depends(get_current_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
if request.app.state.config.SHOW_ADMIN_DETAILS:
|
||||
admin_email = request.app.state.config.ADMIN_EMAIL
|
||||
admin_name = None
|
||||
@@ -1182,7 +1186,9 @@ async def update_ldap_config(request: Request, form_data: LdapConfigForm, user=D
|
||||
|
||||
# create api key
|
||||
@router.post('/api_key', response_model=ApiKey)
|
||||
async def generate_api_key(request: Request, user=Depends(get_current_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def generate_api_key(
|
||||
request: Request, user=Depends(get_current_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
if not request.app.state.config.ENABLE_API_KEYS or (
|
||||
user.role != 'admin'
|
||||
and not await has_permission(user.id, 'features.api_keys', request.app.state.config.USER_PERMISSIONS)
|
||||
|
||||
@@ -94,7 +94,9 @@ async def channel_has_access(
|
||||
return False
|
||||
|
||||
|
||||
async def get_channel_users_with_access(channel: ChannelModel, permission: str = 'read', db: Optional[AsyncSession] = None):
|
||||
async def get_channel_users_with_access(
|
||||
channel: ChannelModel, permission: str = 'read', db: Optional[AsyncSession] = None
|
||||
):
|
||||
return await AccessGrants.get_users_with_access(
|
||||
resource_type='channel',
|
||||
resource_id=channel.id,
|
||||
@@ -893,11 +895,13 @@ async def model_response_handler(request, channel, message, user, db=None):
|
||||
if model:
|
||||
try:
|
||||
# reverse to get in chronological order
|
||||
thread_messages = (await Messages.get_messages_by_parent_id(
|
||||
channel.id,
|
||||
message.parent_id if message.parent_id else message.id,
|
||||
db=db,
|
||||
))[::-1]
|
||||
thread_messages = (
|
||||
await Messages.get_messages_by_parent_id(
|
||||
channel.id,
|
||||
message.parent_id if message.parent_id else message.id,
|
||||
db=db,
|
||||
)
|
||||
)[::-1]
|
||||
|
||||
response_message, channel = await new_message_handler(
|
||||
request,
|
||||
@@ -1120,7 +1124,9 @@ async def post_new_message(
|
||||
try:
|
||||
if files := message.data.get('files', []):
|
||||
for file in files:
|
||||
await Channels.set_file_message_id_in_channel_by_id(channel.id, file.get('id', ''), message.id, db=db)
|
||||
await Channels.set_file_message_id_in_channel_by_id(
|
||||
channel.id, file.get('id', ''), message.id, db=db
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug(e)
|
||||
|
||||
|
||||
@@ -493,7 +493,9 @@ async def delete_all_user_chats(
|
||||
user=Depends(get_verified_user),
|
||||
db: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
if user.role == 'user' and not await has_permission(user.id, 'chat.delete', request.app.state.config.USER_PERMISSIONS):
|
||||
if user.role == 'user' and not await has_permission(
|
||||
user.id, 'chat.delete', request.app.state.config.USER_PERMISSIONS
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
@@ -538,7 +540,9 @@ async def get_user_chat_list_by_user_id(
|
||||
if direction:
|
||||
filter['direction'] = direction
|
||||
|
||||
return await Chats.get_chat_list_by_user_id(user_id, include_archived=True, filter=filter, skip=skip, limit=limit, db=db)
|
||||
return await Chats.get_chat_list_by_user_id(
|
||||
user_id, include_archived=True, filter=filter, skip=skip, limit=limit, db=db
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
@@ -620,7 +624,9 @@ async def search_user_chats(
|
||||
|
||||
|
||||
@router.get('/folder/{folder_id}', response_model=list[ChatResponse])
|
||||
async def get_chats_by_folder_id(folder_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def get_chats_by_folder_id(
|
||||
folder_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
folder_ids = [folder_id]
|
||||
children_folders = await Folders.get_children_folders_by_id_and_user_id(folder_id, user.id, db=db)
|
||||
if children_folders:
|
||||
@@ -815,7 +821,9 @@ async def get_shared_session_user_chat_list(
|
||||
|
||||
|
||||
@router.get('/share/{share_id}', response_model=Optional[ChatResponse])
|
||||
async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def get_shared_chat_by_id(
|
||||
share_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
if user.role == 'pending':
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND)
|
||||
|
||||
@@ -851,7 +859,9 @@ async def get_user_chat_list_by_tag_name(
|
||||
user=Depends(get_verified_user),
|
||||
db: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
chats = await Chats.get_chat_list_by_user_id_and_tag_name(user.id, form_data.name, form_data.skip, form_data.limit, db=db)
|
||||
chats = await Chats.get_chat_list_by_user_id_and_tag_name(
|
||||
user.id, form_data.name, form_data.skip, form_data.limit, db=db
|
||||
)
|
||||
if len(chats) == 0:
|
||||
await Tags.delete_tag_by_name_and_user_id(form_data.name, user.id, db=db)
|
||||
|
||||
@@ -1056,7 +1066,9 @@ async def delete_chat_by_id(
|
||||
|
||||
|
||||
@router.get('/{id}/pinned', response_model=Optional[bool])
|
||||
async def get_pinned_status_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def get_pinned_status_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
chat = await Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||
if chat:
|
||||
return chat.pinned
|
||||
@@ -1137,7 +1149,9 @@ async def clone_chat_by_id(
|
||||
|
||||
|
||||
@router.post('/{id}/clone/shared', response_model=Optional[ChatResponse])
|
||||
async def clone_shared_chat_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def clone_shared_chat_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
if user.role == 'admin':
|
||||
chat = await Chats.get_chat_by_id(id, db=db)
|
||||
else:
|
||||
@@ -1250,7 +1264,9 @@ async def share_chat_by_id(
|
||||
|
||||
|
||||
@router.delete('/{id}/share', response_model=Optional[bool])
|
||||
async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def delete_shared_chat_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
chat = await Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||
if chat:
|
||||
if not chat.share_id:
|
||||
@@ -1371,7 +1387,9 @@ async def delete_tag_by_id_and_tag_name(
|
||||
|
||||
|
||||
@router.delete('/{id}/tags/all', response_model=Optional[bool])
|
||||
async def delete_all_tags_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def delete_all_tags_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
chat = await Chats.get_chat_by_id_and_user_id(id, user.id, db=db)
|
||||
if chat:
|
||||
old_tags = chat.meta.get('tags', [])
|
||||
|
||||
@@ -415,7 +415,9 @@ async def update_feedback_by_id(
|
||||
|
||||
|
||||
@router.delete('/feedback/{id}')
|
||||
async def delete_feedback_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def delete_feedback_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
if user.role == 'admin':
|
||||
success = await Feedbacks.delete_feedback_by_id(id=id, db=db)
|
||||
else:
|
||||
|
||||
@@ -495,7 +495,9 @@ async def get_file_process_status(
|
||||
|
||||
|
||||
@router.get('/{id}/data/content')
|
||||
async def get_file_data_content_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def get_file_data_content_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
file = await Files.get_file_by_id(id, db=db)
|
||||
|
||||
if not file:
|
||||
@@ -646,7 +648,9 @@ async def get_file_content_by_id(
|
||||
|
||||
|
||||
@router.get('/{id}/content/html')
|
||||
async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def get_html_file_content_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
file = await Files.get_file_by_id(id, db=db)
|
||||
|
||||
if not file:
|
||||
@@ -693,7 +697,9 @@ async def get_html_file_content_by_id(id: str, user=Depends(get_verified_user),
|
||||
|
||||
|
||||
@router.get('/{id}/content/{file_name}')
|
||||
async def get_file_content_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def get_file_content_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
file = await Files.get_file_by_id(id, db=db)
|
||||
|
||||
if not file:
|
||||
|
||||
@@ -89,7 +89,9 @@ async def get_folders(
|
||||
valid_files.append(file)
|
||||
|
||||
folder.data['files'] = valid_files
|
||||
await Folders.update_folder_by_id_and_user_id(folder.id, user.id, FolderUpdateForm(data=folder.data), db=db)
|
||||
await Folders.update_folder_by_id_and_user_id(
|
||||
folder.id, user.id, FolderUpdateForm(data=folder.data), db=db
|
||||
)
|
||||
|
||||
folder_list.append(FolderNameIdResponse(**folder.model_dump()))
|
||||
|
||||
@@ -107,7 +109,9 @@ async def create_folder(
|
||||
user=Depends(get_verified_user),
|
||||
db: AsyncSession = Depends(get_async_session),
|
||||
):
|
||||
folder = await Folders.get_folder_by_parent_id_and_user_id_and_name(form_data.parent_id, user.id, form_data.name, db=db)
|
||||
folder = await Folders.get_folder_by_parent_id_and_user_id_and_name(
|
||||
form_data.parent_id, user.id, form_data.name, db=db
|
||||
)
|
||||
|
||||
if folder:
|
||||
raise HTTPException(
|
||||
@@ -250,7 +254,9 @@ async def update_folder_is_expanded_by_id(
|
||||
folder = await Folders.get_folder_by_id_and_user_id(id, user.id, db=db)
|
||||
if folder:
|
||||
try:
|
||||
folder = await Folders.update_folder_is_expanded_by_id_and_user_id(id, user.id, form_data.is_expanded, db=db)
|
||||
folder = await Folders.update_folder_is_expanded_by_id_and_user_id(
|
||||
id, user.id, form_data.is_expanded, db=db
|
||||
)
|
||||
return folder
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
@@ -373,7 +373,9 @@ async def delete_function_by_id(
|
||||
|
||||
|
||||
@router.get('/id/{id}/valves', response_model=Optional[dict])
|
||||
async def get_function_valves_by_id(id: str, user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def get_function_valves_by_id(
|
||||
id: str, user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
function = await Functions.get_function_by_id(id, db=db)
|
||||
if function:
|
||||
try:
|
||||
@@ -473,7 +475,9 @@ async def update_function_valves_by_id(
|
||||
|
||||
|
||||
@router.get('/id/{id}/valves/user', response_model=Optional[dict])
|
||||
async def get_function_user_valves_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def get_function_user_valves_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
function = await Functions.get_function_by_id(id, db=db)
|
||||
if function:
|
||||
try:
|
||||
|
||||
@@ -858,7 +858,9 @@ async def remove_file_from_knowledge_by_id(
|
||||
|
||||
|
||||
@router.delete('/{id}/delete', response_model=bool)
|
||||
async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def delete_knowledge_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
knowledge = await Knowledges.get_knowledge_by_id(id=id, db=db)
|
||||
if not knowledge:
|
||||
raise HTTPException(
|
||||
@@ -931,7 +933,9 @@ async def delete_knowledge_by_id(id: str, user=Depends(get_verified_user), db: A
|
||||
|
||||
|
||||
@router.post('/{id}/reset', response_model=Optional[KnowledgeResponse])
|
||||
async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def reset_knowledge_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
knowledge = await Knowledges.get_knowledge_by_id(id=id, db=db)
|
||||
if not knowledge:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -138,7 +138,10 @@ async def send_request(
|
||||
headers[FORWARD_SESSION_INFO_HEADER_CHAT_ID] = metadata.get('chat_id')
|
||||
|
||||
r = await session.request(
|
||||
method, url, data=payload, headers=headers,
|
||||
method,
|
||||
url,
|
||||
data=payload,
|
||||
headers=headers,
|
||||
ssl=AIOHTTP_CLIENT_SESSION_SSL,
|
||||
timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT),
|
||||
)
|
||||
@@ -782,7 +785,8 @@ async def delete_model(
|
||||
key = get_api_key(url_idx, url, request.app.state.config.OLLAMA_API_CONFIGS)
|
||||
|
||||
await send_request(
|
||||
f'{url}/api/delete', 'DELETE',
|
||||
f'{url}/api/delete',
|
||||
'DELETE',
|
||||
payload=json.dumps(form_data),
|
||||
key=key,
|
||||
user=user,
|
||||
@@ -1650,9 +1654,7 @@ async def upload_model(
|
||||
url = f'{ollama_url}/api/blobs/sha256:{file_hash}'
|
||||
upload_timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
async with aiohttp.ClientSession(timeout=upload_timeout, trust_env=True) as upload_session:
|
||||
async with upload_session.post(
|
||||
url, data=blob_data, ssl=AIOHTTP_CLIENT_SESSION_SSL
|
||||
) as response:
|
||||
async with upload_session.post(url, data=blob_data, ssl=AIOHTTP_CLIENT_SESSION_SSL) as response:
|
||||
if not response.ok:
|
||||
raise Exception('Ollama: Could not create blob, Please try again.')
|
||||
|
||||
|
||||
@@ -1211,7 +1211,7 @@ async def generate_chat_completion(
|
||||
except json.JSONDecodeError:
|
||||
return JSONResponse(
|
||||
status_code=r.status,
|
||||
content={"error": {"message": error_body, "code": r.status}},
|
||||
content={'error': {'message': error_body, 'code': r.status}},
|
||||
)
|
||||
|
||||
streaming = True
|
||||
|
||||
@@ -198,7 +198,9 @@ async def create_new_prompt(
|
||||
|
||||
|
||||
@router.get('/command/{command}', response_model=Optional[PromptAccessResponse])
|
||||
async def get_prompt_by_command(command: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def get_prompt_by_command(
|
||||
command: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
prompt = await Prompts.get_prompt_by_command(command, db=db)
|
||||
|
||||
if prompt:
|
||||
@@ -240,7 +242,9 @@ async def get_prompt_by_command(command: str, user=Depends(get_verified_user), d
|
||||
|
||||
|
||||
@router.get('/id/{prompt_id}', response_model=Optional[PromptAccessResponse])
|
||||
async def get_prompt_by_id(prompt_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def get_prompt_by_id(
|
||||
prompt_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
prompt = await Prompts.get_prompt_by_id(prompt_id, db=db)
|
||||
|
||||
if prompt:
|
||||
@@ -388,7 +392,9 @@ async def update_prompt_metadata(
|
||||
detail=f"Command '/{form_data.command}' is already in use",
|
||||
)
|
||||
|
||||
updated_prompt = await Prompts.update_prompt_metadata(prompt.id, form_data.name, form_data.command, form_data.tags, db=db)
|
||||
updated_prompt = await Prompts.update_prompt_metadata(
|
||||
prompt.id, form_data.name, form_data.command, form_data.tags, db=db
|
||||
)
|
||||
if updated_prompt:
|
||||
return updated_prompt
|
||||
else:
|
||||
@@ -497,7 +503,9 @@ async def update_prompt_access_by_id(
|
||||
|
||||
|
||||
@router.post('/id/{prompt_id}/toggle', response_model=Optional[PromptModel])
|
||||
async def toggle_prompt_active(prompt_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def toggle_prompt_active(
|
||||
prompt_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
prompt = await Prompts.get_prompt_by_id(prompt_id, db=db)
|
||||
|
||||
if not prompt:
|
||||
@@ -537,7 +545,9 @@ async def toggle_prompt_active(prompt_id: str, user=Depends(get_verified_user),
|
||||
|
||||
|
||||
@router.delete('/id/{prompt_id}/delete', response_model=bool)
|
||||
async def delete_prompt_by_id(prompt_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def delete_prompt_by_id(
|
||||
prompt_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
prompt = await Prompts.get_prompt_by_id(prompt_id, db=db)
|
||||
|
||||
if not prompt:
|
||||
|
||||
@@ -622,7 +622,9 @@ async def delete_tools_by_id(
|
||||
|
||||
|
||||
@router.get('/id/{id}/valves', response_model=Optional[dict])
|
||||
async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def get_tools_valves_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
tools = await Tools.get_tool_by_id(id, db=db)
|
||||
if not tools:
|
||||
raise HTTPException(
|
||||
@@ -775,7 +777,9 @@ async def update_tools_valves_by_id(
|
||||
|
||||
|
||||
@router.get('/id/{id}/valves/user', response_model=Optional[dict])
|
||||
async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def get_tools_user_valves_by_id(
|
||||
id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
tools = await Tools.get_tool_by_id(id, db=db)
|
||||
if not tools:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -273,7 +273,9 @@ async def update_default_user_permissions(request: Request, form_data: UserPermi
|
||||
|
||||
|
||||
@router.get('/user/settings', response_model=Optional[UserSettings])
|
||||
async def get_user_settings_by_session_user(user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def get_user_settings_by_session_user(
|
||||
user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
user = await Users.get_user_by_id(user.id, db=db)
|
||||
if user:
|
||||
return user.settings
|
||||
@@ -468,7 +470,9 @@ async def get_user_by_id(user_id: str, user=Depends(get_admin_user), db: AsyncSe
|
||||
|
||||
|
||||
@router.get('/{user_id}/info', response_model=UserInfoResponse)
|
||||
async def get_user_info_by_id(user_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def get_user_info_by_id(
|
||||
user_id: str, user=Depends(get_verified_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
user = await Users.get_user_by_id(user_id, db=db)
|
||||
if user:
|
||||
groups = await Groups.get_groups_by_member_id(user_id, db=db)
|
||||
@@ -487,7 +491,9 @@ async def get_user_info_by_id(user_id: str, user=Depends(get_verified_user), db:
|
||||
|
||||
|
||||
@router.get('/{user_id}/oauth/sessions')
|
||||
async def get_user_oauth_sessions_by_id(user_id: str, user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def get_user_oauth_sessions_by_id(
|
||||
user_id: str, user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
sessions = await OAuthSessions.get_sessions_by_user_id(user_id, db=db)
|
||||
if sessions and len(sessions) > 0:
|
||||
return sessions
|
||||
@@ -685,5 +691,7 @@ async def delete_user_by_id(user_id: str, user=Depends(get_admin_user), db: Asyn
|
||||
|
||||
|
||||
@router.get('/{user_id}/groups')
|
||||
async def get_user_groups_by_id(user_id: str, user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session)):
|
||||
async def get_user_groups_by_id(
|
||||
user_id: str, user=Depends(get_admin_user), db: AsyncSession = Depends(get_async_session)
|
||||
):
|
||||
return await Groups.get_groups_by_member_id(user_id, db=db)
|
||||
|
||||
@@ -20,7 +20,6 @@ from pytz import UTC
|
||||
from typing import Optional, Union, List, Dict
|
||||
|
||||
|
||||
|
||||
from open_webui.utils.access_control import has_permission
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.models.auths import Auths
|
||||
@@ -238,9 +237,7 @@ async def is_valid_token(request, decoded) -> bool:
|
||||
# Per-user revocation (OIDC back-channel logout)
|
||||
user_id = decoded.get('id')
|
||||
if user_id:
|
||||
revoked_at = await request.app.state.redis.get(
|
||||
f'{REDIS_KEY_PREFIX}:auth:user:{user_id}:revoked_at'
|
||||
)
|
||||
revoked_at = await request.app.state.redis.get(f'{REDIS_KEY_PREFIX}:auth:user:{user_id}:revoked_at')
|
||||
if revoked_at:
|
||||
try:
|
||||
revoked_at_ts = int(revoked_at)
|
||||
@@ -385,6 +382,7 @@ async def get_current_user(
|
||||
# Refresh the user's last active timestamp
|
||||
# Fire-and-forget via asyncio.create_task to avoid blocking
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(Users.update_last_active_by_id(user.id))
|
||||
return user
|
||||
else:
|
||||
@@ -432,15 +430,10 @@ async def get_current_user_by_api_key(request, api_key: str):
|
||||
# (Authorization header, cookie, x-api-key header, etc.).
|
||||
if request.app.state.config.ENABLE_API_KEYS_ENDPOINT_RESTRICTIONS:
|
||||
allowed_paths = [
|
||||
path.strip()
|
||||
for path in str(request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS).split(',')
|
||||
if path.strip()
|
||||
path.strip() for path in str(request.app.state.config.API_KEYS_ALLOWED_ENDPOINTS).split(',') if path.strip()
|
||||
]
|
||||
request_path = request.url.path
|
||||
is_allowed = any(
|
||||
request_path == allowed or request_path.startswith(allowed + '/')
|
||||
for allowed in allowed_paths
|
||||
)
|
||||
is_allowed = any(request_path == allowed or request_path.startswith(allowed + '/') for allowed in allowed_paths)
|
||||
if not is_allowed:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
||||
@@ -88,7 +88,7 @@ async def convert_markdown_base64_images(request, content: str, metadata, user):
|
||||
last_end = 0
|
||||
|
||||
for match in MARKDOWN_IMAGE_URL_PATTERN.finditer(content):
|
||||
result_parts.append(content[last_end:match.start()])
|
||||
result_parts.append(content[last_end : match.start()])
|
||||
base64_string = match.group(2)
|
||||
if len(base64_string) > MIN_REPLACEMENT_URL_LENGTH:
|
||||
url = await get_image_url_from_base64(request, base64_string, metadata, user)
|
||||
|
||||
@@ -1785,7 +1785,10 @@ class OAuthManager:
|
||||
log.warning(f'Back-channel logout: no configured provider matches issuer {token_issuer}')
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={'error': 'invalid_request', 'error_description': 'No configured provider matches token issuer'},
|
||||
content={
|
||||
'error': 'invalid_request',
|
||||
'error_description': 'No configured provider matches token issuer',
|
||||
},
|
||||
)
|
||||
|
||||
# 4. Validate the logout_token signature and claims
|
||||
@@ -1886,5 +1889,7 @@ class OAuthManager:
|
||||
f'(email={user.email}, provider={matched_provider}, sessions_deleted={len(sessions)})'
|
||||
)
|
||||
|
||||
log.info(f'Back-channel logout: completed for {len(users_to_logout)} user(s), {revoked_count} revocation(s) set')
|
||||
log.info(
|
||||
f'Back-channel logout: completed for {len(users_to_logout)} user(s), {revoked_count} revocation(s) set'
|
||||
)
|
||||
return JSONResponse(status_code=200, content={})
|
||||
|
||||
@@ -194,20 +194,12 @@ def get_redis_connection(
|
||||
connection = None
|
||||
|
||||
connect_timeout_kwargs = (
|
||||
{'socket_connect_timeout': REDIS_SOCKET_CONNECT_TIMEOUT}
|
||||
if REDIS_SOCKET_CONNECT_TIMEOUT is not None
|
||||
else {}
|
||||
{'socket_connect_timeout': REDIS_SOCKET_CONNECT_TIMEOUT} if REDIS_SOCKET_CONNECT_TIMEOUT is not None else {}
|
||||
)
|
||||
|
||||
keepalive_kwargs = (
|
||||
{'socket_keepalive': True} if REDIS_SOCKET_KEEPALIVE else {}
|
||||
)
|
||||
keepalive_kwargs = {'socket_keepalive': True} if REDIS_SOCKET_KEEPALIVE else {}
|
||||
|
||||
health_check_kwargs = (
|
||||
{'health_check_interval': REDIS_HEALTH_CHECK_INTERVAL}
|
||||
if REDIS_HEALTH_CHECK_INTERVAL
|
||||
else {}
|
||||
)
|
||||
health_check_kwargs = {'health_check_interval': REDIS_HEALTH_CHECK_INTERVAL} if REDIS_HEALTH_CHECK_INTERVAL else {}
|
||||
|
||||
if async_mode:
|
||||
import redis.asyncio as redis
|
||||
|
||||
@@ -64,8 +64,7 @@ async def get_session() -> aiohttp.ClientSession:
|
||||
trust_env=True,
|
||||
)
|
||||
log.info(
|
||||
'Created shared aiohttp session pool '
|
||||
'(limit=%s, per_host=%s, dns_ttl=%d)',
|
||||
'Created shared aiohttp session pool (limit=%s, per_host=%s, dns_ttl=%d)',
|
||||
AIOHTTP_POOL_CONNECTIONS or 'unlimited',
|
||||
AIOHTTP_POOL_CONNECTIONS_PER_HOST or 'unlimited',
|
||||
AIOHTTP_POOL_DNS_TTL,
|
||||
|
||||
@@ -101,7 +101,9 @@ log = logging.getLogger(__name__)
|
||||
|
||||
# Let no function be called without need, and let what
|
||||
# it yields justify the cost of running it.
|
||||
async def get_async_tool_function_and_apply_extra_params(function: Callable, extra_params: dict) -> Callable[..., Awaitable]:
|
||||
async def get_async_tool_function_and_apply_extra_params(
|
||||
function: Callable, extra_params: dict
|
||||
) -> Callable[..., Awaitable]:
|
||||
sig = inspect.signature(function)
|
||||
extra_params = {k: v for k, v in extra_params.items() if k in sig.parameters}
|
||||
partial_func = partial(function, **extra_params)
|
||||
@@ -544,7 +546,9 @@ async def get_builtin_tools(
|
||||
|
||||
# Automation tools - create and manage scheduled automations from chat
|
||||
if is_builtin_tool_enabled('automations') and await has_user_permission('automations'):
|
||||
builtin_functions.extend([create_automation, update_automation, list_automations, toggle_automation, delete_automation])
|
||||
builtin_functions.extend(
|
||||
[create_automation, update_automation, list_automations, toggle_automation, delete_automation]
|
||||
)
|
||||
|
||||
for func in builtin_functions:
|
||||
callable = await get_async_tool_function_and_apply_extra_params(
|
||||
|
||||
@@ -13,19 +13,19 @@ _USER_PROFILE_IMAGE_RE = re.compile(r'^/api/v1/users/[^/?#]+/profile/image$')
|
||||
# regex across megabytes of data on every Pydantic instantiation for zero
|
||||
# security benefit (corrupt base64 simply renders a broken image, same as
|
||||
# a 404 URL). SVG is intentionally excluded: it can carry embedded scripts.
|
||||
_SAFE_DATA_URI_RE = re.compile(
|
||||
r'^data:image/(png|jpeg|gif|webp);base64,', re.IGNORECASE
|
||||
)
|
||||
_SAFE_DATA_URI_RE = re.compile(r'^data:image/(png|jpeg|gif|webp);base64,', re.IGNORECASE)
|
||||
|
||||
# Exact relative paths accepted as profile images. These are the only
|
||||
# static-asset paths OWUI itself assigns; no prefix/wildcard matching is
|
||||
# used so that arbitrary relative paths cannot trigger authenticated GETs
|
||||
# against internal endpoints when rendered as ``<img>`` sources.
|
||||
_SAFE_STATIC_PATHS = frozenset({
|
||||
'/user.png',
|
||||
'/favicon.png',
|
||||
'/static/favicon.png',
|
||||
})
|
||||
_SAFE_STATIC_PATHS = frozenset(
|
||||
{
|
||||
'/user.png',
|
||||
'/favicon.png',
|
||||
'/static/favicon.png',
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def validate_profile_image_url(url: str) -> str:
|
||||
@@ -67,9 +67,7 @@ def validate_profile_image_url(url: str) -> str:
|
||||
# for a URL like http://:80/path with no actual host).
|
||||
if parsed.scheme in ('http', 'https'):
|
||||
if not parsed.hostname:
|
||||
raise ValueError(
|
||||
'Invalid profile image URL: HTTP(S) URLs must include a host.'
|
||||
)
|
||||
raise ValueError('Invalid profile image URL: HTTP(S) URLs must include a host.')
|
||||
return url
|
||||
|
||||
# Base64-encoded raster images uploaded via the frontend.
|
||||
|
||||
@@ -20,7 +20,6 @@
|
||||
type AutomationRunModel
|
||||
} from '$lib/apis/automations';
|
||||
|
||||
|
||||
import Spinner from '$lib/components/common/Spinner.svelte';
|
||||
import Tooltip from '$lib/components/common/Tooltip.svelte';
|
||||
import DeleteConfirmDialog from '$lib/components/common/ConfirmDialog.svelte';
|
||||
@@ -42,7 +41,6 @@
|
||||
let model_id = '';
|
||||
let is_active = true;
|
||||
|
||||
|
||||
let loading = false;
|
||||
let saving = false;
|
||||
let showDeleteConfirm = false;
|
||||
@@ -335,7 +333,6 @@
|
||||
<span class="text-gray-600 dark:text-gray-400">{$i18n.t('Model')}</span>
|
||||
<ModelDropdown bind:model_id side="bottom" align="end" onChange={markDirty} />
|
||||
</div>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -46,11 +46,13 @@
|
||||
}}
|
||||
>
|
||||
<Tooltip
|
||||
content={DOMPurify.sanitize(marked.parse(
|
||||
sanitizeResponseContent(
|
||||
models[selectedModelIdx]?.info?.meta?.description ?? ''
|
||||
).replaceAll('\n', '<br>')
|
||||
))}
|
||||
content={DOMPurify.sanitize(
|
||||
marked.parse(
|
||||
sanitizeResponseContent(
|
||||
models[selectedModelIdx]?.info?.meta?.description ?? ''
|
||||
).replaceAll('\n', '<br>')
|
||||
)
|
||||
)}
|
||||
placement="right"
|
||||
>
|
||||
<img
|
||||
@@ -97,11 +99,13 @@
|
||||
<div
|
||||
class="mt-0.5 text-base font-normal text-gray-500 dark:text-gray-400 line-clamp-3 markdown"
|
||||
>
|
||||
{@html DOMPurify.sanitize(marked.parse(
|
||||
sanitizeResponseContent(
|
||||
models[selectedModelIdx]?.info?.meta?.description
|
||||
).replaceAll('\n', '<br>')
|
||||
))}
|
||||
{@html DOMPurify.sanitize(
|
||||
marked.parse(
|
||||
sanitizeResponseContent(
|
||||
models[selectedModelIdx]?.info?.meta?.description
|
||||
).replaceAll('\n', '<br>')
|
||||
)
|
||||
)}
|
||||
</div>
|
||||
{#if models[selectedModelIdx]?.info?.meta?.user}
|
||||
<div class="mt-0.5 text-sm font-normal text-gray-400 dark:text-gray-500">
|
||||
|
||||
@@ -399,7 +399,8 @@
|
||||
{/if}
|
||||
<iframe
|
||||
src={serveUrl}
|
||||
sandbox="allow-scripts allow-same-origin allow-downloads{($settings?.iframeSandboxAllowForms ?? false)
|
||||
sandbox="allow-scripts allow-same-origin allow-downloads{($settings?.iframeSandboxAllowForms ??
|
||||
false)
|
||||
? ' allow-forms'
|
||||
: ''}"
|
||||
class="w-full h-full border-none bg-white"
|
||||
|
||||
@@ -102,9 +102,7 @@
|
||||
<Cloud className="size-3.5" strokeWidth="2" />
|
||||
|
||||
{#if $selectedTerminalId && selectedLabel}
|
||||
<span class="truncate text-[13px] max-w-[100px] sm:max-w-[150px]"
|
||||
>{selectedLabel}</span
|
||||
>
|
||||
<span class="truncate text-[13px] max-w-[100px] sm:max-w-[150px]">{selectedLabel}</span>
|
||||
{/if}
|
||||
</button>
|
||||
</Tooltip>
|
||||
|
||||
@@ -165,21 +165,25 @@
|
||||
{#if models[selectedModelIdx]?.info?.meta?.description ?? null}
|
||||
<Tooltip
|
||||
className=" w-fit"
|
||||
content={DOMPurify.sanitize(marked.parse(
|
||||
sanitizeResponseContent(
|
||||
models[selectedModelIdx]?.info?.meta?.description ?? ''
|
||||
).replaceAll('\n', '<br>')
|
||||
))}
|
||||
content={DOMPurify.sanitize(
|
||||
marked.parse(
|
||||
sanitizeResponseContent(
|
||||
models[selectedModelIdx]?.info?.meta?.description ?? ''
|
||||
).replaceAll('\n', '<br>')
|
||||
)
|
||||
)}
|
||||
placement="top"
|
||||
>
|
||||
<div
|
||||
class="mt-0.5 px-2 text-sm font-normal text-gray-500 dark:text-gray-400 line-clamp-2 max-w-xl markdown"
|
||||
>
|
||||
{@html DOMPurify.sanitize(marked.parse(
|
||||
sanitizeResponseContent(
|
||||
models[selectedModelIdx]?.info?.meta?.description ?? ''
|
||||
).replaceAll('\n', '<br>')
|
||||
))}
|
||||
{@html DOMPurify.sanitize(
|
||||
marked.parse(
|
||||
sanitizeResponseContent(
|
||||
models[selectedModelIdx]?.info?.meta?.description ?? ''
|
||||
).replaceAll('\n', '<br>')
|
||||
)
|
||||
)}
|
||||
</div>
|
||||
</Tooltip>
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
|
||||
import { splitStream } from '$lib/utils';
|
||||
import Spinner from '$lib/components/common/Spinner.svelte';
|
||||
|
||||
|
||||
const i18n = getContext('i18n');
|
||||
|
||||
|
||||
@@ -101,9 +101,7 @@ export const tokenizeDisplayMath = (
|
||||
() => !requireBlockBoundary || isBlockBoundary(src, afterClose)
|
||||
];
|
||||
|
||||
return validators.every((v) => v())
|
||||
? { type, raw, text, displayMode: true }
|
||||
: undefined;
|
||||
return validators.every((v) => v()) ? { type, raw, text, displayMode: true } : undefined;
|
||||
};
|
||||
|
||||
export default function (options = {}) {
|
||||
|
||||
@@ -41,7 +41,11 @@ export function mentionExtension(opts: MentionOptions = {}) {
|
||||
// mentionStart fires on every '<' in the document, making the tokenizer a hot path.
|
||||
const trigger = opts.triggerChar ?? '@';
|
||||
const re = new RegExp(`^<\\${trigger}([\\w.\\-:/]+)(?:\\|([^>]*))?>`);
|
||||
const snapshot: MentionOptions = { triggerChar: trigger, className: opts.className, extraAttrs: opts.extraAttrs };
|
||||
const snapshot: MentionOptions = {
|
||||
triggerChar: trigger,
|
||||
className: opts.className,
|
||||
extraAttrs: opts.extraAttrs
|
||||
};
|
||||
|
||||
return {
|
||||
name: 'mention',
|
||||
|
||||
@@ -728,7 +728,8 @@
|
||||
|
||||
// Apply theme classes (mirrors logic from chat/Settings/General.svelte)
|
||||
const themes = ['dark', 'light', 'oled-dark'];
|
||||
let themeToApply = newTheme === 'oled-dark' ? 'dark' : newTheme === 'her' ? 'light' : newTheme;
|
||||
let themeToApply =
|
||||
newTheme === 'oled-dark' ? 'dark' : newTheme === 'her' ? 'light' : newTheme;
|
||||
if (newTheme === 'system') {
|
||||
themeToApply = window.matchMedia('(prefers-color-scheme: dark)').matches ? 'dark' : 'light';
|
||||
}
|
||||
@@ -988,13 +989,15 @@
|
||||
console.error('Error refreshing backend config:', error);
|
||||
}
|
||||
|
||||
// Relay auth token to desktop app for API access
|
||||
if (window.electronAPI?.send) {
|
||||
window.electronAPI.send({
|
||||
type: 'token:update',
|
||||
token: localStorage.token
|
||||
}).catch(() => {});
|
||||
}
|
||||
// Relay auth token to desktop app for API access
|
||||
if (window.electronAPI?.send) {
|
||||
window.electronAPI
|
||||
.send({
|
||||
type: 'token:update',
|
||||
token: localStorage.token
|
||||
})
|
||||
.catch(() => {});
|
||||
}
|
||||
} else {
|
||||
// Redirect Invalid Session User to /auth Page
|
||||
localStorage.removeItem('token');
|
||||
|
||||
Reference in New Issue
Block a user