Skip to content

Commit a78e83a

Browse files
authored
Merge pull request #47552 from quinnanm/axol1tl_emulator_v5_pr
Changes to prepare for AXOL1TL v5 (15_0_X)
2 parents 11b84ba + 59888a0 commit a78e83a

File tree

2 files changed

+27
-11
lines changed

2 files changed

+27
-11
lines changed

L1Trigger/L1TGlobal/interface/AXOL1TLCondition.h

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include <iosfwd>
1111
#include <string>
12+
#include <utility>
1213

1314
#include "L1Trigger/L1TGlobal/interface/ConditionEvaluation.h"
1415
#include "DataFormats/L1Trigger/interface/L1Candidate.h"

L1Trigger/L1TGlobal/src/AXOL1TLCondition.cc

+26-11
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@
4141
#include "FWCore/MessageLogger/interface/MessageLogger.h"
4242
#include "FWCore/MessageLogger/interface/MessageDrop.h"
4343

44+
namespace {
45+
//template function for reading results
46+
template <typename ResultType, typename LossType>
47+
LossType readResult(hls4mlEmulator::Model& model) {
48+
std::pair<ResultType, LossType> ADModelResult; //model outputs a pair of the (result vector, loss)
49+
model.read_result(&ADModelResult);
50+
return ADModelResult.second;
51+
}
52+
} // namespace
53+
4454
l1t::AXOL1TLCondition::AXOL1TLCondition()
4555
: ConditionEvaluation(), m_gtAXOL1TLTemplate{nullptr}, m_gtGTB{nullptr}, m_model{nullptr} {}
4656

@@ -130,10 +140,7 @@ const bool l1t::AXOL1TLCondition::evaluateCondition(const int bxEval) const {
130140

131141
//types of inputs and outputs
132142
typedef ap_fixed<18, 13> inputtype;
133-
typedef std::array<ap_fixed<10, 7, AP_RND_CONV, AP_SAT>, 8> resulttype; //v3
134143
typedef ap_ufixed<18, 14> losstype;
135-
typedef std::pair<resulttype, losstype> pairtype;
136-
// typedef std::array<ap_fixed<10, 7>, 13> resulttype; //deprecated v1 type:
137144

138145
//define zero
139146
inputtype fillzero = 0.0;
@@ -148,10 +155,10 @@ const bool l1t::AXOL1TLCondition::evaluateCondition(const int bxEval) const {
148155
inputtype EtSumInput[EtSumVecSize];
149156

150157
//declare result vectors +score
151-
resulttype result;
158+
// resulttype result;
152159
losstype loss;
153-
pairtype ADModelResult; //model outputs a pair of the (result vector, loss)
154-
float score = -1.0; //not sure what the best default is hm??
160+
// pairtype ADModelResult; //model outputs a pair of the (result vector, loss)
161+
float score = -1.0; //not sure what the best default is hm??
155162

156163
//check number of input objects we actually have (muons, jets etc)
157164
int NCandMu = candMuVec->size(useBx);
@@ -198,8 +205,8 @@ const bool l1t::AXOL1TLCondition::evaluateCondition(const int bxEval) const {
198205
if (iMu < NMuons) { //stop if fill the Nobjects we need
199206
MuInput[0 + (3 * iMu)] = ((candMuVec->at(useBx, iMu))->hwPt()) /
200207
2; //index 0,3,6,9 //have to do hwPt/2 in order to match original et inputs
201-
MuInput[1 + (3 * iMu)] = (candMuVec->at(useBx, iMu))->hwEta(); //index 1,4,7,10
202-
MuInput[2 + (3 * iMu)] = (candMuVec->at(useBx, iMu))->hwPhi(); //index 2,5,8,11
208+
MuInput[1 + (3 * iMu)] = (candMuVec->at(useBx, iMu))->hwEtaAtVtx(); //index 1,4,7,10
209+
MuInput[2 + (3 * iMu)] = (candMuVec->at(useBx, iMu))->hwPhiAtVtx(); //index 2,5,8,11
203210
}
204211
}
205212
}
@@ -234,10 +241,18 @@ const bool l1t::AXOL1TLCondition::evaluateCondition(const int bxEval) const {
234241
//now run the inference
235242
m_model->prepare_input(ADModelInput); //scaling internal here
236243
m_model->predict();
237-
m_model->read_result(&ADModelResult); // this should be the square sum model result
244+
// m_model->read_result(&ADModelResult); // this should be the square sum model result
245+
if ((m_model_loader.model_name() == "GTADModel_v3") ||
246+
(m_model_loader.model_name() == "GTADModel_v4")) { //v3/v4 overwrite
247+
using resulttype = std::array<ap_fixed<10, 7, AP_RND_CONV, AP_SAT>, 8>;
248+
loss = readResult<resulttype, losstype>(*m_model);
249+
} else { //v5 default
250+
using resulttype = ap_fixed<18, 14, AP_RND_CONV, AP_SAT>;
251+
loss = readResult<resulttype, losstype>(*m_model);
252+
}
238253

239-
result = ADModelResult.first;
240-
loss = ADModelResult.second;
254+
// result = ADModelResult.first;
255+
// loss = ADModelResult.second;
241256
score = ((loss).to_float()) * 16.0; //scaling to match threshold
242257
//save score to class variable in case score saving needed
243258
setScore(score);

0 commit comments

Comments
 (0)