diff --git a/keras/src/layers/attention/attention.py b/keras/src/layers/attention/attention.py index 04e3f399c5e5..6a9bf1a76205 100644 --- a/keras/src/layers/attention/attention.py +++ b/keras/src/layers/attention/attention.py @@ -119,7 +119,7 @@ def _calculate_scores(self, query, key): Tensor of shape `(batch_size, Tq, Tv)`. """ if self.score_mode == "dot": - scores = ops.matmul(query, ops.transpose(key, axes=[0, 2, 1])) + scores = ops.matmul(query, ops.swapaxes(key, -2, -1)) if self.scale is not None: scores = ops.multiply(scores, self.scale) elif self.score_mode == "concat": @@ -256,7 +256,7 @@ def compute_output_shape(self, input_shape): output_shape = (*query_shape[:-1], value_shape[-1]) if self._return_attention_scores: - scores_shape = (query_shape[0], query_shape[1], key_shape[1]) + scores_shape = (*query_shape[:-1], key_shape[-2]) return output_shape, scores_shape return output_shape @@ -283,10 +283,9 @@ def compute_output_spec( # Handle attention scores if requested if self._return_attention_scores or return_attention_scores: scores_shape = ( - query.shape[0], - query.shape[1], - key.shape[1], - ) # (batch_size, Tq, Tv) + *query.shape[:-1], + key.shape[-2], + ) # (*batch_dims, Tq, Tv) attention_scores_spec = KerasTensor( scores_shape, dtype=self.compute_dtype ) diff --git a/keras/src/layers/attention/attention_test.py b/keras/src/layers/attention/attention_test.py index 2dc5e44825ad..2f5927315f7a 100644 --- a/keras/src/layers/attention/attention_test.py +++ b/keras/src/layers/attention/attention_test.py @@ -394,6 +394,20 @@ def test_attention_compute_output_shape(self): output.shape, ) + def test_attention_nd_inputs(self): + """Test that Attention handles N-D inputs (e.g. 4D from Conv2D).""" + layer = layers.Attention() + # 4D inputs: (batch, height, width, channels) + query = np.random.random((2, 8, 6, 4)).astype(np.float32) + value = np.random.random((2, 8, 6, 4)).astype(np.float32) + output = layer([query, value]) + self.assertEqual(output.shape, (2, 8, 6, 4)) + + # With return_attention_scores + output, scores = layer([query, value], return_attention_scores=True) + self.assertEqual(output.shape, (2, 8, 6, 4)) + self.assertEqual(scores.shape, (2, 8, 6, 6)) + def test_return_attention_scores_true(self): """Test that the layer returns attention scores along with outputs.""" # Generate dummy input data