revert test changes

This commit is contained in:
J. Nick Koston 2024-08-29 10:55:04 -10:00
parent 428621cf0b
commit 8562b9191c
No known key found for this signature in database
2 changed files with 110 additions and 33 deletions

View File

@ -6,8 +6,8 @@ import asyncio
from dataclasses import replace
from functools import partial
import socket
from typing import Any, Callable
from unittest.mock import AsyncMock, MagicMock, create_autospec, patch
from typing import Callable
from unittest.mock import MagicMock, create_autospec, patch
import pytest
import pytest_asyncio
@ -244,21 +244,3 @@ async def api_client(
await connect_task
transport.reset_mock()
yield client, conn, transport, protocol
def make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]:
"""Make a mock connection."""
packets: list[tuple[int, bytes]] = []
class MockConnection(APIConnection):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Swallow args."""
super().__init__(
get_mock_connection_params(), AsyncMock(), True, None, *args, **kwargs
)
def process_packet(self, type_: int, data: bytes):
packets.append((type_, data))
connection = MockConnection()
return connection, packets

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio
import base64
from typing import Any
from unittest.mock import MagicMock, patch
from unittest.mock import AsyncMock, MagicMock, patch
from noise.connection import NoiseConnection # type: ignore[import-untyped]
import pytest
@ -27,7 +27,7 @@ from aioesphomeapi.core import (
)
from .common import get_mock_protocol, mock_data_received
from .conftest import make_mock_connection
from .conftest import get_mock_connection_params
PREAMBLE = b"\x00"
@ -108,6 +108,24 @@ def _mock_responder_proto(psk_bytes: bytes) -> NoiseConnection:
return proto
def _make_mock_connection() -> tuple[APIConnection, list[tuple[int, bytes]]]:
"""Make a mock connection."""
packets: list[tuple[int, bytes]] = []
class MockConnection(APIConnection):
def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Swallow args."""
super().__init__(
get_mock_connection_params(), AsyncMock(), True, None, *args, **kwargs
)
def process_packet(self, type_: int, data: bytes):
packets.append((type_, data))
connection = MockConnection()
return connection, packets
class MockAPINoiseFrameHelper(APINoiseFrameHelper):
def __init__(self, *args: Any, writer: Any | None = None, **kwargs: Any) -> None:
"""Swallow args."""
@ -135,6 +153,83 @@ class MockAPINoiseFrameHelper(APINoiseFrameHelper):
) from err
@pytest.mark.parametrize(
"in_bytes, pkt_data, pkt_type",
[
(PREAMBLE + varuint_to_bytes(0) + varuint_to_bytes(1), b"", 1),
(
PREAMBLE + varuint_to_bytes(192) + varuint_to_bytes(1) + (b"\x42" * 192),
(b"\x42" * 192),
1,
),
(
PREAMBLE + varuint_to_bytes(192) + varuint_to_bytes(100) + (b"\x42" * 192),
(b"\x42" * 192),
100,
),
(
PREAMBLE + varuint_to_bytes(4) + varuint_to_bytes(100) + (b"\x42" * 4),
(b"\x42" * 4),
100,
),
(
PREAMBLE
+ varuint_to_bytes(8192)
+ varuint_to_bytes(8192)
+ (b"\x42" * 8192),
(b"\x42" * 8192),
8192,
),
(
PREAMBLE + varuint_to_bytes(256) + varuint_to_bytes(256) + (b"\x42" * 256),
(b"\x42" * 256),
256,
),
(
PREAMBLE + varuint_to_bytes(1) + varuint_to_bytes(32768) + b"\x42",
b"\x42",
32768,
),
(
PREAMBLE
+ varuint_to_bytes(32768)
+ varuint_to_bytes(32768)
+ (b"\x42" * 32768),
(b"\x42" * 32768),
32768,
),
],
)
@pytest.mark.asyncio
async def test_plaintext_frame_helper(
in_bytes: bytes, pkt_data: bytes, pkt_type: int
) -> None:
for _ in range(3):
connection, packets = _make_mock_connection()
helper = APIPlaintextFrameHelper(
connection=connection, client_info="my client", log_name="test"
)
mock_data_received(helper, in_bytes)
pkt = packets.pop()
type_, data = pkt
assert type_ == pkt_type
assert data == pkt_data
# Make sure we correctly handle fragments
for i in range(len(in_bytes)):
mock_data_received(helper, in_bytes[i : i + 1])
pkt = packets.pop()
type_, data = pkt
assert type_ == pkt_type
assert data == pkt_data
helper.close()
@pytest.mark.parametrize(
"byte_type",
(bytes, bytearray, memoryview),
@ -148,7 +243,7 @@ def test_plaintext_frame_helper_protractor_event_loop(byte_type: Any) -> None:
https://github.com/esphome/issues/issues/5117
"""
for _ in range(3):
connection, packets = make_mock_connection()
connection, packets = _make_mock_connection()
helper = APIPlaintextFrameHelper(
connection=connection, client_info="my client", log_name="test"
)
@ -196,7 +291,7 @@ async def test_noise_protector_event_loop(byte_type: Any) -> None:
"01000d01736572766963657465737400",
"0100160148616e647368616b65204d4143206661696c757265",
]
connection, _ = make_mock_connection()
connection, _ = _make_mock_connection()
helper = MockAPINoiseFrameHelper(
connection=connection,
@ -227,7 +322,7 @@ async def test_noise_frame_helper_incorrect_key():
"01000d01736572766963657465737400",
"0100160148616e647368616b65204d4143206661696c757265",
]
connection, _ = make_mock_connection()
connection, _ = _make_mock_connection()
helper = MockAPINoiseFrameHelper(
connection=connection,
@ -258,7 +353,7 @@ async def test_noise_frame_helper_incorrect_key_fragments():
"01000d01736572766963657465737400",
"0100160148616e647368616b65204d4143206661696c757265",
]
connection, _ = make_mock_connection()
connection, _ = _make_mock_connection()
helper = MockAPINoiseFrameHelper(
connection=connection,
@ -291,7 +386,7 @@ async def test_noise_incorrect_name():
"01000d01736572766963657465737400",
"0100160148616e647368616b65204d4143206661696c757265",
]
connection, _ = make_mock_connection()
connection, _ = _make_mock_connection()
helper = MockAPINoiseFrameHelper(
connection=connection,
@ -337,7 +432,7 @@ async def test_noise_frame_helper_handshake_failure():
def _writer(data: bytes):
writes.append(data)
connection, _ = make_mock_connection()
connection, _ = _make_mock_connection()
helper = MockAPINoiseFrameHelper(
connection=connection,
@ -386,7 +481,7 @@ async def test_noise_frame_helper_handshake_success_with_single_packet():
def _writer(data: bytes):
writes.append(data)
connection, packets = make_mock_connection()
connection, packets = _make_mock_connection()
helper = MockAPINoiseFrameHelper(
connection=connection,
@ -448,7 +543,7 @@ async def test_noise_frame_helper_bad_encryption(
def _writer(data: bytes):
writes.append(data)
connection, packets = make_mock_connection()
connection, packets = _make_mock_connection()
helper = MockAPINoiseFrameHelper(
connection=connection,
@ -549,7 +644,7 @@ async def test_init_noise_with_wrong_byte_marker(noise_conn: APIConnection) -> N
@pytest.mark.asyncio
async def test_noise_frame_helper_empty_hello():
"""Test empty hello with noise."""
connection, _ = make_mock_connection()
connection, _ = _make_mock_connection()
helper = MockAPINoiseFrameHelper(
connection=connection,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
@ -569,7 +664,7 @@ async def test_noise_frame_helper_empty_hello():
@pytest.mark.asyncio
async def test_noise_frame_helper_wrong_protocol():
"""Test noise with the wrong protocol."""
connection, _ = make_mock_connection()
connection, _ = _make_mock_connection()
helper = MockAPINoiseFrameHelper(
connection=connection,
noise_psk="QRTIErOb/fcE9Ukd/5qA3RGYMn0Y+p06U58SCtOXvPc=",
@ -666,7 +761,7 @@ async def test_connection_lost_closes_connection_and_logs(
)
async def test_noise_bad_psks(bad_psk: str, error: str) -> None:
"""Test we raise on bad psks."""
connection, _ = make_mock_connection()
connection, _ = _make_mock_connection()
with pytest.raises(InvalidEncryptionKeyAPIError, match=error):
MockAPINoiseFrameHelper(
connection=connection,