Skip to content

Commit a80414b

Browse files
jkooyKairuiHu
authored andcommitted
Support new task mmworld (#269)
* support new task mmworld * Apply linting fixes
1 parent 8ca58e9 commit a80414b

File tree

2 files changed

+206
-0
lines changed

2 files changed

+206
-0
lines changed
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
dataset_path: MMWorld/MMWorld
2+
dataset_kwargs:
3+
token: True
4+
cache_dir: mmworld
5+
video: True
6+
# From_YouTube: True
7+
task: mmworld
8+
test_split: train
9+
output_type: generate_until
10+
doc_to_visual: !function utils.mmworld_doc_to_visual
11+
doc_to_text: !function utils.mmworld_doc_to_text
12+
doc_to_target: "answer"
13+
generation_kwargs:
14+
max_new_tokens: 16
15+
temperature: 0
16+
top_p: 1.0
17+
num_beams: 1
18+
do_sample: false
19+
# The return value of process_results will be used by metrics
20+
process_results: !function utils.mmworld_process_results
21+
# Note that the metric name can be either a registed metric function (such as the case for GQA) or a key name returned by process_results
22+
metric_list:
23+
- metric: mmworld_accuracy
24+
aggregation: !function utils.mmworld_aggregate_results
25+
higher_is_better: true
26+
lmms_eval_specific_kwargs:
27+
default:
28+
pre_prompt: ""
29+
post_prompt: "\nAnswer with the option's letter from the given choices directly."
30+
gpt4v:
31+
pre_prompt: ""
32+
post_prompt: "\nAnswer the question with A, B, C, or D."
33+
xcomposer2_4khd:
34+
pre_prompt: "[UNUSED_TOKEN_146]user\n"
35+
post_prompt: " Answer this question with A, B, C, or D.[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n"
36+
metadata:
37+
- version: 0.0

lmms_eval/tasks/mmworld/utils.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import datetime
2+
import json
3+
import os
4+
import re
5+
import shutil
6+
import sys
7+
from collections import defaultdict
8+
from pathlib import Path
9+
from typing import Dict, List, Optional, Union
10+
11+
import cv2
12+
import numpy as np
13+
import yaml
14+
from loguru import logger as eval_logger
15+
16+
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
17+
18+
DISCIPLINES = ["Tech & Engineering", "Science", "Health & Medicine", "Sports & Arts", "Game", "Business", "Embodied Tasks"]
19+
20+
21+
replace_prompt = " Please answer yes or no."
22+
23+
# with open(Path(__file__).parent / "_default_template_yaml", "r") as f:
24+
# raw_data = f.readlines()
25+
# safe_data = []
26+
# for i, line in enumerate(raw_data):
27+
# # remove function definition since yaml load cannot handle it
28+
# if "!function" not in line:
29+
# safe_data.append(line)
30+
31+
# config = yaml.safe_load("".join(safe_data))
32+
33+
hf_home = os.getenv("HF_HOME", "~/.cache/huggingface/")
34+
# cache_dir = os.path.join(hf_home, cache_dir)
35+
# base_cache_dir = config["dataset_kwargs"]["cache_dir"]
36+
base_cache_dir = os.path.expanduser(hf_home)
37+
with open(Path(__file__).parent / "mmworld.yaml", "r") as f:
38+
raw_data = f.readlines()
39+
safe_data = []
40+
for i, line in enumerate(raw_data):
41+
# remove function definition since yaml load cannot handle it
42+
if "!function" not in line:
43+
safe_data.append(line)
44+
cache_name = yaml.safe_load("".join(safe_data))["dataset_kwargs"]["cache_dir"]
45+
46+
47+
def extract_and_remove_subfolders(cache_dir):
48+
# Walk through all the subdirectories and move files to the root of cache_dir
49+
for root, dirs, files in os.walk(cache_dir):
50+
for file in files:
51+
source = os.path.join(root, file)
52+
destination = os.path.join(cache_dir, file)
53+
if source != destination:
54+
shutil.move(source, destination)
55+
56+
for root, dirs, files in os.walk(cache_dir, topdown=False):
57+
for dir in dirs:
58+
os.rmdir(os.path.join(root, dir))
59+
60+
61+
def mmworld_doc_to_visual(doc):
62+
cache_dir = os.path.join(base_cache_dir, cache_name)
63+
extract_and_remove_subfolders(cache_dir)
64+
video_path_doc = doc["video_id"].split("/")[-1] + ".mp4"
65+
video_path = os.path.join(cache_dir, video_path_doc).replace(".mp4.mp4", ".mp4")
66+
67+
if os.path.exists(video_path):
68+
video_path = video_path
69+
elif os.path.exists(video_path.replace("mp4", "MP4")):
70+
video_path = video_path.replace("mp4", "MP4")
71+
elif os.path.exists(video_path.replace("mp4", "avi")):
72+
video_path = video_path.replace("mp4", "avi")
73+
elif os.path.exists(os.path.join(cache_dir, "shorts:" + video_path_doc)):
74+
video_path = os.path.join(cache_dir, "shorts:" + video_path_doc)
75+
elif os.path.exists(os.path.join(cache_dir, "shorts:" + doc["video_id"].split("/")[-1] + ".MP4")):
76+
video_path = os.path.join(cache_dir, "shorts:" + doc["video_id"].split("/")[-1] + doc["video_id"] + ".MP4")
77+
elif os.path.exists(os.path.join(cache_dir, "shorts:" + doc["video_id"].split("/")[-1] + ".avi")):
78+
video_path = os.path.join(cache_dir, "shorts:" + doc["video_id"].split("/")[-1] + ".avi")
79+
else:
80+
sys.exit(f"video path:{video_path} does not exist, please check")
81+
82+
return [video_path]
83+
84+
85+
def mmworld_doc_to_text(doc, lmms_eval_specific_kwargs=None):
86+
option_prompt = "Select the best answer to the following multiple-choice question based on the video and the subtitles. Respond with only the letter (A, B, C, or D) of the correct option."
87+
question = doc["question"]
88+
option = str(doc["options"])
89+
question = question + "\n" + option
90+
post_prompt = lmms_eval_specific_kwargs["post_prompt"] if "post_prompt" in lmms_eval_specific_kwargs else "The best answer is:"
91+
full_prompt = option_prompt + "\n" + question + "\n" + post_prompt
92+
return full_prompt
93+
94+
95+
def extract_characters_regex(s):
96+
s = s.strip()
97+
answer_prefixes = [
98+
"The best answer is",
99+
"The correct answer is",
100+
"The answer is",
101+
"The answer",
102+
"The best option is" "The correct option is",
103+
"Best answer:" "Best option:",
104+
]
105+
for answer_prefix in answer_prefixes:
106+
s = s.replace(answer_prefix, "")
107+
108+
if len(s.split()) > 10 and not re.search("[ABCD]", s):
109+
return ""
110+
111+
matches = re.search(r"[ABCD]", s)
112+
if matches is None:
113+
return ""
114+
return matches[0]
115+
116+
117+
def mmworld_process_results(doc, results):
118+
"""
119+
Args:
120+
doc: a instance of the eval dataset
121+
results: [pred]
122+
Returns:
123+
a dictionary with key: metric name (in this case videomme score), value: metric value
124+
"""
125+
pred = results[0]
126+
pred_ans = extract_characters_regex(pred)
127+
# gt_ans = doc["answer"].lower().strip().replace(".", "")
128+
129+
discipline = doc["discipline"]
130+
data_dict = {"video_id": doc["video_id"], "discipline": discipline, "pred_answer": pred_ans, "answer": doc["correct_answer_label"].upper()}
131+
132+
return {f"mmworld_accuracy": data_dict}
133+
134+
135+
def mmworld_aggregate_results(results):
136+
"""
137+
Args:
138+
results: a list of values returned by process_results
139+
Returns:
140+
A score
141+
"""
142+
category2score = {}
143+
144+
for category in DISCIPLINES:
145+
key = f"{category}"
146+
category2score[key] = {"correct": 0, "answered": 0}
147+
148+
for result in results:
149+
category = result["discipline"]
150+
key = f"{category}"
151+
category2score[key]["answered"] += 1
152+
category2score[key]["correct"] += result["pred_answer"] == result["answer"]
153+
154+
for category in DISCIPLINES:
155+
total_correct = 0
156+
total_answered = 0
157+
for k, v in category2score.items():
158+
if category in k:
159+
total_correct += v["correct"]
160+
total_answered += v["answered"]
161+
eval_logger.info(f"Evaluation on DISCIPLINES: {category}: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%")
162+
163+
total_correct = 0
164+
total_answered = 0
165+
for k, v in category2score.items():
166+
total_correct += v["correct"]
167+
total_answered += v["answered"]
168+
eval_logger.info(f"Overall Performance: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%")
169+
return 100 * total_correct / total_answered if total_answered > 0 else 0

0 commit comments

Comments
 (0)