Skip to content

Commit

Permalink
Better typing for requests / stream (#474)
Browse files Browse the repository at this point in the history
* typing for requests / stream

* fix lint fail

* unpack not found

* Fake unpack

* remove wrong unpack from init
  • Loading branch information
novitae authored Jan 12, 2025
1 parent 8cfad55 commit 139f705
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 107 deletions.
121 changes: 40 additions & 81 deletions curl_cffi/requests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,15 @@
"ProxySpec",
]

from functools import partial
from io import BytesIO
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import Optional, TYPE_CHECKING, TypedDict

from ..const import CurlHttpVersion, CurlWsFlag
from ..curl import CurlMime
from ..const import CurlWsFlag
from .cookies import Cookies, CookieTypes
from .errors import RequestsError
from .headers import Headers, HeaderTypes
from .impersonate import BrowserType, BrowserTypeLiteral, ExtraFingerprints, ExtraFpDict
from .impersonate import BrowserType, BrowserTypeLiteral, ExtraFingerprints
from .models import Request, Response
from .session import AsyncSession, HttpMethod, ProxySpec, Session, ThreadType
from .session import AsyncSession, HttpMethod, ProxySpec, Session, ThreadType, RequestParams, Unpack
from .websockets import (
AsyncWebSocket,
WebSocket,
Expand All @@ -50,43 +47,21 @@
WsCloseCode,
)

if TYPE_CHECKING:
class SessionRequestParams(RequestParams):
thread: Optional[ThreadType]
curl_options: Optional[dict]
debug: Optional[bool]
else:
SessionRequestParams = TypedDict

def request(
method: HttpMethod,
url: str,
params: Optional[Union[Dict, List, Tuple]] = None,
data: Optional[Union[Dict[str, str], List[Tuple], str, BytesIO, bytes]] = None,
json: Optional[dict] = None,
headers: Optional[HeaderTypes] = None,
cookies: Optional[CookieTypes] = None,
files: Optional[Dict] = None,
auth: Optional[Tuple[str, str]] = None,
timeout: Union[float, Tuple[float, float]] = 30,
allow_redirects: bool = True,
max_redirects: int = 30,
proxies: Optional[ProxySpec] = None,
proxy: Optional[str] = None,
proxy_auth: Optional[Tuple[str, str]] = None,
verify: Optional[bool] = None,
referer: Optional[str] = None,
accept_encoding: Optional[str] = "gzip, deflate, br, zstd",
content_callback: Optional[Callable] = None,
impersonate: Optional[BrowserTypeLiteral] = None,
ja3: Optional[str] = None,
akamai: Optional[str] = None,
extra_fp: Optional[Union[ExtraFingerprints, ExtraFpDict]] = None,
thread: Optional[ThreadType] = None,
default_headers: Optional[bool] = None,
default_encoding: Union[str, Callable[[bytes], str]] = "utf-8",
quote: Union[str, Literal[False]] = "",
curl_options: Optional[dict] = None,
http_version: Optional[CurlHttpVersion] = None,
debug: bool = False,
interface: Optional[str] = None,
cert: Optional[Union[str, Tuple[str, str]]] = None,
stream: bool = False,
max_recv_speed: int = 0,
multipart: Optional[CurlMime] = None,
debug: Optional[bool] = None,
**kwargs: Unpack[RequestParams],
) -> Response:
"""Send an http request.
Expand Down Expand Up @@ -139,49 +114,33 @@ def request(
Returns:
A ``Response`` object.
"""
debug = False if debug is None else debug
with Session(thread=thread, curl_options=curl_options, debug=debug) as s:
return s.request(
method=method,
url=url,
params=params,
data=data,
json=json,
headers=headers,
cookies=cookies,
files=files,
auth=auth,
timeout=timeout,
allow_redirects=allow_redirects,
max_redirects=max_redirects,
proxies=proxies,
proxy=proxy,
proxy_auth=proxy_auth,
verify=verify,
referer=referer,
accept_encoding=accept_encoding,
content_callback=content_callback,
impersonate=impersonate,
ja3=ja3,
akamai=akamai,
extra_fp=extra_fp,
default_headers=default_headers,
default_encoding=default_encoding,
quote=quote,
http_version=http_version,
interface=interface,
cert=cert,
stream=stream,
max_recv_speed=max_recv_speed,
multipart=multipart,
)
return s.request(method=method, url=url, **kwargs)

def head(url: str, **kwargs: Unpack[SessionRequestParams]):
return request(method="HEAD", url=url, **kwargs)

