Conversation
| fsdp_enabled (bool): Bool of if the model is a FSDP model or not. | ||
|
|
||
| Returns: | ||
| Union[torch.Tensor, None]: The total gradient norm before clipping for 'norm' clipping type, |
There was a problem hiding this comment.
Let's just always return the result, not just for norm.
It's a weird contract for a separate function downstream to know this behavior
There was a problem hiding this comment.
The downstream function always guards by checking if the clipping_type is norm, so that should be good enough
There was a problem hiding this comment.
I thought about it but the other two options don't have top-level scalar values that can be returned:
- value - returns None (see https://github.com/pytorch/pytorch/blob/v2.7.0/torch/nn/utils/clip_grad.py#L250)
- adaptive - is an elementwise op, so there isn't a clear value to return
Hence why I stuck with returning None for those. I agree that the contract is awkward, but we needed to propagate the norm to the logger.
There was a problem hiding this comment.
oh lol, why do they modify in place for value and not for norm 😆 😓
There was a problem hiding this comment.
approved, not ideal, but makes sense that this is the best we can do
Context
Currently
torch.nn.utils.clip_grad_norm_andFSDP.clip_grad_norm_apply the gradient normalization in place but also return the pre-clip gradient norm value, however the value is not capture nor logged anywhere.We can't change the API for all gradient clipping methods since some don't have top level scalar, but we can for gradient norm clipping, the most frequent one we use.
This PR propagates the value outside of the helper and into the algorithm where it can be logged. Since clipping is fairly bursty, we also compute the rolling window over
clipping_frequency_windowsamples to provide a more parseable metric.Known caveats
_clipping_historyis not persisted so the metric will change slightly upon resumptionclipping_threshold: 100Experiments
Couple of example experiments that showcase the functionality with SFT and GRPO are
2025-07-21-debug-gradient-clippingand2025-07-25-math-rlvr-grporespectively