From e5051eefbcf1fecf2f02904c40b862bf899ec8f4 Mon Sep 17 00:00:00 2001 From: Otto Winter Date: Wed, 8 Sep 2021 23:22:47 +0200 Subject: [PATCH] API encryption (#2254) --- esphome/components/api/__init__.py | 33 ++ esphome/components/api/api_connection.cpp | 6 + esphome/components/api/api_frame_helper.cpp | 612 ++++++++++++++++++++ esphome/components/api/api_frame_helper.h | 76 +++ esphome/components/api/api_noise_context.h | 23 + esphome/components/api/api_server.h | 10 + esphome/core/defines.h | 3 + esphome/core/helpers.cpp | 9 + esphome/core/helpers.h | 2 + platformio.ini | 1 + tests/test3.yaml | 2 + 11 files changed, 777 insertions(+) create mode 100644 esphome/components/api/api_noise_context.h diff --git a/esphome/components/api/__init__.py b/esphome/components/api/__init__.py index fc140dc7d2..3705f0d7ca 100644 --- a/esphome/components/api/__init__.py +++ b/esphome/components/api/__init__.py @@ -1,3 +1,5 @@ +import base64 + import esphome.codegen as cg import esphome.config_validation as cv from esphome import automation @@ -6,6 +8,7 @@ from esphome.const import ( CONF_DATA, CONF_DATA_TEMPLATE, CONF_ID, + CONF_KEY, CONF_PASSWORD, CONF_PORT, CONF_REBOOT_TIMEOUT, @@ -41,6 +44,22 @@ SERVICE_ARG_NATIVE_TYPES = { "float[]": cg.std_vector.template(float), "string[]": cg.std_vector.template(cg.std_string), } +CONF_ENCRYPTION = "encryption" + + +def validate_encryption_key(value): + value = cv.string_strict(value) + try: + decoded = base64.b64decode(value, validate=True) + except ValueError as err: + raise cv.Invalid("Invalid key format, please check it's using base64") from err + + if len(decoded) != 32: + raise cv.Invalid("Encryption key must be base64 and 32 bytes long") + + # Return original data for roundtrip conversion + return value + CONFIG_SCHEMA = cv.Schema( { @@ -63,6 +82,11 @@ CONFIG_SCHEMA = cv.Schema( ), } ), + cv.Optional(CONF_ENCRYPTION): cv.Schema( + { + cv.Required(CONF_KEY): validate_encryption_key, + } + ), } ).extend(cv.COMPONENT_SCHEMA) @@ -92,6 +116,15 @@ async def to_code(config): cg.add(var.register_user_service(trigger)) await automation.build_automation(trigger, func_args, conf) + if CONF_ENCRYPTION in config: + conf = config[CONF_ENCRYPTION] + decoded = base64.b64decode(conf[CONF_KEY]) + cg.add(var.set_noise_psk(list(decoded))) + cg.add_define("USE_API_NOISE") + cg.add_library("esphome/noise-c", "0.1.1") + else: + cg.add_define("USE_API_PLAINTEXT") + cg.add_define("USE_API") cg.add_global(api_ns.using) diff --git a/esphome/components/api/api_connection.cpp b/esphome/components/api/api_connection.cpp index bce0b0bab8..650f4f6f6e 100644 --- a/esphome/components/api/api_connection.cpp +++ b/esphome/components/api/api_connection.cpp @@ -23,7 +23,13 @@ APIConnection::APIConnection(std::unique_ptr sock, APIServer *pa : parent_(parent), initial_state_iterator_(parent, this), list_entities_iterator_(parent, this) { this->proto_write_buffer_.reserve(64); +#if defined(USE_API_PLAINTEXT) helper_ = std::unique_ptr{new APIPlaintextFrameHelper(std::move(sock))}; +#elif defined(USE_API_NOISE) + helper_ = std::unique_ptr{new APINoiseFrameHelper(std::move(sock), parent->get_noise_ctx())}; +#else +#error "No frame helper defined" +#endif } void APIConnection::start() { this->last_traffic_ = millis(); diff --git a/esphome/components/api/api_frame_helper.cpp b/esphome/components/api/api_frame_helper.cpp index f903ab8656..26fbf1269f 100644 --- a/esphome/components/api/api_frame_helper.cpp +++ b/esphome/components/api/api_frame_helper.cpp @@ -19,6 +19,617 @@ bool is_would_block(ssize_t ret) { #define HELPER_LOG(msg, ...) ESP_LOGVV(TAG, "%s: " msg, info_.c_str(), ##__VA_ARGS__) +#ifdef USE_API_NOISE +static const char *const PROLOGUE_INIT = "NoiseAPIInit"; + +/// Convert a noise error code to a readable error +std::string noise_err_to_str(int err) { + if (err == NOISE_ERROR_NO_MEMORY) + return "NO_MEMORY"; + if (err == NOISE_ERROR_UNKNOWN_ID) + return "UNKNOWN_ID"; + if (err == NOISE_ERROR_UNKNOWN_NAME) + return "UNKNOWN_NAME"; + if (err == NOISE_ERROR_MAC_FAILURE) + return "MAC_FAILURE"; + if (err == NOISE_ERROR_NOT_APPLICABLE) + return "NOT_APPLICABLE"; + if (err == NOISE_ERROR_SYSTEM) + return "SYSTEM"; + if (err == NOISE_ERROR_REMOTE_KEY_REQUIRED) + return "REMOTE_KEY_REQUIRED"; + if (err == NOISE_ERROR_LOCAL_KEY_REQUIRED) + return "LOCAL_KEY_REQUIRED"; + if (err == NOISE_ERROR_PSK_REQUIRED) + return "PSK_REQUIRED"; + if (err == NOISE_ERROR_INVALID_LENGTH) + return "INVALID_LENGTH"; + if (err == NOISE_ERROR_INVALID_PARAM) + return "INVALID_PARAM"; + if (err == NOISE_ERROR_INVALID_STATE) + return "INVALID_STATE"; + if (err == NOISE_ERROR_INVALID_NONCE) + return "INVALID_NONCE"; + if (err == NOISE_ERROR_INVALID_PRIVATE_KEY) + return "INVALID_PRIVATE_KEY"; + if (err == NOISE_ERROR_INVALID_PUBLIC_KEY) + return "INVALID_PUBLIC_KEY"; + if (err == NOISE_ERROR_INVALID_FORMAT) + return "INVALID_FORMAT"; + if (err == NOISE_ERROR_INVALID_SIGNATURE) + return "INVALID_SIGNATURE"; + return to_string(err); +} + +/// Initialize the frame helper, returns OK if successful. +APIError APINoiseFrameHelper::init() { + if (state_ != State::INITIALIZE || socket_ == nullptr) { + HELPER_LOG("Bad state for init %d", (int) state_); + return APIError::BAD_STATE; + } + int err = socket_->setblocking(false); + if (err != 0) { + state_ = State::FAILED; + HELPER_LOG("Setting nonblocking failed with errno %d", errno); + return APIError::TCP_NONBLOCKING_FAILED; + } + int enable = 1; + err = socket_->setsockopt(IPPROTO_TCP, TCP_NODELAY, &enable, sizeof(int)); + if (err != 0) { + state_ = State::FAILED; + HELPER_LOG("Setting nodelay failed with errno %d", errno); + return APIError::TCP_NODELAY_FAILED; + } + + // init prologue + prologue_.insert(prologue_.end(), PROLOGUE_INIT, PROLOGUE_INIT + strlen(PROLOGUE_INIT)); + + state_ = State::CLIENT_HELLO; + return APIError::OK; +} +/// Run through handshake messages (if in that phase) +APIError APINoiseFrameHelper::loop() { + APIError err = state_action_(); + if (err == APIError::WOULD_BLOCK) + return APIError::OK; + if (err != APIError::OK) + return err; + if (!tx_buf_.empty()) { + err = try_send_tx_buf_(); + if (err != APIError::OK) { + return err; + } + } + return APIError::OK; +} + +/** Read a packet into the rx_buf_. If successful, stores frame data in the frame parameter + * + * @param frame: The struct to hold the frame information in. + * msg_start: points to the start of the payload - this pointer is only valid until the next + * try_receive_raw_ call + * + * @return 0 if a full packet is in rx_buf_ + * @return -1 if error, check errno. + * + * errno EWOULDBLOCK: Packet could not be read without blocking. Try again later. + * errno ENOMEM: Not enough memory for reading packet. + * errno API_ERROR_BAD_INDICATOR: Bad indicator byte at start of frame. + * errno API_ERROR_HANDSHAKE_PACKET_LEN: Packet too big for this phase. + */ +APIError APINoiseFrameHelper::try_read_frame_(ParsedFrame *frame) { + int err; + APIError aerr; + + if (frame == nullptr) { + HELPER_LOG("Bad argument for try_read_frame_"); + return APIError::BAD_ARG; + } + + // read header + if (rx_header_buf_len_ < 3) { + // no header information yet + size_t to_read = 3 - rx_header_buf_len_; + ssize_t received = socket_->read(&rx_header_buf_[rx_header_buf_len_], to_read); + if (is_would_block(received)) { + return APIError::WOULD_BLOCK; + } else if (received == -1) { + state_ = State::FAILED; + HELPER_LOG("Socket read failed with errno %d", errno); + return APIError::SOCKET_READ_FAILED; + } + rx_header_buf_len_ += received; + if (received != to_read) { + // not a full read + return APIError::WOULD_BLOCK; + } + + // header reading done + } + + // read body + uint8_t indicator = rx_header_buf_[0]; + if (indicator != 0x01) { + state_ = State::FAILED; + HELPER_LOG("Bad indicator byte %u", indicator); + return APIError::BAD_INDICATOR; + } + + uint16_t msg_size = (((uint16_t) rx_header_buf_[1]) << 8) | rx_header_buf_[2]; + + if (state_ != State::DATA && msg_size > 128) { + // for handshake message only permit up to 128 bytes + state_ = State::FAILED; + HELPER_LOG("Bad packet len for handshake: %d", msg_size); + return APIError::BAD_HANDSHAKE_PACKET_LEN; + } + + // reserve space for body + if (rx_buf_.size() != msg_size) { + rx_buf_.resize(msg_size); + } + + if (rx_buf_len_ < msg_size) { + // more data to read + size_t to_read = msg_size - rx_buf_len_; + ssize_t received = socket_->read(&rx_buf_[rx_buf_len_], to_read); + if (is_would_block(received)) { + return APIError::WOULD_BLOCK; + } else if (received == -1) { + state_ = State::FAILED; + HELPER_LOG("Socket read failed with errno %d", errno); + return APIError::SOCKET_READ_FAILED; + } + rx_buf_len_ += received; + if (received != to_read) { + // not all read + return APIError::WOULD_BLOCK; + } + } + + // uncomment for even more debugging + // ESP_LOGVV(TAG, "Received frame: %s", hexencode(rx_buf_).c_str()); + frame->msg = std::move(rx_buf_); + // consume msg + rx_buf_ = {}; + rx_buf_len_ = 0; + rx_header_buf_len_ = 0; + return APIError::OK; +} + +/** To be called from read/write methods. + * + * This method runs through the internal handshake methods, if in that state. + * + * If the handshake is still active when this method returns and a read/write can't take place at + * the moment, returns WOULD_BLOCK. + * If an error occured, returns that error. Only returns OK if the transport is ready for data + * traffic. + */ +APIError APINoiseFrameHelper::state_action_() { + int err; + APIError aerr; + if (state_ == State::INITIALIZE) { + HELPER_LOG("Bad state for method: %d", (int) state_); + return APIError::BAD_STATE; + } + if (state_ == State::CLIENT_HELLO) { + // waiting for client hello + ParsedFrame frame; + aerr = try_read_frame_(&frame); + if (aerr != APIError::OK) + return aerr; + // ignore contents, may be used in future for flags + prologue_.push_back((uint8_t)(frame.msg.size() >> 8)); + prologue_.push_back((uint8_t) frame.msg.size()); + prologue_.insert(prologue_.end(), frame.msg.begin(), frame.msg.end()); + + state_ = State::SERVER_HELLO; + } + if (state_ == State::SERVER_HELLO) { + // send server hello + uint8_t msg[1]; + msg[0] = 0x01; // chosen proto + aerr = write_frame_(msg, 1); + if (aerr != APIError::OK) + return aerr; + + // start handshake + aerr = init_handshake_(); + if (aerr != APIError::OK) + return aerr; + + state_ = State::HANDSHAKE; + } + if (state_ == State::HANDSHAKE) { + int action = noise_handshakestate_get_action(handshake_); + if (action == NOISE_ACTION_READ_MESSAGE) { + // waiting for handshake msg + ParsedFrame frame; + aerr = try_read_frame_(&frame); + if (aerr == APIError::BAD_INDICATOR) { + send_explicit_handshake_reject_("Bad indicator byte"); + return aerr; + } + if (aerr == APIError::BAD_HANDSHAKE_PACKET_LEN) { + send_explicit_handshake_reject_("Bad handshake packet len"); + return aerr; + } + if (aerr != APIError::OK) + return aerr; + + if (frame.msg.empty()) { + send_explicit_handshake_reject_("Empty handshake message"); + return APIError::BAD_HANDSHAKE_PACKET_LEN; + } else if (frame.msg[0] != 0x00) { + HELPER_LOG("Bad handshake error byte: %u", frame.msg[0]); + send_explicit_handshake_reject_("Bad handshake error byte"); + return APIError::BAD_HANDSHAKE_PACKET_LEN; + } + + NoiseBuffer mbuf; + noise_buffer_init(mbuf); + noise_buffer_set_input(mbuf, frame.msg.data() + 1, frame.msg.size() - 1); + err = noise_handshakestate_read_message(handshake_, &mbuf, nullptr); + if (err != 0) { + // TODO: explicit rejection + state_ = State::FAILED; + HELPER_LOG("noise_handshakestate_read_message failed: %s", noise_err_to_str(err).c_str()); + if (err == NOISE_ERROR_MAC_FAILURE) { + send_explicit_handshake_reject_("Handshake MAC failure"); + } else { + send_explicit_handshake_reject_("Handshake error"); + } + return APIError::HANDSHAKESTATE_READ_FAILED; + } + + aerr = check_handshake_finished_(); + if (aerr != APIError::OK) + return aerr; + } else if (action == NOISE_ACTION_WRITE_MESSAGE) { + uint8_t buffer[65]; + NoiseBuffer mbuf; + noise_buffer_init(mbuf); + noise_buffer_set_output(mbuf, buffer + 1, sizeof(buffer) - 1); + + err = noise_handshakestate_write_message(handshake_, &mbuf, nullptr); + if (err != 0) { + state_ = State::FAILED; + HELPER_LOG("noise_handshakestate_write_message failed: %s", noise_err_to_str(err).c_str()); + return APIError::HANDSHAKESTATE_WRITE_FAILED; + } + buffer[0] = 0x00; // success + + aerr = write_frame_(buffer, mbuf.size + 1); + if (aerr != APIError::OK) + return aerr; + aerr = check_handshake_finished_(); + if (aerr != APIError::OK) + return aerr; + } else { + // bad state for action + state_ = State::FAILED; + HELPER_LOG("Bad action for handshake: %d", action); + return APIError::HANDSHAKESTATE_BAD_STATE; + } + } + if (state_ == State::CLOSED || state_ == State::FAILED) { + return APIError::BAD_STATE; + } + return APIError::OK; +} +void APINoiseFrameHelper::send_explicit_handshake_reject_(const std::string &reason) { + std::vector data; + data.reserve(reason.size() + 1); + data[0] = 0x01; // failure + for (size_t i = 0; i < reason.size(); i++) { + data[i + 1] = (uint8_t) reason[i]; + } + write_frame_(data.data(), data.size()); +} + +APIError APINoiseFrameHelper::read_packet(ReadPacketBuffer *buffer) { + int err; + APIError aerr; + aerr = state_action_(); + if (aerr != APIError::OK) { + return aerr; + } + + if (state_ != State::DATA) { + return APIError::WOULD_BLOCK; + } + + ParsedFrame frame; + aerr = try_read_frame_(&frame); + if (aerr != APIError::OK) + return aerr; + + NoiseBuffer mbuf; + noise_buffer_init(mbuf); + noise_buffer_set_inout(mbuf, frame.msg.data(), frame.msg.size(), frame.msg.size()); + err = noise_cipherstate_decrypt(recv_cipher_, &mbuf); + if (err != 0) { + state_ = State::FAILED; + HELPER_LOG("noise_cipherstate_decrypt failed: %s", noise_err_to_str(err).c_str()); + return APIError::CIPHERSTATE_DECRYPT_FAILED; + } + + size_t msg_size = mbuf.size; + uint8_t *msg_data = frame.msg.data(); + if (msg_size < 4) { + state_ = State::FAILED; + HELPER_LOG("Bad data packet: size %d too short", msg_size); + return APIError::BAD_DATA_PACKET; + } + + // uint16_t type; + // uint16_t data_len; + // uint8_t *data; + // uint8_t *padding; zero or more bytes to fill up the rest of the packet + uint16_t type = (((uint16_t) msg_data[0]) << 8) | msg_data[1]; + uint16_t data_len = (((uint16_t) msg_data[2]) << 8) | msg_data[3]; + if (data_len > msg_size - 4) { + state_ = State::FAILED; + HELPER_LOG("Bad data packet: data_len %u greater than msg_size %u", data_len, msg_size); + return APIError::BAD_DATA_PACKET; + } + + buffer->container = std::move(frame.msg); + buffer->data_offset = 4; + buffer->data_len = data_len; + buffer->type = type; + return APIError::OK; +} +bool APINoiseFrameHelper::can_write_without_blocking() { return state_ == State::DATA && tx_buf_.empty(); } +APIError APINoiseFrameHelper::write_packet(uint16_t type, const uint8_t *payload, size_t payload_len) { + int err; + APIError aerr; + aerr = state_action_(); + if (aerr != APIError::OK) { + return aerr; + } + + if (state_ != State::DATA) { + return APIError::WOULD_BLOCK; + } + + size_t padding = 0; + size_t msg_len = 4 + payload_len + padding; + size_t frame_len = 3 + msg_len + noise_cipherstate_get_mac_length(send_cipher_); + auto tmpbuf = std::unique_ptr{new (std::nothrow) uint8_t[frame_len]}; + if (tmpbuf == nullptr) { + HELPER_LOG("Could not allocate for writing packet"); + return APIError::OUT_OF_MEMORY; + } + + tmpbuf[0] = 0x01; // indicator + // tmpbuf[1], tmpbuf[2] to be set later + const uint8_t msg_offset = 3; + const uint8_t payload_offset = msg_offset + 4; + tmpbuf[msg_offset + 0] = (uint8_t)(type >> 8); // type + tmpbuf[msg_offset + 1] = (uint8_t) type; + tmpbuf[msg_offset + 2] = (uint8_t)(payload_len >> 8); // data_len + tmpbuf[msg_offset + 3] = (uint8_t) payload_len; + // copy data + std::copy(payload, payload + payload_len, &tmpbuf[payload_offset]); + // fill padding with zeros + std::fill(&tmpbuf[payload_offset + payload_len], &tmpbuf[frame_len], 0); + + NoiseBuffer mbuf; + noise_buffer_init(mbuf); + noise_buffer_set_inout(mbuf, &tmpbuf[msg_offset], msg_len, frame_len - msg_offset); + err = noise_cipherstate_encrypt(send_cipher_, &mbuf); + if (err != 0) { + state_ = State::FAILED; + HELPER_LOG("noise_cipherstate_encrypt failed: %s", noise_err_to_str(err).c_str()); + return APIError::CIPHERSTATE_ENCRYPT_FAILED; + } + + size_t total_len = 3 + mbuf.size; + tmpbuf[1] = (uint8_t)(mbuf.size >> 8); + tmpbuf[2] = (uint8_t) mbuf.size; + // write raw to not have two packets sent if NAGLE disabled + aerr = write_raw_(&tmpbuf[0], total_len); + if (aerr != APIError::OK) { + return aerr; + } + return APIError::OK; +} +APIError APINoiseFrameHelper::try_send_tx_buf_() { + // try send from tx_buf + while (state_ != State::CLOSED && !tx_buf_.empty()) { + ssize_t sent = socket_->write(tx_buf_.data(), tx_buf_.size()); + if (sent == -1) { + if (errno == EWOULDBLOCK || errno == EAGAIN) + break; + state_ = State::FAILED; + HELPER_LOG("Socket write failed with errno %d", errno); + return APIError::SOCKET_WRITE_FAILED; + } else if (sent == 0) { + break; + } + // TODO: inefficient if multiple packets in txbuf + // replace with deque of buffers + tx_buf_.erase(tx_buf_.begin(), tx_buf_.begin() + sent); + } + + return APIError::OK; +} +/** Write the data to the socket, or buffer it a write would block + * + * @param data The data to write + * @param len The length of data + */ +APIError APINoiseFrameHelper::write_raw_(const uint8_t *data, size_t len) { + if (len == 0) + return APIError::OK; + int err; + APIError aerr; + + // uncomment for even more debugging + // ESP_LOGVV(TAG, "Sending raw: %s", hexencode(data, len).c_str()); + + if (!tx_buf_.empty()) { + // try to empty tx_buf_ first + aerr = try_send_tx_buf_(); + if (aerr != APIError::OK && aerr != APIError::WOULD_BLOCK) + return aerr; + } + + if (!tx_buf_.empty()) { + // tx buf not empty, can't write now because then stream would be inconsistent + tx_buf_.insert(tx_buf_.end(), data, data + len); + return APIError::OK; + } + + ssize_t sent = socket_->write(data, len); + if (is_would_block(sent)) { + // operation would block, add buffer to tx_buf + tx_buf_.insert(tx_buf_.end(), data, data + len); + return APIError::OK; + } else if (sent == -1) { + // an error occured + state_ = State::FAILED; + HELPER_LOG("Socket write failed with errno %d", errno); + return APIError::SOCKET_WRITE_FAILED; + } else if (sent != len) { + // partially sent, add end to tx_buf + tx_buf_.insert(tx_buf_.end(), data + sent, data + len); + return APIError::OK; + } + // fully sent + return APIError::OK; +} +APIError APINoiseFrameHelper::write_frame_(const uint8_t *data, size_t len) { + APIError aerr; + + uint8_t header[3]; + header[0] = 0x01; // indicator + header[1] = (uint8_t)(len >> 8); + header[2] = (uint8_t) len; + + aerr = write_raw_(header, 3); + if (aerr != APIError::OK) + return aerr; + aerr = write_raw_(data, len); + return aerr; +} + +/** Initiate the data structures for the handshake. + * + * @return 0 on success, -1 on error (check errno) + */ +APIError APINoiseFrameHelper::init_handshake_() { + int err; + memset(&nid_, 0, sizeof(nid_)); + // const char *proto = "Noise_NNpsk0_25519_ChaChaPoly_SHA256"; + // err = noise_protocol_name_to_id(&nid_, proto, strlen(proto)); + nid_.pattern_id = NOISE_PATTERN_NN; + nid_.cipher_id = NOISE_CIPHER_CHACHAPOLY; + nid_.dh_id = NOISE_DH_CURVE25519; + nid_.prefix_id = NOISE_PREFIX_STANDARD; + nid_.hybrid_id = NOISE_DH_NONE; + nid_.hash_id = NOISE_HASH_SHA256; + nid_.modifier_ids[0] = NOISE_MODIFIER_PSK0; + + err = noise_handshakestate_new_by_id(&handshake_, &nid_, NOISE_ROLE_RESPONDER); + if (err != 0) { + state_ = State::FAILED; + HELPER_LOG("noise_handshakestate_new_by_id failed: %s", noise_err_to_str(err).c_str()); + return APIError::HANDSHAKESTATE_SETUP_FAILED; + } + + const auto &psk = ctx_->get_psk(); + err = noise_handshakestate_set_pre_shared_key(handshake_, psk.data(), psk.size()); + if (err != 0) { + state_ = State::FAILED; + HELPER_LOG("noise_handshakestate_set_pre_shared_key failed: %s", noise_err_to_str(err).c_str()); + return APIError::HANDSHAKESTATE_SETUP_FAILED; + } + + err = noise_handshakestate_set_prologue(handshake_, prologue_.data(), prologue_.size()); + if (err != 0) { + state_ = State::FAILED; + HELPER_LOG("noise_handshakestate_set_prologue failed: %s", noise_err_to_str(err).c_str()); + return APIError::HANDSHAKESTATE_SETUP_FAILED; + } + // set_prologue copies it into handshakestate, so we can get rid of it now + prologue_ = {}; + + err = noise_handshakestate_start(handshake_); + if (err != 0) { + state_ = State::FAILED; + HELPER_LOG("noise_handshakestate_start failed: %s", noise_err_to_str(err).c_str()); + return APIError::HANDSHAKESTATE_SETUP_FAILED; + } + return APIError::OK; +} + +APIError APINoiseFrameHelper::check_handshake_finished_() { + assert(state_ == State::HANDSHAKE); + + int action = noise_handshakestate_get_action(handshake_); + if (action == NOISE_ACTION_READ_MESSAGE || action == NOISE_ACTION_WRITE_MESSAGE) + return APIError::OK; + if (action != NOISE_ACTION_SPLIT) { + state_ = State::FAILED; + HELPER_LOG("Bad action for handshake: %d", action); + return APIError::HANDSHAKESTATE_BAD_STATE; + } + int err = noise_handshakestate_split(handshake_, &send_cipher_, &recv_cipher_); + if (err != 0) { + state_ = State::FAILED; + HELPER_LOG("noise_handshakestate_split failed: %s", noise_err_to_str(err).c_str()); + return APIError::HANDSHAKESTATE_SPLIT_FAILED; + } + + HELPER_LOG("Handshake complete!"); + noise_handshakestate_free(handshake_); + handshake_ = nullptr; + state_ = State::DATA; + return APIError::OK; +} + +APINoiseFrameHelper::~APINoiseFrameHelper() { + if (handshake_ != nullptr) { + noise_handshakestate_free(handshake_); + handshake_ = nullptr; + } + if (send_cipher_ != nullptr) { + noise_cipherstate_free(send_cipher_); + send_cipher_ = nullptr; + } + if (recv_cipher_ != nullptr) { + noise_cipherstate_free(recv_cipher_); + recv_cipher_ = nullptr; + } +} + +APIError APINoiseFrameHelper::close() { + state_ = State::CLOSED; + int err = socket_->close(); + if (err == -1) + return APIError::CLOSE_FAILED; + return APIError::OK; +} +APIError APINoiseFrameHelper::shutdown(int how) { + int err = socket_->shutdown(how); + if (err == -1) + return APIError::SHUTDOWN_FAILED; + if (how == SHUT_RDWR) { + state_ = State::CLOSED; + } + return APIError::OK; +} +extern "C" { +// declare how noise generates random bytes (here with a good HWRNG based on the RF system) +void noise_rand_bytes(void *output, size_t len) { esphome::fill_random(reinterpret_cast(output), len); } +} +#endif // USE_API_NOISE + +#ifdef USE_API_PLAINTEXT + /// Initialize the frame helper, returns OK if successful. APIError APIPlaintextFrameHelper::init() { if (state_ != State::INITIALIZE || socket_ == nullptr) { @@ -289,6 +900,7 @@ APIError APIPlaintextFrameHelper::shutdown(int how) { } return APIError::OK; } +#endif // USE_API_PLAINTEXT } // namespace api } // namespace esphome diff --git a/esphome/components/api/api_frame_helper.h b/esphome/components/api/api_frame_helper.h index 14a0760c25..7189bc4b4b 100644 --- a/esphome/components/api/api_frame_helper.h +++ b/esphome/components/api/api_frame_helper.h @@ -5,7 +5,12 @@ #include "esphome/core/defines.h" +#ifdef USE_API_NOISE +#include "noise/protocol.h" +#endif + #include "esphome/components/socket/socket.h" +#include "api_noise_context.h" namespace esphome { namespace api { @@ -27,6 +32,7 @@ struct PacketBuffer { enum class APIError : int { OK = 0, WOULD_BLOCK = 1001, + BAD_HANDSHAKE_PACKET_LEN = 1002, BAD_INDICATOR = 1003, BAD_DATA_PACKET = 1004, TCP_NODELAY_FAILED = 1005, @@ -37,7 +43,14 @@ enum class APIError : int { BAD_ARG = 1010, SOCKET_READ_FAILED = 1011, SOCKET_WRITE_FAILED = 1012, + HANDSHAKESTATE_READ_FAILED = 1013, + HANDSHAKESTATE_WRITE_FAILED = 1014, + HANDSHAKESTATE_BAD_STATE = 1015, + CIPHERSTATE_DECRYPT_FAILED = 1016, + CIPHERSTATE_ENCRYPT_FAILED = 1017, OUT_OF_MEMORY = 1018, + HANDSHAKESTATE_SETUP_FAILED = 1019, + HANDSHAKESTATE_SPLIT_FAILED = 1020, }; class APIFrameHelper { @@ -53,6 +66,68 @@ class APIFrameHelper { // Give this helper a name for logging virtual void set_log_info(std::string info) = 0; }; + +#ifdef USE_API_NOISE +class APINoiseFrameHelper : public APIFrameHelper { + public: + APINoiseFrameHelper(std::unique_ptr socket, std::shared_ptr ctx) + : socket_(std::move(socket)), ctx_(ctx) {} + ~APINoiseFrameHelper(); + APIError init() override; + APIError loop() override; + APIError read_packet(ReadPacketBuffer *buffer) override; + bool can_write_without_blocking() override; + APIError write_packet(uint16_t type, const uint8_t *payload, size_t len) override; + std::string getpeername() override { return socket_->getpeername(); } + APIError close() override; + APIError shutdown(int how) override; + // Give this helper a name for logging + void set_log_info(std::string info) override { info_ = std::move(info); } + + protected: + struct ParsedFrame { + std::vector msg; + }; + + APIError state_action_(); + APIError try_read_frame_(ParsedFrame *frame); + APIError try_send_tx_buf_(); + APIError write_frame_(const uint8_t *data, size_t len); + APIError write_raw_(const uint8_t *data, size_t len); + APIError init_handshake_(); + APIError check_handshake_finished_(); + void send_explicit_handshake_reject_(const std::string &reason); + + std::unique_ptr socket_; + + std::string info_; + uint8_t rx_header_buf_[3]; + size_t rx_header_buf_len_ = 0; + std::vector rx_buf_; + size_t rx_buf_len_ = 0; + + std::vector tx_buf_; + std::vector prologue_; + + std::shared_ptr ctx_; + NoiseHandshakeState *handshake_ = nullptr; + NoiseCipherState *send_cipher_ = nullptr; + NoiseCipherState *recv_cipher_ = nullptr; + NoiseProtocolId nid_; + + enum class State { + INITIALIZE = 1, + CLIENT_HELLO = 2, + SERVER_HELLO = 3, + HANDSHAKE = 4, + DATA = 5, + CLOSED = 6, + FAILED = 7, + } state_ = State::INITIALIZE; +}; +#endif // USE_API_NOISE + +#ifdef USE_API_PLAINTEXT class APIPlaintextFrameHelper : public APIFrameHelper { public: APIPlaintextFrameHelper(std::unique_ptr socket) : socket_(std::move(socket)) {} @@ -98,6 +173,7 @@ class APIPlaintextFrameHelper : public APIFrameHelper { FAILED = 4, } state_ = State::INITIALIZE; }; +#endif } // namespace api } // namespace esphome diff --git a/esphome/components/api/api_noise_context.h b/esphome/components/api/api_noise_context.h new file mode 100644 index 0000000000..fba6b65a26 --- /dev/null +++ b/esphome/components/api/api_noise_context.h @@ -0,0 +1,23 @@ +#pragma once +#include +#include +#include "esphome/core/defines.h" + +namespace esphome { +namespace api { + +#ifdef USE_API_NOISE +using psk_t = std::array; + +class APINoiseContext { + public: + void set_psk(psk_t psk) { psk_ = std::move(psk); } + const psk_t &get_psk() const { return psk_; } + + protected: + psk_t psk_; +}; +#endif // USE_API_NOISE + +} // namespace api +} // namespace esphome diff --git a/esphome/components/api/api_server.h b/esphome/components/api/api_server.h index 7c42fe7dd5..e3fa6b18c9 100644 --- a/esphome/components/api/api_server.h +++ b/esphome/components/api/api_server.h @@ -11,6 +11,7 @@ #include "list_entities.h" #include "subscribe_state.h" #include "user_services.h" +#include "api_noise_context.h" namespace esphome { namespace api { @@ -30,6 +31,11 @@ class APIServer : public Component, public Controller { void set_password(const std::string &password); void set_reboot_timeout(uint32_t reboot_timeout); +#ifdef USE_API_NOISE + void set_noise_psk(psk_t psk) { noise_ctx_->set_psk(std::move(psk)); } + std::shared_ptr get_noise_ctx() { return noise_ctx_; } +#endif // USE_API_NOISE + void handle_disconnect(APIConnection *conn); #ifdef USE_BINARY_SENSOR void on_binary_sensor_update(binary_sensor::BinarySensor *obj, bool state) override; @@ -89,6 +95,10 @@ class APIServer : public Component, public Controller { std::string password_; std::vector state_subs_; std::vector user_services_; + +#ifdef USE_API_NOISE + std::shared_ptr noise_ctx_ = std::make_shared(); +#endif // USE_API_NOISE }; extern APIServer *global_api_server; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) diff --git a/esphome/core/defines.h b/esphome/core/defines.h index d73f7e9d00..3cca6445b5 100644 --- a/esphome/core/defines.h +++ b/esphome/core/defines.h @@ -54,5 +54,8 @@ #define USE_SOCKET_IMPL_BSD_SOCKETS #endif +#define USE_API_PLAINTEXT +#define USE_API_NOISE + // Disabled feature flags //#define USE_BSEC // Requires a library with proprietary license. diff --git a/esphome/core/helpers.cpp b/esphome/core/helpers.cpp index 9e9c775899..c5ff0102c3 100644 --- a/esphome/core/helpers.cpp +++ b/esphome/core/helpers.cpp @@ -55,6 +55,15 @@ double random_double() { return random_uint32() / double(UINT32_MAX); } float random_float() { return float(random_double()); } +void fill_random(uint8_t *data, size_t len) { +#ifdef ARDUINO_ARCH_ESP32 + esp_fill_random(data, len); +#else + int err = os_get_random(data, len); + assert(err == 0); +#endif +} + static uint32_t fast_random_seed = 0; // NOLINT(cppcoreguidelines-avoid-non-const-global-variables) void fast_random_set_seed(uint32_t seed) { fast_random_seed = seed; } diff --git a/esphome/core/helpers.h b/esphome/core/helpers.h index 5868918cd6..60bc7a9ad3 100644 --- a/esphome/core/helpers.h +++ b/esphome/core/helpers.h @@ -109,6 +109,8 @@ double random_double(); /// Returns a random float between 0 and 1. Essentially just casts random_double() to a float. float random_float(); +void fill_random(uint8_t *data, size_t len); + void fast_random_set_seed(uint32_t seed); uint32_t fast_random_32(); uint16_t fast_random_16(); diff --git a/platformio.ini b/platformio.ini index 88b1000d1d..f4dea3fcb9 100644 --- a/platformio.ini +++ b/platformio.ini @@ -36,6 +36,7 @@ lib_deps = 6306@1.0.3 ; HM3301 glmnet/Dsmr@0.3 ; used by dsmr rweather/Crypto@0.2.0 ; used by dsmr + esphome/noise-c@0.1.1 ; used by api dudanov/MideaUART@1.1.0 ; used by midea build_flags = diff --git a/tests/test3.yaml b/tests/test3.yaml index c012871125..5602481c36 100644 --- a/tests/test3.yaml +++ b/tests/test3.yaml @@ -22,6 +22,8 @@ api: port: 8000 password: 'pwd' reboot_timeout: 0min + encryption: + key: 'bOFFzzvfpg5DB94DuBGLXD/hMnhpDKgP9UQyBulwWVU=' services: - service: hello_world variables: