2020#TODO dont use embedding_dims
2121class AutoIntBackbone (pl .LightningModule ):
2222 def __init__ (self , config : DictConfig ):
23- self .embedding_cat_dim = sum ([y for x , y in config .embedding_dims ])
23+ # self.embedding_cat_dim = sum([y for x, y in config.embedding_dims])
2424 # self.hparams = config
2525 super ().__init__ ()
2626 self .save_hyperparameters (config )
2727 self ._build_network ()
2828
2929 def _build_network (self ):
30- # Category Embedding layers
31- self .cat_embedding_layers = nn .ModuleList (
32- [
33- nn .Embedding (cardinality , self .hparams .embedding_dim )
34- for cardinality in self .hparams .categorical_cardinality
35- ]
36- )
30+ if len (self .hparams .categorical_cols )> 0 :
31+ # Category Embedding layers
32+ self .cat_embedding_layers = nn .ModuleList (
33+ [
34+ nn .Embedding (cardinality , self .hparams .embedding_dim )
35+ for cardinality in self .hparams .categorical_cardinality
36+ ]
37+ )
3738 if self .hparams .batch_norm_continuous_input :
3839 self .normalizing_batch_norm = nn .BatchNorm1d (self .hparams .continuous_dim )
3940 # Continuous Embedding Layer
4041 self .cont_embedding_layer = nn .Embedding (
4142 self .hparams .continuous_dim , self .hparams .embedding_dim
4243 )
43- if self .hparams .embedding_dropout != 0 and self .embedding_cat_dim != 0 :
44+ if self .hparams .embedding_dropout != 0 and len ( self .hparams . categorical_cols ) > 0 :
4445 self .embed_dropout = nn .Dropout (self .hparams .embedding_dropout )
4546 # Deep Layers
4647 _curr_units = self .hparams .embedding_dim
@@ -91,7 +92,7 @@ def forward(self, x: Dict):
9192 # (B, N)
9293 continuous_data , categorical_data = x ["continuous" ], x ["categorical" ]
9394 x = None
94- if self .embedding_cat_dim != 0 :
95+ if len ( self .hparams . categorical_cols ) > 0 :
9596 x_cat = [
9697 embedding_layer (categorical_data [:, i ]).unsqueeze (1 )
9798 for i , embedding_layer in enumerate (self .cat_embedding_layers )
@@ -112,7 +113,7 @@ def forward(self, x: Dict):
112113 )
113114 # (B, N, E)
114115 x = x_cont if x is None else torch .cat ([x , x_cont ], 1 )
115- if self .hparams .embedding_dropout != 0 and self .embedding_cat_dim != 0 :
116+ if self .hparams .embedding_dropout != 0 and len ( self .hparams . categorical_cols ) > 0 :
116117 x = self .embed_dropout (x )
117118 if self .hparams .deep_layers :
118119 x = self .linear_layers (x )
@@ -140,7 +141,7 @@ def forward(self, x: Dict):
140141class AutoIntModel (BaseModel ):
141142 def __init__ (self , config : DictConfig , ** kwargs ):
142143 # The concatenated output dim of the embedding layer
143- self .embedding_cat_dim = sum ([y for x , y in config .embedding_dims ])
144+ # self.embedding_cat_dim = sum([y for x, y in config.embedding_dims])
144145 super ().__init__ (config , ** kwargs )
145146
146147 def _build_network (self ):
0 commit comments