@@ -330,14 +330,56 @@ def _sample(self, X, y):
330
330
331
331
prev_len = y_ .shape [0 ]
332
332
if self .return_indices :
333
- X_ , y_ , idx_ = self .enn_ .fit_sample (X_ , y_ )
334
- idx_under = idx_under [idx_ ]
333
+ X_enn , y_enn , idx_enn = self .enn_ .fit_sample (X_ , y_ )
335
334
else :
336
- X_ , y_ = self .enn_ .fit_sample (X_ , y_ )
337
-
338
- if prev_len == y_ .shape [0 ]:
335
+ X_enn , y_enn = self .enn_ .fit_sample (X_ , y_ )
336
+
337
+ # Check the stopping criterion
338
+ # 1. If there is no changes for the vector y
339
+ # 2. If the number of samples in the other class become inferior to
340
+ # the number of samples in the majority class
341
+ # 3. If one of the class is disappearing
342
+
343
+ # Case 1
344
+ b_conv = (prev_len == y_enn .shape [0 ])
345
+
346
+ # Case 2
347
+ stats_enn = Counter (y_enn )
348
+ self .logger .debug ('Current ENN stats: %s' , stats_enn )
349
+ # Get the number of samples in the non-minority classes
350
+ count_non_min = np .array ([val for val , key
351
+ in zip (stats_enn .itervalues (),
352
+ stats_enn .iterkeys ())
353
+ if key != self .min_c_ ])
354
+ self .logger .debug ('Number of samples in the non-majority'
355
+ ' classes: %s' , count_non_min )
356
+ # Check the minority stop to be the minority
357
+ b_min_bec_maj = np .any (count_non_min < self .stats_c_ [self .min_c_ ])
358
+
359
+ # Case 3
360
+ b_remove_maj_class = (len (stats_enn ) < len (self .stats_c_ ))
361
+
362
+ if b_conv or b_min_bec_maj or b_remove_maj_class :
363
+ # If this is a normal convergence, get the last data
364
+ if b_conv :
365
+ if self .return_indices :
366
+ X_ , y_ , = X_enn , y_enn
367
+ idx_under = idx_under [idx_enn ]
368
+ else :
369
+ X_ , y_ , = X_enn , y_enn
370
+ # Log the variables to explain the stop of the algorithm
371
+ self .logger .debug ('RENN converged: %s' , b_conv )
372
+ self .logger .debug ('RENN minority become majority: %s' ,
373
+ b_min_bec_maj )
374
+ self .logger .debug ('RENN remove one class: %s' ,
375
+ b_remove_maj_class )
339
376
break
340
377
378
+ # Update the data for the next iteration
379
+ X_ , y_ , = X_enn , y_enn
380
+ if self .return_indices :
381
+ idx_under = idx_under [idx_enn ]
382
+
341
383
self .logger .info ('Under-sampling performed: %s' , Counter (y_ ))
342
384
343
385
X_resampled , y_resampled = X_ , y_
0 commit comments