Skip to content

Commit 9eed3af

Browse files
authored
fix bi* models on OrangePi (#2207)
1 parent a4a7f06 commit 9eed3af

File tree

3 files changed

+41
-8
lines changed

3 files changed

+41
-8
lines changed

mindtorch/_apis/npu.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def add(input, other, alpha=1.0): # pylint: disable=unused-argument
117117
Returns:
118118
Tensor: The result of the addition.
119119
"""
120-
if use_pyboost():
120+
if use_pyboost() and not ON_ORANGE_PI:
121121
return pyboost.add_ext_op(input, other, alpha)
122122
if alpha == 1.0:
123123
return legacy.add(input, other)
@@ -724,7 +724,7 @@ def less_equal(input, other):
724724

725725
def select(condition, input, other):
726726
if ON_ORANGE_PI:
727-
return add(mul(condition, input), mul(bitwise_not(condition), other))
727+
return legacy.add(mul(condition, input), mul(bitwise_not(condition), other))
728728
if use_pyboost():
729729
return pyboost.select_op(condition, input, other)
730730
return legacy.select(condition, input, other)
@@ -975,8 +975,23 @@ def inplace_zero(input):
975975
return input
976976

977977
def mse_loss(input, target, reduction):
978-
if use_pyboost():
978+
if use_pyboost() and not ON_ORANGE_PI:
979979
return pyboost.mse_loss_ext_op(input, target, reduction)
980+
x = square(input - target)
981+
average_flag = True
982+
reduce_flag = True
983+
if reduction == 'sum':
984+
average_flag = False
985+
if reduction == 'none':
986+
reduce_flag = False
987+
988+
if reduce_flag and average_flag:
989+
x = mean(x, tuple(range(x.ndim)), False, None)
990+
991+
if reduce_flag and not average_flag:
992+
x = sum(x, tuple(range(x.ndim)), False, None)
993+
994+
return x
980995

981996
def abs(input):
982997
if use_pyboost():
@@ -1126,7 +1141,7 @@ def pow_scalar_tensor(input, scalar):
11261141
return legacy.pow(input, scalar)
11271142

11281143
def adaptive_avg_pool2d(input, output_size):
1129-
if use_pyboost():
1144+
if use_pyboost() and not ON_ORANGE_PI:
11301145
return pyboost.adaptive_avg_pool2d_ext_op(input, output_size)
11311146
return legacy.adaptive_avg_pool2_d(input, output_size)
11321147

@@ -1362,6 +1377,21 @@ def _check_maxpool_padding(padding, nd):
13621377
return (0,) * (3 - nd) + tuple(padding)
13631378
return padding
13641379

1380+
def _cal_dilation(dilation, nd):
1381+
"""check the dilation"""
1382+
if isinstance(dilation, int):
1383+
return dilation
1384+
if isinstance(dilation, tuple):
1385+
if len(dilation) == 1:
1386+
return dilation[0]
1387+
if len(dilation) == nd:
1388+
return (3 - nd) * (1,) + dilation
1389+
if nd == 1:
1390+
raise ValueError(f"the length of 'dilation' must be 1, but got {len(dilation)}.")
1391+
raise ValueError(f"the length of 'dilation' must be 1 or {nd}, but got {len(dilation)}.")
1392+
raise ValueError(f"the 'dilation' must be int or tuple, but got {type(dilation)}.")
1393+
1394+
13651395
def max_pool2d(input, kernel_size, stride=1, padding=0, dilation=1, ceil_mode=False, return_indices=False):
13661396
# out, indices = legacy.max_pool_with_argmax_v2(input, kernel_size, stride, padding, dilation, ceil_mode)
13671397
if not ON_ORANGE_PI:
@@ -1379,6 +1409,7 @@ def max_pool2d(input, kernel_size, stride=1, padding=0, dilation=1, ceil_mode=Fa
13791409
elif isinstance(stride, int):
13801410
stride = (1, stride, stride)
13811411
padding = _check_maxpool_padding(padding, 2)
1412+
dilation = _cal_dilation(dilation, 2)
13821413

13831414
input = expand_dims(input, 2)
13841415
out, indices = legacy.max_pool3_d_with_argmax(input, kernel_size, stride, padding,
@@ -1550,9 +1581,9 @@ def outer(input, other):
15501581
return legacy.outer(input, other)
15511582

15521583
def addcmul(input, tensor1, tensor2, value=1.0):
1553-
if use_pyboost():
1584+
if use_pyboost() and not ON_ORANGE_PI:
15541585
return pyboost.addcmul_op(input, tensor1, tensor2, value)
1555-
return legacy.addcmul(input, tensor1, tensor2, value)
1586+
return legacy.add(mul(mul(tensor1, tensor2), value), input)
15561587

15571588
def prelu(input, weight):
15581589
if use_pyboost():

mindtorch/_tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,8 @@ def __setitem__(self, slices, value):
287287
def __add__(self, other):
288288
# if 0 in self.shape:
289289
# return self
290+
if self.dtype == mindtorch.bool:
291+
return ops.bitwise_or(self, other)
290292
return ops.add(self, other)
291293

292294
def __iadd__(self, other):

mindtorch/nn/functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from mindtorch._C import default_generator
1010
from mindtorch.nn.modules.utils import _pair
1111

12-
from ..configs import ON_A2, ON_A1, FLASH_ATTN_MASK_VALID
12+
from ..configs import ON_A2, ON_A1, ON_ORANGE_PI, FLASH_ATTN_MASK_VALID
1313

1414
generator_step_ = 12
1515

@@ -991,7 +991,7 @@ def group_norm(input, num_groups, weight=None, bias=None, eps=1e-5):
991991
weight = mindtorch.ones([input.shape[1]], dtype=input.dtype, device=input.device)
992992
if bias is None:
993993
bias = mindtorch.zeros([input.shape[1]], dtype=input.dtype, device=input.device)
994-
if input.device.type == 'npu':
994+
if input.device.type == 'npu' and not ON_ORANGE_PI:
995995
return execute('group_norm', input, num_groups, weight, bias, eps)[0]
996996

997997
input_shape = input.shape

0 commit comments

Comments
 (0)