|
import os |
|
import re |
|
import math |
|
import json |
|
import copy |
|
import argparse |
|
import warnings |
|
import traceback |
|
|
|
import cv2 |
|
import torch |
|
import pysubs2 |
|
import numpy as np |
|
import pyarrow.parquet as pq |
|
from tqdm import tqdm |
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
import sys |
|
sys.path.append('./') |
|
from videollama2 import model_init, mm_infer |
|
from videollama2.utils import disable_torch_init |
|
|
|
|
|
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') |
|
|
|
|
|
def split_list(lst, n): |
|
"""Split a list into n (roughly) equal-sized chunks""" |
|
chunk_size = math.ceil(len(lst) / n) |
|
return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] |
|
|
|
|
|
def get_chunk(lst, n, k): |
|
chunks = split_list(lst, n) |
|
return chunks[k] |
|
|
|
|
|
def get_seq_frames(total_num_frames, desired_num_frames): |
|
""" |
|
Calculate the indices of frames to extract from a video. |
|
|
|
Parameters: |
|
total_num_frames (int): Total number of frames in the video. |
|
desired_num_frames (int): Desired number of frames to extract. |
|
|
|
Returns: |
|
list: List of indices of frames to extract. |
|
""" |
|
|
|
|
|
seg_size = float(total_num_frames - 1) / desired_num_frames |
|
|
|
seq = [] |
|
for i in range(desired_num_frames): |
|
|
|
start = int(np.round(seg_size * i)) |
|
end = int(np.round(seg_size * (i + 1))) |
|
|
|
|
|
seq.append((start + end) // 2) |
|
|
|
return seq |
|
|
|
|
|
class VideoMMEDataset(Dataset): |
|
|
|
video_formats = ['.mp4', '.avi', '.mov', '.mkv'] |
|
|
|
def __init__(self, video_folder, subtitle_folder, data_list, processor): |
|
self.video_folder = video_folder |
|
self.subtitle_folder = subtitle_folder |
|
self.data_list = data_list |
|
self.processor = processor |
|
|
|
def __len__(self): |
|
return len(self.data_list) |
|
|
|
def __getitem__(self, idx): |
|
line = self.data_list[idx] |
|
|
|
video_ytid = line['url'].split('watch?v=')[-1] |
|
|
|
for fmt in self.video_formats: |
|
temp_path = os.path.join(self.video_folder, f'{video_ytid}{fmt}') |
|
if os.path.exists(temp_path): |
|
video_path = temp_path |
|
break |
|
|
|
subtitle_path = os.path.join(self.subtitle_folder, f'{video_ytid}.srt') |
|
|
|
try: |
|
video_tensor = self.processor(video_path) |
|
num_frames = video_tensor.shape[0] |
|
except: |
|
traceback.print_exc() |
|
print(f'It occurs error when reading {video_ytid}') |
|
video_tensor = None |
|
num_frames = 0 |
|
|
|
if video_tensor is not None and os.path.exists(subtitle_path): |
|
cv2_vr = cv2.VideoCapture(video_path) |
|
duration = int(cv2_vr.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
fps = cv2_vr.get(cv2.CAP_PROP_FPS) |
|
selected_frame_ids = get_seq_frames(duration, num_frames) |
|
|
|
subs = pysubs2.load(subtitle_path, encoding="utf-8") |
|
subtitles = [] |
|
for seleced_frame_id in selected_frame_ids: |
|
sub_text = "" |
|
cur_time = pysubs2.make_time(fps=fps, frames=seleced_frame_id) |
|
for sub in subs: |
|
if sub.start < cur_time and sub.end > cur_time: |
|
sub_text = sub.text.replace("\\N", " ") |
|
break |
|
if sub_text.strip(): |
|
subtitles.append(sub_text) |
|
subtitles = "\n".join(subtitles) |
|
else: |
|
subtitles = "" |
|
|
|
return { |
|
'video': video_tensor, |
|
'subtitle': subtitles, |
|
'record': line, |
|
} |
|
|
|
|
|
def collate_fn(batch): |
|
vid = [x['video'] for x in batch] |
|
sub = [x['subtitle'] for x in batch] |
|
rcs = [x['record'] for x in batch] |
|
return vid, sub, rcs |
|
|
|
|
|
def load_parquet(parquet_file): |
|
table = pq.read_table(parquet_file) |
|
|
|
|
|
df = table.to_pandas() |
|
|
|
jsons = [] |
|
for record in df.itertuples(): |
|
|
|
if len(jsons) < int(record.video_id): |
|
jsons.append({ |
|
"video_id": record.video_id, |
|
"youtube_id": record.videoID, |
|
"url": record.url, |
|
"duration": record.duration, |
|
"domain": record.domain, |
|
"sub_category": record.sub_category, |
|
"questions": [ |
|
{ |
|
"question_id": record.question_id, |
|
"task_type": record.task_type, |
|
"question": record.question, |
|
"choices": list(record.options), |
|
"answer": record.answer, |
|
} |
|
] |
|
}) |
|
else: |
|
jsons[-1]['questions'].append({ |
|
"question_id": record.question_id, |
|
"task_type": record.task_type, |
|
"question": record.question, |
|
"choices": list(record.options), |
|
"answer": record.answer, |
|
}) |
|
|
|
return jsons |
|
|
|
|
|
def build_videomme_eval(args, processor): |
|
|
|
questions = load_parquet(args.question_file) |
|
|
|
questions = get_chunk(questions, args.num_chunks, args.chunk_idx) |
|
dataset = VideoMMEDataset(args.video_folder, args.subtitle_folder, questions, processor) |
|
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=collate_fn) |
|
|
|
return dataloader |
|
|
|
|
|
def videomme_dump(record, instruct, options, output): |
|
letters = ['A', 'B', 'C', 'D'] |
|
|
|
digit2word = { |
|
'1': 'one', |
|
'2': 'two', |
|
'3': 'three', |
|
'4': 'four', |
|
'5': 'five', |
|
'6': 'six', |
|
'7': 'seven', |
|
'8': 'eight', |
|
'9': 'nine', |
|
'0': 'zero', |
|
} |
|
|
|
output = output.replace('answer', '') |
|
output = output.replace('Answer', '') |
|
pred_answer = re.findall('[\(\ \[]*([A-D])[\)\.\ \]]*', output) |
|
try: |
|
find_flag = False |
|
if len(pred_answer) == 0: |
|
for idx, opt in enumerate(options): |
|
|
|
opt2 = opt |
|
if opt in digit2word: |
|
opt2 = digit2word[opt] |
|
if opt.lower() in output.lower() or opt2.lower() in output.lower(): |
|
pred_idx = idx |
|
find_flag = True |
|
break |
|
else: |
|
pred_answer = pred_answer[0].strip() |
|
pred_answer = pred_answer.strip('()') |
|
pred_idx = letters.index(pred_answer) |
|
find_flag = True |
|
|
|
assert find_flag, 'The video \"{}\" instruct: \n\"{}\"\n output: \n\"{}\"\n is not in the expected format'.format(record['youtube_id'], instruct, output) |
|
except: |
|
traceback.print_exc() |
|
pred_idx = 2 |
|
|
|
return letters[pred_idx] |
|
|
|
|
|
def run_inference(args): |
|
disable_torch_init() |
|
|
|
|
|
model, processor, tokenizer = model_init(args.model_path) |
|
|
|
answer_file = os.path.expanduser(args.answer_file) |
|
answer_sub_file = answer_file.replace('.json', '_sub.json') |
|
os.makedirs(os.path.dirname(answer_file), exist_ok=True) |
|
ans_file = open(answer_file, "w") |
|
ans_sub_file = open(answer_sub_file, "w") |
|
|
|
val_loader = build_videomme_eval(args, processor['video']) |
|
|
|
|
|
for i, (videos, subtitles, records) in enumerate(tqdm(val_loader)): |
|
video_tensor = videos[0] |
|
subtitle = subtitles[0] |
|
record = records[0] |
|
|
|
new_record = copy.deepcopy(record) |
|
new_record_sub = copy.deepcopy(record) |
|
|
|
if video_tensor is None: |
|
new_record['missing'] = True |
|
ans_file.write(json.dumps(new_record) + ",\n") |
|
new_record_sub['missing'] = True |
|
ans_sub_file.write(json.dumps(new_record_sub) + ",\n") |
|
continue |
|
else: |
|
new_record['missing'] = False |
|
new_record_sub['missing'] = False |
|
|
|
questions = record['questions'] |
|
for idx, question in enumerate(questions): |
|
q = question['question'] |
|
choices = question['choices'] |
|
options = [re.findall('[A-D]\. (.*).', c)[0] for c in choices] |
|
|
|
instruct = "Select the best answer to the following multiple-choice question based on the video. Respond with only the letter (A, B, C, or D) of the correct option.\n" |
|
instruct += f"{q}\n" |
|
for cho_idx, cho in enumerate(choices): |
|
instruct += f"{cho}\n" |
|
|
|
instruct += "Answer with the option\'s letter from the given choices directly and only give the best option. The best answer is: " |
|
output = mm_infer(video_tensor, instruct, model=model, tokenizer=tokenizer, modal='video', do_sample=False) |
|
new_record['questions'][idx]['response'] = videomme_dump(record, instruct, options, output) |
|
|
|
instruct = f"This video's subtitles are listed below:\n{subtitle}\n" + instruct |
|
output = mm_infer(video_tensor, instruct, model=model, tokenizer=tokenizer, modal='video', do_sample=False) |
|
new_record_sub['questions'][idx]['response'] = videomme_dump(record, instruct, options, output) |
|
|
|
ans_file.write(json.dumps(new_record) + ",\n") |
|
ans_sub_file.write(json.dumps(new_record_sub) + ",\n") |
|
|
|
ans_file.close() |
|
ans_sub_file.close() |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument('--model-path', help='', required=True) |
|
parser.add_argument('--video-folder', help='Directory containing video files.', required=True) |
|
parser.add_argument('--subtitle-folder', help='Directory containing subtitle files.', required=True) |
|
parser.add_argument('--question-file', help='Path to the ground truth file containing question.', required=True) |
|
parser.add_argument('--answer-file', help='Path to the ground truth file containing answers.', required=True) |
|
parser.add_argument("--num-chunks", type=int, default=1) |
|
parser.add_argument("--chunk-idx", type=int, default=0) |
|
parser.add_argument("--device", type=str, required=False, default='cuda:0') |
|
parser.add_argument("--batch-size", type=int, default=1) |
|
parser.add_argument("--num-workers", type=int, default=8) |
|
args = parser.parse_args() |
|
|
|
run_inference(args) |
|
|