68
68
69
69
70
70
__all__ = [
71
- "is_module_offloaded" ,
72
71
"get_execution_device" ,
73
72
"get_offloaded_device" ,
74
- "update_prefix_dict" ,
75
73
"update_parameter_data" ,
76
74
"register_offload_parameter" ,
77
75
"update_offload_parameter" ,
@@ -117,11 +115,6 @@ def fallback_fn(*args, **kwargs):
117
115
""" Candidates for Depreciation """
118
116
119
117
120
- @check_accelerate (fallback = False )
121
- def is_module_offloaded (module : torch .nn .Module ) -> bool :
122
- return has_offloaded_params (module )
123
-
124
-
125
118
def get_offloaded_device (module : torch .nn .Module ) -> torch .device :
126
119
"""
127
120
:param module: module to check
@@ -137,25 +130,6 @@ def get_offloaded_device(module: torch.nn.Module) -> torch.device:
137
130
return get_execution_device (module )
138
131
139
132
140
- @check_accelerate (fallback = None )
141
- def update_prefix_dict (module : torch .nn .Module , key : str , data : torch .Tensor ):
142
- """
143
- Updates the offloaded state dict for a given module. Parameter named key is replaced
144
- by data. This is neccesary because parameter updates for offloaded modules do not
145
- persist automatically between loads. This function only affects the offloaded
146
- state dict and not the current state of the loaded module.
147
-
148
- :param module: module containing the parameter to update
149
- :param key: name of parameter to update
150
- :param data: tensor to update parameter with in the offloaded state dict
151
- """
152
- if not has_offloaded_params (module ):
153
- raise ValueError ("Prefix dict is only applicable to offloaded modules" )
154
-
155
- weights_map = module ._hf_hook .weights_map
156
- offload_to_weights_map (weights_map , key , data )
157
-
158
-
159
133
def update_parameter_data (
160
134
module : torch .nn .Module , new_param_data : torch .Tensor , param_name : str
161
135
):
0 commit comments