The repository for our paper: How to Configure Good In-Context Sequence for Visual Question Answering
we use OpenFlamingo and its framework to implement various retrieval strategies on three different VQA datasets.
Create a conda environment for running the following code. It is used for anonymous submit now, and it will fix in Formal version.
git clone https://github.com/GaryJiajia/OFv2_ICL_VQA.git
cd OFv2_ICL_VQA
conda env create -f environment.yml
conda activate ofv2
pip install git+https://github.com/openai/CLIP.git
We use VQAv2, OK-VQA, and VizWiz datasets. You need to download the files of these datasets yourself, including the Images and Annotations.
To run evaluations on OKVQA you will need to run the following command:
import nltk
nltk.download('wordnet')
OpenFlamingo is a multimodal language model that can be used for a variety of tasks. It is trained on a large multimodal dataset (e.g. Multimodal C4) and can be used to generate text conditioned on interleaved images/text. You can read its blog and code for more information.
OpenFlamingo combines a pretrained vision encoder and a language model using cross attention layers. In our experiment, we use OpenFlamingo-9B for experiments. which uses pretrained vision encoders from the OpenCLIP package, ViT-L-14, and uses the MPT-7B as the pretrained language models. Initialize the model as above and use the following code.
from open_flamingo import create_model_and_transforms
model, image_processor, tokenizer = create_model_and_transforms(
clip_vision_encoder_path="ViT-L-14",
clip_vision_encoder_pretrained="openai",
lang_encoder_path="anas-awadalla/mpt-7b",
tokenizer_path="anas-awadalla/mpt-7b",
cross_attn_every_n_layers=4
)
# grab model checkpoint from huggingface hub
from huggingface_hub import hf_hub_download
import torch
checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-9B-vitl-mpt7b", "checkpoint.pt")
model.load_state_dict(torch.load(checkpoint_path), strict=False)
We prepare a demo in the file /demo_test/demo_test.ipynb
, you can use different demonstrations and query samples to experience the results generated by the OpenFlamingo. Below is a 2-shot demo test.
from open_flamingo import create_model_and_transforms
import torch
from PIL import Image
from PIL import ImageFilter
import requests
class PATH:
lm_path = "path for mpt-7b"
lm_tokenizer_path = "path for mpt-7b"
checkpoint_path = "path for openflamingo v2 checkpoint.pt"
args = PATH()
device_set = 'cuda:0'
device = torch.device(device_set)
flamingo,image_processor,tokenizer = create_model_and_transforms(
clip_vision_encoder_path = 'ViT-L-14',
clip_vision_encoder_pretrained = "openai",
lang_encoder_path = args.lm_path,
tokenizer_path = args.lm_tokenizer_path,
cross_attn_every_n_layers=4,
# new params
inference=True,
precision ='fp16',
device = device_set,
checkpoint_path = args.checkpoint_path,
)
demo_image_one = Image.open(
requests.get(
"http://images.cocodataset.org/val2017/000000039769.jpg", stream=True
).raw
)
demo_image_two = Image.open("test-006.jpg")
query_image = Image.open("test-006.jpg")
tokenizer.padding_side = "left"
lang_x = tokenizer(
["<image>Question: What kind of animals in the image? Answer: Dog. <|endofchunk|><image>Question: What kind of animals in the image? Answer: Dog. <|endofchunk|><image>Question: What kind of animals in the image? Answer:"],
return_tensors="pt",
)
vision_x = [image_processor(demo_image_one).unsqueeze(0), image_processor(demo_image_two).unsqueeze(0), image_processor(query_image).unsqueeze(0)]
vision_x = torch.cat(vision_x, dim=0)
vision_x = vision_x.unsqueeze(1).unsqueeze(0)
# load data to gpus
vision_x = vision_x.to(device).half()
print(vision_x.device)
input_ids=lang_x["input_ids"]
attention_mask = lang_x["attention_mask"]
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
generated_text = flamingo.generate(
vision_x=vision_x,
lang_x=input_ids,
attention_mask=attention_mask,
max_new_tokens=20,
num_beams=3,
)
print(tokenizer.decode(generated_text[0]))
An example evaluation script is at open_flamingo/scripts/run_vqav2.sh
, as follows:
DEVICE=0 # gpu number
RANDOM_ID="VQAv2_Result_file_name"
RESULTS_FILE="results_${RANDOM_ID}.json"
export MASTER_ADDR='localhost'
export MASTER_PORT='10000'
python open_flamingo/eval/evaluate_vqa.py \
--retrieval_name $RANDOM_ID \
--lm_path "Path for mpt-7b" \
--lm_tokenizer_path "Path for mpt-7b" \
--checkpoint_path "Path for OpenFlamingo-9B-vitl-mpt7b checkpoint.pt" \
--vision_encoder_path "ViT-L-14" \
--vision_encoder_pretrained 'openai' \
--device $DEVICE \
--vqav2_train_image_dir_path "mscoco2014/train2014/" \
--vqav2_train_questions_json_path "vqav2/v2_OpenEnded_mscoco_train2014_questions.json" \
--vqav2_train_annotations_json_path "vqav2/v2_mscoco_train2014_annotations.json" \
--vqav2_test_image_dir_path "mscoco2014/val2014/" \
--vqav2_test_questions_json_path "vqav2/v2_OpenEnded_mscoco_val2014_questions" \
--vqav2_test_annotations_json_path "vqav2/v2_mscoco_val2014_annotations.json" \
--results_file $RESULTS_FILE \
--num_samples 5000\
--shots 4 8 16 32\
--num_trials 1 \
--seed 5 \
--batch_size 1 \
--cross_attn_every_n_layers 4 \
--precision fp16 \
--dataset_name vqav2 \
--eval_vqav2 \
echo "evaluation complete! results written to $RESULTS_FILE"
Change parameters according to your needs and Use following command, you can use it on one RTX 3090 GPU with FP16 precision.
cd this file
bash open_flamingo/scripts/run_vqav2.sh
Before running the above file, you have to run the retrieval/img2img_clip_style.py to get the "validation_xxx.npy" retrival results file which is used in eval/eval_datasets.py. For more details, you can see the answer in this issue.
If you need to use different retrieval methods, you can change the parameters of control_signals
in open_flamingo/eval/evaluate_vqa.py
.
control_signals = {"clip": True, # If clip==False, it means the RS.
"retrieval_name": args.retrieval_name, # The results file name.
"retrieval_type": "SI", # Name of retrieval methods. SI/SQ/SI_Q...
"mismatch_type":"normal", # The mismatch type:answer/image/question/question-answer.
"specification": False, # Add the instruction.
"declaration": False, # Add the declarative sentence into the demonstrations.
"add_declaration": False, # Add the declarative sentence into the demonstrations.
"gauss": True, # Blur the query image.
"None_ICE":False, # In 0-shot setting, should we offer demonstration for the Model.
"order": "order"} # The order of the demonstrations. order/reverse
This code is based on the second version of OpenFlamingo.