diff --git a/vlmeval/config.py b/vlmeval/config.py index f7e7d35de..3e6f85021 100644 --- a/vlmeval/config.py +++ b/vlmeval/config.py @@ -1185,6 +1185,8 @@ "Ovis2-16B": partial(Ovis2, model_path="AIDC-AI/Ovis2-16B"), "Ovis2-34B": partial(Ovis2, model_path="AIDC-AI/Ovis2-34B"), "Ovis-U1-3B": partial(OvisU1, model_path="AIDC-AI/Ovis-U1-3B"), + "Ovis2.5-2B": partial(Ovis2_5, model_path="AIDC-AI/Ovis2.5-2B"), + "Ovis2.5-9B": partial(Ovis2_5, model_path="AIDC-AI/Ovis2.5-9B"), } mantis_series = { diff --git a/vlmeval/vlm/__init__.py b/vlmeval/vlm/__init__.py index 8f4451259..4d98c6d27 100644 --- a/vlmeval/vlm/__init__.py +++ b/vlmeval/vlm/__init__.py @@ -69,7 +69,7 @@ PLLaVA, ) from .vila import VILA, NVILA -from .ovis import Ovis, Ovis1_6, Ovis1_6_Plus, Ovis2, OvisU1 +from .ovis import Ovis, Ovis1_6, Ovis1_6_Plus, Ovis2, OvisU1, Ovis2_5 from .mantis import Mantis from .mixsense import LLama3Mixsense from .parrot import Parrot diff --git a/vlmeval/vlm/ovis/__init__.py b/vlmeval/vlm/ovis/__init__.py index 2bda2ba75..d98fa7be9 100644 --- a/vlmeval/vlm/ovis/__init__.py +++ b/vlmeval/vlm/ovis/__init__.py @@ -1,3 +1,3 @@ -from .ovis import Ovis, Ovis1_6, Ovis1_6_Plus, Ovis2, OvisU1 +from .ovis import Ovis, Ovis1_6, Ovis1_6_Plus, Ovis2, OvisU1, Ovis2_5 -__all__ = ['Ovis', 'Ovis1_6', 'Ovis1_6_Plus', 'Ovis2', 'OvisU1'] +__all__ = ['Ovis', 'Ovis1_6', 'Ovis1_6_Plus', 'Ovis2', 'OvisU1', 'Ovis2_5'] diff --git a/vlmeval/vlm/ovis/ovis.py b/vlmeval/vlm/ovis/ovis.py index fc475037f..aa686fe46 100644 --- a/vlmeval/vlm/ovis/ovis.py +++ b/vlmeval/vlm/ovis/ovis.py @@ -720,3 +720,222 @@ def prepare_inputs(self, message, dataset=None): ], dim=0) return prompt, input_ids, attention_mask, pixel_values, grid_thws + + + +class Ovis2_5(BaseModel): + INSTALL_REQ = False + INTERLEAVE = True + + def __init__(self, model_path='AIDC-AI/Ovis2.5-9B', **kwargs): + assert model_path is not None + # Recommend to install `transformers>=4.51.3`, `torch>=2.4.0` + self.model_path = model_path + self.device = torch.cuda.current_device() + self.dtype = torch.bfloat16 + + self.model = AutoModelForCausalLM.from_pretrained( + self.model_path, + torch_dtype=self.dtype, + trust_remote_code=True + ) + self.model = self.model.eval().to(device=self.device) + + # Thinking mode configuration + self.enable_thinking = kwargs.pop('enable_thinking', True) + self.enable_thinking_budget = kwargs.pop('enable_thinking_budget', True) + self.thinking_budget = kwargs.pop('thinking_budget', 2048) + + self.text_tokenizer = self.model.text_tokenizer + self.pad_token_id = self.text_tokenizer.pad_token_id + self.eos_token_id = self.model.generation_config.eos_token_id if hasattr(self.model, 'generation_config') else self.text_tokenizer.eos_token_id + + self.image_placeholder = '' + + # Generation kwargs - ensure max_new_tokens > thinking_budget + 25 + max_new_tokens = kwargs.pop('max_new_tokens', 3072) + if self.enable_thinking and max_new_tokens <= self.thinking_budget + 25: + max_new_tokens = self.thinking_budget + 512 + + self.gen_kwargs = dict( + max_new_tokens=max_new_tokens, + do_sample=False, + top_p=None, + top_k=None, + temperature=None, + repetition_penalty=None, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id, + use_cache=True, + enable_thinking=self.enable_thinking, + enable_thinking_budget=self.enable_thinking_budget, + thinking_budget=self.thinking_budget, + ) + self.gen_kwargs.update(kwargs) + + # CoT usage configuration for different datasets + self.use_cot = {'MMMU', 'MMVet', 'MMStar', 'MathVista', 'MathVerse', 'MathVision'} + + def use_custom_prompt(self, dataset): + if any(dataset.startswith(prefix) for prefix in ['MMVet', 'MathVista', 'MathVerse', 'MathVision']): + return True + if DATASET_TYPE(dataset) == 'Y/N' or DATASET_TYPE(dataset) == 'MCQ': + return True + return False + + def build_yorn_prompt(self, line, dataset=None): + prompt = line['question'] + if listinstr(['HallusionBench'], dataset): + prompt += ' Please answer yes or no.' + prompt += '\nAnswer the question using a single word or phrase.' + return prompt + + def build_multi_choice_prompt(self, line, dataset=None, use_cot=False): + prompt = line['question'] + hint = line['hint'] if ('hint' in line and not pd.isna(line['hint'])) else None + if hint is not None: + prompt = hint + '\n' + prompt + + options = { + cand: line[cand] + for cand in string.ascii_uppercase + if cand in line and not pd.isna(line[cand]) + } + for key, item in options.items(): + prompt += f'\n{key}. {item}' + + if len(options): + if use_cot: + prompt += "\nProvide a step-by-step solution to the problem, and conclude with 'the answer is' followed by the final solution." + else: + prompt += "\nAnswer with the option's letter from the given choices directly." + + return prompt + + def build_mmvet_prompt(self, line, dataset=None, use_cot=False): + prompt = line['question'] + if use_cot: + prompt += "\nProvide a step-by-step solution to the problem carefully." + return prompt + + def build_math_prompt(self, line, dataset=None, use_cot=False): + prompt = line['question'] + if use_cot: + prompt += "\nProvide a step-by-step solution to the problem, and conclude with 'the answer is' followed by the final solution." + return prompt + + def build_prompt(self, line, dataset=None): + assert self.use_custom_prompt(dataset) + assert isinstance(dataset, str) + tgt_path = self.dump_image(line, dataset) + + use_cot = any(dataset.startswith(prefix) for prefix in self.use_cot) + + if dataset == 'MMVet': + prompt = self.build_mmvet_prompt(line, dataset, use_cot) + elif any(dataset.startswith(prefix) for prefix in ('MathVista', 'MathVerse', 'MathVision')): + prompt = self.build_math_prompt(line, dataset, use_cot) + elif DATASET_TYPE(dataset) == 'Y/N': + prompt = self.build_yorn_prompt(line, dataset) + elif DATASET_TYPE(dataset) == 'MCQ': + prompt = self.build_multi_choice_prompt(line, dataset, use_cot) + else: + raise RuntimeError(f'Invalid dataset type: {DATASET_TYPE(dataset)}') + + message = [dict(type='image', value=s) for s in tgt_path] + [dict(type='text', value=prompt)] + + # interleave dataset + if dataset.startswith('MMMU_'): + from ... import MMMUDataset + message = MMMUDataset.split_MMMU(message) + + return message + + def generate_inner(self, message, dataset=None): + def _extract_answer(text): + answer_index = text.lower().find('the answer is') + if answer_index != -1: + answer_index += len('the answer is') + answer = text[answer_index:].lstrip(':').strip() + else: + answer = text + return answer + + # DynaMath special handling + if dataset == 'DynaMath': + message[-1]['value'] += "\nProvide a step-by-step solution to the problem, and conclude with 'the answer is' followed by the final solution." + + input_ids, pixel_values, grid_thws, prompt = self.prepare_inputs(message, dataset) + + outputs = self.model.generate( + inputs=input_ids, + pixel_values=pixel_values, + grid_thws=grid_thws, + **self.gen_kwargs + ) + + response = self.text_tokenizer.decode(outputs[0], skip_special_tokens=True) + + # Extract thinking content if present + if self.enable_thinking: + think_end = response.rfind('') + if think_end != -1: + think_end += len('') + response = response[think_end:].strip() + + # Extract final answer for CoT prompts + if "conclude with 'the answer is' followed by the final solution." in prompt: + response = _extract_answer(response) + + return response + + def prepare_inputs(self, message, dataset=None): + # build query - convert message format to Ovis2.5 format + images = [x['value'] for x in message if x['type'] == 'image'] + texts = [x['value'] for x in message if x['type'] == 'text'] + + if DATASET_MODALITY(dataset) == 'VIDEO': # video inputs + chunks = [self.image_placeholder for x in message if x['type'] != 'text'] + chunks += [x['value'].strip() for x in message if x['type'] == 'text' and x['value'] != ''] + query = '\n'.join(chunks) + elif len(images) == 0: # text-only inputs + query = '\n'.join(texts) + elif len(images) == 1 and len(texts) == 1: # single-image inputs + query = self.image_placeholder + '\n' + texts[0] + else: # interleaved inputs + chunks = [x['value'].strip() if x['type'] == 'text' else self.image_placeholder for x in message] + query = '\n'.join(chunks) + + # Convert to Ovis2.5 message format + ovis_messages = [{ + "role": "user", + "content": [] + }] + + # Add images first + for image_path in images: + ovis_messages[0]["content"].append({ + "type": "image", + "image": Image.open(image_path) + }) + + # Add text content + if texts: + ovis_messages[0]["content"].append({ + "type": "text", + "text": '\n'.join(texts) + }) + + # Preprocess inputs using Ovis2.5 API + input_ids, pixel_values, grid_thws = self.model.preprocess_inputs( + messages=ovis_messages, + add_generation_prompt=True, + enable_thinking=self.enable_thinking + ) + + # Move to device + input_ids = input_ids.to(device=self.device) + pixel_values = pixel_values.to(device=self.device) if pixel_values is not None else None + grid_thws = grid_thws.to(device=self.device) if grid_thws is not None else None + + return input_ids, pixel_values, grid_thws, query \ No newline at end of file