Skip to content

Commit

Permalink
dataset reader for propbank inventory
Browse files Browse the repository at this point in the history
  • Loading branch information
Riccorl committed Aug 9, 2020
1 parent d6efd40 commit 009c3e3
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 6 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,19 @@

Semantic Role Labeling based on [AllenNLP implementation](https://demo.allennlp.org/semantic-role-labeling) of [Shi et al, 2019](https://arxiv.org/abs/1904.05255). It uses [VerbAatlas](http://verbatlas.org/) inventory and it's trained also on predicate disambiguation, in addition to arguments identification and disambiguation.

### To-Dos

- [x] Works with both PropBank and VerbAtlas (infer inventory from dataset reader)
- [ ] Compatibility with all models from Huggingface's Transformers.
- Now works only with models that accept 1 as token type id

### Infos

- Language Model: BERT
- Dataset: CoNLL 2012


### Results
### Results with VerbAtlas

With `bert-base-cased`:
```
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="transformer_srl", # Replace with your own username
version="2.2rc12",
version="2.2rc13",
author="Riccardo Orlando",
author_email="orlandoricc@gmail.com",
description="SRL Transformer model",
Expand Down
18 changes: 14 additions & 4 deletions transformer_srl/dataset_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,7 @@ def _read(self, file_path: str):
if sentence.srl_frames:
for (_, tags) in sentence.srl_frames:
verb_indicator = [1 if label[-2:] == "-V" else 0 for label in tags]
frames = [
f if v == 1 else "O"
for f, v in zip(sentence.predicate_framenet_ids, verb_indicator)
]
frames = self._get_predicate_labels(sentence, verb_indicator)
lemmas = [
f for f, v in zip(sentence.predicate_lemmas, verb_indicator) if v == 1
]
Expand Down Expand Up @@ -359,6 +356,19 @@ def _convert_tags_to_wordpiece_tags(self, tags: List[str], offsets: List[int]) -
# Add O tags for cls and sep tokens.
return ["O"] + new_tags + ["O"]

def _get_predicate_labels(self, sentence, verb_indicator):
frames = [f if v == 1 else "O" for f, v in zip(frame_labels, verb_indicator)]
labels = []
for i, v in enumerate(verb_indicator):
if v == 1:
label = (
"{}.{}".format(sentence.predicate_lemmas[i], sentence.predicate_framenet_ids[i])
if sentence.predicate_framenet_ids[i].isdigit()
else sentence.predicate_framenet_ids[i]
)
labels.append(label)
return labels


@DatasetReader.register("transformer_srl_dependency")
class SrlUdpDatasetReader(SrlTransformersSpanReader):
Expand Down

0 comments on commit 009c3e3

Please sign in to comment.