Skip to content

Make basic BC/IC classes correctly trainable#2077

Open
kyouma wants to merge 8 commits intolululxvi:masterfrom
kyouma:fix_caching_of_icbc_with_variables
Open

Make basic BC/IC classes correctly trainable#2077
kyouma wants to merge 8 commits intolululxvi:masterfrom
kyouma:fix_caching_of_icbc_with_variables

Conversation

@kyouma
Copy link
Copy Markdown
Contributor

@kyouma kyouma commented Mar 30, 2026

Hello.

In relation to the issue #2074.

Now it is possible to include trainable variables into BC/IC for inverse problems only by using OperatorBC (at least for the main types of BC/IC), but it did not manage to find an explicit statement about this in the documentation or FAQ. Now, if one tries to use DirichletBC, NeumannBC, IC etc. that depends on trainable variables, the function values are cached and the trainable variable gradients are incorrect, which leads to terrible convergence behavior.

This PR proposes to require from the basic BC/IC classes to specify whether they depend on trainable variables (this disables function caching for them) or not. It is done independently for each BC/IC to protect the performance.

I have put the new argument at the end of the class signatures in order to preserve backward compatibility for scripts with positional arguments. But the default value is None, which will cause an exception saying that the user need to set it correctly.

The alternative to these or different code changes is to add a warning to the documentation (at least the BC, IC and inverse problems examples sections) saying that for inverse problems with trainable variables in BC/IC one must use OperatorBC.

It was possible to include trainable variables into BC/IC only using OperatorBC (at least for the main types of BC/IC), but it was not explicitly stated in the documentation or FAQ. Now the basic BC/IC classes require to specify whether they depend on trainable variables (this disables function caching for them) or not.

The backward compatibility will not be broken as the order of arguments is the same, and this argument will require explicit setting during execution.
Comment thread deepxde/icbc/boundary_conditions.py Outdated
Comment thread deepxde/icbc/boundary_conditions.py Outdated
Comment thread deepxde/icbc/boundary_conditions.py Outdated
kyouma and others added 4 commits April 1, 2026 00:23
Use the current changes, but set the default value of this new flag to False and remove the boolean check.
- Old scripts will work as they do now.
- If someone needs to train a BC/IC, then will have to specify it just as the trainable variables.

Co-authored-by: Edwin Chen <echen1ffa@gmail.com>
Set the default value of the `depends_on_trainable_variables` flag to False and remove the boolean check.
- Old scripts will work as they do now.
- If someone needs to train a BC/IC, then will have to specify it just as the trainable variables.
@kyouma
Copy link
Copy Markdown
Contributor Author

kyouma commented Mar 31, 2026

Hello. Please check the changes.

I have updated the default value; removed the boolean check; updated the docstrings to reflect the changes in the documentation.

The next step is to add an example and there emphasize the need to set the depends_on_trainable_variables flag. I will work on it.

Copy link
Copy Markdown
Contributor

@echen5503 echen5503 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be a good idea to warn the user if the depends_on_trainable_variables isn't set.

[on_boundary(x[i], on[i]) for i in range(len(x))]
)
self.component = component
self.depends_on_trainable_variables = depends_on_trainable_variables
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it might be a good idea to warn the user of the issue if self.depends_on_trainable is not explicitly given

Copy link
Copy Markdown
Contributor Author

@kyouma kyouma Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can see 2 ways to check whether depends_on_trainable is set explicitly or not and print the warning:

  1. inside the BC/IC classes' __init__() - this will bombard with the warnings even the users who do not solve inverse problems, but this does not create any interference between BC/IC classes and, for example, the Model class that also can check the BCs properties;
  2. inside Model.compile() (i.e. examine the elements of self.data.bcs) if and only if external_trainable_variables is not empty - this will not work for the tensorflow.compat.v1 backend (this list is ignored by Model.compile() and, subsequently, by the users), and this may require a stricter control over future additions of classes connectible to Model.data (if they are not children of Data class).

    Edit: For tensorflow.compat.v1 there is no explicit caching in BCs and ICs, so maybe the 2nd option is also viable.

