Skip to content

Commit 6a44c2a

Browse files
committed
Qualcomm AI Engine Direct - Test Framework Refactor
Co-authored-by: @winskuo-quic, @chenweng-quic - introduce pytest and reorganize the file architecture for finer-grained testing - wider coverage of operator test with combinations of different precisions, codebase was changed accordingly - add feature tests for HTP
1 parent b69cbcd commit 6a44c2a

74 files changed

Lines changed: 8429 additions & 355 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.ci/scripts/test_model.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ test_model_with_qnn() {
264264
;;
265265
esac
266266

267-
"${PYTHON_EXECUTABLE}" -m examples.qualcomm.${SCRIPT_FOLDER}.${EXPORT_SCRIPT} -b ${CMAKE_OUTPUT_DIR} -m ${QNN_CHIPSET} --ci --compile_only $EXTRA_FLAGS
267+
"${PYTHON_EXECUTABLE}" -m examples.qualcomm.${SCRIPT_FOLDER}.${EXPORT_SCRIPT} --build_folder ${CMAKE_OUTPUT_DIR} --soc_model ${QNN_CHIPSET} --ci --compile_only $EXTRA_FLAGS
268268
EXPORTED_MODEL=$(find "./${EXPORT_SCRIPT}" -type f -name "${MODEL_NAME}*.pte" -print -quit)
269269
}
270270

backends/qualcomm/_passes/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,9 @@
5252
from .recompose_pad_maxpool2d import RecomposePadMaxPool2d
5353
from .recompose_pixel_unshuffle import RecomposePixelUnshuffle
5454
from .recompose_rms_norm import RecomposeRmsNorm
55-
from .reduce_dynamic_range import ReduceDynamicRange
5655
from .remove_0d_tensor import Remove0DTensor
5756
from .remove_redundancy import RemoveRedundancy
5857
from .replace_arange_args import ReplaceArangeArgs
59-
from .replace_inf_values import ReplaceInfValues
6058
from .resolve_debug_handle import ResolveDebugHandle
6159
from .seq_mse import SeqMSE
6260
from .tag_quant_io import TagQuantIO
@@ -110,11 +108,9 @@
110108
RecomposePadMaxPool2d,
111109
RecomposePixelUnshuffle,
112110
RecomposeRmsNorm,
113-
ReduceDynamicRange,
114111
Remove0DTensor,
115112
RemoveRedundancy,
116113
ReplaceArangeArgs,
117-
ReplaceInfValues,
118114
ResolveDebugHandle,
119115
SeqMSE,
120116
TagQuantIO,

