Skip to content

Commit 0b686cd

Browse files
authored
[Accelerate] Remove is_module_offloaded and update_prefix_dict (#366)
Signed-off-by: Kyle Sayers <[email protected]>
1 parent ac2f257 commit 0b686cd

File tree

1 file changed

+0
-26
lines changed

1 file changed

+0
-26
lines changed

src/compressed_tensors/utils/offload.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,8 @@
6868

6969

7070
__all__ = [
71-
"is_module_offloaded",
7271
"get_execution_device",
7372
"get_offloaded_device",
74-
"update_prefix_dict",
7573
"update_parameter_data",
7674
"register_offload_parameter",
7775
"update_offload_parameter",
@@ -117,11 +115,6 @@ def fallback_fn(*args, **kwargs):
117115
""" Candidates for Depreciation """
118116

119117

120-
@check_accelerate(fallback=False)
121-
def is_module_offloaded(module: torch.nn.Module) -> bool:
122-
return has_offloaded_params(module)
123-
124-
125118
def get_offloaded_device(module: torch.nn.Module) -> torch.device:
126119
"""
127120
:param module: module to check
@@ -137,25 +130,6 @@ def get_offloaded_device(module: torch.nn.Module) -> torch.device:
137130
return get_execution_device(module)
138131

139132

140-
@check_accelerate(fallback=None)
141-
def update_prefix_dict(module: torch.nn.Module, key: str, data: torch.Tensor):
142-
"""
143-
Updates the offloaded state dict for a given module. Parameter named key is replaced
144-
by data. This is neccesary because parameter updates for offloaded modules do not
145-
persist automatically between loads. This function only affects the offloaded
146-
state dict and not the current state of the loaded module.
147-
148-
:param module: module containing the parameter to update
149-
:param key: name of parameter to update
150-
:param data: tensor to update parameter with in the offloaded state dict
151-
"""
152-
if not has_offloaded_params(module):
153-
raise ValueError("Prefix dict is only applicable to offloaded modules")
154-
155-
weights_map = module._hf_hook.weights_map
156-
offload_to_weights_map(weights_map, key, data)
157-
158-
159133
def update_parameter_data(
160134
module: torch.nn.Module, new_param_data: torch.Tensor, param_name: str
161135
):

0 commit comments

Comments
 (0)