Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 78 additions & 21 deletions deepxde/icbc/boundary_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import numbers
from abc import ABC, abstractmethod
from functools import wraps
from functools import wraps, lru_cache

import numpy as np

Expand All @@ -33,15 +33,24 @@ class BC(ABC):
geom: A ``deepxde.geometry.Geometry`` instance.
on_boundary: A function: (x, Geometry.on_boundary(x)) -> True/False.
component: The output component satisfying this BC.
depends_on_trainable_variables: Whether this BC depends on any trainable variable or not.
"""

def __init__(self, geom, on_boundary, component):
def __init__(self, geom, on_boundary, component, depends_on_trainable_variables=None):
self.geom = geom
self.on_boundary = lambda x, on: np.array(
[on_boundary(x[i], on[i]) for i in range(len(x))]
)
self.component = component

if depends_on_trainable_variables is None:
_warn_dependance_on_trainable_variables()
depends_on_trainable_variables = False
self.depends_on_trainable_variables = depends_on_trainable_variables
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it might be a good idea to warn the user of the issue if self.depends_on_trainable is not explicitly given

Copy link
Copy Markdown
Contributor Author

@kyouma kyouma Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see 2 ways to check whether depends_on_trainable is set explicitly or not and print the warning:

  1. inside the BC/IC classes' __init__() - this will bombard with the warnings even the users who do not solve inverse problems, but this does not create any interference between BC/IC classes and, for example, the Model class that also can check the BCs properties;
  2. inside Model.compile() (i.e. examine the elements of self.data.bcs) if and only if external_trainable_variables is not empty - this will not work for the tensorflow.compat.v1 backend (this list is ignored by Model.compile() and, subsequently, by the users), and this may require a stricter control over future additions of classes connectible to Model.data (if they are not children of Data class).

    Edit: For tensorflow.compat.v1 there is no explicit caching in BCs and ICs, so maybe the 2nd option is also viable.

I have made a commit with the 1st option. Following the DeepXDE style, the warning is done with a usual print, so that it will not distract the users unaffected by the bug too much. Besides, I will make an example of an inverse problem with a trainable IC and explicitly state the necessity of setting depends_on_trainable to True there. And maybe even add a new entry to the FAQ section.


# If learnable geometry is introduced, the boundary normal caching must be disabled
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by learnable geometry? Finding the shape of the domain for inverse problem?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Or actually, I should have write "learnable or variable" (for example, a domain that does not depend on any trainable variables, but still changes its shape or size over time). But the introduction of such things into DeepXDE seems to be a very distant prospect.

# if the "geometry depends on trainable variables" flag of `self.geom` is True,
# similar to `self.func` here.
self.boundary_normal = npfunc_range_autocache(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't you forgetting to add the flag here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If DeepXDE supports non-constant geometry, then indeed it is safer or even must-do to set the caching flag for boundary normal calculation, too.

Correct me if I am wrong, but as far as I know, DeepXDE does not support such geometry, as even parametrized geometry is reduced to some normalized constant geometry with parameters transferred into the functions.

If there is any concern that in future something may change, it is better to apply the flag here, too. Just like I have changed some lines with if...elif statements to make them more bug-proof in case of possible future changes.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeepXDE only supports constant geometry, it seems, and there is no plan to make nonconstant in the near future to my knowledge, current focus is fixing current geometry bugs.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then it may be OK to leave the boundary normal calculation as it is now. I can add a comment for each creation of self.boundary_normal saying that for support of learnable boundaries the "disable caching" flag must be passed inside it, too.

Or actually, because any changes for learnable geometry support might be done in other places of the library and this comment might be overlooked, I will test how no caching for self.boundary_normal affects performance. It it is ~5%, I will add the flag to self.boundary_normal, too.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this is a more difficult question. If learnable geometry is introduced it is likely to be done via specifying settings of geometry objects, so one cannot pass the depends_on_trainable_variables of a BC object to manage the boundary normal caching. Thus, I have only added comments to warn future contributors.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good

utils.return_tensor(self.geom.boundary_normal)
)
Expand All @@ -67,9 +76,9 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None):
class DirichletBC(BC):
"""Dirichlet boundary conditions: y(x) = func(x)."""

def __init__(self, geom, func, on_boundary, component=0):
super().__init__(geom, on_boundary, component)
self.func = npfunc_range_autocache(utils.return_tensor(func))
def __init__(self, geom, func, on_boundary, component=0, depends_on_trainable_variables=None):
super().__init__(geom, on_boundary, component, depends_on_trainable_variables)
self.func = npfunc_range_autocache(utils.return_tensor(func), self.depends_on_trainable_variables)

def error(self, X, inputs, outputs, beg, end, aux_var=None):
values = self.func(X, beg, end, aux_var)
Expand All @@ -84,9 +93,9 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None):
class NeumannBC(BC):
"""Neumann boundary conditions: dy/dn(x) = func(x)."""

def __init__(self, geom, func, on_boundary, component=0):
super().__init__(geom, on_boundary, component)
self.func = npfunc_range_autocache(utils.return_tensor(func))
def __init__(self, geom, func, on_boundary, component=0, depends_on_trainable_variables=None):
super().__init__(geom, on_boundary, component, depends_on_trainable_variables)
self.func = npfunc_range_autocache(utils.return_tensor(func), self.depends_on_trainable_variables)

def error(self, X, inputs, outputs, beg, end, aux_var=None):
values = self.func(X, beg, end, aux_var)
Expand All @@ -96,8 +105,13 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None):
class RobinBC(BC):
"""Robin boundary conditions: dy/dn(x) = func(x, y)."""

def __init__(self, geom, func, on_boundary, component=0):
super().__init__(geom, on_boundary, component)
def __init__(self, geom, func, on_boundary, component=0, depends_on_trainable_variables=None):
# `depends_on_trainable_variables` is here in order to be consistent
# with other BC/IC functions with `func`
# and in case in future caching is added here, too.
super().__init__(geom, on_boundary, component, depends_on_trainable_variables)
# If for some reason caching of `func` is added here in future,
# it must be disabled if `self.depends_on_trainable_variables` is True.
self.func = func

def error(self, X, inputs, outputs, beg, end, aux_var=None):
Expand All @@ -110,7 +124,7 @@ class PeriodicBC(BC):
"""Periodic boundary conditions on component_x."""

def __init__(self, geom, component_x, on_boundary, derivative_order=0, component=0):
super().__init__(geom, on_boundary, component)
super().__init__(geom, on_boundary, component, False)
self.component_x = component_x
self.derivative_order = derivative_order
if derivative_order > 1:
Expand Down Expand Up @@ -145,6 +159,7 @@ class OperatorBC(BC):
`inputs` and `outputs` are the network input and output tensors,
respectively; `X` are the NumPy array of the `inputs`.
on_boundary: (x, Geometry.on_boundary(x)) -> True/False.
depends_on_trainable_variables: Whether this BC depends on any trainable variable or not.

Warning:
If you use `X` in `func`, then do not set ``num_test`` when you define
Expand All @@ -154,8 +169,13 @@ class OperatorBC(BC):
which cannot be fixed in an easy way for all backends.
"""

