-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgenerate.py
80 lines (65 loc) · 2.45 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# -*- coding: utf-8 -*-
import argparse
import json
import os
import options
from module import augmentor
from module import generator
from module import sampler
from utils.tokenizer import PreDefinedVocab
from utils.tokenizer import WordpieceTokenizer
def main(args):
# set tokenizer
vocab = PreDefinedVocab(
vocab_file=args.vocab_file,
unk_token='[UNK]',
sep_token='[SEP]',
pad_token='[PAD]',
mask_token='[MASK]',
cls_token='[CLS]',
)
tokenizer = WordpieceTokenizer(vocab)
to_word = False
# select a sampling module
if args.sampling_strategy == 'random':
sampling_fn = sampler.UniformSampler()
# select a augmentation module
if args.augmentation_strategy == 'dropout':
generator_fn = generator.DropoutGenerator()
elif args.augmentation_strategy == 'blank':
generator_fn = generator.BlankGenerator(
mask_token=tokenizer.vocab.mask_token)
elif args.augmentation_strategy == 'unigram':
generator_fn = generator.UnigramGenerator(
args.unigram_frequency_for_generation)
to_word=True
elif args.augmentation_strategy == 'bigramkn':
generator_fn = generator.BigramKNGenerator(
args.bigram_frequency_for_generation)
to_word=True
elif args.augmentation_strategy == 'wordnet':
generator_fn = generator.WordNetGenerator(lang=args.lang_for_wordnet)
to_word = True
elif args.augmentation_strategy == 'word2vec':
generator_fn = generator.Word2vecGenerator(args.w2v_file)
to_word = True
elif args.augmentation_strategy == 'ppdb':
generator_fn = generator.PPDBGenerator(args.ppdb_file)
to_word = True
elif args.augmentation_strategy == 'bert':
from pytorch_transformers import BertTokenizer, BertForMaskedLM
bert = BertForMaskedLM.from_pretrained(args.model_name_or_path)
generator_fn = generator.BertGenerator(tokenizer, bert, args.temparature)
augmentor_fn = augmentor.ReplacingAugmentor(
tokenizer, sampling_fn, generator_fn, to_word=to_word)
with open(args.input, 'r') as f:
for line in f:
line = line.rstrip()
augmented_line = augmentor_fn(line, args.augmentation_rate)
print(augmented_line)
if __name__ == '__main__':
parser = argparse.ArgumentParser('')
options.generate_opts(parser)
options.sub_opts(parser)
args = parser.parse_args()
main(args)