Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ def compile(
- `exponential_decay
<https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.exponential_decay>`_:
("exponential", transition_steps, decay_rate)
- `warmup_cosine_decay_schedule
<https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.warmup_cosine_decay_schedule>`_:
("warmup_cosine", peak_value, warmup_steps, decay_steps, end_value)
Comment thread
bonneted marked this conversation as resolved.
Outdated
- `warmup_exponential_decay_schedule
<https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.warmup_exponential_decay_schedule>`_:
("warmup_exponential", peak_value, warmup_steps, transition_steps, decay_rate)
Comment thread
bonneted marked this conversation as resolved.
Outdated

Only required arguments are listed. Optional positional arguments can be provided using Optax order.
`init_value` is not in the tuple/list, since it is always taken from `lr` argument.

loss_weights: A list specifying scalar coefficients (Python floats) to
weight the loss contributions. The loss value that will be minimized by
Expand Down
14 changes: 10 additions & 4 deletions deepxde/optimizers/jax/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import jax
import optax
from optax import schedules


apply_updates = optax.apply_updates
Expand Down Expand Up @@ -37,12 +38,17 @@ def _get_learningrate(lr, decay):
if decay is None:
return lr
if decay[0] == "linear":
return optax.linear_schedule(lr, decay[1], decay[2])
return schedules.linear_schedule(lr, *decay[1:])
if decay[0] == "cosine":
return optax.cosine_decay_schedule(lr, decay[1], decay[2])
return schedules.cosine_decay_schedule(lr, *decay[1:])
if decay[0] == "exponential":
return optax.exponential_decay(lr, decay[1], decay[2])
return schedules.exponential_decay(lr, *decay[1:])
if decay[0] == "warmup_cosine":
return schedules.warmup_cosine_decay_schedule(lr, *decay[1:])
if decay[0] == "warmup_exponential":
return schedules.warmup_exponential_decay_schedule(lr, *decay[1:])

raise NotImplementedError(
f"{decay[0]} learning rate decay to be implemented for backend jax."
f"Unknown decay schedule '{decay[0]}' for JAX backend. "
f"Supported: 'linear', 'cosine', 'exponential', 'warmup_cosine', 'warmup_exponential'."
)
Loading