Skip to content

[wasm][aot] Optimize 64 bit const shuffles. Otherwise prefer vector swizzle. #115351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 48 additions & 9 deletions src/mono/mono/mini/mini-llvm.c
Original file line number Diff line number Diff line change
Expand Up @@ -10449,24 +10449,63 @@ MONO_RESTORE_WARNING
break;
}
case OP_WASM_SIMD_SWIZZLE: {
LLVMValueRef bidx = LLVMBuildBitCast (builder, rhs, LLVMVectorType (i1_t, 16), "");
LLVMValueRef bidx = rhs;
int nelems = LLVMGetVectorSize (LLVMTypeOf (lhs));
if (nelems < 16) {
int stride = 16 / nelems;

if (nelems != 16) {
if (LLVMIsConstant(rhs) && nelems == 2) {
LLVMValueRef indexes [16];
for (int i = 0; i < nelems; ++i)
indexes [i] = LLVMBuildExtractElement (builder, rhs, const_int32 (i), "");

LLVMValueRef shuffle_val = LLVMConstNull (LLVMVectorType (i4_t, nelems));

for (int i = 0; i < nelems; ++i)
shuffle_val = LLVMBuildInsertElement (builder, shuffle_val, convert (ctx, indexes [i], i4_t), const_int32 (i), "");
values [ins->dreg] = LLVMBuildShuffleVector (builder, lhs, LLVMGetUndef (LLVMTypeOf (lhs)), shuffle_val, "");
break;
}

// clamp each index to the lowest invalid value (nelems)
// so it will remain invalid but won't overflow to a valid
// value during the multiply and add below
LLVMTypeRef idx_t = LLVMTypeOf (bidx);
LLVMTypeRef elem_t = LLVMGetElementType (idx_t);
LLVMValueRef minv = broadcast_constant (nelems, elem_t, nelems);
LLVMValueRef cmp = LLVMBuildICmp (builder, LLVMIntULT, bidx, minv, "");
bidx = LLVMBuildSelect (builder, cmp, bidx, minv, "");

// cast indices to i8x16
bidx = LLVMBuildBitCast (builder, bidx, LLVMVectorType (i1_t, 16), "");

// build our offset and fill constant vectors
// fill is used to copy the index value to every byte in the lane
// offset is used to add the position of each byte with in a lane
int shift = nelems == 8 ? 1 : (nelems == 4 ? 2 : 3);
LLVMValueRef fill = LLVMConstNull (LLVMVectorType (i1_t, 16));
LLVMValueRef offset = LLVMConstNull (LLVMVectorType (i1_t, 16));
int stride = 16 / nelems;
for (int i = 0; i < nelems; ++i) {
LLVMValueRef fills [16];
LLVMValueRef offsets [16];
for (int i = 0, k = 0; i < nelems; i++, k += stride) {
for (int j = 0; j < stride; ++j) {
offset = LLVMBuildInsertElement (builder, offset, const_int8 (j), const_int8 (i * stride + j), "");
fill = LLVMBuildInsertElement (builder, fill, const_int8 (i * stride), const_int8 (i * stride + j), "");
offsets[k + j] = const_int8 (j);
fills[k + j] = const_int8 (k);
}
}
LLVMValueRef fill = LLVMConstVector (fills, 16);
LLVMValueRef offset = LLVMConstVector (offsets, 16);

// multiply the indices by the stride using bidx << shift
// llvm should optimize the below to an i8x16 shl intrinsic
LLVMValueRef shiftv = create_shift_vector (ctx, bidx, const_int32 (shift));
bidx = LLVMBuildShl (builder, bidx, shiftv, "");

// copy the shifted value to every byte of a lane via swizzle
LLVMValueRef args [] = { bidx, fill };
bidx = call_intrins (ctx, INTRINS_WASM_SWIZZLE, args, "");
bidx = LLVMBuildAdd (builder, bidx, offset, "");

// add the byte offset to each byte in the lane using bidx | offset
// we can use OR instead of ADD here because the low bits are 0 due to <<
bidx = LLVMBuildOr (builder, bidx, offset, "");
}
LLVMValueRef lhs_b = LLVMBuildBitCast (builder, lhs, LLVMVectorType (i1_t, 16), "");
LLVMValueRef args [] = { lhs_b, bidx };
Expand Down
Loading