I have made a commit with the 1st option. Following the DeepXDE style, the warning is done with a usual print, so that it will not distract the users unaffected by the bug too much. Besides, I will make an example of an inverse problem with a trainable IC and explicitly state the necessity of setting depends_on_trainable to True there. And maybe even add a new entry to the FAQ section.

elif utils.get_num_args(func) == 2:
return wrapper_nocache_auxiliary


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use @cache here, so it only prints once, and maybe add a docstring that tells future maintainers that it only prints once.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure that @cache is the best solution for this, but we shouldn't spam the user with warnings.

@echen5503
Copy link
Copy Markdown
Contributor

Nice, lru_cache(1) is cleaner than @cache by itself. To make final test, can you try non-converging code on this issue, verifying that it now converges?

@kyouma
Copy link
Copy Markdown
Contributor Author

kyouma commented Apr 1, 2026

I was trying to modify the diffusion equation inverse problem example and move the unknown constant C from the PDE to the IC:

torch.sin(C * x[:, 0:1]) * torch.exp(-x[:, 1:2])

and found out that if I put it into sin(), then, with caching on, a double .backward() (I am using PyTorch) execution error appears:

    * * *
  File "C:\Users\Yakov\Documents\GitHub\deepxde\deepxde\model.py", line 394, in closure
    total_loss.backward()

    * * *

  File "C:\Users\Yakov\.pyenv310\lib\site-packages\torch\autograd\graph.py", line 824, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

And if I place the unknown constant C (initial guess 1.0) into the bias of IC:

torch.sin(torch.pi * x[:, 0:1]) * torch.exp(-x[:, 1:2]) + C

then it runs even with caching on. And, as expected, does not converge (must become zero):

Step      Train loss                                  Test loss                                   Test metric
0 [1.00e+00]
1000      [5.31e-03, 1.87e-02, 4.47e-02, 4.10e-03]    [9.81e-03, 1.87e-02, 4.47e-02, 4.10e-03]    [9.68e-01]
1000 [5.47e-01]
2000      [5.48e-03, 1.63e-02, 1.87e-02, 1.98e-03]    [1.36e-02, 1.63e-02, 1.87e-02, 1.98e-03]    [9.56e-01]
2000 [7.74e-02]
3000      [1.40e-03, 9.87e-03, 7.07e-03, 1.76e-03]    [1.52e-02, 9.87e-03, 7.07e-03, 1.76e-03]    [9.09e-01]
3000 [-3.50e-01]
4000      [1.39e-03, 5.78e-03, 5.10e-03, 2.28e-03]    [2.98e-02, 5.78e-03, 5.10e-03, 2.28e-03]    [8.99e-01]
4000 [-6.98e-01]
5000      [3.71e-03, 3.69e-03, 3.36e-03, 1.79e-03]    [8.48e-02, 3.69e-03, 3.36e-03, 1.79e-03]    [8.82e-01]
5000 [-1.08e+00]
6000      [7.08e-04, 2.72e-03, 1.92e-03, 1.98e-03]    [1.35e-01, 2.72e-03, 1.92e-03, 1.98e-03]    [8.88e-01]
6000 [-1.53e+00]
...

When I disable caching for the IC, everything works:

Step      Train loss                                  Test loss                                   Test metric
0         [1.53e+01, 5.92e-02, 1.25e+00, 8.94e-02]    [1.64e+01, 5.92e-02, 1.25e+00, 8.94e-02]    [8.63e-01]
0 [1.00e+00]
1000      [4.23e-03, 5.93e-03, 1.73e-02, 8.38e-04]    [6.61e-03, 5.93e-03, 1.73e-02, 8.38e-04]    [5.84e-01]
1000 [6.11e-01]
2000      [7.02e-04, 2.42e-03, 3.20e-03, 2.24e-04]    [1.23e-03, 2.42e-03, 3.20e-03, 2.24e-04]    [3.43e-01]
2000 [3.59e-01]
3000      [4.45e-04, 7.93e-04, 4.27e-04, 3.41e-05]    [6.68e-04, 7.93e-04, 4.27e-04, 3.41e-05]    [1.72e-01]
3000 [1.83e-01]
4000      [1.21e-03, 2.48e-04, 5.16e-05, 1.15e-04]    [1.19e-03, 2.48e-04, 5.16e-05, 1.15e-04]    [8.70e-02]
4000 [9.01e-02]
5000      [7.69e-05, 5.23e-05, 1.97e-05, 1.35e-05]    [2.22e-04, 5.23e-05, 1.97e-05, 1.35e-05]    [4.34e-02]
5000 [4.62e-02]
6000      [6.58e-05, 1.82e-05, 1.14e-05, 6.68e-06]    [1.70e-04, 1.82e-05, 1.14e-05, 6.68e-06]    [2.46e-02]
6000 [2.67e-02]
7000      [5.18e-05, 9.94e-06, 8.40e-06, 4.30e-06]    [1.29e-04, 9.94e-06, 8.40e-06, 4.30e-06]    [1.73e-02]
7000 [1.86e-02]
...

