diff --git a/src/deepsparse/transformers/eval_downstream.py b/src/deepsparse/transformers/eval_downstream.py index 85486b312f..b434dec625 100644 --- a/src/deepsparse/transformers/eval_downstream.py +++ b/src/deepsparse/transformers/eval_downstream.py @@ -123,10 +123,12 @@ def mnli_eval(args): ) print(f"Engine info: {text_classify.model}") + label_map = {"entailment": 0, "neutral": 1, "contradiction": 2} + for idx, sample in enumerate(tqdm(mnli_matched)): - pred = text_classify(sample["premise"], sample["hypothesis"]) + pred = text_classify([[sample["premise"], sample["hypothesis"]]]) mnli_metrics.add_batch( - predictions=[int(pred[0]["label"].split("_")[-1])], + predictions=[label_map.get(pred[0]["label"])], references=[sample["label"]], ) @@ -134,9 +136,9 @@ def mnli_eval(args): break for idx, sample in enumerate(tqdm(mnli_mismatched)): - pred = text_classify(sample["premise"], sample["hypothesis"]) + pred = text_classify([[sample["premise"], sample["hypothesis"]]]) mnli_metrics.add_batch( - predictions=[int(pred[0]["label"].split("_")[-1])], + predictions=[label_map.get(pred[0]["label"])], references=[sample["label"]], ) @@ -161,11 +163,13 @@ def qqp_eval(args): ) print(f"Engine info: {text_classify.model}") + label_map = {"not_duplicate": 0, "duplicate": 1} + for idx, sample in enumerate(tqdm(qqp)): pred = text_classify([[sample["question1"], sample["question2"]]]) qqp_metrics.add_batch( - predictions=[int(pred[0]["label"].split("_")[-1])], + predictions=[label_map.get(pred[0]["label"])], references=[sample["label"]], ) @@ -190,13 +194,15 @@ def sst2_eval(args): ) print(f"Engine info: {text_classify.model}") + label_map = {"negative": 0, "positive": 1} + for idx, sample in enumerate(tqdm(sst2)): pred = text_classify( sample["sentence"], ) sst2_metrics.add_batch( - predictions=[int(pred[0]["label"].split("_")[-1])], + predictions=[label_map.get(pred[0]["label"])], references=[sample["label"]], ) @@ -229,6 +235,7 @@ def parse_args(): "--dataset", type=str, choices=list(SUPPORTED_DATASETS.keys()), + required=True, ) parser.add_argument(