|
| 1 | +import json |
| 2 | +import re |
| 3 | + |
| 4 | + |
| 5 | +def extract_answer(text): |
| 6 | + match = re.findall(r"(?<!^)[A-Z]", text) |
| 7 | + if match: |
| 8 | + return match[0] |
| 9 | + return None |
| 10 | + |
| 11 | + |
| 12 | +def livexiv_doc_to_visual(doc): |
| 13 | + return [doc["image"].convert("RGB")] |
| 14 | + |
| 15 | + |
| 16 | +def livexiv_doc_to_text(doc, model_specific_kwargs=None): |
| 17 | + question = doc["question"] |
| 18 | + question += "\n" + f"A. {doc['option_a']}\n" |
| 19 | + question += f"B. {doc['option_b']}\n" |
| 20 | + question += f"C. {doc['option_c']}\n" |
| 21 | + question += f"D. {doc['option_d']}" |
| 22 | + return f"{question}\nAnswer with the option's letter from the given choices directly." |
| 23 | + |
| 24 | + |
| 25 | +def livexiv_process_result(doc, result): |
| 26 | + pred = result[0].strip() |
| 27 | + if len(pred) > 1: |
| 28 | + if "answer" in pred.lower(): |
| 29 | + pred = extract_answer(pred) |
| 30 | + else: |
| 31 | + pred = pred[0] |
| 32 | + answer = doc["gt"] |
| 33 | + |
| 34 | + return {f"livexiv_tqa": {"pred": pred, "answer": answer}} |
| 35 | + |
| 36 | + |
| 37 | +def livexiv_aggregation_result(results): |
| 38 | + total_count = 0 |
| 39 | + total_correct = 0 |
| 40 | + for result in results: |
| 41 | + try: |
| 42 | + if result["pred"].lower().strip() == result["answer"].lower().strip(): |
| 43 | + total_correct += 1 |
| 44 | + except Exception as e: |
| 45 | + print(e) |
| 46 | + |
| 47 | + total_count += 1 |
| 48 | + return total_correct / total_count |
| 49 | + |
| 50 | + |
| 51 | +def livexiv_aggregation_result_all(results): |
| 52 | + score = livexiv_aggregation_result(results) |
| 53 | + stored_results = [] |
| 54 | + for result in results: |
| 55 | + stored_results.append({"question_id": result["question_id"], "prediction": result["pred"]}) |
| 56 | + with open("./livexiv_tqa_submission.json", "w") as f: |
| 57 | + json.dump(stored_results, f, indent=4) |
| 58 | + print("Storing files for LiveXiv-TQA submission ...") |
| 59 | + |
| 60 | + return score |
| 61 | + |
| 62 | + |
| 63 | +def livexiv_doc_to_text_mc(doc): |
| 64 | + question = doc["question"] |
| 65 | + return f"{question} Answer :" |
| 66 | + |
| 67 | + |
| 68 | +def livexiv_doc_to_choice(doc): |
| 69 | + return [doc["option_a"], doc["option_b"], doc["option_c"], doc["option_d"]] |
| 70 | + |
| 71 | + |
| 72 | +def livexiv_doc_to_mc_target(doc): |
| 73 | + answer2choice = {"A": "option_a", "B": "option_b", "C": "option_c", "D": "option_d"} |
| 74 | + return doc[answer2choice[doc["answer"]]] |
0 commit comments