From 6562f097e244b1fbc9bc406d52f2a0657da5d4b8 Mon Sep 17 00:00:00 2001 From: Quan Truong Date: Fri, 3 Apr 2026 00:37:26 +0000 Subject: [PATCH 1/5] [None][feat] Add per-request stream_interval support Add a per-request `stream_interval` field to `SamplingParams` and the OpenAI-compatible API (`CompletionRequest`, `ChatCompletionRequest`, `ResponsesRequest`), allowing callers to override the engine-level default on a per-request basis. Resolution order: per-request value > engine default > fallback (1). The field is passed through the SamplingParams constructor so that `_validate()` catches invalid values (<=0) at construction time. Signed-off-by: Quan Truong --- tensorrt_llm/_torch/pyexecutor/llm_request.py | 3 + tensorrt_llm/_torch/pyexecutor/py_executor.py | 2 +- tensorrt_llm/executor/base_worker.py | 1 + tensorrt_llm/llmapi/llm.py | 7 ++- tensorrt_llm/sampling_params.py | 8 ++- tensorrt_llm/serve/openai_protocol.py | 21 +++++++ .../references_committed/sampling_params.yaml | 3 + tests/unittest/llmapi/test_executor.py | 56 +++++++++++++++++++ 8 files changed, 96 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 0ba703621bb0..bbf373ec6f3b 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -621,6 +621,7 @@ def __init__( self.py_num_accepted_draft_tokens_indices = [] self.py_rewind_draft_token_separate_adjustment = 0 self.py_decoding_iter = 0 + self.py_stream_interval = None self.py_last_stream_emit_time = None self.is_attention_dp_dummy = False self.is_cuda_graph_dummy = False @@ -927,6 +928,8 @@ def executor_request_to_llm_request( LogprobMode.RAW), ) + llm_request.py_stream_interval = getattr(executor_request, + "py_stream_interval", None) llm_request.py_disaggregated_params = getattr(executor_request, "py_disaggregated_params", None) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 8ffddfc24531..d05cbbc427b6 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -3214,7 +3214,7 @@ def _handle_responses(self): should_emit = ( request.py_decoding_iter == 1 or request.is_finished - or request.py_decoding_iter % self.stream_interval == 0 + or request.py_decoding_iter % (request.py_stream_interval or self.stream_interval) == 0 or (self.stream_emit_interval_ms > 0 and request.py_last_stream_emit_time and (now - request.py_last_stream_emit_time) * 1000 > diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index 4cf8ad56e25e..ca2ea781b01f 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -567,6 +567,7 @@ def _deduce_max_tokens(request: GenerationRequest, executor_request.py_num_logprobs = request.sampling_params.logprobs executor_request.py_lora_path = py_lora_path executor_request.py_logprobs_mode = request.sampling_params.logprobs_mode + executor_request.py_stream_interval = request.sampling_params._stream_interval # here we add executor_request.py_disaggregated_params= request.disaggregated_params for python cache transceiver if self._is_pytorch_backend and request.disaggregated_params is not None: diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 64503e3d23b2..53117b2fd374 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -787,8 +787,11 @@ def _prepare_sampling_params( sampling_params._generation_logits_auto_enabled = True if sampling_params._stream_interval is None: - sampling_params._stream_interval = getattr(self.args, - "stream_interval", 1) + if sampling_params.stream_interval is not None: + sampling_params._stream_interval = sampling_params.stream_interval + else: + sampling_params._stream_interval = getattr( + self.args, "stream_interval", 1) sampling_params.return_perf_metrics = sampling_params.return_perf_metrics or self.args.return_perf_metrics return sampling_params diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index 2043920fe452..96e7d1f36291 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -291,8 +291,8 @@ class SamplingParams: truncate_prompt_tokens: Optional[int] = None skip_special_tokens: bool = True spaces_between_special_tokens: bool = True - # Currently, _stream_interval is only used to pass llm.args.stream_interval to tokenizer. - # TODO: make this a per-request parameter. + stream_interval: Optional[int] = None + # _stream_interval is the resolved value (per-request or engine default), used internally. _stream_interval: Optional[int] = field(default=None, init=False, repr=False) def __post_init__(self): @@ -317,6 +317,10 @@ def _validate(self): For instance, while the greedy decoding with n > 1 is capable in the Executor class of C++ runtime, the LLM API disallows such combination. """ + if self.stream_interval is not None and self.stream_interval <= 0: + raise ValueError( + f"require stream_interval > 0, got stream_interval={self.stream_interval}" + ) if self.top_p is not None and (self.top_p < 0 or self.top_p > 1): raise ValueError(f"require 0 <= top_p <= 1, got top_p={self.top_p}") if self.top_k is not None and self.top_k < 0: diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index f0093a32137a..9a4e1b6196a2 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -384,6 +384,12 @@ class CompletionRequest(OpenAIBaseModel): default=None, description=("Parameters for disaggregated serving"), ) + stream_interval: Optional[int] = Field( + default=None, + description=( + "The iteration interval to create responses under the streaming mode. " + "If not set, the engine-level default is used."), + ) # doc: end-completion-extra-params @@ -431,6 +437,7 @@ def to_sampling_params(self, # completion-extra-params add_special_tokens=self.add_special_tokens, + stream_interval=self.stream_interval, # TODO: migrate to use logprobs and prompt_logprobs _return_log_probs=bool(self.logprobs), @@ -746,6 +753,12 @@ class ChatCompletionRequest(OpenAIBaseModel): ("If specified, KV cache will be salted with the provided string " "to limit the kv cache reuse on with the requests having the same string." )) + stream_interval: Optional[int] = Field( + default=None, + description=( + "The iteration interval to create responses under the streaming mode. " + "If not set, the engine-level default is used."), + ) # doc: end-chat-completion-extra-params @@ -792,6 +805,7 @@ def to_sampling_params(self, # chat-completion-extra-params add_special_tokens=self.add_special_tokens, + stream_interval=self.stream_interval, # TODO: migrate to use logprobs and prompt_logprobs _return_log_probs=bool(self.logprobs), @@ -903,6 +917,12 @@ class ResponsesRequest(OpenAIBaseModel): top_p: Optional[float] = None truncation: Optional[Literal["auto", "disabled"]] = "disabled" user: Optional[str] = None + stream_interval: Optional[int] = Field( + default=None, + description=( + "The iteration interval to create responses under the streaming mode. " + "If not set, the engine-level default is used."), + ) request_id: str = Field( default_factory=lambda: f"resp_{str(uuid.uuid4().hex)}", @@ -948,6 +968,7 @@ def to_sampling_params( logprobs=self.top_logprobs, stop_token_ids=stop_token_ids, guided_decoding=guided_decoding, + stream_interval=self.stream_interval, ) @model_validator(mode="before") diff --git a/tests/unittest/api_stability/references_committed/sampling_params.yaml b/tests/unittest/api_stability/references_committed/sampling_params.yaml index 084001313dc3..2cc051e5d081 100644 --- a/tests/unittest/api_stability/references_committed/sampling_params.yaml +++ b/tests/unittest/api_stability/references_committed/sampling_params.yaml @@ -123,6 +123,9 @@ methods: spaces_between_special_tokens: annotation: bool default: true + stream_interval: + annotation: Optional[int] + default: null # Returning controls logprobs: annotation: Optional[int] diff --git a/tests/unittest/llmapi/test_executor.py b/tests/unittest/llmapi/test_executor.py index 338f6903b741..89148b81006e 100644 --- a/tests/unittest/llmapi/test_executor.py +++ b/tests/unittest/llmapi/test_executor.py @@ -172,6 +172,62 @@ def test_invalid_sampling_params(): SamplingParams(max_tokens=4, n=4, best_of=3, use_beam_search=True) +def test_stream_interval_validation(): + # Valid values + sp = SamplingParams(stream_interval=1) + assert sp.stream_interval == 1 + sp = SamplingParams(stream_interval=10) + assert sp.stream_interval == 10 + sp = SamplingParams(stream_interval=None) + assert sp.stream_interval is None + + # Invalid values + with pytest.raises(ValueError, match="stream_interval"): + SamplingParams(stream_interval=0) + with pytest.raises(ValueError, match="stream_interval"): + SamplingParams(stream_interval=-1) + + +def test_stream_interval_openai_protocol(): + from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest, + CompletionRequest) + + # CompletionRequest: valid stream_interval passes through + req = CompletionRequest(model="m", prompt="hi", stream_interval=5) + sp = req.to_sampling_params() + assert sp.stream_interval == 5 + + # CompletionRequest: None leaves stream_interval unset + req = CompletionRequest(model="m", prompt="hi") + sp = req.to_sampling_params() + assert sp.stream_interval is None + + # CompletionRequest: invalid stream_interval is caught + req = CompletionRequest(model="m", prompt="hi", stream_interval=-1) + with pytest.raises(ValueError, match="stream_interval"): + req.to_sampling_params() + + # ChatCompletionRequest: valid stream_interval passes through + req = ChatCompletionRequest(model="m", + messages=[{ + "role": "user", + "content": "hi" + }], + stream_interval=10) + sp = req.to_sampling_params() + assert sp.stream_interval == 10 + + # ChatCompletionRequest: invalid stream_interval is caught + req = ChatCompletionRequest(model="m", + messages=[{ + "role": "user", + "content": "hi" + }], + stream_interval=0) + with pytest.raises(ValueError, match="stream_interval"): + req.to_sampling_params() + + @pytest.mark.skipif(torch.cuda.device_count() < 2 or WORLD_SIZE != 2, reason="Must run on 2 MPI ranks with at least 2 GPUs") def test_sync_generation_tp_main_node_only(llama_7b_tp2_path: Path): From 4f360bbee4266f95865fb4c3155762a15c082d00 Mon Sep 17 00:00:00 2001 From: Quan Truong Date: Fri, 3 Apr 2026 00:57:17 +0000 Subject: [PATCH 2/5] Address review feedback: schema-level validation and test coverage - Add gt=0 constraint to stream_interval Pydantic Fields so invalid values are rejected at request construction - Add ResponsesRequest test coverage - Fix import indentation in test Signed-off-by: Quan Truong --- tensorrt_llm/serve/openai_protocol.py | 3 ++ tests/unittest/llmapi/test_executor.py | 49 ++++++++++++++++++-------- 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 9a4e1b6196a2..458f33a4b866 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -386,6 +386,7 @@ class CompletionRequest(OpenAIBaseModel): ) stream_interval: Optional[int] = Field( default=None, + gt=0, description=( "The iteration interval to create responses under the streaming mode. " "If not set, the engine-level default is used."), @@ -755,6 +756,7 @@ class ChatCompletionRequest(OpenAIBaseModel): )) stream_interval: Optional[int] = Field( default=None, + gt=0, description=( "The iteration interval to create responses under the streaming mode. " "If not set, the engine-level default is used."), @@ -919,6 +921,7 @@ class ResponsesRequest(OpenAIBaseModel): user: Optional[str] = None stream_interval: Optional[int] = Field( default=None, + gt=0, description=( "The iteration interval to create responses under the streaming mode. " "If not set, the engine-level default is used."), diff --git a/tests/unittest/llmapi/test_executor.py b/tests/unittest/llmapi/test_executor.py index 89148b81006e..9e329bbc1475 100644 --- a/tests/unittest/llmapi/test_executor.py +++ b/tests/unittest/llmapi/test_executor.py @@ -189,8 +189,13 @@ def test_stream_interval_validation(): def test_stream_interval_openai_protocol(): - from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest, - CompletionRequest) + from pydantic import ValidationError + + from tensorrt_llm.serve.openai_protocol import ( + ChatCompletionRequest, + CompletionRequest, + ResponsesRequest, + ) # CompletionRequest: valid stream_interval passes through req = CompletionRequest(model="m", prompt="hi", stream_interval=5) @@ -202,10 +207,11 @@ def test_stream_interval_openai_protocol(): sp = req.to_sampling_params() assert sp.stream_interval is None - # CompletionRequest: invalid stream_interval is caught - req = CompletionRequest(model="m", prompt="hi", stream_interval=-1) - with pytest.raises(ValueError, match="stream_interval"): - req.to_sampling_params() + # CompletionRequest: invalid stream_interval rejected at construction + with pytest.raises(ValidationError, match="stream_interval"): + CompletionRequest(model="m", prompt="hi", stream_interval=-1) + with pytest.raises(ValidationError, match="stream_interval"): + CompletionRequest(model="m", prompt="hi", stream_interval=0) # ChatCompletionRequest: valid stream_interval passes through req = ChatCompletionRequest(model="m", @@ -217,15 +223,28 @@ def test_stream_interval_openai_protocol(): sp = req.to_sampling_params() assert sp.stream_interval == 10 - # ChatCompletionRequest: invalid stream_interval is caught - req = ChatCompletionRequest(model="m", - messages=[{ - "role": "user", - "content": "hi" - }], - stream_interval=0) - with pytest.raises(ValueError, match="stream_interval"): - req.to_sampling_params() + # ChatCompletionRequest: invalid stream_interval rejected at construction + with pytest.raises(ValidationError, match="stream_interval"): + ChatCompletionRequest(model="m", + messages=[{ + "role": "user", + "content": "hi" + }], + stream_interval=0) + + # ResponsesRequest: valid stream_interval passes through + req = ResponsesRequest(model="m", input="hi", stream_interval=5) + sp = req.to_sampling_params() + assert sp.stream_interval == 5 + + # ResponsesRequest: None leaves stream_interval unset + req = ResponsesRequest(model="m", input="hi") + sp = req.to_sampling_params() + assert sp.stream_interval is None + + # ResponsesRequest: invalid stream_interval rejected at construction + with pytest.raises(ValidationError, match="stream_interval"): + ResponsesRequest(model="m", input="hi", stream_interval=-1) @pytest.mark.skipif(torch.cuda.device_count() < 2 or WORLD_SIZE != 2, From 93b1ae7579fdcf4971d0c75ebc4932be390748f3 Mon Sep 17 00:00:00 2001 From: Quan Truong Date: Fri, 3 Apr 2026 20:30:22 +0000 Subject: [PATCH 3/5] Remove redundant _stream_interval field from SamplingParams MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The private _stream_interval field was an unnecessary indirection — it only existed to hold the resolved value (per-request or engine default). Instead, resolve directly into the public stream_interval field when it is None, eliminating the confusing duplication. Signed-off-by: Quan Truong --- tensorrt_llm/executor/base_worker.py | 2 +- tensorrt_llm/executor/result.py | 2 +- tensorrt_llm/llmapi/llm.py | 9 +++------ tensorrt_llm/sampling_params.py | 2 -- 4 files changed, 5 insertions(+), 10 deletions(-) diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index ca2ea781b01f..dad04c26ca53 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -567,7 +567,7 @@ def _deduce_max_tokens(request: GenerationRequest, executor_request.py_num_logprobs = request.sampling_params.logprobs executor_request.py_lora_path = py_lora_path executor_request.py_logprobs_mode = request.sampling_params.logprobs_mode - executor_request.py_stream_interval = request.sampling_params._stream_interval + executor_request.py_stream_interval = request.sampling_params.stream_interval # here we add executor_request.py_disaggregated_params= request.disaggregated_params for python cache transceiver if self._is_pytorch_backend and request.disaggregated_params is not None: diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 8709af873b0a..17cfc0fa919c 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -677,7 +677,7 @@ def _handle_response(self, response: "GenerationExecutor.Response"): prev_text=beam_output.text, states=beam_output._incremental_states, flush=self._done, - stream_interval=self.sampling_params._stream_interval, + stream_interval=self.sampling_params.stream_interval, **kwargs) else: beam_output.text = self.tokenizer.decode( diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 53117b2fd374..addd47eb67e5 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -786,12 +786,9 @@ def _prepare_sampling_params( sampling_params.return_generation_logits = True sampling_params._generation_logits_auto_enabled = True - if sampling_params._stream_interval is None: - if sampling_params.stream_interval is not None: - sampling_params._stream_interval = sampling_params.stream_interval - else: - sampling_params._stream_interval = getattr( - self.args, "stream_interval", 1) + if sampling_params.stream_interval is None: + sampling_params.stream_interval = getattr( + self.args, "stream_interval", 1) sampling_params.return_perf_metrics = sampling_params.return_perf_metrics or self.args.return_perf_metrics return sampling_params diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index 96e7d1f36291..cdad60846e25 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -292,8 +292,6 @@ class SamplingParams: skip_special_tokens: bool = True spaces_between_special_tokens: bool = True stream_interval: Optional[int] = None - # _stream_interval is the resolved value (per-request or engine default), used internally. - _stream_interval: Optional[int] = field(default=None, init=False, repr=False) def __post_init__(self): if self.pad_id is None: From 3190a6e7ba2c6bab302bed2085d37583c30f9798 Mon Sep 17 00:00:00 2001 From: Quan Truong Date: Fri, 3 Apr 2026 21:39:41 +0000 Subject: [PATCH 4/5] change format --- tensorrt_llm/llmapi/llm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index addd47eb67e5..3ea966075172 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -787,8 +787,8 @@ def _prepare_sampling_params( sampling_params._generation_logits_auto_enabled = True if sampling_params.stream_interval is None: - sampling_params.stream_interval = getattr( - self.args, "stream_interval", 1) + sampling_params.stream_interval = getattr(self.args, + "stream_interval", 1) sampling_params.return_perf_metrics = sampling_params.return_perf_metrics or self.args.return_perf_metrics return sampling_params From 03fb2da7d961a3a2b81d8a6b28438c92e4281daa Mon Sep 17 00:00:00 2001 From: Quan Truong Date: Fri, 3 Apr 2026 21:41:50 +0000 Subject: [PATCH 5/5] change formatting again --- tensorrt_llm/llmapi/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 3ea966075172..b1216afd7928 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -787,7 +787,7 @@ def _prepare_sampling_params( sampling_params._generation_logits_auto_enabled = True if sampling_params.stream_interval is None: - sampling_params.stream_interval = getattr(self.args, + sampling_params.stream_interval = getattr(self.args, "stream_interval", 1) sampling_params.return_perf_metrics = sampling_params.return_perf_metrics or self.args.return_perf_metrics return sampling_params