diff --git a/test/test_ops.py b/test/test_ops.py index 9521f21a815..74bd3ee3522 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -770,7 +770,8 @@ def test_is_leaf_node(self, device): class TestNMS: - def _reference_nms(self, boxes, scores, iou_threshold): + @classmethod + def _reference_nms(cls, boxes, scores, iou_threshold): """ Args: boxes: boxes in corner-form @@ -825,8 +826,8 @@ def test_nms_ref(self, iou, seed): def test_nms_input_errors(self): with pytest.raises(RuntimeError): ops.nms(torch.rand(4), torch.rand(3), 0.5) - with pytest.raises(RuntimeError): - ops.nms(torch.rand(3, 5), torch.rand(3), 0.5) + with pytest.raises((RuntimeError, ValueError)): + ops.nms(torch.rand(3, 6), torch.rand(3), 0.5) with pytest.raises(RuntimeError): ops.nms(torch.rand(3, 4), torch.rand(3, 2), 0.5) with pytest.raises(RuntimeError): @@ -2007,6 +2008,87 @@ def test_cuda_cpu_consistency(self): torch.testing.assert_close(iou_cpu, iou_cuda.cpu(), atol=1e-5, rtol=1e-5) +class TestNMSRotated: + @staticmethod + def _create_tensors(N, device="cpu"): + boxes = torch.rand(N, 4, device=device) * 200 + boxes[:, 2:] += boxes[:, :2] + scores = torch.rand(N, device=device) + return boxes, scores + + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) + def test_nms_rotated_0_degree(self, iou): + torch.manual_seed(0) + N = 1000 + boxes, scores = self._create_tensors(N) + rotated_boxes = torch.zeros(N, 5) + rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 + rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 + rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0] + rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] + + keep_ref = TestNMS._reference_nms(boxes, scores, iou) + keep = ops.nms(rotated_boxes, scores, iou) + torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) + keep_non_rotated = ops.nms(boxes, scores, iou) + torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) + + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) + def test_nms_rotated_90_degrees(self, iou): + torch.manual_seed(0) + N = 1000 + boxes, scores = self._create_tensors(N) + rotated_boxes = torch.zeros(N, 5) + rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 + rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 + # Swap width and height for 90 degrees so reference horizontal NMS can be used + rotated_boxes[:, 2] = boxes[:, 3] - boxes[:, 1] + rotated_boxes[:, 3] = boxes[:, 2] - boxes[:, 0] + rotated_boxes[:, 4] = 90 + + keep_ref = TestNMS._reference_nms(boxes, scores, iou) + keep = ops.nms(rotated_boxes, scores, iou) + torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) + keep_non_rotated = ops.nms(boxes, scores, iou) + torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) + + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) + def test_nms_rotated_180_degrees(self, iou): + torch.manual_seed(0) + N = 1000 + boxes, scores = self._create_tensors(N) + rotated_boxes = torch.zeros(N, 5) + rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 + rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 + rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0] + rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] + rotated_boxes[:, 4] = 180 + + keep_ref = TestNMS._reference_nms(boxes, scores, iou) + keep = ops.nms(rotated_boxes, scores, iou) + torch.testing.assert_close(keep, keep_ref, atol=0, rtol=0) + keep_non_rotated = ops.nms(boxes, scores, iou) + torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) + + @pytest.mark.parametrize("iou", (0.2, 0.5, 0.8)) + def test_batched_nms_rotated_0_degree(self, iou): + torch.manual_seed(0) + N = 2000 + num_classes = 50 + boxes, scores = self._create_tensors(N) + idxs = torch.randint(0, num_classes, (N,)) + rotated_boxes = torch.zeros(N, 5) + rotated_boxes[:, 0] = (boxes[:, 0] + boxes[:, 2]) / 2.0 + rotated_boxes[:, 1] = (boxes[:, 1] + boxes[:, 3]) / 2.0 + rotated_boxes[:, 2] = boxes[:, 2] - boxes[:, 0] + rotated_boxes[:, 3] = boxes[:, 3] - boxes[:, 1] + backup = rotated_boxes.clone() + keep_non_rotated = ops.batched_nms(boxes, scores, idxs, iou) + keep = ops.batched_nms(rotated_boxes, scores, idxs, iou) + assert torch.allclose(rotated_boxes, backup) + torch.testing.assert_close(keep, keep_non_rotated, atol=0, rtol=0) + + def get_boxes(dtype, device): box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) diff --git a/torchvision/_autograd_registrations.py b/torchvision/_autograd_registrations.py index 18d9ced6c54..564657ee35a 100644 --- a/torchvision/_autograd_registrations.py +++ b/torchvision/_autograd_registrations.py @@ -235,6 +235,15 @@ def _autocast_nms(dets, scores, iou_threshold): ) +def _autocast_nms_rotated(dets, scores, iou_threshold): + with torch._C._ExcludeDispatchKeyGuard(_all_autocast_keys): + return torch.ops.torchvision.nms_rotated( + _autocast_cast(dets), + _autocast_cast(scores), + iou_threshold, + ) + + def _autocast_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): orig_dtype = input.dtype with torch._C._ExcludeDispatchKeyGuard(_all_autocast_keys): @@ -358,6 +367,7 @@ def _autocast_deform_conv2d( # nms and roi_align: registered for all autocast device types for _key in ("AutocastCUDA", "AutocastCPU", "AutocastXPU"): _autocast_lib.impl("nms", _autocast_nms, _key) + _autocast_lib.impl("nms_rotated", _autocast_nms_rotated, _key) _autocast_lib.impl("roi_align", _autocast_roi_align, _key) # Other ops: CUDA autocast only diff --git a/torchvision/_meta_registrations.py b/torchvision/_meta_registrations.py index f75bfb77a7f..e7a183250e4 100644 --- a/torchvision/_meta_registrations.py +++ b/torchvision/_meta_registrations.py @@ -174,6 +174,20 @@ def meta_nms(dets, scores, iou_threshold): return dets.new_empty(num_to_keep, dtype=torch.long) +@torch.library.register_fake("torchvision::nms_rotated") +def meta_nms_rotated(dets, scores, iou_threshold): + torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D") + torch._check(dets.size(1) == 5, lambda: f"boxes should have 5 elements in dimension 1, got {dets.size(1)}") + torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}") + torch._check( + dets.size(0) == scores.size(0), + lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}", + ) + ctx = torch._custom_ops.get_ctx() + num_to_keep = ctx.create_unbacked_symint() + return dets.new_empty(num_to_keep, dtype=torch.long) + + @register_meta("deform_conv2d") def meta_deform_conv2d( input, diff --git a/torchvision/csrc/ops/cpu/nms_kernel.cpp b/torchvision/csrc/ops/cpu/nms_kernel.cpp index 454ce118a6d..9a5eb4f242f 100644 --- a/torchvision/csrc/ops/cpu/nms_kernel.cpp +++ b/torchvision/csrc/ops/cpu/nms_kernel.cpp @@ -1,16 +1,25 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + #include #include +#include "../box_iou_rotated_utils.h" + namespace vision { namespace ops { namespace { -template +template at::Tensor nms_kernel_impl( const at::Tensor& dets, const at::Tensor& scores, - double iou_threshold) { + double iou_threshold, + IoUFunc iou_func) { TORCH_CHECK(dets.is_cpu(), "dets must be a CPU tensor"); TORCH_CHECK(scores.is_cpu(), "scores must be a CPU tensor"); TORCH_CHECK( @@ -21,13 +30,6 @@ at::Tensor nms_kernel_impl( return at::empty({0}, dets.options().dtype(at::kLong)); } - auto x1_t = dets.select(1, 0).contiguous(); - auto y1_t = dets.select(1, 1).contiguous(); - auto x2_t = dets.select(1, 2).contiguous(); - auto y2_t = dets.select(1, 3).contiguous(); - - at::Tensor areas_t = (x2_t - x1_t) * (y2_t - y1_t); - auto order_t = std::get<1>( scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); @@ -38,11 +40,6 @@ at::Tensor nms_kernel_impl( auto suppressed = suppressed_t.data_ptr(); auto keep = keep_t.data_ptr(); auto order = order_t.data_ptr(); - auto x1 = x1_t.data_ptr(); - auto y1 = y1_t.data_ptr(); - auto x2 = x2_t.data_ptr(); - auto y2 = y2_t.data_ptr(); - auto areas = areas_t.data_ptr(); int64_t num_to_keep = 0; @@ -52,26 +49,16 @@ at::Tensor nms_kernel_impl( continue; } keep[num_to_keep++] = i; - auto ix1 = x1[i]; - auto iy1 = y1[i]; - auto ix2 = x2[i]; - auto iy2 = y2[i]; - auto iarea = areas[i]; + + iou_func.set_box(i); for (int64_t _j = _i + 1; _j < ndets; _j++) { auto j = order[_j]; if (suppressed[j] == 1) { continue; } - auto xx1 = std::max(ix1, x1[j]); - auto yy1 = std::max(iy1, y1[j]); - auto xx2 = std::min(ix2, x2[j]); - auto yy2 = std::min(iy2, y2[j]); - - auto w = std::max(static_cast(0), xx2 - xx1); - auto h = std::max(static_cast(0), yy2 - yy1); - auto inter = w * h; - auto ovr = inter / (iarea + areas[j] - inter); + + auto ovr = iou_func.compute(j); if (ovr > iou_threshold) { suppressed[j] = 1; } @@ -80,6 +67,70 @@ at::Tensor nms_kernel_impl( return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep); } +template +struct NonRotatedIoU { + const scalar_t* x1; + const scalar_t* y1; + const scalar_t* x2; + const scalar_t* y2; + const scalar_t* areas; + at::Tensor x1_t, y1_t, x2_t, y2_t, areas_t; + + scalar_t ix1, iy1, ix2, iy2, iarea; + + NonRotatedIoU(const at::Tensor& dets) { + x1_t = dets.select(1, 0).contiguous(); + y1_t = dets.select(1, 1).contiguous(); + x2_t = dets.select(1, 2).contiguous(); + y2_t = dets.select(1, 3).contiguous(); + areas_t = (x2_t - x1_t) * (y2_t - y1_t); + x1 = x1_t.data_ptr(); + y1 = y1_t.data_ptr(); + x2 = x2_t.data_ptr(); + y2 = y2_t.data_ptr(); + areas = areas_t.data_ptr(); + } + + void set_box(int64_t i) { + ix1 = x1[i]; + iy1 = y1[i]; + ix2 = x2[i]; + iy2 = y2[i]; + iarea = areas[i]; + } + + scalar_t compute(int64_t j) const { + auto xx1 = std::max(ix1, x1[j]); + auto yy1 = std::max(iy1, y1[j]); + auto xx2 = std::min(ix2, x2[j]); + auto yy2 = std::min(iy2, y2[j]); + + auto w = std::max(static_cast(0), xx2 - xx1); + auto h = std::max(static_cast(0), yy2 - yy1); + auto inter = w * h; + return inter / (iarea + areas[j] - inter); + } +}; + +template +struct RotatedIoU { + const at::Tensor* dets_ptr; + + RotatedIoU(const at::Tensor& dets) : dets_ptr(&dets) {} + + int64_t i; + + void set_box(int64_t i) { + this->i = i; + } + + scalar_t compute(int64_t j) const { + return single_box_iou_rotated( + (*dets_ptr)[i].template data_ptr(), + (*dets_ptr)[j].template data_ptr()); + } +}; + at::Tensor nms_kernel( const at::Tensor& dets, const at::Tensor& scores, @@ -106,7 +157,40 @@ at::Tensor nms_kernel( auto result = at::empty({0}, dets.options()); AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_kernel", [&] { - result = nms_kernel_impl(dets, scores, iou_threshold); + result = nms_kernel_impl( + dets, scores, iou_threshold, NonRotatedIoU(dets)); + }); + return result; +} + +at::Tensor nms_rotated_kernel( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + TORCH_CHECK( + dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); + TORCH_CHECK( + dets.size(1) == 5, + "boxes should have 5 elements in dimension 1, got ", + dets.size(1)); + TORCH_CHECK( + scores.dim() == 1, + "scores should be a 1d tensor, got ", + scores.dim(), + "D"); + TORCH_CHECK( + dets.size(0) == scores.size(0), + "boxes and scores should have same number of elements in ", + "dimension 0, got ", + dets.size(0), + " and ", + scores.size(0)); + + auto result = at::empty({0}, dets.options()); + + AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_rotated_kernel", [&] { + result = nms_kernel_impl( + dets, scores, iou_threshold, RotatedIoU(dets)); }); return result; } @@ -115,6 +199,9 @@ at::Tensor nms_kernel( TORCH_LIBRARY_IMPL(torchvision, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("torchvision::nms"), TORCH_FN(nms_kernel)); + m.impl( + TORCH_SELECTIVE_NAME("torchvision::nms_rotated"), + TORCH_FN(nms_rotated_kernel)); } } // namespace ops diff --git a/torchvision/csrc/ops/nms_rotated.cpp b/torchvision/csrc/ops/nms_rotated.cpp new file mode 100644 index 00000000000..da619e6c32c --- /dev/null +++ b/torchvision/csrc/ops/nms_rotated.cpp @@ -0,0 +1,33 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "nms_rotated.h" + +#include +#include +#include + +namespace vision { +namespace ops { + +at::Tensor nms_rotated( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.nms_rotated.nms_rotated"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::nms_rotated", "") + .typed(); + return op.call(dets, scores, iou_threshold); +} + +TORCH_LIBRARY_FRAGMENT(torchvision, m) { + m.def(TORCH_SELECTIVE_SCHEMA( + "torchvision::nms_rotated(Tensor dets, Tensor scores, float iou_threshold) -> Tensor")); +} + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/nms_rotated.h b/torchvision/csrc/ops/nms_rotated.h new file mode 100644 index 00000000000..98bc225f691 --- /dev/null +++ b/torchvision/csrc/ops/nms_rotated.h @@ -0,0 +1,21 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include "../macros.h" + +namespace vision { +namespace ops { + +VISION_API at::Tensor nms_rotated( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold); + +} // namespace ops +} // namespace vision diff --git a/torchvision/csrc/ops/ops.h b/torchvision/csrc/ops/ops.h index 9902c3b1ecd..173f2ef77b1 100644 --- a/torchvision/csrc/ops/ops.h +++ b/torchvision/csrc/ops/ops.h @@ -3,6 +3,7 @@ #include "box_iou_rotated.h" #include "deform_conv2d.h" #include "nms.h" +#include "nms_rotated.h" #include "ps_roi_align.h" #include "ps_roi_pool.h" #include "roi_align.h" diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index a089af2c4ad..43692538232 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -32,9 +32,12 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: to the behavior of argsort in PyTorch when repeated values are present. Args: - boxes (Tensor[N, 4])): boxes to perform NMS on. They - are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and - ``0 <= y1 < y2``. + boxes (Tensor[N, K])): boxes to perform NMS on. + If K=4, boxes are expected to be in ``(x1, y1, x2, y2)`` format + with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + If K=5, boxes are expected to be in ``(cx, cy, w, h, angle)`` format + for rotated boxes, where ``(cx, cy)`` is the center, ``(w, h)`` is + width and height, and ``angle`` is the rotation angle in degrees. scores (Tensor[N]): scores for each one of the boxes iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold @@ -45,7 +48,15 @@ def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(nms) _assert_has_ops() - return torch.ops.torchvision.nms(boxes, scores, iou_threshold) + + if boxes.size(-1) == 4: + return torch.ops.torchvision.nms(boxes, scores, iou_threshold) + elif boxes.size(-1) == 5: + return torch.ops.torchvision.nms_rotated(boxes, scores, iou_threshold) + else: + raise ValueError( + f"boxes should have 4 (axis-aligned) or 5 (rotated) elements in the last dimension, got {boxes.size(-1)}" + ) def batched_nms( @@ -61,9 +72,9 @@ def batched_nms( will not be applied between elements of different categories. Args: - boxes (Tensor[N, 4]): boxes where NMS will be performed. They - are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and - ``0 <= y1 < y2``. + boxes (Tensor[N, K]): boxes where NMS will be performed. + If K=4, boxes are expected to be in ``(x1, y1, x2, y2)`` format with ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + If K=5, boxes are expected to be in ``(cx, cy, w, h, angle)`` format. scores (Tensor[N]): scores for each one of the boxes idxs (Tensor[N]): indices of the categories for each one of the boxes. iou_threshold (float): discards all overlapping boxes with IoU > iou_threshold @@ -98,7 +109,11 @@ def _batched_nms_coordinate_trick( return torch.empty((0,), dtype=torch.int64, device=boxes.device) max_coordinate = boxes.max() offsets = idxs.to(boxes) * (max_coordinate + torch.tensor(1).to(boxes)) - boxes_for_nms = boxes + offsets[:, None] + if boxes.size(-1) == 4: + boxes_for_nms = boxes + offsets[:, None] + else: + boxes_for_nms = boxes.clone() + boxes_for_nms[..., :2] = boxes[..., :2] + offsets[:, None] keep = nms(boxes_for_nms, scores, iou_threshold) return keep