@@ -62,10 +62,10 @@ def inplace_copy(self, value):
6262 Args:
6363 value (Tensor): The tensor from which to copy the data.
6464 """
65- if use_pyboost :
65+ if use_pyboost () :
6666 return pyboost .inplace_copy_op (self , value )
6767 else :
68- self . assign_value ( value )
68+ legacy . assign ( self , value )
6969 return self
7070
7171def slice (input , dim , start , end , step ):
@@ -85,7 +85,15 @@ def slice(input, dim, start, end, step):
8585 if use_pyboost ():
8686 return pyboost .slice_ext_op (input , dim , start , end , step )
8787 else :
88- return legacy .slice (input , dim , start , end , step )
88+ ndim = input .ndim
89+ begins = [0 ] * ndim
90+ ends = [i for i in input .shape ]
91+ strides = [1 ] * ndim
92+ begins [dim ] = start
93+ ends [dim ] = end
94+ strides [dim ] = step
95+ return legacy .strided_slice (input , tuple (begins ), tuple (ends ), tuple (strides ), 0 , 0 , 0 , 0 , 0 )
96+
8997
9098def embedding (input , weight , padding_idx , max_norm , norm_type , scale_grad_by_freq ):
9199 """
@@ -829,7 +837,7 @@ def bmm(input, other):
829837 return legacy .batch_mat_mul (input , other , False , False )
830838
831839def topk (input , k , dim , largest , sorted ):
832- if use_pyboost ():
840+ if use_pyboost () and not ON_ORANGE_PI :
833841 return pyboost .topk_ext_op (input , k , dim , largest , sorted )
834842
835843 if not largest :
@@ -1296,9 +1304,9 @@ def remainder_tensor_scalar(input, other):
12961304 return out
12971305
12981306def baddbmm (input , batch1 , batch2 , alpha = 1 , beta = 1 ):
1299- if use_pyboost ():
1307+ if use_pyboost () and not ON_ORANGE_PI :
13001308 return pyboost .baddbmm_op (input , batch1 , batch2 , alpha , beta )
1301- return legacy . baddbmm ( input , batch1 , batch2 , alpha , beta )
1309+ return add ( mul ( input , beta ), mul ( bmm ( batch1 , batch2 ) , alpha ) )
13021310
13031311def floor (input ):
13041312 if use_pyboost ():
@@ -1844,4 +1852,16 @@ def cumprod(input, dim, dtype):
18441852 out = legacy .cum_prod (input , dim , False , False )
18451853 if dtype is not None :
18461854 out = cast (out , dtype )
1847- return out
1855+ return out
1856+
1857+ def scatter_nd_update (input , indices , updates ):
1858+ return legacy .scatter_nd_update (input , indices , updates , True )
1859+
1860+ def assign (input , value ):
1861+ return inplace_copy (input , value )
1862+
1863+ def strided_slice (input , begin , end , strides , begin_mask = 0 , end_mask = 0 , ellipsis_mask = 0 , new_axis_mask = 0 , shrink_axis_mask = 0 ):
1864+ return legacy .strided_slice (input , tuple (begin ), tuple (end ), tuple (strides ), begin_mask , end_mask , ellipsis_mask , new_axis_mask , shrink_axis_mask )
1865+
1866+ def tensor_scatter_update (input , indices , updates ):
1867+ return legacy .tensor_scatter_update (input , indices , updates )
0 commit comments