From 6cc6da16956f80067da5f635382800f5bd90612a Mon Sep 17 00:00:00 2001 From: Soren Macbeth Date: Wed, 15 Jan 2025 01:22:55 -0800 Subject: [PATCH] Make tensor dtypes `np.float32` for MPS devices numpy defaults to numpy.float64 when they should be numpy.float32 This caused training to fail on MPS devices but it works on my M1 with this. --- .../ssl_models/common/noise_generators.py | 2 +- src/pytorch_tabular/tabular_datamodule.py | 12 +++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/pytorch_tabular/ssl_models/common/noise_generators.py b/src/pytorch_tabular/ssl_models/common/noise_generators.py index 2da372b4..d9acfeca 100644 --- a/src/pytorch_tabular/ssl_models/common/noise_generators.py +++ b/src/pytorch_tabular/ssl_models/common/noise_generators.py @@ -18,7 +18,7 @@ class SwapNoiseCorrupter(nn.Module): def __init__(self, probas): super().__init__() - self.probas = torch.from_numpy(np.array(probas)) + self.probas = torch.from_numpy(np.array(probas, dtype=np.float32)) def forward(self, x): should_swap = torch.bernoulli(self.probas.to(x.device) * torch.ones(x.shape).to(x.device)) diff --git a/src/pytorch_tabular/tabular_datamodule.py b/src/pytorch_tabular/tabular_datamodule.py index 3d09bb2e..71fe5635 100644 --- a/src/pytorch_tabular/tabular_datamodule.py +++ b/src/pytorch_tabular/tabular_datamodule.py @@ -67,7 +67,7 @@ def __init__( if isinstance(target, str): self.y = self.y.reshape(-1, 1) # .astype(np.int64) else: - self.y = np.zeros((self.n, 1)) # .astype(np.int64) + self.y = np.zeros((self.n, 1), dtype=np.float32) # .astype(np.int64) if task == "classification": self.y = self.y.astype(np.int64) @@ -502,7 +502,7 @@ def _cache_dataset(self): def split_train_val(self, train): logger.debug( - "No validation data provided." f" Using {self.config.validation_split*100}% of train data as validation" + f"No validation data provided. Using {self.config.validation_split * 100}% of train data as validation" ) val_idx = train.sample( int(self.config.validation_split * len(train)), @@ -753,18 +753,16 @@ def _load_dataset_from_cache(self, tag: str = "train"): try: dataset = getattr(self, f"_{tag}_dataset") except AttributeError: - raise AttributeError( - f"{tag}_dataset not found in memory. Please provide the data for" f" {tag} dataloader" - ) + raise AttributeError(f"{tag}_dataset not found in memory. Please provide the data for {tag} dataloader") elif self.cache_mode is self.CACHE_MODES.DISK: try: dataset = torch.load(self.cache_dir / f"{tag}_dataset") except FileNotFoundError: raise FileNotFoundError( - f"{tag}_dataset not found in {self.cache_dir}. Please provide the" f" data for {tag} dataloader" + f"{tag}_dataset not found in {self.cache_dir}. Please provide the data for {tag} dataloader" ) elif self.cache_mode is self.CACHE_MODES.INFERENCE: - raise RuntimeError("Cannot load dataset in inference mode. Use" " `prepare_inference_dataloader` instead") + raise RuntimeError("Cannot load dataset in inference mode. Use `prepare_inference_dataloader` instead") else: raise ValueError(f"{self.cache_mode} is not a valid cache mode") return dataset