From de2fc8672f4162245dd60866e3cdc546badd598b Mon Sep 17 00:00:00 2001 From: Abhirupa Ghosh <100681585+AbhirupaGhosh@users.noreply.github.com> Date: Tue, 20 Jan 2026 11:24:24 -0700 Subject: [PATCH 1/6] Enhance plotting functions and documentation Updated documentation for plotPRC, plotROC, plotCM, plotDensity, and plotTopFeatsVI functions. Added new functions for plotting ROC curves, confusion matrices, and density of predicted class probabilities. --- R/plot_ml.R | 258 ++++++++++++++++++++++++++++++++-------------------- 1 file changed, 157 insertions(+), 101 deletions(-) diff --git a/R/plot_ml.R b/R/plot_ml.R index 6e3eb40..088c481 100644 --- a/R/plot_ml.R +++ b/R/plot_ml.R @@ -19,20 +19,36 @@ #' @importFrom tune extract_fit_parsnip #' @importFrom vip vip #' @importFrom yardstick pr_curve -#' @importFrom graphics barplot NULL -#' plotPRC() -#' -#' Plots the precision-recall curve given a set of test data plus predicted AMR -#' phenotypes. -#' -#' @param test_data_plus_predictions Test data (tibble) with an added column for -#' predicted phenotype labels, such as the output of `predict()`. -#' @return A precision-recall curve as a `ggplot2` object -#' @export +#' Plot a Precision-Recall Curve +#' +#' Generates a precision-recall curve (PRC) for AMR phenotype prediction results. +#' @param test_data_plus_predictions A tibble containing test data with added +#' prediction columns, typically the output of `runMLmodels()`. +#' +#' @return A `ggplot2` object showing the precision-recall curve. +#' +#' @details +#' The function uses `yardstick::pr_curve()` to compute the PR curve and then +#' visualizes it using `ggplot2`. +#' +#' @examples +#' \dontrun{ +#' test_data_plus_predictions <- readr::read_tsv(results/ML_pred/Sfl_drug_AMP_domains_binary_prediction.tsv) +#' plotPRC(test_data_plus_predictions) +#' } +#' +#' @export plotPRC <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) +test_data_plus_predictions <- test_data_plus_predictions |> +dplyr::mutate( +genome_drug.resistant_phenotype = factor( +genome_drug.resistant_phenotype, +levels = c("Resistant", "Susceptible") +) +) prc <- yardstick::pr_curve( test_data_plus_predictions, @@ -46,113 +62,154 @@ plotPRC <- function(test_data_plus_predictions) { return(prc) } -#' plotTopFeatsVI() -#' -#' Generates a plot showing the top features and their variable importance -#' scores. -#' -#' @param fit Best model fit, such as the output of `fitBestModel()` -#' @param n_top_feats [num] Number of top features to plot -#' @return Variable importance plot (a `ggplot2` object) +#' Plot a Receiver Operating Characteristic (ROC) Curve +#' +#' Generates a ROC curve for AMR phenotype prediction results. +#' +#' @param test_data_plus_predictions A tibble with test data and prediction +#' columns (output of `runMLmodels()`). +#' +#' @return A ROC curve plotted using `ggplot2::autoplot()`. +#' #' @export -plotTopFeatsVI <- function(fit, n_top_feats = 10) { - .checkArgWflow(fit) - .checkArgNTopFeats(n_top_feats) - - vip <- fit |> - tune::extract_fit_parsnip() |> - vip::vip(num_features = n_top_feats) + - ggplot2::xlab("Top Features") + +plotROC <- function(test_data_plus_predictions) { + .checkArgTestDataPlusPredictions(test_data_plus_predictions) + test_data_plus_predictions <- test_data_plus_predictions |> +dplyr::mutate( +genome_drug.resistant_phenotype = factor( +genome_drug.resistant_phenotype, +levels = c("Resistant", "Susceptible") +) +) + + roc <- yardstick::roc_curve( + test_data_plus_predictions, + genome_drug.resistant_phenotype, .pred_Resistant + ) |> + ggplot2::autoplot(type = "se") + ggplot2::theme(panel.grid = ggplot2::element_blank()) - - return(vip) + + return(roc) } -#' plotDefaultEval() -#' -#' Plots performance metric or runtime vs. training data proportion or number -#' of cross-validation folds, colored by model. -#' -#' @param default_eval_tibble Output of `findOptimalMLDefaults()` -#' @param x_default_eval [chr] x value of default evaluation plot: "train_prop" -#' or "n_fold" -#' @param y_default_eval [chr] y value of default evaluation plot. It can be -#' "avg_runtime_sec" or one of the following performance metrics: -#' "avg_f1_score", "avg_log2_apop", "avg_bal_acc", or "avg_nmcc" -#' @param xlab [chr] Label for x axis -#' @param ylab [chr] Label for y axis -#' @return A `ggplot2` scatterplot (performance metric or runtime vs. -#' `train_prop` or `n_fold`), colored by model +#' Plot a Confusion Matrix Heatmap +#' +#' Produces a heatmap visualization of the confusion matrix for AMR predictions. +#' +#' @param test_data_plus_predictions A tibble containing true and predicted +#' phenotype labels. +#' +#' @return A heatmap (`ggplot2` object) showing the confusion matrix. +#' #' @export -plotDefaultEval <- function( - default_eval_tibble, x_default_eval = "train_prop", - y_default_eval = "avg_f1_score", xlab = "Train Data Proportion", - ylab = "Average F1 Score" -) { - .checkArgTibble(default_eval_tibble) - .checkArgXDefaultEval(x_default_eval) - .checkArgYDefaultEval(y_default_eval) - .checkArgXYLabs(xlab = xlab, ylab = ylab) - - if (x_default_eval == "n_fold") { - default_eval_tibble <- default_eval_tibble |> - dplyr::filter(train_prop == 0.8) - } else { - default_eval_tibble <- default_eval_tibble |> - dplyr::filter(train_prop != 0.8) - } - - default_eval_plot <- ggplot2::ggplot( - default_eval_tibble, - ggplot2::aes( - x = unlist(default_eval_tibble[x_default_eval]), - y = unlist(default_eval_tibble[y_default_eval]), color = model +plotCM <- function(test_data_plus_predictions) { + .checkArgTestDataPlusPredictions(test_data_plus_predictions) + test_data_plus_predictions <- test_data_plus_predictions |> + dplyr::mutate( + genome_drug.resistant_phenotype = factor( + genome_drug.resistant_phenotype, + levels = c("Resistant", "Susceptible") + ), + .pred_class = factor( + .pred_class, + levels = c("Resistant", "Susceptible") + ) ) - ) + - ggplot2::geom_line(size = 1.5) + - ggplot2::geom_point(size = 3) + - ggplot2::theme( - axis.line = ggplot2::element_line(linewidth = 1.5), - axis.ticks = ggplot2::element_line(linewidth = 1.5, colour = "black"), - axis.text = ggplot2::element_text(size = 16, colour = "black"), - axis.title = ggplot2::element_text(size = 16, face = "bold"), - panel.grid = ggplot2::element_blank(), - panel.background = ggplot2::element_blank(), - legend.text = ggplot2::element_text(size = 16), - legend.title = ggplot2::element_text(size = 16, face = "bold") - ) + - ggplot2::labs(x = xlab, y = ylab, color = "Model") + test_data_plus_predictions |> +yardstick::conf_mat(truth = genome_drug.resistant_phenotype, + estimate = .pred_class) |> +ggplot2::autoplot(type = "heatmap") +} - return(default_eval_plot) +#' Plot Density of Predicted Class Probabilities +#' +#' Visualizes how predicted class probabilities differ between resistant and +#' susceptible genome-drug combinations. +#' +#' @param test_data_plus_predictions Tibble with prediction probabilities and +#' true labels. +#' +#' @return A ggplot2 density plot. +#' +#' @export +plotDensity <- function(test_data_plus_predictions) { + test_data_plus_predictions |> +ggplot2::ggplot(ggplot2::aes(x = .pred_Resistant, +fill = genome_drug.resistant_phenotype)) + +ggplot2::geom_density(alpha = 0.5) +} + +#' Plot Top Feature Importances +#' +#' Creates a bar plot showing the most important features affecting +#' AMR phenotype predictions. +#' +#' @param topfeat A tibble containing feature importance scores +#' (output of `runMLmodels()`). +#' @param n_top_feats Number of top features to display (default: 10). +#' +#' @return A bar plot of variable importance (`ggplot2` object). +#' +#' @examples +#' \dontrun{ +#' topfeat <- readr::read_tsv(results/ML_top_features/Sfl_drug_AMP_domains_binary_top_features.tsv) +#' plotTopFeatsVI(topfeat) +#' } +#' +#' @export +plotTopFeatsVI <- function(topfeat, n_top_feats = 10) { + .checkArgNTopFeats(n_top_feats) + + vip <- topfeat |> + dplyr::slice_max(order_by = Importance, n = n_top_feats) |> + dplyr::mutate( + Variable = factor(Variable, levels = rev(Variable)), # preserve order as shown in table + Sign = factor(Sign, levels = c("POS", "NEG")) + ) |> + ggplot2::ggplot(ggplot2::aes(x = Importance, y = Variable, fill = Sign)) + + ggplot2::geom_col() + + ggplot2::scale_fill_manual( + values = c( + "POS" = "#c6d8d3", + "NEG" = "#f6c9a1" + ) + ) + + ggplot2::labs( + x = "Importance", + y = "Features" + ) + + ggplot2::theme_minimal(base_size = 14) + + ggplot2::theme( + panel.grid.minor = ggplot2::element_blank(), + axis.text.y = ggplot2::element_text(size = 10) + ) + + return(vip) } -#' getBaselineComparisonBarplot() -#' -#' Generates a bar plot that compares model performance with and without -#' randomly shuffled AMR phenotype labels. -#' -#' @param non_shuffled_label_results Output of `runMLPipeline()` -#' (`shuffle_labels = FALSE`) -#' @param shuffled_label_results Output of `runMLPipeline()` -#' (`shuffle_labels = TRUE`) -#' @return A bar plot with balanced accuracy comparisons per antibiotic +#' Compare Baseline Performance With and Without Shuffled Labels +#' +#' Produces a bar plot comparing balanced accuracy for each antibiotic using +#' true AMR labels vs. randomly shuffled labels. +#' +#' @param non_shuffled_label_results Output of `runMLPipeline(shuffle_labels = FALSE)` +#' @param shuffled_label_results Output of `runMLPipeline(shuffle_labels = TRUE)` +#' +#' @return A base R barplot comparing balanced accuracy across models. +#' #' @export getBaselineComparisonBarplot <- function( non_shuffled_label_results, shuffled_label_results ) { - .checkArgTibble(non_shuffled_label_results) - .checkArgTibble(shuffled_label_results) - - drugs <- non_shuffled_label_results |> - dplyr::select(antibiotic) |> - dplyr::pull() + .checkArgTibble(non_shuffled_label_results$performance_tibble) + .checkArgTibble(shuffled_label_results$performance_tibble) - non_shuffled_bal_acc <- non_shuffled_label_results |> + non_shuffled_bal_acc <- non_shuffled_label_results$performance_tibble |> dplyr::select(bal_acc) |> dplyr::pull() - shuffled_bal_acc <- shuffled_label_results |> + shuffled_bal_acc <- shuffled_label_results$performance_tibble |> dplyr::select(bal_acc) |> dplyr::pull() @@ -160,13 +217,12 @@ getBaselineComparisonBarplot <- function( nrow = 2, byrow = TRUE ) - colnames(bal_acc_matrix) <- drugs rownames(bal_acc_matrix) <- c("Non-Shuffled Labels", "Shuffled Labels") baseline_comparison_barplot <- barplot(bal_acc_matrix, beside = TRUE, legend.text = TRUE, col = c("skyblue", "lightpink"), - ylab = "Balanced Accuracy", xlab = "Antibiotic" + ylab = "Balanced Accuracy" ) return(baseline_comparison_barplot) From 9486c555d0933cbeb340b2ece9dce4dc637e9b63 Mon Sep 17 00:00:00 2001 From: AbhirupaGhosh Date: Wed, 21 Jan 2026 06:33:43 +0000 Subject: [PATCH 2/6] Style code (GHA) --- R/core_ml.R | 217 ++++++++++++------- R/generate_matrices_ml.R | 197 ++++++++++------- R/globals.R | 2 - R/plot_ml.R | 121 ++++++----- R/prep_ml.R | 6 +- R/run_ML.R | 450 +++++++++++++++++++++------------------ R/run_ml_pipeline.R | 41 ++-- vignettes/intro.Rmd | 30 +-- 8 files changed, 605 insertions(+), 459 deletions(-) diff --git a/R/core_ml.R b/R/core_ml.R index db874db..1fb0c1b 100644 --- a/R/core_ml.R +++ b/R/core_ml.R @@ -73,7 +73,8 @@ NULL #' @return An `rsplit` object #' @export splitMLInputTibble <- function(ml_input_tibble, split = c(0.6, 0.2), seed = 5280) { - .checkArgTibble(ml_input_tibble, ml = TRUE); .checkArgSplit(split) + .checkArgTibble(ml_input_tibble, ml = TRUE) + .checkArgSplit(split) .checkArgSeed(seed) set.seed(seed) @@ -85,7 +86,7 @@ splitMLInputTibble <- function(ml_input_tibble, split = c(0.6, 0.2), seed = 5280 # If in CV mode: # Still retain a stratified testing holdout purely for final reporting metrics; # CV is only performed on the training portion. - prop_train_for_holdout <- 0.8 # 80 percent train, 20 percent reserved test + prop_train_for_holdout <- 0.8 # 80 percent train, 20 percent reserved test data_split <- rsample::initial_split( ml_input_tibble, prop = prop_train_for_holdout, @@ -115,7 +116,8 @@ splitMLInputTibble <- function(ml_input_tibble, split = c(0.6, 0.2), seed = 5280 #' @return A `recipe` object #' @export buildRecipe <- function(train_data, use_pca = FALSE, pca_threshold = 0.95) { - .checkArgTibble(train_data, ml = TRUE); .checkArgUsePCA(use_pca) + .checkArgTibble(train_data, ml = TRUE) + .checkArgUsePCA(use_pca) .checkArgPCAThreshold(pca_threshold) target_var <- .getTargetVarName(train_data) |> as.character() @@ -124,8 +126,10 @@ buildRecipe <- function(train_data, use_pca = FALSE, pca_threshold = 0.95) { nm <- names(train_data) id_cols <- setdiff(nm[grepl("^genome", nm)], target_var) - rec <- recipes::recipe(formula = stats::reformulate(".", response = target_var), - data = train_data) + rec <- recipes::recipe( + formula = stats::reformulate(".", response = target_var), + data = train_data + ) # Only update roles if we actually have ID columns to mark as metadata if (length(id_cols) > 0) { @@ -146,7 +150,6 @@ buildRecipe <- function(train_data, use_pca = FALSE, pca_threshold = 0.95) { } - #' buildLRModel() #' #' Builds a logistic regression model. @@ -158,13 +161,17 @@ buildRecipe <- function(train_data, use_pca = FALSE, pca_threshold = 0.95) { buildLRModel <- function(multi_class = FALSE) { .checkArgMultiClass(multi_class) - if(!multi_class) { - lr_mod <- parsnip::logistic_reg(penalty = hardhat::tune(), - mixture = hardhat::tune()) |> + if (!multi_class) { + lr_mod <- parsnip::logistic_reg( + penalty = hardhat::tune(), + mixture = hardhat::tune() + ) |> parsnip::set_engine(engine = "glmnet") - } else if(multi_class) { - lr_mod <- parsnip::multinom_reg(penalty = hardhat::tune(), - mixture = hardhat::tune()) |> + } else if (multi_class) { + lr_mod <- parsnip::multinom_reg( + penalty = hardhat::tune(), + mixture = hardhat::tune() + ) |> parsnip::set_engine(engine = "glmnet") } @@ -181,9 +188,11 @@ buildLRModel <- function(multi_class = FALSE) { #' @return A `workflow` object #' @export buildWflow <- function(parsnip_mod, recipe) { - .checkArgParsnipMod(parsnip_mod); .checkArgRecipe(recipe) + .checkArgParsnipMod(parsnip_mod) + .checkArgRecipe(recipe) - wflow <- workflows::workflow() |> workflows::add_model(parsnip_mod) |> + wflow <- workflows::workflow() |> + workflows::add_model(parsnip_mod) |> workflows::add_recipe(recipe) return(wflow) @@ -203,21 +212,21 @@ buildWflow <- function(parsnip_mod, recipe) { #' @return A logistic regression tuning grid as a tibble #' @export buildTuningGrid <- function( - model = "LR", - penalty_vec = 10^seq(-4, -1, length.out = 10), - mix_vec = 0:5 / 5 + model = "LR", + penalty_vec = 10^seq(-4, -1, length.out = 10), + mix_vec = 0:5 / 5 ) { .checkArgModel(model) - + if (model == "LR") { .checkArgPenaltyVec(penalty_vec) .checkArgMixVec(mix_vec) - + penalty <- rep(penalty_vec, each = length(mix_vec)) mixture <- rep(mix_vec, length(penalty_vec)) grid <- tibble::tibble(penalty, mixture) } - + return(grid) } @@ -237,13 +246,14 @@ buildTuningGrid <- function( #' @export tuneGrid <- function(wflow, data_split, grid = buildTuningGrid(model = "LR"), n_fold = 5) { - .checkArgTibble(grid); .checkArgWflow(wflow) + .checkArgTibble(grid) + .checkArgWflow(wflow) .checkArgDataSplit(data_split) split_class <- class(data_split)[1] # Always do CV on the training portion of the split - train_df <- rsample::training(data_split) + train_df <- rsample::training(data_split) target_var <- .getTargetVarName(train_df) if (identical(split_class, "initial_split")) { @@ -259,9 +269,9 @@ tuneGrid <- function(wflow, data_split, grid = buildTuningGrid(model = "LR"), tune_res <- tune::tune_grid( wflow, resamples = resamples, - grid = grid, - control = tune::control_grid(save_pred = TRUE), - metrics = yardstick::metric_set( + grid = grid, + control = tune::control_grid(save_pred = TRUE), + metrics = yardstick::metric_set( yardstick::f_meas, yardstick::pr_auc, yardstick::spec, @@ -286,7 +296,8 @@ tuneGrid <- function(wflow, data_split, grid = buildTuningGrid(model = "LR"), #' @return Best model workflow #' @export selectBestModel <- function(tune_res, wflow, select_best_metric = "mcc") { - .checkArgTuneRes(tune_res); .checkArgWflow(wflow) + .checkArgTuneRes(tune_res) + .checkArgWflow(wflow) .checkArgSelectBestMetric(select_best_metric) best_mod <- tune::select_best(tune_res, metric = select_best_metric) @@ -306,7 +317,8 @@ selectBestModel <- function(tune_res, wflow, select_best_metric = "mcc") { #' @return Best model fit #' @export fitBestModel <- function(final_mod, train_data) { - .checkArgWflow(final_mod); .checkArgTibble(train_data, ml = TRUE) + .checkArgWflow(final_mod) + .checkArgTibble(train_data, ml = TRUE) fit <- final_mod |> parsnip::fit(data = train_data) @@ -324,8 +336,7 @@ fitBestModel <- function(final_mod, train_data) { model <- class(fit$fit$actions$model$spec)[1] - if(model %in% c("logistic_reg", "multinom_reg")) { - + if (model %in% c("logistic_reg", "multinom_reg")) { penalty <- fit$fit$fit$spec$args$penalty mixture <- tryCatch( @@ -334,7 +345,6 @@ fitBestModel <- function(final_mod, train_data) { ) tibble::tibble(penalty = penalty, mixture = mixture) - } else { stop("The `fit` object provided must correspond to 'logistic_reg' or 'multinom_reg'.") } @@ -353,7 +363,8 @@ fitBestModel <- function(final_mod, train_data) { #' labels #' @export predictML <- function(fit, test_data) { - .checkArgWflow(fit); .checkArgTibble(test_data, ml = TRUE) + .checkArgWflow(fit) + .checkArgTibble(test_data, ml = TRUE) test_data_plus_predictions <- parsnip::augment(fit, test_data) @@ -396,7 +407,8 @@ getConfusionMatrix <- function(test_data_plus_predictions) { mcc <- test_data_plus_predictions |> yardstick::mcc(truth = !!target_var, estimate = .pred_class) |> - dplyr::select(.estimate) |> as.numeric() + dplyr::select(.estimate) |> + as.numeric() nmcc <- (mcc + 1) / 2 @@ -413,15 +425,21 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateF1 <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } f1 <- test_data_plus_predictions |> - yardstick::f_meas(truth = genome_drug.resistant_phenotype, - estimate = .pred_class) |> dplyr::select(.estimate) |> as.numeric() |> + yardstick::f_meas( + truth = genome_drug.resistant_phenotype, + estimate = .pred_class + ) |> + dplyr::select(.estimate) |> + as.numeric() |> round(2) return(f1) @@ -437,16 +455,21 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateAUPRC <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } auprc <- test_data_plus_predictions |> yardstick::pr_auc( - truth = genome_drug.resistant_phenotype, .pred_Resistant) |> - dplyr::select(.estimate) |> as.numeric() |> round(2) + truth = genome_drug.resistant_phenotype, .pred_Resistant + ) |> + dplyr::select(.estimate) |> + as.numeric() |> + round(2) return(auprc) } @@ -461,26 +484,33 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateLog2APOP <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } auprc <- .calculateAUPRC(test_data_plus_predictions) prior <- sum( - test_data_plus_predictions$genome_drug.resistant_phenotype == "Resistant") / + test_data_plus_predictions$genome_drug.resistant_phenotype == "Resistant" + ) / nrow(test_data_plus_predictions) - if(prior > 0.3 && prior < 0.7) { - warning(paste("Classes are roughly balanced.", - "Calculation of log2(AUPRC/prior) may be inappropriate.")) - } else if(prior >= 0.7) { - warning(paste("Classes are imbalanced toward the resistant phenotype.", - "Calculation of log2(AUPRC/prior) may be inappropriate.")) + if (prior > 0.3 && prior < 0.7) { + warning(paste( + "Classes are roughly balanced.", + "Calculation of log2(AUPRC/prior) may be inappropriate." + )) + } else if (prior >= 0.7) { + warning(paste( + "Classes are imbalanced toward the resistant phenotype.", + "Calculation of log2(AUPRC/prior) may be inappropriate." + )) } - log2_apop <- log2(auprc/prior) |> round(2) + log2_apop <- log2(auprc / prior) |> round(2) return(log2_apop) } @@ -495,16 +525,21 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateBalAcc <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } bal_acc <- test_data_plus_predictions |> yardstick::bal_accuracy( - truth = genome_drug.resistant_phenotype, estimate = .pred_class) |> - dplyr::select(.estimate) |> as.numeric() |> round(2) + truth = genome_drug.resistant_phenotype, estimate = .pred_class + ) |> + dplyr::select(.estimate) |> + as.numeric() |> + round(2) return(bal_acc) } @@ -519,15 +554,21 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateSensitivity <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } sens <- test_data_plus_predictions |> - yardstick::sens(truth = genome_drug.resistant_phenotype, - estimate = .pred_class) |> dplyr::select(.estimate) |> as.numeric() |> + yardstick::sens( + truth = genome_drug.resistant_phenotype, + estimate = .pred_class + ) |> + dplyr::select(.estimate) |> + as.numeric() |> round(2) return(sens) @@ -543,15 +584,21 @@ getConfusionMatrix <- function(test_data_plus_predictions) { .calculateSpecificity <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) - if(!("genome_drug.resistant_phenotype" %in% + if (!("genome_drug.resistant_phenotype" %in% colnames(test_data_plus_predictions))) { - stop(paste("`test_data_plus_predictions` does not have a column for", - "`genome_drug.resistant_phenotype`.")) + stop(paste( + "`test_data_plus_predictions` does not have a column for", + "`genome_drug.resistant_phenotype`." + )) } spec <- test_data_plus_predictions |> - yardstick::spec(truth = genome_drug.resistant_phenotype, - estimate = .pred_class) |> dplyr::select(.estimate) |> as.numeric() |> + yardstick::spec( + truth = genome_drug.resistant_phenotype, + estimate = .pred_class + ) |> + dplyr::select(.estimate) |> + as.numeric() |> round(2) return(spec) @@ -598,30 +645,36 @@ calculateEvalMets <- function(test_data_plus_predictions) { #' `Importance`, and a column for `Sign` (or, for multi-class, a tibble with #' per-class columns of importance scores for each `Variable`) #' @export -extractTopFeats <- function(fit, prop_vi_top_feats = c(0, 1), - n_top_feats = NA) { +extractTopFeats <- function( + fit, prop_vi_top_feats = c(0, 1), + n_top_feats = NA +) { .checkArgWflow(fit) - if(!is.na(n_top_feats)) {prop_vi_top_feats <- NA} + if (!is.na(n_top_feats)) { + prop_vi_top_feats <- NA + } # Arg checking for every permutation of `prop_vi_top_feats` and `n_top_feats` - if(is.na(n_top_feats) & any(!is.na(prop_vi_top_feats))) { + if (is.na(n_top_feats) & any(!is.na(prop_vi_top_feats))) { .checkArgPropVITopFeats(prop_vi_top_feats) - } else if(any(is.na(prop_vi_top_feats)) & !is.na(n_top_feats)) { + } else if (any(is.na(prop_vi_top_feats)) & !is.na(n_top_feats)) { .checkArgNTopFeats(n_top_feats) - } else if(any(!is.na(prop_vi_top_feats)) & !is.na(n_top_feats)) { + } else if (any(!is.na(prop_vi_top_feats)) & !is.na(n_top_feats)) { stop("Set either `n_top_feats` or `prop_vi_top_feats` to `NA` but not both.") - } else if(any(is.na(prop_vi_top_feats)) & is.na(n_top_feats)) { + } else if (any(is.na(prop_vi_top_feats)) & is.na(n_top_feats)) { stop("Please specify either `n_top_feats` or `prop_vi_top_feats`.") } - feats_arranged <- fit |> workflowsets::extract_fit_parsnip() |> vip::vi() |> + feats_arranged <- fit |> + workflowsets::extract_fit_parsnip() |> + vip::vi() |> dplyr::arrange(dplyr::desc(Importance)) - if(!is.na(n_top_feats)) { + if (!is.na(n_top_feats)) { top_feats_and_VIs <- feats_arranged |> dplyr::slice(1:n_top_feats) - } else if(any(!is.na(prop_vi_top_feats))) { + } else if (any(!is.na(prop_vi_top_feats))) { cum_vi_lower <- prop_vi_top_feats[1] * sum(feats_arranged$Importance) cum_vi_upper <- prop_vi_top_feats[2] * sum(feats_arranged$Importance) @@ -638,9 +691,11 @@ extractTopFeats <- function(fit, prop_vi_top_feats = c(0, 1), # Take a different approach if using multi-class (the previous code would give # a less meaningful result). - if(class(fit$fit$actions$model$spec)[1] == "multinom_reg") { - warning(paste("Extracting top features from a multi-class model.", - "The `prop_vi_top_feats` and `n_top_feats` arguments do not apply.")) + if (class(fit$fit$actions$model$spec)[1] == "multinom_reg") { + warning(paste( + "Extracting top features from a multi-class model.", + "The `prop_vi_top_feats` and `n_top_feats` arguments do not apply." + )) fit_penalty <- .getFitHps(fit)["penalty"] |> as.numeric() glmnet_fit <- parsnip::extract_fit_engine(fit) diff --git a/R/generate_matrices_ml.R b/R/generate_matrices_ml.R index bb1dc68..b19beef 100644 --- a/R/generate_matrices_ml.R +++ b/R/generate_matrices_ml.R @@ -156,7 +156,6 @@ skipImbalancedMatrix <- function(genome_ids, split, stratify_by = NULL, verbosity = c("minimal", "debug")) { - verbosity <- match.arg(verbosity) log <- .make_logger(verbosity) @@ -197,8 +196,10 @@ skipImbalancedMatrix <- function(genome_ids, if (!dir.exists(matrix_path)) dir.create(matrix_path, recursive = TRUE) log("info", paste0("Matrix output directory: ", matrix_path)) - log("debug", paste0("Stratification: ", - ifelse(is.null(stratify_column), "None", stratify_column))) + log("debug", paste0( + "Stratification: ", + ifelse(is.null(stratify_column), "None", stratify_column) + )) # Feature and matrix types feature_types <- list( @@ -220,9 +221,11 @@ skipImbalancedMatrix <- function(genome_ids, # Safe DBI-quoting quote_condition <- function(group_cols, group_values, con) { - ids <- vapply(group_cols, - function(col) DBI::dbQuoteIdentifier(con, col), - character(1)) + ids <- vapply( + group_cols, + function(col) DBI::dbQuoteIdentifier(con, col), + character(1) + ) vals <- vapply( group_cols, function(col) { @@ -256,7 +259,6 @@ skipImbalancedMatrix <- function(genome_ids, log("debug", paste0("Found ", nrow(all_groups), " groups for type: ", group_type)) for (i in seq_len(nrow(all_groups))) { - # New connection for this group con <- DBI::dbConnect(duckdb::duckdb(), parquet_duckdb_path) @@ -268,13 +270,14 @@ skipImbalancedMatrix <- function(genome_ids, condition_string <- quote_condition(group_cols, group_values, con) # Strat filter - strat_filter <- if (!is.null(stratify_column)) + strat_filter <- if (!is.null(stratify_column)) { sprintf("AND \"%s\" IS NOT NULL AND \"%s\" != ''", stratify_column, stratify_column) - else "" + } else { + "" + } # Genome selection logic if (group_type %in% c("drug_class", "drug_class_year", "drug_class_country")) { - genome_ids <- DBI::dbGetQuery(con, sprintf(" WITH class_phenotypes AS ( SELECT \"genome_drug.genome_id\" AS genome_id, @@ -290,7 +293,6 @@ skipImbalancedMatrix <- function(genome_ids, FROM class_phenotypes WHERE any_resistant = 1 OR all_susceptible = 1 ", condition_string))[[1]] - } else { genome_ids <- DBI::dbGetQuery(con, sprintf(" SELECT DISTINCT \"genome_drug.genome_id\" @@ -310,19 +312,24 @@ skipImbalancedMatrix <- function(genome_ids, ", condition_string)) phenotype_summary <- paste( - apply(phenotype_counts_all, 1, - function(row) paste0(row["phenotype"], "=", row["count"])), + apply( + phenotype_counts_all, 1, + function(row) paste0(row["phenotype"], "=", row["count"]) + ), collapse = "; " ) # Apply skip logic if (skipImbalancedMatrix(genome_ids, phenotype_counts_all, n_fold, split, - verbosity = verbosity)) { - + verbosity = verbosity + )) { readr::write_lines( - sprintf("%s\tToo few samples for CV/split\t%d\t%s", - group_label, length(genome_ids), phenotype_summary), - log_path, append = TRUE + sprintf( + "%s\tToo few samples for CV/split\t%d\t%s", + group_label, length(genome_ids), phenotype_summary + ), + log_path, + append = TRUE ) DBI::dbDisconnect(con, shutdown = FALSE) @@ -331,9 +338,12 @@ skipImbalancedMatrix <- function(genome_ids, if (length(genome_ids) < 40) { readr::write_lines( - sprintf("%s\tToo few observations\t%d\t%s", - group_label, length(genome_ids), phenotype_summary), - log_path, append = TRUE + sprintf( + "%s\tToo few observations\t%d\t%s", + group_label, length(genome_ids), phenotype_summary + ), + log_path, + append = TRUE ) DBI::dbDisconnect(con, shutdown = FALSE) @@ -351,9 +361,12 @@ skipImbalancedMatrix <- function(genome_ids, if (nrow(phen2) < 2) { readr::write_lines( - sprintf("%s\tOnly one phenotype class\t%d\t%s", - group_label, length(genome_ids), phenotype_summary), - log_path, append = TRUE + sprintf( + "%s\tOnly one phenotype class\t%d\t%s", + group_label, length(genome_ids), phenotype_summary + ), + log_path, + append = TRUE ) DBI::dbDisconnect(con, shutdown = FALSE) @@ -363,13 +376,14 @@ skipImbalancedMatrix <- function(genome_ids, # Create selected_genomes DBI::dbExecute(con, "CREATE OR REPLACE TEMP TABLE selected_genomes (genome_id VARCHAR)") DBI::dbWriteTable(con, "selected_genomes", - data.frame(genome_id = genome_ids), append = TRUE) + data.frame(genome_id = genome_ids), + append = TRUE + ) # Feature and matrix generation steps for (ftype in names(feature_types)) { - fview <- feature_types[[ftype]]$view - fid <- feature_types[[ftype]]$id_col + fid <- feature_types[[ftype]]$id_col # binary view DBI::dbExecute(con, sprintf(" @@ -389,13 +403,14 @@ skipImbalancedMatrix <- function(genome_ids, } for (mtype in names(matrix_types)) { - binary_only <- matrix_types[[mtype]]$binary_only if (ftype == "struct" && !binary_only) next - mview <- sprintf("%s_%s", ftype, - ifelse(grepl("binary", mtype), "binary", "counts")) - value_col <- matrix_types[[mtype]]$value_col + mview <- sprintf( + "%s_%s", ftype, + ifelse(grepl("binary", mtype), "binary", "counts") + ) + value_col <- matrix_types[[mtype]]$value_col filter_clause <- matrix_types[[mtype]]$filter # select features with non-zero variance @@ -409,29 +424,38 @@ skipImbalancedMatrix <- function(genome_ids, keep_features <- DBI::dbGetQuery(con, keep_query)[["feature_id"]] if (length(keep_features) == 0) { - log("info", paste0("All features filtered for ", - ftype, " - ", mtype, " - ", group_label)) + log("info", paste0( + "All features filtered for ", + ftype, " - ", mtype, " - ", group_label + )) next } - DBI::dbExecute(con, - "CREATE OR REPLACE TEMP TABLE keep_features (feature_id VARCHAR)") + DBI::dbExecute( + con, + "CREATE OR REPLACE TEMP TABLE keep_features (feature_id VARCHAR)" + ) DBI::dbWriteTable(con, - "keep_features", - data.frame(feature_id = keep_features), - append = TRUE) + "keep_features", + data.frame(feature_id = keep_features), + append = TRUE + ) mtype_label <- matrix_types[[mtype]]$label - long_out_path <- file.path(matrix_path, - sprintf("%s_%s_%s_%s_%s_sparse.parquet", - bug, group_type, group_label, ftype, mtype_label)) + long_out_path <- file.path( + matrix_path, + sprintf( + "%s_%s_%s_%s_%s_sparse.parquet", + bug, group_type, group_label, ftype, mtype_label + ) + ) long_out_path_sql <- gsub("\\\\", "/", long_out_path) # phenotype case phenotype_case <- if (group_type %in% - c("drug_class", "drug_class_year", "drug_class_country")) { + c("drug_class", "drug_class_year", "drug_class_country")) { " CASE WHEN MAX(CASE WHEN f.\"genome_drug.resistant_phenotype\"='Resistant' @@ -451,13 +475,20 @@ skipImbalancedMatrix <- function(genome_ids, " } - strat_col_select <- if (!is.null(stratify_by)) - sprintf(", f.\"%s\"", stratify_column) else "" + strat_col_select <- if (!is.null(stratify_by)) { + sprintf(", f.\"%s\"", stratify_column) + } else { + "" + } - strat_col_group <- if (!is.null(stratify_by)) - sprintf(", f.\"%s\"", stratify_column) else "" + strat_col_group <- if (!is.null(stratify_by)) { + sprintf(", f.\"%s\"", stratify_column) + } else { + "" + } - copy_sql <- sprintf(" + copy_sql <- sprintf( + " COPY ( SELECT f.\"genome_drug.genome_id\" AS genome_id, @@ -478,18 +509,21 @@ skipImbalancedMatrix <- function(genome_ids, TO '%s' (FORMAT 'parquet', COMPRESSION 'zstd') ", - fid, value_col, phenotype_case, strat_col_select, - mview, fid, condition_string, - strat_filter, fid, strat_col_group, fid, - long_out_path_sql) + fid, value_col, phenotype_case, strat_col_select, + mview, fid, condition_string, + strat_filter, fid, strat_col_group, fid, + long_out_path_sql + ) ok <- try(DBI::dbExecute(con, copy_sql), silent = TRUE) # On copy failure, log + continue without stopping entire pipeline if (inherits(ok, "try-error")) { readr::write_lines( - sprintf("%s\tCOPY_failed\t%d\t%s", - group_label, length(genome_ids), phenotype_summary), + sprintf( + "%s\tCOPY_failed\t%d\t%s", + group_label, length(genome_ids), phenotype_summary + ), log_path, append = TRUE ) @@ -530,7 +564,7 @@ skipImbalancedMatrix <- function(genome_ids, # Normalize paths to forward slashes for consistency matrix_path <- gsub("\\\\", "/", file.path(path, paste0("matrix_", stratify_by))) - LOO_path <- gsub("\\\\", "/", file.path(path, paste0("LOO_matrix_", stratify_by))) + LOO_path <- gsub("\\\\", "/", file.path(path, paste0("LOO_matrix_", stratify_by))) if (!dir.exists(matrix_path)) { log("info", paste0("The matrix directory ", matrix_path, " does not exist.")) @@ -626,9 +660,11 @@ skipImbalancedMatrix <- function(genome_ids, out_file <- gsub("\\\\", "/", file.path( LOO_path, - paste0(sub_prefix, "_", stratify_by, "_", - drug_class, "_leaveout_", leave_one_out, "_", - sub_feature, "_sparse.parquet") + paste0( + sub_prefix, "_", stratify_by, "_", + drug_class, "_leaveout_", leave_one_out, "_", + sub_feature, "_sparse.parquet" + ) )) arrow::write_parquet(combined, out_file) created <<- c(created, out_file) @@ -702,7 +738,7 @@ skipImbalancedMatrix <- function(genome_ids, # Build one matrix per feature type and matrix type for (ftype in names(feature_types)) { fview <- feature_types[[ftype]]$view - fid <- feature_types[[ftype]]$id_col + fid <- feature_types[[ftype]]$id_col for (mtype in names(matrix_types)) { binary_only <- matrix_types[[mtype]]$binary_only @@ -722,8 +758,9 @@ skipImbalancedMatrix <- function(genome_ids, # Selected genomes DBI::dbExecute(con, "CREATE OR REPLACE TEMP TABLE selected_genomes (genome_id VARCHAR)") DBI::dbWriteTable(con, "selected_genomes", - data.frame(genome_id = genomes_to_keep), - append = TRUE) + data.frame(genome_id = genomes_to_keep), + append = TRUE + ) # Binary view DBI::dbExecute(con, sprintf(" @@ -763,13 +800,15 @@ skipImbalancedMatrix <- function(genome_ids, DBI::dbExecute(con, "CREATE OR REPLACE TEMP TABLE keep_features (feature_id VARCHAR)") DBI::dbWriteTable(con, "keep_features", - data.frame(feature_id = keep_features), - append = TRUE) + data.frame(feature_id = keep_features), + append = TRUE + ) + - - copy_sql <- sprintf(" + copy_sql <- sprintf( + " COPY ( - SELECT + SELECT f.\"genome_drug.genome_id\" AS genome_id, %s AS feature_id, MAX(CAST(%s AS DOUBLE)) AS value, @@ -779,26 +818,26 @@ skipImbalancedMatrix <- function(genome_ids, JOIN keep_features kf ON %s = kf.feature_id JOIN metadata f ON genome_id = f.\"genome_drug.genome_id\" WHERE resistant_classes <> 'Intermediate' - GROUP BY - f.\"genome_drug.genome_id\", - %s, + GROUP BY + f.\"genome_drug.genome_id\", + %s, resistant_classes - ORDER BY - f.\"genome_drug.genome_id\", + ORDER BY + f.\"genome_drug.genome_id\", %s ) TO '%s' (FORMAT 'parquet', COMPRESSION 'zstd') - ", - fid, # %s -> feature_id expression column name - value_col, # %s -> value column to CAST - mview, # %s -> source view (binary or counts) - fid, # %s -> join to keep_features - fid, # %s -> group by feature id - fid, # %s -> order by feature id - out_file_sql # %s -> destination parquet file + ", + fid, # %s -> feature_id expression column name + value_col, # %s -> value column to CAST + mview, # %s -> source view (binary or counts) + fid, # %s -> join to keep_features + fid, # %s -> group by feature id + fid, # %s -> order by feature id + out_file_sql # %s -> destination parquet file ) - + ok <- try(DBI::dbExecute(con, copy_sql), silent = TRUE) if (inherits(ok, "try-error")) { log("info", paste0("COPY failed for MDR matrix: ", out_file)) diff --git a/R/globals.R b/R/globals.R index a6595d2..131d016 100644 --- a/R/globals.R +++ b/R/globals.R @@ -8,7 +8,6 @@ "_PACKAGE" utils::globalVariables(c( - # Prediction columns from tidymodels ".estimate", ".pred_Resistant", @@ -52,7 +51,6 @@ utils::globalVariables(c( "pair_id", "parts", "phenotype", - "precision", "prefix", "prefix_key", diff --git a/R/plot_ml.R b/R/plot_ml.R index b41efb5..42d51ec 100644 --- a/R/plot_ml.R +++ b/R/plot_ml.R @@ -22,33 +22,33 @@ NULL #' Plot a Precision-Recall Curve -#' +#' #' Generates a precision-recall curve (PRC) for AMR phenotype prediction results. #' @param test_data_plus_predictions A tibble containing test data with added #' prediction columns, typically the output of `runMLmodels()`. -#' +#' #' @return A `ggplot2` object showing the precision-recall curve. -#' -#' @details +#' +#' @details #' The function uses `yardstick::pr_curve()` to compute the PR curve and then #' visualizes it using `ggplot2`. -#' +#' #' @examples #' \dontrun{ -#' test_data_plus_predictions <- readr::read_tsv(results/ML_pred/Sfl_drug_AMP_domains_binary_prediction.tsv) +#' test_data_plus_predictions <- readr::read_tsv(results / ML_pred / Sfl_drug_AMP_domains_binary_prediction.tsv) #' plotPRC(test_data_plus_predictions) -#' } -#' +#' } +#' #' @export plotPRC <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) -test_data_plus_predictions <- test_data_plus_predictions |> -dplyr::mutate( -genome_drug.resistant_phenotype = factor( -genome_drug.resistant_phenotype, -levels = c("Resistant", "Susceptible") -) -) + test_data_plus_predictions <- test_data_plus_predictions |> + dplyr::mutate( + genome_drug.resistant_phenotype = factor( + genome_drug.resistant_phenotype, + levels = c("Resistant", "Susceptible") + ) + ) prc <- yardstick::pr_curve( test_data_plus_predictions, @@ -63,24 +63,24 @@ levels = c("Resistant", "Susceptible") } #' Plot a Receiver Operating Characteristic (ROC) Curve -#' +#' #' Generates a ROC curve for AMR phenotype prediction results. -#' +#' #' @param test_data_plus_predictions A tibble with test data and prediction #' columns (output of `runMLmodels()`). -#' +#' #' @return A ROC curve plotted using `ggplot2::autoplot()`. -#' +#' #' @export plotROC <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) test_data_plus_predictions <- test_data_plus_predictions |> -dplyr::mutate( -genome_drug.resistant_phenotype = factor( -genome_drug.resistant_phenotype, -levels = c("Resistant", "Susceptible") -) -) + dplyr::mutate( + genome_drug.resistant_phenotype = factor( + genome_drug.resistant_phenotype, + levels = c("Resistant", "Susceptible") + ) + ) roc <- yardstick::roc_curve( test_data_plus_predictions, @@ -88,19 +88,19 @@ levels = c("Resistant", "Susceptible") ) |> ggplot2::autoplot(type = "se") + ggplot2::theme(panel.grid = ggplot2::element_blank()) - + return(roc) } #' Plot a Confusion Matrix Heatmap -#' +#' #' Produces a heatmap visualization of the confusion matrix for AMR predictions. -#' +#' #' @param test_data_plus_predictions A tibble containing true and predicted #' phenotype labels. -#' +#' #' @return A heatmap (`ggplot2` object) showing the confusion matrix. -#' +#' #' @export plotCM <- function(test_data_plus_predictions) { .checkArgTestDataPlusPredictions(test_data_plus_predictions) @@ -116,62 +116,66 @@ plotCM <- function(test_data_plus_predictions) { ) ) test_data_plus_predictions |> -yardstick::conf_mat(truth = genome_drug.resistant_phenotype, - estimate = .pred_class) |> -ggplot2::autoplot(type = "heatmap") + yardstick::conf_mat( + truth = genome_drug.resistant_phenotype, + estimate = .pred_class + ) |> + ggplot2::autoplot(type = "heatmap") } #' Plot Density of Predicted Class Probabilities -#' +#' #' Visualizes how predicted class probabilities differ between resistant and #' susceptible genome-drug combinations. -#' +#' #' @param test_data_plus_predictions Tibble with prediction probabilities and #' true labels. -#' +#' #' @return A ggplot2 density plot. -#' +#' #' @export plotDensity <- function(test_data_plus_predictions) { test_data_plus_predictions |> -ggplot2::ggplot(ggplot2::aes(x = .pred_Resistant, -fill = genome_drug.resistant_phenotype)) + -ggplot2::geom_density(alpha = 0.5) + ggplot2::ggplot(ggplot2::aes( + x = .pred_Resistant, + fill = genome_drug.resistant_phenotype + )) + + ggplot2::geom_density(alpha = 0.5) } #' Plot Top Feature Importances -#' +#' #' Creates a bar plot showing the most important features affecting #' AMR phenotype predictions. -#' +#' #' @param topfeat A tibble containing feature importance scores #' (output of `runMLmodels()`). #' @param n_top_feats Number of top features to display (default: 10). -#' +#' #' @return A bar plot of variable importance (`ggplot2` object). -#' +#' #' @examples #' \dontrun{ -#' topfeat <- readr::read_tsv(results/ML_top_features/Sfl_drug_AMP_domains_binary_top_features.tsv) +#' topfeat <- readr::read_tsv(results / ML_top_features / Sfl_drug_AMP_domains_binary_top_features.tsv) #' plotTopFeatsVI(topfeat) -#' } -#' +#' } +#' #' @export plotTopFeatsVI <- function(topfeat, n_top_feats = 10) { .checkArgNTopFeats(n_top_feats) - vip <- topfeat |> + vip <- topfeat |> dplyr::slice_max(order_by = Importance, n = n_top_feats) |> dplyr::mutate( - Variable = factor(Variable, levels = rev(Variable)), # preserve order as shown in table + Variable = factor(Variable, levels = rev(Variable)), # preserve order as shown in table Sign = factor(Sign, levels = c("POS", "NEG")) - ) |> + ) |> ggplot2::ggplot(ggplot2::aes(x = Importance, y = Variable, fill = Sign)) + ggplot2::geom_col() + ggplot2::scale_fill_manual( values = c( - "POS" = "#c6d8d3", - "NEG" = "#f6c9a1" + "POS" = "#c6d8d3", + "NEG" = "#f6c9a1" ) ) + ggplot2::labs( @@ -183,20 +187,20 @@ plotTopFeatsVI <- function(topfeat, n_top_feats = 10) { panel.grid.minor = ggplot2::element_blank(), axis.text.y = ggplot2::element_text(size = 10) ) - - return(vip) + + return(vip) } #' Compare Baseline Performance With and Without Shuffled Labels -#' +#' #' Produces a bar plot comparing balanced accuracy for each antibiotic using #' true AMR labels vs. randomly shuffled labels. -#' +#' #' @param non_shuffled_label_results Output of `runMLPipeline(shuffle_labels = FALSE)` #' @param shuffled_label_results Output of `runMLPipeline(shuffle_labels = TRUE)` -#' +#' #' @return A base R barplot comparing balanced accuracy across models. -#' +#' #' @export plotBaselineComparison <- function( non_shuffled_label_results, @@ -268,7 +272,6 @@ plotFishers <- function( alpha = 0.05, label_top_n = 5 ) { - required_cols <- c("gene", "adj_p_value", "sig_after_bh") missing_cols <- setdiff(required_cols, colnames(fisher_df)) diff --git a/R/prep_ml.R b/R/prep_ml.R index d47c160..4a5954e 100644 --- a/R/prep_ml.R +++ b/R/prep_ml.R @@ -111,8 +111,10 @@ loadMLInputTibble <- function(parquet_path) { if (exists(".ml_logger")) { log <- .ml_logger("minimal") - log("debug", paste0("ML tibble constructed: ", nrow(ml_input_tibble), - " genomes x ", getNumFeat(ml_input_tibble), " features")) + log("debug", paste0( + "ML tibble constructed: ", nrow(ml_input_tibble), + " genomes x ", getNumFeat(ml_input_tibble), " features" + )) } if (anyDuplicated(dplyr::pull(ml_input_tibble, genome_id)) != 0) { diff --git a/R/run_ML.R b/R/run_ML.R index eba37f8..2ed07e7 100644 --- a/R/run_ML.R +++ b/R/run_ML.R @@ -4,9 +4,11 @@ #' the ML matrices with these new split/CV values instead. #' @noRd .resolveSplitParams <- function(parquet_path, - defaults = list(split = c(0.8, 0), - seed = 5280, - n_fold = 5)) { + defaults = list( + split = c(0.8, 0), + seed = 5280, + n_fold = 5 + )) { # matrix_dir is the directory that contains the parquet files matrix_dir <- normalizePath(dirname(parquet_path)) params_json <- .readMLParameters(matrix_dir) @@ -16,8 +18,8 @@ } list( - split = if (!is.null(params_json$split)) params_json$split else defaults$split, - seed = if (!is.null(params_json$seed)) params_json$seed else defaults$seed, + split = if (!is.null(params_json$split)) params_json$split else defaults$split, + seed = if (!is.null(params_json$seed)) params_json$seed else defaults$seed, n_fold = if (!is.null(params_json$n_fold)) params_json$n_fold else defaults$n_fold ) } @@ -53,8 +55,9 @@ #' #' # LOO analysis stratified by year #' paths_loo <- createMLResultDir("/path/to/results", -#' stratify_by = "year", -#' LOO = TRUE) +#' stratify_by = "year", +#' LOO = TRUE +#' ) #' #' # MDR analysis #' paths_mdr <- createMLResultDir("/path/to/results", MDR = TRUE) @@ -90,16 +93,17 @@ createMLResultDir <- function(path, ) } else { # Determine prefixes (only in non-MDR mode) - full_prefix <- paste0(ifelse(isTRUE(LOO), "LOO_", ""), - ifelse(isTRUE(cross_test), "cross_test_", "")) + full_prefix <- paste0( + ifelse(isTRUE(LOO), "LOO_", ""), + ifelse(isTRUE(cross_test), "cross_test_", "") + ) half_prefix <- ifelse(isTRUE(LOO), "LOO_", "") # Determine suffix suffix <- if (is.null(stratify_by) || identical(stratify_by, "")) { "" } else { - switch( - stratify_by, + switch(stratify_by, "country" = "_country", "year" = "_year", stop("`stratify_by` must be NULL, 'country', or 'year'.") @@ -127,20 +131,20 @@ createMLResultDir <- function(path, return(paths) } - # createAllMLResultDir <- function(path) { - # createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = FALSE, MDR = FALSE) - # createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = TRUE, MDR = FALSE) - # createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = FALSE, MDR = TRUE) - # createMLResultDir(path, stratify_by = "year", LOO = FALSE, cross_test = FALSE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "year", LOO = FALSE, cross_test = TRUE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "year", LOO = TRUE, cross_test = FALSE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "year", LOO = TRUE, cross_test = TRUE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "country", LOO = FALSE, cross_test = FALSE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "country", LOO = FALSE, cross_test = TRUE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "country", LOO = TRUE, cross_test = FALSE, MDR = FALSE) - # createMLResultDir(path, stratify_by = "country", LOO = TRUE, cross_test = TRUE, MDR = FALSE) - # } - # +# createAllMLResultDir <- function(path) { +# createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = FALSE, MDR = FALSE) +# createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = TRUE, MDR = FALSE) +# createMLResultDir(path, stratify_by = NULL, LOO = FALSE, cross_test = FALSE, MDR = TRUE) +# createMLResultDir(path, stratify_by = "year", LOO = FALSE, cross_test = FALSE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "year", LOO = FALSE, cross_test = TRUE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "year", LOO = TRUE, cross_test = FALSE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "year", LOO = TRUE, cross_test = TRUE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "country", LOO = FALSE, cross_test = FALSE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "country", LOO = FALSE, cross_test = TRUE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "country", LOO = TRUE, cross_test = FALSE, MDR = FALSE) +# createMLResultDir(path, stratify_by = "country", LOO = TRUE, cross_test = TRUE, MDR = FALSE) +# } +# #' Create machine learning input list #' @@ -174,8 +178,9 @@ createMLResultDir <- function(path, #' #' # Cross-test with year stratification #' inputs_ct <- createMLinputList("/path/to/results", -#' stratify_by = "year", -#' cross_test = TRUE) +#' stratify_by = "year", +#' cross_test = TRUE +#' ) #' #' # MDR analysis #' inputs_mdr <- createMLinputList("/path/to/results", MDR = TRUE) @@ -187,10 +192,10 @@ createMLinputList <- function(path, LOO = FALSE, MDR = FALSE, cross_test = FALSE) { - # Validate inputs - if (!is.character(path) || length(path) != 1 || is.na(path)) + if (!is.character(path) || length(path) != 1 || is.na(path)) { stop("`path` must be a valid file path string.") + } path <- normalizePath(path) @@ -225,21 +230,17 @@ createMLinputList <- function(path, # Multi-drug resistance models # ============================ if (MDR) { - parsed <- tibble::tibble(ref_file = files_vec) |> dplyr::mutate( parts = stringr::str_split(basename(ref_file), "_"), - species = purrr::map_chr(parts, ~ .x[1]), - mdr_tag = purrr::map_chr(parts, ~ .x[2]), # always "MDR" + mdr_tag = purrr::map_chr(parts, ~ .x[2]), # always "MDR" phenotype = purrr::map_chr(parts, ~ paste(.x[3:4], collapse = "_")), # Feature is 5th + 6th tokens feature_type = purrr::map_chr(parts, ~ .x[5]), feature_subtype = purrr::map_chr(parts, ~ stringr::str_remove(.x[6], "_sparse.parquet")), - feature = purrr::map2_chr(feature_type, feature_subtype, paste, sep = "_"), - output_prefix = paste0("MDR_", phenotype, "_", feature) ) @@ -247,38 +248,43 @@ createMLinputList <- function(path, dplyr::mutate( test_file = NA_character_, matrix_path = paths$matrix_path, - out_perf = paths$ML_performance, - out_top = paths$ML_top_features, - out_models= paths$ML_models, - out_pred = paths$ML_prediction + out_perf = paths$ML_performance, + out_top = paths$ML_top_features, + out_models = paths$ML_models, + out_pred = paths$ML_prediction ) return(out) - # ============================ - # For all other modeling types - # ============================ + # ============================ + # For all other modeling types + # ============================ } else { - parsed <- tibble::tibble(ref_file = files_vec) |> dplyr::mutate( - parts = stringr::str_split(basename(ref_file), "_"), + parts = stringr::str_split(basename(ref_file), "_"), i_sparse = purrr::map_int(parts, ~ .get_idx(.x, "sparse.parquet")), - i_strat = purrr::map_int(parts, ~ { - if (is.null(stratify_by)) return(NA_integer_) + i_strat = purrr::map_int(parts, ~ { + if (is.null(stratify_by)) { + return(NA_integer_) + } .get_idx(.x, stratify_by) }), # Feature = last two tokens before sparse.parquet feature = purrr::map2_chr(parts, i_sparse, ~ { - i <- .y; x <- .x - if (is.na(i) || i < 3) return(NA_character_) + i <- .y + x <- .x + if (is.na(i) || i < 3) { + return(NA_character_) + } paste(x[(i - 2):(i - 1)], collapse = "_") }), # Drug or drug class extraction drug_or_class = purrr::map2_chr(parts, i_strat, ~ { - i <- .y; x <- .x + i <- .y + x <- .x # Stratified models if (!is.na(i)) { @@ -304,32 +310,40 @@ createMLinputList <- function(path, # Stratification value (if present) strat_value = purrr::map2_chr(parts, i_strat, ~ { - i <- .y; x <- .x - if (is.na(i)) return("") + i <- .y + x <- .x + if (is.na(i)) { + return("") + } # default position is two tokens after the strat label j <- i + 2 # if there's an intervening 'leaveout', skip over it if (j <= length(x) && identical(x[j], "leaveout")) j <- j + 1 - if (j <= length(x)) return(x[j]) - "" # no stratification + if (j <= length(x)) { + return(x[j]) + } + "" # no stratification }), # Prefix key for grouping prefix_key = purrr::map2_chr(parts, i_strat, ~ { - i <- .y; x <- .x + i <- .y + x <- .x # Case A: stratified -> prefix before the stratify label if (!is.na(i)) { - if (i - 1 >= 1) return(paste(x[1:(i - 1)], collapse = "_")) + if (i - 1 >= 1) { + return(paste(x[1:(i - 1)], collapse = "_")) + } return("") } # Case B: unstratified -> prefix is first two tokens - if (x[2] == "drug" && x[3] != "class"){ + if (x[2] == "drug" && x[3] != "class") { # Case A: Cje_drug_X return(paste(x[1:2], collapse = "_")) } - if (x[2] == "drug" && x[3] == "class"){ + if (x[2] == "drug" && x[3] == "class") { # Case A: Cje_drug_X return(paste(x[1:3], collapse = "_")) } @@ -345,18 +359,17 @@ createMLinputList <- function(path, test_file = NA_character_, output_prefix = gsub("_sparse\\.parquet$", "", basename(ref_file)), matrix_path = paths$matrix_path, - out_perf = paths$ML_performance, - out_top = paths$ML_top_features, - out_models= paths$ML_models, - out_pred = paths$ML_prediction + out_perf = paths$ML_performance, + out_top = paths$ML_top_features, + out_models = paths$ML_models, + out_pred = paths$ML_prediction ) return(out) - # ============================ - # Cross-test modeling, no LOO - # ============================ + # ============================ + # Cross-test modeling, no LOO + # ============================ } else if (cross_test && !LOO) { - if (is.null(stratify_by)) { # Case A: stratify_by = NULL, pair across abx within same feature + prefix pairs <- parsed |> @@ -366,8 +379,10 @@ createMLinputList <- function(path, dplyr::select(test_file = ref_file, feature, prefix_key, strat_value, test_drug = drug_or_class), by = c("feature", "prefix_key", "strat_value") ) |> - dplyr::filter(ref_file != test_file, - ref_drug != test_drug) |> + dplyr::filter( + ref_file != test_file, + ref_drug != test_drug + ) |> dplyr::distinct() |> dplyr::mutate( output_prefix = paste0( @@ -380,10 +395,10 @@ createMLinputList <- function(path, out <- pairs |> dplyr::mutate( matrix_path = paths$matrix_path, - out_perf = paths$ML_performance, - out_top = paths$ML_top_features, - out_models= paths$ML_models, - out_pred = paths$ML_prediction + out_perf = paths$ML_performance, + out_top = paths$ML_top_features, + out_models = paths$ML_models, + out_pred = paths$ML_prediction ) return(out) @@ -392,30 +407,29 @@ createMLinputList <- function(path, # Case B: stratify_by != NULL, pair same drug/class, prefix, feature, # but across different stratification groups pairs <- parsed |> - dplyr::select(ref_file, feature, prefix_key, strat_value, - drug_or_class) |> - + dplyr::select( + ref_file, feature, prefix_key, strat_value, + drug_or_class + ) |> # self-join ONLY on prefix_key, drug/class, feature dplyr::inner_join( parsed |> - dplyr::select(test_file = ref_file, - feature, prefix_key, strat_value_test = strat_value, - drug_or_class), + dplyr::select( + test_file = ref_file, + feature, prefix_key, strat_value_test = strat_value, + drug_or_class + ), by = c("prefix_key", "feature", "drug_or_class") ) |> - # do NOT test file against itself dplyr::filter(ref_file != test_file) |> - # enforce different stratification group dplyr::filter(strat_value != strat_value_test) |> - # remove symmetric duplicates (A,B == B,A) dplyr::rowwise() |> dplyr::mutate(pair_id = paste(sort(c(ref_file, test_file)), collapse = "||")) |> dplyr::ungroup() |> dplyr::distinct(pair_id, .keep_all = TRUE) |> - dplyr::mutate( output_prefix = paste0( prefix_key, "_", @@ -429,19 +443,18 @@ createMLinputList <- function(path, out <- pairs |> dplyr::mutate( matrix_path = paths$matrix_path, - out_perf = paths$ML_performance, - out_top = paths$ML_top_features, - out_models= paths$ML_models, - out_pred = paths$ML_prediction + out_perf = paths$ML_performance, + out_top = paths$ML_top_features, + out_models = paths$ML_models, + out_pred = paths$ML_prediction ) return(out) - # ============================ - # Cross-test + LOO modeling - # ============================ + # ============================ + # Cross-test + LOO modeling + # ============================ } else if (cross_test && LOO) { - # LOO requires special directory structure resolution test_path <- file.path(path, stringr::str_remove(basename(paths$matrix_path), "^LOO_")) test_path <- normalizePath(test_path) @@ -461,10 +474,10 @@ createMLinputList <- function(path, out <- loo_pairs |> dplyr::mutate( matrix_path = paths$matrix_path, - out_perf = paths$ML_performance, - out_top = paths$ML_top_features, - out_models= paths$ML_models, - out_pred = paths$ML_prediction + out_perf = paths$ML_performance, + out_top = paths$ML_top_features, + out_models = paths$ML_models, + out_pred = paths$ML_prediction ) return(out) @@ -472,9 +485,11 @@ createMLinputList <- function(path, } # If we ever get here, something wasn't covered - stop("Unhandled combination of arguments: ", - "MDR=", MDR, ", cross_test=", cross_test, ", LOO=", LOO, - ", stratify_by=", if (is.null(stratify_by)) "NULL" else stratify_by) + stop( + "Unhandled combination of arguments: ", + "MDR=", MDR, ", cross_test=", cross_test, ", LOO=", LOO, + ", stratify_by=", if (is.null(stratify_by)) "NULL" else stratify_by + ) } @@ -544,13 +559,15 @@ createMLinputList <- function(path, #' #' # Run with more threads and minimal output #' runMDRmodels("/path/to/results", -#' threads = 32, -#' verbose = FALSE) +#' threads = 32, +#' verbose = FALSE +#' ) #' #' # Run without saving model fits (save disk space) #' runMDRmodels("/path/to/results", -#' threads = 16, -#' return_fit = FALSE) +#' threads = 16, +#' return_fit = FALSE +#' ) #' } #' #' @seealso @@ -571,12 +588,12 @@ runMDRmodels <- function(path, use_saved_split = TRUE, shuffle_labels = FALSE, use_pca = FALSE) { - files <- createMLinputList(path, - stratify_by = NULL, - LOO = FALSE, - cross_test = FALSE, - MDR = TRUE) + stratify_by = NULL, + LOO = FALSE, + cross_test = FALSE, + MDR = TRUE + ) if (nrow(files) == 0) { message("No MDR files found to process. Exiting.") @@ -594,18 +611,19 @@ runMDRmodels <- function(path, # Auto tags for shuffled and PCA shuffle_tag <- if (isTRUE(shuffle_labels)) "shuffled_" else "" - pca_tag <- if (isTRUE(use_pca)) paste0("_pca", as.character(pca_threshold)) else "" + pca_tag <- if (isTRUE(use_pca)) paste0("_pca", as.character(pca_threshold)) else "" results_list <- future.apply::future_lapply( seq_len(nrow(files)), FUN = function(i) { - - ref_parquet <- files$ref_file[i] + ref_parquet <- files$ref_file[i] output_prefix <- files$output_prefix[i] if (interactive()) { - message(sprintf("[runMDRmodels] %d/%d: %s", - i, nrow(files), basename(ref_parquet))) + message(sprintf( + "[runMDRmodels] %d/%d: %s", + i, nrow(files), basename(ref_parquet) + )) } ml_input <- loadMLInputTibble(ref_parquet) @@ -619,32 +637,37 @@ runMDRmodels <- function(path, list(split = split, seed = 5280, n_fold = n_fold) } - res <- try({ - runMLPipeline( - ml_input_tibble = ml_input, - test_data = NA, - model = "LR", - split = sp$split, - n_fold = sp$n_fold, - prop_vi_top_feats = prop_vi_top_feats, - n_top_feats = NA, - use_pca = use_pca, - pca_threshold = pca_threshold, - shuffle_labels = shuffle_labels, - penalty_vec = 10^seq(-4, -1, length.out = 10), - mix_vec = 0:5 / 5, - select_best_metric = "mcc", - seed = sp$seed, - verbose = verbose, - return_tune_res = return_tune_res, - return_fit = return_fit, - return_pred = return_pred - ) - }, silent = TRUE) + res <- try( + { + runMLPipeline( + ml_input_tibble = ml_input, + test_data = NA, + model = "LR", + split = sp$split, + n_fold = sp$n_fold, + prop_vi_top_feats = prop_vi_top_feats, + n_top_feats = NA, + use_pca = use_pca, + pca_threshold = pca_threshold, + shuffle_labels = shuffle_labels, + penalty_vec = 10^seq(-4, -1, length.out = 10), + mix_vec = 0:5 / 5, + select_best_metric = "mcc", + seed = sp$seed, + verbose = verbose, + return_tune_res = return_tune_res, + return_fit = return_fit, + return_pred = return_pred + ) + }, + silent = TRUE + ) if (inherits(res, "try-error")) { - warning("Model failed for: ", output_prefix, - "\n Error: ", attr(res, "condition")$message) + warning( + "Model failed for: ", output_prefix, + "\n Error: ", attr(res, "condition")$message + ) return(NULL) } @@ -652,19 +675,25 @@ runMDRmodels <- function(path, base <- paste0(shuffle_tag, output_prefix, pca_tag) if (!is.null(res$performance_tibble)) { - readr::write_tsv(res$performance_tibble, - file.path(files$out_perf[i], paste0(base, "_performance.tsv"))) + readr::write_tsv( + res$performance_tibble, + file.path(files$out_perf[i], paste0(base, "_performance.tsv")) + ) } if (!is.null(res$top_feat_tibble)) { - readr::write_tsv(res$top_feat_tibble, - file.path(files$out_top[i], paste0(base, "_top_features.tsv"))) + readr::write_tsv( + res$top_feat_tibble, + file.path(files$out_top[i], paste0(base, "_top_features.tsv")) + ) } if (!is.null(res$fit)) { saveRDS(res$fit, file.path(files$out_models[i], paste0(base, "_model_fit.rds"))) } if (!is.null(res$pred)) { - readr::write_tsv(res$pred, - file.path(files$out_pred[i], paste0(base, "_prediction.tsv"))) + readr::write_tsv( + res$pred, + file.path(files$out_pred[i], paste0(base, "_prediction.tsv")) + ) } NULL @@ -783,21 +812,24 @@ runMDRmodels <- function(path, #' #' # Cross-test with year stratification #' runMLmodels("/path/to/results", -#' stratify_by = "year", -#' cross_test = TRUE, -#' threads = 32) +#' stratify_by = "year", +#' cross_test = TRUE, +#' threads = 32 +#' ) #' #' # LOO analysis stratified by country with cross-testing #' runMLmodels("/path/to/results", -#' stratify_by = "country", -#' LOO = TRUE, -#' cross_test = TRUE, -#' verbose = TRUE) +#' stratify_by = "country", +#' LOO = TRUE, +#' cross_test = TRUE, +#' verbose = TRUE +#' ) #' #' # Run without saving model fits (save disk space) #' runMLmodels("/path/to/results", -#' stratify_by = "year", -#' return_fit = FALSE) +#' stratify_by = "year", +#' return_fit = FALSE +#' ) #' } #' #' @seealso @@ -823,19 +855,21 @@ runMLmodels <- function(path, use_saved_split = TRUE, shuffle_labels = FALSE, use_pca = FALSE) { - if (!is.null(stratify_by)) { - if (!is.character(stratify_by) || length(stratify_by) != 1L) + if (!is.character(stratify_by) || length(stratify_by) != 1L) { stop("`stratify_by` must be NULL or a single string: 'year' or 'country'.") - if (!stratify_by %in% c("year", "country")) + } + if (!stratify_by %in% c("year", "country")) { stop("`stratify_by` must be NULL, 'year', or 'country'.") + } } files <- createMLinputList(path, - stratify_by = stratify_by, - LOO = LOO, - MDR = FALSE, - cross_test = cross_test) + stratify_by = stratify_by, + LOO = LOO, + MDR = FALSE, + cross_test = cross_test + ) if (nrow(files) == 0) { message("No files found to process. Exiting.") @@ -864,8 +898,7 @@ runMLmodels <- function(path, strat_suffix <- if (is.null(stratify_by) || identical(stratify_by, "")) { "" } else { - switch( - stratify_by, + switch(stratify_by, "country" = "_country", "year" = "_year", stop("`stratify_by` must be NULL, 'year', or 'country'.") @@ -874,18 +907,19 @@ runMLmodels <- function(path, # Auto naming for shuffled and PCA shuffle_tag <- if (isTRUE(shuffle_labels)) "shuffled_" else "" - pca_tag <- if (isTRUE(use_pca)) paste0("_pca", as.character(pca_threshold)) else "" + pca_tag <- if (isTRUE(use_pca)) paste0("_pca", as.character(pca_threshold)) else "" results_list <- future.apply::future_lapply( seq_len(nrow(files)), FUN = function(i) { - - ref_parquet <- files$ref_file[i] + ref_parquet <- files$ref_file[i] output_prefix <- files$output_prefix[i] if (interactive()) { - message(sprintf("[runMLmodels] %d/%d: %s", - i, nrow(files), basename(ref_parquet))) + message(sprintf( + "[runMLmodels] %d/%d: %s", + i, nrow(files), basename(ref_parquet) + )) } ml_input <- loadMLInputTibble(ref_parquet) @@ -910,32 +944,37 @@ runMLmodels <- function(path, list(split = split, seed = 5280, n_fold = n_fold) } - res <- try({ - runMLPipeline( - ml_input_tibble = ml_input, - test_data = test_data, - model = "LR", - split = sp$split, - n_fold = sp$n_fold, - prop_vi_top_feats = prop_vi_top_feats, - n_top_feats = NA, - use_pca = use_pca, - pca_threshold = pca_threshold, - shuffle_labels = shuffle_labels, - penalty_vec = 10^seq(-4, -1, length.out = 10), - mix_vec = 0:5 / 5, - select_best_metric = "mcc", - seed = sp$seed, - verbose = verbose, - return_tune_res = return_tune_res, - return_fit = return_fit, - return_pred = return_pred - ) - }, silent = TRUE) + res <- try( + { + runMLPipeline( + ml_input_tibble = ml_input, + test_data = test_data, + model = "LR", + split = sp$split, + n_fold = sp$n_fold, + prop_vi_top_feats = prop_vi_top_feats, + n_top_feats = NA, + use_pca = use_pca, + pca_threshold = pca_threshold, + shuffle_labels = shuffle_labels, + penalty_vec = 10^seq(-4, -1, length.out = 10), + mix_vec = 0:5 / 5, + select_best_metric = "mcc", + seed = sp$seed, + verbose = verbose, + return_tune_res = return_tune_res, + return_fit = return_fit, + return_pred = return_pred + ) + }, + silent = TRUE + ) if (inherits(res, "try-error")) { - warning("Model failed for: ", output_prefix, - "\n Error: ", attr(res, "condition")$message) + warning( + "Model failed for: ", output_prefix, + "\n Error: ", attr(res, "condition")$message + ) return(NULL) } @@ -943,19 +982,25 @@ runMLmodels <- function(path, base <- paste0(shuffle_tag, config_prefix, output_prefix, pca_tag, strat_suffix) if (!is.null(res$performance_tibble)) { - readr::write_tsv(res$performance_tibble, - file.path(files$out_perf[i], paste0(base, "_performance.tsv"))) + readr::write_tsv( + res$performance_tibble, + file.path(files$out_perf[i], paste0(base, "_performance.tsv")) + ) } if (!is.null(res$top_feat_tibble)) { - readr::write_tsv(res$top_feat_tibble, - file.path(files$out_top[i], paste0(base, "_top_features.tsv"))) + readr::write_tsv( + res$top_feat_tibble, + file.path(files$out_top[i], paste0(base, "_top_features.tsv")) + ) } if (!is.null(res$fit)) { saveRDS(res$fit, file.path(files$out_models[i], paste0(base, "_model_fit.rds"))) } if (!is.null(res$pred)) { - readr::write_tsv(res$pred, - file.path(files$out_pred[i], paste0(base, "_prediction.tsv"))) + readr::write_tsv( + res$pred, + file.path(files$out_pred[i], paste0(base, "_prediction.tsv")) + ) } NULL @@ -973,7 +1018,6 @@ runMLmodels <- function(path, } - #' Run the entire AMR ML pipeline from a parquet-backed DuckDB #' #' This function provides a complete end-to-end AMR machine learning workflow. @@ -1006,11 +1050,12 @@ runModelingPipeline <- function(parquet_duckdb_path, pca_threshold = 0.99, verbose = TRUE, use_saved_split = TRUE) { - parquet_duckdb_path <- normalizePath(parquet_duckdb_path) if (!file.exists(parquet_duckdb_path)) { - stop("Parquet-backed DuckDB at ", parquet_duckdb_path, " not found.\n", - "Are you using `{Bug}.duckdb` instead of `{Bug}_parquet.duckdb?`") + stop( + "Parquet-backed DuckDB at ", parquet_duckdb_path, " not found.\n", + "Are you using `{Bug}.duckdb` instead of `{Bug}_parquet.duckdb?`" + ) } out_root <- dirname(parquet_duckdb_path) @@ -1024,9 +1069,9 @@ runModelingPipeline <- function(parquet_duckdb_path, generateMLInputs( parquet_duckdb_path = parquet_duckdb_path, out_path = out_root, - n_fold = n_fold, - split = split, - min_n = min_n, + n_fold = n_fold, + split = split, + min_n = min_n, verbosity = if (verbose) "minimal" else "debug" ) @@ -1089,12 +1134,13 @@ runModelingPipeline <- function(parquet_duckdb_path, # All done! if (verbose) { message("\n=== AMR-ML Pipeline Complete ===") - message("All matrices, models, top feature lists, and performance metrics saved under:\n ", - out_root) + message( + "All matrices, models, top feature lists, and performance metrics saved under:\n ", + out_root + ) message("\nTo inspect model outputs, see directories such as:") message(" ML_performance/, ML_models/, ML_prediction/, ML_top_features/") } invisible(out_root) } - diff --git a/R/run_ml_pipeline.R b/R/run_ml_pipeline.R index 2a97c00..ad9f691 100644 --- a/R/run_ml_pipeline.R +++ b/R/run_ml_pipeline.R @@ -93,20 +93,21 @@ runMLPipeline <- function( .checkArgReturnPred(return_pred) - # Set `n_fold` to `NA` if not using cross-validation. if (split[2] != 0) { n_fold <- NA } # Confirm resolved split params - if (verbose) { - mode <- if (split[2] == 0) "cv" else "splits" - message(sprintf("ML split mode: %s | split = c(%.2f, %.2f) | n_fold = %s | seed = %s", - mode, split[1], split[2], - ifelse(is.na(n_fold), "NA", as.character(n_fold)), - as.character(seed))) - } + if (verbose) { + mode <- if (split[2] == 0) "cv" else "splits" + message(sprintf( + "ML split mode: %s | split = c(%.2f, %.2f) | n_fold = %s | seed = %s", + mode, split[1], split[2], + ifelse(is.na(n_fold), "NA", as.character(n_fold)), + as.character(seed) + )) + } # Create a variable indicating whether external `test_data` was provided. This # will be set to `TRUE` later if the `test_data` argument is not `NA`. @@ -116,10 +117,10 @@ runMLPipeline <- function( # Determine whether multi-class classification is to be performed. if (as.character(.getTargetVarName(ml_input_tibble)) == "resistant_classes") { - multi_class <- TRUE - } else { - multi_class <- FALSE - } + multi_class <- TRUE + } else { + multi_class <- FALSE + } if (model != "LR" & multi_class) { stop(paste( @@ -262,7 +263,7 @@ runMLPipeline <- function( mix_vec = mix_vec ) } - + recipe <- buildRecipe(train_data, use_pca = use_pca, pca_threshold = pca_threshold @@ -421,14 +422,16 @@ runMLPipeline <- function( all_results[["fit"]] <- fit } - if(return_pred) { - if(!multi_class){ + if (return_pred) { + if (!multi_class) { all_results[["pred"]] <- test_data_plus_predictions |> - dplyr::select(c(genome_id, .pred_class, .pred_Resistant, - .pred_Susceptible, genome_drug.resistant_phenotype)) - } - all_results[["pred"]] <- test_data_plus_predictions + dplyr::select(c( + genome_id, .pred_class, .pred_Resistant, + .pred_Susceptible, genome_drug.resistant_phenotype + )) } + all_results[["pred"]] <- test_data_plus_predictions + } return(all_results) } diff --git a/vignettes/intro.Rmd b/vignettes/intro.Rmd index 996eb6b..af5bc8e 100644 --- a/vignettes/intro.Rmd +++ b/vignettes/intro.Rmd @@ -264,19 +264,19 @@ ml_tibble_reduced <- removeTopFeats(ml_tibble, top_features) ### Precision-recall curve ```{r plot-prc} -test_data_plus_predictions <- readr::read_tsv(results/ML_pred/Sfl_drug_AMP_domains_binary_prediction.tsv) +test_data_plus_predictions <- readr::read_tsv(results / ML_pred / Sfl_drug_AMP_domains_binary_prediction.tsv) plotPRC(test_data_plus_predictions) ``` ### ROC curve ```{r plot-roc} -test_data_plus_predictions <- readr::read_tsv(results/ML_pred/Sfl_drug_AMP_domains_binary_prediction.tsv) +test_data_plus_predictions <- readr::read_tsv(results / ML_pred / Sfl_drug_AMP_domains_binary_prediction.tsv) plotROC(test_data_plus_predictions) ``` ### Variable importance plot ```{r plot-vi} -topfeat <- readr::read_tsv(results/ML_top_features/Sfl_drug_AMP_domains_binary_top_features.tsv) +topfeat <- readr::read_tsv(results / ML_top_features / Sfl_drug_AMP_domains_binary_top_features.tsv) plotTopFeatsVI(topfeat) ``` ### Baseline comparison barplot @@ -326,7 +326,6 @@ You can label the top N features to highlight the strongest hits (default is 5) ```{r} plotFishers(fisher_results) plotFishers(fisher_results, alpha = 0.01, label_top_n = 5) - ``` ## Wrapper to run all models @@ -338,14 +337,15 @@ Given a DuckDB file produced by `runDataProcessing()`, it: 5. saves performance metrics, fitted models, predictions, and top feature rankings ``` {r} runModelingPipeline(parquet_duckdb_path, - threads = 16, - n_fold = 5, - split = c(1, 0), - min_n = 25, - prop_vi_top_feats = c(0, 1), - pca_threshold = 0.99, - verbose = TRUE, - use_saved_split = TRUE) + threads = 16, + n_fold = 5, + split = c(1, 0), + min_n = 25, + prop_vi_top_feats = c(0, 1), + pca_threshold = 0.99, + verbose = TRUE, + use_saved_split = TRUE +) ``` Merge the performance and top features of each kind of models into a parquet that will serve as starting data for `amRshiny` package @@ -357,7 +357,7 @@ buildPerformancePq( LOO = FALSE, MDR = FALSE, cross_test = FALSE, - out_parquet = NULL, + out_parquet = NULL, compression = "zstd", verbose = TRUE ) @@ -367,8 +367,8 @@ buildTopFeatsPq( LOO = FALSE, MDR = FALSE, cross_test = FALSE, - out_parquet = NULL, + out_parquet = NULL, compression = "zstd", verbose = TRUE -) +) ``` From d3a97bfff946b1be6abb2cf1b9e3067084bea6f4 Mon Sep 17 00:00:00 2001 From: Abhirupa Ghosh <100681585+AbhirupaGhosh@users.noreply.github.com> Date: Fri, 29 May 2026 16:39:23 -0600 Subject: [PATCH 3/6] Refactor plot functions to accept file inputs --- R/plot_ml.R | 58 +++++++++++++++++++---------------------------------- 1 file changed, 21 insertions(+), 37 deletions(-) diff --git a/R/plot_ml.R b/R/plot_ml.R index 42d51ec..a7b459a 100644 --- a/R/plot_ml.R +++ b/R/plot_ml.R @@ -24,8 +24,8 @@ NULL #' Plot a Precision-Recall Curve #' #' Generates a precision-recall curve (PRC) for AMR phenotype prediction results. -#' @param test_data_plus_predictions A tibble containing test data with added -#' prediction columns, typically the output of `runMLmodels()`. +#' @param test_data_plus_predictions_file A file containing test data with added +#' prediction columns, typically the output of `runMLmodels(return_pred=TRUE)`. #' #' @return A `ggplot2` object showing the precision-recall curve. #' @@ -35,12 +35,13 @@ NULL #' #' @examples #' \dontrun{ -#' test_data_plus_predictions <- readr::read_tsv(results / ML_pred / Sfl_drug_AMP_domains_binary_prediction.tsv) -#' plotPRC(test_data_plus_predictions) +#' test_data_plus_predictions_file <- "results / ML_pred / Sfl_drug_AMP_domains_binary_prediction.tsv" +#' plotPRC(test_data_plus_predictions_file) #' } #' #' @export -plotPRC <- function(test_data_plus_predictions) { +plotPRC <- function(test_data_plus_predictions_file) { + test_data_plus_predictions <- readr::read_tsv(test_data_plus_predictions_file) .checkArgTestDataPlusPredictions(test_data_plus_predictions) test_data_plus_predictions <- test_data_plus_predictions |> dplyr::mutate( @@ -66,13 +67,14 @@ plotPRC <- function(test_data_plus_predictions) { #' #' Generates a ROC curve for AMR phenotype prediction results. #' -#' @param test_data_plus_predictions A tibble with test data and prediction -#' columns (output of `runMLmodels()`). +#' @param test_data_plus_predictions_file A file with test data and prediction +#' columns (output of `runMLmodels(return_pred=TRUE)`). #' #' @return A ROC curve plotted using `ggplot2::autoplot()`. #' #' @export -plotROC <- function(test_data_plus_predictions) { +plotROC <- function(test_data_plus_predictions_file) { + test_data_plus_predictions <- readr::read_tsv(test_data_plus_predictions_file) .checkArgTestDataPlusPredictions(test_data_plus_predictions) test_data_plus_predictions <- test_data_plus_predictions |> dplyr::mutate( @@ -96,13 +98,14 @@ plotROC <- function(test_data_plus_predictions) { #' #' Produces a heatmap visualization of the confusion matrix for AMR predictions. #' -#' @param test_data_plus_predictions A tibble containing true and predicted +#' @param test_data_plus_predictions_file A file containing true and predicted #' phenotype labels. #' #' @return A heatmap (`ggplot2` object) showing the confusion matrix. #' #' @export -plotCM <- function(test_data_plus_predictions) { +plotCM <- function(test_data_plus_predictions_file) { + test_data_plus_predictions <- readr::read_tsv(test_data_plus_predictions_file) .checkArgTestDataPlusPredictions(test_data_plus_predictions) test_data_plus_predictions <- test_data_plus_predictions |> dplyr::mutate( @@ -123,32 +126,12 @@ plotCM <- function(test_data_plus_predictions) { ggplot2::autoplot(type = "heatmap") } -#' Plot Density of Predicted Class Probabilities -#' -#' Visualizes how predicted class probabilities differ between resistant and -#' susceptible genome-drug combinations. -#' -#' @param test_data_plus_predictions Tibble with prediction probabilities and -#' true labels. -#' -#' @return A ggplot2 density plot. -#' -#' @export -plotDensity <- function(test_data_plus_predictions) { - test_data_plus_predictions |> - ggplot2::ggplot(ggplot2::aes( - x = .pred_Resistant, - fill = genome_drug.resistant_phenotype - )) + - ggplot2::geom_density(alpha = 0.5) -} - #' Plot Top Feature Importances #' #' Creates a bar plot showing the most important features affecting #' AMR phenotype predictions. #' -#' @param topfeat A tibble containing feature importance scores +#' @param topfeat_file A file containing feature importance scores #' (output of `runMLmodels()`). #' @param n_top_feats Number of top features to display (default: 10). #' @@ -156,12 +139,13 @@ plotDensity <- function(test_data_plus_predictions) { #' #' @examples #' \dontrun{ -#' topfeat <- readr::read_tsv(results / ML_top_features / Sfl_drug_AMP_domains_binary_top_features.tsv) -#' plotTopFeatsVI(topfeat) +#' topfeat_file <- "results / ML_top_features / Sfl_drug_AMP_domains_binary_top_features.tsv" +#' plotTopFeatsVI(topfeat_file) #' } #' #' @export -plotTopFeatsVI <- function(topfeat, n_top_feats = 10) { +plotTopFeatsVI <- function(topfeat_file, n_top_feats = 10) { + topfeat <- readr::read_tsv(topfeat_file) .checkArgNTopFeats(n_top_feats) vip <- topfeat |> @@ -279,8 +263,8 @@ plotFishers <- function( stop("Missing required columns: ", paste(missing_cols, collapse = ", ")) } - plot_df <- fisher_df %>% - dplyr::arrange(adj_p_value) %>% + plot_df <- fisher_df |> + dplyr::arrange(adj_p_value) |> dplyr::mutate( rank = dplyr::row_number(), neg_log10_adj_p = -log10(adj_p_value), @@ -312,7 +296,7 @@ plotFishers <- function( ) if (label_top_n > 0) { - label_df <- plot_df %>% + label_df <- plot_df |> dplyr::slice_head(n = label_top_n) p <- p + From 241836639dc00c8d4e904242d6428d35064a9f7b Mon Sep 17 00:00:00 2001 From: AbhirupaGhosh Date: Fri, 29 May 2026 22:42:41 +0000 Subject: [PATCH 4/6] Style code (GHA) --- R/plot_ml.R | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/R/plot_ml.R b/R/plot_ml.R index a7b459a..6ee976d 100644 --- a/R/plot_ml.R +++ b/R/plot_ml.R @@ -41,7 +41,7 @@ NULL #' #' @export plotPRC <- function(test_data_plus_predictions_file) { - test_data_plus_predictions <- readr::read_tsv(test_data_plus_predictions_file) + test_data_plus_predictions <- readr::read_tsv(test_data_plus_predictions_file) .checkArgTestDataPlusPredictions(test_data_plus_predictions) test_data_plus_predictions <- test_data_plus_predictions |> dplyr::mutate( @@ -74,7 +74,7 @@ plotPRC <- function(test_data_plus_predictions_file) { #' #' @export plotROC <- function(test_data_plus_predictions_file) { - test_data_plus_predictions <- readr::read_tsv(test_data_plus_predictions_file) + test_data_plus_predictions <- readr::read_tsv(test_data_plus_predictions_file) .checkArgTestDataPlusPredictions(test_data_plus_predictions) test_data_plus_predictions <- test_data_plus_predictions |> dplyr::mutate( @@ -105,7 +105,7 @@ plotROC <- function(test_data_plus_predictions_file) { #' #' @export plotCM <- function(test_data_plus_predictions_file) { - test_data_plus_predictions <- readr::read_tsv(test_data_plus_predictions_file) + test_data_plus_predictions <- readr::read_tsv(test_data_plus_predictions_file) .checkArgTestDataPlusPredictions(test_data_plus_predictions) test_data_plus_predictions <- test_data_plus_predictions |> dplyr::mutate( From f2eb43c427d87e3f732505aa6113665f4f885713 Mon Sep 17 00:00:00 2001 From: Abhirupa Ghosh <100681585+AbhirupaGhosh@users.noreply.github.com> Date: Mon, 1 Jun 2026 17:45:56 -0600 Subject: [PATCH 5/6] Implement drug performance plotting functions in plot_ml.R Add functions to plot drug phenotype distribution, model performance, and cross-drug generalization heatmaps. --- R/plot_ml.R | 798 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 798 insertions(+) diff --git a/R/plot_ml.R b/R/plot_ml.R index 6ee976d..fab0759 100644 --- a/R/plot_ml.R +++ b/R/plot_ml.R @@ -313,3 +313,801 @@ plotFishers <- function( return(p) } + +#' Plot drug phenotype distribution +#' +#' Reads metadata and generates a stacked bar plot showing counts of resistant +#' and susceptible phenotypes per antibiotic. +#' +#' @param metadata_path Character. Path to directory containing `metadata.parquet`. +#' +#' @return A ggplot object. +#' @export +#' +#' @examples +#' plotDrugDist(metadata_path = "data/Campylobacter/") +plotDrugDist <- function(metadata_path = "."){ + +metadata <- arrow::read_parquet(file.path(metadata_path,"metadata.parquet")) + +##################### phenotype distribution (drugs) ######################### +drug_dist <- metadata |> + dplyr::distinct(genome.genome_id, + genome_drug.antibiotic, + drug_abbr, + genome_drug.resistant_phenotype) |> + dplyr::count(genome_drug.antibiotic, + drug_abbr, + genome_drug.resistant_phenotype) |> + dplyr::group_by(genome_drug.antibiotic, drug_abbr) |> + dplyr::mutate(total = sum(n)) |> + dplyr::ungroup() |> + dplyr::mutate( + label = paste0(genome_drug.antibiotic, " (", drug_abbr, ")"), + label = forcats::fct_reorder(label, total) + ) + +p <- ggplot2::ggplot( + drug_dist, + ggplot2::aes( + x = label, + y = n, + fill = genome_drug.resistant_phenotype + ) +) + + ggplot2::geom_col(color = "black", width = 0.8) + + ggplot2::coord_flip() + + ggplot2::scale_fill_manual( + values = c( + "Resistant" = "#d4872a", + "Susceptible" = "#5b8db8" + ), + name = "Phenotype" + ) + + ggplot2::labs( + x = "Antibiotic", + y = "Number of unique genomes" + ) + + ggplot2::theme_classic(base_size = 14) + +p +} + +#' Plot drug-level model performance +#' +#' Generates heatmaps and ridge plots summarizing model performance (MCC) +#' across drugs and feature types. +#' +#' @param metadata_path Character. Path to `metadata.parquet`. +#' @param performance_path Character. Path to `all_performance.parquet`. +#' +#' @return A patchwork ggplot object combining multiple panels. +#' @export +#' +#' @examples +#' plotDrugPerf(metadata_path = "data/Campylobacter/", performance_path = "data/Campylobacter/ML_performance/") +plotDrugPerf <- function(metadata_path = ".", performance_path = ".") { + +metadata <- arrow::read_parquet(file.path(metadata_path,"metadata.parquet")) + +performance <- arrow::read_parquet(file.path(performance_path,"all_performance.parquet")) + +######################## drug performances ################################# +median_drug <- performance |> + dplyr::filter( + drug_label == "drug", + !shuffled # keep real models; remove if you want both + ) |> + dplyr::group_by(drug_or_class, feature_type, feature_subtype) |> + dplyr::summarise(median_mcc = median(mcc, na.rm = TRUE), .groups = "drop") |> + dplyr::left_join(plot_df, by = c("drug_or_class" = "drug_abbr")) |> + dplyr::mutate(drug_or_class = reorder(drug_or_class, total)) + +drug_p1 <- ggplot2::ggplot(median_drug, + ggplot2::aes(x = feature_type, + y = drug_or_class, + fill = median_mcc)) + + ggplot2::geom_tile(color = "grey90", width = 0.9) + + + ggplot2::scale_fill_gradientn( + colors = c( + "#C4B8A8", # low + "#FAFAF7", # around 0 + "#5F84C9", # medium/high (~0.7–0.9) + "#0F2A5A" # very dark for ~1 + ), + values = scales::rescale(c(-1, 0, 0.85, 1)), + name = "Best MCC" + ) + + + ggplot2::labs(x = "Feature type") + + ggplot2::theme_minimal(base_size = 12) + + ggplot2::theme( + axis.text = ggplot2::element_text(size = 10, colour = "black"), + axis.title = ggplot2::element_text(size = 12), + axis.title.y = ggplot2::element_blank(), + axis.text.x = ggplot2::element_text(angle = 45, hjust = 1), + legend.position = "bottom" + ) + + ggplot2::coord_fixed() + +drug_p1 + +median_feature <- performance |> + dplyr::filter( + drug_label == "drug", + !shuffled + ) |> + dplyr::group_by(drug_or_class, feature_type) |> + dplyr::summarise(median_mcc = median(mcc, na.rm = TRUE), .groups = "drop") |> + dplyr::left_join(plot_df, by = c("drug_or_class" = "drug_abbr")) |> + dplyr::mutate(drug_or_class = reorder(drug_or_class, total)) + + +feat_pal <- c( + "args" = "#56B4E9", # sky blue + "cogs" = "#E69F00", # orange + "genes" = "#009E73", # bluish green + "domains" = "#F0E442", # yellow + "proteins" = "#CC79A7", # reddish purple + "struct" = "#D55E00" # vermillion +) + +rc_perf <- ggplot2::ggplot(median_feature |> + dplyr::distinct(drug_or_class, + feature_type, median_mcc), + ggplot2::aes(x = median_mcc, y = drug_or_class)) + + + ggridges::geom_density_ridges( + scale = 0.75, + rel_min_height = 0.01, + alpha = 0.4, + fill = "grey90", + colour = "grey70" + ) + + + ggplot2::geom_point( + position = position_jitter(height = 0.1), + size = 2, + alpha = 0.8, + aes(color = feature_type) + ) + + ggplot2::scale_color_manual(values = feat_pal, name = "Feature type") + + ggplot2::stat_summary( + fun = median, + geom = "point", + size = 2, + color = "black" + ) + + ggplot2::theme_minimal(base_size = 14) + + ggplot2::theme( + axis.text = ggplot2::element_text(size = 10, colour = "black"), + axis.title = ggplot2::element_text(size = 12), + axis.title.y = ggplot2::element_blank(), + axis.text.x = ggplot2::element_text(angle = 45, hjust = 1), + legend.position = "right", + panel.grid.minor = ggplot2::element_blank(), + axis.line = ggplot2::element_line(color = "black") + ) + +rc_perf + +final_plot = drug_p1 + + rc_perf + + patchwork::plot_layout( + widths = c(2, 2), # adjust proportions + guides = "collect" + ) & + ggplot2::theme( + legend.position = "bottom" + ) + +final_plot + +} + +#' Plot cross-drug generalization heatmap +#' +#' Creates a heatmap showing cross-drug model performance (MCC), where models +#' trained on one drug are evaluated on another. +#' +#' @param cross_test_performance_path Character. Path to `cross_drug_perf.parquet`. +#' @param drug_performance_path Character. Path to `all_performance.parquet`. +#' +#' @return A ComplexHeatmap object. +#' @export +#' +#' @examples +#' plotCrossDrug(cross_test_performance_path = "data/Campylobacter/cross_test_ML_performance", drug_performance_path = "data/Campylobacter/ML_performance/") +plotCrossDrug <- function(cross_test_performance_path = ".", drug_performance_path = ".") { + + cross_drug <- arrow::read_parquet(file.path(cross_test_performance_path,"cross_drug_perf.parquet")) + performance <- arrow::read_parquet(file.path(drug_performance_path,"all_performance.parquet")) + +###################### CROSS DRUG Testing ############################# +heatmap_df <- cross_drug |> + # dplyr::filter(tested_on %in% (cross_drug |> dplyr::pull(drug_or_class))) |> + dplyr::group_by(drug_or_class, tested_on) |> + dplyr::summarise(median_mcc = median(mcc, na.rm = TRUE), .groups = "drop") + +same_drugs <- performance |> + dplyr::filter(drug_label == "drug", + drug_or_class %in% (cross_drug |> + dplyr::distinct(drug_or_class) |> + dplyr::pull())) |> + dplyr::group_by(drug_or_class) |> + dplyr::summarise(median_mcc = median(mcc, na.rm = TRUE), .groups = "drop") |> + dplyr::mutate(tested_on = drug_or_class) |> + dplyr::distinct(drug_or_class, tested_on, median_mcc) + +heatmap_df <- heatmap_df |> + dplyr::add_row(same_drugs) |> + dplyr::left_join(metadata |> + dplyr::distinct(drug_abbr, class_abbr), + by = c("drug_or_class" = "drug_abbr")) |> + dplyr::rename("drug_class" = "class_abbr") |> + dplyr::left_join(metadata |> + dplyr::distinct(drug_abbr, class_abbr), + by = c("tested_on" = "drug_abbr")) + +# Row annotation (already similar to what you did) +annotation_row <- heatmap_df |> + dplyr::distinct(drug_or_class, drug_class) |> + tibble::column_to_rownames("drug_or_class") + +# Column annotation +annotation_col <- heatmap_df |> + dplyr::distinct(tested_on, class_abbr) |> + tibble::column_to_rownames("tested_on") + +mat <- heatmap_df |> + dplyr::select(drug_or_class, tested_on, median_mcc) |> + tidyr::pivot_wider(names_from = tested_on, values_from = median_mcc) |> + tibble::column_to_rownames("drug_or_class") |> + as.matrix() + +row_order <- heatmap_df |> + dplyr::distinct(drug_or_class, drug_class) |> + dplyr::arrange(drug_class, drug_or_class) |> + dplyr::pull(drug_or_class) + +col_order <- heatmap_df |> + dplyr::distinct(tested_on, class_abbr) |> + dplyr::arrange(class_abbr, tested_on) |> + dplyr::pull(tested_on) + +# mat[is.na(mat)] <- 0 +mat <- mat[row_order, col_order] + +# Align annotations +annotation_row <- annotation_row[row_order, , drop = FALSE] +annotation_col <- annotation_col[col_order, , drop = FALSE] + + +# Collect all classes from both row and column +classes <- base::union( + annotation_row$drug_class, + annotation_col$class_abbr +) + +# Create ONE named color vector +class_colors <- stats::setNames( + scales::hue_pal() (length(classes)), + classes +) + +heat_colors <- colorRampPalette(RColorBrewer::brewer.pal(11, "RdBu"))(100) + +# ---- Convert annotations ---- +ha_row <- ComplexHeatmap::rowAnnotation( + drug_class = annotation_row$drug_class, + col = list(drug_class = class_colors), + show_annotation_name = FALSE, + show_legend = FALSE +) + +ha_col <- ComplexHeatmap::HeatmapAnnotation( + class_abbr = annotation_col$class_abbr, + col = list(class_abbr = class_colors), + show_annotation_name = FALSE, na_col = "grey3" +) + +# ---- Color function (instead of breaks + palette) ---- +col_fun <- circlize::colorRamp2( + seq(-max_val, max_val, length.out = length(heat_colors)), + heat_colors +) +# ---- Heatmap ---- +cross_drug_hm <- ComplexHeatmap::Heatmap( + mat, + name = "median_mcc", + col = col_fun, + cluster_rows = FALSE, + cluster_columns = FALSE, + row_order = row_order, + column_order = col_order, + left_annotation = ha_row, + top_annotation = ha_col, + show_row_names = TRUE, + show_column_names = TRUE, + column_title = "tested on", + row_title = "trained on", + column_title_side = "bottom", + row_title_side = "right", + row_names_gp = grid::gpar(fontsize = 14), + column_names_gp = grid::gpar(fontsize = 14), + column_names_rot = 0, + + # remove borders like pheatmap + rect_gp = grid::gpar(col = NA), + + # legends + show_heatmap_legend = TRUE +) + +cross_drug_hm +} + +#' Plot stratified model performance +#' +#' Visualizes model performance (MCC) stratified by year or country, +#' comparing within-group vs cross-group evaluation. +#' +#' @param year_or_country Character. Either "year" or "country". +#' @param stratified_performance_path Character. Path to stratified performance files. +#' @param stratified_cross_performance_path Character. Path to cross-stratified performance files. +#' +#' @return A ggplot object. +#' @export +#' +#' @examples +#' plotStratifiedPerf("year", stratified_performance_path = "data/Campylobacter/ML_year_performance", +#' stratified_cross_performance_path = "data/Campylobacter/cross_test_ML_year_performance") +plotStratifiedPerf <- function(year_or_country = "year", + stratified_performance_path = ".", + stratified_cross_performance_path = ".") { + + perf <- arrow::read_parquet(file.path(stratified_performance_path, + paste0(year_or_country,"_perf.parquet"))) + + cross_test <- arrow::read_parquet(file.path(stratified_cross_performance_path, + paste0("cross_", + year_or_country, + "_perf.parquet"))) +if(year_or_country == "year") { +all <- perf |> + dplyr::rename("train_year" = "strat_value") |> + dplyr::mutate(test_year = train_year) |> + dplyr::select(drug_label, drug_or_class, + train_year, test_year, feature_type, feature_subtype, mcc) |> + dplyr::bind_rows(cross_test |> + dplyr::select(drug_label, drug_or_class, + train_year, test_year, feature_type, + feature_subtype, mcc)) |> + dplyr::mutate(category = dplyr::if_else( + train_year == test_year, "same year bin", "different year bin")) +} +else { +all <- perf |> + dplyr::rename("train_country" = "strat_value") |> + dplyr::mutate(test_country = train_country) |> + dplyr::select(drug_label, drug_or_class, + train_country, test_country, + feature_type, feature_subtype, mcc) |> + dplyr::bind_rows(cross_test |> + dplyr::select(drug_label, drug_or_class, + train_country, test_country, + feature_type, feature_subtype, mcc)) |> + dplyr::mutate(category = dplyr::if_else( + train_country == test_country, "same country", "different country")) +} + + fill_vals <- if (year_or_country == "year") { + c( + "same year bin" = "#b3cde3", + "different year bin" = "#fbb4ae" + ) + } else { + c( + "same country" = "#b3cde3", + "different country" = "#fbb4ae" + ) + } + + plot <- ggplot2::ggplot( + all |> + dplyr::filter(drug_label == "drug", !is.na(mcc)), + ggplot2::aes(x = mcc, y = drug_or_class, fill = category) + ) + + ggridges::geom_density_ridges( + alpha = 0.5, + scale = 1, + rel_min_height = 0.01, + position = "identity" + ) + + ggplot2::geom_vline(xintercept = 0, linetype = "dashed", color = "black") + + ggplot2::scale_fill_manual(values = fill_vals) + + ggplot2::theme_minimal(base_size = 14) + + ggplot2::labs( + title = if (year_or_country == "year") { + "Temporal performance by drug" + } else { + "Geographical performance by drug" + }, + x = "MCC", + y = "Drug", + fill = "Tested on" + ) + + ggplot2::theme( + axis.title = ggplot2::element_text(colour = "black", size = 10), + axis.text.x = ggplot2::element_text(angle = 45, hjust = 1, size = 10, colour = "black"), + axis.text.y = ggplot2::element_text(size = 10, colour = "black"), + axis.title.y = ggplot2::element_blank(), + legend.title = ggplot2::element_text(size = 12), + legend.text = ggplot2::element_text(size = 10), + legend.position = "bottom", + plot.title = ggplot2::element_text(face = "bold"), + panel.grid.minor = ggplot2::element_blank(), + plot.margin = margin(0, 0, 0, 0) + ) +plot +} + +#' Plot multi-drug resistance (MDR) model performance +#' +#' Generates violin plots of performance, feature importance summaries, +#' and prediction confusion-style visualizations for MDR models. +#' +#' @param MDR_performance_path Character. Path to `MDR_perf.parquet`. +#' @param MDR_top_feature_path Character. Path to `MDR_top_features.parquet`. +#' @param MDR_pred_path Character. Path to `MDR_pred.parquet`. +#' +#' @return A list of ggplot objects. +#' @export +#' +#' @examples +#' plotMDR(MDR_performance_path = "data/Campylobacter/MDR_ML_performance", MDR_top_feature_path = "data/Campylobacter/MDR_ML_top_features", +#' MDR_pred_path = "data/Campylobacter/MDR_ML_pred") +plotMDR <- function(MDR_performance_path = ".", MDR_top_feature_path = ".", + MDR_pred_path = ".") { + +MDR_perf <- arrow::read_parquet(file.path(MDR_performance_path, "MDR_perf.parquet")) + +# ---- Violin plot ---- +perf_plot <- ggplot2::ggplot(MDR_perf, + ggplot2::aes(x = feature_type, y = mcc)) + + + # violins (overall distribution per feature type) + ggplot2::geom_violin(fill = "grey85", color = NA, alpha = 0.8) + + + # points (colored by binary vs counts) + ggplot2::geom_jitter( + ggplot2::aes(color = feature_subtype), + width = 0.12, size = 2, alpha = 0.8 + ) + + + ggplot2::scale_color_manual(values = c( + "binary" = "#7B9CB5", + "counts" = "#CC8644" + )) + + + ggplot2::theme_minimal(base_size = 12) + + ggplot2::labs( + # title = "MDR model performances", + # subtitle = "Violin = distribution per feature type; points = binary vs counts", + x = "Feature type", + y = "MCC", + color = "Feature\nsubtype" + ) + + + ggplot2::theme( + legend.position = "right", + plot.title = ggplot2::element_text(face = "bold") + ) + + ggplot2::theme( + axis.title = ggplot2::element_text(colour = "black", size = 10), + axis.text.x = ggplot2::element_text(angle = 45, hjust = 1, size = 10, colour = "black"), + axis.text.y = ggplot2::element_text(size = 14, colour = "black"), + legend.title = ggplot2::element_text(size = 12), + legend.text = ggplot2::element_text(size = 10), + legend.position = "none", + title = ggplot2::element_text(face = "bold"), + + panel.background = ggplot2::element_blank(), + panel.grid.minor = ggplot2::element_blank(), + panel.grid.major.x = ggplot2::element_blank(), # remove vertical lines + panel.grid.major.y = ggplot2::element_line(color = "grey80"), # keep horizontal lines + + axis.line = ggplot2::element_line(color = "black") + ) + + ggplot2::scale_y_continuous(limits = c(0, 1)) + +perf_plot + +MDR_pred <- arrow::read_parquet(file.path(MDR_pred_path, "MDR_pred.parquet")) |> + dplyr::mutate( + diff_top2 = purrr::pmap_dbl(dplyr::across(dplyr::contains(".pred") & dplyr::where(is.numeric)), function(...) { + x <- c(...) + sx <- sort(x, decreasing = TRUE) + sx[1] - sx[2] + }) # Difference between prediction probabilities of top two classes + ) |> + dplyr::select(genome_id, resistant_classes, .pred_class, diff_top2, + feature_type, feature_subtype, seed) |> + dplyr::group_by(resistant_classes, .pred_class, feature_type) |> + dplyr::summarise(mean_margin = mean(diff_top2), n = n(), .groups = "drop") |> + dplyr::group_by(resistant_classes, feature_type) |> # normalize within true class + dplyr::mutate(sum = sum(n), prop = n / sum(n)) |> + dplyr::ungroup() + +MDR_pred_plot <- ggplot2::ggplot(MDR_pred, + ggplot2::aes(x = resistant_classes, + y = .pred_class)) + + ggplot2::geom_tile(ggplot2::aes(fill = prop)) + + ggplot2::geom_point(ggplot2::aes(size = mean_margin), color = "black") + + ggplot2::facet_wrap(~ feature_type) + + ggplot2::scale_fill_distiller( + palette = "RdBu", + direction = 1, # flip with -1 if needed + name = "Prediction proportion" + ) + + ggplot2::labs(x = "true class", y = "predicted class") + + ggplot2::scale_size(range = c(1, 6), name = "Mean margin") + + ggplot2::coord_equal() + + ggplot2::theme_minimal() + + ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 1)) + + ggplot2::theme( + axis.title = ggplot2::element_text(colour = "black", size = 10), + axis.text.x = ggplot2::element_text(angle = 45, hjust = 1, size = 10, colour = "black"), + axis.text.y = ggplot2::element_text(size = 10, colour = "black"), + legend.title = ggplot2::element_text(size = 12), + legend.text = ggplot2::element_text(size = 10), + legend.position = "right", + title = ggplot2::element_text(face = "bold") + ) + +MDR_pred_plot + +# MDR_feat <- arrow::read_parquet(file.path( +# MDR_top_feature_path,"MDR_top_features.parquet")) |> +# pivot_longer(-c(Variable, feature_type, feature_subtype, seed), +# values_to = "Importance", +# names_to = "Resistant_classes") |> +# filter(!Importance == 0) +# +# MDR_feat_clean <- MDR_feat |> +# dplyr::filter(feature_type != "struct") |> +# dplyr::group_by(Resistant_classes, feature_type, feature_subtype, seed) |> +# dplyr::slice_max(Importance, n = top_n, with_ties = FALSE) |> +# dplyr::ungroup() |> +# dplyr::mutate(Variable = gsub( ".NCBIFAM", "", Variable)) |> +# dplyr::mutate(Variable = gsub("^X", "", Variable)) |> +# dplyr::mutate(Variable = dplyr::if_else( +# feature_type == "domains", gsub("_.*", "", Variable), Variable)) |> +# dplyr::mutate(Variable = dplyr::if_else( +# feature_type == "proteins", gsub("fig.", "fig|", Variable), Variable)) |> +# dplyr::left_join(cluster_feature, by = c("Variable" = "feature")) |> +# dplyr::mutate( +# cluster = dplyr::coalesce(cluster, Variable) +# ) +# +# cluster_df <- MDR_feat_clean |> +# dplyr::group_by(Resistant_classes, cluster) |> +# dplyr::summarise( +# Importance = median(Importance, na.rm = TRUE), +# .groups = "drop" +# ) +# +# top_clusters <- cluster_df |> +# group_by(Resistant_classes) |> +# group_modify(~{ +# +# df <- .x +# +# top_pos <- df |> +# arrange(desc(Importance)) |> +# slice_head(n = 10) +# +# top_neg <- df |> +# arrange(Importance) |> +# slice_head(n = 10) +# +# bind_rows(top_pos, top_neg) +# }) |> +# ungroup() +# +# top_clusters <- top_clusters |> +# dplyr::left_join(protein_names, by = c("cluster" = "proteinID")) |> +# dplyr::mutate( +# proteinName = dplyr::coalesce(proteinName, cluster), # fallback +# proteinName = stringr::str_trunc(proteinName, 50) +# ) |> +# dplyr::distinct(Resistant_classes, proteinName, Importance) |> +# # ✅ reorder AFTER naming +# dplyr::group_by(Resistant_classes) |> +# dplyr::mutate( +# proteinName = forcats::fct_reorder(proteinName, Importance) +# ) |> +# dplyr::ungroup() +# +# ggplot(top_clusters, +# aes(x = Importance, y = proteinName)) + +# +# # line (lollipop stem) +# geom_segment( +# aes(x = 0, xend = Importance, +# y = proteinName, yend = proteinName), +# color = "grey60" +# ) + +# +# # dot +# geom_point( +# aes(color = Importance > 0), +# size = 3 +# ) + +# +# facet_wrap(~ Resistant_classes, scales = "free_y") + +# +# scale_color_manual( +# values = c("TRUE" = "#5b8db8", # positive +# "FALSE" = "#d4872a"), # negative +# guide = "none" +# ) + +# +# theme_minimal(base_size = 13) + +# labs( +# x = "Median importance", +# y = "Cluster" +# ) + +# theme( +# panel.grid.minor = element_blank(), +# strip.text = element_text(face = "bold") +# ) + + +} + +#' Compare shuffled vs real model performance +#' +#' Creates boxplots comparing performance (MCC) between real and shuffled labels +#' across feature types. +#' +#' @param metadata_path Character. Path to `metadata.parquet`. +#' @param performance_path Character. Path to `all_performance.parquet`. +#' +#' @return A ggplot object. +#' @export +#' +#' @examples +#' plotShuffleVsReal(metadata_path = "data/Campylobacter/", performance_path = "data/Campylobacter/ML_performance") +plotShuffleVsReal <- function(metadata_path = ".", performance_path = ".") { + + metadata <- arrow::read_parquet(file.path(metadata_path,"metadata.parquet")) + performance <- arrow::read_parquet(file.path(performance_path,"all_performance.parquet")) + + performance |> + dplyr::mutate( + shuffled_label = dplyr::if_else(shuffled, "shuffled", "real") + ) |> + ggplot2::ggplot(ggplot2::aes(x = feature_subtype, y = mcc, fill = shuffled_label)) + + ggplot2::geom_boxplot( + width = 0.55, outlier.size = 0.8, outlier.alpha = 0.4, + outlier.color = "grey50", linewidth = 0.4 + ) + + ggplot2::geom_hline(yintercept = 0, linetype = "dashed", color = "grey60", linewidth = 0.4) + + ggplot2::scale_fill_manual( + values = c("real" = "#7B9CB5", "shuffled" = "#C4B8A8"), + name = NULL + ) + + ggplot2::scale_y_continuous(limits = c(-0.2, 1), breaks = seq(-0.2, 1, 0.2)) + + ggplot2::facet_wrap(~ feature_type, nrow = 1) + + ggplot2::theme_minimal(base_size = 12) + + theme( + panel.grid.major.x = element_blank(), + panel.grid.minor = element_blank(), + panel.grid.major.y = element_line(color = "#E5E2D9", linewidth = 0.4), + strip.text = element_text(color = "grey30", face = "bold", size = 10), + # strip.background = element_rect(fill = "#EEEAE0", color = NA), + axis.text = element_text(color = "grey45"), + legend.position = "top", + legend.text = element_text(color = "grey40", size = 10) + ) + + labs( + x = NULL, y = "MCC" + ) + +} + +#' Plot top contributing feature clusters +#' +#' Identifies top contributing clusters across feature types and drugs, +#' and visualizes their relative contributions. +#' +#' @param top_feat_path Character. Path to `all_top_features.parquet`. +#' @param cluster_feature_path Character. Path to `cluster_feature.parquet`. +#' @param protein_names_path Character. Path to `protein_names.parquet`. +#' @param top_n Integer. Number of top features to retain per model. +#' +#' @return A ggplot object. +#' @export +#' +#' @examples +#' plotTopClusters(top_feat_path = "data/Campylobacter/ML_top_features", cluster_feature_path = "data/Campylobacter/", +#' protein_names_path = "data/Campylobacter/", top_n = 10) +plotTopClusters <- function(top_feat_path = ".", cluster_feature_path = ".", + protein_names_path = ".", top_n = 10) { + + ################### Top features ######################### + + top_feat <- arrow::read_parquet(file.path(top_feat_path, "all_top_features.parquet")) + cluster_feature <- arrow::read_parquet(file.path(cluster_feature_path, "cluster_feature.parquet")) + protein_names <- arrow::read_parquet(file.path(protein_names_path, "protein_names.parquet")) + + # which clusters appear in top n across feature types per drug + # join top features with cluster mapping, filter out struct and shuffled + top_feat_clean <- top_feat |> + dplyr::filter(!shuffled, feature_type != "struct", drug_label == "drug") |> + dplyr::group_by(drug_or_class, feature_type, feature_subtype, seed) |> + dplyr::slice_max(Importance, n = top_n, with_ties = FALSE) |> + dplyr::ungroup() |> + dplyr::mutate(Variable = gsub( ".NCBIFAM", "", Variable)) |> + dplyr::mutate(Variable = gsub("^X", "", Variable)) |> + dplyr::mutate(Variable = dplyr::if_else( + feature_type == "domains", gsub("_.*", "", Variable), Variable)) |> + dplyr::mutate(Variable = dplyr::if_else( + feature_type == "proteins", gsub("fig.", "fig|", Variable), Variable)) |> + dplyr::left_join(cluster_feature, by = c("Variable" = "feature")) |> + dplyr::mutate( + cluster = dplyr::coalesce(cluster, Variable), # fallback to Variable if no match + Importance_signed = dplyr::if_else(Sign == "NEG", -Importance, Importance) + ) + + shared_mat <- top_feat_clean |> + dplyr::group_by(drug_or_class, feature_type, cluster) |> + dplyr::summarise(abs_imp = median(Importance, na.rm = TRUE), .groups = "drop") |> + + # convert to contribution within each feature_type + dplyr::group_by(drug_or_class, feature_type) |> + dplyr::mutate(contribution = abs_imp / sum(abs_imp, na.rm = TRUE)) |> + + # pick top n contributors + dplyr::slice_max(contribution, n = top_n, with_ties = FALSE) |> + dplyr::ungroup() |> + + dplyr::add_count(drug_or_class, cluster, name = "n_feat_types") |> + dplyr::left_join(protein_names, by = c("cluster" = "proteinID")) |> + dplyr::mutate( + proteinName = stringr::str_trunc(proteinName, 50), + proteinName = forcats::fct_reorder(proteinName, n_feat_types) + ) + + feat_plot <- ggplot2::ggplot(shared_mat, + ggplot2::aes(x = feature_type, + y = proteinName, + fill = contribution)) + + ggplot2::geom_tile(color = "#FAFAF7", linewidth = 0.5, width = 0.9, height = 0.9) + + # coord_fixed() + + ggplot2::scale_fill_distiller( + palette = "RdPu", + direction = 1, + name = "contribution", + na.value = "#EEEAE0" + ) + + ggplot2::facet_wrap(~ drug_or_class, scales = "free_y") + + ggplot2::theme_minimal(base_size = 12) + + ggplot2::theme( + panel.grid.major.x = ggplot2::element_blank(), + panel.grid.minor = ggplot2::element_blank(), + panel.grid.major.y = ggplot2::element_line(color = "#E5E2D9", linewidth = 0.4), + strip.text = ggplot2::element_text(color = "grey30", face = "bold", size = 10), + strip.background = ggplot2::element_rect(fill = "#EEEAE0", color = NA), + axis.title.y = ggplot2::element_blank(), + axis.text.x = ggplot2::element_text(color = "black", angle = 30, hjust = 1, size = 6), + axis.text.y = ggplot2::element_text(color = "black", size = 6), + legend.position = "bottom", + legend.text = ggplot2::element_text(color = "grey40", size = 10), + legend.title = ggplot2::element_text(color = "grey40", size = 10) + ) + + feat_plot +} From 7ece64468621c326cfc65b75973397381bf8c815 Mon Sep 17 00:00:00 2001 From: AbhirupaGhosh Date: Mon, 1 Jun 2026 23:47:29 +0000 Subject: [PATCH 6/6] Style code (GHA) --- R/plot_ml.R | 1163 ++++++++++++++++++++++++++------------------------- 1 file changed, 597 insertions(+), 566 deletions(-) diff --git a/R/plot_ml.R b/R/plot_ml.R index fab0759..825b488 100644 --- a/R/plot_ml.R +++ b/R/plot_ml.R @@ -326,51 +326,54 @@ plotFishers <- function( #' #' @examples #' plotDrugDist(metadata_path = "data/Campylobacter/") -plotDrugDist <- function(metadata_path = "."){ - -metadata <- arrow::read_parquet(file.path(metadata_path,"metadata.parquet")) - -##################### phenotype distribution (drugs) ######################### -drug_dist <- metadata |> - dplyr::distinct(genome.genome_id, - genome_drug.antibiotic, - drug_abbr, - genome_drug.resistant_phenotype) |> - dplyr::count(genome_drug.antibiotic, - drug_abbr, - genome_drug.resistant_phenotype) |> - dplyr::group_by(genome_drug.antibiotic, drug_abbr) |> - dplyr::mutate(total = sum(n)) |> - dplyr::ungroup() |> - dplyr::mutate( - label = paste0(genome_drug.antibiotic, " (", drug_abbr, ")"), - label = forcats::fct_reorder(label, total) - ) +plotDrugDist <- function(metadata_path = ".") { + metadata <- arrow::read_parquet(file.path(metadata_path, "metadata.parquet")) + + ##################### phenotype distribution (drugs) ######################### + drug_dist <- metadata |> + dplyr::distinct( + genome.genome_id, + genome_drug.antibiotic, + drug_abbr, + genome_drug.resistant_phenotype + ) |> + dplyr::count( + genome_drug.antibiotic, + drug_abbr, + genome_drug.resistant_phenotype + ) |> + dplyr::group_by(genome_drug.antibiotic, drug_abbr) |> + dplyr::mutate(total = sum(n)) |> + dplyr::ungroup() |> + dplyr::mutate( + label = paste0(genome_drug.antibiotic, " (", drug_abbr, ")"), + label = forcats::fct_reorder(label, total) + ) -p <- ggplot2::ggplot( - drug_dist, - ggplot2::aes( - x = label, - y = n, - fill = genome_drug.resistant_phenotype - ) -) + - ggplot2::geom_col(color = "black", width = 0.8) + - ggplot2::coord_flip() + - ggplot2::scale_fill_manual( - values = c( - "Resistant" = "#d4872a", - "Susceptible" = "#5b8db8" - ), - name = "Phenotype" - ) + - ggplot2::labs( - x = "Antibiotic", - y = "Number of unique genomes" + p <- ggplot2::ggplot( + drug_dist, + ggplot2::aes( + x = label, + y = n, + fill = genome_drug.resistant_phenotype + ) ) + - ggplot2::theme_classic(base_size = 14) + ggplot2::geom_col(color = "black", width = 0.8) + + ggplot2::coord_flip() + + ggplot2::scale_fill_manual( + values = c( + "Resistant" = "#d4872a", + "Susceptible" = "#5b8db8" + ), + name = "Phenotype" + ) + + ggplot2::labs( + x = "Antibiotic", + y = "Number of unique genomes" + ) + + ggplot2::theme_classic(base_size = 14) -p + p } #' Plot drug-level model performance @@ -387,123 +390,125 @@ p #' @examples #' plotDrugPerf(metadata_path = "data/Campylobacter/", performance_path = "data/Campylobacter/ML_performance/") plotDrugPerf <- function(metadata_path = ".", performance_path = ".") { - -metadata <- arrow::read_parquet(file.path(metadata_path,"metadata.parquet")) - -performance <- arrow::read_parquet(file.path(performance_path,"all_performance.parquet")) - -######################## drug performances ################################# -median_drug <- performance |> - dplyr::filter( - drug_label == "drug", - !shuffled # keep real models; remove if you want both - ) |> - dplyr::group_by(drug_or_class, feature_type, feature_subtype) |> - dplyr::summarise(median_mcc = median(mcc, na.rm = TRUE), .groups = "drop") |> - dplyr::left_join(plot_df, by = c("drug_or_class" = "drug_abbr")) |> - dplyr::mutate(drug_or_class = reorder(drug_or_class, total)) - -drug_p1 <- ggplot2::ggplot(median_drug, - ggplot2::aes(x = feature_type, - y = drug_or_class, - fill = median_mcc)) + - ggplot2::geom_tile(color = "grey90", width = 0.9) + - - ggplot2::scale_fill_gradientn( - colors = c( - "#C4B8A8", # low - "#FAFAF7", # around 0 - "#5F84C9", # medium/high (~0.7–0.9) - "#0F2A5A" # very dark for ~1 - ), - values = scales::rescale(c(-1, 0, 0.85, 1)), - name = "Best MCC" - ) + - - ggplot2::labs(x = "Feature type") + - ggplot2::theme_minimal(base_size = 12) + - ggplot2::theme( - axis.text = ggplot2::element_text(size = 10, colour = "black"), - axis.title = ggplot2::element_text(size = 12), - axis.title.y = ggplot2::element_blank(), - axis.text.x = ggplot2::element_text(angle = 45, hjust = 1), - legend.position = "bottom" - ) + - ggplot2::coord_fixed() + metadata <- arrow::read_parquet(file.path(metadata_path, "metadata.parquet")) -drug_p1 + performance <- arrow::read_parquet(file.path(performance_path, "all_performance.parquet")) -median_feature <- performance |> - dplyr::filter( - drug_label == "drug", - !shuffled - ) |> - dplyr::group_by(drug_or_class, feature_type) |> - dplyr::summarise(median_mcc = median(mcc, na.rm = TRUE), .groups = "drop") |> - dplyr::left_join(plot_df, by = c("drug_or_class" = "drug_abbr")) |> - dplyr::mutate(drug_or_class = reorder(drug_or_class, total)) - - -feat_pal <- c( - "args" = "#56B4E9", # sky blue - "cogs" = "#E69F00", # orange - "genes" = "#009E73", # bluish green - "domains" = "#F0E442", # yellow - "proteins" = "#CC79A7", # reddish purple - "struct" = "#D55E00" # vermillion -) - -rc_perf <- ggplot2::ggplot(median_feature |> - dplyr::distinct(drug_or_class, - feature_type, median_mcc), - ggplot2::aes(x = median_mcc, y = drug_or_class)) + - - ggridges::geom_density_ridges( - scale = 0.75, - rel_min_height = 0.01, - alpha = 0.4, - fill = "grey90", - colour = "grey70" - ) + - - ggplot2::geom_point( - position = position_jitter(height = 0.1), - size = 2, - alpha = 0.8, - aes(color = feature_type) - ) + - ggplot2::scale_color_manual(values = feat_pal, name = "Feature type") + - ggplot2::stat_summary( - fun = median, - geom = "point", - size = 2, - color = "black" + ######################## drug performances ################################# + median_drug <- performance |> + dplyr::filter( + drug_label == "drug", + !shuffled # keep real models; remove if you want both + ) |> + dplyr::group_by(drug_or_class, feature_type, feature_subtype) |> + dplyr::summarise(median_mcc = median(mcc, na.rm = TRUE), .groups = "drop") |> + dplyr::left_join(plot_df, by = c("drug_or_class" = "drug_abbr")) |> + dplyr::mutate(drug_or_class = reorder(drug_or_class, total)) + + drug_p1 <- ggplot2::ggplot( + median_drug, + ggplot2::aes( + x = feature_type, + y = drug_or_class, + fill = median_mcc + ) ) + - ggplot2::theme_minimal(base_size = 14) + - ggplot2::theme( - axis.text = ggplot2::element_text(size = 10, colour = "black"), - axis.title = ggplot2::element_text(size = 12), - axis.title.y = ggplot2::element_blank(), - axis.text.x = ggplot2::element_text(angle = 45, hjust = 1), - legend.position = "right", - panel.grid.minor = ggplot2::element_blank(), - axis.line = ggplot2::element_line(color = "black") - ) + ggplot2::geom_tile(color = "grey90", width = 0.9) + + ggplot2::scale_fill_gradientn( + colors = c( + "#C4B8A8", # low + "#FAFAF7", # around 0 + "#5F84C9", # medium/high (~0.7–0.9) + "#0F2A5A" # very dark for ~1 + ), + values = scales::rescale(c(-1, 0, 0.85, 1)), + name = "Best MCC" + ) + + ggplot2::labs(x = "Feature type") + + ggplot2::theme_minimal(base_size = 12) + + ggplot2::theme( + axis.text = ggplot2::element_text(size = 10, colour = "black"), + axis.title = ggplot2::element_text(size = 12), + axis.title.y = ggplot2::element_blank(), + axis.text.x = ggplot2::element_text(angle = 45, hjust = 1), + legend.position = "bottom" + ) + + ggplot2::coord_fixed() -rc_perf + drug_p1 -final_plot = drug_p1 + - rc_perf + - patchwork::plot_layout( - widths = c(2, 2), # adjust proportions - guides = "collect" - ) & - ggplot2::theme( - legend.position = "bottom" + median_feature <- performance |> + dplyr::filter( + drug_label == "drug", + !shuffled + ) |> + dplyr::group_by(drug_or_class, feature_type) |> + dplyr::summarise(median_mcc = median(mcc, na.rm = TRUE), .groups = "drop") |> + dplyr::left_join(plot_df, by = c("drug_or_class" = "drug_abbr")) |> + dplyr::mutate(drug_or_class = reorder(drug_or_class, total)) + + + feat_pal <- c( + "args" = "#56B4E9", # sky blue + "cogs" = "#E69F00", # orange + "genes" = "#009E73", # bluish green + "domains" = "#F0E442", # yellow + "proteins" = "#CC79A7", # reddish purple + "struct" = "#D55E00" # vermillion ) -final_plot + rc_perf <- ggplot2::ggplot( + median_feature |> + dplyr::distinct( + drug_or_class, + feature_type, median_mcc + ), + ggplot2::aes(x = median_mcc, y = drug_or_class) + ) + + ggridges::geom_density_ridges( + scale = 0.75, + rel_min_height = 0.01, + alpha = 0.4, + fill = "grey90", + colour = "grey70" + ) + + ggplot2::geom_point( + position = position_jitter(height = 0.1), + size = 2, + alpha = 0.8, + aes(color = feature_type) + ) + + ggplot2::scale_color_manual(values = feat_pal, name = "Feature type") + + ggplot2::stat_summary( + fun = median, + geom = "point", + size = 2, + color = "black" + ) + + ggplot2::theme_minimal(base_size = 14) + + ggplot2::theme( + axis.text = ggplot2::element_text(size = 10, colour = "black"), + axis.title = ggplot2::element_text(size = 12), + axis.title.y = ggplot2::element_blank(), + axis.text.x = ggplot2::element_text(angle = 45, hjust = 1), + legend.position = "right", + panel.grid.minor = ggplot2::element_blank(), + axis.line = ggplot2::element_line(color = "black") + ) + rc_perf + + final_plot <- drug_p1 + + rc_perf + + patchwork::plot_layout( + widths = c(2, 2), # adjust proportions + guides = "collect" + ) & + ggplot2::theme( + legend.position = "bottom" + ) + + final_plot } #' Plot cross-drug generalization heatmap @@ -520,132 +525,137 @@ final_plot #' @examples #' plotCrossDrug(cross_test_performance_path = "data/Campylobacter/cross_test_ML_performance", drug_performance_path = "data/Campylobacter/ML_performance/") plotCrossDrug <- function(cross_test_performance_path = ".", drug_performance_path = ".") { - - cross_drug <- arrow::read_parquet(file.path(cross_test_performance_path,"cross_drug_perf.parquet")) - performance <- arrow::read_parquet(file.path(drug_performance_path,"all_performance.parquet")) - -###################### CROSS DRUG Testing ############################# -heatmap_df <- cross_drug |> - # dplyr::filter(tested_on %in% (cross_drug |> dplyr::pull(drug_or_class))) |> - dplyr::group_by(drug_or_class, tested_on) |> - dplyr::summarise(median_mcc = median(mcc, na.rm = TRUE), .groups = "drop") - -same_drugs <- performance |> - dplyr::filter(drug_label == "drug", - drug_or_class %in% (cross_drug |> - dplyr::distinct(drug_or_class) |> - dplyr::pull())) |> - dplyr::group_by(drug_or_class) |> - dplyr::summarise(median_mcc = median(mcc, na.rm = TRUE), .groups = "drop") |> - dplyr::mutate(tested_on = drug_or_class) |> - dplyr::distinct(drug_or_class, tested_on, median_mcc) - -heatmap_df <- heatmap_df |> - dplyr::add_row(same_drugs) |> - dplyr::left_join(metadata |> - dplyr::distinct(drug_abbr, class_abbr), - by = c("drug_or_class" = "drug_abbr")) |> - dplyr::rename("drug_class" = "class_abbr") |> - dplyr::left_join(metadata |> - dplyr::distinct(drug_abbr, class_abbr), - by = c("tested_on" = "drug_abbr")) - -# Row annotation (already similar to what you did) -annotation_row <- heatmap_df |> - dplyr::distinct(drug_or_class, drug_class) |> - tibble::column_to_rownames("drug_or_class") - -# Column annotation -annotation_col <- heatmap_df |> - dplyr::distinct(tested_on, class_abbr) |> - tibble::column_to_rownames("tested_on") - -mat <- heatmap_df |> - dplyr::select(drug_or_class, tested_on, median_mcc) |> - tidyr::pivot_wider(names_from = tested_on, values_from = median_mcc) |> - tibble::column_to_rownames("drug_or_class") |> - as.matrix() - -row_order <- heatmap_df |> - dplyr::distinct(drug_or_class, drug_class) |> - dplyr::arrange(drug_class, drug_or_class) |> - dplyr::pull(drug_or_class) - -col_order <- heatmap_df |> - dplyr::distinct(tested_on, class_abbr) |> - dplyr::arrange(class_abbr, tested_on) |> - dplyr::pull(tested_on) - -# mat[is.na(mat)] <- 0 -mat <- mat[row_order, col_order] - -# Align annotations -annotation_row <- annotation_row[row_order, , drop = FALSE] -annotation_col <- annotation_col[col_order, , drop = FALSE] - - -# Collect all classes from both row and column -classes <- base::union( - annotation_row$drug_class, - annotation_col$class_abbr -) - -# Create ONE named color vector -class_colors <- stats::setNames( - scales::hue_pal() (length(classes)), - classes -) - -heat_colors <- colorRampPalette(RColorBrewer::brewer.pal(11, "RdBu"))(100) - -# ---- Convert annotations ---- -ha_row <- ComplexHeatmap::rowAnnotation( - drug_class = annotation_row$drug_class, - col = list(drug_class = class_colors), - show_annotation_name = FALSE, - show_legend = FALSE -) - -ha_col <- ComplexHeatmap::HeatmapAnnotation( - class_abbr = annotation_col$class_abbr, - col = list(class_abbr = class_colors), - show_annotation_name = FALSE, na_col = "grey3" -) - -# ---- Color function (instead of breaks + palette) ---- -col_fun <- circlize::colorRamp2( - seq(-max_val, max_val, length.out = length(heat_colors)), - heat_colors -) -# ---- Heatmap ---- -cross_drug_hm <- ComplexHeatmap::Heatmap( - mat, - name = "median_mcc", - col = col_fun, - cluster_rows = FALSE, - cluster_columns = FALSE, - row_order = row_order, - column_order = col_order, - left_annotation = ha_row, - top_annotation = ha_col, - show_row_names = TRUE, - show_column_names = TRUE, - column_title = "tested on", - row_title = "trained on", - column_title_side = "bottom", - row_title_side = "right", - row_names_gp = grid::gpar(fontsize = 14), - column_names_gp = grid::gpar(fontsize = 14), - column_names_rot = 0, - - # remove borders like pheatmap - rect_gp = grid::gpar(col = NA), - - # legends - show_heatmap_legend = TRUE -) - -cross_drug_hm + cross_drug <- arrow::read_parquet(file.path(cross_test_performance_path, "cross_drug_perf.parquet")) + performance <- arrow::read_parquet(file.path(drug_performance_path, "all_performance.parquet")) + + ###################### CROSS DRUG Testing ############################# + heatmap_df <- cross_drug |> + # dplyr::filter(tested_on %in% (cross_drug |> dplyr::pull(drug_or_class))) |> + dplyr::group_by(drug_or_class, tested_on) |> + dplyr::summarise(median_mcc = median(mcc, na.rm = TRUE), .groups = "drop") + + same_drugs <- performance |> + dplyr::filter( + drug_label == "drug", + drug_or_class %in% (cross_drug |> + dplyr::distinct(drug_or_class) |> + dplyr::pull()) + ) |> + dplyr::group_by(drug_or_class) |> + dplyr::summarise(median_mcc = median(mcc, na.rm = TRUE), .groups = "drop") |> + dplyr::mutate(tested_on = drug_or_class) |> + dplyr::distinct(drug_or_class, tested_on, median_mcc) + + heatmap_df <- heatmap_df |> + dplyr::add_row(same_drugs) |> + dplyr::left_join( + metadata |> + dplyr::distinct(drug_abbr, class_abbr), + by = c("drug_or_class" = "drug_abbr") + ) |> + dplyr::rename("drug_class" = "class_abbr") |> + dplyr::left_join( + metadata |> + dplyr::distinct(drug_abbr, class_abbr), + by = c("tested_on" = "drug_abbr") + ) + + # Row annotation (already similar to what you did) + annotation_row <- heatmap_df |> + dplyr::distinct(drug_or_class, drug_class) |> + tibble::column_to_rownames("drug_or_class") + + # Column annotation + annotation_col <- heatmap_df |> + dplyr::distinct(tested_on, class_abbr) |> + tibble::column_to_rownames("tested_on") + + mat <- heatmap_df |> + dplyr::select(drug_or_class, tested_on, median_mcc) |> + tidyr::pivot_wider(names_from = tested_on, values_from = median_mcc) |> + tibble::column_to_rownames("drug_or_class") |> + as.matrix() + + row_order <- heatmap_df |> + dplyr::distinct(drug_or_class, drug_class) |> + dplyr::arrange(drug_class, drug_or_class) |> + dplyr::pull(drug_or_class) + + col_order <- heatmap_df |> + dplyr::distinct(tested_on, class_abbr) |> + dplyr::arrange(class_abbr, tested_on) |> + dplyr::pull(tested_on) + + # mat[is.na(mat)] <- 0 + mat <- mat[row_order, col_order] + + # Align annotations + annotation_row <- annotation_row[row_order, , drop = FALSE] + annotation_col <- annotation_col[col_order, , drop = FALSE] + + + # Collect all classes from both row and column + classes <- base::union( + annotation_row$drug_class, + annotation_col$class_abbr + ) + + # Create ONE named color vector + class_colors <- stats::setNames( + scales::hue_pal()(length(classes)), + classes + ) + + heat_colors <- colorRampPalette(RColorBrewer::brewer.pal(11, "RdBu"))(100) + + # ---- Convert annotations ---- + ha_row <- ComplexHeatmap::rowAnnotation( + drug_class = annotation_row$drug_class, + col = list(drug_class = class_colors), + show_annotation_name = FALSE, + show_legend = FALSE + ) + + ha_col <- ComplexHeatmap::HeatmapAnnotation( + class_abbr = annotation_col$class_abbr, + col = list(class_abbr = class_colors), + show_annotation_name = FALSE, na_col = "grey3" + ) + + # ---- Color function (instead of breaks + palette) ---- + col_fun <- circlize::colorRamp2( + seq(-max_val, max_val, length.out = length(heat_colors)), + heat_colors + ) + # ---- Heatmap ---- + cross_drug_hm <- ComplexHeatmap::Heatmap( + mat, + name = "median_mcc", + col = col_fun, + cluster_rows = FALSE, + cluster_columns = FALSE, + row_order = row_order, + column_order = col_order, + left_annotation = ha_row, + top_annotation = ha_col, + show_row_names = TRUE, + show_column_names = TRUE, + column_title = "tested on", + row_title = "trained on", + column_title_side = "bottom", + row_title_side = "right", + row_names_gp = grid::gpar(fontsize = 14), + column_names_gp = grid::gpar(fontsize = 14), + column_names_rot = 0, + + # remove borders like pheatmap + rect_gp = grid::gpar(col = NA), + + # legends + show_heatmap_legend = TRUE + ) + + cross_drug_hm } #' Plot stratified model performance @@ -661,47 +671,63 @@ cross_drug_hm #' @export #' #' @examples -#' plotStratifiedPerf("year", stratified_performance_path = "data/Campylobacter/ML_year_performance", -#' stratified_cross_performance_path = "data/Campylobacter/cross_test_ML_year_performance") -plotStratifiedPerf <- function(year_or_country = "year", - stratified_performance_path = ".", +#' plotStratifiedPerf("year", +#' stratified_performance_path = "data/Campylobacter/ML_year_performance", +#' stratified_cross_performance_path = "data/Campylobacter/cross_test_ML_year_performance" +#' ) +plotStratifiedPerf <- function(year_or_country = "year", + stratified_performance_path = ".", stratified_cross_performance_path = ".") { - - perf <- arrow::read_parquet(file.path(stratified_performance_path, - paste0(year_or_country,"_perf.parquet"))) - - cross_test <- arrow::read_parquet(file.path(stratified_cross_performance_path, - paste0("cross_", - year_or_country, - "_perf.parquet"))) -if(year_or_country == "year") { -all <- perf |> - dplyr::rename("train_year" = "strat_value") |> - dplyr::mutate(test_year = train_year) |> - dplyr::select(drug_label, drug_or_class, - train_year, test_year, feature_type, feature_subtype, mcc) |> - dplyr::bind_rows(cross_test |> - dplyr::select(drug_label, drug_or_class, - train_year, test_year, feature_type, - feature_subtype, mcc)) |> - dplyr::mutate(category = dplyr::if_else( - train_year == test_year, "same year bin", "different year bin")) -} -else { -all <- perf |> - dplyr::rename("train_country" = "strat_value") |> - dplyr::mutate(test_country = train_country) |> - dplyr::select(drug_label, drug_or_class, - train_country, test_country, - feature_type, feature_subtype, mcc) |> - dplyr::bind_rows(cross_test |> - dplyr::select(drug_label, drug_or_class, - train_country, test_country, - feature_type, feature_subtype, mcc)) |> - dplyr::mutate(category = dplyr::if_else( - train_country == test_country, "same country", "different country")) -} - + perf <- arrow::read_parquet(file.path( + stratified_performance_path, + paste0(year_or_country, "_perf.parquet") + )) + + cross_test <- arrow::read_parquet(file.path( + stratified_cross_performance_path, + paste0( + "cross_", + year_or_country, + "_perf.parquet" + ) + )) + if (year_or_country == "year") { + all <- perf |> + dplyr::rename("train_year" = "strat_value") |> + dplyr::mutate(test_year = train_year) |> + dplyr::select( + drug_label, drug_or_class, + train_year, test_year, feature_type, feature_subtype, mcc + ) |> + dplyr::bind_rows(cross_test |> + dplyr::select( + drug_label, drug_or_class, + train_year, test_year, feature_type, + feature_subtype, mcc + )) |> + dplyr::mutate(category = dplyr::if_else( + train_year == test_year, "same year bin", "different year bin" + )) + } else { + all <- perf |> + dplyr::rename("train_country" = "strat_value") |> + dplyr::mutate(test_country = train_country) |> + dplyr::select( + drug_label, drug_or_class, + train_country, test_country, + feature_type, feature_subtype, mcc + ) |> + dplyr::bind_rows(cross_test |> + dplyr::select( + drug_label, drug_or_class, + train_country, test_country, + feature_type, feature_subtype, mcc + )) |> + dplyr::mutate(category = dplyr::if_else( + train_country == test_country, "same country", "different country" + )) + } + fill_vals <- if (year_or_country == "year") { c( "same year bin" = "#b3cde3", @@ -713,7 +739,7 @@ all <- perf |> "different country" = "#fbb4ae" ) } - + plot <- ggplot2::ggplot( all |> dplyr::filter(drug_label == "drug", !is.na(mcc)), @@ -750,7 +776,7 @@ all <- perf |> panel.grid.minor = ggplot2::element_blank(), plot.margin = margin(0, 0, 0, 0) ) -plot + plot } #' Plot multi-drug resistance (MDR) model performance @@ -766,205 +792,208 @@ plot #' @export #' #' @examples -#' plotMDR(MDR_performance_path = "data/Campylobacter/MDR_ML_performance", MDR_top_feature_path = "data/Campylobacter/MDR_ML_top_features", -#' MDR_pred_path = "data/Campylobacter/MDR_ML_pred") -plotMDR <- function(MDR_performance_path = ".", MDR_top_feature_path = ".", +#' plotMDR( +#' MDR_performance_path = "data/Campylobacter/MDR_ML_performance", MDR_top_feature_path = "data/Campylobacter/MDR_ML_top_features", +#' MDR_pred_path = "data/Campylobacter/MDR_ML_pred" +#' ) +plotMDR <- function(MDR_performance_path = ".", MDR_top_feature_path = ".", MDR_pred_path = ".") { - -MDR_perf <- arrow::read_parquet(file.path(MDR_performance_path, "MDR_perf.parquet")) - -# ---- Violin plot ---- -perf_plot <- ggplot2::ggplot(MDR_perf, - ggplot2::aes(x = feature_type, y = mcc)) + - - # violins (overall distribution per feature type) - ggplot2::geom_violin(fill = "grey85", color = NA, alpha = 0.8) + - - # points (colored by binary vs counts) - ggplot2::geom_jitter( - ggplot2::aes(color = feature_subtype), - width = 0.12, size = 2, alpha = 0.8 - ) + - - ggplot2::scale_color_manual(values = c( - "binary" = "#7B9CB5", - "counts" = "#CC8644" - )) + - - ggplot2::theme_minimal(base_size = 12) + - ggplot2::labs( - # title = "MDR model performances", - # subtitle = "Violin = distribution per feature type; points = binary vs counts", - x = "Feature type", - y = "MCC", - color = "Feature\nsubtype" - ) + - - ggplot2::theme( - legend.position = "right", - plot.title = ggplot2::element_text(face = "bold") - ) + - ggplot2::theme( - axis.title = ggplot2::element_text(colour = "black", size = 10), - axis.text.x = ggplot2::element_text(angle = 45, hjust = 1, size = 10, colour = "black"), - axis.text.y = ggplot2::element_text(size = 14, colour = "black"), - legend.title = ggplot2::element_text(size = 12), - legend.text = ggplot2::element_text(size = 10), - legend.position = "none", - title = ggplot2::element_text(face = "bold"), - - panel.background = ggplot2::element_blank(), - panel.grid.minor = ggplot2::element_blank(), - panel.grid.major.x = ggplot2::element_blank(), # remove vertical lines - panel.grid.major.y = ggplot2::element_line(color = "grey80"), # keep horizontal lines - - axis.line = ggplot2::element_line(color = "black") - ) + - ggplot2::scale_y_continuous(limits = c(0, 1)) - -perf_plot - -MDR_pred <- arrow::read_parquet(file.path(MDR_pred_path, "MDR_pred.parquet")) |> - dplyr::mutate( - diff_top2 = purrr::pmap_dbl(dplyr::across(dplyr::contains(".pred") & dplyr::where(is.numeric)), function(...) { - x <- c(...) - sx <- sort(x, decreasing = TRUE) - sx[1] - sx[2] - }) # Difference between prediction probabilities of top two classes - ) |> - dplyr::select(genome_id, resistant_classes, .pred_class, diff_top2, - feature_type, feature_subtype, seed) |> - dplyr::group_by(resistant_classes, .pred_class, feature_type) |> - dplyr::summarise(mean_margin = mean(diff_top2), n = n(), .groups = "drop") |> - dplyr::group_by(resistant_classes, feature_type) |> # normalize within true class - dplyr::mutate(sum = sum(n), prop = n / sum(n)) |> - dplyr::ungroup() - -MDR_pred_plot <- ggplot2::ggplot(MDR_pred, - ggplot2::aes(x = resistant_classes, - y = .pred_class)) + - ggplot2::geom_tile(ggplot2::aes(fill = prop)) + - ggplot2::geom_point(ggplot2::aes(size = mean_margin), color = "black") + - ggplot2::facet_wrap(~ feature_type) + - ggplot2::scale_fill_distiller( - palette = "RdBu", - direction = 1, # flip with -1 if needed - name = "Prediction proportion" + MDR_perf <- arrow::read_parquet(file.path(MDR_performance_path, "MDR_perf.parquet")) + + # ---- Violin plot ---- + perf_plot <- ggplot2::ggplot( + MDR_perf, + ggplot2::aes(x = feature_type, y = mcc) ) + - ggplot2::labs(x = "true class", y = "predicted class") + - ggplot2::scale_size(range = c(1, 6), name = "Mean margin") + - ggplot2::coord_equal() + - ggplot2::theme_minimal() + - ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 1)) + - ggplot2::theme( - axis.title = ggplot2::element_text(colour = "black", size = 10), - axis.text.x = ggplot2::element_text(angle = 45, hjust = 1, size = 10, colour = "black"), - axis.text.y = ggplot2::element_text(size = 10, colour = "black"), - legend.title = ggplot2::element_text(size = 12), - legend.text = ggplot2::element_text(size = 10), - legend.position = "right", - title = ggplot2::element_text(face = "bold") - ) -MDR_pred_plot - -# MDR_feat <- arrow::read_parquet(file.path( -# MDR_top_feature_path,"MDR_top_features.parquet")) |> -# pivot_longer(-c(Variable, feature_type, feature_subtype, seed), -# values_to = "Importance", -# names_to = "Resistant_classes") |> -# filter(!Importance == 0) -# -# MDR_feat_clean <- MDR_feat |> -# dplyr::filter(feature_type != "struct") |> -# dplyr::group_by(Resistant_classes, feature_type, feature_subtype, seed) |> -# dplyr::slice_max(Importance, n = top_n, with_ties = FALSE) |> -# dplyr::ungroup() |> -# dplyr::mutate(Variable = gsub( ".NCBIFAM", "", Variable)) |> -# dplyr::mutate(Variable = gsub("^X", "", Variable)) |> -# dplyr::mutate(Variable = dplyr::if_else( -# feature_type == "domains", gsub("_.*", "", Variable), Variable)) |> -# dplyr::mutate(Variable = dplyr::if_else( -# feature_type == "proteins", gsub("fig.", "fig|", Variable), Variable)) |> -# dplyr::left_join(cluster_feature, by = c("Variable" = "feature")) |> -# dplyr::mutate( -# cluster = dplyr::coalesce(cluster, Variable) -# ) -# -# cluster_df <- MDR_feat_clean |> -# dplyr::group_by(Resistant_classes, cluster) |> -# dplyr::summarise( -# Importance = median(Importance, na.rm = TRUE), -# .groups = "drop" -# ) -# -# top_clusters <- cluster_df |> -# group_by(Resistant_classes) |> -# group_modify(~{ -# -# df <- .x -# -# top_pos <- df |> -# arrange(desc(Importance)) |> -# slice_head(n = 10) -# -# top_neg <- df |> -# arrange(Importance) |> -# slice_head(n = 10) -# -# bind_rows(top_pos, top_neg) -# }) |> -# ungroup() -# -# top_clusters <- top_clusters |> -# dplyr::left_join(protein_names, by = c("cluster" = "proteinID")) |> -# dplyr::mutate( -# proteinName = dplyr::coalesce(proteinName, cluster), # fallback -# proteinName = stringr::str_trunc(proteinName, 50) -# ) |> -# dplyr::distinct(Resistant_classes, proteinName, Importance) |> -# # ✅ reorder AFTER naming -# dplyr::group_by(Resistant_classes) |> -# dplyr::mutate( -# proteinName = forcats::fct_reorder(proteinName, Importance) -# ) |> -# dplyr::ungroup() -# -# ggplot(top_clusters, -# aes(x = Importance, y = proteinName)) + -# -# # line (lollipop stem) -# geom_segment( -# aes(x = 0, xend = Importance, -# y = proteinName, yend = proteinName), -# color = "grey60" -# ) + -# -# # dot -# geom_point( -# aes(color = Importance > 0), -# size = 3 -# ) + -# -# facet_wrap(~ Resistant_classes, scales = "free_y") + -# -# scale_color_manual( -# values = c("TRUE" = "#5b8db8", # positive -# "FALSE" = "#d4872a"), # negative -# guide = "none" -# ) + -# -# theme_minimal(base_size = 13) + -# labs( -# x = "Median importance", -# y = "Cluster" -# ) + -# theme( -# panel.grid.minor = element_blank(), -# strip.text = element_text(face = "bold") -# ) + # violins (overall distribution per feature type) + ggplot2::geom_violin(fill = "grey85", color = NA, alpha = 0.8) + + + # points (colored by binary vs counts) + ggplot2::geom_jitter( + ggplot2::aes(color = feature_subtype), + width = 0.12, size = 2, alpha = 0.8 + ) + + ggplot2::scale_color_manual(values = c( + "binary" = "#7B9CB5", + "counts" = "#CC8644" + )) + + ggplot2::theme_minimal(base_size = 12) + + ggplot2::labs( + # title = "MDR model performances", + # subtitle = "Violin = distribution per feature type; points = binary vs counts", + x = "Feature type", + y = "MCC", + color = "Feature\nsubtype" + ) + + ggplot2::theme( + legend.position = "right", + plot.title = ggplot2::element_text(face = "bold") + ) + + ggplot2::theme( + axis.title = ggplot2::element_text(colour = "black", size = 10), + axis.text.x = ggplot2::element_text(angle = 45, hjust = 1, size = 10, colour = "black"), + axis.text.y = ggplot2::element_text(size = 14, colour = "black"), + legend.title = ggplot2::element_text(size = 12), + legend.text = ggplot2::element_text(size = 10), + legend.position = "none", + title = ggplot2::element_text(face = "bold"), + panel.background = ggplot2::element_blank(), + panel.grid.minor = ggplot2::element_blank(), + panel.grid.major.x = ggplot2::element_blank(), # remove vertical lines + panel.grid.major.y = ggplot2::element_line(color = "grey80"), # keep horizontal lines + + axis.line = ggplot2::element_line(color = "black") + ) + + ggplot2::scale_y_continuous(limits = c(0, 1)) + + perf_plot + MDR_pred <- arrow::read_parquet(file.path(MDR_pred_path, "MDR_pred.parquet")) |> + dplyr::mutate( + diff_top2 = purrr::pmap_dbl(dplyr::across(dplyr::contains(".pred") & dplyr::where(is.numeric)), function(...) { + x <- c(...) + sx <- sort(x, decreasing = TRUE) + sx[1] - sx[2] + }) # Difference between prediction probabilities of top two classes + ) |> + dplyr::select( + genome_id, resistant_classes, .pred_class, diff_top2, + feature_type, feature_subtype, seed + ) |> + dplyr::group_by(resistant_classes, .pred_class, feature_type) |> + dplyr::summarise(mean_margin = mean(diff_top2), n = n(), .groups = "drop") |> + dplyr::group_by(resistant_classes, feature_type) |> # normalize within true class + dplyr::mutate(sum = sum(n), prop = n / sum(n)) |> + dplyr::ungroup() + + MDR_pred_plot <- ggplot2::ggplot( + MDR_pred, + ggplot2::aes( + x = resistant_classes, + y = .pred_class + ) + ) + + ggplot2::geom_tile(ggplot2::aes(fill = prop)) + + ggplot2::geom_point(ggplot2::aes(size = mean_margin), color = "black") + + ggplot2::facet_wrap(~feature_type) + + ggplot2::scale_fill_distiller( + palette = "RdBu", + direction = 1, # flip with -1 if needed + name = "Prediction proportion" + ) + + ggplot2::labs(x = "true class", y = "predicted class") + + ggplot2::scale_size(range = c(1, 6), name = "Mean margin") + + ggplot2::coord_equal() + + ggplot2::theme_minimal() + + ggplot2::theme(axis.text.x = ggplot2::element_text(angle = 45, hjust = 1)) + + ggplot2::theme( + axis.title = ggplot2::element_text(colour = "black", size = 10), + axis.text.x = ggplot2::element_text(angle = 45, hjust = 1, size = 10, colour = "black"), + axis.text.y = ggplot2::element_text(size = 10, colour = "black"), + legend.title = ggplot2::element_text(size = 12), + legend.text = ggplot2::element_text(size = 10), + legend.position = "right", + title = ggplot2::element_text(face = "bold") + ) + MDR_pred_plot + + # MDR_feat <- arrow::read_parquet(file.path( + # MDR_top_feature_path,"MDR_top_features.parquet")) |> + # pivot_longer(-c(Variable, feature_type, feature_subtype, seed), + # values_to = "Importance", + # names_to = "Resistant_classes") |> + # filter(!Importance == 0) + # + # MDR_feat_clean <- MDR_feat |> + # dplyr::filter(feature_type != "struct") |> + # dplyr::group_by(Resistant_classes, feature_type, feature_subtype, seed) |> + # dplyr::slice_max(Importance, n = top_n, with_ties = FALSE) |> + # dplyr::ungroup() |> + # dplyr::mutate(Variable = gsub( ".NCBIFAM", "", Variable)) |> + # dplyr::mutate(Variable = gsub("^X", "", Variable)) |> + # dplyr::mutate(Variable = dplyr::if_else( + # feature_type == "domains", gsub("_.*", "", Variable), Variable)) |> + # dplyr::mutate(Variable = dplyr::if_else( + # feature_type == "proteins", gsub("fig.", "fig|", Variable), Variable)) |> + # dplyr::left_join(cluster_feature, by = c("Variable" = "feature")) |> + # dplyr::mutate( + # cluster = dplyr::coalesce(cluster, Variable) + # ) + # + # cluster_df <- MDR_feat_clean |> + # dplyr::group_by(Resistant_classes, cluster) |> + # dplyr::summarise( + # Importance = median(Importance, na.rm = TRUE), + # .groups = "drop" + # ) + # + # top_clusters <- cluster_df |> + # group_by(Resistant_classes) |> + # group_modify(~{ + # + # df <- .x + # + # top_pos <- df |> + # arrange(desc(Importance)) |> + # slice_head(n = 10) + # + # top_neg <- df |> + # arrange(Importance) |> + # slice_head(n = 10) + # + # bind_rows(top_pos, top_neg) + # }) |> + # ungroup() + # + # top_clusters <- top_clusters |> + # dplyr::left_join(protein_names, by = c("cluster" = "proteinID")) |> + # dplyr::mutate( + # proteinName = dplyr::coalesce(proteinName, cluster), # fallback + # proteinName = stringr::str_trunc(proteinName, 50) + # ) |> + # dplyr::distinct(Resistant_classes, proteinName, Importance) |> + # # ✅ reorder AFTER naming + # dplyr::group_by(Resistant_classes) |> + # dplyr::mutate( + # proteinName = forcats::fct_reorder(proteinName, Importance) + # ) |> + # dplyr::ungroup() + # + # ggplot(top_clusters, + # aes(x = Importance, y = proteinName)) + + # + # # line (lollipop stem) + # geom_segment( + # aes(x = 0, xend = Importance, + # y = proteinName, yend = proteinName), + # color = "grey60" + # ) + + # + # # dot + # geom_point( + # aes(color = Importance > 0), + # size = 3 + # ) + + # + # facet_wrap(~ Resistant_classes, scales = "free_y") + + # + # scale_color_manual( + # values = c("TRUE" = "#5b8db8", # positive + # "FALSE" = "#d4872a"), # negative + # guide = "none" + # ) + + # + # theme_minimal(base_size = 13) + + # labs( + # x = "Median importance", + # y = "Cluster" + # ) + + # theme( + # panel.grid.minor = element_blank(), + # strip.text = element_text(face = "bold") + # ) } #' Compare shuffled vs real model performance @@ -981,41 +1010,39 @@ MDR_pred_plot #' @examples #' plotShuffleVsReal(metadata_path = "data/Campylobacter/", performance_path = "data/Campylobacter/ML_performance") plotShuffleVsReal <- function(metadata_path = ".", performance_path = ".") { - - metadata <- arrow::read_parquet(file.path(metadata_path,"metadata.parquet")) - performance <- arrow::read_parquet(file.path(performance_path,"all_performance.parquet")) - - performance |> - dplyr::mutate( - shuffled_label = dplyr::if_else(shuffled, "shuffled", "real") - ) |> - ggplot2::ggplot(ggplot2::aes(x = feature_subtype, y = mcc, fill = shuffled_label)) + + metadata <- arrow::read_parquet(file.path(metadata_path, "metadata.parquet")) + performance <- arrow::read_parquet(file.path(performance_path, "all_performance.parquet")) + + performance |> + dplyr::mutate( + shuffled_label = dplyr::if_else(shuffled, "shuffled", "real") + ) |> + ggplot2::ggplot(ggplot2::aes(x = feature_subtype, y = mcc, fill = shuffled_label)) + ggplot2::geom_boxplot( - width = 0.55, outlier.size = 0.8, outlier.alpha = 0.4, - outlier.color = "grey50", linewidth = 0.4 - ) + + width = 0.55, outlier.size = 0.8, outlier.alpha = 0.4, + outlier.color = "grey50", linewidth = 0.4 + ) + ggplot2::geom_hline(yintercept = 0, linetype = "dashed", color = "grey60", linewidth = 0.4) + ggplot2::scale_fill_manual( - values = c("real" = "#7B9CB5", "shuffled" = "#C4B8A8"), - name = NULL - ) + + values = c("real" = "#7B9CB5", "shuffled" = "#C4B8A8"), + name = NULL + ) + ggplot2::scale_y_continuous(limits = c(-0.2, 1), breaks = seq(-0.2, 1, 0.2)) + - ggplot2::facet_wrap(~ feature_type, nrow = 1) + + ggplot2::facet_wrap(~feature_type, nrow = 1) + ggplot2::theme_minimal(base_size = 12) + - theme( - panel.grid.major.x = element_blank(), - panel.grid.minor = element_blank(), - panel.grid.major.y = element_line(color = "#E5E2D9", linewidth = 0.4), - strip.text = element_text(color = "grey30", face = "bold", size = 10), - # strip.background = element_rect(fill = "#EEEAE0", color = NA), - axis.text = element_text(color = "grey45"), - legend.position = "top", - legend.text = element_text(color = "grey40", size = 10) - ) + - labs( - x = NULL, y = "MCC" - ) - + theme( + panel.grid.major.x = element_blank(), + panel.grid.minor = element_blank(), + panel.grid.major.y = element_line(color = "#E5E2D9", linewidth = 0.4), + strip.text = element_text(color = "grey30", face = "bold", size = 10), + # strip.background = element_rect(fill = "#EEEAE0", color = NA), + axis.text = element_text(color = "grey45"), + legend.position = "top", + legend.text = element_text(color = "grey40", size = 10) + ) + + labs( + x = NULL, y = "MCC" + ) } #' Plot top contributing feature clusters @@ -1032,59 +1059,63 @@ plotShuffleVsReal <- function(metadata_path = ".", performance_path = ".") { #' @export #' #' @examples -#' plotTopClusters(top_feat_path = "data/Campylobacter/ML_top_features", cluster_feature_path = "data/Campylobacter/", -#' protein_names_path = "data/Campylobacter/", top_n = 10) -plotTopClusters <- function(top_feat_path = ".", cluster_feature_path = ".", +#' plotTopClusters( +#' top_feat_path = "data/Campylobacter/ML_top_features", cluster_feature_path = "data/Campylobacter/", +#' protein_names_path = "data/Campylobacter/", top_n = 10 +#' ) +plotTopClusters <- function(top_feat_path = ".", cluster_feature_path = ".", protein_names_path = ".", top_n = 10) { - ################### Top features ######################### - + top_feat <- arrow::read_parquet(file.path(top_feat_path, "all_top_features.parquet")) cluster_feature <- arrow::read_parquet(file.path(cluster_feature_path, "cluster_feature.parquet")) protein_names <- arrow::read_parquet(file.path(protein_names_path, "protein_names.parquet")) - + # which clusters appear in top n across feature types per drug # join top features with cluster mapping, filter out struct and shuffled - top_feat_clean <- top_feat |> + top_feat_clean <- top_feat |> dplyr::filter(!shuffled, feature_type != "struct", drug_label == "drug") |> dplyr::group_by(drug_or_class, feature_type, feature_subtype, seed) |> dplyr::slice_max(Importance, n = top_n, with_ties = FALSE) |> - dplyr::ungroup() |> - dplyr::mutate(Variable = gsub( ".NCBIFAM", "", Variable)) |> + dplyr::ungroup() |> + dplyr::mutate(Variable = gsub(".NCBIFAM", "", Variable)) |> dplyr::mutate(Variable = gsub("^X", "", Variable)) |> dplyr::mutate(Variable = dplyr::if_else( - feature_type == "domains", gsub("_.*", "", Variable), Variable)) |> + feature_type == "domains", gsub("_.*", "", Variable), Variable + )) |> dplyr::mutate(Variable = dplyr::if_else( - feature_type == "proteins", gsub("fig.", "fig|", Variable), Variable)) |> + feature_type == "proteins", gsub("fig.", "fig|", Variable), Variable + )) |> dplyr::left_join(cluster_feature, by = c("Variable" = "feature")) |> dplyr::mutate( - cluster = dplyr::coalesce(cluster, Variable), # fallback to Variable if no match + cluster = dplyr::coalesce(cluster, Variable), # fallback to Variable if no match Importance_signed = dplyr::if_else(Sign == "NEG", -Importance, Importance) ) - + shared_mat <- top_feat_clean |> dplyr::group_by(drug_or_class, feature_type, cluster) |> dplyr::summarise(abs_imp = median(Importance, na.rm = TRUE), .groups = "drop") |> - # convert to contribution within each feature_type dplyr::group_by(drug_or_class, feature_type) |> dplyr::mutate(contribution = abs_imp / sum(abs_imp, na.rm = TRUE)) |> - # pick top n contributors dplyr::slice_max(contribution, n = top_n, with_ties = FALSE) |> dplyr::ungroup() |> - dplyr::add_count(drug_or_class, cluster, name = "n_feat_types") |> dplyr::left_join(protein_names, by = c("cluster" = "proteinID")) |> dplyr::mutate( proteinName = stringr::str_trunc(proteinName, 50), proteinName = forcats::fct_reorder(proteinName, n_feat_types) ) - - feat_plot <- ggplot2::ggplot(shared_mat, - ggplot2::aes(x = feature_type, - y = proteinName, - fill = contribution)) + + + feat_plot <- ggplot2::ggplot( + shared_mat, + ggplot2::aes( + x = feature_type, + y = proteinName, + fill = contribution + ) + ) + ggplot2::geom_tile(color = "#FAFAF7", linewidth = 0.5, width = 0.9, height = 0.9) + # coord_fixed() + ggplot2::scale_fill_distiller( @@ -1093,7 +1124,7 @@ plotTopClusters <- function(top_feat_path = ".", cluster_feature_path = ".", name = "contribution", na.value = "#EEEAE0" ) + - ggplot2::facet_wrap(~ drug_or_class, scales = "free_y") + + ggplot2::facet_wrap(~drug_or_class, scales = "free_y") + ggplot2::theme_minimal(base_size = 12) + ggplot2::theme( panel.grid.major.x = ggplot2::element_blank(), @@ -1108,6 +1139,6 @@ plotTopClusters <- function(top_feat_path = ".", cluster_feature_path = ".", legend.text = ggplot2::element_text(color = "grey40", size = 10), legend.title = ggplot2::element_text(color = "grey40", size = 10) ) - + feat_plot }