diff --git a/src/pruna/algorithms/base/pruna_base.py b/src/pruna/algorithms/base/pruna_base.py index 0784069b..16512372 100644 --- a/src/pruna/algorithms/base/pruna_base.py +++ b/src/pruna/algorithms/base/pruna_base.py @@ -16,6 +16,7 @@ import functools from abc import ABC, abstractmethod +from pathlib import Path from typing import Any, Dict, Iterable from transformers import Pipeline @@ -355,8 +356,8 @@ def apply(self, model: Any, smash_config: SmashConfig) -> Any: Any The model after the algorithm has been applied. """ - if self.save_fn == SAVE_FUNCTIONS.save_before_apply and smash_config._prepare_saving: - save_dir = smash_config.cache_dir / SAVE_BEFORE_SMASH_CACHE_DIR + if self.save_fn == SAVE_FUNCTIONS.save_before_apply and smash_config.prepare_saving: + save_dir = self.get_save_before_smash_dir(smash_config) save_pruna_model(model, save_dir, smash_config) # save algorithms to reapply after loading @@ -369,7 +370,23 @@ def apply(self, model: Any, smash_config: SmashConfig) -> Any: prefix = self.algorithm_name + "_" wrapped_config = SmashConfigPrefixWrapper(smash_config, prefix) - return self._apply(model, wrapped_config) + result = self._apply(model, wrapped_config) + + self.post_apply_hook(model, smash_config) + return result + + def post_apply_hook(self, model: Any, smash_config: SmashConfig) -> None: + """ + Post apply hook called after _apply returns to run side effects after the algorithm applies. + + Parameters + ---------- + model : Any + The model applied with the algorithm. + smash_config : SmashConfig + The SmashConfig object. + """ + return def get_compatible_algorithms(self) -> list[str]: """ @@ -447,6 +464,23 @@ def get_algorithms_to_run_after_disjointly(self) -> list[str]: """ return _expand_tags_into_algorithm_names(self.disjointly_compatible_after) + @staticmethod + def get_save_before_smash_dir(smash_config: SmashConfig) -> Path: + """ + Get the save directory for the algorithm caches. + + Parameters + ---------- + smash_config : SmashConfig + The SmashConfig to check the cache directory against. + + Returns + ------- + Path + The absolute path of "SAVE_BEFORE_SMASH_CACHE_DIR". + """ + return (smash_config.cache_dir / SAVE_BEFORE_SMASH_CACHE_DIR).resolve() + def wrap_handle_imports(func): """ diff --git a/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py b/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py index 4b65d9df..b7f0f5b4 100644 --- a/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py +++ b/src/pruna/algorithms/global_utils/recovery/perp_recoverer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABCMeta from typing import Any, Dict import torch @@ -23,7 +24,7 @@ from pruna.algorithms.global_utils.recovery.finetuners import PrunaFinetuner from pruna.algorithms.global_utils.recovery.finetuners.diffusers.utils import get_denoiser_attr from pruna.algorithms.global_utils.recovery.utils import get_trainable_parameters -from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.config.smash_config import SmashConfig, SmashConfigPrefixWrapper from pruna.engine.model_checks import ( is_causal_lm, is_flux_pipeline, @@ -31,11 +32,11 @@ is_sd_pipeline, is_sdxl_pipeline, ) -from pruna.engine.save import SAVE_FUNCTIONS +from pruna.engine.save import refresh_saved_model from pruna.logging.logger import pruna_logger -class PERPRecoverer(PrunaAlgorithmBase): +class PERPRecoverer(PrunaAlgorithmBase, metaclass=ABCMeta): """ General purpose PERP recoverer using norm, head and bias finetuning and optionally HuggingFace's LoRA. @@ -52,7 +53,7 @@ class PERPRecoverer(PrunaAlgorithmBase): """ group_tags: list[AlgorithmTag] = [AlgorithmTag.RECOVERER] # type: ignore[attr-defined] - save_fn = SAVE_FUNCTIONS.pickled + save_fn = None references: dict[str, str] = { "GitHub": "https://github.com/huggingface/peft", "Paper": "https://arxiv.org/pdf/2312.15230", @@ -63,7 +64,6 @@ class PERPRecoverer(PrunaAlgorithmBase): def __init__(self, task_name: str, use_lora: bool, use_in_place: bool, is_distillation: bool) -> None: self.task_name = task_name - self.tokenizer_required = task_name == "text_to_text" # type: ignore[misc] if not use_lora and not use_in_place: raise ValueError("Arguments use_lora and use_in_place cannot both be False, please use one of the two.") @@ -89,6 +89,11 @@ def __init__(self, task_name: str, use_lora: bool, use_in_place: bool, is_distil super().__init__() # self.adapters need to be set before calling get_hyperparameters + @property + def tokenizer_required(self) -> bool: + """Overwritten ``tokenizer_required`` property.""" + return self.task_name == "text_to_text" + def get_hyperparameters(self) -> list: """ Configure all algorithm-specific hyperparameters with ConfigSpace. @@ -181,6 +186,20 @@ def _pre_smash_hook(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> adapter_smash_config = SmashConfigPrefixWrapper(smash_config, adapter.adapter_prefix + "_") adapter.pre_smash_hook(model_recovery, adapter_smash_config, seed=adapter_seed) + def post_apply_hook(self, model: Any, smash_config: SmashConfig): + """ + Override to run side effects after the algorithm has been applied. + + Parameters + ---------- + model : Any + The model. + smash_config : SmashConfig + The SmashConfig configuration to apply. + """ + if smash_config.prepare_saving: + refresh_saved_model(model, self.get_save_before_smash_dir(smash_config), smash_config) + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: """ Recover performances from a given model with a given config. diff --git a/src/pruna/config/smash_config.py b/src/pruna/config/smash_config.py index baaba7f7..10ad77b1 100644 --- a/src/pruna/config/smash_config.py +++ b/src/pruna/config/smash_config.py @@ -119,6 +119,11 @@ def __init__( raise ValueError(f"Unsupported configuration type: {type(configuration)}") self.config_space: ConfigurationSpace = self._configuration.config_space + @property + def prepare_saving(self): + """Getter of _prepare_saving as an object's internal data.""" + return self._prepare_saving + @classmethod def from_list( cls, diff --git a/src/pruna/engine/save.py b/src/pruna/engine/save.py index 27101b31..c2df1265 100644 --- a/src/pruna/engine/save.py +++ b/src/pruna/engine/save.py @@ -355,9 +355,7 @@ def save_model_hqq(model: Any, model_path: str | Path, smash_config: SmashConfig # save pipeline info so we can call transformers.pipeline at load time save_pipeline_info(model, str(model_path)) # pipeline loading requires a safetensor file so we save a fake, lightweight one - save_model( - torch.nn.Linear(1, 1), str(model_path / "model.safetensors"), metadata={"format": "pt"} - ) + save_model(torch.nn.Linear(1, 1), str(model_path / "model.safetensors"), metadata={"format": "pt"}) save_model_hqq(model.model, model_path, smash_config) elif is_janus_llamagen_ar(model): @@ -470,6 +468,37 @@ def save_component(attr_name: str | None, module: torch.nn.Module, subpaths: lis smash_config.load_fns.append(LOAD_FUNCTIONS.hqq_diffusers.name) +def refresh_saved_model(model: Any, model_path: Path, smash_config: SmashConfig) -> None: + """ + Refresh the saved save-before-apply model, and the cache will reflect the current model state. + + Recovery modifies weights in-place without changing the model's serialization + format. If a prior algorithm used ``save_before_apply`` (caching the model before + its transformation), the cached snapshot is now stale because recovery changed + the weights. This override refreshes that cache so the already saved model includes + the recovered weights. + + Parameters + ---------- + model : Any + The model to apply the algorithm to. + model_path : Path + The model path to be saved. + smash_config : SmashConfig + The SmashConfig object containing the save and load functions. + """ + if not model_path.exists(): + return None + + ori_save_fns = smash_config.save_fns[:] + smash_config.save_fns = [fn for fn in smash_config.save_fns if fn != SAVE_FUNCTIONS.save_before_apply.name] + # Re-save with recovered weights + shutil.rmtree(model_path, ignore_errors=True) + save_pruna_model(model, model_path, smash_config) + # Restore save_fns + smash_config.save_fns = ori_save_fns + + def reapply(model: Any, model_path: str | Path, smash_config: SmashConfig) -> None: """ Reapply the model. diff --git a/tests/engine/test_save.py b/tests/engine/test_save.py index 2cd6f3e1..7a531ae7 100644 --- a/tests/engine/test_save.py +++ b/tests/engine/test_save.py @@ -1,18 +1,17 @@ import os -import pytest -import torch from pathlib import Path from unittest.mock import patch + +import pytest +import torch +from diffusers import DiffusionPipeline from transformers import AutoModelForCausalLM -from pruna.config.smash_config import SmashConfig + from pruna import smash -from pruna.engine.save import save_pruna_model -from pruna.engine.save import save_pruna_model_to_hub -from pruna.engine.save import SAVE_FUNCTIONS -from pruna.engine.load import load_pruna_model from pruna.config.smash_config import SmashConfig -from diffusers import DiffusionPipeline +from pruna.engine.load import load_pruna_model from pruna.engine.pruna_model import PrunaModel +from pruna.engine.save import SAVE_FUNCTIONS, save_pruna_model, save_pruna_model_to_hub @pytest.mark.slow @@ -29,6 +28,7 @@ def test_save_llm_to_hub() -> None: ) pruna_model.push_to_hub(upload_repo_id, private=False) + @pytest.mark.slow @pytest.mark.cpu def test_save_diffusers_to_hub() -> None: @@ -160,3 +160,39 @@ def test_push_to_hub_path_types(tmp_path) -> None: private=True ) assert mock_upload.called + + +@pytest.mark.cpu +def test_perp_post_apply_hook_round_trip(tmp_path) -> None: + """Test whether PERPRecoverer saves the correct model and load from it.""" + from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase + from pruna.algorithms.global_utils.recovery.perp_recoverer import PERPRecoverer + + class FakeRecoverer(PERPRecoverer): + algorithm_name = "test_fake_recoverer" + + def __init__(self): # noqa: D107 + pass + + def _apply(self, model, smash_config): # noqa: D401 + model.weight.data.fill_(0.99) + return model + + model = torch.nn.Linear(3, 2) + model.weight.data.fill_(0.1) + config = SmashConfig(device="cpu") + + save_dir = PrunaAlgorithmBase.get_save_before_smash_dir(config) + save_pruna_model(model, save_dir, config) + config.save_fns.append(SAVE_FUNCTIONS.save_before_apply.name) + + model = FakeRecoverer().apply(model, config) + + save_path = tmp_path / "final_model" + save_pruna_model(model, save_path, config) + + loaded_model, _ = load_pruna_model(save_path) + loaded_model = loaded_model.cpu() + assert torch.allclose( + loaded_model.weight, torch.full_like(loaded_model.weight, 0.99) + ), "Recovered weights should survive save/load through save_before_apply"