Skip to content

Commit df80895

Browse files
committed
Arm backend: Test fixes for TOSA on Arm64
Signed-off-by: Zingo Andersen <Zingo.Andersen@arm.com> Change-Id: Ia8b796cb91c92ff45b36478d3b6904b25314f00c
1 parent 8dbbd9d commit df80895

4 files changed

Lines changed: 23 additions & 1 deletion

File tree

backends/arm/test/common.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66

77
import os
8+
import platform
89

910
from datetime import datetime
1011

@@ -27,6 +28,10 @@
2728
from executorch.backends.arm.vgf import VgfCompileSpec
2829

2930

31+
def is_aarch64_host() -> bool:
32+
return platform.machine().lower() in ("aarch64", "arm64")
33+
34+
3035
def get_time_formatted_path(path: str, log_prefix: str) -> str:
3136
"""Returns the log path with the current time appended to it. Used for
3237
debugging.

backends/arm/test/models/Qwen3_VL/test_qwen3_vl_layers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,16 @@ def test_qwen3_vl_tosa_FP(test_case: Qwen3VLTestCase):
488488
@common.parametrize(
489489
"test_case",
490490
TOSA_BF16_TEST_CASES,
491+
xfails=(
492+
{
493+
"vision_patch_embed": (
494+
"MLETORCH-2048: Large bf16 patch embedding mismatch on aarch64",
495+
AssertionError,
496+
),
497+
}
498+
if common.is_aarch64_host()
499+
else None
500+
),
491501
)
492502
def test_qwen3_vl_tosa_FP_bf16(test_case: Qwen3VLTestCase):
493503
model, inputs = test_case.model_cls.prepare_model_and_inputs()

backends/arm/test/models/test_mobilenet_v3_arm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def test_mv3_tosa_FP():
4545
pipeline.run()
4646

4747

48+
# Different atol for TOSA on ARM (MLETORCH-2048: Large bf16 patch embedding mismatch on aarch64)
4849
@pytest.mark.slow
4950
def test_mv3_tosa_FP_fp16():
5051
input_tensor_fp16 = torch.rand(
@@ -57,7 +58,7 @@ def test_mv3_tosa_FP_fp16():
5758
aten_op=[],
5859
exir_op=[],
5960
use_to_edge_transform_and_lower=True,
60-
atol=6e-2,
61+
atol=6.5e-2 if common.is_aarch64_host() else 6e-2,
6162
)
6263
pipeline.run()
6364

backends/arm/test/models/test_resnet18.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,12 @@ def test_resnet_18_tosa_FP():
4949
pipeline.run()
5050

5151

52+
@pytest.mark.xfail(
53+
common.is_aarch64_host(),
54+
reason="MLETORCH-2048: Large bf16 ResNet18 mismatch on aarch64",
55+
raises=AssertionError,
56+
strict=True,
57+
)
5258
def test_resnet_18_tosa_FP_bf16():
5359
bf16_model = resnet18(weights=ResNet18_Weights).eval()
5460
bf16_model = bf16_model.to(torch.bfloat16)

0 commit comments

Comments
 (0)