Skip to content

Commit aba34b6

Browse files
committed
initial refactoring of glmnet prediction code
1 parent 0ea300d commit aba34b6

File tree

6 files changed

+239
-82
lines changed

6 files changed

+239
-82
lines changed

R/linear_reg.R

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
#' \pkg{spark}
6464
#'
6565
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "spark")}
66-
#'
66+
#'
6767
#' \pkg{keras}
6868
#'
6969
#' \Sexpr[results=rd]{parsnip:::show_fit(parsnip:::linear_reg(), "keras")}
@@ -216,12 +216,66 @@ organize_glmnet_pred <- function(x, object) {
216216

217217
# ------------------------------------------------------------------------------
218218

219+
# For `predict` methods that use `glmnet`, we have specific methods.
220+
# Only one value of the penalty should be allowed when called by `predict()`:
221+
222+
check_penalty <- function(penalty = NULL, object, multi = FALSE) {
223+
224+
if (is.null(penalty)) {
225+
penalty <- object$fit$lambda
226+
}
227+
228+
# when using `predict()`, allow for a single lambda
229+
if (!multi) {
230+
if (length(penalty) != 1)
231+
stop("`penalty` should be a single numeric value. ",
232+
"`multi_predict()` can be used to get multiple predictions ",
233+
"per row of data.", call. = FALSE)
234+
}
235+
236+
if (length(object$fit$lambda) == 1 && penalty != object$fit$lambda)
237+
stop("The glmnet model was fit with a single penalty value of ",
238+
object$fit$lambda, ". Predicting with a value of ",
239+
penalty, " will give incorrect results from `glmnet()`.",
240+
call. = FALSE)
241+
242+
penalty
243+
}
244+
245+
# ------------------------------------------------------------------------------
246+
# glmnet call stack for linear regression using `predict` when object has
247+
# classes "_elnet" and "model_fit":
248+
#
249+
# predict()
250+
# predict._elnet(penalty = NULL) <-- checks and sets penalty
251+
# predict.model_fit() <-- checks for extra vars in ...
252+
# predict_numeric()
253+
# predict_numeric._elnet()
254+
# predict_numeric.model_fit()
255+
# predict.elnet()
256+
257+
258+
# glmnet call stack for linear regression using `multi_predict` when object has
259+
# classes "_elnet" and "model_fit":
260+
#
261+
# multi_predict()
262+
# multi_predict._elnet(penalty = NULL)
263+
# predict._elnet(multi = TRUE) <-- checks and sets penalty
264+
# predict.model_fit() <-- checks for extra vars in ...
265+
# predict_raw()
266+
# predict_raw._elnet()
267+
# predict_raw.model_fit(opts = list(s = penalty))
268+
# predict.elnet()
269+
270+
219271
#' @export
220272
predict._elnet <-
221-
function(object, new_data, type = NULL, opts = list(), ...) {
273+
function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) {
222274
if (any(names(enquos(...)) == "newdata"))
223275
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
224-
276+
277+
object$spec$args$penalty <- check_penalty(penalty, object, multi)
278+
225279
object$spec <- eval_args(object$spec)
226280
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
227281
}
@@ -230,7 +284,7 @@ predict._elnet <-
230284
predict_numeric._elnet <- function(object, new_data, ...) {
231285
if (any(names(enquos(...)) == "newdata"))
232286
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
233-
287+
234288
object$spec <- eval_args(object$spec)
235289
predict_numeric.model_fit(object, new_data = new_data, ...)
236290
}
@@ -239,8 +293,9 @@ predict_numeric._elnet <- function(object, new_data, ...) {
239293
predict_raw._elnet <- function(object, new_data, opts = list(), ...) {
240294
if (any(names(enquos(...)) == "newdata"))
241295
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
242-
296+
243297
object$spec <- eval_args(object$spec)
298+
opts$s <- object$spec$args$penalty
244299
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
245300
}
246301

@@ -251,14 +306,17 @@ multi_predict._elnet <-
251306
function(object, new_data, type = NULL, penalty = NULL, ...) {
252307
if (any(names(enquos(...)) == "newdata"))
253308
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
254-
309+
255310
dots <- list(...)
256-
if (is.null(penalty))
257-
penalty <- object$fit$lambda
258-
dots$s <- penalty
259311

260312
object$spec <- eval_args(object$spec)
261-
pred <- predict(object, new_data = new_data, type = "raw", opts = dots)
313+
314+
if (is.null(penalty)) {
315+
penalty <- object$fit$lambda
316+
}
317+
318+
pred <- predict._elnet(object, new_data = new_data, type = "raw",
319+
opts = dots, penalty = penalty, multi = TRUE)
262320
param_key <- tibble(group = colnames(pred), penalty = penalty)
263321
pred <- as_tibble(pred)
264322
pred$.row <- 1:nrow(pred)

R/logistic_reg.R

Lines changed: 66 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -235,41 +235,41 @@ organize_glmnet_prob <- function(x, object) {
235235
}
236236

237237
# ------------------------------------------------------------------------------
238+
# glmnet call stack for linear regression using `predict` when object has
239+
# classes "_lognet" and "model_fit" (for class predictions):
240+
#
241+
# predict()
242+
# predict._lognet(penalty = NULL) <-- checks and sets penalty
243+
# predict.model_fit() <-- checks for extra vars in ...
244+
# predict_class()
245+
# predict_class._lognet()
246+
# predict_class.model_fit()
247+
# predict.lognet()
248+
249+
250+
# glmnet call stack for linear regression using `multi_predict` when object has
251+
# classes "_lognet" and "model_fit" (for class predictions):
252+
#
253+
# multi_predict()
254+
# multi_predict._lognet(penalty = NULL)
255+
# predict._lognet(multi = TRUE) <-- checks and sets penalty
256+
# predict.model_fit() <-- checks for extra vars in ...
257+
# predict_raw()
258+
# predict_raw._lognet()
259+
# predict_raw.model_fit(opts = list(s = penalty))
260+
# predict.lognet()
238261

239-
#' @export
240-
predict._lognet <- function (object, new_data, type = NULL, opts = list(), ...) {
241-
if (any(names(enquos(...)) == "newdata"))
242-
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
243-
244-
object$spec <- eval_args(object$spec)
245-
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
246-
}
247-
248-
#' @export
249-
predict_class._lognet <- function (object, new_data, ...) {
250-
if (any(names(enquos(...)) == "newdata"))
251-
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
252-
253-
object$spec <- eval_args(object$spec)
254-
predict_class.model_fit(object, new_data = new_data, ...)
255-
}
262+
# ------------------------------------------------------------------------------
256263

257264
#' @export
258-
predict_classprob._lognet <- function (object, new_data, ...) {
265+
predict._lognet <- function (object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) {
259266
if (any(names(enquos(...)) == "newdata"))
260267
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
261268

262-
object$spec <- eval_args(object$spec)
263-
predict_classprob.model_fit(object, new_data = new_data, ...)
264-
}
265-
266-
#' @export
267-
predict_raw._lognet <- function (object, new_data, opts = list(), ...) {
268-
if (any(names(enquos(...)) == "newdata"))
269-
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
269+
object$spec$args$penalty <- check_penalty(penalty, object, multi)
270270

271271
object$spec <- eval_args(object$spec)
272-
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
272+
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
273273
}
274274

275275

@@ -281,23 +281,26 @@ multi_predict._lognet <-
281281
if (any(names(enquos(...)) == "newdata"))
282282
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
283283

284+
if (is_quosure(penalty))
285+
penalty <- eval_tidy(penalty)
286+
284287
dots <- list(...)
285288
if (is.null(penalty))
286-
penalty <- object$fit$lambda
289+
penalty <- eval_tidy(object$fit$lambda)
287290
dots$s <- penalty
288291

289292
if (is.null(type))
290293
type <- "class"
291-
if (!(type %in% c("class", "prob", "link"))) {
292-
stop ("`type` should be either 'class', 'link', or 'prob'.", call. = FALSE)
294+
if (!(type %in% c("class", "prob", "link", "raw"))) {
295+
stop ("`type` should be either 'class', 'link', 'raw', or 'prob'.", call. = FALSE)
293296
}
294297
if (type == "prob")
295298
dots$type <- "response"
296299
else
297300
dots$type <- type
298301

299302
object$spec <- eval_args(object$spec)
300-
pred <- predict(object, new_data = new_data, type = "raw", opts = dots)
303+
pred <- predict.model_fit(object, new_data = new_data, type = "raw", opts = dots)
301304
param_key <- tibble(group = colnames(pred), penalty = penalty)
302305
pred <- as_tibble(pred)
303306
pred$.row <- 1:nrow(pred)
@@ -321,6 +324,38 @@ multi_predict._lognet <-
321324
tibble(.pred = pred)
322325
}
323326

327+
328+
329+
330+
331+
#' @export
332+
predict_class._lognet <- function (object, new_data, ...) {
333+
if (any(names(enquos(...)) == "newdata"))
334+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
335+
336+
object$spec <- eval_args(object$spec)
337+
predict_class.model_fit(object, new_data = new_data, ...)
338+
}
339+
340+
#' @export
341+
predict_classprob._lognet <- function (object, new_data, ...) {
342+
if (any(names(enquos(...)) == "newdata"))
343+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
344+
345+
object$spec <- eval_args(object$spec)
346+
predict_classprob.model_fit(object, new_data = new_data, ...)
347+
}
348+
349+
#' @export
350+
predict_raw._lognet <- function (object, new_data, opts = list(), ...) {
351+
if (any(names(enquos(...)) == "newdata"))
352+
stop("Did you mean to use `new_data` instead of `newdata`?", call. = FALSE)
353+
354+
object$spec <- eval_args(object$spec)
355+
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
356+
}
357+
358+
324359
# ------------------------------------------------------------------------------
325360

326361
#' @importFrom utils globalVariables

R/multinom_reg.R

Lines changed: 54 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -188,54 +188,46 @@ organize_multnet_prob <- function(x, object) {
188188
}
189189

190190
# ------------------------------------------------------------------------------
191+
# glmnet call stack for linear regression using `predict` when object has
192+
# classes "_multnet" and "model_fit" (for class predictions):
193+
#
194+
# predict()
195+
# predict._multnet(penalty = NULL) <-- checks and sets penalty
196+
# predict.model_fit() <-- checks for extra vars in ...
197+
# predict_class()
198+
# predict_class._multnet()
199+
# predict.multnet()
200+
201+
202+
# glmnet call stack for linear regression using `multi_predict` when object has
203+
# classes "_multnet" and "model_fit" (for class predictions):
204+
#
205+
# multi_predict()
206+
# multi_predict._multnet(penalty = NULL)
207+
# predict._multnet(multi = TRUE) <-- checks and sets penalty
208+
# predict.model_fit() <-- checks for extra vars in ...
209+
# predict_raw()
210+
# predict_raw._multnet()
211+
# predict_raw.model_fit(opts = list(s = penalty))
212+
# predict.multnet()
191213

192-
#' @export
193-
predict._lognet <- function (object, new_data, type = NULL, opts = list(), ...) {
194-
object$spec <- eval_args(object$spec)
195-
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
196-
}
197-
198-
#' @export
199-
predict_class._lognet <- function (object, new_data, ...) {
200-
object$spec <- eval_args(object$spec)
201-
predict_class.model_fit(object, new_data = new_data, ...)
202-
}
203-
204-
#' @export
205-
predict_classprob._multnet <- function (object, new_data, ...) {
206-
object$spec <- eval_args(object$spec)
207-
predict_classprob.model_fit(object, new_data = new_data, ...)
208-
}
209-
210-
#' @export
211-
predict_raw._multnet <- function (object, new_data, opts = list(), ...) {
212-
object$spec <- eval_args(object$spec)
213-
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
214-
}
215-
214+
# ------------------------------------------------------------------------------
216215

217216
#' @export
218217
predict._multnet <-
219-
function(object, new_data, type = NULL, opts = list(), penalty = NULL, ...) {
220-
dots <- list(...)
221-
if (is.null(penalty))
222-
penalty <- object$fit$lambda
218+
function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) {
219+
220+
object$spec$args$penalty <- check_penalty(penalty, object, multi)
223221

224-
if (length(penalty) != 1)
225-
stop("`penalty` should be a single numeric value. ",
226-
"`multi_predict()` can be used to get multiple predictions ",
227-
"per row of data.", call. = FALSE)
228222
object$spec <- eval_args(object$spec)
229223
res <- predict.model_fit(
230224
object = object,
231225
new_data = new_data,
232226
type = type,
233-
opts = opts,
234-
penalty = penalty
227+
opts = opts
235228
)
236-
res
237-
}
238-
229+
res
230+
}
239231

240232
#' @importFrom dplyr full_join as_tibble arrange
241233
#' @importFrom tidyr gather
@@ -255,8 +247,8 @@ multi_predict._multnet <-
255247

256248
if (is.null(type))
257249
type <- "class"
258-
if (!(type %in% c("class", "prob", "link"))) {
259-
stop ("`type` should be either 'class', 'link', or 'prob'.", call. = FALSE)
250+
if (!(type %in% c("class", "prob", "link", "raw"))) {
251+
stop ("`type` should be either 'class', 'link', 'raw', or 'prob'.", call. = FALSE)
260252
}
261253
if (type == "prob")
262254
dots$type <- "response"
@@ -296,6 +288,29 @@ multi_predict._multnet <-
296288
tibble(.pred = pred)
297289
}
298290

291+
#' @export
292+
predict_class._multnet <- function (object, new_data, ...) {
293+
object$spec <- eval_args(object$spec)
294+
predict_class.model_fit(object, new_data = new_data, ...)
295+
}
296+
297+
#' @export
298+
predict_classprob._multnet <- function (object, new_data, ...) {
299+
object$spec <- eval_args(object$spec)
300+
predict_classprob.model_fit(object, new_data = new_data, ...)
301+
}
302+
303+
#' @export
304+
predict_raw._multnet <- function (object, new_data, opts = list(), ...) {
305+
object$spec <- eval_args(object$spec)
306+
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
307+
}
308+
309+
310+
311+
# ------------------------------------------------------------------------------
312+
313+
# This checks as a pre-processor in the model data object
299314
check_glmnet_lambda <- function(dat, object) {
300315
if (length(object$fit$lambda) > 1)
301316
stop(

0 commit comments

Comments
 (0)