Skip to content

Commit

Permalink
Add kNN-MT (#1062)
Browse files Browse the repository at this point in the history
Co-authored-by: Shuoyang Ding <dings@amazon.com>
Co-authored-by: Gary Gao <gary_gao2000@hotmail.com>
Co-authored-by: Jeremy Gwinnup <jeremy@gwinnup.org>
Co-authored-by: Tobias Domhan <domhant@amazon.com>
  • Loading branch information
5 people authored Dec 10, 2022
1 parent 288baa7 commit b4c8427
Show file tree
Hide file tree
Showing 18 changed files with 1,056 additions and 27 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/push_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ jobs:
run: pip install -r requirements/requirements.txt
- name: DeepSpeed requirements
run: pip install -r requirements/requirements.deepspeed.txt
- name: Faiss requirements
run: |
if [ "$RUNNER_OS" == "Linux" ]; then
pip install -r requirements/requirements.faiss-cpu.txt
fi
shell: bash
- name: Development requirements
run: pip install -r requirements/requirements.dev.txt
- name: Unit tests
Expand Down
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@ Note that Sockeye has checks in place to not translate with an old model that wa

Each version section may have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.

## [3.1.28]

### Added

- Added kNN-MT model from [Khandelwal et al., 2021](https://arxiv.org/abs/2010.00710).
- Installation: see [faiss document](https://github.com/facebookresearch/faiss/blob/main/INSTALL.md) -- installation via conda is recommended.
- Building a faiss index from a sockeye model takes two steps:
- Generate decoder states: `sockeye-generate-decoder-states -m [model] --source [src] --target [tgt] --output-dir [output dir]`
- Build index: `sockeye-knn -i [input_dir] -o [output_dir] -t [faiss_index_signature]` where `input_dir` is the same as `output_dir` from the `sockeye-generate-decoder-states` command.
- Faiss index signature reference: [see here](https://github.com/facebookresearch/faiss/wiki/The-index-factory)
- Running inference using the built index: `sockeye-translate ... --knn-index [index_dir] --knn-lambda [interpolation_weight]` where `index_dir` is the same as `output_dir` from the `sockeye-knn` command.

## [3.1.27]

### Changed
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements.faiss-cpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
faiss-cpu >= 1.7.2
1 change: 1 addition & 0 deletions requirements/requirements.faiss-gpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
faiss-gpu >= 1.7.2
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def get_requirements(filename):
'sockeye-translate = sockeye.translate:main',
'sockeye-vocab = sockeye.vocab:main',
'sockeye-rerank = sockeye.rerank:main',
'sockeye-knn = sockeye.knn:main',
'sockeye-generate-decoder-states = sockeye.generate_decoder_states:main'
],
}
args = dict(
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '3.1.27'
__version__ = '3.1.28'
82 changes: 80 additions & 2 deletions sockeye/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,6 +1131,7 @@ def add_translate_cli_args(params):
add_inference_args(params)
add_device_args(params)
add_logging_args(params)
add_knn_mt_args(params) # for kNN MT


def add_score_cli_args(params):
Expand Down Expand Up @@ -1173,7 +1174,40 @@ def add_score_cli_args(params):
'peaked predictions, values > 1.0 produce smoothed distributions.')

params.add_argument('--dtype', default=None, choices=[None, C.DTYPE_FP32, C.DTYPE_FP16, C.DTYPE_INT8],
help="Data type. Default: %(default)s infers from saved model.")
help="Data type. Default: infers from saved model.")

add_logging_args(params)


def add_state_generation_args(params):
add_training_data_args(params, required=True)
add_vocab_args(params)
add_device_args(params)
add_batch_args(params, default_batch_size=56, default_batch_type=C.BATCH_TYPE_SENTENCE)

decode_params = params.add_argument_group("Decoder state generation parameters")

params.add_argument('--state-dtype', default=None, choices=[None, C.DTYPE_FP32, C.DTYPE_FP16],
help="Data type of the decoder state store. Default: infers from saved model.")

params.add_argument("--model", "-m", required=True,
help="Model directory containing trained model.")

params.add_argument(C.TRAINING_ARG_MAX_SEQ_LEN,
type=multiple_values(num_values=2, greater_or_equal=1),
default=None,
help='Maximum sequence length in tokens.'
'Use "x:x" to specify separate values for src&tgt. Default: Read from model.')

# common params with translate CLI
add_length_penalty_args(params)
add_brevity_penalty_args(params)

params.add_argument("--output-dir", "-o", default=None,
help="The path to the directory that stores the decoder states.")

params.add_argument('--dtype', default=None, choices=[None, C.DTYPE_FP32, C.DTYPE_FP16, C.DTYPE_INT8],
help="Data type. Default: infers from saved model.")

add_logging_args(params)

Expand Down Expand Up @@ -1343,7 +1377,7 @@ def add_inference_args(params):
add_brevity_penalty_args(decode_params)

decode_params.add_argument('--dtype', default=None, choices=[None, C.DTYPE_FP32, C.DTYPE_FP16, C.DTYPE_INT8],
help="Data type. Default: %(default)s infers from saved model.")
help="Data type. Default: infers from saved model.")
add_clamp_to_dtype_arg(decode_params)


