Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions apiv2/schema/schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,9 @@ enums:
InstanceSegmentation:
text: InstanceSegmentation
description: A volume with labels for multiple instances
InstanceSegmentationMask:
text: InstanceSegmentationMask
description: A mask with labels for multiple instances
Mesh:
text: Mesh
description: A surface mesh volumes
Expand Down
9 changes: 9 additions & 0 deletions ingestion_tools/dataset_configs/template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ annotations: OPTIONAL
delimiter: OPTIONAL, STRING (DEFAULT ',')
parent_filters: see InstanceSegmentation.parent_filters
exclude: SEE InstanceSegmentation.exclude
- InstanceSegmentationMask:
file_format: see InstanceSegmentation.file_format
glob_string: see InstanceSegmentation.glob_string
glob_strings: see InstanceSegmentation.glob_strings
is_visualization_default: see InstanceSegmentation.is_visualization_default
is_portal_standard: OPTIONAL, BOOLEAN (DEFAULT FALSE)
scale_factor: OPTIONAL, FLOAT (DEFAULT 1) (POSITIVE)
parent_filters: see InstanceSegmentation.parent_filters
exclude: SEE InstanceSegmentation.exclude
- SegmentationMask:
file_format: see InstanceSegmentation.file_format
glob_string: see InstanceSegmentation.glob_string
Expand Down
102 changes: 99 additions & 3 deletions ingestion_tools/scripts/common/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,24 @@ class ZarrReader:
def __init__(self, fs, zarrdir):
self.fs = fs
self.zarrdir = zarrdir
self._loc = None

def get_data(self):
loc = ome_zarr.io.ZarrLocation(self.fs.destformat(self.zarrdir))
loc = self._load_zarr_loc()
data = loc.load("0")
return data

@property
def attrs(self):
loc = self._load_zarr_loc()
group = zarr.group(loc.store)
return group.attrs

def _load_zarr_loc(self):
if self._loc is None:
self._loc = ome_zarr.io.ZarrLocation(self.fs.destformat(self.zarrdir))
return self._loc


