From a288d8d44e91193577a649ca7d72729d90cfdb8b Mon Sep 17 00:00:00 2001 From: Maximilian Schambach <28300359+MaxSchambach@users.noreply.github.com> Date: Sat, 26 Jul 2025 17:17:43 +0200 Subject: [PATCH 1/4] Update dataset.py --- mambular/data_utils/dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mambular/data_utils/dataset.py b/mambular/data_utils/dataset.py index db6c63a7..2ae9af9a 100644 --- a/mambular/data_utils/dataset.py +++ b/mambular/data_utils/dataset.py @@ -44,7 +44,8 @@ def __init__( self.labels = None # No labels in prediction mode def __len__(self): - return len(self.num_features_list[0]) # Use numerical features length + _feats = self.num_features_list if self.num_features_list else cat_features_list + return len(_feats[0]) def __getitem__(self, idx): """Retrieves the features and label for a given index. From 7ade7843e4bcffe4e46ba23b592f3c0f86ba0609 Mon Sep 17 00:00:00 2001 From: Maximilian Schambach <28300359+MaxSchambach@users.noreply.github.com> Date: Sat, 26 Jul 2025 17:24:34 +0200 Subject: [PATCH 2/4] Update dataset.py --- mambular/data_utils/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mambular/data_utils/dataset.py b/mambular/data_utils/dataset.py index 2ae9af9a..c99a719e 100644 --- a/mambular/data_utils/dataset.py +++ b/mambular/data_utils/dataset.py @@ -44,7 +44,7 @@ def __init__( self.labels = None # No labels in prediction mode def __len__(self): - _feats = self.num_features_list if self.num_features_list else cat_features_list + _feats = self.num_features_list if self.num_features_list else self.cat_features_list return len(_feats[0]) def __getitem__(self, idx): From 4d2275c369977fa6ba8c031aa8fff3db2b0833e0 Mon Sep 17 00:00:00 2001 From: Maximilian Schambach <28300359+MaxSchambach@users.noreply.github.com> Date: Sat, 26 Jul 2025 17:26:37 +0200 Subject: [PATCH 3/4] Update dataset.py --- mambular/data_utils/dataset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mambular/data_utils/dataset.py b/mambular/data_utils/dataset.py index c99a719e..be1e0d3a 100644 --- a/mambular/data_utils/dataset.py +++ b/mambular/data_utils/dataset.py @@ -24,11 +24,14 @@ def __init__( labels=None, regression=True, ): + assert cat_features_list or num_features_list + self.cat_features_list = cat_features_list # Categorical features tensors self.num_features_list = num_features_list # Numerical features tensors self.embeddings_list = embeddings_list # Embeddings tensors (optional) self.regression = regression + if labels is not None: if not self.regression: self.num_classes = len(np.unique(labels)) From 9e8043f153e2e0bef3cfd95a6c209c123c591498 Mon Sep 17 00:00:00 2001 From: Maximilian Schambach <28300359+MaxSchambach@users.noreply.github.com> Date: Sat, 26 Jul 2025 17:28:53 +0200 Subject: [PATCH 4/4] Update dataset.py --- mambular/data_utils/dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mambular/data_utils/dataset.py b/mambular/data_utils/dataset.py index be1e0d3a..e447f9a2 100644 --- a/mambular/data_utils/dataset.py +++ b/mambular/data_utils/dataset.py @@ -31,7 +31,6 @@ def __init__( self.embeddings_list = embeddings_list # Embeddings tensors (optional) self.regression = regression - if labels is not None: if not self.regression: self.num_classes = len(np.unique(labels))