Skip to content

Commit

Permalink
Don't try to do name resolution when connecting to a bare IP address
Browse files Browse the repository at this point in the history
  • Loading branch information
agronholm committed Oct 5, 2019
1 parent 3d363c6 commit 996ce98
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 18 deletions.
26 changes: 18 additions & 8 deletions anyio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import typing
from contextlib import contextmanager
from importlib import import_module
from ipaddress import ip_address, IPv6Address
from ssl import SSLContext
from typing import TypeVar, Callable, Union, Optional, Awaitable, Coroutine, Any, Dict, List

Expand Down Expand Up @@ -330,21 +331,30 @@ async def try_connect(af: int, addr: str, delay: float):
raise

assert stream is None
stream = _networking.SocketStream(sock, ssl_context, str(address), tls_standard_compatible)
stream = _networking.SocketStream(sock, ssl_context, target_host, tls_standard_compatible)
await tg.cancel_scope.cancel()

asynclib = _get_asynclib()
interface, family = None, 0 # type: Optional[str], int
if bind_host:
interface, family, _v6only = await _networking.get_bind_address(bind_host)

# getaddrinfo() will raise an exception if name resolution fails
addrlist = await run_in_thread(socket.getaddrinfo, str(address), port, family,
socket.SOCK_STREAM)

# Sort the list so that IPv4 addresses are tried last
addresses = sorted(((item[0], item[-1][0]) for item in addrlist),
key=lambda item: item[0] == socket.AF_INET)
target_host = str(address)
try:
addr_obj = ip_address(address)
except ValueError:
# getaddrinfo() will raise an exception if name resolution fails
resolved = await run_in_thread(socket.getaddrinfo, target_host, port, family,
socket.SOCK_STREAM)

# Sort the list so that IPv4 addresses are tried last
addresses = sorted(((item[0], item[-1][0]) for item in resolved),
key=lambda item: item[0] == socket.AF_INET)
else:
if isinstance(addr_obj, IPv6Address):
addresses = [(socket.AF_INET6, addr_obj.compressed)]
else:
addresses = [(socket.AF_INET, addr_obj.compressed)]

oserrors = [] # type: List[OSError]
async with create_task_group() as tg:
Expand Down
25 changes: 15 additions & 10 deletions tests/test_networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ def localhost():
return '::1' if socket.has_ipv6 else '127.0.0.1'


@pytest.fixture
def fake_localhost_dns(monkeypatch):
# Make it return IPv4 addresses first so we can test the IPv6 preference
fake_results = [(socket.AF_INET, socket.SOCK_STREAM, '', ('127.0.0.1', 0)),
(socket.AF_INET6, socket.SOCK_STREAM, '', ('::1', 0))]
monkeypatch.setattr('socket.getaddrinfo', lambda *args: fake_results)


class TestTCPStream:
@pytest.mark.anyio
async def test_receive_some(self, localhost):
Expand Down Expand Up @@ -270,7 +278,7 @@ async def receive_data():
('::1', b'::1')
])
@pytest.mark.anyio
async def test_happy_eyeballs(self, interface, expected_addr, monkeypatch):
async def test_happy_eyeballs(self, interface, expected_addr, fake_localhost_dns):
async def handle_client(stream):
addr, port, *rest = stream._socket._raw_socket.getpeername()
await stream.send_all(addr.encode() + b'\n')
Expand All @@ -279,11 +287,6 @@ async def server():
async for stream in stream_server.accept_connections():
await tg.spawn(handle_client, stream)

# Fake getaddrinfo() to return IPv4 addresses first so we can test the IPv6 preference
fake_results = [(socket.AF_INET, socket.SOCK_STREAM, '', ('127.0.0.1', 0)),
(socket.AF_INET6, socket.SOCK_STREAM, '', ('::1', 0))]
monkeypatch.setattr('socket.getaddrinfo', lambda *args: fake_results)

async with await create_tcp_server(interface=interface) as stream_server:
async with create_task_group() as tg:
await tg.spawn(server)
Expand All @@ -292,13 +295,15 @@ async def server():

await stream_server.close()

@pytest.mark.skipif(not socket.has_ipv6, reason='IPv6 is not available')
@pytest.mark.parametrize('target, exception_class', [
('localhost', ExceptionGroup),
pytest.param(
'localhost', ExceptionGroup,
marks=[pytest.mark.skipif(not socket.has_ipv6, reason='IPv6 is not available')]
),
('127.0.0.1', ConnectionRefusedError)
])
], ids=['multi', 'single'])
@pytest.mark.anyio
async def test_connrefused(self, target, exception_class):
async def test_connrefused(self, target, exception_class, fake_localhost_dns):
dummy_socket = socket.socket(socket.AF_INET6)
dummy_socket.bind(('::', 0))
free_port = dummy_socket.getsockname()[1]
Expand Down

0 comments on commit 996ce98

Please sign in to comment.