@@ -347,10 +347,17 @@ def _ray_predict(
347347 ** compat_predict_kwargs ,
348348 )
349349
350- def _ray_get_wrap_evaluation_matrices_compat_kwargs (self ) -> dict :
350+ def _ray_get_wrap_evaluation_matrices_compat_kwargs (
351+ self , label_transform = None ) -> dict :
352+ ret = {}
353+ if "label_transform" in inspect .signature (
354+ _wrap_evaluation_matrices ).parameters :
355+ # XGBoost < 1.6.0
356+ identity_func = lambda x : x # noqa
357+ ret ["label_transform" ] = label_transform or identity_func
351358 if hasattr (self , "enable_categorical" ):
352- return { "enable_categorical" : self .enable_categorical }
353- return {}
359+ ret [ "enable_categorical" ] = self .enable_categorical
360+ return ret
354361
355362 # copied from the file in the top comment
356363 # provided here for compatibility with legacy xgboost versions
@@ -450,8 +457,13 @@ def fit(
450457 else :
451458 obj = None
452459
453- model , feval , params = self ._configure_fit (xgb_model , eval_metric ,
454- params )
460+ try :
461+ model , feval , params = self ._configure_fit (xgb_model , eval_metric ,
462+ params )
463+ except TypeError :
464+ # XGBoost >= 1.6.0
465+ model , feval , params , early_stopping_rounds = self ._configure_fit (
466+ xgb_model , eval_metric , params , early_stopping_rounds )
455467
456468 # remove those as they will be set in RayXGBoostActor
457469 params .pop ("n_jobs" , None )
@@ -638,8 +650,13 @@ def fit(
638650 params ["objective" ] = "multi:softprob"
639651 params ["num_class" ] = self .n_classes_
640652
641- model , feval , params = self ._configure_fit (xgb_model , eval_metric ,
642- params )
653+ try :
654+ model , feval , params = self ._configure_fit (xgb_model , eval_metric ,
655+ params )
656+ except TypeError :
657+ # XGBoost >= 1.6.0
658+ model , feval , params , early_stopping_rounds = self ._configure_fit (
659+ xgb_model , eval_metric , params , early_stopping_rounds )
643660
644661 if train_dmatrix is None :
645662 train_dmatrix , evals = _wrap_evaluation_matrices (
@@ -656,13 +673,13 @@ def fit(
656673 base_margin_eval_set = base_margin_eval_set ,
657674 eval_group = None ,
658675 eval_qid = None ,
659- label_transform = label_transform ,
660676 # changed in xgboost-ray:
661677 create_dmatrix = lambda ** kwargs : RayDMatrix (** {
662678 ** kwargs ,
663679 ** ray_dmatrix_params
664680 }),
665- ** self ._ray_get_wrap_evaluation_matrices_compat_kwargs ())
681+ ** self ._ray_get_wrap_evaluation_matrices_compat_kwargs (
682+ label_transform = label_transform ))
666683
667684 # remove those as they will be set in RayXGBoostActor
668685 params .pop ("n_jobs" , None )
@@ -970,8 +987,13 @@ def fit(
970987 evals_result = {}
971988 params = self .get_xgb_params ()
972989
973- model , feval , params = self ._configure_fit (xgb_model , eval_metric ,
974- params )
990+ try :
991+ model , feval , params = self ._configure_fit (xgb_model , eval_metric ,
992+ params )
993+ except TypeError :
994+ # XGBoost >= 1.6.0
995+ model , feval , params , early_stopping_rounds = self ._configure_fit (
996+ xgb_model , eval_metric , params , early_stopping_rounds )
975997 if callable (feval ):
976998 raise ValueError (
977999 "Custom evaluation metric is not yet supported for XGBRanker." )
0 commit comments