Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions book/src/week1-03-gqa.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,6 @@ x = scaled_dot_product_attention_grouped(q, k, v, scale, mask) -> B, L, H_q, D ;
x = linear(x, wo) -> B, L, E
```

Keep in mind that you should use non-traditional RoPE.

You can test your implementation by running the following command:

```bash
Expand Down
81 changes: 74 additions & 7 deletions src/tiny_llm/attention.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import mlx.core as mx
from .basics import softmax, linear


def scaled_dot_product_attention_simple(
query: mx.array,
key: mx.array,
value: mx.array,
scale: float | None = None,
mask: mx.array | None = None,
) -> mx.array:
pass
L, D = query.shape[-2], query.shape[-1]

score = mx.matmul(query, key.swapaxes(-2, -1))
atten_score = score * (mx.rsqrt(D) if scale is None else scale)

if mask is not None:
atten_score += mask

return mx.matmul(softmax(atten_score, axis=-1), value)

class SimpleMultiHeadAttention:
def __init__(
Expand All @@ -22,7 +28,15 @@ def __init__(
wv: mx.array,
wo: mx.array,
):
pass
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_hidden_size = self.hidden_size // self.num_heads

# E x (H x D)
self.wq = wq
self.wk = wk
self.wv = wv
self.wo = wo

def __call__(
self,
Expand All @@ -31,21 +45,74 @@ def __call__(
value: mx.array,
mask: mx.array | None = None,
) -> mx.array:
pass
N, L = query.shape[0], query.shape[1]

# N x L x H x D
query = linear(query, self.wq).reshape(N, L, self.num_heads, self.head_hidden_size)
key = linear(key, self.wk).reshape(N, L, self.num_heads, self.head_hidden_size)
value = linear(value, self.wv).reshape(N, L, self.num_heads, self.head_hidden_size)

def causal_mask(L: int, S: int, dtype: mx.Dtype) -> mx.array:
pass
# N x H x L x D
query = query.swapaxes(-1, -2)
key = key.swapaxes(-1, -2)
value = value.swapaxes(-1, -2)

# N x H x L x D
mh_attention = scaled_dot_product_attention_simple(
query,
key,
value,
mask=mask)
# N x L x H x D
mh_attention = mh_attention.swapaxes(1, 2)

# N x L x E
mh_attention = mh_attention.reshape(N, L, -1)

out = linear(mh_attention, self.wo)
return out



def causal_mask(L: int, S: int, dtype: mx.Dtype) -> mx.array:
mask = mx.tril(mx.ones((L, S)), k=(S-L))
mask = mx.where(mask, mx.array(0), mx.array(-mx.inf)).astype(dtype)
return mask

def scaled_dot_product_attention_grouped(
query: mx.array,
key: mx.array,
value: mx.array,
scale: float | None = None,
mask: mx.array | str | None = None,
) -> mx.array:
pass
H_q, L, D = query.shape[-3], query.shape[-2], query.shape[-1]
H, S = key.shape[-3], key.shape[-2]

assert H_q % H == 0, "Query heads must be divisible by kv heads"

n_repeats = H_q // H

q_shape = query.shape
query = query.reshape(-1, H, n_repeats, L, D)
key = key.reshape(-1, H, 1, S, D)
value = value.reshape(-1, H, 1, S, D)

score = mx.matmul(query, key.swapaxes(-2, -1))
atten_score = score * (mx.rsqrt(D) if scale is None else scale)

if isinstance(mask, str) and mask == "causal":
mask = causal_mask(L, S, atten_score.dtype)
atten_score += mask
elif isinstance(mask, mx.array):
atten_score += mask.reshape(atten_score.shape)
elif mask is None:
pass
else:
raise NotImplementedError

return mx.matmul(softmax(atten_score, axis=-1), value).reshape(q_shape)



def flash_attention(
Expand Down
16 changes: 12 additions & 4 deletions src/tiny_llm/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,22 @@ def softmax(x: mx.array, axis: int) -> mx.array:
# TODO: manual implementation
return mx.softmax(x, axis=axis)


def linear(
x: mx.array,
w: mx.array,
bias: mx.array | None = None,
) -> mx.array:
pass

if bias is None:
return mx.matmul(x, w.T)
else:
return mx.matmul(x, w.T) + bias

def silu(x: mx.array) -> mx.array:
pass
def sigmoid(x: mx.array):
return 1.0 / (1.0 + mx.exp(-x))
return x * sigmoid(x)

def logsumexp_norm(x: mx.array):
c = x.max(axis=-1)
logsumexp = c + mx.log(mx.sum(mx.exp(x - c), axis=-1))
return mx.exp(x - logsumexp)
12 changes: 9 additions & 3 deletions src/tiny_llm/embedding.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import mlx.core as mx
from .basics import linear


class Embedding:
def __init__(self, vocab_size: int, embedding_dim: int, weight: mx.array):
pass
assert weight.shape[0] == vocab_size
assert weight.shape[1] == embedding_dim

self.vocab_size = vocab_size
self.embedding_dim = embedding_dim
self.weight = weight

def __call__(self, x: mx.array) -> mx.array:
pass
return self.weight[x, :]

def as_linear(self, x: mx.array) -> mx.array:
pass
return linear(x, self.weight)
23 changes: 22 additions & 1 deletion src/tiny_llm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from mlx_lm.tokenizer_utils import TokenizerWrapper
from .qwen2_week1 import Qwen2ModelWeek1
from .qwen2_week2 import Qwen2ModelWeek2
from .basics import logsumexp_norm
from typing import Callable


Expand All @@ -12,7 +13,27 @@ def simple_generate(
sampler: Callable[[mx.array], mx.array] | None,
) -> str:
def _step(model, y):
pass
output_logits = model(y[None])
logits = output_logits[:, -1, :]
logprobs = logsumexp_norm(logits)
if sampler is None:
return mx.argmax(logprobs, axis=-1)
else:
return sampler(logprobs)

prompt_tokens = mx.array(tokenizer.encode(prompt, add_special_tokens=False))
next_token = None

detokenizer = tokenizer.detokenizer
detokenizer.reset()

while next_token is None or next_token.item() != tokenizer.eos_token_id:
next_token = _step(model, prompt_tokens)
mx.eval(next_token)
prompt_tokens = mx.concat([prompt_tokens, next_token])
if next_token.item() != tokenizer.eos_token_id:
detokenizer.add_token(next_token.item())
print(detokenizer.last_segment, end="", flush=True)


def simple_generate_with_kv_cache(
Expand Down
10 changes: 8 additions & 2 deletions src/tiny_llm/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@

class RMSNorm:
def __init__(self, dim: int, weight: mx.array, eps: float = 1e-5):
pass
self.dim = dim
self.weight = weight
self.eps = eps

def __call__(self, x: mx.array) -> mx.array:
pass
x = x.astype(mx.float32)

x = x * mx.rsqrt(mx.mean(x.square(), axis=-1, keepdims=True) + self.eps)
x = x * self.weight.astype(mx.float32)
return x
64 changes: 61 additions & 3 deletions src/tiny_llm/positional_encoding.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,75 @@
import mlx.core as mx


class RoPE:
_RoPE_cache_key = None
_RoPE_cache_value = None

def __init__(
self,
dims: int,
seq_len: int,
base: int = 10000,
traditional: bool = False,
):
pass
self.dims = dims
self.max_seq_len = seq_len
self.base = base
self.traditional = traditional

def __call__(
self, x: mx.array, offset: list[slice] | slice | None = None
) -> mx.array:
pass
N, L, H, D = x.shape
assert D == self.dims

costh, sinth = RoPE._cal_cos_sin_theta(offset, L, self.base, self.dims)
costh = costh[:, None, :]
sinth = sinth[:, None, :]

if self.traditional:
x = x.reshape(-1, H, D)
x1 = x[..., ::2]
x2 = x[..., 1::2]
rx1 = costh * x1 - sinth * x2
rx2 = sinth * x1 + costh * x2
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1).reshape(N, L, H, D)

else:
x = x.reshape(-1, H, D)
x1 = x[..., : D//2]
x2 = x[..., D//2:]
rx1 = costh * x1 - sinth * x2
rx2 = sinth * x1 + costh * x2
rx = mx.concatenate([rx1, rx2], axis=-1).reshape(N, L, H, D)
return rx

@classmethod
def _cal_cos_sin_theta(cls, offset, L, base, dim):
if (offset, L, base, dim) == cls._RoPE_cache_key:
return cls._RoPE_cache_value

if offset is None:
off = list(range(0, L))
elif type(offset) is slice:
off = list(range(offset.start or 0, offset.stop, offset.step or 1))
assert len(off) == L

pos = mx.array(off, dtype=mx.float32)

d = dim // 2

freq = mx.exp(
-mx.arange(0.0, d) * (mx.log(base) / d)
)

theta = pos.reshape(-1, 1) * freq.reshape(1, -1)
cos_theta = mx.cos(theta)
sin_theta = mx.sin(theta)

assert(cos_theta.shape == (L, d))
assert(sin_theta.shape == (L, d))

cls._RoPE_cache_key = (offset, L, base, dim)
cls._RoPE_cache_value = (cos_theta, sin_theta)

return cos_theta, sin_theta
Loading