Skip to content

Commit 5726f09

Browse files
authored
bug fixes in bagging and tta (#356)
1 parent cd5ded9 commit 5726f09

File tree

2 files changed

+13
-18
lines changed

2 files changed

+13
-18
lines changed

src/pytorch_tabular/tabular_model.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1424,9 +1424,8 @@ def predict(
14241424
for regression. For classification, the previous options are applied to the confidence
14251425
scores (soft voting) and then converted to final prediction. An additional option
14261426
"hard_voting" is available for classification.
1427-
If callable, should be a function that takes in a list of 2D arrays (num_samples, num_targets)
1428-
and returns a 2D array (num_samples, num_targets). Defaults to "mean".
1429-
1427+
If callable, should be a function that takes in a list of 3D arrays (num_samples, num_cv, num_targets)
1428+
and returns a 2D array of final probabilities (num_samples, num_targets). Defaults to "mean".
14301429
14311430
Returns:
14321431
DataFrame: Returns a dataframe with predictions and features (if `include_input_features=True`).
@@ -1454,7 +1453,6 @@ def add_noise(module, input, output):
14541453

14551454
# Register the hook to the embedding_layer
14561455
handle = self.model.embedding_layer.register_forward_hook(add_noise)
1457-
pred_l = []
14581456
pred_prob_l = []
14591457
for _ in range(num_tta):
14601458
pred_df = self._predict(
@@ -1468,11 +1466,10 @@ def add_noise(module, input, output):
14681466
)
14691467
pred_idx = pred_df.index
14701468
if self.config.task == "classification":
1471-
pred_l.append(pred_df.values[:, -len(self.config.target) :].astype(int))
14721469
pred_prob_l.append(pred_df.values[:, : -len(self.config.target)])
14731470
elif self.config.task == "regression":
14741471
pred_prob_l.append(pred_df.values)
1475-
pred_df = self._combine_predictions(pred_l, pred_prob_l, pred_idx, aggregate_tta, None)
1472+
pred_df = self._combine_predictions(pred_prob_l, pred_idx, aggregate_tta, None)
14761473
# Remove the hook
14771474
handle.remove()
14781475
else:
@@ -1993,7 +1990,6 @@ def cross_validate(
19931990

19941991
def _combine_predictions(
19951992
self,
1996-
pred_l: List[DataFrame],
19971993
pred_prob_l: List[DataFrame],
19981994
pred_idx: Union[pd.Index, List],
19991995
aggregate: Union[str, Callable],
@@ -2008,15 +2004,16 @@ def _combine_predictions(
20082004
elif aggregate == "max":
20092005
bagged_pred = np.max(pred_prob_l, axis=0)
20102006
elif aggregate == "hard_voting" and self.config.task == "classification":
2007+
pred_l = [np.argmax(p, axis=1) for p in pred_prob_l]
20112008
final_pred = np.apply_along_axis(
20122009
lambda x: np.argmax(np.bincount(x)),
20132010
axis=0,
2014-
arr=[p[:, -1].astype(int) for p in pred_l],
2011+
arr=pred_l,
20152012
)
20162013
elif callable(aggregate):
2017-
final_pred = bagged_pred = aggregate(pred_prob_l)
2014+
bagged_pred = aggregate(pred_prob_l)
20182015
if self.config.task == "classification":
2019-
if aggregate == "hard_voting" or callable(aggregate):
2016+
if aggregate == "hard_voting":
20202017
pred_df = pd.DataFrame(
20212018
np.concatenate(pred_prob_l, axis=1),
20222019
columns=[
@@ -2094,8 +2091,8 @@ def bagging_predict(
20942091
for regression. For classification, the previous options are applied to the confidence
20952092
scores (soft voting) and then converted to final prediction. An additional option
20962093
"hard_voting" is available for classification.
2097-
If callable, should be a function that takes in a list of 2D arrays (num_samples, num_targets)
2098-
and returns a 2D array (num_samples, num_targets). Defaults to "mean".
2094+
If callable, should be a function that takes in a list of 3D arrays (num_samples, num_cv, num_targets)
2095+
and returns a 2D array of final probabilities (num_samples, num_targets). Defaults to "mean".
20992096
21002097
weights (Optional[List[float]], optional): The weights to be used for aggregating the predictions
21012098
from each fold. If None, will use equal weights. This is only used when `aggregate` is "mean".
@@ -2122,7 +2119,6 @@ def bagging_predict(
21222119
assert aggregate != "hard_voting", "hard_voting is only available for classification"
21232120
cv = self._check_cv(cv)
21242121
prep_dl_kwargs, prep_model_kwargs, train_kwargs = self._split_kwargs(kwargs)
2125-
pred_l = []
21262122
pred_prob_l = []
21272123
datamodule = None
21282124
model = None
@@ -2149,15 +2145,14 @@ def bagging_predict(
21492145
fold_preds = self.predict(test, include_input_features=False)
21502146
pred_idx = fold_preds.index
21512147
if self.config.task == "classification":
2152-
pred_l.append(fold_preds.values[:, -len(self.config.target) :].astype(int))
21532148
pred_prob_l.append(fold_preds.values[:, : -len(self.config.target)])
21542149
elif self.config.task == "regression":
21552150
pred_prob_l.append(fold_preds.values)
21562151
if verbose:
21572152
logger.info(f"Fold {fold+1}/{cv.get_n_splits()} prediction done")
21582153
self.model.reset_weights()
2159-
pred_df = self._combine_predictions(pred_l, pred_prob_l, pred_idx, aggregate, weights)
2154+
pred_df = self._combine_predictions(pred_prob_l, pred_idx, aggregate, weights)
21602155
if return_raw_predictions:
2161-
return pred_df, pred_l, pred_prob_l
2156+
return pred_df, pred_prob_l
21622157
else:
21632158
return pred_df

tests/test_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -916,7 +916,7 @@ def _run_bagging(
916916
@pytest.mark.parametrize("cv", [2])
917917
@pytest.mark.parametrize(
918918
"aggregate",
919-
["mean", "median", "min", "max", "hard_voting", lambda x: np.argmax(np.median(x, axis=0), axis=1)],
919+
["mean", "median", "min", "max", "hard_voting", lambda x: np.median(x, axis=0)],
920920
)
921921
def test_bagging_classification(
922922
classification_data,
@@ -1040,7 +1040,7 @@ def _run_tta(
10401040
@pytest.mark.parametrize("categorical_cols", [["feature_0_cat"]])
10411041
@pytest.mark.parametrize(
10421042
"aggregate",
1043-
["mean", "median", "min", "max", "hard_voting", lambda x: np.argmax(np.median(x, axis=0), axis=1)],
1043+
["mean", "median", "min", "max", "hard_voting", lambda x: np.median(x, axis=0)],
10441044
)
10451045
def test_tta_classification(
10461046
classification_data,

0 commit comments

Comments
 (0)