-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatasampler.py
99 lines (87 loc) · 4.12 KB
/
datasampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from transformers.tokenization_utils_base import BatchEncoding
from torch.utils.data import DistributedSampler, Dataset
import torch
from typing import Optional, List, Iterator
import torch.distributed as dist
import math
from tqdm import tqdm
class DistributedBucketSampler(DistributedSampler):
r"""
해당 로직은 HuggingFace Transformers 친화적으로 작성되었습니다.
Distributed Sampler that samples indices in a way that groups together features of the dataset of roughly the same
length while keeping a bit of randomness.
"""
# Copied and adapted from PyTorch DistributedSampler.
def __init__(
self,
dataset: Optional[Dataset] = None,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
drop_last: bool = False,
shuffle: bool = False,
lengths: Optional[List[int]] = None,
model_input_name: Optional[str] = None,
):
if dataset is None and lengths is None:
raise ValueError("One of dataset and lengths must be provided.")
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.drop_last = drop_last
self.shuffle = shuffle
if lengths is None:
model_input_name = model_input_name if model_input_name is not None else "input_ids"
if (
not (isinstance(dataset[0], dict) or isinstance(dataset[0], BatchEncoding))
or model_input_name not in dataset[0]
):
raise ValueError(
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
f"'{model_input_name}' key."
)
lengths = [len(feature[model_input_name]) for feature in tqdm(dataset)]
elif isinstance(lengths, torch.Tensor):
print(
"If lengths is a torch.Tensor, DistributedLengthGroupedSampler will be slow. Converting lengths to"
" List[int]..."
)
lengths = lengths.tolist()
self.lengths = lengths
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.lengths) % self.num_replicas != 0:
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil((len(self.lengths) - self.num_replicas) / self.num_replicas)
else:
self.num_samples = math.ceil(len(self.lengths) / self.num_replicas)
self.total_size = self.num_samples * self.num_replicas
def get_bucket_indices(self, lengths: int) -> List[int]:
# 길이로 내림차순하는 index 리스트를 생성한다.
idx_pair = [(idx, length) for idx, length in enumerate(lengths)]
sort_key: tuple = lambda pair: pair[1]
sorted_pair = sorted(idx_pair, key=sort_key, reverse=True)
return [x[0] for x in sorted_pair]
def __iter__(self) -> Iterator:
# Deterministically shuffle based on epoch and seed
indices = self.get_bucket_indices(self.lengths)
if not self.drop_last:
# add extra samples to make it evenly divisible
indices += indices[: (self.total_size - len(indices))]
else:
# remove tail of data to make it evenly divisible.
indices = indices[: self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)