Expand Down Expand Up @@ -1380,6 +1414,7 @@ def add_brevity_penalty_args(params):
'ratio, used for brevity penalty calculation, for all inputs. If zero, uses the average of length '
'ratios from the training data over all models. Default: %(default)s.')


def add_clamp_to_dtype_arg(params):
params.add_argument('--clamp-to-dtype',
action='store_true',
Expand Down Expand Up @@ -1423,3 +1458,46 @@ def add_build_vocab_args(params):
params.add_argument('-o', '--output', required=True, type=str, help="Output filename to write vocabulary to.")
add_vocab_args(params)
add_process_pool_args(params)


def add_knn_mt_args(params):
knn_params = params.add_argument_group("kNN MT parameters")

knn_params.add_argument('--knn-index',
type=str,
help='Optionally use a KNN index during inference to '
'retrieve similar hidden states and corresponding target tokens.',
default=None)
knn_params.add_argument('--knn-lambda',
type=float,
help="Interpolation parameter when using KNN index. Default: %(default)s.",
default=C.DEFAULT_KNN_LAMBDA)


def add_build_knn_index_args(params):
params.add_argument('-i', '--input-dir',
required=True,
type=str,
help='The directory that contains the stored decoder states and values '
f'({C.KNN_STATE_DATA_STORE_NAME} and {C.KNN_WORD_DATA_STORE_NAME}).')
params.add_argument('-o', '--output-dir',
default=None,
type=str,
help='The path to the output directory. Will reuse input directory if not specified.')
params.add_argument('-t', '--index-type',
default=None,
type=str,
help='An optional field to specify the type of the index. '
'Will override settings in the config. '
'The type is specified with a faiss index factory signature, see here: '
'https://github.com/facebookresearch/faiss/wiki/The-index-factory')
params.add_argument('--train-data-input-file',
default=None,
type=str,
help='An optional field to reuse an already-built training data sample for the index. '
'Otherwise, a (slow) sampling step might need to be run.')
params.add_argument('--train-data-size',
default=None,
type=int,
help='An optional field to specify the size of the training sample. '
'Will override settings in the config.')
35 changes: 26 additions & 9 deletions sockeye/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ class _SingleModelInference(_Inference):
def __init__(self,
model: SockeyeModel,
skip_softmax: bool = False,
constant_length_ratio: float = 0.0) -> None:
constant_length_ratio: float = 0.0,
knn_lambda: float = C.DEFAULT_KNN_LAMBDA) -> None:
self._model = model
self._skip_softmax = skip_softmax
self._const_lr = constant_length_ratio
self.knn_lambda = knn_lambda

def state_structure(self) -> List:
return [self._model.state_structure()]
Expand All @@ -83,10 +85,17 @@ def decode_step(self,
vocab_slice_ids: Optional[pt.Tensor] = None,
target_prefix_factor_mask: Optional[pt.Tensor] = None,
factor_vocab_size: Optional[int] = None):
logits, states, target_factor_outputs = self._model.decode_step(step_input, states, vocab_slice_ids)
logits, knn_probs, states, target_factor_outputs = self._model.decode_step(step_input, states, vocab_slice_ids)
if not self._skip_softmax:
logits = pt.log_softmax(logits, dim=-1)
scores = -logits # shape: (batch*beam, output_vocab_size/len(vocab_slice_ids))
if knn_probs is None: # no knn used
probs = pt.log_softmax(logits, dim=-1)
else:
probs = pt.log(self.knn_lambda * pt.softmax(logits, dim=-1) + (1 - self.knn_lambda) * knn_probs)
else:
assert knn_probs is None, "Can't skip softmax with KNN."
probs = logits

