Jax scheduler lr decay#2081
Conversation
echen5503
left a comment
There was a problem hiding this comment.
While I agree that the dictionary version is cleaner than passing in a tuple, the tuple format is used for all other backends AFAIK.
Let's focus on adding the cosine/exponential decays themselves, then we can open a PR for the change of format for all backends, with backwards compatibility as you suggested.
| raise NotImplementedError( | ||
| f"{decay[0]} learning rate decay to be implemented for backend jax." | ||
| ) | ||
| if callable(decay): |
There was a problem hiding this comment.
If you want to add forgiveness of just passing in a decay function, then it must be documented. Additionally, keep 1 task / PR.
There was a problem hiding this comment.
ok, I've removed it
bonneted
left a comment
There was a problem hiding this comment.
I've changed back to tuple arguments, only adding warmup decay schedules.
I've added support for optional positional arguments in the tuple (documented in model.py).
I updated the docs to clearly state that init_value is always supplied by the lr argument.
I also updated to use the optax.schedules namespace instead of flat aliases, since these are planned for removal (see https://github.com/google-deepmind/optax/blob/main/optax/__init__.py).
| raise NotImplementedError( | ||
| f"{decay[0]} learning rate decay to be implemented for backend jax." | ||
| ) | ||
| if callable(decay): |
There was a problem hiding this comment.
ok, I've removed it
echen5503
left a comment
There was a problem hiding this comment.
Great, looks much better now, and thanks for catching something that I didn't see in my initial review. Could you provide a quick test to show that the new decay schedules work properly?
There was a problem hiding this comment.
I've fixed the formatting issue.
I've added a simple example running all decay schedules, and adding optional params for some. It works:
Using backend: jax
Other supported backends: tensorflow.compat.v1, tensorflow, pytorch, paddle.
paddle supports more examples now and is recommended.
Enable just-in-time compilation with XLA.
WARNING:2026-04-04 11:33:39,895:jax._src.xla_bridge:850: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Testing linear... Compiling model...
'compile' took 1.607393 s
final metric: 6.225e-02
Testing cosine... Compiling model...
'compile' took 0.013514 s
final metric: 6.239e-02
Testing exponential... Compiling model...
'compile' took 0.012375 s
final metric: 3.411e-04
Testing warmup_cosine... Compiling model...
'compile' took 0.011391 s
final metric: 7.794e-02
Testing warmup_exponential... Compiling model...
'compile' took 0.012295 s
final metric: 4.453e-04
All schedules ran successfully!
Testing linear with optional parameters... Compiling model...
'compile' took 0.012792 s
final metric: 5.528e-02
Testing exponential with optional parameters... Compiling model...
'compile' took 0.011313 s
final metric: 3.215e-04
Testing warmup_exponential with optional parameters... Compiling model...
'compile' took 0.012612 s
final metric: 4.179e-04
All schedules with optional parameters ran successfully!
I have also fixed the poisson example for JAX: inverse time decay is not supported so I changed it to exponential for JAX.
Let me know if everything looks good, I’ll then remove the test.
|
All comments addressed. |
1322c06 to
0fdb1df
Compare
Add warmup cosine/exponential decay for JAX.
I also refactored the way it passes the arguments to the optax scheduler so the arguments are less implicit (currently list items could be different depending on the context), with kwargs it is more explicit and can be easily extended to new decay schedule.
I kept backward compatibility (list instead of kwargs} for linear/exponential/cosine.
It seems that it was only used in the
Poisson_Dirichlet_1d_exactBCin the library examples.I can also remove the backward compatibility and update the example.
If no
init_valueis given it defaults to thelrargument of the optimizerI tested it on the
Poisson_Dirichlet_1d_exactBCexample with all decay using kwargs and it works fine.