Skip to content

Commit 4d4a292

Browse files
authored
test: fully mask idle batching slots (#109)
Signed-off-by: Connor1996 <zbk602423539@gmail.com>
1 parent 8718a29 commit 4d4a292

File tree

2 files changed

+1
-4
lines changed

2 files changed

+1
-4
lines changed

src/tiny_llm_ref/kv_cache.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,6 @@ def get_seq_len(data):
7878
)
7979
for b in range(B):
8080
if data[b] is None:
81-
# for some reasons we need to do this, otherwise it will cause wrong output?
82-
# maybe precision issues?
83-
masks[b, :, :] = causal_mask(mask_length, seq_len, dtype=key.dtype)
8481
continue
8582
key, value, S, mask = data[b]
8683
keys[b, :, seq_len - S : seq_len, :] = key

tests_refsol/test_week_2_day_6.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def test_task_2_batching_kv_cache():
236236
expected_mask = mx.array(
237237
[
238238
[[[-mx.inf, 0.0, 0.0, -mx.inf], [-mx.inf, 0.0, 0.0, 0.0]]],
239-
[[[0.0, 0.0, 0.0, -mx.inf], [0.0, 0.0, 0.0, 0.0]]],
239+
[[[-mx.inf, -mx.inf, -mx.inf, -mx.inf], [-mx.inf, -mx.inf, -mx.inf, -mx.inf]]],
240240
[[[0.0, 0.0, 0.0, -mx.inf], [0.0, 0.0, 0.0, 0.0]]],
241241
],
242242
dtype=mx.float32,

0 commit comments

Comments
 (0)