Skip to content

Commit

Permalink
Update dependencies, update code, removed unused methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
Riccorl committed Jan 30, 2021
1 parent b262c32 commit 0d2ec62
Show file tree
Hide file tree
Showing 10 changed files with 35 additions and 69,922 deletions.
23 changes: 7 additions & 16 deletions scripts/bert_base_span.sh
Original file line number Diff line number Diff line change
@@ -1,26 +1,17 @@
#!/bin/bash
source /home/orlando/miniconda3/bin/activate allennlp
source /Users/ric/mambaforge/bin/activate srl-mt

HOME="/home/orlando"
DATASET="$HOME/datasets/ontonotes/conll-formatted-ontonotes-verbatlas"
PROJECT="$HOME/transformer-srl"
#HOME="/home/orlando"
DATASET="/Users/ric/Documents/ComputerScience/Projects/transformer-srl/data/conll2012_pb"
PROJECT="/Users/ric/Documents/ComputerScience/Projects/transformer-srl"
# local
# DATASET="/mnt/d/Datasets/conll2012/conll-formatted-ontonotes-verbatlas-subset"
# PROJECT="/mnt/c/Users/rikkw/Desktop/Ric/Projects/srl-bert-span"

export SRL_TRAIN_DATA_PATH="$DATASET/data/train"
export SRL_VALIDATION_DATA_PATH="$DATASET/data/development"

CONFIG="$PROJECT/training_config/bert_base.jsonnet"
CONFIG="$PROJECT/training_config/bert_base_span.jsonnet"
MODEL_DIR="$PROJECT/models/bert_base_conll2012"

free_mem=$(nvidia-smi --query-gpu=memory.free --format=csv -i 1 | grep -Eo [0-9]+)

echo "$free_mem MB"
while [ "$free_mem" -lt 10000 ]; do
free_mem=$(nvidia-smi --query-gpu=memory.free --format=csv -i 1 | grep -Eo [0-9]+)
sleep 5
done

echo "GPU finally free, training..."

allennlp train $CONFIG -s models/bert_base_va --include-package transformer_srl #--recover
allennlp train $CONFIG -s $MODEL_DIR --include-package transformer_srl --force #--recover
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="transformer_srl", # Replace with your own username
version="2.4.11",
version="2.5",
author="Riccardo Orlando",
author_email="orlandoricc@gmail.com",
description="SRL Transformer model",
Expand All @@ -20,8 +20,8 @@
"Operating System :: OS Independent",
],
install_requires=[
"allennlp>=1.2,<1.3",
"allennlp_models>=1.2,<1.3",
"allennlp>=2.0,<2.1",
"allennlp_models>=2.0,<2.1",
"spacy>=2.3,<2.4"
],
python_requires=">=3.6",
Expand Down
2 changes: 1 addition & 1 deletion training_config/bert_base_span.jsonnet
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,6 @@
"grad_norm": 1.0,
"num_epochs": 15,
"validation_metric": "+f1_role",
"cuda_device": 0,
"cuda_device": -1,
},
}
8 changes: 4 additions & 4 deletions transformer_srl/dataset_readers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import pathlib
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Set, Tuple

Expand All @@ -19,8 +18,6 @@

logger = logging.getLogger(__name__)

FRAME_LIST_PATH = pathlib.Path(__file__).resolve().parent / "resources" / "framelist.txt"