def __init__(self, geom, func, on_boundary):
super().__init__(geom, on_boundary, 0)
def __init__(self, geom, func, on_boundary, depends_on_trainable_variables=None):
# `depends_on_trainable_variables` is here in order to be consistent
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with not putting a warning for those that don't use autocache.

Copy link
Copy Markdown
Contributor Author

@kyouma kyouma Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I change the default depends_on_trainable_variables argument value to False here and in other similar classes?

# with other BC/IC functions with `func`
# and in case in future caching is added here, too.
super().__init__(geom, on_boundary, 0, depends_on_trainable_variables)
# If for some reason caching of `func` is added here in future,
# it must be disabled if `self.depends_on_trainable_variables` is True.
self.func = func

def error(self, X, inputs, outputs, beg, end, aux_var=None):
Expand Down Expand Up @@ -253,13 +273,25 @@ class PointSetOperatorBC:
Note, If you want to use batch size here, you should also set callback
'dde.callbacks.PDEPointResampler(bc_points=True)' in training.
shuffle: Randomize the order on each pass through the data when batching.
depends_on_trainable_variables: Whether this BC depends on any trainable variable or not.
"""

def __init__(self, points, values, func, batch_size=None, shuffle=True):
def __init__(self, points, values, func, batch_size=None, shuffle=True, depends_on_trainable_variables=None):
self.points = np.array(points, dtype=config.real(np))
if not isinstance(values, numbers.Number) and values.shape[1] != 1:
raise RuntimeError("PointSetOperatorBC should output 1D values")
self.values = bkd.as_tensor(values, dtype=config.real(bkd.lib))

# `depends_on_trainable_variables` is here in order to be consistent
# with other BC/IC functions with `func`
# and in case in future caching is added here, too.
if depends_on_trainable_variables is None:
_warn_dependance_on_trainable_variables()
depends_on_trainable_variables = False
self.depends_on_trainable_variables = depends_on_trainable_variables
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rationale behind this is just consistency right? Because this is not used.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The consistency and similar "interfaces" in API docs are beautiful, but I have added it here mainly to make it bug-proof in case of future changes when some type of caching is applied here, too. Apart from that, it is not necessary here at all, it is not even a style preference.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, when I said "style choice" I meant "design/robustness choice"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now there is indeed no need to add these lines here (and in 1 or 2 other places, where BC does not use caching for func at all), but if in future someone wanted to change this behavior for some reason, this might be of help. As an alternative, a comment about necessity to add the ability to disable caching if it is applied here in future can be written, but I have decided to just implement a part of this mechanism here.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think adding a comment for the areas where the flag is added but isn't used is sufficient


# If for some reason caching of `func` is added here in future,
# it must be disabled if `self.depends_on_trainable_variables` is True
self.func = func
self.batch_size = batch_size

Expand Down Expand Up @@ -311,11 +343,18 @@ class Interface2DBC:
on_boundary1: First edge func. (x, Geometry.on_boundary(x)) -> True/False.
on_boundary2: Second edge func. (x, Geometry.on_boundary(x)) -> True/False.
direction (string): "normal" or "tangent".
depends_on_trainable_variables: Whether this BC depends on any trainable variable or not.
"""

