Skip to content

Batches loaded from wrong epoch when resuming from second epoch #40690

@ngazagna-qc

Description

@ngazagna-qc

System Info

Required system information

- `transformers` version: 4.57.0.dev0
- Platform: Linux-5.15.0-133-generic-x86_64-with-glibc2.35
- Python version: 3.10.12
- Huggingface_hub version: 0.34.4
- Safetensors version: 0.6.2
- Accelerate version: 1.10.1
- Accelerate config:    not found
- DeepSpeed version: not installed
- PyTorch version (accelerator?): 2.8.0+cu128 (CUDA)
- Tensorflow version (GPU?): 2.15.1 (False)
- Flax version (CPU?/GPU?/TPU?): 0.7.0 (cpu)
- Jax version: 0.4.13
- JaxLib version: 0.4.13
- Using distributed or parallel set-up in script?: no
- Using GPU in script?: no
- GPU type: GRID A100D-16C

Who can help?

@zach-huggingface @SunMarc as it concerns transfomers' Trainer

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

1. Bug description

Let's take the example of the provided script:

  • number of data points: 10
  • batch size: 2
    So 1 epoch = 5 steps.

If we launch a training until the end and monitor the data order:

  • epoch 0: 4, 1, 7, 5, 3, 9, 0, 8, 6, 2
  • epoch 1: 5, 6, || 1, 2, 0, 8, 9, 3, 7, 4
  • epoch 2: 8, 7, 1, 5, 6, 9, 0, 4, 2, 3

But if we stop the training at step 6 and resume (from character ||) the training to the end, we get the following data order:

  • epoch 0: 4, 1, 7, 5, 3, 9, 0, 8, 6, 2
  • epoch 1: 5, 6 || 7, 5, 3, 9, 0, 8, 6, 2
  • epoch 2: 8, 7, 1, 5, 6, 9, 0, 4, 2, 3

We spotted that the epoch_dataloader.iteration is not properly set for the first epoch after resuming. It is initially set to 0, this is why it loads the same order as in epoch 0 (cf data order in italic of the last 4 batches of epoch 0).

2. Reproducing the error

The script to run is available at https://github.com/ngazagna-qc/transformers/blob/fix-data-order-resumed-epoch/reproduce_wrong_resumed_epoch.py.
Run:

python reproduce_wrong_resumed_epoch.py --trainer-class Trainer

Expected behavior

3. Bug fix

We provide the fixed Trainer here: https://github.com/ngazagna-qc/transformers/blob/fix-data-order-resumed-epoch/src/transformers/trainer_fixed.py#L56

The fix only consists to add a line to the _inner_training_loop method:

            if steps_trained_in_current_epoch > 0:
                epoch_dataloader = skip_first_batches(epoch_dataloader, steps_trained_in_current_epoch)
                #### BEGINNING OF THE FIX ####
                epoch_dataloader.iteration = epochs_trained  # FIX: set dataloader to correct epoch
                #### END OF THE FIX ####
                steps_skipped = steps_trained_in_current_epoch
                steps_trained_in_current_epoch = 0
                rng_to_sync = True

It can be tested that this solves the order by running:

python reproduce_wrong_resumed_epoch.py --trainer-class TrainerFixed

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions