diff --git a/imblearn/utils/estimator_checks.py b/imblearn/utils/estimator_checks.py index e8f6f7fb4..48ba0077d 100644 --- a/imblearn/utils/estimator_checks.py +++ b/imblearn/utils/estimator_checks.py @@ -59,6 +59,22 @@ sklearn_version = parse_version(sklearn.__version__) +def sample_dataset_generator(): + X, y = make_classification( + n_samples=1000, + n_classes=3, + n_informative=4, + weights=[0.2, 0.3, 0.5], + random_state=0, + ) + return X, y + + +@pytest.fixture(name="sample_dataset_generator") +def sample_dataset_generator_fixture(): + return sample_dataset_generator() + + def _set_checking_parameters(estimator): params = estimator.get_params() name = estimator.__class__.__name__ @@ -233,13 +249,7 @@ def check_samplers_fit(name, sampler_orig): def check_samplers_fit_resample(name, sampler_orig): sampler = clone(sampler_orig) - X, y = make_classification( - n_samples=1000, - n_classes=3, - n_informative=4, - weights=[0.2, 0.3, 0.5], - random_state=0, - ) + X, y = sample_dataset_generator() target_stats = Counter(y) X_res, y_res = sampler.fit_resample(X, y) if isinstance(sampler, BaseOverSampler): @@ -269,13 +279,7 @@ def check_samplers_fit_resample(name, sampler_orig): def check_samplers_sampling_strategy_fit_resample(name, sampler_orig): sampler = clone(sampler_orig) # in this test we will force all samplers to not change the class 1 - X, y = make_classification( - n_samples=1000, - n_classes=3, - n_informative=4, - weights=[0.2, 0.3, 0.5], - random_state=0, - ) + X, y = sample_dataset_generator() expected_stat = Counter(y)[1] if isinstance(sampler, BaseOverSampler): sampling_strategy = {2: 498, 0: 498} @@ -298,13 +302,7 @@ def check_samplers_sparse(name, sampler_orig): sampler = clone(sampler_orig) # check that sparse matrices can be passed through the sampler leading to # the same results than dense - X, y = make_classification( - n_samples=1000, - n_classes=3, - n_informative=4, - weights=[0.2, 0.3, 0.5], - random_state=0, - ) + X, y = sample_dataset_generator() X_sparse = sparse.csr_matrix(X) X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y) sampler = clone(sampler) @@ -318,13 +316,7 @@ def check_samplers_pandas(name, sampler_orig): pd = pytest.importorskip("pandas") sampler = clone(sampler_orig) # Check that the samplers handle pandas dataframe and pandas series - X, y = make_classification( - n_samples=1000, - n_classes=3, - n_informative=4, - weights=[0.2, 0.3, 0.5], - random_state=0, - ) + X, y = sample_dataset_generator() X_df = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])]) y_df = pd.DataFrame(y) y_s = pd.Series(y, name="class") @@ -351,13 +343,7 @@ def check_samplers_pandas(name, sampler_orig): def check_samplers_list(name, sampler_orig): sampler = clone(sampler_orig) # Check that the can samplers handle simple lists - X, y = make_classification( - n_samples=1000, - n_classes=3, - n_informative=4, - weights=[0.2, 0.3, 0.5], - random_state=0, - ) + X, y = sample_dataset_generator() X_list = X.tolist() y_list = y.tolist() @@ -374,13 +360,7 @@ def check_samplers_list(name, sampler_orig): def check_samplers_multiclass_ova(name, sampler_orig): sampler = clone(sampler_orig) # Check that multiclass target lead to the same results than OVA encoding - X, y = make_classification( - n_samples=1000, - n_classes=3, - n_informative=4, - weights=[0.2, 0.3, 0.5], - random_state=0, - ) + X, y = sample_dataset_generator() y_ova = label_binarize(y, classes=np.unique(y)) X_res, y_res = sampler.fit_resample(X, y) X_res_ova, y_res_ova = sampler.fit_resample(X, y_ova) @@ -391,13 +371,7 @@ def check_samplers_multiclass_ova(name, sampler_orig): def check_samplers_2d_target(name, sampler_orig): sampler = clone(sampler_orig) - X, y = make_classification( - n_samples=100, - n_classes=3, - n_informative=4, - weights=[0.2, 0.3, 0.5], - random_state=0, - ) + X, y = sample_dataset_generator() y = y.reshape(-1, 1) # Make the target 2d sampler.fit_resample(X, y) @@ -405,13 +379,7 @@ def check_samplers_2d_target(name, sampler_orig): def check_samplers_preserve_dtype(name, sampler_orig): sampler = clone(sampler_orig) - X, y = make_classification( - n_samples=1000, - n_classes=3, - n_informative=4, - weights=[0.2, 0.3, 0.5], - random_state=0, - ) + X, y = sample_dataset_generator() # Cast X and y to not default dtype X = X.astype(np.float32) y = y.astype(np.int32) @@ -422,13 +390,7 @@ def check_samplers_preserve_dtype(name, sampler_orig): def check_samplers_sample_indices(name, sampler_orig): sampler = clone(sampler_orig) - X, y = make_classification( - n_samples=1000, - n_classes=3, - n_informative=4, - weights=[0.2, 0.3, 0.5], - random_state=0, - ) + X, y = sample_dataset_generator() sampler.fit_resample(X, y) sample_indices = sampler._get_tags().get("sample_indices", None) if sample_indices: