Summary
Currently TrainCheckpointCallback only supports periodic checkpoint saving by training steps. There is no built-in mechanism to save the best model based on a validation metric (e.g., val/loss).
Motivation
When training MLIP models, users typically want to keep the best-performing checkpoint based on a validation metric, not just the most recent periodic save. This is a common feature in other training frameworks (PyTorch Lightning, Hugging Face Trainer, etc.).
Proposed Solution
Extend TrainCheckpointCallback with:
TrainCheckpointCallback(
checkpoint_every_n_steps=1000, # periodic save (existing)
monitor="val/loss", # metric to track (new)
mode="min", # min or max (new)
save_top_k=3, # keep top-K best (new)
)
monitor: validation metric name to track
mode: "min" (e.g., loss) or "max" (e.g., accuracy)
save_top_k: number of best checkpoints to keep
Backward Compatibility
When monitor is not set, behavior is identical to the current implementation. Zero breaking changes.
Implementation Sketch
def on_eval_end(self, state: State, unit: EvalUnit) -> None:
if self.monitor is None:
return
current_metric = getattr(unit, "last_eval_metrics", {}).get(self.monitor)
if current_metric is None:
return
# compare with best, save if better, cleanup old best checkpoints
The metric would be read from unit.last_eval_metrics (stored during on_eval_epoch_end).
Additional Context
This is a common pattern in training frameworks:
- PyTorch Lightning:
ModelCheckpoint(monitor="val_loss", mode="min")
- Hugging Face:
TrainingArguments(load_best_model_at_end=True)
- DeepSpeed:
checkpoint_tag="best"
Would the team be open to a PR for this feature?
Summary
Currently
TrainCheckpointCallbackonly supports periodic checkpoint saving by training steps. There is no built-in mechanism to save the best model based on a validation metric (e.g.,val/loss).Motivation
When training MLIP models, users typically want to keep the best-performing checkpoint based on a validation metric, not just the most recent periodic save. This is a common feature in other training frameworks (PyTorch Lightning, Hugging Face Trainer, etc.).
Proposed Solution
Extend
TrainCheckpointCallbackwith:monitor: validation metric name to trackmode:"min"(e.g., loss) or"max"(e.g., accuracy)save_top_k: number of best checkpoints to keepBackward Compatibility
When
monitoris not set, behavior is identical to the current implementation. Zero breaking changes.Implementation Sketch
The metric would be read from
unit.last_eval_metrics(stored duringon_eval_epoch_end).Additional Context
This is a common pattern in training frameworks:
ModelCheckpoint(monitor="val_loss", mode="min")TrainingArguments(load_best_model_at_end=True)checkpoint_tag="best"Would the team be open to a PR for this feature?