Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backends/arm/test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@


import os
import platform

from datetime import datetime

Expand All @@ -27,6 +28,10 @@
from executorch.backends.arm.vgf import VgfCompileSpec


def is_aarch64_host() -> bool:
return platform.machine().lower() in ("aarch64", "arm64")


def get_time_formatted_path(path: str, log_prefix: str) -> str:
"""Returns the log path with the current time appended to it. Used for
debugging.
Expand Down
10 changes: 10 additions & 0 deletions backends/arm/test/models/Qwen3_VL/test_qwen3_vl_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,16 @@ def test_qwen3_vl_tosa_FP(test_case: Qwen3VLTestCase):
@common.parametrize(
"test_case",
TOSA_BF16_TEST_CASES,
xfails=(
{
"vision_patch_embed": (
"MLETORCH-2048: Large bf16 patch embedding mismatch on aarch64",
AssertionError,
),
}
if common.is_aarch64_host()
else None
),
)
def test_qwen3_vl_tosa_FP_bf16(test_case: Qwen3VLTestCase):
model, inputs = test_case.model_cls.prepare_model_and_inputs()
Expand Down
3 changes: 2 additions & 1 deletion backends/arm/test/models/test_mobilenet_v3_arm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def test_mv3_tosa_FP():
pipeline.run()


# Slightly higher atol for TOSA FP16 on aarch64 (MLETORCH-2048: numeric mismatch)
@pytest.mark.slow
def test_mv3_tosa_FP_fp16():
input_tensor_fp16 = torch.rand(
Expand All @@ -57,7 +58,7 @@ def test_mv3_tosa_FP_fp16():
aten_op=[],
exir_op=[],
use_to_edge_transform_and_lower=True,
atol=6e-2,
atol=6.5e-2 if common.is_aarch64_host() else 6e-2,
)
pipeline.run()

Expand Down
6 changes: 6 additions & 0 deletions backends/arm/test/models/test_resnet18.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def test_resnet_18_tosa_FP():
pipeline.run()


@pytest.mark.xfail(
common.is_aarch64_host(),
reason="MLETORCH-2048: Large bf16 ResNet18 mismatch on aarch64",
raises=AssertionError,
strict=True,
)
def test_resnet_18_tosa_FP_bf16():
bf16_model = resnet18(weights=ResNet18_Weights).eval()
bf16_model = bf16_model.to(torch.bfloat16)
Expand Down
Loading