def _convert_verb_indices_to_wordpiece_indices(
verb_indices: List[int], offsets: List[int], binary: bool = True
Expand Down Expand Up @@ -468,7 +465,10 @@ def _read(self, file_path: str):
# transpose rolses, to have a list of roles per frame
roles = list(map(list, zip(*roles)))
current_frame = 0
for i, frame, in enumerate(frames):
for (
i,
frame,
) in enumerate(frames):
if frame != "_":
verb_indicator = [0] * len(frames)
verb_indicator[i] = 1
Expand Down
136 changes: 12 additions & 124 deletions transformer_srl/models.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import pathlib
from typing import Any, Dict, List, Union

import numpy as np
import torch
import torch.nn.functional as F
from allennlp.data import TextFieldTensors, Vocabulary
from allennlp.models.model import Model
from allennlp.modules import Seq2SeqEncoder
from allennlp.nn import InitializerApplicator, util
from allennlp.nn.util import get_device_of, get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.training.metrics.fbeta_measure import FBetaMeasure
from allennlp_models.structured_prediction import SrlBert
from allennlp_models.structured_prediction.metrics.srl_eval_scorer import (
Expand All @@ -19,10 +17,8 @@
from torch import nn
from transformers import AutoModel

from transformer_srl.utils import load_label_list, load_lemma_frame, load_role_frame
from transformer_srl.utils import load_label_list

LEMMA_FRAME_PATH = pathlib.Path(__file__).resolve().parent / "resources" / "lemma2va_ml.tsv"
FRAME_ROLE_PATH = pathlib.Path(__file__).resolve().parent / "resources" / "frame2role_ml.tsv"
FRAME_LIST_PATH = pathlib.Path(__file__).resolve().parent / "resources" / "framelist.txt"
ROLE_LIST_PATH = pathlib.Path(__file__).resolve().parent / "resources" / "rolelist.txt"

Expand Down Expand Up @@ -57,17 +53,11 @@ def __init__(
label_smoothing: float = None,
ignore_span_metric: bool = False,
srl_eval_path: str = DEFAULT_SRL_EVAL_PATH,
restrict_frames: bool = False,
restrict_roles: bool = False,
inventory: str = "verbatlas",
**kwargs,
) -> None:
# bypass SrlBert constructor
Model.__init__(self, vocab, **kwargs)
self.lemma_frame_dict = load_lemma_frame(LEMMA_FRAME_PATH)
self.frame_role_dict = load_role_frame(FRAME_ROLE_PATH)
self.restrict_frames = restrict_frames
self.restrict_roles = restrict_roles
self.transformer = AutoModel.from_pretrained(bert_model)
self.frame_criterion = nn.CrossEntropyLoss()
if inventory == "verbatlas":
Expand Down Expand Up @@ -146,7 +136,10 @@ def forward( # type: ignore
mask = get_text_field_mask(tokens)
input_ids = util.get_token_ids_from_text_field_tensors(tokens)
bert_embeddings, _ = self.transformer(
input_ids=input_ids, token_type_ids=verb_indicator, attention_mask=mask,
input_ids=input_ids,
token_type_ids=verb_indicator,
attention_mask=mask,
return_dict=False,
)
# extract embeddings
embedded_text_input = self.embedding_dropout(bert_embeddings)
Expand Down Expand Up @@ -224,32 +217,7 @@ def forward( # type: ignore
def decode_frames(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# frame prediction
frame_probabilities = output_dict["frame_probabilities"]
if self.restrict_frames:
frame_probabilities = frame_probabilities.cpu().data.numpy()
lemmas = output_dict["lemma"]
candidate_labels = [self.lemma_frame_dict.get(l, []) for l in lemmas]
# clear candidates from unknowns
label_set = set(k for k in self._get_label_tokens("frames_labels"))
candidate_labels_ids = [
[
self.vocab.get_token_index(l, namespace="frames_labels")
for l in cl
if l in label_set
]
for cl in candidate_labels
]

frame_predictions = []
for cl, fp in zip(candidate_labels_ids, frame_probabilities):
# restrict candidates from verbatlas inventory
fp_candidates = np.take(fp, cl)
if fp_candidates.size > 0:
frame_predictions.append(cl[fp_candidates.argmax(axis=-1)])
else:
frame_predictions.append(fp.argmax(axis=-1))
else:
frame_predictions = frame_probabilities.argmax(dim=-1).cpu().data.numpy()

frame_predictions = frame_probabilities.argmax(dim=-1).cpu().data.numpy()
output_dict["frame_tags"] = [
self.vocab.get_token_from_index(f, namespace="frames_labels") for f in frame_predictions
]
Expand All @@ -263,33 +231,9 @@ def make_output_human_readable(
self, output_dict: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
output_dict = self.decode_frames(output_dict)
if self.restrict_roles:
output_dict = self._mask_args(output_dict)
output_dict = super().make_output_human_readable(output_dict)
return output_dict

def _mask_args(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
class_probs = output_dict["class_probabilities"]
device = get_device_of(class_probs)
# torch doesn't like -1 as cpu device
device = torch.device("cuda" if device >= 0 else "cpu")
lemmas = output_dict["lemma"]
frames = output_dict["frame_tags"]
candidate_mask = torch.ones_like(class_probs, dtype=torch.bool).to(device)
for i, (l, f) in enumerate(zip(lemmas, frames)):
candidates = self.frame_role_dict.get((l, f), [])
if candidates:
canidate_ids = [
self.vocab.get_token_index(r, namespace="labels") for r in candidates
]
canidate_ids = torch.tensor(canidate_ids).to(device)
canidate_ids = canidate_ids.repeat(candidate_mask.shape[1], 1)
candidate_mask[i].scatter_(1, canidate_ids, False)
else:
candidate_mask[i].fill_(False)
class_probs.masked_fill_(candidate_mask, 0)
return output_dict

@overrides
def get_metrics(self, reset: bool = False):
if self.ignore_span_metric:
Expand All @@ -303,13 +247,9 @@ def get_metrics(self, reset: bool = False):
# This can be a lot of metrics, as there are 3 per class.
# we only really care about the overall metrics, so we filter for them here.
metric_dict_filtered = {
x.split("-")[0] + "_role": y
for x, y in metric_dict.items()
if "overall" in x #and "f1" in x
}
frame_metric_dict = {
x + "_frame": y for x, y in frame_metric_dict.items() #if "fscore" in x
x.split("-")[0] + "_role": y for x, y in metric_dict.items() if "overall" in x
}
frame_metric_dict = {x + "_frame": y for x, y in frame_metric_dict.items()}
return {**metric_dict_filtered, **frame_metric_dict}

def _get_label_tokens(self, namespace: str = "labels"):
Expand Down Expand Up @@ -351,17 +291,10 @@ def __init__(
label_smoothing: float = None,
ignore_span_metric: bool = False,
srl_eval_path: str = DEFAULT_SRL_EVAL_PATH,
restrict_frames: bool = False,
restrict_roles: bool = False,
**kwargs,
) -> None:
# bypass SrlBert constructor
Model.__init__(self, vocab, **kwargs)
self.lemma_frame_dict = load_lemma_frame(LEMMA_FRAME_PATH)
self.frame_role_dict = load_role_frame(FRAME_ROLE_PATH)
self.restrict_frames = restrict_frames
self.restrict_roles = restrict_roles

if isinstance(model_name, str):
self.transformer = AutoModel.from_pretrained(model_name)
else:
Expand Down Expand Up @@ -435,6 +368,7 @@ def forward( # type: ignore
input_ids=util.get_token_ids_from_text_field_tensors(tokens),
token_type_ids=verb_indicator,
attention_mask=mask,
return_dict=False,
)

# extract embeddings
Expand Down Expand Up @@ -489,32 +423,7 @@ def forward( # type: ignore
def decode_frames(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
# frame prediction
frame_probabilities = output_dict["frame_probabilities"]
if self.restrict:
frame_probabilities = frame_probabilities.cpu().data.numpy()
lemmas = output_dict["lemma"]
candidate_labels = [self.lemma_frame_dict.get(l, []) for l in lemmas]
# clear candidates from unknowns
label_set = set(k for k in self._get_label_tokens("frames_labels"))
candidate_labels_ids = [
[
self.vocab.get_token_index(l, namespace="frames_labels")
for l in cl
if l in label_set
]
for cl in candidate_labels
]

frame_predictions = []
for cl, fp in zip(candidate_labels_ids, frame_probabilities):
# restrict candidates from verbatlas inventory
fp_candidates = np.take(fp, cl)
if fp_candidates.size > 0:
frame_predictions.append(cl[fp_candidates.argmax(axis=-1)])
else:
frame_predictions.append(fp.argmax(axis=-1))
else:
frame_predictions = frame_probabilities.argmax(dim=-1).cpu().data.numpy()

frame_predictions = frame_probabilities.argmax(dim=-1).cpu().data.numpy()
output_dict["frame_tags"] = [
self.vocab.get_token_from_index(f, namespace="frames_labels") for f in frame_predictions
]
Expand All @@ -530,7 +439,7 @@ def make_output_human_readable(
output_dict = self.decode_frames(output_dict)
# if self.restrict:
# output_dict = self._mask_args(output_dict)
# output_dict = super().make_output_human_readable(output_dict)
output_dict = super().make_output_human_readable(output_dict)
roles_probabilities = output_dict["role_probabilities"]
roles_predictions = roles_probabilities.argmax(dim=-1).cpu().data.numpy()

Expand All @@ -540,26 +449,6 @@ def make_output_human_readable(
]
return output_dict

def _mask_args(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
class_probs = output_dict["class_probabilities"]
device = get_device_of(class_probs)
lemmas = output_dict["lemma"]
frames = output_dict["frame_tags"]
candidate_mask = torch.ones_like(class_probs, dtype=torch.bool).to(device)
for i, (l, f) in enumerate(zip(lemmas, frames)):
candidates = self.frame_role_dict.get((l, f), [])
if candidates:
canidate_ids = [
self.vocab.get_token_index(r, namespace="labels") for r in candidates
]
canidate_ids = torch.tensor(canidate_ids).to(device)
canidate_ids = canidate_ids.repeat(candidate_mask.shape[1], 1)
candidate_mask[i].scatter_(1, canidate_ids, False)
else:
candidate_mask[i].fill_(False)
class_probs.masked_fill_(candidate_mask, 0)
return output_dict

@overrides
def get_metrics(self, reset: bool = False):
role_metric_dict = self.f1_role_metric.get_metric(reset=reset)
Expand All @@ -582,4 +471,3 @@ def _get_label_ids(self, namespace: str = "labels"):
return self.vocab.get_index_to_token_vocabulary(namespace).keys()

default_predictor = "transformer_srl"

13 changes: 8 additions & 5 deletions transformer_srl/predictors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import enum
from typing import List, Dict, Type
from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers.token_class import Token

import numpy
from allennlp.common import plugins
Expand All @@ -17,7 +17,10 @@
@Predictor.register("transformer_srl")
class SrlTransformersPredictor(SemanticRoleLabelerPredictor):
def __init__(
self, model: Model, dataset_reader: DatasetReader, language: str = "en_core_web_sm",
self,
model: Model,
dataset_reader: DatasetReader,
language: str = "en_core_web_sm",
) -> None:
super().__init__(model, dataset_reader, language)

Expand Down Expand Up @@ -180,9 +183,9 @@ def from_archive(
model_type = config.get("model").get("type")
model_class, _ = Model.resolve_class_name(model_type)
predictor_name = model_class.default_predictor
predictor_class: Type[Predictor] = Predictor.by_name( # type: ignore
predictor_name
) if predictor_name is not None else cls
predictor_class: Type[Predictor] = (
Predictor.by_name(predictor_name) if predictor_name is not None else cls # type: ignore
)

if dataset_reader_to_load == "validation" and "validation_dataset_reader" in config:
dataset_reader_params = config["validation_dataset_reader"]
Expand Down
Loading

0 comments on commit 0d2ec62

Please sign in to comment.