diff --git a/README.md b/README.md index 3dae92fe..06be2658 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ The following algorithms are currently implemented. Any methods that can be cast as an adaptation of the input data can be used in one of two ways: - a scikit-learn transformer (Adapter) which provides both a full Classifier/Regressor estimator - - or an `Adapter` that can be used in a DA pipeline with `make_da_pipeline`. + - or an `Adapter` that can be used in a DA pipeline with `make_da_pipeline`. Refer to the examples below and visit [the gallery](https://scikit-adaptation.github.io/auto_examples/index.html)for more details. ### Deep learning domain adaptation algorithms @@ -84,7 +84,7 @@ details, please refer to this [example](https://scikit-adaptation.github.io/auto First, the DA data in the SKADA API is stored in the following format: ```python -X, y, sample_domain +X, y, sample_domain ``` Where `X` is the input data, `y` is the target labels and `sample_domain` is the @@ -112,7 +112,7 @@ pipe = make_da_pipeline(StandardScaler(), CORALAdapter(), LogisticRegression()) pipe.fit(X, y, sample_domain=sample_domain) # sample_domain passed by name ``` -Please note that for `Adapter` classes that implement sample reweighting, the +Please note that for `Adapter` classes that implement sample reweighting, the subsequent classifier/regressor must require sample_weights as input. This is done with the `set_fit_requires` method. For instance, with `LogisticRegression`, you would use `LogisticRegression().set_fit_requires('sample_weight')`: @@ -143,7 +143,7 @@ cv = SourceTargetShuffleSplit() scorer = PredictionEntropyScorer() # cross val score -scores = cross_val_score(pipe, X, y, params={'sample_domain': sample_domain}, +scores = cross_val_score(pipe, X, y, params={'sample_domain': sample_domain}, cv=cv, scoring=scorer) # grid search @@ -239,8 +239,10 @@ The library is distributed under the 3-Clause BSD license. [33] Kang, G., Jiang, L., Yang, Y., & Hauptmann, A. G. (2019). [Contrastive Adaptation Network for Unsupervised Domain Adaptation](https://arxiv.org/abs/1901.00976). In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 4893-4902). -[34] Jin, Ying, Wang, Ximei, Long, Mingsheng, Wang, Jianmin. [Minimum Class Confusion for Versatile Domain Adaptation](https://arxiv.org/pdf/1912.03699). ECCV, 2020. +[34] Jin, Ying, Wang, Ximei, Long, Mingsheng, Wang, Jianmin. [Minimum Class Confusion for Versatile Domain Adaptation](https://arxiv.org/pdf/1912.03699). ECCV, 2020. [35] Zhang, Y., Liu, T., Long, M., & Jordan, M. I. (2019). [Bridging Theory and Algorithm for Domain Adaptation](https://arxiv.org/abs/1904.05801). In Proceedings of the 36th International Conference on Machine Learning, (pp. 7404-7413). [36] Xiao, Zhiqing, Wang, Haobo, Jin, Ying, Feng, Lei, Chen, Gang, Huang, Fei, Zhao, Junbo.[SPA: A Graph Spectral Alignment Perspective for Domain Adaptation](https://arxiv.org/pdf/2310.17594). In Neurips, 2023. + +[37] Xie, Renchunzi, Odonnat, Ambroise, Feofanov, Vasilii, Deng, Weijian, Zhang, Jianfeng and An, Bo. [MaNo: Exploiting Matrix Norm for Unsupervised Accuracy Estimation Under Distribution Shifts](https://arxiv.org/pdf/2405.18979). In NeurIPS, 2024. \ No newline at end of file diff --git a/docs/source/all.rst b/docs/source/all.rst index 1df6de32..a562929d 100644 --- a/docs/source/all.rst +++ b/docs/source/all.rst @@ -13,7 +13,7 @@ Main module :py:mod:`skada` :no-members: :no-inherited-members: -Sample reweighting DA methods +Sample reweighting DA methods ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. .. autosummary:: @@ -211,6 +211,7 @@ DA metrics :py:mod:`skada.metrics` SoftNeighborhoodDensity CircularValidation MixValScorer + MaNoScorer Model Selection :py:mod:`skada.model_selection` diff --git a/skada/deep/tests/test_deep_scorer.py b/skada/deep/tests/test_deep_scorer.py index 2922201e..a704fd08 100644 --- a/skada/deep/tests/test_deep_scorer.py +++ b/skada/deep/tests/test_deep_scorer.py @@ -1,4 +1,5 @@ # Author: Yanis Lalou +# Ambroise Odonnat # # License: BSD 3-Clause @@ -16,6 +17,7 @@ CircularValidation, DeepEmbeddedValidation, ImportanceWeightedScorer, + MaNoScorer, MixValScorer, PredictionEntropyScorer, SoftNeighborhoodDensity, @@ -29,6 +31,7 @@ PredictionEntropyScorer(), SoftNeighborhoodDensity(), CircularValidation(), + MaNoScorer(), MixValScorer(), ImportanceWeightedScorer(), ], @@ -66,6 +69,7 @@ def test_generic_scorer_on_deepmodel(scorer, da_dataset): PredictionEntropyScorer(), SoftNeighborhoodDensity(), DeepEmbeddedValidation(), + MaNoScorer(), ], ) def test_generic_scorer(scorer, da_dataset): @@ -101,6 +105,7 @@ def test_generic_scorer(scorer, da_dataset): [ DeepEmbeddedValidation(), ImportanceWeightedScorer(), + MaNoScorer(), ], ) def test_scorer_with_nd_features(scorer, da_dataset): @@ -191,7 +196,14 @@ def test_dev_scorer_on_source_only(da_dataset): assert ~np.isnan(scores), "The score is computed" -def test_dev_exception_layer_name(da_dataset): +@pytest.mark.parametrize( + "scorer", + [ + DeepEmbeddedValidation(), + MaNoScorer(), + ], +) +def test_exception_layer_name(scorer, da_dataset): X, y, sample_domain = da_dataset.pack_train(as_sources=["s"], as_targets=["t"]) X_test, y_test, sample_domain_test = da_dataset.pack_test(as_targets=["t"]) @@ -209,4 +221,92 @@ def test_dev_exception_layer_name(da_dataset): estimator.fit(X, y, sample_domain=sample_domain) with pytest.raises(ValueError, match="The layer_name of the estimator is not set."): - DeepEmbeddedValidation()(estimator, X, y, sample_domain) + scorer(estimator, X, y, sample_domain) + + +def test_mano_softmax(da_dataset): + X, y, sample_domain = da_dataset.pack_train(as_sources=["s"], as_targets=["t"]) + X_test, y_test, sample_domain_test = da_dataset.pack_test(as_targets=["t"]) + + estimator = DeepCoral( + ToyModule2D(proba=True), + reg=1, + layer_name="dropout", + batch_size=10, + max_epochs=10, + train_split=None, + ) + + X = X.astype(np.float32) + X_test = X_test.astype(np.float32) + + # without dict + estimator.fit(X, y, sample_domain=sample_domain) + + estimator.predict(X_test, sample_domain=sample_domain_test, allow_source=True) + estimator.predict_proba(X, sample_domain=sample_domain, allow_source=True) + + scorer = MaNoScorer(threshold=-1) + scorer(estimator, X, y, sample_domain) + print(scorer.chosen_normalization.lower()) + assert ( + scorer.chosen_normalization.lower() == "softmax" + ), "the wrong normalization was chosen" + + +def test_mano_taylor(da_dataset): + X, y, sample_domain = da_dataset.pack_train(as_sources=["s"], as_targets=["t"]) + X_test, y_test, sample_domain_test = da_dataset.pack_test(as_targets=["t"]) + + estimator = DeepCoral( + ToyModule2D(proba=True), + reg=1, + layer_name="dropout", + batch_size=10, + max_epochs=10, + train_split=None, + ) + + X = X.astype(np.float32) + X_test = X_test.astype(np.float32) + + # without dict + estimator.fit(X, y, sample_domain=sample_domain) + + estimator.predict(X_test, sample_domain=sample_domain_test, allow_source=True) + estimator.predict_proba(X, sample_domain=sample_domain, allow_source=True) + + scorer = MaNoScorer(threshold=float("inf")) + scorer(estimator, X, y, sample_domain) + assert ( + scorer.chosen_normalization.lower() == "taylor" + ), "the wrong normalization was chosen" + + +def test_mano_output_range(da_dataset): + X, y, sample_domain = da_dataset.pack_train(as_sources=["s"], as_targets=["t"]) + X_test, y_test, sample_domain_test = da_dataset.pack_test(as_targets=["t"]) + + estimator = DeepCoral( + ToyModule2D(proba=True), + reg=1, + layer_name="dropout", + batch_size=10, + max_epochs=10, + train_split=None, + ) + + X = X.astype(np.float32) + X_test = X_test.astype(np.float32) + + # without dict + estimator.fit(X, y, sample_domain=sample_domain) + + estimator.predict(X_test, sample_domain=sample_domain_test, allow_source=True) + estimator.predict_proba(X, sample_domain=sample_domain, allow_source=True) + + scorer = MaNoScorer(threshold=float("inf")) + score = scorer(estimator, X, y, sample_domain) + assert (scorer._sign * score >= 0) and ( + scorer._sign * score <= 1 + ), "The output range should be [-1, 0] or [0, 1]." diff --git a/skada/metrics.py b/skada/metrics.py index a70f030a..64d9f058 100644 --- a/skada/metrics.py +++ b/skada/metrics.py @@ -2,6 +2,7 @@ # Remi Flamary # Oleksii Kachaiev # Yanis Lalou +# Ambroise Odonnat # # License: BSD 3-Clause @@ -396,6 +397,7 @@ def _score(self, estimator, X, y, sample_domain=None, **kwargs): ) has_transform_method = False + if not isinstance(estimator, Pipeline): # The estimator is a deep model if estimator.module_.layer_name is None: @@ -782,3 +784,136 @@ def _score(self, estimator, X, y=None, sample_domain=None, **params): ice_score = ice_diff return self._sign * ice_score + + +class MaNoScorer(_BaseDomainAwareScorer): + """ + MaNo scorer inspired by [37]_, an approach for unsupervised accuracy estimation. + + This scorer used the model's predictions on target data to estimate + the accuracy of the model. The original implementation in [37]_ is + tailored to neural networks and consist of three steps: + 1) Recover target logits (inference step), + 2) Normalize them as probabilities (e.g., with softmax), + 3) Aggregate by averaging the p-norms of the target normalized logits. + + To ensure compatibility with any estimator, we adapt the original implementation. + If the estimator is a neural network, follow 1) --> 2) --> 3) like in [37]_. + Else, directly use the probabilities predicted by the estimator and then do 3). + + See [37]_ for details. + + Parameters + ---------- + p : int, default=4 + Order for the p-norm normalization. + It must be non-negative. + threshold : int, default=5 + Threshold value to determine which normalization to use. + If threshold <= 0, softmax normalization is always used. + See Eq.(6) of [37]_ for more details. + greater_is_better : bool, default=True + Whether higher scores are better. + + Returns + ------- + score : float in [0, 1] (or [-1, 0] depending on the value of self._sign). + + References + ---------- + .. [37] Renchunzi Xie et al. MaNo: Matrix Norm for Unsupervised Accuracy Estimation + under Distribution Shifts. + In NeurIPS, 2024. + """ + + def __init__(self, p=4, threshold=5, greater_is_better=True): + super().__init__() + self.p = p + self.threshold = threshold + self._sign = 1 if greater_is_better else -1 + self.chosen_normalization = None + + if self.p <= 0: + raise ValueError("The order of the p-norm must be positive") + + def _score(self, estimator, X, y, sample_domain=None, **params): + if not hasattr(estimator, "predict_proba"): + raise AttributeError( + "The estimator passed should have a 'predict_proba' method. " + f"The estimator {estimator!r} does not." + ) + + X, y, sample_domain = check_X_y_domain(X, y, sample_domain, allow_nd=True) + source_idx = extract_source_indices(sample_domain) + + # Check from y values if it is a classification problem + y_type = _find_y_type(y) + if y_type != Y_Type.DISCRETE: + raise ValueError("MaNo scorer only supports classification problems.") + + if not isinstance(estimator, Pipeline): + # The estimator is a deep model + if estimator.module_.layer_name is None: + raise ValueError("The layer_name of the estimator is not set.") + + # 1) Recover logits on target + logits = estimator.infer(X[~source_idx], **params).cpu().detach().numpy() + + # 2) Normalize logits to obtain probabilities + criterion = self._get_criterion(logits) + proba = self._softrun( + logits=logits, + criterion=criterion, + threshold=self.threshold, + ) + else: + # Directly recover predicted probabilities + proba = estimator.predict_proba( + X[~source_idx], sample_domain=sample_domain[~source_idx], **params + ) + + # 3) Aggregate following Eq.(2) of [37]_. + score = np.mean(proba**self.p) ** (1 / self.p) + + return self._sign * score + + def _get_criterion(self, logits): + """ + Compute criterion to select the proper normalization. + See Eq.(6) of [1]_ for more details. + """ + proba = self._stable_softmax(logits) + proba = np.log(proba) + divergence = -np.mean(proba) + + return divergence + + def _softrun(self, logits, criterion, threshold): + """Normalize the logits following Eq.(6) of [37]_.""" + if criterion > threshold: + # Apply softmax normalization + outputs = self._stable_softmax(logits) + self.chosen_normalization = "softmax" + else: + # Apply Taylor approximation + outputs = self._taylor_softmax(logits) + self.chosen_normalization = "taylor" + + return outputs + + @staticmethod + def _stable_softmax(logits): + """Compute softmax function.""" + logits -= np.max(logits, axis=1, keepdims=True) + exp_logits = np.exp(logits) + exp_logits /= np.sum(exp_logits, axis=1, keepdims=True) + return exp_logits + + @staticmethod + def _taylor_softmax(logits): + """Compute Taylor approximation of order 2 of softmax.""" + tay_logits = 1 + logits + logits**2 / 2 + tay_logits -= np.min(tay_logits, axis=1, keepdims=True) + tay_logits /= np.sum(tay_logits, axis=1, keepdims=True) + + return tay_logits diff --git a/skada/tests/test_scorer.py b/skada/tests/test_scorer.py index 328a605f..7ff53d91 100644 --- a/skada/tests/test_scorer.py +++ b/skada/tests/test_scorer.py @@ -2,6 +2,7 @@ # Remi Flamary # Oleksii Kachaiev # Yanis Lalou +# Ambroise Odonnat # # License: BSD 3-Clause @@ -23,6 +24,7 @@ CircularValidation, DeepEmbeddedValidation, ImportanceWeightedScorer, + MaNoScorer, MixValScorer, PredictionEntropyScorer, SoftNeighborhoodDensity, @@ -38,6 +40,7 @@ SoftNeighborhoodDensity(), DeepEmbeddedValidation(), CircularValidation(), + MaNoScorer(), ], ) def test_generic_scorer(scorer, da_dataset): @@ -92,6 +95,7 @@ def test_supervised_scorer(da_dataset): [ PredictionEntropyScorer(), SoftNeighborhoodDensity(), + MaNoScorer(), ], ) def test_scorer_with_entropy_requires_predict_proba(scorer, da_dataset): @@ -351,6 +355,51 @@ def test_mixval_scorer_regression(da_reg_dataset): scorer(estimator, X, y, sample_domain) +def test_mano_scorer(da_dataset): + X, y, sample_domain = da_dataset.pack_train(as_sources=["s"], as_targets=["t"]) + estimator = make_da_pipeline( + DensityReweightAdapter(), + LogisticRegression().set_fit_request(sample_weight=True), + ) + + estimator.fit(X, y, sample_domain=sample_domain) + + scorer = MaNoScorer() + score_mean = scorer._score(estimator, X, y, sample_domain=sample_domain) + assert isinstance(score_mean, float), "score_mean is not a float" + + # Test softmax normalization + scorer = MaNoScorer(threshold=-1) + score_mean = scorer._score(estimator, X, y, sample_domain=sample_domain) + assert isinstance(score_mean, float), "score_mean is not a float" + + # Test softmax normalization + scorer = MaNoScorer(threshold=float("inf")) + score_mean = scorer._score(estimator, X, y, sample_domain=sample_domain) + assert isinstance(score_mean, float), "score_mean is not a float" + + # Test invalid p-norm order + with pytest.raises(ValueError): + MaNoScorer(p=-1) + + # Test correct output range + scorer = MaNoScorer() + score_mean = scorer._score(estimator, X, y, sample_domain=sample_domain) + assert (scorer._sign * score_mean >= 0) and ( + scorer._sign * score_mean <= 1 + ), "The output range should be [-1, 0] or [0, 1]." + + +def test_mano_scorer_regression(da_reg_dataset): + X, y, sample_domain = da_reg_dataset.pack(as_sources=["s"], as_targets=["t"]) + + estimator = make_da_pipeline(DensityReweightAdapter(), LogisticRegression()) + + scorer = MaNoScorer() + with pytest.raises(ValueError): + scorer(estimator, X, y, sample_domain) + + @pytest.mark.parametrize( "scorer", [ @@ -360,6 +409,7 @@ def test_mixval_scorer_regression(da_reg_dataset): SoftNeighborhoodDensity(), CircularValidation(), MixValScorer(alpha=0.55, random_state=42), + MaNoScorer(), ], ) def test_scorer_with_nd_input(scorer, da_dataset):