summaryrefslogtreecommitdiff
path: root/searx/network/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'searx/network/__init__.py')
-rw-r--r--searx/network/__init__.py128
1 files changed, 91 insertions, 37 deletions
diff --git a/searx/network/__init__.py b/searx/network/__init__.py
index 06c9f75a4..8622e9731 100644
--- a/searx/network/__init__.py
+++ b/searx/network/__init__.py
@@ -8,7 +8,8 @@ import concurrent.futures
from queue import SimpleQueue
from types import MethodType
from timeit import default_timer
-from typing import Iterable, Tuple
+from typing import Iterable, NamedTuple, Tuple, List, Dict, Union
+from contextlib import contextmanager
import httpx
import anyio
@@ -48,9 +49,23 @@ def get_context_network():
return THREADLOCAL.__dict__.get('network') or get_network()
-def request(method, url, **kwargs):
- """same as requests/requests/api.py request(...)"""
+@contextmanager
+def _record_http_time():
+ # pylint: disable=too-many-branches
time_before_request = default_timer()
+ start_time = getattr(THREADLOCAL, 'start_time', time_before_request)
+ try:
+ yield start_time
+ finally:
+ # update total_time.
+ # See get_time_for_thread() and reset_time_for_thread()
+ if hasattr(THREADLOCAL, 'total_time'):
+ time_after_request = default_timer()
+ THREADLOCAL.total_time += time_after_request - time_before_request
+
+
+def _get_timeout(start_time, kwargs):
+ # pylint: disable=too-many-branches
# timeout (httpx)
if 'timeout' in kwargs:
@@ -65,45 +80,84 @@ def request(method, url, **kwargs):
# ajdust actual timeout
timeout += 0.2 # overhead
- start_time = getattr(THREADLOCAL, 'start_time', time_before_request)
if start_time:
timeout -= default_timer() - start_time
- # raise_for_error
- check_for_httperror = True
- if 'raise_for_httperror' in kwargs:
- check_for_httperror = kwargs['raise_for_httperror']
- del kwargs['raise_for_httperror']
+ return timeout
- # requests compatibility
- if isinstance(url, bytes):
- url = url.decode()
- # network
- network = get_context_network()
-
- # do request
- future = asyncio.run_coroutine_threadsafe(network.request(method, url, **kwargs), get_loop())
- try:
- response = future.result(timeout)
- except concurrent.futures.TimeoutError as e:
- raise httpx.TimeoutException('Timeout', request=None) from e
-
- # requests compatibility
- # see also https://www.python-httpx.org/compatibility/#checking-for-4xx5xx-responses
- response.ok = not response.is_error
-
- # update total_time.
- # See get_time_for_thread() and reset_time_for_thread()
- if hasattr(THREADLOCAL, 'total_time'):
- time_after_request = default_timer()
- THREADLOCAL.total_time += time_after_request - time_before_request
-
- # raise an exception
- if check_for_httperror:
- raise_for_httperror(response)
-
- return response
+def request(method, url, **kwargs):
+ """same as requests/requests/api.py request(...)"""
+ with _record_http_time() as start_time:
+ network = get_context_network()
+ timeout = _get_timeout(start_time, kwargs)
+ future = asyncio.run_coroutine_threadsafe(network.request(method, url, **kwargs), get_loop())
+ try:
+ return future.result(timeout)
+ except concurrent.futures.TimeoutError as e:
+ raise httpx.TimeoutException('Timeout', request=None) from e
+
+
+def multi_requests(request_list: List["Request"]) -> List[Union[httpx.Response, Exception]]:
+ """send multiple HTTP requests in parallel. Wait for all requests to finish."""
+ with _record_http_time() as start_time:
+ # send the requests
+ network = get_context_network()
+ loop = get_loop()
+ future_list = []
+ for request_desc in request_list:
+ timeout = _get_timeout(start_time, request_desc.kwargs)
+ future = asyncio.run_coroutine_threadsafe(
+ network.request(request_desc.method, request_desc.url, **request_desc.kwargs), loop
+ )
+ future_list.append((future, timeout))
+
+ # read the responses
+ responses = []
+ for future, timeout in future_list:
+ try:
+ responses.append(future.result(timeout))
+ except concurrent.futures.TimeoutError:
+ responses.append(httpx.TimeoutException('Timeout', request=None))
+ except Exception as e: # pylint: disable=broad-except
+ responses.append(e)
+ return responses
+
+
+class Request(NamedTuple):
+ """Request description for the multi_requests function"""
+
+ method: str
+ url: str
+ kwargs: Dict[str, str] = {}
+
+ @staticmethod
+ def get(url, **kwargs):
+ return Request('GET', url, kwargs)
+
+ @staticmethod
+ def options(url, **kwargs):
+ return Request('OPTIONS', url, kwargs)
+
+ @staticmethod
+ def head(url, **kwargs):
+ return Request('HEAD', url, kwargs)
+
+ @staticmethod
+ def post(url, **kwargs):
+ return Request('POST', url, kwargs)
+
+ @staticmethod
+ def put(url, **kwargs):
+ return Request('PUT', url, kwargs)
+
+ @staticmethod
+ def patch(url, **kwargs):
+ return Request('PATCH', url, kwargs)
+
+ @staticmethod
+ def delete(url, **kwargs):
+ return Request('DELETE', url, kwargs)
def get(url, **kwargs):