@@ -132,6 +132,7 @@ def get_fsdp_config(self):
132
132
from functools import partial
133
133
134
134
# Third Party
135
+ from accelerate .utils import FullyShardedDataParallelPlugin
135
136
from peft .utils .other import fsdp_auto_wrap_policy
136
137
from torch .distributed .fsdp import BackwardPrefetch , ShardingStrategy
137
138
from torch .distributed .fsdp .fully_sharded_data_parallel import CPUOffload
@@ -159,27 +160,17 @@ def get_fsdp_config(self):
159
160
prefetch_policy = (
160
161
BackwardPrefetch .BACKWARD_POST if is_lora else BackwardPrefetch .BACKWARD_PRE
161
162
)
163
+ fsdp_plugin = FullyShardedDataParallelPlugin (
164
+ auto_wrap_policy = wrap_policy ,
165
+ limit_all_gathers = True ,
166
+ backward_prefetch = prefetch_policy ,
167
+ sharding_strategy = ShardingStrategy [self .fsdp_sharding_strategy ],
168
+ cpu_offload = CPUOffload (self .fsdp_cpu_offload_params ),
169
+ )
162
170
163
171
if self .device_str == "hpu" :
164
- from optimum .habana .accelerate .utils import GaudiFullyShardedDataParallelPlugin
165
- fsdp_plugin = GaudiFullyShardedDataParallelPlugin (
166
- auto_wrap_policy = wrap_policy ,
167
- limit_all_gathers = True ,
168
- backward_prefetch = prefetch_policy ,
169
- sharding_strategy = ShardingStrategy [self .fsdp_sharding_strategy ],
170
- cpu_offload = CPUOffload (self .fsdp_cpu_offload_params ),
171
- )
172
172
fsdp_plugin .use_orig_params = True
173
173
fsdp_plugin .sync_module_states = True
174
- else :
175
- from accelerate .utils import FullyShardedDataParallelPlugin
176
- fsdp_plugin = FullyShardedDataParallelPlugin (
177
- auto_wrap_policy = wrap_policy ,
178
- limit_all_gathers = True ,
179
- backward_prefetch = prefetch_policy ,
180
- sharding_strategy = ShardingStrategy [self .fsdp_sharding_strategy ],
181
- cpu_offload = CPUOffload (self .fsdp_cpu_offload_params ),
182
- )
183
174
184
175
# `use_orig_params` must be disabled when using LoRA and FSDP together
185
176
# Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts
0 commit comments