diff --git a/R/arg_check_ml.R b/R/arg_check_ml.R index f6acf09..2056323 100644 --- a/R/arg_check_ml.R +++ b/R/arg_check_ml.R @@ -703,7 +703,7 @@ NULL #' @keywords internal #' @param y_default_eval [chr] y value of default evaluation plot. It can be #' "avg_runtime_sec" or one of the following performance metrics: -#' "avg_f1_score", "avg_log2_apop", "avg_bal_acc", or "avg_nmcc" +#' "avg_f1_score", "avg_log2_apop", "avg_bal_acc", "avg_mcc", or "avg_nmcc" #' .checkArgYDefaultEval <- function(y_default_eval) { if (!is.character(y_default_eval)) { @@ -711,11 +711,11 @@ NULL } if (!(y_default_eval %in% - c("avg_f1_score", "avg_log2_apop", "avg_bal_acc", "avg_nmcc")) + c("avg_f1_score", "avg_log2_apop", "avg_bal_acc", "avg_mcc", "avg_nmcc")) ) { stop(paste( "`y_default_eval` must be one of:", - "'avg_f1_score', 'avg_log2_apop', 'avg_bal_acc', 'avg_nmcc'." + "'avg_f1_score', 'avg_log2_apop', 'avg_bal_acc', 'avg_mcc', 'avg_nmcc'." )) } } diff --git a/R/core_ml.R b/R/core_ml.R index d805e9c..ec29d41 100644 --- a/R/core_ml.R +++ b/R/core_ml.R @@ -483,16 +483,15 @@ getConfusionMatrix <- function(test_data_plus_predictions) { return(CM) } -#' .calculatenMCC() +#' .calculateMCC() #' -#' Returns the normalized (to a 0 to 1 scale instead of -1 to 1) Matthews -#' correlation coefficient (nMCC) based on the AMR phenotype predictions by an +#' Returns the Matthews correlation coefficient (MCC) +#' based on the AMR phenotype predictions by an #' ML model compared against the actual values. #' #' @inheritParams getConfusionMatrix -#' @return Normalized (to a 0 to 1 scale instead of -1 to 1) Matthews -#' correlation coefficient (nMCC) -.calculatenMCC <- function(test_data_plus_predictions) { +#' @return Matthews correlation coefficient (MCC), range -1 to 1 +.calculateMCC <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) target_var <- .getTargetVarName(test_data_plus_predictions) @@ -500,11 +499,24 @@ getConfusionMatrix <- function(test_data_plus_predictions) { mcc <- test_data_plus_predictions |> yardstick::mcc(truth = !!target_var, estimate = .pred_class) |> dplyr::select(.estimate) |> - as.numeric() + as.numeric() |> + round(2) - nmcc <- (mcc + 1) / 2 + return(mcc) +} - return(round(nmcc, 2)) +#' .calculatenMCC() +#' +#' Returns the normalized (0 to 1) Matthews correlation coefficient (nMCC) +#' based on the AMR phenotype predictions by an ML model compared against +#' the actual values. +#' +#' @inheritParams getConfusionMatrix +#' @return Normalized Matthews correlation coefficient (nMCC), range 0 to 1 +.calculatenMCC <- function(test_data_plus_predictions) { + mcc <- .calculateMCC(test_data_plus_predictions) + nmcc <- round((mcc + 1) / 2, 2) + return(nmcc) } #' .calculateF1() @@ -597,8 +609,8 @@ getConfusionMatrix <- function(test_data_plus_predictions) { )) } else if (prior >= 0.7) { warning(paste( - "Classes are imbalanced toward the resistant phenotype.", - "Calculation of log2(AUPRC/prior) may be inappropriate." + "Classes are imbalanced for this model.", + "The use of the log2(AUPRC/prior) metric may be more informative in this imbalanced model." )) } @@ -699,8 +711,8 @@ getConfusionMatrix <- function(test_data_plus_predictions) { #' calculateEvalMets() #' #' Returns the F1 score, area under the precision-recall curve (AUPRC), balanced -#' accuracy, normalized (to a 0 to 1 scale instead of -1 to 1) Matthews -#' correlation coefficient (nMCC), and log2(AUPRC/prior) based on the AMR +#' accuracy, Matthews correlation coefficient (MCC), normalized MCC (nMCC), +#' and log2(AUPRC/prior) based on the AMR #' phenotype predictions by an ML model compared against the actual values. #' #' @inheritParams getConfusionMatrix @@ -732,10 +744,11 @@ calculateEvalMets <- function(test_data_plus_predictions) { bal_acc <- .calculateBalAcc(test_data_plus_predictions) sens <- .calculateSensitivity(test_data_plus_predictions) spec <- .calculateSpecificity(test_data_plus_predictions) + mcc <- .calculateMCC(test_data_plus_predictions) nmcc <- .calculatenMCC(test_data_plus_predictions) log2_apop <- .calculateLog2APOP(test_data_plus_predictions) - return(c(f1, auprc, bal_acc, nmcc, log2_apop)) + return(c(f1, auprc, bal_acc, mcc, nmcc, log2_apop)) } #' extractTopFeats() diff --git a/R/generate_matrices_ml.R b/R/generate_matrices_ml.R index 978edb0..601a678 100644 --- a/R/generate_matrices_ml.R +++ b/R/generate_matrices_ml.R @@ -288,14 +288,14 @@ skipImbalancedMatrix <- function(genome_ids, if (group_type %in% c("drug_class", "drug_class_year", "drug_class_country")) { genome_ids <- DBI::dbGetQuery(con, sprintf(" WITH class_phenotypes AS ( - SELECT \"genome_drug.genome_id\" AS genome_id, + SELECT \"genome.genome_id\" AS genome_id, MAX(CASE WHEN \"genome_drug.resistant_phenotype\" = 'Resistant' THEN 1 ELSE 0 END) AS any_resistant, MIN(CASE WHEN \"genome_drug.resistant_phenotype\" = 'Susceptible' THEN 1 ELSE 0 END) AS all_susceptible FROM metadata WHERE %s - GROUP BY \"genome_drug.genome_id\" + GROUP BY \"genome.genome_id\" ) SELECT genome_id FROM class_phenotypes @@ -303,7 +303,7 @@ skipImbalancedMatrix <- function(genome_ids, ", condition_string))[[1]] } else { genome_ids <- DBI::dbGetQuery(con, sprintf(" - SELECT DISTINCT \"genome_drug.genome_id\" + SELECT DISTINCT \"genome.genome_id\" FROM metadata WHERE %s AND \"genome_drug.resistant_phenotype\" IN ('Resistant','Susceptible') @@ -499,7 +499,7 @@ skipImbalancedMatrix <- function(genome_ids, " COPY ( SELECT - f.\"genome_drug.genome_id\" AS genome_id, + f.\"genome.genome_id\" AS genome_id, %s AS feature_id, MAX(CAST(%s AS DOUBLE)) AS value, %s AS \"genome_drug.resistant_phenotype\" @@ -507,12 +507,12 @@ skipImbalancedMatrix <- function(genome_ids, FROM %s JOIN selected_genomes USING (genome_id) JOIN keep_features kf ON %s = kf.feature_id - JOIN metadata f ON genome_id = f.\"genome_drug.genome_id\" + JOIN metadata f ON genome_id = f.\"genome.genome_id\" WHERE %s AND f.\"genome_drug.resistant_phenotype\" IN ('Resistant','Susceptible') %s - GROUP BY f.\"genome_drug.genome_id\", %s %s - ORDER BY f.\"genome_drug.genome_id\", %s + GROUP BY f.\"genome.genome_id\", %s %s + ORDER BY f.\"genome.genome_id\", %s ) TO '%s' (FORMAT 'parquet', COMPRESSION 'zstd') @@ -736,7 +736,7 @@ skipImbalancedMatrix <- function(genome_ids, DBI::dbDisconnect(con0, shutdown = FALSE) classes <- metadata_all |> - dplyr::select(genome_drug.genome_id, resistant_classes) |> + dplyr::select(genome.genome_id, resistant_classes) |> dplyr::distinct() |> dplyr::group_by(resistant_classes) |> dplyr::count() |> @@ -748,7 +748,7 @@ skipImbalancedMatrix <- function(genome_ids, genomes_to_keep <- metadata_all |> dplyr::filter(resistant_classes %in% classes) |> - dplyr::pull(genome_drug.genome_id) + dplyr::pull(genome.genome_id) # Build one matrix per feature type and matrix type for (ftype in names(feature_types)) { @@ -828,21 +828,21 @@ skipImbalancedMatrix <- function(genome_ids, " COPY ( SELECT - f.\"genome_drug.genome_id\" AS genome_id, + f.\"genome.genome_id\" AS genome_id, %s AS feature_id, MAX(CAST(%s AS DOUBLE)) AS value, resistant_classes FROM %s JOIN selected_genomes USING (genome_id) JOIN keep_features kf ON %s = kf.feature_id - JOIN metadata f ON genome_id = f.\"genome_drug.genome_id\" + JOIN metadata f ON genome_id = f.\"genome.genome_id\" WHERE resistant_classes <> 'Intermediate' GROUP BY - f.\"genome_drug.genome_id\", + f.\"genome.genome_id\", %s, resistant_classes ORDER BY - f.\"genome_drug.genome_id\", + f.\"genome.genome_id\", %s ) TO '%s' diff --git a/R/globals.R b/R/globals.R index 131d016..e1e3ab5 100644 --- a/R/globals.R +++ b/R/globals.R @@ -44,8 +44,12 @@ utils::globalVariables(c( "idx_strat", "model", "neg_log10_adj_p", + "mcc", "nmcc", "num_obs", + "seed", + "sens", + "spec", "output_prefix", "p_value", "pair_id", diff --git a/R/ife_ml.R b/R/ife_ml.R index 7358640..e6e38c0 100644 --- a/R/ife_ml.R +++ b/R/ife_ml.R @@ -48,7 +48,7 @@ removeTopFeats <- function(ml_input_tibble, top_feat_tibble) { #' runIFE #' Removes top features identified by ML models and retrains iteratively; -#' returns nMCC at each iteration. +#' returns MCC at each iteration. #' #' @param ml_input_tibble An ML-ready tibble generated by `loadMLInputTibble()` #' @param by_num [bool] Set to `TRUE` if removing top features as a percentage @@ -102,6 +102,7 @@ runIFE <- function( num_obs_vec <- c() res_prop_vec <- c() fit_mixture_vec <- c() + mcc_vec <- c() nmcc_vec <- c() n_feats_removed_vec <- c() total_feats_removed_vec <- c() @@ -173,6 +174,9 @@ runIFE <- function( fit_mixture_vec[i] <- ml_res$performance_tibble |> dplyr::select(fit_mixture) |> as.numeric() + mcc_vec[i] <- ml_res$performance_tibble |> + dplyr::select(mcc) |> + as.numeric() nmcc_vec[i] <- ml_res$performance_tibble |> dplyr::select(nmcc) |> as.numeric() @@ -266,7 +270,7 @@ runIFE <- function( percent_removed = c(0, percent_removal_vec), removal_type = rep(removal_type, length(num_obs_vec)), num_obs = num_obs_vec, res_prop = res_prop_vec, - fit_mixture = fit_mixture_vec, nmcc = nmcc_vec, + fit_mixture = fit_mixture_vec, mcc = mcc_vec, nmcc = nmcc_vec, n_feats_removed = n_feats_removed_vec, total_feats_removed = total_feats_removed_vec, run_time_sec = run_time_vec ) diff --git a/R/plot_ml.R b/R/plot_ml.R index c99886c..a986394 100644 --- a/R/plot_ml.R +++ b/R/plot_ml.R @@ -99,7 +99,7 @@ plotTopFeatsVI <- function(fit, n_top_feats = 10) { #' or "n_fold" #' @param y_default_eval [chr] y value of default evaluation plot. It can be #' "avg_runtime_sec" or one of the following performance metrics: -#' "avg_f1_score", "avg_log2_apop", "avg_bal_acc", or "avg_nmcc" +#' "avg_f1_score", "avg_log2_apop", "avg_bal_acc", "avg_mcc", or "avg_nmcc" #' @param xlab [chr] Label for x axis #' @param ylab [chr] Label for y axis #' @return A `ggplot2` scatterplot (performance metric or runtime vs. diff --git a/R/prep_ml.R b/R/prep_ml.R index ef3bbff..9ce2ef2 100644 --- a/R/prep_ml.R +++ b/R/prep_ml.R @@ -113,6 +113,7 @@ loadMLInputTibble <- function(parquet_path) { target_var <- .getTargetVarName(long_tibble) ml_input_tibble <- long_tibble |> + dplyr::distinct() |> dplyr::mutate(!!target_var := as.factor(!!target_var)) |> tidyr::pivot_wider( id_cols = dplyr::all_of( diff --git a/R/run_ml_pipeline.R b/R/run_ml_pipeline.R index eab8dcc..5c47271 100644 --- a/R/run_ml_pipeline.R +++ b/R/run_ml_pipeline.R @@ -310,10 +310,11 @@ runMLPipeline <- function( log2_apop <- .calculateLog2APOP(test_data_plus_predictions) } + mcc <- .calculateMCC(test_data_plus_predictions) nmcc <- .calculatenMCC(test_data_plus_predictions) if (verbose) { - message(paste("Normalized Matthews correlation coefficient:", nmcc)) + message(paste("Matthews correlation coefficient:", mcc, "| nMCC:", nmcc)) } top_feat_tibble <- extractTopFeats(fit, @@ -375,7 +376,7 @@ runMLPipeline <- function( performance_tibble <- tibble::tibble( num_obs = num_obs_ml_input_tibble, n_feat = getNumFeat(ml_input_tibble), model, train_prop = split[1], - val_prop = split[2], n_fold, nmcc, run_time_sec, + val_prop = split[2], n_fold, mcc, nmcc, run_time_sec, seed, date = as.character(Sys.Date()) ) @@ -396,18 +397,20 @@ runMLPipeline <- function( ) |> tibble::add_column(bal_acc, .after = "nmcc") |> tibble::add_column(f1, .after = "nmcc") |> - tibble::add_column(log2_apop, .after = "nmcc") + tibble::add_column(log2_apop, .after = "nmcc") |> + tibble::add_column(sens, .after = "nmcc") |> + tibble::add_column(spec, .after = "nmcc") } if (model == "LR") { performance_tibble <- performance_tibble |> - tibble::add_column(fit_penalty, .before = "nmcc") |> - tibble::add_column(fit_mixture, .before = "nmcc") + tibble::add_column(fit_penalty, .before = "mcc") |> + tibble::add_column(fit_mixture, .before = "mcc") } else if (model == "RF" || model == "BT") { performance_tibble <- performance_tibble |> - tibble::add_column(fit_trees, .before = "nmcc") |> - tibble::add_column(fit_mtry, .before = "nmcc") |> - tibble::add_column(fit_min_n, .before = "nmcc") + tibble::add_column(fit_trees, .before = "mcc") |> + tibble::add_column(fit_mtry, .before = "mcc") |> + tibble::add_column(fit_min_n, .before = "mcc") } if (external_test_data) { diff --git a/README.Rmd b/README.Rmd index 0783647..68a00f0 100644 --- a/README.Rmd +++ b/README.Rmd @@ -103,7 +103,7 @@ This uses specific matrices to test whether ML models can predict resistance aga - **Data preparation**: Load Parquet files and prepare ML-ready datasets - **Model training**: User-customizable logistic regression via tidymodels -- **Evaluation**: nMCC, F1, balanced accuracy, AuPRC, and confusion matrices +- **Evaluation**: MCC, nMCC, F1, balanced accuracy, AuPRC, and confusion matrices - **Feature importance**: Extract and rank predictive features See the [package vignette](https://jravilab.github.io/amRml/articles/intro.html) for detailed usage. diff --git a/README.md b/README.md index 0099d07..3940387 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,7 @@ associated with MDR. - **Data preparation**: Load Parquet files and prepare ML-ready datasets - **Model training**: User-customizable logistic regression via tidymodels -- **Evaluation**: nMCC, F1, balanced accuracy, AuPRC, and confusion +- **Evaluation**: MCC, nMCC, F1, balanced accuracy, AuPRC, and confusion matrices - **Feature importance**: Extract and rank predictive features diff --git a/doc/intro.R b/doc/intro.R index 3782033..d7ca536 100644 --- a/doc/intro.R +++ b/doc/intro.R @@ -112,7 +112,8 @@ knitr::opts_chunk$set( ## ----metrics------------------------------------------------------------------ # # Individual metrics -# nmcc <- calculatenMCC(predictions) # Normalized MCC (0-1 scale) +# mcc <- calculateMCC(predictions) # Matthews correlation coefficient (-1 to 1) +# nmcc <- calculatenMCC(predictions) # Normalized MCC (0 to 1) # f1 <- calculateF1(predictions) # F1 score # bal_acc <- calculateBalAcc(predictions) # Balanced accuracy # auprc <- calculateAUPRC(predictions) # Area under PR curve @@ -150,7 +151,7 @@ knitr::opts_chunk$set( # verbose = TRUE # ) # -# # Results include nMCC at each iteration +# # Results include MCC and nMCC at each iteration # ife_results$ife_performance_tibble # ife_results$feats_removed # If return_feats = TRUE # @@ -233,8 +234,8 @@ knitr::opts_chunk$set( # ) # # # 5. Compare real vs baseline performance -# cat("Real nMCC:", results$performance_tibble$nmcc, "\n") -# cat("Baseline nMCC:", baseline_results$performance_tibble$nmcc, "\n") +# cat("Real MCC:", results$performance_tibble$mcc, "| nMCC:", results$performance_tibble$nmcc, "\n") +# cat("Baseline MCC:", baseline_results$performance_tibble$mcc, "| nMCC:", baseline_results$performance_tibble$nmcc, "\n") # # # 6. Run iterative feature elimination # ife_results <- runIFE( diff --git a/vignettes/intro.Rmd b/vignettes/intro.Rmd index be27d7a..cbe4f2b 100644 --- a/vignettes/intro.Rmd +++ b/vignettes/intro.Rmd @@ -124,7 +124,7 @@ head(results$top_feat_tibble) | `model` | Model type (`"LR"`) | | `train_prop`, `val_prop` | Train/validation split proportions | | `fit_penalty`, `fit_mixture` | Fitted hyperparameters | -| `nmcc`, `f1`, `bal_acc`, `log2_apop` | Performance metrics | +| `mcc`, `nmcc`, `f1`, `bal_acc`, `log2_apop` | Performance metrics | | `run_time_sec` | Runtime in seconds | **`top_feat_tibble`** — ranked feature importance: