mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-21 11:55:11 +01:00
Implement zerocopy writes (#990)
This commit is contained in:
parent
4bea46b201
commit
ba05d38602
@ -11,7 +11,7 @@ cdef class APIFrameHelper:
|
|||||||
cdef object _loop
|
cdef object _loop
|
||||||
cdef APIConnection _connection
|
cdef APIConnection _connection
|
||||||
cdef object _transport
|
cdef object _transport
|
||||||
cdef public object _writer
|
cdef public object _writelines
|
||||||
cdef public object ready_future
|
cdef public object ready_future
|
||||||
cdef bytes _buffer
|
cdef bytes _buffer
|
||||||
cdef unsigned int _buffer_len
|
cdef unsigned int _buffer_len
|
||||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections.abc import Iterable
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Callable, cast
|
from typing import TYPE_CHECKING, Callable, cast
|
||||||
|
|
||||||
@ -31,7 +32,7 @@ class APIFrameHelper:
|
|||||||
"_loop",
|
"_loop",
|
||||||
"_connection",
|
"_connection",
|
||||||
"_transport",
|
"_transport",
|
||||||
"_writer",
|
"_writelines",
|
||||||
"ready_future",
|
"ready_future",
|
||||||
"_buffer",
|
"_buffer",
|
||||||
"_buffer_len",
|
"_buffer_len",
|
||||||
@ -51,7 +52,9 @@ class APIFrameHelper:
|
|||||||
self._loop = loop
|
self._loop = loop
|
||||||
self._connection = connection
|
self._connection = connection
|
||||||
self._transport: asyncio.Transport | None = None
|
self._transport: asyncio.Transport | None = None
|
||||||
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
|
self._writelines: (
|
||||||
|
None | (Callable[[Iterable[bytes | bytearray | memoryview[int]]], None])
|
||||||
|
) = None
|
||||||
self.ready_future = self._loop.create_future()
|
self.ready_future = self._loop.create_future()
|
||||||
self._buffer: bytes | None = None
|
self._buffer: bytes | None = None
|
||||||
self._buffer_len = 0
|
self._buffer_len = 0
|
||||||
@ -146,7 +149,7 @@ class APIFrameHelper:
|
|||||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||||
"""Handle a new connection."""
|
"""Handle a new connection."""
|
||||||
self._transport = cast(asyncio.Transport, transport)
|
self._transport = cast(asyncio.Transport, transport)
|
||||||
self._writer = self._transport.write
|
self._writelines = self._transport.writelines
|
||||||
|
|
||||||
def _handle_error_and_close(self, exc: Exception) -> None:
|
def _handle_error_and_close(self, exc: Exception) -> None:
|
||||||
self._handle_error(exc)
|
self._handle_error(exc)
|
||||||
@ -172,7 +175,7 @@ class APIFrameHelper:
|
|||||||
if self._transport:
|
if self._transport:
|
||||||
self._transport.close()
|
self._transport.close()
|
||||||
self._transport = None
|
self._transport = None
|
||||||
self._writer = None
|
self._writelines = None
|
||||||
|
|
||||||
def pause_writing(self) -> None:
|
def pause_writing(self) -> None:
|
||||||
"""Stub."""
|
"""Stub."""
|
||||||
@ -180,12 +183,14 @@ class APIFrameHelper:
|
|||||||
def resume_writing(self) -> None:
|
def resume_writing(self) -> None:
|
||||||
"""Stub."""
|
"""Stub."""
|
||||||
|
|
||||||
def _write_bytes(self, data: _bytes, debug_enabled: bool) -> None:
|
def _write_bytes(self, data: Iterable[_bytes], debug_enabled: bool) -> None:
|
||||||
"""Write bytes to the socket."""
|
"""Write bytes to the socket."""
|
||||||
if debug_enabled:
|
if debug_enabled:
|
||||||
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, data.hex())
|
_LOGGER.debug(
|
||||||
|
"%s: Sending frame: [%s]", self._log_name, b"".join(data).hex()
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
assert self._writer is not None, "Writer is not set"
|
assert self._writelines is not None, "Writer is not set"
|
||||||
|
|
||||||
self._writer(data)
|
self._writelines(data)
|
||||||
|
@ -218,7 +218,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
frame_len = len(handshake_frame) + 1
|
frame_len = len(handshake_frame) + 1
|
||||||
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
||||||
self._write_bytes(
|
self._write_bytes(
|
||||||
b"".join((NOISE_HELLO, header, b"\x00", handshake_frame)),
|
(NOISE_HELLO, header, b"\x00", handshake_frame),
|
||||||
_LOGGER.isEnabledFor(logging.DEBUG),
|
_LOGGER.isEnabledFor(logging.DEBUG),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -346,7 +346,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
out.append(header)
|
out.append(header)
|
||||||
out.append(frame)
|
out.append(frame)
|
||||||
|
|
||||||
self._write_bytes(b"".join(out), debug_enabled)
|
self._write_bytes(out, debug_enabled)
|
||||||
|
|
||||||
def _handle_frame(self, frame: bytes) -> None:
|
def _handle_frame(self, frame: bytes) -> None:
|
||||||
"""Handle an incoming frame."""
|
"""Handle an incoming frame."""
|
||||||
|
@ -57,9 +57,10 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
|||||||
out.append(b"\0")
|
out.append(b"\0")
|
||||||
out.append(varuint_to_bytes(len(data)))
|
out.append(varuint_to_bytes(len(data)))
|
||||||
out.append(varuint_to_bytes(type_))
|
out.append(varuint_to_bytes(type_))
|
||||||
out.append(data)
|
if data:
|
||||||
|
out.append(data)
|
||||||
|
|
||||||
self._write_bytes(b"".join(out), debug_enabled)
|
self._write_bytes(out, debug_enabled)
|
||||||
|
|
||||||
def data_received(self, data: bytes | bytearray | memoryview) -> None:
|
def data_received(self, data: bytes | bytearray | memoryview) -> None:
|
||||||
self._add_to_buffer(data)
|
self._add_to_buffer(data)
|
||||||
|
@ -65,15 +65,19 @@ class Estr(str):
|
|||||||
"""A subclassed string."""
|
"""A subclassed string."""
|
||||||
|
|
||||||
|
|
||||||
def generate_plaintext_packet(msg: message.Message) -> bytes:
|
def generate_split_plaintext_packet(msg: message.Message) -> list[bytes]:
|
||||||
type_ = PROTO_TO_MESSAGE_TYPE[msg.__class__]
|
type_ = PROTO_TO_MESSAGE_TYPE[msg.__class__]
|
||||||
bytes_ = msg.SerializeToString()
|
bytes_ = msg.SerializeToString()
|
||||||
return (
|
return [
|
||||||
b"\0"
|
b"\0",
|
||||||
+ _cached_varuint_to_bytes(len(bytes_))
|
_cached_varuint_to_bytes(len(bytes_)),
|
||||||
+ _cached_varuint_to_bytes(type_)
|
_cached_varuint_to_bytes(type_),
|
||||||
+ bytes_
|
bytes_,
|
||||||
)
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_plaintext_packet(msg: message.Message) -> bytes:
|
||||||
|
return b"".join(generate_split_plaintext_packet(msg))
|
||||||
|
|
||||||
|
|
||||||
def as_utc(dattim: datetime) -> datetime:
|
def as_utc(dattim: datetime) -> datetime:
|
||||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
from collections.abc import Iterable
|
||||||
import sys
|
import sys
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
@ -132,7 +133,7 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
|
|||||||
"""Swallow args."""
|
"""Swallow args."""
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
transport = MagicMock()
|
transport = MagicMock()
|
||||||
transport.write = writer or MagicMock()
|
transport.writelines = writer or MagicMock()
|
||||||
self.__transport = transport
|
self.__transport = transport
|
||||||
self.connection_made(transport)
|
self.connection_made(transport)
|
||||||
|
|
||||||
@ -147,7 +148,7 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
|
|||||||
frame_len = len(frame)
|
frame_len = len(frame)
|
||||||
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
||||||
try:
|
try:
|
||||||
self._writer(header + frame)
|
self._writelines([header, frame])
|
||||||
except (RuntimeError, ConnectionResetError, OSError) as err:
|
except (RuntimeError, ConnectionResetError, OSError) as err:
|
||||||
raise SocketClosedAPIError(
|
raise SocketClosedAPIError(
|
||||||
f"{self._log_name}: Error while writing data: {err}"
|
f"{self._log_name}: Error while writing data: {err}"
|
||||||
@ -437,8 +438,8 @@ async def test_noise_frame_helper_handshake_failure():
|
|||||||
psk_bytes = base64.b64decode(noise_psk)
|
psk_bytes = base64.b64decode(noise_psk)
|
||||||
writes = []
|
writes = []
|
||||||
|
|
||||||
def _writer(data: bytes):
|
def _writelines(data: Iterable[bytes]):
|
||||||
writes.append(data)
|
writes.append(b"".join(data))
|
||||||
|
|
||||||
connection, _ = _make_mock_connection()
|
connection, _ = _make_mock_connection()
|
||||||
|
|
||||||
@ -448,7 +449,7 @@ async def test_noise_frame_helper_handshake_failure():
|
|||||||
expected_name="servicetest",
|
expected_name="servicetest",
|
||||||
client_info="my client",
|
client_info="my client",
|
||||||
log_name="test",
|
log_name="test",
|
||||||
writer=_writer,
|
writer=_writelines,
|
||||||
)
|
)
|
||||||
|
|
||||||
proto = _mock_responder_proto(psk_bytes)
|
proto = _mock_responder_proto(psk_bytes)
|
||||||
@ -486,8 +487,8 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
|||||||
psk_bytes = base64.b64decode(noise_psk)
|
psk_bytes = base64.b64decode(noise_psk)
|
||||||
writes = []
|
writes = []
|
||||||
|
|
||||||
def _writer(data: bytes):
|
def _writelines(data: Iterable[bytes]):
|
||||||
writes.append(data)
|
writes.append(b"".join(data))
|
||||||
|
|
||||||
connection, packets = _make_mock_connection()
|
connection, packets = _make_mock_connection()
|
||||||
|
|
||||||
@ -497,7 +498,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
|||||||
expected_name="servicetest",
|
expected_name="servicetest",
|
||||||
client_info="my client",
|
client_info="my client",
|
||||||
log_name="test",
|
log_name="test",
|
||||||
writer=_writer,
|
writer=_writelines,
|
||||||
)
|
)
|
||||||
|
|
||||||
proto = _mock_responder_proto(psk_bytes)
|
proto = _mock_responder_proto(psk_bytes)
|
||||||
@ -548,8 +549,8 @@ async def test_noise_frame_helper_bad_encryption(
|
|||||||
psk_bytes = base64.b64decode(noise_psk)
|
psk_bytes = base64.b64decode(noise_psk)
|
||||||
writes = []
|
writes = []
|
||||||
|
|
||||||
def _writer(data: bytes):
|
def _writelines(data: Iterable[bytes]):
|
||||||
writes.append(data)
|
writes.append(b"".join(data))
|
||||||
|
|
||||||
connection, packets = _make_mock_connection()
|
connection, packets = _make_mock_connection()
|
||||||
|
|
||||||
@ -559,7 +560,7 @@ async def test_noise_frame_helper_bad_encryption(
|
|||||||
expected_name="servicetest",
|
expected_name="servicetest",
|
||||||
client_info="my client",
|
client_info="my client",
|
||||||
log_name="test",
|
log_name="test",
|
||||||
writer=_writer,
|
writer=_writelines,
|
||||||
)
|
)
|
||||||
|
|
||||||
proto = _mock_responder_proto(psk_bytes)
|
proto = _mock_responder_proto(psk_bytes)
|
||||||
|
@ -126,6 +126,7 @@ from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
|
|||||||
from .common import (
|
from .common import (
|
||||||
Estr,
|
Estr,
|
||||||
generate_plaintext_packet,
|
generate_plaintext_packet,
|
||||||
|
generate_split_plaintext_packet,
|
||||||
get_mock_zeroconf,
|
get_mock_zeroconf,
|
||||||
mock_data_received,
|
mock_data_received,
|
||||||
)
|
)
|
||||||
@ -1439,7 +1440,12 @@ async def test_bluetooth_gatt_write_without_response(
|
|||||||
)
|
)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
await write_task
|
await write_task
|
||||||
assert transport.mock_calls[0][1][0] == b'\x00\x0cK\x08\xd2\t\x10\xd2\t"\x041234'
|
assert transport.mock_calls[0][1][0] == [
|
||||||
|
b"\x00",
|
||||||
|
b"\x0c",
|
||||||
|
b"K",
|
||||||
|
b'\x08\xd2\t\x10\xd2\t"\x041234',
|
||||||
|
]
|
||||||
|
|
||||||
with pytest.raises(TimeoutAPIError, match="BluetoothGATTWriteResponse"):
|
with pytest.raises(TimeoutAPIError, match="BluetoothGATTWriteResponse"):
|
||||||
await client.bluetooth_gatt_write(1234, 1234, b"1234", True, timeout=0)
|
await client.bluetooth_gatt_write(1234, 1234, b"1234", True, timeout=0)
|
||||||
@ -1484,7 +1490,12 @@ async def test_bluetooth_gatt_write_descriptor_without_response(
|
|||||||
)
|
)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
await write_task
|
await write_task
|
||||||
assert transport.mock_calls[0][1][0] == b"\x00\x0cM\x08\xd2\t\x10\xd2\t\x1a\x041234"
|
assert transport.mock_calls[0][1][0] == [
|
||||||
|
b"\x00",
|
||||||
|
b"\x0c",
|
||||||
|
b"M",
|
||||||
|
b"\x08\xd2\t\x10\xd2\t\x1a\x041234",
|
||||||
|
]
|
||||||
|
|
||||||
with pytest.raises(TimeoutAPIError, match="BluetoothGATTWriteResponse"):
|
with pytest.raises(TimeoutAPIError, match="BluetoothGATTWriteResponse"):
|
||||||
await client.bluetooth_gatt_write_descriptor(1234, 1234, b"1234", timeout=0)
|
await client.bluetooth_gatt_write_descriptor(1234, 1234, b"1234", timeout=0)
|
||||||
@ -2042,8 +2053,8 @@ async def test_bluetooth_device_connect(
|
|||||||
|
|
||||||
cancel = await connect_task
|
cancel = await connect_task
|
||||||
assert states == [(True, 23, 0)]
|
assert states == [(True, 23, 0)]
|
||||||
transport.write.assert_called_once_with(
|
transport.writelines.assert_called_once_with(
|
||||||
generate_plaintext_packet(
|
generate_split_plaintext_packet(
|
||||||
BluetoothDeviceRequest(
|
BluetoothDeviceRequest(
|
||||||
address=1234,
|
address=1234,
|
||||||
request_type=method,
|
request_type=method,
|
||||||
@ -2133,13 +2144,13 @@ async def test_bluetooth_device_connect_times_out_disconnect_ok(
|
|||||||
)
|
)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
# The connect request should be written
|
# The connect request should be written
|
||||||
assert len(transport.write.mock_calls) == 1
|
assert len(transport.writelines.mock_calls) == 1
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
# Now that we timed out, the disconnect
|
# Now that we timed out, the disconnect
|
||||||
# request should be written
|
# request should be written
|
||||||
assert len(transport.write.mock_calls) == 2
|
assert len(transport.writelines.mock_calls) == 2
|
||||||
response: message.Message = BluetoothDeviceConnectionResponse(
|
response: message.Message = BluetoothDeviceConnectionResponse(
|
||||||
address=1234, connected=False, mtu=23, error=8
|
address=1234, connected=False, mtu=23, error=8
|
||||||
)
|
)
|
||||||
@ -2177,7 +2188,7 @@ async def test_bluetooth_device_connect_cancelled(
|
|||||||
)
|
)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
# The connect request should be written
|
# The connect request should be written
|
||||||
assert len(transport.write.mock_calls) == 1
|
assert len(transport.writelines.mock_calls) == 1
|
||||||
connect_task.cancel()
|
connect_task.cancel()
|
||||||
with pytest.raises(asyncio.CancelledError):
|
with pytest.raises(asyncio.CancelledError):
|
||||||
await connect_task
|
await connect_task
|
||||||
|
@ -115,7 +115,7 @@ async def test_timeout_sending_message(
|
|||||||
with patch("aioesphomeapi.connection.DISCONNECT_RESPONSE_TIMEOUT", 0.0):
|
with patch("aioesphomeapi.connection.DISCONNECT_RESPONSE_TIMEOUT", 0.0):
|
||||||
await conn.disconnect()
|
await conn.disconnect()
|
||||||
|
|
||||||
transport.write.assert_called_with(b"\x00\x00\x05")
|
transport.writelines.assert_called_with([b"\x00", b"\x00", b"\x05"])
|
||||||
|
|
||||||
assert "disconnect request failed" in caplog.text
|
assert "disconnect request failed" in caplog.text
|
||||||
assert " Timeout waiting for DisconnectResponse after 0.0s" in caplog.text
|
assert " Timeout waiting for DisconnectResponse after 0.0s" in caplog.text
|
||||||
@ -152,7 +152,7 @@ async def test_disconnect_when_not_fully_connected(
|
|||||||
):
|
):
|
||||||
await connect_task
|
await connect_task
|
||||||
|
|
||||||
transport.write.assert_called_with(b"\x00\x00\x05")
|
transport.writelines.assert_called_with([b"\x00", b"\x00", b"\x05"])
|
||||||
|
|
||||||
assert "disconnect request failed" in caplog.text
|
assert "disconnect request failed" in caplog.text
|
||||||
assert " Timeout waiting for DisconnectResponse after 0.0s" in caplog.text
|
assert " Timeout waiting for DisconnectResponse after 0.0s" in caplog.text
|
||||||
@ -506,7 +506,7 @@ async def test_plaintext_connection_fails_handshake(
|
|||||||
) -> tuple[asyncio.Transport, APIPlaintextFrameHelperHandshakeException]:
|
) -> tuple[asyncio.Transport, APIPlaintextFrameHelperHandshakeException]:
|
||||||
protocol: APIPlaintextFrameHelperHandshakeException = create_func()
|
protocol: APIPlaintextFrameHelperHandshakeException = create_func()
|
||||||
protocol._transport = cast(asyncio.Transport, transport)
|
protocol._transport = cast(asyncio.Transport, transport)
|
||||||
protocol._writer = transport.write
|
protocol._writelines = transport.writelines
|
||||||
protocol.ready_future.set_exception(exception)
|
protocol.ready_future.set_exception(exception)
|
||||||
connected.set()
|
connected.set()
|
||||||
return transport, protocol
|
return transport, protocol
|
||||||
@ -549,7 +549,9 @@ async def test_plaintext_connection_fails_handshake(
|
|||||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||||
await connected.wait()
|
await connected.wait()
|
||||||
|
|
||||||
with (pytest.raises(raised_exception),):
|
with (
|
||||||
|
pytest.raises(raised_exception),
|
||||||
|
):
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
await connect_task
|
await connect_task
|
||||||
|
|
||||||
@ -646,7 +648,7 @@ async def test_force_disconnect_fails(
|
|||||||
await connect_task
|
await connect_task
|
||||||
assert conn.is_connected
|
assert conn.is_connected
|
||||||
|
|
||||||
with patch.object(protocol, "_writer", side_effect=OSError):
|
with patch.object(protocol, "_writelines", side_effect=OSError):
|
||||||
conn.force_disconnect()
|
conn.force_disconnect()
|
||||||
assert "Failed to send (forced) disconnect request" in caplog.text
|
assert "Failed to send (forced) disconnect request" in caplog.text
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
@ -822,7 +824,7 @@ async def test_disconnect_fails_to_send_response(
|
|||||||
await connect_task
|
await connect_task
|
||||||
assert client._connection.is_connected
|
assert client._connection.is_connected
|
||||||
|
|
||||||
with patch.object(protocol, "_writer", side_effect=OSError):
|
with patch.object(protocol, "_writelines", side_effect=OSError):
|
||||||
disconnect_request = DisconnectRequest()
|
disconnect_request = DisconnectRequest()
|
||||||
mock_data_received(protocol, generate_plaintext_packet(disconnect_request))
|
mock_data_received(protocol, generate_plaintext_packet(disconnect_request))
|
||||||
|
|
||||||
@ -893,7 +895,7 @@ async def test_ping_disconnects_after_no_responses(
|
|||||||
|
|
||||||
await connect_task
|
await connect_task
|
||||||
|
|
||||||
ping_request_bytes = b"\x00\x00\x07"
|
ping_request_bytes = [b"\x00", b"\x00", b"\x07"]
|
||||||
|
|
||||||
assert conn.is_connected
|
assert conn.is_connected
|
||||||
transport.reset_mock()
|
transport.reset_mock()
|
||||||
@ -904,9 +906,9 @@ async def test_ping_disconnects_after_no_responses(
|
|||||||
async_fire_time_changed(
|
async_fire_time_changed(
|
||||||
start_time + timedelta(seconds=KEEP_ALIVE_INTERVAL * count)
|
start_time + timedelta(seconds=KEEP_ALIVE_INTERVAL * count)
|
||||||
)
|
)
|
||||||
assert transport.write.call_count == count
|
assert transport.writelines.call_count == count
|
||||||
expected_calls.append(call(ping_request_bytes))
|
expected_calls.append(call(ping_request_bytes))
|
||||||
assert transport.write.mock_calls == expected_calls
|
assert transport.writelines.mock_calls == expected_calls
|
||||||
|
|
||||||
assert conn.is_connected is True
|
assert conn.is_connected is True
|
||||||
|
|
||||||
@ -915,7 +917,7 @@ async def test_ping_disconnects_after_no_responses(
|
|||||||
start_time
|
start_time
|
||||||
+ timedelta(seconds=KEEP_ALIVE_INTERVAL * (max_pings_to_disconnect_after + 1))
|
+ timedelta(seconds=KEEP_ALIVE_INTERVAL * (max_pings_to_disconnect_after + 1))
|
||||||
)
|
)
|
||||||
assert transport.write.call_count == max_pings_to_disconnect_after + 1
|
assert transport.writelines.call_count == max_pings_to_disconnect_after + 1
|
||||||
|
|
||||||
assert conn.is_connected is False
|
assert conn.is_connected is False
|
||||||
|
|
||||||
@ -932,7 +934,7 @@ async def test_ping_does_not_disconnect_if_we_get_responses(
|
|||||||
send_plaintext_connect_response(protocol, False)
|
send_plaintext_connect_response(protocol, False)
|
||||||
|
|
||||||
await connect_task
|
await connect_task
|
||||||
ping_request_bytes = b"\x00\x00\x07"
|
ping_request_bytes = [b"\x00", b"\x00", b"\x07"]
|
||||||
|
|
||||||
assert conn.is_connected
|
assert conn.is_connected
|
||||||
transport.reset_mock()
|
transport.reset_mock()
|
||||||
@ -945,8 +947,8 @@ async def test_ping_does_not_disconnect_if_we_get_responses(
|
|||||||
send_ping_response(protocol)
|
send_ping_response(protocol)
|
||||||
|
|
||||||
# We should only send 1 ping request if we are getting responses
|
# We should only send 1 ping request if we are getting responses
|
||||||
assert transport.write.call_count == 1
|
assert transport.writelines.call_count == 1
|
||||||
assert transport.write.mock_calls == [call(ping_request_bytes)]
|
assert transport.writelines.mock_calls == [call(ping_request_bytes)]
|
||||||
|
|
||||||
# We should disconnect if we are getting ping responses
|
# We should disconnect if we are getting ping responses
|
||||||
assert conn.is_connected is True
|
assert conn.is_connected is True
|
||||||
@ -976,9 +978,9 @@ async def test_respond_to_ping_request(
|
|||||||
transport.reset_mock()
|
transport.reset_mock()
|
||||||
send_ping_request(protocol)
|
send_ping_request(protocol)
|
||||||
# We should respond to ping requests
|
# We should respond to ping requests
|
||||||
ping_response_bytes = b"\x00\x00\x08"
|
ping_response_bytes = [b"\x00", b"\x00", b"\x08"]
|
||||||
assert transport.write.call_count == 1
|
assert transport.writelines.call_count == 1
|
||||||
assert transport.write.mock_calls == [call(ping_response_bytes)]
|
assert transport.writelines.mock_calls == [call(ping_response_bytes)]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
Loading…
Reference in New Issue
Block a user