Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit d64668d

Browse files
dbogunowiczSara Adkins
andauthored
[Cherry-Pick] Fix GHA transformers errors (#2176)
* [Export API Refactor][Fix] Varying sample and model input names on correctness validation (#2131) * initial commit * Apply suggestions from code review * quality * solving the actual problem * Fix GHA transformer errors (#2175) * initial commit * add test_helpers * revert * fix lm_head edge case * Remove leftover print --------- Co-authored-by: Sara Adkins <sara@neuralmagic.com> --------- Co-authored-by: Sara Adkins <sara@neuralmagic.com>
1 parent 7d2f7ce commit d64668d

File tree

2 files changed

+82
-97
lines changed

2 files changed

+82
-97
lines changed

src/sparseml/modifiers/utils/layer_compressor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,12 @@ def revert_layer_wrappers(self):
122122
Reverts wrapped root modules back to their original structure
123123
"""
124124
for name, module_wrapper in self.modules.items():
125-
set_layer(name, module_wrapper.layer, self.layer)
125+
full_name = self._get_full_submodule_name(name)
126+
if len(name) == 0: # special case if layer has no children (i.e. lm_head)
127+
with summon_full_params_context(self.model):
128+
set_layer(full_name, module_wrapper.layer, self.model)
129+
else:
130+
set_layer(name, module_wrapper.layer, self.layer)
126131
module_wrapper.free()
127132
self.modules = None
128133

tests/sparseml/transformers/utils/test_helpers.py

Lines changed: 76 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
1514
import os
16-
from collections import OrderedDict
15+
import shutil
1716

1817
import pytest
1918
import torch
20-
import transformers
19+
from transformers import AutoConfig, AutoModelForCausalLM
2120

22-
from huggingface_hub import snapshot_download
21+
from accelerate import init_empty_weights
2322
from sparseml.transformers.utils.helpers import (
2423
create_fake_dataloader,
2524
infer_recipe_from_model_path,
@@ -32,84 +31,49 @@
3231

3332

3433
@pytest.fixture()
35-
def generative_model_path(tmp_path):
36-
return snapshot_download("roneneldan/TinyStories-1M", local_dir=tmp_path)
34+
def generative_model():
35+
return "roneneldan/TinyStories-1M"
3736

3837

3938
@pytest.fixture()
40-
def model_path(tmp_path):
41-
return Model(
42-
"zoo:mobilebert-squad_wikipedia_bookcorpus-14layer_pruned50.4block_quantized",
43-
tmp_path,
44-
).training.path
39+
def bert_model():
40+
return "zoo:nlp/question_answering/bert-base/pytorch/huggingface/squad/pruned95_obs_quant-none" # noqa E501
4541

4642

4743
@pytest.fixture()
4844
def sequence_length():
49-
return 384
45+
return 320
5046

5147

52-
@pytest.fixture()
53-
def dummy_inputs():
54-
input_ids = torch.zeros((1, 32), dtype=torch.int64)
55-
attention_mask = torch.ones((1, 32), dtype=torch.int64)
48+
def test_create_fake_dataloader(generative_model, sequence_length):
49+
config = AutoConfig.from_pretrained(generative_model)
50+
tokenizer = initialize_tokenizer(
51+
generative_model, sequence_length=sequence_length, task="text-generation"
52+
)
53+
with init_empty_weights():
54+
model = AutoModelForCausalLM.from_config(config)
5655

57-
return OrderedDict(
58-
input_ids=input_ids,
59-
attention_mask=attention_mask,
56+
expected_input_names = ["input_ids", "attention_mask"]
57+
num_samples = 2
58+
data_loader, input_names = create_fake_dataloader(
59+
model=model,
60+
tokenizer=tokenizer,
61+
num_samples=num_samples,
6062
)
6163

64+
assert input_names == expected_input_names
65+
for i, sample in enumerate(data_loader):
66+
assert sample["input_ids"].shape == torch.Size([1, sequence_length])
67+
assert sample["attention_mask"].shape == torch.Size([1, sequence_length])
68+
assert set(sample.keys()) == set(expected_input_names)
69+
assert i == num_samples - 1
6270

63-
@pytest.mark.parametrize(
64-
"stub",
65-
[
66-
"zoo:nlp/question_answering/bert-base/pytorch/huggingface/squad/pruned95_obs_quant-none", # noqa E501
67-
],
68-
)
69-
def test_is_transformer_model(tmp_path, stub):
70-
zoo_model = Model(stub, tmp_path)
71+
72+
def test_is_transformer_model(tmp_path, bert_model):
73+
zoo_model = Model(bert_model, tmp_path)
7174
source_path = zoo_model.training.path
7275
assert is_transformer_model(source_path)
73-
74-
75-
@pytest.mark.parametrize(
76-
"stub",
77-
[
78-
"zoo:nlp/question_answering/bert-base/pytorch/huggingface/squad/pruned95_obs_quant-none", # noqa E501
79-
],
80-
)
81-
def test_save_zoo_directory(stub, tmp_path_factory):
82-
path_to_training_outputs = tmp_path_factory.mktemp("outputs")
83-
save_dir = tmp_path_factory.mktemp("save_dir")
84-
85-
zoo_model = Model(stub, path_to_training_outputs)
86-
zoo_model.download()
87-
88-
save_zoo_directory(
89-
output_dir=save_dir,
90-
training_outputs_dir=path_to_training_outputs,
91-
)
92-
new_zoo_model = Model(str(save_dir))
93-
assert new_zoo_model.validate(minimal_validation=True, validate_onnxruntime=False)
94-
95-
96-
@pytest.mark.parametrize(
97-
"model_path, recipe_found",
98-
[
99-
("roneneldan/TinyStories-1M", False),
100-
("mgoin/all-MiniLM-L6-v2-quant-ds", True),
101-
(
102-
"zoo:mobilebert-squad_wikipedia_bookcorpus-14layer_pruned50.4block_quantized", # noqa E501
103-
True,
104-
),
105-
],
106-
)
107-
def test_infer_recipe_from_model_path(model_path, recipe_found):
108-
recipe = infer_recipe_from_model_path(model_path)
109-
if recipe_found:
110-
assert isinstance(recipe, str)
111-
return
112-
assert recipe is None
76+
shutil.rmtree(tmp_path)
11377

11478

11579
def test_infer_recipe_from_local_model_path(tmp_path):
@@ -124,6 +88,16 @@ def test_infer_recipe_from_local_model_path(tmp_path):
12488
assert recipe == recipe_path.as_posix()
12589

12690

91+
@pytest.fixture(autouse=True)
92+
def model_path_and_recipe_path(tmp_path):
93+
model_path = tmp_path / "model.onnx"
94+
recipe_path = tmp_path / "recipe.yaml"
95+
recipe_path.touch()
96+
model_path.touch()
97+
98+
return model_path, recipe_path
99+
100+
127101
@pytest.mark.parametrize(
128102
"model_path",
129103
[
@@ -140,16 +114,6 @@ def test_resolve_recipe_file(model_path, model_path_and_recipe_path):
140114
)
141115

142116

143-
@pytest.fixture()
144-
def model_path_and_recipe_path(tmp_path):
145-
model_path = tmp_path / "model.onnx"
146-
recipe_path = tmp_path / "recipe.yaml"
147-
recipe_path.touch()
148-
model_path.touch()
149-
150-
return model_path, recipe_path
151-
152-
153117
def test_resolve_recipe_file_from_local_path(model_path_and_recipe_path):
154118
model_path, recipe_path = model_path_and_recipe_path
155119
assert recipe_path.as_posix() == resolve_recipe_file(
@@ -165,24 +129,40 @@ def test_resolve_recipe_file_from_local_path(model_path_and_recipe_path):
165129
)
166130

167131

168-
def test_create_fake_dataloader(generative_model_path, sequence_length):
169-
expected_input_names = ["input_ids", "attention_mask"]
170-
sequence_length = 32
171-
num_samples = 2
132+
@pytest.mark.parametrize(
133+
"model, recipe_found",
134+
[
135+
("roneneldan/TinyStories-1M", False),
136+
("mgoin/all-MiniLM-L6-v2-quant-ds", True),
137+
(
138+
"zoo:mobilebert-squad_wikipedia_bookcorpus-14layer_pruned50.4block_quantized", # noqa E501
139+
True,
140+
),
141+
],
142+
)
143+
def test_infer_recipe_from_model_path(model, recipe_found):
144+
recipe = infer_recipe_from_model_path(model)
145+
if recipe_found:
146+
assert isinstance(recipe, str)
147+
return
148+
assert recipe is None
172149

173-
model = transformers.AutoModelForCausalLM.from_pretrained(generative_model_path)
174-
tokenizer = initialize_tokenizer(
175-
generative_model_path, sequence_length=sequence_length, task="text-generation"
176-
)
177-
data_loader, input_names = create_fake_dataloader(
178-
model=model,
179-
tokenizer=tokenizer,
180-
num_samples=num_samples,
181-
)
182150

183-
assert input_names == expected_input_names
184-
for i, sample in enumerate(data_loader):
185-
assert sample["input_ids"].shape == torch.Size([1, sequence_length])
186-
assert sample["attention_mask"].shape == torch.Size([1, sequence_length])
187-
assert set(sample.keys()) == set(expected_input_names)
188-
assert i == num_samples - 1
151+
@pytest.mark.parametrize(
152+
"stub",
153+
[
154+
"zoo:nlp/question_answering/bert-base/pytorch/huggingface/squad/pruned95_obs_quant-none", # noqa E501
155+
],
156+
)
157+
def test_save_zoo_directory(tmp_path, stub):
158+
path_to_training_outputs = Model(stub).path
159+
save_dir = tmp_path
160+
161+
save_zoo_directory(
162+
output_dir=save_dir,
163+
training_outputs_dir=path_to_training_outputs,
164+
)
165+
zoo_model = Model(str(save_dir))
166+
assert zoo_model.validate(minimal_validation=True, validate_onnxruntime=False)
167+
shutil.rmtree(path_to_training_outputs)
168+
shutil.rmtree(save_dir)

0 commit comments

Comments
 (0)