Skip to content

Commit 97b3707

Browse files
committed
Update params to work for number of devices
1 parent 035d471 commit 97b3707

File tree

12 files changed

+253
-27
lines changed

12 files changed

+253
-27
lines changed

Dockerfile

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,15 @@ ARG EXTRAS=
8888
ENV UV_FIND_LINKS=https://storage.googleapis.com/jax-releases/libtpu_releases.html
8989
# Ensure we install the TPU version, even if building locally.
9090
# Jax will fallback to CPU when run on a machine without TPU.
91-
RUN uv pip install --prerelease=allow .[core,tpu] && uv cache clean
91+
COPY libtpu.so /root/libtpu.so
92+
RUN uv pip install --prerelease=allow .[core,gcp,tpu] && uv cache clean
93+
RUN uv pip install libtpu==0.0.14
94+
95+
# Add this line to print the installed version of libtpu.
96+
RUN pip show libtpu | grep Version
97+
RUN pip show jax | grep Version
98+
RUN pip show jaxlib | grep Version
99+
92100
RUN if [ -n "$EXTRAS" ]; then uv pip install .[$EXTRAS] && uv cache clean; fi
93101
COPY . .
94102

axlearn/cloud/gcp/jobs/launch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,6 @@ def _wrapped_usage(
780780

781781

782782
if __name__ == "__main__":
783-
configure_logging(logging.INFO)
783+
configure_logging(logging.DEBUG)
784784
_private_flags()
785785
app.run(main)

axlearn/cloud/gcp/jobset_utils.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import io
66
import logging
7-
import math
87
import os
98
from dataclasses import dataclass
109
from typing import Any, Optional, Sequence
@@ -27,10 +26,7 @@
2726
)
2827
from axlearn.cloud.gcp.config import gcp_settings
2928
from axlearn.cloud.gcp.node_pool import PRE_PROVISIONER_LABEL
30-
from axlearn.cloud.gcp.system_characteristics import (
31-
GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS,
32-
USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS,
33-
)
29+
from axlearn.cloud.gcp.system_characteristics import USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS
3430
from axlearn.cloud.gcp.tpu import get_default_env, infer_tpu_workers
3531
from axlearn.cloud.gcp.utils import validate_jobset_name
3632
from axlearn.common.compiler_options import infer_tpu_type
@@ -451,15 +447,17 @@ def _build_container(self) -> Nested[Any]:
451447
if cfg.enable_tpu_ici_resiliency is not None:
452448
env_vars["ENABLE_ICI_RESILIENCY"] = str(cfg.enable_tpu_ici_resiliency).lower()
453449

450+
env_vars["TPU_LIBRARY_PATH"] = "/root/libtpu.so"
451+
454452
resources = {"limits": {"google.com/tpu": system.chips_per_vm}}
455-
# Set request memory by host machine type.
456-
machine_memory_gi = GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS.get(
457-
system.gce_machine_type, None
458-
)
459-
if machine_memory_gi is not None:
460-
request_memory_gi = machine_memory_gi * _MEMORY_REQUEST_PERCENTAGE
461-
resources["limits"]["memory"] = f"{machine_memory_gi}Gi"
462-
resources["requests"] = {"memory": f"{math.floor(request_memory_gi)}Gi"}
453+
# # Set request memory by host machine type.
454+
# machine_memory_gi = GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS.get(
455+
# system.gce_machine_type, None
456+
# )
457+
# if machine_memory_gi is not None:
458+
# request_memory_gi = machine_memory_gi * _MEMORY_REQUEST_PERCENTAGE
459+
# resources["limits"]["memory"] = f"{machine_memory_gi}Gi"
460+
# resources["requests"] = {"memory": f"{math.floor(request_memory_gi)}Gi"}
463461

