Skip to content

Commit 4268ab8

Browse files
committed
Improve detection of descriptor functions using the globals package
1 parent d9b90ff commit 4268ab8

File tree

4 files changed

+200
-159
lines changed

4 files changed

+200
-159
lines changed

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ Imports:
2525
glue,
2626
magrittr,
2727
stats,
28-
tidyr
28+
tidyr,
29+
globals
2930
Roxygen: list(markdown = TRUE)
3031
RoxygenNote: 6.1.0.9000
3132
Suggests:

R/descriptors.R

Lines changed: 130 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,107 @@
11
#' @name descriptors
2-
#' @aliases descriptors n_obs n_cols n_preds n_facts n_levs
2+
#' @aliases descriptors .n_obs .n_cols .n_preds .n_facts .n_levs .x .y .dat
33
#' @title Data Set Characteristics Available when Fitting Models
4-
#' @description When using the `fit` functions there are some
4+
#' @description When using the `fit()` functions there are some
55
#' variables that will be available for use in arguments. For
66
#' example, if the user would like to choose an argument value
7-
#' based on the current number of rows in a data set, the `n_obs`
8-
#' variable can be used. See Details below.
7+
#' based on the current number of rows in a data set, the `.n_obs()`
8+
#' function can be used. See Details below.
99
#' @details
10-
#' Existing variables:
10+
#' Existing functions:
1111
#' \itemize{
12-
#' \item `n_obs`: the current number of rows in the data set.
13-
#' \item `n_cols`: the number of columns in the data set that are
12+
#' \item `.n_obs()`: The current number of rows in the data set.
13+
#' \item `.n_cols()`: The number of columns in the data set that are
1414
#' associated with the predictors prior to dummy variable creation.
15-
#' \item `n_preds`: the number of predictors after dummy variables
15+
#' \item `.n_preds()`: The number of predictors after dummy variables
1616
#' are created (if any).
17-
#' \item `n_facts`: the number of factor predictors in the dat set.
18-
#' \item `n_levs`: If the outcome is a factor, this is a table
19-
#' with the counts for each level (and `NA` otherwise)
17+
#' \item `.n_facts()`: The number of factor predictors in the dat set.
18+
#' \item `.n_levs()`: If the outcome is a factor, this is a table
19+
#' with the counts for each level (and `NA` otherwise).
20+
#' \item `.x()`: The predictors returned in the format given. Either a
21+
#' data frame or a matrix.
22+
#' \item `.y()`: The known outcomes returned in the format given. Either
23+
#' a vector, matrix, or data frame.
24+
#' \item `.dat()`: A data frame containing all of the predictors and the
25+
#' outcomes. If `fit_xy()` was used, the outcomes are attached as the
26+
#' column, `..y`.
2027
#' }
2128
#'
2229
#' For example, if you use the model formula `Sepal.Width ~ .` with the `iris`
2330
#' data, the values would be
2431
#' \preformatted{
25-
#' n_cols = 4 (the 4 columns in `iris`)
26-
#' n_preds = 5 (3 numeric columns + 2 from Species dummy variables)
27-
#' n_obs = 150
28-
#' n_levs = NA (no factor outcome)
29-
#' n_facts = 1 (the Species predictor)
32+
#' .n_cols() = 4 (the 4 columns in `iris`)
33+
#' .n_preds() = 5 (3 numeric columns + 2 from Species dummy variables)
34+
#' .n_obs() = 150
35+
#' .n_levs() = NA (no factor outcome)
36+
#' .n_facts() = 1 (the Species predictor)
37+
#' .y() = <vector> (Sepal.Width as a vector)
38+
#' .x() = <data.frame> (The other 4 columns as a data frame)
39+
#' .dat() = <data.frame> (The full data set)
3040
#' }
3141
#'
3242
#' If the formula `Species ~ .` where used:
3343
#' \preformatted{
34-
#' n_cols = 4 (the 4 numeric columns in `iris`)
35-
#' n_preds = 4 (same)
36-
#' n_obs = 150
37-
#' n_levs = c(setosa = 50, versicolor = 50, virginica = 50)
38-
#' n_facts = 0
44+
#' .n_cols() = 4 (the 4 numeric columns in `iris`)
45+
#' .n_preds() = 4 (same)
46+
#' .n_obs() = 150
47+
#' .n_levs() = c(setosa = 50, versicolor = 50, virginica = 50)
48+
#' .n_facts() = 0
49+
#' .y() = <vector> (Species as a vector)
50+
#' .x() = <data.frame> (The other 4 columns as a data frame)
51+
#' .dat() = <data.frame> (The full data set)
3952
#' }
4053
#'
41-
#' To use these in a model fit, either `expression` or `rlang::expr` can be
42-
#' used to delay the evaluation of the argument value until the time when the
43-
#' model is run via `fit` (and the variables listed above are available).
54+
#' To use these in a model fit, pass them to a model specification.
55+
#' The evaluation is delayed until the time when the
56+
#' model is run via `fit()` (and the variables listed above are available).
4457
#' For example:
4558
#'
4659
#' \preformatted{
47-
#' library(rlang)
4860
#'
4961
#' data("lending_club")
5062
#'
51-
#' rand_forest(mode = "classification", mtry = expr(n_cols - 2))
63+
#' rand_forest(mode = "classification", mtry = .n_cols() - 2)
5264
#' }
5365
#'
54-
#' When no instance of `expr` is found in any of the argument
55-
#' values, the descriptor calculation code will not be executed.
66+
#' When no descriptors are found, the computation of the descriptor values
67+
#' is not executed.
5668
#'
5769
NULL
5870

