esphome/esphome/espota2.py

293 lines
9.7 KiB
Python

import hashlib
import logging
import random
import socket
import sys
import time
from esphome.core import EsphomeError
from esphome.helpers import is_ip_address, resolve_ip_address
RESPONSE_OK = 0
RESPONSE_REQUEST_AUTH = 1
RESPONSE_HEADER_OK = 64
RESPONSE_AUTH_OK = 65
RESPONSE_UPDATE_PREPARE_OK = 66
RESPONSE_BIN_MD5_OK = 67
RESPONSE_RECEIVE_OK = 68
RESPONSE_UPDATE_END_OK = 69
RESPONSE_ERROR_MAGIC = 128
RESPONSE_ERROR_UPDATE_PREPARE = 129
RESPONSE_ERROR_AUTH_INVALID = 130
RESPONSE_ERROR_WRITING_FLASH = 131
RESPONSE_ERROR_UPDATE_END = 132
RESPONSE_ERROR_INVALID_BOOTSTRAPPING = 133
RESPONSE_ERROR_WRONG_CURRENT_FLASH_CONFIG = 134
RESPONSE_ERROR_WRONG_NEW_FLASH_CONFIG = 135
RESPONSE_ERROR_ESP8266_NOT_ENOUGH_SPACE = 136
RESPONSE_ERROR_ESP32_NOT_ENOUGH_SPACE = 137
RESPONSE_ERROR_UNKNOWN = 255
OTA_VERSION_1_0 = 1
MAGIC_BYTES = [0x6C, 0x26, 0xF7, 0x5C, 0x45]
_LOGGER = logging.getLogger(__name__)
class ProgressBar:
def __init__(self):
self.last_progress = None
def update(self, progress):
bar_length = 60
status = ""
if progress >= 1:
progress = 1
status = "Done...\r\n"
new_progress = int(progress * 100)
if new_progress == self.last_progress:
return
self.last_progress = new_progress
block = int(round(bar_length * progress))
text = "\rUploading: [{0}] {1}% {2}".format("=" * block + " " * (bar_length - block),
new_progress, status)
sys.stderr.write(text)
sys.stderr.flush()
# pylint: disable=no-self-use
def done(self):
sys.stderr.write('\n')
sys.stderr.flush()
class OTAError(EsphomeError):
pass
def recv_decode(sock, amount, decode=True):
data = sock.recv(amount)
if not decode:
return data
return list(data)
def receive_exactly(sock, amount, msg, expect, decode=True):
if decode:
data = []
else:
data = b''
try:
data += recv_decode(sock, 1, decode=decode)
except OSError as err:
raise OTAError(f"Error receiving acknowledge {msg}: {err}")
try:
check_error(data, expect)
except OTAError as err:
sock.close()
raise OTAError(f"Error {msg}: {err}")
while len(data) < amount:
try:
data += recv_decode(sock, amount - len(data), decode=decode)
except OSError as err:
raise OTAError(f"Error receiving {msg}: {err}")
return data
def check_error(data, expect):
if not expect:
return
dat = data[0]
if dat == RESPONSE_ERROR_MAGIC:
raise OTAError("Error: Invalid magic byte")
if dat == RESPONSE_ERROR_UPDATE_PREPARE:
raise OTAError("Error: Couldn't prepare flash memory for update. Is the binary too big? "
"Please try restarting the ESP.")
if dat == RESPONSE_ERROR_AUTH_INVALID:
raise OTAError("Error: Authentication invalid. Is the password correct?")
if dat == RESPONSE_ERROR_WRITING_FLASH:
raise OTAError("Error: Wring OTA data to flash memory failed. See USB logs for more "
"information.")
if dat == RESPONSE_ERROR_UPDATE_END:
raise OTAError("Error: Finishing update failed. See the MQTT/USB logs for more "
"information.")
if dat == RESPONSE_ERROR_INVALID_BOOTSTRAPPING:
raise OTAError("Error: Please press the reset button on the ESP. A manual reset is "
"required on the first OTA-Update after flashing via USB.")
if dat == RESPONSE_ERROR_WRONG_CURRENT_FLASH_CONFIG:
raise OTAError("Error: ESP has been flashed with wrong flash size. Please choose the "
"correct 'board' option (esp01_1m always works) and then flash over USB.")
if dat == RESPONSE_ERROR_WRONG_NEW_FLASH_CONFIG:
raise OTAError("Error: ESP does not have the requested flash size (wrong board). Please "
"choose the correct 'board' option (esp01_1m always works) and try "
"uploading again.")
if dat == RESPONSE_ERROR_ESP8266_NOT_ENOUGH_SPACE:
raise OTAError("Error: ESP does not have enough space to store OTA file. Please try "
"flashing a minimal firmware (remove everything except ota)")
if dat == RESPONSE_ERROR_ESP32_NOT_ENOUGH_SPACE:
raise OTAError("Error: The OTA partition on the ESP is too small. ESPHome needs to resize "
"this partition, please flash over USB.")
if dat == RESPONSE_ERROR_UNKNOWN:
raise OTAError("Unknown error from ESP")
if not isinstance(expect, (list, tuple)):
expect = [expect]
if dat not in expect:
raise OTAError("Unexpected response from ESP: 0x{:02X}".format(data[0]))
def send_check(sock, data, msg):
try:
if isinstance(data, (list, tuple)):
data = bytes(data)
elif isinstance(data, int):
data = bytes([data])
elif isinstance(data, str):
data = data.encode('utf8')
sock.sendall(data)
except OSError as err:
raise OTAError(f"Error sending {msg}: {err}")
def perform_ota(sock, password, file_handle, filename):
file_md5 = hashlib.md5(file_handle.read()).hexdigest()
file_size = file_handle.tell()
_LOGGER.info('Uploading %s (%s bytes)', filename, file_size)
file_handle.seek(0)
_LOGGER.debug("MD5 of binary is %s", file_md5)
# Enable nodelay, we need it for phase 1
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
send_check(sock, MAGIC_BYTES, 'magic bytes')
_, version = receive_exactly(sock, 2, 'version', RESPONSE_OK)
if version != OTA_VERSION_1_0:
raise OTAError(f"Unsupported OTA version {version}")
# Features
send_check(sock, 0x00, 'features')
receive_exactly(sock, 1, 'features', RESPONSE_HEADER_OK)
auth, = receive_exactly(sock, 1, 'auth', [RESPONSE_REQUEST_AUTH, RESPONSE_AUTH_OK])
if auth == RESPONSE_REQUEST_AUTH:
if not password:
raise OTAError("ESP requests password, but no password given!")
nonce = receive_exactly(sock, 32, 'authentication nonce', [], decode=False).decode()
_LOGGER.debug("Auth: Nonce is %s", nonce)
cnonce = hashlib.md5(str(random.random()).encode()).hexdigest()
_LOGGER.debug("Auth: CNonce is %s", cnonce)
send_check(sock, cnonce, 'auth cnonce')
result_md5 = hashlib.md5()
result_md5.update(password.encode('utf-8'))
result_md5.update(nonce.encode())
result_md5.update(cnonce.encode())
result = result_md5.hexdigest()
_LOGGER.debug("Auth: Result is %s", result)
send_check(sock, result, 'auth result')
receive_exactly(sock, 1, 'auth result', RESPONSE_AUTH_OK)
file_size_encoded = [
(file_size >> 24) & 0xFF,
(file_size >> 16) & 0xFF,
(file_size >> 8) & 0xFF,
(file_size >> 0) & 0xFF,
]
send_check(sock, file_size_encoded, 'binary size')
receive_exactly(sock, 1, 'binary size', RESPONSE_UPDATE_PREPARE_OK)
send_check(sock, file_md5, 'file checksum')
receive_exactly(sock, 1, 'file checksum', RESPONSE_BIN_MD5_OK)
# Disable nodelay for transfer
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 0)
# Limit send buffer (usually around 100kB) in order to have progress bar
# show the actual progress
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 8192)
# Set higher timeout during upload
sock.settimeout(20.0)
offset = 0
progress = ProgressBar()
while True:
chunk = file_handle.read(1024)
if not chunk:
break
offset += len(chunk)
try:
sock.sendall(chunk)
except OSError as err:
sys.stderr.write('\n')
raise OTAError(f"Error sending data: {err}")
progress.update(offset / float(file_size))
progress.done()
# Enable nodelay for last checks
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
_LOGGER.info("Waiting for result...")
receive_exactly(sock, 1, 'receive OK', RESPONSE_RECEIVE_OK)
receive_exactly(sock, 1, 'Update end', RESPONSE_UPDATE_END_OK)
send_check(sock, RESPONSE_OK, 'end acknowledgement')
_LOGGER.info("OTA successful")
# Do not connect logs until it is fully on
time.sleep(1)
def run_ota_impl_(remote_host, remote_port, password, filename):
if is_ip_address(remote_host):
_LOGGER.info("Connecting to %s", remote_host)
ip = remote_host
else:
_LOGGER.info("Resolving IP address of %s", remote_host)
try:
ip = resolve_ip_address(remote_host)
except EsphomeError as err:
_LOGGER.error("Error resolving IP address of %s. Is it connected to WiFi?",
remote_host)
_LOGGER.error("(If this error persists, please set a static IP address: "
"https://esphome.io/components/wifi.html#manual-ips)")
raise OTAError(err)
_LOGGER.info(" -> %s", ip)
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(10.0)
try:
sock.connect((ip, remote_port))
except OSError as err:
sock.close()
_LOGGER.error("Connecting to %s:%s failed: %s", remote_host, remote_port, err)
return 1
file_handle = open(filename, 'rb')
try:
perform_ota(sock, password, file_handle, filename)
except OTAError as err:
_LOGGER.error(str(err))
return 1
finally:
sock.close()
file_handle.close()
return 0
def run_ota(remote_host, remote_port, password, filename):
try:
return run_ota_impl_(remote_host, remote_port, password, filename)
except OTAError as err:
_LOGGER.error(err)
return 1