Skip to content

Commit 9f8b987

Browse files
authored
[Dataset] Add mmau task (EvolvingLMMs-Lab#585)
1 parent fe7a5a6 commit 9f8b987

File tree

5 files changed

+171
-0
lines changed

5 files changed

+171
-0
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
dataset_path: lmms-lab/mmau
2+
dataset_kwargs:
3+
token: True
4+
doc_to_target: "answer"
5+
doc_to_visual: !function utils.doc_to_audio
6+
doc_to_text: !function utils.doc_to_text
7+
doc_to_choice: !function utils.doc_to_choice
8+
generation_kwargs:
9+
max_new_tokens: 128
10+
do_sample: false
11+
lmms_eval_specific_kwargs:
12+
default:
13+
pre_prompt: ""
14+
post_prompt: "\nAnswer with the option's letter from the given choices directly."
15+
16+
process_results: !function utils.mmau_process_results
17+
18+
metadata:
19+
version: 0.0

lmms_eval/tasks/mmau/mmau.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
group: mmau
2+
task:
3+
- mmau_test_mini
4+
- mmau_test
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
task: "mmau_test"
2+
test_split: test
3+
4+
metric_list:
5+
- metric: submission
6+
aggregation: !function utils.mmau_aggregate_results_for_submission
7+
higher_is_better: true
8+
9+
include: _default_template_yaml
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
task: "mmau_test_mini"
2+
test_split: test_mini
3+
4+
metric_list:
5+
- metric: accuracy
6+
aggregation: !function utils.mmau_aggregate_results
7+
higher_is_better: true
8+
9+
include: _default_template_yaml

lmms_eval/tasks/mmau/utils.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import datetime
2+
import json
3+
import os
4+
import random
5+
import re
6+
import sys
7+
import time
8+
from collections import defaultdict
9+
from pathlib import Path
10+
11+
import numpy as np
12+
import yaml
13+
from loguru import logger as eval_logger
14+
15+
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
16+
17+
18+
def doc_to_audio(doc):
19+
return [doc["audio"]]
20+
21+
22+
def doc_to_text(doc, lmms_eval_specific_kwargs):
23+
letter = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J"]
24+
pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
25+
post_prompt = lmms_eval_specific_kwargs["post_prompt"]
26+
question = doc["question"]
27+
choices = json.loads(doc["choices"])
28+
choices = "\n".join([f"{letter[i]}. {choice}" for i, choice in enumerate(choices)])
29+
return f"{pre_prompt}{question}\n{choices}{post_prompt}"
30+
31+
32+
def doc_to_choice(doc):
33+
choices = json.loads(doc["choices"])
34+
return choices
35+
36+
37+
def mmau_process_results(doc, result):
38+
letter = ["A", "B", "C", "D"]
39+
response = parse_multi_choice_response(result[0], letter)
40+
response = letter_to_ans(response, json.loads(doc["choices"]))
41+
doc["model_prediction"] = response
42+
response = response.strip().lower()
43+
gt_ans = doc["answer"].strip().lower()
44+
score = 1.0 if response == gt_ans else 0.0
45+
46+
return {"accuracy": {"overall": score, "task": doc["task"]}, "submission": {**doc}}
47+
48+
49+
def mmau_aggregate_results(results):
50+
total_correct = 0
51+
group_totals = defaultdict(int)
52+
group_correct = defaultdict(int)
53+
54+
for result in results:
55+
accuracy = result["overall"]
56+
total_correct += accuracy
57+
58+
group_totals[result["task"]] += 1
59+
group_correct[result["task"]] += accuracy
60+
61+
overall_accuracy = round(total_correct * 100 / len(results), 5)
62+
categorical_accuracy = {key: round(group_correct[key] * 100 / group_totals[key], 5) for key in group_totals.keys()}
63+
eval_logger.info("=" * 50)
64+
eval_logger.info(f"Overall accuracy: {overall_accuracy}")
65+
eval_logger.info("Categorical accuracy: ")
66+
for key, value in categorical_accuracy.items():
67+
eval_logger.info(f"{key} accuracy: {value}")
68+
eval_logger.info("=" * 50)
69+
return overall_accuracy
70+
71+
72+
def mmau_aggregate_results_for_submission(results, args):
73+
path = generate_submission_file("mmau_submission.json", args)
74+
filtered_results = []
75+
keys_to_keep = ["id", "audio_id", "question", "choices", "model_prediction", "dataset", "task", "split", "category", "sub-category", "difficulty"]
76+
77+
for result in results:
78+
filtered_result = {key: result[key] for key in keys_to_keep if key in result}
79+
filtered_results.append(filtered_result)
80+
81+
results = filtered_results
82+
with open(path, "w") as f:
83+
json.dump(results, f, indent=4)
84+
eval_logger.info(f"Results saved to {path}.")
85+
86+
87+
def parse_multi_choice_response(response, all_choices):
88+
"""
89+
Parse the prediction from the generated response.
90+
Return the predicted choice letter e.g., A, B, C, D.
91+
"""
92+
# Clean response of unwanted characters
93+
for char in [",", ".", "!", "?", ";", ":", "'"]:
94+
response = response.strip(char)
95+
response = " " + response + " " # Add space to avoid partial match
96+
97+
candidates = []
98+
# Look for choices with parentheses, e.g., (A)
99+
for choice in all_choices:
100+
if f"({choice})" in response:
101+
candidates.append(choice)
102+
103+
# Look for simple choices, e.g., A, B, C
104+
if len(candidates) == 0:
105+
for choice in all_choices:
106+
if f" {choice} " in response:
107+
candidates.append(choice)
108+
109+
# Look for choices with periods, e.g., A., B., C.
110+
if len(candidates) == 0:
111+
for choice in all_choices:
112+
if f"{choice}." in response:
113+
candidates.append(choice)
114+
115+
# If no candidates, randomly choose one
116+
if len(candidates) == 0:
117+
pred_index = random.choice(all_choices)
118+
elif len(candidates) > 1:
119+
# If more than one candidate, choose the last one found
120+
start_indexes = [response.rfind(f" {can} ") for can in candidates]
121+
pred_index = candidates[np.argmax(start_indexes)]
122+
else:
123+
# If only one candidate, use it
124+
pred_index = candidates[0]
125+
126+
return pred_index
127+
128+
129+
def letter_to_ans(letter, choices):
130+
return choices[ord(letter) - ord("A")]

0 commit comments

Comments
 (0)