Skip to content
33 changes: 26 additions & 7 deletions colpali_engine/interpretability/similarity_map_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union

import torch
from einops import rearrange
Expand Down Expand Up @@ -56,24 +56,43 @@ def get_similarity_maps_from_embeddings(
return similarity_maps


def normalize_similarity_map(similarity_map: torch.Tensor) -> torch.Tensor:
def normalize_similarity_map(
similarity_map: torch.Tensor,
value_range: Optional[Tuple[float, float]] = None,
) -> torch.Tensor:
"""
Normalize the similarity map to have values in the range [0, 1].

Args:
similarity_map: tensor of shape (n_patch_x, n_patch_y) or (batch_size, n_patch_x, n_patch_y)
value_range: optional tuple specifying the (min, max) range to use for normalization.
When None, the min/max are computed from the input tensor (default behavior).
"""
if similarity_map.ndim not in [2, 3]:
raise ValueError(
"The input tensor must have 2 dimensions (n_patch_x, n_patch_y) or "
"3 dimensions (batch_size, n_patch_x, n_patch_y)."
)

# Compute the minimum values along the last two dimensions (n_patch_x, n_patch_y)
min_vals = similarity_map.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0] # (1, 1) or (batch_size, 1, 1)

# Compute the maximum values along the last two dimensions (n_patch_x, n_patch_y)
max_vals = similarity_map.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0] # (1, 1) or (batch_size, 1, 1)
if value_range is None:
# Compute the minimum values along the last two dimensions (n_patch_x, n_patch_y)
min_vals = similarity_map.min(dim=-1, keepdim=True)[0].min(
dim=-2, keepdim=True
)[0] # (1, 1) or (batch_size, 1, 1)

# Compute the maximum values along the last two dimensions (n_patch_x, n_patch_y)
max_vals = similarity_map.max(dim=-1, keepdim=True)[0].max(
dim=-2, keepdim=True
)[0] # (1, 1) or (batch_size, 1, 1)
else:
min_vals, max_vals = value_range
broadcast_shape = (1,) * similarity_map.ndim
min_vals = torch.as_tensor(min_vals, dtype=similarity_map.dtype, device=similarity_map.device).view(
broadcast_shape
)
max_vals = torch.as_tensor(max_vals, dtype=similarity_map.dtype, device=similarity_map.device).view(
broadcast_shape
)

