Skip to content

Commit e5c0227

Browse files
committed
Compare microbatch forward outputs and gradients
stack-info: PR: #246, branch: xmfan/stack/20
1 parent 2895806 commit e5c0227

File tree

5 files changed

+140
-29
lines changed

5 files changed

+140
-29
lines changed

autoparallel/api.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import copy
7+
import functools
78
import itertools
89
import warnings
910
from contextlib import ExitStack, contextmanager
@@ -42,7 +43,11 @@
4243
)
4344
from .init_weights import hook_params_setters
4445
from .optimize_sharding import ShardingOptimizer
45-
from .utils import _get_device_from_mesh
46+
from .utils import (
47+
NumericsLogger,
48+
_get_device_from_mesh,
49+
debug_boxed_nop_preserve_node_meta,
50+
)
4651

4752
_APPLY_VIEW_MM_VIEW_PATTERN = False
4853

@@ -193,6 +198,7 @@ def __init__(
193198
ac_stage_size_in_GiB: Optional[Union[float, str]] = "auto",
194199
reshard_after_forward: bool = True,
195200
dynamic: bool = False,
201+
numerics_logger: NumericsLogger | None = None,
196202
**kwargs,
197203
):
198204
self.stack = ExitStack()
@@ -220,7 +226,14 @@ def __init__(
220226
self.model = move_to_fake(model, self.fake_mode, device)
221227
self.input_fn = input_fn
222228
self.mesh = mesh
223-
self.compiler_fn = compile_fx_inner if compile else boxed_nop_preserve_node_meta
229+
if compile:
230+
self.compiler_fn = compile_fx_inner
231+
elif numerics_logger:
232+
self.compiler_fn = functools.partial(
233+
debug_boxed_nop_preserve_node_meta, numerics_logger=numerics_logger
234+
)
235+
else:
236+
self.compiler_fn = boxed_nop_preserve_node_meta # type: ignore[assignment]
224237
self.enable_ac = enable_ac
225238
self.ac_stage_size_in_GiB = ac_stage_size_in_GiB
226239
self.reshard_after_forward = reshard_after_forward

autoparallel/graph_pp_runner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def stage_forward(
234234
action: _Action,
235235
ctx: _PipelineContext,
236236
numerics_logs: Optional[list[str]] = None,
237+
forward_hook: Callable | None = None,
237238
) -> None:
238239
schedule = ctx.schedule_ref
239240
assert isinstance(schedule, _PipelineScheduleRuntime)
@@ -292,6 +293,8 @@ def stage_forward(
292293
# Output chunks is only used for the last stage since we only merge the output of the last stage
293294
if stage.is_last:
294295
stage.output_chunks.append(output)
296+
if forward_hook:
297+
forward_hook(stage, action, output)
295298

296299
stage.fwd_cache[mb_index] = (
297300
output_tuple, # stage_output

autoparallel/utils.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def log(self, node: str, args: Iterable[Any], inputs_or_outputs: str):
341341
continue
342342

343343
self._logs.append(
344-
f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)}"
344+
f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)} nan={torch.any(torch.isnan(arg))}"
345345
)
346346

347347
def run_node(self, n: torch.fx.Node) -> Any:
@@ -429,6 +429,20 @@ def log_model_weights(self, parallel_mod):
429429

430430
print(f"Weight hashes written to {path}")
431431

432+
def log_fw_intermediates(self, logs):
433+
rank = torch.distributed.get_rank()
434+
path = self.dir / f"rank_{rank}_fw_intermediates.log"
435+
with open(path, "a") as f:
436+
f.write("\n".join(logs) + "\n")
437+
438+
def log_diff(self, t, rank=0, prefix="?"):
439+
if self.rank == rank:
440+
path = self.dir / "diff.log"
441+
if isinstance(t, torch.distributed.tensor.DTensor):
442+
t = t.to_local()
443+
with open(path, "a") as f:
444+
f.write(f"[{prefix}] hash={hash_tensor(t)}, norm={torch.norm(t)}\n")
445+
432446
def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks):
433447
path = self.dir / "pp_weights.log"
434448

@@ -463,3 +477,40 @@ def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks):
463477

464478
if self.rank == 0:
465479
print(f"Weight hashes written to {path}")
480+
481+
def log_pp_grads(self, orig_mod, stage_mods, num_world_stages, ranks):
482+
path = self.dir / "diff.log"
483+
484+
torch.distributed.barrier()
485+
for i in range(num_world_stages):
486+
if self.rank in ranks and i in stage_mods:
487+
grad_logs = []
488+
real_params = dict(stage_mods[i].named_parameters())
489+
for name, _ in orig_mod.named_parameters():
490+
if name not in real_params:
491+
continue
492+
grad = real_params[name].grad
493+
if grad is None:
494+
grad_logs.append(f"[grad {name}] None")
495+
else:
496+
grad = grad.to_local()
497+
grad_logs.append(
498+
f"[grad {name}] hash={hash_tensor(grad)}, norm={torch.norm(grad)}"
499+
)
500+
with open(path, "a") as f:
501+
f.write("\n".join(grad_logs) + "\n")
502+
torch.distributed.barrier()
503+
504+
505+
def debug_boxed_nop_preserve_node_meta(fx_g, example_inputs, numerics_logger):
506+
def run(args):
507+
with torch.fx.traceback.preserve_node_meta():
508+
interp = DebugInterpreter(fx_g)
509+
out = interp.boxed_run(args)
510+
mylogs = interp.get_logs()
511+
if numerics_logger:
512+
numerics_logger.log_fw_intermediates(mylogs)
513+
return out
514+
515+
run._boxed_call = True
516+
return run

examples/example_ds3_local_map.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str):
118118
mscale=0.70,
119119
)
120120

121-
bs = 4 * mesh.shape[0] * mesh.shape[1]
121+
local_batch_size = 2
122+
global_batch_size = local_batch_size * mesh.shape[0] * mesh.shape[1]
122123
device = torch.device(f"cuda:{local_rank}")
123124

124125
# parallelize the model
@@ -129,11 +130,16 @@ def input_fn():
129130
return torch.randint(
130131
0,
131132
config.vocab_size,
132-
(bs, seq_len),
133+
(global_batch_size, seq_len),
133134
device=device,
134135
)
135136

136-
with AutoParallel(model, input_fn, mesh, dynamic=True) as autop:
137+
numerics_logger = None
138+
if rng_seed is not None:
139+
numerics_logger = NumericsLogger(logs_dir)
140+
with AutoParallel(
141+
model, input_fn, mesh, dynamic=True, numerics_logger=None
142+
) as autop:
137143
autop.add_parameter_memory_constraint(low=None, high=None)
138144

