diff --git a/vlm_eval/models/instructblip.py b/vlm_eval/models/instructblip.py index e84f5c4..8fa3c4f 100644 --- a/vlm_eval/models/instructblip.py +++ b/vlm_eval/models/instructblip.py @@ -15,7 +15,12 @@ from vlm_eval.util.interfaces import VLM, ImageProcessor, Tokenizer # Define InstructBLIP Mapping from Model ID --> HF Hub Path -INSTRUCTBLIP_MODELS = {"instructblip-vicuna-7b": "Salesforce/instructblip-vicuna-7b"} +INSTRUCTBLIP_MODELS = { + # fmt: off + "instructblip-vicuna-7b": "Salesforce/instructblip-vicuna-7b", + "instructblip-vicuna-13b": "Salesforce/instructblip-vicuna-13b", + # fmt: on +} class InstructBLIP(VLM): @@ -57,9 +62,10 @@ def __init__( # "temperature": 1, # } - # For computing likelihoods --> get tokens corresponding to "true", "false" and "yes", "no" + # For computing likelihoods --> get tokens corresponding to "true", "false" and "yes", "no" for T/F questions + # and also the lower-case alphabet for MC questions self.string2idx = {} - for trigger_string in ["true", "false", "yes", "no"]: + for trigger_string in ["true", "false", "yes", "no"] + [f"{chr(ord('a') + idx)}" for idx in range(26)]: token_idx_list = self.text_img_processor.tokenizer.encode(trigger_string, add_special_tokens=False) assert len(token_idx_list) == 1, f'String "{trigger_string}" is tokenized as more than one token!' self.string2idx[trigger_string] = token_idx_list[0] @@ -93,6 +99,8 @@ def get_prompt_fn(self, dataset_family: str = "vqa-v2") -> Callable[[str], str]: bbox_refer_prompt_fn = self.get_bbox_refer_chat_prompt_fn() text_vqa_prompt_fn = self.get_vqa_chat_prompt_fn(uncertainty_aware=False, ocr_handling=True) captioning_prompt_fn = self.get_captioning_prompt_fn() + tally_qa_prompt_fn = self.get_mc_prompt_fn() + ai2d_prompt_fn = self.get_mc_prompt_fn() return { "vqa-v2": vqa_prompt_fn, @@ -103,6 +111,7 @@ def get_prompt_fn(self, dataset_family: str = "vqa-v2") -> Callable[[str], str]: "pope": vqa_prompt_fn, "refcoco": bbox_refer_prompt_fn, "ocid-ref": bbox_refer_prompt_fn, + "ai2d": ai2d_prompt_fn, # Generic for GUI "captioning": captioning_prompt_fn, "bbox_pred": bbox_refer_prompt_fn, @@ -154,6 +163,20 @@ def contrast_caption_prompt_fn(caption: str) -> str: return contrast_caption_prompt_fn + @staticmethod + def get_mc_prompt_fn() -> Callable[[str], str]: + """Generates the full reference prompt for a multiple-choice question-answer task.""" + + def mc_prompt_fn(question: str, choices: List[str]) -> str: + # Create Choice String + assert len(choices) <= 26, "Too many answer choices vs. possible letters in the alphabet!" + choice_str = "\n".join([f"{chr(ord('A') + idx)}. {choice}" for idx, choice in enumerate(choices)]) + q_prompt = "{}\n{}".format(question, choice_str) + q_prompt += "\nAnswer with the option's letter from the given choices directly." + return q_prompt + + return mc_prompt_fn + @staticmethod def get_bbox_refer_chat_prompt_fn() -> Callable[[str], str]: """Generates the full reference prompt for a referring expression localization task."""