Skip to content

Commit 82ab8fe

Browse files
Update 11_spot_hpt_torch.ipynb
1 parent df6c246 commit 82ab8fe

1 file changed

Lines changed: 31 additions & 113 deletions

File tree

notebooks/11_spot_hpt_torch.ipynb

Lines changed: 31 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
},
1313
{
1414
"cell_type": "code",
15-
"execution_count": 1,
15+
"execution_count": null,
1616
"metadata": {},
1717
"outputs": [],
1818
"source": [
@@ -26,20 +26,9 @@
2626
},
2727
{
2828
"cell_type": "code",
29-
"execution_count": 2,
29+
"execution_count": null,
3030
"metadata": {},
31-
"outputs": [
32-
{
33-
"data": {
34-
"text/plain": [
35-
"'11-torch_p040025_360min_10init_2023-04-27_00-08-17'"
36-
]
37-
},
38-
"execution_count": 2,
39-
"metadata": {},
40-
"output_type": "execute_result"
41-
}
42-
],
31+
"outputs": [],
4332
"source": [
4433
"import pickle\n",
4534
"import socket\n",
@@ -72,26 +61,16 @@
7261
},
7362
{
7463
"cell_type": "code",
75-
"execution_count": 3,
64+
"execution_count": null,
7665
"metadata": {},
77-
"outputs": [
78-
{
79-
"name": "stdout",
80-
"output_type": "stream",
81-
"text": [
82-
"spotPython 0.0.44\n",
83-
"spotRiver 0.0.92\n",
84-
"Note: you may need to restart the kernel to use updated packages.\n"
85-
]
86-
}
87-
],
66+
"outputs": [],
8867
"source": [
8968
"pip list | grep \"spot[RiverPython]\""
9069
]
9170
},
9271
{
9372
"cell_type": "code",
94-
"execution_count": 4,
73+
"execution_count": null,
9574
"metadata": {},
9675
"outputs": [],
9776
"source": [
@@ -110,7 +89,7 @@
11089
},
11190
{
11291
"cell_type": "code",
113-
"execution_count": 5,
92+
"execution_count": null,
11493
"metadata": {},
11594
"outputs": [],
11695
"source": [
@@ -174,18 +153,9 @@
174153
},
175154
{
176155
"cell_type": "code",
177-
"execution_count": 6,
156+
"execution_count": null,
178157
"metadata": {},
179-
"outputs": [
180-
{
181-
"name": "stdout",
182-
"output_type": "stream",
183-
"text": [
184-
"2.0.0\n",
185-
"MPS device: mps\n"
186-
]
187-
}
188-
],
158+
"outputs": [],
189159
"source": [
190160
"print(torch.__version__)\n",
191161
"# Check that MPS is available\n",
@@ -212,7 +182,7 @@
212182
},
213183
{
214184
"cell_type": "code",
215-
"execution_count": 7,
185+
"execution_count": null,
216186
"metadata": {},
217187
"outputs": [],
218188
"source": [
@@ -237,7 +207,7 @@
237207
},
238208
{
239209
"cell_type": "code",
240-
"execution_count": 8,
210+
"execution_count": null,
241211
"metadata": {},
242212
"outputs": [],
243213
"source": [
@@ -266,7 +236,7 @@
266236
},
267237
{
268238
"cell_type": "code",
269-
"execution_count": 9,
239+
"execution_count": null,
270240
"metadata": {},
271241
"outputs": [],
272242
"source": [
@@ -287,7 +257,7 @@
287257
},
288258
{
289259
"cell_type": "code",
290-
"execution_count": 10,
260+
"execution_count": null,
291261
"metadata": {},
292262
"outputs": [],
293263
"source": [
@@ -323,7 +293,7 @@
323293
},
324294
{
325295
"cell_type": "code",
326-
"execution_count": 11,
296+
"execution_count": null,
327297
"metadata": {},
328298
"outputs": [],
329299
"source": [
@@ -341,7 +311,7 @@
341311
},
342312
{
343313
"cell_type": "code",
344-
"execution_count": 12,
314+
"execution_count": null,
345315
"metadata": {},
346316
"outputs": [],
347317
"source": [
@@ -362,54 +332,26 @@
362332
},
363333
{
364334
"cell_type": "code",
365-
"execution_count": 13,
335+
"execution_count": null,
366336
"metadata": {},
367-
"outputs": [
368-
{
369-
"name": "stdout",
370-
"output_type": "stream",
371-
"text": [
372-
"Files already downloaded and verified\n",
373-
"Files already downloaded and verified\n"
374-
]
375-
}
376-
],
337+
"outputs": [],
377338
"source": [
378339
"train, test = load_data()"
379340
]
380341
},
381342
{
382343
"cell_type": "code",
383-
"execution_count": 14,
344+
"execution_count": null,
384345
"metadata": {},
385-
"outputs": [
386-
{
387-
"data": {
388-
"text/plain": [
389-
"Dataset CIFAR10\n",
390-
" Number of datapoints: 50000\n",
391-
" Root location: ./data\n",
392-
" Split: Train\n",
393-
" StandardTransform\n",
394-
"Transform: Compose(\n",
395-
" ToTensor()\n",
396-
" Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))\n",
397-
" )"
398-
]
399-
},
400-
"execution_count": 14,
401-
"metadata": {},
402-
"output_type": "execute_result"
403-
}
404-
],
346+
"outputs": [],
405347
"source": [
406348
"train.data.shape, test.data.shape\n",
407349
"train"
408350
]
409351
},
410352
{
411353
"cell_type": "code",
412-
"execution_count": 15,
354+
"execution_count": null,
413355
"metadata": {},
414356
"outputs": [],
415357
"source": [
@@ -432,7 +374,7 @@
432374
},
433375
{
434376
"cell_type": "code",
435-
"execution_count": 16,
377+
"execution_count": null,
436378
"metadata": {},
437379
"outputs": [],
438380
"source": [
@@ -458,7 +400,7 @@
458400
},
459401
{
460402
"cell_type": "code",
461-
"execution_count": 17,
403+
"execution_count": null,
462404
"metadata": {},
463405
"outputs": [],
464406
"source": [
@@ -487,7 +429,7 @@
487429
},
488430
{
489431
"cell_type": "code",
490-
"execution_count": 18,
432+
"execution_count": null,
491433
"metadata": {},
492434
"outputs": [],
493435
"source": [
@@ -505,7 +447,7 @@
505447
},
506448
{
507449
"cell_type": "code",
508-
"execution_count": 19,
450+
"execution_count": null,
509451
"metadata": {},
510452
"outputs": [],
511453
"source": [
@@ -536,12 +478,12 @@
536478
},
537479
{
538480
"cell_type": "code",
539-
"execution_count": 20,
481+
"execution_count": null,
540482
"metadata": {},
541483
"outputs": [],
542484
"source": [
543485
"fun = HyperTorch(seed=123, log_level=50).fun_torch\n",
544-
"weights = -1.0\n",
486+
"weights = 1.0\n",
545487
"horizon = 7*24\n",
546488
"oml_grace_period = 2\n",
547489
"step = 100\n",
@@ -587,7 +529,7 @@
587529
},
588530
{
589531
"cell_type": "code",
590-
"execution_count": 21,
532+
"execution_count": null,
591533
"metadata": {},
592534
"outputs": [],
593535
"source": [
@@ -602,23 +544,9 @@
602544
},
603545
{
604546
"cell_type": "code",
605-
"execution_count": 22,
547+
"execution_count": null,
606548
"metadata": {},
607-
"outputs": [
608-
{
609-
"name": "stdout",
610-
"output_type": "stream",
611-
"text": [
612-
"| name | type | default | lower | upper |\n",
613-
"|------------|--------|-----------|---------|---------|\n",
614-
"| l1 | int | 5 | 2 | 9 |\n",
615-
"| l2 | int | 5 | 2 | 9 |\n",
616-
"| lr | float | 0.001 | 0.0001 | 0.1 |\n",
617-
"| batch_size | int | 4 | 1 | 4 |\n",
618-
"| epochs | int | 3 | 1 | 4 |\n"
619-
]
620-
}
621-
],
549+
"outputs": [],
622550
"source": [
623551
"print(gen_design_table(fun_control))"
624552
]
@@ -636,19 +564,9 @@
636564
},
637565
{
638566
"cell_type": "code",
639-
"execution_count": 23,
567+
"execution_count": null,
640568
"metadata": {},
641-
"outputs": [
642-
{
643-
"name": "stdout",
644-
"output_type": "stream",
645-
"text": [
646-
"Using mps device\n",
647-
"[1, 2000] loss: 2.323\n",
648-
"[1, 4000] loss: 1.159\n"
649-
]
650-
}
651-
],
569+
"outputs": [],
652570
"source": [
653571
"spot_torch = spot.Spot(fun=fun,\n",
654572
" lower = lower,\n",

0 commit comments

Comments
 (0)