-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpreprocess.py
74 lines (60 loc) · 2.16 KB
/
preprocess.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import json
import nltk
from tqdm import tqdm
import pickle
import utils
def get_line_number(file_name):
line_counter = 0
with open(file_name, 'r') as f:
for line in f:
line_counter += 1
return line_counter
if __name__ == '__main__':
random_seed = 123
abstract_len_min, abstract_len_max = 50, 256
title_len_min, title_len_max = 4, 20
abstract_skip_words = [
'withdrawn',
]
title_skip_words = [
'reply', 'Reply',
'comment', 'Comment',
]
target_category = 'cs.AI'
data_fname = 'data/arxiv-metadata-oai-snapshot.json'
line_number = get_line_number(data_fname)
count = 0
abstract_all = []
title_all = []
title_pos_all = []
with open(data_fname, 'r') as f:
for line in tqdm(f, total=line_number):
paper = json.loads(line)
if not any([cat == target_category for cat in paper['categories'].split()]):
continue
abstract = paper['abstract'].strip().replace('\n', ' ')
title = paper['title'].strip().replace('\n ', ' ')
abstract = utils.replace_special_tokens(abstract)
title = utils.replace_special_tokens(title)
abstract_len = len(abstract.split())
title_len = len(title.split())
if abstract_len < abstract_len_min or abstract_len > abstract_len_max or \
title_len < title_len_min or title_len > title_len_max or \
any([w in abstract for w in abstract_skip_words]) or \
any([w in title for w in title_skip_words]):
continue
title_tokens = nltk.word_tokenize(title.lower())
title_pos = [pos for (word, pos) in nltk.pos_tag(title_tokens)]
abstract_all.append(abstract)
title_all.append(title)
title_pos_all.append(title_pos)
count += 1
# if count == 100:
# break
print(f"dataset size: {count}")
with open('data/preprocessed.pkl', 'wb') as f:
pickle.dump({
'abstract': abstract_all,
'title': title_all,
'title_pos': title_pos_all,
}, f)