From 595ca841fc1c40090b969919003b1fc0d438b8f6 Mon Sep 17 00:00:00 2001 From: IamTingTing <6121smile@gmail.com> Date: Fri, 18 Jul 2025 21:06:39 +0800 Subject: [PATCH] fix: dynamic input dim in DiffusionModelEncoder Signed-off-by: IamTingTing <6121smile@gmail.com> --- monai/networks/nets/diffusion_model_unet.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 11196bb343..ad05f6dd6a 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -33,6 +33,7 @@ import math from collections.abc import Sequence +from typing import Optional import torch from torch import nn @@ -2005,7 +2006,7 @@ def __init__( self.down_blocks.append(down_block) - self.out = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)) + self.out: Optional[nn.Module] = None def forward( self, @@ -2048,6 +2049,12 @@ def forward( h, _ = downsample_block(hidden_states=h, temb=emb, context=context) h = h.reshape(h.shape[0], -1) + + # 5. out + if self.out is None: + self.out = nn.Sequential( + nn.Linear(h.shape[1], 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels) + ) output: torch.Tensor = self.out(h) return output