@@ -9,6 +9,12 @@ class CIFAR10DataModule(pl.LightningDataModule):
99 """
1010 A LightningDataModule for handling CIFAR10 data.
1111
12+ Note: Torchvision provides many built-in datasets in the torchvision.datasets module,
13+ as well as utility classes for building your own datasets. All datasets are subclasses
14+ of torch.utils.data.Dataset i.e, they have __getitem__ and __len__ methods implemented.
15+ Hence, they can all be passed to a torch.utils.data.DataLoader which can load multiple
16+ samples in parallel using torch.multiprocessing workers, see [1].
17+
1218 Args:
1319 batch_size (int): The size of the batch.
1420 data_dir (str): The directory where the data is stored. Defaults to "./data".
@@ -18,6 +24,9 @@ class CIFAR10DataModule(pl.LightningDataModule):
1824 data_train (Dataset): The training dataset.
1925 data_val (Dataset): The validation dataset.
2026 data_test (Dataset): The test dataset.
27+
28+ References:
29+ [1] [https://pytorch.org/vision/stable/datasets.html](https://pytorch.org/vision/stable/datasets.html)
2130 """
2231
2332 def __init__ (self , batch_size : int , data_dir : str = "./data" , num_workers : int = 0 ):
@@ -40,22 +49,21 @@ def setup(self, stage: Optional[str] = None) -> None:
4049 stage (Optional[str]): The current stage. Defaults to None.
4150
4251 """
52+ # Assign appropriate data transforms, see
53+ # https://lightning.ai/docs/pytorch/latest/notebooks/course_UvA-DL/04-inception-resnet-densenet.html
54+ DATA_MEANS = (0.49139968 , 0.48215841 , 0.44653091 )
55+ DATA_STDS = (0.24703223 , 0.24348513 , 0.26158784 )
56+ transform = transforms .Compose (
57+ [transforms .ToTensor (), transforms .Normalize (DATA_MEANS , DATA_STDS )]
58+ )
4359 # Assign train/val datasets for use in dataloaders
4460 if stage == "fit" or stage is None :
45- transform = transforms .Compose (
46- [transforms .ToTensor (), transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))]
47- )
4861 data_full = CIFAR10 (root = self .data_dir , train = True , transform = transform )
49- # self.data_train, self.data_val = random_split(daata_full, [45000, 5000])
5062 test_abs = int (len (data_full ) * 0.6 )
51- print ("dm.setup(): test_abs" , test_abs )
5263 self .data_train , self .data_val = random_split (data_full , [test_abs , len (data_full ) - test_abs ])
5364
5465 # Assign test dataset for use in dataloader(s)
5566 if stage == "test" or stage is None :
56- transform = transforms .Compose (
57- [transforms .ToTensor (), transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))]
58- )
5967 self .data_test = CIFAR10 (root = self .data_dir , train = False , transform = transform )
6068
6169 def train_dataloader (self ) -> DataLoader :
0 commit comments