Skip to content

Commit f8ef8c3

Browse files
committed
fixes for glmnet predictions
1 parent 3f5c464 commit f8ef8c3

File tree

10 files changed

+133
-38
lines changed

10 files changed

+133
-38
lines changed

NAMESPACE

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,22 @@ S3method(multi_predict,"_lognet")
99
S3method(multi_predict,"_multnet")
1010
S3method(multi_predict,"_xgb.Booster")
1111
S3method(multi_predict,default)
12+
S3method(predict,"_elnet")
13+
S3method(predict,"_lognet")
1214
S3method(predict,"_multnet")
1315
S3method(predict,model_fit)
16+
S3method(predict_class,"_lognet")
1417
S3method(predict_class,model_fit)
18+
S3method(predict_classprob,"_lognet")
19+
S3method(predict_classprob,"_multnet")
1520
S3method(predict_classprob,model_fit)
1621
S3method(predict_confint,model_fit)
22+
S3method(predict_num,"_elnet")
1723
S3method(predict_num,model_fit)
1824
S3method(predict_predint,model_fit)
25+
S3method(predict_raw,"_elnet")
26+
S3method(predict_raw,"_lognet")
27+
S3method(predict_raw,"_multnet")
1928
S3method(predict_raw,model_fit)
2029
S3method(print,boost_tree)
2130
S3method(print,linear_reg)
@@ -131,6 +140,7 @@ importFrom(purrr,map_dbl)
131140
importFrom(purrr,map_df)
132141
importFrom(purrr,map_dfr)
133142
importFrom(purrr,map_lgl)
143+
importFrom(rlang,eval_tidy)
134144
importFrom(rlang,sym)
135145
importFrom(rlang,syms)
136146
importFrom(stats,.checkMFClasses)

R/arguments.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,20 @@ set_mode <- function(object, mode) {
116116
object
117117
}
118118

119+
# ------------------------------------------------------------------------------
119120

121+
#' @importFrom rlang eval_tidy
122+
#' @importFrom purrr map
123+
maybe_eval <- function(x) {
124+
# if descriptors are in `x`, eval fails
125+
y <- try(rlang::eval_tidy(x), silent = TRUE)
126+
if (inherits(y, "try-error"))
127+
y <- x
128+
y
129+
}
130+
131+
eval_args <- function(spec, ...) {
132+
spec$args <- purrr::map(spec$args, maybe_eval)
133+
spec$others <- purrr::map(spec$others, maybe_eval)
134+
spec
135+
}

R/linear_reg.R

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,27 @@ organize_glmnet_pred <- function(x, object) {
226226
}
227227

228228

229+
# ------------------------------------------------------------------------------
230+
231+
#' @export
232+
predict._elnet <-
233+
function(object, new_data, type = NULL, opts = list(), ...) {
234+
object$spec <- eval_args(object$spec)
235+
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
236+
}
237+
238+
#' @export
239+
predict_num._elnet <- function(object, new_data, ...) {
240+
object$spec <- eval_args(object$spec)
241+
predict_num.model_fit(object, new_data = new_data, ...)
242+
}
243+
244+
#' @export
245+
predict_raw._elnet <- function(object, new_data, opts = list(), ...) {
246+
object$spec <- eval_args(object$spec)
247+
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
248+
}
249+
229250
#' @importFrom dplyr full_join as_tibble arrange
230251
#' @importFrom tidyr gather
231252
#' @export
@@ -235,6 +256,8 @@ multi_predict._elnet <-
235256
if (is.null(penalty))
236257
penalty <- object$fit$lambda
237258
dots$s <- penalty
259+
260+
object$spec <- eval_args(object$spec)
238261
pred <- predict(object, new_data = new_data, type = "raw", opts = dots)
239262
param_key <- tibble(group = colnames(pred), penalty = penalty)
240263
pred <- as_tibble(pred)

R/linear_reg_data.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ linear_reg_lm_data <-
8686
)
8787
)
8888

