Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/training #9

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions configs/model/transformer-base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ model_params:
dropout: 0.1

data_params:
num_workers: 8
num_workers: 8

train_hparams:
batch_size: 25000 # by num_tokens
warmup_steps: 4000
optimizer: adam
beta_1: 0.9 # activates when optimizer is Adam
beta_2: 0.98 # activates when optimizer is Adam
eps: 1e-9 # FIXME: LayerNorm eps?
steps: 100000
batch_size: 6250 # by num_tokens, paper: 25000 tokens
steps: 400000 # paper: 100000, increase as batch_size decrease
warmup_steps: 4000
optimizer: adam
beta_1: 0.9 # activates when optimizer is Adam
beta_2: 0.98 # activates when optimizer is Adam
eps: 1e-9 # activates when optimizer is Adam
label_smoothing_eps: 0.1
1 change: 1 addition & 0 deletions configs/model/transformer-big.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ train_hparams:
beta_1: 0.9 # activates when optimizer is Adam
beta_2: 0.98 # activates when optimizer is Adam
eps: 1e-9
label_smoothing_eps: 0.1
steps: 300000

langpair: en-fr
Expand Down
9,990 changes: 9,990 additions & 0 deletions corpus/wmt14.example.train.norm.de

Large diffs are not rendered by default.

9,990 changes: 9,990 additions & 0 deletions corpus/wmt14.example.train.norm.en

Large diffs are not rendered by default.

7 changes: 5 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
omegaconf>=2.0.0
tokenizers>=0.8.1
tokenizers>=0.9.0
tqdm>=4.48.0
pytorch-lightning>=0.8.5
pytorch-lightning==1.0.0
wandb>=0.10.5
sacrebleu>=1.4.14
datasets>=1.1.2
40 changes: 40 additions & 0 deletions src/bleu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
from torch import Tensor
from datasets import load_metric

from src.utils import Config, load_tokenizer


class BLEU():
"""Calculate BLEU"""

def __init__(self, langpair: str):
self.configs = Config()
self.tokenizer = load_tokenizer(langpair)
self.metric = load_metric("sacrebleu")

def compute(self, target_hat_indices: Tensor, target_indices: Tensor, target_lenghts: Tensor) -> float:
"""
Args:
target_hat_indices: model predictions in indices (batch_size, vocab_size, max_len)
target_indices: reference sentences in indices (batch_size, max_len)
target_lengths: reference sentences length including bos and eos
"""
pred_indices = torch.argmax(target_hat_indices, dim=1)
target_hat_sentences, target_sentences = [], []
for i in range(target_indices.size(0)):
real_length = target_lenghts[i] - 2 # remove eos and bos
real_length = real_length.int()
pred = pred_indices[i][:real_length]
target = target_indices[i][:real_length]
target_hat_sentence = self.tokenizer.decode(pred.cpu().numpy().tolist())
target_sentence = self.tokenizer.decode(target.cpu().numpy().tolist())
target_hat_sentences.append(target_hat_sentence)
target_sentences.append([target_sentence]) # sacrebleu expects reference to be a list of list
print('')
print(f'pred: {target_hat_sentence}')
print(f'ref: {target_sentence}')
self.metric.add_batch(predictions=target_hat_sentences, references=target_sentences)
score = self.metric.compute()
return score['score']

18 changes: 10 additions & 8 deletions src/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def setup(self, stage: Optional[str] = None) -> None:

def batch_by_tokens(
self, dataset: Dataset, max_tokens: Optional[int] = None
) -> List[torch.Tensor]:
) -> List[List]:
"""Create mini-batch tensors by number of tokens

Args:
Expand All @@ -59,29 +59,31 @@ def batch_by_tokens(
indices_batches:
"""
max_tokens = (
25000 if max_tokens is None else self.configs.model.train_hparams.batch_size
self.configs.model.train_hparams.batch_size
if max_tokens is None
else max_tokens
)

