Skip to content

Commit 2998ac2

Browse files
committed
-- rafactored data_aware_initialization
-- added docs
1 parent fab881a commit 2998ac2

File tree

4 files changed

+27
-33
lines changed

4 files changed

+27
-33
lines changed

docs/models.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,8 @@ There are two methods that need to be defined in any class that inherits the Bas
128128

129129
While this is the bare minimum, you can redefine or use any of the Pytorch Lightning standard methods to tweak your model and training to your liking.
130130

131+
If your model needs to use custom data-aware initialization techniques(like NODE), you can override `data_aware_initialization(self, datamodule)` in the model. In here you have access to the datamodule and the dataloaders for initialization.
132+
131133
In addition to the model, you will also need to define a config. Configs are python dataclasses and should inherit `ModelConfig` and will have all the parameters of the ModelConfig. by default. Any additional parameter should be defined in the dataclass.
132134

133135

pytorch_tabular/models/base_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ def calculate_metrics(self, y, y_hat, tag):
158158
prog_bar=True,
159159
)
160160
return metrics
161+
162+
def data_aware_initialization(self, datamodule):
163+
pass
161164

162165
@abstractmethod
163166
def forward(self, x: Dict):

pytorch_tabular/models/node/node_model.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,23 @@ def __init__(self, config: DictConfig, **kwargs):
6464
def subset(self, x):
6565
return x[..., : self.hparams.output_dim].mean(dim=-2)
6666

67+
def data_aware_initialization(self, datamodule):
68+
"""Performs data-aware initialization for NODE"""
69+
logger.info("Data Aware Initialization....")
70+
# Need a big batch to initialize properly
71+
alt_loader = datamodule.train_dataloader(batch_size=2000)
72+
batch = next(iter(alt_loader))
73+
for k, v in batch.items():
74+
if isinstance(v, list) and (len(v) == 0):
75+
# Skipping empty list
76+
continue
77+
# batch[k] = v.to("cpu" if self.config.gpu == 0 else "cuda")
78+
batch[k] = v.to(self.device)
79+
80+
# single forward pass to initialize the ODST
81+
with torch.no_grad():
82+
self(batch)
83+
6784
def _build_network(self):
6885
if self.hparams.embed_categorical:
6986
self.embedding_layers = nn.ModuleList(

pytorch_tabular/tabular_model.py

Lines changed: 5 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,6 @@ def __init__(
7979
), "If `config` is None, `data_config`, `model_config`, `trainer_config`, and `optimizer_config` cannot be None"
8080
data_config = self._read_parse_config(data_config, DataConfig)
8181
model_config = self._read_parse_config(model_config, ModelConfig)
82-
# # Re-routing to Categorical embedding Model if embed_categorical is true for NODE
83-
# if (
84-
# hasattr(model_config, "_model_name")
85-
# and (model_config._model_name == "NODEModel")
86-
# and (model_config.embed_categorical)
87-
# and ("CategoryEmbedding" not in model_config._model_name)
88-
# ):
89-
# model_config._model_name = (
90-
# "CategoryEmbedding" + model_config._model_name
91-
# )
9282
trainer_config = self._read_parse_config(trainer_config, TrainerConfig)
9383
optimizer_config = self._read_parse_config(
9484
optimizer_config, OptimizerConfig
@@ -255,23 +245,6 @@ def _prepare_callbacks(self) -> List:
255245
logger.debug(f"Callbacks used: {callbacks}")
256246
return callbacks
257247

258-
def data_aware_initialization(self):
259-
"""Performs data-aware initialization for NODE"""
260-
logger.info("Data Aware Initialization....")
261-
# Need a big batch to initialize properly
262-
alt_loader = self.datamodule.train_dataloader(batch_size=2000)
263-
batch = next(iter(alt_loader))
264-
for k, v in batch.items():
265-
if isinstance(v, list) and (len(v) == 0):
266-
# Skipping empty list
267-
continue
268-
# batch[k] = v.to("cpu" if self.config.gpu == 0 else "cuda")
269-
batch[k] = v.to(self.model.device)
270-
271-
# single forward pass to initialize the ODST
272-
with torch.no_grad():
273-
self.model(batch)
274-
275248
def _prepare_dataloader(
276249
self, train, validation, test, target_transform=None, train_sampler=None
277250
):
@@ -312,9 +285,9 @@ def _prepare_model(self, loss, metrics, optimizer, optimizer_params, reset):
312285
custom_optimizer=optimizer,
313286
custom_optimizer_params=optimizer_params,
314287
)
315-
# Data Aware Initialization (NODE)
316-
if self.config._model_name in ["NODEModel"]:
317-
self.data_aware_initialization()
288+
# Data Aware Initialization(for the models that need it)
289+
self.model.data_aware_initialization(self.datamodule)
290+
318291

319292
def _prepare_trainer(self, max_epochs=None, min_epochs=None):
320293
logger.info("Preparing the Trainer...")
@@ -459,9 +432,8 @@ def fit(
459432
self.model.train()
460433
if self.config.auto_lr_find and (not self.config.fast_dev_run):
461434
self.trainer.tune(self.model, train_loader, val_loader)
462-
# Parameters in NODE needs to be initialized again
463-
if self.config._model_name in ["CategoryEmbeddingNODEModel", "NODEModel"]:
464-
self.data_aware_initialization()
435+
# Parameters in models needs to be initialized again after LR find
436+
self.model.data_aware_initialization(self.datamodule)
465437
self.model.train()
466438
self.trainer.fit(self.model, train_loader, val_loader)
467439
logger.info("Training the model completed...")

0 commit comments

Comments
 (0)