Skip to content

Commit 4c705ab

Browse files
prep cnn
1 parent 51e81e3 commit 4c705ab

3 files changed

Lines changed: 17 additions & 16 deletions

File tree

src/spotPython/light/cifar10datamodule.py renamed to src/spotPython/light/cifar10/cifar10datamodule.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ class CIFAR10DataModule(pl.LightningDataModule):
99
"""
1010
A LightningDataModule for handling CIFAR10 data.
1111
12+
Note: Torchvision provides many built-in datasets in the torchvision.datasets module,
13+
as well as utility classes for building your own datasets. All datasets are subclasses
14+
of torch.utils.data.Dataset i.e, they have __getitem__ and __len__ methods implemented.
15+
Hence, they can all be passed to a torch.utils.data.DataLoader which can load multiple
16+
samples in parallel using torch.multiprocessing workers, see [1].
17+
1218
Args:
1319
batch_size (int): The size of the batch.
1420
data_dir (str): The directory where the data is stored. Defaults to "./data".
@@ -18,6 +24,9 @@ class CIFAR10DataModule(pl.LightningDataModule):
1824
data_train (Dataset): The training dataset.
1925
data_val (Dataset): The validation dataset.
2026
data_test (Dataset): The test dataset.
27+
28+
References:
29+
[1] [https://pytorch.org/vision/stable/datasets.html](https://pytorch.org/vision/stable/datasets.html)
2130
"""
2231

2332
def __init__(self, batch_size: int, data_dir: str = "./data", num_workers: int = 0):
@@ -40,22 +49,21 @@ def setup(self, stage: Optional[str] = None) -> None:
4049
stage (Optional[str]): The current stage. Defaults to None.
4150
4251
"""
52+
# Assign appropriate data transforms, see
53+
# https://lightning.ai/docs/pytorch/latest/notebooks/course_UvA-DL/04-inception-resnet-densenet.html
54+
DATA_MEANS = (0.49139968, 0.48215841, 0.44653091)
55+
DATA_STDS = (0.24703223, 0.24348513, 0.26158784)
56+
transform = transforms.Compose(
57+
[transforms.ToTensor(), transforms.Normalize(DATA_MEANS, DATA_STDS)]
58+
)
4359
# Assign train/val datasets for use in dataloaders
4460
if stage == "fit" or stage is None:
45-
transform = transforms.Compose(
46-
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
47-
)
4861
data_full = CIFAR10(root=self.data_dir, train=True, transform=transform)
49-
# self.data_train, self.data_val = random_split(daata_full, [45000, 5000])
5062
test_abs = int(len(data_full) * 0.6)
51-
print("dm.setup(): test_abs", test_abs)
5263
self.data_train, self.data_val = random_split(data_full, [test_abs, len(data_full) - test_abs])
5364

5465
# Assign test dataset for use in dataloader(s)
5566
if stage == "test" or stage is None:
56-
transform = transforms.Compose(
57-
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
58-
)
5967
self.data_test = CIFAR10(root=self.data_dir, train=False, transform=transform)
6068

6169
def train_dataloader(self) -> DataLoader:

src/spotPython/light/cnn/netcnnbase.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,6 @@
3232
from torchvision import transforms
3333
from torchvision.datasets import CIFAR10
3434

35-
matplotlib_inline.backend_inline.set_matplotlib_formats("svg", "pdf") # For export
36-
matplotlib.rcParams["lines.linewidth"] = 2.0
37-
sns.reset_orig()
38-
39-
# PyTorch
40-
# Torchvision
41-
4235

4336
class NetCNNBase(L.LightningModule):
4437
def __init__(self, model_name, model_hparams, optimizer_name, optimizer_hparams):

src/spotPython/light/traintest_NEW.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import lightning as L
2-
from spotPython.light.cifar10datamodule import CIFAR10DataModule
2+
from spotPython.light.cifar10.cifar10datamodule import CIFAR10DataModule
33
from spotPython.light.crossvalidationdatamodule import CrossValidationDataModule
44
from spotPython.utils.eda import generate_config_id
55
from pytorch_lightning.loggers import TensorBoardLogger

0 commit comments

Comments
 (0)