Skip to content

Commit 9650021

Browse files
committed
Update params to work for number of devices
1 parent 5ecc953 commit 9650021

File tree

9 files changed

+227
-25
lines changed

9 files changed

+227
-25
lines changed

Dockerfile

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,15 @@ ARG EXTRAS=
8888
ENV UV_FIND_LINKS=https://storage.googleapis.com/jax-releases/libtpu_releases.html
8989
# Ensure we install the TPU version, even if building locally.
9090
# Jax will fallback to CPU when run on a machine without TPU.
91-
RUN uv pip install --prerelease=allow .[core,tpu] && uv cache clean
91+
COPY libtpu.so /root/libtpu.so
92+
RUN uv pip install --prerelease=allow .[core,gcp,tpu] && uv cache clean
93+
RUN uv pip install libtpu==0.0.14
94+
95+
# Add this line to print the installed version of libtpu.
96+
RUN pip show libtpu | grep Version
97+
RUN pip show jax | grep Version
98+
RUN pip show jaxlib | grep Version
99+
92100
RUN if [ -n "$EXTRAS" ]; then uv pip install .[$EXTRAS] && uv cache clean; fi
93101
COPY . .
94102

axlearn/cloud/gcp/jobs/launch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,6 @@ def _wrapped_usage(
780780

781781

782782
if __name__ == "__main__":
783-
configure_logging(logging.INFO)
783+
configure_logging(logging.DEBUG)
784784
_private_flags()
785785
app.run(main)

axlearn/cloud/gcp/jobset_utils.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import io
66
import logging
7-
import math
87
import os
98
from dataclasses import dataclass
109
from typing import Any, Optional, Sequence
@@ -27,10 +26,7 @@
2726
)
2827
from axlearn.cloud.gcp.config import gcp_settings
2928
from axlearn.cloud.gcp.node_pool import PRE_PROVISIONER_LABEL
30-
from axlearn.cloud.gcp.system_characteristics import (
31-
GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS,
32-
USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS,
33-
)
29+
from axlearn.cloud.gcp.system_characteristics import USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS
3430
from axlearn.cloud.gcp.tpu import get_default_env, infer_tpu_workers
3531
from axlearn.cloud.gcp.utils import validate_jobset_name
3632
from axlearn.common.compiler_options import infer_tpu_type
@@ -451,15 +447,17 @@ def _build_container(self) -> Nested[Any]:
451447
if cfg.enable_tpu_ici_resiliency is not None:
452448
env_vars["ENABLE_ICI_RESILIENCY"] = str(cfg.enable_tpu_ici_resiliency).lower()
453449

450+
env_vars["TPU_LIBRARY_PATH"] = "/root/libtpu.so"
451+
454452
resources = {"limits": {"google.com/tpu": system.chips_per_vm}}
455-
# Set request memory by host machine type.
456-
machine_memory_gi = GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS.get(
457-
system.gce_machine_type, None
458-
)
459-
if machine_memory_gi is not None:
460-
request_memory_gi = machine_memory_gi * _MEMORY_REQUEST_PERCENTAGE
461-
resources["limits"]["memory"] = f"{machine_memory_gi}Gi"
462-
resources["requests"] = {"memory": f"{math.floor(request_memory_gi)}Gi"}
453+
# # Set request memory by host machine type.
454+
# machine_memory_gi = GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS.get(
455+
# system.gce_machine_type, None
456+
# )
457+
# if machine_memory_gi is not None:
458+
# request_memory_gi = machine_memory_gi * _MEMORY_REQUEST_PERCENTAGE
459+
# resources["limits"]["memory"] = f"{machine_memory_gi}Gi"
460+
# resources["requests"] = {"memory": f"{math.floor(request_memory_gi)}Gi"}
463461

