Skip to content

Commit ab1a405

Browse files
committed
Merge branch 'master' into add-in
2 parents cc6d563 + 49364a7 commit ab1a405

15 files changed

+113
-117
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ export(show_call)
177177
export(show_engines)
178178
export(show_fit)
179179
export(show_model_info)
180+
export(stan_conf_int)
180181
export(surv_reg)
181182
export(svm_poly)
182183
export(svm_rbf)

R/aaa_models.R

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ check_interface_val <- function(x) {
352352
#' a formula interface, typically some predictor preprocessing must
353353
#' be conducted. `glmnet` is a good example of this.
354354
#'
355-
#' There are three options that can be used for the encodings:
355+
#' There are four options that can be used for the encodings:
356356
#'
357357
#' `predictor_indicators` describes whether and how to create indicator/dummy
358358
#' variables from factor predictors. There are three options: `"none"` (do not
@@ -369,10 +369,15 @@ check_interface_val <- function(x) {
369369
#' intercept, `model.matrix()` computes a full set of indicators for the
370370
#' _first_ factor variable, but an incomplete set for the remainder.
371371
#'
372-
#' Finally, the option `remove_intercept` will remove the intercept column
372+
#' Next, the option `remove_intercept` will remove the intercept column
373373
#' _after_ `model.matrix()` is finished. This can be useful if the model
374374
#' function (e.g. `lm()`) automatically generates an intercept.
375375
#'
376+
#' Finally, `allow_sparse_x` specifies whether the model function can natively
377+
#' accommodate a sparse matrix representation for predictors during fitting
378+
#' and tuning.
379+
#'
380+
#'
376381
#' @references "How to build a parsnip model"
377382
#' \url{https://www.tidymodels.org/learn/develop/models/}
378383
#' @examples

R/descriptors.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,9 @@ get_descr_spark <- function(formula, data) {
250250
.obs <- function() obs
251251
.lvls <- function() y_vals
252252
.facts <- function() factor_pred
253-
.x <- function() abort("Descriptor `.x()` not defined for Spark.")
254-
.y <- function() abort("Descriptor `.y()` not defined for Spark.")
255-
.dat <- function() abort("Descriptor `.dat()` not defined for Spark.")
253+
.x <- function() abort("Descriptor .x() not defined for Spark.")
254+
.y <- function() abort("Descriptor .y() not defined for Spark.")
255+
.dat <- function() abort("Descriptor .dat() not defined for Spark.")
256256

257257
# still need .x(), .y(), .dat() ?
258258

R/linear_reg_data.R

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,13 +258,11 @@ set_pred(
258258
res$.std_error <- apply(results, 2, sd, na.rm = TRUE)
259259
res
260260
},
261-
func = c(pkg = "rstanarm", fun = "posterior_linpred"),
261+
func = c(pkg = "parsnip", fun = "stan_conf_int"),
262262
args =
263263
list(
264264
object = expr(object$fit),
265-
newdata = expr(new_data),
266-
transform = TRUE,
267-
seed = expr(sample.int(10^5, 1))
265+
newdata = expr(new_data)
268266
)
269267
)
270268
)

R/logistic_reg_data.R

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,11 @@ set_pred(
108108
res_1 <- res_2
109109
res_1$lo <- 1 - res_2$hi
110110
res_1$hi <- 1 - res_2$lo
111-
res <- bind_cols(res_1, res_2)
112111
lo_nms <- paste0(".pred_lower_", object$lvl)
113112
hi_nms <- paste0(".pred_upper_", object$lvl)
114-
colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2])
113+
colnames(res_1) <- c(lo_nms[1], hi_nms[1])
114+
colnames(res_2) <- c(lo_nms[2], hi_nms[2])
115+
res <- bind_cols(res_1, res_2)
115116

116117
if (object$spec$method$pred$conf_int$extras$std_error)
117118
res$.std_error <- results$se.fit
@@ -509,22 +510,22 @@ set_pred(
509510
res_1 <- res_2
510511
res_1$lo <- 1 - res_2$hi
511512
res_1$hi <- 1 - res_2$lo
512-
res <- bind_cols(res_1, res_2)
513513
lo_nms <- paste0(".pred_lower_", object$lvl)
514514
hi_nms <- paste0(".pred_upper_", object$lvl)
515-
colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2])
515+
colnames(res_1) <- c(lo_nms[1], hi_nms[1])
516+
colnames(res_2) <- c(lo_nms[2], hi_nms[2])
517+
res <- bind_cols(res_1, res_2)
516518

517-
if (object$spec$method$pred$conf_int$extras$std_error)
519+
if (object$spec$method$pred$conf_int$extras$std_error) {
518520
res$.std_error <- apply(results, 2, sd, na.rm = TRUE)
521+
}
519522
res
520523
},
521-
func = c(pkg = "rstanarm", fun = "posterior_linpred"),
524+
func = c(pkg = "parsnip", fun = "stan_conf_int"),
522525
args =
523526
list(
524-
object = quote(object$fit),
525-
newdata = quote(new_data),
526-
transform = TRUE,
527-
seed = expr(sample.int(10^5, 1))
527+
object = expr(object$fit),
528+
newdata = expr(new_data)
528529
)
529530
)
530531
)
@@ -554,10 +555,11 @@ set_pred(
554555
res_1 <- res_2
555556
res_1$lo <- 1 - res_2$hi
556557
res_1$hi <- 1 - res_2$lo
557-
res <- bind_cols(res_1, res_2)
558558
lo_nms <- paste0(".pred_lower_", object$lvl)
559559
hi_nms <- paste0(".pred_upper_", object$lvl)
560-
colnames(res) <- c(lo_nms[1], hi_nms[1], lo_nms[2], hi_nms[2])
560+
colnames(res_1) <- c(lo_nms[1], hi_nms[1])
561+
colnames(res_2) <- c(lo_nms[2], hi_nms[2])
562+
res <- bind_cols(res_1, res_2)
561563

562564
if (object$spec$method$pred$pred_int$extras$std_error)
563565
res$.std_error <- apply(results, 2, sd, na.rm = TRUE)

R/misc.R

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,4 +305,28 @@ update_engine_parameters <- function(eng_args, ...) {
305305
ret
306306
}
307307

308+
# ------------------------------------------------------------------------------
309+
# Since stan changed the function interface
310+
#' Wrapper for stan confidence intervals
311+
#' @param object A stan model fit
312+
#' @param newdata A data set.
313+
#' @export
314+
#' @keywords internal
315+
stan_conf_int <- function(object, newdata) {
316+
check_installs(list(method = list(libs = "rstanarm")))
317+
if (utils::packageVersion("rstanarm") >= "2.21.1") {
318+
fn <- rlang::call2("posterior_epred", .ns = "rstanarm",
319+
object = expr(object),
320+
newdata = expr(newdata),
321+
seed = expr(sample.int(10^5, 1)))
322+
} else {
323+
fn <- rlang::call2("posterior_linpred", .ns = "rstanarm",
324+
object = expr(object),
325+
newdata = expr(newdata),
326+
transform = TRUE,
327+
seed = expr(sample.int(10^5, 1)))
328+
}
329+
rlang::eval_tidy(fn)
330+
}
331+
308332

man/set_new_model.Rd

Lines changed: 6 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/stan_conf_int.Rd

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
-4.49 KB
Binary file not shown.

tests/testthat/test_boost_tree_C50.R

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,14 @@ test_that('argument checks for data dimensions', {
153153
set_engine("C5.0") %>%
154154
set_mode("classification")
155155

156-
f_fit <- spec %>% fit(species ~ ., data = penguins)
157-
xy_fit <- spec %>% fit_xy(x = penguins[, -1], y = penguins$species)
156+
expect_warning(
157+
f_fit <- spec %>% fit(species ~ ., data = penguins),
158+
"1000 samples were requested"
159+
)
160+
expect_warning(
161+
xy_fit <- spec %>% fit_xy(x = penguins[, -1], y = penguins$species),
162+
"1000 samples were requested"
163+
)
158164

159165
expect_equal(f_fit$fit$control$minCases, nrow(penguins))
160166
expect_equal(xy_fit$fit$control$minCases, nrow(penguins))

0 commit comments

Comments
 (0)