Skip to content

Commit b440a39

Browse files
committed
changed failed model predictions to n = 1
1 parent 0d8a87d commit b440a39

File tree

5 files changed

+17
-17
lines changed

5 files changed

+17
-17
lines changed

R/predict_class.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ predict_class.model_fit <- function(object, new_data, ...) {
1717
stop("No class prediction module defined for this model.", call. = FALSE)
1818

1919
if (inherits(object$fit, "try-error")) {
20-
return(failed_class(n = nrow(new_data), lvl = object$lvl))
20+
return(failed_class(lvl = object$lvl))
2121
}
2222

2323
new_data <- prepare_data(object, new_data)
@@ -58,7 +58,7 @@ predict_class <- function(object, ...)
5858

5959
# Some `predict()` helpers for failed models:
6060

61-
failed_class <- function(n, lvl) {
61+
failed_class <- function(n = 1, lvl) {
6262
res <- rep(NA_character_, n)
6363
res <- factor(res, levels = lvl)
6464
res

R/predict_classprob.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ predict_classprob.model_fit <- function(object, new_data, ...) {
1414
stop("No class probability module defined for this model.", call. = FALSE)
1515

1616
if (inherits(object$fit, "try-error")) {
17-
return(failed_classprob(n = nrow(new_data), lvl = object$lvl))
17+
return(failed_classprob(lvl = object$lvl))
1818
}
1919

2020
new_data <- prepare_data(object, new_data)
@@ -55,7 +55,7 @@ predict_classprob <- function(object, ...)
5555

5656
# Some `predict()` helpers for failed models:
5757

58-
failed_classprob <- function(n, lvl) {
58+
failed_classprob <- function(n = 1, lvl) {
5959
res <- matrix(NA_real_, nrow = n, ncol = length(lvl))
6060
colnames(res) <- lvl
6161
as_tibble(res)

R/predict_interval.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ predict_confint.model_fit <- function(object, new_data, level = 0.95, std_error
1515
"engine.", call. = FALSE)
1616

1717
if (inherits(object$fit, "try-error")) {
18-
return(failed_int(n = nrow(new_data), lvl = object$lvl))
18+
return(failed_int(lvl = object$lvl))
1919
}
2020

2121
new_data <- prepare_data(object, new_data)
@@ -52,7 +52,7 @@ predict_confint <- function(object, ...)
5252

5353
# Some `predict()` helpers for failed models:
5454

55-
failed_int <- function(n, lvl = NULL, nms = ".pred") {
55+
failed_int <- function(n = 1, lvl = NULL, nms = ".pred") {
5656
# TODO figure out multivariate models
5757
if (is.null(lvl)) {
5858
res <- matrix(NA_real_, nrow = n, ncol = length(nms) * 2)
@@ -81,7 +81,7 @@ predict_predint.model_fit <- function(object, new_data, level = 0.95, std_error
8181
"engine.", call. = FALSE)
8282

8383
if (inherits(object$fit, "try-error")) {
84-
return(failed_int(n = nrow(new_data), lvl = object$lvl))
84+
return(failed_int(lvl = object$lvl))
8585
}
8686

8787
new_data <- prepare_data(object, new_data)

R/predict_numeric.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ predict_numeric.model_fit <- function(object, new_data, ...) {
1616

1717
if (inherits(object$fit, "try-error")) {
1818
# TODO handle multivariate cases
19-
return(failed_numeric(n = nrow(new_data)))
19+
return(failed_numeric())
2020
}
2121

2222
new_data <- prepare_data(object, new_data)
@@ -56,7 +56,7 @@ predict_numeric <- function(object, ...)
5656

5757
# Some `predict()` helpers for failed models:
5858

59-
failed_numeric <- function(n, nms = ".pred") {
59+
failed_numeric <- function(n = 1, nms = ".pred") {
6060
res <- matrix(NA_real_, ncol = length(nms), nrow = n)
6161
if (length(nms) > 1) {
6262
colnames(res) <- nms

tests/testthat/test_failed_models.R

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ test_that('numeric model', {
3535
fit(Sepal.Length ~ ., data = iris_bad, control = ctrl)
3636

3737
num_res <- predict(lm_mod, iris_bad[1:11, -1])
38-
expect_equal(num_res, tibble(.pred = rep(NA_real_, 11)))
38+
expect_equal(num_res, tibble(.pred = rep(NA_real_, 1)))
3939

40-
exp_int_res <- tibble(.pred_lower = rep(NA_real_, 11), .pred_upper = rep(NA_real_, 11))
40+
exp_int_res <- tibble(.pred_lower = rep(NA_real_, 1), .pred_upper = rep(NA_real_, 1))
4141
ci_res <- predict(lm_mod, iris_bad[1:11, -1], type = "conf_int")
4242
expect_equal(ci_res, exp_int_res)
4343

@@ -55,22 +55,22 @@ test_that('classification model', {
5555
fit(Class ~ log(funded_amnt) + int_rate + big_num, data = lending_club, control = ctrl)
5656

5757
cls_res <- predict(log_reg, lending_club %>% dplyr::slice(1:7) %>% dplyr::select(-Class))
58-
exp_cls_res <- tibble(.pred_class = factor(rep(NA_character_, 7), levels = lvl))
58+
exp_cls_res <- tibble(.pred_class = factor(rep(NA_character_, 1), levels = lvl))
5959
expect_equal(cls_res, exp_cls_res)
6060

6161
prb_res <-
6262
predict(log_reg, lending_club %>% dplyr::slice(1:7) %>% dplyr::select(-Class), type = "prob")
63-
exp_prb_res <- tibble(.pred_bad = rep(NA_real_, 7), .pred_good = rep(NA_real_, 7))
63+
exp_prb_res <- tibble(.pred_bad = rep(NA_real_, 1), .pred_good = rep(NA_real_, 1))
6464
expect_equal(prb_res, exp_prb_res)
6565

6666
ci_res <-
6767
predict(log_reg, lending_club %>% dplyr::slice(1:7) %>% dplyr::select(-Class), type = "conf_int")
6868
exp_ci_res <-
6969
tibble(
70-
.pred_lower_bad = rep(NA_real_, 7),
71-
.pred_upper_bad = rep(NA_real_, 7),
72-
.pred_lower_good = rep(NA_real_, 7),
73-
.pred_upper_good = rep(NA_real_, 7)
70+
.pred_lower_bad = rep(NA_real_, 1),
71+
.pred_upper_bad = rep(NA_real_, 1),
72+
.pred_lower_good = rep(NA_real_, 1),
73+
.pred_upper_good = rep(NA_real_, 1)
7474
)
7575
expect_equal(ci_res, exp_ci_res)
7676
})

0 commit comments

Comments
 (0)