-
Notifications
You must be signed in to change notification settings - Fork 0
feat: Add support/config for InstanceSegmentationMask #5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
cbb334f
c4aa044
c4527da
2bb7506
4a3d320
bdf0085
5d4bb52
598e594
00a379d
6dec931
5d6ac85
6101bb2
129fde5
0cf3ffe
7eccd55
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,12 +53,24 @@ class ZarrReader: | |
def __init__(self, fs, zarrdir): | ||
self.fs = fs | ||
self.zarrdir = zarrdir | ||
self._loc = None | ||
|
||
def get_data(self): | ||
loc = ome_zarr.io.ZarrLocation(self.fs.destformat(self.zarrdir)) | ||
loc = self._load_zarr_loc() | ||
data = loc.load("0") | ||
return data | ||
|
||
@property | ||
def attrs(self): | ||
loc = self._load_zarr_loc() | ||
group = zarr.group(loc.store) | ||
return group.attrs | ||
|
||
def _load_zarr_loc(self): | ||
if self._loc is None: | ||
self._loc = ome_zarr.io.ZarrLocation(self.fs.destformat(self.zarrdir)) | ||
return self._loc | ||
|
||
|
||
class ZarrWriter: | ||
def __init__(self, fs: FileSystemApi, zarrdir: str): | ||
|
@@ -97,6 +109,7 @@ def write_data( | |
voxel_spacing: List[Tuple[float, float, float]], | ||
chunk_size: Tuple[int, int, int] = (256, 256, 256), | ||
scale_z_axis: bool = True, | ||
store_labels_metadata: bool = False, | ||
): | ||
pyramid = [] | ||
scales = [] | ||
|
@@ -110,6 +123,36 @@ def write_data( | |
pyramid.append(d) | ||
scales.append(self.ome_zarr_transforms(vs)) | ||
|
||
# Store the labels contained in the data if the flag is activated | ||
if store_labels_metadata: | ||
|
||
arr = data[0] | ||
|
||
# t = time.perf_counter() | ||
labels = [int(label) for label in np.unique(arr[arr > 0])] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tested this as I actually expected it to have a bigger difference, but it is slightly more efficient from what I can tell to take the uniques over the unfiltered array and then afterwards filter out from the uniques all the labels <=0 in the list comprehension |
||
# print(f"Time full image {time.perf_counter() - t:.3f}s {labels}") | ||
|
||
# t = time.perf_counter() | ||
# sub = arr[::10, :, :] | ||
# labels = set(int(label) for label in np.unique(sub[sub > 0])) | ||
# sub = arr[:, ::10, :] | ||
# labels.update(int(label) for label in np.unique(sub[sub > 0])) | ||
# sub = arr[:, :, :10] | ||
# labels.update(int(label) for label in np.unique(sub[sub > 0])) | ||
# print(f"Time 10th slices {time.perf_counter() - t:.3f}s {list(labels)}") | ||
|
||
# t = time.perf_counter() | ||
# sub = arr[::50, :, :] | ||
# labels = set(int(label) for label in np.unique(sub[sub > 0])) | ||
# sub = arr[:, ::50, :] | ||
# labels.update(int(label) for label in np.unique(sub[sub > 0])) | ||
# sub = arr[:, :, :50] | ||
# labels.update(int(label) for label in np.unique(sub[sub > 0])) | ||
# print(f"Time 50th slices {time.perf_counter() - t:.3f}s {list(labels)}") | ||
|
||
label_values = [{"id": label, "label": f"{label}"} for label in labels] | ||
self.root_group.attrs["labels_metadata"] = {"version": "1.0", "labels": label_values} | ||
|
||
# Write the pyramid to the zarr store | ||
return ome_zarr.writer.write_multiscale( | ||
pyramid, | ||
|
@@ -344,12 +387,18 @@ def pyramid_to_omezarr( | |
zarrdir: str, | ||
write: bool = True, | ||
pyramid_voxel_spacing: List[Tuple[float, float, float]] = None, | ||
store_labels_metadata: bool = False, | ||
) -> str: | ||
destination_zarrdir = fs.destformat(zarrdir) | ||
# Write zarr data as 256^3 voxel chunks | ||
if write: | ||
writer = ZarrWriter(fs, destination_zarrdir) | ||
writer.write_data(pyramid, voxel_spacing=pyramid_voxel_spacing, chunk_size=(256, 256, 256)) | ||
writer.write_data( | ||
pyramid, | ||
voxel_spacing=pyramid_voxel_spacing, | ||
chunk_size=(256, 256, 256), | ||
store_labels_metadata=store_labels_metadata, | ||
) | ||
else: | ||
print(f"skipping remote push for {destination_zarrdir}") | ||
return os.path.basename(zarrdir) | ||
|
@@ -446,6 +495,48 @@ def has_label(self) -> bool: | |
return bool(np.any(self.volume_reader.get_pyramid_base_data() == self.label)) | ||
|
||
|
||
class MultiLabelMaskConverter(TomoConverter): | ||
def __init__( | ||
self, | ||
fs: FileSystemApi, | ||
filename: str, | ||
header_only: bool = False, | ||
scale_0_dims: tuple[int, int, int] | None = None, | ||
): | ||
super().__init__(fs=fs, filename=filename, header_only=header_only, scale_0_dims=scale_0_dims) | ||
|
||
def get_pyramid_base_data(self) -> np.ndarray: | ||
data = self.volume_reader.get_pyramid_base_data() | ||
|
||
if not self.scale_0_dims: | ||
return self.scaled_data_transformation(data) | ||
|
||
from scipy.ndimage import zoom | ||
|
||
x, y, z = data.shape | ||
nx, ny, nz = self.scale_0_dims | ||
zoom_factor = (nx / x, ny / y, nz / z) | ||
|
||
# rescaled = rescale( | ||
# data, | ||
# scale=zoom_factor, | ||
# order=0, | ||
# preserve_range=True, | ||
# anti_aliasing=False, | ||
# ) | ||
|
||
rescaled = zoom(data, zoom=zoom_factor, order=0) | ||
|
||
return self.scaled_data_transformation(rescaled) | ||
|
||
@classmethod | ||
def scaled_data_transformation(cls, data: np.ndarray) -> np.ndarray: | ||
# For instance segmentation masks we have multiple labels, so we want an uint 32 output. | ||
# downscale_local_mean will return float array even for bool input with non-binary values | ||
# return data.astype(np.uint32) | ||
return data | ||
|
||
|
||
def get_volume_metadata(config: DepositionImportConfig, output_prefix: str) -> dict[str, Any]: | ||
# Generates metadata related to volume files. | ||
scales = [] | ||
|
@@ -494,7 +585,10 @@ def get_converter( | |
label: int | None = None, | ||
scale_0_dims: tuple[int, int, int] | None = None, | ||
threshold: float | None = None, | ||
multilabels: bool = False, | ||
) -> TomoConverter | MaskConverter: | ||
if multilabels: | ||
return MultiLabelMaskConverter(fs, tomo_filename, scale_0_dims=scale_0_dims) | ||
if label is not None: | ||
return MaskConverter(fs, tomo_filename, label, scale_0_dims=scale_0_dims, threshold=threshold) | ||
return TomoConverter(fs, tomo_filename, scale_0_dims=scale_0_dims) | ||
|
@@ -512,15 +606,17 @@ def make_pyramids( | |
label: int = None, | ||
scale_0_dims=None, | ||
threshold: float | None = None, | ||
multilabels: bool = False, | ||
): | ||
tc = get_converter(fs, tomo_filename, label, scale_0_dims, threshold) | ||
tc = get_converter(fs, tomo_filename, label, scale_0_dims, threshold, multilabels=multilabels) | ||
pyramid, pyramid_voxel_spacing = tc.make_pyramid(scale_z_axis=scale_z_axis, voxel_spacing=voxel_spacing) | ||
_ = tc.pyramid_to_omezarr( | ||
fs, | ||
pyramid, | ||
f"{output_prefix}.zarr", | ||
write_zarr, | ||
pyramid_voxel_spacing=pyramid_voxel_spacing, | ||
store_labels_metadata=multilabels, | ||
) | ||
_ = tc.pyramid_to_mrc(fs, pyramid, f"{output_prefix}.mrc", write_mrc, header_mapper, voxel_spacing) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -107,6 +107,8 @@ def _instantiate( | |
anno = PointAnnotation(**instance_args) | ||
if shape == "InstanceSegmentation": | ||
anno = InstanceSegmentationAnnotation(**instance_args) | ||
if shape == "InstanceSegmentationMask": | ||
anno = InstanceSegmentationMaskAnnotation(**instance_args) | ||
if shape == "TriangularMesh": | ||
anno = TriangularMeshAnnotation(**instance_args) | ||
if shape == "TriangularMeshGroup": | ||
|
@@ -316,6 +318,42 @@ def convert(self, output_prefix: str): | |
) | ||
|
||
|
||
class InstanceSegmentationMaskAnnotation(VolumeAnnotationSource): | ||
shape = "InstanceSegmentationMask" | ||
mask_label: int | ||
scale_factor: float | ||
is_portal_standard: bool | ||
|
||
def __init__( | ||
self, | ||
mask_label: int | None = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need to keep There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, yes, good catch, I removed it from the template, but I forgot to remove it from there |
||
scale_factor: float = 1.0, | ||
is_portal_standard: bool = False, | ||
*args, | ||
**kwargs, | ||
) -> None: | ||
super().__init__(*args, **kwargs) | ||
self.mask_label = mask_label if mask_label else 1 | ||
self.scale_factor = scale_factor | ||
self.is_portal_standard = is_portal_standard | ||
|
||
def convert(self, output_prefix: str): | ||
# output_dims = self.get_output_dim() if self.rescale else None | ||
aranega marked this conversation as resolved.
Show resolved
Hide resolved
|
||
output_dims = self.get_output_dim() | ||
# output_dims = None | ||
|
||
return make_pyramids( | ||
self.config.fs, | ||
self.get_output_filename(output_prefix), | ||
self.path, | ||
write_mrc=False, | ||
write_zarr=self.config.write_zarr, | ||
voxel_spacing=self.get_voxel_spacing().as_float(), | ||
scale_0_dims=output_dims, | ||
multilabels=True, | ||
) | ||
|
||
|
||
class SemanticSegmentationMaskAnnotation(VolumeAnnotationSource): | ||
shape = "SegmentationMask" # Don't expose SemanticSegmentationMask to the public portal. | ||
mask_label: int | ||
|
Uh oh!
There was an error while loading. Please reload this page.