From 0d7638b18c915c4b371770408b6962809a9bf100 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Tue, 3 Feb 2026 13:28:05 +0200 Subject: [PATCH 01/12] Plotting hacks --- .../compressor/plotting/plot_metrics.py | 118 ++++++++++++------ 1 file changed, 79 insertions(+), 39 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index 408b098..0385907 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -13,15 +13,21 @@ _COMPRESSOR2LINEINFO = [ ("jpeg2000", ("#EE7733", "-")), - ("sperr", ("#117733", ":")), - ("zfp-round", ("#DDAA33", "--")), + ("sperr", ("#117733", "-")), + ("zfp-round", ("#DDAA33", "-")), ("zfp", ("#EE3377", "--")), - ("sz3", ("#CC3311", "-.")), - ("bitround-pco", ("#0077BB", ":")), + ("sz3", ("#CC3311", "-")), + ("bitround-pco", ("#0077BB", "-")), ("bitround", ("#33BBEE", "-")), ("stochround-pco", ("#BBBBBB", "--")), ("stochround", ("#009988", "--")), ("tthresh", ("#882255", "-.")), + ("safeguarded-sperr", ("#117733", ":")), + ("safeguarded-zfp-round", ("#DDAA33", ":")), + ("safeguarded-sz3", ("#CC3311", ":")), + ("safeguarded-zero-dssim", ("#9467BD", "--")), + ("safeguarded-zero", ("#9467BD", ":")), + ("safeguarded-bitround-pco", ("#0077BB", ":")), ] @@ -44,6 +50,25 @@ def _get_lineinfo(compressor: str) -> tuple[str, str]: ("stochround-pco", "StochRound + PCO"), ("stochround", "StochRound + Zstd"), ("tthresh", "TTHRESH"), + ("safeguarded-sperr", "Safeguarded(SPERR)"), + ("safeguarded-zfp-round", "Safeguarded(ZFP-ROUND)"), + ("safeguarded-sz3", "Safeguarded(SZ3)"), + ("safeguarded-zero-dssim", "Safeguarded(0, dSSIM)"), + ("safeguarded-zero", "Safeguarded(0)"), + ("safeguarded-bitround-pco", "Safeguarded(BitRound + PCO)"), +] + +_COMPRESSOR_ORDER = [ + "BitRound + PCO", + "Safeguarded(BitRound + PCO)", + "ZFP-ROUND", + "Safeguarded(ZFP-ROUND)", + "SZ3", + "Safeguarded(SZ3)", + "SPERR", + "Safeguarded(SPERR)", + "Safeguarded(0)", + "Safeguarded(0, dSSIM)", ] DISTORTION2LEGEND_NAME = { @@ -102,6 +127,7 @@ def plot_metrics( df = pd.read_csv(metrics_path / "all_results.csv") # Filter out excluded datasets and compressors + # bitround jpeg2000-conservative-abs stochround-conservative-abs stochround-pco-conservative-abs zfp-conservative-abs bitround-conservative-rel stochround-pco stochround zfp jpeg2000 df = df[~df["Compressor"].isin(exclude_compressor)] df = df[~df["Dataset"].isin(exclude_dataset)] is_tiny = df["Dataset"].str.endswith("-tiny") @@ -111,13 +137,13 @@ def plot_metrics( filter_chunked = is_chunked if chunked_datasets else ~is_chunked df = df[filter_chunked] - _plot_per_variable_metrics( - datasets=datasets, - compressed_datasets=compressed_datasets, - plots_path=plots_path, - all_results=df, - rd_curves_metrics=["Max Absolute Error", "MAE", "DSSIM", "Spectral Error"], - ) + # _plot_per_variable_metrics( + # datasets=datasets, + # compressed_datasets=compressed_datasets, + # plots_path=plots_path, + # all_results=df, + # rd_curves_metrics=["Max Absolute Error", "MAE", "DSSIM", "Spectral Error"], + # ) df = _rename_compressors(df) normalized_df = _normalize(df) @@ -419,7 +445,10 @@ def _plot_aggregated_rd_curve( # Exclude variables that are not relevant for the distortion metric. normalized_df = normalized_df[~normalized_df["Variable"].isin(exclude_vars)] - compressors = normalized_df["Compressor"].unique() + compressors = sorted( + normalized_df["Compressor"].unique(), + key=lambda k: _COMPRESSOR_ORDER.index(_get_legend_name(k)), + ) agg_distortion = normalized_df.groupby(["Error Bound Name", "Compressor"])[ [compression_metric, distortion_metric] ].agg(agg) @@ -503,8 +532,8 @@ def _plot_aggregated_rd_curve( ) plt.legend( title="Compressor", - loc="upper right", - bbox_to_anchor=(0.8, 0.99), + loc="upper left", + # bbox_to_anchor=(0.8, 0.99), fontsize=12, title_fontsize=14, ) @@ -614,27 +643,32 @@ def _plot_instruction_count(df, outfile: None | Path = None): def _get_median_and_quantiles(df, encode_column, decode_column): - return df.groupby(["Compressor", "Error Bound Name"])[ - [encode_column, decode_column] - ].agg( - encode_median=pd.NamedAgg( - column=encode_column, aggfunc=lambda x: x.quantile(0.5) - ), - encode_lower_quantile=pd.NamedAgg( - column=encode_column, aggfunc=lambda x: x.quantile(0.25) - ), - encode_upper_quantile=pd.NamedAgg( - column=encode_column, aggfunc=lambda x: x.quantile(0.75) - ), - decode_median=pd.NamedAgg( - column=decode_column, aggfunc=lambda x: x.quantile(0.5) - ), - decode_lower_quantile=pd.NamedAgg( - column=decode_column, aggfunc=lambda x: x.quantile(0.25) - ), - decode_upper_quantile=pd.NamedAgg( - column=decode_column, aggfunc=lambda x: x.quantile(0.75) - ), + return ( + df.groupby(["Compressor", "Error Bound Name"])[[encode_column, decode_column]] + .agg( + encode_median=pd.NamedAgg( + column=encode_column, aggfunc=lambda x: x.quantile(0.5) + ), + encode_lower_quantile=pd.NamedAgg( + column=encode_column, aggfunc=lambda x: x.quantile(0.25) + ), + encode_upper_quantile=pd.NamedAgg( + column=encode_column, aggfunc=lambda x: x.quantile(0.75) + ), + decode_median=pd.NamedAgg( + column=decode_column, aggfunc=lambda x: x.quantile(0.5) + ), + decode_lower_quantile=pd.NamedAgg( + column=decode_column, aggfunc=lambda x: x.quantile(0.25) + ), + decode_upper_quantile=pd.NamedAgg( + column=decode_column, aggfunc=lambda x: x.quantile(0.75) + ), + ) + .sort_index( + level=0, + key=lambda ks: [_COMPRESSOR_ORDER.index(_get_legend_name(k)) for k in ks], + ) ) @@ -645,7 +679,10 @@ def _plot_grouped_df( # Bar width bar_width = 0.35 - compressors = grouped_df.index.levels[0].tolist() + compressors = sorted( + grouped_df.index.levels[0].tolist(), + key=lambda k: _COMPRESSOR_ORDER.index(_get_legend_name(k)), + ) x_labels = [_get_legend_name(c) for c in compressors] x_positions = range(len(x_labels)) @@ -653,7 +690,10 @@ def _plot_grouped_df( for i, error_bound in enumerate(error_bounds): ax = axes[i] - bound_data = grouped_df.xs(error_bound, level="Error Bound Name") + bound_data = grouped_df.xs(error_bound, level="Error Bound Name").sort_index( + level=0, + key=lambda ks: [_COMPRESSOR_ORDER.index(_get_legend_name(k)) for k in ks], + ) # Plot encode throughput ax.bar( @@ -720,11 +760,11 @@ def _plot_bound_violations(df, bound_names, outfile: None | Path = None): df_bound["Compressor"] = df_bound["Compressor"].map(_get_legend_name) pass_fail = df_bound.pivot( index="Compressor", columns="Variable", values="Satisfies Bound (Passed)" - ) + ).sort_index(key=lambda ks: [_COMPRESSOR_ORDER.index(k) for k in ks]) pass_fail = pass_fail.astype(np.float32) fraction_fail = df_bound.pivot( index="Compressor", columns="Variable", values="Satisfies Bound (Value)" - ) + ).sort_index(key=lambda ks: [_COMPRESSOR_ORDER.index(k) for k in ks]) annotations = fraction_fail.map( lambda x: "{:.2f}".format(x * 100) if x * 100 >= 0.01 else "<0.01" ) From 87aacacedbfdc90434cb2b23d7504eb3d3ac50e5 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Tue, 3 Feb 2026 15:02:05 +0200 Subject: [PATCH 02/12] some improvements --- .../compressor/plotting/plot_metrics.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index 0385907..fbf4990 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -42,27 +42,27 @@ def _get_lineinfo(compressor: str) -> tuple[str, str]: _COMPRESSOR2LEGEND_NAME = [ ("jpeg2000", "JPEG2000"), ("sperr", "SPERR"), - ("zfp-round", "ZFP-ROUND"), + ("zfp-round", "ZFP"), ("zfp", "ZFP"), ("sz3", "SZ3"), - ("bitround-pco", "BitRound + PCO"), + ("bitround-pco", "BitRound"), ("bitround", "BitRound + Zstd"), ("stochround-pco", "StochRound + PCO"), ("stochround", "StochRound + Zstd"), ("tthresh", "TTHRESH"), ("safeguarded-sperr", "Safeguarded(SPERR)"), - ("safeguarded-zfp-round", "Safeguarded(ZFP-ROUND)"), + ("safeguarded-zfp-round", "Safeguarded(ZFP)"), ("safeguarded-sz3", "Safeguarded(SZ3)"), ("safeguarded-zero-dssim", "Safeguarded(0, dSSIM)"), ("safeguarded-zero", "Safeguarded(0)"), - ("safeguarded-bitround-pco", "Safeguarded(BitRound + PCO)"), + ("safeguarded-bitround-pco", "Safeguarded(BitRound)"), ] _COMPRESSOR_ORDER = [ - "BitRound + PCO", - "Safeguarded(BitRound + PCO)", - "ZFP-ROUND", - "Safeguarded(ZFP-ROUND)", + "BitRound", + "Safeguarded(BitRound)", + "ZFP", + "Safeguarded(ZFP)", "SZ3", "Safeguarded(SZ3)", "SPERR", @@ -476,7 +476,13 @@ def _plot_aggregated_rd_curve( if remove_outliers: # SZ3 and JPEG2000 often give outlier values and violate the bounds. - exclude_compressors = ["sz3", "jpeg2000"] + exclude_compressors = [ + "sz3", + "jpeg2000", + "safeguarded-zero-dssim", + "safeguarded-zero", + "safeguarded-sz3", + ] filtered_agg = agg_distortion[ ~agg_distortion.index.get_level_values("Compressor").isin( exclude_compressors From 46ecfd8f779f08b9d194c0428411467f28bc644e Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Tue, 3 Feb 2026 15:55:44 +0200 Subject: [PATCH 03/12] Adjust SZ3 name --- src/climatebenchpress/compressor/plotting/plot_metrics.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index fbf4990..2e87acf 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -44,7 +44,7 @@ def _get_lineinfo(compressor: str) -> tuple[str, str]: ("sperr", "SPERR"), ("zfp-round", "ZFP"), ("zfp", "ZFP"), - ("sz3", "SZ3"), + ("sz3", "SZ3[v3.2]"), ("bitround-pco", "BitRound"), ("bitround", "BitRound + Zstd"), ("stochround-pco", "StochRound + PCO"), @@ -52,7 +52,7 @@ def _get_lineinfo(compressor: str) -> tuple[str, str]: ("tthresh", "TTHRESH"), ("safeguarded-sperr", "Safeguarded(SPERR)"), ("safeguarded-zfp-round", "Safeguarded(ZFP)"), - ("safeguarded-sz3", "Safeguarded(SZ3)"), + ("safeguarded-sz3", "Safeguarded(SZ3[v3.2])"), ("safeguarded-zero-dssim", "Safeguarded(0, dSSIM)"), ("safeguarded-zero", "Safeguarded(0)"), ("safeguarded-bitround-pco", "Safeguarded(BitRound)"), @@ -63,8 +63,8 @@ def _get_lineinfo(compressor: str) -> tuple[str, str]: "Safeguarded(BitRound)", "ZFP", "Safeguarded(ZFP)", - "SZ3", - "Safeguarded(SZ3)", + "SZ3[v3.2]", + "Safeguarded(SZ3[v3.2])", "SPERR", "Safeguarded(SPERR)", "Safeguarded(0)", From 29dd7bde1c3272427dc38183ab0acd6648b5eb04 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Wed, 4 Feb 2026 15:40:59 +0200 Subject: [PATCH 04/12] Draft safeguards scorecards --- scorecards.ipynb | 722 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 722 insertions(+) create mode 100644 scorecards.ipynb diff --git a/scorecards.ipynb b/scorecards.ipynb new file mode 100644 index 0000000..64bf6ee --- /dev/null +++ b/scorecards.ipynb @@ -0,0 +1,722 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "e7ab252b", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib as mpl\n", + "import seaborn as sns\n", + "from matplotlib.colors import LinearSegmentedColormap\n", + "import matplotlib.patches as mpatches\n", + "from matplotlib.lines import Line2D\n", + "\n", + "from pathlib import Path\n", + "from climatebenchpress.compressor.plotting.plot_metrics import (\n", + " _rename_compressors, \n", + " _get_legend_name,\n", + " _normalize,\n", + " _get_lineinfo,\n", + " DISTORTION2LEGEND_NAME,\n", + " _COMPRESSOR_ORDER,\n", + " _savefig\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "4ee15a6b", + "metadata": {}, + "source": [ + "# Process results" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2d242e7c", + "metadata": {}, + "outputs": [], + "source": [ + "results_file = \"metrics/all_results.csv\"\n", + "df = pd.read_csv(results_file)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ec124c29", + "metadata": {}, + "outputs": [], + "source": [ + "def create_data_matrix(\n", + " df: pd.DataFrame, \n", + " error_bound: str, \n", + " metrics: list[str] = ['DSSIM', 'MAE', 'Max Absolute Error', \"Spectral Error\", \"Compression Ratio [raw B / enc B]\", 'Satisfies Bound (Value)']\n", + "):\n", + " df_filtered = df[df['Error Bound Name'] == error_bound].copy()\n", + " df_filtered[\"Satisfies Bound (Value)\"] = df_filtered[\"Satisfies Bound (Value)\"] * 100 # Convert to percentage\n", + "\n", + " # Get unique variables and compressors\n", + " # dataset_variables = sorted(df_filtered[['Dataset', 'Variable']].drop_duplicates().apply(lambda x: \"/\".join(x), axis=1).unique())\n", + " dataset_variables = sorted(df_filtered['Variable'].unique())\n", + " compressors = sorted(df_filtered['Compressor'].unique(), key=lambda k: _COMPRESSOR_ORDER.index(_get_legend_name(k)))\n", + "\n", + " column_labels = []\n", + " for metric in metrics:\n", + " for dataset_variable in dataset_variables:\n", + " column_labels.append(f\"{dataset_variable}\\n{metric}\")\n", + "\n", + " # Initialize the data matrix\n", + " data_matrix = np.full((len(compressors), len(column_labels)), np.nan)\n", + "\n", + " # Fill the matrix with data\n", + " for i, compressor in enumerate(compressors):\n", + " for j, metric in enumerate(metrics):\n", + " for k, dataset_variable in enumerate(dataset_variables):\n", + " # Get data for this compressor-variable combination\n", + " # dataset, variable = dataset_variable.split('/')\n", + " variable = dataset_variable\n", + " subset = df_filtered[\n", + " (df_filtered['Compressor'] == compressor) & \n", + " (df_filtered['Variable'] == variable) #&\n", + " # (df_filtered['Dataset'] == dataset)\n", + " ]\n", + " if subset.empty:\n", + " print(f\"No data for Compressor: {compressor}, Variable: {variable}\")\n", + " continue\n", + "\n", + " if metric in [\"DSSIM\", \"Spectral Error\"] and variable in [\"ta\", \"tos\"]:\n", + " # These variables have large regions of NaN values which makes the \n", + " # DSSIM and Spectral Error values unreliable.\n", + " continue\n", + "\n", + "\n", + " col_idx = j * len(dataset_variables) + k\n", + " if metric in subset.columns:\n", + " values = subset[metric]\n", + " if len(values) == 1:\n", + " data_matrix[i, col_idx] = values.iloc[0]\n", + " \n", + " return data_matrix, compressors, dataset_variables" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2a399d30", + "metadata": {}, + "outputs": [], + "source": [ + "df = df[~df[\"Compressor\"].isin([\n", + " \"bitround\", \"jpeg2000-conservative-abs\", \"stochround-conservative-abs\",\n", + " \"stochround-pco-conservative-abs\", \"zfp-conservative-abs\",\n", + " \"bitround-conservative-rel\", \"stochround-pco\", \"stochround\", \"zfp\", \"jpeg2000\",\n", + "])]\n", + "df = df[~df[\"Dataset\"].str.contains(\"-tiny\")]\n", + "df = df[~df[\"Dataset\"].str.contains(\"-chunked\")]\n", + "df = _rename_compressors(df)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "2d019b8d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: rlut\n" + ] + } + ], + "source": [ + "metrics = ['DSSIM', 'MAE', 'Max Absolute Error', \"Spectral Error\", \"Compression Ratio [raw B / enc B]\", 'Satisfies Bound (Value)']\n", + "scorecard_data = {}\n", + "for error_bound in [\"low\", \"mid\", \"high\"]:\n", + " scorecard_data[error_bound] = create_data_matrix(df, error_bound, metrics)" + ] + }, + { + "cell_type": "markdown", + "id": "ae80d757", + "metadata": {}, + "source": [ + "# Scorecard" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "871ae766", + "metadata": {}, + "outputs": [], + "source": [ + "METRICS2NAME = {\n", + " # \"Max Absolute Error\": \"MaxAE\",\n", + " \"MAE\": \"Mean Absolute Error\",\n", + " \"Spatial Relative Error (Value)\": \"SRE\",\n", + " \"Compression Ratio [raw B / enc B]\": \"Compression Ratio\",\n", + " \"Satisfies Bound (Value)\": r\"% of Pixels Exceeding Error Bound\",\n", + "}\n", + "\n", + "VARIABLE2NAME = {\n", + " \"10m_u_component_of_wind\": \"10u\",\n", + " \"10m_v_component_of_wind\": \"10v\",\n", + " \"mean_sea_level_pressure\": \"msl\",\n", + "}\n", + "\n", + "DATASET2PREFIX = {\n", + " \"era5-hurricane\": \"h-\",\n", + "}\n", + "\n", + "def get_variable_label(variable):\n", + " dataset, var_name = variable.split('/')\n", + " prefix = DATASET2PREFIX.get(dataset, \"\")\n", + " var_name = VARIABLE2NAME.get(var_name, var_name)\n", + " return f\"{prefix}{var_name}\"\n", + "\n", + "\n", + "def create_compression_scorecard(\n", + " data_matrix, \n", + " compressors, \n", + " variables, \n", + " metrics, \n", + " cbar=True,\n", + " ref_compressor='sz3', \n", + " higher_better_metrics=[\"DSSIM\", \"Compression Ratio [raw B / enc B]\"],\n", + " save_fn=None,\n", + " compare_against_0=False,\n", + " highlight_bigger_than_one=False\n", + "):\n", + " \"\"\"\n", + " Create a scorecard plot similar to the weather forecasting example\n", + " \n", + " Parameters:\n", + " - data_matrix: 2D array with compressors as rows, metric-variable combinations as columns\n", + " - compressors: list of compressor names\n", + " - variables: list of variable names \n", + " - metrics: list of metric names\n", + " - ref_compressor: reference compressor for relative calculations\n", + " - save_fn: filename to save plot (optional)\n", + " \"\"\"\n", + " \n", + " # Calculate relative differences vs reference compressor\n", + " ref_idx = compressors.index(ref_compressor)\n", + " ref_values = data_matrix[ref_idx, :]\n", + " if compare_against_0:\n", + " ref_values = np.zeros_like(data_matrix[ref_idx, :])\n", + " \n", + " relative_matrix = np.full_like(data_matrix, np.nan)\n", + " if highlight_bigger_than_one:\n", + " relative_matrix = np.sign(data_matrix) * 101\n", + " for j in range(data_matrix.shape[1]):\n", + " if metrics[j // len(variables)] == \"Satisfies Bound (Value)\":\n", + " # For bound satisfication lower is better (less number of pixels exceeding error bound).\n", + " relative_matrix[:, j] = -1 * relative_matrix[:, j]\n", + " else:\n", + " for i in range(len(compressors)):\n", + " for j in range(data_matrix.shape[1]):\n", + " if not np.isnan(data_matrix[i, j]) and not np.isnan(ref_values[j]):\n", + " ref_val = np.abs(ref_values[j])\n", + " if ref_val == 0.0:\n", + " ref_val = 1e-10 # Avoid division by zero\n", + " if metrics[j // len(variables)] in higher_better_metrics:\n", + " # Higher is better metrics\n", + " relative_matrix[i, j] = (ref_values[j] - data_matrix[i, j]) / ref_val * 100\n", + " elif metrics[j // len(variables)] == \"Satisfies Bound (Value)\":\n", + " relative_matrix[i, j] = 100 if data_matrix[i, j] != 0 else 0\n", + " else:\n", + " relative_matrix[i, j] = (data_matrix[i, j] - ref_values[j]) / ref_val * 100\n", + "\n", + " # Set up colormap - similar to original\n", + " reds = sns.color_palette('Reds', 6)\n", + " blues = sns.color_palette('Blues_r', 6)\n", + " cmap = mpl.colors.ListedColormap(blues + [(0.95, 0.95, 0.95)] + reds)\n", + " # cb_levels = [-50, -20, -10, -5, -2, -1, 1, 2, 5, 10, 20, 50]\n", + " # cb_levels = [-75, -50, -25, -10, -5, -1, 1, 5, 10, 25, 50, 75]\n", + " cb_levels = [-100, -75, -50, -25, -10, -1, 1, 10, 25, 50, 75, 100]\n", + "\n", + " norm = mpl.colors.BoundaryNorm(cb_levels, cmap.N, extend='both')\n", + " \n", + " # Calculate figure dimensions\n", + " ncompressors = len(compressors)\n", + " nvariables = len(variables)\n", + " nmetrics = len(metrics)\n", + " \n", + " panel_width = (2.5 / 5) * nvariables\n", + " label_width = 1.5 * panel_width\n", + " padding_right = 0.1\n", + " panel_height = panel_width / nvariables\n", + " \n", + " title_height = panel_height * 1.25\n", + " cbar_height = panel_height * 2\n", + " spacing_height = panel_height * 0.1\n", + " spacing_width = panel_height * 0.2\n", + " \n", + " total_width = label_width + nmetrics * panel_width + (nmetrics - 1) * spacing_width + padding_right\n", + " total_height = title_height + cbar_height + ncompressors * panel_height + (ncompressors - 1) * spacing_height\n", + " \n", + " # Create figure and gridspec\n", + " fig = plt.figure(figsize=(total_width, total_height))\n", + " gs = mpl.gridspec.GridSpec(\n", + " ncompressors, nmetrics,\n", + " figure=fig,\n", + " left=label_width / total_width,\n", + " right=1 - padding_right / total_width,\n", + " top=1 - (title_height / total_height),\n", + " bottom=cbar_height / total_height,\n", + " hspace=spacing_height / panel_height,\n", + " wspace=spacing_width / panel_width\n", + " )\n", + " \n", + " # Plot each panel\n", + " for row, compressor in enumerate(compressors):\n", + " for col, metric in enumerate(metrics):\n", + " ax = fig.add_subplot(gs[row, col])\n", + "\n", + " # Get data for this metric (all variables)\n", + " start_col = col * nvariables\n", + " end_col = start_col + nvariables\n", + "\n", + " rel_values = relative_matrix[row, start_col:end_col].reshape(1, -1)\n", + " abs_values = data_matrix[row, start_col:end_col]\n", + " \n", + " # Create heatmap\n", + " img = ax.imshow(rel_values, aspect='auto', cmap=cmap, norm=norm)\n", + " \n", + " # Customize axes\n", + " ax.set_xticks([])\n", + " ax.set_xticklabels([])\n", + " ax.set_yticks([])\n", + " ax.set_yticklabels([])\n", + " \n", + " # Add white grid lines\n", + " for i in range(nvariables):\n", + " rect = mpl.patches.Rectangle(\n", + " (i - 0.5, -0.5), 1, 1,\n", + " linewidth=1, edgecolor='white', facecolor='none'\n", + " )\n", + " ax.add_patch(rect)\n", + " \n", + " # Add absolute values as text\n", + " for i, val in enumerate(abs_values):\n", + " # Ensure we don't have black text on dark background\n", + " color = \"black\" if abs(rel_values[0, i]) < 75 else \"white\"\n", + " fontsize = 10\n", + " # Format numbers appropriately\n", + " if metric in [\"DSSIM\", \"Spectral Error\"] and variables[i] in [\"ta\", \"tos\"]:\n", + " # These variables have large regions of NaN values which makes the \n", + " # DSSIM and Spectral Error values unreliable.\n", + " text = \"N/A\"\n", + " color = \"black\"\n", + " elif np.isnan(val):\n", + " text = \"Fail\"\n", + " color = \"black\"\n", + " elif abs(val) > 10_000:\n", + " text = f\"{val:.1e}\"\n", + " fontsize = 8\n", + " elif abs(val) > 10:\n", + " text = f\"{val:.0f}\"\n", + " elif abs(val) > 1:\n", + " text = f\"{val:.1f}\"\n", + " elif val == 1 and metric == \"DSSIM\":\n", + " text = \"1\"\n", + " elif val == 0:\n", + " text = \"0\"\n", + " elif abs(val) < 0.01:\n", + " text = f\"{val:.1e}\"\n", + " fontsize = 8\n", + " else:\n", + " text = f\"{val:.2f}\"\n", + " ax.text(\n", + " i, \n", + " 0, \n", + " text, \n", + " ha='center', \n", + " va='center', \n", + " fontsize=fontsize, \n", + " color=color\n", + " )\n", + "\n", + " # Add row labels (compressor names)\n", + " if col == 0:\n", + " ax.set_ylabel(_get_legend_name(compressor), rotation=0, ha='right', va='center',\n", + " labelpad=10, fontsize=14)\n", + " \n", + " # Add column titles (variable names)\n", + " if row == 0:\n", + " # ax.set_title(VARIABLE2NAME.get(variable, variable), fontsize=10, pad=10)\n", + " ax.set_title(METRICS2NAME.get(metric, metric), fontsize=16, pad=10)\n", + "\n", + " # Add metric labels at the top on the top row\n", + " if row == 0:\n", + " # ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)\n", + " # ax.set_xticks(range(nmetrics))\n", + " # ax.set_xticklabels(\n", + " # [METRICS2NAME.get(m, m) for m in metrics], \n", + " # rotation=45, \n", + " # ha='left', fontsize=8)\n", + " ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)\n", + " ax.set_xticks(range(nvariables))\n", + " ax.set_xticklabels(\n", + " [VARIABLE2NAME.get(v, v) for v in variables],\n", + " rotation=45,\n", + " ha='left', fontsize=12)\n", + " \n", + " # Style spines\n", + " for spine in ax.spines.values():\n", + " spine.set_color('0.7')\n", + " \n", + " # Add colorbar\n", + " if cbar and not highlight_bigger_than_one:\n", + " rel_cbar_height = cbar_height / total_height\n", + " cax = fig.add_axes((0.4, rel_cbar_height * 0.3, 0.5, rel_cbar_height * 0.2))\n", + " cb = fig.colorbar(img, cax=cax, orientation='horizontal')\n", + " cb.ax.set_xticks(cb_levels)\n", + " if highlight_bigger_than_one:\n", + " cb.ax.set_xlabel('Better ← |non-chunked - chunked| → Worse', fontsize=16)\n", + " else:\n", + " cb.ax.set_xlabel(f'Better ← % difference vs {_get_legend_name(ref_compressor)} → Worse', fontsize=16)\n", + " \n", + " if highlight_bigger_than_one:\n", + " chunking_handles = [\n", + " Line2D([], [], marker=\"s\", color=cmap(101), linestyle=\"None\", markersize=10,\n", + " label=\"Not Chunked Better\"),\n", + " Line2D([], [], marker=\"s\", color=cmap(-101), linestyle=\"None\", markersize=10,\n", + " label=\"Chunked Better\"),\n", + " ]\n", + "\n", + " ax.legend(\n", + " handles=chunking_handles,\n", + " loc=\"upper left\",\n", + " ncol=2,\n", + " bbox_to_anchor=(-0.5, -0.05),\n", + " fontsize=16\n", + " )\n", + "\n", + " plt.tight_layout()\n", + " \n", + " if save_fn:\n", + " plt.savefig(save_fn, dpi=300, bbox_inches='tight')\n", + " plt.close()\n", + " else:\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "678c927b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating scorecard for low bound...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating scorecard for mid bound...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating scorecard for high bound...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n" + ] + } + ], + "source": [ + "for bound_name, (data_matrix, compressors, variables) in scorecard_data.items():\n", + " print(f\"Creating scorecard for {bound_name} bound...\")\n", + " # Split into two rows for better readability.\n", + " create_compression_scorecard(\n", + " data_matrix[:, :3*len(variables)], \n", + " compressors, \n", + " variables, \n", + " metrics[:3],\n", + " ref_compressor=\"bitround-pco\",\n", + " cbar=False,\n", + " save_fn=f\"figures_updated/{bound_name}_scorecard_row1.pdf\"\n", + " )\n", + "\n", + " create_compression_scorecard(\n", + " data_matrix[:, 3*len(variables):], \n", + " compressors, \n", + " variables, \n", + " metrics[3:],\n", + " ref_compressor=\"bitround-pco\",\n", + " save_fn=f\"figures_updated/{bound_name}_scorecard_row2.pdf\"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "id": "3afb646e", + "metadata": {}, + "source": [ + "## Two-Column Scorecard" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "b6fe5f55", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating scorecard for low bound...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating scorecard for mid bound...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating scorecard for high bound...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + " plt.tight_layout()\n" + ] + } + ], + "source": [ + "for bound_name, (data_matrix, compressors, variables) in scorecard_data.items():\n", + " print(f\"Creating scorecard for {bound_name} bound...\")\n", + " # Split into two rows for better readability.\n", + " num_vars = len(variables)\n", + " create_compression_scorecard(\n", + " data_matrix[:, :2*num_vars], \n", + " compressors, \n", + " variables, \n", + " metrics[:2],\n", + " ref_compressor=\"bitround-pco\",\n", + " cbar=False,\n", + " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row1.pdf\"\n", + " )\n", + "\n", + " create_compression_scorecard(\n", + " data_matrix[:, 2*num_vars:4*num_vars], \n", + " compressors, \n", + " variables, \n", + " metrics[2:4],\n", + " ref_compressor=\"bitround-pco\",\n", + " cbar=False,\n", + " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row2.pdf\"\n", + " )\n", + "\n", + " create_compression_scorecard(\n", + " data_matrix[:, 4*num_vars:], \n", + " compressors, \n", + " variables, \n", + " metrics[4:],\n", + " ref_compressor=\"bitround-pco\",\n", + " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row3.pdf\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c546f6b0-bb7b-4646-83bc-3f502bcf6a9f", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From fd15ac3cb95f1214d6b248b3c9891accc596a748 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Thu, 5 Feb 2026 14:46:41 +0200 Subject: [PATCH 05/12] some more tweaks --- scorecards.ipynb | 330 ++++++++++-------- .../compressor/plotting/plot_metrics.py | 23 +- 2 files changed, 203 insertions(+), 150 deletions(-) diff --git a/scorecards.ipynb b/scorecards.ipynb index 64bf6ee..8beff56 100644 --- a/scorecards.ipynb +++ b/scorecards.ipynb @@ -18,13 +18,13 @@ "\n", "from pathlib import Path\n", "from climatebenchpress.compressor.plotting.plot_metrics import (\n", - " _rename_compressors, \n", + " _rename_compressors,\n", " _get_legend_name,\n", " _normalize,\n", " _get_lineinfo,\n", " DISTORTION2LEGEND_NAME,\n", " _COMPRESSOR_ORDER,\n", - " _savefig\n", + " _savefig,\n", ")" ] }, @@ -55,17 +55,29 @@ "outputs": [], "source": [ "def create_data_matrix(\n", - " df: pd.DataFrame, \n", - " error_bound: str, \n", - " metrics: list[str] = ['DSSIM', 'MAE', 'Max Absolute Error', \"Spectral Error\", \"Compression Ratio [raw B / enc B]\", 'Satisfies Bound (Value)']\n", + " df: pd.DataFrame,\n", + " error_bound: str,\n", + " metrics: list[str] = [\n", + " \"DSSIM\",\n", + " \"MAE\",\n", + " \"Max Absolute Error\",\n", + " \"Spectral Error\",\n", + " \"Compression Ratio [raw B / enc B]\",\n", + " \"Satisfies Bound (Value)\",\n", + " ],\n", "):\n", - " df_filtered = df[df['Error Bound Name'] == error_bound].copy()\n", - " df_filtered[\"Satisfies Bound (Value)\"] = df_filtered[\"Satisfies Bound (Value)\"] * 100 # Convert to percentage\n", + " df_filtered = df[df[\"Error Bound Name\"] == error_bound].copy()\n", + " df_filtered[\"Satisfies Bound (Value)\"] = (\n", + " df_filtered[\"Satisfies Bound (Value)\"] * 100\n", + " ) # Convert to percentage\n", "\n", " # Get unique variables and compressors\n", " # dataset_variables = sorted(df_filtered[['Dataset', 'Variable']].drop_duplicates().apply(lambda x: \"/\".join(x), axis=1).unique())\n", - " dataset_variables = sorted(df_filtered['Variable'].unique())\n", - " compressors = sorted(df_filtered['Compressor'].unique(), key=lambda k: _COMPRESSOR_ORDER.index(_get_legend_name(k)))\n", + " dataset_variables = sorted(df_filtered[\"Variable\"].unique())\n", + " compressors = sorted(\n", + " df_filtered[\"Compressor\"].unique(),\n", + " key=lambda k: _COMPRESSOR_ORDER.index(_get_legend_name(k)),\n", + " )\n", "\n", " column_labels = []\n", " for metric in metrics:\n", @@ -83,8 +95,8 @@ " # dataset, variable = dataset_variable.split('/')\n", " variable = dataset_variable\n", " subset = df_filtered[\n", - " (df_filtered['Compressor'] == compressor) & \n", - " (df_filtered['Variable'] == variable) #&\n", + " (df_filtered[\"Compressor\"] == compressor)\n", + " & (df_filtered[\"Variable\"] == variable) # &\n", " # (df_filtered['Dataset'] == dataset)\n", " ]\n", " if subset.empty:\n", @@ -92,17 +104,16 @@ " continue\n", "\n", " if metric in [\"DSSIM\", \"Spectral Error\"] and variable in [\"ta\", \"tos\"]:\n", - " # These variables have large regions of NaN values which makes the \n", + " # These variables have large regions of NaN values which makes the\n", " # DSSIM and Spectral Error values unreliable.\n", " continue\n", "\n", - "\n", " col_idx = j * len(dataset_variables) + k\n", " if metric in subset.columns:\n", " values = subset[metric]\n", " if len(values) == 1:\n", " data_matrix[i, col_idx] = values.iloc[0]\n", - " \n", + "\n", " return data_matrix, compressors, dataset_variables" ] }, @@ -113,11 +124,22 @@ "metadata": {}, "outputs": [], "source": [ - "df = df[~df[\"Compressor\"].isin([\n", - " \"bitround\", \"jpeg2000-conservative-abs\", \"stochround-conservative-abs\",\n", - " \"stochround-pco-conservative-abs\", \"zfp-conservative-abs\",\n", - " \"bitround-conservative-rel\", \"stochround-pco\", \"stochround\", \"zfp\", \"jpeg2000\",\n", - "])]\n", + "df = df[\n", + " ~df[\"Compressor\"].isin(\n", + " [\n", + " \"bitround\",\n", + " \"jpeg2000-conservative-abs\",\n", + " \"stochround-conservative-abs\",\n", + " \"stochround-pco-conservative-abs\",\n", + " \"zfp-conservative-abs\",\n", + " \"bitround-conservative-rel\",\n", + " \"stochround-pco\",\n", + " \"stochround\",\n", + " \"zfp\",\n", + " \"jpeg2000\",\n", + " ]\n", + " )\n", + "]\n", "df = df[~df[\"Dataset\"].str.contains(\"-tiny\")]\n", "df = df[~df[\"Dataset\"].str.contains(\"-chunked\")]\n", "df = _rename_compressors(df)" @@ -152,17 +174,11 @@ "No data for Compressor: sperr, Variable: ta\n", "No data for Compressor: sperr, Variable: tos\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: sperr, Variable: pr\n", "No data for Compressor: sperr, Variable: ta\n", "No data for Compressor: sperr, Variable: tos\n", @@ -182,17 +198,11 @@ "No data for Compressor: sperr, Variable: ta\n", "No data for Compressor: sperr, Variable: tos\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: sperr, Variable: pr\n", "No data for Compressor: sperr, Variable: ta\n", "No data for Compressor: sperr, Variable: tos\n", @@ -212,22 +222,23 @@ "No data for Compressor: sperr, Variable: ta\n", "No data for Compressor: sperr, Variable: tos\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", - "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n", "No data for Compressor: safeguarded-sperr, Variable: pr\n", - "No data for Compressor: safeguarded-sperr, Variable: rlut\n" + "No data for Compressor: safeguarded-sperr, Variable: pr\n" ] } ], "source": [ - "metrics = ['DSSIM', 'MAE', 'Max Absolute Error', \"Spectral Error\", \"Compression Ratio [raw B / enc B]\", 'Satisfies Bound (Value)']\n", + "metrics = [\n", + " \"DSSIM\",\n", + " \"MAE\",\n", + " \"Max Absolute Error\",\n", + " \"Spectral Error\",\n", + " \"Compression Ratio [raw B / enc B]\",\n", + " \"Satisfies Bound (Value)\",\n", + "]\n", "scorecard_data = {}\n", "for error_bound in [\"low\", \"mid\", \"high\"]:\n", " scorecard_data[error_bound] = create_data_matrix(df, error_bound, metrics)" @@ -249,11 +260,10 @@ "outputs": [], "source": [ "METRICS2NAME = {\n", - " # \"Max Absolute Error\": \"MaxAE\",\n", + " \"DSSIM\": \"dSSIM\",\n", " \"MAE\": \"Mean Absolute Error\",\n", - " \"Spatial Relative Error (Value)\": \"SRE\",\n", " \"Compression Ratio [raw B / enc B]\": \"Compression Ratio\",\n", - " \"Satisfies Bound (Value)\": r\"% of Pixels Exceeding Error Bound\",\n", + " \"Satisfies Bound (Value)\": r\"% of Pixels Violating Error Bound\",\n", "}\n", "\n", "VARIABLE2NAME = {\n", @@ -266,43 +276,44 @@ " \"era5-hurricane\": \"h-\",\n", "}\n", "\n", + "\n", "def get_variable_label(variable):\n", - " dataset, var_name = variable.split('/')\n", + " dataset, var_name = variable.split(\"/\")\n", " prefix = DATASET2PREFIX.get(dataset, \"\")\n", " var_name = VARIABLE2NAME.get(var_name, var_name)\n", " return f\"{prefix}{var_name}\"\n", "\n", "\n", "def create_compression_scorecard(\n", - " data_matrix, \n", - " compressors, \n", - " variables, \n", - " metrics, \n", + " data_matrix,\n", + " compressors,\n", + " variables,\n", + " metrics,\n", " cbar=True,\n", - " ref_compressor='sz3', \n", + " ref_compressor=\"sz3\",\n", " higher_better_metrics=[\"DSSIM\", \"Compression Ratio [raw B / enc B]\"],\n", " save_fn=None,\n", " compare_against_0=False,\n", - " highlight_bigger_than_one=False\n", + " highlight_bigger_than_one=False,\n", "):\n", " \"\"\"\n", " Create a scorecard plot similar to the weather forecasting example\n", - " \n", + "\n", " Parameters:\n", " - data_matrix: 2D array with compressors as rows, metric-variable combinations as columns\n", " - compressors: list of compressor names\n", - " - variables: list of variable names \n", + " - variables: list of variable names\n", " - metrics: list of metric names\n", " - ref_compressor: reference compressor for relative calculations\n", " - save_fn: filename to save plot (optional)\n", " \"\"\"\n", - " \n", + "\n", " # Calculate relative differences vs reference compressor\n", " ref_idx = compressors.index(ref_compressor)\n", " ref_values = data_matrix[ref_idx, :]\n", " if compare_against_0:\n", " ref_values = np.zeros_like(data_matrix[ref_idx, :])\n", - " \n", + "\n", " relative_matrix = np.full_like(data_matrix, np.nan)\n", " if highlight_bigger_than_one:\n", " relative_matrix = np.sign(data_matrix) * 101\n", @@ -319,53 +330,68 @@ " ref_val = 1e-10 # Avoid division by zero\n", " if metrics[j // len(variables)] in higher_better_metrics:\n", " # Higher is better metrics\n", - " relative_matrix[i, j] = (ref_values[j] - data_matrix[i, j]) / ref_val * 100\n", + " relative_matrix[i, j] = (\n", + " (ref_values[j] - data_matrix[i, j]) / ref_val * 100\n", + " )\n", " elif metrics[j // len(variables)] == \"Satisfies Bound (Value)\":\n", " relative_matrix[i, j] = 100 if data_matrix[i, j] != 0 else 0\n", " else:\n", - " relative_matrix[i, j] = (data_matrix[i, j] - ref_values[j]) / ref_val * 100\n", + " relative_matrix[i, j] = (\n", + " (data_matrix[i, j] - ref_values[j]) / ref_val * 100\n", + " )\n", "\n", " # Set up colormap - similar to original\n", - " reds = sns.color_palette('Reds', 6)\n", - " blues = sns.color_palette('Blues_r', 6)\n", + " reds = sns.color_palette(\"Reds\", 6)\n", + " blues = sns.color_palette(\"Blues_r\", 6)\n", " cmap = mpl.colors.ListedColormap(blues + [(0.95, 0.95, 0.95)] + reds)\n", " # cb_levels = [-50, -20, -10, -5, -2, -1, 1, 2, 5, 10, 20, 50]\n", " # cb_levels = [-75, -50, -25, -10, -5, -1, 1, 5, 10, 25, 50, 75]\n", " cb_levels = [-100, -75, -50, -25, -10, -1, 1, 10, 25, 50, 75, 100]\n", "\n", - " norm = mpl.colors.BoundaryNorm(cb_levels, cmap.N, extend='both')\n", - " \n", + " norm = mpl.colors.BoundaryNorm(cb_levels, cmap.N, extend=\"both\")\n", + "\n", " # Calculate figure dimensions\n", " ncompressors = len(compressors)\n", " nvariables = len(variables)\n", " nmetrics = len(metrics)\n", - " \n", + "\n", " panel_width = (2.5 / 5) * nvariables\n", " label_width = 1.5 * panel_width\n", " padding_right = 0.1\n", " panel_height = panel_width / nvariables\n", - " \n", + "\n", " title_height = panel_height * 1.25\n", " cbar_height = panel_height * 2\n", " spacing_height = panel_height * 0.1\n", " spacing_width = panel_height * 0.2\n", - " \n", - " total_width = label_width + nmetrics * panel_width + (nmetrics - 1) * spacing_width + padding_right\n", - " total_height = title_height + cbar_height + ncompressors * panel_height + (ncompressors - 1) * spacing_height\n", - " \n", + "\n", + " total_width = (\n", + " label_width\n", + " + nmetrics * panel_width\n", + " + (nmetrics - 1) * spacing_width\n", + " + padding_right\n", + " )\n", + " total_height = (\n", + " title_height\n", + " + cbar_height\n", + " + ncompressors * panel_height\n", + " + (ncompressors - 1) * spacing_height\n", + " )\n", + "\n", " # Create figure and gridspec\n", " fig = plt.figure(figsize=(total_width, total_height))\n", " gs = mpl.gridspec.GridSpec(\n", - " ncompressors, nmetrics,\n", + " ncompressors,\n", + " nmetrics,\n", " figure=fig,\n", " left=label_width / total_width,\n", " right=1 - padding_right / total_width,\n", " top=1 - (title_height / total_height),\n", " bottom=cbar_height / total_height,\n", " hspace=spacing_height / panel_height,\n", - " wspace=spacing_width / panel_width\n", + " wspace=spacing_width / panel_width,\n", " )\n", - " \n", + "\n", " # Plot each panel\n", " for row, compressor in enumerate(compressors):\n", " for col, metric in enumerate(metrics):\n", @@ -377,32 +403,39 @@ "\n", " rel_values = relative_matrix[row, start_col:end_col].reshape(1, -1)\n", " abs_values = data_matrix[row, start_col:end_col]\n", - " \n", + "\n", " # Create heatmap\n", - " img = ax.imshow(rel_values, aspect='auto', cmap=cmap, norm=norm)\n", - " \n", + " img = ax.imshow(rel_values, aspect=\"auto\", cmap=cmap, norm=norm)\n", + "\n", " # Customize axes\n", " ax.set_xticks([])\n", " ax.set_xticklabels([])\n", " ax.set_yticks([])\n", " ax.set_yticklabels([])\n", - " \n", + "\n", " # Add white grid lines\n", " for i in range(nvariables):\n", " rect = mpl.patches.Rectangle(\n", - " (i - 0.5, -0.5), 1, 1,\n", - " linewidth=1, edgecolor='white', facecolor='none'\n", + " (i - 0.5, -0.5),\n", + " 1,\n", + " 1,\n", + " linewidth=1,\n", + " edgecolor=\"white\",\n", + " facecolor=\"none\",\n", " )\n", " ax.add_patch(rect)\n", - " \n", + "\n", " # Add absolute values as text\n", " for i, val in enumerate(abs_values):\n", " # Ensure we don't have black text on dark background\n", " color = \"black\" if abs(rel_values[0, i]) < 75 else \"white\"\n", " fontsize = 10\n", " # Format numbers appropriately\n", - " if metric in [\"DSSIM\", \"Spectral Error\"] and variables[i] in [\"ta\", \"tos\"]:\n", - " # These variables have large regions of NaN values which makes the \n", + " if metric in [\"DSSIM\", \"Spectral Error\"] and variables[i] in [\n", + " \"ta\",\n", + " \"tos\",\n", + " ]:\n", + " # These variables have large regions of NaN values which makes the\n", " # DSSIM and Spectral Error values unreliable.\n", " text = \"N/A\"\n", " color = \"black\"\n", @@ -426,20 +459,20 @@ " else:\n", " text = f\"{val:.2f}\"\n", " ax.text(\n", - " i, \n", - " 0, \n", - " text, \n", - " ha='center', \n", - " va='center', \n", - " fontsize=fontsize, \n", - " color=color\n", + " i, 0, text, ha=\"center\", va=\"center\", fontsize=fontsize, color=color\n", " )\n", "\n", " # Add row labels (compressor names)\n", " if col == 0:\n", - " ax.set_ylabel(_get_legend_name(compressor), rotation=0, ha='right', va='center',\n", - " labelpad=10, fontsize=14)\n", - " \n", + " ax.set_ylabel(\n", + " _get_legend_name(compressor),\n", + " rotation=0,\n", + " ha=\"right\",\n", + " va=\"center\",\n", + " labelpad=10,\n", + " fontsize=14,\n", + " )\n", + "\n", " # Add column titles (variable names)\n", " if row == 0:\n", " # ax.set_title(VARIABLE2NAME.get(variable, variable), fontsize=10, pad=10)\n", @@ -450,37 +483,56 @@ " # ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)\n", " # ax.set_xticks(range(nmetrics))\n", " # ax.set_xticklabels(\n", - " # [METRICS2NAME.get(m, m) for m in metrics], \n", - " # rotation=45, \n", + " # [METRICS2NAME.get(m, m) for m in metrics],\n", + " # rotation=45,\n", " # ha='left', fontsize=8)\n", " ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)\n", " ax.set_xticks(range(nvariables))\n", " ax.set_xticklabels(\n", " [VARIABLE2NAME.get(v, v) for v in variables],\n", " rotation=45,\n", - " ha='left', fontsize=12)\n", - " \n", + " ha=\"left\",\n", + " fontsize=12,\n", + " )\n", + "\n", " # Style spines\n", " for spine in ax.spines.values():\n", - " spine.set_color('0.7')\n", - " \n", + " spine.set_color(\"0.7\")\n", + "\n", " # Add colorbar\n", " if cbar and not highlight_bigger_than_one:\n", " rel_cbar_height = cbar_height / total_height\n", " cax = fig.add_axes((0.4, rel_cbar_height * 0.3, 0.5, rel_cbar_height * 0.2))\n", - " cb = fig.colorbar(img, cax=cax, orientation='horizontal')\n", + " cb = fig.colorbar(img, cax=cax, orientation=\"horizontal\")\n", " cb.ax.set_xticks(cb_levels)\n", " if highlight_bigger_than_one:\n", - " cb.ax.set_xlabel('Better ← |non-chunked - chunked| → Worse', fontsize=16)\n", + " cb.ax.set_xlabel(\"Better ← |non-chunked - chunked| → Worse\", fontsize=16)\n", " else:\n", - " cb.ax.set_xlabel(f'Better ← % difference vs {_get_legend_name(ref_compressor)} → Worse', fontsize=16)\n", - " \n", + " cb.ax.set_xlabel(\n", + " f\"Better ← % difference vs {_get_legend_name(ref_compressor)} → Worse\",\n", + " fontsize=16,\n", + " )\n", + "\n", " if highlight_bigger_than_one:\n", " chunking_handles = [\n", - " Line2D([], [], marker=\"s\", color=cmap(101), linestyle=\"None\", markersize=10,\n", - " label=\"Not Chunked Better\"),\n", - " Line2D([], [], marker=\"s\", color=cmap(-101), linestyle=\"None\", markersize=10,\n", - " label=\"Chunked Better\"),\n", + " Line2D(\n", + " [],\n", + " [],\n", + " marker=\"s\",\n", + " color=cmap(101),\n", + " linestyle=\"None\",\n", + " markersize=10,\n", + " label=\"Not Chunked Better\",\n", + " ),\n", + " Line2D(\n", + " [],\n", + " [],\n", + " marker=\"s\",\n", + " color=cmap(-101),\n", + " linestyle=\"None\",\n", + " markersize=10,\n", + " label=\"Chunked Better\",\n", + " ),\n", " ]\n", "\n", " ax.legend(\n", @@ -488,13 +540,13 @@ " loc=\"upper left\",\n", " ncol=2,\n", " bbox_to_anchor=(-0.5, -0.05),\n", - " fontsize=16\n", + " fontsize=16,\n", " )\n", "\n", " plt.tight_layout()\n", - " \n", + "\n", " if save_fn:\n", - " plt.savefig(save_fn, dpi=300, bbox_inches='tight')\n", + " plt.savefig(save_fn, dpi=300, bbox_inches=\"tight\")\n", " plt.close()\n", " else:\n", " plt.show()" @@ -517,9 +569,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n" ] }, @@ -534,9 +586,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n" ] }, @@ -551,9 +603,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n" ] } @@ -563,22 +615,22 @@ " print(f\"Creating scorecard for {bound_name} bound...\")\n", " # Split into two rows for better readability.\n", " create_compression_scorecard(\n", - " data_matrix[:, :3*len(variables)], \n", - " compressors, \n", - " variables, \n", + " data_matrix[:, : 3 * len(variables)],\n", + " compressors,\n", + " variables,\n", " metrics[:3],\n", " ref_compressor=\"bitround-pco\",\n", " cbar=False,\n", - " save_fn=f\"figures_updated/{bound_name}_scorecard_row1.pdf\"\n", + " save_fn=f\"figures_updated/{bound_name}_scorecard_row1.pdf\",\n", " )\n", "\n", " create_compression_scorecard(\n", - " data_matrix[:, 3*len(variables):], \n", - " compressors, \n", - " variables, \n", + " data_matrix[:, 3 * len(variables) :],\n", + " compressors,\n", + " variables,\n", " metrics[3:],\n", " ref_compressor=\"bitround-pco\",\n", - " save_fn=f\"figures_updated/{bound_name}_scorecard_row2.pdf\"\n", + " save_fn=f\"figures_updated/{bound_name}_scorecard_row2.pdf\",\n", " )" ] }, @@ -607,11 +659,11 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n" ] }, @@ -626,11 +678,11 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n" ] }, @@ -645,11 +697,11 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_14073/1415815655.py:244: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n" ] } @@ -660,32 +712,32 @@ " # Split into two rows for better readability.\n", " num_vars = len(variables)\n", " create_compression_scorecard(\n", - " data_matrix[:, :2*num_vars], \n", - " compressors, \n", - " variables, \n", + " data_matrix[:, : 2 * num_vars],\n", + " compressors,\n", + " variables,\n", " metrics[:2],\n", " ref_compressor=\"bitround-pco\",\n", " cbar=False,\n", - " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row1.pdf\"\n", + " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row1.pdf\",\n", " )\n", "\n", " create_compression_scorecard(\n", - " data_matrix[:, 2*num_vars:4*num_vars], \n", - " compressors, \n", - " variables, \n", + " data_matrix[:, 2 * num_vars : 4 * num_vars],\n", + " compressors,\n", + " variables,\n", " metrics[2:4],\n", " ref_compressor=\"bitround-pco\",\n", " cbar=False,\n", - " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row2.pdf\"\n", + " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row2.pdf\",\n", " )\n", "\n", " create_compression_scorecard(\n", - " data_matrix[:, 4*num_vars:], \n", - " compressors, \n", - " variables, \n", + " data_matrix[:, 4 * num_vars :],\n", + " compressors,\n", + " variables,\n", " metrics[4:],\n", " ref_compressor=\"bitround-pco\",\n", - " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row3.pdf\"\n", + " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row3.pdf\",\n", " )" ] }, diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index 2e87acf..d15403e 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -73,7 +73,7 @@ def _get_lineinfo(compressor: str) -> tuple[str, str]: DISTORTION2LEGEND_NAME = { "Relative MAE": "Mean Absolute Error", - "Relative DSSIM": "DSSIM", + "Relative dSSIM": "dSSIM", "Relative MaxAbsError": "Max Absolute Error", "Spectral Error": "Spectral Error", } @@ -155,7 +155,7 @@ def plot_metrics( for metric in [ "Relative MAE", - "Relative DSSIM", + "Relative dSSIM", "Relative MaxAbsError", "Relative SpectralError", ]: @@ -214,7 +214,7 @@ def _normalize(data): normalize_vars = [ ("Compression Ratio [raw B / enc B]", "Relative CR"), ("MAE", "Relative MAE"), - ("DSSIM", "Relative DSSIM"), + ("DSSIM", "Relative dSSIM"), ("Max Absolute Error", "Relative MaxAbsError"), ("Spectral Error", "Relative SpectralError"), ] @@ -528,12 +528,12 @@ def _plot_aggregated_rd_curve( right=True, ) plt.xlabel( - r"Mean Normalized Compression Ratio ($\uparrow$)", + r"Mean Normalised Compression Ratio ($\uparrow$)", fontsize=16, ) metric_name = DISTORTION2LEGEND_NAME.get(distortion_metric, distortion_metric) plt.ylabel( - rf"Mean Normalized {metric_name} ($\downarrow$)", + rf"Mean Normalised {metric_name} ($\downarrow$)", fontsize=16, ) plt.legend( @@ -545,7 +545,7 @@ def _plot_aggregated_rd_curve( ) arrow_color = "black" - if "DSSIM" in distortion_metric: + if "dSSIM" in distortion_metric: # Add an arrow pointing into the top right corner plt.annotate( "", @@ -573,7 +573,7 @@ def _plot_aggregated_rd_curve( ) # Correct the y-label to point upwards plt.ylabel( - rf"Mean Normalized {metric_name} ($\uparrow$)", + rf"Mean Normalised {metric_name} ($\uparrow$)", fontsize=16, ) else: @@ -601,7 +601,7 @@ def _plot_aggregated_rd_curve( ha="center", ) if ( - "DSSIM" in distortion_metric + "dSSIM" in distortion_metric or "MaxAbsError" in distortion_metric or "SpectralError" in distortion_metric ): @@ -736,13 +736,14 @@ def _plot_grouped_df( ax.set_title(f"{error_bound.capitalize()} Error Bound", fontsize=14) ax.grid(axis="y", linestyle="--", alpha=0.7) if i == 0: - ax.legend(fontsize=14) + ax.legend(fontsize=14, loc="upper left") ax.set_ylabel(ylabel, fontsize=14) + if i == 1: ax.annotate( "Better", - xy=(0.1, 0.8), + xy=(0.1, 0.75), xycoords="axes fraction", - xytext=(0.1, 0.95), + xytext=(0.1, 0.9), textcoords="axes fraction", arrowprops=dict(arrowstyle="->", lw=3, color="black"), fontsize=14, From 77515d9d2f2d3475068b7396cfd4301bc832fdc8 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Thu, 5 Feb 2026 20:37:54 +0200 Subject: [PATCH 06/12] tweak throughput plots --- scorecards.ipynb | 7 ----- .../compressor/plotting/plot_metrics.py | 31 ++++++++++++------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/scorecards.ipynb b/scorecards.ipynb index 8beff56..f8b7b34 100644 --- a/scorecards.ipynb +++ b/scorecards.ipynb @@ -12,19 +12,12 @@ "import matplotlib.pyplot as plt\n", "import matplotlib as mpl\n", "import seaborn as sns\n", - "from matplotlib.colors import LinearSegmentedColormap\n", - "import matplotlib.patches as mpatches\n", "from matplotlib.lines import Line2D\n", "\n", - "from pathlib import Path\n", "from climatebenchpress.compressor.plotting.plot_metrics import (\n", " _rename_compressors,\n", " _get_legend_name,\n", - " _normalize,\n", - " _get_lineinfo,\n", - " DISTORTION2LEGEND_NAME,\n", " _COMPRESSOR_ORDER,\n", - " _savefig,\n", ")" ] }, diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index d15403e..780fbe9 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -614,24 +614,23 @@ def _plot_aggregated_rd_curve( def _plot_throughput(df, outfile: None | Path = None): - # Transform throughput measurements from raw B/s to s/MB for better comparison - # with instruction count measurements. encode_col = "Encode Throughput [raw B / s]" decode_col = "Decode Throughput [raw B / s]" new_df = df[["Compressor", "Error Bound Name", encode_col, decode_col]].copy() - transformed_encode_col = "Encode Throughput [s / MB]" - transformed_decode_col = "Decode Throughput [s / MB]" - new_df[transformed_encode_col] = 1e6 / new_df[encode_col] - new_df[transformed_decode_col] = 1e6 / new_df[decode_col] + transformed_encode_col = "Encode Throughput [MiB / s]" + transformed_decode_col = "Decode Throughput [MiB / s]" + new_df[transformed_encode_col] = new_df[encode_col] / (2**20) + new_df[transformed_decode_col] = new_df[decode_col] / (2**20) encode_col, decode_col = transformed_encode_col, transformed_decode_col grouped_df = _get_median_and_quantiles(new_df, encode_col, decode_col) _plot_grouped_df( grouped_df, title="", - ylabel="Throughput [s / MB]", + ylabel="Throughput [MiB / s]", logy=True, outfile=outfile, + up=True, ) @@ -645,6 +644,7 @@ def _plot_instruction_count(df, outfile: None | Path = None): ylabel="Instructions [# / raw B]", logy=True, outfile=outfile, + up=False, ) @@ -679,7 +679,12 @@ def _get_median_and_quantiles(df, encode_column, decode_column): def _plot_grouped_df( - grouped_df, title, ylabel, outfile: None | Path = None, logy=False + grouped_df, + title, + ylabel, + outfile: None | Path = None, + logy=False, + up=False, ): fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharex=True, sharey=True) @@ -736,16 +741,18 @@ def _plot_grouped_df( ax.set_title(f"{error_bound.capitalize()} Error Bound", fontsize=14) ax.grid(axis="y", linestyle="--", alpha=0.7) if i == 0: - ax.legend(fontsize=14, loc="upper left") + ax.legend( + fontsize=14, loc="lower left" if up else "upper left", framealpha=0.9 + ) ax.set_ylabel(ylabel, fontsize=14) if i == 1: ax.annotate( "Better", - xy=(0.1, 0.75), + xy=(0.51, 0.75), xycoords="axes fraction", - xytext=(0.1, 0.9), + xytext=(0.51, 0.92), textcoords="axes fraction", - arrowprops=dict(arrowstyle="->", lw=3, color="black"), + arrowprops=dict(arrowstyle="<-" if up else "->", lw=3, color="black"), fontsize=14, ha="center", va="bottom", From 6f8fa693061afd1f2a96182d7688dcc5d73f2075 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Fri, 6 Feb 2026 06:38:49 +0200 Subject: [PATCH 07/12] Use Crash instead of Fail --- scorecards.ipynb | 125 ++++------------------------------------------- 1 file changed, 9 insertions(+), 116 deletions(-) diff --git a/scorecards.ipynb b/scorecards.ipynb index f8b7b34..b409d06 100644 --- a/scorecards.ipynb +++ b/scorecards.ipynb @@ -433,7 +433,7 @@ " text = \"N/A\"\n", " color = \"black\"\n", " elif np.isnan(val):\n", - " text = \"Fail\"\n", + " text = \"Crash\"\n", " color = \"black\"\n", " elif abs(val) > 10_000:\n", " text = f\"{val:.1e}\"\n", @@ -562,9 +562,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n" ] }, @@ -579,9 +579,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n" ] }, @@ -596,9 +596,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", + "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", " plt.tight_layout()\n" ] } @@ -614,7 +614,7 @@ " metrics[:3],\n", " ref_compressor=\"bitround-pco\",\n", " cbar=False,\n", - " save_fn=f\"figures_updated/{bound_name}_scorecard_row1.pdf\",\n", + " save_fn=f\"scorecards/{bound_name}_scorecard_row1.pdf\",\n", " )\n", "\n", " create_compression_scorecard(\n", @@ -623,114 +623,7 @@ " variables,\n", " metrics[3:],\n", " ref_compressor=\"bitround-pco\",\n", - " save_fn=f\"figures_updated/{bound_name}_scorecard_row2.pdf\",\n", - " )" - ] - }, - { - "cell_type": "markdown", - "id": "3afb646e", - "metadata": {}, - "source": [ - "## Two-Column Scorecard" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "b6fe5f55", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Creating scorecard for low bound...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Creating scorecard for mid bound...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Creating scorecard for high bound...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_9490/2040868336.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n" - ] - } - ], - "source": [ - "for bound_name, (data_matrix, compressors, variables) in scorecard_data.items():\n", - " print(f\"Creating scorecard for {bound_name} bound...\")\n", - " # Split into two rows for better readability.\n", - " num_vars = len(variables)\n", - " create_compression_scorecard(\n", - " data_matrix[:, : 2 * num_vars],\n", - " compressors,\n", - " variables,\n", - " metrics[:2],\n", - " ref_compressor=\"bitround-pco\",\n", - " cbar=False,\n", - " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row1.pdf\",\n", - " )\n", - "\n", - " create_compression_scorecard(\n", - " data_matrix[:, 2 * num_vars : 4 * num_vars],\n", - " compressors,\n", - " variables,\n", - " metrics[2:4],\n", - " ref_compressor=\"bitround-pco\",\n", - " cbar=False,\n", - " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row2.pdf\",\n", - " )\n", - "\n", - " create_compression_scorecard(\n", - " data_matrix[:, 4 * num_vars :],\n", - " compressors,\n", - " variables,\n", - " metrics[4:],\n", - " ref_compressor=\"bitround-pco\",\n", - " save_fn=f\"scorecards_2cols/{bound_name}_scorecard_row3.pdf\",\n", + " save_fn=f\"scorecards/{bound_name}_scorecard_row2.pdf\",\n", " )" ] }, From 4a26f5c03f13967cba32a5fb5082b642e9a0e6d7 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Fri, 6 Feb 2026 09:02:53 +0200 Subject: [PATCH 08/12] add crash->crash arrows --- scorecards.ipynb | 77 ++++++++++++++++++------------------------------ 1 file changed, 29 insertions(+), 48 deletions(-) diff --git a/scorecards.ipynb b/scorecards.ipynb index b409d06..d787693 100644 --- a/scorecards.ipynb +++ b/scorecards.ipynb @@ -407,13 +407,15 @@ " ax.set_yticklabels([])\n", "\n", " # Add white grid lines\n", - " for i in range(nvariables):\n", + " for i in range(1, nvariables):\n", " rect = mpl.patches.Rectangle(\n", " (i - 0.5, -0.5),\n", " 1,\n", " 1,\n", " linewidth=1,\n", - " edgecolor=\"white\",\n", + " edgecolor=\"lightgrey\"\n", + " if np.isnan(abs_values[i]) and np.isnan(abs_values[i - 1])\n", + " else \"white\",\n", " facecolor=\"none\",\n", " )\n", " ax.add_patch(rect)\n", @@ -455,6 +457,27 @@ " i, 0, text, ha=\"center\", va=\"center\", fontsize=fontsize, color=color\n", " )\n", "\n", + " if (\n", + " row > 0\n", + " and np.isnan(val)\n", + " and np.isnan(data_matrix[row - 1, col * nvariables + i])\n", + " and compressor == f\"safeguarded-{compressors[row - 1]}\"\n", + " and not (\n", + " metric in [\"DSSIM\", \"Spectral Error\"]\n", + " and variables[i]\n", + " in [\n", + " \"ta\",\n", + " \"tos\",\n", + " ]\n", + " )\n", + " ):\n", + " ax.annotate(\n", + " \"\",\n", + " xy=(i, -0.15),\n", + " xytext=(i, -0.9),\n", + " arrowprops=dict(arrowstyle=\"->\", lw=2, color=\"lightgrey\"),\n", + " )\n", + "\n", " # Add row labels (compressor names)\n", " if col == 0:\n", " ax.set_ylabel(\n", @@ -536,7 +559,7 @@ " fontsize=16,\n", " )\n", "\n", - " plt.tight_layout()\n", + " # plt.tight_layout()\n", "\n", " if save_fn:\n", " plt.savefig(save_fn, dpi=300, bbox_inches=\"tight\")\n", @@ -555,52 +578,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "Creating scorecard for low bound...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Creating scorecard for mid bound...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ + "Creating scorecard for low bound...\n", + "Creating scorecard for mid bound...\n", "Creating scorecard for high bound...\n" ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n", - "/var/folders/8v/swxsmn0d4vz5yzwjhf6bc26x3g7lq6/T/ipykernel_52131/2725215813.py:285: UserWarning: This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.\n", - " plt.tight_layout()\n" - ] } ], "source": [ @@ -630,7 +611,7 @@ { "cell_type": "code", "execution_count": null, - "id": "c546f6b0-bb7b-4646-83bc-3f502bcf6a9f", + "id": "c2d8dea1-cd87-48d1-9d5b-8fe106183cbf", "metadata": {}, "outputs": [], "source": [] From 1b5ffa4953149838fe896b2ce7323c098055cea2 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Fri, 6 Feb 2026 09:07:02 +0200 Subject: [PATCH 09/12] improve violation metric title --- scorecards.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scorecards.ipynb b/scorecards.ipynb index d787693..b95a393 100644 --- a/scorecards.ipynb +++ b/scorecards.ipynb @@ -256,7 +256,7 @@ " \"DSSIM\": \"dSSIM\",\n", " \"MAE\": \"Mean Absolute Error\",\n", " \"Compression Ratio [raw B / enc B]\": \"Compression Ratio\",\n", - " \"Satisfies Bound (Value)\": r\"% of Pixels Violating Error Bound\",\n", + " \"Satisfies Bound (Value)\": r\"% of Data Points Violating the Error Bound\",\n", "}\n", "\n", "VARIABLE2NAME = {\n", From 5c6aff54cf26608df6f281e043622d0325b1dfa5 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Fri, 6 Feb 2026 09:16:24 +0200 Subject: [PATCH 10/12] highlight safeguards in throughput --- src/climatebenchpress/compressor/plotting/plot_metrics.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index 780fbe9..3c5d8ca 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -716,7 +716,12 @@ def _plot_grouped_df( bound_data["encode_upper_quantile"], ], label="Encoding", + edgecolor="white", + linewidth=0, color=[_get_lineinfo(comp)[0] for comp in compressors], + hatch=[ + "O" if comp.startswith("safeguarded-") else "" for comp in compressors + ], ) # Plot decode throughput @@ -732,6 +737,9 @@ def _plot_grouped_df( edgecolor=[_get_lineinfo(comp)[0] for comp in compressors], fill=False, linewidth=4, + hatch=[ + "O" if comp.startswith("safeguarded-") else "" for comp in compressors + ], ) # Add labels and title From fe9e313581eecf66974297a1c0aeeef0cb606b57 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Fri, 6 Feb 2026 10:23:36 +0200 Subject: [PATCH 11/12] change throughout legend labels --- src/climatebenchpress/compressor/plotting/plot_metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index 3c5d8ca..fc7cd7f 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -715,7 +715,7 @@ def _plot_grouped_df( bound_data["encode_lower_quantile"], bound_data["encode_upper_quantile"], ], - label="Encoding", + label="Compression", edgecolor="white", linewidth=0, color=[_get_lineinfo(comp)[0] for comp in compressors], @@ -733,7 +733,7 @@ def _plot_grouped_df( bound_data["decode_lower_quantile"], bound_data["decode_upper_quantile"], ], - label="Decoding", + label="Decompression", edgecolor=[_get_lineinfo(comp)[0] for comp in compressors], fill=False, linewidth=4, From 3ff315da2d9fe7331c04896444b00b291ce2a306 Mon Sep 17 00:00:00 2001 From: Juniper Tyree Date: Tue, 10 Feb 2026 21:58:01 +0200 Subject: [PATCH 12/12] Improve dSSIM rd curves --- .../compressor/plotting/plot_metrics.py | 31 ++++++++++++++++--- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index fc7cd7f..1b5b1bb 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -146,7 +146,7 @@ def plot_metrics( # ) df = _rename_compressors(df) - normalized_df = _normalize(df) + normalized_df, normalized_mean_std = _normalize(df) _plot_bound_violations( normalized_df, bound_names, plots_path / "bound_violations.pdf" ) @@ -164,6 +164,7 @@ def plot_metrics( normalized_df, compression_metric="Relative CR", distortion_metric=metric, + mean_std=normalized_mean_std[metric], outfile=plots_path / f"rd_curve_{metric.lower().replace(' ', '_')}.pdf", agg="mean", bound_names=bound_names, @@ -173,6 +174,7 @@ def plot_metrics( normalized_df, compression_metric="Relative CR", distortion_metric=metric, + mean_std=normalized_mean_std[metric], outfile=plots_path / f"full_rd_curve_{metric.lower().replace(' ', '_')}.pdf", agg="mean", @@ -224,6 +226,7 @@ def _normalize(data): dssim_unreliable = normalized["Variable"].isin(["ta", "tos"]) normalized.loc[dssim_unreliable, "DSSIM"] = np.nan + normalize_mean_std = dict() for col, new_col in normalize_vars: mean_std = dict() for var in variables: @@ -239,7 +242,9 @@ def _normalize(data): axis=1, ) - return normalized + normalize_mean_std[new_col] = mean_std + + return normalized, normalize_mean_std def _plot_per_variable_metrics( @@ -434,6 +439,7 @@ def _plot_aggregated_rd_curve( normalized_df, compression_metric, distortion_metric, + mean_std, outfile: None | Path = None, agg="median", bound_names=["low", "mid", "high"], @@ -546,10 +552,27 @@ def _plot_aggregated_rd_curve( arrow_color = "black" if "dSSIM" in distortion_metric: + # Annotate dSSIM = 1, accounting for the normalization + dssim_one = getattr(np, f"nan{agg}")( + [(1 - ms[0]) / ms[1] for ms in mean_std.values()] + ) + plt.axhline(dssim_one, c="k", ls="--") + plt.text( + np.percentile(plt.xlim(), 63), + dssim_one, + "dSSIM = 1", + fontsize=16, + fontweight="bold", + color="black", + ha="center", + va="center", + bbox=dict(edgecolor="none", facecolor="w", alpha=0.85), + ) + # Add an arrow pointing into the top right corner plt.annotate( "", - xy=(0.95, 0.95), + xy=(0.95, 0.875 if remove_outliers else 0.9), xycoords="axes fraction", xytext=(-60, -50), textcoords="offset points", @@ -562,7 +585,7 @@ def _plot_aggregated_rd_curve( # Attach the text to the lower left of the arrow plt.text( 0.83, - 0.92, + 0.845 if remove_outliers else 0.87, "Better", transform=plt.gca().transAxes, fontsize=16,