Skip to content

Commit d86bfa6

Browse files
sorenmacbethSoren Macbeth
andauthored
update embedding_dims if new features are added (#358)
Co-authored-by: Soren Macbeth <soren@abracadaniel.local>
1 parent 5726f09 commit d86bfa6

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

src/pytorch_tabular/tabular_datamodule.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,7 @@ def preprocess_data(self, data: DataFrame, stage: str = "inference") -> Tuple[Da
359359
self.config.categorical_dim = (
360360
len(self.config.categorical_cols) if self.config.categorical_cols is not None else 0
361361
)
362+
self._inferred_config = self._update_config(self.config)
362363
# Encoding Categorical Columns
363364
if len(self.config.categorical_cols) > 0:
364365
data = self._encode_categorical_columns(data, stage)

0 commit comments

Comments
 (0)