-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprune_contexts.py
executable file
·57 lines (45 loc) · 1.74 KB
/
prune_contexts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#!/usr/bin/env python
#
# Author: (c) 2016 Vincent Kriz <kriz@ufal.mff.cuni.cz>
#
import sys
import logging
import argparse
from collections import Counter
# Logging.
logging.basicConfig(format='%(asctime)-15s [%(levelname)7s] %(funcName)s - %(message)s', level=logging.DEBUG)
# Parse command line arguments.
parser = argparse.ArgumentParser()
parser.description = 'Prune training data by the given vocabulary.'
parser.add_argument('--vocabulary', required=True, help='a vocabulary destination')
parser.add_argument('--threshold', required=True, help='a word frequency threshold')
parser.add_argument('--output', required=True, help='a path where final treining data will be stored')
args = parser.parse_args()
# Load vocabulary.
logging.info('Loading vocabulary...')
vocabulary = dict()
with open(args.vocabulary, 'r') as fvocab:
for line in fvocab:
word, frequency = line.rstrip().split('\t')
vocabulary[word] = frequency
logging.info('Vocabulary size: %d', len(vocabulary))
# Grep training data.
with open(args.output, 'w') as fout:
for (n_line, line) in enumerate(sys.stdin):
if (n_line % 1000000) == 0:
logging.info('Processed %d lines.', n_line)
target_word, context = line.rstrip().split(' ')
if target_word not in vocabulary:
continue
if vocabulary[target_word] < args.threshold:
continue
try:
context_word, context_deprel = context.split('_')
except ValueError as exception:
# logging.error('Invalid line %d: %s', n_line, line.rstrip())
continue
if context_word not in vocabulary:
continue
if vocabulary[context_word] < args.threshold:
continue
fout.write('%s' % line)