Skip to content

Commit 54edb43

Browse files
authored
[WebNN] Fixes MLTensor caching across different contexts (#23100)
We weren't checking that MLTensors were from the same context before reusing them. Found while debugging microsoft/webnn-developer-preview#69
1 parent 5afab78 commit 54edb43

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

js/web/lib/wasm/jsep/webnn/tensor-manager.ts

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,9 @@ class TensorWrapper {
141141
return this.mlContext.readTensor(this.mlTensor);
142142
}
143143

144-
public sameTypeAndShape(dataType: MLOperandDataType, shape: readonly number[]): boolean {
144+
public canReuseTensor(context: MLContext, dataType: MLOperandDataType, shape: readonly number[]): boolean {
145145
return (
146+
this.mlContext === context &&
146147
this.dataType === dataType &&
147148
this.tensorShape.length === shape.length &&
148149
this.tensorShape.every((v, i) => v === shape[i])
@@ -176,12 +177,13 @@ class TensorIdTracker {
176177
}
177178

178179
public async ensureTensor(
180+
context: MLContext,
179181
dataType: MLOperandDataType,
180182
shape: readonly number[],
181183
copyOld: boolean,
182184
): Promise<MLTensor> {
183185
if (this.wrapper) {
184-
if (this.wrapper.sameTypeAndShape(dataType, shape)) {
186+
if (this.wrapper.canReuseTensor(context, dataType, shape)) {
185187
return this.wrapper.tensor;
186188
} else {
187189
if (copyOld) {
@@ -288,7 +290,7 @@ class TensorManagerImpl implements TensorManager {
288290
if (!tensor) {
289291
throw new Error('Tensor not found.');
290292
}
291-
return tensor.ensureTensor(dataType, shape, copyOld);
293+
return tensor.ensureTensor(this.backend.currentContext, dataType, shape, copyOld);
292294
}
293295

294296
public upload(tensorId: TensorId, data: Uint8Array): void {
@@ -354,15 +356,15 @@ class TensorManagerImpl implements TensorManager {
354356
readable: boolean,
355357
): Promise<TensorWrapper> {
356358
const sessionId = this.backend.currentSessionId;
359+
const context = this.backend.currentContext;
357360
for (const [index, tensor] of this.freeTensors.entries()) {
358-
if (tensor.sameTypeAndShape(dataType, shape)) {
361+
if (tensor.canReuseTensor(context, dataType, shape)) {
359362
LOG_DEBUG('verbose', () => `[WebNN] Reusing tensor {dataType: ${dataType}, shape: ${shape}}`);
360363
const wrapper = this.freeTensors.splice(index, 1)[0];
361364
wrapper.sessionId = sessionId;
362365
return wrapper;
363366
}
364367
}
365-
const context = this.backend.currentContext;
366368
LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`);
367369
const tensor = await context.createTensor({
368370
dataType,

0 commit comments

Comments
 (0)