Skip to content

Commit f2a731a

Browse files
authored
Merge pull request #390 from tidymodels/stan-function-switch
remove warnings with new stan version
2 parents a82ed40 + 7412288 commit f2a731a

File tree

5 files changed

+58
-16
lines changed

5 files changed

+58
-16
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ export(show_call)
177177
export(show_engines)
178178
export(show_fit)
179179
export(show_model_info)
180+
export(stan_conf_int)
180181
export(surv_reg)
181182
export(svm_poly)
182183
export(svm_rbf)

R/linear_reg_data.R

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,13 +258,11 @@ set_pred(
258258
res$.std_error <- apply(results, 2, sd, na.rm = TRUE)
259259
res
260260
},
261-
func = c(pkg = "rstanarm", fun = "posterior_linpred"),
261+
func = c(pkg = "parsnip", fun = "stan_conf_int"),
262262
args =
263263
list(
264264
object = expr(object$fit),
265-
newdata = expr(new_data),
266-
transform = TRUE,
267-
seed = expr(sample.int(10^5, 1))
265+
newdata = expr(new_data)
268266
)
269267
)
270268
)

R/logistic_reg_data.R

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,11 @@ set_pred(
108108
res_1 <- res_2
109109
res_1$lo <- 1 - res_2$hi
110110
res_1$hi <- 1 - res_2$lo
111-
res <- bind_cols(res_1, res_2)
112111
lo_nms <- paste0(".pred_lower_", object$lvl)
113112
hi_nms <- paste0(".pred_upper_", object$lvl)
114-
colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2])
113+
colnames(res_1) <- c(lo_nms[1], hi_nms[1])
114+
colnames(res_2) <- c(lo_nms[2], hi_nms[2])
115+
res <- bind_cols(res_1, res_2)
115116

116117
if (object$spec$method$pred$conf_int$extras$std_error)
117118
res$.std_error <- results$se.fit
@@ -509,22 +510,22 @@ set_pred(
509510
res_1 <- res_2
510511
res_1$lo <- 1 - res_2$hi
511512
res_1$hi <- 1 - res_2$lo
512-
res <- bind_cols(res_1, res_2)
513513
lo_nms <- paste0(".pred_lower_", object$lvl)
514514
hi_nms <- paste0(".pred_upper_", object$lvl)
515-
colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2])
515+
colnames(res_1) <- c(lo_nms[1], hi_nms[1])
516+
colnames(res_2) <- c(lo_nms[2], hi_nms[2])
517+
res <- bind_cols(res_1, res_2)
516518

517-
if (object$spec$method$pred$conf_int$extras$std_error)
519+
if (object$spec$method$pred$conf_int$extras$std_error) {
518520
res$.std_error <- apply(results, 2, sd, na.rm = TRUE)
521+
}
519522
res
520523
},
521-
func = c(pkg = "rstanarm", fun = "posterior_linpred"),
524+
func = c(pkg = "parsnip", fun = "stan_conf_int"),
522525
args =
523526
list(
524-
object = quote(object$fit),
525-
newdata = quote(new_data),
526-
transform = TRUE,
527-
seed = expr(sample.int(10^5, 1))
527+
object = expr(object$fit),
528+
newdata = expr(new_data)
528529
)
529530
)
530531
)
@@ -554,10 +555,11 @@ set_pred(
554555
res_1 <- res_2
555556
res_1$lo <- 1 - res_2$hi
556557
res_1$hi <- 1 - res_2$lo
557-
res <- bind_cols(res_1, res_2)
558558
lo_nms <- paste0(".pred_lower_", object$lvl)
559559
hi_nms <- paste0(".pred_upper_", object$lvl)
560-
colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2])
560+
colnames(res_1) <- c(lo_nms[1], hi_nms[1])
561+
colnames(res_2) <- c(lo_nms[2], hi_nms[2])
562+
res <- bind_cols(res_1, res_2)
561563

562564
if (object$spec$method$pred$pred_int$extras$std_error)
563565
res$.std_error <- apply(results, 2, sd, na.rm = TRUE)

R/misc.R

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,4 +305,28 @@ update_engine_parameters <- function(eng_args, ...) {
305305
ret
306306
}
307307

308+
# ------------------------------------------------------------------------------
309+
# Since stan changed the function interface
310+
#' Wrapper for stan confidence intervals
311+
#' @param object A stan model fit
312+
#' @param newdata A data set.
313+
#' @export
314+
#' @keywords internal
315+
stan_conf_int <- function(object, newdata) {
316+
check_installs(list(method = list(libs = "rstanarm")))
317+
if (utils::packageVersion("rstanarm") >= "2.21.1") {
318+
fn <- rlang::call2("posterior_epred", .ns = "rstanarm",
319+
object = expr(object),
320+
newdata = expr(newdata),
321+
seed = expr(sample.int(10^5, 1)))
322+
} else {
323+
fn <- rlang::call2("posterior_linpred", .ns = "rstanarm",
324+
object = expr(object),
325+
newdata = expr(newdata),
326+
transform = TRUE,
327+
seed = expr(sample.int(10^5, 1)))
328+
}
329+
rlang::eval_tidy(fn)
330+
}
331+
308332

man/stan_conf_int.Rd

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)