-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathcollate.py
104 lines (84 loc) · 3.93 KB
/
collate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import numpy as np
np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)
TRUNCATE_LEN = 30
class CollateTest:
def __init__(self, user_reviews_dict, item_reviews_dict):
self.user_reviews_dict = user_reviews_dict
self.item_reviews_dict = item_reviews_dict
self.pad_vector = np.ones(512)
def get_key_mask(self, max_length, pad_length):
current_key_mask = torch.zeros([max_length])
if pad_length != 0:
current_key_mask[-pad_length:] = 1
return current_key_mask
def stack_and_pad(self, review_list, bsz):
max_length = max([len(reviews) for reviews in review_list])
review_tensor = torch.empty([bsz, max_length, 512])
key_mask = torch.empty([bsz, max_length])
for idx, reviews in enumerate(review_list):
pad_length = max_length - len(reviews)
pad = torch.zeros([pad_length, 512])
key_mask[idx] = self.get_key_mask(max_length, pad_length)
reviews = torch.cat([torch.from_numpy(reviews), pad])
review_tensor[idx] = reviews
return review_tensor, key_mask.bool()
def __call__(self, batch):
bsz = len(batch)
users = [elements[0] for elements in batch]
items = [elements[1] for elements in batch]
ratings = [elements[2] for elements in batch]
user_review_list = np.array([np.array(self.user_reviews_dict[user])[:TRUNCATE_LEN] for user in users])
item_review_list = np.array([np.array(self.item_reviews_dict[item])[:TRUNCATE_LEN] for item in items])
user_reviews, user_key_mask = self.stack_and_pad(user_review_list, bsz)
item_reviews, item_key_mask = self.stack_and_pad(item_review_list, bsz)
users = torch.tensor(users, dtype=torch.long)
items = torch.tensor(items, dtype=torch.long)
reviews = torch.zeros(1)
ratings = torch.tensor(ratings, dtype=torch.float)
return users, items, reviews, ratings, user_reviews, item_reviews, user_key_mask, item_key_mask
class CollateTrain:
def __init__(self, user_reviews_dict, item_reviews_dict):
self.user_reviews_dict = user_reviews_dict
self.item_reviews_dict = item_reviews_dict
self.pad_vector = np.ones(512)
def get_key_mask(self, max_length, pad_length):
current_key_mask = torch.zeros([max_length - 1])
if pad_length != 0:
current_key_mask[-pad_length:] = 1
return current_key_mask
def stack_and_pad(self, review_list, true_reviews, bsz):
max_length = max([len(reviews) for reviews in review_list])
review_tensor = torch.empty([bsz, max_length - 1, 512])
key_mask = torch.empty([bsz, max_length - 1])
for idx, reviews in enumerate(review_list):
pad_length = max_length - len(reviews)
pad = torch.zeros([pad_length, 512])
key_mask[idx] = self.get_key_mask(max_length, pad_length)
mask = true_reviews[idx].numpy()
true_review_flag = False
for i, review in enumerate(reviews):
if (mask == review).all():
true_review_flag = True
reviews = np.delete(reviews, i, axis=0)
break
if not true_review_flag:
reviews = np.delete(reviews, TRUNCATE_LEN-1, axis=0)
reviews = torch.cat([torch.from_numpy(reviews), pad])
review_tensor[idx] = reviews
return review_tensor, key_mask.bool()
def __call__(self, batch):
bsz = len(batch)
users = [elements[0] for elements in batch]
items = [elements[1] for elements in batch]
reviews = [torch.from_numpy(elements[2]) for elements in batch]
ratings = [elements[3] for elements in batch]
user_review_list = np.array([np.array(self.user_reviews_dict[user])[:TRUNCATE_LEN] for user in users])
item_review_list = np.array([np.array(self.item_reviews_dict[item])[:TRUNCATE_LEN] for item in items])
user_reviews, user_key_mask = self.stack_and_pad(user_review_list, reviews, bsz)
item_reviews, item_key_mask = self.stack_and_pad(item_review_list, reviews, bsz)
users = torch.tensor(users, dtype=torch.long)
items = torch.tensor(items, dtype=torch.long)
reviews = torch.stack(reviews, dim=0)
ratings = torch.tensor(ratings, dtype=torch.float)
return users, items, reviews, ratings, user_reviews, item_reviews, user_key_mask, item_key_mask