Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions keras/src/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _calculate_scores(self, query, key):
Tensor of shape `(batch_size, Tq, Tv)`.
"""
if self.score_mode == "dot":
scores = ops.matmul(query, ops.transpose(key, axes=[0, 2, 1]))
scores = ops.matmul(query, ops.swapaxes(key, -2, -1))
if self.scale is not None:
scores = ops.multiply(scores, self.scale)
elif self.score_mode == "concat":
Expand Down Expand Up @@ -256,7 +256,7 @@ def compute_output_shape(self, input_shape):

output_shape = (*query_shape[:-1], value_shape[-1])
if self._return_attention_scores:
scores_shape = (query_shape[0], query_shape[1], key_shape[1])
scores_shape = (*query_shape[:-1], key_shape[-2])
return output_shape, scores_shape
return output_shape

Expand All @@ -283,10 +283,9 @@ def compute_output_spec(
# Handle attention scores if requested
if self._return_attention_scores or return_attention_scores:
scores_shape = (
query.shape[0],
query.shape[1],
key.shape[1],
) # (batch_size, Tq, Tv)
*query.shape[:-1],
key.shape[-2],
) # (*batch_dims, Tq, Tv)
attention_scores_spec = KerasTensor(
scores_shape, dtype=self.compute_dtype
)
Expand Down
14 changes: 14 additions & 0 deletions keras/src/layers/attention/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,20 @@ def test_attention_compute_output_shape(self):
output.shape,
)

def test_attention_nd_inputs(self):
"""Test that Attention handles N-D inputs (e.g. 4D from Conv2D)."""
layer = layers.Attention()
# 4D inputs: (batch, height, width, channels)
query = np.random.random((2, 8, 6, 4)).astype(np.float32)
value = np.random.random((2, 8, 6, 4)).astype(np.float32)
output = layer([query, value])
self.assertEqual(output.shape, (2, 8, 6, 4))

# With return_attention_scores
output, scores = layer([query, value], return_attention_scores=True)
self.assertEqual(output.shape, (2, 8, 6, 4))
self.assertEqual(scores.shape, (2, 8, 6, 6))

def test_return_attention_scores_true(self):
"""Test that the layer returns attention scores along with outputs."""
# Generate dummy input data
Expand Down