diff --git a/src/diffusers/utils/accelerate_utils.py b/src/diffusers/utils/accelerate_utils.py index af3b712b5a9e..b1ad56f55227 100644 --- a/src/diffusers/utils/accelerate_utils.py +++ b/src/diffusers/utils/accelerate_utils.py @@ -15,6 +15,8 @@ Accelerate utilities: Utilities related to accelerate """ +import functools + from packaging import version from .import_utils import is_accelerate_available @@ -40,6 +42,7 @@ def apply_forward_hook(method): if version.parse(accelerate_version) < version.parse("0.17.0"): return method + @functools.wraps(method) def wrapper(self, *args, **kwargs): if hasattr(self, "_hf_hook") and hasattr(self._hf_hook, "pre_forward"): self._hf_hook.pre_forward(self) diff --git a/tests/others/test_utils.py b/tests/others/test_utils.py index 4600f5f3710a..646d335d2ec8 100755 --- a/tests/others/test_utils.py +++ b/tests/others/test_utils.py @@ -284,6 +284,27 @@ def _capture(target_device): ) +class ApplyForwardHookTester(unittest.TestCase): + """Tests for :func:`diffusers.utils.accelerate_utils.apply_forward_hook`.""" + + def test_preserves_wrapped_function_metadata(self): + import inspect + + from diffusers.utils.accelerate_utils import apply_forward_hook + + @apply_forward_hook + def example(self, x: int, y: int = 0) -> int: + """Example method docstring.""" + return x + y + + assert example.__name__ == "example" + assert example.__doc__ == "Example method docstring." + assert hasattr(example, "__wrapped__") + signature = inspect.signature(example) + assert list(signature.parameters) == ["self", "x", "y"] + assert signature.parameters["y"].default == 0 + + # Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py class ExpectationsTester(unittest.TestCase): def test_expectations(self):