89+
# Note: For glmnet, you will need to make model-specific predict methods.
90+
# See linear_reg.R
8991
linear_reg_glmnet_data <-
9092
list(
9193
libs = "glmnet",

R/logistic_reg.R

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,31 @@ organize_glmnet_prob <- function(x, object) {
247247

248248
# ------------------------------------------------------------------------------
249249

250+
#' @export
251+
predict._lognet <- function (object, new_data, type = NULL, opts = list(), ...) {
252+
object$spec <- eval_args(object$spec)
253+
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
254+
}
255+
256+
#' @export
257+
predict_class._lognet <- function (object, new_data, ...) {
258+
object$spec <- eval_args(object$spec)
259+
predict_class.model_fit(object, new_data = new_data, ...)
260+
}
261+
262+
#' @export
263+
predict_classprob._lognet <- function (object, new_data, ...) {
264+
object$spec <- eval_args(object$spec)
265+
predict_classprob.model_fit(object, new_data = new_data, ...)
266+
}
267+
268+
#' @export
269+
predict_raw._lognet <- function (object, new_data, opts = list(), ...) {
270+
object$spec <- eval_args(object$spec)
271+
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
272+
}
273+
274+
250275
#' @importFrom dplyr full_join as_tibble arrange
251276
#' @importFrom tidyr gather
252277
#' @export
@@ -255,6 +280,7 @@ multi_predict._lognet <-
255280
dots <- list(...)
256281
if (is.null(penalty))
257282
penalty <- object$lambda
283+
dots$s <- penalty
258284

259285
if (is.null(type))
260286
type <- "class"
@@ -266,7 +292,7 @@ multi_predict._lognet <-
266292
else
267293
dots$type <- type
268294

269-
dots$s <- penalty
295+
object$spec <- eval_args(object$spec)
270296
pred <- predict(object, new_data = new_data, type = "raw", opts = dots)
271297
param_key <- tibble(group = colnames(pred), penalty = penalty)
272298
pred <- as_tibble(pred)

R/logistic_reg_data.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ logistic_reg_glm_data <-
9595
)
9696
)
9797

