Skip to content

Commit 1ad3206

Browse files
committed
add medqa task
1 parent 197a934 commit 1ad3206

File tree

2 files changed

+148
-0
lines changed

2 files changed

+148
-0
lines changed

lmms_eval/tasks/medqa/medqa.yaml

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
dataset_path: lmms-lab/MEDQA
2+
dataset_kwargs:
3+
token: True
4+
5+
task: "medqa"
6+
test_split: test
7+
doc_to_target: !function utils.medqa_doc_to_target
8+
doc_to_visual: null
9+
doc_to_text: !function utils.medqa_doc_to_text
10+
doc_to_choice: !function utils.medqa_doc_to_choice
11+
12+
lmms_eval_specific_kwargs:
13+
default:
14+
pre_prompt: ""
15+
post_prompt: "\nAnswer with the option's letter from the given choices directly: "
16+
metric_list:
17+
- metric: accuracy
18+
aggregation: mean
19+
higher_is_better: true
20+
21+
process_results: !function utils.medqa_process_results
22+
23+
metadata:
24+
version: 0.0

lmms_eval/tasks/medqa/utils.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
import random
2+
from typing import List, Dict, Any
3+
4+
import numpy as np
5+
6+
7+
def medqa_doc_to_text(doc: Dict[str, Any], lmms_eval_specific_kwargs: Dict[str, Any]):
8+
"""
9+
Build the MCQ prompt from MEDQA sample.
10+
11+
Expected doc fields (from `lmms-lab/MEDQA` parquet):
12+
- "question": str
13+
- "options": dict mapping letters to option strings (e.g., {"A": "...", "B": "..."})
14+
- Some samples may also expose choices as list-like; we normalize to a lettered block.
15+
- We do not use visuals for MEDQA.
16+
"""
17+
question = doc.get("question", "").strip()
18+
19+
# Normalize options into A..E style lines
20+
options = doc.get("options")
21+
if isinstance(options, dict):
22+
# Keep only A-E in sorted letter order if present
23+
ordered_keys = [k for k in ["A", "B", "C", "D", "E"] if k in options]
24+
options_block = "\n".join([f"{k}. {str(options[k]).strip()}" for k in ordered_keys])
25+
elif isinstance(options, list):
26+
letters = ["A", "B", "C", "D", "E"]
27+
options_block = "\n".join([f"{letters[i]}. {str(opt).strip()}" for i, opt in enumerate(options)])
28+
else:
29+
# Fallback: try to format if already string-like
30+
options_block = str(options) if options is not None else ""
31+
32+
pre_prompt = lmms_eval_specific_kwargs["pre_prompt"]
33+
post_prompt = lmms_eval_specific_kwargs["post_prompt"]
34+
prompt = f"{question}\n{options_block}"
35+
return f"{pre_prompt}{prompt}{post_prompt}"
36+
37+
38+
def medqa_doc_to_target(doc: Dict[str, Any]):
39+
"""
40+
Return the ground-truth answer letter.
41+
42+
MEDQA on HF commonly provides either:
43+
- "answer_idx": a letter like "A"/"B"/... OR
44+
- "answer": a full string like "C" or the option text. We prioritize letter if available.
45+
"""
46+
# Prefer explicit answer letter field when present
47+
if "answer_idx" in doc and isinstance(doc["answer_idx"], str) and len(doc["answer_idx"]) == 1:
48+
return doc["answer_idx"].strip()
49+
50+
# Some variants store the letter in "answer" directly
51+
ans = doc.get("answer")
52+
if isinstance(ans, str) and len(ans.strip()) == 1 and ans.strip().upper() in ["A", "B", "C", "D", "E"]:
53+
return ans.strip().upper()
54+
55+
# If answer is provided as text, try to map back to a letter via options
56+
options = doc.get("options")
57+
if isinstance(options, dict) and isinstance(ans, str):
58+
for k, v in options.items():
59+
if isinstance(v, str) and v.strip() == ans.strip():
60+
return k
61+
62+
# Fallback: unknown -> choose a dummy; evaluation will mark as incorrect
63+
return "A"
64+
65+
66+
def medqa_doc_to_choice(doc: Dict[str, Any]) -> List[str]:
67+
# Detect how many choices are present and return corresponding letters
68+
if isinstance(doc.get("options"), dict):
69+
present = [k for k in ["A", "B", "C", "D", "E"] if k in doc["options"]]
70+
if present:
71+
return present
72+
if isinstance(doc.get("options"), list):
73+
n = min(len(doc["options"]), 5)
74+
return ["A", "B", "C", "D", "E"][:n]
75+
# Default to 5-way if uncertain
76+
return ["A", "B", "C", "D", "E"]
77+
78+
79+
def medqa_process_results(doc: Dict[str, Any], result: List[str]):
80+
"""
81+
Parse model output and compute accuracy against the gold letter.
82+
We robustly extract a single letter from the response.
83+
"""
84+
response = result[0].strip()
85+
all_choices = medqa_doc_to_choice(doc)
86+
pred = _parse_multi_choice_response(response, all_choices)
87+
gt_ans = medqa_doc_to_target(doc)
88+
score = 1.0 if pred == gt_ans else 0.0
89+
return {"accuracy": score}
90+
91+
92+
def _parse_multi_choice_response(response: str, all_choices: List[str]) -> str:
93+
# Clean punctuation around the response
94+
for ch in [",", ".", "!", "?", ";", ":", "'"]:
95+
response = response.strip(ch)
96+
response = " " + response + " "
97+
98+
candidates = []
99+
# (A) style
100+
for c in all_choices:
101+
if f"({c})" in response:
102+
candidates.append(c)
103+
104+
# plain letter surrounded by spaces
105+
if len(candidates) == 0:
106+
for c in all_choices:
107+
if f" {c} " in response:
108+
candidates.append(c)
109+
110+
# A., B., etc.
111+
if len(candidates) == 0:
112+
for c in all_choices:
113+
if f"{c}." in response:
114+
candidates.append(c)
115+
116+
if len(candidates) == 0:
117+
return random.choice(all_choices)
118+
if len(candidates) > 1:
119+
# choose the last occurrence to mitigate explanations mentioning multiple letters
120+
start_indexes = [response.rfind(f" {can} ") for can in candidates]
121+
return candidates[int(np.argmax(start_indexes))]
122+
return candidates[0]
123+
124+

0 commit comments

Comments
 (0)