Skip to content

Commit df85909

Browse files
authored
Generic memory slice (#2234)
* Generic memory slice * fix
1 parent f46e44d commit df85909

File tree

2 files changed

+96
-1
lines changed

2 files changed

+96
-1
lines changed

src/llvm/attributes.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ const nofreefns = Set{String}((
4141
"jl_array_ptr_copy",
4242
"ijl_array_copy",
4343
"jl_array_copy",
44+
"ijl_genericmemory_slice",
45+
"jl_genericmemory_slice",
4446
"ijl_genericmemory_copy_slice",
4547
"jl_genericmemory_copy_slice",
4648
"ijl_get_nth_field_checked",
@@ -644,6 +646,8 @@ function annotate!(mod::LLVM.Module)
644646
"ijl_alloc_array_3d",
645647
"jl_array_copy",
646648
"ijl_array_copy",
649+
"jl_genericmemory_slice",
650+
"ijl_genericmemory_slice",
647651
"jl_genericmemory_copy_slice",
648652
"ijl_genericmemory_copy_slice",
649653
"jl_alloc_genericmemory",
@@ -670,8 +674,11 @@ function annotate!(mod::LLVM.Module)
670674
LLVM.EnumAttribute("inaccessiblememonly")
671675
else
672676
if fname in (
677+
"jl_genericmemory_slice",
678+
"ijl_genericmemory_slice",
673679
"jl_genericmemory_copy_slice",
674-
"ijl_genericmemory_copy_slice",)
680+
"ijl_genericmemory_copy_slice",
681+
)
675682
EnumAttribute(
676683
"memory",
677684
MemoryEffect(

src/rules/llvmrules.jl

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -856,6 +856,88 @@ end
856856
return nothing
857857
end
858858

859+
860+
@register_fwd function genericmemory_slice_fwd(B, orig, gutils, normalR, shadowR)
861+
ctx = LLVM.context(orig)
862+
863+
if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL
864+
return true
865+
end
866+
867+
origops = LLVM.operands(orig)
868+
869+
width = get_width(gutils)
870+
871+
shadowin = invert_pointer(gutils, origops[1], B)
872+
shadowdata = invert_pointer(gutils, origops[2], B)
873+
len = new_from_original(gutils, origops[3])
874+
875+
i8 = LLVM.IntType(8)
876+
algn = 0
877+
878+
shadowres =
879+
UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig))))
880+
for idx = 1:width
881+
ev = if width == 1
882+
shadowin
883+
else
884+
extract_value!(B, shadowin, idx - 1)
885+
end
886+
ev2 = if width == 1
887+
shadowdata
888+
else
889+
extract_value!(B, shadowdata, idx - 1)
890+
end
891+
callv = call_samefunc_with_inverted_bundles!(
892+
B,
893+
gutils,
894+
orig,
895+
[ev, ev2, len],
896+
[API.VT_Shadow, API.VT_Shadow, API.VT_Primal],
897+
false,
898+
) #=lookup=#
899+
if is_constant_value(gutils, origops[1])
900+
emit_error(B, orig, "ijl_genericmemory_slice memory argument (1st arg) was constant but return was active")
901+
end
902+
if is_constant_value(gutils, origops[2])
903+
emit_error(B, orig, "ijl_genericmemory_slice ptr argument (2nd arg) was constant but return was active")
904+
end
905+
if get_runtime_activity(gutils)
906+
prev = new_from_original(gutils, orig)
907+
callv = LLVM.select!(
908+
B,
909+
LLVM.icmp!(
910+
B,
911+
LLVM.API.LLVMIntNE,
912+
ev,
913+
new_from_original(gutils, origops[1]),
914+
),
915+
callv,
916+
prev,
917+
)
918+
if idx == 1
919+
API.moveBefore(prev, callv, B)
920+
end
921+
end
922+
shadowres = if width == 1
923+
callv
924+
else
925+
insert_value!(B, shadowres, callv, idx - 1)
926+
end
927+
end
928+
929+
unsafe_store!(shadowR, shadowres.ref)
930+
return false
931+
end
932+
933+
@register_aug function genericmemory_slice_augfwd(B, orig, gutils, normalR, shadowR, tapeR)
934+
return genericmemory_slice_fwd(B, orig, gutils, normalR, shadowR)
935+
end
936+
937+
@register_rev function genericmemory_slice_rev(B, orig, gutils, tape)
938+
return nothing
939+
end
940+
859941
@register_fwd function arrayreshape_fwd(B, orig, gutils, normalR, shadowR)
860942
if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig)
861943
return true
@@ -2111,6 +2193,12 @@ end
21112193
@revfunc(genericmemory_copy_slice_rev),
21122194
@fwdfunc(genericmemory_copy_slice_fwd),
21132195
)
2196+
register_handler!(
2197+
("jl_genericmemory_slice", "ijl_genericmemory_slice"),
2198+
@augfunc(genericmemory_slice_augfwd),
2199+
@revfunc(genericmemory_slice_rev),
2200+
@fwdfunc(genericmemory_slice_fwd),
2201+
)
21142202
register_handler!(
21152203
("jl_reshape_array", "ijl_reshape_array"),
21162204
@augfunc(arrayreshape_augfwd),

0 commit comments

Comments
 (0)