Skip to content
This repository was archived by the owner on Nov 16, 2023. It is now read-only.

Fix an error caused by glsl mod() function #294

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions lib/backends/webgl/ops/conv-pack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ import {WebGLReshapePacked} from './reshape-packed';
export class WebGLConvPacked extends Conv {
protected artifacts: Artifact[];
protected programInfo: ProgramInfo[];
private kernelReshape = new WebGLReshapePacked();
private im2col: WebGLIm2ColPacked;
private matmul = new WebGLMatMulPacked();
private outputReshape = new WebGLReshapePacked();

run(inferenceHandler: WebGLInferenceHandler, inputs: Tensor[]): Tensor[] {
const programManager = inferenceHandler.session.programManager;
Expand All @@ -34,30 +38,31 @@ export class WebGLConvPacked extends Conv {
this.kernelShape}, pads:${this.pads}, strides:${this.strides}`);

const outputShape = WebGLConv.calcOutputShape(xshape, kshape, this.dilations, this.pads, this.strides);
const im2col = new WebGLIm2ColPacked(outputShape, kshape, this.dilations, this.pads, this.strides);
const matmul = new WebGLMatMulPacked();
const reshape = new WebGLReshapePacked();
if (this.im2col === undefined) {
this.im2col = new WebGLIm2ColPacked(outputShape, kshape, this.dilations, this.pads, this.strides);
}
// shape for kernel reshape
const shape =
new Tensor([2], 'int32', undefined, undefined, new Int32Array([kshape[0], kshape[1] * kshape[2] * kshape[3]]));
if (!this.artifacts) {
this.artifacts = [];
this.programInfo = [];
this.programInfo[0] = im2col.createProgramInfo(inferenceHandler, [inputs[0], inputs[1]]);
this.programInfo[0] = this.im2col.createProgramInfo(inferenceHandler, [inputs[0], inputs[1]]);
this.artifacts[0] = programManager.build(this.programInfo[0]);

this.programInfo[1] = reshape.createProgramInfo(inferenceHandler, [inputs[1], shape]);
this.programInfo[1] = this.kernelReshape.createProgramInfo(inferenceHandler, [inputs[1], shape]);
this.artifacts[1] = programManager.build(this.programInfo[1]);
}

// run im2col
const runDataIm2col = im2col.createRunData(inferenceHandler, this.programInfo[0], [inputs[0], inputs[1]]);
const runDataIm2col = this.im2col.createRunData(inferenceHandler, this.programInfo[0], [inputs[0], inputs[1]]);
inferenceHandler.checkAndUpdateTextureForm(this.artifacts[0], runDataIm2col);
programManager.run(this.artifacts[0], runDataIm2col);
const im2colOutput = runDataIm2col.outputTextureData.tensor;

// reshape kernel
const runDataKernelReshape = reshape.createRunData(inferenceHandler, this.programInfo[1], [inputs[1], shape]);
const runDataKernelReshape =
this.kernelReshape.createRunData(inferenceHandler, this.programInfo[1], [inputs[1], shape]);
inferenceHandler.checkAndUpdateTextureForm(this.artifacts[1], runDataKernelReshape);
programManager.run(this.artifacts[1], runDataKernelReshape);
const kernelReshaped = runDataKernelReshape.outputTextureData.tensor;
Expand All @@ -66,11 +71,11 @@ export class WebGLConvPacked extends Conv {
const hasBias = (inputs.length === 3);
assert(this.artifacts.length > 1, () => 'expect at least 2 artifacts created');
if (this.artifacts.length === 2) {
this.programInfo[2] = matmul.createProgramInfo(
this.programInfo[2] = this.matmul.createProgramInfo(
inferenceHandler, hasBias ? [kernelReshaped, im2colOutput, inputs[2]] : [kernelReshaped, im2colOutput]);
this.artifacts[2] = programManager.build(this.programInfo[2]);
}
const runDataMatmul = matmul.createRunData(
const runDataMatmul = this.matmul.createRunData(
inferenceHandler, this.programInfo[2],
hasBias ? [kernelReshaped, im2colOutput, inputs[2]] : [kernelReshaped, im2colOutput]);
inferenceHandler.checkAndUpdateTextureForm(this.artifacts[2], runDataMatmul);
Expand All @@ -84,11 +89,11 @@ export class WebGLConvPacked extends Conv {

assert(this.artifacts.length > 2, () => 'expect at least 3 artifacts created');
if (this.artifacts.length === 3) {
this.programInfo[3] = reshape.createProgramInfo(inferenceHandler, [matmulOutput, outputShapeTensor]);
this.programInfo[3] = this.outputReshape.createProgramInfo(inferenceHandler, [matmulOutput, outputShapeTensor]);
this.artifacts[3] = programManager.build(this.programInfo[3]);
}
const runDataOutputReshape =
reshape.createRunData(inferenceHandler, this.programInfo[3], [matmulOutput, outputShapeTensor]);
this.outputReshape.createRunData(inferenceHandler, this.programInfo[3], [matmulOutput, outputShapeTensor]);
inferenceHandler.checkAndUpdateTextureForm(this.artifacts[3], runDataOutputReshape);
programManager.run(this.artifacts[3], runDataOutputReshape);
return [runDataOutputReshape.outputTextureData.tensor];
Expand Down
7 changes: 3 additions & 4 deletions lib/backends/webgl/ops/im2col-pack.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,11 @@ export class WebGLIm2ColPacked implements WebGLOperator {

if(blockIndex < ${im2colShape[1]} && pos < ${im2colShape[0]}) {
offsetY = int(blockIndex / (${this.convOutputShape[rank - 1]})) * ${this.strides[0]} - ${this.pads[1]};
d0 = offsetY + ${this.dilations[0]} * (int(mod(float(pos), ${kernelSize}.)) / ${wshape[2]} );
d0 = offsetY + ${this.dilations[0]} * (imod(pos, ${kernelSize}) / ${wshape[2]});

if(d0 < ${xshape[rowDim]} && d0 >= 0) {
offsetX = int(mod(float(blockIndex), ${this.convOutputShape[rank - 1]}.) * ${this.strides[1]}. - ${
this.pads[0]}.);
d1 = offsetX + ${this.dilations[1]} * (int(mod(mod(float(pos), ${kernelSize}.), ${wshape[2]}.)));
offsetX = imod(blockIndex, ${this.convOutputShape[rank - 1]}) * ${this.strides[1]} - ${this.pads[0]};
d1 = offsetX + ${this.dilations[1]} * imod(imod(pos, ${kernelSize}), ${wshape[2]});

if(d1 < ${xshape[colDim]} && d1 >= 0) {

Expand Down
46 changes: 35 additions & 11 deletions lib/backends/webgl/ops/reshape-packed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {Tensor} from '../../../tensor';
import {ShapeUtil} from '../../../util';
import {getGlsl} from '../glsl-source';
import {WebGLInferenceHandler} from '../inference-handler';
import {ProgramInfo, RunData, WebGLOperator} from '../types';
import {ProgramInfo, RunData, TextureData, WebGLOperator} from '../types';
import {TextureLayout} from '../types';

import {unpackFromChannel} from './packing_utils';
Expand All @@ -32,14 +32,18 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator {
// the same between input shape and output shape, the packed reshape can be
// treated as no-op.
const originInputShape = inputs[0].dims;
const inputShape3D = processDims3D(inputs[0].dims);
this.inputShape3D = processDims3D(inputs[0].dims);
let inputLayout: TextureLayout;
if (originInputShape.length === 3) {
inputLayout = handler.getOrCreateTextureLayout(inputs[0], 4, true, originInputShape, true);
} else {
inputLayout = handler.getOrCreateTextureLayout(inputs[0], 4, true, originInputShape, true);
if (originInputShape.length !== 3) {
const originalInputLayout = inputLayout;
// if originShape is not a 3D shape, create texture layout from the processed shape.
inputLayout =
handler.createTextureLayoutFromShape(inputShape3D, 4, inputShape3D, {isPacked: true, reverseWH: true});
inputLayout = handler.createTextureLayoutFromShape(
this.inputShape3D, 4, this.inputShape3D, {isPacked: true, reverseWH: true});
// if the processed input shape produces texture layout differnt from the original
// one, the run data has to use the processed (3D) input shape later.
this.needSqueezeInputData =
(inputLayout.height !== originalInputLayout.height) || (inputLayout.width !== originalInputLayout.width);
}

this.outputShape = ShapeUtil.calculateReshapedDims(originInputShape, inputs[1].integerData);
Expand Down Expand Up @@ -86,9 +90,10 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator {
const glsl = getGlsl(handler.session.backend.glContext.version);

const shaderSource = `
${getReshapedInputCoords(inputShape3D)}
${getReshapedInputCoords(this.inputShape3D)}
${getFlattenedIndexFrom3D(squeezedOutputShape)}
${unpackFromChannel()}

void main() {
ivec3 rc = getOutputCoords();

Expand All @@ -99,7 +104,6 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator {
int cols = ${squeezedOutputShape[1]};

${mainLoop}

${glsl.output} = result;
}
`;
Expand All @@ -115,8 +119,26 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator {
};
}
createRunData(handler: WebGLInferenceHandler, programInfo: ProgramInfo, inputs: Tensor[]): RunData {
const inputTDs =
[handler.getOrCreateTextureData(inputs[0], handler.getOrCreateTextureLayout(inputs[0], 1, false, [], false))];
let inputTDs: [TextureData];
const originalInputLayout = handler.getOrCreateTextureLayout(inputs[0], 1, false, [], false);
const originalInputTD = handler.getOrCreateTextureData(inputs[0], originalInputLayout, false);

if (this.needSqueezeInputData) {
const squeezedInputLayout: TextureLayout = {
channels: 1,
height: originalInputLayout.height,
width: originalInputLayout.width,
shape: this.inputShape3D,
strides: ShapeUtil.computeStrides(this.inputShape3D),
unpackedShape: this.inputShape3D,
};
const squeezedInputTD =
handler.createSharedTextureData(squeezedInputLayout, inputs[0].type, originalInputTD.texture);
inputTDs = [squeezedInputTD];

} else {
inputTDs = [originalInputTD];
}
let outputLayout = this.originalOutputLayout;
if (outputLayout === undefined) {
const originInputShape = inputs[0].dims;
Expand All @@ -133,6 +155,8 @@ export class WebGLReshapePacked extends Reshape implements WebGLOperator {
}
protected outputShape: ReadonlyArray<number>;
private originalOutputLayout: TextureLayout;
private inputShape3D: [number, number, number];
private needSqueezeInputData = false;
}

function processDims3D(shape: readonly number[]|ReadonlyArray<number>|Tensor.IntegerType): [number, number, number] {
Expand Down
5 changes: 5 additions & 0 deletions test/unittests/backends/webgl/test_reshape_packed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,5 +156,10 @@ function getTestData(): TestData[] {
inputShape: [2, 2, 2, 4],
outputShape: [2, 1, 4, 4],
},
{
elementCount: 18432,
inputShape: [512, 36, 1, 1],
outputShape: [512, 36],
},
];
}