Background
After unifying _op_rules and _op_partial_rules into a single registration mechanism #386, operator.getitem remained the only op with special-case handling in three places:
placement_options.py: propagate_tensor_meta is skipped for getitem
apply_sharding.py: _call_getitem handles redistribution instead of _redistribute_and_adjust_args
optimize_sharding.py: _compute_edge_costs has a getitem-specific branch to extract src_spec[node.args[1]] from a tuple
The root cause is that getitem_rule sets input_specs=(output_specs,) — a single spec representing the selected element — rather than the full tuple of upstream specs. This mismatch with the actual input structure (a tuple of tensors) forces special handling everywhere.
What was attempted
Step 1: Fix getitem_rule to set proper input_specs
Changed getitem_rule to set input_specs to the full tuple of upstream specs (one per tuple element), with redistribute costs expanded to N rows (one per element, dummy costs for all except the selected index):
# Before
s = OpSpec(output_specs, input_specs=(output_specs,))
s.redistribute_cost = [costs_for_selected_element]
# After
s = OpSpec(output_specs, input_specs=input_specs, # full tuple
redistribute_cost=redistribute_costs) # N rows
Step 2: Remove operator.getitem skip in propagate_tensor_meta
With N input_specs matching the N tensor elements in user_args, propagate_tensor_meta's assertion len(tensor_metas) == len(input_specs) would pass. Required filtering None input_specs (non-tensor tuple elements) to match tensor_metas (which only contains actual tensors).
Step 3: Remove _call_getitem in apply_sharding.py
Updated _redistribute_and_adjust_args to expand tuple output_specs into individual specs when building curr_specs, with a parallel
input_node_per_spec list to track which FX node each spec came from.
Step 4: Update optimize_sharding.py
_compute_edge_costs: Replaced the operator.getitem branch with a generic isinstance(src_spec, tuple) check using argi to index into tuple output_specs
_build_decision_vars: Same expansion of all_input_nodes into input_nodes_per_spec for tuple output_specs
validate: Same expansion logic
Where it broke: the ILP solver model
The solver creates one binary decision variable per (node_idx, argi, out_idx, inp_idx) tuple. With 1 input_spec (old model),
getitem has one argi=0 dimension, producing M variables (one per strategy pair). With N input_specs (new model), getitem has N argi values, producing N×M variables.
The uniqueness constraint requires each argi to independently select exactly one strategy:
∀i,a: Σ_{o,j} x_{i,a,o,j} = 1
With N args, this means N independent selections. The consistency constraint forces all args to agree on the same output placement, but they can still select different inp_idx values. The net result: the solver selects multiple strategies for a single getitem node, violating the
solution validation assertion.
The fundamental issue is that the solver's ILP model assumes each argi corresponds to a distinct input FX node. For getitem, all N args come from the same upstream node — they're not independent inputs but elements of a single tuple output.
What would need to change to complete this
1. Solver variable model
The solver needs to understand that for getitem, all argi values reference the same producer node and should be coupled. Options:
-
Collapse the N args back to 1 in the solver: Add a mapping layer that translates between the N-input-spec model (used by rules and propagate_tensor_meta) and the 1-arg model (used by the solver). The solver would see getitem as having 1 arg with 1 cost row (the selected element's costs), while the rule and post-processing see N input_specs.
-
Teach the solver about tuple outputs: Add a constraint type that links all argi values for the same upstream tuple node, forcing them to select the same inp_idx. This is the more principled approach but requires new constraint logic.
2. Flow constraints
add_output_input_consistent_constraint links producer output placements to consumer input placements. With N args from the same producer, each arg would generate a flow constraint. These would need to be consistent — all N args must agree on the same producer strategy since they come from the same node.
3. Cost accounting
With N redistribute cost rows, the compute cost is divided by N (per_arg_compute = compute_cost / num_args). This is wrong for getitem
since only one element is actually used — the compute cost should not be diluted across N dummy args.
4. Apply sharding
The _redistribute_and_adjust_args changes (expanding tuple output_specs) worked correctly in isolation. These changes would be valid once the solver produces correct solutions.
Current state
operator.getitem remains special-cased in three places. The special-case handling is correct and well-understood. Removing it requires reworking the solver's ILP variable model, which is a substantially larger change than the rule unification.
Background
After unifying
_op_rulesand_op_partial_rulesinto a single registration mechanism #386,operator.getitemremained the only op with special-case handling in three places:placement_options.py:propagate_tensor_metais skipped for getitemapply_sharding.py:_call_getitemhandles redistribution instead of_redistribute_and_adjust_argsoptimize_sharding.py:_compute_edge_costshas a getitem-specific branch to extractsrc_spec[node.args[1]]from a tupleThe root cause is that
getitem_rulesetsinput_specs=(output_specs,)— a single spec representing the selected element — rather than the full tuple of upstream specs. This mismatch with the actual input structure (a tuple of tensors) forces special handling everywhere.What was attempted
Step 1: Fix
getitem_ruleto set properinput_specsChanged
getitem_ruleto setinput_specsto the full tuple of upstream specs (one per tuple element), with redistribute costs expanded to N rows (one per element, dummy costs for all except the selected index):Step 2: Remove
operator.getitemskip inpropagate_tensor_metaWith N input_specs matching the N tensor elements in
user_args,propagate_tensor_meta's assertionlen(tensor_metas) == len(input_specs)would pass. Required filteringNoneinput_specs (non-tensor tuple elements) to matchtensor_metas(which only contains actual tensors).Step 3: Remove
_call_getiteminapply_sharding.pyUpdated
_redistribute_and_adjust_argsto expand tupleoutput_specsinto individual specs when buildingcurr_specs, with a parallelinput_node_per_speclist to track which FX node each spec came from.Step 4: Update
optimize_sharding.py_compute_edge_costs: Replaced theoperator.getitembranch with a genericisinstance(src_spec, tuple)check usingargito index into tuple output_specs_build_decision_vars: Same expansion ofall_input_nodesintoinput_nodes_per_specfor tuple output_specsvalidate: Same expansion logicWhere it broke: the ILP solver model
The solver creates one binary decision variable per
(node_idx, argi, out_idx, inp_idx)tuple. With 1 input_spec (old model),getitem has one
argi=0dimension, producing M variables (one per strategy pair). With N input_specs (new model), getitem has Nargivalues, producing N×M variables.The uniqueness constraint requires each
argito independently select exactly one strategy:With N args, this means N independent selections. The consistency constraint forces all args to agree on the same output placement, but they can still select different
inp_idxvalues. The net result: the solver selects multiple strategies for a single getitem node, violating thesolution validation assertion.
The fundamental issue is that the solver's ILP model assumes each
argicorresponds to a distinct input FX node. For getitem, all N args come from the same upstream node — they're not independent inputs but elements of a single tuple output.What would need to change to complete this
1. Solver variable model
The solver needs to understand that for getitem, all
argivalues reference the same producer node and should be coupled. Options:Collapse the N args back to 1 in the solver: Add a mapping layer that translates between the N-input-spec model (used by rules and
propagate_tensor_meta) and the 1-arg model (used by the solver). The solver would see getitem as having 1 arg with 1 cost row (the selected element's costs), while the rule and post-processing see N input_specs.Teach the solver about tuple outputs: Add a constraint type that links all
argivalues for the same upstream tuple node, forcing them to select the sameinp_idx. This is the more principled approach but requires new constraint logic.2. Flow constraints
add_output_input_consistent_constraintlinks producer output placements to consumer input placements. With N args from the same producer, each arg would generate a flow constraint. These would need to be consistent — all N args must agree on the same producer strategy since they come from the same node.3. Cost accounting
With N redistribute cost rows, the compute cost is divided by N (
per_arg_compute = compute_cost / num_args). This is wrong for getitemsince only one element is actually used — the compute cost should not be diluted across N dummy args.
4. Apply sharding
The
_redistribute_and_adjust_argschanges (expanding tuple output_specs) worked correctly in isolation. These changes would be valid once the solver produces correct solutions.Current state
operator.getitemremains special-cased in three places. The special-case handling is correct and well-understood. Removing it requires reworking the solver's ILP variable model, which is a substantially larger change than the rule unification.