Skip to content
6 changes: 3 additions & 3 deletions R/arg_check_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -703,19 +703,19 @@ NULL
#' @keywords internal
#' @param y_default_eval [chr] y value of default evaluation plot. It can be
#' "avg_runtime_sec" or one of the following performance metrics:
#' "avg_f1_score", "avg_log2_apop", "avg_bal_acc", or "avg_nmcc"
#' "avg_f1_score", "avg_log2_apop", "avg_bal_acc", "avg_mcc", or "avg_nmcc"
#'
.checkArgYDefaultEval <- function(y_default_eval) {
if (!is.character(y_default_eval)) {
stop("The `y_default_eval` argument can only take character values.")
}

if (!(y_default_eval %in%
c("avg_f1_score", "avg_log2_apop", "avg_bal_acc", "avg_nmcc"))
c("avg_f1_score", "avg_log2_apop", "avg_bal_acc", "avg_mcc", "avg_nmcc"))
) {
stop(paste(
"`y_default_eval` must be one of:",
"'avg_f1_score', 'avg_log2_apop', 'avg_bal_acc', 'avg_nmcc'."
"'avg_f1_score', 'avg_log2_apop', 'avg_bal_acc', 'avg_mcc', 'avg_nmcc'."
))
}
}
Expand Down
41 changes: 27 additions & 14 deletions R/core_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -483,28 +483,40 @@ getConfusionMatrix <- function(test_data_plus_predictions) {
return(CM)
}

#' .calculatenMCC()
#' .calculateMCC()
#'
#' Returns the normalized (to a 0 to 1 scale instead of -1 to 1) Matthews
#' correlation coefficient (nMCC) based on the AMR phenotype predictions by an
#' Returns the Matthews correlation coefficient (MCC)
#' based on the AMR phenotype predictions by an
#' ML model compared against the actual values.
#'
#' @inheritParams getConfusionMatrix
#' @return Normalized (to a 0 to 1 scale instead of -1 to 1) Matthews
#' correlation coefficient (nMCC)
.calculatenMCC <- function(test_data_plus_predictions) {
#' @return Matthews correlation coefficient (MCC), range -1 to 1
.calculateMCC <- function(test_data_plus_predictions) {
.checkArgTestDataPlusPredictions(test_data_plus_predictions)

target_var <- .getTargetVarName(test_data_plus_predictions)

mcc <- test_data_plus_predictions |>
yardstick::mcc(truth = !!target_var, estimate = .pred_class) |>
dplyr::select(.estimate) |>
as.numeric()
as.numeric() |>
round(2)

nmcc <- (mcc + 1) / 2
return(mcc)
Comment thread
epbrenner marked this conversation as resolved.
}

return(round(nmcc, 2))
#' .calculatenMCC()
#'
#' Returns the normalized (0 to 1) Matthews correlation coefficient (nMCC)
#' based on the AMR phenotype predictions by an ML model compared against
#' the actual values.
#'
#' @inheritParams getConfusionMatrix
#' @return Normalized Matthews correlation coefficient (nMCC), range 0 to 1
.calculatenMCC <- function(test_data_plus_predictions) {
mcc <- .calculateMCC(test_data_plus_predictions)
nmcc <- round((mcc + 1) / 2, 2)
return(nmcc)
}

#' .calculateF1()
Expand Down Expand Up @@ -597,8 +609,8 @@ getConfusionMatrix <- function(test_data_plus_predictions) {
))
} else if (prior >= 0.7) {
warning(paste(
"Classes are imbalanced toward the resistant phenotype.",
"Calculation of log2(AUPRC/prior) may be inappropriate."
"Classes are imbalanced for this model.",
"The use of the log2(AUPRC/prior) metric may be more informative in this imbalanced model."
))
}

