Skip to content
Open
10 changes: 8 additions & 2 deletions src/fklearn/training/classification.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Any
from typing import List, Any, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -502,6 +502,7 @@ def lgbm_classification_learner(df: pd.DataFrame,
learning_rate: float = 0.1,
num_estimators: int = 100,
extra_params: LogType = None,
categorical_features: Union[List[str], str] = "auto",
prediction_column: str = "prediction",
weight_column: str = None,
encode_extra_cols: bool = True) -> LearnerReturnType:
Expand Down Expand Up @@ -549,6 +550,11 @@ def lgbm_classification_learner(df: pd.DataFrame,
https://github.com/Microsoft/LightGBM/blob/master/docs/Parameters.rst
If not passed, the default will be used.

categorical_features : list of str, or 'auto', optional (default="auto")
A list of column names that should be treated as categorical features.
See the categorical_feature hyper-parameter in:
https://github.com/Microsoft/LightGBM/blob/master/docs/Parameters.rst

prediction_column : str
The name of the column with the predictions from the model.

Expand All @@ -570,7 +576,7 @@ def lgbm_classification_learner(df: pd.DataFrame,
features = features if not encode_extra_cols else expand_features_encoded(df, features)

dtrain = lgbm.Dataset(df[features].values, label=df[target], feature_name=list(map(str, features)), weight=weights,
Comment thread
jmoralez marked this conversation as resolved.
Outdated
silent=True)
silent=True, categorical_feature=categorical_features)

bst = lgbm.train(params, dtrain, num_estimators)

Expand Down
1 change: 1 addition & 0 deletions tests/training/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ def test_lgbm_classification_learner():
learning_rate=0.1,
num_estimators=20,
extra_params={"max_depth": 4, "seed": 42},
categorical_features=["x2"],
Comment thread
jmoralez marked this conversation as resolved.
Outdated
prediction_column="prediction",
weight_column="w")

Expand Down