Skip to content

Conversation

@RadhaGulhane13
Copy link
Contributor

@RadhaGulhane13 RadhaGulhane13 commented Sep 14, 2025

Description

This PR introduces initial support for Gemma-3 models.

Current evaluations indicate good alignment with VQA tasks and reproducibility on MATH.

Community contributions to help reproduce results on MMMU-pro and MMMU-val would be highly valuable.


Evaluation

Model mmmu_val ai2d chartqa docvqa_val mathvista_testmini
google/gemma-3-4b-it 39.55 71.44 50.96 69.08 52.0

Example Usage

An example usage script is available at:
examples/models/gemma3.sh


Related Issues

#664

Summary by CodeRabbit

  • New Features

    • Added Gemma3 model support with text+visual inputs, batched generation, configurable system/reasoning prompts, token/temperature/stop controls, progress feedback, and multi-GPU/distributed execution.
  • Chores

    • Added an example script to run reproducible Gemma3 evaluations across benchmarks with predefined settings and logging.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 14, 2025

Caution

Review failed

The pull request is closed.

Walkthrough

Adds a new simple model wrapper Gemma3 (registered as "gemma3"), updates the model registry, and adds an example shell script to run lmms_eval with a Gemma-3 IT checkpoint via accelerate for specified evaluation tasks.

Changes

Cohort / File(s) Summary of Changes
Model wrapper (Gemma3)
lmms_eval/models/simple/gemma3.py
Adds Gemma3 class: init loads pretrained Gemma-3 IT model/processor/tokenizer with device and distributed handling, exposes properties (config, tokenizer, model, device, rank, world_size, etc.), implements generate_until with batching, visuals handling, stopping logic, and placeholders for loglikelihood and multi-round generation.
Model registry
lmms_eval/models/__init__.py
Adds "gemma3": "Gemma3" to AVAILABLE_SIMPLE_MODELS.
Example evaluation script
examples/models/gemma3.sh
New script launching lmms_eval via accelerate with google/gemma-3-4b-it (8 processes, port 12345, batch_size 1) targeting tasks mmmu_val, ai2d, and mathvista_testmini, writing outputs to ./logs/.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    actor User
    participant CLI as lmms_eval (CLI)
    participant Registry as ModelRegistry
    participant Gemma as Gemma3
    participant Proc as Processor/Tokenizer
    participant HF as HF Model (Gemma-3 IT)

    User->>CLI: accelerate launch ... --model gemma3 --tasks ...
    CLI->>Registry: resolve "gemma3"
    Registry-->>CLI: lmms_eval.models.simple.gemma3.Gemma3
    CLI->>Gemma: initialize(pretrained, device_map, ...)
    rect rgba(220,240,255,0.4)
      Gemma->>Proc: load tokenizer/processor
      Gemma->>HF: load model (single/MULTI_GPU/FSDP aware)
    end

    loop For each batched request group
      CLI->>Gemma: generate_until(requests)
      Gemma->>Gemma: collate/sort by length, assemble messages + visuals
      Gemma->>Proc: apply_chat_template -> tensors
      Proc-->>Gemma: input tensors (ids, pixel_values, masks)
      Gemma->>HF: generate(...)
      HF-->>Gemma: output ids
      Gemma->>Proc: decode + apply stoppers
      Gemma-->>CLI: ordered text results
    end

    CLI-->>User: metrics and outputs
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Poem

I twitch my whiskers at the new display,
Gemma3 hops in to help and play.
Pixels, prompts, and batched delight,
Generations glow into the night.
Logs and carrots, tidy and bright. 🥕✨

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Title Check ✅ Passed The title "[Feature] Support for Gemma-3 Models" succinctly and accurately describes the primary change in this PR—adding Gemma-3 model support via a new model wrapper, mapping, and example script. It is concise, focused on the main change, and avoids unnecessary detail or noise. A teammate scanning the history would understand the primary intent from this title.
Description Check ✅ Passed The PR description contains a clear "Description" section, an evaluation results table, an example usage script path, and a related-issues link, satisfying the repository's requirement for a detailed description. It provides motivation, quantitative results, and usage guidance, so reviewers have the necessary context to evaluate the change. The explicit checklist boxes from the template are not present in the body but their absence is non-critical given the content provided.
Docstring Coverage ✅ Passed No functions found in the changes. Docstring coverage check skipped.

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c28f92d and 96d75c2.