start_idx = 0
source_sample_lens, target_sample_lens = [], []
indices_batches = []
for end_idx in range(len(dataset)):
source_sample_lens.append(dataset[end_idx]["source"]["length"])
target_sample_lens.append(dataset[end_idx]["target"]["length"])
source_sample_lens.append(dataset[end_idx]["source"]["padded_token"].size(0))
target_sample_lens.append(dataset[end_idx]["target"]["padded_token"].size(0))
# when batch is full
if (
sum(source_sample_lens) > max_tokens
or sum(target_sample_lens) > max_tokens
):
indices_batch = torch.arange(start_idx, end_idx)
indices_batch = torch.arange(start_idx, end_idx).tolist()
indices_batches.append(indices_batch)
start_idx = end_idx
source_sample_lens, target_sample_lens = [source_sample_lens[-1]], [
target_sample_lens[-1]
] # end_idx is not included
# when iteration ends
elif end_idx == len(dataset):
indices_batch = torch.arange(start_idx, end_idx)
elif end_idx == len(dataset) - 1:
indices_batch = torch.arange(start_idx, end_idx).tolist()
indices_batches.append(indices_batch)
return indices_batches

Expand All @@ -96,7 +98,7 @@ def train_dataloader(self) -> DataLoader:
num_workers=self.configs.model.data_params.num_workers,
)

def valid_dataloader(self) -> DataLoader:
def val_dataloader(self) -> DataLoader:
batch_sampler = self.batch_by_tokens(self.valid_dataset)
return DataLoader(
self.valid_dataset,
Expand Down
30 changes: 30 additions & 0 deletions src/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import torch
from torch import nn, Tensor
from torch.nn import functional as F
from torch.nn.modules.loss import _Loss


class LabelSmoothingLoss(_Loss):

def __init__(self, num_classes: int, epsilon: float) -> None:
super().__init__()
self.num_classes = num_classes
self.epsilon = epsilon
self.confidence = torch.tensor(1 - self.epsilon)
self.regularization = torch.tensor(self.epsilon / self.num_classes)

def forward(self, target_hat: Tensor, target: Tensor, ignore_index: int = -100) -> Tensor:
"""
Args:
target_hat: model prediction (batch_size, vocab_size, max_len)
target: vocab index (batch_size, max_len)
ignore_index: the index does not affect loss
"""
target_hat = target_hat.log_softmax(dim=1)
with torch.no_grad():
true = F.one_hot(target, num_classes=self.num_classes).transpose(1, 2).float() # (batch_size, vocab_size, max_len)
true *= self.confidence
true += self.regularization
true.masked_fill_(target.unsqueeze(dim=1) == ignore_index, 0) # to ignore ignore_index prediction
loss = torch.mean(torch.sum(-true * target_hat, dim=1))
return loss
2 changes: 1 addition & 1 deletion src/model/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, is_base: bool = True):

def attention_mask(self, batch_size: int, seq_len: int) -> Tensor:
attention_shape = (batch_size, seq_len, seq_len)
attention_mask = np.triu(np.ones(attention_shape), k=1).astype("unit8")
attention_mask = np.triu(np.ones(attention_shape), k=1).astype("uint8")
attention_mask = torch.from_numpy(attention_mask) == 0
return attention_mask # (batch_size, seq_len, seq_len)

Expand Down
37 changes: 14 additions & 23 deletions src/model/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class PositionalEncoding(nn.Module):
"""PositionalEncoding

Attributes:
batch_size: batch size of the input
max_len: maximum length of the tokens
embedding_dim: embedding dimension of the given token
"""
Expand All @@ -30,14 +31,16 @@ def __init__(self, max_len: int, embedding_dim: int, is_base: bool = True) -> No
)
positional_encoding[:, 0::2] = torch.sin(position / div_term)
positional_encoding[:, 1::2] = torch.cos(position / div_term)
positional_encoding = positional_encoding.unsqueeze(0).transpose(
0, 1
) # (max_len, 1, embedding_dim)
self.register_buffer(
"positional_encoding", positional_encoding
) # TODO: register_buffer?
positional_encoding = positional_encoding.unsqueeze(
0
) # (1, max_len, embedding_dim)
self.register_buffer("positional_encoding", positional_encoding)

