From c9dbc54428061f60e90aab1f77239496ef67dae7 Mon Sep 17 00:00:00 2001 From: Rahul Tuli Date: Wed, 18 Oct 2023 10:29:53 -0400 Subject: [PATCH] [update] Description of `multi_class` attribute (#1327) * Update multi_class description * Fix typo --- .../pipelines/mnli_text_classification.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/deepsparse/transformers/pipelines/mnli_text_classification.py b/src/deepsparse/transformers/pipelines/mnli_text_classification.py index 634b543e6a..1cc05e954c 100644 --- a/src/deepsparse/transformers/pipelines/mnli_text_classification.py +++ b/src/deepsparse/transformers/pipelines/mnli_text_classification.py @@ -71,7 +71,11 @@ class MnliTextClassificationConfig(BaseModel): description="Index of mnli model outputs which denotes contradiction", default=2 ) multi_class: bool = Field( - description="True if class probabilities are independent, default False", + description="Whether or not multiple candidate labels can be true. " + "If `False`, the scores are normalized as the softmax of entailment" + " score. If `True`, the labels are considered independent and probabilities " + "are normalized for each candidate by doing a softmax of the entailment score " + "vs. the contradiction score. Default is `False`.", default=False, ) @@ -95,8 +99,12 @@ class MnliTextClassificationInput(ZeroShotTextClassificationInputBase): default=None, ) multi_class: Optional[bool] = Field( - description="True if class probabilities are independent, default False. " - "If provided, overrides the multi_class value in the config.", + description="Whether or not multiple candidate labels can be true. " + "If `False`, the scores are normalized as the softmax of entailment score. " + "If `True`, the labels are considered independent and probabilities are " + "normalized for each candidate by doing a softmax of the entailment score " + "vs. the contradiction score. Default is `False`. If provided, overrides " + "the multi_class value in the config.", default=None, )