Skip to content

Unifying operator.getitem with the standard code path #388

@fmassa

Description

@fmassa

Background

After unifying _op_rules and _op_partial_rules into a single registration mechanism #386, operator.getitem remained the only op with special-case handling in three places:

  1. placement_options.py: propagate_tensor_meta is skipped for getitem
  2. apply_sharding.py: _call_getitem handles redistribution instead of _redistribute_and_adjust_args
  3. optimize_sharding.py: _compute_edge_costs has a getitem-specific branch to extract src_spec[node.args[1]] from a tuple

The root cause is that getitem_rule sets input_specs=(output_specs,) — a single spec representing the selected element — rather than the full tuple of upstream specs. This mismatch with the actual input structure (a tuple of tensors) forces special handling everywhere.

What was attempted

Step 1: Fix getitem_rule to set proper input_specs

Changed getitem_rule to set input_specs to the full tuple of upstream specs (one per tuple element), with redistribute costs expanded to N rows (one per element, dummy costs for all except the selected index):

# Before
s = OpSpec(output_specs, input_specs=(output_specs,))
s.redistribute_cost = [costs_for_selected_element]

# After
s = OpSpec(output_specs, input_specs=input_specs,  # full tuple
           redistribute_cost=redistribute_costs)    # N rows

Step 2: Remove operator.getitem skip in propagate_tensor_meta

With N input_specs matching the N tensor elements in user_args, propagate_tensor_meta's assertion len(tensor_metas) == len(input_specs) would pass. Required filtering None input_specs (non-tensor tuple elements) to match tensor_metas (which only contains actual tensors).

Step 3: Remove _call_getitem in apply_sharding.py

Updated _redistribute_and_adjust_args to expand tuple output_specs into individual specs when building curr_specs, with a parallel
input_node_per_spec list to track which FX node each spec came from.

Step 4: Update optimize_sharding.py

  • _compute_edge_costs: Replaced the operator.getitem branch with a generic isinstance(src_spec, tuple) check using argi to index into tuple output_specs
  • _build_decision_vars: Same expansion of all_input_nodes into input_nodes_per_spec for tuple output_specs
  • validate: Same expansion logic

Where it broke: the ILP solver model

The solver creates one binary decision variable per (node_idx, argi, out_idx, inp_idx) tuple. With 1 input_spec (old model),
getitem has one argi=0 dimension, producing M variables (one per strategy pair). With N input_specs (new model), getitem has N argi values, producing N×M variables.

The uniqueness constraint requires each argi to independently select exactly one strategy:

∀i,a: Σ_{o,j} x_{i,a,o,j} = 1

With N args, this means N independent selections. The consistency constraint forces all args to agree on the same output placement, but they can still select different inp_idx values. The net result: the solver selects multiple strategies for a single getitem node, violating the
solution validation assertion.

The fundamental issue is that the solver's ILP model assumes each argi corresponds to a distinct input FX node. For getitem, all N args come from the same upstream node — they're not independent inputs but elements of a single tuple output.

What would need to change to complete this

1. Solver variable model

The solver needs to understand that for getitem, all argi values reference the same producer node and should be coupled. Options:

  • Collapse the N args back to 1 in the solver: Add a mapping layer that translates between the N-input-spec model (used by rules and propagate_tensor_meta) and the 1-arg model (used by the solver). The solver would see getitem as having 1 arg with 1 cost row (the selected element's costs), while the rule and post-processing see N input_specs.

  • Teach the solver about tuple outputs: Add a constraint type that links all argi values for the same upstream tuple node, forcing them to select the same inp_idx. This is the more principled approach but requires new constraint logic.

2. Flow constraints

add_output_input_consistent_constraint links producer output placements to consumer input placements. With N args from the same producer, each arg would generate a flow constraint. These would need to be consistent — all N args must agree on the same producer strategy since they come from the same node.

3. Cost accounting

With N redistribute cost rows, the compute cost is divided by N (per_arg_compute = compute_cost / num_args). This is wrong for getitem
since only one element is actually used — the compute cost should not be diluted across N dummy args.

4. Apply sharding

The _redistribute_and_adjust_args changes (expanding tuple output_specs) worked correctly in isolation. These changes would be valid once the solver produces correct solutions.

Current state

operator.getitem remains special-cased in three places. The special-case handling is correct and well-understood. Removing it requires reworking the solver's ILP variable model, which is a substantially larger change than the rule unification.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions