Skip to content

Commit 4b0b462

Browse files
committed
Log forward intermediates hashes w/pp vs w/o pp
stack-info: PR: #246, branch: xmfan/stack/20
1 parent 580144b commit 4b0b462

File tree

3 files changed

+50
-6
lines changed

3 files changed

+50
-6
lines changed

autoparallel/api.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55

66
import copy
7+
import functools
78
import itertools
89
import warnings
910
from contextlib import ExitStack, contextmanager
@@ -42,7 +43,11 @@
4243
)
4344
from .init_weights import hook_params_setters
4445
from .optimize_sharding import ShardingOptimizer
45-
from .utils import _get_device_from_mesh
46+
from .utils import (
47+
NumericsLogger,
48+
_get_device_from_mesh,
49+
debug_boxed_nop_preserve_node_meta,
50+
)
4651

4752
_APPLY_VIEW_MM_VIEW_PATTERN = False
4853

@@ -193,6 +198,7 @@ def __init__(
193198
ac_stage_size_in_GiB: Optional[Union[float, str]] = "auto",
194199
reshard_after_forward: bool = True,
195200
dynamic: bool = False,
201+
numerics_logger: NumericsLogger | None = None,
196202
**kwargs,
197203
):
198204
self.stack = ExitStack()
@@ -220,7 +226,14 @@ def __init__(
220226
self.model = move_to_fake(model, self.fake_mode, device)
221227
self.input_fn = input_fn
222228
self.mesh = mesh
223-
self.compiler_fn = compile_fx_inner if compile else boxed_nop_preserve_node_meta
229+
if compile:
230+
self.compiler_fn = compile_fx_inner
231+
elif numerics_logger:
232+
self.compiler_fn = functools.partial(
233+
debug_boxed_nop_preserve_node_meta, numerics_logger=numerics_logger
234+
)
235+
else:
236+
self.compiler_fn = boxed_nop_preserve_node_meta
224237
self.enable_ac = enable_ac
225238
self.ac_stage_size_in_GiB = ac_stage_size_in_GiB
226239
self.reshard_after_forward = reshard_after_forward

autoparallel/utils.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def log(self, node: str, args: Iterable[Any], inputs_or_outputs: str):
341341
continue
342342

343343
self._logs.append(
344-
f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)}"
344+
f"{node=}, {inputs_or_outputs}[{i}]={torch.hash_tensor(arg)} nan={torch.any(torch.isnan(arg))}"
345345
)
346346

347347
def run_node(self, n: torch.fx.Node) -> Any:
@@ -429,6 +429,18 @@ def log_model_weights(self, parallel_mod):
429429

430430
print(f"Weight hashes written to {path}")
431431

432+
def log_fw_intermediates(self, logs):
433+
rank = torch.distributed.get_rank()
434+
path = self.dir / f"rank_{rank}_fw_intermediates.log"
435+
with open(path, "a") as f:
436+
f.write("\n".join(logs) + "\n")
437+
438+
def log_forward_output(self, fw_out):
439+
if self.rank == 0:
440+
path = self.dir / "fw_out.log"
441+
with open(path, "a") as f:
442+
f.write(f"fw_out={hash_tensor(fw_out)}\n")
443+
432444
def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks):
433445
path = self.dir / "pp_weights.log"
434446

@@ -463,3 +475,17 @@ def log_pp_model_weights(self, orig_mod, stage_mods, num_world_stages, ranks):
463475

464476
if self.rank == 0:
465477
print(f"Weight hashes written to {path}")
478+
479+
480+
def debug_boxed_nop_preserve_node_meta(fx_g, example_inputs, numerics_logger):
481+
def run(args):
482+
with torch.fx.traceback.preserve_node_meta():
483+
interp = DebugInterpreter(fx_g)
484+
out = interp.boxed_run(args)
485+
mylogs = interp.get_logs()
486+
if numerics_logger:
487+
numerics_logger.log_fw_intermediates(mylogs)
488+
return out
489+
490+
run._boxed_call = True
491+
return run

examples/example_ds3_local_map.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,12 @@ def input_fn():
133133
device=device,
134134
)
135135

136-
with AutoParallel(model, input_fn, mesh, dynamic=True) as autop:
136+
numerics_logger = None
137+
if rng_seed is not None:
138+
numerics_logger = NumericsLogger(logs_dir)
139+
with AutoParallel(
140+
model, input_fn, mesh, dynamic=True, numerics_logger=numerics_logger
141+
) as autop:
137142
autop.add_parameter_memory_constraint(low=None, high=None)
138143

139144
# x_sharding = (Shard(0), Replicate())
@@ -153,7 +158,7 @@ def input_fn():
153158
# ) # maybe not correct value
154159
parallel_mod.init_weights(buffer_device=device, seed=rng_seed)
155160
if rng_seed is not None:
156-
NumericsLogger(logs_dir).log_model_weights(parallel_mod)
161+
numerics_logger.log_model_weights(parallel_mod)
157162

158163
x = (
159164
torch.randint(
@@ -173,7 +178,7 @@ def input_fn():
173178
shape_env=shape_env,
174179
):
175180
# # now let's run it
176-
out = parallel_mod(*x)
181+
out = parallel_mod(*x, numerics_logger=numerics_logger)
177182
out.backward(torch.randn_like(out))
178183
else:
179184
out = parallel_mod(*x)

0 commit comments

Comments
 (0)