Skip to content

Commit

Permalink
Add support for using Name: None to explictly disable a header item
Browse files Browse the repository at this point in the history
  • Loading branch information
lexiforest committed Dec 30, 2024
1 parent bc09d4f commit 9e8e90d
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 39 deletions.
79 changes: 42 additions & 37 deletions curl_cffi/requests/headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,7 @@

AbcMapping = Mapping # type: ignore[misc]

from typing import (
Any,
AnyStr,
Dict,
List,
Optional,
Tuple,
Union,
)
from typing import Any, AnyStr, Dict, List, Optional, Tuple, Union, cast

HeaderTypes = Union[
"Headers",
Expand All @@ -56,19 +48,25 @@ def to_str(value: Union[str, bytes], encoding: str = "utf-8") -> str:
return value if isinstance(value, str) else value.decode(encoding)


def to_bytes_or_str(value: str, match_type_of: AnyStr) -> AnyStr:
return value if isinstance(match_type_of, str) else value.encode() # pyright: ignore [reportGeneralTypeIssues]
def to_bytes_or_str_or_none(value: Optional[str], match_type_of: AnyStr) -> Optional[AnyStr]:
if value is None:
return value

if isinstance(match_type_of, str):
return value

return value.encode()


SENSITIVE_HEADERS = {"authorization", "proxy-authorization"}


def obfuscate_sensitive_headers(
items: Iterable[Tuple[AnyStr, AnyStr]],
) -> Iterator[Tuple[AnyStr, AnyStr]]:
) -> Iterator[Tuple[AnyStr, Optional[AnyStr]]]:
for k, v in items:
if to_str(k.lower()) in SENSITIVE_HEADERS:
v = to_bytes_or_str("[secure]", match_type_of=v)
v = to_bytes_or_str_or_none("[secure]", match_type_of=v)
yield k, v


