Description
Our current validator refuses passing non-isbitstype arguments, with the exception of arguments whose type passes the Core.Compiler.isconstType
test. This makes it possible to, e.g., broadcast types as these arguments are only used to specialize the kernel, and not actually used by the generated code (even though they are passed, as opposed to ghost/singleton values).
In JuliaGPU/CUDA.jl#2514, it was noted that some code (notably closure-heavy code generated by Zygote) still refuses to compile, even though the generated code doesn't actually use the non-isbits value. For example:
struct Bar{T}
a::T
end
function main()
a = cu(zeros(5))
capture = Bar
function closure(arg)
capture(arg)
end
function kernel(f, x)
f(x[])
return
end
@cuda kernel(closure, a)
end
The problem here is that the closure captures the type, making the closure non-isbits too. But because the closure is not a const type, we fail compilation. Even though the generated code is perfectly fine:
define ptx_kernel void @_Z6kernel7closureI4TypeI3BazI1TEEE13CuDeviceArrayI7Float32Li1ELi1EE({ i64, i32 } %state, [1 x {}*] %0, { i8 addrspace(1)*, i64, [1 x i64], i64 } %1) local_unnamed_addr {
conversion:
ret void
}
Note how the closure argument does really contain a managed pointer. In this case, we can work around the issue by reviving the more lenient validation removed in #24 where we not only checked for Core.Compiler.isconstType
, but also if the value is actually used:
diff --git a/src/driver.jl b/src/driver.jl
index 9e05eb6..a4cff8f 100644
--- a/src/driver.jl
+++ b/src/driver.jl
@@ -88,8 +88,7 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool
end
@timeit_debug to "Validation" begin
- check_method(job) # not optional
- validate && check_invocation(job)
+ check_method(job)
end
prepare_job!(job)
@@ -99,6 +98,10 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool
ir, ir_meta = emit_llvm(job; libraries, toplevel, optimize, cleanup, only_entry, validate)
+ validate && @timeit_debug to "Validation" begin
+ check_invocation(job, ir_meta.entry)
+ end
+
if output == :llvm
if strip
@timeit_debug to "strip debug info" strip_debuginfo!(ir)
diff --git a/src/validation.jl b/src/validation.jl
index e1a355b..9f1f869 100644
--- a/src/validation.jl
+++ b/src/validation.jl
@@ -66,7 +66,7 @@ function explain_nonisbits(@nospecialize(dt), depth=1; maxdepth=10)
return msg
end
-function check_invocation(@nospecialize(job::CompilerJob))
+function check_invocation(@nospecialize(job::CompilerJob), entry::LLVM.Function)
sig = job.source.specTypes
ft = sig.parameters[1]
tt = Tuple{sig.parameters[2:end]...}
@@ -77,6 +77,9 @@ function check_invocation(@nospecialize(job::CompilerJob))
real_arg_i = 0
for (arg_i,dt) in enumerate(sig.parameters)
+ println(Core.stdout, arg_i)
+ println(Core.stdout, dt)
+
isghosttype(dt) && continue
Core.Compiler.isconstType(dt) && continue
real_arg_i += 1
@@ -89,9 +92,13 @@ function check_invocation(@nospecialize(job::CompilerJob))
end
if !isbitstype(dt)
- throw(KernelError(job, "passing and using non-bitstype argument",
- """Argument $arg_i to your kernel function is of type $dt, which is not isbits:
- $(explain_nonisbits(dt))"""))
+ param = parameters(entry)[real_arg_i]
+ if !isempty(uses(param))
+ println(Core.stdout, string(entry))
+ throw(KernelError(job, "passing and using non-bitstype argument",
+ """Argument $arg_i to your kernel function is of type $dt, which is not isbits:
+ $(explain_nonisbits(dt))"""))
+ end
end
end
Sadly, this approach is insufficient for more complex cases such as:
struct Bar{T}
a::T
b::T
end
function main2()
foo(f) = (args...) -> f(args...)
a = cu(zeros(5)); b = cu(ones(5)); c = Bar{Float32}; d = foo(c)
foo(c).(a, b)
end
Argument 4 to your kernel function is of type Base.Broadcast.Broadcasted{CUDA.CuArrayStyle{1, CUDA.DeviceMemory}, Tuple{Base.OneTo{Int64}}, var"#3#5"{Type{Bar{Float32}}}, Tuple{Base.Broadcast.Extruded{CuDeviceVector{Float32, 1}, Tuple{Bool}, Tuple{Int64}}, Base.Broadcast.Extruded{CuDeviceVector{Float32, 1}, Tuple{Bool}, Tuple{Int64}}}}, which is not isbits:
.f is of type var"#3#5"{Type{Bar{Float32}}} which is not isbits.
.f is of type Type{Bar{Float32}} which is not isbits.
define ptx_kernel void @_Z3_3415CuKernelContext13CuDeviceArrayI15BrokenBroadcastI3AnyELi1ELi1EE11BroadcastedI12CuArrayStyleILi1E12DeviceMemoryE5TupleI5OneToI5Int64EE2_3I4TypeI3BarI7Float32EEES9_I8ExtrudedIS0_ISH_Li1ELi1EES9_I4BoolES9_ISB_EESQ_EESB_({ i64, i32 } %state, { i8 addrspace(1)*, i64, [1 x i64], i64 } %0, { [1 x {}*], [2 x { { i8 addrspace(1)*, i64, [1 x i64], i64 }, [1 x i8], [1 x i64] }], [1 x [1 x i64]] } %1, i64 signext %2) local_unnamed_addr {
conversion:
%.fca.3.extract = extractvalue { i8 addrspace(1)*, i64, [1 x i64], i64 } %0, 3
%.not7 = icmp slt i64 %2, 1
br i1 %.not7, label %common.ret, label %L5.lr.ph
L5.lr.ph: ; preds = %conversion
%3 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%4 = add nuw nsw i32 %3, 1
%5 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%6 = zext i32 %5 to i64
%7 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
%8 = zext i32 %7 to i64
%9 = mul nuw nsw i64 %6, %8
%10 = zext i32 %4 to i64
%11 = add nuw nsw i64 %9, %10
%12 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
%13 = mul i32 %12, %7
%14 = sext i32 %13 to i64
br label %L5
L5: ; preds = %L5, %L5.lr.ph
%value_phi8 = phi i64 [ 1, %L5.lr.ph ], [ %20, %L5 ]
%15 = add i64 %value_phi8, -1
%16 = mul i64 %15, %14
%17 = add i64 %11, %16
%18 = icmp slt i64 %17, 1
%19 = icmp sgt i64 %17, %.fca.3.extract
%spec.select = select i1 %18, i1 true, i1 %19
%20 = add i64 %value_phi8, 1
%.not = icmp sgt i64 %20, %2
%or.cond = select i1 %spec.select, i1 true, i1 %.not
br i1 %or.cond, label %common.ret, label %L5
common.ret: ; preds = %L5, %conversion
ret void
}
Note how the non-isbits Broadcasted argument is used, so it also fails the more lenient validation check, but it's just not the managed pointer that's being used.
I'm not sure how to proceed this. Simply removing the validation and trusting that other aspects of IR validation will error seems too optimistic -- IIRC we introduced this check to prevent accidentally reading CPU memory from the GPU. And actually detecting whether the managed pointer field is the one being used seems hard.
I'm also not sure how important that is; we've not received many bug reports about this, and the motivating example by @BioTurboNick would simply fail after validation anyway because it involves a broken broadcast (producing Any
values). So maybe this isn't very important.