📒 Files selected for processing (2)
  • examples/models/gemma3.sh (1 hunks)
  • lmms_eval/models/simple/gemma3.py (1 hunks)

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

🧹 Nitpick comments (7)
examples/models/gemma3.sh (1)

1-1: Add a shebang and strict mode for portability and safety.

Without a shebang, shells may interpret differently; strict mode avoids silent failures.

Apply this diff:

+#/usr/bin/env bash
+set -euo pipefail
+
 # Run and exactly reproduce gemma3 results!
 # mme as an example
lmms_eval/models/simple/gemma3.py (6)

236-239: Prefer iterable unpacking for list assembly.

Cleaner and avoids an intermediate list.

Apply this diff:

-                message.append(
-                    {
-                        "role": "user",
-                        "content": processed_visuals + [{"type": "text", "text": context}],
-                    }
-                )
+                message.append(
+                    {"role": "user", "content": [*processed_visuals, {"type": "text", "text": context}]}
+                )

20-22: Over-broad warning suppression.

Global ignore can hide real issues; scope or narrow categories instead.

Apply this diff:

-warnings.simplefilter("ignore", category=DeprecationWarning)
-warnings.filterwarnings("ignore")
+warnings.filterwarnings("ignore", category=DeprecationWarning)

89-93: Optional check: be explicit with Optional comparisons (Pyright).

Use is not None to satisfy strict Optional checks.

Apply this diff:

-        if reasoning_prompt:
+        if reasoning_prompt is not None:
             self.reasoning_prompt = reasoning_prompt.replace("\\n", "\n")
         else:
             self.reasoning_prompt = None

216-221: Reasoning prompt concatenation: add a separating newline.

Prevents accidental token fusion.

Apply this diff:

-                if self.reasoning_prompt:
-                    context = context.strip() + self.reasoning_prompt
+                if self.reasoning_prompt:
+                    context = context.strip() + "\n" + self.reasoning_prompt
                     contexts[i] = context

64-67: dtype choice: make configurable.

Hardcoding bfloat16 can fail on older GPUs; expose via init or detect.

Would you like a patch to auto-select torch.float16 on devices without bf16?


193-203: Exception type and message style.

Use TypeError for wrong types per Ruff TRY004/TRY003.

Apply this diff:

-            elif not isinstance(until, list):
-                raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str, list], but got {type(until)}")
+            elif not isinstance(until, list):
+                raise TypeError("gen_kwargs['until'] must be str or list[str]")
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b87aa4d and 68ae823.

📒 Files selected for processing (3)
  • examples/models/gemma3.sh (1 hunks)
  • lmms_eval/models/__init__.py (1 hunks)
  • lmms_eval/models/simple/gemma3.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

