3636from imblearn .over_sampling .base import BaseOverSampler
3737from imblearn .under_sampling .base import BaseCleaningSampler , BaseUnderSampler
3838
39+
3940def sample_dataset_generator ():
4041 X , y = make_classification (
4142 n_samples = 1000 ,
@@ -45,10 +46,13 @@ def sample_dataset_generator():
4546 random_state = 0 ,
4647 )
4748 return X , y
49+
50+
4851@pytest .fixture (name = "sample_dataset_generator" )
4952def sample_dataset_generator_fixture ():
5053 return sample_dataset_generator ()
5154
55+
5256def _set_checking_parameters (estimator ):
5357 params = estimator .get_params ()
5458 name = estimator .__class__ .__name__
@@ -261,6 +265,7 @@ def check_samplers_sampling_strategy_fit_resample(name, sampler_orig):
261265 X_res , y_res = sampler .fit_resample (X , y )
262266 assert Counter (y_res )[1 ] == expected_stat
263267
268+
264269def check_samplers_sparse (name , sampler_orig ):
265270 sampler = clone (sampler_orig )
266271 # check that sparse matrices can be passed through the sampler leading to
@@ -320,7 +325,7 @@ def check_samplers_list(name, sampler_orig):
320325 assert_allclose (y_res , y_res_list )
321326
322327
323- def check_samplers_multiclass_ova (name , sampler_orig , sample_dataset_generator ):
328+ def check_samplers_multiclass_ova (name , sampler_orig ):
324329 sampler = clone (sampler_orig )
325330 # Check that multiclass target lead to the same results than OVA encoding
326331 X , y = sample_dataset_generator ()
@@ -332,15 +337,15 @@ def check_samplers_multiclass_ova(name, sampler_orig, sample_dataset_generator):
332337 assert_allclose (y_res , y_res_ova .argmax (axis = 1 ))
333338
334339
335- def check_samplers_2d_target (name , sampler_orig , sample_dataset_generator ):
340+ def check_samplers_2d_target (name , sampler_orig ):
336341 sampler = clone (sampler_orig )
337342 X , y = sample_dataset_generator ()
338343
339344 y = y .reshape (- 1 , 1 ) # Make the target 2d
340345 sampler .fit_resample (X , y )
341346
342347
343- def check_samplers_preserve_dtype (name , sampler_orig , sample_dataset_generator ):
348+ def check_samplers_preserve_dtype (name , sampler_orig ):
344349 sampler = clone (sampler_orig )
345350 X , y = sample_dataset_generator ()
346351 # Cast X and y to not default dtype
@@ -351,7 +356,7 @@ def check_samplers_preserve_dtype(name, sampler_orig, sample_dataset_generator):
351356 assert y .dtype == y_res .dtype , "y dtype is not preserved"
352357
353358
354- def check_samplers_sample_indices (name , sampler_orig , sample_dataset_generator ):
359+ def check_samplers_sample_indices (name , sampler_orig ):
355360 sampler = clone (sampler_orig )
356361 X , y = sample_dataset_generator ()
357362 sampler .fit_resample (X , y )
0 commit comments