From 056a28906ba9de193da4aa78f6034984654bc4a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kuba=20Szczodrzy=C5=84ski?= Date: Thu, 21 Sep 2023 00:09:23 +0200 Subject: [PATCH] Wizard: fix colored text in input prompts (#5313) --- esphome/util.py | 14 ++++++++++---- esphome/wizard.py | 18 ++++++++++-------- tests/unit_tests/test_wizard.py | 12 ++++++------ 3 files changed, 26 insertions(+), 18 deletions(-) diff --git a/esphome/util.py b/esphome/util.py index 0d60212f50..480618aca0 100644 --- a/esphome/util.py +++ b/esphome/util.py @@ -57,7 +57,7 @@ class SimpleRegistry(dict): return decorator -def safe_print(message=""): +def safe_print(message="", end="\n"): from esphome.core import CORE if CORE.dashboard: @@ -67,20 +67,26 @@ def safe_print(message=""): pass try: - print(message) + print(message, end=end) return except UnicodeEncodeError: pass try: - print(message.encode("utf-8", "backslashreplace")) + print(message.encode("utf-8", "backslashreplace"), end=end) except UnicodeEncodeError: try: - print(message.encode("ascii", "backslashreplace")) + print(message.encode("ascii", "backslashreplace"), end=end) except UnicodeEncodeError: print("Cannot print line because of invalid locale!") +def safe_input(prompt=""): + if prompt: + safe_print(prompt, end="") + return input() + + def shlex_quote(s): if not s: return "''" diff --git a/esphome/wizard.py b/esphome/wizard.py index aa05e513a7..1308338ad0 100644 --- a/esphome/wizard.py +++ b/esphome/wizard.py @@ -11,7 +11,7 @@ from esphome.core import CORE from esphome.helpers import get_bool_env, write_file from esphome.log import Fore, color from esphome.storage_json import StorageJSON, ext_storage_path -from esphome.util import safe_print +from esphome.util import safe_input, safe_print CORE_BIG = r""" _____ ____ _____ ______ / ____/ __ \| __ \| ____| @@ -252,7 +252,7 @@ def safe_print_step(step, big): def default_input(text, default): safe_print() safe_print(f"Press ENTER for default ({default})") - return input(text.format(default)) or default + return safe_input(text.format(default)) or default # From https://stackoverflow.com/a/518232/8924614 @@ -306,7 +306,7 @@ def wizard(path): ) safe_print() sleep(1) - name = input(color(Fore.BOLD_WHITE, "(name): ")) + name = safe_input(color(Fore.BOLD_WHITE, "(name): ")) while True: try: @@ -343,7 +343,9 @@ def wizard(path): while True: sleep(0.5) safe_print() - platform = input(color(Fore.BOLD_WHITE, f"({'/'.join(wizard_platforms)}): ")) + platform = safe_input( + color(Fore.BOLD_WHITE, f"({'/'.join(wizard_platforms)}): ") + ) try: platform = vol.All(vol.Upper, vol.Any(*wizard_platforms))(platform.upper()) break @@ -397,7 +399,7 @@ def wizard(path): boards.append(board_id) while True: - board = input(color(Fore.BOLD_WHITE, "(board): ")) + board = safe_input(color(Fore.BOLD_WHITE, "(board): ")) try: board = vol.All(vol.Lower, vol.Any(*boards))(board) break @@ -423,7 +425,7 @@ def wizard(path): sleep(1.5) safe_print(f"For example \"{color(Fore.BOLD_WHITE, 'Abraham Linksys')}\".") while True: - ssid = input(color(Fore.BOLD_WHITE, "(ssid): ")) + ssid = safe_input(color(Fore.BOLD_WHITE, "(ssid): ")) try: ssid = cv.ssid(ssid) break @@ -449,7 +451,7 @@ def wizard(path): safe_print() safe_print(f"For example \"{color(Fore.BOLD_WHITE, 'PASSWORD42')}\"") sleep(0.5) - psk = input(color(Fore.BOLD_WHITE, "(PSK): ")) + psk = safe_input(color(Fore.BOLD_WHITE, "(PSK): ")) safe_print( "Perfect! WiFi is now set up (you can create static IPs and so on later)." ) @@ -466,7 +468,7 @@ def wizard(path): safe_print() sleep(0.25) safe_print("Press ENTER for no password") - password = input(color(Fore.BOLD_WHITE, "(password): ")) + password = safe_input(color(Fore.BOLD_WHITE, "(password): ")) if not wizard_write( path=path, diff --git a/tests/unit_tests/test_wizard.py b/tests/unit_tests/test_wizard.py index 8bbce08ae5..46700a3ba8 100644 --- a/tests/unit_tests/test_wizard.py +++ b/tests/unit_tests/test_wizard.py @@ -319,7 +319,7 @@ def test_wizard_accepts_default_answers_esp8266(tmpdir, monkeypatch, wizard_answ config_file = tmpdir.join("test.yaml") input_mock = MagicMock(side_effect=wizard_answers) monkeypatch.setattr("builtins.input", input_mock) - monkeypatch.setattr(wz, "safe_print", lambda t=None: 0) + monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0) monkeypatch.setattr(wz, "sleep", lambda _: 0) monkeypatch.setattr(wz, "wizard_write", MagicMock()) @@ -341,7 +341,7 @@ def test_wizard_accepts_default_answers_esp32(tmpdir, monkeypatch, wizard_answer config_file = tmpdir.join("test.yaml") input_mock = MagicMock(side_effect=wizard_answers) monkeypatch.setattr("builtins.input", input_mock) - monkeypatch.setattr(wz, "safe_print", lambda t=None: 0) + monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0) monkeypatch.setattr(wz, "sleep", lambda _: 0) monkeypatch.setattr(wz, "wizard_write", MagicMock()) @@ -371,7 +371,7 @@ def test_wizard_offers_better_node_name(tmpdir, monkeypatch, wizard_answers): config_file = tmpdir.join("test.yaml") input_mock = MagicMock(side_effect=wizard_answers) monkeypatch.setattr("builtins.input", input_mock) - monkeypatch.setattr(wz, "safe_print", lambda t=None: 0) + monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0) monkeypatch.setattr(wz, "sleep", lambda _: 0) monkeypatch.setattr(wz, "wizard_write", MagicMock()) @@ -394,7 +394,7 @@ def test_wizard_requires_correct_platform(tmpdir, monkeypatch, wizard_answers): config_file = tmpdir.join("test.yaml") input_mock = MagicMock(side_effect=wizard_answers) monkeypatch.setattr("builtins.input", input_mock) - monkeypatch.setattr(wz, "safe_print", lambda t=None: 0) + monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0) monkeypatch.setattr(wz, "sleep", lambda _: 0) monkeypatch.setattr(wz, "wizard_write", MagicMock()) @@ -416,7 +416,7 @@ def test_wizard_requires_correct_board(tmpdir, monkeypatch, wizard_answers): config_file = tmpdir.join("test.yaml") input_mock = MagicMock(side_effect=wizard_answers) monkeypatch.setattr("builtins.input", input_mock) - monkeypatch.setattr(wz, "safe_print", lambda t=None: 0) + monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0) monkeypatch.setattr(wz, "sleep", lambda _: 0) monkeypatch.setattr(wz, "wizard_write", MagicMock()) @@ -438,7 +438,7 @@ def test_wizard_requires_valid_ssid(tmpdir, monkeypatch, wizard_answers): config_file = tmpdir.join("test.yaml") input_mock = MagicMock(side_effect=wizard_answers) monkeypatch.setattr("builtins.input", input_mock) - monkeypatch.setattr(wz, "safe_print", lambda t=None: 0) + monkeypatch.setattr(wz, "safe_print", lambda t=None, end=None: 0) monkeypatch.setattr(wz, "sleep", lambda _: 0) monkeypatch.setattr(wz, "wizard_write", MagicMock())