55# License: MIT
66
77import numbers
8+ import warnings
89from collections import Counter
910
1011import numpy as np
12+ from sklearn .base import clone
13+ from sklearn .neighbors import KNeighborsClassifier , NearestNeighbors
1114from sklearn .utils import _safe_indexing
1215
13- from ...utils import Substitution , check_neighbors_object
16+ from ...utils import Substitution
1417from ...utils ._docstring import _n_jobs_docstring
15- from ...utils ._param_validation import HasMethods , Interval , StrOptions
16- from ...utils .fixes import _mode
18+ from ...utils ._param_validation import HasMethods , Hidden , Interval , StrOptions
1719from ..base import BaseCleaningSampler
1820from ._edited_nearest_neighbours import EditedNearestNeighbours
1921
@@ -35,9 +37,14 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler):
3537 ----------
3638 {sampling_strategy}
3739
40+ edited_nearest_neighbours : estimator object, default=None
41+ The :class:`~imblearn.under_sampling.EditedNearestNeighbours` (ENN)
42+ object to clean the dataset. If `None`, a default ENN is created with
43+ `kind_sel="mode"` and `n_neighbors=n_neighbors`.
44+
3845 n_neighbors : int or estimator object, default=3
3946 If ``int``, size of the neighbourhood to consider to compute the
40- nearest neighbors. If object, an estimator that inherits from
47+ K- nearest neighbors. If object, an estimator that inherits from
4148 :class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
4249 find the nearest-neighbors. By default, it will be a 3-NN.
4350
@@ -52,6 +59,11 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler):
5259 The strategy `"all"` will be less conservative than `'mode'`. Thus,
5360 more samples will be removed when `kind_sel="all"` generally.
5461
62+ .. deprecated:: 0.12
63+ `kind_sel` is deprecated in 0.12 and will be removed in 0.14.
64+ Currently the parameter has no effect and corresponds always to the
65+ `"all"` strategy.
66+
5567 threshold_cleaning : float, default=0.5
5668 Threshold used to whether consider a class or not during the cleaning
5769 after applying ENN. A class will be considered during cleaning when:
@@ -70,9 +82,16 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler):
7082 corresponds to the class labels from which to sample and the values
7183 are the number of samples to sample.
7284
85+ edited_nearest_neighbours_ : estimator object
86+ The edited nearest neighbour object used to make the first resampling.
87+
7388 nn_ : estimator object
7489 Validated K-nearest Neighbours object created from `n_neighbors` parameter.
7590
91+ classes_to_clean_ : list
92+ The classes considered with under-sampling by `nn_` in the second cleaning
93+ phase.
94+
7695 sample_indices_ : ndarray of shape (n_new_samples,)
7796 Indices of the samples selected.
7897
@@ -118,52 +137,75 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler):
118137 >>> ncr = NeighbourhoodCleaningRule()
119138 >>> X_res, y_res = ncr.fit_resample(X, y)
120139 >>> print('Resampled dataset shape %s' % Counter(y_res))
121- Resampled dataset shape Counter({{1: 877 , 0: 100}})
140+ Resampled dataset shape Counter({{1: 888 , 0: 100}})
122141 """
123142
124143 _parameter_constraints : dict = {
125144 ** BaseCleaningSampler ._parameter_constraints ,
145+ "edited_nearest_neighbours" : [
146+ HasMethods (["fit_resample" ]),
147+ None ,
148+ ],
126149 "n_neighbors" : [
127150 Interval (numbers .Integral , 1 , None , closed = "left" ),
128151 HasMethods (["kneighbors" , "kneighbors_graph" ]),
129152 ],
130- "kind_sel" : [StrOptions ({"all" , "mode" })],
131- "threshold_cleaning" : [Interval (numbers .Real , 0 , 1 , closed = "neither" )],
153+ "kind_sel" : [StrOptions ({"all" , "mode" }), Hidden ( StrOptions ({ "deprecated" })) ],
154+ "threshold_cleaning" : [Interval (numbers .Real , 0 , None , closed = "neither" )],
132155 "n_jobs" : [numbers .Integral , None ],
133156 }
134157
135158 def __init__ (
136159 self ,
137160 * ,
138161 sampling_strategy = "auto" ,
162+ edited_nearest_neighbours = None ,
139163 n_neighbors = 3 ,
140- kind_sel = "all " ,
164+ kind_sel = "deprecated " ,
141165 threshold_cleaning = 0.5 ,
142166 n_jobs = None ,
143167 ):
144168 super ().__init__ (sampling_strategy = sampling_strategy )
169+ self .edited_nearest_neighbours = edited_nearest_neighbours
145170 self .n_neighbors = n_neighbors
146171 self .kind_sel = kind_sel
147172 self .threshold_cleaning = threshold_cleaning
148173 self .n_jobs = n_jobs
149174
150175 def _validate_estimator (self ):
151176 """Create the objects required by NCR."""
152- self .nn_ = check_neighbors_object (
153- "n_neighbors" , self .n_neighbors , additional_neighbor = 1
154- )
155- self .nn_ .set_params (** {"n_jobs" : self .n_jobs })
177+ if isinstance (self .n_neighbors , numbers .Integral ):
178+ self .nn_ = KNeighborsClassifier (
179+ n_neighbors = self .n_neighbors , n_jobs = self .n_jobs
180+ )
181+ elif isinstance (self .n_neighbors , NearestNeighbors ):
182+ # backward compatibility when passing a NearestNeighbors object
183+ self .nn_ = KNeighborsClassifier (
184+ n_neighbors = self .n_neighbors .n_neighbors - 1 , n_jobs = self .n_jobs
185+ )
186+ else :
187+ self .nn_ = clone (self .n_neighbors )
188+
189+ if self .edited_nearest_neighbours is None :
190+ self .edited_nearest_neighbours_ = EditedNearestNeighbours (
191+ sampling_strategy = self .sampling_strategy ,
192+ n_neighbors = self .n_neighbors ,
193+ kind_sel = "mode" ,
194+ n_jobs = self .n_jobs ,
195+ )
196+ else :
197+ self .edited_nearest_neighbours_ = clone (self .edited_nearest_neighbours )
156198
157199 def _fit_resample (self , X , y ):
200+ if self .kind_sel != "deprecated" :
201+ warnings .warn (
202+ "`kind_sel` is deprecated in 0.12 and will be removed in 0.14. "
203+ "It already has not effect and corresponds to the `'all'` option." ,
204+ FutureWarning ,
205+ )
158206 self ._validate_estimator ()
159- enn = EditedNearestNeighbours (
160- sampling_strategy = self .sampling_strategy ,
161- n_neighbors = self .n_neighbors ,
162- kind_sel = "mode" ,
163- n_jobs = self .n_jobs ,
164- )
165- enn .fit_resample (X , y )
166- index_not_a1 = enn .sample_indices_
207+ self .edited_nearest_neighbours_ .fit_resample (X , y )
208+ index_not_a1 = self .edited_nearest_neighbours_ .sample_indices_
167209 index_a1 = np .ones (y .shape , dtype = bool )
168210 index_a1 [index_not_a1 ] = False
169211 index_a1 = np .flatnonzero (index_a1 )
@@ -172,30 +214,34 @@ def _fit_resample(self, X, y):
172214 target_stats = Counter (y )
173215 class_minority = min (target_stats , key = target_stats .get )
174216 # compute which classes to consider for cleaning for the A2 group
175- classes_under_sample = [
217+ self . classes_to_clean_ = [
176218 c
177219 for c , n_samples in target_stats .items ()
178220 if (
179221 c in self .sampling_strategy_ .keys ()
180- and (n_samples > X . shape [ 0 ] * self .threshold_cleaning )
222+ and (n_samples > target_stats [ class_minority ] * self .threshold_cleaning )
181223 )
182224 ]
183- self .nn_ .fit (X )
225+ self .nn_ .fit (X , y )
226+
184227 class_minority_indices = np .flatnonzero (y == class_minority )
185- X_class = _safe_indexing (X , class_minority_indices )
186- y_class = _safe_indexing (y , class_minority_indices )
187- nnhood_idx = self .nn_ .kneighbors (X_class , return_distance = False )[:, 1 :]
188- nnhood_label = y [nnhood_idx ]
189- if self .kind_sel == "mode" :
190- nnhood_label_majority , _ = _mode (nnhood_label , axis = 1 )
191- nnhood_bool = np .ravel (nnhood_label_majority ) == y_class
192- else : # self.kind_sel == "all":
193- nnhood_label_majority = nnhood_label == class_minority
194- nnhood_bool = np .all (nnhood_label , axis = 1 )
195- # compute a2 group
196- index_a2 = np .ravel (nnhood_idx [~ nnhood_bool ])
197- index_a2 = np .unique (
198- [index for index in index_a2 if y [index ] in classes_under_sample ]
228+ X_minority = _safe_indexing (X , class_minority_indices )
229+ y_minority = _safe_indexing (y , class_minority_indices )
230+
231+ y_pred_minority = self .nn_ .predict (X_minority )
232+ # add an additional sample since the query points contains the original dataset
233+ neighbors_to_minority_indices = self .nn_ .kneighbors (
234+ X_minority , n_neighbors = self .nn_ .n_neighbors + 1 , return_distance = False
235+ )[:, 1 :]
236+
237+ mask_misclassified_minority = y_pred_minority != y_minority
238+ index_a2 = np .ravel (neighbors_to_minority_indices [mask_misclassified_minority ])
239+ index_a2 = np .array (
240+ [
241+ index
242+ for index in np .unique (index_a2 )
243+ if y [index ] in self .classes_to_clean_
244+ ]
199245 )
200246
201247 union_a1_a2 = np .union1d (index_a1 , index_a2 ).astype (int )
0 commit comments