464462
k8s_env_vars = [dict(name=k, value=str(v)) for k, v in env_vars.items()]
465463
k8s_env_vars.append(
@@ -509,8 +507,8 @@ def _build_uploader_container(
509507
interval_s = 60
510508
sync_command = f"while true; do gsutil -m rsync -r {src} {dst}; sleep {interval_s}; done"
511509
resources = {
512-
"requests": {"cpu": "100m", "memory": "128Mi"},
513-
"limits": {"cpu": "500m", "memory": "256Mi"},
510+
# "requests": {"cpu": "100m", "memory": "128Mi"},
511+
# "limits": {"cpu": "500m", "memory": "256Mi"},
514512
}
515513
return dict(
516514
name="output-uploader",

axlearn/common/array_serialization.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,21 +559,34 @@ def serialize(
559559

560560
# pylint: disable-next=redefined-outer-name
561561
async def _run_serializer():
562+
logging.info(
563+
"******* DEBUG GlobalAsyncCheckpointManager _run_serializer "
564+
"with number of commit_futures: %s",
565+
len(commit_futures),
566+
)
562567
future_writer = jax.tree.map(
563568
serialization.async_serialize, arrays, tensorstore_specs, commit_futures
564569
)
570+
logging.info("******* DEBUG GlobalAsyncCheckpointManager _run_serializer Completed")
565571
return await asyncio.gather(*future_writer)
566572

573+
# Is this the problem?
574+
logging.info("******* DEBUG Starting to run _run_serializer")
575+
567576
# Note: We need to run the coroutine in another event loop driven by a separate thread.
568577
# The current event loop might be already running an async function when `serialize` is
569578
# invoked from a coroutine, in which case asyncio.get_running_loop().run_until_complete()
570579
# would not be able to execute another coroutine to completion.
571580
asyncio.run_coroutine_threadsafe(_run_serializer(), self._loop).result()
572581

582+
logging.info("******* DEBUG Starting to run _run_serializer")
583+
573584
self._add_futures(
574585
jax.tree_util.tree_flatten(commit_futures)[0] + (additional_futures or [])
575586
)
576587

588+
logging.info("******* DEBUG Starting to run async_commit")
589+
577590
# Used in wait_until_finished to check on process != 0, if the checkpoint
578591
# has finished writing.
579592
self._start_async_commit(on_commit_callback)

axlearn/common/checkpointer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,14 @@ def async_save_tf_savables(
175175
When this call returns, `value_map` can be safely mutated, but saving to `dir` will not
176176
complete unless the returned future is set.
177177
"""
178+
logging.info("******* DEBUG Saving TF savables to %s async", dir)
178179
# pylint: disable-next=consider-using-with
179180
f = tempfile.TemporaryDirectory()
180181
for path, value in utils.flatten_items(value_map):
181182
tf_checkpoint = tf.train.Checkpoint(value)
183+
logging.info("******* DEBUG Writing %s to path %s", f.name, path)
182184
tf_checkpoint.write(os.path.join(f.name, path))
185+
logging.info("******* DEBUG Done writing %s to path %s", f.name, path)
183186
return executor.submit(_upload_dir, f, dst_dir=dir)
184187

185188

@@ -399,6 +402,7 @@ def __init__(self, cfg: Config):
399402
# TODO(markblee): Consider making BoundedDataShardedAsyncCheckpointManager
400403
# the default once stable.
401404
if cfg.max_concurrent_gb is not None or cfg.max_data_shard_degree:
405+
logging.info("******* DEBUG Using BoundedDataShardedAsyncCheckpointManager")
402406
self._manager = BoundedDataShardedAsyncCheckpointManager(
403407
max_concurrent_gb=cfg.max_concurrent_gb,
404408
timeout_secs=cfg.timeout_secs,
@@ -411,6 +415,7 @@ def __init__(self, cfg: Config):
411415
f"shard_threshold_bytes is set to {cfg.shard_threshold_bytes}, but "
412416
"max_data_shard_degree is not set. It will not take any effect."
413417
)
418+
logging.info("******* DEBUG Using GlobalAsyncCheckpointManager")
414419
self._manager = GlobalAsyncCheckpointManager(timeout_secs=cfg.timeout_secs)
415420
if cfg.max_concurrent_restore_gb is not None and cfg.max_concurrent_restore_gb <= 0:
416421
raise ValueError(
@@ -514,8 +519,12 @@ def save_to_dir(
514519
logging.info("Creating directories: %s", dirs)
515520
list(self._executor.map(fs.makedirs, dirs))
516521
logging.info("All directories created")
522+
523+
logging.info("******* DEBUG starting sync_global_devices")
517524
# Wait for directory and index creation.
518525
multihost_utils.sync_global_devices(ckpt_dir)
526+
logging.info("******* DEBUG finished sync_global_devices")
527+
519528
# Each worker writes its tf checkpoints under a different path.
520529
save_tf_future = async_save_tf_savables(
521530
spec.tf_ckpt_map,
@@ -527,6 +536,7 @@ def save_to_dir(
527536
)
528537

529538
def commit():
539+
logging.info("******* DEBUG starting on_commit_callback")
530540
on_commit_callback(ckpt_dir=ckpt_dir, index=spec.index)
531541
logging.info(
532542
"Serialization of %s completed in %s seconds.",
@@ -538,6 +548,9 @@ def commit():
538548
logging.debug(
539549
"array_values=%s tensorstore=%s", utils.shapes(spec.gda_values), spec.tensorstore_specs
540550
)
551+
logging.info(
552+
"array_values=%s tensorstore=%s", utils.shapes(spec.gda_values), spec.tensorstore_specs
553+
)
541554
self._manager.serialize(
542555
spec.gda_values,
543556
spec.tensorstore_specs,

axlearn/common/compiler_options.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ def default_xla_options(
5959
# further if you see "Allocator failed to allocate". A feature
6060
# to dynamically allocate may come later: b/380514965
6161
megascale_grpc_premap_memory_bytes=17179869184,
62+
# DEBUGGING ONLY: RapidEye output directory for debugging purposes,
63+
megascale_rapid_eye_error_digest_log_path="/output/rapideye/",
64+
# megascale_jax_offset_launch_id_by_module_name="false",
65+
# megascale_jax_use_device_set_based_launch_id="false",
66+
# enable megascale debug port.
67+
megascale_debug_port=8081,
6268
# Flag controlling the maximum number of overlapping host offloadings.
6369
xla_tpu_host_transfer_overlap_limit=24,
6470
# Flag controlling the maximum number of overlapping cross-DCN send/recv.
@@ -149,10 +155,10 @@ def default_xla_options(
149155
# Similar to megascale_error_reporter_abort_on_hang but for unrecoverable errors.
150156
megascale_error_reporter_abort_on_error="true",
151157
# Increase the timeout at which a hang is detected/reported, default is 5m.
152-
megascale_graph_hang_threshold="10m",
158+
megascale_graph_hang_threshold="20m",
153159
# Similar to megascale_graph_hang_threshold but specific to within a launch_id.
154160
# Default is 1m.
155-
megascale_graph_within_launch_hang_threshold="10m",
161+
megascale_graph_within_launch_hang_threshold="20m",
156162
# TODO(ethanli): temporary workaround to avoid memory leak in megascale.
157163
megascale_grpc_enable_xor_tracer="false",
158164
)
@@ -163,7 +169,16 @@ def default_xla_options(
163169
int(v)
164170
continue
165171
except ValueError:
166-
assert v in [True, False, "true", "false", "megachip_tccontrol", "10m"], (k, v)
172+
assert v in [
173+
True,
174+
False,
175+
"true",
176+
"false",
177+
"megachip_tccontrol",
178+
"10m",
179+
"20m",
180+
"/output/rapideye/",
181+
], (k, v)
167182

168183
return options
169184

@@ -302,6 +317,7 @@ def infer_xla_performance_flags(
302317
if current_configuration in mesh_configurations_for_sparse_core_offloading:
303318
flags = dict(
304319
# Must disable continuation fusion to enable sparse core offloading.
320+
# AXLEARN TESTING NOTE: We are disabling this to test for SparseCore related issues.
305321
xla_tpu_enable_async_collective_fusion_fuse_all_gather="false",
306322
xla_tpu_enable_async_collective_fusion_fuse_all_reduce="false",
307323
xla_tpu_enable_async_collective_fusion_fuse_reduce_scatter="false",

axlearn/common/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ def run(
638638
)
639639
self.vlog(3, "Done step %s", self.step)
640640
num_steps += 1
641-
if num_steps % 100 == 0:
641+
if num_steps % 1 == 0:
642642
now = time.perf_counter()
643643
average_step_time = (now - start_time) / num_steps
644644
self._step_log("Average step time: %s seconds", average_step_time)
@@ -1111,7 +1111,7 @@ def _run_step(
11111111
# Run the compiled function.
11121112
self._trainer_state, outputs = compiled_train_step_fn(self.trainer_state, input_batch)
11131113

1114-
if self.step % 100 == 0 or 0 <= self.step <= 5:
1114+
if self.step % 10 == 0 or 0 <= self.step <= 5:
11151115
self._step_log(
11161116
"loss=%s aux=%s",
11171117
outputs["loss"],

axlearn/experiments/text/gpt/fuji.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import itertools
1616
from typing import Any, List, NamedTuple, Optional, Union
1717

18+
import jax
19+
from absl import logging
1820
from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies
1921

2022
from axlearn.common import causal_lm, config
@@ -252,7 +254,6 @@ def get_trainer_kwargs(
252254
max_step = TOTAL_TOKENS[version][model_size] // tokens_per_batch
253255
max_sequence_length = MAX_SEQUENCE_LENGTH[version]
254256
train_batch_size = tokens_per_batch // max_sequence_length
255-
256257
# Whether to use grouped query attention.
257258
num_kv_heads = None
258259
if version in (Version.V3, Version.V3_TIKTOKEN):
@@ -813,6 +814,67 @@ def get_trainer_kwargs(
813814
),
814815
)
815816
elif model_size == "150B":
817+
##################################################################################
818+
max_sequence_length = MAX_SEQUENCE_LENGTH[Version.V2] # 4096
819+
820+
# model_parallelism * fsdp == num_chips_in_trillium (256)
821+
model_parallelism = 4
822+
fsdp = 64
823+
824+
current_pdbs = 0.5
825+
train_batch_size = int(current_pdbs * len(jax.devices()))
826+
827+
# 16 * (1024**2) / 4096 = 4096
828+
tokens_per_batch = int(train_batch_size * max_sequence_length)
829+
830+
# 32M tokens is the max global tokens we can train on.
831+
# We must modify either the pdbs or the model sharding to accommodate 128 slices.
832+
if tokens_per_batch > 32 * (1024**2):
833+
tokens_per_batch = 32 * (1024**2)
834+
# if we want to modify the pdbs:
835+
# current_pdbs = 0.25
836+
837+
# otherwise we can modify the model sharding.
838+
model_parallelism = 8
839+
fsdp = 32
840+
841+
# 32M tokens is the max global tokens we can train on.
842+
assert tokens_per_batch <= 32 * (1024**2)
843+
assert fsdp * model_parallelism == 256
844+
845+
# 1 / model_parallelism = 1 / 4 = 0.25
846+
min_pdbs = 1 / model_parallelism
847+
max_pdbs = 1
848+
849+
# More than 1 pdbs causes an OOM.
850+
assert current_pdbs < max_pdbs
851+
assert current_pdbs >= min_pdbs
852+
853+
# maximum number of devices we can use this config on =
854+
# train_batch_size // min_pdbs = 4096 / 0.25 = 16384
855+
max_devices = int(train_batch_size // min_pdbs)
856+
857+
assert isinstance(train_batch_size, int)
858+
assert isinstance(tokens_per_batch, int)
859+
860+
logging.info(
861+
(
862+
"******* DEBUGGING: max_sequence_length: %s, model_parallelism: %s,"
863+
" fsdp: %s, current_pdbs: %s, train_batch_size: %s,"
864+
" tokens_per_batch: %s, min_pdbs: %s, max_pdbs: %s, max_devices: %s"
865+
),
866+
max_sequence_length,
867+
model_parallelism,
868+
fsdp,
869+
current_pdbs,
870+
train_batch_size,
871+
tokens_per_batch,
872+
min_pdbs,
873+
max_pdbs,
874+
max_devices,
875+
)
876+
##################################################################################
877+
816878
trainer_kwargs = dict(
817879
model_kwargs=dict(
818880
num_layers=80,
@@ -828,8 +890,9 @@ def get_trainer_kwargs(
828890
learner_kwargs=dict(peak_lr=1.5e-4, weight_decay=0.1),
829891
max_sequence_length=max_sequence_length,
830892
train_batch_size=train_batch_size,
831-
max_step=max_step,
832-
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=64, model=4),
893+
max_step=100_000, # max_step,
894+
save_every_n_steps=100,
895+
mesh_shape=mesh_shape_from_axes(data=-1, fsdp=fsdp, model=model_parallelism),
833896
mesh_rules=(
834897
(
835898
# Target per-device token count = 4k.
@@ -971,6 +1034,12 @@ def trainer_configs(
9711034
if model_size not in TOTAL_TOKENS[version]: # This combination does not exist.
9721035
continue
9731036
vocab_size = VOCAB_SIZE[version]
1037+
logging.info(
1038+
"******* DEBUGGING: version: %s, model_size: %s, flash_attention: %s",
1039+
version,
1040+
model_size,
1041+
flash_attention,
1042+
)
9741043
config_name = make_config_name(
9751044
arch=arch,
9761045
model_size=model_size,

0 commit comments

Comments
 (0)