Skip to content

Add dataloader checkpointing example to docs #938

@emergenz

Description

@emergenz

Currently, the documentation states that dataloader checkpointing is possible using orbax, but it is not documented how to actually do it. I essentially found the (I think?) correct way by experimentation and reading the grain/orbax source code.

TL;DR you need to use CheckpointManager, add handlers, and checkpoint the DataLoaderIterator (created using _create_initial_state()) before converting it to a python iterable.

Here is an example of how to use dataloader checkpointing (for future reference):

if __name__ == "__main__":
    jax.distributed.initialize()
    num_devices = jax.device_count()
    if num_devices == 0:
        raise ValueError("No JAX devices found.")
    print(f"Running on {num_devices} devices.")

    if args.batch_size % num_devices != 0:
        raise ValueError(
            f"Global batch size {args.batch_size} must be divisible by "
            f"number of devices {num_devices}."
        )

    per_device_batch_size_for_init = args.batch_size // num_devices

    rng = jax.random.PRNGKey(args.seed)

    # --- Initialize model ---
    tokenizer = TokenizerVQVAE(
        in_dim=args.image_channels,
        model_dim=args.model_dim,
        latent_dim=args.latent_dim,
        num_latents=args.num_latents,
        patch_size=args.patch_size,
        num_blocks=args.num_blocks,
        num_heads=args.num_heads,
        dropout=args.dropout,
        codebook_dropout=args.codebook_dropout,
    )
    rng, _rng = jax.random.split(rng)
    image_shape = (args.image_height, args.image_width, args.image_channels)
    inputs = dict(
        videos=jnp.zeros(
            (per_device_batch_size_for_init, args.seq_len, *image_shape),
            dtype=jnp.float32,
        ),
    )
    init_params = tokenizer.init(_rng, inputs)

    param_counts = count_parameters_by_component(init_params)

    if args.log and jax.process_index() == 0:
        wandb.init(
            entity=args.entity,
            project=args.project,
            name=args.name,
            tags=args.tags,
            group="debug",
            config=args,
        )
        wandb.config.update({"model_param_count": param_counts})

    print("Parameter counts:")
    print(param_counts)

    # --- Initialize optimizer ---
    lr_schedule = optax.warmup_cosine_decay_schedule(
        args.min_lr, args.max_lr, args.warmup_steps, args.num_steps
    )
    tx = optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4)
    train_state = TrainState.create(apply_fn=tokenizer.apply, params=init_params, tx=tx)

    # FIXME: switch to create_hybrid_device_mesh for runs spanning multiple nodes
    device_mesh_arr = create_device_mesh((num_devices,))
    mesh = Mesh(devices=device_mesh_arr, axis_names=("data",))

    replicated_sharding = NamedSharding(mesh, PartitionSpec())
    videos_sharding = NamedSharding(
        mesh, PartitionSpec("data", None, None, None, None)
    )
    train_state = jax.device_put(train_state, replicated_sharding)

    # --- Initialize checkpoint manager ---
    step = 0
    handler_registry = ocp.handlers.DefaultCheckpointHandlerRegistry()
    handler_registry.add('model_state', ocp.args.StandardSave, ocp.handlers.StandardCheckpointHandler)
    handler_registry.add('model_state', ocp.args.StandardRestore, ocp.handlers.StandardCheckpointHandler)
    handler_registry.add('dataloader_state', grain.checkpoint.CheckpointSave, grain.checkpoint.CheckpointHandler) # type: ignore
    handler_registry.add('dataloader_state', grain.checkpoint.CheckpointRestore, grain.checkpoint.CheckpointHandler) # type: ignore
    
    checkpoint_options = ocp.CheckpointManagerOptions(
        save_interval_steps=args.log_checkpoint_interval,
        max_to_keep=3,
        step_format_fixed_length=6,
        cleanup_tmp_directories=True,
    )
    
    checkpoint_manager = ocp.CheckpointManager(
        args.ckpt_dir,
        options=checkpoint_options,
        handler_registry=handler_registry,
    )

    # --- Create DataLoaderIterator from dataloader ---
    array_record_files = [
        os.path.join(args.data_dir, x)
        for x in os.listdir(args.data_dir)
        if x.endswith(".array_record")
    ]
    grain_dataloader = get_dataloader(
        array_record_files,
        args.seq_len,
        # NOTE: We deliberately pass the global batch size
        # The dataloader shards the dataset across all processes
        args.batch_size,
        *image_shape,
        num_workers=8,
        prefetch_buffer_size=1,
        seed=args.seed,
    )
    initial_state = grain_dataloader._create_initial_state()
    grain_iterator = grain.DataLoaderIterator(grain_dataloader, initial_state)
    
    # --- Restore checkpoint ---
    if args.restore_ckpt:
        abstract_train_state = jax.tree_util.tree_map(ocp.utils.to_shape_dtype_struct, train_state)
        restored = checkpoint_manager.restore(
            checkpoint_manager.latest_step(),
            args=ocp.args.Composite(
                model_state=ocp.args.StandardRestore(abstract_train_state),
                dataloader_state=grain.checkpoint.CheckpointRestore(grain_iterator),
            )
        )
        train_state = restored["model_state"]
        grain_iterator = restored["dataloader_state"]
        step = checkpoint_manager.latest_step() or 0
        print(f"Restored dataloader and model state from step {step}")

    # --- TRAIN LOOP ---
    dataloader = (jax.make_array_from_process_local_data(videos_sharding, elem) for elem in grain_iterator) # type: ignore
    print(f"Starting training from step {step}...")
    while step < args.num_steps:
        for videos in dataloader:
            # --- Train step ---
            rng, _rng, _rng_dropout = jax.random.split(rng, 3)

            inputs = dict(videos=videos, rng=_rng, dropout_rng=_rng_dropout)
            train_state, loss, recon, metrics = train_step(train_state, inputs)
            print(f"Step {step}, loss: {loss}")
            step += 1

            # --- Checkpointing ---
            if args.save_ckpt and step % args.log_checkpoint_interval == 0:
                checkpoint_manager.save(
                    step,
                    args=ocp.args.Composite(
                        model_state=ocp.args.StandardSave(train_state),
                        dataloader_state=grain.checkpoint.CheckpointSave(grain_iterator),
                    )
                )
                print(f"Saved checkpoint at step {step}")
            if step >= args.num_steps:
                break

    checkpoint_manager.close()

Metadata

Metadata

Assignees

No one assigned

    Labels

    documentationImprovements or additions to documentationtype:featureNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions