diff options
Diffstat (limited to 'searx/network/__init__.py')
-rw-r--r-- | searx/network/__init__.py | 128 |
1 files changed, 91 insertions, 37 deletions
diff --git a/searx/network/__init__.py b/searx/network/__init__.py index 06c9f75a4..8622e9731 100644 --- a/searx/network/__init__.py +++ b/searx/network/__init__.py @@ -8,7 +8,8 @@ import concurrent.futures from queue import SimpleQueue from types import MethodType from timeit import default_timer -from typing import Iterable, Tuple +from typing import Iterable, NamedTuple, Tuple, List, Dict, Union +from contextlib import contextmanager import httpx import anyio @@ -48,9 +49,23 @@ def get_context_network(): return THREADLOCAL.__dict__.get('network') or get_network() -def request(method, url, **kwargs): - """same as requests/requests/api.py request(...)""" +@contextmanager +def _record_http_time(): + # pylint: disable=too-many-branches time_before_request = default_timer() + start_time = getattr(THREADLOCAL, 'start_time', time_before_request) + try: + yield start_time + finally: + # update total_time. + # See get_time_for_thread() and reset_time_for_thread() + if hasattr(THREADLOCAL, 'total_time'): + time_after_request = default_timer() + THREADLOCAL.total_time += time_after_request - time_before_request + + +def _get_timeout(start_time, kwargs): + # pylint: disable=too-many-branches # timeout (httpx) if 'timeout' in kwargs: @@ -65,45 +80,84 @@ def request(method, url, **kwargs): # ajdust actual timeout timeout += 0.2 # overhead - start_time = getattr(THREADLOCAL, 'start_time', time_before_request) if start_time: timeout -= default_timer() - start_time - # raise_for_error - check_for_httperror = True - if 'raise_for_httperror' in kwargs: - check_for_httperror = kwargs['raise_for_httperror'] - del kwargs['raise_for_httperror'] + return timeout - # requests compatibility - if isinstance(url, bytes): - url = url.decode() - # network - network = get_context_network() - - # do request - future = asyncio.run_coroutine_threadsafe(network.request(method, url, **kwargs), get_loop()) - try: - response = future.result(timeout) - except concurrent.futures.TimeoutError as e: - raise httpx.TimeoutException('Timeout', request=None) from e - - # requests compatibility - # see also https://www.python-httpx.org/compatibility/#checking-for-4xx5xx-responses - response.ok = not response.is_error - - # update total_time. - # See get_time_for_thread() and reset_time_for_thread() - if hasattr(THREADLOCAL, 'total_time'): - time_after_request = default_timer() - THREADLOCAL.total_time += time_after_request - time_before_request - - # raise an exception - if check_for_httperror: - raise_for_httperror(response) - - return response +def request(method, url, **kwargs): + """same as requests/requests/api.py request(...)""" + with _record_http_time() as start_time: + network = get_context_network() + timeout = _get_timeout(start_time, kwargs) + future = asyncio.run_coroutine_threadsafe(network.request(method, url, **kwargs), get_loop()) + try: + return future.result(timeout) + except concurrent.futures.TimeoutError as e: + raise httpx.TimeoutException('Timeout', request=None) from e + + +def multi_requests(request_list: List["Request"]) -> List[Union[httpx.Response, Exception]]: + """send multiple HTTP requests in parallel. Wait for all requests to finish.""" + with _record_http_time() as start_time: + # send the requests + network = get_context_network() + loop = get_loop() + future_list = [] + for request_desc in request_list: + timeout = _get_timeout(start_time, request_desc.kwargs) + future = asyncio.run_coroutine_threadsafe( + network.request(request_desc.method, request_desc.url, **request_desc.kwargs), loop + ) + future_list.append((future, timeout)) + + # read the responses + responses = [] + for future, timeout in future_list: + try: + responses.append(future.result(timeout)) + except concurrent.futures.TimeoutError: + responses.append(httpx.TimeoutException('Timeout', request=None)) + except Exception as e: # pylint: disable=broad-except + responses.append(e) + return responses + + +class Request(NamedTuple): + """Request description for the multi_requests function""" + + method: str + url: str + kwargs: Dict[str, str] = {} + + @staticmethod + def get(url, **kwargs): + return Request('GET', url, kwargs) + + @staticmethod + def options(url, **kwargs): + return Request('OPTIONS', url, kwargs) + + @staticmethod + def head(url, **kwargs): + return Request('HEAD', url, kwargs) + + @staticmethod + def post(url, **kwargs): + return Request('POST', url, kwargs) + + @staticmethod + def put(url, **kwargs): + return Request('PUT', url, kwargs) + + @staticmethod + def patch(url, **kwargs): + return Request('PATCH', url, kwargs) + + @staticmethod + def delete(url, **kwargs): + return Request('DELETE', url, kwargs) def get(url, **kwargs): |