Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions XGBWW_OpenML_1049_W7W8_Alpha_Checkpoint_WWPlots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@
"import xgboost as xgb\n",
"import torch\n",
"import weightwatcher as ww\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import accuracy_score, log_loss\n",
Expand Down Expand Up @@ -228,6 +229,126 @@
" return None\n",
"\n",
"\n",
"\n",
"def _extract_weight_matrix(layer):\n",
" candidates = []\n",
"\n",
" if hasattr(layer, 'weight'):\n",
" candidates.append(layer.weight)\n",
"\n",
" if hasattr(layer, 'modules'):\n",
" for sublayer in layer.modules():\n",
" if sublayer is layer:\n",
" continue\n",
" if hasattr(sublayer, 'weight'):\n",
" candidates.append(sublayer.weight)\n",
"\n",
" for w in candidates:\n",
" if hasattr(w, 'detach'):\n",
" arr = w.detach().cpu().numpy()\n",
" else:\n",
" arr = np.asarray(w)\n",
" if arr.ndim == 2:\n",
" return arr\n",
"\n",
" return None\n",
"\n",
"\n",
"def _dominant_weight_values(weight_matrix, max_traps=8):\n",
" if weight_matrix is None:\n",
" return []\n",
"\n",
" flat = np.asarray(weight_matrix, dtype=float).ravel()\n",
" if flat.size == 0:\n",
" return []\n",
"\n",
" order = np.argsort(np.abs(flat))[::-1]\n",
" selected = []\n",
" for idx in order:\n",
" value = float(flat[idx])\n",
" if not np.isfinite(value):\n",
" continue\n",
" if any(np.isclose(value, prev, atol=1e-12) for prev in selected):\n",
" continue\n",
" selected.append(value)\n",
" if len(selected) >= int(max_traps):\n",
" break\n",
" return selected\n",
"\n",
"\n",
"def plot_trap_lines_on_weight_histogram(layer, matrix_name, round_idx, output_dir, max_traps=8, trap_values=None, message_prefix='analyze traps'):\n",
" weight_matrix = _extract_weight_matrix(layer)\n",
" if weight_matrix is None:\n",
" print(f'round={round_idx:4d} | skipping trap histogram for {matrix_name}: no 2D weight matrix found')\n",
" return\n",
"\n",
" values = np.asarray(weight_matrix, dtype=float).ravel()\n",
" values = values[np.isfinite(values)]\n",
" if values.size == 0:\n",
" print(f'round={round_idx:4d} | skipping trap histogram for {matrix_name}: empty/invalid weights')\n",
" return\n",
"\n",
" if trap_values is None:\n",
" selected_traps = _dominant_weight_values(weight_matrix, max_traps=max_traps)\n",
" else:\n",
" selected_traps = [float(v) for v in trap_values if np.isfinite(v)]\n",
" if not selected_traps:\n",
" print(f'round={round_idx:4d} | skipping trap histogram for {matrix_name}: no trap values selected')\n",
" return\n",
"\n",
" max_abs_weight = float(values[np.argmax(np.abs(values))])\n",
"\n",
" fig, ax = plt.subplots(figsize=(9, 6))\n",
" ax.hist(values, bins=60, density=True, alpha=0.6, color='tab:blue', edgecolor='none')\n",
" ax.set_xlabel('Weight value')\n",
" ax.set_ylabel('Density')\n",
" ax.set_title(f'Layer {matrix_name} ({round_idx}) \u2014 trap lines on weight histogram')\n",
"\n",
" cmap = plt.get_cmap('viridis')\n",
" y_max = ax.get_ylim()[1]\n",
"\n",
" for i, trap_value in enumerate(selected_traps, start=1):\n",
" color = cmap((i - 1) / max(1, len(selected_traps) - 1))\n",
" is_global_max = np.isclose(trap_value, max_abs_weight, atol=1e-12)\n",
" linestyle = (0, (7, 2, 2, 2)) if is_global_max else '--'\n",
" linewidth = 2.6 if is_global_max else 1.6\n",
"\n",
" ax.axvline(trap_value, color=color, linestyle=linestyle, linewidth=linewidth, alpha=0.95)\n",
"\n",
" label_text = f'#{i}: ({trap_value:.4g})\\n' + ('max |w|' if is_global_max else '\\u03b5=0.000')\n",
" y_pos = y_max * (0.93 - 0.08 * ((i - 1) % 5))\n",
" ax.text(\n",
" trap_value,\n",
" y_pos,\n",
" label_text,\n",
" rotation=90,\n",
" va='top',\n",
" ha='center',\n",
" fontsize=8,\n",
" color='black',\n",
" bbox=dict(boxstyle='round,pad=0.15', facecolor='white', edgecolor='none', alpha=0.55),\n",
" )\n",
"\n",
" output_dir = Path(output_dir)\n",
" output_dir.mkdir(parents=True, exist_ok=True)\n",
" figure_path = output_dir / f'{matrix_name.lower()}_trap_hist.png'\n",
" fig.tight_layout()\n",
" fig.savefig(figure_path, dpi=160)\n",
" plt.show()\n",
" plt.close(fig)\n",
" print(f'round={round_idx:4d} | {message_prefix} ({matrix_name}): saved trap histogram to {figure_path}')\n",
"\n",
"\n",
"def plot_removed_traps_on_weight_histogram(layer, matrix_name, round_idx, output_dir, removed_trap_values):\n",
" return plot_trap_lines_on_weight_histogram(\n",
" layer=layer,\n",
" matrix_name=matrix_name,\n",
" round_idx=round_idx,\n",
" output_dir=output_dir,\n",
" trap_values=removed_trap_values,\n",
" message_prefix='remove traps',\n",
" )\n",
"\n",
"def layer_min_matrix_dim(layer):\n",
" shape = _extract_weight_shape(layer)\n",
" if shape is None:\n",
Expand Down Expand Up @@ -450,6 +571,7 @@
"MIN_WW_EIGENVALUES = 10 # Wait until constructed matrix has enough spectral support\n",
"WW_PLOTS_DIR = DRIVE_ROOT / 'ww_plots'\n",
"WW_PLOTS_DIR.mkdir(parents=True, exist_ok=True)\n",
"TRAP_HIST_MAX_LINES = 8\n",
"\n",
"start_time = time.time()\n",
"for r in range(start_round + 1, MAX_ROUNDS + 1):\n",
Expand Down Expand Up @@ -507,6 +629,11 @@
" ww_w7 = ww_stats_for_matrix(layer_w7, 'W7', make_plot=make_ww_plots, plot_dir=round_plot_dir)\n",
" ww_w8 = ww_stats_for_matrix(layer_w8, 'W8', make_plot=make_ww_plots, plot_dir=round_plot_dir)\n",
"\n",
" if make_ww_plots:\n",
" plot_trap_lines_on_weight_histogram(layer_w2, 'W2', r, round_plot_dir, max_traps=TRAP_HIST_MAX_LINES)\n",
" plot_trap_lines_on_weight_histogram(layer_w7, 'W7', r, round_plot_dir, max_traps=TRAP_HIST_MAX_LINES)\n",
" plot_trap_lines_on_weight_histogram(layer_w8, 'W8', r, round_plot_dir, max_traps=TRAP_HIST_MAX_LINES)\n",
"\n",
" alpha_w2 = ww_w2['alpha']\n",
" alpha_w7 = ww_w7['alpha']\n",
" alpha_w8 = ww_w8['alpha']\n",
Expand Down