2 Commits

Author SHA1 Message Date
a414bfb864 protocols.py: Increase clock_safety_factor to 2.5
This fixes cp2102 and ch341a baudswitch on mac sierra

Signed-off-by: Andrew 'ncrmnt' Andrianov <andrew@ncrmnt.org>
2017-10-14 21:08:26 +03:00
68d19f7b88 Use calculated delays
Some serial drivers don't handle draining the transmit buffer
correctly. This has been handled with a long delay so far, which might
be problematic. There's a race condition with some protocol versions.

Until STC15, the baud rate switch is initiated with a command sent by
stcgal, which is replied to by the MCU with the new baud rate. So the
switch of the baud rate has to be done after the command has finished
transmission, but before the MCU has started to transmit the response.

This change calculates the minimum delay needed (with some tolerance
added) so that it's unlikely that the baud rate switch will happen
too late.
2017-10-10 22:11:07 +02:00
18 changed files with 231 additions and 451 deletions

8
.gitignore vendored
View File

@ -1,12 +1,8 @@
*~
*.pyc
*.egg-info
*.eggs/
*.pybuild/
__pycache__/
__pycache__
/build
/dist
/deb_dist
/debian/stcgal*
/debian/files
/.vscode
/.vscode

View File

@ -27,7 +27,7 @@ from setuptools import setup, find_packages
setup(
name = "stcgal",
version = stcgal.__version__,
packages = find_packages(exclude=["doc", "tests"]),
packages = find_packages(exclude=["doc", "test"]),
install_requires = ["pyserial"],
extras_require = {
"usb": ["pyusb>=1.0.0"]
@ -55,6 +55,6 @@ setup(
"Topic :: Software Development :: Embedded Systems",
"Topic :: Software Development",
],
test_suite = "tests",
test_suite = "test",
tests_require = ["PyYAML"],
)

View File

@ -32,10 +32,7 @@ class StcGal:
def __init__(self, opts):
self.opts = opts
self.initialize_protocol(opts)
def initialize_protocol(self, opts):
"""Initialize protocol backend"""
if opts.protocol == "stc89":
self.protocol = Stc89Protocol(opts.port, opts.handshake, opts.baud)
elif opts.protocol == "stc12a":
@ -53,7 +50,8 @@ class StcGal:
elif opts.protocol == "usb15":
self.protocol = StcUsb15Protocol()
else:
self.protocol = StcAutoProtocol(opts.port, opts.handshake, opts.baud)
self.protocol = StcBaseProtocol(opts.port, opts.handshake, opts.baud)
self.protocol.debug = opts.debug
def emit_options(self, options):
@ -135,14 +133,14 @@ class StcGal:
try:
self.protocol.connect(autoreset=self.opts.autoreset, resetcmd=self.opts.resetcmd)
if isinstance(self.protocol, StcAutoProtocol):
if self.opts.protocol == "auto":
if not self.protocol.protocol_name:
raise StcProtocolException("cannot detect protocol")
base_protocol = self.protocol
self.opts.protocol = self.protocol.protocol_name
print("Protocol detected: %s" % self.opts.protocol)
# recreate self.protocol with proper protocol class
self.initialize_protocol(self.opts)
self.__init__(self.opts)
else:
base_protocol = None

View File

@ -5,214 +5,201 @@
import struct
import codecs
class IHex(object):
@classmethod
def read(cls, lines):
ihex = cls()
class IHex:
"""Intel HEX parser and writer"""
segbase = 0
for line in lines:
line = line.strip()
if not line: continue
@classmethod
def read(cls, lines):
"""Read Intel HEX data from string or lines"""
ihex = cls()
t, a, d = ihex.parse_line(line)
if t == 0x00:
ihex.insert_data(segbase + a, d)
segbase = 0
for line in lines:
line = line.strip()
if not line:
continue
elif t == 0x01:
break # Should we check for garbage after this?
t, a, d = ihex.parse_line(line)
if t == 0x00:
ihex.insert_data(segbase + a, d)
elif t == 0x02:
ihex.set_mode(16)
segbase = struct.unpack(">H", d[0:2])[0] << 4
elif t == 0x01:
break # Should we check for garbage after this?
elif t == 0x03:
ihex.set_mode(16)
elif t == 0x02:
ihex.set_mode(16)
segbase = struct.unpack(">H", d[0:2])[0] << 4
cs, ip = struct.unpack(">2H", d[0:2])
ihex.set_start((cs, ip))
elif t == 0x03:
ihex.set_mode(16)
elif t == 0x04:
ihex.set_mode(32)
segbase = struct.unpack(">H", d[0:2])[0] << 16
cs, ip = struct.unpack(">2H", d[0:2])
ihex.set_start((cs, ip))
elif t == 0x05:
ihex.set_mode(32)
ihex.set_start(struct.unpack(">I", d[0:4])[0])
elif t == 0x04:
ihex.set_mode(32)
segbase = struct.unpack(">H", d[0:2])[0] << 16
else:
raise ValueError("Invalid type byte")
elif t == 0x05:
ihex.set_mode(32)
ihex.set_start(struct.unpack(">I", d[0:4])[0])
return ihex
else:
raise ValueError("Invalid type byte")
@classmethod
def read_file(cls, fname):
f = open(fname, "rb")
ihex = cls.read(f)
f.close()
return ihex
return ihex
def __init__(self):
self.areas = {}
self.start = None
self.mode = 8
self.row_bytes = 16
@classmethod
def read_file(cls, fname):
"""Read Intel HEX data from file"""
f = open(fname, "rb")
ihex = cls.read(f)
f.close()
return ihex
def set_row_bytes(self, row_bytes):
"""Set output hex file row width (bytes represented per row)."""
if row_bytes < 1 or row_bytes > 0xff:
raise ValueError("Value out of range: (%r)" % row_bytes)
self.row_bytes = row_bytes
def extract_data(self, start=None, end=None):
if start is None:
start = 0
if end is None:
result = bytearray()
for addr, data in self.areas.items():
if addr >= start:
if len(result) < (addr - start):
result[len(result):addr-start] = bytes(addr-start-len(result))
result[addr-start:addr-start+len(data)] = data
return bytes(result)
else:
result = bytearray()
for addr, data in self.areas.items():
if addr >= start and addr < end:
data = data[:end-addr]
if len(result) < (addr - start):
result[len(result):addr-start] = bytes(addr-start-len(result))
result[addr-start:addr-start+len(data)] = data
return bytes(result)
def set_start(self, start=None):
self.start = start
def __init__(self):
self.areas = {}
self.start = None
self.mode = 8
self.row_bytes = 16
def set_mode(self, mode):
self.mode = mode
def set_row_bytes(self, row_bytes):
"""Set output hex file row width (bytes represented per row)."""
if row_bytes < 1 or row_bytes > 0xff:
raise ValueError("Value out of range: (%r)" % row_bytes)
self.row_bytes = row_bytes
def get_area(self, addr):
for start, data in self.areas.items():
end = start + len(data)
if addr >= start and addr <= end:
return start
def extract_data(self, start=None, end=None):
"""Extract binary data"""
if start is None:
start = 0
return None
if end is None:
result = bytearray()
def insert_data(self, istart, idata):
iend = istart + len(idata)
for addr, data in self.areas.items():
if addr >= start:
if len(result) < (addr - start):
result[len(result):addr - start] = bytes(
addr - start - len(result))
result[addr - start:addr - start + len(data)] = data
area = self.get_area(istart)
if area is None:
self.areas[istart] = idata
return bytes(result)
else:
data = self.areas[area]
# istart - iend + len(idata) + len(data)
self.areas[area] = data[:istart-area] + idata + data[iend-area:]
else:
result = bytearray()
def calc_checksum(self, bytes):
total = sum(bytes)
return (-total) & 0xFF
for addr, data in self.areas.items():
if addr >= start and addr < end:
data = data[:end - addr]
if len(result) < (addr - start):
result[len(result):addr - start] = bytes(
addr - start - len(result))
result[addr - start:addr - start + len(data)] = data
def parse_line(self, rawline):
if rawline[0:1] != b":":
raise ValueError("Invalid line start character (%r)" % rawline[0])
return bytes(result)
try:
#line = rawline[1:].decode("hex")
line = codecs.decode(rawline[1:], "hex_codec")
except:
raise ValueError("Invalid hex data")
def set_start(self, start=None):
self.start = start
length, addr, type = struct.unpack(">BHB", line[:4])
def set_mode(self, mode):
self.mode = mode
dataend = length + 4
data = line[4:dataend]
def get_area(self, addr):
for start, data in self.areas.items():
end = start + len(data)
if addr >= start and addr <= end:
return start
#~ print line[dataend:dataend + 2], repr(line)
cs1 = line[dataend]
cs2 = self.calc_checksum(line[:dataend])
return None
if cs1 != cs2:
raise ValueError("Checksums do not match")
def insert_data(self, istart, idata):
iend = istart + len(idata)
return (type, addr, data)
area = self.get_area(istart)
if area is None:
self.areas[istart] = idata
def make_line(self, type, addr, data):
line = struct.pack(">BHB", len(data), addr, type)
line += data
line += chr(self.calc_checksum(line))
#~ return ":" + line.encode("hex")
return ":" + line.encode("hex").upper() + "\r\n"
else:
data = self.areas[area]
# istart - iend + len(idata) + len(data)
self.areas[area] = data[
:istart - area] + idata + data[iend - area:]
def write(self):
output = ""
for start, data in sorted(self.areas.items()):
i = 0
segbase = 0
def calc_checksum(self, data):
total = sum(data)
return (-total) & 0xFF
while i < len(data):
chunk = data[i:i + self.row_bytes]
def parse_line(self, rawline):
if rawline[0:1] != b":":
raise ValueError("Invalid line start character (%r)" % rawline[0])
addr = start
newsegbase = segbase
try:
line = codecs.decode(rawline[1:], "hex_codec")
except:
raise ValueError("Invalid hex data")
if self.mode == 8:
addr = addr & 0xFFFF
length, addr, line_type = struct.unpack(">BHB", line[:4])
elif self.mode == 16:
t = addr & 0xFFFF
newsegbase = (addr - t) >> 4
addr = t
dataend = length + 4
data = line[4:dataend]
if newsegbase != segbase:
output += self.make_line(0x02, 0, struct.pack(">H", newsegbase))
segbase = newsegbase
cs1 = line[dataend]
cs2 = self.calc_checksum(line[:dataend])
elif self.mode == 32:
newsegbase = addr >> 16
addr = addr & 0xFFFF
if cs1 != cs2:
raise ValueError("Checksums do not match")
if newsegbase != segbase:
output += self.make_line(0x04, 0, struct.pack(">H", newsegbase))
segbase = newsegbase
return (line_type, addr, data)
output += self.make_line(0x00, addr, chunk)
def make_line(self, line_type, addr, data):
line = struct.pack(">BHB", len(data), addr, line_type)
line += data
line += chr(self.calc_checksum(line))
return ":" + line.encode("hex").upper() + "\r\n"
i += self.row_bytes
start += self.row_bytes
def write(self):
"""Write Intel HEX data to string"""
output = ""
if self.start is not None:
if self.mode == 16:
output += self.make_line(0x03, 0, struct.pack(">2H", self.start[0], self.start[1]))
elif self.mode == 32:
output += self.make_line(0x05, 0, struct.pack(">I", self.start))
for start, data in sorted(self.areas.items()):
i = 0
segbase = 0
output += self.make_line(0x01, 0, "")
return output
while i < len(data):
chunk = data[i:i + self.row_bytes]
addr = start
newsegbase = segbase
if self.mode == 8:
addr = addr & 0xFFFF
elif self.mode == 16:
t = addr & 0xFFFF
newsegbase = (addr - t) >> 4
addr = t
if newsegbase != segbase:
output += self.make_line(
0x02, 0, struct.pack(">H", newsegbase))
segbase = newsegbase
elif self.mode == 32:
newsegbase = addr >> 16
addr = addr & 0xFFFF
if newsegbase != segbase:
output += self.make_line(
0x04, 0, struct.pack(">H", newsegbase))
segbase = newsegbase
output += self.make_line(0x00, addr, chunk)
i += self.row_bytes
start += self.row_bytes
if self.start is not None:
if self.mode == 16:
output += self.make_line(
0x03, 0, struct.pack(">2H", self.start[0], self.start[1]))
elif self.mode == 32:
output += self.make_line(
0x05, 0, struct.pack(">I", self.start))
output += self.make_line(0x01, 0, "")
return output
def write_file(self, fname):
"""Write Intel HEX data to file"""
f = open(fname, "w")
f.write(self.write())
f.close()
def write_file(self, fname):
f = open(fname, "w")
f.write(self.write())
f.close()

View File

@ -21,24 +21,15 @@
#
import struct
from abc import ABC
from stcgal.utils import Utils
class BaseOption(ABC):
"""Base class for options"""
def __init__(self):
self.options = ()
self.msr = None
class BaseOption:
def print(self):
"""Print current configuration to standard output"""
print("Target options:")
for name, get_func, _ in self.options:
print(" %s=%s" % (name, get_func()))
def set_option(self, name, value):
"""Set value of a specific option"""
for opt, _, set_func in self.options:
if opt == name:
print("Option %s=%s" % (name, value))
@ -47,14 +38,12 @@ class BaseOption(ABC):
raise ValueError("unknown")
def get_option(self, name):
"""Get option value for a specific option"""
for opt, get_func, _ in self.options:
if opt == name:
return get_func(name)
raise ValueError("unknown")
def get_msr(self):
"""Get array of model-specific configuration registers"""
return bytes(self.msr)
@ -62,7 +51,6 @@ class Stc89Option(BaseOption):
"""Manipulation STC89 series option byte"""
def __init__(self, msr):
super().__init__()
self.msr = msr
self.options = (
("cpu_6t_enabled", self.get_t6, self.set_t6),
@ -141,7 +129,6 @@ class Stc12AOption(BaseOption):
"""Manipulate STC12A series option bytes"""
def __init__(self, msr):
super().__init__()
assert len(msr) == 4
self.msr = bytearray(msr)
@ -226,7 +213,6 @@ class Stc12Option(BaseOption):
"""Manipulate STC10/11/12 series option bytes"""
def __init__(self, msr):
super().__init__()
assert len(msr) == 4
self.msr = bytearray(msr)
@ -351,7 +337,6 @@ class Stc12Option(BaseOption):
class Stc15AOption(BaseOption):
def __init__(self, msr):
super().__init__()
assert len(msr) == 13
self.msr = bytearray(msr)
@ -450,7 +435,6 @@ class Stc15AOption(BaseOption):
class Stc15Option(BaseOption):
def __init__(self, msr):
super().__init__()
assert len(msr) >= 4
self.msr = bytearray(msr)

View File

@ -21,18 +21,12 @@
#
import serial
import sys
import os
import time
import struct
import re
import errno
import sys, os, time, struct, re, errno
import argparse
import collections
from stcgal.models import MCUModelDatabase
from stcgal.utils import Utils
from stcgal.options import Stc89Option, Stc12Option, Stc12AOption, Stc15Option, Stc15AOption
from abc import ABC, abstractmethod
from stcgal.options import *
import functools
try:
@ -52,7 +46,7 @@ class StcProtocolException(Exception):
pass
class StcBaseProtocol(ABC):
class StcBaseProtocol:
"""Basic functionality for STC BSL protocols"""
"""magic word that starts a packet"""
@ -109,10 +103,6 @@ class StcBaseProtocol(ABC):
return packet[5:-1]
@abstractmethod
def write_packet(self, packet_data):
pass
def read_packet(self):
"""Read and check packet from MCU.
@ -202,6 +192,20 @@ class StcBaseProtocol(ABC):
mcu_name += "E" if self.status_packet[17] < 0x70 else "W"
self.model = self.model._replace(name = mcu_name)
protocol_database = [("stc89", r"STC(89|90)(C|LE)\d"),
("stc12a", r"STC12(C|LE)\d052"),
("stc12b", r"STC12(C|LE)(52|56)"),
("stc12", r"(STC|IAP)(10|11|12)\D"),
("stc15a", r"(STC|IAP)15[FL][01]0\d(E|EA|)$"),
("stc15", r"(STC|IAP|IRC)15\D")]
for protocol_name, pattern in protocol_database:
if re.match(pattern, self.model.name):
self.protocol_name = protocol_name
break
else:
self.protocol_name = None
def get_status_packet(self):
"""Read and decode status packet"""
@ -234,6 +238,20 @@ class StcBaseProtocol(ABC):
return iap_wait
def delay_safely_written(self, length):
"""
Delay until data has been safely written and sent to device.
Some buggy serial drivers don't implement tcdrain/flush correctly.
That is, they wait until all data has been written to USB, but they
do not wait until the data has actually finished transmission.
Add additional delay to work around.
"""
bit_time = 1.0 / self.ser.baudrate
byte_time = bit_time * 11.0 # start, 8 data bits, stop, parity
clock_safety_factor = 2.5 # additional delay in case clock is slow
time.sleep(byte_time * length * clock_safety_factor)
def set_option(self, name, value):
self.options.set_option(name, value)
@ -283,8 +301,6 @@ class StcBaseProtocol(ABC):
try:
self.pulse()
self.status_packet = self.get_status_packet()
if len(self.status_packet) < 23:
raise StcProtocolException("status packet too short")
except (StcFramingException, serial.SerialTimeoutException): pass
print("done")
@ -294,21 +310,7 @@ class StcBaseProtocol(ABC):
self.initialize_model()
@abstractmethod
def initialize_status(self, status_packet):
"""Initialize internal state from status packet"""
pass
@abstractmethod
def initialize_options(self, status_packet):
"""Initialize options from status packet"""
pass
def initialize(self, base_protocol=None):
"""
Initialize from another instance. This is an alternative for calling
connect() and is used by protocol autodetection.
"""
def initialize(self, base_protocol = None):
if base_protocol:
self.ser = base_protocol.ser
self.ser.parity = self.PARITY
@ -336,39 +338,6 @@ class StcBaseProtocol(ABC):
print("Disconnected!")
class StcAutoProtocol(StcBaseProtocol):
"""
Protocol handler for autodetection of protocols. Does not implement full
functionality for any device class.
"""
def initialize_model(self):
super().initialize_model()
protocol_database = [("stc89", r"STC(89|90)(C|LE)\d"),
("stc12a", r"STC12(C|LE)\d052"),
("stc12b", r"STC12(C|LE)(52|56)"),
("stc12", r"(STC|IAP)(10|11|12)\D"),
("stc15a", r"(STC|IAP)15[FL][01]0\d(E|EA|)$"),
("stc15", r"(STC|IAP|IRC)15\D")]
for protocol_name, pattern in protocol_database:
if re.match(pattern, self.model.name):
self.protocol_name = protocol_name
break
else:
self.protocol_name = None
def initialize_options(self, status_packet):
raise NotImplementedError
def initialize_status(self, status_packet):
raise NotImplementedError
def write_packet(self, packet_data):
raise NotImplementedError
class Stc89Protocol(StcBaseProtocol):
"""Protocol handler for STC 89/90 series"""
@ -429,9 +398,6 @@ class Stc89Protocol(StcBaseProtocol):
def initialize_options(self, status_packet):
"""Initialize options"""
if len(status_packet) < 20:
raise StcProtocolException("invalid options in status packet")
self.options = Stc89Option(status_packet[19])
self.options.print()
@ -496,7 +462,7 @@ class Stc89Protocol(StcBaseProtocol):
packet += struct.pack(">H", brt)
packet += bytes([0xff - (brt >> 8), brt_csum, delay, iap])
self.write_packet(packet)
time.sleep(0.2)
self.delay_safely_written(len(packet))
self.ser.baudrate = self.baud_transfer
response = self.read_packet()
self.ser.baudrate = self.baud_handshake
@ -510,7 +476,7 @@ class Stc89Protocol(StcBaseProtocol):
packet += struct.pack(">H", brt)
packet += bytes([0xff - (brt >> 8), brt_csum, delay])
self.write_packet(packet)
time.sleep(0.2)
self.delay_safely_written(len(packet))
self.ser.baudrate = self.baud_transfer
response = self.read_packet()
if response[0] != 0x8e:
@ -564,9 +530,9 @@ class Stc89Protocol(StcBaseProtocol):
csum = sum(packet[7:]) & 0xff
self.write_packet(packet)
response = self.read_packet()
if len(response) < 1 or response[0] != 0x80:
if response[0] != 0x80:
raise StcProtocolException("incorrect magic in write packet")
elif len(response) < 2 or response[1] != csum:
elif response[1] != csum:
raise StcProtocolException("verification checksum mismatch")
print(".", end="")
sys.stdout.flush()
@ -668,9 +634,6 @@ class Stc12AProtocol(Stc12AOptionsMixIn, Stc89Protocol):
def initialize_options(self, status_packet):
"""Initialize options"""
if len(status_packet) < 31:
raise StcProtocolException("invalid options in status packet")
# create option state
self.options = Stc12AOption(status_packet[23:26] + status_packet[29:30])
self.options.print()
@ -689,7 +652,7 @@ class Stc12AProtocol(Stc12AOptionsMixIn, Stc89Protocol):
sys.stdout.flush()
packet = bytes([0x8f, 0xc0, brt, 0x3f, brt_csum, delay, iap])
self.write_packet(packet)
time.sleep(0.2)
self.delay_safely_written(len(packet))
self.ser.baudrate = self.baud_transfer
response = self.read_packet()
self.ser.baudrate = self.baud_handshake
@ -701,7 +664,7 @@ class Stc12AProtocol(Stc12AOptionsMixIn, Stc89Protocol):
sys.stdout.flush()
packet = bytes([0x8e, 0xc0, brt, 0x3f, brt_csum, delay])
self.write_packet(packet)
time.sleep(0.2)
self.delay_safely_written(len(packet))
self.ser.baudrate = self.baud_transfer
response = self.read_packet()
if response[0] != 0x8e:
@ -860,9 +823,6 @@ class Stc12BaseProtocol(StcBaseProtocol):
def initialize_options(self, status_packet):
"""Initialize options"""
if len(status_packet) < 29:
raise StcProtocolException("invalid options in status packet")
# create option state
self.options = Stc12Option(status_packet[23:26] + status_packet[27:28])
self.options.print()
@ -889,7 +849,7 @@ class Stc12BaseProtocol(StcBaseProtocol):
sys.stdout.flush()
packet = bytes([0x8f, 0xc0, brt, 0x3f, brt_csum, delay, iap])
self.write_packet(packet)
time.sleep(0.2)
self.delay_safely_written(len(packet))
self.ser.baudrate = self.baud_transfer
response = self.read_packet()
self.ser.baudrate = self.baud_handshake
@ -901,7 +861,7 @@ class Stc12BaseProtocol(StcBaseProtocol):
sys.stdout.flush()
packet = bytes([0x8e, 0xc0, brt, 0x3f, brt_csum, delay])
self.write_packet(packet)
time.sleep(0.2)
self.delay_safely_written(len(packet))
self.ser.baudrate = self.baud_transfer
response = self.read_packet()
if response[0] != 0x84:
@ -997,9 +957,6 @@ class Stc15AProtocol(Stc12Protocol):
def initialize_options(self, status_packet):
"""Initialize options"""
if len(status_packet) < 37:
raise StcProtocolException("invalid options in status packet")
# create option state
self.options = Stc15AOption(status_packet[23:36])
self.options.print()
@ -1112,19 +1069,15 @@ class Stc15AProtocol(Stc12Protocol):
self.write_packet(packet)
self.pulse(timeout=1.0)
response = self.read_packet()
if len(response) < 36 or response[0] != 0x65:
if response[0] != 0x65:
raise StcProtocolException("incorrect magic in handshake packet")
# determine programming speed trim value
target_trim_a, target_count_a = struct.unpack(">HH", response[28:32])
target_trim_b, target_count_b = struct.unpack(">HH", response[32:36])
if target_count_a == target_count_b:
raise StcProtocolException("frequency trimming failed")
m = (target_trim_b - target_trim_a) / (target_count_b - target_count_a)
n = target_trim_a - m * target_count_a
program_trim = round(m * program_count + n)
if program_trim > 65535 or program_trim < 0:
raise StcProtocolException("frequency trimming failed")
# determine trim trials for second round
trim_a, count_a = struct.unpack(">HH", response[12:16])
@ -1143,14 +1096,10 @@ class Stc15AProtocol(Stc12Protocol):
target_count_a = count_a
target_count_b = count_b
# linear interpolate to find range to try next
if target_count_a == target_count_b:
raise StcProtocolException("frequency trimming failed")
m = (target_trim_b - target_trim_a) / (target_count_b - target_count_a)
n = target_trim_a - m * target_count_a
target_trim = round(m * user_count + n)
target_trim_start = min(max(target_trim - 5, target_trim_a), target_trim_b)
if target_trim_start + 11 > 65535 or target_trim_start < 0:
raise StcProtocolException("frequency trimming failed")
# trim challenge-response, second round
packet = bytes([0x65])
@ -1162,7 +1111,7 @@ class Stc15AProtocol(Stc12Protocol):
self.write_packet(packet)
self.pulse(timeout=1.0)
response = self.read_packet()
if len(response) < 56 or response[0] != 0x65:
if response[0] != 0x65:
raise StcProtocolException("incorrect magic in handshake packet")
# determine best trim value
@ -1186,7 +1135,7 @@ class Stc15AProtocol(Stc12Protocol):
packet += struct.pack(">B", 230400 // self.baud_transfer)
packet += bytes([0xa1, 0x64, 0xb8, 0x00, iap_wait, 0x20, 0xff, 0x00])
self.write_packet(packet)
time.sleep(0.2)
self.delay_safely_written(len(packet))
self.ser.baudrate = self.baud_transfer
response = self.read_packet()
if response[0] != 0x84:
@ -1221,11 +1170,7 @@ class Stc15Protocol(Stc15AProtocol):
def initialize_options(self, status_packet):
"""Initialize options"""
if len(status_packet) < 14:
raise StcProtocolException("invalid options in status packet")
# create option state
# XXX: check how option bytes are concatenated here
self.options = Stc15Option(status_packet[5:8] + status_packet[12:13] + status_packet[37:38])
self.options.print()
@ -1270,8 +1215,6 @@ class Stc15Protocol(Stc15AProtocol):
calib_data = response[2:]
challenge_data = packet[2:]
calib_len = response[1]
if len(calib_data) < 2 * calib_len:
raise StcProtocolException("range calibration data missing")
for i in range(calib_len - 1):
count_a, count_b = struct.unpack(">HH", calib_data[2*i:2*i+4])
@ -1281,8 +1224,6 @@ class Stc15Protocol(Stc15AProtocol):
m = (trim_b - trim_a) / (count_b - count_a)
n = trim_a - m * count_a
target_trim = round(m * target_count + n)
if target_trim > 65536 or target_trim < 0:
raise StcProtocolException("frequency trimming failed")
return (target_trim, trim_range)
return None
@ -1294,8 +1235,6 @@ class Stc15Protocol(Stc15AProtocol):
calib_data = response[2:]
challenge_data = packet[2:]
calib_len = response[1]
if len(calib_data) < 2 * calib_len:
raise StcProtocolException("trim calibration data missing")
best = None
best_count = sys.maxsize
@ -1306,9 +1245,6 @@ class Stc15Protocol(Stc15AProtocol):
best_count = abs(count - target_count)
best = (trim_adj, trim_range), count
if not best:
raise StcProtocolException("frequency trimming failed")
return best
def calibrate(self):
@ -1338,7 +1274,7 @@ class Stc15Protocol(Stc15AProtocol):
self.write_packet(packet)
self.pulse(b"\xfe", timeout=1.0)
response = self.read_packet()
if len(response) < 2 or response[0] != 0x00:
if response[0] != 0x00:
raise StcProtocolException("incorrect magic in handshake packet")
# select ranges and trim values
@ -1357,7 +1293,7 @@ class Stc15Protocol(Stc15AProtocol):
self.write_packet(packet)
self.pulse(b"\xfe", timeout=1.0)
response = self.read_packet()
if len(response) < 2 or response[0] != 0x00:
if response[0] != 0x00:
raise StcProtocolException("incorrect magic in handshake packet")
# select final values
@ -1383,9 +1319,8 @@ class Stc15Protocol(Stc15AProtocol):
packet += bytes([iap_wait])
self.write_packet(packet)
response = self.read_packet()
if len(response) < 1 or response[0] != 0x01:
if response[0] != 0x01:
raise StcProtocolException("incorrect magic in handshake packet")
time.sleep(0.2)
self.ser.baudrate = self.baud_transfer
def switch_baud_ext(self):
@ -1400,9 +1335,8 @@ class Stc15Protocol(Stc15AProtocol):
packet += bytes([0x00, 0x00, iap_wait])
self.write_packet(packet)
response = self.read_packet()
if len(response) < 1 or response[0] != 0x01:
if response[0] != 0x01:
raise StcProtocolException("incorrect magic in handshake packet")
time.sleep(0.2)
self.ser.baudrate = self.baud_transfer
# for switching back to RC, program factory values
@ -1426,9 +1360,9 @@ class Stc15Protocol(Stc15AProtocol):
packet += bytes([0x00, 0x00, 0x5a, 0xa5])
self.write_packet(packet)
response = self.read_packet()
if len(response) == 1 and response[0] == 0x0f:
if response[0] == 0x0f:
raise StcProtocolException("MCU is locked")
if len(response) < 1 or response[0] != 0x05:
if response[0] != 0x05:
raise StcProtocolException("incorrect magic in handshake packet")
print("done")
@ -1449,17 +1383,13 @@ class Stc15Protocol(Stc15AProtocol):
packet += bytes([0x00, 0x5a, 0xa5])
self.write_packet(packet)
response = self.read_packet()
if len(response) < 1 or response[0] != 0x03:
if response[0] != 0x03:
raise StcProtocolException("incorrect magic in handshake packet")
print("done")
if len(response) >= 8:
self.uid = response[1:8]
# we should have a UID at this point
if not self.uid:
raise StcProtocolException("UID is missing")
def program_flash(self, data):
"""Program the MCU's flash memory."""
@ -1474,7 +1404,7 @@ class Stc15Protocol(Stc15AProtocol):
while len(packet) < self.PROGRAM_BLOCKSIZE + 3: packet += b"\x00"
self.write_packet(packet)
response = self.read_packet()
if len(response) < 2 or response[0] != 0x02 or response[1] != 0x54:
if response[0] != 0x02 or response[1] != 0x54:
raise StcProtocolException("incorrect magic in write packet")
print(".", end="")
sys.stdout.flush()
@ -1487,7 +1417,7 @@ class Stc15Protocol(Stc15AProtocol):
packet = bytes([0x07, 0x00, 0x00, 0x5a, 0xa5])
self.write_packet(packet)
response = self.read_packet()
if len(response) < 2 or response[0] != 0x07 or response[1] != 0x54:
if response[0] != 0x07 or response[1] != 0x54:
raise StcProtocolException("incorrect magic in finish packet")
print("done")
@ -1526,7 +1456,7 @@ class Stc15Protocol(Stc15AProtocol):
packet += self.build_options()
self.write_packet(packet)
response = self.read_packet()
if len(response) < 2 or response[0] != 0x04 or response[1] != 0x54:
if response[0] != 0x04 or response[1] != 0x54:
raise StcProtocolException("incorrect magic in option packet")
print("done")

View File

@ -29,12 +29,14 @@ class Utils:
def to_bool(cls, val):
"""make sensible boolean from string or other type value"""
if not val:
if val is None:
return False
if isinstance(val, bool):
return val
elif isinstance(val, int):
return bool(val)
elif len(val) == 0:
return False
else:
return True if val[0].lower() == "t" or val[0] == "1" else False

View File

@ -57,7 +57,7 @@ class ProgramTests(unittest.TestCase):
@patch("sys.stdout")
def test_program_stc89(self, out, sleep_mock, serial_mock, write_mock, read_mock):
"""Test a programming cycle with STC89 protocol"""
self._program_yml("./tests/stc89c52rc.yml", serial_mock, read_mock)
self._program_yml("./test/stc89c52rc.yml", serial_mock, read_mock)
@patch("stcgal.protocols.StcBaseProtocol.read_packet")
@patch("stcgal.protocols.Stc89Protocol.write_packet")
@ -66,7 +66,7 @@ class ProgramTests(unittest.TestCase):
@patch("sys.stdout")
def test_program_stc12(self, out, sleep_mock, serial_mock, write_mock, read_mock):
"""Test a programming cycle with STC12 protocol"""
self._program_yml("./tests/stc12c5a60s2.yml", serial_mock, read_mock)
self._program_yml("./test/stc12c5a60s2.yml", serial_mock, read_mock)
@patch("stcgal.protocols.StcBaseProtocol.read_packet")
@patch("stcgal.protocols.Stc89Protocol.write_packet")
@ -75,7 +75,7 @@ class ProgramTests(unittest.TestCase):
@patch("sys.stdout")
def test_program_stc12a(self, out, sleep_mock, serial_mock, write_mock, read_mock):
"""Test a programming cycle with STC12A protocol"""
self._program_yml("./tests/stc12c2052ad.yml", serial_mock, read_mock)
self._program_yml("./test/stc12c2052ad.yml", serial_mock, read_mock)
def test_program_stc12b(self):
"""Test a programming cycle with STC12B protocol"""
@ -88,7 +88,7 @@ class ProgramTests(unittest.TestCase):
@patch("sys.stdout")
def test_program_stc15f2(self, out, sleep_mock, serial_mock, write_mock, read_mock):
"""Test a programming cycle with STC15 protocol, F2 series"""
self._program_yml("./tests/iap15f2k61s2.yml", serial_mock, read_mock)
self._program_yml("./test/iap15f2k61s2.yml", serial_mock, read_mock)
@patch("stcgal.protocols.StcBaseProtocol.read_packet")
@patch("stcgal.protocols.Stc89Protocol.write_packet")
@ -97,7 +97,7 @@ class ProgramTests(unittest.TestCase):
@patch("sys.stdout")
def test_program_stc15w4(self, out, sleep_mock, serial_mock, write_mock, read_mock):
"""Test a programming cycle with STC15 protocol, W4 series"""
self._program_yml("./tests/stc15w4k56s4.yml", serial_mock, read_mock)
self._program_yml("./test/stc15w4k56s4.yml", serial_mock, read_mock)
@unittest.skip("trace is broken")
@patch("stcgal.protocols.StcBaseProtocol.read_packet")
@ -107,7 +107,7 @@ class ProgramTests(unittest.TestCase):
@patch("sys.stdout")
def test_program_stc15a(self, out, sleep_mock, serial_mock, write_mock, read_mock):
"""Test a programming cycle with STC15A protocol"""
self._program_yml("./tests/stc15f104e.yml", serial_mock, read_mock)
self._program_yml("./test/stc15f104e.yml", serial_mock, read_mock)
@patch("stcgal.protocols.StcBaseProtocol.read_packet")
@patch("stcgal.protocols.Stc89Protocol.write_packet")
@ -116,7 +116,7 @@ class ProgramTests(unittest.TestCase):
@patch("sys.stdout")
def test_program_stc15l1(self, out, sleep_mock, serial_mock, write_mock, read_mock):
"""Test a programming cycle with STC15 protocol, L1 series"""
self._program_yml("./tests/stc15l104w.yml", serial_mock, read_mock)
self._program_yml("./test/stc15l104w.yml", serial_mock, read_mock)
def test_program_stc15w4_usb(self):
"""Test a programming cycle with STC15W4 USB protocol"""
@ -133,3 +133,4 @@ class ProgramTests(unittest.TestCase):
read_mock.side_effect = convert_to_bytes(test_data["responses"])
gal = stcgal.frontend.StcGal(opts)
self.assertEqual(gal.run(), 0)

View File

@ -24,6 +24,7 @@
import argparse
import unittest
from unittest.mock import patch
from stcgal.utils import Utils, BaudType
class TestUtils(unittest.TestCase):

View File

@ -1,119 +0,0 @@
#
# Copyright (c) 2017 Grigori Goronzy <greg@chown.ath.cx>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""Tests with fuzzing of input data"""
import random
import sys
import unittest
from unittest.mock import patch
import yaml
import stcgal.frontend
import stcgal.protocols
from tests.test_program import get_default_opts, convert_to_bytes
class ByteArrayFuzzer:
"""Fuzzer for byte arrays"""
def __init__(self):
self.rng = random.Random()
self.cut_propability = 0.01 # probability for cutting off an array early
self.cut_min = 0 # minimum cut amount
self.cut_max = sys.maxsize # maximum cut amount
self.bitflip_probability = 0.0001 # probability for flipping a bit
self.randomize_probability = 0.001 # probability for randomizing a char
def fuzz(self, inp):
"""Fuzz an array of bytes according to predefined settings"""
arr = bytearray(inp)
arr = self.cut_off(arr)
self.randomize(arr)
return bytes(arr)
def randomize(self, arr):
"""Randomize array contents with bitflips and random bytes"""
for i, _ in enumerate(arr):
for j in range(8):
if self.rng.random() < self.bitflip_probability:
arr[i] ^= (1 << j)
if self.rng.random() < self.randomize_probability:
arr[i] = self.rng.getrandbits(8)
def cut_off(self, arr):
"""Cut off data from end of array"""
if self.rng.random() < self.cut_propability:
cut_limit = min(len(arr), self.cut_max)
cut_len = self.rng.randrange(self.cut_min, cut_limit)
arr = arr[0:len(arr) - cut_len]
return arr
class TestProgramFuzzed(unittest.TestCase):
"""Special programming cycle tests that use a fuzzing approach"""
@patch("stcgal.protocols.StcBaseProtocol.read_packet")
@patch("stcgal.protocols.Stc89Protocol.write_packet")
@patch("stcgal.protocols.serial.Serial", autospec=True)
@patch("stcgal.protocols.time.sleep")
@patch("sys.stdout")
@patch("sys.stderr")
def test_program_fuzz(self, err, out, sleep_mock, serial_mock, write_mock, read_mock):
"""Test programming cycles with fuzzing enabled"""
yml = [
"./tests/iap15f2k61s2.yml",
"./tests/stc12c2052ad.yml",
"./tests/stc15w4k56s4.yml",
"./tests/stc12c5a60s2.yml",
"./tests/stc89c52rc.yml",
"./tests/stc15l104w.yml",
"./tests/stc15f104e.yml",
]
fuzzer = ByteArrayFuzzer()
fuzzer.cut_propability = 0.01
fuzzer.bitflip_probability = 0.005
fuzzer.rng = random.Random(1)
for y in yml:
with self.subTest(msg="trace {}".format(y)):
self.single_fuzz(y, serial_mock, fuzzer, read_mock, err, out,
sleep_mock, write_mock)
def single_fuzz(self, yml, serial_mock, fuzzer, read_mock, err, out, sleep_mock, write_mock):
"""Test a single programming cycle with fuzzing"""
with open(yml) as test_file:
test_data = yaml.load(test_file.read())
for _ in range(1000):
with self.subTest():
opts = get_default_opts()
opts.protocol = test_data["protocol"]
opts.code_image.read.return_value = bytes(test_data["code_data"])
serial_mock.return_value.inWaiting.return_value = 1
fuzzed_responses = []
for arr in convert_to_bytes(test_data["responses"]):
fuzzed_responses.append(fuzzer.fuzz(arr))
read_mock.side_effect = fuzzed_responses
gal = stcgal.frontend.StcGal(opts)
self.assertGreaterEqual(gal.run(), 0)
err.reset_mock()
out.reset_mock()
sleep_mock.reset_mock()
serial_mock.reset_mock()
write_mock.reset_mock()
read_mock.reset_mock()