-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathself-supervised.py
90 lines (64 loc) · 2.69 KB
/
self-supervised.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import sys
import math
import json
import argparse
from tqdm import tqdm
from datasets import load_dataset
from transformers import BartForConditionalGeneration, AutoTokenizer, Text2TextGenerationPipeline
from transformers.pipelines.pt_utils import KeyDataset
parser = argparse.ArgumentParser()
parser.add_argument("-m",
"--model",
default="models/bart-base")
parser.add_argument("--batch_size",
default=8,
type=int)
parser.add_argument("--cut_off",
default=2000,
type=int)
parser.add_argument("-i",
"--input")
parser.add_argument("-o",
"--output")
args = parser.parse_args()
print("Using model: {}".format(args.model))
print("Input file: {}".format(args.input))
print("Output file: {}".format(args.output))
input_dataset = []
with open(args.input, 'r') as f:
for line in f:
input_dataset.append(json.loads(line))
def chunker(seq, size):
return (seq[pos:pos + size] for pos in range(0, len(seq), size))
def join_titles_and_abstracts(dataset, special_token="<s>"):
dataset["src"] = "{}<s>{}".format(dataset["title"], dataset["abstract"])
return dataset
# Loading the dataset for testing
dataset = load_dataset("json",
data_files=args.input,
split="train")
dataset = dataset.map(join_titles_and_abstracts)
input_dataset = input_dataset[:args.cut_off]
dataset = dataset.select(range(args.cut_off))
print("Cut off for self-training samples: {}".format(len(dataset)))
#print(dataset)
tokenizer = AutoTokenizer.from_pretrained(args.model,
model_max_length=512)
model = BartForConditionalGeneration.from_pretrained(args.model)
pipe = Text2TextGenerationPipeline(model=model, tokenizer=tokenizer) #, device="mps")
#model = model.to("mps")
pred_phrases = {}
for pred in tqdm(pipe(KeyDataset(dataset, "src"),
batch_size=args.batch_size,
truncation="only_first",
max_length=40,
num_beams=1,
num_return_sequences=1), total=len(dataset)):
pred_phrases[dataset["id"][len(pred_phrases)]] = pred[0]['generated_text'].split(";")
for i, sample in enumerate(input_dataset):
input_dataset[i]["references"] = input_dataset[i]["keyphrases"]
input_dataset[i]["keyphrases"] = pred_phrases[sample["id"]]
with open(args.output, "w") as f:
f.write('\n'.join([json.dumps(doc) for doc in input_dataset]))
#with open(args.output, 'w') as f:
# f.write("\n".join([json.dumps({"id": res[0], "top_m": res[1], "top_k": res[2]}) for res in pred_phrases]))