|
856 | 856 | return nothing |
857 | 857 | end |
858 | 858 |
|
| 859 | + |
| 860 | +@register_fwd function genericmemory_slice_fwd(B, orig, gutils, normalR, shadowR) |
| 861 | + ctx = LLVM.context(orig) |
| 862 | + |
| 863 | + if is_constant_value(gutils, orig) || unsafe_load(shadowR) == C_NULL |
| 864 | + return true |
| 865 | + end |
| 866 | + |
| 867 | + origops = LLVM.operands(orig) |
| 868 | + |
| 869 | + width = get_width(gutils) |
| 870 | + |
| 871 | + shadowin = invert_pointer(gutils, origops[1], B) |
| 872 | + shadowdata = invert_pointer(gutils, origops[2], B) |
| 873 | + len = new_from_original(gutils, origops[3]) |
| 874 | + |
| 875 | + i8 = LLVM.IntType(8) |
| 876 | + algn = 0 |
| 877 | + |
| 878 | + shadowres = |
| 879 | + UndefValue(LLVM.LLVMType(API.EnzymeGetShadowType(width, value_type(orig)))) |
| 880 | + for idx = 1:width |
| 881 | + ev = if width == 1 |
| 882 | + shadowin |
| 883 | + else |
| 884 | + extract_value!(B, shadowin, idx - 1) |
| 885 | + end |
| 886 | + ev2 = if width == 1 |
| 887 | + shadowdata |
| 888 | + else |
| 889 | + extract_value!(B, shadowdata, idx - 1) |
| 890 | + end |
| 891 | + callv = call_samefunc_with_inverted_bundles!( |
| 892 | + B, |
| 893 | + gutils, |
| 894 | + orig, |
| 895 | + [ev, ev2, len], |
| 896 | + [API.VT_Shadow, API.VT_Shadow, API.VT_Primal], |
| 897 | + false, |
| 898 | + ) #=lookup=# |
| 899 | + if is_constant_value(gutils, origops[1]) |
| 900 | + emit_error(B, orig, "ijl_genericmemory_slice memory argument (1st arg) was constant but return was active") |
| 901 | + end |
| 902 | + if is_constant_value(gutils, origops[2]) |
| 903 | + emit_error(B, orig, "ijl_genericmemory_slice ptr argument (2nd arg) was constant but return was active") |
| 904 | + end |
| 905 | + if get_runtime_activity(gutils) |
| 906 | + prev = new_from_original(gutils, orig) |
| 907 | + callv = LLVM.select!( |
| 908 | + B, |
| 909 | + LLVM.icmp!( |
| 910 | + B, |
| 911 | + LLVM.API.LLVMIntNE, |
| 912 | + ev, |
| 913 | + new_from_original(gutils, origops[1]), |
| 914 | + ), |
| 915 | + callv, |
| 916 | + prev, |
| 917 | + ) |
| 918 | + if idx == 1 |
| 919 | + API.moveBefore(prev, callv, B) |
| 920 | + end |
| 921 | + end |
| 922 | + shadowres = if width == 1 |
| 923 | + callv |
| 924 | + else |
| 925 | + insert_value!(B, shadowres, callv, idx - 1) |
| 926 | + end |
| 927 | + end |
| 928 | + |
| 929 | + unsafe_store!(shadowR, shadowres.ref) |
| 930 | + return false |
| 931 | +end |
| 932 | + |
| 933 | +@register_aug function genericmemory_slice_augfwd(B, orig, gutils, normalR, shadowR, tapeR) |
| 934 | + return genericmemory_slice_fwd(B, orig, gutils, normalR, shadowR) |
| 935 | +end |
| 936 | + |
| 937 | +@register_rev function genericmemory_slice_rev(B, orig, gutils, tape) |
| 938 | + return nothing |
| 939 | +end |
| 940 | + |
859 | 941 | @register_fwd function arrayreshape_fwd(B, orig, gutils, normalR, shadowR) |
860 | 942 | if is_constant_value(gutils, orig) && is_constant_inst(gutils, orig) |
861 | 943 | return true |
@@ -2111,6 +2193,12 @@ end |
2111 | 2193 | @revfunc(genericmemory_copy_slice_rev), |
2112 | 2194 | @fwdfunc(genericmemory_copy_slice_fwd), |
2113 | 2195 | ) |
| 2196 | + register_handler!( |
| 2197 | + ("jl_genericmemory_slice", "ijl_genericmemory_slice"), |
| 2198 | + @augfunc(genericmemory_slice_augfwd), |
| 2199 | + @revfunc(genericmemory_slice_rev), |
| 2200 | + @fwdfunc(genericmemory_slice_fwd), |
| 2201 | + ) |
2114 | 2202 | register_handler!( |
2115 | 2203 | ("jl_reshape_array", "ijl_reshape_array"), |
2116 | 2204 | @augfunc(arrayreshape_augfwd), |
|
0 commit comments