forked from NhatHoang2002/ToXCL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_llm.py
107 lines (92 loc) · 3.97 KB
/
test_llm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import csv
import os
import time
from argparse import ArgumentParser
import openai
import pandas as pd
from eval_metrics import (compute_classification_scores,
compute_generation_scores)
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
openai.api_key = "<YOUR_OPENAI_API_KEY}"
device = "cuda"
model, tokenizer = None, None
PROMPT_TEMPLATE = """
The input is a tweet that might contain toxic speech. You are required to detect whether it is hateful or not, and if it is hateful, please give a brief explanation of why it is considered hateful.
The Output can be either of the form "hate <SEP> <GENERATED_EXPLANATION>" or "normal <SEP> none".
Input: {}
Output:
"""
def get_mistral_answer(prompt):
messages = [{"role": "user", "content": prompt}]
encodeds = tokenizer.apply_chat_template(messages, return_tensors="pt")
model_inputs = encodeds.to(device)
generated_ids = model.generate(model_inputs, max_new_tokens=32, do_sample=True)
assistant_message = tokenizer.batch_decode(generated_ids)[0]
return assistant_message.split("[/INST]")[1].strip().split("</s>")[0].strip()
def get_chatgpt_answer(prompt):
messages = [{"role": "user", "content": prompt}]
response = None
while response is None:
try:
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo-0613",
messages=messages
)
except Exception as msg:
print(msg)
print('sleeping because of exception ...')
time.sleep(30)
response = response.choices[0].message["content"]
return response
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument('model_name', choices=['chatgpt', 'mistral'])
parser.add_argument('dataset_name')
parser.add_argument('--output_dir', default='saved/llm')
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
output_path = os.path.join(args.output_dir, f"{args.model_name}_{args.dataset_name}_result.csv")
if args.model_name == 'chatgpt':
get_output_fn = get_chatgpt_answer
elif args.model_name == 'mistral':
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2").to(device)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
get_output_fn = get_mistral_answer
test_data = []
with open(f"data/{args.dataset_name}_valid.csv", 'r', newline='') as file:
csvreader = csv.reader(file)
_ = next(csvreader)
for row in csvreader:
test_data.append({
'text': row[0].strip(),
'label': row[2].strip(),
'explanation': row[3].strip(),
})
# test_output = []
# for idx, row in tqdm(enumerate(test_data), total=len(test_data)):
# tqdm.write(f">>> {row['text']}")
# prompt = PROMPT_TEMPLATE.format(row['text'])
# output = get_output_fn(prompt)
# test_output.append({
# 'text': row['text'],
# 'output': output,
# 'label': row['label'],
# 'gold_explanation': row['explanation']
# })
# pd.DataFrame(test_output).to_csv(output_path, index=False)
df = pd.read_csv(output_path)
cls_ground_truth = df.label.tolist()
generation_ground_truth = df.gold_explanation.tolist()
output = df.output.tolist()
cls_generated, generation_generated = [], []
for x in output:
if "<SEP>" in x:
cls_generated.append(x.split("<SEP>")[0].strip())
generation_generated.append(x.split("<SEP>")[1].strip())
else:
cls_generated.append("hate")
generation_generated.append(x)
print("Classification (ACC, F1):", compute_classification_scores(cls_ground_truth, cls_generated))
print("Generation (BLEU-4, ROUGE-L, METEOR, BERTScore):", compute_generation_scores(generation_ground_truth, generation_generated))