Skip to content

Jax scheduler lr decay#2081

Open
bonneted wants to merge 8 commits intolululxvi:masterfrom
bonneted:jax-scheduler-lr-decay
Open

Jax scheduler lr decay#2081
bonneted wants to merge 8 commits intolululxvi:masterfrom
bonneted:jax-scheduler-lr-decay

Conversation

@bonneted
Copy link
Copy Markdown
Contributor

@bonneted bonneted commented Apr 1, 2026

Add warmup cosine/exponential decay for JAX.
I also refactored the way it passes the arguments to the optax scheduler so the arguments are less implicit (currently list items could be different depending on the context), with kwargs it is more explicit and can be easily extended to new decay schedule.
I kept backward compatibility (list instead of kwargs} for linear/exponential/cosine.
It seems that it was only used in the Poisson_Dirichlet_1d_exactBC in the library examples.
I can also remove the backward compatibility and update the example.
If no init_value is given it defaults to the lr argument of the optimizer

I tested it on the Poisson_Dirichlet_1d_exactBC example with all decay using kwargs and it works fine.

Copy link
Copy Markdown
Contributor

@echen5503 echen5503 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I agree that the dictionary version is cleaner than passing in a tuple, the tuple format is used for all other backends AFAIK.

Let's focus on adding the cosine/exponential decays themselves, then we can open a PR for the change of format for all backends, with backwards compatibility as you suggested.

Comment thread deepxde/optimizers/jax/optimizers.py Outdated
raise NotImplementedError(
f"{decay[0]} learning rate decay to be implemented for backend jax."
)
if callable(decay):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to add forgiveness of just passing in a decay function, then it must be documented. Additionally, keep 1 task / PR.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I've removed it

Copy link
Copy Markdown
Contributor Author

@bonneted bonneted left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've changed back to tuple arguments, only adding warmup decay schedules.
I've added support for optional positional arguments in the tuple (documented in model.py).
I updated the docs to clearly state that init_value is always supplied by the lr argument.
I also updated to use the optax.schedules namespace instead of flat aliases, since these are planned for removal (see https://github.com/google-deepmind/optax/blob/main/optax/__init__.py).

Comment thread deepxde/optimizers/jax/optimizers.py Outdated
raise NotImplementedError(
f"{decay[0]} learning rate decay to be implemented for backend jax."
)
if callable(decay):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I've removed it

Copy link
Copy Markdown
Contributor

@echen5503 echen5503 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, looks much better now, and thanks for catching something that I didn't see in my initial review. Could you provide a quick test to show that the new decay schedules work properly?

Comment thread deepxde/model.py Outdated
Comment thread deepxde/model.py Outdated
Copy link
Copy Markdown
Contributor Author

@bonneted bonneted left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've fixed the formatting issue.
I've added a simple example running all decay schedules, and adding optional params for some. It works:

Using backend: jax
Other supported backends: tensorflow.compat.v1, tensorflow, pytorch, paddle.
paddle supports more examples now and is recommended.
Enable just-in-time compilation with XLA.

WARNING:2026-04-04 11:33:39,895:jax._src.xla_bridge:850: An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Testing linear... Compiling model...
'compile' took 1.607393 s

final metric: 6.225e-02
Testing cosine... Compiling model...
'compile' took 0.013514 s

final metric: 6.239e-02
Testing exponential... Compiling model...
'compile' took 0.012375 s

final metric: 3.411e-04
Testing warmup_cosine... Compiling model...
'compile' took 0.011391 s

final metric: 7.794e-02
Testing warmup_exponential... Compiling model...
'compile' took 0.012295 s

final metric: 4.453e-04

All schedules ran successfully!
Testing linear with optional parameters... Compiling model...
'compile' took 0.012792 s

final metric: 5.528e-02
Testing exponential with optional parameters... Compiling model...
'compile' took 0.011313 s

final metric: 3.215e-04
Testing warmup_exponential with optional parameters... Compiling model...
'compile' took 0.012612 s

final metric: 4.179e-04

All schedules with optional parameters ran successfully!

I have also fixed the poisson example for JAX: inverse time decay is not supported so I changed it to exponential for JAX.
Let me know if everything looks good, I’ll then remove the test.

@echen5503
Copy link
Copy Markdown
Contributor

All comments addressed.

@bonneted bonneted force-pushed the jax-scheduler-lr-decay branch from 1322c06 to 0fdb1df Compare April 4, 2026 16:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants