Skip to content
32 changes: 7 additions & 25 deletions examples/pinn_forward/diffusion_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,16 @@ def pde(x, y):
# Backend jax
# dy_t, _ = dde.grad.jacobian(y, x, j=1)
# dy_xx, _ = dde.grad.hessian(y, x, j=0)
# Backend tensorflow.compat.v1 or tensorflow
# Cross-backend source term
f = dde.backend.exp(-x[:, 1:]) * (
dde.backend.sin(np.pi * x[:, 0:1]) - np.pi**2 * dde.backend.sin(np.pi * x[:, 0:1])
)
# Backend tensorflow.compat.v1 or tensorflow, pytorch, jax, paddle
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Backend tensorflow.compat.v1 or tensorflow, pytorch, jax, paddle

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to say this, because it is implied

return (
dy_t
- dy_xx
+ tf.exp(-x[:, 1:])
* (tf.sin(np.pi * x[:, 0:1]) - np.pi ** 2 * tf.sin(np.pi * x[:, 0:1]))
)
# Backend pytorch
# return (
# dy_t
# - dy_xx
# + torch.exp(-x[:, 1:])
# * (torch.sin(np.pi * x[:, 0:1]) - np.pi ** 2 * torch.sin(np.pi * x[:, 0:1]))
# )
# Backend jax
# return (
# dy_t
# - dy_xx
# + jnp.exp(-x[:, 1:])
# * (jnp.sin(np.pi * x[..., 0:1]) - np.pi ** 2 * jnp.sin(np.pi * x[..., 0:1]))
# )
# Backend paddle
# return (
# dy_t
# - dy_xx
# + paddle.exp(-x[:, 1:])
# * (paddle.sin(np.pi * x[:, 0:1]) - np.pi ** 2 * paddle.sin(np.pi * x[:, 0:1]))
# )
+ f
)


def func(x):
Expand Down
40 changes: 8 additions & 32 deletions examples/pinn_forward/diffusion_1d_exactBC.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,16 @@ def pde(x, y):
# Backend jax
# dy_t, _ = dde.grad.jacobian(y, x, i=0, j=1)
# dy_xx, _ = dde.grad.hessian(y, x, i=0, j=0)
# Backend tensorflow.compat.v1 or tensorflow
# Cross-backend source term
f = dde.backend.exp(-x[:, 1:]) * (
dde.backend.sin(np.pi * x[:, 0:1]) - np.pi**2 * dde.backend.sin(np.pi * x[:, 0:1])
)
# Backend tensorflow.compat.v1 or tensorflow, pytorch, jax, paddle
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Backend tensorflow.compat.v1 or tensorflow, pytorch, jax, paddle

return (
dy_t
- dy_xx
+ tf.exp(-x[:, 1:])
* (tf.sin(np.pi * x[:, 0:1]) - np.pi ** 2 * tf.sin(np.pi * x[:, 0:1]))
+ f
)
# Backend pytorch
# return (
# dy_t
# - dy_xx
# + torch.exp(-x[:, 1:])
# * (torch.sin(np.pi * x[:, 0:1]) - np.pi ** 2 * torch.sin(np.pi * x[:, 0:1]))
# )
# Backend jax
# return (
# dy_t
# - dy_xx
# + jnp.exp(-x[:, 1:])
# * (jnp.sin(np.pi * x[..., 0:1]) - np.pi ** 2 * jnp.sin(np.pi * x[..., 0:1]))
# )
# Backend paddle
# return (
# dy_t
# - dy_xx
# + paddle.exp(-x[:, 1:])
# * (paddle.sin(np.pi * x[:, 0:1]) - np.pi ** 2 * paddle.sin(np.pi * x[:, 0:1]))
# )


def func(x):
Expand All @@ -63,14 +45,8 @@ def func(x):
initializer = "Glorot uniform"
net = dde.nn.FNN(layer_size, activation, initializer)
net.apply_output_transform(
# Backend tensorflow.compat.v1 or tensorflow
lambda x, y: x[:, 1:2] * (1 - x[:, 0:1] ** 2) * y + tf.sin(np.pi * x[:, 0:1])
# Backend pytorch
# lambda x, y: x[:, 1:2] * (1 - x[:, 0:1] ** 2) * y + torch.sin(np.pi * x[:, 0:1])
# Backend jax
# lambda x, y: x[..., 1:2] * (1 - x[..., 0:1] ** 2) * y + jnp.sin(np.pi * x[..., 0:1])
# Backend paddle
# lambda x, y: x[:, 1:2] * (1 - x[:, 0:1] ** 2) * y + paddle.sin(np.pi * x[:, 0:1])
# This single line now works for TensorFlow, PyTorch, JAX, and Paddle
lambda x, y: x[:, 1:2] * (1 - x[:, 0:1] ** 2) * y + dde.backend.sin(np.pi * x[:, 0:1])
)

model = dde.Model(data, net)
Expand Down
34 changes: 8 additions & 26 deletions examples/pinn_forward/diffusion_1d_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,18 @@ def pde(x, y):
# Backend jax
# dy_t, _ = dde.grad.jacobian(y, x, i=0, j=1)
# dy_xx, _ = dde.grad.hessian(y, x, i=0, j=0)
# Backend tensorflow.compat.v1 or tensorflow
# Cross-backend source term
f = dde.backend.exp(-x[:, 1:]) * (
dde.backend.sin(np.pi * x[:, 0:1]) - np.pi**2 * dde.backend.sin(np.pi * x[:, 0:1])
)
# Backend tensorflow.compat.v1 or tensorflow, pytorch, jax, paddle
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Backend tensorflow.compat.v1 or tensorflow, pytorch, jax, paddle

return (
dy_t
- dy_xx
+ tf.exp(-x[:, 1:])
* (tf.sin(np.pi * x[:, 0:1]) - np.pi ** 2 * tf.sin(np.pi * x[:, 0:1]))
+ f
)
# Backend pytorch
# return (
# dy_t
# - dy_xx
# + torch.exp(-x[:, 1:])
# * (torch.sin(np.pi * x[:, 0:1]) - np.pi ** 2 * torch.sin(np.pi * x[:, 0:1]))
# )
# Backend jax
# return (
# dy_t
# - dy_xx
# + jnp.exp(-x[:, 1:])
# * (jnp.sin(np.pi * x[..., 0:1]) - np.pi ** 2 * jnp.sin(np.pi * x[..., 0:1]))
# )
# Backend paddle
# return (
# dy_t
# - dy_xx
# + paddle.exp(-x[:, 1:])
# * (paddle.sin(np.pi * x[:, 0:1]) - np.pi ** 2 * paddle.sin(np.pi * x[:, 0:1]))
# )




def func(x):
return np.sin(np.pi * x[:, 0:1]) * np.exp(-x[:, 1:])

Expand Down