Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 56b27f6

Browse files
dbogunowiczbogunowicz@arrival.combfineranBenjamin
authored
[SparseZoo Integration] update sparsezoo usage for 1.1 refactor (#955)
* [SparseZoo v2 Bridge] Save IC training artifacts to ModelDirectory directory (#864) * initial commit * remove rubbish file * Update train.py * refactor export samples * change name of folders * working, time for refactoring * ready for review * Delete hehe.py * Update helpers.py * Apply suggestions from code review Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com> Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com> * first round of edits * first sweep through all files * checking if pytorch checks pass * second sweep through all files * third sweep through all files * Update README.md * Update README.md * Update README.md * fourth sweep through all files * fifth sweep through all files * Update integrations/keras/README.md * rolling back some edits after bens comment * remove notebook output * add back onnx dataloading + batching functionality + tests * pytorch logic, orgainization, completeness fixes * use final checkpoint in eval mode IC training Co-authored-by: bogunowicz@arrival.com <bogunowicz@arrival.com> Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com> Co-authored-by: Benjamin <ben@neuralmagic.com>
1 parent a9d0585 commit 56b27f6

File tree

46 files changed

+392
-608
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+392
-608
lines changed

integrations/keras/README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,18 @@ Complete lists are available online for all [models](https://sparsezoo.neuralmag
5757

5858
Sample code for retrieving a model from the SparseZoo:
5959
```python
60-
from sparsezoo import Zoo
60+
from sparsezoo import Model
6161

62-
model = Zoo.load_model_from_stub("zoo:cv/classification/resnet_v1-50/keras/sparseml/imagenet/pruned-moderate")
62+
model = Model("zoo:cv/classification/resnet_v1-50/keras/sparseml/imagenet/pruned-moderate")
6363
print(model)
6464
```
6565

6666
Sample code for retrieving a recipe from the SparseZoo:
6767
```python
68-
from sparsezoo import Zoo
68+
from sparsezoo import Model
6969

70-
recipe = Zoo.load_recipe_from_stub("zoo:cv/classification/resnet_v1-50/keras/sparseml/imagenet/pruned-conservative/original")
70+
model = Model("zoo:cv/classification/resnet_v1-50/keras/sparseml/imagenet/pruned-conservative")
71+
recipe = model.recipes.default
7172
print(recipe)
7273
```
7374

integrations/keras/notebooks/classification.ipynb

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,26 @@
5555
},
5656
{
5757
"cell_type": "code",
58-
"execution_count": null,
59-
"metadata": {},
60-
"outputs": [],
58+
"execution_count": 2,
59+
"metadata": {},
60+
"outputs": [
61+
{
62+
"ename": "NameError",
63+
"evalue": "name 'Zoo' is not defined",
64+
"output_type": "error",
65+
"traceback": [
66+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
67+
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
68+
"Input \u001b[0;32mIn [2]\u001b[0m, in \u001b[0;36m<cell line: 37>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRecipe file not found: \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(recipe_file_path))\n\u001b[1;32m 35\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m model_file_path, recipe_file_path\n\u001b[0;32m---> 37\u001b[0m model_file_path, recipe_file_path \u001b[38;5;241m=\u001b[39m \u001b[43mdownload_model_and_recipe\u001b[49m\u001b[43m(\u001b[49m\u001b[43mroot_dir\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mLoading model \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(model_file_path))\n\u001b[1;32m 40\u001b[0m model \u001b[38;5;241m=\u001b[39m keras\u001b[38;5;241m.\u001b[39mmodels\u001b[38;5;241m.\u001b[39mload_model(model_file_path)\n",
69+
"Input \u001b[0;32mIn [2]\u001b[0m, in \u001b[0;36mdownload_model_and_recipe\u001b[0;34m(root_dir)\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;124;03mDownload pretrained model and a pruning recipe\u001b[39;00m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 12\u001b[0m model_dir \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(root_dir, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmnist\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 13\u001b[0m zoo_model \u001b[38;5;241m=\u001b[39m \u001b[43mZoo\u001b[49m\u001b[38;5;241m.\u001b[39mload_model(\n\u001b[1;32m 14\u001b[0m domain\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcv\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 15\u001b[0m sub_domain\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mclassification\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 16\u001b[0m architecture\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmnist\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 17\u001b[0m sub_architecture\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 18\u001b[0m framework\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mkeras\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 19\u001b[0m repo\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msparseml\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 20\u001b[0m dataset\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmnist\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 21\u001b[0m training_scheme\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 22\u001b[0m sparse_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpruned\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 23\u001b[0m sparse_category\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mconservative\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[1;32m 24\u001b[0m sparse_target\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 25\u001b[0m override_parent_path\u001b[38;5;241m=\u001b[39mmodel_dir,\n\u001b[1;32m 26\u001b[0m )\n\u001b[1;32m 27\u001b[0m zoo_model\u001b[38;5;241m.\u001b[39mdownload()\n\u001b[1;32m 29\u001b[0m model_file_path \u001b[38;5;241m=\u001b[39m zoo_model\u001b[38;5;241m.\u001b[39mframework_files[\u001b[38;5;241m0\u001b[39m]\u001b[38;5;241m.\u001b[39mdownloaded_path()\n",
70+
"\u001b[0;31mNameError\u001b[0m: name 'Zoo' is not defined"
71+
]
72+
}
73+
],
6174
"source": [
6275
"import os\n",
6376
"from sparseml.keras.utils import keras\n",
64-
"from sparsezoo.models import Zoo\n",
77+
"from sparsezoo.models import Model\n",
6578
"\n",
6679
"# Root directory for the notebook artifacts\n",
6780
"root_dir = \"./notebooks/keras\"\n",
@@ -71,26 +84,13 @@
7184
" Download pretrained model and a pruning recipe\n",
7285
" \"\"\"\n",
7386
" model_dir = os.path.join(root_dir, \"mnist\")\n",
74-
" zoo_model = Zoo.load_model(\n",
75-
" domain=\"cv\",\n",
76-
" sub_domain=\"classification\",\n",
77-
" architecture=\"mnist\",\n",
78-
" sub_architecture=None,\n",
79-
" framework=\"keras\",\n",
80-
" repo=\"sparseml\",\n",
81-
" dataset=\"mnist\",\n",
82-
" training_scheme=None,\n",
83-
" sparse_name=\"pruned\",\n",
84-
" sparse_category=\"conservative\",\n",
85-
" sparse_target=None,\n",
86-
" override_parent_path=model_dir,\n",
87-
" )\n",
88-
" zoo_model.download()\n",
89-
"\n",
90-
" model_file_path = zoo_model.framework_files[0].downloaded_path()\n",
87+
" zoo_model = Model(...)\n",
88+
"\n",
89+
"\n",
90+
" model_file_path = zoo_model.training.default.get_file(\"model.h5\").path\n",
9191
" if not os.path.exists(model_file_path) or not model_file_path.endswith(\".h5\"):\n",
9292
" raise RuntimeError(\"Model file not found: {}\".format(model_file_path))\n",
93-
" recipe_file_path = zoo_model.recipes[0].downloaded_path()\n",
93+
" recipe_file_path = zoo_model.recipes.default.path\n",
9494
" if not os.path.exists(recipe_file_path):\n",
9595
" raise RuntimeError(\"Recipe file not found: {}\".format(recipe_file_path))\n",
9696
" return model_file_path, recipe_file_path\n",
@@ -424,9 +424,9 @@
424424
],
425425
"metadata": {
426426
"kernelspec": {
427-
"display_name": "Python (keras_pruning)",
427+
"display_name": "Python 3 (ipykernel)",
428428
"language": "python",
429-
"name": "keras_pruning"
429+
"name": "python3"
430430
},
431431
"language_info": {
432432
"codemirror_mode": {
@@ -438,7 +438,7 @@
438438
"name": "python",
439439
"nbconvert_exporter": "python",
440440
"pygments_lexer": "ipython3",
441-
"version": "3.6.9"
441+
"version": "3.8.10"
442442
}
443443
},
444444
"nbformat": 4,

integrations/keras/prune_resnet20.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from sparseml.keras.utils.callbacks import LossesAndMetricsLoggingCallback
3838
from sparseml.keras.utils.exporter import ModelExporter
3939
from sparseml.keras.utils.logger import TensorBoardLogger
40-
from sparsezoo.models import Zoo
40+
from sparsezoo import Model
4141

4242

4343
# Root directory
@@ -64,33 +64,21 @@ def download_model_and_recipe(root_dir: str):
6464
"""
6565
Download pretrained model and a pruning recipe
6666
"""
67-
model_dir = os.path.join(root_dir, "resnet20_v1")
6867

69-
# Load base model to prune
70-
base_zoo_model = Zoo.load_model(
71-
domain="cv",
72-
sub_domain="classification",
73-
architecture="resnet_v1",
74-
sub_architecture=20,
75-
framework="keras",
76-
repo="sparseml",
77-
dataset="cifar_10",
78-
training_scheme=None,
79-
sparse_name="base",
80-
sparse_category="none",
81-
sparse_target=None,
82-
override_parent_path=model_dir,
83-
)
84-
base_zoo_model.download()
85-
model_file_path = base_zoo_model.framework_files[0].downloaded_path()
86-
if not os.path.exists(model_file_path) or not model_file_path.endswith(".h5"):
87-
raise RuntimeError("Model file not found: {}".format(model_file_path))
88-
89-
# Simply use the recipe stub
68+
# Use the recipe stub
9069
recipe_file_path = (
9170
"zoo:cv/classification/resnet_v1-20/keras/sparseml/cifar_10/pruned-conservative"
9271
)
9372

73+
# Load base model to prune
74+
base_zoo_model = Model(recipe_file_path)
75+
base_zoo_model.path = os.path.join(root_dir, "resnet20_v1")
76+
checkpoint = base_zoo_model.training.default
77+
model_file_path = checkpoint.get_file("model.h5").path
78+
recipe_file_path = base_zoo_model.recipes.default.path
79+
if not os.path.exists(model_file_path) or not model_file_path.endswith(".h5"):
80+
raise RuntimeError("Model file not found: {}".format(model_file_path))
81+
9482
return model_file_path, recipe_file_path
9583

9684

integrations/pytorch/README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,18 @@ Complete lists are available online for all [models](https://sparsezoo.neuralmag
6060

6161
Sample code for retrieving a model from the SparseZoo:
6262
```python
63-
from sparsezoo import Zoo
63+
from sparsezoo import Model
6464

65-
model = Zoo.load_model_from_stub("zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned_quant-moderate")
65+
model = Model("zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned_quant-moderate")
6666
print(model)
6767
```
6868

6969
Sample code for retrieving a recipe from the SparseZoo:
7070
```python
71-
from sparsezoo import Zoo
71+
from sparsezoo import Model
7272

73-
recipe = Zoo.load_recipe_from_stub("zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned_quant-moderate/original")
73+
model = Model("zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned_quant-moderate")
74+
recipe = model.recipes.default
7475
print(recipe)
7576
```
7677

integrations/pytorch/notebooks/classification.ipynb

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,9 +219,9 @@
219219
"metadata": {},
220220
"outputs": [],
221221
"source": [
222-
"from sparsezoo import Zoo\n",
222+
"from sparsezoo import Model, search_models\n",
223223
"\n",
224-
"recipe = Zoo.search_recipes(\n",
224+
"zoo_model = search_models(\n",
225225
" domain=\"cv\",\n",
226226
" sub_domain=\"classification\",\n",
227227
" architecture=\"resnet_v1\",\n",
@@ -231,8 +231,7 @@
231231
" dataset=\"imagenette\",\n",
232232
" sparse_name=\"pruned\",\n",
233233
")[0] # unwrap search result\n",
234-
"recipe.download()\n",
235-
"recipe_path = recipe.downloaded_path()\n",
234+
"recipe_path = zoo_model.recipes.default.path\n",
236235
"print(f\"Recipe downloaded to: {recipe_path}\")"
237236
]
238237
},
@@ -364,4 +363,4 @@
364363
},
365364
"nbformat": 4,
366365
"nbformat_minor": 4
367-
}
366+
}

integrations/pytorch/notebooks/detection.ipynb

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,9 @@
238238
"metadata": {},
239239
"outputs": [],
240240
"source": [
241-
"from sparsezoo import Zoo\n",
241+
"from sparsezoo import Model, search_models\n",
242242
"\n",
243-
"recipe = Zoo.search_recipes(\n",
243+
"zoo_model = search_models(\n",
244244
" domain=\"cv\",\n",
245245
" sub_domain=\"detection\",\n",
246246
" architecture=\"ssd\",\n",
@@ -250,8 +250,7 @@
250250
" dataset=\"voc\",\n",
251251
" sparse_name=\"pruned\",\n",
252252
")[0] # unwrap search result\n",
253-
"recipe.download()\n",
254-
"recipe_path = recipe.downloaded_path()\n",
253+
"recipe_path = zoo_model.recipes.default.path\n",
255254
"print(f\"Recipe downloaded to: {recipe_path}\")"
256255
]
257256
},
@@ -370,4 +369,4 @@
370369
},
371370
"nbformat": 4,
372371
"nbformat_minor": 4
373-
}
372+
}

