import argparse
import torch
from [Link] import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from [Link] import conv_templates, SeparatorStyle
from [Link] import load_pretrained_model
from [Link] import disable_torch_init
from llavavid.mm_utils import process_anyres_image,tokenizer_image_token,
get_model_name_from_path, KeywordsStoppingCriteria
import json
import os
import math
from tqdm import tqdm
from decord import VideoReader, cpu
from transformers import AutoConfig
import cv2
import base64
import openai
from PIL import Image
import numpy as np
def split_list(lst, n):
"""Split a list into n (roughly) equal-sized chunks"""
chunk_size = [Link](len(lst) / n) # integer division
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
def get_chunk(lst, n, k):
chunks = split_list(lst, n)
return chunks[k]
def parse_args():
"""
Parse command-line arguments.
"""
parser = [Link]()
# Define the command-line arguments
parser.add_argument("--video_path", help="Path to the video files.",
required=True)
parser.add_argument("--output_dir", help="Directory to save the model results
JSON.", required=True)
parser.add_argument("--output_name", help="Name of the file for storing results
JSON.", required=True)
parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--conv-mode", type=str, default=None)
parser.add_argument("--chunk-idx", type=int, default=0)
parser.add_argument("--mm_resampler_type", type=str, default="spatial_pool")
parser.add_argument("--mm_spatial_pool_stride", type=int, default=4)
parser.add_argument("--mm_spatial_pool_out_channels", type=int, default=1024)
parser.add_argument("--mm_spatial_pool_mode", type=str, default="average")
parser.add_argument("--image_aspect_ratio", type=str, default="anyres")
parser.add_argument("--image_grid_pinpoints", type=str, default="[(224, 448),
(224, 672), (224, 896), (448, 448), (448, 224), (672, 224), (896, 224)]")
parser.add_argument("--mm_patch_merge_type", type=str, default="spatial_unpad")
parser.add_argument("--overwrite", type=lambda x: (str(x).lower() == 'true'),
default=True)
parser.add_argument("--for_get_frames_num", type=int, default=4)
parser.add_argument("--load_8bit", type=lambda x: (str(x).lower() == 'true'),
default=False)
parser.add_argument("--prompt", type=str, default=None)
parser.add_argument("--api_key", type=str, help="OpenAI API key")
parser.add_argument("--mm_newline_position", type=str, default="no_token")
parser.add_argument("--force_sample", type=lambda x: (str(x).lower() ==
'true'), default=False)
return parser.parse_args()
def load_video(video_path, args):
vr = VideoReader(video_path, ctx=cpu(0))
total_frame_num = len(vr)
fps = round(vr.get_avg_fps())
frame_idx = [i for i in range(0, len(vr), fps)]
# sample_fps = args.for_get_frames_num if total_frame_num >
args.for_get_frames_num else total_frame_num
if len(frame_idx) > args.for_get_frames_num or args.force_sample:
sample_fps = args.for_get_frames_num
uniform_sampled_frames = [Link](0, total_frame_num - 1, sample_fps,
dtype=int)
frame_idx = uniform_sampled_frames.tolist()
spare_frames = vr.get_batch(frame_idx).asnumpy()
# Save frames as images
# for i, frame in enumerate(spare_frames):
# [Link](f'{args.output_dir}/frame_{i}.jpg', [Link](frame,
cv2.COLOR_RGB2BGR))
return spare_frames
def load_video_base64(path):
video = [Link](path)
base64Frames = []
while [Link]():
success, frame = [Link]()
if not success:
break
_, buffer = [Link](".jpg", frame)
[Link](base64.b64encode(buffer).decode("utf-8"))
[Link]()
# print(len(base64Frames), "frames read.")
return base64Frames
def run_inference(args):
"""
Run inference on ActivityNet QA DataSet using the Video-ChatGPT model.
Args:
args: Command-line arguments.
"""
# Initialize the model
if "gpt4v" != args.model_path:
model_name = get_model_name_from_path(args.model_path)
# Set model configuration parameters if they exist
if [Link] == True:
overwrite_config = {}
overwrite_config["mm_spatial_pool_mode"] = args.mm_spatial_pool_mode
overwrite_config["mm_spatial_pool_stride"] =
args.mm_spatial_pool_stride
overwrite_config["mm_newline_position"] = args.mm_newline_position
cfg_pretrained = AutoConfig.from_pretrained(args.model_path)
# import pdb;pdb.set_trace()
if "qwen" not in args.model_path.lower():
if "224" in cfg_pretrained.mm_vision_tower:
# suppose the length of text tokens is around 1000, from bo's
report
least_token_number =
args.for_get_frames_num*(16//args.mm_spatial_pool_stride)**2 + 1000
else:
least_token_number =
args.for_get_frames_num*(24//args.mm_spatial_pool_stride)**2 + 1000
scaling_factor = [Link](least_token_number/4096)
if scaling_factor >= 2:
if "vicuna" in cfg_pretrained._name_or_path.lower():
print(float(scaling_factor))
overwrite_config["rope_scaling"] = {"factor":
float(scaling_factor), "type": "linear"}
overwrite_config["max_sequence_length"] = 4096 * scaling_factor
overwrite_config["tokenizer_model_max_length"] = 4096 *
scaling_factor
tokenizer, model, image_processor, context_len =
load_pretrained_model(args.model_path, args.model_base, model_name,
load_8bit=args.load_8bit, overwrite_config=overwrite_config)
else:
tokenizer, model, image_processor, context_len =
load_pretrained_model(args.model_path, args.model_base, model_name)
else:
pass
# Create the output directory if it doesn't exist
if not [Link](args.output_dir):
[Link](args.output_dir)
output_name = args.output_name
answers_file = [Link](args.output_dir, f"{output_name}.json")
ans_file = open(answers_file, "w")
video_path = args.video_path
all_video_pathes = []
# Check if the video_path is a directory or a file
if [Link](video_path):
# If it's a directory, loop over all files in the directory
for filename in [Link](video_path):
# Load the video file
cur_video_path = [Link](video_path, f"{filename}")
all_video_pathes.append([Link](video_path, cur_video_path))
else:
# If it's a file, just process the video
all_video_pathes.append(video_path)
# import pdb;pdb.set_trace()
for video_path in all_video_pathes:
sample_set = {}
question = [Link]
sample_set["Q"] = question
sample_set["video_name"] = video_path
# Check if the video exists
if [Link](video_path):
if "gpt4v" != args.model_path:
video = load_video(video_path, args)
video = image_processor.preprocess(video, return_tensors="pt")
["pixel_values"].half().cuda()
video = [video]
else:
video = load_video_base64(video_path)
interval = int(len(video) / args.for_get_frames_num)
# try:
# Run inference on the video and add the output to the list
if "gpt4v" != args.model_path:
qs = question
if [Link].mm_use_im_start_end:
qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN +
DEFAULT_IM_END_TOKEN + "\n" + qs
else:
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
conv = conv_templates[args.conv_mode].copy()
conv.append_message([Link][0], qs)
conv.append_message([Link][1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX,
return_tensors="pt").unsqueeze(0).cuda()
if tokenizer.pad_token_id is None:
if "qwen" in tokenizer.name_or_path.lower():
print("Setting pad token to bos token for qwen model.")
tokenizer.pad_token_id = 151643
attention_masks = input_ids.ne(tokenizer.pad_token_id).long().cuda()
stop_str = [Link] if conv.sep_style != [Link] else
conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer,
input_ids)
cur_prompt = question
else:
prompt = question
system_error = ""
if "gpt4v" != args.model_path:
with torch.inference_mode():
# model.update_prompt([[cur_prompt]])
# import pdb;pdb.set_trace()
# output_ids = [Link](inputs=input_ids, images=video,
attention_mask=attention_masks, modalities="video", do_sample=True,
temperature=0.2, max_new_tokens=1024, use_cache=True,
stopping_criteria=[stopping_criteria])
if "mistral" not in cfg_pretrained._name_or_path.lower():
output_ids = [Link](inputs=input_ids, images=video,
attention_mask=attention_masks, modalities="video", do_sample=False,
temperature=0.0, max_new_tokens=1024, top_p=0.1,num_beams=1,use_cache=True,
stopping_criteria=[stopping_criteria])
# output_ids = [Link](inputs=input_ids, images=video,
attention_mask=attention_masks, modalities="video", do_sample=True,
temperature=0.2, max_new_tokens=1024, use_cache=True,
stopping_criteria=[stopping_criteria])
else:
output_ids = [Link](inputs=input_ids, images=video,
attention_mask=attention_masks, modalities="video", do_sample=False,
temperature=0.0, max_new_tokens=1024, top_p=0.1, num_beams=1, use_cache=True)
# output_ids = [Link](inputs=input_ids, images=video,
attention_mask=attention_masks, modalities="video", do_sample=True,
temperature=0.2, max_new_tokens=1024, use_cache=True)
else:
openai.api_key = args.api_key # Your API key here
max_num_retries = 0
retry = 5
PROMPT_MESSAGES = [
{
"role": "user",
"content": [
f"These are frames from a video that I want to upload.
Answer me one question of this video: {prompt}",
*map(lambda x: {"image": x, "resize": 336},
video[0::interval]),
],
},
]
params = {
"model": "gpt-4-vision-preview", #gpt-4-1106-vision-preview
"messages": PROMPT_MESSAGES,
"max_tokens": 1024,
}
sucess_flag=False
while max_num_retries < retry:
try:
result = [Link](**params)
outputs = [Link][0].[Link]
sucess_flag = True
break
except Exception as inst :
if 'error' in dir(inst):
# import pdb;pdb.set_trace()
if [Link] == 'rate_limit_exceeded':
if "TPM" in [Link]:
[Link](30)
continue
else:
import pdb;pdb.set_trace()
elif [Link] == 'insufficient_quota':
print(f'insufficient_quota key')
exit()
elif [Link] == 'content_policy_violation':
print(f'content_policy_violation')
system_error = "content_policy_violation"
break
print('Find error message in response:
',str([Link]), 'error code: ', str([Link]))
continue
if not sucess_flag:
print(f'Calling OpenAI failed after retrying for {max_num_retries}
times. Check the logs for details.')
exit()
if "gpt4v" != args.model_path:
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
[0].strip()
else:
print(len(video[0::interval]))
print(f"Question: {prompt}\n")
print(f"Response: {outputs}\n")
if "gpt4v" == args.model_path:
if system_error == 'content_policy_violation':
continue
elif system_error == "":
continue
else:
import pdb;pdb.set_trace()
# import pdb;pdb.set_trace()
if "mistral" not in cfg_pretrained._name_or_path.lower():
if [Link](stop_str):
outputs = outputs[: -len(stop_str)]
outputs = [Link]()
sample_set["pred"] = outputs
ans_file.write([Link](sample_set, ensure_ascii=False) + "\n")
ans_file.flush()
ans_file.close()
if __name__ == "__main__":
args = parse_args()
run_inference(args)