|
3 | 3 | Test runner script for the WEAC package. |
4 | 4 |
|
5 | 5 | This script discovers and runs all tests in the tests directory. |
| 6 | +Provides a pytest-like output with detailed reporting. |
6 | 7 | """ |
7 | 8 |
|
8 | 9 | import os |
9 | 10 | import sys |
| 11 | +import time |
10 | 12 | import unittest |
| 13 | +from collections import defaultdict |
| 14 | +from typing import Dict |
| 15 | + |
| 16 | + |
| 17 | +class PytestLikeTextTestResult(unittest.TextTestResult): |
| 18 | + """A test result class that provides pytest-like output format.""" |
| 19 | + |
| 20 | + PASS = "\033[92m" # Green |
| 21 | + FAIL = "\033[91m" # Red |
| 22 | + SKIP = "\033[93m" # Yellow |
| 23 | + END = "\033[0m" # Reset color |
| 24 | + BOLD = "\033[1m" # Bold text |
| 25 | + |
| 26 | + def __init__(self, stream, descriptions, verbosity): |
| 27 | + """Initialize the test result object.""" |
| 28 | + # Override descriptions to prevent unittest from printing the test docstring |
| 29 | + super().__init__(stream, False, verbosity) |
| 30 | + self.stream = stream |
| 31 | + self.verbosity = verbosity |
| 32 | + self.descriptions = ( |
| 33 | + False # Override to prevent unittest from printing docstrings |
| 34 | + ) |
| 35 | + self.successes = [] |
| 36 | + self.start_time = time.time() |
| 37 | + self.test_times: Dict[str, float] = {} |
| 38 | + self.module_counts: Dict[str, Dict[str, int]] = defaultdict( |
| 39 | + lambda: defaultdict(int) |
| 40 | + ) |
| 41 | + self.total_tests = 0 |
| 42 | + self.test_counter = 0 |
| 43 | + |
| 44 | + # Print header |
| 45 | + self.stream.write( |
| 46 | + f"\n{self.BOLD}============================== test session starts =============================={self.END}\n" |
| 47 | + ) |
| 48 | + self.stream.write(f"platform: {sys.platform}, Python {sys.version.split()[0]}\n") |
| 49 | + self.stream.write( |
| 50 | + f"rootdir: {os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))}\n" |
| 51 | + ) |
| 52 | + self.stream.flush() |
| 53 | + |
| 54 | + def getDescription(self, test): |
| 55 | + """Override to return an empty description, preventing unittest from printing the docstring.""" |
| 56 | + return "" |
| 57 | + |
| 58 | + def set_total_tests(self, count): |
| 59 | + """Set the total number of tests to be run.""" |
| 60 | + self.total_tests = count |
| 61 | + |
| 62 | + def startTest(self, test): |
| 63 | + """Called when a test starts.""" |
| 64 | + super().startTest(test) |
| 65 | + self.test_start_time = time.time() |
| 66 | + self.test_counter += 1 |
| 67 | + |
| 68 | + if self.verbosity > 1: |
| 69 | + # Extract test name and module in a cleaner format |
| 70 | + test_id = test.id() |
| 71 | + module_name, class_name, test_name = test_id.split(".")[-3:] |
| 72 | + |
| 73 | + # Get test description |
| 74 | + doc = test._testMethodDoc or "" |
| 75 | + |
| 76 | + # Print the test name with progress indicator and description |
| 77 | + progress = f"[ {self.test_counter}/{self.total_tests} ]" |
| 78 | + self.stream.write(f"\n{progress} {module_name}.{class_name}.{test_name}\n") |
| 79 | + if doc: |
| 80 | + self.stream.write(f" {doc}\n") |
| 81 | + |
| 82 | + # Indentation for the result |
| 83 | + self.stream.write(" ") |
| 84 | + self.stream.flush() |
| 85 | + |
| 86 | + def _get_module_name(self, test): |
| 87 | + """Extract module name from test.""" |
| 88 | + return test.__class__.__module__.split(".")[-1] |
| 89 | + |
| 90 | + def addSuccess(self, test): |
| 91 | + """Called when a test succeeds.""" |
| 92 | + super().addSuccess(test) |
| 93 | + self.successes.append(test) |
| 94 | + self.test_times[test.id()] = time.time() - self.test_start_time |
| 95 | + module_name = self._get_module_name(test) |
| 96 | + self.module_counts[module_name]["passed"] += 1 |
| 97 | + |
| 98 | + if self.verbosity > 1: |
| 99 | + self.stream.write(f" {self.PASS}✓ PASS{self.END}\n") |
| 100 | + self.stream.flush() |
| 101 | + |
| 102 | + def addError(self, test, err): |
| 103 | + """Called when a test raises an error.""" |
| 104 | + super().addError(test, err) |
| 105 | + self.test_times[test.id()] = time.time() - self.test_start_time |
| 106 | + module_name = self._get_module_name(test) |
| 107 | + self.module_counts[module_name]["errors"] += 1 |
| 108 | + |
| 109 | + if self.verbosity > 1: |
| 110 | + self.stream.write(f" {self.FAIL}E ERROR{self.END}\n") |
| 111 | + self.stream.flush() |
| 112 | + |
| 113 | + def addFailure(self, test, err): |
| 114 | + """Called when a test fails.""" |
| 115 | + super().addFailure(test, err) |
| 116 | + self.test_times[test.id()] = time.time() - self.test_start_time |
| 117 | + module_name = self._get_module_name(test) |
| 118 | + self.module_counts[module_name]["failures"] += 1 |
| 119 | + |
| 120 | + if self.verbosity > 1: |
| 121 | + self.stream.write(f" {self.FAIL}✗ FAIL{self.END}\n") |
| 122 | + self.stream.flush() |
| 123 | + |
| 124 | + def addSkip(self, test, reason): |
| 125 | + """Called when a test is skipped.""" |
| 126 | + super().addSkip(test, reason) |
| 127 | + self.test_times[test.id()] = time.time() - self.test_start_time |
| 128 | + module_name = self._get_module_name(test) |
| 129 | + self.module_counts[module_name]["skipped"] += 1 |
| 130 | + |
| 131 | + if self.verbosity > 1: |
| 132 | + self.stream.write(f" {self.SKIP}s SKIP{self.END} [{reason}]\n") |
| 133 | + self.stream.flush() |
| 134 | + |
| 135 | + def printErrors(self): |
| 136 | + """Print a formatted report of errors and failures.""" |
| 137 | + if self.errors or self.failures: |
| 138 | + self.stream.write( |
| 139 | + f"\n{self.BOLD}============================== FAILURES =============================={self.END}\n" |
| 140 | + ) |
| 141 | + |
| 142 | + for test, err in self.errors + self.failures: |
| 143 | + test_id = test.id() |
| 144 | + module_name, class_name, test_name = test_id.split(".")[-3:] |
| 145 | + self.stream.write( |
| 146 | + f"\n{self.BOLD}{self.FAIL}FAILED{self.END} {module_name}.{class_name}.{test_name}{self.END}\n" |
| 147 | + ) |
| 148 | + self.stream.write(f"{err}\n") |
| 149 | + |
| 150 | + def printTotal(self): |
| 151 | + """Print a summary of all tests run.""" |
| 152 | + total_time = time.time() - self.start_time |
| 153 | + total_tests = self.testsRun |
| 154 | + passed = len(self.successes) |
| 155 | + failures = len(self.failures) |
| 156 | + errors = len(self.errors) |
| 157 | + skipped = len(self.skipped) |
| 158 | + |
| 159 | + # Print per-module summary |
| 160 | + self.stream.write( |
| 161 | + f"\n{self.BOLD}============================== test summary info =============================={self.END}\n" |
| 162 | + ) |
| 163 | + |
| 164 | + for module, counts in sorted(self.module_counts.items()): |
| 165 | + module_total = sum(counts.values()) |
| 166 | + result_str = [] |
| 167 | + if counts["passed"]: |
| 168 | + result_str.append(f"{self.PASS}{counts['passed']} passed{self.END}") |
| 169 | + if counts["failures"]: |
| 170 | + result_str.append(f"{self.FAIL}{counts['failures']} failed{self.END}") |
| 171 | + if counts["errors"]: |
| 172 | + result_str.append(f"{self.FAIL}{counts['errors']} errors{self.END}") |
| 173 | + if counts["skipped"]: |
| 174 | + result_str.append(f"{self.SKIP}{counts['skipped']} skipped{self.END}") |
| 175 | + |
| 176 | + self.stream.write(f"{module}: {', '.join(result_str)}\n") |
| 177 | + |
| 178 | + # Print overall summary |
| 179 | + self.stream.write( |
| 180 | + f"\n{self.BOLD}============================== {total_tests} tests ran in {total_time:.2f}s =============================={self.END}\n" |
| 181 | + ) |
| 182 | + |
| 183 | + result_parts = [] |
| 184 | + if passed: |
| 185 | + result_parts.append(f"{self.PASS}{passed} passed{self.END}") |
| 186 | + if failures: |
| 187 | + result_parts.append(f"{self.FAIL}{failures} failed{self.END}") |
| 188 | + if errors: |
| 189 | + result_parts.append(f"{self.FAIL}{errors} errors{self.END}") |
| 190 | + if skipped: |
| 191 | + result_parts.append(f"{self.SKIP}{skipped} skipped{self.END}") |
| 192 | + |
| 193 | + self.stream.write(", ".join(result_parts) + "\n") |
| 194 | + |
| 195 | + |
| 196 | +class PytestLikeTextTestRunner(unittest.TextTestRunner): |
| 197 | + """A test runner that uses PytestLikeTextTestResult to display results.""" |
| 198 | + |
| 199 | + def __init__( |
| 200 | + self, |
| 201 | + stream=None, |
| 202 | + descriptions=False, # Override to prevent unittest from printing docstrings |
| 203 | + verbosity=1, |
| 204 | + failfast=False, |
| 205 | + buffer=False, |
| 206 | + warnings=None, |
| 207 | + ): |
| 208 | + """Initialize the runner.""" |
| 209 | + super().__init__(stream, descriptions, verbosity, failfast, buffer, warnings) |
| 210 | + |
| 211 | + def _makeResult(self): |
| 212 | + """Create and return a test result object that will be used to store results.""" |
| 213 | + return PytestLikeTextTestResult(self.stream, self.descriptions, self.verbosity) |
| 214 | + |
| 215 | + def run(self, test): |
| 216 | + """Run the given test case or test suite.""" |
| 217 | + result = self._makeResult() |
| 218 | + result.set_total_tests(self._count_tests(test)) |
| 219 | + |
| 220 | + self.stream.write(f"collecting ... {result.total_tests} items collected\n") |
| 221 | + |
| 222 | + # Run tests |
| 223 | + startTime = time.time() |
| 224 | + startTestRun = getattr(result, "startTestRun", None) |
| 225 | + if startTestRun is not None: |
| 226 | + startTestRun() |
| 227 | + try: |
| 228 | + test(result) |
| 229 | + finally: |
| 230 | + stopTestRun = getattr(result, "stopTestRun", None) |
| 231 | + if stopTestRun is not None: |
| 232 | + stopTestRun() |
| 233 | + |
| 234 | + result.printErrors() |
| 235 | + result.printTotal() |
| 236 | + return result |
| 237 | + |
| 238 | + def _count_tests(self, test): |
| 239 | + """Count the total number of tests in a test suite.""" |
| 240 | + if hasattr(test, "_tests"): |
| 241 | + return sum(self._count_tests(t) for t in test._tests) |
| 242 | + else: |
| 243 | + return 1 |
| 244 | + |
| 245 | + |
| 246 | +class CustomTextTestRunner(unittest.TextTestRunner): |
| 247 | + """Hide default unittest output since we're using our custom runner.""" |
| 248 | + |
| 249 | + def run(self, test): |
| 250 | + """Run the test suite with no output.""" |
| 251 | + result = super().run(test) |
| 252 | + return result |
11 | 253 |
|
12 | 254 |
|
13 | 255 | def run_tests(): |
14 | 256 | """Discover and run all tests in the tests directory.""" |
| 257 | + # Redirect both standard out and standard error to capture unittest output |
| 258 | + # This prevents duplicate output from the standard unittest runner |
| 259 | + import io |
| 260 | + from contextlib import redirect_stderr, redirect_stdout |
| 261 | + |
| 262 | + f = io.StringIO() |
| 263 | + |
15 | 264 | # Get the directory containing this script |
16 | 265 | test_dir = os.path.dirname(os.path.abspath(__file__)) |
17 | 266 |
|
18 | | - # Discover all tests in the tests directory |
19 | | - test_suite = unittest.defaultTestLoader.discover(test_dir) |
| 267 | + # Create a test runner with pytest-like output |
| 268 | + test_runner = PytestLikeTextTestRunner(verbosity=2) |
20 | 269 |
|
21 | | - # Create a test runner |
22 | | - test_runner = unittest.TextTestRunner(verbosity=2) |
| 270 | + # Discover all tests in the tests directory |
| 271 | + with redirect_stdout(f), redirect_stderr(f): |
| 272 | + test_suite = unittest.defaultTestLoader.discover(test_dir) |
23 | 273 |
|
24 | | - # Run the tests |
| 274 | + # Run the tests with our custom output |
25 | 275 | result = test_runner.run(test_suite) |
26 | 276 |
|
27 | 277 | # Return appropriate exit code |
|
0 commit comments