Skip to content

Commit 8d2b048

Browse files
committed
feat: add balanced batch sampling
1 parent 4b954bb commit 8d2b048

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

neuralnetlib/utils.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,48 @@ def shuffle(x: np.ndarray, y: np.ndarray = None, random_state: int = None) -> tu
5050
return shuffled_x
5151

5252

53+
def balanced_batch_sampling(n_classes: int, real_samples: np.ndarray, labels: np.ndarray, batch_size: int, rng: np.random.Generator):
54+
"""Generates a balanced batch of samples by selecting a fixed number of samples from each class.
55+
56+
Args:
57+
n_classes (int): The number of classes
58+
real_samples (np.ndarray): The real samples
59+
labels (np.ndarray): The labels of the samples in one-hot encoding
60+
batch_size (int): The total number of samples to select
61+
rng (np.random.Generator): The random number generator
62+
63+
Raises:
64+
ValueError: If the batch size is less than the number of classes
65+
66+
Returns:
67+
tuple: A tuple of (real_samples, labels) where each array has the selected samples
68+
"""
69+
samples_per_class = batch_size // n_classes
70+
if samples_per_class == 0:
71+
raise ValueError(f"batch_size ({batch_size}) doit être au moins égal au nombre de classes ({n_classes})")
72+
73+
selected_indices = []
74+
75+
class_indices = [np.nonzero(labels[:, class_idx] == 1)[0] for class_idx in range(n_classes)]
76+
77+
empty_classes = [i for i, indices in enumerate(class_indices) if len(indices) == 0]
78+
if empty_classes:
79+
raise ValueError(f"Les classes {empty_classes} n'ont aucun échantillon dans le dataset")
80+
81+
for class_idx in range(n_classes):
82+
selected_class_indices = rng.choice(
83+
class_indices[class_idx],
84+
size=samples_per_class,
85+
replace=True
86+
)
87+
selected_indices.extend(selected_class_indices)
88+
89+
selected_indices = np.array(selected_indices)
90+
rng.shuffle(selected_indices)
91+
92+
return real_samples[selected_indices], labels[selected_indices]
93+
94+
5395
def progress_bar(current: int, total: int, width: int = 30, message: str = "") -> None:
5496
"""
5597
Prints a progress bar to the console.

0 commit comments

Comments
 (0)