Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 25 additions & 63 deletions imblearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@
sklearn_version = parse_version(sklearn.__version__)


def sample_dataset_generator():
X, y = make_classification(
n_samples=1000,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
return X, y


@pytest.fixture(name="sample_dataset_generator")
def sample_dataset_generator_fixture():
return sample_dataset_generator()


def _set_checking_parameters(estimator):
params = estimator.get_params()
name = estimator.__class__.__name__
Expand Down Expand Up @@ -233,13 +249,7 @@ def check_samplers_fit(name, sampler_orig):

def check_samplers_fit_resample(name, sampler_orig):
sampler = clone(sampler_orig)
X, y = make_classification(
n_samples=1000,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
X, y = sample_dataset_generator()
target_stats = Counter(y)
X_res, y_res = sampler.fit_resample(X, y)
if isinstance(sampler, BaseOverSampler):
Expand Down Expand Up @@ -269,13 +279,7 @@ def check_samplers_fit_resample(name, sampler_orig):
def check_samplers_sampling_strategy_fit_resample(name, sampler_orig):
sampler = clone(sampler_orig)
# in this test we will force all samplers to not change the class 1
X, y = make_classification(
n_samples=1000,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
X, y = sample_dataset_generator()
expected_stat = Counter(y)[1]
if isinstance(sampler, BaseOverSampler):
sampling_strategy = {2: 498, 0: 498}
Expand All @@ -298,13 +302,7 @@ def check_samplers_sparse(name, sampler_orig):
sampler = clone(sampler_orig)
# check that sparse matrices can be passed through the sampler leading to
# the same results than dense
X, y = make_classification(
n_samples=1000,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
X, y = sample_dataset_generator()
X_sparse = sparse.csr_matrix(X)
X_res_sparse, y_res_sparse = sampler.fit_resample(X_sparse, y)
sampler = clone(sampler)
Expand All @@ -318,13 +316,7 @@ def check_samplers_pandas(name, sampler_orig):
pd = pytest.importorskip("pandas")
sampler = clone(sampler_orig)
# Check that the samplers handle pandas dataframe and pandas series
X, y = make_classification(
n_samples=1000,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
X, y = sample_dataset_generator()
X_df = pd.DataFrame(X, columns=[str(i) for i in range(X.shape[1])])
y_df = pd.DataFrame(y)
y_s = pd.Series(y, name="class")
Expand All @@ -351,13 +343,7 @@ def check_samplers_pandas(name, sampler_orig):
def check_samplers_list(name, sampler_orig):
sampler = clone(sampler_orig)
# Check that the can samplers handle simple lists
X, y = make_classification(
n_samples=1000,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
X, y = sample_dataset_generator()
X_list = X.tolist()
y_list = y.tolist()

Expand All @@ -374,13 +360,7 @@ def check_samplers_list(name, sampler_orig):
def check_samplers_multiclass_ova(name, sampler_orig):
sampler = clone(sampler_orig)
# Check that multiclass target lead to the same results than OVA encoding
X, y = make_classification(
n_samples=1000,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
X, y = sample_dataset_generator()
y_ova = label_binarize(y, classes=np.unique(y))
X_res, y_res = sampler.fit_resample(X, y)
X_res_ova, y_res_ova = sampler.fit_resample(X, y_ova)
Expand All @@ -391,27 +371,15 @@ def check_samplers_multiclass_ova(name, sampler_orig):

def check_samplers_2d_target(name, sampler_orig):
sampler = clone(sampler_orig)
X, y = make_classification(
n_samples=100,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
X, y = sample_dataset_generator()

y = y.reshape(-1, 1) # Make the target 2d
sampler.fit_resample(X, y)


def check_samplers_preserve_dtype(name, sampler_orig):
sampler = clone(sampler_orig)
X, y = make_classification(
n_samples=1000,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
X, y = sample_dataset_generator()
# Cast X and y to not default dtype
X = X.astype(np.float32)
y = y.astype(np.int32)
Expand All @@ -422,13 +390,7 @@ def check_samplers_preserve_dtype(name, sampler_orig):

def check_samplers_sample_indices(name, sampler_orig):
sampler = clone(sampler_orig)
X, y = make_classification(
n_samples=1000,
n_classes=3,
n_informative=4,
weights=[0.2, 0.3, 0.5],
random_state=0,
)
X, y = sample_dataset_generator()
sampler.fit_resample(X, y)
sample_indices = sampler._get_tags().get("sample_indices", None)
if sample_indices:
Expand Down