44# LICENSE file in the root directory of this source tree.
55"""Provide quantization annotation logic for Arm backends.
66
7- This module computes per-node quantization properties and applies input/output
8- annotations to FX graphs using TorchAO qspecs.
9-
7+ This module computes per-node quantization properties and applies
8+ input/output annotations to FX graphs using TorchAO qspecs.
109"""
1110
1211import logging
@@ -57,7 +56,6 @@ class _OpQuantProperties:
5756 indexed by argument positions.
5857 quant_output (Optional[_QuantProperty]): Quantization spec for the
5958 node's output when applicable.
60-
6159 """
6260
6361 def __init__ (self ):
@@ -73,7 +71,6 @@ def _as_list(x):
7371
7472 Returns:
7573 list: ``x`` if already a list; otherwise ``[x]``.
76-
7774 """
7875 if isinstance (x , (list , tuple )):
7976 return x
@@ -122,7 +119,6 @@ def _is_ok_for_quantization(
122119
123120 Returns:
124121 bool: `True` if the node can be quantized, otherwise `False`.
125-
126122 """
127123 # Check output
128124 if quant_properties .quant_output is not None :
@@ -182,7 +178,6 @@ def _get_node_target(module: torch.nn.Module | torch.fx.GraphModule, target_str:
182178
183179 Returns:
184180 Any: Resolved attribute on the module.
185-
186181 """
187182 targets = target_str .split ("." )
188183 for target in targets [:- 1 ]:
@@ -195,7 +190,6 @@ def _is_large_scalar(node: Node, gm: torch.fx.GraphModule):
195190
196191 Large scalars are skipped because ``torch.histc`` supports values only up
197192 to a certain upper bound.
198-
199193 """
200194 HISTC_UPPER_BOUND = 3.4028235e15
201195 if node .op == "get_attr" and isinstance (node .target , str ):
@@ -213,7 +207,8 @@ def _is_large_scalar(node: Node, gm: torch.fx.GraphModule):
213207
214208
215209def _is_non_float_tensor (node : Node ) -> bool :
216- """Check if the output of a node has a data type other than `torch.float32`.
210+ """Check if the output of a node has a data type other than
211+ `torch.float32`.
217212
218213 If the output is not `torch.float32`, quantization cannot be performed, as
219214 observers only work with floating-point tensors.
@@ -230,7 +225,6 @@ def _is_non_float_tensor(node: Node) -> bool:
230225 `torch.float32` as its data type.
231226 - If node.meta["val"] is missing or is not an instance of `FakeTensor`,
232227 the function returns True.
233-
234228 """
235229 if "val" in node .meta and isinstance (node .meta ["val" ], Sequence ):
236230 return any (
@@ -258,7 +252,6 @@ def _annotate_input(node: Node, quant_property: _QuantProperty):
258252 Raises:
259253 RuntimeError: If the node is already annotated.
260254 TypeError: If an input argument is not a ``Node`` instance.
261-
262255 """
263256 if is_annotated (node ):
264257 raise RuntimeError (
@@ -295,7 +288,6 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
295288 RuntimeError: If the node is already annotated.
296289 ValueError: If ``mark_annotated`` is True, ``optional`` is True, or
297290 ``index`` is not zero.
298-
299291 """
300292 if is_annotated (node ):
301293 raise RuntimeError (
@@ -322,7 +314,6 @@ def _match_pattern(
322314 ``pattern``. If ``filter_fn`` is provided, require all nodes in the chain
323315 to pass the filter. Each pattern element is a list of disjunctive node
324316 targets.
325-
326317 """
327318 if len (pattern ) < 1 :
328319 raise ValueError ("No pattern provided" )
@@ -408,6 +399,7 @@ def _match_pattern(
408399 torch .ops .aten .squeeze_copy .default ,
409400 torch .ops .aten .squeeze_copy .dim ,
410401 torch .ops .aten .squeeze_ .dim ,
402+ torch .ops .aten .squeeze_copy .dims ,
411403 torch .ops .aten .squeeze .dim ,
412404 torch .ops .aten .squeeze .dims ,
413405 torch .ops .aten .unbind .int ,
@@ -503,7 +495,6 @@ def get_quant_properties( # noqa: C901
503495 Returns:
504496 _OpQuantProperties | None: Properties to apply, or ``None`` if the
505497 node is unsupported or not suitable for quantization.
506-
507498 """
508499 if node .target == torch .ops .aten .conv_transpose2d .input :
509500 weight_qspec = _adjust_weight_qspec_for_conv_transpose (
@@ -820,7 +811,6 @@ def annotate_graph( # type: ignore[return]
820811
821812 Returns:
822813 Optional[List[List[Node]]]: Reserved for future use; currently None.
823-
824814 """
825815 for node in gm .graph .nodes :
826816 if node .op != "call_function" :
0 commit comments