**/*.py: Type hints are required for all Python code
Public APIs must have docstrings
Maximum line length is 88 characters
Use PEP 8 naming: snake_case for functions/variables
Class names must use PascalCase
Constants should be in UPPER_SNAKE_CASE
Use f-strings for string formatting
Use early returns to avoid nested conditions
Use descriptive names; prefix handler functions with 'handle'
Prefer constants over functions where possible
Prefer functional, immutable approaches when not verbose
Define composing (higher-level) functions before their components
Mark issues in existing code with TODO: prefix in comments
Use functional and stateless approaches where they improve clarity
Use Ruff to enforce: import sorting (I001) and no unused imports
For long strings, wrap using parentheses rather than backslashes
Format long function calls over multiple lines with proper indentation
Split long import lists across multiple lines
Use Pyright type checking: add explicit None checks for Optional values
Use Pyright type narrowing for strings where applicable
Use Ruff (via pre-commit) to format and lint Python files
Document public APIs and test thoroughly

Files:

  • lmms_eval/models/__init__.py
  • lmms_eval/models/simple/gemma3.py
🧬 Code graph analysis (1)
lmms_eval/models/simple/gemma3.py (4)
lmms_eval/api/instance.py (2)
  • Instance (6-29)
  • args (25-29)
lmms_eval/api/model.py (2)
  • lmms (26-331)
  • add_partial (348-352)
lmms_eval/api/registry.py (1)
  • register_model (11-24)
lmms_eval/utils.py (3)
  • Collator (865-1020)
  • chunks (136-171)
  • get_batched (891-913)
🪛 Shellcheck (0.10.0)
examples/models/gemma3.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

🪛 Ruff (0.12.2)
lmms_eval/models/simple/gemma3.py

199-199: Prefer TypeError exception for invalid type

(TRY004)


199-199: Avoid specifying long messages outside the exception class

(TRY003)


237-237: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)

🔇 Additional comments (3)
lmms_eval/models/__init__.py (1)

28-28: Registering gemma3 looks correct.

Mapping resolves to lmms_eval.models.simple.gemma3.Gemma3 as intended.

lmms_eval/models/simple/gemma3.py (2)

269-277: Consider removing tie_weights() call.

Generally unnecessary post-loading; may be a no-op or slow.

Do we need weight tying here? If not, removing can speed init slightly.


186-193: Verified — dependency present and usage consistent. self.task_dict is initialized in lmms_eval/api/model.py and populated by evaluator (lmms_eval/evaluator.py); doc_to_visual is implemented with the expected callable/string signature in lmms_eval/api/task.py; using doc_to_visual[0] is correct in batched code because zip(*chunk) yields tuples.

Comment on lines 34 to 55
def __init__(
self,
pretrained: str = "google/gemma-3-27b-it",
device: Optional[str] = "cuda",
device_map: Optional[str] = "auto",
batch_size: Optional[Union[int, str]] = 1,
trust_remote_code: Optional[bool] = True,
use_cache=True,
attn_implementation: Optional[str] = None,
min_pixels: int = 256 * 28 * 28,
max_pixels: int = 1605632,
max_num_frames: int = 32,
interleave_visuals: Optional[bool] = False,
system_prompt: Optional[str] = "You are a helpful assistant.",
reasoning_prompt: Optional[str] = None,
**kwargs,
) -> None:
super().__init__()
# Do not use kwargs for now
assert kwargs == {}, f"Unexpected kwargs: {kwargs}"

accelerator = Accelerator()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Conflicting kwargs usage for max_length.

You assert kwargs == {} but later read max_length from kwargs; unreachable path.

Apply this diff to make max_length explicit and remove dead kwargs access:

-        system_prompt: Optional[str] = "You are a helpful assistant.",
-        reasoning_prompt: Optional[str] = None,
-        **kwargs,
+        system_prompt: Optional[str] = "You are a helpful assistant.",
+        reasoning_prompt: Optional[str] = None,
+        max_length: int = 2048,
     ) -> None:
         super().__init__()
-        # Do not use kwargs for now
-        assert kwargs == {}, f"Unexpected kwargs: {kwargs}"
+        # NOTE: no extra kwargs are accepted to avoid silent misconfigurations
+        # (add parameters explicitly if needed)
@@
-        self._max_length = kwargs.get("max_length", 2048)
+        self._max_length = int(max_length)

Also applies to: 77-79

🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 34 to 55 (and also adjust
lines ~77-79), the constructor asserts kwargs == {} but later code reads
max_length from kwargs, making that path unreachable; add max_length as an
explicit parameter (e.g., max_length: Optional[int] = None) to the __init__
signature, remove any reads of max_length from kwargs, and drop the kwargs
assertion only if other legitimate kwargs are expected (otherwise keep assertion
and ensure no kwargs are used anywhere); update the later block at lines ~77-79
to reference the explicit max_length parameter instead of accessing kwargs.

Comment on lines +39 to +41
batch_size: Optional[Union[int, str]] = 1,
trust_remote_code: Optional[bool] = True,
use_cache=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

batch_size type mismatch.

Annotated as Optional[Union[int, str]] but cast with int(); non‑numeric strings will crash.

Apply this diff to accept only int:

-        batch_size: Optional[Union[int, str]] = 1,
+        batch_size: int = 1,
@@
-        self.batch_size_per_gpu = int(batch_size)
+        self.batch_size_per_gpu = batch_size

Also applies to: 80-80

🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 39-41 (and also at line 80),
the batch_size parameter is annotated as Optional[Union[int, str]] but the
implementation casts it with int(), which will crash on non-numeric strings;
change the annotation to Optional[int] and update the code to accept only ints
by removing string handling/casting (or explicitly validate and raise a clear
TypeError if a non-int is passed), and adjust any callers or docstrings
accordingly so batch_size is strictly an int (or None).

Comment on lines +74 to +76
self._tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=trust_remote_code, device_map=self.device_map)
self.processor = AutoProcessor.from_pretrained(pretrained, max_pixels=max_pixels, min_pixels=min_pixels)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

AutoTokenizer.from_pretrained does not accept device_map.

Passing device_map will raise a TypeError; remove it from tokenizer init.

Apply this diff:

-        self._tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=trust_remote_code, device_map=self.device_map)
+        self._tokenizer = AutoTokenizer.from_pretrained(
+            pretrained,
+            trust_remote_code=trust_remote_code,
+        )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self._tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=trust_remote_code, device_map=self.device_map)
self.processor = AutoProcessor.from_pretrained(pretrained, max_pixels=max_pixels, min_pixels=min_pixels)
self._tokenizer = AutoTokenizer.from_pretrained(
pretrained,
trust_remote_code=trust_remote_code,
)
self.processor = AutoProcessor.from_pretrained(pretrained, max_pixels=max_pixels, min_pixels=min_pixels)
🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 74-76,
AutoTokenizer.from_pretrained is being called with a device_map keyword which
AutoTokenizer does not accept; remove the device_map argument from the tokenizer
initialization and ensure any device_map usage is applied when loading the model
(not the tokenizer), i.e., call AutoTokenizer.from_pretrained(pretrained,
trust_remote_code=trust_remote_code) and keep device_map only for model loading
logic.

Comment on lines +85 to +88
self.max_pixels = max_pixels
self.min_pixels = min_pixels
self.max_num_frames = max_num_frames

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Unused parameters: interleave_visuals, max_num_frames.

They are not used; either implement or remove to avoid confusion.

Would you like me to wire these into visual processing (e.g., limit frame sampling)?

🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 85 to 88, the constructor
assigns max_num_frames and ignores interleave_visuals (unused) which causes
confusion; either remove the unused parameters from the signature and delete
their assignments, or persist and use them: store self.interleave_visuals =
interleave_visuals and self.max_num_frames = max_num_frames (if not already),
then wire them into visual processing code—use interleave_visuals to control
whether visual frames are interleaved with other modalities and enforce
self.max_num_frames when sampling/iterating frames (truncate or sample to that
limit) wherever frames are prepared or batched.

Comment on lines +114 to +156
@property
def config(self):
# return the associated transformers.AutoConfig for the given pretrained model.
return self._config

@property
def tokenizer(self):
return self._tokenizer

@property
def model(self):
# returns the model, unwrapping it if using Accelerate
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self._model)
else:
return self._model

@property
def eot_token_id(self):
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
# return self.tokenizer.eod_id
return self.tokenizer.eos_token_id

@property
def max_length(self):
return self._max_length

@property
def batch_size(self):
return self.batch_size_per_gpu

@property
def device(self):
return self._device

@property
def rank(self):
return self._rank

@property
def world_size(self):
return self._world_size

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add missing type hints and docstrings for public API.

Guidelines require type hints and docstrings; properties and flatten lack them.

Apply this diff:

 class Gemma3(lmms):
-    """
-    Gemma3 Model
-    https://huggingface.co/google/gemma-3-27b-it
-    """
+    """Gemma-3 simple wrapper for lmms-eval.
+
+    Loads a Gemma-3 IT checkpoint and provides batched `generate_until`.
+    """

@@
-    def config(self):
-        # return the associated transformers.AutoConfig for the given pretrained model.
-        return self._config
+    def config(self) -> PretrainedConfig:
+        """Return the underlying Transformers config."""
+        return self._config
@@
-    def tokenizer(self):
-        return self._tokenizer
+    def tokenizer(self) -> PreTrainedTokenizerBase:
+        """Return the tokenizer/processor tokenizer."""
+        return self._tokenizer
@@
-    def model(self):
-        # returns the model, unwrapping it if using Accelerate
+    def model(self) -> Gemma3ForConditionalGeneration:
+        """Return the model, unwrapped if using Accelerate."""
         if hasattr(self, "accelerator"):
             return self.accelerator.unwrap_model(self._model)
         else:
             return self._model
@@
-    def eot_token_id(self):
+    def eot_token_id(self) -> int:
         # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
         # return self.tokenizer.eod_id
         return self.tokenizer.eos_token_id
@@
-    def max_length(self):
+    def max_length(self) -> int:
         return self._max_length
@@
-    def batch_size(self):
+    def batch_size(self) -> int:
         return self.batch_size_per_gpu
@@
-    def device(self):
+    def device(self) -> torch.device:
         return self._device
@@
-    def rank(self):
+    def rank(self) -> int:
         return self._rank
@@
-    def world_size(self):
+    def world_size(self) -> int:
         return self._world_size
@@
-    def flatten(self, input):
-        new_list = []
-        for i in input:
-            for j in i:
-                new_list.append(j)
-        return new_list
+    def flatten(self, items: List[List[str]]) -> List[str]:
+        """Flatten a list-of-lists one level."""
+        return [j for i in items for j in i]
@@
-    def generate_until(self, requests: List[Instance]) -> List[str]:
+    def generate_until(self, requests: List[Instance]) -> List[str]:
+        """Generate responses for each Instance until stop tokens."""
         res = []
@@
-    def generate_until_multi_round(self, requests) -> List[str]:
-        raise NotImplementedError("TODO: Implement multi-round generation")
+    def generate_until_multi_round(self, requests: List[Instance]) -> List[str]:
+        """Multi-round chat generation (not yet implemented)."""
+        raise NotImplementedError("TODO: Implement multi-round generation")

Also applies to: 160-166, 167-167, 298-300

🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 114 to 156 (also apply similar
changes at 160-166, 167, and 298-300), several public @property methods lack
type hints and docstrings; add explicit return type annotations for each
property (config, tokenizer, model, eot_token_id, max_length, batch_size,
device, rank, world_size) and short one-line docstrings describing the returned
value and types (e.g., "Return the transformers.AutoConfig for the pretrained
model."), and fix the batch_size property to return self.batch_size_per_gpu with
the proper type hint; ensure imports support any forward types if needed and
keep docstrings consistent with project style.

Comment on lines +207 to +215
for i in range(len(contexts)):
if "<image>" in contexts[i]:
contexts[i] = contexts[i].replace("<image>", "")

batched_messages = []
for i, context in enumerate(contexts):
if "<image>" in context:
context = context.replace("<image>", "")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Duplicate '' cleanup.

You strip twice (Lines 207–211 and 213–215). Keep one.

Apply this diff to remove the first loop:

-            for i in range(len(contexts)):
-                if "<image>" in contexts[i]:
-                    contexts[i] = contexts[i].replace("<image>", "")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for i in range(len(contexts)):
if "<image>" in contexts[i]:
contexts[i] = contexts[i].replace("<image>", "")
batched_messages = []
for i, context in enumerate(contexts):
if "<image>" in context:
context = context.replace("<image>", "")
batched_messages = []
for i, context in enumerate(contexts):
if "<image>" in context:
context = context.replace("<image>", "")
🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 207 to 215 there is duplicate
removal of the "<image>" token (first loop lines 207–211 and again inside the
batched_messages loop lines 213–215); remove the first standalone loop (lines
207–211) so that "<image>" is only stripped once when building batched_messages,
leaving the single replacement inside the batched_messages creation.

Comment on lines +247 to +251
if self.device_map == "auto":
inputs = inputs.to("cuda")
else:
inputs = inputs.to(self.device)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Avoid double .to() and wrap long call.

Inputs are moved twice; also exceed 88 chars.

Apply this diff:

-            inputs = self.processor.apply_chat_template(batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", padding="max_length", pad_to_multiple_of=8, max_length=self.max_length).to(
-                self.model.device, dtype=torch.bfloat16
-            )
-
-            if self.device_map == "auto":
-                inputs = inputs.to("cuda")
-            else:
-                inputs = inputs.to(self.device)
+            inputs = self.processor.apply_chat_template(
+                batched_messages,
+                add_generation_prompt=True,
+                tokenize=True,
+                return_dict=True,
+                return_tensors="pt",
+                padding="max_length",
+                pad_to_multiple_of=8,
+                max_length=self.max_length,
+            )
+            target_device = "cuda" if self.device_map == "auto" else self.device
+            inputs = inputs.to(target_device, dtype=torch.bfloat16)

Also applies to: 243-245

🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 243-245 and 247-251, the code
calls .to() twice on inputs and creates lines that exceed the 88-char limit;
replace the conditional device selection with a single device variable (e.g.
device = "cuda" if self.device_map == "auto" else self.device) and then call
inputs = inputs.to(device) once, breaking the line if necessary to keep it under
88 characters.

Comment on lines +253 to +268
default_gen_kwargs = {
"max_new_tokens": 128,
"temperature": 0.0, # Set to 0 for greedy default
"top_p": None,
"num_beams": 1,
}
# Update with provided kwargs
current_gen_kwargs = {**default_gen_kwargs, **gen_kwargs}

if current_gen_kwargs["temperature"] > 0:
current_gen_kwargs["do_sample"] = True
else:
current_gen_kwargs["do_sample"] = False
current_gen_kwargs["temperature"] = None
current_gen_kwargs["top_p"] = None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Generation kwargs: avoid passing None values to generate().

temperature/top_p=None can raise at runtime. Only pass when set; keep greedy defaults.

Apply this diff:

-            default_gen_kwargs = {
-                "max_new_tokens": 128,
-                "temperature": 0.0,  # Set to 0 for greedy default
-                "top_p": None,
-                "num_beams": 1,
-            }
+            default_gen_kwargs = {
+                "max_new_tokens": 128,
+                "temperature": 0.0,  # greedy
+                "num_beams": 1,
+            }
@@
-            if current_gen_kwargs["temperature"] > 0:
-                current_gen_kwargs["do_sample"] = True
-            else:
-                current_gen_kwargs["do_sample"] = False
-                current_gen_kwargs["temperature"] = None
-                current_gen_kwargs["top_p"] = None
+            do_sample = float(current_gen_kwargs.get("temperature", 0.0)) > 0.0
+            top_p = current_gen_kwargs.get("top_p", None)
+            if do_sample and top_p is None:
+                top_p = 0.95
@@
-            cont = self.model.generate(
-                **inputs,
-                do_sample=current_gen_kwargs["do_sample"],
-                temperature=current_gen_kwargs["temperature"],
-                top_p=current_gen_kwargs["top_p"],
-                num_beams=current_gen_kwargs["num_beams"],
-                max_new_tokens=current_gen_kwargs["max_new_tokens"],
-                use_cache=self.use_cache,
-            )
+            gen_args = dict(
+                **inputs,
+                num_beams=current_gen_kwargs["num_beams"],
+                max_new_tokens=current_gen_kwargs["max_new_tokens"],
+                use_cache=self.use_cache,
+            )
+            if do_sample:
+                gen_args.update(do_sample=True, temperature=float(current_gen_kwargs["temperature"]), top_p=float(top_p))
+            cont = self.model.generate(**gen_args)

Also applies to: 269-277

🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 253-268 (and also apply same
fix to 269-277), the merged generation kwargs may include keys with value None
which will raise at runtime when passed to generate(); after merging
default_gen_kwargs and gen_kwargs, determine do_sample = True if temperature is
provided and > 0, otherwise set do_sample = False and ensure temperature and
top_p are omitted rather than set to None; finally, remove any keys whose value
is None from current_gen_kwargs (or build current_gen_kwargs only from keys with
non-None values) so generate() only receives valid parameters.

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
examples/models/gemma3.sh (1)

4-10: Nice: configurables + quoting addressed.

Env‑override defaults and consistent quoting improve reuse and safety. LGTM.

🧹 Nitpick comments (6)
examples/models/gemma3.sh (6)

1-3: Add shebang and strict mode; fix ShellCheck SC2148.

Without a shebang, direct execution depends on the parent shell. Enable strict mode for safer failures.

Apply:

+#!/usr/bin/env bash
+# shellcheck shell=bash
+set -euo pipefail
+IFS=$'\n\t'
 # Run and exactly reproduce gemma3 results!
 # mme as an example

2-2: Fix/clarify the comment (“mme”).

Likely meant MMMU; otherwise reference the exact benchmark.

-# mme as an example
+# MMMU as an example (adjust --tasks as needed)

9-15: Ensure output directory exists; avoid trailing slash.

Prevent runtime failure if logs/ missing and avoid double slashes.

-OUTPUT_PATH="${OUTPUT_PATH:-./logs/}"
+OUTPUT_PATH="${OUTPUT_PATH:-./logs}"
 
+mkdir -p "${OUTPUT_PATH}"
 accelerate launch --num_processes "${NUM_PROCESSES}" --main_process_port "${MAIN_PORT}" -m lmms_eval \
   --model gemma3 \
   --model_args "pretrained=${MODEL_ID}" \
   --tasks "${TASKS}" \
   --batch_size "${BATCH_SIZE}" --output_path "${OUTPUT_PATH}"

6-6: Make model args extensible from the environment.

Allows passing dtype/attn impl/device_map without editing the script.

-MODEL_ID="${MODEL_ID:-google/gemma-3-4b-it}"
+MODEL_ID="${MODEL_ID:-google/gemma-3-4b-it}"
+MODEL_ARGS="${MODEL_ARGS:-pretrained=${MODEL_ID}}"

And update the launch:

-  --model_args "pretrained=${MODEL_ID}" \
+  --model_args "${MODEL_ARGS}" \

4-15: Add an explicit seed for reproducibility.

The script claims exact reproduction; set and pass a seed.

 NUM_PROCESSES="${NUM_PROCESSES:-8}"
 MAIN_PORT="${MAIN_PORT:-12345}"
 MODEL_ID="${MODEL_ID:-google/gemma-3-4b-it}"
 TASKS="${TASKS:-mmmu_val,ai2d,mathvista_testmini}"
 BATCH_SIZE="${BATCH_SIZE:-1}"
 OUTPUT_PATH="${OUTPUT_PATH:-./logs}"
+SEED="${SEED:-1234}"
 
 mkdir -p "${OUTPUT_PATH}"
 accelerate launch --num_processes "${NUM_PROCESSES}" --main_process_port "${MAIN_PORT}" -m lmms_eval \
   --model gemma3 \
-  --model_args "pretrained=${MODEL_ID}" \
+  --model_args "${MODEL_ARGS:-pretrained=${MODEL_ID}}" \
   --tasks "${TASKS}" \
-  --batch_size "${BATCH_SIZE}" --output_path "${OUTPUT_PATH}"
+  --batch_size "${BATCH_SIZE}" --output_path "${OUTPUT_PATH}" \
+  --seed "${SEED}"

Please confirm lmms_eval actually honors --seed globally for these tasks; if not, we should thread it through model/task configs.


11-15: Sanity-check NUM_PROCESSES vs available GPUs.

Hardcoding 8 can over-subscribe smaller machines; consider defaulting to nproc/GPU count or document expectation.

Would you like a small snippet to auto-detect GPUs (via nvidia-smi) and set NUM_PROCESSES accordingly?

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 68ae823 and c28f92d.

📒 Files selected for processing (1)
  • examples/models/gemma3.sh (1 hunks)
🧰 Additional context used
🪛 Shellcheck (0.10.0)
examples/models/gemma3.sh

[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.

(SC2148)

@Luodian Luodian merged commit af2ae71 into EvolvingLMMs-Lab:main Sep 17, 2025
1 check passed
@zzhbrr
Copy link
Contributor

zzhbrr commented Sep 17, 2025

Hi, @RadhaGulhane13.
After running TASKS=mme bash examples/models/gemma3.sh. The error reported is:

Model Responding:   0%|                                                                                                      | 0/2374 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/home/zzh/llmserving/benchmark/lmms-eval/lmms_eval/__main__.py", line 347, in cli_evaluate
    results, samples = cli_evaluate_single(args)
  File "/home/zzh/llmserving/benchmark/lmms-eval/lmms_eval/__main__.py", line 474, in cli_evaluate_single
    results = evaluator.simple_evaluate(
  File "/home/zzh/llmserving/benchmark/lmms-eval/lmms_eval/utils.py", line 536, in _wrapper
    return fn(*args, **kwargs)
  File "/home/zzh/llmserving/benchmark/lmms-eval/lmms_eval/evaluator.py", line 268, in simple_evaluate
    results = evaluate(
  File "/home/zzh/llmserving/benchmark/lmms-eval/lmms_eval/utils.py", line 536, in _wrapper
    return fn(*args, **kwargs)
  File "/home/zzh/llmserving/benchmark/lmms-eval/lmms_eval/evaluator.py", line 501, in evaluate
    resps = getattr(lm, reqtype)(cloned_reqs)  # Choiszt run generate until
  File "/home/zzh/llmserving/benchmark/lmms-eval/lmms_eval/models/simple/gemma3.py", line 300, in generate_until
    cont = self.model.generate(
  File "/home/zzh/anaconda3/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1962, in __getattr__
    raise AttributeError(
AttributeError: 'Gemma3Model' object has no attribute 'generate'
2025-09-17 22:57:53 | ERROR    | __main__:cli_evaluate:369 - Error during evaluation: 'Gemma3Model' object has no attribute 'generate'. Please set `--verbosity=DEBUG` to get more information.

Do you know what the problem is?

@RadhaGulhane13
Copy link
Contributor Author

Hi, @RadhaGulhane13. After running TASKS=mme bash examples/models/gemma3.sh. The error reported is:

Model Responding:   0%|                                                                                                      | 0/2374 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/home/zzh/llmserving/benchmark/lmms-eval/lmms_eval/__main__.py", line 347, in cli_evaluate
    results, samples = cli_evaluate_single(args)
  File "/home/zzh/llmserving/benchmark/lmms-eval/lmms_eval/__main__.py", line 474, in cli_evaluate_single
    results = evaluator.simple_evaluate(
  File "/home/zzh/llmserving/benchmark/lmms-eval/lmms_eval/utils.py", line 536, in _wrapper
    return fn(*args, **kwargs)
  File "/home/zzh/llmserving/benchmark/lmms-eval/lmms_eval/evaluator.py", line 268, in simple_evaluate
    results = evaluate(
  File "/home/zzh/llmserving/benchmark/lmms-eval/lmms_eval/utils.py", line 536, in _wrapper
    return fn(*args, **kwargs)
  File "/home/zzh/llmserving/benchmark/lmms-eval/lmms_eval/evaluator.py", line 501, in evaluate
    resps = getattr(lm, reqtype)(cloned_reqs)  # Choiszt run generate until
  File "/home/zzh/llmserving/benchmark/lmms-eval/lmms_eval/models/simple/gemma3.py", line 300, in generate_until
    cont = self.model.generate(
  File "/home/zzh/anaconda3/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1962, in __getattr__
    raise AttributeError(
AttributeError: 'Gemma3Model' object has no attribute 'generate'
2025-09-17 22:57:53 | ERROR    | __main__:cli_evaluate:369 - Error during evaluation: 'Gemma3Model' object has no attribute 'generate'. Please set `--verbosity=DEBUG` to get more information.

Do you know what the problem is?

I see the problem here. During the code refactor, the following line:

self._model = Gemma3ForConditionalGeneration.from_pretrained(pretrained, **model_kwargs).eval()

was replaced with:

try:
           self._model = AutoModelForVision2Seq.from_pretrained(pretrained, **model_kwargs).eval()
       except Exception:
           # Fallback to a more generic approach if specific model class not found
           from transformers import AutoModel

           self._model = AutoModel.from_pretrained(pretrained, **model_kwargs).eval()

The issue is that AutoModel loads Gemma3Model, but the correct class for generation is Gemma3ForConditionalGeneration.

Refer my commit https://github.com/EvolvingLMMs-Lab/lmms-eval/tree/68ae823f405dc66e10cbac440a546ac629d7d005

@Luodian
Copy link
Contributor

Luodian commented Sep 18, 2025

Sorry I incorrectly accept changes made by Code Agent, I will fix this soon.

@zzhbrr
Copy link
Contributor

zzhbrr commented Sep 18, 2025

Thanks!!! @RadhaGulhane13 @Luodian

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants