diff --git a/colpali_engine/interpretability/similarity_map_utils.py b/colpali_engine/interpretability/similarity_map_utils.py index ce5930ebe..9ab6f192d 100644 --- a/colpali_engine/interpretability/similarity_map_utils.py +++ b/colpali_engine/interpretability/similarity_map_utils.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from einops import rearrange @@ -56,12 +56,17 @@ def get_similarity_maps_from_embeddings( return similarity_maps -def normalize_similarity_map(similarity_map: torch.Tensor) -> torch.Tensor: +def normalize_similarity_map( + similarity_map: torch.Tensor, + value_range: Optional[Tuple[float, float]] = None, +) -> torch.Tensor: """ Normalize the similarity map to have values in the range [0, 1]. Args: similarity_map: tensor of shape (n_patch_x, n_patch_y) or (batch_size, n_patch_x, n_patch_y) + value_range: optional tuple specifying the (min, max) range to use for normalization. + When None, the min/max are computed from the input tensor (default behavior). """ if similarity_map.ndim not in [2, 3]: raise ValueError( @@ -69,11 +74,25 @@ def normalize_similarity_map(similarity_map: torch.Tensor) -> torch.Tensor: "3 dimensions (batch_size, n_patch_x, n_patch_y)." ) - # Compute the minimum values along the last two dimensions (n_patch_x, n_patch_y) - min_vals = similarity_map.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0] # (1, 1) or (batch_size, 1, 1) - - # Compute the maximum values along the last two dimensions (n_patch_x, n_patch_y) - max_vals = similarity_map.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0] # (1, 1) or (batch_size, 1, 1) + if value_range is None: + # Compute the minimum values along the last two dimensions (n_patch_x, n_patch_y) + min_vals = similarity_map.min(dim=-1, keepdim=True)[0].min( + dim=-2, keepdim=True + )[0] # (1, 1) or (batch_size, 1, 1) + + # Compute the maximum values along the last two dimensions (n_patch_x, n_patch_y) + max_vals = similarity_map.max(dim=-1, keepdim=True)[0].max( + dim=-2, keepdim=True + )[0] # (1, 1) or (batch_size, 1, 1) + else: + min_vals, max_vals = value_range + broadcast_shape = (1,) * similarity_map.ndim + min_vals = torch.as_tensor(min_vals, dtype=similarity_map.dtype, device=similarity_map.device).view( + broadcast_shape + ) + max_vals = torch.as_tensor(max_vals, dtype=similarity_map.dtype, device=similarity_map.device).view( + broadcast_shape + ) # Normalize the tensor # NOTE: Add a small epsilon to avoid division by zero. diff --git a/colpali_engine/interpretability/similarity_maps.py b/colpali_engine/interpretability/similarity_maps.py index bc677fa60..ce95d653e 100644 --- a/colpali_engine/interpretability/similarity_maps.py +++ b/colpali_engine/interpretability/similarity_maps.py @@ -1,4 +1,4 @@ -from typing import List, Tuple +from typing import List, Optional, Tuple import matplotlib.pyplot as plt import numpy as np @@ -15,6 +15,7 @@ def plot_similarity_map( similarity_map: torch.Tensor, figsize: Tuple[int, int] = (8, 8), show_colorbar: bool = False, + normalization_range: Optional[Tuple[float, float]] = None, ) -> Tuple[plt.Figure, plt.Axes]: """ Plot and overlay a similarity map over the input image. @@ -42,7 +43,10 @@ def plot_similarity_map( # Normalize the similarity map and convert it to Pillow image similarity_map_array = ( - normalize_similarity_map(similarity_map).to(torch.float32).cpu().numpy() + normalize_similarity_map(similarity_map, value_range=normalization_range) + .to(torch.float32) + .cpu() + .numpy() ) # (n_patches_x, n_patches_y) # Reshape the similarity map to match the PIL shape convention @@ -78,6 +82,7 @@ def plot_all_similarity_maps( figsize: Tuple[int, int] = (8, 8), show_colorbar: bool = False, add_title: bool = True, + normalize_per_query: bool = True, ) -> List[Tuple[plt.Figure, plt.Axes]]: """ For each token in the query, plot and overlay a similarity map over the input image. @@ -93,6 +98,8 @@ def plot_all_similarity_maps( figsize: size of the figure show_colorbar: whether to show a colorbar add_title: whether to add a title with the token and the max similarity score + normalize_per_query: if True (default), reuse a single min/max range across all tokens in the + provided query similarity tensor. This avoids stretching near-constant maps to the full color scale. Example usage for one query-image pair: @@ -133,12 +140,20 @@ def plot_all_similarity_maps( plots: List[Tuple[plt.Figure, plt.Axes]] = [] + normalization_range: Optional[Tuple[float, float]] = None + if normalize_per_query: + normalization_range = ( + similarity_maps.min().item(), + similarity_maps.max().item(), + ) + for idx, token in enumerate(query_tokens): fig, ax = plot_similarity_map( image=image, similarity_map=similarity_maps[idx], figsize=figsize, show_colorbar=show_colorbar, + normalization_range=normalization_range, ) if add_title: diff --git a/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py b/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py index 497fe12cc..acd4b0e37 100644 --- a/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py +++ b/colpali_engine/models/idefics3/colidefics3/processing_colidefics3.py @@ -1,20 +1,30 @@ +import math from typing import ClassVar, List, Optional, Tuple, Union import torch from PIL import Image from transformers import BatchEncoding, BatchFeature, Idefics3Processor -from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor +from colpali_engine.utils.processing_utils import ( + BaseVisualRetrieverProcessor, + Idefics3SplitImageInterpretabilityMixin, +) -class ColIdefics3Processor(BaseVisualRetrieverProcessor, Idefics3Processor): +class ColIdefics3Processor( + Idefics3SplitImageInterpretabilityMixin, + BaseVisualRetrieverProcessor, + Idefics3Processor, +): """ Processor for ColIdefics3. """ query_augmentation_token: ClassVar[str] = "" image_token: ClassVar[str] = "" - visual_prompt_prefix: ClassVar[str] = "<|im_start|>User:Describe the image.\nAssistant:" + visual_prompt_prefix: ClassVar[str] = ( + "<|im_start|>User:Describe the image.\nAssistant:" + ) def __init__(self, *args, image_seq_len=64, **kwargs): super().__init__(*args, image_seq_len=image_seq_len, **kwargs) @@ -72,5 +82,36 @@ def get_n_patches( self, image_size: Tuple[int, int], patch_size: int, + *args, + **kwargs, ) -> Tuple[int, int]: - raise NotImplementedError("This method is not implemented for ColIdefics3.") + """ + Get the number of patches (n_patches_x, n_patches_y) that will be used to process an image of + size (height, width) with the given patch size. + + This method mirrors the Idefics3 image processing logic: + 1. Resize the image so the longest edge equals the processor's longest_edge setting + 2. Calculate the number of patches in each direction using ceiling division + + Args: + image_size: Tuple of (height, width) in pixels. + patch_size: The size of each square patch in pixels. + + Returns: + Tuple of (n_patches_x, n_patches_y) representing the number of patches + along the width and height dimensions respectively. + """ + # Get the longest_edge from the image processor's size configuration + longest_edge = self.image_processor.size.get("longest_edge", 4 * patch_size) + + # Step 1: Calculate resized dimensions using the mixin helper method + height_new, width_new = self._calculate_resized_dimensions( + image_size, longest_edge + ) + + # Step 2: Calculate the number of patches in each direction + # This mirrors the split_image logic from Idefics3ImageProcessor + n_patches_y = math.ceil(height_new / patch_size) + n_patches_x = math.ceil(width_new / patch_size) + + return n_patches_x, n_patches_y diff --git a/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py b/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py index 14867745e..0aa42aee4 100644 --- a/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py +++ b/colpali_engine/models/modernvbert/colvbert/processing_colmodernvbert.py @@ -1,15 +1,23 @@ +import math from typing import ClassVar, List, Optional, Tuple, Union import torch from PIL import Image from transformers import BatchEncoding, BatchFeature, Idefics3Processor -from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor +from colpali_engine.utils.processing_utils import ( + BaseVisualRetrieverProcessor, + Idefics3SplitImageInterpretabilityMixin, +) -class ColModernVBertProcessor(BaseVisualRetrieverProcessor, Idefics3Processor): +class ColModernVBertProcessor( + Idefics3SplitImageInterpretabilityMixin, + BaseVisualRetrieverProcessor, + Idefics3Processor, +): """ - Processor for ColIdefics3. + Processor for ColModernVBert. """ query_augmentation_token: ClassVar[str] = "" @@ -73,6 +81,49 @@ def score( def get_n_patches( self, image_size: Tuple[int, int], - patch_size: int, + patch_size: Optional[int] = None, + *args, + **kwargs, ) -> Tuple[int, int]: - raise NotImplementedError("This method is not implemented for ColIdefics3.") + """ + Get the number of patches (n_patches_x, n_patches_y) that will be used to process an image of + size (height, width) with the given patch size. + + This method mirrors the Idefics3 image processing logic with image splitting: + 1. Resize the image so the longest edge equals the processor's longest_edge setting + 2. Split into sub-patches of max_image_size (512x512) + 3. Each sub-patch becomes image_seq_len tokens (8x8 grid) + + Note: The patch_size parameter is kept for API compatibility but is not used in the + calculation. The actual patch dimensions are determined by the image splitting logic + and image_seq_len. + + Args: + image_size: Tuple of (height, width) in pixels. + patch_size: The size of each square patch in pixels (unused, kept for API compatibility). + + Returns: + Tuple of (n_patches_x, n_patches_y) representing the number of token patches + along the width and height dimensions respectively (excluding global patch). + """ + # Get the longest_edge from the image processor's size configuration + longest_edge = self.image_processor.size.get("longest_edge", 2048) + + # Step 1: Calculate resized dimensions using the mixin helper method + height_new, width_new = self._calculate_resized_dimensions( + image_size, longest_edge + ) + + # Step 2: Calculate number of sub-patches (512x512 patches) + # This mirrors the split_image logic from Idefics3ImageProcessor + max_image_size = self.image_processor.max_image_size.get("longest_edge", 512) + n_subpatches_x = math.ceil(width_new / max_image_size) + n_subpatches_y = math.ceil(height_new / max_image_size) + + # Step 3: Calculate token grid dimensions + # Each sub-patch becomes image_seq_len tokens (typically 64 = 8x8 grid) + tokens_per_subpatch_side = int(math.sqrt(self.image_seq_len)) + n_patches_x = n_subpatches_x * tokens_per_subpatch_side + n_patches_y = n_subpatches_y * tokens_per_subpatch_side + + return n_patches_x, n_patches_y diff --git a/colpali_engine/utils/processing_utils.py b/colpali_engine/utils/processing_utils.py index 4c1d9617b..86ef191bd 100644 --- a/colpali_engine/utils/processing_utils.py +++ b/colpali_engine/utils/processing_utils.py @@ -1,5 +1,6 @@ import importlib import logging +import math from abc import ABC, abstractmethod from typing import ClassVar, List, Optional, Tuple, Union @@ -254,3 +255,250 @@ def get_n_patches( image of size (height, width) with the given patch size. """ pass + + +class Idefics3SplitImageInterpretabilityMixin: + """ + Mixin class providing interpretability support for Idefics3-style image splitting processors. + + This mixin adds methods for: + - Getting image token masks (full and local-only) + - Calculating patch grid dimensions + - Rearranging embeddings from sub-patch order to spatial order + - Computing similarity maps with correct spatial correspondence + + This is designed for processors that use Idefics3-style image splitting where: + 1. Images are resized to fit within a longest_edge constraint + 2. Images are split into sub-patches (e.g., 512x512 patches) + 3. Each sub-patch becomes image_seq_len tokens (e.g., 64 tokens in an 8x8 grid) + 4. A global patch is added as the last image_seq_len tokens + + Both ColIdefics3Processor and ColModernVBertProcessor use this pattern. + """ + + # These attributes must be provided by the implementing class + image_token: str # e.g., "" + image_seq_len: int # e.g., 64 + tokenizer: any # The tokenizer instance + image_processor: any # The image processor instance + + def get_image_mask(self, batch_images: BatchFeature) -> torch.Tensor: + """ + Get a tensor mask that identifies the image tokens in the batch. + + Args: + batch_images: BatchFeature containing processed images with input_ids. + + Returns: + A boolean tensor of the same shape as input_ids, where True indicates + an image token position. + """ + image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token) + return batch_images.input_ids == image_token_id + + def get_local_image_mask(self, batch_images: BatchFeature) -> torch.Tensor: + """ + Get a tensor mask that identifies only the LOCAL image tokens in the batch, + excluding the global patch tokens. + + In Idefics3 with image splitting, images are split into multiple sub-patches + plus one global patch. The global patch tokens are the last image_seq_len + image tokens for each image. For interpretability purposes, we typically want + to exclude the global patch since it doesn't have spatial correspondence. + + Args: + batch_images: BatchFeature containing processed images with input_ids. + + Returns: + A boolean tensor of the same shape as input_ids, where True indicates + a LOCAL image token position (excluding global patch). + """ + # Get the full image mask first + full_mask = self.get_image_mask(batch_images) + local_mask = full_mask.clone() + + # For each batch item, exclude the last image_seq_len image tokens (global patch) + batch_size = batch_images.input_ids.shape[0] + + for batch_idx in range(batch_size): + # Find all image token positions in this batch item + image_positions = full_mask[batch_idx].nonzero(as_tuple=True)[0] + + if len(image_positions) > self.image_seq_len: + # Exclude the last image_seq_len tokens (global patch) + global_patch_start = len(image_positions) - self.image_seq_len + global_patch_indices = image_positions[global_patch_start:] + + # Set these positions to False in the local mask + for idx in global_patch_indices: + local_mask[batch_idx, idx] = False + + return local_mask + + def _calculate_resized_dimensions( + self, + image_size: Tuple[int, int], + longest_edge: Optional[int], + ) -> Tuple[int, int]: + """ + Calculate the resized dimensions for an image based on the longest_edge constraint. + + This mirrors the Idefics3ImageProcessor logic for resizing images. + + Args: + image_size: Tuple of (height, width) in pixels. + longest_edge: Maximum size for the longest edge. If None, no resizing is applied. + + Returns: + Tuple of (height_new, width_new) representing the resized dimensions. + """ + height, width = image_size + + # Handle edge case where resizing is disabled (research use case) + if longest_edge is None: + return height, width + + # Resize the image so the longest edge equals longest_edge + aspect_ratio = width / height + + if width >= height: + # Width is the longest edge + width_new = longest_edge + height_new = int(width_new / aspect_ratio) + # Ensure height is even (as per Idefics3 implementation) + if height_new % 2 != 0: + height_new += 1 + else: + # Height is the longest edge + height_new = longest_edge + width_new = int(height_new * aspect_ratio) + # Ensure width is even (as per Idefics3 implementation) + if width_new % 2 != 0: + width_new += 1 + + # Ensure minimum size of 1 + height_new = max(height_new, 1) + width_new = max(width_new, 1) + + return height_new, width_new + + def rearrange_image_embeddings( + self, + image_embeddings: torch.Tensor, + image_mask: torch.Tensor, + n_patches: Tuple[int, int], + ) -> torch.Tensor: + """ + Rearrange image embeddings from sub-patch order to spatial order. + + In Idefics3 with image splitting, tokens are arranged sub-patch by sub-patch: + - Tokens 0-63: First 512x512 sub-patch (8x8 tokens) + - Tokens 64-127: Second 512x512 sub-patch (8x8 tokens) + - etc. + + This method rearranges them into a proper 2D spatial grid where tokens + are organized by their actual spatial position in the image. + + Args: + image_embeddings: tensor of shape (sequence_length, dim) for a single image + image_mask: boolean tensor of shape (sequence_length,) indicating image tokens + n_patches: tuple of (n_patches_x, n_patches_y) - total token grid dimensions + + Returns: + tensor of shape (n_patches_x, n_patches_y, dim) with spatially correct ordering + """ + # Extract only the image token embeddings + masked_embeddings = image_embeddings[image_mask] # (n_patches_x * n_patches_y, dim) + + n_patches_x, n_patches_y = n_patches + dim = masked_embeddings.shape[-1] + + # Calculate sub-patch grid dimensions + tokens_per_subpatch_side = int(math.sqrt(self.image_seq_len)) + n_subpatches_x = n_patches_x // tokens_per_subpatch_side + n_subpatches_y = n_patches_y // tokens_per_subpatch_side + + # Reshape from flat sub-patch order to sub-patch grid + # Current order: (n_subpatches_y * n_subpatches_x * tokens_per_side * tokens_per_side, dim) + # Reshape to: (n_subpatches_y, n_subpatches_x, tokens_per_side, tokens_per_side, dim) + reshaped = masked_embeddings.reshape( + n_subpatches_y, + n_subpatches_x, + tokens_per_subpatch_side, + tokens_per_subpatch_side, + dim, + ) + + # Permute to interleave sub-patch rows and columns + # From: (n_subpatches_y, n_subpatches_x, tokens_per_side, tokens_per_side, dim) + # To: (n_subpatches_y, tokens_per_side, n_subpatches_x, tokens_per_side, dim) + permuted = reshaped.permute(0, 2, 1, 3, 4) + + # Final reshape to (n_patches_y, n_patches_x, dim) + # Note: This gives (height, width, dim) ordering + spatial_grid = permuted.reshape(n_patches_y, n_patches_x, dim) + + # Transpose to get (n_patches_x, n_patches_y, dim) to match expected format + # This gives (width, height, dim) ordering which matches the similarity map convention + spatial_grid = spatial_grid.permute(1, 0, 2) + + return spatial_grid + + def get_similarity_maps_from_embeddings( + self, + image_embeddings: torch.Tensor, + query_embeddings: torch.Tensor, + n_patches: Union[Tuple[int, int], List[Tuple[int, int]]], + image_mask: torch.Tensor, + ) -> List[torch.Tensor]: + """ + Get similarity maps with correct spatial ordering for Idefics3-style image splitting. + + This method correctly handles the sub-patch token ordering used by Idefics3 processors, + where tokens are arranged sub-patch by sub-patch rather than in row-major order across + the entire image. + + Args: + image_embeddings: tensor of shape (batch_size, image_tokens, dim) + query_embeddings: tensor of shape (batch_size, query_tokens, dim) + n_patches: number of patches per dimension (n_patches_x, n_patches_y). + If a single tuple, it's broadcasted to all batch items. + image_mask: tensor of shape (batch_size, image_tokens) indicating LOCAL image tokens + (use get_local_image_mask to exclude global patch) + + Returns: + List of tensors, each of shape (query_tokens, n_patches_x, n_patches_y) + """ + if isinstance(n_patches, tuple): + n_patches = [n_patches] * image_embeddings.size(0) + + similarity_maps: List[torch.Tensor] = [] + + for idx in range(image_embeddings.size(0)): + # Sanity check + if image_mask[idx].sum() != n_patches[idx][0] * n_patches[idx][1]: + raise ValueError( + f"The number of patches ({n_patches[idx][0]} x {n_patches[idx][1]} = " + f"{n_patches[idx][0] * n_patches[idx][1]}) " + f"does not match the number of non-padded image tokens ({image_mask[idx].sum()}). " + f"Hint: Use get_local_image_mask() instead of get_image_mask() to exclude the global patch." + ) + + # Rearrange image embeddings to correct spatial order + image_embedding_grid = self.rearrange_image_embeddings( + image_embeddings[idx], + image_mask[idx], + n_patches[idx], + ) # (n_patches_x, n_patches_y, dim) + + # Compute similarity: einsum("nk,ijk->nij", query, image_grid) + # query: (query_tokens, dim) + # image_grid: (n_patches_x, n_patches_y, dim) + # result: (query_tokens, n_patches_x, n_patches_y) + similarity_map = torch.einsum( + "nk,ijk->nij", query_embeddings[idx], image_embedding_grid + ) + + similarity_maps.append(similarity_map) + + return similarity_maps diff --git a/examples/interpretability/colmodernvbert/generate_interpretability_maps.py b/examples/interpretability/colmodernvbert/generate_interpretability_maps.py new file mode 100644 index 000000000..24185ba1d --- /dev/null +++ b/examples/interpretability/colmodernvbert/generate_interpretability_maps.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +""" +Simplified example for generating ColModernVBert interpretability maps. + +This example follows the same user-friendly API pattern as the ColPali cookbook, +but uses ColModernVBert with automatic handling of Idefics3-style image splitting. + +Usage: + python examples/interpretability/colmodernvbert/simple_interpretability_example.py +""" + +from pathlib import Path +import uuid +from typing import cast, Any + +import matplotlib.pyplot as plt +import torch +from PIL import Image + +from colpali_engine.interpretability.similarity_maps import plot_all_similarity_maps +from colpali_engine.models import ColModernVBert, ColModernVBertProcessor + + +def main(): + print("=== ColModernVBert Simple Interpretability Example ===\n") + + # ==================== USER INPUTS ==================== + use_real_document = True # Set to False to use a blank test image + # ===================================================== + + if use_real_document: + # Load a real document from DocVQA dataset + print("Loading a real document from DocVQA dataset...") + from datasets import load_dataset + + dataset = load_dataset( + "vidore/docvqa_test_subsampled", split="test", streaming=True + ) + # streaming datasets may yield values that type checkers treat as Sequence; + # cast to dict so string indexing (sample["image"]) is accepted by the type checker. + sample = dict(next(iter(dataset))) + image = sample["image"] + query = sample["query"] + print(f"Document loaded! Query: '{query}'") + else: + # For demo purposes, use a simple test image + print("Creating a demo test image...") + image = Image.new("RGB", (800, 600), color="white") + query = "What is the total revenue?" + + # Load model and processor + print("Loading model and processor...") + processor = ColModernVBertProcessor.from_pretrained("ModernVBERT/colmodernvbert") + model = ColModernVBert.from_pretrained("ModernVBERT/colmodernvbert") + model.eval() + + # Preprocess inputs + print(f"\nProcessing query: '{query}'") + batch_images = processor.process_images([image]) + batch_queries = processor.process_queries([query]) + + # Forward passes + print("Computing embeddings...") + with torch.no_grad(): + image_embeddings = model(**batch_images) + query_embeddings = model(**batch_queries) + + # Get the number of image patches + n_patches = processor.get_n_patches((image.size[1], image.size[0])) + print(f"Number of patches: {n_patches[0]} x {n_patches[1]}") + # Get LOCAL image mask (excludes global patch for spatial correspondence) + image_mask = processor.get_local_image_mask(cast(Any, batch_images)) + + # Generate similarity maps using the processor's method + # This automatically handles Idefics3-style sub-patch rearrangement! + similarity_maps_batch = processor.get_similarity_maps_from_embeddings( + image_embeddings=image_embeddings, + query_embeddings=query_embeddings, + n_patches=n_patches, + image_mask=image_mask, + ) + + # Get the similarity map for our input image + similarity_maps = similarity_maps_batch[ + 0 + ] # (query_length, n_patches_x, n_patches_y) + print(f"Similarity map shape: {similarity_maps.shape}") + + # Get query tokens (filtering out special tokens) + input_ids = batch_queries.input_ids[0].tolist() + query_tokens = processor.tokenizer.convert_ids_to_tokens(batch_queries.input_ids[0]) + special_token_ids = set(processor.tokenizer.all_special_ids or []) + + filtered_tokens = [] + filtered_indices = [] + for idx, (token, token_id) in enumerate(zip(query_tokens, input_ids)): + if token_id in special_token_ids: + continue + filtered_tokens.append(token) + filtered_indices.append(idx) + + # Filter similarity maps to match tokens + similarity_maps = similarity_maps[filtered_indices] + + # Clean tokens for display (remove special characters that may cause encoding issues) + display_tokens = [t.replace("Ġ", " ").replace("▁", " ") for t in filtered_tokens] + print(f"\nQuery tokens: {display_tokens}") + print( + f"Similarity range: [{similarity_maps.min().item():.3f}, {similarity_maps.max().item():.3f}]" + ) + + # Generate all similarity maps + print("\nGenerating similarity maps for all tokens...") + plots = plot_all_similarity_maps( + image=image, + query_tokens=filtered_tokens, + similarity_maps=similarity_maps, + figsize=(8, 8), + show_colorbar=False, + add_title=True, + ) + + # Save the plots + output_dir = Path("outputs/interpretability/colmodernvbert/" + uuid.uuid4().hex[:8]) + output_dir.mkdir(parents=True, exist_ok=True) + + for idx, (fig, ax) in enumerate(plots): + token = filtered_tokens[idx] + # Sanitize token for filename (remove special characters) + token_safe = ( + token.replace("<", "") + .replace(">", "") + .replace("Ġ", "") + .replace("▁", "") + .replace("?", "") + .replace(":", "") + .replace("/", "") + .replace("\\", "") + .replace("|", "") + .replace("*", "") + .replace('"', "") + ) + if not token_safe: + token_safe = f"token_{idx}" + savepath = output_dir / f"similarity_map_{idx}_{token_safe}.png" + fig.savefig(savepath, bbox_inches="tight") + print(f" Saved: {savepath.name}") + plt.close(fig) + + print(f"\n[SUCCESS] All similarity maps saved to: {output_dir.absolute()}") + + +if __name__ == "__main__": + main() diff --git a/tests/models/idefics3/colidefics3/test_processing_colidefics3.py b/tests/models/idefics3/colidefics3/test_processing_colidefics3.py index 4f8599a38..fce61a173 100644 --- a/tests/models/idefics3/colidefics3/test_processing_colidefics3.py +++ b/tests/models/idefics3/colidefics3/test_processing_colidefics3.py @@ -13,11 +13,15 @@ def model_name() -> str: @pytest.fixture(scope="module") -def processor_from_pretrained(model_name: str) -> Generator[ColIdefics3Processor, None, None]: +def processor_from_pretrained( + model_name: str, +) -> Generator[ColIdefics3Processor, None, None]: yield cast(ColIdefics3Processor, ColIdefics3Processor.from_pretrained(model_name)) -def test_load_processor_from_pretrained(processor_from_pretrained: ColIdefics3Processor): +def test_load_processor_from_pretrained( + processor_from_pretrained: ColIdefics3Processor, +): assert isinstance(processor_from_pretrained, ColIdefics3Processor) @@ -63,3 +67,101 @@ def test_process_queries(processor_from_pretrained: ColIdefics3Processor): assert "input_ids" in batch_encoding assert isinstance(batch_encoding["input_ids"], torch.Tensor) assert cast(torch.Tensor, batch_encoding["input_ids"]).shape[0] == len(queries) + + +def test_get_n_patches(processor_from_pretrained: ColIdefics3Processor): + """ + Test that get_n_patches returns the correct number of patches for various image sizes. + """ + # Get the patch size from the image processor + patch_size = processor_from_pretrained.image_processor.max_image_size.get( + "longest_edge", 512 + ) + + # Test case 1: Small square image + image_size = (100, 100) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( + image_size, patch_size + ) + assert isinstance(n_patches_x, int) + assert isinstance(n_patches_y, int) + assert n_patches_x > 0 + assert n_patches_y > 0 + + # Test case 2: Wide image (width > height) + image_size = (100, 200) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( + image_size, patch_size + ) + assert n_patches_x >= n_patches_y # More patches along width + + # Test case 3: Tall image (height > width) + image_size = (200, 100) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( + image_size, patch_size + ) + assert n_patches_y >= n_patches_x # More patches along height + + # Test case 4: Square image + image_size = (500, 500) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( + image_size, patch_size + ) + assert n_patches_x == n_patches_y # Equal patches for square image + + +def test_get_n_patches_matches_actual_processing( + processor_from_pretrained: ColIdefics3Processor, +): + """ + Test that get_n_patches matches the actual number of patches produced by process_images. + """ + # Create a test image + image_size = (16, 32) # PIL Image.new takes (width, height) + image = Image.new("RGB", image_size, color="black") + + # Process the image to get actual patch count + batch_feature = processor_from_pretrained.process_images([image]) + # pixel_values shape is [batch_size, num_patches, channels, patch_height, patch_width] + actual_num_patches = batch_feature["pixel_values"].shape[1] + + # Get the patch size from the image processor + patch_size = processor_from_pretrained.image_processor.max_image_size.get( + "longest_edge", 512 + ) + + # Calculate expected patches using get_n_patches + # Note: image_size for get_n_patches is (height, width), but PIL uses (width, height) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( + (image_size[1], image_size[0]), patch_size + ) + expected_num_patches = n_patches_x * n_patches_y + + # The actual number of patches includes the global image patch (+1) + # So we compare with expected + 1 + assert ( + actual_num_patches == expected_num_patches + 1 + ), f"Expected {expected_num_patches + 1} patches (including global), got {actual_num_patches}" + + +def test_get_image_mask(processor_from_pretrained: ColIdefics3Processor): + """ + Test that get_image_mask correctly identifies image tokens. + """ + # Create a dummy image + image_size = (16, 32) + image = Image.new("RGB", image_size, color="black") + images = [image] + + # Process the image + batch_feature = processor_from_pretrained.process_images(images) + + # Get the image mask + image_mask = processor_from_pretrained.get_image_mask(batch_feature) + + # Assertions + assert isinstance(image_mask, torch.Tensor) + assert image_mask.shape == batch_feature.input_ids.shape + assert image_mask.dtype == torch.bool + # There should be some image tokens (True values) in the mask + assert image_mask.sum() > 0 diff --git a/tests/models/modernvbert/test_interpretability_colmodernvbert.py b/tests/models/modernvbert/test_interpretability_colmodernvbert.py new file mode 100644 index 000000000..310300092 --- /dev/null +++ b/tests/models/modernvbert/test_interpretability_colmodernvbert.py @@ -0,0 +1,366 @@ +""" +Test interpretability maps for ColModernVBert model. + +This module tests: +1. get_n_patches() method - calculates correct patch dimensions +2. get_image_mask() method - identifies image tokens correctly +3. End-to-end similarity map generation +""" + +from typing import Generator, cast + +import pytest +import torch +from PIL import Image + +from colpali_engine.models import ColModernVBert, ColModernVBertProcessor +from colpali_engine.interpretability.similarity_map_utils import ( + normalize_similarity_map, +) + + +@pytest.fixture(scope="module") +def model_name() -> str: + return "ModernVBERT/colmodernvbert" + + +@pytest.fixture(scope="module") +def processor_from_pretrained( + model_name: str, +) -> Generator[ColModernVBertProcessor, None, None]: + yield cast( + ColModernVBertProcessor, ColModernVBertProcessor.from_pretrained(model_name) + ) + + +@pytest.fixture(scope="module") +def model_from_pretrained(model_name: str) -> Generator[ColModernVBert, None, None]: + yield cast(ColModernVBert, ColModernVBert.from_pretrained(model_name)) + + +class TestGetNPatches: + """Test the get_n_patches method for calculating patch dimensions.""" + + def test_get_n_patches_returns_integers( + self, processor_from_pretrained: ColModernVBertProcessor + ): + """Test that get_n_patches returns integer values.""" + patch_size = 14 # Common patch size for vision transformers + image_size = (100, 100) + + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( + image_size, patch_size + ) + + assert isinstance(n_patches_x, int) + assert isinstance(n_patches_y, int) + assert n_patches_x > 0 + assert n_patches_y > 0 + + def test_get_n_patches_wide_image( + self, processor_from_pretrained: ColModernVBertProcessor + ): + """Test that wide images have more patches along width.""" + patch_size = 14 + image_size = (100, 200) # (height, width) - wider than tall + + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( + image_size, patch_size + ) + + # n_patches_x is along width, n_patches_y is along height + assert ( + n_patches_x >= n_patches_y + ), f"Expected more patches along width, got x={n_patches_x}, y={n_patches_y}" + + def test_get_n_patches_tall_image( + self, processor_from_pretrained: ColModernVBertProcessor + ): + """Test that tall images have more patches along height.""" + patch_size = 14 + image_size = (200, 100) # (height, width) - taller than wide + + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( + image_size, patch_size + ) + + assert ( + n_patches_y >= n_patches_x + ), f"Expected more patches along height, got x={n_patches_x}, y={n_patches_y}" + + def test_get_n_patches_square_image( + self, processor_from_pretrained: ColModernVBertProcessor + ): + """Test that square images have equal patches in both dimensions.""" + patch_size = 14 + image_size = (500, 500) + + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( + image_size, patch_size + ) + + assert ( + n_patches_x == n_patches_y + ), f"Expected equal patches for square image, got x={n_patches_x}, y={n_patches_y}" + + def test_get_n_patches_aspect_ratio_preservation( + self, processor_from_pretrained: ColModernVBertProcessor + ): + """Test that aspect ratio is approximately preserved in patch dimensions.""" + patch_size = 14 + + # Test with a 2:1 aspect ratio image + image_size = (300, 600) # height=300, width=600 + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( + image_size, patch_size + ) + + # The aspect ratio of patches should be close to 2:1 + patch_ratio = n_patches_x / n_patches_y + expected_ratio = 2.0 + + # Allow tolerance due to: + # 1. Image splitting into 512x512 sub-patches (quantization effects) + # 2. Even-dimension rounding in resize logic + # 3. Ceiling division in patch calculations + # These factors can cause ~25% deviation from the ideal aspect ratio + assert 1.5 <= patch_ratio <= 2.5, f"Expected ~2:1 ratio, got {patch_ratio:.2f}" + + +class TestGetImageMask: + """Test the get_image_mask method for identifying image tokens.""" + + def test_get_image_mask_shape( + self, processor_from_pretrained: ColModernVBertProcessor + ): + """Test that image mask has the same shape as input_ids.""" + image = Image.new("RGB", (64, 32), color="red") + batch_feature = processor_from_pretrained.process_images([image]) + + image_mask = processor_from_pretrained.get_image_mask(batch_feature) + + assert image_mask.shape == batch_feature.input_ids.shape + assert image_mask.dtype == torch.bool + + def test_get_image_mask_has_image_tokens( + self, processor_from_pretrained: ColModernVBertProcessor + ): + """Test that the mask identifies some image tokens.""" + image = Image.new("RGB", (64, 32), color="blue") + batch_feature = processor_from_pretrained.process_images([image]) + + image_mask = processor_from_pretrained.get_image_mask(batch_feature) + + # There should be image tokens present + assert ( + image_mask.sum() > 0 + ), "Expected to find image tokens in the processed batch" + + def test_get_image_mask_batch_consistency( + self, processor_from_pretrained: ColModernVBertProcessor + ): + """Test that image mask works correctly with batched images.""" + images = [ + Image.new("RGB", (64, 32), color="red"), + Image.new("RGB", (128, 64), color="green"), + ] + batch_feature = processor_from_pretrained.process_images(images) + + image_mask = processor_from_pretrained.get_image_mask(batch_feature) + + assert image_mask.shape[0] == len(images) + # Each image should have some image tokens + for i in range(len(images)): + assert image_mask[i].sum() > 0, f"Image {i} should have image tokens" + + +class TestEndToEndInterpretability: + """Test end-to-end interpretability map generation.""" + + @pytest.mark.slow + def test_similarity_maps_shape( + self, + processor_from_pretrained: ColModernVBertProcessor, + model_from_pretrained: ColModernVBert, + ): + """Test that similarity maps have the correct shape based on get_n_patches.""" + # Create a test image + image_size_pil = (128, 64) # PIL uses (width, height) + image = Image.new("RGB", image_size_pil, color="white") + + # Create a query + query = "test query" + + # Process image and query + batch_images = processor_from_pretrained.process_images([image]) + batch_queries = processor_from_pretrained.process_texts([query]) + + # Get patch size from the model or processor + # ModernVBert uses patch_size from its config + patch_size = ( + 14 # Default for many vision transformers (unused but required for API) + ) + + # Calculate expected patches + # Note: image_size for get_n_patches is (height, width) + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( + (image_size_pil[1], image_size_pil[0]), patch_size # (height, width) + ) + + # Get embeddings + with torch.no_grad(): + image_embeddings = model_from_pretrained(**batch_images) + query_embeddings = model_from_pretrained(**batch_queries) + + # Get LOCAL image mask (excluding global patch for interpretability) + image_mask = processor_from_pretrained.get_local_image_mask(batch_images) + + # Calculate similarity maps using the processor's method + # This correctly handles the sub-patch token ordering + similarity_maps = processor_from_pretrained.get_similarity_maps_from_embeddings( + image_embeddings=image_embeddings, + query_embeddings=query_embeddings, + n_patches=(n_patches_x, n_patches_y), + image_mask=image_mask, + ) + + # Check shape + assert len(similarity_maps) == 1 # One batch item + query_length = query_embeddings.shape[1] + + # similarity_maps[0] should have shape (query_tokens, n_patches_x, n_patches_y) + expected_shape = (query_length, n_patches_x, n_patches_y) + assert ( + similarity_maps[0].shape == expected_shape + ), f"Expected shape {expected_shape}, got {similarity_maps[0].shape}" + + @pytest.mark.slow + def test_similarity_maps_values( + self, + processor_from_pretrained: ColModernVBertProcessor, + model_from_pretrained: ColModernVBert, + ): + """Test that similarity map values are reasonable after normalization.""" + image = Image.new("RGB", (64, 64), color="black") + query = "dark image" + + batch_images = processor_from_pretrained.process_images([image]) + batch_queries = processor_from_pretrained.process_texts([query]) + + patch_size = 14 + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( + (64, 64), patch_size + ) + + with torch.no_grad(): + image_embeddings = model_from_pretrained(**batch_images) + query_embeddings = model_from_pretrained(**batch_queries) + + # Use LOCAL image mask (excluding global patch) + image_mask = processor_from_pretrained.get_local_image_mask(batch_images) + + # Use the processor's method for correct sub-patch ordering + similarity_maps = processor_from_pretrained.get_similarity_maps_from_embeddings( + image_embeddings=image_embeddings, + query_embeddings=query_embeddings, + n_patches=(n_patches_x, n_patches_y), + image_mask=image_mask, + ) + + # Normalize and check values + sim_map = similarity_maps[0][0] # First query token's similarity map + normalized_map = normalize_similarity_map(sim_map) + + # After normalization, values should be in [0, 1] + assert normalized_map.min() >= 0.0 + assert normalized_map.max() <= 1.0 + assert ( + normalized_map.max() == 1.0 + ) # Max should be exactly 1.0 after normalization + + @pytest.mark.slow + def test_patch_count_matches_mask_count( + self, + processor_from_pretrained: ColModernVBertProcessor, + ): + """Test that the number of LOCAL image tokens matches expected patch count.""" + image_size_pil = (128, 128) + image = Image.new("RGB", image_size_pil, color="gray") + + batch_feature = processor_from_pretrained.process_images([image]) + + # Use LOCAL image mask (excluding global patch) + local_image_mask = processor_from_pretrained.get_local_image_mask(batch_feature) + + # Count actual LOCAL image tokens + actual_local_tokens = local_image_mask.sum().item() + + # Calculate expected patches + patch_size = 14 + n_patches_x, n_patches_y = processor_from_pretrained.get_n_patches( + (image_size_pil[1], image_size_pil[0]), patch_size + ) + expected_local_patches = n_patches_x * n_patches_y + + # LOCAL tokens should match exactly + assert ( + actual_local_tokens == expected_local_patches + ), f"Expected {expected_local_patches} local image tokens, got {actual_local_tokens}" + + @pytest.mark.slow + def test_global_patch_excluded( + self, + processor_from_pretrained: ColModernVBertProcessor, + ): + """Test that global patch is correctly excluded from local mask.""" + image_size_pil = (128, 128) + image = Image.new("RGB", image_size_pil, color="gray") + + batch_feature = processor_from_pretrained.process_images([image]) + + full_mask = processor_from_pretrained.get_image_mask(batch_feature) + local_mask = processor_from_pretrained.get_local_image_mask(batch_feature) + + full_count = full_mask.sum().item() + local_count = local_mask.sum().item() + + # The difference should be exactly image_seq_len (global patch tokens) + image_seq_len = processor_from_pretrained.image_seq_len + assert ( + full_count - local_count == image_seq_len + ), f"Expected {image_seq_len} global patch tokens, got {full_count - local_count}" + + +class TestInterpretabilityConsistency: + """Test consistency of interpretability across different scenarios.""" + + def test_different_image_sizes_produce_different_patch_counts( + self, + processor_from_pretrained: ColModernVBertProcessor, + ): + """Test that different image sizes produce different patch dimensions.""" + patch_size = 14 + + small_patches = processor_from_pretrained.get_n_patches((100, 100), patch_size) + large_patches = processor_from_pretrained.get_n_patches((500, 500), patch_size) + + # Larger images should produce the same or different patch counts + # depending on the longest_edge configuration + # At minimum, verify both are valid + assert small_patches[0] > 0 and small_patches[1] > 0 + assert large_patches[0] > 0 and large_patches[1] > 0 + + def test_consistent_patch_calculation( + self, + processor_from_pretrained: ColModernVBertProcessor, + ): + """Test that get_n_patches is deterministic.""" + patch_size = 14 + image_size = (256, 512) + + # Call multiple times + result1 = processor_from_pretrained.get_n_patches(image_size, patch_size) + result2 = processor_from_pretrained.get_n_patches(image_size, patch_size) + result3 = processor_from_pretrained.get_n_patches(image_size, patch_size) + + assert result1 == result2 == result3, "get_n_patches should be deterministic"