1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14-
1514import os
16- from collections import OrderedDict
15+ import shutil
1716
1817import pytest
1918import torch
20- import transformers
19+ from transformers import AutoConfig , AutoModelForCausalLM
2120
22- from huggingface_hub import snapshot_download
21+ from accelerate import init_empty_weights
2322from sparseml .transformers .utils .helpers import (
2423 create_fake_dataloader ,
2524 infer_recipe_from_model_path ,
3231
3332
3433@pytest .fixture ()
35- def generative_model_path ( tmp_path ):
36- return snapshot_download ( "roneneldan/TinyStories-1M" , local_dir = tmp_path )
34+ def generative_model ( ):
35+ return "roneneldan/TinyStories-1M"
3736
3837
3938@pytest .fixture ()
40- def model_path (tmp_path ):
41- return Model (
42- "zoo:mobilebert-squad_wikipedia_bookcorpus-14layer_pruned50.4block_quantized" ,
43- tmp_path ,
44- ).training .path
39+ def bert_model ():
40+ return "zoo:nlp/question_answering/bert-base/pytorch/huggingface/squad/pruned95_obs_quant-none" # noqa E501
4541
4642
4743@pytest .fixture ()
4844def sequence_length ():
49- return 384
45+ return 320
5046
5147
52- @pytest .fixture ()
53- def dummy_inputs ():
54- input_ids = torch .zeros ((1 , 32 ), dtype = torch .int64 )
55- attention_mask = torch .ones ((1 , 32 ), dtype = torch .int64 )
48+ def test_create_fake_dataloader (generative_model , sequence_length ):
49+ config = AutoConfig .from_pretrained (generative_model )
50+ tokenizer = initialize_tokenizer (
51+ generative_model , sequence_length = sequence_length , task = "text-generation"
52+ )
53+ with init_empty_weights ():
54+ model = AutoModelForCausalLM .from_config (config )
5655
57- return OrderedDict (
58- input_ids = input_ids ,
59- attention_mask = attention_mask ,
56+ expected_input_names = ["input_ids" , "attention_mask" ]
57+ num_samples = 2
58+ data_loader , input_names = create_fake_dataloader (
59+ model = model ,
60+ tokenizer = tokenizer ,
61+ num_samples = num_samples ,
6062 )
6163
64+ assert input_names == expected_input_names
65+ for i , sample in enumerate (data_loader ):
66+ assert sample ["input_ids" ].shape == torch .Size ([1 , sequence_length ])
67+ assert sample ["attention_mask" ].shape == torch .Size ([1 , sequence_length ])
68+ assert set (sample .keys ()) == set (expected_input_names )
69+ assert i == num_samples - 1
6270
63- @pytest .mark .parametrize (
64- "stub" ,
65- [
66- "zoo:nlp/question_answering/bert-base/pytorch/huggingface/squad/pruned95_obs_quant-none" , # noqa E501
67- ],
68- )
69- def test_is_transformer_model (tmp_path , stub ):
70- zoo_model = Model (stub , tmp_path )
71+
72+ def test_is_transformer_model (tmp_path , bert_model ):
73+ zoo_model = Model (bert_model , tmp_path )
7174 source_path = zoo_model .training .path
7275 assert is_transformer_model (source_path )
73-
74-
75- @pytest .mark .parametrize (
76- "stub" ,
77- [
78- "zoo:nlp/question_answering/bert-base/pytorch/huggingface/squad/pruned95_obs_quant-none" , # noqa E501
79- ],
80- )
81- def test_save_zoo_directory (stub , tmp_path_factory ):
82- path_to_training_outputs = tmp_path_factory .mktemp ("outputs" )
83- save_dir = tmp_path_factory .mktemp ("save_dir" )
84-
85- zoo_model = Model (stub , path_to_training_outputs )
86- zoo_model .download ()
87-
88- save_zoo_directory (
89- output_dir = save_dir ,
90- training_outputs_dir = path_to_training_outputs ,
91- )
92- new_zoo_model = Model (str (save_dir ))
93- assert new_zoo_model .validate (minimal_validation = True , validate_onnxruntime = False )
94-
95-
96- @pytest .mark .parametrize (
97- "model_path, recipe_found" ,
98- [
99- ("roneneldan/TinyStories-1M" , False ),
100- ("mgoin/all-MiniLM-L6-v2-quant-ds" , True ),
101- (
102- "zoo:mobilebert-squad_wikipedia_bookcorpus-14layer_pruned50.4block_quantized" , # noqa E501
103- True ,
104- ),
105- ],
106- )
107- def test_infer_recipe_from_model_path (model_path , recipe_found ):
108- recipe = infer_recipe_from_model_path (model_path )
109- if recipe_found :
110- assert isinstance (recipe , str )
111- return
112- assert recipe is None
76+ shutil .rmtree (tmp_path )
11377
11478
11579def test_infer_recipe_from_local_model_path (tmp_path ):
@@ -124,6 +88,16 @@ def test_infer_recipe_from_local_model_path(tmp_path):
12488 assert recipe == recipe_path .as_posix ()
12589
12690
91+ @pytest .fixture (autouse = True )
92+ def model_path_and_recipe_path (tmp_path ):
93+ model_path = tmp_path / "model.onnx"
94+ recipe_path = tmp_path / "recipe.yaml"
95+ recipe_path .touch ()
96+ model_path .touch ()
97+
98+ return model_path , recipe_path
99+
100+
127101@pytest .mark .parametrize (
128102 "model_path" ,
129103 [
@@ -140,16 +114,6 @@ def test_resolve_recipe_file(model_path, model_path_and_recipe_path):
140114 )
141115
142116
143- @pytest .fixture ()
144- def model_path_and_recipe_path (tmp_path ):
145- model_path = tmp_path / "model.onnx"
146- recipe_path = tmp_path / "recipe.yaml"
147- recipe_path .touch ()
148- model_path .touch ()
149-
150- return model_path , recipe_path
151-
152-
153117def test_resolve_recipe_file_from_local_path (model_path_and_recipe_path ):
154118 model_path , recipe_path = model_path_and_recipe_path
155119 assert recipe_path .as_posix () == resolve_recipe_file (
@@ -165,24 +129,40 @@ def test_resolve_recipe_file_from_local_path(model_path_and_recipe_path):
165129 )
166130
167131
168- def test_create_fake_dataloader (generative_model_path , sequence_length ):
169- expected_input_names = ["input_ids" , "attention_mask" ]
170- sequence_length = 32
171- num_samples = 2
132+ @pytest .mark .parametrize (
133+ "model, recipe_found" ,
134+ [
135+ ("roneneldan/TinyStories-1M" , False ),
136+ ("mgoin/all-MiniLM-L6-v2-quant-ds" , True ),
137+ (
138+ "zoo:mobilebert-squad_wikipedia_bookcorpus-14layer_pruned50.4block_quantized" , # noqa E501
139+ True ,
140+ ),
141+ ],
142+ )
143+ def test_infer_recipe_from_model_path (model , recipe_found ):
144+ recipe = infer_recipe_from_model_path (model )
145+ if recipe_found :
146+ assert isinstance (recipe , str )
147+ return
148+ assert recipe is None
172149
173- model = transformers .AutoModelForCausalLM .from_pretrained (generative_model_path )
174- tokenizer = initialize_tokenizer (
175- generative_model_path , sequence_length = sequence_length , task = "text-generation"
176- )
177- data_loader , input_names = create_fake_dataloader (
178- model = model ,
179- tokenizer = tokenizer ,
180- num_samples = num_samples ,
181- )
182150
183- assert input_names == expected_input_names
184- for i , sample in enumerate (data_loader ):
185- assert sample ["input_ids" ].shape == torch .Size ([1 , sequence_length ])
186- assert sample ["attention_mask" ].shape == torch .Size ([1 , sequence_length ])
187- assert set (sample .keys ()) == set (expected_input_names )
188- assert i == num_samples - 1
151+ @pytest .mark .parametrize (
152+ "stub" ,
153+ [
154+ "zoo:nlp/question_answering/bert-base/pytorch/huggingface/squad/pruned95_obs_quant-none" , # noqa E501
155+ ],
156+ )
157+ def test_save_zoo_directory (tmp_path , stub ):
158+ path_to_training_outputs = Model (stub ).path
159+ save_dir = tmp_path
160+
161+ save_zoo_directory (
162+ output_dir = save_dir ,
163+ training_outputs_dir = path_to_training_outputs ,
164+ )
165+ zoo_model = Model (str (save_dir ))
166+ assert zoo_model .validate (minimal_validation = True , validate_onnxruntime = False )
167+ shutil .rmtree (path_to_training_outputs )
168+ shutil .rmtree (save_dir )
0 commit comments