diff --git a/src/Three.TSL.js b/src/Three.TSL.js index a2dd262411c7e9..1c41325e4d4d0e 100644 --- a/src/Three.TSL.js +++ b/src/Three.TSL.js @@ -392,6 +392,10 @@ export const rotateUV = TSL.rotateUV; export const roughness = TSL.roughness; export const round = TSL.round; export const rtt = TSL.rtt; +export const quadBroadcast = TSL.quadBroadcast; +export const quadSwapDiagonal = TSL.quadSwapDiagonal; +export const quadSwapX = TSL.quadSwapX; +export const quadSwapY = TSL.quadSwapY; export const sRGBTransferEOTF = TSL.sRGBTransferEOTF; export const sRGBTransferOETF = TSL.sRGBTransferOETF; export const sampler = TSL.sampler; @@ -434,8 +438,29 @@ export const storageObject = TSL.storageObject; export const storageTexture = TSL.storageTexture; export const string = TSL.string; export const sub = TSL.sub; +export const subgroupAdd = TSL.subgroupAdd; +export const subgroupAll = TSL.subgroupAll; +export const subgroupAnd = TSL.subgroupAnd; +export const subgroupAny = TSL.subgroupAny; +export const subgroupBallot = TSL.subgroupBallot; +export const subgroupBroadcast = TSL.subgroupBroadcast; +export const subgroupBroadcastFirst = TSL.subgroupBroadcastFirst; +export const subgroupElect = TSL.subgroupElect; +export const subgroupExclusiveAdd = TSL.subgroupExclusiveAdd; +export const subgroupExclusiveMul = TSL.subgroupExclusiveMul; +export const subgroupInclusiveAdd = TSL.subgroupInclusiveAdd; +export const subgroupInclusiveMul = TSL.subgroupInclusiveMul; export const subgroupIndex = TSL.subgroupIndex; +export const subgroupMax = TSL.subgroupMax; +export const subgroupMin = TSL.subgroupMin; +export const subgroupMul = TSL.subgroupMul; +export const subgroupOr = TSL.subgroupOr; +export const subgroupShuffle = TSL.subgroupShuffle; +export const subgroupShuffleDown = TSL.subgroupShuffleDown; +export const subgroupShuffleUp = TSL.subgroupShuffleUp; +export const subgroupShuffleXor = TSL.subgroupShuffleXor; export const subgroupSize = TSL.subgroupSize; +export const subgroupXor = TSL.subgroupXor; export const tan = TSL.tan; export const tangentGeometry = TSL.tangentGeometry; export const tangentLocal = TSL.tangentLocal; diff --git a/src/nodes/TSL.js b/src/nodes/TSL.js index 0e6d480d0a3eea..bd4f9db530addb 100644 --- a/src/nodes/TSL.js +++ b/src/nodes/TSL.js @@ -128,6 +128,7 @@ export * from './gpgpu/ComputeBuiltinNode.js'; export * from './gpgpu/BarrierNode.js'; export * from './gpgpu/WorkgroupInfoNode.js'; export * from './gpgpu/AtomicFunctionNode.js'; +export * from './gpgpu/SubgroupFunctionNode.js'; // lighting export * from './accessors/Lights.js'; diff --git a/src/nodes/gpgpu/SubgroupFunctionNode.js b/src/nodes/gpgpu/SubgroupFunctionNode.js new file mode 100644 index 00000000000000..786e7cfb090a0b --- /dev/null +++ b/src/nodes/gpgpu/SubgroupFunctionNode.js @@ -0,0 +1,432 @@ +import TempNode from '../core/TempNode.js'; +import { addMethodChaining, nodeProxy } from '../tsl/TSLCore.js'; + + +/** + * This class represents a set of built in WGSL shader functions that sync + * synchronously execute an operation across a subgroup, or 'wave', of compute + * or fragment shader invocations within a workgroup. Typically, these functions + * will synchronously execute an operation using data from all active invocations + * within the subgroup, then broadcast that result to all active invocations. In + * other graphics APIs, subgroup functions are also referred to as wave intrinsics + * (DirectX/HLSL) or warp intrinsics (CUDA). + * + * @augments TempNode + */ +class SubgroupFunctionNode extends TempNode { + + static get type() { + + return 'SubgroupFunctionNode'; + + } + + /** + * Constructs a new function node. + * + * @param {String} method - The subgroup/wave intrinsic method to construct. + * @param {Node} [aNode=null] - The method's first argument. + * @param {Node} [bNode=null] - The method's second argument. + */ + constructor( method, aNode = null, bNode = null ) { + + super(); + + /** + * The subgroup/wave intrinsic method to construct. + * + * @type {String} + */ + this.method = method; + + /** + * The method's first argument. + * + * @type {Node} + */ + this.aNode = aNode; + + /** + * The method's second argument. + * + * @type {Node} + */ + this.bNode = bNode; + + } + + getInputType( builder ) { + + const aType = this.aNode ? this.aNode.getNodeType( builder ) : null; + const bType = this.bNode ? this.bNode.getNodeType( builder ) : null; + + const aLen = builder.isMatrix( aType ) ? 0 : builder.getTypeLength( aType ); + const bLen = builder.isMatrix( bType ) ? 0 : builder.getTypeLength( bType ); + + if ( aLen > bLen ) { + + return aType; + + } else { + + return bType; + + } + + } + + getNodeType( builder ) { + + const method = this.method; + + if ( method === SubgroupFunctionNode.SUBGROUP_ELECT ) { + + return 'bool'; + + } else if ( method === SubgroupFunctionNode.SUBGROUP_BALLOT ) { + + return 'uvec4'; + + } else { + + return this.getInputType( builder ); + + } + + } + + generate( builder, output ) { + + const method = this.method; + + const type = this.getNodeType( builder ); + const inputType = this.getInputType( builder ); + + const a = this.aNode; + const b = this.bNode; + + const params = []; + + if ( + method === SubgroupFunctionNode.SUBGROUP_BROADCAST || + method === SubgroupFunctionNode.SUBGROUP_SHUFFLE || + method === SubgroupFunctionNode.QUAD_BROADCAST + ) { + + const bType = b.getNodeType( builder ); + + params.push( + a.build( builder, type ), + b.build( builder, bType === 'float' ? 'int' : type ) + ); + + } else if ( + method === SubgroupFunctionNode.SUBGROUP_SHUFFLE_XOR || + method === SubgroupFunctionNode.SUBGROUP_SHUFFLE_DOWN || + method === SubgroupFunctionNode.SUBGROUP_SHUFFLE_UP + ) { + + params.push( + a.build( builder, type ), + b.build( builder, 'uint' ) + ); + + } else { + + if ( a !== null ) params.push( a.build( builder, inputType ) ); + if ( b !== null ) params.push( b.build( builder, inputType ) ); + + } + + const paramsString = params.length === 0 ? '()' : `( ${params.join( ', ' )} )`; + + return builder.format( `${ builder.getMethod( method, type ) }${paramsString}`, type, output ); + + + + } + + serialize( data ) { + + super.serialize( data ); + + data.method = this.method; + + } + + deserialize( data ) { + + super.deserialize( data ); + + this.method = data.method; + + } + +} + +// 0 inputs +SubgroupFunctionNode.SUBGROUP_ELECT = 'subgroupElect'; + +// 1 input +SubgroupFunctionNode.SUBGROUP_BALLOT = 'subgroupBallot'; +SubgroupFunctionNode.SUBGROUP_ADD = 'subgroupAdd'; +SubgroupFunctionNode.SUBGROUP_INCLUSIVE_ADD = 'subgroupInclusiveAdd'; +SubgroupFunctionNode.SUBGROUP_EXCLUSIVE_AND = 'subgroupExclusiveAdd'; +SubgroupFunctionNode.SUBGROUP_MUL = 'subgroupMul'; +SubgroupFunctionNode.SUBGROUP_INCLUSIVE_MUL = 'subgroupInclusiveMul'; +SubgroupFunctionNode.SUBGROUP_EXCLUSIVE_MUL = 'subgroupExclusiveMul'; +SubgroupFunctionNode.SUBGROUP_AND = 'subgroupAnd'; +SubgroupFunctionNode.SUBGROUP_OR = 'subgroupOr'; +SubgroupFunctionNode.SUBGROUP_XOR = 'subgroupXor'; +SubgroupFunctionNode.SUBGROUP_MIN = 'subgroupMin'; +SubgroupFunctionNode.SUBGROUP_MAX = 'subgroupMax'; +SubgroupFunctionNode.SUBGROUP_ALL = 'subgroupAll'; +SubgroupFunctionNode.SUBGROUP_ANY = 'subgroupAny'; +SubgroupFunctionNode.SUBGROUP_BROADCAST_FIRST = 'subgroupBroadcastFirst'; +SubgroupFunctionNode.QUAD_SWAP_X = 'quadSwapX'; +SubgroupFunctionNode.QUAD_SWAP_Y = 'quadSwapY'; +SubgroupFunctionNode.QUAD_SWAP_DIAGONAL = 'quadSwapDiagonal'; + +// 2 inputs +SubgroupFunctionNode.SUBGROUP_BROADCAST = 'subgroupBroadcast'; +SubgroupFunctionNode.SUBGROUP_SHUFFLE = 'subgroupShuffle'; +SubgroupFunctionNode.SUBGROUP_SHUFFLE_XOR = 'subgroupShuffleXor'; +SubgroupFunctionNode.SUBGROUP_SHUFFLE_UP = 'subgroupShuffleUp'; +SubgroupFunctionNode.SUBGROUP_SHUFFLE_DOWN = 'subgroupShuffleDown'; +SubgroupFunctionNode.QUAD_BROADCAST = 'quadBroadcast'; + +export default SubgroupFunctionNode; + + + +/** + * Returns true if this invocation has the lowest subgroup_invocation_id + * among active invocations in the subgroup. + * + * @method + * @return {bool} The result of the computation. + */ +export const subgroupElect = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.SUBGROUP_ELECT ); + +/** + * Returns a set of bitfields where the bit corresponding to subgroup_invocation_id + * is 1 if pred is true for that active invocation and 0 otherwise. + * + * @method + * @param {bool} pred - A boolean that sets the bit corresponding to the invocations subgroup invocation id. + * @return {vec4}- A bitfield corresponding to the pred value of each subgroup invocation. + */ +export const subgroupBallot = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.SUBGROUP_BALLOT ); + +/** + * A reduction that adds e among all active invocations and returns that result. + * + * @method + * @param {number} e - The value provided to the reduction by the current invocation. + * @return {number} The accumulated result of the reduction operation. + */ +export const subgroupAdd = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.SUBGROUP_ADD ); + +/** + * An inclusive scan returning the sum of e for all active invocations with subgroup_invocation_id less than or equal to this invocation. + * + * @method + * @param {number} e - The value provided to the inclusive scan by the current invocation. + * @return {number} The accumulated result of the inclusive scan operation. + */ +export const subgroupInclusiveAdd = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.SUBGROUP_INCLUSIVE_ADD ); + +/** + * An exclusive scan that returns the sum of e for all active invocations with subgroup_invocation_id less than this invocation. + * + * @method + * @param {number} e - The value provided to the exclusive scan by the current invocation. + * @return {number} The accumulated result of the exclusive scan operation. + */ +export const subgroupExclusiveAdd = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.SUBGROUP_EXCLUSIVE_AND ); + +/** + * A reduction that multiplies e among all active invocations and returns that result. + * + * @method + * @param {number} e - The value provided to the reduction by the current invocation. + * @return {number} The accumulated result of the reduction operation. + */ +export const subgroupMul = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.SUBGROUP_MUL ); + +/** + * An inclusive scan returning the product of e for all active invocations with subgroup_invocation_id less than or equal to this invocation. + * + * @method + * @param {number} e - The value provided to the inclusive scan by the current invocation. + * @return {number} The accumulated result of the inclusive scan operation. + */ +export const subgroupInclusiveMul = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.SUBGROUP_INCLUSIVE_MUL ); + +/** + * An exclusive scan that returns the product of e for all active invocations with subgroup_invocation_id less than this invocation. + * + * @method + * @param {number} e - The value provided to the exclusive scan by the current invocation. + * @return {number} The accumulated result of the exclusive scan operation. + */ +export const subgroupExclusiveMul = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.SUBGROUP_EXCLUSIVE_MUL ); + +/** + * A reduction that performs a bitwise and of e among all active invocations and returns that result. + * + * @method + * @param {number} e - The value provided to the reduction by the current invocation. + * @return {number} The result of the reduction operation. + */ +export const subgroupAnd = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.SUBGROUP_AND ); + +/** + * A reduction that performs a bitwise or of e among all active invocations and returns that result. + * + * @method + * @param {number} e - The value provided to the reduction by the current invocation. + * @return {number} The result of the reduction operation. + */ +export const subgroupOr = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.SUBGROUP_OR ); + +/** + * A reduction that performs a bitwise xor of e among all active invocations and returns that result. + * + * @method + * @param {number} e - The value provided to the reduction by the current invocation. + * @return {number} The result of the reduction operation. + */ +export const subgroupXor = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.SUBGROUP_XOR ); + +/** + * A reduction that performs a min of e among all active invocations and returns that result. + * + * @method + * @param {number} e - The value provided to the reduction by the current invocation. + * @return {number} The result of the reduction operation. + */ +export const subgroupMin = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.SUBGROUP_MIN ); + +/** + * A reduction that performs a max of e among all active invocations and returns that result. + * + * @method + * @param {number} e - The value provided to the reduction by the current invocation. + * @return {number} The result of the reduction operation. + */ +export const subgroupMax = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.SUBGROUP_MAX ); + +/** + * Returns true if e is true for all active invocations in the subgroup. + * + * @method + * @return {bool} The result of the computation. + */ +export const subgroupAll = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.SUBGROUP_ALL ); + +/** + * Returns true if e is true for any active invocation in the subgroup + * + * @method + * @return {bool} The result of the computation. + */ +export const subgroupAny = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.SUBGROUP_ANY ); + +/** + * Broadcasts e from the active invocation with the lowest subgroup_invocation_id in the subgroup to all other active invocations. + * + * @method + * @param {number} e - The value to broadcast from the lowest subgroup invocation. + * @param {number} id - The subgroup invocation to broadcast from. + * @return {number} The broadcast value. + */ +export const subgroupBroadcastFirst = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.SUBGROUP_BROADCAST_FIRST ); + +/** + * Swaps e between invocations in the quad in the X direction. + * + * @method + * @param {number} e - The value to swap from the current invocation. + * @return {number} The value received from the swap operation. + */ +export const quadSwapX = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.QUAD_SWAP_X ); + +/** + * Swaps e between invocations in the quad in the Y direction. + * + * @method + * @param {number} e - The value to swap from the current invocation. + * @return {number} The value received from the swap operation. + */ +export const quadSwapY = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.QUAD_SWAP_Y ); + +/** + * Swaps e between invocations in the quad diagonally. + * + * @method + * @param {number} e - The value to swap from the current invocation. + * @return {number} The value received from the swap operation. + */ +export const quadSwapDiagonal = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.QUAD_SWAP_DIAGONAL ); + +/** + * Broadcasts e from the invocation whose subgroup_invocation_id matches id, to all active invocations. + * + * @method + * @param {number} e - The value to broadcast from subgroup invocation 'id'. + * @param {number} id - The subgroup invocation to broadcast from. + * @return {number} The broadcast value. + */ +export const subgroupBroadcast = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.SUBGROUP_BROADCAST ); + +/** + * Returns v from the active invocation whose subgroup_invocation_id matches id + * + * @method + * @param {number} v - The value to return from subgroup invocation id^mask. + * @param {number} id - The subgroup invocation which returns the value v. + * @return {number} The broadcast value. + */ +export const subgroupShuffle = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.SUBGROUP_SHUFFLE ); + +/** + * Returns v from the active invocation whose subgroup_invocation_id matches subgroup_invocation_id ^ mask. + * + * @method + * @param {number} v - The value to return from subgroup invocation id^mask. + * @param {number} mask - A bitmask that determines the target invocation via a XOR operation. + * @return {number} The broadcast value. + */ +export const subgroupShuffleXor = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.SUBGROUP_SHUFFLE_XOR ); + +/** + * Returns v from the active invocation whose subgroup_invocation_id matches subgroup_invocation_id - delta + * + * @method + * @param {number} v - The value to return from subgroup invocation id^mask. + * @param {number} delta - A value that offsets the current in. + * @return {number} The broadcast value. + */ +export const subgroupShuffleUp = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.SUBGROUP_SHUFFLE_UP ); + +/** + * Returns v from the active invocation whose subgroup_invocation_id matches subgroup_invocation_id + delta + * + * @method + * @param {number} v - The value to return from subgroup invocation id^mask. + * @param {number} delta - A value that offsets the current subgroup invocation. + * @return {number} The broadcast value. + */ +export const subgroupShuffleDown = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.SUBGROUP_SHUFFLE_DOWN ); + +/** + * Broadcasts e from the quad invocation with id equal to id. + * + * @method + * @param {number} e - The value to broadcast. + * @return {number} The broadcast value. + */ +export const quadBroadcast = /*@__PURE__*/ nodeProxy( SubgroupFunctionNode, SubgroupFunctionNode.QUAD_BROADCAST ); + +addMethodChaining( 'subgroupElect', subgroupElect );