Skip to content

Commit c83dd0c

Browse files
committed
Update imblearn/utils/estimator_checks.py
1 parent 83669e2 commit c83dd0c

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

imblearn/utils/estimator_checks.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858

5959
sklearn_version = parse_version(sklearn.__version__)
6060

61+
6162
def sample_dataset_generator():
6263
X, y = make_classification(
6364
n_samples=1000,
@@ -67,10 +68,13 @@ def sample_dataset_generator():
6768
random_state=0,
6869
)
6970
return X, y
71+
72+
7073
@pytest.fixture(name="sample_dataset_generator")
7174
def sample_dataset_generator_fixture():
7275
return sample_dataset_generator()
7376

77+
7478
def _set_checking_parameters(estimator):
7579
params = estimator.get_params()
7680
name = estimator.__class__.__name__
@@ -289,6 +293,7 @@ def check_samplers_sampling_strategy_fit_resample(name, sampler_orig):
289293
X_res, y_res = sampler.fit_resample(X, y)
290294
assert Counter(y_res)[1] == expected_stat
291295

296+
292297
def check_samplers_sparse(name, sampler_orig):
293298
sampler = clone(sampler_orig)
294299
# check that sparse matrices can be passed through the sampler leading to
@@ -348,7 +353,7 @@ def check_samplers_list(name, sampler_orig):
348353
assert_allclose(y_res, y_res_list)
349354

350355

351-
def check_samplers_multiclass_ova(name, sampler_orig, sample_dataset_generator):
356+
def check_samplers_multiclass_ova(name, sampler_orig):
352357
sampler = clone(sampler_orig)
353358
# Check that multiclass target lead to the same results than OVA encoding
354359
X, y = sample_dataset_generator()
@@ -360,15 +365,15 @@ def check_samplers_multiclass_ova(name, sampler_orig, sample_dataset_generator):
360365
assert_allclose(y_res, y_res_ova.argmax(axis=1))
361366

362367

363-
def check_samplers_2d_target(name, sampler_orig, sample_dataset_generator):
368+
def check_samplers_2d_target(name, sampler_orig):
364369
sampler = clone(sampler_orig)
365370
X, y = sample_dataset_generator()
366371

367372
y = y.reshape(-1, 1) # Make the target 2d
368373
sampler.fit_resample(X, y)
369374

370375

371-
def check_samplers_preserve_dtype(name, sampler_orig, sample_dataset_generator):
376+
def check_samplers_preserve_dtype(name, sampler_orig):
372377
sampler = clone(sampler_orig)
373378
X, y = sample_dataset_generator()
374379
# Cast X and y to not default dtype
@@ -379,7 +384,7 @@ def check_samplers_preserve_dtype(name, sampler_orig, sample_dataset_generator):
379384
assert y.dtype == y_res.dtype, "y dtype is not preserved"
380385

381386

382-
def check_samplers_sample_indices(name, sampler_orig, sample_dataset_generator):
387+
def check_samplers_sample_indices(name, sampler_orig):
383388
sampler = clone(sampler_orig)
384389
X, y = sample_dataset_generator()
385390
sampler.fit_resample(X, y)

0 commit comments

Comments
 (0)