@@ -29,7 +29,7 @@ namespace xgboost {
2929namespace tree {
3030struct ScalarTreeView ;
3131struct MultiTargetTreeView ;
32- }
32+ } // namespace tree
3333
3434class Json ;
3535
@@ -88,7 +88,7 @@ class RegTree : public Model {
8888 /* * @brief tree node */
8989 class Node {
9090 public:
91- XGBOOST_DEVICE Node () {
91+ XGBOOST_DEVICE Node () {
9292 // assert compact alignment
9393 static_assert (sizeof (Node) == 4 * sizeof (int ) + sizeof (Info), " Node: 64 bit align" );
9494 }
@@ -132,16 +132,12 @@ class RegTree : public Model {
132132 * \brief set the left child
133133 * \param nid node id to right child
134134 */
135- XGBOOST_DEVICE void SetLeftChild (int nid) {
136- this ->cleft_ = nid;
137- }
135+ XGBOOST_DEVICE void SetLeftChild (int nid) { this ->cleft_ = nid; }
138136 /* !
139137 * \brief set the right child
140138 * \param nid node id to right child
141139 */
142- XGBOOST_DEVICE void SetRightChild (int nid) {
143- this ->cright_ = nid;
144- }
140+ XGBOOST_DEVICE void SetRightChild (int nid) { this ->cright_ = nid; }
145141 /* !
146142 * \brief set split condition of current node
147143 * \param split_index feature index to split
@@ -166,30 +162,25 @@ class RegTree : public Model {
166162 this ->cright_ = right;
167163 }
168164 /* ! \brief mark that this node is deleted */
169- XGBOOST_DEVICE void MarkDelete () {
170- this ->sindex_ = kDeletedNodeMarker ;
171- }
165+ XGBOOST_DEVICE void MarkDelete () { this ->sindex_ = kDeletedNodeMarker ; }
172166 /* ! \brief Reuse this deleted node. */
173- XGBOOST_DEVICE void Reuse () {
174- this ->sindex_ = 0 ;
175- }
167+ XGBOOST_DEVICE void Reuse () { this ->sindex_ = 0 ; }
176168 // set parent
177169 XGBOOST_DEVICE void SetParent (int pidx, bool is_left_child = true ) {
178170 if (is_left_child) pidx |= (1U << 31 );
179171 this ->parent_ = pidx;
180172 }
181173 bool operator ==(const Node& b) const {
182- return parent_ == b.parent_ && cleft_ == b.cleft_ &&
183- cright_ == b.cright_ && sindex_ == b.sindex_ &&
184- info_.leaf_value == b.info_ .leaf_value ;
174+ return parent_ == b.parent_ && cleft_ == b.cleft_ && cright_ == b.cright_ &&
175+ sindex_ == b.sindex_ && info_.leaf_value == b.info_ .leaf_value ;
185176 }
186177
187178 private:
188179 /* !
189180 * \brief in leaf node, we have weights, in non-leaf nodes,
190181 * we have split condition
191182 */
192- union Info{
183+ union Info {
193184 bst_float leaf_value;
194185 SplitCondT split_cond;
195186 };
@@ -277,9 +268,7 @@ class RegTree : public Model {
277268 }
278269
279270 /* ! \brief get node statistics given nid */
280- RTreeNodeStat& Stat (int nid) {
281- return stats_.HostVector ()[nid];
282- }
271+ RTreeNodeStat& Stat (int nid) { return stats_.HostVector ()[nid]; }
283272
284273 void LoadModel (Json const & in) override ;
285274 void SaveModel (Json* out) const override ;
@@ -314,11 +303,9 @@ class RegTree : public Model {
314303 * \param leaf_right_child The right child index of leaf, by default kInvalidNodeId,
315304 * some updaters use the right child index of leaf as a marker
316305 */
317- void ExpandNode (bst_node_t nid, unsigned split_index, bst_float split_value,
318- bool default_left, bst_float base_weight,
319- bst_float left_leaf_weight, bst_float right_leaf_weight,
320- bst_float loss_change, float sum_hess, float left_sum,
321- float right_sum,
306+ void ExpandNode (bst_node_t nid, unsigned split_index, bst_float split_value, bool default_left,
307+ bst_float base_weight, bst_float left_leaf_weight, bst_float right_leaf_weight,
308+ bst_float loss_change, float sum_hess, float left_sum, float right_sum,
322309 bst_node_t leaf_right_child = kInvalidNodeId );
323310 /* *
324311 * @brief Expands a leaf node into two additional leaf nodes for a multi-target tree.
@@ -365,6 +352,15 @@ class RegTree : public Model {
365352 bst_float base_weight, bst_float left_leaf_weight,
366353 bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
367354 float left_sum, float right_sum);
355+ /* *
356+ * @brief Expands a leaf node with categories for a multi-target tree.
357+ */
358+ void ExpandCategorical (bst_node_t nidx, bst_feature_t split_index,
359+ common::Span<const uint32_t > split_cat, bool default_left,
360+ linalg::VectorView<float const > base_weight,
361+ linalg::VectorView<float const > left_weight,
362+ linalg::VectorView<float const > right_weight, float loss_chg,
363+ float sum_hess, float left_sum, float right_sum);
368364 /* *
369365 * @brief Whether this tree has categorical split.
370366 */
@@ -567,7 +563,7 @@ class RegTree : public Model {
567563 // vector of nodes
568564 HostDeviceVector<Node> nodes_;
569565 // free node space, used during training process
570- std::vector<int > deleted_nodes_;
566+ std::vector<int > deleted_nodes_;
571567 // stats of nodes
572568 HostDeviceVector<RTreeNodeStat> stats_;
573569 HostDeviceVector<FeatureType> split_types_;
@@ -632,13 +628,9 @@ inline void RegTree::FVec::Fill(SparsePage::Inst const& inst) {
632628
633629inline void RegTree::FVec::Drop () { this ->Init (this ->Size ()); }
634630
635- inline size_t RegTree::FVec::Size () const {
636- return data_.size ();
637- }
631+ inline size_t RegTree::FVec::Size () const { return data_.size (); }
638632
639- inline float RegTree::FVec::GetFvalue (size_t i) const {
640- return data_[i];
641- }
633+ inline float RegTree::FVec::GetFvalue (size_t i) const { return data_[i]; }
642634
643635inline bool RegTree::FVec::IsMissing (size_t i) const { return std::isnan (data_[i]); }
644636
0 commit comments