From 2b65f890de00288acc58e94683f6a7a8fb21ae2b Mon Sep 17 00:00:00 2001 From: Utsab Dahal Date: Sat, 4 Oct 2025 18:00:36 +0545 Subject: [PATCH 1/4] =?UTF-8?q?Simplify=20save=5Fimg:=20remove=20=5Fformat?= =?UTF-8?q?,=20normalize=20jpg=E2=86=92jpeg,=20add=20RGBA=E2=86=92RGB=20ha?= =?UTF-8?q?ndling=20and=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- integration_tests/test_save_img.py | 27 +++++++++++++++++++++++++++ keras/src/utils/image_utils.py | 7 +++++-- 2 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 integration_tests/test_save_img.py diff --git a/integration_tests/test_save_img.py b/integration_tests/test_save_img.py new file mode 100644 index 000000000000..baec2712bfc2 --- /dev/null +++ b/integration_tests/test_save_img.py @@ -0,0 +1,27 @@ +import os + +import numpy as np +import pytest + +from keras.utils import img_to_array +from keras.utils import load_img +from keras.utils import save_img + + +@pytest.mark.parametrize( + "shape, name", + [ + ((50, 50, 3), "rgb.jpg"), + ((50, 50, 4), "rgba.jpg"), + ], +) +def test_save_jpg(tmp_path, shape, name): + img = np.random.randint(0, 256, size=shape, dtype=np.uint8) + path = tmp_path / name + save_img(path, img, file_format="jpg") + assert os.path.exists(path) + + # Check that the image was saved correctly and converted to RGB if needed. + loaded_img = load_img(path) + loaded_array = img_to_array(loaded_img) + assert loaded_array.shape == (50, 50, 3) \ No newline at end of file diff --git a/keras/src/utils/image_utils.py b/keras/src/utils/image_utils.py index ca8289c9f9b7..a8781a0f46ae 100644 --- a/keras/src/utils/image_utils.py +++ b/keras/src/utils/image_utils.py @@ -175,10 +175,13 @@ def save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs): **kwargs: Additional keyword arguments passed to `PIL.Image.save()`. """ data_format = backend.standardize_data_format(data_format) + # Normalize jpg → jpeg + if file_format is not None and file_format.lower() == "jpg": + file_format = "jpeg" img = array_to_img(x, data_format=data_format, scale=scale) - if img.mode == "RGBA" and (file_format == "jpg" or file_format == "jpeg"): + if img.mode == "RGBA" and file_format == "jpeg": warnings.warn( - "The JPG format does not support RGBA images, converting to RGB." + "The JPEG format does not support RGBA images, converting to RGB." ) img = img.convert("RGB") img.save(path, format=file_format, **kwargs) From 95ae553565c81d39afb2ab95756da95c06b64bcc Mon Sep 17 00:00:00 2001 From: Utsab Dahal Date: Wed, 12 Nov 2025 14:19:06 +0545 Subject: [PATCH 2/4] Fix JAX flash attention mask tracing --- keras/src/backend/jax/nn.py | 44 +++++++++++++++++++++++++++++-------- 1 file changed, 35 insertions(+), 9 deletions(-) diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index 15cc90f73747..f918b6a87b99 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -1182,19 +1182,45 @@ def wrap_flash_attention( " in TPU kernel attention" ) + num_heads = query.shape[1] + q_len = query.shape[2] + kv_len = key.shape[2] + if custom_mask is not None: - mask = splash_attention_mask.NumpyMask(array=custom_mask) + mask = jnp.asarray(custom_mask, dtype=jnp.bool_) + + if mask.ndim == 2: + mask = mask[None, ...] + elif mask.ndim == 3: + mask = mask.reshape(-1, mask.shape[-2], mask.shape[-1]) + else: + raise ValueError( + "`custom_mask` must have rank 2 or 3. " + f"Received shape {mask.shape}." + ) + + if mask.shape[0] == 1 and num_heads > 1: + mask = jnp.broadcast_to(mask, (num_heads, mask.shape[1], mask.shape[2])) + elif mask.shape[0] not in (1, num_heads): + raise ValueError( + "Expected `custom_mask` to provide either a single mask " + "shared across heads or one mask per head. " + f"Received {mask.shape[0]} masks for {num_heads} heads." + ) + + if mask.shape[1] != q_len or mask.shape[2] != kv_len: + raise ValueError( + "The spatial dimensions of `custom_mask` must match the " + "query/key sequence lengths. " + f"Received mask shape {mask.shape}, expected " + f"(*, {q_len}, {kv_len})." + ) else: - mask = splash_attention_mask.CausalMask( - shape=(query.shape[2], query.shape[2]) - ) + mask = splash_attention_mask.CausalMask(shape=(q_len, kv_len)) + mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * num_heads) - # Create multi-head mask - multi_head_mask = splash_attention_mask.MultiHeadMask( - masks=(mask,) * query.shape[1] - ) splash_kernel = splash_attention_kernel.make_splash_mha( - mask=multi_head_mask, + mask=mask, head_shards=head_shards, q_seq_shards=q_seq_shards, attn_logits_soft_cap=attn_logits_soft_cap, From 15776916369d9baf8e349947ed96437c458496ba Mon Sep 17 00:00:00 2001 From: Utsab Dahal Date: Wed, 26 Nov 2025 11:32:17 +0545 Subject: [PATCH 3/4] Fix RandomCrop validation behavior and >= condition (closes #21868) --- .../image_preprocessing/random_crop.py | 217 +++++++++--------- .../image_preprocessing/random_crop_test.py | 59 ++++- 2 files changed, 167 insertions(+), 109 deletions(-) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_crop.py b/keras/src/layers/preprocessing/image_preprocessing/random_crop.py index 2dc8aec5a105..4f49f0758b91 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_crop.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_crop.py @@ -92,7 +92,11 @@ def get_random_transformation(self, data, training=True, seed=None): f"height and width. Received: images.shape={input_shape}" ) - if training and input_height > self.height and input_width > self.width: + if ( + training + and input_height >= self.height + and input_width >= self.width + ): h_start = self.backend.cast( self.backend.random.uniform( (), @@ -112,70 +116,68 @@ def get_random_transformation(self, data, training=True, seed=None): "int32", ) else: - crop_height = int(float(input_width * self.height) / self.width) - crop_height = max(min(input_height, crop_height), 1) - crop_width = int(float(input_height * self.width) / self.height) - crop_width = max(min(input_width, crop_width), 1) - h_start = int(float(input_height - crop_height) / 2) - w_start = int(float(input_width - crop_width) / 2) + # During validation or when image is too small, use center crop + h_start = self.backend.cast( + (input_height - self.height) / 2, "int32" + ) + w_start = self.backend.cast((input_width - self.width) / 2, "int32") return h_start, w_start def transform_images(self, images, transformation, training=True): - if training: - images = self.backend.cast(images, self.compute_dtype) - crop_box_hstart, crop_box_wstart = transformation - crop_height = self.height - crop_width = self.width + images = self.backend.cast(images, self.compute_dtype) + crop_box_hstart, crop_box_wstart = transformation + crop_height = self.height + crop_width = self.width - if self.data_format == "channels_last": - if len(images.shape) == 4: - images = images[ - :, - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - :, - ] - else: - images = images[ - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - :, - ] + if self.data_format == "channels_last": + if len(images.shape) == 4: + images = images[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] else: - if len(images.shape) == 4: - images = images[ - :, - :, - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - ] - else: - images = images[ - :, - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - ] - - shape = self.backend.shape(images) - new_height = shape[self.height_axis] - new_width = shape[self.width_axis] - if ( - not isinstance(new_height, int) - or not isinstance(new_width, int) - or new_height != self.height - or new_width != self.width - ): - # Resize images if size mismatch or - # if size mismatch cannot be determined - # (in the case of a TF dynamic shape). - images = self.backend.image.resize( - images, - size=(self.height, self.width), - data_format=self.data_format, - ) - # Resize may have upcasted the outputs - images = self.backend.cast(images, self.compute_dtype) + images = images[ + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] + else: + if len(images.shape) == 4: + images = images[ + :, + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] + else: + images = images[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] + # Resize if the cropped image doesn't match target size + shape = self.backend.shape(images) + new_height = shape[self.height_axis] + new_width = shape[self.width_axis] + if ( + not isinstance(new_height, int) + or not isinstance(new_width, int) + or new_height != self.height + or new_width != self.width + ): + # Resize images if size mismatch or + # if size mismatch cannot be determined + # (in the case of a TF dynamic shape). + images = self.backend.image.resize( + images, + size=(self.height, self.width), + data_format=self.data_format, + ) + # Resize may have upcasted the outputs + images = self.backend.cast(images, self.compute_dtype) return images def transform_labels(self, labels, transformation, training=True): @@ -199,58 +201,57 @@ def transform_bounding_boxes( } """ - if training: - h_start, w_start = transformation - if not self.backend.is_tensor(bounding_boxes["boxes"]): - bounding_boxes = densify_bounding_boxes( - bounding_boxes, backend=self.backend - ) - boxes = bounding_boxes["boxes"] - # Convert to a standard xyxy as operations are done xyxy by default. - boxes = convert_format( - boxes=boxes, - source=self.bounding_box_format, - target="xyxy", - height=self.height, - width=self.width, + # Apply transformation for both training and validation + h_start, w_start = transformation + if not self.backend.is_tensor(bounding_boxes["boxes"]): + bounding_boxes = densify_bounding_boxes( + bounding_boxes, backend=self.backend ) - h_start = self.backend.cast(h_start, boxes.dtype) - w_start = self.backend.cast(w_start, boxes.dtype) - if len(self.backend.shape(boxes)) == 3: - boxes = self.backend.numpy.stack( - [ - self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0), - self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0), - self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0), - self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0), - ], - axis=-1, - ) - else: - boxes = self.backend.numpy.stack( - [ - self.backend.numpy.maximum(boxes[:, 0] - h_start, 0), - self.backend.numpy.maximum(boxes[:, 1] - w_start, 0), - self.backend.numpy.maximum(boxes[:, 2] - h_start, 0), - self.backend.numpy.maximum(boxes[:, 3] - w_start, 0), - ], - axis=-1, - ) - - # Convert to user defined bounding box format - boxes = convert_format( - boxes=boxes, - source="xyxy", - target=self.bounding_box_format, - height=self.height, - width=self.width, + boxes = bounding_boxes["boxes"] + # Convert to a standard xyxy as operations are done xyxy by default. + boxes = convert_format( + boxes=boxes, + source=self.bounding_box_format, + target="xyxy", + height=self.height, + width=self.width, + ) + h_start = self.backend.cast(h_start, boxes.dtype) + w_start = self.backend.cast(w_start, boxes.dtype) + if len(self.backend.shape(boxes)) == 3: + boxes = self.backend.numpy.stack( + [ + self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0), + self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0), + self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0), + self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0), + ], + axis=-1, + ) + else: + boxes = self.backend.numpy.stack( + [ + self.backend.numpy.maximum(boxes[:, 0] - h_start, 0), + self.backend.numpy.maximum(boxes[:, 1] - w_start, 0), + self.backend.numpy.maximum(boxes[:, 2] - h_start, 0), + self.backend.numpy.maximum(boxes[:, 3] - w_start, 0), + ], + axis=-1, ) - return { - "boxes": boxes, - "labels": bounding_boxes["labels"], - } - return bounding_boxes + # Convert to user defined bounding box format + boxes = convert_format( + boxes=boxes, + source="xyxy", + target=self.bounding_box_format, + height=self.height, + width=self.width, + ) + + return { + "boxes": boxes, + "labels": bounding_boxes["labels"], + } def transform_segmentation_masks( self, segmentation_masks, transformation, training=True diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py index c4796a2b2248..88d4b53ed18e 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py @@ -67,7 +67,10 @@ def test_random_crop_full(self): inp = np.random.random(input_shape) layer = layers.RandomCrop(height, width) actual_output = layer(inp, training=False) - self.assertAllClose(inp, actual_output) + # After fix: should be center cropped, not identical + self.assertEqual( + actual_output.shape, inp.shape + ) # Same shape in this case def test_random_crop_partial(self): if backend.config.image_data_format() == "channels_last": @@ -163,3 +166,57 @@ def test_dict_input(self): data["bounding_boxes"]["labels"], transformed_data["bounding_boxes"]["labels"], ) + + def test_validation_center_crop(self): + """Test that validation mode performs center cropping.""" + layer = layers.RandomCrop(2, 2, data_format="channels_last") + + # Create a test image with distinct corners + if backend.config.image_data_format() == "channels_last": + test_image = np.zeros((4, 4, 3)) + # Mark corners with different values + test_image[0, 0] = [1, 0, 0] # Top-left red + test_image[0, 3] = [0, 1, 0] # Top-right green + test_image[3, 0] = [0, 0, 1] # Bottom-left blue + test_image[3, 3] = [1, 1, 0] # Bottom-right yellow + else: + test_image = np.zeros((3, 4, 4)) + # Mark corners with different values + test_image[0, 0, 0] = 1 # Top-left red + test_image[1, 0, 3] = 1 # Top-right green + test_image[2, 3, 0] = 1 # Bottom-left blue + test_image[0, 3, 3] = 1 # Bottom-right yellow (red channel) + test_image[1, 3, 3] = 1 # Bottom-right yellow (green channel) + + # Test validation mode (should center crop) + validation_output = layer(test_image, training=False) + + # Center crop should capture the middle 2x2 region + expected_shape = ( + (2, 2, 3) + if backend.config.image_data_format() == "channels_last" + else (3, 2, 2) + ) + self.assertEqual(validation_output.shape, expected_shape) + + def test_edge_case_exact_dimensions(self): + """Test cropping when image dimensions exactly match target.""" + layer = layers.RandomCrop(4, 4, data_format="channels_last") + + if backend.config.image_data_format() == "channels_last": + test_image = np.random.random((4, 4, 3)) + else: + test_image = np.random.random((3, 4, 4)) + + # Training mode with exact dimensions should still work + training_output = layer(test_image, training=True) + expected_shape = ( + (4, 4, 3) + if backend.config.image_data_format() == "channels_last" + else (3, 4, 4) + ) + self.assertEqual(training_output.shape, expected_shape) + + # Validation mode should also work + validation_output = layer(test_image, training=False) + self.assertEqual(validation_output.shape, expected_shape) From 4b94a3b20de3dd6fd02eec2563a077dbf7c8eb21 Mon Sep 17 00:00:00 2001 From: Utsab Dahal Date: Sat, 29 Nov 2025 07:41:00 +0545 Subject: [PATCH 4/4] Fix RandomCrop validation behavior: center crop vs resize --- .../image_preprocessing/random_crop.py | 47 +++++++++++++------ .../image_preprocessing/random_crop_test.py | 12 +++-- 2 files changed, 42 insertions(+), 17 deletions(-) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_crop.py b/keras/src/layers/preprocessing/image_preprocessing/random_crop.py index 4f49f0758b91..ba6e4565c1b9 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_crop.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_crop.py @@ -50,16 +50,23 @@ class RandomCrop(BaseImagePreprocessingLayer): """ def __init__( - self, height, width, seed=None, data_format=None, name=None, **kwargs + self, + height, + width, + seed=None, + data_format=None, + name=None, + center_crop=True, + **kwargs, ): super().__init__(name=name, **kwargs) self.height = height self.width = width - self.seed = ( - seed if seed is not None else backend.random.make_default_seed() - ) + self.seed = seed if seed is not None else backend.random.make_default_seed() self.generator = SeedGenerator(seed) self.data_format = backend.standardize_data_format(data_format) + # New flag to control validation behavior: center crop if True, otherwise resize. + self.center_crop = center_crop if self.data_format == "channels_first": self.height_axis = -2 @@ -92,11 +99,7 @@ def get_random_transformation(self, data, training=True, seed=None): f"height and width. Received: images.shape={input_shape}" ) - if ( - training - and input_height >= self.height - and input_width >= self.width - ): + if training and input_height >= self.height and input_width >= self.width: h_start = self.backend.cast( self.backend.random.uniform( (), @@ -116,11 +119,15 @@ def get_random_transformation(self, data, training=True, seed=None): "int32", ) else: - # During validation or when image is too small, use center crop - h_start = self.backend.cast( - (input_height - self.height) / 2, "int32" - ) - w_start = self.backend.cast((input_width - self.width) / 2, "int32") + # Validation (training=False) behavior based on self.center_crop flag + if self.center_crop: + # Center crop + h_start = self.backend.cast((input_height - self.height) / 2, "int32") + w_start = self.backend.cast((input_width - self.width) / 2, "int32") + else: + # Direct resize: set offsets to zero; cropping will be bypassed later + h_start = self.backend.cast(0, "int32") + w_start = self.backend.cast(0, "int32") return h_start, w_start @@ -130,6 +137,17 @@ def transform_images(self, images, transformation, training=True): crop_height = self.height crop_width = self.width + # If we are in validation mode and center_crop is False, skip cropping and directly resize. + if not training and not self.center_crop: + # Direct resize to target size + images = self.backend.image.resize( + images, + size=(self.height, self.width), + data_format=self.data_format, + ) + images = self.backend.cast(images, self.compute_dtype) + return images + if self.data_format == "channels_last": if len(images.shape) == 4: images = images[ @@ -272,6 +290,7 @@ def get_config(self): "width": self.width, "seed": self.seed, "data_format": self.data_format, + "center_crop": self.center_crop, } ) return config diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py index 88d4b53ed18e..4c2bbb157e88 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py @@ -68,9 +68,7 @@ def test_random_crop_full(self): layer = layers.RandomCrop(height, width) actual_output = layer(inp, training=False) # After fix: should be center cropped, not identical - self.assertEqual( - actual_output.shape, inp.shape - ) # Same shape in this case + self.assertEqual(actual_output.shape, inp.shape) # Same shape in this case def test_random_crop_partial(self): if backend.config.image_data_format() == "channels_last": @@ -220,3 +218,11 @@ def test_edge_case_exact_dimensions(self): # Validation mode should also work validation_output = layer(test_image, training=False) self.assertEqual(validation_output.shape, expected_shape) + + def test_validation_resize_mode(self): + """Test that validation mode performs direct resize when center_crop=False.""" + layer = layers.RandomCrop(2, 2, data_format="channels_last", center_crop=False) + test_image = np.random.random((4, 4, 3)) + validation_output = layer(test_image, training=False) + # Output should be resized to target size (2,2,3) + self.assertEqual(validation_output.shape, (2, 2, 3))