Skip to content

Commit ae7d46f

Browse files
authored
Fix resume bug (#300)
* Add safe globals Signed-off-by: choiyj <cyj21c6352@gmail.com> * Move numpy allowlist to DualBrainTrainer __init__ --------- Signed-off-by: choiyj <cyj21c6352@gmail.com>
1 parent 4ea96a1 commit ae7d46f

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

gr00t/experiment/trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import os
1818
from typing import Optional
1919

20+
import numpy as np
2021
import torch
2122
import transformers
2223
from torch.utils.data import Dataset, Sampler
@@ -64,6 +65,10 @@ class DualBrainTrainer(transformers.Trainer):
6465
def __init__(self, **kwargs):
6566
self.compute_dtype = kwargs.pop("compute_dtype")
6667
super().__init__(**kwargs)
68+
# Allowlist numpy globals for safe RNG state unpickling in PyTorch 2.1+
69+
torch.serialization.add_safe_globals(
70+
[np.core.multiarray._reconstruct, np.ndarray, np.dtype, np.dtypes.UInt32DType]
71+
)
6772

6873
def _get_train_sampler(self):
6974
return BaseSampler(self.train_dataset, shuffle=True, seed=self.args.seed)

0 commit comments

Comments
 (0)