|
22 | 22 |
|
23 | 23 | logger: logging.Logger = logging.getLogger(__name__) |
24 | 24 | QuantArgs = tuple[float, int, int, int, torch.dtype] |
| 25 | +TRANSPARENT_OPS: frozenset[torch._ops.OpOverloadPacket] = frozenset( |
| 26 | + { |
| 27 | + torch.ops.aten.view, |
| 28 | + torch.ops.aten.view_copy, |
| 29 | + torch.ops.aten._unsafe_view, |
| 30 | + torch.ops.aten.reshape, |
| 31 | + torch.ops.aten.permute, |
| 32 | + torch.ops.aten.permute_copy, |
| 33 | + torch.ops.aten.transpose, |
| 34 | + torch.ops.aten.transpose_copy, |
| 35 | + torch.ops.aten.squeeze, |
| 36 | + torch.ops.aten.squeeze_copy, |
| 37 | + torch.ops.aten.unsqueeze, |
| 38 | + torch.ops.aten.unsqueeze_copy, |
| 39 | + torch.ops.aten.slice, |
| 40 | + torch.ops.aten.slice_copy, |
| 41 | + torch.ops.aten.contiguous, |
| 42 | + torch.ops.aten.clone, |
| 43 | + torch.ops.aten.to, |
| 44 | + torch.ops.aten._to_copy, |
| 45 | + } |
| 46 | +) |
25 | 47 |
|
26 | 48 |
|
27 | 49 | @torch.no_grad() |
@@ -244,36 +266,51 @@ def extract_input_quant_params_from_graph( |
244 | 266 | ) -> dict[int, QuantArgs]: |
245 | 267 | """ |
246 | 268 | Extract quantization parameters from the FX graph for model inputs. |
| 269 | +
|
| 270 | + For each name in ``input_names``, walk forward from the matching input |
| 271 | + node through value-preserving "transparent" ops (reshape, permute, ...) |
| 272 | + until reaching the ``quantize_per_tensor`` that fixes that input's scale |
| 273 | + and zero-point. Results are keyed by the index into ``input_names``. |
247 | 274 | """ |
248 | 275 | quant_args: dict[int, QuantArgs] = {} |
249 | 276 | found_names: set[str] = set() |
250 | 277 |
|
251 | 278 | if not input_names: |
252 | 279 | return quant_args |
253 | 280 |
|
| 281 | + # Inputs are referenced by node name, which may be a placeholder or a node |
| 282 | + # that unpacks/derives the input (e.g. a `getitem` off a tuple/multi-output |
| 283 | + # input, as the modai eye-tracking model does), so look the start node up |
| 284 | + # across all nodes -- not just placeholders. Build the name->node map once |
| 285 | + # and reuse it for every requested input. |
| 286 | + nodes_by_name = {n.name: n for n in module.graph.nodes} |
| 287 | + |
| 288 | + quantize_ops = _get_quantize_ops() |
254 | 289 | for idx, name in enumerate(input_names): |
255 | | - for node in module.graph.nodes: |
256 | | - if node.op != "call_function": |
| 290 | + start = nodes_by_name.get(name) |
| 291 | + if start is None: |
| 292 | + continue |
| 293 | + seen: set[torch.fx.Node] = set() |
| 294 | + to_visit: list[torch.fx.Node] = list(start.users) |
| 295 | + while to_visit: |
| 296 | + node = to_visit.pop() |
| 297 | + if node in seen or node.op != "call_function": |
257 | 298 | continue |
258 | | - |
259 | | - if ( |
260 | | - node.args |
261 | | - and isinstance(node.args[0], torch.fx.Node) |
262 | | - and node.args[0].name == name |
263 | | - and not node.name.startswith("_assert_tensor_metadata") |
264 | | - and "quantize_per_tensor" in str(node.target) |
265 | | - ): |
266 | | - args = node.args[1:] |
267 | | - if len(args) >= 5: |
268 | | - quant_args[idx] = ( |
269 | | - float(args[0]), # scale |
270 | | - int(args[1]), # zero_point |
271 | | - int(args[2]), # qmin |
272 | | - int(args[3]), # qmax |
273 | | - args[4], # dtype |
274 | | - ) |
275 | | - found_names.add(name) |
| 299 | + seen.add(node) |
| 300 | + if node.target in quantize_ops: |
| 301 | + # Normalize args→kwargs so params passed positionally or as |
| 302 | + # kwargs (or via defaults) are all handled uniformly. |
| 303 | + quant_args[idx] = ( |
| 304 | + float(get_arg(node, "scale", float)), |
| 305 | + int(get_arg(node, "zero_point", int)), |
| 306 | + int(get_arg(node, "quant_min", int)), |
| 307 | + int(get_arg(node, "quant_max", int)), |
| 308 | + get_arg(node, "dtype", torch.dtype), |
| 309 | + ) |
| 310 | + found_names.add(name) |
276 | 311 | break |
| 312 | + if getattr(node.target, "overloadpacket", None) in TRANSPARENT_OPS: |
| 313 | + to_visit.extend(node.users) |
277 | 314 |
|
278 | 315 | missing_names = set(input_names) - found_names |
279 | 316 | if missing_names: |
|
0 commit comments