5858
5959sklearn_version = parse_version (sklearn .__version__ )
6060
61+
6162def sample_dataset_generator ():
6263 X , y = make_classification (
6364 n_samples = 1000 ,
@@ -67,10 +68,13 @@ def sample_dataset_generator():
6768 random_state = 0 ,
6869 )
6970 return X , y
71+
72+
7073@pytest .fixture (name = "sample_dataset_generator" )
7174def sample_dataset_generator_fixture ():
7275 return sample_dataset_generator ()
7376
77+
7478def _set_checking_parameters (estimator ):
7579 params = estimator .get_params ()
7680 name = estimator .__class__ .__name__
@@ -289,6 +293,7 @@ def check_samplers_sampling_strategy_fit_resample(name, sampler_orig):
289293 X_res , y_res = sampler .fit_resample (X , y )
290294 assert Counter (y_res )[1 ] == expected_stat
291295
296+
292297def check_samplers_sparse (name , sampler_orig ):
293298 sampler = clone (sampler_orig )
294299 # check that sparse matrices can be passed through the sampler leading to
@@ -348,7 +353,7 @@ def check_samplers_list(name, sampler_orig):
348353 assert_allclose (y_res , y_res_list )
349354
350355
351- def check_samplers_multiclass_ova (name , sampler_orig , sample_dataset_generator ):
356+ def check_samplers_multiclass_ova (name , sampler_orig ):
352357 sampler = clone (sampler_orig )
353358 # Check that multiclass target lead to the same results than OVA encoding
354359 X , y = sample_dataset_generator ()
@@ -360,15 +365,15 @@ def check_samplers_multiclass_ova(name, sampler_orig, sample_dataset_generator):
360365 assert_allclose (y_res , y_res_ova .argmax (axis = 1 ))
361366
362367
363- def check_samplers_2d_target (name , sampler_orig , sample_dataset_generator ):
368+ def check_samplers_2d_target (name , sampler_orig ):
364369 sampler = clone (sampler_orig )
365370 X , y = sample_dataset_generator ()
366371
367372 y = y .reshape (- 1 , 1 ) # Make the target 2d
368373 sampler .fit_resample (X , y )
369374
370375
371- def check_samplers_preserve_dtype (name , sampler_orig , sample_dataset_generator ):
376+ def check_samplers_preserve_dtype (name , sampler_orig ):
372377 sampler = clone (sampler_orig )
373378 X , y = sample_dataset_generator ()
374379 # Cast X and y to not default dtype
@@ -379,7 +384,7 @@ def check_samplers_preserve_dtype(name, sampler_orig, sample_dataset_generator):
379384 assert y .dtype == y_res .dtype , "y dtype is not preserved"
380385
381386
382- def check_samplers_sample_indices (name , sampler_orig , sample_dataset_generator ):
387+ def check_samplers_sample_indices (name , sampler_orig ):
383388 sampler = clone (sampler_orig )
384389 X , y = sample_dataset_generator ()
385390 sampler .fit_resample (X , y )
0 commit comments