Skip to content

Commit 85cbe58

Browse files
authored
Minor modification to support DC-AE-turbo (#12169)
* minor modification to support dc-ae-turbo * minor
1 parent 4d9b822 commit 85cbe58

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ def __init__(
299299
act_fn: Union[str, Tuple[str]] = "silu",
300300
upsample_block_type: str = "pixel_shuffle",
301301
in_shortcut: bool = True,
302+
conv_act_fn: str = "relu",
302303
):
303304
super().__init__()
304305

@@ -349,7 +350,7 @@ def __init__(
349350
channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
350351

351352
self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
352-
self.conv_act = nn.ReLU()
353+
self.conv_act = get_activation(conv_act_fn)
353354
self.conv_out = None
354355

355356
if layers_per_block[0] > 0:
@@ -414,6 +415,12 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
414415
The normalization type(s) to use in the decoder.
415416
decoder_act_fns (`Union[str, Tuple[str]]`, defaults to `"silu"`):
416417
The activation function(s) to use in the decoder.
418+
encoder_out_shortcut (`bool`, defaults to `True`):
419+
Whether to use shortcut at the end of the encoder.
420+
decoder_in_shortcut (`bool`, defaults to `True`):
421+
Whether to use shortcut at the beginning of the decoder.
422+
decoder_conv_act_fn (`str`, defaults to `"relu"`):
423+
The activation function to use at the end of the decoder.
417424
scaling_factor (`float`, defaults to `1.0`):
418425
The multiplicative inverse of the root mean square of the latent features. This is used to scale the latent
419426
space to have unit variance when training the diffusion model. The latents are scaled with the formula `z =
@@ -441,6 +448,9 @@ def __init__(
441448
downsample_block_type: str = "pixel_unshuffle",
442449
decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
443450
decoder_act_fns: Union[str, Tuple[str]] = "silu",
451+
encoder_out_shortcut: bool = True,
452+
decoder_in_shortcut: bool = True,
453+
decoder_conv_act_fn: str = "relu",
444454
scaling_factor: float = 1.0,
445455
) -> None:
446456
super().__init__()
@@ -454,6 +464,7 @@ def __init__(
454464
layers_per_block=encoder_layers_per_block,
455465
qkv_multiscales=encoder_qkv_multiscales,
456466
downsample_block_type=downsample_block_type,
467+
out_shortcut=encoder_out_shortcut,
457468
)
458469
self.decoder = Decoder(
459470
in_channels=in_channels,
@@ -466,6 +477,8 @@ def __init__(
466477
norm_type=decoder_norm_types,
467478
act_fn=decoder_act_fns,
468479
upsample_block_type=upsample_block_type,
480+
in_shortcut=decoder_in_shortcut,
481+
conv_act_fn=decoder_conv_act_fn,
469482
)
470483

471484
self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)

0 commit comments

Comments
 (0)