@@ -50,6 +50,48 @@ def shuffle(x: np.ndarray, y: np.ndarray = None, random_state: int = None) -> tu
5050 return shuffled_x
5151
5252
53+ def balanced_batch_sampling (n_classes : int , real_samples : np .ndarray , labels : np .ndarray , batch_size : int , rng : np .random .Generator ):
54+ """Generates a balanced batch of samples by selecting a fixed number of samples from each class.
55+
56+ Args:
57+ n_classes (int): The number of classes
58+ real_samples (np.ndarray): The real samples
59+ labels (np.ndarray): The labels of the samples in one-hot encoding
60+ batch_size (int): The total number of samples to select
61+ rng (np.random.Generator): The random number generator
62+
63+ Raises:
64+ ValueError: If the batch size is less than the number of classes
65+
66+ Returns:
67+ tuple: A tuple of (real_samples, labels) where each array has the selected samples
68+ """
69+ samples_per_class = batch_size // n_classes
70+ if samples_per_class == 0 :
71+ raise ValueError (f"batch_size ({ batch_size } ) doit être au moins égal au nombre de classes ({ n_classes } )" )
72+
73+ selected_indices = []
74+
75+ class_indices = [np .nonzero (labels [:, class_idx ] == 1 )[0 ] for class_idx in range (n_classes )]
76+
77+ empty_classes = [i for i , indices in enumerate (class_indices ) if len (indices ) == 0 ]
78+ if empty_classes :
79+ raise ValueError (f"Les classes { empty_classes } n'ont aucun échantillon dans le dataset" )
80+
81+ for class_idx in range (n_classes ):
82+ selected_class_indices = rng .choice (
83+ class_indices [class_idx ],
84+ size = samples_per_class ,
85+ replace = True
86+ )
87+ selected_indices .extend (selected_class_indices )
88+
89+ selected_indices = np .array (selected_indices )
90+ rng .shuffle (selected_indices )
91+
92+ return real_samples [selected_indices ], labels [selected_indices ]
93+
94+
5395def progress_bar (current : int , total : int , width : int = 30 , message : str = "" ) -> None :
5496 """
5597 Prints a progress bar to the console.
0 commit comments