@@ -117,7 +117,7 @@ def add(input, other, alpha=1.0): # pylint: disable=unused-argument
117117 Returns:
118118 Tensor: The result of the addition.
119119 """
120- if use_pyboost ():
120+ if use_pyboost () and not ON_ORANGE_PI :
121121 return pyboost .add_ext_op (input , other , alpha )
122122 if alpha == 1.0 :
123123 return legacy .add (input , other )
@@ -724,7 +724,7 @@ def less_equal(input, other):
724724
725725def select (condition , input , other ):
726726 if ON_ORANGE_PI :
727- return add (mul (condition , input ), mul (bitwise_not (condition ), other ))
727+ return legacy . add (mul (condition , input ), mul (bitwise_not (condition ), other ))
728728 if use_pyboost ():
729729 return pyboost .select_op (condition , input , other )
730730 return legacy .select (condition , input , other )
@@ -975,8 +975,23 @@ def inplace_zero(input):
975975 return input
976976
977977def mse_loss (input , target , reduction ):
978- if use_pyboost ():
978+ if use_pyboost () and not ON_ORANGE_PI :
979979 return pyboost .mse_loss_ext_op (input , target , reduction )
980+ x = square (input - target )
981+ average_flag = True
982+ reduce_flag = True
983+ if reduction == 'sum' :
984+ average_flag = False
985+ if reduction == 'none' :
986+ reduce_flag = False
987+
988+ if reduce_flag and average_flag :
989+ x = mean (x , tuple (range (x .ndim )), False , None )
990+
991+ if reduce_flag and not average_flag :
992+ x = sum (x , tuple (range (x .ndim )), False , None )
993+
994+ return x
980995
981996def abs (input ):
982997 if use_pyboost ():
@@ -1126,7 +1141,7 @@ def pow_scalar_tensor(input, scalar):
11261141 return legacy .pow (input , scalar )
11271142
11281143def adaptive_avg_pool2d (input , output_size ):
1129- if use_pyboost ():
1144+ if use_pyboost () and not ON_ORANGE_PI :
11301145 return pyboost .adaptive_avg_pool2d_ext_op (input , output_size )
11311146 return legacy .adaptive_avg_pool2_d (input , output_size )
11321147
@@ -1362,6 +1377,21 @@ def _check_maxpool_padding(padding, nd):
13621377 return (0 ,) * (3 - nd ) + tuple (padding )
13631378 return padding
13641379
1380+ def _cal_dilation (dilation , nd ):
1381+ """check the dilation"""
1382+ if isinstance (dilation , int ):
1383+ return dilation
1384+ if isinstance (dilation , tuple ):
1385+ if len (dilation ) == 1 :
1386+ return dilation [0 ]
1387+ if len (dilation ) == nd :
1388+ return (3 - nd ) * (1 ,) + dilation
1389+ if nd == 1 :
1390+ raise ValueError (f"the length of 'dilation' must be 1, but got { len (dilation )} ." )
1391+ raise ValueError (f"the length of 'dilation' must be 1 or { nd } , but got { len (dilation )} ." )
1392+ raise ValueError (f"the 'dilation' must be int or tuple, but got { type (dilation )} ." )
1393+
1394+
13651395def max_pool2d (input , kernel_size , stride = 1 , padding = 0 , dilation = 1 , ceil_mode = False , return_indices = False ):
13661396 # out, indices = legacy.max_pool_with_argmax_v2(input, kernel_size, stride, padding, dilation, ceil_mode)
13671397 if not ON_ORANGE_PI :
@@ -1379,6 +1409,7 @@ def max_pool2d(input, kernel_size, stride=1, padding=0, dilation=1, ceil_mode=Fa
13791409 elif isinstance (stride , int ):
13801410 stride = (1 , stride , stride )
13811411 padding = _check_maxpool_padding (padding , 2 )
1412+ dilation = _cal_dilation (dilation , 2 )
13821413
13831414 input = expand_dims (input , 2 )
13841415 out , indices = legacy .max_pool3_d_with_argmax (input , kernel_size , stride , padding ,
@@ -1550,9 +1581,9 @@ def outer(input, other):
15501581 return legacy .outer (input , other )
15511582
15521583def addcmul (input , tensor1 , tensor2 , value = 1.0 ):
1553- if use_pyboost ():
1584+ if use_pyboost () and not ON_ORANGE_PI :
15541585 return pyboost .addcmul_op (input , tensor1 , tensor2 , value )
1555- return legacy .addcmul ( input , tensor1 , tensor2 , value )
1586+ return legacy .add ( mul ( mul ( tensor1 , tensor2 ) , value ), input )
15561587
15571588def prelu (input , weight ):
15581589 if use_pyboost ():
0 commit comments