mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-21 11:55:11 +01:00
Implement zerocopy writes (#990)
This commit is contained in:
parent
4bea46b201
commit
ba05d38602
@ -11,7 +11,7 @@ cdef class APIFrameHelper:
|
||||
cdef object _loop
|
||||
cdef APIConnection _connection
|
||||
cdef object _transport
|
||||
cdef public object _writer
|
||||
cdef public object _writelines
|
||||
cdef public object ready_future
|
||||
cdef bytes _buffer
|
||||
cdef unsigned int _buffer_len
|
||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
from abc import abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import Iterable
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Callable, cast
|
||||
|
||||
@ -31,7 +32,7 @@ class APIFrameHelper:
|
||||
"_loop",
|
||||
"_connection",
|
||||
"_transport",
|
||||
"_writer",
|
||||
"_writelines",
|
||||
"ready_future",
|
||||
"_buffer",
|
||||
"_buffer_len",
|
||||
@ -51,7 +52,9 @@ class APIFrameHelper:
|
||||
self._loop = loop
|
||||
self._connection = connection
|
||||
self._transport: asyncio.Transport | None = None
|
||||
self._writer: None | (Callable[[bytes | bytearray | memoryview], None]) = None
|
||||
self._writelines: (
|
||||
None | (Callable[[Iterable[bytes | bytearray | memoryview[int]]], None])
|
||||
) = None
|
||||
self.ready_future = self._loop.create_future()
|
||||
self._buffer: bytes | None = None
|
||||
self._buffer_len = 0
|
||||
@ -146,7 +149,7 @@ class APIFrameHelper:
|
||||
def connection_made(self, transport: asyncio.BaseTransport) -> None:
|
||||
"""Handle a new connection."""
|
||||
self._transport = cast(asyncio.Transport, transport)
|
||||
self._writer = self._transport.write
|
||||
self._writelines = self._transport.writelines
|
||||
|
||||
def _handle_error_and_close(self, exc: Exception) -> None:
|
||||
self._handle_error(exc)
|
||||
@ -172,7 +175,7 @@ class APIFrameHelper:
|
||||
if self._transport:
|
||||
self._transport.close()
|
||||
self._transport = None
|
||||
self._writer = None
|
||||
self._writelines = None
|
||||
|
||||
def pause_writing(self) -> None:
|
||||
"""Stub."""
|
||||
@ -180,12 +183,14 @@ class APIFrameHelper:
|
||||
def resume_writing(self) -> None:
|
||||
"""Stub."""
|
||||
|
||||
def _write_bytes(self, data: _bytes, debug_enabled: bool) -> None:
|
||||
def _write_bytes(self, data: Iterable[_bytes], debug_enabled: bool) -> None:
|
||||
"""Write bytes to the socket."""
|
||||
if debug_enabled:
|
||||
_LOGGER.debug("%s: Sending frame: [%s]", self._log_name, data.hex())
|
||||
_LOGGER.debug(
|
||||
"%s: Sending frame: [%s]", self._log_name, b"".join(data).hex()
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
assert self._writer is not None, "Writer is not set"
|
||||
assert self._writelines is not None, "Writer is not set"
|
||||
|
||||
self._writer(data)
|
||||
self._writelines(data)
|
||||
|
@ -218,7 +218,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
frame_len = len(handshake_frame) + 1
|
||||
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
||||
self._write_bytes(
|
||||
b"".join((NOISE_HELLO, header, b"\x00", handshake_frame)),
|
||||
(NOISE_HELLO, header, b"\x00", handshake_frame),
|
||||
_LOGGER.isEnabledFor(logging.DEBUG),
|
||||
)
|
||||
|
||||
@ -346,7 +346,7 @@ class APINoiseFrameHelper(APIFrameHelper):
|
||||
out.append(header)
|
||||
out.append(frame)
|
||||
|
||||
self._write_bytes(b"".join(out), debug_enabled)
|
||||
self._write_bytes(out, debug_enabled)
|
||||
|
||||
def _handle_frame(self, frame: bytes) -> None:
|
||||
"""Handle an incoming frame."""
|
||||
|
@ -57,9 +57,10 @@ class APIPlaintextFrameHelper(APIFrameHelper):
|
||||
out.append(b"\0")
|
||||
out.append(varuint_to_bytes(len(data)))
|
||||
out.append(varuint_to_bytes(type_))
|
||||
out.append(data)
|
||||
if data:
|
||||
out.append(data)
|
||||
|
||||
self._write_bytes(b"".join(out), debug_enabled)
|
||||
self._write_bytes(out, debug_enabled)
|
||||
|
||||
def data_received(self, data: bytes | bytearray | memoryview) -> None:
|
||||
self._add_to_buffer(data)
|
||||
|
@ -65,15 +65,19 @@ class Estr(str):
|
||||
"""A subclassed string."""
|
||||
|
||||
|
||||
def generate_plaintext_packet(msg: message.Message) -> bytes:
|
||||
def generate_split_plaintext_packet(msg: message.Message) -> list[bytes]:
|
||||
type_ = PROTO_TO_MESSAGE_TYPE[msg.__class__]
|
||||
bytes_ = msg.SerializeToString()
|
||||
return (
|
||||
b"\0"
|
||||
+ _cached_varuint_to_bytes(len(bytes_))
|
||||
+ _cached_varuint_to_bytes(type_)
|
||||
+ bytes_
|
||||
)
|
||||
return [
|
||||
b"\0",
|
||||
_cached_varuint_to_bytes(len(bytes_)),
|
||||
_cached_varuint_to_bytes(type_),
|
||||
bytes_,
|
||||
]
|
||||
|
||||
|
||||
def generate_plaintext_packet(msg: message.Message) -> bytes:
|
||||
return b"".join(generate_split_plaintext_packet(msg))
|
||||
|
||||
|
||||
def as_utc(dattim: datetime) -> datetime:
|
||||
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from collections.abc import Iterable
|
||||
import sys
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
@ -132,7 +133,7 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
|
||||
"""Swallow args."""
|
||||
super().__init__(*args, **kwargs)
|
||||
transport = MagicMock()
|
||||
transport.write = writer or MagicMock()
|
||||
transport.writelines = writer or MagicMock()
|
||||
self.__transport = transport
|
||||
self.connection_made(transport)
|
||||
|
||||
@ -147,7 +148,7 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
|
||||
frame_len = len(frame)
|
||||
header = bytes((0x01, (frame_len >> 8) & 0xFF, frame_len & 0xFF))
|
||||
try:
|
||||
self._writer(header + frame)
|
||||
self._writelines([header, frame])
|
||||
except (RuntimeError, ConnectionResetError, OSError) as err:
|
||||
raise SocketClosedAPIError(
|
||||
f"{self._log_name}: Error while writing data: {err}"
|
||||
@ -437,8 +438,8 @@ async def test_noise_frame_helper_handshake_failure():
|
||||
psk_bytes = base64.b64decode(noise_psk)
|
||||
writes = []
|
||||
|
||||
def _writer(data: bytes):
|
||||
writes.append(data)
|
||||
def _writelines(data: Iterable[bytes]):
|
||||
writes.append(b"".join(data))
|
||||
|
||||
connection, _ = _make_mock_connection()
|
||||
|
||||
@ -448,7 +449,7 @@ async def test_noise_frame_helper_handshake_failure():
|
||||
expected_name="servicetest",
|
||||
client_info="my client",
|
||||
log_name="test",
|
||||
writer=_writer,
|
||||
writer=_writelines,
|
||||
)
|
||||
|
||||
proto = _mock_responder_proto(psk_bytes)
|
||||
@ -486,8 +487,8 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
||||
psk_bytes = base64.b64decode(noise_psk)
|
||||
writes = []
|
||||
|
||||
def _writer(data: bytes):
|
||||
writes.append(data)
|
||||
def _writelines(data: Iterable[bytes]):
|
||||
writes.append(b"".join(data))
|
||||
|
||||
connection, packets = _make_mock_connection()
|
||||
|
||||
@ -497,7 +498,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
|
||||
expected_name="servicetest",
|
||||
client_info="my client",
|
||||
log_name="test",
|
||||
writer=_writer,
|
||||
writer=_writelines,
|
||||
)
|
||||
|
||||
proto = _mock_responder_proto(psk_bytes)
|
||||
@ -548,8 +549,8 @@ async def test_noise_frame_helper_bad_encryption(
|
||||
psk_bytes = base64.b64decode(noise_psk)
|
||||
writes = []
|
||||
|
||||
def _writer(data: bytes):
|
||||
writes.append(data)
|
||||
def _writelines(data: Iterable[bytes]):
|
||||
writes.append(b"".join(data))
|
||||
|
||||
connection, packets = _make_mock_connection()
|
||||
|
||||
@ -559,7 +560,7 @@ async def test_noise_frame_helper_bad_encryption(
|
||||
expected_name="servicetest",
|
||||
client_info="my client",
|
||||
log_name="test",
|
||||
writer=_writer,
|
||||
writer=_writelines,
|
||||
)
|
||||
|
||||
proto = _mock_responder_proto(psk_bytes)
|
||||
|
@ -126,6 +126,7 @@ from aioesphomeapi.reconnect_logic import ReconnectLogic, ReconnectLogicState
|
||||
from .common import (
|
||||
Estr,
|
||||
generate_plaintext_packet,
|
||||
generate_split_plaintext_packet,
|
||||
get_mock_zeroconf,
|
||||
mock_data_received,
|
||||
)
|
||||
@ -1439,7 +1440,12 @@ async def test_bluetooth_gatt_write_without_response(
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
await write_task
|
||||
assert transport.mock_calls[0][1][0] == b'\x00\x0cK\x08\xd2\t\x10\xd2\t"\x041234'
|
||||
assert transport.mock_calls[0][1][0] == [
|
||||
b"\x00",
|
||||
b"\x0c",
|
||||
b"K",
|
||||
b'\x08\xd2\t\x10\xd2\t"\x041234',
|
||||
]
|
||||
|
||||
with pytest.raises(TimeoutAPIError, match="BluetoothGATTWriteResponse"):
|
||||
await client.bluetooth_gatt_write(1234, 1234, b"1234", True, timeout=0)
|
||||
@ -1484,7 +1490,12 @@ async def test_bluetooth_gatt_write_descriptor_without_response(
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
await write_task
|
||||
assert transport.mock_calls[0][1][0] == b"\x00\x0cM\x08\xd2\t\x10\xd2\t\x1a\x041234"
|
||||
assert transport.mock_calls[0][1][0] == [
|
||||
b"\x00",
|
||||
b"\x0c",
|
||||
b"M",
|
||||
b"\x08\xd2\t\x10\xd2\t\x1a\x041234",
|
||||
]
|
||||
|
||||
with pytest.raises(TimeoutAPIError, match="BluetoothGATTWriteResponse"):
|
||||
await client.bluetooth_gatt_write_descriptor(1234, 1234, b"1234", timeout=0)
|
||||
@ -2042,8 +2053,8 @@ async def test_bluetooth_device_connect(
|
||||
|
||||
cancel = await connect_task
|
||||
assert states == [(True, 23, 0)]
|
||||
transport.write.assert_called_once_with(
|
||||
generate_plaintext_packet(
|
||||
transport.writelines.assert_called_once_with(
|
||||
generate_split_plaintext_packet(
|
||||
BluetoothDeviceRequest(
|
||||
address=1234,
|
||||
request_type=method,
|
||||
@ -2133,13 +2144,13 @@ async def test_bluetooth_device_connect_times_out_disconnect_ok(
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
# The connect request should be written
|
||||
assert len(transport.write.mock_calls) == 1
|
||||
assert len(transport.writelines.mock_calls) == 1
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
await asyncio.sleep(0)
|
||||
# Now that we timed out, the disconnect
|
||||
# request should be written
|
||||
assert len(transport.write.mock_calls) == 2
|
||||
assert len(transport.writelines.mock_calls) == 2
|
||||
response: message.Message = BluetoothDeviceConnectionResponse(
|
||||
address=1234, connected=False, mtu=23, error=8
|
||||
)
|
||||
@ -2177,7 +2188,7 @@ async def test_bluetooth_device_connect_cancelled(
|
||||
)
|
||||
await asyncio.sleep(0)
|
||||
# The connect request should be written
|
||||
assert len(transport.write.mock_calls) == 1
|
||||
assert len(transport.writelines.mock_calls) == 1
|
||||
connect_task.cancel()
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await connect_task
|
||||
|
@ -115,7 +115,7 @@ async def test_timeout_sending_message(
|
||||
with patch("aioesphomeapi.connection.DISCONNECT_RESPONSE_TIMEOUT", 0.0):
|
||||
await conn.disconnect()
|
||||
|
||||
transport.write.assert_called_with(b"\x00\x00\x05")
|
||||
transport.writelines.assert_called_with([b"\x00", b"\x00", b"\x05"])
|
||||
|
||||
assert "disconnect request failed" in caplog.text
|
||||
assert " Timeout waiting for DisconnectResponse after 0.0s" in caplog.text
|
||||
@ -152,7 +152,7 @@ async def test_disconnect_when_not_fully_connected(
|
||||
):
|
||||
await connect_task
|
||||
|
||||
transport.write.assert_called_with(b"\x00\x00\x05")
|
||||
transport.writelines.assert_called_with([b"\x00", b"\x00", b"\x05"])
|
||||
|
||||
assert "disconnect request failed" in caplog.text
|
||||
assert " Timeout waiting for DisconnectResponse after 0.0s" in caplog.text
|
||||
@ -506,7 +506,7 @@ async def test_plaintext_connection_fails_handshake(
|
||||
) -> tuple[asyncio.Transport, APIPlaintextFrameHelperHandshakeException]:
|
||||
protocol: APIPlaintextFrameHelperHandshakeException = create_func()
|
||||
protocol._transport = cast(asyncio.Transport, transport)
|
||||
protocol._writer = transport.write
|
||||
protocol._writelines = transport.writelines
|
||||
protocol.ready_future.set_exception(exception)
|
||||
connected.set()
|
||||
return transport, protocol
|
||||
@ -549,7 +549,9 @@ async def test_plaintext_connection_fails_handshake(
|
||||
connect_task = asyncio.create_task(connect(conn, login=False))
|
||||
await connected.wait()
|
||||
|
||||
with (pytest.raises(raised_exception),):
|
||||
with (
|
||||
pytest.raises(raised_exception),
|
||||
):
|
||||
await asyncio.sleep(0)
|
||||
await connect_task
|
||||
|
||||
@ -646,7 +648,7 @@ async def test_force_disconnect_fails(
|
||||
await connect_task
|
||||
assert conn.is_connected
|
||||
|
||||
with patch.object(protocol, "_writer", side_effect=OSError):
|
||||
with patch.object(protocol, "_writelines", side_effect=OSError):
|
||||
conn.force_disconnect()
|
||||
assert "Failed to send (forced) disconnect request" in caplog.text
|
||||
await asyncio.sleep(0)
|
||||
@ -822,7 +824,7 @@ async def test_disconnect_fails_to_send_response(
|
||||
await connect_task
|
||||
assert client._connection.is_connected
|
||||
|
||||
with patch.object(protocol, "_writer", side_effect=OSError):
|
||||
with patch.object(protocol, "_writelines", side_effect=OSError):
|
||||
disconnect_request = DisconnectRequest()
|
||||
mock_data_received(protocol, generate_plaintext_packet(disconnect_request))
|
||||
|
||||
@ -893,7 +895,7 @@ async def test_ping_disconnects_after_no_responses(
|
||||
|
||||
await connect_task
|
||||
|
||||
ping_request_bytes = b"\x00\x00\x07"
|
||||
ping_request_bytes = [b"\x00", b"\x00", b"\x07"]
|
||||
|
||||
assert conn.is_connected
|
||||
transport.reset_mock()
|
||||
@ -904,9 +906,9 @@ async def test_ping_disconnects_after_no_responses(
|
||||
async_fire_time_changed(
|
||||
start_time + timedelta(seconds=KEEP_ALIVE_INTERVAL * count)
|
||||
)
|
||||
assert transport.write.call_count == count
|
||||
assert transport.writelines.call_count == count
|
||||
expected_calls.append(call(ping_request_bytes))
|
||||
assert transport.write.mock_calls == expected_calls
|
||||
assert transport.writelines.mock_calls == expected_calls
|
||||
|
||||
assert conn.is_connected is True
|
||||
|
||||
@ -915,7 +917,7 @@ async def test_ping_disconnects_after_no_responses(
|
||||
start_time
|
||||
+ timedelta(seconds=KEEP_ALIVE_INTERVAL * (max_pings_to_disconnect_after + 1))
|
||||
)
|
||||
assert transport.write.call_count == max_pings_to_disconnect_after + 1
|
||||
assert transport.writelines.call_count == max_pings_to_disconnect_after + 1
|
||||
|
||||
assert conn.is_connected is False
|
||||
|
||||
@ -932,7 +934,7 @@ async def test_ping_does_not_disconnect_if_we_get_responses(
|
||||
send_plaintext_connect_response(protocol, False)
|
||||
|
||||
await connect_task
|
||||
ping_request_bytes = b"\x00\x00\x07"
|
||||
ping_request_bytes = [b"\x00", b"\x00", b"\x07"]
|
||||
|
||||
assert conn.is_connected
|
||||
transport.reset_mock()
|
||||
@ -945,8 +947,8 @@ async def test_ping_does_not_disconnect_if_we_get_responses(
|
||||
send_ping_response(protocol)
|
||||
|
||||
# We should only send 1 ping request if we are getting responses
|
||||
assert transport.write.call_count == 1
|
||||
assert transport.write.mock_calls == [call(ping_request_bytes)]
|
||||
assert transport.writelines.call_count == 1
|
||||
assert transport.writelines.mock_calls == [call(ping_request_bytes)]
|
||||
|
||||
# We should disconnect if we are getting ping responses
|
||||
assert conn.is_connected is True
|
||||
@ -976,9 +978,9 @@ async def test_respond_to_ping_request(
|
||||
transport.reset_mock()
|
||||
send_ping_request(protocol)
|
||||
# We should respond to ping requests
|
||||
ping_response_bytes = b"\x00\x00\x08"
|
||||
assert transport.write.call_count == 1
|
||||
assert transport.write.mock_calls == [call(ping_response_bytes)]
|
||||
ping_response_bytes = [b"\x00", b"\x00", b"\x08"]
|
||||
assert transport.writelines.call_count == 1
|
||||
assert transport.writelines.mock_calls == [call(ping_response_bytes)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
Loading…
Reference in New Issue
Block a user