diff --git a/src/pyrecest/_backend/pytorch/random.py b/src/pyrecest/_backend/pytorch/random.py index ad943bf03..081b3a432 100644 --- a/src/pyrecest/_backend/pytorch/random.py +++ b/src/pyrecest/_backend/pytorch/random.py @@ -20,11 +20,22 @@ def _choice_size(size): return size, int(_torch.prod(_torch.tensor(size)).item()) +def _choice_population(a, p=None): + """Normalize NumPy-compatible ``choice`` populations to a tensor.""" + device = p.device if _torch.is_tensor(p) else None + if isinstance(a, int): + if a <= 0: + raise ValueError("a must be greater than 0 unless no samples are taken") + return _torch.arange(a, device=device) + if _torch.is_tensor(a): + return a + return _torch.as_tensor(a, device=device) + + def choice(a, size=None, replace=True, p=None): - assert _torch.is_tensor(a), "a must be a tensor" + a = _choice_population(a, p=p) size, num_samples = _choice_size(size) if p is not None: - assert _torch.is_tensor(p), "p must be a tensor" if not replace and num_samples > len(a): raise ValueError( "Cannot take a larger sample than population when 'replace=False'." @@ -104,4 +115,4 @@ def multivariate_normal(mean, cov, size=None): size = () elif not hasattr(size, "__iter__"): size = (size,) - return _MultivariateNormal(mean, cov).sample(size) + return _MultivariateNormal(mean, cov).sample(size) \ No newline at end of file diff --git a/tests/test_backend_random.py b/tests/test_backend_random.py index ebb855826..22a0abf3a 100644 --- a/tests/test_backend_random.py +++ b/tests/test_backend_random.py @@ -28,6 +28,22 @@ def test_choice_single_sample_preserves_sample_axis(self): self.assertEqual(sample.shape, (1, 2)) + def test_choice_accepts_python_list_population(self): + samples = random.choice([10, 20, 30], size=(32,)) + + self.assertEqual(samples.shape, (32,)) + npt.assert_array_less(9, samples) + npt.assert_array_less(samples, 31) + for sample in pyrecest.backend.to_numpy(samples).tolist(): + self.assertIn(sample, (10, 20, 30)) + + def test_choice_accepts_integer_population(self): + samples = random.choice(5, size=(32,)) + + self.assertEqual(samples.shape, (32,)) + npt.assert_array_less(-1, samples) + npt.assert_array_less(samples, 5) + @unittest.skipIf( pyrecest.backend.__backend_name__ != "jax", "JAX-specific RNG state contract" ) @@ -59,4 +75,4 @@ def test_jax_multinomial_uses_and_advances_global_state(self): if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file