Implement zerocopy writes (#990)

This commit is contained in:
J. Nick Koston 2024-11-01 11:46:10 -05:00 committed by GitHub
parent 4bea46b201
commit ba05d38602
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 78 additions and 54 deletions

View File

@ -11,7 +11,7 @@ cdef class APIFrameHelper:
cdef object _loop
cdef APIConnection _connection
cdef object _transport
cdef public object _writer
cdef public object _writelines
cdef public object ready_future
cdef bytes _buffer
cdef unsigned int _buffer_len

View File

@ -2,6 +2,7 @@ from __future__ import annotations
from abc import abstractmethod
import asyncio
from collections.abc import Iterable
import logging
from typing import TYPE_CHECKING, Callable, cast
@ -31,7 +32,7 @@ class APIFrameHelper:
"_loop",
"_connection",
"_transport",
"_writer",
"_writelines",
"ready_future",
"_buffer",
"_buffer_len",
@ -51,7 +52,9 @@ class APIFrameHelper:
self._loop = loop
self._connection = connection
self._transport: asyncio.Transport | None = None
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
self._writelines: (
None | (Callable[[Iterable[bytes | bytearray | memoryview[int]]], None])
) = None
self.ready_future = self._loop.create_future()
self._buffer: bytes | None = None
self._buffer_len = 0
@ -146,7 +149,7 @@ class APIFrameHelper:
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""Handle a new connection."""
self._transport = cast(asyncio.Transport, transport)
self._writer = self._transport.write
self._writelines = self._transport.writelines
def _handle_error_and_close(self, exc: Exception) -> None:
self._handle_error(exc)
@ -172,7 +175,7 @@ class APIFrameHelper:
if self._transport:
self._transport.close()
self._transport = None
self._writer = None
self._writelines = None
def pause_writing(self) -> None:
"""Stub."""
@ -180,12 +183,14 @@ class APIFrameHelper:
def resume_writing(self) -> None:
"""Stub."""
def _write_bytes(self, data: _bytes, debug_enabled: bool) -> None:
def _write_bytes(self, data: Iterable[_bytes], debug_enabled: bool) -> None:
"""Write bytes to the socket."""
if debug_enabled:
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, data.hex())
_LOGGER.debug(
"%s: Sending frame: [%s]", self._log_name, b"".join(data).hex()
)
if TYPE_CHECKING:
assert self._writer is not None, "Writer is not set"
assert self._writelines is not None, "Writer is not set"
self._writer(data)
self._writelines(data)

View File

@ -218,7 +218,7 @@ class APINoiseFrameHelper(APIFrameHelper):
frame_len = len(handshake_frame) + 1
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
self._write_bytes(
b"".join((NOISE_HELLO, header, b"\x00", handshake_frame)),
(NOISE_HELLO, header, b"\x00", handshake_frame),
_LOGGER.isEnabledFor(logging.DEBUG),
)
@ -346,7 +346,7 @@ class APINoiseFrameHelper(APIFrameHelper):
out.append(header)
out.append(frame)
self._write_bytes(b"".join(out), debug_enabled)
self._write_bytes(out, debug_enabled)
def _handle_frame(self, frame: bytes) -> None:
"""Handle an incoming frame."""

View File

@ -57,9 +57,10 @@ class APIPlaintextFrameHelper(APIFrameHelper):
out.append(b"\0")
out.append(varuint_to_bytes(len(data)))
out.append(varuint_to_bytes(type_))
out.append(data)
if data:
out.append(data)
self._write_bytes(b"".join(out), debug_enabled)
self._write_bytes(out, debug_enabled)
def data_received(self, data: bytes | bytearray | memoryview) -> None:
self._add_to_buffer(data)

View File

@ -65,15 +65,19 @@ class Estr(str):
"""A subclassed string."""
def generate_plaintext_packet(msg: message.Message) -> bytes:
def generate_split_plaintext_packet(msg: message.Message) -> list[bytes]:
type_ = PROTO_TO_MESSAGE_TYPE[msg.__class__]
bytes_ = msg.SerializeToString()
return (
b"\0"
+ _cached_varuint_to_bytes(len(bytes_))
+ _cached_varuint_to_bytes(type_)
+ bytes_
)
return [
b"\0",
_cached_varuint_to_bytes(len(bytes_)),
_cached_varuint_to_bytes(type_),
bytes_,
]
def generate_plaintext_packet(msg: message.Message) -> bytes:
return b"".join(generate_split_plaintext_packet(msg))
def as_utc(dattim: datetime) -> datetime:

View File

@ -2,6 +2,7 @@ from __future__ import annotations
import asyncio
import base64
from collections.abc import Iterable
import sys
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
@ -132,7 +133,7 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
"""Swallow args."""
super().__init__(*args, **kwargs)
transport = MagicMock()
transport.write = writer or MagicMock()
transport.writelines = writer or MagicMock()
self.__transport = transport
self.connection_made(transport)
@ -147,7 +148,7 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
frame_len = len(frame)
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
try:
self._writer(header + frame)
self._writelines([header, frame])
except (RuntimeError, ConnectionResetError, OSError) as err:
raise SocketClosedAPIError(
f"{self._log_name}: Error while writing data: {err}"
@ -437,8 +438,8 @@ async def test_noise_frame_helper_handshake_failure():
psk_bytes = base64.b64decode(noise_psk)
writes = []
def _writer(data: bytes):
writes.append(data)
def _writelines(data: Iterable[bytes]):
writes.append(b"".join(data))
connection, _ = _make_mock_connection()
@ -448,7 +449,7 @@ async def test_noise_frame_helper_handshake_failure():
expected_name="servicetest",
client_info="my client",
log_name="test",
writer=_writer,
writer=_writelines,
)
proto = _mock_responder_proto(psk_bytes)
@ -486,8 +487,8 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
psk_bytes = base64.b64decode(noise_psk)
writes = []
def _writer(data: bytes):
writes.append(data)
def _writelines(data: Iterable[bytes]):
writes.append(b"".join(data))
connection, packets = _make_mock_connection()
@ -497,7 +498,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
expected_name="servicetest",
client_info="my client",
log_name="test",
writer=_writer,
writer=_writelines,
)
proto = _mock_responder_proto(psk_bytes)
@ -548,8 +549,8 @@ async def test_noise_frame_helper_bad_encryption(
psk_bytes = base64.b64decode(noise_psk)
writes = []
def _writer(data: bytes):
writes.append(data)
def _writelines(data: Iterable[bytes]):
writes.append(b"".join(data))
connection, packets = _make_mock_connection()
@ -559,7 +560,7 @@ async def test_noise_frame_helper_bad_encryption(
expected_name="servicetest",
client_info="my client",
log_name="test",
writer=_writer,
writer=_writelines,
)
proto = _mock_responder_proto(psk_bytes)

View File

@ -126,6 +126,7 @@ from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
from .common import (
Estr,
generate_plaintext_packet,
generate_split_plaintext_packet,
get_mock_zeroconf,
mock_data_received,
)
@ -1439,7 +1440,12 @@ async def test_bluetooth_gatt_write_without_response(
)
await asyncio.sleep(0)
await write_task
assert transport.mock_calls[0][1][0] == b'\x00\x0cK\x08\xd2\t\x10\xd2\t"\x041234'
assert transport.mock_calls[0][1][0] == [
b"\x00",
b"\x0c",
b"K",
b'\x08\xd2\t\x10\xd2\t"\x041234',
]
with pytest.raises(TimeoutAPIError, match="BluetoothGATTWriteResponse"):
await client.bluetooth_gatt_write(1234, 1234, b"1234", True, timeout=0)
@ -1484,7 +1490,12 @@ async def test_bluetooth_gatt_write_descriptor_without_response(
)
await asyncio.sleep(0)
await write_task
assert transport.mock_calls[0][1][0] == b"\x00\x0cM\x08\xd2\t\x10\xd2\t\x1a\x041234"
assert transport.mock_calls[0][1][0] == [
b"\x00",
b"\x0c",
b"M",
b"\x08\xd2\t\x10\xd2\t\x1a\x041234",
]
with pytest.raises(TimeoutAPIError, match="BluetoothGATTWriteResponse"):
await client.bluetooth_gatt_write_descriptor(1234, 1234, b"1234", timeout=0)
@ -2042,8 +2053,8 @@ async def test_bluetooth_device_connect(
cancel = await connect_task
assert states == [(True, 23, 0)]
transport.write.assert_called_once_with(
generate_plaintext_packet(
transport.writelines.assert_called_once_with(
generate_split_plaintext_packet(
BluetoothDeviceRequest(
address=1234,
request_type=method,
@ -2133,13 +2144,13 @@ async def test_bluetooth_device_connect_times_out_disconnect_ok(
)
await asyncio.sleep(0)
# The connect request should be written
assert len(transport.write.mock_calls) == 1
assert len(transport.writelines.mock_calls) == 1
await asyncio.sleep(0)
await asyncio.sleep(0)
await asyncio.sleep(0)
# Now that we timed out, the disconnect
# request should be written
assert len(transport.write.mock_calls) == 2
assert len(transport.writelines.mock_calls) == 2
response: message.Message = BluetoothDeviceConnectionResponse(
address=1234, connected=False, mtu=23, error=8
)
@ -2177,7 +2188,7 @@ async def test_bluetooth_device_connect_cancelled(
)
await asyncio.sleep(0)
# The connect request should be written
assert len(transport.write.mock_calls) == 1
assert len(transport.writelines.mock_calls) == 1
connect_task.cancel()
with pytest.raises(asyncio.CancelledError):
await connect_task

View File

@ -115,7 +115,7 @@ async def test_timeout_sending_message(
with patch("aioesphomeapi.connection.DISCONNECT_RESPONSE_TIMEOUT", 0.0):
await conn.disconnect()
transport.write.assert_called_with(b"\x00\x00\x05")
transport.writelines.assert_called_with([b"\x00", b"\x00", b"\x05"])
assert "disconnect request failed" in caplog.text
assert " Timeout waiting for DisconnectResponse after 0.0s" in caplog.text
@ -152,7 +152,7 @@ async def test_disconnect_when_not_fully_connected(
):
await connect_task
transport.write.assert_called_with(b"\x00\x00\x05")
transport.writelines.assert_called_with([b"\x00", b"\x00", b"\x05"])
assert "disconnect request failed" in caplog.text
assert " Timeout waiting for DisconnectResponse after 0.0s" in caplog.text
@ -506,7 +506,7 @@ async def test_plaintext_connection_fails_handshake(
) -> tuple[asyncio.Transport, APIPlaintextFrameHelperHandshakeException]:
protocol: APIPlaintextFrameHelperHandshakeException = create_func()
protocol._transport = cast(asyncio.Transport, transport)
protocol._writer = transport.write
protocol._writelines = transport.writelines
protocol.ready_future.set_exception(exception)
connected.set()
return transport, protocol
@ -549,7 +549,9 @@ async def test_plaintext_connection_fails_handshake(
connect_task = asyncio.create_task(connect(conn, login=False))
await connected.wait()
with (pytest.raises(raised_exception),):
with (
pytest.raises(raised_exception),
):
await asyncio.sleep(0)
await connect_task
@ -646,7 +648,7 @@ async def test_force_disconnect_fails(
await connect_task
assert conn.is_connected
with patch.object(protocol, "_writer", side_effect=OSError):
with patch.object(protocol, "_writelines", side_effect=OSError):
conn.force_disconnect()
assert "Failed to send (forced) disconnect request" in caplog.text
await asyncio.sleep(0)
@ -822,7 +824,7 @@ async def test_disconnect_fails_to_send_response(
await connect_task
assert client._connection.is_connected
with patch.object(protocol, "_writer", side_effect=OSError):
with patch.object(protocol, "_writelines", side_effect=OSError):
disconnect_request = DisconnectRequest()
mock_data_received(protocol, generate_plaintext_packet(disconnect_request))
@ -893,7 +895,7 @@ async def test_ping_disconnects_after_no_responses(
await connect_task
ping_request_bytes = b"\x00\x00\x07"
ping_request_bytes = [b"\x00", b"\x00", b"\x07"]
assert conn.is_connected
transport.reset_mock()
@ -904,9 +906,9 @@ async def test_ping_disconnects_after_no_responses(
async_fire_time_changed(
start_time + timedelta(seconds=KEEP_ALIVE_INTERVAL * count)
)
assert transport.write.call_count == count
assert transport.writelines.call_count == count
expected_calls.append(call(ping_request_bytes))
assert transport.write.mock_calls == expected_calls
assert transport.writelines.mock_calls == expected_calls
assert conn.is_connected is True
@ -915,7 +917,7 @@ async def test_ping_disconnects_after_no_responses(
start_time
+ timedelta(seconds=KEEP_ALIVE_INTERVAL * (max_pings_to_disconnect_after + 1))
)
assert transport.write.call_count == max_pings_to_disconnect_after + 1
assert transport.writelines.call_count == max_pings_to_disconnect_after + 1
assert conn.is_connected is False
@ -932,7 +934,7 @@ async def test_ping_does_not_disconnect_if_we_get_responses(
send_plaintext_connect_response(protocol, False)
await connect_task
ping_request_bytes = b"\x00\x00\x07"
ping_request_bytes = [b"\x00", b"\x00", b"\x07"]
assert conn.is_connected
transport.reset_mock()
@ -945,8 +947,8 @@ async def test_ping_does_not_disconnect_if_we_get_responses(
send_ping_response(protocol)
# We should only send 1 ping request if we are getting responses
assert transport.write.call_count == 1
assert transport.write.mock_calls == [call(ping_request_bytes)]
assert transport.writelines.call_count == 1
assert transport.writelines.mock_calls == [call(ping_request_bytes)]
# We should disconnect if we are getting ping responses
assert conn.is_connected is True
@ -976,9 +978,9 @@ async def test_respond_to_ping_request(
transport.reset_mock()
send_ping_request(protocol)
# We should respond to ping requests
ping_response_bytes = b"\x00\x00\x08"
assert transport.write.call_count == 1
assert transport.write.mock_calls == [call(ping_response_bytes)]
ping_response_bytes = [b"\x00", b"\x00", b"\x08"]
assert transport.writelines.call_count == 1
assert transport.writelines.mock_calls == [call(ping_response_bytes)]
@pytest.mark.asyncio