@@ -3,47 +3,88 @@ defaults:
33 - /base
44 - _self_
55
6- # lepton info
6+ branch : jwilber/add-accelerate-l1-3b-config
7+
8+ # ###########################################################
9+ # lepton job info
10+ # ###########################################################
711node_group : yo-bom-lepton-001
812num_nodes : 2
913device_type : gpu
10- num_devices : 2
14+ num_devices : 8
1115gpu_type : h100-sxm
12- total_gpus : ${multiply:${num_devices},${num_nodes}}
1316resource_shape : " ${device_type}.${num_devices}x${gpu_type}"
1417
18+ # ###########################################################
1519# recipe identifiers
20+ # mostly used for logging and observability
21+ # ###########################################################
1622recipe_subdir : esm2_accelerate_te
1723model_type : esm2
24+ variant : train # train, finetune
25+
26+ # Core identifiers for filtering
27+ framework : native # native, accelerate
28+ parallelism_strategy : fsdp2 # ddp, fsdp2, mfsdp
29+ precision : fp8 # likely bf16 or fp8
30+ te_enabled : true
31+ fp8_enabled : true
32+
33+ # Catchall for additional features/configs
34+ extras : [] # e.g. [thd]
35+
36+ # ###########################################################
37+ # wandb info (total_gpus used for group name)
38+ # ###########################################################
39+ # `total_gpus` calculated from lepton job info above
40+ total_gpus : ${multiply:${num_devices},${num_nodes}}
1841
19- # wandb
2042wandb_init_args :
2143 project : " test_convergence__recipes__${sanitize:${branch}}"
22- group : " ${model_type}__${task_cmd}__${total_gpus}__ ${sanitize:${gpu_type}}"
44+ group : " ${model_type}__${task_cmd}__${total_gpus}gpus__ ${sanitize:${gpu_type}}"
2345 job_type : " ${recipe_subdir}"
2446 name : null
2547
48+ # ###########################################################
49+ # task commands
50+ # shared across all products (if not explicitly overridden)
51+ # ###########################################################
52+ # task_cmd: train_fsdp2 # mfsdp
53+ task_cmd : train
54+
55+ # script overrides
56+ # these should match the keys in the recipe's config file
57+ # model_tag: nvidia/esm2_t36_3B_UR50D
58+
59+ micro_batch_size : 4
60+ # num_warmup_steps: 20_000
2661# config overrides
2762trainer :
2863 report_to : " wandb"
2964
30- # train specific commands
31- task_cmd : train
32- stop_after_n_steps : 10
65+ stop_after_n_steps : 100
3366
34- # configs to run
67+ # ###########################################################
68+ # Each product is a different config to run, alongside
69+ # config-specific arguments. Must have a w`andb_name`.
70+ # ###########################################################
3571products :
36- - config : L0_sanity
72+ - config : L1_3B
73+ acc_config : default
3774 wandb_name : " ${config}__${now:%Y%m%d-%H%M%S}__${gitsha:}"
3875
39- # training script to run
76+ # ###########################################################
77+ # run script
78+ # This gets called right after `checkout_script` in the base config.
79+ # ###########################################################
4080run_script : |
41- accelerate launch --config_file accelerate_config/default .yaml \
81+ accelerate launch --config_file accelerate_config/${acc_config} .yaml \
4282 ${task_cmd}.py \
4383 --config-name=${config} \
4484 stop_after_n_steps=${stop_after_n_steps} \
45- wandb_init_args.mode=${wandb_init_args.mode} \
85+ + wandb_init_args.mode=${wandb_init_args.mode} \
4686 +wandb_init_args.project=${wandb_init_args.project} \
4787 +wandb_init_args.group=${wandb_init_args.group} \
4888 +wandb_init_args.job_type=${wandb_init_args.job_type} \
49- wandb_init_args.name=${wandb_name}
89+ wandb_init_args.name=${wandb_name} \
90+ trainer.per_device_train_batch_size=${micro_batch_size}
0 commit comments