Skip to content

Commit 1d3ec4f

Browse files
committed
Sub-group
1 parent 48905ce commit 1d3ec4f

File tree

4 files changed

+152
-3
lines changed

4 files changed

+152
-3
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ Random = "1"
3838
Random123 = "1.7.1"
3939
RandomNumbers = "1.6.0"
4040
Reexport = "1"
41-
SPIRVIntrinsics = "0.5"
41+
SPIRVIntrinsics = "0.5.7"
4242
SPIRV_LLVM_Backend_jll = "20"
4343
SPIRV_Tools_jll = "2025.1"
4444
StaticArrays = "1"

lib/cl/device.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,20 @@ end
139139
return tuple([Int(r) for r in result]...)
140140
end
141141

142+
# error handling inspired by rusticl
143+
# https://gitlab.freedesktop.org/mesa/mesa/-/blob/c4385d6fb0938231114eb3023082cd33788b89b4/src/gallium/frontends/rusticl/api/device.rs#L314-320
144+
if s == :sub_group_sizes
145+
res_size = Ref{Csize_t}()
146+
err = unchecked_clGetDeviceInfo(d, CL_DEVICE_SUB_GROUP_SIZES_INTEL, C_NULL, C_NULL, res_size)
147+
if err == CL_SUCCESS && res_size[] > 1
148+
result = Vector{Csize_t}(undef, res_size[] ÷ sizeof(Csize_t))
149+
clGetDeviceInfo(d, CL_DEVICE_SUB_GROUP_SIZES_INTEL, sizeof(result), result, C_NULL)
150+
return tuple([Int(r) for r in result]...)
151+
else
152+
return tuple(0, 1)
153+
end
154+
end
155+
142156
if s == :max_image2d_shape
143157
width = Ref{Csize_t}()
144158
height = Ref{Csize_t}()
@@ -273,3 +287,40 @@ function cl_device_type(dtype::Symbol)
273287
end
274288
return cl_dtype
275289
end
290+
291+
sub_groups_supported(d::Device) = "cl_khr_subgroups" in d.extensions || "cl_intel_subgroups" in d.extensions
292+
function sub_group_size(d::Device)
293+
sub_groups_supported(d) || 0
294+
if "cl_amd_device_attribute_query" in d.extensions
295+
scalar = Ref{cl_uint}()
296+
clGetDeviceInfo(d, CL_DEVICE_WAVEFRONT_WIDTH_AMD, sizeof(cl_uint), scalar, C_NULL)
297+
return Int(scalar[])
298+
elseif "cl_nv_device_attribute_query" in d.extensions
299+
scalar = Ref{cl_uint}()
300+
clGetDeviceInfo(d, CL_DEVICE_WARP_SIZE_NV, sizeof(cl_uint), scalar, C_NULL)
301+
return Int(scalar[])
302+
else
303+
sg_sizes = d.sub_group_sizes
304+
return if length(sg_sizes) == 1
305+
Int(only(sg_sizes))
306+
elseif 32 in sg_sizes
307+
32
308+
elseif 64 in sg_sizes
309+
64
310+
elseif 16 in sg_sizes
311+
16
312+
else
313+
Int(first(sg_sizes))
314+
end
315+
end
316+
end
317+
function sub_group_shuffle_supported_types(d::Device)
318+
if "cl_khr_subgroup_shuffle" in d.extensions
319+
res = [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Float32]
320+
"cl_khr_fp16" in d.extensions && push!(res, Float16)
321+
"cl_khr_fp64" in d.extensions && push!(res, Float64)
322+
res
323+
else
324+
DataType[]
325+
end
326+
end

src/compiler/compilation.jl

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
## gpucompiler interface
22

3-
struct OpenCLCompilerParams <: AbstractCompilerParams end
3+
Base.@kwdef struct OpenCLCompilerParams <: AbstractCompilerParams
4+
sub_group_size::Int # Some devices support multiple sizes. This is used to force one when needed
5+
end
6+
function Base.hash(params::OpenCLCompilerParams, h::UInt)
7+
h = hash(params.sub_group_size, h)
8+
9+
return h
10+
end
11+
412
const OpenCLCompilerConfig = CompilerConfig{SPIRVCompilerTarget, OpenCLCompilerParams}
513
const OpenCLCompilerJob = CompilerJob{SPIRVCompilerTarget,OpenCLCompilerParams}
614

