-
Notifications
You must be signed in to change notification settings - Fork 27
Expand file tree
/
Copy pathstep_03.py
More file actions
73 lines (60 loc) · 2.56 KB
/
step_03.py
File metadata and controls
73 lines (60 loc) · 2.56 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# ===----------------------------------------------------------------------=== #
#
# This file is Modular Inc proprietary.
#
# ===----------------------------------------------------------------------=== #
"""
Step 03: Causal Masking
Implement causal attention masking that prevents tokens from attending to future positions.
Tasks:
1. Import functional module (as F) and Tensor from max.nn
2. Add @F.functional decorator to the causal_mask function
3. Create a constant tensor filled with negative infinity
4. Broadcast the mask to the correct shape (sequence_length, n)
5. Apply band_part to create the lower triangular causal structure
Run: pixi run s03
"""
# 1: Import the required modules from MAX
from max.driver import Device
from max.dtype import DType
# TODO: Import functional module max.functional with the alias F
# https://docs.modular.com/max/api/python/functional/
# TODO: Import Tensor from max.tensor
# https://docs.modular.com/max/api/python/tensor/
from max.graph import Dim, DimLike
from max.tensor import Tensor
# 2: Add the @F.functional decorator to make this a MAX functional operation
# TODO: Add the decorator here
def causal_mask(
sequence_length: DimLike,
num_tokens: DimLike,
*,
dtype: DType,
device: Device,
) -> Tensor:
"""Create a causal mask for autoregressive attention.
Args:
sequence_length: Length of the sequence.
num_tokens: Number of tokens.
dtype: Data type for the mask.
device: Device to create the mask on.
Returns:
A causal mask tensor.
"""
# Calculate total sequence length
n = Dim(sequence_length) + num_tokens
# 3: Create a constant tensor filled with negative infinity
# TODO: Use Tensor.constant() with float("-inf"), dtype, and device parameters
# https://docs.modular.com/max/api/python/tensor/#max.tensor.Tensor.constant
# Hint: This creates the base mask value that will block attention to future tokens
mask = None
# 4: Broadcast the mask to the correct shape
# TODO: Use F.broadcast_to() to expand mask to shape (sequence_length, n)
# https://docs.modular.com/max/api/python/functional/#max.functional.broadcast_to
# Hint: This creates a 2D attention mask matrix
mask = None
# 5: Apply band_part to create the causal (lower triangular) structure and return the mask
# TODO: Use F.band_part() with num_lower=None, num_upper=0, exclude=True
# https://docs.modular.com/max/api/python/functional/#max.functional.band_part
# Hint: This keeps only the lower triangle, allowing attention to past tokens only
return None