Skip to content

Commit 4af19f2

Browse files
authored
update (#251)
1 parent dddaeab commit 4af19f2

3 files changed

Lines changed: 32 additions & 4 deletions

File tree

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Changelog
22

3+
## [4.2.1] - 2026-04-22
4+
- **Enhancement**
5+
- Add `valid_dfs` and `callbacks` parameters to `lgbm_regression_learner` for early-stopping and custom LightGBM callback support.
6+
37
## [4.2.0] - 2026-04-22
48
- **Enhancement**
59
- Support `numpy>=1.26,<3` (adds numpy 2.x support).

src/fklearn/training/regression.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, List, Union
1+
from typing import Any, Callable, Dict, Optional, List, Union
22

33
import numpy as np
44
import pandas as pd
@@ -443,6 +443,8 @@ def lgbm_regression_learner(
443443
prediction_column: str = "prediction",
444444
weight_column: str = None,
445445
encode_extra_cols: bool = True,
446+
valid_dfs: Optional[List[pd.DataFrame]] = None,
447+
callbacks: Optional[List[Callable]] = None,
446448
) -> LearnerReturnType:
447449
"""
448450
Fits an LGBM regressor to the dataset.
@@ -465,7 +467,7 @@ def lgbm_regression_learner(
465467
466468
target : str
467469
The name of the column in `df` that should be used as target for the model.
468-
This column should be binary, since this is a classification model.
470+
This column should be binary, if the objective is classification.
469471
470472
learning_rate : float
471473
Float in the range (0, 1]
@@ -495,6 +497,13 @@ def lgbm_regression_learner(
495497
496498
encode_extra_cols : bool (default: True)
497499
If True, treats all columns in `df` with name pattern fklearn_feat__col==val` as feature columns.
500+
501+
valid_dfs : list of pandas.DataFrame, optional (default=None)
502+
A list of datasets to be used for early-stopping during training.
503+
504+
callbacks : list of callable, or None, optional (default=None)
505+
List of callback functions that are applied at each iteration.
506+
See Callbacks in LightGBM Python API for more information.
498507
"""
499508

500509
import lightgbm as lgbm
@@ -508,8 +517,21 @@ def lgbm_regression_learner(
508517
features = features if not encode_extra_cols else expand_features_encoded(df, features)
509518

510519
dtrain = lgbm.Dataset(df[features].values, label=df[target], feature_name=list(map(str, features)), weight=weights)
520+
valid_sets = (
521+
[
522+
lgbm.Dataset(
523+
valid_df[features].values,
524+
label=valid_df[target],
525+
feature_name=list(map(str, features)),
526+
weight=valid_df[weight_column].values if weight_column else None,
527+
)
528+
for valid_df in valid_dfs
529+
]
530+
if valid_dfs is not None
531+
else None
532+
)
511533

512-
bst = lgbm.train(params, dtrain, num_estimators)
534+
bst = lgbm.train(params, dtrain, num_estimators, valid_sets=valid_sets, callbacks=callbacks)
513535

514536
def p(new_df: pd.DataFrame, apply_shap: bool = False) -> pd.DataFrame:
515537
col_dict = {prediction_column: bst.predict(new_df[features].values)}

tests/training/test_regression.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections import Counter
22

3+
import lightgbm
34
import numpy as np
45
import pandas as pd
56

@@ -220,9 +221,10 @@ def test_lgbm_regression_learner():
220221
extra_params={"max_depth": 2, "seed": 42},
221222
prediction_column="prediction",
222223
weight_column="w",
224+
callbacks=[lightgbm.log_evaluation()],
223225
)
224226

225-
predict_fn, pred_train, log = learner(df_train)
227+
predict_fn, pred_train, log = learner(df_train, valid_dfs=[df_train])
226228

227229
pred_test = predict_fn(df_test)
228230

0 commit comments

Comments
 (0)