18 Commits

Author SHA1 Message Date
1a5cf18590 debian: Update Build-Depends and Depends
This closes #32

Signed-off-by: Andrew Andrianov <andrew@ncrmnt.org>
2017-10-25 21:31:14 +03:00
a5e1cc26ee Merge pull request #31 from nekromant/progressbar
Implement progress callback and tqdm progressbar
2017-10-22 15:55:45 +02:00
b77157bc40 .travis.yml: Install tqdm to make ci happy
Signed-off-by: Andrew Andrianov <andrew@ncrmnt.org>
2017-10-19 11:26:42 +03:00
092fbdc842 protocols.py: Implement progress callback and tqdm progressbar
Signed-off-by: Andrew Andrianov <andrew@ncrmnt.org>
2017-10-19 11:26:27 +03:00
e0bda73fed Merge pull request #29 from grigorig/advanced-tests
Advanced tests
2017-10-18 23:22:40 +02:00
57100062af Rename test/ to tests/ 2017-10-12 23:02:02 +02:00
030497beb0 Extract StcAutoProtocol class, fix autodetection
With the introduction of real abstract classes, it is not possible
anymore to instantiate StcBaseProtocol. Instead, extract some of the
code for autodetection into the new class StcAutoProtocol and use
that for autodetection.
2017-10-12 23:02:02 +02:00
fd923f3a92 Cleanup utils
Just a tiny simplification, found by pylint.
2017-10-12 23:02:02 +02:00
b145fb364a Remove unneeded include 2017-10-12 23:02:02 +02:00
a29c9bf42e Add fuzzing programming cycle tests 2017-10-12 23:02:02 +02:00
1cde6da007 stc15: check that a UID has been received
Found by fuzzing. In some cases it's possible that we end up without
a valid UID. Detect and workaround.
2017-10-12 23:02:02 +02:00
ca30a508aa Fix various issues in frequency trimming
Found by fuzzing. The frequency trimming functions did a bad job of
checking for possible out of bounds accesses and didn't handle various
failure cases correctly. Add suitable checks to fix the issues found.

v2: fix one check, add several new ones
2017-10-12 23:01:50 +02:00
b9208c4772 Add length checks for status packets
Fuzzing found a number of issues when status packets are cut short.
Introduce checks on the length of status packets to fix these issues.
2017-10-11 23:20:20 +02:00
ad5a89297f Check length of responses
Fuzzing found lots of issues when packets are cut short. This should
rarely happen, but stcgal should be able to handle it without crashing.

This adds length checks when checking the magic of packets or when
checking checksums.
2017-10-11 23:20:20 +02:00
0cb56f4919 Use abc for StcBaseProtocol
Use the abc module to declare StcBaseProtocol as an abstract base
class and clean up imports while at it.
2017-10-11 23:20:20 +02:00
f195258eb5 Clean up options utilities
Use abc to declare an abstract base class and add some documentation.
2017-10-11 23:20:20 +02:00
ff9530833d Update gitignore 2017-10-11 23:19:11 +02:00
8b0fdcb42a Clean up Intel HEX utilities
No functional change intended.
2017-10-11 23:18:59 +02:00
20 changed files with 470 additions and 233 deletions

8
.gitignore vendored
View File

@ -1,8 +1,12 @@
*~ *~
*.pyc *.pyc
*.egg-info *.egg-info
__pycache__ *.eggs/
*.pybuild/
__pycache__/
/build /build
/dist /dist
/deb_dist /deb_dist
/.vscode /debian/stcgal*
/debian/files
/.vscode

View File

@ -11,7 +11,7 @@ python:
before_install: before_install:
- sudo apt install rpm dpkg-dev debhelper dh-python python3-setuptools fakeroot python3-serial python3-yaml - sudo apt install rpm dpkg-dev debhelper dh-python python3-setuptools fakeroot python3-serial python3-yaml
install: install:
- pip install pyserial pyusb - pip install pyserial pyusb tqdm
script: script:
- python setup.py build - python setup.py build
- python setup.py test - python setup.py test

4
debian/control vendored
View File

@ -2,14 +2,14 @@ Source: stcgal
Section: electronics Section: electronics
Priority: optional Priority: optional
Maintainer: Andrew Andrianov <andrew@ncrmnt.org> Maintainer: Andrew Andrianov <andrew@ncrmnt.org>
Build-Depends: debhelper (>= 9), python3, python3-setuptools, dh-python Build-Depends: debhelper (>= 9), python3, python3-setuptools, dh-python, python3-serial, python3-tqdm, python3-yaml
Standards-Version: 3.9.5 Standards-Version: 3.9.5
Homepage: https://github.com/grigorig/stcgal Homepage: https://github.com/grigorig/stcgal
X-Python3-Version: >= 3.2 X-Python3-Version: >= 3.2
Package: stcgal Package: stcgal
Architecture: all Architecture: all
Depends: ${misc:Depends}, python3, python3-serial Depends: ${misc:Depends}, python3, python3-serial, python3-tqdm
Recommends: python3-usb (>= 1.0.0~b2) Recommends: python3-usb (>= 1.0.0~b2)
Description: STC MCU ISP flash tool Description: STC MCU ISP flash tool
stcgal is a command line flash programming tool for STC MCU Ltd. stcgal is a command line flash programming tool for STC MCU Ltd.

View File

@ -27,7 +27,7 @@ from setuptools import setup, find_packages
setup( setup(
name = "stcgal", name = "stcgal",
version = stcgal.__version__, version = stcgal.__version__,
packages = find_packages(exclude=["doc", "test"]), packages = find_packages(exclude=["doc", "tests"]),
install_requires = ["pyserial"], install_requires = ["pyserial"],
extras_require = { extras_require = {
"usb": ["pyusb>=1.0.0"] "usb": ["pyusb>=1.0.0"]
@ -55,6 +55,6 @@ setup(
"Topic :: Software Development :: Embedded Systems", "Topic :: Software Development :: Embedded Systems",
"Topic :: Software Development", "Topic :: Software Development",
], ],
test_suite = "test", test_suite = "tests",
tests_require = ["PyYAML"], tests_require = ["PyYAML"],
) )

View File

@ -32,7 +32,10 @@ class StcGal:
def __init__(self, opts): def __init__(self, opts):
self.opts = opts self.opts = opts
self.initialize_protocol(opts)
def initialize_protocol(self, opts):
"""Initialize protocol backend"""
if opts.protocol == "stc89": if opts.protocol == "stc89":
self.protocol = Stc89Protocol(opts.port, opts.handshake, opts.baud) self.protocol = Stc89Protocol(opts.port, opts.handshake, opts.baud)
elif opts.protocol == "stc12a": elif opts.protocol == "stc12a":
@ -50,8 +53,7 @@ class StcGal:
elif opts.protocol == "usb15": elif opts.protocol == "usb15":
self.protocol = StcUsb15Protocol() self.protocol = StcUsb15Protocol()
else: else:
self.protocol = StcBaseProtocol(opts.port, opts.handshake, opts.baud) self.protocol = StcAutoProtocol(opts.port, opts.handshake, opts.baud)
self.protocol.debug = opts.debug self.protocol.debug = opts.debug
def emit_options(self, options): def emit_options(self, options):
@ -133,14 +135,14 @@ class StcGal:
try: try:
self.protocol.connect(autoreset=self.opts.autoreset, resetcmd=self.opts.resetcmd) self.protocol.connect(autoreset=self.opts.autoreset, resetcmd=self.opts.resetcmd)
if self.opts.protocol == "auto": if isinstance(self.protocol, StcAutoProtocol):
if not self.protocol.protocol_name: if not self.protocol.protocol_name:
raise StcProtocolException("cannot detect protocol") raise StcProtocolException("cannot detect protocol")
base_protocol = self.protocol base_protocol = self.protocol
self.opts.protocol = self.protocol.protocol_name self.opts.protocol = self.protocol.protocol_name
print("Protocol detected: %s" % self.opts.protocol) print("Protocol detected: %s" % self.opts.protocol)
# recreate self.protocol with proper protocol class # recreate self.protocol with proper protocol class
self.__init__(self.opts) self.initialize_protocol(self.opts)
else: else:
base_protocol = None base_protocol = None

View File

@ -5,201 +5,214 @@
import struct import struct
import codecs import codecs
class IHex(object):
@classmethod
def read(cls, lines):
ihex = cls()
segbase = 0 class IHex:
for line in lines: """Intel HEX parser and writer"""
line = line.strip()
if not line: continue
t, a, d = ihex.parse_line(line) @classmethod
if t == 0x00: def read(cls, lines):
ihex.insert_data(segbase + a, d) """Read Intel HEX data from string or lines"""
ihex = cls()
elif t == 0x01: segbase = 0
break # Should we check for garbage after this? for line in lines:
line = line.strip()
if not line:
continue
elif t == 0x02: t, a, d = ihex.parse_line(line)
ihex.set_mode(16) if t == 0x00:
segbase = struct.unpack(">H", d[0:2])[0] << 4 ihex.insert_data(segbase + a, d)
elif t == 0x03: elif t == 0x01:
ihex.set_mode(16) break # Should we check for garbage after this?
cs, ip = struct.unpack(">2H", d[0:2]) elif t == 0x02:
ihex.set_start((cs, ip)) ihex.set_mode(16)
segbase = struct.unpack(">H", d[0:2])[0] << 4
elif t == 0x04: elif t == 0x03:
ihex.set_mode(32) ihex.set_mode(16)
segbase = struct.unpack(">H", d[0:2])[0] << 16
elif t == 0x05: cs, ip = struct.unpack(">2H", d[0:2])
ihex.set_mode(32) ihex.set_start((cs, ip))
ihex.set_start(struct.unpack(">I", d[0:4])[0])
else: elif t == 0x04:
raise ValueError("Invalid type byte") ihex.set_mode(32)
segbase = struct.unpack(">H", d[0:2])[0] << 16
return ihex elif t == 0x05:
ihex.set_mode(32)
ihex.set_start(struct.unpack(">I", d[0:4])[0])
@classmethod else:
def read_file(cls, fname): raise ValueError("Invalid type byte")
f = open(fname, "rb")
ihex = cls.read(f)
f.close()
return ihex
def __init__(self): return ihex
self.areas = {}
self.start = None
self.mode = 8
self.row_bytes = 16
def set_row_bytes(self, row_bytes): @classmethod
"""Set output hex file row width (bytes represented per row).""" def read_file(cls, fname):
if row_bytes < 1 or row_bytes > 0xff: """Read Intel HEX data from file"""
raise ValueError("Value out of range: (%r)" % row_bytes) f = open(fname, "rb")
self.row_bytes = row_bytes ihex = cls.read(f)
f.close()
def extract_data(self, start=None, end=None): return ihex
if start is None:
start = 0
if end is None:
result = bytearray()
for addr, data in self.areas.items():
if addr >= start:
if len(result) < (addr - start):
result[len(result):addr-start] = bytes(addr-start-len(result))
result[addr-start:addr-start+len(data)] = data
return bytes(result)
else:
result = bytearray()
for addr, data in self.areas.items():
if addr >= start and addr < end:
data = data[:end-addr]
if len(result) < (addr - start):
result[len(result):addr-start] = bytes(addr-start-len(result))
result[addr-start:addr-start+len(data)] = data
return bytes(result)
def set_start(self, start=None):
self.start = start
def set_mode(self, mode): def __init__(self):
self.mode = mode self.areas = {}
self.start = None
self.mode = 8
self.row_bytes = 16
def get_area(self, addr): def set_row_bytes(self, row_bytes):
for start, data in self.areas.items(): """Set output hex file row width (bytes represented per row)."""
end = start + len(data) if row_bytes < 1 or row_bytes > 0xff:
if addr >= start and addr <= end: raise ValueError("Value out of range: (%r)" % row_bytes)
return start self.row_bytes = row_bytes
return None def extract_data(self, start=None, end=None):
"""Extract binary data"""
if start is None:
start = 0
def insert_data(self, istart, idata): if end is None:
iend = istart + len(idata) result = bytearray()
area = self.get_area(istart) for addr, data in self.areas.items():
if area is None: if addr >= start:
self.areas[istart] = idata if len(result) < (addr - start):
result[len(result):addr - start] = bytes(
addr - start - len(result))
result[addr - start:addr - start + len(data)] = data
else: return bytes(result)
data = self.areas[area]
# istart - iend + len(idata) + len(data)
self.areas[area] = data[:istart-area] + idata + data[iend-area:]
def calc_checksum(self, bytes): else:
total = sum(bytes) result = bytearray()
return (-total) & 0xFF
def parse_line(self, rawline): for addr, data in self.areas.items():
if rawline[0:1] != b":": if addr >= start and addr < end:
raise ValueError("Invalid line start character (%r)" % rawline[0]) data = data[:end - addr]
if len(result) < (addr - start):
result[len(result):addr - start] = bytes(
addr - start - len(result))
result[addr - start:addr - start + len(data)] = data
try: return bytes(result)
#line = rawline[1:].decode("hex")
line = codecs.decode(rawline[1:], "hex_codec")
except:
raise ValueError("Invalid hex data")
length, addr, type = struct.unpack(">BHB", line[:4]) def set_start(self, start=None):
self.start = start
dataend = length + 4 def set_mode(self, mode):
data = line[4:dataend] self.mode = mode
#~ print line[dataend:dataend + 2], repr(line) def get_area(self, addr):
cs1 = line[dataend] for start, data in self.areas.items():
cs2 = self.calc_checksum(line[:dataend]) end = start + len(data)
if addr >= start and addr <= end:
return start
if cs1 != cs2: return None
raise ValueError("Checksums do not match")
return (type, addr, data) def insert_data(self, istart, idata):
iend = istart + len(idata)
def make_line(self, type, addr, data): area = self.get_area(istart)
line = struct.pack(">BHB", len(data), addr, type) if area is None:
line += data self.areas[istart] = idata
line += chr(self.calc_checksum(line))
#~ return ":" + line.encode("hex")
return ":" + line.encode("hex").upper() + "\r\n"
def write(self): else:
output = "" data = self.areas[area]
# istart - iend + len(idata) + len(data)
for start, data in sorted(self.areas.items()): self.areas[area] = data[
i = 0 :istart - area] + idata + data[iend - area:]
segbase = 0
while i < len(data): def calc_checksum(self, data):
chunk = data[i:i + self.row_bytes] total = sum(data)
return (-total) & 0xFF
addr = start def parse_line(self, rawline):
newsegbase = segbase if rawline[0:1] != b":":
raise ValueError("Invalid line start character (%r)" % rawline[0])
if self.mode == 8: try:
addr = addr & 0xFFFF line = codecs.decode(rawline[1:], "hex_codec")
except:
raise ValueError("Invalid hex data")
elif self.mode == 16: length, addr, line_type = struct.unpack(">BHB", line[:4])
t = addr & 0xFFFF
newsegbase = (addr - t) >> 4
addr = t
if newsegbase != segbase: dataend = length + 4
output += self.make_line(0x02, 0, struct.pack(">H", newsegbase)) data = line[4:dataend]
segbase = newsegbase
elif self.mode == 32: cs1 = line[dataend]
newsegbase = addr >> 16 cs2 = self.calc_checksum(line[:dataend])
addr = addr & 0xFFFF
if newsegbase != segbase: if cs1 != cs2:
output += self.make_line(0x04, 0, struct.pack(">H", newsegbase)) raise ValueError("Checksums do not match")
segbase = newsegbase
output += self.make_line(0x00, addr, chunk) return (line_type, addr, data)
i += self.row_bytes def make_line(self, line_type, addr, data):
start += self.row_bytes line = struct.pack(">BHB", len(data), addr, line_type)
line += data
line += chr(self.calc_checksum(line))
return ":" + line.encode("hex").upper() + "\r\n"
if self.start is not None: def write(self):
if self.mode == 16: """Write Intel HEX data to string"""
output += self.make_line(0x03, 0, struct.pack(">2H", self.start[0], self.start[1])) output = ""
elif self.mode == 32:
output += self.make_line(0x05, 0, struct.pack(">I", self.start))
output += self.make_line(0x01, 0, "") for start, data in sorted(self.areas.items()):
return output i = 0
segbase = 0
def write_file(self, fname): while i < len(data):
f = open(fname, "w") chunk = data[i:i + self.row_bytes]
f.write(self.write())
f.close() addr = start
newsegbase = segbase
if self.mode == 8:
addr = addr & 0xFFFF
elif self.mode == 16:
t = addr & 0xFFFF
newsegbase = (addr - t) >> 4
addr = t
if newsegbase != segbase:
output += self.make_line(
0x02, 0, struct.pack(">H", newsegbase))
segbase = newsegbase
elif self.mode == 32:
newsegbase = addr >> 16
addr = addr & 0xFFFF
if newsegbase != segbase:
output += self.make_line(
0x04, 0, struct.pack(">H", newsegbase))
segbase = newsegbase
output += self.make_line(0x00, addr, chunk)
i += self.row_bytes
start += self.row_bytes
if self.start is not None:
if self.mode == 16:
output += self.make_line(
0x03, 0, struct.pack(">2H", self.start[0], self.start[1]))
elif self.mode == 32:
output += self.make_line(
0x05, 0, struct.pack(">I", self.start))
output += self.make_line(0x01, 0, "")
return output
def write_file(self, fname):
"""Write Intel HEX data to file"""
f = open(fname, "w")
f.write(self.write())
f.close()

View File

@ -21,15 +21,24 @@
# #
import struct import struct
from abc import ABC
from stcgal.utils import Utils from stcgal.utils import Utils
class BaseOption: class BaseOption(ABC):
"""Base class for options"""
def __init__(self):
self.options = ()
self.msr = None
def print(self): def print(self):
"""Print current configuration to standard output"""
print("Target options:") print("Target options:")
for name, get_func, _ in self.options: for name, get_func, _ in self.options:
print(" %s=%s" % (name, get_func())) print(" %s=%s" % (name, get_func()))
def set_option(self, name, value): def set_option(self, name, value):
"""Set value of a specific option"""
for opt, _, set_func in self.options: for opt, _, set_func in self.options:
if opt == name: if opt == name:
print("Option %s=%s" % (name, value)) print("Option %s=%s" % (name, value))
@ -38,12 +47,14 @@ class BaseOption:
raise ValueError("unknown") raise ValueError("unknown")
def get_option(self, name): def get_option(self, name):
"""Get option value for a specific option"""
for opt, get_func, _ in self.options: for opt, get_func, _ in self.options:
if opt == name: if opt == name:
return get_func(name) return get_func(name)
raise ValueError("unknown") raise ValueError("unknown")
def get_msr(self): def get_msr(self):
"""Get array of model-specific configuration registers"""
return bytes(self.msr) return bytes(self.msr)
@ -51,6 +62,7 @@ class Stc89Option(BaseOption):
"""Manipulation STC89 series option byte""" """Manipulation STC89 series option byte"""
def __init__(self, msr): def __init__(self, msr):
super().__init__()
self.msr = msr self.msr = msr
self.options = ( self.options = (
("cpu_6t_enabled", self.get_t6, self.set_t6), ("cpu_6t_enabled", self.get_t6, self.set_t6),
@ -129,6 +141,7 @@ class Stc12AOption(BaseOption):
"""Manipulate STC12A series option bytes""" """Manipulate STC12A series option bytes"""
def __init__(self, msr): def __init__(self, msr):
super().__init__()
assert len(msr) == 4 assert len(msr) == 4
self.msr = bytearray(msr) self.msr = bytearray(msr)
@ -213,6 +226,7 @@ class Stc12Option(BaseOption):
"""Manipulate STC10/11/12 series option bytes""" """Manipulate STC10/11/12 series option bytes"""
def __init__(self, msr): def __init__(self, msr):
super().__init__()
assert len(msr) == 4 assert len(msr) == 4
self.msr = bytearray(msr) self.msr = bytearray(msr)
@ -337,6 +351,7 @@ class Stc12Option(BaseOption):
class Stc15AOption(BaseOption): class Stc15AOption(BaseOption):
def __init__(self, msr): def __init__(self, msr):
super().__init__()
assert len(msr) == 13 assert len(msr) == 13
self.msr = bytearray(msr) self.msr = bytearray(msr)
@ -435,6 +450,7 @@ class Stc15AOption(BaseOption):
class Stc15Option(BaseOption): class Stc15Option(BaseOption):
def __init__(self, msr): def __init__(self, msr):
super().__init__()
assert len(msr) >= 4 assert len(msr) >= 4
self.msr = bytearray(msr) self.msr = bytearray(msr)

View File

@ -21,13 +21,20 @@
# #
import serial import serial
import sys, os, time, struct, re, errno import sys
import os
import time
import struct
import re
import errno
import argparse import argparse
import collections import collections
from stcgal.models import MCUModelDatabase from stcgal.models import MCUModelDatabase
from stcgal.utils import Utils from stcgal.utils import Utils
from stcgal.options import * from stcgal.options import Stc89Option, Stc12Option, Stc12AOption, Stc15Option, Stc15AOption
from abc import ABC, abstractmethod
import functools import functools
import tqdm
try: try:
import usb.core, usb.util import usb.core, usb.util
@ -46,7 +53,7 @@ class StcProtocolException(Exception):
pass pass
class StcBaseProtocol: class StcBaseProtocol(ABC):
"""Basic functionality for STC BSL protocols""" """Basic functionality for STC BSL protocols"""
"""magic word that starts a packet""" """magic word that starts a packet"""
@ -77,6 +84,22 @@ class StcBaseProtocol:
self.debug = False self.debug = False
self.status_packet = None self.status_packet = None
self.protocol_name = None self.protocol_name = None
self.bar = None
self.progress_cb = self.progress_bar_cb
def progress_text_cb(self, current, written, maximum):
print(current, written, maximum)
def progress_bar_cb(self, current, written, maximum):
if not self.bar:
self.bar = tqdm.tqdm(
total = maximum,
unit = " Bytes",
desc = "Writing flash"
)
self.bar.update(written)
if current == maximum:
self.bar.close()
def dump_packet(self, data, receive=True): def dump_packet(self, data, receive=True):
if self.debug: if self.debug:
@ -103,6 +126,10 @@ class StcBaseProtocol:
return packet[5:-1] return packet[5:-1]
@abstractmethod
def write_packet(self, packet_data):
pass
def read_packet(self): def read_packet(self):
"""Read and check packet from MCU. """Read and check packet from MCU.
@ -192,20 +219,6 @@ class StcBaseProtocol:
mcu_name += "E" if self.status_packet[17] < 0x70 else "W" mcu_name += "E" if self.status_packet[17] < 0x70 else "W"
self.model = self.model._replace(name = mcu_name) self.model = self.model._replace(name = mcu_name)
protocol_database = [("stc89", r"STC(89|90)(C|LE)\d"),
("stc12a", r"STC12(C|LE)\d052"),
("stc12b", r"STC12(C|LE)(52|56)"),
("stc12", r"(STC|IAP)(10|11|12)\D"),
("stc15a", r"(STC|IAP)15[FL][01]0\d(E|EA|)$"),
("stc15", r"(STC|IAP|IRC)15\D")]
for protocol_name, pattern in protocol_database:
if re.match(pattern, self.model.name):
self.protocol_name = protocol_name
break
else:
self.protocol_name = None
def get_status_packet(self): def get_status_packet(self):
"""Read and decode status packet""" """Read and decode status packet"""
@ -287,6 +300,8 @@ class StcBaseProtocol:
try: try:
self.pulse() self.pulse()
self.status_packet = self.get_status_packet() self.status_packet = self.get_status_packet()
if len(self.status_packet) < 23:
raise StcProtocolException("status packet too short")
except (StcFramingException, serial.SerialTimeoutException): pass except (StcFramingException, serial.SerialTimeoutException): pass
print("done") print("done")
@ -296,7 +311,21 @@ class StcBaseProtocol:
self.initialize_model() self.initialize_model()
def initialize(self, base_protocol = None): @abstractmethod
def initialize_status(self, status_packet):
"""Initialize internal state from status packet"""
pass
@abstractmethod
def initialize_options(self, status_packet):
"""Initialize options from status packet"""
pass
def initialize(self, base_protocol=None):
"""
Initialize from another instance. This is an alternative for calling
connect() and is used by protocol autodetection.
"""
if base_protocol: if base_protocol:
self.ser = base_protocol.ser self.ser = base_protocol.ser
self.ser.parity = self.PARITY self.ser.parity = self.PARITY
@ -324,6 +353,39 @@ class StcBaseProtocol:
print("Disconnected!") print("Disconnected!")
class StcAutoProtocol(StcBaseProtocol):
"""
Protocol handler for autodetection of protocols. Does not implement full
functionality for any device class.
"""
def initialize_model(self):
super().initialize_model()
protocol_database = [("stc89", r"STC(89|90)(C|LE)\d"),
("stc12a", r"STC12(C|LE)\d052"),
("stc12b", r"STC12(C|LE)(52|56)"),
("stc12", r"(STC|IAP)(10|11|12)\D"),
("stc15a", r"(STC|IAP)15[FL][01]0\d(E|EA|)$"),
("stc15", r"(STC|IAP|IRC)15\D")]
for protocol_name, pattern in protocol_database:
if re.match(pattern, self.model.name):
self.protocol_name = protocol_name
break
else:
self.protocol_name = None
def initialize_options(self, status_packet):
raise NotImplementedError
def initialize_status(self, status_packet):
raise NotImplementedError
def write_packet(self, packet_data):
raise NotImplementedError
class Stc89Protocol(StcBaseProtocol): class Stc89Protocol(StcBaseProtocol):
"""Protocol handler for STC 89/90 series""" """Protocol handler for STC 89/90 series"""
@ -384,6 +446,9 @@ class Stc89Protocol(StcBaseProtocol):
def initialize_options(self, status_packet): def initialize_options(self, status_packet):
"""Initialize options""" """Initialize options"""
if len(status_packet) < 20:
raise StcProtocolException("invalid options in status packet")
self.options = Stc89Option(status_packet[19]) self.options = Stc89Option(status_packet[19])
self.options.print() self.options.print()
@ -505,8 +570,6 @@ class Stc89Protocol(StcBaseProtocol):
as the block size (depends on MCU's RAM size). as the block size (depends on MCU's RAM size).
""" """
print("Writing %d bytes: " % len(data), end="")
sys.stdout.flush()
for i in range(0, len(data), self.PROGRAM_BLOCKSIZE): for i in range(0, len(data), self.PROGRAM_BLOCKSIZE):
packet = bytes(3) packet = bytes(3)
packet += struct.pack(">H", i) packet += struct.pack(">H", i)
@ -516,13 +579,12 @@ class Stc89Protocol(StcBaseProtocol):
csum = sum(packet[7:]) & 0xff csum = sum(packet[7:]) & 0xff
self.write_packet(packet) self.write_packet(packet)
response = self.read_packet() response = self.read_packet()
if response[0] != 0x80: if len(response) < 1 or response[0] != 0x80:
raise StcProtocolException("incorrect magic in write packet") raise StcProtocolException("incorrect magic in write packet")
elif response[1] != csum: elif len(response) < 2 or response[1] != csum:
raise StcProtocolException("verification checksum mismatch") raise StcProtocolException("verification checksum mismatch")
print(".", end="") self.progress_cb(i, self.PROGRAM_BLOCKSIZE, len(data))
sys.stdout.flush() self.progress_cb(len(data), self.PROGRAM_BLOCKSIZE, len(data))
print(" done")
def program_options(self): def program_options(self):
"""Program option byte into flash""" """Program option byte into flash"""
@ -620,6 +682,9 @@ class Stc12AProtocol(Stc12AOptionsMixIn, Stc89Protocol):
def initialize_options(self, status_packet): def initialize_options(self, status_packet):
"""Initialize options""" """Initialize options"""
if len(status_packet) < 31:
raise StcProtocolException("invalid options in status packet")
# create option state # create option state
self.options = Stc12AOption(status_packet[23:26] + status_packet[29:30]) self.options = Stc12AOption(status_packet[23:26] + status_packet[29:30])
self.options.print() self.options.print()
@ -809,6 +874,9 @@ class Stc12BaseProtocol(StcBaseProtocol):
def initialize_options(self, status_packet): def initialize_options(self, status_packet):
"""Initialize options""" """Initialize options"""
if len(status_packet) < 29:
raise StcProtocolException("invalid options in status packet")
# create option state # create option state
self.options = Stc12Option(status_packet[23:26] + status_packet[27:28]) self.options = Stc12Option(status_packet[23:26] + status_packet[27:28])
self.options.print() self.options.print()
@ -886,8 +954,6 @@ class Stc12BaseProtocol(StcBaseProtocol):
as the block size (depends on MCU's RAM size). as the block size (depends on MCU's RAM size).
""" """
print("Writing %d bytes: " % len(data), end="")
sys.stdout.flush()
for i in range(0, len(data), self.PROGRAM_BLOCKSIZE): for i in range(0, len(data), self.PROGRAM_BLOCKSIZE):
packet = bytes(3) packet = bytes(3)
packet += struct.pack(">H", i) packet += struct.pack(">H", i)
@ -898,9 +964,8 @@ class Stc12BaseProtocol(StcBaseProtocol):
response = self.read_packet() response = self.read_packet()
if response[0] != 0x00: if response[0] != 0x00:
raise StcProtocolException("incorrect magic in write packet") raise StcProtocolException("incorrect magic in write packet")
print(".", end="") self.progress_cb(i, self.PROGRAM_BLOCKSIZE, len(data))
sys.stdout.flush() self.progress_cb(len(data), self.PROGRAM_BLOCKSIZE, len(data))
print(" done")
print("Finishing write: ", end="") print("Finishing write: ", end="")
sys.stdout.flush() sys.stdout.flush()
@ -943,6 +1008,9 @@ class Stc15AProtocol(Stc12Protocol):
def initialize_options(self, status_packet): def initialize_options(self, status_packet):
"""Initialize options""" """Initialize options"""
if len(status_packet) < 37:
raise StcProtocolException("invalid options in status packet")
# create option state # create option state
self.options = Stc15AOption(status_packet[23:36]) self.options = Stc15AOption(status_packet[23:36])
self.options.print() self.options.print()
@ -1055,15 +1123,19 @@ class Stc15AProtocol(Stc12Protocol):
self.write_packet(packet) self.write_packet(packet)
self.pulse(timeout=1.0) self.pulse(timeout=1.0)
response = self.read_packet() response = self.read_packet()
if response[0] != 0x65: if len(response) < 36 or response[0] != 0x65:
raise StcProtocolException("incorrect magic in handshake packet") raise StcProtocolException("incorrect magic in handshake packet")
# determine programming speed trim value # determine programming speed trim value
target_trim_a, target_count_a = struct.unpack(">HH", response[28:32]) target_trim_a, target_count_a = struct.unpack(">HH", response[28:32])
target_trim_b, target_count_b = struct.unpack(">HH", response[32:36]) target_trim_b, target_count_b = struct.unpack(">HH", response[32:36])
if target_count_a == target_count_b:
raise StcProtocolException("frequency trimming failed")
m = (target_trim_b - target_trim_a) / (target_count_b - target_count_a) m = (target_trim_b - target_trim_a) / (target_count_b - target_count_a)
n = target_trim_a - m * target_count_a n = target_trim_a - m * target_count_a
program_trim = round(m * program_count + n) program_trim = round(m * program_count + n)
if program_trim > 65535 or program_trim < 0:
raise StcProtocolException("frequency trimming failed")
# determine trim trials for second round # determine trim trials for second round
trim_a, count_a = struct.unpack(">HH", response[12:16]) trim_a, count_a = struct.unpack(">HH", response[12:16])
@ -1082,10 +1154,14 @@ class Stc15AProtocol(Stc12Protocol):
target_count_a = count_a target_count_a = count_a
target_count_b = count_b target_count_b = count_b
# linear interpolate to find range to try next # linear interpolate to find range to try next
if target_count_a == target_count_b:
raise StcProtocolException("frequency trimming failed")
m = (target_trim_b - target_trim_a) / (target_count_b - target_count_a) m = (target_trim_b - target_trim_a) / (target_count_b - target_count_a)
n = target_trim_a - m * target_count_a n = target_trim_a - m * target_count_a
target_trim = round(m * user_count + n) target_trim = round(m * user_count + n)
target_trim_start = min(max(target_trim - 5, target_trim_a), target_trim_b) target_trim_start = min(max(target_trim - 5, target_trim_a), target_trim_b)
if target_trim_start + 11 > 65535 or target_trim_start < 0:
raise StcProtocolException("frequency trimming failed")
# trim challenge-response, second round # trim challenge-response, second round
packet = bytes([0x65]) packet = bytes([0x65])
@ -1097,7 +1173,7 @@ class Stc15AProtocol(Stc12Protocol):
self.write_packet(packet) self.write_packet(packet)
self.pulse(timeout=1.0) self.pulse(timeout=1.0)
response = self.read_packet() response = self.read_packet()
if response[0] != 0x65: if len(response) < 56 or response[0] != 0x65:
raise StcProtocolException("incorrect magic in handshake packet") raise StcProtocolException("incorrect magic in handshake packet")
# determine best trim value # determine best trim value
@ -1156,7 +1232,11 @@ class Stc15Protocol(Stc15AProtocol):
def initialize_options(self, status_packet): def initialize_options(self, status_packet):
"""Initialize options""" """Initialize options"""
if len(status_packet) < 14:
raise StcProtocolException("invalid options in status packet")
# create option state # create option state
# XXX: check how option bytes are concatenated here
self.options = Stc15Option(status_packet[5:8] + status_packet[12:13] + status_packet[37:38]) self.options = Stc15Option(status_packet[5:8] + status_packet[12:13] + status_packet[37:38])
self.options.print() self.options.print()
@ -1201,6 +1281,8 @@ class Stc15Protocol(Stc15AProtocol):
calib_data = response[2:] calib_data = response[2:]
challenge_data = packet[2:] challenge_data = packet[2:]
calib_len = response[1] calib_len = response[1]
if len(calib_data) < 2 * calib_len:
raise StcProtocolException("range calibration data missing")
for i in range(calib_len - 1): for i in range(calib_len - 1):
count_a, count_b = struct.unpack(">HH", calib_data[2*i:2*i+4]) count_a, count_b = struct.unpack(">HH", calib_data[2*i:2*i+4])
@ -1210,6 +1292,8 @@ class Stc15Protocol(Stc15AProtocol):
m = (trim_b - trim_a) / (count_b - count_a) m = (trim_b - trim_a) / (count_b - count_a)
n = trim_a - m * count_a n = trim_a - m * count_a
target_trim = round(m * target_count + n) target_trim = round(m * target_count + n)
if target_trim > 65536 or target_trim < 0:
raise StcProtocolException("frequency trimming failed")
return (target_trim, trim_range) return (target_trim, trim_range)
return None return None
@ -1221,6 +1305,8 @@ class Stc15Protocol(Stc15AProtocol):
calib_data = response[2:] calib_data = response[2:]
challenge_data = packet[2:] challenge_data = packet[2:]
calib_len = response[1] calib_len = response[1]
if len(calib_data) < 2 * calib_len:
raise StcProtocolException("trim calibration data missing")
best = None best = None
best_count = sys.maxsize best_count = sys.maxsize
@ -1231,6 +1317,9 @@ class Stc15Protocol(Stc15AProtocol):
best_count = abs(count - target_count) best_count = abs(count - target_count)
best = (trim_adj, trim_range), count best = (trim_adj, trim_range), count
if not best:
raise StcProtocolException("frequency trimming failed")
return best return best
def calibrate(self): def calibrate(self):
@ -1260,7 +1349,7 @@ class Stc15Protocol(Stc15AProtocol):
self.write_packet(packet) self.write_packet(packet)
self.pulse(b"\xfe", timeout=1.0) self.pulse(b"\xfe", timeout=1.0)
response = self.read_packet() response = self.read_packet()
if response[0] != 0x00: if len(response) < 2 or response[0] != 0x00:
raise StcProtocolException("incorrect magic in handshake packet") raise StcProtocolException("incorrect magic in handshake packet")
# select ranges and trim values # select ranges and trim values
@ -1279,7 +1368,7 @@ class Stc15Protocol(Stc15AProtocol):
self.write_packet(packet) self.write_packet(packet)
self.pulse(b"\xfe", timeout=1.0) self.pulse(b"\xfe", timeout=1.0)
response = self.read_packet() response = self.read_packet()
if response[0] != 0x00: if len(response) < 2 or response[0] != 0x00:
raise StcProtocolException("incorrect magic in handshake packet") raise StcProtocolException("incorrect magic in handshake packet")
# select final values # select final values
@ -1305,7 +1394,7 @@ class Stc15Protocol(Stc15AProtocol):
packet += bytes([iap_wait]) packet += bytes([iap_wait])
self.write_packet(packet) self.write_packet(packet)
response = self.read_packet() response = self.read_packet()
if response[0] != 0x01: if len(response) < 1 or response[0] != 0x01:
raise StcProtocolException("incorrect magic in handshake packet") raise StcProtocolException("incorrect magic in handshake packet")
time.sleep(0.2) time.sleep(0.2)
self.ser.baudrate = self.baud_transfer self.ser.baudrate = self.baud_transfer
@ -1322,7 +1411,7 @@ class Stc15Protocol(Stc15AProtocol):
packet += bytes([0x00, 0x00, iap_wait]) packet += bytes([0x00, 0x00, iap_wait])
self.write_packet(packet) self.write_packet(packet)
response = self.read_packet() response = self.read_packet()
if response[0] != 0x01: if len(response) < 1 or response[0] != 0x01:
raise StcProtocolException("incorrect magic in handshake packet") raise StcProtocolException("incorrect magic in handshake packet")
time.sleep(0.2) time.sleep(0.2)
self.ser.baudrate = self.baud_transfer self.ser.baudrate = self.baud_transfer
@ -1348,9 +1437,9 @@ class Stc15Protocol(Stc15AProtocol):
packet += bytes([0x00, 0x00, 0x5a, 0xa5]) packet += bytes([0x00, 0x00, 0x5a, 0xa5])
self.write_packet(packet) self.write_packet(packet)
response = self.read_packet() response = self.read_packet()
if response[0] == 0x0f: if len(response) == 1 and response[0] == 0x0f:
raise StcProtocolException("MCU is locked") raise StcProtocolException("MCU is locked")
if response[0] != 0x05: if len(response) < 1 or response[0] != 0x05:
raise StcProtocolException("incorrect magic in handshake packet") raise StcProtocolException("incorrect magic in handshake packet")
print("done") print("done")
@ -1371,18 +1460,20 @@ class Stc15Protocol(Stc15AProtocol):
packet += bytes([0x00, 0x5a, 0xa5]) packet += bytes([0x00, 0x5a, 0xa5])
self.write_packet(packet) self.write_packet(packet)
response = self.read_packet() response = self.read_packet()
if response[0] != 0x03: if len(response) < 1 or response[0] != 0x03:
raise StcProtocolException("incorrect magic in handshake packet") raise StcProtocolException("incorrect magic in handshake packet")
print("done") print("done")
if len(response) >= 8: if len(response) >= 8:
self.uid = response[1:8] self.uid = response[1:8]
# we should have a UID at this point
if not self.uid:
raise StcProtocolException("UID is missing")
def program_flash(self, data): def program_flash(self, data):
"""Program the MCU's flash memory.""" """Program the MCU's flash memory."""
print("Writing %d bytes: " % len(data), end="")
sys.stdout.flush()
for i in range(0, len(data), self.PROGRAM_BLOCKSIZE): for i in range(0, len(data), self.PROGRAM_BLOCKSIZE):
packet = bytes([0x22]) if i == 0 else bytes([0x02]) packet = bytes([0x22]) if i == 0 else bytes([0x02])
packet += struct.pack(">H", i) packet += struct.pack(">H", i)
@ -1392,11 +1483,10 @@ class Stc15Protocol(Stc15AProtocol):
while len(packet) < self.PROGRAM_BLOCKSIZE + 3: packet += b"\x00" while len(packet) < self.PROGRAM_BLOCKSIZE + 3: packet += b"\x00"
self.write_packet(packet) self.write_packet(packet)
response = self.read_packet() response = self.read_packet()
if response[0] != 0x02 or response[1] != 0x54: if len(response) < 2 or response[0] != 0x02 or response[1] != 0x54:
raise StcProtocolException("incorrect magic in write packet") raise StcProtocolException("incorrect magic in write packet")
print(".", end="") self.progress_cb(i, self.PROGRAM_BLOCKSIZE, len(data))
sys.stdout.flush() self.progress_cb(len(data), self.PROGRAM_BLOCKSIZE, len(data))
print(" done")
# BSL 7.2+ needs a write finish packet according to dumps # BSL 7.2+ needs a write finish packet according to dumps
if self.bsl_version >= 0x72: if self.bsl_version >= 0x72:
@ -1405,7 +1495,7 @@ class Stc15Protocol(Stc15AProtocol):
packet = bytes([0x07, 0x00, 0x00, 0x5a, 0xa5]) packet = bytes([0x07, 0x00, 0x00, 0x5a, 0xa5])
self.write_packet(packet) self.write_packet(packet)
response = self.read_packet() response = self.read_packet()
if response[0] != 0x07 or response[1] != 0x54: if len(response) < 2 or response[0] != 0x07 or response[1] != 0x54:
raise StcProtocolException("incorrect magic in finish packet") raise StcProtocolException("incorrect magic in finish packet")
print("done") print("done")
@ -1444,7 +1534,7 @@ class Stc15Protocol(Stc15AProtocol):
packet += self.build_options() packet += self.build_options()
self.write_packet(packet) self.write_packet(packet)
response = self.read_packet() response = self.read_packet()
if response[0] != 0x04 or response[1] != 0x54: if len(response) < 2 or response[0] != 0x04 or response[1] != 0x54:
raise StcProtocolException("incorrect magic in option packet") raise StcProtocolException("incorrect magic in option packet")
print("done") print("done")
@ -1579,8 +1669,6 @@ class StcUsb15Protocol(Stc15Protocol):
def program_flash(self, data): def program_flash(self, data):
"""Program the MCU's flash memory.""" """Program the MCU's flash memory."""
print("Writing %d bytes: " % len(data), end="")
sys.stdout.flush()
for i in range(0, len(data), self.PROGRAM_BLOCKSIZE): for i in range(0, len(data), self.PROGRAM_BLOCKSIZE):
packet = data[i:i+self.PROGRAM_BLOCKSIZE] packet = data[i:i+self.PROGRAM_BLOCKSIZE]
while len(packet) < self.PROGRAM_BLOCKSIZE: packet += b"\x00" while len(packet) < self.PROGRAM_BLOCKSIZE: packet += b"\x00"
@ -1590,9 +1678,8 @@ class StcUsb15Protocol(Stc15Protocol):
response = self.read_packet() response = self.read_packet()
if response[0] != 0x02 or response[1] != 0x54: if response[0] != 0x02 or response[1] != 0x54:
raise StcProtocolException("incorrect magic in write packet") raise StcProtocolException("incorrect magic in write packet")
print(".", end="") self.progress_cb(i, self.PROGRAM_BLOCKSIZE, len(data))
sys.stdout.flush() self.progress_cb(len(data), self.PROGRAM_BLOCKSIZE, len(data))
print(" done")
def program_options(self): def program_options(self):
print("Setting options: ", end="") print("Setting options: ", end="")

View File

@ -29,14 +29,12 @@ class Utils:
def to_bool(cls, val): def to_bool(cls, val):
"""make sensible boolean from string or other type value""" """make sensible boolean from string or other type value"""
if val is None: if not val:
return False return False
if isinstance(val, bool): if isinstance(val, bool):
return val return val
elif isinstance(val, int): elif isinstance(val, int):
return bool(val) return bool(val)
elif len(val) == 0:
return False
else: else:
return True if val[0].lower() == "t" or val[0] == "1" else False return True if val[0].lower() == "t" or val[0] == "1" else False

119
tests/test_fuzzing.py Normal file
View File

@ -0,0 +1,119 @@
#
# Copyright (c) 2017 Grigori Goronzy <greg@chown.ath.cx>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
"""Tests with fuzzing of input data"""
import random
import sys
import unittest
from unittest.mock import patch
import yaml
import stcgal.frontend
import stcgal.protocols
from tests.test_program import get_default_opts, convert_to_bytes
class ByteArrayFuzzer:
"""Fuzzer for byte arrays"""
def __init__(self):
self.rng = random.Random()
self.cut_propability = 0.01 # probability for cutting off an array early
self.cut_min = 0 # minimum cut amount
self.cut_max = sys.maxsize # maximum cut amount
self.bitflip_probability = 0.0001 # probability for flipping a bit
self.randomize_probability = 0.001 # probability for randomizing a char
def fuzz(self, inp):
"""Fuzz an array of bytes according to predefined settings"""
arr = bytearray(inp)
arr = self.cut_off(arr)
self.randomize(arr)
return bytes(arr)
def randomize(self, arr):
"""Randomize array contents with bitflips and random bytes"""
for i, _ in enumerate(arr):
for j in range(8):
if self.rng.random() < self.bitflip_probability:
arr[i] ^= (1 << j)
if self.rng.random() < self.randomize_probability:
arr[i] = self.rng.getrandbits(8)
def cut_off(self, arr):
"""Cut off data from end of array"""
if self.rng.random() < self.cut_propability:
cut_limit = min(len(arr), self.cut_max)
cut_len = self.rng.randrange(self.cut_min, cut_limit)
arr = arr[0:len(arr) - cut_len]
return arr
class TestProgramFuzzed(unittest.TestCase):
"""Special programming cycle tests that use a fuzzing approach"""
@patch("stcgal.protocols.StcBaseProtocol.read_packet")
@patch("stcgal.protocols.Stc89Protocol.write_packet")
@patch("stcgal.protocols.serial.Serial", autospec=True)
@patch("stcgal.protocols.time.sleep")
@patch("sys.stdout")
@patch("sys.stderr")
def test_program_fuzz(self, err, out, sleep_mock, serial_mock, write_mock, read_mock):
"""Test programming cycles with fuzzing enabled"""
yml = [
"./tests/iap15f2k61s2.yml",
"./tests/stc12c2052ad.yml",
"./tests/stc15w4k56s4.yml",
"./tests/stc12c5a60s2.yml",
"./tests/stc89c52rc.yml",
"./tests/stc15l104w.yml",
"./tests/stc15f104e.yml",
]
fuzzer = ByteArrayFuzzer()
fuzzer.cut_propability = 0.01
fuzzer.bitflip_probability = 0.005
fuzzer.rng = random.Random(1)
for y in yml:
with self.subTest(msg="trace {}".format(y)):
self.single_fuzz(y, serial_mock, fuzzer, read_mock, err, out,
sleep_mock, write_mock)
def single_fuzz(self, yml, serial_mock, fuzzer, read_mock, err, out, sleep_mock, write_mock):
"""Test a single programming cycle with fuzzing"""
with open(yml) as test_file:
test_data = yaml.load(test_file.read())
for _ in range(1000):
with self.subTest():
opts = get_default_opts()
opts.protocol = test_data["protocol"]
opts.code_image.read.return_value = bytes(test_data["code_data"])
serial_mock.return_value.inWaiting.return_value = 1
fuzzed_responses = []
for arr in convert_to_bytes(test_data["responses"]):
fuzzed_responses.append(fuzzer.fuzz(arr))
read_mock.side_effect = fuzzed_responses
gal = stcgal.frontend.StcGal(opts)
self.assertGreaterEqual(gal.run(), 0)
err.reset_mock()
out.reset_mock()
sleep_mock.reset_mock()
serial_mock.reset_mock()
write_mock.reset_mock()
read_mock.reset_mock()

View File

@ -57,7 +57,7 @@ class ProgramTests(unittest.TestCase):
@patch("sys.stdout") @patch("sys.stdout")
def test_program_stc89(self, out, sleep_mock, serial_mock, write_mock, read_mock): def test_program_stc89(self, out, sleep_mock, serial_mock, write_mock, read_mock):
"""Test a programming cycle with STC89 protocol""" """Test a programming cycle with STC89 protocol"""
self._program_yml("./test/stc89c52rc.yml", serial_mock, read_mock) self._program_yml("./tests/stc89c52rc.yml", serial_mock, read_mock)
@patch("stcgal.protocols.StcBaseProtocol.read_packet") @patch("stcgal.protocols.StcBaseProtocol.read_packet")
@patch("stcgal.protocols.Stc89Protocol.write_packet") @patch("stcgal.protocols.Stc89Protocol.write_packet")
@ -66,7 +66,7 @@ class ProgramTests(unittest.TestCase):
@patch("sys.stdout") @patch("sys.stdout")
def test_program_stc12(self, out, sleep_mock, serial_mock, write_mock, read_mock): def test_program_stc12(self, out, sleep_mock, serial_mock, write_mock, read_mock):
"""Test a programming cycle with STC12 protocol""" """Test a programming cycle with STC12 protocol"""
self._program_yml("./test/stc12c5a60s2.yml", serial_mock, read_mock) self._program_yml("./tests/stc12c5a60s2.yml", serial_mock, read_mock)
@patch("stcgal.protocols.StcBaseProtocol.read_packet") @patch("stcgal.protocols.StcBaseProtocol.read_packet")
@patch("stcgal.protocols.Stc89Protocol.write_packet") @patch("stcgal.protocols.Stc89Protocol.write_packet")
@ -75,7 +75,7 @@ class ProgramTests(unittest.TestCase):
@patch("sys.stdout") @patch("sys.stdout")
def test_program_stc12a(self, out, sleep_mock, serial_mock, write_mock, read_mock): def test_program_stc12a(self, out, sleep_mock, serial_mock, write_mock, read_mock):
"""Test a programming cycle with STC12A protocol""" """Test a programming cycle with STC12A protocol"""
self._program_yml("./test/stc12c2052ad.yml", serial_mock, read_mock) self._program_yml("./tests/stc12c2052ad.yml", serial_mock, read_mock)
def test_program_stc12b(self): def test_program_stc12b(self):
"""Test a programming cycle with STC12B protocol""" """Test a programming cycle with STC12B protocol"""
@ -88,7 +88,7 @@ class ProgramTests(unittest.TestCase):
@patch("sys.stdout") @patch("sys.stdout")
def test_program_stc15f2(self, out, sleep_mock, serial_mock, write_mock, read_mock): def test_program_stc15f2(self, out, sleep_mock, serial_mock, write_mock, read_mock):
"""Test a programming cycle with STC15 protocol, F2 series""" """Test a programming cycle with STC15 protocol, F2 series"""
self._program_yml("./test/iap15f2k61s2.yml", serial_mock, read_mock) self._program_yml("./tests/iap15f2k61s2.yml", serial_mock, read_mock)
@patch("stcgal.protocols.StcBaseProtocol.read_packet") @patch("stcgal.protocols.StcBaseProtocol.read_packet")
@patch("stcgal.protocols.Stc89Protocol.write_packet") @patch("stcgal.protocols.Stc89Protocol.write_packet")
@ -97,7 +97,7 @@ class ProgramTests(unittest.TestCase):
@patch("sys.stdout") @patch("sys.stdout")
def test_program_stc15w4(self, out, sleep_mock, serial_mock, write_mock, read_mock): def test_program_stc15w4(self, out, sleep_mock, serial_mock, write_mock, read_mock):
"""Test a programming cycle with STC15 protocol, W4 series""" """Test a programming cycle with STC15 protocol, W4 series"""
self._program_yml("./test/stc15w4k56s4.yml", serial_mock, read_mock) self._program_yml("./tests/stc15w4k56s4.yml", serial_mock, read_mock)
@unittest.skip("trace is broken") @unittest.skip("trace is broken")
@patch("stcgal.protocols.StcBaseProtocol.read_packet") @patch("stcgal.protocols.StcBaseProtocol.read_packet")
@ -107,7 +107,7 @@ class ProgramTests(unittest.TestCase):
@patch("sys.stdout") @patch("sys.stdout")
def test_program_stc15a(self, out, sleep_mock, serial_mock, write_mock, read_mock): def test_program_stc15a(self, out, sleep_mock, serial_mock, write_mock, read_mock):
"""Test a programming cycle with STC15A protocol""" """Test a programming cycle with STC15A protocol"""
self._program_yml("./test/stc15f104e.yml", serial_mock, read_mock) self._program_yml("./tests/stc15f104e.yml", serial_mock, read_mock)
@patch("stcgal.protocols.StcBaseProtocol.read_packet") @patch("stcgal.protocols.StcBaseProtocol.read_packet")
@patch("stcgal.protocols.Stc89Protocol.write_packet") @patch("stcgal.protocols.Stc89Protocol.write_packet")
@ -116,7 +116,7 @@ class ProgramTests(unittest.TestCase):
@patch("sys.stdout") @patch("sys.stdout")
def test_program_stc15l1(self, out, sleep_mock, serial_mock, write_mock, read_mock): def test_program_stc15l1(self, out, sleep_mock, serial_mock, write_mock, read_mock):
"""Test a programming cycle with STC15 protocol, L1 series""" """Test a programming cycle with STC15 protocol, L1 series"""
self._program_yml("./test/stc15l104w.yml", serial_mock, read_mock) self._program_yml("./tests/stc15l104w.yml", serial_mock, read_mock)
def test_program_stc15w4_usb(self): def test_program_stc15w4_usb(self):
"""Test a programming cycle with STC15W4 USB protocol""" """Test a programming cycle with STC15W4 USB protocol"""
@ -133,4 +133,3 @@ class ProgramTests(unittest.TestCase):
read_mock.side_effect = convert_to_bytes(test_data["responses"]) read_mock.side_effect = convert_to_bytes(test_data["responses"])
gal = stcgal.frontend.StcGal(opts) gal = stcgal.frontend.StcGal(opts)
self.assertEqual(gal.run(), 0) self.assertEqual(gal.run(), 0)

View File

@ -24,7 +24,6 @@
import argparse import argparse
import unittest import unittest
from unittest.mock import patch
from stcgal.utils import Utils, BaudType from stcgal.utils import Utils, BaudType
class TestUtils(unittest.TestCase): class TestUtils(unittest.TestCase):