|
| 1 | +import random |
| 2 | +from typing import List, Dict, Any |
| 3 | + |
| 4 | +import numpy as np |
| 5 | + |
| 6 | + |
| 7 | +def medqa_doc_to_text(doc: Dict[str, Any], lmms_eval_specific_kwargs: Dict[str, Any]): |
| 8 | + """ |
| 9 | + Build the MCQ prompt from MEDQA sample. |
| 10 | +
|
| 11 | + Expected doc fields (from `lmms-lab/MEDQA` parquet): |
| 12 | + - "question": str |
| 13 | + - "options": dict mapping letters to option strings (e.g., {"A": "...", "B": "..."}) |
| 14 | + - Some samples may also expose choices as list-like; we normalize to a lettered block. |
| 15 | + - We do not use visuals for MEDQA. |
| 16 | + """ |
| 17 | + question = doc.get("question", "").strip() |
| 18 | + |
| 19 | + # Normalize options into A..E style lines |
| 20 | + options = doc.get("options") |
| 21 | + if isinstance(options, dict): |
| 22 | + # Keep only A-E in sorted letter order if present |
| 23 | + ordered_keys = [k for k in ["A", "B", "C", "D", "E"] if k in options] |
| 24 | + options_block = "\n".join([f"{k}. {str(options[k]).strip()}" for k in ordered_keys]) |
| 25 | + elif isinstance(options, list): |
| 26 | + letters = ["A", "B", "C", "D", "E"] |
| 27 | + options_block = "\n".join([f"{letters[i]}. {str(opt).strip()}" for i, opt in enumerate(options)]) |
| 28 | + else: |
| 29 | + # Fallback: try to format if already string-like |
| 30 | + options_block = str(options) if options is not None else "" |
| 31 | + |
| 32 | + pre_prompt = lmms_eval_specific_kwargs["pre_prompt"] |
| 33 | + post_prompt = lmms_eval_specific_kwargs["post_prompt"] |
| 34 | + prompt = f"{question}\n{options_block}" |
| 35 | + return f"{pre_prompt}{prompt}{post_prompt}" |
| 36 | + |
| 37 | + |
| 38 | +def medqa_doc_to_target(doc: Dict[str, Any]): |
| 39 | + """ |
| 40 | + Return the ground-truth answer letter. |
| 41 | +
|
| 42 | + MEDQA on HF commonly provides either: |
| 43 | + - "answer_idx": a letter like "A"/"B"/... OR |
| 44 | + - "answer": a full string like "C" or the option text. We prioritize letter if available. |
| 45 | + """ |
| 46 | + # Prefer explicit answer letter field when present |
| 47 | + if "answer_idx" in doc and isinstance(doc["answer_idx"], str) and len(doc["answer_idx"]) == 1: |
| 48 | + return doc["answer_idx"].strip() |
| 49 | + |
| 50 | + # Some variants store the letter in "answer" directly |
| 51 | + ans = doc.get("answer") |
| 52 | + if isinstance(ans, str) and len(ans.strip()) == 1 and ans.strip().upper() in ["A", "B", "C", "D", "E"]: |
| 53 | + return ans.strip().upper() |
| 54 | + |
| 55 | + # If answer is provided as text, try to map back to a letter via options |
| 56 | + options = doc.get("options") |
| 57 | + if isinstance(options, dict) and isinstance(ans, str): |
| 58 | + for k, v in options.items(): |
| 59 | + if isinstance(v, str) and v.strip() == ans.strip(): |
| 60 | + return k |
| 61 | + |
| 62 | + # Fallback: unknown -> choose a dummy; evaluation will mark as incorrect |
| 63 | + return "A" |
| 64 | + |
| 65 | + |
| 66 | +def medqa_doc_to_choice(doc: Dict[str, Any]) -> List[str]: |
| 67 | + # Detect how many choices are present and return corresponding letters |
| 68 | + if isinstance(doc.get("options"), dict): |
| 69 | + present = [k for k in ["A", "B", "C", "D", "E"] if k in doc["options"]] |
| 70 | + if present: |
| 71 | + return present |
| 72 | + if isinstance(doc.get("options"), list): |
| 73 | + n = min(len(doc["options"]), 5) |
| 74 | + return ["A", "B", "C", "D", "E"][:n] |
| 75 | + # Default to 5-way if uncertain |
| 76 | + return ["A", "B", "C", "D", "E"] |
| 77 | + |
| 78 | + |
| 79 | +def medqa_process_results(doc: Dict[str, Any], result: List[str]): |
| 80 | + """ |
| 81 | + Parse model output and compute accuracy against the gold letter. |
| 82 | + We robustly extract a single letter from the response. |
| 83 | + """ |
| 84 | + response = result[0].strip() |
| 85 | + all_choices = medqa_doc_to_choice(doc) |
| 86 | + pred = _parse_multi_choice_response(response, all_choices) |
| 87 | + gt_ans = medqa_doc_to_target(doc) |
| 88 | + score = 1.0 if pred == gt_ans else 0.0 |
| 89 | + return {"accuracy": score} |
| 90 | + |
| 91 | + |
| 92 | +def _parse_multi_choice_response(response: str, all_choices: List[str]) -> str: |
| 93 | + # Clean punctuation around the response |
| 94 | + for ch in [",", ".", "!", "?", ";", ":", "'"]: |
| 95 | + response = response.strip(ch) |
| 96 | + response = " " + response + " " |
| 97 | + |
| 98 | + candidates = [] |
| 99 | + # (A) style |
| 100 | + for c in all_choices: |
| 101 | + if f"({c})" in response: |
| 102 | + candidates.append(c) |
| 103 | + |
| 104 | + # plain letter surrounded by spaces |
| 105 | + if len(candidates) == 0: |
| 106 | + for c in all_choices: |
| 107 | + if f" {c} " in response: |
| 108 | + candidates.append(c) |
| 109 | + |
| 110 | + # A., B., etc. |
| 111 | + if len(candidates) == 0: |
| 112 | + for c in all_choices: |
| 113 | + if f"{c}." in response: |
| 114 | + candidates.append(c) |
| 115 | + |
| 116 | + if len(candidates) == 0: |
| 117 | + return random.choice(all_choices) |
| 118 | + if len(candidates) > 1: |
| 119 | + # choose the last occurrence to mitigate explanations mentioning multiple letters |
| 120 | + start_indexes = [response.rfind(f" {can} ") for can in candidates] |
| 121 | + return candidates[int(np.argmax(start_indexes))] |
| 122 | + return candidates[0] |
| 123 | + |
| 124 | + |
0 commit comments