def forward(self, embeddings: Tensor) -> Tensor:
batch_size = embeddings.size(0)
self.positional_encoding = self.positional_encoding.repeat(batch_size, 1, 1)
if torch.cuda.is_available():
self.positional_encoding = self.positional_encoding.to('cuda')
embeddings = (
embeddings + self.positional_encoding
) # (batch_size, max_len, embedding_dim)
Expand All @@ -63,8 +66,6 @@ def __init__(self, langpair: str, is_base: bool = True) -> None:
self.vocab_size, self.dim_model, padding_idx=padding_idx
)
self.scale = self.dim_model ** 0.5
self.max_len = configs.model.model_params.max_len
self.positional_encoding = PositionalEncoding(self.max_len, self.dim_model)

def forward(self, source_tokens: torch.Tensor) -> nn.Embedding:
"""Get embedding matrix for source tokens
Expand All @@ -79,9 +80,9 @@ def forward(self, source_tokens: torch.Tensor) -> nn.Embedding:
source_tokens
) # (batch_size, max_len, dim_model)
embeddings *= self.scale
embeddings = self.positional_encoding(
embeddings
) # (batch_size, max_len, dim_model)
_, max_len, dim_model = embeddings.size() # max_len varies with the batch
positional_encoding = PositionalEncoding(max_len, dim_model)
embeddings = positional_encoding(embeddings) # (batch_size, max_len, dim_model)
return embeddings


Expand Down Expand Up @@ -130,18 +131,6 @@ def forward(
value: value embedding (batch_size, max_len, dim_model)
attention_mask: used to implement masked_attention (batch_size, max_len, max_len)
"""
if self.masked_attention:
assert (
key == value
), "masked self-attention requires key, and value to be of the same"
assert (
attention_mask is not None
), "masked self-attention requires attention mask"
else:
assert (
query == key == value
), "self-attention requires query, key, and value to be of the same"

