11import * as d from 'typegpu/data' ;
22import tgpu , { type Eventual , type TgpuFn } from 'typegpu' ;
3-
4- export type SampleFiller = TgpuFn <
5- (
6- x : d . I32 ,
7- y : d . I32 ,
8- outSamplerPtr : d . Ptr < 'function' , d . WgslArray < d . Vec4f > , 'read-write' > ,
9- ) => d . Void
10- > ;
11- export type KernelReader = TgpuFn < ( idx : d . U32 ) => d . Vec4f > ;
3+ import { dot } from 'typegpu/std' ;
124
135/** Has to be divisible by 4 */
146export const inChannelsSlot = tgpu . slot < number > ( ) ;
157export const outChannelsSlot = tgpu . slot < number > ( ) ;
168export const kernelRadiusSlot = tgpu . slot < number > ( ) ;
17- const sampleFillerSlot = tgpu . slot < SampleFiller > ( ) ;
18- const kernelReaderSlot = tgpu . slot < KernelReader > ( ) ;
199
2010export const inChannelsQuarter = tgpu [ '~unstable' ] . derived ( ( ) => {
21- if ( inChannelsSlot . value % 4 !== 0 ) {
11+ if ( inChannelsSlot . $ % 4 !== 0 ) {
2212 throw new Error ( `'inChannels' has to be divisible by 4` ) ;
2313 }
24- return inChannelsSlot . value / 4 ;
14+ return inChannelsSlot . $ / 4 ;
2515} ) ;
2616
17+ const SampleArray = tgpu [ '~unstable' ] . derived ( ( ) =>
18+ d . arrayOf ( d . vec4f , inChannelsQuarter . $ ) ,
19+ ) ;
20+
21+ const ResultArray = tgpu [ '~unstable' ] . derived ( ( ) =>
22+ d . arrayOf ( d . f32 , outChannelsSlot . $ ) ,
23+ ) ;
24+
25+ export type SampleFiller = TgpuFn <
26+ (
27+ x : d . I32 ,
28+ y : d . I32 ,
29+ outSamples : d . Ptr < 'function' , typeof SampleArray . $ , 'read-write' > ,
30+ ) => d . Void
31+ > ;
32+ export type KernelReader = TgpuFn < ( idx : d . U32 ) => d . Vec4f > ;
33+
34+ const fillSampleSlot = tgpu . slot < SampleFiller > ( ) ;
35+ const readKernelSlot = tgpu . slot < KernelReader > ( ) ;
36+
2737const _convolveFn = tgpu [ '~unstable' ] . derived ( ( ) => {
28- return tgpu
29- . fn ( [ d . vec2u , d . ptrFn ( d . arrayOf ( d . f32 , outChannelsSlot . value ) ) ] ) (
30- /* wgsl */ `(coord: vec2u, result: ptr<function, array<f32, outChannels>>) {
31- var sample = array<vec4f, inChannelsQuarter>( );
38+ return tgpu . fn ( [ d . vec2u , d . ptrFn ( ResultArray . $ ) ] ) ( ( coord , outResult ) => {
39+ const sample = SampleArray . $ ( ) ;
40+ const kernelRadius = d . i32 ( kernelRadiusSlot . $ ) ;
41+ const kernelRadiusU = d . u32 ( kernelRadiusSlot . $ ) ;
3242
33- var coord_idx: u32 = 0 ;
34- for (var i: i32 = -i32(kernelRadiusSlot) ; i <= i32(kernelRadiusSlot) ; i++) {
35- for (var j: i32 = -i32(kernelRadiusSlot) ; j <= i32(kernelRadiusSlot) ; j++) {
36- fillSample( i32(coord.x) + i, i32(coord.y) + j, & sample);
43+ let coord_idx = d . u32 ( 0 ) ;
44+ for ( let i = - kernelRadius ; i <= kernelRadius ; i ++ ) {
45+ for ( let j = - kernelRadius ; j <= kernelRadius ; j ++ ) {
46+ fillSampleSlot . $ ( d . i32 ( coord . x ) + i , d . i32 ( coord . y ) + j , sample ) ;
3747
38- for (var out_c: u32 = 0; out_c < outChannels; out_c++) {
39- var weight_idx = (coord_idx + out_c * (2 * kernelRadiusSlot + 1) * (2 * kernelRadiusSlot + 1)) * inChannelsQuarter;
40- for (var in_c: u32 = 0; in_c < inChannelsQuarter; in_c++) {
41- (*result)[out_c] += dot(sample[in_c], readKernel(weight_idx));
42- weight_idx++;
43- }
44- }
48+ for ( let out_c = d . u32 ( 0 ) ; out_c < outChannelsSlot . $ ; out_c ++ ) {
49+ let weight_idx =
50+ ( coord_idx +
51+ out_c * ( 2 * kernelRadiusU + 1 ) * ( 2 * kernelRadiusU + 1 ) ) *
52+ inChannelsQuarter . $ ;
4553
46- coord_idx++;
54+ for ( let in_c = d . u32 ( 0 ) ; in_c < inChannelsQuarter . $ ; in_c ++ ) {
55+ outResult [ out_c ] += dot ( sample [ in_c ] , readKernelSlot . $ ( weight_idx ) ) ;
56+ weight_idx ++ ;
4757 }
4858 }
49- }` ,
50- )
51- . $uses ( {
52- inChannelsQuarter,
53- outChannels : outChannelsSlot ,
54- kernelRadiusSlot,
55- fillSample : sampleFillerSlot ,
56- readKernel : kernelReaderSlot ,
57- } )
58- . $name ( '_convolveFn' ) ;
59+
60+ coord_idx ++ ;
61+ }
62+ }
63+ } ) ;
5964} ) ;
6065
6166export const convolveFn = ( {
@@ -66,5 +71,5 @@ export const convolveFn = ({
6671 kernelReader : Eventual < KernelReader > ;
6772} ) =>
6873 _convolveFn
69- . with ( sampleFillerSlot , sampleFiller )
70- . with ( kernelReaderSlot , kernelReader ) ;
74+ . with ( fillSampleSlot , sampleFiller )
75+ . with ( readKernelSlot , kernelReader ) ;
0 commit comments