@@ -141,8 +141,9 @@ class TensorWrapper {
141141 return this . mlContext . readTensor ( this . mlTensor ) ;
142142 }
143143
144- public sameTypeAndShape ( dataType : MLOperandDataType , shape : readonly number [ ] ) : boolean {
144+ public canReuseTensor ( context : MLContext , dataType : MLOperandDataType , shape : readonly number [ ] ) : boolean {
145145 return (
146+ this . mlContext === context &&
146147 this . dataType === dataType &&
147148 this . tensorShape . length === shape . length &&
148149 this . tensorShape . every ( ( v , i ) => v === shape [ i ] )
@@ -176,12 +177,13 @@ class TensorIdTracker {
176177 }
177178
178179 public async ensureTensor (
180+ context : MLContext ,
179181 dataType : MLOperandDataType ,
180182 shape : readonly number [ ] ,
181183 copyOld : boolean ,
182184 ) : Promise < MLTensor > {
183185 if ( this . wrapper ) {
184- if ( this . wrapper . sameTypeAndShape ( dataType , shape ) ) {
186+ if ( this . wrapper . canReuseTensor ( context , dataType , shape ) ) {
185187 return this . wrapper . tensor ;
186188 } else {
187189 if ( copyOld ) {
@@ -288,7 +290,7 @@ class TensorManagerImpl implements TensorManager {
288290 if ( ! tensor ) {
289291 throw new Error ( 'Tensor not found.' ) ;
290292 }
291- return tensor . ensureTensor ( dataType , shape , copyOld ) ;
293+ return tensor . ensureTensor ( this . backend . currentContext , dataType , shape , copyOld ) ;
292294 }
293295
294296 public upload ( tensorId : TensorId , data : Uint8Array ) : void {
@@ -354,15 +356,15 @@ class TensorManagerImpl implements TensorManager {
354356 readable : boolean ,
355357 ) : Promise < TensorWrapper > {
356358 const sessionId = this . backend . currentSessionId ;
359+ const context = this . backend . currentContext ;
357360 for ( const [ index , tensor ] of this . freeTensors . entries ( ) ) {
358- if ( tensor . sameTypeAndShape ( dataType , shape ) ) {
361+ if ( tensor . canReuseTensor ( context , dataType , shape ) ) {
359362 LOG_DEBUG ( 'verbose' , ( ) => `[WebNN] Reusing tensor {dataType: ${ dataType } , shape: ${ shape } }` ) ;
360363 const wrapper = this . freeTensors . splice ( index , 1 ) [ 0 ] ;
361364 wrapper . sessionId = sessionId ;
362365 return wrapper ;
363366 }
364367 }
365- const context = this . backend . currentContext ;
366368 LOG_DEBUG ( 'verbose' , ( ) => `[WebNN] MLContext.createTensor {dataType: ${ dataType } , shape: ${ shape } }` ) ;
367369 const tensor = await context . createTensor ( {
368370 dataType,
0 commit comments