Improv serial/checksum changes (#2731)

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Jesse Hills 2021-11-16 11:02:45 +13:00
parent 024632dbd0
commit b0a0a153f3
No known key found for this signature in database
GPG Key ID: BEAAE804EFD8E83A
3 changed files with 61 additions and 37 deletions

View File

@ -2,30 +2,32 @@
namespace improv { namespace improv {
ImprovCommand parse_improv_data(const std::vector<uint8_t> &data) { ImprovCommand parse_improv_data(const std::vector<uint8_t> &data, bool check_checksum) {
return parse_improv_data(data.data(), data.size()); return parse_improv_data(data.data(), data.size(), check_checksum);
} }
ImprovCommand parse_improv_data(const uint8_t *data, size_t length) { ImprovCommand parse_improv_data(const uint8_t *data, size_t length, bool check_checksum) {
ImprovCommand improv_command; ImprovCommand improv_command;
Command command = (Command) data[0]; Command command = (Command) data[0];
uint8_t data_length = data[1]; uint8_t data_length = data[1];
if (data_length != length - 3) { if (data_length != length - 2 - check_checksum) {
improv_command.command = UNKNOWN; improv_command.command = UNKNOWN;
return improv_command; return improv_command;
} }
uint8_t checksum = data[length - 1]; if (check_checksum) {
uint8_t checksum = data[length - 1];
uint32_t calculated_checksum = 0; uint32_t calculated_checksum = 0;
for (uint8_t i = 0; i < length - 1; i++) { for (uint8_t i = 0; i < length - 1; i++) {
calculated_checksum += data[i]; calculated_checksum += data[i];
} }
if ((uint8_t) calculated_checksum != checksum) { if ((uint8_t) calculated_checksum != checksum) {
improv_command.command = BAD_CHECKSUM; improv_command.command = BAD_CHECKSUM;
return improv_command; return improv_command;
}
} }
if (command == WIFI_SETTINGS) { if (command == WIFI_SETTINGS) {
@ -46,7 +48,7 @@ ImprovCommand parse_improv_data(const uint8_t *data, size_t length) {
return improv_command; return improv_command;
} }
std::vector<uint8_t> build_rpc_response(Command command, const std::vector<std::string> &datum) { std::vector<uint8_t> build_rpc_response(Command command, const std::vector<std::string> &datum, bool add_checksum) {
std::vector<uint8_t> out; std::vector<uint8_t> out;
uint32_t length = 0; uint32_t length = 0;
out.push_back(command); out.push_back(command);
@ -58,17 +60,19 @@ std::vector<uint8_t> build_rpc_response(Command command, const std::vector<std::
} }
out.insert(out.begin() + 1, length); out.insert(out.begin() + 1, length);
uint32_t calculated_checksum = 0; if (add_checksum) {
uint32_t calculated_checksum = 0;
for (uint8_t byte : out) { for (uint8_t byte : out) {
calculated_checksum += byte; calculated_checksum += byte;
}
out.push_back(calculated_checksum);
} }
out.push_back(calculated_checksum);
return out; return out;
} }
#ifdef USE_ARDUINO #ifdef ARDUINO
std::vector<uint8_t> build_rpc_response(Command command, const std::vector<String> &datum) { std::vector<uint8_t> build_rpc_response(Command command, const std::vector<String> &datum, bool add_checksum) {
std::vector<uint8_t> out; std::vector<uint8_t> out;
uint32_t length = 0; uint32_t length = 0;
out.push_back(command); out.push_back(command);
@ -80,14 +84,16 @@ std::vector<uint8_t> build_rpc_response(Command command, const std::vector<Strin
} }
out.insert(out.begin() + 1, length); out.insert(out.begin() + 1, length);
uint32_t calculated_checksum = 0; if (add_checksum) {
uint32_t calculated_checksum = 0;
for (uint8_t byte : out) { for (uint8_t byte : out) {
calculated_checksum += byte; calculated_checksum += byte;
}
out.push_back(calculated_checksum);
} }
out.push_back(calculated_checksum);
return out; return out;
} }
#endif // USE_ARDUINO #endif // ARDUINO
} // namespace improv } // namespace improv

View File

@ -51,12 +51,13 @@ struct ImprovCommand {
std::string password; std::string password;
}; };
ImprovCommand parse_improv_data(const std::vector<uint8_t> &data); ImprovCommand parse_improv_data(const std::vector<uint8_t> &data, bool check_checksum = true);
ImprovCommand parse_improv_data(const uint8_t *data, size_t length); ImprovCommand parse_improv_data(const uint8_t *data, size_t length, bool check_checksum = true);
std::vector<uint8_t> build_rpc_response(Command command, const std::vector<std::string> &datum); std::vector<uint8_t> build_rpc_response(Command command, const std::vector<std::string> &datum,
bool add_checksum = true);
#ifdef ARDUINO #ifdef ARDUINO
std::vector<uint8_t> build_rpc_response(Command command, const std::vector<String> &datum); std::vector<uint8_t> build_rpc_response(Command command, const std::vector<String> &datum, bool add_checksum = true);
#endif // ARDUINO #endif // ARDUINO
} // namespace improv } // namespace improv

View File

@ -98,13 +98,13 @@ std::vector<uint8_t> ImprovSerialComponent::build_rpc_settings_response_(improv:
std::string webserver_url = "http://" + ip.str() + ":" + to_string(WEBSERVER_PORT); std::string webserver_url = "http://" + ip.str() + ":" + to_string(WEBSERVER_PORT);
urls.push_back(webserver_url); urls.push_back(webserver_url);
#endif #endif
std::vector<uint8_t> data = improv::build_rpc_response(command, urls); std::vector<uint8_t> data = improv::build_rpc_response(command, urls, false);
return data; return data;
} }
std::vector<uint8_t> ImprovSerialComponent::build_version_info_() { std::vector<uint8_t> ImprovSerialComponent::build_version_info_() {
std::vector<std::string> infos = {"ESPHome", ESPHOME_VERSION, ESPHOME_VARIANT, App.get_name()}; std::vector<std::string> infos = {"ESPHome", ESPHOME_VERSION, ESPHOME_VARIANT, App.get_name()};
std::vector<uint8_t> data = improv::build_rpc_response(improv::GET_DEVICE_INFO, infos); std::vector<uint8_t> data = improv::build_rpc_response(improv::GET_DEVICE_INFO, infos, false);
return data; return data;
}; };
@ -140,22 +140,33 @@ bool ImprovSerialComponent::parse_improv_serial_byte_(uint8_t byte) {
if (at < 8 + data_len) if (at < 8 + data_len)
return true; return true;
if (at == 8 + data_len) { if (at == 8 + data_len)
return true;
if (at == 8 + data_len + 1) {
uint8_t checksum = 0x00;
for (uint8_t i = 0; i < at; i++)
checksum += raw[i];
if (checksum != byte) {
ESP_LOGW(TAG, "Error decoding Improv payload");
this->set_error_(improv::ERROR_INVALID_RPC);
return false;
}
if (type == TYPE_RPC) { if (type == TYPE_RPC) {
this->set_error_(improv::ERROR_NONE); this->set_error_(improv::ERROR_NONE);
auto command = improv::parse_improv_data(&raw[9], data_len); auto command = improv::parse_improv_data(&raw[9], data_len, false);
return this->parse_improv_payload_(command); return this->parse_improv_payload_(command);
} }
} }
return true;
// If we got here then the command coming is is improv, but not an RPC command
return false;
} }
bool ImprovSerialComponent::parse_improv_payload_(improv::ImprovCommand &command) { bool ImprovSerialComponent::parse_improv_payload_(improv::ImprovCommand &command) {
switch (command.command) { switch (command.command) {
case improv::BAD_CHECKSUM:
ESP_LOGW(TAG, "Error decoding Improv payload");
this->set_error_(improv::ERROR_INVALID_RPC);
return false;
case improv::WIFI_SETTINGS: { case improv::WIFI_SETTINGS: {
wifi::WiFiAP sta{}; wifi::WiFiAP sta{};
sta.set_ssid(command.ssid); sta.set_ssid(command.ssid);
@ -232,6 +243,12 @@ void ImprovSerialComponent::send_response_(std::vector<uint8_t> &response) {
data[7] = TYPE_RPC_RESPONSE; data[7] = TYPE_RPC_RESPONSE;
data[8] = response.size(); data[8] = response.size();
data.insert(data.end(), response.begin(), response.end()); data.insert(data.end(), response.begin(), response.end());
uint8_t checksum = 0x00;
for (uint8_t d : data)
checksum += d;
data.push_back(checksum);
this->write_data_(data); this->write_data_(data);
} }