Fix disconnecting while handshake is in process (#428)

This commit is contained in:
J. Nick Koston 2023-05-04 12:47:03 -05:00 committed by GitHub
parent 910d197906
commit 8261700bdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 14 deletions

View File

@ -344,6 +344,8 @@ class APINoiseFrameHelper(APIFrameHelper):
def write_packet(self, type_: int, data: bytes) -> None:
"""Write a packet to the socket."""
if self._state != NoiseConnectionState.READY:
raise HandshakeAPIError("Noise connection is not ready")
self._write_frame(
self._proto.encrypt(
(

View File

@ -69,7 +69,16 @@ CONNECT_REQUEST_TIMEOUT = 30.0
# The connect timeout should be the maximum time we expect the esp to take
# to reboot and connect to the network/WiFi.
CONNECT_TIMEOUT = 60.0
TCP_CONNECT_TIMEOUT = 60.0
# The maximum time for the whole connect process to complete
CONNECT_AND_SETUP_TIMEOUT = 120.0
# How long to wait for an existing connection to finish being
# setup when requesting a disconnect so we can try to disconnect
# gracefully without closing the socket out from under the
# the esp device
DISCONNECT_WAIT_CONNECT_TIMEOUT = 5.0
in_do_connect: contextvars.ContextVar[Optional[bool]] = contextvars.ContextVar(
"in_do_connect"
@ -240,7 +249,7 @@ class APIConnection:
try:
coro = asyncio.get_event_loop().sock_connect(self._socket, sockaddr)
async with async_timeout.timeout(CONNECT_TIMEOUT):
async with async_timeout.timeout(TCP_CONNECT_TIMEOUT):
await coro
except OSError as err:
raise SocketAPIError(f"Error connecting to {sockaddr}: {err}") from err
@ -411,10 +420,10 @@ class APIConnection:
)
try:
# Allow 2 minutes for connect; this is only as a last measure
# Allow 2 minutes for connect and setup; this is only as a last measure
# to protect from issues if some part of the connect process mistakenly
# does not have a timeout
async with async_timeout.timeout(120.0):
async with async_timeout.timeout(CONNECT_AND_SETUP_TIMEOUT):
await self._connect_task
except asyncio.CancelledError:
# If the task was cancelled, we need to clean up the connection
@ -428,6 +437,7 @@ class APIConnection:
self._cleanup()
raise
self._connect_task = None
self._connection_state = ConnectionState.CONNECTED
self._connect_complete = True
@ -697,20 +707,25 @@ class APIConnection:
self.send_message(resp)
async def disconnect(self) -> None:
if not self._is_socket_open or not self._frame_helper:
"""Disconnect from the API."""
if self._connect_task:
# Try to wait for the handshake to finish so we can send
# a disconnect request. If it doesn't finish in time
# we will just close the socket.
await asyncio.wait([self._connect_task], timeout=5.0)
self._expected_disconnect = True
if self._is_socket_open and self._frame_helper:
# We still want to send a disconnect request even
# if the hello phase isn't finished to ensure we
# the esp will clean up the connection as soon
# as possible.
return
self._expected_disconnect = True
try:
await self.send_message_await_response(
DisconnectRequest(), DisconnectResponse
)
except APIConnectionError:
pass
try:
await self.send_message_await_response(
DisconnectRequest(), DisconnectResponse
)
except APIConnectionError:
pass
self._connection_state = ConnectionState.CLOSED
self._cleanup()