-
Notifications
You must be signed in to change notification settings - Fork 318
Expand file tree
/
Copy pathp12.mojo
More file actions
93 lines (76 loc) · 2.81 KB
/
p12.mojo
File metadata and controls
93 lines (76 loc) · 2.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from memory import UnsafePointer, stack_allocation
from gpu import thread_idx, block_idx, block_dim, barrier
from gpu.host import DeviceContext
from gpu.memory import AddressSpace
from testing import assert_equal
comptime TPB = 8
comptime SIZE = 8
comptime BLOCKS_PER_GRID = (1, 1)
comptime THREADS_PER_BLOCK = (TPB, 1)
comptime dtype = DType.float32
# ANCHOR: dot_product_solution
fn dot_product(
output: UnsafePointer[Scalar[dtype], MutAnyOrigin],
a: UnsafePointer[Scalar[dtype], MutAnyOrigin],
b: UnsafePointer[Scalar[dtype], MutAnyOrigin],
size: UInt,
):
shared = stack_allocation[
TPB,
Scalar[dtype],
address_space = AddressSpace.SHARED,
]()
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
if global_i < size:
shared[local_i] = a[global_i] * b[global_i]
barrier()
# The following causes race condition: all threads writing to the same location
# out[0] += shared[local_i]
# Instead can do parallel reduction in shared memory as opposed to
# global memory which has no guarantee on synchronization.
# Loops using global memory can cause thread divergence because
# fundamentally GPUs execute threads in warps (groups of 32 threads typically)
# and warps can be scheduled independently.
# However, shared memory does not have such issues as long as we use `barrier()`
# correctly when we're in the same thread block.
stride = UInt(TPB // 2)
while stride > 0:
if local_i < stride:
shared[local_i] += shared[local_i + stride]
barrier()
stride //= 2
# only thread 0 writes the final result
if local_i == 0:
output[0] = shared[0]
# ANCHOR_END: dot_product_solution
def main():
with DeviceContext() as ctx:
out = ctx.enqueue_create_buffer[dtype](1)
out.enqueue_fill(0)
a = ctx.enqueue_create_buffer[dtype](SIZE)
a.enqueue_fill(0)
b = ctx.enqueue_create_buffer[dtype](SIZE)
b.enqueue_fill(0)
with a.map_to_host() as a_host, b.map_to_host() as b_host:
for i in range(SIZE):
a_host[i] = i
b_host[i] = i
ctx.enqueue_function[dot_product, dot_product](
out,
a,
b,
UInt(SIZE),
grid_dim=BLOCKS_PER_GRID,
block_dim=THREADS_PER_BLOCK,
)
expected = ctx.enqueue_create_host_buffer[dtype](1)
expected.enqueue_fill(0)
ctx.synchronize()
with a.map_to_host() as a_host, b.map_to_host() as b_host:
for i in range(SIZE):
expected[0] += a_host[i] * b_host[i]
with out.map_to_host() as out_host:
print("out:", out_host)
print("expected:", expected)
assert_equal(out_host[0], expected[0])