Skip to content

Commit a5daf50

Browse files
authored
Clearer allocation API and add matmul test (#80)
1 parent 3f398e2 commit a5daf50

23 files changed

+599
-153
lines changed

.formatter.exs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ locals_without_parens = [
77
defk: 2,
88
defbind: 1,
99
set!: 2,
10-
ptr!: 1,
11-
ptr!: 2,
10+
tmp!: 1,
11+
tmp!: 2,
12+
new!: 1,
13+
new!: 2,
1214
launch!: 3,
1315
defer: 1,
1416
free!: 1,

bench/enif_merge_sort.ex

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,13 @@ defmodule ENIFMergeSort do
1515

1616
@err %ArgumentError{message: "list expected"}
1717
defm sort(env, list) :: Term.t() do
18-
len_ptr = ptr! i32()
18+
len_ptr = tmp! i32()
1919

2020
if enif_get_list_length(env, list, len_ptr) != 0 do
21-
movable_list_ptr = ptr! Term.t()
21+
movable_list_ptr = tmp! Term.t()
2222
set! movable_list_ptr[0], list
2323
len = len_ptr[0]
24-
arr = ptr! Term.t(), len
24+
arr = new! Term.t(), len
2525
SortUtil.copy_terms(env, movable_list_ptr, arr)
2626
zero = const 0 :: i32()
2727
do_sort(arr, zero, len - 1)

bench/enif_quick_sort.ex

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ defmodule ENIFQuickSort do
1212

1313
defm partition(arr :: Pointer.t(Term.t()), low :: i32(), high :: i32()) :: i32() do
1414
pivot = arr[high]
15-
i_ptr = ptr! i32()
15+
i_ptr = tmp! i32()
1616
set! i_ptr[0], low - 1
1717
start = arr + low
1818

@@ -39,13 +39,13 @@ defmodule ENIFQuickSort do
3939

4040
@err %ArgumentError{message: "list expected"}
4141
defm sort(env, list) :: Term.t() do
42-
len_ptr = ptr! i32()
42+
len_ptr = tmp! i32()
4343

4444
if enif_get_list_length(env, list, len_ptr) != 0 do
45-
movable_list_ptr = ptr! Term.t()
45+
movable_list_ptr = tmp! Term.t()
4646
set! movable_list_ptr[0], list
4747
len = len_ptr[0]
48-
arr = ptr! Term.t(), len
48+
arr = new! Term.t(), len
4949
SortUtil.copy_terms(env, movable_list_ptr, arr)
5050
zero = const 0 :: i32()
5151
do_sort(arr, zero, len - 1)

bench/enif_tim_sort.ex

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ defmodule ENIFTimSort do
1111
for_loop {temp, i} <- {start, n} do
1212
i = value index.casts(i) :: i32()
1313
i = i + start_i
14-
j_ptr = ptr! i32()
14+
j_ptr = tmp! i32()
1515
set! j_ptr[0], i - 1
1616

1717
while(j_ptr[0] >= left && arr[j_ptr[0]] > temp) do
@@ -27,9 +27,8 @@ defmodule ENIFTimSort do
2727

2828
defm tim_sort(arr :: Pointer.t(Term.t()), n :: i32()) do
2929
run = const 32 :: i32()
30-
i_ptr = ptr! i32()
31-
zero = const 0 :: i32()
32-
set! i_ptr[0], zero
30+
i_ptr = tmp! i32()
31+
set! i_ptr[0], 0
3332

3433
while i_ptr[0] < n do
3534
i = i_ptr[0]
@@ -38,14 +37,14 @@ defmodule ENIFTimSort do
3837
set! i_ptr[0], i + run
3938
end
4039

41-
size_ptr = ptr! i32()
40+
size_ptr = tmp! i32()
4241
set! size_ptr[0], run
4342

4443
while size_ptr[0] < n do
4544
size = size_ptr[0]
4645

47-
left_ptr = ptr! i32()
48-
set! left_ptr[0], zero
46+
left_ptr = tmp! i32()
47+
set! left_ptr[0], 0
4948

5049
while left_ptr[0] < n do
5150
left = left_ptr[0]
@@ -66,13 +65,13 @@ defmodule ENIFTimSort do
6665

6766
@err %ArgumentError{message: "list expected"}
6867
defm sort(env, list) :: Term.t() do
69-
len_ptr = ptr! i32()
68+
len_ptr = tmp! i32()
7069

7170
if enif_get_list_length(env, list, len_ptr) != 0 do
72-
movable_list_ptr = ptr! Term.t()
71+
movable_list_ptr = tmp! Term.t()
7372
set! movable_list_ptr[0], list
7473
len = len_ptr[0]
75-
arr = ptr! Term.t(), len
74+
arr = new! Term.t(), len
7675
SortUtil.copy_terms(env, movable_list_ptr, arr)
7776
tim_sort(arr, len)
7877
defer free! arr

bench/gpu/matmul.ex

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
defmodule MatMulKernel do
2+
@moduledoc false
3+
use Charms
4+
alias Charms.{Term, Pointer}
5+
alias Charms.GPU
6+
7+
# Matrix Dimensions
8+
# A: (M x K), B: (K x N), C: (M x N)
9+
@m 64
10+
# Inner dimension (must match cols of A and rows of B)
11+
@k 128
12+
@n 32
13+
14+
@size_a @m * @k
15+
@size_b @k * @n
16+
@size_c @m * @n
17+
18+
@block_size 1024
19+
20+
# Kernel: C = A * B
21+
# A is m*k, B is k*n, C is m*n
22+
defk matmul(a :: Pointer.t(f32()), b :: Pointer.t(f32()), c :: Pointer.t(f32())) do
23+
# Global thread index (maps to C matrix)
24+
idx = GPU.block_id() * @block_size + GPU.thread_id()
25+
26+
# Map linear index to matrix coordinates (row, col) of C (m x n)
27+
row = idx / @n
28+
col = rem(idx, @n)
29+
30+
if idx < @size_c do
31+
# Accumulator for the dot product
32+
sum_ptr = tmp! f32()
33+
set! sum_ptr[0], 0.0
34+
35+
# Iterator k
36+
k_ptr = tmp! i32()
37+
set! k_ptr[0], 0
38+
39+
# Loop over the shared dimension K
40+
while k_ptr[0] < @k do
41+
k = k_ptr[0]
42+
43+
# A[row * k_dim + k]
44+
val_a = a[row * @k + k]
45+
# B[k * n_dim + col]
46+
val_b = b[k * @n + col]
47+
48+
set! sum_ptr[0], sum_ptr[0] + val_a * val_b
49+
set! k_ptr[0], k + 1
50+
end
51+
52+
# Store result in C
53+
set! c[idx], sum_ptr[0]
54+
end
55+
end
56+
57+
@grid_size ceil(@size_c / @block_size)
58+
59+
defm main(env, l_a :: Term.t(), l_b :: Term.t()) :: Term.t() do
60+
size_a = Term.to_i64!(env, @size_a)
61+
size_b = Term.to_i64!(env, @size_b)
62+
size_c = Term.to_i64!(env, @size_c)
63+
64+
# 1. Allocate Device Memory
65+
a = GPU.allocate(f32(), size_a)
66+
b = GPU.allocate(f32(), size_b)
67+
c = GPU.allocate(f32(), size_c)
68+
69+
# 2. Allocate Host Memory (Dedicated buffers as requested)
70+
buffer_a = GPU.allocate(f32(), size_a, host_shared: true)
71+
buffer_b = GPU.allocate(f32(), size_b, host_shared: true)
72+
buffer_c = GPU.allocate(f32(), size_c, host_shared: true)
73+
74+
# 3. Cleanup
75+
defer GPU.await([
76+
GPU.dealloc(a),
77+
GPU.dealloc(b),
78+
GPU.dealloc(c),
79+
GPU.dealloc(buffer_a),
80+
GPU.dealloc(buffer_b),
81+
GPU.dealloc(buffer_c)
82+
])
83+
84+
# 4. Copy Input (Host -> Buffer -> Device)
85+
movable_list_ptr = tmp! Term.t()
86+
87+
# Copy A
88+
set! movable_list_ptr[0], l_a
89+
copy_terms_as_floats(env, movable_list_ptr, buffer_a)
90+
GPU.memcpy(a, buffer_a) |> GPU.await()
91+
92+
# Copy B
93+
set! movable_list_ptr[0], l_b
94+
copy_terms_as_floats(env, movable_list_ptr, buffer_b)
95+
GPU.memcpy(b, buffer_b) |> GPU.await()
96+
97+
# 5. Launch Kernel
98+
launch! matmul(a, b, c), Term.to_i64!(env, @grid_size), Term.to_i64!(env, @block_size)
99+
100+
# 6. Copy Output (Device -> Buffer -> Host)
101+
GPU.memcpy(buffer_c, c) |> GPU.await()
102+
103+
# 7. Construct Elixir List from Buffer C
104+
arr = new! Term.t(), size_c
105+
defer free! arr
106+
107+
for_loop {element, i} <- {buffer_c, size_c} do
108+
element = value arith.extf(element) :: f64()
109+
set! arr[i], enif_make_double(env, element)
110+
end
111+
112+
size_c_i32 = value arith.trunci(size_c) :: i32()
113+
enif_make_list_from_array(env, arr, size_c_i32)
114+
end
115+
116+
defm copy_terms_as_floats(env, tail :: Pointer.t(Term.t()), arr :: Pointer.t(f32())) do
117+
head = tmp! Term.t()
118+
zero = const 0 :: i32()
119+
i_ptr = tmp! i32()
120+
set! i_ptr[0], zero
121+
122+
while(enif_get_list_cell(env, tail[0], head, tail) > 0) do
123+
double_ptr = tmp! f64()
124+
enif_get_double(env, head[0], double_ptr)
125+
i = i_ptr[0]
126+
set! arr[i], value(arith.truncf(double_ptr[0]) :: f32())
127+
set! i_ptr[0], i + 1
128+
end
129+
end
130+
131+
def random_list(size) do
132+
Enum.map(1..size, fn _ -> :rand.uniform() end)
133+
end
134+
135+
def dims, do: {@m, @k, @n}
136+
end
137+
138+
defmodule SquareMatMulKernel do
139+
@moduledoc false
140+
use Charms
141+
alias Charms.{Term, Pointer}
142+
alias Charms.GPU
143+
144+
# Matrix Dimensions (N x N)
145+
# 64 x 64 matrix = 4,096 elements
146+
@width 64
147+
@size @width * @width
148+
@block_size 1024
149+
150+
# Kernel: C = A * B
151+
# Uses 1D thread indexing mapped to 2D matrix coordinates
152+
defk matmul(a :: Pointer.t(f32()), b :: Pointer.t(f32()), c :: Pointer.t(f32())) do
153+
# Global thread index
154+
idx = GPU.block_id() * @block_size + GPU.thread_id()
155+
156+
# Map linear index to matrix coordinates (row, col)
157+
# Note: width must be constant or passed as arg. Using module attr for simplicity.
158+
row = idx / @width
159+
col = rem(idx, @width)
160+
161+
if idx < @size do
162+
# Accumulator for the dot product
163+
sum_ptr = tmp! f32()
164+
set! sum_ptr[0], 0.0
165+
166+
# Iterator k
167+
k_ptr = tmp! i32()
168+
set! k_ptr[0], 0
169+
170+
while k_ptr[0] < @width do
171+
k = k_ptr[0]
172+
173+
# A[row * width + k]
174+
val_a = a[row * @width + k]
175+
# B[k * width + col]
176+
val_b = b[k * @width + col]
177+
178+
set! sum_ptr[0], sum_ptr[0] + val_a * val_b
179+
set! k_ptr[0], k + 1
180+
end
181+
182+
# Store result in C
183+
set! c[idx], sum_ptr[0]
184+
end
185+
end
186+
187+
@grid_size ceil(@size / @block_size)
188+
189+
defm main(env, l_a :: Term.t(), l_b :: Term.t()) :: Term.t() do
190+
size = Term.to_i64!(env, @size)
191+
192+
# 1. Allocate Device Memory
193+
a = GPU.allocate(f32(), size)
194+
b = GPU.allocate(f32(), size)
195+
c = GPU.allocate(f32(), size)
196+
buffer = GPU.allocate(f32(), size, host_shared: true)
197+
198+
# 2. Cleanup
199+
defer GPU.await([
200+
GPU.dealloc(a),
201+
GPU.dealloc(b),
202+
GPU.dealloc(c),
203+
GPU.dealloc(buffer)
204+
])
205+
206+
# 3. Copy Input (Host -> Device)
207+
movable_list_ptr = tmp! Term.t()
208+
209+
# Copy A
210+
set! movable_list_ptr[0], l_a
211+
copy_terms_as_floats(env, movable_list_ptr, buffer)
212+
GPU.memcpy(a, buffer) |> GPU.await()
213+
214+
# Copy B
215+
set! movable_list_ptr[0], l_b
216+
copy_terms_as_floats(env, movable_list_ptr, buffer)
217+
GPU.memcpy(b, buffer) |> GPU.await()
218+
219+
# 4. Launch Kernel
220+
# We launch enough threads to cover the N*N matrix
221+
launch! matmul(a, b, c), Term.to_i64!(env, @grid_size), Term.to_i64!(env, @block_size)
222+
223+
# 5. Copy Output (Device -> Host)
224+
GPU.memcpy(buffer, c) |> GPU.await()
225+
226+
# 6. Construct Elixir List from Buffer
227+
arr = new! Term.t(), size
228+
defer free! arr
229+
230+
for_loop {element, i} <- {buffer, size} do
231+
element = value arith.extf(element) :: f64()
232+
set! arr[i], enif_make_double(env, element)
233+
end
234+
235+
size_i32 = value arith.trunci(size) :: i32()
236+
enif_make_list_from_array(env, arr, size_i32)
237+
end
238+
239+
defm copy_terms_as_floats(env, tail :: Pointer.t(Term.t()), arr :: Pointer.t(f32())) do
240+
head = tmp! Term.t()
241+
zero = const 0 :: i32()
242+
i_ptr = tmp! i32()
243+
set! i_ptr[0], zero
244+
245+
while(enif_get_list_cell(env, tail[0], head, tail) > 0) do
246+
double_ptr = tmp! f64()
247+
enif_get_double(env, head[0], double_ptr)
248+
i = i_ptr[0]
249+
set! arr[i], value(arith.truncf(double_ptr[0]) :: f32())
250+
set! i_ptr[0], i + 1
251+
end
252+
end
253+
254+
# Helper to generate data for the test
255+
def random_matrix() do
256+
Enum.map(1..@size, fn _ -> :rand.uniform() end)
257+
end
258+
259+
def width, do: @width
260+
end

0 commit comments

Comments
 (0)