Skip to content

ShardingOptimizer.save()/load() resurrects constraints removed via remove_constraints() #476

@fmassa

Description

@fmassa

Summary

remove_constraints() updates the live optimizer state, but that removal is not faithfully preserved across save() / load().

Today, save() serializes constraint_log, and load() reconstructs constraints by replaying that append-only log. Since removals are not recorded in the serialized state, any constraints that were removed before saving are re-applied after loading.

This appears to be a pre-existing issue with the save/load design for interactive constraint editing, not something specific to one recent change.

Repro

Case 1: removed node constraint comes back after load

with AutoParallel(model, input_fn, mesh) as autop:
    autop.add_input_constraints([x_sharding])
    autop.add_output_constraints([out_sharding])
    opt = autop.sharding_optimizer

    base = opt.get_solution()
    names = opt.add_node_constraint(node, placement=target_placement)
    constrained = opt.resolve()

    opt.remove_constraints(names)
    reverted = opt.resolve()

    opt.save("model.ap")

loaded = ShardingOptimizer.load("model.ap")
after_load = loaded.resolve()

Case 2: removed memory constraint comes back after load

with AutoParallel(model, input_fn, mesh) as autop:
    autop.add_input_constraints([x_sharding])
    autop.add_output_constraints([out_sharding])
    autop.add_parameter_memory_constraint(low=None, high=None)

    constrained = autop.optimize_placement()

    opt = autop.sharding_optimizer
    opt.remove_constraints(["memory_constraint_high", "memory_constraint_low"])
    reverted = opt.resolve()

    opt.save("model.ap")

loaded = ShardingOptimizer.load("model.ap")
after_load = loaded.resolve()

Actual behavior

After load(), the optimizer behaves as if the removed constraints were never removed:

  • removed node constraints are applied again
  • removed memory constraints are rebuilt again

This happens because load() replays the original constraint_log, which still contains the original add_*constraint calls, but has no record of the later remove_constraints() calls.

Expected behavior

A saved optimizer should preserve the current active constraint state, not just the historical sequence of added constraints.

If a constraint has been removed before save(), then after load() + resolve() it should remain removed.

Why this matters

This breaks offline / notebook-style "what-if" workflows that rely on:

  1. adding constraints
  2. re-solving
  3. removing constraints
  4. saving the resulting optimizer state for later exploration

After reload, the optimizer can produce different results from the state that was saved in memory.

This is especially confusing because the in-process behavior of remove_constraints() works as expected, but that state is not durable across serialization.

Likely root cause

  • remove_constraints() mutates live optimizer state (prob.constraints, and possibly related helper state)
  • save() serializes constraint_log
  • load() reconstructs constraints by replaying constraint_log
  • removals are not represented in the serialized data model

Possible directions

Any of these could work:

  • record removals in the serialized constraint history
  • serialize the active constraint set directly instead of replaying an append-only log
  • serialize a "removed constraints" tombstone set and apply it after replay

The key requirement is that save() / load() preserve the optimizer's effective active constraints, not just the original additions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions