Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 155 additions & 27 deletions datasets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Modified for Adaptive Superpixels feature
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
from typing import Any, Callable, Optional
Expand All @@ -18,56 +19,173 @@
from timm.data import create_transform


# Superpixel generation function
def generate_superpixels(
image_tensor_batch: torch.Tensor,
k_values: torch.Tensor,
m_values: torch.Tensor,
denormalize_transform: Callable,
downsample_factor: int,
spix_method: str,
device: torch.device
) -> torch.Tensor:
"""
Generates superpixel assignments for a batch of images.

Args:
image_tensor_batch (torch.Tensor): Batch of image tensors (N, C, H, W).
Assumed to be normalized.
k_values (torch.Tensor): Tensor of K values (number of superpixels) for each image (N,).
m_values (torch.Tensor): Tensor of m values (compactness) for each image (N,).
denormalize_transform (Callable): A callable transform (e.g., an instance of Denormalize)
to convert normalized image tensors back to approx. [0, 255] range.
downsample_factor (int): Factor by which to downsample the image before applying SLIC.
Higher values speed up computation but reduce detail.
spix_method (str): Superpixel algorithm to use, either 'fastslic' or 'slic'.
device (torch.device): The Torch device to which the final assignment tensor should be moved.

Returns:
torch.Tensor: A batch of superpixel assignment maps.
Shape (N, 1, H_spix, W_spix), where H_spix and W_spix are
the dimensions of the downsampled image used for SLIC.
"""
batch_assignments = []
for i in range(image_tensor_batch.shape[0]): # Iterate over each image in the batch
img_tensor = image_tensor_batch[i] # Single image tensor (C, H, W)
k_i = k_values[i].item() # K value for this specific image
m_i = m_values[i].item() # m value for this specific image

# Denormalize and prepare image for SLIC algorithm
# SLIC typically expects an image in a standard format (e.g., uint8, [0,255])
# Ensure tensor is on CPU for numpy conversion if it's not already.
img_for_spix_normalized_cpu = img_tensor.cpu()
img_for_spix_denormalized = denormalize_transform(img_for_spix_normalized_cpu)

# Convert to NumPy array, scale to [0, 255], and change layout from (C,H,W) to (H,W,C)
img_for_spix_numpy = np.array(img_for_spix_denormalized * 255).transpose(1, 2, 0)

# Rescale (downsample) the image
# anti_aliasing=True is important for quality when downscaling
# channel_axis=2 indicates the last axis is channels for skimage.transform.rescale
img_for_spix_rescaled = rescale(
img_for_spix_numpy,
1 / downsample_factor,
anti_aliasing=True,
channel_axis=2
).round().clip(0, 255).astype(np.uint8)

# Apply the chosen SLIC algorithm
if spix_method == 'fastslic':
slic_engine = SlicAvx2(num_components=int(k_i), compactness=m_i)
assignment = slic_engine.iterate(img_for_spix_rescaled)
elif spix_method == 'slic':
# skimage.segmentation.slic expects HWC format.
# channel_axis=-1 (or 2 for 3D HWC) might be needed depending on the version of skimage
# and if it correctly infers multichannel from the last dimension.
assignment = slic(
img_for_spix_rescaled,
n_segments=int(k_i),
compactness=m_i,
channel_axis=-1, # Explicitly state channel axis for skimage
start_label=0 # Ensures segments start from 0
)
else:
raise NotImplementedError(f"Superpixel method {spix_method} not implemented.")

# Convert assignment map to tensor and add a channel dimension (1, H_spix, W_spix)
batch_assignments.append(torch.tensor(assignment, dtype=torch.int32).unsqueeze(0))

# Explicitly delete intermediate NumPy arrays and SLIC engine to help GC
del img_for_spix_normalized_cpu
del img_for_spix_denormalized
del img_for_spix_numpy
del img_for_spix_rescaled
del assignment # The NumPy array from SLIC
# slic_engine is minor, but can be deleted if it stores state and was created (fastslic path)
if 'slic_engine' in locals():
del slic_engine

# Stack all assignment maps into a single batch tensor (N, 1, H_spix, W_spix)
# and move to the specified device.
return torch.stack(batch_assignments).to(device)


class SpixImageFolder(datasets.ImageFolder):
def __init__(
self,
root: str,
n_segments=196,
compactness=10,
downsample=2,
n_segments=196, # Default number of superpixels (K) if not adaptive
compactness=10, # Default compactness (m) if not adaptive
downsample=2, # Default downsample factor for superpixel generation
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = datasets.folder.default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
spix_method = 'fastslic',
spix_method = 'fastslic', # Default superpixel algorithm
adaptive_superpixels = False, # If True, superpixels are not generated by this class
):
super().__init__(root, transform, target_transform, loader, is_valid_file)
self.n_segments = n_segments
self.n_segments = n_segments # K for static superpixel mode
self.adaptive_superpixels = adaptive_superpixels # Flag to control behavior in __getitem__
# Denormalize transform is needed for superpixel generation as models operate on normalized images
self.denormalize = Denormalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
self.compactness = compactness
self.compactness = compactness # m for static superpixel mode
self.downsample = downsample
self.spix_method = spix_method

