Implement sensible error handling

This commit is contained in:
Grigori Goronzy 2014-01-06 22:27:15 +01:00
parent cdbb6eee7c
commit 7b748dcf22

172
stcgal.py
View File

@ -25,8 +25,6 @@
"""
TODO:
- Utils class?
- error/exception handling
- some more documentation / comments
- private member naming, other style issues
@ -41,14 +39,34 @@ import argparse
DEBUG = False
class Utils:
"""make sensible boolean from string or other type value"""
@classmethod
def to_bool(self, val):
"""make sensible boolean from string or other type value"""
if isinstance(val, bool): return val
if isinstance(val, int): return bool(val)
if len(val) == 0: return False
return True if val[0].lower() == "t" or val[0] == "1" else False
@classmethod
def to_int(self, val):
"""make int from any value, nice error message if not possible"""
try: return int(val, 0)
except: raise ValueError("invalid integer")
class BaudType:
"""Check baud rate for validity"""
def __call__(self, string):
baud = int(string)
if baud not in serial.Serial.BAUDRATES:
raise argparse.ArgumentTypeError("illegal baudrate")
return baud
def __repr__(self): return "baudrate"
class Stc12Option:
"""Manipulate STC10/11/12 series option bytes"""
@ -113,9 +131,10 @@ class Stc12Option:
return 2 ** (((self.msr[0] >> 4) & 0x03) + 12)
def set_osc_stable_delay(self, val):
val = int(val, 0)
val = Utils.to_int(val)
osc_vals = {4096: 0, 8192: 1, 16384: 2, 32768: 3}
if val not in osc_vals.keys(): raise ValueError
if val not in osc_vals.keys():
raise ValueError("must be one of %s" % list(osc_vals.keys()))
self.msr[0] &= 0x8f
self.msr[0] |= osc_vals[val] << 4
@ -125,7 +144,8 @@ class Stc12Option:
def set_por_delay(self, val):
delays = {"short": 1, "long": 0}
if val not in delays.keys(): raise ValueError
if val not in delays.keys():
raise ValueError("must be one of %s" % list(delays.keys()))
self.msr[1] &= 0x7f
self.msr[1] |= delays[val] << 7
@ -135,7 +155,8 @@ class Stc12Option:
def set_clock_gain(self, val):
gains = {"low": 0, "high": 1}
if val not in gains.keys(): raise ValueError
if val not in gains.keys():
raise ValueError("must be one of %s" % list(gains.keys()))
self.msr[1] &= 0xbf
self.msr[1] |= gains[val] << 6
@ -145,7 +166,8 @@ class Stc12Option:
def set_clock_source(self, val):
sources = {"internal": 0, "external": 1}
if val not in sources.keys(): raise ValueError
if val not in sources.keys():
raise ValueError("must be one of %s" % list(sources.keys()))
self.msr[1] &= 0xfd
self.msr[1] |= sources[val] << 1
@ -169,9 +191,10 @@ class Stc12Option:
return 2 ** (((self.msr[2]) & 0x07) + 1)
def set_watchdog_prescale(self, val):
val = int(val, 0)
val = Utils.to_int(val)
wd_vals = {2: 0, 4: 1, 8: 2, 16: 3, 32: 4, 64: 5, 128: 6, 256: 7}
if val not in wd_vals.keys(): raise ValueError
if val not in wd_vals.keys():
raise ValueError("must be one of %s" % list(wd_vals.keys()))
self.msr[2] &= 0xf8
self.msr[2] |= wd_vals[val]
@ -245,18 +268,14 @@ class Stc12Protocol:
packet = bytes()
packet += self.ser.read(2)
if packet[0:2] != self.PACKET_START:
print("Wrong magic (%s), discarding packet!" %
packet[0:2], file=sys.stderr)
self.dump_packet(packet)
return None
raise RuntimeError("wrong packet frame start")
# read direction and length
packet += self.ser.read(3)
if packet[2] != self.PACKET_MCU[0]:
print("Wrong direction (%s), discarding packet!" %
hex(packet[3]), file=sys.stderr)
self.dump_packet(packet)
return None
raise RuntimeError("wrong packet direction magic")
# read packet data
packet_len, = struct.unpack(">H", packet[3:5])
@ -264,19 +283,15 @@ class Stc12Protocol:
# verify end code
if packet[packet_len+1] != self.PACKET_END[0]:
print("Wrong end code (%s), discarding packet!" %
hex(packet[packet_len+1]), file=sys.stderr)
self.dump_packet(packet)
return None
raise RuntimeError("wrong packet frame end")
# verify checksum
packet_csum, = struct.unpack(">H", packet[packet_len-1:packet_len+1])
calc_csum = sum(packet[2:packet_len-1]) & 0xffff
if packet_csum != calc_csum:
print("Wrong checksum (%s, expected %s), discarding packet!" %
(hex(packet_csum), hex(calc_csum)), file=sys.stderr)
self.dump_packet(packet)
return None
raise RuntimeError("packet checksum mismatch")
self.dump_packet(packet, receive=True)
@ -331,7 +346,8 @@ class Stc12Protocol:
# baudrate is directly controlled by programming the MCU's BRT register
brt = 256 - round((self.mcu_clock_hz) / (self.baud_transfer * 16))
brt_csum = (2 * (256 - brt)) & 0xff
baud_actual = (self.mcu_clock_hz) / (16 * (256 - brt))
try: baud_actual = (self.mcu_clock_hz) / (16 * (256 - brt))
except ZeroDivisionError: raise RuntimeError("baudrate too high")
baud_error = (abs(self.baud_transfer - baud_actual) * 100.0) / self.baud_transfer
if baud_error > 5.0:
print("WARNING: baud rate error is %.2f%%. You may need to set a slower rate." %
@ -380,16 +396,13 @@ class Stc12Protocol:
# read status packet
status_packet = self.read_packet()
if status_packet == None or status_packet[0] != 0x50:
print("Error receiving status packet, aborting!", file=sys.stderr)
return False
if status_packet[0] != 0x50:
raise RuntimeError("wrong magic in status packet")
self.decode_status_packet(status_packet)
self.print_mcu_info()
self.options = Stc12Option(status_packet[23:27])
self.options.print()
return True
def handshake(self):
"""Do baudrate handshake
@ -404,9 +417,8 @@ class Stc12Protocol:
packet += struct.pack(">H", self.mcu_magic)
self.write_packet(packet)
response = self.read_packet()
if response == None or response[0] != 0x8f:
print("Error receiving handshake packet, aborting!", file=sys.stderr)
return False
if response[0] != 0x8f:
raise RuntimeError("wrong magic in handshake packet")
# test new settings
print("testing...", end="")
@ -417,9 +429,8 @@ class Stc12Protocol:
self.ser.baudrate = self.baud_transfer
response = self.read_packet()
self.ser.baudrate = self.baud_handshake
if response == None or response[0] != 0x8f:
print("Error receiving handshake packet, aborting!", file=sys.stderr)
return False
if response[0] != 0x8f:
raise RuntimeError("wrong magic in handshake packet")
# switch to the settings
print("setting...", end="")
@ -429,12 +440,10 @@ class Stc12Protocol:
time.sleep(0.2)
self.ser.baudrate = self.baud_transfer
response = self.read_packet()
if response == None or response[0] != 0x84:
print("Error receiving handshake packet, aborting!", file=sys.stderr)
return False
print("done")
if response[0] != 0x84:
raise RuntimeError("wrong magic in handshake packet")
return True
print("done")
def erase_flash(self, erase_size, flash_size):
"""Erase the MCU's flash memory.
@ -451,10 +460,8 @@ class Stc12Protocol:
for i in range(0x80, 0x0d, -1): packet += bytes([i])
self.write_packet(packet)
response = self.read_packet()
if response == None or response[0] != 0x00:
print("Error receiving erase response, aborting!", file=sys.stderr)
return False
return True
if response[0] != 0x00:
raise RuntimeError("wrong magic in erase packet")
def program_flash(self, addr, data):
"""Program the MCU's flash memory.
@ -474,12 +481,10 @@ class Stc12Protocol:
csum = sum(packet[7:]) & 0xff
self.write_packet(packet)
response = self.read_packet()
if response == None or response[0] != 0x00:
print("Error receiving program response packet, aborting!", file=sys.stderr)
return False
if response[0] != 0x00:
raise RuntimeError("wrong magic in write packet")
elif response[1] != csum:
print("Wrong checksum in program response (%s, expected %s), aborting!" %
(hex(response[1]), hex(csum)), file=sys.stderr)
raise RuntimeError("verification checksum mismatch")
print(".", end="")
sys.stdout.flush()
print()
@ -488,13 +493,10 @@ class Stc12Protocol:
packet += struct.pack(">H", self.mcu_magic)
self.write_packet(packet)
response = self.read_packet()
if response == None or response[0] != 0x8d:
print("Error receiving program finish response packet, aborting!", file=sys.stderr)
return False
if response[0] != 0x8d:
raise RuntimeError("wrong magic in finish packet")
print("Finished writing flash!")
return True
def set_option(self, name, value):
self.options.set_option(name, value)
@ -509,16 +511,13 @@ class Stc12Protocol:
packet += struct.pack(">I", int(self.mcu_clock_hz))
self.write_packet(packet)
response = self.read_packet()
if response == None or response[0] != 0x50:
print("Error receiving set options response packet, aborting!", file=sys.stderr)
return False
if response[0] != 0x50:
raise RuntimeError("wrong magic in option packet")
print("Target UID: %02x%02x%02x%02x%02x%02x%02x" %
(response[18], response[19], response[20], response[21],
response[22], response[23], response[24]))
return True
def disconnect(self):
"""Disconnect from MCU"""
@ -535,31 +534,62 @@ class StcGal:
self.opts = opts
self.protocol = Stc12Protocol(opts.port, opts.handshake, opts.baud)
def emit_options(self, options):
for o in options:
try:
k, v = o.split("=", 1)
self.protocol.set_option(k, v)
except ValueError as e:
raise NameError("invalid option '%s' (%s)" % (k, e))
def run(self):
self.protocol.connect()
try: self.protocol.connect()
except KeyboardInterrupt:
print("interrupted")
return 2
except RuntimeError as e:
print("Communication error: %s" % e, file=sys.stderr)
return 1
except serial.serialutil.SerialException as e:
print("Serial communication error: %s" % e, file=sys.stderr)
return 1
if opts.binary:
bindata = opts.binary.read()
try:
bindata = opts.binary.read()
if opts.option:
for o in opts.option:
k, v = o.split("=", 1)
self.protocol.set_option(k, v)
if opts.option: self.emit_options(opts.option)
self.protocol.handshake()
self.protocol.erase_flash(len(bindata), 0xf0 * 256)
self.protocol.program_flash(0, bindata)
self.protocol.program_options()
self.protocol.handshake()
self.protocol.erase_flash(len(bindata), 0xf0 * 256)
self.protocol.program_flash(0, bindata)
self.protocol.program_options()
self.protocol.disconnect()
return 0
except NameError as e:
print("Option error: %s" % e, file=sys.stderr)
self.protocol.disconnect()
return 1
except RuntimeError as e:
print("Communication error: %s" % e, file=sys.stderr)
self.protocol.disconnect()
return 1
except KeyboardInterrupt:
print("interrupted")
self.protocol.disconnect()
return 2
except serial.serialutil.SerialException as e:
print("Serial communication error: %s" % e, file=sys.stderr)
return 1
self.protocol.disconnect()
if __name__ == "__main__":
# check arguments
parser = argparse.ArgumentParser(description="STC10/11/12 series MCU ISP flash tool")
parser.add_argument("binary", help="binary file to flash", type=argparse.FileType("rb"), nargs='?')
parser.add_argument("-p", "--port", help="serial port device", default="/dev/ttyUSB0")
parser.add_argument("-b", "--baud", help="transfer baud rate (default: 19200)", type=int, default=19200)
parser.add_argument("-l", "--handshake", help="handshake baud rate (default: 2400)", type=int, default=2400)
parser.add_argument("-b", "--baud", help="transfer baud rate (default: 19200)", type=BaudType(), default=19200)
parser.add_argument("-l", "--handshake", help="handshake baud rate (default: 2400)", type=BaudType(), default=2400)
parser.add_argument("-o", "--option", help="set option (can be used multiple times)", action="append")
opts = parser.parse_args()