Skip to content

Commit 8955546

Browse files
authored
fix b class models on OrangePi (#2210)
1 parent 1d3b326 commit 8955546

File tree

4 files changed

+33
-12
lines changed

4 files changed

+33
-12
lines changed

mindtorch/_apis/npu.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ def inplace_copy(self, value):
6262
Args:
6363
value (Tensor): The tensor from which to copy the data.
6464
"""
65-
if use_pyboost:
65+
if use_pyboost():
6666
return pyboost.inplace_copy_op(self, value)
6767
else:
68-
self.assign_value(value)
68+
legacy.assign(self, value)
6969
return self
7070

7171
def slice(input, dim, start, end, step):
@@ -85,7 +85,15 @@ def slice(input, dim, start, end, step):
8585
if use_pyboost():
8686
return pyboost.slice_ext_op(input, dim, start, end, step)
8787
else:
88-
return legacy.slice(input, dim, start, end, step)
88+
ndim = input.ndim
89+
begins = [0] * ndim
90+
ends = [i for i in input.shape]
91+
strides = [1] * ndim
92+
begins[dim] = start
93+
ends[dim] = end
94+
strides[dim] = step
95+
return legacy.strided_slice(input, tuple(begins), tuple(ends), tuple(strides), 0, 0, 0, 0, 0)
96+
8997

9098
def embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq):
9199
"""
@@ -829,7 +837,7 @@ def bmm(input, other):
829837
return legacy.batch_mat_mul(input, other, False, False)
830838

831839
def topk(input, k, dim, largest, sorted):
832-
if use_pyboost():
840+
if use_pyboost() and not ON_ORANGE_PI:
833841
return pyboost.topk_ext_op(input, k, dim, largest, sorted)
834842

835843
if not largest:
@@ -1296,9 +1304,9 @@ def remainder_tensor_scalar(input, other):
12961304
return out
12971305

12981306
def baddbmm(input, batch1, batch2, alpha=1, beta=1):
1299-
if use_pyboost():
1307+
if use_pyboost() and not ON_ORANGE_PI:
13001308
return pyboost.baddbmm_op(input, batch1, batch2, alpha, beta)
1301-
return legacy.baddbmm(input, batch1, batch2, alpha, beta)
1309+
return add(mul(input, beta), mul(bmm(batch1, batch2), alpha))
13021310

13031311
def floor(input):
13041312
if use_pyboost():
@@ -1844,4 +1852,16 @@ def cumprod(input, dim, dtype):
18441852
out = legacy.cum_prod(input, dim, False, False)
18451853
if dtype is not None:
18461854
out = cast(out, dtype)
1847-
return out
1855+
return out
1856+
1857+
def scatter_nd_update(input, indices, updates):
1858+
return legacy.scatter_nd_update(input, indices, updates, True)
1859+
1860+
def assign(input, value):
1861+
return inplace_copy(input, value)
1862+
1863+
def strided_slice(input, begin, end, strides, begin_mask=0, end_mask=0, ellipsis_mask=0, new_axis_mask=0, shrink_axis_mask=0):
1864+
return legacy.strided_slice(input, tuple(begin), tuple(end), tuple(strides), begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask)
1865+
1866+
def tensor_scatter_update(input, indices, updates):
1867+
return legacy.tensor_scatter_update(input, indices, updates)

mindtorch/_tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class StubTensor: pass
2323
from ._bind import get_device_in_context, device_, get_default_dtype
2424
from ._utils import _rebuild_tensor_v2
2525
from ._C.size import Size
26-
from .configs import DEVICE_TARGET, cpu_use_numpy
26+
from .configs import DEVICE_TARGET, cpu_use_numpy, ON_ORANGE_PI
2727

2828
device_map = {
2929
'cpu': 'CPU',
@@ -282,7 +282,7 @@ def __setitem__(self, slices, value):
282282
if value.device != self.device:
283283
value._device = self.device
284284

285-
if self.device.type == 'npu':
285+
if self.device.type == 'npu' and not ON_ORANGE_PI:
286286
if value.device != self.device:
287287
value._device = self.device
288288
out = ops.tensor_setitem(self, slices, value)
@@ -301,7 +301,7 @@ def __iadd__(self, other):
301301
return self.copy_(ops.add(self, other))
302302

303303
def __radd__(self, other):
304-
return Tensor.__add__(other, self)
304+
return ops.add(other, self)
305305

306306
def __div__(self, other):
307307
# if 0 in self.shape:

mindtorch/nn/modules/activation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1804,7 +1804,7 @@ def forward(self, input: Tensor) -> Tensor:
18041804
"""
18051805
Runs the forward pass.
18061806
"""
1807-
return F.softmax(input, self.dim, _stacklevel=5)
1807+
return F.softmax(input, self.dim)
18081808

18091809
def extra_repr(self) -> str:
18101810
"""

mindtorch/ops/array.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,7 @@ def _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims):
634634
raise TypeError(f"Index {index} contain unsupported elements")
635635
self_viewed, dim, remain_indexes, self_viewed_shape = _process_dim_in_multi_dim_index(
636636
self_viewed, self, index, dim, indexed_dims, i, remain_indexes, self_viewed_shape)
637+
637638
return self_viewed, remain_indexes
638639

639640

@@ -1162,7 +1163,7 @@ def strided_slice_update(x, begin, end, strides, updates,
11621163
# for i in range(ndim-1, -1, -1):
11631164
# if (shrink_axis_mask >> i) & 1:
11641165
# x_updated = mindtorch.squeeze(x_updated, dim=i)
1165-
1166+
11661167
return x_updated
11671168

11681169

0 commit comments

Comments
 (0)