Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ def __init__(
self.py_num_accepted_draft_tokens_indices = []
self.py_rewind_draft_token_separate_adjustment = 0
self.py_decoding_iter = 0
self.py_stream_interval = None
self.py_last_stream_emit_time = None
self.is_attention_dp_dummy = False
self.is_cuda_graph_dummy = False
Expand Down Expand Up @@ -927,6 +928,8 @@ def executor_request_to_llm_request(
LogprobMode.RAW),
)

llm_request.py_stream_interval = getattr(executor_request,
"py_stream_interval", None)
llm_request.py_disaggregated_params = getattr(executor_request,
"py_disaggregated_params",
None)
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3214,7 +3214,7 @@ def _handle_responses(self):
should_emit = (
request.py_decoding_iter == 1
or request.is_finished
or request.py_decoding_iter % self.stream_interval == 0
or request.py_decoding_iter % (request.py_stream_interval or self.stream_interval) == 0
or (self.stream_emit_interval_ms > 0
and request.py_last_stream_emit_time
and (now - request.py_last_stream_emit_time) * 1000 >
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/executor/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ def _deduce_max_tokens(request: GenerationRequest,
executor_request.py_num_logprobs = request.sampling_params.logprobs
executor_request.py_lora_path = py_lora_path
executor_request.py_logprobs_mode = request.sampling_params.logprobs_mode
executor_request.py_stream_interval = request.sampling_params.stream_interval

# here we add executor_request.py_disaggregated_params= request.disaggregated_params for python cache transceiver
if self._is_pytorch_backend and request.disaggregated_params is not None:
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/executor/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def _handle_response(self, response: "GenerationExecutor.Response"):
prev_text=beam_output.text,
states=beam_output._incremental_states,
flush=self._done,
stream_interval=self.sampling_params._stream_interval,
stream_interval=self.sampling_params.stream_interval,
**kwargs)
else:
beam_output.text = self.tokenizer.decode(
Expand Down
6 changes: 3 additions & 3 deletions tensorrt_llm/llmapi/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,9 +786,9 @@ def _prepare_sampling_params(
sampling_params.return_generation_logits = True
sampling_params._generation_logits_auto_enabled = True

if sampling_params._stream_interval is None:
sampling_params._stream_interval = getattr(self.args,
"stream_interval", 1)
if sampling_params.stream_interval is None:
sampling_params.stream_interval = getattr(self.args,
"stream_interval", 1)
sampling_params.return_perf_metrics = sampling_params.return_perf_metrics or self.args.return_perf_metrics
return sampling_params

Expand Down
8 changes: 5 additions & 3 deletions tensorrt_llm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,7 @@ class SamplingParams:
truncate_prompt_tokens: Optional[int] = None
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
# Currently, _stream_interval is only used to pass llm.args.stream_interval to tokenizer.
# TODO: make this a per-request parameter.
_stream_interval: Optional[int] = field(default=None, init=False, repr=False)
stream_interval: Optional[int] = None

def __post_init__(self):
if self.pad_id is None:
Expand All @@ -317,6 +315,10 @@ def _validate(self):
For instance, while the greedy decoding with n > 1 is capable in the
Executor class of C++ runtime, the LLM API disallows such combination.
"""
if self.stream_interval is not None and self.stream_interval <= 0:
raise ValueError(
f"require stream_interval > 0, got stream_interval={self.stream_interval}"
)
if self.top_p is not None and (self.top_p < 0 or self.top_p > 1):
raise ValueError(f"require 0 <= top_p <= 1, got top_p={self.top_p}")
if self.top_k is not None and self.top_k < 0:
Expand Down
24 changes: 24 additions & 0 deletions tensorrt_llm/serve/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,13 @@ class CompletionRequest(OpenAIBaseModel):
default=None,
description=("Parameters for disaggregated serving"),
)
stream_interval: Optional[int] = Field(
default=None,
gt=0,
description=(
"The iteration interval to create responses under the streaming mode. "
"If not set, the engine-level default is used."),
)

# doc: end-completion-extra-params

Expand Down Expand Up @@ -431,6 +438,7 @@ def to_sampling_params(self,

# completion-extra-params
add_special_tokens=self.add_special_tokens,
stream_interval=self.stream_interval,

# TODO: migrate to use logprobs and prompt_logprobs
_return_log_probs=bool(self.logprobs),
Expand Down Expand Up @@ -746,6 +754,13 @@ class ChatCompletionRequest(OpenAIBaseModel):
("If specified, KV cache will be salted with the provided string "
"to limit the kv cache reuse on with the requests having the same string."
))
stream_interval: Optional[int] = Field(
default=None,
gt=0,
description=(
"The iteration interval to create responses under the streaming mode. "
"If not set, the engine-level default is used."),
)

# doc: end-chat-completion-extra-params

Expand Down Expand Up @@ -792,6 +807,7 @@ def to_sampling_params(self,

# chat-completion-extra-params
add_special_tokens=self.add_special_tokens,
stream_interval=self.stream_interval,

# TODO: migrate to use logprobs and prompt_logprobs
_return_log_probs=bool(self.logprobs),
Expand Down Expand Up @@ -903,6 +919,13 @@ class ResponsesRequest(OpenAIBaseModel):
top_p: Optional[float] = None
truncation: Optional[Literal["auto", "disabled"]] = "disabled"
user: Optional[str] = None
stream_interval: Optional[int] = Field(
default=None,
gt=0,
description=(
"The iteration interval to create responses under the streaming mode. "
"If not set, the engine-level default is used."),
)

request_id: str = Field(
default_factory=lambda: f"resp_{str(uuid.uuid4().hex)}",
Expand Down Expand Up @@ -948,6 +971,7 @@ def to_sampling_params(
logprobs=self.top_logprobs,
stop_token_ids=stop_token_ids,
guided_decoding=guided_decoding,
stream_interval=self.stream_interval,
)

@model_validator(mode="before")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ methods:
spaces_between_special_tokens:
annotation: bool
default: true
stream_interval:
annotation: Optional[int]
default: null
# Returning controls
logprobs:
annotation: Optional[int]
Expand Down
75 changes: 75 additions & 0 deletions tests/unittest/llmapi/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,81 @@ def test_invalid_sampling_params():
SamplingParams(max_tokens=4, n=4, best_of=3, use_beam_search=True)


def test_stream_interval_validation():
# Valid values
sp = SamplingParams(stream_interval=1)
assert sp.stream_interval == 1
sp = SamplingParams(stream_interval=10)
assert sp.stream_interval == 10
sp = SamplingParams(stream_interval=None)
assert sp.stream_interval is None

# Invalid values
with pytest.raises(ValueError, match="stream_interval"):
SamplingParams(stream_interval=0)
with pytest.raises(ValueError, match="stream_interval"):
SamplingParams(stream_interval=-1)


def test_stream_interval_openai_protocol():
from pydantic import ValidationError

from tensorrt_llm.serve.openai_protocol import (
ChatCompletionRequest,
CompletionRequest,
ResponsesRequest,
)

# CompletionRequest: valid stream_interval passes through
req = CompletionRequest(model="m", prompt="hi", stream_interval=5)
sp = req.to_sampling_params()
assert sp.stream_interval == 5

# CompletionRequest: None leaves stream_interval unset
req = CompletionRequest(model="m", prompt="hi")
sp = req.to_sampling_params()
assert sp.stream_interval is None

# CompletionRequest: invalid stream_interval rejected at construction
with pytest.raises(ValidationError, match="stream_interval"):
CompletionRequest(model="m", prompt="hi", stream_interval=-1)
with pytest.raises(ValidationError, match="stream_interval"):
CompletionRequest(model="m", prompt="hi", stream_interval=0)

# ChatCompletionRequest: valid stream_interval passes through
req = ChatCompletionRequest(model="m",
messages=[{
"role": "user",
"content": "hi"
}],
stream_interval=10)
sp = req.to_sampling_params()
assert sp.stream_interval == 10

# ChatCompletionRequest: invalid stream_interval rejected at construction
with pytest.raises(ValidationError, match="stream_interval"):
ChatCompletionRequest(model="m",
messages=[{
"role": "user",
"content": "hi"
}],
stream_interval=0)

# ResponsesRequest: valid stream_interval passes through
req = ResponsesRequest(model="m", input="hi", stream_interval=5)
sp = req.to_sampling_params()
assert sp.stream_interval == 5

# ResponsesRequest: None leaves stream_interval unset
req = ResponsesRequest(model="m", input="hi")
sp = req.to_sampling_params()
assert sp.stream_interval is None

# ResponsesRequest: invalid stream_interval rejected at construction
with pytest.raises(ValidationError, match="stream_interval"):
ResponsesRequest(model="m", input="hi", stream_interval=-1)


@pytest.mark.skipif(torch.cuda.device_count() < 2 or WORLD_SIZE != 2,
reason="Must run on 2 MPI ranks with at least 2 GPUs")
def test_sync_generation_tp_main_node_only(llama_7b_tp2_path: Path):
Expand Down
Loading