Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@
*/
public abstract class VectorArchitecture {

/** Distinguishes between vector compress and expand operation support. */
public enum CompressExpandOp {
COMPRESS,
EXPAND
}

/**
* The stride (in bytes) for vectors of ordinary object pointers in memory. That is, this is the
* compressed reference size if compressed references are enabled.
Expand Down Expand Up @@ -365,11 +371,12 @@ protected int getSupportedVectorLogicLengthHelper(LogicNode logicNode, int maxLe
/**
* Get the maximum supported vector length for a vector compress/expand based on a mask.
*
* @param elementStamp the stamp of the elements to be blended
* @param elementStamp the stamp of the elements to be compressed/expanded
* @param maxLength the maximum length to return
* @param op the operation (compress or expand)
* @return the number of elements that can be compressed/expanded by a single instruction
*/
public abstract int getSupportedVectorCompressExpandLength(Stamp elementStamp, int maxLength);
public abstract int getSupportedVectorCompressExpandLength(Stamp elementStamp, int maxLength, CompressExpandOp op);

/**
* Determine the minimum alignment in bytes that is guaranteed for objects.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ public int getSupportedVectorBlendLength(Stamp elementStamp, int maxLength) {
}

@Override
public int getSupportedVectorCompressExpandLength(Stamp elementStamp, int maxLength) {
public int getSupportedVectorCompressExpandLength(Stamp elementStamp, int maxLength, CompressExpandOp op) {
return 1;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -611,16 +611,29 @@ private boolean isImpossibleLongToDoubleConversion(Stamp result, Stamp input) {
}

@Override
public int getSupportedVectorCompressExpandLength(Stamp elementStamp, int maxLength) {
public int getSupportedVectorCompressExpandLength(Stamp elementStamp, int maxLength, CompressExpandOp op) {
if (!hasMinimumVectorizationRequirements(maxLength)) {
return 1;
}

AVXSize avxSize = compressExpandOps.getSupportedAVXSize(elementStamp, maxLength);
int supportedLength = getSupportedVectorLength(elementStamp, maxLength, avxSize);
if (op == CompressExpandOp.COMPRESS && supportedLength == 1 && supportsByteCompressFallback(elementStamp)) {
/*
* AVX byte-compress fallback: emulate byte compress with shuffle-based code paths.
*/
supportedLength = getSupportedVectorLength(elementStamp, maxLength, getMaxSupportedAVXSize(arch.getFeatures()));
}
return Math.min(supportedLength, maxLength);
}

private boolean supportsByteCompressFallback(Stamp elementStamp) {
return elementStamp instanceof IntegerStamp integerStamp &&
integerStamp.getBits() == Byte.SIZE &&
arch.getFeatures().contains(CPUFeature.AVX2) &&
arch.getFeatures().contains(CPUFeature.POPCNT);
}

