Implement sensible error handling
This commit is contained in:
		
							
								
								
									
										168
									
								
								stcgal.py
									
									
									
									
									
								
							
							
						
						
									
										168
									
								
								stcgal.py
									
									
									
									
									
								
							@ -25,8 +25,6 @@
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
TODO:
 | 
			
		||||
- Utils class?
 | 
			
		||||
- error/exception handling
 | 
			
		||||
- some more documentation / comments
 | 
			
		||||
- private member naming, other style issues
 | 
			
		||||
 | 
			
		||||
@ -41,14 +39,34 @@ import argparse
 | 
			
		||||
DEBUG = False
 | 
			
		||||
 | 
			
		||||
class Utils:
 | 
			
		||||
    """make sensible boolean from string or other type value"""
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def to_bool(self, val):
 | 
			
		||||
        """make sensible boolean from string or other type value"""
 | 
			
		||||
 | 
			
		||||
        if isinstance(val, bool): return val
 | 
			
		||||
        if isinstance(val, int): return bool(val)
 | 
			
		||||
        if len(val) == 0: return False
 | 
			
		||||
        return True if val[0].lower() == "t" or val[0] == "1" else False
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def to_int(self, val):
 | 
			
		||||
        """make int from any value, nice error message if not possible"""
 | 
			
		||||
 | 
			
		||||
        try: return int(val, 0)
 | 
			
		||||
        except: raise ValueError("invalid integer")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaudType:
 | 
			
		||||
    """Check baud rate for validity"""
 | 
			
		||||
 | 
			
		||||
    def __call__(self, string):
 | 
			
		||||
        baud = int(string)
 | 
			
		||||
        if baud not in serial.Serial.BAUDRATES:
 | 
			
		||||
            raise argparse.ArgumentTypeError("illegal baudrate")
 | 
			
		||||
        return baud
 | 
			
		||||
 | 
			
		||||
    def __repr__(self): return "baudrate"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Stc12Option:
 | 
			
		||||
    """Manipulate STC10/11/12 series option bytes"""
 | 
			
		||||
@ -113,9 +131,10 @@ class Stc12Option:
 | 
			
		||||
        return 2 ** (((self.msr[0] >> 4) & 0x03) + 12)
 | 
			
		||||
 | 
			
		||||
    def set_osc_stable_delay(self, val):
 | 
			
		||||
        val = int(val, 0)
 | 
			
		||||
        val = Utils.to_int(val)
 | 
			
		||||
        osc_vals = {4096: 0, 8192: 1, 16384: 2, 32768: 3}
 | 
			
		||||
        if val not in osc_vals.keys(): raise ValueError
 | 
			
		||||
        if val not in osc_vals.keys():
 | 
			
		||||
            raise ValueError("must be one of %s" % list(osc_vals.keys()))
 | 
			
		||||
        self.msr[0] &= 0x8f
 | 
			
		||||
        self.msr[0] |= osc_vals[val] << 4
 | 
			
		||||
 | 
			
		||||
@ -125,7 +144,8 @@ class Stc12Option:
 | 
			
		||||
 | 
			
		||||
    def set_por_delay(self, val):
 | 
			
		||||
        delays = {"short": 1, "long": 0}
 | 
			
		||||
        if val not in delays.keys(): raise ValueError
 | 
			
		||||
        if val not in delays.keys():
 | 
			
		||||
            raise ValueError("must be one of %s" % list(delays.keys()))
 | 
			
		||||
        self.msr[1] &= 0x7f
 | 
			
		||||
        self.msr[1] |= delays[val] << 7
 | 
			
		||||
 | 
			
		||||
@ -135,7 +155,8 @@ class Stc12Option:
 | 
			
		||||
 | 
			
		||||
    def set_clock_gain(self, val):
 | 
			
		||||
        gains = {"low": 0, "high": 1}
 | 
			
		||||
        if val not in gains.keys(): raise ValueError
 | 
			
		||||
        if val not in gains.keys():
 | 
			
		||||
            raise ValueError("must be one of %s" % list(gains.keys()))
 | 
			
		||||
        self.msr[1] &= 0xbf
 | 
			
		||||
        self.msr[1] |= gains[val] << 6
 | 
			
		||||
 | 
			
		||||
