-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtag.py
59 lines (51 loc) · 2.32 KB
/
tag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from flair.data import Span
from flair.datasets import ColumnCorpus, DataLoader
from flair.embeddings import *
from flair.models import SequenceTagger
device = torch.device('cuda:0')
tag_type = 'ner'
def process_file(tagger: SequenceTagger, file_path: Union[str, Path], out_path: Union[str, Path], print_corpus=None):
try:
corpus: ColumnCorpus = ColumnCorpus(
data_folder=os.path.split(file_path)[0],
train_file=os.path.split(file_path)[1],
column_format={0: 'text', 1: 'begin', 2: 'end', 3: 'ner'}
)
if len(corpus.get_all_sentences()) == 0:
return 0
if print_corpus is not None:
results: List[Span] = []
data_loader = DataLoader(corpus.get_all_sentences())
result, loss = tagger.evaluate(data_loader)
print(result.detailed_results)
if not os.path.isfile(print_corpus):
for sentence in corpus.train:
for span in sentence.get_spans(tag_type):
if span.tag is not "O":
results.append(span)
print_spans_in_brat_format(results, print_corpus)
return tag_corpus(corpus, file_path, out_path, tagger)
except IndexError:
log.error(f'IndexError in file: "{file_path}"!')
return 0
def tag_corpus(corpus, filename, outpath, tagger):
results: List[Span] = []
tagged: List[Sentence] = tagger.predict(sentences=corpus.get_all_sentences())
for sentence in tagged:
for span in sentence.get_spans(tag_type):
if span.tag is not "O":
results.append(span)
log.debug(f'Found {len(results)} tags in {filename}')
print_spans_in_brat_format(results, outpath)
return len(results)
def print_spans_in_brat_format(results: List[Span], outpath: Union[str, Path], ralign=False):
with open(outpath, 'w', encoding='utf-8') as outfile:
if len(results) > 0:
log10 = int(np.log10(len(results)) + 1)
else:
log10 = 1
for i, span in enumerate(results, start=1):
if ralign:
print(f'T{i: <{log10}d}\t{span.tag} {span.start_pos} {span.end_pos}\t{span.text}', file=outfile)
else:
print(f'T{i}\t{span.tag} {span.start_pos} {span.end_pos}\t{span.text}', file=outfile)