diff --git a/.gitignore b/.gitignore index 6906a010..09400a80 100644 --- a/.gitignore +++ b/.gitignore @@ -155,3 +155,4 @@ waveorder/_version.py # example data /examples/data_temp/* /logs/* +runs/* \ No newline at end of file diff --git a/docs/examples/models/isotropic_thin_3d.py b/docs/examples/models/isotropic_thin_3d.py index 9693645d..73b2d968 100644 --- a/docs/examples/models/isotropic_thin_3d.py +++ b/docs/examples/models/isotropic_thin_3d.py @@ -10,6 +10,7 @@ import napari import numpy as np +import torch from waveorder.models import isotropic_thin_3d @@ -32,9 +33,9 @@ ] ) transfer_function_arguments = { - "z_position_list": (np.arange(z_shape) - z_shape // 2) * z_pixel_size, - "numerical_aperture_illumination": 0.9, - "numerical_aperture_detection": 1.2, + "z_position_list": (torch.arange(z_shape) - z_shape // 2) * z_pixel_size, + "numerical_aperture_illumination": torch.tensor([0.9]), + "numerical_aperture_detection": torch.tensor([1.2]), } # Create a disk phantom diff --git a/docs/examples/visuals/optimize_phase_recon.py b/docs/examples/visuals/optimize_phase_recon.py new file mode 100644 index 00000000..acbc4854 --- /dev/null +++ b/docs/examples/visuals/optimize_phase_recon.py @@ -0,0 +1,274 @@ +from datetime import datetime + +import napari +import numpy as np +import torch + +# Commenting biahub dependency for now +# from biahub.cli.utils import model_to_yaml +# from biahub.settings import StitchSettings +from iohub import open_ome_zarr +from iohub.ngff import TransformationMeta +from torch.utils.tensorboard import SummaryWriter + +from waveorder import optics, util +from waveorder.models import isotropic_thin_3d + + +# === Core Functions === +def run_reconstruction( + zyx_tile: torch.Tensor, recon_args: dict +) -> torch.Tensor: + + # Prepare transfer function arguments + tf_args = recon_args.copy() + Z, _, _ = zyx_tile.shape + tf_args["z_position_list"] = ( + torch.arange(Z) - (Z // 2) + recon_args["z_offset"] + ) * recon_args["z_scale"] + tf_args.pop("z_offset") + tf_args.pop("z_scale") + + # Core reconstruction calls + tf_abs, tf_phase = isotropic_thin_3d.calculate_transfer_function(**tf_args) + system = isotropic_thin_3d.calculate_singular_system(tf_abs, tf_phase) + _, yx_phase_recon = isotropic_thin_3d.apply_inverse_transfer_function( + zyx_tile, system, regularization_strength=1e-2 + ) + return yx_phase_recon + + +def compute_midband_power( + yx_array: torch.Tensor, + NA_det: float, + lambda_ill: float, + pixel_size: float, + band: tuple[float, float] = (0.125, 0.25), +) -> torch.Tensor: + _, _, fxx, fyy = util.gen_coordinate(yx_array.shape, pixel_size) + frr = torch.tensor(np.sqrt(fxx**2 + fyy**2)) + xy_abs_fft = torch.abs(torch.fft.fftn(yx_array)) + cutoff = 2 * NA_det / lambda_ill + mask = torch.logical_and(frr > cutoff * band[0], frr < cutoff * band[1]) + return torch.sum(xy_abs_fft[mask]) + + +def extract_tiles( + zyx_data: np.ndarray, num_tiles: tuple[int, int], overlap_pct: float +) -> tuple[dict[str, np.ndarray], dict[str, tuple[int, int, int]]]: + Z, Y, X = zyx_data.shape + tile_height = int( + np.ceil(Y / (num_tiles[0] - (num_tiles[0] - 1) * overlap_pct)) + ) + tile_width = int( + np.ceil(X / (num_tiles[1] - (num_tiles[1] - 1) * overlap_pct)) + ) + stride_y = int(tile_height * (1 - overlap_pct)) + stride_x = int(tile_width * (1 - overlap_pct)) + + tiles = {} + translations = {} + for yi in range(num_tiles[0]): + for xi in range(num_tiles[1]): + y0, x0 = yi * stride_y, xi * stride_x + y1, x1 = min(y0 + tile_height, Y), min(x0 + tile_width, X) + tile_name = f"0/0/{yi:03d}{xi:03d}" + tiles[tile_name] = zyx_data[:, y0:y1, x0:x1] + translations[tile_name] = (0, y0, x0) + return tiles, translations + + +def log_optimization_progress( + step: int, + optimization_params: dict[str, torch.nn.Parameter], + loss: torch.Tensor, + tb_writer: SummaryWriter, + recon_args: dict, + yx_recon: torch.Tensor, +) -> None: + # Print progress + print(f"Step {step + 1}/{NUM_ITERATIONS}") + for name, param in optimization_params.items(): + print(f"\t{name} = {param.item():.4f}") + print(f"\tLoss: {loss.item():.2e}\n") + + # Log metrics and images + tb_writer.add_scalar("Loss", loss.item(), step) + for name, param in optimization_params.items(): + tb_writer.add_scalar(name, param.item(), step) + + yx_pixel_factor = 2 + fyy, fxx = util.generate_frequencies( + [yx_pixel_factor * x for x in recon_args["yx_shape"]], + recon_args["yx_pixel_size"] / yx_pixel_factor, + ) + pupil = optics.generate_tilted_pupil( + fxx, + fyy, + recon_args["numerical_aperture_illumination"], + recon_args["wavelength_illumination"], + recon_args["index_of_refraction_media"], + recon_args["tilt_angle_zenith"], + recon_args["tilt_angle_azimuth"], + ) + tb_writer.add_image( + "Illumination Pupil", + torch.fft.fftshift(pupil).detach().numpy()[None], + step, + ) + tb_writer.add_image( + "Reconstructed Phase", yx_recon.detach().numpy()[None], step + ) + + +def prepare_optimizer( + optimizable_params: dict[str, tuple[bool, float, float]], +) -> tuple[dict[str, torch.nn.Parameter], torch.optim.Optimizer]: + optimization_params: dict[str, torch.nn.Parameter] = {} + optimizer_config = [] + for name, (enabled, initial, lr) in optimizable_params.items(): + if enabled: + param = torch.nn.Parameter( + torch.tensor([initial], device="cpu"), requires_grad=True + ) + optimization_params[name] = param + optimizer_config.append({"params": [param], "lr": lr}) + + optimizer = torch.optim.Adam(optimizer_config) + return optimization_params, optimizer + + +def optimize_tile( + zyx_tile: torch.Tensor, + recon_args: dict, + optimizable_params: dict[str, tuple[bool, float, float]], + tb_writer: SummaryWriter, + num_iterations: int = 10, +) -> torch.Tensor: + optimization_params, optimizer = prepare_optimizer(optimizable_params) + + for step in range(num_iterations): + + # Update params + for name, param in optimization_params.items(): + recon_args[name] = param + + # Run reconstruction and compute loss + yx_recon = run_reconstruction(zyx_tile, recon_args) + loss = -compute_midband_power( + yx_recon, + NA_det=0.15, + lambda_ill=recon_args["wavelength_illumination"], + pixel_size=recon_args["yx_pixel_size"], + band=(0.1, 0.2), + ) + + # Update optimizer + loss.backward() + optimizer.step() + optimizer.zero_grad() + + log_optimization_progress( + step, optimization_params, loss, tb_writer, recon_args, yx_recon + ) + + return yx_recon.detach() + + +# === Configuration === +# INPUTS +INPUT_PATH = "/hpc/projects/intracellular_dashboard/ops/ops0031_20250424/0-convert/live_imaging/tracking_symlink.zarr" +INPUT_FOV = "A/1/001007" +SUBTILES = ["0/0/001001"] # or "all" + +# OUTPUTS +OUTPUT_PATH = "./optimized_recon.zarr" +OUTPUT_CHANNEL_NAME = "recon" + +# TILING +STITCH_CONFIG_PATH = "./stitch_config.yaml" +NUM_TILES = (6, 6) +OVERLAP_FRACTION = 0.2 + +# OPTIMIZATION +NUM_ITERATIONS = 10 +LOGS_DIR = "./runs" +FIXED_PARAMS = { + "wavelength_illumination": 0.450, + "index_of_refraction_media": 1.0, + "invert_phase_contrast": True, +} +OPTIMIZABLE_PARAMS = { # (optimize?, initial_value, learning_rate) + "z_offset": (True, 0.0, 0.01), + "numerical_aperture_detection": (True, 0.15, 0.001), + "numerical_aperture_illumination": (True, 0.1, 0.001), + "tilt_angle_zenith": (True, 0.1, 0.005), + "tilt_angle_azimuth": (True, 260 * np.pi / 180, 0.001), +} + +# === Main Execution === +input_store = open_ome_zarr(INPUT_PATH) +zyx_data = input_store[INPUT_FOV].data[0][0] +_, _, z_scale, y_scale, x_scale = input_store[INPUT_FOV].scale + +output_store = open_ome_zarr( + OUTPUT_PATH, layout="hcs", mode="w", channel_names=[OUTPUT_CHANNEL_NAME] +) +tiles, translations = extract_tiles(zyx_data, NUM_TILES, OVERLAP_FRACTION) +# Commenting biahub dependency for now +# model_to_yaml( +# StitchSettings(total_translation=translations), STITCH_CONFIG_PATH +# ) + +if SUBTILES == "all": + selected_keys = tiles.keys() +else: + selected_keys = SUBTILES + +for key in selected_keys: + zyx_tile = torch.tensor(tiles[key], dtype=torch.float32, device="cpu") + + print(f"Processing tile {key}") + timestamp = datetime.now().strftime("%d%H%M") + log_dir = f"{LOGS_DIR}/tile_{key.replace('/', '_')}_{timestamp}" + tb_writer = SummaryWriter(log_dir=log_dir) + + # Prepare reconstruction arguments + recon_args = FIXED_PARAMS + for name, value in OPTIMIZABLE_PARAMS.items(): + recon_args[name] = torch.tensor( + [value[1]], dtype=torch.float32, device="cpu" + ) + recon_args["yx_shape"] = zyx_tile.shape[1:] + recon_args["yx_pixel_size"] = y_scale + recon_args["z_scale"] = z_scale + + initial_recon = run_reconstruction(zyx_tile, recon_args) + optimized_recon = optimize_tile( + zyx_tile, + recon_args, + OPTIMIZABLE_PARAMS, + tb_writer, + num_iterations=NUM_ITERATIONS, + ) + tb_writer.close() + + # Write to napari viewer + scale = [z_scale, y_scale, x_scale] + viewer = napari.Viewer() + viewer.add_image( + initial_recon.numpy()[None], name=f"initial-{key}", scale=scale + ) + viewer.add_image( + optimized_recon.numpy()[None], name=f"optimized-{key}", scale=scale + ) + viewer.add_image(zyx_tile, name=f"tile-{key}", scale=scale) + + # Write to output store + pos = output_store.create_position(*key.split("/")) + pos.create_image( + "0", + optimized_recon[None, None, None].numpy(), + transform=[TransformationMeta(type="scale", scale=[1, 1] + scale)], + ) + input("Press Enter to continue...") diff --git a/pyproject.toml b/pyproject.toml index ff6a1a96..af228c19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,7 @@ visual = [ "napari-ome-zarr>=0.3.2", # drag and drop convenience "pycromanager==0.27.2", "jupyter", + "tensorboard", ] dev = [ "black==25.1.0", diff --git a/tests/cli_tests/test_compute_tf.py b/tests/cli_tests/test_compute_tf.py index b3cdaf59..86856c62 100644 --- a/tests/cli_tests/test_compute_tf.py +++ b/tests/cli_tests/test_compute_tf.py @@ -1,5 +1,5 @@ -import numpy as np import pytest +import torch from click.testing import CliRunner from waveorder.cli import settings @@ -22,7 +22,7 @@ ) def test_position_list_from_shape_scale_offset(shape, scale, offset, expected): result = _position_list_from_shape_scale_offset(shape, scale, offset) - np.testing.assert_allclose(result, expected) + torch.testing.assert_close(result, torch.tensor(expected)) def test_compute_transfer(tmp_path, example_plate): diff --git a/tests/models/test_isotropic_thin_3d.py b/tests/models/test_isotropic_thin_3d.py index 9bd49bf7..09f66675 100644 --- a/tests/models/test_isotropic_thin_3d.py +++ b/tests/models/test_isotropic_thin_3d.py @@ -9,7 +9,7 @@ def test_calculate_transfer_function(invert_phase_contrast): Hu, Hp = isotropic_thin_3d.calculate_transfer_function( yx_shape=(100, 101), yx_pixel_size=6.5 / 40, - z_position_list=[-1, 0, 1], + z_position_list=torch.tensor([-1, 0, 1]), wavelength_illumination=0.5, index_of_refraction_media=1.0, numerical_aperture_illumination=0.4, diff --git a/tests/test_optics.py b/tests/test_optics.py index d51b5bef..d5a00df8 100644 --- a/tests/test_optics.py +++ b/tests/test_optics.py @@ -8,11 +8,23 @@ def test_generate_pupil(): pupil = optics.generate_pupil(radial_frequencies, 0.5, 0.5) # Corners are in the pupil - assert pupil[0, 0] == 1 - assert pupil[-1, -1] == 1 + assert torch.isclose(pupil[0, 0], torch.tensor(1.0), rtol=1e-3) + assert torch.isclose(pupil[-1, -1], torch.tensor(1.0), rtol=1e-3) # Center is outside the pupil - assert pupil[5, 5] == 0 + assert pupil[5, 5] < 1e-3 + + +def test_generate_pupil_cutoff(): + """ + Test generate_pupil at the cutoff frequency. + """ + frr = torch.tensor([[0.5, 1.0, 1.5]]) + NA = 1.0 + lamb_in = 1.0 + pupil = optics.generate_pupil(frr, NA, lamb_in) + # At cutoff, sigmoid should be ~0.5 + assert torch.isclose(pupil[0, 1], torch.tensor(0.5), atol=1e-3) def test_generate_propagation_kernel(): @@ -27,7 +39,7 @@ def test_generate_propagation_kernel(): assert propagation_kernel.shape == (3, 10, 10) assert propagation_kernel[1, 0, 0] == 1 - assert propagation_kernel[1, 5, 5] == 0 + assert torch.abs(propagation_kernel[1, 5, 5]) < 1e-3 def test_gen_Greens_function_z(): @@ -41,7 +53,7 @@ def test_gen_Greens_function_z(): ) assert G.shape == (3, 10, 10) - assert G[1, 5, 5] == 0 + assert torch.abs(G[1, 5, 5]) < 1e-3 def test_WOTF_2D(): @@ -57,7 +69,12 @@ def test_WOTF_2D(): ) # Absorption DC term - assert absorption_transfer_function[0, 0] == 2 + assert torch.isclose( + torch.real(absorption_transfer_function[0, 0]), + torch.tensor(2.0), + rtol=1e-3, + ) + assert torch.abs(torch.imag(absorption_transfer_function[0, 0])) < 1e-3 # No phase contrast for an in-focus slice assert torch.all(torch.real(phase_transfer_function) == 0) diff --git a/waveorder/cli/compute_transfer_function.py b/waveorder/cli/compute_transfer_function.py index 4374b19c..0e6a715d 100644 --- a/waveorder/cli/compute_transfer_function.py +++ b/waveorder/cli/compute_transfer_function.py @@ -1,7 +1,7 @@ from pathlib import Path import click -import numpy as np +import torch from iohub.ngff import Position, open_ome_zarr from waveorder import focus @@ -25,18 +25,20 @@ def _position_list_from_shape_scale_offset( shape: int, scale: float, offset: float -) -> list: +) -> torch.Tensor: """ - Generates a list of positions based on the given array shape, pixel size (scale), and offset. + Generates a 1D tensor of positions based on the given array shape, pixel size (scale), and offset. Examples -------- >>> _position_list_from_shape_scale_offset(5, 1.0, 0.0) - [2.0, 1.0, 0.0, -1.0, -2.0] + tensor([ 2., 1., 0., -1., -2.]) >>> _position_list_from_shape_scale_offset(4, 0.5, 1.0) - [1.5, 1.0, 0.5, 0.0] + tensor([1.5, 1.0, 0.5, 0.0]) """ - return list((-np.arange(shape) + (shape // 2) + offset) * scale) + return ( + -torch.arange(shape, dtype=torch.float32) + (shape // 2) + offset + ) * scale def generate_and_save_vector_birefringence_transfer_function( diff --git a/waveorder/models/isotropic_fluorescent_thick_3d.py b/waveorder/models/isotropic_fluorescent_thick_3d.py index 1a1f1d1e..e4d40f9a 100644 --- a/waveorder/models/isotropic_fluorescent_thick_3d.py +++ b/waveorder/models/isotropic_fluorescent_thick_3d.py @@ -1,6 +1,5 @@ from typing import Literal -import numpy as np import torch from torch import Tensor @@ -31,6 +30,7 @@ def calculate_transfer_function( z_padding: int, index_of_refraction_media: float, numerical_aperture_detection: float, + detection_phase_zernike_vector: Tensor = torch.tensor([0.0]), confocal_pinhole_diameter: float | None = None, ) -> Tensor: """Calculate the optical transfer function for fluorescence imaging. @@ -56,6 +56,9 @@ def calculate_transfer_function( Refractive index of imaging medium numerical_aperture_detection : float Numerical aperture of detection objective + detection_phase_zernike_vector : Tensor, optional + Zernike phase vector for detection objective. If None, no phase correction + is applied. confocal_pinhole_diameter : float | None, optional Diameter of confocal pinhole in image space (demagnified). If None, computes widefield OTF. If specified, computes confocal OTF. @@ -81,8 +84,8 @@ def calculate_transfer_function( transverse_nyquist = transverse_nyquist / 2 axial_nyquist = axial_nyquist / 2 - yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist)) - z_factor = int(np.ceil(z_pixel_size / axial_nyquist)) + yx_factor = int(torch.ceil(yx_pixel_size / transverse_nyquist)) + z_factor = int(torch.ceil(z_pixel_size / axial_nyquist)) optical_transfer_function = _calculate_wrap_unsafe_transfer_function( ( @@ -96,6 +99,7 @@ def calculate_transfer_function( z_padding, index_of_refraction_media, numerical_aperture_detection, + detection_phase_zernike_vector, confocal_pinhole_diameter, ) zyx_out_shape = (zyx_shape[0] + 2 * z_padding,) + zyx_shape[1:] @@ -141,21 +145,24 @@ def _calculate_wrap_unsafe_transfer_function( z_padding: int, index_of_refraction_media: float, numerical_aperture_detection: float, + detection_phase_zernike_vector: Tensor = torch.tensor([0.0]), confocal_pinhole_diameter: float | None = None, ) -> Tensor: - radial_frequencies = util.generate_radial_frequencies( - zyx_shape[1:], yx_pixel_size - ) + fyy, fxx = util.generate_frequencies(zyx_shape[1:], yx_pixel_size) + radial_frequencies = torch.sqrt(fyy**2 + fxx**2) z_total = zyx_shape[0] + 2 * z_padding z_position_list = torch.fft.ifftshift( (torch.arange(z_total) - z_total // 2) * z_pixel_size ) - det_pupil = optics.generate_pupil( - radial_frequencies, + det_pupil = optics.generate_tilted_pupil( + fxx, + fyy, numerical_aperture_detection, wavelength_emission, + index_of_refraction_media, + phase_zernike_vector=detection_phase_zernike_vector, ) propagation_kernel = optics.generate_propagation_kernel( @@ -164,14 +171,12 @@ def _calculate_wrap_unsafe_transfer_function( wavelength_emission / index_of_refraction_media, z_position_list, ) - point_spread_function = ( torch.abs(torch.fft.ifft2(propagation_kernel, dim=(1, 2))) ** 2 ) optical_transfer_function = torch.fft.fftn( point_spread_function, dim=(0, 1, 2) ) - # Confocal: multiply excitation PSF with detection PSF (downweighted by pinhole) if confocal_pinhole_diameter is not None: pinhole_otf_2d = _calculate_pinhole_aperture_otf( @@ -197,6 +202,7 @@ def _calculate_wrap_unsafe_transfer_function( optical_transfer_function /= torch.max( torch.abs(optical_transfer_function) ) # normalize + # NB: this is a /= operation, but in-place operations do not propagate gradients return optical_transfer_function diff --git a/waveorder/models/isotropic_thin_3d.py b/waveorder/models/isotropic_thin_3d.py index 16f253b9..b68aef60 100644 --- a/waveorder/models/isotropic_thin_3d.py +++ b/waveorder/models/isotropic_thin_3d.py @@ -1,6 +1,5 @@ from typing import Literal, Tuple -import numpy as np import torch from torch import Tensor @@ -38,19 +37,23 @@ def generate_test_phantom( def calculate_transfer_function( yx_shape: Tuple[int, int], yx_pixel_size: float, - z_position_list: list, + z_position_list: torch.Tensor, wavelength_illumination: float, index_of_refraction_media: float, numerical_aperture_illumination: float, numerical_aperture_detection: float, invert_phase_contrast: bool = False, + tilt_angle_zenith: torch.Tensor = torch.tensor([0.0]), + tilt_angle_azimuth: torch.Tensor = torch.tensor([0.0]), ) -> Tuple[Tensor, Tensor]: transverse_nyquist = sampling.transverse_nyquist( wavelength_illumination, numerical_aperture_illumination, numerical_aperture_detection, ) - yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist)) + yx_factor = int( + torch.ceil(torch.tensor(yx_pixel_size / transverse_nyquist)) + ) ( absorption_2d_to_3d_transfer_function, @@ -67,6 +70,8 @@ def calculate_transfer_function( numerical_aperture_illumination, numerical_aperture_detection, invert_phase_contrast=invert_phase_contrast, + tilt_angle_zenith=tilt_angle_zenith, + tilt_angle_azimuth=tilt_angle_azimuth, ) absorption_2d_to_3d_transfer_function_out = torch.zeros( @@ -97,12 +102,14 @@ def calculate_transfer_function( def _calculate_wrap_unsafe_transfer_function( yx_shape: Tuple[int, int], yx_pixel_size: float, - z_position_list: list, + z_position_list: torch.Tensor, wavelength_illumination: float, index_of_refraction_media: float, numerical_aperture_illumination: float, numerical_aperture_detection: float, invert_phase_contrast: bool = False, + tilt_angle_zenith: float = 0.0, + tilt_angle_azimuth: float = 0.0, ) -> Tuple[Tensor, Tensor]: if numerical_aperture_illumination >= numerical_aperture_detection: print( @@ -111,19 +118,30 @@ def _calculate_wrap_unsafe_transfer_function( "numerical_aperture_illumination to 0.9 * " "numerical_aperture_detection to avoid singularities." ) - numerical_aperture_illumination = 0.9 * numerical_aperture_detection + numerical_aperture_illumination = torch.where( + numerical_aperture_illumination >= numerical_aperture_detection, + numerical_aperture_detection, + numerical_aperture_illumination, + ) if invert_phase_contrast: - z_position_list = [-1 * x for x in z_position_list] - radial_frequencies = util.generate_radial_frequencies( - yx_shape, yx_pixel_size - ) + z_positions = z_position_list * -1 + else: + z_positions = z_position_list.clone() - illumination_pupil = optics.generate_pupil( - radial_frequencies, + fyy, fxx = util.generate_frequencies(yx_shape, yx_pixel_size) + radial_frequencies = torch.sqrt(fyy**2 + fxx**2) + + illumination_pupil = optics.generate_tilted_pupil( + fxx, + fyy, numerical_aperture_illumination, wavelength_illumination, + index_of_refraction_media, + tilt_angle_zenith, + tilt_angle_azimuth, ) + detection_pupil = optics.generate_pupil( radial_frequencies, numerical_aperture_detection, @@ -133,17 +151,17 @@ def _calculate_wrap_unsafe_transfer_function( radial_frequencies, detection_pupil, wavelength_illumination / index_of_refraction_media, - torch.tensor(z_position_list), + z_positions, ) - zyx_shape = (len(z_position_list),) + tuple(yx_shape) + zyx_shape = (len(z_positions),) + tuple(yx_shape) absorption_2d_to_3d_transfer_function = torch.zeros( zyx_shape, dtype=torch.complex64 ) phase_2d_to_3d_transfer_function = torch.zeros( zyx_shape, dtype=torch.complex64 ) - for z in range(len(z_position_list)): + for z in range(len(z_positions)): ( absorption_2d_to_3d_transfer_function[z], phase_2d_to_3d_transfer_function[z], @@ -188,13 +206,27 @@ def calculate_singular_system( ), dim=0, ) - YXsf_transfer_function = sfYX_transfer_function.permute(2, 3, 0, 1) - Up, Sp, Vhp = torch.linalg.svd(YXsf_transfer_function, full_matrices=False) - U = Up.permute(2, 3, 0, 1) - S = Sp.permute(2, 0, 1) - Vh = Vhp.permute(2, 3, 0, 1) + + # phase only reconstruction + sfYX_transfer_function = phase_2d_to_3d_transfer_function.unsqueeze(0) + U = torch.ones_like(sfYX_transfer_function) + S = torch.linalg.norm(sfYX_transfer_function, dim=1) + Vh = sfYX_transfer_function / S.unsqueeze(0) return U, S, Vh + # Absorption and phase reconstruction + # Gradients do not work with complex-valued SVD, so only phase reconstructions + # are supported for now. + + # YXsf_transfer_function = sfYX_transfer_function.permute(2, 3, 0, 1) + # Up, Sp, Vhp = torch.linalg.svd(YXsf_transfer_function, full_matrices=False) + + # U = Up.permute(2, 3, 0, 1) + # S = Sp.permute(2, 0, 1) + # Vh = Vhp.permute(2, 3, 0, 1) + + # return U, S, Vh + def visualize_transfer_function( viewer, @@ -257,17 +289,18 @@ def apply_transfer_function( # simulate absorbing object yx_absorption_hat = torch.fft.fftn(yx_absorption) - zyx_absorption_data_hat = yx_absorption_hat[None, ...] * torch.real( - absorption_2d_to_3d_transfer_function + zyx_absorption_data_hat = ( + yx_absorption_hat[None, ...] * absorption_2d_to_3d_transfer_function ) + zyx_absorption_data = torch.real( torch.fft.ifftn(zyx_absorption_data_hat, dim=(1, 2)) ) # simulate phase object yx_phase_hat = torch.fft.fftn(yx_phase) - zyx_phase_data_hat = yx_phase_hat[None, ...] * torch.real( - phase_2d_to_3d_transfer_function + zyx_phase_data_hat = ( + yx_phase_hat[None, ...] * phase_2d_to_3d_transfer_function ) zyx_phase_data = torch.real( torch.fft.ifftn(zyx_phase_data_hat, dim=(1, 2)) @@ -331,14 +364,16 @@ def apply_inverse_transfer_function( # TODO Consider refactoring with vectorial transfer function SVD if reconstruction_algorithm == "Tikhonov": - print("Computing inverse filter") U, S, Vh = singular_system S_reg = S / (S**2 + regularization_strength) sfyx_inverse_filter = torch.einsum( - "sj...,j...,jf...->fs...", U, S_reg, Vh + "sj...,j...,jf...->fs...", U, S_reg, Vh.conj() ) - absorption_yx, phase_yx = apply_filter_bank(sfyx_inverse_filter, zyx) + # Phase only reconstruction + # absorption_yx, phase_yx = apply_filter_bank(sfyx_inverse_filter, zyx) + phase_yx = apply_filter_bank(sfyx_inverse_filter, zyx)[0] + absorption_yx = torch.zeros_like(phase_yx) # ADMM deconvolution with anisotropic TV regularization elif reconstruction_algorithm == "TV": diff --git a/waveorder/models/phase_thick_3d.py b/waveorder/models/phase_thick_3d.py index 5a2e2547..1fd5d1f7 100644 --- a/waveorder/models/phase_thick_3d.py +++ b/waveorder/models/phase_thick_3d.py @@ -43,6 +43,8 @@ def calculate_transfer_function( numerical_aperture_illumination: float, numerical_aperture_detection: float, invert_phase_contrast: bool = False, + tilt_angle_zenith: torch.Tensor = torch.tensor([0.0]), + tilt_angle_azimuth: torch.Tensor = torch.tensor([0.0]), ) -> tuple[np.ndarray, np.ndarray]: transverse_nyquist = sampling.transverse_nyquist( wavelength_illumination, @@ -50,13 +52,14 @@ def calculate_transfer_function( numerical_aperture_detection, ) axial_nyquist = sampling.axial_nyquist( - wavelength_illumination, + torch.tensor(wavelength_illumination), numerical_aperture_detection, - index_of_refraction_media, + torch.tensor(index_of_refraction_media), ) - - yx_factor = int(np.ceil(yx_pixel_size / transverse_nyquist)) - z_factor = int(np.ceil(z_pixel_size / axial_nyquist)) + yx_factor = int( + torch.ceil(torch.tensor(yx_pixel_size / transverse_nyquist)) + ) + z_factor = int(torch.ceil(torch.tensor(z_pixel_size / axial_nyquist))) ( real_potential_transfer_function, @@ -75,6 +78,8 @@ def calculate_transfer_function( numerical_aperture_illumination, numerical_aperture_detection, invert_phase_contrast=invert_phase_contrast, + tilt_angle_zenith=tilt_angle_zenith, + tilt_angle_azimuth=tilt_angle_azimuth, ) zyx_out_shape = (zyx_shape[0] + 2 * z_padding,) + zyx_shape[1:] @@ -98,10 +103,11 @@ def _calculate_wrap_unsafe_transfer_function( numerical_aperture_illumination: float, numerical_aperture_detection: float, invert_phase_contrast: bool = False, + tilt_angle_zenith: torch.Tensor = torch.tensor([0.0]), + tilt_angle_azimuth: torch.Tensor = torch.tensor([0.0]), ) -> tuple[np.ndarray, np.ndarray]: - radial_frequencies = util.generate_radial_frequencies( - zyx_shape[1:], yx_pixel_size - ) + fyy, fxx = util.generate_frequencies(zyx_shape[1:], yx_pixel_size) + radial_frequencies = torch.sqrt(fyy**2 + fxx**2) z_total = zyx_shape[0] + 2 * z_padding z_position_list = torch.fft.ifftshift( (torch.arange(z_total) - z_total // 2) * z_pixel_size @@ -109,10 +115,14 @@ def _calculate_wrap_unsafe_transfer_function( if invert_phase_contrast: z_position_list = torch.flip(z_position_list, dims=(0,)) - ill_pupil = optics.generate_pupil( - radial_frequencies, - numerical_aperture_illumination, + ill_pupil = optics.generate_tilted_pupil( + fxx, + fyy, + torch.tensor(numerical_aperture_illumination), wavelength_illumination, + index_of_refraction_media, + tilt_angle_zenith, + tilt_angle_azimuth, ) det_pupil = optics.generate_pupil( radial_frequencies, @@ -144,7 +154,6 @@ def _calculate_wrap_unsafe_transfer_function( greens_function_z, z_pixel_size, ) - return real_potential_transfer_function, imag_potential_transfer_function diff --git a/waveorder/optics.py b/waveorder/optics.py index 46204517..16bd35d0 100644 --- a/waveorder/optics.py +++ b/waveorder/optics.py @@ -4,6 +4,8 @@ import torch from numpy.fft import fft2, fftn, fftshift, ifft2, ifftn, ifftshift +from waveorder import zernike + def Jones_sample(Ein, t, sa): """ @@ -118,34 +120,122 @@ def analyzer_output(Ein, alpha, beta): return Eout -def generate_pupil(frr, NA, lamb_in): +def generate_pupil( + frr: torch.Tensor, NA: float, lamb_in: float, slope: float = 4.0 +) -> torch.Tensor: """ - - compute pupil function given spatial frequency, NA, wavelength. + Generate a soft-edged pupil function using a sigmoid roll-off. The default + sigmoid softens the edge within ~1 voxel of the boundary. Parameters ---------- - frr : torch.tensor - radial frequency coordinate in units of inverse length + frr : torch.Tensor + radial frequency coordinates (units: 1/length) - NA : float - numerical aperture of the pupil function (normalized by the refractive index of the immersion media) + NA : float + numerical aperture (unitless) lamb_in : float - wavelength of the light in free space - in units of length (inverse of frr's units) + wavelength of light (same length units as frr^-1) + + slope : float, optional + steepness of the sigmoid roll-off (default 4.0 keeps ~90% of + sigmoid within a single voxel) Returns ------- - Pupil : numpy.ndarray - pupil function with the specified parameters with the size of (Ny, Nx) + pupil : torch.Tensor + pupil function, pupil.shape == frr.shape, values in [0, 1] + """ + pixel_slope = slope / torch.abs(frr[0, 1] - frr[0, 0]) + cutoff = NA / lamb_in + pupil = torch.sigmoid(pixel_slope * (cutoff - frr)) + return pupil + + +def generate_tilted_pupil( + fxx: torch.Tensor, + fyy: torch.Tensor, + NA: torch.Tensor, + lamb_in: float, + n: float = 1.0, + tilt_angle_zenith: torch.Tensor = torch.Tensor([0.0]), + tilt_angle_azimuth: torch.Tensor = torch.Tensor([0.0]), + slope: float = 4.0, + phase_zernike_vector: torch.Tensor = None, +) -> torch.Tensor: + """ + Generate a soft-edged 2-D pupil that may be tilted on the Ewald sphere. + Parameters + ---------- + fxx, fyy : torch.Tensor + Cartesian frequency grids (units: 1/length) with identical shape. + NA : torch.Tensor + Numerical aperture of the objective (must satisfy NA ≤ n). + lamb_in : float + Illumination wavelength (units: length) + n : float, optional + Refractive index of the immersion medium (default 1.0). + tilt_angle_zenith : float, optional + Polar angle θ (radians) between the pupil axis and +z (0 = untilted). + tilt_angle_azimuth : float, optional + Azimuth φ (radians) of the tilt in the xy-plane (0 = +x). + slope : float, optional + Controls sigmoid roll-off (≈ 90 % change in one pixel when slope=4). + phase_zernike_vector : torch.Tensor, optional + Zernike phase coefficients (radians), shape [N]. + + Returns + ------- + pupil : torch.Tensor + 2-D soft mask with values in [0, 1] and shape == fx.shape. """ + if NA > n: + raise ValueError("NA must be ≤ n (otherwise angle would be complex).") + + # constants + K = n / lamb_in # Ewald-sphere radius + cos_alpha_max = torch.sqrt(1 - (NA / n) ** 2) + + # sampling metrics + # Assume fxx, fyy are on a regular grid ⇒ pixel spacing in fx direction: + df = torch.abs(fxx[0, 1] - fxx[0, 0]) + pixel_slope = slope / df + + # sphere coordinates + fz_sq = K**2 - fxx**2 - fyy**2 + inside_sphere = fz_sq >= 0 + # clamp to avoid negative round-off, but keep gradients + fz = torch.sqrt(torch.clamp(fz_sq, min=0.0)) + + # tilt unit vector + sx = torch.sin(tilt_angle_zenith) * torch.cos(tilt_angle_azimuth) + sy = torch.sin(tilt_angle_zenith) * torch.sin(tilt_angle_azimuth) + sz = torch.cos(tilt_angle_zenith) + + # dot-product test + dot = fxx * sx + fyy * sy + fz * sz # v · s + threshold = K * cos_alpha_max + pupil_soft = torch.sigmoid(pixel_slope * (dot - threshold)) + + # mask outside sphere + pupil = pupil_soft * inside_sphere.to(fxx.dtype) + + if phase_zernike_vector is None or phase_zernike_vector.numel() == 0: + return pupil + + norm = NA / lamb_in + rho_sq = ((fxx / norm) ** 2 + (fyy / norm) ** 2).clamp(min=1e-6) + rho = torch.sqrt(rho_sq).clamp(max=1.0) + theta = torch.atan2(fyy, fxx) + phase = torch.zeros_like(rho) - Pupil = torch.zeros(frr.shape) - Pupil[frr < NA / lamb_in] = 1 + for j, coeff in enumerate(phase_zernike_vector): + m, n = zernike.noll_to_zern(j + 1) + phase = phase + (coeff * zernike.zernike(m, n, rho, theta)) - return Pupil + return pupil * torch.exp(1j * phase) def gen_sector_Pupil(fxx, fyy, NA, lamb_in, sector_angle, rotation_angle): @@ -461,7 +551,9 @@ def generate_greens_function_z( oblique_factor = ( (1 - wavelength_illumination**2 * radial_frequencies**2) - * pupil_support + * pupil_support.type( + torch.complex64 + ) # complex to avoid sqrt(-1) -> nan ) ** (1 / 2) / wavelength_illumination if axially_even: diff --git a/waveorder/sampling.py b/waveorder/sampling.py index aab179d3..950b5d0f 100644 --- a/waveorder/sampling.py +++ b/waveorder/sampling.py @@ -1,4 +1,3 @@ -import numpy as np import torch @@ -65,7 +64,7 @@ def axial_nyquist( """ n_on_lambda = index_of_refraction_media / wavelength_emission - cutoff_frequency = n_on_lambda - np.sqrt( + cutoff_frequency = n_on_lambda - torch.sqrt( n_on_lambda**2 - (numerical_aperture_detection / wavelength_emission) ** 2 ) diff --git a/waveorder/zernike.py b/waveorder/zernike.py new file mode 100644 index 00000000..bed82663 --- /dev/null +++ b/waveorder/zernike.py @@ -0,0 +1,40 @@ +import torch + + +def noll_to_zern(j: int) -> tuple[int, int]: + n = 0 + j1 = j - 1 + while j1 > n: + n += 1 + j1 -= n + m = -n + 2 * j1 + return m, n + + +def factorial(n: int) -> int: + return 1 if n < 2 else n * factorial(n - 1) + + +def zernike_radial(m: int, n: int, rho: torch.Tensor) -> torch.Tensor: + R = torch.zeros_like(rho) + for k in range((n - abs(m)) // 2 + 1): + num = (-1.0) ** k * factorial(n - k) + denom = ( + factorial(k) + * factorial((n + abs(m)) // 2 - k) + * factorial((n - abs(m)) // 2 - k) + ) + R += num / denom * rho ** (n - 2 * k) + return R + + +def zernike( + m: int, n: int, rho: torch.Tensor, theta: torch.Tensor +) -> torch.Tensor: + R = zernike_radial(m, n, rho) + if m > 0: + return R * torch.cos(m * theta) + elif m < 0: + return R * torch.sin(-m * theta) + else: + return R