Skip to content

Commit

Permalink
fix gemma bug (#1660)
Browse files Browse the repository at this point in the history
* fix gemma bug

* fix black
  • Loading branch information
minhthuc2502 authored Apr 10, 2024
1 parent 1deef09 commit 28812c1
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion python/ctranslate2/converters/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,10 +1291,16 @@ def get_model_spec(self, model):
if num_heads_kv == num_heads:
num_heads_kv = None

activation_config = getattr(
model.config, "hidden_activation", "gelu_pytorch_tanh"
)

spec = transformer_spec.TransformerDecoderModelSpec.from_config(
num_layers,
num_heads,
activation=common_spec.Activation.GELU,
activation=common_spec.Activation.GELU
if activation_config == "gelu"
else common_spec.Activation.GELUTanh,
pre_norm=True,
ffn_glu=True,
rms_norm=True,
Expand Down

0 comments on commit 28812c1

Please sign in to comment.