@@ -176,7 +176,7 @@ def __init__(
176176 )
177177
178178 self ._config = self ._model .config
179- self .model .to ("cuda" ).eval ().bfloat16 ()
179+ self .model .to (self . device ).eval ().bfloat16 ()
180180 self .model .tie_weights ()
181181 self .truncation = truncation
182182 self .batch_size_per_gpu = int (batch_size )
@@ -207,10 +207,9 @@ def __init__(
207207 elif accelerator .num_processes == 1 and device_map == "auto" :
208208 eval_logger .info (f"Using { accelerator .num_processes } devices with tensor parallelism" )
209209 self ._rank = 0
210- self ._word_size = 1
210+ self ._world_size = 1
211211 else :
212212 eval_logger .info (f"Using single device: { self ._device } " )
213- self .model .to (self ._device )
214213 self ._rank = 0
215214 self ._world_size = 1
216215 self .accelerator = accelerator
@@ -405,10 +404,11 @@ def process_audio(self, audio_array, sampling_rate):
405404 audio = audio [:, 0 ]
406405 target_sr = 16000
407406 CHUNK_LIM = 480000
408- import librosa
409407
410408 if sampling_rate != target_sr :
411409 speech_wav = librosa .resample (audio_array , orig_sr = sampling_rate , target_sr = target_sr ).astype (np .float32 )
410+ else :
411+ speech_wav = audio_array .astype (np .float32 )
412412 speechs = []
413413 speech_wavs = []
414414
@@ -485,13 +485,13 @@ def _collate(x):
485485 eval_logger .info (f"Video { visuals } can not load, check the source" )
486486 continue
487487 audio = self .extract_audio (visual )
488- audio .write_audiofile ("./video_audio .wav" )
489- video_audio_path = "./video_audio .wav"
488+ audio .write_audiofile (f "./video_audio_ { self . rank } .wav" )
489+ video_audio_path = f "./video_audio_ { self . rank } .wav"
490490 speech , speech_length , speech_chunk , speech_wav = self .load_audio (video_audio_path )
491- speechs .append (speech .bfloat16 ().to ("cuda" ))
492- speech_lengths .append (speech_length .to ("cuda" ))
493- speech_chunks .append (speech_chunk .to ("cuda" ))
494- speech_wavs .append (speech_wav .to ("cuda" ))
491+ speechs .append (speech .bfloat16 ().to (self . device ))
492+ speech_lengths .append (speech_length .to (self . device ))
493+ speech_chunks .append (speech_chunk .to (self . device ))
494+ speech_wavs .append (speech_wav .to (self . device ))
495495 os .remove (video_audio_path )
496496
497497 # Process images of video
@@ -508,7 +508,7 @@ def _collate(x):
508508 if frame_idx is None :
509509 frame_idx = np .arange (0 , len (video_processed ), dtype = int ).tolist ()
510510
511- video_processed = torch .cat (video_processed , dim = 0 ).bfloat16 ().to ("cuda" )
511+ video_processed = torch .cat (video_processed , dim = 0 ).bfloat16 ().to (self . device )
512512 video_processed = (video_processed , video_processed )
513513
514514 video_data = (video_processed , (384 , 384 ), "video" )
@@ -522,44 +522,46 @@ def _collate(x):
522522 image_tensor_ , image_highres_tensor_ = process_anyres_highres_image (visual , self ._image_processor )
523523 image_tensor .append (image_tensor_ )
524524 image_highres_tensor .append (image_highres_tensor_ )
525- if all (x .shape == image_tensor [0 ].shape for x in image_tensor ):
526- image_tensor = torch .stack (image_tensor , dim = 0 )
527- if all (x .shape == image_highres_tensor [0 ].shape for x in image_highres_tensor ):
528- image_highres_tensor = torch .stack (image_highres_tensor , dim = 0 )
529- if type (image_tensor ) is list :
530- image_tensor = [_image .bfloat16 ().to ("cuda" ) for _image in image_tensor ]
531- else :
532- image_tensor = image_tensor .bfloat16 ().to ("cuda" )
533- if type (image_highres_tensor ) is list :
534- image_highres_tensor = [_image .bfloat16 ().to ("cuda" ) for _image in image_highres_tensor ]
535- else :
536- image_highres_tensor = image_highres_tensor .bfloat16 ().to ("cuda" )
537-
538- # Processing dummy audio, as required by model
539- speechs .append (torch .zeros (1 , 3000 , 128 ).bfloat16 ().to ("cuda" ))
540- speech_lengths .append (torch .LongTensor ([3000 ]).to ("cuda" ))
541- speech_wavs .append (torch .zeros ([1 , 480000 ]).to ("cuda" ))
542- speech_chunks .append (torch .LongTensor ([1 ]).to ("cuda" ))
543525
544526 elif isinstance (visual , dict ) and "array" in visual : # For Audio
545527 if MODALITY is None :
546528 MODALITY = "AUDIO"
547529 mels , speech_length , speech_chunk , speech_wav = self .process_audio (visual ["array" ], visual ["sampling_rate" ])
548- speechs .append (mels .bfloat16 ().to ("cuda" ))
549- speech_lengths .append (speech_length .to ("cuda" ))
550- speech_chunks .append (speech_chunk .to ("cuda" ))
551- speech_wavs .append (speech_wav .to ("cuda" ))
530+ speechs .append (mels .bfloat16 ().to (self . device ))
531+ speech_lengths .append (speech_length .to (self . device ))
532+ speech_chunks .append (speech_chunk .to (self . device ))
533+ speech_wavs .append (speech_wav .to (self . device ))
552534
553535 # Processing dummy images, as required by model
554- images .append (torch .zeros (1 , 3 , 224 , 224 ).to (dtype = torch .bfloat16 , device = "cuda" , non_blocking = True ))
555- images_highres .append (torch .zeros (1 , 3 , 224 , 224 ).to (dtype = torch .bfloat16 , device = "cuda" , non_blocking = True ))
536+ images .append (torch .zeros (1 , 3 , 224 , 224 ).to (dtype = torch .bfloat16 , device = self . device , non_blocking = True ))
537+ images_highres .append (torch .zeros (1 , 3 , 224 , 224 ).to (dtype = torch .bfloat16 , device = self . device , non_blocking = True ))
556538 image_sizes .append ((224 , 224 ))
557539
558540 if not video_processed and MODALITY == "VIDEO" :
559541 # If video is not processed, skip the iteration
560542 pbar .update (1 )
561543 continue
562544
545+ if MODALITY == "IMAGE" :
546+ if all (x .shape == image_tensor [0 ].shape for x in image_tensor ):
547+ image_tensor = torch .stack (image_tensor , dim = 0 )
548+ if all (x .shape == image_highres_tensor [0 ].shape for x in image_highres_tensor ):
549+ image_highres_tensor = torch .stack (image_highres_tensor , dim = 0 )
550+ if type (image_tensor ) is list :
551+ image_tensor = [_image .bfloat16 ().to ("cuda" ) for _image in image_tensor ]
552+ else :
553+ image_tensor = image_tensor .bfloat16 ().to ("cuda" )
554+ if type (image_highres_tensor ) is list :
555+ image_highres_tensor = [_image .bfloat16 ().to ("cuda" ) for _image in image_highres_tensor ]
556+ else :
557+ image_highres_tensor = image_highres_tensor .bfloat16 ().to ("cuda" )
558+
559+ # Processing dummy audio, as required by model
560+ speechs .append (torch .zeros (1 , 3000 , 128 ).bfloat16 ().to ("cuda" ))
561+ speech_lengths .append (torch .LongTensor ([3000 ]).to ("cuda" ))
562+ speech_wavs .append (torch .zeros ([1 , 480000 ]).to ("cuda" ))
563+ speech_chunks .append (torch .LongTensor ([1 ]).to ("cuda" ))
564+
563565 # we assume all gen kwargs in the batch are the same
564566 # this is safe to assume because the `grouper` object ensures it.
565567 gen_kwargs = all_gen_kwargs [0 ]
@@ -601,11 +603,11 @@ def _collate(x):
601603 eval_logger .debug (f"Prompt for doc ID { doc_id [0 ]} :\n \n { prompt } \n " )
602604
603605 if MODALITY == "AUDIO" :
604- input_ids = tokenizer_speech_token (prompt , self .tokenizer , SPEECH_TOKEN_INDEX , return_tensors = "pt" ).unsqueeze (0 ).to (self ._device )
606+ input_ids = tokenizer_speech_token (prompt , self .tokenizer , SPEECH_TOKEN_INDEX , return_tensors = "pt" ).unsqueeze (0 ).to (self .device )
605607 elif MODALITY == "IMAGE" :
606- input_ids = tokenizer_image_token (prompt , self .tokenizer , IMAGE_TOKEN_INDEX , return_tensors = "pt" ).unsqueeze (0 ).to (self ._device )
608+ input_ids = tokenizer_image_token (prompt , self .tokenizer , IMAGE_TOKEN_INDEX , return_tensors = "pt" ).unsqueeze (0 ).to (self .device )
607609 elif MODALITY == "VIDEO" :
608- input_ids = tokenizer_speech_image_token (prompt , self .tokenizer , IMAGE_TOKEN_INDEX , return_tensors = "pt" ).unsqueeze (0 ).to ("cuda" )
610+ input_ids = tokenizer_speech_image_token (prompt , self .tokenizer , IMAGE_TOKEN_INDEX , return_tensors = "pt" ).unsqueeze (0 ).to (self . device )
609611 pad_token_ids = 151643
610612 attention_masks = input_ids .ne (pad_token_ids ).long ().to (self .device )
611613
0 commit comments