Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ def __init__(
self.llm_args.enable_iter_perf_stats = reporting_info.enable_iter_perf_stats
self.llm_args.enable_iter_req_stats = reporting_info.enable_iter_req_stats
self.llm_args.stream_interval = 1
self.llm_args.stream_interval_ms = 0
self.llm_args.attention_dp_config = None
self.llm_args.batch_wait_timeout_ms = 0
self.llm_args.batch_wait_timeout_iters = 0
Expand Down
6 changes: 5 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,9 @@ def __init__(
self.py_rewind_draft_token_separate_adjustment = 0
self.py_decoding_iter = 0
self.py_stream_interval = None
self.py_last_stream_emit_time = None
self.py_stream_interval_ms = None
self.py_last_stream_emit_time: Optional[float] = None
self.py_last_stream_emit_iter = 0
self.is_attention_dp_dummy = False
self.is_cuda_graph_dummy = False
self.py_kv_transfer_start_time = None
Expand Down Expand Up @@ -930,6 +932,8 @@ def executor_request_to_llm_request(

llm_request.py_stream_interval = getattr(executor_request,
"py_stream_interval", None)
llm_request.py_stream_interval_ms = getattr(executor_request,
"py_stream_interval_ms", None)
llm_request.py_disaggregated_params = getattr(executor_request,
"py_disaggregated_params",
None)
Expand Down
30 changes: 20 additions & 10 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def __init__(
self.enable_iter_perf_stats = self.llm_args.enable_iter_perf_stats
self.enable_iter_req_stats = self.llm_args.enable_iter_req_stats
self.stream_interval = self.llm_args.stream_interval
self.stream_emit_interval_ms = self.llm_args.stream_emit_interval_ms
self.stream_interval_ms = self.llm_args.stream_interval_ms
self.attention_dp_enable_balance = (
self.llm_args.attention_dp_config is not None
and self.llm_args.attention_dp_config.enable_balance)
Expand Down Expand Up @@ -3210,17 +3210,27 @@ def _handle_responses(self):
request.update_perf_metrics(self.iter_counter)

request_done = False
request.py_last_stream_emit_iter += 1
now = get_steady_clock_now_in_seconds()
should_emit = (
request.py_decoding_iter == 1
or request.is_finished
or request.py_decoding_iter % (request.py_stream_interval or self.stream_interval) == 0
or (self.stream_emit_interval_ms > 0
and request.py_last_stream_emit_time
and (now - request.py_last_stream_emit_time) * 1000 >
self.stream_emit_interval_ms)
)

# Resolve effective intervals (per-request > engine > default)
token_interval = request.py_stream_interval or self.stream_interval
Comment thread
Thachnh marked this conversation as resolved.
time_interval_ms = request.py_stream_interval_ms or self.stream_interval_ms or 0

# Time interval takes priority; fall back to token interval
if time_interval_ms > 0:
interval_triggered = (
request.py_last_stream_emit_time is not None
and (now - request.py_last_stream_emit_time) * 1000
>= time_interval_ms)
else:
interval_triggered = request.py_last_stream_emit_iter >= token_interval

should_emit = (request.py_decoding_iter == 1
or request.is_finished or interval_triggered)

if should_emit:
request.py_last_stream_emit_iter = 0
request.py_last_stream_emit_time = now
response = request.create_response(False, self.dist.rank)
if response:
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/executor/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,7 @@ def _deduce_max_tokens(request: GenerationRequest,
executor_request.py_lora_path = py_lora_path
executor_request.py_logprobs_mode = request.sampling_params.logprobs_mode
executor_request.py_stream_interval = request.sampling_params.stream_interval
executor_request.py_stream_interval_ms = request.sampling_params.stream_interval_ms

# here we add executor_request.py_disaggregated_params= request.disaggregated_params for python cache transceiver
if self._is_pytorch_backend and request.disaggregated_params is not None:
Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,10 @@ def _prepare_sampling_params(
if sampling_params.stream_interval is None:
sampling_params.stream_interval = getattr(self.args,
"stream_interval", 1)
if sampling_params.stream_interval_ms is None:
engine_val = getattr(self.args, "stream_interval_ms", 0)
if engine_val > 0:
sampling_params.stream_interval_ms = engine_val
sampling_params.return_perf_metrics = sampling_params.return_perf_metrics or self.args.return_perf_metrics
return sampling_params

Expand Down
20 changes: 6 additions & 14 deletions tensorrt_llm/llmapi/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2999,15 +2999,18 @@ class TorchLlmArgs(BaseLlmArgs):
# TODO: make this a per-request parameter
stream_interval: int = Field(
default=1,
ge=1,
description=
"The iteration interval to create responses under the streaming mode. "
"Set this to a larger value when the batch size is large, which helps reduce the streaming overhead.",
)
stream_emit_interval_ms: int = Field(
stream_interval_ms: int = Field(
default=0,
ge=0,
description=
"The time interval (milliseconds) to create responses under the streaming mode. "
"Set to 0 to disable time-based throttling.",
"The time interval in milliseconds to create responses under the streaming mode. "
"Set to 0 to disable time-based streaming throttle. "
"When stream_interval_ms is set (> 0), it takes priority over stream_interval.",
)

force_dynamic_quantization: bool = Field(
Expand Down Expand Up @@ -3220,17 +3223,6 @@ def validate_speculative_config(self):

return self

@model_validator(mode="after")
def validate_stream_interval(self):
if self.stream_interval <= 0:
raise ValueError(
f"stream_interval must be positive, got {self.stream_interval}")
if self.stream_emit_interval_ms < 0:
raise ValueError(
"stream_emit_interval_ms must be non-negative, got "
f"{self.stream_emit_interval_ms}")
return self

@model_validator(mode="after")
def validate_checkpoint_format(self):
if self.checkpoint_format is not None and self.checkpoint_loader is not None:
Expand Down
5 changes: 5 additions & 0 deletions tensorrt_llm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ class SamplingParams:
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
stream_interval: Optional[int] = None
stream_interval_ms: Optional[int] = None

def __post_init__(self):
if self.pad_id is None:
Expand Down Expand Up @@ -319,6 +320,10 @@ def _validate(self):
raise ValueError(
f"require stream_interval > 0, got stream_interval={self.stream_interval}"
)
if self.stream_interval_ms is not None and self.stream_interval_ms <= 0:
raise ValueError(
f"require stream_interval_ms > 0, got stream_interval_ms={self.stream_interval_ms}"
)
if self.top_p is not None and (self.top_p < 0 or self.top_p > 1):
raise ValueError(f"require 0 <= top_p <= 1, got top_p={self.top_p}")
if self.top_k is not None and self.top_k < 0:
Expand Down
24 changes: 24 additions & 0 deletions tensorrt_llm/serve/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,13 @@ class CompletionRequest(OpenAIBaseModel):
"The iteration interval to create responses under the streaming mode. "
"If not set, the engine-level default is used."),
)
stream_interval_ms: Optional[int] = Field(
default=None,
gt=0,
description=(
"The time interval in milliseconds to create responses under the streaming mode. "
"If not set, the engine-level default is used."),
)

# doc: end-completion-extra-params

Expand Down Expand Up @@ -439,6 +446,7 @@ def to_sampling_params(self,
# completion-extra-params
add_special_tokens=self.add_special_tokens,
stream_interval=self.stream_interval,
stream_interval_ms=self.stream_interval_ms,

# TODO: migrate to use logprobs and prompt_logprobs
_return_log_probs=bool(self.logprobs),
Expand Down Expand Up @@ -761,6 +769,13 @@ class ChatCompletionRequest(OpenAIBaseModel):
"The iteration interval to create responses under the streaming mode. "
"If not set, the engine-level default is used."),
)
stream_interval_ms: Optional[int] = Field(
default=None,
gt=0,
description=(
"The time interval in milliseconds to create responses under the streaming mode. "
"If not set, the engine-level default is used."),
)

# doc: end-chat-completion-extra-params

Expand Down Expand Up @@ -808,6 +823,7 @@ def to_sampling_params(self,
# chat-completion-extra-params
add_special_tokens=self.add_special_tokens,
stream_interval=self.stream_interval,
stream_interval_ms=self.stream_interval_ms,

# TODO: migrate to use logprobs and prompt_logprobs
_return_log_probs=bool(self.logprobs),
Expand Down Expand Up @@ -926,6 +942,13 @@ class ResponsesRequest(OpenAIBaseModel):
"The iteration interval to create responses under the streaming mode. "
"If not set, the engine-level default is used."),
)
stream_interval_ms: Optional[int] = Field(
default=None,
gt=0,
description=(
"The time interval in milliseconds to create responses under the streaming mode. "
"If not set, the engine-level default is used."),
)

request_id: str = Field(
default_factory=lambda: f"resp_{str(uuid.uuid4().hex)}",
Expand Down Expand Up @@ -972,6 +995,7 @@ def to_sampling_params(
stop_token_ids=stop_token_ids,
guided_decoding=guided_decoding,
stream_interval=self.stream_interval,
stream_interval_ms=self.stream_interval_ms,
)

@model_validator(mode="before")
Expand Down
3 changes: 3 additions & 0 deletions tests/unittest/api_stability/references_committed/llm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ methods:
stream_interval:
annotation: int
default: 1
stream_interval_ms:
annotation: int
default: 0

kwargs:
annotation: Any
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ methods:
stream_interval:
annotation: Optional[int]
default: null
stream_interval_ms:
annotation: Optional[int]
default: null
# Returning controls
logprobs:
annotation: Optional[int]
Expand Down
89 changes: 89 additions & 0 deletions tests/unittest/llmapi/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,95 @@ def test_stream_interval_openai_protocol():
ResponsesRequest(model="m", input="hi", stream_interval=-1)


def test_stream_interval_ms_validation():
# Valid values
sp = SamplingParams(stream_interval_ms=100)
assert sp.stream_interval_ms == 100
sp = SamplingParams(stream_interval_ms=1)
assert sp.stream_interval_ms == 1
sp = SamplingParams(stream_interval_ms=None)
assert sp.stream_interval_ms is None

# Both can be set simultaneously
sp = SamplingParams(stream_interval=5, stream_interval_ms=100)
assert sp.stream_interval == 5
assert sp.stream_interval_ms == 100

# Invalid values
with pytest.raises(ValueError, match="stream_interval_ms"):
SamplingParams(stream_interval_ms=0)
with pytest.raises(ValueError, match="stream_interval_ms"):
SamplingParams(stream_interval_ms=-1)


def test_stream_interval_ms_openai_protocol():
from pydantic import ValidationError

from tensorrt_llm.serve.openai_protocol import (
ChatCompletionRequest,
CompletionRequest,
ResponsesRequest,
)

# CompletionRequest: valid stream_interval_ms passes through
req = CompletionRequest(model="m", prompt="hi", stream_interval_ms=100)
sp = req.to_sampling_params()
assert sp.stream_interval_ms == 100

# CompletionRequest: None leaves stream_interval_ms unset
req = CompletionRequest(model="m", prompt="hi")
sp = req.to_sampling_params()
assert sp.stream_interval_ms is None

# CompletionRequest: invalid stream_interval_ms rejected at construction
with pytest.raises(ValidationError, match="stream_interval_ms"):
CompletionRequest(model="m", prompt="hi", stream_interval_ms=-1)
with pytest.raises(ValidationError, match="stream_interval_ms"):
CompletionRequest(model="m", prompt="hi", stream_interval_ms=0)

# ChatCompletionRequest: valid stream_interval_ms passes through
req = ChatCompletionRequest(model="m",
messages=[{
"role": "user",
"content": "hi"
}],
stream_interval_ms=200)
sp = req.to_sampling_params()
assert sp.stream_interval_ms == 200

# ChatCompletionRequest: invalid stream_interval_ms rejected at construction
with pytest.raises(ValidationError, match="stream_interval_ms"):
ChatCompletionRequest(model="m",
messages=[{
"role": "user",
"content": "hi"
}],
stream_interval_ms=0)

# ResponsesRequest: valid stream_interval_ms passes through
req = ResponsesRequest(model="m", input="hi", stream_interval_ms=50)
sp = req.to_sampling_params()
assert sp.stream_interval_ms == 50

# ResponsesRequest: None leaves stream_interval_ms unset
req = ResponsesRequest(model="m", input="hi")
sp = req.to_sampling_params()
assert sp.stream_interval_ms is None

# ResponsesRequest: invalid stream_interval_ms rejected at construction
with pytest.raises(ValidationError, match="stream_interval_ms"):
ResponsesRequest(model="m", input="hi", stream_interval_ms=-1)

# Both can be set together
req = CompletionRequest(model="m",
prompt="hi",
stream_interval=5,
stream_interval_ms=100)
sp = req.to_sampling_params()
assert sp.stream_interval == 5
assert sp.stream_interval_ms == 100


@pytest.mark.skipif(torch.cuda.device_count() < 2 or WORLD_SIZE != 2,
reason="Must run on 2 MPI ranks with at least 2 GPUs")
def test_sync_generation_tp_main_node_only(llama_7b_tp2_path: Path):
Expand Down
Loading