|
354 | 354 | }, |
355 | 355 | { |
356 | 356 | "cell_type": "code", |
357 | | - "execution_count": 1, |
| 357 | + "execution_count": null, |
358 | 358 | "metadata": {}, |
359 | 359 | "outputs": [], |
360 | 360 | "source": [ |
|
365 | 365 | }, |
366 | 366 | { |
367 | 367 | "cell_type": "code", |
368 | | - "execution_count": 2, |
| 368 | + "execution_count": null, |
369 | 369 | "metadata": {}, |
370 | | - "outputs": [ |
371 | | - { |
372 | | - "name": "stdout", |
373 | | - "output_type": "stream", |
374 | | - "text": [ |
375 | | - "Batch Size: 5\n", |
376 | | - "---------------\n", |
377 | | - "Inputs: tensor([[1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,\n", |
378 | | - " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", |
379 | | - " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", |
380 | | - " [1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,\n", |
381 | | - " 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0,\n", |
382 | | - " 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0],\n", |
383 | | - " [1, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,\n", |
384 | | - " 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1,\n", |
385 | | - " 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0],\n", |
386 | | - " [1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0,\n", |
387 | | - " 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", |
388 | | - " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", |
389 | | - " [0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,\n", |
390 | | - " 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,\n", |
391 | | - " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])\n", |
392 | | - "Targets: tensor([ 0, 1, 6, 9, 10])\n" |
393 | | - ] |
394 | | - } |
395 | | - ], |
| 370 | + "outputs": [], |
396 | 371 | "source": [ |
397 | 372 | "from torch.utils.data import DataLoader\n", |
398 | 373 | "# Set batch size for DataLoader\n", |
|
456 | 431 | "metadata": {}, |
457 | 432 | "outputs": [], |
458 | 433 | "source": [ |
459 | | - "# from spotPython.light.pkldataset import PKLDataset\n", |
460 | | - "# import torch\n", |
461 | | - "# dataset = PKLDataset(pkl_file='./data/spotPython/data_sensitive.pkl', target_column='N', feature_type=torch.float32, target_type=torch.float64, rmNA=False)" |
| 434 | + "from spotPython.data.pkldataset import PKLDataset\n", |
| 435 | + "import torch\n", |
| 436 | + "dataset = PKLDataset(directory=\"./data/spotPython/\", filename=\"data_sensitive.pkl\", target_column='N', feature_type=torch.float32, target_type=torch.float64, rmNA=False)" |
462 | 437 | ] |
463 | 438 | }, |
464 | 439 | { |
|
467 | 442 | "metadata": {}, |
468 | 443 | "outputs": [], |
469 | 444 | "source": [ |
470 | | - "# from torch.utils.data import DataLoader\n", |
471 | | - "# # Set batch size for DataLoader\n", |
472 | | - "# batch_size = 5\n", |
473 | | - "# # Create DataLoader\n", |
474 | | - "# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)\n", |
| 445 | + "from torch.utils.data import DataLoader\n", |
| 446 | + "# Set batch size for DataLoader\n", |
| 447 | + "batch_size = 5\n", |
| 448 | + "# Create DataLoader\n", |
| 449 | + "dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)\n", |
475 | 450 | "\n", |
476 | | - "# # Iterate over the data in the DataLoader\n", |
477 | | - "# for batch in dataloader:\n", |
478 | | - "# inputs, targets = batch\n", |
479 | | - "# print(f\"Batch Size: {inputs.size(0)}\")\n", |
480 | | - "# print(\"---------------\")\n", |
481 | | - "# print(f\"Inputs: {inputs}\")\n", |
482 | | - "# print(f\"Targets: {targets}\")\n", |
483 | | - "# break" |
| 451 | + "# Iterate over the data in the DataLoader\n", |
| 452 | + "for batch in dataloader:\n", |
| 453 | + " inputs, targets = batch\n", |
| 454 | + " print(f\"Batch Size: {inputs.size(0)}\")\n", |
| 455 | + " print(\"---------------\")\n", |
| 456 | + " print(f\"Inputs: {inputs}\")\n", |
| 457 | + " print(f\"Targets: {targets}\")\n", |
| 458 | + " break" |
| 459 | + ] |
| 460 | + }, |
| 461 | + { |
| 462 | + "cell_type": "markdown", |
| 463 | + "metadata": {}, |
| 464 | + "source": [ |
| 465 | + "# Test lightdatamodule" |
| 466 | + ] |
| 467 | + }, |
| 468 | + { |
| 469 | + "cell_type": "code", |
| 470 | + "execution_count": 1, |
| 471 | + "metadata": {}, |
| 472 | + "outputs": [ |
| 473 | + { |
| 474 | + "name": "stdout", |
| 475 | + "output_type": "stream", |
| 476 | + "text": [ |
| 477 | + "Loading data from /Users/bartz/miniforge3/envs/spotCondaEnv/lib/python3.11/site-packages/spotPython/data/data.csv\n", |
| 478 | + "11\n" |
| 479 | + ] |
| 480 | + } |
| 481 | + ], |
| 482 | + "source": [ |
| 483 | + "from spotPython.data.lightdatamodule import LightDataModule\n", |
| 484 | + "from spotPython.data.csvdataset import CSVDataset\n", |
| 485 | + "from spotPython.data.pkldataset import PKLDataset\n", |
| 486 | + "import torch\n", |
| 487 | + "dataset = CSVDataset(csv_file='data.csv', target_column='prognosis', feature_type=torch.long)\n", |
| 488 | + "# dataset = PKLDataset(directory=\"./data/spotPython/\", filename=\"data_sensitive.pkl\", target_column='N', feature_type=torch.float32, target_type=torch.float64, rmNA=False)\n", |
| 489 | + "print(len(dataset))" |
| 490 | + ] |
| 491 | + }, |
| 492 | + { |
| 493 | + "cell_type": "code", |
| 494 | + "execution_count": 2, |
| 495 | + "metadata": {}, |
| 496 | + "outputs": [], |
| 497 | + "source": [ |
| 498 | + "data_module = LightDataModule(dataset=dataset, batch_size=5, test_size=0.5)" |
484 | 499 | ] |
485 | 500 | }, |
| 501 | + { |
| 502 | + "cell_type": "code", |
| 503 | + "execution_count": 3, |
| 504 | + "metadata": {}, |
| 505 | + "outputs": [ |
| 506 | + { |
| 507 | + "name": "stdout", |
| 508 | + "output_type": "stream", |
| 509 | + "text": [ |
| 510 | + "full_train_size: 0.5\n", |
| 511 | + "val_size: 0.25\n", |
| 512 | + "train_size: 0.25\n", |
| 513 | + "test_size: 0.5\n" |
| 514 | + ] |
| 515 | + } |
| 516 | + ], |
| 517 | + "source": [ |
| 518 | + "data_module.setup()" |
| 519 | + ] |
| 520 | + }, |
| 521 | + { |
| 522 | + "cell_type": "code", |
| 523 | + "execution_count": 4, |
| 524 | + "metadata": {}, |
| 525 | + "outputs": [ |
| 526 | + { |
| 527 | + "name": "stdout", |
| 528 | + "output_type": "stream", |
| 529 | + "text": [ |
| 530 | + "Training set size: 3\n" |
| 531 | + ] |
| 532 | + } |
| 533 | + ], |
| 534 | + "source": [ |
| 535 | + "print(f\"Training set size: {len(data_module.data_train)}\")" |
| 536 | + ] |
| 537 | + }, |
| 538 | + { |
| 539 | + "cell_type": "code", |
| 540 | + "execution_count": 5, |
| 541 | + "metadata": {}, |
| 542 | + "outputs": [ |
| 543 | + { |
| 544 | + "name": "stdout", |
| 545 | + "output_type": "stream", |
| 546 | + "text": [ |
| 547 | + "Validation set size: 3\n" |
| 548 | + ] |
| 549 | + } |
| 550 | + ], |
| 551 | + "source": [ |
| 552 | + "print(f\"Validation set size: {len(data_module.data_val)}\")" |
| 553 | + ] |
| 554 | + }, |
| 555 | + { |
| 556 | + "cell_type": "code", |
| 557 | + "execution_count": 6, |
| 558 | + "metadata": {}, |
| 559 | + "outputs": [ |
| 560 | + { |
| 561 | + "name": "stdout", |
| 562 | + "output_type": "stream", |
| 563 | + "text": [ |
| 564 | + "Test set size: 6\n" |
| 565 | + ] |
| 566 | + } |
| 567 | + ], |
| 568 | + "source": [ |
| 569 | + "print(f\"Test set size: {len(data_module.data_test)}\")" |
| 570 | + ] |
| 571 | + }, |
| 572 | + { |
| 573 | + "cell_type": "code", |
| 574 | + "execution_count": null, |
| 575 | + "metadata": {}, |
| 576 | + "outputs": [], |
| 577 | + "source": [] |
| 578 | + }, |
486 | 579 | { |
487 | 580 | "cell_type": "code", |
488 | 581 | "execution_count": null, |
|
0 commit comments