diff --git a/.gitignore b/.gitignore index 4f7110253..f64f9d000 100644 --- a/.gitignore +++ b/.gitignore @@ -182,3 +182,4 @@ examples/tutorials/z_other_tutorials/json_io/multiple_series_faults_computed.jso examples/tutorials/z_other_tutorials/json_io/combination_model.json examples/tutorials/z_other_tutorials/json_io/combination_model_computed.json /test/temp/ +test/test_modules/run_test.py diff --git a/gempy/modules/mesh_extranction/marching_cubes.py b/gempy/modules/mesh_extranction/marching_cubes.py index c992f35f1..41d6d9def 100644 --- a/gempy/modules/mesh_extranction/marching_cubes.py +++ b/gempy/modules/mesh_extranction/marching_cubes.py @@ -1,3 +1,4 @@ +import os import numpy as np from typing import Optional from skimage import measure @@ -72,6 +73,14 @@ def extract_mesh_for_element(structural_element: StructuralElement, mask : np.ndarray, optional Optional mask to restrict the mesh extraction to specific regions. """ + if type(scalar_field).__module__ == 'torch': + import torch + scalar_field = torch.to_numpy(scalar_field) + if type(mask).__module__ == "torch": + import torch + mask = torch.to_numpy(mask) + + # Extract mesh using marching cubes verts, faces, _, _ = measure.marching_cubes( volume=scalar_field.reshape(regular_grid.resolution), diff --git a/test/test_modules/test_marching_cubes_pytorch.py b/test/test_modules/test_marching_cubes_pytorch.py new file mode 100644 index 000000000..7e3b8cdcd --- /dev/null +++ b/test/test_modules/test_marching_cubes_pytorch.py @@ -0,0 +1,64 @@ +""" +Copied from "test_marching_cubes.py" to test the pytorch implementation of marching cubes with minor adjustments +""" + +import os + +os.environ["DEFAULT_BACKEND"] = "PYTORCH" + +import numpy as np +from gempy_engine.core.data.raw_arrays_solution import RawArraysSolution + +import gempy as gp +from gempy.core.data.enumerators import ExampleModel +from gempy.core.data.grid_modules import RegularGrid +from gempy.modules.mesh_extranction import marching_cubes +from gempy.optional_dependencies import require_gempy_viewer + +PLOT = True + + +def test_marching_cubes_implementation(): + assert os.environ["DEFAULT_BACKEND"] == "PYTORCH" + model = gp.generate_example_model(ExampleModel.COMBINATION, compute_model=False) + + # Change the grid to only be the dense grid + dense_grid: RegularGrid = RegularGrid( + extent=model.grid.extent, + resolution=np.array([40, 20, 20]) + ) + + model.grid.dense_grid = dense_grid + gp.set_active_grid( + grid=model.grid, + grid_type=[model.grid.GridTypes.DENSE], + reset=True + ) + model.interpolation_options = gp.data.InterpolationOptions.init_dense_grid_options() + gp.compute_model(model) + + # Assert + assert model.solutions.block_solution_type == RawArraysSolution.BlockSolutionType.DENSE_GRID + assert model.solutions.dc_meshes is None + arrays = model.solutions.raw_arrays # * arrays is equivalent to gempy v2 solutions + + # assert arrays.scalar_field_matrix.shape == (3, 8_000) # * 3 surfaces, 8000 points + + marching_cubes.set_meshes_with_marching_cubes(model) + + # Assert + assert model.solutions.block_solution_type == RawArraysSolution.BlockSolutionType.DENSE_GRID + assert model.solutions.dc_meshes is None + assert model.structural_frame.structural_groups[0].elements[0].vertices.shape == (600, 3) + assert model.structural_frame.structural_groups[1].elements[0].vertices.shape == (860, 3) + assert model.structural_frame.structural_groups[2].elements[0].vertices.shape == (1_256, 3) + assert model.structural_frame.structural_groups[2].elements[1].vertices.shape == (1_680, 3) + + if PLOT: + gpv = require_gempy_viewer() + gtv: gpv.GemPyToVista = gpv.plot_3d( + model=model, + show_data=True, + image=True, + show=True + )