diff --git a/doc/whats_new/v0.11.rst b/doc/whats_new/v0.11.rst index 5c303d0bf..a6cb21821 100644 --- a/doc/whats_new/v0.11.rst +++ b/doc/whats_new/v0.11.rst @@ -13,6 +13,10 @@ Bug fixes `bool` and `pd.category` by delegating the conversion to scikit-learn encoder. :pr:`1002` by :user:`Guillaume Lemaitre `. +- Handle sparse matrices in :class:`~imblearn.over_sampling.SMOTEN` and raise a warning + since it requires a conversion to dense matrices. + :pr:`1003` by :user:`Guillaume Lemaitre `. + Compatibility ............. diff --git a/imblearn/over_sampling/_smote/base.py b/imblearn/over_sampling/_smote/base.py index e4ea9d2d8..da880f87c 100644 --- a/imblearn/over_sampling/_smote/base.py +++ b/imblearn/over_sampling/_smote/base.py @@ -14,6 +14,7 @@ import numpy as np from scipy import sparse from sklearn.base import clone +from sklearn.exceptions import DataConversionWarning from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder from sklearn.utils import _safe_indexing, check_array, check_random_state from sklearn.utils.sparsefuncs_fast import ( @@ -893,7 +894,7 @@ def _check_X_y(self, X, y): y, reset=True, dtype=None, - accept_sparse=False, + accept_sparse=["csr", "csc"], ) return X, y, binarize_y @@ -927,6 +928,17 @@ def _fit_resample(self, X, y): FutureWarning, ) + if sparse.issparse(X): + X_sparse_format = X.format + X = X.toarray() + warnings.warn( + "Passing a sparse matrix to SMOTEN is not really efficient since it is" + " converted to a dense array internally.", + DataConversionWarning, + ) + else: + X_sparse_format = None + self._validate_estimator() X_resampled = [X.copy()] @@ -964,7 +976,12 @@ def _fit_resample(self, X, y): X_resampled = np.vstack(X_resampled) y_resampled = np.hstack(y_resampled) - return X_resampled, y_resampled + if X_sparse_format == "csr": + return sparse.csr_matrix(X_resampled), y_resampled + elif X_sparse_format == "csc": + return sparse.csc_matrix(X_resampled), y_resampled + else: + return X_resampled, y_resampled def _more_tags(self): return {"X_types": ["2darray", "dataframe", "string"]} diff --git a/imblearn/over_sampling/_smote/tests/test_smoten.py b/imblearn/over_sampling/_smote/tests/test_smoten.py index 6bd9d8356..db9c14f99 100644 --- a/imblearn/over_sampling/_smote/tests/test_smoten.py +++ b/imblearn/over_sampling/_smote/tests/test_smoten.py @@ -1,6 +1,8 @@ import numpy as np import pytest -from sklearn.preprocessing import OrdinalEncoder +from sklearn.exceptions import DataConversionWarning +from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder +from sklearn.utils._testing import _convert_container from imblearn.over_sampling import SMOTEN @@ -56,6 +58,24 @@ def test_smoten_resampling(): np.testing.assert_array_equal(y_generated, "not apple") +@pytest.mark.parametrize("sparse_format", ["sparse_csr", "sparse_csc"]) +def test_smoten_sparse_input(data, sparse_format): + """Check that we handle sparse input in SMOTEN even if it is not efficient. + + Non-regression test for: + https://github.com/scikit-learn-contrib/imbalanced-learn/issues/971 + """ + X, y = data + X = OneHotEncoder().fit_transform(X) + X = _convert_container(X, sparse_format) + + with pytest.warns(DataConversionWarning, match="is not really efficient"): + X_res, y_res = SMOTEN(random_state=0).fit_resample(X, y) + + assert X_res.format == X.format + assert X_res.shape[0] == len(y_res) + + def test_smoten_categorical_encoder(data): """Check that `categorical_encoder` is used when provided."""