Skip to content

feat: Add best-model checkpoint saving by monitored metric #2037

@mintleaf84

Description

@mintleaf84

Summary

Currently TrainCheckpointCallback only supports periodic checkpoint saving by training steps. There is no built-in mechanism to save the best model based on a validation metric (e.g., val/loss).

Motivation

When training MLIP models, users typically want to keep the best-performing checkpoint based on a validation metric, not just the most recent periodic save. This is a common feature in other training frameworks (PyTorch Lightning, Hugging Face Trainer, etc.).

Proposed Solution

Extend TrainCheckpointCallback with:

TrainCheckpointCallback(
    checkpoint_every_n_steps=1000,  # periodic save (existing)
    monitor="val/loss",              # metric to track (new)
    mode="min",                      # min or max (new)
    save_top_k=3,                    # keep top-K best (new)
)
  • monitor: validation metric name to track
  • mode: "min" (e.g., loss) or "max" (e.g., accuracy)
  • save_top_k: number of best checkpoints to keep

Backward Compatibility

When monitor is not set, behavior is identical to the current implementation. Zero breaking changes.

Implementation Sketch

def on_eval_end(self, state: State, unit: EvalUnit) -> None:
    if self.monitor is None:
        return
    current_metric = getattr(unit, "last_eval_metrics", {}).get(self.monitor)
    if current_metric is None:
        return
    # compare with best, save if better, cleanup old best checkpoints

The metric would be read from unit.last_eval_metrics (stored during on_eval_epoch_end).

Additional Context

This is a common pattern in training frameworks:

  • PyTorch Lightning: ModelCheckpoint(monitor="val_loss", mode="min")
  • Hugging Face: TrainingArguments(load_best_model_at_end=True)
  • DeepSpeed: checkpoint_tag="best"

Would the team be open to a PR for this feature?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions