diff --git a/.allennlp_plugins b/.allennlp_plugins index 19dbb52..057be1e 100644 --- a/.allennlp_plugins +++ b/.allennlp_plugins @@ -1 +1 @@ -srl_transformers \ No newline at end of file +transformer_srl \ No newline at end of file diff --git a/setup.py b/setup.py index 14f72fc..0a74821 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="transformer_srl", # Replace with your own username - version="1.0", + version="1.1", author="Riccardo Orlando", author_email="orlandoricc@gmail.com", description="SRL Transformer model", diff --git a/training_config/bert_base.jsonnet b/training_config/bert_base.jsonnet index 5f70275..3aa9479 100644 --- a/training_config/bert_base.jsonnet +++ b/training_config/bert_base.jsonnet @@ -1,6 +1,6 @@ { "dataset_reader": { - "type": "srl_transformers", + "type": "transformer_srl", "bert_model_name": "bert-base-cased", }, @@ -15,7 +15,7 @@ "validation_data_path": std.extVar("SRL_VALIDATION_DATA_PATH"), "model": { - "type": "srl_transformers", + "type": "transformer_srl", "embedding_dropout": 0.1, "bert_model": "bert-base-cased", }, diff --git a/training_config/bert_base_ml.jsonnet b/training_config/bert_base_ml.jsonnet index f2427f7..0dc1792 100644 --- a/training_config/bert_base_ml.jsonnet +++ b/training_config/bert_base_ml.jsonnet @@ -1,6 +1,6 @@ { "dataset_reader": { - "type": "srl_transformers", + "type": "transformer_srl", "bert_model_name": "bert-base-multilingual-cased", }, @@ -15,7 +15,7 @@ "validation_data_path": std.extVar("SRL_VALIDATION_DATA_PATH"), "model": { - "type": "srl_transformers", + "type": "transformer_srl", "embedding_dropout": 0.1, "bert_model": "bert-base-multilingual-cased", }, diff --git a/training_config/bert_large.jsonnet b/training_config/bert_large.jsonnet index b49a13c..abb5006 100644 --- a/training_config/bert_large.jsonnet +++ b/training_config/bert_large.jsonnet @@ -1,6 +1,6 @@ { "dataset_reader": { - "type": "srl_transformers", + "type": "transformer_srl", "bert_model_name": "bert-large-cased", }, @@ -15,7 +15,7 @@ "validation_data_path": std.extVar("SRL_VALIDATION_DATA_PATH"), "model": { - "type": "srl_transformers", + "type": "transformer_srl", "embedding_dropout": 0.1, "bert_model": "bert-large-cased", }, diff --git a/training_config/bert_tiny.jsonnet b/training_config/bert_tiny.jsonnet index 8724511..1e6c5f9 100644 --- a/training_config/bert_tiny.jsonnet +++ b/training_config/bert_tiny.jsonnet @@ -1,6 +1,6 @@ { "dataset_reader": { - "type": "srl_transformers", + "type": "transformer_srl", "bert_model_name": "mrm8488/bert-tiny-finetuned-squadv2", }, @@ -16,7 +16,7 @@ "validation_data_path": std.extVar("SRL_VALIDATION_DATA_PATH"), "model": { - "type": "srl_transformers", + "type": "transformer_srl", "embedding_dropout": 0.1, "bert_model": "mrm8488/bert-tiny-finetuned-squadv2", }, diff --git a/training_config/xlmr_base.jsonnet b/training_config/xlmr_base.jsonnet index f070ca5..58f36ee 100644 --- a/training_config/xlmr_base.jsonnet +++ b/training_config/xlmr_base.jsonnet @@ -1,6 +1,6 @@ { "dataset_reader": { - "type": "srl_transformers", + "type": "transformer_srl", "bert_model_name": "xlm-roberta-base", }, @@ -15,7 +15,7 @@ "validation_data_path": std.extVar("SRL_VALIDATION_DATA_PATH"), "model": { - "type": "srl_transformers", + "type": "transformer_srl", "embedding_dropout": 0.1, "bert_model": "xlm-roberta-base", }, diff --git a/transformer_srl/dataset_readers.py b/transformer_srl/dataset_readers.py index 1fb3677..7b9a8dd 100644 --- a/transformer_srl/dataset_readers.py +++ b/transformer_srl/dataset_readers.py @@ -126,7 +126,7 @@ def _convert_frames_indices_to_wordpiece_indices( return ["O"] + new_frame_labels + ["O"] -@DatasetReader.register("srl_transformers") +@DatasetReader.register("transformer_srl") class SrlTransformersReader(SrlReader): """ This DatasetReader is designed to read in the English OntoNotes v5.0 data diff --git a/transformer_srl/models.py b/transformer_srl/models.py index 944519d..686a4a9 100644 --- a/transformer_srl/models.py +++ b/transformer_srl/models.py @@ -27,7 +27,7 @@ FRAME_ROLE_PATH = pathlib.Path(__file__).resolve().parent / "resources" / "frame2role.csv" -@Model.register("srl_transformers") +@Model.register("transformer_srl") class SrlTransformers(SrlBert): """ @@ -309,4 +309,4 @@ def _get_label_tokens(self, namespace: str = "labels"): def _get_label_ids(self, namespace: str = "labels"): return self.vocab.get_index_to_token_vocabulary(namespace).keys() - default_predictor = "srl_transformers" + default_predictor = "transformer_srl" diff --git a/transformer_srl/predictors.py b/transformer_srl/predictors.py index a35c44a..aae2323 100644 --- a/transformer_srl/predictors.py +++ b/transformer_srl/predictors.py @@ -13,7 +13,7 @@ from spacy.tokens import Doc -@Predictor.register("srl_transformers") +@Predictor.register("transformer_srl") class SrlTransformersPredictor(SemanticRoleLabelerPredictor): def __init__( self, model: Model, dataset_reader: DatasetReader, language: str = "en_core_web_sm",