diff --git a/scorecards.ipynb b/scorecards.ipynb new file mode 100644 index 0000000..b95a393 --- /dev/null +++ b/scorecards.ipynb @@ -0,0 +1,641 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "e7ab252b", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib as mpl\n", + "import seaborn as sns\n", + "from matplotlib.lines import Line2D\n", + "\n", + "from climatebenchpress.compressor.plotting.plot_metrics import (\n", + " _rename_compressors,\n", + " _get_legend_name,\n", + " _COMPRESSOR_ORDER,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "4ee15a6b", + "metadata": {}, + "source": [ + "# Process results" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "2d242e7c", + "metadata": {}, + "outputs": [], + "source": [ + "results_file = \"metrics/all_results.csv\"\n", + "df = pd.read_csv(results_file)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ec124c29", + "metadata": {}, + "outputs": [], + "source": [ + "def create_data_matrix(\n", + " df: pd.DataFrame,\n", + " error_bound: str,\n", + " metrics: list[str] = [\n", + " \"DSSIM\",\n", + " \"MAE\",\n", + " \"Max Absolute Error\",\n", + " \"Spectral Error\",\n", + " \"Compression Ratio [raw B / enc B]\",\n", + " \"Satisfies Bound (Value)\",\n", + " ],\n", + "):\n", + " df_filtered = df[df[\"Error Bound Name\"] == error_bound].copy()\n", + " df_filtered[\"Satisfies Bound (Value)\"] = (\n", + " df_filtered[\"Satisfies Bound (Value)\"] * 100\n", + " ) # Convert to percentage\n", + "\n", + " # Get unique variables and compressors\n", + " # dataset_variables = sorted(df_filtered[['Dataset', 'Variable']].drop_duplicates().apply(lambda x: \"/\".join(x), axis=1).unique())\n", + " dataset_variables = sorted(df_filtered[\"Variable\"].unique())\n", + " compressors = sorted(\n", + " df_filtered[\"Compressor\"].unique(),\n", + " key=lambda k: _COMPRESSOR_ORDER.index(_get_legend_name(k)),\n", + " )\n", + "\n", + " column_labels = []\n", + " for metric in metrics:\n", + " for dataset_variable in dataset_variables:\n", + " column_labels.append(f\"{dataset_variable}\\n{metric}\")\n", + "\n", + " # Initialize the data matrix\n", + " data_matrix = np.full((len(compressors), len(column_labels)), np.nan)\n", + "\n", + " # Fill the matrix with data\n", + " for i, compressor in enumerate(compressors):\n", + " for j, metric in enumerate(metrics):\n", + " for k, dataset_variable in enumerate(dataset_variables):\n", + " # Get data for this compressor-variable combination\n", + " # dataset, variable = dataset_variable.split('/')\n", + " variable = dataset_variable\n", + " subset = df_filtered[\n", + " (df_filtered[\"Compressor\"] == compressor)\n", + " & (df_filtered[\"Variable\"] == variable) # &\n", + " # (df_filtered['Dataset'] == dataset)\n", + " ]\n", + " if subset.empty:\n", + " print(f\"No data for Compressor: {compressor}, Variable: {variable}\")\n", + " continue\n", + "\n", + " if metric in [\"DSSIM\", \"Spectral Error\"] and variable in [\"ta\", \"tos\"]:\n", + " # These variables have large regions of NaN values which makes the\n", + " # DSSIM and Spectral Error values unreliable.\n", + " continue\n", + "\n", + " col_idx = j * len(dataset_variables) + k\n", + " if metric in subset.columns:\n", + " values = subset[metric]\n", + " if len(values) == 1:\n", + " data_matrix[i, col_idx] = values.iloc[0]\n", + "\n", + " return data_matrix, compressors, dataset_variables" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2a399d30", + "metadata": {}, + "outputs": [], + "source": [ + "df = df[\n", + " ~df[\"Compressor\"].isin(\n", + " [\n", + " \"bitround\",\n", + " \"jpeg2000-conservative-abs\",\n", + " \"stochround-conservative-abs\",\n", + " \"stochround-pco-conservative-abs\",\n", + " \"zfp-conservative-abs\",\n", + " \"bitround-conservative-rel\",\n", + " \"stochround-pco\",\n", + " \"stochround\",\n", + " \"zfp\",\n", + " \"jpeg2000\",\n", + " ]\n", + " )\n", + "]\n", + "df = df[~df[\"Dataset\"].str.contains(\"-tiny\")]\n", + "df = df[~df[\"Dataset\"].str.contains(\"-chunked\")]\n", + "df = _rename_compressors(df)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "2d019b8d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: sperr, Variable: pr\n", + "No data for Compressor: sperr, Variable: ta\n", + "No data for Compressor: sperr, Variable: tos\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n", + "No data for Compressor: safeguarded-sperr, Variable: pr\n" + ] + } + ], + "source": [ + "metrics = [\n", + " \"DSSIM\",\n", + " \"MAE\",\n", + " \"Max Absolute Error\",\n", + " \"Spectral Error\",\n", + " \"Compression Ratio [raw B / enc B]\",\n", + " \"Satisfies Bound (Value)\",\n", + "]\n", + "scorecard_data = {}\n", + "for error_bound in [\"low\", \"mid\", \"high\"]:\n", + " scorecard_data[error_bound] = create_data_matrix(df, error_bound, metrics)" + ] + }, + { + "cell_type": "markdown", + "id": "ae80d757", + "metadata": {}, + "source": [ + "# Scorecard" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "871ae766", + "metadata": {}, + "outputs": [], + "source": [ + "METRICS2NAME = {\n", + " \"DSSIM\": \"dSSIM\",\n", + " \"MAE\": \"Mean Absolute Error\",\n", + " \"Compression Ratio [raw B / enc B]\": \"Compression Ratio\",\n", + " \"Satisfies Bound (Value)\": r\"% of Data Points Violating the Error Bound\",\n", + "}\n", + "\n", + "VARIABLE2NAME = {\n", + " \"10m_u_component_of_wind\": \"10u\",\n", + " \"10m_v_component_of_wind\": \"10v\",\n", + " \"mean_sea_level_pressure\": \"msl\",\n", + "}\n", + "\n", + "DATASET2PREFIX = {\n", + " \"era5-hurricane\": \"h-\",\n", + "}\n", + "\n", + "\n", + "def get_variable_label(variable):\n", + " dataset, var_name = variable.split(\"/\")\n", + " prefix = DATASET2PREFIX.get(dataset, \"\")\n", + " var_name = VARIABLE2NAME.get(var_name, var_name)\n", + " return f\"{prefix}{var_name}\"\n", + "\n", + "\n", + "def create_compression_scorecard(\n", + " data_matrix,\n", + " compressors,\n", + " variables,\n", + " metrics,\n", + " cbar=True,\n", + " ref_compressor=\"sz3\",\n", + " higher_better_metrics=[\"DSSIM\", \"Compression Ratio [raw B / enc B]\"],\n", + " save_fn=None,\n", + " compare_against_0=False,\n", + " highlight_bigger_than_one=False,\n", + "):\n", + " \"\"\"\n", + " Create a scorecard plot similar to the weather forecasting example\n", + "\n", + " Parameters:\n", + " - data_matrix: 2D array with compressors as rows, metric-variable combinations as columns\n", + " - compressors: list of compressor names\n", + " - variables: list of variable names\n", + " - metrics: list of metric names\n", + " - ref_compressor: reference compressor for relative calculations\n", + " - save_fn: filename to save plot (optional)\n", + " \"\"\"\n", + "\n", + " # Calculate relative differences vs reference compressor\n", + " ref_idx = compressors.index(ref_compressor)\n", + " ref_values = data_matrix[ref_idx, :]\n", + " if compare_against_0:\n", + " ref_values = np.zeros_like(data_matrix[ref_idx, :])\n", + "\n", + " relative_matrix = np.full_like(data_matrix, np.nan)\n", + " if highlight_bigger_than_one:\n", + " relative_matrix = np.sign(data_matrix) * 101\n", + " for j in range(data_matrix.shape[1]):\n", + " if metrics[j // len(variables)] == \"Satisfies Bound (Value)\":\n", + " # For bound satisfication lower is better (less number of pixels exceeding error bound).\n", + " relative_matrix[:, j] = -1 * relative_matrix[:, j]\n", + " else:\n", + " for i in range(len(compressors)):\n", + " for j in range(data_matrix.shape[1]):\n", + " if not np.isnan(data_matrix[i, j]) and not np.isnan(ref_values[j]):\n", + " ref_val = np.abs(ref_values[j])\n", + " if ref_val == 0.0:\n", + " ref_val = 1e-10 # Avoid division by zero\n", + " if metrics[j // len(variables)] in higher_better_metrics:\n", + " # Higher is better metrics\n", + " relative_matrix[i, j] = (\n", + " (ref_values[j] - data_matrix[i, j]) / ref_val * 100\n", + " )\n", + " elif metrics[j // len(variables)] == \"Satisfies Bound (Value)\":\n", + " relative_matrix[i, j] = 100 if data_matrix[i, j] != 0 else 0\n", + " else:\n", + " relative_matrix[i, j] = (\n", + " (data_matrix[i, j] - ref_values[j]) / ref_val * 100\n", + " )\n", + "\n", + " # Set up colormap - similar to original\n", + " reds = sns.color_palette(\"Reds\", 6)\n", + " blues = sns.color_palette(\"Blues_r\", 6)\n", + " cmap = mpl.colors.ListedColormap(blues + [(0.95, 0.95, 0.95)] + reds)\n", + " # cb_levels = [-50, -20, -10, -5, -2, -1, 1, 2, 5, 10, 20, 50]\n", + " # cb_levels = [-75, -50, -25, -10, -5, -1, 1, 5, 10, 25, 50, 75]\n", + " cb_levels = [-100, -75, -50, -25, -10, -1, 1, 10, 25, 50, 75, 100]\n", + "\n", + " norm = mpl.colors.BoundaryNorm(cb_levels, cmap.N, extend=\"both\")\n", + "\n", + " # Calculate figure dimensions\n", + " ncompressors = len(compressors)\n", + " nvariables = len(variables)\n", + " nmetrics = len(metrics)\n", + "\n", + " panel_width = (2.5 / 5) * nvariables\n", + " label_width = 1.5 * panel_width\n", + " padding_right = 0.1\n", + " panel_height = panel_width / nvariables\n", + "\n", + " title_height = panel_height * 1.25\n", + " cbar_height = panel_height * 2\n", + " spacing_height = panel_height * 0.1\n", + " spacing_width = panel_height * 0.2\n", + "\n", + " total_width = (\n", + " label_width\n", + " + nmetrics * panel_width\n", + " + (nmetrics - 1) * spacing_width\n", + " + padding_right\n", + " )\n", + " total_height = (\n", + " title_height\n", + " + cbar_height\n", + " + ncompressors * panel_height\n", + " + (ncompressors - 1) * spacing_height\n", + " )\n", + "\n", + " # Create figure and gridspec\n", + " fig = plt.figure(figsize=(total_width, total_height))\n", + " gs = mpl.gridspec.GridSpec(\n", + " ncompressors,\n", + " nmetrics,\n", + " figure=fig,\n", + " left=label_width / total_width,\n", + " right=1 - padding_right / total_width,\n", + " top=1 - (title_height / total_height),\n", + " bottom=cbar_height / total_height,\n", + " hspace=spacing_height / panel_height,\n", + " wspace=spacing_width / panel_width,\n", + " )\n", + "\n", + " # Plot each panel\n", + " for row, compressor in enumerate(compressors):\n", + " for col, metric in enumerate(metrics):\n", + " ax = fig.add_subplot(gs[row, col])\n", + "\n", + " # Get data for this metric (all variables)\n", + " start_col = col * nvariables\n", + " end_col = start_col + nvariables\n", + "\n", + " rel_values = relative_matrix[row, start_col:end_col].reshape(1, -1)\n", + " abs_values = data_matrix[row, start_col:end_col]\n", + "\n", + " # Create heatmap\n", + " img = ax.imshow(rel_values, aspect=\"auto\", cmap=cmap, norm=norm)\n", + "\n", + " # Customize axes\n", + " ax.set_xticks([])\n", + " ax.set_xticklabels([])\n", + " ax.set_yticks([])\n", + " ax.set_yticklabels([])\n", + "\n", + " # Add white grid lines\n", + " for i in range(1, nvariables):\n", + " rect = mpl.patches.Rectangle(\n", + " (i - 0.5, -0.5),\n", + " 1,\n", + " 1,\n", + " linewidth=1,\n", + " edgecolor=\"lightgrey\"\n", + " if np.isnan(abs_values[i]) and np.isnan(abs_values[i - 1])\n", + " else \"white\",\n", + " facecolor=\"none\",\n", + " )\n", + " ax.add_patch(rect)\n", + "\n", + " # Add absolute values as text\n", + " for i, val in enumerate(abs_values):\n", + " # Ensure we don't have black text on dark background\n", + " color = \"black\" if abs(rel_values[0, i]) < 75 else \"white\"\n", + " fontsize = 10\n", + " # Format numbers appropriately\n", + " if metric in [\"DSSIM\", \"Spectral Error\"] and variables[i] in [\n", + " \"ta\",\n", + " \"tos\",\n", + " ]:\n", + " # These variables have large regions of NaN values which makes the\n", + " # DSSIM and Spectral Error values unreliable.\n", + " text = \"N/A\"\n", + " color = \"black\"\n", + " elif np.isnan(val):\n", + " text = \"Crash\"\n", + " color = \"black\"\n", + " elif abs(val) > 10_000:\n", + " text = f\"{val:.1e}\"\n", + " fontsize = 8\n", + " elif abs(val) > 10:\n", + " text = f\"{val:.0f}\"\n", + " elif abs(val) > 1:\n", + " text = f\"{val:.1f}\"\n", + " elif val == 1 and metric == \"DSSIM\":\n", + " text = \"1\"\n", + " elif val == 0:\n", + " text = \"0\"\n", + " elif abs(val) < 0.01:\n", + " text = f\"{val:.1e}\"\n", + " fontsize = 8\n", + " else:\n", + " text = f\"{val:.2f}\"\n", + " ax.text(\n", + " i, 0, text, ha=\"center\", va=\"center\", fontsize=fontsize, color=color\n", + " )\n", + "\n", + " if (\n", + " row > 0\n", + " and np.isnan(val)\n", + " and np.isnan(data_matrix[row - 1, col * nvariables + i])\n", + " and compressor == f\"safeguarded-{compressors[row - 1]}\"\n", + " and not (\n", + " metric in [\"DSSIM\", \"Spectral Error\"]\n", + " and variables[i]\n", + " in [\n", + " \"ta\",\n", + " \"tos\",\n", + " ]\n", + " )\n", + " ):\n", + " ax.annotate(\n", + " \"\",\n", + " xy=(i, -0.15),\n", + " xytext=(i, -0.9),\n", + " arrowprops=dict(arrowstyle=\"->\", lw=2, color=\"lightgrey\"),\n", + " )\n", + "\n", + " # Add row labels (compressor names)\n", + " if col == 0:\n", + " ax.set_ylabel(\n", + " _get_legend_name(compressor),\n", + " rotation=0,\n", + " ha=\"right\",\n", + " va=\"center\",\n", + " labelpad=10,\n", + " fontsize=14,\n", + " )\n", + "\n", + " # Add column titles (variable names)\n", + " if row == 0:\n", + " # ax.set_title(VARIABLE2NAME.get(variable, variable), fontsize=10, pad=10)\n", + " ax.set_title(METRICS2NAME.get(metric, metric), fontsize=16, pad=10)\n", + "\n", + " # Add metric labels at the top on the top row\n", + " if row == 0:\n", + " # ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)\n", + " # ax.set_xticks(range(nmetrics))\n", + " # ax.set_xticklabels(\n", + " # [METRICS2NAME.get(m, m) for m in metrics],\n", + " # rotation=45,\n", + " # ha='left', fontsize=8)\n", + " ax.tick_params(top=True, labeltop=True, bottom=False, labelbottom=False)\n", + " ax.set_xticks(range(nvariables))\n", + " ax.set_xticklabels(\n", + " [VARIABLE2NAME.get(v, v) for v in variables],\n", + " rotation=45,\n", + " ha=\"left\",\n", + " fontsize=12,\n", + " )\n", + "\n", + " # Style spines\n", + " for spine in ax.spines.values():\n", + " spine.set_color(\"0.7\")\n", + "\n", + " # Add colorbar\n", + " if cbar and not highlight_bigger_than_one:\n", + " rel_cbar_height = cbar_height / total_height\n", + " cax = fig.add_axes((0.4, rel_cbar_height * 0.3, 0.5, rel_cbar_height * 0.2))\n", + " cb = fig.colorbar(img, cax=cax, orientation=\"horizontal\")\n", + " cb.ax.set_xticks(cb_levels)\n", + " if highlight_bigger_than_one:\n", + " cb.ax.set_xlabel(\"Better ← |non-chunked - chunked| → Worse\", fontsize=16)\n", + " else:\n", + " cb.ax.set_xlabel(\n", + " f\"Better ← % difference vs {_get_legend_name(ref_compressor)} → Worse\",\n", + " fontsize=16,\n", + " )\n", + "\n", + " if highlight_bigger_than_one:\n", + " chunking_handles = [\n", + " Line2D(\n", + " [],\n", + " [],\n", + " marker=\"s\",\n", + " color=cmap(101),\n", + " linestyle=\"None\",\n", + " markersize=10,\n", + " label=\"Not Chunked Better\",\n", + " ),\n", + " Line2D(\n", + " [],\n", + " [],\n", + " marker=\"s\",\n", + " color=cmap(-101),\n", + " linestyle=\"None\",\n", + " markersize=10,\n", + " label=\"Chunked Better\",\n", + " ),\n", + " ]\n", + "\n", + " ax.legend(\n", + " handles=chunking_handles,\n", + " loc=\"upper left\",\n", + " ncol=2,\n", + " bbox_to_anchor=(-0.5, -0.05),\n", + " fontsize=16,\n", + " )\n", + "\n", + " # plt.tight_layout()\n", + "\n", + " if save_fn:\n", + " plt.savefig(save_fn, dpi=300, bbox_inches=\"tight\")\n", + " plt.close()\n", + " else:\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "678c927b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating scorecard for low bound...\n", + "Creating scorecard for mid bound...\n", + "Creating scorecard for high bound...\n" + ] + } + ], + "source": [ + "for bound_name, (data_matrix, compressors, variables) in scorecard_data.items():\n", + " print(f\"Creating scorecard for {bound_name} bound...\")\n", + " # Split into two rows for better readability.\n", + " create_compression_scorecard(\n", + " data_matrix[:, : 3 * len(variables)],\n", + " compressors,\n", + " variables,\n", + " metrics[:3],\n", + " ref_compressor=\"bitround-pco\",\n", + " cbar=False,\n", + " save_fn=f\"scorecards/{bound_name}_scorecard_row1.pdf\",\n", + " )\n", + "\n", + " create_compression_scorecard(\n", + " data_matrix[:, 3 * len(variables) :],\n", + " compressors,\n", + " variables,\n", + " metrics[3:],\n", + " ref_compressor=\"bitround-pco\",\n", + " save_fn=f\"scorecards/{bound_name}_scorecard_row2.pdf\",\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2d8dea1-cd87-48d1-9d5b-8fe106183cbf", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/climatebenchpress/compressor/plotting/plot_metrics.py b/src/climatebenchpress/compressor/plotting/plot_metrics.py index 408b098..1b5b1bb 100644 --- a/src/climatebenchpress/compressor/plotting/plot_metrics.py +++ b/src/climatebenchpress/compressor/plotting/plot_metrics.py @@ -13,15 +13,21 @@ _COMPRESSOR2LINEINFO = [ ("jpeg2000", ("#EE7733", "-")), - ("sperr", ("#117733", ":")), - ("zfp-round", ("#DDAA33", "--")), + ("sperr", ("#117733", "-")), + ("zfp-round", ("#DDAA33", "-")), ("zfp", ("#EE3377", "--")), - ("sz3", ("#CC3311", "-.")), - ("bitround-pco", ("#0077BB", ":")), + ("sz3", ("#CC3311", "-")), + ("bitround-pco", ("#0077BB", "-")), ("bitround", ("#33BBEE", "-")), ("stochround-pco", ("#BBBBBB", "--")), ("stochround", ("#009988", "--")), ("tthresh", ("#882255", "-.")), + ("safeguarded-sperr", ("#117733", ":")), + ("safeguarded-zfp-round", ("#DDAA33", ":")), + ("safeguarded-sz3", ("#CC3311", ":")), + ("safeguarded-zero-dssim", ("#9467BD", "--")), + ("safeguarded-zero", ("#9467BD", ":")), + ("safeguarded-bitround-pco", ("#0077BB", ":")), ] @@ -36,19 +42,38 @@ def _get_lineinfo(compressor: str) -> tuple[str, str]: _COMPRESSOR2LEGEND_NAME = [ ("jpeg2000", "JPEG2000"), ("sperr", "SPERR"), - ("zfp-round", "ZFP-ROUND"), + ("zfp-round", "ZFP"), ("zfp", "ZFP"), - ("sz3", "SZ3"), - ("bitround-pco", "BitRound + PCO"), + ("sz3", "SZ3[v3.2]"), + ("bitround-pco", "BitRound"), ("bitround", "BitRound + Zstd"), ("stochround-pco", "StochRound + PCO"), ("stochround", "StochRound + Zstd"), ("tthresh", "TTHRESH"), + ("safeguarded-sperr", "Safeguarded(SPERR)"), + ("safeguarded-zfp-round", "Safeguarded(ZFP)"), + ("safeguarded-sz3", "Safeguarded(SZ3[v3.2])"), + ("safeguarded-zero-dssim", "Safeguarded(0, dSSIM)"), + ("safeguarded-zero", "Safeguarded(0)"), + ("safeguarded-bitround-pco", "Safeguarded(BitRound)"), +] + +_COMPRESSOR_ORDER = [ + "BitRound", + "Safeguarded(BitRound)", + "ZFP", + "Safeguarded(ZFP)", + "SZ3[v3.2]", + "Safeguarded(SZ3[v3.2])", + "SPERR", + "Safeguarded(SPERR)", + "Safeguarded(0)", + "Safeguarded(0, dSSIM)", ] DISTORTION2LEGEND_NAME = { "Relative MAE": "Mean Absolute Error", - "Relative DSSIM": "DSSIM", + "Relative dSSIM": "dSSIM", "Relative MaxAbsError": "Max Absolute Error", "Spectral Error": "Spectral Error", } @@ -102,6 +127,7 @@ def plot_metrics( df = pd.read_csv(metrics_path / "all_results.csv") # Filter out excluded datasets and compressors + # bitround jpeg2000-conservative-abs stochround-conservative-abs stochround-pco-conservative-abs zfp-conservative-abs bitround-conservative-rel stochround-pco stochround zfp jpeg2000 df = df[~df["Compressor"].isin(exclude_compressor)] df = df[~df["Dataset"].isin(exclude_dataset)] is_tiny = df["Dataset"].str.endswith("-tiny") @@ -111,16 +137,16 @@ def plot_metrics( filter_chunked = is_chunked if chunked_datasets else ~is_chunked df = df[filter_chunked] - _plot_per_variable_metrics( - datasets=datasets, - compressed_datasets=compressed_datasets, - plots_path=plots_path, - all_results=df, - rd_curves_metrics=["Max Absolute Error", "MAE", "DSSIM", "Spectral Error"], - ) + # _plot_per_variable_metrics( + # datasets=datasets, + # compressed_datasets=compressed_datasets, + # plots_path=plots_path, + # all_results=df, + # rd_curves_metrics=["Max Absolute Error", "MAE", "DSSIM", "Spectral Error"], + # ) df = _rename_compressors(df) - normalized_df = _normalize(df) + normalized_df, normalized_mean_std = _normalize(df) _plot_bound_violations( normalized_df, bound_names, plots_path / "bound_violations.pdf" ) @@ -129,7 +155,7 @@ def plot_metrics( for metric in [ "Relative MAE", - "Relative DSSIM", + "Relative dSSIM", "Relative MaxAbsError", "Relative SpectralError", ]: @@ -138,6 +164,7 @@ def plot_metrics( normalized_df, compression_metric="Relative CR", distortion_metric=metric, + mean_std=normalized_mean_std[metric], outfile=plots_path / f"rd_curve_{metric.lower().replace(' ', '_')}.pdf", agg="mean", bound_names=bound_names, @@ -147,6 +174,7 @@ def plot_metrics( normalized_df, compression_metric="Relative CR", distortion_metric=metric, + mean_std=normalized_mean_std[metric], outfile=plots_path / f"full_rd_curve_{metric.lower().replace(' ', '_')}.pdf", agg="mean", @@ -188,7 +216,7 @@ def _normalize(data): normalize_vars = [ ("Compression Ratio [raw B / enc B]", "Relative CR"), ("MAE", "Relative MAE"), - ("DSSIM", "Relative DSSIM"), + ("DSSIM", "Relative dSSIM"), ("Max Absolute Error", "Relative MaxAbsError"), ("Spectral Error", "Relative SpectralError"), ] @@ -198,6 +226,7 @@ def _normalize(data): dssim_unreliable = normalized["Variable"].isin(["ta", "tos"]) normalized.loc[dssim_unreliable, "DSSIM"] = np.nan + normalize_mean_std = dict() for col, new_col in normalize_vars: mean_std = dict() for var in variables: @@ -213,7 +242,9 @@ def _normalize(data): axis=1, ) - return normalized + normalize_mean_std[new_col] = mean_std + + return normalized, normalize_mean_std def _plot_per_variable_metrics( @@ -408,6 +439,7 @@ def _plot_aggregated_rd_curve( normalized_df, compression_metric, distortion_metric, + mean_std, outfile: None | Path = None, agg="median", bound_names=["low", "mid", "high"], @@ -419,7 +451,10 @@ def _plot_aggregated_rd_curve( # Exclude variables that are not relevant for the distortion metric. normalized_df = normalized_df[~normalized_df["Variable"].isin(exclude_vars)] - compressors = normalized_df["Compressor"].unique() + compressors = sorted( + normalized_df["Compressor"].unique(), + key=lambda k: _COMPRESSOR_ORDER.index(_get_legend_name(k)), + ) agg_distortion = normalized_df.groupby(["Error Bound Name", "Compressor"])[ [compression_metric, distortion_metric] ].agg(agg) @@ -447,7 +482,13 @@ def _plot_aggregated_rd_curve( if remove_outliers: # SZ3 and JPEG2000 often give outlier values and violate the bounds. - exclude_compressors = ["sz3", "jpeg2000"] + exclude_compressors = [ + "sz3", + "jpeg2000", + "safeguarded-zero-dssim", + "safeguarded-zero", + "safeguarded-sz3", + ] filtered_agg = agg_distortion[ ~agg_distortion.index.get_level_values("Compressor").isin( exclude_compressors @@ -493,28 +534,45 @@ def _plot_aggregated_rd_curve( right=True, ) plt.xlabel( - r"Mean Normalized Compression Ratio ($\uparrow$)", + r"Mean Normalised Compression Ratio ($\uparrow$)", fontsize=16, ) metric_name = DISTORTION2LEGEND_NAME.get(distortion_metric, distortion_metric) plt.ylabel( - rf"Mean Normalized {metric_name} ($\downarrow$)", + rf"Mean Normalised {metric_name} ($\downarrow$)", fontsize=16, ) plt.legend( title="Compressor", - loc="upper right", - bbox_to_anchor=(0.8, 0.99), + loc="upper left", + # bbox_to_anchor=(0.8, 0.99), fontsize=12, title_fontsize=14, ) arrow_color = "black" - if "DSSIM" in distortion_metric: + if "dSSIM" in distortion_metric: + # Annotate dSSIM = 1, accounting for the normalization + dssim_one = getattr(np, f"nan{agg}")( + [(1 - ms[0]) / ms[1] for ms in mean_std.values()] + ) + plt.axhline(dssim_one, c="k", ls="--") + plt.text( + np.percentile(plt.xlim(), 63), + dssim_one, + "dSSIM = 1", + fontsize=16, + fontweight="bold", + color="black", + ha="center", + va="center", + bbox=dict(edgecolor="none", facecolor="w", alpha=0.85), + ) + # Add an arrow pointing into the top right corner plt.annotate( "", - xy=(0.95, 0.95), + xy=(0.95, 0.875 if remove_outliers else 0.9), xycoords="axes fraction", xytext=(-60, -50), textcoords="offset points", @@ -527,7 +585,7 @@ def _plot_aggregated_rd_curve( # Attach the text to the lower left of the arrow plt.text( 0.83, - 0.92, + 0.845 if remove_outliers else 0.87, "Better", transform=plt.gca().transAxes, fontsize=16, @@ -538,7 +596,7 @@ def _plot_aggregated_rd_curve( ) # Correct the y-label to point upwards plt.ylabel( - rf"Mean Normalized {metric_name} ($\uparrow$)", + rf"Mean Normalised {metric_name} ($\uparrow$)", fontsize=16, ) else: @@ -566,7 +624,7 @@ def _plot_aggregated_rd_curve( ha="center", ) if ( - "DSSIM" in distortion_metric + "dSSIM" in distortion_metric or "MaxAbsError" in distortion_metric or "SpectralError" in distortion_metric ): @@ -579,24 +637,23 @@ def _plot_aggregated_rd_curve( def _plot_throughput(df, outfile: None | Path = None): - # Transform throughput measurements from raw B/s to s/MB for better comparison - # with instruction count measurements. encode_col = "Encode Throughput [raw B / s]" decode_col = "Decode Throughput [raw B / s]" new_df = df[["Compressor", "Error Bound Name", encode_col, decode_col]].copy() - transformed_encode_col = "Encode Throughput [s / MB]" - transformed_decode_col = "Decode Throughput [s / MB]" - new_df[transformed_encode_col] = 1e6 / new_df[encode_col] - new_df[transformed_decode_col] = 1e6 / new_df[decode_col] + transformed_encode_col = "Encode Throughput [MiB / s]" + transformed_decode_col = "Decode Throughput [MiB / s]" + new_df[transformed_encode_col] = new_df[encode_col] / (2**20) + new_df[transformed_decode_col] = new_df[decode_col] / (2**20) encode_col, decode_col = transformed_encode_col, transformed_decode_col grouped_df = _get_median_and_quantiles(new_df, encode_col, decode_col) _plot_grouped_df( grouped_df, title="", - ylabel="Throughput [s / MB]", + ylabel="Throughput [MiB / s]", logy=True, outfile=outfile, + up=True, ) @@ -610,42 +667,56 @@ def _plot_instruction_count(df, outfile: None | Path = None): ylabel="Instructions [# / raw B]", logy=True, outfile=outfile, + up=False, ) def _get_median_and_quantiles(df, encode_column, decode_column): - return df.groupby(["Compressor", "Error Bound Name"])[ - [encode_column, decode_column] - ].agg( - encode_median=pd.NamedAgg( - column=encode_column, aggfunc=lambda x: x.quantile(0.5) - ), - encode_lower_quantile=pd.NamedAgg( - column=encode_column, aggfunc=lambda x: x.quantile(0.25) - ), - encode_upper_quantile=pd.NamedAgg( - column=encode_column, aggfunc=lambda x: x.quantile(0.75) - ), - decode_median=pd.NamedAgg( - column=decode_column, aggfunc=lambda x: x.quantile(0.5) - ), - decode_lower_quantile=pd.NamedAgg( - column=decode_column, aggfunc=lambda x: x.quantile(0.25) - ), - decode_upper_quantile=pd.NamedAgg( - column=decode_column, aggfunc=lambda x: x.quantile(0.75) - ), + return ( + df.groupby(["Compressor", "Error Bound Name"])[[encode_column, decode_column]] + .agg( + encode_median=pd.NamedAgg( + column=encode_column, aggfunc=lambda x: x.quantile(0.5) + ), + encode_lower_quantile=pd.NamedAgg( + column=encode_column, aggfunc=lambda x: x.quantile(0.25) + ), + encode_upper_quantile=pd.NamedAgg( + column=encode_column, aggfunc=lambda x: x.quantile(0.75) + ), + decode_median=pd.NamedAgg( + column=decode_column, aggfunc=lambda x: x.quantile(0.5) + ), + decode_lower_quantile=pd.NamedAgg( + column=decode_column, aggfunc=lambda x: x.quantile(0.25) + ), + decode_upper_quantile=pd.NamedAgg( + column=decode_column, aggfunc=lambda x: x.quantile(0.75) + ), + ) + .sort_index( + level=0, + key=lambda ks: [_COMPRESSOR_ORDER.index(_get_legend_name(k)) for k in ks], + ) ) def _plot_grouped_df( - grouped_df, title, ylabel, outfile: None | Path = None, logy=False + grouped_df, + title, + ylabel, + outfile: None | Path = None, + logy=False, + up=False, ): fig, axes = plt.subplots(1, 3, figsize=(18, 6), sharex=True, sharey=True) # Bar width bar_width = 0.35 - compressors = grouped_df.index.levels[0].tolist() + compressors = sorted( + grouped_df.index.levels[0].tolist(), + key=lambda k: _COMPRESSOR_ORDER.index(_get_legend_name(k)), + ) x_labels = [_get_legend_name(c) for c in compressors] x_positions = range(len(x_labels)) @@ -653,7 +724,10 @@ def _plot_grouped_df( for i, error_bound in enumerate(error_bounds): ax = axes[i] - bound_data = grouped_df.xs(error_bound, level="Error Bound Name") + bound_data = grouped_df.xs(error_bound, level="Error Bound Name").sort_index( + level=0, + key=lambda ks: [_COMPRESSOR_ORDER.index(_get_legend_name(k)) for k in ks], + ) # Plot encode throughput ax.bar( @@ -664,8 +738,13 @@ def _plot_grouped_df( bound_data["encode_lower_quantile"], bound_data["encode_upper_quantile"], ], - label="Encoding", + label="Compression", + edgecolor="white", + linewidth=0, color=[_get_lineinfo(comp)[0] for comp in compressors], + hatch=[ + "O" if comp.startswith("safeguarded-") else "" for comp in compressors + ], ) # Plot decode throughput @@ -677,10 +756,13 @@ def _plot_grouped_df( bound_data["decode_lower_quantile"], bound_data["decode_upper_quantile"], ], - label="Decoding", + label="Decompression", edgecolor=[_get_lineinfo(comp)[0] for comp in compressors], fill=False, linewidth=4, + hatch=[ + "O" if comp.startswith("safeguarded-") else "" for comp in compressors + ], ) # Add labels and title @@ -690,15 +772,18 @@ def _plot_grouped_df( ax.set_title(f"{error_bound.capitalize()} Error Bound", fontsize=14) ax.grid(axis="y", linestyle="--", alpha=0.7) if i == 0: - ax.legend(fontsize=14) + ax.legend( + fontsize=14, loc="lower left" if up else "upper left", framealpha=0.9 + ) ax.set_ylabel(ylabel, fontsize=14) + if i == 1: ax.annotate( "Better", - xy=(0.1, 0.8), + xy=(0.51, 0.75), xycoords="axes fraction", - xytext=(0.1, 0.95), + xytext=(0.51, 0.92), textcoords="axes fraction", - arrowprops=dict(arrowstyle="->", lw=3, color="black"), + arrowprops=dict(arrowstyle="<-" if up else "->", lw=3, color="black"), fontsize=14, ha="center", va="bottom", @@ -720,11 +805,11 @@ def _plot_bound_violations(df, bound_names, outfile: None | Path = None): df_bound["Compressor"] = df_bound["Compressor"].map(_get_legend_name) pass_fail = df_bound.pivot( index="Compressor", columns="Variable", values="Satisfies Bound (Passed)" - ) + ).sort_index(key=lambda ks: [_COMPRESSOR_ORDER.index(k) for k in ks]) pass_fail = pass_fail.astype(np.float32) fraction_fail = df_bound.pivot( index="Compressor", columns="Variable", values="Satisfies Bound (Value)" - ) + ).sort_index(key=lambda ks: [_COMPRESSOR_ORDER.index(k) for k in ks]) annotations = fraction_fail.map( lambda x: "{:.2f}".format(x * 100) if x * 100 >= 0.01 else "<0.01" )