Skip to content

Commit 7761876

Browse files
authored
[Feat] Support VideoLLaMA3 (#588)
* add videollama3 * update
1 parent c8c146e commit 7761876

File tree

2 files changed

+282
-0
lines changed

2 files changed

+282
-0
lines changed

lmms_eval/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
"tinyllava": "TinyLlava",
5757
"videoChatGPT": "VideoChatGPT",
5858
"videochat2": "VideoChat2",
59+
"videollama3": "VideoLLaMA3",
5960
"video_llava": "VideoLLaVA",
6061
"vila": "VILA",
6162
"vita": "VITA",

lmms_eval/models/videollama3.py

Lines changed: 281 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
import base64
2+
import uuid
3+
from io import BytesIO
4+
from typing import List, Optional, Tuple, Union
5+
6+
import numpy as np
7+
import torch
8+
from accelerate import Accelerator, DistributedType
9+
from decord import VideoReader, cpu
10+
from loguru import logger as eval_logger
11+
from PIL import Image
12+
from tqdm import tqdm
13+
from transformers import (
14+
AutoImageProcessor,
15+
AutoModel,
16+
AutoModelForCausalLM,
17+
AutoProcessor,
18+
)
19+
20+
from lmms_eval import utils
21+
from lmms_eval.api.instance import Instance
22+
from lmms_eval.api.model import lmms
23+
from lmms_eval.api.registry import register_model
24+
25+
26+
@register_model("videollama3")
27+
class VideoLLaMA3(lmms):
28+
"""
29+
VideoLLaMA3 Model
30+
Video checkpoint from Hugging Face: DAMO-NLP-SG/VideoLLaMA3-7B
31+
Image checkpoint from Hugging Face: DAMO-NLP-SG/VideoLLaMA3-7B-Image
32+
33+
Example usage:
34+
35+
accelerate launch --num_processes=8 --main_process_port 12345 -m lmms_eval \
36+
--model videollama3 \
37+
--model_args pretrained=DAMO-NLP-SG/VideoLLaMA3-7B \
38+
--tasks mvbench \
39+
--batch_size 1 \
40+
--log_samples \
41+
--log_samples_suffix debug \
42+
--output_path ./logs/
43+
44+
accelerate launch --num_processes=8 --main_process_port 12345 -m lmms_eval \
45+
--model videollama3 \
46+
--model_args pretrained=DAMO-NLP-SG/VideoLLaMA3-7B-Image \
47+
--tasks docvqa_test \
48+
--batch_size 1 \
49+
--log_samples \
50+
--log_samples_suffix debug \
51+
--output_path ./logs/
52+
"""
53+
54+
def __init__(
55+
self,
56+
pretrained: str = "DAMO-NLP-SG/VideoLLaMA3-7B",
57+
device: Optional[str] = "cuda",
58+
device_map: Optional[str] = "auto",
59+
batch_size: Optional[Union[int, str]] = 1,
60+
use_flash_attention_2: Optional[bool] = True,
61+
max_num_frames: int = 180,
62+
use_custom_video_loader=False, # True for video-mmmu
63+
**kwargs,
64+
) -> None:
65+
super().__init__()
66+
# Do not use kwargs for now
67+
assert kwargs == {}, f"Unexpected kwargs: {kwargs}"
68+
69+
accelerator = Accelerator()
70+
if accelerator.num_processes > 1:
71+
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
72+
self.device_map = f"cuda:{accelerator.local_process_index}"
73+
elif accelerator.num_processes == 1 and device_map == "auto":
74+
self._device = torch.device(device)
75+
self.device_map = device_map
76+
else:
77+
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
78+
self.device_map = f"cuda:{accelerator.local_process_index}"
79+
80+
if use_flash_attention_2:
81+
self._model = AutoModelForCausalLM.from_pretrained(
82+
pretrained,
83+
trust_remote_code=True,
84+
device_map=self.device_map,
85+
torch_dtype=torch.bfloat16,
86+
attn_implementation="flash_attention_2",
87+
)
88+
else:
89+
self._model = AutoModelForCausalLM.from_pretrained(
90+
pretrained,
91+
trust_remote_code=True,
92+
device_map=self.device_map,
93+
torch_dtype=torch.bfloat16,
94+
)
95+
self.processor = AutoProcessor.from_pretrained(pretrained, trust_remote_code=True)
96+
self.tokenizer = self.processor.tokenizer
97+
self.max_num_frames = max_num_frames
98+
self.batch_size_per_gpu = int(batch_size)
99+
self.use_custom_video_loader = use_custom_video_loader
100+
101+
if accelerator.num_processes > 1:
102+
assert accelerator.distributed_type in [
103+
DistributedType.FSDP,
104+
DistributedType.MULTI_GPU,
105+
], "Unsupported distributed type provided. Only DDP and FSDP are supported."
106+
if accelerator.distributed_type == DistributedType.FSDP:
107+
self._model = accelerator.prepare(self.model)
108+
else:
109+
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
110+
self.accelerator = accelerator
111+
if self.accelerator.is_local_main_process:
112+
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
113+
self._rank = self.accelerator.local_process_index
114+
self._world_size = self.accelerator.num_processes
115+
else:
116+
self._rank = 0
117+
self._world_size = 1
118+
119+
@property
120+
def config(self):
121+
# return the associated transformers.AutoConfig for the given pretrained model.
122+
return self._config
123+
124+
@property
125+
def model(self):
126+
# returns the model, unwrapping it if using Accelerate
127+
if hasattr(self, "accelerator"):
128+
return self.accelerator.unwrap_model(self._model)
129+
else:
130+
return self._model
131+
132+
@property
133+
def batch_size(self):
134+
return self.batch_size_per_gpu
135+
136+
@property
137+
def device(self):
138+
return self._device
139+
140+
@property
141+
def rank(self):
142+
return self._rank
143+
144+
@property
145+
def world_size(self):
146+
return self._world_size
147+
148+
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
149+
raise NotImplementedError("Loglikelihood is not implemented for VideoLLaMA3")
150+
151+
def flatten(self, input):
152+
new_list = []
153+
for i in input:
154+
for j in i:
155+
new_list.append(j)
156+
return new_list
157+
158+
def generate_until(self, requests: List[Instance]) -> List[str]:
159+
res = []
160+
161+
def _collate(x):
162+
# the negative sign on len(toks) sorts descending - this has a few advantages:
163+
# - time estimates will always be over not underestimates, which is more useful for planning
164+
# - to know the size of a batch when going through the list, you know the first one is always the batch
165+
# padded context length. this is useful to simplify the batching logic and more importantly to make
166+
# automatic adaptive batches much much easier to implement
167+
# - any OOMs will happen right away rather than near the end
168+
toks = self.tokenizer.encode(x[0])
169+
return -len(toks), x[0]
170+
171+
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
172+
# we group requests by their generation_kwargs,
173+
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
174+
# in the same batch.
175+
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
176+
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
177+
for chunk in chunks:
178+
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
179+
task = task[0]
180+
split = split[0]
181+
visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
182+
visuals = self.flatten(visuals)
183+
184+
gen_kwargs = all_gen_kwargs[0]
185+
186+
message = []
187+
188+
processed_visuals = []
189+
for i, context in enumerate(contexts):
190+
if len(visuals) > 0:
191+
visual = visuals[i] if i < len(visuals) else None
192+
if isinstance(visual, str) and visual.endswith((".mp4", ".avi", ".mov")): # Video file
193+
if self.use_custom_video_loader:
194+
frames, timestamps = read_video_custom(visual)
195+
message.append({"role": "user", "content": [{"type": "video", "video": frames, "timestamps": timestamps, "num_frames": len(timestamps)}, {"type": "text", "text": context}]})
196+
else:
197+
message.append({"role": "user", "content": [{"type": "video", "video": {"video_path": visual, "fps": 1, "max_frames": self.max_num_frames}}, {"type": "text", "text": context}]})
198+
elif isinstance(visual, Image.Image):
199+
message.append({"role": "user", "content": [{"type": "image", "image": visual}, {"type": "text", "text": context}]})
200+
elif isinstance(visual, (list, tuple)) and all(isinstance(v, Image.Image) for v in visual): # Multiple images
201+
image_content = []
202+
for v in visual:
203+
image_content.append({"type": "image", "image": v})
204+
message.append({"role": "user", "content": image_content + [{"type": "text", "text": context}]})
205+
else:
206+
message.append({"role": "user", "content": [{"type": "text", "text": context}]})
207+
inputs = self.processor(conversation=message, return_tensors="pt", add_generation_prompt=True)
208+
209+
do_sample = gen_kwargs.get("do_sample", False)
210+
temperature = gen_kwargs.get("temperature", 0.2 if do_sample else 1.0)
211+
top_p = gen_kwargs.get("top_p", 0.9 if do_sample else 1.0)
212+
top_k = gen_kwargs.get("top_k", 20 if do_sample else 50)
213+
max_new_tokens = gen_kwargs.get("max_new_tokens", 2048)
214+
215+
inputs = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
216+
if "pixel_values" in inputs:
217+
inputs["pixel_values"] = inputs["pixel_values"].to(torch.bfloat16)
218+
219+
with torch.inference_mode():
220+
output_ids = self.model.generate(
221+
**inputs,
222+
do_sample=do_sample,
223+
temperature=temperature,
224+
max_new_tokens=max_new_tokens,
225+
top_p=top_p,
226+
top_k=top_k,
227+
use_cache=True,
228+
pad_token_id=self.processor.tokenizer.eos_token_id,
229+
)
230+
231+
answers = self.processor.batch_decode(output_ids, skip_special_tokens=True)
232+
233+
for i, ans in enumerate(answers):
234+
answers[i] = ans.strip()
235+
236+
for ans, context in zip(answers, contexts):
237+
res.append(ans)
238+
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), ans)
239+
pbar.update(1)
240+
241+
res = re_ords.get_original(res)
242+
243+
pbar.close()
244+
return res
245+
246+
def generate_until_multi_round(self, requests) -> List[str]:
247+
raise NotImplementedError("TODO: Implement multi-round generation")
248+
249+
250+
def read_video_custom(video_path, fps=1, max_frames_num=180, force_include_last_frame=True):
251+
vr = VideoReader(video_path, ctx=cpu(0))
252+
duration = len(vr)
253+
vid_fps = vr.get_avg_fps()
254+
fps_list = []
255+
256+
if fps is not None and duration / vid_fps < max_frames_num:
257+
segment_len = min(vid_fps // fps, duration)
258+
frame_ids = np.arange(segment_len // 2, duration, segment_len, dtype=int)
259+
if force_include_last_frame:
260+
last_frame_id = duration - 1
261+
if last_frame_id not in frame_ids:
262+
frame_ids = frame_ids.tolist()
263+
frame_ids.append(last_frame_id)
264+
else:
265+
if duration <= max_frames_num:
266+
frame_ids = np.arange(duration).astype(int).tolist()
267+
else:
268+
frame_ids = np.linspace(0, duration - 1, max_frames_num, dtype=int)
269+
if force_include_last_frame:
270+
last_frame_id = duration - 1
271+
if last_frame_id not in frame_ids:
272+
uniform_sampled_frames = np.linspace(0, duration - 1, max_frames_num - 1, dtype=int)
273+
frame_ids = uniform_sampled_frames.tolist()
274+
frame_ids.append(last_frame_id)
275+
276+
for frame_id in frame_ids:
277+
fps_list.append(frame_id / vid_fps)
278+
279+
frames = vr.get_batch(frame_ids).asnumpy()
280+
# print(fps_list)
281+
return frames, fps_list

0 commit comments

Comments
 (0)