Skip to content

Commit ccd33ff

Browse files
authored
Merge pull request #69 from topepo/fix/protect-core-model-args
Fix/protect core model args
2 parents 7bc24ff + f9faf4e commit ccd33ff

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

R/arguments.R

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,11 @@ prune_arg_list <- function(x, whitelist = NULL, modified = character(0)) {
7070
x
7171
}
7272

73-
check_others <- function(args, obj) {
73+
check_others <- function(args, obj, core_args) {
7474
# Make sure that we are not trying to modify an argument that
75-
# is explicitly protected in the method metadata
76-
common_args <- intersect(obj$protect, names(args))
75+
# is explicitly protected in the method metadata or arg_key
76+
protected_args <- unique(c(obj$protect, core_args))
77+
common_args <- intersect(protected_args, names(args))
7778
if (length(common_args) > 0) {
7879
args <- args[!(names(args) %in% common_args)]
7980
common_args <- paste0(common_args, collapse = ", ")

R/translate.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ translate.default <- function(x, engine, ...) {
5959
# check secondary arguments to see if they are in the final
6060
# expression unless there are dots, warn if protected args are
6161
# being altered
62-
x$others <- check_others(x$others, x$method$fit)
62+
eng_arg_key <- arg_key[[x$engine]]
63+
x$others <- check_others(x$others, x$method$fit, eng_arg_key)
6364

6465
# keep only modified args
6566
modifed_args <- !vapply(actual_args, null_value, lgl(1))

tests/testthat/test_rand_forest_ranger.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ num_pred <- c("funded_amnt", "annual_inc", "num_il_tl")
1111
lc_basic <- rand_forest()
1212
lc_ranger <- rand_forest(others = list(seed = 144))
1313

14-
bad_ranger_cls <- rand_forest(others = list(min.node.size = -10))
14+
bad_ranger_cls <- rand_forest(others = list(replace = "bad"))
1515
bad_rf_cls <- rand_forest(others = list(sampsize = -10))
1616

1717
ctrl <- fit_control(verbosity = 1, catch = FALSE)
@@ -160,7 +160,7 @@ num_pred <- names(mtcars)[3:6]
160160

161161
car_basic <- rand_forest()
162162

163-
bad_ranger_reg <- rand_forest(others = list(min.node.size = -10))
163+
bad_ranger_reg <- rand_forest(others = list(replace = "bad"))
164164
bad_rf_reg <- rand_forest(others = list(sampsize = -10))
165165

166166
ctrl <- list(verbosity = 1, catch = FALSE)

0 commit comments

Comments
 (0)