Skip to content

Implements Predictor specialization for multi-diffusion#1573

Open
CharlelieLrt wants to merge 17 commits intoNVIDIA:mainfrom
CharlelieLrt:multi_diffusion_sampling
Open

Implements Predictor specialization for multi-diffusion#1573
CharlelieLrt wants to merge 17 commits intoNVIDIA:mainfrom
CharlelieLrt:multi_diffusion_sampling

Conversation

@CharlelieLrt
Copy link
Copy Markdown
Collaborator

PhysicsNeMo Pull Request

Description

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 17, 2026

Greptile Summary

This PR introduces MultiDiffusionPredictor, a Predictor-protocol-compatible wrapper for test-time patch-based diffusion sampling using MultiDiffusionModel2D. It pre-patches the condition and positional embedding once at construction time to avoid redundant work per step, refactors the duplicate PE injection logic into a shared _inject_patched_pos_embd helper, and generalises the DDP/compile-unwrapping helper into a reusable _unwrap_module utility. A thorough test suite covering constructor validation, non-regression, gradient flow, torch.compile, and checkpoint round-trips is included.

  • P1 (predictor.py line 406): _skip_positional_embedding_injection = True is permanently written to the underlying MultiDiffusionModel2D and never restored. If the same model object is used for loss computation after the predictor is created (e.g., continued fine-tuning), PE injection is silently skipped, producing wrong training outputs without any error.

Important Files Changed

Filename Overview
physicsnemo/diffusion/multi_diffusion/predictor.py New file implementing MultiDiffusionPredictor — a Predictor-protocol wrapper for patch-based sampling. Correctly pre-patches condition and PE at construction, but permanently mutates the underlying model's _skip_positional_embedding_injection flag with no restore mechanism (P1), and the dead is_compiling() guard in init is misleading.
physicsnemo/diffusion/multi_diffusion/models.py Adds _skip_positional_embedding_injection flag and refactors PE injection into a shared _inject_patched_pos_embd helper. The refactoring is clean and the logic is preserved correctly.
physicsnemo/diffusion/utils/utils.py Extracts the DDP/torch.compile unwrapping logic into a reusable generic _unwrap_module function. Clean generalization with correct TypeVar usage.
physicsnemo/diffusion/multi_diffusion/losses.py Replaces the local _unwrap_multi_diffusion helper with the new generic _unwrap_module utility. Equivalent behavior, no logic changes.
physicsnemo/diffusion/guidance/dps_guidance.py One-line change: DPSScorePredictor now explicitly inherits from Predictor, correctly aligning it with the protocol.
physicsnemo/diffusion/multi_diffusion/init.py Adds MultiDiffusionPredictor to the public module exports.
test/diffusion/test_multi_diffusion_predictor.py Comprehensive test suite for MultiDiffusionPredictor covering constructor validation, non-regression outputs, gradient flow, torch.compile compatibility, and checkpoint round-trips.
test/diffusion/test_multi_diffusion_sampling.py End-to-end sampling tests for MultiDiffusionPredictor integrated with all three noise schedulers (EDM, VE, VP), two solvers, and compiled predictor paths.

Comments Outside Diff (1)

  1. physicsnemo/diffusion/multi_diffusion/predictor.py, line 405-406 (link)

    P1 Persistent flag side-effect on underlying model

    _skip_positional_embedding_injection = True is written to _md_model and never restored. If the same MultiDiffusionModel2D is used for loss computation after the predictor is created — for example, after sampling a checkpoint to validate a mid-training result or during continued fine-tuning — the model's forward will silently skip PE injection in both the "no patching" and "with patching" paths, producing wrong outputs from the loss without any error or warning.

    A minimal safeguard would be to save and restore the original flag, or at minimum include a runtime warning that the model is permanently modified:

    # Save original value so callers can restore it
    self._prev_skip_pe = self._md_model._skip_positional_embedding_injection
    self._md_model._skip_positional_embedding_injection = True

    Or add a close() / __del__ to restore the flag, or document that the model must not be used for training after predictor construction.

Reviews (1): Last reviewed commit: "Implements Predictor specialization for ..." | Re-trigger Greptile

