We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 4ea96a1 commit ae7d46fCopy full SHA for ae7d46f
gr00t/experiment/trainer.py
@@ -17,6 +17,7 @@
17
import os
18
from typing import Optional
19
20
+import numpy as np
21
import torch
22
import transformers
23
from torch.utils.data import Dataset, Sampler
@@ -64,6 +65,10 @@ class DualBrainTrainer(transformers.Trainer):
64
65
def __init__(self, **kwargs):
66
self.compute_dtype = kwargs.pop("compute_dtype")
67
super().__init__(**kwargs)
68
+ # Allowlist numpy globals for safe RNG state unpickling in PyTorch 2.1+
69
+ torch.serialization.add_safe_globals(
70
+ [np.core.multiarray._reconstruct, np.ndarray, np.dtype, np.dtypes.UInt32DType]
71
+ )
72
73
def _get_train_sampler(self):
74
return BaseSampler(self.train_dataset, shuffle=True, seed=self.args.seed)
0 commit comments