Skip to content

Commit 428cc23

Browse files
committed
-- bug fixed for encode_date columns
1 parent 132c3f2 commit 428cc23

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

pytorch_tabular/models/base_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def calculate_loss(self, y, y_hat, tag):
103103
)
104104
computed_loss = torch.stack(losses, dim=0).sum()
105105
else:
106+
#TODO loss fails with batch size of 1
106107
computed_loss = self.loss(y_hat.squeeze(), y.squeeze())
107108
self.log(
108109
f"{tag}_loss",

pytorch_tabular/tabular_datamodule.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,7 @@ def add_datepart(
433433
)
434434
df.insert(3, prefix + "Week", week)
435435
added_features.append(prefix + "Week")
436-
# Not adding Elapsed by default. Need to route it through config
436+
#TODO Not adding Elapsed by default. Need to route it through config
437437
# mask = ~field.isna()
438438
# df[prefix + "Elapsed"] = np.where(
439439
# mask, field.values.astype(np.int64) // 10 ** 9, None
@@ -443,10 +443,10 @@ def add_datepart(
443443
df.drop(field_name, axis=1, inplace=True)
444444

445445
# Removing features woth zero variations
446-
for col in added_features:
447-
if len(df[col].unique()) == 1:
448-
df.drop(columns=col, inplace=True)
449-
added_features.remove(col)
446+
# for col in added_features:
447+
# if len(df[col].unique()) == 1:
448+
# df.drop(columns=col, inplace=True)
449+
# added_features.remove(col)
450450
return df, added_features
451451

452452
def train_dataloader(self, batch_size: Optional[int] = None) -> DataLoader:

0 commit comments

Comments
 (0)