Skip to content

Commit 315f0b9

Browse files
committed
hit ddp training scripts as well
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent b1313fe commit 315f0b9

File tree

6 files changed

+12
-45
lines changed

6 files changed

+12
-45
lines changed

bionemo-recipes/recipes/esm2_native_te/perf_logger.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,3 +143,6 @@ def finish(self):
143143

144144
wandb.finish()
145145
self._progress_bar.close()
146+
147+
if self.fp8_stats_enabled:
148+
debug_api.end_debug()

bionemo-recipes/recipes/esm2_native_te/train_ddp.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint
3131
from dataset import create_bshd_dataloader, create_thd_dataloader
3232
from distributed_config import DistributedConfig
33+
from fp8_debugging import initialize_fp8_debugging
3334
from perf_logger import PerfLogger
3435
from scheduler import get_linear_schedule_with_warmup
3536

@@ -52,24 +53,10 @@ def main(args: DictConfig) -> float | None:
5253
torch.distributed.init_process_group(backend="nccl", device_id=device)
5354
torch.cuda.set_device(dist_config.local_rank)
5455

55-
# TE Debug feature logging
56-
if args.fp8_stats_config.enabled and not args.fp8_config.enabled:
57-
raise ValueError(
58-
"fp8_stats_config.enabled is true but fp8_config.enabled is false, please set fp8_config.enabled to true in the config if you wish to collect FP8 stats"
59-
)
60-
56+
# TE Debug feature logging - MUST be done BEFORE FSDP wrapping
6157
if args.fp8_stats_config.enabled:
62-
fp8_stats_file = args.fp8_stats_config.fp8_stats_file
63-
fp8_log_dir = Path(args.fp8_stats_config.fp8_log_dir) / f"rank_{dist_config.rank}"
64-
fp8_log_dir.mkdir(parents=True, exist_ok=True)
65-
logger.info(f"Logging FP8 stats to {fp8_log_dir}")
66-
te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features")
67-
debug_api.initialize(
68-
config_file=fp8_stats_file,
69-
feature_dirs=[te_features_dir],
70-
log_dir=fp8_log_dir,
71-
default_logging_enabled=True,
72-
)
58+
initialize_fp8_debugging(dist_config, **args.fp8_stats_config, fp8_enabled=args.fp8_config.enabled)
59+
7360
# Create a device mesh for DDP. While this isn't strictly necessary, it mirrors the device mesh we create for FSDP2
7461
# and MFSDP.
7562
device_mesh = init_device_mesh("cuda", mesh_shape=(dist_config.world_size,), mesh_dim_names=("ddp",))
@@ -157,8 +144,6 @@ def main(args: DictConfig) -> float | None:
157144
loss = outputs.loss
158145
loss.backward()
159146

160-
if args.fp8_stats_config.enabled:
161-
debug_api.step()
162147
# Compute and clip gradient norms.
163148
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0).item()
164149

@@ -206,8 +191,6 @@ def main(args: DictConfig) -> float | None:
206191

207192
# Clean up distributed training
208193
perf_logger.finish()
209-
if args.fp8_stats_config.enabled:
210-
debug_api.end_debug()
211194
torch.distributed.destroy_process_group()
212195

213196
return perf_logger.min_loss

bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,6 @@ def main(args: DictConfig) -> float | None:
205205

206206
# Clean up distributed training
207207
perf_logger.finish()
208-
if args.fp8_stats_config.enabled:
209-
debug_api.end_debug()
210208
torch.distributed.destroy_process_group()
211209

212210
return perf_logger.min_loss

bionemo-recipes/recipes/llama3_native_te/perf_logger.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ def finish(self):
177177
wandb.finish()
178178
self._progress_bar.close()
179179

180+
if self.fp8_stats_enabled:
181+
debug_api.end_debug()
182+
180183

181184
def setup_profiler(args: DictConfig, wandb_run: wandb.Run):
182185
"""Setup a basic torch profiler for the experiment.

bionemo-recipes/recipes/llama3_native_te/train_ddp.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from checkpoint import load_checkpoint_ddp, save_checkpoint_ddp, save_final_model_ddp, should_save_checkpoint
3333
from dataset import create_bshd_dataloader, create_thd_dataloader
3434
from distributed_config import DistributedConfig
35+
from fp8_debugging import initialize_fp8_debugging
3536
from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM
3637
from perf_logger import PerfLogger
3738
from scheduler import get_cosine_annealing_schedule_with_warmup
@@ -56,23 +57,8 @@ def main(args: DictConfig) -> float | None:
5657
torch.cuda.set_device(dist_config.local_rank)
5758

5859
# TE Debug feature logging
59-
if args.fp8_stats_config.enabled and not args.fp8_config.enabled:
60-
raise ValueError(
61-
"fp8_stats_config.enabled is true but fp8_config.enabled is false, please set fp8_config.enabled to true in the config if you wish to collect FP8 stats"
62-
)
63-
6460
if args.fp8_stats_config.enabled:
65-
fp8_stats_file = args.fp8_stats_config.fp8_stats_file
66-
fp8_log_dir = Path(args.fp8_stats_config.fp8_log_dir) / f"rank_{dist_config.rank}"
67-
fp8_log_dir.mkdir(parents=True, exist_ok=True)
68-
logger.info(f"Logging FP8 stats to {fp8_log_dir}")
69-
te_features_dir = str(Path(transformer_engine.__file__).parent / "debug" / "features")
70-
debug_api.initialize(
71-
config_file=fp8_stats_file,
72-
feature_dirs=[te_features_dir],
73-
log_dir=fp8_log_dir,
74-
default_logging_enabled=True,
75-
)
61+
initialize_fp8_debugging(dist_config, **args.fp8_stats_config, fp8_enabled=args.fp8_config.enabled)
7662

7763
# Create a device mesh for DDP. While this isn't strictly necessary, it mirrors the device mesh we create for FSDP2
7864
# and MFSDP.
@@ -163,8 +149,6 @@ def main(args: DictConfig) -> float | None:
163149

164150
# Log microbatch step data for accumulation metrics
165151
perf_logger.log_micro_step(batch=batch, outputs=outputs)
166-
if args.fp8_stats_config.enabled:
167-
debug_api.step()
168152

169153
# Gradient accumulation - only step optimizer after accumulating gradients
170154
if micro_step % args.grad_acc_steps == 0:
@@ -215,8 +199,6 @@ def main(args: DictConfig) -> float | None:
215199

216200
# Clean up distributed training
217201
perf_logger.finish()
218-
if args.fp8_stats_config.enabled:
219-
debug_api.end_debug()
220202
torch.distributed.destroy_process_group()
221203

222204
return perf_logger.min_loss

bionemo-recipes/recipes/llama3_native_te/train_fsdp2.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,6 @@ def main(args: DictConfig) -> float | None:
229229

230230
# Clean up distributed training
231231
perf_logger.finish()
232-
if args.fp8_stats_config.enabled:
233-
debug_api.end_debug()
234232
torch.distributed.destroy_process_group()
235233

236234
return perf_logger.min_loss

0 commit comments

Comments
 (0)