Implement zerocopy writes (#990)

This commit is contained in:
J. Nick Koston 2024-11-01 11:46:10 -05:00 committed by GitHub
parent 4bea46b201
commit ba05d38602
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 78 additions and 54 deletions

View File

@ -11,7 +11,7 @@ cdef class APIFrameHelper:
cdef object _loop cdef object _loop
cdef APIConnection _connection cdef APIConnection _connection
cdef object _transport cdef object _transport
cdef public object _writer cdef public object _writelines
cdef public object ready_future cdef public object ready_future
cdef bytes _buffer cdef bytes _buffer
cdef unsigned int _buffer_len cdef unsigned int _buffer_len

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from abc import abstractmethod from abc import abstractmethod
import asyncio import asyncio
from collections.abc import Iterable
import logging import logging
from typing import TYPE_CHECKING, Callable, cast from typing import TYPE_CHECKING, Callable, cast
@ -31,7 +32,7 @@ class APIFrameHelper:
"_loop", "_loop",
"_connection", "_connection",
"_transport", "_transport",
"_writer", "_writelines",
"ready_future", "ready_future",
"_buffer", "_buffer",
"_buffer_len", "_buffer_len",
@ -51,7 +52,9 @@ class APIFrameHelper:
self._loop = loop self._loop = loop
self._connection = connection self._connection = connection
self._transport: asyncio.Transport | None = None self._transport: asyncio.Transport | None = None
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None self._writelines: (
None | (Callable[[Iterable[bytes | bytearray | memoryview[int]]], None])
) = None
self.ready_future = self._loop.create_future() self.ready_future = self._loop.create_future()
self._buffer: bytes | None = None self._buffer: bytes | None = None
self._buffer_len = 0 self._buffer_len = 0
@ -146,7 +149,7 @@ class APIFrameHelper:
def connection_made(self, transport: asyncio.BaseTransport) -> None: def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Handle a new connection.""" """Handle a new connection."""
self._transport = cast(asyncio.Transport, transport) self._transport = cast(asyncio.Transport, transport)
self._writer = self._transport.write self._writelines = self._transport.writelines
def _handle_error_and_close(self, exc: Exception) -> None: def _handle_error_and_close(self, exc: Exception) -> None:
self._handle_error(exc) self._handle_error(exc)
@ -172,7 +175,7 @@ class APIFrameHelper:
if self._transport: if self._transport:
self._transport.close() self._transport.close()
self._transport = None self._transport = None
self._writer = None self._writelines = None
def pause_writing(self) -> None: def pause_writing(self) -> None:
"""Stub.""" """Stub."""
@ -180,12 +183,14 @@ class APIFrameHelper:
def resume_writing(self) -> None: def resume_writing(self) -> None:
"""Stub.""" """Stub."""
def _write_bytes(self, data: _bytes, debug_enabled: bool) -> None: def _write_bytes(self, data: Iterable[_bytes], debug_enabled: bool) -> None:
"""Write bytes to the socket.""" """Write bytes to the socket."""
if debug_enabled: if debug_enabled:
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, data.hex()) _LOGGER.debug(
"%s: Sending frame: [%s]", self._log_name, b"".join(data).hex()
)
if TYPE_CHECKING: if TYPE_CHECKING:
assert self._writer is not None, "Writer is not set" assert self._writelines is not None, "Writer is not set"
self._writer(data) self._writelines(data)

View File

@ -218,7 +218,7 @@ class APINoiseFrameHelper(APIFrameHelper):
frame_len = len(handshake_frame) + 1 frame_len = len(handshake_frame) + 1
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF)) header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
self._write_bytes( self._write_bytes(
b"".join((NOISE_HELLO, header, b"\x00", handshake_frame)), (NOISE_HELLO, header, b"\x00", handshake_frame),
_LOGGER.isEnabledFor(logging.DEBUG), _LOGGER.isEnabledFor(logging.DEBUG),
) )
@ -346,7 +346,7 @@ class APINoiseFrameHelper(APIFrameHelper):
out.append(header) out.append(header)
out.append(frame) out.append(frame)
self._write_bytes(b"".join(out), debug_enabled) self._write_bytes(out, debug_enabled)
def _handle_frame(self, frame: bytes) -> None: def _handle_frame(self, frame: bytes) -> None:
"""Handle an incoming frame.""" """Handle an incoming frame."""

View File

@ -57,9 +57,10 @@ class APIPlaintextFrameHelper(APIFrameHelper):
out.append(b"\0") out.append(b"\0")
out.append(varuint_to_bytes(len(data))) out.append(varuint_to_bytes(len(data)))
out.append(varuint_to_bytes(type_)) out.append(varuint_to_bytes(type_))
out.append(data) if data:
out.append(data)
self._write_bytes(b"".join(out), debug_enabled) self._write_bytes(out, debug_enabled)
def data_received(self, data: bytes | bytearray | memoryview) -> None: def data_received(self, data: bytes | bytearray | memoryview) -> None:
self._add_to_buffer(data) self._add_to_buffer(data)

View File

@ -65,15 +65,19 @@ class Estr(str):
"""A subclassed string.""" """A subclassed string."""
def generate_plaintext_packet(msg: message.Message) -> bytes: def generate_split_plaintext_packet(msg: message.Message) -> list[bytes]:
type_ = PROTO_TO_MESSAGE_TYPE[msg.__class__] type_ = PROTO_TO_MESSAGE_TYPE[msg.__class__]
bytes_ = msg.SerializeToString() bytes_ = msg.SerializeToString()
return ( return [
b"\0" b"\0",
+ _cached_varuint_to_bytes(len(bytes_)) _cached_varuint_to_bytes(len(bytes_)),
+ _cached_varuint_to_bytes(type_) _cached_varuint_to_bytes(type_),
+ bytes_ bytes_,
) ]
def generate_plaintext_packet(msg: message.Message) -> bytes:
return b"".join(generate_split_plaintext_packet(msg))
def as_utc(dattim: datetime) -> datetime: def as_utc(dattim: datetime) -> datetime:

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio import asyncio
import base64 import base64
from collections.abc import Iterable
import sys import sys
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
@ -132,7 +133,7 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
"""Swallow args.""" """Swallow args."""
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
transport = MagicMock() transport = MagicMock()
transport.write = writer or MagicMock() transport.writelines = writer or MagicMock()
self.__transport = transport self.__transport = transport
self.connection_made(transport) self.connection_made(transport)
@ -147,7 +148,7 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
frame_len = len(frame) frame_len = len(frame)
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF)) header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
try: try:
self._writer(header + frame) self._writelines([header, frame])
except (RuntimeError, ConnectionResetError, OSError) as err: except (RuntimeError, ConnectionResetError, OSError) as err:
raise SocketClosedAPIError( raise SocketClosedAPIError(
f"{self._log_name}: Error while writing data: {err}" f"{self._log_name}: Error while writing data: {err}"
@ -437,8 +438,8 @@ async def test_noise_frame_helper_handshake_failure():
psk_bytes = base64.b64decode(noise_psk) psk_bytes = base64.b64decode(noise_psk)
writes = [] writes = []
def _writer(data: bytes): def _writelines(data: Iterable[bytes]):
writes.append(data) writes.append(b"".join(data))
connection, _ = _make_mock_connection() connection, _ = _make_mock_connection()
@ -448,7 +449,7 @@ async def test_noise_frame_helper_handshake_failure():
expected_name="servicetest", expected_name="servicetest",
client_info="my client", client_info="my client",
log_name="test", log_name="test",
writer=_writer, writer=_writelines,
) )
proto = _mock_responder_proto(psk_bytes) proto = _mock_responder_proto(psk_bytes)
@ -486,8 +487,8 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
psk_bytes = base64.b64decode(noise_psk) psk_bytes = base64.b64decode(noise_psk)
writes = [] writes = []
def _writer(data: bytes): def _writelines(data: Iterable[bytes]):
writes.append(data) writes.append(b"".join(data))
connection, packets = _make_mock_connection() connection, packets = _make_mock_connection()
@ -497,7 +498,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
expected_name="servicetest", expected_name="servicetest",
client_info="my client", client_info="my client",
log_name="test", log_name="test",
writer=_writer, writer=_writelines,
) )
proto = _mock_responder_proto(psk_bytes) proto = _mock_responder_proto(psk_bytes)
@ -548,8 +549,8 @@ async def test_noise_frame_helper_bad_encryption(
psk_bytes = base64.b64decode(noise_psk) psk_bytes = base64.b64decode(noise_psk)
writes = [] writes = []
def _writer(data: bytes): def _writelines(data: Iterable[bytes]):
writes.append(data) writes.append(b"".join(data))
connection, packets = _make_mock_connection() connection, packets = _make_mock_connection()
@ -559,7 +560,7 @@ async def test_noise_frame_helper_bad_encryption(
expected_name="servicetest", expected_name="servicetest",
client_info="my client", client_info="my client",
log_name="test", log_name="test",
writer=_writer, writer=_writelines,
) )
proto = _mock_responder_proto(psk_bytes) proto = _mock_responder_proto(psk_bytes)

View File

@ -126,6 +126,7 @@ from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
from .common import ( from .common import (
Estr, Estr,
generate_plaintext_packet, generate_plaintext_packet,
generate_split_plaintext_packet,
get_mock_zeroconf, get_mock_zeroconf,
mock_data_received, mock_data_received,
) )
@ -1439,7 +1440,12 @@ async def test_bluetooth_gatt_write_without_response(
) )
await asyncio.sleep(0) await asyncio.sleep(0)
await write_task await write_task
assert transport.mock_calls[0][1][0] == b'\x00\x0cK\x08\xd2\t\x10\xd2\t"\x041234' assert transport.mock_calls[0][1][0] == [
b"\x00",
b"\x0c",
b"K",
b'\x08\xd2\t\x10\xd2\t"\x041234',
]
with pytest.raises(TimeoutAPIError, match="BluetoothGATTWriteResponse"): with pytest.raises(TimeoutAPIError, match="BluetoothGATTWriteResponse"):
await client.bluetooth_gatt_write(1234, 1234, b"1234", True, timeout=0) await client.bluetooth_gatt_write(1234, 1234, b"1234", True, timeout=0)
@ -1484,7 +1490,12 @@ async def test_bluetooth_gatt_write_descriptor_without_response(
) )
await asyncio.sleep(0) await asyncio.sleep(0)
await write_task await write_task
assert transport.mock_calls[0][1][0] == b"\x00\x0cM\x08\xd2\t\x10\xd2\t\x1a\x041234" assert transport.mock_calls[0][1][0] == [
b"\x00",
b"\x0c",
b"M",
b"\x08\xd2\t\x10\xd2\t\x1a\x041234",
]
with pytest.raises(TimeoutAPIError, match="BluetoothGATTWriteResponse"): with pytest.raises(TimeoutAPIError, match="BluetoothGATTWriteResponse"):
await client.bluetooth_gatt_write_descriptor(1234, 1234, b"1234", timeout=0) await client.bluetooth_gatt_write_descriptor(1234, 1234, b"1234", timeout=0)
@ -2042,8 +2053,8 @@ async def test_bluetooth_device_connect(
cancel = await connect_task cancel = await connect_task
assert states == [(True, 23, 0)] assert states == [(True, 23, 0)]
transport.write.assert_called_once_with( transport.writelines.assert_called_once_with(
generate_plaintext_packet( generate_split_plaintext_packet(
BluetoothDeviceRequest( BluetoothDeviceRequest(
address=1234, address=1234,
request_type=method, request_type=method,
@ -2133,13 +2144,13 @@ async def test_bluetooth_device_connect_times_out_disconnect_ok(
) )
await asyncio.sleep(0) await asyncio.sleep(0)
# The connect request should be written # The connect request should be written
assert len(transport.write.mock_calls) == 1 assert len(transport.writelines.mock_calls) == 1
await asyncio.sleep(0) await asyncio.sleep(0)
await asyncio.sleep(0) await asyncio.sleep(0)
await asyncio.sleep(0) await asyncio.sleep(0)
# Now that we timed out, the disconnect # Now that we timed out, the disconnect
# request should be written # request should be written
assert len(transport.write.mock_calls) == 2 assert len(transport.writelines.mock_calls) == 2
response: message.Message = BluetoothDeviceConnectionResponse( response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=False, mtu=23, error=8 address=1234, connected=False, mtu=23, error=8
) )
@ -2177,7 +2188,7 @@ async def test_bluetooth_device_connect_cancelled(
) )
await asyncio.sleep(0) await asyncio.sleep(0)
# The connect request should be written # The connect request should be written
assert len(transport.write.mock_calls) == 1 assert len(transport.writelines.mock_calls) == 1
connect_task.cancel() connect_task.cancel()
with pytest.raises(asyncio.CancelledError): with pytest.raises(asyncio.CancelledError):
await connect_task await connect_task

View File

@ -115,7 +115,7 @@ async def test_timeout_sending_message(
with patch("aioesphomeapi.connection.DISCONNECT_RESPONSE_TIMEOUT", 0.0): with patch("aioesphomeapi.connection.DISCONNECT_RESPONSE_TIMEOUT", 0.0):
await conn.disconnect() await conn.disconnect()
transport.write.assert_called_with(b"\x00\x00\x05") transport.writelines.assert_called_with([b"\x00", b"\x00", b"\x05"])
assert "disconnect request failed" in caplog.text assert "disconnect request failed" in caplog.text
assert " Timeout waiting for DisconnectResponse after 0.0s" in caplog.text assert " Timeout waiting for DisconnectResponse after 0.0s" in caplog.text
@ -152,7 +152,7 @@ async def test_disconnect_when_not_fully_connected(
): ):
await connect_task await connect_task
transport.write.assert_called_with(b"\x00\x00\x05") transport.writelines.assert_called_with([b"\x00", b"\x00", b"\x05"])
assert "disconnect request failed" in caplog.text assert "disconnect request failed" in caplog.text
assert " Timeout waiting for DisconnectResponse after 0.0s" in caplog.text assert " Timeout waiting for DisconnectResponse after 0.0s" in caplog.text
@ -506,7 +506,7 @@ async def test_plaintext_connection_fails_handshake(
) -> tuple[asyncio.Transport, APIPlaintextFrameHelperHandshakeException]: ) -> tuple[asyncio.Transport, APIPlaintextFrameHelperHandshakeException]:
protocol: APIPlaintextFrameHelperHandshakeException = create_func() protocol: APIPlaintextFrameHelperHandshakeException = create_func()
protocol._transport = cast(asyncio.Transport, transport) protocol._transport = cast(asyncio.Transport, transport)
protocol._writer = transport.write protocol._writelines = transport.writelines
protocol.ready_future.set_exception(exception) protocol.ready_future.set_exception(exception)
connected.set() connected.set()
return transport, protocol return transport, protocol
@ -549,7 +549,9 @@ async def test_plaintext_connection_fails_handshake(
connect_task = asyncio.create_task(connect(conn, login=False)) connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait() await connected.wait()
with (pytest.raises(raised_exception),): with (
pytest.raises(raised_exception),
):
await asyncio.sleep(0) await asyncio.sleep(0)
await connect_task await connect_task
@ -646,7 +648,7 @@ async def test_force_disconnect_fails(
await connect_task await connect_task
assert conn.is_connected assert conn.is_connected
with patch.object(protocol, "_writer", side_effect=OSError): with patch.object(protocol, "_writelines", side_effect=OSError):
conn.force_disconnect() conn.force_disconnect()
assert "Failed to send (forced) disconnect request" in caplog.text assert "Failed to send (forced) disconnect request" in caplog.text
await asyncio.sleep(0) await asyncio.sleep(0)
@ -822,7 +824,7 @@ async def test_disconnect_fails_to_send_response(
await connect_task await connect_task
assert client._connection.is_connected assert client._connection.is_connected
with patch.object(protocol, "_writer", side_effect=OSError): with patch.object(protocol, "_writelines", side_effect=OSError):
disconnect_request = DisconnectRequest() disconnect_request = DisconnectRequest()
mock_data_received(protocol, generate_plaintext_packet(disconnect_request)) mock_data_received(protocol, generate_plaintext_packet(disconnect_request))
@ -893,7 +895,7 @@ async def test_ping_disconnects_after_no_responses(
await connect_task await connect_task
ping_request_bytes = b"\x00\x00\x07" ping_request_bytes = [b"\x00", b"\x00", b"\x07"]
assert conn.is_connected assert conn.is_connected
transport.reset_mock() transport.reset_mock()
@ -904,9 +906,9 @@ async def test_ping_disconnects_after_no_responses(
async_fire_time_changed( async_fire_time_changed(
start_time + timedelta(seconds=KEEP_ALIVE_INTERVAL * count) start_time + timedelta(seconds=KEEP_ALIVE_INTERVAL * count)
) )
assert transport.write.call_count == count assert transport.writelines.call_count == count
expected_calls.append(call(ping_request_bytes)) expected_calls.append(call(ping_request_bytes))
assert transport.write.mock_calls == expected_calls assert transport.writelines.mock_calls == expected_calls
assert conn.is_connected is True assert conn.is_connected is True
@ -915,7 +917,7 @@ async def test_ping_disconnects_after_no_responses(
start_time start_time
+ timedelta(seconds=KEEP_ALIVE_INTERVAL * (max_pings_to_disconnect_after + 1)) + timedelta(seconds=KEEP_ALIVE_INTERVAL * (max_pings_to_disconnect_after + 1))
) )
assert transport.write.call_count == max_pings_to_disconnect_after + 1 assert transport.writelines.call_count == max_pings_to_disconnect_after + 1
assert conn.is_connected is False assert conn.is_connected is False
@ -932,7 +934,7 @@ async def test_ping_does_not_disconnect_if_we_get_responses(
send_plaintext_connect_response(protocol, False) send_plaintext_connect_response(protocol, False)
await connect_task await connect_task
ping_request_bytes = b"\x00\x00\x07" ping_request_bytes = [b"\x00", b"\x00", b"\x07"]
assert conn.is_connected assert conn.is_connected
transport.reset_mock() transport.reset_mock()
@ -945,8 +947,8 @@ async def test_ping_does_not_disconnect_if_we_get_responses(
send_ping_response(protocol) send_ping_response(protocol)
# We should only send 1 ping request if we are getting responses # We should only send 1 ping request if we are getting responses
assert transport.write.call_count == 1 assert transport.writelines.call_count == 1
assert transport.write.mock_calls == [call(ping_request_bytes)] assert transport.writelines.mock_calls == [call(ping_request_bytes)]
# We should disconnect if we are getting ping responses # We should disconnect if we are getting ping responses
assert conn.is_connected is True assert conn.is_connected is True
@ -976,9 +978,9 @@ async def test_respond_to_ping_request(
transport.reset_mock() transport.reset_mock()
send_ping_request(protocol) send_ping_request(protocol)
# We should respond to ping requests # We should respond to ping requests
ping_response_bytes = b"\x00\x00\x08" ping_response_bytes = [b"\x00", b"\x00", b"\x08"]
assert transport.write.call_count == 1 assert transport.writelines.call_count == 1
assert transport.write.mock_calls == [call(ping_response_bytes)] assert transport.writelines.mock_calls == [call(ping_response_bytes)]
@pytest.mark.asyncio @pytest.mark.asyncio