Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion include/tvm/relax/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ struct NonMaximumSuppressionAttrs
int id_index;
bool return_indices;
bool invalid_to_bottom;
double soft_nms_sigma;
double score_threshold;

static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
Expand All @@ -143,7 +145,11 @@ struct NonMaximumSuppressionAttrs
.def_ro("return_indices", &NonMaximumSuppressionAttrs::return_indices,
"Whether to return box indices in input data.")
.def_ro("invalid_to_bottom", &NonMaximumSuppressionAttrs::invalid_to_bottom,
"Whether to move all valid bounding boxes to the top.");
"Whether to move all valid bounding boxes to the top.")
.def_ro("soft_nms_sigma", &NonMaximumSuppressionAttrs::soft_nms_sigma,
"Sigma for soft-NMS; 0.0 means standard hard NMS.")
.def_ro("score_threshold", &NonMaximumSuppressionAttrs::score_threshold,
"Score threshold for soft-NMS validity check; 0.0 when unused.");
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("relax.attrs.NonMaximumSuppressionAttrs",
NonMaximumSuppressionAttrs, BaseAttrsNode);
Expand Down
36 changes: 26 additions & 10 deletions python/tvm/relax/frontend/tflite/tflite_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3573,10 +3573,6 @@ def convert_nms_v5(self, op):
if isinstance(soft_nms_sigma, np.ndarray):
assert soft_nms_sigma.size == 1, "only one value is expected."
soft_nms_sigma = float(soft_nms_sigma)
if soft_nms_sigma != 0.0:
raise tvm.error.OpNotImplemented(
"It is soft_nms when soft_nms_sigma != 0, which is not supported!"
)

scores_expand = relax.op.expand_dims(scores, axis=-1)
data = relax.op.concat([scores_expand, boxes], axis=-1)
Expand All @@ -3602,18 +3598,38 @@ def convert_nms_v5(self, op):
id_index=-1,
return_indices=True,
invalid_to_bottom=False,
soft_nms_sigma=soft_nms_sigma,
score_threshold=score_threshold,
)

selected_indices = relax.op.squeeze(nms_ret[0], axis=[0])
if soft_nms_sigma > 0.0:
processed_data = relax.op.squeeze(nms_ret[0], axis=[0])
indices_from_nms = nms_ret[1]
num_valid_from_nms = nms_ret[2]
else:
indices_from_nms = nms_ret[0]
num_valid_from_nms = nms_ret[1]

selected_indices = relax.op.squeeze(indices_from_nms, axis=[0])
selected_indices = relax.op.strided_slice(
selected_indices, axes=[0], begin=[0], end=[max_output_size]
)
num_valid = relax.op.reshape(nms_ret[1], [])
num_valid = relax.op.reshape(num_valid_from_nms, [])

# Clamp out-of-bound padded indices to prevent take() crash.
num_boxes = int(self.get_tensor_shape(input_tensors[0])[0])
safe_indices = relax.op.clip(selected_indices, min=0, max=num_boxes - 1)
selected_scores = relax.op.take(scores, safe_indices, axis=0)
if soft_nms_sigma > 0.0:
# Extract decayed scores from the processed data (score_index=0)
selected_scores = relax.op.strided_slice(
processed_data, axes=[1], begin=[0], end=[1]
)
selected_scores = relax.op.squeeze(selected_scores, axis=[1])
selected_scores = relax.op.strided_slice(
selected_scores, axes=[0], begin=[0], end=[max_output_size]
)
else:
# Clamp out-of-bound padded indices to prevent take() crash.
num_boxes = int(self.get_tensor_shape(input_tensors[0])[0])
safe_indices = relax.op.clip(selected_indices, min=0, max=num_boxes - 1)
selected_scores = relax.op.take(scores, safe_indices, axis=0)
Comment thread
Aharrypotter marked this conversation as resolved.

out = relax.Tuple([selected_indices, selected_scores, num_valid])
return out
Expand Down
20 changes: 18 additions & 2 deletions python/tvm/relax/op/vision/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def non_max_suppression(
id_index=0,
return_indices=True,
invalid_to_bottom=False,
soft_nms_sigma=0.0,
score_threshold=0.0,
):
"""Non-maximum suppression operator for object detection.

Expand Down Expand Up @@ -160,12 +162,24 @@ def non_max_suppression(
Whether to move valid bounding boxes to the top of the returned tensor.
This option only affects the ``return_indices=False`` path.

soft_nms_sigma : float, optional
Sigma for soft-NMS Gaussian penalty. When ``0.0`` (default), standard
hard NMS is used. Positive values decay overlapping box scores instead
of suppressing them outright.

score_threshold : float, optional
Minimum score for a box to be eligible for selection during soft-NMS.
Only used when ``soft_nms_sigma > 0``. Defaults to ``0.0``.

Returns
-------
out : relax.Expr
If ``return_indices`` is ``True``, returns
``(box_indices, valid_box_count)`` with shapes
If ``return_indices`` is ``True`` and ``soft_nms_sigma`` is ``0.0``,
returns ``(box_indices, valid_box_count)`` with shapes
``[batch_size, num_anchors]`` and ``[batch_size, 1]``.
If ``return_indices`` is ``True`` and ``soft_nms_sigma > 0``,
returns ``(out_data, box_indices, valid_box_count)`` where
``out_data`` has the same shape as the input data.
Otherwise returns the modified data tensor.
"""
return _ffi_api.non_max_suppression(
Expand All @@ -181,4 +195,6 @@ def non_max_suppression(
id_index,
return_indices,
invalid_to_bottom,
soft_nms_sigma,
score_threshold,
)
2 changes: 2 additions & 0 deletions python/tvm/relax/transform/legalize_ops/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ def _non_max_suppression(block_builder: BlockBuilder, call: Call) -> Expr:
id_index=call.attrs.id_index,
return_indices=call.attrs.return_indices,
invalid_to_bottom=call.attrs.invalid_to_bottom,
soft_nms_sigma=call.attrs.soft_nms_sigma,
score_threshold=call.attrs.score_threshold,
)


