-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathaverage.py
More file actions
142 lines (121 loc) · 4.46 KB
/
average.py
File metadata and controls
142 lines (121 loc) · 4.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import json
import argparse
def get_args():
parser = argparse.ArgumentParser(description="Compute average evaluation metrics.")
parser.add_argument("--input_file", type=str, required=True)
parser.add_argument("--result_file", type=str, required=True)
return parser.parse_args()
def main():
args = get_args()
# Metrics to aggregate
metrics = {
# ASR
'wer': [], 'cer': [],
# Spectrum
'mcd': [],
# Speaker Similarity
'prompt_gen_cos_sim': [],
# FFE, GPE, VDE
'ffe': [], 'gpe': [], 'vde': [],
# MOS
'singmos': [], 'sheet': [],
}
print(f"Reading from {args.input_file} ...")
valid_count = 0
has_prompt_language = False
try:
with open(args.input_file, 'r', encoding='utf-8') as fin:
for line in fin:
line = line.strip()
if not line: continue
try:
data = json.loads(line)
except json.JSONDecodeError:
continue
if data.get('eval_status') != 'success':
continue
valid_count += 1
if 'prompt_language' in data:
has_prompt_language = True
for key in metrics.keys():
if key in data:
try:
val = float(data[key])
metrics[key].append(val)
except (ValueError, TypeError):
pass
except FileNotFoundError:
print(f"Error: File {args.input_file} not found.")
return
def _init_metrics():
return {k: [] for k in metrics.keys()}
def _avg_metrics(metric_dict):
out = {}
for key, val_list in metric_dict.items():
if len(val_list) > 0:
avg_val = sum(val_list) / len(val_list)
out[key] = f"{avg_val:.4f}"
else:
out[key] = "0.0000"
return out
if has_prompt_language:
parallel_metrics = _init_metrics()
cross_metrics = _init_metrics()
parallel_count = 0
cross_count = 0
if valid_count > 0:
with open(args.input_file, 'r', encoding='utf-8') as fin:
for line in fin:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
continue
if data.get('eval_status') != 'success':
continue
if 'prompt_language' not in data:
continue
src_lang = data.get('language')
prompt_lang = data.get('prompt_language')
target = parallel_metrics if src_lang == prompt_lang else cross_metrics
if src_lang == prompt_lang:
parallel_count += 1
else:
cross_count += 1
for key in metrics.keys():
if key in data:
try:
val = float(data[key])
target[key].append(val)
except (ValueError, TypeError):
pass
out_parallel = _avg_metrics(parallel_metrics)
out_cross = _avg_metrics(cross_metrics)
out_dict = {
"parallel": {
"infer_num": parallel_count,
**out_parallel,
},
"cross": {
"infer_num": cross_count,
**out_cross,
},
}
print(f"\n=== Evaluation Summary (Samples: {valid_count}) ===")
print(f"parallel infer_num: {parallel_count}")
print(f"cross infer_num: {cross_count}")
else:
out_all = _avg_metrics(metrics)
out_dict = {
"infer_num": valid_count,
**out_all,
}
print(f"\n=== Evaluation Summary (Samples: {valid_count}) ===")
print(f"infer_num: {valid_count}")
with open(args.result_file, 'w', encoding='utf-8') as fout:
json.dump(out_dict, fout, indent=4, ensure_ascii=False)
print(f'\nDetailed averaged results saved to {args.result_file}')
if __name__ == "__main__":
main()