Skip to content

Commit 87adde2

Browse files
authored
1.12: Compute types before ival (#2928)
1 parent 07d7b9e commit 87adde2

File tree

1 file changed

+38
-39
lines changed

1 file changed

+38
-39
lines changed

src/rules/customrules.jl

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -574,45 +574,6 @@ function enzyme_custom_setup_args(
574574
push!(actives, op)
575575
else
576576

577-
ival = nothing
578-
roots_ival = nothing
579-
if B !== nothing
580-
ival = if is_constant_value(gutils, op)
581-
@assert orig_activep != activep
582-
@assert orig_activep == API.DFT_CONSTANT
583-
if val == nothing
584-
load!(B, iarty, ogval)
585-
else
586-
val
587-
end
588-
else
589-
invert_pointer(gutils, op, B)
590-
end
591-
@assert ival !== nothing
592-
593-
uncache_arg = uncacheable[arg.codegen.i] != 0
594-
if roots_op !== nothing
595-
uncache_arg |= uncacheable[arg.codegen.i + 1] != 0
596-
end
597-
if uncache_arg
598-
# TODO we will are not restoring the bits_ref data of the
599-
# shadow value (though now we are at least doing so properly for primal)
600-
# x/ref https://github.com/EnzymeAD/Enzyme.jl/issues/2304
601-
end
602-
603-
if reverse && !is_constant_value(gutils, op)
604-
ival = lookup_value(gutils, ival, B)
605-
end
606-
if roots_op !== nothing
607-
roots_ival = invert_pointer(gutils, roots_op, B)
608-
if reverse
609-
roots_ival = lookup_value(gutils, roots_ival, B)
610-
end
611-
end
612-
@assert ival !== nothing
613-
end
614-
615-
616577
shadowty = arg.typ
617578
mixed = false
618579
if width == 1
@@ -654,6 +615,44 @@ function enzyme_custom_setup_args(
654615
if mixed
655616
@assert arg.cc == GPUCompiler.BITS_REF
656617
end
618+
619+
ival = nothing
620+
roots_ival = nothing
621+
if B !== nothing
622+
ival = if is_constant_value(gutils, op)
623+
@assert orig_activep != activep
624+
@assert orig_activep == API.DFT_CONSTANT
625+
if val == nothing
626+
load!(B, iarty, ogval)
627+
else
628+
val
629+
end
630+
else
631+
invert_pointer(gutils, op, B)
632+
end
633+
@assert ival !== nothing
634+
635+
uncache_arg = uncacheable[arg.codegen.i] != 0
636+
if roots_op !== nothing
637+
uncache_arg |= uncacheable[arg.codegen.i + 1] != 0
638+
end
639+
if uncache_arg
640+
# TODO we will are not restoring the bits_ref data of the
641+
# shadow value (though now we are at least doing so properly for primal)
642+
# x/ref https://github.com/EnzymeAD/Enzyme.jl/issues/2304
643+
end
644+
645+
if reverse && !is_constant_value(gutils, op)
646+
ival = lookup_value(gutils, ival, B)
647+
end
648+
if roots_op !== nothing
649+
roots_ival = invert_pointer(gutils, roots_op, B)
650+
if reverse
651+
roots_ival = lookup_value(gutils, roots_ival, B)
652+
end
653+
end
654+
@assert ival !== nothing
655+
end
657656

658657
if B !== nothing
659658
@assert ival !== nothing

0 commit comments

Comments
 (0)