Skip to content

Commit

Permalink
feat: patch torch cosine_similarity
Browse files Browse the repository at this point in the history
  • Loading branch information
LutingWang committed Oct 10, 2024
1 parent 4b5a1f1 commit 29ba5f3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
12 changes: 11 additions & 1 deletion docs/source/pretrained/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
from transformers.modeling_outputs import BaseModelOutput

import todd

tokenizer: T5Tokenizer = AutoTokenizer.from_pretrained(
'pretrained/t5/t5-large',
)
Expand All @@ -10,6 +12,14 @@
'Studies have been shown that owning a dog is good for you',
return_tensors='pt',
)
print(tokenizer.convert_ids_to_tokens(tokens['input_ids'][0]))
# ['▁Studies', '▁have', '▁been', '▁shown', '▁that', '▁own', 'ing', '▁', 'a',
# '▁dog', '▁is', '▁good', '▁for', '▁you', '</s>']
if todd.Store.cuda: # pylint: disable=using-constant-test
model = model.cuda()
tokens = tokens.to('cuda')
with torch.no_grad():
outputs: BaseModelOutput = model(**tokens)
print(outputs.last_hidden_state)
print(outputs.last_hidden_state[0, -1])
# tensor([ 0.0900, -0.0016, -0.0112, ..., 0.0356, -0.0397, 0.0150],
# device='cuda:0')
9 changes: 9 additions & 0 deletions todd/patches/torch/aten.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
__all__ = [
'all_close',
'random_int',
'cosine_similarity',
]

from typing import Any, cast

import torch
import torch.distributed as dist
import torch.nn.functional as F

from .distributed import get_world_size

Expand All @@ -24,3 +26,10 @@ def all_close(x: Any, y: Any, *args, **kwargs) -> bool:
if not isinstance(y, torch.Tensor):
y = torch.tensor(y)
return torch.allclose(x, y, *args, **kwargs)


def cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# fixed CUDA out of memory error of torch.cosine_similarity
x = F.normalize(x)
y = F.normalize(y)
return torch.einsum('x d, y d -> x y', x, y)

0 comments on commit 29ba5f3

Please sign in to comment.