workaround pytest env max length on win32

This commit is contained in:
J. Nick Koston 2024-08-29 10:48:05 -10:00
parent b8551b35d5
commit 428621cf0b
No known key found for this signature in database
4 changed files with 147 additions and 119 deletions

View File

@ -1,10 +1,11 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import time from collections.abc import Awaitable
from datetime import datetime, timezone from datetime import datetime, timezone
from functools import partial from functools import partial
from typing import Awaitable, Callable import time
from typing import Callable
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from google.protobuf import message from google.protobuf import message

View File

@ -3,11 +3,11 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import socket
from dataclasses import replace from dataclasses import replace
from functools import partial from functools import partial
from typing import Callable import socket
from unittest.mock import MagicMock, create_autospec, patch from typing import Any, Callable
from unittest.mock import AsyncMock, MagicMock, create_autospec, patch
import pytest import pytest
import pytest_asyncio import pytest_asyncio
@ -181,7 +181,12 @@ async def plaintext_connect_task_no_login_with_expected_name(
connect(conn_with_expected_name, login=False) connect(conn_with_expected_name, login=False)
) )
await connected.wait() await connected.wait()
yield conn_with_expected_name, transport, conn_with_expected_name._frame_helper, connect_task yield (
conn_with_expected_name,
transport,
conn_with_expected_name._frame_helper,
connect_task,
)
@pytest_asyncio.fixture(name="plaintext_connect_task_with_login") @pytest_asyncio.fixture(name="plaintext_connect_task_with_login")
@ -201,7 +206,12 @@ async def plaintext_connect_task_with_login(
): ):
connect_task = asyncio.create_task(connect(conn_with_password, login=True)) connect_task = asyncio.create_task(connect(conn_with_password, login=True))
await connected.wait() await connected.wait()
yield conn_with_password, transport, conn_with_password._frame_helper, connect_task yield (
conn_with_password,
transport,
conn_with_password._frame_helper,
connect_task,
)
@pytest_asyncio.fixture(name="api_client") @pytest_asyncio.fixture(name="api_client")
@ -234,3 +244,21 @@ async def api_client(
await connect_task await connect_task
transport.reset_mock() transport.reset_mock()
yield client, conn, transport, protocol yield client, conn, transport, protocol
def make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]:
"""Make a mock connection."""
packets: list[tuple[int, bytes]] = []
class MockConnection(APIConnection):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Swallow args."""
super().__init__(
get_mock_connection_params(), AsyncMock(), True, None, *args, **kwargs
)
def process_packet(self, type_: int, data: bytes):
packets.append((type_, data))
connection = MockConnection()
return connection, packets

View File

@ -2,9 +2,8 @@ from __future__ import annotations
import asyncio import asyncio
import base64 import base64
import sys
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import MagicMock, patch
from noise.connection import NoiseConnection # type: ignore[import-untyped] from noise.connection import NoiseConnection # type: ignore[import-untyped]
import pytest import pytest
@ -28,7 +27,7 @@ from aioesphomeapi.core import (
) )
from .common import get_mock_protocol, mock_data_received from .common import get_mock_protocol, mock_data_received
from .conftest import get_mock_connection_params from .conftest import make_mock_connection
PREAMBLE = b"\x00" PREAMBLE = b"\x00"
@ -109,24 +108,6 @@ def _mock_responder_proto(psk_bytes: bytes) -> NoiseConnection:
return proto return proto
def _make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]:
"""Make a mock connection."""
packets: list[tuple[int, bytes]] = []
class MockConnection(APIConnection):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Swallow args."""
super().__init__(
get_mock_connection_params(), AsyncMock(), True, None, *args, **kwargs
)
def process_packet(self, type_: int, data: bytes):
packets.append((type_, data))
connection = MockConnection()
return connection, packets
class MockAPINoiseFrameHelper(APINoiseFrameHelper): class MockAPINoiseFrameHelper(APINoiseFrameHelper):
def __init__(self, *args: Any, writer: Any | None = None, **kwargs: Any) -> None: def __init__(self, *args: Any, writer: Any | None = None, **kwargs: Any) -> None:
"""Swallow args.""" """Swallow args."""
@ -154,86 +135,6 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
) from err ) from err
@pytest.mark.parametrize(
"in_bytes, pkt_data, pkt_type",
[
(PREAMBLE + varuint_to_bytes(0) + varuint_to_bytes(1), b"", 1),
(
PREAMBLE + varuint_to_bytes(192) + varuint_to_bytes(1) + (b"\x42" * 192),
(b"\x42" * 192),
1,
),
(
PREAMBLE + varuint_to_bytes(192) + varuint_to_bytes(100) + (b"\x42" * 192),
(b"\x42" * 192),
100,
),
(
PREAMBLE + varuint_to_bytes(4) + varuint_to_bytes(100) + (b"\x42" * 4),
(b"\x42" * 4),
100,
),
(
PREAMBLE
+ varuint_to_bytes(8192)
+ varuint_to_bytes(8192)
+ (b"\x42" * 8192),
(b"\x42" * 8192),
8192,
),
(
PREAMBLE + varuint_to_bytes(256) + varuint_to_bytes(256) + (b"\x42" * 256),
(b"\x42" * 256),
256,
),
(
PREAMBLE + varuint_to_bytes(1) + varuint_to_bytes(32768) + b"\x42",
b"\x42",
32768,
),
(
PREAMBLE
+ varuint_to_bytes(32768)
+ varuint_to_bytes(32768)
+ (b"\x42" * 32768),
(b"\x42" * 32768),
32768,
),
],
)
@pytest.mark.skipif(
sys.platform == "win32", reason="Fails on Windows due to pytest internals"
)
@pytest.mark.asyncio
async def test_plaintext_frame_helper(
in_bytes: bytes, pkt_data: bytes, pkt_type: int
) -> None:
for _ in range(3):
connection, packets = _make_mock_connection()
helper = APIPlaintextFrameHelper(
connection=connection, client_info="my client", log_name="test"
)
mock_data_received(helper, in_bytes)
pkt = packets.pop()
type_, data = pkt
assert type_ == pkt_type
assert data == pkt_data
# Make sure we correctly handle fragments
for i in range(len(in_bytes)):
mock_data_received(helper, in_bytes[i : i + 1])
pkt = packets.pop()
type_, data = pkt
assert type_ == pkt_type
assert data == pkt_data
helper.close()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"byte_type", "byte_type",
(bytes, bytearray, memoryview), (bytes, bytearray, memoryview),
@ -247,7 +148,7 @@ def test_plaintext_frame_helper_protractor_event_loop(byte_type: Any) -> None:
https://github.com/esphome/issues/issues/5117 https://github.com/esphome/issues/issues/5117
""" """
for _ in range(3): for _ in range(3):
connection, packets = _make_mock_connection() connection, packets = make_mock_connection()
helper = APIPlaintextFrameHelper( helper = APIPlaintextFrameHelper(
connection=connection, client_info="my client", log_name="test" connection=connection, client_info="my client", log_name="test"
) )
@ -295,7 +196,7 @@ async def test_noise_protector_event_loop(byte_type: Any) -> None:
"01000d01736572766963657465737400", "01000d01736572766963657465737400",
"0100160148616e647368616b65204d4143206661696c757265", "0100160148616e647368616b65204d4143206661696c757265",
] ]
connection, _ = _make_mock_connection() connection, _ = make_mock_connection()
helper = MockAPINoiseFrameHelper( helper = MockAPINoiseFrameHelper(
connection=connection, connection=connection,
@ -326,7 +227,7 @@ async def test_noise_frame_helper_incorrect_key():
"01000d01736572766963657465737400", "01000d01736572766963657465737400",
"0100160148616e647368616b65204d4143206661696c757265", "0100160148616e647368616b65204d4143206661696c757265",
] ]
connection, _ = _make_mock_connection() connection, _ = make_mock_connection()
helper = MockAPINoiseFrameHelper( helper = MockAPINoiseFrameHelper(
connection=connection, connection=connection,
@ -357,7 +258,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
"01000d01736572766963657465737400", "01000d01736572766963657465737400",
"0100160148616e647368616b65204d4143206661696c757265", "0100160148616e647368616b65204d4143206661696c757265",
] ]
connection, _ = _make_mock_connection() connection, _ = make_mock_connection()
helper = MockAPINoiseFrameHelper( helper = MockAPINoiseFrameHelper(
connection=connection, connection=connection,
@ -390,7 +291,7 @@ async def test_noise_incorrect_name():
"01000d01736572766963657465737400", "01000d01736572766963657465737400",
"0100160148616e647368616b65204d4143206661696c757265", "0100160148616e647368616b65204d4143206661696c757265",
] ]
connection, _ = _make_mock_connection() connection, _ = make_mock_connection()
helper = MockAPINoiseFrameHelper( helper = MockAPINoiseFrameHelper(
connection=connection, connection=connection,
@ -436,7 +337,7 @@ async def test_noise_frame_helper_handshake_failure():
def _writer(data: bytes): def _writer(data: bytes):
writes.append(data) writes.append(data)
connection, _ = _make_mock_connection() connection, _ = make_mock_connection()
helper = MockAPINoiseFrameHelper( helper = MockAPINoiseFrameHelper(
connection=connection, connection=connection,
@ -485,7 +386,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
def _writer(data: bytes): def _writer(data: bytes):
writes.append(data) writes.append(data)
connection, packets = _make_mock_connection() connection, packets = make_mock_connection()
helper = MockAPINoiseFrameHelper( helper = MockAPINoiseFrameHelper(
connection=connection, connection=connection,
@ -547,7 +448,7 @@ async def test_noise_frame_helper_bad_encryption(
def _writer(data: bytes): def _writer(data: bytes):
writes.append(data) writes.append(data)
connection, packets = _make_mock_connection() connection, packets = make_mock_connection()
helper = MockAPINoiseFrameHelper( helper = MockAPINoiseFrameHelper(
connection=connection, connection=connection,
@ -648,7 +549,7 @@ async def test_init_noise_with_wrong_byte_marker(noise_conn: APIConnection) -> N
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_noise_frame_helper_empty_hello(): async def test_noise_frame_helper_empty_hello():
"""Test empty hello with noise.""" """Test empty hello with noise."""
connection, _ = _make_mock_connection() connection, _ = make_mock_connection()
helper = MockAPINoiseFrameHelper( helper = MockAPINoiseFrameHelper(
connection=connection, connection=connection,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
@ -668,7 +569,7 @@ async def test_noise_frame_helper_empty_hello():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_noise_frame_helper_wrong_protocol(): async def test_noise_frame_helper_wrong_protocol():
"""Test noise with the wrong protocol.""" """Test noise with the wrong protocol."""
connection, _ = _make_mock_connection() connection, _ = make_mock_connection()
helper = MockAPINoiseFrameHelper( helper = MockAPINoiseFrameHelper(
connection=connection, connection=connection,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=", noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
@ -765,7 +666,7 @@ async def test_connection_lost_closes_connection_and_logs(
) )
async def test_noise_bad_psks(bad_psk: str, error: str) -> None: async def test_noise_bad_psks(bad_psk: str, error: str) -> None:
"""Test we raise on bad psks.""" """Test we raise on bad psks."""
connection, _ = _make_mock_connection() connection, _ = make_mock_connection()
with pytest.raises(InvalidEncryptionKeyAPIError, match=error): with pytest.raises(InvalidEncryptionKeyAPIError, match=error):
MockAPINoiseFrameHelper( MockAPINoiseFrameHelper(
connection=connection, connection=connection,

View File

@ -0,0 +1,98 @@
from __future__ import annotations
import sys
import pytest
from aioesphomeapi._frame_helper import APIPlaintextFrameHelper
from aioesphomeapi._frame_helper.plain_text import _varuint_to_bytes as varuint_to_bytes
from .common import mock_data_received
from .conftest import make_mock_connection
PREAMBLE = b"\x00"
NOISE_HELLO = b"\x01\x00\x00"
def test_skip_win32():
if sys.platform == "win32":
pytest.skip("Skip on Windows", allow_module_level=True)
assert True
@pytest.mark.parametrize(
"in_bytes, pkt_data, pkt_type",
[
(PREAMBLE + varuint_to_bytes(0) + varuint_to_bytes(1), b"", 1),
(
PREAMBLE + varuint_to_bytes(192) + varuint_to_bytes(1) + (b"\x42" * 192),
(b"\x42" * 192),
1,
),
(
PREAMBLE + varuint_to_bytes(192) + varuint_to_bytes(100) + (b"\x42" * 192),
(b"\x42" * 192),
100,
),
(
PREAMBLE + varuint_to_bytes(4) + varuint_to_bytes(100) + (b"\x42" * 4),
(b"\x42" * 4),
100,
),
(
PREAMBLE
+ varuint_to_bytes(8192)
+ varuint_to_bytes(8192)
+ (b"\x42" * 8192),
(b"\x42" * 8192),
8192,
),
(
PREAMBLE + varuint_to_bytes(256) + varuint_to_bytes(256) + (b"\x42" * 256),
(b"\x42" * 256),
256,
),
(
PREAMBLE + varuint_to_bytes(1) + varuint_to_bytes(32768) + b"\x42",
b"\x42",
32768,
),
(
PREAMBLE
+ varuint_to_bytes(32768)
+ varuint_to_bytes(32768)
+ (b"\x42" * 32768),
(b"\x42" * 32768),
32768,
),
],
)
@pytest.mark.asyncio
async def test_plaintext_frame_helper(
in_bytes: bytes, pkt_data: bytes, pkt_type: int
) -> None:
for _ in range(3):
connection, packets = make_mock_connection()
helper = APIPlaintextFrameHelper(
connection=connection, client_info="my client", log_name="test"
)
mock_data_received(helper, in_bytes)
pkt = packets.pop()
type_, data = pkt
assert type_ == pkt_type
assert data == pkt_data
# Make sure we correctly handle fragments
for i in range(len(in_bytes)):
mock_data_received(helper, in_bytes[i : i + 1])
pkt = packets.pop()
type_, data = pkt
assert type_ == pkt_type
assert data == pkt_data
helper.close()