This commit is contained in:
J. Nick Koston 2023-12-12 10:16:03 -10:00
parent 44aee612a4
commit e762ec3e1e
No known key found for this signature in database
3 changed files with 25 additions and 19 deletions

View File

@ -188,25 +188,31 @@ async def async_resolve_host(
zc_error: Exception | None = None
for host in hosts:
host_addrs: list[AddrInfo] = []
host_is_name = host_is_name_part(host) or address_is_local(host)
if host_is_name:
name = host.partition(".")[0]
try:
host_addrs.extend(
await _async_resolve_host_zeroconf(
name, port, zeroconf_manager=zeroconf_manager
)
)
except ResolveAPIError as err:
zc_error = err
if not host_is_name:
try:
addrs.extend(_async_ip_address_to_addrs(ip_address(host), port))
host_addrs.extend(_async_ip_address_to_addrs(ip_address(host), port))
except ValueError:
# Not an IP address
continue
name = host.partition(".")[0]
try:
addrs.extend(
await _async_resolve_host_zeroconf(
name, port, zeroconf_manager=zeroconf_manager
)
)
except ResolveAPIError as err:
zc_error = err
addrs.extend(await _async_resolve_host_getaddrinfo(host, port))
if not host_addrs:
host_addrs.extend(await _async_resolve_host_getaddrinfo(host, port))
addrs.extend(host_addrs)
if not addrs:
if zc_error:

View File

@ -71,7 +71,7 @@ def patchable_api_client() -> APIClient:
def get_mock_connection_params() -> ConnectionParams:
return ConnectionParams(
address="fake.address",
addresses=["fake.address"],
port=6052,
password=None,
client_info="Tests client",

View File

@ -99,7 +99,7 @@ async def test_resolve_host_zeroconf_fails_end_to_end(async_zeroconf: AsyncZeroc
"aioesphomeapi.host_resolver.AsyncServiceInfo.async_request",
side_effect=Exception("no buffers"),
), pytest.raises(ResolveAPIError, match="no buffers"):
await hr.async_resolve_host("asdf.local", 6052)
await hr.async_resolve_host(["asdf.local"], 6052)
@pytest.mark.asyncio
@ -140,7 +140,7 @@ async def test_resolve_host_getaddrinfo_oserror(event_loop):
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos):
resolve_zc.return_value = addr_infos
ret = await hr.async_resolve_host("example.local", 6052)
ret = await hr.async_resolve_host(["example.local"], 6052)
resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
resolve_addr.assert_not_called()
@ -153,7 +153,7 @@ async def test_resolve_host_mdns(resolve_addr, resolve_zc, addr_infos):
async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
resolve_zc.return_value = []
resolve_addr.return_value = addr_infos
ret = await hr.async_resolve_host("example.local", 6052)
ret = await hr.async_resolve_host(["example.local"], 6052)
resolve_zc.assert_called_once_with("example", 6052, zeroconf_manager=None)
resolve_addr.assert_called_once_with("example.local", 6052)
@ -166,7 +166,7 @@ async def test_resolve_host_mdns_empty(resolve_addr, resolve_zc, addr_infos):
async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos):
resolve_addr.return_value = addr_infos
with pytest.raises(ResolveAPIError):
await hr.async_resolve_host("example.local", 6052)
await hr.async_resolve_host(["example.local"], 6052)
@pytest.mark.asyncio
@ -174,7 +174,7 @@ async def test_resolve_host_mdns_no_results(resolve_addr, addr_infos):
@patch("aioesphomeapi.host_resolver._async_resolve_host_getaddrinfo")
async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos):
resolve_addr.return_value = addr_infos
ret = await hr.async_resolve_host("example.com", 6052)
ret = await hr.async_resolve_host(["example.local"], 6052)
resolve_zc.assert_not_called()
resolve_addr.assert_called_once_with("example.com", 6052)
@ -187,7 +187,7 @@ async def test_resolve_host_addrinfo(resolve_addr, resolve_zc, addr_infos):
async def test_resolve_host_addrinfo_empty(resolve_addr, resolve_zc, addr_infos):
resolve_addr.return_value = []
with pytest.raises(APIConnectionError):
await hr.async_resolve_host("example.com", 6052)
await hr.async_resolve_host(["example.local"], 6052)
resolve_zc.assert_not_called()
resolve_addr.assert_called_once_with("example.com", 6052)
@ -199,7 +199,7 @@ async def test_resolve_host_addrinfo_empty(resolve_addr, resolve_zc, addr_infos)
async def test_resolve_host_with_address(resolve_addr, resolve_zc):
resolve_zc.return_value = []
resolve_addr.return_value = addr_infos
ret = await hr.async_resolve_host("127.0.0.1", 6052)
ret = await hr.async_resolve_host(["127.0.0.1"], 6052)
resolve_zc.assert_not_called()
resolve_addr.assert_not_called()