Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions vlm_eval/models/instructblip.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from vlm_eval.util.interfaces import VLM, ImageProcessor, Tokenizer

# Define InstructBLIP Mapping from Model ID --> HF Hub Path
INSTRUCTBLIP_MODELS = {"instructblip-vicuna-7b": "Salesforce/instructblip-vicuna-7b"}
INSTRUCTBLIP_MODELS = {
# fmt: off
"instructblip-vicuna-7b": "Salesforce/instructblip-vicuna-7b",
"instructblip-vicuna-13b": "Salesforce/instructblip-vicuna-13b",
# fmt: on
}


class InstructBLIP(VLM):
Expand Down Expand Up @@ -57,9 +62,10 @@ def __init__(
# "temperature": 1,
# }

# For computing likelihoods --> get tokens corresponding to "true", "false" and "yes", "no"
# For computing likelihoods --> get tokens corresponding to "true", "false" and "yes", "no" for T/F questions
# and also the lower-case alphabet for MC questions
self.string2idx = {}
for trigger_string in ["true", "false", "yes", "no"]:
for trigger_string in ["true", "false", "yes", "no"] + [f"{chr(ord('a') + idx)}" for idx in range(26)]:
token_idx_list = self.text_img_processor.tokenizer.encode(trigger_string, add_special_tokens=False)
assert len(token_idx_list) == 1, f'String "{trigger_string}" is tokenized as more than one token!'
self.string2idx[trigger_string] = token_idx_list[0]
Expand Down Expand Up @@ -93,6 +99,8 @@ def get_prompt_fn(self, dataset_family: str = "vqa-v2") -> Callable[[str], str]:
bbox_refer_prompt_fn = self.get_bbox_refer_chat_prompt_fn()
text_vqa_prompt_fn = self.get_vqa_chat_prompt_fn(uncertainty_aware=False, ocr_handling=True)
captioning_prompt_fn = self.get_captioning_prompt_fn()
tally_qa_prompt_fn = self.get_mc_prompt_fn()
ai2d_prompt_fn = self.get_mc_prompt_fn()

return {
"vqa-v2": vqa_prompt_fn,
Expand All @@ -103,6 +111,7 @@ def get_prompt_fn(self, dataset_family: str = "vqa-v2") -> Callable[[str], str]:
"pope": vqa_prompt_fn,
"refcoco": bbox_refer_prompt_fn,
"ocid-ref": bbox_refer_prompt_fn,
"ai2d": ai2d_prompt_fn,
# Generic for GUI
"captioning": captioning_prompt_fn,
"bbox_pred": bbox_refer_prompt_fn,
Expand Down Expand Up @@ -154,6 +163,20 @@ def contrast_caption_prompt_fn(caption: str) -> str:

return contrast_caption_prompt_fn

@staticmethod
def get_mc_prompt_fn() -> Callable[[str], str]:
"""Generates the full reference prompt for a multiple-choice question-answer task."""

def mc_prompt_fn(question: str, choices: List[str]) -> str:
# Create Choice String
assert len(choices) <= 26, "Too many answer choices vs. possible letters in the alphabet!"
choice_str = "\n".join([f"{chr(ord('A') + idx)}. {choice}" for idx, choice in enumerate(choices)])
q_prompt = "{}\n{}".format(question, choice_str)
q_prompt += "\nAnswer with the option's letter from the given choices directly."
return q_prompt

return mc_prompt_fn

@staticmethod
def get_bbox_refer_chat_prompt_fn() -> Callable[[str], str]:
"""Generates the full reference prompt for a referring expression localization task."""
Expand Down