-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathcalc_code_bleu.py
110 lines (83 loc) · 4.38 KB
/
calc_code_bleu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# -*- coding:utf-8 -*-
import json
import argparse
import bleu as bleu
import weighted_ngram_match as weighted_ngram_match
import syntax_match as syntax_match
import dataflow_match as dataflow_match
from pathlib import Path
root_directory = Path(__file__).parents[2]
def make_weights(reference_tokens, key_word_list):
return {token: 1 if token in key_word_list else 0.2 \
for token in reference_tokens}
def compute_codebleu(hypothesis, references, lang, params='0.25,0.25,0.25,0.25'):
alpha, beta, gamma, theta = [float(x) for x in params.split(',')]
# calculate ngram match (BLEU)
tokenized_hyps = [x.split() for x in hypothesis]
tokenized_refs = [[x.split() for x in reference] for reference in references]
ngram_match_score = bleu.corpus_bleu(tokenized_refs, tokenized_hyps)
# calculate weighted ngram match
kw_file = root_directory.joinpath("files_to_be_submitted/code_implementations/keywords/{}.txt".format(lang))
keywords = [x.strip() for x in open(kw_file, 'r', encoding='utf-8').readlines()]
tokenized_refs_with_weights = \
[
[
[
reference_tokens, make_weights(reference_tokens, keywords)
] for reference_tokens in reference
] for reference in tokenized_refs
]
weighted_ngram_match_score = weighted_ngram_match.corpus_bleu(tokenized_refs_with_weights, tokenized_hyps)
# calculate syntax match
syntax_match_score = syntax_match.corpus_syntax_match(references, hypothesis, lang)
# calculate dataflow match
dataflow_match_score = dataflow_match.corpus_dataflow_match(references, hypothesis, lang)
code_bleu_score = alpha * ngram_match_score \
+ beta * weighted_ngram_match_score \
+ gamma * syntax_match_score \
+ theta * dataflow_match_score
return code_bleu_score, (ngram_match_score, weighted_ngram_match_score, syntax_match_score, dataflow_match_score)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--refs', type=str, nargs='+', required=True, help='reference files')
parser.add_argument('--json_refs', action='store_true', help='reference files are JSON files')
parser.add_argument('--hyp', type=str, required=True, help='hypothesis file')
parser.add_argument('--lang', type=str, required=True,
choices=['java', 'javascript', 'c_sharp', 'php', 'go', 'python', 'ruby'],
help='programming language')
parser.add_argument('--params', type=str, default='0.25,0.25,0.25,0.25',
help='alpha, beta and gamma')
args = parser.parse_args()
# List(List(String))
# -> length of the outer List is number of references per translation
# -> length of the inner List is number of total examples
pre_references = [
[x.strip() for x in open(file, 'r', encoding='utf-8').readlines()]
for file in args.refs
]
# List(String)
hypothesis = [x.strip() for x in open(args.hyp, 'r', encoding='utf-8').readlines()]
for i in range(len(pre_references)):
assert len(hypothesis) == len(pre_references[i])
references = []
for i in range(len(hypothesis)):
ref_for_instance = []
for j in range(len(pre_references)):
if args.json_refs:
_ref = json.loads(pre_references[j][i])
ref_for_instance.append(_ref['code'])
else:
ref_for_instance.append(pre_references[j][i])
references.append(ref_for_instance)
assert len(references) == len(pre_references) * len(hypothesis)
# references is List(List(String)) where the inner List is a
# list of reference translations for one example.
code_bleu_score, (ngram_match_score, weighted_ngram_match_score, syntax_match_score, dataflow_match_score) = \
compute_codebleu(hypothesis, references, args.lang, args.params)
print('ngram match: {0}, weighted ngram match: {1}, syntax_match: {2}, dataflow_match: {3}'.
format(ngram_match_score, weighted_ngram_match_score, syntax_match_score, dataflow_match_score))
print('CodeBLEU score: %.2f' % (code_bleu_score * 100.0))
if __name__ == '__main__':
main()