mirror of
https://github.com/esphome/aioesphomeapi.git
synced 2024-11-22 12:05:12 +01:00
Expect a name for connections (#122)
Co-authored-by: Jesse Hills <3060199+jesserockz@users.noreply.github.com>
This commit is contained in:
parent
fe298c1f27
commit
9964034f18
@ -8,6 +8,7 @@ from typing import Optional
|
|||||||
from noise.connection import NoiseConnection # type: ignore
|
from noise.connection import NoiseConnection # type: ignore
|
||||||
|
|
||||||
from .core import (
|
from .core import (
|
||||||
|
BadNameAPIError,
|
||||||
HandshakeAPIError,
|
HandshakeAPIError,
|
||||||
InvalidEncryptionKeyAPIError,
|
InvalidEncryptionKeyAPIError,
|
||||||
ProtocolAPIError,
|
ProtocolAPIError,
|
||||||
@ -178,18 +179,35 @@ class APINoiseFrameHelper(APIFrameHelper):
|
|||||||
_LOGGER.debug("Received frame %s", frame.hex())
|
_LOGGER.debug("Received frame %s", frame.hex())
|
||||||
return frame
|
return frame
|
||||||
|
|
||||||
async def perform_handshake(self) -> None:
|
async def perform_handshake(self, expected_name: Optional[str]) -> None:
|
||||||
await self._write_frame(b"") # ClientHello
|
await self._write_frame(b"") # ClientHello
|
||||||
prologue = b"NoiseAPIInit" + b"\x00\x00"
|
prologue = b"NoiseAPIInit" + b"\x00\x00"
|
||||||
|
|
||||||
server_hello = await self._read_frame() # ServerHello
|
server_hello = await self._read_frame() # ServerHello
|
||||||
if not server_hello:
|
if not server_hello:
|
||||||
raise HandshakeAPIError("ServerHello is empty")
|
raise HandshakeAPIError("ServerHello is empty")
|
||||||
|
|
||||||
|
# First byte of server hello is the protocol the server chose
|
||||||
|
# for this session. Currently only 0x01 (Noise_NNpsk0_25519_ChaChaPoly_SHA256)
|
||||||
|
# exists.
|
||||||
chosen_proto = server_hello[0]
|
chosen_proto = server_hello[0]
|
||||||
if chosen_proto != 0x01:
|
if chosen_proto != 0x01:
|
||||||
raise HandshakeAPIError(
|
raise HandshakeAPIError(
|
||||||
f"Unknown protocol selected by client {chosen_proto}"
|
f"Unknown protocol selected by client {chosen_proto}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check name matches expected name (for noise sessions, this is done
|
||||||
|
# during hello phase before a connection is set up)
|
||||||
|
# Server name is encoded as a string followed by a zero byte after the chosen proto byte
|
||||||
|
server_name_i = server_hello.find(b"\0", 1)
|
||||||
|
if server_name_i != -1:
|
||||||
|
# server name found, this extension was added in 2022.2
|
||||||
|
server_name = server_hello[1:server_name_i].decode()
|
||||||
|
if expected_name is not None and expected_name != server_name:
|
||||||
|
raise BadNameAPIError(
|
||||||
|
f"Server sent a different name '{server_name}'", server_name
|
||||||
|
)
|
||||||
|
|
||||||
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
|
self._proto = NoiseConnection.from_name(b"Noise_NNpsk0_25519_ChaChaPoly_SHA256")
|
||||||
self._proto.set_as_initiator()
|
self._proto.set_as_initiator()
|
||||||
self._proto.set_psks(_decode_noise_psk(self._noise_psk))
|
self._proto.set_psks(_decode_noise_psk(self._noise_psk))
|
||||||
|
@ -97,6 +97,9 @@ message HelloResponse {
|
|||||||
// and only exists for debugging/logging purposes.
|
// and only exists for debugging/logging purposes.
|
||||||
// For example "ESPHome v1.10.0 on ESP8266"
|
// For example "ESPHome v1.10.0 on ESP8266"
|
||||||
string server_info = 3;
|
string server_info = 3;
|
||||||
|
|
||||||
|
// The name of the server (App.get_name())
|
||||||
|
string name = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Message sent at the beginning of each connection to authenticate the client
|
// Message sent at the beginning of each connection to authenticate the client
|
||||||
|
File diff suppressed because one or more lines are too long
@ -137,7 +137,23 @@ class APIClient:
|
|||||||
keepalive: float = 15.0,
|
keepalive: float = 15.0,
|
||||||
zeroconf_instance: ZeroconfInstanceType = None,
|
zeroconf_instance: ZeroconfInstanceType = None,
|
||||||
noise_psk: Optional[str] = None,
|
noise_psk: Optional[str] = None,
|
||||||
|
expected_name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
"""Create a client, this object is shared across sessions.
|
||||||
|
|
||||||
|
:param address: The address to connect to; for example an IP address
|
||||||
|
or .local name for mDNS lookup.
|
||||||
|
:param port: The port to connect to
|
||||||
|
:param password: Optional password to send to the device for authentication
|
||||||
|
:param client_info: User Agent string to send.
|
||||||
|
:param keepalive: The keepalive time in seconds (ping interval) for detecting stale connections.
|
||||||
|
Every keepalive seconds a ping is sent, if no pong is received the connection is closed.
|
||||||
|
:param zeroconf_instance: Pass a zeroconf instance to use if an mDNS lookup is necessary.
|
||||||
|
:param noise_psk: Encryption preshared key for noise transport encrypted sessions.
|
||||||
|
:param expected_name: Require the devices name to match the given expected name.
|
||||||
|
Can be used to prevent accidentally connecting to a different device if
|
||||||
|
IP passed as address but DHCP reassigned IP.
|
||||||
|
"""
|
||||||
self._params = ConnectionParams(
|
self._params = ConnectionParams(
|
||||||
address=address,
|
address=address,
|
||||||
port=port,
|
port=port,
|
||||||
@ -147,10 +163,19 @@ class APIClient:
|
|||||||
zeroconf_instance=zeroconf_instance,
|
zeroconf_instance=zeroconf_instance,
|
||||||
# treat empty psk string as missing (like password)
|
# treat empty psk string as missing (like password)
|
||||||
noise_psk=noise_psk or None,
|
noise_psk=noise_psk or None,
|
||||||
|
expected_name=expected_name,
|
||||||
)
|
)
|
||||||
self._connection: Optional[APIConnection] = None
|
self._connection: Optional[APIConnection] = None
|
||||||
self._cached_name: Optional[str] = None
|
self._cached_name: Optional[str] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def expected_name(self) -> Optional[str]:
|
||||||
|
return self._params.expected_name
|
||||||
|
|
||||||
|
@expected_name.setter
|
||||||
|
def expected_name(self, value: Optional[str]) -> None:
|
||||||
|
self._params.expected_name = value
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def address(self) -> str:
|
def address(self) -> str:
|
||||||
return self._params.address
|
return self._params.address
|
||||||
|
@ -32,6 +32,7 @@ from .api_pb2 import ( # type: ignore
|
|||||||
from .core import (
|
from .core import (
|
||||||
MESSAGE_TYPE_TO_PROTO,
|
MESSAGE_TYPE_TO_PROTO,
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
|
BadNameAPIError,
|
||||||
InvalidAuthAPIError,
|
InvalidAuthAPIError,
|
||||||
PingFailedAPIError,
|
PingFailedAPIError,
|
||||||
ProtocolAPIError,
|
ProtocolAPIError,
|
||||||
@ -55,6 +56,7 @@ class ConnectionParams:
|
|||||||
keepalive: float
|
keepalive: float
|
||||||
zeroconf_instance: hr.ZeroconfInstanceType
|
zeroconf_instance: hr.ZeroconfInstanceType
|
||||||
noise_psk: Optional[str]
|
noise_psk: Optional[str]
|
||||||
|
expected_name: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class ConnectionState(enum.Enum):
|
class ConnectionState(enum.Enum):
|
||||||
@ -174,7 +176,7 @@ class APIConnection:
|
|||||||
fh = self._frame_helper = APINoiseFrameHelper(
|
fh = self._frame_helper = APINoiseFrameHelper(
|
||||||
reader, writer, self._params.noise_psk
|
reader, writer, self._params.noise_psk
|
||||||
)
|
)
|
||||||
await fh.perform_handshake()
|
await fh.perform_handshake(self._params.expected_name)
|
||||||
|
|
||||||
self._connection_state = ConnectionState.SOCKET_OPENED
|
self._connection_state = ConnectionState.SOCKET_OPENED
|
||||||
|
|
||||||
@ -206,6 +208,15 @@ class APIConnection:
|
|||||||
)
|
)
|
||||||
raise APIConnectionError("Incompatible API version.")
|
raise APIConnectionError("Incompatible API version.")
|
||||||
|
|
||||||
|
if (
|
||||||
|
self._params.expected_name is not None
|
||||||
|
and resp.name != ""
|
||||||
|
and resp.name != self._params.expected_name
|
||||||
|
):
|
||||||
|
raise BadNameAPIError(
|
||||||
|
f"Server sent a different name '{resp.name}'", resp.name
|
||||||
|
)
|
||||||
|
|
||||||
self._connection_state = ConnectionState.CONNECTED
|
self._connection_state = ConnectionState.CONNECTED
|
||||||
|
|
||||||
async def _connect_start_ping(self) -> None:
|
async def _connect_start_ping(self) -> None:
|
||||||
|
@ -96,6 +96,14 @@ class HandshakeAPIError(APIConnectionError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BadNameAPIError(APIConnectionError):
|
||||||
|
"""Raised when a name received from the remote but does not much the expected name."""
|
||||||
|
|
||||||
|
def __init__(self, msg: str, received_name: str) -> None:
|
||||||
|
super().__init__(msg)
|
||||||
|
self.received_name = received_name
|
||||||
|
|
||||||
|
|
||||||
class InvalidEncryptionKeyAPIError(HandshakeAPIError):
|
class InvalidEncryptionKeyAPIError(HandshakeAPIError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ def connection_params() -> ConnectionParams:
|
|||||||
keepalive=15.0,
|
keepalive=15.0,
|
||||||
zeroconf_instance=None,
|
zeroconf_instance=None,
|
||||||
noise_psk=None,
|
noise_psk=None,
|
||||||
|
expected_name=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user