Expand Down Expand Up @@ -699,8 +711,8 @@ getConfusionMatrix <- function(test_data_plus_predictions) {
#' calculateEvalMets()
#'
#' Returns the F1 score, area under the precision-recall curve (AUPRC), balanced
#' accuracy, normalized (to a 0 to 1 scale instead of -1 to 1) Matthews
#' correlation coefficient (nMCC), and log2(AUPRC/prior) based on the AMR
#' accuracy, Matthews correlation coefficient (MCC), normalized MCC (nMCC),
#' and log2(AUPRC/prior) based on the AMR
#' phenotype predictions by an ML model compared against the actual values.
#'
#' @inheritParams getConfusionMatrix
Expand Down Expand Up @@ -732,10 +744,11 @@ calculateEvalMets <- function(test_data_plus_predictions) {
bal_acc <- .calculateBalAcc(test_data_plus_predictions)
sens <- .calculateSensitivity(test_data_plus_predictions)
spec <- .calculateSpecificity(test_data_plus_predictions)
mcc <- .calculateMCC(test_data_plus_predictions)
nmcc <- .calculatenMCC(test_data_plus_predictions)
log2_apop <- .calculateLog2APOP(test_data_plus_predictions)

return(c(f1, auprc, bal_acc, nmcc, log2_apop))
return(c(f1, auprc, bal_acc, mcc, nmcc, log2_apop))
}

#' extractTopFeats()
Expand Down
26 changes: 13 additions & 13 deletions R/generate_matrices_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -288,22 +288,22 @@ skipImbalancedMatrix <- function(genome_ids,
if (group_type %in% c("drug_class", "drug_class_year", "drug_class_country")) {
genome_ids <- DBI::dbGetQuery(con, sprintf("
WITH class_phenotypes AS (
SELECT \"genome_drug.genome_id\" AS genome_id,
SELECT \"genome.genome_id\" AS genome_id,
MAX(CASE WHEN \"genome_drug.resistant_phenotype\" = 'Resistant'
THEN 1 ELSE 0 END) AS any_resistant,
MIN(CASE WHEN \"genome_drug.resistant_phenotype\" = 'Susceptible'
THEN 1 ELSE 0 END) AS all_susceptible
FROM metadata
WHERE %s
GROUP BY \"genome_drug.genome_id\"
GROUP BY \"genome.genome_id\"
)
SELECT genome_id
FROM class_phenotypes
WHERE any_resistant = 1 OR all_susceptible = 1
", condition_string))[[1]]
} else {
genome_ids <- DBI::dbGetQuery(con, sprintf("
SELECT DISTINCT \"genome_drug.genome_id\"
SELECT DISTINCT \"genome.genome_id\"
FROM metadata
WHERE %s
AND \"genome_drug.resistant_phenotype\" IN ('Resistant','Susceptible')
Expand Down Expand Up @@ -499,20 +499,20 @@ skipImbalancedMatrix <- function(genome_ids,
"
COPY (
SELECT
f.\"genome_drug.genome_id\" AS genome_id,
f.\"genome.genome_id\" AS genome_id,
%s AS feature_id,
MAX(CAST(%s AS DOUBLE)) AS value,
%s AS \"genome_drug.resistant_phenotype\"
%s
FROM %s
JOIN selected_genomes USING (genome_id)
JOIN keep_features kf ON %s = kf.feature_id
JOIN metadata f ON genome_id = f.\"genome_drug.genome_id\"
JOIN metadata f ON genome_id = f.\"genome.genome_id\"
WHERE %s
AND f.\"genome_drug.resistant_phenotype\" IN ('Resistant','Susceptible')
%s
GROUP BY f.\"genome_drug.genome_id\", %s %s
ORDER BY f.\"genome_drug.genome_id\", %s
GROUP BY f.\"genome.genome_id\", %s %s
ORDER BY f.\"genome.genome_id\", %s
)
TO '%s'
(FORMAT 'parquet', COMPRESSION 'zstd')
Expand Down Expand Up @@ -736,7 +736,7 @@ skipImbalancedMatrix <- function(genome_ids,
DBI::dbDisconnect(con0, shutdown = FALSE)

classes <- metadata_all |>
dplyr::select(genome_drug.genome_id, resistant_classes) |>
dplyr::select(genome.genome_id, resistant_classes) |>
dplyr::distinct() |>
dplyr::group_by(resistant_classes) |>
dplyr::count() |>
Expand All @@ -748,7 +748,7 @@ skipImbalancedMatrix <- function(genome_ids,

genomes_to_keep <- metadata_all |>
dplyr::filter(resistant_classes %in% classes) |>
dplyr::pull(genome_drug.genome_id)
dplyr::pull(genome.genome_id)

# Build one matrix per feature type and matrix type
for (ftype in names(feature_types)) {
Expand Down Expand Up @@ -828,21 +828,21 @@ skipImbalancedMatrix <- function(genome_ids,
"
COPY (
SELECT
f.\"genome_drug.genome_id\" AS genome_id,
f.\"genome.genome_id\" AS genome_id,
%s AS feature_id,
MAX(CAST(%s AS DOUBLE)) AS value,
resistant_classes
FROM %s
JOIN selected_genomes USING (genome_id)
JOIN keep_features kf ON %s = kf.feature_id
JOIN metadata f ON genome_id = f.\"genome_drug.genome_id\"
JOIN metadata f ON genome_id = f.\"genome.genome_id\"
WHERE resistant_classes <> 'Intermediate'
GROUP BY
f.\"genome_drug.genome_id\",
f.\"genome.genome_id\",
%s,
resistant_classes
ORDER BY
f.\"genome_drug.genome_id\",
f.\"genome.genome_id\",
%s
)
TO '%s'
Expand Down
4 changes: 4 additions & 0 deletions R/globals.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,12 @@ utils::globalVariables(c(
"idx_strat",
"model",
"neg_log10_adj_p",
"mcc",
"nmcc",
"num_obs",
"seed",
"sens",
"spec",
"output_prefix",
"p_value",
"pair_id",
Expand Down
8 changes: 6 additions & 2 deletions R/ife_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ removeTopFeats <- function(ml_input_tibble, top_feat_tibble) {

#' runIFE
#' Removes top features identified by ML models and retrains iteratively;
#' returns nMCC at each iteration.
#' returns MCC at each iteration.
#'
#' @param ml_input_tibble An ML-ready tibble generated by `loadMLInputTibble()`
#' @param by_num [bool] Set to `TRUE` if removing top features as a percentage
Expand Down Expand Up @@ -102,6 +102,7 @@ runIFE <- function(
num_obs_vec <- c()
res_prop_vec <- c()
fit_mixture_vec <- c()
mcc_vec <- c()
nmcc_vec <- c()
n_feats_removed_vec <- c()
total_feats_removed_vec <- c()
Expand Down Expand Up @@ -173,6 +174,9 @@ runIFE <- function(
fit_mixture_vec[i] <- ml_res$performance_tibble |>
dplyr::select(fit_mixture) |>
as.numeric()
mcc_vec[i] <- ml_res$performance_tibble |>
dplyr::select(mcc) |>
as.numeric()
nmcc_vec[i] <- ml_res$performance_tibble |>
dplyr::select(nmcc) |>
as.numeric()
Expand Down Expand Up @@ -266,7 +270,7 @@ runIFE <- function(
percent_removed = c(0, percent_removal_vec),
removal_type = rep(removal_type, length(num_obs_vec)),
num_obs = num_obs_vec, res_prop = res_prop_vec,
fit_mixture = fit_mixture_vec, nmcc = nmcc_vec,
fit_mixture = fit_mixture_vec, mcc = mcc_vec, nmcc = nmcc_vec,
n_feats_removed = n_feats_removed_vec,
total_feats_removed = total_feats_removed_vec, run_time_sec = run_time_vec
)
Expand Down
2 changes: 1 addition & 1 deletion R/plot_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ plotTopFeatsVI <- function(fit, n_top_feats = 10) {
#' or "n_fold"
#' @param y_default_eval [chr] y value of default evaluation plot. It can be
#' "avg_runtime_sec" or one of the following performance metrics:
#' "avg_f1_score", "avg_log2_apop", "avg_bal_acc", or "avg_nmcc"
#' "avg_f1_score", "avg_log2_apop", "avg_bal_acc", "avg_mcc", or "avg_nmcc"
#' @param xlab [chr] Label for x axis
#' @param ylab [chr] Label for y axis
#' @return A `ggplot2` scatterplot (performance metric or runtime vs.
Expand Down
1 change: 1 addition & 0 deletions R/prep_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ loadMLInputTibble <- function(parquet_path) {
target_var <- .getTargetVarName(long_tibble)

ml_input_tibble <- long_tibble |>
dplyr::distinct() |>
dplyr::mutate(!!target_var := as.factor(!!target_var)) |>
tidyr::pivot_wider(
id_cols = dplyr::all_of(
Expand Down
19 changes: 11 additions & 8 deletions R/run_ml_pipeline.R
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,11 @@ runMLPipeline <- function(
log2_apop <- .calculateLog2APOP(test_data_plus_predictions)
}

mcc <- .calculateMCC(test_data_plus_predictions)
nmcc <- .calculatenMCC(test_data_plus_predictions)

if (verbose) {
message(paste("Normalized Matthews correlation coefficient:", nmcc))
message(paste("Matthews correlation coefficient:", mcc, "| nMCC:", nmcc))
}

top_feat_tibble <- extractTopFeats(fit,
Expand Down Expand Up @@ -375,7 +376,7 @@ runMLPipeline <- function(
performance_tibble <- tibble::tibble(
num_obs = num_obs_ml_input_tibble,
n_feat = getNumFeat(ml_input_tibble), model, train_prop = split[1],
val_prop = split[2], n_fold, nmcc, run_time_sec,
val_prop = split[2], n_fold, mcc, nmcc, run_time_sec, seed,
date = as.character(Sys.Date())
)

Expand All @@ -396,18 +397,20 @@ runMLPipeline <- function(
) |>
tibble::add_column(bal_acc, .after = "nmcc") |>
tibble::add_column(f1, .after = "nmcc") |>
tibble::add_column(log2_apop, .after = "nmcc")
tibble::add_column(log2_apop, .after = "nmcc") |>
tibble::add_column(sens, .after = "nmcc") |>
tibble::add_column(spec, .after = "nmcc")
}

if (model == "LR") {
performance_tibble <- performance_tibble |>
tibble::add_column(fit_penalty, .before = "nmcc") |>
tibble::add_column(fit_mixture, .before = "nmcc")
tibble::add_column(fit_penalty, .before = "mcc") |>
tibble::add_column(fit_mixture, .before = "mcc")
} else if (model == "RF" || model == "BT") {
performance_tibble <- performance_tibble |>
tibble::add_column(fit_trees, .before = "nmcc") |>
tibble::add_column(fit_mtry, .before = "nmcc") |>
tibble::add_column(fit_min_n, .before = "nmcc")
tibble::add_column(fit_trees, .before = "mcc") |>
tibble::add_column(fit_mtry, .before = "mcc") |>
tibble::add_column(fit_min_n, .before = "mcc")
}

if (external_test_data) {
Expand Down
2 changes: 1 addition & 1 deletion README.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ This uses specific matrices to test whether ML models can predict resistance aga

- **Data preparation**: Load Parquet files and prepare ML-ready datasets
- **Model training**: User-customizable logistic regression via tidymodels
- **Evaluation**: nMCC, F1, balanced accuracy, AuPRC, and confusion matrices
- **Evaluation**: MCC, nMCC, F1, balanced accuracy, AuPRC, and confusion matrices
- **Feature importance**: Extract and rank predictive features

See the [package vignette](https://jravilab.github.io/amRml/articles/intro.html) for detailed usage.
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ associated with MDR.
- **Data preparation**: Load Parquet files and prepare ML-ready datasets
- **Model training**: User-customizable logistic regression via
tidymodels
- **Evaluation**: nMCC, F1, balanced accuracy, AuPRC, and confusion
- **Evaluation**: MCC, nMCC, F1, balanced accuracy, AuPRC, and confusion
matrices
- **Feature importance**: Extract and rank predictive features

Expand Down
9 changes: 5 additions & 4 deletions doc/intro.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ knitr::opts_chunk$set(

## ----metrics------------------------------------------------------------------
# # Individual metrics
# nmcc <- calculatenMCC(predictions) # Normalized MCC (0-1 scale)
# mcc <- calculateMCC(predictions) # Matthews correlation coefficient (-1 to 1)
# nmcc <- calculatenMCC(predictions) # Normalized MCC (0 to 1)
# f1 <- calculateF1(predictions) # F1 score
# bal_acc <- calculateBalAcc(predictions) # Balanced accuracy
# auprc <- calculateAUPRC(predictions) # Area under PR curve
Expand Down Expand Up @@ -150,7 +151,7 @@ knitr::opts_chunk$set(
# verbose = TRUE
# )
#
# # Results include nMCC at each iteration
# # Results include MCC and nMCC at each iteration
# ife_results$ife_performance_tibble
# ife_results$feats_removed # If return_feats = TRUE
#
Expand Down Expand Up @@ -233,8 +234,8 @@ knitr::opts_chunk$set(
# )
#
# # 5. Compare real vs baseline performance
# cat("Real nMCC:", results$performance_tibble$nmcc, "\n")
# cat("Baseline nMCC:", baseline_results$performance_tibble$nmcc, "\n")
# cat("Real MCC:", results$performance_tibble$mcc, "| nMCC:", results$performance_tibble$nmcc, "\n")
# cat("Baseline MCC:", baseline_results$performance_tibble$mcc, "| nMCC:", baseline_results$performance_tibble$nmcc, "\n")
#
# # 6. Run iterative feature elimination
# ife_results <- runIFE(
Expand Down
2 changes: 1 addition & 1 deletion vignettes/intro.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ head(results$top_feat_tibble)
| `model` | Model type (`"LR"`) |
| `train_prop`, `val_prop` | Train/validation split proportions |
| `fit_penalty`, `fit_mixture` | Fitted hyperparameters |
| `nmcc`, `f1`, `bal_acc`, `log2_apop` | Performance metrics |
| `mcc`, `nmcc`, `f1`, `bal_acc`, `log2_apop` | Performance metrics |
| `run_time_sec` | Runtime in seconds |

**`top_feat_tibble`** — ranked feature importance:
Expand Down