Skip to content

Commit 3247d4f

Browse files
authored
Merge pull request #400 from tidymodels/allow-sparse-docs
Document `allow_sparse_x`
2 parents f2a731a + 97206ee commit 3247d4f

File tree

9 files changed

+52
-98
lines changed

9 files changed

+52
-98
lines changed

R/aaa_models.R

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ check_interface_val <- function(x) {
352352
#' a formula interface, typically some predictor preprocessing must
353353
#' be conducted. `glmnet` is a good example of this.
354354
#'
355-
#' There are three options that can be used for the encodings:
355+
#' There are four options that can be used for the encodings:
356356
#'
357357
#' `predictor_indicators` describes whether and how to create indicator/dummy
358358
#' variables from factor predictors. There are three options: `"none"` (do not
@@ -369,10 +369,15 @@ check_interface_val <- function(x) {
369369
#' intercept, `model.matrix()` computes a full set of indicators for the
370370
#' _first_ factor variable, but an incomplete set for the remainder.
371371
#'
372-
#' Finally, the option `remove_intercept` will remove the intercept column
372+
#' Next, the option `remove_intercept` will remove the intercept column
373373
#' _after_ `model.matrix()` is finished. This can be useful if the model
374374
#' function (e.g. `lm()`) automatically generates an intercept.
375375
#'
376+
#' Finally, `allow_sparse_x` specifies whether the model function can natively
377+
#' accommodate a sparse matrix representation for predictors during fitting
378+
#' and tuning.
379+
#'
380+
#'
376381
#' @references "How to build a parsnip model"
377382
#' \url{https://www.tidymodels.org/learn/develop/models/}
378383
#' @examples

man/set_new_model.Rd

Lines changed: 6 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
-4.49 KB
Binary file not shown.

tests/testthat/test_boost_tree_C50.R

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,14 @@ test_that('argument checks for data dimensions', {
153153
set_engine("C5.0") %>%
154154
set_mode("classification")
155155

156-
f_fit <- spec %>% fit(species ~ ., data = penguins)
157-
xy_fit <- spec %>% fit_xy(x = penguins[, -1], y = penguins$species)
156+
expect_warning(
157+
f_fit <- spec %>% fit(species ~ ., data = penguins),
158+
"1000 samples were requested"
159+
)
160+
expect_warning(
161+
xy_fit <- spec %>% fit_xy(x = penguins[, -1], y = penguins$species),
162+
"1000 samples were requested"
163+
)
158164

159165
expect_equal(f_fit$fit$control$minCases, nrow(penguins))
160166
expect_equal(xy_fit$fit$control$minCases, nrow(penguins))

tests/testthat/test_boost_tree_xgboost.R

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -281,20 +281,20 @@ test_that('early stopping', {
281281
regex = NA
282282
)
283283

284-
expect_warning(
284+
expect_warning(
285285
reg_fit <-
286286
boost_tree(trees = 20, stop_iter = 30, mode = "regression") %>%
287287
set_engine("xgboost", validation = .1) %>%
288288
fit(mpg ~ ., data = mtcars[-(1:4), ]),
289289
regex = "`early_stop` was reduced to 19"
290290
)
291-
expect_error(
292-
reg_fit <-
293-
boost_tree(trees = 20, stop_iter = 0, mode = "regression") %>%
294-
set_engine("xgboost", validation = .1) %>%
295-
fit(mpg ~ ., data = mtcars[-(1:4), ]),
296-
regex = "`early_stop` should be on"
297-
)
291+
expect_error(
292+
reg_fit <-
293+
boost_tree(trees = 20, stop_iter = 0, mode = "regression") %>%
294+
set_engine("xgboost", validation = .1) %>%
295+
fit(mpg ~ ., data = mtcars[-(1:4), ]),
296+
regex = "`early_stop` should be on"
297+
)
298298
})
299299

300300

@@ -379,9 +379,14 @@ test_that('argument checks for data dimensions', {
379379
penguins_dummy <- model.matrix(species ~ ., data = penguins)
380380
penguins_dummy <- as.data.frame(penguins_dummy[, -1])
381381

382-
f_fit <- spec %>% fit(species ~ ., data = penguins)
383-
xy_fit <- spec %>% fit_xy(x = penguins_dummy, y = penguins$species)
384-
382+
expect_warning(
383+
f_fit <- spec %>% fit(species ~ ., data = penguins),
384+
"1000 samples were requested"
385+
)
386+
expect_warning(
387+
xy_fit <- spec %>% fit_xy(x = penguins_dummy, y = penguins$species),
388+
"1000 samples were requested"
389+
)
385390
expect_equal(f_fit$fit$params$colsample_bytree, 1)
386391
expect_equal(f_fit$fit$params$min_child_weight, nrow(penguins))
387392
expect_equal(xy_fit$fit$params$colsample_bytree, 1)

tests/testthat/test_linear_reg_keras.R

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,11 @@ test_that('model fitting', {
5151
),
5252
regexp = NA
5353
)
54-
fit1$elapsed <- fit2$elapsed
55-
expect_equal(fit1, fit2)
54+
expect_equal(
55+
unlist(keras::get_weights(fit1$fit)),
56+
unlist(keras::get_weights(fit2$fit)),
57+
tolerance = .1
58+
)
5659

5760
expect_error(
5861
fit(

tests/testthat/test_logistic_reg_keras.R

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,11 @@ test_that('model fitting', {
6464
),
6565
regexp = NA
6666
)
67-
fit1$elapsed <- fit2$elapsed
68-
expect_equal(fit1, fit2)
67+
expect_equal(
68+
unlist(keras::get_weights(fit1$fit)),
69+
unlist(keras::get_weights(fit2$fit)),
70+
tolerance = .1
71+
)
6972

7073
expect_error(
7174
fit(

tests/testthat/test_multinom_reg_keras.R

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,11 @@ test_that('model fitting', {
6060
),
6161
regexp = NA
6262
)
63-
fit1$elapsed <- fit2$elapsed
64-
expect_equal(fit1, fit2)
63+
expect_equal(
64+
unlist(keras::get_weights(fit1$fit)),
65+
unlist(keras::get_weights(fit2$fit)),
66+
tolerance = .1
67+
)
6568

6669
expect_error(
6770
fit(

tests/testthat/test_varying.R

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ library(dplyr)
55

66
context("varying parameters")
77

8-
load(test_path("recipes_examples.RData"))
9-
108
test_that('main parsnip arguments', {
119

1210
mod_1 <- rand_forest() %>%
@@ -94,49 +92,6 @@ test_that('other parsnip arguments', {
9492
expect_equal(other_4, exp_4)
9593
})
9694

97-
98-
test_that('recipe parameters', {
99-
100-
# un-randomify the id names
101-
rec_1_id <- rec_1
102-
rec_1_id$steps[[1]]$id <- "center_1"
103-
rec_1_id$steps[[2]]$id <- "knnimpute_1"
104-
rec_1_id$steps[[3]]$id <- "pca_1"
105-
106-
rec_res_1 <- varying_args(rec_1_id)
107-
108-
exp_1 <- tibble(
109-
name = c("K", "num", "threshold", "options"),
110-
varying = c(TRUE, TRUE, FALSE, FALSE),
111-
id = c("knnimpute_1", rep("pca_1", 3)),
112-
type = rep("step", 4)
113-
)
114-
115-
expect_equal(rec_res_1, exp_1)
116-
117-
# un-randomify the id names
118-
rec_3_id <- rec_3
119-
rec_3_id$steps[[1]]$id <- "center_1"
120-
rec_3_id$steps[[2]]$id <- "knnimpute_1"
121-
rec_3_id$steps[[3]]$id <- "pca_1"
122-
123-
rec_res_3 <- varying_args(rec_3_id)
124-
exp_3 <- exp_1
125-
exp_3$varying <- FALSE
126-
expect_equal(rec_res_3, exp_3)
127-
128-
rec_res_4 <- varying_args(rec_4)
129-
130-
exp_4 <- tibble(
131-
name = character(),
132-
varying = logical(),
133-
id = character(),
134-
type = character()
135-
)
136-
137-
expect_equal(rec_res_4, exp_4)
138-
})
139-
14095
test_that("empty lists return FALSE - #131", {
14196
expect_equal(
14297
parsnip:::find_varying(list()),
@@ -164,33 +119,3 @@ test_that("varying() deeply nested in calls can be located - #134", {
164119
TRUE
165120
)
166121
})
167-
168-
test_that("recipe steps with non-varying args error if specified as varying()", {
169-
170-
rec_bad_varying <- rec_1
171-
rec_bad_varying$steps[[1]]$skip <- varying()
172-
173-
expect_error(
174-
varying_args(rec_bad_varying),
175-
"The following argument for a recipe step of type 'step_center' is not allowed to vary: 'skip'."
176-
)
177-
})
178-
179-
test_that("`full = FALSE` returns only varying arguments", {
180-
181-
x_spec <- rand_forest(min_n = varying()) %>%
182-
set_engine("ranger", sample.fraction = varying())
183-
184-
x_rec <- rec_1
185-
186-
expect_equal(
187-
varying_args(x_spec, full = FALSE)$name,
188-
c("min_n", "sample.fraction")
189-
)
190-
191-
expect_equal(
192-
varying_args(x_rec, full = FALSE)$name,
193-
c("K", "num")
194-
)
195-
196-
})

0 commit comments

Comments
 (0)