@@ -264,3 +264,43 @@ def collate_fn(batch):
264264 labels = torch .Tensor (labels ).long ()
265265
266266 return features , labels
267+
268+
269+ class NoDuplicateClassesDataLoader :
270+
271+ def __init__ (self , train_examples , batch_size ):
272+ self .batch_size = batch_size
273+ self .collate_fn = None
274+ self .train_examples = train_examples
275+
276+ # TODO: add assert batch_size <= num_classes
277+
278+ def __iter__ (self ):
279+ label_class_dict = {}
280+ random .shuffle (self .train_examples )
281+ for example in self .train_examples :
282+ example_label_list = label_class_dict .get (example .label , [])
283+ example_label_list .append (example )
284+ label_class_dict [example .label ] = example_label_list
285+
286+ for _ in range (self .__len__ ()):
287+ batch = []
288+ classes_in_batch = set ()
289+
290+ while len (batch ) < self .batch_size :
291+ class_to_add = random .choice (label_class_dict .keys ())
292+ if class_to_add not in classes_in_batch :
293+ example = label_class_dict [class_to_add ].pop (0 )
294+ batch .append (example )
295+
296+ # list of examples for this class is empty and needs to be refilled
297+ if len (label_class_dict [class_to_add ]) == 0 :
298+ random .shuffle (self .train_examples )
299+ for example in self .train_examples :
300+ if example .label == class_to_add :
301+ label_class_dict [class_to_add ].append (example )
302+
303+ yield self .collate_fn (batch ) if self .collate_fn is not None else batch
304+
305+ def __len__ (self ):
306+ return math .floor (len (self .train_examples ) / self .batch_size )
0 commit comments