From d13c014a53454df8f921ffdd57ea0ebdf773cf71 Mon Sep 17 00:00:00 2001 From: Jeff Rhoades <37990507+rhoadesScholar@users.noreply.github.com> Date: Thu, 30 Jan 2025 12:29:39 -0500 Subject: [PATCH 1/2] feat: add flow command for live predictions --- examples/train_3D.py | 22 ++++++------ pyproject.toml | 1 + .../cli/__init__.py | 3 +- .../cli/visualize.py | 34 +++++++++++++++++++ src/cellmap_segmentation_challenge/config.py | 1 - 5 files changed, 48 insertions(+), 13 deletions(-) diff --git a/examples/train_3D.py b/examples/train_3D.py index 0b65103..9b684ce 100644 --- a/examples/train_3D.py +++ b/examples/train_3D.py @@ -37,24 +37,24 @@ iterations_per_epoch = 1000 # number of iterations per epoch random_seed = 42 # random seed for reproducibility -# classes = ["nuc", "er"] # list of classes to segment -classes = get_tested_classes() # list of classes to segment +classes = ["nuc", "er", "mito"] # sample list of classes to segment +# classes = get_tested_classes() # list of all classes in the challenge -# # Defining model (comment out all that are not used) -# # 3D UNet -# model_name = "3d_unet" # name of the model to use -# model_to_load = "3d_unet" # name of the pre-trained model to load -# model = UNet_3D(1, len(classes)) +# Defining model (comment out all that are not used) +# 3D UNet +model_name = "3d_unet" # name of the model to use +model_to_load = "3d_unet" # name of the pre-trained model to load +model = UNet_3D(1, len(classes)) # 3D ResNet # model_name = "3d_resnet" # name of the model to use # model_to_load = "3d_resnet" # name of the pre-trained model to load # model = ResNet(ndims=3, output_nc=len(classes)) -# 3D ViT VNet -model_name = "3d_vnet" # name of the model to use -model_to_load = "3d_vnet" # name of the pre-trained model to load -model = ViTVNet(len(classes), img_size=input_array_info["shape"]) +# # 3D ViT VNet +# model_name = "3d_vnet" # name of the model to use +# model_to_load = "3d_vnet" # name of the pre-trained model to load +# model = ViTVNet(len(classes), img_size=input_array_info["shape"]) load_model = "latest" # load the latest model or the best validation model diff --git a/pyproject.toml b/pyproject.toml index 2a79921..0a77d1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "cellmap-schemas==0.8.0", "click>=8, <9", "cellmap-data@git+https://github.com/janelia-cellmap/cellmap-data", + "cellmap-flow", "tqdm", "zarr < 3.0.0", "xarray-ome-ngff >= 3.1.1, <=4.0.0", diff --git a/src/cellmap_segmentation_challenge/cli/__init__.py b/src/cellmap_segmentation_challenge/cli/__init__.py index 672ede2..7d406a0 100644 --- a/src/cellmap_segmentation_challenge/cli/__init__.py +++ b/src/cellmap_segmentation_challenge/cli/__init__.py @@ -6,7 +6,7 @@ from .predict import predict_cli from .process import process_cli from .train import train_cli -from .visualize import visualize_cli +from .visualize import visualize_cli, flow from .speedtest import speedtest_cli @@ -31,3 +31,4 @@ def echo(): run.add_command(visualize_cli, name="visualize") run.add_command(package_submission_cli, name="pack-results") run.add_command(speedtest_cli, name="speedtest") +run.add_command(flow, name="flow") diff --git a/src/cellmap_segmentation_challenge/cli/visualize.py b/src/cellmap_segmentation_challenge/cli/visualize.py index 920f54f..4f00813 100644 --- a/src/cellmap_segmentation_challenge/cli/visualize.py +++ b/src/cellmap_segmentation_challenge/cli/visualize.py @@ -1,4 +1,6 @@ import click +import os +from cellmap_segmentation_challenge import SEARCH_PATH, RAW_NAME @click.command @@ -58,3 +60,35 @@ def visualize_cli(datasets, crops, classes, kinds): classes=classes.split(","), kinds=kinds.split(","), ) + + +@click.command +@click.option( + "--script_path", + "-s", + type=click.STRING, + required=True, + help="Path to the script to run for live prediction.", +) +@click.option( + "--dataset", + "-d", + type=click.STRING, + required=True, + help="Dataset to view (Example: 'jrc_cos7-1a')", +) +def flow(script_path, dataset): + """ + Run a cellmap-flow to visualize live predictions using a script defining a model config, visualizing the results in Neuroglancer. + + Parameters + ---------- + script_path : str + Path to the script defining the model config (e.g. `examples/train_2D.py`). + dataset : str + Dataset to view (Example: 'jrc_cos7-1a'), + """ + + dataset_path = SEARCH_PATH.format(dataset=dataset, name=RAW_NAME) + + os.system(f"cellmap_flow script -s {script_path} -d {dataset_path}") diff --git a/src/cellmap_segmentation_challenge/config.py b/src/cellmap_segmentation_challenge/config.py index 11e154a..7ab7222 100644 --- a/src/cellmap_segmentation_challenge/config.py +++ b/src/cellmap_segmentation_challenge/config.py @@ -21,7 +21,6 @@ TRUTH_PATH = (BASE_DATA_PATH / "ground_truth.zarr").path # s3 paths -# GT_S3_BUCKET = "janelia-cellmap-fg5f2y1pl8" GT_S3_BUCKET = "janelia-cosem-datasets" RAW_S3_BUCKET = "janelia-cosem-datasets" S3_SEARCH_PATH = "{dataset}/{dataset}.zarr/recon-1/{name}" From 343190214778c61738a72814e91b5d714d008200 Mon Sep 17 00:00:00 2001 From: rhoadesScholar Date: Thu, 30 Jan 2025 14:13:42 -0500 Subject: [PATCH 2/2] add level specification to flow --- .../cli/visualize.py | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/src/cellmap_segmentation_challenge/cli/visualize.py b/src/cellmap_segmentation_challenge/cli/visualize.py index 4f00813..3739ccf 100644 --- a/src/cellmap_segmentation_challenge/cli/visualize.py +++ b/src/cellmap_segmentation_challenge/cli/visualize.py @@ -1,6 +1,10 @@ import click import os + +from upath import UPath from cellmap_segmentation_challenge import SEARCH_PATH, RAW_NAME +from cellmap_segmentation_challenge.utils import load_safe_config +from cellmap_data.utils import find_level @click.command @@ -77,7 +81,15 @@ def visualize_cli(datasets, crops, classes, kinds): required=True, help="Dataset to view (Example: 'jrc_cos7-1a')", ) -def flow(script_path, dataset): +@click.option( + "--level", + "-l", + type=click.STRING, + required=False, + default=None, + help="(Optional) Scale level to feed to model (Example: 's0'). If not specified, will be inferred from input_array_info in the config script.", +) +def flow(script_path, dataset, level): """ Run a cellmap-flow to visualize live predictions using a script defining a model config, visualizing the results in Neuroglancer. @@ -87,8 +99,19 @@ def flow(script_path, dataset): Path to the script defining the model config (e.g. `examples/train_2D.py`). dataset : str Dataset to view (Example: 'jrc_cos7-1a'), + level : str + (Optional) Scale level to feed to model (Example: 's0'). If not specified, will be inferred from input_array_info in the config script. """ + config = load_safe_config(script_path) + dataset_path = SEARCH_PATH.format(dataset=dataset, name=RAW_NAME) + if level is None: + level = find_level( + dataset_path, + {k: v for k, v in zip("zyx", config.input_array_info["scale"])}, + ) - os.system(f"cellmap_flow script -s {script_path} -d {dataset_path}") + os.system( + f"cellmap_flow script -s {script_path} -d {(UPath(dataset_path) / level).path}" + )