Skip to content

Commit 9f32ec2

Browse files
committed
Update OH to v1.18.0
1 parent d1cada3 commit 9f32ec2

File tree

2 files changed

+26
-22
lines changed

2 files changed

+26
-22
lines changed

docs/hpu.md

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Next changes are required to enable training on HPU:
1010

1111
It is also recommended to use HPU optimized versions of transformers:
1212

13-
```python
13+
```Python
1414
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
1515
adapt_transformers_to_gaudi()
1616
```
@@ -29,8 +29,21 @@ To compute bucked size, we use next algorithm:
2929
This approach limits overhead of the bucketing to 1/16 th of the longest sample and allows us to significantly reduce number of recompilations.
3030

3131
## How to run
32-
To run training make next changes to config file:
33-
```json
32+
To run training build docker using next dockerfile:
33+
```Dockerfile
34+
FROM vault.habana.ai/gaudi-docker/1.21.0/rhel9.4/habanalabs/pytorch-installer-2.6.0:1.21.0-555
35+
36+
ARG CMAKE_ARGS="-DGGML_NATIVE=off"
37+
38+
WORKDIR /app
39+
RUN pip install git+https://github.com/instructlab/[email protected]
40+
41+
WORKDIR /app
42+
RUN pip install git+https://github.com/huggingface/[email protected]
43+
```
44+
45+
Then make next changes to config file:
46+
```YAML
3447
train:
3548
device: hpu
3649
distributed_backend: fsdp
@@ -40,8 +53,8 @@ train:
4053
disable_flash_attn: true
4154
```
4255
43-
And use this command line:
44-
```bash
56+
And finally run this command line:
57+
```BASH
4558
ilab --config=./config.yaml model train --pipeline accelerated --data-path ./data.jsonl
4659
```
4760

src/instructlab/training/accelerator.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def get_fsdp_config(self):
132132
from functools import partial
133133

134134
# Third Party
135+
from accelerate.utils import FullyShardedDataParallelPlugin
135136
from peft.utils.other import fsdp_auto_wrap_policy
136137
from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
137138
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
@@ -159,27 +160,17 @@ def get_fsdp_config(self):
159160
prefetch_policy = (
160161
BackwardPrefetch.BACKWARD_POST if is_lora else BackwardPrefetch.BACKWARD_PRE
161162
)
163+
fsdp_plugin = FullyShardedDataParallelPlugin(
164+
auto_wrap_policy=wrap_policy,
165+
limit_all_gathers=True,
166+
backward_prefetch=prefetch_policy,
167+
sharding_strategy=ShardingStrategy[self.fsdp_sharding_strategy],
168+
cpu_offload=CPUOffload(self.fsdp_cpu_offload_params),
169+
)
162170

163171
if self.device_str == "hpu":
164-
from optimum.habana.accelerate.utils import GaudiFullyShardedDataParallelPlugin
165-
fsdp_plugin = GaudiFullyShardedDataParallelPlugin(
166-
auto_wrap_policy=wrap_policy,
167-
limit_all_gathers=True,
168-
backward_prefetch=prefetch_policy,
169-
sharding_strategy=ShardingStrategy[self.fsdp_sharding_strategy],
170-
cpu_offload=CPUOffload(self.fsdp_cpu_offload_params),
171-
)
172172
fsdp_plugin.use_orig_params=True
173173
fsdp_plugin.sync_module_states=True
174-
else:
175-
from accelerate.utils import FullyShardedDataParallelPlugin
176-
fsdp_plugin = FullyShardedDataParallelPlugin(
177-
auto_wrap_policy=wrap_policy,
178-
limit_all_gathers=True,
179-
backward_prefetch=prefetch_policy,
180-
sharding_strategy=ShardingStrategy[self.fsdp_sharding_strategy],
181-
cpu_offload=CPUOffload(self.fsdp_cpu_offload_params),
182-
)
183174

184175
# `use_orig_params` must be disabled when using LoRA and FSDP together
185176
# Source: https://huggingface.co/docs/peft/en/accelerate/fsdp#the-important-parts

0 commit comments

Comments
 (0)