@@ -163,9 +163,9 @@ def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str):
163163 # This is the spmd mesh to be used for tracing
164164 mesh = world_mesh [("dp_mod_ep" , "ep" )]
165165
166- global_batch_size = 32 * dp_degree
167166 # Batch size that will be supplied to the schedule and will be broken down into microbatches
168- local_batch_size = global_batch_size // dp_degree
167+ local_batch_size = 32
168+ # global_batch_size = local_batch_size * dp_degree
169169 n_microbatches = 16
170170 # Batch size with which the spmd graphs will actually be executed
171171 microbatch_size = local_batch_size // n_microbatches
@@ -412,10 +412,6 @@ def shape_inference_output_fn_last_stage():
412412
413413 world_size = torch .distributed .get_world_size ()
414414 num_world_stages = world_size * len (stage_mods )
415- if rng_seed is not None :
416- NumericsLogger (logs_dir ).log_pp_model_weights (
417- model , stage_mods , num_world_stages , ranks = [0 , 4 ]
418- )
419415
420416 stages = []
421417 # Step 4. Construct pipeline stages for this pp_rank using the stage modules, graphs and metadata
@@ -440,6 +436,7 @@ def shape_inference_output_fn_last_stage():
440436 group = world_mesh .get_group ("pp" ),
441437 )
442438 stages .append (stage )
439+
443440 # Step 5. Construct the pipeline runner using the pipeline stages for this pp_rank
444441 schedule = build_pipeline_schedule (
445442 stages = stages ,
@@ -451,9 +448,32 @@ def shape_inference_output_fn_last_stage():
451448 backward_requires_autograd = False ,
452449 )
453450 assert isinstance (schedule , _PipelineScheduleRuntime )
451+
452+ if rng_seed is not None :
453+ numerics_logger = NumericsLogger (logs_dir )
454+ numerics_logger .log_pp_model_weights (
455+ model , stage_mods , num_world_stages , ranks = [0 , 4 ]
456+ )
457+ torch .manual_seed (rng_seed )
458+
459+ def last_stage_forward_hook (
460+ stage : GraphPipelineStage , action : str , output : torch .Tensor
461+ ):
462+ if not stage .is_last or rng_seed is None :
463+ return
464+
465+ rank = torch .distributed .get_rank ()
466+ if rank == 4 :
467+ numerics_logger .log_diff (
468+ output , rank = 4 , prefix = f"mb{ action .microbatch_index } fwd out"
469+ )
470+
454471 # Step 6. Override the pipeline runner's action implementations
455472 schedule .register_custom_function (
456- FORWARD , functools .partial (stage_forward , numerics_logs = None )
473+ FORWARD ,
474+ functools .partial (
475+ stage_forward , numerics_logs = None , forward_hook = last_stage_forward_hook
476+ ),
457477 )
458478 schedule .register_custom_function (FULL_BACKWARD , stage_full_backward )
459479 schedule .register_custom_function (REDUCE_GRAD , stage_reduce_grad )
@@ -476,10 +496,16 @@ def shape_inference_output_fn_last_stage():
476496 with torch .no_grad ():
477497 if pp_rank == 0 :
478498 x = runtime_input_fn ()
499+ if rng_seed :
500+ numerics_logger .log_diff (
501+ x .to (torch .float32 ), prefix = "full batch input"
502+ )
479503 graph_pp_runner .step (x )
480504 else :
481505 graph_pp_runner .step ()
482506
507+ numerics_logger .log_pp_grads (model , stage_mods , num_world_stages , ranks = [0 , 4 ])
508+
483509 print ("All good!" )
484510
485511 if torch .distributed .is_initialized ():
0 commit comments