diff --git a/README.rst b/README.rst index 900d8dc67..b238744cc 100644 --- a/README.rst +++ b/README.rst @@ -30,7 +30,7 @@ .. |PythonMinVersion| replace:: 3.8 .. |NumPyMinVersion| replace:: 1.17.3 .. |SciPyMinVersion| replace:: 1.3.2 -.. |ScikitLearnMinVersion| replace:: 1.1.3 +.. |ScikitLearnMinVersion| replace:: 1.0.2 .. |MatplotlibMinVersion| replace:: 3.1.2 .. |PandasMinVersion| replace:: 1.0.5 .. |TensorflowMinVersion| replace:: 2.4.3 diff --git a/azure-pipelines.yml b/azure-pipelines.yml index 2c5ec9af2..5a7c4aa99 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -158,7 +158,7 @@ jobs: # Linux environment to test the latest available dependencies and MKL. pylatest_pip_openblas_pandas: DISTRIB: 'conda-pip-latest' - PYTHON_VERSION: '3.9' + PYTHON_VERSION: '*' TEST_DOCS: 'true' TEST_DOCSTRINGS: 'true' CHECK_WARNINGS: 'true' @@ -185,7 +185,7 @@ jobs: TENSORFLOW_VERSION: 'min' TEST_DOCS: 'true' TEST_DOCSTRINGS: 'false' # it is going to fail because of scikit-learn inheritance - CHECK_WARNINGS: 'true' + CHECK_WARNINGS: 'false' # in case the older version raise some FutureWarnings pylatest_pip_keras: DISTRIB: 'conda-pip-latest-keras' CONDA_CHANNEL: 'conda-forge' @@ -209,7 +209,7 @@ jobs: KERAS_VERSION: 'min' TEST_DOCS: 'true' TEST_DOCSTRINGS: 'false' # it is going to fail because of scikit-learn inheritance - CHECK_WARNINGS: 'true' + CHECK_WARNINGS: 'false' # in case the older version raise some FutureWarnings # Currently runs on Python 3.8 while only Python 3.7 available # - template: build_tools/azure/posix-docker.yml diff --git a/build_tools/azure/posix-docker.yml b/build_tools/azure/posix-docker.yml index b9bc59b99..cd85e4644 100644 --- a/build_tools/azure/posix-docker.yml +++ b/build_tools/azure/posix-docker.yml @@ -30,6 +30,7 @@ jobs: THREADPOOLCTL_VERSION: 'latest' COVERAGE: 'false' TEST_DOCSTRINGS: 'false' + CHECK_WARNINGS: 'false' BLAS: 'openblas' # Set in azure-pipelines.yml DISTRIB: '' diff --git a/build_tools/azure/posix.yml b/build_tools/azure/posix.yml index b5c92f520..59367a29e 100644 --- a/build_tools/azure/posix.yml +++ b/build_tools/azure/posix.yml @@ -36,6 +36,7 @@ jobs: COVERAGE: 'true' TEST_DOCS: 'false' TEST_DOCSTRINGS: 'false' + CHECK_WARNINGS: 'false' SHOW_SHORT_SUMMARY: 'false' strategy: matrix: diff --git a/build_tools/azure/test_script.sh b/build_tools/azure/test_script.sh index 57136b41c..446b08b38 100755 --- a/build_tools/azure/test_script.sh +++ b/build_tools/azure/test_script.sh @@ -34,7 +34,7 @@ if [[ "$COVERAGE" == "true" ]]; then TEST_CMD="$TEST_CMD --cov-config='$COVERAGE_PROCESS_START' --cov imblearn --cov-report=" fi -if [[ -n "$CHECK_WARNINGS" ]]; then +if [[ "$CHECK_WARNINGS" == "true" ]]; then # numpy's 1.19.0's tostring() deprecation is ignored until scipy and joblib removes its usage TEST_CMD="$TEST_CMD -Werror::DeprecationWarning -Werror::FutureWarning -Wignore:tostring:DeprecationWarning" diff --git a/build_tools/azure/windows.yml b/build_tools/azure/windows.yml index 5b986a6ef..240e4c24d 100644 --- a/build_tools/azure/windows.yml +++ b/build_tools/azure/windows.yml @@ -21,6 +21,7 @@ jobs: PYTEST_XDIST_VERSION: 'latest' TEST_DIR: '$(Agent.WorkFolder)/tmp_folder' CPU_COUNT: '2' + CHECK_WARNINGS: 'false' strategy: matrix: ${{ insert }}: ${{ parameters.matrix }} diff --git a/doc/ensemble.rst b/doc/ensemble.rst index 21d6a6e0c..e556e4693 100644 --- a/doc/ensemble.rst +++ b/doc/ensemble.rst @@ -38,7 +38,7 @@ data set, this classifier will favor the majority classes:: >>> bc.fit(X_train, y_train) #doctest: BaggingClassifier(...) >>> y_pred = bc.predict(X_test) - >>> balanced_accuracy_score(y_test, y_pred) # doctest: + >>> balanced_accuracy_score(y_test, y_pred) 0.77... In :class:`BalancedBaggingClassifier`, each bootstrap sample will be further @@ -54,10 +54,10 @@ sampling is controlled by the parameter `sampler` or the two parameters ... sampling_strategy='auto', ... replacement=False, ... random_state=0) - >>> bbc.fit(X_train, y_train) # doctest: + >>> bbc.fit(X_train, y_train) BalancedBaggingClassifier(...) >>> y_pred = bbc.predict(X_test) - >>> balanced_accuracy_score(y_test, y_pred) # doctest: + >>> balanced_accuracy_score(y_test, y_pred) 0.8... Changing the `sampler` will give rise to different known implementation @@ -78,10 +78,10 @@ each tree of the forest will be provided a balanced bootstrap sample >>> from imblearn.ensemble import BalancedRandomForestClassifier >>> brf = BalancedRandomForestClassifier(n_estimators=100, random_state=0) - >>> brf.fit(X_train, y_train) # doctest: + >>> brf.fit(X_train, y_train) BalancedRandomForestClassifier(...) >>> y_pred = brf.predict(X_test) - >>> balanced_accuracy_score(y_test, y_pred) # doctest: + >>> balanced_accuracy_score(y_test, y_pred) 0.8... .. _boosting: @@ -97,10 +97,10 @@ a boosting iteration :cite:`seiffert2009rusboost`:: >>> from imblearn.ensemble import RUSBoostClassifier >>> rusboost = RUSBoostClassifier(n_estimators=200, algorithm='SAMME.R', ... random_state=0) - >>> rusboost.fit(X_train, y_train) # doctest: + >>> rusboost.fit(X_train, y_train) RUSBoostClassifier(...) >>> y_pred = rusboost.predict(X_test) - >>> balanced_accuracy_score(y_test, y_pred) # doctest: + >>> balanced_accuracy_score(y_test, y_pred) 0... A specific method which uses :class:`~sklearn.ensemble.AdaBoostClassifier` as @@ -111,10 +111,10 @@ the :class:`BalancedBaggingClassifier` API, one can construct the ensemble as:: >>> from imblearn.ensemble import EasyEnsembleClassifier >>> eec = EasyEnsembleClassifier(random_state=0) - >>> eec.fit(X_train, y_train) # doctest: + >>> eec.fit(X_train, y_train) EasyEnsembleClassifier(...) >>> y_pred = eec.predict(X_test) - >>> balanced_accuracy_score(y_test, y_pred) # doctest: + >>> balanced_accuracy_score(y_test, y_pred) 0.6... .. topic:: Examples diff --git a/imblearn/_min_dependencies.py b/imblearn/_min_dependencies.py index 72976f2b1..aaa5ce9ae 100644 --- a/imblearn/_min_dependencies.py +++ b/imblearn/_min_dependencies.py @@ -4,7 +4,7 @@ NUMPY_MIN_VERSION = "1.17.3" SCIPY_MIN_VERSION = "1.3.2" PANDAS_MIN_VERSION = "1.0.5" -SKLEARN_MIN_VERSION = "1.1.3" +SKLEARN_MIN_VERSION = "1.0.2" TENSORFLOW_MIN_VERSION = "2.4.3" KERAS_MIN_VERSION = "2.4.3" JOBLIB_MIN_VERSION = "1.1.1" diff --git a/imblearn/combine/_smote_enn.py b/imblearn/combine/_smote_enn.py index b13b47a32..20c650776 100644 --- a/imblearn/combine/_smote_enn.py +++ b/imblearn/combine/_smote_enn.py @@ -89,7 +89,7 @@ class SMOTEENN(BaseSampler): >>> from collections import Counter >>> from sklearn.datasets import make_classification - >>> from imblearn.combine import SMOTEENN # doctest: + >>> from imblearn.combine import SMOTEENN >>> X, y = make_classification(n_classes=2, class_sep=2, ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) diff --git a/imblearn/combine/_smote_tomek.py b/imblearn/combine/_smote_tomek.py index 80d9675a8..a6996d81d 100644 --- a/imblearn/combine/_smote_tomek.py +++ b/imblearn/combine/_smote_tomek.py @@ -87,8 +87,7 @@ class SMOTETomek(BaseSampler): >>> from collections import Counter >>> from sklearn.datasets import make_classification - >>> from imblearn.combine import \ -SMOTETomek # doctest: + >>> from imblearn.combine import SMOTETomek >>> X, y = make_classification(n_classes=2, class_sep=2, ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) diff --git a/imblearn/ensemble/_bagging.py b/imblearn/ensemble/_bagging.py index 56b19c06f..5a80156cd 100644 --- a/imblearn/ensemble/_bagging.py +++ b/imblearn/ensemble/_bagging.py @@ -9,16 +9,23 @@ import warnings import numpy as np +from joblib import Parallel from sklearn.base import clone from sklearn.ensemble import BaggingClassifier +from sklearn.ensemble._bagging import _parallel_decision_function +from sklearn.ensemble._base import _partition_estimators from sklearn.tree import DecisionTreeClassifier +from sklearn.utils.fixes import delayed +from sklearn.utils.validation import check_is_fitted from ..pipeline import Pipeline from ..under_sampling import RandomUnderSampler from ..under_sampling.base import BaseUnderSampler from ..utils import Substitution, check_sampling_strategy, check_target_type +from ..utils._available_if import available_if from ..utils._docstring import _n_jobs_docstring, _random_state_docstring from ..utils._validation import _deprecate_positional_args +from ._common import _estimator_has @Substitution( @@ -229,8 +236,7 @@ class BalancedBaggingClassifier(BaggingClassifier): >>> from sklearn.datasets import make_classification >>> from sklearn.model_selection import train_test_split >>> from sklearn.metrics import confusion_matrix - >>> from imblearn.ensemble import \ -BalancedBaggingClassifier # doctest: + >>> from imblearn.ensemble import BalancedBaggingClassifier >>> X, y = make_classification(n_classes=2, class_sep=2, ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) @@ -239,7 +245,7 @@ class BalancedBaggingClassifier(BaggingClassifier): >>> X_train, X_test, y_train, y_test = train_test_split(X, y, ... random_state=0) >>> bbc = BalancedBaggingClassifier(random_state=42) - >>> bbc.fit(X_train, y_train) # doctest: + >>> bbc.fit(X_train, y_train) BalancedBaggingClassifier(...) >>> y_pred = bbc.predict(X_test) >>> print(confusion_matrix(y_test, y_pred)) @@ -408,6 +414,53 @@ def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None): # None. return super()._fit(X, y, self.max_samples, sample_weight=None) + # TODO: remove when minimum supported version of scikit-learn is 1.1 + @available_if(_estimator_has("decision_function")) + def decision_function(self, X): + """Average of the decision functions of the base classifiers. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Sparse matrices are accepted only if + they are supported by the base estimator. + + Returns + ------- + score : ndarray of shape (n_samples, k) + The decision function of the input samples. The columns correspond + to the classes in sorted order, as they appear in the attribute + ``classes_``. Regression and binary classification are special + cases with ``k == 1``, otherwise ``k==n_classes``. + """ + check_is_fitted(self) + + # Check data + X = self._validate_data( + X, + accept_sparse=["csr", "csc"], + dtype=None, + force_all_finite=False, + reset=False, + ) + + # Parallel loop + n_jobs, _, starts = _partition_estimators(self.n_estimators, self.n_jobs) + + all_decisions = Parallel(n_jobs=n_jobs, verbose=self.verbose)( + delayed(_parallel_decision_function)( + self.estimators_[starts[i] : starts[i + 1]], + self.estimators_features_[starts[i] : starts[i + 1]], + X, + ) + for i in range(n_jobs) + ) + + # Reduce + decisions = sum(all_decisions) / self.n_estimators + + return decisions + def _more_tags(self): tags = super()._more_tags() tags_key = "_xfail_checks" diff --git a/imblearn/ensemble/_common.py b/imblearn/ensemble/_common.py new file mode 100644 index 000000000..eb24e737d --- /dev/null +++ b/imblearn/ensemble/_common.py @@ -0,0 +1,15 @@ +def _estimator_has(attr): + """Check if we can delegate a method to the underlying estimator. + First, we check the first fitted estimator if available, otherwise we + check the estimator attribute. + """ + + def check(self): + if hasattr(self, "estimators_"): + return hasattr(self.estimators_[0], attr) + elif self.estimator is not None: + return hasattr(self.estimator, attr) + else: # TODO(1.4): Remove when the base_estimator deprecation cycle ends + return hasattr(self.base_estimator, attr) + + return check diff --git a/imblearn/ensemble/_easy_ensemble.py b/imblearn/ensemble/_easy_ensemble.py index 9303af8d4..b67f39fb2 100644 --- a/imblearn/ensemble/_easy_ensemble.py +++ b/imblearn/ensemble/_easy_ensemble.py @@ -9,15 +9,22 @@ import warnings import numpy as np +from joblib import Parallel from sklearn.base import clone from sklearn.ensemble import AdaBoostClassifier, BaggingClassifier +from sklearn.ensemble._bagging import _parallel_decision_function +from sklearn.ensemble._base import _partition_estimators +from sklearn.utils.fixes import delayed +from sklearn.utils.validation import check_is_fitted from ..pipeline import Pipeline from ..under_sampling import RandomUnderSampler from ..under_sampling.base import BaseUnderSampler from ..utils import Substitution, check_sampling_strategy, check_target_type +from ..utils._available_if import available_if from ..utils._docstring import _n_jobs_docstring, _random_state_docstring from ..utils._validation import _deprecate_positional_args +from ._common import _estimator_has MAX_INT = np.iinfo(np.int32).max @@ -31,7 +38,7 @@ class EasyEnsembleClassifier(BaggingClassifier): """Bag of balanced boosted learners also known as EasyEnsemble. This algorithm is known as EasyEnsemble [1]_. The classifier is an - ensemble of AdaBoost learners trained on different balanced boostrap + ensemble of AdaBoost learners trained on different balanced bootstrap samples. The balancing is achieved by random under-sampling. Read more in the :ref:`User Guide `. @@ -154,8 +161,7 @@ class EasyEnsembleClassifier(BaggingClassifier): >>> from sklearn.datasets import make_classification >>> from sklearn.model_selection import train_test_split >>> from sklearn.metrics import confusion_matrix - >>> from imblearn.ensemble import \ -EasyEnsembleClassifier # doctest: + >>> from imblearn.ensemble import EasyEnsembleClassifier >>> X, y = make_classification(n_classes=2, class_sep=2, ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) @@ -164,7 +170,7 @@ class EasyEnsembleClassifier(BaggingClassifier): >>> X_train, X_test, y_train, y_test = train_test_split(X, y, ... random_state=0) >>> eec = EasyEnsembleClassifier(random_state=42) - >>> eec.fit(X_train, y_train) # doctest: + >>> eec.fit(X_train, y_train) EasyEnsembleClassifier(...) >>> y_pred = eec.predict(X_test) >>> print(confusion_matrix(y_test, y_pred)) @@ -314,3 +320,50 @@ def _fit(self, X, y, max_samples=None, max_depth=None, sample_weight=None): # RandomUnderSampler is not supporting sample_weight. We need to pass # None. return super()._fit(X, y, self.max_samples, sample_weight=None) + + # TODO: remove when minimum supported version of scikit-learn is 1.1 + @available_if(_estimator_has("decision_function")) + def decision_function(self, X): + """Average of the decision functions of the base classifiers. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + The training input samples. Sparse matrices are accepted only if + they are supported by the base estimator. + + Returns + ------- + score : ndarray of shape (n_samples, k) + The decision function of the input samples. The columns correspond + to the classes in sorted order, as they appear in the attribute + ``classes_``. Regression and binary classification are special + cases with ``k == 1``, otherwise ``k==n_classes``. + """ + check_is_fitted(self) + + # Check data + X = self._validate_data( + X, + accept_sparse=["csr", "csc"], + dtype=None, + force_all_finite=False, + reset=False, + ) + + # Parallel loop + n_jobs, _, starts = _partition_estimators(self.n_estimators, self.n_jobs) + + all_decisions = Parallel(n_jobs=n_jobs, verbose=self.verbose)( + delayed(_parallel_decision_function)( + self.estimators_[starts[i] : starts[i + 1]], + self.estimators_features_[starts[i] : starts[i + 1]], + X, + ) + for i in range(n_jobs) + ) + + # Reduce + decisions = sum(all_decisions) / self.n_estimators + + return decisions diff --git a/imblearn/ensemble/_forest.py b/imblearn/ensemble/_forest.py index e753ef450..ff0ad8813 100644 --- a/imblearn/ensemble/_forest.py +++ b/imblearn/ensemble/_forest.py @@ -8,6 +8,7 @@ from warnings import warn import numpy as np +import sklearn from joblib import Parallel from numpy import float32 as DTYPE from numpy import float64 as DOUBLE @@ -22,7 +23,7 @@ ) from sklearn.exceptions import DataConversionWarning from sklearn.tree import DecisionTreeClassifier -from sklearn.utils import _safe_indexing, check_random_state +from sklearn.utils import _safe_indexing, check_random_state, parse_version from sklearn.utils.fixes import delayed from sklearn.utils.multiclass import type_of_target from sklearn.utils.validation import _check_sample_weight @@ -35,6 +36,7 @@ from ..utils._validation import _deprecate_positional_args, check_sampling_strategy MAX_INT = np.iinfo(np.int32).max +sklearn_version = parse_version(sklearn.__version__) def _local_parallel_build_trees( @@ -49,6 +51,7 @@ def _local_parallel_build_trees( verbose=0, class_weight=None, n_samples_bootstrap=None, + forest=None, ): # resample before to fit the tree X_resampled, y_resampled = sampler.fit_resample(X, y) @@ -56,18 +59,34 @@ def _local_parallel_build_trees( sample_weight = _safe_indexing(sample_weight, sampler.sample_indices_) if _get_n_samples_bootstrap is not None: n_samples_bootstrap = min(n_samples_bootstrap, X_resampled.shape[0]) - tree = _parallel_build_trees( - tree, - bootstrap, - X_resampled, - y_resampled, - sample_weight, - tree_idx, - n_trees, - verbose=verbose, - class_weight=class_weight, - n_samples_bootstrap=n_samples_bootstrap, - ) + + if sklearn_version >= parse_version("1.1"): + tree = _parallel_build_trees( + tree, + bootstrap, + X_resampled, + y_resampled, + sample_weight, + tree_idx, + n_trees, + verbose=verbose, + class_weight=class_weight, + n_samples_bootstrap=n_samples_bootstrap, + ) + else: + # TODO: remove when the minimum version of scikit-learn supported is 1.1 + tree = _parallel_build_trees( + tree, + forest, + X_resampled, + y_resampled, + sample_weight, + tree_idx, + n_trees, + verbose=verbose, + class_weight=class_weight, + n_samples_bootstrap=n_samples_bootstrap, + ) return sampler, tree @@ -324,9 +343,9 @@ class labels (multi-output problem). ... n_informative=4, weights=[0.2, 0.3, 0.5], ... random_state=0) >>> clf = BalancedRandomForestClassifier(max_depth=2, random_state=0) - >>> clf.fit(X, y) # doctest: + >>> clf.fit(X, y) BalancedRandomForestClassifier(...) - >>> print(clf.feature_importances_) # doctest: + >>> print(clf.feature_importances_) [...] >>> print(clf.predict([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ... 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])) @@ -582,6 +601,7 @@ def fit(self, X, y, sample_weight=None): verbose=self.verbose, class_weight=self.class_weight, n_samples_bootstrap=n_samples_bootstrap, + forest=self, ) for i, (s, t) in enumerate(zip(samplers, trees)) ) diff --git a/imblearn/ensemble/_weight_boosting.py b/imblearn/ensemble/_weight_boosting.py index db817e150..de6b4a004 100644 --- a/imblearn/ensemble/_weight_boosting.py +++ b/imblearn/ensemble/_weight_boosting.py @@ -152,9 +152,9 @@ class RUSBoostClassifier(AdaBoostClassifier): ... n_informative=4, weights=[0.2, 0.3, 0.5], ... random_state=0) >>> clf = RUSBoostClassifier(random_state=0) - >>> clf.fit(X, y) # doctest: + >>> clf.fit(X, y) RUSBoostClassifier(...) - >>> clf.predict(X) # doctest: + >>> clf.predict(X) array([...]) """ diff --git a/imblearn/ensemble/tests/test_bagging.py b/imblearn/ensemble/tests/test_bagging.py index c2b580d7f..9c23b59ff 100644 --- a/imblearn/ensemble/tests/test_bagging.py +++ b/imblearn/ensemble/tests/test_bagging.py @@ -67,9 +67,15 @@ def test_balanced_bagging_classifier(estimator, params): ) X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0) - BalancedBaggingClassifier(estimator=estimator, random_state=0, **params).fit( + bag = BalancedBaggingClassifier(estimator=estimator, random_state=0, **params).fit( X_train, y_train - ).predict(X_test) + ) + bag.predict(X_test) + bag.predict_proba(X_test) + bag.predict_log_proba(X_test) + bag.score(X_test, y_test) + if hasattr(estimator, "decision_function"): + bag.decision_function(X_test) def test_bootstrap_samples(): diff --git a/imblearn/over_sampling/_adasyn.py b/imblearn/over_sampling/_adasyn.py index caaa4b91d..fc52f8460 100644 --- a/imblearn/over_sampling/_adasyn.py +++ b/imblearn/over_sampling/_adasyn.py @@ -102,8 +102,7 @@ class ADASYN(BaseOverSampler): -------- >>> from collections import Counter >>> from sklearn.datasets import make_classification - >>> from imblearn.over_sampling import \ -ADASYN # doctest: + >>> from imblearn.over_sampling import ADASYN >>> X, y = make_classification(n_classes=2, class_sep=2, ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, ... n_features=20, n_clusters_per_class=1, n_samples=1000, diff --git a/imblearn/over_sampling/_random_over_sampler.py b/imblearn/over_sampling/_random_over_sampler.py index bd7f4026c..937d4a0fe 100644 --- a/imblearn/over_sampling/_random_over_sampler.py +++ b/imblearn/over_sampling/_random_over_sampler.py @@ -118,8 +118,7 @@ class RandomOverSampler(BaseOverSampler): -------- >>> from collections import Counter >>> from sklearn.datasets import make_classification - >>> from imblearn.over_sampling import \ -RandomOverSampler # doctest: + >>> from imblearn.over_sampling import RandomOverSampler >>> X, y = make_classification(n_classes=2, class_sep=2, ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) diff --git a/imblearn/over_sampling/_smote/base.py b/imblearn/over_sampling/_smote/base.py index 9aad1ec4a..2aaadcadb 100644 --- a/imblearn/over_sampling/_smote/base.py +++ b/imblearn/over_sampling/_smote/base.py @@ -289,8 +289,7 @@ class SMOTE(BaseSMOTE): -------- >>> from collections import Counter >>> from sklearn.datasets import make_classification - >>> from imblearn.over_sampling import \ -SMOTE # doctest: + >>> from imblearn.over_sampling import SMOTE >>> X, y = make_classification(n_classes=2, class_sep=2, ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) diff --git a/imblearn/over_sampling/_smote/filter.py b/imblearn/over_sampling/_smote/filter.py index ca968224d..92201f6f1 100644 --- a/imblearn/over_sampling/_smote/filter.py +++ b/imblearn/over_sampling/_smote/filter.py @@ -133,8 +133,7 @@ class BorderlineSMOTE(BaseSMOTE): -------- >>> from collections import Counter >>> from sklearn.datasets import make_classification - >>> from imblearn.over_sampling import \ -BorderlineSMOTE # doctest: + >>> from imblearn.over_sampling import BorderlineSMOTE >>> X, y = make_classification(n_classes=2, class_sep=2, ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) @@ -387,8 +386,7 @@ class SVMSMOTE(BaseSMOTE): -------- >>> from collections import Counter >>> from sklearn.datasets import make_classification - >>> from imblearn.over_sampling import \ -SVMSMOTE # doctest: + >>> from imblearn.over_sampling import SVMSMOTE >>> X, y = make_classification(n_classes=2, class_sep=2, ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) diff --git a/imblearn/pipeline.py b/imblearn/pipeline.py index 320f086d1..2b687aa69 100644 --- a/imblearn/pipeline.py +++ b/imblearn/pipeline.py @@ -90,7 +90,7 @@ class Pipeline(pipeline.Pipeline): >>> from sklearn.neighbors import KNeighborsClassifier as KNN >>> from sklearn.metrics import classification_report >>> from imblearn.over_sampling import SMOTE - >>> from imblearn.pipeline import Pipeline # doctest: + >>> from imblearn.pipeline import Pipeline >>> X, y = make_classification(n_classes=2, class_sep=2, ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) @@ -101,7 +101,7 @@ class Pipeline(pipeline.Pipeline): >>> knn = KNN() >>> pipeline = Pipeline([('smt', smt), ('pca', pca), ('knn', knn)]) >>> X_train, X_test, y_train, y_test = tts(X, y, random_state=42) - >>> pipeline.fit(X_train, y_train) # doctest: + >>> pipeline.fit(X_train, y_train) Pipeline(...) >>> y_hat = pipeline.predict(X_test) >>> print(classification_report(y_test, y_hat)) @@ -437,7 +437,6 @@ def make_pipeline(*steps, memory=None, verbose=False): >>> from sklearn.naive_bayes import GaussianNB >>> from sklearn.preprocessing import StandardScaler >>> make_pipeline(StandardScaler(), GaussianNB(priors=None)) - ... # doctest: Pipeline(steps=[('standardscaler', StandardScaler()), ('gaussiannb', GaussianNB())]) """ diff --git a/imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py b/imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py index 0378e43fb..b0f26df48 100644 --- a/imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py +++ b/imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py @@ -91,18 +91,18 @@ class CondensedNearestNeighbour(BaseCleaningSampler): Examples -------- - >>> from collections import Counter # doctest: +SKIP - >>> from sklearn.datasets import fetch_mldata # doctest: +SKIP + >>> from collections import Counter # doctest: +SKIP + >>> from sklearn.datasets import fetch_mldata # doctest: +SKIP >>> from imblearn.under_sampling import \ -CondensedNearestNeighbour # doctest: +SKIP - >>> pima = fetch_mldata('diabetes_scale') # doctest: +SKIP - >>> X, y = pima['data'], pima['target'] # doctest: +SKIP - >>> print('Original dataset shape %s' % Counter(y)) # doctest: +SKIP - Original dataset shape Counter({{1: 500, -1: 268}}) # doctest: +SKIP - >>> cnn = CondensedNearestNeighbour(random_state=42) # doctest: +SKIP - >>> X_res, y_res = cnn.fit_resample(X, y) #doctest: +SKIP - >>> print('Resampled dataset shape %s' % Counter(y_res)) # doctest: +SKIP - Resampled dataset shape Counter({{-1: 268, 1: 227}}) # doctest: +SKIP +CondensedNearestNeighbour # doctest: +SKIP + >>> pima = fetch_mldata('diabetes_scale') # doctest: +SKIP + >>> X, y = pima['data'], pima['target'] # doctest: +SKIP + >>> print('Original dataset shape %s' % Counter(y)) # doctest: +SKIP + Original dataset shape Counter({{1: 500, -1: 268}}) # doctest: +SKIP + >>> cnn = CondensedNearestNeighbour(random_state=42) # doctest: +SKIP + >>> X_res, y_res = cnn.fit_resample(X, y) #doctest: +SKIP + >>> print('Resampled dataset shape %s' % Counter(y_res)) # doctest: +SKIP + Resampled dataset shape Counter({{-1: 268, 1: 227}}) # doctest: +SKIP """ @_deprecate_positional_args diff --git a/imblearn/under_sampling/_prototype_selection/_edited_nearest_neighbours.py b/imblearn/under_sampling/_prototype_selection/_edited_nearest_neighbours.py index 2390db065..3d096f684 100644 --- a/imblearn/under_sampling/_prototype_selection/_edited_nearest_neighbours.py +++ b/imblearn/under_sampling/_prototype_selection/_edited_nearest_neighbours.py @@ -101,8 +101,7 @@ class EditedNearestNeighbours(BaseCleaningSampler): >>> from collections import Counter >>> from sklearn.datasets import make_classification - >>> from imblearn.under_sampling import \ -EditedNearestNeighbours # doctest: + >>> from imblearn.under_sampling import EditedNearestNeighbours >>> X, y = make_classification(n_classes=2, class_sep=2, ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) @@ -469,8 +468,7 @@ class without early stopping. -------- >>> from collections import Counter >>> from sklearn.datasets import make_classification - >>> from imblearn.under_sampling import \ -AllKNN # doctest: + >>> from imblearn.under_sampling import AllKNN >>> X, y = make_classification(n_classes=2, class_sep=2, ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) diff --git a/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py b/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py index 6835753b4..21cbbdf92 100644 --- a/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py +++ b/imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py @@ -97,7 +97,7 @@ class InstanceHardnessThreshold(BaseUnderSampler): Original dataset shape Counter({{1: 900, 0: 100}}) >>> iht = InstanceHardnessThreshold(random_state=42) >>> X_res, y_res = iht.fit_resample(X, y) - >>> print('Resampled dataset shape %s' % Counter(y_res)) # doctest: + >>> print('Resampled dataset shape %s' % Counter(y_res)) Resampled dataset shape Counter({{1: 5..., 0: 100}}) """ diff --git a/imblearn/under_sampling/_prototype_selection/_nearmiss.py b/imblearn/under_sampling/_prototype_selection/_nearmiss.py index d9571b82c..ce594ab52 100644 --- a/imblearn/under_sampling/_prototype_selection/_nearmiss.py +++ b/imblearn/under_sampling/_prototype_selection/_nearmiss.py @@ -93,8 +93,7 @@ class NearMiss(BaseUnderSampler): -------- >>> from collections import Counter >>> from sklearn.datasets import make_classification - >>> from imblearn.under_sampling import \ -NearMiss # doctest: + >>> from imblearn.under_sampling import NearMiss >>> X, y = make_classification(n_classes=2, class_sep=2, ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) diff --git a/imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py b/imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py index e67eca5a8..7ff383b2a 100644 --- a/imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py +++ b/imblearn/under_sampling/_prototype_selection/_neighbourhood_cleaning_rule.py @@ -102,8 +102,7 @@ class NeighbourhoodCleaningRule(BaseCleaningSampler): -------- >>> from collections import Counter >>> from sklearn.datasets import make_classification - >>> from imblearn.under_sampling import \ -NeighbourhoodCleaningRule # doctest: + >>> from imblearn.under_sampling import NeighbourhoodCleaningRule >>> X, y = make_classification(n_classes=2, class_sep=2, ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) diff --git a/imblearn/under_sampling/_prototype_selection/_one_sided_selection.py b/imblearn/under_sampling/_prototype_selection/_one_sided_selection.py index 616bd1c16..01ce831ad 100644 --- a/imblearn/under_sampling/_prototype_selection/_one_sided_selection.py +++ b/imblearn/under_sampling/_prototype_selection/_one_sided_selection.py @@ -89,8 +89,7 @@ class OneSidedSelection(BaseCleaningSampler): >>> from collections import Counter >>> from sklearn.datasets import make_classification - >>> from imblearn.under_sampling import \ - OneSidedSelection # doctest: + >>> from imblearn.under_sampling import OneSidedSelection >>> X, y = make_classification(n_classes=2, class_sep=2, ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) diff --git a/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py b/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py index d54704c05..d86f6e9a3 100644 --- a/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py +++ b/imblearn/under_sampling/_prototype_selection/_random_under_sampler.py @@ -65,8 +65,7 @@ class RandomUnderSampler(BaseUnderSampler): -------- >>> from collections import Counter >>> from sklearn.datasets import make_classification - >>> from imblearn.under_sampling import \ -RandomUnderSampler # doctest: + >>> from imblearn.under_sampling import RandomUnderSampler >>> X, y = make_classification(n_classes=2, class_sep=2, ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) diff --git a/imblearn/under_sampling/_prototype_selection/_tomek_links.py b/imblearn/under_sampling/_prototype_selection/_tomek_links.py index f7cacacdb..dd1331a5c 100644 --- a/imblearn/under_sampling/_prototype_selection/_tomek_links.py +++ b/imblearn/under_sampling/_prototype_selection/_tomek_links.py @@ -71,8 +71,7 @@ class TomekLinks(BaseCleaningSampler): -------- >>> from collections import Counter >>> from sklearn.datasets import make_classification - >>> from imblearn.under_sampling import \ -TomekLinks # doctest: + >>> from imblearn.under_sampling import TomekLinks >>> X, y = make_classification(n_classes=2, class_sep=2, ... weights=[0.1, 0.9], n_informative=3, n_redundant=1, flip_y=0, ... n_features=20, n_clusters_per_class=1, n_samples=1000, random_state=10) diff --git a/imblearn/utils/_available_if.py b/imblearn/utils/_available_if.py new file mode 100644 index 000000000..9b2c5e6db --- /dev/null +++ b/imblearn/utils/_available_if.py @@ -0,0 +1,99 @@ +"""This is a copy of sklearn/utils/_available_if.py. It can be removed when +we support scikit-learn >= 1.1. +""" +# mypy: ignore-errors + +from functools import update_wrapper, wraps +from types import MethodType + +import sklearn +from sklearn.utils import parse_version + +sklearn_version = parse_version(sklearn.__version__) + +if sklearn_version < parse_version("1.1"): + + class _AvailableIfDescriptor: + """Implements a conditional property using the descriptor protocol. + + Using this class to create a decorator will raise an ``AttributeError`` + if check(self) returns a falsey value. Note that if check raises an error + this will also result in hasattr returning false. + + See https://docs.python.org/3/howto/descriptor.html for an explanation of + descriptors. + """ + + def __init__(self, fn, check, attribute_name): + self.fn = fn + self.check = check + self.attribute_name = attribute_name + + # update the docstring of the descriptor + update_wrapper(self, fn) + + def __get__(self, obj, owner=None): + attr_err = AttributeError( + f"This {owner.__name__!r} has no attribute {self.attribute_name!r}" + ) + if obj is not None: + # delegate only on instances, not the classes. + # this is to allow access to the docstrings. + if not self.check(obj): + raise attr_err + out = MethodType(self.fn, obj) + + else: + # This makes it possible to use the decorated method as an + # unbound method, for instance when monkeypatching. + @wraps(self.fn) + def out(*args, **kwargs): + if not self.check(args[0]): + raise attr_err + return self.fn(*args, **kwargs) + + return out + + def available_if(check): + """An attribute that is available only if check returns a truthy value. + + Parameters + ---------- + check : callable + When passed the object with the decorated method, this should return + a truthy value if the attribute is available, and either return False + or raise an AttributeError if not available. + + Returns + ------- + callable + Callable makes the decorated method available if `check` returns + a truthy value, otherwise the decorated method is unavailable. + + Examples + -------- + >>> from sklearn.utils.metaestimators import available_if + >>> class HelloIfEven: + ... def __init__(self, x): + ... self.x = x + ... + ... def _x_is_even(self): + ... return self.x % 2 == 0 + ... + ... @available_if(_x_is_even) + ... def say_hello(self): + ... print("Hello") + ... + >>> obj = HelloIfEven(1) + >>> hasattr(obj, "say_hello") + False + >>> obj.x = 2 + >>> hasattr(obj, "say_hello") + True + >>> obj.say_hello() + Hello + """ + return lambda fn: _AvailableIfDescriptor(fn, check, attribute_name=fn.__name__) + +else: + from sklearn.utils.metaestimators import available_if # noqa diff --git a/imblearn/utils/estimator_checks.py b/imblearn/utils/estimator_checks.py index 14c5d9a1a..871f7ac26 100644 --- a/imblearn/utils/estimator_checks.py +++ b/imblearn/utils/estimator_checks.py @@ -12,6 +12,7 @@ import numpy as np import pytest +import sklearn from scipy import sparse from sklearn.base import clone from sklearn.cluster import KMeans @@ -28,12 +29,15 @@ assert_raises_regex, ) from sklearn.utils.estimator_checks import _get_check_estimator_ids, _maybe_mark_xfail +from sklearn.utils.fixes import parse_version from sklearn.utils.multiclass import type_of_target from imblearn.datasets import make_imbalance from imblearn.over_sampling.base import BaseOverSampler from imblearn.under_sampling.base import BaseCleaningSampler, BaseUnderSampler +sklearn_version = parse_version(sklearn.__version__) + def _set_checking_parameters(estimator): params = estimator.get_params() @@ -41,9 +45,13 @@ def _set_checking_parameters(estimator): if "n_estimators" in params: estimator.set_params(n_estimators=min(5, estimator.n_estimators)) if name == "ClusterCentroids": + if sklearn_version < parse_version("1.1"): + algorithm = "full" + else: + algorithm = "lloyd" estimator.set_params( voting="soft", - estimator=KMeans(random_state=0, algorithm="lloyd", n_init=1), + estimator=KMeans(random_state=0, algorithm=algorithm, n_init=1), ) if name == "KMeansSMOTE": estimator.set_params(kmeans_estimator=12)