Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions tensorrt_llm/serve/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ class CompletionStreamResponse(OpenAIBaseModel):
model: str
choices: List[CompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
prompt_token_ids: Optional[List[int]] = None

def _response_format_to_guided_decoding_params(
response_format: Optional[ResponseFormat],
Expand Down Expand Up @@ -361,6 +362,7 @@ class CompletionRequest(OpenAIBaseModel):
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
return_context_logits: bool = False
detokenize: bool = True
return_token_ids: bool = False
# doc: end-completion-sampling-params

# doc: begin-completion-extra-params
Expand Down Expand Up @@ -608,6 +610,7 @@ class ChatCompletionStreamResponse(OpenAIBaseModel):
model: str
choices: List[ChatCompletionResponseStreamChoice]
usage: Optional[UsageInfo] = Field(default=None)
prompt_token_ids: Optional[List[int]] = None


class FunctionDefinition(OpenAIBaseModel):
Expand Down Expand Up @@ -684,6 +687,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
lora_request: Optional[LoRARequest] = None
return_token_ids: bool = False
# doc: end-chat-completion-sampling-params

# doc: begin-chat-completion-extra-params
Expand Down
12 changes: 12 additions & 0 deletions tensorrt_llm/serve/openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,7 @@ def get_role() -> str:
async def chat_stream_generator(
promise: RequestOutput, postproc_params: PostprocParams) -> AsyncGenerator[str, None]:
nonlocal did_complete
return_token_ids_on_first_chunk = request.return_token_ids
if not self.postproc_worker_enabled:
post_processor, args = postproc_params.post_processor, postproc_params.postproc_args
try:
Expand All @@ -1036,6 +1037,9 @@ async def chat_stream_generator(
prom_metrics["request_completed_total"] += 1
prom_metrics[f"request_success_total{{finished_reason=\"{choice.finish_reason}\""] += 1

if return_token_ids_on_first_chunk:
pp_res.prompt_token_ids = promise.prompt_token_ids
return_token_ids_on_first_chunk = False
pp_res_json = pp_res.model_dump_json(exclude_unset=True)
yield f"data: {pp_res_json}\n\n"
yield f"data: [DONE]\n\n"
Expand Down Expand Up @@ -1137,6 +1141,8 @@ async def chat_stream_generator(
response = await self._create_chat_response(promise, postproc_params, disaggregated_params)
if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only":
response.prompt_token_ids = promise.prompt_token_ids
if request.return_token_ids:
response.prompt_token_ids = promise.prompt_token_ids
raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds()
await self._extract_metrics(promise, raw_request)
return JSONResponse(content=response.model_dump())
Expand Down Expand Up @@ -1260,6 +1266,8 @@ async def completion_response(promise: RequestOutput,
if disaggregated_params and disaggregated_params.request_type and disaggregated_params.request_type == "context_only":
# Include prompt token ids for context-only requests
pp_result.prompt_token_ids = response.prompt_token_ids
if request.return_token_ids:
pp_result.prompt_token_ids = response.prompt_token_ids
raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds()
await self._extract_metrics(response, raw_request)

Expand Down Expand Up @@ -1302,6 +1310,7 @@ def merge_completion_responses(responses: List[CompletionResponse]) -> Completio

async def completion_generator(promise: RequestOutput, params: Optional[PostprocParams]):
did_complete = False
return_token_ids_on_first_chunk = request.return_token_ids
try:
async for output in promise:
if not self.postproc_worker_enabled:
Expand All @@ -1316,6 +1325,9 @@ async def completion_generator(promise: RequestOutput, params: Optional[Postproc
did_complete = True
prom_metrics["request_completed_total"] += 1
prom_metrics[f"request_success_total{{finished_reason=\"{choice.finish_reason}\""] += 1
if return_token_ids_on_first_chunk:
pp_res.prompt_token_ids = promise.prompt_token_ids
return_token_ids_on_first_chunk = False
pp_res_json = pp_res.model_dump_json(exclude_unset=True)
yield f"data: {pp_res_json}\n\n"
finally:
Expand Down
Loading