Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/convergence-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
model_config: ${{ github.event_name == 'schedule' && fromJSON('["esm2_native_te_650m", "esm2_native_te_3b", "esm2_native_te_15b", "codonfm_ptl_te"]') || fromJSON(format('["{0}"]', github.event.inputs.model_config)) }}
model_config: ${{ github.event_name == 'schedule' && fromJSON('["esm2_native_te_650m", "esm2_native_te_15b", "codonfm_ptl_te"]') || fromJSON(format('["{0}"]', github.event.inputs.model_config)) }}
fail-fast: false
steps:
- name: Checkout
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ framework: native # native, accelerate
precision: bf16 # likely bf16 or fp8
te_enabled: true
fp8_enabled: false
fp8_recipe: ""
fp8_format: ""
# thd_enabled: false

# Catchall for additional features/configs
Expand Down Expand Up @@ -123,6 +125,28 @@ products:
micro_batch_size: 4
wandb_name: "esm2_native_15b__fsdp2__baseline__${now:%Y%m%d-%H%M%S}__${gitsha:}"
job_name: "esm2-native-15b-fsdp2-baseline"
# TE bshd perf, FSDP2, FP8
- config: L1_15B_perf_test
task_cmd: train_fsdp2
parallelism_strategy: fsdp2
thd_enabled: false
fp8_enabled: true
fp8_recipe: transformer_engine.common.recipe.Float8BlockScaling
fp8_format: E4M3
micro_batch_size: 4
wandb_name: "esm2_native_15b__fsdp2__fp8__${now:%Y%m%d-%H%M%S}__${gitsha:}"
job_name: "esm2-native-15b-fsdp2-fp8"
# TE thd perf, FSDP2, FP8
- config: L1_15B_perf_test
task_cmd: train_fsdp2
parallelism_strategy: fsdp2
thd_enabled: true
fp8_enabled: true
fp8_recipe: transformer_engine.common.recipe.Float8BlockScaling
fp8_format: E4M3
micro_batch_size: 4
wandb_name: "esm2_native_15b__fsdp2__thd__fp8__${now:%Y%m%d-%H%M%S}__${gitsha:}"
job_name: "esm2-native-15b-fsdp2-thd-fp8"

############################################################
# run script
Expand Down Expand Up @@ -156,4 +180,6 @@ run_script: |
checkpoint.resume_from_checkpoint=${resume_from_checkpoint} \
+checkpoint.save_checkpoints=${save_checkpoints} \
+checkpoint.use_distributed_checkpoint_fsdp2=${use_distributed_checkpoint_fsdp2} \
fp8_config.enabled=${fp8_enabled}
fp8_config.enabled=${fp8_enabled} \
fp8_config.fp8_recipe=${fp8_recipe} \
fp8_config.fp8_format=${fp8_format}
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ wandb_init_args:
# these should match the keys in the recipe's config file
model_tag: nvidia/esm2_t36_3B_UR50D
# task_cmd: train_fsdp2 # mfsdp
num_train_steps: 20_000
num_train_steps: 10_000
# dataset commands
micro_batch_size: 16
load_dataset_kwargs_path: nvidia/esm2_uniref_pretraining_data
load_dataset_kwargs_streaming: true
load_dataset_kwargs_revision: 4ac1d2973567e46b8ca95901f4b4793a21305995 # pragma: allowlist secret

# lr commands
num_warmup_steps: 2_000
num_warmup_steps: 1_000
# checkpoint controls
ckpt_dir: ""
save_checkpoints: false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ wandb_init_args:
# these should match the keys in the recipe's config file
model_tag: nvidia/esm2_t36_650M_UR50D
# task_cmd: train_fsdp2 # mfsdp
num_train_steps: 20_000
num_train_steps: 10_000
# dataset commands
micro_batch_size: 16
load_dataset_kwargs_path: nvidia/esm2_uniref_pretraining_data
Expand All @@ -67,7 +67,7 @@ load_dataset_kwargs_revision: 4ac1d2973567e46b8ca95901f4b4793a21305995 # pragma:
num_workers: 1

# lr commands
num_warmup_steps: 2_000
num_warmup_steps: 1_000
# checkpoint controls
ckpt_dir: ""
save_checkpoints: false
Expand All @@ -91,16 +91,16 @@ products:
wandb_name: "esm2_native_650m__fsdp2__thd__${now:%Y%m%d-%H%M%S}__${gitsha:}"
job_name: "esm2-native-650m-fsdp2-thd"
# OSS Convergence Baseline
- config: L1_650M
model_tag: facebook/esm2_t33_650M_UR50D
num_nodes: 8
num_devices: 8
task_cmd: train_fsdp2
parallelism_strategy: fsdp2
thd_enabled: false
micro_batch_size: 32
wandb_name: "esm2_native_650m__fsdp2__${now:%Y%m%d-%H%M%S}__${gitsha:}"
job_name: "esm2-native-650m-fsdp2"
# - config: L1_650M
# model_tag: facebook/esm2_t33_650M_UR50D
# num_nodes: 8
# num_devices: 8
# task_cmd: train_fsdp2
# parallelism_strategy: fsdp2
# thd_enabled: false
# micro_batch_size: 32
# wandb_name: "esm2_native_650m__fsdp2__${now:%Y%m%d-%H%M%S}__${gitsha:}"
# job_name: "esm2-native-650m-fsdp2"

############################################################
# run script
Expand Down