Skip to content

Commit 27cc7e8

Browse files
authored
Walk transparent ops when extracting input quant params (#20139)
Differential Revision: D107922730 Pull Request resolved: #20139
1 parent bd7426d commit 27cc7e8

1 file changed

Lines changed: 57 additions & 20 deletions

File tree

backends/cadence/aot/compiler_funcs.py

Lines changed: 57 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,28 @@
2222

2323
logger: logging.Logger = logging.getLogger(__name__)
2424
QuantArgs = tuple[float, int, int, int, torch.dtype]
25+
TRANSPARENT_OPS: frozenset[torch._ops.OpOverloadPacket] = frozenset(
26+
{
27+
torch.ops.aten.view,
28+
torch.ops.aten.view_copy,
29+
torch.ops.aten._unsafe_view,
30+
torch.ops.aten.reshape,
31+
torch.ops.aten.permute,
32+
torch.ops.aten.permute_copy,
33+
torch.ops.aten.transpose,
34+
torch.ops.aten.transpose_copy,
35+
torch.ops.aten.squeeze,
36+
torch.ops.aten.squeeze_copy,
37+
torch.ops.aten.unsqueeze,
38+
torch.ops.aten.unsqueeze_copy,
39+
torch.ops.aten.slice,
40+
torch.ops.aten.slice_copy,
41+
torch.ops.aten.contiguous,
42+
torch.ops.aten.clone,
43+
torch.ops.aten.to,
44+
torch.ops.aten._to_copy,
45+
}
46+
)
2547

2648

2749
@torch.no_grad()
@@ -244,36 +266,51 @@ def extract_input_quant_params_from_graph(
244266
) -> dict[int, QuantArgs]:
245267
"""
246268
Extract quantization parameters from the FX graph for model inputs.
269+
270+
For each name in ``input_names``, walk forward from the matching input
271+
node through value-preserving "transparent" ops (reshape, permute, ...)
272+
until reaching the ``quantize_per_tensor`` that fixes that input's scale
273+
and zero-point. Results are keyed by the index into ``input_names``.
247274
"""
248275
quant_args: dict[int, QuantArgs] = {}
249276
found_names: set[str] = set()
250277

251278
if not input_names:
252279
return quant_args
253280

281+
# Inputs are referenced by node name, which may be a placeholder or a node
282+
# that unpacks/derives the input (e.g. a `getitem` off a tuple/multi-output
283+
# input, as the modai eye-tracking model does), so look the start node up
284+
# across all nodes -- not just placeholders. Build the name->node map once
285+
# and reuse it for every requested input.
286+
nodes_by_name = {n.name: n for n in module.graph.nodes}
287+
288+
quantize_ops = _get_quantize_ops()
254289
for idx, name in enumerate(input_names):
255-
for node in module.graph.nodes:
256-
if node.op != "call_function":
290+
start = nodes_by_name.get(name)
291+
if start is None:
292+
continue
293+
seen: set[torch.fx.Node] = set()
294+
to_visit: list[torch.fx.Node] = list(start.users)
295+
while to_visit:
296+
node = to_visit.pop()
297+
if node in seen or node.op != "call_function":
257298
continue
258-
259-
if (
260-
node.args
261-
and isinstance(node.args[0], torch.fx.Node)
262-
and node.args[0].name == name
263-
and not node.name.startswith("_assert_tensor_metadata")
264-
and "quantize_per_tensor" in str(node.target)
265-
):
266-
args = node.args[1:]
267-
if len(args) >= 5:
268-
quant_args[idx] = (
269-
float(args[0]), # scale
270-
int(args[1]), # zero_point
271-
int(args[2]), # qmin
272-
int(args[3]), # qmax
273-
args[4], # dtype
274-
)
275-
found_names.add(name)
299+
seen.add(node)
300+
if node.target in quantize_ops:
301+
# Normalize args→kwargs so params passed positionally or as
302+
# kwargs (or via defaults) are all handled uniformly.
303+
quant_args[idx] = (
304+
float(get_arg(node, "scale", float)),
305+
int(get_arg(node, "zero_point", int)),
306+
int(get_arg(node, "quant_min", int)),
307+
int(get_arg(node, "quant_max", int)),
308+
get_arg(node, "dtype", torch.dtype),
309+
)
310+
found_names.add(name)
276311
break
312+
if getattr(node.target, "overloadpacket", None) in TRANSPARENT_OPS:
313+
to_visit.extend(node.users)
277314

278315
missing_names = set(input_names) - found_names
279316
if missing_names:

0 commit comments

Comments
 (0)