integrations/pytorch/notebooks/sparse_quantized_transfer_learning.ipynb

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
"source": [
7777
"from sparseml.pytorch.models import ModelRegistry\n",
7878
"from sparseml.pytorch.datasets import ImagenetteDataset, ImagenetteSize\n",
79-
"from sparsezoo import Zoo\n",
79+
"from sparsezoo import Model\n",
8080
"\n",
8181
"#######################################################\n",
8282
"# Define your model below\n",
@@ -85,7 +85,7 @@
8585
"# SparseZoo stub to pre-trained sparse-quantized ResNet-50 for imagenet dataset\n",
8686
"zoo_stub_path = (\n",
8787
" \"zoo:cv/classification/resnet_v1-50/pytorch/sparseml/imagenet/pruned_quant-moderate\"\n",
88-
" \"?recipe_type=transfer_learn\"\n",
88+
" \"?recipe=transfer_learn\"\n",
8989
")\n",
9090
"model = ModelRegistry.create(\n",
9191
" key=\"resnet50\",\n",
@@ -236,9 +236,10 @@
236236
"metadata": {},
237237
"outputs": [],
238238
"source": [
239-
"from sparsezoo import Zoo\n",
239+
"from sparsezoo import Model\n",
240240
"\n",
241-
"recipe_path = Zoo.download_recipe_from_stub(zoo_stub_path)\n",
241+
"zoo_model = Model(zoo_stub_path)\n",
242+
"recipe_path = zoo_model.recipes.default.path\n",
242243
"print(f\"Recipe downloaded to: {recipe_path}\")"
243244
]
244245
},
@@ -450,4 +451,4 @@
450451
},
451452
"nbformat": 4,
452453
"nbformat_minor": 4
453-
}
454+
}

