Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,19 +1182,45 @@ def wrap_flash_attention(
" in TPU kernel attention"
)

num_heads = query.shape[1]
q_len = query.shape[2]
kv_len = key.shape[2]

if custom_mask is not None:
mask = splash_attention_mask.NumpyMask(array=custom_mask)
mask = jnp.asarray(custom_mask, dtype=jnp.bool_)

if mask.ndim == 2:
mask = mask[None, ...]
elif mask.ndim == 3:
mask = mask.reshape(-1, mask.shape[-2], mask.shape[-1])
else:
raise ValueError(
"`custom_mask` must have rank 2 or 3. "
f"Received shape {mask.shape}."
)

if mask.shape[0] == 1 and num_heads > 1:
mask = jnp.broadcast_to(mask, (num_heads, mask.shape[1], mask.shape[2]))
elif mask.shape[0] not in (1, num_heads):
raise ValueError(
"Expected `custom_mask` to provide either a single mask "
"shared across heads or one mask per head. "
f"Received {mask.shape[0]} masks for {num_heads} heads."
)

if mask.shape[1] != q_len or mask.shape[2] != kv_len:
raise ValueError(
"The spatial dimensions of `custom_mask` must match the "
"query/key sequence lengths. "
f"Received mask shape {mask.shape}, expected "
f"(*, {q_len}, {kv_len})."
)
else:
mask = splash_attention_mask.CausalMask(
shape=(query.shape[2], query.shape[2])
)
mask = splash_attention_mask.CausalMask(shape=(q_len, kv_len))
mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * num_heads)

# Create multi-head mask
multi_head_mask = splash_attention_mask.MultiHeadMask(
masks=(mask,) * query.shape[1]
)
splash_kernel = splash_attention_kernel.make_splash_mha(
mask=multi_head_mask,
mask=mask,
head_shards=head_shards,
q_seq_shards=q_seq_shards,
attn_logits_soft_cap=attn_logits_soft_cap,
Expand Down
242 changes: 131 additions & 111 deletions keras/src/layers/preprocessing/image_preprocessing/random_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,23 @@ class RandomCrop(BaseImagePreprocessingLayer):
"""

def __init__(
self, height, width, seed=None, data_format=None, name=None, **kwargs
self,
height,
width,
seed=None,
data_format=None,
name=None,
center_crop=True,
**kwargs,
):
super().__init__(name=name, **kwargs)
self.height = height
self.width = width
self.seed = (
seed if seed is not None else backend.random.make_default_seed()
)
self.seed = seed if seed is not None else backend.random.make_default_seed()
self.generator = SeedGenerator(seed)
self.data_format = backend.standardize_data_format(data_format)
# New flag to control validation behavior: center crop if True, otherwise resize.
self.center_crop = center_crop

if self.data_format == "channels_first":
self.height_axis = -2
Expand Down Expand Up @@ -92,7 +99,7 @@ def get_random_transformation(self, data, training=True, seed=None):
f"height and width. Received: images.shape={input_shape}"
)

if training and input_height > self.height and input_width > self.width:
if training and input_height >= self.height and input_width >= self.width:
h_start = self.backend.cast(
self.backend.random.uniform(
(),
Expand All @@ -112,70 +119,83 @@ def get_random_transformation(self, data, training=True, seed=None):
"int32",
)
else:
crop_height = int(float(input_width * self.height) / self.width)
crop_height = max(min(input_height, crop_height), 1)
crop_width = int(float(input_height * self.width) / self.height)
crop_width = max(min(input_width, crop_width), 1)
h_start = int(float(input_height - crop_height) / 2)
w_start = int(float(input_width - crop_width) / 2)
# Validation (training=False) behavior based on self.center_crop flag
if self.center_crop:
# Center crop
h_start = self.backend.cast((input_height - self.height) / 2, "int32")
w_start = self.backend.cast((input_width - self.width) / 2, "int32")
else:
# Direct resize: set offsets to zero; cropping will be bypassed later
h_start = self.backend.cast(0, "int32")
w_start = self.backend.cast(0, "int32")

return h_start, w_start

def transform_images(self, images, transformation, training=True):
if training:
images = self.backend.cast(images, self.compute_dtype)
crop_box_hstart, crop_box_wstart = transformation
crop_height = self.height
crop_width = self.width

# If we are in validation mode and center_crop is False, skip cropping and directly resize.
if not training and not self.center_crop:
# Direct resize to target size
images = self.backend.image.resize(
images,
size=(self.height, self.width),
data_format=self.data_format,
)
images = self.backend.cast(images, self.compute_dtype)
crop_box_hstart, crop_box_wstart = transformation
crop_height = self.height
crop_width = self.width
return images

if self.data_format == "channels_last":
if len(images.shape) == 4:
images = images[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
else:
images = images[
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
if self.data_format == "channels_last":
if len(images.shape) == 4:
images = images[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
else:
if len(images.shape) == 4:
images = images[
:,
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
else:
images = images[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]

shape = self.backend.shape(images)
new_height = shape[self.height_axis]
new_width = shape[self.width_axis]
if (
not isinstance(new_height, int)
or not isinstance(new_width, int)
or new_height != self.height
or new_width != self.width
):
# Resize images if size mismatch or
# if size mismatch cannot be determined
# (in the case of a TF dynamic shape).
images = self.backend.image.resize(
images,
size=(self.height, self.width),
data_format=self.data_format,
)
# Resize may have upcasted the outputs
images = self.backend.cast(images, self.compute_dtype)
images = images[
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
:,
]
else:
if len(images.shape) == 4:
images = images[
:,
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
else:
images = images[
:,
crop_box_hstart : crop_box_hstart + crop_height,
crop_box_wstart : crop_box_wstart + crop_width,
]
# Resize if the cropped image doesn't match target size
shape = self.backend.shape(images)
new_height = shape[self.height_axis]
new_width = shape[self.width_axis]
if (
not isinstance(new_height, int)
or not isinstance(new_width, int)
or new_height != self.height
or new_width != self.width
):
# Resize images if size mismatch or
# if size mismatch cannot be determined
# (in the case of a TF dynamic shape).
images = self.backend.image.resize(
images,
size=(self.height, self.width),
data_format=self.data_format,
)
# Resize may have upcasted the outputs
images = self.backend.cast(images, self.compute_dtype)
return images

def transform_labels(self, labels, transformation, training=True):
Expand All @@ -199,58 +219,57 @@ def transform_bounding_boxes(
}
"""

if training:
h_start, w_start = transformation
if not self.backend.is_tensor(bounding_boxes["boxes"]):
bounding_boxes = densify_bounding_boxes(
bounding_boxes, backend=self.backend
)
boxes = bounding_boxes["boxes"]
# Convert to a standard xyxy as operations are done xyxy by default.
boxes = convert_format(
boxes=boxes,
source=self.bounding_box_format,
target="xyxy",
height=self.height,
width=self.width,
# Apply transformation for both training and validation
h_start, w_start = transformation
if not self.backend.is_tensor(bounding_boxes["boxes"]):
bounding_boxes = densify_bounding_boxes(
bounding_boxes, backend=self.backend
)
h_start = self.backend.cast(h_start, boxes.dtype)
w_start = self.backend.cast(w_start, boxes.dtype)
if len(self.backend.shape(boxes)) == 3:
boxes = self.backend.numpy.stack(
[
self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0),
self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0),
self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0),
self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0),
],
axis=-1,
)
else:
boxes = self.backend.numpy.stack(
[
self.backend.numpy.maximum(boxes[:, 0] - h_start, 0),
self.backend.numpy.maximum(boxes[:, 1] - w_start, 0),
self.backend.numpy.maximum(boxes[:, 2] - h_start, 0),
self.backend.numpy.maximum(boxes[:, 3] - w_start, 0),
],
axis=-1,
)

