forked from Xilinx/onnx-mlir
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathNNPALimit.cpp
More file actions
148 lines (129 loc) · 4.75 KB
/
Copy pathNNPALimit.cpp
File metadata and controls
148 lines (129 loc) · 4.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
/*
* SPDX-License-Identifier: Apache-2.0
*/
//===----------------------- NNPALimit.cpp --------------------------------===//
//
// Copyright 2022-2024 The IBM Research Authors.
//
// =============================================================================
//
// The NNPA constant values.
//
//===----------------------------------------------------------------------===//
#include "src/Accelerators/NNPA/Support/NNPALimit.hpp"
#include "src/Compiler/CompilerOptions.hpp"
#include <assert.h>
#include <string>
using namespace onnx_mlir;
//===----------------------------------------------------------------------===//
// Scan mcpu and march flags into NNPALevel
static NNPALevel getNNPAFromTargetFlag(std::string str) {
// Coded it efficiently as it is called over and over again.
if (str.size() == 3) {
if (str[0] == 'z') {
if (str[1] == '1') {
if (str[2] == '6')
return NNPALevel::M14;
if (str[2] == '7')
return NNPALevel::M15;
}
}
} else if (str.size() == 6) {
if (str[0] == 'a' && str[1] == 'r' && str[2] == 'c' && str[3] == 'h') {
if (str[4] == '1') {
if (str[5] == '4')
return NNPALevel::M14;
if (str[5] == '5')
return NNPALevel::M15;
}
}
}
return NNPALevel::NONE;
}
// Read march flag, and if undefined, then read mcpu.
NNPALevel getNNPAFromFlags() {
NNPALevel level = getNNPAFromTargetFlag(march);
if (level == NNPALevel::NONE)
level = getNNPAFromTargetFlag(mcpu);
return level;
}
//===----------------------------------------------------------------------===//
// Print NNPALevel as a string (depending on which option was given)
// Print level using mcpu, march, or both depending on the options that were
// given to the compiler. Favor the zYY names below over the archXX names.
std::string getNNPAString(NNPALevel level) {
std::string val;
if (!mcpu.empty()) {
// The mcpu compiler option is defined, give an answer
if (level == NNPALevel::M14)
val = "--mcpu=z16"; // Note: --mcpu is deprecated.
else if (level == NNPALevel::M15)
val = "--mcpu=arch15"; // Note: --mcpu is deprecated.
else
assert(level == NNPALevel::NONE && "unknown mcpu option");
}
if (!march.empty()) {
if (!val.empty() && level != NNPALevel::NONE)
val = val.append(" ");
// The march compiler option is defined, give an answer
if (level == NNPALevel::M14)
val = val.append("--march=z16");
else if (level == NNPALevel::M15)
val = val.append("--march=arch15");
else
assert(level == NNPALevel::NONE && "unknown march option");
}
return val;
}
/// A function to check whether the input NNPA level, ie. "z16", is compatible
/// with the current NNPA level.
bool isCompatibleWithNNPALevel(NNPALevel level) {
NNPALevel flagLevel = getNNPAFromFlags();
if (level == NNPALevel::NONE && flagLevel == NNPALevel::NONE)
return false;
return level <= flagLevel;
}
/// A function to check whether the current --march, ie. "z16", is less than or
/// equal to the given NNPA level.
bool isLessEqualNNPALevel(NNPALevel level) {
NNPALevel flagLevel = getNNPAFromFlags();
if (level == NNPALevel::NONE && flagLevel == NNPALevel::NONE)
return false;
return flagLevel <= level;
}
//===----------------------------------------------------------------------===//
// Max dimension checks
// The NNPA maximum supported dimension index size value by using
// zdnn_get_nnpa_max_dim_idx_size() This value depends on HW.
static constexpr int64_t NNPA_ARCH14_MAXIMUM_DIMENSION_INDEX_SIZE = 32768;
/*
ARCH15 sizes are dimension dependent:
for(int i=1; i<=4; ++i) {
uint32_t maxDimSize = zdnn_get_max_for_dim((uint8_t) i);
printf(" max size for dim e%i: %i\n", i, (int) maxDimSize);
}
max size for dim e1: 2097152
max size for dim e2: 1048576
max size for dim e3: 32768
max size for dim e4: 32768
*/
static constexpr int64_t NNPA_ARCH15_MAXIMUM_DIMENSION_INDEX_SIZES[] = {
/*e1*/ 2097152, /*e2*/ 1048576, /*e3*/ 32768, /*e4*/ 32768};
int64_t NNPAGetMaxForDim(int64_t dim, int64_t rank) {
assert(rank >= 0 && "expected positive rank");
assert(dim >= 0 && dim < rank && "dim outside range [0..rank)");
if (rank > 4)
return 0;
// rank 4: (index from memref = 0, 1, 2, 3) -> e (4, 3, 2, 1)
// rank 3: (index from memref = 0, 1, 2) -> e (3, 2, 1)
// rank 2: (index from memref = 0, 1) -> e (2, 1)
// rank 1: (index from memref = 0) -> e (1)
int64_t e = rank - dim;
// List from newest NNPA to oldest, to select the most recent compatible
// one.
if (isCompatibleWithNNPALevel(NNPALevel::M15))
return NNPA_ARCH15_MAXIMUM_DIMENSION_INDEX_SIZES[e - 1];
if (isCompatibleWithNNPALevel(NNPALevel::M14))
return NNPA_ARCH14_MAXIMUM_DIMENSION_INDEX_SIZE;
return 0;
}