464462
k8s_env_vars = [dict(name=k, value=str(v)) for k, v in env_vars.items()]
465463
k8s_env_vars.append(
@@ -509,8 +507,8 @@ def _build_uploader_container(
509507
interval_s = 60
510508
sync_command = f"while true; do gsutil -m rsync -r {src} {dst}; sleep {interval_s}; done"
511509
resources = {
512-
"requests": {"cpu": "100m", "memory": "128Mi"},
513-
"limits": {"cpu": "500m", "memory": "256Mi"},
510+
# "requests": {"cpu": "100m", "memory": "128Mi"},
511+
# "limits": {"cpu": "500m", "memory": "256Mi"},
514512
}
515513
return dict(
516514
name="output-uploader",

axlearn/cloud/gcp/tpu.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313

1414
def get_default_env(*, tpu_type: str, num_tpu_slices: int, job_name: str) -> dict[str, Any]:
1515
"""Gets the default environment for TPU pods."""
16+
del job_name # Unused.
1617
return dict(
1718
# Use a large refresh to mitigate DNS timeout issues until tf>2.12 upgrade.
1819
GCS_RESOLVE_REFRESH_SECS=600,
1920
TPU_TYPE=tpu_type,
2021
NUM_TPU_SLICES=num_tpu_slices,
21-
XLA_FLAGS=f"--xla_dump_to=/output/{job_name}/xla",
22+
XLA_FLAGS="",
23+
# XLA_FLAGS=f"--xla_dump_to=/output/{job_name}/xla",
2224
TF_CPP_MIN_LOG_LEVEL=0,
2325
# Necessary for surfacing FATAL TPU errors.
2426
TPU_STDERR_LOG_LEVEL=0,

axlearn/common/array_serialization.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,21 +559,34 @@ def serialize(
559559

560560
# pylint: disable-next=redefined-outer-name
561561
async def _run_serializer():
562+
logging.info(
563+
"******* DEBUG GlobalAsyncCheckpointManager _run_serializer "
564+
"with number of commit_futures: %s",
565+
len(commit_futures),
566+
)
562567
future_writer = jax.tree.map(
563568
serialization.async_serialize, arrays, tensorstore_specs, commit_futures
564569
)
570+
logging.info("******* DEBUG GlobalAsyncCheckpointManager _run_serializer Completed")
565571
return await asyncio.gather(*future_writer)
566572

573+
# Is this the problem?
574+
logging.info("******* DEBUG Starting to run _run_serializer")
575+
567576
# Note: We need to run the coroutine in another event loop driven by a separate thread.
568577
# The current event loop might be already running an async function when `serialize` is
569578
# invoked from a coroutine, in which case asyncio.get_running_loop().run_until_complete()
570579
# would not be able to execute another coroutine to completion.
571580
asyncio.run_coroutine_threadsafe(_run_serializer(), self._loop).result()
572581

582+
logging.info("******* DEBUG Starting to run _run_serializer")
583+
573584
self._add_futures(
574585
jax.tree_util.tree_flatten(commit_futures)[0] + (additional_futures or [])
575586
)
576587

588+
logging.info("******* DEBUG Starting to run async_commit")
589+
577590
# Used in wait_until_finished to check on process != 0, if the checkpoint
578591
# has finished writing.
579592
self._start_async_commit(on_commit_callback)

axlearn/common/checkpointer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,11 +175,14 @@ def async_save_tf_savables(
175175
When this call returns, `value_map` can be safely mutated, but saving to `dir` will not
176176
complete unless the returned future is set.
177177
"""
178+
logging.info("******* DEBUG Saving TF savables to %s async", dir)
178179
# pylint: disable-next=consider-using-with
179180
f = tempfile.TemporaryDirectory()
180181
for path, value in utils.flatten_items(value_map):
181182
tf_checkpoint = tf.train.Checkpoint(value)
183+
logging.info("******* DEBUG Writing %s to path %s", f.name, path)
182184
tf_checkpoint.write(os.path.join(f.name, path))
185+
logging.info("******* DEBUG Done writing %s to path %s", f.name, path)
183186
return executor.submit(_upload_dir, f, dst_dir=dir)
184187

185188

@@ -399,6 +402,7 @@ def __init__(self, cfg: Config):
399402
# TODO(markblee): Consider making BoundedDataShardedAsyncCheckpointManager
400403
# the default once stable.
401404
if cfg.max_concurrent_gb is not None or cfg.max_data_shard_degree:
405+
logging.info("******* DEBUG Using BoundedDataShardedAsyncCheckpointManager")
402406
self._manager = BoundedDataShardedAsyncCheckpointManager(
403407
max_concurrent_gb=cfg.max_concurrent_gb,
404408
timeout_secs=cfg.timeout_secs,
@@ -411,6 +415,7 @@ def __init__(self, cfg: Config):
411415
f"shard_threshold_bytes is set to {cfg.shard_threshold_bytes}, but "
412416
"max_data_shard_degree is not set. It will not take any effect."
413417
)
418+
logging.info("******* DEBUG Using GlobalAsyncCheckpointManager")
414419
self._manager = GlobalAsyncCheckpointManager(timeout_secs=cfg.timeout_secs)
415420
if cfg.max_concurrent_restore_gb is not None and cfg.max_concurrent_restore_gb <= 0:
416421
raise ValueError(
@@ -514,8 +519,12 @@ def save_to_dir(
514519
logging.info("Creating directories: %s", dirs)
515520
list(self._executor.map(fs.makedirs, dirs))
516521
logging.info("All directories created")
522+
523+
logging.info("******* DEBUG starting sync_global_devices")
517524
# Wait for directory and index creation.
518525
multihost_utils.sync_global_devices(ckpt_dir)
526+
logging.info("******* DEBUG finished sync_global_devices")
527+
519528
# Each worker writes its tf checkpoints under a different path.
520529
save_tf_future = async_save_tf_savables(
521530
spec.tf_ckpt_map,
@@ -527,6 +536,7 @@ def save_to_dir(
527536
)
528537

529538
def commit():
539+
logging.info("******* DEBUG starting on_commit_callback")
530540
on_commit_callback(ckpt_dir=ckpt_dir, index=spec.index)
531541
logging.info(
532542
"Serialization of %s completed in %s seconds.",
@@ -538,6 +548,9 @@ def commit():
538548
logging.debug(
539549
"array_values=%s tensorstore=%s", utils.shapes(spec.gda_values), spec.tensorstore_specs
540550
)
551+
logging.info(
552+
"array_values=%s tensorstore=%s", utils.shapes(spec.gda_values), spec.tensorstore_specs
553+
)
541554
self._manager.serialize(
542555
spec.gda_values,
543556
spec.tensorstore_specs,

axlearn/common/compiler_options.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,12 @@ def default_xla_options(
5959
# further if you see "Allocator failed to allocate". A feature
6060
# to dynamically allocate may come later: b/380514965
6161
megascale_grpc_premap_memory_bytes=17179869184,
62+
# DEBUGGING ONLY: RapidEye output directory for debugging purposes,
63+
megascale_rapid_eye_error_digest_log_path="/output/rapideye/",
64+
# megascale_jax_offset_launch_id_by_module_name="false",
65+
# megascale_jax_use_device_set_based_launch_id="false",
66+
# enable megascale debug port.
67+
megascale_debug_port=8081,
6268
# Flag controlling the maximum number of overlapping host offloadings.
6369
xla_tpu_host_transfer_overlap_limit=24,
6470
# Flag controlling the maximum number of overlapping cross-DCN send/recv.
@@ -149,12 +155,20 @@ def default_xla_options(
149155
# Similar to megascale_error_reporter_abort_on_hang but for unrecoverable errors.
150156
megascale_error_reporter_abort_on_error="true",
151157
# Increase the timeout at which a hang is detected/reported, default is 5m.
152-
megascale_graph_hang_threshold="10m",
158+
megascale_graph_hang_threshold="60m",
153159
# Similar to megascale_graph_hang_threshold but specific to within a launch_id.
154160
# Default is 1m.
155-
megascale_graph_within_launch_hang_threshold="10m",
161+
megascale_graph_within_launch_hang_threshold="60m",
156162
# TODO(ethanli): temporary workaround to avoid memory leak in megascale.
157163
megascale_grpc_enable_xor_tracer="false",
164+
# # The duration of missing heartbeats before shutting down.
165+
# jax_heartbeat_timeout="100s",
166+
# # JAX gRPC timeout duration.
167+
# jax_rpc_timeout="120s",
168+
# # JAX distributed initialization timeout.
169+
# jax_distributed_initialization_timeout="3600s",
170+
# # JAX shutdown timeout duration
171+
# jax_distributed_shutdown_timeout="5m",
158172
)
159173

160174
# Validate options. Will never fail if this function is implemented correctly.
@@ -163,7 +177,20 @@ def default_xla_options(
163177
int(v)
164178
continue
165179
except ValueError:
166-
assert v in [True, False, "true", "false", "megachip_tccontrol", "10m"], (k, v)
180+
assert v in [
181+
True,
182+
False,
183+
"true",
184+
"false",
185+
"megachip_tccontrol",
186+
"10m",
187+
"60m",
188+
"100s",
189+
"120s",
190+
"3600s",
191+
"5m",
192+
"/output/rapideye/",
193+
], (k, v)
167194

168195
return options
169196

@@ -302,6 +329,7 @@ def infer_xla_performance_flags(
302329
if current_configuration in mesh_configurations_for_sparse_core_offloading:
303330
flags = dict(
304331
# Must disable continuation fusion to enable sparse core offloading.
332+
# AXLEARN TESTING NOTE: We are disabling this to test for SparseCore related issues.
305333
xla_tpu_enable_async_collective_fusion_fuse_all_gather="false",
306334
xla_tpu_enable_async_collective_fusion_fuse_all_reduce="false",
307335
xla_tpu_enable_async_collective_fusion_fuse_reduce_scatter="false",

axlearn/common/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -626,7 +626,7 @@ def run(
626626
)
627627
self.vlog(3, "Done step %s", self.step)
628628
num_steps += 1
629-
if num_steps % 100 == 0:
629+
if num_steps % 10 == 0:
630630
now = time.perf_counter()
631631
average_step_time = (now - start_time) / num_steps
632632
self._step_log("Average step time: %s seconds", average_step_time)
@@ -1099,7 +1099,7 @@ def _run_step(
10991099
# Run the compiled function.
11001100
self._trainer_state, outputs = compiled_train_step_fn(self.trainer_state, input_batch)
11011101

1102-
if self.step % 100 == 0 or 0 <= self.step <= 5:
1102+
if self.step % 10 == 0 or 0 <= self.step <= 5:
11031103
self._step_log(
11041104
"loss=%s aux=%s",
11051105
outputs["loss"],

axlearn/common/utils_spmd.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,16 @@ def setup(
8888
coordinator_address=distributed_coordinator,
8989
num_processes=num_processes,
9090
process_id=process_id,
91+
# The duration of missing heartbeats before shutting down.
92+
heartbeat_timeout="120s",
93+
# JAX distributed initialization timeout.
94+
initialization_timeout="3600s",
95+
# JAX distributed shutdown timeout.
96+
shutdown_timeout="3600s",
97+
# RPC timeout.
98+
rpc_timeout="3600s",
99+
# RPC timeout for heartbeat.
100+
coordinator_rpc_timeout="3600s",
91101
)
92102
if jax_backend == "gpu":
93103
# jax 0.4.34 introduced a change to cluster auto-detection behavior, supplying

0 commit comments

Comments
 (0)