71+
#' @export
72+
#' @rdname descriptors
73+
.n_cols <- function() descr_env$.n_cols()
74+
75+
#' @export
76+
#' @rdname descriptors
77+
.n_preds <- function() descr_env$.n_preds()
78+
79+
#' @export
80+
#' @rdname descriptors
81+
.n_obs <- function() descr_env$.n_obs()
82+
83+
#' @export
84+
#' @rdname descriptors
85+
.n_levs <- function() descr_env$.n_levs()
86+
87+
#' @export
88+
#' @rdname descriptors
89+
.n_facts <- function() descr_env$.n_facts()
90+
91+
#' @export
92+
#' @rdname descriptors
93+
.x <- function() descr_env$.x()
94+
95+
#' @export
96+
#' @rdname descriptors
97+
.y <- function() descr_env$.y()
98+
99+
#' @export
100+
#' @rdname descriptors
101+
.dat <- function() descr_env$.dat()
102+
103+
# Descriptor retrievers --------------------------------------------------------
104+
59105
get_descr_form <- function(formula, data) {
60106
if (inherits(data, "tbl_spark")) {
61107
res <- get_descr_spark(formula, data)
@@ -209,11 +255,11 @@ get_descr_spark <- function(formula, data) {
209255

210256
get_descr_xy <- function(x, y) {
211257

212-
if(is.factor(y)) {
213-
.n_levs <- function() {
214-
table(y, dnn = NULL)
215-
}
216-
} else n_levs <- function() { NA }
258+
.n_levs <- if (is.factor(y)) {
259+
function() table(y, dnn = NULL)
260+
} else {
261+
function() NA
262+
}
217263

218264
.n_cols <- function() {
219265
ncol(x)
@@ -235,9 +281,7 @@ get_descr_xy <- function(x, y) {
235281
}
236282

237283
.dat <- function() {
238-
x <- as.data.frame(x)
239-
x[[".y"]] <- y
240-
x
284+
convert_xy_to_form_fit(x, y)
241285
}
242286

243287
.x <- function() {
@@ -278,51 +322,52 @@ make_descr <- function(object) {
278322
any(expr_main) | any(expr_others)
279323
}
280324

281-
# # given a quosure arg, does the expression contain a descriptor function?
282-
# find_descr <- function(x) {
283-
#
284-
# if(is_quosure(x)) {
285-
# x <- rlang::quo_get_expr(x)
286-
# }
287-
#
288-
# if(is_descr(x)) {
289-
# TRUE
290-
# }
291-
#
292-
# # handles NULL, literals
293-
# else if (is.atomic(x) | is.name(x)) {
294-
# FALSE
295-
# }
296-
#
297-
# else if (is.call(x)) {
298-
# any(rlang::squash_lgl(lapply(x, find_descr)))
299-
# }
300-
#
301-
# else {
302-
# # User supplied incorrect input
303-
# stop("Don't know how to handle type ", typeof(x),
304-
# call. = FALSE)
305-
# }
306-
#
307-
# }
308-
#
309-
# is_descr <- function(expr) {
310-
#
311-
# descriptors <- list(
312-
# expr(.n_cols),
313-
# expr(.n_preds),
314-
# expr(.n_obs),
315-
# expr(.n_levs),
316-
# expr(.n_facts),
317-
# expr(.x),
318-
# expr(.y),
319-
# expr(.dat)
320-
# )
321-
#
322-
# any(map_lgl(descriptors, identical, y = expr))
323-
# }
324-
325-
# descrs = list of functions that actually eval .n_cols()
325+
# Locate descriptors -----------------------------------------------------------
326+
327+
# take a list of arguments, see if any require descriptors
328+
requires_descrs <- function(lst) {
329+
any(map_lgl(lst, has_any_descrs))
330+
}
331+
332+
# given a quosure arg, does the expression contain a descriptor function?
333+
has_any_descrs <- function(x) {
334+
335+
.x_expr <- rlang::get_expr(x)
336+
.x_env <- rlang::get_env(x, parent.frame())
337+
338+
# evaluated value
339+
# required so we don't pass an empty env to findGlobals(), which is an error
340+
if (identical(.x_env, rlang::empty_env())) {
341+
return(FALSE)
342+
}
343+
344+
# globals::globalsOf() is recursive and finds globals if the user passes
345+
# in a function that wraps a descriptor fn
346+
.globals <- globals::globalsOf(expr = .x_expr, envir = .x_env)
347+
.globals <- names(.globals)
348+
349+
any(map_lgl(.globals, is_descr))
350+
}
351+
352+
is_descr <- function(x) {
353+
354+
descrs <- list(
355+
".n_cols",
356+
".n_preds",
357+
".n_obs",
358+
".n_levs",
359+
".n_facts",
360+
".x",
361+
".y",
362+
".dat"
363+
)
364+
365+
any(map_lgl(descrs, identical, y = x))
366+
}
367+
368+
# Helpers for overwriting descriptors temporarily ------------------------------
369+
370+
# descrs = list of functions that actually eval to .n_cols()
326371
poke_descrs <- function(descrs) {
327372

328373
descr_names <- names(descr_env)
@@ -348,51 +393,14 @@ scoped_descrs <- function(descrs, frame = caller_env()) {
348393

349394
# Inline everything so the call will succeed in any environment
350395
expr <- call2(on.exit, call2(poke_descrs, old), add = TRUE)
351-
eval_bare(expr, frame)
396+
rlang::eval_bare(expr, frame)
352397

353398
invisible(old)
354399
}
355400

356-
#' @export
357-
.n_cols <- function() {
358-
descr_env$.n_cols()
359-
}
360-
361-
#' @export
362-
.n_preds <- function() {
363-
descr_env$.n_preds()
364-
}
365-
366-
#' @export
367-
.n_obs <- function() {
368-
descr_env$.n_obs()
369-
}
370-
371-
#' @export
372-
.n_levs <- function() {
373-
descr_env$.n_levs()
374-
}
375-
376-
#' @export
377-
.n_facts <- function() {
378-
descr_env$.n_facts()
379-
}
380-
381-
#' @export
382-
.x <- function() {
383-
descr_env$.x()
384-
}
385-
386-
#' @export
387-
.y <- function() {
388-
descr_env$.y()
389-
}
390-
391-
#' @export
392-
.dat <- function() {
393-
descr_env$.dat()
394-
}
395-
401+
# Environment that descriptors are found in.
402+
# Originally set to error. At fit time, these are temporarily overriden
403+
# with their actual implementations
396404
descr_env <- rlang::new_environment(
397405
data = list(
398406
.n_cols = function() abort("Descriptor context not set"),

R/fit_helpers.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ form_form <-
1515

1616
object <- check_mode(object, y_levels)
1717

18-
# need to improve this to find any descriptors
19-
if(make_descr(object)) {
18+
# if descriptors are needed, update descr_env with the calculated values
19+
if(requires_descrs(object$args)) {
2020
data_stats <- get_descr_form(env$formula, env$data)
2121
scoped_descrs(data_stats)
2222
}
@@ -67,8 +67,8 @@ xy_xy <- function(object, env, control, target = "none", ...) {
6767

6868
object <- check_mode(object, levels(env$y))
6969

70-
# need to improve this to find any descriptors
71-
if(make_descr(object)) {
70+
# if descriptors are needed, update descr_env with the calculated values
71+
if(requires_descrs(object$args)) {
7272
data_stats <- get_descr_form(env$formula, env$data)
7373
scoped_descrs(data_stats)
7474
}

0 commit comments

Comments
 (0)