Use safetensors in addition to python pickle files (.pth) for tensors and metadata#3011
Draft
mhalle wants to merge 8 commits intoMIC-DKFZ:masterfrom
Draft
Use safetensors in addition to python pickle files (.pth) for tensors and metadata#3011mhalle wants to merge 8 commits intoMIC-DKFZ:masterfrom
mhalle wants to merge 8 commits intoMIC-DKFZ:masterfrom
Conversation
Replace the dual-write .pth + .safetensors prototype with a self-describing three-file layout: weights.safetensors (with weight_layout=torch_ncdhw in the metadata header), trainer_state.safetensors (flattened optimizer + grad_scaler tensors), and a JSON sidecar holding everything Python plus tensor placeholder skeletons. Recursive walkers split/merge arbitrary nested optimizer state, so warmed Adam buffers and step counters round-trip byte-equal. safetensors is now a hard dependency (the trainer always writes the new format); the optional [safetensors] extra is dropped. The convert CLI gains a --delete-pth flag. Tests cover layout, weight_layout metadata, network and optimizer round-trip, reload-into-real-AdamW, Python metadata, load_optimizer=False, legacy .pth fallback, and the convert helper. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Save path always writes the safetensors layout via save_checkpoint, then also writes the legacy .pth iff nnUNet_save_pth=1 (the default — opt out explicitly to drop .pth). Load path routes through load_checkpoint, which prefers safetensors and falls back to .pth so existing runs keep working. Also fixes a latent bug where load_checkpoint(dict) never assigned the local 'checkpoint' variable. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Both the inference predictor and the pretrained-weight loader only need network_weights and the small JSON metadata. Skip the trainer-state safetensors file via load_optimizer=False so distribution-only artifacts (no optimizer state) load cleanly. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Document the three-file layout, the nnUNet_save_pth opt-out, the convert CLI, and the weight_layout metadata tag that non-PyTorch loaders should read. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The safety argument for safetensors got materially stronger with CVE-2025-32434 (CVSS 9.3, fixed in PyTorch 2.6.0), which demonstrated RCE through torch.load even with weights_only=True. nnU-Net checkpoints contain optimizer state and arbitrary init_args, so every legacy .pth load takes the unsafe weights_only=False path. Recommend nnUNet_save_pth=0 for anyone distributing models to third parties. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
.pth files are PyTorch pickle archives, so MLX/JAX/ONNX/C++ runtimes can't deserialize them without taking a hard torch dependency just to load weights. safetensors is framework-agnostic by design and is what makes the MLX inference port possible without a torch runtime requirement. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The previous single <base>.json conflated inference metadata (init_args, trainer_name, mirroring axes) with trainer state (logging, optimizer skeleton, epoch). Distributing a model meant shipping the optimizer skeleton and full training-loss history along with the weights — bloat plus a fingerprinting surface. New layout pairs each safetensors file with its own JSON: <base>.safetensors + <base>.json (inference) <base>.trainer_state.safetensors + <base>.trainer_state.json (resume) Distribution = the inference pair, two small files. Resume = all four. The trainer-state pair is omitted entirely when a checkpoint has no optimizer state. Tests gain test_inference_meta_isolated_from_trainer_state to lock the boundary, and test_inference_only_artifact for the two-file path. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Accidentally committed in the previous commit via 'git add -A'. uv is not the project's package manager and uv.lock is a local-tooling artifact that doesn't belong in the tree. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Member
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.

