Skip to content

Commit

Permalink
Accept Transformer models with custom number of layers and heads (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaumekln authored Oct 28, 2019
1 parent e32621d commit b4ed10b
Show file tree
Hide file tree
Showing 8 changed files with 27 additions and 10 deletions.
3 changes: 2 additions & 1 deletion include/ctranslate2/models/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ namespace ctranslate2 {
class TransformerModel : public Model
{
public:
TransformerModel(const std::string& path, size_t spec_revision, size_t num_heads = 0);
size_t num_heads() const;
size_t current_spec_revision() const override;
std::unique_ptr<layers::Encoder> make_encoder() const override;
std::unique_ptr<layers::Decoder> make_decoder() const override;
protected:
TransformerModel(const std::string& path, size_t spec_revision, size_t num_heads);
void register_variable(const std::string& name, StorageView& variable) override;
void finalize() override;

size_t _num_heads;
};
Expand Down
4 changes: 2 additions & 2 deletions python/ctranslate2/converters/opennmt_py.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from ctranslate2.converters import utils
from ctranslate2.converters.converter import Converter
from ctranslate2.specs import catalog, transformer_spec
from ctranslate2.specs import transformer_spec


class OpenNMTPyConverter(Converter):
Expand All @@ -22,7 +22,7 @@ def _load(self, model_spec):
variables = checkpoint["model"]
variables["generator.weight"] = checkpoint["generator"]["0.weight"]
variables["generator.bias"] = checkpoint["generator"]["0.bias"]
if isinstance(model_spec, (catalog.TransformerBase, catalog.TransformerBig)):
if isinstance(model_spec, transformer_spec.TransformerSpec):
set_transformer_spec(model_spec, variables)
else:
raise NotImplementedError()
Expand Down
4 changes: 2 additions & 2 deletions python/ctranslate2/converters/opennmt_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ctranslate2.converters import utils
from ctranslate2.converters.converter import Converter
from ctranslate2.specs import catalog, transformer_spec
from ctranslate2.specs import transformer_spec


def load_model(model_path, src_vocab=None, tgt_vocab=None):
Expand Down Expand Up @@ -87,7 +87,7 @@ def _load(self, model_spec):
self._model_path,
src_vocab=self._src_vocab,
tgt_vocab=self._tgt_vocab)
if isinstance(model_spec, (catalog.TransformerBase, catalog.TransformerBig)):
if isinstance(model_spec, transformer_spec.TransformerSpec):
if version == 2:
set_transformer_spec_v2(model_spec, variables)
else:
Expand Down
4 changes: 2 additions & 2 deletions python/ctranslate2/specs/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

class TransformerBase(transformer_spec.TransformerSpec):
def __init__(self):
super(TransformerBase, self).__init__(6)
super(TransformerBase, self).__init__(6, 8)

class TransformerBig(transformer_spec.TransformerSpec):
def __init__(self):
super(TransformerBig, self).__init__(6)
super(TransformerBig, self).__init__(6, 16)
11 changes: 9 additions & 2 deletions python/ctranslate2/specs/transformer_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,20 @@


class TransformerSpec(model_spec.LayerSpec):
def __init__(self, num_layers):
"""Describes a Transformer model.
The specification is invariant to hidden dimensions but requires to
explicitly set the number of layers and attention heads.
"""
def __init__(self, num_layers, num_heads):
import numpy as np
self.num_heads = np.dtype("int8").type(num_heads)
self.encoder = TransformerEncoderSpec(num_layers)
self.decoder = TransformerDecoderSpec(num_layers)

@property
def revision(self):
return 2
return 3

class TransformerEncoderSpec(model_spec.LayerSpec):
def __init__(self, num_layers):
Expand Down
1 change: 1 addition & 0 deletions python/tests/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def test_return_attention():
@pytest.mark.parametrize(
"model_path,src_vocab,tgt_vocab,model_spec",
[("v1/savedmodel", None, None, "TransformerBase"),
("v1/savedmodel", None, None, ctranslate2.specs.TransformerSpec(num_layers=6, num_heads=8)),
("v1/checkpoint", "ar.vocab", "en.vocab", ctranslate2.specs.TransformerBase()),
("v2/checkpoint", "ar.vocab", "en.vocab", ctranslate2.specs.TransformerBase()),
])
Expand Down
2 changes: 2 additions & 0 deletions src/models/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,8 @@ namespace ctranslate2 {
model = new TransformerBaseModel(path, spec_revision);
else if (spec == "TransformerBig")
model = new TransformerBigModel(path, spec_revision);
else if (spec == "TransformerSpec")
model = new TransformerModel(path, spec_revision);
else
throw std::invalid_argument("Unsupported model spec " + spec);

Expand Down
8 changes: 7 additions & 1 deletion src/models/transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ namespace ctranslate2 {
}

size_t TransformerModel::current_spec_revision() const {
return 2;
return 3;
}

void TransformerModel::register_variable(const std::string& name, StorageView& variable) {
Expand All @@ -56,6 +56,12 @@ namespace ctranslate2 {
Model::register_variable(var_name, variable);
}

void TransformerModel::finalize() {
Model::finalize();
if (_spec_revision >= 3)
_num_heads = get_variable("num_heads").as_scalar<int8_t>();
}

std::unique_ptr<layers::Encoder> TransformerModel::make_encoder() const {
return std::unique_ptr<layers::Encoder>(new TransformerEncoder(*this, "encoder"));
}
Expand Down

0 comments on commit b4ed10b

Please sign in to comment.