-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcalibration.py
47 lines (42 loc) · 1.98 KB
/
calibration.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import numpy as np
from os.path import join
from glob import glob
from tqdm import tqdm
from scipy.special import softmax
import re
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model_dir", type=str, default="outputs/run.pairwise*preference_matrix")
args = parser.parse_args()
token_prob_df = []
model_preference_mat = {}
ori_model_preference_mat = {}
for model_dir in tqdm(glob(args.model_dir)):
mean_list = []
preference_mat = {}
ori_preference_mat = {}
for q_file in glob(join(model_dir, "q-*_logit.npy")):
qid = re.search(r"q-(\d+)_logit\.npy", q_file).group(1)
logit_arr = np.load(q_file).astype(np.float32)
# save comparision matrix for without calibration
wocal_arr = np.apply_along_axis(lambda x: x[0] > x[1], axis=-1, arr=logit_arr)
np.save(q_file.replace("_logit", "_wocal"), wocal_arr)
# calibration
preference_mat[qid] = softmax(logit_arr, -1)[...,0]
preference_mat[qid] = softmax(np.stack([preference_mat[qid],
preference_mat[qid].T], -1), -1)[...,0]
np.save(q_file.replace("_logit", "_fix"), preference_mat[qid]>0.5)
np.save(q_file.replace("_logit", "_calogit"), preference_mat[qid])
mean_list.append(logit_arr.mean((0,1)))
# os.rename(q_file.replace("_logit", ""), q_file.replace("_logit", "_ori")) # change old file name
ori_preference_mat[qid] = np.load(q_file.replace("_logit", "_ori"))
model_name = re.search(r"allpair\.(.+)_preference_matrix", model_dir).group(1)
ori_model_preference_mat[model_name + ("-icl" if "icl" in model_dir else "")] = ori_preference_mat
model_preference_mat[model_name + ("-icl" if "icl" in model_dir else "")] = preference_mat
prob_value = np.stack(mean_list).mean(0)
token_prob_df.append({
"model": model_name, "ICL": "icl" in model_dir,
"A": prob_value[0], "B": prob_value[1],
"#Query": len(mean_list)
})
print("Calibration Finished")