Skip to content

Commit 6ed1cf3

Browse files
authored
fix safetensors loading bug and add some modifications based on orangepi (#2055)
1 parent 7b308e5 commit 6ed1cf3

File tree

3 files changed

+52
-4
lines changed

3 files changed

+52
-4
lines changed

mindnlp/core/nn/functional.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
def gelu(input, approximate='none'):
1515
if use_pyboost():
1616
return mindspore.mint.nn.functional.gelu(input, approximate=approximate)
17+
if ON_ORANGE_PI:
18+
return mindspore.mint.nn.functional.gelu(input, approximate=approximate)
1719
return ops.gelu(input, approximate)
1820

1921
def relu(input):

mindnlp/core/serialization.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545

4646
import safetensors
4747
import safetensors.numpy
48+
from safetensors import deserialize
4849

4950
from mindnlp.core import nn
5051
from mindnlp.core.nn import Parameter
@@ -1575,6 +1576,48 @@ def get_tensor(self, name):
15751576
return self.tensors[name].get()
15761577

15771578

1579+
def legacy_safe_load_file(filename):
1580+
"""
1581+
This function safely loads a file containing state dictionary data and converts it into a dictionary of MindSpore Parameters.
1582+
1583+
Args:
1584+
filename (str): The path to the file containing the state dictionary data to be loaded.
1585+
1586+
Returns:
1587+
dict: A dictionary where keys are parameter names and values are MindSpore Parameters.
1588+
1589+
Raises:
1590+
FileNotFoundError: If the specified file 'filename' does not exist.
1591+
ValueError: If the data in the file is not in the correct format to create MindSpore Parameters.
1592+
"""
1593+
with open(filename, "rb") as f:
1594+
data = f.read()
1595+
1596+
safeview = deserialize(data)
1597+
1598+
result = {}
1599+
try:
1600+
for k, v in safeview:
1601+
dtype = _MS_TYPES[v["dtype"]]
1602+
if (not SUPPORT_BF16 and dtype != mindspore.bfloat16) or SUPPORT_BF16:
1603+
arr = Tensor.convert_bytes_to_tensor(bytes(v["data"]), tuple(v["shape"]), dtype)
1604+
result[k] = Tensor(arr)
1605+
else:
1606+
raise TypeError('Do not support bfloat16 on current device, use numpy as convert buffer to boost load.')
1607+
return result
1608+
1609+
except Exception as e:
1610+
for k, v in safeview:
1611+
dtype = _NP_TYPES[v["dtype"]]
1612+
arr = np.frombuffer(v["data"], dtype=dtype).reshape(v["shape"])
1613+
1614+
if (not SUPPORT_BF16 and dtype != bfloat16) or SUPPORT_BF16:
1615+
result[k] = Tensor.from_numpy(arr)
1616+
else:
1617+
result[k] = Tensor.from_numpy(arr.astype(np.float16))
1618+
return result
1619+
1620+
15781621
def safe_load_file(filename):
15791622
"""
15801623
This function safely loads a file containing state dictionary data and converts it into a dictionary of MindSpore Parameters.
@@ -1591,9 +1634,12 @@ def safe_load_file(filename):
15911634
"""
15921635

15931636
result = {}
1594-
with fast_safe_open(filename, framework="np") as f:
1595-
for k in f.keys():
1596-
result[k] = f.get_tensor(k)
1637+
try:
1638+
with fast_safe_open(filename, framework="np") as f:
1639+
for k in f.keys():
1640+
result[k] = f.get_tensor(k)
1641+
except Exception as e:
1642+
result = legacy_safe_load_file(filename)
15971643
return result
15981644

15991645

mindnlp/engine/trainer/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def num_tokens(self, train_ds: 'mindspore.dataset.Dataset', max_steps: Optional[
594594
"""
595595
train_tokens = 0
596596
try:
597-
for step, batch in train_ds.create_dict_iterator():
597+
for step, batch in enumerate(train_ds.create_dict_iterator()):
598598
tokens = batch["input_ids"].numel()
599599
if max_steps is not None:
600600
return tokens * max_steps

0 commit comments

Comments
 (0)