diff --git a/samples/gnmt/README.md b/samples/gnmt/README.md new file mode 100644 index 0000000..c2c761b --- /dev/null +++ b/samples/gnmt/README.md @@ -0,0 +1,17 @@ +# GNMT (Google Neural Machine Translation) Model + +This directory contains an implementation of GNMT that was adapted from the +code found in the [MLPerf training repository](https://github.com/mlperf/training/tree/master/rnn_translator). + +To launch an interactive Skyline profiling session for GNMT, run +``` +skyline interactive entry_point.py +``` + + +## License + +This code, with the exception of the `skyline_` prefixed functions in +`entry_point.py`, was adapted from the MLPerf training benchmarks and therefore +shares the same license. The unmodified license can be found in the `LICENSE` +file in the `seq2seq` directory. diff --git a/samples/gnmt/entry_point.py b/samples/gnmt/entry_point.py new file mode 100644 index 0000000..fb328d9 --- /dev/null +++ b/samples/gnmt/entry_point.py @@ -0,0 +1,327 @@ +import argparse +from ast import literal_eval + +import torch +import torch.nn as nn +import torch.optim +import torch.distributed as dist + +import seq2seq.data.config as config +import seq2seq.utils as utils +from seq2seq.models.gnmt import GNMT +from seq2seq.train.fp_optimizers import Fp32Optimizer +from seq2seq.train.lr_scheduler import WarmupMultiStepLR +from seq2seq.train.smoothing import LabelSmoothing + +torch.backends.cudnn.benchmark = True + + +def get_args(): + def exclusive_group(group, name, default, help): + destname = name.replace('-', '_') + subgroup = group.add_mutually_exclusive_group(required=False) + subgroup.add_argument(f'--{name}', dest=f'{destname}', + action='store_true', + help=f'{help} (use \'--no-{name}\' to disable)') + subgroup.add_argument(f'--no-{name}', dest=f'{destname}', + action='store_false', help=argparse.SUPPRESS) + subgroup.set_defaults(**{destname: default}) + + parser = argparse.ArgumentParser( + description='GNMT training', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + # dataset + dataset = parser.add_argument_group('dataset setup') + dataset.add_argument('--dataset-dir', default='data/wmt16_de_en', + help='path to the directory with training/test data') + dataset.add_argument('--max-size', default=None, type=int, + help='use at most MAX_SIZE elements from training \ + dataset (useful for benchmarking), by default \ + uses entire dataset') + + # results + results = parser.add_argument_group('results setup') + results.add_argument('--results-dir', default='results', + help='path to directory with results, it will be \ + automatically created if it does not exist') + results.add_argument('--save', default='gnmt', + help='defines subdirectory within RESULTS_DIR for \ + results from this training run') + results.add_argument('--print-freq', default=10, type=int, + help='print log every PRINT_FREQ batches') + + # model + model = parser.add_argument_group('model setup') + model.add_argument('--hidden-size', default=1024, type=int, + help='model hidden size') + model.add_argument('--num-layers', default=4, type=int, + help='number of RNN layers in encoder and in decoder') + model.add_argument('--dropout', default=0.2, type=float, + help='dropout applied to input of RNN cells') + + exclusive_group(group=model, name='share-embedding', default=True, + help='use shared embeddings for encoder and decoder') + + model.add_argument('--smoothing', default=0.1, type=float, + help='label smoothing, if equal to zero model will use \ + CrossEntropyLoss, if not zero model will be trained \ + with label smoothing loss') + + # setup + general = parser.add_argument_group('general setup') + general.add_argument('--math', default='fp32', choices=['fp16', 'fp32'], + help='arithmetic type') + general.add_argument('--seed', default=None, type=int, + help='master seed for random number generators, if \ + "seed" is undefined then the master seed will be \ + sampled from random.SystemRandom()') + + exclusive_group(group=general, name='eval', default=True, + help='run validation and test after every epoch') + exclusive_group(group=general, name='env', default=False, + help='print info about execution env') + exclusive_group(group=general, name='cuda', default=True, + help='enables cuda') + exclusive_group(group=general, name='cudnn', default=True, + help='enables cudnn') + + # training + training = parser.add_argument_group('training setup') + training.add_argument('--train-batch-size', default=128, type=int, + help='training batch size per worker') + training.add_argument('--train-global-batch-size', default=None, type=int, + help='global training batch size, this argument \ + does not have to be defined, if it is defined it \ + will be used to automatically \ + compute train_iter_size \ + using the equation: train_iter_size = \ + train_global_batch_size // (train_batch_size * \ + world_size)') + training.add_argument('--train-iter-size', metavar='N', default=1, + type=int, + help='training iter size, training loop will \ + accumulate gradients over N iterations and execute \ + optimizer every N steps') + training.add_argument('--epochs', default=8, type=int, + help='max number of training epochs') + + training.add_argument('--grad-clip', default=5.0, type=float, + help='enables gradient clipping and sets maximum \ + norm of gradients') + training.add_argument('--max-length-train', default=50, type=int, + help='maximum sequence length for training \ + (including special BOS and EOS tokens)') + training.add_argument('--min-length-train', default=0, type=int, + help='minimum sequence length for training \ + (including special BOS and EOS tokens)') + training.add_argument('--train-loader-workers', default=2, type=int, + help='number of workers for training data loading') + training.add_argument('--batching', default='bucketing', type=str, + choices=['random', 'sharding', 'bucketing'], + help='select batching algorithm') + training.add_argument('--shard-size', default=80, type=int, + help='shard size for "sharding" batching algorithm, \ + in multiples of global batch size') + training.add_argument('--num-buckets', default=5, type=int, + help='number of buckets for "bucketing" batching \ + algorithm') + + # optimizer + optimizer = parser.add_argument_group('optimizer setup') + optimizer.add_argument('--optimizer', type=str, default='Adam', + help='training optimizer') + optimizer.add_argument('--lr', type=float, default=1.00e-3, + help='learning rate') + + # scheduler + scheduler = parser.add_argument_group('learning rate scheduler setup') + scheduler.add_argument('--warmup-steps', type=str, default='200', + help='number of learning rate warmup iterations') + scheduler.add_argument('--remain-steps', type=str, default='0.666', + help='starting iteration for learning rate decay') + scheduler.add_argument('--decay-interval', type=str, default='None', + help='interval between learning rate decay steps') + scheduler.add_argument('--decay-steps', type=int, default=4, + help='max number of learning rate decay steps') + scheduler.add_argument('--decay-factor', type=float, default=0.5, + help='learning rate decay factor') + + # validation + val = parser.add_argument_group('validation setup') + val.add_argument('--val-batch-size', default=64, type=int, + help='batch size for validation') + val.add_argument('--max-length-val', default=125, type=int, + help='maximum sequence length for validation \ + (including special BOS and EOS tokens)') + val.add_argument('--min-length-val', default=0, type=int, + help='minimum sequence length for validation \ + (including special BOS and EOS tokens)') + val.add_argument('--val-loader-workers', default=0, type=int, + help='number of workers for validation data loading') + + # test + test = parser.add_argument_group('test setup') + test.add_argument('--test-batch-size', default=128, type=int, + help='batch size for test') + test.add_argument('--max-length-test', default=150, type=int, + help='maximum sequence length for test \ + (including special BOS and EOS tokens)') + test.add_argument('--min-length-test', default=0, type=int, + help='minimum sequence length for test \ + (including special BOS and EOS tokens)') + test.add_argument('--beam-size', default=5, type=int, + help='beam size') + test.add_argument('--len-norm-factor', default=0.6, type=float, + help='length normalization factor') + test.add_argument('--cov-penalty-factor', default=0.1, type=float, + help='coverage penalty factor') + test.add_argument('--len-norm-const', default=5.0, type=float, + help='length normalization constant') + test.add_argument('--intra-epoch-eval', metavar='N', default=0, type=int, + help='evaluate within training epoch, this option will \ + enable extra N equally spaced evaluations executed \ + during each training epoch') + test.add_argument('--test-loader-workers', default=0, type=int, + help='number of workers for test data loading') + + # checkpointing + chkpt = parser.add_argument_group('checkpointing setup') + chkpt.add_argument('--start-epoch', default=0, type=int, + help='manually set initial epoch counter') + chkpt.add_argument('--resume', default=None, type=str, metavar='PATH', + help='resumes training from checkpoint from PATH') + chkpt.add_argument('--save-all', action='store_true', default=False, + help='saves checkpoint after every epoch') + chkpt.add_argument('--save-freq', default=5000, type=int, + help='save checkpoint every SAVE_FREQ batches') + chkpt.add_argument('--keep-checkpoints', default=0, type=int, + help='keep only last KEEP_CHECKPOINTS checkpoints, \ + affects only checkpoints controlled by --save-freq \ + option') + + # benchmarking + benchmark = parser.add_argument_group('benchmark setup') + benchmark.add_argument('--target-bleu', default=24.0, type=float, + help='target accuracy, training will be stopped \ + when the target is achieved') + + # distributed + distributed = parser.add_argument_group('distributed setup') + distributed.add_argument('--rank', default=0, type=int, + help='global rank of the process, do not set!') + distributed.add_argument('--local_rank', default=0, type=int, + help='local rank of the process, do not set!') + + args = parser.parse_args([]) + + args.warmup_steps = literal_eval(args.warmup_steps) + args.remain_steps = literal_eval(args.remain_steps) + args.decay_interval = literal_eval(args.decay_interval) + + return args + + +def build_criterion(vocab_size, padding_idx, smoothing): + if smoothing == 0.: + loss_weight = torch.ones(vocab_size) + loss_weight[padding_idx] = 0 + criterion = nn.CrossEntropyLoss(weight=loss_weight, size_average=False) + else: + criterion = LabelSmoothing(padding_idx, smoothing) + + return criterion + + +class GNMTWithLoss(nn.Module): + def __init__(self, gnmt, loss_fn): + super().__init__() + self.gnmt = gnmt + self.loss_fn = loss_fn + + def forward(self, src, src_len, tgt, tgt_len): + out = self.gnmt(src, src_len, tgt[:-1]) + T, B = out.size(0), out.size(1) + tgt_labels = tgt[1:] + loss = self.loss_fn( + out.view(T * B, -1), + tgt_labels.contiguous().view(-1)) + return loss / B + + +def skyline_model_provider(): + args = get_args() + vocab_size = 32317 + model_config = { + 'hidden_size': args.hidden_size, + 'num_layers': args.num_layers, + 'dropout': args.dropout, + 'batch_first': False, + 'share_embedding': args.share_embedding, + } + model = GNMTWithLoss( + GNMT(vocab_size=vocab_size, **model_config), + build_criterion(vocab_size, config.PAD, args.smoothing), + ).cuda() + model.zero_grad() + return model + + +def skyline_input_provider(batch_size=64): + vocab_size = 32000 + src_len = 25 + tgt_len = 25 + + device = torch.device('cuda') + + src = torch.randint( + low=0, + high=vocab_size, + size=(src_len, batch_size), + dtype=torch.int64, + device=device, + ) + tgt = torch.randint( + low=0, + high=vocab_size, + size=(tgt_len, batch_size), + dtype=torch.int64, + device=device, + ) + + src_len_tensor = torch.tensor( + [src_len] * batch_size, dtype=torch.int64, device=device) + tgt_len_tensor = torch.tensor( + [tgt_len] * batch_size, dtype=torch.int64, device=device) + + return src, src_len_tensor, tgt, tgt_len_tensor + + +def skyline_iteration_provider(model): + args = get_args() + opt_config = { + 'optimizer': args.optimizer, + 'lr': args.lr, + } + scheduler_config = { + 'warmup_steps': args.warmup_steps, + 'remain_steps': args.remain_steps, + 'decay_interval': args.decay_interval, + 'decay_steps': args.decay_steps, + 'decay_factor': args.decay_factor, + } + + train_loader_len = 437268 + total_train_iters = train_loader_len // args.train_iter_size * args.epochs + opt_name = opt_config.pop('optimizer') + optimizer = torch.optim.__dict__[opt_name](model.parameters(), **opt_config) + scheduler = WarmupMultiStepLR( + optimizer, total_train_iters, **scheduler_config) + fp_optimizer = Fp32Optimizer(model, args.grad_clip) + + def iteration(src, src_len, tgt, tgt_len): + loss = model(src, src_len, tgt, tgt_len) + loss.backward() + fp_optimizer.step(optimizer, scheduler, update=True) + + return iteration diff --git a/samples/gnmt/seq2seq/LICENSE b/samples/gnmt/seq2seq/LICENSE new file mode 100644 index 0000000..4343c76 --- /dev/null +++ b/samples/gnmt/seq2seq/LICENSE @@ -0,0 +1,22 @@ +MIT License + +Copyright (c) 2017 Elad Hoffer +Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/samples/gnmt/seq2seq/data/config.py b/samples/gnmt/seq2seq/data/config.py new file mode 100644 index 0000000..0582e04 --- /dev/null +++ b/samples/gnmt/seq2seq/data/config.py @@ -0,0 +1,32 @@ +PAD_TOKEN = '' +UNK_TOKEN = '' +BOS_TOKEN = '' +EOS_TOKEN = '<\s>' + +# special PAD, UNKNOWN, BEGIN-OF-STRING, END-OF-STRING tokens +PAD, UNK, BOS, EOS = [0, 1, 2, 3] + +# path to the BPE vocabulary file, relative to the data directory, it should +# point to file generated by subword-nmt/get_vocab.py +VOCAB_FNAME = 'vocab.bpe.32000' + +# paths to source and target training files, relative to the data directory, it +# should point to BPE-encoded files, generated by subword-nmt/apply_bpe.py +SRC_TRAIN_FNAME = 'train.tok.clean.bpe.32000.en' +TGT_TRAIN_FNAME = 'train.tok.clean.bpe.32000.de' + +# paths to source and target validation files, relative to the data directory, +# it should point to BPE-encoded files, generated by subword-nmt/apply_bpe.py +SRC_VAL_FNAME = 'newstest_dev.tok.clean.bpe.32000.en' +TGT_VAL_FNAME = 'newstest_dev.tok.clean.bpe.32000.de' + +# path to the test source file, relative to the data directory, it should point +# to BPE-encoded file, generated by subword-nmt/apply_bpe.py +SRC_TEST_FNAME = 'newstest2014.tok.bpe.32000.en' + +# path to the test target file, relative to the data directory, it should point +# to plaintext file, tokenization is performed by the sacrebleu package +TGT_TEST_TARGET_FNAME = 'newstest2014.de' + +# path to the moses detokenizer, relative to the data directory +DETOKENIZER = 'mosesdecoder/scripts/tokenizer/detokenizer.perl' diff --git a/samples/gnmt/seq2seq/data/dataset.py b/samples/gnmt/seq2seq/data/dataset.py new file mode 100644 index 0000000..5d63884 --- /dev/null +++ b/samples/gnmt/seq2seq/data/dataset.py @@ -0,0 +1,385 @@ +import logging +from operator import itemgetter + +import torch +from torch.utils.data import DataLoader +from torch.utils.data import Dataset + +import seq2seq.data.config as config +from seq2seq.data.sampler import BucketingSampler +from seq2seq.data.sampler import DistributedSampler +from seq2seq.data.sampler import ShardingSampler +from seq2seq.data.sampler import StaticDistributedSampler + + +def build_collate_fn(batch_first=False, parallel=True, sort=False): + """ + Factory for collate_fn functions. + + :param batch_first: if True returns batches in (batch, seq) format, if + False returns in (seq, batch) format + :param parallel: if True builds batches from parallel corpus (src, tgt) + :param sort: if True sorts by src sequence length within each batch + """ + def collate_seq(seq): + """ + Builds batches for training or inference. + Batches are returned as pytorch tensors, with padding. + + :param seq: list of sequences + """ + lengths = [len(s) for s in seq] + batch_length = max(lengths) + + shape = (batch_length, len(seq)) + seq_tensor = torch.full(shape, config.PAD, dtype=torch.int64) + + for i, s in enumerate(seq): + end_seq = lengths[i] + seq_tensor[:end_seq, i].copy_(s[:end_seq]) + + if batch_first: + seq_tensor = seq_tensor.t() + + return (seq_tensor, lengths) + + def parallel_collate(seqs): + """ + Builds batches from parallel dataset (src, tgt), optionally sorts batch + by src sequence length. + + :param seqs: tuple of (src, tgt) sequences + """ + src_seqs, tgt_seqs = zip(*seqs) + if sort: + indices, src_seqs = zip(*sorted(enumerate(src_seqs), + key=lambda item: len(item[1]), + reverse=True)) + tgt_seqs = [tgt_seqs[idx] for idx in indices] + + return tuple([collate_seq(s) for s in [src_seqs, tgt_seqs]]) + + def single_collate(src_seqs): + """ + Builds batches from text dataset, optionally sorts batch by src + sequence length. + + :param src_seqs: source sequences + """ + if sort: + indices, src_seqs = zip(*sorted(enumerate(src_seqs), + key=lambda item: len(item[1]), + reverse=True)) + else: + indices = range(len(src_seqs)) + + return collate_seq(src_seqs), tuple(indices) + + if parallel: + return parallel_collate + else: + return single_collate + + +class TextDataset(Dataset): + def __init__(self, src_fname, tokenizer, min_len=None, max_len=None, + sort=False, max_size=None): + """ + Constructor for the TextDataset. Builds monolingual dataset. + + :param src_fname: path to the file with data + :param tokenizer: tokenizer + :param min_len: minimum sequence length + :param max_len: maximum sequence length + :param sort: sorts dataset by sequence length + :param max_size: loads at most 'max_size' samples from the input file, + if None loads the entire dataset + """ + + self.min_len = min_len + self.max_len = max_len + self.parallel = False + self.sorted = False + + self.src = self.process_data(src_fname, tokenizer, max_size) + + if min_len is not None and max_len is not None: + self.filter_data(min_len, max_len) + + lengths = [len(s) for s in self.src] + self.lengths = torch.tensor(lengths) + + if sort: + self.sort_by_length() + + def sort_by_length(self): + """ + Sorts dataset by the sequence length. + """ + self.lengths, indices = self.lengths.sort(descending=True) + + self.src = [self.src[idx] for idx in indices] + self.indices = indices.tolist() + self.sorted = True + + def unsort(self, array): + """ + "Unsorts" given array (restores original order of elements before + dataset was sorted by sequence length). + + :param array: array to be "unsorted" + """ + if self.sorted: + inverse = sorted(enumerate(self.indices), key=itemgetter(1)) + array = [array[i[0]] for i in inverse] + return array + + def filter_data(self, min_len, max_len): + """ + Preserves only samples which satisfy the following inequality: + min_len <= sample sequence length <= max_len + + :param min_len: minimum sequence length + :param max_len: maximum sequence length + """ + logging.info(f'Filtering data, min len: {min_len}, max len: {max_len}') + + initial_len = len(self.src) + filtered_src = [] + for src in self.src: + if min_len <= len(src) <= max_len: + filtered_src.append(src) + + self.src = filtered_src + filtered_len = len(self.src) + logging.info(f'Pairs before: {initial_len}, after: {filtered_len}') + + def process_data(self, fname, tokenizer, max_size): + """ + Loads data from the input file. + + :param fname: input file name + :param tokenizer: tokenizer + :param max_size: loads at most 'max_size' samples from the input file, + if None loads the entire dataset + """ + logging.info(f'Processing data from {fname}') + data = [] + with open(fname) as dfile: + for idx, line in enumerate(dfile): + if max_size and idx == max_size: + break + entry = tokenizer.segment(line) + entry = torch.tensor(entry) + data.append(entry) + return data + + def __len__(self): + return len(self.src) + + def __getitem__(self, idx): + return self.src[idx] + + def get_loader(self, batch_size=1, seeds=None, shuffle=False, + num_workers=0, batch_first=False, pad=False, + batching=None, batching_opt={}): + + collate_fn = build_collate_fn(batch_first, parallel=self.parallel, + sort=True) + + if shuffle: + if batching == 'random': + sampler = DistributedSampler(self, batch_size, seeds) + elif batching == 'sharding': + sampler = ShardingSampler(self, batch_size, seeds, + batching_opt['shard_size']) + elif batching == 'bucketing': + sampler = BucketingSampler(self, batch_size, seeds, + batching_opt['num_buckets']) + else: + raise NotImplementedError + else: + sampler = StaticDistributedSampler(self, batch_size, pad) + + return DataLoader(self, + batch_size=batch_size, + collate_fn=collate_fn, + sampler=sampler, + num_workers=num_workers, + pin_memory=True, + drop_last=False) + + +class ParallelDataset(TextDataset): + def __init__(self, src_fname, tgt_fname, tokenizer, + min_len, max_len, sort=False, max_size=None): + """ + Constructor for the ParallelDataset. + Tokenization is done when the data is loaded from the disk. + + :param src_fname: path to the file with src language data + :param tgt_fname: path to the file with tgt language data + :param tokenizer: tokenizer + :param min_len: minimum sequence length + :param max_len: maximum sequence length + :param sort: sorts dataset by sequence length + :param max_size: loads at most 'max_size' samples from the input file, + if None loads the entire dataset + """ + + self.min_len = min_len + self.max_len = max_len + self.parallel = True + self.sorted = False + + self.src = self.process_data(src_fname, tokenizer, max_size) + self.tgt = self.process_data(tgt_fname, tokenizer, max_size) + assert len(self.src) == len(self.tgt) + + self.filter_data(min_len, max_len) + assert len(self.src) == len(self.tgt) + + src_lengths = [len(s) for s in self.src] + tgt_lengths = [len(t) for t in self.tgt] + self.src_lengths = torch.tensor(src_lengths) + self.tgt_lengths = torch.tensor(tgt_lengths) + self.lengths = self.src_lengths + self.tgt_lengths + + if sort: + self.sort_by_length() + + def sort_by_length(self): + """ + Sorts dataset by the sequence length. + """ + self.lengths, indices = self.lengths.sort(descending=True) + + self.src = [self.src[idx] for idx in indices] + self.tgt = [self.tgt[idx] for idx in indices] + self.src_lengths = [self.src_lengths[idx] for idx in indices] + self.tgt_lengths = [self.tgt_lengths[idx] for idx in indices] + self.indices = indices.tolist() + self.sorted = True + + def filter_data(self, min_len, max_len): + """ + Preserves only samples which satisfy the following inequality: + min_len <= src sample sequence length <= max_len AND + min_len <= tgt sample sequence length <= max_len + + :param min_len: minimum sequence length + :param max_len: maximum sequence length + """ + logging.info(f'Filtering data, min len: {min_len}, max len: {max_len}') + + initial_len = len(self.src) + filtered_src = [] + filtered_tgt = [] + for src, tgt in zip(self.src, self.tgt): + if min_len <= len(src) <= max_len and \ + min_len <= len(tgt) <= max_len: + filtered_src.append(src) + filtered_tgt.append(tgt) + + self.src = filtered_src + self.tgt = filtered_tgt + filtered_len = len(self.src) + logging.info(f'Pairs before: {initial_len}, after: {filtered_len}') + + def __getitem__(self, idx): + return self.src[idx], self.tgt[idx] + + +class LazyParallelDataset(TextDataset): + def __init__(self, src_fname, tgt_fname, tokenizer, + min_len, max_len, sort=False, max_size=None): + """ + Constructor for the LazyParallelDataset. + Tokenization is done on the fly. + + :param src_fname: path to the file with src language data + :param tgt_fname: path to the file with tgt language data + :param tokenizer: tokenizer + :param min_len: minimum sequence length + :param max_len: maximum sequence length + :param sort: sorts dataset by sequence length + :param max_size: loads at most 'max_size' samples from the input file, + if None loads the entire dataset + """ + self.min_len = min_len + self.max_len = max_len + self.parallel = True + self.sorted = False + self.tokenizer = tokenizer + + self.raw_src = self.process_raw_data(src_fname, max_size) + self.raw_tgt = self.process_raw_data(tgt_fname, max_size) + assert len(self.raw_src) == len(self.raw_tgt) + + logging.info(f'Filtering data, min len: {min_len}, max len: {max_len}') + # Subtracting 2 because EOS and BOS are added later during tokenization + self.filter_raw_data(min_len - 2, max_len - 2) + assert len(self.raw_src) == len(self.raw_tgt) + + # Adding 2 because EOS and BOS are added later during tokenization + src_lengths = [i + 2 for i in self.src_len] + tgt_lengths = [i + 2 for i in self.tgt_len] + self.src_lengths = torch.tensor(src_lengths) + self.tgt_lengths = torch.tensor(tgt_lengths) + self.lengths = self.src_lengths + self.tgt_lengths + + def process_raw_data(self, fname, max_size): + """ + Loads data from the input file. + + :param fname: input file name + :param max_size: loads at most 'max_size' samples from the input file, + if None loads the entire dataset + """ + logging.info(f'Processing data from {fname}') + data = [] + with open(fname) as dfile: + for idx, line in enumerate(dfile): + if max_size and idx == max_size: + break + data.append(line) + return data + + def filter_raw_data(self, min_len, max_len): + """ + Preserves only samples which satisfy the following inequality: + min_len <= src sample sequence length <= max_len AND + min_len <= tgt sample sequence length <= max_len + + :param min_len: minimum sequence length + :param max_len: maximum sequence length + """ + initial_len = len(self.raw_src) + filtered_src = [] + filtered_tgt = [] + filtered_src_len = [] + filtered_tgt_len = [] + for src, tgt in zip(self.raw_src, self.raw_tgt): + src_len = src.count(' ') + 1 + tgt_len = tgt.count(' ') + 1 + if min_len <= src_len <= max_len and \ + min_len <= tgt_len <= max_len: + filtered_src.append(src) + filtered_tgt.append(tgt) + filtered_src_len.append(src_len) + filtered_tgt_len.append(tgt_len) + + self.raw_src = filtered_src + self.raw_tgt = filtered_tgt + self.src_len = filtered_src_len + self.tgt_len = filtered_tgt_len + filtered_len = len(self.raw_src) + logging.info(f'Pairs before: {initial_len}, after: {filtered_len}') + + def __getitem__(self, idx): + src = torch.tensor(self.tokenizer.segment(self.raw_src[idx])) + tgt = torch.tensor(self.tokenizer.segment(self.raw_tgt[idx])) + return src, tgt + + def __len__(self): + return len(self.raw_src) diff --git a/samples/gnmt/seq2seq/data/sampler.py b/samples/gnmt/seq2seq/data/sampler.py new file mode 100644 index 0000000..204b560 --- /dev/null +++ b/samples/gnmt/seq2seq/data/sampler.py @@ -0,0 +1,278 @@ +import logging + +import torch +from torch.utils.data.sampler import Sampler + +from seq2seq.utils import get_rank +from seq2seq.utils import get_world_size +from seq2seq.utils import gnmt_print + + +class DistributedSampler(Sampler): + def __init__(self, dataset, batch_size, seeds, world_size=None, rank=None): + """ + Constructor for the DistributedSampler. + + :param dataset: dataset + :param batch_size: local batch size + :param seeds: list of seeds, one seed for each training epoch + :param world_size: number of distributed workers + :param rank: rank of the current process + """ + if world_size is None: + world_size = get_world_size() + if rank is None: + rank = get_rank() + + self.dataset = dataset + self.world_size = world_size + self.rank = rank + self.epoch = 0 + self.seeds = seeds + + self.batch_size = batch_size + self.global_batch_size = batch_size * world_size + + self.data_len = len(self.dataset) + + self.num_samples = self.data_len // self.global_batch_size \ + * self.global_batch_size + + def init_rng(self): + """ + Creates new RNG, seed depends on current epoch idx. + """ + rng = torch.Generator() + seed = self.seeds[self.epoch] + logging.info(f'Sampler for epoch {self.epoch} uses seed {seed}') + rng.manual_seed(seed) + return rng + + def distribute_batches(self, indices): + """ + Assigns batches to workers. + Consecutive ranks are getting consecutive batches. + + :param indices: torch.tensor with batch indices + """ + assert len(indices) == self.num_samples + + indices = indices.view(-1, self.batch_size) + indices = indices[self.rank::self.world_size].contiguous() + indices = indices.view(-1) + indices = indices.tolist() + + assert len(indices) == self.num_samples // self.world_size + return indices + + def reshuffle_batches(self, indices, rng): + """ + Permutes global batches + + :param indices: torch.tensor with batch indices + :param rng: instance of torch.Generator + """ + indices = indices.view(-1, self.global_batch_size) + num_batches = indices.shape[0] + order = torch.randperm(num_batches, generator=rng) + indices = indices[order, :] + indices = indices.view(-1) + return indices + + def __iter__(self): + rng = self.init_rng() + # generate permutation + indices = torch.randperm(self.data_len, generator=rng) + + # make indices evenly divisible by (batch_size * world_size) + indices = indices[:self.num_samples] + + # assign batches to workers + indices = self.distribute_batches(indices) + return iter(indices) + + def set_epoch(self, epoch): + """ + Sets current epoch index. + Epoch index is used to seed RNG in __iter__() function. + + :param epoch: index of current epoch + """ + self.epoch = epoch + + def __len__(self): + return self.num_samples // self.world_size + + +class ShardingSampler(DistributedSampler): + def __init__(self, dataset, batch_size, seeds, shard_size, + world_size=None, rank=None): + """ + Constructor for the ShardingSampler. + + :param dataset: dataset + :param batch_size: local batch size + :param seeds: list of seeds, one seed for each training epoch + :param shard_size: number of global batches within one shard + :param world_size: number of distributed workers + :param rank: rank of the current process + """ + + super().__init__(dataset, batch_size, seeds, world_size, rank) + + self.shard_size = shard_size + self.num_samples = self.data_len // self.global_batch_size \ + * self.global_batch_size + + def __iter__(self): + rng = self.init_rng() + # generate permutation + indices = torch.randperm(self.data_len, generator=rng) + # make indices evenly divisible by (batch_size * world_size) + indices = indices[:self.num_samples] + + # splits the dataset into chunks of 'self.shard_size' global batches + # each, sorts by (src + tgt) sequence length within each chunk, + # reshuffles all global batches + shard_size = self.global_batch_size * self.shard_size + nshards = (self.num_samples + shard_size - 1) // shard_size + + lengths = self.dataset.lengths[indices] + + shards = [indices[i * shard_size:(i+1) * shard_size] for i in range(nshards)] + len_shards = [lengths[i * shard_size:(i+1) * shard_size] for i in range(nshards)] + + # sort by (src + tgt) sequence length within each shard + indices = [] + for len_shard in len_shards: + _, ind = len_shard.sort() + indices.append(ind) + + output = tuple(shard[idx] for shard, idx in zip(shards, indices)) + + # build batches + indices = torch.cat(output) + # perform global reshuffle of all global batches + indices = self.reshuffle_batches(indices, rng) + # distribute batches to individual workers + indices = self.distribute_batches(indices) + return iter(indices) + + +class BucketingSampler(DistributedSampler): + def __init__(self, dataset, batch_size, seeds, num_buckets, + world_size=None, rank=None): + """ + Constructor for the BucketingSampler. + + :param dataset: dataset + :param batch_size: local batch size + :param seeds: list of seeds, one seed for each training epoch + :param num_buckets: number of buckets + :param world_size: number of distributed workers + :param rank: rank of the current process + """ + + super().__init__(dataset, batch_size, seeds, world_size, rank) + + self.num_buckets = num_buckets + bucket_width = (dataset.max_len + num_buckets - 1) // num_buckets + + # assign sentences to buckets based on src and tgt sequence lengths + bucket_ids = torch.max(dataset.src_lengths // bucket_width, + dataset.tgt_lengths // bucket_width) + bucket_ids.clamp_(0, num_buckets - 1) + + # build buckets + all_indices = torch.tensor(range(self.data_len)) + self.buckets = [] + self.num_samples = 0 + global_bs = self.global_batch_size + + for bid in range(num_buckets): + # gather indices for current bucket + indices = all_indices[bucket_ids == bid] + self.buckets.append(indices) + + # count number of samples in current bucket + samples = len(indices) // global_bs * global_bs + self.num_samples += samples + + def __iter__(self): + rng = self.init_rng() + global_bs = self.global_batch_size + + indices = [] + for bid in range(self.num_buckets): + # random shuffle within current bucket + perm = torch.randperm(len(self.buckets[bid]), generator=rng) + bucket_indices = self.buckets[bid][perm] + + # make bucket_indices evenly divisible by global batch size + length = len(bucket_indices) // global_bs * global_bs + bucket_indices = bucket_indices[:length] + assert len(bucket_indices) % self.global_batch_size == 0 + + # add samples from current bucket to indices for current epoch + indices.append(bucket_indices) + + indices = torch.cat(indices) + assert len(indices) % self.global_batch_size == 0 + + # perform global reshuffle of all global batches + indices = self.reshuffle_batches(indices, rng) + # distribute batches to individual workers + indices = self.distribute_batches(indices) + return iter(indices) + + +class StaticDistributedSampler(Sampler): + def __init__(self, dataset, batch_size, pad, world_size=None, rank=None): + """ + Constructor for the StaticDistributedSampler. + + :param dataset: dataset + :param batch_size: local batch size + :param pad: if True: pads dataset to a multiple of global_batch_size + samples + :param world_size: number of distributed workers + :param rank: rank of the current process + """ + if world_size is None: + world_size = get_world_size() + if rank is None: + rank = get_rank() + + self.world_size = world_size + + global_batch_size = batch_size * world_size + + data_len = len(dataset) + num_samples = (data_len + global_batch_size - 1) \ + // global_batch_size * global_batch_size + self.num_samples = num_samples + + indices = list(range(data_len)) + if pad: + # pad dataset to a multiple of global_batch_size samples, uses + # sample with idx 0 as pad + indices += [0] * (num_samples - len(indices)) + else: + # temporary pad to a multiple of global batch size, pads with "-1" + # which is later removed from the list of indices + indices += [-1] * (num_samples - len(indices)) + indices = torch.tensor(indices) + + indices = indices.view(-1, batch_size) + indices = indices[rank::world_size].contiguous() + indices = indices.view(-1) + # remove temporary pad + indices = indices[indices != -1] + indices = indices.tolist() + self.indices = indices + + def __iter__(self): + return iter(self.indices) + + def __len__(self): + return len(self.indices) diff --git a/samples/gnmt/seq2seq/data/tokenizer.py b/samples/gnmt/seq2seq/data/tokenizer.py new file mode 100644 index 0000000..b2d9549 --- /dev/null +++ b/samples/gnmt/seq2seq/data/tokenizer.py @@ -0,0 +1,105 @@ +import logging +from collections import defaultdict +from functools import partial + +import seq2seq.data.config as config + + +class Tokenizer: + """ + Tokenizer class. + """ + def __init__(self, vocab_fname=None, pad=1, separator='@@'): + """ + Constructor for the Tokenizer class. + + :param vocab_fname: path to the file with vocabulary + :param pad: pads vocabulary to a multiple of 'pad' tokens + :param separator: tokenization separator + """ + if vocab_fname: + self.separator = separator + + logging.info(f'Building vocabulary from {vocab_fname}') + vocab = [config.PAD_TOKEN, config.UNK_TOKEN, + config.BOS_TOKEN, config.EOS_TOKEN] + + with open(vocab_fname) as vfile: + for line in vfile: + vocab.append(line.strip()) + + self.pad_vocabulary(vocab, pad) + + self.vocab_size = len(vocab) + logging.info(f'Size of vocabulary: {self.vocab_size}') + + self.tok2idx = defaultdict(partial(int, config.UNK)) + for idx, token in enumerate(vocab): + self.tok2idx[token] = idx + + self.idx2tok = {} + for key, value in self.tok2idx.items(): + self.idx2tok[value] = key + + def pad_vocabulary(self, vocab, pad): + """ + Pads vocabulary to a multiple of 'pad' tokens. + + :param vocab: list with vocabulary + :param pad: integer + """ + vocab_size = len(vocab) + padded_vocab_size = (vocab_size + pad - 1) // pad * pad + for i in range(0, padded_vocab_size - vocab_size): + token = f'madeupword{i:04d}' + vocab.append(token) + assert len(vocab) % pad == 0 + + def get_state(self): + logging.info(f'Saving state of the tokenizer') + state = { + 'separator': self.separator, + 'vocab_size': self.vocab_size, + 'tok2idx': self.tok2idx, + 'idx2tok': self.idx2tok, + } + return state + + def set_state(self, state): + logging.info(f'Restoring state of the tokenizer') + self.separator = state['separator'] + self.vocab_size = state['vocab_size'] + self.tok2idx = state['tok2idx'] + self.idx2tok = state['idx2tok'] + + def segment(self, line): + """ + Tokenizes single sentence and adds special BOS and EOS tokens. + + :param line: sentence + + returns: list representing tokenized sentence + """ + line = line.strip().split() + entry = [self.tok2idx[i] for i in line] + entry = [config.BOS] + entry + [config.EOS] + return entry + + def detokenize(self, inputs, delim=' '): + """ + Detokenizes single sentence and removes token separator characters. + + :param inputs: sequence of tokens + :param delim: tokenization delimiter + + returns: string representing detokenized sentence + """ + detok = delim.join([self.idx2tok[idx] for idx in inputs]) + detok = detok.replace(self.separator + ' ', '') + detok = detok.replace(self.separator, '') + + detok = detok.replace(config.BOS_TOKEN, '') + detok = detok.replace(config.EOS_TOKEN, '') + detok = detok.replace(config.PAD_TOKEN, '') + detok = detok.strip() + return detok diff --git a/samples/gnmt/seq2seq/inference/beam_search.py b/samples/gnmt/seq2seq/inference/beam_search.py new file mode 100644 index 0000000..11338b1 --- /dev/null +++ b/samples/gnmt/seq2seq/inference/beam_search.py @@ -0,0 +1,291 @@ +import torch + +from seq2seq.data.config import BOS +from seq2seq.data.config import EOS +from seq2seq.utils import gnmt_print + + +class SequenceGenerator: + """ + Generator for the autoregressive inference with beam search decoding. + """ + def __init__(self, model, beam_size=5, max_seq_len=100, cuda=False, + len_norm_factor=0.6, len_norm_const=5, + cov_penalty_factor=0.1): + """ + Constructor for the SequenceGenerator. + + Beam search decoding supports coverage penalty and length + normalization. For details, refer to Section 7 of the GNMT paper + (https://arxiv.org/pdf/1609.08144.pdf). + + :param model: model which implements generate method + :param beam_size: decoder beam size + :param max_seq_len: maximum decoder sequence length + :param cuda: whether to use cuda + :param len_norm_factor: length normalization factor + :param len_norm_const: length normalization constant + :param cov_penalty_factor: coverage penalty factor + """ + + self.model = model + self.cuda = cuda + self.beam_size = beam_size + self.max_seq_len = max_seq_len + self.len_norm_factor = len_norm_factor + self.len_norm_const = len_norm_const + self.cov_penalty_factor = cov_penalty_factor + + self.batch_first = self.model.batch_first + + def greedy_search(self, batch_size, initial_input, initial_context=None): + """ + Greedy decoder. + + :param batch_size: decoder batch size + :param initial_input: initial input, usually tensor of BOS tokens + :param initial_context: initial context, usually [encoder_context, + src_seq_lengths, None] + + returns: (translation, lengths, counter) + translation: (batch_size, max_seq_len) - indices of target tokens + lengths: (batch_size) - lengths of generated translations + counter: number of iterations of the decoding loop + """ + max_seq_len = self.max_seq_len + + translation = torch.zeros(batch_size, max_seq_len, dtype=torch.int64) + lengths = torch.ones(batch_size, dtype=torch.int64) + active = torch.arange(0, batch_size, dtype=torch.int64) + base_mask = torch.arange(0, batch_size, dtype=torch.int64) + + if self.cuda: + translation = translation.cuda() + lengths = lengths.cuda() + active = active.cuda() + base_mask = base_mask.cuda() + + translation[:, 0] = BOS + words, context = initial_input, initial_context + + if self.batch_first: + word_view = (-1, 1) + ctx_batch_dim = 0 + else: + word_view = (1, -1) + ctx_batch_dim = 1 + + counter = 0 + for idx in range(1, max_seq_len): + if not len(active): + break + counter += 1 + + words = words.view(word_view) + output = self.model.generate(words, context, 1) + words, logprobs, attn, context = output + words = words.view(-1) + + translation[active, idx] = words + lengths[active] += 1 + + terminating = (words == EOS) + + if terminating.any(): + not_terminating = ~terminating + + mask = base_mask[:len(active)] + mask = mask.masked_select(not_terminating) + active = active.masked_select(not_terminating) + + words = words[mask] + context[0] = context[0].index_select(ctx_batch_dim, mask) + context[1] = context[1].index_select(0, mask) + context[2] = context[2].index_select(1, mask) + + return translation, lengths, counter + + def beam_search(self, batch_size, initial_input, initial_context=None): + """ + Beam search decoder. + + :param batch_size: decoder batch size + :param initial_input: initial input, usually tensor of BOS tokens + :param initial_context: initial context, usually [encoder_context, + src_seq_lengths, None] + + returns: (translation, lengths, counter) + translation: (batch_size, max_seq_len) - indices of target tokens + lengths: (batch_size) - lengths of generated translations + counter: number of iterations of the decoding loop + """ + beam_size = self.beam_size + norm_const = self.len_norm_const + norm_factor = self.len_norm_factor + max_seq_len = self.max_seq_len + cov_penalty_factor = self.cov_penalty_factor + + translation = torch.zeros(batch_size * beam_size, max_seq_len, + dtype=torch.int64) + lengths = torch.ones(batch_size * beam_size, dtype=torch.int64) + scores = torch.zeros(batch_size * beam_size, dtype=torch.float32) + + active = torch.arange(0, batch_size * beam_size, dtype=torch.int64) + base_mask = torch.arange(0, batch_size * beam_size, dtype=torch.int64) + global_offset = torch.arange(0, batch_size * beam_size, beam_size, + dtype=torch.int64) + + eos_beam_fill = torch.tensor([0] + (beam_size - 1) * [float('-inf')]) + + if self.cuda: + translation = translation.cuda() + lengths = lengths.cuda() + active = active.cuda() + base_mask = base_mask.cuda() + scores = scores.cuda() + global_offset = global_offset.cuda() + eos_beam_fill = eos_beam_fill.cuda() + + translation[:, 0] = BOS + + words, context = initial_input, initial_context + + if self.batch_first: + word_view = (-1, 1) + ctx_batch_dim = 0 + attn_query_dim = 1 + else: + word_view = (1, -1) + ctx_batch_dim = 1 + attn_query_dim = 0 + + # replicate context + if self.batch_first: + # context[0] (encoder state): (batch, seq, feature) + _, seq, feature = context[0].shape + context[0].unsqueeze_(1) + context[0] = context[0].expand(-1, beam_size, -1, -1) + context[0] = context[0].contiguous().view(batch_size * beam_size, + seq, feature) + # context[0]: (batch * beam, seq, feature) + else: + # context[0] (encoder state): (seq, batch, feature) + seq, _, feature = context[0].shape + context[0].unsqueeze_(2) + context[0] = context[0].expand(-1, -1, beam_size, -1) + context[0] = context[0].contiguous().view(seq, batch_size * + beam_size, feature) + # context[0]: (seq, batch * beam, feature) + + # context[1] (encoder seq length): (batch) + context[1].unsqueeze_(1) + context[1] = context[1].expand(-1, beam_size) + context[1] = context[1].contiguous().view(batch_size * beam_size) + # context[1]: (batch * beam) + + accu_attn_scores = torch.zeros(batch_size * beam_size, seq) + if self.cuda: + accu_attn_scores = accu_attn_scores.cuda() + + counter = 0 + for idx in range(1, self.max_seq_len): + if not len(active): + break + counter += 1 + + eos_mask = (words == EOS) + eos_mask = eos_mask.view(-1, beam_size) + + terminating, _ = eos_mask.min(dim=1) + + lengths[active[~eos_mask.view(-1)]] += 1 + + output = self.model.generate(words, context, beam_size) + words, logprobs, attn, context = output + + attn = attn.float().squeeze(attn_query_dim) + attn = attn.masked_fill(eos_mask.view(-1).unsqueeze(1), 0) + accu_attn_scores[active] += attn + + # words: (batch, beam, k) + words = words.view(-1, beam_size, beam_size) + words = words.masked_fill(eos_mask.unsqueeze(2), EOS) + + # logprobs: (batch, beam, k) + logprobs = logprobs.float().view(-1, beam_size, beam_size) + + if eos_mask.any(): + logprobs[eos_mask] = eos_beam_fill + + active_scores = scores[active].view(-1, beam_size) + # new_scores: (batch, beam, k) + new_scores = active_scores.unsqueeze(2) + logprobs + + if idx == 1: + new_scores[:, 1:, :].fill_(float('-inf')) + + new_scores = new_scores.view(-1, beam_size * beam_size) + # index: (batch, beam) + _, index = new_scores.topk(beam_size, dim=1) + source_beam = index / beam_size + + new_scores = new_scores.view(-1, beam_size * beam_size) + best_scores = torch.gather(new_scores, 1, index) + scores[active] = best_scores.view(-1) + + words = words.view(-1, beam_size * beam_size) + words = torch.gather(words, 1, index) + + # words: (1, batch * beam) + words = words.view(word_view) + + offset = global_offset[:source_beam.shape[0]] + source_beam += offset.unsqueeze(1) + + translation[active, :] = translation[active[source_beam.view(-1)], :] + translation[active, idx] = words.view(-1) + + lengths[active] = lengths[active[source_beam.view(-1)]] + + context[2] = context[2].index_select(1, source_beam.view(-1)) + + if terminating.any(): + not_terminating = ~terminating + not_terminating = not_terminating.unsqueeze(1) + not_terminating = not_terminating.expand(-1, beam_size).contiguous() + + normalization_mask = active.view(-1, beam_size)[terminating] + + # length normalization + norm = lengths[normalization_mask].float() + norm = (norm_const + norm) / (norm_const + 1.0) + norm = norm ** norm_factor + + scores[normalization_mask] /= norm + + # coverage penalty + penalty = accu_attn_scores[normalization_mask] + penalty = penalty.clamp(0, 1) + penalty = penalty.log() + penalty[penalty == float('-inf')] = 0 + penalty = penalty.sum(dim=-1) + + scores[normalization_mask] += cov_penalty_factor * penalty + + mask = base_mask[:len(active)] + mask = mask.masked_select(not_terminating.view(-1)) + + words = words.index_select(ctx_batch_dim, mask) + context[0] = context[0].index_select(ctx_batch_dim, mask) + context[1] = context[1].index_select(0, mask) + context[2] = context[2].index_select(1, mask) + + active = active.masked_select(not_terminating.view(-1)) + + scores = scores.view(batch_size, beam_size) + _, idx = scores.max(dim=1) + + translation = translation[idx + global_offset, :] + lengths = lengths[idx + global_offset] + + return translation, lengths, counter diff --git a/samples/gnmt/seq2seq/inference/inference.py b/samples/gnmt/seq2seq/inference/inference.py new file mode 100644 index 0000000..5ec3a4b --- /dev/null +++ b/samples/gnmt/seq2seq/inference/inference.py @@ -0,0 +1,289 @@ +import contextlib +import logging +import os +import subprocess +import time + +import torch +import torch.distributed as dist + +import seq2seq.data.config as config +from seq2seq.inference.beam_search import SequenceGenerator +from seq2seq.utils import AverageMeter +from seq2seq.utils import barrier +from seq2seq.utils import get_rank +from seq2seq.utils import get_world_size + + +def gather_predictions(preds): + world_size = get_world_size() + if world_size > 1: + all_preds = [preds.new(preds.size(0), preds.size(1)) for i in range(world_size)] + dist.all_gather(all_preds, preds) + preds = torch.cat(all_preds) + return preds + + +class Translator: + """ + Translator runs validation on test dataset, executes inference, optionally + computes BLEU score using sacrebleu. + """ + def __init__(self, + model, + tokenizer, + loader, + beam_size=5, + len_norm_factor=0.6, + len_norm_const=5.0, + cov_penalty_factor=0.1, + max_seq_len=50, + cuda=False, + print_freq=1, + dataset_dir=None, + save_path=None, + target_bleu=None): + + self.model = model + self.tokenizer = tokenizer + self.loader = loader + self.insert_target_start = [config.BOS] + self.insert_src_start = [config.BOS] + self.insert_src_end = [config.EOS] + self.batch_first = model.batch_first + self.cuda = cuda + self.beam_size = beam_size + self.print_freq = print_freq + self.dataset_dir = dataset_dir + self.target_bleu = target_bleu + self.save_path = save_path + + self.distributed = (get_world_size() > 1) + + self.generator = SequenceGenerator( + model=self.model, + beam_size=beam_size, + max_seq_len=max_seq_len, + cuda=cuda, + len_norm_factor=len_norm_factor, + len_norm_const=len_norm_const, + cov_penalty_factor=cov_penalty_factor) + + def build_eval_path(self, epoch, iteration): + """ + Appends index of the current epoch and index of the current iteration + to the name of the file with results. + + :param epoch: index of the current epoch + :param iteration: index of the current iteration + """ + if iteration is not None: + eval_fname = f'eval_epoch_{epoch}_iter_{iteration}' + else: + eval_fname = f'eval_epoch_{epoch}' + eval_path = os.path.join(self.save_path, eval_fname) + return eval_path + + def run(self, calc_bleu=True, epoch=None, iteration=None, eval_path=None, + summary=False, reference_path=None): + """ + Runs translation on test dataset. + + :param calc_bleu: if True compares results with reference and computes + BLEU score + :param epoch: index of the current epoch + :param iteration: index of the current iteration + :param eval_path: path to the file for saving results + :param summary: if True prints summary + :param reference_path: path to the file with reference translation + """ + if self.cuda: + test_bleu = torch.cuda.FloatTensor([0]) + break_training = torch.cuda.LongTensor([0]) + else: + test_bleu = torch.FloatTensor([0]) + break_training = torch.LongTensor([0]) + + if eval_path is None: + eval_path = self.build_eval_path(epoch, iteration) + detok_eval_path = eval_path + '.detok' + + with contextlib.suppress(FileNotFoundError): + os.remove(eval_path) + os.remove(detok_eval_path) + + rank = get_rank() + logging.info(f'Running evaluation on test set') + self.model.eval() + torch.cuda.empty_cache() + + output = self.evaluate(epoch, iteration, summary) + output = output[:len(self.loader.dataset)] + output = self.loader.dataset.unsort(output) + + if rank == 0: + with open(eval_path, 'a') as eval_file: + eval_file.writelines(output) + if calc_bleu: + self.run_detokenizer(eval_path) + test_bleu[0] = self.run_sacrebleu(detok_eval_path, reference_path) + if summary: + logging.info(f'BLEU on test dataset: {test_bleu[0]:.2f}') + + if self.target_bleu and test_bleu[0] >= self.target_bleu: + logging.info(f'Target accuracy reached') + break_training[0] = 1 + + barrier() + torch.cuda.empty_cache() + logging.info(f'Finished evaluation on test set') + + if self.distributed: + dist.broadcast(break_training, 0) + dist.broadcast(test_bleu, 0) + + return test_bleu[0].item(), break_training[0].item() + + def evaluate(self, epoch, iteration, summary): + """ + Runs evaluation on test dataset. + + :param epoch: index of the current epoch + :param iteration: index of the current iteration + :param summary: if True prints summary + """ + batch_time = AverageMeter(False) + tot_tok_per_sec = AverageMeter(False) + iterations = AverageMeter(False) + enc_seq_len = AverageMeter(False) + dec_seq_len = AverageMeter(False) + stats = {} + + output = [] + + for i, (src, indices) in enumerate(self.loader): + translate_timer = time.time() + src, src_length = src + + batch_size = self.loader.batch_size + global_batch_size = batch_size * get_world_size() + beam_size = self.beam_size + + bos = [self.insert_target_start] * (batch_size * beam_size) + bos = torch.LongTensor(bos) + if self.batch_first: + bos = bos.view(-1, 1) + else: + bos = bos.view(1, -1) + + src_length = torch.LongTensor(src_length) + stats['total_enc_len'] = int(src_length.sum()) + + if self.cuda: + src = src.cuda() + src_length = src_length.cuda() + bos = bos.cuda() + + with torch.no_grad(): + context = self.model.encode(src, src_length) + context = [context, src_length, None] + + if beam_size == 1: + generator = self.generator.greedy_search + else: + generator = self.generator.beam_search + preds, lengths, counter = generator(batch_size, bos, context) + + stats['total_dec_len'] = lengths.sum().item() + stats['iters'] = counter + + indices = torch.tensor(indices).to(preds) + preds = preds.scatter(0, indices.unsqueeze(1).expand_as(preds), preds) + + preds = gather_predictions(preds).cpu() + + for pred in preds: + pred = pred.tolist() + detok = self.tokenizer.detokenize(pred) + output.append(detok + '\n') + + elapsed = time.time() - translate_timer + batch_time.update(elapsed, batch_size) + + total_tokens = stats['total_dec_len'] + stats['total_enc_len'] + ttps = total_tokens / elapsed + tot_tok_per_sec.update(ttps, batch_size) + + iterations.update(stats['iters']) + enc_seq_len.update(stats['total_enc_len'] / batch_size, batch_size) + dec_seq_len.update(stats['total_dec_len'] / batch_size, batch_size) + + if i % self.print_freq == 0: + log = [] + log += f'TEST ' + if epoch is not None: + log += f'[{epoch}]' + if iteration is not None: + log += f'[{iteration}]' + log += f'[{i}/{len(self.loader)}]\t' + log += f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + log += f'Decoder iters {iterations.val:.1f} ({iterations.avg:.1f})\t' + log += f'Tok/s {tot_tok_per_sec.val:.0f} ({tot_tok_per_sec.avg:.0f})' + log = ''.join(log) + logging.info(log) + + tot_tok_per_sec.reduce('sum') + enc_seq_len.reduce('mean') + dec_seq_len.reduce('mean') + batch_time.reduce('mean') + iterations.reduce('sum') + + if summary and get_rank() == 0: + time_per_sentence = (batch_time.avg / global_batch_size) + log = [] + log += f'TEST SUMMARY:\n' + log += f'Lines translated: {len(self.loader.dataset)}\t' + log += f'Avg total tokens/s: {tot_tok_per_sec.avg:.0f}\n' + log += f'Avg time per batch: {batch_time.avg:.3f} s\t' + log += f'Avg time per sentence: {1000*time_per_sentence:.3f} ms\n' + log += f'Avg encoder seq len: {enc_seq_len.avg:.2f}\t' + log += f'Avg decoder seq len: {dec_seq_len.avg:.2f}\t' + log += f'Total decoder iterations: {int(iterations.sum)}' + log = ''.join(log) + logging.info(log) + + return output + + def run_detokenizer(self, eval_path): + """ + Executes moses detokenizer on eval_path file and saves result to + eval_path + ".detok" file. + + :param eval_path: path to the tokenized input + """ + logging.info('Running detokenizer') + detok_path = os.path.join(self.dataset_dir, config.DETOKENIZER) + detok_eval_path = eval_path + '.detok' + + with open(detok_eval_path, 'w') as detok_eval_file, \ + open(eval_path, 'r') as eval_file: + subprocess.run(['perl', f'{detok_path}'], stdin=eval_file, + stdout=detok_eval_file, stderr=subprocess.DEVNULL) + + def run_sacrebleu(self, detok_eval_path, reference_path): + """ + Executes sacrebleu and returns BLEU score. + + :param detok_eval_path: path to the test file + :param reference_path: path to the reference file + """ + if reference_path is None: + reference_path = os.path.join(self.dataset_dir, + config.TGT_TEST_TARGET_FNAME) + sacrebleu_params = '--score-only -lc --tokenize intl' + logging.info(f'Running sacrebleu (parameters: {sacrebleu_params})') + sacrebleu = subprocess.run([f'sacrebleu --input {detok_eval_path} \ + {reference_path} {sacrebleu_params}'], + stdout=subprocess.PIPE, shell=True) + test_bleu = float(sacrebleu.stdout.strip()) + return test_bleu diff --git a/samples/gnmt/seq2seq/models/attention.py b/samples/gnmt/seq2seq/models/attention.py new file mode 100644 index 0000000..230f1bc --- /dev/null +++ b/samples/gnmt/seq2seq/models/attention.py @@ -0,0 +1,164 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.parameter import Parameter + + +class BahdanauAttention(nn.Module): + """ + Bahdanau Attention (https://arxiv.org/abs/1409.0473) + Implementation is very similar to tf.contrib.seq2seq.BahdanauAttention + """ + def __init__(self, query_size, key_size, num_units, normalize=False, + batch_first=False, init_weight=0.1): + """ + Constructor for the BahdanauAttention. + + :param query_size: feature dimension for query + :param key_size: feature dimension for keys + :param num_units: internal feature dimension + :param normalize: whether to normalize energy term + :param batch_first: if True batch size is the 1st dimension, if False + the sequence is first and batch size is second + :param init_weight: range for uniform initializer used to initialize + Linear key and query transform layers and linear_att vector + """ + super(BahdanauAttention, self).__init__() + + self.normalize = normalize + self.batch_first = batch_first + self.num_units = num_units + + self.linear_q = nn.Linear(query_size, num_units, bias=False) + self.linear_k = nn.Linear(key_size, num_units, bias=False) + nn.init.uniform_(self.linear_q.weight.data, -init_weight, init_weight) + nn.init.uniform_(self.linear_k.weight.data, -init_weight, init_weight) + + self.linear_att = Parameter(torch.Tensor(num_units)) + + self.mask = None + + if self.normalize: + self.normalize_scalar = Parameter(torch.Tensor(1)) + self.normalize_bias = Parameter(torch.Tensor(num_units)) + else: + self.register_parameter('normalize_scalar', None) + self.register_parameter('normalize_bias', None) + + self.reset_parameters(init_weight) + + def reset_parameters(self, init_weight): + """ + Sets initial random values for trainable parameters. + """ + stdv = 1. / math.sqrt(self.num_units) + self.linear_att.data.uniform_(-init_weight, init_weight) + + if self.normalize: + self.normalize_scalar.data.fill_(stdv) + self.normalize_bias.data.zero_() + + def set_mask(self, context_len, context): + """ + sets self.mask which is applied before softmax + ones for inactive context fields, zeros for active context fields + + :param context_len: b + :param context: if batch_first: (b x t_k x n) else: (t_k x b x n) + + self.mask: (b x t_k) + """ + + if self.batch_first: + max_len = context.size(1) + else: + max_len = context.size(0) + + indices = torch.arange(0, max_len, dtype=torch.int64, + device=context.device) + self.mask = indices >= (context_len.unsqueeze(1)) + + def calc_score(self, att_query, att_keys): + """ + Calculate Bahdanau score + + :param att_query: b x t_q x n + :param att_keys: b x t_k x n + + returns: b x t_q x t_k scores + """ + + b, t_k, n = att_keys.size() + t_q = att_query.size(1) + + att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n) + att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n) + sum_qk = att_query + att_keys + + if self.normalize: + sum_qk = sum_qk + self.normalize_bias + linear_att = self.linear_att / self.linear_att.norm() + linear_att = linear_att * self.normalize_scalar + else: + linear_att = self.linear_att + + out = torch.tanh(sum_qk).matmul(linear_att) + return out + + def forward(self, query, keys): + """ + + :param query: if batch_first: (b x t_q x n) else: (t_q x b x n) + :param keys: if batch_first: (b x t_k x n) else (t_k x b x n) + + :returns: (context, scores_normalized) + context: if batch_first: (b x t_q x n) else (t_q x b x n) + scores_normalized: if batch_first (b x t_q x t_k) else (t_q x b x t_k) + """ + + # first dim of keys and query has to be 'batch', it's needed for bmm + if not self.batch_first: + keys = keys.transpose(0, 1) + if query.dim() == 3: + query = query.transpose(0, 1) + + if query.dim() == 2: + single_query = True + query = query.unsqueeze(1) + else: + single_query = False + + b = query.size(0) + t_k = keys.size(1) + t_q = query.size(1) + + # FC layers to transform query and key + processed_query = self.linear_q(query) + processed_key = self.linear_k(keys) + + # scores: (b x t_q x t_k) + scores = self.calc_score(processed_query, processed_key) + + if self.mask is not None: + mask = self.mask.unsqueeze(1).expand(b, t_q, t_k) + # I can't use -INF because of overflow check in pytorch + scores.data.masked_fill_(mask, -65504.0) + + # Normalize the scores, softmax over t_k + scores_normalized = F.softmax(scores, dim=-1) + + # Calculate the weighted average of the attention inputs according to + # the scores + # context: (b x t_q x n) + context = torch.bmm(scores_normalized, keys) + + if single_query: + context = context.squeeze(1) + scores_normalized = scores_normalized.squeeze(1) + elif not self.batch_first: + context = context.transpose(0, 1) + scores_normalized = scores_normalized.transpose(0, 1) + + return context, scores_normalized diff --git a/samples/gnmt/seq2seq/models/decoder.py b/samples/gnmt/seq2seq/models/decoder.py new file mode 100644 index 0000000..7e22a1c --- /dev/null +++ b/samples/gnmt/seq2seq/models/decoder.py @@ -0,0 +1,222 @@ +import itertools + +import torch +import torch.nn as nn + +import seq2seq.data.config as config +from seq2seq.models.attention import BahdanauAttention +from seq2seq.utils import init_lstm_ + + +class RecurrentAttention(nn.Module): + """ + LSTM wrapped with an attention module. + """ + def __init__(self, input_size=1024, context_size=1024, hidden_size=1024, + num_layers=1, batch_first=False, dropout=0.2, + init_weight=0.1): + """ + Constructor for the RecurrentAttention. + + :param input_size: number of features in input tensor + :param context_size: number of features in output from encoder + :param hidden_size: internal hidden size + :param num_layers: number of layers in LSTM + :param batch_first: if True the model uses (batch,seq,feature) tensors, + if false the model uses (seq, batch, feature) + :param dropout: probability of dropout (on input to LSTM layer) + :param init_weight: range for the uniform initializer + """ + + super(RecurrentAttention, self).__init__() + + self.rnn = nn.LSTM(input_size, hidden_size, num_layers, bias=True, + batch_first=batch_first) + init_lstm_(self.rnn, init_weight) + + self.attn = BahdanauAttention(hidden_size, context_size, context_size, + normalize=True, batch_first=batch_first) + + self.dropout = nn.Dropout(dropout) + + def forward(self, inputs, hidden, context, context_len): + """ + Execute RecurrentAttention. + + :param inputs: tensor with inputs + :param hidden: hidden state for LSTM layer + :param context: context tensor from encoder + :param context_len: vector of encoder sequence lengths + + :returns (rnn_outputs, hidden, attn_output, attn_scores) + """ + # set attention mask, sequences have different lengths, this mask + # allows to include only valid elements of context in attention's + # softmax + self.attn.set_mask(context_len, context) + + inputs = self.dropout(inputs) + rnn_outputs, hidden = self.rnn(inputs, hidden) + attn_outputs, scores = self.attn(rnn_outputs, context) + + return rnn_outputs, hidden, attn_outputs, scores + + +class Classifier(nn.Module): + """ + Fully-connected classifier + """ + def __init__(self, in_features, out_features, init_weight=0.1): + """ + Constructor for the Classifier. + + :param in_features: number of input features + :param out_features: number of output features (size of vocabulary) + :param init_weight: range for the uniform initializer + """ + super(Classifier, self).__init__() + self.classifier = nn.Linear(in_features, out_features) + nn.init.uniform_(self.classifier.weight.data, -init_weight, init_weight) + nn.init.uniform_(self.classifier.bias.data, -init_weight, init_weight) + + def forward(self, x): + """ + Execute the classifier. + + :param x: output from decoder + """ + out = self.classifier(x) + return out + + +class ResidualRecurrentDecoder(nn.Module): + """ + Decoder with Embedding, LSTM layers, attention, residual connections and + optinal dropout. + + Attention implemented in this module is different than the attention + discussed in the GNMT arxiv paper. In this model the output from the first + LSTM layer of the decoder goes into the attention module, then the + re-weighted context is concatenated with inputs to all subsequent LSTM + layers in the decoder at the current timestep. + + Residual connections are enabled after 3rd LSTM layer, dropout is applied + on inputs to LSTM layers. + """ + def __init__(self, vocab_size, hidden_size=1024, num_layers=4, dropout=0.2, + batch_first=False, embedder=None, init_weight=0.1): + """ + Constructor of the ResidualRecurrentDecoder. + + :param vocab_size: size of vocabulary + :param hidden_size: hidden size for LSMT layers + :param num_layers: number of LSTM layers + :param dropout: probability of dropout (on input to LSTM layers) + :param batch_first: if True the model uses (batch,seq,feature) tensors, + if false the model uses (seq, batch, feature) + :param embedder: instance of nn.Embedding, if None constructor will + create new embedding layer + :param init_weight: range for the uniform initializer + """ + super(ResidualRecurrentDecoder, self).__init__() + + self.num_layers = num_layers + + self.att_rnn = RecurrentAttention(hidden_size, hidden_size, + hidden_size, num_layers=1, + batch_first=batch_first, + dropout=dropout) + + self.rnn_layers = nn.ModuleList() + for _ in range(num_layers - 1): + self.rnn_layers.append( + nn.LSTM(2 * hidden_size, hidden_size, num_layers=1, bias=True, + batch_first=batch_first)) + + for lstm in self.rnn_layers: + init_lstm_(lstm, init_weight) + + if embedder is not None: + self.embedder = embedder + else: + self.embedder = nn.Embedding(vocab_size, hidden_size, + padding_idx=config.PAD) + nn.init.uniform_(self.embedder.weight.data, -init_weight, init_weight) + + self.classifier = Classifier(hidden_size, vocab_size) + self.dropout = nn.Dropout(p=dropout) + + def init_hidden(self, hidden): + """ + Converts flattened hidden state (from sequence generator) into a tuple + of hidden states. + + :param hidden: None or flattened hidden state for decoder RNN layers + """ + if hidden is not None: + # per-layer chunks + hidden = hidden.chunk(self.num_layers) + # (h, c) chunks for LSTM layer + hidden = tuple(i.chunk(2) for i in hidden) + else: + hidden = [None] * self.num_layers + + self.next_hidden = [] + return hidden + + def append_hidden(self, h): + """ + Appends the hidden vector h to the list of internal hidden states. + + :param h: hidden vector + """ + if self.inference: + self.next_hidden.append(h) + + def package_hidden(self): + """ + Flattens the hidden state from all LSTM layers into one tensor (for + the sequence generator). + """ + if self.inference: + hidden = torch.cat(tuple(itertools.chain(*self.next_hidden))) + else: + hidden = None + return hidden + + def forward(self, inputs, context, inference=False): + """ + Execute the decoder. + + :param inputs: tensor with inputs to the decoder + :param context: state of encoder, encoder sequence lengths and hidden + state of decoder's LSTM layers + :param inference: if True stores and repackages hidden state + """ + self.inference = inference + + enc_context, enc_len, hidden = context + hidden = self.init_hidden(hidden) + + x = self.embedder(inputs) + + x, h, attn, scores = self.att_rnn(x, hidden[0], enc_context, enc_len) + self.append_hidden(h) + + x = torch.cat((x, attn), dim=2) + x = self.dropout(x) + x, h = self.rnn_layers[0](x, hidden[1]) + self.append_hidden(h) + + for i in range(1, len(self.rnn_layers)): + residual = x + x = torch.cat((x, attn), dim=2) + x = self.dropout(x) + x, h = self.rnn_layers[i](x, hidden[i + 1]) + self.append_hidden(h) + x = x + residual + + x = self.classifier(x) + hidden = self.package_hidden() + + return x, scores, [enc_context, enc_len, hidden] diff --git a/samples/gnmt/seq2seq/models/encoder.py b/samples/gnmt/seq2seq/models/encoder.py new file mode 100644 index 0000000..0b1944a --- /dev/null +++ b/samples/gnmt/seq2seq/models/encoder.py @@ -0,0 +1,95 @@ +import torch.nn as nn +from torch.nn.utils.rnn import pack_padded_sequence +from torch.nn.utils.rnn import pad_packed_sequence + +import seq2seq.data.config as config +from seq2seq.utils import init_lstm_ + + +class ResidualRecurrentEncoder(nn.Module): + """ + Encoder with Embedding, LSTM layers, residual connections and optional + dropout. + + The first LSTM layer is bidirectional and uses variable sequence length + API, the remaining (num_layers-1) layers are unidirectional. Residual + connections are enabled after third LSTM layer, dropout is applied on + inputs to LSTM layers. + """ + def __init__(self, vocab_size, hidden_size=1024, num_layers=4, dropout=0.2, + batch_first=False, embedder=None, init_weight=0.1): + """ + Constructor for the ResidualRecurrentEncoder. + + :param vocab_size: size of vocabulary + :param hidden_size: hidden size for LSTM layers + :param num_layers: number of LSTM layers, 1st layer is bidirectional + :param dropout: probability of dropout (on input to LSTM layers) + :param batch_first: if True the model uses (batch,seq,feature) tensors, + if false the model uses (seq, batch, feature) + :param embedder: instance of nn.Embedding, if None constructor will + create new embedding layer + :param init_weight: range for the uniform initializer + """ + super(ResidualRecurrentEncoder, self).__init__() + self.batch_first = batch_first + self.rnn_layers = nn.ModuleList() + # 1st LSTM layer, bidirectional + self.rnn_layers.append( + nn.LSTM(hidden_size, hidden_size, num_layers=1, bias=True, + batch_first=batch_first, bidirectional=True)) + + # 2nd LSTM layer, with 2x larger input_size + self.rnn_layers.append( + nn.LSTM((2 * hidden_size), hidden_size, num_layers=1, bias=True, + batch_first=batch_first)) + + # Remaining LSTM layers + for _ in range(num_layers - 2): + self.rnn_layers.append( + nn.LSTM(hidden_size, hidden_size, num_layers=1, bias=True, + batch_first=batch_first)) + + for lstm in self.rnn_layers: + init_lstm_(lstm, init_weight) + + self.dropout = nn.Dropout(p=dropout) + + if embedder is not None: + self.embedder = embedder + else: + self.embedder = nn.Embedding(vocab_size, hidden_size, + padding_idx=config.PAD) + nn.init.uniform_(self.embedder.weight.data, -init_weight, init_weight) + + def forward(self, inputs, lengths): + """ + Execute the encoder. + + :param inputs: tensor with indices from the vocabulary + :param lengths: vector with sequence lengths (excluding padding) + + returns: tensor with encoded sequences + """ + x = self.embedder(inputs) + + # bidirectional layer + x = self.dropout(x) + x = pack_padded_sequence(x, lengths.cpu().numpy(), + batch_first=self.batch_first) + x, _ = self.rnn_layers[0](x) + x, _ = pad_packed_sequence(x, batch_first=self.batch_first) + + # 1st unidirectional layer + x = self.dropout(x) + x, _ = self.rnn_layers[1](x) + + # the rest of unidirectional layers, + # with residual connections starting from 3rd layer + for i in range(2, len(self.rnn_layers)): + residual = x + x = self.dropout(x) + x, _ = self.rnn_layers[i](x) + x = x + residual + + return x diff --git a/samples/gnmt/seq2seq/models/gnmt.py b/samples/gnmt/seq2seq/models/gnmt.py new file mode 100644 index 0000000..4832f72 --- /dev/null +++ b/samples/gnmt/seq2seq/models/gnmt.py @@ -0,0 +1,52 @@ +import torch.nn as nn + +import seq2seq.data.config as config +from seq2seq.models.decoder import ResidualRecurrentDecoder +from seq2seq.models.encoder import ResidualRecurrentEncoder +from seq2seq.models.seq2seq_base import Seq2Seq +from seq2seq.utils import gnmt_print + + +class GNMT(Seq2Seq): + """ + GNMT v2 model + """ + def __init__(self, vocab_size, hidden_size=1024, num_layers=4, dropout=0.2, + batch_first=False, share_embedding=True): + """ + Constructor for the GNMT v2 model. + + :param vocab_size: size of vocabulary (number of tokens) + :param hidden_size: internal hidden size of the model + :param num_layers: number of layers, applies to both encoder and + decoder + :param dropout: probability of dropout (in encoder and decoder) + :param batch_first: if True the model uses (batch,seq,feature) tensors, + if false the model uses (seq, batch, feature) + :param share_embedding: if True embeddings are shared between encoder + and decoder + """ + + super(GNMT, self).__init__(batch_first=batch_first) + + if share_embedding: + embedder = nn.Embedding(vocab_size, hidden_size, + padding_idx=config.PAD) + nn.init.uniform_(embedder.weight.data, -0.1, 0.1) + else: + embedder = None + + self.encoder = ResidualRecurrentEncoder(vocab_size, hidden_size, + num_layers, dropout, + batch_first, embedder) + + self.decoder = ResidualRecurrentDecoder(vocab_size, hidden_size, + num_layers, dropout, + batch_first, embedder) + + def forward(self, input_encoder, input_enc_len, input_decoder): + context = self.encode(input_encoder, input_enc_len) + context = (context, input_enc_len, None) + output, _, _ = self.decode(input_decoder, context) + + return output diff --git a/samples/gnmt/seq2seq/models/seq2seq_base.py b/samples/gnmt/seq2seq/models/seq2seq_base.py new file mode 100644 index 0000000..4b06317 --- /dev/null +++ b/samples/gnmt/seq2seq/models/seq2seq_base.py @@ -0,0 +1,64 @@ +import torch.nn as nn +from torch.nn.functional import log_softmax + + +class Seq2Seq(nn.Module): + """ + Generic Seq2Seq module, with an encoder and a decoder. + """ + def __init__(self, encoder=None, decoder=None, batch_first=False): + """ + Constructor for the Seq2Seq module. + + :param encoder: encoder module + :param decoder: decoder module + :param batch_first: if True the model uses (batch, seq, feature) + tensors, if false the model uses (seq, batch, feature) tensors + """ + super(Seq2Seq, self).__init__() + self.encoder = encoder + self.decoder = decoder + self.batch_first = batch_first + + def encode(self, inputs, lengths): + """ + Applies the encoder to inputs with a given input sequence lengths. + + :param inputs: tensor with inputs (batch, seq_len) if 'batch_first' + else (seq_len, batch) + :param lengths: vector with sequence lengths (excluding padding) + """ + return self.encoder(inputs, lengths) + + def decode(self, inputs, context, inference=False): + """ + Applies the decoder to inputs, given the context from the encoder. + + :param inputs: tensor with inputs (batch, seq_len) if 'batch_first' + else (seq_len, batch) + :param context: context from the encoder + :param inference: if True inference mode, if False training mode + """ + return self.decoder(inputs, context, inference) + + def generate(self, inputs, context, beam_size): + """ + Autoregressive generator, works with SequenceGenerator class. + Executes decoder (in inference mode), applies log_softmax and topK for + inference with beam search decoding. + + :param inputs: tensor with inputs to the decoder + :param context: context from the encoder + :param beam_size: beam size for the generator + + returns: (words, logprobs, scores, new_context) + words: indices of topK tokens + logprobs: log probabilities of topK tokens + scores: scores from the attention module (for coverage penalty) + new_context: new decoder context, includes new hidden states for + decoder RNN cells + """ + logits, scores, new_context = self.decode(inputs, context, True) + logprobs = log_softmax(logits, dim=-1) + logprobs, words = logprobs.topk(beam_size, dim=-1) + return words, logprobs, scores, new_context diff --git a/samples/gnmt/seq2seq/train/fp_optimizers.py b/samples/gnmt/seq2seq/train/fp_optimizers.py new file mode 100644 index 0000000..8f65caa --- /dev/null +++ b/samples/gnmt/seq2seq/train/fp_optimizers.py @@ -0,0 +1,161 @@ +import logging +import math + +import torch +from torch.nn.utils import clip_grad_norm_ + + +class Fp16Optimizer: + """ + Mixed precision optimizer with dynamic loss scaling and backoff. + https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html#scalefactor + """ + @staticmethod + def set_grads(params, params_with_grad): + """ + Copies gradients from param_with_grad to params + + :param params: dst parameters + :param params_with_grad: src parameters + """ + for param, param_w_grad in zip(params, params_with_grad): + if param.grad is None: + param.grad = torch.nn.Parameter(torch.empty_like(param)) + param.grad.data.copy_(param_w_grad.grad.data) + + @staticmethod + def set_weights(params, new_params): + """ + Copies parameters from new_params to params + + :param params: dst parameters + :param new_params: src parameters + """ + for param, new_param in zip(params, new_params): + param.data.copy_(new_param.data) + + def __init__(self, fp16_model, grad_clip=float('inf'), loss_scale=8192, + dls_downscale=2, dls_upscale=2, dls_upscale_interval=128): + """ + Constructor for the Fp16Optimizer. + + :param fp16_model: model (previously casted to half) + :param grad_clip: coefficient for gradient clipping, max L2 norm of the + gradients + :param loss_scale: initial loss scale + :param dls_downscale: loss downscale factor, loss scale is divided by + this factor when NaN/INF occurs in the gradients + :param dls_upscale: loss upscale factor, loss scale is multiplied by + this factor if previous dls_upscale_interval batches finished + successfully + :param dls_upscale_interval: interval for loss scale upscaling + """ + logging.info('Initializing fp16 optimizer') + self.initialize_model(fp16_model) + + self.since_last_invalid = 0 + self.loss_scale = loss_scale + self.dls_downscale = dls_downscale + self.dls_upscale = dls_upscale + self.dls_upscale_interval = dls_upscale_interval + self.grad_clip = grad_clip + + def initialize_model(self, model): + """ + Initializes internal state and build fp32 master copy of weights. + + :param model: fp16 model + """ + logging.info('Initializing fp32 clone weights') + self.fp16_model = model + self.fp16_model.zero_grad() + self.fp32_params = [param.to(torch.float32).detach() + for param in model.parameters()] + + for param in self.fp32_params: + param.requires_grad = True + + def step(self, loss, optimizer, scheduler, update=True): + """ + Performs one step of the optimizer. + Applies loss scaling, computes gradients in fp16, converts gradients to + fp32, inverts scaling and applies optional gradient norm clipping. + If gradients are finite, it applies update to fp32 master weights and + copies updated parameters to fp16 model for the next iteration. If + gradients are not finite, it skips the batch and adjusts scaling factor + for the next iteration. + + :param loss: value of loss function + :param optimizer: optimizer + :param update: if True executes weight update + """ + loss *= self.loss_scale + loss.backward() + + if update: + self.set_grads(self.fp32_params, self.fp16_model.parameters()) + if self.loss_scale != 1.0: + for param in self.fp32_params: + param.grad.data /= self.loss_scale + + norm = clip_grad_norm_(self.fp32_params, self.grad_clip) + + if math.isfinite(norm): + scheduler.step() + optimizer.step() + self.set_weights(self.fp16_model.parameters(), + self.fp32_params) + self.since_last_invalid += 1 + else: + self.loss_scale /= self.dls_downscale + self.since_last_invalid = 0 + logging.info(f'Gradient norm: {norm}') + logging.info(f'Skipped batch, new scale: {self.loss_scale}') + + if self.since_last_invalid >= self.dls_upscale_interval: + self.loss_scale *= self.dls_upscale + self.loss_scale = min(self.loss_scale, 8192.0) + logging.info(f'Upscaling, new scale: {self.loss_scale}') + self.since_last_invalid = 0 + + self.fp16_model.zero_grad() + + +class Fp32Optimizer: + """ + Standard optimizer, computes backward and applies weight update. + """ + def __init__(self, model, grad_clip=None): + """ + Constructor for the Fp32Optimizer + + :param model: model + :param grad_clip: coefficient for gradient clipping, max L2 norm of the + gradients + """ + self.initialize_model(model) + self.grad_clip = grad_clip + + def initialize_model(self, model): + """ + Initializes state of the model. + + :param model: model + """ + self.model = model + self.model.zero_grad() + + def step(self, optimizer, scheduler, update=True): + """ + Performs one step of the optimizer. + + :param loss: value of loss function + :param optimizer: optimizer + :param update: if True executes weight update + """ + if update: + if self.grad_clip != float('inf'): + clip_grad_norm_(self.model.parameters(), self.grad_clip) + optimizer.step() + scheduler.step() + self.model.zero_grad() diff --git a/samples/gnmt/seq2seq/train/lr_scheduler.py b/samples/gnmt/seq2seq/train/lr_scheduler.py new file mode 100644 index 0000000..de1d459 --- /dev/null +++ b/samples/gnmt/seq2seq/train/lr_scheduler.py @@ -0,0 +1,95 @@ +import logging +import math + +import torch + +from seq2seq.utils import gnmt_print + + +def perhaps_convert_float(param, total): + if isinstance(param, float): + param = int(param * total) + return param + + +class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): + """ + Learning rate scheduler with exponential warmup and step decay. + """ + def __init__(self, optimizer, iterations, warmup_steps=0, + remain_steps=1.0, decay_interval=None, decay_steps=4, + decay_factor=0.5, last_epoch=-1): + """ + Constructor of WarmupMultiStepLR. + + Parameters: warmup_steps, remain_steps and decay_interval accept both + integers and floats as an input. Integer input is interpreted as + absolute index of iteration, float input is interpreted as a fraction + of total training iterations (epochs * steps_per_epoch). + + If decay_interval is None then the decay will happen at regulary spaced + intervals ('decay_steps' decays between iteration indices + 'remain_steps' and 'iterations'). + + :param optimizer: instance of optimizer + :param iterations: total number of training iterations + :param warmup_steps: number of warmup iterations + :param remain_steps: start decay at 'remain_steps' iteration + :param decay_interval: interval between LR decay steps + :param decay_steps: max number of decay steps + :param decay_factor: decay factor + :param last_epoch: the index of last iteration + """ + + # iterations before learning rate reaches base LR + self.warmup_steps = perhaps_convert_float(warmup_steps, iterations) + + # iteration at which decay starts + self.remain_steps = perhaps_convert_float(remain_steps, iterations) + + # number of steps between each decay + if decay_interval is None: + # decay at regulary spaced intervals + decay_iterations = iterations - self.remain_steps + self.decay_interval = decay_iterations // (decay_steps) + self.decay_interval = max(self.decay_interval, 1) + else: + self.decay_interval = perhaps_convert_float(decay_interval, + iterations) + + # multiplicative decay factor + self.decay_factor = decay_factor + + # max number of decay steps + self.decay_steps = decay_steps + + if self.warmup_steps > self.remain_steps: + logging.warn(f'warmup_steps should not be larger than ' + f'remain_steps, setting warmup_steps=remain_steps') + self.warmup_steps = self.remain_steps + + super(WarmupMultiStepLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch <= self.warmup_steps: + # exponential lr warmup + if self.warmup_steps != 0: + warmup_factor = math.exp(math.log(0.01) / self.warmup_steps) + else: + warmup_factor = 1.0 + inv_decay = warmup_factor ** (self.warmup_steps - self.last_epoch) + lr = [base_lr * inv_decay for base_lr in self.base_lrs] + + elif self.last_epoch >= self.remain_steps: + # step decay + decay_iter = self.last_epoch - self.remain_steps + num_decay_steps = decay_iter // self.decay_interval + 1 + num_decay_steps = min(num_decay_steps, self.decay_steps) + lr = [ + base_lr * (self.decay_factor ** num_decay_steps) + for base_lr in self.base_lrs + ] + else: + # base lr + lr = [base_lr for base_lr in self.base_lrs] + return lr diff --git a/samples/gnmt/seq2seq/train/smoothing.py b/samples/gnmt/seq2seq/train/smoothing.py new file mode 100644 index 0000000..d5b60b2 --- /dev/null +++ b/samples/gnmt/seq2seq/train/smoothing.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn + + +class LabelSmoothing(nn.Module): + """ + NLL loss with label smoothing. + """ + def __init__(self, padding_idx, smoothing=0.0): + """ + Constructor for the LabelSmoothing module. + + :param padding_idx: index of the PAD token + :param smoothing: label smoothing factor + """ + super(LabelSmoothing, self).__init__() + self.padding_idx = padding_idx + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + + def forward(self, x, target): + logprobs = torch.nn.functional.log_softmax(x, dim=-1, + dtype=torch.float32) + + non_pad_mask = (target != self.padding_idx) + nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) + nll_loss = nll_loss.squeeze(1)[non_pad_mask] + smooth_loss = -logprobs.mean(dim=-1)[non_pad_mask] + loss = self.confidence * nll_loss + self.smoothing * smooth_loss + return loss.sum() diff --git a/samples/gnmt/seq2seq/train/trainer.py b/samples/gnmt/seq2seq/train/trainer.py new file mode 100644 index 0000000..e3c2a5f --- /dev/null +++ b/samples/gnmt/seq2seq/train/trainer.py @@ -0,0 +1,384 @@ +import logging +import os +import time +from itertools import cycle + +import numpy as np +import torch +import torch.optim +import torch.utils.data +from apex.parallel import DistributedDataParallel as DDP + +from seq2seq.train.fp_optimizers import Fp16Optimizer +from seq2seq.train.fp_optimizers import Fp32Optimizer +from seq2seq.train.lr_scheduler import WarmupMultiStepLR +from seq2seq.utils import AverageMeter +from seq2seq.utils import gnmt_print +from seq2seq.utils import sync_workers + + +class Seq2SeqTrainer: + """ + Seq2SeqTrainer + """ + def __init__(self, + model, + criterion, + opt_config, + scheduler_config, + print_freq=10, + save_freq=1000, + grad_clip=float('inf'), + batch_first=False, + save_info={}, + save_path='.', + train_iterations=0, + checkpoint_filename='checkpoint%s.pth', + keep_checkpoints=5, + math='fp32', + cuda=True, + distributed=False, + intra_epoch_eval=0, + iter_size=1, + translator=None, + verbose=False): + """ + Constructor for the Seq2SeqTrainer. + + :param model: model to train + :param criterion: criterion (loss function) + :param opt_config: dictionary with options for the optimizer + :param scheduler_config: dictionary with options for the learning rate + scheduler + :param print_freq: prints short summary every 'print_freq' iterations + :param save_freq: saves checkpoint every 'save_freq' iterations + :param grad_clip: coefficient for gradient clipping + :param batch_first: if True the model uses (batch,seq,feature) tensors, + if false the model uses (seq, batch, feature) + :param save_info: dict with additional state stored in each checkpoint + :param save_path: path to the directiory for checkpoints + :param train_iterations: total number of training iterations to execute + :param checkpoint_filename: name of files with checkpoints + :param keep_checkpoints: max number of checkpoints to keep + :param math: arithmetic type + :param cuda: if True use cuda, if False train on cpu + :param distributed: if True run distributed training + :param intra_epoch_eval: number of additional eval runs within each + training epoch + :param iter_size: number of iterations between weight updates + :param translator: instance of Translator, runs inference on test set + :param verbose: enables verbose logging + """ + super(Seq2SeqTrainer, self).__init__() + self.model = model + self.criterion = criterion + self.epoch = 0 + self.save_info = save_info + self.save_path = save_path + self.save_freq = save_freq + self.save_counter = 0 + self.checkpoint_filename = checkpoint_filename + self.checkpoint_counter = cycle(range(keep_checkpoints)) + self.opt_config = opt_config + self.cuda = cuda + self.distributed = distributed + self.print_freq = print_freq + self.batch_first = batch_first + self.verbose = verbose + self.loss = None + self.translator = translator + self.intra_epoch_eval = intra_epoch_eval + self.iter_size = iter_size + + if cuda: + self.model = self.model.cuda() + self.criterion = self.criterion.cuda() + + if math == 'fp16': + self.model = self.model.half() + + if distributed: + self.model = DDP(self.model) + + if math == 'fp16': + self.fp_optimizer = Fp16Optimizer(self.model, grad_clip) + params = self.fp_optimizer.fp32_params + elif math == 'fp32': + self.fp_optimizer = Fp32Optimizer(self.model, grad_clip) + params = self.model.parameters() + + opt_name = opt_config.pop('optimizer') + self.optimizer = torch.optim.__dict__[opt_name](params, **opt_config) + logging.info(f'Using optimizer: {self.optimizer}') + + self.scheduler = WarmupMultiStepLR(self.optimizer, train_iterations, + **scheduler_config) + + def iterate(self, src, tgt, update=True, training=True): + """ + Performs one iteration of the training/validation. + + :param src: batch of examples from the source language + :param tgt: batch of examples from the target language + :param update: if True: optimizer does update of the weights + :param training: if True: executes optimizer + """ + src, src_length = src + tgt, tgt_length = tgt + src_length = torch.LongTensor(src_length) + tgt_length = torch.LongTensor(tgt_length) + + num_toks = {} + num_toks['tgt'] = int(sum(tgt_length - 1)) + num_toks['src'] = int(sum(src_length)) + + if self.cuda: + src = src.cuda() + src_length = src_length.cuda() + tgt = tgt.cuda() + + if self.batch_first: + output = self.model(src, src_length, tgt[:, :-1]) + tgt_labels = tgt[:, 1:] + T, B = output.size(1), output.size(0) + else: + output = self.model(src, src_length, tgt[:-1]) + tgt_labels = tgt[1:] + T, B = output.size(0), output.size(1) + + loss = self.criterion(output.view(T * B, -1), + tgt_labels.contiguous().view(-1)) + + loss_per_batch = loss.item() + loss /= (B * self.iter_size) + + if training: + self.fp_optimizer.step(loss, self.optimizer, self.scheduler, + update) + + loss_per_token = loss_per_batch / num_toks['tgt'] + loss_per_sentence = loss_per_batch / B + + return loss_per_token, loss_per_sentence, num_toks + + def feed_data(self, data_loader, training=True): + """ + Runs training or validation on batches from data_loader. + + :param data_loader: data loader + :param training: if True runs training else runs validation + """ + if training: + assert self.optimizer is not None + eval_fractions = np.linspace(0, 1, self.intra_epoch_eval+2)[1:-1] + iters_with_update = len(data_loader) // self.iter_size + eval_iters = (eval_fractions * iters_with_update).astype(int) + eval_iters = eval_iters * self.iter_size + eval_iters = set(eval_iters) + + batch_time = AverageMeter() + data_time = AverageMeter() + losses_per_token = AverageMeter(skip_first=False) + losses_per_sentence = AverageMeter(skip_first=False) + + tot_tok_time = AverageMeter() + src_tok_time = AverageMeter() + tgt_tok_time = AverageMeter() + + batch_size = data_loader.batch_size + + end = time.time() + for i, (src, tgt) in enumerate(data_loader): + self.save_counter += 1 + # measure data loading time + data_time.update(time.time() - end) + + update = False + if i % self.iter_size == self.iter_size - 1: + update = True + + # do a train/evaluate iteration + stats = self.iterate(src, tgt, update, training=training) + loss_per_token, loss_per_sentence, num_toks = stats + + # measure accuracy and record loss + losses_per_token.update(loss_per_token, num_toks['tgt']) + losses_per_sentence.update(loss_per_sentence, batch_size) + + # measure elapsed time + elapsed = time.time() - end + batch_time.update(elapsed) + src_tok_time.update(num_toks['src'] / elapsed) + tgt_tok_time.update(num_toks['tgt'] / elapsed) + tot_num_toks = num_toks['tgt'] + num_toks['src'] + tot_tok_time.update(tot_num_toks / elapsed) + self.loss = losses_per_token.avg + + if training and i in eval_iters: + test_bleu, _ = self.translator.run(calc_bleu=True, + epoch=self.epoch, + iteration=i) + + log = [] + log += [f'TRAIN [{self.epoch}][{i}/{len(data_loader)}]'] + log += [f'BLEU: {test_bleu:.2f}'] + log = '\t'.join(log) + logging.info(log) + + self.model.train() + self.preallocate(data_loader, training=True) + + if i % self.print_freq == 0: + phase = 'TRAIN' if training else 'VALIDATION' + log = [] + log += [f'{phase} [{self.epoch}][{i}/{len(data_loader)}]'] + log += [f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'] + log += [f'Data {data_time.val:.2e} ({data_time.avg:.2e})'] + log += [f'Tok/s {tot_tok_time.val:.0f} ({tot_tok_time.avg:.0f})'] + if self.verbose: + log += [f'Src tok/s {src_tok_time.val:.0f} ({src_tok_time.avg:.0f})'] + log += [f'Tgt tok/s {tgt_tok_time.val:.0f} ({tgt_tok_time.avg:.0f})'] + log += [f'Loss/sentence {losses_per_sentence.val:.1f} ({losses_per_sentence.avg:.1f})'] + log += [f'Loss/tok {losses_per_token.val:.4f} ({losses_per_token.avg:.4f})'] + if training: + lr = self.optimizer.param_groups[0]['lr'] + log += [f'LR {lr:.3e}'] + log = '\t'.join(log) + logging.info(log) + + save_chkpt = (self.save_counter % self.save_freq) == (self.save_freq - 1) + if training and save_chkpt: + self.save_counter = 0 + self.save_info['iteration'] = i + identifier = next(self.checkpoint_counter, -1) + if identifier != -1: + with sync_workers() as rank: + if rank == 0: + self.save(identifier=identifier) + + end = time.time() + + tot_tok_time.reduce('sum') + losses_per_token.reduce('mean') + + return losses_per_token.avg, tot_tok_time.avg + + def preallocate(self, data_loader, training): + """ + Generates maximum sequence length batch and runs forward and backward + pass without updating model parameters. + + :param data_loader: data loader + :param training: if True preallocates memory for backward pass + """ + batch_size = data_loader.batch_size + max_len = data_loader.dataset.max_len + + src_length = [max_len] * batch_size + tgt_length = [max_len] * batch_size + + if self.batch_first: + shape = (batch_size, max_len) + else: + shape = (max_len, batch_size) + + src = torch.full(shape, 4, dtype=torch.int64) + tgt = torch.full(shape, 4, dtype=torch.int64) + src = src, src_length + tgt = tgt, tgt_length + self.iterate(src, tgt, update=False, training=training) + self.model.zero_grad() + + def optimize(self, data_loader): + """ + Sets model in training mode, preallocates memory and runs training on + data provided by data_loader. + + :param data_loader: data loader + """ + torch.set_grad_enabled(True) + self.model.train() + torch.cuda.empty_cache() + self.preallocate(data_loader, training=True) + output = self.feed_data(data_loader, training=True) + self.model.zero_grad() + torch.cuda.empty_cache() + return output + + def evaluate(self, data_loader): + """ + Sets model in eval mode, disables gradients, preallocates memory and + runs validation on data provided by data_loader. + + :param data_loader: data loader + """ + torch.set_grad_enabled(False) + self.model.eval() + torch.cuda.empty_cache() + self.preallocate(data_loader, training=False) + output = self.feed_data(data_loader, training=False) + self.model.zero_grad() + torch.cuda.empty_cache() + return output + + def load(self, filename): + """ + Loads checkpoint from filename. + + :param filename: path to the checkpoint file + """ + if os.path.isfile(filename): + checkpoint = torch.load(filename, map_location={'cuda:0': 'cpu'}) + if self.distributed: + self.model.module.load_state_dict(checkpoint['state_dict']) + else: + self.model.load_state_dict(checkpoint['state_dict']) + self.fp_optimizer.initialize_model(self.model) + self.optimizer.load_state_dict(checkpoint['optimizer']) + self.scheduler.load_state_dict(checkpoint['scheduler']) + self.epoch = checkpoint['epoch'] + self.loss = checkpoint['loss'] + logging.info(f'Loaded checkpoint {filename} (epoch {self.epoch})') + else: + logging.error(f'Invalid checkpoint: {filename}') + + def save(self, identifier=None, is_best=False, save_all=False): + """ + Stores checkpoint to a file. + + :param identifier: identifier for periodic checkpoint + :param is_best: if True stores checkpoint to 'model_best.pth' + :param save_all: if True stores checkpoint after completed training + epoch + """ + + def write_checkpoint(state, filename): + filename = os.path.join(self.save_path, filename) + logging.info(f'Saving model to {filename}') + torch.save(state, filename) + + if self.distributed: + model_state = self.model.module.state_dict() + else: + model_state = self.model.state_dict() + + state = { + 'epoch': self.epoch, + 'state_dict': model_state, + 'optimizer': self.optimizer.state_dict(), + 'scheduler': self.scheduler.state_dict(), + 'loss': getattr(self, 'loss', None), + } + state = dict(list(state.items()) + list(self.save_info.items())) + + if identifier is not None: + filename = self.checkpoint_filename % identifier + write_checkpoint(state, filename) + + if is_best: + filename = 'model_best.pth' + write_checkpoint(state, filename) + + if save_all: + filename = f'checkpoint_epoch_{self.epoch:03d}.pth' + write_checkpoint(state, filename) diff --git a/samples/gnmt/seq2seq/utils.py b/samples/gnmt/seq2seq/utils.py new file mode 100644 index 0000000..c692a01 --- /dev/null +++ b/samples/gnmt/seq2seq/utils.py @@ -0,0 +1,346 @@ +import logging.config +import os +import random +import sys +import time +from contextlib import contextmanager + +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.init as init +import torch.utils.collect_env + + +def gnmt_print(*args, **kwargs): + """ + Wrapper for MLPerf compliance logging calls. + All arguments but 'sync' are passed to mlperf_log.gnmt_print function. + If 'sync' is set to True then the wrapper will synchronize all distributed + workers. 'sync' should be set to True for all compliance tags that require + accurate timing (RUN_START, RUN_STOP etc.) + """ + if kwargs.pop('sync'): + barrier() + if get_rank() == 0: + kwargs['stack_offset'] = 2 + + +def init_lstm_(lstm, init_weight=0.1): + """ + Initializes weights of LSTM layer. + Weights and biases are initialized with uniform(-init_weight, init_weight) + distribution. + + :param lstm: instance of torch.nn.LSTM + :param init_weight: range for the uniform initializer + """ + # Initialize hidden-hidden weights + init.uniform_(lstm.weight_hh_l0.data, -init_weight, init_weight) + # Initialize input-hidden weights: + init.uniform_(lstm.weight_ih_l0.data, -init_weight, init_weight) + + # Initialize bias. PyTorch LSTM has two biases, one for input-hidden GEMM + # and the other for hidden-hidden GEMM. Here input-hidden bias is + # initialized with uniform distribution and hidden-hidden bias is + # initialized with zeros. + init.uniform_(lstm.bias_ih_l0.data, -init_weight, init_weight) + init.zeros_(lstm.bias_hh_l0.data) + + if lstm.bidirectional: + init.uniform_(lstm.weight_hh_l0_reverse.data, -init_weight, init_weight) + init.uniform_(lstm.weight_ih_l0_reverse.data, -init_weight, init_weight) + + init.uniform_(lstm.bias_ih_l0_reverse.data, -init_weight, init_weight) + init.zeros_(lstm.bias_hh_l0_reverse.data) + + +def generate_seeds(rng, size): + """ + Generate list of random seeds + + :param rng: random number generator + :param size: length of the returned list + """ + seeds = [rng.randint(0, 2**32 - 1) for _ in range(size)] + return seeds + + +def broadcast_seeds(seeds, device): + """ + Broadcasts random seeds to all distributed workers. + Returns list of random seeds (broadcasted from workers with rank 0). + + :param seeds: list of seeds (integers) + :param device: torch.device + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + seeds_tensor = torch.LongTensor(seeds).to(device) + torch.distributed.broadcast(seeds_tensor, 0) + seeds = seeds_tensor.tolist() + return seeds + + +def setup_seeds(master_seed, epochs, device): + """ + Generates seeds from one master_seed. + Function returns (worker_seeds, shuffling_seeds), worker_seeds are later + used to initialize per-worker random number generators (mostly for + dropouts), shuffling_seeds are for RNGs resposible for reshuffling the + dataset before each epoch. + Seeds are generated on worker with rank 0 and broadcasted to all other + workers. + + :param master_seed: master RNG seed used to initialize other generators + :param epochs: number of epochs + :param device: torch.device (used for distributed.broadcast) + """ + if master_seed is None: + # random master seed, random.SystemRandom() uses /dev/urandom on Unix + master_seed = random.SystemRandom().randint(0, 2**32 - 1) + if get_rank() == 0: + # master seed is reported only from rank=0 worker, it's to avoid + # confusion, seeds from rank=0 are later broadcasted to other + # workers + logging.info(f'Using random master seed: {master_seed}') + else: + # master seed was specified from command line + logging.info(f'Using master seed from command line: {master_seed}') + + # initialize seeding RNG + seeding_rng = random.Random(master_seed) + + # generate worker seeds, one seed for every distributed worker + worker_seeds = generate_seeds(seeding_rng, get_world_size()) + + # generate seeds for data shuffling, one seed for every epoch + shuffling_seeds = generate_seeds(seeding_rng, epochs) + + # broadcast seeds from rank=0 to other workers + worker_seeds = broadcast_seeds(worker_seeds, device) + shuffling_seeds = broadcast_seeds(shuffling_seeds, device) + return worker_seeds, shuffling_seeds + + +def barrier(): + """ + Works as a temporary distributed barrier, currently pytorch + doesn't implement barrier for NCCL backend. + Calls all_reduce on dummy tensor and synchronizes with GPU. + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.all_reduce(torch.cuda.FloatTensor(1)) + torch.cuda.synchronize() + + +def get_rank(): + """ + Gets distributed rank or returns zero if distributed is not initialized. + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + else: + rank = 0 + return rank + + +def get_world_size(): + """ + Gets total number of distributed workers or returns one if distributed is + not initialized. + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + else: + world_size = 1 + return world_size + + +@contextmanager +def sync_workers(): + """ + Yields distributed rank and synchronizes all workers on exit. + """ + rank = get_rank() + yield rank + barrier() + + +@contextmanager +def timer(name, ndigits=2, sync_gpu=True): + if sync_gpu: + torch.cuda.synchronize() + start = time.time() + yield + if sync_gpu: + torch.cuda.synchronize() + stop = time.time() + elapsed = round(stop - start, ndigits) + logging.info(f'TIMER {name} {elapsed}') + + +def setup_logging(log_file=os.devnull): + """ + Configures logging. + By default logs from all workers are printed to the console, entries are + prefixed with "N: " where N is the rank of the worker. Logs printed to the + console don't include timestaps. + Full logs with timestamps are saved to the log_file file. + """ + class RankFilter(logging.Filter): + def __init__(self, rank): + self.rank = rank + + def filter(self, record): + record.rank = self.rank + return True + + rank = get_rank() + rank_filter = RankFilter(rank) + + logging_format = "%(asctime)s - %(levelname)s - %(rank)s - %(message)s" + logging.basicConfig(level=logging.DEBUG, + format=logging_format, + datefmt="%Y-%m-%d %H:%M:%S", + filename=log_file, + filemode='w') + console = logging.StreamHandler(sys.stdout) + console.setLevel(logging.INFO) + formatter = logging.Formatter('%(rank)s: %(message)s') + console.setFormatter(formatter) + logging.getLogger('').addHandler(console) + logging.getLogger('').addFilter(rank_filter) + + +def set_device(cuda, local_rank): + """ + Sets device based on local_rank and returns instance of torch.device. + + :param cuda: if True: use cuda + :param local_rank: local rank of the worker + """ + if cuda: + torch.cuda.set_device(local_rank) + device = torch.device('cuda') + else: + device = torch.device('cpu') + return device + + +def init_distributed(cuda): + """ + Initializes distributed backend. + + :param cuda: (bool) if True initializes nccl backend, if False initializes + gloo backend + """ + world_size = int(os.environ.get('WORLD_SIZE', 1)) + distributed = (world_size > 1) + if distributed: + backend = 'nccl' if cuda else 'gloo' + dist.init_process_group(backend=backend, + init_method='env://') + assert dist.is_initialized() + return distributed + + +def log_env_info(): + """ + Prints information about execution environment. + """ + logging.info('Collecting environment information...') + env_info = torch.utils.collect_env.get_pretty_env_info() + logging.info(f'{env_info}') + + +def pad_vocabulary(math): + if math == 'fp16': + pad_vocab = 8 + elif math == 'fp32': + pad_vocab = 1 + return pad_vocab + + +class AverageMeter: + """ + Computes and stores the average and current value + """ + def __init__(self, skip_first=True): + self.reset() + self.skip = skip_first + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + + if self.skip: + self.skip = False + else: + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def reduce(self, op): + """ + Reduces average value over all workers. + + :param op: 'sum' or 'mean', reduction operator + """ + if op not in ('sum', 'mean'): + raise NotImplementedError + + distributed = (get_world_size() > 1) + if distributed: + # Backward/forward compatibility around + # https://github.com/pytorch/pytorch/commit/540ef9b1fc5506369a48491af8a285a686689b36 and + # https://github.com/pytorch/pytorch/commit/044d00516ccd6572c0d6ab6d54587155b02a3b86 + # To accomodate change in Pytorch's distributed API + if hasattr(dist, "get_backend"): + _backend = dist.get_backend() + if hasattr(dist, "DistBackend"): + backend_enum_holder = dist.DistBackend + else: + backend_enum_holder = dist.Backend + else: + _backend = dist._backend + backend_enum_holder = dist.dist_backend + + cuda = _backend == backend_enum_holder.NCCL + + if cuda: + avg = torch.cuda.FloatTensor([self.avg]) + _sum = torch.cuda.FloatTensor([self.sum]) + else: + avg = torch.FloatTensor([self.avg]) + _sum = torch.FloatTensor([self.sum]) + + _reduce_op = dist.reduce_op if hasattr(dist, "reduce_op") else dist.ReduceOp + dist.all_reduce(avg, op=_reduce_op.SUM) + dist.all_reduce(_sum, op=_reduce_op.SUM) + self.avg = avg.item() + self.sum = _sum.item() + + if op == 'mean': + self.avg /= get_world_size() + self.sum /= get_world_size() + + +def debug_tensor(tensor, name): + """ + Simple utility which helps with debugging. + Takes a tensor and outputs: min, max, avg, std, number of NaNs, number of + INFs. + + :param tensor: torch tensor + :param name: name of the tensor (only for logging) + """ + logging.info(name) + tensor = tensor.detach().float().cpu().numpy() + logging.info(f'MIN: {tensor.min()} MAX: {tensor.max()} ' + f'AVG: {tensor.mean()} STD: {tensor.std()} ' + f'NAN: {np.isnan(tensor).sum()} INF: {np.isinf(tensor).sum()}')