|
1 | 1 | import jax |
2 | 2 | import jax.numpy as jnp |
3 | | -import taichi as ti |
| 3 | +import taichi as taichi |
| 4 | +import pytest |
| 5 | +import platform |
4 | 6 |
|
5 | 7 | import brainpy.math as bm |
6 | 8 |
|
7 | 9 | bm.set_platform('cpu') |
8 | 10 |
|
| 11 | +if not platform.platform().startswith('Windows'): |
| 12 | + pytest.skip(allow_module_level=True) |
| 13 | + |
| 14 | + |
9 | 15 | # @ti.kernel |
10 | 16 | # def event_ell_cpu(indices: ti.types.ndarray(ndim=2), |
11 | 17 | # vector: ti.types.ndarray(ndim=1), |
|
19 | 25 | # for j in range(num_cols): |
20 | 26 | # out[indices[i, j]] += weight_0 |
21 | 27 |
|
22 | | -@ti.func |
23 | | -def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32: |
24 | | - return weight[0] |
| 28 | +@taichi.func |
| 29 | +def get_weight(weight: taichi.types.ndarray(ndim=1)) -> taichi.f32: |
| 30 | + return weight[0] |
25 | 31 |
|
26 | | -@ti.func |
27 | | -def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32): |
28 | | - out[index] += weight_val |
29 | 32 |
|
30 | | -@ti.kernel |
31 | | -def event_ell_cpu(indices: ti.types.ndarray(ndim=2), |
32 | | - vector: ti.types.ndarray(ndim=1), |
33 | | - weight: ti.types.ndarray(ndim=1), |
34 | | - out: ti.types.ndarray(ndim=1)): |
35 | | - weight_val = get_weight(weight) |
36 | | - num_rows, num_cols = indices.shape |
37 | | - ti.loop_config(serialize=True) |
38 | | - for i in range(num_rows): |
39 | | - if vector[i]: |
40 | | - for j in range(num_cols): |
41 | | - update_output(out, indices[i, j], weight_val) |
| 33 | +@taichi.func |
| 34 | +def update_output(out: taichi.types.ndarray(ndim=1), index: taichi.i32, weight_val: taichi.f32): |
| 35 | + out[index] += weight_val |
42 | 36 |
|
43 | 37 |
|
44 | | -prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu) |
| 38 | +@taichi.kernel |
| 39 | +def event_ell_cpu(indices: taichi.types.ndarray(ndim=2), |
| 40 | + vector: taichi.types.ndarray(ndim=1), |
| 41 | + weight: taichi.types.ndarray(ndim=1), |
| 42 | + out: taichi.types.ndarray(ndim=1)): |
| 43 | + weight_val = get_weight(weight) |
| 44 | + num_rows, num_cols = indices.shape |
| 45 | + taichi.loop_config(serialize=True) |
| 46 | + for i in range(num_rows): |
| 47 | + if vector[i]: |
| 48 | + for j in range(num_cols): |
| 49 | + update_output(out, indices[i, j], weight_val) |
45 | 50 |
|
46 | 51 |
|
47 | | -# def test_taichi_op_register(): |
48 | | -# s = 1000 |
49 | | -# indices = bm.random.randint(0, s, (s, 1000)) |
50 | | -# vector = bm.random.rand(s) < 0.1 |
51 | | -# weight = bm.array([1.0]) |
| 52 | +prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu) |
| 53 | + |
52 | 54 |
|
53 | | -# out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) |
| 55 | +def test_taichi_op_register(): |
| 56 | + s = 1000 |
| 57 | + indices = bm.random.randint(0, s, (s, 1000)) |
| 58 | + vector = bm.random.rand(s) < 0.1 |
| 59 | + weight = bm.array([1.0]) |
54 | 60 |
|
55 | | -# out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) |
| 61 | + out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) |
56 | 62 |
|
57 | | -# print(out) |
58 | | -# bm.clear_buffer_memory() |
| 63 | + out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)]) |
59 | 64 |
|
| 65 | + print(out) |
| 66 | + bm.clear_buffer_memory() |
60 | 67 |
|
61 | 68 | # test_taichi_op_register() |
0 commit comments