Skip to content

Commit 7a1459e

Browse files
author
Guillaume Lemaitre
committed
Solving the issue of the stopping criterion of the RENN
Conflicts: imblearn/under_sampling/tests/test_repeated_edited_nearest_neighbours.py
1 parent 04f8d30 commit 7a1459e

File tree

2 files changed

+77
-5
lines changed

2 files changed

+77
-5
lines changed

imblearn/under_sampling/edited_nearest_neighbours.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -330,14 +330,56 @@ def _sample(self, X, y):
330330

331331
prev_len = y_.shape[0]
332332
if self.return_indices:
333-
X_, y_, idx_ = self.enn_.fit_sample(X_, y_)
334-
idx_under = idx_under[idx_]
333+
X_enn, y_enn, idx_enn = self.enn_.fit_sample(X_, y_)
335334
else:
336-
X_, y_ = self.enn_.fit_sample(X_, y_)
337-
338-
if prev_len == y_.shape[0]:
335+
X_enn, y_enn = self.enn_.fit_sample(X_, y_)
336+
337+
# Check the stopping criterion
338+
# 1. If there is no changes for the vector y
339+
# 2. If the number of samples in the other class become inferior to
340+
# the number of samples in the majority class
341+
# 3. If one of the class is disappearing
342+
343+
# Case 1
344+
b_conv = (prev_len == y_enn.shape[0])
345+
346+
# Case 2
347+
stats_enn = Counter(y_enn)
348+
self.logger.debug('Current ENN stats: %s', stats_enn)
349+
# Get the number of samples in the non-minority classes
350+
count_non_min = np.array([val for val, key
351+
in zip(stats_enn.itervalues(),
352+
stats_enn.iterkeys())
353+
if key != self.min_c_])
354+
self.logger.debug('Number of samples in the non-majority'
355+
' classes: %s', count_non_min)
356+
# Check the minority stop to be the minority
357+
b_min_bec_maj = np.any(count_non_min < self.stats_c_[self.min_c_])
358+
359+
# Case 3
360+
b_remove_maj_class = (len(stats_enn) < len(self.stats_c_))
361+
362+
if b_conv or b_min_bec_maj or b_remove_maj_class:
363+
# If this is a normal convergence, get the last data
364+
if b_conv:
365+
if self.return_indices:
366+
X_, y_, = X_enn, y_enn
367+
idx_under = idx_under[idx_enn]
368+
else:
369+
X_, y_, = X_enn, y_enn
370+
# Log the variables to explain the stop of the algorithm
371+
self.logger.debug('RENN converged: %s', b_conv)
372+
self.logger.debug('RENN minority become majority: %s',
373+
b_min_bec_maj)
374+
self.logger.debug('RENN remove one class: %s',
375+
b_remove_maj_class)
339376
break
340377

378+
# Update the data for the next iteration
379+
X_, y_, = X_enn, y_enn
380+
if self.return_indices:
381+
idx_under = idx_under[idx_enn]
382+
341383
self.logger.info('Under-sampling performed: %s', Counter(y_))
342384

343385
X_resampled, y_resampled = X_, y_

imblearn/under_sampling/tests/test_repeated_edited_nearest_neighbours.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from numpy.testing import assert_array_equal
1010
from numpy.testing import assert_warns
1111

12+
from collections import Counter
13+
1214
from sklearn.datasets import make_classification
1315
from sklearn.utils.estimator_checks import check_estimator
1416

@@ -140,3 +142,31 @@ def test_renn_sample_wrong_X():
140142
renn.fit(X, Y)
141143
assert_raises(RuntimeError, renn.sample, np.random.random((100, 40)),
142144
np.array([0] * 50 + [1] * 50))
145+
146+
147+
def test_continuous_error():
148+
"""Test either if an error is raised when the target are continuous
149+
type"""
150+
151+
# continuous case
152+
y = np.linspace(0, 1, 5000)
153+
enn = RepeatedEditedNearestNeighbours(random_state=RND_SEED)
154+
assert_warns(UserWarning, enn.fit, X, y)
155+
156+
157+
def test_multiclass_fit_sample():
158+
"""Test fit sample method with multiclass target"""
159+
160+
# Make y to be multiclass
161+
y = Y.copy()
162+
y[0:1000] = 2
163+
164+
# Resample the data
165+
enn = RepeatedEditedNearestNeighbours(random_state=RND_SEED)
166+
X_resampled, y_resampled = enn.fit_sample(X, y)
167+
168+
# Check the size of y
169+
count_y_res = Counter(y_resampled)
170+
assert_equal(count_y_res[0], 400)
171+
assert_equal(count_y_res[1], 3600)
172+
assert_equal(count_y_res[2], 1000)

0 commit comments

Comments
 (0)