-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
147 lines (121 loc) · 5.15 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# coding=utf-8
# Copyright 2023 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import pickle
import torch
import numpy as np
from typing import List, Optional
from torch import ByteTensor
def pad(seq_batch: List[List[object]],
pad_token=0,
min_len: Optional[int] = None) -> tuple[List[List[object]], ByteTensor]:
"""Pads a batch of sequences with pad_token to the same length.
Args:
seq_batch: A batch of sequences to pad
pad_token: A token to pad with
min_len: The minimum length to pad to. If None, the maximum length in
Returns:
A tuple of (padded_seq_batch, mask)
padded_seq_batch: A padded batch of sequences
mask: A mask of the padded batch
"""
max_len = max(len(seq) for seq in seq_batch)
if min_len is not None:
max_len = max(max_len, min_len)
batch_size = len(seq_batch)
mask = torch.ones(batch_size, max_len).byte()
padded = []
for i, seq in enumerate(seq_batch):
padding = max_len - len(seq)
padded.append(seq + [pad_token] * padding)
if padding > 0:
mask[i, -padding:] = 0
return padded, mask
def mask_renormalize(probs : torch.FloatTensor, mask: torch.ByteTensor) -> torch.FloatTensor:
"""Renormalizes probs with a mask so that the unmasked entries sum to 1.
Args:
probs (torch.FloatTensor): batch of probability distributions with shape
(batch_dim1, batch_dim2, ..., num_elements).
mask (torch.ByteTensor): masks out elements if the value is 0.
Returns:
renormalized_probs (torch.FloatTensor): the tensor of same shape as probs.
Each batch row (last dim) sums to 1, where masked entries have value 0.
If all entries in a row are masked, the batch row sums to 0.
"""
# Set masked entries to 0
masked_probs = probs * mask.float()
# Renormalize the unmasked entries
renormalized_probs = masked_probs / (masked_probs.sum(dim=-1, keepdim=True) + 1e-8)
return renormalized_probs
def as_batches(parallel_data: List[List[object]],
batch_size: int, sequence_length: int) -> List[List[List[object]]]:
"""Iterable of batches of sequences of consecutive data of sequence_length.
A single pass through this iterable will include all starting positions in
each of the parallel sequences in data exactly once.
Args:
paralell_data (List[List[object]]): parallel sequences of consecutive
timesteps of data. Resulting batches will only include consecutive
subsequences within a single parallel sequence of data.
batch_size (int): size of the batches. Last batch may contain fewer than
batch_size sequences.
sequence_length (int): length of the sequences in each batch.
Yields:
List[List[object]]: the outer list is length of batch_size, the inner list
are all length sequence_length. Inner lists are all consecutive subsequences.
"""
positions = []
for i, seq in enumerate(parallel_data):
positions.extend([(i, start_pos) for start_pos in range(len(seq) - sequence_length)])
# Shuffle the positions
np.random.shuffle(positions)
# Yield batches of positions of size batch_size * sequence_length
for i in range(math.ceil(len(positions) / batch_size)):
batch = [
parallel_data[index][start:start+sequence_length]
for index, start in positions[i*batch_size:(i+1)*batch_size]
]
# (batch_size, sequence_length)
yield batch
def save_pickle(content, pickle_file_path, overwrite=False):
"""Saves content to pickle_file_path.
Args:
content (object): object to save.
pickle_file_path (str): path to save the content to.
Raises:
ValueError: if the file already exists.
Returns:
None
"""
if not overwrite and os.path.exists(pickle_file_path):
raise ValueError(f"File already exists: {pickle_file_path}")
print("Saving pickle file: ", pickle_file_path)
with open(pickle_file_path,"wb") as f:
pickle.dump(content, f)
def load_pickle(pickle_file_path):
"""Loads a pickle file.
Args:
pickle_file_path (str): path to the pickle file.
Raises:
ValueError: if the file does not exist.
Returns:
object: the object loaded from the pickle file.
"""
if os.path.exists(pickle_file_path):
print("Loading pickle file: ", pickle_file_path)
with open(pickle_file_path, "rb") as f:
return pickle.load(f)
else:
raise ValueError(f"File does not exist: {pickle_file_path}")