Skip to content

Commit 83e5a7a

Browse files
lightdatamodule
1 parent 28f774f commit 83e5a7a

3 files changed

Lines changed: 160 additions & 45 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 137 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@
354354
},
355355
{
356356
"cell_type": "code",
357-
"execution_count": 1,
357+
"execution_count": null,
358358
"metadata": {},
359359
"outputs": [],
360360
"source": [
@@ -365,34 +365,9 @@
365365
},
366366
{
367367
"cell_type": "code",
368-
"execution_count": 2,
368+
"execution_count": null,
369369
"metadata": {},
370-
"outputs": [
371-
{
372-
"name": "stdout",
373-
"output_type": "stream",
374-
"text": [
375-
"Batch Size: 5\n",
376-
"---------------\n",
377-
"Inputs: tensor([[1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,\n",
378-
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
379-
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
380-
" [1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,\n",
381-
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0,\n",
382-
" 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],\n",
383-
" [1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,\n",
384-
" 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1,\n",
385-
" 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0],\n",
386-
" [1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0,\n",
387-
" 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
388-
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
389-
" [0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,\n",
390-
" 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,\n",
391-
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])\n",
392-
"Targets: tensor([ 0, 1, 6, 9, 10])\n"
393-
]
394-
}
395-
],
370+
"outputs": [],
396371
"source": [
397372
"from torch.utils.data import DataLoader\n",
398373
"# Set batch size for DataLoader\n",
@@ -456,9 +431,9 @@
456431
"metadata": {},
457432
"outputs": [],
458433
"source": [
459-
"# from spotPython.light.pkldataset import PKLDataset\n",
460-
"# import torch\n",
461-
"# dataset = PKLDataset(pkl_file='./data/spotPython/data_sensitive.pkl', target_column='N', feature_type=torch.float32, target_type=torch.float64, rmNA=False)"
434+
"from spotPython.data.pkldataset import PKLDataset\n",
435+
"import torch\n",
436+
"dataset = PKLDataset(directory=\"./data/spotPython/\", filename=\"data_sensitive.pkl\", target_column='N', feature_type=torch.float32, target_type=torch.float64, rmNA=False)"
462437
]
463438
},
464439
{
@@ -467,22 +442,140 @@
467442
"metadata": {},
468443
"outputs": [],
469444
"source": [
470-
"# from torch.utils.data import DataLoader\n",
471-
"# # Set batch size for DataLoader\n",
472-
"# batch_size = 5\n",
473-
"# # Create DataLoader\n",
474-
"# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)\n",
445+
"from torch.utils.data import DataLoader\n",
446+
"# Set batch size for DataLoader\n",
447+
"batch_size = 5\n",
448+
"# Create DataLoader\n",
449+
"dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)\n",
475450
"\n",
476-
"# # Iterate over the data in the DataLoader\n",
477-
"# for batch in dataloader:\n",
478-
"# inputs, targets = batch\n",
479-
"# print(f\"Batch Size: {inputs.size(0)}\")\n",
480-
"# print(\"---------------\")\n",
481-
"# print(f\"Inputs: {inputs}\")\n",
482-
"# print(f\"Targets: {targets}\")\n",
483-
"# break"
451+
"# Iterate over the data in the DataLoader\n",
452+
"for batch in dataloader:\n",
453+
" inputs, targets = batch\n",
454+
" print(f\"Batch Size: {inputs.size(0)}\")\n",
455+
" print(\"---------------\")\n",
456+
" print(f\"Inputs: {inputs}\")\n",
457+
" print(f\"Targets: {targets}\")\n",
458+
" break"
459+
]
460+
},
461+
{
462+
"cell_type": "markdown",
463+
"metadata": {},
464+
"source": [
465+
"# Test lightdatamodule"
466+
]
467+
},
468+
{
469+
"cell_type": "code",
470+
"execution_count": 1,
471+
"metadata": {},
472+
"outputs": [
473+
{
474+
"name": "stdout",
475+
"output_type": "stream",
476+
"text": [
477+
"Loading data from /Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/spotPython/data/data.csv\n",
478+
"11\n"
479+
]
480+
}
481+
],
482+
"source": [
483+
"from spotPython.data.lightdatamodule import LightDataModule\n",
484+
"from spotPython.data.csvdataset import CSVDataset\n",
485+
"from spotPython.data.pkldataset import PKLDataset\n",
486+
"import torch\n",
487+
"dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)\n",
488+
"# dataset = PKLDataset(directory=\"./data/spotPython/\", filename=\"data_sensitive.pkl\", target_column='N', feature_type=torch.float32, target_type=torch.float64, rmNA=False)\n",
489+
"print(len(dataset))"
490+
]
491+
},
492+
{
493+
"cell_type": "code",
494+
"execution_count": 2,
495+
"metadata": {},
496+
"outputs": [],
497+
"source": [
498+
"data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5)"
484499
]
485500
},
501+
{
502+
"cell_type": "code",
503+
"execution_count": 3,
504+
"metadata": {},
505+
"outputs": [
506+
{
507+
"name": "stdout",
508+
"output_type": "stream",
509+
"text": [
510+
"full_train_size: 0.5\n",
511+
"val_size: 0.25\n",
512+
"train_size: 0.25\n",
513+
"test_size: 0.5\n"
514+
]
515+
}
516+
],
517+
"source": [
518+
"data_module.setup()"
519+
]
520+
},
521+
{
522+
"cell_type": "code",
523+
"execution_count": 4,
524+
"metadata": {},
525+
"outputs": [
526+
{
527+
"name": "stdout",
528+
"output_type": "stream",
529+
"text": [
530+
"Training set size: 3\n"
531+
]
532+
}
533+
],
534+
"source": [
535+
"print(f\"Training set size: {len(data_module.data_train)}\")"
536+
]
537+
},
538+
{
539+
"cell_type": "code",
540+
"execution_count": 5,
541+
"metadata": {},
542+
"outputs": [
543+
{
544+
"name": "stdout",
545+
"output_type": "stream",
546+
"text": [
547+
"Validation set size: 3\n"
548+
]
549+
}
550+
],
551+
"source": [
552+
"print(f\"Validation set size: {len(data_module.data_val)}\")"
553+
]
554+
},
555+
{
556+
"cell_type": "code",
557+
"execution_count": 6,
558+
"metadata": {},
559+
"outputs": [
560+
{
561+
"name": "stdout",
562+
"output_type": "stream",
563+
"text": [
564+
"Test set size: 6\n"
565+
]
566+
}
567+
],
568+
"source": [
569+
"print(f\"Test set size: {len(data_module.data_test)}\")"
570+
]
571+
},
572+
{
573+
"cell_type": "code",
574+
"execution_count": null,
575+
"metadata": {},
576+
"outputs": [],
577+
"source": []
578+
},
486579
{
487580
"cell_type": "code",
488581
"execution_count": null,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.6.42"
10+
version = "0.6.43"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

test/test_lightdatamodule.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import pytest
2+
import torch
3+
from spotPython.data.lightdatamodule import LightDataModule
4+
from spotPython.data.csvdataset import CSVDataset
5+
6+
7+
def test_light_data_module():
8+
# Create an instance of CSVDataset for testing
9+
dataset = CSVDataset(target_column='prognosis', feature_type=torch.long)
10+
11+
# Test the length of the dataset
12+
assert len(dataset) > 0
13+
14+
data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5)
15+
data_module.setup()
16+
17+
# Test the length of val and train: should be equal, because test_size=0.5
18+
assert len(data_module.data_train) == len(data_module.data_val)
19+
20+
21+
if __name__ == "__main__":
22+
pytest.main(["-v", __file__])

0 commit comments

Comments
 (0)