Comment thread physicsnemo/diffusion/multi_diffusion/predictor.py Outdated
Comment thread physicsnemo/diffusion/multi_diffusion/predictor.py
Comment thread physicsnemo/diffusion/multi_diffusion/predictor.py
Compiling the predictor instance directly was producing divergent results
under torch 2.10 in the sample() loop (euler cases only). Follow the same
pattern as test_samplers.py::TestSampleCompile and compile the denoiser
closure instead — tracing through it still verifies that the predictor's
__call__ path is compile-compatible.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
torch 2.10 Dynamo crashes with Fatal Python error: Aborted when tracing
the nested MultiDiffusionPredictor -> MultiDiffusionModel2D call chain
inside sample() with fullgraph=True. Allow graph breaks here; the
predictor compile contract is still tested in isolation by
test_multi_diffusion_predictor.py::TestCompile.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
Dispatch on pos_embd presence and model_kwargs is now resolved once at
__init__ into a specialized closure, so __call__ is branch-free and the
no-kwargs path avoids ** expansion. This keeps fullgraph=True compile
cleanly traceable under torch 2.10 (which was hitting a Dynamo abort on
the nested MultiDiffusionPredictor -> MultiDiffusionModel2D call chain
when the denoiser closure was compiled in the sample() loop).

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
Reverts the two earlier CI-fix attempts (compile-denoiser switch, predictor
hot-path flatten) since neither actually fixed the divergence. The
underlying issue is an upstream torch>=2.10 Dynamo bug: euler + compiled
MultiDiffusionPredictor produces numerically divergent results. Heun works,
predictor compiles correctly in isolation. For euler we now assert only
shape + isfinite until the upstream bug is resolved.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
sample() passes t.expand(B) (a stride-0 non-contiguous tensor) into
solver.step(). HeunSolver already forces .contiguous() on both tensors to
prevent torch.compile from specializing on the stride pattern of the first
call and then either mis-firing guards or silently recompiling on
subsequent calls with different underlying storage.

EulerSolver and EDMStochasticEulerSolver had no such guard, which was a
latent bug exposed by torch 2.10 (stricter stride tracking) in the
multi-diffusion compiled sample loop — producing 90%+ element divergence
vs eager on the first call and a Dynamo abort on the second call. Apply
the same fix uniformly across all four solver steps.

Also revert the temporary loosened euler assertion in
test_multi_diffusion_sampling.py now that the real fix is in place.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
…sionPredictor

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
Move the _patching None check out of the is_compiling guard in
MultiDiffusionModel2D so the type checker narrows self._patching
to RandomPatching2D | GridPatching2D for the rest of each method,
and route fuse/reset_patch_indices through isinstance.

Streamline TestConstructor to only exercise the public contract
(.fuse, .model, setter round-trip) and drop assertions on private
caches. Compile the denoiser instead of the predictor in
TestMultiDiffusionSampleCompile and add TestMultiDiffusionFullSamplerCompile
mirroring test_samplers.py.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
pos_embd.unsqueeze(0).expand(B, -1, -1, -1) produces a stride-0 view
(all B copies share storage). Passing this through nn.ReflectionPad2d
and F.unfold inside image_batching triggers a glibc heap corruption
on torch 2.10 (CI, not locally on torch 2.8) when the first non-regression
posembd_sin test runs. Same class of fix as the earlier euler solver.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
Instantiating torch.nn.ReflectionPad2d inside image_batching on every
call creates a fresh nn.Module each time, which torch.compile / AOT
autograd struggles to trace cleanly under fullgraph=True on torch 2.10.
Switch to torch.nn.functional.pad which is a plain functional call and
traces without allocating a module. Same result semantically.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
einops.rearrange goes through a pattern-matched lowering path that
torch.compile / inductor on torch 2.10 handles fragilely in the
image_batching / image_fuse hot paths. The underlying transform is a
plain view + permute + view, so express it directly: this gives inductor
a straightforward sequence of ops to trace, and drops the einops
dependency from this module.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
Under torch.compile / inductor on torch 2.10, a compiled sample() call
through MultiDiffusionPredictor was returning a tensor whose metadata
was valid but whose data pointer was dangling (use-after-free) — the
caller SIGABRTed on the first read of the tensor data. Add .contiguous()
at the two boundaries that returned a view: image_fuse returns
x_folded[...] / overlap_count[...], and MultiDiffusionModel2D.forward
returns the (possibly fused) inner-model output. Forcing fresh storage
on each boundary prevents the returned tensor from aliasing a buffer
whose lifetime ends with the compiled frame.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
The second torch.compile call of a fused MultiDiffusionPredictor was
segfaulting (SIGSEGV) while the first succeeded. .contiguous() is a
no-op when the tensor is already contiguous, so inductor could still
see the returned tensor as aliasing an internal buffer across calls.
.clone() always allocates fresh storage, so successive compiled calls
get independent outputs. Also drop the redundant .contiguous() added
earlier in MultiDiffusionModel2D.forward now that image_fuse owns that
boundary.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
…e on torch>=2.10

Revert commits 3dfcdb5, 746518f and a007c46 (native-torch rearrange
in image_batching/image_fuse, .contiguous() on returned tensors, .clone()
at fuse boundary) since they did not resolve the torch 2.10 inductor
codegen segfault in TestMultiDiffusionFullSamplerCompile. Keep commits
7e1db11 (pos_embd .contiguous() for the glibc heap corruption in
posembd_sin non-regression tests) and feb0d9e (ReflectionPad2d → F.pad).

Gate TestMultiDiffusionFullSamplerCompile with xfail(run=False) when
torch>=2.10 so the SIGSEGV does not bring down the pytest process.
TestMultiDiffusionSampleCompile (per-step denoiser compile) still runs.

Signed-off-by: Charlelie Laurent <claurent@nvidia.com>
@CharlelieLrt
Copy link
Copy Markdown
Collaborator Author

/blossom-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant