Skip to content

fix clipping logic, add test for clipping functions#73

Merged
ClashLuke merged 35 commits into
HomebrewML:mainfrom
alexjwilliams:clipfix
Jul 26, 2025
Merged

fix clipping logic, add test for clipping functions#73
ClashLuke merged 35 commits into
HomebrewML:mainfrom
alexjwilliams:clipfix

Conversation

@alexjwilliams

Copy link
Copy Markdown
Contributor

Fixes #72.

Current state of the test:

FAILED test/test_clip.py::test_clip[128-2-Muon] - AssertionError: nan before clipping
FAILED test/test_clip.py::test_clip[128-2-MuonLaProp] - AssertionError: nan before clipping
FAILED test/test_clip.py::test_clip[128-2-ForeachNewtonPSGDLRA] - ValueError: Hessian approximation requires a closure.
FAILED test/test_clip.py::test_clip[128-2-ForeachMuon] - AssertionError: nan before clipping
FAILED test/test_clip.py::test_clip[128-2-ForeachCachedNewtonPSGD] - ValueError: Hessian approximation requires a closure.
FAILED test/test_clip.py::test_clip[128-2-NewtonPSGDLRA] - ValueError: Hessian approximation requires a closure.
FAILED test/test_clip.py::test_clip[128-2-NewtonHybrid2PSGDLRA] - ValueError: Hessian approximation requires a closure.
FAILED test/test_clip.py::test_clip[128-2-NewtonHybrid2PSGDKron] - ValueError: Hessian approximation requires a closure.
FAILED test/test_clip.py::test_clip[128-2-NewtonPSGDKron] - ValueError: Hessian approximation requires a closure.

I believe the ValueError is due to some unrelated issue. The AssertionError occurs at the very beginning of a clip function, before anything is done to the input tensors. Therefore, it seems that this error is due to Muon just not behaving well when its updates/grads are clipped. I am not really all that familiar with Muon, so let me know if the presence of this error makes sense to you and what you would like me to change to remedy it.

@ClashLuke

Copy link
Copy Markdown
Member

I've merged both clamps into one. Could you double-check whether it still looks good? I'll get the tests to run/pass in a few hours.

@ClashLuke ClashLuke force-pushed the main branch 2 times, most recently from 06fb5f7 to d099575 Compare July 25, 2025 16:37
- fix clip_at:
    * remove the redundant clip_at = max(clip_at, eps) in
      _compilable_rmsnorm_clip_
    * remove the clip_at = max(clip_at, eps) in _clip and in
      _compilable_global_rmsnorm_clip_. These lines made the numerator
      of the scalar equal to max(clip_at, eps) instead of just clip_at.
- divide only once by numel in _compilable_global_rmsnorm_clip_ for
  better perf/elegance
- other straightforward fixes
@alexjwilliams

Copy link
Copy Markdown
Contributor Author

Sorry for the delay. I have been recovering from surgery.

There are some issues with the edits. I fixed the straightforward ones in the new commit. Please take a look at it. Now for the more nuanced "issues" (considerations):

The equality that you wrote down,

(x / y.clamp(min=eps)).clamp(max=1) 
== min(x/max(y,eps),1) (A) 
== x / y.clamp(min=max(x, eps)) 
== x / max(y, max(x,eps)) (B)

is true, but the original expression is not equal to A:

(clip_at / (norm + 1e-6)).clamp(max=1.0) 
== (x / (y + eps)).clamp(max=1) 
== min(x/(y+eps), 1) (C) 
!= min(x/max(y,eps),1) (A)

This being said, the use of max(y,eps) appears to be better than using y+eps. For y < eps, the expression A is always 1, since x/eps is large. C behaves similarly in this case. But, for y >= eps, the expression A, when it is not equal to 1, is x/y exactly, rather than x/(y+eps) which is the value of C in this case. I imagine you were aware of this when making your changes, but I wanted to lay it out in detail just in case.

So, the code is not wrong, but it does not match pytorch. As I stated in the issue, pytorch uses the following logic (which is what I copied exactly):

clip_coef = max_norm / (total_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)

If you think/know that using max(y,eps) performs materially better in practice, then keep it like you have it. But if not, we should just match pytorch to ensure stability for users if they switch between the two.

Now consider the value you used for eps, 1e-8. This, combined with max(y,eps), actually provides less aggressive stabilization than either y+eps or max(y,eps) together with eps=1e-6. Suppose 1e-8 < y < 1e-6. Then, for eps=1e-8, max(y,eps) == y, which is smaller than the larger epsilon 1e-6! y+eps will also be smaller than the larger epsilon, unless y is really close to 1e-6 (within 1e-8). So, in either case, we are dividing x by a number smaller than the original epsilon. Note that pytorch uses 1e-6 for epsilon and it is not user configurable. I think we should copy this, unless your experience tells you deviating is better.

@ClashLuke

Copy link
Copy Markdown
Member

Awesome, thank you for the detailed breakdown!

Yeah, the max is on purpose. In my tests, the max empirically converges better by preserving numerical accuracy for longer. This is not the same semantics as torch, but it's consistent with how eps is handled in the rest of the library.

I especially appreciate you catching the x vs x32 multiplication issue! Merging this now. Feel free to contribute again if you spot other issues!

Regardless, I hope your surgery went well and wish you a swift recovery.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

clipping logic is incorrect

2 participants