diff --git a/.gitignore b/.gitignore index 1210604..8446421 100644 --- a/.gitignore +++ b/.gitignore @@ -119,6 +119,9 @@ poetry.lock # notebooks **/*.ipynb - +#parquet +data/hv/peptide/staged/01_hv.features.parquet +# env yaml +envs/env.yaml # scratch tmp diff --git a/envs/env.yaml b/envs/env.yaml index f27fbd8..983af30 100644 --- a/envs/env.yaml +++ b/envs/env.yaml @@ -22,6 +22,8 @@ dependencies: # fasta parsing - biopython + # scikit-learn + - scikit-learn # for pulling large datasets - dvc @@ -30,3 +32,4 @@ dependencies: # swap me for git URL - -e /tgen_labs/altin/alphafold3/runs/linear_peptide - git+https://github.com/ljwoods2/mdaf3.git@main + - git+https://github.com/pegerto/mdakit_sasa@main diff --git a/notebooks/hv/30mer_pLDDT.py b/notebooks/hv/30mer_pLDDT.py new file mode 100644 index 0000000..aa637e3 --- /dev/null +++ b/notebooks/hv/30mer_pLDDT.py @@ -0,0 +1,388 @@ +import polars as pl + +# this one is my package +from mdaf3.AF3OutputParser import AF3Output +from mdaf3.FeatureExtraction import * +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import roc_curve, auc +from pathlib import Path +from af3_linear_epitopes import statistics as st + +# import dataframes from "staged" directory +fp_test_dat = pl.read_parquet( + "/scratch/sromero/af3-linear-epitopes/data/hv/focal_protein/staged/00_focal_protein.filt.parquet" +) +all_statistics = pl.read_parquet( + "/scratch/sromero/af3-linear-epitopes/data/hv/peptide/staged/01_hv.features.parquet" +) + + +true_mean = ( + all_statistics.filter(pl.col("epitope")) + .select(pl.col("Mean_pLDDT")) + .to_series() + .to_list() +) + +avg_true_mean = sum(true_mean) / len(true_mean) + +false_mean = ( + all_statistics.filter(~pl.col("epitope")) + .select(pl.col("Mean_pLDDT")) + .to_series() + .to_list() +) +avg_false_mean = sum(false_mean) / len(false_mean) + + +true_std = ( + all_statistics.filter(pl.col("epitope")) + .select(pl.col("Std_pLDDT")) + .to_series() + .to_list() +) +avg_true_std = sum(true_std) / len(true_std) + +false_std = ( + all_statistics.filter(~pl.col("epitope")) + .select(pl.col("Std_pLDDT")) + .to_series() + .to_list() +) +avg_false_std = sum(false_std) / len(false_std) + + +true_min = ( + all_statistics.filter(pl.col("epitope")) + .select(pl.col("Min_pLDDT")) + .to_series() + .to_list() +) + +avg_true_min = sum(true_min) / len(true_min) + +false_min = ( + all_statistics.filter(~pl.col("epitope")) + .select(pl.col("Min_pLDDT")) + .to_series() + .to_list() +) +avg_false_min = sum(false_min) / len(false_min) + + +print( + "Average mean,min,std pLLDT values respectively for epitope: " + + str(avg_true_mean) + + ", " + + str(avg_true_min) + + ", " + + str(avg_true_std) + + "\nAverage mean,min,std pLLDT values respectively for non-epitope: " + + str(avg_false_mean) + + ", " + + str(avg_false_min) + + ", " + + str(avg_false_std) +) + + +def plot_epitope_non_epitope_stats( + avg_true_mean: float, + avg_true_min: float, + avg_true_std: float, + avg_false_mean: float, + avg_false_min: float, + avg_false_std: float, +): + """ + Creates a grouped bar graph comparing mean, minimum, and standard deviation + of pLDDT values for Epitopes and Non-Epitopes. + + Args: + avg_true_mean (float): Average mean pLDDT for epitopes. + avg_true_min (float): Average minimum pLDDT for epitopes. + avg_true_std (float): Average standard deviation pLDDT for epitopes. + avg_false_mean (float): Average mean pLDDT for non-epitopes. + avg_false_min (float): Average minimum pLDDT for non-epitopes. + avg_false_std (float): Average standard deviation pLDDT for non-epitopes. + """ + categories = ["Epitope", "Non-Epitope"] + # Data for each statistic type + mean_values = [avg_true_mean, avg_false_mean] + min_values = [avg_true_min, avg_false_min] + std_values = [avg_true_std, avg_false_std] + + # Set up bar positions + x = np.arange(len(categories)) # the label locations + width = 0.25 # the width of the bars + + fig, ax = plt.subplots(figsize=(10, 7)) + + # Create bars for Mean, Min, and Std Dev for both categories + rects1 = ax.bar( + x - width, + mean_values, + width, + label="Mean pLDDT", + color="skyblue", + edgecolor="grey", + ) + rects2 = ax.bar( + x, min_values, width, label="Min pLDDT", color="lightcoral", edgecolor="grey" + ) + rects3 = ax.bar( + x + width, + std_values, + width, + label="Std Dev pLDDT", + color="lightgreen", + edgecolor="grey", + ) + + # Add labels, title, and custom x-axis tick labels + ax.set_ylabel("pLDDT Value", fontsize=12) + ax.set_title("pLDDT Statistics 30-mer: Epitopes vs Non-Epitopes", fontsize=16) + ax.set_xticks(x) + ax.set_xticklabels(categories, fontsize=12) + ax.legend() + ax.grid(axis="y", linestyle="--", alpha=0.7) + + # Add value labels on top of the bars + def autolabel_single_bar(rects): + for rect in rects: + height = rect.get_height() + ax.annotate( + f"{height:.2f}", + xy=(rect.get_x() + rect.get_width() / 2, height), + xytext=(0, 3), # 3 points vertical offset + textcoords="offset points", + ha="center", + va="bottom", + fontsize=9, + ) + + autolabel_single_bar(rects1) + autolabel_single_bar(rects2) + autolabel_single_bar(rects3) + + plt.tight_layout() + return fig + + +pLDDT_statistics_30mer = plot_epitope_non_epitope_stats( + avg_true_mean, + avg_true_min, + avg_true_std, + avg_false_mean, + avg_false_min, + avg_false_std, +) +pLDDT_statistics_30mer.savefig( + "../results/figures/pLDDT_statistics_30mer_epitope_vs_non-epitopes.png" +) + + +true_mean_min_9mer = ( + all_statistics.filter(pl.col("epitope")) + .select(pl.col("9mer_Mean_pLDDT").list.min()) + .to_series() + .to_list() +) + +avg_true_mean_min_9mer = sum(true_mean_min_9mer) / len(true_mean_min_9mer) + +false_mean_min_9mer = ( + all_statistics.filter(~pl.col("epitope")) + .select(pl.col("9mer_Mean_pLDDT").list.min()) + .to_series() + .to_list() +) +avg_false_mean_min_9mer = sum(false_mean_min_9mer) / len(false_mean_min_9mer) + +# ------------------------------------------------------------------------------------- +true_std_min_9mer = ( + all_statistics.filter(pl.col("epitope")) + .select(pl.col("9mer_std_pLDDT").list.min()) + .to_series() + .to_list() +) + +avg_true_std_min_9mer = sum(true_std_min_9mer) / len(true_std_min_9mer) + +false_std_min_9mer = ( + all_statistics.filter(~pl.col("epitope")) + .select(pl.col("9mer_std_pLDDT").list.min()) + .to_series() + .to_list() +) +avg_false_std_min_9mer = sum(false_std_min_9mer) / len(false_std_min_9mer) + + +print( + "Average mean,std pLLDT values respectively for epitope: " + + str(avg_true_mean_min_9mer) + + ", " + + str(avg_true_std_min_9mer) + + "\nAverage mean,min,std pLLDT values respectively for non-epitope: " + + str(avg_false_mean_min_9mer) + + ", " + + str(avg_false_std_min_9mer) +) + + +def plot_epitope_non_epitope_stats_9mer( + avg_true_mean_min_9mer: float, + avg_true_std_min_9mer: float, + avg_false_mean_min_9mer: float, + avg_false_std_min_9mer: float, +): + """ + Creates a grouped bar graph comparing mean, minimum, and standard deviation + of pLDDT values for Epitopes and Non-Epitopes. + + Args: + avg_true_mean_min_9mer (float): Average mean pLDDT for epitopes. + avg_true_std_min_9mer (float): Average standard deviation pLDDT for epitopes. + avg_false_mean_min_9mer (float): Average mean pLDDT for non-epitopes. + avg_false_std_min_9mer (float): Average standard deviation pLDDT for non-epitopes. + """ + categories = ["Epitope", "Non-Epitope"] + # Data for each statistic type + mean_values = [avg_true_mean_min_9mer, avg_false_mean_min_9mer] + std_values = [avg_true_std_min_9mer, avg_false_std_min_9mer] + + # Set up bar positions + x = np.arange(len(categories)) # the label locations + width = 0.25 # the width of the bars + + fig, ax = plt.subplots(figsize=(10, 7)) + + # Create bars for Mean, Min, and Std Dev for both categories + rects1 = ax.bar( + x - width, + mean_values, + width, + label="Mean pLDDT", + color="skyblue", + edgecolor="grey", + ) + rects3 = ax.bar( + x + width, + std_values, + width, + label="Std Dev pLDDT", + color="lightgreen", + edgecolor="grey", + ) + + # Add labels, title, and custom x-axis tick labels + ax.set_ylabel("pLDDT Value", fontsize=12) + ax.set_title("pLDDT Statistics 9-mer: Epitopes vs Non-Epitopes", fontsize=16) + ax.set_xticks(x) + ax.set_xticklabels(categories, fontsize=12) + ax.legend() + ax.grid(axis="y", linestyle="--", alpha=0.7) + + # Add value labels on top of the bars + def autolabel_single_bar(rects): + for rect in rects: + height = rect.get_height() + ax.annotate( + f"{height:.2f}", + xy=(rect.get_x() + rect.get_width() / 2, height), + xytext=(0, 3), # 3 points vertical offset + textcoords="offset points", + ha="center", + va="bottom", + fontsize=9, + ) + + autolabel_single_bar(rects1) + autolabel_single_bar(rects3) + + plt.tight_layout() + return fig + + +pLDDT_avg_9mer = plot_epitope_non_epitope_stats_9mer( + avg_true_mean_min_9mer, + avg_true_std_min_9mer, + avg_false_mean_min_9mer, + avg_false_std_min_9mer, +) +pLDDT_avg_9mer.savefig( + "../results/figures/pLDDT_statistics_9mer_epitope_vs_non-epitopes.png" +) + + +y_hat = st.normalized_pLDDT_30mer(all_statistics, "Mean_pLDDT") +print(y_hat) + + +y = all_statistics.select("epitope").to_series() +print(y) + +mean_of_30mer = st.plot_auc_roc_curve(y, y_hat, "Normalized Mean of pLDDT 30-mear ROC") +mean_of_30mer.savefig("../results/figures/pLDDT_ROC_mean_30mers.png") + +all_statistics = all_statistics.with_columns( + pl.col("9mer_Mean_pLDDT").list.min().alias("Min_of_means_9mer") +) + +all_statistics = all_statistics.with_columns( + pl.col("9mer_Mean_pLDDT").list.max().alias("Max_of_means_9mer") +) + + +y_hat_min = st.normalized_pLDDT_30mer(all_statistics, "Min_of_means_9mer") +min_roc_9mer = st.plot_auc_roc_curve(y, y_hat_min, "Normalized Min of pLDDT 9-mer ROC") +min_roc_9mer.savefig("../results/figures/pLDDT_ROC_min_9mers.png") + +y_hat_max = st.normalized_pLDDT_30mer(all_statistics, "Max_of_means_9mer") +max_roc_9mer = st.plot_auc_roc_curve(y, y_hat_max, "Normalized Max of pLDDT 9-mer ROC") +max_roc_9mer.savefig("../results/figures/pLDDT_ROC_max_9mers.png") + + +all_statistics = st.pl_avg_weight( + all_statistics, "/scratch/sromero/af3-linear-epitopes/data/hv/peptide/inference" +) + + +y_hat_weight = st.normalized_pLDDT_30mer(all_statistics, "avg_atomic_weight") +st.plot_auc_roc_curve(y, y_hat_weight, "Normalized atomic weight 30-mer ROC") + +all_statistics = st.pl_helix( + all_statistics, "/scratch/sromero/af3-linear-epitopes/data/hv/peptide/inference" +) + +all_statistics = all_statistics.with_columns( + ( + ( + pl.col("helix").list.sum().cast(pl.Float64) + / pl.col("helix").list.len().cast(pl.Float64).fill_null(0) + ) + * 100 + ).alias("true_helix_percentage") +) + +all_statistics = st.pl_beta( + all_statistics, "/scratch/sromero/af3-linear-epitopes/data/hv/peptide/inference" +) + +all_statistics = all_statistics.with_columns( + ( + ( + pl.col("beta").list.sum().cast(pl.Float64) + / pl.col("beta").list.len().cast(pl.Float64).fill_null(0) + ) + * 100 + ).alias("true_beta_percentage") +) + + +y_hat_helix = st.normalized_pLDDT_30mer(all_statistics, "true_helix_percentage") +st.plot_auc_roc_curve(y, y_hat_helix, "Normalized helix percentage pLDDT 9mer ROC") + +y_hat_beta = st.normalized_pLDDT_30mer(all_statistics, "true_beta_percentage") +st.plot_auc_roc_curve(y, y_hat_beta, "Normalized beta pleats percentage pLDDT 9mer ROC") diff --git a/notebooks/hv/atomic_weight_AUC_notebook.py b/notebooks/hv/atomic_weight_AUC_notebook.py new file mode 100644 index 0000000..326752d --- /dev/null +++ b/notebooks/hv/atomic_weight_AUC_notebook.py @@ -0,0 +1,86 @@ +import polars as pl + +# this one is my package +from mdaf3.AF3OutputParser import AF3Output +from mdaf3.FeatureExtraction import * +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import roc_curve, auc +from pathlib import Path +from af3_linear_epitopes import statistics as st +from af3_linear_epitopes import statistics_focal as stf + + +# import dataframes from "staged" directory +fp_test_dat = pl.read_parquet( + "../data/hv/focal_protein/staged/00_focal_protein.filt.parquet" +) +all_statistics = pl.read_parquet("../data/hv/peptide/staged/01_hv.features.parquet") +all_statistics_fp = pl.read_parquet("../data/hv/peptide/staged/01_hv.exploded.parquet") + + +y_hat_weight = st.normalized_pLDDT_30mer(all_statistics, "atomic_weight") +y = all_statistics.select("epitope").to_series() +st.plot_auc_roc_curve(y, y_hat_weight, "Normalized atomic weight 30-mer ROC") + + +y_hat_mass_fp = st.normalized_pLDDT_30mer(all_statistics, "atomic_weight") +y_true = all_statistics.select(pl.col("epitope")) +st.plot_auc_roc_curve(y_true, y_hat_mass_fp, "Normalized atomic mass for 30-mer fp ROC") + + +aggrigate_atomic_weights = all_statistics.group_by("peptide").agg( + (pl.col("9mer_weight").list.max()).first().alias("score"), + pl.col("epitope").first().alias("epitope"), +) + + +aggrigate_atomic_weights + + +y_hat_mass_fp = st.normalized_pLDDT_30mer(aggrigate_atomic_weights, "score") +y_true = aggrigate_atomic_weights.select(pl.col("epitope")) +st.plot_auc_roc_curve( + y_true, y_hat_mass_fp, "Normalized max of atomic mass for 9-mers ROC" +) + + +aggrigate_atomic_weights_mean = all_statistics_fp.group_by("fp_job_names").agg( + (pl.col("9mer_weight").list.mean()).alias("score"), + pl.col("epitope").alias("epitope"), +) +aggrigate_atomic_weights_min = all_statistics_fp.group_by("fp_job_names").agg( + (pl.col("9mer_weight").list.min()).alias("score"), + pl.col("epitope").alias("epitope"), +) +aggrigate_atomic_weights_max = all_statistics_fp.group_by("fp_job_names").agg( + (pl.col("9mer_weight").list.max()).alias("score"), + pl.col("epitope").alias("epitope"), +) + + +import sklearn + +aggrigate_atomic_weights_mean = aggrigate_atomic_weights_mean.with_columns( + pl.struct(pl.col("score").alias("y_hat"), pl.col("epitope").alias("y_true")) + .map_elements(lambda x: sklearn.metrics.roc_auc_score(x["y_true"], x["y_hat"])) + .alias("AUC") +) +aggrigate_atomic_weights_min = aggrigate_atomic_weights_min.with_columns( + pl.struct(pl.col("score").alias("y_hat"), pl.col("epitope").alias("y_true")) + .map_elements(lambda x: sklearn.metrics.roc_auc_score(x["y_true"], x["y_hat"])) + .alias("AUC") +) +aggrigate_atomic_weights_max = aggrigate_atomic_weights_max.with_columns( + pl.struct(pl.col("score").alias("y_hat"), pl.col("epitope").alias("y_true")) + .map_elements(lambda x: sklearn.metrics.roc_auc_score(x["y_true"], x["y_hat"])) + .alias("AUC") +) + + +mean_auc_fp = aggrigate_atomic_weights_mean.select("AUC").mean() +mean_auc_min = aggrigate_atomic_weights_min.select("AUC").mean() +mean_auc_max = aggrigate_atomic_weights_max.select("AUC").mean() +print(mean_auc_fp) +print(mean_auc_min) +print(mean_auc_max) diff --git a/notebooks/hv/graphs_and_data.py b/notebooks/hv/graphs_and_data.py new file mode 100644 index 0000000..0ec8dec --- /dev/null +++ b/notebooks/hv/graphs_and_data.py @@ -0,0 +1,344 @@ +import polars as pl + +# this one is my package +from mdaf3.AF3OutputParser import AF3Output +from mdaf3.FeatureExtraction import * +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import roc_curve, auc +from pathlib import Path +from af3_linear_epitopes import statistics as st +from af3_linear_epitopes import statistics_focal as stf + + +# import dataframes from "staged" directory +fp_test_dat = pl.read_parquet( + "/scratch/sromero/af3-linear-epitopes/data/hv/focal_protein/staged/00_focal_protein.filt.parquet" +) +all_statistics = pl.read_parquet( + "/scratch/sromero/af3-linear-epitopes/data/hv/peptide/staged/01_hv.features.parquet" +) +all_statistics_fp = pl.read_parquet( + "/scratch/sromero/af3-linear-epitopes/data/hv/peptide/staged/01_hv.exploded.parquet" +) + + +true_mean = ( + all_statistics.filter(pl.col("epitope")) + .select(pl.col("Mean_pLDDT")) + .to_series() + .to_list() +) + +avg_true_mean = sum(true_mean) / len(true_mean) + +false_mean = ( + all_statistics.filter(~pl.col("epitope")) + .select(pl.col("Mean_pLDDT")) + .to_series() + .to_list() +) +avg_false_mean = sum(false_mean) / len(false_mean) + + +true_std = ( + all_statistics.filter(pl.col("epitope")) + .select(pl.col("Std_pLDDT")) + .to_series() + .to_list() +) +avg_true_std = sum(true_std) / len(true_std) + +false_std = ( + all_statistics.filter(~pl.col("epitope")) + .select(pl.col("Std_pLDDT")) + .to_series() + .to_list() +) +avg_false_std = sum(false_std) / len(false_std) + + +true_min = ( + all_statistics.filter(pl.col("epitope")) + .select(pl.col("Min_pLDDT")) + .to_series() + .to_list() +) + +avg_true_min = sum(true_min) / len(true_min) + +false_min = ( + all_statistics.filter(~pl.col("epitope")) + .select(pl.col("Min_pLDDT")) + .to_series() + .to_list() +) +avg_false_min = sum(false_min) / len(false_min) + + +print( + "Average mean,min,std pLLDT values respectively for epitope: " + + str(avg_true_mean) + + ", " + + str(avg_true_min) + + ", " + + str(avg_true_std) + + "\nAverage mean,min,std pLLDT values respectively for non-epitope: " + + str(avg_false_mean) + + ", " + + str(avg_false_min) + + ", " + + str(avg_false_std) +) + + +pLDDT_statistics_30mer = st.plot_epitope_non_epitope_stats_30mer( + avg_true_mean, + avg_true_min, + avg_true_std, + avg_false_mean, + avg_false_min, + avg_false_std, +) + + +true_mean_min_9mer = ( + all_statistics.filter(pl.col("epitope")) + .select(pl.col("9mer_Mean_pLDDT").list.min()) + .to_series() + .to_list() +) + +avg_true_mean_min_9mer = sum(true_mean_min_9mer) / len(true_mean_min_9mer) + +false_mean_min_9mer = ( + all_statistics.filter(~pl.col("epitope")) + .select(pl.col("9mer_Mean_pLDDT").list.min()) + .to_series() + .to_list() +) +avg_false_mean_min_9mer = sum(false_mean_min_9mer) / len(false_mean_min_9mer) + +# -------------------------------------------------------------------------------------- +true_min_min_9mer = ( + all_statistics.filter(pl.col("epitope")) + .select(pl.col("9mer_min_pLDDT").list.min()) + .to_series() + .to_list() +) + +avg_true_min_min_9mer = sum(true_min_min_9mer) / len(true_min_min_9mer) + +false_min_min_9mer = ( + all_statistics.filter(~pl.col("epitope")) + .select(pl.col("9mer_min_pLDDT").list.min()) + .to_series() + .to_list() +) +avg_false_min_min_9mer = sum(false_min_min_9mer) / len(false_min_min_9mer) + +# ------------------------------------------------------------------------------------- +true_std_min_9mer = ( + all_statistics.filter(pl.col("epitope")) + .select(pl.col("9mer_std_pLDDT").list.min()) + .to_series() + .to_list() +) + +avg_true_std_min_9mer = sum(true_std_min_9mer) / len(true_std_min_9mer) + +false_std_min_9mer = ( + all_statistics.filter(~pl.col("epitope")) + .select(pl.col("9mer_std_pLDDT").list.min()) + .to_series() + .to_list() +) +avg_false_std_min_9mer = sum(false_std_min_9mer) / len(false_std_min_9mer) + + +print( + "Average mean,min,std pLLDT values respectively for epitope: " + + str(avg_true_mean_min_9mer) + + ", " + + str(avg_true_min_min_9mer) + + ", " + + str(avg_true_std_min_9mer) + + "\nAverage mean,min,std pLLDT values respectively for non-epitope: " + + str(avg_false_mean_min_9mer) + + ", " + + str(avg_false_min_min_9mer) + + ", " + + str(avg_false_std_min_9mer) +) + + +pLDDT_avg_9mer = st.plot_epitope_non_epitope_stats_9mer( + avg_true_mean_min_9mer, + avg_true_std_min_9mer, + avg_false_mean_min_9mer, + avg_false_std_min_9mer, +) + + +y_hat = st.normalized_pLDDT_30mer(all_statistics, "Mean_pLDDT") + + +y = all_statistics.select("epitope").to_series() + + +mean_of_30mer = st.plot_auc_roc_curve(y, y_hat, "Normalized Mean of pLDDT 30-mear ROC") + + +all_statistics = all_statistics.with_columns( + pl.col("9mer_Mean_pLDDT").list.min().alias("Min_of_means_9mer_peptide") +) + + +all_statistics = all_statistics.with_columns( + pl.col("9mer_Mean_pLDDT").list.max().alias("Max_of_means_9mer_peptide") +) + + +y_hat_min = st.normalized_pLDDT_30mer(all_statistics, "Min_of_means_9mer_peptide") +min_roc_9mer = st.plot_auc_roc_curve(y, y_hat_min, "Normalized Min of pLDDT 9-mer ROC") + + +y_hat_max = st.normalized_pLDDT_30mer(all_statistics, "Max_of_means_9mer_peptide") +max_roc_9mer = st.plot_auc_roc_curve(y, y_hat_max, "Normalized Max of pLDDT 9-mer ROC") + + +y_hat_weight = st.normalized_pLDDT_30mer(all_statistics, "atomic_weight") +norm_atomic_weight_30mer = st.plot_auc_roc_curve( + y, y_hat_weight, "Normalized atomic weight 30-mer peptide ROC" +) + + +y = all_statistics_fp.select("epitope").to_series() +y_hat_fp_30mer = st.normalized_pLDDT_30mer(all_statistics_fp, "mean_pLDDT_slice") +norm_fp_pLDDT_mean = st.plot_auc_roc_curve( + y, y_hat_fp_30mer, "Normalized focal protein pLDDT mean 30-mer ROC" +) + + +# takes the mean of every mean pLDDT score for a peptide in different focal protiens +fp_aggrigate_30mer = all_statistics_fp.group_by("peptide").agg( + (pl.col("pLDDT_slice_9mer").list.mean().mean()).alias("score"), + pl.col("epitope").first().alias("epitope"), +) + + +y_hat_geometric_fp = st.normalized_pLDDT_30mer(fp_aggrigate_30mer, "score") +y_true = fp_aggrigate_30mer.select(pl.col("epitope")) +norm_mean_pLDDT_9mer_fp = st.plot_auc_roc_curve( + y_true, + y_hat_geometric_fp, + "nomralized mean of every mean pLDDT score for a peptide in different focal protiens ROC", +) + + +all_statistics_fp = all_statistics_fp.with_columns( + (pl.col("pLDDT_slice_9mer").list.eval((pl.element() / 100).log().mean().exp())) + .list.first() + .alias("Geometric_mean_9mer") +) + + +y_hat_geometric = st.normalized_pLDDT_30mer(all_statistics_fp, "Geometric_mean_9mer") +norm_geometric_mean_9mer_fp = st.plot_auc_roc_curve( + y, y_hat_geometric, "Normalized geometric mean 9-mer focal protein ROC" +) + + +fp_aggrigate = all_statistics_fp.group_by("peptide").agg( + (pl.col("Geometric_mean_9mer").mean()).alias("score"), + pl.col("epitope").first().alias("epitope"), +) + + +y_hat_geometric_fp = fp_aggrigate.select(pl.col("score")) +y_true = fp_aggrigate.select(pl.col("epitope")) +geometric_mean_9mer_fp = st.plot_auc_roc_curve( + y_true, y_hat_geometric_fp, "geometric mean 9-mer fp ROC" +) + + +fp_aggrigate_9mer = all_statistics_fp.group_by("fp_job_names").agg( + (pl.col("Geometric_mean_9mer")).alias("score"), pl.col("epitope").alias("epitope") +) + + +import sklearn + +fp_aggrigate_9mer = fp_aggrigate_9mer.with_columns( + pl.struct(pl.col("score").alias("y_hat"), pl.col("epitope").alias("y_true")) + .map_elements(lambda x: sklearn.metrics.roc_auc_score(x["y_true"], x["y_hat"])) + .alias("AUC") +) + + +mean_auc = fp_aggrigate_9mer.select("AUC").mean() +print(mean_auc) + + +y_hat_mass_fp = st.normalized_pLDDT_30mer(all_statistics_fp, "atomic_weight") +y_true = all_statistics_fp.select(pl.col("epitope")) +norm_weight_30mer_fp = st.plot_auc_roc_curve( + y_true, y_hat_mass_fp, "Normalized atomic weight for 30-mer fp ROC" +) + + +aggrigate_atomic_weights = all_statistics.group_by("peptide").agg( + (pl.col("9mer_weight").list.min()).first().alias("score"), + pl.col("epitope").first().alias("epitope"), +) +y_hat_mass_fp = st.normalized_pLDDT_30mer(aggrigate_atomic_weights, "score") +y_true = aggrigate_atomic_weights.select(pl.col("epitope")) +norm_min_weight_9mer_fp = st.plot_auc_roc_curve( + y_true, y_hat_mass_fp, "Normalized min of atomic weight for 9-mers ROC" +) + + +aggrigate_atomic_weights = all_statistics.group_by("peptide").agg( + (pl.col("9mer_weight").list.max()).first().alias("score"), + pl.col("epitope").first().alias("epitope"), +) +y_hat_mass_fp = st.normalized_pLDDT_30mer(aggrigate_atomic_weights, "score") +y_true = aggrigate_atomic_weights.select(pl.col("epitope")) +norm_max_weight_9mer_fp = st.plot_auc_roc_curve( + y_true, y_hat_mass_fp, "Normalized max of atomic mass for 9-mers ROC" +) +# Shows the AUC scores for each focal protein based on the mean, max, and min of the atomic weights of the 9-mers +aggrigate_atomic_weights_mean = all_statistics_fp.group_by("fp_job_names").agg( + (pl.col("9mer_weight").list.mean()).alias("score"), + pl.col("epitope").alias("epitope"), +) +aggrigate_atomic_weights_min = all_statistics_fp.group_by("fp_job_names").agg( + (pl.col("9mer_weight").list.min()).alias("score"), + pl.col("epitope").alias("epitope"), +) +aggrigate_atomic_weights_max = all_statistics_fp.group_by("fp_job_names").agg( + (pl.col("9mer_weight").list.max()).alias("score"), + pl.col("epitope").alias("epitope"), +) + +aggrigate_atomic_weights_mean = aggrigate_atomic_weights_mean.with_columns( + pl.struct(pl.col("score").alias("y_hat"), pl.col("epitope").alias("y_true")) + .map_elements(lambda x: sklearn.metrics.roc_auc_score(x["y_true"], x["y_hat"])) + .alias("AUC") +) +aggrigate_atomic_weights_min = aggrigate_atomic_weights_min.with_columns( + pl.struct(pl.col("score").alias("y_hat"), pl.col("epitope").alias("y_true")) + .map_elements(lambda x: sklearn.metrics.roc_auc_score(x["y_true"], x["y_hat"])) + .alias("AUC") +) +aggrigate_atomic_weights_max = aggrigate_atomic_weights_max.with_columns( + pl.struct(pl.col("score").alias("y_hat"), pl.col("epitope").alias("y_true")) + .map_elements(lambda x: sklearn.metrics.roc_auc_score(x["y_true"], x["y_hat"])) + .alias("AUC") +) +mean_auc_fp = aggrigate_atomic_weights_mean.select("AUC").mean() +mean_auc_min = aggrigate_atomic_weights_min.select("AUC").mean() +mean_auc_max = aggrigate_atomic_weights_max.select("AUC").mean() +print(mean_auc_fp) +print(mean_auc_min) +print(mean_auc_max) diff --git a/notebooks/hv/pLDDT_AUC_notebook.py b/notebooks/hv/pLDDT_AUC_notebook.py new file mode 100644 index 0000000..93f7204 --- /dev/null +++ b/notebooks/hv/pLDDT_AUC_notebook.py @@ -0,0 +1,131 @@ +import polars as pl + +# this one is my package +from mdaf3.AF3OutputParser import AF3Output +from mdaf3.FeatureExtraction import * +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import roc_curve, auc +from pathlib import Path +from af3_linear_epitopes import statistics as st +from af3_linear_epitopes import statistics_focal as stf + + +# import dataframes from "staged" directory +fp_test_dat = pl.read_parquet( + "/scratch/sromero/af3-linear-epitopes/data/hv/focal_protein/staged/00_focal_protein.filt.parquet" +) +all_statistics = pl.read_parquet( + "/scratch/sromero/af3-linear-epitopes/data/hv/peptide/staged/01_hv.features.parquet" +) +all_statistics_fp = pl.read_parquet( + "/scratch/sromero/af3-linear-epitopes/data/hv/peptide/staged/01_hv.exploded.parquet" +) + + +y_hat = st.normalized_pLDDT_30mer(all_statistics, "Mean_pLDDT") + + +y = all_statistics.select("epitope").to_series() + + +st.plot_auc_roc_curve(y, y_hat, "Normalized Mean pLDDT 30-mer ROC") + + +all_statistics = all_statistics.with_columns( + pl.col("9mer_Mean_pLDDT").list.min().alias("Min_of_means_9mer_peptide") +) + + +all_statistics = all_statistics.with_columns( + pl.col("9mer_Mean_pLDDT").list.max().alias("Max_of_means_9mer_peptide") +) + + +y_hat_min = st.normalized_pLDDT_30mer(all_statistics, "Min_of_means_9mer_peptide") +st.plot_auc_roc_curve(y, y_hat_min, "Normalized Min pLDDT 9mer ROC") + + +y_hat_max = st.normalized_pLDDT_30mer(all_statistics, "Max_of_means_9mer_peptide") +st.plot_auc_roc_curve(y, y_hat_max, "Normalized Max pLDDT 9mer ROC") + + +y_hat_weight = st.normalized_pLDDT_30mer(all_statistics, "atomic_weight") +st.plot_auc_roc_curve(y, y_hat_weight, "Normalized atomic weight 30-mer ROC") + + +fp_test_dat + + +all_statistics_fp + + +y = all_statistics_fp.select("epitope").to_series() +print(len(y)) +y_hat_fp_30mer = st.normalized_pLDDT_30mer(all_statistics_fp, "mean_pLDDT_slice") +print(len(y_hat_fp_30mer)) +st.plot_auc_roc_curve( + y, y_hat_fp_30mer, "Normalized focal protein pLDDT mean 30-mer ROC" +) + + +fp_aggrigate_30mer = all_statistics_fp.group_by("peptide").agg( + (pl.col("mean_pLDDT_slice").mean()).alias("score"), + pl.col("epitope").first().alias("epitope"), +) + + +y_hat_geometric_fp = st.normalized_pLDDT_30mer(fp_aggrigate_30mer, "score") +y_true = fp_aggrigate_30mer.select(pl.col("epitope")) +st.plot_auc_roc_curve(y_true, y_hat_geometric_fp, "geometric mean 9-mer fp ROC") + + +all_statistics_fp + + +all_statistics_fp = all_statistics_fp.with_columns( + (pl.col("pLDDT_slice_9mer").list.eval((pl.element() / 100).log().mean().exp())) + .list.first() + .alias("Geometric_mean_9mer") +) + + +all_statistics_fp + + +y_hat_geometric = st.normalized_pLDDT_30mer(all_statistics_fp, "Geometric_mean_9mer") +st.plot_auc_roc_curve(y, y_hat_geometric, "Normalized geometric mean 9-mer fp ROC") + + +fp_aggrigate = all_statistics_fp.group_by("peptide").agg( + (pl.col("Geometric_mean_9mer").mean()).alias("score"), + pl.col("epitope").first().alias("epitope"), +) + + +y_hat_geometric_fp = fp_aggrigate.select(pl.col("score")) +y_true = fp_aggrigate.select(pl.col("epitope")) +st.plot_auc_roc_curve( + y_true, y_hat_geometric_fp, "Normalized geometric mean 9-mer fp ROC" +) + + +fp_aggrigate_9mer = all_statistics_fp.group_by("fp_job_names").agg( + (pl.col("Geometric_mean_9mer")).alias("score"), pl.col("epitope").alias("epitope") +) + + +fp_aggrigate_9mer + + +import sklearn + +fp_aggrigate_9mer = fp_aggrigate_9mer.with_columns( + pl.struct(pl.col("score").alias("y_hat"), pl.col("epitope").alias("y_true")) + .map_elements(lambda x: sklearn.metrics.roc_auc_score(x["y_true"], x["y_hat"])) + .alias("AUC") +) + + +mean_auc = fp_aggrigate_9mer.select("AUC").mean() +print(mean_auc) diff --git a/notebooks/hv/rsa_sa_AUC_notebook.py b/notebooks/hv/rsa_sa_AUC_notebook.py new file mode 100644 index 0000000..c70365d --- /dev/null +++ b/notebooks/hv/rsa_sa_AUC_notebook.py @@ -0,0 +1,36 @@ + +import polars as pl + +# this one is my package +from mdaf3.AF3OutputParser import AF3Output +from mdaf3.FeatureExtraction import * +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import roc_curve, auc +from pathlib import Path +from af3_linear_epitopes import statistics as st +from af3_linear_epitopes import statistics_focal as stf + + +# import dataframes from "staged" directory +fp_test_dat = pl.read_parquet( + "/scratch/sromero/af3-linear-epitopes/data/hv/focal_protein/staged/00_focal_protein.filt.parquet" +) +all_statistics = pl.read_parquet( + "/scratch/sromero/af3-linear-epitopes/data/hv/peptide/staged/01_hv.features.parquet" +) +all_statistics_fp = pl.read_parquet( + "/scratch/sromero/af3-linear-epitopes/data/hv/peptide/staged/01_hv.exploded.parquet" + + +all_statistics_fp + + +all_statistics_fp.with_columns(pl.col("RSA").list.mean().alias("Mean_RSA")) + + +y_hat_RSA_fp = st.normalized_pLDDT_30mer(all_statistics_fp, "Mean_RSA") +y_true_RSA = all_statistics_fp.select(pl.col("epitope")) +st.plot_auc_roc_curve( + y_true_RSA, y_hat_RSA_fp, "Normalized mean RSA values for 30mer fp ROC" +) diff --git a/notebooks/hv_class/hv_class.py b/notebooks/hv_class/hv_class.py new file mode 100644 index 0000000..ecce471 --- /dev/null +++ b/notebooks/hv_class/hv_class.py @@ -0,0 +1,123 @@ +import polars as pl + +# this one is my package +from mdaf3.AF3OutputParser import AF3Output +from mdaf3.FeatureExtraction import * +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import roc_curve, auc +from pathlib import Path +from af3_linear_epitopes import statistics as st +from af3_linear_epitopes import statistics_focal as stf + + +# import dataframes from "staged" directory + +fp_hv_class_dat = pl.read_parquet( + "../../data/hv_class/focal_protein/staged/hv_class_focal_protein.filt.clust.parquet" +).filter(pl.col("representative")) +hv_class_dat = pl.read_parquet( + "../../data/hv_class/peptide/staged/hv_class_peptide.filt.parquet" +) +all_statistics_hv_class = pl.read_parquet( + "../../data/hv_class/peptide/staged/01_hv_class.features.parquet" +) +all_statistics_hv_class_fp = pl.read_parquet( + "../../data/hv_class/focal_protein/staged/01_hv_class.exploded.parquet" +) + + +y = all_statistics_hv_class_fp.select("epitope").to_series() +print(len(y)) +y_hat_fp_30mer = st.normalized_pLDDT_30mer( + all_statistics_hv_class_fp, "mean_pLDDT_slice" +) +print(len(y_hat_fp_30mer)) +hv_class_norm_mean_pLDDT_30mer_fp = st.plot_auc_roc_curve( + y, y_hat_fp_30mer, "hv_class Normalized focal protein pLDDT mean 30-mer ROC" +) +hv_class_norm_mean_pLDDT_30mer_fp.savefig( + "../../results/figures/hv_class_norm_mean_pLDDT_30mer_fp.png" +) + + +fp_aggrigate_30mer = all_statistics_hv_class_fp.group_by("peptide").agg( + (pl.col("mean_pLDDT_slice").mean()).alias("score"), + pl.col("epitope").first().alias("epitope"), +) + + +y_hat_geometric_fp = st.normalized_pLDDT_30mer(fp_aggrigate_30mer, "score") +y_true = fp_aggrigate_30mer.select(pl.col("epitope")) +hv_class_geometric_mean_pLDDT_30mer_fp = st.plot_auc_roc_curve( + y_true, y_hat_geometric_fp, "hv_class geometric mean pLDDT 30-mer fp ROC" +) +hv_class_geometric_mean_pLDDT_30mer_fp.savefig( + "../../results/figures/hv_class_geometric_mean_pLDDT_30mer_fp.png" +) + + +all_statistics_hv_class_fp = all_statistics_hv_class_fp.with_columns( + (pl.col("pLDDT_slice_9mer").list.eval((pl.element() / 100).log().mean().exp())) + .list.first() + .alias("Geometric_mean_9mer") +) + + +all_statistics_hv_class_fp + + +y_hat_geometric = st.normalized_pLDDT_30mer( + all_statistics_hv_class_fp, "Geometric_mean_9mer" +) +hv_class_norm_geometric_mean_pLDDT_9mer_fp = st.plot_auc_roc_curve( + y, y_hat_geometric, "hv_class Normalized geometric mean 9-mer fp ROC" +) +hv_class_norm_geometric_mean_pLDDT_9mer_fp.savefig( + "../../results/figures/hv_class_norm_geometric_mean_pLDDT_9mer_fp.png" +) + + +fp_aggrigate = all_statistics_hv_class_fp.group_by("peptide").agg( + (pl.col("Geometric_mean_9mer").mean()).alias("score"), + pl.col("epitope").first().alias("epitope"), +) + + +y_hat_geometric_fp = fp_aggrigate.select(pl.col("score")) +y_true = fp_aggrigate.select(pl.col("epitope")) +hv_class_geometric_mean_pLDDT_9mer_fp = st.plot_auc_roc_curve( + y_true, y_hat_geometric_fp, "hv_class geometric mean 9-mer fp ROC" +) +hv_class_geometric_mean_pLDDT_9mer_fp.savefig( + "../../results/figures/hv_class_geometric_mean_pLDDT_9mer_fp.png" +) + + +fp_aggrigate_9mer = all_statistics_hv_class_fp.group_by("job_name").agg( + (pl.col("Geometric_mean_9mer")).alias("score"), pl.col("epitope").alias("epitope") +) + + +fp_aggrigate_9mer = fp_aggrigate_9mer.with_columns( + pl.col("score").list.len().alias("#_of_fp_its_in") +) + +percentage = ( + fp_aggrigate_9mer.filter(pl.col("#_of_fp_its_in") > 1).height + / fp_aggrigate_9mer.height +) * 100 +print(percentage) + + +import sklearn + +fp_aggrigate_9mer = fp_aggrigate_9mer.with_columns( + pl.struct(pl.col("score").alias("y_hat"), pl.col("epitope").alias("y_true")) + .map_elements(lambda x: sklearn.metrics.roc_auc_score(x["y_true"], x["y_hat"])) + .alias("AUC") +) + + +mean_auc = fp_aggrigate_9mer.select("AUC").mean() +print(mean_auc) diff --git a/notebooks/hv_class/hv_class_scratch.py b/notebooks/hv_class/hv_class_scratch.py new file mode 100644 index 0000000..77d1903 --- /dev/null +++ b/notebooks/hv_class/hv_class_scratch.py @@ -0,0 +1,142 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.17.2 +# kernelspec: +# display_name: linear-epitope +# language: python +# name: python3 +# --- + +# %% +# %load_ext autoreload +# %autoreload 2 + +# %% +import polars as pl + +# this one is my package +from mdaf3.AF3OutputParser import AF3Output +from mdaf3.FeatureExtraction import * +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import roc_curve, auc +from pathlib import Path +from af3_linear_epitopes import statistics as st +from af3_linear_epitopes import statistics_focal as stf + + +# import dataframes from "staged" directory + +fp_hv_class_dat = pl.read_parquet( + "../../data/hv_class/focal_protein/staged/hv_class_focal_protein.filt.clust.parquet" +).filter(pl.col("representative")) +hv_class_dat = pl.read_parquet( + "../../data/hv_class/peptide/staged/hv_class_peptide.filt.parquet" +) +all_statistics_hv_class = pl.read_parquet( + "../../data/hv_class/peptide/staged/01_hv_class.features.parquet" +) +all_statistics_hv_class_fp = pl.read_parquet( + "../../data/hv_class/focal_protein/staged/01_hv_class.exploded.parquet" +) + +# %% +y = all_statistics_hv_class_fp.select("epitope").to_series() +print(len(y)) +y_hat_fp_30mer = st.normalized_pLDDT_30mer( + all_statistics_hv_class_fp, "mean_pLDDT_slice" +) +print(len(y_hat_fp_30mer)) +hv_class_norm_mean_pLDDT_30mer_fp = st.plot_auc_roc_curve( + y, y_hat_fp_30mer, "hv_class Normalized focal protein pLDDT mean 30-mer ROC" +) +hv_class_norm_mean_pLDDT_30mer_fp.savefig( + "../../results/figures/hv_class_norm_mean_pLDDT_30mer_fp.png" +) + +# %% +fp_aggrigate_30mer = all_statistics_hv_class_fp.group_by("peptide").agg( + (pl.col("mean_pLDDT_slice").mean()).alias("score"), + pl.col("epitope").first().alias("epitope"), +) + +# %% +y_hat_geometric_fp = st.normalized_pLDDT_30mer(fp_aggrigate_30mer, "score") +y_true = fp_aggrigate_30mer.select(pl.col("epitope")) +hv_class_geometric_mean_pLDDT_30mer_fp = st.plot_auc_roc_curve( + y_true, y_hat_geometric_fp, "hv_class geometric mean pLDDT 30-mer fp ROC" +) +hv_class_geometric_mean_pLDDT_30mer_fp.savefig( + "../../results/figures/hv_class_geometric_mean_pLDDT_30mer_fp.png" +) + +# %% +all_statistics_hv_class_fp = all_statistics_hv_class_fp.with_columns( + (pl.col("pLDDT_slice_9mer").list.eval((pl.element() / 100).log().mean().exp())) + .list.first() + .alias("Geometric_mean_9mer") +) + +# %% +all_statistics_hv_class_fp + +# %% +y_hat_geometric = st.normalized_pLDDT_30mer( + all_statistics_hv_class_fp, "Geometric_mean_9mer" +) +hv_class_norm_geometric_mean_pLDDT_9mer_fp = st.plot_auc_roc_curve( + y, y_hat_geometric, "hv_class Normalized geometric mean 9-mer fp ROC" +) +hv_class_norm_geometric_mean_pLDDT_9mer_fp.savefig( + "../../results/figures/hv_class_norm_geometric_mean_pLDDT_9mer_fp.png" +) + +# %% +fp_aggrigate = all_statistics_hv_class_fp.group_by("peptide").agg( + (pl.col("Geometric_mean_9mer").mean()).alias("score"), + pl.col("epitope").first().alias("epitope"), +) + +# %% +y_hat_geometric_fp = fp_aggrigate.select(pl.col("score")) +y_true = fp_aggrigate.select(pl.col("epitope")) +hv_class_geometric_mean_pLDDT_9mer_fp = st.plot_auc_roc_curve( + y_true, y_hat_geometric_fp, "hv_class geometric mean 9-mer fp ROC" +) +hv_class_geometric_mean_pLDDT_9mer_fp.savefig( + "../../results/figures/hv_class_geometric_mean_pLDDT_9mer_fp.png" +) + +# %% +fp_aggrigate_9mer = all_statistics_hv_class_fp.group_by("job_name").agg( + (pl.col("Geometric_mean_9mer")).alias("score"), pl.col("epitope").alias("epitope") +) + +# %% +fp_aggrigate_9mer = fp_aggrigate_9mer.with_columns( + pl.col("score").list.len().alias("#_of_fp_its_in") +) + +percentage = ( + fp_aggrigate_9mer.filter(pl.col("#_of_fp_its_in") > 1).height + / fp_aggrigate_9mer.height +) * 100 +print(percentage) + +# %% +import sklearn + +fp_aggrigate_9mer = fp_aggrigate_9mer.with_columns( + pl.struct(pl.col("score").alias("y_hat"), pl.col("epitope").alias("y_true")) + .map_elements(lambda x: sklearn.metrics.roc_auc_score(x["y_true"], x["y_hat"])) + .alias("AUC") +) + +# %% +mean_auc = fp_aggrigate_9mer.select("AUC").mean() +print(mean_auc) diff --git a/notebooks/hv_class/hv_class_structure_data.py b/notebooks/hv_class/hv_class_structure_data.py new file mode 100644 index 0000000..63f9ba0 --- /dev/null +++ b/notebooks/hv_class/hv_class_structure_data.py @@ -0,0 +1,128 @@ +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.2 +# kernelspec: +# display_name: linear-epitope +# language: python +# name: python3 +# --- + +# %% +# %load_ext autoreload +# %autoreload 2 + +# %% +import polars as pl + +# this one is my package +from mdaf3.AF3OutputParser import AF3Output +from mdaf3.FeatureExtraction import * +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import roc_curve, auc +from pathlib import Path +from af3_linear_epitopes import statistics as st +from af3_linear_epitopes import statistics_focal as stf + + +# import dataframes from "staged" directory + +fp_hv_class_dat = pl.read_parquet( + "../../data/hv_class/focal_protein/staged/hv_class_focal_protein.filt.clust.parquet" +).filter(pl.col("representative")) +hv_class_dat = pl.read_parquet( + "../../data/hv_class/peptide/staged/hv_class_peptide.filt.parquet" +) +all_statistics_hv_class = pl.read_parquet( + "../../data/hv_class/peptide/staged/01_hv_class.features.parquet" +) +all_statistics_hv_class_fp = pl.read_parquet( + "../../data/hv_class/focal_protein/staged/01_hv_class.exploded.parquet" +) + +# %% +all_statistics_hv_class_fp = st.pl_structure( + all_statistics_hv_class_fp, + "../../data/hv_class/peptide/inference", +) + +# %% +y_hat_structure_fp = st.normalized_pLDDT_30mer( + all_statistics_hv_class_fp, "helix_percentage" +) +y_true = all_statistics_hv_class_fp.select(pl.col("epitope")) +hv_class_geometric_mean_pLDDT_30mer_fp = st.plot_auc_roc_curve( + y_true, y_hat_structure_fp, "hv_class normalized helix percentage 30-mer fp ROC" +) + +# %% +y_hat_structure_fp = st.normalized_pLDDT_30mer( + all_statistics_hv_class_fp, "beta_sheet_percentage" +) +y_true = all_statistics_hv_class_fp.select(pl.col("epitope")) +hv_class_geometric_mean_pLDDT_30mer_fp = st.plot_auc_roc_curve( + y_true, + y_hat_structure_fp, + "hv_class normalized beta sheet percentage 30-mer fp ROC", +) + +# %% +y_hat_structure_fp = st.normalized_pLDDT_30mer( + all_statistics_hv_class_fp, "loop_percentage" +) +y_true = all_statistics_hv_class_fp.select(pl.col("epitope")) +hv_class_geometric_mean_pLDDT_30mer_fp = st.plot_auc_roc_curve( + y_true, y_hat_structure_fp, "hv_class normalized loop percentage 30-mer fp ROC" +) + +# %% +aggregate_mean_pLDDT_fp = all_statistics_hv_class_fp.group_by("fp_job_names").agg( + (pl.col("helix_percentage")).alias("score"), + pl.col("epitope").alias("epitope"), +) +import sklearn + +aggregate_mean_pLDDT_fp = aggregate_mean_pLDDT_fp.with_columns( + pl.struct(pl.col("score").alias("y_hat"), pl.col("epitope").alias("y_true")) + .map_elements(lambda x: sklearn.metrics.roc_auc_score(x["y_true"], x["y_hat"])) + .alias("AUC") +) +mean_auc_pLDDT = aggregate_mean_pLDDT_fp.select("AUC").mean() +print(mean_auc_pLDDT) + +# %% +aggregate_mean_pLDDT_fp = all_statistics_hv_class_fp.group_by("fp_job_names").agg( + (pl.col("beta_sheet_percentage")).alias("score"), + pl.col("epitope").alias("epitope"), +) +import sklearn + +aggregate_mean_pLDDT_fp = aggregate_mean_pLDDT_fp.with_columns( + pl.struct(pl.col("score").alias("y_hat"), pl.col("epitope").alias("y_true")) + .map_elements(lambda x: sklearn.metrics.roc_auc_score(x["y_true"], x["y_hat"])) + .alias("AUC") +) +mean_auc_pLDDT = aggregate_mean_pLDDT_fp.select("AUC").mean() +print(mean_auc_pLDDT) + +# %% +aggregate_mean_pLDDT_fp = all_statistics_hv_class_fp.group_by("fp_job_names").agg( + (pl.col("loop_percentage")).alias("score"), + pl.col("epitope").alias("epitope"), +) +import sklearn + +aggregate_mean_pLDDT_fp = aggregate_mean_pLDDT_fp.with_columns( + pl.struct(pl.col("score").alias("y_hat"), pl.col("epitope").alias("y_true")) + .map_elements(lambda x: sklearn.metrics.roc_auc_score(x["y_true"], x["y_hat"])) + .alias("AUC") +) +mean_auc_pLDDT = aggregate_mean_pLDDT_fp.select("AUC").mean() +print(mean_auc_pLDDT) diff --git a/notebooks/hv_class/rsa_sa_hv_class_AUC.py b/notebooks/hv_class/rsa_sa_hv_class_AUC.py new file mode 100644 index 0000000..2a5146e --- /dev/null +++ b/notebooks/hv_class/rsa_sa_hv_class_AUC.py @@ -0,0 +1,182 @@ +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.2 +# kernelspec: +# display_name: linear-epitope +# language: python +# name: python3 +# --- + +# %% +import polars as pl + + +# %% +# this one is my package +from mdaf3.AF3OutputParser import AF3Output +from mdaf3.FeatureExtraction import * +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import roc_curve, auc +from pathlib import Path +from af3_linear_epitopes import statistics as st +from af3_linear_epitopes import statistics_focal as stf + + +# %% +# import dataframes from "staged" directory +fp_hv_class_dat = pl.read_parquet( + "../../data/hv_class/focal_protein/staged/hv_class_focal_protein.filt.clust.parquet" +).filter(pl.col("representative")) +hv_class_dat = pl.read_parquet( + "../../data/hv_class/peptide/staged/hv_class_peptide.filt.parquet" +) +all_statistics_hv_class_fp = pl.read_parquet( + "../../data/hv_class/focal_protein/staged/01_hv_class.exploded.parquet" +) +rsa_data = pl.read_parquet( + "../../data/hv_class/focal_protein/staged/05_focal_protein.rsa.parquet" +).filter(pl.col("representative")) + + +# %% +all_statistics_hv_class_fp = rsa_data.join( + all_statistics_hv_class_fp, left_on="job_name", right_on="fp_job_names" +) + + +# %% +all_statistics_hv_class_fp = all_statistics_hv_class_fp.drop( + "seq_right", "raw_protein_ids_right" +) + + +# %% +all_statistics_hv_class_fp = all_statistics_hv_class_fp.rename( + {"job_name": "fp_job_name", "job_name_right": "job_name"} +) + + +# %% +all_statistics_hv_class_fp = st.rsa_mean(all_statistics_hv_class_fp) + + +# %% +y_hat_RSA_fp = st.normalized_pLDDT_30mer( + all_statistics_hv_class_fp, "mean_rsa_slice", 0 +) +y_true_RSA = all_statistics_hv_class_fp.select(pl.col("epitope")) +inverse_norm_rsa_mean_30mer_fp_hv_class = st.plot_auc_roc_curve( + y_true_RSA, + y_hat_RSA_fp, + "hv_class Normalized mean RSA values for 30mer fp ROC", +) +inverse_norm_rsa_mean_30mer_fp_hv_class.savefig( + "../../results/figures/inverse_norm_rsa_mean_30mer_fp_hv_class.png" +) + + +# %% +y_hat_SA_fp = st.normalized_pLDDT_30mer(all_statistics_hv_class_fp, "mean_sa_slice", -1) +y_true_SA = all_statistics_hv_class_fp.select(pl.col("epitope")) +inverse_norm_sa_mean_30mer_fp_hv_class = st.plot_auc_roc_curve( + y_true_SA, + y_hat_SA_fp, + "hv_class inverse Normalized mean SA values for 30mer fp ROC", +) +inverse_norm_sa_mean_30mer_fp_hv_class.savefig( + "../../results/figures/inverse_norm_sa_mean_30mer_fp_hv_class.png" +) + + +# %% +aggregate_mean_rsa = all_statistics_hv_class_fp.group_by("peptide").agg( + (pl.col("mean_rsa_slice").mean()).alias("score"), + pl.col("epitope").first().alias("epitope"), +) + + +# %% +y_hat_RSA_fp = aggregate_mean_rsa.select(pl.col("score")) +y_true_RSA = aggregate_mean_rsa.select(pl.col("epitope")) +rsa_mean_aggregate_30mer_fp_hv_class = st.plot_auc_roc_curve( + y_true_RSA, y_hat_RSA_fp, "hv_class aggregate mean RSA values for 30mer fp ROC" +) +rsa_mean_aggregate_30mer_fp_hv_class.savefig( + "../../results/figures/rsa_mean_aggregate_30mer_fp_hv_class.png" +) + + +# %% +aggregate_mean_sa = all_statistics_hv_class_fp.group_by("peptide").agg( + (pl.col("mean_sa_slice").mean()).alias("score"), + pl.col("epitope").first().alias("epitope"), +) + + +# %% +y_hat_SA_fp = aggregate_mean_sa.select(pl.col("score")) +y_true_SA = aggregate_mean_sa.select(pl.col("epitope")) +sa_mean_aggregate_30mer_fp_hv_class = st.plot_auc_roc_curve( + y_true_SA, + y_hat_SA_fp, + "hv_class aggregate mean SA values for 30mer fp ROC", +) +sa_mean_aggregate_30mer_fp_hv_class.savefig( + "../../results/figures/sa_mean_aggregate_30mer_fp_hv_class.png" +) + + +# %% +aggregate_mean_rsa_fp = all_statistics_hv_class_fp.group_by("fp_job_name").agg( + (pl.col("mean_rsa_slice")).alias("score"), + pl.col("epitope").alias("epitope"), +) +aggregate_mean_sa_fp = all_statistics_hv_class_fp.group_by("fp_job_name").agg( + (pl.col("mean_sa_slice")).alias("score"), + pl.col("epitope").alias("epitope"), +) + + +# %% +import sklearn + + +# %% +aggregate_mean_rsa_fp = aggregate_mean_rsa_fp.with_columns( + pl.struct(pl.col("score").alias("y_hat"), pl.col("epitope").alias("y_true")) + .map_elements(lambda x: sklearn.metrics.roc_auc_score(x["y_true"], x["y_hat"])) + .alias("AUC") +) +aggregate_mean_sa_fp = aggregate_mean_sa_fp.with_columns( + pl.struct(pl.col("score").alias("y_hat"), pl.col("epitope").alias("y_true")) + .map_elements(lambda x: sklearn.metrics.roc_auc_score(x["y_true"], x["y_hat"])) + .alias("AUC") +) + + +# %% +mean_auc_rsa = aggregate_mean_rsa_fp.select("AUC").mean() +mean_auc_sa = aggregate_mean_sa_fp.select("AUC").mean() +print(mean_auc_rsa) +print(mean_auc_sa) + +# %% +auc_rsa_data = aggregate_mean_rsa_fp["AUC"].to_list() +auc_sa_data = aggregate_mean_sa_fp["AUC"].to_list() +st.display_boxplot(auc_rsa_data, "Mean auc for RSA values box plot", "", "AUC") +st.display_boxplot(auc_sa_data, "Mean auc for SA values box plot", "", "AUC") + +# %% +aggregate_mean_sa_fp.filter(pl.col("AUC") == 0) + +# %% +zero_auc = aggregate_mean_sa_fp.filter(pl.col("AUC") == 0) +zero_auc.filter(pl.col("score").list.len() == 2) diff --git a/notebooks/in_class/in_class_structure_data.py b/notebooks/in_class/in_class_structure_data.py new file mode 100644 index 0000000..dbc269b --- /dev/null +++ b/notebooks/in_class/in_class_structure_data.py @@ -0,0 +1,273 @@ +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.2 +# kernelspec: +# display_name: linear-epitope +# language: python +# name: python3 +# --- + +# %% +# %load_ext autoreload +# %autoreload 2 + + +# %% +import polars as pl + +# this one is my package +from mdaf3.AF3OutputParser import AF3Output +from mdaf3.FeatureExtraction import * +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import roc_curve, auc +from pathlib import Path +from af3_linear_epitopes import statistics as st +from af3_linear_epitopes import statistics_focal as stf +from scipy import stats + + +# import dataframes from "staged" directory + +fp_in_class_dat = pl.read_parquet( + "/scratch/sromero/af3-linear-epitopes/data/in_class/focal_protein/staged/in_class_focal_protein.filt.clust.parquet" +).filter(pl.col("representative")) +in_class_dat = pl.read_parquet( + "/scratch/sromero/af3-linear-epitopes/data/in_class/peptide/staged/in_class_peptide.filt.parquet" +) + +all_statistics_in_class_fp = pl.read_parquet( + "../../data/in_class/focal_protein/staged/01_in_class.exploded.parquet" +) + +# %% +all_statistics_in_class_fp = st.pl_structure_fp( + all_statistics_in_class_fp, + "../../data/in_class/focal_protein/inference", +) + + # %% + all_statistics_in_class_fp = all_statistics_in_class_fp.with_columns( + (pl.col("helix") / 30).alias("helix_percentage"), + (pl.col("beta") / 30).alias("beta_sheet_percentage"), + (pl.col("loop") / 30).alias("loop_percentage"), + ) + +# %% +all_statistics_in_class_fp + +# %% +y_hat_structure_fp = st.normalized_pLDDT_30mer( + all_statistics_in_class_fp, "helix_percentage", -1 +) +y_true = all_statistics_in_class_fp.select(pl.col("epitope")) +in_class_helix_30mer_fp = st.plot_auc_roc_curve( + y_true, + y_hat_structure_fp, + "in_class inverse normalized helix percentage 30-mer fp ROC", +) + +# %% +y_hat_structure_fp = st.normalized_pLDDT_30mer( + all_statistics_in_class_fp, "beta_sheet_percentage", 0 +) +y_true = all_statistics_in_class_fp.select(pl.col("epitope")) +in_class_beta_sheet_30mer_fp = st.plot_auc_roc_curve( + y_true, + y_hat_structure_fp, + "in_class inverse normalized beta sheet percentage 30-mer fp ROC", +) + +# %% +y_hat_structure_fp = st.normalized_pLDDT_30mer( + all_statistics_in_class_fp, "loop_percentage", -1 +) +y_true = all_statistics_in_class_fp.select(pl.col("epitope")) +in_class_loop_30mer_fp = st.plot_auc_roc_curve( + y_true, y_hat_structure_fp, "iv_class normalized loop percentage 30-mer fp ROC" +) + +# %% +aggregate_mean_pLDDT_fp = all_statistics_in_class_fp.group_by("fp_job_names").agg( + (pl.col("helix_percentage")).alias("score"), + pl.col("epitope").alias("epitope"), +) +import sklearn + +aggregate_mean_pLDDT_fp = aggregate_mean_pLDDT_fp.with_columns( + pl.struct(pl.col("score").alias("y_hat"), pl.col("epitope").alias("y_true")) + .map_elements(lambda x: sklearn.metrics.roc_auc_score(x["y_true"], x["y_hat"])) + .alias("AUC") +) +mean_auc_pLDDT = aggregate_mean_pLDDT_fp.select("AUC").mean() +print(mean_auc_pLDDT) + +# %% +aggregate_mean_pLDDT_fp = all_statistics_in_class_fp.group_by("fp_job_names").agg( + (pl.col("beta_sheet_percentage")).alias("score"), + pl.col("epitope").alias("epitope"), +) +import sklearn + +aggregate_mean_pLDDT_fp = aggregate_mean_pLDDT_fp.with_columns( + pl.struct(pl.col("score").alias("y_hat"), pl.col("epitope").alias("y_true")) + .map_elements(lambda x: sklearn.metrics.roc_auc_score(x["y_true"], x["y_hat"])) + .alias("AUC") +) +mean_auc_pLDDT = aggregate_mean_pLDDT_fp.select("AUC").mean() +print(mean_auc_pLDDT) + +# %% +aggregate_mean_pLDDT_fp = all_statistics_in_class_fp.group_by("fp_job_names").agg( + (pl.col("loop_percentage")).alias("score"), + pl.col("epitope").alias("epitope"), +) + +aggregate_mean_pLDDT_fp = aggregate_mean_pLDDT_fp.with_columns( + pl.struct(pl.col("score").alias("y_hat"), pl.col("epitope").alias("y_true")) + .map_elements(lambda x: sklearn.metrics.roc_auc_score(x["y_true"], x["y_hat"])) + .alias("AUC") +) +mean_auc_pLDDT = aggregate_mean_pLDDT_fp.select("AUC").mean() +print(mean_auc_pLDDT) + +# %% +all_statistics_in_class_fp = st.pl_amino_acids( + all_statistics_in_class_fp, + "../../data/in_class/focal_protein/inference", +) + +# %% +all_statistics_in_class_fp = all_statistics_in_class_fp.unnest("amino_acid_count") + + # %% + amino_acid_count_epitope = { + "A": [], # Alanine + "R": [], # Arginine + "N": [], # Asparagine + "D": [], # Aspartic Acid + "C": [], # Cysteine + "Q": [], # Glutamine + "E": [], # Glutamic Acid + "G": [], # Glycine + "H": [], # Histidine + "I": [], # Isoleucine + "L": [], # Leucine + "K": [], # Lysine + "M": [], # Methionine + "F": [], # Phenylalanine + "P": [], # Proline + "S": [], # Serine + "T": [], # Threonine + "W": [], # Tryptophan + "Y": [], # Tyrosine + "V": [], # Valine + } + amino_acid_count_non_epitope = { + "A": [], # Alanine + "R": [], # Arginine + "N": [], # Asparagine + "D": [], # Aspartic Acid + "C": [], # Cysteine + "Q": [], # Glutamine + "E": [], # Glutamic Acid + "G": [], # Glycine + "H": [], # Histidine + "I": [], # Isoleucine + "L": [], # Leucine + "K": [], # Lysine + "M": [], # Methionine + "F": [], # Phenylalanine + "P": [], # Proline + "S": [], # Serine + "T": [], # Threonine + "W": [], # Tryptophan + "Y": [], # Tyrosine + "V": [], # Valine + } + amino_acid_p_values = { + "A": 0.0, # Alanine + "R": 0.0, # Arginine + "N": 0.0, # Asparagine + "D": 0.0, # Aspartic Acid + "C": 0.0, # Cysteine + "Q": 0.0, # Glutamine + "E": 0.0, # Glutamic Acid + "G": 0.0, # Glycine + "H": 0.0, # Histidine + "I": 0.0, # Isoleucine + "L": 0.0, # Leucine + "K": 0.0, # Lysine + "M": 0.0, # Methionine + "F": 0.0, # Phenylalanine + "P": 0.0, # Proline + "S": 0.0, # Serine + "T": 0.0, # Threonine + "W": 0.0, # Tryptophan + "Y": 0.0, # Tyrosine + "V": 0.0, # Valine + } + amino_acid_t_values = { + "A": 0.0, # Alanine + "R": 0.0, # Arginine + "N": 0.0, # Asparagine + "D": 0.0, # Aspartic Acid + "C": 0.0, # Cysteine + "Q": 0.0, # Glutamine + "E": 0.0, # Glutamic Acid + "G": 0.0, # Glycine + "H": 0.0, # Histidine + "I": 0.0, # Isoleucine + "L": 0.0, # Leucine + "K": 0.0, # Lysine + "M": 0.0, # Methionine + "F": 0.0, # Phenylalanine + "P": 0.0, # Proline + "S": 0.0, # Serine + "T": 0.0, # Threonine + "W": 0.0, # Tryptophan + "Y": 0.0, # Tyrosine + "V": 0.0, # Valine + } + +# %% +filtered_non_epitope = all_statistics_in_class_fp.filter(~pl.col("epitope")) +filtered = all_statistics_in_class_fp.filter(pl.col("epitope")) + +symbols = list(amino_acid_count_epitope.keys()) +for i in range(0, len(symbols)): + amino_acid_count_epitope[symbols[i]] = filtered[symbols[i]].to_list() + amino_acid_count_non_epitope[symbols[i]] = filtered_non_epitope[ + symbols[i] + ].to_list() + t_stat, p_value = stats.ttest_ind( + amino_acid_count_epitope[symbols[i]], amino_acid_count_non_epitope[symbols[i]] + ) + amino_acid_p_values[symbols[i]] = p_value + amino_acid_t_values[symbols[i]] = t_stat + +# %% +amino_acid_t_values + +# %% +st.plot_dictionary_bar_chart( + amino_acid_p_values, + "p_values of amino acids tested from epitope to non-epitope regions", + "Amino Acid Symbols", + "p-values", +) + +# %% +st.plot_dictionary_bar_chart( + amino_acid_t_values, + "t_values of amino acids tested from epitope to non-epitope regions", + "Amino Acid Symbols", + "t-values", +) diff --git a/notebooks/in_class/in_data.py b/notebooks/in_class/in_data.py new file mode 100644 index 0000000..ab899e0 --- /dev/null +++ b/notebooks/in_class/in_data.py @@ -0,0 +1,164 @@ +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.2 +# kernelspec: +# display_name: linear-epitope +# language: python +# name: python3 +# --- + +# %% +import polars as pl + +# %% +# this one is my package +from mdaf3.AF3OutputParser import AF3Output +from mdaf3.FeatureExtraction import * +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import roc_curve, auc +from pathlib import Path +from af3_linear_epitopes import statistics as st +from af3_linear_epitopes import statistics_focal as stf + + +# %% [markdown] +# import dataframes from "staged" directory + +# %% +fp_in_class_dat = pl.read_parquet( + "/scratch/sromero/af3-linear-epitopes/data/in_class/focal_protein/staged/in_class_focal_protein.filt.clust.parquet" +).filter(pl.col("representative")) +in_class_dat = pl.read_parquet( + "/scratch/sromero/af3-linear-epitopes/data/in_class/peptide/staged/in_class_peptide.filt.parquet" +) + +# %% +all_statistics_in_class_fp = pl.read_parquet( + "../../data/in_class/focal_protein/staged/01_in_class.exploded.parquet" +) + + +# %% +in_class_dat + + +# %% +fp_in_class_dat + + +# %% +all_statistics_in_class_fp + + +# %% +y = all_statistics_in_class_fp.select("epitope").to_series() +print(len(y)) +y_hat_fp_30mer = st.normalized_pLDDT_30mer( + all_statistics_in_class_fp, "mean_pLDDT_slice" +) +print(len(y_hat_fp_30mer)) +in_class_norm_30mer_pLDDT_ROC = st.plot_auc_roc_curve( + y, y_hat_fp_30mer, "in_class Normalized focal protein pLDDT mean 30-mer ROC" +) +in_class_norm_30mer_pLDDT_ROC.savefig( + "../../results/figures/in_class_norm_30mer_pLDDT_ROC.png" +) + + +# %% +fp_aggrigate_30mer = all_statistics_in_class_fp.group_by("peptide").agg( + (pl.col("mean_pLDDT_slice").mean()).alias("score"), + pl.col("epitope").first().alias("epitope"), +) + + +# %% +y_hat_geometric_fp = st.normalized_pLDDT_30mer(fp_aggrigate_30mer, "score") +y_true = fp_aggrigate_30mer.select(pl.col("epitope")) +in_class_geo_mean_9mer_pLDDT_ROC = st.plot_auc_roc_curve( + y_true, y_hat_geometric_fp, "in_class geometric mean 9-mer fp ROC" +) +in_class_geo_mean_9mer_pLDDT_ROC.savefig( + "../../results/figures/in_class_geo_mean_9mer_pLDDT_ROC.png" +) + + +# %% +all_statistics_in_class_fp = all_statistics_in_class_fp.with_columns( + (pl.col("pLDDT_slice_9mer").list.eval((pl.element() / 100).log().mean().exp())) + .list.first() + .alias("Geometric_mean_9mer") +) + + +# %% +y_hat_geometric = st.normalized_pLDDT_30mer( + all_statistics_in_class_fp, "Geometric_mean_9mer" +) +in_class_norm_geo_mean_9mer_pLDDT_ROC = st.plot_auc_roc_curve( + y, y_hat_geometric, "in_class Normalized geometric mean 9-mer fp ROC" +) +in_class_norm_geo_mean_9mer_pLDDT_ROC.savefig( + "../../results/figures/in_class_norm_geo_mean_9mer_pLDDT_ROC.png" +) + + +# %% +fp_aggrigate = all_statistics_in_class_fp.group_by("peptide").agg( + (pl.col("Geometric_mean_9mer").mean()).alias("score"), + pl.col("epitope").first().alias("epitope"), +) + + +# %% +y_hat_geometric_fp = fp_aggrigate.select(pl.col("score")) +y_true = fp_aggrigate.select(pl.col("epitope")) +in_class_geo_mean_30mer_pLDDT_ROC = st.plot_auc_roc_curve( + y_true, y_hat_geometric_fp, "in_class geometric mean 30-mer fp ROC" +) +in_class_geo_mean_30mer_pLDDT_ROC.savefig( + "../../results/figures/in_class_geo_mean_30mer_pLDDT_ROC.png" +) + + +# %% +fp_aggrigate_9mer = all_statistics_in_class_fp.group_by("job_name").agg( + (pl.col("Geometric_mean_9mer")).alias("score"), pl.col("epitope").alias("epitope") +) + + +# %% +fp_aggrigate_9mer = fp_aggrigate_9mer.with_columns( + pl.col("score").list.len().alias("#_of_fp_its_in") +) + +# %% +percentage = ( + fp_aggrigate_9mer.filter(pl.col("#_of_fp_its_in") > 1).height + / fp_aggrigate_9mer.height +) * 100 +print(percentage) + + +# %% +import sklearn + +# %% +fp_aggrigate_9mer = fp_aggrigate_9mer.with_columns( + pl.struct(pl.col("score").alias("y_hat"), pl.col("epitope").alias("y_true")) + .map_elements(lambda x: sklearn.metrics.roc_auc_score(x["y_true"], x["y_hat"])) + .alias("AUC") +) + + +# %% +mean_auc = fp_aggrigate_9mer.select("AUC").mean() +print(mean_auc) diff --git a/notebooks/in_class/rsa_sa_in_class_AUC.py b/notebooks/in_class/rsa_sa_in_class_AUC.py new file mode 100644 index 0000000..4419888 --- /dev/null +++ b/notebooks/in_class/rsa_sa_in_class_AUC.py @@ -0,0 +1,166 @@ +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.2 +# kernelspec: +# display_name: linear-epitope +# language: python +# name: python3 +# --- + +# %% +# %load_ext autoreload +# %autoreload 2 + +# %% +import polars as pl + +# %% +# this one is my package +from mdaf3.AF3OutputParser import AF3Output +from mdaf3.FeatureExtraction import * +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import roc_curve, auc +from pathlib import Path +from af3_linear_epitopes import statistics as st +from af3_linear_epitopes import statistics_focal as stf + + +# %% +# import dataframes from "staged" directory +fp_in_class_dat = pl.read_parquet( + "../../data/in_class/focal_protein/staged/in_class_focal_protein.filt.clust.parquet" +).filter(pl.col("representative")) +in_class_dat = pl.read_parquet( + "../../data/in_class/peptide/staged/in_class_peptide.filt.parquet" +) + +# %% +all_statistics_in_class_fp = pl.read_parquet( + "../../data/in_class/focal_protein/staged/01_in_class.exploded.parquet" +) +rsa_data = pl.read_parquet( + "../../data/in_class/focal_protein/staged/05_focal_protein.rsa.parquet" +).filter(pl.col("representative")) + +# %% +all_statistics_in_class_fp = rsa_data.join( + all_statistics_in_class_fp, left_on="job_name", right_on="fp_job_names" +) + + +# %% +all_statistics_in_class_fp = all_statistics_in_class_fp.drop( + "seq_right", "raw_protein_ids_right" +) + + +# %% +all_statistics_in_class_fp = all_statistics_in_class_fp.rename( + {"job_name": "fp_job_name", "job_name_right": "job_name"} +) + + +# %% +all_statistics_in_class_fp = st.rsa_mean(all_statistics_in_class_fp) + + +# %% +y_hat_RSA_fp = st.normalized_pLDDT_30mer(all_statistics_in_class_fp, "mean_rsa_slice") +y_true_RSA = all_statistics_in_class_fp.select(pl.col("epitope")) +in_class_norm_rsa_mean_30mer_ROC = st.plot_auc_roc_curve( + y_true_RSA, y_hat_RSA_fp, "in_class Normalized mean RSA values for 30mer fp ROC" +) +in_class_norm_rsa_mean_30mer_ROC.savefig( + "../../results/figures/in_class_norm_rsa_mean_30mer_ROC.png" +) + + +# %% +y_hat_SA_fp = st.normalized_pLDDT_30mer(all_statistics_in_class_fp, "mean_sa_slice") +y_true_SA = all_statistics_in_class_fp.select(pl.col("epitope")) +in_class_norm_sa_mean_30mer_ROC = st.plot_auc_roc_curve( + y_true_SA, y_hat_SA_fp, "in_class Normalized mean SA values for 30mer fp ROC" +) +in_class_norm_sa_mean_30mer_ROC.savefig( + "../../results/figures/in_class_norm_sa_mean_30mer_ROC.png" +) + + +# %% +aggregate_mean_rsa = all_statistics_in_class_fp.group_by("peptide").agg( + (pl.col("mean_rsa_slice").mean()).alias("score"), + pl.col("epitope").first().alias("epitope"), +) + + +# %% +y_hat_RSA_fp = aggregate_mean_rsa.select(pl.col("score")) +y_true_RSA = aggregate_mean_rsa.select(pl.col("epitope")) +in_class_aggregate_rsa_mean_30mer_ROC = st.plot_auc_roc_curve( + y_true_RSA, y_hat_RSA_fp, "in_class aggregate mean RSA values for 30mer fp ROC" +) +in_class_aggregate_rsa_mean_30mer_ROC.savefig( + "../../results/figures/in_class_aggregate_rsa_mean_30mer_ROC.png" +) + + +# %% +aggregate_mean_sa = all_statistics_in_class_fp.group_by("peptide").agg( + (pl.col("mean_sa_slice").mean()).alias("score"), + pl.col("epitope").first().alias("epitope"), +) + + +# %% +y_hat_SA_fp = aggregate_mean_sa.select(pl.col("score")) +y_true_SA = aggregate_mean_sa.select(pl.col("epitope")) +in_class_aggregate_sa_mean_30mer_ROC = st.plot_auc_roc_curve( + y_true_SA, + y_hat_SA_fp, + "in_class aggregate mean SA values for 30mer fp ROC", +) +in_class_aggregate_sa_mean_30mer_ROC.savefig( + "../../results/figures/in_class_aggregate_sa_mean_30mer_ROC.png" +) + + +# %% +aggregate_mean_rsa_fp = all_statistics_in_class_fp.group_by("fp_job_name").agg( + (pl.col("mean_rsa_slice")).alias("score"), + pl.col("epitope").alias("epitope"), +) +aggregate_mean_sa_fp = all_statistics_in_class_fp.group_by("fp_job_name").agg( + (pl.col("mean_sa_slice")).alias("score"), + pl.col("epitope").alias("epitope"), +) + + +# %% +import sklearn + +# %% +aggregate_mean_rsa_fp = aggregate_mean_rsa_fp.with_columns( + pl.struct(pl.col("score").alias("y_hat"), pl.col("epitope").alias("y_true")) + .map_elements(lambda x: sklearn.metrics.roc_auc_score(x["y_true"], x["y_hat"])) + .alias("AUC") +) +aggregate_mean_sa_fp = aggregate_mean_sa_fp.with_columns( + pl.struct(pl.col("score").alias("y_hat"), pl.col("epitope").alias("y_true")) + .map_elements(lambda x: sklearn.metrics.roc_auc_score(x["y_true"], x["y_hat"])) + .alias("AUC") +) + + +# %% +mean_auc_rsa = aggregate_mean_rsa_fp.select("AUC").mean() +mean_auc_sa = aggregate_mean_sa_fp.select("AUC").mean() +print(mean_auc_rsa) +print(mean_auc_sa) diff --git a/results/figures/geometric_mean_9mer_fp_ROC.png b/results/figures/geometric_mean_9mer_fp_ROC.png new file mode 100644 index 0000000..e72becc Binary files /dev/null and b/results/figures/geometric_mean_9mer_fp_ROC.png differ diff --git a/results/figures/hv_class_geometric_mean_pLDDT_30mer_fp.png b/results/figures/hv_class_geometric_mean_pLDDT_30mer_fp.png new file mode 100644 index 0000000..f8d6a67 Binary files /dev/null and b/results/figures/hv_class_geometric_mean_pLDDT_30mer_fp.png differ diff --git a/results/figures/hv_class_geometric_mean_pLDDT_9mer_fp.png b/results/figures/hv_class_geometric_mean_pLDDT_9mer_fp.png new file mode 100644 index 0000000..09a1ca8 Binary files /dev/null and b/results/figures/hv_class_geometric_mean_pLDDT_9mer_fp.png differ diff --git a/results/figures/hv_class_norm_geometric_mean_pLDDT_9mer_fp.png b/results/figures/hv_class_norm_geometric_mean_pLDDT_9mer_fp.png new file mode 100644 index 0000000..78b55f4 Binary files /dev/null and b/results/figures/hv_class_norm_geometric_mean_pLDDT_9mer_fp.png differ diff --git a/results/figures/hv_class_norm_mean_pLDDT_30mer_fp.png b/results/figures/hv_class_norm_mean_pLDDT_30mer_fp.png new file mode 100644 index 0000000..c84370b Binary files /dev/null and b/results/figures/hv_class_norm_mean_pLDDT_30mer_fp.png differ diff --git a/results/figures/in_class_aggregate_rsa_mean_30mer_ROC.png b/results/figures/in_class_aggregate_rsa_mean_30mer_ROC.png new file mode 100644 index 0000000..faa1849 Binary files /dev/null and b/results/figures/in_class_aggregate_rsa_mean_30mer_ROC.png differ diff --git a/results/figures/in_class_aggregate_sa_mean_30mer_ROC.png b/results/figures/in_class_aggregate_sa_mean_30mer_ROC.png new file mode 100644 index 0000000..a25a683 Binary files /dev/null and b/results/figures/in_class_aggregate_sa_mean_30mer_ROC.png differ diff --git a/results/figures/in_class_geo_mean_30mer_pLDDT_ROC.png b/results/figures/in_class_geo_mean_30mer_pLDDT_ROC.png new file mode 100644 index 0000000..3939b20 Binary files /dev/null and b/results/figures/in_class_geo_mean_30mer_pLDDT_ROC.png differ diff --git a/results/figures/in_class_geo_mean_9mer_pLDDT_ROC.png b/results/figures/in_class_geo_mean_9mer_pLDDT_ROC.png new file mode 100644 index 0000000..6884764 Binary files /dev/null and b/results/figures/in_class_geo_mean_9mer_pLDDT_ROC.png differ diff --git a/results/figures/in_class_norm_30mer_pLDDT_ROC.png b/results/figures/in_class_norm_30mer_pLDDT_ROC.png new file mode 100644 index 0000000..16aa577 Binary files /dev/null and b/results/figures/in_class_norm_30mer_pLDDT_ROC.png differ diff --git a/results/figures/in_class_norm_geo_mean_9mer_pLDDT_ROC.png b/results/figures/in_class_norm_geo_mean_9mer_pLDDT_ROC.png new file mode 100644 index 0000000..fc48a68 Binary files /dev/null and b/results/figures/in_class_norm_geo_mean_9mer_pLDDT_ROC.png differ diff --git a/results/figures/in_class_norm_rsa_mean_30mer_ROC.png b/results/figures/in_class_norm_rsa_mean_30mer_ROC.png new file mode 100644 index 0000000..0be96f2 Binary files /dev/null and b/results/figures/in_class_norm_rsa_mean_30mer_ROC.png differ diff --git a/results/figures/in_class_norm_sa_mean_30mer_ROC.png b/results/figures/in_class_norm_sa_mean_30mer_ROC.png new file mode 100644 index 0000000..dbd5ac0 Binary files /dev/null and b/results/figures/in_class_norm_sa_mean_30mer_ROC.png differ diff --git a/results/figures/inverse_norm_rsa_mean_30mer_fp_hv_class.png b/results/figures/inverse_norm_rsa_mean_30mer_fp_hv_class.png new file mode 100644 index 0000000..c7b88c2 Binary files /dev/null and b/results/figures/inverse_norm_rsa_mean_30mer_fp_hv_class.png differ diff --git a/results/figures/inverse_norm_sa_mean_30mer_fp_hv_class.png b/results/figures/inverse_norm_sa_mean_30mer_fp_hv_class.png new file mode 100644 index 0000000..1f263c4 Binary files /dev/null and b/results/figures/inverse_norm_sa_mean_30mer_fp_hv_class.png differ diff --git a/results/figures/norm_atomic_weight_30mer_ROC.png b/results/figures/norm_atomic_weight_30mer_ROC.png new file mode 100644 index 0000000..268ed8c Binary files /dev/null and b/results/figures/norm_atomic_weight_30mer_ROC.png differ diff --git a/results/figures/norm_fp_pLDDT_mean_30mer_ROC.png b/results/figures/norm_fp_pLDDT_mean_30mer_ROC.png new file mode 100644 index 0000000..3e28593 Binary files /dev/null and b/results/figures/norm_fp_pLDDT_mean_30mer_ROC.png differ diff --git a/results/figures/norm_geometric_mean_9mer_fp_ROC.png b/results/figures/norm_geometric_mean_9mer_fp_ROC.png new file mode 100644 index 0000000..a2f8d82 Binary files /dev/null and b/results/figures/norm_geometric_mean_9mer_fp_ROC.png differ diff --git a/results/figures/norm_max_weight_9mer_fp_ROC.png b/results/figures/norm_max_weight_9mer_fp_ROC.png new file mode 100644 index 0000000..f068712 Binary files /dev/null and b/results/figures/norm_max_weight_9mer_fp_ROC.png differ diff --git a/results/figures/norm_mean_pLDDT_9mer_fp_ROC.png b/results/figures/norm_mean_pLDDT_9mer_fp_ROC.png new file mode 100644 index 0000000..e782a5a Binary files /dev/null and b/results/figures/norm_mean_pLDDT_9mer_fp_ROC.png differ diff --git a/results/figures/norm_min_weight_9mer_fp_ROC.png b/results/figures/norm_min_weight_9mer_fp_ROC.png new file mode 100644 index 0000000..d26260b Binary files /dev/null and b/results/figures/norm_min_weight_9mer_fp_ROC.png differ diff --git a/results/figures/norm_weight_30mer_fp_ROC.png b/results/figures/norm_weight_30mer_fp_ROC.png new file mode 100644 index 0000000..7309604 Binary files /dev/null and b/results/figures/norm_weight_30mer_fp_ROC.png differ diff --git a/results/figures/pLDDT_ROC_max_9mers.png b/results/figures/pLDDT_ROC_max_9mers.png new file mode 100644 index 0000000..bb4d68b Binary files /dev/null and b/results/figures/pLDDT_ROC_max_9mers.png differ diff --git a/results/figures/pLDDT_ROC_max_9mers_ROC.png b/results/figures/pLDDT_ROC_max_9mers_ROC.png new file mode 100644 index 0000000..bb4d68b Binary files /dev/null and b/results/figures/pLDDT_ROC_max_9mers_ROC.png differ diff --git a/results/figures/pLDDT_ROC_mean_30mers.png b/results/figures/pLDDT_ROC_mean_30mers.png new file mode 100644 index 0000000..f59b219 Binary files /dev/null and b/results/figures/pLDDT_ROC_mean_30mers.png differ diff --git a/results/figures/pLDDT_ROC_min_9mers.png b/results/figures/pLDDT_ROC_min_9mers.png new file mode 100644 index 0000000..2061b8d Binary files /dev/null and b/results/figures/pLDDT_ROC_min_9mers.png differ diff --git a/results/figures/pLDDT_statistics_30mer_epitope_vs_non-epitopes.png b/results/figures/pLDDT_statistics_30mer_epitope_vs_non-epitopes.png new file mode 100644 index 0000000..2ce43cf Binary files /dev/null and b/results/figures/pLDDT_statistics_30mer_epitope_vs_non-epitopes.png differ diff --git a/results/figures/pLDDT_statistics_9mer_epitope_vs_non-epitopes.png b/results/figures/pLDDT_statistics_9mer_epitope_vs_non-epitopes.png new file mode 100644 index 0000000..4c28f2a Binary files /dev/null and b/results/figures/pLDDT_statistics_9mer_epitope_vs_non-epitopes.png differ diff --git a/results/figures/rsa_mean_aggregate_30mer_fp_hv_class.png b/results/figures/rsa_mean_aggregate_30mer_fp_hv_class.png new file mode 100644 index 0000000..a0d500e Binary files /dev/null and b/results/figures/rsa_mean_aggregate_30mer_fp_hv_class.png differ diff --git a/results/figures/sa_mean_aggregate_30mer_fp_hv_class.png b/results/figures/sa_mean_aggregate_30mer_fp_hv_class.png new file mode 100644 index 0000000..4b23c3f Binary files /dev/null and b/results/figures/sa_mean_aggregate_30mer_fp_hv_class.png differ diff --git a/scripts/hv_class/05_run_extract_feat_focal_protein.sh b/scripts/hv_class/05_run_extract_feat_focal_protein.sh new file mode 100755 index 0000000..28e9a54 --- /dev/null +++ b/scripts/hv_class/05_run_extract_feat_focal_protein.sh @@ -0,0 +1,19 @@ +#!/bin/bash +#SBATCH --job-name=extract_feat +#SBATCH --mail-type=ALL +#SBATCH --mail-user=lwoods@tgen.org +#SBATCH --ntasks=1 +#SBATCH --mem=64G +#SBATCH --time=5-00:00:00 +#SBATCH -c 16 +#SBATCH --output=tmp/nextflow/hv_class/focal_protein/extract_feat.%j.log + +. ./scripts/setup.sh + +# env vars +export NXF_LOG_FILE=tmp/nextflow/hv_class/focal_protein/extract_feat/nextflow.log +export NXF_CACHE_DIR=tmp/nextflow/hv_class/focal_protein/extract_feat/ + +nextflow run \ + ./workflows/05_extract_feat_focal_protein.nf \ + --dset_name hv_class diff --git a/scripts/in_class/05_run_extract_feat_focal_protein.sh b/scripts/in_class/05_run_extract_feat_focal_protein.sh new file mode 100755 index 0000000..145431a --- /dev/null +++ b/scripts/in_class/05_run_extract_feat_focal_protein.sh @@ -0,0 +1,19 @@ +#!/bin/bash +#SBATCH --job-name=extract_feat +#SBATCH --mail-type=ALL +#SBATCH --mail-user=lwoods@tgen.org +#SBATCH --ntasks=1 +#SBATCH --mem=64G +#SBATCH --time=5-00:00:00 +#SBATCH -c 16 +#SBATCH --output=tmp/nextflow/in_class/focal_protein/extract_feat.%j.log + +. ./scripts/setup.sh + +# env vars +export NXF_LOG_FILE=tmp/nextflow/in_class/focal_protein/extract_feat/nextflow.log +export NXF_CACHE_DIR=tmp/nextflow/in_class/focal_protein/extract_feat/ + +nextflow run \ + ./workflows/05_extract_feat_focal_protein.nf \ + --dset_name in_class diff --git a/scripts/test/05_run_extract_feat_focal_protein.sh b/scripts/test/05_run_extract_feat_focal_protein.sh new file mode 100755 index 0000000..d4564be --- /dev/null +++ b/scripts/test/05_run_extract_feat_focal_protein.sh @@ -0,0 +1,19 @@ +#!/bin/bash +#SBATCH --job-name=extract_feat +#SBATCH --mail-type=ALL +#SBATCH --mail-user=lwoods@tgen.org +#SBATCH --ntasks=1 +#SBATCH --mem=64G +#SBATCH --time=5-00:00:00 +#SBATCH -c 16 +#SBATCH --output=tmp/nextflow/test/focal_protein/extract_feat.%j.log + +. ./scripts/setup.sh + +# env vars +export NXF_LOG_FILE=tmp/nextflow/test/focal_protein/extract_feat/nextflow.log +export NXF_CACHE_DIR=tmp/nextflow/test/focal_protein/extract_feat/ + +nextflow run \ + ./workflows/05_extract_feat_focal_protein.nf \ + --dset_name test diff --git a/src/af3_linear_epitopes/sasa.py b/src/af3_linear_epitopes/sasa.py new file mode 100644 index 0000000..152fe27 --- /dev/null +++ b/src/af3_linear_epitopes/sasa.py @@ -0,0 +1,64 @@ +from MDAnalysis.exceptions import NoDataError +from mdakit_sasa.analysis.sasaanalysis import SASAAnalysis +import freesasa +import numpy as np + + +class PatchedSASAAnalysis(SASAAnalysis): + + def _prepare(self): + self.results.total_area = np.zeros( + self.n_frames, + dtype=float, + ) + self.results.residue_area = np.zeros( + (self.n_frames, len(self.universe.residues.resids)), + dtype=float, + ) + self.results.relative_residue_area = np.zeros( + (self.n_frames, len(self.universe.residues.resids)), + dtype=float, + ) + + def _single_frame(self): + """Calculate data from a single frame of trajectory""" + + structure = freesasa.Structure() + # FreeSasa structure accepts PDBS if not available requires to reconstruct the structure using `addAtom` + for a in self.atomgroup: + x, y, z = a.position + try: + resname = a.resname + except NoDataError: + resname = "ANY" # Default classifier value + + structure.addAtom( + a.type.rjust(2), resname, a.resnum.item(), a.segid, x, y, z + ) + + # Define 1 cpu for windows avoid freesasa code to calculate it. + parametes = freesasa.Parameters() + if self._is_windows(): + parametes.setNThreads(1) + + result = freesasa.calc(structure, parametes) + + residue_areas = [ + result.residueAreas()[s][r] + for s in list(result.residueAreas().keys()) + for r in list(result.residueAreas()[s].keys()) + ] + self.results.total_area[self._frame_index] = result.totalArea() + + # Defend agains residue counts mismatch + if len(self.universe.residues.resids) != len(residue_areas): + logger.error( + f"Residude count do not match the expectation, residue SASA not in results { len(self.universe.residues.resids)} != {len(residue_areas)}" + ) + else: + self.results.residue_area[self._frame_index] = [ + r.total for r in residue_areas + ] + self.results.relative_residue_area[self._frame_index] = [ + r.relativeTotal for r in residue_areas + ] diff --git a/src/af3_linear_epitopes/statistics.py b/src/af3_linear_epitopes/statistics.py new file mode 100644 index 0000000..111ecab --- /dev/null +++ b/src/af3_linear_epitopes/statistics.py @@ -0,0 +1,1034 @@ +from .sasa import PatchedSASAAnalysis +import polars as pl +import matplotlib.pyplot as plt +import numpy as np +from MDAnalysis.analysis.dssp import DSSP +from mdaf3.FeatureExtraction import * +from mdaf3.AF3OutputParser import AF3Output +from pathlib import Path +from sklearn.metrics import roc_curve, auc +import matplotlib.patches as mpatches + +# adds the basic statistic data(ex: mean, min, and standard deviation) to the dataset +CHUNKSIZE = 15 + + +def polar_charge(row): + polar_amino_acid = ["S", "T", "N", "Q", "C", "Y", "D", "E", "K", "R", "H"] + non_polar_amino_acid = ["A", "V", "L", "I", "P", "F", "W", "M", "G"] + seq = row["peptide"] + polar_count = 0 + non_polar = 0 + for i in range(0, len(seq)): + if seq[i] in polar_amino_acid: + polar_count += 1 + else: + non_polar += 1 + row["polar"] = polar_count + row["non_polar"] = non_polar + return row + + +def pl_polar_charge(dataset, path): + amino_polar_charge = split_apply_combine(dataset, polar_charge, chunksize=CHUNKSIZE) + return amino_polar_charge + + +# mean rsa value +def rsa_mean(dataset): + dataset = dataset.with_columns( + pl.col("RSA") + .list.slice(offset=pl.col("fp_seq_idxs"), length=30) + .list.mean() + .alias("mean_rsa_slice") + ) + dataset = dataset.with_columns( + pl.col("SA") + .list.slice(offset=pl.col("fp_seq_idxs"), length=30) + .list.mean() + .alias("mean_sa_slice") + ) + return dataset + + +# distance from head and tail of amino acid +def distance(dataset): + dataset = dataset.with_columns( + (pl.col("fp_seq_idxs") + 14.5).alias("distance_from_head") + ) + dataset = dataset.with_columns( + (pl.col("seq").str.len_chars() - (pl.col("fp_seq_idxs") + 14.5)).alias( + "distance_from_tail" + ) + ) + return dataset + + +# finding amino acid sequence patterns +def amino_acid_freq(row, path): + af3 = AF3Output(Path(path) / row["job_name"]) + seq = row["peptide"] + amino_acids = { + "A": 0, # Alanine + "R": 0, # Arginine + "N": 0, # Asparagine + "D": 0, # Aspartic Acid + "C": 0, # Cysteine + "Q": 0, # Glutamine + "E": 0, # Glutamic Acid + "G": 0, # Glycine + "H": 0, # Histidine + "I": 0, # Isoleucine + "L": 0, # Leucine + "K": 0, # Lysine + "M": 0, # Methionine + "F": 0, # Phenylalanine + "P": 0, # Proline + "S": 0, # Serine + "T": 0, # Threonine + "W": 0, # Tryptophan + "Y": 0, # Tyrosine + "V": 0, # Valine + } + for i in range(0, len(seq)): + amino_acids[seq[i]] += 1 + row["most_frequent_amino_acid"] = max_key = max(amino_acids, key=amino_acids.get) + row["amino_acid_count"] = amino_acids + return row + + +def pl_amino_acids(dataset, path): + amino_acid = split_apply_combine( + dataset, amino_acid_freq, path, chunksize=CHUNKSIZE + ) + return amino_acid + + +# finding the SASA for our datasets +def sasa_fp(row, path): + af3 = AF3Output(Path(path) / row["job_name"]) + u = af3.get_mda_universe() + analysis = PatchedSASAAnalysis(u) + analysis.run() + row["RSA"] = analysis.results.relative_residue_area[0].tolist() + row["SA"] = analysis.results.residue_area[0].tolist() + + return row + + +def pl_sasa_fp(dataset, path): + area = split_apply_combine(dataset, sasa_fp, path, chunksize=CHUNKSIZE) + return area + + +# finding the avg atomic weight for the 30-mers +def avg_atomic_weight_30mer(row, path): + af3 = AF3Output(Path(path) / row["job_name"]) + u = af3.get_mda_universe() + weight = u.atoms.total_mass() + avg_weight = weight + row["atomic_weight"] = avg_weight + return row + + +def avg_atomic_weight_9mer(row, path): + af3 = AF3Output(Path(path) / row["job_name"]) + u = af3.get_mda_universe() + index = 0 + mass = [] + for j in range(index, index + 22): + mass.append(u.residues[j : j + 9].atoms.total_mass()) + row["9mer_weight"] = mass + return row + + +def pl_9mer_weight(dataset, path): + weight_9mer = split_apply_combine( + dataset, avg_atomic_weight_9mer, path, chunksize=CHUNKSIZE + ) + return weight_9mer + + +def pl_avg_weight(dataset, path): + avg_weight = split_apply_combine( + dataset, avg_atomic_weight_30mer, path, chunksize=CHUNKSIZE + ) + return avg_weight + + +def raw_helix_indices_bool(sel): + # find helices + # https://docs.mdanalysis.org/2.8.0/documentation_pages/analysis/dssp.html + helix_resindices_boolmask = DSSP(sel).run().results.dssp_ndarray[0, :, 1] + return helix_resindices_boolmask.tolist() + + +def raw_beta_indices_bool(sel): + # find helices + # https://docs.mdanalysis.org/2.8.0/documentation_pages/analysis/dssp.html + beta_resindices_boolmask = DSSP(sel).run().results.dssp_ndarray[0, :, 2] + return beta_resindices_boolmask.tolist() + + +def raw_loop_indices_bool(sel): + # find helices + # https://docs.mdanalysis.org/2.8.0/documentation_pages/analysis/dssp.html + beta_resindices_boolmask = DSSP(sel).run().results.dssp_ndarray[0, :, 0] + return beta_resindices_boolmask.tolist() + + +# Recreating an error code:______________________________________________________ +def pl_structure_error(dataset, path): + beta_dataset = split_apply_combine(dataset, beta_error, path, chunksize=CHUNKSIZE) + beta_helix_dataset = split_apply_combine( + beta_dataset, helix_error, path, chunksize=CHUNKSIZE + ) + beta_helix_loop_dataset = split_apply_combine( + beta_helix_dataset, loop_error, path, chunksize=CHUNKSIZE + ) + beta_helix_loop_dataset = beta_helix_loop_dataset.with_columns( + (pl.col("helix") / 30).alias("helix_percentage"), + (pl.col("beta") / 30).alias("beta_sheet_percentage"), + (pl.col("loop") / 30).alias("loop_percentage"), + ) + return beta_helix_loop_dataset + + +def loop_error(row, path): + af3 = AF3Output(Path(path) / row["fp_job_names"]) + u = af3.get_mda_universe() + row["loop"] = sum(raw_loop_indices_bool(u)) + return row + + +# finding the beta pleats of the 30-mer +def beta_error(row, path): + af3 = AF3Output(Path(path) / row["fp_job_names"]) + u = af3.get_mda_universe() + row["beta"] = sum(raw_beta_indices_bool(u)) + return row + + +# finding the helix's of the 30mers +def helix_error(row, path): + af3 = AF3Output(Path(path) / row["fp_job_names"]) + u = af3.get_mda_universe() + row["helix"] = sum(raw_helix_indices_bool(u)) + return row + + +# ___________________________________________________________ + + +def pl_structure(dataset, path): + beta_dataset = split_apply_combine(dataset, beta, path, chunksize=CHUNKSIZE) + beta_helix_dataset = split_apply_combine( + beta_dataset, helix, path, chunksize=CHUNKSIZE + ) + beta_helix_loop_dataset = split_apply_combine( + beta_helix_dataset, loop, path, chunksize=CHUNKSIZE + ) + beta_helix_loop_dataset = beta_helix_loop_dataset.with_columns( + (pl.col("helix") / 30).alias("helix_percentage"), + (pl.col("beta") / 30).alias("beta_sheet_percentage"), + (pl.col("loop") / 30).alias("loop_percentage"), + ) + return beta_helix_loop_dataset + + +def pl_structure_fp(dataset, path): + dataset = split_apply_combine(dataset, structure, path, chunksize=CHUNKSIZE) + return dataset + + +def structure(row, path): + af3 = AF3Output(Path(path) / row["fp_job_names"]) + u = af3.get_mda_universe() + index = row["fp_seq_idxs"] + row["loop"] = raw_loop_indices_bool(u) + row["beta"] = raw_beta_indices_bool(u) + row["helix"] = raw_helix_indices_bool(u) + row["loop"] = sum(row["loop"][index : index + 30]) + row["beta"] = sum(row["beta"][index : index + 30]) + row["helix"] = sum(row["helix"][index : index + 30]) + return row + + +def loop(row, path): + af3 = AF3Output(Path(path) / row["job_name"]) + u = af3.get_mda_universe() + row["loop"] = sum(raw_loop_indices_bool(u)) + return row + + +# finding the beta pleats of the 30-mer +def beta(row, path): + af3 = AF3Output(Path(path) / row["job_name"]) + u = af3.get_mda_universe() + row["beta"] = sum(raw_beta_indices_bool(u)) + return row + + +def pl_beta(dataset, path): + beta_dataset = split_apply_combine(dataset, beta, path, chunksize=CHUNKSIZE) + return beta_dataset + + +# finding the helix's of the 30mers +def helix(row, path): + af3 = AF3Output(Path(path) / row["job_name"]) + u = af3.get_mda_universe() + row["helix"] = sum(raw_helix_indices_bool(u)) + return row + + +def pl_helix(dataset, path): + helix_dataset = split_apply_combine(dataset, helix, path, chunksize=CHUNKSIZE) + return helix_dataset + + +# creating col of statistics(mean,min,std) of every 9mer peptide +def statistics_9mer(dataset, path): + mean_dataset_9mer = pl_mean_9mer(dataset, path) + min_mean_dataset_9mer = pl_min_9mer(mean_dataset_9mer, path) + return pl_std_9mer(min_mean_dataset_9mer, path) + + +def peptide_9mer(dataset, path): + col_list = [] + peptides = dataset.select("peptide").to_series() + for j in range(0, len(peptides)): + peptide_30mer = peptides[j] + nine_mer_seq = [] + for i in range(0, 22): + nine_mer_seq.append(peptide_30mer[i : i + 9]) + col_list.append(nine_mer_seq) + dataset = dataset.with_columns(pl.Series(col_list).alias("9mer_seq")) + return dataset + + +def mean_func_9mer(row, path): + af3 = AF3Output(Path(path) / row["job_name"]) + pLDDT = af3.get_mda_universe().atoms.select_atoms("name CA").tempfactors + pLDDT_9mer = [] + for i in range(0, 22): + pLDDT_9mer.append(pLDDT[i : i + 9].mean()) + + row["9mer_Mean_pLDDT"] = pLDDT_9mer + return row + + +def pl_mean_9mer(dataset, path): + mean_dataset = split_apply_combine( + dataset, mean_func_9mer, path, chunksize=CHUNKSIZE + ) + return mean_dataset + + +def min_func_9mer(row, path): + af3 = AF3Output(Path(path) / row["job_name"]) + pLDDT = af3.get_mda_universe().atoms.select_atoms("name CA").tempfactors + pLDDT_9mer = [] + for i in range(0, 22): + pLDDT_9mer.append(pLDDT[i : i + 9].min()) + + row["9mer_min_pLDDT"] = pLDDT_9mer + return row + + +def pl_min_9mer(dataset, path): + min_dataset = split_apply_combine(dataset, min_func_9mer, path, chunksize=CHUNKSIZE) + return min_dataset + + +def std_func_9mer(row, path): + af3 = AF3Output(Path(path) / row["job_name"]) + pLDDT = af3.get_mda_universe().atoms.select_atoms("name CA").tempfactors + pLDDT_9mer = [] + for i in range(0, 22): + pLDDT_9mer.append(pLDDT[i : i + 9].std()) + + row["9mer_std_pLDDT"] = pLDDT_9mer + return row + + +def pl_std_9mer(dataset, path): + std_dataset = split_apply_combine(dataset, std_func_9mer, path, chunksize=CHUNKSIZE) + return std_dataset + + +# Now working with PAE values +def pae_statistics(dataset, path): + mean_dataset_pae = pl_pae_mean(dataset, path) + min_mean_dataset_pae = pl_pae_min(mean_dataset_pae, path) + return pl_pae_std(min_mean_dataset_pae, path) + + +def pae_mean_func(row, path): + af3 = AF3Output(Path(path) / row["job_name"]) + u = af3.get_mda_universe() + residx = u.residues[0:30].resindices + row["mean_PAE_values"] = af3.get_pae_ndarr()[residx].mean() + return row + + +def pl_pae_mean(dataset, path): + mean_dataset = split_apply_combine( + dataset, pae_mean_func, path, chunksize=CHUNKSIZE + ) + return mean_dataset + + +def pae_min_func(row, path): + af3 = AF3Output(Path(path) / row["job_name"]) + u = af3.get_mda_universe() + residx = u.residues[0:30].resindices + row["min_PAE_values"] = af3.get_pae_ndarr()[residx].min() + return row + + +def pl_pae_min(dataset, path): + min_dataset = split_apply_combine(dataset, pae_min_func, path, chunksize=CHUNKSIZE) + return min_dataset + + +def pae_std_func(row, path): + af3 = AF3Output(Path(path) / row["job_name"]) + u = af3.get_mda_universe() + residx = u.residues[0:30].resindices + row["std_PAE_values"] = af3.get_pae_ndarr()[residx].std() + return row + + +def pl_pae_std(dataset, path): + std_dataset = split_apply_combine(dataset, pae_std_func, path, chunksize=CHUNKSIZE) + return std_dataset + + +def statistics(dataset, path): + mean_dataset = pl_mean(dataset, path) + min_mean_dataset = pl_min(mean_dataset, path) + return pl_std(min_mean_dataset, path) + + +# normalizing the pLDDT values +def normalized_pLDDT_30mer(dataset, colname: str, inverse: int): + max_pLDDT = dataset.select(pl.col(colname)).max().item() + print("max:" + str(max_pLDDT)) + min_pLDDT = dataset.select(pl.col(colname)).min().item() + print("min:" + str(min_pLDDT)) + if inverse == -1: + normalized_series = ( + dataset.with_columns( + ((1 - (pl.col(colname) - min_pLDDT) / (max_pLDDT - min_pLDDT))).alias( + "normalized_pLDDT" + ) + ) + .select(pl.col("normalized_pLDDT")) + .to_series() + .to_numpy() + ) + else: + normalized_series = ( + dataset.with_columns( + (((pl.col(colname) - min_pLDDT) / (max_pLDDT - min_pLDDT))).alias( + "normalized_pLDDT" + ) + ) + .select(pl.col("normalized_pLDDT")) + .to_series() + .to_numpy() + ) + return normalized_series + + +# Is the functions that appends the mean column to the dataset +def mean_func(row, path): + af3 = AF3Output(Path(path) / row["job_name"]) + row["Mean_pLDDT"] = ( + af3.get_mda_universe().atoms.select_atoms("name CA").tempfactors.mean() + ) + return row + + +# same for mean but for minimum +def min_func(row, path): + af3 = AF3Output(Path(path) / row["job_name"]) + row["Min_pLDDT"] = ( + af3.get_mda_universe().atoms.select_atoms("name CA").tempfactors.min() + ) + return row + + +# for standard deviation +def std_func(row, path): + af3 = AF3Output(Path(path) / row["job_name"]) + row["Std_pLDDT"] = ( + af3.get_mda_universe().atoms.select_atoms("name CA").tempfactors.std() + ) + return row + + +# applies the mean function to the dataset +def pl_mean(dataset, path): + mean_dataset = split_apply_combine(dataset, mean_func, path, chunksize=CHUNKSIZE) + return mean_dataset + + +# applies the min function to the dataset +def pl_min(dataset, path): + min_dataset = split_apply_combine(dataset, min_func, path, chunksize=CHUNKSIZE) + return min_dataset + + +# applies the standard deviation to the dataset +def pl_std(dataset, path): + std_dataset = split_apply_combine(dataset, std_func, path, chunksize=CHUNKSIZE) + return std_dataset + + +# box and whisker plot +def display_boxplot(data, title="Box and Whisker Plot", x_label="", y_label="Value"): + """ + Displays a box and whisker plot for the given data using Matplotlib and Polars. + + Args: + data (list, numpy.ndarray, polars.Series, polars.DataFrame, or dict/list of such): + The data to plot. + - If a single list, numpy.ndarray, or polars.Series: a single box plot. + - If a polars.DataFrame: + - If it has one numeric column, that column will be plotted. + - If it has multiple numeric columns, each will get a box plot. + - If it has a 'category' column and a 'value' column, it will plot + box plots per category. + - If a dictionary: keys are categories, values are lists/arrays/Series of data. + - If a list of lists/arrays/Series: each inner list/array/Series represents a category. + title (str, optional): The title of the plot. Defaults to "Box and Whisker Plot". + x_label (str, optional): The label(s) for the x-axis. + - If a string, used as the overall x-axis label. + - If a list of strings, used as tick labels for multiple categories. + Defaults to an empty string. + y_label (str, optional): The label for the y-axis. Defaults to "Value". + """ + plot_data = [] + category_labels = [] + + # --- Data Preparation Logic using Polars --- + if isinstance(data, (list, np.ndarray)): + # Single dataset (list or numpy array) + plot_data.append(data) + category_labels.append("") # No specific category label for a single plot + elif isinstance(data, pl.Series): + # Single Polars Series + plot_data.append(data.to_list()) + category_labels.append("") + elif isinstance(data, pl.DataFrame): + # Polars DataFrame handling + if "category" in data.columns and "value" in data.columns: + # Assume long format: 'category' column for grouping, 'value' for data + grouped = data.group_by("category").agg( + pl.col("value").list().alias("values") + ) + for row in grouped.iter_rows(named=True): + plot_data.append(row["values"]) + category_labels.append(str(row["category"])) + if not x_label: + x_label = "Category" + else: + # Plot each numeric column as a separate box + for col_name in data.columns: + if data[col_name].dtype.is_numeric(): # Check if column is numeric + plot_data.append(data[col_name].to_list()) + category_labels.append(col_name) + if not x_label: + x_label = "Columns" # Default label for multiple columns + + elif isinstance(data, dict): + # Dictionary of datasets (keys are categories) + for key, value in data.items(): + if isinstance(value, pl.Series): + plot_data.append(value.to_list()) + elif isinstance(value, (list, np.ndarray)): + plot_data.append(value) + else: + print( + f"Warning: Skipping unsupported data type for key '{key}' in dictionary." + ) + continue + category_labels.append(str(key)) + if not x_label: + x_label = "Category" + + elif isinstance(data, list) and all( + isinstance(d, (list, np.ndarray, pl.Series)) for d in data + ): + # List of datasets (each element is a category) + for i, dataset in enumerate(data): + if isinstance(dataset, pl.Series): + plot_data.append(dataset.to_list()) + elif isinstance(dataset, (list, np.ndarray)): + plot_data.append(dataset) + else: + continue # Should not happen due to all() check + cat_label = f"Category {i+1}" + if isinstance(x_label, list) and i < len(x_label): + cat_label = x_label[i] + category_labels.append(cat_label) + if not x_label: + x_label = "Category" + else: + print( + "Error: Unsupported data format. Please provide a list, numpy array, polars Series/DataFrame, or a list/dictionary of such for multiple plots." + ) + return + + # --- Plotting with Matplotlib --- + if not plot_data: + print("No valid data to plot.") + return + + plt.figure(figsize=(8, 6)) + + # Handle single vs. multiple box plots + if ( + len(plot_data) == 1 and not category_labels[0] + ): # Single plot, no explicit category + plt.boxplot(plot_data[0]) + plt.tick_params( + axis="x", which="both", bottom=False, top=False, labelbottom=False + ) # Hide x-axis ticks/labels + else: + plt.boxplot(plot_data) + if category_labels and all(category_labels): # If we have valid category labels + plt.xticks( + ticks=np.arange(1, len(category_labels) + 1), + labels=category_labels, + rotation=45, + ha="right", + ) + plt.xlabel(x_label) # Set x-axis label if provided + plt.title(title) + plt.ylabel(y_label) + plt.grid(axis="y", linestyle="--", alpha=0.7) + plt.tight_layout() + plt.show() + + +# the code below are functions to create the bar graphs and AUC curves +def plot_epitope_non_epitope_stats_9mer( + avg_true_mean_min_9mer: float, + avg_true_std_min_9mer: float, + avg_false_mean_min_9mer: float, + avg_false_std_min_9mer: float, +): + """ + Creates a grouped bar graph comparing mean, minimum, and standard deviation + of pLDDT values for Epitopes and Non-Epitopes. + + Args: + avg_true_mean_min_9mer (float): Average mean pLDDT for epitopes. + avg_true_std_min_9mer (float): Average standard deviation pLDDT for epitopes. + avg_false_mean_min_9mer (float): Average mean pLDDT for non-epitopes. + avg_false_std_min_9mer (float): Average standard deviation pLDDT for non-epitopes. + """ + categories = ["Epitope", "Non-Epitope"] + # Data for each statistic type + mean_values = [avg_true_mean_min_9mer, avg_false_mean_min_9mer] + std_values = [avg_true_std_min_9mer, avg_false_std_min_9mer] + + # Set up bar positions + x = np.arange(len(categories)) # the label locations + width = 0.25 # the width of the bars + + fig, ax = plt.subplots(figsize=(10, 7)) + + # Create bars for Mean, Min, and Std Dev for both categories + rects1 = ax.bar( + x - width, + mean_values, + width, + label="Mean pLDDT", + color="skyblue", + edgecolor="grey", + ) + rects3 = ax.bar( + x + width, + std_values, + width, + label="Std Dev pLDDT", + color="lightgreen", + edgecolor="grey", + ) + + # Add labels, title, and custom x-axis tick labels + ax.set_ylabel("pLDDT Value", fontsize=12) + ax.set_title("pLDDT Statistics 9-mer: Epitopes vs Non-Epitopes", fontsize=16) + ax.set_xticks(x) + ax.set_xticklabels(categories, fontsize=12) + ax.legend() + ax.grid(axis="y", linestyle="--", alpha=0.7) + + # Add value labels on top of the bars + def autolabel_single_bar(rects): + for rect in rects: + height = rect.get_height() + ax.annotate( + f"{height:.2f}", + xy=(rect.get_x() + rect.get_width() / 2, height), + xytext=(0, 3), # 3 points vertical offset + textcoords="offset points", + ha="center", + va="bottom", + fontsize=9, + ) + + autolabel_single_bar(rects1) + autolabel_single_bar(rects3) + + plt.tight_layout() + return fig + + +def plot_epitope_non_epitope_stats_30mer( + avg_true_mean: float, + avg_true_min: float, + avg_true_std: float, + avg_false_mean: float, + avg_false_min: float, + avg_false_std: float, +): + """ + Creates a grouped bar graph comparing mean, minimum, and standard deviation + of pLDDT values for Epitopes and Non-Epitopes. + + Args: + avg_true_mean (float): Average mean pLDDT for epitopes. + avg_true_min (float): Average minimum pLDDT for epitopes. + avg_true_std (float): Average standard deviation pLDDT for epitopes. + avg_false_mean (float): Average mean pLDDT for non-epitopes. + avg_false_min (float): Average minimum pLDDT for non-epitopes. + avg_false_std (float): Average standard deviation pLDDT for non-epitopes. + """ + categories = ["Epitope", "Non-Epitope"] + # Data for each statistic type + mean_values = [avg_true_mean, avg_false_mean] + min_values = [avg_true_min, avg_false_min] + std_values = [avg_true_std, avg_false_std] + + # Set up bar positions + x = np.arange(len(categories)) # the label locations + width = 0.25 # the width of the bars + + fig, ax = plt.subplots(figsize=(10, 7)) + + # Create bars for Mean, Min, and Std Dev for both categories + rects1 = ax.bar( + x - width, + mean_values, + width, + label="Mean pLDDT", + color="skyblue", + edgecolor="grey", + ) + rects2 = ax.bar( + x, min_values, width, label="Min pLDDT", color="lightcoral", edgecolor="grey" + ) + rects3 = ax.bar( + x + width, + std_values, + width, + label="Std Dev pLDDT", + color="lightgreen", + edgecolor="grey", + ) + + # Add labels, title, and custom x-axis tick labels + ax.set_ylabel("pLDDT Value", fontsize=12) + ax.set_title("pLDDT Statistics 30-mer: Epitopes vs Non-Epitopes", fontsize=16) + ax.set_xticks(x) + ax.set_xticklabels(categories, fontsize=12) + ax.legend() + ax.grid(axis="y", linestyle="--", alpha=0.7) + + # Add value labels on top of the bars + def autolabel_single_bar(rects): + for rect in rects: + height = rect.get_height() + ax.annotate( + f"{height:.2f}", + xy=(rect.get_x() + rect.get_width() / 2, height), + xytext=(0, 3), # 3 points vertical offset + textcoords="offset points", + ha="center", + va="bottom", + fontsize=9, + ) + + autolabel_single_bar(rects1) + autolabel_single_bar(rects2) + autolabel_single_bar(rects3) + + plt.tight_layout() + return fig + + +def plot_auc_roc_curve( + y_true: np.ndarray, + y_scores: np.ndarray, + title: str = "Receiver Operating Characteristic (ROC) Curve", +) -> plt.Figure: + """ + Generates and plots the Receiver Operating Characteristic (ROC) curve and calculates + the Area Under the Curve (AUC) for binary classification predictions. + + Args: + y_true (np.ndarray): True binary labels (0 or 1). + y_scores (np.ndarray): Target scores, usually the probability estimates + of the positive class. + title (str, optional): Title for the plot. Defaults to + "Receiver Operating Characteristic (ROC) Curve". + + Returns: + matplotlib.figure.Figure: The generated matplotlib figure object containing the ROC curve. + """ + + # Ensure inputs are numpy arrays + y_true = np.asarray(y_true) + y_scores = np.asarray(y_scores) + + # Calculate False Positive Rate (FPR), True Positive Rate (TPR), and thresholds + # fpr: array of shape (n_thresholds,) + # Increasing false positive rates such that element i is the false positive rate + # of predictions with score >= thresholds[i]. + # tpr: array of shape (n_thresholds,) + # Increasing true positive rates such that element i is the true positive rate + # of predictions with score >= thresholds[i]. + # thresholds: array of shape (n_thresholds,) + # Decreasing thresholds on the decision function used to compute fpr and tpr. + # thresholds[0] represents no instances being predicted as positive. + fpr, tpr, thresholds = roc_curve(y_true, y_scores) + + # The AUC provides an aggregate measure of performance across all possible + # classification thresholds. It ranges from 0 to 1, where 1 is perfect + # classification and 0.5 is random. + roc_auc = auc(fpr, tpr) + + fig, ax = plt.subplots(figsize=(8, 8)) + + ax.plot( + fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (AUC = {roc_auc:.2f})" + ) + + # A random classifier would have an AUC of 0.5, indicated by this diagonal line. + ax.plot( + [0, 1], + [0, 1], + color="navy", + lw=2, + linestyle="--", + label="Random Classifier (AUC = 0.5)", + ) + + ax.set_title(title, fontsize=16) + ax.set_xlabel("False Positive Rate (FPR)", fontsize=12) + ax.set_ylabel("True Positive Rate (TPR)", fontsize=12) + + ax.set_xlim([0.0, 1.0]) + ax.set_ylim([0.0, 1.05]) + + ax.legend(loc="lower right") + + ax.grid(True, linestyle="--", alpha=0.7) + + plt.tight_layout() + return fig + + +# dot plot against two arrays +def plot_dot_plot( + x_values, + y_values, + title="Dot Plot", + x_label="X-axis", + y_label="Y-axis", + marker_style="o", + marker_color="blue", + alpha=0.7, + figsize=(8, 6), +): + """ + Plots two arrays against each other as a dot plot (scatter plot). + + Args: + x_values (list or numpy.ndarray): The values for the x-axis. + y_values (list or numpy.ndarray): The values for the y-axis. + Must have the same length as x_values. + title (str, optional): The title of the plot. Defaults to "Dot Plot". + x_label (str, optional): The label for the x-axis. Defaults to "X-axis". + y_label (str, optional): The label for the y-axis. Defaults to "Y-axis". + marker_style (str, optional): The style of the markers. E.g., 'o' for circles, + 'x' for 'x's, '*' for stars. Defaults to 'o'. + marker_color (str, optional): The color of the markers. E.g., 'blue', 'red', + 'green', '#FF5733'. Defaults to 'blue'. + alpha (float, optional): The transparency of the markers (0.0 to 1.0). + Useful for visualizing dense data. Defaults to 0.7. + figsize (tuple, optional): The size of the figure (width, height) in inches. + Defaults to (8, 6). + """ + if len(x_values) != len(y_values): + print("Error: x_values and y_values must have the same length.") + return + + plt.figure(figsize=figsize) # Set the figure size + + # Create the scatter plot + plt.scatter( + x_values, y_values, marker=marker_style, color=marker_color, alpha=alpha + ) + + # Add labels and title + plt.title(title) + plt.xlabel(x_label) + plt.ylabel(y_label) + + plt.grid(True, linestyle="--", alpha=0.6) # Add a grid for better readability + plt.tight_layout() # Adjust layout to prevent labels from overlapping + plt.show() + + +def plot_dictionary_bar_chart( + data_dict: dict[str, float], + title: str = "Comparison of Values per Category", + x_label: str = "Category", + y_label: str = "Value", + sort_by_value: bool = False, # Set to True to sort bars by their height +): + polar_amino_acid = ["S", "T", "N", "Q", "C", "Y", "D", "E", "K", "R", "H"] + """ + Plots a bar chart for a dictionary where keys map to single float values. + + Args: + data_dict (dict[str, float]): A dictionary where keys are category names (str) + and values are single numerical data points (float or int). + title (str, optional): The main title for the plot. + x_label (str, optional): The label for the x-axis (categories). + y_label (str, optional): The label for the y-axis (values). + sort_by_value (bool, optional): If True, bars will be sorted by their value (height). + Defaults to False (sorted by key name). + bar_color (str, optional): The color of the bars. Defaults to 'skyblue'. + """ + if not data_dict: + print("Error: The input dictionary is empty. No chart to plot.") + return + + # Extract keys and values + if sort_by_value: + # Sort items by value (ascending) + sorted_items = sorted(data_dict.items(), key=lambda item: item[1]) + keys = [item[0] for item in sorted_items] + values = [item[1] for item in sorted_items] + else: + # Sort items by key name (alphabetical) for consistent order if not sorting by value + sorted_items = sorted(data_dict.items()) + keys = [item[0] for item in sorted_items] + values = [item[1] for item in sorted_items] + + colors_by_threshold = ["blue" if v in polar_amino_acid else "red" for v in keys] + import matplotlib.patches as mpatches + + red_patch = mpatches.Patch(color="red", label="Non-Polar amino acid") + blue_patch = mpatches.Patch(color="blue", label="Polar amino acid") + + # Create the bar chart + plt.figure(figsize=(10, 6)) # Adjust figure size as needed + plt.bar(keys, values, color=colors_by_threshold) + # Add labels and title + plt.title(title) + plt.xlabel(x_label) + plt.ylabel(y_label) + + # Rotate x-axis labels if there are many categories to prevent overlap + if len(keys) > 5: # Arbitrary threshold, adjust as needed + plt.xticks(rotation=45, ha="right") + + plt.grid(axis="y", linestyle="--", alpha=0.7) # Add horizontal grid lines + plt.tight_layout() # Adjust layout to prevent labels from overlapping + plt.legend(handles=[red_patch, blue_patch]) + plt.show() + + +def plot_p_values_bar_chart( + data_dict: dict[str, float], + title: str = "Comparison of Values per Category", + x_label: str = "Category", + y_label: str = "Value", + sort_by_value: bool = False, # Set to True to sort bars by their height +): + polar_amino_acid = ["S", "T", "N", "Q", "C", "Y", "D", "E", "K", "R", "H"] + """ + Plots a bar chart for a dictionary where keys map to single float values. + + Args: + data_dict (dict[str, float]): A dictionary where keys are category names (str) + and values are single numerical data points (float or int). + title (str, optional): The main title for the plot. + x_label (str, optional): The label for the x-axis (categories). + y_label (str, optional): The label for the y-axis (values). + sort_by_value (bool, optional): If True, bars will be sorted by their value (height). + Defaults to False (sorted by key name). + bar_color (str, optional): The color of the bars. Defaults to 'skyblue'. + """ + if not data_dict: + print("Error: The input dictionary is empty. No chart to plot.") + return + + # Extract keys and values + if sort_by_value: + # Sort items by value (ascending) + sorted_items = sorted(data_dict.items(), key=lambda item: item[1]) + keys = [item[0] for item in sorted_items] + values = [item[1] for item in sorted_items] + else: + # Sort items by key name (alphabetical) for consistent order if not sorting by value + sorted_items = sorted(data_dict.items()) + keys = [item[0] for item in sorted_items] + values = [item[1] for item in sorted_items] + import matplotlib.patches as mpatches + + colors_by_threshold = ["blue" if v in polar_amino_acid else "red" for v in keys] + + red_patch = mpatches.Patch(color="red", label="Non-Polar amino acid") + blue_patch = mpatches.Patch(color="blue", label="Polar amino acid") + + # Create the bar chart + plt.figure(figsize=(10, 6)) # Adjust figure size as needed + plt.bar(keys, values, color=colors_by_threshold) + + # Set the y-axis to a logarithmic scale + plt.yscale("log") + + # Add a dotted black horizontal line at y = 0.05 + # The `axhline` function adds a horizontal line across the axis. + # `linestyle=':'` creates a dotted line, and `color='black'` sets its color. + plt.axhline(y=0.05, color="black", linestyle=":", label="Threshold at 0.05") + + # Add labels and title + plt.title(title) + plt.xlabel(x_label) + plt.ylabel(y_label) + + # Rotate x-axis labels if there are many categories to prevent overlap + if len(keys) > 5: # Arbitrary threshold, adjust as needed + plt.xticks(rotation=45, ha="right") + + plt.grid(axis="y", linestyle="--", alpha=0.7) # Add horizontal grid lines + plt.tight_layout() # Adjust layout to prevent labels from overlapping + plt.legend( + handles=[ + red_patch, + blue_patch, + mpatches.Patch(color="black", linestyle=":", label="Threshold at 0.05"), + ], + loc="upper left", + ) + plt.show() diff --git a/src/af3_linear_epitopes/statistics_focal.py b/src/af3_linear_epitopes/statistics_focal.py new file mode 100644 index 0000000..15f9335 --- /dev/null +++ b/src/af3_linear_epitopes/statistics_focal.py @@ -0,0 +1,65 @@ +# --- +# jupyter: +# jupytext: +# cell_metadata_filter: -all +# custom_cell_magics: kql +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.11.2 +# kernelspec: +# display_name: linear-epitope +# language: python +# name: python3 +# --- + +# %% +import polars as pl +import matplotlib.pyplot as plt +import numpy as np +from MDAnalysis.analysis.dssp import DSSP +from mdaf3.FeatureExtraction import * +from mdaf3.AF3OutputParser import AF3Output +from pathlib import Path +from sklearn.metrics import roc_curve, auc + + +CHUNKSIZE = 15 +# data_dir = ( +# "/tgen_labs/altin/alphafold3/runs/linear_peptide/data/hv/focal_protein/inference" +# ) +# fp_test_dat = pl.read_parquet( +# "/scratch/sromero/af3-linear-epitopes/data/hv/focal_protein/staged/00_focal_protein.filt.parquet" +# ) + + +def fp_extract(row, path): + + try: + af3 = AF3Output(Path(path) / row["job_name"]) + row["pLDDT"] = ( + af3.get_mda_universe().atoms.select_atoms("name CA").tempfactors.tolist() + ) + except FileNotFoundError: + row["pLDDT"] = None + + return row + + +def pl_fp_extract(dataset, path): + fp_dataset = split_apply_combine(dataset, fp_extract, path, chunksize=CHUNKSIZE) + return fp_dataset + + +def fp_pLDDT_score_9mer(dataset): + fp_9mer_mean_pLDDT = dataset.with_columns( + pl.col("pLDDT") + .list.slice(pl.col("fp_seq_idxs"), 30) + .map_elements( + lambda x: [x[i : i + 9].mean() for i in range(0, len(x) - 8)], + return_dtype=pl.List(pl.Float64), + ) + .alias("pLDDT_slice_9mer") + ) + return fp_9mer_mean_pLDDT diff --git a/workflows/05_extract_feat_focal_protein.nf b/workflows/05_extract_feat_focal_protein.nf new file mode 100644 index 0000000..822b68d --- /dev/null +++ b/workflows/05_extract_feat_focal_protein.nf @@ -0,0 +1,35 @@ +process EXTRACT_RSA { + queue 'compute' + cpus '8' + clusterOptions '--time=1-00:00:00' + memory '64GB' + executor "slurm" + tag "rsa" + conda 'envs/env.yaml' + publishDir "$params.data_dir/$params.dset_name/focal_protein/staged", mode: 'copy' + + input: + path(pq) + path(inf) + + output: + path("*.parquet") + + script: + """ + rsa_calculator.py \\ + -pq ${pq} \\ + -i ${inf} \\ + -o "05_focal_protein.rsa.parquet" + """ +} + +workflow { + filt_pq = Channel.fromPath("$params.data_dir/$params.dset_name/focal_protein/staged/*.filt*.parquet") + inf = Channel.fromPath("$params.data_dir/$params.dset_name/focal_protein/inference") + + EXTRACT_RSA( + filt_pq, + inf + ) +} \ No newline at end of file diff --git a/workflows/bin/rsa_calculator.py b/workflows/bin/rsa_calculator.py new file mode 100755 index 0000000..f80c6be --- /dev/null +++ b/workflows/bin/rsa_calculator.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python +import polars as pl +from af3_linear_epitopes import statistics as st +import argparse + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-pq", + "--input_parquet", + type=str, + ) + parser.add_argument( + "-i", + "--inference_path", + type=str, + ) + parser.add_argument( + "-o", + "--output_path", + type=str, + ) + args = parser.parse_args() + + all_statistics_fp = pl.read_parquet(args.input_parquet) + all_statistics_fp = st.pl_sasa_fp(all_statistics_fp, args.inference_path) + + all_statistics_fp.write_parquet(args.output_path) diff --git a/workflows/bin/statistics_hv_in_class.py b/workflows/bin/statistics_hv_in_class.py new file mode 100644 index 0000000..ee13f60 --- /dev/null +++ b/workflows/bin/statistics_hv_in_class.py @@ -0,0 +1,106 @@ +import polars as pl + +# this one is my package +from mdaf3.AF3OutputParser import AF3Output +from mdaf3.FeatureExtraction import * +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import roc_curve, auc +from pathlib import Path +from af3_linear_epitopes import statistics as st +from af3_linear_epitopes import statistics_focal as stf + +# import dataframes from "staged" directory +fp_hv_class_dat = pl.read_parquet( + "data/hv_class/focal_protein/staged/hv_class_focal_protein.filt.clust.parquet" +).filter(pl.col("representative")) +hv_class_dat = pl.read_parquet( + "data/hv_class/peptide/staged/hv_class_peptide.filt.parquet" +) + +all_statistics_hv_class = st.statistics( + hv_class_dat, + "data/hv_class/peptide/inference", +) + + +all_statistics_hv_class = st.peptide_9mer( + all_statistics_hv_class, + "data/hv_class/peptide/inference", +) + + +all_statistics_hv_class = st.statistics_9mer( + all_statistics_hv_class, + "data/hv_class/peptide/inference", +) +all_statistics_hv_class = st.pae_statistics( + all_statistics_hv_class, + "data/hv_class/peptide/inference", +) +all_statistics_hv_class = st.pl_avg_weight( + all_statistics_hv_class, "data/hv_class/peptide/inference" +) +# ______________________________________________________________________ +# all_statistics_hv_class = st.pl_helix( +# all_statistics_hv_class, "data/hv_class/peptide/inference" +# ) +# all_statistics_hv_class = all_statistics_hv_class.with_columns( +# ( +# ( +# pl.col("helix").list.sum().cast(pl.Float64) +# / pl.col("helix").list.len().cast(pl.Float64).fill_null(0) +# ) +# * 100 +# ).alias("true__helix_percentage") +# ) + +# all_statistics_hv_class = st.pl_beta( +# all_statistics_hv_class, " data/hv_class/peptide/inference" +# ) +# all_statistics_hv_class = all_statistics_hv_class.with_columns( +# ( +# ( +# pl.col("beta").list.sum().cast(pl.Float64) +# / pl.col("beta").list.len().cast(pl.Float64).fill_null(0) +# ) +# * 100 +# ).alias("true_beta_percentage") +# ) + +fp_hv_class_dat = stf.pl_fp_extract( + fp_hv_class_dat, + "data/hv_class/focal_protein/inference", +) +all_statistics_hv_class_fp = hv_class_dat.explode(["fp_job_names", "fp_seq_idxs"]).join( + fp_hv_class_dat, left_on="fp_job_names", right_on="job_name" +) + + +all_statistics_hv_class_fp = all_statistics_hv_class_fp.explode(["fp_seq_idxs"]) +all_statistics_hv_class_fp = all_statistics_hv_class_fp.with_columns( + pl.col("pLDDT") + .list.slice(pl.col("fp_seq_idxs"), 30) + .list.mean() + .alias("mean_pLDDT_slice") +) + +all_statistics_hv_class_fp = stf.fp_pLDDT_score_9mer(all_statistics_hv_class_fp) +all_statistics_hv_class_fp = st.pl_9mer_weight( + all_statistics_hv_class_fp, + "data/hv_class/peptide/inference", +) +all_statistics_hv_class_error_fp = st.pl_structure_error( + all_statistics_hv_class_fp, + "data/hv_class/focal_protein/inference", +) + +all_statistics_hv_class.write_parquet( + "data/hv_class/peptide/staged/01_hv_class.features.parquet" +) +all_statistics_hv_class_fp.write_parquet( + "data/hv_class/focal_protein/staged/01_hv_class.exploded.parquet" +) +all_statistics_hv_class_error_fp.write_parquet( + "data/hv_class/focal_protein/staged/01_hv_class.error.parquet" +) diff --git a/workflows/bin/statistics_in_class.py b/workflows/bin/statistics_in_class.py new file mode 100644 index 0000000..654bc9c --- /dev/null +++ b/workflows/bin/statistics_in_class.py @@ -0,0 +1,99 @@ +import polars as pl + +# this one is my package +from mdaf3.AF3OutputParser import AF3Output +from mdaf3.FeatureExtraction import * +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import roc_curve, auc +from pathlib import Path +from af3_linear_epitopes import statistics as st +from af3_linear_epitopes import statistics_focal as stf + +# import dataframes from "staged" directory +fp_in_class_dat = pl.read_parquet( + "data/in_class/focal_protein/staged/in_class_focal_protein.filt.clust.parquet" +).filter(pl.col("representative")) +in_class_dat = pl.read_parquet( + "data/in_class/peptide/staged/in_class_peptide.filt.parquet" +) + +# all_statistics_in_class = st.statistics( +# in_class_dat, +# "data/hv_class/peptide/inference", +# ) + + +# all_statistics_in_class = st.peptide_9mer( +# all_statistics_in_class, +# "data/hv_class/peptide/inference", +# ) + + +# all_statistics_in_class = st.statistics_9mer( +# all_statistics_in_class, +# "data/hv_class/peptide/inference", +# ) +# all_statistics_in_class = st.pae_statistics( +# all_statistics_in_class, +# "data/hv_class/peptide/inference", +# ) +# all_statistics_in_class = st.pl_avg_weight( +# all_statistics_in_class, "data/hv_class/peptide/inference" +# ) + +# all_statistics_in_class = st.pl_helix( +# all_statistics_in_class, "data/hv_class/peptide/inference" +# ) +# all_statistics_in_class = all_statistics_in_class.with_columns( +# ( +# ( +# pl.col("helix").list.sum().cast(pl.Float64) +# / pl.col("helix").list.len().cast(pl.Float64).fill_null(0) +# ) +# * 100 +# ).alias("true__helix_percentage") +# ) + +# all_statistics_in_class = st.pl_beta( +# all_statistics_in_class, " data/hv_class/peptide/inference" +# ) +# all_statistics_in_class = all_statistics_in_class.with_columns( +# ( +# ( +# pl.col("beta").list.sum().cast(pl.Float64) +# / pl.col("beta").list.len().cast(pl.Float64).fill_null(0) +# ) +# * 100 +# ).alias("true_beta_percentage") +# ) + +fp_in_class_dat = stf.pl_fp_extract( + fp_in_class_dat, + "data/in_class/focal_protein/inference", +) +all_statistics_in_class_fp = in_class_dat.explode(["fp_job_names", "fp_seq_idxs"]).join( + fp_in_class_dat, left_on="fp_job_names", right_on="job_name" +) + + +all_statistics_in_class_fp = all_statistics_in_class_fp.explode(["fp_seq_idxs"]) +all_statistics_in_class_fp = all_statistics_in_class_fp.with_columns( + pl.col("pLDDT") + .list.slice(pl.col("fp_seq_idxs"), 30) + .list.mean() + .alias("mean_pLDDT_slice") +) + +all_statistics_in_class_fp = stf.fp_pLDDT_score_9mer(all_statistics_in_class_fp) +# all_statistics_in_class_fp = st.pl_9mer_weight( +# all_statistics_in_class_fp, +# "/scratch/sromero/af3-linear-epitopes/data/in_class/focal_protein/inference", +# ) + +# all_statistics_in_class.write_parquet( +# "data/in_class/peptide/staged/01_in_class.features.parquet" +# ) +all_statistics_in_class_fp.write_parquet( + "data/in_class/focal_protein/staged/01_in_class.exploded.parquet" +) diff --git a/workflows/bin/statistics_peptide.py b/workflows/bin/statistics_peptide.py new file mode 100644 index 0000000..103b7be --- /dev/null +++ b/workflows/bin/statistics_peptide.py @@ -0,0 +1,92 @@ +import polars as pl + +# this one is my package +from mdaf3.AF3OutputParser import AF3Output +from mdaf3.FeatureExtraction import * +import matplotlib.pyplot as plt +import numpy as np +from sklearn.metrics import roc_curve, auc +from pathlib import Path +from af3_linear_epitopes import statistics as st +from af3_linear_epitopes import statistics_focal as stf + +# import dataframes from "staged" directory +fp_test_dat = pl.read_parquet( + "../data/hv/focal_protein/staged/00_focal_protein.filt.parquet" +) +peptide_test_dat = pl.read_parquet("data/hv/peptide/staged/00_hv.filt.parquet") + + +all_statistics = st.statistics( + peptide_test_dat, + "data/hv/peptide/inference", +) + + +all_statistics = st.peptide_9mer( + all_statistics, + "data/hv/peptide/inference", +) + + +all_statistics = st.statistics_9mer( + all_statistics, + "data/hv/peptide/inference", +) +all_statistics = st.pae_statistics( + all_statistics, + "data/hv/peptide/inference", +) +all_statistics = st.pl_avg_weight(all_statistics, "data/hv/peptide/inference") + +all_statistics = st.pl_helix(all_statistics, "data/hv/peptide/inference") +all_statistics = all_statistics.with_columns( + ( + ( + pl.col("helix").list.sum().cast(pl.Float64) + / pl.col("helix").list.len().cast(pl.Float64).fill_null(0) + ) + * 100 + ).alias("true__helix_percentage") +) + +all_statistics = st.pl_beta(all_statistics, " data/hv/peptide/inference") +all_statistics = all_statistics.with_columns( + ( + ( + pl.col("beta").list.sum().cast(pl.Float64) + / pl.col("beta").list.len().cast(pl.Float64).fill_null(0) + ) + * 100 + ).alias("true_beta_percentage") +) + +fp_test_dat = stf.pl_fp_extract( + fp_test_dat, + "/tgen_labs/altin/alphafold3/runs/linear_peptide/data/hv/focal_protein/inference", +) +all_statistics_fp = peptide_test_dat.explode(["fp_job_names", "fp_seq_idxs"]).join( + fp_test_dat, left_on="fp_job_names", right_on="job_name" +) + + +all_statistics_fp = all_statistics_fp.explode(["fp_seq_idxs"]) +all_statistics_fp = all_statistics_fp.with_columns( + pl.col("pLDDT") + .list.slice(pl.col("fp_seq_idxs"), 30) + .list.mean() + .alias("mean_pLDDT_slice") +) + +all_statistics_fp = stf.fp_pLDDT_score_9mer( + all_statistics_fp, +) +all_statistics_fp = st.pl_9mer_weight( + all_statistics_fp, "/scratch/sromero/af3-linear-epitopes/data/hv/peptide/inference" +) +all_statistics = st.pl_9mer_weight( + all_statistics, "/scratch/sromero/af3-linear-epitopes/data/hv/peptide/inference" +) + +all_statistics.write_parquet("data/hv/peptide/staged/01_hv.features.parquet") +all_statistics_fp.write_parquet("data/hv/peptide/staged/01_hv.exploded.parquet")