diff --git a/src/deepsparse/transformers/metrics.py b/src/deepsparse/transformers/metrics.py index 394db12813..418f137e17 100644 --- a/src/deepsparse/transformers/metrics.py +++ b/src/deepsparse/transformers/metrics.py @@ -20,8 +20,8 @@ import numpy -import torch from sklearn.metrics import precision_recall_fscore_support +from scipy.special import log_softmax __all__ = [ @@ -71,11 +71,7 @@ def add_batch(self, predictions: numpy.ndarray, targets: numpy.ndarray): targets = targets.flatten() # Compute negative log-likelihood and accumulate - self._neg_log_likelihood += torch.nn.functional.cross_entropy( - torch.tensor(predictions), - torch.tensor(targets), - reduction="sum", - ).item() + self._neg_log_likelihood += _cross_entropy(predictions, targets, reduction="sum").sum() # Track number of tokens processed self._number_tokens += predictions.shape[0] @@ -90,11 +86,7 @@ def add_batch(self, predictions: numpy.ndarray, targets: numpy.ndarray): targets = numpy.expand_dims(targets, axis=0) # Compute negative log-likelihoods for batch - neg_log_likelihoods = torch.nn.functional.cross_entropy( - torch.tensor(predictions.transpose(0, 2, 1)), - torch.tensor(targets), - reduction="none", - ).numpy().mean(-1) + neg_log_likelihoods = _cross_entropy(predictions, targets) # Compute perplexities for batch perplexities = numpy.exp(neg_log_likelihoods) @@ -181,3 +173,15 @@ def compute(self) -> Dict[str, float]: results["f1_std"] = f1.std() return results + + +def _cross_entropy(predictions, targets, reduction="mean"): + logp = log_softmax(predictions, axis=-1) + neg_log_likelihoods = -1. * numpy.take_along_axis(logp, numpy.expand_dims(targets, axis=-1), axis=-1) + neg_log_likelihoods = numpy.squeeze(neg_log_likelihoods, axis=-1) + if reduction == "mean": + neg_log_likelihoods = neg_log_likelihoods.mean(axis=-1) + elif reduction == "sum": + neg_log_likelihoods = neg_log_likelihoods.sum(axis=-1) + + return neg_log_likelihoods \ No newline at end of file