1515import itertools
1616from typing import Any , List , NamedTuple , Optional , Union
1717
18+ import jax
19+ from absl import logging
1820from jax .ad_checkpoint import checkpoint_policies as jax_remat_policies
1921
2022from axlearn .common import causal_lm , config
@@ -252,7 +254,6 @@ def get_trainer_kwargs(
252254 max_step = TOTAL_TOKENS [version ][model_size ] // tokens_per_batch
253255 max_sequence_length = MAX_SEQUENCE_LENGTH [version ]
254256 train_batch_size = tokens_per_batch // max_sequence_length
255-
256257 # Whether to use grouped query attention.
257258 num_kv_heads = None
258259 if version in (Version .V3 , Version .V3_TIKTOKEN ):
@@ -813,6 +814,67 @@ def get_trainer_kwargs(
813814 ),
814815 )
815816 elif model_size == "150B" :
817+ ##################################################################################
818+ max_sequence_length = MAX_SEQUENCE_LENGTH [Version .V2 ] # 4096
819+
820+ # model_parallelism * fsdp == num_chips_in_trillium (256)
821+ model_parallelism = 4
822+ fsdp = 64
823+
824+ current_pdbs = 0.5
825+ train_batch_size = int (current_pdbs * len (jax .devices ()))
826+
827+ # 16 * (1024**2) / 4096 = 4096
828+ tokens_per_batch = int (train_batch_size * max_sequence_length )
829+
830+ # 32M tokens is the max global tokens we can train on.
831+ # We must modify either the pdbs or the model sharding to accommodate 128 slices.
832+ if tokens_per_batch > 32 * (1024 ** 2 ):
833+ tokens_per_batch = 32 * (1024 ** 2 )
834+ # if we want to modify the pdbs:
835+ # current_pdbs = 0.25
836+
837+ # otherwise we can modify the model sharding.
838+ model_parallelism = 8
839+ fsdp = 32
840+
841+ # 32M tokens is the max global tokens we can train on.
842+ assert tokens_per_batch <= 32 * (1024 ** 2 )
843+ assert fsdp * model_parallelism == 256
844+
845+ # 1 / model_parallelism = 1 / 4 = 0.25
846+ min_pdbs = 1 / model_parallelism
847+ max_pdbs = 1
848+
849+ # More than 1 pdbs causes an OOM.
850+ assert current_pdbs < max_pdbs
851+ assert current_pdbs >= min_pdbs
852+
853+ # maximum number of devices we can use this config on =
854+ # train_batch_size // min_pdbs = 4096 / 0.25 = 16384
855+ max_devices = int (train_batch_size // min_pdbs )
856+
857+ assert isinstance (train_batch_size , int )
858+ assert isinstance (tokens_per_batch , int )
859+
860+ logging .info (
861+ (
862+ "******* DEBUGGING: max_sequence_length: %s, model_parallelism: %s,"
863+ " fsdp: %s, current_pdbs: %s, train_batch_size: %s,"
864+ " tokens_per_batch: %s, min_pdbs: %s, max_pdbs: %s, max_devices: %s"
865+ ),
866+ max_sequence_length ,
867+ model_parallelism ,
868+ fsdp ,
869+ current_pdbs ,
870+ train_batch_size ,
871+ tokens_per_batch ,
872+ min_pdbs ,
873+ max_pdbs ,
874+ max_devices ,
875+ )
876+ ##################################################################################
877+
816878 trainer_kwargs = dict (
817879 model_kwargs = dict (
818880 num_layers = 80 ,
@@ -828,8 +890,9 @@ def get_trainer_kwargs(
828890 learner_kwargs = dict (peak_lr = 1.5e-4 , weight_decay = 0.1 ),
829891 max_sequence_length = max_sequence_length ,
830892 train_batch_size = train_batch_size ,
831- max_step = max_step ,
832- mesh_shape = mesh_shape_from_axes (data = - 1 , fsdp = 64 , model = 4 ),
893+ max_step = 100_000 , # max_step,
894+ save_every_n_steps = 100 ,
895+ mesh_shape = mesh_shape_from_axes (data = - 1 , fsdp = fsdp , model = model_parallelism ),
833896 mesh_rules = (
834897 (
835898 # Target per-device token count = 4k.
@@ -971,6 +1034,12 @@ def trainer_configs(
9711034 if model_size not in TOTAL_TOKENS [version ]: # This combination does not exist.
9721035 continue
9731036 vocab_size = VOCAB_SIZE [version ]
1037+ logging .info (
1038+ "******* DEBUGGING: version: %s, model_size: %s, flash_attention: %s" ,
1039+ version ,
1040+ model_size ,
1041+ flash_attention ,
1042+ )
9741043 config_name = make_config_name (
9751044 arch = arch ,
9761045 model_size = model_size ,
0 commit comments