-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathgrading.py
More file actions
69 lines (54 loc) · 2.41 KB
/
Copy pathgrading.py
File metadata and controls
69 lines (54 loc) · 2.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import json
import os
import datasets
from tqdm import tqdm
import argparse
from utils import grading_text, grading_math, pred_cut
API_KEY = "xxx"
BASE_URL = "xxx"
def grading(pred_extract, info, grading_model="gpt-4.1-mini"):
if info['answer_type'] in ["text", "single_choice", "multiple_choice"]:
return grading_text(pred_extract, info,
grading_model=grading_model, API_KEY=API_KEY, BASE_URL=BASE_URL)
else:
result = grading_math(pred_extract, info)
if result is True:
return result
else:
pred_extract_cut = pred_cut(pred_extract)
try:
result_cut = grading_math(pred_extract_cut, info)
except Exception as e:
print("Error in grading for cut result.")
print(e)
result_cut = False
return result_cut
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--response_file", type=str, required=True, help="Input file name.")
parser.add_argument("--grading_model", type=str, default="gpt-4.1-mini", help="Model name for grading.")
args = parser.parse_args()
dataset = datasets.load_dataset("meituan-longcat/General365_Public")
question_id_to_info = {}
for item in dataset["test"]:
question_id_to_info[item["id"]] = item
with open(os.path.join("model_responses", args.response_file), "r", encoding="utf-8") as f_read:
model_responses = [json.loads(line) for line in f_read.readlines()]
question_id2acc = {}
for model_response in tqdm(model_responses):
question_id = model_response["question_id"]
pred = model_response["model_response"]
gold_info = question_id_to_info[question_id]
acc = grading(pred, gold_info, grading_model=args.grading_model)
question_id2acc[question_id] = acc
avg_acc = sum(question_id2acc[qid] for qid in question_id2acc) / len(question_id2acc)
print(f"Average Accuracy: {avg_acc}")
grading_result = {
"average_accuracy": avg_acc,
"per_question_accuracy": question_id2acc,
}
if not os.path.exists("grading_results"):
os.makedirs("grading_results")
write_file = os.path.join("grading_results", f"result_{args.response_file}.log")
with open(write_file, "w", encoding="utf-8") as f_write:
json.dump(grading_result, f_write, indent=4, ensure_ascii=False)