From 12da08eb31cd65924c26cbdd0e9eeeaebc1c17bc Mon Sep 17 00:00:00 2001 From: Quan Truong Date: Thu, 2 Apr 2026 21:24:56 +0000 Subject: [PATCH 1/2] Add return tokens id option --- tensorrt_llm/serve/openai_protocol.py | 4 ++++ tensorrt_llm/serve/openai_server.py | 12 ++++++++++++ 2 files changed, 16 insertions(+) diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 6ec157e72efa..f0093a32137a 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -192,6 +192,7 @@ class CompletionStreamResponse(OpenAIBaseModel): model: str choices: List[CompletionResponseStreamChoice] usage: Optional[UsageInfo] = Field(default=None) + prompt_token_ids: Optional[List[int]] = None def _response_format_to_guided_decoding_params( response_format: Optional[ResponseFormat], @@ -361,6 +362,7 @@ class CompletionRequest(OpenAIBaseModel): truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None return_context_logits: bool = False detokenize: bool = True + return_token_ids: bool = False # doc: end-completion-sampling-params # doc: begin-completion-extra-params @@ -608,6 +610,7 @@ class ChatCompletionStreamResponse(OpenAIBaseModel): model: str choices: List[ChatCompletionResponseStreamChoice] usage: Optional[UsageInfo] = Field(default=None) + prompt_token_ids: Optional[List[int]] = None class FunctionDefinition(OpenAIBaseModel): @@ -684,6 +687,7 @@ class ChatCompletionRequest(OpenAIBaseModel): spaces_between_special_tokens: bool = True truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None lora_request: Optional[LoRARequest] = None + return_token_ids: bool = False # doc: end-chat-completion-sampling-params # doc: begin-chat-completion-extra-params diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index ffa6b2febbae..d933a13a687d 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -1023,6 +1023,7 @@ def get_role() -> str: async def chat_stream_generator( promise: RequestOutput, postproc_params: PostprocParams) -> AsyncGenerator[str, None]: nonlocal did_complete + is_first_chunk = request.return_token_ids if not self.postproc_worker_enabled: post_processor, args = postproc_params.post_processor, postproc_params.postproc_args try: @@ -1036,6 +1037,9 @@ async def chat_stream_generator( prom_metrics["request_completed_total"] += 1 prom_metrics[f"request_success_total{{finished_reason=\"{choice.finish_reason}\""] += 1 + if is_first_chunk: + pp_res.prompt_token_ids = promise.prompt_token_ids + is_first_chunk = False pp_res_json = pp_res.model_dump_json(exclude_unset=True) yield f"data: {pp_res_json}\n\n" yield f"data: [DONE]\n\n" @@ -1137,6 +1141,8 @@ async def chat_stream_generator( response = await self._create_chat_response(promise, postproc_params, disaggregated_params) if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only": response.prompt_token_ids = promise.prompt_token_ids + if request.return_token_ids: + response.prompt_token_ids = promise.prompt_token_ids raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds() await self._extract_metrics(promise, raw_request) return JSONResponse(content=response.model_dump()) @@ -1260,6 +1266,8 @@ async def completion_response(promise: RequestOutput, if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only": # Include prompt token ids for context-only requests pp_result.prompt_token_ids = response.prompt_token_ids + if request.return_token_ids: + pp_result.prompt_token_ids = response.prompt_token_ids raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds() await self._extract_metrics(response, raw_request) @@ -1302,6 +1310,7 @@ def merge_completion_responses(responses: List[CompletionResponse]) -> Completio async def completion_generator(promise: RequestOutput, params: Optional[PostprocParams]): did_complete = False + is_first_chunk = request.return_token_ids try: async for output in promise: if not self.postproc_worker_enabled: @@ -1316,6 +1325,9 @@ async def completion_generator(promise: RequestOutput, params: Optional[Postproc did_complete = True prom_metrics["request_completed_total"] += 1 prom_metrics[f"request_success_total{{finished_reason=\"{choice.finish_reason}\""] += 1 + if is_first_chunk: + pp_res.prompt_token_ids = promise.prompt_token_ids + is_first_chunk = False pp_res_json = pp_res.model_dump_json(exclude_unset=True) yield f"data: {pp_res_json}\n\n" finally: From 97dec15ebe9dcb15d48904a4a0c112b179953dc4 Mon Sep 17 00:00:00 2001 From: Quan Truong Date: Thu, 2 Apr 2026 21:46:14 +0000 Subject: [PATCH 2/2] Change naming --- tensorrt_llm/serve/openai_server.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index d933a13a687d..55caca39cf6a 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -1023,7 +1023,7 @@ def get_role() -> str: async def chat_stream_generator( promise: RequestOutput, postproc_params: PostprocParams) -> AsyncGenerator[str, None]: nonlocal did_complete - is_first_chunk = request.return_token_ids + return_token_ids_on_first_chunk = request.return_token_ids if not self.postproc_worker_enabled: post_processor, args = postproc_params.post_processor, postproc_params.postproc_args try: @@ -1037,9 +1037,9 @@ async def chat_stream_generator( prom_metrics["request_completed_total"] += 1 prom_metrics[f"request_success_total{{finished_reason=\"{choice.finish_reason}\""] += 1 - if is_first_chunk: + if return_token_ids_on_first_chunk: pp_res.prompt_token_ids = promise.prompt_token_ids - is_first_chunk = False + return_token_ids_on_first_chunk = False pp_res_json = pp_res.model_dump_json(exclude_unset=True) yield f"data: {pp_res_json}\n\n" yield f"data: [DONE]\n\n" @@ -1310,7 +1310,7 @@ def merge_completion_responses(responses: List[CompletionResponse]) -> Completio async def completion_generator(promise: RequestOutput, params: Optional[PostprocParams]): did_complete = False - is_first_chunk = request.return_token_ids + return_token_ids_on_first_chunk = request.return_token_ids try: async for output in promise: if not self.postproc_worker_enabled: @@ -1325,9 +1325,9 @@ async def completion_generator(promise: RequestOutput, params: Optional[Postproc did_complete = True prom_metrics["request_completed_total"] += 1 prom_metrics[f"request_success_total{{finished_reason=\"{choice.finish_reason}\""] += 1 - if is_first_chunk: + if return_token_ids_on_first_chunk: pp_res.prompt_token_ids = promise.prompt_token_ids - is_first_chunk = False + return_token_ids_on_first_chunk = False pp_res_json = pp_res.model_dump_json(exclude_unset=True) yield f"data: {pp_res_json}\n\n" finally: