diff --git a/setup.py b/setup.py index 73521cf..a027db3 100755 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ from setuptools import setup, find_packages setup( name = "stcgal", version = stcgal.__version__, - packages = find_packages(exclude=["doc", "test"]), + packages = find_packages(exclude=["doc", "tests"]), install_requires = ["pyserial"], extras_require = { "usb": ["pyusb>=1.0.0"] @@ -55,6 +55,6 @@ setup( "Topic :: Software Development :: Embedded Systems", "Topic :: Software Development", ], - test_suite = "test", + test_suite = "tests", tests_require = ["PyYAML"], ) diff --git a/stcgal/frontend.py b/stcgal/frontend.py index 2bade62..997ac45 100644 --- a/stcgal/frontend.py +++ b/stcgal/frontend.py @@ -32,7 +32,10 @@ class StcGal: def __init__(self, opts): self.opts = opts + self.initialize_protocol(opts) + def initialize_protocol(self, opts): + """Initialize protocol backend""" if opts.protocol == "stc89": self.protocol = Stc89Protocol(opts.port, opts.handshake, opts.baud) elif opts.protocol == "stc12a": @@ -50,8 +53,7 @@ class StcGal: elif opts.protocol == "usb15": self.protocol = StcUsb15Protocol() else: - self.protocol = StcBaseProtocol(opts.port, opts.handshake, opts.baud) - + self.protocol = StcAutoProtocol(opts.port, opts.handshake, opts.baud) self.protocol.debug = opts.debug def emit_options(self, options): @@ -133,14 +135,14 @@ class StcGal: try: self.protocol.connect(autoreset=self.opts.autoreset, resetcmd=self.opts.resetcmd) - if self.opts.protocol == "auto": + if isinstance(self.protocol, StcAutoProtocol): if not self.protocol.protocol_name: raise StcProtocolException("cannot detect protocol") base_protocol = self.protocol self.opts.protocol = self.protocol.protocol_name print("Protocol detected: %s" % self.opts.protocol) # recreate self.protocol with proper protocol class - self.__init__(self.opts) + self.initialize_protocol(self.opts) else: base_protocol = None diff --git a/stcgal/options.py b/stcgal/options.py index 8bdccbf..639c2e1 100644 --- a/stcgal/options.py +++ b/stcgal/options.py @@ -21,15 +21,24 @@ # import struct +from abc import ABC from stcgal.utils import Utils -class BaseOption: +class BaseOption(ABC): + """Base class for options""" + + def __init__(self): + self.options = () + self.msr = None + def print(self): + """Print current configuration to standard output""" print("Target options:") for name, get_func, _ in self.options: print(" %s=%s" % (name, get_func())) def set_option(self, name, value): + """Set value of a specific option""" for opt, _, set_func in self.options: if opt == name: print("Option %s=%s" % (name, value)) @@ -38,12 +47,14 @@ class BaseOption: raise ValueError("unknown") def get_option(self, name): + """Get option value for a specific option""" for opt, get_func, _ in self.options: if opt == name: return get_func(name) raise ValueError("unknown") def get_msr(self): + """Get array of model-specific configuration registers""" return bytes(self.msr) @@ -51,6 +62,7 @@ class Stc89Option(BaseOption): """Manipulation STC89 series option byte""" def __init__(self, msr): + super().__init__() self.msr = msr self.options = ( ("cpu_6t_enabled", self.get_t6, self.set_t6), @@ -129,6 +141,7 @@ class Stc12AOption(BaseOption): """Manipulate STC12A series option bytes""" def __init__(self, msr): + super().__init__() assert len(msr) == 4 self.msr = bytearray(msr) @@ -213,6 +226,7 @@ class Stc12Option(BaseOption): """Manipulate STC10/11/12 series option bytes""" def __init__(self, msr): + super().__init__() assert len(msr) == 4 self.msr = bytearray(msr) @@ -337,6 +351,7 @@ class Stc12Option(BaseOption): class Stc15AOption(BaseOption): def __init__(self, msr): + super().__init__() assert len(msr) == 13 self.msr = bytearray(msr) @@ -435,6 +450,7 @@ class Stc15AOption(BaseOption): class Stc15Option(BaseOption): def __init__(self, msr): + super().__init__() assert len(msr) >= 4 self.msr = bytearray(msr) diff --git a/stcgal/protocols.py b/stcgal/protocols.py index 31b88f9..1060a7d 100644 --- a/stcgal/protocols.py +++ b/stcgal/protocols.py @@ -21,12 +21,18 @@ # import serial -import sys, os, time, struct, re, errno +import sys +import os +import time +import struct +import re +import errno import argparse import collections from stcgal.models import MCUModelDatabase from stcgal.utils import Utils -from stcgal.options import * +from stcgal.options import Stc89Option, Stc12Option, Stc12AOption, Stc15Option, Stc15AOption +from abc import ABC, abstractmethod import functools try: @@ -46,7 +52,7 @@ class StcProtocolException(Exception): pass -class StcBaseProtocol: +class StcBaseProtocol(ABC): """Basic functionality for STC BSL protocols""" """magic word that starts a packet""" @@ -103,6 +109,10 @@ class StcBaseProtocol: return packet[5:-1] + @abstractmethod + def write_packet(self, packet_data): + pass + def read_packet(self): """Read and check packet from MCU. @@ -192,20 +202,6 @@ class StcBaseProtocol: mcu_name += "E" if self.status_packet[17] < 0x70 else "W" self.model = self.model._replace(name = mcu_name) - protocol_database = [("stc89", r"STC(89|90)(C|LE)\d"), - ("stc12a", r"STC12(C|LE)\d052"), - ("stc12b", r"STC12(C|LE)(52|56)"), - ("stc12", r"(STC|IAP)(10|11|12)\D"), - ("stc15a", r"(STC|IAP)15[FL][01]0\d(E|EA|)$"), - ("stc15", r"(STC|IAP|IRC)15\D")] - - for protocol_name, pattern in protocol_database: - if re.match(pattern, self.model.name): - self.protocol_name = protocol_name - break - else: - self.protocol_name = None - def get_status_packet(self): """Read and decode status packet""" @@ -287,6 +283,8 @@ class StcBaseProtocol: try: self.pulse() self.status_packet = self.get_status_packet() + if len(self.status_packet) < 23: + raise StcProtocolException("status packet too short") except (StcFramingException, serial.SerialTimeoutException): pass print("done") @@ -296,7 +294,21 @@ class StcBaseProtocol: self.initialize_model() - def initialize(self, base_protocol = None): + @abstractmethod + def initialize_status(self, status_packet): + """Initialize internal state from status packet""" + pass + + @abstractmethod + def initialize_options(self, status_packet): + """Initialize options from status packet""" + pass + + def initialize(self, base_protocol=None): + """ + Initialize from another instance. This is an alternative for calling + connect() and is used by protocol autodetection. + """ if base_protocol: self.ser = base_protocol.ser self.ser.parity = self.PARITY @@ -324,6 +336,39 @@ class StcBaseProtocol: print("Disconnected!") +class StcAutoProtocol(StcBaseProtocol): + """ + Protocol handler for autodetection of protocols. Does not implement full + functionality for any device class. + """ + + def initialize_model(self): + super().initialize_model() + + protocol_database = [("stc89", r"STC(89|90)(C|LE)\d"), + ("stc12a", r"STC12(C|LE)\d052"), + ("stc12b", r"STC12(C|LE)(52|56)"), + ("stc12", r"(STC|IAP)(10|11|12)\D"), + ("stc15a", r"(STC|IAP)15[FL][01]0\d(E|EA|)$"), + ("stc15", r"(STC|IAP|IRC)15\D")] + + for protocol_name, pattern in protocol_database: + if re.match(pattern, self.model.name): + self.protocol_name = protocol_name + break + else: + self.protocol_name = None + + def initialize_options(self, status_packet): + raise NotImplementedError + + def initialize_status(self, status_packet): + raise NotImplementedError + + def write_packet(self, packet_data): + raise NotImplementedError + + class Stc89Protocol(StcBaseProtocol): """Protocol handler for STC 89/90 series""" @@ -384,6 +429,9 @@ class Stc89Protocol(StcBaseProtocol): def initialize_options(self, status_packet): """Initialize options""" + if len(status_packet) < 20: + raise StcProtocolException("invalid options in status packet") + self.options = Stc89Option(status_packet[19]) self.options.print() @@ -516,9 +564,9 @@ class Stc89Protocol(StcBaseProtocol): csum = sum(packet[7:]) & 0xff self.write_packet(packet) response = self.read_packet() - if response[0] != 0x80: + if len(response) < 1 or response[0] != 0x80: raise StcProtocolException("incorrect magic in write packet") - elif response[1] != csum: + elif len(response) < 2 or response[1] != csum: raise StcProtocolException("verification checksum mismatch") print(".", end="") sys.stdout.flush() @@ -620,6 +668,9 @@ class Stc12AProtocol(Stc12AOptionsMixIn, Stc89Protocol): def initialize_options(self, status_packet): """Initialize options""" + if len(status_packet) < 31: + raise StcProtocolException("invalid options in status packet") + # create option state self.options = Stc12AOption(status_packet[23:26] + status_packet[29:30]) self.options.print() @@ -809,6 +860,9 @@ class Stc12BaseProtocol(StcBaseProtocol): def initialize_options(self, status_packet): """Initialize options""" + if len(status_packet) < 29: + raise StcProtocolException("invalid options in status packet") + # create option state self.options = Stc12Option(status_packet[23:26] + status_packet[27:28]) self.options.print() @@ -943,6 +997,9 @@ class Stc15AProtocol(Stc12Protocol): def initialize_options(self, status_packet): """Initialize options""" + if len(status_packet) < 37: + raise StcProtocolException("invalid options in status packet") + # create option state self.options = Stc15AOption(status_packet[23:36]) self.options.print() @@ -1055,15 +1112,19 @@ class Stc15AProtocol(Stc12Protocol): self.write_packet(packet) self.pulse(timeout=1.0) response = self.read_packet() - if response[0] != 0x65: + if len(response) < 36 or response[0] != 0x65: raise StcProtocolException("incorrect magic in handshake packet") # determine programming speed trim value target_trim_a, target_count_a = struct.unpack(">HH", response[28:32]) target_trim_b, target_count_b = struct.unpack(">HH", response[32:36]) + if target_count_a == target_count_b: + raise StcProtocolException("frequency trimming failed") m = (target_trim_b - target_trim_a) / (target_count_b - target_count_a) n = target_trim_a - m * target_count_a program_trim = round(m * program_count + n) + if program_trim > 65535 or program_trim < 0: + raise StcProtocolException("frequency trimming failed") # determine trim trials for second round trim_a, count_a = struct.unpack(">HH", response[12:16]) @@ -1082,10 +1143,14 @@ class Stc15AProtocol(Stc12Protocol): target_count_a = count_a target_count_b = count_b # linear interpolate to find range to try next + if target_count_a == target_count_b: + raise StcProtocolException("frequency trimming failed") m = (target_trim_b - target_trim_a) / (target_count_b - target_count_a) n = target_trim_a - m * target_count_a target_trim = round(m * user_count + n) target_trim_start = min(max(target_trim - 5, target_trim_a), target_trim_b) + if target_trim_start + 11 > 65535 or target_trim_start < 0: + raise StcProtocolException("frequency trimming failed") # trim challenge-response, second round packet = bytes([0x65]) @@ -1097,7 +1162,7 @@ class Stc15AProtocol(Stc12Protocol): self.write_packet(packet) self.pulse(timeout=1.0) response = self.read_packet() - if response[0] != 0x65: + if len(response) < 56 or response[0] != 0x65: raise StcProtocolException("incorrect magic in handshake packet") # determine best trim value @@ -1156,7 +1221,11 @@ class Stc15Protocol(Stc15AProtocol): def initialize_options(self, status_packet): """Initialize options""" + if len(status_packet) < 14: + raise StcProtocolException("invalid options in status packet") + # create option state + # XXX: check how option bytes are concatenated here self.options = Stc15Option(status_packet[5:8] + status_packet[12:13] + status_packet[37:38]) self.options.print() @@ -1201,6 +1270,8 @@ class Stc15Protocol(Stc15AProtocol): calib_data = response[2:] challenge_data = packet[2:] calib_len = response[1] + if len(calib_data) < 2 * calib_len: + raise StcProtocolException("range calibration data missing") for i in range(calib_len - 1): count_a, count_b = struct.unpack(">HH", calib_data[2*i:2*i+4]) @@ -1210,6 +1281,8 @@ class Stc15Protocol(Stc15AProtocol): m = (trim_b - trim_a) / (count_b - count_a) n = trim_a - m * count_a target_trim = round(m * target_count + n) + if target_trim > 65536 or target_trim < 0: + raise StcProtocolException("frequency trimming failed") return (target_trim, trim_range) return None @@ -1221,6 +1294,8 @@ class Stc15Protocol(Stc15AProtocol): calib_data = response[2:] challenge_data = packet[2:] calib_len = response[1] + if len(calib_data) < 2 * calib_len: + raise StcProtocolException("trim calibration data missing") best = None best_count = sys.maxsize @@ -1231,6 +1306,9 @@ class Stc15Protocol(Stc15AProtocol): best_count = abs(count - target_count) best = (trim_adj, trim_range), count + if not best: + raise StcProtocolException("frequency trimming failed") + return best def calibrate(self): @@ -1260,7 +1338,7 @@ class Stc15Protocol(Stc15AProtocol): self.write_packet(packet) self.pulse(b"\xfe", timeout=1.0) response = self.read_packet() - if response[0] != 0x00: + if len(response) < 2 or response[0] != 0x00: raise StcProtocolException("incorrect magic in handshake packet") # select ranges and trim values @@ -1279,7 +1357,7 @@ class Stc15Protocol(Stc15AProtocol): self.write_packet(packet) self.pulse(b"\xfe", timeout=1.0) response = self.read_packet() - if response[0] != 0x00: + if len(response) < 2 or response[0] != 0x00: raise StcProtocolException("incorrect magic in handshake packet") # select final values @@ -1305,7 +1383,7 @@ class Stc15Protocol(Stc15AProtocol): packet += bytes([iap_wait]) self.write_packet(packet) response = self.read_packet() - if response[0] != 0x01: + if len(response) < 1 or response[0] != 0x01: raise StcProtocolException("incorrect magic in handshake packet") time.sleep(0.2) self.ser.baudrate = self.baud_transfer @@ -1322,7 +1400,7 @@ class Stc15Protocol(Stc15AProtocol): packet += bytes([0x00, 0x00, iap_wait]) self.write_packet(packet) response = self.read_packet() - if response[0] != 0x01: + if len(response) < 1 or response[0] != 0x01: raise StcProtocolException("incorrect magic in handshake packet") time.sleep(0.2) self.ser.baudrate = self.baud_transfer @@ -1348,9 +1426,9 @@ class Stc15Protocol(Stc15AProtocol): packet += bytes([0x00, 0x00, 0x5a, 0xa5]) self.write_packet(packet) response = self.read_packet() - if response[0] == 0x0f: + if len(response) == 1 and response[0] == 0x0f: raise StcProtocolException("MCU is locked") - if response[0] != 0x05: + if len(response) < 1 or response[0] != 0x05: raise StcProtocolException("incorrect magic in handshake packet") print("done") @@ -1371,13 +1449,17 @@ class Stc15Protocol(Stc15AProtocol): packet += bytes([0x00, 0x5a, 0xa5]) self.write_packet(packet) response = self.read_packet() - if response[0] != 0x03: + if len(response) < 1 or response[0] != 0x03: raise StcProtocolException("incorrect magic in handshake packet") print("done") if len(response) >= 8: self.uid = response[1:8] + # we should have a UID at this point + if not self.uid: + raise StcProtocolException("UID is missing") + def program_flash(self, data): """Program the MCU's flash memory.""" @@ -1392,7 +1474,7 @@ class Stc15Protocol(Stc15AProtocol): while len(packet) < self.PROGRAM_BLOCKSIZE + 3: packet += b"\x00" self.write_packet(packet) response = self.read_packet() - if response[0] != 0x02 or response[1] != 0x54: + if len(response) < 2 or response[0] != 0x02 or response[1] != 0x54: raise StcProtocolException("incorrect magic in write packet") print(".", end="") sys.stdout.flush() @@ -1405,7 +1487,7 @@ class Stc15Protocol(Stc15AProtocol): packet = bytes([0x07, 0x00, 0x00, 0x5a, 0xa5]) self.write_packet(packet) response = self.read_packet() - if response[0] != 0x07 or response[1] != 0x54: + if len(response) < 2 or response[0] != 0x07 or response[1] != 0x54: raise StcProtocolException("incorrect magic in finish packet") print("done") @@ -1444,7 +1526,7 @@ class Stc15Protocol(Stc15AProtocol): packet += self.build_options() self.write_packet(packet) response = self.read_packet() - if response[0] != 0x04 or response[1] != 0x54: + if len(response) < 2 or response[0] != 0x04 or response[1] != 0x54: raise StcProtocolException("incorrect magic in option packet") print("done") diff --git a/stcgal/utils.py b/stcgal/utils.py index 3204a30..2239e6d 100644 --- a/stcgal/utils.py +++ b/stcgal/utils.py @@ -29,14 +29,12 @@ class Utils: def to_bool(cls, val): """make sensible boolean from string or other type value""" - if val is None: + if not val: return False if isinstance(val, bool): return val elif isinstance(val, int): return bool(val) - elif len(val) == 0: - return False else: return True if val[0].lower() == "t" or val[0] == "1" else False diff --git a/test/__init__.py b/tests/__init__.py similarity index 100% rename from test/__init__.py rename to tests/__init__.py diff --git a/test/iap15f2k61s2.yml b/tests/iap15f2k61s2.yml similarity index 100% rename from test/iap15f2k61s2.yml rename to tests/iap15f2k61s2.yml diff --git a/test/stc12c2052ad.yml b/tests/stc12c2052ad.yml similarity index 100% rename from test/stc12c2052ad.yml rename to tests/stc12c2052ad.yml diff --git a/test/stc12c5a60s2.yml b/tests/stc12c5a60s2.yml similarity index 100% rename from test/stc12c5a60s2.yml rename to tests/stc12c5a60s2.yml diff --git a/test/stc15f104e.yml b/tests/stc15f104e.yml similarity index 100% rename from test/stc15f104e.yml rename to tests/stc15f104e.yml diff --git a/test/stc15l104w.yml b/tests/stc15l104w.yml similarity index 100% rename from test/stc15l104w.yml rename to tests/stc15l104w.yml diff --git a/test/stc15w4k56s4.yml b/tests/stc15w4k56s4.yml similarity index 100% rename from test/stc15w4k56s4.yml rename to tests/stc15w4k56s4.yml diff --git a/test/stc89c52rc.yml b/tests/stc89c52rc.yml similarity index 100% rename from test/stc89c52rc.yml rename to tests/stc89c52rc.yml diff --git a/tests/test_fuzzing.py b/tests/test_fuzzing.py new file mode 100644 index 0000000..52864b9 --- /dev/null +++ b/tests/test_fuzzing.py @@ -0,0 +1,119 @@ +# +# Copyright (c) 2017 Grigori Goronzy +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + +"""Tests with fuzzing of input data""" + +import random +import sys +import unittest +from unittest.mock import patch +import yaml +import stcgal.frontend +import stcgal.protocols +from tests.test_program import get_default_opts, convert_to_bytes + +class ByteArrayFuzzer: + """Fuzzer for byte arrays""" + + def __init__(self): + self.rng = random.Random() + self.cut_propability = 0.01 # probability for cutting off an array early + self.cut_min = 0 # minimum cut amount + self.cut_max = sys.maxsize # maximum cut amount + self.bitflip_probability = 0.0001 # probability for flipping a bit + self.randomize_probability = 0.001 # probability for randomizing a char + + def fuzz(self, inp): + """Fuzz an array of bytes according to predefined settings""" + arr = bytearray(inp) + arr = self.cut_off(arr) + self.randomize(arr) + return bytes(arr) + + def randomize(self, arr): + """Randomize array contents with bitflips and random bytes""" + for i, _ in enumerate(arr): + for j in range(8): + if self.rng.random() < self.bitflip_probability: + arr[i] ^= (1 << j) + if self.rng.random() < self.randomize_probability: + arr[i] = self.rng.getrandbits(8) + + def cut_off(self, arr): + """Cut off data from end of array""" + if self.rng.random() < self.cut_propability: + cut_limit = min(len(arr), self.cut_max) + cut_len = self.rng.randrange(self.cut_min, cut_limit) + arr = arr[0:len(arr) - cut_len] + return arr + +class TestProgramFuzzed(unittest.TestCase): + """Special programming cycle tests that use a fuzzing approach""" + + @patch("stcgal.protocols.StcBaseProtocol.read_packet") + @patch("stcgal.protocols.Stc89Protocol.write_packet") + @patch("stcgal.protocols.serial.Serial", autospec=True) + @patch("stcgal.protocols.time.sleep") + @patch("sys.stdout") + @patch("sys.stderr") + def test_program_fuzz(self, err, out, sleep_mock, serial_mock, write_mock, read_mock): + """Test programming cycles with fuzzing enabled""" + yml = [ + "./tests/iap15f2k61s2.yml", + "./tests/stc12c2052ad.yml", + "./tests/stc15w4k56s4.yml", + "./tests/stc12c5a60s2.yml", + "./tests/stc89c52rc.yml", + "./tests/stc15l104w.yml", + "./tests/stc15f104e.yml", + ] + fuzzer = ByteArrayFuzzer() + fuzzer.cut_propability = 0.01 + fuzzer.bitflip_probability = 0.005 + fuzzer.rng = random.Random(1) + for y in yml: + with self.subTest(msg="trace {}".format(y)): + self.single_fuzz(y, serial_mock, fuzzer, read_mock, err, out, + sleep_mock, write_mock) + + def single_fuzz(self, yml, serial_mock, fuzzer, read_mock, err, out, sleep_mock, write_mock): + """Test a single programming cycle with fuzzing""" + with open(yml) as test_file: + test_data = yaml.load(test_file.read()) + for _ in range(1000): + with self.subTest(): + opts = get_default_opts() + opts.protocol = test_data["protocol"] + opts.code_image.read.return_value = bytes(test_data["code_data"]) + serial_mock.return_value.inWaiting.return_value = 1 + fuzzed_responses = [] + for arr in convert_to_bytes(test_data["responses"]): + fuzzed_responses.append(fuzzer.fuzz(arr)) + read_mock.side_effect = fuzzed_responses + gal = stcgal.frontend.StcGal(opts) + self.assertGreaterEqual(gal.run(), 0) + err.reset_mock() + out.reset_mock() + sleep_mock.reset_mock() + serial_mock.reset_mock() + write_mock.reset_mock() + read_mock.reset_mock() diff --git a/test/test_program.py b/tests/test_program.py similarity index 91% rename from test/test_program.py rename to tests/test_program.py index a3444c1..c3938af 100644 --- a/test/test_program.py +++ b/tests/test_program.py @@ -57,7 +57,7 @@ class ProgramTests(unittest.TestCase): @patch("sys.stdout") def test_program_stc89(self, out, sleep_mock, serial_mock, write_mock, read_mock): """Test a programming cycle with STC89 protocol""" - self._program_yml("./test/stc89c52rc.yml", serial_mock, read_mock) + self._program_yml("./tests/stc89c52rc.yml", serial_mock, read_mock) @patch("stcgal.protocols.StcBaseProtocol.read_packet") @patch("stcgal.protocols.Stc89Protocol.write_packet") @@ -66,7 +66,7 @@ class ProgramTests(unittest.TestCase): @patch("sys.stdout") def test_program_stc12(self, out, sleep_mock, serial_mock, write_mock, read_mock): """Test a programming cycle with STC12 protocol""" - self._program_yml("./test/stc12c5a60s2.yml", serial_mock, read_mock) + self._program_yml("./tests/stc12c5a60s2.yml", serial_mock, read_mock) @patch("stcgal.protocols.StcBaseProtocol.read_packet") @patch("stcgal.protocols.Stc89Protocol.write_packet") @@ -75,7 +75,7 @@ class ProgramTests(unittest.TestCase): @patch("sys.stdout") def test_program_stc12a(self, out, sleep_mock, serial_mock, write_mock, read_mock): """Test a programming cycle with STC12A protocol""" - self._program_yml("./test/stc12c2052ad.yml", serial_mock, read_mock) + self._program_yml("./tests/stc12c2052ad.yml", serial_mock, read_mock) def test_program_stc12b(self): """Test a programming cycle with STC12B protocol""" @@ -88,7 +88,7 @@ class ProgramTests(unittest.TestCase): @patch("sys.stdout") def test_program_stc15f2(self, out, sleep_mock, serial_mock, write_mock, read_mock): """Test a programming cycle with STC15 protocol, F2 series""" - self._program_yml("./test/iap15f2k61s2.yml", serial_mock, read_mock) + self._program_yml("./tests/iap15f2k61s2.yml", serial_mock, read_mock) @patch("stcgal.protocols.StcBaseProtocol.read_packet") @patch("stcgal.protocols.Stc89Protocol.write_packet") @@ -97,7 +97,7 @@ class ProgramTests(unittest.TestCase): @patch("sys.stdout") def test_program_stc15w4(self, out, sleep_mock, serial_mock, write_mock, read_mock): """Test a programming cycle with STC15 protocol, W4 series""" - self._program_yml("./test/stc15w4k56s4.yml", serial_mock, read_mock) + self._program_yml("./tests/stc15w4k56s4.yml", serial_mock, read_mock) @unittest.skip("trace is broken") @patch("stcgal.protocols.StcBaseProtocol.read_packet") @@ -107,7 +107,7 @@ class ProgramTests(unittest.TestCase): @patch("sys.stdout") def test_program_stc15a(self, out, sleep_mock, serial_mock, write_mock, read_mock): """Test a programming cycle with STC15A protocol""" - self._program_yml("./test/stc15f104e.yml", serial_mock, read_mock) + self._program_yml("./tests/stc15f104e.yml", serial_mock, read_mock) @patch("stcgal.protocols.StcBaseProtocol.read_packet") @patch("stcgal.protocols.Stc89Protocol.write_packet") @@ -116,7 +116,7 @@ class ProgramTests(unittest.TestCase): @patch("sys.stdout") def test_program_stc15l1(self, out, sleep_mock, serial_mock, write_mock, read_mock): """Test a programming cycle with STC15 protocol, L1 series""" - self._program_yml("./test/stc15l104w.yml", serial_mock, read_mock) + self._program_yml("./tests/stc15l104w.yml", serial_mock, read_mock) def test_program_stc15w4_usb(self): """Test a programming cycle with STC15W4 USB protocol""" @@ -133,4 +133,3 @@ class ProgramTests(unittest.TestCase): read_mock.side_effect = convert_to_bytes(test_data["responses"]) gal = stcgal.frontend.StcGal(opts) self.assertEqual(gal.run(), 0) - \ No newline at end of file diff --git a/test/test_utils.py b/tests/test_utils.py similarity index 98% rename from test/test_utils.py rename to tests/test_utils.py index e5fce66..59809b7 100644 --- a/test/test_utils.py +++ b/tests/test_utils.py @@ -24,7 +24,6 @@ import argparse import unittest -from unittest.mock import patch from stcgal.utils import Utils, BaudType class TestUtils(unittest.TestCase):