|
| 1 | +import json |
| 2 | +import os |
| 3 | +from pathlib import Path |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import yaml |
| 7 | +from loguru import logger as eval_logger |
| 8 | + |
| 9 | +hf_home = os.getenv("HF_HOME", "~/.cache/huggingface/") |
| 10 | +base_cache_dir = os.path.expanduser(hf_home) |
| 11 | +with open(Path(__file__).parent / "temporalbench_short_qa.yaml", "r") as f: |
| 12 | + raw_data = f.readlines() |
| 13 | + safe_data = [] |
| 14 | + for i, line in enumerate(raw_data): |
| 15 | + if "!function" not in line: |
| 16 | + safe_data.append(line) |
| 17 | +cache_name = yaml.safe_load("".join(safe_data))["dataset_kwargs"]["cache_dir"] |
| 18 | + |
| 19 | + |
| 20 | +textscore_dict, videoscore_dict = {}, {} |
| 21 | + |
| 22 | + |
| 23 | +def prep_data(): |
| 24 | + global textscore_dict, videoscore_dict |
| 25 | + cache_dir = os.path.join(base_cache_dir, cache_name) |
| 26 | + breakpoint() |
| 27 | + with open(os.path.join(cache_dir, "temporalbench_short_qa.json")) as f: |
| 28 | + textscore_list = json.load(f) |
| 29 | + textscore_dict = {} |
| 30 | + for item in textscore_list: |
| 31 | + textscore_dict[item["idx"]] = item |
| 32 | + return textscore_dict |
| 33 | + |
| 34 | + |
| 35 | +def temporalbench_doc_to_visual(doc): |
| 36 | + cache_dir = os.path.join(base_cache_dir, cache_name) |
| 37 | + video_path = os.path.join(cache_dir, doc["video_name"]) |
| 38 | + if not os.path.exists(video_path): |
| 39 | + raise Exception(f"video path:{video_path} does not exist, please check") |
| 40 | + return video_path |
| 41 | + |
| 42 | + |
| 43 | +def temporalbench_doc_to_text(doc, lmms_eval_specific_kwargs=None): |
| 44 | + return doc["question"] |
| 45 | + |
| 46 | + |
| 47 | +def temporalbench_process_results(doc, results): |
| 48 | + pred = results[0] |
| 49 | + data_dict = {"item": doc, "pred": pred} |
| 50 | + |
| 51 | + return {"temporalbench_score": data_dict} |
| 52 | + |
| 53 | + |
| 54 | +def temporalbench_caption_aggregate_results(results): |
| 55 | + from sentence_transformers import SentenceTransformer, util |
| 56 | + |
| 57 | + preds = [] |
| 58 | + for data in results: |
| 59 | + preds.append({"idx": data["item"]["idx"], "response": data["pred"]}) |
| 60 | + id2question = {} |
| 61 | + for data in results: |
| 62 | + id2question[data["item"]["idx"]] = data["item"] |
| 63 | + |
| 64 | + gt_list = [id2question[pred["idx"]]["GT"] for pred in preds] |
| 65 | + ref_list = [pred["response"] for pred in preds] |
| 66 | + |
| 67 | + model_name = "all-MiniLM-L6-v2" |
| 68 | + model = SentenceTransformer(model_name) |
| 69 | + model = model.to("cuda:0") |
| 70 | + |
| 71 | + # Combine ref and gt lists into a big batch for encoding |
| 72 | + combined_sentences = ref_list + gt_list |
| 73 | + |
| 74 | + # Encode the batch with CUDA |
| 75 | + embeddings = model.encode(combined_sentences, convert_to_tensor=True, device="cuda") |
| 76 | + |
| 77 | + # Split embeddings into ref and gt parts |
| 78 | + ref_embeddings = embeddings[: len(ref_list)] |
| 79 | + gt_embeddings = embeddings[len(ref_list) :] |
| 80 | + |
| 81 | + # Calculate cosine similarities between each ref and gt pair |
| 82 | + cosine_scores = util.cos_sim(ref_embeddings, gt_embeddings).diagonal() |
| 83 | + |
| 84 | + # Calculate the average similarity |
| 85 | + avg_similarity = cosine_scores.mean().item() * 100 |
| 86 | + |
| 87 | + return avg_similarity |
| 88 | + |
| 89 | + |
| 90 | +def temporalbench_aggregate_results(results): |
| 91 | + preds = [] |
| 92 | + for data in results: |
| 93 | + preds.append({"idx": data["item"]["idx"], "response": data["pred"]}) |
| 94 | + |
| 95 | + id2question = {} |
| 96 | + for data in results: |
| 97 | + id2question[data["item"]["idx"]] = data["item"] |
| 98 | + correct_count = 0 |
| 99 | + multiple_binary_qa_correct = {} |
| 100 | + binary_qa_per_dataset = {} |
| 101 | + multiple_binary_qa_per_dataset = {} |
| 102 | + |
| 103 | + if "category" in data["item"] and data["item"]["category"] != "": |
| 104 | + binary_qa_per_category = {} |
| 105 | + multiple_binary_qa_per_category = {} |
| 106 | + |
| 107 | + for pred in preds: |
| 108 | + # Binary QA Accuracy |
| 109 | + idx = pred["idx"] |
| 110 | + gt = id2question[idx]["GT"] |
| 111 | + predict_correct = gt.lower() == pred["response"][0].lower() |
| 112 | + if predict_correct: |
| 113 | + correct_count += 1 |
| 114 | + |
| 115 | + # Multiple Binary QA Accuracy |
| 116 | + video_name = id2question[idx]["video_name"] |
| 117 | + if video_name not in multiple_binary_qa_correct: |
| 118 | + multiple_binary_qa_correct[video_name] = True |
| 119 | + if not predict_correct: |
| 120 | + multiple_binary_qa_correct[video_name] = False |
| 121 | + |
| 122 | + # Per dataset Performance |
| 123 | + dataset = id2question[idx]["dataset"] |
| 124 | + if dataset not in binary_qa_per_dataset: |
| 125 | + binary_qa_per_dataset[dataset] = [] |
| 126 | + multiple_binary_qa_per_dataset[dataset] = {} |
| 127 | + binary_qa_per_dataset[dataset].append(predict_correct) |
| 128 | + if video_name not in multiple_binary_qa_per_dataset[dataset]: |
| 129 | + multiple_binary_qa_per_dataset[dataset][video_name] = True |
| 130 | + if not predict_correct: |
| 131 | + multiple_binary_qa_per_dataset[dataset][video_name] = False |
| 132 | + |
| 133 | + # Per category Performance |
| 134 | + if "category" in data["item"] and data["item"]["category"] != "": |
| 135 | + category = id2question[idx]["category"] |
| 136 | + if category not in binary_qa_per_category: |
| 137 | + binary_qa_per_category[category] = [] |
| 138 | + multiple_binary_qa_per_category[category] = {} |
| 139 | + binary_qa_per_category[category].append(predict_correct) |
| 140 | + if video_name not in multiple_binary_qa_per_category[category]: |
| 141 | + multiple_binary_qa_per_category[category][video_name] = True |
| 142 | + if not predict_correct: |
| 143 | + multiple_binary_qa_per_category[category][video_name] = False |
| 144 | + |
| 145 | + # Print the results |
| 146 | + # try: |
| 147 | + width_dataset = 40 # for dataset names |
| 148 | + width_counts = 15 # for correct/total counts |
| 149 | + width_percentage = 1 # for percentages |
| 150 | + loginfo = "" |
| 151 | + loginfo += "*" * 20 |
| 152 | + Binary_accuracy = correct_count / len(preds) * 100 |
| 153 | + loginfo += "\n" |
| 154 | + loginfo += f"{'Binary Accuracy:':<{width_dataset}} {correct_count}/{len(preds):<{width_counts}} {Binary_accuracy:>{width_percentage}.2f}%" |
| 155 | + mba_correct = sum([1 for v in multiple_binary_qa_correct.values() if v]) |
| 156 | + Multiple_Binary_accuracy = mba_correct / len(multiple_binary_qa_correct) * 100 |
| 157 | + loginfo += "\n" |
| 158 | + loginfo += f"{'Multiple Binary Accuracy:':<{width_dataset}} {mba_correct}/{len(multiple_binary_qa_correct):<{width_counts}} {Multiple_Binary_accuracy:>{width_percentage}.2f}%" |
| 159 | + # Print header |
| 160 | + loginfo += "\n" |
| 161 | + loginfo += "+" * 110 |
| 162 | + loginfo += "\n" |
| 163 | + loginfo += f"|+++ {'Dataset':<{width_dataset}}Binary Accuracy {'':<{7}} {'':>{width_percentage}} " |
| 164 | + f"||| Multiple Binary Accuracy {'':<{width_counts}} {'':>{width_percentage}}" |
| 165 | + loginfo += "\n" |
| 166 | + loginfo += "+" * 110 |
| 167 | + for dataset, binary_qa in binary_qa_per_dataset.items(): |
| 168 | + mba_correct = sum([1 for v in multiple_binary_qa_per_dataset[dataset].values() if v]) |
| 169 | + loginfo += "\n" |
| 170 | + loginfo += f"|--- {dataset + ' ':<{width_dataset}} {sum(binary_qa)}/{len(binary_qa):<{width_counts}} {sum(binary_qa)/len(binary_qa) * 100:>{width_percentage}.2f}% ||| {mba_correct}/{len(multiple_binary_qa_per_dataset[dataset]):<{width_counts}} {mba_correct/len(multiple_binary_qa_per_dataset[dataset]) * 100:>{width_percentage}.2f}%" |
| 171 | + |
| 172 | + if "category" in data["item"] and data["item"]["category"] != "": |
| 173 | + loginfo += "\n" |
| 174 | + loginfo += "+" * 110 |
| 175 | + loginfo += "\n" |
| 176 | + loginfo += f"|-- {'Category':<{width_dataset}}Binary Accuracy {'':<{7}} {'':>{width_percentage}} " |
| 177 | + f"||| Multiple Binary Accuracy {'':<{width_counts}} {'':>{width_percentage}}" |
| 178 | + loginfo += "\n" |
| 179 | + loginfo += "+" * 110 |
| 180 | + category_mapping = { |
| 181 | + 1: "Action Order", |
| 182 | + 2: "Action Frequency", |
| 183 | + 3: "Action Type", |
| 184 | + 4: "Motion Magnitude", |
| 185 | + 5: "Motion Direction/Orientation", |
| 186 | + 6: "Action Effector", |
| 187 | + 8: "Event Order", |
| 188 | + 7: "Others", |
| 189 | + } |
| 190 | + for category_index, category in category_mapping.items(): |
| 191 | + if category in binary_qa_per_category: |
| 192 | + binary_qa = binary_qa_per_category[category] |
| 193 | + mba_correct = sum([1 for v in multiple_binary_qa_per_category[category].values() if v]) |
| 194 | + loginfo += "\n" |
| 195 | + loginfo += ( |
| 196 | + f"|--- {category + ' ':<{width_dataset}} {sum(binary_qa)}/{len(binary_qa):<{width_counts}} {sum(binary_qa)/len(binary_qa) * 100:>{width_percentage}.2f}% " |
| 197 | + f"||| {mba_correct}/{len(multiple_binary_qa_per_category[category]):<{width_counts}} {mba_correct/len(multiple_binary_qa_per_category[category]) * 100:>{width_percentage}.2f}%" |
| 198 | + ) |
| 199 | + eval_logger.info(loginfo) |
| 200 | + return Binary_accuracy, Multiple_Binary_accuracy |
0 commit comments