|
407 | 407 | " x = self.fc3(x)\n", |
408 | 408 | " return x\n", |
409 | 409 | "\n", |
410 | | - "```\n", |
| 410 | + "```\n" |
| 411 | + ] |
| 412 | + }, |
| 413 | + { |
| 414 | + "attachments": {}, |
| 415 | + "cell_type": "markdown", |
| 416 | + "metadata": {}, |
| 417 | + "source": [ |
411 | 418 | "\n", |
412 | 419 | "#### The `Net_Core` class\n", |
413 | 420 | "\n", |
|
438 | 445 | " self.sgd_momentum = sgd_momentum\n", |
439 | 446 | "```\n", |
440 | 447 | "\n", |
441 | | - ":::{.callout-note}\n", |
| 448 | + ":::{.callout-note}" |
| 449 | + ] |
| 450 | + }, |
| 451 | + { |
| 452 | + "cell_type": "markdown", |
| 453 | + "metadata": {}, |
| 454 | + "source": [ |
442 | 455 | "\n", |
443 | 456 | "\n", |
444 | 457 | "We see that the class `Net_lin_reg` has additional attributes and does not inherit from `nn` directly. It adds an additional class, `Net_core`, that takes care of additional attributes that are common to all neural network models, e.g., the learning rate multiplier `lr_mult` or the batch size `batch_size`.\n", |
@@ -63168,12 +63181,12 @@ |
63168 | 63181 | ] |
63169 | 63182 | }, |
63170 | 63183 | { |
| 63184 | + "attachments": {}, |
63171 | 63185 | "cell_type": "markdown", |
63172 | 63186 | "metadata": {}, |
63173 | 63187 | "source": [ |
63174 | 63188 | "\n", |
63175 | | - "## Tensorboard {#sec-tensorboard}\n", |
63176 | | - "\n" |
| 63189 | + "## Tensorboard {#sec-tensorboard}" |
63177 | 63190 | ] |
63178 | 63191 | }, |
63179 | 63192 | { |
@@ -63389,13 +63402,14 @@ |
63389 | 63402 | "\n", |
63390 | 63403 | "\n", |
63391 | 63404 | "```{raw}\n", |
63392 | | - "Net_CIFAR10(\n", |
63393 | | - " (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))\n", |
63394 | | - " (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", |
63395 | | - " (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))\n", |
63396 | | - " (fc1): Linear(in_features=400, out_features=64, bias=True)\n", |
| 63405 | + "Net_lin_reg(\n", |
| 63406 | + " (fc1): Linear(in_features=10, out_features=64, bias=True)\n", |
63397 | 63407 | " (fc2): Linear(in_features=64, out_features=32, bias=True)\n", |
63398 | | - " (fc3): Linear(in_features=32, out_features=10, bias=True)\n", |
| 63408 | + " (fc3): Linear(in_features=32, out_features=1, bias=True)\n", |
| 63409 | + " (relu): ReLU()\n", |
| 63410 | + " (softmax): Softmax(dim=1)\n", |
| 63411 | + " (dropout1): Dropout(p=0.0, inplace=False)\n", |
| 63412 | + " (dropout2): Dropout(p=0.0, inplace=False)\n", |
63399 | 63413 | ")\n", |
63400 | 63414 | "```\n", |
63401 | 63415 | "\n" |
|
84464 | 84478 | "metadata": {}, |
84465 | 84479 | "source": [ |
84466 | 84480 | "```{raw}\n", |
84467 | | - "Loss on hold-out set: 1.2267619131326675\n", |
84468 | | - "Accuracy on hold-out set: 0.58955\n", |
84469 | | - "Early stopping at epoch 13\n", |
| 84481 | + "Epoch: 1\n", |
| 84482 | + "Loss on hold-out set: 0.17853929138431945\n", |
| 84483 | + "MeanAbsoluteError value on hold-out data: 0.3907899856567383\n", |
| 84484 | + "Epoch: 2\n", |
| 84485 | + "Loss on hold-out set: 0.17439044278115035\n", |
| 84486 | + "MeanAbsoluteError value on hold-out data: 0.38570401072502136\n", |
84470 | 84487 | "```\n", |
84471 | 84488 | "\n", |
84472 | 84489 | "If `path` is set to a filename, e.g., `path = \"model_spot_trained.pt\"`, the weights of the trained model will be loaded from this file." |
|
84509 | 84526 | " task=fun_control[\"task\"],)" |
84510 | 84527 | ] |
84511 | 84528 | }, |
| 84529 | + { |
| 84530 | + "cell_type": "raw", |
| 84531 | + "metadata": {}, |
| 84532 | + "source": [ |
| 84533 | + "Loss on hold-out set: 1.85966069472272e-05\n", |
| 84534 | + "MeanAbsoluteError value on hold-out data: 0.0021022311411798\n", |
| 84535 | + "Final evaluation: Validation loss: 1.85966069472272e-05\n", |
| 84536 | + "Final evaluation: Validation metric: 0.0021022311411798\n", |
| 84537 | + "----------------------------------------------\n", |
| 84538 | + "(1.85966069472272e-05, nan, tensor(0.0021))" |
| 84539 | + ] |
| 84540 | + }, |
84512 | 84541 | { |
84513 | 84542 | "cell_type": "markdown", |
84514 | 84543 | "metadata": {}, |
@@ -220752,7 +220781,7 @@ |
220752 | 220781 | "name": "python", |
220753 | 220782 | "nbconvert_exporter": "python", |
220754 | 220783 | "pygments_lexer": "ipython3", |
220755 | | - "version": "3.10.11" |
| 220784 | + "version": "3.10.10" |
220756 | 220785 | } |
220757 | 220786 | }, |
220758 | 220787 | "nbformat": 4, |
|
0 commit comments