-
Notifications
You must be signed in to change notification settings - Fork 520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Add per channel quantization support for Onnx.QLinearConv op #3917
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Vivek. I think you need to modify some of the output quantization handling in the per-channel case. Maybe store a bool that tracks if we are in the per-channel case so you can reuse it for the output.
It looks like this conversion automatically fuses the input and weight quantization with the convolution, so the only thing that fuse-quantized-ops is going to do is quantize the bias (which won't work currently in the per-channel case). I think it is fine, but we won't be able to check correctness e2e until we address the per-channel quantization, unfortunately.
return failure(); | ||
auto weightShape = weightTy.getSizes(); | ||
auto weightScaleShape = weightScaleTy.getSizes(); | ||
Value weightScaleScalar = extract(weightScale); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
extract won't work if the weight scale isn't a single element. I'd put this in the else block below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see you use this below to handle the quantization of the output, but this must also be per-channel if the weight is per-channel.
Value weightScaleScalar = extract(weightScale); | ||
if (weightScaleShape.size() == 1 && | ||
weightScaleShape[0] != Torch::kUnknownSize && | ||
weightScaleShape[0] == weightShape[0]) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additionally check that weightShape[0] != 1
since we don't want to lower to per-channel when there is only one channel.
} else { | ||
weightZp = extract(weightZp); | ||
weight = makePerTensor(weight, weightScaleScalar, weightZp); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bit of a nit, but I'd prefer an else if
here with the conditions for makePerTensor
, and then an else branch with an unreachable, just to be very clear about what assumptions are being made in each case.
|
||
cTy = rewriter.getType<Torch::ValueTensorType>( | ||
outputTy = rewriter.getType<Torch::ValueTensorType>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, this is a bit subtle. The last optional input for this op is the int32 bias, assumed to be quantized via the product of input and weight scales. This implies that the quantization of the bias (and also the output of the convolution) is also per-channel if the weight was per-channel quantized. This part is fine, but we will need to case out the logic below.
|
||
Value outScale = rewriter.create<Torch::AtenMulFloatOp>( | ||
binder.getLoc(), rewriter.getType<Torch::FloatType>(), aScale, | ||
bScale); | ||
binder.getLoc(), rewriter.getType<Torch::FloatType>(), inputScale, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will need to possibly be float x tensor mul.
This commit extends the OnnxToTorch Lowering for Onnx.QLinearConv op by adding the support for per channel quantization for the weight argument.
Signed-off-by: Vivek Khandelwal [email protected]