From b6827627683bd710802c24f8935160b6268c7569 Mon Sep 17 00:00:00 2001 From: KemingWu Date: Thu, 4 Dec 2025 11:37:48 +0000 Subject: [PATCH 1/3] add task gedit-bench --- examples/models/bagel.sh | 95 +++- lmms_eval/api/task.py | 10 +- lmms_eval/models/simple/bagel.py | 319 ++++++++--- lmms_eval/tasks/gedit_bench/__init__.py | 4 + lmms_eval/tasks/gedit_bench/gedit_bench.yaml | 110 ++++ lmms_eval/tasks/gedit_bench/secret.env | 0 lmms_eval/tasks/gedit_bench/utils.py | 526 ++++++++++++++++++ .../tasks/gedit_bench/viescore/__init__.py | 141 +++++ .../viescore/mllm_tools/__init__.py | 0 .../gedit_bench/viescore/mllm_tools/gemini.py | 154 +++++ .../viescore/mllm_tools/idefics2_eval.py | 41 ++ .../mllm_tools/mantis_idefics2_eval.py | 41 ++ .../viescore/mllm_tools/minicpmv_eval.py | 42 ++ .../gedit_bench/viescore/mllm_tools/openai.py | 169 ++++++ .../viescore/mllm_tools/qwen25vl_eval.py | 85 +++ .../gedit_bench/viescore/mllm_tools/utils.py | 70 +++ .../viescore/mllm_tools/vllm_qwen_eval.py | 220 ++++++++ .../gedit_bench/viescore/parse_prompt.py | 22 + lmms_eval/tasks/gedit_bench/viescore/utils.py | 369 ++++++++++++ .../tasks/gedit_bench/viescore/vie_prompts.py | 405 ++++++++++++++ 20 files changed, 2749 insertions(+), 74 deletions(-) create mode 100644 lmms_eval/tasks/gedit_bench/__init__.py create mode 100644 lmms_eval/tasks/gedit_bench/gedit_bench.yaml create mode 100644 lmms_eval/tasks/gedit_bench/secret.env create mode 100644 lmms_eval/tasks/gedit_bench/utils.py create mode 100644 lmms_eval/tasks/gedit_bench/viescore/__init__.py create mode 100644 lmms_eval/tasks/gedit_bench/viescore/mllm_tools/__init__.py create mode 100644 lmms_eval/tasks/gedit_bench/viescore/mllm_tools/gemini.py create mode 100644 lmms_eval/tasks/gedit_bench/viescore/mllm_tools/idefics2_eval.py create mode 100644 lmms_eval/tasks/gedit_bench/viescore/mllm_tools/mantis_idefics2_eval.py create mode 100644 lmms_eval/tasks/gedit_bench/viescore/mllm_tools/minicpmv_eval.py create mode 100644 lmms_eval/tasks/gedit_bench/viescore/mllm_tools/openai.py create mode 100644 lmms_eval/tasks/gedit_bench/viescore/mllm_tools/qwen25vl_eval.py create mode 100644 lmms_eval/tasks/gedit_bench/viescore/mllm_tools/utils.py create mode 100644 lmms_eval/tasks/gedit_bench/viescore/mllm_tools/vllm_qwen_eval.py create mode 100644 lmms_eval/tasks/gedit_bench/viescore/parse_prompt.py create mode 100644 lmms_eval/tasks/gedit_bench/viescore/utils.py create mode 100644 lmms_eval/tasks/gedit_bench/viescore/vie_prompts.py diff --git a/examples/models/bagel.sh b/examples/models/bagel.sh index fee4e6ee0..0ed9ceaef 100644 --- a/examples/models/bagel.sh +++ b/examples/models/bagel.sh @@ -1,9 +1,9 @@ #!/bin/bash -# Bagel Model Evaluation Script +# Bagel Model Evaluation Script for GEdit-Bench # # This script demonstrates how to run lmms-eval with the Bagel multimodal model -# for text-to-image generation tasks. +# for image editing tasks using GEdit-Bench. # # Prerequisites: # 1. Clone Bagel repository at lmms-eval root: @@ -14,19 +14,92 @@ # Download from https://huggingface.co/ByteDance-Seed/BAGEL-7B-MoT # # Usage: -# bash examples/models/bagel.sh +# # Use local Qwen2.5-VL for evaluation: +# bash examples/models/bagel.sh qwen25vl +# +# # Use vLLM remote Qwen for evaluation: +# bash examples/models/bagel.sh vllm_qwen +# +# # Use GPT-4o for evaluation: +# bash examples/models/bagel.sh gpt4o + +# Activate conda environment (uncomment and modify if needed) +# source miniconda3/etc/profile.d/conda.sh +# conda activate lmms-eval + +# ============================================ +# Configuration +# ============================================ + +MODEL_PATH=/your/path/to/models/BAGEL-7B-MoT +TASK=gedit_bench + +# GEdit-Bench environment variables +export GEDIT_BENCH_MODEL_NAME="bagel" +export GEDIT_BENCH_OUTPUT_DIR="./logs/bagel_persistent_folder/bagel_generated_images" +export GEDIT_BENCH_VIE_KEY_PATH="./lmms_eval/tasks/gedit_bench/secret.env" -# Set model path - should point to the model weights directory -# Can be absolute path or relative path -MODEL_PATH=$1 -export GOOGLE_API_KEY= -TASK=$2 +# ============================================ +# Evaluation Backend Selection +# ============================================ -# Run evaluation with BFloat16 (default, full precision) - accelerate launch -m lmms_eval \ +# Get backend from command line argument, default to "qwen25vl" +EVAL_BACKBONE=${1:-vllm_qwen25vl} + +if [ "$EVAL_BACKBONE" == "vllm_qwen" ] || [ "$EVAL_BACKBONE" == "vllm_qwen25vl" ] || [ "$EVAL_BACKBONE" == "vllm_qwen3vl" ]; then + echo "Using vLLM Qwen for VIEScore evaluation..." + export GEDIT_BENCH_VIE_BACKBONE="$EVAL_BACKBONE" + # vLLM API settings - modify these for your setup + export VLLM_API_BASE="http://host:8000/v1" + # export VLLM_API_BASE="${VLLM_API_BASE:-http://localhost:8000/v1}" + export VLLM_API_KEY="${VLLM_API_KEY:-EMPTY}" + export VLLM_MODEL_NAME="${VLLM_MODEL_NAME:-Qwen/Qwen2.5-VL-72B-Instruct-AWQ}" + echo " VLLM_API_BASE: $VLLM_API_BASE" + echo " VLLM_MODEL_NAME: $VLLM_MODEL_NAME" +elif [ "$EVAL_BACKBONE" == "gpt4o" ]; then + echo "Using GPT-4o for VIEScore evaluation..." + export GEDIT_BENCH_VIE_BACKBONE="gpt4o" + # Set your OpenAI API key + # export OPENAI_API_KEY="your-api-key-here" +else + echo "Using local Qwen2.5-VL for VIEScore evaluation..." + export GEDIT_BENCH_VIE_BACKBONE="qwen25vl" +fi + +# ============================================ +# Run Evaluation +# ============================================ + +echo "============================================" +echo "Starting GEdit-Bench evaluation..." +echo "============================================" +echo " Model: Bagel" +echo " Model Path: $MODEL_PATH" +echo " Evaluation Backend: $GEDIT_BENCH_VIE_BACKBONE" +echo " Output Directory: $GEDIT_BENCH_OUTPUT_DIR" +echo "============================================" +echo "" + +# 图像编辑任务 (GEdit-Bench) +# task_mode=edit: 输入图像 + 编辑指令 -> 编辑后的图像 +accelerate launch -m lmms_eval \ --model bagel \ - --model_args pretrained=${MODEL_PATH},mode=1 \ + --model_args pretrained=${MODEL_PATH},task_mode=edit \ --tasks $TASK \ --batch_size 1 \ --log_samples \ --output_path ./logs/ + +echo "" +echo "============================================" +echo "Evaluation complete!" +echo "============================================" + +# 如果是文本生图任务,使用 task_mode=generate: +# accelerate launch -m lmms_eval \ +# --model bagel \ +# --model_args pretrained=${MODEL_PATH},task_mode=generate \ +# --tasks ueval \ +# --batch_size 1 \ +# --log_samples \ +# --output_path ./logs/ diff --git a/lmms_eval/api/task.py b/lmms_eval/api/task.py index ed9aa7576..73309ca38 100755 --- a/lmms_eval/api/task.py +++ b/lmms_eval/api/task.py @@ -1057,13 +1057,21 @@ def concat_tar_parts(tar_parts, output_tar): **dataset_kwargs if dataset_kwargs is not None else {}, ) + # Ensure dataset is a DatasetDict so downstream logic that expects multiple splits works. + if not isinstance(self.dataset, datasets.DatasetDict): + split_name = self.config.test_split or self.config.validation_split or self.config.training_split or "train" + self.dataset = datasets.DatasetDict({split_name: self.dataset}) + if self.config.process_docs is not None: for split in self.dataset: if split in [self.config.training_split, self.config.validation_split, self.config.test_split, self.config.fewshot_split]: self.dataset[split] = self.config.process_docs(self.dataset[split]) # copy dataset, remove image features - self.dataset_no_image = self.dataset.copy() + try: + self.dataset_no_image = self.dataset.copy() + except AttributeError: + self.dataset_no_image = datasets.DatasetDict({k: v for k, v in self.dataset.items()}) for doc_name in self.dataset_no_image: remove_cols = [] features = self.dataset_no_image[doc_name].features diff --git a/lmms_eval/models/simple/bagel.py b/lmms_eval/models/simple/bagel.py index 8d95dbecc..82df1acc4 100644 --- a/lmms_eval/models/simple/bagel.py +++ b/lmms_eval/models/simple/bagel.py @@ -6,12 +6,7 @@ import numpy as np import torch -from accelerate import ( - Accelerator, - infer_auto_device_map, - init_empty_weights, - load_checkpoint_and_dispatch, -) +from accelerate import Accelerator, init_empty_weights, load_checkpoint_and_dispatch from loguru import logger as eval_logger from tqdm import tqdm @@ -30,19 +25,84 @@ eval_logger.warning(f"Bagel repository not found at {bagel_path}. " f"Please clone it: cd {wd} && git clone https://github.com/ByteDance-Seed/Bagel.git") +def _check_bagel_modifications(bagel_path: str) -> bool: + """ + Check if the required modifications have been made to the Bagel repository. + + Returns True if modifications are detected, False otherwise. + """ + modifications_needed = [] + + # Check 1: Bagel/modeling/bagel/bagel.py - forward_cache_update_vae dtype/device fix + bagel_model_file = os.path.join(bagel_path, "modeling/bagel/bagel.py") + if os.path.exists(bagel_model_file): + with open(bagel_model_file, "r") as f: + content = f.read() + if "padded_images.to(device=vae_model.encoder.conv_in.weight.device" not in content: + modifications_needed.append("Bagel/modeling/bagel/bagel.py: forward_cache_update_vae() method needs dtype/device fix for VAE encode") + + # Check 2: Bagel/inferencer.py - decode_image dtype/device fix + inferencer_file = os.path.join(bagel_path, "inferencer.py") + if os.path.exists(inferencer_file): + with open(inferencer_file, "r") as f: + content = f.read() + if "latent.to(device=self.vae_model.decoder.conv_in.weight.device" not in content: + modifications_needed.append("Bagel/inferencer.py: decode_image() method needs dtype/device fix for VAE decode") + + if modifications_needed: + eval_logger.warning("=" * 80) + eval_logger.warning("IMPORTANT: Bagel repository requires modifications to work with lmms-eval!") + eval_logger.warning("=" * 80) + eval_logger.warning("") + eval_logger.warning("The following modifications are needed to fix dtype/device mismatch errors:") + eval_logger.warning("") + for i, mod in enumerate(modifications_needed, 1): + eval_logger.warning(f" {i}. {mod}") + eval_logger.warning("") + eval_logger.warning("Required changes:") + eval_logger.warning("") + eval_logger.warning("1. In Bagel/modeling/bagel/bagel.py, find forward_cache_update_vae() method and add:") + eval_logger.warning(" BEFORE: padded_latent = vae_model.encode(padded_images)") + eval_logger.warning(" AFTER: padded_images = padded_images.to(device=vae_model.encoder.conv_in.weight.device, dtype=vae_model.encoder.conv_in.weight.dtype)") + eval_logger.warning(" padded_latent = vae_model.encode(padded_images)") + eval_logger.warning("") + eval_logger.warning("2. In Bagel/inferencer.py, find decode_image() method and add:") + eval_logger.warning(" BEFORE: image = self.vae_model.decode(latent)") + eval_logger.warning(" AFTER: latent = latent.to(device=self.vae_model.decoder.conv_in.weight.device, dtype=self.vae_model.decoder.conv_in.weight.dtype)") + eval_logger.warning(" image = self.vae_model.decode(latent)") + eval_logger.warning("") + eval_logger.warning("=" * 80) + return False + return True + + +# Check for required Bagel modifications when module is loaded +if os.path.exists(bagel_path): + _check_bagel_modifications(bagel_path) + + @register_model("bagel") class Bagel(lmms): """ Bagel Multimodal Model - Supports text-to-image generation with optional thinking process + Supports text-to-image generation and image editing with optional thinking process Example usage: + # Text-to-Image Generation accelerate launch -m lmms_eval \ --model bagel \ - --model_args pretrained=/path/to/BAGEL-7B-MoT,mode=1 \ + --model_args pretrained=/path/to/BAGEL-7B-MoT,task_mode=generate \ --tasks ueval \ --batch_size 1 \ --output_path ./logs/ + + # Image Editing (e.g., GEdit-Bench) + accelerate launch -m lmms_eval \ + --model bagel \ + --model_args pretrained=/path/to/BAGEL-7B-MoT,task_mode=edit \ + --tasks gedit_bench \ + --batch_size 1 \ + --output_path ./logs/ """ def __init__( @@ -52,7 +112,9 @@ def __init__( load_in_8bit: bool = False, output_image_dir: Optional[str] = None, show_thinking: bool = False, + task_mode: str = "generate", # "generate" for text-to-image, "edit" for image editing cfg_text_scale: float = 4.0, + cfg_img_scale: float = 2.0, # Image guidance scale for edit tasks cfg_interval: float = 0.4, timestep_shift: float = 3.0, num_timesteps: int = 50, @@ -113,6 +175,13 @@ def __init__( self.load_in_8bit = load_in_8bit self.show_thinking = show_thinking self.continual_mode = continual_mode + self.task_mode = task_mode # "generate" or "edit" + + # Validate task mode + if task_mode not in ["generate", "edit"]: + raise ValueError(f"Invalid task_mode: {task_mode}. Must be 'generate' or 'edit'") + + eval_logger.info(f"Bagel task_mode: {task_mode}") # Validate quantization settings if load_in_4bit and load_in_8bit: @@ -128,11 +197,16 @@ def __init__( # Generation hyperparameters self.cfg_text_scale = cfg_text_scale + self.cfg_img_scale = cfg_img_scale # For edit tasks self.cfg_interval = cfg_interval self.timestep_shift = timestep_shift self.num_timesteps = num_timesteps self.cfg_renorm_min = cfg_renorm_min - self.cfg_renorm_type = cfg_renorm_type + # Use different default cfg_renorm_type based on task_mode + if task_mode == "edit" and cfg_renorm_type == "global": + self.cfg_renorm_type = "text_channel" # Better for edit tasks + else: + self.cfg_renorm_type = cfg_renorm_type self.max_think_token_n = max_think_token_n self.do_sample = do_sample self.text_temperature = text_temperature @@ -247,38 +321,25 @@ def _load_model(self): vae_transform = self.ImageTransform(1024, 512, 16) vit_transform = self.ImageTransform(980, 224, 14) - # Setup device map for multi-GPU - device_map = infer_auto_device_map( - model, - max_memory={i: "80GiB" for i in range(torch.cuda.device_count())}, - no_split_module_classes=["Bagel", "Qwen2MoTDecoderLayer"], - ) - - # Ensure certain modules are on the same device - same_device_modules = ["language_model.model.embed_tokens", "time_embedder", "latent_pos_embed", "vae2llm", "llm2vae", "connector", "vit_pos_embed"] - - if torch.cuda.device_count() == 1: - first_device = device_map.get(same_device_modules[0], "cuda:0") - for k in same_device_modules: - device_map[k] = first_device if k in device_map else "cuda:0" - else: - first_device = device_map.get(same_device_modules[0]) - for k in same_device_modules: - if k in device_map: - device_map[k] = first_device - # Load checkpoint based on precision mode checkpoint_path = os.path.join(model_path, "ema.safetensors") + local_rank = self._rank + if hasattr(self, "accelerator") and self.accelerator is not None: + local_rank = self.accelerator.local_process_index + device = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu" + eval_logger.info(f"Loading model to {device}") + + device_map = {"": device} + if self.precision_mode == "bf16": - # BF16: Full precision + inference_dtype = torch.bfloat16 model = load_checkpoint_and_dispatch( model, checkpoint=checkpoint_path, device_map=device_map, - offload_buffers=True, - offload_folder="offload", - dtype=torch.bfloat16, + offload_buffers=False, + dtype=inference_dtype, force_hooks=True, ).eval() eval_logger.info("Loaded model in BFloat16 precision") @@ -302,6 +363,7 @@ def _load_model(self): eval_logger.info("Loaded model in 4-bit (NF4) quantization") except ImportError: raise ImportError("4-bit quantization requires bitsandbytes. " "Install it with: pip install bitsandbytes") + inference_dtype = torch.bfloat16 elif self.precision_mode == "8bit": # INT8: 8-bit quantization @@ -322,10 +384,15 @@ def _load_model(self): eval_logger.info("Loaded model in 8-bit (INT8) quantization") except ImportError: raise ImportError("8-bit quantization requires bitsandbytes. " "Install it with: pip install bitsandbytes") + inference_dtype = torch.float32 else: raise ValueError(f"Unknown precision mode: {self.precision_mode}") + # Move VAE model to the same device/dtype as main model + vae_model = vae_model.to(device, dtype=inference_dtype) + eval_logger.info(f"Moved VAE model to {device} (dtype={inference_dtype})") + # Create inferencer self.inferencer = self.InterleaveInferencer( model=model, @@ -338,6 +405,7 @@ def _load_model(self): self._model = model self._tokenizer = tokenizer + self._device = device @property def rank(self): @@ -355,6 +423,10 @@ def model(self): def tokenizer(self): return self._tokenizer + @property + def device(self): + return self._device + def set_seed(self, seed: int): """Set random seeds for reproducibility""" if seed > 0: @@ -369,49 +441,126 @@ def set_seed(self, seed: int): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False - def generate_text_and_image(self, prompt: str, doc_id: str, task: str) -> Tuple[str, List[str]]: + def generate_text_and_image(self, prompt: str, doc_id: str, task: str, input_image=None, key: str = None, task_type: str = None, instruction_language: str = None, source_image=None, edit_type: str = None) -> Tuple[str, List[str]]: """ - Generate text and image from prompt + Generate or edit image based on prompt and optional input image Args: - prompt: Input text prompt + prompt: Input text prompt (for generation) or editing instruction (for edit) doc_id: Document ID for file naming task: Task name for file naming + input_image: Optional PIL Image for editing tasks + key: Unique key for naming (used by GEdit-Bench and ImgEdit) + task_type: Task type for GEdit-Bench (e.g., "background_change") + instruction_language: Language for GEdit-Bench ("en" or "cn") + source_image: Original source image to save as _SRCIMG + edit_type: Edit type for ImgEdit (e.g., "replace", "add", "adjust") Returns: Tuple of (generated_text, list_of_image_paths) """ self.set_seed(self.seed) - # Prepare inference hyperparameters - inference_hyper = { - "max_think_token_n": self.max_think_token_n if self.show_thinking else 1024, - "do_sample": self.do_sample if self.show_thinking else False, - "text_temperature": self.text_temperature if self.show_thinking else 0.3, - "cfg_text_scale": self.cfg_text_scale, - "cfg_interval": [self.cfg_interval, 1.0], - "timestep_shift": self.timestep_shift, - "num_timesteps": self.num_timesteps, - "cfg_renorm_min": self.cfg_renorm_min, - "cfg_renorm_type": self.cfg_renorm_type, - "image_shapes": self.image_shapes, - } - - # Generate - result = self.inferencer(text=prompt, think=self.show_thinking, **inference_hyper) + if self.task_mode == "edit": + # Image editing mode + if input_image is None: + eval_logger.warning(f"Edit mode but no input image provided for doc_id {doc_id}") + return "", [] + + # Prepare edit hyperparameters (different from generation) + inference_hyper = { + "max_think_token_n": self.max_think_token_n if self.show_thinking else 1024, + "do_sample": self.do_sample if self.show_thinking else False, + "text_temperature": self.text_temperature if self.show_thinking else 0.3, + "cfg_text_scale": self.cfg_text_scale, + "cfg_img_scale": self.cfg_img_scale, + "cfg_interval": [0.0, 1.0], # Edit tasks use [0.0, 1.0] + "timestep_shift": self.timestep_shift, + "num_timesteps": self.num_timesteps, + "cfg_renorm_min": self.cfg_renorm_min, + "cfg_renorm_type": self.cfg_renorm_type, + } + + # Ensure input_image is RGB + if hasattr(input_image, "convert"): + input_image = input_image.convert("RGB") + + # Generate edited image + result = self.inferencer(image=input_image, text=prompt, think=self.show_thinking, **inference_hyper) + else: + # Text-to-image generation mode + # Prepare generation hyperparameters + inference_hyper = { + "max_think_token_n": self.max_think_token_n if self.show_thinking else 1024, + "do_sample": self.do_sample if self.show_thinking else False, + "text_temperature": self.text_temperature if self.show_thinking else 0.3, + "cfg_text_scale": self.cfg_text_scale, + "cfg_interval": [self.cfg_interval, 1.0], + "timestep_shift": self.timestep_shift, + "num_timesteps": self.num_timesteps, + "cfg_renorm_min": self.cfg_renorm_min, + "cfg_renorm_type": self.cfg_renorm_type, + "image_shapes": self.image_shapes, + } + + # Generate new image + result = self.inferencer(text=prompt, think=self.show_thinking, **inference_hyper) # Extract text output_text = result.get("text", "") - # Save image + # Save image based on task type output_images = [] if "image" in result and result["image"] is not None: image = result["image"] - safe_filename = f"{task}_{doc_id}.png" - image_path = os.path.join(self.output_image_dir, safe_filename) - image.save(image_path) - output_images.append(image_path) - eval_logger.info(f"Saved image: {image_path}") + + # Check if this is ImgEdit task (has IMGEDIT_MODEL_NAME env var or edit_type) + imgedit_model_name = os.getenv("IMGEDIT_MODEL_NAME") + gedit_model_name = os.getenv("GEDIT_BENCH_MODEL_NAME") + + if imgedit_model_name and key and edit_type: + # ImgEdit style path: {output_dir}/{model_name}/{key}.png + save_dir = os.path.join(self.output_image_dir, imgedit_model_name) + os.makedirs(save_dir, exist_ok=True) + + # Save generated image + image_path = os.path.join(save_dir, f"{key}.png") + image.save(image_path) + output_images.append(image_path) + eval_logger.info(f"Saved ImgEdit image: {image_path}") + + # Save source image as _SRCIMG if provided + if source_image is not None: + src_image_path = os.path.join(save_dir, f"{key}_SRCIMG.png") + if hasattr(source_image, "save"): + source_image.save(src_image_path) + eval_logger.info(f"Saved source image: {src_image_path}") + + elif key and task_type and instruction_language: + # GEdit-Bench style path: {output_dir}/{model_name}/fullset/{task_type}/{instruction_language}/{key}.png + model_name = gedit_model_name or "bagel" + save_dir = os.path.join(self.output_image_dir, model_name, "fullset", task_type, instruction_language) + os.makedirs(save_dir, exist_ok=True) + + # Save generated image + image_path = os.path.join(save_dir, f"{key}.png") + image.save(image_path) + output_images.append(image_path) + eval_logger.info(f"Saved GEdit-Bench image: {image_path}") + + # Save source image as _SRCIMG if provided + if source_image is not None: + src_image_path = os.path.join(save_dir, f"{key}_SRCIMG.png") + if hasattr(source_image, "save"): + source_image.save(src_image_path) + eval_logger.info(f"Saved source image: {src_image_path}") + else: + # Fallback to simple naming + safe_filename = f"{task}_{doc_id}.png" + image_path = os.path.join(self.output_image_dir, safe_filename) + image.save(image_path) + output_images.append(image_path) + eval_logger.info(f"Saved image: {image_path}") return output_text, output_images @@ -423,12 +572,15 @@ def format_output(self, text: str, images: List[str]) -> str: def generate_until(self, requests: List[Instance]) -> List[str]: """Main inference method""" res = [] - pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Bagel Generating") + desc = "Bagel Editing" if self.task_mode == "edit" else "Bagel Generating" + pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc=desc) def get_uuid(task, split, doc_id): return f"{task}___{split}___{doc_id}" - for contexts, _, _, doc_id, task, split in [reg.args for reg in requests]: + for req in requests: + # Unpack arguments: (ctx, generation_kwargs, doc_to_visual, doc_id, task, split) + contexts, gen_kwargs, doc_to_visual, doc_id, task, split = req.args doc_uuid = get_uuid(task, split, doc_id) # Check cache @@ -440,9 +592,52 @@ def get_uuid(task, split, doc_id): pbar.update(1) continue - # Generate + # Get input image and doc metadata for edit tasks + input_image = None + source_image = None + key = None + task_type = None + instruction_language = None + edit_type = None # For ImgEdit + doc = None + + # Try to get the document from task dataset + try: + doc = self.task_dict[task][split][doc_id] + except Exception as e: + eval_logger.debug(f"Could not get doc for doc_id {doc_id}: {e}") + + if self.task_mode == "edit": + # doc_to_visual is a function that returns list of images + if callable(doc_to_visual) and doc is not None: + try: + visuals = doc_to_visual(doc) + if visuals and len(visuals) > 0: + input_image = visuals[0] + if hasattr(input_image, "convert"): + input_image = input_image.convert("RGB") + eval_logger.debug(f"Got input image for doc_id {doc_id}") + except Exception as e: + eval_logger.warning(f"Failed to get input image for doc_id {doc_id}: {e}") + + # Extract task-specific fields from doc + if doc is not None: + key = doc.get("key", str(doc_id)) + # GEdit-Bench specific fields + task_type = doc.get("task_type", "unknown") + instruction_language = doc.get("instruction_language", "en") + # ImgEdit specific fields + edit_type = doc.get("edit_type") # e.g., "replace", "add", "adjust" + # Get source image (original un-resized image) for saving as _SRCIMG + source_image = doc.get("input_image") or doc.get("input_image_raw") + if source_image and hasattr(source_image, "convert"): + source_image = source_image.convert("RGB") + + # Generate/Edit prompt = contexts - output_text, output_images = self.generate_text_and_image(prompt, str(doc_id), task) + output_text, output_images = self.generate_text_and_image( + prompt, str(doc_id), task, input_image=input_image, key=key, task_type=task_type, instruction_language=instruction_language, source_image=source_image, edit_type=edit_type + ) # Format output formatted_output = self.format_output(output_text, output_images) diff --git a/lmms_eval/tasks/gedit_bench/__init__.py b/lmms_eval/tasks/gedit_bench/__init__.py new file mode 100644 index 000000000..fb8c1a656 --- /dev/null +++ b/lmms_eval/tasks/gedit_bench/__init__.py @@ -0,0 +1,4 @@ +""" +GEdit-Bench task package initializer. +Ensures relative imports (e.g. from .viescore import VIEScore) work correctly. +""" diff --git a/lmms_eval/tasks/gedit_bench/gedit_bench.yaml b/lmms_eval/tasks/gedit_bench/gedit_bench.yaml new file mode 100644 index 000000000..141b0db8c --- /dev/null +++ b/lmms_eval/tasks/gedit_bench/gedit_bench.yaml @@ -0,0 +1,110 @@ +dataset_path: stepfun-ai/GEdit-Bench +dataset_kwargs: + token: True + load_from_disk: True +task: "gedit_bench" +test_split: train +output_type: generate_until +doc_to_visual: !function utils.gedit_bench_doc_to_visual +doc_to_text: !function utils.gedit_bench_doc_to_text +doc_to_target: "instruction" +generation_kwargs: + max_new_tokens: 512 + temperature: 0 + top_p: 1.0 + num_beams: 1 + do_sample: false +# The return value of process_results will be used by metrics +process_results: !function utils.gedit_bench_process_results + +# Metrics breakdown: +# - Overall scores (all samples) +# - By language: en (English), cn (Chinese) +# - By subset: fullset (all), intersection (subset with both en and cn) +# - Score types: semantics, quality, overall + +metric_list: + # ========================================== + # Overall Scores (All Samples) + # ========================================== + - metric: gedit_bench_semantics_score + aggregation: !function utils.gedit_bench_aggregate_results + higher_is_better: true + - metric: gedit_bench_quality_score + aggregation: !function utils.gedit_bench_aggregate_results + higher_is_better: true + - metric: gedit_bench_overall_score + aggregation: !function utils.gedit_bench_aggregate_results + higher_is_better: true + + # ========================================== + # English - Fullset + # ========================================== + - metric: gedit_bench_semantics_score + aggregation: !function utils.gedit_bench_aggregate_en_fullset + higher_is_better: true + metric_name: gedit_bench_en_fullset_semantics + - metric: gedit_bench_quality_score + aggregation: !function utils.gedit_bench_aggregate_en_fullset + higher_is_better: true + metric_name: gedit_bench_en_fullset_quality + - metric: gedit_bench_overall_score + aggregation: !function utils.gedit_bench_aggregate_en_fullset + higher_is_better: true + metric_name: gedit_bench_en_fullset_overall + + # ========================================== + # English - Intersection + # ========================================== + - metric: gedit_bench_semantics_score + aggregation: !function utils.gedit_bench_aggregate_en_intersection + higher_is_better: true + metric_name: gedit_bench_en_intersection_semantics + - metric: gedit_bench_quality_score + aggregation: !function utils.gedit_bench_aggregate_en_intersection + higher_is_better: true + metric_name: gedit_bench_en_intersection_quality + - metric: gedit_bench_overall_score + aggregation: !function utils.gedit_bench_aggregate_en_intersection + higher_is_better: true + metric_name: gedit_bench_en_intersection_overall + + # ========================================== + # Chinese - Fullset + # ========================================== + - metric: gedit_bench_semantics_score + aggregation: !function utils.gedit_bench_aggregate_cn_fullset + higher_is_better: true + metric_name: gedit_bench_cn_fullset_semantics + - metric: gedit_bench_quality_score + aggregation: !function utils.gedit_bench_aggregate_cn_fullset + higher_is_better: true + metric_name: gedit_bench_cn_fullset_quality + - metric: gedit_bench_overall_score + aggregation: !function utils.gedit_bench_aggregate_cn_fullset + higher_is_better: true + metric_name: gedit_bench_cn_fullset_overall + + # ========================================== + # Chinese - Intersection + # ========================================== + - metric: gedit_bench_semantics_score + aggregation: !function utils.gedit_bench_aggregate_cn_intersection + higher_is_better: true + metric_name: gedit_bench_cn_intersection_semantics + - metric: gedit_bench_quality_score + aggregation: !function utils.gedit_bench_aggregate_cn_intersection + higher_is_better: true + metric_name: gedit_bench_cn_intersection_quality + - metric: gedit_bench_overall_score + aggregation: !function utils.gedit_bench_aggregate_cn_intersection + higher_is_better: true + metric_name: gedit_bench_cn_intersection_overall + +lmms_eval_specific_kwargs: + default: + pre_prompt: "" + post_prompt: "" +metadata: + - version: 0.1 + description: "GEdit-Bench with detailed metrics by language and subset" diff --git a/lmms_eval/tasks/gedit_bench/secret.env b/lmms_eval/tasks/gedit_bench/secret.env new file mode 100644 index 000000000..e69de29bb diff --git a/lmms_eval/tasks/gedit_bench/utils.py b/lmms_eval/tasks/gedit_bench/utils.py new file mode 100644 index 000000000..b5664315a --- /dev/null +++ b/lmms_eval/tasks/gedit_bench/utils.py @@ -0,0 +1,526 @@ +""" +GEdit-Bench Utils +Image editing evaluation task using VIEScore +""" + +import json +import math +import os +import shutil +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional + +import numpy as np +from loguru import logger as eval_logger +from PIL import Image + +# Try to import VIEScore +try: + from lmms_eval.tasks.gedit_bench.viescore import VIEScore + + VIESCORE_AVAILABLE = True +except ImportError: + VIESCORE_AVAILABLE = False + eval_logger.warning("VIEScore not available. Please install it: pip install viescore. " "Evaluation scores will not be computed.") + +# Task groups for GEdit-Bench +GEDIT_BENCH_GROUPS = [ + "background_change", + "color_alter", + "material_alter", + "motion_change", + "ps_human", + "style_change", + "subject-add", + "subject-remove", + "subject-replace", + "text_change", + "tone_transfer", +] + + +def calculate_dimensions(target_area, ratio): + """Calculate dimensions maintaining aspect ratio""" + width = math.sqrt(target_area * ratio) + height = width / ratio + new_area = width * height + return int(width), int(height), int(new_area) + + +def _get_vie_score(backbone: str = "gpt4o", key_path: Optional[str] = None): + """ + Get or create VIEScore instance. + Note: In multi-process environments, each process will have its own instance. + """ + if not VIESCORE_AVAILABLE: + raise ImportError("VIEScore is not available. Please install it: pip install viescore") + + # Create a new instance each time (safe for multi-process) + # VIEScore initialization is relatively lightweight + return VIEScore(backbone=backbone, task="tie", key_path=key_path) + + +def gedit_bench_doc_to_visual(doc): + """Extract input image from document""" + # Try different possible field names + input_image = doc.get("input_image") or doc.get("input_image_raw") + if input_image is None: + eval_logger.warning(f"No input image found in document. Available keys: {list(doc.keys())}") + return [] + # Convert to RGB if it's a PIL Image + if hasattr(input_image, "convert"): + return [input_image.convert("RGB")] + return [input_image] + + +def gedit_bench_doc_to_text(doc, lmms_eval_specific_kwargs=None): + """Extract instruction text from document""" + instruction = doc.get("instruction", "").strip() + pre_prompt = "" + post_prompt = "" + if lmms_eval_specific_kwargs: + pre_prompt = lmms_eval_specific_kwargs.get("pre_prompt", "") + post_prompt = lmms_eval_specific_kwargs.get("post_prompt", "") + return f"{pre_prompt}{instruction}{post_prompt}" + + +def gedit_bench_doc_to_target(doc): + """Extract target instruction (for reference)""" + return doc.get("instruction", "") + + +def _save_image_to_structure( + image_path: str, + key: str, + task_type: str, + instruction_language: str, + output_base_dir: str, + model_name: str = "default", +) -> str: + """ + Save image to the required directory structure: + results/{model_name}/fullset/{task_type}/{instruction_language}/{key}.png + + Args: + image_path: Path to the generated image + key: Unique key for this sample + task_type: Task type (e.g., "background_change") + instruction_language: Language of instruction ("en" or "cn") + output_base_dir: Base directory for outputs + model_name: Name of the model being evaluated + + Returns: + Path to the saved image + """ + # Create directory structure + save_dir = os.path.join(output_base_dir, model_name, "fullset", task_type, instruction_language) + os.makedirs(save_dir, exist_ok=True) + + # Save image with key as filename (preserve extension if exists) + if os.path.exists(image_path): + # Copy image to new location + save_path = os.path.join(save_dir, f"{key}.png") + shutil.copy2(image_path, save_path) + return save_path + else: + eval_logger.warning(f"Image not found at {image_path}, skipping save") + return "" + + +def gedit_bench_process_results(doc, results, **kwargs): + """ + Process model predictions: + 1. Parse JSON output to extract text and images + 2. Save images to required directory structure + 3. Evaluate using VIEScore + + Args: + doc: Document containing input image, instruction, key, task_type, etc. + results: Model predictions [JSON string with {"text": "...", "images": [...]}] + **kwargs: Additional arguments (may include full_docs) + + Returns: + Dict with metrics: semantics_score, quality_score, overall_score + """ + # Get configuration from environment variables or use defaults + # Note: defaults should match bagel.py's output_image_dir structure + model_name = os.getenv("GEDIT_BENCH_MODEL_NAME", "bagel") + output_base_dir = os.getenv("GEDIT_BENCH_OUTPUT_DIR", "./logs/bagel_persistent_folder/bagel_generated_images") + vie_backbone = os.getenv("GEDIT_BENCH_VIE_BACKBONE", "gpt4o") + vie_key_path = os.getenv("GEDIT_BENCH_VIE_KEY_PATH", None) + pred = results[0] if results else "{}" + try: + pred = json.loads(pred) + except (json.JSONDecodeError, TypeError): + eval_logger.warning(f"Failed to parse prediction JSON: {pred}") + pred = {"text": "", "images": []} + + model_text = pred.get("text", "") + model_images = pred.get("images", []) + + # Extract document fields + key = doc.get("key", "unknown") + task_type = doc.get("task_type", "unknown") + instruction = doc.get("instruction", "") + instruction_language = doc.get("instruction_language", "en") + intersection_exist = doc.get("Intersection_exist", False) + + # Get input image (try different possible field names) + input_image = doc.get("input_image") or doc.get("input_image_raw") + + # If input_image is None, try to load from saved _SRCIMG file + # This happens when process_results is called after generation, and the doc doesn't contain PIL image data + input_image_pil = None + if input_image is not None: + input_image_pil = input_image.convert("RGB") if hasattr(input_image, "convert") else input_image + else: + # Try to load from _SRCIMG file saved during generation + src_img_path = os.path.join(output_base_dir, model_name, "fullset", task_type, instruction_language, f"{key}_SRCIMG.png") + if os.path.exists(src_img_path): + try: + input_image_pil = Image.open(src_img_path).convert("RGB") + eval_logger.debug(f"Loaded source image from {src_img_path}") + except Exception as e: + eval_logger.warning(f"Failed to load source image from {src_img_path}: {e}") + + if input_image_pil is None: + eval_logger.warning(f"No input image found for key {key} (neither in doc nor as _SRCIMG file)") + return { + "gedit_bench_semantics_score": { + "key": key, + "task_type": task_type, + "instruction_language": instruction_language, + "score": 0.0, + "intersection_exist": intersection_exist, + }, + "gedit_bench_quality_score": { + "key": key, + "task_type": task_type, + "instruction_language": instruction_language, + "score": 0.0, + "intersection_exist": intersection_exist, + }, + "gedit_bench_overall_score": { + "key": key, + "task_type": task_type, + "instruction_language": instruction_language, + "score": 0.0, + "intersection_exist": intersection_exist, + }, + } + + # Save generated images to required structure (or use existing) + edited_image_path = None + if model_images and len(model_images) > 0: + # Use first generated image + generated_image_path = model_images[0] + # Check if the image is already at the target location + if os.path.exists(generated_image_path): + edited_image_path = generated_image_path + # Also copy to standard structure if not already there + target_path = os.path.join(output_base_dir, model_name, "fullset", task_type, instruction_language, f"{key}.png") + if generated_image_path != target_path and not os.path.exists(target_path): + edited_image_path = _save_image_to_structure( + generated_image_path, + key, + task_type, + instruction_language, + output_base_dir, + model_name, + ) + else: + eval_logger.warning(f"Generated image not found at {generated_image_path}") + + # If no image from model results, try to find existing generated image in the standard location + if edited_image_path is None: + existing_path = os.path.join(output_base_dir, model_name, "fullset", task_type, instruction_language, f"{key}.png") + if os.path.exists(existing_path): + edited_image_path = existing_path + eval_logger.debug(f"Found existing generated image at {existing_path}") + + # If still no edited image, return zero scores + if edited_image_path is None: + eval_logger.warning(f"No generated images found for key {key}") + return { + "gedit_bench_semantics_score": { + "key": key, + "task_type": task_type, + "instruction_language": instruction_language, + "score": 0.0, + "intersection_exist": intersection_exist, + }, + "gedit_bench_quality_score": { + "key": key, + "task_type": task_type, + "instruction_language": instruction_language, + "score": 0.0, + "intersection_exist": intersection_exist, + }, + "gedit_bench_overall_score": { + "key": key, + "task_type": task_type, + "instruction_language": instruction_language, + "score": 0.0, + "intersection_exist": intersection_exist, + }, + } + + # Evaluate using VIEScore + if not VIESCORE_AVAILABLE: + eval_logger.warning("VIEScore not available, skipping evaluation") + return { + "gedit_bench_semantics_score": { + "key": key, + "task_type": task_type, + "instruction_language": instruction_language, + "score": 0.0, + "intersection_exist": intersection_exist, + }, + "gedit_bench_quality_score": { + "key": key, + "task_type": task_type, + "instruction_language": instruction_language, + "score": 0.0, + "intersection_exist": intersection_exist, + }, + "gedit_bench_overall_score": { + "key": key, + "task_type": task_type, + "instruction_language": instruction_language, + "score": 0.0, + "intersection_exist": intersection_exist, + }, + } + + try: + # Load edited image + edited_image_pil = Image.open(edited_image_path).convert("RGB") + + # Resize images to target area (512x512 equivalent) + source_img_width, source_img_height, _ = calculate_dimensions(512 * 512, input_image_pil.width / input_image_pil.height) + edited_img_width, edited_img_height, _ = calculate_dimensions(512 * 512, edited_image_pil.width / edited_image_pil.height) + + input_image_pil = input_image_pil.resize((source_img_width, source_img_height)) + edited_image_pil = edited_image_pil.resize((edited_img_width, edited_img_height)) + + # Get VIEScore instance + vie_score = _get_vie_score(backbone=vie_backbone, key_path=vie_key_path) + + # Evaluate: VIEScore.evaluate returns [semantics_score, quality_score, overall_score] + score_list = vie_score.evaluate([input_image_pil, edited_image_pil], instruction) + semantics_score, quality_score, overall_score = score_list + + eval_logger.info(f"[{task_type}] Key {key}: " f"Semantics={semantics_score:.3f}, " f"Quality={quality_score:.3f}, " f"Overall={overall_score:.3f}, " f"Language={instruction_language}") + + return { + "gedit_bench_semantics_score": { + "key": key, + "task_type": task_type, + "instruction_language": instruction_language, + "score": float(semantics_score), + "intersection_exist": intersection_exist, + }, + "gedit_bench_quality_score": { + "key": key, + "task_type": task_type, + "instruction_language": instruction_language, + "score": float(quality_score), + "intersection_exist": intersection_exist, + }, + "gedit_bench_overall_score": { + "key": key, + "task_type": task_type, + "instruction_language": instruction_language, + "score": float(overall_score), + "intersection_exist": intersection_exist, + }, + } + except Exception as e: + eval_logger.error(f"Error evaluating key {key}: {e}") + return { + "gedit_bench_semantics_score": { + "key": key, + "task_type": task_type, + "instruction_language": instruction_language, + "score": 0.0, + "intersection_exist": intersection_exist, + }, + "gedit_bench_quality_score": { + "key": key, + "task_type": task_type, + "instruction_language": instruction_language, + "score": 0.0, + "intersection_exist": intersection_exist, + }, + "gedit_bench_overall_score": { + "key": key, + "task_type": task_type, + "instruction_language": instruction_language, + "score": 0.0, + "intersection_exist": intersection_exist, + }, + } + + +def gedit_bench_aggregate_results(results): + """ + Aggregate results across all samples and compute final scores + + Args: + results: List of result dicts from process_results, each containing: + - key: Sample key + - task_type: Task type + - instruction_language: Language ("en" or "cn") + - score: Score value + - intersection_exist: Whether intersection exists + + Returns: + Final aggregated score (average across all samples) + """ + if not results: + return 0.0 + + # Calculate average score + scores = [r["score"] for r in results if "score" in r] + if not scores: + return 0.0 + + avg_score = np.mean(scores) + + # Log breakdown by task type and language + task_type_scores = defaultdict(list) + language_scores = defaultdict(list) + intersection_scores = [] + non_intersection_scores = [] + + for r in results: + if "score" in r: + task_type = r.get("task_type", "unknown") + language = r.get("instruction_language", "unknown") + intersection_exist = r.get("intersection_exist", False) + + task_type_scores[task_type].append(r["score"]) + language_scores[language].append(r["score"]) + + if intersection_exist: + intersection_scores.append(r["score"]) + else: + non_intersection_scores.append(r["score"]) + + # Log statistics + eval_logger.info(f"Overall average score: {avg_score:.3f}") + eval_logger.info(f"Number of samples: {len(scores)}") + + if task_type_scores: + eval_logger.info("Scores by task type:") + for task_type, task_scores in sorted(task_type_scores.items()): + task_avg = np.mean(task_scores) + eval_logger.info(f" {task_type}: {task_avg:.3f} (n={len(task_scores)})") + + if language_scores: + eval_logger.info("Scores by language:") + for language, lang_scores in sorted(language_scores.items()): + lang_avg = np.mean(lang_scores) + eval_logger.info(f" {language}: {lang_avg:.3f} (n={len(lang_scores)})") + + if intersection_scores: + intersection_avg = np.mean(intersection_scores) + eval_logger.info(f"Intersection samples average: {intersection_avg:.3f} (n={len(intersection_scores)})") + + if non_intersection_scores: + non_intersection_avg = np.mean(non_intersection_scores) + eval_logger.info(f"Non-intersection samples average: {non_intersection_avg:.3f} (n={len(non_intersection_scores)})") + + return avg_score + + +# ============================================ +# Detailed Aggregation Functions by Language and Subset +# ============================================ + + +def _aggregate_by_filter(results, language: str = None, intersection_only: bool = None): + """ + Helper function to aggregate scores with filters. + + Args: + results: List of result dicts + language: Filter by language ("en" or "cn"), None for all + intersection_only: True for intersection subset, False for non-intersection, None for all + + Returns: + Average score for filtered samples + """ + if not results: + return 0.0 + + filtered_scores = [] + for r in results: + if "score" not in r: + continue + + # Apply language filter + if language is not None: + if r.get("instruction_language", "unknown") != language: + continue + + # Apply intersection filter + if intersection_only is not None: + is_intersection = r.get("intersection_exist", False) + if intersection_only and not is_intersection: + continue + if not intersection_only and is_intersection: + continue + + filtered_scores.append(r["score"]) + + if not filtered_scores: + return 0.0 + + return float(np.mean(filtered_scores)) + + +# ========================================== +# English Language Aggregations +# ========================================== + + +def gedit_bench_aggregate_en_fullset(results): + """Aggregate English fullset scores (all English samples)""" + return _aggregate_by_filter(results, language="en", intersection_only=None) + + +def gedit_bench_aggregate_en_intersection(results): + """Aggregate English intersection subset scores""" + return _aggregate_by_filter(results, language="en", intersection_only=True) + + +# ========================================== +# Chinese Language Aggregations +# ========================================== + + +def gedit_bench_aggregate_cn_fullset(results): + """Aggregate Chinese fullset scores (all Chinese samples)""" + return _aggregate_by_filter(results, language="cn", intersection_only=None) + + +def gedit_bench_aggregate_cn_intersection(results): + """Aggregate Chinese intersection subset scores""" + return _aggregate_by_filter(results, language="cn", intersection_only=True) + + +# ========================================== +# Intersection/Non-intersection Aggregations (All Languages) +# ========================================== + + +def gedit_bench_aggregate_intersection(results): + """Aggregate intersection subset scores (all languages)""" + return _aggregate_by_filter(results, language=None, intersection_only=True) + + +def gedit_bench_aggregate_fullset(results): + """Aggregate fullset scores (all samples, all languages)""" + return _aggregate_by_filter(results, language=None, intersection_only=None) diff --git a/lmms_eval/tasks/gedit_bench/viescore/__init__.py b/lmms_eval/tasks/gedit_bench/viescore/__init__.py new file mode 100644 index 000000000..1e50433c0 --- /dev/null +++ b/lmms_eval/tasks/gedit_bench/viescore/__init__.py @@ -0,0 +1,141 @@ +import math + +from lmms_eval.tasks.gedit_bench.viescore import vie_prompts +from lmms_eval.tasks.gedit_bench.viescore.utils import mllm_output_to_dict + + +class VIEScore: + def __init__(self, backbone="gpt4o", task="t2i", key_path=None) -> None: + self.task = task + self.backbone_name = backbone + + if self.task not in ["t2i", "tie", "t2v"]: + raise ValueError("task must be either 't2i' or 'tie'") + + if self.backbone_name == "gpt4o": + from lmms_eval.tasks.gedit_bench.viescore.mllm_tools.openai import GPT4o + + self.model = GPT4o(key_path, model_name="gpt-4.1") + elif self.backbone_name == "gpt4v": + from lmms_eval.tasks.gedit_bench.viescore.mllm_tools.openai import GPT4v + + self.model = GPT4v(key_path) + elif self.backbone_name == "gemini": + from lmms_eval.tasks.gedit_bench.viescore.mllm_tools.gemini import Gemini + + self.model = Gemini() + elif self.backbone_name == "idefics2": + from lmms_eval.tasks.gedit_bench.viescore.mllm_tools.idefics2_eval import ( + Idefics2, + ) + + self.model = Idefics2() + elif self.backbone_name == "mantis": + from lmms_eval.tasks.gedit_bench.viescore.mllm_tools.mantis_idefics2_eval import ( + Mantis, + ) + + self.model = Mantis() + elif self.backbone_name == "minicpmv": + from lmms_eval.tasks.gedit_bench.viescore.mllm_tools.minicpmv_eval import ( + MiniCPMV, + ) + + self.model = MiniCPMV() + elif self.backbone_name == "qwen25vl": + from lmms_eval.tasks.gedit_bench.viescore.mllm_tools.qwen25vl_eval import ( + Qwen25VL, + ) + + self.model = Qwen25VL() + # vLLM-based backends (remote API) + elif self.backbone_name == "vllm_qwen" or self.backbone_name == "vllm_qwen25vl": + from lmms_eval.tasks.gedit_bench.viescore.mllm_tools.vllm_qwen_eval import ( + VLLMQwen25VL, + ) + + self.model = VLLMQwen25VL() + elif self.backbone_name == "vllm_qwen3vl": + from lmms_eval.tasks.gedit_bench.viescore.mllm_tools.vllm_qwen_eval import ( + VLLMQwen3VL, + ) + + self.model = VLLMQwen3VL() + else: + raise NotImplementedError(f"backbone '{backbone}' not supported. " f"Available: gpt4o, gpt4v, gemini, idefics2, mantis, minicpmv, qwen25vl, " f"vllm_qwen, vllm_qwen25vl, vllm_qwen3vl") + self.context = vie_prompts._context_no_delimit + if self.task == "t2i": + self.SC_prompt = "\n".join([self.context, vie_prompts._prompts_0shot_one_image_gen_rule, vie_prompts._prompts_0shot_t2i_rule_SC]) + self.PQ_prompt = "\n".join([self.context, vie_prompts._prompts_0shot_rule_PQ]) + elif self.task == "tie": + self.SC_prompt = "\n".join([self.context, vie_prompts._prompts_0shot_two_image_edit_rule, vie_prompts._prompts_0shot_tie_rule_SC]) + self.PQ_prompt = "\n".join([self.context, vie_prompts._prompts_0shot_rule_PQ]) + elif self.task == "t2v": + self.SC_prompt = "\n".join([self.context, vie_prompts._prompts_0shot_one_video_gen_rule, vie_prompts._prompts_0shot_t2v_rule_SC]) + self.PQ_prompt = "\n".join([self.context, vie_prompts._prompts_0shot_t2v_rule_PQ]) + + def evaluate(self, image_prompts, text_prompt, extract_overall_score_only=False, extract_all_score=True, echo_output=False): + if not isinstance(image_prompts, list): + image_prompts = [image_prompts] + if self.backbone_name in ["gpt4o", "gpt4v"]: + self.model.use_encode = False if isinstance(image_prompts[0], str) else True + # print("Using encode:", self.model.use_encode) + if self.task == "t2i": + _SC_prompt = self.SC_prompt.replace("", text_prompt) + elif self.task == "tie": + _SC_prompt = self.SC_prompt.replace("", text_prompt) + elif self.task == "t2v": + _SC_prompt = self.SC_prompt.replace("", text_prompt) + SC_prompt_final = self.model.prepare_prompt(image_prompts, _SC_prompt) + if self.task == "tie": + PQ_prompt_final = self.model.prepare_prompt(image_prompts[-1], self.PQ_prompt) + else: + PQ_prompt_final = self.model.prepare_prompt(image_prompts, self.PQ_prompt) + + results_dict = {} + + SC_dict = False + PQ_dict = False + tries = 0 + max_tries = 1 + while SC_dict is False or PQ_dict is False: + tries += 1 + guess_if_cannot_parse = True if tries > max_tries else False + result_SC = self.model.get_parsed_output(SC_prompt_final) + result_PQ = self.model.get_parsed_output(PQ_prompt_final) + SC_dict = mllm_output_to_dict(result_SC, give_up_parsing=guess_if_cannot_parse) + PQ_dict = mllm_output_to_dict(result_PQ, give_up_parsing=guess_if_cannot_parse) + + if SC_dict == "rate_limit_exceeded" or PQ_dict == "rate_limit_exceeded": + print("rate_limit_exceeded") + raise ValueError("rate_limit_exceeded") + results_dict["SC"] = SC_dict + results_dict["PQ"] = PQ_dict + if echo_output: + print("results_dict", results_dict) + if extract_all_score: + SC_score = min(results_dict["SC"]["score"]) + PQ_score = min(results_dict["PQ"]["score"]) + O_score = math.sqrt(SC_score * PQ_score) + return [SC_score, PQ_score, O_score] + if extract_overall_score_only: + SC_scores = results_dict["SC"]["score"] + PQ_scores = results_dict["PQ"]["score"] + O_score = math.sqrt(min(SC_scores) * min(PQ_scores)) + return O_score + return results_dict + + +if __name__ == "__main__": + model = VIEScore(backbone="gemini", task="t2i") + from datasets import load_dataset + + dataset = load_dataset("TIGER-Lab/GenAI-Arena-Bench", "image_generation") + dataset = dataset["test"] + print("Now running the VIEScore model") + for idx in range(5): + left_image = dataset["left_image"][idx] + right_image = dataset["right_image"][idx] + prompt = dataset["prompt"][idx] + print(model.evaluate(left_image, prompt, extract_all_score=True)) + print(model.evaluate(right_image, prompt, extract_all_score=True)) diff --git a/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/__init__.py b/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/gemini.py b/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/gemini.py new file mode 100644 index 000000000..9f67fc63b --- /dev/null +++ b/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/gemini.py @@ -0,0 +1,154 @@ +""" +Install the Google AI Python SDK + +$ pip install google-generativeai + +See the getting started guide for more information: +https://ai.google.dev/gemini-api/docs/get-started/python +""" + +import os +import tempfile +from io import BytesIO +from typing import List +from urllib.parse import urlparse + +import google.generativeai as genai +import requests +from PIL import Image + +genai.configure(api_key=os.environ["GEMINI_API_KEY"]) + + +def upload_to_gemini(input, mime_type=None): + """Uploads the given file or PIL image to Gemini. + + See https://ai.google.dev/gemini-api/docs/prompting_with_media + """ + if isinstance(input, str): + # Input is a file path + file = genai.upload_file(input, mime_type=mime_type) + elif isinstance(input, Image.Image): + # Input is a PIL image + with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp_file: + input.save(tmp_file, format="JPEG") + tmp_file_path = tmp_file.name + file = genai.upload_file(tmp_file_path, mime_type=mime_type or "image/jpeg") + os.remove(tmp_file_path) + else: + raise ValueError("Unsupported input type. Must be a file path or PIL Image.") + + # print(f"Uploaded file '{file.display_name}' as: {file.uri}") + return file + + +def save_image_from_url(url, base_save_directory="tmp", file_name=None): + # Parse the URL to create a directory path + parsed_url = urlparse(url) + url_path = os.path.join(parsed_url.netloc, parsed_url.path.lstrip("/")) + save_directory = os.path.join(base_save_directory, os.path.dirname(url_path)) + + # Create the directory if it doesn't exist + if not os.path.exists(save_directory): + os.makedirs(save_directory) + + # Get the image from the URL + response = requests.get(url) + if response.status_code == 200: + # Open the image + image = Image.open(BytesIO(response.content)) + + # Set the file name if not provided + if not file_name: + file_name = os.path.basename(parsed_url.path) + + # Save the image locally + file_path = os.path.join(save_directory, file_name) + image.save(file_path) + + return file_path + else: + raise Exception(f"Failed to retrieve image from URL. Status code: {response.status_code}") + + +class Gemini: + def __init__(self, model_name="gemini-1.5-pro-latest"): + # Create the model + # See https://ai.google.dev/api/python/google/generativeai/GenerativeModel + generation_config = { + "temperature": 1, + "top_p": 0.95, + "top_k": 64, + "max_output_tokens": 8192, + "response_mime_type": "text/plain", + } + safety_settings = [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_NONE", + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_NONE", + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_NONE", + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_NONE", + }, + ] + self.model = genai.GenerativeModel( + model_name=model_name, + safety_settings=safety_settings, + generation_config=generation_config, + ) + + def prepare_prompt(self, image_links: List = [], text_prompt: str = ""): + if not isinstance(image_links, list): + image_links = [image_links] + + images_prompt = [] + for image_link in image_links: + if isinstance(image_link, str): + image = save_image_from_url(image_link) + else: + image = image_link + image = upload_to_gemini(image, mime_type="image/jpeg") + images_prompt.append(image) + + prompt_content = [images_prompt, text_prompt] + return prompt_content + + def get_parsed_output(self, prompt): + images_prompt = prompt[0] + text_prompt = prompt[1] + chat_session = self.model.start_chat( + history=[ + { + "role": "user", + "parts": images_prompt, + }, + ] + ) + try: + response = chat_session.send_message(text_prompt) + except: + return "Error in sending message to chat session." + return self.extract_response(response) + + def extract_response(self, response): + response = response.text + return response + + +if __name__ == "__main__": + model = Gemini() + prompt = model.prepare_prompt( + ["https://chromaica.github.io/Museum/ImagenHub_Text-Guided_IE/DiffEdit/sample_34_1.jpg", "https://chromaica.github.io/Museum/ImagenHub_Text-Guided_IE/input/sample_34_1.jpg"], "What is difference between two images?" + ) + print("prompt : \n", prompt) + res = model.get_parsed_output(prompt) + print("result : \n", res) diff --git a/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/idefics2_eval.py b/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/idefics2_eval.py new file mode 100644 index 000000000..2bae04177 --- /dev/null +++ b/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/idefics2_eval.py @@ -0,0 +1,41 @@ +import os +import time +from typing import List + +import torch +from transformers import AutoModelForVision2Seq, AutoProcessor +from transformers.image_utils import load_image +from transformers.utils import is_flash_attn_2_available + + +class Idefics2: + def __init__(self, model_path: str = "HuggingFaceM4/idefics2-8b") -> None: + attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else None + print(f"Using {attn_implementation} for attention implementation") + self.model = AutoModelForVision2Seq.from_pretrained(model_path, device_map="auto", torch_dtype=torch.float16, _attn_implementation=attn_implementation).eval() + self.processor = AutoProcessor.from_pretrained(model_path) + + def prepare_prompt(self, image_links: List = [], text_prompt: str = ""): + if not isinstance(image_links, list): + image_links = [image_links] + messages = [{"role": "user", "content": [{"type": "image"}] * len(image_links) + [{"type": "text", "text": text_prompt}]}] + prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) + images = [load_image(image_link) for image_link in image_links] # Support PIL images as well + inputs = self.processor(text=prompt, images=images, return_tensors="pt") + inputs = {k: v.to(self.model.device) for k, v in inputs.items()} + return inputs + + def get_parsed_output(self, inputs): + generate_ids = self.model.generate(**inputs, max_new_tokens=512, num_beams=1) + generated_text = self.processor.batch_decode(generate_ids[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + return generated_text + + +if __name__ == "__main__": + model = Idefics2() + prompt = model.prepare_prompt( + ["https://chromaica.github.io/Museum/ImagenHub_Text-Guided_IE/DiffEdit/sample_34_1.jpg", "https://chromaica.github.io/Museum/ImagenHub_Text-Guided_IE/input/sample_34_1.jpg"], "What is difference between two images?" + ) + # print("prompt : \n", prompt) + res = model.get_parsed_output(prompt) + print("result : \n", res) diff --git a/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/mantis_idefics2_eval.py b/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/mantis_idefics2_eval.py new file mode 100644 index 000000000..d12f09bbf --- /dev/null +++ b/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/mantis_idefics2_eval.py @@ -0,0 +1,41 @@ +import os +import time +from typing import List + +import torch +from transformers import AutoModelForVision2Seq, AutoProcessor +from transformers.image_utils import load_image +from transformers.utils import is_flash_attn_2_available + + +class Mantis: + def __init__(self, model_path: str = "TIGER-Lab/Mantis-8B-Idefics2") -> None: + attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else None + print(f"Using {attn_implementation} for attention implementation") + self.model = AutoModelForVision2Seq.from_pretrained(model_path, device_map="auto", torch_dtype=torch.float16, _attn_implementation=attn_implementation).eval() + self.processor = AutoProcessor.from_pretrained(model_path) + + def prepare_prompt(self, image_links: List = [], text_prompt: str = ""): + if not isinstance(image_links, list): + image_links = [image_links] + messages = [{"role": "user", "content": [{"type": "image"}] * len(image_links) + [{"type": "text", "text": text_prompt}]}] + prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) + images = [load_image(image_link) for image_link in image_links] # Support PIL images as well + inputs = self.processor(text=prompt, images=images, return_tensors="pt") + inputs = {k: v.to(self.model.device) for k, v in inputs.items()} + return inputs + + def get_parsed_output(self, inputs): + generate_ids = self.model.generate(**inputs, max_new_tokens=512, num_beams=1) + generated_text = self.processor.batch_decode(generate_ids[:, inputs["input_ids"].shape[1] :], skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + return generated_text + + +if __name__ == "__main__": + model = Mantis() + prompt = model.prepare_prompt( + ["https://chromaica.github.io/Museum/ImagenHub_Text-Guided_IE/DiffEdit/sample_34_1.jpg", "https://chromaica.github.io/Museum/ImagenHub_Text-Guided_IE/input/sample_34_1.jpg"], "What is difference between two images?" + ) + # print("prompt : \n", prompt) + res = model.get_parsed_output(prompt) + print("result : \n", res) diff --git a/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/minicpmv_eval.py b/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/minicpmv_eval.py new file mode 100644 index 000000000..bf69e04d6 --- /dev/null +++ b/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/minicpmv_eval.py @@ -0,0 +1,42 @@ +import os +import time +from typing import List + +import torch +from PIL import Image +from transformers import AutoModel, AutoTokenizer +from transformers.utils import is_flash_attn_2_available + + +class MiniCPMV: + def __init__(self) -> None: + attn_implementation = "flash_attention_2" if is_flash_attn_2_available() else None + self.model = AutoModel.from_pretrained("openbmb/MiniCPM-Llama3-V-2_5", trust_remote_code=True, torch_dtype=torch.float16, device_map="auto", _attn_implementation=attn_implementation).eval() + self.tokenizer = AutoTokenizer.from_pretrained("openbmb/MiniCPM-Llama3-V-2_5", trust_remote_code=True) + + print(f"Using {attn_implementation} for attention implementation") + + def prepare_prompt(self, image_links: List = [], text_prompt: str = ""): + if not isinstance(image_links, list): + image_links = [image_links] + messages = [{"role": "user", "content": [{"type": "image"}] * len(image_links) + [{"type": "text", "text": text_prompt}]}] + return messages + + def get_parsed_output(self, inputs): + res = self.model.chat( + image=None, + msgs=inputs, + tokenizer=self.tokenizer, + sampling=False, # if sampling=False, beam_search will be used by default + ) + return res + + +if __name__ == "__main__": + model = MiniCPMV() + prompt = model.prepare_prompt( + ["https://chromaica.github.io/Museum/ImagenHub_Text-Guided_IE/DiffEdit/sample_34_1.jpg", "https://chromaica.github.io/Museum/ImagenHub_Text-Guided_IE/input/sample_34_1.jpg"], "What is difference between two images?" + ) + # print("prompt : \n", prompt) + res = model.get_parsed_output(prompt) + print("result : \n", res) diff --git a/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/openai.py b/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/openai.py new file mode 100644 index 000000000..7d012bc1e --- /dev/null +++ b/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/openai.py @@ -0,0 +1,169 @@ +import base64 +import os +from io import BytesIO, StringIO +from typing import List, Optional, Tuple, Union + +import requests +from PIL import Image, ImageOps + + +def get_api_key(file_path): + # Read the API key from the first line of the file + with open(file_path, "r") as file: + return file.readline().strip() + + +# Function to encode the image +def encode_image(image_path): + with open(image_path, "rb") as image_file: + return base64.b64encode(image_file.read()).decode("utf-8") + + +def pick_next_item(current_item, item_list): + if current_item not in item_list: + raise ValueError("Current item is not in the list") + current_index = item_list.index(current_item) + next_index = (current_index + 1) % len(item_list) + + return item_list[next_index] + + +# Function to encode a PIL image +def encode_pil_image(pil_image): + # Create an in-memory binary stream + image_stream = BytesIO() + + # Save the PIL image to the binary stream in JPEG format (you can change the format if needed) + pil_image.save(image_stream, format="JPEG") + + # Get the binary data from the stream and encode it as base64 + image_data = image_stream.getvalue() + base64_image = base64.b64encode(image_data).decode("utf-8") + + return base64_image + + +def load_image(image: Union[str, Image.Image], format: str = "RGB", size: Optional[Tuple] = None) -> Image.Image: + """ + Load an image from a given path or URL and convert it to a PIL Image. + + Args: + image (Union[str, Image.Image]): The image path, URL, or a PIL Image object to be loaded. + format (str, optional): Desired color format of the resulting image. Defaults to "RGB". + size (Optional[Tuple], optional): Desired size for resizing the image. Defaults to None. + + Returns: + Image.Image: A PIL Image in the specified format and size. + + Raises: + ValueError: If the provided image format is not recognized. + """ + if isinstance(image, str): + if image.startswith("http://") or image.startswith("https://"): + image = Image.open(requests.get(image, stream=True).raw) + elif os.path.isfile(image): + image = Image.open(image) + else: + raise ValueError(f"Incorrect path or url, URLs must start with `http://` or `https://`, and {image} is not a valid path") + elif isinstance(image, Image.Image): + image = image + else: + raise ValueError("Incorrect format used for image. Should be an url linking to an image, a local path, or a PIL image.") + image = ImageOps.exif_transpose(image) + image = image.convert(format) + if size != None: + image = image.resize(size, Image.LANCZOS) + return image + + +class GPT4v: + def __init__(self, api_key_path="keys/secret.env", are_images_encoded=False, model_name="gpt-4-vision-preview"): + """OpenAI GPT-4-vision model wrapper + Args: + api_key_path (str): Path to the API key file. Defaults to 'keys/secret.env'. + are_images_encoded (bool): Whether the images are encoded in base64. Defaults to False. + """ + self.multiple_api_keys = False + self.current_key_file = None + self.key_lists = None + if isinstance(api_key_path, list): + self.key_lists = api_key_path + self.current_key_file = api_key_path[0] + self.api_key = get_api_key(self.current_key_file) + self.multiple_api_keys = True + else: + self.api_key = get_api_key(api_key_path) + + if not self.api_key: + print("API key not found.") + exit(1) + + self.url = "https://api.openai.com/v1/chat/completions" + self.model_name = model_name + self.use_encode = are_images_encoded + + def prepare_prompt(self, image_links: List = [], text_prompt: str = ""): + prompt_content = [] + text_dict = {"type": "text", "text": text_prompt} + prompt_content.append(text_dict) + + if not isinstance(image_links, list): + image_links = [image_links] + for image_link in image_links: + image = load_image(image_link) + if self.use_encode == True: + visual_dict = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{encode_pil_image(image)}"}} + else: + visual_dict = {"type": "image_url", "image_url": {"url": image_link}} + prompt_content.append(visual_dict) + return prompt_content + + def get_parsed_output(self, prompt): + payload = {"model": self.model_name, "messages": [{"role": "user", "content": prompt}], "max_tokens": 1400} + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} + response = requests.post(self.url, json=payload, headers=headers) + # return response.text + return self.extract_response(response) + + def extract_response(self, response): + response = response.json() + + try: + out = response["choices"][0]["message"]["content"] + return out + except: + if response["error"]["code"] == "content_policy_violation": + print("Code is content_policy_violation") + elif response["error"]["code"] == "rate_limit_exceeded" or response["error"]["code"] == "insufficient_quota": + print(f"Code is {response['error']['code']}") + print(response["error"]["message"]) + if self.multiple_api_keys == True: + new_key = pick_next_item(self.current_key_file, self.key_lists) + self.update_key(new_key) + self.current_key_file = new_key # override key + print("New key is from the file: ", new_key) + else: + print("Code is different") + print(response) + return "" + + def update_key(self, key, load_from_file=True): + if load_from_file: + self.api_key = get_api_key(key) + else: + self.api_key = key + + +class GPT4o(GPT4v): + def __init__(self, api_key_path="keys/secret.env", are_images_encoded=False, model_name="gpt-4o-2024-05-13"): + super().__init__(api_key_path, are_images_encoded, model_name) + + +if __name__ == "__main__": + model = GPT4o("secret_t2.env", model_name="gpt-4.1") + prompt = model.prepare_prompt( + ["https://chromaica.github.io/Museum/ImagenHub_Text-Guided_IE/DiffEdit/sample_34_1.jpg", "https://chromaica.github.io/Museum/ImagenHub_Text-Guided_IE/input/sample_34_1.jpg"], "What is difference between two images?" + ) + print("prompt : \n", prompt) + res = model.get_parsed_output(prompt) + print("result : \n", res) diff --git a/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/qwen25vl_eval.py b/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/qwen25vl_eval.py new file mode 100644 index 000000000..745c547bb --- /dev/null +++ b/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/qwen25vl_eval.py @@ -0,0 +1,85 @@ +import base64 +import os +import random +import time +from io import BytesIO +from typing import List + +import megfile +import numpy as np +import requests +import torch +from PIL import Image +from qwen_vl_utils import process_vision_info +from transformers import ( + AutoModel, + AutoProcessor, + AutoTokenizer, + Qwen2_5_VLForConditionalGeneration, +) +from transformers.utils import is_flash_attn_2_available + + +def process_image(image): + img_byte_arr = BytesIO() + image.save(img_byte_arr, format="PNG") + img_byte_arr = img_byte_arr.getvalue() + return img_byte_arr + + +def set_seed(seed: int): + """ + Args: + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + seed (`int`): The seed to set. + """ + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +class Qwen25VL: + def __init__(self) -> None: + self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained("/pfs/training-data/hf/models/Qwen/Qwen2.5-VL-72B-Instruct-AWQ", torch_dtype=torch.float16, device_map="auto").eval() + self.processor = AutoProcessor.from_pretrained("/pfs/training-data/hf/models/Qwen/Qwen2.5-VL-72B-Instruct-AWQ") + + def prepare_prompt(self, image_links: List = [], text_prompt: str = ""): + if not isinstance(image_links, list): + image_links = [image_links] + + image_links_base64 = [] + + messages = [{"role": "user", "content": [{"type": "image", "image": img_link} for img_link in image_links] + [{"type": "text", "text": text_prompt}]}] + return messages + + def get_parsed_output(self, messages): + set_seed(42) + # Prepare the inputs + text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + image_inputs, video_inputs = process_vision_info(messages) + + # Process inputs + inputs = self.processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt") + inputs = inputs.to("cuda") + + # Generate output + generation_config = { + "max_new_tokens": 512, + "num_beams": 1, + "do_sample": False, + "temperature": 0.1, + "top_p": None, + } + generated_ids = self.model.generate(**inputs, **generation_config) + generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] + output_text = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) + + return output_text[0] if output_text else "" + + +if __name__ == "__main__": + model = Qwen25VL() + prompt = model.prepare_prompt(["https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"], "Describe the image in detail.") + res = model.get_parsed_output(prompt) + print("result : \n", res) diff --git a/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/utils.py b/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/utils.py new file mode 100644 index 000000000..760b62c46 --- /dev/null +++ b/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/utils.py @@ -0,0 +1,70 @@ +import base64 +from io import BytesIO +from typing import List + +import requests +from PIL import Image + + +def pil_image_to_base64(pil_image, format="PNG"): + buffered = BytesIO() + pil_image.save(buffered, format=format) # Save image to the buffer in the specified format + img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") # Encode the buffer's content to base64 + return img_str + + +def load_image(image_file): + if image_file.startswith("http"): + response = requests.get(image_file) + image = Image.open(BytesIO(response.content)).convert("RGB") + else: + import os + + image = Image.open(image_file).convert("RGB") + return image + + +def load_images(image_files): + out = [] + for image_file in image_files: + image = load_image(image_file) + out.append(image) + return out + + +def merge_images(image_links: List = []): + """Merge multiple images into one image + + Args: + image_links (List, optional): List of image links. Defaults to []. + + Returns: + [type]: [description] + """ + if len(image_links) == 0: + return None + images = load_images(image_links) + if len(images) == 1: + return images[0] + widths, heights = zip(*(i.size for i in images)) + average_height = sum(heights) // len(heights) + for i, im in enumerate(images): + # scale in proportion + images[i] = im.resize((int(im.size[0] * average_height / im.size[1]), average_height)) + widths, heights = zip(*(i.size for i in images)) + total_width = sum(widths) + max_height = max(heights) + new_im = Image.new("RGB", (total_width + 10 * (len(images) - 1), max_height)) + x_offset = 0 + for i, im in enumerate(images): + if i > 0: + # past a column of 1 pixel starting from x_offset width being black, 8 pixels being white, and 1 pixel being black + new_im.paste(Image.new("RGB", (1, max_height), (0, 0, 0)), (x_offset, 0)) + x_offset += 1 + new_im.paste(Image.new("RGB", (8, max_height), (255, 255, 255)), (x_offset, 0)) + x_offset += 8 + new_im.paste(Image.new("RGB", (1, max_height), (0, 0, 0)), (x_offset, 0)) + x_offset += 1 + new_im.paste(im, (x_offset, 0)) + x_offset += im.size[0] + return new_im diff --git a/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/vllm_qwen_eval.py b/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/vllm_qwen_eval.py new file mode 100644 index 000000000..995db5277 --- /dev/null +++ b/lmms_eval/tasks/gedit_bench/viescore/mllm_tools/vllm_qwen_eval.py @@ -0,0 +1,220 @@ +""" +vLLM-based Qwen VL model evaluation backend. + +This module provides evaluation using Qwen2.5-VL or Qwen3-VL models served via vLLM. +vLLM provides an OpenAI-compatible API endpoint. + +Usage: + 1. Start vLLM server on remote machine: + python -m vllm.entrypoints.openai.api_server \ + --model Qwen/Qwen2.5-VL-72B-Instruct \ + --port 8000 \ + --tensor-parallel-size 4 + + 2. Set environment variables: + export VLLM_API_BASE="http://remote-server:8000/v1" + export VLLM_MODEL_NAME="Qwen/Qwen2.5-VL-72B-Instruct" # optional + + 3. Use backbone="vllm_qwen" in VIEScore + +""" + +import base64 +import os +import random +from io import BytesIO +from typing import List, Optional + +import numpy as np + + +def encode_image_to_base64(image) -> str: + """Convert PIL Image to base64 string""" + if isinstance(image, str): + # It's a file path + with open(image, "rb") as f: + return base64.b64encode(f.read()).decode("utf-8") + else: + # It's a PIL Image + buffer = BytesIO() + image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + +def set_seed(seed: int): + """Set random seed for reproducibility""" + random.seed(seed) + np.random.seed(seed) + try: + import torch + + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + except ImportError: + pass + + +class VLLMQwen: + """ + vLLM-based Qwen VL model for image evaluation. + + Connects to a vLLM server running Qwen2.5-VL or Qwen3-VL and uses + the OpenAI-compatible API for inference. + + Environment variables: + - VLLM_API_BASE: Base URL of vLLM server (e.g., "http://localhost:8000/v1") + - VLLM_API_KEY: API key if required (default: "EMPTY") + - VLLM_MODEL_NAME: Model name to use (default: auto-detect from server) + - VLLM_TIMEOUT: Request timeout in seconds (default: 120) + """ + + def __init__( + self, + api_base: Optional[str] = None, + api_key: Optional[str] = None, + model_name: Optional[str] = None, + timeout: int = 120, + ) -> None: + """ + Initialize vLLM Qwen client. + + Args: + api_base: vLLM server base URL. Defaults to VLLM_API_BASE env var. + api_key: API key if required. Defaults to VLLM_API_KEY env var or "EMPTY". + model_name: Model name. Defaults to VLLM_MODEL_NAME env var. + timeout: Request timeout in seconds. + """ + self.api_base = api_base or os.getenv("VLLM_API_BASE", "http://localhost:8000/v1") + self.api_key = api_key or os.getenv("VLLM_API_KEY", "EMPTY") + self.model_name = model_name or os.getenv("VLLM_MODEL_NAME") + self.timeout = int(os.getenv("VLLM_TIMEOUT", str(timeout))) + + # Auto-detect model name if not provided + if not self.model_name: + self.model_name = self._get_model_name() + + print(f"VLLMQwen initialized: api_base={self.api_base}, model={self.model_name}") + + def _get_model_name(self) -> str: + """Get model name from vLLM server""" + try: + from openai import OpenAI + + client = OpenAI( + api_key=self.api_key, + base_url=self.api_base, + timeout=self.timeout, + ) + models = client.models.list() + if models.data: + return models.data[0].id + except Exception as e: + print(f"Warning: Could not auto-detect model name: {e}") + return "default" + + def prepare_prompt(self, image_links: List = [], text_prompt: str = ""): + """ + Prepare prompt for Qwen VL model. + + Args: + image_links: List of PIL Images or image paths + text_prompt: Text prompt + + Returns: + List of messages in OpenAI chat format + """ + if not isinstance(image_links, list): + image_links = [image_links] + + # Build content list with images and text + content = [] + + for img in image_links: + img_base64 = encode_image_to_base64(img) + content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{img_base64}"}}) + + content.append({"type": "text", "text": text_prompt}) + + messages = [{"role": "user", "content": content}] + + return messages + + def get_parsed_output(self, messages) -> str: + """ + Get model output for the given messages. + + Args: + messages: Messages in OpenAI chat format + + Returns: + Model response text + """ + set_seed(42) + + try: + from openai import OpenAI + except ImportError: + raise ImportError("openai package is required. Install with: pip install openai") + + try: + client = OpenAI( + api_key=self.api_key, + base_url=self.api_base, + timeout=self.timeout, + ) + + response = client.chat.completions.create( + model=self.model_name, + messages=messages, + max_tokens=512, + temperature=0.1, + top_p=None, + ) + + return response.choices[0].message.content if response.choices else "" + + except Exception as e: + print(f"Error calling vLLM API: {e}") + raise + + +class VLLMQwen25VL(VLLMQwen): + """vLLM-based Qwen2.5-VL model""" + + def __init__(self, **kwargs) -> None: + # Set default model name for Qwen2.5-VL if not provided + if "model_name" not in kwargs and not os.getenv("VLLM_MODEL_NAME"): + kwargs["model_name"] = os.getenv("VLLM_MODEL_NAME", "Qwen/Qwen2.5-VL-72B-Instruct") + super().__init__(**kwargs) + + +class VLLMQwen3VL(VLLMQwen): + """vLLM-based Qwen3-VL model""" + + def __init__(self, **kwargs) -> None: + # Set default model name for Qwen3-VL if not provided + if "model_name" not in kwargs and not os.getenv("VLLM_MODEL_NAME"): + kwargs["model_name"] = os.getenv("VLLM_MODEL_NAME", "Qwen/Qwen3-VL-72B-Instruct") + super().__init__(**kwargs) + + +if __name__ == "__main__": + # Test the vLLM Qwen client + import sys + + # Check if API base is set + api_base = os.getenv("VLLM_API_BASE") + if not api_base: + print("Please set VLLM_API_BASE environment variable") + print("Example: export VLLM_API_BASE='http://localhost:8000/v1'") + sys.exit(1) + + print(f"Testing vLLM Qwen client with API base: {api_base}") + + model = VLLMQwen25VL() + + # Test with a simple text prompt (no image) + messages = [{"role": "user", "content": [{"type": "text", "text": "Hello, what model are you?"}]}] + + response = model.get_parsed_output(messages) + print(f"Response: {response}") diff --git a/lmms_eval/tasks/gedit_bench/viescore/parse_prompt.py b/lmms_eval/tasks/gedit_bench/viescore/parse_prompt.py new file mode 100644 index 000000000..f760f70b4 --- /dev/null +++ b/lmms_eval/tasks/gedit_bench/viescore/parse_prompt.py @@ -0,0 +1,22 @@ +import os + + +def create_python_file_with_texts(folder_path, output_file): + with open(output_file, "w", encoding="utf-8") as out_file: + out_file.write("# This file is generated automatically through parse_prompt.py\n\n") + for root, dirs, files in os.walk(folder_path): + for file in files: + if file.endswith(".txt"): + file_path = os.path.join(root, file) + var_name = "_" + file_path.replace(folder_path, "").replace(os.sep, "_").replace(".txt", "").strip("_") + with open(file_path, "r", encoding="utf-8") as f: + content = f.read().replace('"""', '"""') + out_file.write(f'{var_name} = """{content}"""\n\n') + + +# Example usage +current_file_path = os.path.abspath(__file__) +current_folder_path = os.path.dirname(current_file_path) +folder_path = os.path.join(current_folder_path, "prompts_raw") +output_file = os.path.join(current_folder_path, "vie_prompts.py") +create_python_file_with_texts(folder_path, output_file) diff --git a/lmms_eval/tasks/gedit_bench/viescore/utils.py b/lmms_eval/tasks/gedit_bench/viescore/utils.py new file mode 100644 index 000000000..1df331bc1 --- /dev/null +++ b/lmms_eval/tasks/gedit_bench/viescore/utils.py @@ -0,0 +1,369 @@ +import ast +import json +import os +import random +from typing import List, Optional, Union + +import regex as re + + +def fix_json(input_str): + # Add double quotes around keys using regex + fixed_str = re.sub(r"(\w+):", r'"\1":', input_str) + + # Add double quotes around string values if necessary and wrap int/float values in [] + def format_value(match): + key, value, comma = match.groups() + value = value.strip() + # Check if value is an integer or float + if re.match(r"^-?\d+(\.\d+)?$", value): + value = f"[{value}]" + # Check if value is a boolean or null + elif re.match(r"^(true|false|null)$", value, re.IGNORECASE): + pass # leave as is + else: + # Add quotes around string values + value = f'"{value}"' + return f"{key}: {value}{comma}" + + fixed_str = re.sub(r'(".*?"):(.*?)(,|})', format_value, fixed_str) + + return fixed_str + + +def read_file_to_string(file_path): + """ + Reads the contents of a text file and returns it as a string. + + :param file_path: The path to the text file. + :return: A string containing the contents of the file. + """ + try: + with open(file_path, "r", encoding="utf-8") as file: + return file.read() + except FileNotFoundError: + print(f"The file {file_path} was not found.") + return None + except Exception as e: + print(f"An error occurred: {e}") + return None + + +def read_files_to_string(file_paths): + """ + Reads the contents of multiple text files and returns them as a single string, + with each file's contents separated by a newline. + + :param file_paths: A list of paths to text files. + :return: A string containing the concatenated contents of the files. + """ + all_contents = [] # List to hold the contents of each file + + for file_path in file_paths: + try: + with open(file_path, "r", encoding="utf-8") as file: + all_contents.append(file.read()) + except FileNotFoundError: + print(f"The file {file_path} was not found.") + except Exception as e: + print(f"An error occurred while reading {file_path}: {e}") + + # Join all the contents with a newline character + return "\n".join(all_contents) + + +def get_file_path(filename: Union[str, os.PathLike], search_from: Union[str, os.PathLike] = "."): + """ + Search for a file across a directory and return its absolute path. + + Args: + filename (Union[str, os.PathLike]): The name of the file to search for. + search_from (Union[str, os.PathLike], optional): The directory from which to start the search. Defaults to ".". + + Returns: + str: Absolute path to the found file. + + Raises: + FileNotFoundError: If the file is not found. + """ + for root, dirs, files in os.walk(search_from): + for name in files: + if name == filename: + return os.path.abspath(os.path.join(root, name)) + raise FileNotFoundError(filename, "not found.") + + +# +========================================================================================= +def verify(s, target_sequence): + # Count the occurrences of the target sequence + count = s.count(target_sequence) + + # Check if the target sequence appears exactly twice + return count == 2 + + +def is_int_between_0_and_10(s): + try: + num = int(s) + return 0 <= num <= 10 + except ValueError: + return False + + +def is_str_a_list_of_ints_0_to_10(s): + try: + # Attempt to parse the string as a Python literal (list, dict, etc.) + parsed = ast.literal_eval(s) + + # Check if the parsed object is a list + if not isinstance(parsed, list): + return False + + # Check if all elements are integers and between 0 to 10 + return all(isinstance(item, int) and 0 <= item <= 10 for item in parsed) + + except (ValueError, SyntaxError): + # If parsing fails or any other error occurs + return False + + +def is_str_valid_score_format_brackets(s): + try: + # Removing brackets and splitting the string by commas + content = s.strip("[]").split(",") + + length = len(content) + + # Parsing each element and checking the format and range + scores = {} + for item in content: + key, value = item.split(":") + key = key.strip() + value = int(value.strip()) + + # Check if the key starts with 'score' and the value is in the correct range + if not key.startswith("score") or not 0 <= value <= 10: + return False + + scores[key] = value + + fetch_words = [f"score{i+1}" for i in range(length)] + # Check if at least 'score1' and 'score2' are present + return all(key in scores for key in fetch_words) + + except (ValueError, SyntaxError): + # If any parsing error occurs + return False + + +# +========================================================================================= +def mllm_output_to_dict(input_string, give_up_parsing=False): + """ + Args: + input_string (str): actually the output of the mllm model to be parsed + output_file_name (str): The name of the output file. + """ + # Catch for gpt4v rate_limit_exceeded error + if input_string == "rate_limit_exceeded": + return "rate_limit_exceeded" + + # Define the delimiters + delimiter = "||V^=^V||" + + if input_string.count(delimiter) == 2: + if not verify(input_string, delimiter): + print("The required delimiters were not found correctly in the string.") + return False + # Extract the content between the delimiters + start_index = input_string.find(delimiter) + len(delimiter) + end_index = input_string.rfind(delimiter) + else: + # find the json mannually + # some mllm tends not to output the delimiters, but it does output the json contents + # so we will find the json content mannually + start_index = input_string.find("{") + end_index = input_string.rfind("}") + 1 + if start_index == -1 or end_index == 0: + # json not found + # some mllm tends to output only a list of scores like [6, 0], + # this time we will just get the scores and ignore the reasoning (other part of the json) + start_index = input_string.find("[") + end_index = input_string.rfind("]") + 1 + if give_up_parsing: # if we want to give up parsing + guessed_value = random.randint(0, 10) + print(f"Failed to find the json content in the string. Guess a value : {guessed_value}.") + json_content = {"score": [guessed_value], "reasoning": f"guess_if_cannot_parse | {input_string}"} + json_str = json.dumps(json_content) + input_string = json_str + start_index = 0 + end_index = len(json_str) + elif re.match(r"^\[\d+, ?\d+\]$", input_string[start_index:end_index]): + scores = json.loads(input_string[start_index:end_index]) + if not isinstance(scores, list): + scores = [scores] + json_content = {"score": scores, "reasoning": "System: output is simply a list of scores"} + json_str = json.dumps(json_content) + input_string = json_str + start_index = 0 + end_index = len(json_str) + elif is_int_between_0_and_10(input_string): # if output is simply a number + scores = [int(input_string)] + json_content = {"score": scores, "reasoning": "System: output is simply a number"} + json_str = json.dumps(json_content) + input_string = json_str + start_index = 0 + end_index = len(json_str) + else: + print("Failed to find the json content in the string.") + return False + + # Check if we found two delimiters + if start_index != -1 and end_index != -1 and start_index != end_index: + # Extract the JSON string + json_str = input_string[start_index:end_index].strip() + json_str = json_str.replace("\n", "") + # Parse the JSON string into a dictionary + try: + new_data = json.loads(json_str) + if not isinstance(new_data["score"], list): + new_data["score"] = [new_data["score"]] + except: + print("Now fixing: ", json_str) + try: + new_data = json.loads(fix_json(json_str)) + return new_data + except: + print("Error: Cannot fix", json_str) + return False + return new_data + else: + print("The required delimiters were not found correctly in the string.") + return False + + +def write_entry_to_json_file(input_string, uid, prompt_input, vision_input, output_file_name, give_up_parsing=False): + """ + Args: + input_string (str): actually the output of the mllm model to be parsed + uid (str): The unique identifier for the each item in the test data + prompt_input (str): The prompt input for the entry. text prompt. + vision_input (str): The vision input for the entry. image links. + output_file_name (str): The name of the output file. + """ + # Catch for gpt4v rate_limit_exceeded error + if input_string == "rate_limit_exceeded": + return "rate_limit_exceeded" + + # Define the delimiters + delimiter = "||V^=^V||" + + if input_string.count(delimiter) == 2: + if not verify(input_string, delimiter): + print("The required delimiters were not found correctly in the string.") + return False + # Extract the content between the delimiters + start_index = input_string.find(delimiter) + len(delimiter) + end_index = input_string.rfind(delimiter) + else: + # find the json mannually + # some mllm tends not to output the delimiters, but it does output the json contents + # so we will find the json content mannually + start_index = input_string.find("{") + end_index = input_string.rfind("}") + 1 + if start_index == -1 or end_index == 0: + # json not found + # some mllm tends to output only a list of scores like [6, 0], + # this time we will just get the scores and ignore the reasoning (other part of the json) + start_index = input_string.find("[") + end_index = input_string.rfind("]") + 1 + if give_up_parsing: # if we want to give up parsing + guessed_value = random.randint(0, 10) + print(f"Failed to find the json content in the string. Guess a value : {guessed_value}.") + json_content = {"score": [guessed_value], "reasoning": f"guess_if_cannot_parse | {input_string}"} + json_str = json.dumps(json_content) + input_string = json_str + start_index = 0 + end_index = len(json_str) + elif re.match(r"^\[\d+, ?\d+\]$", input_string[start_index:end_index]): + scores = json.loads(input_string[start_index:end_index]) + json_content = {"score": scores, "reasoning": None} + json_str = json.dumps(json_content) + input_string = json_str + start_index = 0 + end_index = len(json_str) + elif is_int_between_0_and_10(input_string): # if output is simply a number + scores = [int(input_string)] + json_content = {"score": scores, "reasoning": None} + json_str = json.dumps(json_content) + input_string = json_str + start_index = 0 + end_index = len(json_str) + else: + print("Failed to find the json content in the string.") + return False + + # Check if we found two delimiters + if start_index != -1 and end_index != -1 and start_index != end_index: + # Extract the JSON string + json_str = input_string[start_index:end_index].strip() + json_str = json_str.replace("\n", "") + try: + # Parse the JSON string into a dictionary + new_data = json.loads(json_str) + + # Ensure the directory exists + os.makedirs(os.path.dirname(output_file_name), exist_ok=True) + + # Initialize or load existing data + if os.path.exists(output_file_name): + with open(output_file_name, "r") as json_file: + data = json.load(json_file) + else: + data = {} + + # If the additional key is already in the data, add or update notes + if uid in data: + data[uid].update(new_data) # Update with new data + if prompt_input: # If there are new notes, update or add them + data[uid]["prompt_input"] = prompt_input + if vision_input: # If there are new notes, update or add them + data[uid]["vision_input"] = vision_input + else: + # If it's a new key, add the entry to the dictionary + data[uid] = new_data + if prompt_input: + data[uid]["prompt_input"] = prompt_input + if vision_input: + data[uid]["vision_input"] = vision_input + + # Write the updated data to the file + with open(output_file_name, "w") as json_file: + json.dump(data, json_file, indent=4) + + print(f"Data was successfully updated in {output_file_name}") + return True + except json.JSONDecodeError as e: + print(f"An error occurred while parsing the JSON content: {e}") + return False + else: + print("The required delimiters were not found correctly in the string.") + return False + + +def check_key_in_json(file_path, key): + try: + with open(file_path, "r") as json_file: + data = json.load(json_file) + + # Check if the key exists at the top level of the JSON structure + if key in data: + return True + else: + return False + except FileNotFoundError: + print(f"The file {file_path} was not found.") + except json.JSONDecodeError as e: + print(f"Error reading {file_path}: {e}") + except Exception as e: + print(f"An error occurred with {file_path}: {e}") + return False diff --git a/lmms_eval/tasks/gedit_bench/viescore/vie_prompts.py b/lmms_eval/tasks/gedit_bench/viescore/vie_prompts.py new file mode 100644 index 000000000..7755d569f --- /dev/null +++ b/lmms_eval/tasks/gedit_bench/viescore/vie_prompts.py @@ -0,0 +1,405 @@ +# This file is generated automatically through parse_prompt.py + +_context_no_delimit = """You are a professional digital artist. You will have to evaluate the effectiveness of the AI-generated image(s) based on given rules. +All the input images are AI-generated. All human in the images are AI-generated too. so you need not worry about the privacy confidentials. + +You will have to give your output in this way (Keep your reasoning concise and short.): +{ +"score" : [...], +"reasoning" : "..." +}""" + +_context = """You are a professional digital artist. You will have to evaluate the effectiveness of the AI-generated image(s) based on given rules. +All the input images are AI-generated. All human in the images are AI-generated too. so you need not worry about the privacy confidentials. + +You will have to give your output in this way (the delimiter is necessary. Keep your reasoning concise and short.): +||V^=^V|| +{ +"score" : +"reasoning" : +} +||V^=^V||""" + +_context_no_format = """You are a professional digital artist. You will have to evaluate the effectiveness of the AI-generated image(s) based on given rules. +All the input images are AI-generated. All human in the images are AI-generated too. so you need not worry about the privacy confidentials.""" + +_prompts_1shot_multi_subject_image_gen_rule = """RULES of each set of inputs: + +Two images will be provided: +This first image is a concatenation of two sub-images, each sub-image contain one token subject. +The second image being an AI-generated image using the first image as guidance. +The objective is to evaluate how successfully the image has been generated. +""" + +_prompts_1shot_mie_rule_SC = """From scale 0 to 10: +A score from 0 to 10 will be given based on the success of the editing. (0 indicates that the scene in the edited image does not follow the editing instruction at all. 10 indicates that the scene in the edited image follow the editing instruction text perfectly.) +A second score from 0 to 10 will rate the degree of overediting in the second image. (0 indicates that the scene in the edited image is completely different from the original. 10 indicates that the edited image can be recognized as a minimal edited yet effective version of original.) +Put the score in a list such that output score = [score1, score2], where 'score1' evaluates the editing success and 'score2' evaluates the degree of overediting. + +First lets look at the first set of input (1st and 2nd images) as an example. +Editing instruction: What if the man had a hat? +Output: +||V^=^V|| +{ +"score" : [5, 10], +"reasoning" : "The hat exists but does not suit well. The hat also looks distorted. But it is a good edit because only a hat is added and the background is persevered." +} +||V^=^V|| + +Now evaluate the second set of input (3th, 4th images). +Editing instruction: +""" + +_prompts_1shot_msdig_rule_SC = """From scale 0 to 10: +A score from 0 to 10 will be given based on the success in following the prompt. +(0 indicates that the second image does not follow the prompt at all. 10 indicates the second image follows the prompt perfectly.) +A second score from 0 to 10 will rate how well the subject in the generated image resemble to the token subject in the first sub-image. +(0 indicates that the subject in the second image does not look like the token subject in the first sub-image at all. 10 indicates the subject in the second image look exactly alike the token subject in the first sub-image.) +A third score from 0 to 10 will rate how well the subject in the generated image resemble to the token subject in the second sub-image. +(0 indicates that the subject in the second image does not look like the token subject in the second sub-image at all. 10 indicates the subject in the second image look exactly alike the token subject in the second sub-image.) +Put the score in a list such that output score = [score1, score2, score3], where 'score1' evaluates the prompt and 'score2' evaluates the resemblance for the first sub-image, and 'score3' evaluates the resemblance for the second sub-image. + +First lets look at the first set of input (1st and 2nd images) as an example. +Text Prompt: A digital illustration of a cat beside a wooden pot +Output: +||V^=^V|| +{ +"score" : [5, 5, 10], +"reasoning" : "The cat is not beside the wooden pot. The pot looks partially resemble to the subject pot. The cat looks highly resemble to the subject cat." +} +||V^=^V|| + +Now evaluate the second set of input (3th, 4th images). +Text Prompt: """ + +_prompts_1shot_t2i_rule_SC = """From scale 0 to 10: +A score from 0 to 10 will be given based on the success in following the prompt. +(0 indicates that the AI generated image does not follow the prompt at all. 10 indicates the AI generated image follows the prompt perfectly.) + +Put the score in a list such that output score = [score]. + +First lets look at the first set of input (1st image) as an example. +Text Prompt: A pink and a white frisbee are on the ground. +Output: +||V^=^V|| +{ +"score" : [5], +"reasoning" : "White frisbee not present in the image." +} +||V^=^V|| + +Now evaluate the second set of input (2nd image). +Text Prompt: +""" + +_prompts_1shot_tie_rule_SC = """From scale 0 to 10: +A score from 0 to 10 will be given based on the success of the editing. (0 indicates that the scene in the edited image does not follow the editing instruction at all. 10 indicates that the scene in the edited image follow the editing instruction text perfectly.) +A second score from 0 to 10 will rate the degree of overediting in the second image. (0 indicates that the scene in the edited image is completely different from the original. 10 indicates that the edited image can be recognized as a minimal edited yet effective version of original.) +Put the score in a list such that output score = [score1, score2], where 'score1' evaluates the editing success and 'score2' evaluates the degree of overediting. + +First lets look at the first set of input (1st and 2nd images) as an example. +Editing instruction: What if the man had a hat? +Output: +||V^=^V|| +{ +"score" : [5, 10], +"reasoning" : "The hat exists but does not suit well. The hat also looks distorted. But it is a good edit because only a hat is added and the background is persevered." +} +||V^=^V|| + +Now evaluate the second set of input (3th, 4th images). +Editing instruction: +""" + +_prompts_1shot_sdie_rule_SC = """From scale 0 to 10: +A score from 0 to 10 will rate how well the subject in the generated image resemble to the token subject in the second image. +(0 indicates that the subject in the third image does not look like the token subject at all. 10 indicates the subject in the third image look exactly alike the token subject.) +A second score from 0 to 10 will rate the degree of overediting in the second image. +(0 indicates that the scene in the edited image is completely different from the first image. 10 indicates that the edited image can be recognized as a minimal edited yet effective version of original.) +Put the score in a list such that output score = [score1, score2], where 'score1' evaluates the resemblance and 'score2' evaluates the degree of overediting. + +First lets look at the first set of input (1st, 2nd and 3rd images) as an example. +Subject: +Output: +||V^=^V|| +{ +"score" : [5, 10], +"reasoning" : "The monster toy looks partially resemble to the token subject. The edit is minimal." +} +||V^=^V|| + +Now evaluate the second set of input (4th, 5th, and 6th images). +Subject: +""" + +_prompts_1shot_one_image_gen_rule = """RULES of each set of inputs: + +One image will be provided; The image is an AI-generated image. +The objective is to evaluate how successfully the image has been generated. +""" + +_prompts_1shot_sdig_rule_SC = """From scale 0 to 10: +A score from 0 to 10 will be given based on the success in following the prompt. +(0 indicates that the second image does not follow the prompt at all. 10 indicates the second image follows the prompt perfectly.) +A second score from 0 to 10 will rate how well the subject in the generated image resemble to the token subject in the first image. +(0 indicates that the subject in the second image does not look like the token subject at all. 10 indicates the subject in the second image look exactly alike the token subject.) +Put the score in a list such that output score = [score1, score2], where 'score1' evaluates the prompt and 'score2' evaluates the resemblance. + +First lets look at the first set of input (1st and 2nd images) as an example. +Text Prompt: a red cartoon figure eating a banana +Output: +||V^=^V|| +{ +"score" : [10, 5], +"reasoning" : "The red cartoon figure is eating a banana. The red cartoon figure looks partially resemble to the subject." +} +||V^=^V|| + +Now evaluate the second set of input (3th, 4th images). +Text Prompt: +""" + +_prompts_1shot_rule_PQ = """RULES of each set of inputs: + +One image will be provided; The image is an AI-generated image. +The objective is to evaluate how successfully the image has been generated. + +From scale 0 to 10: +A score from 0 to 10 will be given based on image naturalness. +( + 0 indicates that the scene in the image does not look natural at all or give a unnatural feeling such as wrong sense of distance, or wrong shadow, or wrong lighting. + 10 indicates that the image looks natural. +) +A second score from 0 to 10 will rate the image artifacts. +( + 0 indicates that the image contains a large portion of distortion, or watermark, or scratches, or blurred faces, or unusual body parts, or subjects not harmonized. + 10 indicates the image has no artifacts. +) +Put the score in a list such that output score = [naturalness, artifacts] + + +First lets look at the first set of input (1st image) as an example. +Output: +||V^=^V|| +{ +"score" : [5, 5], +"reasoning" : "The image gives an unnatural feeling on hands of the girl. There is also minor distortion on the eyes of the girl." +} +||V^=^V|| + +Now evaluate the second set of input (2nd image). + +""" + +_prompts_1shot_subject_image_gen_rule = """RULES of each set of inputs: + +Two images will be provided: The first being a token subject image and the second being an AI-generated image using the first image as guidance. +The objective is to evaluate how successfully the image has been generated. +""" + +_prompts_1shot_cig_rule_SC = """ +From scale 0 to 10: +A score from 0 to 10 will be given based on the success in following the prompt. +(0 indicates that the second image does not follow the prompt at all. 10 indicates the second image follows the prompt perfectly.) +A second score from 0 to 10 will rate how well the generated image is following the guidance image. +(0 indicates that the second image is not following the guidance at all. 10 indicates that second image is following the guidance image.) +Put the score in a list such that output score = [score1, score2], where 'score1' evaluates the prompt and 'score2' evaluates the guidance. + +First lets look at the first set of input (1st and 2nd images) as an example. +Text Prompt: the bridge is red, Golden Gate Bridge in San Francisco, USA +Output: +||V^=^V|| +{ +"score" : [5, 5], +"reasoning" : "The bridge is red. But half of the bridge is gone." +} +||V^=^V|| + +Now evaluate the second set of input (3th, 4th images). +Text Prompt: +""" + +_prompts_1shot_two_image_edit_rule = """RULES of each set of inputs: + +Two images will be provided: The first being the original AI-generated image and the second being an edited version of the first. +The objective is to evaluate how successfully the editing instruction has been executed in the second image. + +Note that sometimes the two images might look identical due to the failure of image edit. +""" + +_prompts_1shot_subject_image_edit_rule = """RULES of each set of inputs: + +Three images will be provided: +The first image is a input image to be edited. +The second image is a token subject image. +The third image is an AI-edited image from the first image. it should contain a subject that looks alike the subject in second image. +The objective is to evaluate how successfully the image has been edited. +""" + +_prompts_1shot_control_image_gen_rule = """RULES of each set of inputs: + +Two images will be provided: The first being a processed image (e.g. Canny edges, openpose, grayscale etc.) and the second being an AI-generated image using the first image as guidance. +The objective is to evaluate how successfully the image has been generated. +""" + +_prompts_0shot_two_image_edit_rule = """RULES: + +Two images will be provided: The first being the original AI-generated image and the second being an edited version of the first. +The objective is to evaluate how successfully the editing instruction has been executed in the second image. + +Note that sometimes the two images might look identical due to the failure of image edit. +""" + +_prompts_0shot_one_video_gen_rule = """RULES: + +The images are extracted from a AI-generated video according to the text prompt. +The objective is to evaluate how successfully the video has been generated. +""" + +_prompts_0shot_t2v_rule_PQ = """RULES: + +The image frames are AI-generated. +The objective is to evaluate how successfully the image frames has been generated. + +From scale 0 to 10: +A score from 0 to 10 will be given based on the image frames naturalness. +( + 0 indicates that the scene in the image frames does not look natural at all or give a unnatural feeling such as wrong sense of distance, or wrong shadow, or wrong lighting. + 10 indicates that the image frames looks natural. +) +A second score from 0 to 10 will rate the image frames artifacts. +( + 0 indicates that the image frames contains a large portion of distortion, or watermark, or scratches, or blurred faces, or unusual body parts, or subjects not harmonized. + 10 indicates the image frames has no artifacts. +) +Put the score in a list such that output score = [naturalness, artifacts] +""" + +_prompts_0shot_msdig_rule_SC = """From scale 0 to 10: +A score from 0 to 10 will be given based on the success in following the prompt. +(0 indicates that the second image does not follow the prompt at all. 10 indicates the second image follows the prompt perfectly.) +A second score from 0 to 10 will rate how well the subject in the generated image resemble to the token subject in the first sub-image. +(0 indicates that the subject in the second image does not look like the token subject in the first sub-image at all. 10 indicates the subject in the second image look exactly alike the token subject in the first sub-image.) +A third score from 0 to 10 will rate how well the subject in the generated image resemble to the token subject in the second sub-image. +(0 indicates that the subject in the second image does not look like the token subject in the second sub-image at all. 10 indicates the subject in the second image look exactly alike the token subject in the second sub-image.) +Put the score in a list such that output score = [score1, score2, score3], where 'score1' evaluates the prompt and 'score2' evaluates the resemblance for the first sub-image, and 'score3' evaluates the resemblance for the second sub-image. + +Text Prompt: +""" + +_prompts_0shot_sdie_rule_SC = """From scale 0 to 10: +A score from 0 to 10 will rate how well the subject in the generated image resemble to the token subject in the second image. +(0 indicates that the subject in the third image does not look like the token subject at all. 10 indicates the subject in the third image look exactly alike the token subject.) +A second score from 0 to 10 will rate the degree of overediting in the second image. +(0 indicates that the scene in the edited image is completely different from the first image. 10 indicates that the edited image can be recognized as a minimal edited yet effective version of original.) +Put the score in a list such that output score = [score1, score2], where 'score1' evaluates the resemblance and 'score2' evaluates the degree of overediting. + +Subject: """ + +_prompts_0shot_subject_image_edit_rule = """RULES: + +Three images will be provided: +The first image is a input image to be edited. +The second image is a token subject image. +The third image is an AI-edited image from the first image. it should contain a subject that looks alike the subject in second image. +The objective is to evaluate how successfully the image has been edited. +""" + +_prompts_0shot_mie_rule_SC = """From scale 0 to 10: +A score from 0 to 10 will be given based on the success of the editing. (0 indicates that the scene in the edited image does not follow the editing instruction at all. 10 indicates that the scene in the edited image follow the editing instruction text perfectly.) +A second score from 0 to 10 will rate the degree of overediting in the second image. (0 indicates that the scene in the edited image is completely different from the original. 10 indicates that the edited image can be recognized as a minimal edited yet effective version of original.) +Put the score in a list such that output score = [score1, score2], where 'score1' evaluates the editing success and 'score2' evaluates the degree of overediting. + +Editing instruction: +""" + +_prompts_0shot_sdig_rule_SC = """From scale 0 to 10: +A score from 0 to 10 will be given based on the success in following the prompt. +(0 indicates that the second image does not follow the prompt at all. 10 indicates the second image follows the prompt perfectly.) +A second score from 0 to 10 will rate how well the subject in the generated image resemble to the token subject in the first image. +(0 indicates that the subject in the second image does not look like the token subject at all. 10 indicates the subject in the second image look exactly alike the token subject.) +Put the score in a list such that output score = [score1, score2], where 'score1' evaluates the prompt and 'score2' evaluates the resemblance. + +Text Prompt: +""" + +_prompts_0shot_tie_rule_SC = """ +From scale 0 to 10: +A score from 0 to 10 will be given based on the success of the editing. (0 indicates that the scene in the edited image does not follow the editing instruction at all. 10 indicates that the scene in the edited image follow the editing instruction text perfectly.) +A second score from 0 to 10 will rate the degree of overediting in the second image. (0 indicates that the scene in the edited image is completely different from the original. 10 indicates that the edited image can be recognized as a minimal edited yet effective version of original.) +Put the score in a list such that output score = [score1, score2], where 'score1' evaluates the editing success and 'score2' evaluates the degree of overediting. + +Editing instruction: +""" + +_prompts_0shot_t2i_rule_SC = """From scale 0 to 10: +A score from 0 to 10 will be given based on the success in following the prompt. +(0 indicates that the AI generated image does not follow the prompt at all. 10 indicates the AI generated image follows the prompt perfectly.) + +Put the score in a list such that output score = [score]. + +Text Prompt: +""" + +_prompts_0shot_cig_rule_SC = """From scale 0 to 10: +A score from 0 to 10 will be given based on the success in following the prompt. +(0 indicates that the second image does not follow the prompt at all. 10 indicates the second image follows the prompt perfectly.) +A second score from 0 to 10 will rate how well the generated image is following the guidance image. +(0 indicates that the second image is not following the guidance at all. 10 indicates that second image is following the guidance image.) +Put the score in a list such that output score = [score1, score2], where 'score1' evaluates the prompt and 'score2' evaluates the guidance. + +Text Prompt: """ + +_prompts_0shot_control_image_gen_rule = """RULES: + +Two images will be provided: The first being a processed image (e.g. Canny edges, openpose, grayscale etc.) and the second being an AI-generated image using the first image as guidance. +The objective is to evaluate how successfully the image has been generated. +""" + +_prompts_0shot_rule_PQ = """RULES: + +The image is an AI-generated image. +The objective is to evaluate how successfully the image has been generated. + +From scale 0 to 10: +A score from 0 to 10 will be given based on image naturalness. +( + 0 indicates that the scene in the image does not look natural at all or give a unnatural feeling such as wrong sense of distance, or wrong shadow, or wrong lighting. + 10 indicates that the image looks natural. +) +A second score from 0 to 10 will rate the image artifacts. +( + 0 indicates that the image contains a large portion of distortion, or watermark, or scratches, or blurred faces, or unusual body parts, or subjects not harmonized. + 10 indicates the image has no artifacts. +) +Put the score in a list such that output score = [naturalness, artifacts] +""" + +_prompts_0shot_t2v_rule_SC = """From scale 0 to 10: +A score from 0 to 10 will be given based on the success in following the prompt. +(0 indicates that the image frames does not follow the prompt at all. 10 indicates the image frames follows the prompt perfectly.) + +Put the score in a list such that output score = [score]. + +Text Prompt: +""" + +_prompts_0shot_multi_subject_image_gen_rule = """RULES: + +Two images will be provided: +This first image is a concatenation of two sub-images, each sub-image contain one token subject. +The second image being an AI-generated image using the first image as guidance. +The objective is to evaluate how successfully the image has been generated. +""" + +_prompts_0shot_subject_image_gen_rule = """RULES: + +Two images will be provided: The first being a token subject image and the second being an AI-generated image using the first image as guidance. +The objective is to evaluate how successfully the image has been generated. +""" + +_prompts_0shot_one_image_gen_rule = """RULES: + +The image is an AI-generated image according to the text prompt. +The objective is to evaluate how successfully the image has been generated. +""" From 55b44cd69873dfa57b720767ed3c21b7740cbcd5 Mon Sep 17 00:00:00 2001 From: KemingWu Date: Thu, 4 Dec 2025 11:45:34 +0000 Subject: [PATCH 2/3] fix videomme --- lmms_eval/tasks/videomme/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lmms_eval/tasks/videomme/utils.py b/lmms_eval/tasks/videomme/utils.py index 36e529dc4..b205af172 100644 --- a/lmms_eval/tasks/videomme/utils.py +++ b/lmms_eval/tasks/videomme/utils.py @@ -9,11 +9,11 @@ from typing import Dict, List, Optional, Union import cv2 +import datasets import numpy as np import yaml from loguru import logger as eval_logger -import datasets from lmms_eval.tasks._task_utils.file_utils import generate_submission_file VIDEO_TYPE = ["short", "medium", "long"] From 121e21916a5db98bfcfce8ffdf8e65da7edb2a72 Mon Sep 17 00:00:00 2001 From: KemingWu Date: Fri, 5 Dec 2025 06:16:23 +0000 Subject: [PATCH 3/3] add task imgedit for bagel --- lmms_eval/models/simple/bagel.py | 14 +- lmms_eval/tasks/gedit_bench/gedit_bench.yaml | 76 +- lmms_eval/tasks/gedit_bench/utils.py | 311 ++--- lmms_eval/tasks/imgedit/__init__.py | 3 + lmms_eval/tasks/imgedit/imgedit.yaml | 134 ++ lmms_eval/tasks/imgedit/prepare_dataset.py | 200 +++ lmms_eval/tasks/imgedit/utils.py | 1185 ++++++++++++++++++ 7 files changed, 1731 insertions(+), 192 deletions(-) create mode 100644 lmms_eval/tasks/imgedit/__init__.py create mode 100644 lmms_eval/tasks/imgedit/imgedit.yaml create mode 100644 lmms_eval/tasks/imgedit/prepare_dataset.py create mode 100644 lmms_eval/tasks/imgedit/utils.py diff --git a/lmms_eval/models/simple/bagel.py b/lmms_eval/models/simple/bagel.py index 82df1acc4..a2c8d2339 100644 --- a/lmms_eval/models/simple/bagel.py +++ b/lmms_eval/models/simple/bagel.py @@ -234,10 +234,18 @@ def __init__( else: self.response_persistent_folder = response_persistent_folder - if output_image_dir is None: - self.output_image_dir = os.path.join(self.response_persistent_folder, "bagel_generated_images") - else: + # Check for task-specific output directory from environment variables + # Priority: output_image_dir param > IMGEDIT_OUTPUT_DIR > GEDIT_BENCH_OUTPUT_DIR > default + if output_image_dir is not None: self.output_image_dir = output_image_dir + elif os.getenv("IMGEDIT_OUTPUT_DIR"): + self.output_image_dir = os.getenv("IMGEDIT_OUTPUT_DIR") + eval_logger.info(f"Using IMGEDIT_OUTPUT_DIR: {self.output_image_dir}") + elif os.getenv("GEDIT_BENCH_OUTPUT_DIR"): + self.output_image_dir = os.getenv("GEDIT_BENCH_OUTPUT_DIR") + eval_logger.info(f"Using GEDIT_BENCH_OUTPUT_DIR: {self.output_image_dir}") + else: + self.output_image_dir = os.path.join(self.response_persistent_folder, "bagel_generated_images") os.makedirs(self.output_image_dir, exist_ok=True) eval_logger.info(f"Image output directory: {self.output_image_dir}") diff --git a/lmms_eval/tasks/gedit_bench/gedit_bench.yaml b/lmms_eval/tasks/gedit_bench/gedit_bench.yaml index 141b0db8c..17e229514 100644 --- a/lmms_eval/tasks/gedit_bench/gedit_bench.yaml +++ b/lmms_eval/tasks/gedit_bench/gedit_bench.yaml @@ -17,15 +17,17 @@ generation_kwargs: # The return value of process_results will be used by metrics process_results: !function utils.gedit_bench_process_results +# ============================================ # Metrics breakdown: # - Overall scores (all samples) # - By language: en (English), cn (Chinese) # - By subset: fullset (all), intersection (subset with both en and cn) # - Score types: semantics, quality, overall +# ============================================ metric_list: # ========================================== - # Overall Scores (All Samples) + # Overall Scores (All Samples) - Global Average # ========================================== - metric: gedit_bench_semantics_score aggregation: !function utils.gedit_bench_aggregate_results @@ -38,73 +40,61 @@ metric_list: higher_is_better: true # ========================================== - # English - Fullset + # English - Fullset (All English samples) # ========================================== - - metric: gedit_bench_semantics_score - aggregation: !function utils.gedit_bench_aggregate_en_fullset + - metric: gedit_bench_en_fullset_semantics + aggregation: !function utils.gedit_bench_aggregate_en_fullset_semantics higher_is_better: true - metric_name: gedit_bench_en_fullset_semantics - - metric: gedit_bench_quality_score - aggregation: !function utils.gedit_bench_aggregate_en_fullset + - metric: gedit_bench_en_fullset_quality + aggregation: !function utils.gedit_bench_aggregate_en_fullset_quality higher_is_better: true - metric_name: gedit_bench_en_fullset_quality - - metric: gedit_bench_overall_score - aggregation: !function utils.gedit_bench_aggregate_en_fullset + - metric: gedit_bench_en_fullset_overall + aggregation: !function utils.gedit_bench_aggregate_en_fullset_overall higher_is_better: true - metric_name: gedit_bench_en_fullset_overall # ========================================== - # English - Intersection + # English - Intersection (English samples in intersection set) # ========================================== - - metric: gedit_bench_semantics_score - aggregation: !function utils.gedit_bench_aggregate_en_intersection + - metric: gedit_bench_en_intersection_semantics + aggregation: !function utils.gedit_bench_aggregate_en_intersection_semantics higher_is_better: true - metric_name: gedit_bench_en_intersection_semantics - - metric: gedit_bench_quality_score - aggregation: !function utils.gedit_bench_aggregate_en_intersection + - metric: gedit_bench_en_intersection_quality + aggregation: !function utils.gedit_bench_aggregate_en_intersection_quality higher_is_better: true - metric_name: gedit_bench_en_intersection_quality - - metric: gedit_bench_overall_score - aggregation: !function utils.gedit_bench_aggregate_en_intersection + - metric: gedit_bench_en_intersection_overall + aggregation: !function utils.gedit_bench_aggregate_en_intersection_overall higher_is_better: true - metric_name: gedit_bench_en_intersection_overall # ========================================== - # Chinese - Fullset + # Chinese - Fullset (All Chinese samples) # ========================================== - - metric: gedit_bench_semantics_score - aggregation: !function utils.gedit_bench_aggregate_cn_fullset + - metric: gedit_bench_cn_fullset_semantics + aggregation: !function utils.gedit_bench_aggregate_cn_fullset_semantics higher_is_better: true - metric_name: gedit_bench_cn_fullset_semantics - - metric: gedit_bench_quality_score - aggregation: !function utils.gedit_bench_aggregate_cn_fullset + - metric: gedit_bench_cn_fullset_quality + aggregation: !function utils.gedit_bench_aggregate_cn_fullset_quality higher_is_better: true - metric_name: gedit_bench_cn_fullset_quality - - metric: gedit_bench_overall_score - aggregation: !function utils.gedit_bench_aggregate_cn_fullset + - metric: gedit_bench_cn_fullset_overall + aggregation: !function utils.gedit_bench_aggregate_cn_fullset_overall higher_is_better: true - metric_name: gedit_bench_cn_fullset_overall # ========================================== - # Chinese - Intersection + # Chinese - Intersection (Chinese samples in intersection set) # ========================================== - - metric: gedit_bench_semantics_score - aggregation: !function utils.gedit_bench_aggregate_cn_intersection + - metric: gedit_bench_cn_intersection_semantics + aggregation: !function utils.gedit_bench_aggregate_cn_intersection_semantics higher_is_better: true - metric_name: gedit_bench_cn_intersection_semantics - - metric: gedit_bench_quality_score - aggregation: !function utils.gedit_bench_aggregate_cn_intersection + - metric: gedit_bench_cn_intersection_quality + aggregation: !function utils.gedit_bench_aggregate_cn_intersection_quality higher_is_better: true - metric_name: gedit_bench_cn_intersection_quality - - metric: gedit_bench_overall_score - aggregation: !function utils.gedit_bench_aggregate_cn_intersection + - metric: gedit_bench_cn_intersection_overall + aggregation: !function utils.gedit_bench_aggregate_cn_intersection_overall higher_is_better: true - metric_name: gedit_bench_cn_intersection_overall lmms_eval_specific_kwargs: default: pre_prompt: "" post_prompt: "" metadata: - - version: 0.1 - description: "GEdit-Bench with detailed metrics by language and subset" + - version: 0.2 + description: "GEdit-Bench with detailed metrics by language (en/cn) and subset (fullset/intersection)" diff --git a/lmms_eval/tasks/gedit_bench/utils.py b/lmms_eval/tasks/gedit_bench/utils.py index b5664315a..c88d0eab6 100644 --- a/lmms_eval/tasks/gedit_bench/utils.py +++ b/lmms_eval/tasks/gedit_bench/utils.py @@ -128,6 +128,57 @@ def _save_image_to_structure( return "" +def _create_result_entry(key, task_type, instruction_language, score, intersection_exist): + """Helper to create a result entry dict""" + return { + "key": key, + "task_type": task_type, + "instruction_language": instruction_language, + "score": score, + "intersection_exist": intersection_exist, + } + + +def _create_all_metric_results(key, task_type, instruction_language, semantics_score, quality_score, overall_score, intersection_exist): + """ + Create result dict with all metric keys for detailed breakdown. + + Returns metrics for: + - Overall scores (all samples) + - English fullset and intersection + - Chinese fullset and intersection + """ + base_entry = { + "key": key, + "task_type": task_type, + "instruction_language": instruction_language, + "intersection_exist": intersection_exist, + } + + return { + # Overall scores (used for global aggregation) + "gedit_bench_semantics_score": {**base_entry, "score": semantics_score}, + "gedit_bench_quality_score": {**base_entry, "score": quality_score}, + "gedit_bench_overall_score": {**base_entry, "score": overall_score}, + # English fullset metrics + "gedit_bench_en_fullset_semantics": {**base_entry, "score": semantics_score}, + "gedit_bench_en_fullset_quality": {**base_entry, "score": quality_score}, + "gedit_bench_en_fullset_overall": {**base_entry, "score": overall_score}, + # English intersection metrics + "gedit_bench_en_intersection_semantics": {**base_entry, "score": semantics_score}, + "gedit_bench_en_intersection_quality": {**base_entry, "score": quality_score}, + "gedit_bench_en_intersection_overall": {**base_entry, "score": overall_score}, + # Chinese fullset metrics + "gedit_bench_cn_fullset_semantics": {**base_entry, "score": semantics_score}, + "gedit_bench_cn_fullset_quality": {**base_entry, "score": quality_score}, + "gedit_bench_cn_fullset_overall": {**base_entry, "score": overall_score}, + # Chinese intersection metrics + "gedit_bench_cn_intersection_semantics": {**base_entry, "score": semantics_score}, + "gedit_bench_cn_intersection_quality": {**base_entry, "score": quality_score}, + "gedit_bench_cn_intersection_overall": {**base_entry, "score": overall_score}, + } + + def gedit_bench_process_results(doc, results, **kwargs): """ Process model predictions: @@ -141,7 +192,7 @@ def gedit_bench_process_results(doc, results, **kwargs): **kwargs: Additional arguments (may include full_docs) Returns: - Dict with metrics: semantics_score, quality_score, overall_score + Dict with metrics for all breakdown categories """ # Get configuration from environment variables or use defaults # Note: defaults should match bagel.py's output_image_dir structure @@ -186,29 +237,7 @@ def gedit_bench_process_results(doc, results, **kwargs): if input_image_pil is None: eval_logger.warning(f"No input image found for key {key} (neither in doc nor as _SRCIMG file)") - return { - "gedit_bench_semantics_score": { - "key": key, - "task_type": task_type, - "instruction_language": instruction_language, - "score": 0.0, - "intersection_exist": intersection_exist, - }, - "gedit_bench_quality_score": { - "key": key, - "task_type": task_type, - "instruction_language": instruction_language, - "score": 0.0, - "intersection_exist": intersection_exist, - }, - "gedit_bench_overall_score": { - "key": key, - "task_type": task_type, - "instruction_language": instruction_language, - "score": 0.0, - "intersection_exist": intersection_exist, - }, - } + return _create_all_metric_results(key, task_type, instruction_language, 0.0, 0.0, 0.0, intersection_exist) # Save generated images to required structure (or use existing) edited_image_path = None @@ -242,56 +271,12 @@ def gedit_bench_process_results(doc, results, **kwargs): # If still no edited image, return zero scores if edited_image_path is None: eval_logger.warning(f"No generated images found for key {key}") - return { - "gedit_bench_semantics_score": { - "key": key, - "task_type": task_type, - "instruction_language": instruction_language, - "score": 0.0, - "intersection_exist": intersection_exist, - }, - "gedit_bench_quality_score": { - "key": key, - "task_type": task_type, - "instruction_language": instruction_language, - "score": 0.0, - "intersection_exist": intersection_exist, - }, - "gedit_bench_overall_score": { - "key": key, - "task_type": task_type, - "instruction_language": instruction_language, - "score": 0.0, - "intersection_exist": intersection_exist, - }, - } + return _create_all_metric_results(key, task_type, instruction_language, 0.0, 0.0, 0.0, intersection_exist) # Evaluate using VIEScore if not VIESCORE_AVAILABLE: eval_logger.warning("VIEScore not available, skipping evaluation") - return { - "gedit_bench_semantics_score": { - "key": key, - "task_type": task_type, - "instruction_language": instruction_language, - "score": 0.0, - "intersection_exist": intersection_exist, - }, - "gedit_bench_quality_score": { - "key": key, - "task_type": task_type, - "instruction_language": instruction_language, - "score": 0.0, - "intersection_exist": intersection_exist, - }, - "gedit_bench_overall_score": { - "key": key, - "task_type": task_type, - "instruction_language": instruction_language, - "score": 0.0, - "intersection_exist": intersection_exist, - }, - } + return _create_all_metric_results(key, task_type, instruction_language, 0.0, 0.0, 0.0, intersection_exist) try: # Load edited image @@ -311,61 +296,18 @@ def gedit_bench_process_results(doc, results, **kwargs): score_list = vie_score.evaluate([input_image_pil, edited_image_pil], instruction) semantics_score, quality_score, overall_score = score_list - eval_logger.info(f"[{task_type}] Key {key}: " f"Semantics={semantics_score:.3f}, " f"Quality={quality_score:.3f}, " f"Overall={overall_score:.3f}, " f"Language={instruction_language}") - - return { - "gedit_bench_semantics_score": { - "key": key, - "task_type": task_type, - "instruction_language": instruction_language, - "score": float(semantics_score), - "intersection_exist": intersection_exist, - }, - "gedit_bench_quality_score": { - "key": key, - "task_type": task_type, - "instruction_language": instruction_language, - "score": float(quality_score), - "intersection_exist": intersection_exist, - }, - "gedit_bench_overall_score": { - "key": key, - "task_type": task_type, - "instruction_language": instruction_language, - "score": float(overall_score), - "intersection_exist": intersection_exist, - }, - } + eval_logger.info(f"[{task_type}] Key {key}: " f"Semantics={semantics_score:.3f}, " f"Quality={quality_score:.3f}, " f"Overall={overall_score:.3f}, " f"Language={instruction_language}, " f"Intersection={intersection_exist}") + + return _create_all_metric_results(key, task_type, instruction_language, float(semantics_score), float(quality_score), float(overall_score), intersection_exist) except Exception as e: eval_logger.error(f"Error evaluating key {key}: {e}") - return { - "gedit_bench_semantics_score": { - "key": key, - "task_type": task_type, - "instruction_language": instruction_language, - "score": 0.0, - "intersection_exist": intersection_exist, - }, - "gedit_bench_quality_score": { - "key": key, - "task_type": task_type, - "instruction_language": instruction_language, - "score": 0.0, - "intersection_exist": intersection_exist, - }, - "gedit_bench_overall_score": { - "key": key, - "task_type": task_type, - "instruction_language": instruction_language, - "score": 0.0, - "intersection_exist": intersection_exist, - }, - } + return _create_all_metric_results(key, task_type, instruction_language, 0.0, 0.0, 0.0, intersection_exist) def gedit_bench_aggregate_results(results): """ - Aggregate results across all samples and compute final scores + Aggregate results across all samples and compute final scores. + Also logs detailed breakdown by task type, language, and intersection status. Args: results: List of result dicts from process_results, each containing: @@ -409,34 +351,34 @@ def gedit_bench_aggregate_results(results): non_intersection_scores.append(r["score"]) # Log statistics - eval_logger.info(f"Overall average score: {avg_score:.3f}") + eval_logger.info(f"Overall average score: {avg_score:.4f}") eval_logger.info(f"Number of samples: {len(scores)}") if task_type_scores: eval_logger.info("Scores by task type:") for task_type, task_scores in sorted(task_type_scores.items()): task_avg = np.mean(task_scores) - eval_logger.info(f" {task_type}: {task_avg:.3f} (n={len(task_scores)})") + eval_logger.info(f" {task_type}: {task_avg:.4f} (n={len(task_scores)})") if language_scores: eval_logger.info("Scores by language:") for language, lang_scores in sorted(language_scores.items()): lang_avg = np.mean(lang_scores) - eval_logger.info(f" {language}: {lang_avg:.3f} (n={len(lang_scores)})") + eval_logger.info(f" {language}: {lang_avg:.4f} (n={len(lang_scores)})") if intersection_scores: intersection_avg = np.mean(intersection_scores) - eval_logger.info(f"Intersection samples average: {intersection_avg:.3f} (n={len(intersection_scores)})") + eval_logger.info(f"Intersection samples average: {intersection_avg:.4f} (n={len(intersection_scores)})") if non_intersection_scores: non_intersection_avg = np.mean(non_intersection_scores) - eval_logger.info(f"Non-intersection samples average: {non_intersection_avg:.3f} (n={len(non_intersection_scores)})") + eval_logger.info(f"Non-intersection samples average: {non_intersection_avg:.4f} (n={len(non_intersection_scores)})") - return avg_score + return float(avg_score) # ============================================ -# Detailed Aggregation Functions by Language and Subset +# Helper Function for Filtered Aggregation # ============================================ @@ -447,7 +389,7 @@ def _aggregate_by_filter(results, language: str = None, intersection_only: bool Args: results: List of result dicts language: Filter by language ("en" or "cn"), None for all - intersection_only: True for intersection subset, False for non-intersection, None for all + intersection_only: True for intersection subset only, None for all (fullset) Returns: Average score for filtered samples @@ -465,12 +407,12 @@ def _aggregate_by_filter(results, language: str = None, intersection_only: bool if r.get("instruction_language", "unknown") != language: continue - # Apply intersection filter - if intersection_only is not None: + # Apply intersection filter (only filter if intersection_only is True) + # intersection_only=None means fullset (all samples of that language) + # intersection_only=True means only intersection samples + if intersection_only is True: is_intersection = r.get("intersection_exist", False) - if intersection_only and not is_intersection: - continue - if not intersection_only and is_intersection: + if not is_intersection: continue filtered_scores.append(r["score"]) @@ -478,12 +420,99 @@ def _aggregate_by_filter(results, language: str = None, intersection_only: bool if not filtered_scores: return 0.0 - return float(np.mean(filtered_scores)) + avg = float(np.mean(filtered_scores)) + + # Log filter info + lang_str = language if language else "all" + subset_str = "intersection" if intersection_only else "fullset" + eval_logger.debug(f"Aggregating {lang_str} {subset_str}: {avg:.4f} (n={len(filtered_scores)})") + + return avg + + +# ============================================ +# English - Fullset Aggregations +# ============================================ + + +def gedit_bench_aggregate_en_fullset_semantics(results): + """Aggregate English fullset semantics scores""" + return _aggregate_by_filter(results, language="en", intersection_only=None) + + +def gedit_bench_aggregate_en_fullset_quality(results): + """Aggregate English fullset quality scores""" + return _aggregate_by_filter(results, language="en", intersection_only=None) -# ========================================== -# English Language Aggregations -# ========================================== +def gedit_bench_aggregate_en_fullset_overall(results): + """Aggregate English fullset overall scores""" + return _aggregate_by_filter(results, language="en", intersection_only=None) + + +# ============================================ +# English - Intersection Aggregations +# ============================================ + + +def gedit_bench_aggregate_en_intersection_semantics(results): + """Aggregate English intersection semantics scores""" + return _aggregate_by_filter(results, language="en", intersection_only=True) + + +def gedit_bench_aggregate_en_intersection_quality(results): + """Aggregate English intersection quality scores""" + return _aggregate_by_filter(results, language="en", intersection_only=True) + + +def gedit_bench_aggregate_en_intersection_overall(results): + """Aggregate English intersection overall scores""" + return _aggregate_by_filter(results, language="en", intersection_only=True) + + +# ============================================ +# Chinese - Fullset Aggregations +# ============================================ + + +def gedit_bench_aggregate_cn_fullset_semantics(results): + """Aggregate Chinese fullset semantics scores""" + return _aggregate_by_filter(results, language="cn", intersection_only=None) + + +def gedit_bench_aggregate_cn_fullset_quality(results): + """Aggregate Chinese fullset quality scores""" + return _aggregate_by_filter(results, language="cn", intersection_only=None) + + +def gedit_bench_aggregate_cn_fullset_overall(results): + """Aggregate Chinese fullset overall scores""" + return _aggregate_by_filter(results, language="cn", intersection_only=None) + + +# ============================================ +# Chinese - Intersection Aggregations +# ============================================ + + +def gedit_bench_aggregate_cn_intersection_semantics(results): + """Aggregate Chinese intersection semantics scores""" + return _aggregate_by_filter(results, language="cn", intersection_only=True) + + +def gedit_bench_aggregate_cn_intersection_quality(results): + """Aggregate Chinese intersection quality scores""" + return _aggregate_by_filter(results, language="cn", intersection_only=True) + + +def gedit_bench_aggregate_cn_intersection_overall(results): + """Aggregate Chinese intersection overall scores""" + return _aggregate_by_filter(results, language="cn", intersection_only=True) + + +# ============================================ +# Legacy Functions (for backward compatibility) +# ============================================ def gedit_bench_aggregate_en_fullset(results): @@ -496,11 +525,6 @@ def gedit_bench_aggregate_en_intersection(results): return _aggregate_by_filter(results, language="en", intersection_only=True) -# ========================================== -# Chinese Language Aggregations -# ========================================== - - def gedit_bench_aggregate_cn_fullset(results): """Aggregate Chinese fullset scores (all Chinese samples)""" return _aggregate_by_filter(results, language="cn", intersection_only=None) @@ -511,11 +535,6 @@ def gedit_bench_aggregate_cn_intersection(results): return _aggregate_by_filter(results, language="cn", intersection_only=True) -# ========================================== -# Intersection/Non-intersection Aggregations (All Languages) -# ========================================== - - def gedit_bench_aggregate_intersection(results): """Aggregate intersection subset scores (all languages)""" return _aggregate_by_filter(results, language=None, intersection_only=True) diff --git a/lmms_eval/tasks/imgedit/__init__.py b/lmms_eval/tasks/imgedit/__init__.py new file mode 100644 index 000000000..653769f18 --- /dev/null +++ b/lmms_eval/tasks/imgedit/__init__.py @@ -0,0 +1,3 @@ +# ImgEdit Benchmark Task +# Paper: ImgEdit: A Unified Image Editing Benchmark +# https://github.com/sysuyy/ImgEdit diff --git a/lmms_eval/tasks/imgedit/imgedit.yaml b/lmms_eval/tasks/imgedit/imgedit.yaml new file mode 100644 index 000000000..278001137 --- /dev/null +++ b/lmms_eval/tasks/imgedit/imgedit.yaml @@ -0,0 +1,134 @@ +# ImgEdit Benchmark Configuration +# Paper: ImgEdit: A Unified Image Editing Benchmark +# https://github.com/sysuyy/ImgEdit +# +# Dataset preparation: +# 1. Download Benchmark.tar from https://huggingface.co/datasets/sysuyy/ImgEdit/blob/main/Benchmark.tar +# 2. Extract: tar -xvf Benchmark.tar +# 3. Run prepare_dataset.py to convert JSON to HuggingFace dataset format: +# python lmms_eval/tasks/imgedit/prepare_dataset.py \ +# --json_file /path/to/singleturn.json \ +# --img_root /path/to/singleturn/ \ +# --output_dir /path/to/output/dataset +# +# Environment variables: +# - IMGEDIT_MODEL_NAME: Name of the model being evaluated (default: "default") +# - IMGEDIT_OUTPUT_DIR: Directory to save generated images (default: "./logs/imgedit_results") +# - IMGEDIT_ORIGIN_IMG_ROOT: Root directory of original images (for runtime loading) +# - IMGEDIT_EVAL_BACKBONE: Evaluation model - "gpt4o" (default) or "qwen25vl" +# - OPENAI_API_KEY: OpenAI API key (for GPT-4o) +# - OPENAI_BASE_URL: Optional custom OpenAI API base URL +# - QWEN_MODEL_PATH: Path to Qwen2.5-VL model (for qwen25vl backend) + +# Local dataset path (after running prepare_dataset.py) +dataset_path: sysuyy/ImgEdit +dataset_kwargs: + load_from_disk: True +task: "imgedit" +test_split: train # The prepared dataset uses "train" split +output_type: generate_until + +# Document processing functions +doc_to_visual: !function utils.imgedit_doc_to_visual +doc_to_text: !function utils.imgedit_doc_to_text +doc_to_target: "prompt" + +# Generation parameters +generation_kwargs: + max_new_tokens: 512 + temperature: 0 + top_p: 1.0 + num_beams: 1 + do_sample: false + +# Process results using GPT-4o or Qwen2.5-VL evaluation (controlled by IMGEDIT_EVAL_BACKBONE) +process_results: !function utils.imgedit_process_results + +# Metrics +# ========================================== +# Overall metrics (aggregated across all edit types) +# ========================================== +# Score1: Prompt Compliance / Style Fidelity / etc. (depends on edit type) +# Score2: Visual Naturalness / Content Preservation / etc. +# Score3: Physical & Detail Integrity / Rendering Quality / etc. +# avg_score: Average of Score1, Score2, Score3 +metric_list: + # Overall scores + - metric: imgedit_score1 + aggregation: !function utils.imgedit_aggregate_results + higher_is_better: true + - metric: imgedit_score2 + aggregation: !function utils.imgedit_aggregate_results + higher_is_better: true + - metric: imgedit_score3 + aggregation: !function utils.imgedit_aggregate_results + higher_is_better: true + - metric: imgedit_avg_score + aggregation: !function utils.imgedit_aggregate_results + higher_is_better: true + + # ========================================== + # Per-type average scores + # ========================================== + # Replace type + - metric: imgedit_avg_score + aggregation: !function utils.imgedit_aggregate_replace + higher_is_better: true + metric_name: imgedit_replace_score + + # Add type + - metric: imgedit_avg_score + aggregation: !function utils.imgedit_aggregate_add + higher_is_better: true + metric_name: imgedit_add_score + + # Adjust type + - metric: imgedit_avg_score + aggregation: !function utils.imgedit_aggregate_adjust + higher_is_better: true + metric_name: imgedit_adjust_score + + # Remove type + - metric: imgedit_avg_score + aggregation: !function utils.imgedit_aggregate_remove + higher_is_better: true + metric_name: imgedit_remove_score + + # Style type + - metric: imgedit_avg_score + aggregation: !function utils.imgedit_aggregate_style + higher_is_better: true + metric_name: imgedit_style_score + + # Action type + - metric: imgedit_avg_score + aggregation: !function utils.imgedit_aggregate_action + higher_is_better: true + metric_name: imgedit_action_score + + # Extract type + - metric: imgedit_avg_score + aggregation: !function utils.imgedit_aggregate_extract + higher_is_better: true + metric_name: imgedit_extract_score + + # Background type + - metric: imgedit_avg_score + aggregation: !function utils.imgedit_aggregate_background + higher_is_better: true + metric_name: imgedit_background_score + + # Compose type + - metric: imgedit_avg_score + aggregation: !function utils.imgedit_aggregate_compose + higher_is_better: true + metric_name: imgedit_compose_score + +lmms_eval_specific_kwargs: + default: + pre_prompt: "" + post_prompt: "" + +metadata: + - version: 0.2 + description: "ImgEdit Benchmark - Unified Image Editing Evaluation using GPT-4o or Qwen2.5-VL" diff --git a/lmms_eval/tasks/imgedit/prepare_dataset.py b/lmms_eval/tasks/imgedit/prepare_dataset.py new file mode 100644 index 000000000..ce345603f --- /dev/null +++ b/lmms_eval/tasks/imgedit/prepare_dataset.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +""" +Prepare ImgEdit dataset for lmms-eval + +This script converts the ImgEdit singleturn.json format to a Hugging Face datasets +compatible format that can be loaded by lmms-eval. + +Usage: + python prepare_dataset.py \ + --json_file /path/to/singleturn.json \ + --img_root /path/to/singleturn/ \ + --output_dir /path/to/output/dataset + +The output dataset will be saved in Arrow format and can be loaded with: + datasets.load_from_disk(output_dir) +""" + +import argparse +import json +import os +from pathlib import Path + +from datasets import Dataset, Features, Image, Value +from PIL import Image as PILImage +from tqdm import tqdm + + +def load_singleturn_json(json_path: str) -> dict: + """Load the singleturn.json file""" + with open(json_path, "r") as f: + return json.load(f) + + +def convert_to_dataset_format( + singleturn_data: dict, + img_root: str, + verify_images: bool = True, +) -> list: + """ + Convert singleturn.json format to dataset format. + + Args: + singleturn_data: Dict loaded from singleturn.json + img_root: Root directory containing images (e.g., .../singleturn/) + verify_images: Whether to verify that images exist + + Returns: + List of dicts suitable for creating a HF Dataset + """ + records = [] + missing_images = [] + + for key, item in tqdm(singleturn_data.items(), desc="Converting"): + image_id = item.get("id", "") # e.g., "animal/000342021.jpg" + prompt = item.get("prompt", "") + edit_type = item.get("edit_type", "adjust") + + # Full path to original image + image_path = os.path.join(img_root, image_id) + + if verify_images and not os.path.exists(image_path): + missing_images.append(image_path) + continue + + record = { + "key": key, + "id": image_id, + "prompt": prompt, + "edit_type": edit_type, + "image_path": image_path, + } + + # Try to load image if verification is enabled + if verify_images: + try: + img = PILImage.open(image_path).convert("RGB") + record["input_image"] = img + except Exception as e: + print(f"Warning: Failed to load image {image_path}: {e}") + missing_images.append(image_path) + continue + + records.append(record) + + if missing_images: + print(f"\nWarning: {len(missing_images)} images not found:") + for path in missing_images[:10]: + print(f" - {path}") + if len(missing_images) > 10: + print(f" ... and {len(missing_images) - 10} more") + + return records + + +def create_dataset(records: list, include_images: bool = True) -> Dataset: + """Create a Hugging Face Dataset from records""" + if include_images: + # Dataset with embedded images + features = Features( + { + "key": Value("string"), + "id": Value("string"), + "prompt": Value("string"), + "edit_type": Value("string"), + "image_path": Value("string"), + "input_image": Image(), + } + ) + else: + # Dataset with only paths (images loaded at runtime) + # Remove input_image from records if present + for record in records: + if "input_image" in record: + del record["input_image"] + + features = Features( + { + "key": Value("string"), + "id": Value("string"), + "prompt": Value("string"), + "edit_type": Value("string"), + "image_path": Value("string"), + } + ) + + return Dataset.from_list(records, features=features) + + +def main(): + parser = argparse.ArgumentParser(description="Prepare ImgEdit dataset for lmms-eval") + parser.add_argument( + "--json_file", + type=str, + default="ImgEdit/Benchmark/singleturn/singleturn.json", + help="Path to singleturn.json file", + ) + parser.add_argument( + "--img_root", + type=str, + default="ImgEdit/Benchmark/singleturn", + help="Root directory containing images", + ) + parser.add_argument( + "--output_dir", + type=str, + default="ImgEdit/imgedit_dataset", + help="Output directory for the prepared dataset", + ) + parser.add_argument( + "--embed_images", + action="store_true", + help="Embed images in the dataset (larger file size but faster loading)", + ) + parser.add_argument( + "--skip_verification", + action="store_true", + help="Skip image verification (useful if images are not yet downloaded)", + ) + args = parser.parse_args() + + print(f"Loading singleturn.json from: {args.json_file}") + singleturn_data = load_singleturn_json(args.json_file) + print(f"Found {len(singleturn_data)} entries") + + print(f"\nConverting to dataset format...") + print(f"Image root: {args.img_root}") + records = convert_to_dataset_format( + singleturn_data, + args.img_root, + verify_images=not args.skip_verification, + ) + print(f"Converted {len(records)} valid entries") + + print(f"\nCreating HuggingFace Dataset...") + dataset = create_dataset(records, include_images=args.embed_images) + + print(f"\nSaving dataset to: {args.output_dir}") + os.makedirs(args.output_dir, exist_ok=True) + dataset.save_to_disk(args.output_dir) + + print(f"\nDone! Dataset saved with {len(dataset)} samples") + print(f"\nTo use with lmms-eval, update imgedit.yaml:") + print(f" dataset_path: {args.output_dir}") + print(f" dataset_kwargs:") + print(f" load_from_disk: True") + print(f" test_split: train") + + # Print edit type distribution + edit_types = {} + for record in records: + edit_type = record.get("edit_type", "unknown") + edit_types[edit_type] = edit_types.get(edit_type, 0) + 1 + + print(f"\nEdit type distribution:") + for edit_type, count in sorted(edit_types.items()): + print(f" {edit_type}: {count}") + + +if __name__ == "__main__": + main() diff --git a/lmms_eval/tasks/imgedit/utils.py b/lmms_eval/tasks/imgedit/utils.py new file mode 100644 index 000000000..72d10155e --- /dev/null +++ b/lmms_eval/tasks/imgedit/utils.py @@ -0,0 +1,1185 @@ +""" +ImgEdit Benchmark Utils +Image editing evaluation task using GPT-4o or Qwen2.5-VL + +Based on: https://github.com/sysuyy/ImgEdit +Paper: ImgEdit: A Unified Image Editing Benchmark + +Environment variables: + - IMGEDIT_EVAL_BACKBONE: "gpt4o" or "qwen25vl" (default: "gpt4o") + - IMGEDIT_MODEL_NAME: Name of the model being evaluated + - IMGEDIT_OUTPUT_DIR: Directory to save generated images + - IMGEDIT_ORIGIN_IMG_ROOT: Root directory of original images + - OPENAI_API_KEY: OpenAI API key (for GPT-4o) + - OPENAI_BASE_URL: Optional custom OpenAI API base URL + - QWEN_MODEL_PATH: Path to Qwen2.5-VL model (default: Qwen/Qwen2.5-VL-72B-Instruct-AWQ) +""" + +import base64 +import json +import os +import re +from collections import defaultdict +from io import BytesIO +from typing import Dict, List, Optional, Tuple + +import numpy as np +from loguru import logger as eval_logger +from PIL import Image + +# Global variable to cache Qwen2.5-VL model (lazy loading) +_qwen25vl_model = None + +# Edit type prompts for evaluation +# These are the evaluation criteria for different edit types +IMGEDIT_PROMPTS = { + "replace": """ +You are a data rater specializing in grading image replacement edits. You will be given two images (before and after editing) and the corresponding editing instructions. Your task is to evaluate the replacement editing effect on a 5-point scale from three perspectives: + +Prompt Compliance +1 Target not replaced, or an unrelated object edited. +2 Only part of the target replaced, or wrong class/description used. +3 Target largely replaced but other objects altered, remnants visible, or count/position clearly wrong. +4 Correct object fully replaced; only minor attribute errors (colour, size, etc.). +5 Perfect replacement: all and only the specified objects removed; new objects' class, number, position, scale, pose and detail exactly match the prompt. + +Visual Naturalness +1 Image heavily broken or new object deformed / extremely blurred. +2 Obvious seams, smears, or strong mismatch in resolution or colour; background not restored. +3 Basic style similar, but lighting or palette clashes; fuzzy edges or noise are noticeable. +4 Style almost uniform; tiny edge artefacts visible only on close inspection; casual viewers see no edit. +5 Completely seamless; new objects blend fully with the scene, edit area undetectable. + +Physical & Detail Integrity +1 Floating, interpenetration, severe perspective/light errors; key original elements ruined; background heavily warped. +2 Missing shadows/occlusion; large background shifts or holes. +3 Lighting, perspective and contact surfaces mostly correct; small but tolerable errors; background adjusted locally. +4 New objects interact realistically with scene (shadows, reflections, texture) and preserve existing details; background change minimal. +5 Physically flawless and enhances realism: accurate highlights, shadows, reflections, ambient effects; background untouched. +The second and third score should no higher than first score!!! + +Example Response Format: +Brief reasoning: A short explanation of the score based on the criteria above, no more than 20 words. +Prompt Compliance: A number from 1 to 5. +Visual Naturalness: A number from 1 to 5. +Physical & Detail Integrity: A number from 1 to 5. +editing instruction is : . + +Below are the images before and after editing: +""", + "add": """ +You are a data rater specializing in grading image addition edits. You will be given two images (before and after editing) and the corresponding editing instructions. Your task is to evaluate the added object(s) on a 5-point scale from three perspectives: + +Prompt Compliance +1 Nothing added or the added content is corrupt. +2 Added object is a wrong class or unrelated to the prompt. +3 Correct class, but key attributes (position, colour, size, count, etc.) are wrong. +4 Main attributes correct; only minor details off or 1-2 small features missing. +5 Every stated attribute correct and scene logic reasonable; only microscopic flaws. + +Visual Naturalness +1 Image badly broken or full of artefacts. +2 Obvious paste marks; style, resolution, or palette strongly mismatch. +3 General style similar, but lighting or colours clearly clash; noticeable disharmony. +4 Style almost uniform; small edge issues visible only when zoomed. +5 Perfect blend; no visible difference between added object and original image. + +Physical & Detail Coherence +1 Severe physical errors (floating, wrong perspective/light); key original elements blocked; background heavily distorted. +2 Contact or occlusion handled poorly; minor background shifts, jaggies or noise; background visibly changed. +3 Lighting, perspective, and contact mostly correct; remaining flaws small and acceptable; limited background change. +4 Shadows, reflections, and material response believable; no loss of original detail; background changes are minute. +5 Added object enhances overall realism: precise highlights, shadows, ambient effects; background essentially untouched. +The second and third score should no higher than first score!!! + +Example Response Format: +Brief reasoning: A short explanation of the score based on the criteria above, no more than 20 words. +Prompt Compliance: A number from 1 to 5. +Visual Naturalness: A number from 1 to 5. +Physical & Detail Coherence: A number from 1 to 5. +editing instruction is : . + +Below are the images before and after editing: +""", + "adjust": """ +You are a data rater specializing in grading attribute alteration edits. You will be given two images (before and after editing) and the corresponding editing instructions. Your task is to evaluate the attribute change on a 5-point scale from three perspectives: + +Prompt Compliance +1 Target not adjusted, wrong object touched, or geometry changed. +2 Right object but wrong attribute value/direction; only part edited; other objects also altered; slight stretch/crop. +3 Mainly correct object and attribute, yet large hue/brightness/texture error; minor collateral edits; visible jaggies/distortion. +4 All requested objects adjusted, only their attributes changed; shape kept; small inaccuracy in colour, material or amount. +5 Exactly and only the requested objects adjusted; colour, material, gloss etc. match the prompt perfectly; shape 100% intact; zero unintended edits. + +Visual Seamlessness +1 Massive colour spill, mosaics or heavy noise; image nearly unusable. +2 Clear smears/bleeding on edges; abrupt resolution or tone shift; highlights/shadows clipped; background gaps. +3 Overall palette OK but local tone or grain conflicts; soft edges; noticeable disharmony. +4 Style unified, transitions smooth; only slight edge artefacts visible when zoomed. +5 No detectable edit traces; colours/materials fuse with scene lighting; edit area practically invisible. + +Physical & Detail Fidelity +1 Object floating, interpenetrating, or severe perspective/light mismatch; background badly warped. +2 Missing shadows/highlights; wrong reflection direction; background visibly discoloured or distorted. +3 Light, perspective and contact surface largely correct; minor acceptable flaws; background only locally affected. +4 Adjusted material interacts believably with scene; shadows, highlights, reflections handled well; original details preserved. +5 High physical realism: fine micro-highlights, diffuse bounce, subsurface effects present; overall scene realism improved. +The second and third score should no higher than first score!!! + +Example Response Format: +Brief reasoning: A short explanation of the score based on the criteria above, no more than 20 words. +Prompt Compliance: A number from 1 to 5. +Visual Seamlessness: A number from 1 to 5. +Physical & Detail Fidelity: A number from 1 to 5. +editing instruction is : . + +Below are the images before and after editing: +""", + "remove": """ +You are a data rater specializing in grading object removal edits. You will be given two images (before and after editing) and the corresponding editing instructions. Your task is to evaluate the removal quality on a 5-point scale from three perspectives: + +Prompt Compliance +1 Nothing removed, or an unrelated object edited. +2 Target only partly removed, or a different instance/class deleted, or another object appears in the gap. +3 Target mostly removed but extra objects also deleted, or fragments of the target remain. +4 Only the specified objects removed, but a few tiny/background items deleted by mistake, or the count is wrong. +5 Perfect: all and only the requested objects removed; every other element untouched. + +Visual Naturalness +1 Image badly broken (large holes, strong artefacts). +2 Clear erase marks; colour/resolution mismatch; background not restored. +3 General look acceptable yet lighting/colour/style still clash; blur or noise visible. +4 Style consistent; minor edge issues visible only when zoomed. +5 Seamless: removal is virtually impossible to spot. + +Physical & Detail Integrity +1 Severe physical errors (floating items, wrong perspective/light); key scene elements damaged; background heavily warped. +2 Large un-filled gaps or obvious background shifts. +3 Lighting, perspective and contacts mostly correct; flaws small and tolerable; background adjusted locally. +4 Background reconstruction clean; existing details preserved; only minute changes outside the removal area. +5 Physically flawless and even enhances realism: accurate light/shadow/texture infill, high-quality micro-details. +The second and third score should no higher than first score!!! + +Example Response Format: +Brief reasoning: A short explanation of the score based on the criteria above, no more than 20 words. +Prompt Compliance: A number from 1 to 5. +Visual Naturalness: A number from 1 to 5. +Physical & Detail Integrity: A number from 1 to 5. +editing instruction is : . + +Below are the images before and after editing: +""", + "style": """ +You are a data rater specializing in grading style transfer edits. You will be given an input image, a reference style, and the styled result. Your task is to evaluate the style transfer on a 5-point scale from three perspectives: + +Style Fidelity +1 Target style absent or clearly wrong. +2 Style shows in a few areas only, or mixed with unrelated styles. +3 Key traits (palette, brushwork, texture) present but patchy or inconsistent. +4 Style reproduced across almost the whole image; only small local mismatches. +5 Full, faithful transfer: colour, texture, brushwork, lighting all match the exemplar over the entire image. + +Content Preservation +1 Major objects or layout lost/distorted; original scene barely recognisable. +2 Main subject recognisable, but size, perspective or key parts clearly wrong/missing. +3 Overall structure correct; some local warping or minor omissions. +4 Nearly all geometry intact; only slight, non-distracting deformation. +5 All objects and spatial relations kept; only stylistic, harmless distortion. + +Rendering Quality +1 Heavy noise, banding, pixel damage or blur; image unusable. +2 Visible seams, aliasing, colour drift; low resolution or chaotic strokes. +3 Moderate quality: local blur/noise/texture breaks, but generally acceptable. +4 Sharp, coherent strokes; tiny artefacts visible only when zoomed. +5 High resolution, no artefacts; strokes, textures and colour transitions look fully natural. +The second and third score should no higher than first score!!! + +Example Response Format: +Brief reasoning: A short explanation of the score based on the criteria above, no more than 20 words. +Style Fidelity: A number from 1 to 5. +Content Preservation: A number from 1 to 5. +Rendering Quality: A number from 1 to 5. +editing instruction is : . + +Below are the input, reference style, and styled output image: +""", + "action": """ +You are a data rater specializing in grading action or expression change edits. You will be given two images (before and after editing) and the editing instruction. Your task is to evaluate the motion or expression change on a 5-point scale from three perspectives: + +Action / Expression Fidelity +1 No visible change, or wrong action / expression. +2 Partial or clearly incorrect pose; only some body parts change; expression direction wrong. +3 Main idea present but details off (angle, side, intensity, missing gesture). +4 Requested pose / expression achieved with just minor inaccuracy (small angular drift, timing nuance). +5 Exact match to prompt: every limb, gesture, and facial muscle aligns with the described action. + +Identity Preservation +1 Person unrecognisable; face or body replaced. +2 Strong drift: key facial features, hairstyle or clothing heavily altered. +3 Mostly same identity; moderate changes in some features but still recognisable. +4 Identity clearly the same; only subtle stylisation or lighting differences. +5 Perfect preservation of face, hairstyle, skin tone, clothing and accessories. + +Visual & Anatomical Coherence +1 Severe artifacts: broken or duplicated limbs, extreme distortion, heavy noise/blur. +2 Noticeable cut-out halos, proportion errors, lighting or perspective clearly off. +3 Generally plausible; minor joint or shading issues; small noise/blur acceptable. +4 Clean render; anatomy, lighting, depth and edges consistent; flaws only on close inspection. +5 Flawless realism or stylistic coherence; perfect anatomy, lighting, shadows and texture continuity. +The second and third score should no higher than first score!!! + +Example Response Format: +Brief reasoning: A short explanation of the score based on the criteria above, no more than 20 words. +Action Fidelity: A number from 1 to 5. +Identity Preservation: A number from 1 to 5. +Visual & Anatomical Coherence: A number from 1 to 5. +editing instruction is : . + +Below are the images before and after editing: +""", + "extract": """ +You are a data rater specializing in grading object cut-out quality. You will be given an image with the object extracted on a white background. Your task is to evaluate the cut-out accuracy on a 5-point scale from three perspectives: + +Object Selection & Identity +1 Wrong object or multiple objects extracted. +2 Correct class but only part of the object, or obvious intrusions from other items. +3 Object largely correct yet small pieces missing / extra, identity still recognisable. +4 Full object with clear identity; only tiny mis-crop (e.g., tip of antenna). +5 Exact requested object, complete and unmistakably the same instance (ID). + +Mask Precision & Background Purity +1 Large background remnants, holes in mask, or non-white backdrop dominates. +2 Noticeable jagged edges, colour fringes, grey/colour patches in white area. +3 Acceptable mask; minor edge softness or faint halo visible on close look. +4 Clean, smooth edges; white (#FFFFFF) background uniform, tiny artefacts only when zoomed. +5 Crisp anti-aliased contour, zero spill or halo; backdrop perfectly pure white throughout. + +Object Integrity & Visual Quality +1 Severe blur, compression, deformation, or missing parts; unusable. +2 Moderate noise, colour shift, or slight warping; details clearly degraded. +3 Overall intact with minor softness or noise; colours mostly preserved. +4 Sharp detail, accurate colours; negligible artefacts. +5 Pristine: high-resolution detail, true colours, no artefacts or distortion. +The second and third score should no higher than first score!!! + +Example Response Format: +Brief reasoning: A short explanation of the score based on the criteria above, no more than 20 words. +Object Identity: A number from 1 to 5. +Mask Precision: A number from 1 to 5. +Visual Quality: A number from 1 to 5. +editing instruction is : . + +Below is the extracted object image: +""", + "background": """ +You are a data rater specializing in grading background editing. You will be given two images (before and after editing) and the editing instruction. Your task is to evaluate the background change on a 5-point scale from three perspectives: + +Instruction Compliance +1 No change, or background unrelated to prompt, or foreground also replaced/distorted. +2 Background partly replaced or wrong style/content; foreground noticeably altered. +3 Main background replaced but elements missing/extra, or faint spill onto subject edges. +4 Requested background fully present; foreground intact except minute artefacts or small prompt mismatch (e.g. colour tone). +5 Background exactly matches prompt (content, style, placement); all foreground pixels untouched. + +Visual Seamlessness (Edge & Texture Blend) +1 Large tearing, posterisation, extreme blur/noise; edit area obvious at a glance. +2 Clear cut-out halos, colour-resolution gap, or heavy smudge strokes. +3 Blend acceptable but visible on closer look: slight edge blur, grain or palette shift. +4 Nearly invisible seams; textures and sharpness aligned, only minor issues when zoomed in. +5 Indistinguishable composite: edges, textures, resolution and colour grading perfectly continuous. + +Physical Consistency (Lighting, Perspective, Depth) +1 Severe mismatch: wrong horizon, conflicting light direction, floating subject, warped geometry. +2 Noticeable but not extreme inconsistencies in light, shadows or scale; depth cues off. +3 Overall believable; small errors in shadow length, perspective or ambient colour. +4 Lighting, scale, depth, and camera angle well matched; only subtle discrepancies. +5 Physically flawless: foreground and new background share coherent light, shadows, reflections, perspective and atmospheric depth, enhancing overall realism. +The second and third score should no higher than first score!!! + +Example Response Format: +Brief reasoning: A short explanation of the score based on the criteria above, no more than 20 words. +Instruction Compliance: A number from 1 to 5. +Visual Seamlessness: A number from 1 to 5. +Physical Consistency: A number from 1 to 5. +editing instruction is : . + +Below are the images before and after editing: +""", + "compose": """ +You are a data rater specializing in grading hybrid image edits (involving multiple operations on multiple objects). You will be given two images (before and after editing) and the editing instruction. Your task is to evaluate the overall editing quality on a 5-point scale from three perspectives: + +Instruction Compliance +1 Neither object nor operations match the prompt; wrong items edited or shapes distorted. +2 Only one object correctly edited, or both edited but with wrong/partial operations; collateral changes to other items. +3 Both target objects touched, each with the requested operation broadly correct but missing details (e.g., wrong colour value, incomplete removal). +4 Both objects receive the exact operations; tiny deviations in amount, position, or parameter. No unintended edits elsewhere. +5 Perfect execution: each object fully reflects its specified operation, all other scene elements untouched. + +Visual Naturalness (Seamlessness) +1 Large artefacts, obvious cut-outs, heavy blur/noise; edits conspicuous at a glance. +2 Clear edge halos, colour or resolution mismatch, awkward scaling. +3 Acceptable but visible on close look: slight edge softness, minor palette or focus shift. +4 Edits blend smoothly; seams hard to spot, textures and sharpness largely consistent. +5 Indistinguishable composite: colour grading, grain, resolution and style fully match the original image. + +Physical Consistency & Fine Detail +1 Severe lighting/perspective mismatch, missing or wrong shadows; objects appear floating or warped. +2 Noticeable but tolerable inconsistencies in illumination, scale, or depth cues. +3 Generally plausible; small errors in shadow length, reflection angle, or texture alignment. +4 Lighting, perspective, and material response closely match; only subtle flaws visible when zoomed. +5 Physically flawless: shadows, highlights, reflections, depth and texture perfectly integrated, enhancing overall realism. +The second and third score should no higher than first score!!! + +Example Response Format: +Brief reasoning: A short explanation of the score based on the criteria above, no more than 20 words. +Instruction Compliance: A number from 1 to 5. +Visual Naturalness: A number from 1 to 5. +Physical Consistency & Fine Detail: A number from 1 to 5. +editing instruction is : . + +Below are the images before and after editing: +""", +} + +# Edit types supported +IMGEDIT_EDIT_TYPES = [ + "replace", + "add", + "adjust", + "remove", + "style", + "action", + "extract", + "background", + "compose", +] + + +def image_to_base64(image) -> Optional[str]: + """Convert PIL Image or image path to base64 string""" + try: + if isinstance(image, str): + # It's a path + if not os.path.exists(image): + eval_logger.warning(f"Image file not found: {image}") + return None + with open(image, "rb") as f: + return base64.b64encode(f.read()).decode("utf-8") + elif hasattr(image, "save"): + # It's a PIL Image + buffer = BytesIO() + image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + else: + eval_logger.warning(f"Unknown image type: {type(image)}") + return None + except Exception as e: + eval_logger.error(f"Error converting image to base64: {e}") + return None + + +def parse_gpt_scores(response_text: str) -> Tuple[float, float, float]: + """ + Parse GPT/Qwen response to extract three scores. + Returns tuple of (score1, score2, score3) + """ + try: + # Find all numbers in the format "Score Name: X" + score_pattern = r":\s*(\d+)" + matches = re.findall(score_pattern, response_text) + + if len(matches) >= 3: + # Take the last 3 numbers (the actual scores) + scores = [float(matches[-3]), float(matches[-2]), float(matches[-1])] + return tuple(scores) + + # Alternative: find standalone numbers on lines + lines = response_text.strip().split("\n") + scores = [] + for line in lines: + # Look for patterns like "Prompt Compliance: 4" or just "4" + match = re.search(r"(\d+)\s*$", line.strip()) + if match: + scores.append(float(match.group(1))) + + if len(scores) >= 3: + return (scores[-3], scores[-2], scores[-1]) + + eval_logger.warning(f"Could not parse 3 scores from response: {response_text[:200]}...") + return (0.0, 0.0, 0.0) + except Exception as e: + eval_logger.error(f"Error parsing scores: {e}") + return (0.0, 0.0, 0.0) + + +def calculate_average_score(scores: Tuple[float, float, float]) -> float: + """Calculate average of three scores""" + return sum(scores) / 3.0 + + +def imgedit_doc_to_visual(doc): + """ + Extract input image from document. + + Priority order: + 1. input_image field (PIL Image from embedded dataset) + 2. image_path field (full path saved by prepare_dataset.py) + 3. id field + IMGEDIT_ORIGIN_IMG_ROOT (relative path like "animal/000342021.jpg") + 4. input_image as string path + """ + origin_img_root = os.getenv("IMGEDIT_ORIGIN_IMG_ROOT", "") + + # 1. Try input_image field (PIL Image from embedded dataset) + input_image = doc.get("input_image") or doc.get("image") + if input_image is not None and hasattr(input_image, "convert"): + try: + return [input_image.convert("RGB")] + except Exception as e: + eval_logger.error(f"Error converting input_image: {e}") + + # 2. Try image_path field (full path saved by prepare_dataset.py) + image_path = doc.get("image_path", "") + if image_path and os.path.exists(image_path): + try: + return [Image.open(image_path).convert("RGB")] + except Exception as e: + eval_logger.error(f"Error loading image from image_path {image_path}: {e}") + + # 3. Try id field + origin_img_root (relative path like "animal/000342021.jpg") + image_id = doc.get("id", "") + if image_id and origin_img_root: + full_path = os.path.join(origin_img_root, image_id) + if os.path.exists(full_path): + try: + return [Image.open(full_path).convert("RGB")] + except Exception as e: + eval_logger.error(f"Error loading image from {full_path}: {e}") + + # 4. Try input_image as string path + if input_image is not None and isinstance(input_image, str): + # Try as absolute path first + if os.path.exists(input_image): + try: + return [Image.open(input_image).convert("RGB")] + except Exception as e: + eval_logger.error(f"Error loading image from {input_image}: {e}") + # Try with origin_img_root + elif origin_img_root: + full_path = os.path.join(origin_img_root, input_image) + if os.path.exists(full_path): + try: + return [Image.open(full_path).convert("RGB")] + except Exception as e: + eval_logger.error(f"Error loading image from {full_path}: {e}") + + eval_logger.warning(f"No input image found in document. " f"Available keys: {list(doc.keys())}, " f"image_path={image_path}, id={image_id}, origin_img_root={origin_img_root}") + return [] + + +def imgedit_doc_to_text(doc, lmms_eval_specific_kwargs=None): + """Extract instruction text from document""" + instruction = doc.get("prompt", "").strip() + pre_prompt = "" + post_prompt = "" + if lmms_eval_specific_kwargs: + pre_prompt = lmms_eval_specific_kwargs.get("pre_prompt", "") + post_prompt = lmms_eval_specific_kwargs.get("post_prompt", "") + return f"{pre_prompt}{instruction}{post_prompt}" + + +def imgedit_doc_to_target(doc): + """Extract target instruction (for reference)""" + return doc.get("prompt", "") + + +# ============================================ +# Qwen2.5-VL Evaluation Backend +# ============================================ + + +def _get_qwen25vl_model(): + """ + Get or create Qwen2.5-VL model instance (lazy loading, singleton pattern). + """ + global _qwen25vl_model + if _qwen25vl_model is not None: + return _qwen25vl_model + + try: + import random + + import numpy as np + import torch + from qwen_vl_utils import process_vision_info + from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + + model_path = os.getenv("QWEN_MODEL_PATH", "/pfs/training-data/hf/models/Qwen/Qwen2.5-VL-72B-Instruct-AWQ") + + eval_logger.info(f"Loading Qwen2.5-VL model from {model_path}...") + + model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto").eval() + processor = AutoProcessor.from_pretrained(model_path) + + _qwen25vl_model = {"model": model, "processor": processor, "process_vision_info": process_vision_info} + + eval_logger.info("Qwen2.5-VL model loaded successfully!") + return _qwen25vl_model + + except ImportError as e: + eval_logger.error(f"Failed to import Qwen2.5-VL dependencies: {e}") + eval_logger.error("Please install: pip install transformers qwen-vl-utils") + return None + except Exception as e: + eval_logger.error(f"Failed to load Qwen2.5-VL model: {e}") + return None + + +def _call_qwen25vl_for_evaluation( + original_image, + edited_image, + edit_prompt: str, + edit_type: str, +) -> Optional[str]: + """ + Call Qwen2.5-VL for image editing evaluation. + + Args: + original_image: Original image (PIL Image) + edited_image: Edited image (PIL Image) + edit_prompt: The editing instruction + edit_type: Type of edit (replace, add, adjust, etc.) + + Returns: + Model response text or None if failed + """ + import random + + import numpy as np + import torch + + def set_seed(seed: int): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + qwen_model = _get_qwen25vl_model() + if qwen_model is None: + return None + + model = qwen_model["model"] + processor = qwen_model["processor"] + process_vision_info = qwen_model["process_vision_info"] + + # Get prompt template for this edit type + prompt_template = IMGEDIT_PROMPTS.get(edit_type, IMGEDIT_PROMPTS["adjust"]) + full_prompt = prompt_template.replace("", edit_prompt) + + try: + # # Build message content for Qwen2.5-VL + # if edit_type == "extract": + # # For extract, only show the edited image + # messages = [ + # { + # "role": "user", + # "content": [ + # {"type": "image", "image": edited_image}, + # {"type": "text", "text": full_prompt}, + # ], + # } + # ] + # else: + # For other types, show both original and edited images + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": original_image}, + {"type": "image", "image": edited_image}, + {"type": "text", "text": full_prompt}, + ], + } + ] + + set_seed(42) + + # Prepare the inputs + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + image_inputs, video_inputs = process_vision_info(messages) + + # Process inputs + inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt") + inputs = inputs.to("cuda") + + # Generate output + generation_config = { + "max_new_tokens": 512, + "num_beams": 1, + "do_sample": False, + "temperature": 0.1, + "top_p": None, + } + generated_ids = model.generate(**inputs, **generation_config) + generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] + output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) + + return output_text[0] if output_text else "" + + except Exception as e: + eval_logger.error(f"Error calling Qwen2.5-VL: {e}") + return None + + +# ============================================ +# GPT-4o Evaluation Backend +# ============================================ + + +def _call_gpt_for_evaluation( + original_image, + edited_image, + edit_prompt: str, + edit_type: str, + api_key: Optional[str] = None, + base_url: Optional[str] = None, +) -> Optional[str]: + """ + Call GPT-4o for image editing evaluation. + + Args: + original_image: Original image (PIL Image or path) + edited_image: Edited image (PIL Image or path) + edit_prompt: The editing instruction + edit_type: Type of edit (replace, add, adjust, etc.) + api_key: OpenAI API key + base_url: OpenAI API base URL + + Returns: + GPT response text or None if failed + """ + try: + from openai import OpenAI + except ImportError: + eval_logger.error("OpenAI package not installed. Run: pip install openai") + return None + + # Get API credentials from environment if not provided + api_key = api_key or os.getenv("OPENAI_API_KEY") + base_url = base_url or os.getenv("OPENAI_BASE_URL") + + if not api_key: + eval_logger.error("OpenAI API key not found. Set OPENAI_API_KEY environment variable.") + return None + + # Convert images to base64 + original_b64 = image_to_base64(original_image) + edited_b64 = image_to_base64(edited_image) + + if not original_b64 or not edited_b64: + eval_logger.error("Failed to convert images to base64") + return None + + # Get prompt template for this edit type + prompt_template = IMGEDIT_PROMPTS.get(edit_type, IMGEDIT_PROMPTS["adjust"]) + full_prompt = prompt_template.replace("", edit_prompt) + + try: + client_kwargs = {"api_key": api_key} + if base_url: + client_kwargs["base_url"] = base_url + + client = OpenAI(**client_kwargs) + + # Build message content + content = [ + {"type": "text", "text": full_prompt}, + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{original_b64}"}}, + ] + + # For extract type, only show the edited image + if edit_type != "extract": + content.append({"type": "image_url", "image_url": {"url": f"data:image/png;base64,{edited_b64}"}}) + else: + # For extract, replace the second image with the edited one + content[1] = {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{edited_b64}"}} + + response = client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": content}], + max_tokens=512, + ) + + return response.choices[0].message.content + except Exception as e: + eval_logger.error(f"Error calling GPT API: {e}") + return None + + +# ============================================ +# vLLM Qwen Evaluation Backend +# ============================================ + + +def _call_vllm_qwen_for_evaluation( + original_image, + edited_image, + edit_prompt: str, + edit_type: str, +) -> Optional[str]: + """ + Call Qwen model via vLLM API for image editing evaluation. + + Environment variables: + - VLLM_API_BASE: Base URL of vLLM server (e.g., "http://localhost:8000/v1") + - VLLM_API_KEY: API key if required (default: "EMPTY") + - VLLM_MODEL_NAME: Model name (default: auto-detect) + + Args: + original_image: Original image (PIL Image) + edited_image: Edited image (PIL Image) + edit_prompt: The editing instruction + edit_type: Type of edit (replace, add, adjust, etc.) + + Returns: + Model response text or None if failed + """ + try: + from openai import OpenAI + except ImportError: + eval_logger.error("OpenAI package not installed. Run: pip install openai") + return None + + api_base = os.getenv("VLLM_API_BASE") + if not api_base: + eval_logger.error("VLLM_API_BASE environment variable not set") + return None + + api_key = os.getenv("VLLM_API_KEY", "EMPTY") + model_name = os.getenv("VLLM_MODEL_NAME", "default") + timeout = int(os.getenv("VLLM_TIMEOUT", "120")) + + # Get prompt template for this edit type + prompt_template = IMGEDIT_PROMPTS.get(edit_type, IMGEDIT_PROMPTS["adjust"]) + full_prompt = prompt_template.replace("", edit_prompt) + + # Convert images to base64 + original_b64 = image_to_base64(original_image) + edited_b64 = image_to_base64(edited_image) + + if not original_b64 or not edited_b64: + eval_logger.error("Failed to convert images to base64") + return None + + try: + client = OpenAI( + api_key=api_key, + base_url=api_base, + timeout=timeout, + ) + + # Auto-detect model name if not set + if model_name == "default": + try: + models = client.models.list() + if models.data: + model_name = models.data[0].id + except Exception: + pass + + # Build message content + content = [ + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{original_b64}"}}, + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{edited_b64}"}}, + {"type": "text", "text": full_prompt}, + ] + + response = client.chat.completions.create( + model=model_name, + messages=[{"role": "user", "content": content}], + max_tokens=512, + temperature=0.1, + ) + + return response.choices[0].message.content if response.choices else "" + + except Exception as e: + eval_logger.error(f"Error calling vLLM API: {e}") + return None + + +# ============================================ +# Unified Evaluation Function +# ============================================ + + +def _call_model_for_evaluation( + original_image, + edited_image, + edit_prompt: str, + edit_type: str, +) -> Optional[str]: + """ + Call the configured model for evaluation. + + The backend is selected via IMGEDIT_EVAL_BACKBONE environment variable: + - "gpt4o" (default): Use GPT-4o via OpenAI API + - "qwen25vl": Use Qwen2.5-VL locally + - "vllm_qwen" / "vllm_qwen25vl" / "vllm_qwen3vl": Use Qwen via vLLM API + """ + backbone = os.getenv("IMGEDIT_EVAL_BACKBONE", "gpt4o").lower() + + if backbone == "qwen25vl": + eval_logger.debug(f"Using Qwen2.5-VL (local) for evaluation (edit_type={edit_type})") + return _call_qwen25vl_for_evaluation(original_image, edited_image, edit_prompt, edit_type) + elif backbone in ["vllm_qwen", "vllm_qwen25vl", "vllm_qwen3vl"]: + eval_logger.debug(f"Using vLLM Qwen for evaluation (edit_type={edit_type})") + return _call_vllm_qwen_for_evaluation(original_image, edited_image, edit_prompt, edit_type) + else: + eval_logger.debug(f"Using GPT-4o for evaluation (edit_type={edit_type})") + return _call_gpt_for_evaluation(original_image, edited_image, edit_prompt, edit_type) + + +# ============================================ +# Process Results +# ============================================ + + +def imgedit_process_results(doc, results, **kwargs): + """ + Process model predictions: + 1. Parse JSON output to extract text and images + 2. Save images to required directory structure + 3. Evaluate using GPT-4o or Qwen2.5-VL + + Args: + doc: Document containing input image, instruction, key, edit_type, etc. + results: Model predictions [JSON string with {"text": "...", "images": [...]}] + **kwargs: Additional arguments + + Returns: + Dict with metrics: imgedit_score1, imgedit_score2, imgedit_score3, imgedit_avg_score + """ + # Get configuration from environment variables + model_name = os.getenv("IMGEDIT_MODEL_NAME", "default") + output_base_dir = os.getenv("IMGEDIT_OUTPUT_DIR", "./logs/imgedit_results") + origin_img_root = os.getenv("IMGEDIT_ORIGIN_IMG_ROOT", "") + + # Parse prediction + pred = results[0] if results else "{}" + try: + pred = json.loads(pred) + except (json.JSONDecodeError, TypeError): + eval_logger.warning(f"Failed to parse prediction JSON: {pred}") + pred = {"text": "", "images": []} + + model_images = pred.get("images", []) + + # Extract document fields + key = doc.get("key", str(doc.get("id", "unknown"))) + edit_type = doc.get("edit_type", "adjust") + edit_prompt = doc.get("prompt", "") + image_id = doc.get("id", "") # relative path like "animal/000342021.jpg" + + # Get input/original image - try multiple sources in order of priority + input_image_pil = None + + # 1. Try input_image field (PIL Image from embedded dataset) + input_image = doc.get("input_image") or doc.get("image") + if input_image is not None and hasattr(input_image, "convert"): + try: + input_image_pil = input_image.convert("RGB") + eval_logger.debug(f"Loaded input_image from doc (PIL) for key {key}") + except Exception as e: + eval_logger.warning(f"Failed to convert input_image for key {key}: {e}") + + # 2. Try image_path field (full path saved by prepare_dataset.py) + if input_image_pil is None: + image_path = doc.get("image_path", "") + if image_path and os.path.exists(image_path): + try: + input_image_pil = Image.open(image_path).convert("RGB") + eval_logger.debug(f"Loaded image from image_path: {image_path}") + except Exception as e: + eval_logger.warning(f"Failed to load image from {image_path}: {e}") + + # 3. Try id field with origin_img_root (relative path like "animal/000342021.jpg") + if input_image_pil is None and image_id and origin_img_root: + full_path = os.path.join(origin_img_root, image_id) + if os.path.exists(full_path): + try: + input_image_pil = Image.open(full_path).convert("RGB") + eval_logger.debug(f"Loaded image from origin_img_root + id: {full_path}") + except Exception as e: + eval_logger.warning(f"Failed to load image from {full_path}: {e}") + + # 4. Try input_image as string path + if input_image_pil is None and input_image is not None and isinstance(input_image, str): + # Try as absolute path first + if os.path.exists(input_image): + try: + input_image_pil = Image.open(input_image).convert("RGB") + eval_logger.debug(f"Loaded image from input_image path: {input_image}") + except Exception as e: + eval_logger.warning(f"Failed to load image from {input_image}: {e}") + # Try with origin_img_root + elif origin_img_root: + full_path = os.path.join(origin_img_root, input_image) + if os.path.exists(full_path): + try: + input_image_pil = Image.open(full_path).convert("RGB") + eval_logger.debug(f"Loaded image from origin_img_root + input_image: {full_path}") + except Exception as e: + eval_logger.warning(f"Failed to load image from {full_path}: {e}") + + # 5. Try to load from saved _SRCIMG file + if input_image_pil is None: + src_img_path = os.path.join(output_base_dir, model_name, f"{key}_SRCIMG.png") + if os.path.exists(src_img_path): + try: + input_image_pil = Image.open(src_img_path).convert("RGB") + eval_logger.debug(f"Loaded source image from _SRCIMG: {src_img_path}") + except Exception as e: + eval_logger.warning(f"Failed to load source image from {src_img_path}: {e}") + + # Return zero scores if no input image + if input_image_pil is None: + eval_logger.warning(f"No input image found for key {key}. " f"Tried: input_image={doc.get('input_image') is not None}, " f"image_path={doc.get('image_path', '')}, " f"id={image_id}, origin_img_root={origin_img_root}") + return _create_zero_result(key, edit_type) + + # Find edited image + edited_image_path = None + edited_image_pil = None + + if model_images and len(model_images) > 0: + generated_image_path = model_images[0] + if os.path.exists(generated_image_path): + edited_image_path = generated_image_path + try: + edited_image_pil = Image.open(edited_image_path).convert("RGB") + except Exception as e: + eval_logger.warning(f"Failed to load edited image: {e}") + + # Try to find from standard location + if edited_image_pil is None: + existing_path = os.path.join(output_base_dir, model_name, f"{key}.png") + if os.path.exists(existing_path): + edited_image_path = existing_path + try: + edited_image_pil = Image.open(existing_path).convert("RGB") + except Exception as e: + eval_logger.warning(f"Failed to load edited image from {existing_path}: {e}") + + # Return zero scores if no edited image + if edited_image_pil is None: + eval_logger.warning(f"No edited image found for key {key}") + return _create_zero_result(key, edit_type) + + # Call model for evaluation (GPT-4o or Qwen2.5-VL based on config) + model_response = _call_model_for_evaluation( + input_image_pil, + edited_image_pil, + edit_prompt, + edit_type, + ) + + if model_response is None: + eval_logger.warning(f"Model evaluation failed for key {key}") + return _create_zero_result(key, edit_type) + + # Parse scores from model response + score1, score2, score3 = parse_gpt_scores(model_response) + avg_score = calculate_average_score((score1, score2, score3)) + + eval_logger.info(f"[{edit_type}] Key {key}: " f"Score1={score1:.1f}, Score2={score2:.1f}, Score3={score3:.1f}, " f"Avg={avg_score:.2f}") + + return { + "imgedit_score1": { + "key": key, + "edit_type": edit_type, + "score": float(score1), + }, + "imgedit_score2": { + "key": key, + "edit_type": edit_type, + "score": float(score2), + }, + "imgedit_score3": { + "key": key, + "edit_type": edit_type, + "score": float(score3), + }, + "imgedit_avg_score": { + "key": key, + "edit_type": edit_type, + "score": float(avg_score), + "model_response": model_response, + }, + } + + +def _create_zero_result(key: str, edit_type: str) -> Dict: + """Create a zero-score result dict""" + return { + "imgedit_score1": { + "key": key, + "edit_type": edit_type, + "score": 0.0, + }, + "imgedit_score2": { + "key": key, + "edit_type": edit_type, + "score": 0.0, + }, + "imgedit_score3": { + "key": key, + "edit_type": edit_type, + "score": 0.0, + }, + "imgedit_avg_score": { + "key": key, + "edit_type": edit_type, + "score": 0.0, + }, + } + + +# ============================================ +# Aggregation Functions +# ============================================ + + +def imgedit_aggregate_results(results): + """ + Aggregate results across all samples and compute final scores. + Returns overall average score. + + Args: + results: List of result dicts from process_results + + Returns: + Final aggregated score (average across all samples) + """ + if not results: + return 0.0 + + # Calculate average score + scores = [r["score"] for r in results if "score" in r] + if not scores: + return 0.0 + + avg_score = np.mean(scores) + + # Log breakdown by edit type + edit_type_scores = defaultdict(list) + + for r in results: + if "score" in r: + edit_type = r.get("edit_type", "unknown") + edit_type_scores[edit_type].append(r["score"]) + + # Log statistics + eval_logger.info(f"Overall average score: {avg_score:.3f}") + eval_logger.info(f"Number of samples: {len(scores)}") + + if edit_type_scores: + eval_logger.info("Scores by edit type:") + for edit_type, type_scores in sorted(edit_type_scores.items()): + type_avg = np.mean(type_scores) + eval_logger.info(f" {edit_type}: {type_avg:.3f} (n={len(type_scores)})") + + return avg_score + + +def imgedit_aggregate_by_type(results): + """ + Aggregate results by edit type and return a dict of scores per type. + + Args: + results: List of result dicts from process_results + + Returns: + Dict mapping edit_type to average score + """ + if not results: + return {} + + edit_type_scores = defaultdict(list) + + for r in results: + if "score" in r: + edit_type = r.get("edit_type", "unknown") + edit_type_scores[edit_type].append(r["score"]) + + type_averages = {} + for edit_type, type_scores in edit_type_scores.items(): + type_averages[edit_type] = float(np.mean(type_scores)) + + # Log the breakdown + eval_logger.info("=" * 50) + eval_logger.info("Scores by Edit Type:") + eval_logger.info("=" * 50) + for edit_type in IMGEDIT_EDIT_TYPES: + if edit_type in type_averages: + eval_logger.info(f" {edit_type}: {type_averages[edit_type]:.3f} (n={len(edit_type_scores[edit_type])})") + eval_logger.info("=" * 50) + + return type_averages + + +# Per-type aggregation functions for YAML +def _aggregate_for_type(results, target_type: str): + """Helper to aggregate scores for a specific edit type""" + if not results: + return 0.0 + + type_scores = [r["score"] for r in results if r.get("edit_type") == target_type and "score" in r] + + if not type_scores: + return 0.0 + + return float(np.mean(type_scores)) + + +def imgedit_aggregate_replace(results): + """Aggregate scores for 'replace' edit type""" + return _aggregate_for_type(results, "replace") + + +def imgedit_aggregate_add(results): + """Aggregate scores for 'add' edit type""" + return _aggregate_for_type(results, "add") + + +def imgedit_aggregate_adjust(results): + """Aggregate scores for 'adjust' edit type""" + return _aggregate_for_type(results, "adjust") + + +def imgedit_aggregate_remove(results): + """Aggregate scores for 'remove' edit type""" + return _aggregate_for_type(results, "remove") + + +def imgedit_aggregate_style(results): + """Aggregate scores for 'style' edit type""" + return _aggregate_for_type(results, "style") + + +def imgedit_aggregate_action(results): + """Aggregate scores for 'action' edit type""" + return _aggregate_for_type(results, "action") + + +def imgedit_aggregate_extract(results): + """Aggregate scores for 'extract' edit type""" + return _aggregate_for_type(results, "extract") + + +def imgedit_aggregate_background(results): + """Aggregate scores for 'background' edit type""" + return _aggregate_for_type(results, "background") + + +def imgedit_aggregate_compose(results): + """Aggregate scores for 'compose' edit type""" + return _aggregate_for_type(results, "compose")