22"""Scikit-learn wrapper interface for LightGBM."""
33import copy
44from inspect import signature
5- from typing import Callable , Dict , List , Optional , Tuple , Union
5+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
66
77import numpy as np
88
@@ -582,38 +582,49 @@ def set_params(self, **params):
582582 self ._other_params [key ] = value
583583 return self
584584
585- def fit (self , X , y ,
586- sample_weight = None , init_score = None , group = None ,
587- eval_set = None , eval_names = None , eval_sample_weight = None ,
588- eval_class_weight = None , eval_init_score = None , eval_group = None ,
589- eval_metric = None , early_stopping_rounds = None ,
590- feature_name = 'auto' , categorical_feature = 'auto' ,
591- callbacks = None , init_model = None ):
592- """Docstring is set after definition, using a template."""
585+ def _process_params (self , stage : str ) -> Dict [str , Any ]:
586+ """Process the parameters of this estimator based on its type, parameter aliases, etc.
587+
588+ Parameters
589+ ----------
590+ stage : str
591+ Name of the stage (can be ``fit`` or ``predict``) this method is called from.
592+
593+ Returns
594+ -------
595+ processed_params : dict
596+ Processed parameter names mapped to their values.
597+ """
598+ assert stage in {"fit" , "predict" }
593599 params = self .get_params ()
594600
595601 params .pop ('objective' , None )
596602 for alias in _ConfigAliases .get ('objective' ):
597603 if alias in params :
598- self . _objective = params .pop (alias )
604+ obj = params .pop (alias )
599605 _log_warning (f"Found '{ alias } ' in params. Will use it instead of 'objective' argument" )
600- if self ._objective is None :
601- if isinstance (self , LGBMRegressor ):
602- self ._objective = "regression"
603- elif isinstance (self , LGBMClassifier ):
604- if self ._n_classes > 2 :
605- self ._objective = "multiclass"
606+ if stage == "fit" :
607+ self ._objective = obj
608+ if stage == "fit" :
609+ if self ._objective is None :
610+ if isinstance (self , LGBMRegressor ):
611+ self ._objective = "regression"
612+ elif isinstance (self , LGBMClassifier ):
613+ if self ._n_classes > 2 :
614+ self ._objective = "multiclass"
615+ else :
616+ self ._objective = "binary"
617+ elif isinstance (self , LGBMRanker ):
618+ self ._objective = "lambdarank"
606619 else :
607- self ._objective = "binary"
608- elif isinstance (self , LGBMRanker ):
609- self ._objective = "lambdarank"
610- else :
611- raise ValueError ("Unknown LGBMModel type." )
620+ raise ValueError ("Unknown LGBMModel type." )
612621 if callable (self ._objective ):
613- self ._fobj = _ObjectiveFunctionWrapper (self ._objective )
622+ if stage == "fit" :
623+ self ._fobj = _ObjectiveFunctionWrapper (self ._objective )
614624 params ['objective' ] = 'None' # objective = nullptr for unknown objective
615625 else :
616- self ._fobj = None
626+ if stage == "fit" :
627+ self ._fobj = None
617628 params ['objective' ] = self ._objective
618629
619630 params .pop ('importance_type' , None )
@@ -634,16 +645,6 @@ def fit(self, X, y,
634645 eval_at = params .pop (alias )
635646 params ['eval_at' ] = eval_at
636647
637- # Do not modify original args in fit function
638- # Refer to https://github.com/microsoft/LightGBM/pull/2619
639- eval_metric_list = copy .deepcopy (eval_metric )
640- if not isinstance (eval_metric_list , list ):
641- eval_metric_list = [eval_metric_list ]
642-
643- # Separate built-in from callable evaluation metrics
644- eval_metrics_callable = [_EvalFunctionWrapper (f ) for f in eval_metric_list if callable (f )]
645- eval_metrics_builtin = [m for m in eval_metric_list if isinstance (m , str )]
646-
647648 # register default metric for consistency with callable eval_metric case
648649 original_metric = self ._objective if isinstance (self ._objective , str ) else None
649650 if original_metric is None :
@@ -658,6 +659,28 @@ def fit(self, X, y,
658659 # overwrite default metric by explicitly set metric
659660 params = _choose_param_value ("metric" , params , original_metric )
660661
662+ return params
663+
664+ def fit (self , X , y ,
665+ sample_weight = None , init_score = None , group = None ,
666+ eval_set = None , eval_names = None , eval_sample_weight = None ,
667+ eval_class_weight = None , eval_init_score = None , eval_group = None ,
668+ eval_metric = None , early_stopping_rounds = None ,
669+ feature_name = 'auto' , categorical_feature = 'auto' ,
670+ callbacks = None , init_model = None ):
671+ """Docstring is set after definition, using a template."""
672+ params = self ._process_params (stage = "fit" )
673+
674+ # Do not modify original args in fit function
675+ # Refer to https://github.com/microsoft/LightGBM/pull/2619
676+ eval_metric_list = copy .deepcopy (eval_metric )
677+ if not isinstance (eval_metric_list , list ):
678+ eval_metric_list = [eval_metric_list ]
679+
680+ # Separate built-in from callable evaluation metrics
681+ eval_metrics_callable = [_EvalFunctionWrapper (f ) for f in eval_metric_list if callable (f )]
682+ eval_metrics_builtin = [m for m in eval_metric_list if isinstance (m , str )]
683+
661684 # concatenate metric from params (or default if not provided in params) and eval_metric
662685 params ['metric' ] = [params ['metric' ]] if isinstance (params ['metric' ], (str , type (None ))) else params ['metric' ]
663686 params ['metric' ] = [e for e in eval_metrics_builtin if e not in params ['metric' ]] + params ['metric' ]
@@ -799,8 +822,23 @@ def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
799822 raise ValueError ("Number of features of the model must "
800823 f"match the input. Model n_features_ is { self ._n_features } and "
801824 f"input n_features is { n_features } " )
825+ # retrive original params that possibly can be used in both training and prediction
826+ # and then overwrite them (considering aliases) with params that were passed directly in prediction
827+ predict_params = self ._process_params (stage = "predict" )
828+ for alias in _ConfigAliases .get_by_alias (
829+ "data" ,
830+ "X" ,
831+ "raw_score" ,
832+ "start_iteration" ,
833+ "num_iteration" ,
834+ "pred_leaf" ,
835+ "pred_contrib" ,
836+ * kwargs .keys ()
837+ ):
838+ predict_params .pop (alias , None )
839+ predict_params .update (kwargs )
802840 return self ._Booster .predict (X , raw_score = raw_score , start_iteration = start_iteration , num_iteration = num_iteration ,
803- pred_leaf = pred_leaf , pred_contrib = pred_contrib , ** kwargs )
841+ pred_leaf = pred_leaf , pred_contrib = pred_contrib , ** predict_params )
804842
805843 predict .__doc__ = _lgbmmodel_doc_predict .format (
806844 description = "Return the predicted value for each sample." ,
0 commit comments