mirror of
https://github.com/open-webui/open-webui.git
synced 2026-06-14 03:30:25 +00:00
refac
This commit is contained in:
@@ -1,20 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.utils.session_pool import get_session
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Brave free-tier rate limit: 1 request per second.
|
||||
_RATE_LIMIT_RETRY_DELAY = 1.0
|
||||
|
||||
def search_brave(api_key: str, query: str, count: int, filter_list: list[str | None] = None) -> list[SearchResult]:
|
||||
"""Search using Brave's Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A Brave Search API key
|
||||
query (str): The query to search for
|
||||
async def search_brave(
|
||||
api_key: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: list[str | None] | None = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Query the Brave Web Search API and return normalised results.
|
||||
|
||||
Retries once on HTTP 429 (rate-limit) after a short delay.
|
||||
"""
|
||||
url = 'https://api.search.brave.com/res/v1/web/search'
|
||||
headers = {
|
||||
@@ -24,27 +30,27 @@ def search_brave(api_key: str, query: str, count: int, filter_list: list[str | N
|
||||
}
|
||||
params = {'q': query, 'count': count}
|
||||
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
session = await get_session()
|
||||
async with session.get(url, headers=headers, params=params) as response:
|
||||
if response.status == 429:
|
||||
log.info('Brave Search rate-limited (429); retrying after %.1fs', _RATE_LIMIT_RETRY_DELAY)
|
||||
await asyncio.sleep(_RATE_LIMIT_RETRY_DELAY)
|
||||
async with session.get(url, headers=headers, params=params) as retry_resp:
|
||||
retry_resp.raise_for_status()
|
||||
payload = await retry_resp.json()
|
||||
else:
|
||||
response.raise_for_status()
|
||||
payload = await response.json()
|
||||
|
||||
# Handle 429 rate limiting - Brave free tier allows 1 request/second
|
||||
# If rate limited, wait 1 second and retry once before failing
|
||||
if response.status_code == 429:
|
||||
log.info('Brave Search API rate limited (429), retrying after 1 second...')
|
||||
time.sleep(1)
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
json_response = response.json()
|
||||
results = json_response.get('web', {}).get('results', [])
|
||||
web_results = payload.get('web', {}).get('results', [])
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
web_results = get_filtered_results(web_results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result['url'],
|
||||
title=result.get('title'),
|
||||
snippet=result.get('description'),
|
||||
link=item.get('url', ''),
|
||||
title=item.get('title'),
|
||||
snippet=item.get('description'),
|
||||
)
|
||||
for result in results[:count]
|
||||
for item in web_results[:count]
|
||||
]
|
||||
|
||||
@@ -2,70 +2,65 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.utils.session_pool import get_session
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def search_google_pse(
|
||||
async def search_google_pse(
|
||||
api_key: str,
|
||||
search_engine_id: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: list[str | None] = None,
|
||||
filter_list: list[str | None] | None = None,
|
||||
referer: str | None = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
|
||||
Handles pagination for counts greater than 10.
|
||||
"""Query Google Programmable Search Engine with automatic pagination.
|
||||
|
||||
Args:
|
||||
api_key (str): A Programmable Search Engine API key
|
||||
search_engine_id (str): A Programmable Search Engine ID
|
||||
query (str): The query to search for
|
||||
count (int): The number of results to return (max 100, as PSE max results per query is 10 and max page is 10)
|
||||
filter_list (list[str | None], optional): A list of keywords to filter out from results. Defaults to None.
|
||||
|
||||
Returns:
|
||||
list[SearchResult]: A list of SearchResult objects.
|
||||
The PSE API returns at most 10 results per request, so this function
|
||||
issues multiple requests when ``count > 10``.
|
||||
"""
|
||||
url = 'https://www.googleapis.com/customsearch/v1'
|
||||
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
headers: dict[str, str] = {'Content-Type': 'application/json'}
|
||||
if referer:
|
||||
headers['Referer'] = referer
|
||||
|
||||
all_results = []
|
||||
start_index = 1 # Google PSE start parameter is 1-based
|
||||
all_items: list[dict] = []
|
||||
start_index = 1 # PSE uses 1-based pagination
|
||||
|
||||
while count > 0:
|
||||
num_results_this_page = min(count, 10) # Google PSE max results per page is 10
|
||||
session = await get_session()
|
||||
remaining = count
|
||||
while remaining > 0:
|
||||
page_size = min(remaining, 10)
|
||||
params = {
|
||||
'cx': search_engine_id,
|
||||
'q': query,
|
||||
'key': api_key,
|
||||
'num': num_results_this_page,
|
||||
'start': start_index,
|
||||
'num': str(page_size),
|
||||
'start': str(start_index),
|
||||
}
|
||||
response = requests.request('GET', url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
json_response = response.json()
|
||||
results = json_response.get('items', [])
|
||||
if results: # check if results are returned. If not, no more pages to fetch.
|
||||
all_results.extend(results)
|
||||
count -= len(results) # Decrement count by the number of results fetched in this page.
|
||||
start_index += 10 # Increment start index for the next page
|
||||
else:
|
||||
break # No more results from Google PSE, break the loop
|
||||
|
||||
async with session.get(url, headers=headers, params=params) as response:
|
||||
response.raise_for_status()
|
||||
payload = await response.json()
|
||||
|
||||
items = payload.get('items', [])
|
||||
if not items:
|
||||
break
|
||||
|
||||
all_items.extend(items)
|
||||
remaining -= len(items)
|
||||
start_index += 10
|
||||
|
||||
if filter_list:
|
||||
all_results = get_filtered_results(all_results, filter_list)
|
||||
all_items = get_filtered_results(all_items, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result['link'],
|
||||
title=result.get('title'),
|
||||
snippet=result.get('snippet'),
|
||||
link=item.get('link', ''),
|
||||
title=item.get('title'),
|
||||
snippet=item.get('snippet'),
|
||||
)
|
||||
for result in all_results
|
||||
for item in all_items
|
||||
]
|
||||
|
||||
@@ -2,87 +2,65 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.utils.session_pool import get_session
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# SearXNG request headers — identifies the bot to instance operators.
|
||||
_SEARXNG_HEADERS = {
|
||||
'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) RAG Bot',
|
||||
'Accept': 'text/html',
|
||||
'Accept-Encoding': 'gzip, deflate',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Connection': 'keep-alive',
|
||||
}
|
||||
|
||||
def search_searxng( # noqa: PLR0913
|
||||
|
||||
async def search_searxng(
|
||||
query_url: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: list[str | None] = None,
|
||||
filter_list: list[str | None] | None = None,
|
||||
**kwargs,
|
||||
) -> list[SearchResult]:
|
||||
"""Query a SearXNG instance and return results sorted by relevance score.
|
||||
|
||||
Optional keyword arguments (language, safesearch, time_range, categories)
|
||||
are forwarded directly as SearXNG query parameters.
|
||||
"""
|
||||
Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
|
||||
|
||||
The function allows passing additional parameters such as language or time_range to tailor the search result.
|
||||
|
||||
Args:
|
||||
query_url (str): The base URL of the SearXNG server.
|
||||
query (str): The search term or question to find in the SearXNG database.
|
||||
count (int): The maximum number of results to retrieve from the search.
|
||||
|
||||
Keyword Args:
|
||||
language (str): Language filter for the search results; e.g., "all", "en-US", "es". Defaults to "all".
|
||||
safesearch (int): Safe search filter for safer web results; 0 = off, 1 = moderate, 2 = strict. Defaults to 1 (moderate).
|
||||
time_range (str): Time range for filtering results by date; e.g., "2023-04-05..today" or "all-time". Defaults to ''.
|
||||
categories: (list[str | None]): Specific categories within which the search should be performed, defaulting to an empty string if not provided.
|
||||
|
||||
Returns:
|
||||
list[SearchResult]: A list of SearchResults sorted by relevance score in descending order.
|
||||
|
||||
Raise:
|
||||
requests.exceptions.RequestException: If a request error occurs during the search process.
|
||||
"""
|
||||
|
||||
# Default values for optional parameters are provided as empty strings or None when not specified.
|
||||
language = kwargs.get('language', 'all').strip().rstrip(',')
|
||||
safesearch = kwargs.get('safesearch', '1')
|
||||
time_range = kwargs.get('time_range', '')
|
||||
categories = ''.join(kwargs.get('categories', []))
|
||||
# Normalise legacy ``<query>``-style URLs by stripping any query string.
|
||||
if '<query>' in query_url:
|
||||
query_url = query_url.split('?')[0]
|
||||
|
||||
params = {
|
||||
'q': query,
|
||||
'format': 'json',
|
||||
'pageno': 1,
|
||||
'safesearch': safesearch,
|
||||
'language': language,
|
||||
'time_range': time_range,
|
||||
'categories': categories,
|
||||
'safesearch': kwargs.get('safesearch', '1'),
|
||||
'language': kwargs.get('language', 'all').strip().rstrip(','),
|
||||
'time_range': kwargs.get('time_range', ''),
|
||||
'categories': ''.join(kwargs.get('categories', [])),
|
||||
'theme': 'simple',
|
||||
'image_proxy': 0,
|
||||
}
|
||||
|
||||
# Legacy query format
|
||||
if '<query>' in query_url:
|
||||
# Strip all query parameters from the URL
|
||||
query_url = query_url.split('?')[0]
|
||||
log.debug('searching %s', query_url)
|
||||
|
||||
log.debug(f'searching {query_url}')
|
||||
session = await get_session()
|
||||
async with session.get(query_url, headers=_SEARXNG_HEADERS, params=params) as response:
|
||||
response.raise_for_status()
|
||||
payload = await response.json()
|
||||
|
||||
response = requests.get(
|
||||
query_url,
|
||||
headers={
|
||||
'User-Agent': 'Open WebUI (https://github.com/open-webui/open-webui) RAG Bot',
|
||||
'Accept': 'text/html',
|
||||
'Accept-Encoding': 'gzip, deflate',
|
||||
'Accept-Language': 'en-US,en;q=0.5',
|
||||
'Connection': 'keep-alive',
|
||||
},
|
||||
params=params,
|
||||
)
|
||||
|
||||
response.raise_for_status() # Raise an exception for HTTP errors.
|
||||
|
||||
json_response = response.json()
|
||||
results = json_response.get('results', [])
|
||||
sorted_results = sorted(results, key=lambda x: x.get('score', 0), reverse=True)
|
||||
results = sorted(payload.get('results', []), key=lambda x: x.get('score', 0), reverse=True)
|
||||
if filter_list:
|
||||
sorted_results = get_filtered_results(sorted_results, filter_list)
|
||||
results = get_filtered_results(results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(link=result['url'], title=result.get('title'), snippet=result.get('content'))
|
||||
for result in sorted_results[:count]
|
||||
SearchResult(
|
||||
link=item.get('url', ''),
|
||||
title=item.get('title'),
|
||||
snippet=item.get('content'),
|
||||
)
|
||||
for item in results[:count]
|
||||
]
|
||||
|
||||
@@ -3,36 +3,39 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.utils.session_pool import get_session
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def search_serper(api_key: str, query: str, count: int, filter_list: list[str | None] = None) -> list[SearchResult]:
|
||||
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
|
||||
async def search_serper(
|
||||
api_key: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: list[str | None] | None = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Query the serper.dev Google Search API and return normalised results.
|
||||
|
||||
Args:
|
||||
api_key (str): A serper.dev API key
|
||||
query (str): The query to search for
|
||||
Results are sorted by their position field before truncation.
|
||||
"""
|
||||
url = 'https://google.serper.dev/search'
|
||||
|
||||
payload = json.dumps({'q': query})
|
||||
headers = {'X-API-KEY': api_key, 'Content-Type': 'application/json'}
|
||||
|
||||
response = requests.request('POST', url, headers=headers, data=payload)
|
||||
response.raise_for_status()
|
||||
session = await get_session()
|
||||
async with session.post(url, headers=headers, data=json.dumps({'q': query})) as response:
|
||||
response.raise_for_status()
|
||||
payload = await response.json()
|
||||
|
||||
json_response = response.json()
|
||||
results = sorted(json_response.get('organic', []), key=lambda x: x.get('position', 0))
|
||||
organic = sorted(payload.get('organic', []), key=lambda item: item.get('position', 0))
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
organic = get_filtered_results(organic, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result['link'],
|
||||
title=result.get('title'),
|
||||
snippet=result.get('snippet'),
|
||||
link=item.get('link', ''),
|
||||
title=item.get('title'),
|
||||
snippet=item.get('snippet'),
|
||||
)
|
||||
for result in results[:count]
|
||||
for item in organic[:count]
|
||||
]
|
||||
|
||||
@@ -2,42 +2,41 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from open_webui.retrieval.web.main import SearchResult, get_filtered_results
|
||||
from open_webui.utils.session_pool import get_session
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def search_serpstack(
|
||||
async def search_serpstack(
|
||||
api_key: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: list[str | None] = None,
|
||||
filter_list: list[str | None] | None = None,
|
||||
https_enabled: bool = True,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using serpstack.com's and return the results as a list of SearchResult objects.
|
||||
"""Query the serpstack.com API and return normalised results.
|
||||
|
||||
Args:
|
||||
api_key (str): A serpstack.com API key
|
||||
query (str): The query to search for
|
||||
https_enabled (bool): Whether to use HTTPS or HTTP for the API request
|
||||
Uses HTTPS by default; set ``https_enabled=False`` for free-tier HTTP access.
|
||||
"""
|
||||
url = f'{"https" if https_enabled else "http"}://api.serpstack.com/search'
|
||||
scheme = 'https' if https_enabled else 'http'
|
||||
url = f'{scheme}://api.serpstack.com/search'
|
||||
params = {'access_key': api_key, 'query': query}
|
||||
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
params = {
|
||||
'access_key': api_key,
|
||||
'query': query,
|
||||
}
|
||||
session = await get_session()
|
||||
async with session.get(url, params=params) as response:
|
||||
response.raise_for_status()
|
||||
payload = await response.json()
|
||||
|
||||
response = requests.request('POST', url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
|
||||
json_response = response.json()
|
||||
results = sorted(json_response.get('organic_results', []), key=lambda x: x.get('position', 0))
|
||||
organic = sorted(payload.get('organic_results', []), key=lambda x: x.get('position', 0))
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
organic = get_filtered_results(organic, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(link=result['url'], title=result.get('title'), snippet=result.get('snippet'))
|
||||
for result in results[:count]
|
||||
SearchResult(
|
||||
link=item.get('url', ''),
|
||||
title=item.get('title'),
|
||||
snippet=item.get('snippet'),
|
||||
)
|
||||
for item in organic[:count]
|
||||
]
|
||||
|
||||
@@ -1866,33 +1866,18 @@ async def process_web(
|
||||
)
|
||||
|
||||
|
||||
def search_web(request: Request, engine: str, query: str, user=None) -> list[SearchResult]:
|
||||
"""Search the web using a search engine and return the results as a list of SearchResult objects.
|
||||
Will look for a search engine API key in environment variables in the following order:
|
||||
- SEARXNG_QUERY_URL
|
||||
- YACY_QUERY_URL + YACY_USERNAME + YACY_PASSWORD
|
||||
- GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
|
||||
- BRAVE_SEARCH_API_KEY
|
||||
- KAGI_SEARCH_API_KEY
|
||||
- MOJEEK_SEARCH_API_KEY
|
||||
- BOCHA_SEARCH_API_KEY
|
||||
- SERPSTACK_API_KEY
|
||||
- SERPER_API_KEY
|
||||
- SERPLY_API_KEY
|
||||
- TAVILY_API_KEY
|
||||
- EXA_API_KEY
|
||||
- PERPLEXITY_API_KEY
|
||||
- SOUGOU_API_SID + SOUGOU_API_SK
|
||||
- SEARCHAPI_API_KEY + SEARCHAPI_ENGINE (by default `google`)
|
||||
- SERPAPI_API_KEY + SERPAPI_ENGINE (by default `google`)
|
||||
- LINKUP_API_KEY
|
||||
Args:
|
||||
query (str): The query to search for
|
||||
async def search_web(request: Request, engine: str, query: str, user=None) -> list[SearchResult]:
|
||||
"""Dispatch a web search query to the configured engine and return results.
|
||||
|
||||
Providers that have been migrated to async (aiohttp) are awaited natively.
|
||||
Legacy sync providers are offloaded via ``asyncio.to_thread`` to avoid
|
||||
blocking the event loop.
|
||||
"""
|
||||
|
||||
# TODO: add playwright to search the web
|
||||
if engine == 'ollama_cloud':
|
||||
return search_ollama_cloud(
|
||||
return await asyncio.to_thread(
|
||||
search_ollama_cloud,
|
||||
'https://ollama.com',
|
||||
request.app.state.config.OLLAMA_CLOUD_WEB_SEARCH_API_KEY,
|
||||
query,
|
||||
@@ -1901,7 +1886,8 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
)
|
||||
elif engine == 'perplexity_search':
|
||||
if request.app.state.config.PERPLEXITY_API_KEY:
|
||||
return search_perplexity_search(
|
||||
return await asyncio.to_thread(
|
||||
search_perplexity_search,
|
||||
request.app.state.config.PERPLEXITY_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
@@ -1914,7 +1900,7 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
elif engine == 'searxng':
|
||||
if request.app.state.config.SEARXNG_QUERY_URL:
|
||||
searxng_kwargs = {'language': request.app.state.config.SEARXNG_LANGUAGE}
|
||||
return search_searxng(
|
||||
return await search_searxng(
|
||||
request.app.state.config.SEARXNG_QUERY_URL,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
@@ -1925,7 +1911,8 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
raise Exception('No SEARXNG_QUERY_URL found in environment variables')
|
||||
elif engine == 'yacy':
|
||||
if request.app.state.config.YACY_QUERY_URL:
|
||||
return search_yacy(
|
||||
return await asyncio.to_thread(
|
||||
search_yacy,
|
||||
request.app.state.config.YACY_QUERY_URL,
|
||||
request.app.state.config.YACY_USERNAME,
|
||||
request.app.state.config.YACY_PASSWORD,
|
||||
@@ -1937,7 +1924,7 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
raise Exception('No YACY_QUERY_URL found in environment variables')
|
||||
elif engine == 'google_pse':
|
||||
if request.app.state.config.GOOGLE_PSE_API_KEY and request.app.state.config.GOOGLE_PSE_ENGINE_ID:
|
||||
return search_google_pse(
|
||||
return await search_google_pse(
|
||||
request.app.state.config.GOOGLE_PSE_API_KEY,
|
||||
request.app.state.config.GOOGLE_PSE_ENGINE_ID,
|
||||
query,
|
||||
@@ -1949,7 +1936,7 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
raise Exception('No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables')
|
||||
elif engine == 'brave':
|
||||
if request.app.state.config.BRAVE_SEARCH_API_KEY:
|
||||
return search_brave(
|
||||
return await search_brave(
|
||||
request.app.state.config.BRAVE_SEARCH_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
@@ -1959,7 +1946,8 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
raise Exception('No BRAVE_SEARCH_API_KEY found in environment variables')
|
||||
elif engine == 'brave_llm_context':
|
||||
if request.app.state.config.BRAVE_SEARCH_API_KEY:
|
||||
return search_brave_llm_context(
|
||||
return await asyncio.to_thread(
|
||||
search_brave_llm_context,
|
||||
request.app.state.config.BRAVE_SEARCH_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
@@ -1970,7 +1958,8 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
raise Exception('No BRAVE_SEARCH_API_KEY found in environment variables')
|
||||
elif engine == 'kagi':
|
||||
if request.app.state.config.KAGI_SEARCH_API_KEY:
|
||||
return search_kagi(
|
||||
return await asyncio.to_thread(
|
||||
search_kagi,
|
||||
request.app.state.config.KAGI_SEARCH_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
@@ -1980,7 +1969,8 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
raise Exception('No KAGI_SEARCH_API_KEY found in environment variables')
|
||||
elif engine == 'mojeek':
|
||||
if request.app.state.config.MOJEEK_SEARCH_API_KEY:
|
||||
return search_mojeek(
|
||||
return await asyncio.to_thread(
|
||||
search_mojeek,
|
||||
request.app.state.config.MOJEEK_SEARCH_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
@@ -1990,7 +1980,8 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
raise Exception('No MOJEEK_SEARCH_API_KEY found in environment variables')
|
||||
elif engine == 'bocha':
|
||||
if request.app.state.config.BOCHA_SEARCH_API_KEY:
|
||||
return search_bocha(
|
||||
return await asyncio.to_thread(
|
||||
search_bocha,
|
||||
request.app.state.config.BOCHA_SEARCH_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
@@ -2000,7 +1991,7 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
raise Exception('No BOCHA_SEARCH_API_KEY found in environment variables')
|
||||
elif engine == 'serpstack':
|
||||
if request.app.state.config.SERPSTACK_API_KEY:
|
||||
return search_serpstack(
|
||||
return await search_serpstack(
|
||||
request.app.state.config.SERPSTACK_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
@@ -2011,7 +2002,7 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
raise Exception('No SERPSTACK_API_KEY found in environment variables')
|
||||
elif engine == 'serper':
|
||||
if request.app.state.config.SERPER_API_KEY:
|
||||
return search_serper(
|
||||
return await search_serper(
|
||||
request.app.state.config.SERPER_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
@@ -2021,7 +2012,8 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
raise Exception('No SERPER_API_KEY found in environment variables')
|
||||
elif engine == 'serply':
|
||||
if request.app.state.config.SERPLY_API_KEY:
|
||||
return search_serply(
|
||||
return await asyncio.to_thread(
|
||||
search_serply,
|
||||
request.app.state.config.SERPLY_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
@@ -2030,7 +2022,8 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
else:
|
||||
raise Exception('No SERPLY_API_KEY found in environment variables')
|
||||
elif engine == 'duckduckgo':
|
||||
return search_duckduckgo(
|
||||
return await asyncio.to_thread(
|
||||
search_duckduckgo,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
@@ -2039,7 +2032,8 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
)
|
||||
elif engine == 'tavily':
|
||||
if request.app.state.config.TAVILY_API_KEY:
|
||||
return search_tavily(
|
||||
return await asyncio.to_thread(
|
||||
search_tavily,
|
||||
request.app.state.config.TAVILY_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
@@ -2049,7 +2043,8 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
raise Exception('No TAVILY_API_KEY found in environment variables')
|
||||
elif engine == 'exa':
|
||||
if request.app.state.config.EXA_API_KEY:
|
||||
return search_exa(
|
||||
return await asyncio.to_thread(
|
||||
search_exa,
|
||||
request.app.state.config.EXA_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
@@ -2059,7 +2054,8 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
raise Exception('No EXA_API_KEY found in environment variables')
|
||||
elif engine == 'searchapi':
|
||||
if request.app.state.config.SEARCHAPI_API_KEY:
|
||||
return search_searchapi(
|
||||
return await asyncio.to_thread(
|
||||
search_searchapi,
|
||||
request.app.state.config.SEARCHAPI_API_KEY,
|
||||
request.app.state.config.SEARCHAPI_ENGINE,
|
||||
query,
|
||||
@@ -2070,7 +2066,8 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
raise Exception('No SEARCHAPI_API_KEY found in environment variables')
|
||||
elif engine == 'serpapi':
|
||||
if request.app.state.config.SERPAPI_API_KEY:
|
||||
return search_serpapi(
|
||||
return await asyncio.to_thread(
|
||||
search_serpapi,
|
||||
request.app.state.config.SERPAPI_API_KEY,
|
||||
request.app.state.config.SERPAPI_ENGINE,
|
||||
query,
|
||||
@@ -2080,14 +2077,16 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
else:
|
||||
raise Exception('No SERPAPI_API_KEY found in environment variables')
|
||||
elif engine == 'jina':
|
||||
return search_jina(
|
||||
return await asyncio.to_thread(
|
||||
search_jina,
|
||||
request.app.state.config.JINA_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.JINA_API_BASE_URL,
|
||||
)
|
||||
elif engine == 'bing':
|
||||
return search_bing(
|
||||
return await asyncio.to_thread(
|
||||
search_bing,
|
||||
request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||||
request.app.state.config.BING_SEARCH_V7_ENDPOINT,
|
||||
str(DEFAULT_LOCALE),
|
||||
@@ -2101,7 +2100,8 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
and request.app.state.config.AZURE_AI_SEARCH_ENDPOINT
|
||||
and request.app.state.config.AZURE_AI_SEARCH_INDEX_NAME
|
||||
):
|
||||
return search_azure(
|
||||
return await asyncio.to_thread(
|
||||
search_azure,
|
||||
request.app.state.config.AZURE_AI_SEARCH_API_KEY,
|
||||
request.app.state.config.AZURE_AI_SEARCH_ENDPOINT,
|
||||
request.app.state.config.AZURE_AI_SEARCH_INDEX_NAME,
|
||||
@@ -2113,15 +2113,9 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
raise Exception(
|
||||
'AZURE_AI_SEARCH_API_KEY, AZURE_AI_SEARCH_ENDPOINT, and AZURE_AI_SEARCH_INDEX_NAME are required for Azure AI Search'
|
||||
)
|
||||
elif engine == 'exa':
|
||||
return search_exa(
|
||||
request.app.state.config.EXA_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
elif engine == 'perplexity':
|
||||
return search_perplexity(
|
||||
return await asyncio.to_thread(
|
||||
search_perplexity,
|
||||
request.app.state.config.PERPLEXITY_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
@@ -2131,7 +2125,8 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
)
|
||||
elif engine == 'sougou':
|
||||
if request.app.state.config.SOUGOU_API_SID and request.app.state.config.SOUGOU_API_SK:
|
||||
return search_sougou(
|
||||
return await asyncio.to_thread(
|
||||
search_sougou,
|
||||
request.app.state.config.SOUGOU_API_SID,
|
||||
request.app.state.config.SOUGOU_API_SK,
|
||||
query,
|
||||
@@ -2141,7 +2136,8 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
else:
|
||||
raise Exception('No SOUGOU_API_SID or SOUGOU_API_SK found in environment variables')
|
||||
elif engine == 'firecrawl':
|
||||
return search_firecrawl(
|
||||
return await asyncio.to_thread(
|
||||
search_firecrawl,
|
||||
request.app.state.config.FIRECRAWL_API_BASE_URL,
|
||||
request.app.state.config.FIRECRAWL_API_KEY,
|
||||
query,
|
||||
@@ -2149,7 +2145,8 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
request.app.state.config.WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
elif engine == 'external':
|
||||
return search_external(
|
||||
return await asyncio.to_thread(
|
||||
search_external,
|
||||
request,
|
||||
request.app.state.config.EXTERNAL_WEB_SEARCH_URL,
|
||||
request.app.state.config.EXTERNAL_WEB_SEARCH_API_KEY,
|
||||
@@ -2159,7 +2156,8 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
user=user,
|
||||
)
|
||||
elif engine == 'yandex':
|
||||
return search_yandex(
|
||||
return await asyncio.to_thread(
|
||||
search_yandex,
|
||||
request,
|
||||
request.app.state.config.YANDEX_WEB_SEARCH_URL,
|
||||
request.app.state.config.YANDEX_WEB_SEARCH_API_KEY,
|
||||
@@ -2170,7 +2168,8 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
user=user,
|
||||
)
|
||||
elif engine == 'youcom':
|
||||
return search_youcom(
|
||||
return await asyncio.to_thread(
|
||||
search_youcom,
|
||||
request.app.state.config.YOUCOM_API_KEY,
|
||||
query,
|
||||
request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
@@ -2178,7 +2177,8 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
)
|
||||
elif engine == 'linkup':
|
||||
if request.app.state.config.LINKUP_API_KEY:
|
||||
return search_linkup(
|
||||
return await asyncio.to_thread(
|
||||
search_linkup,
|
||||
api_key=request.app.state.config.LINKUP_API_KEY,
|
||||
query=query,
|
||||
count=request.app.state.config.WEB_SEARCH_RESULT_COUNT,
|
||||
@@ -2191,6 +2191,7 @@ def search_web(request: Request, engine: str, query: str, user=None) -> list[Sea
|
||||
raise Exception('No search engine API key found in environment variables')
|
||||
|
||||
|
||||
|
||||
@router.post('/process/web/search')
|
||||
async def process_web_search(request: Request, form_data: SearchForm, user=Depends(get_verified_user)):
|
||||
if not request.app.state.config.ENABLE_WEB_SEARCH:
|
||||
@@ -2224,8 +2225,7 @@ async def process_web_search(request: Request, form_data: SearchForm, user=Depen
|
||||
|
||||
async def search_query_with_semaphore(query):
|
||||
async with semaphore:
|
||||
return await run_in_threadpool(
|
||||
search_web,
|
||||
return await search_web(
|
||||
request,
|
||||
request.app.state.config.WEB_SEARCH_ENGINE,
|
||||
query,
|
||||
@@ -2234,10 +2234,9 @@ async def process_web_search(request: Request, form_data: SearchForm, user=Depen
|
||||
|
||||
search_tasks = [search_query_with_semaphore(query) for query in form_data.queries]
|
||||
else:
|
||||
# Unlimited parallel execution (previous behavior)
|
||||
# Unlimited parallel execution
|
||||
search_tasks = [
|
||||
run_in_threadpool(
|
||||
search_web,
|
||||
search_web(
|
||||
request,
|
||||
request.app.state.config.WEB_SEARCH_ENGINE,
|
||||
query,
|
||||
|
||||
@@ -230,7 +230,7 @@ async def search_web(
|
||||
max_count = 5 if configured is None else configured
|
||||
count = max(1, min(count, max_count)) if count is not None else max_count
|
||||
|
||||
results = await asyncio.to_thread(_search_web, __request__, engine, query, user)
|
||||
results = await _search_web(__request__, engine, query, user)
|
||||
|
||||
# Limit results
|
||||
results = results[:count] if results else []
|
||||
|
||||
Reference in New Issue
Block a user