q = self.q_project(query) # (batch_size, max_len, dim_q)
k = self.k_project(key) # (batch_size, max_len, dim_k)
v = self.v_project(value) # (batch_size, max_len, dim_v)
Expand All @@ -152,6 +141,8 @@ def forward(
qk = qk.masked_fill(qk == 0, self.config.model.train_hparams.eps)

if self.masked_attention:
if torch.cuda.is_available():
attention_mask = attention_mask.to('cuda')
qk = qk.masked_fill(
attention_mask == 0, self.config.model.train_hparams.eps
)
Expand Down
96 changes: 88 additions & 8 deletions src/model/transformer.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,35 @@
import numpy as np
import torch
from pytorch_lightning import LightningModule
from torch import Tensor, nn
from torch.nn import functional as F

from src.model.decoder import Decoder
from src.model.encoder import Encoder
from src.utils import Config
from src.utils import Config, load_tokenizer
from src.loss import LabelSmoothingLoss
from src.bleu import BLEU


class Transformer(nn.Module):
class Transformer(LightningModule):
"""Transformer Model"""

def __init__(self, langpair: str, is_base: bool = True) -> None:
super().__init__()
configs = Config()
configs.add_tokenizer(langpair)
configs.add_model(is_base)
dim_model: int = configs.model.model_params.dim_model
vocab_size = configs.tokenizer.vocab_size
self.configs = Config()
self.configs.add_tokenizer(langpair)
self.configs.add_model(is_base)
self.dim_model: int = self.configs.model.model_params.dim_model
vocab_size = self.configs.tokenizer.vocab_size
tokenizer = load_tokenizer(langpair)
self.padding_idx = tokenizer.token_to_id("<pad>")

self.encoder = Encoder(langpair)
self.decoder = Decoder(langpair)
self.linear = nn.Linear(dim_model, vocab_size)
self.linear = nn.Linear(self.dim_model, vocab_size)

self.loss_with_label_smoothing = LabelSmoothingLoss(vocab_size, self.configs.model.train_hparams.label_smoothing_eps)
self.bleu = BLEU(langpair)

def forward(
self,
Expand All @@ -33,3 +44,72 @@ def forward(
)
output = self.linear(target_emb)
return output

def training_step(self, batch, batch_idx):
source = batch["source"]
target = batch["target"]
target_hat = self(
source["padded_token"],
source["mask"],
target["padded_token"],
target["mask"],
) # (batch_size, max_len, vocab_size)
target_hat.transpose_(1, 2) # (batch_size, vocab_size, max_len)
target["padded_token"] = target["padded_token"][
:, 1:
] # remove <bos> from target
target_hat = target_hat[:, :, :-1] # match shape with target
loss = self.loss_with_label_smoothing(target_hat, target['padded_token'], ignore_index=self.padding_idx)
perplexity = torch.exp(loss)
bleu = self.bleu.compute(target_hat, target['padded_token'], target['length'])
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) # pytorch lightning logger
self.log('train_bleu', bleu, on_step=True, on_epoch=True, prog_bar=True, logger=True) # pytorch lightning logger
self.logger.experiment.log({'train_loss': loss, 'train_perplexity': perplexity, 'train_bleu': bleu})
return loss

def validation_step(self, batch, batch_idx):
source = batch["source"]
target = batch["target"]
target_hat = self(
source["padded_token"],
source["mask"],
target["padded_token"],
target["mask"],
) # (batch_size, max_len, vocab_size)
target_hat.transpose_(1, 2) # (batch_size, vocab_size, max_len)
target["padded_token"] = target["padded_token"][
:, 1:
] # remove <bos> from target
target_hat = target_hat[:, :, :-1] # match shape with target
loss = self.loss_with_label_smoothing(target_hat, target['padded_token'], ignore_index=self.padding_idx)
perplexity = torch.exp(loss)
bleu = self.bleu.compute(target_hat, target['padded_token'], target['length'])
self.log('valid_loss', loss)
self.log('valid_bleu', bleu)
self.logger.experiment.log({'valid_loss': loss, 'valid_perplexity': perplexity, 'valid_bleu': bleu})
return loss

def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(),
betas=(
self.configs.model.train_hparams.beta_1,
self.configs.model.train_hparams.beta_2,
),
eps=self.configs.model.train_hparams.eps,
)
return optimizer

def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False):
# update optimizer every 2 steps if set 12500 tokens per batch (paper: 25000 tokens per step)
update_period = 25000 // self.configs.model.train_hparams.batch_size
warmup_steps = self.configs.model.train_hparams.warmup_steps
global_step = self.trainer.global_step
if global_step % update_period == 0:
step_num = global_step // update_period # lazy update
lr_scale = np.min([np.power(step_num, -0.5), step_num * np.power(warmup_steps, -1.5)])
lr = lr_scale * np.power(self.dim_model, -0.5)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
optimizer.step()
optimizer.zero_grad()
5 changes: 3 additions & 2 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path, PosixPath
from typing import List

import torch
from omegaconf import DictConfig, OmegaConf
from tokenizers import SentencePieceBPETokenizer

Expand Down Expand Up @@ -54,8 +55,8 @@ def load_tokenizer(langpair: str) -> SentencePieceBPETokenizer:
)

tokenizer = SentencePieceBPETokenizer(
vocab_file=str(vocab_filepath),
merges_file=str(merges_filepath),
vocab=str(vocab_filepath),
merges=str(merges_filepath),
)
return tokenizer

Expand Down
Loading