98+
# Note: For glmnet, you will need to make model-specific predict methods.
99+
# See logistic_reg.R
98100
logistic_reg_glmnet_data <-
99101
list(
100102
libs = "glmnet",

R/multinom_reg.R

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,31 @@ organize_multnet_prob <- function(x, object) {
200200

201201
# ------------------------------------------------------------------------------
202202

203+
#' @export
204+
predict._lognet <- function (object, new_data, type = NULL, opts = list(), ...) {
205+
object$spec <- eval_args(object$spec)
206+
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
207+
}
208+
209+
#' @export
210+
predict_class._lognet <- function (object, new_data, ...) {
211+
object$spec <- eval_args(object$spec)
212+
predict_class.model_fit(object, new_data = new_data, ...)
213+
}
214+
215+
#' @export
216+
predict_classprob._multnet <- function (object, new_data, ...) {
217+
object$spec <- eval_args(object$spec)
218+
predict_classprob.model_fit(object, new_data = new_data, ...)
219+
}
220+
221+
#' @export
222+
predict_raw._multnet <- function (object, new_data, opts = list(), ...) {
223+
object$spec <- eval_args(object$spec)
224+
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
225+
}
226+
227+
203228
#' @export
204229
predict._multnet <-
205230
function(object, new_data, type = NULL, opts = list(), penalty = NULL, ...) {
@@ -211,6 +236,7 @@ predict._multnet <-
211236
stop("`penalty` should be a single numeric value. ",
212237
"`multi_predict` can be used to get multiple predictions ",
213238
"per row of data.", call. = FALSE)
239+
object$spec <- eval_args(object$spec)
214240
res <- predict.model_fit(
215241
object = object,
216242
new_data = new_data,
@@ -227,9 +253,13 @@ predict._multnet <-
227253
#' @export
228254
multi_predict._multnet <-
229255
function(object, new_data, type = NULL, penalty = NULL, ...) {
256+
if (is_quosure(penalty))
257+
penalty <- eval_tidy(penalty)
258+
230259
dots <- list(...)
231260
if (is.null(penalty))
232-
penalty <- object$lambda
261+
penalty <- eval_tidy(object$lambda)
262+
dots$s <- penalty
233263

234264
if (is.null(type))
235265
type <- "class"
@@ -241,7 +271,7 @@ multi_predict._multnet <-
241271
else
242272
dots$type <- type
243273

244-
dots$s <- penalty
274+
object$spec <- eval_args(object$spec)
245275
pred <- predict.model_fit(object, new_data = new_data, type = "raw", opts = dots)
246276

247277
format_probs <- function(x) {

tests/testthat/test_linear_reg_glmnet.R

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,7 @@ test_that('glmnet prediction, single lambda', {
7171
s = iris_basic$spec$args$penalty)
7272
uni_pred <- unname(uni_pred[,1])
7373

74-
# TODO neet a fix here
75-
# expect_equal(uni_pred, predict_num(res_xy, iris[1:5, num_pred]))
74+
expect_equal(uni_pred, predict_num(res_xy, iris[1:5, num_pred]))
7675

7776
res_form <- fit(
7877
iris_basic,
@@ -90,8 +89,8 @@ test_that('glmnet prediction, single lambda', {
9089
newx = form_pred,
9190
s = res_form$spec$spec$args$penalty)
9291
form_pred <- unname(form_pred[,1])
93-
# TODO neet a fix here
94-
# expect_equal(form_pred, predict_num(res_form, iris[1:5, c("Sepal.Width", "Species")]))
92+
93+
expect_equal(form_pred, predict_num(res_form, iris[1:5, c("Sepal.Width", "Species")]))
9594
})
9695

9796

@@ -119,8 +118,7 @@ test_that('glmnet prediction, multiple lambda', {
119118
mult_pred$lambda <- rep(lams, each = 5)
120119
mult_pred <- mult_pred[,-2]
121120

122-
# TODO neet a fix here
123-
# expect_equal(mult_pred, predict_num(res_xy, iris[1:5, num_pred]))
121+
expect_equal(mult_pred, predict_num(res_xy, iris[1:5, num_pred]))
124122

125123
res_form <- fit(
126124
iris_mult,
@@ -141,8 +139,7 @@ test_that('glmnet prediction, multiple lambda', {
141139
form_pred$lambda <- rep(lams, each = 5)
142140
form_pred <- form_pred[,-2]
143141

144-
# TODO neet a fix here
145-
# expect_equal(form_pred, predict_num(res_form, iris[1:5, c("Sepal.Width", "Species")]))
142+
expect_equal(form_pred, predict_num(res_form, iris[1:5, c("Sepal.Width", "Species")]))
146143
})
147144

148145
test_that('glmnet prediction, all lambda', {
@@ -164,8 +161,7 @@ test_that('glmnet prediction, all lambda', {
164161
all_pred$lambda <- rep(res_xy$fit$lambda, each = 5)
165162
all_pred <- all_pred[,-2]
166163

167-
# TODO neet a fix here
168-
# expect_equal(all_pred, predict_num(res_xy, iris[1:5, num_pred]))
164+
expect_equal(all_pred, predict_num(res_xy, iris[1:5, num_pred]))
169165

170166
# test that the lambda seq is in the right order (since no docs on this)
171167
tmp_pred <- predict(res_xy$fit, newx = as.matrix(iris[1:5, num_pred]),
@@ -189,8 +185,7 @@ test_that('glmnet prediction, all lambda', {
189185
form_pred$lambda <- rep(res_form$fit$lambda, each = 5)
190186
form_pred <- form_pred[,-2]
191187

192-
# TODO neet a fix here
193-
# expect_equal(form_pred, predict_num(res_form, iris[1:5, c("Sepal.Width", "Species")]))
188+
expect_equal(form_pred, predict_num(res_form, iris[1:5, c("Sepal.Width", "Species")]))
194189
})
195190

196191

tests/testthat/test_logistic_reg_glmnet.R

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ test_that('glmnet prediction, one lambda', {
6767
uni_pred <- factor(uni_pred, levels = levels(lending_club$Class))
6868
uni_pred <- unname(uni_pred)
6969

70-
# not currently working; will fix
71-
# expect_equal(uni_pred, predict_class(xy_fit, lending_club[1:7, num_pred]))
70+
expect_equal(uni_pred, predict_class(xy_fit, lending_club[1:7, num_pred]))
7271

7372
res_form <- fit(
7473
logistic_reg(penalty = 0.1),
@@ -88,8 +87,8 @@ test_that('glmnet prediction, one lambda', {
8887
form_pred <- ifelse(form_pred >= 0.5, "good", "bad")
8988
form_pred <- factor(form_pred, levels = levels(lending_club$Class))
9089
form_pred <- unname(form_pred)
91-
# not currently working; will fix
92-
# expect_equal(form_pred, predict_class(res_form, lending_club[1:7, c("funded_amnt", "int_rate")]))
90+
91+
expect_equal(form_pred, predict_class(res_form, lending_club[1:7, c("funded_amnt", "int_rate")]))
9392

9493
})
9594

@@ -118,8 +117,7 @@ test_that('glmnet prediction, mulitiple lambda', {
118117
mult_pred$lambda <- rep(lams, each = 7)
119118
mult_pred <- mult_pred[, -2]
120119

121-
# not currently working; will fix
122-
# expect_equal(mult_pred, predict_class(xy_fit, lending_club[1:7, num_pred]))
120+
expect_equal(mult_pred, predict_class(xy_fit, lending_club[1:7, num_pred]))
123121

124122
res_form <- fit(
125123
logistic_reg(penalty = lams),
@@ -142,14 +140,12 @@ test_that('glmnet prediction, mulitiple lambda', {
142140
form_pred$lambda <- rep(lams, each = 7)
143141
form_pred <- form_pred[, -2]
144142

145-
# not currently working; will fix
146-
# expect_equal(form_pred, predict_class(res_form, lending_club[1:7, c("funded_amnt", "int_rate")]))
143+
expect_equal(form_pred, predict_class(res_form, lending_club[1:7, c("funded_amnt", "int_rate")]))
147144

148145
})
149146

150147
test_that('glmnet prediction, no lambda', {
151148

152-
skip("not currently working; will fix")
153149
skip_if_not_installed("glmnet")
154150

155151
xy_fit <- fit_xy(
@@ -163,7 +159,7 @@ test_that('glmnet prediction, no lambda', {
163159
mult_pred <-
164160
predict(xy_fit$fit,
165161
newx = as.matrix(lending_club[1:7, num_pred]),
166-
s = xy_fit$spec$args$penalty, type = "response")
162+
s = xy_fit$fit$lambda, type = "response")
167163
mult_pred <- stack(as.data.frame(mult_pred))
168164
mult_pred$values <- ifelse(mult_pred$values >= 0.5, "good", "bad")
169165
mult_pred$values <- factor(mult_pred$values, levels = levels(lending_club$Class))
@@ -199,7 +195,6 @@ test_that('glmnet prediction, no lambda', {
199195

200196
test_that('glmnet probabilities, one lambda', {
201197

202-
skip("not currently working; will fix")
203198
skip_if_not_installed("glmnet")
204199

205200
xy_fit <- fit_xy(
@@ -243,7 +238,6 @@ test_that('glmnet probabilities, one lambda', {
243238

244239
test_that('glmnet probabilities, mulitiple lambda', {
245240

246-
skip("not currently working; will fix")
247241
skip_if_not_installed("glmnet")
248242

249243
lams <- c(0.01, 0.1)
@@ -292,7 +286,6 @@ test_that('glmnet probabilities, mulitiple lambda', {
292286

293287
test_that('glmnet probabilities, no lambda', {
294288

295-
skip("not currently working; will fix")
296289
skip_if_not_installed("glmnet")
297290

298291
xy_fit <- fit_xy(

tests/testthat/test_multinom_reg_glmnet.R

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,10 @@ test_that('glmnet probabilities, mulitiple lambda', {
118118
names(mult_pred) <- NULL
119119
mult_pred <- tibble(.pred = mult_pred)
120120

121-
# needs fixin
122-
# expect_equal(
123-
# mult_pred$.pred,
124-
# multi_predict(xy_fit, iris[rows, 1:4], penalty = xy_fit$spec$args$penalty, type = "prob")$.pred
125-
# )
121+
expect_equal(
122+
mult_pred$.pred,
123+
multi_predict(xy_fit, iris[rows, 1:4], penalty = lams, type = "prob")$.pred
124+
)
126125

127126
mult_class <- names(mult_probs)[apply(mult_probs, 1, which.max)]
128127
mult_class <- tibble(
@@ -135,11 +134,10 @@ test_that('glmnet probabilities, mulitiple lambda', {
135134
names(mult_class) <- NULL
136135
mult_class <- tibble(.pred = mult_class)
137136

138-
# needs fixin
139-
# expect_equal(
140-
# mult_class$.pred,
141-
# multi_predict(xy_fit, iris[rows, 1:4], penalty = xy_fit$spec$args$penalty)$.pred
142-
# )
137+
expect_equal(
138+
mult_class$.pred,
139+
multi_predict(xy_fit, iris[rows, 1:4], penalty = lams)$.pred
140+
)
143141
})
144142

145143

0 commit comments

Comments
 (0)