Skip to content

Commit dfd8d59

Browse files
committed
updates for quantile predictions
1 parent 3d19f60 commit dfd8d59

File tree

4 files changed

+68
-3
lines changed

4 files changed

+68
-3
lines changed

NAMESPACE

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ S3method(predict_confint,model_fit)
2222
S3method(predict_num,"_elnet")
2323
S3method(predict_num,model_fit)
2424
S3method(predict_predint,model_fit)
25+
S3method(predict_quantile,model_fit)
2526
S3method(predict_raw,"_elnet")
2627
S3method(predict_raw,"_lognet")
2728
S3method(predict_raw,"_multnet")
@@ -95,6 +96,8 @@ export(predict_num)
9596
export(predict_num.model_fit)
9697
export(predict_predint)
9798
export(predict_predint.model_fit)
99+
export(predict_quantile)
100+
export(predict_quantile.model_fit)
98101
export(predict_raw)
99102
export(predict_raw.model_fit)
100103
export(rand_forest)
@@ -113,10 +116,12 @@ import(rlang)
113116
importFrom(dplyr,arrange)
114117
importFrom(dplyr,as_tibble)
115118
importFrom(dplyr,bind_cols)
119+
importFrom(dplyr,bind_rows)
116120
importFrom(dplyr,collect)
117121
importFrom(dplyr,full_join)
118122
importFrom(dplyr,funs)
119123
importFrom(dplyr,group_by)
124+
importFrom(dplyr,mutate)
120125
importFrom(dplyr,pull)
121126
importFrom(dplyr,rename)
122127
importFrom(dplyr,rename_at)
@@ -159,6 +164,7 @@ importFrom(stats,predict)
159164
importFrom(stats,qnorm)
160165
importFrom(stats,qt)
161166
importFrom(stats,quantile)
167+
importFrom(stats,setNames)
162168
importFrom(stats,terms)
163169
importFrom(stats,update)
164170
importFrom(tibble,as_tibble)

R/misc.R

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,15 @@ check_args <- function(object) {
178178
check_args.default <- function(object) {
179179
invisible(object)
180180
}
181+
182+
# ------------------------------------------------------------------------------
183+
184+
# copied form recipes
185+
186+
names0 <- function (num, prefix = "x") {
187+
if (num < 1)
188+
stop("`num` should be > 0", call. = FALSE)
189+
ind <- format(1:num)
190+
ind <- gsub(" ", "0", ind)
191+
paste0(prefix, ind)
192+
}

R/predict.R

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
#' @param object An object of class `model_fit`
88
#' @param new_data A rectangular data object, such as a data frame.
99
#' @param type A single character value or `NULL`. Possible values
10-
#' are "numeric", "class", "probs", "conf_int", "pred_int", or
11-
#' "raw". When `NULL`, `predict` will choose an appropriate value
10+
#' are "numeric", "class", "probs", "conf_int", "pred_int", "quantile",
11+
#' or "raw". When `NULL`, `predict` will choose an appropriate value
1212
#' based on the model's mode.
1313
#' @param opts A list of optional arguments to the underlying
1414
#' predict function that will be used when `type = "raw"`. The
@@ -45,6 +45,10 @@
4545
#' produces for class probabilities (or other non-scalar outputs),
4646
#' the columns will be named `.pred_lower_classlevel` and so on.
4747
#'
48+
#' Quantile predictions return a tibble with a column `.pred`, which is
49+
#' a list-column. Each list element contains a tibble with columns
50+
#' `.pred` and `.quantile` (and perhaps others).
51+
#'
4852
#' Using `type = "raw"` with `predict.model_fit` (or using
4953
#' `predict_raw`) will return the unadulterated results of the
5054
#' prediction function.
@@ -96,6 +100,7 @@ predict.model_fit <- function (object, new_data, type = NULL, opts = list(), ...
96100
prob = predict_classprob(object = object, new_data = new_data, ...),
97101
conf_int = predict_confint(object = object, new_data = new_data, ...),
98102
pred_int = predict_predint(object = object, new_data = new_data, ...),
103+
quantile = predict_quantile(object = object, new_data = new_data, ...),
99104
raw = predict_raw(object = object, new_data = new_data, opts = opts, ...),
100105
stop("I don't know about type = '", "'", type, call. = FALSE)
101106
)
@@ -112,7 +117,8 @@ predict.model_fit <- function (object, new_data, type = NULL, opts = list(), ...
112117
res
113118
}
114119

115-
pred_types <- c("raw", "numeric", "class", "link", "prob", "conf_int", "pred_int")
120+
pred_types <-
121+
c("raw", "numeric", "class", "link", "prob", "conf_int", "pred_int", "quantile")
116122

117123
#' @importFrom glue glue_collapse
118124
check_pred_type <- function(object, type) {

R/predict_quantile.R

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
#' @keywords internal
2+
#' @rdname other_predict
3+
#' @param quant A vector of numbers between 0 and 1 for the quantile being
4+
#' predicted.
5+
#' @inheritParams predict.model_fit
6+
#' @method predict_quantile model_fit
7+
#' @export predict_quantile.model_fit
8+
#' @export
9+
predict_quantile.model_fit <-
10+
function (object, new_data, quantile = (1:9)/10, ...) {
11+
12+
if (is.null(object$spec$method$quantile))
13+
stop("No quantile prediction method defined for this ",
14+
"engine.", call. = FALSE)
15+
16+
new_data <- prepare_data(object, new_data)
17+
18+
# preprocess data
19+
if (!is.null(object$spec$method$quantile$pre))
20+
new_data <- object$spec$method$quantile$pre(new_data, object)
21+
22+
# Pass some extra arguments to be used in post-processor
23+
object$spec$method$quantile$args$p <- quantile
24+
pred_call <- make_pred_call(object$spec$method$quantile)
25+
26+
res <- eval_tidy(pred_call)
27+
28+
# post-process the predictions
29+
if(!is.null(object$spec$method$quantile$post)) {
30+
res <- object$spec$method$quantile$post(res, object)
31+
}
32+
33+
res
34+
}
35+
36+
#' @export
37+
#' @keywords internal
38+
#' @rdname other_predict
39+
#' @inheritParams predict.model_fit
40+
predict_quantile <- function (object, ...)
41+
UseMethod("predict_quantile")

0 commit comments

Comments
 (0)