From 759c430c2b45c4de2eaf117d401b38f039cbeb4f Mon Sep 17 00:00:00 2001 From: Christopher Jones Date: Fri, 14 Mar 2025 10:18:13 -0500 Subject: [PATCH] Improve BDT in L1Trigger/L1TMuonEndCap - fixed memory corruption problems which could lead to double deletes or deletion of random memory locations - memory handling now done via std::unique_ptr - improved const correctness --- .../L1TMuonEndCap/interface/bdt/Forest.h | 20 ++-- L1Trigger/L1TMuonEndCap/interface/bdt/Node.h | 30 ++--- L1Trigger/L1TMuonEndCap/interface/bdt/Tree.h | 18 +-- L1Trigger/L1TMuonEndCap/src/bdt/Forest.cc | 73 +++++------- L1Trigger/L1TMuonEndCap/src/bdt/Node.cc | 57 ++++------ L1Trigger/L1TMuonEndCap/src/bdt/Tree.cc | 105 +++++++----------- 6 files changed, 127 insertions(+), 176 deletions(-) diff --git a/L1Trigger/L1TMuonEndCap/interface/bdt/Forest.h b/L1Trigger/L1TMuonEndCap/interface/bdt/Forest.h index a1b4cb5c80d10..9cb4c10753847 100644 --- a/L1Trigger/L1TMuonEndCap/interface/bdt/Forest.h +++ b/L1Trigger/L1TMuonEndCap/interface/bdt/Forest.h @@ -3,6 +3,7 @@ #ifndef L1Trigger_L1TMuonEndCap_emtf_Forest #define L1Trigger_L1TMuonEndCap_emtf_Forest +#include #include "Tree.h" #include "LossFunctions.h" #include "CondFormats/L1TObjects/interface/L1TMuonEndCapForest.h" @@ -14,28 +15,29 @@ namespace emtf { // Constructor(s)/Destructor Forest(); Forest(std::vector& trainingEvents); - ~Forest(); + ~Forest() = default; Forest(const Forest& forest); Forest& operator=(const Forest& forest); Forest(Forest&& forest) = default; + Forest& operator=(Forest&& forest) = default; // Get/Set void setTrainingEvents(std::vector& trainingEvents); std::vector getTrainingEvents(); // Returns the number of trees in the forest. - unsigned int size(); + unsigned int size() const; // Get info on variable importance. - void rankVariables(std::vector& rank); + void rankVariables(std::vector& rank) const; // Output the list of split values used for each variable. - void saveSplitValues(const char* savefilename); + void saveSplitValues(const char* savefilename) const; // Helpful operations - void listEvents(std::vector >& e); - void sortEventVectors(std::vector >& e); + void listEvents(std::vector>& e) const; + void sortEventVectors(std::vector>& e) const; void generate(int numTrainEvents, int numTestEvents, double sigma); void loadForestFromXML(const char* directory, unsigned int numTrees); void loadFromCondPayload(const L1TMuonEndCapForest::DForest& payload); @@ -63,9 +65,9 @@ namespace emtf { Tree* getTree(unsigned int i); private: - std::vector > events; - std::vector > subSample; - std::vector trees; + std::vector> events; + std::vector> subSample; + std::vector> trees; }; } // namespace emtf diff --git a/L1Trigger/L1TMuonEndCap/interface/bdt/Node.h b/L1Trigger/L1TMuonEndCap/interface/bdt/Node.h index 346a19117f2a4..f5c5d203b9b58 100644 --- a/L1Trigger/L1TMuonEndCap/interface/bdt/Node.h +++ b/L1Trigger/L1TMuonEndCap/interface/bdt/Node.h @@ -5,6 +5,7 @@ #include #include +#include #include "Event.h" namespace emtf { @@ -13,43 +14,46 @@ namespace emtf { public: Node(); Node(std::string cName); - ~Node(); + ~Node() = default; Node(Node &&) = default; Node(const Node &) = delete; Node &operator=(const Node &) = delete; - std::string getName(); + std::string getName() const; void setName(std::string sName); - double getErrorReduction(); + double getErrorReduction() const; void setErrorReduction(double sErrorReduction); Node *getLeftDaughter(); - void setLeftDaughter(Node *sLeftDaughter); + const Node *getLeftDaughter() const; + void setLeftDaughter(std::unique_ptr sLeftDaughter); + const Node *getRightDaughter() const; Node *getRightDaughter(); - void setRightDaughter(Node *sLeftDaughter); + void setRightDaughter(std::unique_ptr sLeftDaughter); Node *getParent(); + const Node *getParent() const; void setParent(Node *sParent); - double getSplitValue(); + double getSplitValue() const; void setSplitValue(double sSplitValue); - int getSplitVariable(); + int getSplitVariable() const; void setSplitVariable(int sSplitVar); - double getFitValue(); + double getFitValue() const; void setFitValue(double sFitValue); - double getTotalError(); + double getTotalError() const; void setTotalError(double sTotalError); - double getAvgError(); + double getAvgError() const; void setAvgError(double sAvgError); - int getNumEvents(); + int getNumEvents() const; void setNumEvents(int sNumEvents); std::vector > &getEvents(); @@ -64,8 +68,8 @@ namespace emtf { private: std::string name; - Node *leftDaughter; - Node *rightDaughter; + std::unique_ptr leftDaughter; + std::unique_ptr rightDaughter; Node *parent; double splitValue; diff --git a/L1Trigger/L1TMuonEndCap/interface/bdt/Tree.h b/L1Trigger/L1TMuonEndCap/interface/bdt/Tree.h index f9c18ebdb6fde..4a5aa2384ba99 100644 --- a/L1Trigger/L1TMuonEndCap/interface/bdt/Tree.h +++ b/L1Trigger/L1TMuonEndCap/interface/bdt/Tree.h @@ -4,6 +4,7 @@ #define L1Trigger_L1TMuonEndCap_emtf_Tree #include +#include #include "Node.h" #include "TXMLEngine.h" #include "CondFormats/L1TObjects/interface/L1TMuonEndCapForest.h" @@ -16,14 +17,15 @@ namespace emtf { public: Tree(); Tree(std::vector>& cEvents); - ~Tree(); + ~Tree() = default; Tree(const Tree& tree); Tree& operator=(const Tree& tree); - Tree(Tree&& tree) noexcept; + Tree(Tree&& tree) noexcept = default; void setRootNode(Node* sRootNode); Node* getRootNode(); + const Node* getRootNode() const; void setTerminalNodes(std::list& sTNodes); std::list& getTerminalNodes(); @@ -48,17 +50,17 @@ namespace emtf { const L1TMuonEndCapForest::DTreeNode& node, Node* tnode); - void rankVariables(std::vector& v); - void rankVariablesRecursive(Node* node, std::vector& v); + void rankVariables(std::vector& v) const; + void rankVariablesRecursive(Node* node, std::vector& v) const; - void getSplitValues(std::vector>& v); - void getSplitValuesRecursive(Node* node, std::vector>& v); + void getSplitValues(std::vector>& v) const; + void getSplitValuesRecursive(Node* node, std::vector>& v) const; double getBoostWeight(void) const { return boostWeight; } void setBoostWeight(double wgt) { boostWeight = wgt; } private: - Node* rootNode; + std::unique_ptr rootNode; std::list terminalNodes; int numTerminalNodes; double rmsError; @@ -66,7 +68,7 @@ namespace emtf { unsigned xmlVersion; // affects only XML loading part, save uses an old format and looses the boostWeight // this is the main recursive workhorse function that compensates for Nodes being non-copyable - Node* copyFrom(const Node* local_root); // no garantees if throws in the process + std::unique_ptr copyFrom(const Node* local_root); // no garantees if throws in the process // a dumb DFS tree traversal void findLeafs(Node* local_root, std::list& tn); }; diff --git a/L1Trigger/L1TMuonEndCap/src/bdt/Forest.cc b/L1Trigger/L1TMuonEndCap/src/bdt/Forest.cc index 62603324b6e69..717a9097b0f2f 100644 --- a/L1Trigger/L1TMuonEndCap/src/bdt/Forest.cc +++ b/L1Trigger/L1TMuonEndCap/src/bdt/Forest.cc @@ -45,37 +45,17 @@ Forest::Forest() { events = std::vector>(1); } Forest::Forest(std::vector& trainingEvents) { setTrainingEvents(trainingEvents); } -///////////////////////////////////////////////////////////////////////// -// _______________________Destructor____________________________________// -////////////////////////////////////////////////////////////////////////// - -Forest::~Forest() { - // When the forest is destroyed it will delete the trees as well as the - // events from the training and testing sets. - // The user may want the events to remain after they destroy the forest - // this should be changed in future upgrades. - - for (unsigned int i = 0; i < trees.size(); i++) { - if (trees[i]) - delete trees[i]; - } -} - Forest::Forest(const Forest& forest) { - transform(forest.trees.cbegin(), forest.trees.cend(), back_inserter(trees), [](const Tree* tree) { - return new Tree(*tree); + transform(forest.trees.cbegin(), forest.trees.cend(), back_inserter(trees), [](const std::unique_ptr& tree) { + return std::make_unique(*tree); }); } Forest& Forest::operator=(const Forest& forest) { - for (unsigned int i = 0; i < trees.size(); i++) { - if (trees[i]) - delete trees[i]; - } trees.resize(0); - - transform(forest.trees.cbegin(), forest.trees.cend(), back_inserter(trees), [](const Tree* tree) { - return new Tree(*tree); + trees.reserve(forest.trees.size()); + transform(forest.trees.cbegin(), forest.trees.cend(), back_inserter(trees), [](const std::unique_ptr& tree) { + return std::make_unique(*tree); }); return *this; } @@ -113,7 +93,7 @@ std::vector Forest::getTrainingEvents() { return events[0]; } // return the ith tree Tree* Forest::getTree(unsigned int i) { if (/*i>=0 && */ i < trees.size()) - return trees[i]; + return trees[i].get(); else { //std::cout << i << "is an invalid input for getTree. Out of range." << std::endl; return nullptr; @@ -124,7 +104,7 @@ Tree* Forest::getTree(unsigned int i) { // ______________________Various_Helpful_Functions______________________// ////////////////////////////////////////////////////////////////////////// -unsigned int Forest::size() { +unsigned int Forest::size() const { // Return the number of trees in the forest. return trees.size(); } @@ -139,7 +119,7 @@ unsigned int Forest::size() { // ---------------------------------------------------------------------- ////////////////////////////////////////////////////////////////////////// -void Forest::listEvents(std::vector>& e) { +void Forest::listEvents(std::vector>& e) const { // Simply list the events in each event vector. We have multiple copies // of the events vector. Each copy is sorted according to a different // determining variable. @@ -178,7 +158,7 @@ bool compareEventsById(Event* e1, Event* e2) { // ---------------------------------------------------------------------- ////////////////////////////////////////////////////////////////////////// -void Forest::sortEventVectors(std::vector>& e) { +void Forest::sortEventVectors(std::vector>& e) const { // When a node chooses the optimum split point and split variable it needs // the events to be sorted according to the variable it is considering. @@ -192,7 +172,7 @@ void Forest::sortEventVectors(std::vector>& e) { // ---------------------------------------------------------------------- ////////////////////////////////////////////////////////////////////////// -void Forest::rankVariables(std::vector& rank) { +void Forest::rankVariables(std::vector& rank) const { // This function ranks the determining variables according to their importance // in determining the fit. Use a low learning rate for better results. // Separates completely useless variables from useful ones well, @@ -242,7 +222,7 @@ void Forest::rankVariables(std::vector& rank) { // ---------------------------------------------------------------------- ////////////////////////////////////////////////////////////////////////// -void Forest::saveSplitValues(const char* savefilename) { +void Forest::saveSplitValues(const char* savefilename) const { // This function gathers all of the split values from the forest and puts them into lists. std::ofstream splitvaluefile; @@ -377,8 +357,8 @@ void Forest::doRegression(int nodeLimit, for (unsigned int i = 0; i < (unsigned)treeLimit; i++) { // std::cout << "++Building Tree " << i << "... " << std::endl; - Tree* tree = new Tree(events); - trees.push_back(tree); + trees.emplace_back(std::make_unique(events)); + Tree* tree = trees.back().get(); tree->buildTree(nodeLimit); // Update the targets for the next tree to fit. @@ -426,7 +406,7 @@ void Forest::predictEvents(std::vector& eventsp, unsigned int numtrees) void Forest::appendCorrection(std::vector& eventsp, int treenum) { // Update the prediction by appending the next correction. - Tree* tree = trees[treenum]; + Tree* tree = trees[treenum].get(); tree->filterEvents(eventsp); // Update the events with their new prediction. @@ -463,7 +443,7 @@ void Forest::predictEvent(Event* e, unsigned int numtrees) { void Forest::appendCorrection(Event* e, int treenum) { // Update the prediction by appending the next correction. - Tree* tree = trees[treenum]; + Tree* tree = trees[treenum].get(); Node* terminalNode = tree->filterEvent(e); // Update the event with its new prediction. @@ -478,12 +458,13 @@ void Forest::loadForestFromXML(const char* directory, unsigned int numTrees) { // Load a forest that has already been created and stored into XML somewhere. // Initialize the vector of trees. - trees = std::vector(numTrees); + trees.resize(numTrees); + trees.shrink_to_fit(); // Load the Forest. // std::cout << std::endl << "Loading Forest from XML ... " << std::endl; for (unsigned int i = 0; i < numTrees; i++) { - trees[i] = new Tree(); + trees[i] = std::make_unique(); std::stringstream ss; ss << directory << "/" << i << ".xml"; @@ -499,17 +480,12 @@ void Forest::loadFromCondPayload(const L1TMuonEndCapForest::DForest& forest) { // Initialize the vector of trees. unsigned int numTrees = forest.size(); - // clean-up leftovers from previous initialization (if any) - for (unsigned int i = 0; i < trees.size(); i++) { - if (trees[i]) - delete trees[i]; - } - - trees = std::vector(numTrees); + trees.resize(numTrees); + trees.shrink_to_fit(); // Load the Forest. for (unsigned int i = 0; i < numTrees; i++) { - trees[i] = new Tree(); + trees[i] = std::make_unique(); trees[i]->loadFromCondPayload(forest[i]); } } @@ -556,7 +532,8 @@ void Forest::doStochasticRegression( // Prepare some things. sortEventVectors(events); - trees = std::vector(treeLimit); + trees.resize(treeLimit); + trees.shrink_to_fit(); // See how long the regression takes. TStopwatch timer; @@ -572,7 +549,7 @@ void Forest::doStochasticRegression( for (unsigned int i = 0; i < (unsigned)treeLimit; i++) { // Build the tree using a random subsample. prepareRandomSubsample(fraction); - trees[i] = new Tree(subSample); + trees[i] = std::make_unique(subSample); trees[i]->buildTree(nodeLimit); // Fit all of the events based upon the tree we built using @@ -580,7 +557,7 @@ void Forest::doStochasticRegression( trees[i]->filterEvents(events[0]); // Update the targets for the next tree to fit. - updateRegTargets(trees[i], learningRate, l); + updateRegTargets(trees[i].get(), learningRate, l); // Save trees to xml in some directory. std::ostringstream ss; diff --git a/L1Trigger/L1TMuonEndCap/src/bdt/Node.cc b/L1Trigger/L1TMuonEndCap/src/bdt/Node.cc index e4090d308b903..7a232bd0e305c 100644 --- a/L1Trigger/L1TMuonEndCap/src/bdt/Node.cc +++ b/L1Trigger/L1TMuonEndCap/src/bdt/Node.cc @@ -53,79 +53,70 @@ Node::Node(std::string cName) { errorReduction = -1; } -////////////////////////////////////////////////////////////////////////// -// _______________________Destructor____________________________________// -////////////////////////////////////////////////////////////////////////// - -Node::~Node() { - // Recursively delete all nodes in the tree. - if (leftDaughter) - delete leftDaughter; - if (rightDaughter) - delete rightDaughter; -} - ////////////////////////////////////////////////////////////////////////// // ______________________Get/Set________________________________________// ////////////////////////////////////////////////////////////////////////// void Node::setName(std::string sName) { name = sName; } -std::string Node::getName() { return name; } +std::string Node::getName() const { return name; } // ---------------------------------------------------------------------- void Node::setErrorReduction(double sErrorReduction) { errorReduction = sErrorReduction; } -double Node::getErrorReduction() { return errorReduction; } +double Node::getErrorReduction() const { return errorReduction; } // ---------------------------------------------------------------------- -void Node::setLeftDaughter(Node* sLeftDaughter) { leftDaughter = sLeftDaughter; } +void Node::setLeftDaughter(std::unique_ptr sLeftDaughter) { leftDaughter = std::move(sLeftDaughter); } -Node* Node::getLeftDaughter() { return leftDaughter; } +Node* Node::getLeftDaughter() { return leftDaughter.get(); } +const Node* Node::getLeftDaughter() const { return leftDaughter.get(); } -void Node::setRightDaughter(Node* sRightDaughter) { rightDaughter = sRightDaughter; } +void Node::setRightDaughter(std::unique_ptr sRightDaughter) { rightDaughter = std::move(sRightDaughter); } -Node* Node::getRightDaughter() { return rightDaughter; } +Node* Node::getRightDaughter() { return rightDaughter.get(); } +const Node* Node::getRightDaughter() const { return rightDaughter.get(); } // ---------------------------------------------------------------------- void Node::setParent(Node* sParent) { parent = sParent; } Node* Node::getParent() { return parent; } +const Node* Node::getParent() const { return parent; } // ---------------------------------------------------------------------- void Node::setSplitValue(double sSplitValue) { splitValue = sSplitValue; } -double Node::getSplitValue() { return splitValue; } +double Node::getSplitValue() const { return splitValue; } void Node::setSplitVariable(int sSplitVar) { splitVariable = sSplitVar; } -int Node::getSplitVariable() { return splitVariable; } +int Node::getSplitVariable() const { return splitVariable; } // ---------------------------------------------------------------------- void Node::setFitValue(double sFitValue) { fitValue = sFitValue; } -double Node::getFitValue() { return fitValue; } +double Node::getFitValue() const { return fitValue; } // ---------------------------------------------------------------------- void Node::setTotalError(double sTotalError) { totalError = sTotalError; } -double Node::getTotalError() { return totalError; } +double Node::getTotalError() const { return totalError; } void Node::setAvgError(double sAvgError) { avgError = sAvgError; } -double Node::getAvgError() { return avgError; } +double Node::getAvgError() const { return avgError; } // ---------------------------------------------------------------------- void Node::setNumEvents(int sNumEvents) { numEvents = sNumEvents; } -int Node::getNumEvents() { return numEvents; } +int Node::getNumEvents() const { return numEvents; } // ---------------------------------------------------------------------- @@ -252,14 +243,12 @@ void Node::listEvents() { void Node::theMiracleOfChildBirth() { // Create Daughter Nodes - Node* left = new Node(name + " left"); - Node* right = new Node(name + " right"); + leftDaughter = std::make_unique(name + " left"); + rightDaughter = std::make_unique(name + " right"); // Link the Nodes Appropriately - leftDaughter = left; - rightDaughter = right; - left->setParent(this); - right->setParent(this); + leftDaughter->setParent(this); + rightDaughter->setParent(this); } // ---------------------------------------------------------------------- @@ -280,8 +269,8 @@ void Node::filterEventsToDaughters() { unsigned int sv = splitVariable; double sp = splitValue; - Node* left = leftDaughter; - Node* right = rightDaughter; + Node* left = leftDaughter.get(); + Node* right = rightDaughter.get(); std::vector > l(events.size()); std::vector > r(events.size()); @@ -320,8 +309,8 @@ Node* Node::filterEventToDaughter(Event* e) { unsigned int sv = splitVariable; double sp = splitValue; - Node* left = leftDaughter; - Node* right = rightDaughter; + Node* left = leftDaughter.get(); + Node* right = rightDaughter.get(); Node* nextNode = nullptr; // Prevent out-of-bounds access diff --git a/L1Trigger/L1TMuonEndCap/src/bdt/Tree.cc b/L1Trigger/L1TMuonEndCap/src/bdt/Tree.cc index 3eaac170bf48a..dadf3818164d8 100644 --- a/L1Trigger/L1TMuonEndCap/src/bdt/Tree.cc +++ b/L1Trigger/L1TMuonEndCap/src/bdt/Tree.cc @@ -28,37 +28,26 @@ using namespace emtf; Tree::Tree() { - rootNode = new Node("root"); + rootNode = std::make_unique("root"); - terminalNodes.push_back(rootNode); + terminalNodes.push_back(rootNode.get()); numTerminalNodes = 1; boostWeight = 0; xmlVersion = 2017; } Tree::Tree(std::vector>& cEvents) { - rootNode = new Node("root"); + rootNode = std::make_unique("root"); rootNode->setEvents(cEvents); - terminalNodes.push_back(rootNode); + terminalNodes.push_back(rootNode.get()); numTerminalNodes = 1; boostWeight = 0; xmlVersion = 2017; } -////////////////////////////////////////////////////////////////////////// -// _______________________Destructor____________________________________// -////////////////////////////////////////////////////////////////////////// - -Tree::~Tree() { - // When the tree is destroyed it will delete all of the nodes in the tree. - // The deletion begins with the rootnode and continues recursively. - if (rootNode) - delete rootNode; -} Tree::Tree(const Tree& tree) { - // unfortunately, authors of these classes didn't use const qualifiers - rootNode = copyFrom(const_cast(tree).getRootNode()); + rootNode = copyFrom(tree.getRootNode()); numTerminalNodes = tree.numTerminalNodes; rmsError = tree.rmsError; boostWeight = tree.boostWeight; @@ -66,16 +55,13 @@ Tree::Tree(const Tree& tree) { terminalNodes.resize(0); // find new leafs - findLeafs(rootNode, terminalNodes); + findLeafs(rootNode.get(), terminalNodes); /// if( numTerminalNodes != terminalNodes.size() ) throw std::runtime_error(); } Tree& Tree::operator=(const Tree& tree) { - if (rootNode) - delete rootNode; - // unfortunately, authors of these classes didn't use const qualifiers - rootNode = copyFrom(const_cast(tree).getRootNode()); + rootNode = copyFrom(tree.getRootNode()); numTerminalNodes = tree.numTerminalNodes; rmsError = tree.rmsError; boostWeight = tree.boostWeight; @@ -83,32 +69,32 @@ Tree& Tree::operator=(const Tree& tree) { terminalNodes.resize(0); // find new leafs - findLeafs(rootNode, terminalNodes); + findLeafs(rootNode.get(), terminalNodes); /// if( numTerminalNodes != terminalNodes.size() ) throw std::runtime_error(); return *this; } -Node* Tree::copyFrom(const Node* local_root) { +std::unique_ptr Tree::copyFrom(const Node* local_root) { // end-case if (!local_root) return nullptr; - Node* lr = const_cast(local_root); + const Node* lr = local_root; // recursion - Node* left_new_child = copyFrom(lr->getLeftDaughter()); - Node* right_new_child = copyFrom(lr->getRightDaughter()); + auto left_new_child = copyFrom(lr->getLeftDaughter()); + auto right_new_child = copyFrom(lr->getRightDaughter()); // performing main work at this level - Node* new_local_root = new Node(lr->getName()); + auto new_local_root = std::make_unique(lr->getName()); if (left_new_child) - left_new_child->setParent(new_local_root); + left_new_child->setParent(new_local_root.get()); if (right_new_child) - right_new_child->setParent(new_local_root); - new_local_root->setLeftDaughter(left_new_child); - new_local_root->setRightDaughter(right_new_child); + right_new_child->setParent(new_local_root.get()); + new_local_root->setLeftDaughter(std::move(left_new_child)); + new_local_root->setRightDaughter(std::move(right_new_child)); new_local_root->setErrorReduction(lr->getErrorReduction()); new_local_root->setSplitValue(lr->getSplitValue()); new_local_root->setSplitVariable(lr->getSplitVariable()); @@ -135,23 +121,14 @@ void Tree::findLeafs(Node* local_root, std::list& tn) { findLeafs(local_root->getRightDaughter(), tn); } -Tree::Tree(Tree&& tree) noexcept - : rootNode(tree.rootNode), - terminalNodes(std::move(tree.terminalNodes)), - numTerminalNodes(tree.numTerminalNodes), - rmsError(tree.rmsError), - boostWeight(tree.boostWeight), - xmlVersion(tree.xmlVersion) { - tree.rootNode = nullptr; // this line is the only reason not to use default move constructor -} - ////////////////////////////////////////////////////////////////////////// // ______________________Get/Set________________________________________// ////////////////////////////////////////////////////////////////////////// -void Tree::setRootNode(Node* sRootNode) { rootNode = sRootNode; } +void Tree::setRootNode(Node* sRootNode) { rootNode.reset(sRootNode); } -Node* Tree::getRootNode() { return rootNode; } +Node* Tree::getRootNode() { return rootNode.get(); } +const Node* Tree::getRootNode() const { return rootNode.get(); } // ---------------------------------------------------------------------- @@ -249,7 +226,7 @@ void Tree::filterEvents(std::vector& tEvents) { // The tree now knows about the events it needs to fit. // Filter them into a predictive region (terminal node). - filterEventsRecursive(rootNode); + filterEventsRecursive(rootNode.get()); } // ---------------------------------------------------------------------- @@ -277,7 +254,7 @@ Node* Tree::filterEvent(Event* e) { // given by the tEvents vector. // Filter the event into a predictive region (terminal node). - Node* node = filterEventRecursive(rootNode, e); + Node* node = filterEventRecursive(rootNode.get(), e); return node; } @@ -296,7 +273,7 @@ Node* Tree::filterEventRecursive(Node* node, Event* e) { // ---------------------------------------------------------------------- -void Tree::rankVariablesRecursive(Node* node, std::vector& v) { +void Tree::rankVariablesRecursive(Node* node, std::vector& v) const { // We recursively go through all of the nodes in the tree and find the // total error reduction for each variable. The one with the most // error reduction should be the most important. @@ -328,11 +305,11 @@ void Tree::rankVariablesRecursive(Node* node, std::vector& v) { // ---------------------------------------------------------------------- -void Tree::rankVariables(std::vector& v) { rankVariablesRecursive(rootNode, v); } +void Tree::rankVariables(std::vector& v) const { rankVariablesRecursive(rootNode.get(), v); } // ---------------------------------------------------------------------- -void Tree::getSplitValuesRecursive(Node* node, std::vector>& v) { +void Tree::getSplitValuesRecursive(Node* node, std::vector>& v) const { // We recursively go through all of the nodes in the tree and find the // split points used for each split variable. @@ -360,20 +337,22 @@ void Tree::getSplitValuesRecursive(Node* node, std::vector>& // ---------------------------------------------------------------------- -void Tree::getSplitValues(std::vector>& v) { getSplitValuesRecursive(rootNode, v); } +void Tree::getSplitValues(std::vector>& v) const { getSplitValuesRecursive(rootNode.get(), v); } ////////////////////////////////////////////////////////////////////////// // ______________________Storage/Retrieval______________________________// ////////////////////////////////////////////////////////////////////////// -template -std::string numToStr(T num) { - // Convert a number to a string. - std::stringstream ss; - ss << num; - std::string s = ss.str(); - return s; -} +namespace { + template + std::string numToStr(T num) { + // Convert a number to a string. + std::stringstream ss; + ss << num; + std::string s = ss.str(); + return s; + } +} // namespace // ---------------------------------------------------------------------- @@ -392,10 +371,10 @@ void Tree::saveToXML(const char* c) { // Add the root node. XMLNodePointer_t root = xml->NewChild(nullptr, nullptr, rootNode->getName().c_str()); - addXMLAttributes(xml, rootNode, root); + addXMLAttributes(xml, rootNode.get(), root); // Recursively write the tree to XML. - saveToXMLRecursive(xml, rootNode, root); + saveToXMLRecursive(xml, rootNode.get(), root); // Make the XML Document. XMLDocPointer_t xmldoc = xml->NewDoc(); @@ -464,7 +443,7 @@ void Tree::loadFromXML(const char* filename) { xmlVersion = 2016; } // Recursively connect nodes together. - loadFromXMLRecursive(xml, mainnode, rootNode); + loadFromXMLRecursive(xml, mainnode, rootNode.get()); // Release memory before exit xml->FreeDoc(xmldoc); @@ -548,12 +527,10 @@ void Tree::loadFromXMLRecursive(TXMLEngine* xml, XMLNodePointer_t xnode, Node* t void Tree::loadFromCondPayload(const L1TMuonEndCapForest::DTree& tree) { // start fresh in case this is not the only call to construct a tree - if (rootNode) - delete rootNode; - rootNode = new Node("root"); + rootNode = std::make_unique("root"); const L1TMuonEndCapForest::DTreeNode& mainnode = tree[0]; - loadFromCondPayloadRecursive(tree, mainnode, rootNode); + loadFromCondPayloadRecursive(tree, mainnode, rootNode.get()); } void Tree::loadFromCondPayloadRecursive(const L1TMuonEndCapForest::DTree& tree,