Skip to content

Use safetensors in addition to python pickle files (.pth) for tensors and metadata#3011

Draft
mhalle wants to merge 8 commits intoMIC-DKFZ:masterfrom
mhalle:feature/safetensors-support
Draft

Use safetensors in addition to python pickle files (.pth) for tensors and metadata#3011
mhalle wants to merge 8 commits intoMIC-DKFZ:masterfrom
mhalle:feature/safetensors-support

Conversation

@mhalle
Copy link
Copy Markdown

@mhalle mhalle commented Apr 7, 2026

Summary

Replaces nnU-Net's .pth pickle checkpoints with a self-describing safetensors layout, with a backwards-compatible dual-write so existing tooling keeps working.

Why

nnU-Net's existing code calls torch.load(..., weights_only=False) everywhere it loads a checkpoint. As of PyTorch 2.6 the default is weights_only=True, and CVE-2025-32434 (CVSS 9.3, fixed in 2.6.0) demonstrated remote code execution through torch.load even with weights_only=True. nnU-Net checkpoints contain optimizer state and arbitrary init_args, so every legacy load takes the unsafe path. Distributing pretrained .pth models to third parties is now distributing a known attack vector.

There is also a portability angle: .pth files are PyTorch pickle archives, so MLX, JAX, ONNX runtimes, and native C++ inference cannot deserialize them without taking a hard torch dependency purely to load weights. safetensors is framework-agnostic by design — the MLX inference port for Apple Silicon already uses this to provide a torch-free runtime.

What changes

For a checkpoint named checkpoint_final.pth the trainer now writes:

File Group Contents
checkpoint_final.safetensors inference Network weights only. Has weight_layout=torch_ncdhw in the safetensors metadata header so non-PyTorch loaders can transpose deterministically.
checkpoint_final.json inference Inference metadata only: init_args, trainer_name, inference_allowed_mirroring_axes. Small.
checkpoint_final.trainer_state.safetensors trainer state Optimizer + grad_scaler tensors, flattened with dotted keys (optimizer.state.<param_id>.<buffer_name>).
checkpoint_final.trainer_state.json trainer state Trainer Python state: current_epoch, _best_ema, logging, plus skeletons for the optimizer and grad_scaler dicts with tensor placeholders.
checkpoint_final.pth (legacy) Legacy pickle. Written by default for backwards compat; opt out with nnUNet_save_pth=0.

The two inference files are the full distribution artifact for a pretrained model — ship them and a third party can run inference without ever touching pickle. The two trainer state files are needed only to resume training and contain no information that an inference user needs (or that you may want to expose: logging carries the full training-loss history, which can fingerprint a dataset). The trainer-state pair is omitted entirely if the checkpoint has no optimizer state.

A recursive _split / _merge walker in checkpoint_io.py handles arbitrary nested optimizer state, so warmed Adam buffers (exp_avg, exp_avg_sq, step) and grad_scaler scale tensors round-trip byte-equal. Inference callers pass load_optimizer=False to skip the trainer-state pair entirely.

Backwards compatibility

  • Reading old checkpoints just works. load_checkpoint prefers the safetensors layout but falls back to torch.load on .pth if the safetensors siblings don't exist. Existing model zoos load unchanged.
  • Writing the legacy .pth is on by default, gated by nnUNet_save_pth=1. Set the env var to 0 to disable. The suggested deprecation path is: this release dual-write, next release flip the default to off, release after that remove the writer. The .pth reader stays indefinitely.
  • The latent bug where nnUNetTrainer.load_checkpoint(<dict>) never assigned checkpoint is fixed in passing.

Dependency change

safetensors>=0.4.0 is a required dependency. The trainer always writes the new format, so the dependency is genuinely needed at runtime; safetensors is a small Rust extension distributed as wheels on every supported platform and installs in well under a second.

Tests

