Add optional basic cython implementation for frame_helper (#564)

This commit is contained in:
J. Nick Koston 2023-10-12 08:12:39 -10:00 committed by GitHub
parent 275ca3a660
commit 2c6f3d40ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 187 additions and 39 deletions

View File

@ -15,7 +15,7 @@ concurrency:
jobs: jobs:
ci: ci:
name: ${{ matrix.name }} py ${{ matrix.python-version }} on ${{ matrix.os }} name: ${{ matrix.name }} py ${{ matrix.python-version }} on ${{ matrix.os }} (${{ matrix.extension }})
runs-on: ${{ matrix.os }} runs-on: ${{ matrix.os }}
strategy: strategy:
fail-fast: false fail-fast: false
@ -26,7 +26,10 @@ jobs:
- "3.11" - "3.11"
- "3.12" - "3.12"
os: os:
- ubuntu-latest - ubuntu-latest
extension:
- "skip_cython"
- "use_cython"
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
@ -43,10 +46,20 @@ jobs:
uses: actions/cache@v3 uses: actions/cache@v3
with: with:
path: ${{ steps.pip-cache.outputs.dir }} path: ${{ steps.pip-cache.outputs.dir }}
key: pip-${{ steps.python.outputs.python-version }}-${{ hashFiles('requirements.txt', 'requirements_test.txt') }} key: pip-${{ steps.python.outputs.python-version }}-${{ matrix.extension }}-${{ hashFiles('requirements.txt', 'requirements_test.txt') }}
restore-keys: | restore-keys: |
pip-${{ steps.python.outputs.python-version }}- pip-${{ steps.python.outputs.python-version }}-${{ matrix.extension }}-
- name: Set up Python environment - name: Set up Python environment (no cython)
if: ${{ matrix.extension == 'skip_cython' }}
env:
SKIP_CYTHON: 1
run: |
pip3 install -r requirements.txt -r requirements_test.txt
pip3 install -e .
- name: Set up Python environment (cython)
if: ${{ matrix.extension == 'use_cython' }}
env:
REQUIRE_CYTHON: 1
run: | run: |
pip3 install -r requirements.txt -r requirements_test.txt pip3 install -r requirements.txt -r requirements_test.txt
pip3 install -e . pip3 install -e .
@ -60,19 +73,19 @@ jobs:
- run: flake8 aioesphomeapi - run: flake8 aioesphomeapi
name: Lint with flake8 name: Lint with flake8
if: ${{ matrix.python-version == '3.11' }} if: ${{ matrix.python-version == '3.11' && matrix.extension == 'skip_cython' }}
- run: pylint aioesphomeapi - run: pylint aioesphomeapi
name: Lint with pylint name: Lint with pylint
if: ${{ matrix.python-version == '3.11' }} if: ${{ matrix.python-version == '3.11' && matrix.extension == 'skip_cython' }}
- run: black --check --diff --color aioesphomeapi tests - run: black --check --diff --color aioesphomeapi tests
name: Check formatting with black name: Check formatting with black
if: ${{ matrix.python-version == '3.11' }} if: ${{ matrix.python-version == '3.11' && matrix.extension == 'skip_cython' }}
- run: isort --check --diff aioesphomeapi tests - run: isort --check --diff aioesphomeapi tests
name: Check import order with isort name: Check import order with isort
if: ${{ matrix.python-version == '3.11' }} if: ${{ matrix.python-version == '3.11' && matrix.extension == 'skip_cython' }}
- run: mypy aioesphomeapi - run: mypy aioesphomeapi
name: Check typing with mypy name: Check typing with mypy
if: ${{ matrix.python-version == '3.11' }} if: ${{ matrix.python-version == '3.11' && matrix.extension == 'skip_cython' }}
- run: pytest -vv --tb=native tests - run: pytest -vv --tb=native tests
name: Run tests with pytest name: Run tests with pytest
- run: | - run: |
@ -86,4 +99,4 @@ jobs:
exit 1 exit 1
fi fi
name: Check protobuf files match name: Check protobuf files match
if: ${{ matrix.python-version == '3.11' }} if: ${{ matrix.python-version == '3.11' && matrix.extension == 'skip_cython' }}

View File

@ -12,6 +12,12 @@ The module is available from the `Python Package Index <https://pypi.python.org/
$ pip3 install aioesphomeapi $ pip3 install aioesphomeapi
An optional cython extension is available for better performance, and the module will try to build it automatically.
The extension requires a C compiler and Python development headers. The module will fall back to the pure Python implementation if they are unavailable.
Building the extension can be forcefully disabled by setting the environment variable ``SKIP_CYTHON`` to ``1``.
Usage Usage
----- -----

View File

@ -0,0 +1,21 @@
import cython
cdef class APIFrameHelper:
cdef object _loop
cdef object _on_pkt
cdef object _on_error
cdef object _transport
cdef public object _writer
cdef public object _ready_future
cdef bytearray _buffer
cdef cython.uint _buffer_len
cdef cython.uint _pos
cdef object _client_info
cdef str _log_name
cdef object _debug_enabled
@cython.locals(original_pos=cython.uint, new_pos=cython.uint)
cdef _read_exactly(self, int length)

View File

@ -19,8 +19,10 @@ SOCKET_ERRORS = (
WRITE_EXCEPTIONS = (RuntimeError, ConnectionResetError, OSError) WRITE_EXCEPTIONS = (RuntimeError, ConnectionResetError, OSError)
_int = int
class APIFrameHelper(asyncio.Protocol):
class APIFrameHelper:
"""Helper class to handle the API frame protocol.""" """Helper class to handle the API frame protocol."""
__slots__ = ( __slots__ = (
@ -64,7 +66,7 @@ class APIFrameHelper(asyncio.Protocol):
if not self._ready_future.done(): if not self._ready_future.done():
self._ready_future.set_exception(exc) self._ready_future.set_exception(exc)
def _read_exactly(self, length: int) -> bytearray | None: def _read_exactly(self, length: _int) -> bytearray | None:
"""Read exactly length bytes from the buffer or None if all the bytes are not yet available.""" """Read exactly length bytes from the buffer or None if all the bytes are not yet available."""
original_pos = self._pos original_pos = self._pos
new_pos = original_pos + length new_pos = original_pos + length
@ -106,14 +108,15 @@ class APIFrameHelper(asyncio.Protocol):
self._on_error(exc) self._on_error(exc)
def connection_lost(self, exc: Exception | None) -> None: def connection_lost(self, exc: Exception | None) -> None:
"""Handle the connection being lost."""
self._handle_error( self._handle_error(
exc or SocketClosedAPIError(f"{self._log_name}: Connection lost") exc or SocketClosedAPIError(f"{self._log_name}: Connection lost")
) )
return super().connection_lost(exc)
def eof_received(self) -> bool | None: def eof_received(self) -> bool | None:
"""Handle EOF received."""
self._handle_error(SocketClosedAPIError(f"{self._log_name}: EOF received")) self._handle_error(SocketClosedAPIError(f"{self._log_name}: EOF received"))
return super().eof_received() return False
def close(self) -> None: def close(self) -> None:
"""Close the connection.""" """Close the connection."""
@ -121,3 +124,9 @@ class APIFrameHelper(asyncio.Protocol):
self._transport.close() self._transport.close()
self._transport = None self._transport = None
self._writer = None self._writer = None
def pause_writing(self) -> None:
"""Stub."""
def resume_writing(self) -> None:
"""Stub."""

View File

@ -0,0 +1,27 @@
import cython
from .base cimport APIFrameHelper
cdef object TYPE_CHECKING
cdef class APINoiseFrameHelper(APIFrameHelper):
cdef object _noise_psk
cdef object _expected_name
cdef object _state
cdef object _dispatch
cdef object _server_name
cdef object _proto
cdef object _decrypt
cdef object _encrypt
cdef bint _is_ready
@cython.locals(
header=bytearray,
preamble=cython.uint,
msg_size_high=cython.uint,
msg_size_low=cython.uint,
end_of_frame_pos=cython.uint,
)
cpdef data_received(self, bytes data)

View File

@ -144,7 +144,9 @@ class APINoiseFrameHelper(APIFrameHelper):
header = self._read_exactly(3) header = self._read_exactly(3)
if header is None: if header is None:
return return
preamble, msg_size_high, msg_size_low = header preamble = header[0]
msg_size_high = header[1]
msg_size_low = header[2]
if preamble != 0x01: if preamble != 0x01:
self._handle_error_and_close( self._handle_error_and_close(
ProtocolAPIError( ProtocolAPIError(

View File

@ -0,0 +1,23 @@
import cython
from .base cimport APIFrameHelper
cdef object TYPE_CHECKING
cdef object WRITE_EXCEPTIONS
cdef object bytes_to_varuint, varuint_to_bytes
cdef class APIPlaintextFrameHelper(APIFrameHelper):
@cython.locals(
msg_type=bytes,
length=bytes,
init_bytes=bytearray,
add_length=bytearray,
end_of_frame_pos=cython.uint,
length_int=cython.uint,
preamble=cython.uint,
length_high=cython.uint,
maybe_msg_type=cython.uint
)
cpdef data_received(self, bytes data)

View File

@ -50,8 +50,10 @@ class APIPlaintextFrameHelper(APIFrameHelper):
if init_bytes is None: if init_bytes is None:
return return
msg_type_int: int | None = None msg_type_int: int | None = None
length_int: int | None = None length_int = 0
preamble, length_high, maybe_msg_type = init_bytes preamble = init_bytes[0]
length_high = init_bytes[1]
maybe_msg_type = init_bytes[2]
if preamble != 0x00: if preamble != 0x00:
if preamble == 0x01: if preamble == 0x01:
self._handle_error_and_close( self._handle_error_and_close(
@ -88,7 +90,7 @@ class APIPlaintextFrameHelper(APIFrameHelper):
if add_length is None: if add_length is None:
return return
length += add_length length += add_length
length_int = bytes_to_varuint(length) length_int = bytes_to_varuint(length) or 0
# Since the length is longer than 1 byte we do not have the # Since the length is longer than 1 byte we do not have the
# message type yet. # message type yet.
msg_type = b"" msg_type = b""
@ -105,7 +107,6 @@ class APIPlaintextFrameHelper(APIFrameHelper):
msg_type_int = bytes_to_varuint(msg_type) msg_type_int = bytes_to_varuint(msg_type)
if TYPE_CHECKING: if TYPE_CHECKING:
assert length_int is not None
assert msg_type_int is not None assert msg_type_int is not None
if length_int == 0: if length_int == 0:

View File

@ -325,7 +325,7 @@ class APIConnection:
assert self._socket is not None assert self._socket is not None
if self._params.noise_psk is None: if self._params.noise_psk is None:
_, fh = await loop.create_connection( _, fh = await loop.create_connection( # type: ignore[type-var]
lambda: APIPlaintextFrameHelper( lambda: APIPlaintextFrameHelper(
on_pkt=process_packet, on_pkt=process_packet,
on_error=self._report_fatal_error, on_error=self._report_fatal_error,
@ -337,7 +337,7 @@ class APIConnection:
else: else:
noise_psk = self._params.noise_psk noise_psk = self._params.noise_psk
assert noise_psk is not None assert noise_psk is not None
_, fh = await loop.create_connection( _, fh = await loop.create_connection( # type: ignore[type-var]
lambda: APINoiseFrameHelper( lambda: APINoiseFrameHelper(
noise_psk=noise_psk, noise_psk=noise_psk,
expected_name=self._params.expected_name, expected_name=self._params.expected_name,

View File

@ -29,3 +29,7 @@ disable = [
"duplicate-code", "duplicate-code",
"too-many-lines", "too-many-lines",
] ]
[build-system]
requires = ['setuptools>=65.4.1', 'wheel', 'Cython>=3.0.2']

View File

@ -3,6 +3,9 @@
import os import os
from setuptools import find_packages, setup from setuptools import find_packages, setup
import os
from distutils.command.build_ext import build_ext
here = os.path.abspath(os.path.dirname(__file__)) here = os.path.abspath(os.path.dirname(__file__))
@ -31,20 +34,59 @@ DOWNLOAD_URL = "{}/archive/{}.zip".format(GITHUB_URL, VERSION)
with open(os.path.join(here, "requirements.txt")) as requirements_txt: with open(os.path.join(here, "requirements.txt")) as requirements_txt:
REQUIRES = requirements_txt.read().splitlines() REQUIRES = requirements_txt.read().splitlines()
setup(
name=PROJECT_PACKAGE_NAME, setup_kwargs = {
version=VERSION, "name": PROJECT_PACKAGE_NAME,
url=PROJECT_URL, "version": VERSION,
download_url=DOWNLOAD_URL, "url": PROJECT_URL,
author=PROJECT_AUTHOR, "download_url": DOWNLOAD_URL,
author_email=PROJECT_EMAIL, "author": PROJECT_AUTHOR,
description="Python API for interacting with ESPHome devices.", "author_email": PROJECT_EMAIL,
long_description=long_description, "description": "Python API for interacting with ESPHome devices.",
license=PROJECT_LICENSE, "long_description": long_description,
packages=find_packages(exclude=["tests", "tests.*"]), "license": PROJECT_LICENSE,
include_package_data=True, "packages": find_packages(exclude=["tests", "tests.*"]),
zip_safe=False, "include_package_data": True,
install_requires=REQUIRES, "zip_safe": False,
python_requires=">=3.9", "install_requires": REQUIRES,
test_suite="tests", "python_requires": ">=3.9",
) "test_suite": "tests",
}
class OptionalBuildExt(build_ext):
def build_extensions(self):
try:
super().build_extensions()
except Exception:
pass
def cythonize_if_available(setup_kwargs):
if os.environ.get("SKIP_CYTHON", False):
return
try:
from Cython.Build import cythonize
setup_kwargs.update(
dict(
ext_modules=cythonize(
[
"aioesphomeapi/_frame_helper/plain_text.py",
"aioesphomeapi/_frame_helper/noise.py",
"aioesphomeapi/_frame_helper/base.py",
],
compiler_directives={"language_level": "3"}, # Python 3
),
cmdclass=dict(build_ext=OptionalBuildExt),
)
)
except Exception:
if os.environ.get("REQUIRE_CYTHON"):
raise
pass
cythonize_if_available(setup_kwargs)
setup(**setup_kwargs)