Skip to content

Commit cb08638

Browse files
committed
allow 'objective' for xgboost to be passed as an engine argument
1 parent 7b81378 commit cb08638

File tree

3 files changed

+29
-8
lines changed

3 files changed

+29
-8
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# parsnip (development version)
22

3+
* For `xgboost` models, users can now pass `objective` to `set_engine("xgboost")`.
4+
35
# parsnip 0.1.4
46

57
* `show_engines()` will provide information on the current set for a model.

R/boost_tree.R

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,8 @@ xgb_train <- function(
312312
min_child_weight = 1, gamma = 0, subsample = 1, validation = 0,
313313
early_stop = NULL, ...) {
314314

315+
others <- list(...)
316+
315317
num_class <- length(levels(y))
316318

317319
if (!is.numeric(validation) || validation < 0 || validation >= 1) {
@@ -327,13 +329,15 @@ xgb_train <- function(
327329
}
328330

329331

330-
if (is.numeric(y)) {
331-
loss <- "reg:squarederror"
332-
} else {
333-
if (num_class == 2) {
334-
loss <- "binary:logistic"
332+
if (!any(names(others) == "objective")) {
333+
if (is.numeric(y)) {
334+
others$objective <- "reg:squarederror"
335335
} else {
336-
loss <- "multi:softprob"
336+
if (num_class == 2) {
337+
others$objective <- "binary:logistic"
338+
} else {
339+
others$objective <- "multi:softprob"
340+
}
337341
}
338342
}
339343

@@ -378,7 +382,6 @@ xgb_train <- function(
378382
watchlist = quote(x$watchlist),
379383
params = arg_list,
380384
nrounds = nrounds,
381-
objective = loss,
382385
early_stopping_rounds = early_stop
383386
)
384387
if (!is.null(num_class) && num_class > 2) {
@@ -388,7 +391,7 @@ xgb_train <- function(
388391
call <- make_call(fun = "xgb.train", ns = "xgboost", main_args)
389392

390393
# override or add some other args
391-
others <- list(...)
394+
392395
others <-
393396
others[!(names(others) %in% c("data", "weights", "nrounds", "num_class", names(arg_list)))]
394397
if (!(any(names(others) == "verbose"))) {

tests/testthat/test_boost_tree_xgboost.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,26 @@ test_that('xgboost regression prediction', {
159159

160160
form_pred <- predict(form_fit$fit, newdata = xgb.DMatrix(data = as.matrix(mtcars[1:8, -1])))
161161
expect_equal(form_pred, predict(form_fit, new_data = mtcars[1:8, -1])$.pred)
162+
163+
expect_equal(form_fit$fit$params$objective, "reg:squarederror")
164+
162165
})
163166

164167

165168

169+
test_that('xgboost alternate objective', {
170+
skip_if_not_installed("xgboost")
171+
172+
spec <-
173+
boost_tree() %>%
174+
set_engine("xgboost", objective = "reg:pseudohubererror") %>%
175+
set_mode("regression")
176+
177+
xgb_fit <- spec %>% fit(mpg ~ ., data = mtcars)
178+
expect_equal(xgb_fit$fit$params$objective, "reg:pseudohubererror")
179+
})
180+
181+
166182
test_that('submodel prediction', {
167183

168184
skip_if_not_installed("xgboost")

0 commit comments

Comments
 (0)