diff --git a/pyproject.toml b/pyproject.toml index 3432d01e4..e44d918fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ description = "Scilpy: diffusion MRI tools and utilities" authors = [{ name = "SCIL Team" }] readme = "README.md" requires-python = ">=3.11, <3.13" -license-files = ["LICENSE"] +license = { file = "LICENSE" } classifiers = [ "Development Status :: 3 - Alpha", "Environment :: Console", @@ -206,6 +206,7 @@ scil_sh_fusion = "scilpy.cli.scil_sh_fusion:main" scil_sh_to_aodf = "scilpy.cli.scil_sh_to_aodf:main" scil_sh_to_rish = "scilpy.cli.scil_sh_to_rish:main" scil_sh_to_sf = "scilpy.cli.scil_sh_to_sf:main" +scil_fodf_global_sf_threshold = "scilpy.cli.scil_fodf_global_sf_threshold:main" scil_stats_group_comparison = "scilpy.cli.scil_stats_group_comparison:main" scil_surface_apply_transform = "scilpy.cli.scil_surface_apply_transform:main" scil_surface_convert = "scilpy.cli.scil_surface_convert:main" diff --git a/src/scilpy/cli/scil_NODDI_maps.py b/src/scilpy/cli/scil_NODDI_maps.py index 1104620d8..328dc1ed6 100755 --- a/src/scilpy/cli/scil_NODDI_maps.py +++ b/src/scilpy/cli/scil_NODDI_maps.py @@ -22,10 +22,10 @@ import tempfile import amico -from dipy.io.gradients import read_bvals_bvecs import numpy as np from scilpy.io.gradients import fsl2mrtrix +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_overwrite_arg, add_processes_arg, add_verbose_arg, @@ -117,7 +117,11 @@ def main(): assert_headers_compatible(parser, args.in_dwi, optional=args.mask) # Generate a scheme file from the bvals and bvecs files - bvals, _ = read_bvals_bvecs(args.in_bval, args.in_bvec) + simg = StatefulImage.load(args.in_dwi) + simg.load_gradients(args.in_bval, args.in_bvec) + bvals = simg.bvals + world_bvecs = simg.world_bvecs + _ = check_b0_threshold(bvals.min(), b0_thr=args.tolerance, skip_b0_check=args.skip_b0_check, overwrite_with_min=False) @@ -135,13 +139,16 @@ def main(): logging.info('Will compute NODDI with AMICO on {} shells at found at {}.' .format(len(shells_centroids), np.sort(shells_centroids))) - # Save the resulting bvals to a temporary file + # Save the resulting bvals and bvecs to temporary files tmp_dir = tempfile.TemporaryDirectory() tmp_scheme_filename = os.path.join(tmp_dir.name, 'gradients.b') tmp_bval_filename = os.path.join(tmp_dir.name, 'bval') + tmp_bvec_filename = os.path.join(tmp_dir.name, 'bvec') np.savetxt(tmp_bval_filename, shells_centroids[indices_shells], newline=' ', fmt='%i') - fsl2mrtrix(tmp_bval_filename, args.in_bvec, tmp_scheme_filename) + # Use world_bvecs for the MRTrix scheme file to ensure consistency + np.savetxt(tmp_bvec_filename, world_bvecs.T, fmt='%.8f') + fsl2mrtrix(tmp_bval_filename, tmp_bvec_filename, tmp_scheme_filename) with redirected_stdout: # Load the data diff --git a/src/scilpy/cli/scil_bingham_metrics.py b/src/scilpy/cli/scil_bingham_metrics.py index a193941dc..5e9e171b3 100755 --- a/src/scilpy/cli/scil_bingham_metrics.py +++ b/src/scilpy/cli/scil_bingham_metrics.py @@ -30,12 +30,12 @@ ------------------------------------------------------------------------------ """ -import nibabel as nib import time import argparse import logging from scilpy.io.image import get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_overwrite_arg, add_processes_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, validate_nbr_processes, @@ -98,10 +98,13 @@ def main(): assert_outputs_exist(parser, args, [], optional=outputs) assert_headers_compatible(parser, args.in_bingham, args.mask) - bingham_im = nib.load(args.in_bingham) - bingham = bingham_im.get_fdata() - mask = get_data_as_mask(nib.load(args.mask), - dtype=bool) if args.mask else None + simg_bingham = StatefulImage.load(args.in_bingham) + bingham = simg_bingham.get_fdata() + mask = None + if args.mask: + mask_simg = StatefulImage.load(args.mask) + mask_simg.reorient(simg_bingham.axcodes) + mask = get_data_as_mask(mask_simg, dtype=bool) nbr_processes = validate_nbr_processes(parser, args) @@ -112,7 +115,7 @@ def main(): t1 = time.perf_counter() logging.info('FD computed in (s): {0}'.format(t1 - t0)) if args.out_fd: - nib.save(nib.Nifti1Image(fd, bingham_im.affine), args.out_fd) + StatefulImage.from_data(fd, simg_bingham).save(args.out_fd) if args.out_fs: t0 = time.perf_counter() @@ -120,15 +123,15 @@ def main(): fs = compute_fiber_spread(bingham, fd) t1 = time.perf_counter() logging.info('FS computed in (s): {0}'.format(t1 - t0)) - nib.save(nib.Nifti1Image(fs, bingham_im.affine), args.out_fs) + StatefulImage.from_data(fs, simg_bingham).save(args.out_fs) if args.out_ff: t0 = time.perf_counter() logging.info('Computing fiber fraction.') ff = compute_fiber_fraction(fd) t1 = time.perf_counter() - logging.info('FS computed in (s): {0}'.format(t1 - t0)) - nib.save(nib.Nifti1Image(ff, bingham_im.affine), args.out_ff) + logging.info('FF computed in (s): {0}'.format(t1 - t0)) + StatefulImage.from_data(ff, simg_bingham).save(args.out_ff) if __name__ == '__main__': diff --git a/src/scilpy/cli/scil_btensor_metrics.py b/src/scilpy/cli/scil_btensor_metrics.py index 2c8439eca..11f10f858 100755 --- a/src/scilpy/cli/scil_btensor_metrics.py +++ b/src/scilpy/cli/scil_btensor_metrics.py @@ -43,9 +43,9 @@ import nibabel as nib import numpy as np -from scilpy.image.utils import extract_affine from scilpy.io.btensor import generate_btensor_input from scilpy.io.image import get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist, assert_outputs_exist, add_processes_arg, add_verbose_arg, add_skip_b0_check_arg, @@ -178,7 +178,9 @@ def main(): raise ValueError(msg) # Loading - affine = extract_affine(args.in_dwis) + simg = StatefulImage.load(args.in_dwis[0]) + simg.to_ras() + affine = simg.affine # Note. This script does not currently allow using a separate b0_threshold # for the b0s. Using the tolerance. To change this, we would have to @@ -199,11 +201,14 @@ def main(): 'No mask provided. The fit might not converge due to noise. ' 'Please provide a mask if it is the case.') else: - mask = get_data_as_mask(nib.load(args.mask), dtype=bool) + mask_simg = StatefulImage.load(args.mask) + mask_simg.to_ras() + mask = get_data_as_mask(mask_simg, dtype=bool) if args.fa is not None: - vol = nib.load(args.fa) - FA = vol.get_fdata(dtype=np.float32) + fa_simg = StatefulImage.load(args.fa) + fa_simg.to_ras() + FA = fa_simg.get_fdata(dtype=np.float32) # Processing parameters = fit_gamma(data, gtab_infos, mask=mask, diff --git a/src/scilpy/cli/scil_bundle_generate_priors.py b/src/scilpy/cli/scil_bundle_generate_priors.py index 187ce480c..14d1a07e1 100755 --- a/src/scilpy/cli/scil_bundle_generate_priors.py +++ b/src/scilpy/cli/scil_bundle_generate_priors.py @@ -22,6 +22,7 @@ import numpy as np from scilpy.io.image import get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.streamlines import load_tractogram_with_reference from scilpy.io.utils import (add_overwrite_arg, add_reference_arg, @@ -104,10 +105,17 @@ def main(): assert_outputs_exist(parser, args, required) # Loading - img_sh = nib.load(args.in_fodf) - sh_shape = img_sh.shape - sh_order = find_order_from_nb_coeff(sh_shape) sh_basis, is_legacy = parse_sh_basis_arg(args) + simg_sh = StatefulImage.load(args.in_fodf, is_orientation=True, + is_world_space=not args.is_voxel_space, + sh_basis=sh_basis, is_legacy=is_legacy) + # Bring to voxel space for multiplication with TODI (which is in vox space) + input_sh_3d = simg_sh.to_voxel_direction(sh_basis=sh_basis, + is_legacy=is_legacy).astype(np.float32) + + sh_shape = input_sh_3d.shape + sh_order = find_order_from_nb_coeff(sh_shape) + img_mask = nib.load(args.in_mask) mask_data = get_data_as_mask(img_mask) @@ -124,17 +132,22 @@ def main(): # SF to SH # Memory friendly saving, as soon as possible saving then delete - priors_3d = np.zeros(sh_shape) + priors_3d = np.zeros(sh_shape, dtype=np.float32) sphere = get_sphere(name='repulsion724') priors_3d[sub_mask_3d] = sf_to_sh(todi_sf, sphere, sh_order_max=sh_order, basis_type=sh_basis, - legacy=is_legacy) - nib.save(nib.Nifti1Image(priors_3d, img_mask.affine), out_priors) + legacy=is_legacy).astype(np.float32) + + simg_priors = StatefulImage(priors_3d, img_mask.affine, + sh_basis=sh_basis, is_legacy=is_legacy, + is_orientation=True, is_world_space=False) + simg_priors.to_world_direction() + nib.save(simg_priors, out_priors) del priors_3d # Back to SF - input_sh_3d = img_sh.get_fdata(dtype=np.float32) + # input_sh_3d is already in voxel space input_sf_1d = sh_to_sf(input_sh_3d[sub_mask_3d], sphere, sh_order_max=sh_order, basis_type=sh_basis, legacy=is_legacy) @@ -155,8 +168,13 @@ def main(): input_sh_3d[sub_mask_3d] = sf_to_sh(mult_sf_1d, sphere, sh_order_max=sh_order, basis_type=sh_basis, - legacy=is_legacy) - nib.save(nib.Nifti1Image(input_sh_3d, img_mask.affine), out_efod) + legacy=is_legacy).astype(np.float32) + + simg_efod = StatefulImage(input_sh_3d, img_mask.affine, + sh_basis=sh_basis, is_legacy=is_legacy, + is_orientation=True, is_world_space=False) + simg_efod.to_world_direction() + nib.save(simg_efod, out_efod) del input_sh_3d nib.save(nib.Nifti1Image(sub_mask_3d.astype(np.uint8), img_mask.affine), diff --git a/src/scilpy/cli/scil_dki_metrics.py b/src/scilpy/cli/scil_dki_metrics.py index b40df3699..9b41ea3c3 100755 --- a/src/scilpy/cli/scil_dki_metrics.py +++ b/src/scilpy/cli/scil_dki_metrics.py @@ -55,20 +55,18 @@ import dipy.reconst.dki as dki import dipy.reconst.msdki as msdki -from dipy.io.gradients import read_bvals_bvecs from dipy.core.gradients import gradient_table from scilpy.dwi.operations import compute_residuals from scilpy.image.volume_operations import smooth_to_fwhm from scilpy.io.image import get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_overwrite_arg, add_skip_b0_check_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, add_tolerance_arg, assert_headers_compatible) from scilpy.gradients.bvec_bval_tools import (check_b0_threshold, - is_normalized_bvecs, - identify_shells, - normalize_bvecs) + identify_shells) from scilpy.version import version_string @@ -184,16 +182,22 @@ def main(): assert_headers_compatible(parser, args.in_dwi, args.mask) # Loading - img = nib.load(args.in_dwi) - data = img.get_fdata(dtype=np.float32) - affine = img.affine - mask = get_data_as_mask(nib.load(args.mask), - dtype=bool) if args.mask else None - - bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec) - if not is_normalized_bvecs(bvecs): - logging.warning('Your b-vectors do not seem normalized...') - bvecs = normalize_bvecs(bvecs) + simg = StatefulImage.load(args.in_dwi) + simg.load_gradients(args.in_bval, args.in_bvec) + + # DKI fit expects RAS (via dipy) + simg.to_ras() + + data = simg.get_fdata(dtype=np.float32) + affine = simg.affine + bvals = simg.bvals + bvecs = simg.world_bvecs + + mask = None + if args.mask: + mask_simg = StatefulImage.load(args.mask) + mask_simg.to_ras() + mask = get_data_as_mask(mask_simg, dtype=bool) # Note. This script does not currently allow using a separate b0_threshold # for the b0s. Using the tolerance. To change this, we would have to diff --git a/src/scilpy/cli/scil_dti_metrics.py b/src/scilpy/cli/scil_dti_metrics.py index 551cbb373..983b414d5 100755 --- a/src/scilpy/cli/scil_dti_metrics.py +++ b/src/scilpy/cli/scil_dti_metrics.py @@ -25,12 +25,10 @@ import argparse import logging -import nibabel as nib import numpy as np from dipy.core.gradients import gradient_table import dipy.denoise.noise_estimate as ne -from dipy.io.gradients import read_bvals_bvecs from dipy.reconst.dti import (TensorModel, color_fa, fractional_anisotropy, geodesic_anisotropy, mean_diffusivity, axial_diffusivity, norm, @@ -41,6 +39,7 @@ from scilpy.dwi.operations import compute_residuals, \ compute_residuals_statistics from scilpy.io.image import get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_b0_thresh_arg, add_overwrite_arg, add_skip_b0_check_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, @@ -184,19 +183,28 @@ def main(): assert_headers_compatible(parser, args.in_dwi, args.mask) # Loading - img = nib.load(args.in_dwi) - data = img.get_fdata(dtype=np.float32) - affine = img.affine - mask = get_data_as_mask(nib.load(args.mask), - dtype=bool) if args.mask else None + simg = StatefulImage.load(args.in_dwi) + simg.load_gradients(args.in_bval, args.in_bvec) - logging.info('Tensor estimation with the {} method...'.format(args.method)) - bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec) + # Reorient to RAS for DIPY + simg.to_ras() + + data = simg.get_fdata(dtype=np.float32) + bvals = simg.bvals + bvecs = simg.world_bvecs if not is_normalized_bvecs(bvecs): - logging.warning('Your b-vectors do not seem normalized...') + logger.warning('Your b-vectors do not seem normalized...') bvecs = normalize_bvecs(bvecs) + mask = None + if args.mask: + mask_simg = StatefulImage.load(args.mask) + mask_simg.reorient(simg.axcodes) + mask = get_data_as_mask(mask_simg, dtype=bool) + + logging.info('Tensor estimation with the {} method...'.format(args.method)) + # How the b0_threshold is used: gtab.b0s_mask is used # 1) In TensorModel in Dipy: # - The S0 images used as any other image in the design matrix and in @@ -229,40 +237,39 @@ def main(): tensor_vals_reordered = convert_tensor_from_dipy_format( tensor_vals, final_format=args.tensor_format) - fiber_tensors = nib.Nifti1Image( - tensor_vals_reordered.astype(np.float32), affine) - nib.save(fiber_tensors, args.tensor) + StatefulImage.from_data(tensor_vals_reordered.astype(np.float32), + simg).save(args.tensor) - del tensor_vals, fiber_tensors, tensor_vals_reordered + del tensor_vals, tensor_vals_reordered if args.fa or args.rgb: FA = fractional_anisotropy(tenfit.evals) FA[np.isnan(FA)] = 0 FA = np.clip(FA, 0, 1) if args.fa: - nib.save(nib.Nifti1Image(FA.astype(np.float32), affine), args.fa) + StatefulImage.from_data(FA.astype(np.float32), simg).save(args.fa) if args.rgb: RGB = color_fa(FA, tenfit.evecs) - nib.save(nib.Nifti1Image(np.array(255 * RGB, 'uint8'), affine), - args.rgb) + StatefulImage.from_data(np.array(255 * RGB, 'uint8'), + simg).save(args.rgb) if args.ga: GA = geodesic_anisotropy(tenfit.evals) GA[np.isnan(GA)] = 0 - nib.save(nib.Nifti1Image(GA.astype(np.float32), affine), args.ga) + StatefulImage.from_data(GA.astype(np.float32), simg).save(args.ga) if args.md: MD = mean_diffusivity(tenfit.evals) - nib.save(nib.Nifti1Image(MD.astype(np.float32), affine), args.md) + StatefulImage.from_data(MD.astype(np.float32), simg).save(args.md) if args.ad: AD = axial_diffusivity(tenfit.evals) - nib.save(nib.Nifti1Image(AD.astype(np.float32), affine), args.ad) + StatefulImage.from_data(AD.astype(np.float32), simg).save(args.ad) if args.rd: RD = radial_diffusivity(tenfit.evals) - nib.save(nib.Nifti1Image(RD.astype(np.float32), affine), args.rd) + StatefulImage.from_data(RD.astype(np.float32), simg).save(args.rd) if args.mode: # Compute tensor mode @@ -271,31 +278,32 @@ def main(): # Since the mode computation can generate NANs when not masked, # we need to remove them. non_nan_indices = np.isfinite(inter_mode) - mode = np.zeros(inter_mode.shape) - mode[non_nan_indices] = inter_mode[non_nan_indices] - nib.save(nib.Nifti1Image(mode.astype(np.float32), affine), args.mode) + mode_data = np.zeros(inter_mode.shape) + mode_data[non_nan_indices] = inter_mode[non_nan_indices] + StatefulImage.from_data(mode_data.astype(np.float32), + simg).save(args.mode) if args.norm: NORM = norm(tenfit.quadratic_form) - nib.save(nib.Nifti1Image(NORM.astype(np.float32), affine), args.norm) + StatefulImage.from_data(NORM.astype(np.float32), simg).save(args.norm) if args.evecs: - evecs = tenfit.evecs.astype(np.float32) - nib.save(nib.Nifti1Image(evecs, affine), args.evecs) + evecs_data = tenfit.evecs.astype(np.float32) + StatefulImage.from_data(evecs_data, simg).save(args.evecs) # save individual e-vectors also for i in range(3): - nib.save(nib.Nifti1Image(evecs[..., i], affine), - add_filename_suffix(args.evecs, '_v'+str(i+1))) + StatefulImage.from_data(evecs_data[..., i], simg).save( + add_filename_suffix(args.evecs, '_v'+str(i+1))) if args.evals: - evals = tenfit.evals.astype(np.float32) - nib.save(nib.Nifti1Image(evals, affine), args.evals) + evals_data = tenfit.evals.astype(np.float32) + StatefulImage.from_data(evals_data, simg).save(args.evals) # save individual e-values also for i in range(3): - nib.save(nib.Nifti1Image(evals[..., i], affine), - add_filename_suffix(args.evals, '_e' + str(i+1))) + StatefulImage.from_data(evals_data[..., i], simg).save( + add_filename_suffix(args.evals, '_e' + str(i+1))) if args.p_i_signal: S0 = np.mean(data[..., gtab.b0s_mask], axis=-1, keepdims=True) @@ -305,8 +313,8 @@ def main(): if args.mask is not None: pis_mask *= mask - nib.save(nib.Nifti1Image(pis_mask.astype(np.int16), affine), - args.p_i_signal) + StatefulImage.from_data(pis_mask.astype(np.int16), + simg).save(args.p_i_signal) if args.pulsation: STD = np.std(data[..., ~gtab.b0s_mask], axis=-1) @@ -314,8 +322,8 @@ def main(): if args.mask is not None: STD *= mask - nib.save(nib.Nifti1Image(STD.astype(np.float32), affine), - add_filename_suffix(args.pulsation, '_std_dwi')) + StatefulImage.from_data(STD.astype(np.float32), simg).save( + add_filename_suffix(args.pulsation, '_std_dwi')) if np.sum(gtab.b0s_mask) <= 1: logger.info('Not enough b=0 images to output standard ' @@ -330,8 +338,8 @@ def main(): if args.mask is not None: STD *= mask - nib.save(nib.Nifti1Image(STD.astype(np.float32), affine), - add_filename_suffix(args.pulsation, '_std_b0')) + StatefulImage.from_data(STD.astype(np.float32), simg).save( + add_filename_suffix(args.pulsation, '_std_b0')) if args.residual: if mask is None: @@ -354,7 +362,7 @@ def main(): R, data_diff = compute_residuals( predicted_data=tenfit2_predict.astype(np.float32), real_data=data, b0s_mask=gtab.b0s_mask, mask=mask) - nib.save(nib.Nifti1Image(R.astype(np.float32), affine), args.residual) + StatefulImage.from_data(R.astype(np.float32), simg).save(args.residual) # Each volume's residual statistics R_k, q1, q3, iqr, std = compute_residuals_statistics(data_diff) diff --git a/src/scilpy/cli/scil_dwi_extract_b0.py b/src/scilpy/cli/scil_dwi_extract_b0.py index e8d8e93e1..c98bc9d7a 100755 --- a/src/scilpy/cli/scil_dwi_extract_b0.py +++ b/src/scilpy/cli/scil_dwi_extract_b0.py @@ -12,12 +12,12 @@ import os from dipy.core.gradients import gradient_table -from dipy.io.gradients import read_bvals_bvecs import nibabel as nib import numpy as np from scilpy.dwi.utils import extract_b0 +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_b0_thresh_arg, add_overwrite_arg, add_skip_b0_check_arg, add_verbose_arg, assert_inputs_exist) @@ -93,7 +93,10 @@ def main(): # Outputs are not checked, since multiple use cases # are possible and hard to check - bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec) + simg = StatefulImage.load(args.in_dwi) + simg.load_gradients(args.in_bval, args.in_bvec) + bvals = simg.bvals + bvecs = simg.world_bvecs args.b0_threshold = check_b0_threshold(bvals.min(), b0_thr=args.b0_threshold, @@ -112,16 +115,13 @@ def main(): elif args.cluster_first: extract_in_cluster = True - image = nib.load(args.in_dwi) - b0_volumes = extract_b0( - image, gtab.b0s_mask, extract_in_cluster, strategy, args.block_size) + simg, gtab.b0s_mask, extract_in_cluster, strategy, args.block_size) if len(b0_volumes.shape) > 3 and not args.single_image: - _split_time_steps(b0_volumes, image.affine, image.header, args.out_b0) + _split_time_steps(b0_volumes, simg.affine, simg.header, args.out_b0) else: - nib.save(nib.Nifti1Image(b0_volumes, image.affine, image.header), - args.out_b0) + StatefulImage.from_data(b0_volumes, simg).save(args.out_b0) if __name__ == '__main__': diff --git a/src/scilpy/cli/scil_dwi_to_sh.py b/src/scilpy/cli/scil_dwi_to_sh.py index 463ed2032..118e0fb67 100755 --- a/src/scilpy/cli/scil_dwi_to_sh.py +++ b/src/scilpy/cli/scil_dwi_to_sh.py @@ -9,10 +9,9 @@ import logging from dipy.core.gradients import gradient_table -from dipy.io.gradients import read_bvals_bvecs -import nibabel as nib import numpy as np +from scilpy.io.stateful_image import StatefulImage from scilpy.gradients.bvec_bval_tools import check_b0_threshold from scilpy.io.image import get_data_as_mask from scilpy.io.utils import (add_b0_thresh_arg, add_overwrite_arg, @@ -68,10 +67,12 @@ def main(): assert_outputs_exist(parser, args, args.out_sh) assert_headers_compatible(parser, args.in_dwi, args.mask) - vol = nib.load(args.in_dwi) - dwi = vol.get_fdata(dtype=np.float32) + simg = StatefulImage.load(args.in_dwi) + simg.load_gradients(args.in_bval, args.in_bvec) + dwi = simg.get_fdata(dtype=np.float32) - bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec) + bvals = simg.bvals + bvecs = simg.world_bvecs # gtab.b0s_mask in used in compute_sh_coefficients to get the b0s. args.b0_threshold = check_b0_threshold(bvals.min(), @@ -81,15 +82,18 @@ def main(): sh_basis, is_legacy = parse_sh_basis_arg(args) - mask = get_data_as_mask(nib.load(args.mask), - dtype=bool) if args.mask else None + mask = None + if args.mask: + mask_simg = StatefulImage.load(args.mask) + mask_simg.to_ras() + mask = get_data_as_mask(mask_simg, dtype=bool) sh = compute_sh_coefficients(dwi, gtab, args.b0_threshold, args.sh_order, sh_basis, args.smooth, use_attenuation=args.use_attenuation, mask=mask, is_legacy=is_legacy) - nib.save(nib.Nifti1Image(sh.astype(np.float32), vol.affine), args.out_sh) + StatefulImage.from_data(sh.astype(np.float32), simg).save(args.out_sh) if __name__ == "__main__": diff --git a/src/scilpy/cli/scil_fibertube_tracking.py b/src/scilpy/cli/scil_fibertube_tracking.py index 75c0477f4..ded8253db 100755 --- a/src/scilpy/cli/scil_fibertube_tracking.py +++ b/src/scilpy/cli/scil_fibertube_tracking.py @@ -155,12 +155,13 @@ def _build_arg_parser(): ' faces. [%(default)s]') ftod_g.add_argument('--sfthres', dest='sf_threshold', metavar='sf_th', type=float, default=0.1, - help='Spherical function relative threshold. ' - '[%(default)s]') + help='Spherical function relative threshold ' + 'within each voxel. [%(default)s]') ftod_g.add_argument('--sfthres_init', metavar='sf_th', type=float, - default=0.5, dest='sf_threshold_init', - help="Spherical function relative threshold value " - "for the \ninitial direction. [%(default)s]") + default=0.5, + help='Spherical function relative threshold ' + 'within each voxel for the \n' + 'initial direction. [%(default)s]') seed_group = p.add_argument_group( 'Seeding options') @@ -263,7 +264,8 @@ def main(): # Since the scilpy Tracker requires a mask, we provide a fake one that will # never interfere. fake_mask_data = np.ones(in_sft.dimensions) - fake_mask = DataVolume(fake_mask_data, in_sft.voxel_sizes, 'nearest') + fake_mask = DataVolume(fake_mask_data, in_sft.voxel_sizes, in_sft.affine, + interpolation='nearest') if args.use_ftODF: logging.debug("Instantiating FTODF datavolume") @@ -302,7 +304,7 @@ def main(): logging.debug("Instantiating ODF propagator") propagator = ODFPropagator( datavolume, args.step_size, args.rk_order, args.algo, sh_basis, - args.sf_threshold, args.sf_threshold_init, theta, args.sphere, + args.sf_threshold, args.sfthres_init, theta, args.sphere, sub_sphere=args.sub_sphere, space=our_space, origin=our_origin, is_legacy=is_legacy) else: diff --git a/src/scilpy/cli/scil_fodf_global_sf_threshold.py b/src/scilpy/cli/scil_fodf_global_sf_threshold.py new file mode 100644 index 000000000..f618c7d3f --- /dev/null +++ b/src/scilpy/cli/scil_fodf_global_sf_threshold.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +""" +Compute a binary mask based on a global SF threshold. +The script masks voxels where the max SF amplitude is below +either a relative factor or an absolute threshold. + +The absolute threshold can be estimated from the mean/median maximum fODF in the +ventricles, computed with scil_fodf_max_in_ventricles. + +The input can be either SH coefficients or peaks. However, the vectors +cannot be normalized, as the amplitude is used for thresholding. +""" + +import argparse +import logging + +import nibabel as nib +import numpy as np + +from scilpy.io.stateful_image import StatefulImage +from scilpy.io.utils import (add_sh_basis_args, add_sphere_arg, + add_verbose_arg, add_overwrite_arg, + assert_inputs_exist, assert_outputs_exist, + parse_sh_basis_arg) +from scilpy.reconst.utils import compute_sf_threshold_mask +from scilpy.version import version_string + + +def _build_arg_parser(): + p = argparse.ArgumentParser(description=__doc__, + formatter_class=argparse.RawTextHelpFormatter, + epilog=version_string) + + p.add_argument('in_odf', + help='Input ODF file (SH or Peaks) (.nii.gz).') + p.add_argument('out_mask', + help='Output binary mask (.nii.gz).') + + thr_g = p.add_mutually_exclusive_group(required=True) + thr_g.add_argument('--relative', type=float, + help='Global SF threshold relative factor (0-1).') + thr_g.add_argument('--absolute', type=float, + help='Global SF absolute threshold.') + add_sh_basis_args(p) + add_sphere_arg(p) + add_overwrite_arg(p) + add_verbose_arg(p) + + return p + + +def main(): + parser = _build_arg_parser() + args = parser.parse_args() + logging.getLogger().setLevel(logging.getLevelName(args.verbose)) + + assert_inputs_exist(parser, args.in_odf) + assert_outputs_exist(parser, args, args.out_mask) + + sh_basis, is_legacy = parse_sh_basis_arg(args) + + logging.info("Loading ODF data.") + simg = StatefulImage.load(args.in_odf, is_orientation=True, + sh_basis=sh_basis, is_legacy=is_legacy) + + data = simg.to_voxel_direction(sh_basis=sh_basis, + is_legacy=is_legacy).astype(np.float32) + + logging.info("Computing global SF threshold mask.") + mask, global_max, threshold = compute_sf_threshold_mask( + data, sphere_name=args.sphere, relative_factor=args.relative, + absolute_threshold=args.absolute, sh_basis=sh_basis, + is_legacy=is_legacy) + + logging.info("Global max SF amplitude: {:.4f}".format(global_max)) + if args.relative is not None: + logging.info("Relative threshold: {:.4f} (Factor: {})" + .format(threshold, args.relative)) + else: + logging.info("Absolute threshold used: {:.4f}".format(args.absolute)) + + logging.info("Number of voxels in mask: {}".format(np.sum(mask))) + + # Save mask + mask_img = nib.Nifti1Image(mask.astype(np.uint8), simg.affine, + simg.header) + nib.save(mask_img, args.out_mask) + + +if __name__ == "__main__": + main() diff --git a/src/scilpy/cli/scil_fodf_memsmt.py b/src/scilpy/cli/scil_fodf_memsmt.py index 0fe6537e7..00219b4e6 100755 --- a/src/scilpy/cli/scil_fodf_memsmt.py +++ b/src/scilpy/cli/scil_fodf_memsmt.py @@ -180,7 +180,7 @@ def main(): dtype=bool) if args.mask else None # Checking data and sh_order - verify_data_vs_sh_order(data, args.sh_order) + verify_data_vs_sh_order(data, args.sh_order, gtab=gtab) sh_basis, is_legacy = parse_sh_basis_arg(args) # Checking response functions and computing mesmt response function diff --git a/src/scilpy/cli/scil_fodf_metrics.py b/src/scilpy/cli/scil_fodf_metrics.py index 92df1b578..f7b69d5ee 100755 --- a/src/scilpy/cli/scil_fodf_metrics.py +++ b/src/scilpy/cli/scil_fodf_metrics.py @@ -40,6 +40,7 @@ from dipy.direction.peaks import reshape_peaks_for_visualization from scilpy.io.image import get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_overwrite_arg, add_sh_basis_args, add_processes_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, @@ -138,11 +139,14 @@ def main(): assert_headers_compatible(parser, args.in_fODF, args.mask) # Loading - vol = nib.load(args.in_fODF) - data = vol.get_fdata(dtype=np.float32) - affine = vol.affine - mask = get_data_as_mask(nib.load(args.mask), - dtype=bool) if args.mask else None + simg = StatefulImage.load(args.in_fODF) + data = simg.get_fdata(dtype=np.float32) + affine = simg.affine + mask = None + if args.mask: + mask_simg = StatefulImage.load(args.mask) + mask_simg.reorient(simg.axcodes) + mask = get_data_as_mask(mask_simg, dtype=bool) sphere = get_sphere(name=args.sphere) sh_basis, is_legacy = parse_sh_basis_arg(args) @@ -168,26 +172,26 @@ def main(): # Save result if args.nufo: - nib.save(nib.Nifti1Image(nufo_map.astype(np.float32), affine), - args.nufo) + nufo_img = nib.Nifti1Image(nufo_map.astype(np.float32), affine) + StatefulImage.create_from(nufo_img, simg).save(args.nufo) if args.afd_max: - nib.save(nib.Nifti1Image(afd_max.astype(np.float32), affine), - args.afd_max) + afd_max_img = nib.Nifti1Image(afd_max.astype(np.float32), affine) + StatefulImage.create_from(afd_max_img, simg).save(args.afd_max) if args.afd_total: # this is the analytical afd total afd_tot = data[:, :, :, 0] - nib.save(nib.Nifti1Image(afd_tot.astype(np.float32), affine), - args.afd_total) + afd_tot_img = nib.Nifti1Image(afd_tot.astype(np.float32), affine) + StatefulImage.create_from(afd_tot_img, simg).save(args.afd_total) if args.afd_sum: - nib.save(nib.Nifti1Image(afd_sum.astype(np.float32), affine), - args.afd_sum) + afd_sum_img = nib.Nifti1Image(afd_sum.astype(np.float32), affine) + StatefulImage.create_from(afd_sum_img, simg).save(args.afd_sum) if args.rgb: - nib.save(nib.Nifti1Image(rgb_map.astype('uint8'), affine), - args.rgb) + rgb_img = nib.Nifti1Image(rgb_map.astype('uint8'), affine) + StatefulImage.create_from(rgb_img, simg).save(args.rgb) if args.peaks or args.peak_values: if not args.abs_peaks_and_values: @@ -196,15 +200,19 @@ def main(): where=peak_values[..., 0, None] != 0) peak_dirs[...] *= peak_values[..., :, None] if args.peaks: - nib.save(nib.Nifti1Image( + peaks_img = nib.Nifti1Image( reshape_peaks_for_visualization(peak_dirs), - affine), args.peaks) + affine) + StatefulImage.create_from(peaks_img, simg).save(args.peaks) if args.peak_values: - nib.save(nib.Nifti1Image(peak_values, vol.affine), - args.peak_values) + peak_vals_img = nib.Nifti1Image(peak_values, affine) + StatefulImage.create_from(peak_vals_img, simg).save( + args.peak_values) if args.peak_indices: - nib.save(nib.Nifti1Image(peak_indices, vol.affine), args.peak_indices) + peak_indices_img = nib.Nifti1Image(peak_indices, affine) + StatefulImage.create_from(peak_indices_img, simg).save( + args.peak_indices) if __name__ == "__main__": diff --git a/src/scilpy/cli/scil_fodf_msmt.py b/src/scilpy/cli/scil_fodf_msmt.py index cd782b331..3f084c01a 100755 --- a/src/scilpy/cli/scil_fodf_msmt.py +++ b/src/scilpy/cli/scil_fodf_msmt.py @@ -22,15 +22,14 @@ from dipy.core.gradients import gradient_table, unique_bvals_tolerance from dipy.data import get_sphere -from dipy.io.gradients import read_bvals_bvecs from dipy.reconst.mcsd import MultiShellDeconvModel, multi_shell_fiber_response -import nibabel as nib import numpy as np from scilpy.gradients.bvec_bval_tools import (check_b0_threshold, normalize_bvecs, is_normalized_bvecs) from scilpy.io.image import get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_overwrite_arg, add_processes_arg, assert_inputs_exist, assert_outputs_exist, add_sh_basis_args, add_skip_b0_check_arg, @@ -132,18 +131,31 @@ def main(): wm_frf = np.loadtxt(args.in_wm_frf) gm_frf = np.loadtxt(args.in_gm_frf) csf_frf = np.loadtxt(args.in_csf_frf) - vol = nib.load(args.in_dwi) - data = vol.get_fdata(dtype=np.float32) - bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec) + + simg = StatefulImage.load(args.in_dwi) + simg.load_gradients(args.in_bval, args.in_bvec) + + # Orientation standardization? + # Reconstruction logic (dipy/scilpy) often prefers specific orientation. + # We reorient secondary inputs to match the primary one. + # If we want to be fully robust, we could force RAS here, but let's see. + # scil_frf_msmt used to_ras(), so let's be consistent. + simg.to_ras() + + data = simg.get_fdata(dtype=np.float32) + bvals = simg.bvals + bvecs = simg.world_bvecs # Checking data and sh_order wm_frf, gm_frf, csf_frf = verify_frf_files(wm_frf, gm_frf, csf_frf) - verify_data_vs_sh_order(data, args.sh_order) sh_basis, is_legacy = parse_sh_basis_arg(args) # Checking mask - mask = get_data_as_mask(nib.load(args.mask), - dtype=bool) if args.mask else None + mask = None + if args.mask: + mask_simg = StatefulImage.load(args.mask) + mask_simg.reorient(simg.axcodes) + mask = get_data_as_mask(mask_simg, dtype=bool) # Checking bvals, bvecs values and loading gtab if not is_normalized_bvecs(bvecs): @@ -161,6 +173,9 @@ def main(): overwrite_with_min=False) gtab = gradient_table(bvals, bvecs=bvecs, b0_threshold=args.tolerance) + # Checking data and sh_order + verify_data_vs_sh_order(data, args.sh_order, gtab=gtab) + # Loading spheres reg_sphere = get_sphere(name='symmetric362') @@ -206,8 +221,8 @@ def main(): is_input_legacy=True, is_output_legacy=is_legacy, nbr_processes=args.nbr_processes) - nib.save(nib.Nifti1Image(wm_coeff.astype(np.float32), - vol.affine), args.wm_out_fODF) + res_simg = StatefulImage.from_data(wm_coeff.astype(np.float32), simg) + res_simg.save(args.wm_out_fODF) if args.gm_out_fODF: gm_coeff = shm_coeff[..., 1] @@ -218,8 +233,8 @@ def main(): is_input_legacy=True, is_output_legacy=is_legacy, nbr_processes=args.nbr_processes) - nib.save(nib.Nifti1Image(gm_coeff.astype(np.float32), - vol.affine), args.gm_out_fODF) + res_simg = StatefulImage.from_data(gm_coeff.astype(np.float32), simg) + res_simg.save(args.gm_out_fODF) if args.csf_out_fODF: csf_coeff = shm_coeff[..., 0] @@ -230,18 +245,18 @@ def main(): is_input_legacy=True, is_output_legacy=is_legacy, nbr_processes=args.nbr_processes) - nib.save(nib.Nifti1Image(csf_coeff.astype(np.float32), - vol.affine), args.csf_out_fODF) + res_simg = StatefulImage.from_data(csf_coeff.astype(np.float32), simg) + res_simg.save(args.csf_out_fODF) if args.vf: - nib.save(nib.Nifti1Image(vf.astype(np.float32), - vol.affine), args.vf) + res_simg = StatefulImage.from_data(vf.astype(np.float32), simg) + res_simg.save(args.vf) if args.vf_rgb: vf_rgb = vf / np.max(vf) * 255 vf_rgb = np.clip(vf_rgb, 0, 255) - nib.save(nib.Nifti1Image(vf_rgb.astype(np.uint8), - vol.affine), args.vf_rgb) + res_simg = StatefulImage.from_data(vf_rgb.astype(np.uint8), simg) + res_simg.save(args.vf_rgb) if __name__ == "__main__": diff --git a/src/scilpy/cli/scil_fodf_ssst.py b/src/scilpy/cli/scil_fodf_ssst.py index 73a420e08..4690f278b 100755 --- a/src/scilpy/cli/scil_fodf_ssst.py +++ b/src/scilpy/cli/scil_fodf_ssst.py @@ -12,22 +12,24 @@ from dipy.core.gradients import gradient_table from dipy.data import get_sphere -from dipy.io.gradients import read_bvals_bvecs from dipy.reconst.csdeconv import ConstrainedSphericalDeconvModel import nibabel as nib import numpy as np +from scilpy.dwi.operations import compute_dwi_attenuation from scilpy.gradients.bvec_bval_tools import (check_b0_threshold, + identify_shells, normalize_bvecs, is_normalized_bvecs) from scilpy.io.image import get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_b0_thresh_arg, add_overwrite_arg, add_processes_arg, add_sh_basis_args, add_skip_b0_check_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, parse_sh_basis_arg, assert_headers_compatible) from scilpy.reconst.fodf import fit_from_model -from scilpy.reconst.sh import convert_sh_basis +from scilpy.reconst.sh import convert_sh_basis, verify_data_vs_sh_order from scilpy.version import version_string @@ -54,6 +56,11 @@ def _build_arg_parser(): '--mask', metavar='', help='Path to a binary mask. Only the data inside the mask will be ' 'used \nfor computations and reconstruction.') + p.add_argument( + '--voxel_wise_s0', action='store_true', + help='If set, performs voxel-wise S0 normalization before ' + 'deconvolution. \nIn this case, the mean_b0_val from the ' + 'FRF file is ignored.') add_b0_thresh_arg(p) add_skip_b0_check_arg(p, will_overwrite_with_min=True) @@ -77,25 +84,26 @@ def main(): # Loading data full_frf = np.loadtxt(args.frf_file) - vol = nib.load(args.in_dwi) - data = vol.get_fdata(dtype=np.float32) - bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec) + simg = StatefulImage.load(args.in_dwi) + simg.load_gradients(args.in_bval, args.in_bvec) + + # Reorient to RAS for DIPY + simg.to_ras() + + data = simg.get_fdata(dtype=np.float32) + bvals = simg.bvals + bvecs = simg.world_bvecs # Checking mask - mask = get_data_as_mask(nib.load(args.mask), - dtype=bool) if args.mask else None + mask = None + if args.mask: + mask_simg = StatefulImage.load(args.mask) + mask_simg.to_ras() + mask = get_data_as_mask(mask_simg, dtype=bool) sh_order = args.sh_order sh_basis, is_legacy = parse_sh_basis_arg(args) - # Checking data and sh_order - if data.shape[-1] < (sh_order + 1) * (sh_order + 2) / 2: - logging.warning( - 'We recommend having at least {} unique DWI volumes, but you ' - 'currently have {} volumes. Try lowering the parameter sh_order ' - 'in case of non convergence.'.format( - (sh_order + 1) * (sh_order + 2) / 2, data.shape[-1])) - # Checking bvals, bvecs values and loading gtab if not is_normalized_bvecs(bvecs): logging.warning('Your b-vectors do not seem normalized...') @@ -107,6 +115,28 @@ def main(): skip_b0_check=args.skip_b0_check) gtab = gradient_table(bvals, bvecs=bvecs, b0_threshold=args.b0_threshold) + # Checking data and sh_order + verify_data_vs_sh_order(data, sh_order, gtab=gtab) + + # Checking shells + shells_centroids, _ = identify_shells(bvals, args.b0_threshold, + round_centroids=True) + dwi_shells = shells_centroids[shells_centroids > args.b0_threshold] + shells_centroids = list(sorted(shells_centroids[shells_centroids > args.b0_threshold])) + min_non_b0_shell = np.min(shells_centroids) if len(shells_centroids) > 0 else 0 + max_non_b0_delta = np.ediff1d(shells_centroids)[0] if len(shells_centroids) > 1 else 0 + if max_non_b0_delta >= min_non_b0_shell: + logging.warning( + 'Your shells seem to be very far apart (max delta: {}, min non-b0 shell: {}). ' + 'This might cause problems for the estimation of the FRF. ' + 'Consider using scil_frf_msmt.py.'.format(max_non_b0_delta, min_non_b0_shell)) + + if len(dwi_shells) > 0 and np.max(dwi_shells) < 900 and sh_order > 4: + logging.warning( + 'Your maximum b-value ({}) is relatively low. ' + 'High SH order ({}) might be unstable. ' + 'Consider using --sh_order 4.'.format(np.max(dwi_shells), sh_order)) + # Checking full_frf and separating it if not full_frf.shape[0] == 4: raise ValueError('FRF file did not contain 4 elements. ' @@ -114,6 +144,16 @@ def main(): frf = full_frf[0:3] mean_b0_val = full_frf[3] + if args.voxel_wise_s0: + if np.any(gtab.b0s_mask): + logging.info("Applying voxel-wise S0 normalization.") + b0_mean = np.mean(data[..., gtab.b0s_mask], axis=-1) + data = compute_dwi_attenuation(data, b0_mean) + mean_b0_val = 1.0 + else: + logging.warning("Voxel-wise S0 normalization requested but no b0 " + "volumes found. Skipping normalization.") + # Loading the sphere reg_sphere = get_sphere(name='symmetric362') @@ -134,9 +174,11 @@ def main(): is_input_legacy=True, is_output_legacy=is_legacy, nbr_processes=args.nbr_processes) - nib.save(nib.Nifti1Image(shm_coeff.astype(np.float32), - affine=vol.affine, - header=vol.header), args.out_fODF) + + fodf_img = nib.Nifti1Image(shm_coeff.astype(np.float32), + affine=simg.affine, + header=simg.header) + StatefulImage.create_from(fodf_img, simg).save(args.out_fODF) if __name__ == "__main__": diff --git a/src/scilpy/cli/scil_frf_msmt.py b/src/scilpy/cli/scil_frf_msmt.py index 0b4640cbc..86a5110ab 100755 --- a/src/scilpy/cli/scil_frf_msmt.py +++ b/src/scilpy/cli/scil_frf_msmt.py @@ -26,13 +26,12 @@ import logging from dipy.core.gradients import unique_bvals_tolerance -from dipy.io.gradients import read_bvals_bvecs -import nibabel as nib import numpy as np from scilpy.dwi.utils import extract_dwi_shell from scilpy.gradients.bvec_bval_tools import check_b0_threshold from scilpy.io.image import get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_overwrite_arg, add_precision_arg, add_skip_b0_check_arg, add_verbose_arg, assert_inputs_exist, @@ -157,9 +156,15 @@ def main(): roi_radii = assert_roi_radii_format(parser) # Loading - vol = nib.load(args.in_dwi) - data = vol.get_fdata(dtype=np.float32) - bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec) + simg = StatefulImage.load(args.in_dwi) + simg.load_gradients(args.in_bval, args.in_bvec) + + # FRF computation often expects RAS (via dipy) + simg.to_ras() + + data = simg.get_fdata(dtype=np.float32) + bvals = simg.bvals + bvecs = simg.world_bvecs dti_lim = args.dti_bval_limit @@ -172,7 +177,7 @@ def main(): list_bvals = unique_bvals_tolerance(bvals, tol=args.tolerance) if not np.all(list_bvals <= dti_lim): _, data_dti, bvals_dti, bvecs_dti = extract_dwi_shell( - vol, bvals, bvecs, list_bvals[list_bvals <= dti_lim], + simg, bvals, bvecs, list_bvals[list_bvals <= dti_lim], tol=args.tolerance) bvals_dti = np.squeeze(bvals_dti) else: @@ -180,14 +185,29 @@ def main(): bvals_dti = None bvecs_dti = None - mask = get_data_as_mask(nib.load(args.mask), - dtype=bool) if args.mask else None - mask_wm = get_data_as_mask(nib.load(args.mask_wm), - dtype=bool) if args.mask_wm else None - mask_gm = get_data_as_mask(nib.load(args.mask_gm), - dtype=bool) if args.mask_gm else None - mask_csf = get_data_as_mask(nib.load(args.mask_csf), - dtype=bool) if args.mask_csf else None + mask = None + if args.mask: + mask_simg = StatefulImage.load(args.mask) + mask_simg.reorient(simg.axcodes) + mask = get_data_as_mask(mask_simg, dtype=bool) + + mask_wm = None + if args.mask_wm: + mask_wm_simg = StatefulImage.load(args.mask_wm) + mask_wm_simg.reorient(simg.axcodes) + mask_wm = get_data_as_mask(mask_wm_simg, dtype=bool) + + mask_gm = None + if args.mask_gm: + mask_gm_simg = StatefulImage.load(args.mask_gm) + mask_gm_simg.reorient(simg.axcodes) + mask_gm = get_data_as_mask(mask_gm_simg, dtype=bool) + + mask_csf = None + if args.mask_csf: + mask_csf_simg = StatefulImage.load(args.mask_csf) + mask_csf_simg.reorient(simg.axcodes) + mask_csf = get_data_as_mask(mask_csf_simg, dtype=bool) # Processing responses, frf_masks = compute_msmt_frf(data, bvals, bvecs, @@ -208,10 +228,10 @@ def main(): # Saving masks_files = [args.wm_frf_mask, args.gm_frf_mask, args.csf_frf_mask] - for mask, mask_file in zip(frf_masks, masks_files): + for frf_mask, mask_file in zip(frf_masks, masks_files): if mask_file: - nib.save(nib.Nifti1Image(mask.astype(np.uint8), vol.affine), - mask_file) + res_simg = StatefulImage.from_data(frf_mask.astype(np.uint8), simg) + res_simg.save(mask_file) frf_out = [args.out_wm_frf, args.out_gm_frf, args.out_csf_frf] diff --git a/src/scilpy/cli/scil_frf_ssst.py b/src/scilpy/cli/scil_frf_ssst.py index bd027da42..8191e56b3 100755 --- a/src/scilpy/cli/scil_frf_ssst.py +++ b/src/scilpy/cli/scil_frf_ssst.py @@ -16,12 +16,12 @@ import argparse import logging -from dipy.io.gradients import read_bvals_bvecs -import nibabel as nib import numpy as np -from scilpy.gradients.bvec_bval_tools import check_b0_threshold +from scilpy.gradients.bvec_bval_tools import (check_b0_threshold, + identify_shells) from scilpy.io.image import get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_b0_thresh_arg, add_overwrite_arg, add_precision_arg, add_skip_b0_check_arg, add_verbose_arg, @@ -103,18 +103,45 @@ def main(): roi_radii = assert_roi_radii_format(parser) - vol = nib.load(args.in_dwi) - data = vol.get_fdata(dtype=np.float32) + simg = StatefulImage.load(args.in_dwi) + simg.load_gradients(args.in_bval, args.in_bvec) + + simg.to_ras() + + data = simg.get_fdata(dtype=np.float32) + bvals = simg.bvals + bvecs = simg.world_bvecs - bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec) args.b0_threshold = check_b0_threshold(bvals.min(), b0_thr=args.b0_threshold, skip_b0_check=args.skip_b0_check) - mask = get_data_as_mask(nib.load(args.mask), - dtype=bool) if args.mask else None - mask_wm = get_data_as_mask(nib.load(args.mask_wm), - dtype=bool) if args.mask_wm else None + shells_centroids, _ = identify_shells(bvals, args.b0_threshold, + round_centroids=True) + shells_centroids = list(sorted( + shells_centroids[shells_centroids > args.b0_threshold])) + min_non_b0_shell = np.min(shells_centroids) \ + if len(shells_centroids) > 0 else 0 + max_non_b0_delta = np.ediff1d(shells_centroids)[0] \ + if len(shells_centroids) > 1 else 0 + if max_non_b0_delta >= min_non_b0_shell: + logging.warning( + 'Your shells seem to be very far apart (max delta: {}, ' + 'min non-b0 shell: {}). This might cause problems for the ' + 'estimation of the FRF. Consider using scil_frf_msmt.py.' + .format(max_non_b0_delta, min_non_b0_shell)) + + mask = None + if args.mask: + mask_simg = StatefulImage.load(args.mask) + mask_simg.to_ras() + mask = get_data_as_mask(mask_simg, dtype=bool) + + mask_wm = None + if args.mask_wm: + mask_wm_simg = StatefulImage.load(args.mask_wm) + mask_wm_simg.to_ras() + mask_wm = get_data_as_mask(mask_wm_simg, dtype=bool) full_response = compute_ssst_frf( data, bvals, bvecs, args.b0_threshold, mask=mask, diff --git a/src/scilpy/cli/scil_gradients_validate_correct.py b/src/scilpy/cli/scil_gradients_validate_correct.py index a52b1107f..be290cee3 100755 --- a/src/scilpy/cli/scil_gradients_validate_correct.py +++ b/src/scilpy/cli/scil_gradients_validate_correct.py @@ -2,23 +2,15 @@ # -*- coding: utf-8 -*- """ Detect sign flips and/or axes swaps in the gradients table from a fiber -coherence index [1]. The script takes as input the principal direction(s) -at each voxel, the b-vectors and the fractional anisotropy map and outputs -a corrected b-vectors file. +coherence index [1]. The script takes as input the DWI, b-values and b-vectors +and outputs a corrected b-vectors file. A typical pipeline could be: ->>> scil_dti_metrics dwi.nii.gz bval bvec --not_all --fa fa.nii.gz - --evecs peaks.nii.gz ->>> scil_gradients_validate_correct bvec peaks_v1.nii.gz fa.nii.gz bvec_corr +>>> scil_gradients_validate_correct dwi.nii.gz bval bvec bvec_corr -Note that peaks_v1.nii.gz is the file containing the direction associated -to the highest eigenvalue at each voxel. - -It is also possible to use a file containing multiple principal directions per -voxel, given that they are sorted by decreasing amplitude. In that case, the -first direction (with the highest amplitude) will be chosen for validation. -Only 4D data is supported, so the directions must be stored in a single -dimension. For example, peaks.nii.gz from scil_fodf_metrics could be used. +The script refits the DTI model 24 times (once for each possible axis +permutation and flip) and chooses the one that maximizes the fiber coherence +index. For performance, the fit is only performed on voxels with FA > 0.5. ------------------------------------------------------------------------------ Reference: @@ -30,17 +22,22 @@ """ import argparse +import itertools import logging -from dipy.io.gradients import read_bvals_bvecs +from dipy.core.gradients import gradient_table +from dipy.reconst.dti import TensorModel import numpy as np -import nibabel as nib +from tqdm import tqdm from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist, assert_outputs_exist, add_verbose_arg, - assert_headers_compatible) + add_b0_thresh_arg, add_skip_b0_check_arg) from scilpy.io.image import get_data_as_mask -from scilpy.reconst.fiber_coherence import compute_coherence_table_for_transforms +from scilpy.io.stateful_image import StatefulImage +from scilpy.gradients.bvec_bval_tools import check_b0_threshold +from scilpy.reconst.fiber_coherence import (compute_fiber_coherence, + NB_FLIPS) from scilpy.version import version_string @@ -49,25 +46,24 @@ def _build_arg_parser(): formatter_class=argparse.RawTextHelpFormatter, epilog=version_string) + p.add_argument('in_dwi', + help='Path to the input DWI file.') + p.add_argument('in_bval', + help='Path to the b-values file.') p.add_argument('in_bvec', - help='Path to bvec file.') - p.add_argument('in_peaks', - help='Path to peaks file.') - p.add_argument('in_FA', - help='Path to the fractional anisotropy file.') + help='Path to the b-vectors file to validate.') p.add_argument('out_bvec', help='Path to corrected bvec file (FSL format).') p.add_argument('--mask', - help='Path to an optional mask. If set, FA and Peaks will ' - 'only be used inside the mask.') - p.add_argument('--fa_threshold', default=0.2, type=float, + help='Path to an optional mask. If set, DTI fit will ' + 'only be performed inside the mask.') + p.add_argument('--fa_threshold', default=0.5, type=float, help='FA threshold. Only voxels with FA higher ' 'than fa_threshold will be considered. [%(default)s]') - p.add_argument('--column_wise', action='store_true', - help='Specify if input peaks are column-wise (..., 3, N) ' - 'instead of row-wise (..., N, 3).') + add_b0_thresh_arg(p) + add_skip_b0_check_arg(p, will_overwrite_with_min=True) add_verbose_arg(p) add_overwrite_arg(p) return p @@ -78,50 +74,100 @@ def main(): args = parser.parse_args() logging.getLogger().setLevel(logging.getLevelName(args.verbose)) - assert_inputs_exist(parser, [args.in_bvec, args.in_peaks, args.in_FA], + assert_inputs_exist(parser, [args.in_dwi, args.in_bval, args.in_bvec], optional=args.mask) assert_outputs_exist(parser, args, args.out_bvec) - assert_headers_compatible(parser, [args.in_peaks, args.in_FA], - optional=args.mask) - - _, bvecs = read_bvals_bvecs(None, args.in_bvec) - fa = nib.load(args.in_FA).get_fdata() - peaks = nib.load(args.in_peaks).get_fdata() - - if peaks.shape[-1] > 3: - logging.info('More than one principal direction per voxel was given.') - peaks = peaks[..., 0:3] - logging.info('The first peak is assumed to be the biggest.') - - # convert peaks to a volume of shape (H, W, D, N, 3) - if args.column_wise: - peaks = np.reshape(peaks, peaks.shape[:3] + (3, -1)) - peaks = np.transpose(peaks, axes=(0, 1, 2, 4, 3)) - else: - peaks = np.reshape(peaks, peaks.shape[:3] + (-1, 3)) - peaks = np.squeeze(peaks) - if args.mask: - mask = get_data_as_mask(nib.load(args.mask), ref_shape=peaks.shape) - fa[np.logical_not(mask)] = 0 - peaks[np.logical_not(mask)] = 0 + # Loading data + simg = StatefulImage.load(args.in_dwi) + simg.load_gradients(args.in_bval, args.in_bvec) + simg.to_ras() - peaks[fa < args.fa_threshold] = 0 - coherence, transform = compute_coherence_table_for_transforms(peaks, fa) + data = simg.get_fdata(dtype=np.float32) + bvals = simg.bvals + bvecs = simg.world_bvecs + + mask = None + if args.mask: + mask_simg = StatefulImage.load(args.mask) + mask_simg.to_ras() + mask = get_data_as_mask(mask_simg, dtype=bool) + + # Initial DTI fit to get FA and identify high-FA voxels + args.b0_threshold = check_b0_threshold(bvals.min(), + b0_thr=args.b0_threshold, + skip_b0_check=args.skip_b0_check) + gtab = gradient_table(bvals, bvecs=bvecs, b0_threshold=args.b0_threshold) + tenmodel = TensorModel(gtab, fit_method='WLS', + min_signal=np.min(data[data > 0])) + tenfit = tenmodel.fit(data, mask=mask) + fa = tenfit.fa + + # Define high-FA mask for coherence calculation + high_fa_mask = fa > args.fa_threshold + if mask is not None: + high_fa_mask &= mask + + if np.sum(high_fa_mask) == 0: + logging.error('No voxels found with FA > {}. Aborting.' + .format(args.fa_threshold)) + return + + # Generate 24 possible permutation/flips of gradient directions + permutations = list(itertools.permutations([0, 1, 2])) + transforms = np.zeros((len(permutations) * NB_FLIPS, 3, 3)) + for i in range(len(permutations)): + transforms[i * NB_FLIPS, np.arange(3), permutations[i]] = 1 + for ii in range(3): + flip = np.eye(3) + flip[ii, ii] = -1 + transforms[ii + i * NB_FLIPS + + 1] = transforms[i * NB_FLIPS].dot(flip) + + # Iterative refit and coherence calculation + best_coherence = -1 + best_t = None + + logging.info('Refitting DTI 24 times for gradient validation...') + for t in tqdm(transforms): + # Transform bvecs + # Note: Dipy expects bvecs as (N, 3). We apply the transform to axes. + # G' = G @ T + bvecs_candidate = bvecs @ t + + gtab_candidate = gradient_table(bvals, bvecs=bvecs_candidate, + b0_threshold=args.b0_threshold) + tenmodel_candidate = TensorModel(gtab_candidate, fit_method='WLS', + min_signal=np.min(data[data > 0])) + + # Fit ONLY on the high-FA mask to save time + tenfit_candidate = tenmodel_candidate.fit(data, mask=high_fa_mask) + + # Extract the principal direction (v1) + # evecs is (H, W, D, 3, 3), evecs[..., 0] is the first eigenvector (peak) + peaks = tenfit_candidate.evecs[..., 0] + + # Compute coherence + coherence = compute_fiber_coherence(peaks, fa) + + if coherence > best_coherence: + best_coherence = coherence + best_t = t - best_t = transform[np.argmax(coherence)] if (best_t == np.eye(3)).all(): - logging.info('b-vectors are already correct.') + logging.info('b-vectors are already correct. Coherence: {:.2f}' + .format(best_coherence)) correct_bvecs = bvecs else: - logging.info('Applying correction to b-vectors. ' - 'Transform is: \n{0}.'.format(best_t)) - correct_bvecs = np.dot(bvecs, best_t) + logging.info('Applying correction to b-vectors. Coherence: {:.2f} ' + '\nTransform is: \n{}.'.format(best_coherence, best_t)) + correct_bvecs = bvecs @ best_t logging.info('Saving bvecs to file: {0}.'.format(args.out_bvec)) - # FSL format (3, N) - np.savetxt(args.out_bvec, correct_bvecs.T, '%.8f') + # Save using StatefulImage to ensure they are in the original voxel space + simg.attach_world_gradients(bvals, correct_bvecs) + simg.save_gradients(args.in_bval, args.out_bvec) if __name__ == "__main__": diff --git a/src/scilpy/cli/scil_mti_maps_MT.py b/src/scilpy/cli/scil_mti_maps_MT.py index b39cb80bd..1944d3838 100755 --- a/src/scilpy/cli/scil_mti_maps_MT.py +++ b/src/scilpy/cli/scil_mti_maps_MT.py @@ -93,6 +93,7 @@ import numpy as np from scilpy.io.mti import add_common_args_mti, load_and_verify_mti +from scilpy.io.image import load_img, get_data_as_mask from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist, add_verbose_arg, assert_output_dirs_exist_and_empty) @@ -186,7 +187,8 @@ def main(): optional=args.in_mtoff_t1 or [] + [args.mask]) # Define reference image for saving maps - affine = nib.load(input_maps_lists[0][0]).affine + ref_img, _ = load_img(input_maps_lists[0][0]) + affine = ref_img.affine # Other checks, loading, saving contrast_maps. single_echo, flip_angles, rep_times, B1_map, contrast_maps = \ @@ -251,8 +253,13 @@ def main(): img_data_list.append(MTsat) # Apply thresholds on maps + mask_data = None + if args.mask: + mask_img, _ = load_img(args.mask) + mask_data = get_data_as_mask(mask_img) + for i, map in enumerate(img_data_list): - img_data_list[i] = threshold_map(map, args.mask, 0, 100) + img_data_list[i] = threshold_map(map, mask_data, 0, 100) # Save ihMT and MT images if args.filtering: diff --git a/src/scilpy/cli/scil_mti_maps_ihMT.py b/src/scilpy/cli/scil_mti_maps_ihMT.py index 7ab913273..c669aecfa 100755 --- a/src/scilpy/cli/scil_mti_maps_ihMT.py +++ b/src/scilpy/cli/scil_mti_maps_ihMT.py @@ -108,6 +108,7 @@ import numpy as np from scilpy.io.mti import add_common_args_mti, load_and_verify_mti +from scilpy.io.image import load_img, get_data_as_mask from scilpy.io.utils import (add_overwrite_arg, assert_inputs_exist, add_verbose_arg, assert_output_dirs_exist_and_empty) @@ -272,8 +273,14 @@ def main(): # Apply thresholds on maps upper_thresholds = [100, 100, 10, 10] idx_contrast_lists = [[0, 1, 2, 3, 4], [3, 4], [0, 1, 2, 3], [3, 4]] + + mask_data = None + if args.mask: + mask_img, _ = load_img(args.mask) + mask_data = get_data_as_mask(mask_img) + for i, map in enumerate(img_data): - img_data[i] = threshold_map(map, args.mask, 0, upper_thresholds[i], + img_data[i] = threshold_map(map, mask_data, 0, upper_thresholds[i], idx_contrast_list=idx_contrast_lists[i], contrast_maps=contrast_maps) diff --git a/src/scilpy/cli/scil_qball_metrics.py b/src/scilpy/cli/scil_qball_metrics.py index b835af699..42bf0f556 100755 --- a/src/scilpy/cli/scil_qball_metrics.py +++ b/src/scilpy/cli/scil_qball_metrics.py @@ -22,15 +22,13 @@ from dipy.core.gradients import gradient_table from dipy.data import get_sphere -from dipy.io import read_bvals_bvecs from dipy.direction.peaks import (peaks_from_model, reshape_peaks_for_visualization) from dipy.reconst.shm import QballModel, CsaOdfModel, anisotropic_power -from scilpy.gradients.bvec_bval_tools import (check_b0_threshold, - is_normalized_bvecs, - normalize_bvecs) +from scilpy.gradients.bvec_bval_tools import check_b0_threshold from scilpy.io.image import get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_b0_thresh_arg, add_overwrite_arg, add_processes_arg, add_sh_basis_args, add_skip_b0_check_arg, add_verbose_arg, @@ -126,15 +124,11 @@ def main(): parallel = nbr_processes > 1 # Load data - img = nib.load(args.in_dwi) - data = img.get_fdata(dtype=np.float32) - - bvals, bvecs = read_bvals_bvecs(args.in_bval, args.in_bvec) - - if not is_normalized_bvecs(bvecs): - logging.warning('Your b-vectors do not seem normalized... Normalizing ' - 'now.') - bvecs = normalize_bvecs(bvecs) + simg = StatefulImage.load(args.in_dwi) + simg.load_gradients(args.in_bval, args.in_bvec) + data = simg.get_fdata(dtype=np.float32) + bvals = simg.bvals + bvecs = simg.world_bvecs # Usage of gtab.b0s_mask in dipy's models is not very well documented, but # we can see that it is indeed used. @@ -172,31 +166,30 @@ def main(): num_processes=nbr_processes) if args.gfa: - nib.save(nib.Nifti1Image(odfpeaks.gfa.astype(np.float32), img.affine), - args.gfa) + res = odfpeaks.gfa.astype(np.float32) + StatefulImage.from_data(res, simg).save(args.gfa) if args.peaks: - nib.save(nib.Nifti1Image(reshape_peaks_for_visualization(odfpeaks), - img.affine), args.peaks) + res = reshape_peaks_for_visualization(odfpeaks) + StatefulImage.from_data(res, simg).save(args.peaks) if args.peak_indices: - nib.save(nib.Nifti1Image(odfpeaks.peak_indices, img.affine), - args.peak_indices) + res = odfpeaks.peak_indices + StatefulImage.from_data(res, simg).save(args.peak_indices) if args.sh: - nib.save(nib.Nifti1Image( - odfpeaks.shm_coeff.astype(np.float32), img.affine), - args.sh) + res = odfpeaks.shm_coeff.astype(np.float32) + StatefulImage.from_data(res, simg).save(args.sh) if args.nufo: peaks_count = (odfpeaks.peak_indices > -1).sum(3) - nib.save(nib.Nifti1Image(peaks_count.astype(np.int32), img.affine), - args.nufo) + res = peaks_count.astype(np.int32) + StatefulImage.from_data(res, simg).save(args.nufo) if args.a_power: odf_a_power = anisotropic_power(odfpeaks.shm_coeff) - nib.save(nib.Nifti1Image(odf_a_power.astype(np.float32), img.affine), - args.a_power) + res = odf_a_power.astype(np.float32) + StatefulImage.from_data(res, simg).save(args.a_power) if __name__ == "__main__": diff --git a/src/scilpy/cli/scil_search_keywords.py b/src/scilpy/cli/scil_search_keywords.py index 57cda997b..e1c1e9030 100755 --- a/src/scilpy/cli/scil_search_keywords.py +++ b/src/scilpy/cli/scil_search_keywords.py @@ -204,18 +204,15 @@ def main(): continue # Highlight keywords based on verbosity level - with open(hidden_dir / f'{match}.help', 'r') as f: + with open(hidden_dir / f'{match}.help', 'r', encoding='utf-8') as f: docstrings = f.read() - all_experessions = stemmed_keywords + keywords + phrases \ - + stemmed_phrases + all_expressions = set(stemmed_keywords + keywords + phrases + stemmed_phrases) if not args.no_synonyms: - all_experessions += synonyms - - all_experessions = set(all_experessions) + all_expressions.update(synonyms) highlighted_docstring = _highlight_keywords(docstrings, - all_experessions) + all_expressions) if args.verbose == 'INFO': first_sentence = _split_first_sentence( highlighted_docstring)[0] @@ -235,8 +232,8 @@ def main(): original_word = keyword_mapping.get( word, phrase_mapping.get(word, word)) logging.info( - f"{Fore.LIGHTGREEN_EX}Occurrence of '{original_word}': ' \ - f'{score}{Style.RESET_ALL}") + f"{Fore.LIGHTGREEN_EX}Occurrence of '{original_word}': " + f"{score}{Style.RESET_ALL}") logging.info(f"{Fore.LIGHTBLUE_EX}{'=' * SPACING_LEN}") logging.info("\n") diff --git a/src/scilpy/cli/scil_sh_to_sf.py b/src/scilpy/cli/scil_sh_to_sf.py index 1e30e75c7..e131d4d02 100755 --- a/src/scilpy/cli/scil_sh_to_sf.py +++ b/src/scilpy/cli/scil_sh_to_sf.py @@ -13,14 +13,13 @@ import argparse import logging -import nibabel as nib import numpy as np -from dipy.core.gradients import gradient_table from dipy.core.sphere import Sphere from dipy.data import SPHERE_FILES, get_sphere from dipy.io import read_bvals_bvecs from scilpy.gradients.bvec_bval_tools import DEFAULT_B0_THRESHOLD +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_overwrite_arg, add_processes_arg, add_sh_basis_args, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, @@ -145,18 +144,31 @@ def main(): bvals, _ = read_bvals_bvecs(args.in_bval, None) # Load SH - vol_sh = nib.load(args.in_sh) - data_sh = vol_sh.get_fdata(dtype=np.float32) + simg_sh = StatefulImage.load(args.in_sh) + data_sh = simg_sh.get_fdata(dtype=np.float32) # Sample SF from SH if args.sphere: sphere = get_sphere(name=args.sphere) else: # args.in_bvec is set. - gtab = gradient_table(bvals, bvecs=bvecs, - b0_threshold=args.b0_threshold) - # Remove bvecs corresponding to b0 images - bvecs = bvecs[np.logical_not(gtab.b0s_mask)] - sphere = Sphere(xyz=bvecs) + # Manually rotate bvecs to world space instead of using load_gradients + # because the number of volumes in SH (e.g. 15) does not match the + # number of gradients (e.g. 21). + ref_affine = simg_sh._original_affine \ + if simg_sh._original_affine is not None else simg_sh.affine + R = simg_sh._get_rotation_matrix(ref_affine) + + if StatefulImage.needs_fsl_flip(ref_affine): + bvecs[:, 0] *= -1 + + world_bvecs = np.dot(bvecs, R.T) + # Normalize + norms = np.linalg.norm(world_bvecs, axis=1) + world_bvecs[norms > 1e-6] /= norms[norms > 1e-6][:, None] + + # Use world_bvecs for projection to ensure world-space alignment + bvecs_to_use = world_bvecs[np.logical_not(bvals <= args.b0_threshold)] + sphere = Sphere(xyz=bvecs_to_use) sf = convert_sh_to_sf(data_sh, sphere, input_basis=sh_basis, @@ -178,8 +190,9 @@ def main(): # Add b0 images to SF (and bvals if necessary) if --in_b0 was provided if args.in_b0: # Load b0 - vol_b0 = nib.load(args.in_b0) - data_b0 = vol_b0.get_fdata(dtype=args.dtype) + simg_b0 = StatefulImage.load(args.in_b0) + simg_b0.reorient(simg_sh.axcodes) + data_b0 = simg_b0.get_fdata(dtype=args.dtype) if data_b0.ndim == 3: data_b0 = data_b0[..., np.newaxis] @@ -206,10 +219,24 @@ def main(): # Save new bvecs if args.out_bvec: - np.savetxt(args.out_bvec, new_bvecs.T, fmt='%.8f') + # We need to save bvecs in the original voxel space for FSL compatibility. + # If we used a sphere from bvecs, they were world_bvecs. + # We should use simg_sh to transform them back. + if not args.sphere: + # Reconstruct world bvecs with b0s if necessary + full_world_bvecs = np.zeros((len(new_bvecs), 3)) + start_idx = data_b0.shape[-1] if args.in_b0 else 0 + full_world_bvecs[start_idx:] = sphere.vertices + + # Use a dummy StatefulImage to save them in voxel space + simg_out = StatefulImage.from_data(sf, simg_sh) + simg_out.attach_world_gradients([0]*len(new_bvecs), full_world_bvecs) + simg_out.save_gradients(args.out_bval, args.out_bvec) + else: + np.savetxt(args.out_bvec, new_bvecs.T, fmt='%.8f') # Save SF - nib.save(nib.Nifti1Image(sf, vol_sh.affine), args.out_sf) + StatefulImage.from_data(sf, simg_sh).save(args.out_sf) if __name__ == "__main__": diff --git a/src/scilpy/cli/scil_tracking_local.py b/src/scilpy/cli/scil_tracking_local.py index bed9e5e7b..565caac62 100755 --- a/src/scilpy/cli/scil_tracking_local.py +++ b/src/scilpy/cli/scil_tracking_local.py @@ -42,6 +42,7 @@ to disable backward tracking. This option isn't available for CPU tracking. * Random number generator seed (RNG): CPU and GPU use different RNG implementations,< + assert_inputs_exist, so the same `--seed` is reproducible within a backend but does not guarantee identical streamlines across CPU vs GPU tracking. @@ -66,11 +67,14 @@ from nibabel.streamlines import TrkFile, detect_format from dipy.data import get_sphere +from dipy.io.stateful_tractogram import Space +from scilpy.reconst.utils import compute_sf_threshold_mask from dipy.tracking import utils as track_utils from dipy.tracking.local_tracking import LocalTracking from dipy.tracking.stopping_criterion import BinaryStoppingCriterion from dipy.tracking.tracker import eudx_tracking from scilpy.io.image import get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_sphere_arg, add_verbose_arg, assert_headers_compatible, assert_inputs_exist, assert_outputs_exist, parse_sh_basis_arg, @@ -187,18 +191,44 @@ def main(): "Ignoring.") args.save_seeds = False + sh_basis, is_legacy = parse_sh_basis_arg(args) + # Make sure the data is isotropic. Else, the strategy used # when providing information to dipy (i.e. working as if in voxel space) # will not yield correct results. Tracking is performed in voxel space # in both the GPU and CPU cases. - odf_sh_img = nib.load(args.in_odf) - if not np.allclose(np.mean(odf_sh_img.header.get_zooms()[:3]), - odf_sh_img.header.get_zooms()[0], atol=1e-03): + odf_sh_simg = StatefulImage.load(args.in_odf, is_orientation=True, + is_world_space=not args.is_voxel_space, + sh_basis=sh_basis) + if not np.allclose(np.mean(odf_sh_simg.header.get_zooms()[:3]), + odf_sh_simg.header.get_zooms()[0], atol=1e-03): parser.error( 'ODF SH file is not isotropic. Tracking cannot be ran robustly.') logging.debug("Loading masks and finding seeds.") - mask_data = get_data_as_mask(nib.load(args.in_mask), dtype=bool) + mask_simg = StatefulImage.load(args.in_mask) + mask_simg.reorient(odf_sh_simg.axcodes) + mask_data = get_data_as_mask(mask_simg, dtype=bool) + + # ODF data + odf_sh_data = odf_sh_simg.to_voxel_direction( + sh_basis=sh_basis, nbr_processes=1).astype(np.float32) + + sf_mask = None + if args.global_sf_rel_thr is not None or args.global_sf_abs_thr is not None: + sf_mask, global_max, threshold = compute_sf_threshold_mask( + odf_sh_data, sphere_name=args.sphere, + relative_factor=args.global_sf_rel_thr, + absolute_threshold=args.global_sf_abs_thr, sh_basis=sh_basis, + is_legacy=is_legacy) + logging.info("Global SF threshold mask: Global Max SF amplitude: {:.4f}" + .format(global_max)) + if args.global_sf_rel_thr is not None: + logging.info("Global SF threshold mask: Computed threshold: {:.4f} " + "(Factor: {})".format(threshold, args.global_sf_rel_thr)) + else: + logging.info("Global SF threshold mask: Absolute threshold: {:.4f}" + .format(args.global_sf_abs_thr)) if args.npv: nb_seeds = args.npv @@ -210,26 +240,31 @@ def main(): nb_seeds = 1 seed_per_vox = True - voxel_size = odf_sh_img.header.get_zooms()[0] + voxel_size = odf_sh_simg.header.get_zooms()[0] vox_step_size = args.step_size / voxel_size - seed_img = nib.load(args.in_seed) - - sh_basis, is_legacy = parse_sh_basis_arg(args) + seed_simg = StatefulImage.load(args.in_seed) + seed_simg.reorient(odf_sh_simg.axcodes) - if np.count_nonzero(seed_img.get_fdata(dtype=np.float32)) == 0: + if np.count_nonzero(seed_simg.get_fdata(dtype=np.float32)) == 0: raise IOError('The image {} is empty. ' 'It can\'t be loaded as ' 'seeding mask.'.format(args.in_seed)) - # Note. Seeds are in voxel world, center origin. - # (See the examples in random_seeds_from_mask). + # Note. Seeds are in world space (RASMM) for CPU, and voxel space for GPU. + # Both use center origin. logging.info("Preparing seeds.") + # Always track in voxel space to avoid affine-related orientation issues + # and match the voxel-oriented ODF data. + tracking_space = Space.VOX + tracking_affine = np.eye(4) + if args.in_custom_seeds: seeds = np.squeeze(load_matrix_in_any_format(args.in_custom_seeds)) else: + # Use identity affine to get seeds in voxel space seeds = track_utils.random_seeds_from_mask( - seed_img.get_fdata(dtype=np.float32), - np.eye(4), + seed_simg.get_fdata(dtype=np.float32), + tracking_affine, seeds_count=nb_seeds, seed_count_per_voxel=seed_per_vox, random_seed=args.seed) @@ -239,22 +274,26 @@ def main(): # LocalTracking.maxlen is actually the maximum length # per direction, we need to filter post-tracking. max_steps_per_direction = int(args.max_length / args.step_size) - stopping_criterion = BinaryStoppingCriterion(mask_data) + + combined_mask = mask_data + if sf_mask is not None: + combined_mask = np.logical_and(mask_data, sf_mask) + + stopping_criterion = BinaryStoppingCriterion(combined_mask) logging.info("Starting CPU local tracking.") if args.algo == 'eudx': streamlines_generator = eudx_tracking( seeds, stopping_criterion, - np.eye(4), + tracking_affine, pam=get_direction_getter( - args.in_odf, args.algo, args.sphere, + odf_sh_data, args.algo, args.sphere, args.sub_sphere, args.theta, sh_basis, voxel_size, args.sf_threshold, args.sh_to_pmf, args.probe_length, args.probe_radius, args.probe_quality, args.probe_count, args.support_exponent, is_legacy=is_legacy), - max_cross=1, max_len=max_steps_per_direction, step_size=vox_step_size, max_angle=get_theta(args.theta, args.algo), @@ -264,14 +303,14 @@ def main(): else: streamlines_generator = LocalTracking( get_direction_getter( - args.in_odf, args.algo, args.sphere, + odf_sh_data, args.algo, args.sphere, args.sub_sphere, args.theta, sh_basis, voxel_size, args.sf_threshold, args.sh_to_pmf, args.probe_length, args.probe_radius, args.probe_quality, args.probe_count, args.support_exponent, is_legacy=is_legacy), stopping_criterion, - seeds, np.eye(4), + seeds, tracking_affine, step_size=vox_step_size, max_cross=1, maxlen=max_steps_per_direction, fixedstep=True, return_all=True, @@ -283,15 +322,12 @@ def main(): # to agree with DIPY's implementation max_strl_len = int(2.0 * args.max_length / args.step_size) + 1 - # data volume - odf_sh = odf_sh_img.get_fdata(dtype=np.float32) - # GPU tracking needs the full sphere sphere = get_sphere(name=args.sphere).subdivide(n=args.sub_sphere) logging.info("Starting GPU local tracking.") streamlines_generator = GPUTracker( - odf_sh, mask_data, seeds, + odf_sh_data, mask_data, seeds, vox_step_size, max_strl_len, theta=get_theta(args.theta, args.algo), sf_threshold=args.sf_threshold, @@ -306,9 +342,9 @@ def main(): # save streamlines on-the-fly to file save_tractogram(streamlines_generator, tracts_format, - odf_sh_img, total_nb_seeds, args.out_tractogram, + odf_sh_simg, total_nb_seeds, args.out_tractogram, args.min_length, args.max_length, args.compress_th, - args.save_seeds, args.verbose) + args.save_seeds, args.verbose, space=tracking_space) # Final logging logging.info('Saved tractogram to {0}.'.format(args.out_tractogram)) diff --git a/src/scilpy/cli/scil_tracking_local_dev.py b/src/scilpy/cli/scil_tracking_local_dev.py index 6b36fcb63..470c0fecc 100755 --- a/src/scilpy/cli/scil_tracking_local_dev.py +++ b/src/scilpy/cli/scil_tracking_local_dev.py @@ -63,20 +63,19 @@ import json import dipy.core.geometry as gm +from dipy.io.stateful_tractogram import Space, Origin import nibabel as nib -import numpy as np - -from dipy.io.stateful_tractogram import StatefulTractogram, Space -from dipy.io.stateful_tractogram import Origin -from dipy.io.streamline import save_tractogram from nibabel.streamlines import detect_format, TrkFile +import numpy as np from scilpy.io.image import assert_same_resolution +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_processes_arg, add_sphere_arg, add_verbose_arg, assert_inputs_exist, assert_outputs_exist, parse_sh_basis_arg, verify_compression_th, load_matrix_in_any_format) +from scilpy.reconst.utils import compute_sf_threshold_mask from scilpy.image.volume_space_management import DataVolume from scilpy.tracking.propagator import ODFPropagator from scilpy.tracking.rap import RAPContinue, RAPSwitch @@ -87,7 +86,8 @@ add_tracking_options, get_theta, verify_streamline_length_options, - verify_seed_options) + verify_seed_options, + save_tractogram) from scilpy.version import version_string from scilpy.image.labels import get_data_as_labels from scilpy.io.image import get_data_as_mask @@ -114,7 +114,7 @@ def _build_arg_parser(): track_g.add_argument('--sfthres_init', metavar='sf_th', type=float, default=0.5, dest='sf_threshold_init', help="Spherical function relative threshold value " - "for the \ninitial direction. [%(default)s]") + "within each voxel for the \ninitial direction. [%(default)s]") track_g.add_argument('--rk_order', metavar="K", type=int, default=1, choices=[1, 2, 4], help="The order of the Runge-Kutta integration used " @@ -242,26 +242,27 @@ def main(): # ------- PREPARING DATA ------- theta = gm.math.radians(get_theta(args.theta, args.algo)) - max_nbr_pts = int(args.max_length / args.step_size) - min_nbr_pts = max(int(args.min_length / args.step_size), 1) - if args.in_odf: - assert_same_resolution([args.in_mask, args.in_odf, args.in_seed]) - - # Choosing our space and origin for this tracking - # If save_seeds, space and origin must be vox, center. Choosing those - # values. + # Always track in voxel space to avoid affine-related orientation issues + # and match the voxel-oriented ODF data. our_space = Space.VOX - our_origin = Origin('center') + our_origin = Origin.NIFTI logging.info("Loading seeding mask.") - seed_img = nib.load(args.in_seed) - seed_data = seed_img.get_fdata(caching='unchanged', dtype=float) + seed_simg = StatefulImage.load(args.in_seed) + seed_data = seed_simg.get_fdata(caching='unchanged', dtype=float) if np.count_nonzero(seed_data) == 0: raise IOError('The image {} is empty. ' 'It can\'t be loaded as ' 'seeding mask.'.format(args.in_seed)) - seed_res = seed_img.header.get_zooms()[:3] + seed_res = seed_simg.header.get_zooms()[:3] + voxel_size = np.average(seed_res) + vox_step_size = args.step_size / voxel_size + + max_nbr_pts = int(args.max_length / args.step_size) + min_nbr_pts = max(int(args.min_length / args.step_size), 1) + if args.in_odf: + assert_same_resolution([args.in_mask, args.in_odf, args.in_seed]) # ------- INSTANTIATING SEED GENERATOR ------- if args.in_custom_seeds: @@ -270,7 +271,9 @@ def main(): origin=our_origin) nbr_seeds = len(seeds) else: + # Use identity affine for voxel space seeding seed_generator = SeedGenerator(seed_data, seed_res, + affine=np.eye(4), space=our_space, origin=our_origin, n_repeats=args.n_repeats_per_seed) @@ -288,18 +291,56 @@ def main(): ' value > 0.'.format(args.in_seed)) logging.info("Loading tracking mask.") - mask_img = nib.load(args.in_mask) - mask_data = mask_img.get_fdata(caching='unchanged', dtype=float) - mask_res = mask_img.header.get_zooms()[:3] - mask = DataVolume(mask_data, mask_res, args.mask_interp) + mask_simg = StatefulImage.load(args.in_mask) + mask_simg.reorient(seed_simg.axcodes) + mask_data = mask_simg.get_fdata(caching='unchanged', dtype=float) + mask_res = mask_simg.header.get_zooms()[:3] + # Use identity affine for DataVolume to match voxel space tracking + mask = DataVolume(mask_data, mask_res, affine=np.eye(4), + interpolation=args.mask_interp) + + sh_basis, is_legacy = parse_sh_basis_arg(args) # ------- INSTANTIATING PROPAGATOR ------- if args.in_odf: logging.info("Loading ODF SH data.") - odf_sh_img = nib.load(args.in_odf) - odf_sh_data = odf_sh_img.get_fdata(caching='unchanged', dtype=float) - odf_sh_res = odf_sh_img.header.get_zooms()[:3] - dataset = DataVolume(odf_sh_data, odf_sh_res, args.sh_interp) + odf_sh_simg = StatefulImage.load(args.in_odf, is_orientation=True, + is_world_space=not args.is_voxel_space, + sh_basis=sh_basis) + odf_sh_simg.reorient(seed_simg.axcodes) + odf_sh_data = odf_sh_simg.to_voxel_direction( + sh_basis=sh_basis, nbr_processes=1).astype(np.float32) + + sf_mask = None + if args.global_sf_rel_thr is not None or args.global_sf_abs_thr is not None: + sf_mask, global_max, threshold = compute_sf_threshold_mask( + odf_sh_data, sphere_name=args.sphere, + relative_factor=args.global_sf_rel_thr, + absolute_threshold=args.global_sf_abs_thr, sh_basis=sh_basis, + is_legacy=is_legacy) + logging.info("Global SF threshold mask: Global Max SF amplitude: {:.4f}" + .format(global_max)) + if args.global_sf_rel_thr is not None: + logging.info("Global SF threshold mask: Computed threshold: {:.4f} " + "(Factor: {})".format(threshold, args.global_sf_rel_thr)) + else: + logging.info("Global SF threshold mask: Absolute threshold: {:.4f}" + .format(args.global_sf_abs_thr)) + + # Re-instantiate DataVolume with original mask_data + mask = DataVolume(mask_data, mask_res, affine=np.eye(4), + interpolation=args.mask_interp) + + if sf_mask is not None: + # Mask the stopping criterion + mask_data = np.logical_and(mask_data, sf_mask) + mask = DataVolume(mask_data, mask_res, affine=np.eye(4), + interpolation=args.mask_interp) + + odf_sh_res = odf_sh_simg.header.get_zooms()[:3] + # Use identity affine for DataVolume to match voxel space tracking + dataset = DataVolume(odf_sh_data, odf_sh_res, affine=np.eye(4), + interpolation=args.sh_interp) logging.info("Instantiating propagator.") # Converting step size to vox space @@ -307,12 +348,7 @@ def main(): # 1e-3. assert np.allclose(np.mean(odf_sh_res[:3]), odf_sh_res, atol=1e-03) - voxel_size = odf_sh_img.header.get_zooms()[0] - vox_step_size = args.step_size / voxel_size - - # Using space and origin in the propagator: vox and center, like - # in dipy. - sh_basis, is_legacy = parse_sh_basis_arg(args) + # Using space and origin in the propagator: VOX and NIFTI. propagator = ODFPropagator( dataset, vox_step_size, args.rk_order, args.algo, sh_basis, @@ -333,24 +369,24 @@ def main(): if filename not in loaded_datasets: odf_sh_img = nib.load(filename) odf_sh_res = odf_sh_img.header.get_zooms()[:3] - voxel_size = odf_sh_img.header.get_zooms()[0] - vox_step_size = cfg.get('step_size', args.step_size) / voxel_size + # Use identity affine for DataVolume to match voxel space tracking loaded_datasets[filename] = DataVolume( odf_sh_img.get_fdata(caching='unchanged', dtype=float), - odf_sh_res, args.sh_interp) + odf_sh_res, affine=np.eye(4), + interpolation=args.sh_interp) # Get params from rap_policies file + algo = cfg.get('algo', args.algo) + theta = gm.math.radians(get_theta(cfg.get('theta', args.theta), algo)) sh_basis_name = cfg.get('sh_basis', 'descoteaux07_legacy') sh_basis = ('descoteaux07' if 'descoteaux07' in sh_basis_name else 'tournier07') - algo = cfg.get('algo', args.algo) - theta = gm.math.radians(get_theta(cfg.get('theta', args.theta), algo)) is_legacy = 'legacy' in sh_basis_name # Build propagator from rap_policies file propagators[label] = ODFPropagator( - loaded_datasets[filename], vox_step_size, args.rk_order, - algo, sh_basis, args.sf_threshold, + loaded_datasets[filename], cfg.get('step_size', args.step_size) / voxel_size, + args.rk_order, algo, sh_basis, args.sf_threshold, args.sf_threshold_init, theta, args.sphere, sub_sphere=args.sub_sphere, space=our_space, origin=our_origin, is_legacy=is_legacy) @@ -374,7 +410,10 @@ def main(): rap_img = nib.load(args.rap_mask) rap_mask_data = get_data_as_mask(rap_img) rap_mask_res = rap_img.header.get_zooms()[:3] - rap_volume = DataVolume(rap_mask_data, rap_mask_res, args.mask_interp) + # Use identity affine for DataVolume to match voxel space tracking + rap_volume = DataVolume(rap_mask_data, rap_mask_res, + affine=np.eye(4), + interpolation=args.mask_interp) elif args.rap_labels: logging.info("Loading RAP labels.") rap_label_img = nib.load(args.rap_labels) @@ -386,7 +425,10 @@ def main(): rap_label_data = get_data_as_labels(rap_label_img) rap_label_res = rap_label_img.header.get_zooms()[:3] - rap_volume = DataVolume(rap_label_data, rap_label_res, 'nearest') + # Use identity affine for DataVolume to match voxel space tracking + rap_volume = DataVolume(rap_label_data, rap_label_res, + affine=np.eye(4), + interpolation='nearest') if args.rap_method == "continue": rap = RAPContinue(rap_volume, propagator, max_nbr_pts, @@ -397,11 +439,13 @@ def main(): rap = None logging.info("Instantiating tracker.") + # We must force save_seeds=True so that Tracker returns (streamlines, seeds) + # as expected by scilpy.tracking.utils.save_tractogram tracker = Tracker(propagator, mask, seed_generator, nbr_seeds, min_nbr_pts, max_nbr_pts, args.max_invalid_nb_points, - compression_th=args.compress_th, + compression_th=None, nbr_processes=args.nbr_processes, - save_seeds=args.save_seeds, + save_seeds=True, mmap_mode='r+', rng_seed=args.rng_seed, track_forward_only=args.forward_only, skip=args.skip, @@ -418,27 +462,16 @@ def main(): .format(len(streamlines), nbr_seeds, str_time)) # save seeds if args.save_seeds is given - # We seeded (and tracked) in vox, center, which is what is expected for - # seeds. - if args.save_seeds: - data_per_streamline = {'seeds': seeds} - else: - data_per_streamline = {} # Save RAP entry/exit mask if requested if args.rap_save_entry_exit: - tracker.save_rap_entry_exit_mask(args.rap_save_entry_exit, mask_img) - - # Compared with scil_tracking_local, using sft rather than - # LazyTractogram to deal with space. - # Contrary to scilpy or dipy, where space after tracking is vox, here - # space after tracking is voxmm. - # Smallest possible streamline coordinate is (0,0,0), equivalent of - # corner origin (TrackVis) - sft = StatefulTractogram(streamlines, mask_img, - space=our_space, origin=our_origin, - data_per_streamline=data_per_streamline) - save_tractogram(sft, args.out_tractogram) + tracker.save_rap_entry_exit_mask(args.rap_save_entry_exit, mask_simg) + + # save streamlines on-the-fly to file + save_tractogram(zip(streamlines, seeds), tracts_format, + odf_sh_simg, nbr_seeds, args.out_tractogram, + args.min_length, args.max_length, args.compress_th, + args.save_seeds, args.verbose, space=our_space) if __name__ == "__main__": diff --git a/src/scilpy/cli/scil_tracking_pft.py b/src/scilpy/cli/scil_tracking_pft.py index 1161c2ca8..1b994bacb 100755 --- a/src/scilpy/cli/scil_tracking_pft.py +++ b/src/scilpy/cli/scil_tracking_pft.py @@ -38,24 +38,24 @@ from dipy.data import get_sphere, HemiSphere from dipy.direction import (ProbabilisticDirectionGetter, DeterministicMaximumDirectionGetter) -from dipy.io.utils import (get_reference_info, - create_tractogram_header) from dipy.tracking.local_tracking import ParticleFilteringTracking from dipy.tracking.stopping_criterion import (ActStoppingCriterion, CmcStoppingCriterion) +from dipy.io.stateful_tractogram import Space from dipy.tracking import utils as track_utils -from dipy.tracking.streamlinespeed import length, compress_streamlines import nibabel as nib -from nibabel.streamlines import LazyTractogram +from nibabel.streamlines import detect_format import numpy as np from scilpy.io.image import get_data_as_mask -from scilpy.io.utils import (add_overwrite_arg, add_sh_basis_args, - add_verbose_arg, assert_inputs_exist, +from scilpy.io.stateful_image import StatefulImage +from scilpy.io.utils import (add_verbose_arg, assert_inputs_exist, assert_outputs_exist, parse_sh_basis_arg, - assert_headers_compatible, add_compression_arg, + assert_headers_compatible, verify_compression_th) -from scilpy.tracking.utils import get_theta +from scilpy.tracking.utils import (add_out_options, get_theta, + add_tracking_options, + save_tractogram) from scilpy.version import version_string @@ -81,34 +81,18 @@ def _build_arg_parser(): p.add_argument('out_tractogram', help='Tractogram output file (must be .trk or .tck).') - track_g = p.add_argument_group('Tracking options') + track_g = add_tracking_options(p) track_g.add_argument('--algo', default='prob', choices=['det', 'prob'], help='Algorithm to use (must be "det" or "prob"). ' '[%(default)s]') - track_g.add_argument('--step', dest='step_size', type=float, default=0.2, - help='Step size in mm. [%(default)s]') - track_g.add_argument('--min_length', type=float, default=10., - help='Minimum length of a streamline in mm. ' - '[%(default)s]') - track_g.add_argument('--max_length', type=float, default=300., - help='Maximum length of a streamline in mm. ' - '[%(default)s]') - track_g.add_argument('--theta', type=float, - help='Maximum angle between 2 steps. ' - '["det"=45, "prob"=20]') track_g.add_argument('--act', action='store_true', help='If set, uses anatomically-constrained ' 'tractography (ACT) \ninstead of continuous map ' 'criterion (CMC).') - track_g.add_argument('--sfthres', dest='sf_threshold', - type=float, default=0.1, - help='Spherical function relative threshold. ' - '[%(default)s]') track_g.add_argument('--sfthres_init', dest='sf_threshold_init', type=float, default=0.5, help='Spherical function relative threshold value ' - 'for the \ninitial direction. [%(default)s]') - add_sh_basis_args(track_g) + 'within each voxel for the \ninitial direction. [%(default)s]') seed_group = p.add_argument_group( 'Seeding options', @@ -130,19 +114,13 @@ def _build_arg_parser(): help='Length of PFT forward tracking (mm). ' '[%(default)s]') - out_g = p.add_argument_group('Output options') + out_g = add_out_options(p) out_g.add_argument('--all', dest='keep_all', action='store_true', help='If set, keeps "excluded" streamlines.\n' 'NOT RECOMMENDED, except for debugging.') out_g.add_argument('--seed', type=int, help='Random number generator seed.') - add_overwrite_arg(out_g) - out_g.add_argument('--save_seeds', action='store_true', - help='If set, save the seeds used for the tracking \n ' - 'in the data_per_streamline property.') - - add_compression_arg(out_g) add_verbose_arg(p) return p @@ -187,9 +165,13 @@ def main(): if args.nt and args.nt <= 0: parser.error('Total number of seeds must be > 0.') - fodf_sh_img = nib.load(args.in_sh) - if not np.allclose(np.mean(fodf_sh_img.header.get_zooms()[:3]), - fodf_sh_img.header.get_zooms()[0], atol=1e-03): + sh_basis, is_legacy = parse_sh_basis_arg(args) + + fodf_sh_simg = StatefulImage.load(args.in_sh, is_orientation=True, + is_world_space=not args.is_voxel_space, + sh_basis=sh_basis) + if not np.allclose(np.mean(fodf_sh_simg.header.get_zooms()[:3]), + fodf_sh_simg.header.get_zooms()[0], atol=1e-03): parser.error( 'SH file is not isotropic. Tracking cannot be ran robustly.') @@ -199,8 +181,6 @@ def main(): if not np.allclose(np.linalg.norm(tracking_sphere.vertices, axis=1), 1.): raise RuntimeError('Tracking sphere should be unit normed.') - sh_basis, is_legacy = parse_sh_basis_arg(args) - if args.algo == 'det': dgklass = DeterministicMaximumDirectionGetter else: @@ -213,7 +193,7 @@ def main(): # relative_peak_threshold is for initial directions filtering # min_separation_angle is the initial separation angle for peak extraction dg = dgklass.from_shcoeff( - fodf_sh_img.get_fdata(dtype=np.float32), + fodf_sh_simg.to_voxel_direction(sh_basis=sh_basis, nbr_processes=1), max_angle=theta, sphere=tracking_sphere, basis_type=sh_basis, @@ -221,20 +201,43 @@ def main(): pmf_threshold=args.sf_threshold, relative_peak_threshold=args.sf_threshold_init) - map_include_img = nib.load(args.in_map_include) - map_exclude_img = nib.load(args.map_exclude_file) - voxel_size = np.average(map_include_img.header['pixdim'][1:4]) + map_include_simg = StatefulImage.load(args.in_map_include) + map_include_simg.reorient(fodf_sh_simg.axcodes) + map_exclude_simg = StatefulImage.load(args.map_exclude_file) + map_exclude_simg.reorient(fodf_sh_simg.axcodes) + + map_include_data = map_include_simg.get_fdata(dtype=np.float32) + map_exclude_data = map_exclude_simg.get_fdata(dtype=np.float32) + + sf_mask = None + + # In PFT, exclude map = 1 and include map = 0 ensures stopping and excluding. + # Apply to maps only for stopping criterion. + if sf_mask is not None: + map_include_data[~sf_mask] = 0 + map_exclude_data[~sf_mask] = 1 + + voxel_size = np.average(fodf_sh_simg.header.get_zooms()[:3]) + vox_step_size = args.step_size / voxel_size + + # Always track in voxel space to avoid affine-related orientation issues + # and match the voxel-oriented ODF data. + tracking_space = Space.VOX + tracking_affine = np.eye(4) if not args.act: + # tissue_classifier expects parameters in the tracking space. + # Since we track in voxel space (identity affine), we use + # vox_step_size and average_voxel_size = 1.0. tissue_classifier = CmcStoppingCriterion( - map_include_img.get_fdata(dtype=np.float32), - map_exclude_img.get_fdata(dtype=np.float32), - step_size=args.step_size, - average_voxel_size=voxel_size) + map_include_data, + map_exclude_data, + step_size=vox_step_size, + average_voxel_size=1.0) else: tissue_classifier = ActStoppingCriterion( - map_include_img.get_fdata(dtype=np.float32), - map_exclude_img.get_fdata(dtype=np.float32)) + map_include_data, + map_exclude_data) if args.npv: nb_seeds = args.npv @@ -246,64 +249,45 @@ def main(): nb_seeds = 1 seed_per_vox = True - voxel_size = fodf_sh_img.header.get_zooms()[0] - vox_step_size = args.step_size / voxel_size - seed_img = nib.load(args.in_seed) + seed_simg = StatefulImage.load(args.in_seed) + seed_simg.reorient(fodf_sh_simg.axcodes) + seeds = track_utils.random_seeds_from_mask( - get_data_as_mask(seed_img, dtype=bool), - np.eye(4), + get_data_as_mask(seed_simg, dtype=bool), + tracking_affine, seeds_count=nb_seeds, seed_count_per_voxel=seed_per_vox, random_seed=args.seed) + total_nb_seeds = len(seeds) # Note that max steps is used once for the forward pass, and # once for the backwards. This doesn't, in fact, control the real # max length max_steps = int(args.max_length / args.step_size) + 1 + # We must force save_seeds=True so that the generator yields (strl, seed) + # as expected by scilpy.tracking.utils.save_tractogram pft_streamlines = ParticleFilteringTracking( dg, tissue_classifier, seeds, - np.eye(4), + tracking_affine, max_cross=1, step_size=vox_step_size, maxlen=max_steps, - pft_back_tracking_dist=args.back_tracking, - pft_front_tracking_dist=args.forward_tracking, + pft_back_tracking_dist=args.back_tracking / voxel_size, + pft_front_tracking_dist=args.forward_tracking / voxel_size, particle_count=args.particles, return_all=args.keep_all, random_seed=args.seed, - save_seeds=args.save_seeds) + save_seeds=True) - scaled_min_length = args.min_length / voxel_size - scaled_max_length = args.max_length / voxel_size + tracts_format = detect_format(args.out_tractogram) - if args.save_seeds: - filtered_streamlines, seeds = \ - zip(*((s, p) for s, p in pft_streamlines - if scaled_min_length <= length(s) <= scaled_max_length)) - data_per_streamlines = {'seeds': lambda: seeds} - else: - filtered_streamlines = \ - (s for s in pft_streamlines - if scaled_min_length <= length(s) <= scaled_max_length) - data_per_streamlines = {} - - if args.compress_th: - filtered_streamlines = ( - compress_streamlines(s, args.compress_th) - for s in filtered_streamlines) - - tractogram = LazyTractogram(lambda: filtered_streamlines, - data_per_streamlines, - affine_to_rasmm=seed_img.affine) - - filetype = nib.streamlines.detect_format(args.out_tractogram) - reference = get_reference_info(seed_img) - header = create_tractogram_header(filetype, *reference) - - # Use generator to save the streamlines on-the-fly - nib.streamlines.save(tractogram, args.out_tractogram, header=header) + # save streamlines on-the-fly to file + save_tractogram(pft_streamlines, tracts_format, + fodf_sh_simg, total_nb_seeds, args.out_tractogram, + args.min_length, args.max_length, args.compress_th, + args.save_seeds, args.verbose, space=tracking_space) if __name__ == '__main__': diff --git a/src/scilpy/cli/scil_tractogram_compute_TODI.py b/src/scilpy/cli/scil_tractogram_compute_TODI.py index d30473356..48453b596 100755 --- a/src/scilpy/cli/scil_tractogram_compute_TODI.py +++ b/src/scilpy/cli/scil_tractogram_compute_TODI.py @@ -24,6 +24,7 @@ import numpy as np from scilpy.io.image import get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.io.streamlines import load_tractogram_with_reference from scilpy.io.utils import (add_overwrite_arg, add_reference_arg, add_sh_basis_args, add_verbose_arg, @@ -139,12 +140,15 @@ def main(): if args.out_todi_sh: sh_basis, is_legacy = parse_sh_basis_arg(args) - img = todi_obj.get_sh(sh_basis, args.sh_order, - full_basis=args.asymmetric, - is_legacy=is_legacy) - img = todi_obj.reshape_to_3d(img) - img = nib.Nifti1Image(img.astype(np.float32), affine) - img.to_filename(args.out_todi_sh) + data = todi_obj.get_sh(sh_basis, args.sh_order, + full_basis=args.asymmetric, + is_legacy=is_legacy) + data = todi_obj.reshape_to_3d(data).astype(np.float32) + simg = StatefulImage(data, affine, sh_basis=sh_basis, + is_legacy=is_legacy, is_orientation=True, + is_world_space=False) + simg.to_world_direction() + nib.save(simg, args.out_todi_sh) if args.out_tdi: img = todi_obj.get_tdi() @@ -153,6 +157,7 @@ def main(): img.to_filename(args.out_tdi) if args.out_todi_sf: + # SF rotation is not yet supported in StatefulImage img = todi_obj.get_todi() img = todi_obj.reshape_to_3d(img) img = nib.Nifti1Image(img.astype(np.float32), affine) diff --git a/src/scilpy/cli/scil_tractogram_flip.py b/src/scilpy/cli/scil_tractogram_flip.py index 82531564e..b0d6394e8 100755 --- a/src/scilpy/cli/scil_tractogram_flip.py +++ b/src/scilpy/cli/scil_tractogram_flip.py @@ -11,10 +11,10 @@ import argparse import logging -from dipy.io.streamline import save_tractogram - -from scilpy.io.streamlines import load_tractogram_with_reference -from scilpy.io.utils import (add_reference_arg, +from scilpy.io.streamlines import (load_tractogram_with_reference, + save_tractogram) +from scilpy.io.utils import (add_bbox_arg, + add_reference_arg, add_verbose_arg, add_overwrite_arg, assert_inputs_exist, @@ -39,6 +39,7 @@ def _build_arg_parser(): 'and y axes use: x y.') add_reference_arg(p) + add_bbox_arg(p) add_verbose_arg(p) add_overwrite_arg(p) @@ -58,7 +59,8 @@ def main(): sft.to_corner() new_sft = flip_sft(sft, args.axes) - save_tractogram(new_sft, args.out_tractogram) + save_tractogram(new_sft, args.out_tractogram, False, + bbox_valid_check=args.bbox_check) if __name__ == "__main__": diff --git a/src/scilpy/cli/scil_tractogram_project_map_to_streamlines.py b/src/scilpy/cli/scil_tractogram_project_map_to_streamlines.py index f56852447..e896a992b 100755 --- a/src/scilpy/cli/scil_tractogram_project_map_to_streamlines.py +++ b/src/scilpy/cli/scil_tractogram_project_map_to_streamlines.py @@ -136,7 +136,8 @@ def main(): else: interp = "nearest" - map_volume = DataVolume(map_data, map_res, interp) + map_volume = DataVolume(map_data, map_res, map_img.affine, + interpolation=interp) logging.info("Projecting map onto streamlines") streamline_data = project_map_to_streamlines( diff --git a/src/scilpy/cli/scil_viz_bingham_fit.py b/src/scilpy/cli/scil_viz_bingham_fit.py index 9bd0cd868..384a30b8f 100755 --- a/src/scilpy/cli/scil_viz_bingham_fit.py +++ b/src/scilpy/cli/scil_viz_bingham_fit.py @@ -14,8 +14,6 @@ import argparse import logging -import nibabel as nib - from dipy.data import get_sphere, SPHERE_FILES from scilpy.io.utils import (add_overwrite_arg, @@ -24,6 +22,7 @@ assert_outputs_exist) from scilpy.utils.spatial import RAS_AXES_NAMES from scilpy.utils.spatial import get_axis_index +from scilpy.io.stateful_image import StatefulImage from scilpy.version import version_string from scilpy.viz.backends.fury import (create_interactive_window, @@ -78,6 +77,11 @@ def _build_arg_parser(): help='Color each bingham distribution with a ' 'different color. [%(default)s]') + p.add_argument('--is_voxel_space', action='store_true', + help='If set, assumes the input Bingham parameters are ' + 'already in \nvoxel space. Default assumes world ' + 'space (RAS).') + return p @@ -95,7 +99,10 @@ def _get_data_from_inputs(args): """ Load data given by args. """ - bingham = nib.load(args.in_bingham).get_fdata() + simg = StatefulImage.load(args.in_bingham, is_orientation=True, + is_world_space=not args.is_voxel_space) + simg.to_ras() + bingham = simg.to_voxel_direction() if not args.slice_index: slice_index = bingham.shape[get_axis_index(args.axis_name)] // 2 else: @@ -103,7 +110,7 @@ def _get_data_from_inputs(args): bingham = bingham[_get_slicing_for_axis(args.axis_name, slice_index, bingham.shape)] - return bingham + return bingham, simg.affine def main(): @@ -118,18 +125,20 @@ def main(): assert_inputs_exist(parser, args.in_bingham) assert_outputs_exist(parser, args, [], args.output) - data = _get_data_from_inputs(args) + data, affine = _get_data_from_inputs(args) sph = get_sphere(name=args.sphere) actors = create_bingham_slicer(data, args.axis_name, args.slice_index, sph, - color_per_lobe=args.color_per_lobe) + color_per_lobe=args.color_per_lobe, + affine=None) # Prepare and display the scene scene = create_scene(actors, args.axis_name, args.slice_index, data.shape[:3], - args.win_dims[0] / args.win_dims[1]) + args.win_dims[0] / args.win_dims[1], + affine=None) if not args.silent: create_interactive_window( diff --git a/src/scilpy/cli/scil_viz_fodf.py b/src/scilpy/cli/scil_viz_fodf.py index 753eae950..8117652d2 100755 --- a/src/scilpy/cli/scil_viz_fodf.py +++ b/src/scilpy/cli/scil_viz_fodf.py @@ -22,7 +22,6 @@ import argparse import logging -import nibabel as nib import numpy as np from dipy.data import get_sphere @@ -36,6 +35,7 @@ parse_sh_basis_arg, assert_headers_compatible) from scilpy.io.image import assert_same_resolution, get_data_as_mask +from scilpy.io.stateful_image import StatefulImage from scilpy.utils.spatial import RAS_AXES_NAMES from scilpy.version import version_string from scilpy.viz.backends.fury import (create_interactive_window, @@ -220,7 +220,12 @@ def _get_data_from_inputs(args): Load data given by args. Perform checks to ensure dimensions agree between the data for mask, background, peaks and fODF. """ - fodf = nib.load(args.in_fodf).get_fdata(dtype=np.float32) + sh_basis, is_legacy = parse_sh_basis_arg(args) + fodf_simg = StatefulImage.load(args.in_fodf, is_orientation=True, + is_world_space=not args.is_voxel_space, + sh_basis=sh_basis, is_legacy=is_legacy) + fodf_simg.to_ras() + fodf = fodf_simg.to_voxel_direction(sh_basis=sh_basis, is_legacy=is_legacy) # Optional: bg = None @@ -231,16 +236,25 @@ def _get_data_from_inputs(args): variance = None if args.background: assert_same_resolution([args.background, args.in_fodf]) - bg = nib.load(args.background).get_fdata() + bg_simg = StatefulImage.load(args.background) + bg_simg.reorient(fodf_simg.axcodes) + bg = bg_simg.get_fdata() if args.in_transparency_mask: + tm_simg = StatefulImage.load(args.in_transparency_mask) + tm_simg.reorient(fodf_simg.axcodes) transparency_mask = get_data_as_mask( - nib.load(args.in_transparency_mask), dtype=bool) + tm_simg, dtype=bool) if args.mask: assert_same_resolution([args.mask, args.in_fodf]) - mask = get_data_as_mask(nib.load(args.mask), dtype=bool) + mask_simg = StatefulImage.load(args.mask) + mask_simg.reorient(fodf_simg.axcodes) + mask = get_data_as_mask(mask_simg, dtype=bool) if args.peaks: assert_same_resolution([args.peaks, args.in_fodf]) - peaks = nib.load(args.peaks).get_fdata() + peaks_simg = StatefulImage.load(args.peaks, is_orientation=True, + is_world_space=not args.is_voxel_space) + peaks_simg.reorient(fodf_simg.axcodes) + peaks = peaks_simg.to_voxel_direction() if len(peaks.shape) == 4: last_dim = peaks.shape[-1] if last_dim % 3 == 0: @@ -250,13 +264,20 @@ def _get_data_from_inputs(args): raise ValueError('Peaks volume last dimension ({0}) cannot ' 'be reshaped as (npeaks, 3).' .format(peaks.shape[-1])) - if args.peaks_values: - assert_same_resolution([args.peaks_values, args.in_fodf]) - peak_vals =\ - nib.load(args.peaks_values).get_fdata() + if args.peaks_values: + assert_same_resolution([args.peaks_values, args.in_fodf]) + peak_vals_simg = StatefulImage.load(args.peaks_values) + peak_vals_simg.reorient(fodf_simg.axcodes) + peak_vals =\ + peak_vals_simg.get_fdata() if args.variance: assert_same_resolution([args.variance, args.in_fodf]) - variance = nib.load(args.variance).get_fdata(dtype=np.float32) + variance_simg = StatefulImage.load(args.variance, is_orientation=True, + is_world_space=not args.is_voxel_space, + sh_basis=sh_basis, is_legacy=is_legacy) + variance_simg.reorient(fodf_simg.axcodes) + variance = variance_simg.to_voxel_direction(sh_basis=sh_basis, + is_legacy=is_legacy) if len(variance.shape) == 3: variance = np.reshape(variance, variance.shape + (1,)) if variance.shape != fodf.shape: @@ -264,14 +285,15 @@ def _get_data_from_inputs(args): 'variance {1}.' .format(fodf.shape, variance.shape)) - return fodf, bg, transparency_mask, mask, peaks, peak_vals, variance + return (fodf, bg, transparency_mask, mask, peaks, peak_vals, variance, + fodf_simg.affine) def main(): parser = _build_arg_parser() args = _parse_args(parser) (fodf, bg, transparency_mask, mask, peaks, peaks_values, - variance) = _get_data_from_inputs(args) + variance, affine) = _get_data_from_inputs(args) sph = get_sphere(name=args.sphere) sh_order, full_basis = get_sh_order_and_fullness(fodf.shape[-1]) sh_basis, is_legacy = parse_sh_basis_arg(args) @@ -292,7 +314,7 @@ def main(): sh_variance=variance, mask=mask, nb_subdivide=args.sph_subdivide, radial_scale=not args.radial_scale_off, norm=not args.norm_off, colormap=args.colormap or color_rgb, variance_k=args.variance_k, - variance_color=var_color, is_legacy=is_legacy) + variance_color=var_color, is_legacy=is_legacy, affine=None) actors.append(odf_actor) # Instantiate a variance slicer actor if a variance image is supplied @@ -308,7 +330,8 @@ def main(): value_range=args.bg_range, opacity=args.bg_opacity, offset=args.bg_offset, - interpolation=args.bg_interpolation) + interpolation=args.bg_interpolation, + affine=None) actors.append(bg_actor) # Instantiate a peaks slicer actor if peaks are supplied @@ -323,7 +346,8 @@ def main(): color=args.peaks_color, peaks_width=args.peaks_width, opacity=args.peaks_opacity, - symmetric=not full_basis) + symmetric=not full_basis, + affine=None) actors.append(peaks_actor) @@ -332,20 +356,23 @@ def main(): args.slice_index, fodf.shape[:3], args.win_dims[0] / args.win_dims[1], - bg_color=args.bg_color) + bg_color=args.bg_color, + affine=None) mask_scene = None if transparency_mask is not None: mask_actor = create_texture_slicer(transparency_mask.astype("uint8"), args.axis_name, args.slice_index, - offset=0.0) + offset=0.0, + affine=None) mask_scene = create_scene([mask_actor], args.axis_name, args.slice_index, transparency_mask.shape, args.win_dims[0] / args.win_dims[1], - bg_color=args.bg_color) + bg_color=args.bg_color, + affine=None) if not args.silent: create_interactive_window(scene, args.win_dims, args.interactor) diff --git a/src/scilpy/cli/scil_volume_apply_transform.py b/src/scilpy/cli/scil_volume_apply_transform.py index e07c454d5..fd60404ee 100755 --- a/src/scilpy/cli/scil_volume_apply_transform.py +++ b/src/scilpy/cli/scil_volume_apply_transform.py @@ -10,7 +10,6 @@ import argparse import logging -import nibabel as nib import numpy as np from scilpy.image.volume_operations import apply_transform diff --git a/src/scilpy/cli/scil_volume_math.py b/src/scilpy/cli/scil_volume_math.py index e9816ae04..e84771ead 100755 --- a/src/scilpy/cli/scil_volume_math.py +++ b/src/scilpy/cli/scil_volume_math.py @@ -21,6 +21,7 @@ from scilpy.image.volume_math import (get_image_ops, get_operations_doc) from scilpy.io.image import load_img +from scilpy.io.stateful_image import StatefulImage from scilpy.io.utils import (add_overwrite_arg, add_verbose_arg, assert_outputs_exist) @@ -76,9 +77,10 @@ def main(): # Find at least one mask, but prefer a 4D mask if there is any. mask = None found_ref = False + ref_img = None for input_arg in args.in_args: if not is_float(input_arg): - ref_img = nib.load(input_arg) + ref_img, _ = load_img(input_arg) found_ref = True if mask is None: mask = np.zeros(ref_img.shape) @@ -92,19 +94,20 @@ def main(): # Load all input masks. input_img = [] for input_arg in args.in_args: + img, dtype = load_img(input_arg) if not is_float(input_arg) and \ - not is_header_compatible(ref_img, input_arg): + not is_header_compatible(ref_img, img): parser.error('Inputs do not have a compatible header.') - img, dtype = load_img(input_arg) if not isinstance(img, float): - args.data_type = img.header.get_data_dtype() if args.data_type is None else args.data_type + args.data_type = img.header.get_data_dtype() \ + if args.data_type is None else args.data_type - if isinstance(img, nib.Nifti1Image) and \ + if isinstance(img, StatefulImage) and \ dtype != ref_img.get_data_dtype() and \ not args.data_type: parser.error('Inputs do not have a compatible data type.\n' 'Use --data_type to specify output datatype.') - if args.operation in binary_op and isinstance(img, nib.Nifti1Image): + if args.operation in binary_op and isinstance(img, StatefulImage): data = img.get_fdata(dtype=np.float64) unique = np.unique(data) if not len(unique) <= 2: @@ -116,7 +119,7 @@ def main(): 'binary arrays, will be converted.\n' 'Non-zeros will be set to ones.') - if isinstance(img, nib.Nifti1Image): + if isinstance(img, StatefulImage): data = img.get_fdata(dtype=np.float64) if data.ndim == 4: mask[np.sum(data, axis=3).astype(bool) > 0] = 1 @@ -144,7 +147,10 @@ def main(): new_img = nib.Nifti1Image(output_data, ref_img.affine, header=ref_img.header) - nib.save(new_img, args.out_image) + + # Use StatefulImage.create_from to ensure original orientation + # ref_img is also a StatefulImage (loaded via load_img earlier) + StatefulImage.create_from(new_img, ref_img).save(args.out_image) if __name__ == "__main__": diff --git a/src/scilpy/cli/scil_volume_modify_voxel_order.py b/src/scilpy/cli/scil_volume_modify_voxel_order.py index 5575f5dd8..719056b14 100644 --- a/src/scilpy/cli/scil_volume_modify_voxel_order.py +++ b/src/scilpy/cli/scil_volume_modify_voxel_order.py @@ -32,6 +32,7 @@ import argparse import logging import nibabel as nib +import numpy as np from scilpy.io.utils import (add_overwrite_arg, add_verbose_arg, @@ -54,6 +55,15 @@ def _build_arg_parser(): p.add_argument('--new_voxel_order', required=True, help='The new voxel order (e.g., "RAS", "1,2,3").') + p.add_argument('--in_bvec', + help='Path of the b-vectors file.') + p.add_argument('--in_bval', + help='Path of the b-values file.') + p.add_argument('--out_bvec', + help='Path of the modified b-vectors file to write.') + p.add_argument('--out_bval', + help='Path of the modified b-values file to write.') + add_verbose_arg(p) add_overwrite_arg(p) @@ -65,18 +75,41 @@ def main(): args = parser.parse_args() logging.getLogger().setLevel(logging.getLevelName(args.verbose)) - assert_inputs_exist(parser, args.in_image) - assert_outputs_exist(parser, args, args.out_image) + assert_inputs_exist(parser, args.in_image, args.in_bvec) + assert_outputs_exist(parser, args, args.out_image, args.out_bvec) img = nib.load(args.in_image) simg = StatefulImage.load(args.in_image) + if args.in_bvec: + bvecs = np.loadtxt(args.in_bvec) + if bvecs.shape[0] == 3 and bvecs.shape[1] != 3: + bvecs = bvecs.T + + # Create dummy bvals to satisfy StatefulImage validation + bvals = np.zeros(len(bvecs)) + simg.attach_gradients(bvals, bvecs) + parsed_voxel_order = parse_voxel_order(args.new_voxel_order, dimensions=len(img.shape)) simg.reorient(parsed_voxel_order) - nib.save(simg, args.out_image) + # To enforce the new voxel order in the header, we need to create + # a new StatefulImage, which will update the header accordingly. + new_simg = StatefulImage.convert_to_simg(simg, simg.bvals, simg.bvecs) + new_simg.save(args.out_image) + + if args.in_bvec and args.out_bvec: + if args.in_bval and args.out_bval: + new_simg.save_gradients(args.out_bval, args.out_bvec) + else: + # If no bval file or no output bval path, save only bvecs. + # new_simg.bvecs returns bvecs in the current (new) orientation. + np.savetxt(args.out_bvec, new_simg.bvecs.T, fmt='%.8f') + if args.in_bval and not args.out_bval: + logging.warning("b-values were provided but no output path " + "was specified. b-values will not be saved.") if __name__ == "__main__": diff --git a/src/scilpy/cli/scil_volume_resample.py b/src/scilpy/cli/scil_volume_resample.py index 761efa0c2..5575ac6f0 100755 --- a/src/scilpy/cli/scil_volume_resample.py +++ b/src/scilpy/cli/scil_volume_resample.py @@ -17,7 +17,6 @@ import argparse import logging -import nibabel as nib import numpy as np from scilpy.io.utils import (add_verbose_arg, add_overwrite_arg, @@ -86,12 +85,12 @@ def main(): if args.enforce_voxel_size and not args.voxel_size: parser.error("Cannot enforce voxel size without a voxel size.") - if args.volume_size and (not len(args.volume_size) == 1 and - not len(args.volume_size) == 3): + if args.volume_size and (not len(args.volume_size) == 1 + and not len(args.volume_size) == 3): parser.error('Invalid dimensions for --volume_size.') - if args.voxel_size and (not len(args.voxel_size) == 1 and - not len(args.voxel_size) == 3): + if args.voxel_size and (not len(args.voxel_size) == 1 + and not len(args.voxel_size) == 3): parser.error('Invalid dimensions for --voxel_size.') logging.info('Loading raw data from %s', args.in_image) @@ -100,15 +99,15 @@ def main(): ref_img = None if args.ref: - ref_img = nib.load(args.ref) + ref_img = StatefulImage.load(args.ref) # Must not verify that headers are compatible. But can verify that, at # least, the first columns of their affines are compatible. - img_zoom_invert = [1 / zoom for zoom in ref_img.header.get_zooms()[:3]] + img_zoom_invert = [1 / zoom for zoom in simg.header.get_zooms()[:3]] ref_zoom_invert = [1 / zoom for zoom in ref_img.header.get_zooms()[:3]] - img_affine = np.dot(simg.affine[:3, :3], img_zoom_invert) - ref_affine = np.dot(ref_img.affine[:3, :3], ref_zoom_invert) + img_affine = np.dot(simg.affine[:3, :3], np.diag(img_zoom_invert)) + ref_affine = np.dot(ref_img.affine[:3, :3], np.diag(ref_zoom_invert)) if not np.allclose(img_affine, ref_affine): parser.error("The --ref image should have the same affine as the " diff --git a/src/scilpy/cli/scil_volume_reshape.py b/src/scilpy/cli/scil_volume_reshape.py index 597e40f3f..d9bb6da68 100755 --- a/src/scilpy/cli/scil_volume_reshape.py +++ b/src/scilpy/cli/scil_volume_reshape.py @@ -75,8 +75,8 @@ def main(): assert_inputs_exist(parser, args.in_image, args.ref) assert_outputs_exist(parser, args, args.out_image) - if args.volume_size and (not len(args.volume_size) == 1 and - not len(args.volume_size) == 3): + if args.volume_size and (not len(args.volume_size) == 1 + and not len(args.volume_size) == 3): parser.error('--volume_size takes in either 1 or 3 arguments.') logging.info('Loading raw data from %s', args.in_image) diff --git a/src/scilpy/cli/tests/test_fodf_ssst.py b/src/scilpy/cli/tests/test_fodf_ssst.py index 9e8278fcb..befa3fad7 100644 --- a/src/scilpy/cli/tests/test_fodf_ssst.py +++ b/src/scilpy/cli/tests/test_fodf_ssst.py @@ -38,3 +38,20 @@ def test_execution_processing(script_runner, monkeypatch): '--sh_basis', 'tournier07', '--processes', '1', '--b0_threshold', '1', '-f']) assert not ret.success + + +def test_execution_voxel_wise_s0(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_dwi = os.path.join(SCILPY_HOME, 'processing', + 'dwi_crop_3000.nii.gz') + in_bval = os.path.join(SCILPY_HOME, 'processing', + '3000.bval') + in_bvec = os.path.join(SCILPY_HOME, 'processing', + '3000.bvec') + in_frf = os.path.join(SCILPY_HOME, 'processing', + 'frf.txt') + ret = script_runner.run(['scil_fodf_ssst', in_dwi, in_bval, in_bvec, + in_frf, 'fodf_vw.nii.gz', '--sh_order', '4', + '--sh_basis', 'tournier07', '--processes', '1', + '--voxel_wise_s0']) + assert ret.success diff --git a/src/scilpy/cli/tests/test_gradients_validate_correct.py b/src/scilpy/cli/tests/test_gradients_validate_correct.py index 8c79a6f41..e1653155b 100644 --- a/src/scilpy/cli/tests/test_gradients_validate_correct.py +++ b/src/scilpy/cli/tests/test_gradients_validate_correct.py @@ -26,27 +26,8 @@ def test_execution_processing_dti_peaks(script_runner, monkeypatch): in_bvec = os.path.join(SCILPY_HOME, 'processing', '1000.bvec') - # generate the peaks file and fa map we'll use to test our script - script_runner.run(['scil_dti_metrics', in_dwi, in_bval, in_bvec, - '--not_all', '--fa', 'fa.nii.gz', - '--evecs', 'evecs.nii.gz']) # test the actual script - ret = script_runner.run(['scil_gradients_validate_correct', in_bvec, - 'evecs_v1.nii.gz', 'fa.nii.gz', - 'bvec_corr', '-v']) - assert ret.success - - -def test_execution_processing_fodf_peaks(script_runner, monkeypatch): - monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) - in_bvec = os.path.join(SCILPY_HOME, 'processing', - 'dwi.bvec') - in_peaks = os.path.join(SCILPY_HOME, 'processing', - 'peaks.nii.gz') - in_fa = os.path.join(SCILPY_HOME, 'processing', - 'fa.nii.gz') - - # test the actual script - ret = script_runner.run(['scil_gradients_validate_correct', in_bvec, - in_peaks, in_fa, 'bvec_corr_fodf', '-v']) + ret = script_runner.run(['scil_gradients_validate_correct', + in_dwi, in_bval, in_bvec, 'bvec_corr.bvec', + '--fa_thresh', '0.5', '-v']) assert ret.success diff --git a/src/scilpy/cli/tests/test_scil_volume_modify_voxel_order.py b/src/scilpy/cli/tests/test_scil_volume_modify_voxel_order.py deleted file mode 100644 index 847f070e6..000000000 --- a/src/scilpy/cli/tests/test_scil_volume_modify_voxel_order.py +++ /dev/null @@ -1,53 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- - -import os -import nibabel as nib -import numpy as np -import tempfile - - -tmp_dir = tempfile.TemporaryDirectory() - - -def test_help_option(script_runner): - ret = script_runner.run(['scil_volume_modify_voxel_order', '--help']) - assert ret.success - - -def test_execution(script_runner, monkeypatch): - monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) - in_file = 'input.nii.gz' - img = nib.Nifti1Image(np.zeros((10, 20, 30)), np.eye(4)) - nib.save(img, in_file) - - # Test with character-based voxel order - out_file_lps = 'output_lps.nii.gz' - ret = script_runner.run(['scil_volume_modify_voxel_order', in_file, - out_file_lps, '--new_voxel_order=LPS', '-f']) - assert ret.success - lps_img = nib.load(out_file_lps) - assert nib.aff2axcodes(lps_img.affine) == ('L', 'P', 'S') - - # Test with numeric voxel order - out_file_asr = 'output_asr.nii.gz' - ret = script_runner.run(['scil_volume_modify_voxel_order', in_file, - out_file_asr, '--new_voxel_order=3,1,2', '-f']) - assert ret.success - asr_img = nib.load(out_file_asr) - assert nib.aff2axcodes(asr_img.affine) == ('S', 'R', 'A') - - # Test with negative numeric voxel order - out_file_lai = 'output_lai.nii.gz' - ret = script_runner.run(['scil_volume_modify_voxel_order', in_file, - out_file_lai, '--new_voxel_order=-1,2,-3', - '-f']) - assert ret.success - lai_img = nib.load(out_file_lai) - assert nib.aff2axcodes(lai_img.affine) == ('L', 'A', 'I') - - # Test with invalid input - ret = script_runner.run(['scil_volume_modify_voxel_order', in_file, - 'output.nii.gz', '--new_voxel_order=invalid', - '-f']) - assert not ret.success diff --git a/src/scilpy/cli/tests/test_sh_to_sf.py b/src/scilpy/cli/tests/test_sh_to_sf.py index 41156a8d9..5eff117c2 100644 --- a/src/scilpy/cli/tests/test_sh_to_sf.py +++ b/src/scilpy/cli/tests/test_sh_to_sf.py @@ -29,7 +29,7 @@ def test_execution_in_sphere(script_runner, monkeypatch): in_bval, '--in_b0', in_b0, '--out_bval', 'sf_724.bval', '--out_bvec', 'sf_724.bvec', '--sphere', 'symmetric724', '--dtype', 'float32', - '--processes', '1']) + '--processes', '1', '-f']) assert ret.success diff --git a/src/scilpy/cli/tests/test_tracking_local.py b/src/scilpy/cli/tests/test_tracking_local.py index 2a06a9294..64dac472a 100644 --- a/src/scilpy/cli/tests/test_tracking_local.py +++ b/src/scilpy/cli/tests/test_tracking_local.py @@ -2,11 +2,13 @@ # -*- coding: utf-8 -*- import os +import pytest import tempfile import numpy as np from scilpy import SCILPY_HOME from scilpy.io.fetcher import fetch_data, get_testing_files_dict +from scilpy.gpuparallel.opencl_utils import have_opencl # If they already exist, this only takes 5 seconds (check md5sum) fetch_data(get_testing_files_dict(), keys=['tracking.zip']) @@ -73,6 +75,7 @@ def test_execution_sphere_subdivide(script_runner, monkeypatch): assert ret.success +@pytest.mark.skipif(not have_opencl, reason='pyopencl not installed') def test_execution_sphere_gpu(script_runner, monkeypatch): monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) in_fodf = os.path.join(SCILPY_HOME, 'tracking', 'fodf.nii.gz') @@ -83,7 +86,7 @@ def test_execution_sphere_gpu(script_runner, monkeypatch): '--use_gpu', '--sphere', 'symmetric362', '--npv', '1']) - assert not ret.success + assert ret.success def test_sh_interp_without_gpu(script_runner, monkeypatch): @@ -122,6 +125,7 @@ def test_batch_size_without_gpu(script_runner, monkeypatch): assert not ret.success +@pytest.mark.skipif(not have_opencl, reason='pyopencl not installed') def test_algo_with_gpu(script_runner, monkeypatch): monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) in_fodf = os.path.join(SCILPY_HOME, 'tracking', 'fodf.nii.gz') @@ -131,7 +135,7 @@ def test_algo_with_gpu(script_runner, monkeypatch): in_mask, in_mask, 'gpu_det.trk', '--algo', 'det', '--use_gpu', '--nt', '100']) - assert not ret.success + assert ret.success def test_execution_tracking_fodf_no_compression(script_runner, monkeypatch): diff --git a/src/scilpy/cli/tests/test_tractogram_convert.py b/src/scilpy/cli/tests/test_tractogram_convert.py index 6805bc535..98ca86ef2 100644 --- a/src/scilpy/cli/tests/test_tractogram_convert.py +++ b/src/scilpy/cli/tests/test_tractogram_convert.py @@ -24,5 +24,7 @@ def test_execution_surface_vtk_fib(script_runner, monkeypatch): in_fa = os.path.join(SCILPY_HOME, 'surface_vtk_fib', 'fa.nii.gz') ret = script_runner.run(['scil_tractogram_convert', in_fib, - 'gyri_fanning.trk', '--reference', in_fa]) + 'gyri_fanning.trk', '--reference', in_fa, + '--no_bbox_check']) + assert ret.success diff --git a/src/scilpy/cli/tests/test_tractogram_flip.py b/src/scilpy/cli/tests/test_tractogram_flip.py index d77a5ff92..79f29bc11 100644 --- a/src/scilpy/cli/tests/test_tractogram_flip.py +++ b/src/scilpy/cli/tests/test_tractogram_flip.py @@ -24,5 +24,7 @@ def test_execution_surface_vtk_fib(script_runner, monkeypatch): in_fa = os.path.join(SCILPY_HOME, 'surface_vtk_fib', 'fa.nii.gz') ret = script_runner.run(['scil_tractogram_flip', in_fib, - 'gyri_fanning_fliped.tck', 'x', '--reference', in_fa]) + 'gyri_fanning_fliped.tck', 'x', '--reference', in_fa, + '--no_bbox_check']) + assert ret.success diff --git a/src/scilpy/cli/tests/test_volume_apply_transform.py b/src/scilpy/cli/tests/test_volume_apply_transform.py index d133f2cbf..83e56a6d0 100644 --- a/src/scilpy/cli/tests/test_volume_apply_transform.py +++ b/src/scilpy/cli/tests/test_volume_apply_transform.py @@ -60,3 +60,25 @@ def test_execution_interp_lin(script_runner, monkeypatch): 'template_lin.nii.gz', '--inverse', '--interp', 'linear', '-f']) assert ret.success + + +def test_execution_and_header_compatibility(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_model = os.path.join(SCILPY_HOME, 'bst', 'template', + 'template0.nii.gz') + in_fa = os.path.join(SCILPY_HOME, 'bst', + 'fa.nii.gz') + in_aff = os.path.join(SCILPY_HOME, 'bst', + 'output0GenericAffine.mat') + out_filename = 'template_lin_header_test.nii.gz' + + # Run the transformation + ret = script_runner.run(['scil_volume_apply_transform', + in_model, in_fa, in_aff, + out_filename, '--inverse', '-f']) + assert ret.success + + # Check for header compatibility between the output and the reference + ret = script_runner.run(['scil_header_validate_compatibility', + out_filename, in_fa]) + assert ret.success, "Headers are not compatible!" diff --git a/src/scilpy/cli/tests/test_volume_modify_voxel_order.py b/src/scilpy/cli/tests/test_volume_modify_voxel_order.py new file mode 100644 index 000000000..c1305efa9 --- /dev/null +++ b/src/scilpy/cli/tests/test_volume_modify_voxel_order.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import os +import nibabel as nib +import numpy as np +import tempfile + +from scilpy import SCILPY_HOME +from scilpy.io.fetcher import fetch_data, get_testing_files_dict + +# If they already exist, this only takes 5 seconds (check md5sum) +fetch_data(get_testing_files_dict(), keys=['processing.zip']) +tmp_dir = tempfile.TemporaryDirectory() + + +def test_help_option(script_runner): + ret = script_runner.run(['scil_volume_modify_voxel_order', '--help']) + assert ret.success + + +def test_execution(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_file = 'input.nii.gz' + img = nib.Nifti1Image(np.zeros((10, 20, 30)), np.eye(4)) + nib.save(img, in_file) + + # Test with character-based voxel order + out_file_lps = 'output_lps.nii.gz' + ret = script_runner.run(['scil_volume_modify_voxel_order', in_file, + out_file_lps, '--new_voxel_order=LPS', '-f']) + assert ret.success + lps_img = nib.load(out_file_lps) + assert nib.aff2axcodes(lps_img.affine) == ('L', 'P', 'S') + + # Test with numeric voxel order + out_file_asr = 'output_asr.nii.gz' + ret = script_runner.run(['scil_volume_modify_voxel_order', in_file, + out_file_asr, '--new_voxel_order=3,1,2', '-f']) + assert ret.success + asr_img = nib.load(out_file_asr) + assert nib.aff2axcodes(asr_img.affine) == ('S', 'R', 'A') + + # Test with negative numeric voxel order + out_file_lai = 'output_lai.nii.gz' + ret = script_runner.run(['scil_volume_modify_voxel_order', in_file, + out_file_lai, '--new_voxel_order=-1,2,-3', + '-f']) + assert ret.success + lai_img = nib.load(out_file_lai) + assert nib.aff2axcodes(lai_img.affine) == ('L', 'A', 'I') + + # Test with invalid input + ret = script_runner.run(['scil_volume_modify_voxel_order', in_file, + 'output.nii.gz', '--new_voxel_order=invalid', + '-f']) + assert not ret.success + + +def test_execution_with_gradients(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + + # 1. Create a 4D dummy NIfTI (RAS) + n_volumes = 2 + in_file = 'input_4d.nii.gz' + data = np.zeros((10, 10, 10, n_volumes)) + img = nib.Nifti1Image(data, np.eye(4)) + nib.save(img, in_file) + + # 2. Create bvecs + bvecs = np.array([[0, 0, 0], [1, 0, 0]]) # X-direction in RAS + + in_bvec = 'input.bvec' + np.savetxt(in_bvec, bvecs.T, fmt='%.8f') + + # 3. Run script to modify voxel order to LPS + out_file = 'output_lps.nii.gz' + out_bvec = 'output_lps.bvec' + ret = script_runner.run(['scil_volume_modify_voxel_order', in_file, + out_file, '--new_voxel_order=LPS', + '--in_bvec', in_bvec, '--out_bvec', out_bvec, '-f']) + assert ret.success + + # 4. Verify image + lps_img = nib.load(out_file) + assert nib.aff2axcodes(lps_img.affine) == ('L', 'P', 'S') + + # 5. Verify gradients (they should be reoriented to match LPS) + assert os.path.exists(out_bvec) + + saved_bvecs = np.loadtxt(out_bvec).T # loadtxt returns (3, N) for FSL + + # RAS to LPS: flip X and Y. + # Original bvec [1, 0, 0] (X) should become [-1, 0, 0] + expected_bvecs = np.array([[0, 0, 0], [-1, 0, 0]]) + assert np.allclose(saved_bvecs, expected_bvecs, atol=1e-3) + + +def test_execution_with_gradients_numeric(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + + # 1. Create a 4D dummy NIfTI (RAS) + n_volumes = 2 + in_file = 'input_4d_num.nii.gz' + data = np.zeros((10, 10, 10, n_volumes)) + img = nib.Nifti1Image(data, np.eye(4)) + nib.save(img, in_file) + + # 2. Create bvecs + bvecs = np.array([[0, 0, 0], [1, 0, 0]]) # X-direction in RAS + + in_bvec = 'input_num.bvec' + np.savetxt(in_bvec, bvecs.T, fmt='%.8f') + + # 3. Run script to modify voxel order to LPS using numeric: -1,-2,3 + out_file = 'output_lps_num.nii.gz' + out_bvec = 'output_lps_num.bvec' + ret = script_runner.run(['scil_volume_modify_voxel_order', in_file, + out_file, '--new_voxel_order=-1,-2,3', + '--in_bvec', in_bvec, '--out_bvec', out_bvec, '-f']) + assert ret.success + + # 4. Verify image + lps_img = nib.load(out_file) + assert nib.aff2axcodes(lps_img.affine)[:3] == ('L', 'P', 'S') + + # 5. Verify gradients + assert os.path.exists(out_bvec) + saved_bvecs = np.loadtxt(out_bvec).T + expected_bvecs = np.array([[0, 0, 0], [-1, 0, 0]]) + assert np.allclose(saved_bvecs, expected_bvecs, atol=1e-3) + + +def test_execution_real_data(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_image = os.path.join(SCILPY_HOME, 'processing', + 'dwi_crop.nii.gz') + + # Verify original orientation is RAS + img_in = nib.load(in_image) + assert nib.aff2axcodes(img_in.affine) == ('R', 'A', 'S') + + # Test LPS + out_lps = 'real_lps.nii.gz' + ret = script_runner.run(['scil_volume_modify_voxel_order', in_image, + out_lps, '--new_voxel_order=LPS', '-f']) + assert ret.success + img = nib.load(out_lps) + assert nib.aff2axcodes(img.affine) == ('L', 'P', 'S') + + # Test RAS + out_ras = 'real_ras.nii.gz' + ret = script_runner.run(['scil_volume_modify_voxel_order', in_image, + out_ras, '--new_voxel_order=RAS', '-f']) + assert ret.success + img = nib.load(out_ras) + assert nib.aff2axcodes(img.affine) == ('R', 'A', 'S') + + # Test LPI + out_lpi = 'real_lpi.nii.gz' + ret = script_runner.run(['scil_volume_modify_voxel_order', in_image, + out_lpi, '--new_voxel_order=LPI', '-f']) + assert ret.success + img = nib.load(out_lpi) + assert nib.aff2axcodes(img.affine) == ('L', 'P', 'I') + + +def test_execution_with_bvec_real_data(script_runner, monkeypatch): + monkeypatch.chdir(os.path.expanduser(tmp_dir.name)) + in_image = os.path.join(SCILPY_HOME, 'processing', + 'dwi_crop.nii.gz') + in_bvec = os.path.join(SCILPY_HOME, 'processing', + 'dwi.bvec') + + # Verify original orientation is RAS + img_in = nib.load(in_image) + assert nib.aff2axcodes(img_in.affine) == ('R', 'A', 'S') + + # Test LPI + out_lpi = 'real_lpi_grad.nii.gz' + out_bvec = 'real_lpi_grad.bvec' + ret = script_runner.run(['scil_volume_modify_voxel_order', in_image, + out_lpi, '--new_voxel_order=LPS', + '--in_bvec', in_bvec, '--out_bvec', out_bvec, '-f']) + assert ret.success + + # Verify image + img = nib.load(out_lpi) + assert nib.aff2axcodes(img.affine)[:3] == ('L', 'P', 'S') + + # Verify bvec + assert os.path.exists(out_bvec) + old_bvecs = np.loadtxt(in_bvec) + new_bvecs = np.loadtxt(out_bvec) + + # RAS to LPS: flip X, Y + expected_bvecs = old_bvecs.copy() + expected_bvecs[0, :] *= -1 # Flip X + expected_bvecs[1, :] *= -1 # Flip Y + assert np.allclose(new_bvecs, expected_bvecs, atol=1e-3) diff --git a/src/scilpy/image/tests/test_volume_operations.py b/src/scilpy/image/tests/test_volume_operations.py index 54b789aa1..8baa6123a 100644 --- a/src/scilpy/image/tests/test_volume_operations.py +++ b/src/scilpy/image/tests/test_volume_operations.py @@ -180,7 +180,7 @@ def test_resample_volume(): # Ref: 2x2x2, voxel size 3x3x3 ref3d = np.ones((2, 2, 2)) - ref_affine = np.eye(4)*3 + ref_affine = np.eye(4) * 3 ref_affine[-1, -1] = 1 # 1) Option volume_shape: expecting an output of 2x2x2, which means @@ -217,7 +217,7 @@ def test_resample_volume(): def test_reshape_volume_pad(): # 3D img simg = StatefulImage( - np.arange(1, (3**3)+1).reshape((3, 3, 3)).astype(float), + np.arange(1, (3**3) + 1).reshape((3, 3, 3)).astype(float), np.eye(4)) # 1) Reshaping to 4x4x4, padding with 0 @@ -237,7 +237,7 @@ def test_reshape_volume_pad(): # 4D img (2 "stacked" 3D volumes) simg = StatefulImage( - np.arange(1, ((3**3) * 2)+1).reshape((3, 3, 3, 2)).astype(float), + np.arange(1, ((3**3) * 2) + 1).reshape((3, 3, 3, 2)).astype(float), np.eye(4)) # 2) Reshaping to 5x5x5, padding with 0 @@ -248,7 +248,7 @@ def test_reshape_volume_pad(): def test_reshape_volume_crop(): # 3D img simg = StatefulImage( - np.arange(1, (3**3)+1).reshape((3, 3, 3)).astype(float), + np.arange(1, (3**3) + 1).reshape((3, 3, 3)).astype(float), np.eye(4)) # 1) Cropping to 1x1x1 @@ -265,7 +265,7 @@ def test_reshape_volume_crop(): # 4D img simg = StatefulImage( - np.arange(1, ((3**3) * 2)+1).reshape((3, 3, 3, 2)).astype(float), + np.arange(1, ((3**3) * 2) + 1).reshape((3, 3, 3, 2)).astype(float), np.eye(4)) # 2) Cropping to 2x2x2 @@ -278,7 +278,7 @@ def test_reshape_volume_crop(): def test_reshape_volume_dtype(): # 3D img simg = StatefulImage( - np.arange(1, (3**3)+1).reshape((3, 3, 3)).astype(np.uint16), + np.arange(1, (3**3) + 1).reshape((3, 3, 3)).astype(np.uint16), np.eye(4)) # 1) Staying in 3x3x3, same dtype diff --git a/src/scilpy/image/volume_operations.py b/src/scilpy/image/volume_operations.py index 02c52e957..48460351a 100644 --- a/src/scilpy/image/volume_operations.py +++ b/src/scilpy/image/volume_operations.py @@ -188,8 +188,11 @@ def apply_transform(transfo, reference, raise ValueError('Does not support this dataset (shape, type, etc)') moved_nib_img = nib.Nifti1Image(resampled.astype(orig_type), grid2world) - return StatefulImage.create_from(moved_nib_img, - StatefulImage.convert_to_simg(reference)) + if isinstance(reference, StatefulImage): + return StatefulImage.create_from(moved_nib_img, reference) + else: + return StatefulImage.create_from( + moved_nib_img, StatefulImage.convert_to_simg(reference)) def transform_dwi(reg_obj, static, dwi, interpolation='linear'): @@ -213,7 +216,7 @@ def transform_dwi(reg_obj, static, dwi, interpolation='linear'): trans_dwi: nib.Nifti1Image The warped 4D volume. """ - trans_dwi = np.zeros(static.shape + (dwi.shape[3],), dtype=dwi.dtype) + trans_dwi = np.zeros(static.shape[:3] + (dwi.shape[3],), dtype=dwi.dtype) for i in range(dwi.shape[3]): trans_dwi[..., i] = reg_obj.transform(dwi[..., i], interpolation=interpolation) @@ -272,8 +275,8 @@ def register_image(static, static_grid2world, moving, moving_grid2world, level_iters = [250, 100, 50, 25] if fine else [50, 25, 5] # With images too small, dipy fails with no clear warning. - if (np.any(np.asarray(moving.shape) < 8) or - np.any(np.asarray(static.shape) < 8)): + if (np.any(np.asarray(moving.shape) < 8) + or np.any(np.asarray(static.shape) < 8)): raise ValueError("Current implementation of registration was prepared " "with factors up to 8. Requires images with at least " "8 voxels in each direction.") @@ -397,7 +400,7 @@ def compute_snr(dwi, bval, bvec, b0_thr, mask, noise_mask=None, noise_map=None, # Add the upper half in order to delete the neck and shoulder # when inverting the mask - noise_mask[..., :noise_mask.shape[-1]//2] = 1 + noise_mask[..., :noise_mask.shape[-1] // 2] = 1 # Reverse the mask to get only noise noise_mask = (~noise_mask).astype(bool) diff --git a/src/scilpy/image/volume_space_management.py b/src/scilpy/image/volume_space_management.py index 6f5b54429..b9ade16fb 100644 --- a/src/scilpy/image/volume_space_management.py +++ b/src/scilpy/image/volume_space_management.py @@ -21,7 +21,8 @@ class DataVolume(object): Class to access/interpolate data from nibabel object """ - def __init__(self, data, voxres, interpolation=None, must_be_3d=False): + def __init__(self, data, voxres, affine=None, interpolation=None, + must_be_3d=False): """ Parameters ---------- @@ -29,6 +30,8 @@ def __init__(self, data, voxres, interpolation=None, must_be_3d=False): The data, ex, loaded from nibabel img.get_fdata(). voxres: np.array(3,) The pixel resolution, ex, using img.header.get_zooms()[:3]. + affine: np.array(4,4) + The affine matrix mapping voxel coordinates to RASMM. interpolation: str or None The interpolation choice amongst "trilinear" or "nearest". If None, functions getting a coordinate in mm instead of voxel @@ -43,9 +46,14 @@ def __init__(self, data, voxres, interpolation=None, must_be_3d=False): raise Exception("Interpolation must be 'trilinear' or " "'nearest'") - self.data = data + self.data = data.astype(np.float64, copy=False) self.nb_coeffs = data.shape[-1] self.voxres = voxres + self.affine = affine + if affine is not None: + self.inv_affine = np.linalg.inv(affine) + else: + self.inv_affine = None if must_be_3d and self.data.ndim != 3: raise Exception("Data should have been 3D but data dimension is:" @@ -88,7 +96,7 @@ def get_value_at_coordinate(self, x, y, z, space, origin): x, y, z: floats Voxel coordinates along each axis. space: dipy Space - 'vox' or 'voxmm'. + 'vox', 'voxmm' or 'rasmm'. origin: dipy Origin 'corner' or 'center'. @@ -101,9 +109,10 @@ def get_value_at_coordinate(self, x, y, z, space, origin): return self._vox_to_value(x, y, z, origin) elif space == Space.VOXMM: return self._voxmm_to_value(x, y, z, origin) + elif space == Space.RASMM: + return self._rasmm_to_value(x, y, z, origin) else: - raise NotImplementedError("We have not prepared the DataVolume to " - "work in RASMM space yet.") + raise ValueError("Space should be a choice of Dipy Space.") def is_idx_in_bound(self, i, j, k): """ @@ -132,7 +141,7 @@ def is_coordinate_in_bound(self, x, y, z, space, origin): x, y, z: floats Voxel coordinates along each axis. space: dipy Space - 'vox' or 'voxmm'. + 'vox', 'voxmm' or 'rasmm'. origin: dipy Origin 'corner' or 'center'. @@ -145,9 +154,10 @@ def is_coordinate_in_bound(self, x, y, z, space, origin): return self._is_vox_in_bound(x, y, z, origin) elif space == Space.VOXMM: return self._is_voxmm_in_bound(x, y, z, origin) + elif space == Space.RASMM: + return self._is_rasmm_in_bound(x, y, z, origin) else: - raise NotImplementedError("We have not prepared the DataVolume to " - "work in RASMM space yet.") + raise ValueError("Space should be a choice of Dipy Space.") def _clip_idx_to_bound(self, i, j, k): """ @@ -369,6 +379,94 @@ def _is_voxmm_in_bound(self, x, y, z, origin): """ return self.is_idx_in_bound(*self.voxmm_to_idx(x, y, z, origin)) + def rasmm_to_vox(self, x, y, z, origin): + """ + Get voxel space coordinates at position x, y, z (rasmm). + + Parameters + ---------- + x, y, z: floats + Position coordinate (rasmm) along x, y, z axis. + origin: dipy Origin + 'corner' or 'center'. + + Return + ------ + x, y, z: floats + Voxel space coordinates for position x, y, z. + """ + if self.inv_affine is None: + raise ValueError("Affine matrix is required for RASMM space.") + + vox_corner = np.dot(self.inv_affine, [x, y, z, 1])[:3] + if origin == Origin('center'): + return vox_corner - 0.5 + return vox_corner + + def vox_to_rasmm(self, x, y, z, origin): + """ + Get RASMM space coordinates at position x, y, z (vox). + + Parameters + ---------- + x, y, z: floats + Position coordinate (vox) along x, y, z axis. + origin: dipy Origin + 'corner' or 'center'. + + Return + ------ + x, y, z: floats + RASMM space coordinates for position x, y, z. + """ + if self.affine is None: + raise ValueError("Affine matrix is required for RASMM space.") + + if origin == Origin('center'): + x, y, z = x + 0.5, y + 0.5, z + 0.5 + + return np.dot(self.affine, [x, y, z, 1])[:3] + + def _rasmm_to_value(self, x, y, z, origin): + """ + Get the voxel value at voxel position x, y, z (rasmm) in the dataset. + If the coordinates are out of bound, the nearest voxel value is taken. + Value is interpolated based on the value of self.interpolation. + + Parameters + ---------- + x, y, z: floats + Position coordinate (rasmm) along x, y, z axis. + origin: dipy Space + 'center' or 'corner'. + + Return + ------ + value: ndarray (self.dims[-1],) or float + Interpolated value at position x, y, z (rasmm). If the last + dimension is of length 1, return a scalar value. + """ + return self._vox_to_value(*self.rasmm_to_vox(x, y, z, origin), origin) + + def _is_rasmm_in_bound(self, x, y, z, origin): + """ + Test if the position x, y, z rasmm is in the dataset range. + + Parameters + ---------- + x, y, z: floats + Position coordinate (rasmm) along x, y, z axis. + origin: dipy Space + 'center' or 'corner'. + + Return + ------ + value: bool + True if position is in dataset range and false otherwise. + """ + return self.is_idx_in_bound(*self.vox_to_idx( + *self.rasmm_to_vox(x, y, z, origin), origin)) + class FibertubeDataVolume(DataVolume): """ @@ -442,15 +540,18 @@ def get_value_at_coordinate(self, x, y, z, space, origin): return self._voxmm_to_value(*self.vox_to_voxmm(x, y, z), origin) elif space == Space.VOXMM: return self._voxmm_to_value(x, y, z, origin) + elif space == Space.RASMM: + return self._voxmm_to_value(*self.rasmm_to_voxmm(x, y, z), origin) else: - raise NotImplementedError("We have not prepared the DataVolume " - "to work in RASMM space yet.") + raise ValueError("Space should be a choice of Dipy Space.") def is_idx_in_bound(self, i, j, k): return super().is_idx_in_bound(i, j, k) def is_coordinate_in_bound(self, x, y, z, space, origin): FibertubeDataVolume._validate_origin(origin) + if space == Space.RASMM: + return self._is_rasmm_in_bound(x, y, z, origin) return super().is_coordinate_in_bound(x, y, z, space, origin) @staticmethod @@ -486,6 +587,24 @@ def vox_to_voxmm(self, x, y, z): y * self.voxres[1], z * self.voxres[2]] + def rasmm_to_voxmm(self, x, y, z): + """ + Get voxmm space coordinates at position x, y, z (rasmm). + + Parameters + ---------- + x, y, z: floats + Position coordinate (rasmm) along x, y, z axis. + + Return + ------ + x, y, z: floats + voxmm space coordinates for position x, y, z. + """ + # FibertubeDataVolume only supports origin center (NIFTI) + vox = self.rasmm_to_vox(x, y, z, Origin.NIFTI) + return self.vox_to_voxmm(*vox) + def _clip_voxmm_to_bound(self, x, y, z, origin): return self.vox_to_voxmm(*self._clip_vox_to_bound( *self.voxmm_to_vox(x, y, z), origin)) diff --git a/src/scilpy/io/btensor.py b/src/scilpy/io/btensor.py index 73cbe5f40..1411df2c1 100644 --- a/src/scilpy/io/btensor.py +++ b/src/scilpy/io/btensor.py @@ -1,15 +1,10 @@ -import logging - from dipy.core.gradients import (gradient_table, unique_bvals_tolerance, get_bval_indices) -from dipy.io.gradients import read_bvals_bvecs -import nibabel as nib import numpy as np from scilpy.dwi.utils import extract_dwi_shell -from scilpy.gradients.bvec_bval_tools import (normalize_bvecs, - is_normalized_bvecs, - check_b0_threshold) +from scilpy.gradients.bvec_bval_tools import check_b0_threshold +from scilpy.io.stateful_image import StatefulImage bshapes = {0: "STE", 1: "LTE", -0.5: "PTE", 0.5: "CTE"} @@ -111,21 +106,24 @@ def generate_btensor_input(in_dwis, in_bvals, in_bvecs, in_bdeltas, for inputf, bvalsf, bvecsf, b_delta in zip(in_dwis, in_bvals, in_bvecs, in_bdeltas): if inputf: # verifies if the input file exists - vol = nib.load(inputf) - bvals, bvecs = read_bvals_bvecs(bvalsf, bvecsf) + simg = StatefulImage.load(inputf) + simg.load_gradients(bvalsf, bvecsf) + simg.to_ras() + + bvals = simg.bvals + bvecs = simg.world_bvecs + _ = check_b0_threshold(bvals.min(), b0_thr=tol, skip_b0_check=skip_b0_check, overwrite_with_min=False) if np.sum([bvals > tol]) != 0: bvals = np.round(bvals) - if not is_normalized_bvecs(bvecs): - logging.warning('Your b-vectors do not seem normalized...') - bvecs = normalize_bvecs(bvecs) + ubvals = unique_bvals_tolerance(bvals, tol=tol) for ubval in ubvals: # Loop over all unique bvals # Extracting the data for the ubval shell indices, shell_data, _, output_bvecs = \ - extract_dwi_shell(vol, bvals, bvecs, [ubval], tol=tol) + extract_dwi_shell(simg, bvals, bvecs, [ubval], tol=tol) nb_bvecs = len(indices) # Adding the current data to each arrays of interest acq_index_full = np.concatenate([acq_index_full, diff --git a/src/scilpy/io/image.py b/src/scilpy/io/image.py index af17a4bab..91d89e11a 100644 --- a/src/scilpy/io/image.py +++ b/src/scilpy/io/image.py @@ -2,11 +2,12 @@ from dipy.io.utils import is_header_compatible import logging -import nibabel as nib import numpy as np import os from scilpy.utils import is_float +from scilpy.io.stateful_image import StatefulImage + def load_img(arg): """ @@ -22,7 +23,7 @@ def load_img(arg): else: if not os.path.isfile(arg): raise ValueError('Input file {} does not exist.'.format(arg)) - img = nib.load(arg) + img = StatefulImage.load(arg) shape = img.header.get_data_shape() dtype = img.header.get_data_dtype() logging.info('Loaded {} of shape {} and data_type {}.'.format( @@ -87,14 +88,18 @@ def get_data_as_mask(mask_img, dtype=np.uint8): Data (dtype : np.uint8 or bool). """ # Verify that out data type is ok - if not (issubclass(np.dtype(dtype).type, np.uint8) or - issubclass(np.dtype(dtype).type, np.dtype(bool).type)): + if not (issubclass(np.dtype(dtype).type, np.uint8) + or issubclass(np.dtype(dtype).type, np.dtype(bool).type)): raise IOError('Output data type must be uint8 or bool. ' 'Current data type is {}.'.format(dtype)) # Verify that loaded datatype is ok curr_type = mask_img.get_data_dtype().type - basename = os.path.basename(mask_img.get_filename()) + if hasattr(mask_img, 'get_filename') and mask_img.get_filename(): + basename = os.path.basename(mask_img.get_filename()) + else: + basename = "unnamed" + if np.issubdtype(curr_type, np.signedinteger) or \ np.issubdtype(curr_type, np.unsignedinteger) \ or np.issubdtype(curr_type, np.dtype(bool).type): diff --git a/src/scilpy/io/mti.py b/src/scilpy/io/mti.py index 7d7ec5090..3b998da75 100644 --- a/src/scilpy/io/mti.py +++ b/src/scilpy/io/mti.py @@ -228,7 +228,7 @@ def _prepare_B1_map(args, flip_angles, extended_dir, affine): """ B1_map = None if args.in_B1_map and args.in_mtoff_t1: - B1_img = nib.load(args.in_B1_map) + B1_img, _ = load_img(args.in_B1_map) B1_map = B1_img.get_fdata(dtype=np.float32) B1_map = adjust_B1_map_intensities(B1_map, nominal=args.B1_nominal) B1_map = smooth_B1_map(B1_map, wdims=args.B1_smooth_dims) diff --git a/src/scilpy/io/stateful_image.py b/src/scilpy/io/stateful_image.py index b61e0b834..825e320d7 100644 --- a/src/scilpy/io/stateful_image.py +++ b/src/scilpy/io/stateful_image.py @@ -1,6 +1,10 @@ # -*- coding: utf-8 -*- import nibabel as nib +import numpy as np +from scipy.linalg import polar + +from dipy.io.gradients import read_bvals_bvecs from dipy.io.utils import get_reference_info from scilpy.utils.orientation import validate_voxel_order @@ -18,7 +22,10 @@ class StatefulImage(nib.Nifti1Image): def __init__(self, dataobj, affine, header=None, extra=None, file_map=None, original_affine=None, original_dimensions=None, original_voxel_sizes=None, - original_axcodes=None): + original_axcodes=None, bvals=None, bvecs=None, + gradients_original_order=True, + sh_basis='descoteaux07', is_legacy=True, + is_orientation=False, is_world_space=True): """ Initialize a StatefulImage object. @@ -32,8 +39,34 @@ def __init__(self, dataobj, affine, header=None, extra=None, self._original_voxel_sizes = original_voxel_sizes self._original_axcodes = original_axcodes + # Directional information + self._sh_basis = sh_basis + self._is_legacy = is_legacy + self._is_orientation = is_orientation + self._is_world_space = is_world_space + + # Store gradient information + self._bvals = None + self._world_bvecs = None + if bvals is not None and bvecs is not None: + self.attach_gradients(bvals, bvecs, gradients_original_order) + + def _get_rotation_matrix(self, affine): + """ + Extract the pure rotation component from a 4x4 affine matrix. + """ + # Extract 3x3 part + A = affine[:3, :3] + # Polar decomposition: A = P * R + # R is the closest orthogonal matrix to A. + # We want the orthogonal part that matches the image's orientation. + R, P = polar(A) + return R + @classmethod - def load(cls, filename, to_orientation="RAS"): + def load(cls, filename, to_orientation="RAS", + is_orientation=False, is_world_space=True, + sh_basis='descoteaux07', is_legacy=True): """ Load a NIfTI image, store its original orientation, and reorient it. @@ -43,6 +76,12 @@ def load(cls, filename, to_orientation="RAS"): Path to the NIfTI file. to_orientation : str or tuple, optional The target orientation for the in-memory data. Default is "RAS". + is_orientation : bool, optional + Whether the image contains directional data (SH, Peaks, SF). + Default is False. + is_world_space : bool, optional + Whether the directional data is already in world space. + Only used if is_orientation is True. Default is True. Returns ------- @@ -66,11 +105,171 @@ def load(cls, filename, to_orientation="RAS"): else: reoriented_img = img - return cls(reoriented_img.dataobj, reoriented_img.affine, + simg = cls(reoriented_img.dataobj, reoriented_img.affine, reoriented_img.header, original_affine=original_affine, original_dimensions=original_dims, original_voxel_sizes=original_voxel_sizes, - original_axcodes=original_axcodes) + original_axcodes=original_axcodes, + sh_basis=sh_basis, is_legacy=is_legacy, + is_orientation=is_orientation, + is_world_space=is_world_space) + + if is_orientation and not is_world_space: + # Move from original voxel space to world space + # Note: We use original_affine because the data was loaded + # in that space. + data = simg.get_fdata(dtype=np.float32) + R = simg._get_rotation_matrix(original_affine) + rotated_data = simg._rotate_direction_data(data, R, + sh_basis=sh_basis, + is_legacy=is_legacy) + simg = cls.from_data(rotated_data, simg) + + return simg + + def to_voxel_direction(self, data=None, sh_basis=None, + is_legacy=None, nbr_processes=None): + """ + Transform directional data from world space to current voxel space. + + Parameters + ---------- + data : np.ndarray, optional + The directional data to transform. If None, uses the image data + and updates it in-place. + sh_basis : str, optional + The SH basis of the directional data. Defaults to self.sh_basis. + is_legacy : bool, optional + Whether the SH basis is legacy. Defaults to self.is_legacy. + nbr_processes : int, optional + Number of processes to use for rotation. + + Returns + ------- + np.ndarray + The transformed directional data in voxel space. + """ + if sh_basis is None: + sh_basis = self.sh_basis + if is_legacy is None: + is_legacy = self.is_legacy + + if data is None: + if not self.is_orientation: + raise ValueError("Image is not marked as directional.") + if not self.is_world_space: + return self.get_fdata(dtype=np.float32) + + data = self.get_fdata(dtype=np.float32) + R = self._get_rotation_matrix(self.affine).T + rotated_data = self._rotate_direction_data(data, R, + sh_basis=sh_basis, + is_legacy=is_legacy, + nbr_processes=nbr_processes) + self._dataobj = rotated_data + self._is_world_space = False + return rotated_data + + # R_world_to_voxel = R_voxel_to_world.T + R = self._get_rotation_matrix(self.affine).T + return self._rotate_direction_data(data, R, sh_basis=sh_basis, + is_legacy=is_legacy, + nbr_processes=nbr_processes) + + def to_world_direction(self, data=None, sh_basis=None, + is_legacy=None, nbr_processes=None): + """ + Transform directional data from voxel space to world space. + + Parameters + ---------- + data : np.ndarray, optional + The directional data to transform. If None, uses the image data + and updates it in-place. + sh_basis : str, optional + The SH basis of the directional data. Defaults to self.sh_basis. + is_legacy : bool, optional + Whether the SH basis is legacy. Defaults to self.is_legacy. + nbr_processes : int, optional + Number of processes to use for rotation. + + Returns + ------- + np.ndarray + The transformed directional data in world space. + """ + if sh_basis is None: + sh_basis = self.sh_basis + if is_legacy is None: + is_legacy = self.is_legacy + + if data is None: + if not self.is_orientation: + raise ValueError("Image is not marked as directional.") + if self.is_world_space: + return self.get_fdata(dtype=np.float32) + + data = self.get_fdata(dtype=np.float32) + R = self._get_rotation_matrix(self.affine) + rotated_data = self._rotate_direction_data(data, R, + sh_basis=sh_basis, + is_legacy=is_legacy, + nbr_processes=nbr_processes) + self._dataobj = rotated_data + self._is_world_space = True + return rotated_data + + R = self._get_rotation_matrix(self.affine) + return self._rotate_direction_data(data, R, sh_basis=sh_basis, + is_legacy=is_legacy, + nbr_processes=nbr_processes) + + def _rotate_direction_data(self, data, R, sh_basis='descoteaux07', + is_legacy=True, nbr_processes=None): + """ + Internal helper to rotate SH or Peaks data. + """ + from scilpy.reconst.utils import (get_sh_order_and_fullness, + is_data_peaks) + + original_shape = data.shape + if len(original_shape) == 5 and original_shape[-1] == 7: + # Bingham-like data: [amp, mu1_x, mu1_y, mu1_z, mu2_x, mu2_y, mu2_z] + # We rotate mu1 and mu2 + bingham_data = data.copy() + mu1 = bingham_data[..., 1:4].reshape(-1, 3) + mu2 = bingham_data[..., 4:7].reshape(-1, 3) + rotated_mu1 = np.dot(mu1, R.T) + rotated_mu2 = np.dot(mu2, R.T) + bingham_data[..., 1:4] = rotated_mu1.reshape(original_shape[:4] + (3,)) + bingham_data[..., 4:7] = rotated_mu2.reshape(original_shape[:4] + (3,)) + return bingham_data + + # Handle 5D data + if len(original_shape) == 5: + # We treat each "lobe" independently for rotation if it's not SH + data = data.reshape(original_shape[0:3] + (-1,)) + + last_dim = data.shape[-1] + is_sh = not is_data_peaks(data) + if is_sh: + from scilpy.reconst.sh import rotate_sh + # SH data can be 4D (XxYxZxN) + order, full = get_sh_order_and_fullness(last_dim) + return rotate_sh(data, R, basis_type=sh_basis, + full_basis=full, is_legacy=is_legacy, + nbr_processes=nbr_processes) + elif last_dim % 3 == 0: + # Assume Peaks (N*3) + # Reshape to (..., N, 3), rotate, and reshape back + reshaped_data = data.reshape(-1, 3) + rotated_data = np.dot(reshaped_data, R.T) + return rotated_data.reshape(original_shape) + else: + raise ValueError( + f"Could not identify directional data type for " + f"shape {original_shape}. Not SH (wrong #coeffs) and " + f"not Peaks (not a multiple of 3).") def save(self, filename): """ @@ -87,8 +286,20 @@ def save(self, filename): "with StatefulImage.load() or that original_axcodes was" "provided when creating the StatefulImage instance.") - self.reorient_to_original() - nib.save(self, filename) + current_axcodes = self.axcodes[:3] + target_axcodes = self._original_axcodes[:3] + + if current_axcodes == target_axcodes: + nib.save(self, filename) + else: + start_ornt = nib.orientations.axcodes2ornt(current_axcodes) + target_ornt = nib.orientations.axcodes2ornt(target_axcodes) + transform = nib.orientations.ornt_transform(start_ornt, + target_ornt) + # Use Nifti1Image.as_reoriented to get a temporary object + # in the original orientation for saving. + reoriented_img = nib.Nifti1Image.as_reoriented(self, transform) + nib.save(reoriented_img, filename) @staticmethod def create_from(source, reference): @@ -110,15 +321,57 @@ def create_from(source, reference): A new StatefulImage with the source image's data and the reference image's original orientation information. """ + bvals = None + bvecs = None + if reference.bvals is not None and reference.world_bvecs is not None: + if source.ndim == 4 and len(reference.bvals) == source.shape[3]: + bvals = reference.bvals + # Transform world-space bvecs to source voxel space + R_source = reference._get_rotation_matrix(source.affine) + bvecs = np.dot(reference.world_bvecs, R_source) + + if StatefulImage.needs_fsl_flip(source.affine): + bvecs[:, 0] *= -1 + orig_dims = reference._original_dimensions + orig_vox = reference._original_voxel_sizes return StatefulImage(source.dataobj, source.affine, header=source.header, original_affine=reference._original_affine, - original_dimensions=reference._original_dimensions, - original_voxel_sizes=reference._original_voxel_sizes, - original_axcodes=reference._original_axcodes) + original_dimensions=orig_dims, + original_voxel_sizes=orig_vox, + original_axcodes=reference._original_axcodes, + bvals=bvals, bvecs=bvecs, + gradients_original_order=False, + sh_basis=reference.sh_basis, + is_legacy=reference.is_legacy, + is_orientation=reference.is_orientation, + is_world_space=reference.is_world_space) + + @staticmethod + def from_data(data, reference): + """ + Create a new StatefulImage from a numpy array, preserving the original + orientation information from a reference StatefulImage. + + Parameters + ---------- + data : numpy.ndarray + The image data to use for the new StatefulImage. + reference : StatefulImage + The reference image from which to copy original orientation + information. + + Returns + ------- + StatefulImage + A new StatefulImage with the data and the reference + image's original orientation information. + """ + new_img = nib.Nifti1Image(data, reference.affine, reference.header) + return StatefulImage.create_from(new_img, reference) @staticmethod - def convert_to_simg(img): + def convert_to_simg(img, bvals=None, bvecs=None): """ Initialize a StatefulImage from an existing Nifti1Image. @@ -129,13 +382,210 @@ def convert_to_simg(img): ---------- img : nib.Nifti1Image The Nifti1Image to initialize from. + bvals : array-like, optional + B-values. + bvecs : array-like, optional + B-vectors. """ + original_axcodes = nib.orientations.aff2axcodes(img.affine) + if len(img.shape) == 4: + original_axcodes += ('T',) + return StatefulImage(img.dataobj, img.affine, header=img.header, original_affine=img.affine.copy(), original_dimensions=img.header.get_data_shape(), original_voxel_sizes=img.header.get_zooms(), - original_axcodes=nib.orientations.aff2axcodes( - img.affine)) + original_axcodes=original_axcodes, + bvals=bvals, bvecs=bvecs) + + @staticmethod + def needs_fsl_flip(affine): + """ + According to BIDS/MRtrix convention, if the determinant of the + 3x3 rotation/scaling part of the affine is positive (neurological), + the x-component of the FSL-format bvecs must be flipped. + """ + return np.linalg.det(affine[:3, :3]) > 0 + + @property + def _needs_fsl_flip(self): + return StatefulImage.needs_fsl_flip(self.affine) + + @property + def sh_basis(self): + """Get the SH basis.""" + return self._sh_basis + + @property + def is_legacy(self): + """Get whether the SH basis is legacy.""" + return self._is_legacy + + @property + def is_orientation(self): + """Get whether the image contains directional data.""" + return self._is_orientation + + @property + def is_world_space(self): + """Get whether the directional data is in world space.""" + return self._is_world_space + + @property + def bvals(self): + """Get the current b-values.""" + return self._bvals + + @property + def bvecs(self): + """Get the current (reoriented) b-vectors in voxel space.""" + if self._world_bvecs is None: + return None + # Transform from world space to current voxel space + R = self._get_rotation_matrix(self.affine) + # v_voxel = v_world * R + bvecs = np.dot(self._world_bvecs, R) + + if self._needs_fsl_flip: + bvecs[:, 0] *= -1 + + return bvecs + + @property + def world_bvecs(self): + """Get the current b-vectors in world space.""" + return self._world_bvecs + + def attach_gradients(self, bvals, bvecs, original_order=True): + """ + Attach b-values and b-vectors to the image. + Gradients are stored internally in world space. + + Parameters + ---------- + bvals : array-like + B-values. + bvecs : array-like + B-vectors. + original_order : bool, optional + If True, assumes b-vectors are in the original voxel order. + If False, assumes b-vectors match current in-memory orientation. + Default is True. + """ + self._bvals = np.asanyarray(bvals) + bvecs = np.asanyarray(bvecs).copy() + + # Validate shapes + if self._bvals.ndim != 1: + raise ValueError("bvals must be a 1D array.") + if bvecs.ndim != 2 or bvecs.shape[1] != 3: + raise ValueError("bvecs must be an (N, 3) array.") + if len(self._bvals) != len(bvecs): + raise ValueError("bvals and bvecs must have the same length.") + + # Validate against image data + if len(self._bvals) != self.shape[3]: + raise ValueError(f"Number of gradients ({len(self._bvals)}) does " + f"not match number of volumes ({self.shape[3]}).") + + if original_order: + # Transform from original voxel space to world space + ref_affine = self._original_affine \ + if self._original_affine is not None else self.affine + else: + # Transform from current voxel space to world space + ref_affine = self.affine + + R = self._get_rotation_matrix(ref_affine) + + # Apply BIDS flip if needed + if StatefulImage.needs_fsl_flip(ref_affine): + bvecs[:, 0] *= -1 + + self._world_bvecs = np.dot(bvecs, R.T) + + # Normalize + norms = np.linalg.norm(self._world_bvecs, axis=1) + self._world_bvecs[norms > 1e-6] /= norms[norms > 1e-6][:, None] + + def attach_world_gradients(self, bvals, world_bvecs): + """ + Attach b-values and world-space b-vectors to the image. + + Parameters + ---------- + bvals : array-like + B-values. + world_bvecs : array-like + B-vectors in world space (RAS mm). + """ + self._bvals = np.asanyarray(bvals) + self._world_bvecs = np.asanyarray(world_bvecs).copy() + + # Validate shapes + if self._bvals.ndim != 1: + raise ValueError("bvals must be a 1D array.") + if self._world_bvecs.ndim != 2 or self._world_bvecs.shape[1] != 3: + raise ValueError("world_bvecs must be an (N, 3) array.") + if len(self._bvals) != len(self._world_bvecs): + raise ValueError( + "bvals and world_bvecs must have the same length.") + + # Normalize + norms = np.linalg.norm(self._world_bvecs, axis=1) + self._world_bvecs[norms > 1e-6] /= norms[norms > 1e-6][:, None] + + def load_gradients(self, bval_path, bvec_path): + """ + Load b-values and b-vectors from FSL-formatted files. + + Parameters + ---------- + bval_path : str + Path to the bvals file. + bvec_path : str + Path to the bvecs file. + """ + bvals, bvecs = read_bvals_bvecs(bval_path, bvec_path) + self.attach_gradients(bvals, bvecs) + + def save_gradients(self, bval_path, bvec_path): + """ + Save b-values and b-vectors to FSL-formatted files. + Ensures b-vectors match the original voxel order. + + Parameters + ---------- + bval_path : str + Path to save the bvals file. + bvec_path : str + Path to save the bvecs file. + """ + if self._bvals is None or self._world_bvecs is None: + raise ValueError("No gradients attached to this StatefulImage.") + + # Transform from world space back to original voxel space + ref_affine = self._original_affine \ + if self._original_affine is not None else self.affine + R = self._get_rotation_matrix(ref_affine) + # v_voxel = v_world * R + bvecs_to_save = np.dot(self._world_bvecs, R) + + # According to BIDS/MRtrix convention, if the determinant of the + # affine is positive (neurological), the x-component of the bvecs + # must be flipped. + if StatefulImage.needs_fsl_flip(ref_affine): + bvecs_to_save[:, 0] *= -1 + + np.savetxt(bvec_path, bvecs_to_save.T, fmt='%.8f') + np.savetxt(bval_path, self._bvals[None, :], fmt='%.3f') + + def _reorient_gradients(self, start_axcodes, target_axcodes): + """ + Internal helper to reorient b-vectors. + Now that b-vectors are in world space, this does nothing. + """ + pass def reorient_to_original(self): """ @@ -150,8 +600,8 @@ def reorient_to_original(self): """ if self._original_axcodes is None: raise ValueError( - "Original axis codes are not set cannot reorient to original" - "orientation.") + "Original axis codes are not set. Cannot reorient to original" + " orientation.") self.reorient(self._original_axcodes) def reorient(self, target_axcodes): @@ -163,40 +613,31 @@ def reorient(self, target_axcodes): target_axcodes : str or tuple The target orientation axis codes (e.g., "LPS", ("R", "A", "S")). """ - validate_voxel_order(target_axcodes) - - current_axcodes = nib.orientations.aff2axcodes(self.affine) - if current_axcodes == tuple(target_axcodes): - return + if target_axcodes is None: + raise ValueError("Axis codes cannot be None.") - # Check unique are only valid axis codes - valid_codes = {'L', 'R', 'A', 'P', 'S', 'I'} - for code in target_axcodes: - if code not in valid_codes: - raise ValueError(f"Invalid axis code '{code}' in target.") + # Ensure target_axcodes has the same number of dimensions as self.shape + # by padding with unique placeholder codes if necessary. + target_axcodes = tuple(target_axcodes[:3]) - # Check L/R, A/P, S/I pairs are not both present - pairs = [('L', 'R'), ('A', 'P'), ('S', 'I')] - for pair in pairs: - if pair[0] in target_axcodes and pair[1] in target_axcodes: - raise ValueError(f"Conflicting axis codes '{pair[0]}' and " - f"'{pair[1]}' in target.") + validate_voxel_order(target_axcodes, dimensions=3) - # Check no repeated axis codes (LL, RR, etc.) - if len(set(target_axcodes)) != 3: - raise ValueError("Target axis codes must be unique.") + current_axcodes = self.axcodes[:3] + if current_axcodes == target_axcodes: + return start_ornt = nib.orientations.axcodes2ornt(current_axcodes) target_ornt = nib.orientations.axcodes2ornt(target_axcodes) transform = nib.orientations.ornt_transform(start_ornt, target_ornt) - reoriented_img = self.as_reoriented(transform) - self.__init__(reoriented_img.dataobj, reoriented_img.affine, - reoriented_img.header, - original_affine=self._original_affine, - original_dimensions=self._original_dimensions, - original_voxel_sizes=self._original_voxel_sizes, - original_axcodes=self._original_axcodes) + # Use Nifti1Image.as_reoriented to get a temporary object + # with the new orientation. + reoriented_img = nib.Nifti1Image.as_reoriented(self, transform) + + # Update Nifti1Image attributes in-place + self._dataobj = reoriented_img.dataobj + self._affine = reoriented_img.affine + self._header = reoriented_img.header def to_ras(self): """Convenience method to reorient in-memory data to RAS.""" @@ -214,7 +655,7 @@ def to_reference(self, obj): Parameters ---------- obj : object - Reference object from which orientation information can be obtained. + Reference object from which orientation information is obtained. Must not be an instance of ``StatefulImage``. Raises @@ -227,20 +668,43 @@ def to_reference(self, obj): raise TypeError('Reference object must not be a StatefulImage.') _, _, _, voxel_order = get_reference_info(obj) - self.reorient(voxel_order) + self.reorient(voxel_order[:3]) @property def axcodes(self): """Get the axis codes for the current image orientation.""" - return nib.orientations.aff2axcodes(self.affine) + codes = list(nib.orientations.aff2axcodes(self.affine)) + if len(self.shape) > 3: + extra_codes = ['T', 'U', 'V', 'W', 'X', 'Y', 'Z'] + for i in range(3, len(self.shape)): + codes.append(extra_codes[i-3]) + return tuple(codes) @property def original_axcodes(self): """Get the axis codes for the original image orientation.""" return self._original_axcodes + @property + def original_affine(self): + """Get the original image affine.""" + return self._original_affine + + @property + def original_header(self): + """Get a header matching the original image orientation.""" + # Create a copy of the current header but with original info + header = self.header.copy() + header.set_sform(self._original_affine) + header.set_qform(self._original_affine) + if self._original_voxel_sizes is not None: + header.set_zooms(self._original_voxel_sizes) + if self._original_dimensions is not None: + header.set_data_shape(self._original_dimensions) + return header + def __str__(self): - """Return a string representation of the image, including orientation.""" + """Return a string representation including orientation information.""" base_str = super().__str__() current_axcodes = self.axcodes reoriented = current_axcodes != self._original_axcodes diff --git a/src/scilpy/io/tests/test_stateful_image.py b/src/scilpy/io/tests/test_stateful_image.py index e0fb840b3..033d6e3e4 100644 --- a/src/scilpy/io/tests/test_stateful_image.py +++ b/src/scilpy/io/tests/test_stateful_image.py @@ -193,7 +193,7 @@ def test_direct_instantiation(): @pytest.mark.parametrize("codes, error_msg", [ (None, "Axis codes cannot be None."), - ("INVALID", "Target axis codes must be of length 3."), + ("INVALID", "Invalid axis code 'N' in target."), ("RAR", "Target axis codes must be unique."), ("LRR", "Target axis codes must be unique."), ("LRA", "Conflicting axis codes 'L' and 'R' in target."), diff --git a/src/scilpy/io/tests/test_stateful_image_gradients.py b/src/scilpy/io/tests/test_stateful_image_gradients.py new file mode 100644 index 000000000..11b54093e --- /dev/null +++ b/src/scilpy/io/tests/test_stateful_image_gradients.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- + +import os +import pytest +import tempfile +from contextlib import contextmanager + +import nibabel as nib +import numpy as np + +from scilpy.io.stateful_image import StatefulImage + + +@contextmanager +def create_dummy_nifti_with_gradients(filename="test.nii.gz", n_volumes=5, affine=None): + """ + Create a dummy NIfTI file and gradient files for testing. + """ + with tempfile.TemporaryDirectory() as tmpdir: + shape = (10, 10, 10, n_volumes) + if affine is None: + affine = np.eye(4) + data = np.random.rand(*shape).astype(np.float32) + img = nib.Nifti1Image(data, affine) + + file_path = os.path.join(tmpdir, filename) + nib.save(img, file_path) + + bvals = np.random.randint(0, 3000, n_volumes) + bvecs = np.random.randn(n_volumes, 3) + bvecs /= (np.linalg.norm(bvecs, axis=1)[:, None] + 1e-8) + + bval_path = os.path.join(tmpdir, "test.bval") + bvec_path = os.path.join(tmpdir, "test.bvec") + + np.savetxt(bval_path, bvals[None, :], fmt='%d') + np.savetxt(bvec_path, bvecs.T, fmt='%.8f') + + yield file_path, bval_path, bvec_path, bvals, bvecs + + +def test_attach_gradients(): + with create_dummy_nifti_with_gradients() as (img_p, bval_p, bvec_p, bvals, bvecs): + simg = StatefulImage.load(img_p) + simg.attach_gradients(bvals, bvecs) + + assert np.allclose(simg.bvals, bvals) + assert np.allclose(simg.bvecs, bvecs) + + +def test_load_gradients(): + with create_dummy_nifti_with_gradients() as (img_p, bval_p, bvec_p, bvals, bvecs): + simg = StatefulImage.load(img_p) + simg.load_gradients(bval_p, bvec_p) + + assert np.allclose(simg.bvals, bvals) + assert np.allclose(simg.bvecs, bvecs, atol=1e-5) + + +def test_reorient_gradients(): + with create_dummy_nifti_with_gradients() as (img_p, bval_p, bvec_p, bvals, bvecs): + simg = StatefulImage.load(img_p) + simg.attach_gradients(bvals, bvecs) + + # LPS reorientation: flip x and y + simg.to_lps() + assert simg.axcodes == ("L", "P", "S", "T") + + expected_bvecs = bvecs.copy() + expected_bvecs[:, 0] *= -1 + expected_bvecs[:, 1] *= -1 + + assert np.allclose(simg.bvecs, expected_bvecs) + + # Reorient back to RAS + simg.to_ras() + assert simg.axcodes == ("R", "A", "S", "T") + assert np.allclose(simg.bvecs, bvecs) + + +def test_save_gradients(): + with create_dummy_nifti_with_gradients() as (img_p, bval_p, bvec_p, bvals, bvecs): + simg = StatefulImage.load(img_p) + simg.attach_gradients(bvals, bvecs) + simg.to_lps() + + tmp_dir = os.path.dirname(img_p) + out_bval = os.path.join(tmp_dir, "out.bval") + out_bvec = os.path.join(tmp_dir, "out.bvec") + + simg.save_gradients(out_bval, out_bvec) + + # Saved gradients should be back in RAS (original) + saved_bvals = np.loadtxt(out_bval) + saved_bvecs = np.loadtxt(out_bvec).T + + assert np.allclose(saved_bvals, bvals) + assert np.allclose(saved_bvecs, bvecs) + + # StatefulImage itself should still be in LPS + assert simg.axcodes == ("L", "P", "S", "T") + + +def test_create_from_with_gradients(): + with create_dummy_nifti_with_gradients() as (img_p, bval_p, bvec_p, bvals, bvecs): + simg = StatefulImage.load(img_p) + simg.attach_gradients(bvals, bvecs) + simg.to_lps() + + # Create new simg from source (RAS) but with same reference (LPS) + source_nii = nib.load(img_p) + new_simg = StatefulImage.create_from(source_nii, simg) + + # new_simg matches source_nii (RAS) + assert new_simg.axcodes == ("R", "A", "S", "T") + # bvecs should have been reoriented back to RAS to match source_nii + assert np.allclose(new_simg.bvecs, bvecs) + assert np.allclose(new_simg.bvals, bvals) + + +def test_validation_errors(): + with create_dummy_nifti_with_gradients(n_volumes=5) as \ + (img_p, bval_p, bvec_p, bvals, bvecs): + simg = StatefulImage.load(img_p) + + # Wrong number of volumes + with pytest.raises(ValueError, + match="Number of gradients.*does not match number of volumes"): + simg.attach_gradients(bvals[:3], bvecs[:3]) + + # Wrong shape + with pytest.raises(ValueError, match="bvals must be a 1D array"): + simg.attach_gradients(bvals[:, None], bvecs) + + +def test_gradient_consistency_across_orientations(): + """ + Comprehensive test: + 1. Create RAS image + gradients. + 2. Reorient to LAS, LPS, LPI. + 3. Save in those orientations. + 4. Load back and verify they all return to the same RAS state. + """ + n_volumes = 4 + with create_dummy_nifti_with_gradients(n_volumes=n_volumes) as \ + (img_p, bval_p, bvec_p, bvals, bvecs): + simg_ras = StatefulImage.load(img_p) + simg_ras.attach_gradients(bvals, bvecs) + + # Original bvecs are in RAS (matching simg_ras.axcodes) + original_bvecs = simg_ras.bvecs.copy() + + for target_ornt in ["LAS", "LPS", "LPI"]: + with tempfile.TemporaryDirectory() as tmpdir: + # 1. Reorient + simg_ras.reorient(target_ornt) + + # 2. Create a "new" original at this orientation so we can save it AS is + # convert_to_simg sets the current state as the "original" + simg_target = StatefulImage.convert_to_simg( + simg_ras, simg_ras.bvals, simg_ras.bvecs) + + # 3. Save + target_img_p = os.path.join(tmpdir, "target.nii.gz") + target_bval_p = os.path.join(tmpdir, "target.bval") + target_bvec_p = os.path.join(tmpdir, "target.bvec") + + simg_target.save(target_img_p) + simg_target.save_gradients(target_bval_p, target_bvec_p) + + # 4. Load back (defaults to RAS) + simg_verify = StatefulImage.load( + target_img_p, to_orientation="RAS") + simg_verify.load_gradients(target_bval_p, target_bvec_p) + + # 5. Verify + assert simg_verify.axcodes == ("R", "A", "S", "T") + # Threshold for float precision after multiple transforms + assert np.allclose(simg_verify.bvecs, + original_bvecs, atol=1e-5) + assert np.allclose(simg_verify.bvals, bvals) + + # Go back to RAS for next iteration + simg_ras.to_ras() + + +def test_world_bvecs_non_diagonal_affine(): + # Create a rotation matrix (45 degrees around Z) + theta = np.pi / 4 + R = np.array([ + [np.cos(theta), -np.sin(theta), 0], + [np.sin(theta), np.cos(theta), 0], + [0, 0, 1] + ]) + affine = np.eye(4) + affine[:3, :3] = R + + with create_dummy_nifti_with_gradients(affine=affine) as (img_p, bval_p, bvec_p, bvals, bvecs): + simg = StatefulImage.load(img_p) + simg.attach_gradients(bvals, bvecs) + + # world_bvecs should be (bvecs * [-1, 1, 1]) * R.T because det(R) > 0 + bvecs_flipped = bvecs.copy() + bvecs_flipped[:, 0] *= -1 + expected_world_bvecs = np.dot(bvecs_flipped, R.T) + assert np.allclose(simg.world_bvecs, expected_world_bvecs) + + # bvecs property should return original bvecs (voxel space) + assert np.allclose(simg.bvecs, bvecs) + + # Save and reload + tmp_dir = os.path.dirname(img_p) + out_bval = os.path.join(tmp_dir, "out.bval") + out_bvec = os.path.join(tmp_dir, "out.bvec") + simg.save_gradients(out_bval, out_bvec) + + saved_bvecs = np.loadtxt(out_bvec).T + assert np.allclose(saved_bvecs, bvecs) + + +def test_world_bvecs_negative_det_affine(): + # LAS affine (det < 0) + affine = np.diag([-1, 1, 1, 1]) + + with create_dummy_nifti_with_gradients(affine=affine) as (img_p, bval_p, bvec_p, bvals, bvecs): + simg = StatefulImage.load(img_p) + simg.attach_gradients(bvals, bvecs) + + # world_bvecs should be bvecs * R.T because det(R) < 0 + R = simg._get_rotation_matrix(affine) + expected_world_bvecs = np.dot(bvecs, R.T) + assert np.allclose(simg.world_bvecs, expected_world_bvecs) + + # Verify that bvecs property returns original bvecs + assert np.allclose(simg.bvecs, bvecs) + + +def test_world_bvecs_reorientation_roundtrip(): + # Start with LAS (det < 0) + affine_las = np.diag([-1, 1, 1, 1]) + + with create_dummy_nifti_with_gradients(affine=affine_las) as (img_p, bval_p, bvec_p, bvals, bvecs_las): + simg = StatefulImage.load(img_p) + simg.attach_gradients(bvals, bvecs_las) + + # Reorient to RAS (det > 0) + simg.to_ras() + assert simg.axcodes == ("R", "A", "S", "T") + + # world_bvecs should remain the same + R_las = simg._get_rotation_matrix(affine_las) + expected_world_bvecs = np.dot(bvecs_las, R_las.T) + assert np.allclose(simg.world_bvecs, expected_world_bvecs) + + # bvecs in RAS should be flipped in X compared to world_bvecs * R_ras + # because det(RAS) > 0. + # Since R_ras = I, bvecs_ras = world_bvecs * [-1, 1, 1] + expected_bvecs_ras = expected_world_bvecs.copy() + expected_bvecs_ras[:, 0] *= -1 + assert np.allclose(simg.bvecs, expected_bvecs_ras) + + # Reorient back to LAS + simg.reorient("LAS") + assert np.allclose(simg.bvecs, bvecs_las) + assert np.allclose(simg.world_bvecs, expected_world_bvecs) diff --git a/src/scilpy/io/utils.py b/src/scilpy/io/utils.py index b8a6a86a8..db6a1e301 100644 --- a/src/scilpy/io/utils.py +++ b/src/scilpy/io/utils.py @@ -372,6 +372,10 @@ def add_sh_basis_args(parser, mandatory=False, input_output=False): parser.add_argument(arg_name, nargs=nargs, choices=choices, default=def_val, help=help_msg) + parser.add_argument('--is_voxel_space', action='store_true', + help='If set, assumes the input fODF/Peaks are already ' + 'in \nvoxel space. Default assumes world space ' + '(RAS).') def parse_sh_basis_arg(args): @@ -459,7 +463,11 @@ def add_peaks_screenshot_args(parser, default_width=3.0, default_alpha=1.0, help="Width of the peaks lines. [%(default)s]") rpg.add_argument("--peaks_opacity", type=ranged_type(float, 0., 1.), default=default_alpha, - help="Opacity value for the peaks overlay. [%(default)s]") + help="Opacity of the peaks, from 0 to 1. [%(default)s]") + rpg.add_argument('--is_voxel_space', action='store_true', + help='If set, assumes the input fODF/Peaks are already ' + 'in \nvoxel space. Default assumes world space ' + '(RAS).') def add_overlays_screenshot_args(parser, default_alpha=0.5, @@ -1250,7 +1258,14 @@ def get_default_screenshotting_data(args, peaks=True): peaks_imgs = None if peaks and args.peaks: - peaks_imgs = [nib.load(f) for f in args.peaks] + from scilpy.io.stateful_image import StatefulImage + peaks_imgs = [] + for f in args.peaks: + simg = StatefulImage.load(f, is_orientation=True, + is_world_space=not args.is_voxel_space) + # For screenshotting, we want the data in voxel space + # as the screenshotting actors currently assume voxel space. + peaks_imgs.append(simg) return (volume_img, transparency_img, diff --git a/src/scilpy/reconst/fodf.py b/src/scilpy/reconst/fodf.py index ed8ceb7c5..bdfd9300d 100644 --- a/src/scilpy/reconst/fodf.py +++ b/src/scilpy/reconst/fodf.py @@ -4,12 +4,9 @@ import multiprocessing import numpy as np -from dipy.data import get_sphere from dipy.reconst.mcsd import MSDeconvFit from dipy.reconst.multi_voxel import MultiVoxelFit -from dipy.reconst.shm import sh_to_sf_matrix -from scilpy.reconst.utils import find_order_from_nb_coeff from dipy.utils.optpkg import optional_package cvx, have_cvxpy, _ = optional_package("cvxpy") @@ -67,14 +64,14 @@ def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis, Mean maximum fODF value and mask of voxels used. """ - order = find_order_from_nb_coeff(data) - sphere = get_sphere(name='repulsion100') - b_matrix, _ = sh_to_sf_matrix(sphere, sh_order_max=order, - basis_type=sh_basis, legacy=is_legacy) + from scilpy.reconst.utils import compute_max_sf_amplitude + out_mask = np.zeros(data.shape[:-1]) if mask is None: - mask = np.ones(data.shape[:-1]) + mask = np.ones(data.shape[:-1], dtype=bool) + else: + mask = mask.astype(bool) # 1000 works well at 2x2x2 = 8 mm3 # Hence, we multiply by the volume of a voxel @@ -84,14 +81,14 @@ def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis, else: max_number_of_voxels = 1000 logging.debug("Searching for ventricle voxels, up to a maximum of {} " - "voxels.".format(max_number_of_voxels)) + f"voxels: {max_number_of_voxels}") # In the case of 2D-like data (3D data with one dimension size of 1), or # a small 3D dataset, the full range of data is scanned. if small_dims: - all_i = list(range(0, data.shape[0])) - all_j = list(range(0, data.shape[1])) - all_k = list(range(0, data.shape[2])) + range_i = slice(None) + range_j = slice(None) + range_k = slice(None) # In the case of a normal 3D dataset, a window is created in the middle of # the image to capture the ventricles. No need to scan the whole image. # (Automatic definition of window's radius based on the shape of the data.) @@ -104,35 +101,46 @@ def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis, else: radius = 5 - all_i = list(range(int(data.shape[0]/2) - radius, - int(data.shape[0]/2) + radius)) - all_j = list(range(int(data.shape[1]/2) - radius, - int(data.shape[1]/2) + radius)) - all_k = list(range(int(data.shape[2]/2) - radius, - int(data.shape[2]/2) + radius)) - - # Ok. Now find ventricle voxels. - list_of_max = [] - for i in all_i: - for j in all_j: - for k in all_k: - if len(list_of_max) > max_number_of_voxels - 1: - continue - if fa[i, j, k] < fa_threshold \ - and md[i, j, k] > md_threshold \ - and mask[i, j, k] == 1: - sf = np.dot(data[i, j, k], b_matrix) - list_of_max.append(sf.max()) - out_mask[i, j, k] = 1 - - logging.info('Number of voxels detected: {}'.format(len(list_of_max))) + mid_i, mid_j, mid_k = [int(s / 2) for s in data.shape[:-1]] + range_i = slice(mid_i - radius, mid_i + radius) + range_j = slice(mid_j - radius, mid_j + radius) + range_k = slice(mid_k - radius, mid_k + radius) + + # Find ventricle voxels candidates + ventricle_mask = np.zeros(data.shape[:-1], dtype=bool) + ventricle_mask[range_i, range_j, range_k] = ( + (fa[range_i, range_j, range_k] < fa_threshold) & + (md[range_i, range_j, range_k] > md_threshold) & + (mask[range_i, range_j, range_k]) + ) + + # Limit the number of voxels + ventricle_indices = np.argwhere(ventricle_mask) + if len(ventricle_indices) > max_number_of_voxels: + ventricle_indices = ventricle_indices[:max_number_of_voxels] + ventricle_mask[:] = False + ventricle_mask[tuple(ventricle_indices.T)] = True + + if not np.any(ventricle_mask): + logging.warning('No voxels found for evaluation! Change your fa ' + 'and/or md thresholds') + return 0, out_mask + + # Compute SF max in selected voxels + list_of_max = compute_max_sf_amplitude(data, sh_basis, is_legacy, + sphere_name='repulsion100', + mask=ventricle_mask) + list_of_max = list_of_max[ventricle_mask] + out_mask = ventricle_mask.astype(float) + + logging.info(f'Number of voxels detected: {len(list_of_max)}') if len(list_of_max) == 0: logging.warning('No voxels found for evaluation! Change your fa ' 'and/or md thresholds') return 0, out_mask - logging.info('Average max fodf value: {}'.format(np.mean(list_of_max))) - logging.info('Median max fodf value: {}'.format(np.median(list_of_max))) + logging.info(f'Average max fodf value: {np.mean(list_of_max)}') + logging.info(f'Median max fodf value: {np.median(list_of_max)}') if use_median: return np.median(list_of_max), out_mask else: diff --git a/src/scilpy/reconst/mti.py b/src/scilpy/reconst/mti.py index 976c53ad7..93b43c4b4 100644 --- a/src/scilpy/reconst/mti.py +++ b/src/scilpy/reconst/mti.py @@ -5,8 +5,6 @@ import scipy.io import scipy.ndimage -from scilpy.io.image import get_data_as_mask - def py_fspecial_gauss(shape, sigma): """ @@ -151,7 +149,7 @@ def compute_ratio_map(mt_on_single, mt_off, mt_on_dual=None): return MTR -def threshold_map(computed_map, in_mask, +def threshold_map(computed_map, mask_data, lower_threshold, upper_threshold, idx_contrast_list=None, contrast_maps=None): """ @@ -167,7 +165,8 @@ def threshold_map(computed_map, in_mask, ---------- computed_map: 3D-Array data. Myelin map (ihMT or non-ihMT maps) - in_mask: Path to binary T1 mask from T1 segmentation. + mask_data: Numpy array. + Binary T1 mask from T1 segmentation. Must be the sum of GM+WM+CSF. lower_threshold: Value for low thresold upper_thresold: Value for up thresold @@ -188,10 +187,8 @@ def threshold_map(computed_map, in_mask, computed_map[computed_map < lower_threshold] = 0 computed_map[computed_map > upper_threshold] = 0 - # Load and apply sum of T1 probability maps on myelin maps - if in_mask is not None: - mask_image = nib.load(in_mask) - mask_data = get_data_as_mask(mask_image) + # Apply T1 mask on myelin maps + if mask_data is not None: computed_map[np.where(mask_data == 0)] = 0 # Apply threshold based on combination of specific contrast maps diff --git a/src/scilpy/reconst/sh.py b/src/scilpy/reconst/sh.py index 1190fdf7c..29cd99d40 100644 --- a/src/scilpy/reconst/sh.py +++ b/src/scilpy/reconst/sh.py @@ -5,19 +5,20 @@ import numpy as np from dipy.core.sphere import Sphere +from dipy.core.subdivide_octahedron import create_unit_sphere from dipy.direction.peaks import peak_directions from dipy.reconst.odf import gfa from dipy.reconst.shm import (sh_to_sf_matrix, order_from_ncoef, sf_to_sh, sph_harm_ind_list) - from scilpy.gradients.bvec_bval_tools import (identify_shells, is_normalized_bvecs, normalize_bvecs, DEFAULT_B0_THRESHOLD) from scilpy.dwi.operations import compute_dwi_attenuation +from scilpy.reconst.utils import get_sh_order_and_fullness -def verify_data_vs_sh_order(data, sh_order): +def verify_data_vs_sh_order(data, sh_order, gtab=None): """ Raises a warning if the dwi data shape is not enough for the chosen sh_order. @@ -28,13 +29,24 @@ def verify_data_vs_sh_order(data, sh_order): Diffusion signal as weighted images (4D). sh_order: int SH order to fit, by default 4. + gtab: GradientTable, optional + Dipy object that contains all bvals and bvecs. """ - if data.shape[-1] < (sh_order + 1) * (sh_order + 2) / 2: - logging.warning( - 'We recommend having at least {} unique DWIs volumes, but you ' - 'currently have {} volumes. Try lowering the parameter --sh_order ' - 'in case of non convergence.'.format( - (sh_order + 1) * (sh_order + 2) / 2, data.shape[-1])) + if gtab is not None: + num_dwi = np.sum(~gtab.b0s_mask) + if num_dwi < (sh_order + 1) * (sh_order + 2) / 2: + logging.warning( + 'We recommend having at least {} unique DWI volumes, but you ' + 'currently have {} volumes (excluding b0). Try lowering the ' + 'parameter --sh_order in case of non convergence.'.format( + (sh_order + 1) * (sh_order + 2) / 2, num_dwi)) + else: + if data.shape[-1] < (sh_order + 1) * (sh_order + 2) / 2: + logging.warning( + 'We recommend having at least {} unique DWIs volumes, but you ' + 'currently have {} volumes. Try lowering the parameter --sh_order ' + 'in case of non convergence.'.format( + (sh_order + 1) * (sh_order + 2) / 2, data.shape[-1])) def compute_sh_coefficients(dwi, gradient_table, @@ -178,10 +190,113 @@ def compute_rish(sh, mask=None, full_basis=False): rish *= mask[..., None] orders = sorted(np.unique(order_ids)) - return rish, orders +def rotate_sh(sh_coeffs, rotation_matrix, basis_type='descoteaux07', + full_basis=False, is_legacy=True, nbr_processes=1, + in_place=False): + """ + Rotate SH coefficients using a rotation matrix. + + This implementation uses a discrete approach: + 1. Sample SH to SF on a dense sphere. + 2. Rotate the sphere points by the inverse rotation. + 3. Fit back to SH. + + Parameters + ---------- + sh_coeffs : np.ndarray + SH coefficients. Can be 1D or 4D (XxYxZxN). + rotation_matrix : np.ndarray (3, 3) + Rotation matrix. + basis_type : str, optional + SH basis type. + full_basis : bool, optional + Whether the SH basis is full. + is_legacy : bool, optional + Whether the SH basis is legacy. + nbr_processes : int, optional + Number of processes to use for rotation. + Default: 1. This is to avoid RAM problems and using all CPU for loading + FODFs by default. + in_place : bool, optional + If True, applies the rotation in-place directly on sh_coeffs + to save memory. Note: this alters the input array. + Default: False + + Returns + ------- + rotated_sh : np.ndarray + Rotated SH coefficients. + """ + if np.allclose(rotation_matrix, np.eye(3), atol=1e-6): + if not in_place: + return sh_coeffs.copy() + return sh_coeffs + + sh_order, full_basis = get_sh_order_and_fullness(sh_coeffs.shape[-1]) + + if nbr_processes is None or nbr_processes <= 0: + nbr_processes = 1 + + # Dense sphere to minimize aliasing/error + + # Level 5 octahedron subdivision gives 1026 vertices. + sphere = create_unit_sphere(recursion_level=5) + + # To rotate the function f by R, we want g(x) = f(R^-1 x). + # We sample g at points x_j (the sphere vertices). + # g(x_j) = f(R^-1 x_j). + # R^-1 x_j are the "rotated" sphere vertices. + inv_R = np.linalg.inv(rotation_matrix) + rotated_xyz = np.dot(sphere.vertices, inv_R.T) + rotated_sphere = Sphere(xyz=rotated_xyz) + + # Handle 1D vs 4D data + original_shape = sh_coeffs.shape + if len(original_shape) == 1: + sh_coeffs = sh_coeffs[None, None, None, :] + + if not in_place: + sh_coeffs = sh_coeffs.copy() + + # Compute mask to save memory on large volumes + mask = np.any(sh_coeffs, axis=-1) + + indices = np.nonzero(mask) + num_non_zero = indices[0].shape[0] + + CHUNK_SIZE = 50000 + for start_idx in range(0, num_non_zero, CHUNK_SIZE): + end_idx = min(start_idx + CHUNK_SIZE, num_non_zero) + + chunk_idx = tuple(idx[start_idx:end_idx] for idx in indices) + sh_chunk = sh_coeffs[chunk_idx] + + # Sample original SH at rotated positions + # Use scilpy's convert_sh_to_sf for memory efficiency (masking) + sf_chunk = convert_sh_to_sf(sh_chunk.astype(np.float32), rotated_sphere, + input_basis=basis_type, + input_full_basis=full_basis, + is_input_legacy=is_legacy, + mask=None, + dtype="float32", + nbr_processes=nbr_processes) + + # Fit these values back to SH using the ORIGINAL sphere (the canonical basis) + rotated_sh_chunk = sf_to_sh(sf_chunk, sphere, sh_order_max=sh_order, + basis_type=basis_type, full_basis=full_basis, + legacy=is_legacy) + + sh_coeffs[chunk_idx] = rotated_sh_chunk.astype(sh_coeffs.dtype) + + if len(original_shape) == 1: + return sh_coeffs.reshape(-1) + + return sh_coeffs + + def _peaks_from_sh_parallel(args): (shm_coeff, B, sphere, relative_peak_threshold, absolute_threshold, min_separation_angle, @@ -213,10 +328,11 @@ def _peaks_from_sh_loop(shm_coeff, B, sphere, relative_peak_threshold, odf = np.dot(shm_coeff[idx], B) odf[odf < absolute_threshold] = 0. - dirs, peaks, ind = peak_directions(odf, sphere, - relative_peak_threshold=relative_peak_threshold, - min_separation_angle=min_separation_angle, - is_symmetric=is_symmetric) + dirs, peaks, ind = peak_directions( + odf, sphere, + relative_peak_threshold=relative_peak_threshold, + min_separation_angle=min_separation_angle, + is_symmetric=is_symmetric) if peaks.shape[0] != 0: n = min(npeaks, peaks.shape[0]) @@ -354,7 +470,7 @@ def peaks_from_sh(shm_coeff, sphere, mask=None, relative_peak_threshold=0.5, # Bring back to the original shape peak_dirs_array = np.zeros(data_shape[0:3] + (npeaks, 3)) peak_values_array = np.zeros(data_shape[0:3] + (npeaks,)) - peak_indices_array = np.zeros(data_shape[0:3] + (npeaks,)) + peak_indices_array = np.full(data_shape[0:3] + (npeaks,), -1, dtype=np.int32) peak_dirs_array[mask] = tmp_peak_dirs_array peak_values_array[mask] = tmp_peak_values_array peak_indices_array[mask] = tmp_peak_indices_array @@ -555,14 +671,13 @@ def _convert_sh_basis_parallel(args): def _convert_sh_basis_loop(sh, B_in, invB_out): """ - Loops on 2D (ravelled) data and fits each voxel separately. + Vectorized SH basis conversion. For a more complete description of parameters, see convert_sh_basis. """ # Data: Ravelled 4D data. Shape [N, X] where N is the number of voxels. - for idx in range(sh.shape[0]): - if sh[idx].any(): - sf = np.dot(sh[idx], B_in) - sh[idx] = np.dot(sf, invB_out) + if sh.any(): + sf = np.dot(sh, B_in) + sh = np.dot(sf, invB_out) return sh @@ -623,15 +738,14 @@ def convert_sh_basis(shm_coeff, sphere, mask=None, data_shape = shm_coeff.shape if mask is None: - mask = np.sum(shm_coeff, axis=3).astype(bool) + mask = np.sum(shm_coeff, axis=-1).astype(bool) nbr_processes = multiprocessing.cpu_count() \ if nbr_processes is None or nbr_processes < 0 else nbr_processes - # Ravel the first 3 dimensions while keeping the 4th intact, like a list of - # 1D time series voxels. + # Ravel the first N-1 dimensions while keeping the last one intact. shm_coeff = shm_coeff[mask].reshape( - (np.count_nonzero(mask), data_shape[3])) + (np.count_nonzero(mask), data_shape[-1])) # Separating the case nbr_processes=1 to help get good coverage metrics # (codecov does not deal well with multiprocessing) @@ -642,20 +756,19 @@ def convert_sh_basis(shm_coeff, sphere, mask=None, shm_coeff_chunks = np.array_split(shm_coeff, nbr_processes) pool = multiprocessing.Pool(nbr_processes) - results = pool.map(_convert_sh_basis_parallel, - zip(shm_coeff_chunks, - itertools.repeat(B_in), - itertools.repeat(invB_out), - np.arange(len(shm_coeff_chunks)))) - pool.close() - pool.join() - - # Re-assemble the chunk together. chunk_len = np.cumsum([0] + [len(c) for c in shm_coeff_chunks]) tmp_shm_coeff_array = np.zeros((np.count_nonzero(mask), data_shape[3])) - for i, new_shm_coeff in results: + + for i, new_shm_coeff in pool.imap_unordered(_convert_sh_basis_parallel, + zip(shm_coeff_chunks, + itertools.repeat(B_in), + itertools.repeat(invB_out), + np.arange(len(shm_coeff_chunks)))): tmp_shm_coeff_array[chunk_len[i]:chunk_len[i+1], :] = new_shm_coeff + pool.close() + pool.join() + # Bring back to the original shape shm_coeff_array = np.zeros(data_shape) shm_coeff_array[mask] = tmp_shm_coeff_array @@ -671,15 +784,11 @@ def _convert_sh_to_sf_parallel(args): def _convert_sh_to_sf_loop(sh, new_output_dim, B_in): """ - Loops on 2D data and fits each voxel separately. + Vectorized matrix multiplication for SH to SF conversion. See convert_sh_to_sf for more information. """ # Data: Ravelled 4D data. Shape [N, X] where N is the number of voxels. - sf = np.zeros((sh.shape[0], new_output_dim), dtype=np.float32) - - for idx in range(sh.shape[0]): - if sh[idx].any(): - sf[idx] = np.dot(sh[idx], B_in) + sf = np.dot(sh, B_in).astype(np.float32) return sf @@ -687,7 +796,7 @@ def _convert_sh_to_sf_loop(sh, new_output_dim, B_in): def convert_sh_to_sf(shm_coeff, sphere, mask=None, dtype="float32", input_basis='descoteaux07', input_full_basis=False, is_input_legacy=True, - nbr_processes=multiprocessing.cpu_count()): + nbr_processes=None): """Converts spherical harmonic coefficients to an SF sphere Parameters @@ -733,15 +842,17 @@ def convert_sh_to_sf(shm_coeff, sphere, mask=None, dtype="float32", data_shape = shm_coeff.shape if mask is None: - mask = np.sum(shm_coeff, axis=3).astype(bool) + mask = np.sum(shm_coeff, axis=-1).astype(bool) + + nbr_processes = multiprocessing.cpu_count() if nbr_processes is None \ + or nbr_processes <= 0 else nbr_processes output_dim = len(sphere.vertices) - new_shape = data_shape[:3] + (output_dim,) + new_shape = data_shape[:-1] + (output_dim,) - # Ravel the first 3 dimensions while keeping the 4th intact, like a list of - # 1D time series voxels. + # Ravel the first N-1 dimensions while keeping the last one intact. shm_coeff = shm_coeff[mask].reshape( - (np.count_nonzero(mask), data_shape[3])) + (np.count_nonzero(mask), data_shape[-1])) # Separating the case nbr_processes=1 to help get good coverage metrics # (codecov does not deal well with multiprocessing) @@ -752,21 +863,20 @@ def convert_sh_to_sf(shm_coeff, sphere, mask=None, dtype="float32", shm_coeff_chunks = np.array_split(shm_coeff, nbr_processes) pool = multiprocessing.Pool(nbr_processes) - results = pool.map(_convert_sh_to_sf_parallel, - zip(shm_coeff_chunks, - itertools.repeat(B_in), - itertools.repeat(output_dim), - np.arange(len(shm_coeff_chunks)))) - pool.close() - pool.join() - - # Re-assemble the chunk together. chunk_len = np.cumsum([0] + [len(c) for c in shm_coeff_chunks]) - tmp_sf_array = np.zeros((np.count_nonzero(mask), new_shape[3]), + tmp_sf_array = np.zeros((np.count_nonzero(mask), new_shape[-1]), dtype=dtype) - for i, new_sf in results: + + for i, new_sf in pool.imap_unordered(_convert_sh_to_sf_parallel, + zip(shm_coeff_chunks, + itertools.repeat(B_in), + itertools.repeat(output_dim), + np.arange(len(shm_coeff_chunks)))): tmp_sf_array[chunk_len[i]:chunk_len[i + 1], :] = new_sf + pool.close() + pool.join() + # Bring back to the original shape sf_array = np.zeros(new_shape, dtype=dtype) sf_array[mask] = tmp_sf_array diff --git a/src/scilpy/reconst/tests/test_fodf.py b/src/scilpy/reconst/tests/test_fodf.py index 89911702d..a0a7b5c86 100644 --- a/src/scilpy/reconst/tests/test_fodf.py +++ b/src/scilpy/reconst/tests/test_fodf.py @@ -37,7 +37,7 @@ def test_get_ventricles_max_fodf(): sf1 = np.dot(fodf_3x3_order8_descoteaux07[1, 0, 0], b_matrix) sf2 = np.dot(fodf_3x3_order8_descoteaux07[1, 1, 0], b_matrix) - assert mean == np.mean([np.max(sf1), np.max(sf2)]) + assert np.allclose(mean, np.mean([np.max(sf1), np.max(sf2)]), atol=1e-6) def test_get_ventricles_max_fodf_median(): @@ -64,7 +64,7 @@ def test_get_ventricles_max_fodf_median(): sf1 = np.dot(fodf_3x3_order8_descoteaux07[1, 0, 0], b_matrix) sf2 = np.dot(fodf_3x3_order8_descoteaux07[1, 1, 0], b_matrix) - assert median == np.median([np.max(sf1), np.max(sf2)]) + assert np.allclose(median, np.median([np.max(sf1), np.max(sf2)]), atol=1e-6) def test_get_ventricles_max_fodf_mask(): diff --git a/src/scilpy/reconst/utils.py b/src/scilpy/reconst/utils.py index 054236177..1ed85b98e 100644 --- a/src/scilpy/reconst/utils.py +++ b/src/scilpy/reconst/utils.py @@ -33,7 +33,9 @@ def get_maximas(data, sphere, b_matrix, threshold, absolute_threshold, spherical_func = np.dot(data, b_matrix.T) spherical_func[np.nonzero(spherical_func < absolute_threshold)] = 0. return peak_directions( - spherical_func, sphere, threshold, min_separation_angle) + spherical_func, sphere, + relative_peak_threshold=threshold, + min_separation_angle=min_separation_angle) def get_sphere_neighbours(sphere, max_angle): @@ -57,3 +59,196 @@ def get_sphere_neighbours(sphere, max_angle): np.outer(zs, zs)) neighbours = scalar_prods >= np.cos(max_angle) return neighbours + + +def is_data_peaks(img_data): + """ + Heuristic to find out if the input are peaks or fodf. + fodf are always around 0.15 and peaks around 0.75. + Peaks have more zero values than fodf. The first value of fodf is + usually the highest. + + Parameters + ---------- + img_data : np.ndarray + 4D image data where the last dimension contains directional info. + + Returns + ------- + is_peaks : bool + True if data is likely peaks, False if likely fODF (SH). + """ + last_dim = img_data.shape[-1] + if last_dim == 3: + return True + + # Sum of absolute values to detect non-zero voxels correctly + non_zeros_mask = np.any(np.abs(img_data) > 0, axis=-1) + if not np.count_nonzero(non_zeros_mask): + return False + + try: + order, full = get_sh_order_and_fullness(last_dim) + # Symmetric SH must be even order + if not full and order % 2 != 0: + return False + except ValueError: + # If not a valid SH number of coefficients, and not 3, + # it might be something else, but if it's a multiple of 3 + # it's likely Peaks. + if last_dim % 3 == 0: + return True + return False + + data_nz = img_data[non_zeros_mask] + + # If all triplets have the same norm, it is likely peaks, otherwise SH. + if last_dim % 3 == 0: + norm = np.linalg.norm(data_nz.reshape(-1, 3), axis=-1) + if np.all(np.isclose(norm, norm[0])): + return True + + # If the max is in the first triplet but not at index 0, it's likely Peaks. + # Smoothed SH almost always has max at index 0 + argmax_indices = np.argmax(np.abs(data_nz), axis=-1) + if last_dim % 3 == 0 and \ + np.mean(np.logical_or(argmax_indices == 1, + argmax_indices == 2)) > 0.1: + return True + + # Exact zeros. SH almost never has exact zeros in real data. + # Peaks often have exact zeros for unused lobes + zero_ratio = np.mean(data_nz == 0) + if zero_ratio > 0.05: + return True + + # Default to SH + return False + + +def compute_max_sf_amplitude(data, sh_basis, is_legacy, + sphere_name='repulsion100', mask=None): + """ + Compute the maximum SF amplitude for each voxel. + Only computes SF for voxels where data is non-zero (or in mask) to save RAM. + + Parameters + ---------- + data : np.ndarray + ODF data (SH). + sh_basis : str + SH basis ('tournier07' or 'descoteaux07'). + is_legacy : bool + Whether the SH basis is legacy. + sphere_name : str or dipy.core.sphere.Sphere, optional + Sphere name for SF conversion or Sphere object. + mask : np.ndarray, optional + Binary mask. If provided, only voxels in mask are computed. + + Returns + ------- + max_sf : np.ndarray + Maximum SF amplitude per voxel. + """ + from dipy.data import get_sphere + from dipy.reconst.shm import sh_to_sf_matrix + from dipy.core.sphere import Sphere + + if mask is None: + mask = np.any(data, axis=-1) + + order = find_order_from_nb_coeff(data) + if isinstance(sphere_name, (Sphere,)): + sphere = sphere_name + else: + sphere = get_sphere(name=sphere_name) + + b_matrix, _ = sh_to_sf_matrix(sphere, sh_order_max=order, + basis_type=sh_basis, legacy=is_legacy) + + max_sf = np.zeros(data.shape[:-1], dtype=np.float32) + if np.any(mask): + # Vectorized SF computation for masked voxels + sf = np.dot(data[mask], b_matrix) + max_sf[mask] = np.max(sf, axis=-1) + + return max_sf + + +def compute_sf_threshold_mask(data, sphere_name='repulsion100', + relative_factor=None, + absolute_threshold=None, + sh_basis='descoteaux07', + is_legacy=True, postprocess_mask=True): + """ + Compute a binary mask based on a global SF amplitude threshold. + + Parameters + ---------- + data : np.ndarray + ODF data (SH or Peaks). + sphere_name : str or dipy.core.sphere.Sphere, optional + Sphere name for SF conversion or Sphere object. + relative_factor : float, optional + Factor between 0 and 1. Threshold is factor * global_max_sf. + absolute_threshold : float, optional + Absolute threshold on SF amplitude. + sh_basis : str, optional + SH basis ('tournier07' or 'descoteaux07'). + is_legacy : bool, optional + Whether the SH basis is legacy. + + Returns + ------- + mask : np.ndarray + Binary mask. + global_max : float + Global maximum SF amplitude. + threshold : float + Computed threshold value. + """ + if relative_factor is None and absolute_threshold is None: + raise ValueError("Either relative_factor or absolute_threshold " + "must be provided.") + + is_peaks = is_data_peaks(data) + if is_peaks: + npeaks = data.shape[-1] // 3 + peaks = data.reshape(data.shape[:3] + (npeaks, 3)) + norms = np.linalg.norm(peaks, axis=-1) + # maximum amplitude/norm across peaks + max_amp = np.max(norms, axis=-1) + else: + max_amp = compute_max_sf_amplitude(data, sh_basis, is_legacy, + sphere_name=sphere_name) + + global_max = np.max(max_amp) + + if absolute_threshold is not None: + threshold = absolute_threshold + else: + threshold = relative_factor * global_max + + mask = max_amp >= threshold + + if postprocess_mask: + import scipy.ndimage as ndi + # Postprocess to labels all elements and count voxels for each label + labels = ndi.label(mask)[0] + label_counts = np.bincount(labels.ravel()) + # Find the largest connected component (excluding background) + largest_label = np.argmax(label_counts[1:]) + 1 # +1 to skip background + # Create a mask for the largest connected component + mask = labels == largest_label + inverted_mask = ~mask + + # Remove isolated voxels in the inverted mask (holes in the main mask) + labels_inverted = ndi.label(inverted_mask)[0] + label_counts_inverted = np.bincount(labels_inverted.ravel()) + for label, count in enumerate(label_counts_inverted): + if label == 0: + continue # Skip background + if count < 100: # Threshold for filling holes (can be adjusted) + mask[labels_inverted == label] = True + + return mask, global_max, threshold diff --git a/src/scilpy/tests/test_stateful_image_direction.py b/src/scilpy/tests/test_stateful_image_direction.py new file mode 100644 index 000000000..ddaca1bdd --- /dev/null +++ b/src/scilpy/tests/test_stateful_image_direction.py @@ -0,0 +1,103 @@ +# -*- coding: utf-8 -*- + +import numpy as np +import nibabel as nib +from scilpy.io.stateful_image import StatefulImage +from scilpy.reconst.utils import is_data_peaks + + +def test_peak_direction_transform(): + # Create a 90-degree rotation affine (X-axis) + # y_world = -z_voxel, z_world = y_voxel + affine = np.array([ + [1, 0, 0, 0], + [0, 0, -1, 0], + [0, 1, 0, 0], + [0, 0, 0, 1] + ]) + + # 1. Test Peaks (3 coefficients) + data_peaks = np.zeros((2, 2, 2, 3)) + data_peaks[:, :, :, :] = [0, 0, 1] # Voxel Z + + img = nib.Nifti1Image(data_peaks, affine) + simg = StatefulImage.convert_to_simg(img) + + # Voxel (0,0,1) -> World (0,-1,0) + world_peaks = simg.to_world_direction(data_peaks) + expected_world = [0, -1, 0] + np.testing.assert_allclose(world_peaks[0, 0, 0], expected_world, atol=1e-5) + + # World (0,-1,0) -> Voxel (0,0,1) + voxel_peaks = simg.to_voxel_direction(world_peaks) + expected_voxel = [0, 0, 1] + np.testing.assert_allclose(voxel_peaks[0, 0, 0], expected_voxel, atol=1e-5) + + +def test_sh_direction_transform(): + # Create a 90-degree rotation affine (X-axis) + affine = np.array([ + [1, 0, 0, 0], + [0, 0, -1, 0], + [0, 1, 0, 0], + [0, 0, 0, 1] + ]) + + # Order 2, 6 coefficients for symmetric + data_sh = np.zeros((2, 2, 2, 6)) + # Isotropic part, make sure it's the max so it's recognized as SH + data_sh[:, :, :, 0] = 5.0 + data_sh[:, :, :, 1:] = 0.01 # Add noise to prevent exact zeros + data_sh[:, :, :, 3] = 1.0 # Some orientation part + + img = nib.Nifti1Image(data_sh, affine) + simg = StatefulImage.convert_to_simg(img) + + # Verify it doesn't crash and changes coefficients + world_sh = simg.to_world_direction(data_sh) + assert not np.allclose(world_sh[0, 0, 0], data_sh[0, 0, 0]) + + # Reverting should return original + back_sh = simg.to_voxel_direction(world_sh) + np.testing.assert_allclose(back_sh, data_sh, atol=1e-5) + + +def test_stateful_image_load_direction(tmp_path): + affine = np.array([ + [1, 0, 0, 0], + [0, 0, -1, 0], + [0, 1, 0, 0], + [0, 0, 0, 1] + ]) + data_peaks = np.zeros((2, 2, 2, 3)) + data_peaks[:, :, :, :] = [0, 0, 1] # Voxel Z + + img_path = str(tmp_path / "voxel_peaks.nii.gz") + nib.save(nib.Nifti1Image(data_peaks, affine), img_path) + + # Load as voxel-space directional image + # Internal representation should move to World Space (0, -1, 0) + simg = StatefulImage.load(img_path, is_orientation=True, + is_world_space=False) + + expected_world = [0, -1, 0] + np.testing.assert_allclose(simg.get_fdata()[0, 0, 0], expected_world, + atol=1e-5) + + +def test_heuristic_is_data_peaks(): + # Peaks: multiple peaks with zeros or high argmax + peaks_data = np.zeros((2, 2, 2, 6)) + # Make sure the max is in the first triplet to pass the + # `argmax_indices > 2` check + # But place it at index 1 to trigger the `== 1 or == 2` check + peaks_data[0, 0, 0, :3] = [0, 1, 0] # Peak 1 is Y + # Argmax is 1 -> is_peaks should be True + assert is_data_peaks(peaks_data) is True + + # SH: First value (l=0) is usually highest + sh_data = np.zeros((2, 2, 2, 6)) + sh_data[:, :, :, 0] = 1.0 # l=0 + sh_data[:, :, :, 1:] = 0.1 # Small l=2 + # Argmax is 0 -> is_peaks should be False + assert is_data_peaks(sh_data) is False diff --git a/src/scilpy/tests/test_tracking_io_alignment.py b/src/scilpy/tests/test_tracking_io_alignment.py new file mode 100644 index 000000000..1f2c8a51c --- /dev/null +++ b/src/scilpy/tests/test_tracking_io_alignment.py @@ -0,0 +1,221 @@ +import numpy as np +import nibabel as nib +import pytest +from dipy.io.stateful_tractogram import StatefulTractogram, Space +from dipy.io.streamline import load_tractogram, save_tractogram +from scilpy.tracking.utils import save_tractogram as scil_save_tractogram + + +def create_fake_header(affine, shape=(10, 10, 10)): + data = np.zeros(shape) + img = nib.Nifti1Image(data, affine) + return img + + +@pytest.mark.parametrize("affine_type", ["iso_1mm", "iso_2mm", "aniso", "complex"]) +@pytest.mark.parametrize("ext", [".trk", ".tck"]) +def test_tracking_io_alignment(tmp_path, affine_type, ext): + if affine_type == "iso_1mm": + affine = np.diag([1, 1, 1, 1]) + elif affine_type == "iso_2mm": + affine = np.diag([2, 2, 2, 1]) + elif affine_type == "aniso": + affine = np.diag([1, 1, 2, 1]) + elif affine_type == "complex": + # Rotation 30 deg around Z, scaling, translation + theta = np.radians(30) + c, s = np.cos(theta), np.sin(theta) + R = np.array([ + [c, -s, 0], + [s, c, 0], + [0, 0, 1] + ]) + S = np.diag([1.1, 0.9, 1.2]) + T = np.array([10, -20, 30]) + affine = np.eye(4) + affine[:3, :3] = R @ S + affine[:3, 3] = T + + img = create_fake_header(affine) + img_path = str(tmp_path / "ref.nii.gz") + nib.save(img, img_path) + + # Create streamlines in VOXEL space, origin CENTER + # (0,0,0) to (5,5,5) + vox_streamlines = [np.array([ + [0, 0, 0], + [1, 1, 1], + [2, 2, 2], + [5, 5, 5] + ], dtype=float)] + + # Convert to RASMM for StatefulTractogram + # StatefulTractogram expects streamlines in RASMM if space is Space.RASMM + sft = StatefulTractogram(vox_streamlines, img, Space.VOX) + + output_path = str(tmp_path / f"tracto{ext}") + + # Method 1: Use DIPY save_tractogram (standard) + save_tractogram(sft, output_path) + + # Reload and check + sft_loaded = load_tractogram(output_path, img_path) + + # Check streamlines in VOX space + sft_loaded.to_vox() + loaded_vox = sft_loaded.streamlines + + assert len(loaded_vox) == len(vox_streamlines) + for orig, loaded in zip(vox_streamlines, loaded_vox): + assert np.allclose(orig, loaded, atol=1e-3) + + # Check streamlines in RASMM space + sft.to_rasmm() + sft_loaded.to_rasmm() + for orig, loaded in zip(sft.streamlines, sft_loaded.streamlines): + assert np.allclose(orig, loaded, atol=1e-3) + + +@pytest.mark.parametrize("affine_type", ["iso_1mm", "iso_2mm", "aniso", "complex"]) +@pytest.mark.parametrize("ext", [".trk", ".tck"]) +def test_scil_save_tractogram_alignment(tmp_path, affine_type, ext): + if affine_type == "iso_1mm": + affine = np.diag([1, 1, 1, 1]) + elif affine_type == "iso_2mm": + affine = np.diag([2, 2, 2, 1]) + elif affine_type == "aniso": + affine = np.diag([1, 1, 2, 1]) + elif affine_type == "complex": + theta = np.radians(30) + c, s = np.cos(theta), np.sin(theta) + R = np.array([[c, -s, 0], [s, c, 0], [0, 0, 1]]) + S = np.diag([1.1, 0.9, 1.2]) + T = np.array([10, -20, 30]) + affine = np.eye(4) + affine[:3, :3] = R @ S + affine[:3, 3] = T + + img = create_fake_header(affine) + img_path = str(tmp_path / "ref.nii.gz") + nib.save(img, img_path) + + # Create streamlines in VOXEL space, origin CENTER + vox_streamlines = [np.array([ + [0, 0, 0], + [1, 1, 1], + [2, 2, 2], + [5, 5, 5] + ], dtype=float)] + + # Generator for scil_save_tractogram + # it yields (streamline, seed) + # We make it a list so it's re-iterable if needed + stream_gen_list = [(s.copy(), s[0].copy()) for s in vox_streamlines] + + output_path = str(tmp_path / f"scil_tracto{ext}") + tracts_format = nib.streamlines.detect_format(output_path) + + # scil_save_tractogram(streamlines_generator, tracts_format, ref_img, total_nb_seeds, + # out_tractogram, min_length, max_length, compress, save_seeds, verbose) + scil_save_tractogram(stream_gen_list, tracts_format, img, len(vox_streamlines), + output_path, 0, 1000, None, False, False) + + # Reload and check + sft_loaded = load_tractogram(output_path, img_path) + sft_loaded.to_vox() + loaded_vox = sft_loaded.streamlines + + assert len(loaded_vox) == len(vox_streamlines) + for orig, loaded in zip(vox_streamlines, loaded_vox): + # Using a slightly larger tolerance because TRK/TCK might have some + # precision loss or 0.5 offset handling differences + assert np.allclose(orig, loaded, atol=1e-3) + + +def test_tck_trk_physical_alignment(tmp_path): + # Rotation 45 deg around X + theta = np.radians(45) + c, s = np.cos(theta), np.sin(theta) + R = np.array([ + [1, 0, 0], + [0, c, -s], + [0, s, c] + ]) + affine = np.eye(4) + affine[:3, :3] = R + + img = create_fake_header(affine) + img_path = str(tmp_path / "ref.nii.gz") + nib.save(img, img_path) + + vox_streamlines = [np.array([ + [0, 0, 0], + [1, 2, 3], + [5, 5, 5] + ], dtype=float)] + + sft = StatefulTractogram(vox_streamlines, img, Space.VOX) + + trk_path = str(tmp_path / "tracto.trk") + tck_path = str(tmp_path / "tracto.tck") + + save_tractogram(sft, trk_path) + save_tractogram(sft, tck_path) + + sft_trk = load_tractogram(trk_path, img_path) + sft_tck = load_tractogram(tck_path, img_path) + + sft_trk.to_rasmm() + sft_tck.to_rasmm() + + for s_trk, s_tck in zip(sft_trk.streamlines, sft_tck.streamlines): + assert np.allclose(s_trk, s_tck, atol=1e-3) + + +def test_negative_det_alignment(tmp_path): + # LAS affine (det < 0) + affine = np.diag([-1, 1, 1, 1]) + # Add some translation to make it more interesting + affine[:3, 3] = [100, 100, 100] + + img = create_fake_header(affine) + img_path = str(tmp_path / "ref.nii.gz") + nib.save(img, img_path) + + vox_streamlines = [np.array([ + [0, 0, 0], + [1, 2, 3], + [5, 5, 5] + ], dtype=float)] + + sft = StatefulTractogram(vox_streamlines, img, Space.VOX) + + trk_path = str(tmp_path / "tracto.trk") + tck_path = str(tmp_path / "tracto.tck") + + save_tractogram(sft, trk_path) + save_tractogram(sft, tck_path) + + sft_trk = load_tractogram(trk_path, img_path) + sft_tck = load_tractogram(tck_path, img_path) + + sft_trk.to_rasmm() + sft_tck.to_rasmm() + + # Verify they align in RASMM + for s_trk, s_tck in zip(sft_trk.streamlines, sft_tck.streamlines): + assert np.allclose(s_trk, s_tck, atol=1e-3) + + # Verify RASMM coordinates manually + # v_rasmm = R * v_vox + T + # For LAS: R = [[-1, 0, 0], [0, 1, 0], [0, 0, 1]], T = [100, 100, 100] + # [0,0,0] -> [-1*0+100, 1*0+100, 1*0+100] = [100, 100, 100] + # [1,2,3] -> [-1*1+100, 1*2+100, 1*3+100] = [99, 102, 103] + expected_rasmm = [np.array([ + [100, 100, 100], + [99, 102, 103], + [95, 105, 105] + ], dtype=float)] + + for s_trk, s_exp in zip(sft_trk.streamlines, expected_rasmm): + assert np.allclose(s_trk, s_exp, atol=1e-3) diff --git a/src/scilpy/tests/test_world_space_pipeline.py b/src/scilpy/tests/test_world_space_pipeline.py new file mode 100644 index 000000000..058a6c409 --- /dev/null +++ b/src/scilpy/tests/test_world_space_pipeline.py @@ -0,0 +1,136 @@ +# -*- coding: utf-8 -*- + +import nibabel as nib +import numpy as np +import pytest +from dipy.core.gradients import gradient_table +from dipy.reconst.dti import TensorModel +from dipy.io.stateful_tractogram import Space +from dipy.io.streamline import load_tractogram +from scilpy.io.stateful_image import StatefulImage +from scilpy.tracking.seed import SeedGenerator +from scilpy.tracking.utils import save_tractogram + + +@pytest.fixture +def rotated_las_dataset(tmp_path): + """ + Create a mock LAS dataset with 45-degree rotation around Z and 2mm voxels. + """ + affine = np.array([ + [-1.414, 1.414, 0.0, 50.0], + [-1.414, -1.414, 0.0, 50.0], + [0.0, 0.0, 2.0, -20.0], + [0.0, 0.0, 0.0, 1.0] + ]) + + # Gradients (6 dirs + 1 b0) - Defined in WORLD X alignment + bvecs_world = np.array([ + [0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], + [-1, 0, 0], [0, -1, 0], [0, 0, -1] + ]) + bvals = np.array([0, 1000, 1000, 1000, 1000, 1000, 1000]) + + # Simulate signal + data = np.ones((10, 10, 10, 7)) * 20 + data[2:8, 2:8, 2:8, 0] = 100 + for i in range(1, 7): + g = bvecs_world[i] + cos_theta = np.dot(g, [1, 0, 0]) + data[2:8, 2:8, 2:8, i] = 100 * np.exp(-1.0 * (cos_theta**2)) + + # Back-project world bvecs to voxel space for FSL file + R = affine[:3, :3] + R_inv = np.linalg.inv(R / np.linalg.norm(R, axis=0)) + bvecs_vox = np.dot(bvecs_world, R_inv.T) + if np.linalg.det(R) > 0: + bvecs_vox[:, 0] *= -1 + + dwi_path = str(tmp_path / "dwi.nii.gz") + bval_path = str(tmp_path / "dwi.bval") + bvec_path = str(tmp_path / "dwi.bvec") + + nib.save(nib.Nifti1Image(data.astype(np.float32), affine), dwi_path) + np.savetxt(bval_path, bvals[None, :], fmt='%d') + np.savetxt(bvec_path, bvecs_vox.T, fmt='%.8f') + + return dwi_path, bval_path, bvec_path, bvecs_world + + +def test_stateful_image_world_gradients(rotated_las_dataset): + dwi, bval, bvec, bvecs_world_truth = rotated_las_dataset + simg = StatefulImage.load(dwi) + simg.load_gradients(bval, bvec) + + # Assert world_bvecs match truth + np.testing.assert_allclose(simg.world_bvecs, bvecs_world_truth, atol=1e-2) + + # Assert saving and reloading recovers world truth + tmp_bvec = bvec + "_tmp.bvec" + simg.save_gradients(bval, tmp_bvec) + simg2 = StatefulImage.load(dwi) + simg2.load_gradients(bval, tmp_bvec) + np.testing.assert_allclose(simg2.world_bvecs, bvecs_world_truth, atol=1e-2) + + +def test_dti_fitting_world_space(rotated_las_dataset): + dwi, bval, bvec, _ = rotated_las_dataset + simg = StatefulImage.load(dwi) + simg.load_gradients(bval, bvec) + + gtab = gradient_table(simg.bvals, bvecs=simg.world_bvecs) + tenmodel = TensorModel(gtab) + tenfit = tenmodel.fit(simg.get_fdata()) + + peak = tenfit.evecs[5, 5, 5, 0] + # Fiber was simulated along physical X + assert np.abs(peak[0]) > 0.8 + + +def test_tracking_seeding_world_space(rotated_las_dataset): + dwi, bval, bvec, _ = rotated_las_dataset + simg = StatefulImage.load(dwi) + + seed_data = np.zeros((10, 10, 10)) + seed_data[5, 5, 5] = 1 + + seed_gen = SeedGenerator(seed_data, simg.header.get_zooms()[:3], + affine=simg.affine, space=Space.RASMM) + + rng = np.random.RandomState(42) + seed = seed_gen.get_next_pos(rng, np.arange(1), 0) + + # Project back to voxel space and check index + inv_affine = np.linalg.inv(simg.affine) + seed_vox = np.dot(inv_affine, np.append(seed, 1.0))[:3] + np.testing.assert_allclose(seed_vox, [5.5, 5.5, 5.5], atol=1.0) + + +def test_save_tractogram_world_space(tmp_path, rotated_las_dataset): + dwi, bval, bvec, _ = rotated_las_dataset + simg = StatefulImage.load(dwi) + + # World coordinate for center of (5,5,5) + seed_world = np.dot(simg.affine, [5.5, 5.5, 5.5, 1])[:3] + streamline = np.array([seed_world, seed_world + [10, 0, 0]], dtype=float) + + def mock_gen(): + yield streamline, seed_world + + out_trk_scil = str(tmp_path / "test_scil.trk") + from nibabel.streamlines import TrkFile as NibTrkFile + save_tractogram(mock_gen, NibTrkFile, simg, 1, out_trk_scil, + 0, 1000, None, True, False, space=Space.RASMM) + + sft_scil = load_tractogram(out_trk_scil, dwi, bbox_valid_check=False) + assert len(sft_scil.streamlines) == 1, "Scilpy save_tractogram produced empty file!" + sft_scil.to_rasmm() + + # Assert coordinates match + np.testing.assert_allclose(sft_scil.streamlines[0], streamline, atol=1e-2) + # Assert seed was saved correctly in DPS + np.testing.assert_allclose(sft_scil.data_per_streamline['seeds'][0], seed_world, atol=1e-2) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/src/scilpy/tracking/propagator.py b/src/scilpy/tracking/propagator.py index 32f7e6b5d..82bb9746c 100644 --- a/src/scilpy/tracking/propagator.py +++ b/src/scilpy/tracking/propagator.py @@ -373,11 +373,6 @@ def __init__(self, datavolume, step_size, super().__init__(datavolume, step_size, rk_order, dipy_sphere, sub_sphere, space, origin) - if self.space == Space.RASMM: - raise NotImplementedError( - "This version of the propagator on ODF is not ready to work " - "in RASMM space.") - # Warn user if the rk order does not match the algo if rk_order != 1 and algo == 'prob': logging.warning('Probabilistic tracking with RK order != 1 is ' diff --git a/src/scilpy/tracking/seed.py b/src/scilpy/tracking/seed.py index da13f1d18..710629259 100644 --- a/src/scilpy/tracking/seed.py +++ b/src/scilpy/tracking/seed.py @@ -10,15 +10,12 @@ class SeedGenerator: """ Class to get seeding positions. - Generated seeds are in voxmm space, origin=corner. Ex: a seed sampled - exactly at voxel i,j,k = (0,1,2), with resolution 3x3x3mm will have - coordinates x,y,z = (0, 3, 6). - - Using get_next_pos, seeds are placed randomly within the voxel. In the same - example as above, seed sampled in voxel i,j,k = (0,1,2) will be somewhere - in the range x = [0, 3], y = [3, 6], z = [6, 9]. + Generated seeds are in the specified space (default Space.VOX) and + origin (default Origin.CENTER). For tracking in physical space, + Space.RASMM should be used. """ - def __init__(self, data, voxres, + + def __init__(self, data, voxres, affine=None, space=Space('vox'), origin=Origin('center'), n_repeats=1): """ Parameters @@ -28,22 +25,26 @@ def __init__(self, data, voxres, to find all voxels with values > 0, but will not be kept in memory. voxres: np.ndarray(3,) The pixel resolution, ex, using img.header.get_zooms()[:3]. + affine: np.ndarray(4,4) + The affine matrix mapping voxel coordinates to RASMM. n_repeats: int Number of times a same seed position is returned. If used, we supposed that calls to either get_next_pos or get_next_n_pos are used sequentially. Not verified. """ self.voxres = voxres + self.affine = affine self.n_repeats = n_repeats self.origin = origin self.space = space - if space == Space.RASMM: - raise NotImplementedError("We do not support rasmm space.") - elif space not in [Space.VOX, Space.VOXMM]: + if space not in [Space.VOX, Space.VOXMM, Space.RASMM]: raise ValueError("Space should be a choice of Dipy Space.") if origin not in [Origin.NIFTI, Origin.TRACKVIS]: raise ValueError("Origin should be a choice of Dipy Origin.") + if space == Space.RASMM and affine is None: + raise ValueError("Affine matrix is required for RASMM space.") + # self.seed are all the voxels where a seed could be placed # (voxel space, origin=corner, int numbers). self.seeds_vox_corner = np.array(np.where(np.squeeze(data) > 0), @@ -112,8 +113,14 @@ def get_next_pos(self, random_generator, shuffled_indices, which_seed): return x, y, z elif self.space == Space.VOXMM: return x * self.voxres[0], y * self.voxres[1], z * self.voxres[2] + elif self.space == Space.RASMM: + # If origin is center, we need to add 0.5 to get back to corner + # before applying the affine. + if self.origin == Origin('center'): + x, y, z = x + 0.5, y + 0.5, z + 0.5 + return np.dot(self.affine, [x, y, z, 1])[:3] else: - raise NotImplementedError("We do not support rasmm space.") + raise ValueError("Space should be a choice of Dipy Space.") def get_next_n_pos(self, random_generator, shuffled_indices, which_seed_start, n): @@ -209,8 +216,14 @@ def get_next_n_pos(self, random_generator, shuffled_indices, seed = [x * self.voxres[0], y * self.voxres[1], z * self.voxres[2]] + elif self.space == Space.RASMM: + # If origin is center, we need to add 0.5 to get back to corner + # before applying the affine. + if self.origin == Origin('center'): + x, y, z = x + 0.5, y + 0.5, z + 0.5 + seed = np.dot(self.affine, [x, y, z, 1])[:3] else: - raise NotImplementedError("We do not support rasmm space.") + raise ValueError("Space should be a choice of Dipy Space.") seeds.append(seed) return seeds @@ -274,8 +287,9 @@ class FibertubeSeedGenerator(SeedGenerator): fibertube tracking. Generates a given number of seed within the first segment of a given number of fibertubes. """ + def __init__(self, centerlines, diameters, nb_seeds_per_fibertube, - local_seeding: Literal['center', 'random']): + local_seeding: Literal['center', 'random']): """ Parameters ---------- @@ -387,6 +401,7 @@ class CustomSeedsDispenser(SeedGenerator): Adaptation of the scilpy.tracking.seed.SeedGenerator interface for using already generated, custom seeds. """ + def __init__(self, custom_seeds, space=Space('vox'), origin=Origin('center')): """ diff --git a/src/scilpy/tracking/tests/test_tracking_utils.py b/src/scilpy/tracking/tests/test_tracking_utils.py new file mode 100644 index 000000000..81cd1d7f8 --- /dev/null +++ b/src/scilpy/tracking/tests/test_tracking_utils.py @@ -0,0 +1,70 @@ +import numpy as np +import nibabel as nib +import pytest + +from dipy.io.streamline import load_tractogram + +from scilpy.tracking.utils import save_tractogram as scil_save_tractogram + + +def create_fake_header(affine, shape=(10, 10, 10)): + data = np.zeros(shape) + img = nib.Nifti1Image(data, affine) + return img + + +@pytest.mark.parametrize("affine_type", ["iso_1mm", "iso_2mm", "aniso", "complex"]) +@pytest.mark.parametrize("ext", [".trk", ".tck"]) +def test_scil_save_tractogram_alignment(tmp_path, affine_type, ext): + if affine_type == "iso_1mm": + affine = np.diag([1, 1, 1, 1]) + elif affine_type == "iso_2mm": + affine = np.diag([2, 2, 2, 1]) + elif affine_type == "aniso": + affine = np.diag([1, 1, 2, 1]) + elif affine_type == "complex": + # Rotation 30 deg around Z, scaling, translation + theta = np.radians(30) + c, s = np.cos(theta), np.sin(theta) + R = np.array([ + [c, -s, 0], + [s, c, 0], + [0, 0, 1] + ]) + S = np.diag([1.1, 0.9, 1.2]) + T = np.array([10, -20, 30]) + affine = np.eye(4) + affine[:3, :3] = R @ S + affine[:3, 3] = T + + img = create_fake_header(affine) + img_path = str(tmp_path / "ref.nii.gz") + nib.save(img, img_path) + + # Create streamlines in VOXEL space, origin CENTER + vox_streamlines = [np.array([ + [1, 1, 1], + [2, 2, 2], + [5, 5, 5] + ], dtype=float)] + + # Generator for scil_save_tractogram + # it yields (streamline, seed) + stream_gen_list = [(s.copy(), s[0].copy()) for s in vox_streamlines] + + output_path = str(tmp_path / f"scil_tracto{ext}") + tracts_format = nib.streamlines.detect_format(output_path) + + # scil_save_tractogram(streamlines_generator, tracts_format, ref_img, total_nb_seeds, + # out_tractogram, min_length, max_length, compress, save_seeds, verbose) + scil_save_tractogram(stream_gen_list, tracts_format, img, len(vox_streamlines), + output_path, 0, 1000, None, False, False) + + # Reload and check + sft_loaded = load_tractogram(output_path, img_path) + sft_loaded.to_vox() + loaded_vox = sft_loaded.streamlines + + assert len(loaded_vox) == len(vox_streamlines) + for orig, loaded in zip(vox_streamlines, loaded_vox): + assert np.allclose(orig, loaded, atol=1e-3) diff --git a/src/scilpy/tracking/tracker.py b/src/scilpy/tracking/tracker.py index c8db1f1c6..19304a429 100644 --- a/src/scilpy/tracking/tracker.py +++ b/src/scilpy/tracking/tracker.py @@ -112,10 +112,6 @@ def __init__(self, propagator: AbstractPropagator, mask: DataVolume, self.origin = self.propagator.origin self.space = self.propagator.space - if self.space == Space.RASMM: - raise NotImplementedError( - "This version of the Tracker is not ready to work in RASMM " - "space.") if (seed_generator.origin != propagator.origin or seed_generator.space != propagator.space): raise ValueError("Seed generator and propagator must work with " @@ -159,7 +155,8 @@ def save_rap_entry_exit_mask(self, output_path, reference_img): mask_data = np.zeros(reference_img.shape[:3], dtype=np.uint8) # Convert coordinates to voxel space and set mask values - # Each element is a tuple (coord, coord_type) where coord_type is 1 (entry) or 2 (exit) + # Each element is a tuple (coord, coord_type) where coord_type is 1 + # (entry) or 2 (exit) for coord, coord_type in self.rap_entry_exit_coords: # Coordinates are already in voxel space (VOX, center) # Round to nearest integer voxel @@ -168,7 +165,7 @@ def save_rap_entry_exit_mask(self, output_path, reference_img): # Check bounds if (0 <= vox_coord[0] < mask_data.shape[0] and 0 <= vox_coord[1] < mask_data.shape[1] and - 0 <= vox_coord[2] < mask_data.shape[2]): + 0 <= vox_coord[2] < mask_data.shape[2]): # Use max to handle overlapping entry/exit points # If both entry and exit occur at same voxel, exit (2) will prevail mask_data[vox_coord[0], vox_coord[1], vox_coord[2]] = max( @@ -182,7 +179,8 @@ def save_rap_entry_exit_mask(self, output_path, reference_img): entry_count = sum(1 for _, t in self.rap_entry_exit_coords if t == 1) exit_count = sum(1 for _, t in self.rap_entry_exit_coords if t == 2) logging.info(f"Saved RAP entry/exit mask to {output_path}") - logging.info(f"Entry coordinates: {entry_count}, Exit coordinates: {exit_count}") + logging.info( + f"Entry coordinates: {entry_count}, Exit coordinates: {exit_count}") logging.info(f"Unique voxels with entry (1): {np.sum(mask_data == 1)}, " f"exit (2): {np.sum(mask_data == 2)}") @@ -309,7 +307,6 @@ def _get_streamlines_sub(self, params): List of list of 3D positions (streamlines). """ chunk_id, lock = params - global multiprocess_init_args self._reload_data_for_new_process(multiprocess_init_args) try: @@ -392,7 +389,7 @@ def _get_streamlines(self, chunk_id, lock=None): # on current process ID. eps = s + chunk_id / (self.nbr_processes + 1) line_generator = np.random.default_rng( - np.abs(hash((seed + (eps, eps, eps), self.rng_seed)))) + np.abs(hash((tuple(seed + (eps, eps, eps)), self.rng_seed)))) # Forward and backward tracking line = self._get_line_both_directions(seed, line_generator) @@ -509,7 +506,8 @@ def _propagate_line(self, line, previous_dir): # Detect entering RAP region if is_currently_in_rap and not in_rap_region: - self.rap_entry_exit_coords.append((line[-1].copy(), 1)) # 1 for entry + self.rap_entry_exit_coords.append( + (line[-1].copy(), 1)) # 1 for entry in_rap_region = True logging.debug(f"TRACKER ENTERING pos={np.round(line[-1], 2)}") @@ -527,7 +525,8 @@ def _propagate_line(self, line, previous_dir): new_pos = line[-1] # Verify that our RAP propagated point stays within the tracking mask - propagation_can_continue = self._verify_stopping_criteria(new_pos) + propagation_can_continue = self._verify_stopping_criteria( + new_pos) if not propagation_can_continue: logging.debug("TRACKER out of mask, stop.") line.pop() @@ -547,13 +546,16 @@ def _propagate_line(self, line, previous_dir): if invalid_direction_count > self.max_invalid_dirs: break - propagation_can_continue = self._verify_stopping_criteria(new_pos) + propagation_can_continue = self._verify_stopping_criteria( + new_pos) if propagation_can_continue or self.append_last_point: line.append(new_pos) previous_dir = new_dir - logging.debug(f"TRACKER end of propagation: {len(line)} total points, last pos={np.round(line[-1], 2)}") + logging.debug( + "TRACKER end of propagation: {} total points, last pos={}" + .format(len(line), np.round(line[-1], 2))) return line def _verify_stopping_criteria(self, last_pos): @@ -614,6 +616,7 @@ class GPUTracker(): GPU tracking mode. `prob` samples directions from the SF and `det` follows the maximum SF direction. """ + def __init__(self, sh, mask, seeds, step_size, max_nbr_pts, theta=20.0, sf_threshold=0.1, sh_interp='trilinear', sh_basis='descoteaux07', is_legacy=True, batch_size=100000, @@ -697,7 +700,8 @@ def _track(self): 'true' if self.forward_only else 'false') cl_kernel.set_define('PROBABILISTIC', 'true' if self.probabilistic else 'false') - cl_kernel.set_define('RNG_SEED', '{}u'.format(np.uint32(self.rng_seed))) + cl_kernel.set_define( + 'RNG_SEED', '{}u'.format(np.uint32(self.rng_seed))) cl_kernel.set_define('SF_THRESHOLD', '{:.8f}f'.format(self.sf_threshold)) cl_kernel.set_define('SH_INTERP_NN', @@ -749,4 +753,4 @@ def _track(self): # output is yielded so that we can use LazyTractogram. # seed and strl with origin center (same as DIPY) - yield strl - 0.5, seed - 0.5 \ No newline at end of file + yield strl - 0.5, seed - 0.5 diff --git a/src/scilpy/tracking/utils.py b/src/scilpy/tracking/utils.py index 6a8f26a96..1e069f348 100644 --- a/src/scilpy/tracking/utils.py +++ b/src/scilpy/tracking/utils.py @@ -13,9 +13,11 @@ from dipy.direction import (DeterministicMaximumDirectionGetter, ProbabilisticDirectionGetter, PTTDirectionGetter) from dipy.direction.peaks import PeaksAndMetrics +from dipy.io.stateful_tractogram import Origin, Space from dipy.io.utils import create_tractogram_header, get_reference_info from dipy.reconst.shm import sh_to_sf_matrix from dipy.tracking.streamlinespeed import compress_streamlines, length + from scilpy.io.utils import (add_compression_arg, add_overwrite_arg, add_sh_basis_args) from scilpy.reconst.utils import find_order_from_nb_coeff, get_maximas @@ -99,8 +101,20 @@ def add_tracking_options(p): '["eudx"=60, "det"=45, "prob"=20, "ptt"=20]') track_g.add_argument('--sfthres', dest='sf_threshold', metavar='sf_th', type=float, default=0.1, - help='Spherical function relative threshold. ' - '[%(default)s]') + help='Spherical function relative threshold ' + 'within each voxel. [%(default)s]') + global_sf_g = track_g.add_mutually_exclusive_group() + global_sf_g.add_argument('--global_sf_rel_thr', metavar='FACTOR', + type=float, nargs='?', const=0.1, default=None, + help='Global SF relative threshold factor.' \ + 'If set, masks voxels where\nmax SF amplitude < ' + 'FACTOR * max global SF amplitude. \n' + 'If used without a value, default is [%(const)s].') + global_sf_g.add_argument('--global_sf_abs_thr', metavar='ABS_THR', + type=float, + help='Global SF absolute threshold.' + 'If set, masks voxels where \n' + 'max SF amplitude < ABS_THR.') add_sh_basis_args(track_g) return track_g @@ -203,7 +217,8 @@ def tqdm_if_verbose(generator: Iterable, verbose: bool, *args, **kwargs): def save_tractogram( streamlines_generator, tracts_format, ref_img, total_nb_seeds, - out_tractogram, min_length, max_length, compress, save_seeds, verbose + out_tractogram, min_length, max_length, compress, save_seeds, verbose, + space=Space.VOX, origin=Origin.NIFTI ): """ Save the streamlines on-the-fly using a generator. Tracts are filtered according to their length and compressed if requested. Seeds @@ -216,7 +231,7 @@ def save_tractogram( Streamlines generator. tracts_format : TrkFile or TckFile Tractogram format. - ref_img : nibabel.Nifti1Image + ref_img : nibabel.Nifti1Image or scilpy.io.stateful_image.StatefulImage Image used as reference. total_nb_seeds : int Total number of seeds. @@ -233,63 +248,95 @@ def save_tractogram( data_per_streamline property. verbose : bool If True, display progression bar. + space : Space + Space in which the streamlines are generated. + origin : Origin + Origin in which the streamlines are generated. """ + voxel_size = np.array(ref_img.header.get_zooms()[:3]) + # If ref_img is a StatefulImage, we want to save relative to its + # original on-disk orientation, not the internal (likely RAS) one. + from scilpy.io.stateful_image import StatefulImage + is_stateful = isinstance(ref_img, StatefulImage) + + if is_stateful: + affine_mod = ref_img.affine.copy() + affine_ori = ref_img._original_affine + original_voxel_size = np.array(ref_img._original_voxel_sizes[:3]) + else: + affine_mod = ref_img.affine.copy() + affine_ori = ref_img.affine.copy() + original_voxel_size = voxel_size - voxel_size = ref_img.header.get_zooms()[0] - - scaled_min_length = min_length / voxel_size - scaled_max_length = max_length / voxel_size - - # Tracking is expected to be returned in voxel space, origin `center`. def tracks_generator_wrapper(): - for strl, seed in tqdm_if_verbose(streamlines_generator, + # If streamlines_generator is a callable, call it to get a new generator. + # This allows re-iterating if LazyTractogram needs it. + if callable(streamlines_generator): + iterable = streamlines_generator() + else: + iterable = streamlines_generator + + miniters = int(total_nb_seeds / 100) if total_nb_seeds >= 100 else 1 + for strl, seed in tqdm_if_verbose(iterable, verbose=verbose, total=total_nb_seeds, - miniters=int(total_nb_seeds / 100), + miniters=miniters, leave=False): - if (scaled_min_length <= length(strl) <= scaled_max_length): - # Seeds are saved with origin `center` by our own convention. - # Other scripts (e.g. scil_tractogram_seed_density_map) expect - # so. - dps = {} + # 1. Get to RASMM (physical world space) for filtering and compression + if space == Space.VOX: + strl_rasmm = nib.affines.apply_affine(affine_mod, strl) + elif space == Space.VOXMM: + strl_rasmm = nib.affines.apply_affine( + affine_mod, strl / voxel_size) + elif space == Space.RASMM: + strl_rasmm = strl + else: + raise ValueError("Unknown space") + + strl_len = length(strl_rasmm) + if (min_length <= strl_len <= max_length): + # Prepare DPS for this streamline + strl_dps = {} if save_seeds: - dps['seeds'] = seed + strl_dps['seeds'] = seed if compress: - # compression threshold is given in mm, but we - # are in voxel space - strl = compress_streamlines( - strl, compress / voxel_size) + strl_rasmm = compress_streamlines(strl_rasmm, compress) - # TODO: Use nibabel utilities for dealing with spaces if tracts_format is TrkFile: - # Streamlines are dumped in mm space with - # origin `corner`. This is what is expected by - # LazyTractogram for .trk files (although this is not - # specified anywhere in the doc) - strl += 0.5 - strl *= voxel_size # in mm. + # TRK expects VOXMM relative to original orientation + strl_vox = nib.affines.apply_affine( + np.linalg.inv(affine_ori), strl_rasmm) + # Add half-voxel shift to go from scilpy center-origin + # to nibabel corner-origin in voxmm. + strl_to_save = (strl_vox + 0.5) * original_voxel_size else: - # Streamlines are dumped in true world space with - # origin center as expected by .tck files. - strl = np.dot(strl, ref_img.affine[:3, :3]) + \ - ref_img.affine[:3, 3] + # TCK expects RASMM + strl_to_save = strl_rasmm - yield TractogramItem(strl, dps, {}) + yield TractogramItem(strl_to_save, strl_dps, {}) tractogram = LazyTractogram.from_data_func(tracks_generator_wrapper) - tractogram.affine_to_rasmm = ref_img.affine + tractogram.affine_to_rasmm = np.eye(4) filetype = nib.streamlines.detect_format(out_tractogram) - reference = get_reference_info(ref_img) - header = create_tractogram_header(filetype, *reference) + + if is_stateful: + reference = (ref_img._original_affine, + ref_img._original_dimensions[:3], + ref_img._original_voxel_sizes[:3], + "".join(ref_img._original_axcodes[:3])) + header = create_tractogram_header(filetype, *reference) + else: + reference = get_reference_info(ref_img) + header = create_tractogram_header(filetype, *reference) # Use generator to save the streamlines on-the-fly nib.streamlines.save(tractogram, out_tractogram, header=header) -def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis, +def get_direction_getter(img_data, algo, sphere, sub_sphere, theta, sh_basis, voxel_size, sf_threshold, sh_to_pmf, probe_length, probe_radius, probe_quality, probe_count, support_exponent, is_legacy=True): @@ -297,8 +344,8 @@ def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis, Parameters ---------- - in_img: str - Path to the input odf file. + img_data: ndarray + The input odf data. algo: str Algorithm to use for tracking. Can be 'det', 'prob', 'ptt' or 'eudx'. sphere: str @@ -343,21 +390,14 @@ def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis, dg: dipy.direction.DirectionGetter The direction getter object. """ - img_data = nib.load(in_img).get_fdata(dtype=np.float32) - sphere = HemiSphere.from_sphere( get_sphere(name=sphere)).subdivide(n=sub_sphere) # Theta depends on user choice and algorithm theta = get_theta(theta, algo) - # Heuristic to find out if the input are peaks or fodf - # fodf are always around 0.15 and peaks around 0.75 - # Peaks have more zero values than fodf. The first value of fodf is - # usually the highest. - non_zeros_count = np.count_nonzero(np.sum(img_data, axis=-1)) - non_first_val_count = np.count_nonzero(np.argmax(img_data, axis=-1)) - is_peaks = non_first_val_count / non_zeros_count > 0.5 + from scilpy.reconst.utils import is_data_peaks + is_peaks = is_data_peaks(img_data) if algo in ['det', 'prob', 'ptt']: if is_peaks: @@ -367,10 +407,10 @@ def get_direction_getter(in_img, algo, sphere, sub_sphere, theta, sh_basis, kwargs = {} if algo == 'ptt': dg_class = PTTDirectionGetter - # Considering the step size usually used, the probe length - # can be set as the voxel size. - kwargs = {'probe_length': probe_length, - 'probe_radius': probe_radius, + # Probe length and radius are in mm, convert to voxel units + # since tracking is performed in voxel space (identity affine). + kwargs = {'probe_length': probe_length / voxel_size, + 'probe_radius': probe_radius / voxel_size, 'probe_quality': probe_quality, 'probe_count': probe_count, 'data_support_exponent': support_exponent} diff --git a/src/scilpy/tractanalysis/voxel_boundary_intersection.pyx b/src/scilpy/tractanalysis/voxel_boundary_intersection.pyx index 70cf12510..5050b47b0 100644 --- a/src/scilpy/tractanalysis/voxel_boundary_intersection.pyx +++ b/src/scilpy/tractanalysis/voxel_boundary_intersection.pyx @@ -51,23 +51,23 @@ def subdivide_streamlines_at_voxel_faces(streamlines): cdef: cnp.npy_intp nb_streamlines = len(streamlines._lengths) cnp.npy_intp at_point = 0 + cnp.float32_t[:, :] data_view_in = streamlines_data # Multiplying by 6 is simply a heuristic to avoiding resizing too many # times. In my bundles tests, I had either 0 or 1 resize. cnp.npy_intp max_points = (streamlines_data.size // 6) * 12 new_array_sequence = nib.streamlines.array_sequence.ArraySequence() - new_array_sequence._lengths.resize(nb_streamlines) - new_array_sequence._offsets.resize(nb_streamlines) + new_array_sequence._lengths.resize(nb_streamlines, refcheck=False) + new_array_sequence._offsets.resize(nb_streamlines, refcheck=False) new_array_sequence._data = np.empty(max_points * 3, np.float32) cdef: cnp.npy_intp[:] lengths_view_in = streamlines._lengths cnp.npy_intp[:] offsets_view_in = streamlines._offsets - float[:, :] data_view_in = streamlines_data cnp.npy_intp[:] lengths_view_out = new_array_sequence._lengths cnp.npy_intp[:] offsets_view_out = new_array_sequence._offsets - cnp.float32_t[:] data_view_out = new_array_sequence._data + cnp.float32_t[:] data_view_out = np.asarray(new_array_sequence._data, dtype=np.float32) cdef Pointers pointers pointers.lengths_in = &lengths_view_in[0] diff --git a/src/scilpy/tractograms/uncompress.pyx b/src/scilpy/tractograms/uncompress.pyx index 6cf137b10..d8c94ae69 100644 --- a/src/scilpy/tractograms/uncompress.pyx +++ b/src/scilpy/tractograms/uncompress.pyx @@ -69,13 +69,13 @@ def streamlines_to_voxel_coordinates(streamlines, return_mapping=False): cnp.npy_intp max_points = (streamlines._data.size // 3) new_array_sequence = nib.streamlines.array_sequence.ArraySequence() - new_array_sequence._lengths.resize(nb_streamlines) - new_array_sequence._offsets.resize(nb_streamlines) + new_array_sequence._lengths.resize(nb_streamlines, refcheck=False) + new_array_sequence._offsets.resize(nb_streamlines, refcheck=False) new_array_sequence._data = np.empty(max_points * 3, np.uint16) points_to_index = nib.streamlines.array_sequence.ArraySequence() - points_to_index._lengths.resize(nb_streamlines) - points_to_index._offsets.resize(nb_streamlines) + points_to_index._lengths.resize(nb_streamlines, refcheck=False) + points_to_index._offsets.resize(nb_streamlines, refcheck=False) points_to_index._data = np.zeros(int(streamlines._data.size / 3), np.uint16) cdef: diff --git a/src/scilpy/utils/orientation.py b/src/scilpy/utils/orientation.py index 1c2b70365..65e3ebea0 100644 --- a/src/scilpy/utils/orientation.py +++ b/src/scilpy/utils/orientation.py @@ -6,16 +6,19 @@ def validate_voxel_order(axcodes, dimensions=3): """ Validate a set of axis codes. + Parameters ---------- axcodes : str or tuple or list The axis codes to validate (e.g., "LPS", ("R", "A", "S")). dimensions : int The number of dimensions of the image. + Returns ------- tuple A tuple of validated axis codes. + Raises ------ ValueError @@ -26,12 +29,13 @@ def validate_voxel_order(axcodes, dimensions=3): axcodes = tuple(axcodes) if len(axcodes) != dimensions: - raise ValueError(f"Target axis codes must be of length {dimensions}.") + raise ValueError(f"Target axis codes must be of length {dimensions}. " + f"Got {len(axcodes)}.") # Check unique are only valid axis codes valid_codes = {"L", "R", "A", "P", "S", "I"} - if dimensions == 4: - valid_codes.add("T") + if dimensions >= 4: + valid_codes.update(["T", "U", "V", "W", "X", "Y", "Z"]) for code in axcodes: if code not in valid_codes: raise ValueError(f"Invalid axis code '{code}' in target.") @@ -53,16 +57,18 @@ def parse_voxel_order(order_str, dimensions=3): """ Parse the voxel order string into a tuple of axis codes. """ - order_str_cleaned = order_str.replace(',', '').replace(' ', '') + order_str_cleaned = order_str.replace(',', '').replace(' ', '').upper() - if dimensions == 4 and order_str_cleaned.isalpha(): - raise ValueError("Alphabetical voxel order is not supported for 4D " - "images. Please use numeric format.") + if dimensions == 4 and order_str_cleaned.isalpha() and \ + len(order_str_cleaned) == 3: + order_str_cleaned += 'T' if order_str_cleaned.isalpha(): - if len(order_str_cleaned) != 3: - raise ValueError("Voxel order string must have 3 characters.") - return validate_voxel_order(tuple(order_str_cleaned.upper())) + if len(order_str_cleaned) != dimensions: + raise ValueError(f"Voxel order string must have {dimensions} " + f"characters.") + return validate_voxel_order(tuple(order_str_cleaned), + dimensions=dimensions) if order_str_cleaned.replace('-', '').isdigit(): numeric_parts = re.findall(r'-?\d', order_str_cleaned) @@ -73,9 +79,12 @@ def parse_voxel_order(order_str, dimensions=3): if dimensions == 4: ras_map = {1: 'R', 2: 'A', 3: 'S', 4: 'T'} - flip_map = {'R': 'L', 'A': 'P', 'S': 'I', 'T': 'T'} + flip_map = {'R': 'L', 'A': 'P', 'S': 'I'} if len(numeric_parts) == 4: - if abs(int(numeric_parts[3])) != 4: + if int(numeric_parts[3]) == -4: + raise ValueError("Flipping the 4th dimension is not " + "supported.") + if int(numeric_parts[3]) != 4: raise ValueError("The 4th dimension must be 4 or -4.") else: ras_map = {1: 'R', 2: 'A', 3: 'S'} @@ -86,19 +95,24 @@ def parse_voxel_order(order_str, dimensions=3): num = int(part) axis = ras_map[abs(num)] if num < 0: + if axis not in flip_map: + raise ValueError(f"Axis {axis} cannot be flipped.") axis = flip_map[axis] order.append(axis) + if dimensions == 4 and len(order) == 3: + order.append('T') + # Check for duplicate axes - if len(set(order)) != len(numeric_parts): + if len(set(order)) != len(order): # Handle swapped axes from numeric input (e.g., '231') axis_vals = [ras_map[abs(int(p))] for p in numeric_parts] if len(set(axis_vals)) == len(numeric_parts): - return validate_voxel_order(tuple(order), dimensions=len(numeric_parts)) + return validate_voxel_order(tuple(order), dimensions=dimensions) else: raise ValueError("Invalid numeric voxel order. " "Axes cannot be repeated.") - return validate_voxel_order(tuple(order), dimensions=len(numeric_parts)) - + return validate_voxel_order(tuple(order), dimensions=dimensions) + raise ValueError(f"Invalid voxel order format: {order_str}") diff --git a/src/scilpy/utils/scilpy_bot.py b/src/scilpy/utils/scilpy_bot.py index b3792776e..cce8b04ad 100644 --- a/src/scilpy/utils/scilpy_bot.py +++ b/src/scilpy/utils/scilpy_bot.py @@ -57,7 +57,7 @@ def _make_title(text): Returns a formatted title string with centered text and spacing """ return f'{Fore.LIGHTBLUE_EX}{Style.BRIGHT}{text.center(SPACING_LEN, "=")}' \ - f'{Style.RESET_ALL}' + f'{Style.RESET_ALL}' def _get_docstring_from_script_path(script): @@ -273,7 +273,7 @@ def _highlight_keywords(text, all_expressions): # Function to apply highlighting to the matched word def apply_highlight(match): return f'{Fore.LIGHTYELLOW_EX}{Style.BRIGHT}{match.group(0)}' \ - f'{Style.RESET_ALL}' + f'{Style.RESET_ALL}' # Replace the matched word with its highlighted version text = pattern.sub(apply_highlight, text) diff --git a/src/scilpy/utils/tests/test_orientation.py b/src/scilpy/utils/tests/test_orientation.py index 815794550..a6af6af39 100644 --- a/src/scilpy/utils/tests/test_orientation.py +++ b/src/scilpy/utils/tests/test_orientation.py @@ -84,19 +84,17 @@ def test_parse_voxel_order_invalid_format(): match="Voxel order string must have 3 or 4 numbers."): parse_voxel_order("1,2,3,4,5", dimensions=4) + def test_parse_voxel_order_4d_valid_numeric(): """Test parsing of valid 4D numeric voxel order strings.""" assert parse_voxel_order("1,2,3,4", dimensions=4) == ("R", "A", "S", "T") assert parse_voxel_order("-1,2,-3,4", dimensions=4) == ("L", "A", "I", "T") - assert parse_voxel_order("2,3,1", dimensions=4) == ("A", "S", "R") + assert parse_voxel_order("2,3,1", dimensions=4) == ("A", "S", "R", "T") -def test_parse_voxel_order_4d_invalid_alpha(): - """Test that 4D alphabetical voxel order strings raise an error.""" - with pytest.raises(ValueError, - match="Alphabetical voxel order is not supported for 4D " - "images. Please use numeric format."): - parse_voxel_order("RAS", dimensions=4) +def test_parse_voxel_order_4d_alpha(): + """Test that 4D alphabetical voxel order strings are now supported.""" + assert parse_voxel_order("RAS", dimensions=4) == ("R", "A", "S", "T") def test_parse_voxel_order_4d_invalid_numeric(): @@ -105,6 +103,10 @@ def test_parse_voxel_order_4d_invalid_numeric(): match="The 4th dimension must be 4 or -4."): parse_voxel_order("1,2,3,5", dimensions=4) + with pytest.raises(ValueError, + match="Flipping the 4th dimension is not supported."): + parse_voxel_order("1,2,3,-4", dimensions=4) + with pytest.raises(ValueError, match="Voxel order string must have 3 or 4 numbers."): parse_voxel_order("1,2", dimensions=4) diff --git a/src/scilpy/viz/backends/fury.py b/src/scilpy/viz/backends/fury.py index 3cd8888a5..828f661f4 100644 --- a/src/scilpy/viz/backends/fury.py +++ b/src/scilpy/viz/backends/fury.py @@ -20,7 +20,8 @@ class CamParams(Enum): PARA_SCALE = 'parallel_scale' -def initialize_camera(orientation, slice_index, volume_shape, aspect_ratio): +def initialize_camera(orientation, slice_index, volume_shape, aspect_ratio, + affine=None): """ Initialize a camera for a given orientation. The camera's focus (VIEW_CENTER) is set to the slice_index along the chosen orientation, at @@ -58,6 +59,8 @@ def initialize_camera(orientation, slice_index, volume_shape, aspect_ratio): Shape of the sliced volume. aspect_ratio : float Ratio between viewport's width and height. + affine : np.ndarray, optional + Voxel-to-world affine matrix. Returns ------- @@ -88,11 +91,30 @@ def initialize_camera(orientation, slice_index, volume_shape, aspect_ratio): camera[CamParams.VIEW_UP] = np.zeros((3,)) camera[CamParams.VIEW_UP][vert_idx] = -1.0 + if affine is not None: + # Transform VIEW_CENTER + center = np.append(camera[CamParams.VIEW_CENTER], 1.0) + camera[CamParams.VIEW_CENTER] = np.dot(affine, center)[:3] + + # Transform VIEW_POS + pos = np.append(camera[CamParams.VIEW_POS], 1.0) + camera[CamParams.VIEW_POS] = np.dot(affine, pos)[:3] + + # Transform VIEW_UP + up = np.append(camera[CamParams.VIEW_UP], 0.0) + camera[CamParams.VIEW_UP] = np.dot(affine, up)[:3] + camera[CamParams.VIEW_UP] /= np.linalg.norm(camera[CamParams.VIEW_UP]) + # Based on : https://stackoverflow.com/questions/6565703/ # math-algorithm-fit-image-to-screen-retain-aspect-ratio - remain_axis = np.delete(volume_shape, [axis_index, vert_idx], 0) - ref_height = volume_shape[vert_idx] - if remain_axis[0] / volume_shape[vert_idx] > aspect_ratio: + voxel_size = np.ones(3) + if affine is not None: + voxel_size = np.linalg.norm(affine[:3, :3], axis=0) + + world_shape = np.array(volume_shape) * voxel_size + remain_axis = np.delete(world_shape, [axis_index, vert_idx], 0) + ref_height = world_shape[vert_idx] + if remain_axis[0] / world_shape[vert_idx] > aspect_ratio: ref_height = remain_axis[0] / aspect_ratio # From vtkCamera documentation, see SetViewAngle and SetParallelScale @@ -129,7 +151,8 @@ def set_display_extent(slicer_actor, orientation, volume_shape, slice_index): slicer_actor.display_extent(*extents) -def set_viewport(scene, orientation, slice_index, volume_shape, aspect_ratio): +def set_viewport(scene, orientation, slice_index, volume_shape, aspect_ratio, + affine=None): """ Place the camera in the scene to capture all its content at a given slice_index. @@ -146,11 +169,13 @@ def set_viewport(scene, orientation, slice_index, volume_shape, aspect_ratio): Shape of the sliced volume. aspect_ratio : float Ratio between viewport's width and height. + affine : np.ndarray, optional + Voxel-to-world affine matrix. """ scene.projection(proj_type='parallel') camera = initialize_camera( - orientation, slice_index, volume_shape, aspect_ratio) + orientation, slice_index, volume_shape, aspect_ratio, affine=affine) scene.set_camera(position=camera[CamParams.VIEW_POS], focal_point=camera[CamParams.VIEW_CENTER], view_up=camera[CamParams.VIEW_UP]) @@ -162,7 +187,7 @@ def set_viewport(scene, orientation, slice_index, volume_shape, aspect_ratio): def create_scene(actors, orientation, slice_index, volume_shape, aspect_ratio, - *, bg_color=(0, 0, 0)): + *, bg_color=(0, 0, 0), affine=None): """ Create a 3D scene containing actors fitting inside a grid. The camera is placed based on the orientation supplied by the user. The projection mode @@ -182,6 +207,8 @@ def create_scene(actors, orientation, slice_index, volume_shape, aspect_ratio, Ratio between viewport's width and height. bg_color: tuple Background color expressed as RGB triplet in the range [0, 1]. + affine : np.ndarray, optional + Voxel-to-world affine matrix. Returns ------- @@ -196,14 +223,15 @@ def create_scene(actors, orientation, slice_index, volume_shape, aspect_ratio, for _actor in actors: scene.add(_actor) - set_viewport(scene, orientation, slice_index, volume_shape, aspect_ratio) + set_viewport(scene, orientation, slice_index, volume_shape, aspect_ratio, + affine=affine) return scene def create_interactive_window(scene, window_size, interactor, *, title="Viewer", open_window=True - ): # pragma: no cover + ): # pragma: no cover # (Function ignored from coverage statistics) """ Create a 3D window with the content of scene, equiped with an interactor. @@ -296,7 +324,7 @@ def snapshot_scenes(scenes, window_size): def create_contours_actor(contours, opacity=1., linewidth=3., - color=[255, 0, 0]): + color=[255, 0, 0], affine=None): """ Create an actor from a vtkPolyData of contours @@ -310,6 +338,8 @@ def create_contours_actor(contours, opacity=1., linewidth=3., Thickness of the contour line. color : tuple, list of int Color of the contour in RGB [0, 255]. + affine : np.ndarray, optional + Voxel-to-world affine matrix. Returns ------- @@ -317,7 +347,24 @@ def create_contours_actor(contours, opacity=1., linewidth=3., Fury object containing the contours' information. """ - contours_actor = get_actor_from_polydata(contours) + if affine is not None: + import vtk + vtk_matrix = vtk.vtkMatrix4x4() + for i in range(4): + for j in range(4): + vtk_matrix.SetElement(i, j, affine[i, j]) + + transform = vtk.vtkTransform() + transform.SetMatrix(vtk_matrix) + + transform_filter = vtk.vtkTransformPolyDataFilter() + transform_filter.SetTransform(transform) + transform_filter.SetInputData(contours) + transform_filter.Update() + contours_actor = get_actor_from_polydata(transform_filter.GetOutput()) + else: + contours_actor = get_actor_from_polydata(contours) + contours_actor.GetMapper().ScalarVisibilityOff() contours_actor.GetProperty().SetLineWidth(linewidth) contours_actor.GetProperty().SetColor(color) @@ -328,7 +375,8 @@ def create_contours_actor(contours, opacity=1., linewidth=3., def create_odf_actors(sf_fodf, sphere, scale, sf_variance=None, mask=None, radial_scale=False, norm=False, colormap=None, - variance_k=1.0, variance_color=None, B_mat=None): + variance_k=1.0, variance_color=None, B_mat=None, + affine=None): """ Create an ODF slicer actor displaying a fODF slice. The input volume is a 3-dimensional grid containing the SH coefficients of the fODF at each @@ -361,6 +409,8 @@ def create_odf_actors(sf_fodf, sphere, scale, sf_variance=None, mask=None, Optional SH to SF matrix for projecting `odfs` given in SH coefficients on the `sphere`. If None, then the input is assumed to be expressed in SF coefficients. + affine : np.ndarray, optional + Voxel-to-world affine matrix. Returns ------- @@ -387,7 +437,8 @@ def create_odf_actors(sf_fodf, sphere, scale, sf_variance=None, mask=None, var_actor = actor.odf_slicer(fodf_uncertainty, mask=mask, norm=False, radial_scale=radial_scale, sphere=sphere, scale=scale, - colormap=variance_color) + colormap=variance_color, + affine=affine) var_actor.GetProperty().SetDiffuse(0.0) var_actor.GetProperty().SetAmbient(1.0) @@ -396,14 +447,16 @@ def create_odf_actors(sf_fodf, sphere, scale, sf_variance=None, mask=None, odf_actor = actor.odf_slicer(sf_fodf, mask=mask, norm=False, radial_scale=radial_scale, sphere=sphere, scale=scale, - colormap=colormap, B_matrix=B_mat) + colormap=colormap, B_matrix=B_mat, + affine=affine) return odf_actor, var_actor def create_peaks_actor(peaks, mask, opacity=1.0, linewidth=1.0, color=None, symmetric=False, lut_values=None, lod=False, - lod_nb_points=10000, lod_points_size=3): + lod_nb_points=10000, lod_points_size=3, + affine=None): """ Create a Peaks actor from a N-dimensional array. Data can be from 2D (M 3D peaks) to 5D (XxYxZxM 3D peaks). Color is None by default so coloring @@ -433,6 +486,8 @@ def create_peaks_actor(peaks, mask, opacity=1.0, linewidth=1.0, color=None, Number of points to use for level of detail rendering. lod_points_size : int Size of the points for level of detail rendering. + affine : np.ndarray, optional + Voxel-to-world affine matrix. Returns ------- @@ -440,7 +495,10 @@ def create_peaks_actor(peaks, mask, opacity=1.0, linewidth=1.0, color=None, Fury object containing the peaks' information. """ - return actor.peak_slicer(peaks, mask=mask, affine=np.eye(4), + if affine is None: + affine = np.eye(4) + + return actor.peak_slicer(peaks, mask=mask, affine=affine, colors=color, opacity=opacity, linewidth=linewidth, symmetric=symmetric, peaks_values=lut_values, diff --git a/src/scilpy/viz/screenshot.py b/src/scilpy/viz/screenshot.py index 4c509d92a..a8456e6d0 100644 --- a/src/scilpy/viz/screenshot.py +++ b/src/scilpy/viz/screenshot.py @@ -103,7 +103,7 @@ def screenshot_peaks(img, orientation, slice_ids, size, mask_img=None): Parameters ---------- - img : nib.Nifti1Image + img : nib.Nifti1Image or StatefulImage Peaks volume image. orientation : str Slicing axis name. @@ -122,7 +122,13 @@ def screenshot_peaks(img, orientation, slice_ids, size, mask_img=None): if mask_img: mask = mask_img.get_fdata().astype(bool) - peaks_actor = create_peaks_slicer(img.get_fdata(), orientation, 0, + from scilpy.io.stateful_image import StatefulImage + if isinstance(img, StatefulImage): + data = img.to_voxel_direction() + else: + data = img.get_fdata() + + peaks_actor = create_peaks_slicer(data, orientation, 0, mask=mask) return snapshot_slices([peaks_actor], slice_ids, orientation, diff --git a/src/scilpy/viz/slice.py b/src/scilpy/viz/slice.py index c8977d2c0..5b44e8a67 100644 --- a/src/scilpy/viz/slice.py +++ b/src/scilpy/viz/slice.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from dipy.reconst.shm import sh_to_sf, sh_to_sf_matrix +from dipy.reconst.shm import sh_to_sf from fury import actor import numpy as np @@ -16,7 +16,7 @@ def create_texture_slicer(texture, orientation, slice_index, *, mask=None, value_range=None, opacity=1.0, offset=0.5, - lut=None, interpolation='nearest'): + lut=None, interpolation='nearest', affine=None): """ Create a texture displayed at a given offset (in the given orientation) from the origin of the grid. @@ -45,6 +45,8 @@ def create_texture_slicer(texture, orientation, slice_index, *, mask=None, interpolation : str Interpolation mode for the texture image. Choices are nearest or linear. Defaults to nearest. + affine : np.ndarray, optional + Voxel-to-world affine matrix. Returns ------- @@ -52,7 +54,11 @@ def create_texture_slicer(texture, orientation, slice_index, *, mask=None, Fury object containing the texture information. """ - affine = affine_from_offset(orientation, offset) + offset_affine = affine_from_offset(orientation, offset) + if affine is not None: + affine = np.dot(affine, offset_affine) + else: + affine = offset_affine if mask is not None: texture[np.where(mask == 0)] = 0 @@ -128,13 +134,13 @@ def create_contours_slicer(data, contour_values, orientation, slice_index, def create_peaks_slicer(data, orientation, slice_index, *, peak_values=None, mask=None, color=None, peaks_width=1.0, - opacity=1.0, symmetric=False): + opacity=1.0, symmetric=False, affine=None): """ Create a peaks slicer actor rendering a slice of the input peaks. Parameters ---------- - data : np.ndarray + data : np.ndarray or StatefulImage Peaks data. orientation : str Name of the axis to visualize. Choices are axial, coronal and sagittal. @@ -155,6 +161,8 @@ def create_peaks_slicer(data, orientation, slice_index, *, peak_values=None, If True, peaks are drawn for both peaks_dirs and -peaks_dirs. Else, peaks are only drawn for directions given by peaks_dirs. Defaults to False. + affine : np.ndarray, optional + Voxel-to-world affine matrix. Returns ------- @@ -162,6 +170,10 @@ def create_peaks_slicer(data, orientation, slice_index, *, peak_values=None, Fury object containing the peaks information. """ + from scilpy.io.stateful_image import StatefulImage + if isinstance(data, StatefulImage): + data = data.to_voxel_direction() + # Reshape peaks volume to XxYxZxNx3 data = data.reshape(data.shape[:3] + (-1, 3)) norm = np.linalg.norm(data, axis=-1) @@ -182,7 +194,8 @@ def create_peaks_slicer(data, orientation, slice_index, *, peak_values=None, peaks_slicer = create_peaks_actor(data, mask, opacity=opacity, linewidth=peaks_width, color=color, lut_values=peak_values, - symmetric=symmetric) + symmetric=symmetric, + affine=affine) set_display_extent(peaks_slicer, orientation, data.shape, slice_index) @@ -193,7 +206,8 @@ def create_odf_slicer(sh_fodf, orientation, slice_index, sphere, sh_order, sh_basis, full_basis, scale, sh_variance=None, mask=None, nb_subdivide=None, radial_scale=False, norm=False, colormap=None, variance_k=1, - variance_color=(255, 255, 255), is_legacy=True): + variance_color=(255, 255, 255), is_legacy=True, + affine=None): """ Create a ODF slicer actor displaying a fODF slice. The input volume is a 3-dimensional grid containing the SH coefficients of the fODF for each @@ -202,7 +216,7 @@ def create_odf_slicer(sh_fodf, orientation, slice_index, sphere, sh_order, Parameters ---------- - sh_fodf : np.ndarray + sh_fodf : np.ndarray or StatefulImage Spherical harmonics of fODF data. orientation : str Name of the axis to visualize. Choices are axial, coronal and sagittal. @@ -218,7 +232,7 @@ def create_odf_slicer(sh_fodf, orientation, slice_index, sphere, sh_order, Boolean indicating if the basis is full or not. scale : float Scaling factor for FODF. - sh_variance : np.ndarray, optional + sh_variance : np.ndarray or StatefulImage, optional Spherical harmonics of the variance fODF data. mask : np.ndarray, optional Only the data inside the mask will be displayed. Defaults to None. @@ -237,6 +251,8 @@ def create_odf_slicer(sh_fodf, orientation, slice_index, sphere, sh_order, Color of the variance fODF data, in RGB. is_legacy: bool Whether the SH basis is used in legacy formats [True]. + affine : np.ndarray, optional + Voxel-to-world affine matrix. Returns ------- @@ -246,33 +262,35 @@ def create_odf_slicer(sh_fodf, orientation, slice_index, sphere, sh_order, Fury object containing the odf variance information. """ + from scilpy.io.stateful_image import StatefulImage + if isinstance(sh_fodf, StatefulImage): + sh_fodf = sh_fodf.to_voxel_direction(sh_basis=sh_basis) + + if isinstance(sh_variance, StatefulImage): + sh_variance = sh_variance.to_voxel_direction(sh_basis=sh_basis) + # Subdivide the spheres if nb_subdivide is provided if nb_subdivide is not None: sphere = sphere.subdivide(n=nb_subdivide) - fodf = sh_to_sf(sh_fodf, sphere, + # Always compute SF from SH to avoid FURY's odf_slicer bugs with matrix multiplication + fodf = sh_to_sf(sh_fodf, sphere, sh_order_max=sh_order, basis_type=sh_basis, full_basis=full_basis, legacy=is_legacy) fodf_var = None - B_mat = None if sh_variance is not None: fodf_var = sh_to_sf(sh_variance, sphere, sh_order_max=sh_order, basis_type=sh_basis, full_basis=full_basis, legacy=is_legacy) - else: - fodf = sh_fodf - B_mat = sh_to_sf_matrix(sphere, sh_order_max=sh_order, - basis_type=sh_basis, - full_basis=full_basis, return_inv=False) odf_actor, var_actor = create_odf_actors(fodf, sphere, scale, fodf_var, mask, radial_scale, norm, colormap, variance_k, variance_color, - B_mat=B_mat) + affine=affine) set_display_extent(odf_actor, orientation, sh_fodf.shape[:3], slice_index) if sh_variance is not None: @@ -283,13 +301,13 @@ def create_odf_slicer(sh_fodf, orientation, slice_index, sphere, sh_order, def create_bingham_slicer(data, orientation, slice_index, - sphere, color_per_lobe=False): + sphere, color_per_lobe=False, affine=None): """ Create a bingham fit slicer using a combination of odf_slicer actors Parameters ---------- - data: Array + data: Array or StatefulImage Volume of shape (X, Y, Z, N_LOBES, NB_PARAMS) containing the Bingham distributions parameters. Note, NB_PARAMS is usually 7. One of X, Y, Z should be of value 1 (one slice). @@ -302,12 +320,19 @@ def create_bingham_slicer(data, orientation, slice_index, color_per_lobe: bool If true, each Bingham distribution is colored using a disting color. Else, Bingham distributions are colored by their orientation. + affine: np.ndarray, optional + Voxel-to-world affine matrix. Return ------ actors: list of fury odf_slicer actors ODF slicer actors representing the Bingham distributions. """ + + from scilpy.io.stateful_image import StatefulImage + if isinstance(data, StatefulImage): + data = data.to_voxel_direction() + shape = data.shape if len(shape) != 5: raise ValueError('Expecting bingham data to be 5D ' @@ -327,7 +352,7 @@ def create_bingham_slicer(data, orientation, slice_index, color = colors[nn] if color_per_lobe else None odf_actor, _ = create_odf_actors(sf, sphere, 0.5, colormap=color, - radial_scale=True) + radial_scale=True, affine=affine) set_display_extent(odf_actor, orientation, shape[:3], slice_index) actors.append(odf_actor)