Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 2c3bdf7

Browse files
author
Sara Adkins
authored
Make SmoothQuant/LogEqualization FSDP Compatible (#2025) (#2178)
1 parent d64668d commit 2c3bdf7

File tree

4 files changed

+57
-13
lines changed

4 files changed

+57
-13
lines changed

src/sparseml/core/modifier/stage.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,11 @@ def initialize(self, state: "State", **kwargs):
137137
if self.applied:
138138
return
139139

140+
accelerator = kwargs.get("accelerator", None)
140141
for modifier in self.modifiers:
141142
modifier.initialize(state, **kwargs)
143+
if accelerator:
144+
accelerator.wait_for_everyone()
142145
state.loggers.system.info(tag="stage", string="Modifiers initialized")
143146

144147
def finalize(self, state: "State", **kwargs):
@@ -153,8 +156,11 @@ def finalize(self, state: "State", **kwargs):
153156
if self.applied:
154157
return
155158

159+
accelerator = kwargs.get("accelerator", None)
156160
for modifier in self.modifiers:
157161
modifier.finalize(state, **kwargs)
162+
if accelerator:
163+
accelerator.wait_for_everyone()
158164

159165
self.applied = True
160166
state.loggers.system.info(tag="stage", string="Modifiers finalized")

src/sparseml/modifiers/smoothquant/pytorch.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sparseml.core.model.pytorch import ModifiableModelPyTorch
2323
from sparseml.modifiers.smoothquant.base import SmoothQuantModifier, SmoothQuantScale
2424
from sparseml.modifiers.utils.pytorch_helpers import run_calibration_forward
25+
from sparseml.utils.fsdp.helpers import get_fsdp_parent
2526

2627

2728
_LOGGER = logging.getLogger(__name__)
@@ -56,7 +57,7 @@ def on_initialize(self, state: State, **kwargs) -> bool:
5657

5758
self._setup_scale_hooks()
5859
self._calibrate(state.model, calibration_dataloader)
59-
self._apply_smoothing()
60+
self._apply_smoothing(state.model)
6061

6162
return True
6263

@@ -138,7 +139,7 @@ def _calibrate(self, model: ModifiableModelPyTorch, calibration_dataloader: List
138139
del self.hooks_
139140

140141
@torch.no_grad()
141-
def _apply_smoothing(self):
142+
def _apply_smoothing(self, model: ModifiableModelPyTorch):
142143
"""
143144
After calibration, apply smoothing to the activations and push the transform
144145
into the following weights by applying the inverse to each balance weight.
@@ -162,17 +163,26 @@ def _apply_smoothing(self):
162163
scales, torch.Tensor([MINIMUM_SMOOTHING_SCALE]).to(scales.device)
163164
)
164165

165-
# invert the smoothing in the following layers
166-
for layer in balance_layers:
167-
layer.weight.mul_(scales.view(1, -1))
168-
169-
# apply the smoothing
170-
if smooth_layer.weight.ndim == 1:
171-
smooth_layer.weight.div_(scales)
166+
@torch.no_grad()
167+
def smooth(module):
168+
if module in balance_layers:
169+
module.weight.mul_(scales.view(1, -1))
170+
elif module == smooth_layer:
171+
if module.weight.ndim == 1:
172+
module.weight.div_(scales)
173+
else:
174+
module.weight.div_(scales.view(-1, 1))
175+
if hasattr(module, "bias") and module.bias is not None:
176+
module.bias.div_(scales)
177+
178+
parent = get_fsdp_parent(mapping.smooth_name, model.model)
179+
if parent is not None:
180+
parent.apply(smooth)
172181
else:
173-
smooth_layer.weight.div_(scales.view(-1, 1))
174-
if hasattr(smooth_layer, "bias") and smooth_layer.bias is not None:
175-
smooth_layer.bias.div_(scales)
182+
# if we're not running with FSDP we can apply smoothing directly
183+
for layer in balance_layers:
184+
smooth(layer)
185+
smooth(smooth_layer)
176186

177187
def _calculate_smoothing_scales(
178188
self, balance_layers: List[Module], activation_scales: torch.Tensor

src/sparseml/transformers/finetune/session_mixin.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,7 @@ def one_shot(self, calib_data: DataLoader, stage: Optional[str] = None):
431431
calib_data=calib_data,
432432
start=-1,
433433
copy_data=False,
434+
accelerator=self.accelerator,
434435
)
435436

436437
self.accelerator.wait_for_everyone()

src/sparseml/utils/fsdp/helpers.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Union
15+
import operator
16+
from typing import Optional, Union
1617

1718

1819
try:
@@ -37,6 +38,7 @@
3738
"set_wrapped_model",
3839
"unwrap_and_export_model",
3940
"save_pretrained_fsdp",
41+
"get_fsdp_parent",
4042
]
4143

4244

@@ -132,3 +134,28 @@ def save_pretrained_fsdp(model, accelerator, output_dir):
132134
save_function=accelerator.save,
133135
state_dict=state_dict,
134136
)
137+
138+
139+
def get_fsdp_parent(layer_name: str, model: Module) -> Optional[Module]:
140+
"""
141+
Gets the closest parent of layer_name that is wrapped by FSDP. If no FSDP wrapper
142+
is found just return None
143+
144+
:param layer_name: layer name in model to get parent of
145+
:model: pytorch module to search through
146+
:return: FSDP wrapped parent of layer_name if available, otherwise None
147+
"""
148+
if not is_fsdp_model(model):
149+
return None
150+
151+
parent_name = layer_name
152+
parent = operator.attrgetter(parent_name)(model)
153+
while not isinstance(parent, FullyShardedDataParallel):
154+
if len(parent_name) == 0: # we've reached the root module and its not FSDP
155+
# this should never get hit because we check for an FSDP root above
156+
# but while statements without a backup are too scary
157+
return None
158+
parent_name = ".".join(parent_name.split(".")[:-1])
159+
parent = operator.attrgetter(parent_name)(model)
160+
161+
return parent

0 commit comments

Comments
 (0)