diff --git a/.github/workflows/push_pr.yml b/.github/workflows/push_pr.yml index 5a82e81b8..1a4ae3b9d 100644 --- a/.github/workflows/push_pr.yml +++ b/.github/workflows/push_pr.yml @@ -32,6 +32,12 @@ jobs: run: pip install -r requirements/requirements.txt - name: DeepSpeed requirements run: pip install -r requirements/requirements.deepspeed.txt + - name: Faiss requirements + run: | + if [ "$RUNNER_OS" == "Linux" ]; then + pip install -r requirements/requirements.faiss-cpu.txt + fi + shell: bash - name: Development requirements run: pip install -r requirements/requirements.dev.txt - name: Unit tests diff --git a/CHANGELOG.md b/CHANGELOG.md index cd768da63..f3d73856b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,18 @@ Note that Sockeye has checks in place to not translate with an old model that wa Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_. +## [3.1.28] + +### Added + +- Added kNN-MT model from [Khandelwal et al., 2021](https://arxiv.org/abs/2010.00710). + - Installation: see [faiss document](https://github.com/facebookresearch/faiss/blob/main/INSTALL.md) -- installation via conda is recommended. + - Building a faiss index from a sockeye model takes two steps: + - Generate decoder states: `sockeye-generate-decoder-states -m [model] --source [src] --target [tgt] --output-dir [output dir]` + - Build index: `sockeye-knn -i [input_dir] -o [output_dir] -t [faiss_index_signature]` where `input_dir` is the same as `output_dir` from the `sockeye-generate-decoder-states` command. + - Faiss index signature reference: [see here](https://github.com/facebookresearch/faiss/wiki/The-index-factory) + - Running inference using the built index: `sockeye-translate ... --knn-index [index_dir] --knn-lambda [interpolation_weight]` where `index_dir` is the same as `output_dir` from the `sockeye-knn` command. + ## [3.1.27] ### Changed diff --git a/requirements/requirements.faiss-cpu.txt b/requirements/requirements.faiss-cpu.txt new file mode 100644 index 000000000..3e1473919 --- /dev/null +++ b/requirements/requirements.faiss-cpu.txt @@ -0,0 +1 @@ +faiss-cpu >= 1.7.2 diff --git a/requirements/requirements.faiss-gpu.txt b/requirements/requirements.faiss-gpu.txt new file mode 100644 index 000000000..e567b137d --- /dev/null +++ b/requirements/requirements.faiss-gpu.txt @@ -0,0 +1 @@ +faiss-gpu >= 1.7.2 diff --git a/setup.py b/setup.py index 2a5334230..c0aac8ca9 100644 --- a/setup.py +++ b/setup.py @@ -87,6 +87,8 @@ def get_requirements(filename): 'sockeye-translate = sockeye.translate:main', 'sockeye-vocab = sockeye.vocab:main', 'sockeye-rerank = sockeye.rerank:main', + 'sockeye-knn = sockeye.knn:main', + 'sockeye-generate-decoder-states = sockeye.generate_decoder_states:main' ], } args = dict( diff --git a/sockeye/__init__.py b/sockeye/__init__.py index 3ea151618..78bb654be 100644 --- a/sockeye/__init__.py +++ b/sockeye/__init__.py @@ -11,4 +11,4 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -__version__ = '3.1.27' +__version__ = '3.1.28' diff --git a/sockeye/arguments.py b/sockeye/arguments.py index dd87a8d7f..6a613ea08 100644 --- a/sockeye/arguments.py +++ b/sockeye/arguments.py @@ -1131,6 +1131,7 @@ def add_translate_cli_args(params): add_inference_args(params) add_device_args(params) add_logging_args(params) + add_knn_mt_args(params) # for kNN MT def add_score_cli_args(params): @@ -1173,7 +1174,40 @@ def add_score_cli_args(params): 'peaked predictions, values > 1.0 produce smoothed distributions.') params.add_argument('--dtype', default=None, choices=[None, C.DTYPE_FP32, C.DTYPE_FP16, C.DTYPE_INT8], - help="Data type. Default: %(default)s infers from saved model.") + help="Data type. Default: infers from saved model.") + + add_logging_args(params) + + +def add_state_generation_args(params): + add_training_data_args(params, required=True) + add_vocab_args(params) + add_device_args(params) + add_batch_args(params, default_batch_size=56, default_batch_type=C.BATCH_TYPE_SENTENCE) + + decode_params = params.add_argument_group("Decoder state generation parameters") + + params.add_argument('--state-dtype', default=None, choices=[None, C.DTYPE_FP32, C.DTYPE_FP16], + help="Data type of the decoder state store. Default: infers from saved model.") + + params.add_argument("--model", "-m", required=True, + help="Model directory containing trained model.") + + params.add_argument(C.TRAINING_ARG_MAX_SEQ_LEN, + type=multiple_values(num_values=2, greater_or_equal=1), + default=None, + help='Maximum sequence length in tokens.' + 'Use "x:x" to specify separate values for src&tgt. Default: Read from model.') + + # common params with translate CLI + add_length_penalty_args(params) + add_brevity_penalty_args(params) + + params.add_argument("--output-dir", "-o", default=None, + help="The path to the directory that stores the decoder states.") + + params.add_argument('--dtype', default=None, choices=[None, C.DTYPE_FP32, C.DTYPE_FP16, C.DTYPE_INT8], + help="Data type. Default: infers from saved model.") add_logging_args(params) @@ -1343,7 +1377,7 @@ def add_inference_args(params): add_brevity_penalty_args(decode_params) decode_params.add_argument('--dtype', default=None, choices=[None, C.DTYPE_FP32, C.DTYPE_FP16, C.DTYPE_INT8], - help="Data type. Default: %(default)s infers from saved model.") + help="Data type. Default: infers from saved model.") add_clamp_to_dtype_arg(decode_params) @@ -1380,6 +1414,7 @@ def add_brevity_penalty_args(params): 'ratio, used for brevity penalty calculation, for all inputs. If zero, uses the average of length ' 'ratios from the training data over all models. Default: %(default)s.') + def add_clamp_to_dtype_arg(params): params.add_argument('--clamp-to-dtype', action='store_true', @@ -1423,3 +1458,46 @@ def add_build_vocab_args(params): params.add_argument('-o', '--output', required=True, type=str, help="Output filename to write vocabulary to.") add_vocab_args(params) add_process_pool_args(params) + + +def add_knn_mt_args(params): + knn_params = params.add_argument_group("kNN MT parameters") + + knn_params.add_argument('--knn-index', + type=str, + help='Optionally use a KNN index during inference to ' + 'retrieve similar hidden states and corresponding target tokens.', + default=None) + knn_params.add_argument('--knn-lambda', + type=float, + help="Interpolation parameter when using KNN index. Default: %(default)s.", + default=C.DEFAULT_KNN_LAMBDA) + + +def add_build_knn_index_args(params): + params.add_argument('-i', '--input-dir', + required=True, + type=str, + help='The directory that contains the stored decoder states and values ' + f'({C.KNN_STATE_DATA_STORE_NAME} and {C.KNN_WORD_DATA_STORE_NAME}).') + params.add_argument('-o', '--output-dir', + default=None, + type=str, + help='The path to the output directory. Will reuse input directory if not specified.') + params.add_argument('-t', '--index-type', + default=None, + type=str, + help='An optional field to specify the type of the index. ' + 'Will override settings in the config. ' + 'The type is specified with a faiss index factory signature, see here: ' + 'https://github.com/facebookresearch/faiss/wiki/The-index-factory') + params.add_argument('--train-data-input-file', + default=None, + type=str, + help='An optional field to reuse an already-built training data sample for the index. ' + 'Otherwise, a (slow) sampling step might need to be run.') + params.add_argument('--train-data-size', + default=None, + type=int, + help='An optional field to specify the size of the training sample. ' + 'Will override settings in the config.') diff --git a/sockeye/beam_search.py b/sockeye/beam_search.py index 48d640a88..a82ef83c4 100644 --- a/sockeye/beam_search.py +++ b/sockeye/beam_search.py @@ -65,10 +65,12 @@ class _SingleModelInference(_Inference): def __init__(self, model: SockeyeModel, skip_softmax: bool = False, - constant_length_ratio: float = 0.0) -> None: + constant_length_ratio: float = 0.0, + knn_lambda: float = C.DEFAULT_KNN_LAMBDA) -> None: self._model = model self._skip_softmax = skip_softmax self._const_lr = constant_length_ratio + self.knn_lambda = knn_lambda def state_structure(self) -> List: return [self._model.state_structure()] @@ -83,10 +85,17 @@ def decode_step(self, vocab_slice_ids: Optional[pt.Tensor] = None, target_prefix_factor_mask: Optional[pt.Tensor] = None, factor_vocab_size: Optional[int] = None): - logits, states, target_factor_outputs = self._model.decode_step(step_input, states, vocab_slice_ids) + logits, knn_probs, states, target_factor_outputs = self._model.decode_step(step_input, states, vocab_slice_ids) if not self._skip_softmax: - logits = pt.log_softmax(logits, dim=-1) - scores = -logits # shape: (batch*beam, output_vocab_size/len(vocab_slice_ids)) + if knn_probs is None: # no knn used + probs = pt.log_softmax(logits, dim=-1) + else: + probs = pt.log(self.knn_lambda * pt.softmax(logits, dim=-1) + (1 - self.knn_lambda) * knn_probs) + else: + assert knn_probs is None, "Can't skip softmax with KNN." + probs = logits + + scores = -probs # shape: (batch*beam, output_vocab_size/len(vocab_slice_ids)) target_factors = None # type: Optional[pt.Tensor] if target_factor_outputs: @@ -120,7 +129,8 @@ class _EnsembleInference(_Inference): def __init__(self, models: List[SockeyeModel], ensemble_mode: str = 'linear', - constant_length_ratio: float = 0.0) -> None: + constant_length_ratio: float = 0.0, + knn_lambda: float = C.DEFAULT_KNN_LAMBDA) -> None: self._models = models if ensemble_mode == 'linear': self._interpolation = self.linear_interpolation @@ -129,6 +139,7 @@ def __init__(self, else: raise ValueError() self._const_lr = constant_length_ratio + self.knn_lambda = knn_lambda def state_structure(self) -> List: return [model.state_structure() for model in self._models] @@ -163,8 +174,11 @@ def decode_step(self, for model, model_state_structure in zip(self._models, self.state_structure()): model_states = states[state_index:state_index+len(model_state_structure)] state_index += len(model_state_structure) - logits, model_states, target_factor_outputs = model.decode_step(step_input, model_states, vocab_slice_ids) - probs = logits.softmax(dim=-1) + logits, knn_probs, model_states, target_factor_outputs = model.decode_step(step_input, model_states, vocab_slice_ids) + if knn_probs is None: + probs = logits.softmax(dim=-1) + else: + probs = self.knn_lambda * pt.softmax(logits, dim=-1) + (1 - self.knn_lambda) * knn_probs outputs.append(probs) if target_factor_outputs: target_factor_probs = [tfo.softmax(dim=-1) for tfo in target_factor_outputs] @@ -1086,6 +1100,7 @@ def get_search_algorithm(models: List[SockeyeModel], ensemble_mode: str = 'linear', beam_search_stop: str = C.BEAM_SEARCH_STOP_ALL, constant_length_ratio: float = 0.0, + knn_lambda: float = C.DEFAULT_KNN_LAMBDA, sample: Optional[int] = None, prevent_unk: bool = False, greedy: bool = False, @@ -1114,7 +1129,8 @@ def get_search_algorithm(models: List[SockeyeModel], num_target_factors=models[0].num_target_factors, inference=_SingleModelInference(model=models[0], skip_softmax=True, - constant_length_ratio=0.0), + constant_length_ratio=0.0, + knn_lambda=knn_lambda), skip_nvs=skip_nvs, nvs_thresh=nvs_thresh) else: @@ -1125,7 +1141,8 @@ def get_search_algorithm(models: List[SockeyeModel], logger.info("Enabled skipping softmax for a single model and greedy decoding.") inference = _SingleModelInference(model=models[0], skip_softmax=skip_softmax, - constant_length_ratio=constant_length_ratio) + constant_length_ratio=constant_length_ratio, + knn_lambda=knn_lambda) else: inference = _EnsembleInference(models=models, ensemble_mode=ensemble_mode, diff --git a/sockeye/constants.py b/sockeye/constants.py index b0553caee..670c24b48 100644 --- a/sockeye/constants.py +++ b/sockeye/constants.py @@ -287,6 +287,7 @@ DTYPE_FP32 = 'float32' DTYPE_TF32 = 'tf32' DTYPE_INT8 = 'int8' +DTYPE_INT16 = 'int16' DTYPE_INT32 = 'int32' LARGE_POSITIVE_VALUE = 99999999. LARGE_VALUES = { @@ -381,3 +382,13 @@ BREVITY_PENALTY_CONSTANT = 'constant' BREVITY_PENALTY_LEARNED = 'learned' BREVITY_PENALTY_NONE = 'none' + +# k-nn +KNN_STATE_DATA_STORE_NAME = "keys.npy" +KNN_WORD_DATA_STORE_NAME = "vals.npy" +KNN_WORD_DATA_STORE_DTYPE = DTYPE_INT32 +KNN_CONFIG_NAME = "config.yaml" +KNN_INDEX_NAME = "key_index" +KNN_EPSILON = 1e-6 +DEFAULT_DATA_STORE_BLOCK_SIZE = 1024 * 1024 +DEFAULT_KNN_LAMBDA = 0.8 \ No newline at end of file diff --git a/sockeye/generate_decoder_states.py b/sockeye/generate_decoder_states.py new file mode 100644 index 000000000..8df008f77 --- /dev/null +++ b/sockeye/generate_decoder_states.py @@ -0,0 +1,280 @@ +# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import argparse +import logging +import os +from typing import Dict, List + +import numpy as np +import torch as pt + +from . import arguments +from . import constants as C +from . import data_io +from . import utils +from .log import setup_main_logger +from .model import SockeyeModel, load_model +from .vocab import Vocab +from .utils import check_condition +from .knn import KNNConfig, get_state_store_path, get_word_store_path, get_config_path + +# Temporary logger, the real one (logging to a file probably, will be created in the main function) +logger = logging.getLogger(__name__) + + +class NumpyMemmapStorage: + """ + Wraps a numpy memmap as a datastore for decoder state vectors. + + :param file_name: disk file path to store the memory-mapped file + :param num_dim: number of dimensions of the vectors in the data store + :param dtype: data type of the vectors in the data store + """ + + def __init__(self, + file_name: str, + num_dim: int, + dtype: np.dtype) -> None: + self.file_name = file_name + self.num_dim = num_dim # dimension of a single entry + self.dtype = dtype + self.block_size = -1 + self.mmap = None + self.tail_idx = 0 # where the next entry should be inserted + self.size = 0 # size of storage already assigned + + def open(self, initial_size: int, block_size: int) -> None: + """Create a memmap handle and initialize its sizes.""" + self.mmap = np.memmap(self.file_name, dtype=self.dtype, mode='w+', shape=(initial_size, self.num_dim)) + self.size = initial_size + self.block_size = block_size + + def add(self, array: np.ndarray) -> None: + """ + It turns out that numpy memmap actually cannot be re-sized. + So we have to pre-estimate how many entries we need and put it down as initial_size. + If we end up adding more entries to the memmap than initially claimed, we'll have to bail out. + + :param array: the array of states to be added. + """ + assert self.mmap is not None + num_entries, num_dim = array.shape + assert num_dim == self.num_dim + + if self.tail_idx + num_entries > self.size: + # bail out + logger.warning( + f"Trying to write {num_entries} entries into a numpy memmap that " + \ + f"has size {self.size} and already has {self.tail_idx} entries. Nothing is written." + ) + else: + start = self.tail_idx + end = self.tail_idx + num_entries + self.mmap[start:end] = array + + self.tail_idx += num_entries + + +class DecoderStateGenerator: + """ + Generate decoder states by using a translation model to force-decode a parallel dataset. + + :param model: Sockeye translation model used to generate the states. + :param source_vocabs: source vocabs for the translation model. + :param target_vocabs: target vocabs for the translation model. + :param output_dir: path to the memmap (directory) storing decoder states. + :param max_seq_len_source: maximum source length for decoding. + :param max_seq_len_target: maximum source length for decoding. + :param state_data_type: data type for storing decoder states. + :param word_data_type: data type for storing word indexes. + :param device: device (cpu/gpu) for decoding. + """ + + def __init__(self, + model: SockeyeModel, + source_vocabs: List[Vocab], + target_vocabs: List[Vocab], + output_dir: str, + max_seq_len_source: int, + max_seq_len_target: int, + state_data_type: str, + word_data_type: str, + device: pt.device) -> None: + self.model = model + self.source_vocabs = source_vocabs + self.target_vocabs = target_vocabs + self.device = device + self.traced_model = None + self.max_seq_len_source = max_seq_len_source + self.max_seq_len_target = max_seq_len_target + + self.output_dir = output_dir + self.state_store_file = None + self.words_store_file = None + + # info for KNNConfig + self.num_states = 0 + self.dimension = None + self.state_data_type = utils.get_numpy_dtype(state_data_type) + self.word_data_type = utils.get_numpy_dtype(word_data_type) + + @staticmethod + def probe_token_count(target_path: str, max_seq_len: int) -> int: + """Count the number of tokens in the file at `target_path`, with each line truncated at `max_seq_len`.""" + token_count = 0 + with open(target_path, 'r') as f: + for line in f: + token_count += min(len(line.split()) + 1, max_seq_len) # +1 for EOS + return token_count + + def init_store_file(self, initial_size: int) -> None: + """Initialize the memory map files.""" + self.dimension = self.model.config.config_decoder.model_size + + self.state_store_file = NumpyMemmapStorage(get_state_store_path(self.output_dir), + self.dimension, self.state_data_type) + self.words_store_file = NumpyMemmapStorage(get_word_store_path(self.output_dir), + 1, self.word_data_type) # dim=1 because it's just scalar word index + self.state_store_file.open(initial_size, 1) + self.words_store_file.open(initial_size, 1) + + def generate_states_and_store(self, + sources: List[str], + targets: List[str], + batch_size: int) -> None: + """ + Generate decoder states by force-decoding the sentence pairs in `sources` and `targets` with a NMT model. + + :param sources: list of source segments. + :param targets: list of target segments. + :param batch_size: number of sentence pairs to decode at once. + """ + assert self.state_store_file != None, \ + "You should call probe_token_count first to initialize the store files." + + # get data iter + data_iter = data_io.get_scoring_data_iters( + sources=sources, + targets=targets, + source_vocabs=self.source_vocabs, + target_vocabs=self.target_vocabs, + batch_size=batch_size, + max_seq_len_source=self.max_seq_len_source, + max_seq_len_target=self.max_seq_len_target + ) + + with pt.inference_mode(): + for batch_no, batch in enumerate(data_iter, 1): + if (batch_no + 1) % 1000 == 0: + logger.debug("At batch number {0}".format(batch_no + 1)) + + # get decoder states + batch = batch.load(self.device) + model_inputs = (batch.source, batch.source_length, batch.target, batch.target_length) + if self.traced_model is None: + trace_inputs = {'get_decoder_states': model_inputs} + self.traced_model = pt.jit.trace_module(self.model, trace_inputs, strict=False) + # shape: (batch, seq_len, hidden_dim) + decoder_states = self.traced_model.get_decoder_states(*model_inputs) + + # flatten batch and seq_len dimensions, remove pads on the target + pad_mask = (batch.target != C.PAD_ID)[:, :, 0] # shape: (batch, seq_len) + flat_target = batch.target[pad_mask].cpu().detach().numpy() + flat_states = decoder_states[pad_mask].cpu().detach().numpy() + + # store + self.state_store_file.add(flat_states) + self.words_store_file.add(flat_target) + + def save_config(self): + """ + Save a config file with information of the data store. + """ + config = KNNConfig( + index_size=self.num_states, + dimension=self.dimension, + state_data_type=utils.dtype_to_str(self.state_data_type), + word_data_type=utils.dtype_to_str(self.word_data_type), + # the remaining two values are only placeholders -- they are left for the faiss index builder to fill + index_type="", + train_data_size=-1, + ) + config.save(get_config_path(self.output_dir)) + + +def store(args: argparse.Namespace): + """Build a data store with an existing model and a parallel corpus.""" + use_cpu = args.use_cpu + if not pt.cuda.is_available(): + logger.info("CUDA not available, using cpu") + use_cpu = True + device = pt.device('cpu') if use_cpu else pt.device('cuda', args.device_id) + logger.info(f"Scoring device: {device}") + + model, source_vocabs, target_vocabs = load_model(args.model, device=device, dtype=args.dtype) + model.eval() + + max_seq_len_source = model.max_supported_len_source + max_seq_len_target = model.max_supported_len_target + if args.max_seq_len is not None: + max_seq_len_source = min(args.max_seq_len[0] + C.SPACE_FOR_XOS, max_seq_len_source) + max_seq_len_target = min(args.max_seq_len[1] + C.SPACE_FOR_XOS, max_seq_len_target) + + sources = [args.source] + args.source_factors + sources = [str(os.path.abspath(source)) for source in sources] + targets = [args.target] + args.target_factors + targets = [str(os.path.abspath(target)) for target in targets] + + check_condition(len(targets) == model.num_target_factors, + "Number of target inputs/factors provided (%d) does not match number of target factors " + "required by the model (%d)" % (len(targets), model.num_target_factors)) + + # if state data type is None, use inferred data type + if args.state_dtype is None: + args.state_dtype = utils.dtype_to_str(model.dtype) + + if not os.path.exists(args.output_dir): + os.mkdir(args.output_dir) + elif os.path.isfile(args.output_dir): + logging.error(f"{args.output_dir} already exists as a file") + + generator = DecoderStateGenerator(model, source_vocabs, target_vocabs, args.output_dir, + max_seq_len_source, max_seq_len_target, + args.state_dtype, C.KNN_WORD_DATA_STORE_DTYPE, device) + generator.num_states = DecoderStateGenerator.probe_token_count(targets[0], max_seq_len_target) + generator.init_store_file(generator.num_states) + generator.generate_states_and_store(sources, targets, args.batch_size) + generator.save_config() + + +def main(): + params = arguments.ConfigArgumentParser( + description='CLI to generate decoder states from parallel data with a trained model, ' + 'and build a data store from it.' + ) + arguments.add_state_generation_args(params) + args = params.parse_args() + check_condition(args.batch_type == C.BATCH_TYPE_SENTENCE, "Batching by number of words is not supported") + + setup_main_logger(file_logging=False, + console=not args.quiet, + level=args.loglevel) # pylint: disable=no-member + + utils.log_basic_info(args) + + store(args) + + +if __name__ == "__main__": + main() diff --git a/sockeye/inference.py b/sockeye/inference.py index e819287e1..f1cc993e5 100644 --- a/sockeye/inference.py +++ b/sockeye/inference.py @@ -763,6 +763,7 @@ def __init__(self, sample: Optional[int] = None, output_scores: bool = False, constant_length_ratio: float = 0.0, + knn_lambda: float = C.DEFAULT_KNN_LAMBDA, max_output_length_num_stds: int = C.DEFAULT_NUM_STD_MAX_OUTPUT_LENGTH, max_input_length: Optional[int] = None, max_output_length: Optional[int] = None, @@ -812,6 +813,7 @@ def __init__(self, beam_search_stop=beam_search_stop, scorer=self._scorer, constant_length_ratio=constant_length_ratio, + knn_lambda=knn_lambda, prevent_unk=prevent_unk, greedy=greedy, skip_nvs=skip_nvs, diff --git a/sockeye/knn.py b/sockeye/knn.py new file mode 100755 index 000000000..dbad4a38e --- /dev/null +++ b/sockeye/knn.py @@ -0,0 +1,210 @@ +# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +from dataclasses import dataclass +import logging +import numpy as np +import os +import shutil +from typing import Optional + +from . import arguments +from sockeye import config, utils, constants as C +from sockeye.log import setup_main_logger + +try: + import faiss +except ImportError: + pass + +logger = logging.getLogger(__name__) + + +@dataclass +class KNNConfig(config.Config): + """ + KNNConfig defines knn-specific configurations, including the information about the data store + as well as the index itself. + + :param index_size: Size of the index and the data store. + :param dimension: Number of dimensions of the decoder states. + :param state_data_type: Data type of the decoder states (keys). + :param word_data_type: Data type of the stored word indexes (values). + :param index_type: faiss index signature, see https://github.com/facebookresearch/faiss/wiki/The-index-factory + :param train_data_size: Size of the training data used to train the index (if it needs to be trained). + """ + index_size: int + dimension: int + state_data_type: str + word_data_type: str + index_type: str + train_data_size: int + + +class FaissIndexBuilder: + """ + Builds a faiss index from a data store containing stored keys (i.e., decoder hidden states for k-NN-based MT). + + :param config: a KNNConfig object containing the index configuration information + :param use_gpu: build index on a gpu + :param device_id: device id if building index on gpu + """ + + def __init__(self, config: KNNConfig, use_gpu: bool = False, device_id: int = 0): + utils.check_import_faiss() # faiss will definitely be used for this class, so check here + self.config = config + self.use_gpu = use_gpu + self.device_id = device_id + + def init_faiss_index(self, train_sample: Optional[np.memmap] = None): + """Initialize the Faiss index to be built and conduct training if needed.""" + index = faiss.index_factory(self.config.dimension, self.config.index_type) + # pylint is disabled for members that only exists in faiss-gpu + if self.use_gpu: + res = faiss.StandardGpuResources() # pylint: disable=no-member + co = faiss.GpuClonerOptions() # pylint: disable=no-member + index = faiss.index_cpu_to_gpu(res, self.device_id, index, co) # pylint: disable=no-member + + if not index.is_trained and train_sample is not None: + index.train(train_sample.astype(np.float32)) # unfortunately, faiss index only supports float32 + elif not index.is_trained: + logger.error("Index needs training but no training sample is passed.") + + return index + + def add_items(self, index, keys: np.array): + """Add items to the index (must call `init_faiss_index` first).""" + item_count, key_dim = keys.shape + assert key_dim == self.config.dimension + + index.add(keys.astype(np.float32)) # unfortunately, faiss index only supports float32 + + def block_add_items(self, index, keys: np.array, block_size: int = C.DEFAULT_DATA_STORE_BLOCK_SIZE): + """Add items to the index in blocks -- used for a large number of items (must call `init_faiss_index` first).""" + item_count, key_dim = keys.shape + assert key_dim == self.config.dimension + + n_blocks = item_count // block_size + for i in range(n_blocks): + logger.debug(f"adding block no.{i}") + start = block_size * i + end = block_size * (i + 1) + index.add(keys[start:end].astype(np.float32)) # unfortunately, faiss index only supports float32 + + if block_size * n_blocks < item_count: + start = block_size * n_blocks + index.add(keys[start:item_count].astype(np.float32)) # unfortunately, faiss index only supports float32 + + @staticmethod + def build_train_sample(keys: np.array, sample_size: int): + """Randomly sample `sample_size` keys as training sample.""" + item_count, _ = keys.shape + assert 0 < sample_size <= item_count + + if sample_size < item_count: + train_sample_idx = np.random.choice(np.arange(item_count), size=[sample_size], replace=False) + train_sample = keys[train_sample_idx] + else: + train_sample = keys + + return train_sample + + def build_faiss_index(self, keys: np.array, train_sample: Optional[np.memmap] = None): + """ + Top-level function of the class to build faiss index for a set of keys, optionally with samples for training. + """ + item_count, _ = keys.shape + if train_sample is None and self.config.train_data_size > 0: + train_sample = FaissIndexBuilder.build_train_sample(keys, self.config.train_data_size) + + index = self.init_faiss_index(train_sample) + self.block_add_items(index, keys) + + return index + + +def get_state_store_path(dir): + """Get the path to the state store file given a kNN export directory.""" + return os.path.join(dir, C.KNN_STATE_DATA_STORE_NAME) + + +def get_word_store_path(dir): + """Get the path to the word store file given a kNN export directory.""" + return os.path.join(dir, C.KNN_WORD_DATA_STORE_NAME) + + +def get_config_path(dir): + """Get the path to the kNN config file given a kNN export directory.""" + return os.path.join(dir, C.KNN_CONFIG_NAME) + + +def build_knn_index_package(args): + """Top-level function that builds a kNN index package (kNN index and config file) + from an existing state and word store.""" + state_store_filename = get_state_store_path(args.input_dir) + word_store_filename = get_word_store_path(args.input_dir) + config_filename = get_config_path(args.input_dir) + utils.check_condition(os.path.exists(state_store_filename), f"Input file {state_store_filename} not found!") + utils.check_condition(os.path.exists(word_store_filename), f"Input file {word_store_filename} not found!") + utils.check_condition(os.path.exists(config_filename), f"Config file {config_filename} not found!") + utils.check_import_faiss() + + setup_main_logger(file_logging=False, + console=not args.quiet, + level=args.loglevel) # pylint: disable=no-member + utils.log_basic_info(args) + + config = KNNConfig.load(config_filename) + if args.index_type is not None: + config.index_type = args.index_type + if args.train_data_size is not None: + config.train_data_size = args.train_data_size + + keys = np.memmap(state_store_filename, dtype=config.state_data_type, + mode='r', shape=(config.index_size, config.dimension)) + builder = FaissIndexBuilder(config, not args.use_cpu, args.device_id) + train_sample = None + if args.train_data_input_file is not None: + train_sample = np.memmap(args.train_data_input_file, dtype=config.state_data_type, + mode='r', shape=(config.index_size, config.dimension)) + index = builder.build_faiss_index(keys, train_sample) + + if not args.use_cpu: + index_cpu = faiss.index_gpu_to_cpu(index) # pylint: disable=no-member + else: + index_cpu = index + + if not args.output_dir: + args.output_dir = args.input_dir + elif args.output_dir != args.input_dir: + shutil.copy(word_store_filename, os.path.join(args.output_dir, C.KNN_WORD_DATA_STORE_NAME)) + + if not os.path.exists(args.output_dir): + os.mkdir(args.output_dir) + + faiss.write_index(index_cpu, os.path.join(args.output_dir, C.KNN_INDEX_NAME)) + config.save(os.path.join(args.output_dir, C.KNN_CONFIG_NAME)) + + +def main(): + params = arguments.ConfigArgumentParser(description='CLI to build knn index.') + arguments.add_build_knn_index_args(params) + arguments.add_logging_args(params) + arguments.add_device_args(params) + args = params.parse_args() + + build_knn_index_package(args) + + +if __name__ == "__main__": + main() diff --git a/sockeye/layers.py b/sockeye/layers.py index bf2730833..0827f11e3 100644 --- a/sockeye/layers.py +++ b/sockeye/layers.py @@ -16,6 +16,7 @@ from dataclasses import dataclass from typing import List, Optional, Tuple +import numpy as np import torch as pt import torch.nn.functional as F @@ -117,6 +118,64 @@ def forward(self, data: pt.Tensor, vocab_slice_ids: Optional[pt.Tensor] = None) return F.linear(data, weight, bias) +class KNN(pt.nn.Module): + """ + An alternative output layer that can produce a output distribution over the vocabulary + by using the decoder hidden state to query into an index. + For more details, see: https://arxiv.org/abs/2010.00710. + + :param keys_index: faiss index used for k-NN query. + :param vals: a list of word indexes that maps key ids to their corresponding vocabulary ids. + :param vocab_size: the size of the output vocabulary. + :param k: number of candidates to be retrieved by k-nearest neighbors query. + :param temperature: temperature that controls the smoothness of the output distribution. + :param state_store: an optional state store object that is used to compute the exact distance + between the query and the index. + """ + + def __init__(self, + keys_index: "faiss.Index", # type: ignore # suppress mypy error becaues faiss is an optional import + vals: np.memmap, + vocab_size: int, + k=64, + temperature=10, + state_store: Optional[np.memmap] = None) -> None: + super().__init__() + self.keys_index = keys_index + self.vals = vals + self.vocab_size = vocab_size + self.k = k + self.temperature = temperature + self.state_store = state_store + + def forward(self, data: pt.Tensor): + # faiss only supports float32 + distances, indices = self.keys_index.search(data.cpu().numpy().astype(np.float32), self.k) + # Map indices to tokens + y = self.vals[(indices + 1) % len(self.vals)] + # no EOS is inserted in generated data store, so we need to use the BOS of the next sentence as EOS + y[y == C.BOS_ID] = C.EOS_ID + + # use exact distance when state_store is available + if self.state_store is not None: + raw_keys = pt.from_numpy(self.state_store[indices]).to(device=data.device) # (data.shape[0], k, dim) + distances = pt.norm(data.unsqueeze(1) - raw_keys, p=2, dim=-1) # data lacks k axis, so need to create one + else: + distances = np.sqrt(distances) # unlike pytorch, faiss doesn't do sqrt for us + distances = pt.from_numpy(distances).to(device=data.device) + + # pytorch expects long for indexes + y = pt.from_numpy(y).to(device=data.device).long() + + probs = pt.exp(-distances / self.temperature) + full_probs = pt.zeros((data.shape[0], self.vocab_size), device=data.device) + full_probs.scatter_add_(src=probs, index=y.squeeze(2), dim=-1) + z = pt.sum(full_probs, dim=-1).unsqueeze(-1) + z[z < C.KNN_EPSILON] = C.KNN_EPSILON # avoid div by 0 (which may happen when distances of all items are large) + full_probs.div_(z) + return full_probs + + @dataclass class LengthRatioConfig(config.Config): num_layers: int # Number of layers @@ -520,7 +579,7 @@ def __init__(self, self.kv_interleaved = False def separate_kv(self): - """ Writes kv input projection parameters in non-interleaved format (compatible with F.multi_head_attention). """ + """Writes kv input projection parameters in non-interleaved format (compatible with F.multi_head_attention). """ assert self.kv_interleaved with pt.no_grad(): k, v = self.ff_kv.weight.data.view(self.heads, 2 * self.depth_per_head, self._depth_key_value).split( @@ -531,7 +590,7 @@ def separate_kv(self): self.kv_interleaved = False def interleave_kv(self): - """ Writes kv input projection parameters in interleaved format (compatible with interleaved matmul). """ + """Writes kv input projection parameters in interleaved format (compatible with interleaved matmul). """ assert not self.kv_interleaved with pt.no_grad(): k, v = self.ff_kv.weight.data.split(self.depth, dim=0) diff --git a/sockeye/model.py b/sockeye/model.py index 27ea5534a..ed46cd9ce 100644 --- a/sockeye/model.py +++ b/sockeye/model.py @@ -19,6 +19,7 @@ from functools import lru_cache from typing import cast, Dict, List, Optional, Tuple, Union +import numpy as np import torch as pt from sockeye import __version__ @@ -34,6 +35,14 @@ from .encoder import FactorConfig from .layers import LengthRatioConfig from . import nvs +from sockeye.knn import KNNConfig + +try: + import faiss # pylint: disable=E0401 + # The following import will allow us to pass pytorch arrays directly to faiss + import faiss.contrib.torch_utils # pylint: disable=E0401 +except: + pass logger = logging.getLogger(__name__) @@ -178,6 +187,8 @@ def __init__(self, for name, dtype in mismatched_dtype_params: logger.warn(f'{name}: {dtype} -> {self.dtype}') + self.knn : Optional[layers.KNN] = None + def cast(self, dtype: Union[pt.dtype, str]): dtype = utils.get_torch_dtype(dtype) @@ -278,7 +289,8 @@ def _embed_and_encode(self, def decode_step(self, step_input: pt.Tensor, states: List[pt.Tensor], - vocab_slice_ids: Optional[pt.Tensor] = None) -> Tuple[pt.Tensor, List[pt.Tensor], List[pt.Tensor]]: + vocab_slice_ids: Optional[pt.Tensor] = None) -> Tuple[pt.Tensor,pt.Tensor, List[pt.Tensor], + List[pt.Tensor]]: """ One step decoding of the translation model. @@ -288,7 +300,7 @@ def decode_step(self, :param vocab_slice_ids: Optional list of vocabulary ids to use for reduced matrix multiplication at the output layer. - :return: logits, list of new model states, other target factor logits. + :return: logits, KNN output if present otherwise None, list of new model states, other target factor logits. """ decode_step_inputs = [step_input, states] if vocab_slice_ids is not None: @@ -298,13 +310,19 @@ def decode_step(self, decode_step_module = _DecodeStep(self.embedding_target, self.decoder, self.output_layer, - self.factor_output_layers) + self.factor_output_layers, + self.knn) self.traced_decode_step = pt.jit.trace(decode_step_module, decode_step_inputs) # the traced module returns a flat list of tensors decode_step_outputs = self.traced_decode_step(*decode_step_inputs) - step_output, *target_factor_outputs = decode_step_outputs[:self.num_target_factors] - new_states = decode_step_outputs[self.num_target_factors:] - return step_output, new_states, target_factor_outputs + # +1 for the decoder output, which will be used to generate kNN output + step_output, decoder_out, *target_factor_outputs = decode_step_outputs[:self.num_target_factors + 1] + + # do the query here because it cannot be traced (jit.ignore does not play well with tracing) + knn_output = self.knn(decoder_out) if self.knn is not None else None + + new_states = decode_step_outputs[self.num_target_factors + 1:] + return step_output, knn_output, new_states, target_factor_outputs def forward(self, source, source_length, target, target_length): # pylint: disable=arguments-differ # When updating only the decoder (specified directly or implied by @@ -334,6 +352,17 @@ def forward(self, source, source_length, target, target_length): # pylint: disa return forward_output + def get_decoder_states(self, source, source_length, target, target_length): + """Same as `forward`, but skip the output layer and return the decoder states.""" + with pt.no_grad() if self.train_decoder_only or self.forward_pass_cache_size > 0 else utils.no_context(): + source_encoded, source_encoded_length, target_embed, states, nvs_prediction = self.embed_and_encode( + source, + source_length, + target) + + decoder_states = self.decoder.decode_seq(target_embed, states=states) + return decoder_states + def predict_output_length(self, source_encoded: pt.Tensor, source_encoded_length: pt.Tensor, @@ -462,6 +491,29 @@ def set_parameters(self, (name, model_params[name].size(), new_params[name].size()) model_params[name].data[:] = new_params[name].data + def load_knn_index(self, knn_index_folder: str) -> None: + """ + Load kNN index from a directory. + + :param knn_index_folder: same as `output_dir` from the `sockeye-knn` command, + containing the index and a config file. + """ + utils.check_import_faiss() + knn_config = KNNConfig.load(os.path.join(knn_index_folder, C.KNN_CONFIG_NAME)) + knn_config = cast(KNNConfig, knn_config) # load returns a Config class, need to cast to subclass KNNConfig + keys_index = faiss.read_index(os.path.join(knn_index_folder, C.KNN_INDEX_NAME)) + vals = np.memmap(os.path.join(knn_index_folder, C.KNN_WORD_DATA_STORE_NAME), + dtype=utils.get_numpy_dtype(knn_config.word_data_type), + mode='r', + shape=(knn_config.index_size, 1)) # type: np.memmap + state_store = None # type: Optional[np.memmap] + if os.path.isfile(os.path.join(knn_index_folder, C.KNN_STATE_DATA_STORE_NAME)): + state_store = np.memmap(os.path.join(knn_index_folder, C.KNN_STATE_DATA_STORE_NAME), + dtype=utils.get_numpy_dtype(knn_config.state_data_type), + mode='r', + shape=(knn_config.index_size, knn_config.dimension)) + self.knn = layers.KNN(keys_index, vals, vocab_size=self.config.vocab_target_size, state_store=state_store) + @staticmethod def save_version(folder: str): """ @@ -573,13 +625,15 @@ def __init__(self, embedding_target: encoder.Embedding, decoder: decoder.Decoder, output_layer: layers.OutputLayer, - factor_output_layers: pt.nn.ModuleList): + factor_output_layers: pt.nn.ModuleList, + knn : Optional[layers.KNN] = None): super().__init__() self.embedding_target = embedding_target self.decoder = decoder self.output_layer = pt.jit.script(output_layer) self.factor_output_layers = factor_output_layers self.has_target_factors = bool(factor_output_layers) + self.knn = knn def forward(self, step_input, @@ -594,7 +648,7 @@ def forward(self, # return values are collected in a flat list due to constraints in mixed return types in traced modules # (can only by tensors, or lists of tensors or dicts of tensors, but no mix of them). - outputs = [step_output] + outputs = [step_output, decoder_out] if self.has_target_factors: outputs += [fol(decoder_out) for fol in self.factor_output_layers] outputs += new_states @@ -649,7 +703,8 @@ def load_model(model_folder: str, train_decoder_only: bool = False, allow_missing: bool = False, set_grad_req_null: bool = True, - forward_pass_cache_size: int = 0) -> Tuple[SockeyeModel, List[vocab.Vocab], List[vocab.Vocab]]: + forward_pass_cache_size: int = 0, + knn_index: Optional[str] = None) -> Tuple[SockeyeModel, List[vocab.Vocab], List[vocab.Vocab]]: """ Load a model from model_folder. @@ -664,6 +719,7 @@ def load_model(model_folder: str, :param allow_missing: Allow missing parameters in the loaded model. :param set_grad_req_null: Set grad_req to null for model parameters. :param forward_pass_cache_size: If > 0, cache encoder and embedding calculations of forward pass. + :param knn_index: Optional path to a folder containing a KNN model index. :return: List of models, source vocabularies, target vocabularies. """ source_vocabs = vocab.load_source_vocabs(model_folder) @@ -690,6 +746,9 @@ def load_model(model_folder: str, allow_missing=allow_missing, ignore_extra=False) + if knn_index is not None: + model.load_knn_index(knn_index) + model.to(device) if set_grad_req_null: @@ -721,8 +780,10 @@ def load_models(device: pt.device, train_decoder_only: bool = False, allow_missing: bool = False, set_grad_req_null: bool = True, - forward_pass_cache_size: int = 0) -> Tuple[List[SockeyeModel], - List[vocab.Vocab], List[vocab.Vocab]]: + forward_pass_cache_size: int = 0, + knn_index: Optional[str] = None) -> Tuple[List[SockeyeModel], + List[vocab.Vocab], + List[vocab.Vocab]]: """ Loads a list of models for inference. @@ -737,6 +798,7 @@ def load_models(device: pt.device, :param allow_missing: Allow missing parameters in the loaded models. :param set_grad_req_null: Set grad_req to null for model parameters. :param forward_pass_cache_size: If > 0, cache encoder and embedding calculations of forward pass. + :param knn_index: Optional path to a folder containing a KNN model index. :return: List of models, source vocabulary, target vocabulary, source factor vocabularies. """ logger.info("Loading %d model(s) from %s ...", len(model_folders), model_folders) @@ -760,7 +822,8 @@ def load_models(device: pt.device, train_decoder_only=train_decoder_only, allow_missing=allow_missing, set_grad_req_null=set_grad_req_null, - forward_pass_cache_size=forward_pass_cache_size) + forward_pass_cache_size=forward_pass_cache_size, + knn_index=knn_index) models.append(model) source_vocabs.append(src_vcbs) target_vocabs.append(trg_vcbs) diff --git a/sockeye/translate.py b/sockeye/translate.py index eb33146c4..f336c3e72 100644 --- a/sockeye/translate.py +++ b/sockeye/translate.py @@ -69,12 +69,14 @@ def run_translate(args: argparse.Namespace): device = init_device(args, logger) logger.info(f"Translate Device: {device}") + models, source_vocabs, target_vocabs = load_models(device=device, model_folders=args.models, checkpoints=args.checkpoints, dtype=args.dtype, clamp_to_dtype=args.clamp_to_dtype, - inference_only=True) + inference_only=True, + knn_index=args.knn_index) restrict_lexicon = None # type: Optional[Union[RestrictLexicon, Dict[str, RestrictLexicon]]] if args.restrict_lexicon is not None: @@ -134,6 +136,7 @@ def run_translate(args: argparse.Namespace): sample=args.sample, output_scores=output_handler.reports_score(), constant_length_ratio=constant_length_ratio, + knn_lambda=args.knn_lambda, max_output_length_num_stds=args.max_output_length_num_stds, max_input_length=args.max_input_length, max_output_length=args.max_output_length, diff --git a/sockeye/utils.py b/sockeye/utils.py index 2a4557c91..1e1e981aa 100644 --- a/sockeye/utils.py +++ b/sockeye/utils.py @@ -540,6 +540,23 @@ def get_torch_dtype(dtype: Union[pt.dtype, str]) -> pt.dtype: raise ValueError(f'Cannot convert to Torch dtype: {dtype}') +_STRING_TO_NUMPY_DTYPE = { + C.DTYPE_FP16: np.float16, + C.DTYPE_FP32: np.float32, + C.DTYPE_INT8: np.int8, + C.DTYPE_INT16: np.int16, + C.DTYPE_INT32: np.int32, +} + + +def get_numpy_dtype(dtype: Union[np.dtype, str]): + if isinstance(dtype, np.dtype): + return dtype + if dtype in _STRING_TO_NUMPY_DTYPE: + return _STRING_TO_NUMPY_DTYPE[dtype] + raise ValueError(f'Cannot convert to NumPy dtype: {dtype}') + + def log_parameters(model: pt.nn.Module): """ Logs information about model parameters. @@ -700,6 +717,23 @@ def using_deepspeed() -> bool: return _using_deepspeed +# Track whether Faiss has been confirmed importable +_faiss_checked = False + +def check_import_faiss(): + """ + Make sure the faiss module can be imported. + """ + global _faiss_checked + if not _faiss_checked: + try: + import faiss # pylint: disable=E0401 + _faiss_checked = True + except: + raise RuntimeError('To run kNN-MT models, please install faiss by following ' + 'https://github.com/facebookresearch/faiss/blob/main/INSTALL.md') + + def count_seq_len(sample: str, count_type: str = 'char', replace_tokens: Optional[List] = None) -> int: """ Count sequence length, after replacing (optional) token/s. diff --git a/test/unit/test_knn.py b/test/unit/test_knn.py new file mode 100644 index 000000000..644180a9e --- /dev/null +++ b/test/unit/test_knn.py @@ -0,0 +1,175 @@ +# Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You may not +# use this file except in compliance with the License. A copy of the License +# is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is distributed on +# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +# express or implied. See the License for the specific language governing +# permissions and limitations under the License. + +import numpy as np +import os +import pytest +import tempfile +import torch as pt + +import sockeye +import sockeye.constants as C +from sockeye.generate_decoder_states import NumpyMemmapStorage, DecoderStateGenerator +from sockeye.knn import KNNConfig, FaissIndexBuilder, get_config_path, get_state_store_path, get_word_store_path +from sockeye.vocab import build_vocab + +# Only run certain tests in this file if faiss is installed +try: + import faiss # pylint: disable=E0401 + faiss_installed = True +except: + faiss_installed = False + + +def test_numpy_memmap_storage(): + # test open + with tempfile.TemporaryDirectory() as memmap_dir: + memmap_file = os.path.join(memmap_dir, "foo") + store = NumpyMemmapStorage(memmap_file, 64, np.float16) + store.open(64, 8) + + # test add + ones_block_16 = np.ones((16, 64), dtype=np.float16) + ones_block_32 = np.ones((32, 64), dtype=np.float16) + zeros_block_16 = np.zeros((16, 64), dtype=np.float16) + zeros_block_32 = np.zeros((32, 64), dtype=np.float16) + # add for 0-15 + store.add(ones_block_16) + assert (store.mmap[:16, :] == ones_block_16).all() + assert (store.mmap[16:32, :] == zeros_block_16).all() + assert (store.mmap[32:, :] == zeros_block_32).all() + # add for 16-31 + store.add(ones_block_16) + assert (store.mmap[:32, :] == ones_block_32).all() + assert (store.mmap[32:, :] == zeros_block_32).all() + # add with size overflow -- this should trigger a warning without doing anything + store.add(np.ones((64, 64), dtype=np.float16)) + assert (store.mmap[:32, :] == ones_block_32).all() + assert (store.mmap[32:, :] == zeros_block_32).all() + + +def test_decoder_state_generator(): + data = 'One Ring to rule them all, One Ring to find them' + max_seq_len_source = 30 + max_seq_len_target = 30 + + vocabs = [build_vocab([data])] + config_embed = sockeye.encoder.EmbeddingConfig(vocab_size=len(vocabs[0]), num_embed=16, dropout=0.0) + config_encoder = sockeye.encoder.EncoderConfig(model_size=16, attention_heads=2, + feed_forward_num_hidden=16, depth_key_value=16, + act_type='relu', num_layers=2, dropout_attention=0.0, + dropout_act=0.0, dropout_prepost=0.0, + positional_embedding_type='fixed', preprocess_sequence='n', + postprocess_sequence='n', max_seq_len_source=max_seq_len_source, + max_seq_len_target=max_seq_len_target) + config_data = sockeye.data_io.DataConfig(data_statistics=None, + max_seq_len_source=max_seq_len_source, + max_seq_len_target=max_seq_len_target, + num_source_factors=0, num_target_factors=0) + config = sockeye.model.ModelConfig(config_data=config_data, + vocab_source_size=len(vocabs[0]), vocab_target_size=len(vocabs[0]), + config_embed_source=config_embed, config_embed_target=config_embed, + config_encoder=config_encoder, config_decoder=config_encoder) + + with tempfile.TemporaryDirectory() as model_dir, tempfile.TemporaryDirectory() as data_dir: + params_fname = os.path.join(model_dir, C.PARAMS_BEST_NAME) + + # create and save float32 model + model = sockeye.model.SockeyeModel(config=config) + assert model.dtype == pt.float32 + for param in model.parameters(): + assert param.dtype == pt.float32 + model.save_config(model_dir) + model.save_version(model_dir) + model.save_parameters(params_fname) + model.eval() + + # add dummy sentence to data_path + data_path = os.path.join(data_dir, "data.txt") + data_file = open(data_path, 'w') + data_file.write(data + '\n') + data_file.close() + + # create generator from mock model and vocab + generator = DecoderStateGenerator(model, vocabs, vocabs, data_dir, max_seq_len_source, max_seq_len_target, + 'float32', 'int32', pt.device('cpu')) + max_seq_len_target = min(max_seq_len_target + C.SPACE_FOR_XOS, max_seq_len_target) + generator.num_states = DecoderStateGenerator.probe_token_count(data_path, max_seq_len_target) + + # test init_store_file + generator.init_store_file(generator.num_states) + generator.dimension = 16 + + # test generate_states_and_store + data_paths = [data_path] + generator.generate_states_and_store(data_paths, data_paths, 1) + + # check if state and word store files are there + assert os.path.isfile(get_state_store_path(data_dir)) + assert os.path.isfile(get_word_store_path(data_dir)) + + # test save_config + generator.save_config() + + # check if the config content makes sense + config = KNNConfig.load(get_config_path(data_dir)) + assert config.index_size == DecoderStateGenerator.probe_token_count(data_path, max_seq_len_target) + assert config.dimension == 16 + assert config.state_data_type == 'float32' + assert config.word_data_type == 'int32' + assert config.index_type == '' + assert config.train_data_size == -1 + + +@pytest.mark.skipif(not faiss_installed, reason='Faiss is not installed') +def test_faiss_index_builder(): + num_data_points = 16 + num_dimensions = 16 + + config = KNNConfig(num_data_points, num_dimensions, 'float32', 'int32', "Flat", -1) + builder = FaissIndexBuilder(config) + index = builder.init_faiss_index() + + # build data + states = np.outer(np.arange(num_data_points, dtype=np.float32), np.ones(num_dimensions, dtype=np.float32)) + + # offset should be < 0.5 + def query_tests(offset): + # check by querying into the index + for i in range(1, num_data_points - 1): + query = np.expand_dims(states[i], axis=0) + offset + dists, idxs = index.search(query, 3) # 3 because we expect to see itself and the two neighboring ones + + # check for idxs + assert idxs[0][0] == i + + # check for dists + # note that faiss won't do sqrt for the L2 distances + # (reference: https://github.com/facebookresearch/faiss/wiki/MetricType-and-distances#metric_l2) + gld_dists = np.array([[np.power(offset, 2) * num_dimensions, + np.power(1 - offset, 2) * num_dimensions, np.power(1 + offset, 2) * num_dimensions]]) + assert np.all(np.isclose(dists, gld_dists)) + + # test add_items + builder.add_items(index, states) + query_tests(0.1) + index.reset() + + # test block_add_items + builder.block_add_items(index, states, block_size=3) # add items with block size of 3 + query_tests(0.1) + + # test build_train_sample + train_sample = builder.build_train_sample(states, 8) + assert train_sample.shape == (8, num_dimensions) + assert states.dtype == train_sample.dtype diff --git a/test/unit/test_layers.py b/test/unit/test_layers.py index 895d1fa6a..d987d65c5 100644 --- a/test/unit/test_layers.py +++ b/test/unit/test_layers.py @@ -11,11 +11,22 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. +from math import pow, sqrt +import numpy as np import pytest import torch as pt +import sockeye.constants as C import sockeye.layers import sockeye.transformer +from sockeye.knn import KNNConfig, FaissIndexBuilder + +# Only run certain tests in this file if faiss is installed +try: + import faiss # pylint: disable=E0401 + faiss_installed = True +except: + faiss_installed = False def test_lhuc(): @@ -166,3 +177,67 @@ def test_interleaved_multihead_self_attention(seq_len, batch_size, hidden, heads r_test, _ = mha(inputs, previous_states=None, mask=mask) # Note: can also handle the mask repated on the qlen axis assert pt.allclose(r_train, r_test, atol=1e-06) + + +@pytest.mark.skipif(not faiss_installed, reason='Faiss is not installed') +def test_knn_layer(): + num_data_points = 16 + num_dimensions = 16 + assert num_dimensions > 4 # there are at least 4 items in a vocabulary + + config = KNNConfig(num_data_points, num_dimensions, 'float32', 'int32', "Flat", -1) + builder = FaissIndexBuilder(config) + index = builder.init_faiss_index() + + # build data + states = np.outer(np.arange(num_data_points, dtype=np.float32), np.ones(num_dimensions, dtype=np.float32)) + words = np.arange(num_data_points + 1, dtype=np.int32) - 1 # need to prepend a at the beginning + words[0] = 0 + words = np.expand_dims(words, axis=1) + builder.add_items(index, states) + + # in case BOS and/or EOS ID are changed, the test should be revisited to make sure no overflow/underflow occurs + assert C.BOS_ID + 2 < num_data_points - 1 + assert C.EOS_ID < C.BOS_ID + 2 or C.EOS_ID > num_data_points + + def build_gld_probs(offset): + gld_dists = pt.sqrt(pt.FloatTensor([pow(1 + offset, 2) * num_dimensions, + pow(offset, 2) * num_dimensions, + pow(1 - offset, 2) * num_dimensions])) + gld_logits = pt.exp(-gld_dists) + gld_probs = gld_logits.div_(pt.sum(gld_logits)) + return gld_probs + + def query_test(knn_layer, offset): + for i in range(C.BOS_ID + 2, num_data_points - 1): + query = np.expand_dims(states[i], axis=0) + offset + query = pt.from_numpy(query) + probs = knn_layer(query) + + logits_idxs = pt.LongTensor(list(range(i-1, i+2))) + gld_probs = pt.zeros(1, num_dimensions) + gld_probs[0, logits_idxs] = build_gld_probs(offset) + assert pt.allclose(probs, gld_probs) + + # test when inexact distances are used + knn_layer = sockeye.layers.KNN(index, words, 16, 3, 1) + query_test(knn_layer, 0.1) + + # test when exact distances are used + knn_layer = sockeye.layers.KNN(index, words, 16, 3, 1, states) + query_test(knn_layer, 0.1) + + # test BOS case & scatter_add + assert C.BOS_ID == 2 + assert C.EOS_ID == 3 + offset = 0.1 + query = np.expand_dims(states[C.BOS_ID], axis=0) + offset + query = pt.from_numpy(query) + probs = knn_layer(query) + gld_probs_unscattered = build_gld_probs(offset) + gld_probs = pt.zeros(1, num_dimensions) + gld_probs[0, 1] = gld_probs_unscattered[0] + gld_probs[0, 3] = gld_probs_unscattered[1] + gld_probs_unscattered[2] + gld_probs = gld_probs.div_(pt.sum(gld_probs)) + + assert pt.allclose(probs, gld_probs)