|
4 | 4 | sys.path.insert(0, glob.glob('build/lib.linux-*')[0])
|
5 | 5 | sys.path.insert(0,'./src')
|
6 | 6 |
|
| 7 | +from functools import partial |
7 | 8 | import pytest
|
8 | 9 | import jax
|
9 | 10 | import jax.numpy as jnp
|
@@ -100,22 +101,114 @@ def test_flash_fwd_vmap(n, seqlen, h, d, causal, local, dtype):
|
100 | 101 | k = jax.random.normal(jax.random.PRNGKey(1), [x, n, seqlen, h, d], dtype=jnp.float32)
|
101 | 102 | v = jax.random.normal(jax.random.PRNGKey(2), [x, n, seqlen, h, d], dtype=jnp.float32)
|
102 | 103 |
|
103 |
| - @jax.jit |
104 | 104 | def ref(q,k,v):
|
105 | 105 | return ref_mha(q,k,v, is_causal=bool(causal), window_size=window_size)
|
106 |
| - @jax.jit |
107 | 106 | def flash(q,k,v):
|
108 | 107 | return flash_mha(q,k,v, is_causal=bool(causal), window_size=window_size)
|
109 | 108 |
|
110 |
| - ref_out = jnp.stack([ref(q[i],k[i],v[i]) for i in range(x)]) |
| 109 | + ref_out = jax.vmap(ref)(q,k,v) |
111 | 110 | q = q.astype(dtype)
|
112 | 111 | k = k.astype(dtype)
|
113 | 112 | v = v.astype(dtype)
|
114 |
| - f16_out = jnp.stack([ref(q[i],k[i],v[i]) for i in range(x)]) |
115 |
| - |
| 113 | + f16_out = jax.vmap(ref)(q,k,v) |
116 | 114 |
|
117 | 115 | out = jax.vmap(flash)(q,k,v)
|
118 | 116 | check(ref_out, f16_out, out)
|
119 | 117 |
|
| 118 | +@pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16]) |
| 119 | +@pytest.mark.parametrize("local", ['local','']) |
| 120 | +@pytest.mark.parametrize("causal", ['causal','']) |
| 121 | +@pytest.mark.parametrize("d", [59, 32]) |
| 122 | +@pytest.mark.parametrize("h", [1, 4]) |
| 123 | +@pytest.mark.parametrize("seqlen", [97, 128]) |
| 124 | +@pytest.mark.parametrize("n", [1]) |
| 125 | +def test_flash_fwd_vmapq(n, seqlen, h, d, causal, local, dtype): |
| 126 | + window_size = (3,3) if local else (-1,-1) |
| 127 | + |
| 128 | + x = 4 |
| 129 | + q = jax.random.normal(jax.random.PRNGKey(0), [x, n, seqlen, h, d], dtype=jnp.float32) |
| 130 | + k = jax.random.normal(jax.random.PRNGKey(1), [n, seqlen, h, d], dtype=jnp.float32) |
| 131 | + v = jax.random.normal(jax.random.PRNGKey(2), [n, seqlen, h, d], dtype=jnp.float32) |
| 132 | + |
| 133 | + def ref(q,k,v): |
| 134 | + return ref_mha(q,k,v, is_causal=bool(causal), window_size=window_size) |
| 135 | + def flash(q,k,v): |
| 136 | + return flash_mha(q,k,v, is_causal=bool(causal), window_size=window_size) |
| 137 | + |
| 138 | + ref_out = jax.vmap(ref, in_axes=(0,None,None))(q,k,v) |
| 139 | + q = q.astype(dtype) |
| 140 | + k = k.astype(dtype) |
| 141 | + v = v.astype(dtype) |
| 142 | + f16_out = jax.vmap(ref, in_axes=(0,None,None))(q,k,v) |
| 143 | + |
| 144 | + out = jax.vmap(flash, in_axes=(0,None,None))(q,k,v) |
| 145 | + check(ref_out, f16_out, out) |
| 146 | + |
| 147 | +@pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16]) |
| 148 | +@pytest.mark.parametrize("local", ['local','']) |
| 149 | +@pytest.mark.parametrize("causal", ['causal','']) |
| 150 | +@pytest.mark.parametrize("d", [59, 32]) |
| 151 | +@pytest.mark.parametrize("h", [1, 4]) |
| 152 | +@pytest.mark.parametrize("seqlen", [97, 128]) |
| 153 | +@pytest.mark.parametrize("n", [1]) |
| 154 | +def test_flash_bwd_vmap(n, seqlen, h, d, causal, local, dtype): |
| 155 | + window_size = (3,3) if local else (-1,-1) |
| 156 | + |
| 157 | + x = 4 |
| 158 | + q = jax.random.normal(jax.random.PRNGKey(0), [x, n, seqlen, h, d], dtype=jnp.float32) |
| 159 | + k = jax.random.normal(jax.random.PRNGKey(1), [x, n, seqlen, h, d], dtype=jnp.float32) |
| 160 | + v = jax.random.normal(jax.random.PRNGKey(2), [x, n, seqlen, h, d], dtype=jnp.float32) |
| 161 | + do = jax.random.normal(jax.random.PRNGKey(3), [x, n, seqlen, h, d], dtype=jnp.float32) |
| 162 | + |
| 163 | + def func(mha, q,k,v): |
| 164 | + @partial(jax.vmap, in_axes=(0,0,0)) |
| 165 | + def fwd(q,k,v): |
| 166 | + return mha(q,k,v, is_causal=bool(causal), window_size=window_size) |
| 167 | + o, bwd = jax.vjp(fwd,q,k,v) |
| 168 | + return bwd(do) |
| 169 | + |
| 170 | + ref_out = func(ref_mha, q,k,v) |
| 171 | + q = q.astype(dtype) |
| 172 | + k = k.astype(dtype) |
| 173 | + v = v.astype(dtype) |
| 174 | + do = do.astype(dtype) |
| 175 | + f16_out = func(ref_mha, q,k,v) |
| 176 | + |
| 177 | + out = func(flash_mha, q,k,v) |
| 178 | + check(ref_out, f16_out, out) |
| 179 | + |
| 180 | +@pytest.mark.parametrize("dtype", [jnp.float16, jnp.bfloat16]) |
| 181 | +@pytest.mark.parametrize("local", ['local','']) |
| 182 | +@pytest.mark.parametrize("causal", ['causal','']) |
| 183 | +@pytest.mark.parametrize("d", [59, 32]) |
| 184 | +@pytest.mark.parametrize("h", [1, 4]) |
| 185 | +@pytest.mark.parametrize("seqlen", [97, 128]) |
| 186 | +@pytest.mark.parametrize("n", [1]) |
| 187 | +def test_flash_bwd_vmapq(n, seqlen, h, d, causal, local, dtype): |
| 188 | + window_size = (3,3) if local else (-1,-1) |
| 189 | + |
| 190 | + x = 4 |
| 191 | + q = jax.random.normal(jax.random.PRNGKey(0), [x, n, seqlen, h, d], dtype=jnp.float32) |
| 192 | + k = jax.random.normal(jax.random.PRNGKey(1), [n, seqlen, h, d], dtype=jnp.float32) |
| 193 | + v = jax.random.normal(jax.random.PRNGKey(2), [n, seqlen, h, d], dtype=jnp.float32) |
| 194 | + do = jax.random.normal(jax.random.PRNGKey(3), [x, n, seqlen, h, d], dtype=jnp.float32) |
| 195 | + |
| 196 | + def func(mha, q,k,v): |
| 197 | + @partial(jax.vmap, in_axes=(0,None,None)) |
| 198 | + def fwd(q,k,v): |
| 199 | + return mha(q,k,v, is_causal=bool(causal), window_size=window_size) |
| 200 | + o, bwd = jax.vjp(fwd,q,k,v) |
| 201 | + return bwd(do) |
| 202 | + |
| 203 | + ref_out = func(ref_mha, q,k,v) |
| 204 | + q = q.astype(dtype) |
| 205 | + k = k.astype(dtype) |
| 206 | + v = v.astype(dtype) |
| 207 | + do = do.astype(dtype) |
| 208 | + f16_out = func(ref_mha, q,k,v) |
| 209 | + |
| 210 | + out = func(flash_mha, q,k,v) |
| 211 | + check(ref_out, f16_out, out) |
| 212 | + |
120 | 213 | if __name__ == '__main__':
|
121 | 214 | test_flash_bwd(1,4,1,32,4,False,False,jnp.float16)
|
0 commit comments