scores = -probs # shape: (batch*beam, output_vocab_size/len(vocab_slice_ids))

target_factors = None # type: Optional[pt.Tensor]
if target_factor_outputs:
Expand Down Expand Up @@ -120,7 +129,8 @@ class _EnsembleInference(_Inference):
def __init__(self,
models: List[SockeyeModel],
ensemble_mode: str = 'linear',
constant_length_ratio: float = 0.0) -> None:
constant_length_ratio: float = 0.0,
knn_lambda: float = C.DEFAULT_KNN_LAMBDA) -> None:
self._models = models
if ensemble_mode == 'linear':
self._interpolation = self.linear_interpolation
Expand All @@ -129,6 +139,7 @@ def __init__(self,
else:
raise ValueError()
self._const_lr = constant_length_ratio
self.knn_lambda = knn_lambda

def state_structure(self) -> List:
return [model.state_structure() for model in self._models]
Expand Down Expand Up @@ -163,8 +174,11 @@ def decode_step(self,
for model, model_state_structure in zip(self._models, self.state_structure()):
model_states = states[state_index:state_index+len(model_state_structure)]
state_index += len(model_state_structure)
logits, model_states, target_factor_outputs = model.decode_step(step_input, model_states, vocab_slice_ids)
probs = logits.softmax(dim=-1)
logits, knn_probs, model_states, target_factor_outputs = model.decode_step(step_input, model_states, vocab_slice_ids)
if knn_probs is None:
probs = logits.softmax(dim=-1)
else:
probs = self.knn_lambda * pt.softmax(logits, dim=-1) + (1 - self.knn_lambda) * knn_probs
outputs.append(probs)
if target_factor_outputs:
target_factor_probs = [tfo.softmax(dim=-1) for tfo in target_factor_outputs]
Expand Down Expand Up @@ -1086,6 +1100,7 @@ def get_search_algorithm(models: List[SockeyeModel],
ensemble_mode: str = 'linear',
beam_search_stop: str = C.BEAM_SEARCH_STOP_ALL,
constant_length_ratio: float = 0.0,
knn_lambda: float = C.DEFAULT_KNN_LAMBDA,
sample: Optional[int] = None,
prevent_unk: bool = False,
greedy: bool = False,
Expand Down Expand Up @@ -1114,7 +1129,8 @@ def get_search_algorithm(models: List[SockeyeModel],
num_target_factors=models[0].num_target_factors,
inference=_SingleModelInference(model=models[0],
skip_softmax=True,
constant_length_ratio=0.0),
constant_length_ratio=0.0,
knn_lambda=knn_lambda),
skip_nvs=skip_nvs,
nvs_thresh=nvs_thresh)
else:
Expand All @@ -1125,7 +1141,8 @@ def get_search_algorithm(models: List[SockeyeModel],
logger.info("Enabled skipping softmax for a single model and greedy decoding.")
inference = _SingleModelInference(model=models[0],
skip_softmax=skip_softmax,
constant_length_ratio=constant_length_ratio)
constant_length_ratio=constant_length_ratio,
knn_lambda=knn_lambda)
else:
inference = _EnsembleInference(models=models,
ensemble_mode=ensemble_mode,
Expand Down
11 changes: 11 additions & 0 deletions sockeye/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@
DTYPE_FP32 = 'float32'
DTYPE_TF32 = 'tf32'
DTYPE_INT8 = 'int8'
DTYPE_INT16 = 'int16'
DTYPE_INT32 = 'int32'
LARGE_POSITIVE_VALUE = 99999999.
LARGE_VALUES = {
Expand Down Expand Up @@ -381,3 +382,13 @@
BREVITY_PENALTY_CONSTANT = 'constant'
BREVITY_PENALTY_LEARNED = 'learned'
BREVITY_PENALTY_NONE = 'none'

# k-nn
KNN_STATE_DATA_STORE_NAME = "keys.npy"
KNN_WORD_DATA_STORE_NAME = "vals.npy"
KNN_WORD_DATA_STORE_DTYPE = DTYPE_INT32
KNN_CONFIG_NAME = "config.yaml"
KNN_INDEX_NAME = "key_index"
KNN_EPSILON = 1e-6
DEFAULT_DATA_STORE_BLOCK_SIZE = 1024 * 1024
DEFAULT_KNN_LAMBDA = 0.8
Loading

0 comments on commit b4c8427

Please sign in to comment.