From 3475645292a78f7ce08f5d12c4b1044934d7c51f Mon Sep 17 00:00:00 2001 From: Yifei Kong Date: Fri, 22 Sep 2023 11:53:00 +0800 Subject: [PATCH] WIP: using a queue to implement stream --- curl_cffi/requests/__init__.py | 4 +-- curl_cffi/requests/models.py | 52 ++++++++++++++++++++++++++++++ curl_cffi/requests/session.py | 53 ++++++++++++++++++++++++++----- example.py => examples/example.py | 0 examples/stream.py | 13 ++++++++ tests/unittest/test_requests.py | 5 +++ 6 files changed, 116 insertions(+), 11 deletions(-) rename example.py => examples/example.py (100%) create mode 100644 examples/stream.py diff --git a/curl_cffi/requests/__init__.py b/curl_cffi/requests/__init__.py index 28cae6b..24beb15 100644 --- a/curl_cffi/requests/__init__.py +++ b/curl_cffi/requests/__init__.py @@ -88,9 +88,7 @@ def request( Returns: A [Response](/api/curl_cffi.requests#curl_cffi.requests.Response) object. """ - with Session( - thread=thread, curl_options=curl_options, debug=debug - ) as s: + with Session(thread=thread, curl_options=curl_options, debug=debug) as s: return s.request( method=method, url=url, diff --git a/curl_cffi/requests/models.py b/curl_cffi/requests/models.py index f334633..8cdd545 100644 --- a/curl_cffi/requests/models.py +++ b/curl_cffi/requests/models.py @@ -1,6 +1,7 @@ import warnings from json import loads from typing import Optional +import queue from .. import Curl from .headers import Headers @@ -8,6 +9,13 @@ from .errors import RequestsError +def clear_queue(q: queue.Queue): + with q.mutex: + q.queue.clear() + q.all_tasks_done.notify_all() + q.unfinished_tasks = 0 + + class Request: def __init__(self, url: str, headers: Headers, method: str): self.url = url @@ -34,6 +42,7 @@ class Response: http_version: http version used. history: history redirections, only headers are available. """ + def __init__(self, curl: Optional[Curl] = None, request: Optional[Request] = None): self.curl = curl self.request = request @@ -51,6 +60,7 @@ def __init__(self, curl: Optional[Curl] = None, request: Optional[Request] = Non self.redirect_url = "" self.http_version = 0 self.history = [] + self.queue: Optional[queue.Queue] = None @property def text(self) -> str: @@ -60,6 +70,48 @@ def raise_for_status(self): if not self.ok: raise RequestsError(f"HTTP Error {self.status_code}: {self.reason}") + def iter_lines(self, chunk_size=None, decode_unicode=False, delimiter=None): + """ + Copied from: https://requests.readthedocs.io/en/latest/_modules/requests/models/ + which is under the License: Apache 2.0 + """ + pending = None + + for chunk in self.iter_content( + chunk_size=chunk_size, decode_unicode=decode_unicode + ): + if pending is not None: + chunk = pending + chunk + if delimiter: + lines = chunk.split(delimiter) + else: + lines = chunk.splitlines() + if lines and lines[-1] and chunk and lines[-1][-1] == chunk[-1]: + pending = lines.pop() + else: + pending = None + + yield from lines + + if pending is not None: + yield pending + + def iter_content(self, chunk_size=None, decode_unicode=False): + if chunk_size: + warnings.warn("chunk_size is ignored, there is no way to tell curl that.") + if decode_unicode: + raise NotImplementedError() + try: + while True: + chunk = self.queue.get() # type: ignore + if chunk is None: + return + yield chunk + finally: + # If anything happens, always free the memory + self.curl.reset() # type: ignore + clear_queue(self.queue) # type: ignore + def json(self, **kw): return loads(self.content, **kw) diff --git a/curl_cffi/requests/session.py b/curl_cffi/requests/session.py index 1b3323a..ff48952 100644 --- a/curl_cffi/requests/session.py +++ b/curl_cffi/requests/session.py @@ -3,12 +3,14 @@ import re import threading import warnings +import queue from enum import Enum from functools import partialmethod from io import BytesIO from json import dumps from typing import Callable, Dict, List, Optional, Tuple, Union, cast from urllib.parse import ParseResult, parse_qsl, unquote, urlencode, urlparse +from concurrent.futures import ThreadPoolExecutor from .. import AsyncCurl, Curl, CurlError, CurlInfo, CurlOpt, CurlHttpVersion from .cookies import Cookies, CookieTypes, CurlMorsel @@ -185,6 +187,7 @@ def _set_curl_options( default_headers: Optional[bool] = None, http_version: Optional[CurlHttpVersion] = None, interface: Optional[str] = None, + stream: bool = False, ): c = curl @@ -355,12 +358,14 @@ def _set_curl_options( for k, v in self.curl_options.items(): c.setopt(k, v) - if content_callback is None: + buffer = None + if stream: + c.setopt(CurlOpt.WRITEFUNCTION, self.queue.put) # type: ignore + elif content_callback is not None: + c.setopt(CurlOpt.WRITEFUNCTION, content_callback) + else: buffer = BytesIO() c.setopt(CurlOpt.WRITEDATA, buffer) - else: - buffer = None - c.setopt(CurlOpt.WRITEFUNCTION, content_callback) header_buffer = BytesIO() c.setopt(CurlOpt.HEADERDATA, header_buffer) @@ -470,6 +475,8 @@ def __init__( super().__init__(**kwargs) self._thread = thread self._use_thread_local_curl = use_thread_local_curl + self._queue = None + self._executor = None if use_thread_local_curl: self._local = threading.local() if curl: @@ -485,13 +492,30 @@ def __init__( def curl(self): if self._use_thread_local_curl: if self._is_customized_curl: - warnings.warn("Creating fresh curl in different thread.") + warnings.warn("Creating fresh curl handle in different thread.") if not getattr(self._local, "curl", None): self._local.curl = Curl(debug=self.debug) return self._local.curl else: return self._curl + @property + def executor(self): + if self._executor is None: + self._executor = ThreadPoolExecutor() + return self._executor + + @property + def queue(self): + if self._use_thread_local_curl: + if getattr(self._local, "queue", None) is None: + self._local.queue = queue.Queue() + return self._local.queue + else: + if self._queue is None: + self._queue = queue.Queue() + return self._queue + def __enter__(self): return self @@ -525,6 +549,7 @@ def request( default_headers: Optional[bool] = None, http_version: Optional[CurlHttpVersion] = None, interface: Optional[str] = None, + stream: bool = False, ) -> Response: """Send the request, see [curl_cffi.requests.request](/api/curl_cffi.requests/#curl_cffi.requests.request) for details on parameters.""" c = self.curl @@ -551,6 +576,7 @@ def request( default_headers=default_headers, http_version=http_version, interface=interface, + stream=stream, ) try: if self._thread == "eventlet": @@ -560,7 +586,16 @@ def request( # see: https://www.gevent.org/api/gevent.threadpool.html gevent.get_hub().threadpool.spawn(c.perform).get() else: - c.perform() + if stream: + queue = self.queue # using queue from current thread + + def perform(): + c.perform() + queue.put(None) # sentinel + + self.executor.submit(perform) + else: + c.perform() except CurlError as e: rsp = self._parse_response(c, buffer, header_buffer) rsp.request = req @@ -568,9 +603,11 @@ def request( else: rsp = self._parse_response(c, buffer, header_buffer) rsp.request = req + rsp.queue = self.queue return rsp finally: - self.curl.reset() + if not stream: + self.curl.reset() head = partialmethod(request, "HEAD") get = partialmethod(request, "GET") @@ -626,7 +663,7 @@ def __init__( self.init_pool() if sys.version_info >= (3, 8) and sys.platform.lower().startswith("win"): if isinstance( - asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy + asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy # type: ignore ): warnings.warn(WINDOWS_WARN) diff --git a/example.py b/examples/example.py similarity index 100% rename from example.py rename to examples/example.py diff --git a/examples/stream.py b/examples/stream.py new file mode 100644 index 0000000..f9f204d --- /dev/null +++ b/examples/stream.py @@ -0,0 +1,13 @@ +from curl_cffi import requests + + +with requests.Session() as s: + r = s.get("https://httpbin.org/stream/20", stream=True) + for chunk in r.iter_content(): + print("CHUNK", chunk) + + +with requests.Session() as s: + r = s.get("https://httpbin.org/stream/20", stream=True) + for line in r.iter_lines(): + print("LINE", line.decode()) diff --git a/tests/unittest/test_requests.py b/tests/unittest/test_requests.py index b4f7d2e..663b7bd 100644 --- a/tests/unittest/test_requests.py +++ b/tests/unittest/test_requests.py @@ -405,3 +405,8 @@ def test_session_with_headers(server): r = s.get(str(server.url), headers={"Foo": "bar"}) r = s.get(str(server.url), headers={"Foo": "baz"}) assert r.status_code == 200 + + +def test_stream(server): + s = requests.Session() +