@@ -693,6 +693,7 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie
693693
694694 if mode == API. DEM_ForwardMode && (used || idx != 0 )
695695 # Zero any jlvalue_t inner elements of preceeding allocation.
696+
696697 # Specifically in forward mode, you will first run the original allocation,
697698 # then all shadow allocations. These allocations will thus all run before
698699 # any value may store into them. For example, as follows:
@@ -703,12 +704,38 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie
703704 # As a result, by the time of the subsequent GC allocation, the memory in the preceeding
704705 # allocation might be undefined, and trigger a GC error. To avoid this,
705706 # we will explicitly zero the GC'd fields of the previous allocation.
707+
708+ # Reverse mode will do similarly, except doing the shadow first
706709 prev = LLVM. Instruction (prev)
707710 B = LLVM. IRBuilder ()
708711 position! (B, LLVM. Instruction (LLVM. API. LLVMGetNextInstruction (prev)))
709712
710713 create_recursive_stores (B, Ty, prev)
711714 end
715+ if (mode == API. DEM_ReverseModePrimal || mode == API. DEM_ReverseModeCombined) && used
716+ # Zero any jlvalue_t inner elements of preceeding allocation.
717+
718+ # Specifically in reverse mode, you will run the original allocation,
719+ # after all shadow allocations. The shadow allocations will thus all run before any value may store into them. For example, as follows:
720+ # %"orig'" = julia.gcalloc(...)
721+ # %orig = julia.gc_alloc(...)
722+ # store "orig'"[0] = jlvaluet'
723+ # store orig[0] = jlvaluet
724+ #
725+ # Normally this is fine, since we will memset right after the shadow
726+ # however we will do this memset non atomically and if you have a case like the following, there will be an issue
727+
728+ # %"orig'" = julia.gcalloc(...)
729+ # memset("orig'")
730+ # %orig = julia.gc_alloc(...)
731+ # store "orig'"[0] = jlvaluet'
732+ # store orig[0] = jlvaluet
733+ #
734+ # Julia could decide to dead store eliminate the memset (not being read before the store of jlvaluet'), resulting in an error
735+ B = LLVM. IRBuilder ()
736+ position! (B, LLVM. Instruction (LLVM. API. LLVMGetNextInstruction (V)))
737+ create_recursive_stores (B, Ty, V)
738+ end
712739
713740 nothing
714741end
0 commit comments