Skip to content

Commit 368518b

Browse files
committed
fixes for as_tibble.matrix with no colnames
1 parent 6499cd4 commit 368518b

File tree

6 files changed

+18
-18
lines changed

6 files changed

+18
-18
lines changed

R/mlp.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,9 @@ nnet_softmax <- function(results, object) {
379379
results <- cbind(1 - results, results)
380380

381381
results <- apply(results, 1, function(x) exp(x)/sum(exp(x)))
382-
results <- as_tibble(t(results))
382+
results <- t(results)
383383
names(results) <- paste0(".pred_", object$lvl)
384+
results <- as_tibble(results)
384385
results
385386
}
386387

R/mlp_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ set_pred(
139139
value = list(
140140
pre = NULL,
141141
post = function(x, object) {
142-
x <- as_tibble(x)
143142
colnames(x) <- object$lvl
143+
x <- as_tibble(x)
144144
x
145145
},
146146
func = c(pkg = "keras", fun = "predict_proba"),

R/multinom_reg_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ set_pred(
216216
value = list(
217217
pre = NULL,
218218
post = function(x, object) {
219-
x <- as_tibble(x)
220219
colnames(x) <- object$lvl
220+
x <- as_tibble(x)
221221
x
222222
},
223223
func = c(pkg = "keras", fun = "predict_proba"),

R/nullmodel.R

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,15 @@ predict.nullmodel <- function (object, new_data = NULL, type = NULL, ...) {
111111
out <- factor(rep(object$value, n), levels = object$levels)
112112
}
113113
} else {
114-
if(type %in% c("prob", "class")) stop("Only numeric predicitons are applicable to regression models")
115-
if(length(object$value) == 1) {
114+
if (type %in% c("prob", "class")) {
115+
stop("Only numeric predicitons are applicable to regression models")
116+
}
117+
if (length(object$value) == 1) {
116118
out <- rep(object$value, n)
117119
} else {
118-
out <- as_tibble(matrix(rep(object$value, n),
119-
ncol = length(object$value), byrow = TRUE))
120-
121-
names(out) <- names(object$value)
120+
out <- matrix(rep(object$value, n), ncol = length(object$value), byrow = TRUE)
121+
colnames(out) <- names(object$value)
122+
out <- as_tibble(out)
122123
}
123124
}
124125
out

R/nullmodel_data.R

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ set_pred(
9595
value = list(
9696
pre = NULL,
9797
post = function(x, object) {
98-
str(as_tibble(x))
9998
as_tibble(x)
10099
},
101100
func = c(fun = "predict"),

tests/testthat/test_logistic_reg_keras.R

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -160,10 +160,9 @@ test_that('classification probabilities', {
160160
y = tr_dat$Class
161161
)
162162

163-
keras_pred <-
164-
keras::predict_proba(lr_fit$fit, as.matrix(te_dat[, -1])) %>%
165-
as_tibble() %>%
166-
setNames(paste0(".pred_", lr_fit$lvl))
163+
keras_pred <- keras::predict_proba(lr_fit$fit, as.matrix(te_dat[, -1]))
164+
colnames(keras_pred) <- paste0(".pred_", lr_fit$lvl)
165+
keras_pred <- as_tibble(keras_pred)
167166

168167
parsnip_pred <- predict(lr_fit, te_dat[, -1], type = "prob")
169168
expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred))
@@ -177,10 +176,10 @@ test_that('classification probabilities', {
177176
y = tr_dat$Class
178177
)
179178

180-
keras_pred <-
181-
keras::predict_proba(plrfit$fit, as.matrix(te_dat[, -1])) %>%
182-
as_tibble() %>%
183-
setNames(paste0(".pred_", lr_fit$lvl))
179+
keras_pred <- keras::predict_proba(plrfit$fit, as.matrix(te_dat[, -1]))
180+
colnames(keras_pred) <- paste0(".pred_", lr_fit$lvl)
181+
keras_pred <- as_tibble(keras_pred)
182+
184183
parsnip_pred <- predict(plrfit, te_dat[, -1], type = "prob")
185184
expect_equal(as.data.frame(keras_pred), as.data.frame(parsnip_pred))
186185

0 commit comments

Comments
 (0)