integrations/pytorch/notebooks/torchvision.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,9 @@
270270
"metadata": {},
271271
"outputs": [],
272272
"source": [
273-
"from sparsezoo import Zoo\n",
273+
"from sparsezoo import Model, search_models\n",
274274
"\n",
275-
"recipe = Zoo.search_recipes(\n",
275+
"zoo_model = search_models(\n",
276276
" domain=\"cv\",\n",
277277
" sub_domain=\"classification\",\n",
278278
" architecture=\"resnet_v1\",\n",
@@ -282,8 +282,8 @@
282282
" dataset=\"imagenette\",\n",
283283
" sparse_name=\"pruned\",\n",
284284
")[0] # unwrap search result\n",
285-
"recipe.download()\n",
286-
"recipe_path = recipe.downloaded_path()\n",
285+
"\n",
286+
"recipe_path = zoo_model.recipes.default.path\n",
287287
"print(f\"Recipe downloaded to: {recipe_path}\")"
288288
]
289289
},
@@ -396,4 +396,4 @@
396396
},
397397
"nbformat": 4,
398398
"nbformat_minor": 4
399-
}
399+
}

integrations/pytorch/utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
from torch.utils.data import DataLoader
2727

2828
from sparseml.pytorch.datasets import DatasetRegistry, ssd_collate_fn, yolo_collate_fn
29+
from sparseml.pytorch.image_classification.utils.helpers import (
30+
download_framework_model_by_recipe_type,
31+
)
2932
from sparseml.pytorch.models import ModelRegistry
3033
from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer
3134
from sparseml.pytorch.sparsification import ConstantPruningModifier
@@ -45,7 +48,7 @@
4548
torch_distributed_zero_first,
4649
)
4750
from sparseml.utils import create_dirs
48-
from sparsezoo import Zoo
51+
from sparsezoo import Model
4952