Twelve round-trip tests in nnunetv2/tests/test_checkpoint_io.py, all passing locally:

  • File layout (four files written, no .pth)
  • Inference metadata is isolated from trainer state (locks the boundary so future edits can't leak logging or optimizer skeletons into the distributable JSON)
  • Inference-only artifact: a checkpoint with no optimizer state writes exactly the two inference files
  • weight_layout metadata header
  • Network weights round-trip (byte-equal)
  • Optimizer state round-trip (byte-equal, including warmed AdamW with 3 real steps)
  • Reconstructed optimizer state loads cleanly into a fresh AdamW and matches
  • Python metadata round-trip
  • load_optimizer=False skip path
  • Legacy .pth fallback when safetensors siblings absent
  • convert_pth_to_safetensors end-to-end
  • FileNotFoundError on missing checkpoint

These exercise checkpoint_io.py in isolation. The branch has not yet been run through nnunetv2/tests/integration_tests/run_integration_test.sh — that needs a CUDA box and the Hippocampus dataset, and would benefit from a maintainer's environment before merging.

CLI

A new nnUNetv2_convert_to_safetensors entry point converts existing .pth checkpoints in place:

nnUNetv2_convert_to_safetensors path/to/checkpoint_final.pth
nnUNetv2_convert_to_safetensors path/to/model_dir --recursive
nnUNetv2_convert_to_safetensors path/to/checkpoint_final.pth --delete-pth

Docs

New documentation/explanation_checkpoint_format.md covers the layout, the env var, the conversion CLI, the weight_layout metadata tag, and the security rationale. Follows the existing explanation_*.md naming convention.

Test plan

  • pytest nnunetv2/tests/test_checkpoint_io.py passes (12/12)
  • Smoke import of all touched modules
  • Full integration test on a CUDA box with the Hippocampus dataset (needs a maintainer's environment)
  • Verify training resume from a safetensors checkpoint produces equivalent loss curves to resume from .pth
  • Verify W&B and meta-logger integration is unaffected by the trainer changes

Open questions for review

  1. Hard dep vs optional extra for safetensors. I chose hard. Defensible because the trainer always writes the new format and the wheel is tiny, but happy to revisit if you'd rather keep the install footprint minimal and gate the new format on the import being available.
  2. Env var name. nnUNet_save_pth matches the existing nnUNet_* env var convention. Open to alternatives.
  3. Deprecation timeline. Suggested above as dual-write → opt-in → removal across three releases. Faster or slower is fine.

🤖 Generated with Claude Code

mhalle and others added 6 commits April 7, 2026 18:55
Replace the dual-write .pth + .safetensors prototype with a self-describing
three-file layout: weights.safetensors (with weight_layout=torch_ncdhw in the
metadata header), trainer_state.safetensors (flattened optimizer + grad_scaler
tensors), and a JSON sidecar holding everything Python plus tensor placeholder
skeletons. Recursive walkers split/merge arbitrary nested optimizer state, so
warmed Adam buffers and step counters round-trip byte-equal.

safetensors is now a hard dependency (the trainer always writes the new
format); the optional [safetensors] extra is dropped. The convert CLI gains
a --delete-pth flag.

Tests cover layout, weight_layout metadata, network and optimizer round-trip,
reload-into-real-AdamW, Python metadata, load_optimizer=False, legacy .pth
fallback, and the convert helper.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Save path always writes the safetensors layout via save_checkpoint, then
also writes the legacy .pth iff nnUNet_save_pth=1 (the default — opt out
explicitly to drop .pth). Load path routes through load_checkpoint, which
prefers safetensors and falls back to .pth so existing runs keep working.

Also fixes a latent bug where load_checkpoint(dict) never assigned the local
'checkpoint' variable.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Both the inference predictor and the pretrained-weight loader only need
network_weights and the small JSON metadata. Skip the trainer-state
safetensors file via load_optimizer=False so distribution-only artifacts
(no optimizer state) load cleanly.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Document the three-file layout, the nnUNet_save_pth opt-out, the convert
CLI, and the weight_layout metadata tag that non-PyTorch loaders should read.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The safety argument for safetensors got materially stronger with
CVE-2025-32434 (CVSS 9.3, fixed in PyTorch 2.6.0), which demonstrated RCE
through torch.load even with weights_only=True. nnU-Net checkpoints
contain optimizer state and arbitrary init_args, so every legacy .pth load
takes the unsafe weights_only=False path. Recommend nnUNet_save_pth=0 for
anyone distributing models to third parties.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
.pth files are PyTorch pickle archives, so MLX/JAX/ONNX/C++ runtimes can't
deserialize them without taking a hard torch dependency just to load
weights. safetensors is framework-agnostic by design and is what makes the
MLX inference port possible without a torch runtime requirement.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@mhalle mhalle changed the title Canonical safetensors checkpoint layout Use safetensors in addition to python pickle files (.pth) for tensors and metadata Apr 7, 2026
mhalle and others added 2 commits April 7, 2026 19:23
The previous single <base>.json conflated inference metadata (init_args,
trainer_name, mirroring axes) with trainer state (logging, optimizer
skeleton, epoch). Distributing a model meant shipping the optimizer
skeleton and full training-loss history along with the weights — bloat
plus a fingerprinting surface.

New layout pairs each safetensors file with its own JSON:

  <base>.safetensors                 + <base>.json                 (inference)
  <base>.trainer_state.safetensors   + <base>.trainer_state.json   (resume)

Distribution = the inference pair, two small files. Resume = all four.
The trainer-state pair is omitted entirely when a checkpoint has no
optimizer state.

Tests gain test_inference_meta_isolated_from_trainer_state to lock the
boundary, and test_inference_only_artifact for the two-file path.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Accidentally committed in the previous commit via 'git add -A'. uv is not
the project's package manager and uv.lock is a local-tooling artifact
that doesn't belong in the tree.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@FabianIsensee FabianIsensee self-assigned this Apr 7, 2026
@FabianIsensee
Copy link
Copy Markdown
Member

FabianIsensee commented Apr 9, 2026

Hey thanks for this PR! I tested it locally and at least the writing of checkpoints works, but there are at least two issues I found so far:

  • continuation of the training doesn't work with the new checkpoints. It doesn't seem to detect the latest / best checkpoints anymore (or expects to find the pth files which may not exist!)
  • we constantly struggle with checkpoints swamping our network drives and the safetensors variant is WAY to big:
    image
    It consumes >2x the storage space. We need a solution for that

Best,
Fabian

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants