-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path01_texture_inference.comp
More file actions
49 lines (42 loc) · 1.85 KB
/
Copy path01_texture_inference.comp
File metadata and controls
49 lines (42 loc) · 1.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
/*!
\file 01_texture_inference.comp
\author Sho Ikeda
\brief HLSL compute kernel for texture generation using MLP
\copyright Copyright (c) 2026 Advanced Micro Devices, Inc. All Rights Reserved.
SPDX-License-Identifier: MIT
*/
#include "mlp.hlsl"
#include "texture_inference_common.hlsl"
ByteAddressBuffer UvBuffer;
RWByteAddressBuffer OutputBuffer;
ByteAddressBuffer WeightBuffer;
ByteAddressBuffer BiasBuffer;
int TEST_WEIGHT_MATRIX_SIZE_FIRST;
int TEST_WEIGHT_MATRIX_SIZE_HIDDEN;
template <typename Type>
void inference(const uint threadId)
{
const dx::linalg::ComponentEnum elemType = mininn::impl::TypeTraits<Type>::COMPONENT_TYPE;
const dx::linalg::MatrixLayoutEnum layout = (dx::linalg::MatrixLayoutEnum)MINIDXNN_WEIGHT_MATRIX_LAYOUT;
using ActivationHiddenT = mininn:: MINIDXNN_ACTIVATION_HIDDEN_TYPE;
using ActivationLastT = mininn:: MINIDXNN_ACTIVATION_LAST_TYPE;
texkernel::inferenceStep<Type, MINIDXNN_NUM_LAYERS, MINIDXNN_HIDDEN_LAYER_DIMENSIONS,
elemType, layout, ActivationHiddenT, ActivationLastT,
MINIDXNN_WEIGHT_MATRIX_ALIGNMENT, MINIDXNN_WEIGHT_MATRIX_VECTOR_STRIDE_ALIGNMENT, MINIDXNN_BIAS_VECTOR_ALIGNMENT,
(MINIDXNN_HAS_BIAS != 0)>(
threadId, UvBuffer, OutputBuffer, WeightBuffer, BiasBuffer,
uint2((uint)TEST_WEIGHT_MATRIX_SIZE_FIRST, (uint)TEST_WEIGHT_MATRIX_SIZE_HIDDEN),
MINIDXNN_NUM_TASKS);
}
[numthreads(MINIDXNN_NUM_THREADS_X, 1, 1)]
void inferenceF32Kernel(const uint3 groupThreadId : SV_GroupThreadID, const uint3 groupId : SV_GroupID)
{
const uint threadId = groupId.x * MINIDXNN_NUM_THREADS_X + groupThreadId.x;
inference<float>(threadId);
}
[numthreads(MINIDXNN_NUM_THREADS_X, 1, 1)]
void inferenceF16Kernel(const uint3 groupThreadId : SV_GroupThreadID, const uint3 groupId : SV_GroupID)
{
const uint threadId = groupId.x * MINIDXNN_NUM_THREADS_X + groupThreadId.x;
inference<half>(threadId);
}