class ZarrWriter:
def __init__(self, fs: FileSystemApi, zarrdir: str):
Expand Down Expand Up @@ -97,6 +109,7 @@ def write_data(
voxel_spacing: List[Tuple[float, float, float]],
chunk_size: Tuple[int, int, int] = (256, 256, 256),
scale_z_axis: bool = True,
store_labels_metadata: bool = False,
):
pyramid = []
scales = []
Expand All @@ -110,6 +123,36 @@ def write_data(
pyramid.append(d)
scales.append(self.ome_zarr_transforms(vs))

# Store the labels contained in the data if the flag is activated
if store_labels_metadata:

arr = data[0]

# t = time.perf_counter()
labels = [int(label) for label in np.unique(arr[arr > 0])]
Copy link

@seankmartin seankmartin Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this as I actually expected it to have a bigger difference, but it is slightly more efficient from what I can tell to take the uniques over the unfiltered array and then afterwards filter out from the uniques all the labels <=0 in the list comprehension

# print(f"Time full image {time.perf_counter() - t:.3f}s {labels}")

# t = time.perf_counter()
# sub = arr[::10, :, :]
# labels = set(int(label) for label in np.unique(sub[sub > 0]))
# sub = arr[:, ::10, :]
# labels.update(int(label) for label in np.unique(sub[sub > 0]))
# sub = arr[:, :, :10]
# labels.update(int(label) for label in np.unique(sub[sub > 0]))
# print(f"Time 10th slices {time.perf_counter() - t:.3f}s {list(labels)}")

# t = time.perf_counter()
# sub = arr[::50, :, :]
# labels = set(int(label) for label in np.unique(sub[sub > 0]))
# sub = arr[:, ::50, :]
# labels.update(int(label) for label in np.unique(sub[sub > 0]))
# sub = arr[:, :, :50]
# labels.update(int(label) for label in np.unique(sub[sub > 0]))
# print(f"Time 50th slices {time.perf_counter() - t:.3f}s {list(labels)}")

label_values = [{"id": label, "label": f"{label}"} for label in labels]
self.root_group.attrs["labels_metadata"] = {"version": "1.0", "labels": label_values}

# Write the pyramid to the zarr store
return ome_zarr.writer.write_multiscale(
pyramid,
Expand Down Expand Up @@ -344,12 +387,18 @@ def pyramid_to_omezarr(
zarrdir: str,
write: bool = True,
pyramid_voxel_spacing: List[Tuple[float, float, float]] = None,
store_labels_metadata: bool = False,
) -> str:
destination_zarrdir = fs.destformat(zarrdir)
# Write zarr data as 256^3 voxel chunks
if write:
writer = ZarrWriter(fs, destination_zarrdir)
writer.write_data(pyramid, voxel_spacing=pyramid_voxel_spacing, chunk_size=(256, 256, 256))
writer.write_data(
pyramid,
voxel_spacing=pyramid_voxel_spacing,
chunk_size=(256, 256, 256),
store_labels_metadata=store_labels_metadata,
)
else:
print(f"skipping remote push for {destination_zarrdir}")
return os.path.basename(zarrdir)
Expand Down Expand Up @@ -446,6 +495,48 @@ def has_label(self) -> bool:
return bool(np.any(self.volume_reader.get_pyramid_base_data() == self.label))


class MultiLabelMaskConverter(TomoConverter):
def __init__(
self,
fs: FileSystemApi,
filename: str,
header_only: bool = False,
scale_0_dims: tuple[int, int, int] | None = None,
):
super().__init__(fs=fs, filename=filename, header_only=header_only, scale_0_dims=scale_0_dims)

def get_pyramid_base_data(self) -> np.ndarray:
data = self.volume_reader.get_pyramid_base_data()

if not self.scale_0_dims:
return self.scaled_data_transformation(data)

from scipy.ndimage import zoom

x, y, z = data.shape
nx, ny, nz = self.scale_0_dims
zoom_factor = (nx / x, ny / y, nz / z)

# rescaled = rescale(
# data,
# scale=zoom_factor,
# order=0,
# preserve_range=True,
# anti_aliasing=False,
# )

rescaled = zoom(data, zoom=zoom_factor, order=0)

return self.scaled_data_transformation(rescaled)

@classmethod
def scaled_data_transformation(cls, data: np.ndarray) -> np.ndarray:
# For instance segmentation masks we have multiple labels, so we want an uint 32 output.
# downscale_local_mean will return float array even for bool input with non-binary values
# return data.astype(np.uint32)
return data


def get_volume_metadata(config: DepositionImportConfig, output_prefix: str) -> dict[str, Any]:
# Generates metadata related to volume files.
scales = []
Expand Down Expand Up @@ -494,7 +585,10 @@ def get_converter(
label: int | None = None,
scale_0_dims: tuple[int, int, int] | None = None,
threshold: float | None = None,
multilabels: bool = False,
) -> TomoConverter | MaskConverter:
if multilabels:
return MultiLabelMaskConverter(fs, tomo_filename, scale_0_dims=scale_0_dims)
if label is not None:
return MaskConverter(fs, tomo_filename, label, scale_0_dims=scale_0_dims, threshold=threshold)
return TomoConverter(fs, tomo_filename, scale_0_dims=scale_0_dims)
Expand All @@ -512,15 +606,17 @@ def make_pyramids(
label: int = None,
scale_0_dims=None,
threshold: float | None = None,
multilabels: bool = False,
):
tc = get_converter(fs, tomo_filename, label, scale_0_dims, threshold)
tc = get_converter(fs, tomo_filename, label, scale_0_dims, threshold, multilabels=multilabels)
pyramid, pyramid_voxel_spacing = tc.make_pyramid(scale_z_axis=scale_z_axis, voxel_spacing=voxel_spacing)
_ = tc.pyramid_to_omezarr(
fs,
pyramid,
f"{output_prefix}.zarr",
write_zarr,
pyramid_voxel_spacing=pyramid_voxel_spacing,
store_labels_metadata=multilabels,
)
_ = tc.pyramid_to_mrc(fs, pyramid, f"{output_prefix}.mrc", write_mrc, header_mapper, voxel_spacing)

Expand Down
38 changes: 38 additions & 0 deletions ingestion_tools/scripts/importers/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def _instantiate(
anno = PointAnnotation(**instance_args)
if shape == "InstanceSegmentation":
anno = InstanceSegmentationAnnotation(**instance_args)
if shape == "InstanceSegmentationMask":
anno = InstanceSegmentationMaskAnnotation(**instance_args)
if shape == "TriangularMesh":
anno = TriangularMeshAnnotation(**instance_args)
if shape == "TriangularMeshGroup":
Expand Down Expand Up @@ -316,6 +318,42 @@ def convert(self, output_prefix: str):
)


class InstanceSegmentationMaskAnnotation(VolumeAnnotationSource):
shape = "InstanceSegmentationMask"
mask_label: int
scale_factor: float
is_portal_standard: bool

def __init__(
self,
mask_label: int | None = None,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to keep mask_label or can it be removed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes, good catch, I removed it from the template, but I forgot to remove it from there

scale_factor: float = 1.0,
is_portal_standard: bool = False,
*args,
**kwargs,
) -> None:
super().__init__(*args, **kwargs)
self.mask_label = mask_label if mask_label else 1
self.scale_factor = scale_factor
self.is_portal_standard = is_portal_standard

def convert(self, output_prefix: str):
# output_dims = self.get_output_dim() if self.rescale else None
output_dims = self.get_output_dim()
# output_dims = None

return make_pyramids(
self.config.fs,
self.get_output_filename(output_prefix),
self.path,
write_mrc=False,
write_zarr=self.config.write_zarr,
voxel_spacing=self.get_voxel_spacing().as_float(),
scale_0_dims=output_dims,
multilabels=True,
)


class SemanticSegmentationMaskAnnotation(VolumeAnnotationSource):
shape = "SegmentationMask" # Don't expose SemanticSegmentationMask to the public portal.
mask_label: int
Expand Down
57 changes: 44 additions & 13 deletions ingestion_tools/scripts/importers/visualization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from common import colors
from common.colors import generate_hash, to_base_hash_input
from common.finders import DefaultImporterFactory
from common.image import VolumeInfo
from common.image import VolumeInfo, ZarrReader
from common.metadata import NeuroglancerMetadata
from importers.annotation import OrientedPointAnnotation
from importers.base_importer import BaseImporter
Expand Down Expand Up @@ -74,8 +74,9 @@ def _to_segmentation_mask_layer(
source_path: str,
file_metadata: dict[str, Any],
name_prefix: str,
color: str,
color: str | dict[int, str],
resolution: tuple[float, float, float],
visible_segments: tuple[int, ...] = (1,),
**kwargs,
) -> dict[str, Any]:
return state_generator.generate_segmentation_mask_layer(
Expand All @@ -85,6 +86,7 @@ def _to_segmentation_mask_layer(
color=color,
scale=resolution,
is_visible=file_metadata.get("is_visualization_default"),
visible_segments=visible_segments,
)

def _to_point_layer(
Expand Down Expand Up @@ -166,6 +168,7 @@ def get_annotation_layer_info(self, alignment_metadata_path: str) -> dict[str, A
"InstanceSegmentation",
"TriangularMesh",
"TriangularMeshGroup",
"InstanceSegmentationMask",
}:
print(f"Skipping file with unknown shape {shape}")
continue
Expand All @@ -177,26 +180,36 @@ def get_annotation_layer_info(self, alignment_metadata_path: str) -> dict[str, A
print(f"Skipping file with unsupported format {file.get('format')}")
continue

nb_colors = 1
visible_segments = None
if shape == "InstanceSegmentationMask":
# We load the ome zarr file and get the unique non zero labels and then set of those as visible
visible_segments = self._get_labels(file.get("path"))
nb_colors = len(visible_segments)

color_seed = generate_hash({**annotation_hash_input, **{"shape": shape}})
hex_colors, float_colors = colors.get_hex_colors(1, exclude=colors_used, seed=color_seed)
hex_colors, float_colors = colors.get_hex_colors(nb_colors, exclude=colors_used, seed=color_seed)

path = self.config.to_formatted_path(
os.path.join(precompute_path, f"{metadata_file_name}_{shape.lower()}"),
)

is_instance_seg = shape == "InstanceSegmentation"
is_instance_seg = shape == "InstanceSegmentation" or shape == "InstanceSegmentationMask"

annotation_layer_info[file.get("path")] = {
args = {
"source_path": path,
"file_metadata": file,
"name_prefix": name_prefix,
"color": hex_colors[0],
"shape": shape,
"args": {
"source_path": path,
"file_metadata": file,
"name_prefix": name_prefix,
"color": hex_colors[0],
"shape": shape,
},
}

if shape == "InstanceSegmentationMask":
args["visible_segments"] = visible_segments
args["color"] = dict(zip(visible_segments, hex_colors))

annotation_layer_info[file.get("path")] = {"shape": shape, "args": args}

if not is_instance_seg:
colors_used.append(float_colors[0])

Expand Down Expand Up @@ -226,6 +239,23 @@ def _has_oriented_mesh(self, path: str):
mesh_folder_path = os.path.join(self.config.output_prefix, oriented_mesh_filename)
return fs.exists(mesh_folder_path)

def _get_labels(self, path: str):
segmentation_filename = os.path.join(self.config.output_prefix, path)

reader = ZarrReader(self.config.fs, segmentation_filename)
try:
labels_info = reader.attrs.get("labels_metadata", {})["labels"]
labels = [label["id"] for label in labels_info]
except Exception:
# Get labels iterating by chunks over the tab
# We lazy import dask and numpy
import dask.array as da
import numpy as np

arr = reader.get_data()
labels = set(da.unique(arr[arr > 0]).compute().astype(np.integer))
return tuple(labels)

def _create_config(self, alignment_metadata_path: str) -> dict[str, Any]:
tomogram = self.get_tomogram()
volume_info = tomogram.get_output_volume_info()
Expand All @@ -237,6 +267,7 @@ def _create_config(self, alignment_metadata_path: str) -> dict[str, Any]:
t = time()
print("Start contrast limit computation for", tomogram)
contrast_limits = tomogram.get_contrast_limits()

print(f"Computed contrast limit {contrast_limits} in {(time() - t):.2f}s")
layers = [self._to_tomogram_layer(tomogram, volume_info, resolution, contrast_limits)]

Expand All @@ -245,7 +276,7 @@ def _create_config(self, alignment_metadata_path: str) -> dict[str, Any]:
for _, info in annotation_layer_info.items():
args = {**info["args"], "resolution": resolution}
shape = info["shape"]
if shape == "SegmentationMask":
if shape == "SegmentationMask" or shape == "InstanceSegmentationMask":
layers.append(self._to_segmentation_mask_layer(**args))
elif shape in {"Point", "OrientedPoint", "InstanceSegmentation"}:
if shape == "OrientedPoint":
Expand Down
Loading