Skip to content

Commit

Permalink
Add stream_predict_coref
Browse files Browse the repository at this point in the history
  • Loading branch information
Aethor committed Nov 13, 2023
1 parent e85b8d5 commit 5212b6c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "tibert"
version = "0.2.3"
version = "0.2.4"
description = "BERT for Coreference Resolution"
authors = ["Arthur Amalvy <arthur.amalvy@univ-avignon.fr>"]
license = "GPL-3.0-only"
Expand Down
56 changes: 43 additions & 13 deletions tibert/predict.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Literal, List, Union, cast
from typing import TYPE_CHECKING, Generator, Literal, List, Union, cast
from transformers import PreTrainedTokenizerFast
import torch
from tqdm import tqdm
Expand All @@ -14,15 +14,16 @@
)


def predict_coref(
def stream_predict_coref(
documents: List[Union[str, List[str]]],
model: BertForCoreferenceResolution,
tokenizer: PreTrainedTokenizerFast,
batch_size: int = 1,
quiet: bool = False,
device_str: Literal["cpu", "cuda", "auto"] = "auto",
lang: str = "en",
) -> List[CoreferenceDocument]:
) -> Generator[CoreferenceDocument, None, None]:

"""Predict coreference chains for a list of documents.
:param documents: A list of documents, tokenized or not. If
Expand All @@ -47,7 +48,7 @@ def predict_coref(
device = torch.device(device_str)

if len(documents) == 0:
return []
return

# Tokenized input sentence if needed
if isinstance(documents[0], str):
Expand All @@ -72,10 +73,10 @@ def predict_coref(
model = model.eval() # type: ignore
model = model.to(device)

preds = []

with torch.no_grad():

for i, batch in enumerate(tqdm(dataloader, disable=quiet)):

local_batch_size = batch["input_ids"].shape[0]

start_idx = batch_size * i
Expand All @@ -84,21 +85,50 @@ def predict_coref(

batch = batch.to(device)
out: BertCoreferenceResolutionOutput = model(**batch)

out_docs = out.coreference_documents(
[
[tokenizer.decode(t) for t in input_ids] # type: ignore
for input_ids in batch["input_ids"]
]
)
out_docs = [
out_doc.from_wpieced_to_tokenized(original_doc.tokens, batch, batch_i)
for batch_i, (original_doc, out_doc) in enumerate(
zip(batch_docs, out_docs)

for batch_i, (original_doc, out_doc) in enumerate(
zip(batch_docs, out_docs)
):
doc = out_doc.from_wpieced_to_tokenized(
original_doc.tokens, batch, batch_i
)
]
preds += out_docs
yield doc


def predict_coref(
documents: List[Union[str, List[str]]],
model: BertForCoreferenceResolution,
tokenizer: PreTrainedTokenizerFast,
batch_size: int = 1,
quiet: bool = False,
device_str: Literal["cpu", "cuda", "auto"] = "auto",
lang: str = "en",
) -> List[CoreferenceDocument]:
"""Predict coreference chains for a list of documents.
:param documents: A list of documents, tokenized or not. If
documents are not tokenized, MosesTokenizer will tokenize them
automatically.
:param tokenizer:
:param batch_size:
:param quiet: If ``True``, will report progress using ``tqdm``.
:param lang: lang for ``MosesTokenizer``
return preds
:return: a list of ``CoreferenceDocument``, with annotated
coreference chains.
"""
return list(
stream_predict_coref(
documents, model, tokenizer, batch_size, quiet, device_str, lang
)
)


def predict_coref_simple(
Expand Down

0 comments on commit 5212b6c

Please sign in to comment.