Skip to content

Commit 91745cf

Browse files
authored
Bugfix: avoid squeezing batch size 1 dimension in lstm and gru emulation (#913)
Also fixes an issue of squeeze and unsqueeze emulation code. It should use for...of loop to get axes values rather than indices. Fix #889
1 parent 94bcacb commit 91745cf

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

index.bs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5630,14 +5630,14 @@ partial dictionary MLOpSupportLimits {
56305630
null);
56315631
let currentHidden = squeeze(
56325632
builder,
5633-
builder.slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize]));
5633+
builder.slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize]), [0]);
56345634

56355635
for (let step = 0; step < steps; ++step) {
56365636
const slice =
56375637
(dir == 1 || direction == 'backward' ? steps - step - 1 : step);
56385638
const currentInput = squeeze(
56395639
builder,
5640-
builder.slice(input, [slice, 0, 0], [1, batchSize, inputSize]));
5640+
builder.slice(input, [slice, 0, 0], [1, batchSize, inputSize]), [0]);
56415641

56425642
currentHidden = builder.gruCell(
56435643
currentInput,
@@ -7011,17 +7011,17 @@ partial dictionary MLOpSupportLimits {
70117011

70127012
let currentHidden = squeeze(
70137013
builder,
7014-
builder.slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize]));
7014+
builder.slice(hiddenState, [dir, 0, 0], [1, batchSize, hiddenSize]), [0]);
70157015
let currentCell = squeeze(
70167016
builder,
7017-
builder.slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize]));
7017+
builder.slice(cellState, [dir, 0, 0], [1, batchSize, hiddenSize]), [0]);
70187018

70197019
for (let step = 0; step < steps; ++step) {
70207020
const slice =
70217021
(dir == 1 || direction == 'backward' ? steps - step - 1 : step);
70227022
const currentInput = squeeze(
70237023
builder,
7024-
builder.slice(input, [slice, 0, 0], [1, batchSize, inputSize]));
7024+
builder.slice(input, [slice, 0, 0], [1, batchSize, inputSize]), [0]);
70257025

70267026
[currentHidden, currentCell] = builder.lstmCell(
70277027
currentInput,
@@ -10466,7 +10466,7 @@ Operations present in other neural network inference APIs can often be emulated
1046610466
axes.push(i);
1046710467
});
1046810468
const shape = Array.from(input.shape);
10469-
for (let axis in axes.sort().reverse())
10469+
for (let axis of axes.sort().reverse())
1047010470
if (axis < shape.length && shape[axis] == 1)
1047110471
shape.splice(axis, 1);
1047210472
return builder.reshape(input, shape);
@@ -10485,7 +10485,7 @@ Operations present in other neural network inference APIs can often be emulated
1048510485
<pre highlight="js">
1048610486
function unsqueeze(builder, input, axes) {
1048710487
const shape = Array.from(input.shape);
10488-
for (let axis in axes.sort())
10488+
for (let axis of axes.sort())
1048910489
shape.splice(axis, 0, 1);
1049010490
return builder.reshape(input, shape);
1049110491
}

0 commit comments

Comments
 (0)