45
45
46
46
import safetensors
47
47
import safetensors .numpy
48
+ from safetensors import deserialize
48
49
49
50
from mindnlp .core import nn
50
51
from mindnlp .core .nn import Parameter
@@ -1575,6 +1576,48 @@ def get_tensor(self, name):
1575
1576
return self .tensors [name ].get ()
1576
1577
1577
1578
1579
+ def legacy_safe_load_file (filename ):
1580
+ """
1581
+ This function safely loads a file containing state dictionary data and converts it into a dictionary of MindSpore Parameters.
1582
+
1583
+ Args:
1584
+ filename (str): The path to the file containing the state dictionary data to be loaded.
1585
+
1586
+ Returns:
1587
+ dict: A dictionary where keys are parameter names and values are MindSpore Parameters.
1588
+
1589
+ Raises:
1590
+ FileNotFoundError: If the specified file 'filename' does not exist.
1591
+ ValueError: If the data in the file is not in the correct format to create MindSpore Parameters.
1592
+ """
1593
+ with open (filename , "rb" ) as f :
1594
+ data = f .read ()
1595
+
1596
+ safeview = deserialize (data )
1597
+
1598
+ result = {}
1599
+ try :
1600
+ for k , v in safeview :
1601
+ dtype = _MS_TYPES [v ["dtype" ]]
1602
+ if (not SUPPORT_BF16 and dtype != mindspore .bfloat16 ) or SUPPORT_BF16 :
1603
+ arr = Tensor .convert_bytes_to_tensor (bytes (v ["data" ]), tuple (v ["shape" ]), dtype )
1604
+ result [k ] = Tensor (arr )
1605
+ else :
1606
+ raise TypeError ('Do not support bfloat16 on current device, use numpy as convert buffer to boost load.' )
1607
+ return result
1608
+
1609
+ except Exception as e :
1610
+ for k , v in safeview :
1611
+ dtype = _NP_TYPES [v ["dtype" ]]
1612
+ arr = np .frombuffer (v ["data" ], dtype = dtype ).reshape (v ["shape" ])
1613
+
1614
+ if (not SUPPORT_BF16 and dtype != bfloat16 ) or SUPPORT_BF16 :
1615
+ result [k ] = Tensor .from_numpy (arr )
1616
+ else :
1617
+ result [k ] = Tensor .from_numpy (arr .astype (np .float16 ))
1618
+ return result
1619
+
1620
+
1578
1621
def safe_load_file (filename ):
1579
1622
"""
1580
1623
This function safely loads a file containing state dictionary data and converts it into a dictionary of MindSpore Parameters.
@@ -1591,9 +1634,12 @@ def safe_load_file(filename):
1591
1634
"""
1592
1635
1593
1636
result = {}
1594
- with fast_safe_open (filename , framework = "np" ) as f :
1595
- for k in f .keys ():
1596
- result [k ] = f .get_tensor (k )
1637
+ try :
1638
+ with fast_safe_open (filename , framework = "np" ) as f :
1639
+ for k in f .keys ():
1640
+ result [k ] = f .get_tensor (k )
1641
+ except Exception as e :
1642
+ result = legacy_safe_load_file (filename )
1597
1643
return result
1598
1644
1599
1645
0 commit comments