# Normalize the tensor
# NOTE: Add a small epsilon to avoid division by zero.
Expand Down
19 changes: 17 additions & 2 deletions colpali_engine/interpretability/similarity_maps.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Tuple
from typing import List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
Expand All @@ -15,6 +15,7 @@ def plot_similarity_map(
similarity_map: torch.Tensor,
figsize: Tuple[int, int] = (8, 8),
show_colorbar: bool = False,
normalization_range: Optional[Tuple[float, float]] = None,
) -> Tuple[plt.Figure, plt.Axes]:
"""
Plot and overlay a similarity map over the input image.
Expand Down Expand Up @@ -42,7 +43,10 @@ def plot_similarity_map(

# Normalize the similarity map and convert it to Pillow image
similarity_map_array = (
normalize_similarity_map(similarity_map).to(torch.float32).cpu().numpy()
normalize_similarity_map(similarity_map, value_range=normalization_range)
.to(torch.float32)
.cpu()
.numpy()
) # (n_patches_x, n_patches_y)

# Reshape the similarity map to match the PIL shape convention
Expand Down Expand Up @@ -78,6 +82,7 @@ def plot_all_similarity_maps(
figsize: Tuple[int, int] = (8, 8),
show_colorbar: bool = False,
add_title: bool = True,
normalize_per_query: bool = True,
) -> List[Tuple[plt.Figure, plt.Axes]]:
"""
For each token in the query, plot and overlay a similarity map over the input image.
Expand All @@ -93,6 +98,8 @@ def plot_all_similarity_maps(
figsize: size of the figure
show_colorbar: whether to show a colorbar
add_title: whether to add a title with the token and the max similarity score
normalize_per_query: if True (default), reuse a single min/max range across all tokens in the
provided query similarity tensor. This avoids stretching near-constant maps to the full color scale.

Example usage for one query-image pair:

Expand Down Expand Up @@ -133,12 +140,20 @@ def plot_all_similarity_maps(

plots: List[Tuple[plt.Figure, plt.Axes]] = []

normalization_range: Optional[Tuple[float, float]] = None
if normalize_per_query:
normalization_range = (
similarity_maps.min().item(),
similarity_maps.max().item(),
)

for idx, token in enumerate(query_tokens):
fig, ax = plot_similarity_map(
image=image,
similarity_map=similarity_maps[idx],
figsize=figsize,
show_colorbar=show_colorbar,
normalization_range=normalization_range,
)

if add_title:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,30 @@
import math
from typing import ClassVar, List, Optional, Tuple, Union

import torch
from PIL import Image
from transformers import BatchEncoding, BatchFeature, Idefics3Processor

from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
from colpali_engine.utils.processing_utils import (
BaseVisualRetrieverProcessor,
Idefics3SplitImageInterpretabilityMixin,
)


class ColIdefics3Processor(BaseVisualRetrieverProcessor, Idefics3Processor):
class ColIdefics3Processor(
Idefics3SplitImageInterpretabilityMixin,
BaseVisualRetrieverProcessor,
Idefics3Processor,
):
"""
Processor for ColIdefics3.
"""

query_augmentation_token: ClassVar[str] = "<end_of_utterance>"
image_token: ClassVar[str] = "<image>"
visual_prompt_prefix: ClassVar[str] = "<|im_start|>User:<image>Describe the image.<end_of_utterance>\nAssistant:"
visual_prompt_prefix: ClassVar[str] = (
"<|im_start|>User:<image>Describe the image.<end_of_utterance>\nAssistant:"
)

def __init__(self, *args, image_seq_len=64, **kwargs):
super().__init__(*args, image_seq_len=image_seq_len, **kwargs)
Expand Down Expand Up @@ -72,5 +82,36 @@ def get_n_patches(
self,
image_size: Tuple[int, int],
patch_size: int,
*args,
**kwargs,
) -> Tuple[int, int]:
raise NotImplementedError("This method is not implemented for ColIdefics3.")
"""
Get the number of patches (n_patches_x, n_patches_y) that will be used to process an image of
size (height, width) with the given patch size.

This method mirrors the Idefics3 image processing logic:
1. Resize the image so the longest edge equals the processor's longest_edge setting
2. Calculate the number of patches in each direction using ceiling division

Args:
image_size: Tuple of (height, width) in pixels.
patch_size: The size of each square patch in pixels.

Returns:
Tuple of (n_patches_x, n_patches_y) representing the number of patches
along the width and height dimensions respectively.
"""
# Get the longest_edge from the image processor's size configuration
longest_edge = self.image_processor.size.get("longest_edge", 4 * patch_size)

# Step 1: Calculate resized dimensions using the mixin helper method
height_new, width_new = self._calculate_resized_dimensions(
image_size, longest_edge
)

# Step 2: Calculate the number of patches in each direction
# This mirrors the split_image logic from Idefics3ImageProcessor
n_patches_y = math.ceil(height_new / patch_size)
n_patches_x = math.ceil(width_new / patch_size)

return n_patches_x, n_patches_y
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
import math
from typing import ClassVar, List, Optional, Tuple, Union

import torch
from PIL import Image
from transformers import BatchEncoding, BatchFeature, Idefics3Processor

from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
from colpali_engine.utils.processing_utils import (
BaseVisualRetrieverProcessor,
Idefics3SplitImageInterpretabilityMixin,
)


class ColModernVBertProcessor(BaseVisualRetrieverProcessor, Idefics3Processor):
class ColModernVBertProcessor(
Idefics3SplitImageInterpretabilityMixin,
BaseVisualRetrieverProcessor,
Idefics3Processor,
):
"""
Processor for ColIdefics3.
Processor for ColModernVBert.
"""

query_augmentation_token: ClassVar[str] = "<end_of_utterance>"
Expand Down Expand Up @@ -73,6 +81,49 @@ def score(
def get_n_patches(
self,
image_size: Tuple[int, int],
patch_size: int,
patch_size: Optional[int] = None,
*args,
**kwargs,
) -> Tuple[int, int]:
raise NotImplementedError("This method is not implemented for ColIdefics3.")
"""
Get the number of patches (n_patches_x, n_patches_y) that will be used to process an image of
size (height, width) with the given patch size.

This method mirrors the Idefics3 image processing logic with image splitting:
1. Resize the image so the longest edge equals the processor's longest_edge setting
2. Split into sub-patches of max_image_size (512x512)
3. Each sub-patch becomes image_seq_len tokens (8x8 grid)

Note: The patch_size parameter is kept for API compatibility but is not used in the
calculation. The actual patch dimensions are determined by the image splitting logic
and image_seq_len.

Args:
image_size: Tuple of (height, width) in pixels.
patch_size: The size of each square patch in pixels (unused, kept for API compatibility).

Returns:
Tuple of (n_patches_x, n_patches_y) representing the number of token patches
along the width and height dimensions respectively (excluding global patch).
"""
# Get the longest_edge from the image processor's size configuration
longest_edge = self.image_processor.size.get("longest_edge", 2048)

# Step 1: Calculate resized dimensions using the mixin helper method
height_new, width_new = self._calculate_resized_dimensions(
image_size, longest_edge
)

# Step 2: Calculate number of sub-patches (512x512 patches)
# This mirrors the split_image logic from Idefics3ImageProcessor
max_image_size = self.image_processor.max_image_size.get("longest_edge", 512)
n_subpatches_x = math.ceil(width_new / max_image_size)
n_subpatches_y = math.ceil(height_new / max_image_size)

# Step 3: Calculate token grid dimensions
# Each sub-patch becomes image_seq_len tokens (typically 64 = 8x8 grid)
tokens_per_subpatch_side = int(math.sqrt(self.image_seq_len))
n_patches_x = n_subpatches_x * tokens_per_subpatch_side
n_patches_y = n_subpatches_y * tokens_per_subpatch_side

return n_patches_x, n_patches_y
Loading