Skip to content

Commit cd5ded9

Browse files
Test Time Augmentation and other bug fixes (#355)
* refactored predict * added tta and test cases * made tuning wihtout cv more rbust * fixed bugs in test tuner * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3b8bbc7 commit cd5ded9

File tree

4 files changed

+366
-74
lines changed

4 files changed

+366
-74
lines changed

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(
5757

5858
self.task = task
5959
self.n = data.shape[0]
60-
60+
self.target = target
6161
if target:
6262
self.y = data[target].astype(np.float32).values
6363
if isinstance(target, str):
@@ -81,16 +81,19 @@ def __init__(
8181
def data(self):
8282
"""Returns the data as a pandas dataframe."""
8383
if self.continuous_cols and self.categorical_cols:
84-
return pd.DataFrame(
84+
data = pd.DataFrame(
8585
np.concatenate([self.categorical_X, self.continuous_X], axis=1),
8686
columns=self.categorical_cols + self.continuous_cols,
8787
)
8888
elif self.continuous_cols:
89-
return pd.DataFrame(self.continuous_X, columns=self.continuous_cols)
89+
data = pd.DataFrame(self.continuous_X, columns=self.continuous_cols)
9090
elif self.categorical_cols:
91-
return pd.DataFrame(self.categorical_X, columns=self.categorical_cols)
91+
data = pd.DataFrame(self.categorical_X, columns=self.categorical_cols)
9292
else:
93-
return pd.DataFrame()
93+
data = pd.DataFrame()
94+
for i, t in enumerate(self.target):
95+
data[t] = self.y[:, i]
96+
return data
9497

9598
def __len__(self):
9699
"""Denotes the total number of samples."""

src/pytorch_tabular/tabular_model.py

Lines changed: 213 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,65 +1201,19 @@ def evaluate(
12011201
)
12021202
return result
12031203

1204-
def predict(
1204+
def _generate_predictions(
12051205
self,
1206-
test: DataFrame,
1207-
quantiles: Optional[List] = [0.25, 0.5, 0.75],
1208-
n_samples: Optional[int] = 100,
1209-
ret_logits=False,
1210-
include_input_features: bool = False,
1211-
device: Optional[torch.device] = None,
1212-
progress_bar: Optional[str] = None,
1213-
) -> DataFrame:
1214-
"""Uses the trained model to predict on new data and return as a dataframe.
1215-
1216-
Args:
1217-
test (DataFrame): The new dataframe with the features defined during training
1218-
quantiles (Optional[List]): For probabilistic models like Mixture Density Networks, this specifies
1219-
the different quantiles to be extracted apart from the `central_tendency` and added to the dataframe.
1220-
For other models it is ignored. Defaults to [0.25, 0.5, 0.75]
1221-
n_samples (Optional[int]): Number of samples to draw from the posterior to estimate the quantiles.
1222-
Ignored for non-probabilistic models. Defaults to 100
1223-
ret_logits (bool): Flag to return raw model outputs/logits except the backbone features along
1224-
with the dataframe. Defaults to False
1225-
include_input_features (bool): DEPRECATED: Flag to include the input features in the returned dataframe.
1226-
Defaults to True
1227-
progress_bar: chose progress bar for tracking the progress
1228-
1229-
Returns:
1230-
DataFrame: Returns a dataframe with predictions and features (if `include_input_features=True`).
1231-
If classification, it returns probabilities and final prediction
1232-
"""
1233-
warnings.warn(
1234-
"`include_input_features` will be deprecated in the next release."
1235-
" Please add index columns to the test dataframe if you want to"
1236-
" retain some features like the key or id",
1237-
DeprecationWarning,
1238-
)
1239-
assert all(q <= 1 and q >= 0 for q in quantiles), "Quantiles should be a decimal between 0 and 1"
1240-
model = self.model # default
1241-
if device is not None:
1242-
if isinstance(device, str):
1243-
device = torch.device(device)
1244-
if self.model.device != device:
1245-
model = self.model.to(device)
1246-
model.eval()
1247-
inference_dataloader = self.datamodule.prepare_inference_dataloader(test)
1206+
model,
1207+
inference_dataloader,
1208+
quantiles,
1209+
n_samples,
1210+
ret_logits,
1211+
progress_bar,
1212+
is_probabilistic,
1213+
):
12481214
point_predictions = []
12491215
quantile_predictions = []
12501216
logits_predictions = defaultdict(list)
1251-
is_probabilistic = hasattr(model.hparams, "_probabilistic") and model.hparams._probabilistic
1252-
1253-
if progress_bar == "rich":
1254-
from rich.progress import track
1255-
1256-
progress_bar = partial(track, description="Generating Predictions...")
1257-
elif progress_bar == "tqdm":
1258-
from tqdm.auto import tqdm
1259-
1260-
progress_bar = partial(tqdm, description="Generating Predictions...")
1261-
else:
1262-
progress_bar = lambda it: it # noqa E731
12631217
for batch in progress_bar(inference_dataloader):
12641218
for k, v in batch.items():
12651219
if isinstance(v, list) and (len(v) == 0):
@@ -1275,8 +1229,6 @@ def predict(
12751229
y_hat, ret_value = model.predict(batch, ret_model_output=True)
12761230
if ret_logits:
12771231
for k, v in ret_value.items():
1278-
# if k == "backbone_features":
1279-
# continue
12801232
logits_predictions[k].append(v.detach().cpu())
12811233
point_predictions.append(y_hat.detach().cpu())
12821234
if is_probabilistic:
@@ -1288,6 +1240,19 @@ def predict(
12881240
quantile_predictions = torch.cat(quantile_predictions, dim=0).unsqueeze(-1)
12891241
if quantile_predictions.ndim == 2:
12901242
quantile_predictions = quantile_predictions.unsqueeze(-1)
1243+
return point_predictions, quantile_predictions, logits_predictions
1244+
1245+
def _format_predicitons(
1246+
self,
1247+
test,
1248+
point_predictions,
1249+
quantile_predictions,
1250+
logits_predictions,
1251+
quantiles,
1252+
ret_logits,
1253+
include_input_features,
1254+
is_probabilistic,
1255+
):
12911256
pred_df = test.copy() if include_input_features else DataFrame(index=test.index)
12921257
if self.config.task == "regression":
12931258
point_predictions = point_predictions.numpy()
@@ -1340,6 +1305,188 @@ def predict(
13401305
pred_df[f"{k}"] = v[:, i]
13411306
return pred_df
13421307

1308+
def _predict(
1309+
self,
1310+
test: DataFrame,
1311+
quantiles: Optional[List] = [0.25, 0.5, 0.75],
1312+
n_samples: Optional[int] = 100,
1313+
ret_logits=False,
1314+
include_input_features: bool = False,
1315+
device: Optional[torch.device] = None,
1316+
progress_bar: Optional[str] = None,
1317+
) -> DataFrame:
1318+
"""Uses the trained model to predict on new data and return as a dataframe.
1319+
1320+
Args:
1321+
test (DataFrame): The new dataframe with the features defined during training
1322+
quantiles (Optional[List]): For probabilistic models like Mixture Density Networks, this specifies
1323+
the different quantiles to be extracted apart from the `central_tendency` and added to the dataframe.
1324+
For other models it is ignored. Defaults to [0.25, 0.5, 0.75]
1325+
n_samples (Optional[int]): Number of samples to draw from the posterior to estimate the quantiles.
1326+
Ignored for non-probabilistic models. Defaults to 100
1327+
ret_logits (bool): Flag to return raw model outputs/logits except the backbone features along
1328+
with the dataframe. Defaults to False
1329+
include_input_features (bool): DEPRECATED: Flag to include the input features in the returned dataframe.
1330+
Defaults to True
1331+
progress_bar: chose progress bar for tracking the progress
1332+
1333+
Returns:
1334+
DataFrame: Returns a dataframe with predictions and features (if `include_input_features=True`).
1335+
If classification, it returns probabilities and final prediction
1336+
"""
1337+
assert all(q <= 1 and q >= 0 for q in quantiles), "Quantiles should be a decimal between 0 and 1"
1338+
model = self.model # default
1339+
if device is not None:
1340+
if isinstance(device, str):
1341+
device = torch.device(device)
1342+
if self.model.device != device:
1343+
model = self.model.to(device)
1344+
model.eval()
1345+
inference_dataloader = self.datamodule.prepare_inference_dataloader(test)
1346+
is_probabilistic = hasattr(model.hparams, "_probabilistic") and model.hparams._probabilistic
1347+
1348+
if progress_bar == "rich":
1349+
from rich.progress import track
1350+
1351+
progress_bar = partial(track, description="Generating Predictions...")
1352+
elif progress_bar == "tqdm":
1353+
from tqdm.auto import tqdm
1354+
1355+
progress_bar = partial(tqdm, description="Generating Predictions...")
1356+
else:
1357+
progress_bar = lambda it: it # noqa E731
1358+
point_predictions, quantile_predictions, logits_predictions = self._generate_predictions(
1359+
model,
1360+
inference_dataloader,
1361+
quantiles,
1362+
n_samples,
1363+
ret_logits,
1364+
progress_bar,
1365+
is_probabilistic,
1366+
)
1367+
pred_df = self._format_predicitons(
1368+
test,
1369+
point_predictions,
1370+
quantile_predictions,
1371+
logits_predictions,
1372+
quantiles,
1373+
ret_logits,
1374+
include_input_features,
1375+
is_probabilistic,
1376+
)
1377+
return pred_df
1378+
1379+
def predict(
1380+
self,
1381+
test: DataFrame,
1382+
quantiles: Optional[List] = [0.25, 0.5, 0.75],
1383+
n_samples: Optional[int] = 100,
1384+
ret_logits=False,
1385+
include_input_features: bool = False,
1386+
device: Optional[torch.device] = None,
1387+
progress_bar: Optional[str] = None,
1388+
test_time_augmentation: Optional[bool] = False,
1389+
num_tta: Optional[float] = 5,
1390+
alpha_tta: Optional[float] = 0.1,
1391+
aggregate_tta: Optional[str] = "mean",
1392+
) -> DataFrame:
1393+
"""Uses the trained model to predict on new data and return as a dataframe.
1394+
1395+
Args:
1396+
test (DataFrame): The new dataframe with the features defined during training
1397+
1398+
quantiles (Optional[List]): For probabilistic models like Mixture Density Networks, this specifies
1399+
the different quantiles to be extracted apart from the `central_tendency` and added to the dataframe.
1400+
For other models it is ignored. Defaults to [0.25, 0.5, 0.75]
1401+
1402+
n_samples (Optional[int]): Number of samples to draw from the posterior to estimate the quantiles.
1403+
Ignored for non-probabilistic models. Defaults to 100
1404+
1405+
ret_logits (bool): Flag to return raw model outputs/logits except the backbone features along
1406+
with the dataframe. Defaults to False
1407+
1408+
include_input_features (bool): DEPRECATED: Flag to include the input features in the returned dataframe.
1409+
Defaults to True
1410+
1411+
progress_bar: chose progress bar for tracking the progress
1412+
1413+
test_time_augmentation (bool): If True, will use test time augmentation to generate predictions.
1414+
The approach is very similar to what is described [here](https://kozodoi.me/blog/20210908/tta-tabular)
1415+
But, we add noise to the embedded inputs to handle categorical features as well.\
1416+
\\(x_{aug} = x_{orig} + \alpha * \\epsilon\\) where \\(\\epsilon \\sim \\mathcal{N}(0, 1)\\)
1417+
Defaults to False
1418+
num_tta (float): The number of augumentations to run TTA for. Defaults to 0.0
1419+
1420+
alpha_tta (float): The standard deviation of the gaussian noise to be added to the input features
1421+
1422+
aggregate_tta (Union[str, Callable], optional): The function to be used to aggregate the
1423+
predictions from each augumentation. If str, should be one of "mean", "median", "min", or "max"
1424+
for regression. For classification, the previous options are applied to the confidence
1425+
scores (soft voting) and then converted to final prediction. An additional option
1426+
"hard_voting" is available for classification.
1427+
If callable, should be a function that takes in a list of 2D arrays (num_samples, num_targets)
1428+
and returns a 2D array (num_samples, num_targets). Defaults to "mean".
1429+
1430+
1431+
Returns:
1432+
DataFrame: Returns a dataframe with predictions and features (if `include_input_features=True`).
1433+
If classification, it returns probabilities and final prediction
1434+
"""
1435+
warnings.warn(
1436+
"`include_input_features` will be deprecated in the next release."
1437+
" Please add index columns to the test dataframe if you want to"
1438+
" retain some features like the key or id",
1439+
DeprecationWarning,
1440+
)
1441+
if test_time_augmentation:
1442+
assert num_tta > 0, "num_tta should be greater than 0"
1443+
assert alpha_tta > 0, "alpha_tta should be greater than 0"
1444+
assert include_input_features is False, "include_input_features cannot be True for TTA."
1445+
if not callable(aggregate_tta):
1446+
assert aggregate_tta in ["mean", "median", "min", "max", "hard_voting"], (
1447+
"aggregate should be one of 'mean', 'median', 'min', 'max', or" " 'hard_voting'"
1448+
)
1449+
if self.config.task == "regression":
1450+
assert aggregate_tta != "hard_voting", "hard_voting is only available for classification"
1451+
1452+
def add_noise(module, input, output):
1453+
return output + alpha_tta * torch.randn_like(output)
1454+
1455+
# Register the hook to the embedding_layer
1456+
handle = self.model.embedding_layer.register_forward_hook(add_noise)
1457+
pred_l = []
1458+
pred_prob_l = []
1459+
for _ in range(num_tta):
1460+
pred_df = self._predict(
1461+
test,
1462+
quantiles,
1463+
n_samples,
1464+
ret_logits,
1465+
include_input_features=False,
1466+
device=device,
1467+
progress_bar=progress_bar,
1468+
)
1469+
pred_idx = pred_df.index
1470+
if self.config.task == "classification":
1471+
pred_l.append(pred_df.values[:, -len(self.config.target) :].astype(int))
1472+
pred_prob_l.append(pred_df.values[:, : -len(self.config.target)])
1473+
elif self.config.task == "regression":
1474+
pred_prob_l.append(pred_df.values)
1475+
pred_df = self._combine_predictions(pred_l, pred_prob_l, pred_idx, aggregate_tta, None)
1476+
# Remove the hook
1477+
handle.remove()
1478+
else:
1479+
pred_df = self._predict(
1480+
test,
1481+
quantiles,
1482+
n_samples,
1483+
ret_logits,
1484+
include_input_features,
1485+
device,
1486+
progress_bar,
1487+
)
1488+
return pred_df
1489+
13431490
def load_best_model(self) -> None:
13441491
"""Loads the best model after training is done."""
13451492
if self.trainer.checkpoint_callback is not None:
@@ -1708,7 +1855,8 @@ def _check_cv(self, cv):
17081855
return StratifiedKFold(cv)
17091856
else:
17101857
return KFold(cv)
1711-
elif isinstance(cv, Iterable):
1858+
elif isinstance(cv, Iterable) and not isinstance(cv, str):
1859+
# An iterable yielding (train, test) splits as arrays of indices.
17121860
return cv
17131861
elif isinstance(cv, BaseCrossValidator):
17141862
return cv
@@ -1800,11 +1948,17 @@ def cross_validate(
18001948
metric = metric if metric.startswith("test_") else "test_" + metric
18011949
elif callable(metric):
18021950
is_callable_metric = True
1951+
1952+
if isinstance(cv, BaseCrossValidator):
1953+
it = enumerate(cv.split(train, y=train[self.config.target], groups=groups))
1954+
else:
1955+
# when iterable is directly passed
1956+
it = enumerate(cv)
18031957
cv_metrics = []
18041958
datamodule = None
18051959
model = None
18061960
oof_preds = []
1807-
for fold, (train_idx, val_idx) in enumerate(cv.split(train, y=train[self.config.target], groups=groups)):
1961+
for fold, (train_idx, val_idx) in it:
18081962
if verbose:
18091963
logger.info(f"Running Fold {fold+1}/{cv.get_n_splits()}")
18101964
train_fold = train.iloc[train_idx]

0 commit comments

Comments
 (0)