Skip to content

Commit a316188

Browse files
egalliashrit-ms
authored andcommitted
[WebNN EP] Automatically use ml-tensor for outputs (#24282)
### Description If it would improve performance, this patch moves outputs to MLTensor backed Tensors. ### Motivation and Context We are currently performing an extra copy on output tensors located in the CPU when using the WebNN EP (MLTensor -(copy)-> wasm heap -(copy)-> JS). This patch removes this copy by moving the readback to JS instead of wasm. As an extra benefit, we can also start the readbacks and wait for them in parallel. This change is similar to #23073
1 parent 63b8670 commit a316188

File tree

5 files changed

+87
-6
lines changed

5 files changed

+87
-6
lines changed

js/web/lib/wasm/jsep/backend-webnn.ts

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,11 +79,20 @@ export class WebNNBackend {
7979
* Maps from session id to list of graph inputs.
8080
*/
8181
private sessionGraphInputs: Map<number, string[]> = new Map();
82+
/**
83+
* Maps from session id to list of graph outputs.
84+
*/
85+
private sessionGraphOutputs: Map<number, string[]> = new Map();
8286
/**
8387
* Temporary graph inputs for the current session.
8488
* These inputs will be registered when the session is created.
8589
*/
8690
private temporaryGraphInputs: string[] = [];
91+
/**
92+
* Temporary graph outputs for the current session.
93+
* These outputs will be registered when the session is created.
94+
*/
95+
private temporaryGraphOutputs: string[] = [];
8796
/**
8897
* Temporary tensors for the current session.
8998
*/
@@ -167,10 +176,15 @@ export class WebNNBackend {
167176
this.sessionGraphInputs.set(sessionId, this.temporaryGraphInputs);
168177
this.temporaryGraphInputs = [];
169178
}
179+
if (this.temporaryGraphOutputs.length > 0) {
180+
this.sessionGraphOutputs.set(sessionId, this.temporaryGraphOutputs);
181+
this.temporaryGraphOutputs = [];
182+
}
170183
}
171184

172185
public onReleaseSession(sessionId: number): void {
173186
this.sessionGraphInputs.delete(sessionId);
187+
this.sessionGraphOutputs.delete(sessionId);
174188
const mlContext = this.mlContextBySessionId.get(sessionId)!;
175189
if (!mlContext) {
176190
// Current session is not a WebNN session.
@@ -363,6 +377,10 @@ export class WebNNBackend {
363377
this.temporaryGraphInputs.push(inputName);
364378
}
365379

380+
public registerGraphOutput(outputName: string): void {
381+
this.temporaryGraphOutputs.push(outputName);
382+
}
383+
366384
public isGraphInput(sessionId: number, inputName: string): boolean {
367385
const inputNames = this.sessionGraphInputs.get(sessionId);
368386
if (!inputNames) {
@@ -371,6 +389,14 @@ export class WebNNBackend {
371389
return inputNames.includes(inputName);
372390
}
373391

392+
public isGraphOutput(sessionId: number, outputName: string): boolean {
393+
const outputNames = this.sessionGraphOutputs.get(sessionId);
394+
if (!outputNames) {
395+
return false;
396+
}
397+
return outputNames.includes(outputName);
398+
}
399+
374400
public isInt64Supported(sessionId: number): boolean {
375401
const context = this.mlContextBySessionId.get(sessionId);
376402
return !!context?.opSupportLimits().input.dataTypes.includes('int64');

js/web/lib/wasm/wasm-core-impl.ts

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,13 @@ export const initEp = async (env: Env, epName: string): Promise<void> => {
172172
/**
173173
* valid data locations for input/output tensors.
174174
*/
175-
type SupportedTensorDataLocationForInputOutput = 'cpu' | 'cpu-pinned' | 'gpu-buffer' | 'ml-tensor';
175+
type SupportedTensorDataLocationForInputOutput =
176+
| 'cpu'
177+
| 'cpu-pinned'
178+
| 'gpu-buffer'
179+
| 'ml-tensor'
180+
// Use 'ml-tensor' during inference, but output a tensor located on the CPU.
181+
| 'ml-tensor-cpu-output';
176182

177183
type IOBindingState = {
178184
/**
@@ -424,6 +430,11 @@ export const createSession = async (
424430
typeof options?.preferredOutputLocation === 'string'
425431
? options.preferredOutputLocation
426432
: (options?.preferredOutputLocation?.[nameString] ?? 'cpu');
433+
const isGraphOutput = wasm.webnnIsGraphOutput;
434+
if (location === 'cpu' && isGraphOutput && isGraphOutput(sessionHandle, nameString)) {
435+
outputPreferredLocations.push('ml-tensor-cpu-output');
436+
continue;
437+
}
427438
if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer' && location !== 'ml-tensor') {
428439
throw new Error(`Not supported preferred output location: ${location}.`);
429440
}
@@ -438,7 +449,10 @@ export const createSession = async (
438449

439450
// use IO binding only when at least one output is preferred to be on GPU.
440451
let bindingState: IOBindingState | null = null;
441-
if (!BUILD_DEFS.DISABLE_JSEP && outputPreferredLocations.some((l) => l === 'gpu-buffer' || l === 'ml-tensor')) {
452+
if (
453+
!BUILD_DEFS.DISABLE_JSEP &&
454+
outputPreferredLocations.some((l) => l === 'gpu-buffer' || l === 'ml-tensor' || l === 'ml-tensor-cpu-output')
455+
) {
442456
ioBindingHandle = wasm._OrtCreateBinding(sessionHandle);
443457
if (ioBindingHandle === 0) {
444458
checkLastError("Can't create IO binding.");
@@ -447,7 +461,10 @@ export const createSession = async (
447461
bindingState = {
448462
handle: ioBindingHandle,
449463
outputPreferredLocations,
450-
outputPreferredLocationsEncoded: outputPreferredLocations.map((l) => dataLocationStringToEnum(l)),
464+
outputPreferredLocationsEncoded: outputPreferredLocations
465+
// 'ml-tensor-cpu-output' is treated as 'ml-tensor' for the purpose of IO binding.
466+
.map((l) => (l === 'ml-tensor-cpu-output' ? 'ml-tensor' : l))
467+
.map((l) => dataLocationStringToEnum(l)),
451468
};
452469
}
453470

@@ -599,10 +616,11 @@ export const prepareInputOutputTensor = async (
599616
}
600617
} else {
601618
const isGraphInput = wasm.webnnIsGraphInput;
602-
if (dataType !== 'string' && isGraphInput) {
619+
const isGraphOutput = wasm.webnnIsGraphOutput;
620+
if (dataType !== 'string' && isGraphInput && isGraphOutput) {
603621
const tensorName = wasm.UTF8ToString(tensorNameUTF8Encoded);
604622
// Promote the tensor to 'ml-tensor' if it is a graph input.
605-
if (isGraphInput(sessionId, tensorName)) {
623+
if (isGraphInput(sessionId, tensorName) || isGraphOutput(sessionId, tensorName)) {
606624
const dataTypeEnum = tensorDataTypeStringToEnum(dataType);
607625
dataByteLength = calculateTensorSizeInBytes(dataTypeEnum, dims)!;
608626
actualLocation = 'ml-tensor';
@@ -810,6 +828,7 @@ export const run = async (
810828
}
811829

812830
const output: TensorMetadata[] = [];
831+
const outputPromises: Array<Promise<[number, Tensor.DataType]>> = [];
813832

814833
for (let i = 0; i < outputCount; i++) {
815834
const tensor = Number(wasm.getValue(outputValuesOffset + i * ptrSize, '*'));
@@ -958,6 +977,20 @@ export const run = async (
958977
},
959978
'ml-tensor',
960979
]);
980+
} else if (preferredLocation === 'ml-tensor-cpu-output' && size > 0) {
981+
const data = wasm.webnnCreateMLTensorDownloader!(dataOffset, type as Tensor.MLTensorDataTypes)();
982+
const index = output.length;
983+
// Delay the data download and releasing the tensor until we can wait for all output tensors to be downloaded.
984+
keepOutputTensor = true;
985+
outputPromises.push(
986+
(async () => {
987+
const result: [number, Tensor.DataType] = [index, await data];
988+
wasm.webnnReleaseTensorId!(dataOffset);
989+
wasm._OrtReleaseTensor(tensor);
990+
return result;
991+
})(),
992+
);
993+
output.push([type, dims, [], 'cpu']);
961994
} else {
962995
const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type);
963996
const data = new typedArrayConstructor(size);
@@ -975,7 +1008,6 @@ export const run = async (
9751008
if (!keepOutputTensor) {
9761009
wasm._OrtReleaseTensor(tensor);
9771010
}
978-
wasm.webnnOnRunEnd?.(sessionHandle);
9791011
}
9801012
}
9811013

@@ -992,8 +1024,14 @@ export const run = async (
9921024
false,
9931025
]);
9941026
}
1027+
// Wait for all output tensor data to be downloaded.
1028+
for (const [index, data] of await Promise.all(outputPromises)) {
1029+
output[index][2] = data;
1030+
}
9951031
return output;
9961032
} finally {
1033+
wasm.webnnOnRunEnd?.(sessionHandle);
1034+
9971035
wasm.stackRestore(beforeRunStack);
9981036

9991037
if (BUILD_DEFS.USE_WEBGPU_EP) {

js/web/lib/wasm/wasm-types.ts

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,19 @@ export declare namespace JSEP {
287287
* @returns whether the input is a WebNN graph input.
288288
*/
289289
webnnIsGraphInput: (sessionId: number, inputName: string) => boolean;
290+
/**
291+
* [exported from pre-jsep.js] Register a WebNN graph output.
292+
* @param outputName - specify the output name.
293+
*/
294+
webnnRegisterGraphOutput: (outputName: string) => void;
295+
/**
296+
* [exported from pre-jsep.js] Check if a graph output is a WebNN graph output.
297+
* @param sessionId - specify the session ID.
298+
* @param outputName - specify the output name.
299+
* @returns whether the output is a WebNN graph output.
300+
*/
301+
webnnIsGraphOutput: (sessionId: number, outputName: string) => boolean;
302+
290303
/**
291304
* [exported from pre-jsep.js] Create a temporary MLTensor for a session.
292305
* @param sessionId - specify the session ID.

onnxruntime/core/providers/webnn/builders/model_builder.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i
300300
emscripten::val::module_property("webnnRegisterGraphInput")(name);
301301
input_names_.push_back(name);
302302
} else {
303+
emscripten::val::module_property("webnnRegisterGraphOutput")(name);
303304
output_names_.push_back(name);
304305
}
305306

onnxruntime/wasm/pre-jsep.js

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ Module["jsepInit"] = (name, params) => {
151151
Module["webnnRegisterGraphInput"] =
152152
backend["registerGraphInput"].bind(backend);
153153
Module["webnnIsGraphInput"] = backend["isGraphInput"].bind(backend);
154+
Module["webnnRegisterGraphOutput"] =
155+
backend["registerGraphOutput"].bind(backend);
156+
Module["webnnIsGraphOutput"] = backend["isGraphOutput"].bind(backend);
154157

155158
Module["webnnCreateTemporaryTensor"] =
156159
backend["createTemporaryTensor"].bind(backend);

0 commit comments

Comments
 (0)