From 667a79b3f259c952d155f4a96a067100a9bcadfe Mon Sep 17 00:00:00 2001 From: oliver Date: Fri, 14 Apr 2023 13:58:05 -0700 Subject: [PATCH 1/6] fixed grid effect --- modulus/datapipes/benchmarks/darcy.py | 29 ++++++++++++++------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/modulus/datapipes/benchmarks/darcy.py b/modulus/datapipes/benchmarks/darcy.py index f69e1f2012..db3c7953a1 100644 --- a/modulus/datapipes/benchmarks/darcy.py +++ b/modulus/datapipes/benchmarks/darcy.py @@ -257,22 +257,23 @@ def generate_batch(self) -> None: if normalized_inf_residual < ( self.convergence_threshold * grid_reduction_factor ): - - # upsample to higher resolution - if grid_reduction_factor > 1: - wp.launch( - kernel=bilinear_upsample_batched_2d, - dim=self.dim, - inputs=[ - self.darcy0, - self.dim[1], - self.dim[2], - grid_reduction_factor, - ], - device=self.device, - ) break + # upsample to higher resolution + if grid_reduction_factor > 1: + wp.launch( + kernel=bilinear_upsample_batched_2d, + dim=self.dim, + inputs=[ + self.darcy0, + self.dim[1], + self.dim[2], + grid_reduction_factor, + ], + device=self.device, + ) + + def __iter__(self) -> Tuple[Tensor, Tensor]: """ Yields From a26f7e4d1b96ba0a863cedafd721987baa8bef14 Mon Sep 17 00:00:00 2001 From: Oliver Date: Fri, 18 Jul 2025 12:29:26 -0700 Subject: [PATCH 2/6] added transient mesh dataset --- CHANGELOG.md | 2 + physicsnemo/datapipes/cae/__init__.py | 1 + physicsnemo/datapipes/cae/mesh_datapipe.py | 131 ++---- physicsnemo/datapipes/cae/readers.py | 130 ++++++ .../datapipes/cae/transient_mesh_datapipe.py | 398 ++++++++++++++++++ pyproject.toml | 6 +- .../datapipes/test_transient_mesh_datapipe.py | 130 ++++++ 7 files changed, 698 insertions(+), 100 deletions(-) create mode 100644 physicsnemo/datapipes/cae/transient_mesh_datapipe.py create mode 100644 test/datapipes/test_transient_mesh_datapipe.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 1dd4f90809..f5145b4bad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Safe API to override `__init__`'s arguments saved in checkpoint file with `Module.from_checkpoint("chkpt.mdlus", models_args)`. - PyTorch Geometric MeshGraphNet backend. +- Transient Mesh Dataset. ### Changed @@ -27,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Existing DGL-based vortex shedding example has been renamed to `vortex_shedding_mgn_dgl`. Added new `vortex_shedding_mgn` example that uses PyTorch Geometric instead. - HEALPixLayer can now use earth2grid HEALPix padding ops, if desired +- Mesh Dataset supports vtm files ### Deprecated diff --git a/physicsnemo/datapipes/cae/__init__.py b/physicsnemo/datapipes/cae/__init__.py index c0d17ff723..2557c2093f 100644 --- a/physicsnemo/datapipes/cae/__init__.py +++ b/physicsnemo/datapipes/cae/__init__.py @@ -16,3 +16,4 @@ from .domino_datapipe import DoMINODataPipe from .mesh_datapipe import MeshDatapipe +from .transient_mesh_datapipe import TransientMeshDatapipe diff --git a/physicsnemo/datapipes/cae/mesh_datapipe.py b/physicsnemo/datapipes/cae/mesh_datapipe.py index 844bd474f7..a5f44e477b 100644 --- a/physicsnemo/datapipes/cae/mesh_datapipe.py +++ b/physicsnemo/datapipes/cae/mesh_datapipe.py @@ -17,7 +17,6 @@ import numpy as np import torch -import vtk try: import nvidia.dali as dali @@ -38,7 +37,14 @@ from physicsnemo.datapipes.datapipe import Datapipe from physicsnemo.datapipes.meta import DatapipeMetaData -from .readers import read_cgns, read_vtp, read_vtu +from .readers import ( + parse_vtk_polydata, + parse_vtk_unstructuredgrid, + read_cgns, + read_vtm, + read_vtp, + read_vtu, +) @dataclass @@ -57,7 +63,7 @@ class MeshDatapipe(Datapipe): Parameters ---------- data_dir : str - Directory where ERA5 data is stored + Directory where data is stored variables : List[str, None] Ordered list of variables to be loaded from the files num_variables : int @@ -70,8 +76,8 @@ class MeshDatapipe(Datapipe): If provided, the statistics are used to normalize the attributes batch_size : int, optional Batch size, by default 1 - num_steps : int, optional - Number of timesteps are included in the output variables, by default 1 + num_samples : int, optional + Number of samples to be loaded from the files, by default 1 shuffle : bool, optional Shuffle dataset, by default True num_workers : int, optional @@ -84,6 +90,20 @@ class MeshDatapipe(Datapipe): Number of training processes, by default 1 cache_data : False, optional Whether to cache the data in memory for faster access in subsequent epochs, by default False + + Note + ---- + The data is expected to be stored in the following format: + data_dir/ + ├── mesh_0001.vtp + ├── mesh_0002.vtp + └── ... + + The data is returned as a tuple of vertices, attributes, and edges. + An example of the data output a tuple of tensors: + vertices: torch.Size([batch_size, num_vertices, dim]) + ux: torch.Size([batch_size, num_vertices, 1]) + edges: torch.Size([batch_size, num_edges, 2]) """ def __init__( @@ -150,6 +170,8 @@ def parse_dataset_files(self) -> None: pattern = "*.vtp" case "vtu": pattern = "*.vtu" + case "vtm": + pattern = "*.vtm" case "cgns": pattern = "*.cgns" case _: @@ -378,6 +400,8 @@ def mesh_reader(self): return read_vtp if self.file_format == "vtu": return read_vtu + if self.file_format == "vtm": + return read_vtm if self.file_format == "cgns": return read_cgns else: @@ -387,101 +411,10 @@ def mesh_reader(self): def parse_vtk_data(self): if self.file_format == "vtp": - return _parse_vtk_polydata - elif self.file_format in ["vtu", "cgns"]: - return _parse_vtk_unstructuredgrid + return parse_vtk_polydata + elif self.file_format in ["vtu", "cgns", "vtm"]: + return parse_vtk_unstructuredgrid else: raise NotImplementedError( f"Data type {self.file_format} is not supported yet" ) - - -def _parse_vtk_polydata(polydata, variables): - # Fetch vertices - points = polydata.GetPoints() - if points is None: - raise ValueError("Failed to get points from the polydata.") - vertices = torch.tensor( - np.array([points.GetPoint(i) for i in range(points.GetNumberOfPoints())]), - dtype=torch.float32, - ) - - # Fetch node attributes # TODO modularize - attributes = [] - point_data = polydata.GetPointData() - if point_data is None: - raise ValueError("Failed to get point data from the unstructured grid.") - for array_name in variables: - try: - array = point_data.GetArray(array_name) - except ValueError: - raise ValueError( - f"Failed to get array {array_name} from the unstructured grid." - ) - array_data = np.zeros( - (points.GetNumberOfPoints(), array.GetNumberOfComponents()) - ) - for j in range(points.GetNumberOfPoints()): - array.GetTuple(j, array_data[j]) - attributes.append(torch.tensor(array_data, dtype=torch.float32)) - attributes = torch.cat(attributes, dim=-1) - # TODO torch.cat is usually very inefficient when the number of items is large. - # If possible, the resulting tensor should be pre-allocated and filled in during the loop. - - # Fetch edges - polys = polydata.GetPolys() - if polys is None: - raise ValueError("Failed to get polygons from the polydata.") - polys.InitTraversal() - edges = [] - id_list = vtk.vtkIdList() - for _ in range(polys.GetNumberOfCells()): - polys.GetNextCell(id_list) - num_ids = id_list.GetNumberOfIds() - edges = [ - (id_list.GetId(j), id_list.GetId((j + 1) % num_ids)) for j in range(num_ids) - ] - edges = torch.tensor(edges, dtype=torch.long) - - return vertices, attributes, edges - - -def _parse_vtk_unstructuredgrid(grid, variables): - # Fetch vertices - points = grid.GetPoints() - if points is None: - raise ValueError("Failed to get points from the unstructured grid.") - vertices = torch.tensor( - np.array([points.GetPoint(i) for i in range(points.GetNumberOfPoints())]), - dtype=torch.float32, - ) - - # Fetch node attributes # TODO modularize - attributes = [] - point_data = grid.GetPointData() - if point_data is None: - raise ValueError("Failed to get point data from the unstructured grid.") - for array_name in variables: - try: - array = point_data.GetArray(array_name) - except ValueError: - raise ValueError( - f"Failed to get array {array_name} from the unstructured grid." - ) - array_data = np.zeros( - (points.GetNumberOfPoints(), array.GetNumberOfComponents()) - ) - for j in range(points.GetNumberOfPoints()): - array.GetTuple(j, array_data[j]) - attributes.append(torch.tensor(array_data, dtype=torch.float32)) - if variables: - attributes = torch.cat(attributes, dim=-1) - else: - attributes = torch.zeros((1,), dtype=torch.float32) - - # Return a dummy tensor of zeros for edges since they are not directly computable - return ( - vertices, - attributes, - torch.zeros((0, 2), dtype=torch.long), - ) # Dummy tensor for edges diff --git a/physicsnemo/datapipes/cae/readers.py b/physicsnemo/datapipes/cae/readers.py index b083f2e50d..4aa54b8f19 100644 --- a/physicsnemo/datapipes/cae/readers.py +++ b/physicsnemo/datapipes/cae/readers.py @@ -17,12 +17,104 @@ import os from typing import Any +import numpy as np import torch import vtk Tensor = torch.Tensor +def parse_vtk_polydata(polydata, variables): + # Fetch vertices + points = polydata.GetPoints() + if points is None: + raise ValueError("Failed to get points from the polydata.") + vertices = torch.tensor( + np.array([points.GetPoint(i) for i in range(points.GetNumberOfPoints())]), + dtype=torch.float32, + ) + + # Fetch node attributes # TODO modularize + attributes = [] + point_data = polydata.GetPointData() + if point_data is None: + raise ValueError("Failed to get point data from the unstructured grid.") + for array_name in variables: + try: + array = point_data.GetArray(array_name) + except ValueError: + raise ValueError( + f"Failed to get array {array_name} from the unstructured grid." + ) + array_data = np.zeros( + (points.GetNumberOfPoints(), array.GetNumberOfComponents()) + ) + for j in range(points.GetNumberOfPoints()): + array.GetTuple(j, array_data[j]) + attributes.append(torch.tensor(array_data, dtype=torch.float32)) + attributes = torch.cat(attributes, dim=-1) + # TODO torch.cat is usually very inefficient when the number of items is large. + # If possible, the resulting tensor should be pre-allocated and filled in during the loop. + + # Fetch edges + polys = polydata.GetPolys() + if polys is None: + raise ValueError("Failed to get polygons from the polydata.") + polys.InitTraversal() + edges = [] + id_list = vtk.vtkIdList() + for _ in range(polys.GetNumberOfCells()): + polys.GetNextCell(id_list) + num_ids = id_list.GetNumberOfIds() + edges = [ + (id_list.GetId(j), id_list.GetId((j + 1) % num_ids)) for j in range(num_ids) + ] + edges = torch.tensor(edges, dtype=torch.long) + + return vertices, attributes, edges + + +def parse_vtk_unstructuredgrid(grid, variables): + # Fetch vertices + points = grid.GetPoints() + if points is None: + raise ValueError("Failed to get points from the unstructured grid.") + vertices = torch.tensor( + np.array([points.GetPoint(i) for i in range(points.GetNumberOfPoints())]), + dtype=torch.float32, + ) + + # Fetch node attributes # TODO modularize + attributes = [] + point_data = grid.GetPointData() + if point_data is None: + raise ValueError("Failed to get point data from the unstructured grid.") + for array_name in variables: + try: + array = point_data.GetArray(array_name) + except ValueError: + raise ValueError( + f"Failed to get array {array_name} from the unstructured grid." + ) + array_data = np.zeros( + (points.GetNumberOfPoints(), array.GetNumberOfComponents()) + ) + for j in range(points.GetNumberOfPoints()): + array.GetTuple(j, array_data[j]) + attributes.append(torch.tensor(array_data, dtype=torch.float32)) + if variables: + attributes = torch.cat(attributes, dim=-1) + else: + attributes = torch.zeros((1,), dtype=torch.float32) + + # Return a dummy tensor of zeros for edges since they are not directly computable + return ( + vertices, + attributes, + torch.zeros((0, 2), dtype=torch.long), + ) # Dummy tensor for edges + + def read_vtp(file_path: str) -> Any: # TODO add support for older format (VTK) """ Read a VTP file and return the polydata. @@ -95,6 +187,44 @@ def read_vtu(file_path: str) -> Any: return grid +def read_vtm(file_path: str) -> Any: + """ + Read a VTM (VTK MultiBlock) file and return the unstructured grid data. + + Parameters + ---------- + file_path : str + Path to the VTM file. + + Returns + ------- + vtkUnstructuredGrid + The unstructured grid data extracted from the multi-block dataset. + """ + # Check if file exists + if not os.path.exists(file_path): + raise FileNotFoundError(f"{file_path} does not exist.") + + # Check if file has .vtm extension + if not file_path.endswith(".vtm"): + raise ValueError(f"Expected a .vtm file, got {file_path}") + + # Create a VTM reader + reader = vtk.vtkXMLMultiBlockDataReader() + reader.SetFileName(file_path) + reader.Update() + + # Get the multi-block dataset + multi_block = reader.GetOutput() + + # Check if the multi-block dataset is valid + if multi_block is None: + raise ValueError(f"Failed to read multi-block data from {file_path}") + + # Extract and return the vtkUnstructuredGrid from the multi-block dataset + return _extract_unstructured_grid(multi_block) + + def read_cgns(file_path: str) -> Any: """ Read a CGNS file and return the unstructured grid data. diff --git a/physicsnemo/datapipes/cae/transient_mesh_datapipe.py b/physicsnemo/datapipes/cae/transient_mesh_datapipe.py new file mode 100644 index 0000000000..5905b80b94 --- /dev/null +++ b/physicsnemo/datapipes/cae/transient_mesh_datapipe.py @@ -0,0 +1,398 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import torch +import vtk + +try: + import nvidia.dali as dali + import nvidia.dali.plugin.pytorch as dali_pth +except ImportError: + raise ImportError( + "DALI dataset requires NVIDIA DALI package to be installed. " + + "The package can be installed at:\n" + + "https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html" + ) + +from dataclasses import dataclass +from pathlib import Path +from typing import Iterable, List, Tuple, Union + +from torch import Tensor + +from physicsnemo.datapipes.datapipe import Datapipe +from physicsnemo.datapipes.meta import DatapipeMetaData + +from .readers import parse_vtk_polydata, parse_vtk_unstructuredgrid, read_cgns, read_vtp, read_vtu, read_vtm + + +@dataclass +class MetaData(DatapipeMetaData): + name: str = "TransientMeshDatapipe" + # Optimization + auto_device: bool = True + cuda_graphs: bool = True + # Parallel + ddp_sharding: bool = True + + +class TransientMeshDatapipe(Datapipe): + """DALI data pipeline for transient mesh data + + The data is expected to be stored in the following format: + data_dir/ + ├── sim_0001/ + │ ├── mesh_0001.vtm + │ ├── mesh_0002.vtm + │ └── ... + ├── sim_0002/ + │ ├── mesh_0001.vtm + │ ├── mesh_0002.vtm + │ └── ... + └── ... + The data is returned as a tuple of vertices, attributes, and edges. + An example of the data output a tuple of tensors: + vertices : torch.Size([batch_size, sequence_length, num_vertices, dim]) + ux : torch.Size([batch_size, sequence_length, num_vertices, 1]) + edges : torch.Size([batch_size, sequence_length, num_edges, 2]) + + Parameters + ---------- + data_dir : str + Root directory containing sub-folders for each simulation run. + variables : List[str] + Ordered list of variable names to read from each mesh file. + num_variables : int + Number of variables (channels) expected in ``variables``. + file_format : str, optional + Mesh file format, by default "vtm". Supported formats: "vtm", "vtp", "vtu", "cgns". + stats_dir : Union[str, None], optional + Directory holding ``global_means.npy`` and ``global_stds.npy`` files used for normalisation. + sequence_length : int, optional + Number of consecutive timesteps returned in each sample sequence, by default ``2``. + batch_size : int, optional + Samples per batch, by default ``1``. + shuffle : bool, optional + Shuffle sequences each epoch, by default ``True``. + num_workers : int, optional + Number of Python workers used by DALI external source, by default ``1``. + device : Union[str, torch.device], optional + Device on which the DALI pipeline runs, by default the first CUDA device. + process_rank : int, optional + Local rank id when using distributed training, by default ``0``. + world_size : int, optional + Total number of distributed processes, by default ``1``. + cache_data : bool, optional + Whether to cache parsed mesh data in memory, by default ``False``. + """ + + def __init__( + self, + data_dir: str, + variables: List[str], + num_variables: int, + file_format: str = "vtp", + stats_dir: Union[str, None] = None, + sequence_length: int = 2, + batch_size: int = 1, + shuffle: bool = True, + num_workers: int = 1, + device: Union[str, torch.device] = "cuda", + process_rank: int = 0, + world_size: int = 1, + cache_data: bool = False, + ): + super().__init__(meta=MetaData()) + self.file_format = file_format + self.variables = variables + self.num_variables = num_variables + self.sequence_length = sequence_length + self.batch_size = batch_size + self.num_workers = num_workers + self.shuffle = shuffle + self.data_dir = Path(data_dir) + self.stats_dir = Path(stats_dir) if stats_dir is not None else None + self.process_rank = process_rank + self.world_size = world_size + self.cache_data = cache_data + + # if self.batch_size > 1: + # raise NotImplementedError("Batch size greater than 1 is not supported yet") + + # Set up device, needed for pipeline + if isinstance(device, str): + device = torch.device(device) + # Need a index id if cuda + if device.type == "cuda" and device.index is None: + device = torch.device("cuda:0") + self.device = device + + # check root directory exists + if not self.data_dir.is_dir(): + raise IOError(f"Error, data directory {self.data_dir} does not exist") + + self.parse_dataset_files() + self.load_statistics() + + self.pipe = self._create_pipeline() + + def parse_dataset_files(self) -> None: + """Parses the data directory and builds a list of fixed-length sequences. + + Each sub-directory inside ``data_dir`` is assumed to correspond to one + simulation run that contains an ordered series of mesh files. For a + chosen ``sequence_length`` this routine creates all sliding-window + sequences and stores them in ``self.sequence_paths``. + """ + # Determine file glob pattern + match self.file_format: + case "vtp": + pattern = "*.vtp" + case "vtu": + pattern = "*.vtu" + case "vtm": + pattern = "*.vtm" + case "cgns": + pattern = "*.cgns" + case _: + raise NotImplementedError( + f"Data type {self.file_format} is not supported yet" + ) + + # Build the list of sequences. + self.sequence_paths: List[List[str]] = [] + sim_dirs = [p for p in sorted(self.data_dir.iterdir()) if p.is_dir()] + + # Fallback: if no sub-directories are present but the current directory already + # contains mesh files matching the pattern, treat *data_dir* itself as a single + # simulation folder so that users can point the datapipe directly at one run. + if not sim_dirs: + raise IOError( + f"No mesh files matching '{pattern}' found in {self.data_dir} and no sub-directories present." + ) + + for sim_dir in sim_dirs: + files = sorted(str(fp) for fp in sim_dir.glob(pattern)) + if len(files) < self.sequence_length: + self.logger.warning( + f"Skipping {sim_dir} – only {len(files)} files but sequence_length={self.sequence_length}" + ) + continue + for i in range(len(files) - self.sequence_length + 1): + self.sequence_paths.append(files[i : i + self.sequence_length]) + + self.logger.info(f"Total number of sequences: {len(self.sequence_paths)}") + + def load_statistics( + self, + ) -> None: # TODO generalize and combine with climate/era5_hdf5 datapipes + """Loads statistics from pre-computed numpy files + + The statistic files should be of name global_means.npy and global_std.npy with + a shape of [1, C] located in the stat_dir. + + Raises + ------ + IOError + If mean or std numpy files are not found + AssertionError + If loaded numpy arrays are not of correct size + """ + # If no stats dir we just skip loading the stats + if self.stats_dir is None: + self.mu = None + self.std = None + return + # load normalisation values + mean_stat_file = self.stats_dir / Path("global_means.npy") + std_stat_file = self.stats_dir / Path("global_stds.npy") + + if not mean_stat_file.exists(): + raise IOError(f"Mean statistics file {mean_stat_file} not found") + if not std_stat_file.exists(): + raise IOError(f"Std statistics file {std_stat_file} not found") + + # has shape [1, C] + self.mu = np.load(str(mean_stat_file))[:, 0 : self.num_variables] + # has shape [1, C] + self.std = np.load(str(std_stat_file))[:, 0 : self.num_variables] + + if not self.mu.shape == self.std.shape == (1, self.num_variables): + raise AssertionError("Error, normalisation arrays have wrong shape") + + def _create_pipeline(self) -> dali.Pipeline: + """Create DALI pipeline + + Returns + ------- + dali.Pipeline + Mesh DALI pipeline + """ + pipe = dali.Pipeline( + batch_size=self.batch_size, + num_threads=2, + prefetch_queue_depth=2, + py_num_workers=self.num_workers, + device_id=self.device.index, + py_start_method="spawn", + ) + + with pipe: + source = TransientMeshDaliExternalSource( + sequence_paths=self.sequence_paths, + file_format=self.file_format, + variables=self.variables, + batch_size=self.batch_size, + shuffle=self.shuffle, + process_rank=self.process_rank, + world_size=self.world_size, + cache_data=self.cache_data, + ) + # Update length of dataset + self.length = len(source) // self.batch_size + # Read current batch. + vertices, attributes, edges = dali.fn.external_source( + source, + num_outputs=3, + parallel=True, + batch=False, + ) + + if self.device.type == "cuda": + # Move tensors to GPU as external_source won't do that. + vertices = vertices.gpu() + attributes = attributes.gpu() + edges = edges.gpu() + + # Normalize attributes if statistics are available. + if self.stats_dir is not None: + attributes = dali.fn.normalize(attributes, mean=self.mu, stddev=self.std) + + # Set outputs. + pipe.set_outputs(vertices, attributes, edges) + + return pipe + + def __iter__(self): + # Reset the pipeline before creating an iterator to enable epochs. + self.pipe.reset() + # Create DALI PyTorch iterator. + return dali_pth.DALIGenericIterator([self.pipe], ["vertices", "x", "edges"]) + + def __len__(self): + return self.length + + +class TransientMeshDaliExternalSource: + """DALI external source that yields fixed-length sequences of mesh data.""" + + def __init__( + self, + sequence_paths: Iterable[Iterable[str]], + file_format: str, + variables: List[str], + batch_size: int = 1, + shuffle: bool = True, + process_rank: int = 0, + world_size: int = 1, + cache_data: bool = False, + ): + self.sequence_paths = list(sequence_paths) + self.file_format = file_format + self.variables = variables + self.batch_size = batch_size + self.shuffle = shuffle + self.cache_data = cache_data + + self.last_epoch = None + + # Shard indices if running in parallel (e.g. DDP). + all_indices = np.arange(len(self.sequence_paths)) + self.indices = np.array_split(all_indices, world_size)[process_rank] + + # Number of full batches (DALI does not support incomplete batches in parallel mode). + self.num_batches = len(self.indices) // self.batch_size + + # Helpers for reading / parsing single mesh files. + self.mesh_reader_fn = self.mesh_reader() + self.parse_vtk_data_fn = self.parse_vtk_data() + + # Optional in-memory cache keyed by absolute file path. + if self.cache_data: + unique_files = {fp for seq in self.sequence_paths for fp in seq} + self.data_cache = {fp: None for fp in unique_files} + + def __call__(self, sample_info: dali.types.SampleInfo) -> Tuple[Tensor, Tensor, Tensor]: + if sample_info.iteration >= self.num_batches: + raise StopIteration() + + # Epoch-wise shuffling. + if self.shuffle and sample_info.epoch_idx != self.last_epoch: + np.random.default_rng(seed=sample_info.epoch_idx).shuffle(self.indices) + self.last_epoch = sample_info.epoch_idx + + idx = self.indices[sample_info.idx_in_epoch] + seq_files = self.sequence_paths[idx] + + vertices_seq, attributes_seq, edges_seq = [], [], [] + for fp in seq_files: + if self.cache_data: + cached = self.data_cache.get(fp) + if cached is None: + data = self.mesh_reader_fn(fp) + cached = self.parse_vtk_data_fn(data, self.variables) + self.data_cache[fp] = cached + v, a, e = cached + else: + v, a, e = self.parse_vtk_data_fn(self.mesh_reader_fn(fp), self.variables) + vertices_seq.append(v) + attributes_seq.append(a) + edges_seq.append(e) + + vertices = torch.stack(vertices_seq, dim=0) + attributes = torch.stack(attributes_seq, dim=0) + edges = torch.stack(edges_seq, dim=0) + + return vertices, attributes, edges + + def __len__(self): + return len(self.indices) + + def mesh_reader(self): + if self.file_format == "vtp": + return read_vtp + if self.file_format == "vtu": + return read_vtu + if self.file_format == "vtm": + return read_vtm + if self.file_format == "cgns": + return read_cgns + else: + raise NotImplementedError( + f"Data type {self.file_format} is not supported yet" + ) + + def parse_vtk_data(self): + if self.file_format == "vtp": + return parse_vtk_polydata + elif self.file_format in ["vtu", "cgns", "vtm"]: + return parse_vtk_unstructuredgrid + else: + raise NotImplementedError( + f"Data type {self.file_format} is not supported yet" + ) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 0c1957bdad..c1449e07a5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ makani = [ fignet = [ "jaxtyping>=0.2", - "torch_scatter>=2.1", + #"torch_scatter>=2.1", "torchinfo>=1.8", "warp-lang>=1.0", "webdataset>=0.2", @@ -107,6 +107,10 @@ all = [ ] +[tool.uv] +no-build-isolation-package = ["physicsnemo", "torch_scatter"] + + [tool.setuptools.dynamic] version = {attr = "physicsnemo.__version__"} diff --git a/test/datapipes/test_transient_mesh_datapipe.py b/test/datapipes/test_transient_mesh_datapipe.py new file mode 100644 index 0000000000..d6bc7e81c4 --- /dev/null +++ b/test/datapipes/test_transient_mesh_datapipe.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from pathlib import Path + +import pytest +from pytest_utils import import_or_fail + + +@import_or_fail(["vtk"]) +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +@pytest.mark.parametrize("file_format", ["vtp", "vtu"]) +def test_transient_mesh_datapipe(device, file_format, tmp_path, pytestconfig): + """Smoke-tests the TransientMeshDatapipe with a synthetic VTP time-sequence.""" + + import vtk + + from physicsnemo.datapipes.cae import TransientMeshDatapipe + + def _write_random_mesh(num_points: int, num_triangles: int, out_file: Path): + """Create a random VTP or VTU mesh depending on file extension.""" + # Create random points + points = vtk.vtkPoints() + for _ in range(num_points): + x, y, z = ( + random.uniform(-10, 10), + random.uniform(-10, 10), + random.uniform(-10, 10), + ) + points.InsertNextPoint(x, y, z) + + # Create triangles + triangles = vtk.vtkCellArray() + for _ in range(num_triangles): + p1, p2, p3 = ( + random.randint(0, num_points - 1), + random.randint(0, num_points - 1), + random.randint(0, num_points - 1), + ) + triangle = vtk.vtkTriangle() + triangle.GetPointIds().SetId(0, p1) + triangle.GetPointIds().SetId(1, p2) + triangle.GetPointIds().SetId(2, p3) + triangles.InsertNextCell(triangle) + + # Attribute array + scalars = vtk.vtkDoubleArray() + scalars.SetName("RandomFeatures") + for _ in range(num_points): + scalars.InsertNextValue(random.uniform(0, 1)) + + if out_file.suffix == ".vtp": + poly = vtk.vtkPolyData() + poly.SetPoints(points) + poly.SetPolys(triangles) + poly.GetPointData().SetScalars(scalars) + writer = vtk.vtkXMLPolyDataWriter() + writer.SetFileName(str(out_file)) + writer.SetInputData(poly) + writer.Write() + else: + grid = vtk.vtkUnstructuredGrid() + grid.SetPoints(points) + grid.SetCells(vtk.VTK_TRIANGLE, triangles) + grid.GetPointData().SetScalars(scalars) + writer = vtk.vtkXMLUnstructuredGridWriter() + writer.SetFileName(str(out_file)) + writer.SetInputData(grid) + writer.Write() + + # ------------------------------------------------------------------ + # Build temporary dataset: 1 simulation directory, 3 timesteps. + # ------------------------------------------------------------------ + root_dir = tmp_path / "dataset" + sim_dir = root_dir / "simulation_000" + sim_dir.mkdir(parents=True) + + for step in range(3): # Need >= sequence_length (2) + 1 + file_path = sim_dir / f"mesh_{step:04d}.{file_format}" + _write_random_mesh(num_points=10, num_triangles=20, out_file=file_path) + + # ------------------------------------------------------------------ + # Instantiate the datapipe and validate basic behaviour. + # ------------------------------------------------------------------ + sequence_length = 2 + dp = TransientMeshDatapipe( + data_dir=root_dir, + variables=["RandomFeatures"], + num_variables=1, + file_format=file_format, + sequence_length=sequence_length, + batch_size=1, + shuffle=False, + num_workers=1, + device=device, + ) + + # There are 3 files → (3 - 2 + 1) = 2 sequences. + assert len(dp) == 2 + + for batch in dp: + sample = batch[0] + vertices = sample["vertices"] + x = sample["x"] + edges = sample["edges"] + + # Expected shapes: (B=1, S, V, ...) + assert vertices.shape[:2] == (1, sequence_length) + assert vertices.shape[-1] == 3 + assert x.shape[:3] == (1, sequence_length, 10) + assert x.shape[-1] == 1 + # Edges tensor last dim must be 2 + assert edges.shape[-1] == 2 + + # Only iterate first batch for speed. + break \ No newline at end of file From 2cf6876c758aca838d240f2cc5c53d9cb1c53f4d Mon Sep 17 00:00:00 2001 From: Oliver Date: Fri, 18 Jul 2025 12:31:57 -0700 Subject: [PATCH 3/6] fixed darcy --- physicsnemo/datapipes/benchmarks/darcy.py | 1 - 1 file changed, 1 deletion(-) diff --git a/physicsnemo/datapipes/benchmarks/darcy.py b/physicsnemo/datapipes/benchmarks/darcy.py index 025d24e7a0..e85e55e9f4 100644 --- a/physicsnemo/datapipes/benchmarks/darcy.py +++ b/physicsnemo/datapipes/benchmarks/darcy.py @@ -273,7 +273,6 @@ def generate_batch(self) -> None: ], device=self.device, ) - def __iter__(self) -> Tuple[Tensor, Tensor]: """ From f4a1dc7e7a60ea2b4bce06a0e412ef2adbaa579e Mon Sep 17 00:00:00 2001 From: Oliver Date: Fri, 18 Jul 2025 12:34:11 -0700 Subject: [PATCH 4/6] udid comment --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c1449e07a5..27d90ad304 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ makani = [ fignet = [ "jaxtyping>=0.2", - #"torch_scatter>=2.1", + "torch_scatter>=2.1", "torchinfo>=1.8", "warp-lang>=1.0", "webdataset>=0.2", From a7ee016cb0af3b3793d02ab2bc8e1938ccfa03a3 Mon Sep 17 00:00:00 2001 From: Oliver Date: Fri, 18 Jul 2025 12:40:19 -0700 Subject: [PATCH 5/6] not sure --- physicsnemo/datapipes/cae/domino_datapipe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/physicsnemo/datapipes/cae/domino_datapipe.py b/physicsnemo/datapipes/cae/domino_datapipe.py index ad83331f24..84334a7cd6 100644 --- a/physicsnemo/datapipes/cae/domino_datapipe.py +++ b/physicsnemo/datapipes/cae/domino_datapipe.py @@ -34,7 +34,10 @@ from pathlib import Path from typing import Literal, Optional, Protocol, Sequence, Union -import cuml +try: + import cuml +except ImportError: + pass # TODO: Fix dependency (ohennigh) import cupy as cp import numpy as np import torch From 5abec54a8106eff612a23503f4df1570a9c62920 Mon Sep 17 00:00:00 2001 From: Oliver Date: Mon, 21 Jul 2025 10:21:59 -0700 Subject: [PATCH 6/6] fixing ci --- .../datapipes/cae/transient_mesh_datapipe.py | 24 ++++++++++++++----- .../datapipes/test_transient_mesh_datapipe.py | 2 +- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/physicsnemo/datapipes/cae/transient_mesh_datapipe.py b/physicsnemo/datapipes/cae/transient_mesh_datapipe.py index 5905b80b94..3e76526aaf 100644 --- a/physicsnemo/datapipes/cae/transient_mesh_datapipe.py +++ b/physicsnemo/datapipes/cae/transient_mesh_datapipe.py @@ -17,7 +17,6 @@ import numpy as np import torch -import vtk try: import nvidia.dali as dali @@ -38,7 +37,14 @@ from physicsnemo.datapipes.datapipe import Datapipe from physicsnemo.datapipes.meta import DatapipeMetaData -from .readers import parse_vtk_polydata, parse_vtk_unstructuredgrid, read_cgns, read_vtp, read_vtu, read_vtm +from .readers import ( + parse_vtk_polydata, + parse_vtk_unstructuredgrid, + read_cgns, + read_vtm, + read_vtp, + read_vtu, +) @dataclass @@ -281,7 +287,9 @@ def _create_pipeline(self) -> dali.Pipeline: # Normalize attributes if statistics are available. if self.stats_dir is not None: - attributes = dali.fn.normalize(attributes, mean=self.mu, stddev=self.std) + attributes = dali.fn.normalize( + attributes, mean=self.mu, stddev=self.std + ) # Set outputs. pipe.set_outputs(vertices, attributes, edges) @@ -337,7 +345,9 @@ def __init__( unique_files = {fp for seq in self.sequence_paths for fp in seq} self.data_cache = {fp: None for fp in unique_files} - def __call__(self, sample_info: dali.types.SampleInfo) -> Tuple[Tensor, Tensor, Tensor]: + def __call__( + self, sample_info: dali.types.SampleInfo + ) -> Tuple[Tensor, Tensor, Tensor]: if sample_info.iteration >= self.num_batches: raise StopIteration() @@ -359,7 +369,9 @@ def __call__(self, sample_info: dali.types.SampleInfo) -> Tuple[Tensor, Tensor, self.data_cache[fp] = cached v, a, e = cached else: - v, a, e = self.parse_vtk_data_fn(self.mesh_reader_fn(fp), self.variables) + v, a, e = self.parse_vtk_data_fn( + self.mesh_reader_fn(fp), self.variables + ) vertices_seq.append(v) attributes_seq.append(a) edges_seq.append(e) @@ -395,4 +407,4 @@ def parse_vtk_data(self): else: raise NotImplementedError( f"Data type {self.file_format} is not supported yet" - ) \ No newline at end of file + ) diff --git a/test/datapipes/test_transient_mesh_datapipe.py b/test/datapipes/test_transient_mesh_datapipe.py index d6bc7e81c4..69034ca0d7 100644 --- a/test/datapipes/test_transient_mesh_datapipe.py +++ b/test/datapipes/test_transient_mesh_datapipe.py @@ -127,4 +127,4 @@ def _write_random_mesh(num_points: int, num_triangles: int, out_file: Path): assert edges.shape[-1] == 2 # Only iterate first batch for speed. - break \ No newline at end of file + break