def __getitem__(self, index: int):
"""
Args:
index (int): Index
index (int): Index of the data sample to retrieve.

Returns:
tuple: (sample, assignment, target) where assignment is a map of superpixel indices for each pixel, and target is class_index of the target class.
tuple:
- If `self.adaptive_superpixels` is `True` (adaptive mode):
`(sample, target)`
where `sample` is the transformed image tensor, and `target` is the class index.
- If `self.adaptive_superpixels` is `False` (static superpixel mode):
`(sample, assignment, target)`
where `assignment` is the superpixel map tensor for the sample.
"""
path, target = self.samples[index]
sample = self.loader(path)
path, target = self.samples[index] # Get image path and class target
sample = self.loader(path) # Load image using the specified loader

# Apply image transformations (e.g., augmentations, normalization)
if self.transform is not None:
# augmented sample
sample = self.transform(sample)

# temporarily convert to [0, 255] and resize to acquire superpixels.
sample_for_spix = np.array(self.denormalize(sample) * 255).transpose(1, 2, 0)
sample_for_spix = rescale(sample_for_spix, 1 / self.downsample, anti_aliasing=True, channel_axis=2).round().clip(0, 255).astype(np.uint8)
if self.spix_method == 'fastslic':
slic_ = SlicAvx2(num_components=self.n_segments, compactness=self.compactness)
assignment = slic_.iterate(sample_for_spix)
elif self.spix_method == 'slic':
assignment = slic(sample_for_spix, n_segments=self.n_segments, compactness=self.compactness)
else:
raise NotImplementedError

assignment = torch.tensor(assignment).unsqueeze(0)

# Apply target transformations if any
if self.target_transform is not None:
target = self.target_transform(target)

return sample, assignment, target

if self.adaptive_superpixels:
# In adaptive mode, superpixels are generated on-the-fly in the training loop (engine.py).
# This class only returns the image and its target.
return sample, target
else:
# In static mode, superpixels are generated here using fixed K (n_segments) and m (compactness).
# The `generate_superpixels` function expects a batch of images,
# so we unsqueeze the single sample to create a batch of 1.

# Determine the device of the sample tensor. For datasets, this is typically CPU.
# The generate_superpixels function handles moving the result to a specified device,
# but its inputs (especially for denormalization and numpy conversion) are often expected on CPU.
current_device = sample.device

# Call the helper function to generate superpixel assignments for the single image.
# K and m values are taken from self.n_segments and self.compactness.
assignment_batch = generate_superpixels(
image_tensor_batch=sample.unsqueeze(0), # Convert (C,H,W) to (1,C,H,W)
k_values=torch.tensor([float(self.n_segments)], device=current_device), # K for this image
m_values=torch.tensor([float(self.compactness)], device=current_device),# m for this image
denormalize_transform=self.denormalize, # Pass the denormalization utility
downsample_factor=self.downsample, # Pass the downsampling factor
spix_method=self.spix_method, # Pass the SLIC algorithm choice
device=current_device # Superpixels generated on current_device (CPU)
)
# Remove the batch dimension from the assignment map (1,1,H_spix,W_spix) -> (1,H_spix,W_spix)
assignment = assignment_batch.squeeze(0)
return sample, assignment, target


class Denormalize(torch.nn.Module):
Expand Down Expand Up @@ -145,7 +263,17 @@ def build_dataset(is_train, args):
elif args.data_set == 'IMNET':
root = os.path.join(args.data_path, 'train' if is_train else 'val')
if 'suit' in args.model:
dataset = SpixImageFolder(root, transform=transform, n_segments=args.n_spix_segments, compactness=args.compactness, downsample=args.downsample, spix_method=args.spix_method)
# Pass adaptive_superpixels flag, defaulting to False for now as args.adaptive_superpixels is not yet defined
# This will be updated later to use args.adaptive_superpixels
dataset = SpixImageFolder(
root,
transform=transform,
n_segments=args.n_spix_segments,
compactness=args.compactness,
downsample=args.downsample,
spix_method=args.spix_method,
adaptive_superpixels=getattr(args, 'adaptive_superpixels', False) # Use getattr to avoid error if not present
)
else:
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = 1000
Expand Down
Loading