Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,24 @@ def compile(

- For backend JAX:

- Generic format for Optax schedule:
("schedule_name", {"kwarg": value, ...})
init_value will default to `lr` if not provided in kwargs.
- `linear_schedule
<https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.linear_schedule>`_:
("linear", end_value, transition_steps)
<https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.linear_schedule>`_:
("linear", end_value, transition_steps)
- `cosine_decay_schedule
<https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.cosine_decay_schedule>`_:
("cosine", decay_steps, alpha)
<https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.cosine_decay_schedule>`_:
("cosine", decay_steps, alpha)
- `exponential_decay
<https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.exponential_decay>`_:
("exponential", transition_steps, decay_rate)
<https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.exponential_decay>`_:
("exponential", transition_steps, decay_rate)
- `warmup_cosine_decay_schedule
<https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.warmup_cosine_decay_schedule>`_:
("warmup_cosine", {"init_value": ..., "peak_value": ..., "warmup_steps": ..., "decay_steps": ...})
- `warmup_exponential_decay_schedule
<https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.schedules.warmup_exponential_decay_schedule>`_:
("warmup_exponential", {"init_value": ..., "peak_value": ..., "warmup_steps": ..., "transition_steps": ..., "decay_rate": ...})

loss_weights: A list specifying scalar coefficients (Python floats) to
weight the loss contributions. The loss value that will be minimized by
Expand Down
68 changes: 58 additions & 10 deletions deepxde/optimizers/jax/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import jax
import optax
from optax import schedules


apply_updates = optax.apply_updates
Expand Down Expand Up @@ -36,13 +37,60 @@ def get(optimizer, learning_rate=None, decay=None):
def _get_learningrate(lr, decay):
if decay is None:
return lr
if decay[0] == "linear":
return optax.linear_schedule(lr, decay[1], decay[2])
if decay[0] == "cosine":
return optax.cosine_decay_schedule(lr, decay[1], decay[2])
if decay[0] == "exponential":
return optax.exponential_decay(lr, decay[1], decay[2])

raise NotImplementedError(
f"{decay[0]} learning rate decay to be implemented for backend jax."
)
if callable(decay):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to add forgiveness of just passing in a decay function, then it must be documented. Additionally, keep 1 task / PR.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I've removed it

return decay

schedule_map = {
"linear": schedules.linear_schedule,
"cosine": schedules.cosine_decay_schedule,
"exponential": schedules.exponential_decay,
"warmup_cosine": schedules.warmup_cosine_decay_schedule,
"warmup_exponential": schedules.warmup_exponential_decay_schedule,
}

# Preferred format: decay = ("schedule_name", {"kwarg": value, ...})
# Legacy tuple formats are still supported below.
if isinstance(decay, tuple) and len(decay) == 2 and isinstance(decay[1], dict):
name, kwargs = decay
elif isinstance(decay, tuple) and len(decay) == 3 and decay[0] == "linear":
name, kwargs = "linear", {
"init_value": lr,
"end_value": decay[1],
"transition_steps": decay[2],
}
elif isinstance(decay, tuple) and len(decay) == 3 and decay[0] == "cosine":
name, kwargs = "cosine", {
"init_value": lr,
"decay_steps": decay[1],
"alpha": decay[2],
}
elif isinstance(decay, tuple) and len(decay) == 3 and decay[0] == "exponential":
name, kwargs = "exponential", {
"init_value": lr,
"transition_steps": decay[1],
"decay_rate": decay[2],
}
else:
raise ValueError(
"For JAX, use decay=(name, kwargs) with Optax schedule kwargs, "
"or a legacy tuple for linear/cosine/exponential. See "
"https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html"
)

# Keep a simple default: if schedule accepts init_value, use lr unless provided.
if isinstance(kwargs, dict) and "init_value" not in kwargs:
kwargs = {"init_value": lr, **kwargs}

try:
return schedule_map[name](**kwargs)
except KeyError as exc:
raise NotImplementedError(
f"Unknown JAX decay schedule '{name}'. Supported schedules: "
f"{list(schedule_map.keys())}. See "
"https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html"
) from exc
except TypeError as exc:
raise ValueError(
f"Invalid kwargs for JAX decay schedule '{name}': {kwargs}. See "
"https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html"
) from exc
Loading