-
Notifications
You must be signed in to change notification settings - Fork 959
Add physical context to examples, and use backend-agnostic functions #2066
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
vishnu-2206
wants to merge
7
commits into
lululxvi:master
Choose a base branch
from
vishnu-2206:doc-physics-diffusion-vsr
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 3 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
1508cad
Refractor diffusion_1d example with physical context and descriptive …
vishnu-2206 64a176f
chore: use dde.backend for source term and revert docstring changes
vishnu-2206 f1fbe3a
fix: unify 1D diffusion suite with universal dde.backend and physics …
vishnu-2206 f355cb5
fix: refactor 1D diffusion example to be backend-agnostic using dde.b…
vishnu-2206 f0d08ae
fix: update exactBC.py with backend-agnostic transform and approved s…
vishnu-2206 5321b8f
add 1D heat equation example with adaptive resampling
vishnu-2206 b2f54be
Remove redundant backend labels and cleanup comments
CarrotBear-00 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -18,34 +18,16 @@ def pde(x, y): | |||
| # Backend jax | ||||
| # dy_t, _ = dde.grad.jacobian(y, x, i=0, j=1) | ||||
| # dy_xx, _ = dde.grad.hessian(y, x, i=0, j=0) | ||||
| # Backend tensorflow.compat.v1 or tensorflow | ||||
| # Cross-backend source term | ||||
| f = dde.backend.exp(-x[:, 1:]) * ( | ||||
| dde.backend.sin(np.pi * x[:, 0:1]) - np.pi**2 * dde.backend.sin(np.pi * x[:, 0:1]) | ||||
| ) | ||||
| # Backend tensorflow.compat.v1 or tensorflow, pytorch, jax, paddle | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
| return ( | ||||
| dy_t | ||||
| - dy_xx | ||||
| + tf.exp(-x[:, 1:]) | ||||
| * (tf.sin(np.pi * x[:, 0:1]) - np.pi ** 2 * tf.sin(np.pi * x[:, 0:1])) | ||||
| + f | ||||
| ) | ||||
| # Backend pytorch | ||||
| # return ( | ||||
| # dy_t | ||||
| # - dy_xx | ||||
| # + torch.exp(-x[:, 1:]) | ||||
| # * (torch.sin(np.pi * x[:, 0:1]) - np.pi ** 2 * torch.sin(np.pi * x[:, 0:1])) | ||||
| # ) | ||||
| # Backend jax | ||||
| # return ( | ||||
| # dy_t | ||||
| # - dy_xx | ||||
| # + jnp.exp(-x[:, 1:]) | ||||
| # * (jnp.sin(np.pi * x[..., 0:1]) - np.pi ** 2 * jnp.sin(np.pi * x[..., 0:1])) | ||||
| # ) | ||||
| # Backend paddle | ||||
| # return ( | ||||
| # dy_t | ||||
| # - dy_xx | ||||
| # + paddle.exp(-x[:, 1:]) | ||||
| # * (paddle.sin(np.pi * x[:, 0:1]) - np.pi ** 2 * paddle.sin(np.pi * x[:, 0:1])) | ||||
| # ) | ||||
|
|
||||
|
|
||||
| def func(x): | ||||
|
|
@@ -63,14 +45,8 @@ def func(x): | |||
| initializer = "Glorot uniform" | ||||
| net = dde.nn.FNN(layer_size, activation, initializer) | ||||
| net.apply_output_transform( | ||||
| # Backend tensorflow.compat.v1 or tensorflow | ||||
| lambda x, y: x[:, 1:2] * (1 - x[:, 0:1] ** 2) * y + tf.sin(np.pi * x[:, 0:1]) | ||||
| # Backend pytorch | ||||
| # lambda x, y: x[:, 1:2] * (1 - x[:, 0:1] ** 2) * y + torch.sin(np.pi * x[:, 0:1]) | ||||
| # Backend jax | ||||
| # lambda x, y: x[..., 1:2] * (1 - x[..., 0:1] ** 2) * y + jnp.sin(np.pi * x[..., 0:1]) | ||||
| # Backend paddle | ||||
| # lambda x, y: x[:, 1:2] * (1 - x[:, 0:1] ** 2) * y + paddle.sin(np.pi * x[:, 0:1]) | ||||
| # This single line now works for TensorFlow, PyTorch, JAX, and Paddle | ||||
| lambda x, y: x[:, 1:2] * (1 - x[:, 0:1] ** 2) * y + dde.backend.sin(np.pi * x[:, 0:1]) | ||||
| ) | ||||
|
|
||||
| model = dde.Model(data, net) | ||||
|
|
||||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -18,36 +18,18 @@ def pde(x, y): | |||
| # Backend jax | ||||
| # dy_t, _ = dde.grad.jacobian(y, x, i=0, j=1) | ||||
| # dy_xx, _ = dde.grad.hessian(y, x, i=0, j=0) | ||||
| # Backend tensorflow.compat.v1 or tensorflow | ||||
| # Cross-backend source term | ||||
| f = dde.backend.exp(-x[:, 1:]) * ( | ||||
| dde.backend.sin(np.pi * x[:, 0:1]) - np.pi**2 * dde.backend.sin(np.pi * x[:, 0:1]) | ||||
| ) | ||||
| # Backend tensorflow.compat.v1 or tensorflow, pytorch, jax, paddle | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
| return ( | ||||
| dy_t | ||||
| - dy_xx | ||||
| + tf.exp(-x[:, 1:]) | ||||
| * (tf.sin(np.pi * x[:, 0:1]) - np.pi ** 2 * tf.sin(np.pi * x[:, 0:1])) | ||||
| + f | ||||
| ) | ||||
| # Backend pytorch | ||||
| # return ( | ||||
| # dy_t | ||||
| # - dy_xx | ||||
| # + torch.exp(-x[:, 1:]) | ||||
| # * (torch.sin(np.pi * x[:, 0:1]) - np.pi ** 2 * torch.sin(np.pi * x[:, 0:1])) | ||||
| # ) | ||||
| # Backend jax | ||||
| # return ( | ||||
| # dy_t | ||||
| # - dy_xx | ||||
| # + jnp.exp(-x[:, 1:]) | ||||
| # * (jnp.sin(np.pi * x[..., 0:1]) - np.pi ** 2 * jnp.sin(np.pi * x[..., 0:1])) | ||||
| # ) | ||||
| # Backend paddle | ||||
| # return ( | ||||
| # dy_t | ||||
| # - dy_xx | ||||
| # + paddle.exp(-x[:, 1:]) | ||||
| # * (paddle.sin(np.pi * x[:, 0:1]) - np.pi ** 2 * paddle.sin(np.pi * x[:, 0:1])) | ||||
| # ) | ||||
|
|
||||
|
|
||||
|
|
||||
|
|
||||
| def func(x): | ||||
| return np.sin(np.pi * x[:, 0:1]) * np.exp(-x[:, 1:]) | ||||
|
|
||||
|
|
||||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to say this, because it is implied