Skip to content

Commit 77cead1

Browse files
committed
--added new trainer arguements --upgraded versions of pytorch lightning --switched metrics to torchmetrics
1 parent 5c7738f commit 77cead1

File tree

4 files changed

+65
-20
lines changed

4 files changed

+65
-20
lines changed

pytorch_tabular/config/config.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# For license information, see LICENSE.TXT
44
"""Config"""
55
from dataclasses import MISSING, dataclass, field
6-
from typing import List, Optional, Tuple
6+
from typing import List, Optional, Tuple, Union
77
import os
88
from omegaconf import OmegaConf
99

@@ -168,18 +168,31 @@ class TrainerConfig:
168168
169169
max_epochs (int): Maximum number of epochs to be run
170170
171-
min_epochs (int): Minimum number of epochs to be run
171+
min_epochs (int): Force training for at least these many epochs. 1 by default
172172
173-
gpus (int): The index of the GPU to be used. If `None`, will use CPU
173+
max_time (Optional[int]): Stop training after this amount of time has passed. Disabled by default (None)
174+
175+
gpus (int): Number of gpus to train on (int) or which GPUs to train on (list or str). -1 uses all available GPUs. By default uses CPU (None)
174176
175177
accumulate_grad_batches (int): Accumulates grads every k batches or as set up in the dict.
176178
Trainer also calls optimizer.step() for the last indivisible step number.
177179
178180
auto_lr_find (bool): Runs a learning rate finder algorithm (see this paper) when calling trainer.tune(),
179181
to find optimal initial learning rate.
180182
183+
auto_select_gpus (bool): If enabled and `gpus` is an integer, pick available gpus automatically.
184+
This is especially useful when GPUs are configured to be in 'exclusive mode', such that only one
185+
process at a time can access them.
186+
181187
check_val_every_n_epoch (int): Check val every n train epochs.
182188
189+
deterministic (bool): If true enables cudnn.deterministic. Might make your system slower, but ensures reproducibility.
190+
191+
accelerator(str): The accelerator backend to use. Defaults to None. Check this link for detailed documentation about the functionality.
192+
https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#accelerator
193+
194+
tpu_cores(int): How many TPU cores to train on (1 or 8) / Single TPU to train on [1]. Defaults to None
195+
183196
gradient_clip_val (float): Gradient clipping value
184197
185198
overfit_batches (float): Uses this much data of the training set. If nonzero, will use the same training set
@@ -219,17 +232,20 @@ class TrainerConfig:
219232
default=64, metadata={"help": "Number of samples in each batch of training"}
220233
)
221234
fast_dev_run: bool = field(
222-
default=False, metadata={"help": "Quick Debug Run of Val"}
235+
default=False, metadata={"help": "runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) of train, val and test to find any bugs (ie: a sort of unit test)."}
223236
)
224237
max_epochs: int = field(
225238
default=10, metadata={"help": "Maximum number of epochs to be run"}
226239
)
227-
min_epochs: int = field(
228-
default=1, metadata={"help": "Minimum number of epochs to be run"}
240+
min_epochs: Optional[int] = field(
241+
default=1, metadata={"help": "Force training for at least these many epochs. 1 by default"}
242+
)
243+
max_time: Optional[int] = field(
244+
default=None, metadata={"help": "Stop training after this amount of time has passed. Disabled by default (None)"}
229245
)
230-
gpus: Optional[int] = field(
246+
gpus: Union[int, list] = field(
231247
default=None,
232-
metadata={"help": "The index of the GPU to be used. If None, will use CPU"},
248+
metadata={"help": "Number of gpus to train on (int) or which GPUs to train on (list or str). -1 uses all available GPUs. By default uses CPU (None)"},
233249
)
234250
accumulate_grad_batches: int = field(
235251
default=1,
@@ -243,6 +259,12 @@ class TrainerConfig:
243259
"help": "Runs a learning rate finder algorithm (see this paper) when calling trainer.tune(), to find optimal initial learning rate."
244260
},
245261
)
262+
auto_select_gpus: bool = field(
263+
default=True,
264+
metadata={
265+
"help": "If enabled and `gpus` is an integer, pick available gpus automatically. This is especially useful when GPUs are configured to be in 'exclusive mode', such that only one process at a time can access them."
266+
},
267+
)
246268
check_val_every_n_epoch: int = field(
247269
default=1, metadata={"help": "Check val every n train epochs."}
248270
)
@@ -255,6 +277,25 @@ class TrainerConfig:
255277
"help": "Uses this much data of the training set. If nonzero, will use the same training set for validation and testing. If the training dataloaders have shuffle=True, Lightning will automatically disable it. Useful for quickly debugging or trying to overfit on purpose."
256278
},
257279
)
280+
deterministic: bool = field(
281+
default=False,
282+
metadata={
283+
"help": "If true enables cudnn.deterministic. Might make your system slower, but ensures reproducibility."
284+
},
285+
)
286+
accelerator: Optional[str] = field(
287+
default=None,
288+
metadata={
289+
"help": "The accelerator backend to use. Defaults to None. Check this link for detailed documentation about the functionality. https://pytorch-lightning.readthedocs.io/en/latest/common/trainer.html#accelerator",
290+
"choices": [None, "dp", "ddp", "ddp_cpu", "ddp2"],
291+
},
292+
)
293+
tpu_cores: Optional[Union[List[int], str, int]] = field(
294+
default=None,
295+
metadata={
296+
"help": "How many TPU cores to train on (1 or 8) / Single TPU to train on [1]. Defaults to None",
297+
},
298+
)
258299
profiler: Optional[str] = field(
259300
default=None,
260301
metadata={
@@ -530,7 +571,7 @@ class ModelConfig:
530571
metrics: Optional[List[str]] = field(
531572
default=None,
532573
metadata={
533-
"help": "the list of metrics you need to track during training. The metrics should be one of the functional metrics implemented in PyTorch Lightning. By default, it is accuracy if classification and MeanSquaredLogError for regression"
574+
"help": "the list of metrics you need to track during training. The metrics should be one of the functional metrics implemented in ``torchmetrics``. By default, it is accuracy if classification and mean_squared_error for regression"
534575
},
535576
)
536577
metrics_params: Optional[List] = field(

pytorch_tabular/models/base_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import pytorch_lightning as pl
1010
import torch
11+
import torchmetrics
1112
import torch.nn as nn
1213
from omegaconf import DictConfig
1314

@@ -73,13 +74,13 @@ def _setup_loss(self):
7374
def _setup_metrics(self):
7475
if self.custom_metrics is None:
7576
self.metrics = []
76-
task_module = pl.metrics.functional
77+
task_module = torchmetrics.functional
7778
for metric in self.hparams.metrics:
7879
try:
7980
self.metrics.append(getattr(task_module, metric))
8081
except AttributeError as e:
8182
logger.error(
82-
f"{metric} is not a valid functional metric defined in the pytorch_lightning.metrics.functional module"
83+
f"{metric} is not a valid functional metric defined in the torchmetrics.functional module"
8384
)
8485
raise e
8586
else:
@@ -124,7 +125,7 @@ def calculate_metrics(self, y, y_hat, tag):
124125
for i in range(self.hparams.output_dim):
125126
if (
126127
metric.__name__
127-
== pl.metrics.functional.mean_squared_log_error.__name__
128+
== torchmetrics.functional.mean_squared_log_error.__name__
128129
):
129130
# MSLE should only be used in strictly positive targets. It is undefined otherwise
130131
_metric = metric(

requirements.txt

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
1-
torch>=1.3
1+
torch>=1.4
22
category-encoders==2.2.2
3-
numpy>=1.16.6
3+
numpy>=1.17.2
44
pandas==1.1.5
55
scikit-learn==0.23.2
6-
pytorch-lightning==1.0.8 #works well with wandb
7-
omegaconf==2.0.5
8-
tensorboard>=2.2.0
6+
pytorch-lightning==1.3.6
7+
omegaconf>=2.0.1
8+
torchmetrics>=0.3.2
9+
tensorboard>=2.2.0, !=2.5.0
910
pytorch-tabnet==3.0.0
10-
PyYAML>=5.1 # OmegaConf requirement >=5.1
11+
PyYAML>=5.1.* # OmegaConf requirement >=5.1
1112
# importlib-metadata <1,>=0.12
13+
matplotlib>3.1
1214
ipywidgets
13-
matplotlib
15+
# Use dataclasses backport for Python 3.6.
16+
dataclasses;python_version=='3.6'

requirements_testing.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
pip==19.2.3
1+
pip==20.3.1
22
bump2version==0.5.11
33
wheel==0.33.6
44
watchdog==0.9.0

0 commit comments

Comments
 (0)