@Override
public int getObjectAlignment() {
return objectAlignment;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@
import jdk.graal.compiler.lir.amd64.AMD64AddressValue;
import jdk.graal.compiler.lir.amd64.AMD64Binary;
import jdk.graal.compiler.lir.amd64.AMD64Move;
import jdk.graal.compiler.lir.amd64.vector.AVXByteCompress;
import jdk.graal.compiler.lir.amd64.vector.AMD64VectorBinary;
import jdk.graal.compiler.lir.amd64.vector.AMD64VectorBlend;
import jdk.graal.compiler.lir.amd64.vector.AMD64VectorClearOp;
Expand Down Expand Up @@ -2076,6 +2077,24 @@ public static AMD64Assembler.VexOp getMaskedOpcode(AMD64 arch, MaskedOpMetaData

@Override
public Variable emitVectorCompress(LIRKind resultKind, Value source, Value mask) {
AMD64Kind kind = (AMD64Kind) resultKind.getPlatformKind();
AVXSize size = AVXKind.getRegisterSize(kind);
if (kind.getScalar() == AMD64Kind.BYTE &&
!supports(AMD64.CPUFeature.AVX512_VBMI2) &&
supports(AMD64.CPUFeature.AVX2) &&
supports(AMD64.CPUFeature.POPCNT)) {
/*
* VPCOMPRESSB (native byte compress) requires AVX512_VBMI2. Without it, byte compress
* must be emulated with the AVX2 shuffle-based fallback.
*/
Variable result = getLIRGen().newVariable(resultKind);
AMD64Kind maskKind = size == AVXSize.ZMM ? AMD64Kind.QWORD : AMD64Kind.DWORD;
Value scalarMask = emitMoveOpMaskToInteger(LIRKind.value(maskKind), mask, kind.getVectorLength());
getLIRGen().append(new AVXByteCompress.CompressBytesWithMaskOp(getLIRGen(), asAllocatable(result), asAllocatable(source), asAllocatable(scalarMask)));
return result;
}

GraalError.guarantee(supports(AMD64.CPUFeature.AVX512_VBMI2), "compress without fallback requires AVX512_VBMI2");
Variable result = getLIRGen().newVariable(resultKind);
getLIRGen().append(new AVX512CompressExpand.CompressOp(result, asAllocatable(source), asAllocatable(mask)));
return result;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@
import jdk.graal.compiler.lir.amd64.vector.AMD64VectorMove;
import jdk.graal.compiler.lir.amd64.vector.AMD64VectorShuffle;
import jdk.graal.compiler.lir.amd64.vector.AMD64VectorUnary;
import jdk.graal.compiler.lir.amd64.vector.AVXByteCompress;
import jdk.graal.compiler.lir.asm.ArrayDataPointerConstant;
import jdk.graal.compiler.vector.nodes.simd.SimdConstant;
import jdk.graal.compiler.vector.nodes.simd.SimdStamp;
Expand Down Expand Up @@ -1463,6 +1464,15 @@ public Variable emitVectorOpMaskOrTestMove(Value left, Value right, boolean allZ

@Override
public Variable emitVectorCompress(LIRKind resultKind, Value source, Value mask) {
AMD64Kind kind = (AMD64Kind) resultKind.getPlatformKind();
AVXSize size = AVXKind.getRegisterSize(kind);
if (kind.getScalar() == AMD64Kind.BYTE && supports(CPUFeature.AVX2) && supports(CPUFeature.POPCNT) && (size == AVXSize.XMM || size == AVXSize.YMM)) {
Variable result = getLIRGen().newVariable(resultKind);
Variable scalarMask = getLIRGen().newVariable(LIRKind.value(AMD64Kind.DWORD));
getLIRGen().append(new AMD64VectorUnary.AVXUnaryRROp(VPMOVMSKB, size, scalarMask, asAllocatable(mask)));
getLIRGen().append(new AVXByteCompress.CompressBytesWithMaskOp(getLIRGen(), asAllocatable(result), asAllocatable(source), asAllocatable(scalarMask)));
return result;
}
throw GraalError.shouldNotReachHere("AVX/AVX2 does not support compress/expand");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ protected SimdCompressNode(SimdStamp stamp, ValueNode src, ValueNode mask) {
SimdStamp srcStamp = (SimdStamp) src.stamp(NodeView.DEFAULT);
SimdStamp maskStamp = (SimdStamp) mask.stamp(NodeView.DEFAULT);
GraalError.guarantee(stamp.isCompatible(srcStamp), "%s - %s", stamp, src);
GraalError.guarantee(maskStamp.getComponent(0) instanceof LogicValueStamp, "%s", mask);
GraalError.guarantee(maskStamp.isMask(), "%s", mask);
GraalError.guarantee(srcStamp.getVectorLength() == maskStamp.getVectorLength(), "%s - %s", src, mask);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ protected SimdExpandNode(SimdStamp stamp, ValueNode src, ValueNode mask) {
SimdStamp srcStamp = (SimdStamp) src.stamp(NodeView.DEFAULT);
SimdStamp maskStamp = (SimdStamp) mask.stamp(NodeView.DEFAULT);
GraalError.guarantee(stamp.isCompatible(srcStamp), "%s - %s", stamp, src);
GraalError.guarantee(maskStamp.getComponent(0) instanceof LogicValueStamp, "%s", mask);
GraalError.guarantee(maskStamp.isMask(), "%s", mask);
GraalError.guarantee(srcStamp.getVectorLength() == maskStamp.getVectorLength(), "%s - %s", src, mask);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@
import jdk.graal.compiler.nodes.spi.CanonicalizerTool;
import jdk.graal.compiler.nodes.spi.CoreProviders;
import jdk.graal.compiler.vector.architecture.VectorArchitecture;
import jdk.graal.compiler.vector.nodes.simd.SimdConcatNode;
import jdk.graal.compiler.vector.nodes.simd.SimdConstant;
import jdk.graal.compiler.vector.nodes.simd.SimdCutNode;
import jdk.graal.compiler.vector.nodes.simd.SimdStamp;
import jdk.graal.compiler.vector.replacements.vectorapi.VectorAPIOperations;
import jdk.vm.ci.meta.JavaConstant;
Expand Down Expand Up @@ -195,22 +197,21 @@ public boolean canExpand(VectorArchitecture vectorArch, EconomicMap<ValueNode, S
boolean supportedDirectly = vectorArch.getSupportedVectorShiftWithScalarCount(elementStamp, vectorLength, op) == vectorLength;
if (supportedDirectly) {
return true;
} else {
/*
* Special case for byte shifts on AMD64: See if we can extend to shorts, shift, and
* narrow back to bytes. Not relevant for AArch64, which has native byte shifts and
* takes the "supportedDirectly" path above.
*/
if (PrimitiveStamp.getBits(elementStamp) == Byte.SIZE) {
IntegerStamp byteStamp = (IntegerStamp) elementStamp;
IntegerStamp shortStamp = StampFactory.forInteger(Short.SIZE);
ArithmeticOpTable.IntegerConvertOp<?> extend = (op.equals(byteStamp.getOps().getUShr()) ? byteStamp.getOps().getZeroExtend() : byteStamp.getOps().getSignExtend());
return vectorArch.getSupportedVectorConvertLength(shortStamp, byteStamp, vectorLength, extend) == vectorLength &&
vectorArch.getSupportedVectorShiftWithScalarCount(shortStamp, vectorLength, op) == vectorLength &&
vectorArch.getSupportedVectorConvertLength(byteStamp, shortStamp, vectorLength, shortStamp.getOps().getNarrow()) == vectorLength;
}
}
/*
* Special case for byte shifts on backends without direct full-width byte shift support:
* first try widen->shift->narrow for the full vector, then fall back to doing the same
* transform on two half vectors and concatenating the result. AArch64 typically takes the
* direct path above because it has native byte shifts.
*/
if (PrimitiveStamp.getBits(elementStamp) != Byte.SIZE) {
return false;
}
IntegerStamp byteStamp = (IntegerStamp) elementStamp;
if (canExpandByteShiftViaWidening(vectorArch, byteStamp, vectorLength)) {
return true;
}
return canExpandByteShiftViaLaneSplit(vectorArch, byteStamp, vectorLength);
}

@Override
Expand All @@ -225,15 +226,53 @@ public ValueNode expand(VectorArchitecture vectorArch, NodeMap<ValueNode> expand
boolean supportedDirectly = vectorArch.getSupportedVectorShiftWithScalarCount(elementStamp, vectorLength, op) == vectorLength;
if (supportedDirectly) {
return ShiftNode.shiftOp(value, shiftAmount, NodeView.DEFAULT, op);
} else {
GraalError.guarantee(PrimitiveStamp.getBits(elementStamp) == Byte.SIZE, "unexpected stamp: %s", elementStamp);
IntegerStamp byteStamp = (IntegerStamp) elementStamp;
ValueNode extendedVector = (op.equals(byteStamp.getOps().getUShr())
? ZeroExtendNode.create(value, Byte.SIZE, Short.SIZE, NodeView.DEFAULT)
: SignExtendNode.create(value, Byte.SIZE, Short.SIZE, NodeView.DEFAULT));
ValueNode shiftedVector = ShiftNode.shiftOp(extendedVector, shiftAmount, NodeView.DEFAULT, op);
ValueNode narrowedVector = NarrowNode.create(shiftedVector, Short.SIZE, Byte.SIZE, NodeView.DEFAULT);
return narrowedVector;
}
GraalError.guarantee(PrimitiveStamp.getBits(elementStamp) == Byte.SIZE, "unexpected stamp: %s", elementStamp);
IntegerStamp byteStamp = (IntegerStamp) elementStamp;
if (canExpandByteShiftViaWidening(vectorArch, byteStamp, vectorLength)) {
return emitByteShiftViaWidening(value, shiftAmount, byteStamp);
}
if (canExpandByteShiftViaLaneSplit(vectorArch, byteStamp, vectorLength)) {
return emitByteShiftViaLaneSplit(value, shiftAmount, byteStamp, vectorLength);
}
throw GraalError.shouldNotReachHere("byte vector shift cannot be expanded");
}

private boolean canExpandByteShiftViaWidening(VectorArchitecture vectorArch, IntegerStamp byteStamp, int vectorLength) {
IntegerStamp shortStamp = StampFactory.forInteger(Short.SIZE);
ArithmeticOpTable.IntegerConvertOp<?> extend = (op.equals(byteStamp.getOps().getUShr()) ? byteStamp.getOps().getZeroExtend() : byteStamp.getOps().getSignExtend());
return vectorArch.getSupportedVectorConvertLength(shortStamp, byteStamp, vectorLength, extend) == vectorLength &&
vectorArch.getSupportedVectorShiftWithScalarCount(shortStamp, vectorLength, op) == vectorLength &&
vectorArch.getSupportedVectorConvertLength(byteStamp, shortStamp, vectorLength, shortStamp.getOps().getNarrow()) == vectorLength;
}

private ValueNode emitByteShiftViaWidening(ValueNode value, ValueNode shiftAmount, IntegerStamp byteStamp) {
ValueNode extendedVector = (op.equals(byteStamp.getOps().getUShr())
? ZeroExtendNode.create(value, Byte.SIZE, Short.SIZE, NodeView.DEFAULT)
: SignExtendNode.create(value, Byte.SIZE, Short.SIZE, NodeView.DEFAULT));
ValueNode shiftedVector = ShiftNode.shiftOp(extendedVector, shiftAmount, NodeView.DEFAULT, op);
return NarrowNode.create(shiftedVector, Short.SIZE, Byte.SIZE, NodeView.DEFAULT);
}

private boolean canExpandByteShiftViaLaneSplit(VectorArchitecture vectorArch, IntegerStamp byteStamp, int vectorLength) {
if (vectorLength <= 1 || (vectorLength & 1) != 0) {
return false;
}
int halfLength = vectorLength / 2;
IntegerStamp shortStamp = StampFactory.forInteger(Short.SIZE);
ArithmeticOpTable.IntegerConvertOp<?> extend = (op.equals(byteStamp.getOps().getUShr()) ? byteStamp.getOps().getZeroExtend() : byteStamp.getOps().getSignExtend());
return vectorArch.supportsVectorConcat(halfLength * Byte.BYTES) &&
vectorArch.getSupportedVectorConvertLength(shortStamp, byteStamp, halfLength, extend) == halfLength &&
vectorArch.getSupportedVectorShiftWithScalarCount(shortStamp, halfLength, op) == halfLength &&
vectorArch.getSupportedVectorConvertLength(byteStamp, shortStamp, halfLength, shortStamp.getOps().getNarrow()) == halfLength;
}

private ValueNode emitByteShiftViaLaneSplit(ValueNode value, ValueNode shiftAmount, IntegerStamp byteStamp, int vectorLength) {
int halfLength = vectorLength / 2;
ValueNode lowBytes = new SimdCutNode(value, 0, halfLength);
ValueNode highBytes = new SimdCutNode(value, halfLength, halfLength);
ValueNode lowShifted = emitByteShiftViaWidening(lowBytes, shiftAmount, byteStamp);
ValueNode highShifted = emitByteShiftViaWidening(highBytes, shiftAmount, byteStamp);
return new SimdConcatNode(lowShifted, highShifted);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ public boolean canExpand(VectorArchitecture vectorArch, EconomicMap<ValueNode, S
if (opr == MASK_COMPRESS_OP) {
return elementStamp instanceof LogicValueStamp;
} else {
return vectorArch.getSupportedVectorCompressExpandLength(elementStamp, vectorStamp.getVectorLength()) == vectorStamp.getVectorLength();
VectorArchitecture.CompressExpandOp op = opr == COMPRESS_OP ? VectorArchitecture.CompressExpandOp.COMPRESS : VectorArchitecture.CompressExpandOp.EXPAND;
return vectorArch.getSupportedVectorCompressExpandLength(elementStamp, vectorStamp.getVectorLength(), op) == vectorStamp.getVectorLength();
}
}

Expand Down