Skip to content

Commit d5d1567

Browse files
Threshold filter
1 parent 0d5e769 commit d5d1567

File tree

6 files changed

+84
-8
lines changed

6 files changed

+84
-8
lines changed

scip_tiff.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ masking:
1515
bbox_channel_index: 0
1616
export: false
1717
kwargs:
18+
smooth: 1
1819
normalization:
1920
lower: 0
2021
upper: 1

scip_zarr.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ masking:
1515
export: false
1616
kwargs:
1717
noisy_channels: [0]
18+
filter:
1819
normalization:
1920
export:
2021
format: parquet

src/scip/filters/__init__.py

Whitespace-only changes.

src/scip/filters/threshold.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Mapping
2+
3+
from scip.utils.util import copy_without
4+
import dask.bag
5+
import scipy.stats
6+
7+
8+
def feature_partition(part):
9+
for i in range(len(part)):
10+
if "pixels" in part[i]:
11+
part[i]["filter_sum"] = part[i]["pixels"][0].sum()
12+
return part
13+
14+
15+
def item(bag: dask.bag.Bag) -> Mapping[str, dask.bag.Item]:
16+
bag = bag.filter(lambda a: "filter_sum" in a).map(lambda a: a["filter_sum"])
17+
mu = bag.mean()
18+
std = bag.std()
19+
return dict(mu=mu, std=std)
20+
21+
22+
def predicate(x, *, mu, std):
23+
q5 = scipy.stats.norm.ppf(0.05, loc=mu, scale=std)
24+
if ("filter_sum" in x) and (x["filter_sum"] > q5):
25+
return x
26+
else:
27+
return copy_without(x, without=["mask", "pixels"])

src/scip/loading/zarr.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,46 @@
2323
from pathlib import Path
2424
import re
2525
from typing import Tuple
26+
import copy
27+
28+
29+
def reload_image_partition(
30+
partition,
31+
channels,
32+
clip: int,
33+
regex: str,
34+
limit: int = -1
35+
):
36+
z = zarr.open(partition[0]["path"])
37+
indices = [p["zarr_idx"] for p in partition]
38+
data = z.get_coordinate_selection(indices)
39+
shapes = numpy.array(z.attrs["shape"])[indices]
40+
41+
newpartition = copy.deepcopy(partition)
42+
for i in range(len(partition)):
43+
if "mask" not in partition[i]:
44+
continue
45+
46+
if clip is not None:
47+
newpartition[i]["pixels"] = numpy.clip(data[i].reshape(shapes[i])[channels], 0, clip)
48+
else:
49+
newpartition[i]["pixels"] = data[i].reshape(shapes[i])[channels]
50+
newpartition[i]["pixels"] = newpartition[i]["pixels"].astype(numpy.float32)
51+
return newpartition
2652

2753

2854
def load_image_partition(partition, z, channels, clip):
55+
2956
start, end = partition[0]["zarr_idx"], partition[-1]["zarr_idx"]
3057
data = z[start:end + 1]
3158
shapes = z.attrs["shape"][start:end + 1]
32-
for i, event in enumerate(partition):
59+
60+
for i in range(len(partition)):
3361
if clip is not None:
34-
event["pixels"] = numpy.clip(data[i].reshape(shapes[i])[channels], 0, clip)
62+
partition[i]["pixels"] = numpy.clip(data[i].reshape(shapes[i])[channels], 0, clip)
3563
else:
36-
event["pixels"] = data[i].reshape(shapes[i])[channels]
37-
event["pixels"] = event["pixels"].astype(numpy.float32)
64+
partition[i]["pixels"] = data[i].reshape(shapes[i])[channels]
65+
partition[i]["pixels"] = partition[i]["pixels"].astype(numpy.float32)
3866
return partition
3967

4068

src/scip/main.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ def get_images_bag(
4747
channels: List[int],
4848
config: dict,
4949
partition_size: int,
50-
gpu_accelerated: bool
50+
gpu_accelerated: bool,
51+
loader_module
5152
) -> Tuple[dask.bag.Bag, int, dict]:
5253

53-
loader_module = import_module('scip.loading.%s' % config["loading"]["format"])
5454
loader = partial(
5555
loader_module.bag_from_directory,
5656
channels=channels,
@@ -109,7 +109,7 @@ def get_schema(event):
109109

110110

111111
def remove_pixels(event):
112-
newevent = copy_without(event, ["pixels", "shape", "mask"])
112+
newevent = copy_without(event, ["pixels", "mask"])
113113
newevent["shape"] = list(event["mask"].shape)
114114
newevent["mask"] = event["mask"].ravel()
115115
return newevent
@@ -203,13 +203,15 @@ def main(
203203

204204
logger.debug("loading images in to bags")
205205

206+
loader_module = import_module('scip.loading.%s' % config["loading"]["format"])
206207
with dask.config.set(**{'array.slicing.split_large_chunks': False}):
207208
images, maximum_pixel_value, loader_meta = get_images_bag(
208209
paths=paths,
209210
channels=channels,
210211
config=config,
211212
partition_size=partition_size,
212-
gpu_accelerated=gpu > 0
213+
gpu_accelerated=gpu > 0,
214+
loader_module=loader_module
213215
)
214216

215217
futures = []
@@ -309,6 +311,23 @@ def main(
309311
output=output
310312
))
311313

314+
if config["filter"] is not None:
315+
filter_module = import_module('scip.loading.%s' % config["filter"]["name"])
316+
317+
images = images.map_partitions(filter_module.feature_partition)
318+
images = images.map(copy_without, without=["pixels"]).persist()
319+
filter_items = filter_module.grouped_item(images, key="group")
320+
321+
filter_items = filter_items["mu"].compute()
322+
323+
images = images.map(filter_module.predicate, **filter_items)
324+
325+
images = images.map_partitions(
326+
loader_module.reload_image_partition,
327+
channels=channels,
328+
**(config["loading"]["loader_kwargs"] or dict())
329+
)
330+
312331
quantiles = None
313332
if config["normalization"] is not None:
314333
logger.debug("performing normalization")

0 commit comments

Comments
 (0)