Skip to content

Commit 495803c

Browse files
Adding informative str and repr (#507)
* added informative repr, str and repr_html * fixed some issues and added test cases * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixed some precommit errors --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d574814 commit 495803c

File tree

3 files changed

+293
-8
lines changed

3 files changed

+293
-8
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,4 @@ docs/tutorials/pytorch-tabular-covertype/
162162

163163
# Pycharm
164164
.idea/
165+
test.ipynb

src/pytorch_tabular/tabular_model.py

Lines changed: 230 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,16 @@
33
# For license information, see LICENSE.TXT
44
"""Tabular Model."""
55

6+
import html
67
import inspect
8+
import json
79
import os
10+
import uuid
811
import warnings
912
from collections import defaultdict
1013
from functools import partial
1114
from pathlib import Path
15+
from pprint import pformat
1216
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
1317

1418
import joblib
@@ -22,7 +26,9 @@
2226
from pandas import DataFrame
2327
from pytorch_lightning import seed_everything
2428
from pytorch_lightning.callbacks import RichProgressBar
25-
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import GradientAccumulationScheduler
29+
from pytorch_lightning.callbacks.gradient_accumulation_scheduler import (
30+
GradientAccumulationScheduler,
31+
)
2632
from pytorch_lightning.tuner.tuning import Tuner
2733
from pytorch_lightning.utilities.model_summary import summarize
2834
from rich import print as rich_print
@@ -41,7 +47,11 @@
4147
)
4248
from pytorch_tabular.config.config import InferredConfig
4349
from pytorch_tabular.models.base_model import BaseModel, _CaptumModel, _GenericModel
44-
from pytorch_tabular.models.common.layers.embeddings import Embedding1dLayer, Embedding2dLayer, PreEncoded1dLayer
50+
from pytorch_tabular.models.common.layers.embeddings import (
51+
Embedding1dLayer,
52+
Embedding2dLayer,
53+
PreEncoded1dLayer,
54+
)
4555
from pytorch_tabular.tabular_datamodule import TabularDatamodule
4656
from pytorch_tabular.utils import (
4757
OOMException,
@@ -262,7 +272,9 @@ def _setup_experiment_tracking(self):
262272
"""Sets up the Experiment Tracking Framework according to the choices made in the Experimentconfig."""
263273
if self.config.log_target == "tensorboard":
264274
self.logger = pl.loggers.TensorBoardLogger(
265-
name=self.run_name, save_dir=self.config.project_name, version=self.uid
275+
name=self.run_name,
276+
save_dir=self.config.project_name,
277+
version=self.uid,
266278
)
267279
elif self.config.log_target == "wandb":
268280
self.logger = pl.loggers.WandbLogger(
@@ -1647,8 +1659,9 @@ def summary(self, model=None, max_depth: int = -1) -> None:
16471659
"""Prints a summary of the model.
16481660
16491661
Args:
1650-
max_depth (int): The maximum depth to traverse the modules and displayed in the summary.
1651-
Defaults to -1, which means will display all the modules.
1662+
max_depth (int): The maximum depth to traverse the modules and
1663+
displayed in the summary. Defaults to -1, which means will
1664+
display all the modules.
16521665
16531666
"""
16541667
if model is not None:
@@ -1666,8 +1679,215 @@ def summary(self, model=None, max_depth: int = -1) -> None:
16661679
"been initialized or passed in as an argument[/bold red]"
16671680
)
16681681

1682+
def ret_summary(self, model=None, max_depth: int = -1) -> str:
1683+
"""Returns a summary of the model as a string.
1684+
1685+
Args:
1686+
max_depth (int): The maximum depth to traverse the modules and
1687+
displayed in the summary. Defaults to -1, which means will
1688+
display all the modules.
1689+
1690+
Returns:
1691+
str: The summary of the model.
1692+
1693+
"""
1694+
if model is not None:
1695+
return str(summarize(model, max_depth=max_depth))
1696+
elif self.has_model:
1697+
return str(summarize(self.model, max_depth=max_depth))
1698+
else:
1699+
summary_str = f"{self.__class__.__name__}\n"
1700+
summary_str += "-" * 100 + "\n"
1701+
summary_str += "Config\n"
1702+
summary_str += "-" * 100 + "\n"
1703+
summary_str += pformat(self.config.__dict__["_content"], indent=4, width=80, compact=True)
1704+
summary_str += "\nFull Model Summary once model has been " "initialized or passed in as an argument"
1705+
return summary_str
1706+
16691707
def __str__(self) -> str:
1670-
return self.summary()
1708+
"""Returns a readable summary of the TabularModel object."""
1709+
model_name = self.model.__class__.__name__ if self.has_model else self.config._model_name + "(Not Initialized)"
1710+
return f"{self.__class__.__name__}(model={model_name})"
1711+
1712+
def __repr__(self) -> str:
1713+
"""Returns an unambiguous representation of the TabularModel object."""
1714+
config_str = json.dumps(OmegaConf.to_container(self.config, resolve=True), indent=4)
1715+
ret_str = f"{self.__class__.__name__}(\n"
1716+
if self.has_model:
1717+
ret_str += f" model={self.model.__class__.__name__},\n"
1718+
else:
1719+
ret_str += f" model={self.config._model_name} (Not Initialized),\n"
1720+
ret_str += f" config={config_str},\n"
1721+
return ret_str
1722+
1723+
def _repr_html_(self):
1724+
"""Generate an HTML representation for Jupyter Notebook."""
1725+
css = """
1726+
<style>
1727+
.main-container {
1728+
font-family: Arial, sans-serif;
1729+
font-size: 14px;
1730+
border: 1px dashed #ccc;
1731+
padding: 10px;
1732+
margin: 10px;
1733+
background-color: #f9f9f9;
1734+
}
1735+
.header {
1736+
background-color: #e8f4fc;
1737+
padding: 5px;
1738+
font-weight: bold;
1739+
text-align: center;
1740+
border-bottom: 1px solid #ccc;
1741+
}
1742+
.section {
1743+
margin: 10px 0;
1744+
padding: 10px;
1745+
border: 1px solid #ccc;
1746+
background-color: #ffffff;
1747+
}
1748+
.step {
1749+
border: 1px solid #ccc;
1750+
background-color: #f0f8ff;
1751+
margin: 5px 0;
1752+
padding: 5px;
1753+
}
1754+
.sub-step {
1755+
margin-left: 20px;
1756+
border: 1px solid #ddd;
1757+
background-color: #f9f9f9;
1758+
padding: 5px;
1759+
}
1760+
.toggle-button {
1761+
cursor: pointer;
1762+
font-size: 12px;
1763+
margin-right: 5px;
1764+
}
1765+
.toggle-button:hover {
1766+
color: #0056b3;
1767+
}
1768+
.hidden {
1769+
display: none;
1770+
}
1771+
table {
1772+
width: 100%;
1773+
border-collapse: collapse;
1774+
}
1775+
table, th, td {
1776+
border: 1px solid black;
1777+
}
1778+
th, td {
1779+
padding: 5px;
1780+
text-align: left;
1781+
}
1782+
</style>
1783+
<script>
1784+
function toggleVisibility(id) {
1785+
var element = document.getElementById(id);
1786+
if (element.classList.contains('hidden')) {
1787+
element.classList.remove('hidden');
1788+
} else {
1789+
element.classList.add('hidden');
1790+
}
1791+
}
1792+
</script>
1793+
"""
1794+
1795+
# Header (Main model name)
1796+
uid = str(uuid.uuid4())
1797+
model_status = "" if self.has_model else "(Not Initialized)"
1798+
model_name = self.model.__class__.__name__ if self.has_model else self.config._model_name
1799+
header_html = f"<div class='header'>{html.escape(model_name)}{model_status}</div>"
1800+
1801+
# Config Section
1802+
config_html = self._generate_collapsible_section("Model Config", self.config, uid=uid, is_dict=True)
1803+
1804+
# Summary Section
1805+
summary_html = (
1806+
""
1807+
if not self.has_model
1808+
else self._generate_collapsible_section("Model Summary", self._generate_model_summary_table(), uid=uid)
1809+
)
1810+
1811+
# Combine sections
1812+
return f"""
1813+
{css}
1814+
<div class='main-container'>
1815+
{header_html}
1816+
{config_html}
1817+
{summary_html}
1818+
</div>
1819+
"""
1820+
1821+
def _generate_collapsible_section(self, title, content, uid, is_dict=False):
1822+
container_id = title.lower().replace(" ", "_") + uid
1823+
if is_dict:
1824+
content = self._generate_nested_collapsible_sections(
1825+
OmegaConf.to_container(content, resolve=True), container_id
1826+
)
1827+
return f"""
1828+
<div>
1829+
<span
1830+
class="toggle-button"
1831+
onclick="toggleVisibility('{container_id}')"
1832+
>
1833+
&#9654;
1834+
</span>
1835+
<strong>{html.escape(title)}</strong>
1836+
<div id="{container_id}" class="hidden section">
1837+
{content}
1838+
</div>
1839+
</div>
1840+
"""
1841+
1842+
def _generate_nested_collapsible_sections(self, content, parent_id):
1843+
html_content = ""
1844+
for key, value in content.items():
1845+
if isinstance(value, dict):
1846+
nested_id = f"{parent_id}_{key}".replace(" ", "_")
1847+
nested_id = nested_id + str(uuid.uuid4())
1848+
nested_content = self._generate_nested_collapsible_sections(value, nested_id)
1849+
html_content += f"""
1850+
<div>
1851+
<span
1852+
class="toggle-button"
1853+
onclick="toggleVisibility('{nested_id}')"
1854+
>
1855+
&#9654;
1856+
</span>
1857+
<strong>{html.escape(key)}</strong>
1858+
<div id="{nested_id}" class="hidden section">
1859+
{nested_content}
1860+
</div>
1861+
</div>
1862+
"""
1863+
else:
1864+
html_content += f"<div><strong>{html.escape(key)}:</strong> {html.escape(str(value))}</div>"
1865+
return html_content
1866+
1867+
def _generate_model_summary_table(self):
1868+
model_summary = summarize(self.model, max_depth=1)
1869+
table_html = """
1870+
<table>
1871+
<tr>
1872+
<th><b>Layer</b></th>
1873+
<th><b>Type</b></th>
1874+
<th><b>Params</b></th>
1875+
<th><b>In sizes</b></th>
1876+
<th><b>Out sizes</b></th>
1877+
</tr>
1878+
"""
1879+
for name, layer in model_summary._layer_summary.items():
1880+
table_html += f"""
1881+
<tr>
1882+
<td>{html.escape(name)}</td>
1883+
<td>{html.escape(layer.layer_type)}</td>
1884+
<td>{html.escape(str(layer.num_parameters))}</td>
1885+
<td>{html.escape(str(layer.in_size))}</td>
1886+
<td>{html.escape(str(layer.out_size))}</td>
1887+
</tr>
1888+
"""
1889+
table_html += "</table>"
1890+
return table_html
16711891

16721892
def feature_importance(self) -> DataFrame:
16731893
"""Returns the feature importance of the model as a pandas DataFrame."""
@@ -1998,7 +2218,10 @@ def cross_validate(
19982218
# Initialize datamodule and model in the first fold
19992219
# uses train data from this fold to fit all transformers
20002220
datamodule = self.prepare_dataloader(
2001-
train=train.iloc[train_idx], validation=train.iloc[val_idx], seed=42, **prep_dl_kwargs
2221+
train=train.iloc[train_idx],
2222+
validation=train.iloc[val_idx],
2223+
seed=42,
2224+
**prep_dl_kwargs,
20022225
)
20032226
model = self.prepare_model(datamodule, **prep_model_kwargs)
20042227
else:

tests/test_common.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
MODEL_CONFIG_SAVE_TEST = [
3333
(CategoryEmbeddingModelConfig, {"layers": "10-20"}),
34-
(AutoIntConfig, {"num_heads": 1, "num_attn_blocks": 1}),
34+
(GANDALFConfig, {}),
3535
(NodeConfig, {"num_trees": 100, "depth": 2}),
3636
(TabNetModelConfig, {"n_a": 2, "n_d": 2}),
3737
]
@@ -1247,3 +1247,64 @@ def test_model_compare_regression(regression_data, model_list, continuous_cols,
12471247
# # there may be multiple models with the same score
12481248
# best_models = comp_df.loc[comp_df[f"test_{rank_metric[0]}"] == best_score, "model"].values.tolist()
12491249
# assert best_model.model._get_name() in best_models
1250+
1251+
1252+
@pytest.mark.parametrize("model_config_class", MODEL_CONFIG_SAVE_TEST)
1253+
@pytest.mark.parametrize("continuous_cols", [list(DATASET_CONTINUOUS_COLUMNS)])
1254+
@pytest.mark.parametrize("categorical_cols", [["HouseAgeBin"]])
1255+
@pytest.mark.parametrize("custom_metrics", [None, [fake_metric]])
1256+
@pytest.mark.parametrize("custom_loss", [None, torch.nn.L1Loss()])
1257+
@pytest.mark.parametrize("custom_optimizer", [None, torch.optim.Adagrad, "SGD", "torch_optimizer.AdaBound"])
1258+
def test_str_repr(
1259+
regression_data,
1260+
model_config_class,
1261+
continuous_cols,
1262+
categorical_cols,
1263+
custom_metrics,
1264+
custom_loss,
1265+
custom_optimizer,
1266+
):
1267+
(train, test, target) = regression_data
1268+
data_config = DataConfig(
1269+
target=target,
1270+
continuous_cols=continuous_cols,
1271+
categorical_cols=categorical_cols,
1272+
)
1273+
model_config_class, model_config_params = model_config_class
1274+
model_config_params["task"] = "regression"
1275+
model_config = model_config_class(**model_config_params)
1276+
trainer_config = TrainerConfig(
1277+
max_epochs=3,
1278+
checkpoints=None,
1279+
early_stopping=None,
1280+
accelerator="cpu",
1281+
fast_dev_run=True,
1282+
)
1283+
optimizer_config = OptimizerConfig()
1284+
1285+
tabular_model = TabularModel(
1286+
data_config=data_config,
1287+
model_config=model_config,
1288+
optimizer_config=optimizer_config,
1289+
trainer_config=trainer_config,
1290+
)
1291+
assert "Not Initialized" in str(tabular_model)
1292+
assert "Not Initialized" in repr(tabular_model)
1293+
assert "Model Summary" not in tabular_model._repr_html_()
1294+
assert "Model Config" in tabular_model._repr_html_()
1295+
assert "config" in tabular_model.__repr__()
1296+
assert "config" not in str(tabular_model)
1297+
tabular_model.fit(
1298+
train=train,
1299+
metrics=custom_metrics,
1300+
metrics_prob_inputs=None if custom_metrics is None else [False],
1301+
loss=custom_loss,
1302+
optimizer=custom_optimizer,
1303+
optimizer_params={},
1304+
)
1305+
assert model_config_class._model_name in str(tabular_model)
1306+
assert model_config_class._model_name in repr(tabular_model)
1307+
assert "Model Summary" in tabular_model._repr_html_()
1308+
assert "Model Config" in tabular_model._repr_html_()
1309+
assert "config" in tabular_model.__repr__()
1310+
assert model_config_class._model_name in tabular_model._repr_html_()

0 commit comments

Comments
 (0)