diff --git a/XGBWW_OpenML_1049_W7W8_Alpha_Checkpoint_WWPlots.ipynb b/XGBWW_OpenML_1049_W7W8_Alpha_Checkpoint_WWPlots.ipynb index f0dc576..62a82da 100644 --- a/XGBWW_OpenML_1049_W7W8_Alpha_Checkpoint_WWPlots.ipynb +++ b/XGBWW_OpenML_1049_W7W8_Alpha_Checkpoint_WWPlots.ipynb @@ -190,6 +190,7 @@ "import xgboost as xgb\n", "import torch\n", "import weightwatcher as ww\n", + "import matplotlib.pyplot as plt\n", "\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import accuracy_score, log_loss\n", @@ -228,6 +229,126 @@ " return None\n", "\n", "\n", + "\n", + "def _extract_weight_matrix(layer):\n", + " candidates = []\n", + "\n", + " if hasattr(layer, 'weight'):\n", + " candidates.append(layer.weight)\n", + "\n", + " if hasattr(layer, 'modules'):\n", + " for sublayer in layer.modules():\n", + " if sublayer is layer:\n", + " continue\n", + " if hasattr(sublayer, 'weight'):\n", + " candidates.append(sublayer.weight)\n", + "\n", + " for w in candidates:\n", + " if hasattr(w, 'detach'):\n", + " arr = w.detach().cpu().numpy()\n", + " else:\n", + " arr = np.asarray(w)\n", + " if arr.ndim == 2:\n", + " return arr\n", + "\n", + " return None\n", + "\n", + "\n", + "def _dominant_weight_values(weight_matrix, max_traps=8):\n", + " if weight_matrix is None:\n", + " return []\n", + "\n", + " flat = np.asarray(weight_matrix, dtype=float).ravel()\n", + " if flat.size == 0:\n", + " return []\n", + "\n", + " order = np.argsort(np.abs(flat))[::-1]\n", + " selected = []\n", + " for idx in order:\n", + " value = float(flat[idx])\n", + " if not np.isfinite(value):\n", + " continue\n", + " if any(np.isclose(value, prev, atol=1e-12) for prev in selected):\n", + " continue\n", + " selected.append(value)\n", + " if len(selected) >= int(max_traps):\n", + " break\n", + " return selected\n", + "\n", + "\n", + "def plot_trap_lines_on_weight_histogram(layer, matrix_name, round_idx, output_dir, max_traps=8, trap_values=None, message_prefix='analyze traps'):\n", + " weight_matrix = _extract_weight_matrix(layer)\n", + " if weight_matrix is None:\n", + " print(f'round={round_idx:4d} | skipping trap histogram for {matrix_name}: no 2D weight matrix found')\n", + " return\n", + "\n", + " values = np.asarray(weight_matrix, dtype=float).ravel()\n", + " values = values[np.isfinite(values)]\n", + " if values.size == 0:\n", + " print(f'round={round_idx:4d} | skipping trap histogram for {matrix_name}: empty/invalid weights')\n", + " return\n", + "\n", + " if trap_values is None:\n", + " selected_traps = _dominant_weight_values(weight_matrix, max_traps=max_traps)\n", + " else:\n", + " selected_traps = [float(v) for v in trap_values if np.isfinite(v)]\n", + " if not selected_traps:\n", + " print(f'round={round_idx:4d} | skipping trap histogram for {matrix_name}: no trap values selected')\n", + " return\n", + "\n", + " max_abs_weight = float(values[np.argmax(np.abs(values))])\n", + "\n", + " fig, ax = plt.subplots(figsize=(9, 6))\n", + " ax.hist(values, bins=60, density=True, alpha=0.6, color='tab:blue', edgecolor='none')\n", + " ax.set_xlabel('Weight value')\n", + " ax.set_ylabel('Density')\n", + " ax.set_title(f'Layer {matrix_name} ({round_idx}) \u2014 trap lines on weight histogram')\n", + "\n", + " cmap = plt.get_cmap('viridis')\n", + " y_max = ax.get_ylim()[1]\n", + "\n", + " for i, trap_value in enumerate(selected_traps, start=1):\n", + " color = cmap((i - 1) / max(1, len(selected_traps) - 1))\n", + " is_global_max = np.isclose(trap_value, max_abs_weight, atol=1e-12)\n", + " linestyle = (0, (7, 2, 2, 2)) if is_global_max else '--'\n", + " linewidth = 2.6 if is_global_max else 1.6\n", + "\n", + " ax.axvline(trap_value, color=color, linestyle=linestyle, linewidth=linewidth, alpha=0.95)\n", + "\n", + " label_text = f'#{i}: ({trap_value:.4g})\\n' + ('max |w|' if is_global_max else '\\u03b5=0.000')\n", + " y_pos = y_max * (0.93 - 0.08 * ((i - 1) % 5))\n", + " ax.text(\n", + " trap_value,\n", + " y_pos,\n", + " label_text,\n", + " rotation=90,\n", + " va='top',\n", + " ha='center',\n", + " fontsize=8,\n", + " color='black',\n", + " bbox=dict(boxstyle='round,pad=0.15', facecolor='white', edgecolor='none', alpha=0.55),\n", + " )\n", + "\n", + " output_dir = Path(output_dir)\n", + " output_dir.mkdir(parents=True, exist_ok=True)\n", + " figure_path = output_dir / f'{matrix_name.lower()}_trap_hist.png'\n", + " fig.tight_layout()\n", + " fig.savefig(figure_path, dpi=160)\n", + " plt.show()\n", + " plt.close(fig)\n", + " print(f'round={round_idx:4d} | {message_prefix} ({matrix_name}): saved trap histogram to {figure_path}')\n", + "\n", + "\n", + "def plot_removed_traps_on_weight_histogram(layer, matrix_name, round_idx, output_dir, removed_trap_values):\n", + " return plot_trap_lines_on_weight_histogram(\n", + " layer=layer,\n", + " matrix_name=matrix_name,\n", + " round_idx=round_idx,\n", + " output_dir=output_dir,\n", + " trap_values=removed_trap_values,\n", + " message_prefix='remove traps',\n", + " )\n", + "\n", "def layer_min_matrix_dim(layer):\n", " shape = _extract_weight_shape(layer)\n", " if shape is None:\n", @@ -450,6 +571,7 @@ "MIN_WW_EIGENVALUES = 10 # Wait until constructed matrix has enough spectral support\n", "WW_PLOTS_DIR = DRIVE_ROOT / 'ww_plots'\n", "WW_PLOTS_DIR.mkdir(parents=True, exist_ok=True)\n", + "TRAP_HIST_MAX_LINES = 8\n", "\n", "start_time = time.time()\n", "for r in range(start_round + 1, MAX_ROUNDS + 1):\n", @@ -507,6 +629,11 @@ " ww_w7 = ww_stats_for_matrix(layer_w7, 'W7', make_plot=make_ww_plots, plot_dir=round_plot_dir)\n", " ww_w8 = ww_stats_for_matrix(layer_w8, 'W8', make_plot=make_ww_plots, plot_dir=round_plot_dir)\n", "\n", + " if make_ww_plots:\n", + " plot_trap_lines_on_weight_histogram(layer_w2, 'W2', r, round_plot_dir, max_traps=TRAP_HIST_MAX_LINES)\n", + " plot_trap_lines_on_weight_histogram(layer_w7, 'W7', r, round_plot_dir, max_traps=TRAP_HIST_MAX_LINES)\n", + " plot_trap_lines_on_weight_histogram(layer_w8, 'W8', r, round_plot_dir, max_traps=TRAP_HIST_MAX_LINES)\n", + "\n", " alpha_w2 = ww_w2['alpha']\n", " alpha_w7 = ww_w7['alpha']\n", " alpha_w8 = ww_w8['alpha']\n",