Skip to content

Commit cd1ee8a

Browse files
committed
add NoDuplicateClassesDataLoader
1 parent f8b9873 commit cd1ee8a

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

src/setfit/data.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,3 +264,43 @@ def collate_fn(batch):
264264
labels = torch.Tensor(labels).long()
265265

266266
return features, labels
267+
268+
269+
class NoDuplicateClassesDataLoader:
270+
271+
def __init__(self, train_examples, batch_size):
272+
self.batch_size = batch_size
273+
self.collate_fn = None
274+
self.train_examples = train_examples
275+
276+
# TODO: add assert batch_size <= num_classes
277+
278+
def __iter__(self):
279+
label_class_dict = {}
280+
random.shuffle(self.train_examples)
281+
for example in self.train_examples:
282+
example_label_list = label_class_dict.get(example.label, [])
283+
example_label_list.append(example)
284+
label_class_dict[example.label] = example_label_list
285+
286+
for _ in range(self.__len__()):
287+
batch = []
288+
classes_in_batch = set()
289+
290+
while len(batch) < self.batch_size:
291+
class_to_add = random.choice(label_class_dict.keys())
292+
if class_to_add not in classes_in_batch:
293+
example = label_class_dict[class_to_add].pop(0)
294+
batch.append(example)
295+
296+
# list of examples for this class is empty and needs to be refilled
297+
if len(label_class_dict[class_to_add]) == 0:
298+
random.shuffle(self.train_examples)
299+
for example in self.train_examples:
300+
if example.label == class_to_add:
301+
label_class_dict[class_to_add].append(example)
302+
303+
yield self.collate_fn(batch) if self.collate_fn is not None else batch
304+
305+
def __len__(self):
306+
return math.floor(len(self.train_examples) / self.batch_size)

0 commit comments

Comments
 (0)