def __init__(self, geom, func, on_boundary1, on_boundary2, direction="normal"):
def __init__(self, geom, func, on_boundary1, on_boundary2, direction="normal", depends_on_trainable_variables=None):
self.geom = geom
self.func = npfunc_range_autocache(utils.return_tensor(func))

if depends_on_trainable_variables is None:
_warn_dependance_on_trainable_variables()
depends_on_trainable_variables = False
self.depends_on_trainable_variables = depends_on_trainable_variables

self.func = npfunc_range_autocache(utils.return_tensor(func), self.depends_on_trainable_variables)
self.on_boundary1 = lambda x, on: np.array(
[on_boundary1(x[i], on[i]) for i in range(len(x))]
)
Expand All @@ -324,6 +363,9 @@ def __init__(self, geom, func, on_boundary1, on_boundary2, direction="normal"):
)
self.direction = direction

# If learnable geometry is introduced, the boundary normal caching must be disabled
# if the "geometry depends on trainable variables" flag of `self.geom` is True,
# similar to `self.func` here.
self.boundary_normal = npfunc_range_autocache(
utils.return_tensor(self.geom.boundary_normal)
)
Expand Down Expand Up @@ -371,10 +413,11 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None):
return left_values + right_values - values


def npfunc_range_autocache(func):
def npfunc_range_autocache(func, disable_caching=False):
"""Call a NumPy function on a range of the input ndarray.

If the backend is pytorch, the results are cached based on the id of X.
For BC/IC objects that depend on trainable variables caching must be disabled.
"""
# For some BCs, we need to call self.func(X[beg:end]) in BC.error(). For backend
# tensorflow.compat.v1/tensorflow, self.func() is only called once in graph mode,
Expand Down Expand Up @@ -421,13 +464,27 @@ def wrapper_cache_auxiliary(X, beg, end, aux_var):
cache[key] = func(X[beg:end], aux_var[beg:end])
return cache[key]

