Skip to content

Commit 1c13d55

Browse files
committed
spark updates for rf
1 parent 87c3b3f commit 1c13d55

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

R/rand_forest.R

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ translate.rand_forest <- function(x, engine, ...) {
159159
)
160160
else
161161
x$method$fit_args$type <- x$mode
162+
163+
# See "Details" in ?ml_random_forest_classifier
164+
if (is.numeric(x$method$fit_args$feature_subset_strategy))
165+
x$method$fit_args$feature_subset_strategy <-
166+
paste(x$method$fit_args$feature_subset_strategy)
162167
}
163168

164169
# add checks to error trap or change things for this method

R/rand_forest_constr.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ rand_forest_randomForest_fit <-
3939
libs = "randomForest",
4040
interface = "data.frame",
4141
protect = c("x", "y"),
42-
fit_name = c(pkg = "randomForest", fun = "randomForest.default"),
42+
fit_name = c(pkg = "randomForest", fun = "randomForest"),
4343
alternates =
4444
list()
4545
)
@@ -49,7 +49,7 @@ rand_forest_spark_fit <-
4949
list(
5050
libs = "sparklyr",
5151
interface = "spark",
52-
protect = c("x", "features_col", "label_col", "type"),
52+
protect = c("x", "formula", "type"),
5353
fit_name = c(pkg = "sparklyr", fun = "ml_random_forest"),
5454
alternates =
5555
list(

0 commit comments

Comments
 (0)