diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py index 888db3e52b..63debde5bc 100644 --- a/pytensor/scalar/loop.py +++ b/pytensor/scalar/loop.py @@ -180,12 +180,16 @@ def perform(self, node, inputs, output_storage): inner_fn = self.py_perform_fn if self.is_while: - until = False - for i in range(n_steps): - *carry, until = inner_fn(*carry, *constant) - if until: - break - carry.append(until) + # If n_steps <= 0, the loop is skipped and done should be True + if n_steps <= 0: + carry.append(True) + else: + until = False + for i in range(n_steps): + *carry, until = inner_fn(*carry, *constant) + if until: + break + carry.append(until) else: if n_steps < 0: diff --git a/tests/scalar/test_loop.py b/tests/scalar/test_loop.py index d4f0f5b021..fd7d117d9f 100644 --- a/tests/scalar/test_loop.py +++ b/tests/scalar/test_loop.py @@ -149,6 +149,30 @@ def test_rebuild_dtype(): assert y.dtype == "float32" +def test_until_zero_steps(): + """Until loop should return done=True when n_steps <= 0.""" + from pytensor import Mode, function + from pytensor.scalar import float64, int64 + from pytensor.scalar.loop import ScalarLoop + + n_steps = int64("n_steps") + x0 = float64("x0") + + x = x0 + 1 + until = x > 10 + + op = ScalarLoop(init=[x0], update=[x], until=until) + + x_out, done = op(n_steps, x0) + + fn = function([n_steps, x0], [x_out, done], mode=Mode(linker="py")) + + done_val = fn(0, 0.0)[1] + + # According to docstring logic + assert done_val is True + + def test_non_scalar_error(): x0 = float64("x0") x = as_scalar(tensor_exp(x0))