if backend_name in ["tensorflow.compat.v1", "tensorflow", "jax"]:
if (backend_name in ["tensorflow.compat.v1", "tensorflow", "jax"]) or disable_caching:
if utils.get_num_args(func) == 1:
return wrapper_nocache
if utils.get_num_args(func) == 2:
elif utils.get_num_args(func) == 2:
return wrapper_nocache_auxiliary
if backend_name in ["pytorch", "paddle"]:
elif backend_name in ["pytorch", "paddle"]:
if utils.get_num_args(func) == 1:
return wrapper_cache
if utils.get_num_args(func) == 2:
elif utils.get_num_args(func) == 2:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the point of turning this into elif? Isn't this a common early return design pattern, where you don't need elif?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that because of ...or disable_caching part there can be a situation when both external conditions are True, and if in future returns are changed into assignments for some reason, a bug may be overlooked. That is why I decided to mark all these branches as incompatible.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The outer elif seems necessary by this reasoning, but what about inner elif?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency and explicitness.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I guess it's just a style choice.

return wrapper_nocache_auxiliary


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use @cache here, so it only prints once, and maybe add a docstring that tells future maintainers that it only prints once.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure that @cache is the best solution for this, but we shouldn't spam the user with warnings.

@lru_cache(maxsize=1)
def _warn_dependance_on_trainable_variables():
"""Print the warning that must contain the same message in constructors of various BC and IC classes.
The warning will be shown only once, which is achieved by using `lru_cache(maxsize=1)`
for a function with no arguments.
"""
print(
"Warning: The BC/IC instance initialization parameter `depends_on_trainable_variables` "
"must be explicitly set to either True or False for all BC and IC objects, or else "
"the gradients of the loss function with respect to the external trainable variables "
"in inverse problems may become wrong."
)
27 changes: 22 additions & 5 deletions deepxde/icbc/initial_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,34 @@

import numpy as np

from .boundary_conditions import npfunc_range_autocache
from .boundary_conditions import npfunc_range_autocache, _warn_dependance_on_trainable_variables
from .. import backend as bkd
from .. import utils


class IC:
"""Initial conditions: y([x, t0]) = func([x, t0])."""

def __init__(self, geom, func, on_initial, component=0):
"""Initial conditions: y([x, t0]) = func([x, t0]).

Args:
geom: A ``deepxde.geometry.Geometry`` instance.
func: A function takes arguments (`inputs`, `outputs`)
and outputs a tensor of size `N x 1`, where `N` is the length of `inputs`.
`inputs` and `outputs` are the network input and output tensors,
respectively.
on_initial: A function: (x, Geometry.on_initial(x)) -> True/False.
component: The output component satisfying this IC.
depends_on_trainable_variables: Whether this IC depends on any trainable variable or not.
"""

def __init__(self, geom, func, on_initial, component=0, depends_on_trainable_variables=None):
self.geom = geom
self.func = npfunc_range_autocache(utils.return_tensor(func))

if depends_on_trainable_variables is None:
_warn_dependance_on_trainable_variables()
depends_on_trainable_variables = False
self.depends_on_trainable_variables = depends_on_trainable_variables

self.func = npfunc_range_autocache(utils.return_tensor(func), self.depends_on_trainable_variables)
self.on_initial = lambda x, on: np.array(
[on_initial(x[i], on[i]) for i in range(len(x))]
)
Expand Down
Loading