139145
# x_sharding = (Shard(0), Replicate())
@@ -153,17 +159,22 @@ def input_fn():
153159
# ) # maybe not correct value
154160
parallel_mod.init_weights(buffer_device=device, seed=rng_seed)
155161
if rng_seed is not None:
156-
numerics_logger = NumericsLogger(logs_dir)
157162
numerics_logger.log_model_weights(parallel_mod)
158-
159-
x = (
160-
torch.randint(
161-
0,
162-
config.vocab_size,
163-
(bs // mesh.shape[0] // mesh.shape[1], seq_len),
164-
device=device,
165-
),
163+
torch.manual_seed(rng_seed)
164+
165+
n_microbatches = 16
166+
full_batch = torch.randint(
167+
0,
168+
config.vocab_size,
169+
(local_batch_size * n_microbatches, seq_len),
170+
device=device,
166171
)
172+
microbatches = torch.split(full_batch, local_batch_size, dim=0)
173+
assert len(microbatches) == n_microbatches
174+
if rng_seed:
175+
numerics_logger.log_diff(
176+
full_batch.to(torch.float32), prefix="full batch input"
177+
)
167178

168179
# Symbolically evaluate in case you want to test running a graph bigger than your gpu
169180
if fake_evaluate:
@@ -173,15 +184,22 @@ def input_fn():
173184
allow_non_fake_inputs=True,
174185
shape_env=shape_env,
175186
):
176-
# # now let's run it
177-
out = parallel_mod(*x)
178-
out.backward(torch.randn_like(out))
187+
# now let's run it
188+
for x in microbatches:
189+
out = parallel_mod(x)
190+
out.backward(torch.ones_like(out))
179191
else:
180-
out = parallel_mod(*x)
181-
assert not torch.any(torch.isnan(out)), "Found NaNs in forward output"
192+
for i, x in enumerate(microbatches):
193+
assert x.shape[0] == 2
194+
out = parallel_mod(x)
195+
assert not torch.any(torch.isnan(out)), "Found NaNs in forward output"
196+
out.backward(torch.ones_like(out))
197+
if rng_seed is not None:
198+
numerics_logger.log_diff(out, prefix=f"mb{i} fwd out")
199+
182200
if rng_seed is not None:
183-
numerics_logger.log_forward_output(out)
184-
out.backward(torch.randn_like(out))
201+
for k, v in parallel_mod.named_parameters():
202+
numerics_logger.log_diff(v.grad, prefix=f"grad {k}")
185203

186204
print("All good!")
187205

examples/example_ds3_pp.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,9 @@ def run_test(fake_evaluate: bool, rng_seed: Optional[int], logs_dir: str):
163163
# This is the spmd mesh to be used for tracing
164164
mesh = world_mesh[("dp_mod_ep", "ep")]
165165

166-
global_batch_size = 32 * dp_degree
167166
# Batch size that will be supplied to the schedule and will be broken down into microbatches
168-
local_batch_size = global_batch_size // dp_degree
167+
local_batch_size = 32
168+
# global_batch_size = local_batch_size * dp_degree
169169
n_microbatches = 16
170170
# Batch size with which the spmd graphs will actually be executed
171171
microbatch_size = local_batch_size // n_microbatches
@@ -412,10 +412,6 @@ def shape_inference_output_fn_last_stage():
412412

413413
world_size = torch.distributed.get_world_size()
414414
num_world_stages = world_size * len(stage_mods)
415-
if rng_seed is not None:
416-
NumericsLogger(logs_dir).log_pp_model_weights(
417-
model, stage_mods, num_world_stages, ranks=[0, 4]
418-
)
419415

420416
stages = []
421417
# Step 4. Construct pipeline stages for this pp_rank using the stage modules, graphs and metadata
@@ -440,6 +436,7 @@ def shape_inference_output_fn_last_stage():
440436
group=world_mesh.get_group("pp"),
441437
)
442438
stages.append(stage)
439+
443440
# Step 5. Construct the pipeline runner using the pipeline stages for this pp_rank
444441
schedule = build_pipeline_schedule(
445442
stages=stages,
@@ -451,9 +448,32 @@ def shape_inference_output_fn_last_stage():
451448
backward_requires_autograd=False,
452449
)
453450
assert isinstance(schedule, _PipelineScheduleRuntime)
451+
452+
if rng_seed is not None:
453+
numerics_logger = NumericsLogger(logs_dir)
454+
numerics_logger.log_pp_model_weights(
455+
model, stage_mods, num_world_stages, ranks=[0, 4]
456+
)
457+
torch.manual_seed(rng_seed)
458+
459+
def last_stage_forward_hook(
460+
stage: GraphPipelineStage, action: str, output: torch.Tensor
461+
):
462+
if not stage.is_last or rng_seed is None:
463+
return
464+
465+
rank = torch.distributed.get_rank()
466+
if rank == 4:
467+
numerics_logger.log_diff(
468+
output, rank=4, prefix=f"mb{action.microbatch_index} fwd out"
469+
)
470+
454471
# Step 6. Override the pipeline runner's action implementations
455472
schedule.register_custom_function(
456-
FORWARD, functools.partial(stage_forward, numerics_logs=None)
473+
FORWARD,
474+
functools.partial(
475+
stage_forward, numerics_logs=None, forward_hook=last_stage_forward_hook
476+
),
457477
)
458478
schedule.register_custom_function(FULL_BACKWARD, stage_full_backward)
459479
schedule.register_custom_function(REDUCE_GRAD, stage_reduce_grad)
@@ -476,10 +496,16 @@ def shape_inference_output_fn_last_stage():
476496
with torch.no_grad():
477497
if pp_rank == 0:
478498
x = runtime_input_fn()
499+
if rng_seed:
500+
numerics_logger.log_diff(
501+
x.to(torch.float32), prefix="full batch input"
502+
)
479503
graph_pp_runner.step(x)
480504
else:
481505
graph_pp_runner.step()
482506

507+
numerics_logger.log_pp_grads(model, stage_mods, num_world_stages, ranks=[0, 4])
508+
483509
print("All good!")
484510

485511
if torch.distributed.is_initialized():

0 commit comments

Comments
 (0)