-
Notifications
You must be signed in to change notification settings - Fork 451
[Feature] Support for Gemma-3 Models #821
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support for Gemma-3 Models #821
Conversation
Signed-off-by: Radha Gulhane <[email protected]>
|
Caution Review failedThe pull request is closed. WalkthroughAdds a new simple model wrapper Gemma3 (registered as "gemma3"), updates the model registry, and adds an example shell script to run lmms_eval with a Gemma-3 IT checkpoint via accelerate for specified evaluation tasks. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant CLI as lmms_eval (CLI)
participant Registry as ModelRegistry
participant Gemma as Gemma3
participant Proc as Processor/Tokenizer
participant HF as HF Model (Gemma-3 IT)
User->>CLI: accelerate launch ... --model gemma3 --tasks ...
CLI->>Registry: resolve "gemma3"
Registry-->>CLI: lmms_eval.models.simple.gemma3.Gemma3
CLI->>Gemma: initialize(pretrained, device_map, ...)
rect rgba(220,240,255,0.4)
Gemma->>Proc: load tokenizer/processor
Gemma->>HF: load model (single/MULTI_GPU/FSDP aware)
end
loop For each batched request group
CLI->>Gemma: generate_until(requests)
Gemma->>Gemma: collate/sort by length, assemble messages + visuals
Gemma->>Proc: apply_chat_template -> tensors
Proc-->>Gemma: input tensors (ids, pixel_values, masks)
Gemma->>HF: generate(...)
HF-->>Gemma: output ids
Gemma->>Proc: decode + apply stoppers
Gemma-->>CLI: ordered text results
end
CLI-->>User: metrics and outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Poem
Pre-merge checks and finishing touches✅ Passed checks (3 passed)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (2)
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 10
🧹 Nitpick comments (7)
examples/models/gemma3.sh (1)
1-1: Add a shebang and strict mode for portability and safety.Without a shebang, shells may interpret differently; strict mode avoids silent failures.
Apply this diff:
+#/usr/bin/env bash +set -euo pipefail + # Run and exactly reproduce gemma3 results! # mme as an examplelmms_eval/models/simple/gemma3.py (6)
236-239: Prefer iterable unpacking for list assembly.Cleaner and avoids an intermediate list.
Apply this diff:
- message.append( - { - "role": "user", - "content": processed_visuals + [{"type": "text", "text": context}], - } - ) + message.append( + {"role": "user", "content": [*processed_visuals, {"type": "text", "text": context}]} + )
20-22: Over-broad warning suppression.Global ignore can hide real issues; scope or narrow categories instead.
Apply this diff:
-warnings.simplefilter("ignore", category=DeprecationWarning) -warnings.filterwarnings("ignore") +warnings.filterwarnings("ignore", category=DeprecationWarning)
89-93: Optional check: be explicit with Optional comparisons (Pyright).Use
is not Noneto satisfy strict Optional checks.Apply this diff:
- if reasoning_prompt: + if reasoning_prompt is not None: self.reasoning_prompt = reasoning_prompt.replace("\\n", "\n") else: self.reasoning_prompt = None
216-221: Reasoning prompt concatenation: add a separating newline.Prevents accidental token fusion.
Apply this diff:
- if self.reasoning_prompt: - context = context.strip() + self.reasoning_prompt + if self.reasoning_prompt: + context = context.strip() + "\n" + self.reasoning_prompt contexts[i] = context
64-67: dtype choice: make configurable.Hardcoding bfloat16 can fail on older GPUs; expose via init or detect.
Would you like a patch to auto-select torch.float16 on devices without bf16?
193-203: Exception type and message style.Use TypeError for wrong types per Ruff TRY004/TRY003.
Apply this diff:
- elif not isinstance(until, list): - raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str, list], but got {type(until)}") + elif not isinstance(until, list): + raise TypeError("gen_kwargs['until'] must be str or list[str]")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/models/gemma3.sh(1 hunks)lmms_eval/models/__init__.py(1 hunks)lmms_eval/models/simple/gemma3.py(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
📄 CodeRabbit inference engine (CLAUDE.md)
**/*.py: Type hints are required for all Python code
Public APIs must have docstrings
Maximum line length is 88 characters
Use PEP 8 naming: snake_case for functions/variables
Class names must use PascalCase
Constants should be in UPPER_SNAKE_CASE
Use f-strings for string formatting
Use early returns to avoid nested conditions
Use descriptive names; prefix handler functions with 'handle'
Prefer constants over functions where possible
Prefer functional, immutable approaches when not verbose
Define composing (higher-level) functions before their components
Mark issues in existing code with TODO: prefix in comments
Use functional and stateless approaches where they improve clarity
Use Ruff to enforce: import sorting (I001) and no unused imports
For long strings, wrap using parentheses rather than backslashes
Format long function calls over multiple lines with proper indentation
Split long import lists across multiple lines
Use Pyright type checking: add explicit None checks for Optional values
Use Pyright type narrowing for strings where applicable
Use Ruff (via pre-commit) to format and lint Python files
Document public APIs and test thoroughly
Files:
lmms_eval/models/__init__.pylmms_eval/models/simple/gemma3.py
🧬 Code graph analysis (1)
lmms_eval/models/simple/gemma3.py (4)
lmms_eval/api/instance.py (2)
Instance(6-29)args(25-29)lmms_eval/api/model.py (2)
lmms(26-331)add_partial(348-352)lmms_eval/api/registry.py (1)
register_model(11-24)lmms_eval/utils.py (3)
Collator(865-1020)chunks(136-171)get_batched(891-913)
🪛 Shellcheck (0.10.0)
examples/models/gemma3.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
🪛 Ruff (0.12.2)
lmms_eval/models/simple/gemma3.py
199-199: Prefer TypeError exception for invalid type
(TRY004)
199-199: Avoid specifying long messages outside the exception class
(TRY003)
237-237: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
🔇 Additional comments (3)
lmms_eval/models/__init__.py (1)
28-28: Registering gemma3 looks correct.Mapping resolves to lmms_eval.models.simple.gemma3.Gemma3 as intended.
lmms_eval/models/simple/gemma3.py (2)
269-277: Consider removing tie_weights() call.Generally unnecessary post-loading; may be a no-op or slow.
Do we need weight tying here? If not, removing can speed init slightly.
186-193: Verified — dependency present and usage consistent. self.task_dict is initialized in lmms_eval/api/model.py and populated by evaluator (lmms_eval/evaluator.py); doc_to_visual is implemented with the expected callable/string signature in lmms_eval/api/task.py; using doc_to_visual[0] is correct in batched code because zip(*chunk) yields tuples.
| def __init__( | ||
| self, | ||
| pretrained: str = "google/gemma-3-27b-it", | ||
| device: Optional[str] = "cuda", | ||
| device_map: Optional[str] = "auto", | ||
| batch_size: Optional[Union[int, str]] = 1, | ||
| trust_remote_code: Optional[bool] = True, | ||
| use_cache=True, | ||
| attn_implementation: Optional[str] = None, | ||
| min_pixels: int = 256 * 28 * 28, | ||
| max_pixels: int = 1605632, | ||
| max_num_frames: int = 32, | ||
| interleave_visuals: Optional[bool] = False, | ||
| system_prompt: Optional[str] = "You are a helpful assistant.", | ||
| reasoning_prompt: Optional[str] = None, | ||
| **kwargs, | ||
| ) -> None: | ||
| super().__init__() | ||
| # Do not use kwargs for now | ||
| assert kwargs == {}, f"Unexpected kwargs: {kwargs}" | ||
|
|
||
| accelerator = Accelerator() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Conflicting kwargs usage for max_length.
You assert kwargs == {} but later read max_length from kwargs; unreachable path.
Apply this diff to make max_length explicit and remove dead kwargs access:
- system_prompt: Optional[str] = "You are a helpful assistant.",
- reasoning_prompt: Optional[str] = None,
- **kwargs,
+ system_prompt: Optional[str] = "You are a helpful assistant.",
+ reasoning_prompt: Optional[str] = None,
+ max_length: int = 2048,
) -> None:
super().__init__()
- # Do not use kwargs for now
- assert kwargs == {}, f"Unexpected kwargs: {kwargs}"
+ # NOTE: no extra kwargs are accepted to avoid silent misconfigurations
+ # (add parameters explicitly if needed)
@@
- self._max_length = kwargs.get("max_length", 2048)
+ self._max_length = int(max_length)Also applies to: 77-79
🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 34 to 55 (and also adjust
lines ~77-79), the constructor asserts kwargs == {} but later code reads
max_length from kwargs, making that path unreachable; add max_length as an
explicit parameter (e.g., max_length: Optional[int] = None) to the __init__
signature, remove any reads of max_length from kwargs, and drop the kwargs
assertion only if other legitimate kwargs are expected (otherwise keep assertion
and ensure no kwargs are used anywhere); update the later block at lines ~77-79
to reference the explicit max_length parameter instead of accessing kwargs.
| batch_size: Optional[Union[int, str]] = 1, | ||
| trust_remote_code: Optional[bool] = True, | ||
| use_cache=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
batch_size type mismatch.
Annotated as Optional[Union[int, str]] but cast with int(); non‑numeric strings will crash.
Apply this diff to accept only int:
- batch_size: Optional[Union[int, str]] = 1,
+ batch_size: int = 1,
@@
- self.batch_size_per_gpu = int(batch_size)
+ self.batch_size_per_gpu = batch_sizeAlso applies to: 80-80
🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 39-41 (and also at line 80),
the batch_size parameter is annotated as Optional[Union[int, str]] but the
implementation casts it with int(), which will crash on non-numeric strings;
change the annotation to Optional[int] and update the code to accept only ints
by removing string handling/casting (or explicitly validate and raise a clear
TypeError if a non-int is passed), and adjust any callers or docstrings
accordingly so batch_size is strictly an int (or None).
| self._tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=trust_remote_code, device_map=self.device_map) | ||
| self.processor = AutoProcessor.from_pretrained(pretrained, max_pixels=max_pixels, min_pixels=min_pixels) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AutoTokenizer.from_pretrained does not accept device_map.
Passing device_map will raise a TypeError; remove it from tokenizer init.
Apply this diff:
- self._tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=trust_remote_code, device_map=self.device_map)
+ self._tokenizer = AutoTokenizer.from_pretrained(
+ pretrained,
+ trust_remote_code=trust_remote_code,
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| self._tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=trust_remote_code, device_map=self.device_map) | |
| self.processor = AutoProcessor.from_pretrained(pretrained, max_pixels=max_pixels, min_pixels=min_pixels) | |
| self._tokenizer = AutoTokenizer.from_pretrained( | |
| pretrained, | |
| trust_remote_code=trust_remote_code, | |
| ) | |
| self.processor = AutoProcessor.from_pretrained(pretrained, max_pixels=max_pixels, min_pixels=min_pixels) |
🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 74-76,
AutoTokenizer.from_pretrained is being called with a device_map keyword which
AutoTokenizer does not accept; remove the device_map argument from the tokenizer
initialization and ensure any device_map usage is applied when loading the model
(not the tokenizer), i.e., call AutoTokenizer.from_pretrained(pretrained,
trust_remote_code=trust_remote_code) and keep device_map only for model loading
logic.
| self.max_pixels = max_pixels | ||
| self.min_pixels = min_pixels | ||
| self.max_num_frames = max_num_frames | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Unused parameters: interleave_visuals, max_num_frames.
They are not used; either implement or remove to avoid confusion.
Would you like me to wire these into visual processing (e.g., limit frame sampling)?
🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 85 to 88, the constructor
assigns max_num_frames and ignores interleave_visuals (unused) which causes
confusion; either remove the unused parameters from the signature and delete
their assignments, or persist and use them: store self.interleave_visuals =
interleave_visuals and self.max_num_frames = max_num_frames (if not already),
then wire them into visual processing code—use interleave_visuals to control
whether visual frames are interleaved with other modalities and enforce
self.max_num_frames when sampling/iterating frames (truncate or sample to that
limit) wherever frames are prepared or batched.
| @property | ||
| def config(self): | ||
| # return the associated transformers.AutoConfig for the given pretrained model. | ||
| return self._config | ||
|
|
||
| @property | ||
| def tokenizer(self): | ||
| return self._tokenizer | ||
|
|
||
| @property | ||
| def model(self): | ||
| # returns the model, unwrapping it if using Accelerate | ||
| if hasattr(self, "accelerator"): | ||
| return self.accelerator.unwrap_model(self._model) | ||
| else: | ||
| return self._model | ||
|
|
||
| @property | ||
| def eot_token_id(self): | ||
| # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence* | ||
| # return self.tokenizer.eod_id | ||
| return self.tokenizer.eos_token_id | ||
|
|
||
| @property | ||
| def max_length(self): | ||
| return self._max_length | ||
|
|
||
| @property | ||
| def batch_size(self): | ||
| return self.batch_size_per_gpu | ||
|
|
||
| @property | ||
| def device(self): | ||
| return self._device | ||
|
|
||
| @property | ||
| def rank(self): | ||
| return self._rank | ||
|
|
||
| @property | ||
| def world_size(self): | ||
| return self._world_size | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add missing type hints and docstrings for public API.
Guidelines require type hints and docstrings; properties and flatten lack them.
Apply this diff:
class Gemma3(lmms):
- """
- Gemma3 Model
- https://huggingface.co/google/gemma-3-27b-it
- """
+ """Gemma-3 simple wrapper for lmms-eval.
+
+ Loads a Gemma-3 IT checkpoint and provides batched `generate_until`.
+ """
@@
- def config(self):
- # return the associated transformers.AutoConfig for the given pretrained model.
- return self._config
+ def config(self) -> PretrainedConfig:
+ """Return the underlying Transformers config."""
+ return self._config
@@
- def tokenizer(self):
- return self._tokenizer
+ def tokenizer(self) -> PreTrainedTokenizerBase:
+ """Return the tokenizer/processor tokenizer."""
+ return self._tokenizer
@@
- def model(self):
- # returns the model, unwrapping it if using Accelerate
+ def model(self) -> Gemma3ForConditionalGeneration:
+ """Return the model, unwrapped if using Accelerate."""
if hasattr(self, "accelerator"):
return self.accelerator.unwrap_model(self._model)
else:
return self._model
@@
- def eot_token_id(self):
+ def eot_token_id(self) -> int:
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
# return self.tokenizer.eod_id
return self.tokenizer.eos_token_id
@@
- def max_length(self):
+ def max_length(self) -> int:
return self._max_length
@@
- def batch_size(self):
+ def batch_size(self) -> int:
return self.batch_size_per_gpu
@@
- def device(self):
+ def device(self) -> torch.device:
return self._device
@@
- def rank(self):
+ def rank(self) -> int:
return self._rank
@@
- def world_size(self):
+ def world_size(self) -> int:
return self._world_size
@@
- def flatten(self, input):
- new_list = []
- for i in input:
- for j in i:
- new_list.append(j)
- return new_list
+ def flatten(self, items: List[List[str]]) -> List[str]:
+ """Flatten a list-of-lists one level."""
+ return [j for i in items for j in i]
@@
- def generate_until(self, requests: List[Instance]) -> List[str]:
+ def generate_until(self, requests: List[Instance]) -> List[str]:
+ """Generate responses for each Instance until stop tokens."""
res = []
@@
- def generate_until_multi_round(self, requests) -> List[str]:
- raise NotImplementedError("TODO: Implement multi-round generation")
+ def generate_until_multi_round(self, requests: List[Instance]) -> List[str]:
+ """Multi-round chat generation (not yet implemented)."""
+ raise NotImplementedError("TODO: Implement multi-round generation")Also applies to: 160-166, 167-167, 298-300
🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 114 to 156 (also apply similar
changes at 160-166, 167, and 298-300), several public @property methods lack
type hints and docstrings; add explicit return type annotations for each
property (config, tokenizer, model, eot_token_id, max_length, batch_size,
device, rank, world_size) and short one-line docstrings describing the returned
value and types (e.g., "Return the transformers.AutoConfig for the pretrained
model."), and fix the batch_size property to return self.batch_size_per_gpu with
the proper type hint; ensure imports support any forward types if needed and
keep docstrings consistent with project style.
| for i in range(len(contexts)): | ||
| if "<image>" in contexts[i]: | ||
| contexts[i] = contexts[i].replace("<image>", "") | ||
|
|
||
| batched_messages = [] | ||
| for i, context in enumerate(contexts): | ||
| if "<image>" in context: | ||
| context = context.replace("<image>", "") | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
You strip twice (Lines 207–211 and 213–215). Keep one.
Apply this diff to remove the first loop:
- for i in range(len(contexts)):
- if "<image>" in contexts[i]:
- contexts[i] = contexts[i].replace("<image>", "")📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| for i in range(len(contexts)): | |
| if "<image>" in contexts[i]: | |
| contexts[i] = contexts[i].replace("<image>", "") | |
| batched_messages = [] | |
| for i, context in enumerate(contexts): | |
| if "<image>" in context: | |
| context = context.replace("<image>", "") | |
| batched_messages = [] | |
| for i, context in enumerate(contexts): | |
| if "<image>" in context: | |
| context = context.replace("<image>", "") |
🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 207 to 215 there is duplicate
removal of the "<image>" token (first loop lines 207–211 and again inside the
batched_messages loop lines 213–215); remove the first standalone loop (lines
207–211) so that "<image>" is only stripped once when building batched_messages,
leaving the single replacement inside the batched_messages creation.
| if self.device_map == "auto": | ||
| inputs = inputs.to("cuda") | ||
| else: | ||
| inputs = inputs.to(self.device) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Avoid double .to() and wrap long call.
Inputs are moved twice; also exceed 88 chars.
Apply this diff:
- inputs = self.processor.apply_chat_template(batched_messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", padding="max_length", pad_to_multiple_of=8, max_length=self.max_length).to(
- self.model.device, dtype=torch.bfloat16
- )
-
- if self.device_map == "auto":
- inputs = inputs.to("cuda")
- else:
- inputs = inputs.to(self.device)
+ inputs = self.processor.apply_chat_template(
+ batched_messages,
+ add_generation_prompt=True,
+ tokenize=True,
+ return_dict=True,
+ return_tensors="pt",
+ padding="max_length",
+ pad_to_multiple_of=8,
+ max_length=self.max_length,
+ )
+ target_device = "cuda" if self.device_map == "auto" else self.device
+ inputs = inputs.to(target_device, dtype=torch.bfloat16)Also applies to: 243-245
🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 243-245 and 247-251, the code
calls .to() twice on inputs and creates lines that exceed the 88-char limit;
replace the conditional device selection with a single device variable (e.g.
device = "cuda" if self.device_map == "auto" else self.device) and then call
inputs = inputs.to(device) once, breaking the line if necessary to keep it under
88 characters.
| default_gen_kwargs = { | ||
| "max_new_tokens": 128, | ||
| "temperature": 0.0, # Set to 0 for greedy default | ||
| "top_p": None, | ||
| "num_beams": 1, | ||
| } | ||
| # Update with provided kwargs | ||
| current_gen_kwargs = {**default_gen_kwargs, **gen_kwargs} | ||
|
|
||
| if current_gen_kwargs["temperature"] > 0: | ||
| current_gen_kwargs["do_sample"] = True | ||
| else: | ||
| current_gen_kwargs["do_sample"] = False | ||
| current_gen_kwargs["temperature"] = None | ||
| current_gen_kwargs["top_p"] = None | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generation kwargs: avoid passing None values to generate().
temperature/top_p=None can raise at runtime. Only pass when set; keep greedy defaults.
Apply this diff:
- default_gen_kwargs = {
- "max_new_tokens": 128,
- "temperature": 0.0, # Set to 0 for greedy default
- "top_p": None,
- "num_beams": 1,
- }
+ default_gen_kwargs = {
+ "max_new_tokens": 128,
+ "temperature": 0.0, # greedy
+ "num_beams": 1,
+ }
@@
- if current_gen_kwargs["temperature"] > 0:
- current_gen_kwargs["do_sample"] = True
- else:
- current_gen_kwargs["do_sample"] = False
- current_gen_kwargs["temperature"] = None
- current_gen_kwargs["top_p"] = None
+ do_sample = float(current_gen_kwargs.get("temperature", 0.0)) > 0.0
+ top_p = current_gen_kwargs.get("top_p", None)
+ if do_sample and top_p is None:
+ top_p = 0.95
@@
- cont = self.model.generate(
- **inputs,
- do_sample=current_gen_kwargs["do_sample"],
- temperature=current_gen_kwargs["temperature"],
- top_p=current_gen_kwargs["top_p"],
- num_beams=current_gen_kwargs["num_beams"],
- max_new_tokens=current_gen_kwargs["max_new_tokens"],
- use_cache=self.use_cache,
- )
+ gen_args = dict(
+ **inputs,
+ num_beams=current_gen_kwargs["num_beams"],
+ max_new_tokens=current_gen_kwargs["max_new_tokens"],
+ use_cache=self.use_cache,
+ )
+ if do_sample:
+ gen_args.update(do_sample=True, temperature=float(current_gen_kwargs["temperature"]), top_p=float(top_p))
+ cont = self.model.generate(**gen_args)Also applies to: 269-277
🤖 Prompt for AI Agents
In lmms_eval/models/simple/gemma3.py around lines 253-268 (and also apply same
fix to 269-277), the merged generation kwargs may include keys with value None
which will raise at runtime when passed to generate(); after merging
default_gen_kwargs and gen_kwargs, determine do_sample = True if temperature is
provided and > 0, otherwise set do_sample = False and ensure temperature and
top_p are omitted rather than set to None; finally, remove any keys whose value
is None from current_gen_kwargs (or build current_gen_kwargs only from keys with
non-None values) so generate() only receives valid parameters.
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
examples/models/gemma3.sh (1)
4-10: Nice: configurables + quoting addressed.Env‑override defaults and consistent quoting improve reuse and safety. LGTM.
🧹 Nitpick comments (6)
examples/models/gemma3.sh (6)
1-3: Add shebang and strict mode; fix ShellCheck SC2148.Without a shebang, direct execution depends on the parent shell. Enable strict mode for safer failures.
Apply:
+#!/usr/bin/env bash +# shellcheck shell=bash +set -euo pipefail +IFS=$'\n\t' # Run and exactly reproduce gemma3 results! # mme as an example
2-2: Fix/clarify the comment (“mme”).Likely meant MMMU; otherwise reference the exact benchmark.
-# mme as an example +# MMMU as an example (adjust --tasks as needed)
9-15: Ensure output directory exists; avoid trailing slash.Prevent runtime failure if logs/ missing and avoid double slashes.
-OUTPUT_PATH="${OUTPUT_PATH:-./logs/}" +OUTPUT_PATH="${OUTPUT_PATH:-./logs}" +mkdir -p "${OUTPUT_PATH}" accelerate launch --num_processes "${NUM_PROCESSES}" --main_process_port "${MAIN_PORT}" -m lmms_eval \ --model gemma3 \ --model_args "pretrained=${MODEL_ID}" \ --tasks "${TASKS}" \ --batch_size "${BATCH_SIZE}" --output_path "${OUTPUT_PATH}"
6-6: Make model args extensible from the environment.Allows passing dtype/attn impl/device_map without editing the script.
-MODEL_ID="${MODEL_ID:-google/gemma-3-4b-it}" +MODEL_ID="${MODEL_ID:-google/gemma-3-4b-it}" +MODEL_ARGS="${MODEL_ARGS:-pretrained=${MODEL_ID}}"And update the launch:
- --model_args "pretrained=${MODEL_ID}" \ + --model_args "${MODEL_ARGS}" \
4-15: Add an explicit seed for reproducibility.The script claims exact reproduction; set and pass a seed.
NUM_PROCESSES="${NUM_PROCESSES:-8}" MAIN_PORT="${MAIN_PORT:-12345}" MODEL_ID="${MODEL_ID:-google/gemma-3-4b-it}" TASKS="${TASKS:-mmmu_val,ai2d,mathvista_testmini}" BATCH_SIZE="${BATCH_SIZE:-1}" OUTPUT_PATH="${OUTPUT_PATH:-./logs}" +SEED="${SEED:-1234}" mkdir -p "${OUTPUT_PATH}" accelerate launch --num_processes "${NUM_PROCESSES}" --main_process_port "${MAIN_PORT}" -m lmms_eval \ --model gemma3 \ - --model_args "pretrained=${MODEL_ID}" \ + --model_args "${MODEL_ARGS:-pretrained=${MODEL_ID}}" \ --tasks "${TASKS}" \ - --batch_size "${BATCH_SIZE}" --output_path "${OUTPUT_PATH}" + --batch_size "${BATCH_SIZE}" --output_path "${OUTPUT_PATH}" \ + --seed "${SEED}"Please confirm lmms_eval actually honors --seed globally for these tasks; if not, we should thread it through model/task configs.
11-15: Sanity-check NUM_PROCESSES vs available GPUs.Hardcoding 8 can over-subscribe smaller machines; consider defaulting to nproc/GPU count or document expectation.
Would you like a small snippet to auto-detect GPUs (via nvidia-smi) and set NUM_PROCESSES accordingly?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/models/gemma3.sh(1 hunks)
🧰 Additional context used
🪛 Shellcheck (0.10.0)
examples/models/gemma3.sh
[error] 1-1: Tips depend on target shell and yours is unknown. Add a shebang or a 'shell' directive.
(SC2148)
|
Hi, @RadhaGulhane13. Do you know what the problem is? |
I see the problem here. During the code refactor, the following line: was replaced with: The issue is that AutoModel loads Gemma3Model, but the correct class for generation is Gemma3ForConditionalGeneration. Refer my commit https://github.com/EvolvingLMMs-Lab/lmms-eval/tree/68ae823f405dc66e10cbac440a546ac629d7d005 |
|
Sorry I incorrectly accept changes made by Code Agent, I will fix this soon. |
|
Thanks!!! @RadhaGulhane13 @Luodian |
Description
This PR introduces initial support for Gemma-3 models.
Current evaluations indicate good alignment with VQA tasks and reproducibility on MATH.
Community contributions to help reproduce results on MMMU-pro and MMMU-val would be highly valuable.
Evaluation
Example Usage
An example usage script is available at:
examples/models/gemma3.shRelated Issues
#664
Summary by CodeRabbit
New Features
Chores