Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ def compile(
- `exponential_decay
<https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.exponential_decay>`_:
("exponential", transition_steps, decay_rate)
- `warmup_cosine_decay_schedule
<https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.warmup_cosine_decay_schedule>`_:
("warmup_cosine", peak_value, warmup_steps, decay_steps, end_value)
- `warmup_exponential_decay_schedule
<https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.warmup_exponential_decay_schedule>`_:
("warmup_exponential", peak_value, warmup_steps, transition_steps, decay_rate)

Only required arguments are listed. Optional positional arguments can be provided in Optax order.
`init_value` is not in the tuple/list, since it is always supplied by the `lr` argument.

loss_weights: A list specifying scalar coefficients (Python floats) to
weight the loss contributions. The loss value that will be minimized by
Expand Down
14 changes: 10 additions & 4 deletions deepxde/optimizers/jax/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import jax
import optax
from optax import schedules


apply_updates = optax.apply_updates
Expand Down Expand Up @@ -37,12 +38,17 @@ def _get_learningrate(lr, decay):
if decay is None:
return lr
if decay[0] == "linear":
return optax.linear_schedule(lr, decay[1], decay[2])
return schedules.linear_schedule(lr, *decay[1:])
if decay[0] == "cosine":
return optax.cosine_decay_schedule(lr, decay[1], decay[2])
return schedules.cosine_decay_schedule(lr, *decay[1:])
if decay[0] == "exponential":
return optax.exponential_decay(lr, decay[1], decay[2])
return schedules.exponential_decay(lr, *decay[1:])
if decay[0] == "warmup_cosine":
return schedules.warmup_cosine_decay_schedule(lr, *decay[1:])
if decay[0] == "warmup_exponential":
return schedules.warmup_exponential_decay_schedule(lr, *decay[1:])

raise NotImplementedError(
f"{decay[0]} learning rate decay to be implemented for backend jax."
f"Unknown decay schedule '{decay[0]}' for JAX backend. "
f"Supported: 'linear', 'cosine', 'exponential', 'warmup_cosine', 'warmup_exponential'."
)
4 changes: 3 additions & 1 deletion examples/pinn_forward/Poisson_Dirichlet_1d_exactBC.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ def output_transform(x, y):
net.apply_output_transform(output_transform)

model = dde.Model(data, net)
# Most backends
model.compile("adam", lr=1e-4, decay=("inverse time", 1000, 0.3), metrics=["l2 relative error"])

# Backend jax (does not support "inverse time" decay)
# model.compile("adam", lr=1e-4, decay=("exponential", 1000, 0.9), metrics=["l2 relative error"])
losshistory, train_state = model.train(iterations=30000)

dde.saveplot(losshistory, train_state, issave=True, isplot=True)
Loading