Skip to content

[Relax][Frontend][TFLite] Add soft-NMS support for TFLite NON_MAX_SUPPRESSION_V5#19426

Open
Aharrypotter wants to merge 4 commits intoapache:mainfrom
Aharrypotter:tflite-nms-v5-soft-nms-19412
Open

[Relax][Frontend][TFLite] Add soft-NMS support for TFLite NON_MAX_SUPPRESSION_V5#19426
Aharrypotter wants to merge 4 commits intoapache:mainfrom
Aharrypotter:tflite-nms-v5-soft-nms-19412

Conversation

@Aharrypotter
Copy link
Copy Markdown
Contributor

Summary

This PR completes the TFLite NON_MAX_SUPPRESSION_V5 implementation in Relax by adding support for soft_nms_sigma != 0.

It extends relax.vision.non_max_suppression with soft-NMS attributes, updates the TFLite frontend to consume the soft-NMS outputs correctly, and aligns the TOPI implementation with LiteRT's reference behavior.

Relates to #19412.

Changes

  1. Relax / TOPI soft-NMS support

    • Extend NonMaximumSuppressionAttrs with soft_nms_sigma and score_threshold.
    • Thread the new attributes through Relax op registration, Python wrapper, and legalization.
    • Add soft-NMS handling to TOPI classic NMS so relax.vision.non_max_suppression can represent the NON_MAX_SUPPRESSION_V5 behavior.
  2. TFLite frontend support for NON_MAX_SUPPRESSION_V5

    • Remove the previous soft_nms_sigma != 0 unsupported-path guard in the TFLite frontend.
    • Forward soft_nms_sigma and score_threshold into relax.vision.non_max_suppression.
    • Handle the soft-NMS return path explicitly so the frontend reads decayed scores from the processed NMS output instead of re-reading the original score tensor.
  3. Soft-NMS correctness fixes

    • Fix the soft-NMS path so boxes whose scores fall below the threshold after decay are invalidated consistently.
    • Keep returned indices and decayed scores aligned in both the TOPI TIR implementation and the NumPy reference implementation.
    • Update the soft-NMS candidate selection logic to re-pick the current best candidate after each decay step, matching LiteRT's
      reference behavior.
    • Align the Gaussian decay formula with LiteRT.
  4. Test coverage

    • Add Relax tests for soft-NMS struct-info inference and legalization.
    • Add Relax E2E tests covering reordered outputs after score decay and other soft-NMS follow-up cases.
    • Add TFLite frontend tests for NON_MAX_SUPPRESSION_V5 with soft_nms_sigma != 0.
    • Add IR checks to verify that soft_nms_sigma and score_threshold are forwarded correctly.

Testing

python -m pytest -n 1 tests/python/relax/test_op_vision.py -k "all_class_non_max_suppression or get_valid_counts or nms" -v
python -m pytest tests/python/relax/test_frontend_tflite.py -k "nms_v5" -v

Result:

  • Relax vision tests passed locally
  • TFLite NON_MAX_SUPPRESSION_V5 coverage added for both hard-NMS and soft-NMS paths

… TFLite NMSv5

This commit adds soft-NMS support to relax.vision.non_max_suppression and
the TFLite frontend's NonMaxSuppressionV5 converter.

Key changes:
- Extended NonMaximumSuppressionAttrs with soft_nms_sigma and score_threshold.
- Implemented Gaussian soft-NMS decay in TOPI _classic_nms_ir:
  score *= exp(-iou^2 / sigma) for overlapping boxes instead of hard suppression.
- Introduced conditional return type: soft-NMS returns a 3-tuple
  (out_data, box_indices, valid_box_count) so the TFLite frontend can extract
  decayed scores; hard NMS keeps the original 2-tuple.
- Updated Relax struct info inference to match the conditional return type.
- Removed the soft_nms_sigma != 0.0 OpNotImplemented guard in TFLite frontend.
- Added tests: struct info inference, legalization, TVMScript parser, and
  TFLite frontend E2E smoke tests for soft-NMS.
  This commit fixes a critical bug in the soft-NMS path where decayed boxes
  whose scores dropped below the threshold were not properly invalidated,
  causing misalignment between selected_indices and selected_scores.

  Fixes in TOPI _classic_nms_ir:
  1. After Gaussian decay, invalidate boxes with score <= threshold by
     setting their out_box_indices to -1.
  2. During the compaction phase for return_indices=True, also compact
     out_data so that the first num_valid entries match the selected boxes
     in order. Fill remaining entries with -1.

  Fixes in the NumPy reference implementation (nms_python.py):
  - Synchronize the same post-decay invalidation logic to keep the
    reference behavior consistent with the TIR kernel and TensorFlow.

  Without this fix, the TFLite frontend would extract wrong scores from
  processed_data[:max_output_size] when a middle box was decayed below
  threshold but later boxes were still valid.
