-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathoutput.py
More file actions
33 lines (24 loc) · 787 Bytes
/
output.py
File metadata and controls
33 lines (24 loc) · 787 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import argparse
import pandas as pd
from tqdm import tqdm
import json
from datasets import load_dataset
def main(args):
dataset_df = load_dataset("ScottHan/CorrelationQA")['train']['true_answer']
with open(args.input_path) as f:
mllm_answer = f.readlines()
true = 0
wrong = 0
print(len(mllm_answer))
assert len(dataset_df) == len(mllm_answer)
for i in range(len(mllm_answer)):
if dataset_df[i].lower() in mllm_answer[i].lower():
true+=1
else:
wrong+=1
print(f'Average Accuracy: { true / (true+wrong) }')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--input_path", type=str, default= "output.txt")
args = parser.parse_args()
main(args)