diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 0ba703621bb0..bbf373ec6f3b 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -621,6 +621,7 @@ def __init__( self.py_num_accepted_draft_tokens_indices = [] self.py_rewind_draft_token_separate_adjustment = 0 self.py_decoding_iter = 0 + self.py_stream_interval = None self.py_last_stream_emit_time = None self.is_attention_dp_dummy = False self.is_cuda_graph_dummy = False @@ -927,6 +928,8 @@ def executor_request_to_llm_request( LogprobMode.RAW), ) + llm_request.py_stream_interval = getattr(executor_request, + "py_stream_interval", None) llm_request.py_disaggregated_params = getattr(executor_request, "py_disaggregated_params", None) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 8ffddfc24531..d05cbbc427b6 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -3214,7 +3214,7 @@ def _handle_responses(self): should_emit = ( request.py_decoding_iter == 1 or request.is_finished - or request.py_decoding_iter % self.stream_interval == 0 + or request.py_decoding_iter % (request.py_stream_interval or self.stream_interval) == 0 or (self.stream_emit_interval_ms > 0 and request.py_last_stream_emit_time and (now - request.py_last_stream_emit_time) * 1000 > diff --git a/tensorrt_llm/executor/base_worker.py b/tensorrt_llm/executor/base_worker.py index 4cf8ad56e25e..dad04c26ca53 100644 --- a/tensorrt_llm/executor/base_worker.py +++ b/tensorrt_llm/executor/base_worker.py @@ -567,6 +567,7 @@ def _deduce_max_tokens(request: GenerationRequest, executor_request.py_num_logprobs = request.sampling_params.logprobs executor_request.py_lora_path = py_lora_path executor_request.py_logprobs_mode = request.sampling_params.logprobs_mode + executor_request.py_stream_interval = request.sampling_params.stream_interval # here we add executor_request.py_disaggregated_params= request.disaggregated_params for python cache transceiver if self._is_pytorch_backend and request.disaggregated_params is not None: diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 8709af873b0a..17cfc0fa919c 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -677,7 +677,7 @@ def _handle_response(self, response: "GenerationExecutor.Response"): prev_text=beam_output.text, states=beam_output._incremental_states, flush=self._done, - stream_interval=self.sampling_params._stream_interval, + stream_interval=self.sampling_params.stream_interval, **kwargs) else: beam_output.text = self.tokenizer.decode( diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 64503e3d23b2..b1216afd7928 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -786,9 +786,9 @@ def _prepare_sampling_params( sampling_params.return_generation_logits = True sampling_params._generation_logits_auto_enabled = True - if sampling_params._stream_interval is None: - sampling_params._stream_interval = getattr(self.args, - "stream_interval", 1) + if sampling_params.stream_interval is None: + sampling_params.stream_interval = getattr(self.args, + "stream_interval", 1) sampling_params.return_perf_metrics = sampling_params.return_perf_metrics or self.args.return_perf_metrics return sampling_params diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index 2043920fe452..cdad60846e25 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -291,9 +291,7 @@ class SamplingParams: truncate_prompt_tokens: Optional[int] = None skip_special_tokens: bool = True spaces_between_special_tokens: bool = True - # Currently, _stream_interval is only used to pass llm.args.stream_interval to tokenizer. - # TODO: make this a per-request parameter. - _stream_interval: Optional[int] = field(default=None, init=False, repr=False) + stream_interval: Optional[int] = None def __post_init__(self): if self.pad_id is None: @@ -317,6 +315,10 @@ def _validate(self): For instance, while the greedy decoding with n > 1 is capable in the Executor class of C++ runtime, the LLM API disallows such combination. """ + if self.stream_interval is not None and self.stream_interval <= 0: + raise ValueError( + f"require stream_interval > 0, got stream_interval={self.stream_interval}" + ) if self.top_p is not None and (self.top_p < 0 or self.top_p > 1): raise ValueError(f"require 0 <= top_p <= 1, got top_p={self.top_p}") if self.top_k is not None and self.top_k < 0: diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index f0093a32137a..458f33a4b866 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -384,6 +384,13 @@ class CompletionRequest(OpenAIBaseModel): default=None, description=("Parameters for disaggregated serving"), ) + stream_interval: Optional[int] = Field( + default=None, + gt=0, + description=( + "The iteration interval to create responses under the streaming mode. " + "If not set, the engine-level default is used."), + ) # doc: end-completion-extra-params @@ -431,6 +438,7 @@ def to_sampling_params(self, # completion-extra-params add_special_tokens=self.add_special_tokens, + stream_interval=self.stream_interval, # TODO: migrate to use logprobs and prompt_logprobs _return_log_probs=bool(self.logprobs), @@ -746,6 +754,13 @@ class ChatCompletionRequest(OpenAIBaseModel): ("If specified, KV cache will be salted with the provided string " "to limit the kv cache reuse on with the requests having the same string." )) + stream_interval: Optional[int] = Field( + default=None, + gt=0, + description=( + "The iteration interval to create responses under the streaming mode. " + "If not set, the engine-level default is used."), + ) # doc: end-chat-completion-extra-params @@ -792,6 +807,7 @@ def to_sampling_params(self, # chat-completion-extra-params add_special_tokens=self.add_special_tokens, + stream_interval=self.stream_interval, # TODO: migrate to use logprobs and prompt_logprobs _return_log_probs=bool(self.logprobs), @@ -903,6 +919,13 @@ class ResponsesRequest(OpenAIBaseModel): top_p: Optional[float] = None truncation: Optional[Literal["auto", "disabled"]] = "disabled" user: Optional[str] = None + stream_interval: Optional[int] = Field( + default=None, + gt=0, + description=( + "The iteration interval to create responses under the streaming mode. " + "If not set, the engine-level default is used."), + ) request_id: str = Field( default_factory=lambda: f"resp_{str(uuid.uuid4().hex)}", @@ -948,6 +971,7 @@ def to_sampling_params( logprobs=self.top_logprobs, stop_token_ids=stop_token_ids, guided_decoding=guided_decoding, + stream_interval=self.stream_interval, ) @model_validator(mode="before") diff --git a/tests/unittest/api_stability/references_committed/sampling_params.yaml b/tests/unittest/api_stability/references_committed/sampling_params.yaml index 084001313dc3..2cc051e5d081 100644 --- a/tests/unittest/api_stability/references_committed/sampling_params.yaml +++ b/tests/unittest/api_stability/references_committed/sampling_params.yaml @@ -123,6 +123,9 @@ methods: spaces_between_special_tokens: annotation: bool default: true + stream_interval: + annotation: Optional[int] + default: null # Returning controls logprobs: annotation: Optional[int] diff --git a/tests/unittest/llmapi/test_executor.py b/tests/unittest/llmapi/test_executor.py index 338f6903b741..9e329bbc1475 100644 --- a/tests/unittest/llmapi/test_executor.py +++ b/tests/unittest/llmapi/test_executor.py @@ -172,6 +172,81 @@ def test_invalid_sampling_params(): SamplingParams(max_tokens=4, n=4, best_of=3, use_beam_search=True) +def test_stream_interval_validation(): + # Valid values + sp = SamplingParams(stream_interval=1) + assert sp.stream_interval == 1 + sp = SamplingParams(stream_interval=10) + assert sp.stream_interval == 10 + sp = SamplingParams(stream_interval=None) + assert sp.stream_interval is None + + # Invalid values + with pytest.raises(ValueError, match="stream_interval"): + SamplingParams(stream_interval=0) + with pytest.raises(ValueError, match="stream_interval"): + SamplingParams(stream_interval=-1) + + +def test_stream_interval_openai_protocol(): + from pydantic import ValidationError + + from tensorrt_llm.serve.openai_protocol import ( + ChatCompletionRequest, + CompletionRequest, + ResponsesRequest, + ) + + # CompletionRequest: valid stream_interval passes through + req = CompletionRequest(model="m", prompt="hi", stream_interval=5) + sp = req.to_sampling_params() + assert sp.stream_interval == 5 + + # CompletionRequest: None leaves stream_interval unset + req = CompletionRequest(model="m", prompt="hi") + sp = req.to_sampling_params() + assert sp.stream_interval is None + + # CompletionRequest: invalid stream_interval rejected at construction + with pytest.raises(ValidationError, match="stream_interval"): + CompletionRequest(model="m", prompt="hi", stream_interval=-1) + with pytest.raises(ValidationError, match="stream_interval"): + CompletionRequest(model="m", prompt="hi", stream_interval=0) + + # ChatCompletionRequest: valid stream_interval passes through + req = ChatCompletionRequest(model="m", + messages=[{ + "role": "user", + "content": "hi" + }], + stream_interval=10) + sp = req.to_sampling_params() + assert sp.stream_interval == 10 + + # ChatCompletionRequest: invalid stream_interval rejected at construction + with pytest.raises(ValidationError, match="stream_interval"): + ChatCompletionRequest(model="m", + messages=[{ + "role": "user", + "content": "hi" + }], + stream_interval=0) + + # ResponsesRequest: valid stream_interval passes through + req = ResponsesRequest(model="m", input="hi", stream_interval=5) + sp = req.to_sampling_params() + assert sp.stream_interval == 5 + + # ResponsesRequest: None leaves stream_interval unset + req = ResponsesRequest(model="m", input="hi") + sp = req.to_sampling_params() + assert sp.stream_interval is None + + # ResponsesRequest: invalid stream_interval rejected at construction + with pytest.raises(ValidationError, match="stream_interval"): + ResponsesRequest(model="m", input="hi", stream_interval=-1) + + @pytest.mark.skipif(torch.cuda.device_count() < 2 or WORLD_SIZE != 2, reason="Must run on 2 MPI ranks with at least 2 GPUs") def test_sync_generation_tp_main_node_only(llama_7b_tp2_path: Path):