diff --git a/python/ctranslate2/converters/eole_ct2.py b/python/ctranslate2/converters/eole_ct2.py new file mode 100644 index 000000000..e20d879c9 --- /dev/null +++ b/python/ctranslate2/converters/eole_ct2.py @@ -0,0 +1,352 @@ +import argparse + +from eole.config.run import PredictConfig +from eole.constants import PositionEncodingType +from eole.inputters.inputter import vocabs_to_dict +from eole.models.model import BaseModel + +from ctranslate2.converters import utils +from ctranslate2.converters.converter import Converter +from ctranslate2.specs import common_spec, transformer_spec + +_SUPPORTED_ACTIVATIONS = { + "gelu": common_spec.Activation.GELU, + "fast_gelu": common_spec.Activation.GELUTanh, + "relu": common_spec.Activation.RELU, + "gated-silu": common_spec.Activation.SWISH, +} + + +def _get_model_spec_seq2seq( + config, variables, src_vocabs, tgt_vocabs, num_source_embeddings +): + """Creates a model specification from the model config.""" + with_relative_position = ( + getattr(config.embeddings, "position_encoding_type", None) + == PositionEncodingType.Relative + ) + with_rotary = ( + getattr(config.embeddings, "position_encoding_type", None) + == PositionEncodingType.Rotary + ) + if with_rotary: + raise ValueError( + "Rotary embeddings are not supported yet for encoder/decoder models" + ) + with_alibi = ( + getattr(config.embeddings, "position_encoding_type", None) + == PositionEncodingType.Alibi + ) + if with_alibi: + raise ValueError("Alibi is not supported yet for encoder/decoder models") + activation_fn = getattr(config, "mlp_activation_fn", "relu") + + # Return the first head of the last layer unless the model was trained with alignments. + if getattr(config.decoder, "lambda_align", 0) == 0: + alignment_layer = -1 + alignment_heads = 1 + else: + alignment_layer = config.decoder.alignment_layer + alignment_heads = config.decoder.alignment_heads + + num_heads = getattr(config.decoder, "heads", 8) + # num_kv = getattr(config.decoder, "heads_kv", 0) + # if num_kv == num_heads or num_kv == 0: + # num_kv = None + # rotary_dim = 0 if with_rotary else None + # rotary_interleave = getattr(config.rope_config, "rotary_interleave", True) + ffn_glu = activation_fn == "gated-silu" + sliding_window = getattr(config, "sliding_window", 0) + if sliding_window != 0: + raise ValueError( + "Sliding window is not suported yet for encoder/decoder models" + ) + + model_spec = transformer_spec.TransformerSpec.from_config( + (config.encoder.layers, config.decoder.layers), + num_heads, + with_relative_position=with_relative_position, + # alibi=with_alibi, + activation=_SUPPORTED_ACTIVATIONS[activation_fn], + ffn_glu=ffn_glu, + rms_norm=config.layer_norm == "rms", + # rotary_dim=rotary_dim, + # rotary_interleave=rotary_interleave, + # num_heads_kv=num_kv, + # sliding_window=sliding_window, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + num_source_embeddings=num_source_embeddings, + # multi_query_attention=getattr(opt, "multiquery", False), + ) + + set_transformer_spec(model_spec, variables) + for src_vocab in src_vocabs: + model_spec.register_source_vocabulary(src_vocab) + for tgt_vocab in tgt_vocabs: + model_spec.register_target_vocabulary(tgt_vocab) + + return model_spec + + +def _get_model_spec_lm( + config, variables, src_vocabs, tgt_vocabs, num_source_embeddings +): + """Creates a model specification from the model config.""" + with_relative_position = ( + getattr(config.embeddings, "position_encoding_type", None) + == PositionEncodingType.Relative + ) + with_rotary = ( + getattr(config.embeddings, "position_encoding_type", None) + == PositionEncodingType.Rotary + ) + with_alibi = ( + getattr(config.embeddings, "position_encoding_type", None) + == PositionEncodingType.Alibi + ) + activation_fn = getattr(config, "mlp_activation_fn", "relu") + num_heads = getattr(config.decoder, "heads", 8) + num_kv = getattr(config.decoder, "heads_kv", 0) + if num_kv == num_heads or num_kv == 0: + num_kv = None + rotary_dim = 0 if with_rotary else None + rotary_interleave = getattr(config.rope_config, "rotary_interleave", True) + ffn_glu = activation_fn == "gated-silu" + sliding_window = getattr(config, "sliding_window", 0) + + model_spec = transformer_spec.TransformerDecoderModelSpec.from_config( + config.decoder.layers, + num_heads, + activation=_SUPPORTED_ACTIVATIONS[activation_fn], + ffn_glu=ffn_glu, + with_relative_position=with_relative_position, + alibi=with_alibi, + rms_norm=config.layer_norm == "rms", + rotary_dim=rotary_dim, + rotary_interleave=rotary_interleave, + num_heads_kv=num_kv, + sliding_window=sliding_window, + # multi_query_attention=getattr(opt, "multiquery", False), + ) + + set_transformer_decoder( + model_spec.decoder, + variables, + with_encoder_attention=False, + ) + + for tgt_vocab in tgt_vocabs: + model_spec.register_vocabulary(tgt_vocab) + + return model_spec + + +def get_vocabs(vocab): + src_vocabs = [vocab["src"]] + tgt_vocabs = [vocab["tgt"]] + return src_vocabs, tgt_vocabs + + +class EoleConverter(Converter): + """Converts models generated by OpenNMT-py.""" + + def __init__(self, model_path: str): + """Initializes the OpenNMT-py converter. + + Arguments: + model_path: Path to the OpenNMT-py PyTorch model (.pt file). + """ + self._model_path = model_path + + def _load(self): + import torch + + config = PredictConfig(model_path=self._model_path, src="dummy") + + vocabs, model, model_config = BaseModel.load_test_model(config) + vocabs_dict = vocabs_to_dict(vocabs) + + config.model = model_config + src_vocabs, tgt_vocabs = get_vocabs(vocabs_dict) + + if config.model.decoder.decoder_type == "transformer_lm": + spec = _get_model_spec_lm( + config.model, + model.state_dict(), + src_vocabs, + tgt_vocabs, + num_source_embeddings=len(src_vocabs), + ) + else: + spec = _get_model_spec_seq2seq( + config.model, + model.state_dict(), + src_vocabs, + tgt_vocabs, + num_source_embeddings=len(src_vocabs), + ) + spec.config.decoder_start_token = vocabs["decoder_start_token"] + + spec.config.bos_token = vocabs["specials"]["bos_token"] + spec.config.eos_token = vocabs["specials"]["eos_token"] + spec.config.unk_token = vocabs["specials"]["unk_token"] + spec.config.layer_norm_epsilon = getattr(config, "norm_eps", 1e-6) + + return spec + + +def set_transformer_spec(spec, variables): + set_transformer_encoder(spec.encoder, variables) + set_transformer_decoder(spec.decoder, variables) + + +def set_transformer_encoder(spec, variables): + set_input_layers(spec, variables, "src_emb") + set_layer_norm(spec.layer_norm, variables, "encoder.layer_norm") + for i, layer in enumerate(spec.layer): + set_transformer_encoder_layer( + layer, variables, "encoder.transformer_layers.%d" % i + ) + + +def set_transformer_decoder(spec, variables, with_encoder_attention=True): + set_input_layers(spec, variables, "tgt_emb") + set_layer_norm(spec.layer_norm, variables, "decoder.layer_norm") + for i, layer in enumerate(spec.layer): + set_transformer_decoder_layer( + layer, + variables, + "decoder.transformer_layers.%d" % i, + with_encoder_attention=with_encoder_attention, + ) + + set_linear(spec.projection, variables, "generator") + + +def set_input_layers(spec, variables, scope): + if hasattr(spec, "position_encodings"): + set_position_encodings( + spec.position_encodings, + variables, + "%s.pe" % scope, + ) + else: + spec.scale_embeddings = False + + embeddings_specs = spec.embeddings + # encoder embeddings are stored in a list(onmt/ct2 legacy with features) + if isinstance(embeddings_specs, list): + embeddings_specs = embeddings_specs[0] + set_embeddings(embeddings_specs, variables, "%s.embeddings" % scope) + + +def set_transformer_encoder_layer(spec, variables, scope): + set_multi_head_attention( + spec.self_attention, + variables, + "%s.self_attn" % scope, + self_attention=True, + ) + set_layer_norm( + spec.self_attention.layer_norm, variables, "%s.input_layernorm" % scope + ) + set_layer_norm( + spec.ffn.layer_norm, variables, "%s.post_attention_layernorm" % scope + ) + set_ffn(spec.ffn, variables, "%s.mlp" % scope) + + +def set_transformer_decoder_layer(spec, variables, scope, with_encoder_attention=True): + set_multi_head_attention( + spec.self_attention, + variables, + "%s.self_attn" % scope, + self_attention=True, + ) + set_layer_norm( + spec.self_attention.layer_norm, variables, "%s.input_layernorm" % scope + ) + if with_encoder_attention: + set_multi_head_attention(spec.attention, variables, "%s.context_attn" % scope) + set_layer_norm( + spec.attention.layer_norm, variables, "%s.precontext_layernorm" % scope + ) + set_layer_norm( + spec.ffn.layer_norm, variables, "%s.post_attention_layernorm" % scope + ) + set_ffn(spec.ffn, variables, "%s.mlp" % scope) + + +def set_ffn(spec, variables, scope): + set_linear(spec.linear_0, variables, "%s.gate_up_proj" % scope) + set_linear(spec.linear_1, variables, "%s.down_proj" % scope) + if hasattr(spec, "linear_0_noact"): + set_linear(spec.linear_0_noact, variables, "%s.up_proj" % scope) + + +def set_multi_head_attention(spec, variables, scope, self_attention=False): + if self_attention: + split_layers = [common_spec.LinearSpec() for _ in range(3)] + set_linear(split_layers[0], variables, "%s.linear_query" % scope) + set_linear(split_layers[1], variables, "%s.linear_keys" % scope) + set_linear(split_layers[2], variables, "%s.linear_values" % scope) + utils.fuse_linear(spec.linear[0], split_layers) + else: + set_linear(spec.linear[0], variables, "%s.linear_query" % scope) + split_layers = [common_spec.LinearSpec() for _ in range(2)] + set_linear(split_layers[0], variables, "%s.linear_keys" % scope) + set_linear(split_layers[1], variables, "%s.linear_values" % scope) + utils.fuse_linear(spec.linear[1], split_layers) + set_linear(spec.linear[-1], variables, "%s.final_linear" % scope) + if hasattr(spec, "relative_position_keys"): + spec.relative_position_keys = _get_variable( + variables, "%s.relative_positions_embeddings.weight" % scope + ) + spec.relative_position_values = spec.relative_position_keys + + +def set_layer_norm(spec, variables, scope): + try: + spec.gamma = _get_variable(variables, "%s.weight" % scope) + except KeyError: + # Compatibility with older models using a custom LayerNorm module. + spec.gamma = _get_variable(variables, "%s.a_2" % scope) + spec.beta = _get_variable(variables, "%s.b_2" % scope) + try: + spec.beta = _get_variable(variables, "%s.bias" % scope) + except KeyError: + pass + + +def set_linear(spec, variables, scope): + spec.weight = _get_variable(variables, "%s.weight" % scope) + bias = variables.get("%s.bias" % scope) + if bias is not None: + spec.bias = bias + + +def set_embeddings(spec, variables, scope): + spec.weight = _get_variable(variables, "%s.weight" % scope) + + +def set_position_encodings(spec, variables, scope): + spec.encodings = _get_variable(variables, "%s.pe" % scope).squeeze() + + +def _get_variable(variables, name): + return variables[name] + + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("--model_path", required=True, help="Model path.") + Converter.declare_arguments(parser) + args = parser.parse_args() + EoleConverter(args.model_path).convert_from_args(args) + + +if __name__ == "__main__": + main()