diff --git a/orca_python/classifiers/OrdinalDecomposition.py b/orca_python/classifiers/OrdinalDecomposition.py index 9ea2fbe..ff76995 100644 --- a/orca_python/classifiers/OrdinalDecomposition.py +++ b/orca_python/classifiers/OrdinalDecomposition.py @@ -3,13 +3,31 @@ import numpy as np from sklearn.base import BaseEstimator, ClassifierMixin, _fit_context from sklearn.utils._param_validation import StrOptions -from sklearn.utils.validation import check_array, check_is_fitted, check_X_y +from sklearn.utils.validation import check_is_fitted + +# scikit-learn >= 1.6 +try: + from sklearn.utils.validation import validate_data as _sk_validate_data + + def _validate_data_compat(estimator, X, y=None, *, reset=True, **kwargs): + y_arg = "no_validation" if y is None else y + return _sk_validate_data(estimator, X, y_arg, reset=reset, **kwargs) + + +# scikit-learn < 1.6 +except Exception: + + def _validate_data_compat(estimator, X, y=None, *, reset=True, **kwargs): + if y is None: + return estimator._validate_data(X, reset=reset, **kwargs) + return estimator._validate_data(X, y, reset=reset, **kwargs) + from orca_python.model_selection import load_classifier -class OrdinalDecomposition(BaseEstimator, ClassifierMixin): - """OrdinalDecomposition ensemble classifier. +class OrdinalDecomposition(ClassifierMixin, BaseEstimator): + """Ordinal decomposition ensemble classifier. This class implements an ensemble model where an ordinal problem is decomposed into several binary subproblems, each one of which will generate a different (binary) @@ -20,69 +38,67 @@ class OrdinalDecomposition(BaseEstimator, ClassifierMixin): Parameters ---------- - dtype : str - Type of decomposition to be performed by classifier. May be one of 4 different - types: 'ordered_partitions', 'one_vs_next', 'one_vs_followers' or - 'one_vs_previous'. - - The coding matrix generated by each method, for a problem with 5 classes will - be as follows: - - ordered_partitions one_vs_next one_vs_followers one_vs_previous - - -, -, -, -; -, , , ; -, , , ; +, +, +, +; - +, -, -, -; +, -, , ; +, -, , ; +, +, +, -; - +, +, -, -; , +, -, ; +, +, -, ; +, +, -, ; - +, +, +, -; , , +, -; +, +, +, -; +, -, , ; - +, +, +, +; , , , +; +, +, +, +; -, , , ; - - where rows represent classes and columns represent base classifiers. Plus signs - indicate that for that classifier, the label will be part of the positive - class, on the other hand, a minus sign places that class into the negative one - for that binary problem. If there is no sign, then those samples will not be - used when building the model. - - decision_method : str - Decision method that transforms the predictions of the n different base - classifiers to produce the final label (one among the real ordinal classes). - - base_classifier : str - Base classifier used to build a model for each binary subproblem. The base - classifier need to be a classifier of orca-python framework or any classifier - available in sklearn. Other classifiers implemented in sklearn's API can be - used here. - - parameters : dict - This dictionary will store the parameters used to build the base classifier. - Only one value per parameter is allowed. + dtype : {'ordered_partitions', 'one_vs_next', 'one_vs_followers', 'one_vs_previous'}, \ + default='ordered_partitions' + Type of decomposition used to build the coding matrix. Each row of the + coding matrix corresponds to a class and each column to a binary subproblem. + Entries are in {-1, 0, +1}: -1 for negative class, +1 for positive class, + and 0 if the class is ignored in that subproblem. + + decision_method : {'exponential_loss', 'hinge_loss', 'logarithmic_loss', 'frank_hall'}, \ + default='frank_hall' + Method to aggregate the predictions of the binary estimators into class + probabilities or labels. + + base_classifier : str, default='LogisticRegression' + Name of the base classifier to be instantiated via + :func:`orca_python.model_selection.load_classifier`. It can refer to + a classifier available in the orca-python framework or to a scikit-learn + compatible classifier. + + parameters : dict or None, default=None + Hyperparameters to initialize the base classifier. If ``None``, + defaults of the base classifier are used. Each key must map to a single value. Attributes ---------- - classes_ : list - List that contains all different class labels found in the original dataset. + estimators_ : list of estimators + Estimators used for predictions. + + classes_ : ndarray of shape (n_classes,) + Class labels for each output. + + n_features_in_ : int + Number of features seen during fit. coding_matrix_ : array-like, shape (n_targets, n_targets-1) Matrix that defines which classes will be used to build the model of each subproblem, and in which binary class they belong inside those new models. Further explained previously. - classifiers_ : list of classifiers - Initially empty, will include all fitted models for each subproblem once the fit - function for this class is called successfully. + Notes + ----- + For ``n_classes=5``, the four decomposition types generate the following + coding matrices (rows = classes, columns = binary subproblems). Entries are + ``+1`` for positive class membership and ``-1`` for negative class membership. + + :: - X_ : array-like, shape (n_samples, n_features) - Training patterns array, where n_samples is the number of samples and - n_features is the number of features. + ordered_partitions one_vs_next one_vs_followers one_vs_previous - y_ : array-like, shape (n_samples,) - Target vector relative to X. + [-1 -1 -1 -1] [-1 0 0 0] [-1 0 0 0] [ 1 1 1 1] + [ 1 -1 -1 -1] [ 1 -1 0 0] [ 1 -1 0 0] [ 1 1 1 -1] + [ 1 1 -1 -1] [ 0 1 -1 0] [ 1 1 -1 0] [ 1 1 -1 0] + [ 1 1 1 -1] [ 0 0 1 -1] [ 1 1 1 -1] [ 1 -1 0 0] + [ 1 1 1 1] [ 0 0 0 1] [ 1 1 1 1] [-1 0 0 0] References ---------- - .. [1] P.A. Gutierrez, M. Perez-Ortiz, J. Sanchez-Monedero, F. Fernandez-Navarro - and C. Hervas-Martinez, "Ordinal regression methods: survey and - experimental study", IEEE Transactions on Knowledge and Data Engineering, - Vol. 28. Issue 1, 2016, https://doi.org/10.1109/TKDE.2015.2457911 + .. [1] P.A. Gutiérrez, M. Pérez-Ortiz, J. Sánchez-Monedero, F. Fernández-Navarro + and C. Hervás-Martínez, "Ordinal regression methods: survey and + experimental study", IEEE Transactions on Knowledge and Data + Engineering, Vol. 28. Issue 1, 2016, + http://dx.doi.org/10.1109/TKDE.2015.2457911 """ @@ -103,7 +119,7 @@ class OrdinalDecomposition(BaseEstimator, ClassifierMixin): ) ], "base_classifier": [str], - "parameters": [dict], + "parameters": [dict, None], } def __init__( @@ -111,7 +127,7 @@ def __init__( dtype="ordered_partitions", decision_method="frank_hall", base_classifier="LogisticRegression", - parameters={}, + parameters=None, ): self.dtype = dtype self.decision_method = decision_method @@ -120,16 +136,15 @@ def __init__( @_fit_context(prefer_skip_nested_validation=True) def fit(self, X, y): - """Fit the model with the training data. + """Fit underlying estimators to data matrix X and target(s) y. Parameters ---------- - X : {array-like, sparse matrix} of shape (n_samples, n_features) - Training patterns array, where n_samples is the number of samples and - n_features is the number of features. + X : ndarray or sparse matrix of shape (n_samples, n_features) + The input data. - y : array-like of shape (n_samples,) - Target vector relative to X. + y : ndarray of shape (n_samples,) + The target values. Returns ------- @@ -142,34 +157,43 @@ def fit(self, X, y): If parameters are invalid or data has wrong format. """ - X, y = check_X_y(X, y) - - self.X_ = X - self.y_ = y + X, y = _validate_data_compat( + self, X, y, accept_sparse=False, ensure_2d=True, dtype=None + ) # Get list of different labels of the dataset self.classes_ = np.unique(y) + if self.classes_.size < 2: + raise ValueError("OrdinalDecomposition requires at least 2 classes.") + + dtype = str(self.dtype).lower() + decision = str(self.decision_method).lower() + if decision == "frank_hall" and dtype != "ordered_partitions": + raise ValueError( + "When using Frank and Hall decision method,\ + ordered_partitions must be used" + ) # Give each train input its corresponding output label # for each binary classifier - self.coding_matrix_ = self._coding_matrix( - self.dtype.lower(), len(self.classes_) - ) + self.coding_matrix_ = self._coding_matrix(dtype, len(self.classes_)) class_labels = self.coding_matrix_[(np.digitize(y, self.classes_) - 1), :] - self.classifiers_ = [] + self.estimators_ = [] + parameters = {} if self.parameters is None else self.parameters + # Fitting n_targets - 1 classifiers for n in range(len(class_labels[0, :])): + estimator = load_classifier(self.base_classifier, param_grid=parameters) + if not hasattr(estimator, "predict_proba"): + raise TypeError( + f'Base estimator "{self.base_classifier}" must implement predict_proba.' + ) - estimator = load_classifier( - self.base_classifier, param_grid=self.parameters - ) - estimator.fit( - X[np.where(class_labels[:, n] != 0)], - np.ravel(class_labels[np.where(class_labels[:, n] != 0), n].T), - ) + mask = class_labels[:, n] != 0 + estimator.fit(X[mask], class_labels[mask, n].ravel()) - self.classifiers_.append(estimator) + self.estimators_.append(estimator) return self @@ -179,13 +203,12 @@ def predict(self, X): Parameters ---------- X : {array-like, sparse matrix} of shape (n_samples, n_features) - Test patterns array, where n_samples is the number of samples and - n_features is the number of features. + The input data. Returns ------- - y_pred : array, shape (n_samples,) - Class labels for samples in X. + y_pred : ndarray of shape (n_samples,) + The predicted classes. Raises ------ @@ -199,15 +222,14 @@ def predict(self, X): If the specified loss method is not implemented. """ - check_is_fitted(self, ["X_", "y_"]) - X = check_array(X) + check_is_fitted(self, ["estimators_", "classes_", "coding_matrix_"]) + X = _validate_data_compat(self, X, reset=False, ensure_2d=True, dtype=None) # Getting predicted labels for dataset from each classifier predictions = self._get_predictions(X) decision_method = self.decision_method.lower() if decision_method == "exponential_loss": - # Scaling predictions from [0,1] range to [-1,1] predictions = predictions * 2 - 1 @@ -216,7 +238,6 @@ def predict(self, X): y_pred = self.classes_[np.argmin(losses, axis=1)] elif decision_method == "hinge_loss": - # Scaling predictions from [0,1] range to [-1,1] predictions = predictions * 2 - 1 @@ -225,7 +246,6 @@ def predict(self, X): y_pred = self.classes_[np.argmin(losses, axis=1)] elif decision_method == "logarithmic_loss": - # Scaling predictions from [0,1] range to [-1,1] predictions = predictions * 2 - 1 @@ -234,7 +254,6 @@ def predict(self, X): y_pred = self.classes_[np.argmin(losses, axis=1)] elif decision_method == "frank_hall": - # Transforming from binary problems to the original problem y_proba = self._frank_hall_method(predictions) y_pred = self.classes_[np.argmax(y_proba, axis=1)] @@ -247,13 +266,14 @@ def predict(self, X): return y_pred def predict_proba(self, X): - """Return the probability of the sample for each class in the model. + """Probability estimates. + + The returned estimates for all classes are ordered by label of classes. Parameters ---------- X : {array-like, sparse matrix} of shape (n_samples, n_features) - Test patterns array, where n_samples is the number of samples and - n_features is the number of features. + The input data. Returns ------- @@ -273,54 +293,50 @@ def predict_proba(self, X): If the specified loss method is not implemented. """ - check_is_fitted(self, ["X_", "y_"]) - X = check_array(X) + check_is_fitted(self, ["estimators_", "classes_", "coding_matrix_"]) + X = _validate_data_compat(self, X, reset=False, ensure_2d=True, dtype=None) # Getting predicted labels for dataset from each classifier predictions = self._get_predictions(X) decision_method = self.decision_method.lower() if decision_method == "exponential_loss": - # Scaling predictions from [0,1] range to [-1,1] predictions = predictions * 2 - 1 # Transforming from binary problems to the original problem - losses = self._exponential_loss(predictions) - losses = 1 / losses.astype(float) - y_proba = [] - for losse in losses: - y_proba.append((np.exp(losse) / np.sum(np.exp(losse)))) - y_proba = np.array(y_proba) + losses = self._exponential_loss(predictions).astype(float) + eps = np.finfo(float).tiny + scores = 1.0 / (losses + eps) + scores -= scores.max(axis=1, keepdims=True) + y_proba = np.exp(scores) + y_proba /= y_proba.sum(axis=1, keepdims=True) elif decision_method == "hinge_loss": - # Scaling predictions from [0,1] range to [-1,1] predictions = predictions * 2 - 1 # Transforming from binary problems to the original problem - losses = self._hinge_loss(predictions) - losses = 1 / losses.astype(float) - y_proba = [] - for losse in losses: - y_proba.append((np.exp(losse) / np.sum(np.exp(losse)))) - y_proba = np.array(y_proba) + losses = self._hinge_loss(predictions).astype(float) + eps = np.finfo(float).tiny + scores = 1.0 / (losses + eps) + scores -= scores.max(axis=1, keepdims=True) + y_proba = np.exp(scores) + y_proba /= y_proba.sum(axis=1, keepdims=True) elif decision_method == "logarithmic_loss": - # Scaling predictions from [0,1] range to [-1,1] predictions = predictions * 2 - 1 # Transforming from binary problems to the original problem - losses = self._logarithmic_loss(predictions) - losses = 1 / losses.astype(float) - y_proba = [] - for losse in losses: - y_proba.append((np.exp(losse) / np.sum(np.exp(losse)))) - y_proba = np.array(y_proba) + losses = self._logarithmic_loss(predictions).astype(float) + eps = np.finfo(float).tiny + scores = 1.0 / (losses + eps) + scores -= scores.max(axis=1, keepdims=True) + y_proba = np.exp(scores) + y_proba /= y_proba.sum(axis=1, keepdims=True) elif decision_method == "frank_hall": - # Transforming from binary problems to the original problem y_proba = self._frank_hall_method(predictions) @@ -356,30 +372,25 @@ def _coding_matrix(self, dtype, n_classes): """ if dtype == "ordered_partitions": - coding_matrix = np.triu((-2 * np.ones(n_classes - 1))) + 1 coding_matrix = np.vstack([coding_matrix, np.ones((1, n_classes - 1))]) elif dtype == "one_vs_next": - plus_ones = np.diagflat(np.ones((1, n_classes - 1), dtype=int), -1) minus_ones = -(np.eye(n_classes, n_classes - 1, dtype=int)) coding_matrix = minus_ones + plus_ones[:, :-1] elif dtype == "one_vs_followers": - minus_ones = np.diagflat(-np.ones((1, n_classes), dtype=int)) plus_ones = np.tril(np.ones(n_classes), -1) coding_matrix = (plus_ones + minus_ones)[:, :-1] elif dtype == "one_vs_previous": - plusones = np.triu(np.ones(n_classes)) minusones = -np.diagflat(np.ones((1, n_classes - 1)), -1) coding_matrix = np.flip((plusones + minusones)[:, :-1], axis=1) else: - raise ValueError("Decomposition type %s does not exist" % dtype) return coding_matrix.astype(int) @@ -394,8 +405,7 @@ def _get_predictions(self, X): Parameters ---------- X : {array-like, sparse matrix} of shape (n_samples, n_features) - Test patterns array, where n_samples is the number of samples and - n_features is the number of features. + The input data. Returns ------- @@ -404,7 +414,7 @@ def _get_predictions(self, X): """ predictions = np.array( - list(map(lambda c: c.predict_proba(X)[:, 1], self.classifiers_)) + [est.predict_proba(X)[:, 1] for est in self.estimators_] ).T return predictions @@ -423,23 +433,14 @@ def _exponential_loss(self, predictions): Returns ------- - e_losses : array, shape (n_samples, n_targets) + e_losses : ndarray of shape (n_samples, n_classes) Exponential losses for each sample of dataset X. One different value for each class label. """ - # Computing exponential losses - e_losses = np.zeros((predictions.shape[0], (predictions.shape[1] + 1))) - for i in range(predictions.shape[1] + 1): - - e_losses[:, i] = np.sum( - np.exp( - -predictions - * np.tile(self.coding_matrix_[i, :], (predictions.shape[0], 1)) - ), - axis=1, - ) - + C = self.coding_matrix_[None, :, :] + M = predictions[:, None, :] + e_losses = np.exp(-M * C).sum(axis=2) return e_losses def _hinge_loss(self, predictions): @@ -456,27 +457,14 @@ def _hinge_loss(self, predictions): Returns ------- - h_losses : array, shape (n_samples, n_targets) + h_losses : ndarray of shape (n_samples, n_classes) Hinge losses for each sample of dataset X. One different value for each class label. """ - # Computing Hinge losses - h_losses = np.zeros((predictions.shape[0], (predictions.shape[1] + 1))) - for i in range(predictions.shape[1] + 1): - - h_losses[:, i] = np.sum( - np.maximum( - 0, - ( - 1 - - np.tile(self.coding_matrix_[i, :], (predictions.shape[0], 1)) - * predictions - ), - ), - axis=1, - ) - + C = self.coding_matrix_[None, :, :] + M = predictions[:, None, :] + h_losses = np.maximum(0.0, 1.0 - C * M).sum(axis=2) return h_losses def _logarithmic_loss(self, predictions): @@ -493,27 +481,14 @@ def _logarithmic_loss(self, predictions): Returns ------- - l_losses : array, shape (n_samples, n_targets) + l_losses : ndarray of shape (n_samples, n_classes) Logarithmic losses for each sample of dataset X. One different value for each class label. """ - # Computing logarithmic losses - l_losses = np.zeros((predictions.shape[0], (predictions.shape[1] + 1))) - for i in range(predictions.shape[1] + 1): - - l_losses[:, i] = np.sum( - np.log( - 1 - + np.exp( - -2 - * np.tile(self.coding_matrix_[i, :], (predictions.shape[0], 1)) - * predictions - ) - ), - axis=1, - ) - + C = self.coding_matrix_[None, :, :] + M = predictions[:, None, :] + l_losses = np.log1p(np.exp(-2.0 * C * M)).sum(axis=2) return l_losses def _frank_hall_method(self, predictions): @@ -530,21 +505,10 @@ def _frank_hall_method(self, predictions): Returns ------- - y_proba : array, shape (n_samples, n_targets) - Class labels predicted for samples in dataset X. - - Raises - ------ - AttributeError - If the decomposition type is not ordered_partitions. + y_proba : ndarray of shape (n_samples, n_classes) + Class membership probabilities for each sample. """ - if self.dtype.lower() != "ordered_partitions": - raise AttributeError( - "When using Frank and Hall decision method,\ - ordered_partitions must be used" - ) - y_proba = np.empty([(predictions.shape[0]), (predictions.shape[1] + 1)]) # Probabilities of each set to belong to the first ordinal class diff --git a/orca_python/classifiers/tests/test_ordinal_decomposition.py b/orca_python/classifiers/tests/test_ordinal_decomposition.py index b79d567..67f5786 100644 --- a/orca_python/classifiers/tests/test_ordinal_decomposition.py +++ b/orca_python/classifiers/tests/test_ordinal_decomposition.py @@ -158,14 +158,9 @@ def test_coding_matrix(dtype, expected_cm): npt.assert_array_equal(cm, expected_cm) -def test_frank_hall_method(X): +def test_frank_hall_method(): """Test that frank and hall method returns expected values for one toy problem (starting off predicted probabilities given by each binary classifier).""" - # Checking frank_hall cannot be used whitout ordered_partitions - classifier = OrdinalDecomposition(dtype="one_vs_next", decision_method="frank_hall") - with pytest.raises(AttributeError): - classifier._frank_hall_method(X) - classifier = OrdinalDecomposition(dtype="ordered_partitions") classifier.coding_matrix_ = classifier._coding_matrix(classifier.dtype, 5) @@ -352,3 +347,10 @@ def test_ordinal_decomposition_predict_invalid_input_raises_error(X, y): with pytest.raises(ValueError): classifier.predict([]) + + +def test_frank_hall_method_raises_error(X, y): + """Test that using frank_hall with invalid dtype raises a ValueError.""" + classifier = OrdinalDecomposition(dtype="one_vs_next", decision_method="frank_hall") + with pytest.raises(ValueError): + classifier.fit(X, y)