1- from typing import Any , Dict , List , Union
1+ from typing import Any , Callable , Dict , Optional , List , Union
22
33import numpy as np
44import pandas as pd
@@ -443,6 +443,8 @@ def lgbm_regression_learner(
443443 prediction_column : str = "prediction" ,
444444 weight_column : str = None ,
445445 encode_extra_cols : bool = True ,
446+ valid_dfs : Optional [List [pd .DataFrame ]] = None ,
447+ callbacks : Optional [List [Callable ]] = None ,
446448) -> LearnerReturnType :
447449 """
448450 Fits an LGBM regressor to the dataset.
@@ -465,7 +467,7 @@ def lgbm_regression_learner(
465467
466468 target : str
467469 The name of the column in `df` that should be used as target for the model.
468- This column should be binary, since this is a classification model .
470+ This column should be binary, if the objective is classification.
469471
470472 learning_rate : float
471473 Float in the range (0, 1]
@@ -495,6 +497,13 @@ def lgbm_regression_learner(
495497
496498 encode_extra_cols : bool (default: True)
497499 If True, treats all columns in `df` with name pattern fklearn_feat__col==val` as feature columns.
500+
501+ valid_dfs : list of pandas.DataFrame, optional (default=None)
502+ A list of datasets to be used for early-stopping during training.
503+
504+ callbacks : list of callable, or None, optional (default=None)
505+ List of callback functions that are applied at each iteration.
506+ See Callbacks in LightGBM Python API for more information.
498507 """
499508
500509 import lightgbm as lgbm
@@ -508,8 +517,21 @@ def lgbm_regression_learner(
508517 features = features if not encode_extra_cols else expand_features_encoded (df , features )
509518
510519 dtrain = lgbm .Dataset (df [features ].values , label = df [target ], feature_name = list (map (str , features )), weight = weights )
520+ valid_sets = (
521+ [
522+ lgbm .Dataset (
523+ valid_df [features ].values ,
524+ label = valid_df [target ],
525+ feature_name = list (map (str , features )),
526+ weight = valid_df [weight_column ].values if weight_column else None ,
527+ )
528+ for valid_df in valid_dfs
529+ ]
530+ if valid_dfs is not None
531+ else None
532+ )
511533
512- bst = lgbm .train (params , dtrain , num_estimators )
534+ bst = lgbm .train (params , dtrain , num_estimators , valid_sets = valid_sets , callbacks = callbacks )
513535
514536 def p (new_df : pd .DataFrame , apply_shap : bool = False ) -> pd .DataFrame :
515537 col_dict = {prediction_column : bst .predict (new_df [features ].values )}
0 commit comments