|
12 | 12 | }, |
13 | 13 | { |
14 | 14 | "cell_type": "code", |
15 | | - "execution_count": 1, |
| 15 | + "execution_count": null, |
16 | 16 | "metadata": {}, |
17 | 17 | "outputs": [], |
18 | 18 | "source": [ |
|
26 | 26 | }, |
27 | 27 | { |
28 | 28 | "cell_type": "code", |
29 | | - "execution_count": 2, |
| 29 | + "execution_count": null, |
30 | 30 | "metadata": {}, |
31 | | - "outputs": [ |
32 | | - { |
33 | | - "data": { |
34 | | - "text/plain": [ |
35 | | - "'11-torch_p040025_360min_10init_2023-04-27_00-08-17'" |
36 | | - ] |
37 | | - }, |
38 | | - "execution_count": 2, |
39 | | - "metadata": {}, |
40 | | - "output_type": "execute_result" |
41 | | - } |
42 | | - ], |
| 31 | + "outputs": [], |
43 | 32 | "source": [ |
44 | 33 | "import pickle\n", |
45 | 34 | "import socket\n", |
|
72 | 61 | }, |
73 | 62 | { |
74 | 63 | "cell_type": "code", |
75 | | - "execution_count": 3, |
| 64 | + "execution_count": null, |
76 | 65 | "metadata": {}, |
77 | | - "outputs": [ |
78 | | - { |
79 | | - "name": "stdout", |
80 | | - "output_type": "stream", |
81 | | - "text": [ |
82 | | - "spotPython 0.0.44\n", |
83 | | - "spotRiver 0.0.92\n", |
84 | | - "Note: you may need to restart the kernel to use updated packages.\n" |
85 | | - ] |
86 | | - } |
87 | | - ], |
| 66 | + "outputs": [], |
88 | 67 | "source": [ |
89 | 68 | "pip list | grep \"spot[RiverPython]\"" |
90 | 69 | ] |
91 | 70 | }, |
92 | 71 | { |
93 | 72 | "cell_type": "code", |
94 | | - "execution_count": 4, |
| 73 | + "execution_count": null, |
95 | 74 | "metadata": {}, |
96 | 75 | "outputs": [], |
97 | 76 | "source": [ |
|
110 | 89 | }, |
111 | 90 | { |
112 | 91 | "cell_type": "code", |
113 | | - "execution_count": 5, |
| 92 | + "execution_count": null, |
114 | 93 | "metadata": {}, |
115 | 94 | "outputs": [], |
116 | 95 | "source": [ |
|
174 | 153 | }, |
175 | 154 | { |
176 | 155 | "cell_type": "code", |
177 | | - "execution_count": 6, |
| 156 | + "execution_count": null, |
178 | 157 | "metadata": {}, |
179 | | - "outputs": [ |
180 | | - { |
181 | | - "name": "stdout", |
182 | | - "output_type": "stream", |
183 | | - "text": [ |
184 | | - "2.0.0\n", |
185 | | - "MPS device: mps\n" |
186 | | - ] |
187 | | - } |
188 | | - ], |
| 158 | + "outputs": [], |
189 | 159 | "source": [ |
190 | 160 | "print(torch.__version__)\n", |
191 | 161 | "# Check that MPS is available\n", |
|
212 | 182 | }, |
213 | 183 | { |
214 | 184 | "cell_type": "code", |
215 | | - "execution_count": 7, |
| 185 | + "execution_count": null, |
216 | 186 | "metadata": {}, |
217 | 187 | "outputs": [], |
218 | 188 | "source": [ |
|
237 | 207 | }, |
238 | 208 | { |
239 | 209 | "cell_type": "code", |
240 | | - "execution_count": 8, |
| 210 | + "execution_count": null, |
241 | 211 | "metadata": {}, |
242 | 212 | "outputs": [], |
243 | 213 | "source": [ |
|
266 | 236 | }, |
267 | 237 | { |
268 | 238 | "cell_type": "code", |
269 | | - "execution_count": 9, |
| 239 | + "execution_count": null, |
270 | 240 | "metadata": {}, |
271 | 241 | "outputs": [], |
272 | 242 | "source": [ |
|
287 | 257 | }, |
288 | 258 | { |
289 | 259 | "cell_type": "code", |
290 | | - "execution_count": 10, |
| 260 | + "execution_count": null, |
291 | 261 | "metadata": {}, |
292 | 262 | "outputs": [], |
293 | 263 | "source": [ |
|
323 | 293 | }, |
324 | 294 | { |
325 | 295 | "cell_type": "code", |
326 | | - "execution_count": 11, |
| 296 | + "execution_count": null, |
327 | 297 | "metadata": {}, |
328 | 298 | "outputs": [], |
329 | 299 | "source": [ |
|
341 | 311 | }, |
342 | 312 | { |
343 | 313 | "cell_type": "code", |
344 | | - "execution_count": 12, |
| 314 | + "execution_count": null, |
345 | 315 | "metadata": {}, |
346 | 316 | "outputs": [], |
347 | 317 | "source": [ |
|
362 | 332 | }, |
363 | 333 | { |
364 | 334 | "cell_type": "code", |
365 | | - "execution_count": 13, |
| 335 | + "execution_count": null, |
366 | 336 | "metadata": {}, |
367 | | - "outputs": [ |
368 | | - { |
369 | | - "name": "stdout", |
370 | | - "output_type": "stream", |
371 | | - "text": [ |
372 | | - "Files already downloaded and verified\n", |
373 | | - "Files already downloaded and verified\n" |
374 | | - ] |
375 | | - } |
376 | | - ], |
| 337 | + "outputs": [], |
377 | 338 | "source": [ |
378 | 339 | "train, test = load_data()" |
379 | 340 | ] |
380 | 341 | }, |
381 | 342 | { |
382 | 343 | "cell_type": "code", |
383 | | - "execution_count": 14, |
| 344 | + "execution_count": null, |
384 | 345 | "metadata": {}, |
385 | | - "outputs": [ |
386 | | - { |
387 | | - "data": { |
388 | | - "text/plain": [ |
389 | | - "Dataset CIFAR10\n", |
390 | | - " Number of datapoints: 50000\n", |
391 | | - " Root location: ./data\n", |
392 | | - " Split: Train\n", |
393 | | - " StandardTransform\n", |
394 | | - "Transform: Compose(\n", |
395 | | - " ToTensor()\n", |
396 | | - " Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))\n", |
397 | | - " )" |
398 | | - ] |
399 | | - }, |
400 | | - "execution_count": 14, |
401 | | - "metadata": {}, |
402 | | - "output_type": "execute_result" |
403 | | - } |
404 | | - ], |
| 346 | + "outputs": [], |
405 | 347 | "source": [ |
406 | 348 | "train.data.shape, test.data.shape\n", |
407 | 349 | "train" |
408 | 350 | ] |
409 | 351 | }, |
410 | 352 | { |
411 | 353 | "cell_type": "code", |
412 | | - "execution_count": 15, |
| 354 | + "execution_count": null, |
413 | 355 | "metadata": {}, |
414 | 356 | "outputs": [], |
415 | 357 | "source": [ |
|
432 | 374 | }, |
433 | 375 | { |
434 | 376 | "cell_type": "code", |
435 | | - "execution_count": 16, |
| 377 | + "execution_count": null, |
436 | 378 | "metadata": {}, |
437 | 379 | "outputs": [], |
438 | 380 | "source": [ |
|
458 | 400 | }, |
459 | 401 | { |
460 | 402 | "cell_type": "code", |
461 | | - "execution_count": 17, |
| 403 | + "execution_count": null, |
462 | 404 | "metadata": {}, |
463 | 405 | "outputs": [], |
464 | 406 | "source": [ |
|
487 | 429 | }, |
488 | 430 | { |
489 | 431 | "cell_type": "code", |
490 | | - "execution_count": 18, |
| 432 | + "execution_count": null, |
491 | 433 | "metadata": {}, |
492 | 434 | "outputs": [], |
493 | 435 | "source": [ |
|
505 | 447 | }, |
506 | 448 | { |
507 | 449 | "cell_type": "code", |
508 | | - "execution_count": 19, |
| 450 | + "execution_count": null, |
509 | 451 | "metadata": {}, |
510 | 452 | "outputs": [], |
511 | 453 | "source": [ |
|
536 | 478 | }, |
537 | 479 | { |
538 | 480 | "cell_type": "code", |
539 | | - "execution_count": 20, |
| 481 | + "execution_count": null, |
540 | 482 | "metadata": {}, |
541 | 483 | "outputs": [], |
542 | 484 | "source": [ |
543 | 485 | "fun = HyperTorch(seed=123, log_level=50).fun_torch\n", |
544 | | - "weights = -1.0\n", |
| 486 | + "weights = 1.0\n", |
545 | 487 | "horizon = 7*24\n", |
546 | 488 | "oml_grace_period = 2\n", |
547 | 489 | "step = 100\n", |
|
587 | 529 | }, |
588 | 530 | { |
589 | 531 | "cell_type": "code", |
590 | | - "execution_count": 21, |
| 532 | + "execution_count": null, |
591 | 533 | "metadata": {}, |
592 | 534 | "outputs": [], |
593 | 535 | "source": [ |
|
602 | 544 | }, |
603 | 545 | { |
604 | 546 | "cell_type": "code", |
605 | | - "execution_count": 22, |
| 547 | + "execution_count": null, |
606 | 548 | "metadata": {}, |
607 | | - "outputs": [ |
608 | | - { |
609 | | - "name": "stdout", |
610 | | - "output_type": "stream", |
611 | | - "text": [ |
612 | | - "| name | type | default | lower | upper |\n", |
613 | | - "|------------|--------|-----------|---------|---------|\n", |
614 | | - "| l1 | int | 5 | 2 | 9 |\n", |
615 | | - "| l2 | int | 5 | 2 | 9 |\n", |
616 | | - "| lr | float | 0.001 | 0.0001 | 0.1 |\n", |
617 | | - "| batch_size | int | 4 | 1 | 4 |\n", |
618 | | - "| epochs | int | 3 | 1 | 4 |\n" |
619 | | - ] |
620 | | - } |
621 | | - ], |
| 549 | + "outputs": [], |
622 | 550 | "source": [ |
623 | 551 | "print(gen_design_table(fun_control))" |
624 | 552 | ] |
|
636 | 564 | }, |
637 | 565 | { |
638 | 566 | "cell_type": "code", |
639 | | - "execution_count": 23, |
| 567 | + "execution_count": null, |
640 | 568 | "metadata": {}, |
641 | | - "outputs": [ |
642 | | - { |
643 | | - "name": "stdout", |
644 | | - "output_type": "stream", |
645 | | - "text": [ |
646 | | - "Using mps device\n", |
647 | | - "[1, 2000] loss: 2.323\n", |
648 | | - "[1, 4000] loss: 1.159\n" |
649 | | - ] |
650 | | - } |
651 | | - ], |
| 569 | + "outputs": [], |
652 | 570 | "source": [ |
653 | 571 | "spot_torch = spot.Spot(fun=fun,\n", |
654 | 572 | " lower = lower,\n", |
|
0 commit comments