From 4eb3228aad3445e53144c98efe6280a99e09a749 Mon Sep 17 00:00:00 2001 From: Aharrypotter <62729549+Aharrypotter@users.noreply.github.com> Date: Mon, 20 Apr 2026 07:56:54 +0800 Subject: [PATCH 1/4] [Relax][TOPI][TFLite] Add soft-NMS support to non_max_suppression for TFLite NMSv5 This commit adds soft-NMS support to relax.vision.non_max_suppression and the TFLite frontend's NonMaxSuppressionV5 converter. Key changes: - Extended NonMaximumSuppressionAttrs with soft_nms_sigma and score_threshold. - Implemented Gaussian soft-NMS decay in TOPI _classic_nms_ir: score *= exp(-iou^2 / sigma) for overlapping boxes instead of hard suppression. - Introduced conditional return type: soft-NMS returns a 3-tuple (out_data, box_indices, valid_box_count) so the TFLite frontend can extract decayed scores; hard NMS keeps the original 2-tuple. - Updated Relax struct info inference to match the conditional return type. - Removed the soft_nms_sigma != 0.0 OpNotImplemented guard in TFLite frontend. - Added tests: struct info inference, legalization, TVMScript parser, and TFLite frontend E2E smoke tests for soft-NMS. --- include/tvm/relax/attrs/vision.h | 8 +- .../relax/frontend/tflite/tflite_frontend.py | 43 +++++++--- python/tvm/relax/op/vision/nms.py | 20 ++++- .../relax/transform/legalize_ops/vision.py | 2 + python/tvm/topi/testing/nms_python.py | 24 ++++-- python/tvm/topi/vision/nms.py | 46 ++++++++-- src/relax/op/vision/nms.cc | 29 ++++++- src/relax/op/vision/nms.h | 4 +- tests/python/relax/test_frontend_tflite.py | 83 ++++++++++++++++++- tests/python/relax/test_op_vision.py | 69 +++++++++++++++ .../relax/test_tvmscript_parser_op_vision.py | 70 ++++++++++++++++ 11 files changed, 361 insertions(+), 37 deletions(-) diff --git a/include/tvm/relax/attrs/vision.h b/include/tvm/relax/attrs/vision.h index 8971127d76dc..9dec6fd5037e 100644 --- a/include/tvm/relax/attrs/vision.h +++ b/include/tvm/relax/attrs/vision.h @@ -122,6 +122,8 @@ struct NonMaximumSuppressionAttrs int id_index; bool return_indices; bool invalid_to_bottom; + double soft_nms_sigma; + double score_threshold; static void RegisterReflection() { namespace refl = tvm::ffi::reflection; @@ -143,7 +145,11 @@ struct NonMaximumSuppressionAttrs .def_ro("return_indices", &NonMaximumSuppressionAttrs::return_indices, "Whether to return box indices in input data.") .def_ro("invalid_to_bottom", &NonMaximumSuppressionAttrs::invalid_to_bottom, - "Whether to move all valid bounding boxes to the top."); + "Whether to move all valid bounding boxes to the top.") + .def_ro("soft_nms_sigma", &NonMaximumSuppressionAttrs::soft_nms_sigma, + "Sigma for soft-NMS; 0.0 means standard hard NMS.") + .def_ro("score_threshold", &NonMaximumSuppressionAttrs::score_threshold, + "Score threshold for soft-NMS validity check; 0.0 when unused."); } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.NonMaximumSuppressionAttrs", NonMaximumSuppressionAttrs, BaseAttrsNode); diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 334021b90300..221609a1c628 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -3573,10 +3573,6 @@ def convert_nms_v5(self, op): if isinstance(soft_nms_sigma, np.ndarray): assert soft_nms_sigma.size == 1, "only one value is expected." soft_nms_sigma = float(soft_nms_sigma) - if soft_nms_sigma != 0.0: - raise tvm.error.OpNotImplemented( - "It is soft_nms when soft_nms_sigma != 0, which is not supported!" - ) scores_expand = relax.op.expand_dims(scores, axis=-1) data = relax.op.concat([scores_expand, boxes], axis=-1) @@ -3602,18 +3598,39 @@ def convert_nms_v5(self, op): id_index=-1, return_indices=True, invalid_to_bottom=False, + soft_nms_sigma=soft_nms_sigma, + score_threshold=score_threshold, ) - selected_indices = relax.op.squeeze(nms_ret[0], axis=[0]) - selected_indices = relax.op.strided_slice( - selected_indices, axes=[0], begin=[0], end=[max_output_size] - ) - num_valid = relax.op.reshape(nms_ret[1], []) + if soft_nms_sigma > 0.0: + # Soft-NMS returns (out_data, box_indices, valid_box_count) + processed_data = relax.op.squeeze(nms_ret[0], axis=[0]) + selected_indices = relax.op.squeeze(nms_ret[1], axis=[0]) + selected_indices = relax.op.strided_slice( + selected_indices, axes=[0], begin=[0], end=[max_output_size] + ) + num_valid = relax.op.reshape(nms_ret[2], []) + + # Extract decayed scores from the processed data (score_index=0) + selected_scores = relax.op.strided_slice( + processed_data, axes=[1], begin=[0], end=[1] + ) + selected_scores = relax.op.squeeze(selected_scores, axis=[1]) + selected_scores = relax.op.strided_slice( + selected_scores, axes=[0], begin=[0], end=[max_output_size] + ) + else: + # Hard NMS returns (box_indices, valid_box_count) + selected_indices = relax.op.squeeze(nms_ret[0], axis=[0]) + selected_indices = relax.op.strided_slice( + selected_indices, axes=[0], begin=[0], end=[max_output_size] + ) + num_valid = relax.op.reshape(nms_ret[1], []) - # Clamp out-of-bound padded indices to prevent take() crash. - num_boxes = int(self.get_tensor_shape(input_tensors[0])[0]) - safe_indices = relax.op.clip(selected_indices, min=0, max=num_boxes - 1) - selected_scores = relax.op.take(scores, safe_indices, axis=0) + # Clamp out-of-bound padded indices to prevent take() crash. + num_boxes = int(self.get_tensor_shape(input_tensors[0])[0]) + safe_indices = relax.op.clip(selected_indices, min=0, max=num_boxes - 1) + selected_scores = relax.op.take(scores, safe_indices, axis=0) out = relax.Tuple([selected_indices, selected_scores, num_valid]) return out diff --git a/python/tvm/relax/op/vision/nms.py b/python/tvm/relax/op/vision/nms.py index 4eb3eb7f7a78..6cf6a93f9c14 100644 --- a/python/tvm/relax/op/vision/nms.py +++ b/python/tvm/relax/op/vision/nms.py @@ -115,6 +115,8 @@ def non_max_suppression( id_index=0, return_indices=True, invalid_to_bottom=False, + soft_nms_sigma=0.0, + score_threshold=0.0, ): """Non-maximum suppression operator for object detection. @@ -160,12 +162,24 @@ def non_max_suppression( Whether to move valid bounding boxes to the top of the returned tensor. This option only affects the ``return_indices=False`` path. + soft_nms_sigma : float, optional + Sigma for soft-NMS Gaussian penalty. When ``0.0`` (default), standard + hard NMS is used. Positive values decay overlapping box scores instead + of suppressing them outright. + + score_threshold : float, optional + Minimum score for a box to be eligible for selection during soft-NMS. + Only used when ``soft_nms_sigma > 0``. Defaults to ``0.0``. + Returns ------- out : relax.Expr - If ``return_indices`` is ``True``, returns - ``(box_indices, valid_box_count)`` with shapes + If ``return_indices`` is ``True`` and ``soft_nms_sigma`` is ``0.0``, + returns ``(box_indices, valid_box_count)`` with shapes ``[batch_size, num_anchors]`` and ``[batch_size, 1]``. + If ``return_indices`` is ``True`` and ``soft_nms_sigma > 0``, + returns ``(out_data, box_indices, valid_box_count)`` where + ``out_data`` has the same shape as the input data. Otherwise returns the modified data tensor. """ return _ffi_api.non_max_suppression( @@ -181,4 +195,6 @@ def non_max_suppression( id_index, return_indices, invalid_to_bottom, + soft_nms_sigma, + score_threshold, ) diff --git a/python/tvm/relax/transform/legalize_ops/vision.py b/python/tvm/relax/transform/legalize_ops/vision.py index c515fc8fe81a..4419549164cf 100644 --- a/python/tvm/relax/transform/legalize_ops/vision.py +++ b/python/tvm/relax/transform/legalize_ops/vision.py @@ -152,6 +152,8 @@ def _non_max_suppression(block_builder: BlockBuilder, call: Call) -> Expr: id_index=call.attrs.id_index, return_indices=call.attrs.return_indices, invalid_to_bottom=call.attrs.invalid_to_bottom, + soft_nms_sigma=call.attrs.soft_nms_sigma, + score_threshold=call.attrs.score_threshold, ) diff --git a/python/tvm/topi/testing/nms_python.py b/python/tvm/topi/testing/nms_python.py index 7c8c20f5b412..757da68906ae 100644 --- a/python/tvm/topi/testing/nms_python.py +++ b/python/tvm/topi/testing/nms_python.py @@ -46,6 +46,8 @@ def non_max_suppression_python( id_index=0, return_indices=True, invalid_to_bottom=False, + soft_nms_sigma=0.0, + score_threshold=0.0, ): """Numpy reference for classic non_max_suppression. @@ -71,6 +73,9 @@ def non_max_suppression_python( compacted = np.full((batch_size, num_anchors), -1, dtype="int32") valid_box_count = np.zeros((batch_size, 1), dtype="int32") + is_soft_nms = soft_nms_sigma > 0.0 + thresh = score_threshold if is_soft_nms else 0.0 + for i in range(batch_size): nkeep = int(valid_count[i]) if 0 < top_k < nkeep: @@ -89,9 +94,10 @@ def non_max_suppression_python( # Greedy NMS num_valid = 0 for j in range(nkeep): - if out_data[i, j, score_index] <= 0: - out_data[i, j, :] = -1.0 - out_box_indices[i, j] = -1 + if out_data[i, j, score_index] <= thresh: + if not is_soft_nms: + out_data[i, j, :] = -1.0 + out_box_indices[i, j] = -1 continue if 0 < max_output_size <= num_valid: out_data[i, j, :] = -1.0 @@ -102,7 +108,7 @@ def non_max_suppression_python( # Suppress overlapping boxes for k in range(j + 1, nkeep): - if out_data[i, k, score_index] <= 0: + if out_data[i, k, score_index] <= thresh: continue do_suppress = False @@ -116,8 +122,12 @@ def non_max_suppression_python( if do_suppress: iou = _iou(out_data[i, j], out_data[i, k], coord_start) if iou >= iou_threshold: - out_data[i, k, score_index] = -1.0 - out_box_indices[i, k] = -1 + if is_soft_nms: + decay = np.exp(-(iou ** 2) / soft_nms_sigma) + out_data[i, k, score_index] *= decay + else: + out_data[i, k, score_index] = -1.0 + out_box_indices[i, k] = -1 if return_indices: # Compact valid indices to top and remap to original @@ -130,6 +140,8 @@ def non_max_suppression_python( valid_box_count[i, 0] = cnt if return_indices: + if is_soft_nms: + return [out_data, compacted, valid_box_count] return [compacted, valid_box_count] if invalid_to_bottom: diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index a602527fcc8d..5773a739c99e 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -188,6 +188,8 @@ def _classic_nms_ir( out_data, out_box_indices, out_valid_box_count, + soft_nms_sigma=0.0, + score_threshold=0.0, ): """IR for classic single-class non-maximum suppression.""" with IRBuilder() as ib: @@ -200,6 +202,10 @@ def _classic_nms_ir( if out_valid_box_count is not None: out_valid_box_count = T.buffer_proxy(out_valid_box_count) + is_soft_nms = soft_nms_sigma > 0.0 + # For hard NMS the historical threshold is 0.0; for soft NMS use score_threshold. + thresh = tvm.tirx.Cast(data.dtype, T.float32(score_threshold if is_soft_nms else 0.0)) + with T.parallel(0, batch_size) as i: # Step 1: Reorder data by sorted score nkeep_buf = T.alloc_buffer((1,), "int32", scope="local") @@ -228,10 +234,10 @@ def _classic_nms_ir( num_valid_boxes[0] = T.int32(0) with T.serial(0, nkeep_local[0]) as j: - # Check if box j is still valid (score > 0) and within max_output_size + # Check if box j is still valid (score > threshold) and within max_output_size with T.If( tvm.tirx.all( - out_data[i, j, score_index] > tvm.tirx.Cast(data.dtype, T.float32(0.0)), + out_data[i, j, score_index] > thresh, tvm.tirx.Select( max_output_size > 0, num_valid_boxes[0] < max_output_size, @@ -247,8 +253,7 @@ def _classic_nms_ir( with T.If( tvm.tirx.all( k > j, - out_data[i, k, score_index] - > tvm.tirx.Cast(data.dtype, T.float32(0.0)), + out_data[i, k, score_index] > thresh, ) ): with T.Then(): @@ -322,10 +327,23 @@ def _classic_nms_ir( with T.If(iou >= iou_threshold): with T.Then(): - out_data[i, k, score_index] = tvm.tirx.Cast( - data.dtype, T.float32(-1.0) - ) - out_box_indices[i, k] = T.int32(-1) + if is_soft_nms: + # Soft-NMS Gaussian decay + decay = tvm.tirx.exp( + -(iou * iou) + / tvm.tirx.Cast( + data.dtype, + T.float32(soft_nms_sigma), + ) + ) + out_data[i, k, score_index] = ( + out_data[i, k, score_index] * decay + ) + else: + out_data[i, k, score_index] = tvm.tirx.Cast( + data.dtype, T.float32(-1.0) + ) + out_box_indices[i, k] = T.int32(-1) with T.Else(): # Box suppressed or beyond max_output_size @@ -372,6 +390,8 @@ def non_max_suppression( id_index=0, return_indices=True, invalid_to_bottom=False, + soft_nms_sigma=0.0, + score_threshold=0.0, ): """Non-maximum suppression operator for object detection. @@ -416,6 +436,12 @@ def non_max_suppression( invalid_to_bottom : optional, boolean Whether to move all valid bounding boxes to the top. + soft_nms_sigma : optional, float + Sigma for soft-NMS Gaussian penalty. 0.0 means standard hard NMS. + + score_threshold : optional, float + Minimum score for a box to be eligible during soft-NMS. + Returns ------- out : tvm.te.Tensor or tuple of tvm.te.Tensor @@ -464,6 +490,7 @@ def non_max_suppression( coord_start, score_index, id_index, return_indices, outs[0], outs[1], outs[2], + soft_nms_sigma, score_threshold, ), dtype=[data.dtype, "int32", "int32"], out_buffers=[out_data_buf, out_box_indices_buf, out_valid_box_count_buf], @@ -471,6 +498,8 @@ def non_max_suppression( name="non_max_suppression", tag="non_max_suppression", ) + if soft_nms_sigma > 0.0: + return [out_data, out_box_indices, out_valid_box_count] return [out_box_indices, out_valid_box_count] out_data, out_box_indices = te.extern( @@ -484,6 +513,7 @@ def non_max_suppression( coord_start, score_index, id_index, return_indices, outs[0], outs[1], None, + soft_nms_sigma, score_threshold, ), dtype=[data.dtype, "int32"], out_buffers=[out_data_buf, out_box_indices_buf], diff --git a/src/relax/op/vision/nms.cc b/src/relax/op/vision/nms.cc index 97508d721189..9d7144b5d7f3 100644 --- a/src/relax/op/vision/nms.cc +++ b/src/relax/op/vision/nms.cc @@ -196,8 +196,8 @@ TVM_REGISTER_OP("relax.vision.get_valid_counts") Expr non_max_suppression(Expr data, Expr valid_count, Expr indices, int max_output_size, double iou_threshold, bool force_suppress, int top_k, int coord_start, - int score_index, int id_index, bool return_indices, - bool invalid_to_bottom) { + int score_index, int id_index, bool return_indices, bool invalid_to_bottom, + double soft_nms_sigma, double score_threshold) { auto attrs = tvm::ffi::make_object(); attrs->max_output_size = max_output_size; attrs->iou_threshold = iou_threshold; @@ -208,6 +208,8 @@ Expr non_max_suppression(Expr data, Expr valid_count, Expr indices, int max_outp attrs->id_index = id_index; attrs->return_indices = return_indices; attrs->invalid_to_bottom = invalid_to_bottom; + attrs->soft_nms_sigma = soft_nms_sigma; + attrs->score_threshold = score_threshold; static const Op& op = Op::Get("relax.vision.non_max_suppression"); return Call(op, {std::move(data), std::move(valid_count), std::move(indices)}, Attrs(attrs), {}); @@ -319,7 +321,28 @@ StructInfo InferStructInfoNMS(const Call& call, const BlockBuilder& ctx) { } if (attrs->return_indices) { - // Returns (box_indices[batch, num_anchors], valid_box_count[batch, 1]) + if (attrs->soft_nms_sigma > 0.0) { + // Soft-NMS returns (out_data[batch, num_anchors, elem_length], + // box_indices[batch, num_anchors], + // valid_box_count[batch, 1]) + if (data_shape == nullptr) { + tvm::ffi::Array fields = { + TensorStructInfo(data_sinfo->dtype, /*ndim=*/3, vdev), + TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev), + TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev)}; + return TupleStructInfo(fields); + } + auto batch = data_shape->values[0]; + auto num_anchors = data_shape->values[1]; + tvm::ffi::Array fields = { + TensorStructInfo(ffi::GetRef(data_shape), data_sinfo->dtype, vdev), + TensorStructInfo(ShapeExpr({batch, num_anchors}), DataType::Int(32), vdev), + TensorStructInfo(ShapeExpr({batch, IntImm(DataType::Int(64), 1)}), DataType::Int(32), + vdev)}; + return TupleStructInfo(fields); + } + + // Hard NMS returns (box_indices[batch, num_anchors], valid_box_count[batch, 1]) if (data_shape == nullptr) { tvm::ffi::Array fields = { TensorStructInfo(DataType::Int(32), /*ndim=*/2, vdev), diff --git a/src/relax/op/vision/nms.h b/src/relax/op/vision/nms.h index 3fbd2609e289..83ca5b1bc083 100644 --- a/src/relax/op/vision/nms.h +++ b/src/relax/op/vision/nms.h @@ -44,8 +44,8 @@ Expr get_valid_counts(Expr data, double score_threshold, int id_index, int score /*! \brief Non-maximum suppression for object detection. */ Expr non_max_suppression(Expr data, Expr valid_count, Expr indices, int max_output_size, double iou_threshold, bool force_suppress, int top_k, int coord_start, - int score_index, int id_index, bool return_indices, - bool invalid_to_bottom); + int score_index, int id_index, bool return_indices, bool invalid_to_bottom, + double soft_nms_sigma = 0.0, double score_threshold = 0.0); } // namespace relax } // namespace tvm diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 92080634e2cc..1993a788793f 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -1353,7 +1353,9 @@ def _verify_nms_v5(mod, tf_func, boxes_np, scores_np): ) -def _build_nms_v5_mod(num_boxes, max_output_size, iou_threshold, score_threshold): +def _build_nms_v5_mod( + num_boxes, max_output_size, iou_threshold, score_threshold, soft_nms_sigma=0.0 +): """Convert a NonMaxSuppressionV5 TFLite model to a Relax module. Scalar params must be Python literals (not tf.constant) so TFLite can @@ -1374,7 +1376,7 @@ def func(self, boxes, scores): max_output_size=max_output_size, iou_threshold=iou_threshold, score_threshold=score_threshold, - soft_nms_sigma=0.0, + soft_nms_sigma=soft_nms_sigma, pad_to_max_output_size=True, ) return indices, out_scores, valid @@ -1612,6 +1614,49 @@ def _make_valid_boxes(rng, n): ] +_NMS_V5_SOFT_CASES = [ + pytest.param( + 6, + 6, + 0.5, + 0.0, + 0.5, + np.array( + [ + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.0, 1.0, 1.0], + [0.0, 0.1, 1.0, 1.1], + [0.0, 0.0, 1.0, 0.9], + [0.5, 0.5, 1.5, 1.5], + [0.0, 0.0, 0.3, 0.3], + ], + dtype=np.float32, + ), + np.array([0.9, 0.75, 0.6, 0.5, 0.4, 0.3], dtype=np.float32), + id="soft_nms_basic", + ), + pytest.param( + 5, + 5, + 0.5, + 0.0, + 0.3, + np.array( + [ + [0.0, 0.0, 1.0, 1.0], + [0.1, 0.1, 1.1, 1.1], + [0.2, 0.2, 1.2, 1.2], + [0.3, 0.3, 1.3, 1.3], + [2.0, 2.0, 3.0, 3.0], + ], + dtype=np.float32, + ), + np.array([0.9, 0.8, 0.7, 0.6, 0.5], dtype=np.float32), + id="soft_nms_tight_sigma", + ), +] + + @pytest.mark.parametrize( "num_boxes,max_output_size,iou_threshold,score_threshold,boxes,scores", _NMS_V5_CASES, @@ -1622,6 +1667,20 @@ def test_nms_v5(num_boxes, max_output_size, iou_threshold, score_threshold, boxe _verify_nms_v5(mod, tf_func, boxes, scores) +@pytest.mark.parametrize( + "num_boxes,max_output_size,iou_threshold,score_threshold,soft_nms_sigma,boxes,scores", + _NMS_V5_SOFT_CASES, +) +def test_nms_v5_soft( + num_boxes, max_output_size, iou_threshold, score_threshold, soft_nms_sigma, boxes, scores +): + """NON_MAX_SUPPRESSION_V5 with soft_nms_sigma: conversion smoke test + E2E correctness.""" + mod, tf_func = _build_nms_v5_mod( + num_boxes, max_output_size, iou_threshold, score_threshold, soft_nms_sigma + ) + _verify_nms_v5(mod, tf_func, boxes, scores) + + def test_nms_v5_ir(): """Verify the emitted Relax IR has correct structure for NON_MAX_SUPPRESSION_V5.""" num_boxes = 6 @@ -1646,6 +1705,26 @@ def test_nms_v5_ir(): assert f"R.Tensor(({max_output_size},)" in ir +def test_nms_v5_soft_ir(): + """Verify the emitted Relax IR passes soft_nms_sigma for NON_MAX_SUPPRESSION_V5.""" + num_boxes = 6 + max_output_size = 3 + mod, _ = _build_nms_v5_mod( + num_boxes=num_boxes, + max_output_size=max_output_size, + iou_threshold=0.5, + score_threshold=0.0, + soft_nms_sigma=0.5, + ) + + ir = mod.script() + + # soft_nms_sigma must appear in the IR + assert "soft_nms_sigma=0.5" in ir + # score_threshold must also be forwarded + assert "score_threshold=0.0" in ir + + _DETECTION_POSTPROCESS_SMOKE_CASES = [ pytest.param( { diff --git a/tests/python/relax/test_op_vision.py b/tests/python/relax/test_op_vision.py index b597b325f4fe..6e8a0a364621 100644 --- a/tests/python/relax/test_op_vision.py +++ b/tests/python/relax/test_op_vision.py @@ -302,6 +302,26 @@ def test_nms_infer_struct_info_return_indices(): ) +def test_nms_infer_struct_info_return_indices_soft_nms(): + bb = relax.BlockBuilder() + data = relax.Var("data", R.Tensor((2, 10, 6), "float32")) + valid_count = relax.Var("valid_count", R.Tensor((2,), "int32")) + indices = relax.Var("indices", R.Tensor((2, 10), "int32")) + _check_inference( + bb, + relax.op.vision.non_max_suppression( + data, valid_count, indices, return_indices=True, soft_nms_sigma=0.5 + ), + relax.TupleStructInfo( + [ + relax.TensorStructInfo((2, 10, 6), "float32"), + relax.TensorStructInfo((2, 10), "int32"), + relax.TensorStructInfo((2, 1), "int32"), + ] + ), + ) + + def test_nms_infer_struct_info_return_data(): bb = relax.BlockBuilder() data = relax.Var("data", R.Tensor((2, 10, 6), "float32")) @@ -457,6 +477,8 @@ def main( id_index=0, return_indices=True, invalid_to_bottom=False, + soft_nms_sigma=0.0, + score_threshold=0.0, ) return gv @@ -473,6 +495,51 @@ def main( ) +def test_nms_legalize_soft_nms(): + @tvm.script.ir_module + class NMS: + @R.function + def main( + data: R.Tensor((1, 5, 6), "float32"), + valid_count: R.Tensor((1,), "int32"), + indices: R.Tensor((1, 5), "int32"), + ) -> R.Tuple( + R.Tensor((1, 5, 6), "float32"), + R.Tensor((1, 5), "int32"), + R.Tensor((1, 1), "int32"), + ): + gv = R.vision.non_max_suppression( + data, + valid_count, + indices, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=False, + top_k=-1, + coord_start=2, + score_index=1, + id_index=0, + return_indices=True, + invalid_to_bottom=False, + soft_nms_sigma=0.5, + score_threshold=0.0, + ) + return gv + + mod = LegalizeOps()(NMS) + _assert_relax_op_legalized(mod, "relax.vision.non_max_suppression") + tvm.ir.assert_structural_equal( + mod["main"].ret_struct_info, + relax.TupleStructInfo( + [ + relax.TensorStructInfo((1, 5, 6), "float32"), + relax.TensorStructInfo((1, 5), "int32"), + relax.TensorStructInfo((1, 1), "int32"), + ] + ), + ) + + def test_nms_legalize_return_data(): @tvm.script.ir_module class NMS: @@ -495,6 +562,8 @@ def main( id_index=0, return_indices=False, invalid_to_bottom=True, + soft_nms_sigma=0.0, + score_threshold=0.0, ) return gv diff --git a/tests/python/relax/test_tvmscript_parser_op_vision.py b/tests/python/relax/test_tvmscript_parser_op_vision.py index 370b68769e6e..d4755ee367f7 100644 --- a/tests/python/relax/test_tvmscript_parser_op_vision.py +++ b/tests/python/relax/test_tvmscript_parser_op_vision.py @@ -126,6 +126,8 @@ def foo( id_index=0, return_indices=True, invalid_to_bottom=False, + soft_nms_sigma=0.0, + score_threshold=0.0, ) ) return gv @@ -150,6 +152,70 @@ def foo( id_index=0, return_indices=True, invalid_to_bottom=False, + soft_nms_sigma=0.0, + score_threshold=0.0, + ) + ) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_non_max_suppression_return_indices_soft_nms(): + @R.function + def foo( + data: R.Tensor((2, 5, 6), "float32"), + valid_count: R.Tensor((2,), "int32"), + indices: R.Tensor((2, 5), "int32"), + ) -> R.Tuple( + R.Tensor((2, 5, 6), "float32"), + R.Tensor((2, 5), "int32"), + R.Tensor((2, 1), "int32"), + ): + gv: R.Tuple( + R.Tensor((2, 5, 6), "float32"), + R.Tensor((2, 5), "int32"), + R.Tensor((2, 1), "int32"), + ) = R.vision.non_max_suppression( + data, + valid_count, + indices, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=False, + top_k=3, + coord_start=2, + score_index=1, + id_index=0, + return_indices=True, + invalid_to_bottom=False, + soft_nms_sigma=0.5, + score_threshold=0.0, + ) + return gv + + data = relax.Var("data", R.Tensor((2, 5, 6), "float32")) + valid_count = relax.Var("valid_count", R.Tensor((2,), "int32")) + indices = relax.Var("indices", R.Tensor((2, 5), "int32")) + + bb = relax.BlockBuilder() + with bb.function("foo", [data, valid_count, indices]): + gv = bb.emit( + relax.op.vision.non_max_suppression( + data, + valid_count, + indices, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=False, + top_k=3, + coord_start=2, + score_index=1, + id_index=0, + return_indices=True, + invalid_to_bottom=False, + soft_nms_sigma=0.5, + score_threshold=0.0, ) ) bb.emit_func_output(gv) @@ -177,6 +243,8 @@ def foo( id_index=0, return_indices=False, invalid_to_bottom=True, + soft_nms_sigma=0.0, + score_threshold=0.0, ) return gv @@ -200,6 +268,8 @@ def foo( id_index=0, return_indices=False, invalid_to_bottom=True, + soft_nms_sigma=0.0, + score_threshold=0.0, ) ) bb.emit_func_output(gv) From ab5a89ab44f7a03282ec7ded759afb28d4844dea Mon Sep 17 00:00:00 2001 From: Aharrypotter <62729549+Aharrypotter@users.noreply.github.com> Date: Mon, 20 Apr 2026 10:23:07 +0800 Subject: [PATCH 2/4] [Relax][TOPI] Fix soft-NMS output alignment in non_max_suppression This commit fixes a critical bug in the soft-NMS path where decayed boxes whose scores dropped below the threshold were not properly invalidated, causing misalignment between selected_indices and selected_scores. Fixes in TOPI _classic_nms_ir: 1. After Gaussian decay, invalidate boxes with score <= threshold by setting their out_box_indices to -1. 2. During the compaction phase for return_indices=True, also compact out_data so that the first num_valid entries match the selected boxes in order. Fill remaining entries with -1. Fixes in the NumPy reference implementation (nms_python.py): - Synchronize the same post-decay invalidation logic to keep the reference behavior consistent with the TIR kernel and TensorFlow. Without this fix, the TFLite frontend would extract wrong scores from processed_data[:max_output_size] when a middle box was decayed below threshold but later boxes were still valid. --- python/tvm/topi/testing/nms_python.py | 6 ++++++ python/tvm/topi/vision/nms.py | 13 +++++++++++++ 2 files changed, 19 insertions(+) diff --git a/python/tvm/topi/testing/nms_python.py b/python/tvm/topi/testing/nms_python.py index 757da68906ae..83c7ecff16de 100644 --- a/python/tvm/topi/testing/nms_python.py +++ b/python/tvm/topi/testing/nms_python.py @@ -129,6 +129,12 @@ def non_max_suppression_python( out_data[i, k, score_index] = -1.0 out_box_indices[i, k] = -1 + if is_soft_nms: + # Invalidate boxes whose score dropped below threshold after decay. + for j in range(nkeep): + if out_data[i, j, score_index] <= thresh: + out_box_indices[i, j] = -1 + if return_indices: # Compact valid indices to top and remap to original cnt = 0 diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 5773a739c99e..5cbfb6d458a6 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -351,6 +351,13 @@ def _classic_nms_ir( out_data[i, j, k] = tvm.tirx.Cast(data.dtype, T.float32(-1.0)) out_box_indices[i, j] = T.int32(-1) + # Step 2b: For soft-NMS, invalidate boxes whose score dropped below threshold. + if is_soft_nms: + with T.serial(0, nkeep_local[0]) as j: + with T.If(out_data[i, j, score_index] <= thresh): + with T.Then(): + out_box_indices[i, j] = T.int32(-1) + # Step 3: If return_indices, remap to original indices if return_indices: if out_valid_box_count is not None: @@ -362,6 +369,9 @@ def _classic_nms_ir( with T.serial(0, num_anchors) as j: with T.If(out_box_indices[i, j] >= 0): with T.Then(): + if is_soft_nms: + with T.serial(0, box_data_length) as k: + out_data[i, valid_idx[0], k] = out_data[i, j, k] orig_idx = out_box_indices[i, j] out_box_indices[i, valid_idx[0]] = indices[i, orig_idx] valid_idx[0] = valid_idx[0] + 1 @@ -373,6 +383,9 @@ def _classic_nms_ir( with T.If(j >= valid_idx[0]): with T.Then(): out_box_indices[i, j] = T.int32(-1) + if is_soft_nms: + with T.serial(0, box_data_length) as k: + out_data[i, j, k] = tvm.tirx.Cast(data.dtype, T.float32(-1.0)) return ib.get() From b2c557fc8f7919bc49137b2f4a4a6b2b2a8ab5c1 Mon Sep 17 00:00:00 2001 From: Aharrypotter <62729549+Aharrypotter@users.noreply.github.com> Date: Mon, 20 Apr 2026 23:40:18 +0800 Subject: [PATCH 3/4] [Relax][TOPI][TFLite] Align soft-NMS with LiteRT reference Align the soft-NMS path in non_max_suppression with LiteRT's reference behavior. - Rewrite soft-NMS candidate selection in TOPI TIR and the NumPy reference to re-pick the current best score after each decay step - Match LiteRT's Gaussian decay formula and keep decayed scores/data aligned with the returned indices for TFLite NMSv5 - Update conditional soft-NMS return documentation and add Relax/TFLite tests for reordered outputs and forwarded soft-NMS attributes Validation: - python -m pytest -n 1 tests/python/relax/test_op_vision.py -k "all_class_non_max_suppression or get_valid_counts or nms" -v Result: - 32 passed --- python/tvm/topi/testing/nms_python.py | 81 +++- python/tvm/topi/vision/nms.py | 430 +++++++++++++++------ tests/python/relax/test_frontend_tflite.py | 34 ++ tests/python/relax/test_op_vision.py | 55 +++ 4 files changed, 464 insertions(+), 136 deletions(-) diff --git a/python/tvm/topi/testing/nms_python.py b/python/tvm/topi/testing/nms_python.py index 83c7ecff16de..493270f9d4ab 100644 --- a/python/tvm/topi/testing/nms_python.py +++ b/python/tvm/topi/testing/nms_python.py @@ -64,7 +64,9 @@ def non_max_suppression_python( Returns ------- - If return_indices is True: (box_indices, valid_box_count) + If return_indices is True and soft_nms_sigma == 0.0: (box_indices, valid_box_count) + If return_indices is True and soft_nms_sigma > 0.0: + (out_data, box_indices, valid_box_count) Otherwise: modified data tensor """ batch_size, num_anchors, _ = data.shape @@ -75,6 +77,7 @@ def non_max_suppression_python( is_soft_nms = soft_nms_sigma > 0.0 thresh = score_threshold if is_soft_nms else 0.0 + soft_nms_scale = -0.5 / soft_nms_sigma if is_soft_nms else 0.0 for i in range(batch_size): nkeep = int(valid_count[i]) @@ -91,6 +94,68 @@ def non_max_suppression_python( out_data[i, j, :] = data[i, src, :] out_box_indices[i, j] = src + if is_soft_nms: + num_selected = 0 + while num_selected < nkeep and (max_output_size < 0 or num_selected < max_output_size): + best_idx = -1 + best_score = thresh + for j in range(num_selected, nkeep): + if out_box_indices[i, j] >= 0 and out_data[i, j, score_index] > best_score: + best_idx = j + best_score = out_data[i, j, score_index] + + if best_idx < 0: + break + + if best_idx != num_selected: + out_data[i, [num_selected, best_idx], :] = out_data[ + i, [best_idx, num_selected], : + ] + out_box_indices[i, [num_selected, best_idx]] = out_box_indices[ + i, [best_idx, num_selected] + ] + + selected_idx = num_selected + for j in range(selected_idx + 1, nkeep): + if out_box_indices[i, j] < 0 or out_data[i, j, score_index] <= thresh: + continue + + do_suppress = False + if force_suppress: + do_suppress = True + elif id_index >= 0: + do_suppress = ( + out_data[i, selected_idx, id_index] == out_data[i, j, id_index] + ) + else: + do_suppress = True + + if not do_suppress: + continue + + iou = _iou(out_data[i, selected_idx], out_data[i, j], coord_start) + if iou >= iou_threshold: + out_box_indices[i, j] = -1 + else: + out_data[i, j, score_index] *= np.exp(soft_nms_scale * (iou**2)) + if out_data[i, j, score_index] <= thresh: + out_box_indices[i, j] = -1 + + num_selected += 1 + + valid_box_count[i, 0] = num_selected + if return_indices: + for j in range(num_selected): + orig_idx = out_box_indices[i, j] + compacted[i, j] = int(indices[i, orig_idx]) + out_box_indices[i, j] = compacted[i, j] + for j in range(num_selected, num_anchors): + out_data[i, j, :] = -1.0 + out_box_indices[i, j] = -1 + else: + out_data[i, num_selected:, :] = -1.0 + continue + # Greedy NMS num_valid = 0 for j in range(nkeep): @@ -122,18 +187,8 @@ def non_max_suppression_python( if do_suppress: iou = _iou(out_data[i, j], out_data[i, k], coord_start) if iou >= iou_threshold: - if is_soft_nms: - decay = np.exp(-(iou ** 2) / soft_nms_sigma) - out_data[i, k, score_index] *= decay - else: - out_data[i, k, score_index] = -1.0 - out_box_indices[i, k] = -1 - - if is_soft_nms: - # Invalidate boxes whose score dropped below threshold after decay. - for j in range(nkeep): - if out_data[i, j, score_index] <= thresh: - out_box_indices[i, j] = -1 + out_data[i, k, score_index] = -1.0 + out_box_indices[i, k] = -1 if return_indices: # Compact valid indices to top and remap to original diff --git a/python/tvm/topi/vision/nms.py b/python/tvm/topi/vision/nms.py index 5cbfb6d458a6..91fbf057411b 100644 --- a/python/tvm/topi/vision/nms.py +++ b/python/tvm/topi/vision/nms.py @@ -228,140 +228,327 @@ def _classic_nms_ir( out_data[i, j, k] = tvm.tirx.Cast(data.dtype, T.float32(-1.0)) out_box_indices[i, j] = T.int32(-1) - # Step 2: Apply NMS - greedy suppression - num_valid_boxes_buf = T.alloc_buffer((1,), "int32", scope="local") - num_valid_boxes = T.buffer_proxy(num_valid_boxes_buf) - num_valid_boxes[0] = T.int32(0) - - with T.serial(0, nkeep_local[0]) as j: - # Check if box j is still valid (score > threshold) and within max_output_size - with T.If( - tvm.tirx.all( - out_data[i, j, score_index] > thresh, + if is_soft_nms: + # LiteRT soft-NMS selects the current highest-score candidate each round. + soft_nms_scale = tvm.tirx.Cast(data.dtype, T.float32(-0.5 / soft_nms_sigma)) + num_valid_boxes_buf = T.alloc_buffer((1,), "int32", scope="local") + num_valid_boxes = T.buffer_proxy(num_valid_boxes_buf) + num_valid_boxes[0] = T.int32(0) + + with T.serial(0, nkeep_local[0]) as _: + with T.If( tvm.tirx.Select( max_output_size > 0, num_valid_boxes[0] < max_output_size, tvm.tirx.const(True), - ), - ) - ): - with T.Then(): - num_valid_boxes[0] = num_valid_boxes[0] + 1 - - # Suppress overlapping boxes - with T.serial(0, nkeep_local[0]) as k: - with T.If( - tvm.tirx.all( - k > j, - out_data[i, k, score_index] > thresh, - ) - ): + ) + ): + with T.Then(): + best_idx_buf = T.alloc_buffer((1,), "int32", scope="local") + best_idx = T.buffer_proxy(best_idx_buf) + best_idx[0] = T.int32(-1) + best_score_buf = T.alloc_buffer((1,), data.dtype, scope="local") + best_score = T.buffer_proxy(best_score_buf) + best_score[0] = thresh + + with T.serial(0, nkeep_local[0]) as j: + with T.If( + tvm.tirx.all( + j >= num_valid_boxes[0], + out_box_indices[i, j] >= 0, + out_data[i, j, score_index] > best_score[0], + ) + ): + with T.Then(): + best_idx[0] = j + best_score[0] = out_data[i, j, score_index] + + with T.If(best_idx[0] >= 0): with T.Then(): - # Check class ID match (or force_suppress) - do_suppress = tvm.tirx.const(False) - if force_suppress: - do_suppress = tvm.tirx.const(True) - elif id_index >= 0: - do_suppress = ( - out_data[i, j, id_index] == out_data[i, k, id_index] - ) - else: - do_suppress = tvm.tirx.const(True) - - with T.If(do_suppress): + with T.If(best_idx[0] != num_valid_boxes[0]): with T.Then(): - # Calculate IoU - a_l = tvm.te.min( - out_data[i, j, coord_start], - out_data[i, j, coord_start + 2], - ) - a_t = tvm.te.min( - out_data[i, j, coord_start + 1], - out_data[i, j, coord_start + 3], - ) - a_r = tvm.te.max( - out_data[i, j, coord_start], - out_data[i, j, coord_start + 2], - ) - a_b = tvm.te.max( - out_data[i, j, coord_start + 1], - out_data[i, j, coord_start + 3], - ) - - b_l = tvm.te.min( - out_data[i, k, coord_start], - out_data[i, k, coord_start + 2], + tmp_idx_buf = T.alloc_buffer( + (1,), "int32", scope="local" ) - b_t = tvm.te.min( - out_data[i, k, coord_start + 1], - out_data[i, k, coord_start + 3], - ) - b_r = tvm.te.max( - out_data[i, k, coord_start], - out_data[i, k, coord_start + 2], - ) - b_b = tvm.te.max( - out_data[i, k, coord_start + 1], - out_data[i, k, coord_start + 3], + tmp_idx = T.buffer_proxy(tmp_idx_buf) + tmp_idx[0] = out_box_indices[i, num_valid_boxes[0]] + out_box_indices[ + i, num_valid_boxes[0] + ] = out_box_indices[i, best_idx[0]] + out_box_indices[i, best_idx[0]] = tmp_idx[0] + + with T.serial(0, box_data_length) as k: + tmp_val_buf = T.alloc_buffer( + (1,), data.dtype, scope="local" + ) + tmp_val = T.buffer_proxy(tmp_val_buf) + tmp_val[0] = out_data[i, num_valid_boxes[0], k] + out_data[i, num_valid_boxes[0], k] = out_data[ + i, best_idx[0], k + ] + out_data[i, best_idx[0], k] = tmp_val[0] + + with T.serial(0, nkeep_local[0]) as j: + with T.If( + tvm.tirx.all( + j > num_valid_boxes[0], + out_box_indices[i, j] >= 0, + out_data[i, j, score_index] > thresh, ) + ): + with T.Then(): + do_suppress = tvm.tirx.const(False) + if force_suppress: + do_suppress = tvm.tirx.const(True) + elif id_index >= 0: + do_suppress = ( + out_data[i, num_valid_boxes[0], id_index] + == out_data[i, j, id_index] + ) + else: + do_suppress = tvm.tirx.const(True) + + with T.If(do_suppress): + with T.Then(): + a_l = tvm.te.min( + out_data[ + i, num_valid_boxes[0], coord_start + ], + out_data[ + i, + num_valid_boxes[0], + coord_start + 2, + ], + ) + a_t = tvm.te.min( + out_data[ + i, + num_valid_boxes[0], + coord_start + 1, + ], + out_data[ + i, + num_valid_boxes[0], + coord_start + 3, + ], + ) + a_r = tvm.te.max( + out_data[ + i, num_valid_boxes[0], coord_start + ], + out_data[ + i, + num_valid_boxes[0], + coord_start + 2, + ], + ) + a_b = tvm.te.max( + out_data[ + i, + num_valid_boxes[0], + coord_start + 1, + ], + out_data[ + i, + num_valid_boxes[0], + coord_start + 3, + ], + ) - w = tvm.te.max( - tvm.tirx.Cast(data.dtype, T.float32(0.0)), - tvm.te.min(a_r, b_r) - tvm.te.max(a_l, b_l), - ) - h = tvm.te.max( - tvm.tirx.Cast(data.dtype, T.float32(0.0)), - tvm.te.min(a_b, b_b) - tvm.te.max(a_t, b_t), - ) - area = h * w - u = ( - (a_r - a_l) * (a_b - a_t) - + (b_r - b_l) * (b_b - b_t) - - area - ) - iou = tvm.tirx.Select( - u <= tvm.tirx.Cast(data.dtype, T.float32(0.0)), - tvm.tirx.Cast(data.dtype, T.float32(0.0)), - area / u, - ) + b_l = tvm.te.min( + out_data[i, j, coord_start], + out_data[i, j, coord_start + 2], + ) + b_t = tvm.te.min( + out_data[i, j, coord_start + 1], + out_data[i, j, coord_start + 3], + ) + b_r = tvm.te.max( + out_data[i, j, coord_start], + out_data[i, j, coord_start + 2], + ) + b_b = tvm.te.max( + out_data[i, j, coord_start + 1], + out_data[i, j, coord_start + 3], + ) - with T.If(iou >= iou_threshold): - with T.Then(): - if is_soft_nms: - # Soft-NMS Gaussian decay - decay = tvm.tirx.exp( - -(iou * iou) - / tvm.tirx.Cast( - data.dtype, - T.float32(soft_nms_sigma), - ) + zero = tvm.tirx.Cast( + data.dtype, T.float32(0.0) ) - out_data[i, k, score_index] = ( - out_data[i, k, score_index] * decay + w = tvm.te.max( + zero, + tvm.te.min(a_r, b_r) + - tvm.te.max(a_l, b_l), ) - else: - out_data[i, k, score_index] = tvm.tirx.Cast( - data.dtype, T.float32(-1.0) + h = tvm.te.max( + zero, + tvm.te.min(a_b, b_b) + - tvm.te.max(a_t, b_t), + ) + area = h * w + u = ( + (a_r - a_l) * (a_b - a_t) + + (b_r - b_l) * (b_b - b_t) + - area + ) + iou = tvm.tirx.Select( + u + <= tvm.tirx.Cast( + data.dtype, T.float32(0.0) + ), + zero, + area / u, ) - out_box_indices[i, k] = T.int32(-1) - with T.Else(): - # Box suppressed or beyond max_output_size - with T.serial(0, box_data_length) as k: - out_data[i, j, k] = tvm.tirx.Cast(data.dtype, T.float32(-1.0)) - out_box_indices[i, j] = T.int32(-1) + with T.If(iou >= iou_threshold): + with T.Then(): + out_box_indices[i, j] = T.int32(-1) + with T.If(iou < iou_threshold): + with T.Then(): + out_data[i, j, score_index] = ( + out_data[i, j, score_index] + * tvm.tirx.exp( + soft_nms_scale + * iou + * iou + ) + ) + with T.If( + out_data[i, j, score_index] + <= thresh + ): + with T.Then(): + out_box_indices[ + i, j + ] = T.int32(-1) + + num_valid_boxes[0] = num_valid_boxes[0] + 1 + + if return_indices: + out_valid_box_count[i, 0] = num_valid_boxes[0] + + with T.serial(0, num_anchors) as j: + with T.If(j < num_valid_boxes[0]): + with T.Then(): + orig_idx = out_box_indices[i, j] + out_box_indices[i, j] = indices[i, orig_idx] + with T.If(j >= num_valid_boxes[0]): + with T.Then(): + with T.serial(0, box_data_length) as k: + out_data[i, j, k] = tvm.tirx.Cast( + data.dtype, T.float32(-1.0) + ) + out_box_indices[i, j] = T.int32(-1) + else: + with T.serial(0, num_anchors) as j: + with T.If(j >= num_valid_boxes[0]): + with T.Then(): + with T.serial(0, box_data_length) as k: + out_data[i, j, k] = tvm.tirx.Cast(data.dtype, T.float32(-1.0)) + else: + # Step 2: Apply hard NMS - greedy suppression + num_valid_boxes_buf = T.alloc_buffer((1,), "int32", scope="local") + num_valid_boxes = T.buffer_proxy(num_valid_boxes_buf) + num_valid_boxes[0] = T.int32(0) - # Step 2b: For soft-NMS, invalidate boxes whose score dropped below threshold. - if is_soft_nms: with T.serial(0, nkeep_local[0]) as j: - with T.If(out_data[i, j, score_index] <= thresh): + with T.If( + tvm.tirx.all( + out_data[i, j, score_index] > thresh, + tvm.tirx.Select( + max_output_size > 0, + num_valid_boxes[0] < max_output_size, + tvm.tirx.const(True), + ), + ) + ): with T.Then(): + num_valid_boxes[0] = num_valid_boxes[0] + 1 + + with T.serial(0, nkeep_local[0]) as k: + with T.If( + tvm.tirx.all(k > j, out_data[i, k, score_index] > thresh) + ): + with T.Then(): + do_suppress = tvm.tirx.const(False) + if force_suppress: + do_suppress = tvm.tirx.const(True) + elif id_index >= 0: + do_suppress = ( + out_data[i, j, id_index] == out_data[i, k, id_index] + ) + else: + do_suppress = tvm.tirx.const(True) + + with T.If(do_suppress): + with T.Then(): + a_l = tvm.te.min( + out_data[i, j, coord_start], + out_data[i, j, coord_start + 2], + ) + a_t = tvm.te.min( + out_data[i, j, coord_start + 1], + out_data[i, j, coord_start + 3], + ) + a_r = tvm.te.max( + out_data[i, j, coord_start], + out_data[i, j, coord_start + 2], + ) + a_b = tvm.te.max( + out_data[i, j, coord_start + 1], + out_data[i, j, coord_start + 3], + ) + + b_l = tvm.te.min( + out_data[i, k, coord_start], + out_data[i, k, coord_start + 2], + ) + b_t = tvm.te.min( + out_data[i, k, coord_start + 1], + out_data[i, k, coord_start + 3], + ) + b_r = tvm.te.max( + out_data[i, k, coord_start], + out_data[i, k, coord_start + 2], + ) + b_b = tvm.te.max( + out_data[i, k, coord_start + 1], + out_data[i, k, coord_start + 3], + ) + + w = tvm.te.max( + tvm.tirx.Cast(data.dtype, T.float32(0.0)), + tvm.te.min(a_r, b_r) - tvm.te.max(a_l, b_l), + ) + h = tvm.te.max( + tvm.tirx.Cast(data.dtype, T.float32(0.0)), + tvm.te.min(a_b, b_b) - tvm.te.max(a_t, b_t), + ) + area = h * w + u = ( + (a_r - a_l) * (a_b - a_t) + + (b_r - b_l) * (b_b - b_t) + - area + ) + iou = tvm.tirx.Select( + u <= tvm.tirx.Cast(data.dtype, T.float32(0.0)), + tvm.tirx.Cast(data.dtype, T.float32(0.0)), + area / u, + ) + + with T.If(iou >= iou_threshold): + with T.Then(): + out_data[i, k, score_index] = tvm.tirx.Cast( + data.dtype, T.float32(-1.0) + ) + out_box_indices[i, k] = T.int32(-1) + + with T.Else(): + with T.serial(0, box_data_length) as k: + out_data[i, j, k] = tvm.tirx.Cast(data.dtype, T.float32(-1.0)) out_box_indices[i, j] = T.int32(-1) - # Step 3: If return_indices, remap to original indices - if return_indices: - if out_valid_box_count is not None: - # Count valid boxes and remap indices + if return_indices: valid_idx_buf = T.alloc_buffer((1,), "int32", scope="local") valid_idx = T.buffer_proxy(valid_idx_buf) valid_idx[0] = T.int32(0) @@ -369,23 +556,16 @@ def _classic_nms_ir( with T.serial(0, num_anchors) as j: with T.If(out_box_indices[i, j] >= 0): with T.Then(): - if is_soft_nms: - with T.serial(0, box_data_length) as k: - out_data[i, valid_idx[0], k] = out_data[i, j, k] orig_idx = out_box_indices[i, j] out_box_indices[i, valid_idx[0]] = indices[i, orig_idx] valid_idx[0] = valid_idx[0] + 1 out_valid_box_count[i, 0] = valid_idx[0] - # Fill remaining with -1 with T.serial(0, num_anchors) as j: with T.If(j >= valid_idx[0]): with T.Then(): out_box_indices[i, j] = T.int32(-1) - if is_soft_nms: - with T.serial(0, box_data_length) as k: - out_data[i, j, k] = tvm.tirx.Cast(data.dtype, T.float32(-1.0)) return ib.get() @@ -458,7 +638,11 @@ def non_max_suppression( Returns ------- out : tvm.te.Tensor or tuple of tvm.te.Tensor - If return_indices is True, returns a tuple of (box_indices, valid_box_count). + If ``return_indices`` is ``True`` and ``soft_nms_sigma`` is ``0.0``, + returns ``(box_indices, valid_box_count)``. + If ``return_indices`` is ``True`` and ``soft_nms_sigma > 0``, + returns ``(out_data, box_indices, valid_box_count)`` where + ``out_data`` has the same shape as the input data. Otherwise returns the modified data tensor. """ batch_size = data.shape[0] diff --git a/tests/python/relax/test_frontend_tflite.py b/tests/python/relax/test_frontend_tflite.py index 1993a788793f..78e45d965315 100644 --- a/tests/python/relax/test_frontend_tflite.py +++ b/tests/python/relax/test_frontend_tflite.py @@ -1654,6 +1654,40 @@ def _make_valid_boxes(rng, n): np.array([0.9, 0.8, 0.7, 0.6, 0.5], dtype=np.float32), id="soft_nms_tight_sigma", ), + pytest.param( + 3, + 3, + 0.5, + 0.3, + 0.1, + np.array( + [ + [0.0, 0.0, 1.0, 1.0], + [0.2, 0.2, 1.2, 1.2], + [2.0, 2.0, 3.0, 3.0], + ], + dtype=np.float32, + ), + np.array([0.9, 0.8, 0.75], dtype=np.float32), + id="soft_nms_threshold_hole", + ), + pytest.param( + 3, + 3, + 0.5, + 0.0, + 0.1, + np.array( + [ + [0.0, 0.0, 1.0, 1.0], + [0.2, 0.2, 1.2, 1.2], + [2.0, 2.0, 3.0, 3.0], + ], + dtype=np.float32, + ), + np.array([0.9, 0.85, 0.8], dtype=np.float32), + id="soft_nms_reorder", + ), ] diff --git a/tests/python/relax/test_op_vision.py b/tests/python/relax/test_op_vision.py index 6e8a0a364621..ef260cf18858 100644 --- a/tests/python/relax/test_op_vision.py +++ b/tests/python/relax/test_op_vision.py @@ -646,6 +646,8 @@ def _run_nms_e2e( id_index: int = 0, return_indices: bool = True, invalid_to_bottom: bool = False, + soft_nms_sigma: float = 0.0, + score_threshold: float = 0.0, ): """Run classic NMS through legalization and VM execution.""" @@ -672,6 +674,8 @@ def _run_nms_e2e( id_index=id_index, return_indices=return_indices, invalid_to_bottom=invalid_to_bottom, + soft_nms_sigma=soft_nms_sigma, + score_threshold=score_threshold, ) ) bb.emit_func_output(result) @@ -729,6 +733,57 @@ def test_nms_e2e_return_indices(): tvm.testing.assert_allclose(result[1].numpy(), ref_valid_box_count) +@tvm.testing.requires_llvm +def test_nms_e2e_soft_nms_reorders_by_decayed_score(): + """Soft-NMS should re-rank by decayed scores instead of keeping the initial order.""" + + raw_data = np.array( + [ + [ + [0.0, 0.90, 0.0, 0.0, 1.0, 1.0], + [0.0, 0.85, 0.2, 0.2, 1.2, 1.2], + [0.0, 0.80, 2.0, 2.0, 3.0, 3.0], + [-1.0, 0.99, 0.0, 0.0, 1.0, 1.0], + ] + ], + dtype="float32", + ) + valid_count_np, filtered_data_np, filtered_indices_np = _prepare_nms_inputs(raw_data) + ref_out_data, ref_indices, ref_valid_box_count = tvm.topi.testing.non_max_suppression_python( + filtered_data_np, + valid_count_np, + filtered_indices_np, + max_output_size=-1, + iou_threshold=0.5, + force_suppress=True, + top_k=-1, + coord_start=2, + score_index=1, + id_index=-1, + return_indices=True, + invalid_to_bottom=False, + soft_nms_sigma=0.1, + score_threshold=0.0, + ) + result = _run_nms_e2e( + filtered_data_np, + valid_count_np, + filtered_indices_np, + iou_threshold=0.5, + force_suppress=True, + id_index=-1, + return_indices=True, + invalid_to_bottom=False, + soft_nms_sigma=0.1, + score_threshold=0.0, + ) + + np.testing.assert_array_equal(ref_indices[0, :3], np.array([0, 2, 1], dtype="int32")) + tvm.testing.assert_allclose(result[0].numpy(), ref_out_data) + tvm.testing.assert_allclose(result[1].numpy(), ref_indices) + tvm.testing.assert_allclose(result[2].numpy(), ref_valid_box_count) + + @tvm.testing.requires_llvm def test_nms_e2e_return_indices_with_invalid_to_bottom(): """Validate that invalid_to_bottom is a no-op when returning indices.""" From c4c9e2653f283bea957dfa8a6d418a23d3fc0b95 Mon Sep 17 00:00:00 2001 From: Aharrypotter <62729549+Aharrypotter@users.noreply.github.com> Date: Tue, 21 Apr 2026 00:56:57 +0800 Subject: [PATCH 4/4] [Relax][Frontend][TFLite] Refactor NMSV5 return path and clean up ref impl - Deduplicate squeeze/strided_slice/reshape for selected_indices and num_valid across soft-NMS and hard-NMS branches - Remove redundant `if not is_soft_nms` check in the reference implementation (Greedy NMS section is only reached when False) --- .../relax/frontend/tflite/tflite_frontend.py | 25 +++++++++---------- python/tvm/topi/testing/nms_python.py | 5 ++-- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py b/python/tvm/relax/frontend/tflite/tflite_frontend.py index 221609a1c628..9df4e5cf51f2 100644 --- a/python/tvm/relax/frontend/tflite/tflite_frontend.py +++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py @@ -3603,14 +3603,20 @@ def convert_nms_v5(self, op): ) if soft_nms_sigma > 0.0: - # Soft-NMS returns (out_data, box_indices, valid_box_count) processed_data = relax.op.squeeze(nms_ret[0], axis=[0]) - selected_indices = relax.op.squeeze(nms_ret[1], axis=[0]) - selected_indices = relax.op.strided_slice( - selected_indices, axes=[0], begin=[0], end=[max_output_size] - ) - num_valid = relax.op.reshape(nms_ret[2], []) + indices_from_nms = nms_ret[1] + num_valid_from_nms = nms_ret[2] + else: + indices_from_nms = nms_ret[0] + num_valid_from_nms = nms_ret[1] + + selected_indices = relax.op.squeeze(indices_from_nms, axis=[0]) + selected_indices = relax.op.strided_slice( + selected_indices, axes=[0], begin=[0], end=[max_output_size] + ) + num_valid = relax.op.reshape(num_valid_from_nms, []) + if soft_nms_sigma > 0.0: # Extract decayed scores from the processed data (score_index=0) selected_scores = relax.op.strided_slice( processed_data, axes=[1], begin=[0], end=[1] @@ -3620,13 +3626,6 @@ def convert_nms_v5(self, op): selected_scores, axes=[0], begin=[0], end=[max_output_size] ) else: - # Hard NMS returns (box_indices, valid_box_count) - selected_indices = relax.op.squeeze(nms_ret[0], axis=[0]) - selected_indices = relax.op.strided_slice( - selected_indices, axes=[0], begin=[0], end=[max_output_size] - ) - num_valid = relax.op.reshape(nms_ret[1], []) - # Clamp out-of-bound padded indices to prevent take() crash. num_boxes = int(self.get_tensor_shape(input_tensors[0])[0]) safe_indices = relax.op.clip(selected_indices, min=0, max=num_boxes - 1) diff --git a/python/tvm/topi/testing/nms_python.py b/python/tvm/topi/testing/nms_python.py index 493270f9d4ab..c8711c70dde2 100644 --- a/python/tvm/topi/testing/nms_python.py +++ b/python/tvm/topi/testing/nms_python.py @@ -160,9 +160,8 @@ def non_max_suppression_python( num_valid = 0 for j in range(nkeep): if out_data[i, j, score_index] <= thresh: - if not is_soft_nms: - out_data[i, j, :] = -1.0 - out_box_indices[i, j] = -1 + out_data[i, j, :] = -1.0 + out_box_indices[i, j] = -1 continue if 0 < max_output_size <= num_valid: out_data[i, j, :] = -1.0