forked from NVIDIA/TensorRT
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgroupNormPluginCommon.h
More file actions
69 lines (62 loc) · 2.19 KB
/
groupNormPluginCommon.h
File metadata and controls
69 lines (62 loc) · 2.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef TRT_GROUPNORM_PLUGIN_COMMON_H
#define TRT_GROUPNORM_PLUGIN_COMMON_H
#include <cstdint>
#include <cuda.h>
#include <cuda_fp16.h>
struct GroupNormNHWCParams
{
// The output buffer. Layout NHWC.
__half* dst;
// The input buffer. Layout NHWC.
__half const* src;
// The gamma scaling factor.
float const* gamma;
// The beta term to add in GN.
float const* beta;
// The temporary buffer to do the global parallel reduction. Size:
// BLOCKS_PER_BATCH x C x 2.
float* redBuffer;
// The number of instances in the batch.
int32_t n;
// The height and width of each activation map.
int32_t h;
int32_t w;
// The number of channels.
int32_t c;
// The number of groups.
int32_t groups;
// Do we apply the Swish activation function?
bool withSwish;
// Precomputed values and parameters to control the execution of the kernels.
// The number of activations per instance (h * w) and the number of
// activations per block.
int32_t hw;
int32_t hwPerBlock;
// The number of channels per group and blocks per activation in the C
// dimension.
int32_t cPerBlock;
int32_t cPerGroup;
// The precomputed stride between instances.
int32_t hwc;
// The inverse of hwc in floats (to compute mean/var).
float invHWC;
// The precomputed number of groups per block.
int32_t groupsPerBlock;
};
#endif // TRT_GROUPNORM_PLUGIN_COMMON_H