Expand Down
78 changes: 75 additions & 3 deletions python/tvm/topi/testing/nms_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def non_max_suppression_python(
id_index=0,
return_indices=True,
invalid_to_bottom=False,
soft_nms_sigma=0.0,
score_threshold=0.0,
):
"""Numpy reference for classic non_max_suppression.

Expand All @@ -62,7 +64,9 @@ def non_max_suppression_python(

Returns
-------
If return_indices is True: (box_indices, valid_box_count)
If return_indices is True and soft_nms_sigma == 0.0: (box_indices, valid_box_count)
If return_indices is True and soft_nms_sigma > 0.0:
(out_data, box_indices, valid_box_count)
Otherwise: modified data tensor
"""
batch_size, num_anchors, _ = data.shape
Expand All @@ -71,6 +75,10 @@ def non_max_suppression_python(
compacted = np.full((batch_size, num_anchors), -1, dtype="int32")
valid_box_count = np.zeros((batch_size, 1), dtype="int32")

is_soft_nms = soft_nms_sigma > 0.0
thresh = score_threshold if is_soft_nms else 0.0
soft_nms_scale = -0.5 / soft_nms_sigma if is_soft_nms else 0.0

for i in range(batch_size):
nkeep = int(valid_count[i])
if 0 < top_k < nkeep:
Expand All @@ -86,10 +94,72 @@ def non_max_suppression_python(
out_data[i, j, :] = data[i, src, :]
out_box_indices[i, j] = src

if is_soft_nms:
num_selected = 0
while num_selected < nkeep and (max_output_size < 0 or num_selected < max_output_size):
best_idx = -1
best_score = thresh
for j in range(num_selected, nkeep):
if out_box_indices[i, j] >= 0 and out_data[i, j, score_index] > best_score:
best_idx = j
best_score = out_data[i, j, score_index]

if best_idx < 0:
break

if best_idx != num_selected:
out_data[i, [num_selected, best_idx], :] = out_data[
i, [best_idx, num_selected], :
]
out_box_indices[i, [num_selected, best_idx]] = out_box_indices[
i, [best_idx, num_selected]
]

selected_idx = num_selected
for j in range(selected_idx + 1, nkeep):
if out_box_indices[i, j] < 0 or out_data[i, j, score_index] <= thresh:
continue

do_suppress = False
if force_suppress:
do_suppress = True
elif id_index >= 0:
do_suppress = (
out_data[i, selected_idx, id_index] == out_data[i, j, id_index]
)
else:
do_suppress = True

if not do_suppress:
continue

iou = _iou(out_data[i, selected_idx], out_data[i, j], coord_start)
if iou >= iou_threshold:
out_box_indices[i, j] = -1
else:
out_data[i, j, score_index] *= np.exp(soft_nms_scale * (iou**2))
if out_data[i, j, score_index] <= thresh:
out_box_indices[i, j] = -1

num_selected += 1

valid_box_count[i, 0] = num_selected
if return_indices:
for j in range(num_selected):
orig_idx = out_box_indices[i, j]
compacted[i, j] = int(indices[i, orig_idx])
out_box_indices[i, j] = compacted[i, j]
for j in range(num_selected, num_anchors):
out_data[i, j, :] = -1.0
out_box_indices[i, j] = -1
else:
out_data[i, num_selected:, :] = -1.0
continue

# Greedy NMS
num_valid = 0
for j in range(nkeep):
if out_data[i, j, score_index] <= 0:
if out_data[i, j, score_index] <= thresh:
out_data[i, j, :] = -1.0
out_box_indices[i, j] = -1
continue
Comment thread
Aharrypotter marked this conversation as resolved.
Expand All @@ -102,7 +172,7 @@ def non_max_suppression_python(

# Suppress overlapping boxes
for k in range(j + 1, nkeep):
if out_data[i, k, score_index] <= 0:
if out_data[i, k, score_index] <= thresh:
continue

do_suppress = False
Expand Down Expand Up @@ -130,6 +200,8 @@ def non_max_suppression_python(
valid_box_count[i, 0] = cnt

if return_indices:
if is_soft_nms:
return [out_data, compacted, valid_box_count]
return [compacted, valid_box_count]

if invalid_to_bottom:
Expand Down
Loading
Loading