Skip to content

Add regr.fuser learner (fused Lasso regression)#555

Open
EngineerDanny wants to merge 3 commits intomlr-org:mainfrom
EngineerDanny:main
Open

Add regr.fuser learner (fused Lasso regression)#555
EngineerDanny wants to merge 3 commits intomlr-org:mainfrom
EngineerDanny:main

Conversation

@EngineerDanny
Copy link
Copy Markdown

Summary

Adds a new regression learner regr.fuser that wraps the fused Lasso model from fuser and requires a single grouping column via the group role.

Implementation Details

  • New learner: LearnerRegrFuser
  • Backend: fuser::fusedL2DescentGLMNet(X, y, groups, lambda, G, gamma, scaling)
  • Requires exactly one grouping column via the group role
  • Defaults: lambda = NULL, gamma = 1, scaling = FALSE, G = NULL (defaults to all-ones K×K)
  • Errors on unseen groups at prediction time
  • Validates scalar lambda, gamma, and G dimensions
  • Tests for happy path + validation cases

Example

library(mlr3)
library(mlr3extralearners)

data = data.frame(
  y = rnorm(20),
  x1 = rnorm(20),
  x2 = rnorm(20),
  group = rep(c("A", "B"), each = 10)
)

task = TaskRegr$new(id = "fuser", backend = data, target = "y")
task$set_col_roles("group", roles = "group")

learner = lrn("regr.fuser", lambda = 0.01, gamma = 0.01, scaling = FALSE)
learner$train(task)
pred = learner$predict(task)

Tests

  • Rscript -e 'testthat::test_file("tests/testthat/test_fuser_regr_fuser.R")'

Notes

  • Requires fuser (GitHub fork): remotes::install_github("EngineerDanny/fuser")

Closes #426

Copy link
Copy Markdown
Member

@be-marc be-marc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, thanks for contributing. Lot of unnecessary checks. Otherwise, looks good to me. You need to release your fuser fork on r-universe; gh remotes do not work well for us.

x = as_numeric_matrix(task$data(cols = feature_names))
y = as.numeric(task$truth())

if (length(y) != nrow(x)) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is superfluous. mlr3 guarantees this.

stopf("Target length (%i) must match the number of rows (%i).",
length(y), nrow(x))
}
if (anyNA(x) || anyNA(y)) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also unnecessary. We check this in mlr3 when the learner does not have the property "missings".

groups = as.character(groups)
}
groups = as.vector(groups)
if (length(groups) != nrow(x)) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also unnecessary.

stopf("Grouping column length (%i) must match the number of rows (%i).",
length(groups), nrow(x))
}
if (anyNA(groups)) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also unnecessary.

}

pv = self$param_set$get_values(tags = "train")
if (!is.null(pv$lambda) && (!is.numeric(pv$lambda) || length(pv$lambda) != 1L)) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guaranteed by the parameter set

if (!is.null(pv$lambda) && (!is.numeric(pv$lambda) || length(pv$lambda) != 1L)) {
stopf("Parameter 'lambda' must be a numeric scalar or NULL.")
}
if (!is.numeric(pv$gamma) || length(pv$gamma) != 1L) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also guaranteed by the ps

stopf("Grouping column length (%i) must match the number of rows (%i).",
length(groups), nrow(x))
}
if (anyNA(x)) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary

}

x = as_numeric_matrix(task$data(cols = self$model$feature_names))
if (length(groups) != nrow(x)) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary

@tdhock
Copy link
Copy Markdown
Contributor

tdhock commented Mar 26, 2026

I would suggest changing role name from group to subset.

  • data in a group must stay together when splitting. (not the same concept as in fuser)

@be-marc be-marc force-pushed the main branch 2 times, most recently from e0d78c9 to 18bf58b Compare April 4, 2026 18:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[LRNRQ] Add <algorithm> from package fuser

3 participants