Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion mindnlp/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .ops import *
from .serialization import load, save
from ._bind import get_default_dtype, set_default_dtype
from .amp import autocast, GradScaler

from . import profiler, cuda, optim, amp, compiler, jit, version, __future__, overrides, \
return_types, linalg, fx, backends, testing
Expand Down Expand Up @@ -79,4 +80,18 @@ def wrap_func(fn):
return fn
if fn is not None:
return wrap_func(fn)
return wrap_func
return wrap_func

AUTO_CAST_DTYE = {
'cuda': float16,
'cpu': float16,
'npu': float16,
'Ascend': float16
}

def set_autocast_dtype(device_type, dtype):
assert device_type in AUTO_CAST_DTYE.keys(), f'{device_type} is not in {AUTO_CAST_DTYE.keys()}'
AUTO_CAST_DTYE[device_type] = dtype

def get_autocast_dtype(device_type):
return AUTO_CAST_DTYE[device_type]
4 changes: 4 additions & 0 deletions mindnlp/core/_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
def is_floating_point(self):
return isinstance(self, (typing.Float, typing.BFloat16))

def is_complex(self):
return isinstance(self, typing.Complex)

Type.is_floating_point = is_floating_point
Type.is_complex = is_complex
Type.__str__ = Type.__repr__

@property
Expand Down
51 changes: 46 additions & 5 deletions mindnlp/core/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ class StubTensor: pass
mindspore.float16: 2,
}

DEVICE_MAP = {
'GPU': 'cuda',
'Ascend': 'npu',
'CPU': 'cpu'
}

class TypedTensorMeta(_TensorMeta):
def __isinstancecheck__(self, instance):
if not isinstance(instance, Tensor):
Expand Down Expand Up @@ -77,11 +83,13 @@ def tensor(data, *, dtype=None, device=None, requires_grad=False):
if device is None:
device = get_default_device()

data_np = np.array(data, order='C') # must be C for mindspore Tensor
if dtype is not None:
data_np = data_np.astype(dtype2np[dtype])
tensor = Tensor(data, dtype=dtype)
else:
tensor = Tensor(data)

tensor = Tensor(data_np).to(device)
tensor = tensor.to(device)
tensor.requires_grad_(requires_grad)
return tensor

def is_tensor(x):
Expand Down Expand Up @@ -144,8 +152,8 @@ def data_ptr(self):
Tensor.data_ptr = data_ptr
StubTensor.data_ptr = data_ptr

Tensor.device = device_('not support yet.')
StubTensor.device = device_('not support yet.')
Tensor.device = device_(DEVICE_MAP[mindspore.get_context('device_target')])
StubTensor.device = device_(DEVICE_MAP[mindspore.get_context('device_target')])

def _expand(self, *size):
if len(size) == 1:
Expand Down Expand Up @@ -207,6 +215,25 @@ def __getitem__(self, slices):
Tensor.__getitem__ = __getitem__
StubTensor.__getitem__ = __getitem__

origin_setitem = Tensor.__setitem__
def __setitem__(self, slices, value):
if isinstance(value, float):
if value == float('inf'):
value = ops.finfo(self.dtype).max
elif value == -float('inf'):
value = ops.finfo(self.dtype).min
# if isinstance(slices, tuple):
# new_slices = ()
# for s in slices:
# if isinstance(s, range):
# s = list(s)
# new_slices += (s,)
# slices = new_slices
return origin_setitem(self, slices, value)

Tensor.__setitem__ = __setitem__
StubTensor.__setitem__ = __setitem__

def numel(self):
return math.prod(self.shape)

Expand Down Expand Up @@ -338,6 +365,20 @@ def data(self, new_value):
Tensor.narrow = ops.narrow
StubTensor.narrow = ops.narrow

def bitwise_or_(self, other):
out = ops.bitwise_or(self, other)
self.copy_(out)
return self

Tensor.bitwise_or_ = bitwise_or_
StubTensor.bitwise_or_ = bitwise_or_

# fix TypeError: unhashable type: 'StubTensor'
StubTensor.__hash__ = Tensor.__hash__

Tensor.masked_fill = ops.masked_fill
StubTensor.masked_fill = ops.masked_fill