Summary
Replaces nnU-Net's
.pthpickle checkpoints with a self-describing safetensors layout, with a backwards-compatible dual-write so existing tooling keeps working.Why
nnU-Net's existing code calls
torch.load(..., weights_only=False)everywhere it loads a checkpoint. As of PyTorch 2.6 the default isweights_only=True, and CVE-2025-32434 (CVSS 9.3, fixed in 2.6.0) demonstrated remote code execution throughtorch.loadeven withweights_only=True. nnU-Net checkpoints contain optimizer state and arbitraryinit_args, so every legacy load takes the unsafe path. Distributing pretrained.pthmodels to third parties is now distributing a known attack vector.There is also a portability angle:
.pthfiles are PyTorch pickle archives, so MLX, JAX, ONNX runtimes, and native C++ inference cannot deserialize them without taking a hardtorchdependency purely to load weights. safetensors is framework-agnostic by design — the MLX inference port for Apple Silicon already uses this to provide a torch-free runtime.What changes
For a checkpoint named
checkpoint_final.pththe trainer now writes:checkpoint_final.safetensorsweight_layout=torch_ncdhwin the safetensors metadata header so non-PyTorch loaders can transpose deterministically.checkpoint_final.jsoninit_args,trainer_name,inference_allowed_mirroring_axes. Small.checkpoint_final.trainer_state.safetensorsoptimizer.state.<param_id>.<buffer_name>).checkpoint_final.trainer_state.jsoncurrent_epoch,_best_ema,logging, plus skeletons for the optimizer and grad_scaler dicts with tensor placeholders.checkpoint_final.pthnnUNet_save_pth=0.The two inference files are the full distribution artifact for a pretrained model — ship them and a third party can run inference without ever touching pickle. The two trainer state files are needed only to resume training and contain no information that an inference user needs (or that you may want to expose:
loggingcarries the full training-loss history, which can fingerprint a dataset). The trainer-state pair is omitted entirely if the checkpoint has no optimizer state.A recursive
_split/_mergewalker incheckpoint_io.pyhandles arbitrary nested optimizer state, so warmed Adam buffers (exp_avg,exp_avg_sq,step) and grad_scaler scale tensors round-trip byte-equal. Inference callers passload_optimizer=Falseto skip the trainer-state pair entirely.Backwards compatibility
load_checkpointprefers the safetensors layout but falls back totorch.loadon.pthif the safetensors siblings don't exist. Existing model zoos load unchanged..pthis on by default, gated bynnUNet_save_pth=1. Set the env var to0to disable. The suggested deprecation path is: this release dual-write, next release flip the default to off, release after that remove the writer. The.pthreader stays indefinitely.nnUNetTrainer.load_checkpoint(<dict>)never assignedcheckpointis fixed in passing.Dependency change
safetensors>=0.4.0is a required dependency. The trainer always writes the new format, so the dependency is genuinely needed at runtime;safetensorsis a small Rust extension distributed as wheels on every supported platform and installs in well under a second.Tests
Twelve round-trip tests in
nnunetv2/tests/test_checkpoint_io.py, all passing locally:.pth)loggingor optimizer skeletons into the distributable JSON)weight_layoutmetadata headerAdamWand matchesload_optimizer=Falseskip path.pthfallback when safetensors siblings absentconvert_pth_to_safetensorsend-to-endFileNotFoundErroron missing checkpointThese exercise
checkpoint_io.pyin isolation. The branch has not yet been run throughnnunetv2/tests/integration_tests/run_integration_test.sh— that needs a CUDA box and the Hippocampus dataset, and would benefit from a maintainer's environment before merging.CLI
A new
nnUNetv2_convert_to_safetensorsentry point converts existing.pthcheckpoints in place:Docs
New
documentation/explanation_checkpoint_format.mdcovers the layout, the env var, the conversion CLI, theweight_layoutmetadata tag, and the security rationale. Follows the existingexplanation_*.mdnaming convention.Test plan
pytest nnunetv2/tests/test_checkpoint_io.pypasses (12/12).pthOpen questions for review
safetensors. I chose hard. Defensible because the trainer always writes the new format and the wheel is tiny, but happy to revisit if you'd rather keep the install footprint minimal and gate the new format on the import being available.nnUNet_save_pthmatches the existingnnUNet_*env var convention. Open to alternatives.🤖 Generated with Claude Code