Skip to content

Commit 772a542

Browse files
committed
Export custom training helpers
1 parent 61db24d commit 772a542

File tree

7 files changed

+54
-4
lines changed

7 files changed

+54
-4
lines changed

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,15 @@ export(.n_obs)
5757
export(.n_preds)
5858
export(.x)
5959
export(.y)
60+
export(C5.0_train)
6061
export(boost_tree)
6162
export(check_empty_ellipse)
6263
export(fit)
6364
export(fit.model_spec)
6465
export(fit_control)
6566
export(fit_xy)
6667
export(fit_xy.model_spec)
68+
export(keras_mlp)
6769
export(linear_reg)
6870
export(logistic_reg)
6971
export(make_classes)
@@ -97,6 +99,7 @@ export(varying_args)
9799
export(varying_args.model_spec)
98100
export(varying_args.recipe)
99101
export(varying_args.step)
102+
export(xgb_train)
100103
import(rlang)
101104
importFrom(dplyr,arrange)
102105
importFrom(dplyr,as_tibble)

R/boost_tree.R

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,9 @@ check_args.boost_tree <- function(object) {
257257

258258
# xgboost helpers --------------------------------------------------------------
259259

260+
#' Training helper for xgboost
261+
#'
262+
#' @export
260263
xgb_train <- function(
261264
x, y,
262265
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bytree = 1,
@@ -399,6 +402,9 @@ xgb_by_tree <- function(tree, object, new_data, type, ...) {
399402

400403
# C5.0 helpers -----------------------------------------------------------------
401404

405+
#' Training helper for C5.0
406+
#'
407+
#' @export
402408
C5.0_train <-
403409
function(x, y, weights = NULL, trials = 15, minCases = 2, sample = 0, ...) {
404410
other_args <- list(...)

R/boost_tree_data.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ boost_tree_xgboost_data <-
2424
fit = list(
2525
interface = "matrix",
2626
protect = c("x", "y"),
27-
func = c(pkg = NULL, fun = "xgb_train"),
27+
func = c(pkg = "parsnip", fun = "xgb_train"),
2828
defaults =
2929
list(
3030
nthread = 1,
@@ -94,7 +94,7 @@ boost_tree_C5.0_data <-
9494
fit = list(
9595
interface = "data.frame",
9696
protect = c("x", "y", "weights"),
97-
func = c(pkg = NULL, fun = "C5.0_train"),
97+
func = c(pkg = "parsnip", fun = "C5.0_train"),
9898
defaults = list()
9999
),
100100
classes = list(

R/mlp_data.R

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ mlp_keras_data <-
2222
fit = list(
2323
interface = "matrix",
2424
protect = c("x", "y"),
25-
func = c(pkg = NULL, fun = "keras_mlp"),
25+
func = c(pkg = "parsnip", fun = "keras_mlp"),
2626
defaults = list()
2727
),
2828
pred = list(
@@ -131,6 +131,9 @@ class2ind <- function (x, drop2nd = FALSE) {
131131
y
132132
}
133133

134+
#' MLP in Keras
135+
#'
136+
#' @export
134137
keras_mlp <-
135138
function(x, y,
136139
hidden_units = 5, decay = 0, dropout = 0, epochs = 20, act = "softmax",
@@ -155,7 +158,7 @@ keras_mlp <-
155158
else
156159
y <- matrix(y, ncol = 1)
157160
}
158-
161+
159162
model <- keras::keras_model_sequential()
160163
if(decay > 0) {
161164
model %>%

man/C5.0_train.Rd

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/keras_mlp.Rd

Lines changed: 13 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/xgb_train.Rd

Lines changed: 13 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)