From 20848109969f9c8accb228aac0231140ce3e92cb Mon Sep 17 00:00:00 2001 From: FraGirla <66215169+FraGirla@users.noreply.github.com> Date: Wed, 25 Feb 2026 19:02:17 +0100 Subject: [PATCH 1/5] [torchlib] Fix torchvision::roi_align lowering to accept 7-arg schema MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Dynamo ONNX export calls torchvision::roi_align with (input, rois, spatial_scale, pooled_h, pooled_w, sampling_ratio, aligned), but torchlib’s torchvision_roi_align expected output_size and rejected the extra args. Update torchvision_roi_align to take pooled_h/pooled_w directly and add an OpInfo input_wrangler to adapt wrapper-style roi_align(input, boxes, output_size, ...). Refs pytorch/pytorch#175732 Test: pytest tests/function_libs/torch_lib/ops_test.py -k roi_align -v --- .../function_libs/torch_lib/ops/vision.py | 13 ++++--- .../function_libs/torch_lib/ops_test_data.py | 35 ++++++++++++++++++- 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/vision.py b/onnxscript/function_libs/torch_lib/ops/vision.py index 5c1b1fda6b..65386f1577 100644 --- a/onnxscript/function_libs/torch_lib/ops/vision.py +++ b/onnxscript/function_libs/torch_lib/ops/vision.py @@ -55,16 +55,15 @@ def _process_sampling_ratio_for_roi_align(sampling_ratio: int): @torch_op("torchvision::roi_align", trace_only=True) def torchvision_roi_align( input, - boxes, - output_size: Sequence[int], - spatial_scale: float = 1.0, + rois, + spatial_scale: float, + pooled_height: int, + pooled_width: int, sampling_ratio: int = -1, aligned: bool = False, ): - """roi_align(input: torch.Tensor, boxes: Union[torch.Tensor, list[torch.Tensor]], output_size: None, spatial_scale: float = 1.0, sampling_ratio: int = -1, aligned: bool = False) -> torch.Tensor""" - pooled_height, pooled_width = output_size - batch_indices = _process_batch_indices_for_roi_align(boxes) - rois_coords = _process_rois_for_roi_align(boxes) + batch_indices = _process_batch_indices_for_roi_align(rois) + rois_coords = _process_rois_for_roi_align(rois) coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel" sampling_ratio = _process_sampling_ratio_for_roi_align(sampling_ratio) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index a40535f4ba..d7d62c75cf 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -449,6 +449,35 @@ def _where_input_wrangler( return args, kwargs +def _torchvision_roi_align_default_input_wrangler( + args: list[Any], kwargs: dict[str, Any] +) -> tuple[list[Any], dict[str, Any]]: + # Convert: + # roi_align(input, boxes, output_size, spatial_scale=..., sampling_ratio=..., aligned=...) + # into: + # roi_align(input, boxes, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned) + output_size = args.pop(2) + if isinstance(output_size, np.ndarray): + if output_size.ndim == 0: + pooled_height = int(output_size) + pooled_width = int(output_size) + else: + pooled_height, pooled_width = output_size.tolist() + elif isinstance(output_size, (tuple, list)): + pooled_height, pooled_width = output_size + else: + pooled_height = output_size + pooled_width = output_size + + pooled_height = int(pooled_height) + pooled_width = int(pooled_width) + spatial_scale = float(kwargs.pop("spatial_scale", 1.0)) + sampling_ratio = int(kwargs.pop("sampling_ratio", -1)) + aligned = bool(kwargs.pop("aligned", False)) + args.extend([spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned]) + return args, {} + + # Ops to be tested for numerical consistency between onnx and pytorch # Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = ( @@ -1919,7 +1948,11 @@ def _where_input_wrangler( ), TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like), TorchLibOpInfo("torchvision.ops.nms", vision_ops.torchvision_nms), - TorchLibOpInfo("torchvision.ops.roi_align", vision_ops.torchvision_roi_align), + TorchLibOpInfo( + "torchvision.ops.roi_align", + vision_ops.torchvision_roi_align, + input_wrangler=_torchvision_roi_align_default_input_wrangler, + ), TorchLibOpInfo("torchvision.ops.roi_pool", vision_ops.torchvision_roi_pool), ) From a7799becdc5079d77a131e8de102d8fbab472f45 Mon Sep 17 00:00:00 2001 From: FraGirla <66215169+FraGirla@users.noreply.github.com> Date: Wed, 25 Feb 2026 19:02:26 +0100 Subject: [PATCH 2/5] Add docstring to torchvision_roi_align for improved clarity --- onnxscript/function_libs/torch_lib/ops/vision.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/function_libs/torch_lib/ops/vision.py b/onnxscript/function_libs/torch_lib/ops/vision.py index 65386f1577..513010beb7 100644 --- a/onnxscript/function_libs/torch_lib/ops/vision.py +++ b/onnxscript/function_libs/torch_lib/ops/vision.py @@ -62,6 +62,7 @@ def torchvision_roi_align( sampling_ratio: int = -1, aligned: bool = False, ): + """torchvision::roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned)""" batch_indices = _process_batch_indices_for_roi_align(rois) rois_coords = _process_rois_for_roi_align(rois) coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel" From db3701c12790cf180eedbc5495a2512f34389440 Mon Sep 17 00:00:00 2001 From: FraGirla <66215169+FraGirla@users.noreply.github.com> Date: Wed, 25 Feb 2026 21:46:41 +0100 Subject: [PATCH 3/5] Added types in docstring --- onnxscript/function_libs/torch_lib/ops/vision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/vision.py b/onnxscript/function_libs/torch_lib/ops/vision.py index 513010beb7..01ed9314bc 100644 --- a/onnxscript/function_libs/torch_lib/ops/vision.py +++ b/onnxscript/function_libs/torch_lib/ops/vision.py @@ -62,7 +62,7 @@ def torchvision_roi_align( sampling_ratio: int = -1, aligned: bool = False, ): - """torchvision::roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned)""" + """torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, bool aligned) -> Tensor""" batch_indices = _process_batch_indices_for_roi_align(rois) rois_coords = _process_rois_for_roi_align(rois) coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel" From 7211569dcbb979a1b00408b2d28470baf34ec869 Mon Sep 17 00:00:00 2001 From: FraGirla <66215169+FraGirla@users.noreply.github.com> Date: Wed, 25 Feb 2026 21:51:35 +0100 Subject: [PATCH 4/5] Refactor sample_inputs_roi_align for improved clarity and efficiency; remove redundant tests and simplify input handling --- tests/function_libs/torch_lib/extra_opinfo.py | 97 +++++-------------- .../function_libs/torch_lib/ops_test_data.py | 36 +------ 2 files changed, 26 insertions(+), 107 deletions(-) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index a28a6c9cd9..bc1299255e 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -1471,81 +1471,34 @@ def sample_inputs_replication_pad1d(op_info, device, dtype, requires_grad, **kwa def sample_inputs_roi_align(op_info, device, dtype, requires_grad, **kwargs): - del op_info - del kwargs - # roi_align signature: (input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False) - - # Test 1: spatial_scale=1, sampling_ratio=2 - x1 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) - roi1 = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=dtype, device=device) - yield opinfo_core.SampleInput( - x1, - args=(roi1, (5, 5)), - kwargs={"spatial_scale": 1.0, "sampling_ratio": 2, "aligned": True}, - ) - - # Test 2: spatial_scale=0.5, sampling_ratio=3 - x2 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) - roi2 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device) - yield opinfo_core.SampleInput( - x2, - args=(roi2, (5, 5)), - kwargs={"spatial_scale": 0.5, "sampling_ratio": 3, "aligned": True}, - ) - - # Test 3: spatial_scale=1.8, sampling_ratio=2 - x3 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) - roi3 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device) - yield opinfo_core.SampleInput( - x3, - args=(roi3, (5, 5)), - kwargs={"spatial_scale": 1.8, "sampling_ratio": 2, "aligned": True}, - ) - - # Test 4: spatial_scale=2.5, sampling_ratio=0, output_size=(2,2) - x4 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) - roi4 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device) - yield opinfo_core.SampleInput( - x4, - args=(roi4, (2, 2)), - kwargs={"spatial_scale": 2.5, "sampling_ratio": 0, "aligned": True}, - ) + del op_info, kwargs - # Test 5: spatial_scale=2.5, sampling_ratio=-1, output_size=(2,2) - x5 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) - roi5 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device) - yield opinfo_core.SampleInput( - x5, - args=(roi5, (2, 2)), - kwargs={"spatial_scale": 2.5, "sampling_ratio": -1, "aligned": True}, - ) + def make_x(): + return torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) - # Test 6: malformed boxes (test_roi_align_malformed_boxes) - x6 = torch.randn(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) - roi6 = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=dtype, device=device) - yield opinfo_core.SampleInput( - x6, - args=(roi6, (5, 5)), - kwargs={"spatial_scale": 1.0, "sampling_ratio": 1, "aligned": True}, - ) + # rois is [K, 5] = [batch_idx, x1, y1, x2, y2] + roi_a = torch.tensor([[0, 1.5, 1.5, 3.0, 3.0]], dtype=dtype, device=device) + roi_b = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device) + roi_int = torch.tensor([[0, 0.0, 0.0, 4.0, 4.0]], dtype=dtype, device=device) + roi_malformed = torch.tensor([[0, 2.0, 0.3, 1.5, 1.5]], dtype=dtype, device=device) # x1 > x2-ish - # Test 7: aligned=False, spatial_scale=1, sampling_ratio=2 - x7 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) - roi7 = torch.tensor([[0, 0, 0, 4, 4]], dtype=dtype, device=device) - yield opinfo_core.SampleInput( - x7, - args=(roi7, (5, 5)), - kwargs={"spatial_scale": 1.0, "sampling_ratio": 2, "aligned": False}, - ) + # (rois, spatial_scale, pooled_h, pooled_w, sampling_ratio, aligned) + cases = [ + (roi_a, 1.0, 5, 5, 2, True), + (roi_b, 0.5, 5, 5, 3, True), + (roi_b, 1.8, 5, 5, 2, True), + (roi_b, 2.5, 2, 2, 0, True), + (roi_b, 2.5, 2, 2, -1, True), + (roi_malformed, 1.0, 5, 5, 1, True), + (roi_int, 1.0, 5, 5, 2, False), + (roi_int, 1.0, 5, 5, -1, False), + ] - # Test 8: aligned=False, spatial_scale=1, sampling_ratio=-1 - x8 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) - roi8 = torch.tensor([[0, 0, 0, 4, 4]], dtype=dtype, device=device) - yield opinfo_core.SampleInput( - x8, - args=(roi8, (5, 5)), - kwargs={"spatial_scale": 1.0, "sampling_ratio": -1, "aligned": False}, - ) + for rois, spatial_scale, ph, pw, sr, aligned in cases: + yield opinfo_core.SampleInput( + make_x(), + args=(rois, float(spatial_scale), int(ph), int(pw), int(sr), bool(aligned)), + ) def sample_inputs_roi_pool(op_info, device, dtype, requires_grad, **kwargs): @@ -3132,7 +3085,7 @@ def __init__(self): ), opinfo_core.OpInfo( "torchvision.ops.roi_align", - op=torchvision.ops.roi_align, + op=torch.ops.torchvision.roi_align.default, dtypes=common_dtype.floating_types(), sample_inputs_func=sample_inputs_roi_align, supports_out=False, diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index d7d62c75cf..27c931d03b 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -448,36 +448,6 @@ def _where_input_wrangler( args[0], args[1] = args[1], args[0] return args, kwargs - -def _torchvision_roi_align_default_input_wrangler( - args: list[Any], kwargs: dict[str, Any] -) -> tuple[list[Any], dict[str, Any]]: - # Convert: - # roi_align(input, boxes, output_size, spatial_scale=..., sampling_ratio=..., aligned=...) - # into: - # roi_align(input, boxes, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned) - output_size = args.pop(2) - if isinstance(output_size, np.ndarray): - if output_size.ndim == 0: - pooled_height = int(output_size) - pooled_width = int(output_size) - else: - pooled_height, pooled_width = output_size.tolist() - elif isinstance(output_size, (tuple, list)): - pooled_height, pooled_width = output_size - else: - pooled_height = output_size - pooled_width = output_size - - pooled_height = int(pooled_height) - pooled_width = int(pooled_width) - spatial_scale = float(kwargs.pop("spatial_scale", 1.0)) - sampling_ratio = int(kwargs.pop("sampling_ratio", -1)) - aligned = bool(kwargs.pop("aligned", False)) - args.extend([spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned]) - return args, {} - - # Ops to be tested for numerical consistency between onnx and pytorch # Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = ( @@ -1948,11 +1918,7 @@ def _torchvision_roi_align_default_input_wrangler( ), TorchLibOpInfo("zeros_like", core_ops.aten_zeros_like), TorchLibOpInfo("torchvision.ops.nms", vision_ops.torchvision_nms), - TorchLibOpInfo( - "torchvision.ops.roi_align", - vision_ops.torchvision_roi_align, - input_wrangler=_torchvision_roi_align_default_input_wrangler, - ), + TorchLibOpInfo("torchvision.ops.roi_align", vision_ops.torchvision_roi_align), TorchLibOpInfo("torchvision.ops.roi_pool", vision_ops.torchvision_roi_pool), ) From 18cd6214e2e1b99f2a289bcb4521b4c8b7a99a38 Mon Sep 17 00:00:00 2001 From: FraGirla <66215169+FraGirla@users.noreply.github.com> Date: Thu, 26 Feb 2026 00:40:30 +0100 Subject: [PATCH 5/5] Fixed lintrunner --- tests/function_libs/torch_lib/extra_opinfo.py | 8 ++++++-- tests/function_libs/torch_lib/ops_test_data.py | 1 + 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/function_libs/torch_lib/extra_opinfo.py b/tests/function_libs/torch_lib/extra_opinfo.py index bc1299255e..29df92f097 100644 --- a/tests/function_libs/torch_lib/extra_opinfo.py +++ b/tests/function_libs/torch_lib/extra_opinfo.py @@ -1474,13 +1474,17 @@ def sample_inputs_roi_align(op_info, device, dtype, requires_grad, **kwargs): del op_info, kwargs def make_x(): - return torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad) + return torch.rand( + 1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad + ) # rois is [K, 5] = [batch_idx, x1, y1, x2, y2] roi_a = torch.tensor([[0, 1.5, 1.5, 3.0, 3.0]], dtype=dtype, device=device) roi_b = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device) roi_int = torch.tensor([[0, 0.0, 0.0, 4.0, 4.0]], dtype=dtype, device=device) - roi_malformed = torch.tensor([[0, 2.0, 0.3, 1.5, 1.5]], dtype=dtype, device=device) # x1 > x2-ish + roi_malformed = torch.tensor( + [[0, 2.0, 0.3, 1.5, 1.5]], dtype=dtype, device=device + ) # x1 > x2-ish # (rois, spatial_scale, pooled_h, pooled_w, sampling_ratio, aligned) cases = [ diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 27c931d03b..a40535f4ba 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -448,6 +448,7 @@ def _where_input_wrangler( args[0], args[1] = args[1], args[0] return args, kwargs + # Ops to be tested for numerical consistency between onnx and pytorch # Find the names of the OpInfos in torch/testing/_internal/common_methods_invocations.py TESTED_TORCHLIB_OPS: tuple[TorchLibOpInfo, ...] = (