backends/qualcomm/_passes/decompose_col_im.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def _decompose_im2col(self, graph_module: torch.fx.GraphModule):
3030
if node.target == self.im2col_op:
3131
input_node = node.args[0]
3232
kernel_size = node.args[1]
33+
dilation = node.args[2]
34+
padding = node.args[3]
3335
stride = node.args[4]
3436
batch_size = node.meta["val"].shape[0]
3537
assert (
@@ -41,6 +43,12 @@ def _decompose_im2col(self, graph_module: torch.fx.GraphModule):
4143
assert (
4244
kernel_size[0] == kernel_size[1]
4345
), "im2col can only be converted when kernel height == width"
46+
assert all(
47+
d == 1 for d in dilation
48+
), "col2im can only be converted when dilation equals to (1, 1)"
49+
assert all(
50+
p == 0 for p in padding
51+
), "col2im can only be converted when padding equals to (0, 0)"
4452
users = list(node.users.keys())
4553
with graph_module.graph.inserting_after(input_node):
4654
pixel_unshuffle_node = graph_module.graph.create_node(
@@ -77,6 +85,8 @@ def _decompose_col2im(self, graph_module: torch.fx.GraphModule):
7785
input_node = node.args[0]
7886
output_size = node.args[1]
7987
kernel_size = node.args[2]
88+
dilation = node.args[3]
89+
padding = node.args[4]
8090
stride = node.args[5]
8191
batch_size = node.meta["val"].shape[0]
8292
assert (
@@ -88,6 +98,13 @@ def _decompose_col2im(self, graph_module: torch.fx.GraphModule):
8898
assert (
8999
kernel_size[0] == kernel_size[1]
90100
), "col2im can only be converted when kernel height == width"
101+
assert all(
102+
d == 1 for d in dilation
103+
), "col2im can only be converted when dilation equals to (1, 1)"
104+
assert all(
105+
p == 0 for p in padding
106+
), "col2im can only be converted when padding equals to (0, 0)"
107+
91108
users = list(node.users.keys())
92109
with graph_module.graph.inserting_after(input_node):
93110
view_tensor = input_node.meta["val"].reshape(

backends/qualcomm/_passes/decompose_linalg_vector_norm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ def forward(self, x):
2525
self.dim = 0
2626

2727
x = torch.abs(x)
28+
29+
# QNN would not be able to compute pow where exponential is inf or -inf.
30+
if self.exp == float("inf"):
31+
return torch.amax(x, dim=self.dim, keepdim=self.keepdim)
32+
if self.exp == float("-inf"):
33+
return torch.amin(x, dim=self.dim, keepdim=self.keepdim)
2834
x = torch.pow(x, self.exp)
2935
x = torch.sum(x, dim=self.dim, keepdim=self.keepdim)
3036
return torch.pow(x, 1.0 / self.exp)

backends/qualcomm/_passes/decompose_remainder.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,10 @@ def __init__(self):
2323
super(DecomposeRemainder, self).__init__()
2424
self.remainder_targets = {
2525
torch.ops.aten.remainder.Scalar,
26+
torch.ops.aten.remainder.Scalar_Tensor,
2627
torch.ops.aten.remainder.Tensor,
2728
exir_ops.edge.aten.remainder.Scalar,
29+
exir_ops.edge.aten.remainder.Scalar_Tensor,
2830
exir_ops.edge.aten.remainder.Tensor,
2931
}
3032

@@ -35,7 +37,7 @@ def call(self, graph_module: torch.fx.GraphModule):
3537

3638
for node in list(graph.nodes):
3739
if node.op == "call_function" and node.target in self.remainder_targets:
38-
x_node = node.args[0]
40+
x_arg = node.args[0]
3941
y_arg = node.args[1]
4042
is_edge = isinstance(node.target, EdgeOpOverload)
4143
meta = node.meta
@@ -61,8 +63,21 @@ def call(self, graph_module: torch.fx.GraphModule):
6163
else torch.ops.aten.sub.Tensor
6264
)
6365

64-
is_scalar = not isinstance(y_arg, torch.fx.Node)
65-
if is_scalar and is_edge:
66+
is_x_scalar = not isinstance(x_arg, torch.fx.Node)
67+
if is_x_scalar and is_edge:
68+
if x_arg not in const_cache:
69+
attr_name = get_new_attr_name_with_prefix("_remainder_const_")(
70+
graph_module
71+
)
72+
const_cache[x_arg] = get_const_node(
73+
graph, graph_module, attr_name, x_arg, node
74+
)
75+
x_node = const_cache[x_arg]
76+
else:
77+
x_node = x_arg
78+
79+
is_y_scalar = not isinstance(y_arg, torch.fx.Node)
80+
if is_y_scalar and is_edge:
6681
if y_arg not in const_cache:
6782
attr_name = get_new_attr_name_with_prefix("_remainder_const_")(
6883
graph_module

backends/qualcomm/_passes/decompose_roll.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self, val_shape, shifts, dims):
1616
super().__init__()
1717
self.val_shape = val_shape
1818
if dims[0] is None:
19-
self.shifts = [shifts[0] % torch.numel(torch.tensor(val_shape))]
19+
self.shifts = [shifts[0] % torch.numel(torch.empty(val_shape))]
2020
else:
2121
self.shifts = [shift % val_shape[dim] for shift, dim in zip(shifts, dims)]
2222
self.dims = dims

backends/qualcomm/_passes/lift_constant_scalar_operands.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,17 @@ class TensorOpInfo:
3333

3434
SCALAR_OPS = {
3535
aten.eq.Scalar: TensorOpInfo(aten.eq.Tensor, False, False),
36+
aten.eq.Tensor: TensorOpInfo(aten.eq.Tensor, False, False),
3637
aten.ge.Scalar: TensorOpInfo(aten.ge.Tensor, False, False),
38+
aten.ge.Tensor: TensorOpInfo(aten.ge.Tensor, False, False),
3739
aten.gt.Scalar: TensorOpInfo(aten.gt.Tensor, False, False),
40+
aten.gt.Tensor: TensorOpInfo(aten.gt.Tensor, False, False),
3841
aten.le.Scalar: TensorOpInfo(aten.le.Tensor, False, False),
42+
aten.le.Tensor: TensorOpInfo(aten.le.Tensor, False, False),
3943
aten.lt.Scalar: TensorOpInfo(aten.lt.Tensor, False, False),
44+
aten.lt.Tensor: TensorOpInfo(aten.lt.Tensor, False, False),
4045
aten.ne.Scalar: TensorOpInfo(aten.ne.Tensor, False, False),
46+
aten.ne.Tensor: TensorOpInfo(aten.ne.Tensor, False, False),
4147
aten.add.Scalar: TensorOpInfo(aten.add.Tensor, False, False),
4248
aten.add_.Scalar: TensorOpInfo(aten.add_.Tensor, False, False),
4349
# For below cases, refer to LiftAddTensor Model in UT for sample
@@ -87,6 +93,7 @@ def _build_tensor_constant(
8793
) -> TensorConstant:
8894
# For dtype, in some cases, we cannot use node.args[0] as scalar dtype.
8995
# Ex: Where op args[0] can be bool, however, we probably want args[1] and args[2] to be dtype same as node.meta["val"] instead of bool type
96+
9097
first_arg = node.args[0]
9198
tensor = torch.tensor(
9299
const_val,

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,9 @@
5757
RecomposePadMaxPool2d,
5858
RecomposePixelUnshuffle,
5959
RecomposeRmsNorm,
60-
ReduceDynamicRange,
6160
Remove0DTensor,
6261
RemoveRedundancy,
6362
ReplaceArangeArgs,
64-
ReplaceInfValues,
6563
ResolveDebugHandle,
6664
TagQuantIO,
6765
)
@@ -226,7 +224,6 @@ def transform_for_to_edge_pipeline(
226224
# Before quantizer
227225
def transform_for_annotation_pipeline(self, graph_module: GraphModule):
228226
self.add_pass(RemoveRedundancy(quantization_capture=True))
229-
self.add_pass(ReduceDynamicRange())
230227
self.add_pass(RecomposePixelUnshuffle(quantization_capture=True))
231228
self.add_pass(RecomposeRmsNorm(quantization_capture=True))
232229
self.add_pass(ReplaceArangeArgs())
@@ -255,7 +252,6 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
255252
self.add_pass(DecomposeSelectScatter())
256253
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
257254
self.add_pass(DecomposeLogVariants())
258-
self.add_pass(ReplaceInfValues())
259255
self.add_pass(LiftConstantScalarOperands())
260256
self.add_pass(InsertReshapeForReduceOps())
261257
return self._transform(graph_module)

backends/qualcomm/_passes/reduce_dynamic_range.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

backends/qualcomm/_passes/replace_inf_values.py

Lines changed: 0 additions & 66 deletions
This file was deleted.

0 commit comments

Comments
 (0)