Skip to content

Commit 2d85ab0

Browse files
committed
Merge branch 'develop'
2 parents b63db5a + 0abcb23 commit 2d85ab0

3 files changed

Lines changed: 620 additions & 7 deletions

File tree

.coverage

52 KB
Binary file not shown.

tests/run_tests.py

Lines changed: 255 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,275 @@
33
Test runner script for the WEAC package.
44
55
This script discovers and runs all tests in the tests directory.
6+
Provides a pytest-like output with detailed reporting.
67
"""
78

89
import os
910
import sys
11+
import time
1012
import unittest
13+
from collections import defaultdict
14+
from typing import Dict
15+
16+
17+
class PytestLikeTextTestResult(unittest.TextTestResult):
18+
"""A test result class that provides pytest-like output format."""
19+
20+
PASS = "\033[92m" # Green
21+
FAIL = "\033[91m" # Red
22+
SKIP = "\033[93m" # Yellow
23+
END = "\033[0m" # Reset color
24+
BOLD = "\033[1m" # Bold text
25+
26+
def __init__(self, stream, descriptions, verbosity):
27+
"""Initialize the test result object."""
28+
# Override descriptions to prevent unittest from printing the test docstring
29+
super().__init__(stream, False, verbosity)
30+
self.stream = stream
31+
self.verbosity = verbosity
32+
self.descriptions = (
33+
False # Override to prevent unittest from printing docstrings
34+
)
35+
self.successes = []
36+
self.start_time = time.time()
37+
self.test_times: Dict[str, float] = {}
38+
self.module_counts: Dict[str, Dict[str, int]] = defaultdict(
39+
lambda: defaultdict(int)
40+
)
41+
self.total_tests = 0
42+
self.test_counter = 0
43+
44+
# Print header
45+
self.stream.write(
46+
f"\n{self.BOLD}============================== test session starts =============================={self.END}\n"
47+
)
48+
self.stream.write(f"platform: {sys.platform}, Python {sys.version.split()[0]}\n")
49+
self.stream.write(
50+
f"rootdir: {os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))}\n"
51+
)
52+
self.stream.flush()
53+
54+
def getDescription(self, test):
55+
"""Override to return an empty description, preventing unittest from printing the docstring."""
56+
return ""
57+
58+
def set_total_tests(self, count):
59+
"""Set the total number of tests to be run."""
60+
self.total_tests = count
61+
62+
def startTest(self, test):
63+
"""Called when a test starts."""
64+
super().startTest(test)
65+
self.test_start_time = time.time()
66+
self.test_counter += 1
67+
68+
if self.verbosity > 1:
69+
# Extract test name and module in a cleaner format
70+
test_id = test.id()
71+
module_name, class_name, test_name = test_id.split(".")[-3:]
72+
73+
# Get test description
74+
doc = test._testMethodDoc or ""
75+
76+
# Print the test name with progress indicator and description
77+
progress = f"[ {self.test_counter}/{self.total_tests} ]"
78+
self.stream.write(f"\n{progress} {module_name}.{class_name}.{test_name}\n")
79+
if doc:
80+
self.stream.write(f" {doc}\n")
81+
82+
# Indentation for the result
83+
self.stream.write(" ")
84+
self.stream.flush()
85+
86+
def _get_module_name(self, test):
87+
"""Extract module name from test."""
88+
return test.__class__.__module__.split(".")[-1]
89+
90+
def addSuccess(self, test):
91+
"""Called when a test succeeds."""
92+
super().addSuccess(test)
93+
self.successes.append(test)
94+
self.test_times[test.id()] = time.time() - self.test_start_time
95+
module_name = self._get_module_name(test)
96+
self.module_counts[module_name]["passed"] += 1
97+
98+
if self.verbosity > 1:
99+
self.stream.write(f" {self.PASS}✓ PASS{self.END}\n")
100+
self.stream.flush()
101+
102+
def addError(self, test, err):
103+
"""Called when a test raises an error."""
104+
super().addError(test, err)
105+
self.test_times[test.id()] = time.time() - self.test_start_time
106+
module_name = self._get_module_name(test)
107+
self.module_counts[module_name]["errors"] += 1
108+
109+
if self.verbosity > 1:
110+
self.stream.write(f" {self.FAIL}E ERROR{self.END}\n")
111+
self.stream.flush()
112+
113+
def addFailure(self, test, err):
114+
"""Called when a test fails."""
115+
super().addFailure(test, err)
116+
self.test_times[test.id()] = time.time() - self.test_start_time
117+
module_name = self._get_module_name(test)
118+
self.module_counts[module_name]["failures"] += 1
119+
120+
if self.verbosity > 1:
121+
self.stream.write(f" {self.FAIL}✗ FAIL{self.END}\n")
122+
self.stream.flush()
123+
124+
def addSkip(self, test, reason):
125+
"""Called when a test is skipped."""
126+
super().addSkip(test, reason)
127+
self.test_times[test.id()] = time.time() - self.test_start_time
128+
module_name = self._get_module_name(test)
129+
self.module_counts[module_name]["skipped"] += 1
130+
131+
if self.verbosity > 1:
132+
self.stream.write(f" {self.SKIP}s SKIP{self.END} [{reason}]\n")
133+
self.stream.flush()
134+
135+
def printErrors(self):
136+
"""Print a formatted report of errors and failures."""
137+
if self.errors or self.failures:
138+
self.stream.write(
139+
f"\n{self.BOLD}============================== FAILURES =============================={self.END}\n"
140+
)
141+
142+
for test, err in self.errors + self.failures:
143+
test_id = test.id()
144+
module_name, class_name, test_name = test_id.split(".")[-3:]
145+
self.stream.write(
146+
f"\n{self.BOLD}{self.FAIL}FAILED{self.END} {module_name}.{class_name}.{test_name}{self.END}\n"
147+
)
148+
self.stream.write(f"{err}\n")
149+
150+
def printTotal(self):
151+
"""Print a summary of all tests run."""
152+
total_time = time.time() - self.start_time
153+
total_tests = self.testsRun
154+
passed = len(self.successes)
155+
failures = len(self.failures)
156+
errors = len(self.errors)
157+
skipped = len(self.skipped)
158+
159+
# Print per-module summary
160+
self.stream.write(
161+
f"\n{self.BOLD}============================== test summary info =============================={self.END}\n"
162+
)
163+
164+
for module, counts in sorted(self.module_counts.items()):
165+
module_total = sum(counts.values())
166+
result_str = []
167+
if counts["passed"]:
168+
result_str.append(f"{self.PASS}{counts['passed']} passed{self.END}")
169+
if counts["failures"]:
170+
result_str.append(f"{self.FAIL}{counts['failures']} failed{self.END}")
171+
if counts["errors"]:
172+
result_str.append(f"{self.FAIL}{counts['errors']} errors{self.END}")
173+
if counts["skipped"]:
174+
result_str.append(f"{self.SKIP}{counts['skipped']} skipped{self.END}")
175+
176+
self.stream.write(f"{module}: {', '.join(result_str)}\n")
177+
178+
# Print overall summary
179+
self.stream.write(
180+
f"\n{self.BOLD}============================== {total_tests} tests ran in {total_time:.2f}s =============================={self.END}\n"
181+
)
182+
183+
result_parts = []
184+
if passed:
185+
result_parts.append(f"{self.PASS}{passed} passed{self.END}")
186+
if failures:
187+
result_parts.append(f"{self.FAIL}{failures} failed{self.END}")
188+
if errors:
189+
result_parts.append(f"{self.FAIL}{errors} errors{self.END}")
190+
if skipped:
191+
result_parts.append(f"{self.SKIP}{skipped} skipped{self.END}")
192+
193+
self.stream.write(", ".join(result_parts) + "\n")
194+
195+
196+
class PytestLikeTextTestRunner(unittest.TextTestRunner):
197+
"""A test runner that uses PytestLikeTextTestResult to display results."""
198+
199+
def __init__(
200+
self,
201+
stream=None,
202+
descriptions=False, # Override to prevent unittest from printing docstrings
203+
verbosity=1,
204+
failfast=False,
205+
buffer=False,
206+
warnings=None,
207+
):
208+
"""Initialize the runner."""
209+
super().__init__(stream, descriptions, verbosity, failfast, buffer, warnings)
210+
211+
def _makeResult(self):
212+
"""Create and return a test result object that will be used to store results."""
213+
return PytestLikeTextTestResult(self.stream, self.descriptions, self.verbosity)
214+
215+
def run(self, test):
216+
"""Run the given test case or test suite."""
217+
result = self._makeResult()
218+
result.set_total_tests(self._count_tests(test))
219+
220+
self.stream.write(f"collecting ... {result.total_tests} items collected\n")
221+
222+
# Run tests
223+
startTime = time.time()
224+
startTestRun = getattr(result, "startTestRun", None)
225+
if startTestRun is not None:
226+
startTestRun()
227+
try:
228+
test(result)
229+
finally:
230+
stopTestRun = getattr(result, "stopTestRun", None)
231+
if stopTestRun is not None:
232+
stopTestRun()
233+
234+
result.printErrors()
235+
result.printTotal()
236+
return result
237+
238+
def _count_tests(self, test):
239+
"""Count the total number of tests in a test suite."""
240+
if hasattr(test, "_tests"):
241+
return sum(self._count_tests(t) for t in test._tests)
242+
else:
243+
return 1
244+
245+
246+
class CustomTextTestRunner(unittest.TextTestRunner):
247+
"""Hide default unittest output since we're using our custom runner."""
248+
249+
def run(self, test):
250+
"""Run the test suite with no output."""
251+
result = super().run(test)
252+
return result
11253

12254

13255
def run_tests():
14256
"""Discover and run all tests in the tests directory."""
257+
# Redirect both standard out and standard error to capture unittest output
258+
# This prevents duplicate output from the standard unittest runner
259+
import io
260+
from contextlib import redirect_stderr, redirect_stdout
261+
262+
f = io.StringIO()
263+
15264
# Get the directory containing this script
16265
test_dir = os.path.dirname(os.path.abspath(__file__))
17266

18-
# Discover all tests in the tests directory
19-
test_suite = unittest.defaultTestLoader.discover(test_dir)
267+
# Create a test runner with pytest-like output
268+
test_runner = PytestLikeTextTestRunner(verbosity=2)
20269

21-
# Create a test runner
22-
test_runner = unittest.TextTestRunner(verbosity=2)
270+
# Discover all tests in the tests directory
271+
with redirect_stdout(f), redirect_stderr(f):
272+
test_suite = unittest.defaultTestLoader.discover(test_dir)
23273

24-
# Run the tests
274+
# Run the tests with our custom output
25275
result = test_runner.run(test_suite)
26276

27277
# Return appropriate exit code

0 commit comments

Comments
 (0)