Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions deepxde/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from . import config
from . import gradients as grad
from . import optimizers
from . import utils
from .backend import backend_name, jax, paddle, tf, torch

Expand Down Expand Up @@ -603,3 +604,121 @@ def on_epoch_end(self):
raise ValueError(
"`num_bcs` changed! Please update the loss function by `model.compile`."
)


class TimeTracker(Callback):
"""Track the elapsed time and show it together with the estimated remaining time.

Args:
period: How often to show the time estimations (default is 500 iterations).
"""

def __init__(self, period=500):
super().__init__()
self.period = period
self.t_start = None
self.starting_epoch = 0
self.last_display_epoch = 0
self.total_iterations = None

def _format_time(self, seconds):
"""Format time in smart format based on magnitude.

- If hours > 0: "X:YY:ZZ"
- Otherwise: "YY:ZZ"
"""
if seconds < 0:
return "N/A"

hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)

if hours > 0:
return f"{hours}:{minutes:02d}:{secs:02d}"
else:
return f"{minutes:02d}:{secs:02d}"

def on_train_begin(self):
self.t_start = time.time()
self.starting_epoch = self._get_iteration()
self.last_display_epoch = self._get_iteration()

def _get_iteration(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you can _get_iteration, why not _get_iterations (total iterations) as well, and reduce overhead in model.py code?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is an argument of the methods train() and _train_sgd() in Model, but not a property of Model, so callbacks do not have access to it.

A similar callback parameter setting method is already used in the _train_tensorflow_compat_v1_scipy() method, so I have followed the same way.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. You are right.

return getattr(
getattr(self.model, "train_state", 0),
"iteration",
0
)

def _get_epochs_since_last(self):
return self._get_iteration() - self.last_display_epoch

def _get_epochs_since_start(self):
return self._get_iteration() - self.starting_epoch

def on_epoch_end(self):
if self._get_epochs_since_last() >= self.period:
self.last_display_epoch = self._get_iteration()
self._display_time_status()

def on_train_end(self):
if self._get_epochs_since_last() > 0:
self._display_time_status()

def _display_time_status(self):
"""Display elapsed time and estimated remaining time."""
if self.t_start is None:
return

elapsed = time.time() - self.t_start
current_step = self._get_epochs_since_start()

# Estimate remaining time
if current_step > 0:
rate = elapsed / current_step

# Get total iterations if available from model
total_iterations = None
if not optimizers.is_external_optimizer(self.model.opt_name):
total_iterations = self.total_iterations
else:
if self.model.opt_name in ["L-BFGS", "L-BFGS-B"]:
total_iterations = optimizers.LBFGS_options["maxiter"]
elif self.model.opt_name in ["NNCG"]:
total_iterations = optimizers.NNCG_options["cgmaxiter"]
else:
print(
f"Warning: The optimizer {self.model.opt_name} "
"is not supported fully by the `TimeTracker` callback. "
"Open an issue on the DeepXDE repository."
)

if total_iterations is not None:
remaining_steps = max(0, total_iterations - current_step)
remaining = rate * remaining_steps
else:
# Can't estimate without knowing total iterations
remaining = None
else:
rate = None
remaining = None

# Format output
elapsed_str = self._format_time(elapsed)
if rate is not None:
rate_str = f"{1 / rate:.2f}it/s" if (1 / rate > 0.05) else f"{rate:.2f}s/it"
rate_str = f", {rate_str}"
else:
rate_str = ""

if remaining is not None:
remaining_str = self._format_time(remaining)

print(
f"{self._get_iteration()} [{elapsed_str}<{remaining_str}{rate_str}]"
)
else:
print(
f"{self._get_iteration()} [{elapsed_str}{rate_str}]"
)
4 changes: 4 additions & 0 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,10 @@ def train(
self.batch_size = batch_size
self.callbacks = CallbackList(callbacks=callbacks)
self.callbacks.set_model(self)
for cb in self.callbacks.callbacks:
if type(cb).__name__ == "TimeTracker":
cb.total_iterations = iterations

if disregard_previous_best:
self.train_state.disregard_best()

Expand Down
Loading