head = partial(request, "HEAD")
get = partial(request, "GET")
post = partial(request, "POST")
put = partial(request, "PUT")
patch = partial(request, "PATCH")
delete = partial(request, "DELETE")
options = partial(request, "OPTIONS")
trace = partial(request, "TRACE")
query = partial(request, "QUERY")
def get(url: str, **kwargs: Unpack[SessionRequestParams]):
return request(method="GET", url=url, **kwargs)

def post(url: str, **kwargs: Unpack[SessionRequestParams]):
return request(method="POST", url=url, **kwargs)

def put(url: str, **kwargs: Unpack[SessionRequestParams]):
return request(method="PUT", url=url, **kwargs)

def patch(url: str, **kwargs: Unpack[SessionRequestParams]):
return request(method="PATCH", url=url, **kwargs)

def delete(url: str, **kwargs: Unpack[SessionRequestParams]):
return request(method="DELETE", url=url, **kwargs)

def options(url: str, **kwargs: Unpack[SessionRequestParams]):
return request(method="OPTIONS", url=url, **kwargs)

def trace(url: str, **kwargs: Unpack[SessionRequestParams]):
return request(method="TRACE", url=url, **kwargs)

def query(url: str, **kwargs: Unpack[SessionRequestParams]):
return request(method="QUERY", url=url, **kwargs)
133 changes: 108 additions & 25 deletions curl_cffi/requests/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import warnings
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager, contextmanager, suppress
from functools import partialmethod
from io import BytesIO
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -79,9 +78,49 @@ class BaseSessionParams(TypedDict, total=False):
cert: Optional[Union[str, Tuple[str, str]]]
response_class: Optional[Type[Response]]

class StreamRequestParams(TypedDict, total=False):
params: Optional[Union[Dict, List, Tuple]]
data: Optional[Union[Dict[str, str], List[Tuple], str, BytesIO, bytes]]
json: Optional[dict]
headers: Optional[HeaderTypes]
cookies: Optional[CookieTypes]
files: Optional[Dict]
auth: Optional[Tuple[str, str]]
timeout: Optional[Union[float, Tuple[float, float], object]]
allow_redirects: Optional[bool]
max_redirects: Optional[int]
proxies: Optional[ProxySpec]
proxy: Optional[str]
proxy_auth: Optional[Tuple[str, str]]
verify: Optional[bool]
referer: Optional[str]
accept_encoding: Optional[str]
content_callback: Optional[Callable]
impersonate: Optional[BrowserTypeLiteral]
ja3: Optional[str]
akamai: Optional[str]
extra_fp: Optional[Union[ExtraFingerprints, ExtraFpDict]]
default_headers: Optional[bool]
default_encoding: Union[str, Callable[[bytes], str]]
quote: Union[str, Literal[False]]
http_version: Optional[CurlHttpVersion]
interface: Optional[str]
cert: Optional[Union[str, Tuple[str, str]]]
max_recv_speed: int
multipart: Optional[CurlMime]

class RequestParams(StreamRequestParams):
stream: Optional[bool]

else:
class _Unpack:
@staticmethod
def __getitem__(*args, **kwargs): pass
Unpack = _Unpack()

ProxySpec = Dict[str, str]
BaseSessionParams = TypedDict
StreamRequestParams, RequestParams = TypedDict, TypedDict

ThreadType = Literal["eventlet", "gevent"]
HttpMethod = Literal["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD", "TRACE", "PATCH", "QUERY"]
Expand Down Expand Up @@ -348,9 +387,14 @@ def close(self) -> None:
self.curl.close()

@contextmanager
def stream(self, *args, **kwargs):
def stream(
self,
method: HttpMethod,
url: str,
**kwargs: Unpack[StreamRequestParams],
):
"""Equivalent to ``with request(..., stream=True) as r:``"""
rsp = self.request(*args, **kwargs, stream=True)
rsp = self.request(method=method, url=url, **kwargs, stream=True)
try:
yield rsp
finally:
Expand Down Expand Up @@ -422,7 +466,7 @@ def request(
http_version: Optional[CurlHttpVersion] = None,
interface: Optional[str] = None,
cert: Optional[Union[str, Tuple[str, str]]] = None,
stream: bool = False,
stream: Optional[bool] = None,
max_recv_speed: int = 0,
multipart: Optional[CurlMime] = None,
) -> Response:
Expand Down Expand Up @@ -537,15 +581,32 @@ def cleanup(fut):
finally:
c.reset()

