Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/conda-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,15 @@ jobs:
- uses: actions/checkout@v4
- uses: conda-incubator/setup-miniconda@v3
with:
miniforge-version: latest
auto-update-conda: true
auto-activate-base: true
activate-environment: ""
channel-priority: strict
conda-remove-defaults: "true"
miniforge-variant: Miniforge3
- shell: bash -l {0}
run: conda info --envs
- name: Build pytorch-3dunet
shell: bash -l {0}
run: |
conda install -q conda-build
conda install -q conda-build conda=25.7.0
conda build -c conda-forge conda-recipe
23 changes: 23 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: Lint

on:
push:
branches:
- "master"
pull_request:
branches: [master]

jobs:
ruff:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install ruff
run: pip install ruff
- name: Run ruff check
run: ruff check .
- name: Run ruff format check
run: ruff format --check .
38 changes: 22 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,43 +37,48 @@ The format of the raw and label datasets depends on whether the problem is 2D or
| multi-channel | (C, 1, Y, X) | (C, Z, Y, X) |

## Prerequisites

- NVIDIA GPU
- CUDA CuDNN
- [Miniconda](https://www.anaconda.com/docs/getting-started/miniconda/install)
- Python 3.11+
- NVIDIA GPU (optional but recommended for training/prediction speedup)

### Running on Windows/OSX

`pytorch-3dunet` is a cross-platform package and runs on Windows and OS X as well.

## Installation

- The easiest way to install `pytorch-3dunet` package is via conda:
The easiest way to install `pytorch-3dunet` package is via conda:

```bash
conda install -c conda-forge pytorch-3dunet
```
# Created new conda environment "3dunet" with the latest python version from the conda-forge channel
conda create -n 3dunet python -c conda-forge -y

**Note:** The conda package does not include PyTorch dependencies. You need to install them separately in your conda environemt:
# Activate the conda environment
conda activate 3dunet

```bash
# Install PyTorch with CUDA support (adjust for your CUDA version, below it's CUDA 11.8)
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
# pytorch-3dunet does not include PyTorch dependencies, so that one can install the desired PyTorch version (with/without CUDA support) separately
pip install torch torchvision
# you may need to adjust the command above depending on your GPU and the CUDA version you want to use, e.g. for CUDA 12.6:
# pip install torch torchvision --index-url https://download.pytorch.org/whl/cu126
# or for CPU-only version:
# pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu

# Or install CPU-only version
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
# Install the latest pytorch-3dunet package from conda-forge channel
conda install -c conda-forge pytorch-3dunet
```

After installation the following commands will be accessible within the conda environment:
`train3dunet` for training the network and `predict3dunet` for prediction (see below).

- One can also install directly from source, i.e. go to the checkout directory and run:

One can also install directly from source, i.e. go to the checkout directory and run:
```
pip install -e .
```

### Installation tips
Make sure that the installed `torch` is compatible with your CUDA version, otherwise the training/prediction will fail to run on GPU.
PyTorch package comes with their own CUDA runtime libraries, so you don't need to install CUDA separately on your system.
However, you must ensure that the PyTorch/CUDA version you choose is compatible with your GPU’s compute capability.
See [PyTorch installation guide](https://pytorch.org/get-started/locally/) for more details.

## Train
Given that `pytorch-3dunet` package was installed via conda as described above, you can train the network by simply invoking:
Expand All @@ -98,7 +103,7 @@ One can monitor the training progress with Tensorboard `tensorboard --logdir <ch
When training with `WeightedCrossEntropyLoss`, `CrossEntropyLoss` the target dataset has to be 3D, see also pytorch
documentation for CE loss: https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html
2. When training with `BCEWithLogitsLoss`, `DiceLoss`, `BCEDiceLoss`, `GeneralizedDiceLoss` set `final_sigmoid=True` in
the `model` part of the config so that the sigmoid is applied to the logits.
the `model` part of the config so that the sigmoid is applied to the logits in inference mode.
3. When training with cross entropy based losses (`WeightedCrossEntropyLoss`, `CrossEntropyLoss`) set
`final_sigmoid=False` so that `Softmax` normalization is applied to the logits.

Expand Down Expand Up @@ -303,6 +308,7 @@ pip install -e .

Tests can be run via `pytest`.
The device the tests should be run on can be specified with the `--device` argument (`cpu`, `mps`, or `cuda` - default: `cpu`).
Linting is done via `ruff` (see `pyproject.toml` for configuration).

## Release new version on `conda-forge` channel
To release a new version of `pytorch-3dunet` on the `conda-forge` channel, follow these steps:
Expand Down
1 change: 1 addition & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ dependencies:
- scikit-image
- pyyaml
- pytest
- ruff
- pip
- pip:
- torch
Expand Down
24 changes: 23 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,26 @@ predict3dunet = "pytorch3dunet.predict:main"
version = {attr = "pytorch3dunet.__version__.__version__"}

[tool.setuptools.packages.find]
exclude = ["tests*"]
exclude = ["tests*"]

[tool.ruff]
line-length = 120
target-version = "py311"

[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"UP", # pyupgrade
]
ignore = [
"E501", # line too long (handled by formatter)
]

[tool.ruff.format]
quote-style = "double"
indent-style = "space"
2 changes: 1 addition & 1 deletion pytorch3dunet/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .__version__ import __version__
from .__version__ import __version__ # noqa: F401
2 changes: 1 addition & 1 deletion pytorch3dunet/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.9.3'
__version__ = "1.9.3"
97 changes: 58 additions & 39 deletions pytorch3dunet/augment/transformer-test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,26 @@
"metadata": {},
"outputs": [],
"source": [
"from transforms import LabelToAffinities, RandomFlip, RandomRotate, RandomRotate90, RandomContrast, Normalize, ElasticDeformation, StandardLabelToBoundary, AdditiveGaussianNoise, AdditivePoissonNoise, LabelToBoundaryAndAffinities, FlyWingBoundary, LabelToBoundaryAndAffinities, Standardize\n",
"import matplotlib.pyplot as plt\n",
"import h5py\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from transforms import (\n",
" AdditiveGaussianNoise,\n",
" AdditivePoissonNoise,\n",
" ElasticDeformation,\n",
" LabelToAffinities,\n",
" RandomContrast,\n",
" RandomFlip,\n",
" RandomRotate,\n",
" RandomRotate90,\n",
" Standardize,\n",
" StandardLabelToBoundary,\n",
")\n",
"\n",
"path = '../../resources/sample_ovule.h5'\n",
"with h5py.File(path, 'r') as f:\n",
" raw = f['raw'][...]\n",
" label = f['label'][...]"
"path = \"../../resources/sample_ovule.h5\"\n",
"with h5py.File(path, \"r\") as f:\n",
" raw = f[\"raw\"][...]\n",
" label = f[\"label\"][...]"
]
},
{
Expand Down Expand Up @@ -168,28 +179,28 @@
],
"source": [
"# LabelToBoundary demo\n",
"path = '../resources/sample_patch.h5'\n",
"path = \"../resources/sample_patch.h5\"\n",
"\n",
"with h5py.File(path, 'r') as f:\n",
" label = f['label'][...]\n",
" t1 = LabelToAffinities(offsets=[1,4,6,8], aggregate_affinities=True)\n",
"with h5py.File(path, \"r\") as f:\n",
" label = f[\"label\"][...]\n",
" t1 = LabelToAffinities(offsets=[1, 4, 6, 8], aggregate_affinities=True)\n",
" label_transformed1 = t1(label)\n",
" \n",
"\n",
" sltb = StandardLabelToBoundary(blur=True, sigma=1.3)\n",
" label_transformed1 = np.append(label_transformed1, sltb(label), axis=0)\n",
" \n",
" fig, axes = plt.subplots(1, 6, figsize=(25, 5))\n",
"\n",
" fig, axes = plt.subplots(1, 6, figsize=(25, 5))\n",
" ax = axes.ravel()\n",
"\n",
" # show label\n",
" ax[0].set_title('Label')\n",
" ax[0].imshow(label[40, ...], cmap='flag')\n",
" ax[0].set_title(\"Label\")\n",
" ax[0].imshow(label[40, ...], cmap=\"flag\")\n",
" ax[0].set_axis_off()\n",
"\n",
" for i in range(label_transformed1.shape[0]): \n",
" ax[i+1].set_title(f'Affinities{i}')\n",
" ax[i+1].imshow(label_transformed1[i, 40, ...])\n",
" ax[i+1].set_axis_off()"
" for i in range(label_transformed1.shape[0]):\n",
" ax[i + 1].set_title(f\"Affinities{i}\")\n",
" ax[i + 1].imshow(label_transformed1[i, 40, ...])\n",
" ax[i + 1].set_axis_off()"
]
},
{
Expand All @@ -215,36 +226,44 @@
"source": [
"# demo some of the augmentations\n",
"rs = np.random.RandomState()\n",
"raw_transformers = [RandomFlip(rs), RandomRotate90(rs), RandomRotate(rs, angle_spectrum=20, axes=[(2, 1)], mode='reflect'), RandomContrast(rs, alpha=(0.5, 1.5), execution_probability=1.0), ElasticDeformation(rs, 3, alpha=20, sigma=3, execution_probability=1.0), AdditiveGaussianNoise(rs, scale=(0.0, 0.5), execution_probability=1.0), AdditivePoissonNoise(rs, lam=(0.0, 0.5), execution_probability=1.0)]\n",
"raw_transformers = [\n",
" RandomFlip(rs),\n",
" RandomRotate90(rs),\n",
" RandomRotate(rs, angle_spectrum=20, axes=[(2, 1)], mode=\"reflect\"),\n",
" RandomContrast(rs, alpha=(0.5, 1.5), execution_probability=1.0),\n",
" ElasticDeformation(rs, 3, alpha=20, sigma=3, execution_probability=1.0),\n",
" AdditiveGaussianNoise(rs, scale=(0.0, 0.5), execution_probability=1.0),\n",
" AdditivePoissonNoise(rs, lam=(0.0, 0.5), execution_probability=1.0),\n",
"]\n",
"\n",
"with h5py.File(path, 'r') as f:\n",
" raw = f['raw'][...]\n",
"with h5py.File(path, \"r\") as f:\n",
" raw = f[\"raw\"][...]\n",
" mid_z = raw.shape[0] // 2\n",
" raw = Standardize(np.mean(raw), np.std(raw))(raw)\n",
" #raw = Normalize(np.min(raw), np.max(raw))(raw)\n",
" label = f['label'][...]\n",
" \n",
" # raw = Normalize(np.min(raw), np.max(raw))(raw)\n",
" label = f[\"label\"][...]\n",
"\n",
" # show transforms\n",
" fig, axes = plt.subplots(len(raw_transformers), 3, figsize=(15, len(raw_transformers) * 5))\n",
" ax = axes.ravel()\n",
" \n",
"\n",
" for t, i in enumerate(range(0, 3 * len(raw_transformers), 3)):\n",
" transformer = raw_transformers[t]\n",
" ax[i].set_title('Raw')\n",
" ax[i].set_title(\"Raw\")\n",
" ax[i].imshow(raw[mid_z, ...])\n",
" ax[i].set_axis_off()\n",
"\n",
" # show boundary for the 1st offset\n",
" ax[i+1].set_title(f'{type(transformer).__name__}-0')\n",
" ax[i + 1].set_title(f\"{type(transformer).__name__}-0\")\n",
" aug = transformer(raw)\n",
" ax[i+1].imshow(aug[mid_z, ...])\n",
" ax[i+1].set_axis_off()\n",
" ax[i + 1].imshow(aug[mid_z, ...])\n",
" ax[i + 1].set_axis_off()\n",
"\n",
" # show boundary for the 4th offset\n",
" ax[i+2].set_title(f'{type(transformer).__name__}-1')\n",
" ax[i + 2].set_title(f\"{type(transformer).__name__}-1\")\n",
" aug = transformer(raw)\n",
" ax[i+2].imshow(aug[mid_z, ...])\n",
" ax[i+2].set_axis_off()"
" ax[i + 2].imshow(aug[mid_z, ...])\n",
" ax[i + 2].set_axis_off()"
]
},
{
Expand Down Expand Up @@ -273,11 +292,11 @@
"rs = np.random.RandomState()\n",
"transformer = AdditiveGaussianNoise(rs, scale=(0.0, 1.0), execution_probability=1.0)\n",
"\n",
"ax[0].set_title('Raw')\n",
"ax[0].set_title(\"Raw\")\n",
"ax[0].imshow(raw[mid_z, ...])\n",
"ax[0].set_axis_off()\n",
"for i in range(1, 10):\n",
" ax[i].set_title(f'Noise-{i}')\n",
" ax[i].set_title(f\"Noise-{i}\")\n",
" aug = transformer(raw)\n",
" ax[i].imshow(aug[mid_z, ...])\n",
" ax[i].set_axis_off()"
Expand Down Expand Up @@ -309,11 +328,11 @@
"rs = np.random.RandomState()\n",
"transformer = AdditivePoissonNoise(rs, lam=(0.0, 1.0), execution_probability=1.0)\n",
"\n",
"ax[0].set_title('Raw')\n",
"ax[0].set_title(\"Raw\")\n",
"ax[0].imshow(raw[mid_z, ...])\n",
"ax[0].set_axis_off()\n",
"for i in range(1, 10):\n",
" ax[i].set_title(f'Noise-{i}')\n",
" ax[i].set_title(f\"Noise-{i}\")\n",
" aug = transformer(raw)\n",
" ax[i].imshow(aug[mid_z, ...])\n",
" ax[i].set_axis_off()"
Expand Down Expand Up @@ -346,11 +365,11 @@
"t1 = AdditiveGaussianNoise(rs, scale=(0.0, 1.0), execution_probability=1.0)\n",
"t2 = AdditivePoissonNoise(rs, lam=(0.0, 0.5), execution_probability=1.0)\n",
"\n",
"ax[0].set_title('Raw')\n",
"ax[0].set_title(\"Raw\")\n",
"ax[0].imshow(raw[mid_z, ...])\n",
"ax[0].set_axis_off()\n",
"for i in range(1, 10):\n",
" ax[i].set_title(f'Noise-{i}')\n",
" ax[i].set_title(f\"Noise-{i}\")\n",
" aug = t1(t2(raw))\n",
" ax[i].imshow(aug[mid_z, ...])\n",
" ax[i].set_axis_off()"
Expand Down
Loading