@ -145,7 +166,8 @@ class Stc12Option:
 | 
			
		||||
 | 
			
		||||
    def set_clock_source(self, val):
 | 
			
		||||
        sources = {"internal": 0, "external": 1}
 | 
			
		||||
        if val not in sources.keys(): raise ValueError
 | 
			
		||||
        if val not in sources.keys():
 | 
			
		||||
            raise ValueError("must be one of %s" % list(sources.keys()))
 | 
			
		||||
        self.msr[1] &= 0xfd
 | 
			
		||||
        self.msr[1] |= sources[val] << 1
 | 
			
		||||
 | 
			
		||||
@ -169,9 +191,10 @@ class Stc12Option:
 | 
			
		||||
        return 2 ** (((self.msr[2]) & 0x07) + 1)
 | 
			
		||||
 | 
			
		||||
    def set_watchdog_prescale(self, val):
 | 
			
		||||
        val = int(val, 0)
 | 
			
		||||
        val = Utils.to_int(val)
 | 
			
		||||
        wd_vals = {2: 0, 4: 1, 8: 2, 16: 3, 32: 4, 64: 5, 128: 6, 256: 7}
 | 
			
		||||
        if val not in wd_vals.keys(): raise ValueError
 | 
			
		||||
        if val not in wd_vals.keys():
 | 
			
		||||
            raise ValueError("must be one of %s" % list(wd_vals.keys()))
 | 
			
		||||
        self.msr[2] &= 0xf8
 | 
			
		||||
        self.msr[2] |= wd_vals[val]
 | 
			
		||||
 | 
			
		||||
@ -245,18 +268,14 @@ class Stc12Protocol:
 | 
			
		||||
        packet = bytes()
 | 
			
		||||
        packet += self.ser.read(2)
 | 
			
		||||
        if packet[0:2] != self.PACKET_START:
 | 
			
		||||
            print("Wrong magic (%s), discarding packet!" %
 | 
			
		||||
                  packet[0:2], file=sys.stderr)
 | 
			
		||||
            self.dump_packet(packet)
 | 
			
		||||
            return None
 | 
			
		||||
            raise RuntimeError("wrong packet frame start")
 | 
			
		||||
 | 
			
		||||
        # read direction and length
 | 
			
		||||
        packet += self.ser.read(3)
 | 
			
		||||
        if packet[2] != self.PACKET_MCU[0]:
 | 
			
		||||
            print("Wrong direction (%s), discarding packet!" %
 | 
			
		||||
                  hex(packet[3]), file=sys.stderr)
 | 
			
		||||
            self.dump_packet(packet)
 | 
			
		||||
            return None
 | 
			
		||||
            raise RuntimeError("wrong packet direction magic")
 | 
			
		||||
 | 
			
		||||
        # read packet data
 | 
			
		||||
        packet_len, = struct.unpack(">H", packet[3:5])
 | 
			
		||||
@ -264,19 +283,15 @@ class Stc12Protocol:
 | 
			
		||||
 | 
			
		||||
        # verify end code
 | 
			
		||||
        if packet[packet_len+1] != self.PACKET_END[0]:
 | 
			
		||||
            print("Wrong end code (%s), discarding packet!" %
 | 
			
		||||
                  hex(packet[packet_len+1]), file=sys.stderr)
 | 
			
		||||
            self.dump_packet(packet)
 | 
			
		||||
            return None
 | 
			
		||||
            raise RuntimeError("wrong packet frame end")
 | 
			
		||||
 | 
			
		||||
        # verify checksum
 | 
			
		||||
        packet_csum, = struct.unpack(">H", packet[packet_len-1:packet_len+1])
 | 
			
		||||
        calc_csum = sum(packet[2:packet_len-1]) & 0xffff
 | 
			
		||||
        if packet_csum != calc_csum:
 | 
			
		||||
            print("Wrong checksum (%s, expected %s), discarding packet!" %
 | 
			
		||||
                  (hex(packet_csum), hex(calc_csum)), file=sys.stderr)
 | 
			
		||||
            self.dump_packet(packet)
 | 
			
		||||
            return None
 | 
			
		||||
            raise RuntimeError("packet checksum mismatch")
 | 
			
		||||
 | 
			
		||||
        self.dump_packet(packet, receive=True)
 | 
			
		||||
 | 
			
		||||
@ -331,7 +346,8 @@ class Stc12Protocol:
 | 
			
		||||
        # baudrate is directly controlled by programming the MCU's BRT register
 | 
			
		||||
        brt = 256 - round((self.mcu_clock_hz) / (self.baud_transfer * 16))
 | 
			
		||||
        brt_csum = (2 * (256 - brt)) & 0xff
 | 
			
		||||
        baud_actual = (self.mcu_clock_hz) / (16 * (256 - brt))
 | 
			
		||||
        try: baud_actual = (self.mcu_clock_hz) / (16 * (256 - brt))
 | 
			
		||||
        except ZeroDivisionError: raise RuntimeError("baudrate too high")
 | 
			
		||||
        baud_error = (abs(self.baud_transfer - baud_actual) * 100.0) / self.baud_transfer
 | 
			
		||||
        if baud_error > 5.0:
 | 
			
		||||
            print("WARNING: baud rate error is %.2f%%. You may need to set a slower rate." %
 | 
			
		||||
@ -380,16 +396,13 @@ class Stc12Protocol:
 | 
			
		||||
 | 
			
		||||
        # read status packet
 | 
			
		||||
        status_packet = self.read_packet()
 | 
			
		||||
        if status_packet == None or status_packet[0] != 0x50:
 | 
			
		||||
            print("Error receiving status packet, aborting!", file=sys.stderr)
 | 
			
		||||
            return False
 | 
			
		||||
        if status_packet[0] != 0x50:
 | 
			
		||||
            raise RuntimeError("wrong magic in status packet")
 | 
			
		||||
        self.decode_status_packet(status_packet)
 | 
			
		||||
        self.print_mcu_info()
 | 
			
		||||
        self.options = Stc12Option(status_packet[23:27])
 | 
			
		||||
        self.options.print()
 | 
			
		||||
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def handshake(self):
 | 
			
		||||
        """Do baudrate handshake
 | 
			
		||||
 | 
			
		||||
@ -404,9 +417,8 @@ class Stc12Protocol:
 | 
			
		||||
        packet += struct.pack(">H", self.mcu_magic)
 | 
			
		||||
        self.write_packet(packet)
 | 
			
		||||
        response = self.read_packet()
 | 
			
		||||
        if response == None or response[0] != 0x8f:
 | 
			
		||||
            print("Error receiving handshake packet, aborting!", file=sys.stderr)
 | 
			
		||||
            return False
 | 
			
		||||
        if response[0] != 0x8f:
 | 
			
		||||
            raise RuntimeError("wrong magic in handshake packet")
 | 
			
		||||
 | 
			
		||||
        # test new settings
 | 
			
		||||
        print("testing...", end="")
 | 
			
		||||
@ -417,9 +429,8 @@ class Stc12Protocol:
 | 
			
		||||
        self.ser.baudrate = self.baud_transfer
 | 
			
		||||
        response = self.read_packet()
 | 
			
		||||
        self.ser.baudrate = self.baud_handshake
 | 
			
		||||
        if response == None or response[0] != 0x8f:
 | 
			
		||||
            print("Error receiving handshake packet, aborting!", file=sys.stderr)
 | 
			
		||||
            return False
 | 
			
		||||
        if response[0] != 0x8f:
 | 
			
		||||
            raise RuntimeError("wrong magic in handshake packet")
 | 
			
		||||
 | 
			
		||||
        # switch to the settings
 | 
			
		||||
        print("setting...", end="")
 | 
			
		||||
@ -429,12 +440,10 @@ class Stc12Protocol:
 | 
			
		||||
        time.sleep(0.2)
 | 
			
		||||
        self.ser.baudrate = self.baud_transfer
 | 
			
		||||
        response = self.read_packet()
 | 
			
		||||
        if response == None or response[0] != 0x84:
 | 
			
		||||
            print("Error receiving handshake packet, aborting!", file=sys.stderr)
 | 
			
		||||
            return False
 | 
			
		||||
        print("done")
 | 
			
		||||
        if response[0] != 0x84:
 | 
			
		||||
            raise RuntimeError("wrong magic in handshake packet")
 | 
			
		||||
 | 
			
		||||
        return True
 | 
			
		||||
        print("done")
 | 
			
		||||
 | 
			
		||||
    def erase_flash(self, erase_size, flash_size):
 | 
			
		||||
        """Erase the MCU's flash memory.
 | 
			
		||||
@ -451,10 +460,8 @@ class Stc12Protocol:
 | 
			
		||||
        for i in range(0x80, 0x0d, -1): packet += bytes([i])
 | 
			
		||||
        self.write_packet(packet)
 | 
			
		||||
        response = self.read_packet()
 | 
			
		||||
        if response == None or response[0] != 0x00:
 | 
			
		||||
            print("Error receiving erase response, aborting!", file=sys.stderr)
 | 
			
		||||
            return False
 | 
			
		||||
        return True
 | 
			
		||||
        if response[0] != 0x00:
 | 
			
		||||
            raise RuntimeError("wrong magic in erase packet")
 | 
			
		||||
 | 
			
		||||
    def program_flash(self, addr, data):
 | 
			
		||||
        """Program the MCU's flash memory.
 | 
			
		||||
@ -474,12 +481,10 @@ class Stc12Protocol:
 | 
			
		||||
            csum = sum(packet[7:]) & 0xff
 | 
			
		||||
            self.write_packet(packet)
 | 
			
		||||
            response = self.read_packet()
 | 
			
		||||
            if response == None or response[0] != 0x00:
 | 
			
		||||
                print("Error receiving program response packet, aborting!", file=sys.stderr)
 | 
			
		||||
                return False
 | 
			
		||||
            if response[0] != 0x00:
 | 
			
		||||
                raise RuntimeError("wrong magic in write packet")
 | 
			
		||||
            elif response[1] != csum:
 | 
			
		||||
                print("Wrong checksum in program response (%s, expected %s), aborting!" %
 | 
			
		||||
                      (hex(response[1]), hex(csum)), file=sys.stderr)
 | 
			
		||||
                raise RuntimeError("verification checksum mismatch")
 | 
			
		||||
            print(".", end="")
 | 
			
		||||
            sys.stdout.flush()
 | 
			
		||||
        print()
 | 
			
		||||
@ -488,13 +493,10 @@ class Stc12Protocol:
 | 
			
		||||
        packet += struct.pack(">H", self.mcu_magic)
 | 
			
		||||
        self.write_packet(packet)
 | 
			
		||||
        response = self.read_packet()
 | 
			
		||||
        if response == None or response[0] != 0x8d:
 | 
			
		||||
            print("Error receiving program finish response packet, aborting!", file=sys.stderr)
 | 
			
		||||
            return False
 | 
			
		||||
        if response[0] != 0x8d:
 | 
			
		||||
            raise RuntimeError("wrong magic in finish packet")
 | 
			
		||||
        print("Finished writing flash!")
 | 
			
		||||
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def set_option(self, name, value):
 | 
			
		||||
        self.options.set_option(name, value)
 | 
			
		||||
 | 
			
		||||
@ -509,16 +511,13 @@ class Stc12Protocol:
 | 
			
		||||
        packet += struct.pack(">I", int(self.mcu_clock_hz))
 | 
			
		||||
        self.write_packet(packet)
 | 
			
		||||
        response = self.read_packet()
 | 
			
		||||
        if response == None or response[0] != 0x50:
 | 
			
		||||
            print("Error receiving set options response packet, aborting!", file=sys.stderr)
 | 
			
		||||
            return False
 | 
			
		||||
        if response[0] != 0x50:
 | 
			
		||||
            raise RuntimeError("wrong magic in option packet")
 | 
			
		||||
 | 
			
		||||
        print("Target UID: %02x%02x%02x%02x%02x%02x%02x" %
 | 
			
		||||
              (response[18], response[19], response[20], response[21],
 | 
			
		||||
               response[22], response[23], response[24]))
 | 
			
		||||
 | 
			
		||||
        return True
 | 
			
		||||
 | 
			
		||||
    def disconnect(self):
 | 
			
		||||
        """Disconnect from MCU"""
 | 
			
		||||
 | 
			
		||||
@ -535,31 +534,62 @@ class StcGal:
 | 
			
		||||
        self.opts = opts
 | 
			
		||||
        self.protocol = Stc12Protocol(opts.port, opts.handshake, opts.baud)
 | 
			
		||||
 | 
			
		||||
    def run(self):
 | 
			
		||||
        self.protocol.connect()
 | 
			
		||||
 | 
			
		||||
        if opts.binary:
 | 
			
		||||
            bindata = opts.binary.read()
 | 
			
		||||
 | 
			
		||||
            if opts.option:
 | 
			
		||||
                for o in opts.option:
 | 
			
		||||
    def emit_options(self, options):
 | 
			
		||||
        for o in options:
 | 
			
		||||
            try:
 | 
			
		||||
                k, v = o.split("=", 1)
 | 
			
		||||
                self.protocol.set_option(k, v)
 | 
			
		||||
            except ValueError as e:
 | 
			
		||||
                raise NameError("invalid option '%s' (%s)" % (k, e))
 | 
			
		||||
 | 
			
		||||
    def run(self):
 | 
			
		||||
        try: self.protocol.connect()
 | 
			
		||||
        except KeyboardInterrupt:
 | 
			
		||||
            print("interrupted")
 | 
			
		||||
            return 2
 | 
			
		||||
        except RuntimeError as e:
 | 
			
		||||
            print("Communication error: %s" % e, file=sys.stderr)
 | 
			
		||||
            return 1
 | 
			
		||||
        except serial.serialutil.SerialException as e:
 | 
			
		||||
            print("Serial communication error: %s" % e, file=sys.stderr)
 | 
			
		||||
            return 1
 | 
			
		||||
 | 
			
		||||
        if opts.binary:
 | 
			
		||||
            try:
 | 
			
		||||
                bindata = opts.binary.read()
 | 
			
		||||
 | 
			
		||||
                if opts.option: self.emit_options(opts.option)
 | 
			
		||||
 | 
			
		||||
                self.protocol.handshake()
 | 
			
		||||
                self.protocol.erase_flash(len(bindata), 0xf0 * 256)
 | 
			
		||||
                self.protocol.program_flash(0, bindata)
 | 
			
		||||
                self.protocol.program_options()
 | 
			
		||||
 | 
			
		||||
                self.protocol.disconnect()
 | 
			
		||||
                return 0
 | 
			
		||||
            except NameError as e:
 | 
			
		||||
                print("Option error: %s" % e, file=sys.stderr)
 | 
			
		||||
                self.protocol.disconnect()
 | 
			
		||||
                return 1
 | 
			
		||||
            except RuntimeError as e:
 | 
			
		||||
                print("Communication error: %s" % e, file=sys.stderr)
 | 
			
		||||
                self.protocol.disconnect()
 | 
			
		||||
                return 1
 | 
			
		||||
            except KeyboardInterrupt:
 | 
			
		||||
                print("interrupted")
 | 
			
		||||
                self.protocol.disconnect()
 | 
			
		||||
                return 2
 | 
			
		||||
            except serial.serialutil.SerialException as e:
 | 
			
		||||
                print("Serial communication error: %s" % e, file=sys.stderr)
 | 
			
		||||
                return 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    # check arguments
 | 
			
		||||
    parser = argparse.ArgumentParser(description="STC10/11/12 series MCU ISP flash tool")
 | 
			
		||||
    parser.add_argument("binary", help="binary file to flash", type=argparse.FileType("rb"), nargs='?')
 | 
			
		||||
    parser.add_argument("-p", "--port", help="serial port device", default="/dev/ttyUSB0")
 | 
			
		||||
    parser.add_argument("-b", "--baud", help="transfer baud rate (default: 19200)", type=int, default=19200)
 | 
			
		||||
    parser.add_argument("-l", "--handshake", help="handshake baud rate (default: 2400)", type=int, default=2400)
 | 
			
		||||
    parser.add_argument("-b", "--baud", help="transfer baud rate (default: 19200)", type=BaudType(), default=19200)
 | 
			
		||||
    parser.add_argument("-l", "--handshake", help="handshake baud rate (default: 2400)", type=BaudType(), default=2400)
 | 
			
		||||
    parser.add_argument("-o", "--option", help="set option (can be used multiple times)", action="append")
 | 
			
		||||
    opts = parser.parse_args()
 | 
			
		||||
    
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user