-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathmain.py
69 lines (58 loc) · 2.08 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import argparse
import json
import os
from conversers import load_attack_and_target_models
if __name__ == '__main__':
parser = argparse.ArgumentParser()
########### Target model parameters ##########
parser.add_argument(
"--target-model",
default = "vicuna",
choices=["vicuna", 'falcon', 'llama2',"gpt-3.5-turbo", "gpt-4"],
help = "Name of target model.",
)
parser.add_argument(
"--target-max-n-tokens",
type = int,
default = 300,
help = "Maximum number of generated tokens for the target."
)
parser.add_argument(
"--exp_name",
type = str,
default = "main",
choices=['main', 'abl_c', 'abl_layer', 'multi_scene', 'abl_fig6_4', 'further_q'],
help = "Experiment file name"
)
parser.add_argument(
"--defense",
type = str,
default = "none",
choices=['none', 'sr', 'ic'],
help = "LLM defense: None, Self-Reminder, In-Context"
)
##################################################
args = parser.parse_args()
f = open(f'./res/data_{args.exp_name}.json',)
datas = json.load(f)
f.close()
results = [{} for _ in range(len(datas))]
for idx, data in enumerate(datas):
if args.exp_name in ['main', 'further_q']:
questions = [data['inception_attack']] + data['questions']
else:
questions = data['questions']
targetLM = load_attack_and_target_models(args)
results[idx]['topic'] = data['topic']
# Get target responses
results[idx]['qA_pairs'] = []
for question in questions:
target_response_list = targetLM.get_response(question, args.defense)
results[idx]['qA_pairs'].append({'Q': question, 'A': target_response_list})
print(target_response_list)
del targetLM
results_dumped = json.dumps(results)
os.makedirs('results', exist_ok=True)
with open(f'./results/{args.target_model}_{args.exp_name}_{args.defense}_results.json', 'w+') as f:
f.write(results_dumped)
f.close()