-
Notifications
You must be signed in to change notification settings - Fork 69
Open
Labels
documentationImprovements or additions to documentationImprovements or additions to documentationtype:featureNew feature or requestNew feature or request
Description
Currently, the documentation states that dataloader checkpointing is possible using orbax, but it is not documented how to actually do it. I essentially found the (I think?) correct way by experimentation and reading the grain/orbax source code.
TL;DR you need to use CheckpointManager, add handlers, and checkpoint the DataLoaderIterator (created using _create_initial_state()) before converting it to a python iterable.
Here is an example of how to use dataloader checkpointing (for future reference):
if __name__ == "__main__":
jax.distributed.initialize()
num_devices = jax.device_count()
if num_devices == 0:
raise ValueError("No JAX devices found.")
print(f"Running on {num_devices} devices.")
if args.batch_size % num_devices != 0:
raise ValueError(
f"Global batch size {args.batch_size} must be divisible by "
f"number of devices {num_devices}."
)
per_device_batch_size_for_init = args.batch_size // num_devices
rng = jax.random.PRNGKey(args.seed)
# --- Initialize model ---
tokenizer = TokenizerVQVAE(
in_dim=args.image_channels,
model_dim=args.model_dim,
latent_dim=args.latent_dim,
num_latents=args.num_latents,
patch_size=args.patch_size,
num_blocks=args.num_blocks,
num_heads=args.num_heads,
dropout=args.dropout,
codebook_dropout=args.codebook_dropout,
)
rng, _rng = jax.random.split(rng)
image_shape = (args.image_height, args.image_width, args.image_channels)
inputs = dict(
videos=jnp.zeros(
(per_device_batch_size_for_init, args.seq_len, *image_shape),
dtype=jnp.float32,
),
)
init_params = tokenizer.init(_rng, inputs)
param_counts = count_parameters_by_component(init_params)
if args.log and jax.process_index() == 0:
wandb.init(
entity=args.entity,
project=args.project,
name=args.name,
tags=args.tags,
group="debug",
config=args,
)
wandb.config.update({"model_param_count": param_counts})
print("Parameter counts:")
print(param_counts)
# --- Initialize optimizer ---
lr_schedule = optax.warmup_cosine_decay_schedule(
args.min_lr, args.max_lr, args.warmup_steps, args.num_steps
)
tx = optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4)
train_state = TrainState.create(apply_fn=tokenizer.apply, params=init_params, tx=tx)
# FIXME: switch to create_hybrid_device_mesh for runs spanning multiple nodes
device_mesh_arr = create_device_mesh((num_devices,))
mesh = Mesh(devices=device_mesh_arr, axis_names=("data",))
replicated_sharding = NamedSharding(mesh, PartitionSpec())
videos_sharding = NamedSharding(
mesh, PartitionSpec("data", None, None, None, None)
)
train_state = jax.device_put(train_state, replicated_sharding)
# --- Initialize checkpoint manager ---
step = 0
handler_registry = ocp.handlers.DefaultCheckpointHandlerRegistry()
handler_registry.add('model_state', ocp.args.StandardSave, ocp.handlers.StandardCheckpointHandler)
handler_registry.add('model_state', ocp.args.StandardRestore, ocp.handlers.StandardCheckpointHandler)
handler_registry.add('dataloader_state', grain.checkpoint.CheckpointSave, grain.checkpoint.CheckpointHandler) # type: ignore
handler_registry.add('dataloader_state', grain.checkpoint.CheckpointRestore, grain.checkpoint.CheckpointHandler) # type: ignore
checkpoint_options = ocp.CheckpointManagerOptions(
save_interval_steps=args.log_checkpoint_interval,
max_to_keep=3,
step_format_fixed_length=6,
cleanup_tmp_directories=True,
)
checkpoint_manager = ocp.CheckpointManager(
args.ckpt_dir,
options=checkpoint_options,
handler_registry=handler_registry,
)
# --- Create DataLoaderIterator from dataloader ---
array_record_files = [
os.path.join(args.data_dir, x)
for x in os.listdir(args.data_dir)
if x.endswith(".array_record")
]
grain_dataloader = get_dataloader(
array_record_files,
args.seq_len,
# NOTE: We deliberately pass the global batch size
# The dataloader shards the dataset across all processes
args.batch_size,
*image_shape,
num_workers=8,
prefetch_buffer_size=1,
seed=args.seed,
)
initial_state = grain_dataloader._create_initial_state()
grain_iterator = grain.DataLoaderIterator(grain_dataloader, initial_state)
# --- Restore checkpoint ---
if args.restore_ckpt:
abstract_train_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, train_state)
restored = checkpoint_manager.restore(
checkpoint_manager.latest_step(),
args=ocp.args.Composite(
model_state=ocp.args.StandardRestore(abstract_train_state),
dataloader_state=grain.checkpoint.CheckpointRestore(grain_iterator),
)
)
train_state = restored["model_state"]
grain_iterator = restored["dataloader_state"]
step = checkpoint_manager.latest_step() or 0
print(f"Restored dataloader and model state from step {step}")
# --- TRAIN LOOP ---
dataloader = (jax.make_array_from_process_local_data(videos_sharding, elem) for elem in grain_iterator) # type: ignore
print(f"Starting training from step {step}...")
while step < args.num_steps:
for videos in dataloader:
# --- Train step ---
rng, _rng, _rng_dropout = jax.random.split(rng, 3)
inputs = dict(videos=videos, rng=_rng, dropout_rng=_rng_dropout)
train_state, loss, recon, metrics = train_step(train_state, inputs)
print(f"Step {step}, loss: {loss}")
step += 1
# --- Checkpointing ---
if args.save_ckpt and step % args.log_checkpoint_interval == 0:
checkpoint_manager.save(
step,
args=ocp.args.Composite(
model_state=ocp.args.StandardSave(train_state),
dataloader_state=grain.checkpoint.CheckpointSave(grain_iterator),
)
)
print(f"Saved checkpoint at step {step}")
if step >= args.num_steps:
break
checkpoint_manager.close()Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
documentationImprovements or additions to documentationImprovements or additions to documentationtype:featureNew feature or requestNew feature or request