def _rebuild_from_type_v2(func, new_type, args, state):
ret = func(*args)
Expand Down
1 change: 1 addition & 0 deletions mindnlp/core/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import cuda, mps
5 changes: 5 additions & 0 deletions mindnlp/core/backends/mps/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def is_available():
return False

def is_built():
return False
80 changes: 48 additions & 32 deletions mindnlp/core/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
from typing import Optional, Tuple, List
import numpy as np
import mindspore
from mindspore import ops, mint
from mindspore.ops._primitive_cache import _get_cache_prim

Expand All @@ -11,16 +12,15 @@

generator_step_ = 12

def gelu(input, approximate='none'):
if approximate == 'tanh':
return execute('gelu', input)
return input * 0.5 * (1.0 + core.erf(input / math.sqrt(2.0)))
def gelu(input, *, approximate='none'):
if use_pyboost():
return mint.nn.functional.gelu(input, approximate=approximate)
return ops.gelu(input, approximate)

def relu(input, inplace=False):
if inplace:
execute('inplace_relu', input)
return input
return execute('relu', input)
if use_pyboost():
return mint.nn.functional.relu(input)
return ops.relu(input)

def tanh(input, inplace=False):
if use_pyboost():
Expand All @@ -29,10 +29,16 @@ def tanh(input, inplace=False):


def sigmoid(input):
return execute('sigmoid', input)
if use_pyboost() and not ON_ORANGE_PI:
return mint.nn.functional.sigmoid(input)
return ops.sigmoid(input)

def silu(input):
return execute('silu', input)
def silu(input, inplace=False):
if DEVICE_TARGET == 'CPU' or ON_ORANGE_PI:
return input * sigmoid(input)
if use_pyboost():
return mint.nn.functional.silu(input)
return ops.silu(input)

def mish(input):
return ops.mish(input)
Expand All @@ -54,7 +60,10 @@ def softplus(input, beta=1, threshold=20):
return ops.softplus(input, beta, threshold)

def logsigmoid(input):
return execute('logsigmoid', input)
if use_pyboost():
return mint.nn.functional.logsigmoid(input)
return ops.logsigmoid(input)