@@ -29,6 +37,12 @@ function GPUCompiler.finish_module!(@nospecialize(job::OpenCLCompilerJob),
2937
Tuple{CompilerJob{SPIRVCompilerTarget}, LLVM.Module, LLVM.Function},
3038
job, mod, entry)
3139

40+
# Set the subgroup size if supported
41+
sg_size = job.config.params.sub_group_size
42+
if sg_size >= 0
43+
metadata(entry)["intel_reqd_sub_group_size"] = MDNode([ConstantInt(Int32(sg_size))])
44+
end
45+
3246
# if this kernel uses our RNG, we should prime the shared state.
3347
# XXX: these transformations should really happen at the Julia IR level...
3448
if haskey(functions(mod), "julia.opencl.random_keys") && job.config.kernel
@@ -131,9 +145,16 @@ end
131145
supports_fp16 = "cl_khr_fp16" in dev.extensions
132146
supports_fp64 = "cl_khr_fp64" in dev.extensions
133147

148+
# Set to -1 if specifying a subgroup size is not supported
149+
sub_group_size = if "cl_intel_required_subgroup_size" in dev.extensions
150+
cl.sub_group_size(dev)
151+
else
152+
-1
153+
end
154+
134155
# create GPUCompiler objects
135156
target = SPIRVCompilerTarget(; supports_fp16, supports_fp64, validate=true, kwargs...)
136-
params = OpenCLCompilerParams()
157+
params = OpenCLCompilerParams(; sub_group_size)
137158
CompilerConfig(target, params; kernel, name, always_inline)
138159
end
139160

test/intrinsics.jl

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,83 @@ end
166166
@test call_on_device(OpenCL.mad, x, y, z) x * y + z
167167
end
168168

169+
if cl.sub_groups_supported(cl.device())
170+
171+
struct SubgroupData
172+
sub_group_size::UInt32
173+
max_sub_group_size::UInt32
174+
num_sub_groups::UInt32
175+
sub_group_id::UInt32
176+
sub_group_local_id::UInt32
177+
end
178+
function test_subgroup_kernel(results)
179+
i = get_global_id(1)
180+
181+
if i <= length(results)
182+
@inbounds results[i] = SubgroupData(
183+
get_sub_group_size(),
184+
get_max_sub_group_size(),
185+
get_num_sub_groups(),
186+
get_sub_group_id(),
187+
get_sub_group_local_id()
188+
)
189+
end
190+
return
191+
end
192+
193+
@testset "Sub-groups" begin
194+
sg_size = cl.sub_group_size(cl.device())
195+
196+
@testset "Indexing intrinsics" begin
197+
# Test with small kernel
198+
sg_n = 2
199+
local_size = sg_size * sg_n
200+
numworkgroups = 2
201+
N = local_size * numworkgroups
202+
203+
results = CLVector{SubgroupData}(undef, N)
204+
kernel = @opencl launch = false test_subgroup_kernel(results)
205+
206+
kernel(results; local_size, global_size=N)
207+
208+
host_results = Array(results)
209+
210+
# Verify results make sense
211+
for (i, sg_data) in enumerate(host_results)
212+
@test sg_data.sub_group_size == sg_size
213+
@test sg_data.max_sub_group_size == sg_size
214+
@test sg_data.num_sub_groups == sg_n
215+
216+
# Group ID should be 1-based
217+
expected_sub_group = div(((i - 1) % local_size), sg_size) + 1
218+
@test sg_data.sub_group_id == expected_sub_group
219+
220+
# Local ID should be 1-based within group
221+
expected_sg_local = ((i - 1) % sg_size) + 1
222+
@test sg_data.sub_group_local_id == expected_sg_local
223+
end
224+
end
225+
226+
@testset "shuffle idx" begin
227+
function shfl_idx_kernel(d)
228+
i = get_sub_group_local_id()
229+
j = get_sub_group_size() - i + 1
230+
231+
d[i] = sub_group_shuffle(d[i], j)
232+
233+
return
234+
end
235+
236+
@testset for T in cl.sub_group_shuffle_supported_types(cl.device())
237+
a = rand(T, sg_size)
238+
d_a = CLArray(a)
239+
@opencl local_size = sg_size global_size = sg_size shfl_idx_kernel(d_a)
240+
@test Array(d_a) == reverse(a)
241+
end
242+
end
243+
end
244+
end # if cl.sub_groups_supported(cl.device())
245+
169246
@testset "SIMD - $N x $T" for N in simd_ns, T in float_types
170247
# codegen emits i48 here, which SPIR-V doesn't support
171248
# XXX: fix upstream?

0 commit comments

Comments
 (0)