@@ -299,6 +299,7 @@ def __init__(
299
299
act_fn : Union [str , Tuple [str ]] = "silu" ,
300
300
upsample_block_type : str = "pixel_shuffle" ,
301
301
in_shortcut : bool = True ,
302
+ conv_act_fn : str = "relu" ,
302
303
):
303
304
super ().__init__ ()
304
305
@@ -349,7 +350,7 @@ def __init__(
349
350
channels = block_out_channels [0 ] if layers_per_block [0 ] > 0 else block_out_channels [1 ]
350
351
351
352
self .norm_out = RMSNorm (channels , 1e-5 , elementwise_affine = True , bias = True )
352
- self .conv_act = nn . ReLU ( )
353
+ self .conv_act = get_activation ( conv_act_fn )
353
354
self .conv_out = None
354
355
355
356
if layers_per_block [0 ] > 0 :
@@ -414,6 +415,12 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
414
415
The normalization type(s) to use in the decoder.
415
416
decoder_act_fns (`Union[str, Tuple[str]]`, defaults to `"silu"`):
416
417
The activation function(s) to use in the decoder.
418
+ encoder_out_shortcut (`bool`, defaults to `True`):
419
+ Whether to use shortcut at the end of the encoder.
420
+ decoder_in_shortcut (`bool`, defaults to `True`):
421
+ Whether to use shortcut at the beginning of the decoder.
422
+ decoder_conv_act_fn (`str`, defaults to `"relu"`):
423
+ The activation function to use at the end of the decoder.
417
424
scaling_factor (`float`, defaults to `1.0`):
418
425
The multiplicative inverse of the root mean square of the latent features. This is used to scale the latent
419
426
space to have unit variance when training the diffusion model. The latents are scaled with the formula `z =
@@ -441,6 +448,9 @@ def __init__(
441
448
downsample_block_type : str = "pixel_unshuffle" ,
442
449
decoder_norm_types : Union [str , Tuple [str ]] = "rms_norm" ,
443
450
decoder_act_fns : Union [str , Tuple [str ]] = "silu" ,
451
+ encoder_out_shortcut : bool = True ,
452
+ decoder_in_shortcut : bool = True ,
453
+ decoder_conv_act_fn : str = "relu" ,
444
454
scaling_factor : float = 1.0 ,
445
455
) -> None :
446
456
super ().__init__ ()
@@ -454,6 +464,7 @@ def __init__(
454
464
layers_per_block = encoder_layers_per_block ,
455
465
qkv_multiscales = encoder_qkv_multiscales ,
456
466
downsample_block_type = downsample_block_type ,
467
+ out_shortcut = encoder_out_shortcut ,
457
468
)
458
469
self .decoder = Decoder (
459
470
in_channels = in_channels ,
@@ -466,6 +477,8 @@ def __init__(
466
477
norm_type = decoder_norm_types ,
467
478
act_fn = decoder_act_fns ,
468
479
upsample_block_type = upsample_block_type ,
480
+ in_shortcut = decoder_in_shortcut ,
481
+ conv_act_fn = decoder_conv_act_fn ,
469
482
)
470
483
471
484
self .spatial_compression_ratio = 2 ** (len (encoder_block_out_channels ) - 1 )
0 commit comments