And also for C inside sin() (the ground truth value is $\pi$):

Step      Train loss                                  Test loss                                   Test metric
0         [1.60e+01, 1.78e-02, 2.83e-01, 1.39e-01]    [1.72e+01, 1.78e-02, 2.83e-01, 1.39e-01]    [1.15e+00]
0 [1.00e+00]
1000      [4.98e-03, 6.52e-03, 7.09e-02, 8.99e-04]    [5.79e-03, 6.52e-03, 7.09e-02, 8.99e-04]    [1.19e-01]
1000 [1.49e+00]
2000      [4.84e-03, 3.14e-03, 2.14e-02, 3.07e-05]    [7.04e-03, 3.14e-03, 2.14e-02, 3.07e-05]    [9.88e-02]
2000 [2.28e+00]
3000      [2.21e-04, 5.63e-04, 6.48e-04, 1.02e-05]    [1.98e-03, 5.63e-04, 6.48e-04, 1.02e-05]    [4.36e-02]
3000 [2.90e+00]
4000      [6.63e-05, 1.55e-04, 4.20e-05, 2.23e-05]    [1.11e-03, 1.55e-04, 4.20e-05, 2.23e-05]    [2.07e-02]
4000 [3.04e+00]
5000      [4.39e-05, 8.11e-05, 1.13e-05, 1.48e-05]    [8.35e-04, 8.11e-05, 1.13e-05, 1.48e-05]    [1.51e-02]
5000 [3.07e+00]
6000      [3.38e-05, 5.56e-05, 6.44e-06, 8.92e-06]    [7.05e-04, 5.56e-05, 6.44e-06, 8.92e-06]    [1.28e-02]
6000 [3.08e+00]
7000      [7.29e-05, 4.36e-05, 5.76e-06, 5.87e-06]    [6.75e-04, 4.36e-05, 5.76e-06, 5.87e-06]    [1.16e-02]
7000 [3.09e+00]
...

I think that adding retain_graph=True to some .backward() calls (maybe the jacobian() and hessian() are the culprits?) is too difficult to control and test across the whole library, and also may need to do the same for other backends. Thus, disabling caching for some BCs and ICs is not only a question of convergence, but also a question of solvability for a lot of inverse problems.

I also have found this discussion #1727, where the author uses the same trick to make IC learnable.

@echen5503
Copy link
Copy Markdown
Contributor

When you say "When I disable caching for the IC," you mean that you passed in the depends_on_trainable_variable flag to true right? Could you provide the code for the test?

@kyouma
Copy link
Copy Markdown
Contributor Author

kyouma commented Apr 3, 2026

Yes, I do.

Here it is. I am planning to add this as an example of trainable IC. This is the same problem as in DeepXDE examples, but I have moved C from the PDE to the sin() inside the IC (so C must converge to $\pi$).

diffusion_1d_inverse_learnable_ic.py

Copy link
Copy Markdown
Contributor

@echen5503 echen5503 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just a few small changes. Does this bug apply to PDE as well?

