Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions onnxscript/function_libs/torch_lib/ops/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,16 @@ def _process_sampling_ratio_for_roi_align(sampling_ratio: int):
@torch_op("torchvision::roi_align", trace_only=True)
def torchvision_roi_align(
input,
boxes,
output_size: Sequence[int],
spatial_scale: float = 1.0,
rois,
spatial_scale: float,
pooled_height: int,
pooled_width: int,
sampling_ratio: int = -1,
aligned: bool = False,
):
"""roi_align(input: torch.Tensor, boxes: Union[torch.Tensor, list[torch.Tensor]], output_size: None, spatial_scale: float = 1.0, sampling_ratio: int = -1, aligned: bool = False) -> torch.Tensor"""
pooled_height, pooled_width = output_size
batch_indices = _process_batch_indices_for_roi_align(boxes)
rois_coords = _process_rois_for_roi_align(boxes)
"""torchvision::roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, bool aligned) -> Tensor"""
batch_indices = _process_batch_indices_for_roi_align(rois)
rois_coords = _process_rois_for_roi_align(rois)
coordinate_transformation_mode = "half_pixel" if aligned else "output_half_pixel"
sampling_ratio = _process_sampling_ratio_for_roi_align(sampling_ratio)

Expand Down
101 changes: 29 additions & 72 deletions tests/function_libs/torch_lib/extra_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,81 +1471,38 @@ def sample_inputs_replication_pad1d(op_info, device, dtype, requires_grad, **kwa


def sample_inputs_roi_align(op_info, device, dtype, requires_grad, **kwargs):
del op_info
del kwargs
# roi_align signature: (input, boxes, output_size, spatial_scale=1.0, sampling_ratio=-1, aligned=False)

# Test 1: spatial_scale=1, sampling_ratio=2
x1 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
roi1 = torch.tensor([[0, 1.5, 1.5, 3, 3]], dtype=dtype, device=device)
yield opinfo_core.SampleInput(
x1,
args=(roi1, (5, 5)),
kwargs={"spatial_scale": 1.0, "sampling_ratio": 2, "aligned": True},
)

# Test 2: spatial_scale=0.5, sampling_ratio=3
x2 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
roi2 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
yield opinfo_core.SampleInput(
x2,
args=(roi2, (5, 5)),
kwargs={"spatial_scale": 0.5, "sampling_ratio": 3, "aligned": True},
)

# Test 3: spatial_scale=1.8, sampling_ratio=2
x3 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
roi3 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
yield opinfo_core.SampleInput(
x3,
args=(roi3, (5, 5)),
kwargs={"spatial_scale": 1.8, "sampling_ratio": 2, "aligned": True},
)

# Test 4: spatial_scale=2.5, sampling_ratio=0, output_size=(2,2)
x4 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
roi4 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
yield opinfo_core.SampleInput(
x4,
args=(roi4, (2, 2)),
kwargs={"spatial_scale": 2.5, "sampling_ratio": 0, "aligned": True},
)
del op_info, kwargs

# Test 5: spatial_scale=2.5, sampling_ratio=-1, output_size=(2,2)
x5 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
roi5 = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
yield opinfo_core.SampleInput(
x5,
args=(roi5, (2, 2)),
kwargs={"spatial_scale": 2.5, "sampling_ratio": -1, "aligned": True},
)
def make_x():
return torch.rand(
1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad
)

# Test 6: malformed boxes (test_roi_align_malformed_boxes)
x6 = torch.randn(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
roi6 = torch.tensor([[0, 2, 0.3, 1.5, 1.5]], dtype=dtype, device=device)
yield opinfo_core.SampleInput(
x6,
args=(roi6, (5, 5)),
kwargs={"spatial_scale": 1.0, "sampling_ratio": 1, "aligned": True},
)
# rois is [K, 5] = [batch_idx, x1, y1, x2, y2]
roi_a = torch.tensor([[0, 1.5, 1.5, 3.0, 3.0]], dtype=dtype, device=device)
roi_b = torch.tensor([[0, 0.2, 0.3, 4.5, 3.5]], dtype=dtype, device=device)
roi_int = torch.tensor([[0, 0.0, 0.0, 4.0, 4.0]], dtype=dtype, device=device)
roi_malformed = torch.tensor(
[[0, 2.0, 0.3, 1.5, 1.5]], dtype=dtype, device=device
) # x1 > x2-ish

# Test 7: aligned=False, spatial_scale=1, sampling_ratio=2
x7 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
roi7 = torch.tensor([[0, 0, 0, 4, 4]], dtype=dtype, device=device)
yield opinfo_core.SampleInput(
x7,
args=(roi7, (5, 5)),
kwargs={"spatial_scale": 1.0, "sampling_ratio": 2, "aligned": False},
)
# (rois, spatial_scale, pooled_h, pooled_w, sampling_ratio, aligned)
cases = [
(roi_a, 1.0, 5, 5, 2, True),
(roi_b, 0.5, 5, 5, 3, True),
(roi_b, 1.8, 5, 5, 2, True),
(roi_b, 2.5, 2, 2, 0, True),
(roi_b, 2.5, 2, 2, -1, True),
(roi_malformed, 1.0, 5, 5, 1, True),
(roi_int, 1.0, 5, 5, 2, False),
(roi_int, 1.0, 5, 5, -1, False),
]

# Test 8: aligned=False, spatial_scale=1, sampling_ratio=-1
x8 = torch.rand(1, 1, 10, 10, dtype=dtype, device=device, requires_grad=requires_grad)
roi8 = torch.tensor([[0, 0, 0, 4, 4]], dtype=dtype, device=device)
yield opinfo_core.SampleInput(
x8,
args=(roi8, (5, 5)),
kwargs={"spatial_scale": 1.0, "sampling_ratio": -1, "aligned": False},
)
for rois, spatial_scale, ph, pw, sr, aligned in cases:
yield opinfo_core.SampleInput(
make_x(),
args=(rois, float(spatial_scale), int(ph), int(pw), int(sr), bool(aligned)),
)


def sample_inputs_roi_pool(op_info, device, dtype, requires_grad, **kwargs):
Expand Down Expand Up @@ -3132,7 +3089,7 @@ def __init__(self):
),
opinfo_core.OpInfo(
"torchvision.ops.roi_align",
op=torchvision.ops.roi_align,
op=torch.ops.torchvision.roi_align.default,
dtypes=common_dtype.floating_types(),
sample_inputs_func=sample_inputs_roi_align,
supports_out=False,
Expand Down
Loading