diff --git a/src/vpt_plugin_cellpose/__init__.py b/src/vpt_plugin_cellpose/__init__.py index 5d6c8e1..4e47caa 100644 --- a/src/vpt_plugin_cellpose/__init__.py +++ b/src/vpt_plugin_cellpose/__init__.py @@ -8,6 +8,7 @@ class CellposeSegProperties: model_dimensions: str version: str custom_weights: Optional[str] = None + use_gpu: bool = False @dataclass(frozen=True) @@ -18,3 +19,4 @@ class CellposeSegParameters: flow_threshold: float mask_threshold: float minimum_mask_size: int + batch_size: Optional[int] = 8 # default in Cellpose v1* diff --git a/src/vpt_plugin_cellpose/predict.py b/src/vpt_plugin_cellpose/predict.py index d223784..9181b48 100644 --- a/src/vpt_plugin_cellpose/predict.py +++ b/src/vpt_plugin_cellpose/predict.py @@ -26,9 +26,9 @@ def run(images: ImageSet, properties: CellposeSegProperties, parameters: Cellpos return np.zeros((image.shape[0],) + image.shape[1:-1]) if properties.custom_weights: - model = models.CellposeModel(gpu=False, pretrained_model=properties.custom_weights, net_avg=False) + model = models.CellposeModel(gpu=properties.use_gpu, pretrained_model=properties.custom_weights, net_avg=False) else: - model = models.Cellpose(gpu=False, model_type=properties.model, net_avg=False) + model = models.Cellpose(gpu=properties.use_gpu, model_type=properties.model, net_avg=False) to_segment_z = list(set(range(image.shape[0])).difference(empty_z_levels)) mask = model.eval( @@ -42,6 +42,7 @@ def run(images: ImageSet, properties: CellposeSegProperties, parameters: Cellpos min_size=parameters.minimum_mask_size, tile=True, do_3D=(properties.model_dimensions == "3D"), + batch_size=parameters.batch_size, )[0] mask = mask.reshape((len(to_segment_z),) + image.shape[1:-1]) for i in empty_z_levels: