Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit d7c5dc4

Browse files
[transformers] Delegate creating GradSampler to PruningModifiers (#831)
* Add customizable grad sampler Add customizable grad sampler Change default arg to None * Update MFAC with GradSampler changes Update mfac tests to reflect GradSampler changes Add grad_sampler_kwargs as ModifierProp Fix style * Remove apostrophe from docstring * Fix typos * quality fix Co-authored-by: Benjamin Fineran <bfineran@users.noreply.github.com>
1 parent be28f31 commit d7c5dc4

File tree

3 files changed

+114
-77
lines changed

3 files changed

+114
-77
lines changed

src/sparseml/pytorch/sparsification/pruning/modifier_pruning_mfac.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import os
2222
from abc import ABC, abstractmethod
2323
from functools import wraps
24-
from typing import Dict, List, Optional, Union
24+
from typing import Any, Dict, List, Optional, Union
2525

2626
import torch
2727
import torch.distributed as dist
@@ -83,6 +83,8 @@ class MFACPruningModifier(BaseGradualPruningModifier):
8383
| num_grads: {0.0: 64, 0.5: 128, 0.75: 256, 0.85: 512}
8484
| fisher_block_size: 10000
8585
| available_devices: ["cuda:0"]
86+
| grad_sampler_kwargs:
87+
| batch_size: 8
8688
8789
:param init_sparsity: the initial sparsity for the param to start with at
8890
start_epoch
@@ -130,6 +132,8 @@ class MFACPruningModifier(BaseGradualPruningModifier):
130132
:param mask_type: String to define type of sparsity to apply. May be 'unstructred'
131133
for unstructured pruning or 'block4' for four block pruning or a list of two
132134
integers for a custom block shape. Default is 'unstructured'
135+
:param grad_sampler_kwargs: kwargs to override default train dataloader config
136+
for gradient sampling.
133137
"""
134138

135139
def __init__(
@@ -151,6 +155,7 @@ def __init__(
151155
num_pages: int = 1, # break computation into pages when block size is None
152156
available_devices: Optional[List[str]] = None,
153157
mask_type: str = "unstructured",
158+
grad_sampler_kwargs: Optional[Dict[str, Any]] = None,
154159
):
155160
super().__init__(
156161
params=params,
@@ -172,6 +177,7 @@ def __init__(
172177
self._fisher_block_size = fisher_block_size
173178
self._num_pages = num_pages
174179
self._mask_type = mask_type
180+
self._grad_sampler_kwargs = grad_sampler_kwargs
175181
if available_devices is None:
176182
if torch.cuda.device_count() > 0:
177183
self._available_devices = ["cuda:0"]
@@ -229,6 +235,13 @@ def available_devices(self) -> Optional[List[str]]:
229235
"""
230236
return self._available_devices
231237

238+
@ModifierProp(serializable=True)
239+
def grad_sampler_kwargs(self) -> Optional[Dict[str, Any]]:
240+
"""
241+
Return dict of training dataloader configs overridden for gradient sampling
242+
"""
243+
return self._grad_sampler_kwargs
244+
232245
@ModifierProp()
233246
def mask_type(self) -> str:
234247
"""
@@ -259,14 +272,21 @@ def initialize(
259272
if "grad_sampler" in kwargs and self._use_gradient_buffering is not True:
260273
# set grad sampler, must be done before initialize in case pruning step
261274
# occurs on initialize epoch
262-
grad_sampler = kwargs["grad_sampler"]
263-
if not isinstance(grad_sampler, GradSampler):
264-
raise ValueError(
265-
"grad_sampler must be an instance of the GradSampler class"
275+
if (
276+
"data_loader_builder" not in kwargs["grad_sampler"]
277+
or "loss_function" not in kwargs["grad_sampler"]
278+
):
279+
raise RuntimeError(
280+
"grad_sampler dict with data_loader_builder and loss_function "
281+
"must be provided to initialize GradSampler"
266282
)
267-
self._grad_sampler = grad_sampler
283+
self._grad_sampler = GradSampler(
284+
kwargs["grad_sampler"]["data_loader_builder"](
285+
self._grad_sampler_kwargs
286+
),
287+
kwargs["grad_sampler"]["loss_function"],
288+
)
268289
self.log_string("Using provided GradSampler")
269-
270290
elif self._use_gradient_buffering is False:
271291
raise RuntimeError(
272292
"grad_sampler must be provided when use_gradient_buffering is set"

src/sparseml/transformers/sparsification/trainer.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
import torch
2929
from torch import distributed as dist
3030
from torch.nn import Module
31-
from torch.utils.data import RandomSampler
3231
from transformers import Trainer as TransformersTrainer
3332
from transformers import TrainerCallback, TrainerControl, TrainingArguments
3433
from transformers.file_utils import WEIGHTS_NAME
@@ -38,7 +37,6 @@
3837

3938
from sparseml.pytorch.optim import ScheduledModifierManager, ScheduledOptimizer
4039
from sparseml.pytorch.utils import (
41-
GradSampler,
4240
LoggerManager,
4341
ModuleSparsificationInfo,
4442
TensorBoardLogger,
@@ -157,9 +155,6 @@ def __init__(
157155
self.callback_disable_fp16 = DisableHalfPrecisionCallback(self)
158156
self.callback_handler.add_callback(self.callback_disable_fp16)
159157
self._add_tensorboard_logger_if_available()
160-
self.grad_sampler = GradSampler(
161-
self._mfac_data_loader(), self._mfac_loss_function
162-
)
163158

164159
model_signature = inspect.signature(self.model.forward)
165160
self._model_signature_columns = list(model_signature.parameters.keys())
@@ -275,7 +270,10 @@ def create_optimizer(self):
275270
wrap_optim=self.scaler,
276271
loggers=self.logger_manager,
277272
distillation_teacher=self.teacher,
278-
grad_sampler=self.grad_sampler,
273+
grad_sampler={
274+
"data_loader_builder": self._data_loader_builder,
275+
"loss_function": self._loss_function,
276+
},
279277
)
280278
else:
281279
wrap_optim_key = "optimizer"
@@ -286,16 +284,22 @@ def create_optimizer(self):
286284
steps_per_epoch=self.manager_steps_per_epoch,
287285
loggers=self.logger_manager,
288286
initialize_kwargs={
289-
"grad_sampler": self.grad_sampler,
290287
"distillation_teacher": self.teacher,
288+
"grad_sampler": {
289+
"data_loader_builder": self._data_loader_builder,
290+
"loss_function": self._loss_function,
291+
},
291292
},
292293
)
293294
if not self.manager.initialized:
294295
self.manager.initialize(
295296
self.model,
296297
loggers=self.logger_manager,
297298
distillation_teacher=self.teacher,
298-
grad_sampler=self.grad_sampler,
299+
grad_sampler={
300+
"data_loader_builder": self._data_loader_builder,
301+
"loss_function": self._loss_function,
302+
},
299303
)
300304
self.manager_initialized = True
301305
_LOGGER.info(
@@ -663,25 +667,21 @@ def _reload_model_state(self, load_path: str, orig_state_dict: Dict[str, Any]):
663667
delayed_load=False,
664668
)
665669

666-
def _mfac_data_loader(self):
667-
def dataloader():
668-
data_loader_template = self.get_train_dataloader()
669-
670-
data_loader = torch.utils.data.DataLoader(
671-
dataset=data_loader_template.dataset,
672-
batch_size=data_loader_template.batch_size // 2,
673-
sampler=RandomSampler(data_loader_template.dataset, replacement=False),
674-
num_workers=data_loader_template.num_workers,
675-
collate_fn=data_loader_template.collate_fn,
676-
pin_memory=data_loader_template.pin_memory,
677-
drop_last=data_loader_template.drop_last,
678-
timeout=data_loader_template.timeout,
679-
worker_init_fn=data_loader_template.worker_init_fn,
680-
generator=data_loader_template.generator,
681-
prefetch_factor=data_loader_template.prefetch_factor,
682-
persistent_workers=data_loader_template.persistent_workers,
683-
)
670+
def _data_loader_builder(self, kwargs: Optional[Dict[str, Any]] = None):
671+
default_loader = self.get_train_dataloader()
672+
template = dict(default_loader.__dict__)
684673

674+
# drop attributes that will be auto-initialized
675+
to_drop = [k for k in template if k.startswith("_") or k == "batch_sampler"]
676+
for item in to_drop:
677+
template.pop(item)
678+
679+
# override defaults if kwargs are given, for example via recipe
680+
if kwargs:
681+
template.update(kwargs)
682+
data_loader = type(default_loader)(**template)
683+
684+
while True: # infinite dataloading
685685
for sample in data_loader:
686686
if self.label_smoother is not None and "labels" in sample:
687687
label = sample.pop("labels")
@@ -690,9 +690,7 @@ def dataloader():
690690
sample = self._prepare_inputs(sample)
691691
yield [], sample, label
692692

693-
return dataloader
694-
695-
def _mfac_loss_function(self, model_outputs, loss_target):
693+
def _loss_function(self, model_outputs, loss_target):
696694
if loss_target is not None:
697695
loss = self.label_smoother(model_outputs, loss_target)
698696
else:

tests/sparseml/pytorch/sparsification/pruning/test_modifier_pruning_mfac.py

Lines changed: 60 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from typing import Any, Dict, Optional
1516

1617
import pytest
1718
import torch
1819
from torch.utils.data import DataLoader
1920

2021
from flaky import flaky
2122
from sparseml.pytorch.sparsification.pruning import MFACPruningModifier
22-
from sparseml.pytorch.utils import GradSampler, tensor_sparsity
23+
from sparseml.pytorch.utils import tensor_sparsity
2324
from sparseml.utils import FROM_PARAM_TOKEN
2425
from tests.sparseml.pytorch.helpers import MLPDataset, MLPNet
2526
from tests.sparseml.pytorch.sparsification.pruning.helpers import (
@@ -38,29 +39,32 @@
3839
)
3940

4041

41-
def _device_data_loader(data_loader):
42-
for sample in data_loader:
43-
img, target = [t for t in sample]
44-
yield [img], {}, target
45-
46-
47-
def _mfac_loss_function(model_outputs, loss_target):
48-
return torch.nn.functional.mse_loss(model_outputs[0], loss_target)
42+
def _get_loss_function():
43+
return lambda model_outputs, loss_target: torch.nn.functional.mse_loss(
44+
model_outputs[0], loss_target
45+
)
4946

5047

51-
def _build_gradient_sampler(
48+
def _get_dataloader_builder(
5249
dataset_lambda,
53-
loss_function,
54-
data_generator,
55-
batch_size,
50+
mfac_batch_size,
5651
num_grads,
5752
num_epochs,
5853
update_frequency,
5954
):
60-
data_length = int(batch_size * num_grads * num_epochs * (1 / update_frequency) * 2)
61-
dataset = dataset_lambda(length=data_length)
62-
data_loader = DataLoader(dataset, batch_size=batch_size)
63-
return GradSampler(data_generator(data_loader), loss_function)
55+
def dataloader_builder(kwargs: Optional[Dict[str, Any]] = None):
56+
batch_size = kwargs["batch_size"] if kwargs else mfac_batch_size
57+
data_length = int(
58+
mfac_batch_size * num_grads * num_epochs * (1 / update_frequency) * 2
59+
)
60+
dataset = dataset_lambda(length=data_length)
61+
data_loader = DataLoader(dataset, batch_size=batch_size)
62+
63+
for sample in data_loader:
64+
img, target = [t for t in sample]
65+
yield [img], {}, target
66+
67+
return dataloader_builder
6468

6569

6670
@flaky(max_runs=3, min_passes=2)
@@ -106,6 +110,9 @@ def _build_gradient_sampler(
106110
inter_func="cubic",
107111
num_grads=8,
108112
global_sparsity=True,
113+
grad_sampler_kwargs={
114+
"batch_size": 4,
115+
},
109116
),
110117
],
111118
scope="function",
@@ -121,8 +128,8 @@ def _build_gradient_sampler(
121128
)
122129
class TestMFACPruningModifier(ScheduledUpdateModifierTest):
123130
@pytest.mark.parametrize(
124-
"dataset_lambda,loss,mfac_batch_size",
125-
[(MLPDataset, _mfac_loss_function, 4)],
131+
"dataset_lambda,mfac_batch_size",
132+
[(MLPDataset, 4)],
126133
)
127134
def test_lifecycle(
128135
self,
@@ -131,21 +138,21 @@ def test_lifecycle(
131138
optim_lambda,
132139
test_steps_per_epoch, # noqa: F811
133140
dataset_lambda,
134-
loss,
135141
mfac_batch_size,
136142
):
137143
modifier = modifier_lambda()
138144
model = model_lambda()
139145
optimizer = optim_lambda(model)
140-
grad_sampler = _build_gradient_sampler(
141-
dataset_lambda,
142-
loss,
143-
_device_data_loader,
144-
mfac_batch_size,
145-
modifier.num_grads,
146-
modifier.end_epoch - modifier.start_epoch + 1,
147-
modifier.update_frequency,
148-
)
146+
grad_sampler = {
147+
"data_loader_builder": _get_dataloader_builder(
148+
dataset_lambda,
149+
mfac_batch_size,
150+
modifier.num_grads,
151+
modifier.end_epoch - modifier.start_epoch + 1,
152+
modifier.update_frequency,
153+
),
154+
"loss_function": _get_loss_function(),
155+
}
149156

150157
self.initialize_helper(modifier, model, grad_sampler=grad_sampler)
151158
if modifier.start_epoch > 0:
@@ -222,8 +229,8 @@ def _test_final_sparsity_applied():
222229
_test_final_sparsity_applied()
223230

224231
@pytest.mark.parametrize(
225-
"dataset_lambda,loss,mfac_batch_size",
226-
[(MLPDataset, _mfac_loss_function, 4)],
232+
"dataset_lambda,mfac_batch_size",
233+
[(MLPDataset, 4)],
227234
)
228235
def test_scheduled_update(
229236
self,
@@ -233,19 +240,20 @@ def test_scheduled_update(
233240
test_epoch, # noqa: F811
234241
test_steps_per_epoch, # noqa: F811
235242
dataset_lambda,
236-
loss,
237243
mfac_batch_size,
238244
):
239245
modifier = modifier_lambda()
240-
grad_sampler = _build_gradient_sampler(
241-
dataset_lambda,
242-
loss,
243-
_device_data_loader,
244-
mfac_batch_size,
245-
modifier.num_grads,
246-
modifier.end_epoch - modifier.start_epoch + 1,
247-
modifier.update_frequency,
248-
)
246+
grad_sampler = {
247+
"data_loader_builder": _get_dataloader_builder(
248+
dataset_lambda,
249+
mfac_batch_size,
250+
modifier.num_grads,
251+
modifier.end_epoch - modifier.start_epoch + 1,
252+
modifier.update_frequency,
253+
),
254+
"loss_function": _get_loss_function(),
255+
}
256+
249257
super().test_scheduled_update(
250258
modifier_lambda,
251259
model_lambda,
@@ -290,6 +298,7 @@ def test_mfac_pruning_yaml(params, init_sparsity, final_sparsity):
290298
num_pages = 1
291299
available_devices = ["cpu"]
292300
mask_type = "block4"
301+
batch_size = 4
293302
yaml_str = f"""
294303
!MFACPruningModifier
295304
init_sparsity: {init_sparsity}
@@ -307,6 +316,8 @@ def test_mfac_pruning_yaml(params, init_sparsity, final_sparsity):
307316
num_pages: {num_pages}
308317
available_devices: {available_devices}
309318
mask_type: {mask_type}
319+
grad_sampler_kwargs:
320+
batch_size: {batch_size}
310321
"""
311322
yaml_modifier = MFACPruningModifier.load_obj(yaml_str)
312323
serialized_modifier = MFACPruningModifier.load_obj(
@@ -328,6 +339,9 @@ def test_mfac_pruning_yaml(params, init_sparsity, final_sparsity):
328339
num_pages=num_pages,
329340
available_devices=available_devices,
330341
mask_type=mask_type,
342+
grad_sampler_kwargs={
343+
"batch_size": batch_size,
344+
},
331345
)
332346
assert isinstance(yaml_modifier, MFACPruningModifier)
333347
pruning_modifier_serialization_vals_test(
@@ -373,3 +387,8 @@ def test_mfac_pruning_yaml(params, init_sparsity, final_sparsity):
373387
== str(serialized_modifier.mask_type)
374388
== str(obj_modifier.mask_type)
375389
)
390+
assert (
391+
str(yaml_modifier._grad_sampler_kwargs)
392+
== str(serialized_modifier._grad_sampler_kwargs)
393+
== str(obj_modifier._grad_sampler_kwargs)
394+
)

0 commit comments

Comments
 (0)