def leaky_relu(input, alpha=0.2):
if use_pyboost():
Expand Down Expand Up @@ -221,13 +230,13 @@ def embedding(input, weight, padding_idx=None, max_norm=None, norm_type=2.0, sca
return ops.gather(weight, input, 0)

def rms_norm(input, normalized_shape, weight, eps=1e-5):
return execute('rms_norm', input, weight, eps)[0]
return ops.rms_norm(input, weight, eps)[0]

def fast_gelu(x):
return ops.fast_gelu(x)

def swiglu(x, dim=-1):
return execute('swiglu', x, dim)
return ops.swiglu(x, dim)

def apply_rotary_pos_emb(query, key, cos, sin, position_ids, cos_format=0):
return mindspore.ops.auto_generate.gen_ops_def.apply_rotary_pos_emb_(
Expand Down Expand Up @@ -361,8 +370,8 @@ def l1_loss(input, target, reduction='mean'):
return ops.l1_loss(input, target, reduction)

def smooth_l1_loss(input, target, beta=1.0, reduction='none'):
input = input.to(mindspore.float32)
target = target.to(mindspore.float32)
input = input.to(core.float32)
target = target.to(core.float32)
return ops.smooth_l1_loss(input, target, beta, reduction)

def kl_div(logits, labels, reduction='mean', log_target=False):
Expand Down Expand Up @@ -634,25 +643,32 @@ def _in_projection_packed(
b_q, b_k, b_v = b.chunk(3)
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)

def scaled_dot_product_attention(query, key, value, attn_mask, dropout_p, is_causal):
embed_size = query.shape[-1]
scaling_factor = ops.sqrt(ops.sqrt(core.Tensor(embed_size, dtype=query.dtype)))
query = query / scaling_factor

def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
is_causal=False, scale=None, enable_gqa=False) -> core.Tensor:
L, S = query.size(-2), key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = core.zeros(L, S, dtype=query.dtype, device=query.device)
if is_causal:
L = query.shape[-2]
S = key.shape[-2]
attn_mask = ops.ones((L, S), core.bool_).tril()
assert attn_mask is None
temp_mask = core.ones(L, S, dtype=core.bool).tril(diagonal=0)
attn_bias = attn_bias.masked_fill_(temp_mask.logical_not(), core.finfo(attn_bias.dtype).min)
attn_bias.to(query.dtype)

attn = ops.matmul(query, key.swapaxes(-2, -1) / scaling_factor)
if attn_mask is not None:
attn = attn + attn_mask
attn = softmax(attn, -1)
if dropout_p > 0.:
attn = ops.dropout(attn, dropout_p)
output = ops.matmul(attn, value)
if attn_mask.dtype == core.bool:
attn_bias = attn_bias.masked_fill_(attn_mask.logical_not(), core.finfo(attn_bias.dtype).min)
else:
attn_bias = attn_mask + attn_bias

return output
if enable_gqa:
key = key.repeat_interleave(query.size(-3) // key.size(-3), -3)
value = value.repeat_interleave(query.size(-3) // value.size(-3), -3)

attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = softmax(attn_weight, dim=-1)
attn_weight = dropout(attn_weight, dropout_p, training=True)
return attn_weight @ value


def _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads):
Expand Down Expand Up @@ -1197,4 +1213,4 @@ def make_causal_mask(
)

def rotary_position_embedding(x, cos, sin, mode=0):
return execute('rotary_position_embedding', x, cos, sin, mode)
return ops.rotary_position_embedding(x, cos, sin, mode)
14 changes: 11 additions & 3 deletions mindnlp/core/ops/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
from mindspore.ops.operations._grad_ops import StridedSliceGrad

from ..configs import use_pyboost, ON_ORANGE_PI
from .other import broadcast_tensors
from .other import broadcast_tensors, finfo
from ._inner import call_ms_func
from .creation import arange

# adjoint

Expand All @@ -31,7 +30,11 @@ def cat(tensors, dim=0, *, out=None, **kwargs):

# concat
has_concat = hasattr(mindspore.mint, 'concat')
def concat(tensors, dim=0, *, out=None):
def concat(tensors, dim=0, *, out=None, **kwargs):
axis = kwargs.get('axis', None)
if axis:
assert dim == 0, "Can not set `axis` and `dim` at same time."
dim = axis
return cat(tensors, dim, out=out)

# concatenate
Expand Down Expand Up @@ -372,6 +375,11 @@ def where(condition, *args, out=None):
return nonzero(condition, as_tuple=True)
assert len(args) == 2
input, other = args
if isinstance(input, float) and input == -float("inf"):
input = finfo(other.dtype).min
if isinstance(other, float) and other == -float("inf"):
other = finfo(input.dtype).min

output = mindspore.mint.where(condition, input, other)
if out is not None:
out.assign_value(output)
Expand Down
12 changes: 9 additions & 3 deletions mindnlp/core/ops/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,8 +791,6 @@ def tril(input, diagonal=0, *, out=None):

# triu
has_triu = hasattr(mindspore.mint, "triu")


def triu(input, diagonal=0, *, out=None):
if use_pyboost() and has_triu:
return call_ms_func(mindspore.mint.triu, input, diagonal, out=out)
Expand Down Expand Up @@ -821,8 +819,16 @@ def unflatten(x, dim, sizes):

# resolve_neg


has_masked_fill = hasattr(mindspore.mint, "masked_fill")
def masked_fill(input, mask, value):
if isinstance(value, float):
if value == -float('inf'):
value = finfo(input.dtype).min
if value == float('inf'):
value = finfo(input.dtype).max

if has_masked_fill:
return mindspore.mint.masked_fill(input, mask, value)
masked_fill_ = _get_cache_prim(ops.MaskedFill)()
return masked_fill_(input, mask, mindspore.tensor(value, dtype=input.dtype))

Expand Down
7 changes: 6 additions & 1 deletion tests/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@ def run_tests():
"""
# 获取命令行参数(排除脚本名本身)
pytest_args = sys.argv[1:]
# not support sdpa/loss.backward/torchscript/torch.fx/torch.compile
skip_ut = "not sdpa " \
"and not headmasking " \
"and not gradient_checkpointing " \
"and not retain_grad " \
"and not data_parallel"
"and not data_parallel " \
"and not with_static_cache " \
"and not compile " \
"and not compilation " \
"and not torchscript "

pytest_args.extend(['-k', skip_ut])
if not pytest_args:
Expand Down
Loading