diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py index 2a708e1118e0..373ef32154d8 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py @@ -199,7 +199,11 @@ def __init__( super().__init__() if safety_checker is None: - safety_checker = CosmosSafetyChecker() + grad_enabled = torch.is_grad_enabled() + try: + safety_checker = CosmosSafetyChecker() + finally: + torch.set_grad_enabled(grad_enabled) self.register_modules( vae=vae, diff --git a/tests/pipelines/cosmos/test_cosmos2_video2world.py b/tests/pipelines/cosmos/test_cosmos2_video2world.py index 1b814257a30a..0221499b29fb 100644 --- a/tests/pipelines/cosmos/test_cosmos2_video2world.py +++ b/tests/pipelines/cosmos/test_cosmos2_video2world.py @@ -17,6 +17,7 @@ import os import tempfile import unittest +from unittest import mock import numpy as np import PIL.Image @@ -158,6 +159,27 @@ def test_inference(self): generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3)) + def test_default_safety_checker_preserves_grad_mode(self): + components = self.get_dummy_components() + components.pop("safety_checker") + + class GradDisablingCosmosSafetyChecker: + def __init__(self): + torch.set_grad_enabled(False) + + original_grad_mode = torch.is_grad_enabled() + torch.set_grad_enabled(True) + try: + with mock.patch( + "diffusers.pipelines.cosmos.pipeline_cosmos2_video2world.CosmosSafetyChecker", + GradDisablingCosmosSafetyChecker, + ): + self.pipeline_class(**components) + + self.assertTrue(torch.is_grad_enabled()) + finally: + torch.set_grad_enabled(original_grad_mode) + def test_components_function(self): init_components = self.get_dummy_components() init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))}