Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions examples/wave_equation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
## 1D Wave Equation
Given the wave equation:

$$ u_{tt} = c^2 u_{xx}, $$

with wave speed $c = 1$, domain $x \in [0, 1]$ and $t \in [0, 1]$, and the initial and boundary conditions:
$$u(0, x) = \sin(\pi x), $$
$$u_t(0, x) = 0, $$
$$u(t, 0) = 0,$$
$$u(t, 1) = 0.$$

The exact solution is $u(x, t) = \sin(\pi x) \cos(\pi t)$.

### Problem Setup

| 1D Wave Equation | |
|------------------------------|---|
Comment on lines +16 to +17
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The markdown table is not formatted consistently with other example READMEs: the rows start with || (double pipe), which renders as an extra empty column in many Markdown parsers. Use a standard table format (single leading | per row) like the other examples under examples/*/README.md.

Suggested change
| 1D Wave Equation | |
|------------------------------|---|
| Quantity | Value |
|------------------------------|-------------------------------|

Copilot uses AI. Check for mistakes.
| PDE equation | $f = u_{tt} - c^2 u_{xx}$ |
| Initial condition (displacement) | $u(0, x) = \sin(\pi x)$ |
| Initial condition (velocity) | $u_t(0, x) = 0$ |
| Dirichlet boundary conditions | $u(t, 0) = u(t, 1) = 0$ |
| The output of net | $[u(t, x)]$ |
| Layers of net | $[2] + 5 \times [50] + [1]$ |
| Sample count from collection points | $10000$ |
| Sample count from initial conditions | $100$ |
| Sample count from boundary conditions | $100$ |
| Loss function | $\text{MSE}_0 + \text{MSE}_b + \text{MSE}_c$ |
90 changes: 90 additions & 0 deletions examples/wave_equation/configs/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
defaults:
- train
- _self_

N0: 100
N_b: 100
N_f: 10_000

time_domain:
_target_: pinnstf2.data.TimeDomain
t_interval: [0, 1.0]
t_points: 201

spatial_domain:
_target_: pinnstf2.data.Interval
x_interval: [0, 1]
shape: [256, 1]

mesh:
_target_: pinnstf2.data.Mesh
root_dir: null
read_data_fn: ???
ub: [1.0, 1.0]
lb: [0.0, 0.0]

train_datasets:
- mesh_sampler:
_target_: pinnstf2.data.MeshSampler
_partial_: true
num_sample: ${N_f}
collection_points:
- f

- initial_condition:
_target_: pinnstf2.data.InitialCondition
_partial_: true
num_sample: ${N0}
solution:
- u
- v

- dirichlet_boundary_condition:
_target_: pinnstf2.data.DirichletBoundaryCondition
_partial_: true
num_sample: ${N_b}
solution:
- u

val_dataset:
- mesh_sampler:
_target_: pinnstf2.data.MeshSampler
_partial_: true
solution:
- u

pred_dataset:
- mesh_sampler:
_target_: pinnstf2.data.MeshSampler
_partial_: true
solution:
- u

net:
_target_: pinnstf2.models.FCN
layers: [2, 50, 50, 50, 50, 50, 1]
output_names:
- u

trainer:
max_epochs: 20000
check_val_every_n_epoch: 20001

model:
loss_fn: mse

train: true
val: true
test: false
optimized_metric:
error:
- u

plotting: null

seed: 1234
task_name: wave_equation

hydra:
searchpath:
- pkg://pinnstf2/conf
91 changes: 91 additions & 0 deletions examples/wave_equation/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Dict, Optional

import hydra
import numpy as np
import tensorflow as tf
from omegaconf import DictConfig

import pinnstf2


C = 1.0 # wave speed


def read_data_fn(root_path):
"""Generate analytical reference data for the 1D wave equation.

Exact solution: u(x, t) = sin(pi*x) * cos(pi*c*t)
satisfying u_tt = c^2 * u_xx with u(0,t) = u(1,t) = 0.

:param root_path: The root directory containing the data (unused for analytical data).
:return: Dictionary with exact solution array.
"""

nx, nt = 256, 201
x = np.linspace(0, 1, nx)
t = np.linspace(0, 1, nt)
X, T = np.meshgrid(x, t, indexing="ij")
exact_u = np.sin(np.pi * X) * np.cos(np.pi * C * T)
return {"u": exact_u}
Comment on lines +24 to +29
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config.yaml requests initial-condition supervision for both u and v, but read_data_fn() only returns {"u": exact_u}. This will raise a KeyError when Mesh.on_initial_boundary() tries to slice self.solution["v"] for the initial condition sampler. Include an analytical v solution array in the returned dict (e.g., v = ∂u/∂t) or adjust the config to avoid requesting v from the mesh solution.

Copilot uses AI. Check for mistakes.


def pde_fn(
outputs: Dict[str, tf.Tensor],
x: tf.Tensor,
t: tf.Tensor,
):
"""Define the wave equation PDE residual: u_tt - c^2 * u_xx = 0.

:param outputs: Dictionary of network outputs.
:param x: Spatial coordinate tensor.
:param t: Temporal coordinate tensor.
:return: Updated outputs with PDE residual.
"""

u_x, u_t = pinnstf2.utils.gradient(outputs["u"], [x, t])
u_xx = pinnstf2.utils.gradient(u_x, x)
u_tt = pinnstf2.utils.gradient(u_t, t)
outputs["f"] = u_tt - C**2 * u_xx
return outputs


def output_fn(
outputs: Dict[str, tf.Tensor],
x: tf.Tensor,
t: tf.Tensor,
):
"""Compute velocity field for initial velocity enforcement.

:param outputs: Dictionary of network outputs.
:param x: Spatial coordinate tensor.
:param t: Temporal coordinate tensor.
:return: Updated outputs with velocity field.
"""

outputs["v"] = pinnstf2.utils.gradient(outputs["u"], t)
return outputs


@hydra.main(version_base="1.3", config_path="configs", config_name="config.yaml")
def main(cfg: DictConfig) -> Optional[float]:
"""Main entry point for training.

:param cfg: DictConfig configuration composed by Hydra.
:return: Optional[float] with optimized metric value.
"""

pinnstf2.utils.extras(cfg)

metric_dict, _ = pinnstf2.train(
cfg, read_data_fn=read_data_fn, pde_fn=pde_fn, output_fn=output_fn
)

metric_value = pinnstf2.utils.get_metric_value(
metric_dict=metric_dict, metric_names=cfg.get("optimized_metric")
)

return metric_value


if __name__ == "__main__":
main()
Loading