-
Notifications
You must be signed in to change notification settings - Fork 37
Add 1D wave equation example #7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| ## 1D Wave Equation | ||
| Given the wave equation: | ||
|
|
||
| $$ u_{tt} = c^2 u_{xx}, $$ | ||
|
|
||
| with wave speed $c = 1$, domain $x \in [0, 1]$ and $t \in [0, 1]$, and the initial and boundary conditions: | ||
| $$u(0, x) = \sin(\pi x), $$ | ||
| $$u_t(0, x) = 0, $$ | ||
| $$u(t, 0) = 0,$$ | ||
| $$u(t, 1) = 0.$$ | ||
|
|
||
| The exact solution is $u(x, t) = \sin(\pi x) \cos(\pi t)$. | ||
|
|
||
| ### Problem Setup | ||
|
|
||
| | 1D Wave Equation | | | ||
| |------------------------------|---| | ||
| | PDE equation | $f = u_{tt} - c^2 u_{xx}$ | | ||
| | Initial condition (displacement) | $u(0, x) = \sin(\pi x)$ | | ||
| | Initial condition (velocity) | $u_t(0, x) = 0$ | | ||
| | Dirichlet boundary conditions | $u(t, 0) = u(t, 1) = 0$ | | ||
| | The output of net | $[u(t, x)]$ | | ||
| | Layers of net | $[2] + 5 \times [50] + [1]$ | | ||
| | Sample count from collection points | $10000$ | | ||
| | Sample count from initial conditions | $100$ | | ||
| | Sample count from boundary conditions | $100$ | | ||
| | Loss function | $\text{MSE}_0 + \text{MSE}_b + \text{MSE}_c$ | | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| defaults: | ||
| - train | ||
| - _self_ | ||
|
|
||
| N0: 100 | ||
| N_b: 100 | ||
| N_f: 10_000 | ||
|
|
||
| time_domain: | ||
| _target_: pinnstf2.data.TimeDomain | ||
| t_interval: [0, 1.0] | ||
| t_points: 201 | ||
|
|
||
| spatial_domain: | ||
| _target_: pinnstf2.data.Interval | ||
| x_interval: [0, 1] | ||
| shape: [256, 1] | ||
|
|
||
| mesh: | ||
| _target_: pinnstf2.data.Mesh | ||
| root_dir: null | ||
| read_data_fn: ??? | ||
| ub: [1.0, 1.0] | ||
| lb: [0.0, 0.0] | ||
|
|
||
| train_datasets: | ||
| - mesh_sampler: | ||
| _target_: pinnstf2.data.MeshSampler | ||
| _partial_: true | ||
| num_sample: ${N_f} | ||
| collection_points: | ||
| - f | ||
|
|
||
| - initial_condition: | ||
| _target_: pinnstf2.data.InitialCondition | ||
| _partial_: true | ||
| num_sample: ${N0} | ||
| solution: | ||
| - u | ||
| - v | ||
|
|
||
| - dirichlet_boundary_condition: | ||
| _target_: pinnstf2.data.DirichletBoundaryCondition | ||
| _partial_: true | ||
| num_sample: ${N_b} | ||
| solution: | ||
| - u | ||
|
|
||
| val_dataset: | ||
| - mesh_sampler: | ||
| _target_: pinnstf2.data.MeshSampler | ||
| _partial_: true | ||
| solution: | ||
| - u | ||
|
|
||
| pred_dataset: | ||
| - mesh_sampler: | ||
| _target_: pinnstf2.data.MeshSampler | ||
| _partial_: true | ||
| solution: | ||
| - u | ||
|
|
||
| net: | ||
| _target_: pinnstf2.models.FCN | ||
| layers: [2, 50, 50, 50, 50, 50, 1] | ||
| output_names: | ||
| - u | ||
|
|
||
| trainer: | ||
| max_epochs: 20000 | ||
| check_val_every_n_epoch: 20001 | ||
|
|
||
| model: | ||
| loss_fn: mse | ||
|
|
||
| train: true | ||
| val: true | ||
| test: false | ||
| optimized_metric: | ||
| error: | ||
| - u | ||
|
|
||
| plotting: null | ||
|
|
||
| seed: 1234 | ||
| task_name: wave_equation | ||
|
|
||
| hydra: | ||
| searchpath: | ||
| - pkg://pinnstf2/conf |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| from typing import Dict, Optional | ||
|
|
||
| import hydra | ||
| import numpy as np | ||
| import tensorflow as tf | ||
| from omegaconf import DictConfig | ||
|
|
||
| import pinnstf2 | ||
|
|
||
|
|
||
| C = 1.0 # wave speed | ||
|
|
||
|
|
||
| def read_data_fn(root_path): | ||
| """Generate analytical reference data for the 1D wave equation. | ||
|
|
||
| Exact solution: u(x, t) = sin(pi*x) * cos(pi*c*t) | ||
| satisfying u_tt = c^2 * u_xx with u(0,t) = u(1,t) = 0. | ||
|
|
||
| :param root_path: The root directory containing the data (unused for analytical data). | ||
| :return: Dictionary with exact solution array. | ||
| """ | ||
|
|
||
| nx, nt = 256, 201 | ||
| x = np.linspace(0, 1, nx) | ||
| t = np.linspace(0, 1, nt) | ||
| X, T = np.meshgrid(x, t, indexing="ij") | ||
| exact_u = np.sin(np.pi * X) * np.cos(np.pi * C * T) | ||
| return {"u": exact_u} | ||
|
Comment on lines
+24
to
+29
|
||
|
|
||
|
|
||
| def pde_fn( | ||
| outputs: Dict[str, tf.Tensor], | ||
| x: tf.Tensor, | ||
| t: tf.Tensor, | ||
| ): | ||
| """Define the wave equation PDE residual: u_tt - c^2 * u_xx = 0. | ||
|
|
||
| :param outputs: Dictionary of network outputs. | ||
| :param x: Spatial coordinate tensor. | ||
| :param t: Temporal coordinate tensor. | ||
| :return: Updated outputs with PDE residual. | ||
| """ | ||
|
|
||
| u_x, u_t = pinnstf2.utils.gradient(outputs["u"], [x, t]) | ||
| u_xx = pinnstf2.utils.gradient(u_x, x) | ||
| u_tt = pinnstf2.utils.gradient(u_t, t) | ||
| outputs["f"] = u_tt - C**2 * u_xx | ||
| return outputs | ||
|
|
||
|
|
||
| def output_fn( | ||
| outputs: Dict[str, tf.Tensor], | ||
| x: tf.Tensor, | ||
| t: tf.Tensor, | ||
| ): | ||
| """Compute velocity field for initial velocity enforcement. | ||
|
|
||
| :param outputs: Dictionary of network outputs. | ||
| :param x: Spatial coordinate tensor. | ||
| :param t: Temporal coordinate tensor. | ||
| :return: Updated outputs with velocity field. | ||
| """ | ||
|
|
||
| outputs["v"] = pinnstf2.utils.gradient(outputs["u"], t) | ||
| return outputs | ||
|
|
||
|
|
||
| @hydra.main(version_base="1.3", config_path="configs", config_name="config.yaml") | ||
| def main(cfg: DictConfig) -> Optional[float]: | ||
| """Main entry point for training. | ||
|
|
||
| :param cfg: DictConfig configuration composed by Hydra. | ||
| :return: Optional[float] with optimized metric value. | ||
| """ | ||
|
|
||
| pinnstf2.utils.extras(cfg) | ||
|
|
||
| metric_dict, _ = pinnstf2.train( | ||
| cfg, read_data_fn=read_data_fn, pde_fn=pde_fn, output_fn=output_fn | ||
| ) | ||
|
|
||
| metric_value = pinnstf2.utils.get_metric_value( | ||
| metric_dict=metric_dict, metric_names=cfg.get("optimized_metric") | ||
| ) | ||
|
|
||
| return metric_value | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The markdown table is not formatted consistently with other example READMEs: the rows start with
||(double pipe), which renders as an extra empty column in many Markdown parsers. Use a standard table format (single leading|per row) like the other examples underexamples/*/README.md.