Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions source/tests/consistent/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class CommonTest(ABC):
pd_class: ClassVar[type | None]
"""Paddle model class."""
array_api_strict_class: ClassVar[type | None]
"""array_api_strict model class."""
args: ClassVar[Argument | list[Argument] | None]
"""Arguments that maps to the `data`."""
skip_dp: ClassVar[bool] = False
Expand All @@ -124,6 +125,98 @@ class CommonTest(ABC):
atol = 1e-10
"""Absolute tolerance for comparing the return value. Override for float32."""

def __init_subclass__(cls, **kwargs) -> None:
super().__init_subclass__(**kwargs)
cls._prune_disabled_test_methods()

@classmethod
def _prune_disabled_test_methods(cls) -> None:
"""Drop inherited backend tests that are guaranteed to skip.

The consistency suites generate many parameterized unittest classes.
In CI jobs that only exercise a subset of backends/devices, collecting
inherited test methods that will always call ``skipTest`` adds noise and
scheduling overhead without improving coverage.

Shadowing those inherited methods with ``None`` keeps unittest/pytest
from collecting them in the first place while leaving potentially-runnable
methods untouched.
"""
for test_name in cls._disabled_test_methods():
if callable(getattr(cls, test_name, None)):
setattr(cls, test_name, None)

@classmethod
def _disabled_test_methods(cls) -> set[str]:
try:
case = cls()
except Exception:
return set()

def resolve_skip(name: str) -> bool | None:
try:
return bool(getattr(case, name))
except Exception:
return None

disabled = set()
backend_tests = {
"skip_tf": (
"test_tf_consistent_with_ref",
"test_tf_self_consistent",
),
"skip_dp": (
"test_dp_consistent_with_ref",
"test_dp_self_consistent",
),
"skip_pt": (
"test_pt_consistent_with_ref",
"test_pt_self_consistent",
"test_dp_pt_api",
),
"skip_pt_expt": (
"test_pt_expt_consistent_with_ref",
"test_pt_expt_self_consistent",
"test_dp_pt_expt_api",
),
"skip_jax": (
"test_jax_consistent_with_ref",
"test_jax_self_consistent",
),
"skip_pd": (
"test_pd_consistent_with_ref",
"test_pd_self_consistent",
),
"skip_array_api_strict": (
"test_array_api_strict_consistent_with_ref",
"test_array_api_strict_self_consistent",
),
}
for skip_attr, test_names in backend_tests.items():
if resolve_skip(skip_attr) is True:
disabled.update(test_names)

if getattr(case, "pt_expt_class", None) is None:
disabled.update(
(
"test_pt_expt_consistent_with_ref",
"test_pt_expt_self_consistent",
"test_dp_pt_expt_api",
)
)

if TEST_DEVICE != "cpu" and CI:
disabled.update(
(
"test_dp_consistent_with_ref",
"test_dp_self_consistent",
"test_array_api_strict_consistent_with_ref",
"test_array_api_strict_self_consistent",
)
)

return disabled

def setUp(self) -> None:
self.unique_id = uuid4().hex

Expand Down
Loading