Skip to content

Commit 61f4de5

Browse files
committed
[js/webgpu] Donot record with computePassEncoder when capturing
1 parent 497b06f commit 61f4de5

File tree

3 files changed

+77
-40
lines changed

3 files changed

+77
-40
lines changed

js/web/lib/wasm/jsep/backend-webgpu.ts

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,21 @@ import {
2323
TimestampQuery,
2424
} from './webgpu/types';
2525

26-
interface CommandInfo {
26+
interface ComputeCommand {
2727
readonly kernelId: number;
2828
readonly computePipeline: GPUComputePipeline;
2929
readonly bindGroup: GPUBindGroup;
3030
readonly dispatchGroup: [number, number, number];
3131
}
3232

33+
interface MemcpyCommand {
34+
readonly source: GPUBuffer;
35+
readonly dest: GPUBuffer;
36+
readonly size: number;
37+
}
38+
39+
type Command = ComputeCommand | MemcpyCommand;
40+
3341
interface KernelInfo {
3442
readonly kernelType: string;
3543
readonly kernelName: string;
@@ -234,9 +242,9 @@ export class WebGpuBackend {
234242
env: Env;
235243
sessionStatus: SessionState = 'default';
236244
/**
237-
* a SessionID -> CommandInfo[] mapping. It's used to record all GPU commands for corresponding session.
245+
* a SessionID -> Command[] mapping. It's used to record all GPU commands for corresponding session.
238246
*/
239-
capturedCommandList: Map<number, CommandInfo[]> = new Map();
247+
capturedCommandList: Map<number, Command[]> = new Map();
240248

241249
/**
242250
* a SessionID -> PendingKernelInfo[] mapping for profiling.
@@ -837,13 +845,19 @@ export class WebGpuBackend {
837845
}
838846
return gpuData.buffer;
839847
}
848+
849+
async replayAndDownloadGpuData(gpuBuffer: GPUBuffer, originalSize: number): Promise<Uint8Array> {
850+
this.replay();
851+
return downloadGpuData(this, gpuBuffer, originalSize);
852+
}
853+
840854
createDownloader(
841855
gpuBuffer: GPUBuffer,
842856
size: number,
843857
type: Tensor.GpuBufferDataTypes,
844858
): () => Promise<Tensor.DataType> {
845859
return async () => {
846-
const data = await downloadGpuData(this, gpuBuffer, size);
860+
const data = await this.replayAndDownloadGpuData(gpuBuffer, size);
847861
return createView(data.buffer, type);
848862
};
849863
}
@@ -909,18 +923,27 @@ export class WebGpuBackend {
909923
for (let i = 0; i < length; i++) {
910924
const computePassEncoder = this.getComputePassEncoder();
911925
const command = sessionCommandList![i];
912-
this.writeTimestamp(this.pendingDispatchNumber * 2);
913-
computePassEncoder.setPipeline(command.computePipeline);
914-
computePassEncoder.setBindGroup(0, command.bindGroup);
915-
computePassEncoder.dispatchWorkgroups(...command.dispatchGroup);
916-
this.writeTimestamp(this.pendingDispatchNumber * 2 + 1);
917-
this.pendingDispatchNumber++;
918-
if (this.queryType !== 'none') {
919-
this.pendingKernels.push(sessionPendingKernels![i]);
920-
}
921-
if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') {
926+
if ('bindGroup' in command) {
927+
this.writeTimestamp(this.pendingDispatchNumber * 2);
928+
computePassEncoder.setPipeline(command.computePipeline);
929+
computePassEncoder.setBindGroup(0, command.bindGroup);
930+
computePassEncoder.dispatchWorkgroups(...command.dispatchGroup);
931+
this.writeTimestamp(this.pendingDispatchNumber * 2 + 1);
932+
933+
this.pendingDispatchNumber++;
934+
if (this.queryType !== 'none') {
935+
this.pendingKernels.push(sessionPendingKernels![i]);
936+
}
937+
if (this.pendingDispatchNumber >= this.maxDispatchNumber || this.queryType === 'at-passes') {
938+
this.endComputePass();
939+
}
940+
} else {
941+
const commandEncoder = this.getCommandEncoder();
942+
this.pendingDispatchNumber++;
922943
this.endComputePass();
944+
commandEncoder.copyBufferToBuffer(command.source, 0, command.dest, 0, command.size);
923945
}
946+
924947
if (this.pendingDispatchNumber >= this.maxDispatchNumber) {
925948
this.flush();
926949
}

js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -274,16 +274,28 @@ class GpuDataManagerImpl implements GpuDataManager {
274274

275275
const size = calcNormalizedBufferSize(sourceGpuDataCache.originalSize);
276276

277-
// GPU copy
278-
const commandEncoder = this.backend.getCommandEncoder();
279-
this.backend.endComputePass();
280-
commandEncoder.copyBufferToBuffer(
281-
sourceGpuDataCache.gpuData.buffer,
282-
0,
283-
destinationGpuDataCache.gpuData.buffer,
284-
0,
285-
size,
286-
);
277+
if (this.backend.sessionStatus === 'capturing') {
278+
const command = {
279+
source: sourceGpuDataCache.gpuData.buffer,
280+
dest: destinationGpuDataCache.gpuData.buffer,
281+
size,
282+
};
283+
const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!);
284+
sessionCommandList!.push(command);
285+
286+
this.backend.pendingDispatchNumber++;
287+
} else {
288+
// GPU copy
289+
const commandEncoder = this.backend.getCommandEncoder();
290+
this.backend.endComputePass();
291+
commandEncoder.copyBufferToBuffer(
292+
sourceGpuDataCache.gpuData.buffer,
293+
0,
294+
destinationGpuDataCache.gpuData.buffer,
295+
0,
296+
size,
297+
);
298+
}
287299
}
288300

289301
registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previous?: [GpuDataId, GPUBuffer]): number {

js/web/lib/wasm/jsep/webgpu/program-manager.ts

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ export class ProgramManager {
4141
): void {
4242
TRACE_FUNC_BEGIN(buildArtifact.programInfo.name);
4343
const device = this.backend.device;
44-
const computePassEncoder = this.backend.getComputePassEncoder();
45-
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2);
4644
const entries = [];
4745
for (const input of inputs) {
4846
entries.push({ binding: entries.length, resource: { buffer: input.buffer } });
@@ -68,23 +66,27 @@ export class ProgramManager {
6866
};
6967
const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!);
7068
sessionCommandList!.push(commandInfo);
71-
}
69+
} else {
70+
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2);
71+
const computePassEncoder = this.backend.getComputePassEncoder();
72+
computePassEncoder.setPipeline(buildArtifact.computePipeline);
73+
computePassEncoder.setBindGroup(0, bindGroup);
74+
computePassEncoder.dispatchWorkgroups(...dispatchGroup);
7275

73-
computePassEncoder.setPipeline(buildArtifact.computePipeline);
74-
computePassEncoder.setBindGroup(0, bindGroup);
75-
computePassEncoder.dispatchWorkgroups(...dispatchGroup);
76-
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1);
77-
this.backend.pendingDispatchNumber++;
76+
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1);
77+
this.backend.pendingDispatchNumber++;
7878

79-
if (
80-
this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber ||
81-
this.backend.queryType === 'at-passes'
82-
) {
83-
this.backend.endComputePass();
84-
}
85-
if (this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber) {
86-
this.backend.flush();
79+
if (
80+
this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber ||
81+
this.backend.queryType === 'at-passes'
82+
) {
83+
this.backend.endComputePass();
84+
}
85+
if (this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber) {
86+
this.backend.flush();
87+
}
8788
}
89+
8890
TRACE_FUNC_END(buildArtifact.programInfo.name);
8991
}
9092
dispose(): void {

0 commit comments

Comments
 (0)