Skip to content

Commit 1ccfcc2

Browse files
committed
-- fixed autoint embedding issue
1 parent 428cc23 commit 1ccfcc2

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

pytorch_tabular/models/autoint/autoint.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,28 @@
2020
#TODO dont use embedding_dims
2121
class AutoIntBackbone(pl.LightningModule):
2222
def __init__(self, config: DictConfig):
23-
self.embedding_cat_dim = sum([y for x, y in config.embedding_dims])
23+
# self.embedding_cat_dim = sum([y for x, y in config.embedding_dims])
2424
# self.hparams = config
2525
super().__init__()
2626
self.save_hyperparameters(config)
2727
self._build_network()
2828

2929
def _build_network(self):
30-
# Category Embedding layers
31-
self.cat_embedding_layers = nn.ModuleList(
32-
[
33-
nn.Embedding(cardinality, self.hparams.embedding_dim)
34-
for cardinality in self.hparams.categorical_cardinality
35-
]
36-
)
30+
if len(self.hparams.categorical_cols)>0:
31+
# Category Embedding layers
32+
self.cat_embedding_layers = nn.ModuleList(
33+
[
34+
nn.Embedding(cardinality, self.hparams.embedding_dim)
35+
for cardinality in self.hparams.categorical_cardinality
36+
]
37+
)
3738
if self.hparams.batch_norm_continuous_input:
3839
self.normalizing_batch_norm = nn.BatchNorm1d(self.hparams.continuous_dim)
3940
# Continuous Embedding Layer
4041
self.cont_embedding_layer = nn.Embedding(
4142
self.hparams.continuous_dim, self.hparams.embedding_dim
4243
)
43-
if self.hparams.embedding_dropout != 0 and self.embedding_cat_dim != 0:
44+
if self.hparams.embedding_dropout != 0 and len(self.hparams.categorical_cols)>0:
4445
self.embed_dropout = nn.Dropout(self.hparams.embedding_dropout)
4546
# Deep Layers
4647
_curr_units = self.hparams.embedding_dim
@@ -91,7 +92,7 @@ def forward(self, x: Dict):
9192
# (B, N)
9293
continuous_data, categorical_data = x["continuous"], x["categorical"]
9394
x = None
94-
if self.embedding_cat_dim != 0:
95+
if len(self.hparams.categorical_cols) > 0:
9596
x_cat = [
9697
embedding_layer(categorical_data[:, i]).unsqueeze(1)
9798
for i, embedding_layer in enumerate(self.cat_embedding_layers)
@@ -112,7 +113,7 @@ def forward(self, x: Dict):
112113
)
113114
# (B, N, E)
114115
x = x_cont if x is None else torch.cat([x, x_cont], 1)
115-
if self.hparams.embedding_dropout != 0 and self.embedding_cat_dim != 0:
116+
if self.hparams.embedding_dropout != 0 and len(self.hparams.categorical_cols) > 0:
116117
x = self.embed_dropout(x)
117118
if self.hparams.deep_layers:
118119
x = self.linear_layers(x)
@@ -140,7 +141,7 @@ def forward(self, x: Dict):
140141
class AutoIntModel(BaseModel):
141142
def __init__(self, config: DictConfig, **kwargs):
142143
# The concatenated output dim of the embedding layer
143-
self.embedding_cat_dim = sum([y for x, y in config.embedding_dims])
144+
# self.embedding_cat_dim = sum([y for x, y in config.embedding_dims])
144145
super().__init__(config, **kwargs)
145146

146147
def _build_network(self):

0 commit comments

Comments
 (0)