head = partialmethod(request, "HEAD")
get = partialmethod(request, "GET")
post = partialmethod(request, "POST")
put = partialmethod(request, "PUT")
patch = partialmethod(request, "PATCH")
delete = partialmethod(request, "DELETE")
options = partialmethod(request, "OPTIONS")
trace = partialmethod(request, "TRACE")
query = partialmethod(request, "QUERY")
def head(self, url: str, **kwargs: Unpack[RequestParams]):
return self.request(method="HEAD", url=url, **kwargs)

def get(self, url: str, **kwargs: Unpack[RequestParams]):
return self.request(method="GET", url=url, **kwargs)

def post(self, url: str, **kwargs: Unpack[RequestParams]):
return self.request(method="POST", url=url, **kwargs)

def put(self, url: str, **kwargs: Unpack[RequestParams]):
return self.request(method="PUT", url=url, **kwargs)

def patch(self, url: str, **kwargs: Unpack[RequestParams]):
return self.request(method="PATCH", url=url, **kwargs)

def delete(self, url: str, **kwargs: Unpack[RequestParams]):
return self.request(method="DELETE", url=url, **kwargs)

def options(self, url: str, **kwargs: Unpack[RequestParams]):
return self.request(method="OPTIONS", url=url, **kwargs)

def trace(self, url: str, **kwargs: Unpack[RequestParams]):
return self.request(method="TRACE", url=url, **kwargs)

def query(self, url: str, **kwargs: Unpack[RequestParams]):
return self.request(method="QUERY", url=url, **kwargs)


class AsyncSession(BaseSession):
Expand Down Expand Up @@ -674,9 +735,14 @@ def release_curl(self, curl):
curl.close()

@asynccontextmanager
async def stream(self, *args, **kwargs):
async def stream(
self,
method: HttpMethod,
url: str,
**kwargs: Unpack[StreamRequestParams],
):
"""Equivalent to ``async with request(..., stream=True) as r:``"""
rsp = await self.request(*args, **kwargs, stream=True)
rsp = await self.request(method=method, url=url, **kwargs, stream=True)
try:
yield rsp
finally:
Expand Down Expand Up @@ -815,7 +881,7 @@ async def request(
http_version: Optional[CurlHttpVersion] = None,
interface: Optional[str] = None,
cert: Optional[Union[str, Tuple[str, str]]] = None,
stream: bool = False,
stream: Optional[bool] = None,
max_recv_speed: int = 0,
multipart: Optional[CurlMime] = None,
):
Expand Down Expand Up @@ -917,12 +983,29 @@ def cleanup(fut):
finally:
self.release_curl(curl)

head = partialmethod(request, "HEAD")
get = partialmethod(request, "GET")
post = partialmethod(request, "POST")
put = partialmethod(request, "PUT")
patch = partialmethod(request, "PATCH")
delete = partialmethod(request, "DELETE")
options = partialmethod(request, "OPTIONS")
trace = partialmethod(request, "TRACE")
query = partialmethod(request, "QUERY")
def head(self, url: str, **kwargs: Unpack[RequestParams]):
return self.request(method="HEAD", url=url, **kwargs)

def get(self, url: str, **kwargs: Unpack[RequestParams]):
return self.request(method="GET", url=url, **kwargs)

def post(self, url: str, **kwargs: Unpack[RequestParams]):
return self.request(method="POST", url=url, **kwargs)

def put(self, url: str, **kwargs: Unpack[RequestParams]):
return self.request(method="PUT", url=url, **kwargs)

def patch(self, url: str, **kwargs: Unpack[RequestParams]):
return self.request(method="PATCH", url=url, **kwargs)

def delete(self, url: str, **kwargs: Unpack[RequestParams]):
return self.request(method="DELETE", url=url, **kwargs)

def options(self, url: str, **kwargs: Unpack[RequestParams]):
return self.request(method="OPTIONS", url=url, **kwargs)

def trace(self, url: str, **kwargs: Unpack[RequestParams]):
return self.request(method="TRACE", url=url, **kwargs)

def query(self, url: str, **kwargs: Unpack[RequestParams]):
return self.request(method="QUERY", url=url, **kwargs)
2 changes: 1 addition & 1 deletion curl_cffi/requests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def set_curl_options(
http_version: Optional[CurlHttpVersion] = None,
interface: Optional[str] = None,
cert: Optional[Union[str, Tuple[str, str]]] = None,
stream: bool = False,
stream: Optional[bool] = None,
max_recv_speed: int = 0,
multipart: Optional[CurlMime] = None,
queue_class: Any = None,
Expand Down

0 comments on commit 139f705

Please sign in to comment.