5053

5154
@unique
@@ -237,9 +240,10 @@ def create_model(args: Any, num_classes: int) -> Module:
237240
with torch_distributed_zero_first(args.local_rank): # only download once locally
238241
if args.checkpoint_path == "zoo":
239242
if args.recipe_path and args.recipe_path.startswith("zoo:"):
240-
args.checkpoint_path = Zoo.download_recipe_base_framework_files(
241-
args.recipe_path, extensions=[".pth"]
242-
)[0]
243+
zoo_model = Model(args.recipe_path)
244+
args.checkpoint_path = download_framework_model_by_recipe_type(
245+
zoo_model
246+
)
243247
else:
244248
raise ValueError(
245249
"'zoo' provided as --checkpoint-path but a SparseZoo stub"

integrations/tensorflow_v1/README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,18 @@ Complete lists are available online for all [models](https://sparsezoo.neuralmag
5757

5858
Sample code for retrieving a model from the SparseZoo:
5959
```python
60-
from sparsezoo import Zoo
60+
from sparsezoo import Model
6161

62-
model = Zoo.load_model_from_stub("zoo:cv/classification/resnet_v1-50/tensorflow_v1/sparseml/imagenette/pruned-moderate")
62+
model = Model("zoo:cv/classification/resnet_v1-50/tensorflow_v1/sparseml/imagenette/pruned-moderate")
6363
print(model)
6464
```
6565

6666
Sample code for retrieving a recipe from the SparseZoo:
6767
```python
68-
from sparsezoo import Zoo
68+
from sparsezoo import Model
6969

70-
recipe = Zoo.load_recipe_from_stub("zoo:cv/classification/resnet_v1-50/tensorflow_v1/sparseml/imagenette/pruned-moderate/original")
70+
model = Model("zoo:cv/classification/resnet_v1-50/tensorflow_v1/sparseml/imagenette/pruned-moderate")
71+
recipe = model.recipes.default
7172
print(recipe)
7273
```
7374

0 commit comments

Comments
 (0)