-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
110 lines (79 loc) · 3.26 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import json
import torch
import torch.distributed as dist
from typing import List, Union, Optional, Tuple, Mapping, Dict
def normalize_instruction(instruction: str):
instruction = instruction.strip()
if len(instruction) > 0 and instruction[-1] in [';', ':', ',', '.']:
return instruction[:-1]
else:
return instruction
def dist_gather_tensor(t: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
if t is None:
return None
t = t.contiguous()
all_tensors = [torch.empty_like(t) for _ in range(dist.get_world_size())]
dist.all_gather(all_tensors, t)
all_tensors[dist.get_rank()] = t
all_tensors = torch.cat(all_tensors, dim=0)
return all_tensors
@torch.no_grad()
def select_grouped_indices(scores: torch.Tensor,
group_size: int,
start: int = 0) -> torch.Tensor:
assert len(scores.shape) == 2
batch_size = scores.shape[0]
assert batch_size * group_size <= scores.shape[1]
indices = torch.arange(0, group_size, dtype=torch.long)
indices = indices.repeat(batch_size, 1)
indices += torch.arange(0, batch_size, dtype=torch.long).unsqueeze(-1) * group_size
indices += start
return indices.to(scores.device)
def full_contrastive_scores_and_labels(
query: torch.Tensor,
key: torch.Tensor,
use_all_pairs: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
assert key.shape[0] % query.shape[0] == 0, '{} % {} > 0'.format(key.shape[0], query.shape[0])
train_n_passages = key.shape[0] // query.shape[0]
labels = torch.arange(0, query.shape[0], dtype=torch.long, device=query.device)
# labels = labels * train_n_passages
# batch_size x (batch_size x n_psg)
qk = torch.mm(query, key.t())
if not use_all_pairs:
return qk, labels
# batch_size x dim
sliced_key = key.index_select(dim=0, index=labels)
assert query.shape[0] == sliced_key.shape[0]
# batch_size x batch_size
kq = torch.mm(sliced_key, query.t())
kq.fill_diagonal_(float('-inf'))
qq = torch.mm(query, query.t())
qq.fill_diagonal_(float('-inf'))
kk = torch.mm(sliced_key, sliced_key.t())
kk.fill_diagonal_(float('-inf'))
scores = torch.cat([qk, kq, qq, kk], dim=-1)
return scores, labels
def slice_batch_dict(batch_dict: Dict[str, torch.Tensor], prefix: str) -> dict:
return {k[len(prefix):]: v for k, v in batch_dict.items() if k.startswith(prefix)}
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name: str, round_digits: int = 3):
self.name = name
self.round_digits = round_digits
self.reset()
def reset(self):
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
return '{}: {}'.format(self.name, round(self.avg, self.round_digits))
if __name__ == '__main__':
query = torch.randn(4, 16)
key = torch.randn(4 * 3, 16)
scores, labels = full_contrastive_scores_and_labels(query, key)
print(scores.shape)
print(labels)