-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcollator.py
32 lines (28 loc) · 1.23 KB
/
collator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from dataclasses import dataclass, field
from transformers import (
DataCollator,
)
from typing import Dict, List, Optional
import torch
# prepares lm_labels from target_ids, returns examples with keys as expected by the forward method
# this is necessacry because the trainer directly passes this dict as arguments to the model
# so make sure the keys match the parameter names of the forward method
@dataclass
class T2TDataCollator:
def __call__(self, batch: List) -> Dict[str, torch.Tensor]:
"""
Take a list of samples from a Dataset and collate them into a batch.
Returns:
A dictionary of tensors
"""
input_ids = torch.stack([example['input_ids'] for example in batch])
lm_labels = torch.stack([example['target_ids'] for example in batch])
lm_labels[lm_labels[:, :] == 0] = -100
attention_mask = torch.stack([example['attention_mask'] for example in batch])
decoder_attention_mask = torch.stack([example['target_attention_mask'] for example in batch])
return {
'input_ids': input_ids,
'attention_mask': attention_mask,
'labels': lm_labels,
'decoder_attention_mask': decoder_attention_mask
}