Skip to content

Commit e8b80b9

Browse files
committed
[js/webgpu] Donot record with computePassEncoder when capturing
1 parent 497b06f commit e8b80b9

File tree

3 files changed

+55
-23
lines changed

3 files changed

+55
-23
lines changed

js/web/lib/wasm/jsep/backend-webgpu.ts

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,13 @@ import {
2424
} from './webgpu/types';
2525

2626
interface CommandInfo {
27-
readonly kernelId: number;
28-
readonly computePipeline: GPUComputePipeline;
29-
readonly bindGroup: GPUBindGroup;
30-
readonly dispatchGroup: [number, number, number];
27+
readonly kernelId?: number;
28+
readonly computePipeline?: GPUComputePipeline;
29+
readonly bindGroup?: GPUBindGroup;
30+
readonly dispatchGroup?: [number, number, number];
31+
readonly source?: GPUBuffer;
32+
readonly dest?: GPUBuffer;
33+
readonly size?: number;
3134
}
3235

3336
interface KernelInfo {
@@ -909,10 +912,16 @@ export class WebGpuBackend {
909912
for (let i = 0; i < length; i++) {
910913
const computePassEncoder = this.getComputePassEncoder();
911914
const command = sessionCommandList![i];
912-
this.writeTimestamp(this.pendingDispatchNumber * 2);
913-
computePassEncoder.setPipeline(command.computePipeline);
914-
computePassEncoder.setBindGroup(0, command.bindGroup);
915-
computePassEncoder.dispatchWorkgroups(...command.dispatchGroup);
915+
if (command.bindGroup) {
916+
this.writeTimestamp(this.pendingDispatchNumber * 2);
917+
computePassEncoder.setPipeline(command.computePipeline!);
918+
computePassEncoder.setBindGroup(0, command.bindGroup);
919+
computePassEncoder.dispatchWorkgroups(...command.dispatchGroup!);
920+
} else {
921+
this.writeTimestamp(this.pendingDispatchNumber * 2);
922+
const commandEncoder = this.getCommandEncoder();
923+
commandEncoder.copyBufferToBuffer(command.source!, 0, command.dest!, 0, command.size!);
924+
}
916925
this.writeTimestamp(this.pendingDispatchNumber * 2 + 1);
917926
this.pendingDispatchNumber++;
918927
if (this.queryType !== 'none') {

js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -274,16 +274,39 @@ class GpuDataManagerImpl implements GpuDataManager {
274274

275275
const size = calcNormalizedBufferSize(sourceGpuDataCache.originalSize);
276276

277-
// GPU copy
278-
const commandEncoder = this.backend.getCommandEncoder();
279-
this.backend.endComputePass();
280-
commandEncoder.copyBufferToBuffer(
281-
sourceGpuDataCache.gpuData.buffer,
282-
0,
283-
destinationGpuDataCache.gpuData.buffer,
284-
0,
285-
size,
286-
);
277+
if (this.backend.sessionStatus === 'capturing') {
278+
const commandInfo = {
279+
source: sourceGpuDataCache.gpuData.buffer,
280+
dest: destinationGpuDataCache.gpuData.buffer,
281+
size: size,
282+
};
283+
const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!);
284+
sessionCommandList!.push(commandInfo);
285+
} else {
286+
// GPU copy
287+
const commandEncoder = this.backend.getCommandEncoder();
288+
this.backend.endComputePass();
289+
commandEncoder.copyBufferToBuffer(
290+
sourceGpuDataCache.gpuData.buffer,
291+
0,
292+
destinationGpuDataCache.gpuData.buffer,
293+
0,
294+
size,
295+
);
296+
}
297+
298+
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1);
299+
this.backend.pendingDispatchNumber++;
300+
301+
if (
302+
this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber ||
303+
this.backend.queryType === 'at-passes'
304+
) {
305+
this.backend.endComputePass();
306+
}
307+
if (this.backend.pendingDispatchNumber >= this.backend.maxDispatchNumber) {
308+
this.backend.flush();
309+
}
287310
}
288311

289312
registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previous?: [GpuDataId, GPUBuffer]): number {

js/web/lib/wasm/jsep/webgpu/program-manager.ts

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ export class ProgramManager {
4141
): void {
4242
TRACE_FUNC_BEGIN(buildArtifact.programInfo.name);
4343
const device = this.backend.device;
44-
const computePassEncoder = this.backend.getComputePassEncoder();
4544
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2);
4645
const entries = [];
4746
for (const input of inputs) {
@@ -68,11 +67,12 @@ export class ProgramManager {
6867
};
6968
const sessionCommandList = this.backend.capturedCommandList.get(this.backend.currentSessionId!);
7069
sessionCommandList!.push(commandInfo);
70+
} else {
71+
const computePassEncoder = this.backend.getComputePassEncoder();
72+
computePassEncoder.setPipeline(buildArtifact.computePipeline);
73+
computePassEncoder.setBindGroup(0, bindGroup);
74+
computePassEncoder.dispatchWorkgroups(...dispatchGroup);
7175
}
72-
73-
computePassEncoder.setPipeline(buildArtifact.computePipeline);
74-
computePassEncoder.setBindGroup(0, bindGroup);
75-
computePassEncoder.dispatchWorkgroups(...dispatchGroup);
7676
this.backend.writeTimestamp(this.backend.pendingDispatchNumber * 2 + 1);
7777
this.backend.pendingDispatchNumber++;
7878

0 commit comments

Comments
 (0)