33# For license information, see LICENSE.TXT
44"""Tabular Model."""
55
6+ import html
67import inspect
8+ import json
79import os
10+ import uuid
811import warnings
912from collections import defaultdict
1013from functools import partial
1114from pathlib import Path
15+ from pprint import pformat
1216from typing import Callable , Dict , Iterable , List , Optional , Tuple , Union
1317
1418import joblib
2226from pandas import DataFrame
2327from pytorch_lightning import seed_everything
2428from pytorch_lightning .callbacks import RichProgressBar
25- from pytorch_lightning .callbacks .gradient_accumulation_scheduler import GradientAccumulationScheduler
29+ from pytorch_lightning .callbacks .gradient_accumulation_scheduler import (
30+ GradientAccumulationScheduler ,
31+ )
2632from pytorch_lightning .tuner .tuning import Tuner
2733from pytorch_lightning .utilities .model_summary import summarize
2834from rich import print as rich_print
4147)
4248from pytorch_tabular .config .config import InferredConfig
4349from pytorch_tabular .models .base_model import BaseModel , _CaptumModel , _GenericModel
44- from pytorch_tabular .models .common .layers .embeddings import Embedding1dLayer , Embedding2dLayer , PreEncoded1dLayer
50+ from pytorch_tabular .models .common .layers .embeddings import (
51+ Embedding1dLayer ,
52+ Embedding2dLayer ,
53+ PreEncoded1dLayer ,
54+ )
4555from pytorch_tabular .tabular_datamodule import TabularDatamodule
4656from pytorch_tabular .utils import (
4757 OOMException ,
@@ -262,7 +272,9 @@ def _setup_experiment_tracking(self):
262272 """Sets up the Experiment Tracking Framework according to the choices made in the Experimentconfig."""
263273 if self .config .log_target == "tensorboard" :
264274 self .logger = pl .loggers .TensorBoardLogger (
265- name = self .run_name , save_dir = self .config .project_name , version = self .uid
275+ name = self .run_name ,
276+ save_dir = self .config .project_name ,
277+ version = self .uid ,
266278 )
267279 elif self .config .log_target == "wandb" :
268280 self .logger = pl .loggers .WandbLogger (
@@ -1647,8 +1659,9 @@ def summary(self, model=None, max_depth: int = -1) -> None:
16471659 """Prints a summary of the model.
16481660
16491661 Args:
1650- max_depth (int): The maximum depth to traverse the modules and displayed in the summary.
1651- Defaults to -1, which means will display all the modules.
1662+ max_depth (int): The maximum depth to traverse the modules and
1663+ displayed in the summary. Defaults to -1, which means will
1664+ display all the modules.
16521665
16531666 """
16541667 if model is not None :
@@ -1666,8 +1679,215 @@ def summary(self, model=None, max_depth: int = -1) -> None:
16661679 "been initialized or passed in as an argument[/bold red]"
16671680 )
16681681
1682+ def ret_summary (self , model = None , max_depth : int = - 1 ) -> str :
1683+ """Returns a summary of the model as a string.
1684+
1685+ Args:
1686+ max_depth (int): The maximum depth to traverse the modules and
1687+ displayed in the summary. Defaults to -1, which means will
1688+ display all the modules.
1689+
1690+ Returns:
1691+ str: The summary of the model.
1692+
1693+ """
1694+ if model is not None :
1695+ return str (summarize (model , max_depth = max_depth ))
1696+ elif self .has_model :
1697+ return str (summarize (self .model , max_depth = max_depth ))
1698+ else :
1699+ summary_str = f"{ self .__class__ .__name__ } \n "
1700+ summary_str += "-" * 100 + "\n "
1701+ summary_str += "Config\n "
1702+ summary_str += "-" * 100 + "\n "
1703+ summary_str += pformat (self .config .__dict__ ["_content" ], indent = 4 , width = 80 , compact = True )
1704+ summary_str += "\n Full Model Summary once model has been " "initialized or passed in as an argument"
1705+ return summary_str
1706+
16691707 def __str__ (self ) -> str :
1670- return self .summary ()
1708+ """Returns a readable summary of the TabularModel object."""
1709+ model_name = self .model .__class__ .__name__ if self .has_model else self .config ._model_name + "(Not Initialized)"
1710+ return f"{ self .__class__ .__name__ } (model={ model_name } )"
1711+
1712+ def __repr__ (self ) -> str :
1713+ """Returns an unambiguous representation of the TabularModel object."""
1714+ config_str = json .dumps (OmegaConf .to_container (self .config , resolve = True ), indent = 4 )
1715+ ret_str = f"{ self .__class__ .__name__ } (\n "
1716+ if self .has_model :
1717+ ret_str += f" model={ self .model .__class__ .__name__ } ,\n "
1718+ else :
1719+ ret_str += f" model={ self .config ._model_name } (Not Initialized),\n "
1720+ ret_str += f" config={ config_str } ,\n "
1721+ return ret_str
1722+
1723+ def _repr_html_ (self ):
1724+ """Generate an HTML representation for Jupyter Notebook."""
1725+ css = """
1726+ <style>
1727+ .main-container {
1728+ font-family: Arial, sans-serif;
1729+ font-size: 14px;
1730+ border: 1px dashed #ccc;
1731+ padding: 10px;
1732+ margin: 10px;
1733+ background-color: #f9f9f9;
1734+ }
1735+ .header {
1736+ background-color: #e8f4fc;
1737+ padding: 5px;
1738+ font-weight: bold;
1739+ text-align: center;
1740+ border-bottom: 1px solid #ccc;
1741+ }
1742+ .section {
1743+ margin: 10px 0;
1744+ padding: 10px;
1745+ border: 1px solid #ccc;
1746+ background-color: #ffffff;
1747+ }
1748+ .step {
1749+ border: 1px solid #ccc;
1750+ background-color: #f0f8ff;
1751+ margin: 5px 0;
1752+ padding: 5px;
1753+ }
1754+ .sub-step {
1755+ margin-left: 20px;
1756+ border: 1px solid #ddd;
1757+ background-color: #f9f9f9;
1758+ padding: 5px;
1759+ }
1760+ .toggle-button {
1761+ cursor: pointer;
1762+ font-size: 12px;
1763+ margin-right: 5px;
1764+ }
1765+ .toggle-button:hover {
1766+ color: #0056b3;
1767+ }
1768+ .hidden {
1769+ display: none;
1770+ }
1771+ table {
1772+ width: 100%;
1773+ border-collapse: collapse;
1774+ }
1775+ table, th, td {
1776+ border: 1px solid black;
1777+ }
1778+ th, td {
1779+ padding: 5px;
1780+ text-align: left;
1781+ }
1782+ </style>
1783+ <script>
1784+ function toggleVisibility(id) {
1785+ var element = document.getElementById(id);
1786+ if (element.classList.contains('hidden')) {
1787+ element.classList.remove('hidden');
1788+ } else {
1789+ element.classList.add('hidden');
1790+ }
1791+ }
1792+ </script>
1793+ """
1794+
1795+ # Header (Main model name)
1796+ uid = str (uuid .uuid4 ())
1797+ model_status = "" if self .has_model else "(Not Initialized)"
1798+ model_name = self .model .__class__ .__name__ if self .has_model else self .config ._model_name
1799+ header_html = f"<div class='header'>{ html .escape (model_name )} { model_status } </div>"
1800+
1801+ # Config Section
1802+ config_html = self ._generate_collapsible_section ("Model Config" , self .config , uid = uid , is_dict = True )
1803+
1804+ # Summary Section
1805+ summary_html = (
1806+ ""
1807+ if not self .has_model
1808+ else self ._generate_collapsible_section ("Model Summary" , self ._generate_model_summary_table (), uid = uid )
1809+ )
1810+
1811+ # Combine sections
1812+ return f"""
1813+ { css }
1814+ <div class='main-container'>
1815+ { header_html }
1816+ { config_html }
1817+ { summary_html }
1818+ </div>
1819+ """
1820+
1821+ def _generate_collapsible_section (self , title , content , uid , is_dict = False ):
1822+ container_id = title .lower ().replace (" " , "_" ) + uid
1823+ if is_dict :
1824+ content = self ._generate_nested_collapsible_sections (
1825+ OmegaConf .to_container (content , resolve = True ), container_id
1826+ )
1827+ return f"""
1828+ <div>
1829+ <span
1830+ class="toggle-button"
1831+ onclick="toggleVisibility('{ container_id } ')"
1832+ >
1833+ ▶
1834+ </span>
1835+ <strong>{ html .escape (title )} </strong>
1836+ <div id="{ container_id } " class="hidden section">
1837+ { content }
1838+ </div>
1839+ </div>
1840+ """
1841+
1842+ def _generate_nested_collapsible_sections (self , content , parent_id ):
1843+ html_content = ""
1844+ for key , value in content .items ():
1845+ if isinstance (value , dict ):
1846+ nested_id = f"{ parent_id } _{ key } " .replace (" " , "_" )
1847+ nested_id = nested_id + str (uuid .uuid4 ())
1848+ nested_content = self ._generate_nested_collapsible_sections (value , nested_id )
1849+ html_content += f"""
1850+ <div>
1851+ <span
1852+ class="toggle-button"
1853+ onclick="toggleVisibility('{ nested_id } ')"
1854+ >
1855+ ▶
1856+ </span>
1857+ <strong>{ html .escape (key )} </strong>
1858+ <div id="{ nested_id } " class="hidden section">
1859+ { nested_content }
1860+ </div>
1861+ </div>
1862+ """
1863+ else :
1864+ html_content += f"<div><strong>{ html .escape (key )} :</strong> { html .escape (str (value ))} </div>"
1865+ return html_content
1866+
1867+ def _generate_model_summary_table (self ):
1868+ model_summary = summarize (self .model , max_depth = 1 )
1869+ table_html = """
1870+ <table>
1871+ <tr>
1872+ <th><b>Layer</b></th>
1873+ <th><b>Type</b></th>
1874+ <th><b>Params</b></th>
1875+ <th><b>In sizes</b></th>
1876+ <th><b>Out sizes</b></th>
1877+ </tr>
1878+ """
1879+ for name , layer in model_summary ._layer_summary .items ():
1880+ table_html += f"""
1881+ <tr>
1882+ <td>{ html .escape (name )} </td>
1883+ <td>{ html .escape (layer .layer_type )} </td>
1884+ <td>{ html .escape (str (layer .num_parameters ))} </td>
1885+ <td>{ html .escape (str (layer .in_size ))} </td>
1886+ <td>{ html .escape (str (layer .out_size ))} </td>
1887+ </tr>
1888+ """
1889+ table_html += "</table>"
1890+ return table_html
16711891
16721892 def feature_importance (self ) -> DataFrame :
16731893 """Returns the feature importance of the model as a pandas DataFrame."""
@@ -1998,7 +2218,10 @@ def cross_validate(
19982218 # Initialize datamodule and model in the first fold
19992219 # uses train data from this fold to fit all transformers
20002220 datamodule = self .prepare_dataloader (
2001- train = train .iloc [train_idx ], validation = train .iloc [val_idx ], seed = 42 , ** prep_dl_kwargs
2221+ train = train .iloc [train_idx ],
2222+ validation = train .iloc [val_idx ],
2223+ seed = 42 ,
2224+ ** prep_dl_kwargs ,
20022225 )
20032226 model = self .prepare_model (datamodule , ** prep_model_kwargs )
20042227 else :
0 commit comments