@@ -23,13 +23,21 @@ import {
2323 TimestampQuery ,
2424} from './webgpu/types' ;
2525
26- interface CommandInfo {
26+ interface ComputeCommand {
2727 readonly kernelId : number ;
2828 readonly computePipeline : GPUComputePipeline ;
2929 readonly bindGroup : GPUBindGroup ;
3030 readonly dispatchGroup : [ number , number , number ] ;
3131}
3232
33+ interface MemcpyCommand {
34+ readonly source : GPUBuffer ;
35+ readonly dest : GPUBuffer ;
36+ readonly size : number ;
37+ }
38+
39+ type Command = ComputeCommand | MemcpyCommand ;
40+
3341interface KernelInfo {
3442 readonly kernelType : string ;
3543 readonly kernelName : string ;
@@ -234,9 +242,9 @@ export class WebGpuBackend {
234242 env : Env ;
235243 sessionStatus : SessionState = 'default' ;
236244 /**
237- * a SessionID -> CommandInfo [] mapping. It's used to record all GPU commands for corresponding session.
245+ * a SessionID -> Command [] mapping. It's used to record all GPU commands for corresponding session.
238246 */
239- capturedCommandList : Map < number , CommandInfo [ ] > = new Map ( ) ;
247+ capturedCommandList : Map < number , Command [ ] > = new Map ( ) ;
240248
241249 /**
242250 * a SessionID -> PendingKernelInfo[] mapping for profiling.
@@ -837,13 +845,19 @@ export class WebGpuBackend {
837845 }
838846 return gpuData . buffer ;
839847 }
848+
849+ async replayAndDownloadGpuData ( gpuBuffer : GPUBuffer , originalSize : number ) : Promise < Uint8Array > {
850+ this . replay ( ) ;
851+ return downloadGpuData ( this , gpuBuffer , originalSize ) ;
852+ }
853+
840854 createDownloader (
841855 gpuBuffer : GPUBuffer ,
842856 size : number ,
843857 type : Tensor . GpuBufferDataTypes ,
844858 ) : ( ) => Promise < Tensor . DataType > {
845859 return async ( ) => {
846- const data = await downloadGpuData ( this , gpuBuffer , size ) ;
860+ const data = await this . replayAndDownloadGpuData ( gpuBuffer , size ) ;
847861 return createView ( data . buffer , type ) ;
848862 } ;
849863 }
@@ -909,18 +923,27 @@ export class WebGpuBackend {
909923 for ( let i = 0 ; i < length ; i ++ ) {
910924 const computePassEncoder = this . getComputePassEncoder ( ) ;
911925 const command = sessionCommandList ! [ i ] ;
912- this . writeTimestamp ( this . pendingDispatchNumber * 2 ) ;
913- computePassEncoder . setPipeline ( command . computePipeline ) ;
914- computePassEncoder . setBindGroup ( 0 , command . bindGroup ) ;
915- computePassEncoder . dispatchWorkgroups ( ...command . dispatchGroup ) ;
916- this . writeTimestamp ( this . pendingDispatchNumber * 2 + 1 ) ;
917- this . pendingDispatchNumber ++ ;
918- if ( this . queryType !== 'none' ) {
919- this . pendingKernels . push ( sessionPendingKernels ! [ i ] ) ;
920- }
921- if ( this . pendingDispatchNumber >= this . maxDispatchNumber || this . queryType === 'at-passes' ) {
926+ if ( 'bindGroup' in command ) {
927+ this . writeTimestamp ( this . pendingDispatchNumber * 2 ) ;
928+ computePassEncoder . setPipeline ( command . computePipeline ) ;
929+ computePassEncoder . setBindGroup ( 0 , command . bindGroup ) ;
930+ computePassEncoder . dispatchWorkgroups ( ...command . dispatchGroup ) ;
931+ this . writeTimestamp ( this . pendingDispatchNumber * 2 + 1 ) ;
932+
933+ this . pendingDispatchNumber ++ ;
934+ if ( this . queryType !== 'none' ) {
935+ this . pendingKernels . push ( sessionPendingKernels ! [ i ] ) ;
936+ }
937+ if ( this . pendingDispatchNumber >= this . maxDispatchNumber || this . queryType === 'at-passes' ) {
938+ this . endComputePass ( ) ;
939+ }
940+ } else {
941+ const commandEncoder = this . getCommandEncoder ( ) ;
942+ this . pendingDispatchNumber ++ ;
922943 this . endComputePass ( ) ;
944+ commandEncoder . copyBufferToBuffer ( command . source , 0 , command . dest , 0 , command . size ) ;
923945 }
946+
924947 if ( this . pendingDispatchNumber >= this . maxDispatchNumber ) {
925948 this . flush ( ) ;
926949 }
0 commit comments