Align the soft-NMS path in non_max_suppression with LiteRT's
reference behavior.

- Rewrite soft-NMS candidate selection in TOPI TIR and the NumPy
  reference to re-pick the current best score after each decay step
- Match LiteRT's Gaussian decay formula and keep decayed scores/data
  aligned with the returned indices for TFLite NMSv5
- Update conditional soft-NMS return documentation and add Relax/TFLite
  tests for reordered outputs and forwarded soft-NMS attributes

Validation:
- python -m pytest -n 1 tests/python/relax/test_op_vision.py -k "all_class_non_max_suppression or get_valid_counts or nms" -v

Result:
- 32 passed
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements Soft-NMS support across the TVM Relax stack, including the TFLite frontend, Relax operators, and TOPI TIR implementations. The changes introduce soft_nms_sigma and score_threshold attributes to the non_max_suppression operator and update the return signature to handle the additional output data required for soft-NMS. Feedback focuses on refactoring redundant code in the TFLite frontend, removing unreachable logic in the Python reference implementation, and consolidating duplicated IoU calculation logic within the TIR implementation to improve maintainability.

Comment thread python/tvm/relax/frontend/tflite/tflite_frontend.py
Comment thread python/tvm/topi/testing/nms_python.py
num_valid_boxes_buf = T.alloc_buffer((1,), "int32", scope="local")
num_valid_boxes = T.buffer_proxy(num_valid_boxes_buf)
num_valid_boxes[0] = T.int32(0)
if is_soft_nms:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The IoU calculation logic is duplicated in both the soft-NMS (if is_soft_nms, lines 311-401) and hard-NMS (else, lines 483-537) branches. This makes the code harder to maintain. Consider refactoring this logic into a shared helper function or macro if possible within TIR script to avoid duplication.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The IoU blocks operate in different indexing contexts and are self-contained within their respective branches. Extracting a shared TIR helper would add complexity without clear benefits for this already-large diff. I'd prefer to keep them as-is.

… impl

 - Deduplicate squeeze/strided_slice/reshape for selected_indices and num_valid across soft-NMS and hard-NMS branches
 - Remove redundant `if not is_soft_nms` check in the reference implementation (Greedy NMS section is only reached when False)
@Aharrypotter
Copy link
Copy Markdown
Contributor Author

cc @tlopex

Copy link
Copy Markdown
Member

@tlopex tlopex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few things still need to be cleaned up before this is ready:

  1. Please add a note in python/tvm/relax/op/vision/nms.py that the return tuple shape depends on soft_nms_sigma: it is a 2-tuple when soft_nms_sigma == 0, and a 3-tuple (with decayed out_data prepended) when soft_nms_sigma > 0. That is easy to miss from the signature alone.

  2. In both python/tvm/relax/op/vision/nms.py and python/tvm/topi/vision/nms.py, please clarify that this score_threshold is a post-decay floor used only in the soft-NMS path, which is different from get_valid_counts.score_threshold (pre-filter). Since the TFLite frontend passes both, it would be good to spell that out clearly.

  3. In python/tvm/topi/vision/nms.py, the IoU computation block is duplicated between the hard-NMS and soft-NMS branches of _classic_nms_ir. Please consider pulling it into a local helper so the two paths do not drift on future fixes.

  4. In python/tvm/topi/vision/nms.py, best_idx_buf, best_score_buf, tmp_idx_buf, and tmp_val_buf are currently allocated inside nested T.If / T.serial bodies. Please move them up next to num_valid_boxes_buf under with T.parallel(0, batch_size).

  5. In python/tvm/relax/frontend/tflite/tflite_frontend.py, the soft-NMS path takes scores from processed_data, which uses -1.0 for invalid slots. TensorFlow NonMaxSuppressionV5 expects non-negative padding. Please double-check that this is really okay, and add a clip(min=0.0) if needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants