diff --git a/src/pytorch_tabular/ssl_models/common/noise_generators.py b/src/pytorch_tabular/ssl_models/common/noise_generators.py index 2da372b4..d9acfeca 100644 --- a/src/pytorch_tabular/ssl_models/common/noise_generators.py +++ b/src/pytorch_tabular/ssl_models/common/noise_generators.py @@ -18,7 +18,7 @@ class SwapNoiseCorrupter(nn.Module): def __init__(self, probas): super().__init__() - self.probas = torch.from_numpy(np.array(probas)) + self.probas = torch.from_numpy(np.array(probas, dtype=np.float32)) def forward(self, x): should_swap = torch.bernoulli(self.probas.to(x.device) * torch.ones(x.shape).to(x.device)) diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index 7b1e89f8..2baebecb 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -67,7 +67,7 @@ def __init__( if isinstance(target, str): self.y = self.y.reshape(-1, 1) # .astype(np.int64) else: - self.y = np.zeros((self.n, 1)) # .astype(np.int64) + self.y = np.zeros((self.n, 1), dtype=np.float32) # .astype(np.int64) if task == "classification": self.y = self.y.astype(np.int64) @@ -502,7 +502,7 @@ def _cache_dataset(self): def split_train_val(self, train): logger.debug( - "No validation data provided." f" Using {self.config.validation_split*100}% of train data as validation" + f"No validation data provided. Using {self.config.validation_split * 100}% of train data as validation" ) val_idx = train.sample( int(self.config.validation_split * len(train)), @@ -753,9 +753,7 @@ def _load_dataset_from_cache(self, tag: str = "train"): try: dataset = getattr(self, f"_{tag}_dataset") except AttributeError: - raise AttributeError( - f"{tag}_dataset not found in memory. Please provide the data for" f" {tag} dataloader" - ) + raise AttributeError(f"{tag}_dataset not found in memory. Please provide the data for {tag} dataloader") elif self.cache_mode is self.CACHE_MODES.DISK: try: # get the torch version @@ -768,10 +766,10 @@ def _load_dataset_from_cache(self, tag: str = "train"): dataset = torch.load(self.cache_dir / f"{tag}_dataset", weights_only=False) except FileNotFoundError: raise FileNotFoundError( - f"{tag}_dataset not found in {self.cache_dir}. Please provide the" f" data for {tag} dataloader" + f"{tag}_dataset not found in {self.cache_dir}. Please provide the data for {tag} dataloader" ) elif self.cache_mode is self.CACHE_MODES.INFERENCE: - raise RuntimeError("Cannot load dataset in inference mode. Use" " `prepare_inference_dataloader` instead") + raise RuntimeError("Cannot load dataset in inference mode. Use `prepare_inference_dataloader` instead") else: raise ValueError(f"{self.cache_mode} is not a valid cache mode") return dataset