Expand All @@ -85,30 +83,32 @@ def normalize_header_key(
return bytes_value.lower() if lower else bytes_value


def normalize_header_value(value: Union[str, bytes, int], encoding: Optional[str] = None) -> bytes:
def normalize_header_value(
value: Union[str, bytes, int, None], encoding: Optional[str] = None
) -> Union[bytes, None]:
"""
Coerce str/bytes into a strictly byte-wise HTTP header value.
"""
if value is None:
return None

if isinstance(value, bytes):
return value

# The default encoding for header value should be latin-1
# See: RFC and https://github.com/python/cpython/blob/bc264eac3ad14dab748e33b3d714c2674872791f/Lib/http/client.py#L1309
if isinstance(value, int):
return str(value).encode()
else:
return value.encode(encoding or "latin-1")

return cast(str, value).encode(encoding or "latin-1")

class Headers(MutableMapping[str, str]):

class Headers(MutableMapping[str, Optional[str]]):
"""
HTTP headers, as a case-insensitive multi-dict.
"""

def __init__(
self,
headers: Optional[HeaderTypes] = None,
encoding: Optional[str] = None,
) -> None:
def __init__(self, headers: Optional[HeaderTypes] = None, encoding: Optional[str] = None):
if not headers:
self._list = [] # type: List[Tuple[bytes, bytes, bytes]]
elif isinstance(headers, Headers):
Expand All @@ -121,15 +121,16 @@ def __init__(
normalize_header_value(v, encoding),
)
for k, v in headers.items()
if v is not None
]
elif isinstance(headers, list):
# list of "Name: Value" pairs
if isinstance(headers[0], (str, bytes)):
sep = ":" if isinstance(headers[0], str) else b":"
h = []
for line in headers:
k, v = line.split(sep, maxsplit=1) # pyright: ignore
h.append((k, v.strip()))
# list of (Name, Value) pairs
elif isinstance(headers[0], tuple):
h = headers
self._list = [
Expand All @@ -154,7 +155,7 @@ def encoding(self) -> str:
for key, value in self.raw:
try:
key.decode(encoding)
value.decode(encoding)
value.decode(encoding) if value is not None else value
except UnicodeDecodeError:
break
else:
Expand All @@ -173,7 +174,7 @@ def encoding(self, value: str) -> None:
self._encoding = value

@property
def raw(self) -> List[Tuple[bytes, bytes]]:
def raw(self) -> List[Tuple[bytes, Optional[bytes]]]:
"""
Returns a list of the raw header items, as byte pairs.
"""
Expand All @@ -182,40 +183,41 @@ def raw(self) -> List[Tuple[bytes, bytes]]:
def keys(self) -> KeysView[str]:
return {key.decode(self.encoding): None for _, key, _ in self._list}.keys()

def values(self) -> ValuesView[str]:
def values(self) -> ValuesView[Optional[str]]:
values_dict: Dict[str, str] = {}
for _, key, value in self._list:
str_key = key.decode(self.encoding)
str_value = value.decode(self.encoding)
str_value = value.decode(self.encoding) if value is not None else value
if str_key in values_dict:
values_dict[str_key] += f", {str_value}"
else:
values_dict[str_key] = str_value
return values_dict.values()

def items(self) -> ItemsView[str, str]:
def items(self) -> ItemsView[str, Optional[str]]:
"""
Return `(key, value)` items of headers. Concatenate headers
into a single comma separated value when a key occurs multiple times.
"""
values_dict: Dict[str, str] = {}
for _, key, value in self._list:
str_key = key.decode(self.encoding)
str_value = value.decode(self.encoding)
str_value = value.decode(self.encoding) if value is not None else value
if str_key in values_dict:
values_dict[str_key] += f", {str_value}"
else:
values_dict[str_key] = str_value
return values_dict.items()

def multi_items(self) -> List[Tuple[str, str]]:
def multi_items(self) -> List[Tuple[str, Optional[str]]]:
"""
Return a list of `(key, value)` pairs of headers. Allow multiple
occurrences of the same key without concatenating into a single
comma separated value.
"""
return [
(key.decode(self.encoding), value.decode(self.encoding)) for key, _, value in self._list
(key.decode(self.encoding), value.decode(self.encoding) if value is not None else value)
for key, _, value in self._list
]

def get(self, key: str, default: Any = None) -> Any:
Expand All @@ -228,7 +230,7 @@ def get(self, key: str, default: Any = None) -> Any:
except KeyError:
return default

def get_list(self, key: str, split_commas: bool = False) -> List[str]:
def get_list(self, key: str, split_commas: bool = False) -> List[Optional[str]]:
"""
Return a list of all header values for a given key.
If `split_commas=True` is passed, then any comma separated header
Expand All @@ -237,7 +239,7 @@ def get_list(self, key: str, split_commas: bool = False) -> List[str]:
get_header_key = key.lower().encode(self.encoding)

values = [
item_value.decode(self.encoding)
item_value.decode(self.encoding) if item_value is not None else item_value
for _, item_key, item_value in self._list
if item_key.lower() == get_header_key
]
Expand All @@ -260,7 +262,7 @@ def update(self, headers: Optional[HeaderTypes] = None) -> None: # type: ignore
def copy(self) -> "Headers":
return Headers(self, encoding=self.encoding)

def __getitem__(self, key: str) -> str:
def __getitem__(self, key: str) -> Optional[str]:
"""
Return a single header value.
If there are multiple headers with the same key, then we concatenate
Expand All @@ -269,23 +271,26 @@ def __getitem__(self, key: str) -> str:
normalized_key = key.lower().encode(self.encoding)

items = [
header_value.decode(self.encoding)
header_value.decode(self.encoding) if header_value is not None else header_value
for _, header_key, header_value in self._list
if header_key == normalized_key
]

if items == [None]:
return None

if items:
return ", ".join(items)

raise KeyError(key)

def __setitem__(self, key: str, value: str) -> None:
def __setitem__(self, key: str, value: Optional[str]) -> None:
"""
Set the header `key` to `value`, removing any duplicate entries.
Retains insertion order.
"""
set_key = key.encode(self._encoding or "utf-8")
set_value = value.encode(self._encoding or "utf-8")
set_value = value.encode(self._encoding or "utf-8") if value is not None else value
lookup_key = set_key.lower()

found_indexes = [
Expand Down
9 changes: 7 additions & 2 deletions curl_cffi/requests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,11 +415,16 @@ def set_curl_options(
# See: https://stackoverflow.com/a/32911474/1061155
header_lines = []
for k, v in h.multi_items():
header_lines.append(f"{k}: {v}" if v else f"{k};")
if v is None:
header_lines.append(f"{k}:") # Explictly disable this header
elif v == "":
header_lines.append(f"{k};") # Add an empty valued header
else:
header_lines.append(f"{k}: {v}")

# Add content-type if missing
if json is not None:
update_header_line(header_lines, "Content-Type", "application/json", replace=True)
update_header_line(header_lines, "Content-Type", "application/json")
if isinstance(data, dict) and method != "POST":
update_header_line(header_lines, "Content-Type", "application/x-www-form-urlencoded")
if isinstance(data, (str, bytes)):
Expand Down
7 changes: 7 additions & 0 deletions tests/unittest/test_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,10 @@ def test_replace_header():
assert header_lines == ["Content-Type: application/json"]
update_header_line(header_lines, "Host", "example.com", replace=True)
assert header_lines == ["Content-Type: application/json", "Host: example.com"]


def test_none_headers():
"""Allow using None to explictly remove headers"""
headers = Headers({"Content-Type": None})
assert headers["content-type"] is None

13 changes: 13 additions & 0 deletions tests/unittest/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,19 @@ def test_empty_header_included(server):
assert headers["Xxx"][0] == ""


def test_explict_remove_header(server):
r = requests.get(str(server.url.copy_with(path="/echo_headers")), json={"foo": "bar"})
headers = r.json()
assert headers["Content-type"][0] == "application/json"
r = requests.get(
str(server.url.copy_with(path="/echo_headers")),
json={"foo": "bar"},
headers={"Content-Type": None},
)
headers = r.json()
assert "Content-type" not in headers


def test_expect_header_omitted(server):
r = requests.get(str(server.url.copy_with(path="/echo_headers")), headers={"expect": "100"})
headers = r.json()
Expand Down

0 comments on commit 9e8e90d

Please sign in to comment.