firmware_flasher: switch bootloader comms to fixed msg len

main
Ilya Zhuravlev 2021-01-06 10:59:48 -05:00
parent 1dfe2519ed
commit b0854a5e3d
3 changed files with 18 additions and 10 deletions

View File

@ -13,7 +13,7 @@ from PyQt5.QtWidgets import QHBoxLayout, QLineEdit, QToolButton, QPlainTextEdit,
from basic_editor import BasicEditor from basic_editor import BasicEditor
from unlocker import Unlocker from unlocker import Unlocker
from util import tr, chunks, find_vial_devices from util import tr, chunks, find_vial_devices, pad_for_vibl
from vial_device import VialBootloader, VialKeyboard from vial_device import VialBootloader, VialKeyboard
@ -23,6 +23,9 @@ BL_SUPPORTED_VERSION = 0
def send_retries(dev, data, retries=20): def send_retries(dev, data, retries=20):
""" Sends usb packet up to 'retries' times, returns True if success, False if failed """ """ Sends usb packet up to 'retries' times, returns True if success, False if failed """
if len(data) != 64:
raise RuntimeError("sending invalid data length")
for x in range(retries): for x in range(retries):
ret = dev.send(data) ret = dev.send(data)
if ret == len(data) + 1: if ret == len(data) + 1:
@ -55,13 +58,13 @@ def cmd_flash(device, firmware, enable_insecure, log_cb, progress_cb, complete_c
)) ))
# Check bootloader is correct version # Check bootloader is correct version
device.send(b"VC\x00") send_retries(device, pad_for_vibl(b"VC\x00"))
ver = device.recv(8)[0] ver = device.recv(8)[0]
log_cb("* Bootloader version: {}".format(ver)) log_cb("* Bootloader version: {}".format(ver))
if ver != BL_SUPPORTED_VERSION: if ver != BL_SUPPORTED_VERSION:
return error_cb("Error: Unsupported bootloader version") return error_cb("Error: Unsupported bootloader version")
device.send(b"VC\x01") send_retries(device, pad_for_vibl(b"VC\x01"))
uid = device.recv(8) uid = device.recv(8)
log_cb("* Vial ID: {}".format(uid.hex())) log_cb("* Vial ID: {}".format(uid.hex()))
@ -81,11 +84,9 @@ def cmd_flash(device, firmware, enable_insecure, log_cb, progress_cb, complete_c
# Flash # Flash
log_cb("Flashing...") log_cb("Flashing...")
device.send(b"VC\x02" + struct.pack("<H", len(fw_payload) // CHUNK)) send_retries(device, pad_for_vibl(b"VC\x02" + struct.pack("<H", len(fw_payload) // CHUNK)))
total = 0 total = 0
for part in chunks(fw_payload, CHUNK): for part in chunks(fw_payload, CHUNK):
if len(part) < CHUNK:
part += b"\x00" * (CHUNK - len(part))
if not send_retries(device, part): if not send_retries(device, part):
return error_cb("Error while sending data, firmware is corrupted") return error_cb("Error while sending data, firmware is corrupted")
total += len(part) total += len(part)
@ -95,8 +96,8 @@ def cmd_flash(device, firmware, enable_insecure, log_cb, progress_cb, complete_c
log_cb("Rebooting...") log_cb("Rebooting...")
# enable insecure mode on first boot in order to restore keymap/macros # enable insecure mode on first boot in order to restore keymap/macros
if enable_insecure: if enable_insecure:
device.send(b"VC\x04") send_retries(device, pad_for_vibl(b"VC\x04"))
device.send(b"VC\x03") send_retries(device, pad_for_vibl(b"VC\x03"))
complete_cb("Done!") complete_cb("Done!")

View File

@ -50,3 +50,10 @@ def find_vial_devices(sideload_vid=None, sideload_pid=None):
def chunks(data, sz): def chunks(data, sz):
for i in range(0, len(data), sz): for i in range(0, len(data), sz):
yield data[i:i+sz] yield data[i:i+sz]
def pad_for_vibl(msg):
""" Pads message to vibl fixed 64-byte length """
if len(msg) > 64:
raise RuntimeError("vibl message too long")
return msg + b"\x00" * (64 - len(msg))

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: GPL-2.0-or-later # SPDX-License-Identifier: GPL-2.0-or-later
from hidproxy import hid from hidproxy import hid
from keyboard_comm import Keyboard from keyboard_comm import Keyboard
from util import MSG_LEN from util import MSG_LEN, pad_for_vibl
class VialDevice: class VialDevice:
@ -66,7 +66,7 @@ class VialBootloader(VialDevice):
super().open() super().open()
except OSError: except OSError:
return b"" return b""
self.send(b"VC\x01") self.send(pad_for_vibl(b"VC\x01"))
data = self.recv(8, timeout_ms=500) data = self.recv(8, timeout_ms=500)
super().close() super().close()
return data return data