Skip to content

Commit

Permalink
Remove torch dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
anmarques committed Aug 23, 2023
1 parent 756169c commit 97b5f1a
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions src/deepsparse/transformers/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@

import numpy

import torch
from sklearn.metrics import precision_recall_fscore_support
from scipy.special import log_softmax


__all__ = [
Expand Down Expand Up @@ -71,11 +71,7 @@ def add_batch(self, predictions: numpy.ndarray, targets: numpy.ndarray):
targets = targets.flatten()

# Compute negative log-likelihood and accumulate
self._neg_log_likelihood += torch.nn.functional.cross_entropy(
torch.tensor(predictions),
torch.tensor(targets),
reduction="sum",
).item()
self._neg_log_likelihood += _cross_entropy(predictions, targets, reduction="sum").sum()

# Track number of tokens processed
self._number_tokens += predictions.shape[0]
Expand All @@ -90,11 +86,7 @@ def add_batch(self, predictions: numpy.ndarray, targets: numpy.ndarray):
targets = numpy.expand_dims(targets, axis=0)

# Compute negative log-likelihoods for batch
neg_log_likelihoods = torch.nn.functional.cross_entropy(
torch.tensor(predictions.transpose(0, 2, 1)),
torch.tensor(targets),
reduction="none",
).numpy().mean(-1)
neg_log_likelihoods = _cross_entropy(predictions, targets)

# Compute perplexities for batch
perplexities = numpy.exp(neg_log_likelihoods)
Expand Down Expand Up @@ -181,3 +173,15 @@ def compute(self) -> Dict[str, float]:
results["f1_std"] = f1.std()

return results


def _cross_entropy(predictions, targets, reduction="mean"):
logp = log_softmax(predictions, axis=-1)
neg_log_likelihoods = -1. * numpy.take_along_axis(logp, numpy.expand_dims(targets, axis=-1), axis=-1)
neg_log_likelihoods = numpy.squeeze(neg_log_likelihoods, axis=-1)
if reduction == "mean":
neg_log_likelihoods = neg_log_likelihoods.mean(axis=-1)
elif reduction == "sum":
neg_log_likelihoods = neg_log_likelihoods.sum(axis=-1)

return neg_log_likelihoods

0 comments on commit 97b5f1a

Please sign in to comment.