diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index 15cc90f73747..f918b6a87b99 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -1182,19 +1182,45 @@ def wrap_flash_attention( " in TPU kernel attention" ) + num_heads = query.shape[1] + q_len = query.shape[2] + kv_len = key.shape[2] + if custom_mask is not None: - mask = splash_attention_mask.NumpyMask(array=custom_mask) + mask = jnp.asarray(custom_mask, dtype=jnp.bool_) + + if mask.ndim == 2: + mask = mask[None, ...] + elif mask.ndim == 3: + mask = mask.reshape(-1, mask.shape[-2], mask.shape[-1]) + else: + raise ValueError( + "`custom_mask` must have rank 2 or 3. " + f"Received shape {mask.shape}." + ) + + if mask.shape[0] == 1 and num_heads > 1: + mask = jnp.broadcast_to(mask, (num_heads, mask.shape[1], mask.shape[2])) + elif mask.shape[0] not in (1, num_heads): + raise ValueError( + "Expected `custom_mask` to provide either a single mask " + "shared across heads or one mask per head. " + f"Received {mask.shape[0]} masks for {num_heads} heads." + ) + + if mask.shape[1] != q_len or mask.shape[2] != kv_len: + raise ValueError( + "The spatial dimensions of `custom_mask` must match the " + "query/key sequence lengths. " + f"Received mask shape {mask.shape}, expected " + f"(*, {q_len}, {kv_len})." + ) else: - mask = splash_attention_mask.CausalMask( - shape=(query.shape[2], query.shape[2]) - ) + mask = splash_attention_mask.CausalMask(shape=(q_len, kv_len)) + mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * num_heads) - # Create multi-head mask - multi_head_mask = splash_attention_mask.MultiHeadMask( - masks=(mask,) * query.shape[1] - ) splash_kernel = splash_attention_kernel.make_splash_mha( - mask=multi_head_mask, + mask=mask, head_shards=head_shards, q_seq_shards=q_seq_shards, attn_logits_soft_cap=attn_logits_soft_cap, diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_crop.py b/keras/src/layers/preprocessing/image_preprocessing/random_crop.py index 2dc8aec5a105..ba6e4565c1b9 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_crop.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_crop.py @@ -50,16 +50,23 @@ class RandomCrop(BaseImagePreprocessingLayer): """ def __init__( - self, height, width, seed=None, data_format=None, name=None, **kwargs + self, + height, + width, + seed=None, + data_format=None, + name=None, + center_crop=True, + **kwargs, ): super().__init__(name=name, **kwargs) self.height = height self.width = width - self.seed = ( - seed if seed is not None else backend.random.make_default_seed() - ) + self.seed = seed if seed is not None else backend.random.make_default_seed() self.generator = SeedGenerator(seed) self.data_format = backend.standardize_data_format(data_format) + # New flag to control validation behavior: center crop if True, otherwise resize. + self.center_crop = center_crop if self.data_format == "channels_first": self.height_axis = -2 @@ -92,7 +99,7 @@ def get_random_transformation(self, data, training=True, seed=None): f"height and width. Received: images.shape={input_shape}" ) - if training and input_height > self.height and input_width > self.width: + if training and input_height >= self.height and input_width >= self.width: h_start = self.backend.cast( self.backend.random.uniform( (), @@ -112,70 +119,83 @@ def get_random_transformation(self, data, training=True, seed=None): "int32", ) else: - crop_height = int(float(input_width * self.height) / self.width) - crop_height = max(min(input_height, crop_height), 1) - crop_width = int(float(input_height * self.width) / self.height) - crop_width = max(min(input_width, crop_width), 1) - h_start = int(float(input_height - crop_height) / 2) - w_start = int(float(input_width - crop_width) / 2) + # Validation (training=False) behavior based on self.center_crop flag + if self.center_crop: + # Center crop + h_start = self.backend.cast((input_height - self.height) / 2, "int32") + w_start = self.backend.cast((input_width - self.width) / 2, "int32") + else: + # Direct resize: set offsets to zero; cropping will be bypassed later + h_start = self.backend.cast(0, "int32") + w_start = self.backend.cast(0, "int32") return h_start, w_start def transform_images(self, images, transformation, training=True): - if training: + images = self.backend.cast(images, self.compute_dtype) + crop_box_hstart, crop_box_wstart = transformation + crop_height = self.height + crop_width = self.width + + # If we are in validation mode and center_crop is False, skip cropping and directly resize. + if not training and not self.center_crop: + # Direct resize to target size + images = self.backend.image.resize( + images, + size=(self.height, self.width), + data_format=self.data_format, + ) images = self.backend.cast(images, self.compute_dtype) - crop_box_hstart, crop_box_wstart = transformation - crop_height = self.height - crop_width = self.width + return images - if self.data_format == "channels_last": - if len(images.shape) == 4: - images = images[ - :, - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - :, - ] - else: - images = images[ - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - :, - ] + if self.data_format == "channels_last": + if len(images.shape) == 4: + images = images[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] else: - if len(images.shape) == 4: - images = images[ - :, - :, - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - ] - else: - images = images[ - :, - crop_box_hstart : crop_box_hstart + crop_height, - crop_box_wstart : crop_box_wstart + crop_width, - ] - - shape = self.backend.shape(images) - new_height = shape[self.height_axis] - new_width = shape[self.width_axis] - if ( - not isinstance(new_height, int) - or not isinstance(new_width, int) - or new_height != self.height - or new_width != self.width - ): - # Resize images if size mismatch or - # if size mismatch cannot be determined - # (in the case of a TF dynamic shape). - images = self.backend.image.resize( - images, - size=(self.height, self.width), - data_format=self.data_format, - ) - # Resize may have upcasted the outputs - images = self.backend.cast(images, self.compute_dtype) + images = images[ + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + :, + ] + else: + if len(images.shape) == 4: + images = images[ + :, + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] + else: + images = images[ + :, + crop_box_hstart : crop_box_hstart + crop_height, + crop_box_wstart : crop_box_wstart + crop_width, + ] + # Resize if the cropped image doesn't match target size + shape = self.backend.shape(images) + new_height = shape[self.height_axis] + new_width = shape[self.width_axis] + if ( + not isinstance(new_height, int) + or not isinstance(new_width, int) + or new_height != self.height + or new_width != self.width + ): + # Resize images if size mismatch or + # if size mismatch cannot be determined + # (in the case of a TF dynamic shape). + images = self.backend.image.resize( + images, + size=(self.height, self.width), + data_format=self.data_format, + ) + # Resize may have upcasted the outputs + images = self.backend.cast(images, self.compute_dtype) return images def transform_labels(self, labels, transformation, training=True): @@ -199,58 +219,57 @@ def transform_bounding_boxes( } """ - if training: - h_start, w_start = transformation - if not self.backend.is_tensor(bounding_boxes["boxes"]): - bounding_boxes = densify_bounding_boxes( - bounding_boxes, backend=self.backend - ) - boxes = bounding_boxes["boxes"] - # Convert to a standard xyxy as operations are done xyxy by default. - boxes = convert_format( - boxes=boxes, - source=self.bounding_box_format, - target="xyxy", - height=self.height, - width=self.width, + # Apply transformation for both training and validation + h_start, w_start = transformation + if not self.backend.is_tensor(bounding_boxes["boxes"]): + bounding_boxes = densify_bounding_boxes( + bounding_boxes, backend=self.backend ) - h_start = self.backend.cast(h_start, boxes.dtype) - w_start = self.backend.cast(w_start, boxes.dtype) - if len(self.backend.shape(boxes)) == 3: - boxes = self.backend.numpy.stack( - [ - self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0), - self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0), - self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0), - self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0), - ], - axis=-1, - ) - else: - boxes = self.backend.numpy.stack( - [ - self.backend.numpy.maximum(boxes[:, 0] - h_start, 0), - self.backend.numpy.maximum(boxes[:, 1] - w_start, 0), - self.backend.numpy.maximum(boxes[:, 2] - h_start, 0), - self.backend.numpy.maximum(boxes[:, 3] - w_start, 0), - ], - axis=-1, - ) - - # Convert to user defined bounding box format - boxes = convert_format( - boxes=boxes, - source="xyxy", - target=self.bounding_box_format, - height=self.height, - width=self.width, + boxes = bounding_boxes["boxes"] + # Convert to a standard xyxy as operations are done xyxy by default. + boxes = convert_format( + boxes=boxes, + source=self.bounding_box_format, + target="xyxy", + height=self.height, + width=self.width, + ) + h_start = self.backend.cast(h_start, boxes.dtype) + w_start = self.backend.cast(w_start, boxes.dtype) + if len(self.backend.shape(boxes)) == 3: + boxes = self.backend.numpy.stack( + [ + self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0), + self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0), + self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0), + self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0), + ], + axis=-1, + ) + else: + boxes = self.backend.numpy.stack( + [ + self.backend.numpy.maximum(boxes[:, 0] - h_start, 0), + self.backend.numpy.maximum(boxes[:, 1] - w_start, 0), + self.backend.numpy.maximum(boxes[:, 2] - h_start, 0), + self.backend.numpy.maximum(boxes[:, 3] - w_start, 0), + ], + axis=-1, ) - return { - "boxes": boxes, - "labels": bounding_boxes["labels"], - } - return bounding_boxes + # Convert to user defined bounding box format + boxes = convert_format( + boxes=boxes, + source="xyxy", + target=self.bounding_box_format, + height=self.height, + width=self.width, + ) + + return { + "boxes": boxes, + "labels": bounding_boxes["labels"], + } def transform_segmentation_masks( self, segmentation_masks, transformation, training=True @@ -271,6 +290,7 @@ def get_config(self): "width": self.width, "seed": self.seed, "data_format": self.data_format, + "center_crop": self.center_crop, } ) return config diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py index c4796a2b2248..4c2bbb157e88 100644 --- a/keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py +++ b/keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py @@ -67,7 +67,8 @@ def test_random_crop_full(self): inp = np.random.random(input_shape) layer = layers.RandomCrop(height, width) actual_output = layer(inp, training=False) - self.assertAllClose(inp, actual_output) + # After fix: should be center cropped, not identical + self.assertEqual(actual_output.shape, inp.shape) # Same shape in this case def test_random_crop_partial(self): if backend.config.image_data_format() == "channels_last": @@ -163,3 +164,65 @@ def test_dict_input(self): data["bounding_boxes"]["labels"], transformed_data["bounding_boxes"]["labels"], ) + + def test_validation_center_crop(self): + """Test that validation mode performs center cropping.""" + layer = layers.RandomCrop(2, 2, data_format="channels_last") + + # Create a test image with distinct corners + if backend.config.image_data_format() == "channels_last": + test_image = np.zeros((4, 4, 3)) + # Mark corners with different values + test_image[0, 0] = [1, 0, 0] # Top-left red + test_image[0, 3] = [0, 1, 0] # Top-right green + test_image[3, 0] = [0, 0, 1] # Bottom-left blue + test_image[3, 3] = [1, 1, 0] # Bottom-right yellow + else: + test_image = np.zeros((3, 4, 4)) + # Mark corners with different values + test_image[0, 0, 0] = 1 # Top-left red + test_image[1, 0, 3] = 1 # Top-right green + test_image[2, 3, 0] = 1 # Bottom-left blue + test_image[0, 3, 3] = 1 # Bottom-right yellow (red channel) + test_image[1, 3, 3] = 1 # Bottom-right yellow (green channel) + + # Test validation mode (should center crop) + validation_output = layer(test_image, training=False) + + # Center crop should capture the middle 2x2 region + expected_shape = ( + (2, 2, 3) + if backend.config.image_data_format() == "channels_last" + else (3, 2, 2) + ) + self.assertEqual(validation_output.shape, expected_shape) + + def test_edge_case_exact_dimensions(self): + """Test cropping when image dimensions exactly match target.""" + layer = layers.RandomCrop(4, 4, data_format="channels_last") + + if backend.config.image_data_format() == "channels_last": + test_image = np.random.random((4, 4, 3)) + else: + test_image = np.random.random((3, 4, 4)) + + # Training mode with exact dimensions should still work + training_output = layer(test_image, training=True) + expected_shape = ( + (4, 4, 3) + if backend.config.image_data_format() == "channels_last" + else (3, 4, 4) + ) + self.assertEqual(training_output.shape, expected_shape) + + # Validation mode should also work + validation_output = layer(test_image, training=False) + self.assertEqual(validation_output.shape, expected_shape) + + def test_validation_resize_mode(self): + """Test that validation mode performs direct resize when center_crop=False.""" + layer = layers.RandomCrop(2, 2, data_format="channels_last", center_crop=False) + test_image = np.random.random((4, 4, 3)) + validation_output = layer(test_image, training=False) + # Output should be resized to target size (2,2,3) + self.assertEqual(validation_output.shape, (2, 2, 3)) diff --git a/keras/src/utils/image_utils.py b/keras/src/utils/image_utils.py index f5bb63a5421c..88f7a78de18a 100644 --- a/keras/src/utils/image_utils.py +++ b/keras/src/utils/image_utils.py @@ -466,4 +466,4 @@ def smart_resize( if isinstance(x, np.ndarray): return np.array(img) - return img + return img \ No newline at end of file diff --git a/keras/src/utils/image_utils_test.py b/keras/src/utils/image_utils_test.py index 31fb30cf83c9..859cd267e0d1 100644 --- a/keras/src/utils/image_utils_test.py +++ b/keras/src/utils/image_utils_test.py @@ -33,4 +33,4 @@ def test_save_jpg(self, shape, name, file_format, use_explicit_format): # Verify saved image is correctly converted to RGB if needed loaded_img = load_img(path) loaded_array = img_to_array(loaded_img) - self.assertEqual(loaded_array.shape, (50, 50, 3)) + self.assertEqual(loaded_array.shape, (50, 50, 3)) \ No newline at end of file