Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions examples/train_3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,24 +37,24 @@
iterations_per_epoch = 1000 # number of iterations per epoch
random_seed = 42 # random seed for reproducibility

# classes = ["nuc", "er"] # list of classes to segment
classes = get_tested_classes() # list of classes to segment
classes = ["nuc", "er", "mito"] # sample list of classes to segment
# classes = get_tested_classes() # list of all classes in the challenge

# # Defining model (comment out all that are not used)
# # 3D UNet
# model_name = "3d_unet" # name of the model to use
# model_to_load = "3d_unet" # name of the pre-trained model to load
# model = UNet_3D(1, len(classes))
# Defining model (comment out all that are not used)
# 3D UNet
model_name = "3d_unet" # name of the model to use
model_to_load = "3d_unet" # name of the pre-trained model to load
model = UNet_3D(1, len(classes))

# 3D ResNet
# model_name = "3d_resnet" # name of the model to use
# model_to_load = "3d_resnet" # name of the pre-trained model to load
# model = ResNet(ndims=3, output_nc=len(classes))

# 3D ViT VNet
model_name = "3d_vnet" # name of the model to use
model_to_load = "3d_vnet" # name of the pre-trained model to load
model = ViTVNet(len(classes), img_size=input_array_info["shape"])
# # 3D ViT VNet
# model_name = "3d_vnet" # name of the model to use
# model_to_load = "3d_vnet" # name of the pre-trained model to load
# model = ViTVNet(len(classes), img_size=input_array_info["shape"])

load_model = "latest" # load the latest model or the best validation model

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ dependencies = [
"cellmap-schemas==0.8.0",
"click>=8, <9",
"cellmap-data@git+https://github.com/janelia-cellmap/cellmap-data",
"cellmap-flow",
"tqdm",
"zarr < 3.0.0",
"xarray-ome-ngff >= 3.1.1, <=4.0.0",
Expand Down
3 changes: 2 additions & 1 deletion src/cellmap_segmentation_challenge/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .predict import predict_cli
from .process import process_cli
from .train import train_cli
from .visualize import visualize_cli
from .visualize import visualize_cli, flow
from .speedtest import speedtest_cli


Expand All @@ -31,3 +31,4 @@ def echo():
run.add_command(visualize_cli, name="visualize")
run.add_command(package_submission_cli, name="pack-results")
run.add_command(speedtest_cli, name="speedtest")
run.add_command(flow, name="flow")
57 changes: 57 additions & 0 deletions src/cellmap_segmentation_challenge/cli/visualize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import click
import os

from upath import UPath
from cellmap_segmentation_challenge import SEARCH_PATH, RAW_NAME
from cellmap_segmentation_challenge.utils import load_safe_config
from cellmap_data.utils import find_level


@click.command
Expand Down Expand Up @@ -58,3 +64,54 @@ def visualize_cli(datasets, crops, classes, kinds):
classes=classes.split(","),
kinds=kinds.split(","),
)


@click.command
@click.option(
"--script_path",
"-s",
type=click.STRING,
required=True,
help="Path to the script to run for live prediction.",
)
@click.option(
"--dataset",
"-d",
type=click.STRING,
required=True,
help="Dataset to view (Example: 'jrc_cos7-1a')",
)
@click.option(
"--level",
"-l",
type=click.STRING,
required=False,
default=None,
help="(Optional) Scale level to feed to model (Example: 's0'). If not specified, will be inferred from input_array_info in the config script.",
)
def flow(script_path, dataset, level):
"""
Run a cellmap-flow to visualize live predictions using a script defining a model config, visualizing the results in Neuroglancer.

Parameters
----------
script_path : str
Path to the script defining the model config (e.g. `examples/train_2D.py`).
dataset : str
Dataset to view (Example: 'jrc_cos7-1a'),
level : str
(Optional) Scale level to feed to model (Example: 's0'). If not specified, will be inferred from input_array_info in the config script.
"""

config = load_safe_config(script_path)

dataset_path = SEARCH_PATH.format(dataset=dataset, name=RAW_NAME)
if level is None:
level = find_level(
dataset_path,
{k: v for k, v in zip("zyx", config.input_array_info["scale"])},
)

os.system(
f"cellmap_flow script -s {script_path} -d {(UPath(dataset_path) / level).path}"
)
1 change: 0 additions & 1 deletion src/cellmap_segmentation_challenge/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
TRUTH_PATH = (BASE_DATA_PATH / "ground_truth.zarr").path

# s3 paths
# GT_S3_BUCKET = "janelia-cellmap-fg5f2y1pl8"
GT_S3_BUCKET = "janelia-cosem-datasets"
RAW_S3_BUCKET = "janelia-cosem-datasets"
S3_SEARCH_PATH = "{dataset}/{dataset}.zarr/recon-1/{name}"
Expand Down
Loading