Skip to content

Commit

Permalink
WIP: using a queue to implement stream
Browse files Browse the repository at this point in the history
  • Loading branch information
perklet committed Oct 1, 2023
1 parent c56e96a commit 3475645
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 11 deletions.
4 changes: 1 addition & 3 deletions curl_cffi/requests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ def request(
Returns:
A [Response](/api/curl_cffi.requests#curl_cffi.requests.Response) object.
"""
with Session(
thread=thread, curl_options=curl_options, debug=debug
) as s:
with Session(thread=thread, curl_options=curl_options, debug=debug) as s:
return s.request(
method=method,
url=url,
Expand Down
52 changes: 52 additions & 0 deletions curl_cffi/requests/models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import warnings
from json import loads
from typing import Optional
import queue

from .. import Curl
from .headers import Headers
from .cookies import Cookies
from .errors import RequestsError


def clear_queue(q: queue.Queue):
with q.mutex:
q.queue.clear()
q.all_tasks_done.notify_all()
q.unfinished_tasks = 0


class Request:
def __init__(self, url: str, headers: Headers, method: str):
self.url = url
Expand All @@ -34,6 +42,7 @@ class Response:
http_version: http version used.
history: history redirections, only headers are available.
"""

def __init__(self, curl: Optional[Curl] = None, request: Optional[Request] = None):
self.curl = curl
self.request = request
Expand All @@ -51,6 +60,7 @@ def __init__(self, curl: Optional[Curl] = None, request: Optional[Request] = Non
self.redirect_url = ""
self.http_version = 0
self.history = []
self.queue: Optional[queue.Queue] = None

@property
def text(self) -> str:
Expand All @@ -60,6 +70,48 @@ def raise_for_status(self):
if not self.ok:
raise RequestsError(f"HTTP Error {self.status_code}: {self.reason}")

def iter_lines(self, chunk_size=None, decode_unicode=False, delimiter=None):
"""
Copied from: https://requests.readthedocs.io/en/latest/_modules/requests/models/
which is under the License: Apache 2.0
"""
pending = None

for chunk in self.iter_content(
chunk_size=chunk_size, decode_unicode=decode_unicode
):
if pending is not None:
chunk = pending + chunk
if delimiter:
lines = chunk.split(delimiter)
else:
lines = chunk.splitlines()
if lines and lines[-1] and chunk and lines[-1][-1] == chunk[-1]:
pending = lines.pop()
else:
pending = None

yield from lines

if pending is not None:
yield pending

def iter_content(self, chunk_size=None, decode_unicode=False):
if chunk_size:
warnings.warn("chunk_size is ignored, there is no way to tell curl that.")
if decode_unicode:
raise NotImplementedError()
try:
while True:
chunk = self.queue.get() # type: ignore
if chunk is None:
return
yield chunk
finally:
# If anything happens, always free the memory
self.curl.reset() # type: ignore
clear_queue(self.queue) # type: ignore

def json(self, **kw):
return loads(self.content, **kw)

Expand Down
53 changes: 45 additions & 8 deletions curl_cffi/requests/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import re
import threading
import warnings
import queue
from enum import Enum
from functools import partialmethod
from io import BytesIO
from json import dumps
from typing import Callable, Dict, List, Optional, Tuple, Union, cast
from urllib.parse import ParseResult, parse_qsl, unquote, urlencode, urlparse
from concurrent.futures import ThreadPoolExecutor

from .. import AsyncCurl, Curl, CurlError, CurlInfo, CurlOpt, CurlHttpVersion
from .cookies import Cookies, CookieTypes, CurlMorsel
Expand Down Expand Up @@ -185,6 +187,7 @@ def _set_curl_options(
default_headers: Optional[bool] = None,
http_version: Optional[CurlHttpVersion] = None,
interface: Optional[str] = None,
stream: bool = False,
):
c = curl

Expand Down Expand Up @@ -355,12 +358,14 @@ def _set_curl_options(
for k, v in self.curl_options.items():
c.setopt(k, v)

if content_callback is None:
buffer = None
if stream:
c.setopt(CurlOpt.WRITEFUNCTION, self.queue.put) # type: ignore
elif content_callback is not None:
c.setopt(CurlOpt.WRITEFUNCTION, content_callback)
else:
buffer = BytesIO()
c.setopt(CurlOpt.WRITEDATA, buffer)
else:
buffer = None
c.setopt(CurlOpt.WRITEFUNCTION, content_callback)
header_buffer = BytesIO()
c.setopt(CurlOpt.HEADERDATA, header_buffer)

Expand Down Expand Up @@ -470,6 +475,8 @@ def __init__(
super().__init__(**kwargs)
self._thread = thread
self._use_thread_local_curl = use_thread_local_curl
self._queue = None
self._executor = None
if use_thread_local_curl:
self._local = threading.local()
if curl:
Expand All @@ -485,13 +492,30 @@ def __init__(
def curl(self):
if self._use_thread_local_curl:
if self._is_customized_curl:
warnings.warn("Creating fresh curl in different thread.")
warnings.warn("Creating fresh curl handle in different thread.")
if not getattr(self._local, "curl", None):
self._local.curl = Curl(debug=self.debug)
return self._local.curl
else:
return self._curl

@property
def executor(self):
if self._executor is None:
self._executor = ThreadPoolExecutor()
return self._executor

@property
def queue(self):
if self._use_thread_local_curl:
if getattr(self._local, "queue", None) is None:
self._local.queue = queue.Queue()
return self._local.queue
else:
if self._queue is None:
self._queue = queue.Queue()
return self._queue

def __enter__(self):
return self

Expand Down Expand Up @@ -525,6 +549,7 @@ def request(
default_headers: Optional[bool] = None,
http_version: Optional[CurlHttpVersion] = None,
interface: Optional[str] = None,
stream: bool = False,
) -> Response:
"""Send the request, see [curl_cffi.requests.request](/api/curl_cffi.requests/#curl_cffi.requests.request) for details on parameters."""
c = self.curl
Expand All @@ -551,6 +576,7 @@ def request(
default_headers=default_headers,
http_version=http_version,
interface=interface,
stream=stream,
)
try:
if self._thread == "eventlet":
Expand All @@ -560,17 +586,28 @@ def request(
# see: https://www.gevent.org/api/gevent.threadpool.html
gevent.get_hub().threadpool.spawn(c.perform).get()
else:
c.perform()
if stream:
queue = self.queue # using queue from current thread

def perform():
c.perform()
queue.put(None) # sentinel

self.executor.submit(perform)
else:
c.perform()
except CurlError as e:
rsp = self._parse_response(c, buffer, header_buffer)
rsp.request = req
raise RequestsError(str(e), e.code, rsp) from e
else:
rsp = self._parse_response(c, buffer, header_buffer)
rsp.request = req
rsp.queue = self.queue
return rsp
finally:
self.curl.reset()
if not stream:
self.curl.reset()

head = partialmethod(request, "HEAD")
get = partialmethod(request, "GET")
Expand Down Expand Up @@ -626,7 +663,7 @@ def __init__(
self.init_pool()
if sys.version_info >= (3, 8) and sys.platform.lower().startswith("win"):
if isinstance(
asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy
asyncio.get_event_loop_policy(), asyncio.WindowsProactorEventLoopPolicy # type: ignore
):
warnings.warn(WINDOWS_WARN)

Expand Down
File renamed without changes.
13 changes: 13 additions & 0 deletions examples/stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from curl_cffi import requests


with requests.Session() as s:
r = s.get("https://httpbin.org/stream/20", stream=True)
for chunk in r.iter_content():
print("CHUNK", chunk)


with requests.Session() as s:
r = s.get("https://httpbin.org/stream/20", stream=True)
for line in r.iter_lines():
print("LINE", line.decode())
5 changes: 5 additions & 0 deletions tests/unittest/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,8 @@ def test_session_with_headers(server):
r = s.get(str(server.url), headers={"Foo": "bar"})
r = s.get(str(server.url), headers={"Foo": "baz"})
assert r.status_code == 200


def test_stream(server):
s = requests.Session()

0 comments on commit 3475645

Please sign in to comment.