if utils.get_num_args(func) == 1:
return wrapper_cache
if utils.get_num_args(func) == 2:
elif utils.get_num_args(func) == 2:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the point of turning this into elif? Isn't this a common early return design pattern, where you don't need elif?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that because of ...or disable_caching part there can be a situation when both external conditions are True, and if in future returns are changed into assignments for some reason, a bug may be overlooked. That is why I decided to mark all these branches as incompatible.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The outer elif seems necessary by this reasoning, but what about inner elif?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency and explicitness.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I guess it's just a style choice.

@kyouma
Copy link
Copy Markdown
Contributor Author

kyouma commented Apr 3, 2026

As far as I understand, it only applies to IC and also BCs that wrap the func constructor argument in the caching thing. It causes either wrong gradients or an exception due to multiple backward() calls for the loss object stored in the cache.

@echen5503
Copy link
Copy Markdown
Contributor

That's my understanding too; PDE classes don't seem to use cache.

Copy link
Copy Markdown
Contributor

@echen5503 echen5503 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need a few clarifications

depends_on_trainable_variables = False
self.depends_on_trainable_variables = depends_on_trainable_variables

self.boundary_normal = npfunc_range_autocache(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't you forgetting to add the flag here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If DeepXDE supports non-constant geometry, then indeed it is safer or even must-do to set the caching flag for boundary normal calculation, too.

Correct me if I am wrong, but as far as I know, DeepXDE does not support such geometry, as even parametrized geometry is reduced to some normalized constant geometry with parameters transferred into the functions.

If there is any concern that in future something may change, it is better to apply the flag here, too. Just like I have changed some lines with if...elif statements to make them more bug-proof in case of possible future changes.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeepXDE only supports constant geometry, it seems, and there is no plan to make nonconstant in the near future to my knowledge, current focus is fixing current geometry bugs.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then it may be OK to leave the boundary normal calculation as it is now. I can add a comment for each creation of self.boundary_normal saying that for support of learnable boundaries the "disable caching" flag must be passed inside it, too.

Or actually, because any changes for learnable geometry support might be done in other places of the library and this comment might be overlooked, I will test how no caching for self.boundary_normal affects performance. It it is ~5%, I will add the flag to self.boundary_normal, too.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this is a more difficult question. If learnable geometry is introduced it is likely to be done via specifying settings of geometry objects, so one cannot pass the depends_on_trainable_variables of a BC object to manage the boundary normal caching. Thus, I have only added comments to warn future contributors.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good

if depends_on_trainable_variables is None:
_warn_dependance_on_trainable_variables()
depends_on_trainable_variables = False
self.depends_on_trainable_variables = depends_on_trainable_variables
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rationale behind this is just consistency right? Because this is not used.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The consistency and similar "interfaces" in API docs are beautiful, but I have added it here mainly to make it bug-proof in case of future changes when some type of caching is applied here, too. Apart from that, it is not necessary here at all, it is not even a style preference.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, when I said "style choice" I meant "design/robustness choice"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now there is indeed no need to add these lines here (and in 1 or 2 other places, where BC does not use caching for func at all), but if in future someone wanted to change this behavior for some reason, this might be of help. As an alternative, a comment about necessity to add the ability to disable caching if it is applied here in future can be written, but I have decided to just implement a part of this mechanism here.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think adding a comment for the areas where the flag is added but isn't used is sufficient

@kyouma kyouma force-pushed the fix_caching_of_icbc_with_variables branch from 36cdc53 to b1e820b Compare April 4, 2026 19:09
depends_on_trainable_variables = False
self.depends_on_trainable_variables = depends_on_trainable_variables

# If learnable geometry is introduced, the boundary normal caching must be disabled
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by learnable geometry? Finding the shape of the domain for inverse problem?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Or actually, I should have write "learnable or variable" (for example, a domain that does not depend on any trainable variables, but still changes its shape or size over time). But the introduction of such things into DeepXDE seems to be a very distant prospect.

def __init__(self, geom, func, on_boundary):
super().__init__(geom, on_boundary, 0)
def __init__(self, geom, func, on_boundary, depends_on_trainable_variables=None):
# `depends_on_trainable_variables` is here in order to be consistent
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with not putting a warning for those that don't use autocache.

Copy link
Copy Markdown
Contributor Author

@kyouma kyouma Apr 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I change the default depends_on_trainable_variables argument value to False here and in other similar classes?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants