diff --git a/onnxscript/function_libs/torch_lib/ops/vision.py b/onnxscript/function_libs/torch_lib/ops/vision.py index 5c1b1fda6b..01ed9314bc 100644 --- a/onnxscript/function_libs/torch_lib/ops/vision.py +++ b/onnxscript/function_libs/torch_lib/ops/vision.py @@ -55,16 +55,16 @@ def _process_sampling_ratio_for_roi_align(sampling_ratio: int): @torch_op("torchvision::roi_align", trace_only=True) def torchvision_roi_align( input, - boxes, - output_size: Sequence[int], - spatial_scale: float = 1.0, + rois, + spatial_scale: float, + pooled_height: int, + pooled_width: int, sampling_ratio: int = -1, aligned: bool = False, ): - """roi_align(input: torch.Tensor, boxes: Union[torch.Tensor, list[torch.Tensor]], output_size: None, spatial_scale: float = 1.0, sampling_ratio: int = -1, aligned: bool = False) -> torch.Tensor""" - pooled_height, pooled_width = output_size - batch_indices = _process_batch_indices_for_roi_align(boxes) - rois_coords = _process_rois_for_roi_align(boxes) + """torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, bool aligned) -> Tensor""" + batch_indices = _process_batch_indices_for_roi_align(rois) + rois_coords = _process_rois_for_roi_align(rois) coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel" sampling_ratio = _process_sampling_ratio_for_roi_align(sampling_ratio) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index a28a6c9cd9..29df92f097 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -1471,81 +1471,38 @@ def sample_inputs_replication_pad1d(op_info, device, dtype, requires_grad, **kwa def sample_inputs_roi_align(op_info, device, dtype, requires_grad, **kwargs): - del op_info - del kwargs - # roi_align signature: (input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False) - - # Test 1: spatial_scale=1, sampling_ratio=2 - x1 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) - roi1 = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=dtype, device=device) - yield opinfo_core.SampleInput( - x1, - args=(roi1, (5, 5)), - kwargs={"spatial_scale": 1.0, "sampling_ratio": 2, "aligned": True}, - ) - - # Test 2: spatial_scale=0.5, sampling_ratio=3 - x2 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) - roi2 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device) - yield opinfo_core.SampleInput( - x2, - args=(roi2, (5, 5)), - kwargs={"spatial_scale": 0.5, "sampling_ratio": 3, "aligned": True}, - ) - - # Test 3: spatial_scale=1.8, sampling_ratio=2 - x3 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) - roi3 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device) - yield opinfo_core.SampleInput( - x3, - args=(roi3, (5, 5)), - kwargs={"spatial_scale": 1.8, "sampling_ratio": 2, "aligned": True}, - ) - - # Test 4: spatial_scale=2.5, sampling_ratio=0, output_size=(2,2) - x4 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) - roi4 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device) - yield opinfo_core.SampleInput( - x4, - args=(roi4, (2, 2)), - kwargs={"spatial_scale": 2.5, "sampling_ratio": 0, "aligned": True}, - ) + del op_info, kwargs - # Test 5: spatial_scale=2.5, sampling_ratio=-1, output_size=(2,2) - x5 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) - roi5 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device) - yield opinfo_core.SampleInput( - x5, - args=(roi5, (2, 2)), - kwargs={"spatial_scale": 2.5, "sampling_ratio": -1, "aligned": True}, - ) + def make_x(): + return torch.rand( + 1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad + ) - # Test 6: malformed boxes (test_roi_align_malformed_boxes) - x6 = torch.randn(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) - roi6 = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=dtype, device=device) - yield opinfo_core.SampleInput( - x6, - args=(roi6, (5, 5)), - kwargs={"spatial_scale": 1.0, "sampling_ratio": 1, "aligned": True}, - ) + # rois is [K, 5] = [batch_idx, x1, y1, x2, y2] + roi_a = torch.tensor([[0, 1.5, 1.5, 3.0, 3.0]], dtype=dtype, device=device) + roi_b = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device) + roi_int = torch.tensor([[0, 0.0, 0.0, 4.0, 4.0]], dtype=dtype, device=device) + roi_malformed = torch.tensor( + [[0, 2.0, 0.3, 1.5, 1.5]], dtype=dtype, device=device + ) # x1 > x2-ish - # Test 7: aligned=False, spatial_scale=1, sampling_ratio=2 - x7 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) - roi7 = torch.tensor([[0, 0, 0, 4, 4]], dtype=dtype, device=device) - yield opinfo_core.SampleInput( - x7, - args=(roi7, (5, 5)), - kwargs={"spatial_scale": 1.0, "sampling_ratio": 2, "aligned": False}, - ) + # (rois, spatial_scale, pooled_h, pooled_w, sampling_ratio, aligned) + cases = [ + (roi_a, 1.0, 5, 5, 2, True), + (roi_b, 0.5, 5, 5, 3, True), + (roi_b, 1.8, 5, 5, 2, True), + (roi_b, 2.5, 2, 2, 0, True), + (roi_b, 2.5, 2, 2, -1, True), + (roi_malformed, 1.0, 5, 5, 1, True), + (roi_int, 1.0, 5, 5, 2, False), + (roi_int, 1.0, 5, 5, -1, False), + ] - # Test 8: aligned=False, spatial_scale=1, sampling_ratio=-1 - x8 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) - roi8 = torch.tensor([[0, 0, 0, 4, 4]], dtype=dtype, device=device) - yield opinfo_core.SampleInput( - x8, - args=(roi8, (5, 5)), - kwargs={"spatial_scale": 1.0, "sampling_ratio": -1, "aligned": False}, - ) + for rois, spatial_scale, ph, pw, sr, aligned in cases: + yield opinfo_core.SampleInput( + make_x(), + args=(rois, float(spatial_scale), int(ph), int(pw), int(sr), bool(aligned)), + ) def sample_inputs_roi_pool(op_info, device, dtype, requires_grad, **kwargs): @@ -3132,7 +3089,7 @@ def __init__(self): ), opinfo_core.OpInfo( "torchvision.ops.roi_align", - op=torchvision.ops.roi_align, + op=torch.ops.torchvision.roi_align.default, dtypes=common_dtype.floating_types(), sample_inputs_func=sample_inputs_roi_align, supports_out=False,