# Convert to user defined bounding box format
boxes = convert_format(
boxes=boxes,
source="xyxy",
target=self.bounding_box_format,
height=self.height,
width=self.width,
boxes = bounding_boxes["boxes"]
# Convert to a standard xyxy as operations are done xyxy by default.
boxes = convert_format(
boxes=boxes,
source=self.bounding_box_format,
target="xyxy",
height=self.height,
width=self.width,
)
h_start = self.backend.cast(h_start, boxes.dtype)
w_start = self.backend.cast(w_start, boxes.dtype)
if len(self.backend.shape(boxes)) == 3:
boxes = self.backend.numpy.stack(
[
self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0),
self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0),
self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0),
self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0),
],
axis=-1,
)
else:
boxes = self.backend.numpy.stack(
[
self.backend.numpy.maximum(boxes[:, 0] - h_start, 0),
self.backend.numpy.maximum(boxes[:, 1] - w_start, 0),
self.backend.numpy.maximum(boxes[:, 2] - h_start, 0),
self.backend.numpy.maximum(boxes[:, 3] - w_start, 0),
],
axis=-1,
)

return {
"boxes": boxes,
"labels": bounding_boxes["labels"],
}
return bounding_boxes
# Convert to user defined bounding box format
boxes = convert_format(
boxes=boxes,
source="xyxy",
target=self.bounding_box_format,
height=self.height,
width=self.width,
)

return {
"boxes": boxes,
"labels": bounding_boxes["labels"],
}

def transform_segmentation_masks(
self, segmentation_masks, transformation, training=True
Expand All @@ -271,6 +290,7 @@ def get_config(self):
"width": self.width,
"seed": self.seed,
"data_format": self.data_format,
"center_crop": self.center_crop,
}
)
return config
Loading
Loading