@@ -322,11 +322,11 @@ def forward(self, input: Tensor) -> Tensor:
322322class _ConvTransposeNd (_ConvNd ):
323323 def __init__ (self , in_channels , out_channels , kernel_size , stride ,
324324 padding , dilation , transposed , output_padding ,
325- groups , bias , padding_mode , dtype = None ) -> None :
325+ groups , bias , padding_mode , dtype = None , device = None ) -> None :
326326 if padding_mode != 'zeros' :
327327 raise ValueError (f'Only "zeros" padding mode is supported for { self .__class__ .__name__ } ' )
328328
329- factory_kwargs = {'dtype' : dtype }
329+ factory_kwargs = {'dtype' : dtype , 'device' : device }
330330 super ().__init__ (
331331 in_channels , out_channels , kernel_size , stride ,
332332 padding , dilation , transposed , output_padding ,
@@ -426,62 +426,71 @@ class ConvTranspose1d(_ConvTransposeNd):
426426 bias (Tensor): the learnable bias of the module of shape (out_channels)
427427 """
428428
429- def __init__ (self , in_channels , out_channels , kernel_size , stride = 1 ,
430- padding = 0 , output_padding = 0 , groups = 1 , bias = True , dilation = 1 , padding_mode : str = 'zeros' ):
429+ def __init__ (
430+ self ,
431+ in_channels : int ,
432+ out_channels : int ,
433+ kernel_size : _size_1_t ,
434+ stride : _size_1_t = 1 ,
435+ padding : _size_1_t = 0 ,
436+ output_padding : _size_1_t = 0 ,
437+ groups : int = 1 ,
438+ bias : bool = True ,
439+ dilation : _size_1_t = 1 ,
440+ padding_mode : str = "zeros" ,
441+ device = None ,
442+ dtype = None ,
443+ ) -> None :
444+ factory_kwargs = {"device" : device , "dtype" : dtype }
431445 kernel_size = _single (kernel_size )
432446 stride = _single (stride )
433447 padding = _single (padding )
434448 dilation = _single (dilation )
435449 output_padding = _single (output_padding )
436- super (ConvTranspose1d , self ).__init__ (
437- in_channels , out_channels , kernel_size , stride , padding , dilation ,
438- True , output_padding , groups , bias , padding_mode )
439-
440- pad_mode = 'pad'
441- pad = padding
442- if isinstance (padding , tuple ):
443- pad = (0 , 0 , padding [0 ], padding [0 ])
444- elif isinstance (padding , int ):
445- pad = (0 , 0 ) + (padding ,) * 2
446- if not isinstance (padding , (int , tuple )):
447- pad_mode = padding
448- pad = (0 ,) * 4
449-
450- # cause Conv2DTranspose's out_channel refers to Conv2D's out_channel.
451- self .conv2d_transpose = mops .Conv2DTranspose (out_channel = self .out_channels ,
452- kernel_size = (1 ,) + self .kernel_size ,
453- mode = 1 ,
454- pad_mode = pad_mode ,
455- pad = pad ,
456- stride = (1 ,) + self .stride ,
457- dilation = (1 ,) + self .dilation ,
458- group = self .groups )
459- self .h_add = _deconv_output_length (pad_mode , 1 , 1 , 1 , pad [0 ] + pad [1 ])
460- self .w_add = _deconv_output_length (pad_mode , kernel_size [0 ], stride [0 ], dilation [0 ], pad [2 ] + pad [3 ])
461-
462- def forward (self , input , output_size = None ):
463- if self .padding_mode != 'zeros' :
464- raise ValueError ('Only `zeros` padding mode is supported for ConvTranspose2d' )
450+ super ().__init__ (
451+ in_channels ,
452+ out_channels ,
453+ kernel_size ,
454+ stride ,
455+ padding ,
456+ dilation ,
457+ True ,
458+ output_padding ,
459+ groups ,
460+ bias ,
461+ padding_mode ,
462+ ** factory_kwargs ,
463+ )
464+
465+ def forward (self , input : Tensor , output_size : Optional [list [int ]] = None ) -> Tensor :
466+ if self .padding_mode != "zeros" :
467+ raise ValueError (
468+ "Only `zeros` padding mode is supported for ConvTranspose1d"
469+ )
465470
466471 assert isinstance (self .padding , tuple )
467472 # One cannot replace List by Tuple or Sequence in "_output_padding" because
468473 # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
469474 num_spatial_dims = 1
470475 output_padding = self ._output_padding (
471- input , output_size , self .stride , self .padding , self .kernel_size , # type: ignore[arg-type]
472- num_spatial_dims , self .dilation ) # type: ignore[arg-type]
473- input = mops .expand_dims (input , 2 )
474- n , _ , h , w = input .shape
475- conv2d_trans_ret = self .conv2d_transpose (input , self .weight .expand_dims (2 ),
476- (n , self .out_channels ,
477- h + self .h_add ,
478- w * self .stride [0 ] + self .w_add ))
479- if self .bias is not None :
480- conv2d_trans_ret = mops .bias_add (conv2d_trans_ret , self .bias )
481-
482- conv2d_trans_ret = conv2d_trans_ret .squeeze (2 )
483- conv2d_trans_ret = ops .pad (conv2d_trans_ret , (0 ,) + output_padding , value = 0. )
484- return conv2d_trans_ret
476+ input ,
477+ output_size ,
478+ self .stride , # type: ignore[arg-type]
479+ self .padding , # type: ignore[arg-type]
480+ self .kernel_size , # type: ignore[arg-type]
481+ num_spatial_dims ,
482+ self .dilation , # type: ignore[arg-type]
483+ )
484+ return F .conv_transpose1d (
485+ input ,
486+ self .weight ,
487+ self .bias ,
488+ self .stride ,
489+ self .padding ,
490+ output_padding ,
491+ self .groups ,
492+ self .dilation ,
493+ )
485494
486495
487496def _deconv_output_length (pad_mode , filter_size , stride_size , dilation_size , padding ):
@@ -582,66 +591,80 @@ class ConvTranspose2d(_ConvTransposeNd):
582591 https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
583592 """
584593
585- def __init__ (self , in_channels , out_channels , kernel_size , stride = 1 ,
586- padding = 0 , output_padding = 0 , groups = 1 , bias = True , dilation = 1 ,
587- padding_mode = 'zeros' , dtype = None ):
588- factory_kwargs = {'dtype' : dtype }
594+ def __init__ (
595+ self ,
596+ in_channels : int ,
597+ out_channels : int ,
598+ kernel_size : _size_2_t ,
599+ stride : _size_2_t = 1 ,
600+ padding : _size_2_t = 0 ,
601+ output_padding : _size_2_t = 0 ,
602+ groups : int = 1 ,
603+ bias : bool = True ,
604+ dilation : _size_2_t = 1 ,
605+ padding_mode : str = "zeros" ,
606+ device = None ,
607+ dtype = None ,
608+ ) -> None :
609+ factory_kwargs = {"device" : device , "dtype" : dtype }
589610 kernel_size = _pair (kernel_size )
590611 stride = _pair (stride )
591612 padding = _pair (padding )
592613 dilation = _pair (dilation )
593614 output_padding = _pair (output_padding )
594615 super ().__init__ (
595- in_channels , out_channels , kernel_size , stride , padding , dilation ,
596- True , output_padding , groups , bias , padding_mode , ** factory_kwargs )
597-
598- pad_mode = 'pad'
599- pad = padding
600- if isinstance (padding , tuple ):
601- pad = (padding [0 ], padding [0 ], padding [1 ], padding [1 ])
602- elif isinstance (padding , int ):
603- pad = (padding ,) * 4
604- if not isinstance (padding , (int , tuple )):
605- pad_mode = padding
606- pad = (0 ,) * 4
607-
608- # cause Conv2DTranspose's out_channel refers to Conv2D's out_channel.
609- self .conv2d_transpose = mops .Conv2DTranspose (out_channel = in_channels ,
610- kernel_size = kernel_size ,
611- mode = 1 ,
612- pad_mode = pad_mode ,
613- pad = pad ,
614- stride = stride ,
615- dilation = dilation ,
616- group = groups )
617-
618- self .h_add = _deconv_output_length (pad_mode , kernel_size [0 ], stride [0 ], dilation [0 ], pad [0 ] + pad [1 ])
619- self .w_add = _deconv_output_length (pad_mode , kernel_size [1 ], stride [1 ], dilation [1 ], pad [2 ] + pad [3 ])
620-
621- def forward (self , input , output_size = None ):
622- if self .padding_mode != 'zeros' :
623- raise ValueError ('Only `zeros` padding mode is supported for ConvTranspose2d' )
616+ in_channels ,
617+ out_channels ,
618+ kernel_size ,
619+ stride ,
620+ padding ,
621+ dilation ,
622+ True ,
623+ output_padding ,
624+ groups ,
625+ bias ,
626+ padding_mode ,
627+ ** factory_kwargs ,
628+ )
629+
630+ def forward (self , input : Tensor , output_size : Optional [list [int ]] = None ) -> Tensor :
631+ """
632+ Performs the forward pass.
633+
634+ Attributes:
635+ input (Tensor): The input tensor.
636+ output_size (list[int], optional): A list of integers representing
637+ the size of the output tensor. Default is None.
638+ """
639+ if self .padding_mode != "zeros" :
640+ raise ValueError (
641+ "Only `zeros` padding mode is supported for ConvTranspose2d"
642+ )
624643
625644 assert isinstance (self .padding , tuple )
626645 # One cannot replace List by Tuple or Sequence in "_output_padding" because
627646 # TorchScript does not support `Sequence[T]` or `Tuple[T, ...]`.
628647 num_spatial_dims = 2
629648 output_padding = self ._output_padding (
630- input , output_size , self .stride , self .padding , self .kernel_size , # type: ignore[arg-type]
631- num_spatial_dims , self .dilation ) # type: ignore[arg-type]
632-
633- n , _ , h , w = input .shape
634- conv2d_trans_ret = self .conv2d_transpose (input , self .weight ,
635- (n , self .out_channels ,
636- h * self .stride [0 ] + self .h_add ,
637- w * self .stride [1 ] + self .w_add ))
638- if self .bias is not None :
639- conv2d_trans_ret = mops .bias_add (conv2d_trans_ret , self .bias )
640-
641- conv2d_trans_ret = ops .pad (conv2d_trans_ret , output_padding , value = 0. )
642-
643- return conv2d_trans_ret
649+ input ,
650+ output_size ,
651+ self .stride , # type: ignore[arg-type]
652+ self .padding , # type: ignore[arg-type]
653+ self .kernel_size , # type: ignore[arg-type]
654+ num_spatial_dims ,
655+ self .dilation , # type: ignore[arg-type]
656+ )
644657
658+ return F .conv_transpose2d (
659+ input ,
660+ self .weight ,
661+ self .bias ,
662+ self .stride ,
663+ self .padding ,
664+ output_padding ,
665+ self .groups ,
666+ self .dilation ,
667+ )
645668
646669# class ConvTranspose3d(_ConvTransposeNd):
647670# r"""Applies a 3D transposed convolution operator over an input image composed of several input
0 commit comments