Skip to content

Commit 9721837

Browse files
authored
Replace huggingface args with hydra configs (#57)
* Replace huggingface config args with hydra configs * Fix train.py script to remove old args * Add hydra-core dependency to pyproject.toml and fix lint * Update README.md and rename llama.yaml to llama-3-8b.yaml
1 parent a908b1b commit 9721837

File tree

10 files changed

+126
-220
lines changed

10 files changed

+126
-220
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ Train Llama 3 8B using torch_xla:
2727

2828
```sh
2929
export HF_TOKEN='...your huggingface token...'
30-
XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 python3 torchprime/torch_xla_models/train.py \
31-
torchprime/torch_xla_models/configs/run.json
30+
XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 python3 torchprime/torch_xla_models/train.py
3231
```
3332

3433
Train Llama 3 8B using torchax:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dependencies = [
1616
"transformers==4.44.2",
1717
"transformers[torch]==4.44.2",
1818
"datasets==3.0.0",
19+
"hydra-core==1.3.0",
1920
"optax==0.2.4",
2021
"fire==0.7.0",
2122
"tensorflow-cpu==2.18.0",

torchprime/torch_xla_models/README.md

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
3. Run the training script:
2525

2626
```
27-
XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 python3 torchprime/torch_xla_models/train.py \
28-
torchprime/torch_xla_models/configs/run.json
27+
XLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 python3 torchprime/torch_xla_models/train.py
2928
```
3029

3130
## Running on XPK
@@ -37,20 +36,7 @@ export HF_TOKEN='... hugging face token ...'
3736
export XLA_IR_DEBUG=1
3837
export XLA_HLO_DEBUG=1
3938

40-
tp run torchprime/torch_xla_models/train.py \
41-
--dataset_name wikitext \
42-
--dataset_config_name 'wikitext-103-raw-v1' \
43-
--output_dir /tmp \
44-
--cache_dir /tmp \
45-
--global_batch_size 256 \
46-
--logging_steps 10 \
47-
--max_steps 30 \
48-
--profile_step 5 \
49-
--model_id 'meta-llama/Meta-Llama-3-8B' \
50-
--tokenizer_name 'meta-llama/Meta-Llama-3-8B' \
51-
--block_size 8192 \
52-
--fsdp full_shard \
53-
--fsdp_config torchprime/torch_xla_models/configs/fsdp_config.json
39+
tp run torchprime/torch_xla_models/train.py
5440
```
5541

5642
This will build the dockerfile and launch it on XPK.
@@ -59,5 +45,6 @@ This will build the dockerfile and launch it on XPK.
5945
## Key Components
6046

6147
- `train.py`: Main training script that sets up the model, data, and training loop
62-
- `configs/run.json`: Configuration file for the training script
48+
- `configs/base.yaml`: Configuration file for the training script
49+
- `configs/model`: Configuration files for the training models
6350
- `llama/model.py`: Implementation of the Llama model
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# This defines the order in which configs are loaded. The latter configs
2+
# override the earlier ones.
3+
defaults:
4+
- _self_ # refers to this config file
5+
- model: llama-3-8b # refers to model/llama.yaml
6+
7+
dataset_name: wikitext
8+
dataset_config_name: wikitext-2-raw-v1
9+
global_batch_size: 8
10+
logging_steps: 10
11+
max_steps: 15
12+
block_size: 8192
13+
cache_dir: /tmp/
14+
seed: 42
15+
profile_step: -1
16+
profile_logdir: /tmp/profile
17+
profile_duration: 100000
18+
fsdp:
19+
transformer_layer_cls_to_wrap:
20+
- LlamaDecoderLayer
21+
xla_fsdp_grad_ckpt: true
22+
optimizer:
23+
learning_rate: 5.e-5
24+
lr_scheduler:
25+
type: linear
26+
warmup_steps: 0

torchprime/torch_xla_models/configs/fsdp_config.json

Lines changed: 0 additions & 8 deletions
This file was deleted.
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
vocab_size: 128256
2+
hidden_size: 4096
3+
intermediate_size: 14336
4+
num_hidden_layers: 32
5+
num_attention_heads: 32
6+
num_key_value_heads: 8
7+
hidden_act: silu
8+
max_position_embeddings: 131072
9+
bos_token_id: 128000
10+
eos_token_id: 128001
11+
tokenizer_name: meta-llama/Meta-Llama-3-8B
12+
initializer_range: 0.02
13+
rms_norm_eps: 1.0e-05
14+
attention_dropout: false
15+
attention_bias: false
16+
flash_attention: true
17+
rope_theta: 500000.0

torchprime/torch_xla_models/configs/run.json

Lines changed: 0 additions & 22 deletions
This file was deleted.

torchprime/torch_xla_models/llama/model.py

Lines changed: 26 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,10 @@
2222

2323
import torch
2424
import torch_xla.debug.profiler as xp
25+
from omegaconf import DictConfig
2526
from torch import nn
2627
from torch.nn import CrossEntropyLoss
2728
from transformers.activations import ACT2FN
28-
from transformers.modeling_utils import PreTrainedModel
29-
from transformers.models.llama.configuration_llama import LlamaConfig
3029
from transformers.models.llama.modeling_llama import CausalLMOutputWithPast
3130
from transformers.utils import logging
3231

@@ -52,12 +51,7 @@ def forward(self, hidden_states):
5251

5352
class LlamaRotaryEmbedding(nn.Module):
5453
def __init__(
55-
self,
56-
dim,
57-
max_position_embeddings=2048,
58-
base=10000,
59-
device=None,
60-
scaling_factor=1.0,
54+
self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0
6155
):
6256
super().__init__()
6357
self.scaling_factor = scaling_factor
@@ -161,7 +155,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
161155
class LlamaAttention(nn.Module):
162156
"""Multi-headed attention from 'Attention Is All You Need' paper"""
163157

164-
def __init__(self, config: LlamaConfig, layer_idx: int | None = None):
158+
def __init__(self, config: DictConfig, layer_idx: int | None = None):
165159
super().__init__()
166160
self.config = config
167161
self.layer_idx = layer_idx
@@ -290,7 +284,7 @@ def forward(
290284

291285

292286
class LlamaDecoderLayer(nn.Module):
293-
def __init__(self, config: LlamaConfig, layer_idx: int):
287+
def __init__(self, config: DictConfig, layer_idx: int):
294288
super().__init__()
295289
self.hidden_size = config.hidden_size
296290

@@ -338,35 +332,19 @@ def forward(
338332
return hidden_states
339333

340334

341-
class LlamaPreTrainedModel(PreTrainedModel):
342-
def _init_weights(self, module):
343-
std = self.config.initializer_range
344-
if isinstance(module, nn.Linear):
345-
module.weight.data.normal_(mean=0.0, std=std)
346-
if module.bias is not None:
347-
module.bias.data.zero_()
348-
elif isinstance(module, nn.Embedding):
349-
module.weight.data.normal_(mean=0.0, std=std)
350-
if module.padding_idx is not None:
351-
module.weight.data[module.padding_idx].zero_()
352-
353-
354-
class LlamaModel(LlamaPreTrainedModel):
335+
class LlamaModel(nn.Module):
355336
"""
356337
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
357338
358339
Args:
359-
config: LlamaConfig
340+
config: DictConfig
360341
"""
361342

362-
def __init__(self, config: LlamaConfig):
363-
super().__init__(config)
364-
self.padding_idx = config.pad_token_id
343+
def __init__(self, config: DictConfig):
344+
super().__init__()
365345
self.vocab_size = config.vocab_size
366346

367-
self.embed_tokens = nn.Embedding(
368-
config.vocab_size, config.hidden_size, self.padding_idx
369-
)
347+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
370348
self.layers = nn.ModuleList(
371349
[
372350
LlamaDecoderLayer(config, layer_idx)
@@ -375,9 +353,6 @@ def __init__(self, config: LlamaConfig):
375353
)
376354
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
377355

378-
# Initialize weights and apply final processing
379-
self.post_init()
380-
381356
@xp.trace_me("LlamaModel")
382357
def forward(
383358
self,
@@ -393,11 +368,7 @@ def forward(
393368
# Create a causal mask without calling the current method
394369
seq_length = inputs_embeds.size(1)
395370
causal_mask = torch.triu(
396-
torch.full(
397-
(seq_length, seq_length),
398-
float("-inf"),
399-
device=inputs_embeds.device,
400-
),
371+
torch.full((seq_length, seq_length), float("-inf"), device=inputs_embeds.device),
401372
diagonal=1,
402373
)
403374
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) # Add batch and head dimension
@@ -411,24 +382,34 @@ def forward(
411382
# decoder layers
412383
for decoder_layer in self.layers:
413384
hidden_states = decoder_layer(
414-
hidden_states,
415-
attention_mask=causal_mask,
416-
position_ids=position_ids,
385+
hidden_states, attention_mask=causal_mask, position_ids=position_ids
417386
)
418387

419388
hidden_states = self.norm(hidden_states)
420389
return hidden_states
421390

422391

423-
class LlamaForCausalLM(LlamaPreTrainedModel):
392+
class LlamaForCausalLM(nn.Module):
424393
def __init__(self, config):
425-
super().__init__(config)
394+
super().__init__()
395+
self.config = config
426396
self.model = LlamaModel(config)
427397
self.vocab_size = config.vocab_size
428398
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
429399

430400
# Initialize weights and apply final processing
431-
self.post_init()
401+
self.apply(self._init_weights)
402+
403+
def _init_weights(self, module):
404+
std = self.config.initializer_range
405+
if isinstance(module, nn.Linear):
406+
module.weight.data.normal_(mean=0.0, std=std)
407+
if module.bias is not None:
408+
module.bias.data.zero_()
409+
elif isinstance(module, nn.Embedding):
410+
module.weight.data.normal_(mean=0.0, std=std)
411+
if module.padding_idx is not None:
412+
module.weight.data[module.padding_idx].zero_()
432413

433414
@xp.trace_me("LlamaForCausalLM")
434415
def forward(

torchprime/torch_xla_models/tests/test_llama.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55
import torch_xla
6+
from omegaconf import OmegaConf
67
from transformers import AutoConfig
78
from transformers import LlamaForCausalLM as HfLlamaForCausalLM
89

@@ -24,10 +25,28 @@ def setUp(self):
2425
vocab_size=self.vocab_size,
2526
)
2627
config.flash_attention = False
28+
torchprime_config = OmegaConf.create(
29+
{
30+
"vocab_size": 128,
31+
"hidden_size": 8,
32+
"intermediate_size": 16,
33+
"num_hidden_layers": 1,
34+
"num_attention_heads": 8,
35+
"num_key_value_heads": 8,
36+
"hidden_act": "silu",
37+
"max_position_embeddings": 8192,
38+
"initializer_range": 0.02,
39+
"rms_norm_eps": 1.0e-05,
40+
"attention_dropout": False,
41+
"attention_bias": False,
42+
"flash_attention": False,
43+
"rope_theta": 500000.0,
44+
}
45+
)
2746
# place model on CPU device first
2847
with torch.device("cpu"):
2948
self.hf_model = HfLlamaForCausalLM(config)
30-
self.model = LlamaForCausalLM(config)
49+
self.model = LlamaForCausalLM(torchprime_config)
3150
self.model.load_state_dict(self.hf_model.state_dict())
3251

3352
def test_forward_our_model_against_hf_model(self):

0 commit comments

Comments
 (0)