Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,14 @@ Status DynamicQuantizeLSTM::Compute(OpKernelContext* context) const {
const Tensor* r_zp = context->Input<Tensor>(11);

const TensorShape& W_zp_shape = w_zp->Shape();
const TensorShape& R_zp_shape = w_zp->Shape();
const TensorShape& R_zp_shape = r_zp->Shape();
const TensorShape& W_scale_shape = w_scale->Shape();
const TensorShape& R_scale_shape = r_scale->Shape();

WeightCheck(W_zp_shape, W_zero_point);
WeightCheck(R_zp_shape, R_zero_point);
WeightCheck(W_scale_shape, W_scale);
WeightCheck(W_scale_shape, R_scale);
WeightCheck(R_scale_shape, R_scale);

const bool is_W_signed = (W != nullptr) ? W->IsDataType<int8_t>() : is_W_signed_;
const bool is_R_signed = (R != nullptr) ? R->IsDataType<int8_t>() : is_R_signed_;
Expand Down
79 changes: 79 additions & 0 deletions onnxruntime/test/contrib_ops/quantize_lstm_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -525,5 +525,84 @@ TEST(DynamicQuantLSTMTest, SharedPrepackedWeights) {
}
#endif

// Builds a minimal per-tensor DynamicQuantizeLSTM and runs it expecting the kernel's input
// validation to reject the recurrence quantization parameters. The caller supplies the R_scale and
// R_zero_point shapes so a shape that is inconsistent with R (e.g. first dim != num_directions) can
// be exercised. The recurrence parameters must be validated symmetrically with the input (W) ones.
static void RunQuantLSTMExpectInvalidRecurrenceQuantParam(const std::vector<int64_t>& r_scale_dims,
const std::vector<int64_t>& r_zp_dims,
const std::string& expected_error) {
OpTester test("DynamicQuantizeLSTM", 1 /*opset_version*/, onnxruntime::kMSDomain /*domain*/);

constexpr int64_t num_directions = 1;
constexpr int64_t input_size = 2;
constexpr int64_t hidden_size = 2;
constexpr int64_t batch_size = 1;
constexpr int64_t seq_len = 1;

auto num_elements = [](const std::vector<int64_t>& dims) {
int64_t count = 1;
for (int64_t dim : dims) {
count *= dim;
}
return static_cast<size_t>(count);
};

test.AddAttribute<std::vector<std::string>>("activations", {"sigmoid", "tanh", "tanh"});
test.AddAttribute("direction", "forward");
test.AddAttribute("hidden_size", hidden_size);
test.AddAttribute<int64_t>("input_forget", static_cast<int64_t>(0));

// X: [seq_length, batch_size, input_size]
test.AddInput<float>("X", {seq_len, batch_size, input_size},
std::vector<float>(num_elements({seq_len, batch_size, input_size}), 0.0f));

// W / R quantized weight values are irrelevant: validation fails before any dequantization.
test.AddInput<uint8_t>("W", {num_directions, input_size, 4 * hidden_size},
std::vector<uint8_t>(num_elements({num_directions, input_size, 4 * hidden_size}), 0));
test.AddInput<uint8_t>("R", {num_directions, hidden_size, 4 * hidden_size},
std::vector<uint8_t>(num_elements({num_directions, hidden_size, 4 * hidden_size}), 0));

test.AddOptionalInputEdge<float>(); // B
test.AddOptionalInputEdge<int>(); // sequence_lens
test.AddInput<float>("initial_h", {num_directions, batch_size, hidden_size},
std::vector<float>(num_elements({num_directions, batch_size, hidden_size}), 0.0f));
test.AddInput<float>("initial_c", {num_directions, batch_size, hidden_size},
std::vector<float>(num_elements({num_directions, batch_size, hidden_size}), 0.0f));
test.AddOptionalInputEdge<float>(); // P

// Valid per-tensor quantization parameters for the input weights.
test.AddInput<float>("W_scale", {num_directions}, std::vector<float>(num_directions, 1.0f));
test.AddInput<uint8_t>("W_zero_point", {num_directions}, std::vector<uint8_t>(num_directions, 0));

// Recurrence parameters with caller-supplied (possibly inconsistent) shapes.
test.AddInput<float>("R_scale", r_scale_dims, std::vector<float>(num_elements(r_scale_dims), 1.0f));
test.AddInput<uint8_t>("R_zero_point", r_zp_dims, std::vector<uint8_t>(num_elements(r_zp_dims), 0));

// Placeholder outputs (not validated: the run fails during input validation).
test.AddOutput<float>("Y", {seq_len, num_directions, batch_size, hidden_size},
std::vector<float>(num_elements({seq_len, num_directions, batch_size, hidden_size}), 0.0f));
test.AddOutput<float>("Y_h", {num_directions, batch_size, hidden_size},
std::vector<float>(num_elements({num_directions, batch_size, hidden_size}), 0.0f));
test.AddOutput<float>("Y_c", {num_directions, batch_size, hidden_size},
std::vector<float>(num_elements({num_directions, batch_size, hidden_size}), 0.0f));

test.Run(OpTester::ExpectResult::kExpectFailure, expected_error);
}

TEST(DynamicQuantLSTMTest, RejectsInconsistentRecurrenceZeroPointShape) {
// R_zero_point's first dim must equal num_directions (1); {2} is inconsistent and must be rejected
// rather than silently validated against the input zero point's shape.
RunQuantLSTMExpectInvalidRecurrenceQuantParam(/*r_scale_dims=*/{1}, /*r_zp_dims=*/{2},
"Input R_zero_point must have shape");
}

TEST(DynamicQuantLSTMTest, RejectsInconsistentRecurrenceScaleShape) {
// R_scale's first dim must equal num_directions (1); {2} is inconsistent and must be rejected
// rather than silently validated against the input scale's shape.
RunQuantLSTMExpectInvalidRecurrenceQuantParam(/*r_scale_dims=*/{2}, /*r_zp_dims=*/{1},
"Input R_scale must have shape");
}

} // namespace test
} // namespace onnxruntime
Loading