From 025139330dc090f0afe7f2be2264cd8ba61dc279 Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Fri, 12 Sep 2025 17:37:19 -0700 Subject: [PATCH 01/15] reasoning guided decoder --- .../_torch/pyexecutor/grammar_matcher.py | 42 +++++++++++++++++++ .../_torch/pyexecutor/guided_decoder.py | 3 +- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py index 68fb6bf13d83..a9fad5fbd451 100644 --- a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py +++ b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py @@ -58,6 +58,48 @@ def is_terminated(self) -> bool: return self._matcher.is_terminated() +class GrammarMatcherWrapper(GrammarMatcher): + def __init__(self, matcher: GrammarMatcher, guided_decoding_params: GuidedDecodingParams): + super().__init__() + self._matcher = matcher + self._guided_decoding_params = guided_decoding_params + self._end_thinking_token_id = 128799 + self._is_thinking = True + self._steps_after_thinking = 0 + + def accept_token(self, token_id: int) -> bool: + if token_id == self._end_thinking_token_id and self._is_thinking: + self._is_thinking = False + return True + self._steps_after_thinking += 1 + return self._matcher.accept_token(token_id) + + def rollback(self, num_tokens: int) -> None: + if not self._is_thinking: + return + # cannot rollback more than steps after thinking + num_tokens_to_rollback = min(num_tokens, self._steps_after_thinking) + self._matcher.rollback(num_tokens_to_rollback) + if num_tokens > self._steps_after_thinking: + self._is_thinking = False + + def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, + index: int) -> None: + self._matcher.fill_next_token_bitmask(next_token_bitmask, index) + + def is_terminated(self) -> bool: + return self._matcher.is_terminated() + +class GrammarMatcherFactoryWrapper(GrammarMatcherFactory): + def __init__(self, factory: GrammarMatcherFactory): + super().__init__() + self._factory = factory + + def create(self, + guided_decoding_params: GuidedDecodingParams) -> GrammarMatcher: + matcher = self._factory.create(guided_decoding_params) + return GrammarMatcherWrapper(matcher, guided_decoding_params) + class XGrammarMatcherFactory(GrammarMatcherFactory): def __init__(self, diff --git a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py index c27c81f4b08b..0b361c95cd38 100644 --- a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py +++ b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py @@ -10,7 +10,7 @@ from ...bindings.internal.batch_manager import LlmRequestType from ...logger import logger from ..hostfunc import hostfunc -from .grammar_matcher import (GrammarMatcher, LLGuidanceMatcherFactory, +from .grammar_matcher import (GrammarMatcher, GrammarMatcherFactoryWrapper, LLGuidanceMatcherFactory, XGrammarMatcherFactory) from .llm_request import LlmRequest from .scheduler import ScheduledRequests @@ -161,6 +161,7 @@ def __init__(self, raise ValueError( f"Invalid guided decoding backend: {self.guided_decoding_backend}" ) + self.grammar_matcher_factory = GrammarMatcherFactoryWrapper(self.grammar_matcher_factory) logger.info( f"Guided decoder initialized with backend: {self.guided_decoding_backend}" ) From f4fe3b60ce5e06a1a0168405cfb48f97ab8909f7 Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Mon, 15 Sep 2025 09:44:47 -0700 Subject: [PATCH 02/15] skip if thinking --- .../_torch/pyexecutor/grammar_matcher.py | 16 +++++++++++++++- tensorrt_llm/_torch/pyexecutor/guided_decoder.py | 3 +++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py index a9fad5fbd451..074b31264a3e 100644 --- a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py +++ b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py @@ -29,6 +29,10 @@ def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, def is_terminated(self) -> bool: pass + @abstractmethod + def is_thinking(self) -> bool: + pass + class GrammarMatcherFactory(ABC): @@ -56,6 +60,9 @@ def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, def is_terminated(self) -> bool: return self._matcher.is_terminated() + + def is_thinking(self) -> bool: + return False class GrammarMatcherWrapper(GrammarMatcher): @@ -70,6 +77,7 @@ def __init__(self, matcher: GrammarMatcher, guided_decoding_params: GuidedDecodi def accept_token(self, token_id: int) -> bool: if token_id == self._end_thinking_token_id and self._is_thinking: self._is_thinking = False + self._steps_after_thinking = 0 return True self._steps_after_thinking += 1 return self._matcher.accept_token(token_id) @@ -77,7 +85,7 @@ def accept_token(self, token_id: int) -> bool: def rollback(self, num_tokens: int) -> None: if not self._is_thinking: return - # cannot rollback more than steps after thinking + # cannot rollback more than steps_after_thinking num_tokens_to_rollback = min(num_tokens, self._steps_after_thinking) self._matcher.rollback(num_tokens_to_rollback) if num_tokens > self._steps_after_thinking: @@ -89,6 +97,9 @@ def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, def is_terminated(self) -> bool: return self._matcher.is_terminated() + + def is_thinking(self) -> bool: + return self._is_thinking class GrammarMatcherFactoryWrapper(GrammarMatcherFactory): def __init__(self, factory: GrammarMatcherFactory): @@ -209,6 +220,9 @@ def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, def is_terminated(self) -> bool: return self._is_terminated + def is_thinking(self) -> bool: + return False + def _check_err(self) -> None: if self._matcher.is_error(): raise ValueError( diff --git a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py index 0b361c95cd38..309724927b03 100644 --- a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py +++ b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py @@ -261,6 +261,9 @@ def _build(self, requests: GuidedRequests) -> None: self.num_advanced_tokens[slot] += 1 if matcher.is_terminated(): break + if matcher.is_thinking(): + # don't apply bitmask when the matcher is thinking + continue matcher.fill_next_token_bitmask(self.bitmask_host, offset + i) self.token_mask_host[offset + i] = 1 From 3680d59bab9b11993c8a74f9c8fd7c015e475d8d Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Mon, 15 Sep 2025 10:11:21 -0700 Subject: [PATCH 03/15] fix --- tensorrt_llm/_torch/pyexecutor/grammar_matcher.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py index 074b31264a3e..2729968fa054 100644 --- a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py +++ b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py @@ -75,10 +75,14 @@ def __init__(self, matcher: GrammarMatcher, guided_decoding_params: GuidedDecodi self._steps_after_thinking = 0 def accept_token(self, token_id: int) -> bool: - if token_id == self._end_thinking_token_id and self._is_thinking: - self._is_thinking = False - self._steps_after_thinking = 0 - return True + print(token_id) + if self._is_thinking: + if token_id == self._end_thinking_token_id: + self._is_thinking = False + self._steps_after_thinking = 0 + return True + else: + return True self._steps_after_thinking += 1 return self._matcher.accept_token(token_id) From 9b217595d6c28a3c5f654615611a977c672ac666 Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Mon, 15 Sep 2025 10:13:39 -0700 Subject: [PATCH 04/15] fix --- tensorrt_llm/_torch/pyexecutor/grammar_matcher.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py index 2729968fa054..c6e4839e78fd 100644 --- a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py +++ b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py @@ -87,13 +87,13 @@ def accept_token(self, token_id: int) -> bool: return self._matcher.accept_token(token_id) def rollback(self, num_tokens: int) -> None: - if not self._is_thinking: + if self._is_thinking: return # cannot rollback more than steps_after_thinking num_tokens_to_rollback = min(num_tokens, self._steps_after_thinking) self._matcher.rollback(num_tokens_to_rollback) if num_tokens > self._steps_after_thinking: - self._is_thinking = False + self._is_thinking = True def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, index: int) -> None: From bc1e299857e0ec6ad2f4dc71f52faea294c51b55 Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Mon, 15 Sep 2025 10:19:55 -0700 Subject: [PATCH 05/15] fix --- tensorrt_llm/_torch/pyexecutor/guided_decoder.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py index 309724927b03..86385ccf5988 100644 --- a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py +++ b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py @@ -250,9 +250,10 @@ def _build(self, requests: GuidedRequests) -> None: self.num_advanced_tokens[slot] += 1 if not matcher.is_terminated(): - matcher.fill_next_token_bitmask(self.bitmask_host, offset) - self.token_mask_host[offset] = 1 - self.num_guided_tokens[slot] += 1 + if not matcher.is_thinking(): + matcher.fill_next_token_bitmask(self.bitmask_host, offset) + self.token_mask_host[offset] = 1 + self.num_guided_tokens[slot] += 1 # Process draft tokens for i, tid in enumerate(req.draft_tokens, 1): accepted = matcher.accept_token(tid) From 75f516be9431c316eded17b196adc1e29021fb91 Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Mon, 15 Sep 2025 13:46:00 -0700 Subject: [PATCH 06/15] control thinking from api --- cpp/include/tensorrt_llm/executor/executor.h | 4 ++- .../executor/guidedDecodingParams.cpp | 10 ++++-- cpp/tensorrt_llm/executor/serialization.cpp | 5 ++- .../nanobind/executor/request.cpp | 7 ++-- .../_torch/pyexecutor/grammar_matcher.py | 35 ++++++++++--------- tensorrt_llm/sampling_params.py | 14 +++++--- tensorrt_llm/serve/openai_protocol.py | 6 ++-- 7 files changed, 51 insertions(+), 30 deletions(-) diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index d9b115bf8fe9..75836b0ed64c 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -509,11 +509,12 @@ class GuidedDecodingParams kSTRUCTURAL_TAG = 4, }; - explicit GuidedDecodingParams(GuideType guideType, std::optional guide = std::nullopt); + explicit GuidedDecodingParams(GuideType guideType, std::optional guide = std::nullopt, std::optional thinkingEndTokenId = std::nullopt); bool operator==(GuidedDecodingParams const& other) const; [[nodiscard]] GuideType getGuideType() const; [[nodiscard]] std::optional getGuide() const; + [[nodiscard]] std::optional getThinkingEndTokenId() const; private: friend class Serialization; @@ -523,6 +524,7 @@ class GuidedDecodingParams /// @brief The detailed guide string. It could be a json schema, a regular expression or a EBNF grammar depending on /// mGuideType. std::optional mGuide; + std::optional mThinkingEndTokenId; }; using RetentionPriority = SizeType32; diff --git a/cpp/tensorrt_llm/executor/guidedDecodingParams.cpp b/cpp/tensorrt_llm/executor/guidedDecodingParams.cpp index 83b99e0b4125..0429572d0898 100644 --- a/cpp/tensorrt_llm/executor/guidedDecodingParams.cpp +++ b/cpp/tensorrt_llm/executor/guidedDecodingParams.cpp @@ -22,9 +22,10 @@ namespace tensorrt_llm::executor { -GuidedDecodingParams::GuidedDecodingParams(GuideType guideType, std::optional guide) +GuidedDecodingParams::GuidedDecodingParams(GuideType guideType, std::optional guide, std::optional thinkingEndTokenId) : mGuideType{guideType} , mGuide{std::move(guide)} + , mThinkingEndTokenId{thinkingEndTokenId} { TLLM_CHECK_WITH_INFO(mGuideType == GuideType::kJSON || mGuide.has_value(), "The guide string must be provided unless using GuideType::kJSON."); @@ -32,7 +33,7 @@ GuidedDecodingParams::GuidedDecodingParams(GuideType guideType, std::optional GuidedDecodingParams::getGuide() const return mGuide; } +std::optional(is); auto guide = su::deserializeWithGetterType(is); - return GuidedDecodingParams(guideType, guide); + auto thinkingEndTokenId = su::deserializeWithGetterType(is); + return GuidedDecodingParams(guideType, guide, thinkingEndTokenId); } void Serialization::serialize(GuidedDecodingParams const& guidedDecodingParams, std::ostream& os) { su::serialize(guidedDecodingParams.getGuideType(), os); su::serialize(guidedDecodingParams.getGuide(), os); + su::serialize(guidedDecodingParams.getThinkingEndTokenId(), os); } size_t Serialization::serializedSize(GuidedDecodingParams const& guidedDecodingParams) @@ -1577,6 +1579,7 @@ size_t Serialization::serializedSize(GuidedDecodingParams const& guidedDecodingP size_t totalSize = 0; totalSize += su::serializedSize(guidedDecodingParams.getGuideType()); totalSize += su::serializedSize(guidedDecodingParams.getGuide()); + totalSize += su::serializedSize(guidedDecodingParams.getThinkingEndTokenId()); return totalSize; } diff --git a/cpp/tensorrt_llm/nanobind/executor/request.cpp b/cpp/tensorrt_llm/nanobind/executor/request.cpp index d26c8dd70e0a..016a589816e0 100644 --- a/cpp/tensorrt_llm/nanobind/executor/request.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/request.cpp @@ -542,16 +542,16 @@ void initRequestBindings(nb::module_& m) .value("STRUCTURAL_TAG", tle::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG); auto guidedDecodingParamsGetstate - = [](tle::GuidedDecodingParams const& self) { return nb::make_tuple(self.getGuideType(), self.getGuide()); }; + = [](tle::GuidedDecodingParams const& self) { return nb::make_tuple(self.getGuideType(), self.getGuide(), self.getThinkingEndTokenId()); }; auto guidedDecodingParamsSetstate = [](tle::GuidedDecodingParams& self, nb::tuple const& state) { - if (state.size() != 2) + if (state.size() != 3) { throw std::runtime_error("Invalid GuidedDecodingParams state!"); } new (&self) tle::GuidedDecodingParams( - nb::cast(state[0]), nb::cast>(state[1])); + nb::cast(state[0]), nb::cast>(state[1]), nb::cast>(state[2])); }; pyGuidedDecodingParams @@ -559,6 +559,7 @@ void initRequestBindings(nb::module_& m) nb::arg("guide") = nb::none()) .def_prop_ro("guide_type", &tle::GuidedDecodingParams::getGuideType) .def_prop_ro("guide", &tle::GuidedDecodingParams::getGuide) + .def_prop_ro("thinking_end_token_id", &tle::GuidedDecodingParams::getThinkingEndTokenId) .def("__getstate__", guidedDecodingParamsGetstate) .def("__setstate__", guidedDecodingParamsSetstate); diff --git a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py index c6e4839e78fd..64d282494074 100644 --- a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py +++ b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py @@ -70,30 +70,33 @@ def __init__(self, matcher: GrammarMatcher, guided_decoding_params: GuidedDecodi super().__init__() self._matcher = matcher self._guided_decoding_params = guided_decoding_params - self._end_thinking_token_id = 128799 - self._is_thinking = True + self._end_thinking_token_id = guided_decoding_params.end_thinking_token_id + self._is_thinking = self._end_thinking_token_id is not None self._steps_after_thinking = 0 def accept_token(self, token_id: int) -> bool: print(token_id) - if self._is_thinking: - if token_id == self._end_thinking_token_id: - self._is_thinking = False - self._steps_after_thinking = 0 - return True - else: - return True - self._steps_after_thinking += 1 + if self._end_thinking_token_id: + if self._is_thinking: + if token_id == self._end_thinking_token_id: + self._is_thinking = False + self._steps_after_thinking = 0 + return True + else: + return True + self._steps_after_thinking += 1 return self._matcher.accept_token(token_id) def rollback(self, num_tokens: int) -> None: - if self._is_thinking: - return - # cannot rollback more than steps_after_thinking - num_tokens_to_rollback = min(num_tokens, self._steps_after_thinking) + num_tokens_to_rollback = num_tokens + if self._end_thinking_token_id: + if self._is_thinking: + return + # cannot rollback more than steps_after_thinking + num_tokens_to_rollback = min(num_tokens, self._steps_after_thinking) + if num_tokens > self._steps_after_thinking: + self._is_thinking = True self._matcher.rollback(num_tokens_to_rollback) - if num_tokens > self._steps_after_thinking: - self._is_thinking = True def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, index: int) -> None: diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index 38f0a07bbf0b..3e08d4783ac0 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -27,6 +27,7 @@ class GuidedDecodingParams: grammar: Optional[str] = None json_object: bool = False structural_tag: Optional[str] = None + thinking_end_token_id: Optional[int] = None def _validate(self): num_guides = 0 @@ -459,7 +460,11 @@ def _get_guided_decoding_params(self) -> tllme.GuidedDecodingParams: return None if self.guided_decoding.json_object: - return tllme.GuidedDecodingParams(tllme.GuidedDecodingParams.GuideType.JSON) + return tllme.GuidedDecodingParams( + tllme.GuidedDecodingParams.GuideType.JSON, + None, + self.guided_decoding.thinking_end_token_id, + ) elif self.guided_decoding.json is not None: json_schema = self.guided_decoding.json if isinstance(json_schema, BaseModel): @@ -467,20 +472,21 @@ def _get_guided_decoding_params(self) -> tllme.GuidedDecodingParams: if isinstance(json_schema, dict): json_schema = json.dumps(json_schema) return tllme.GuidedDecodingParams( - tllme.GuidedDecodingParams.GuideType.JSON_SCHEMA, json_schema + tllme.GuidedDecodingParams.GuideType.JSON_SCHEMA, json_schema, self.guided_decoding.thinking_end_token_id ) elif self.guided_decoding.regex is not None: return tllme.GuidedDecodingParams( - tllme.GuidedDecodingParams.GuideType.REGEX, self.guided_decoding.regex + tllme.GuidedDecodingParams.GuideType.REGEX, self.guided_decoding.regex, self.guided_decoding.thinking_end_token_id ) elif self.guided_decoding.grammar is not None: return tllme.GuidedDecodingParams( - tllme.GuidedDecodingParams.GuideType.EBNF_GRAMMAR, self.guided_decoding.grammar + tllme.GuidedDecodingParams.GuideType.EBNF_GRAMMAR, self.guided_decoding.grammar, self.guided_decoding.thinking_end_token_id ) elif self.guided_decoding.structural_tag is not None: return tllme.GuidedDecodingParams( tllme.GuidedDecodingParams.GuideType.STRUCTURAL_TAG, self.guided_decoding.structural_tag, + self.guided_decoding.thinking_end_token_id, ) else: return None diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index c391ff0e360c..85800881232d 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -100,7 +100,7 @@ class ResponseFormat(OpenAIBaseModel): schema: Optional[dict] = None structures: Optional[List[StructuralTag]] = None triggers: Optional[List[str]] = None - + thinking_end_token_id: Optional[int] = None class DisaggregatedParams(OpenAIBaseModel): request_type: str @@ -189,9 +189,9 @@ def _response_format_to_guided_decoding_params( raise ValueError( "The 'schema' field is required when response_format.type is 'json'." ) - return GuidedDecodingParams(json=response_format.schema) + return GuidedDecodingParams(json=response_format.schema, thinking_end_token_id=response_format.thinking_end_token_id) elif response_format.type == "json_object": - return GuidedDecodingParams(json_object=True) + return GuidedDecodingParams(json_object=True, thinking_end_token_id=response_format.thinking_end_token_id) elif response_format.type == "structural_tag": return GuidedDecodingParams( structural_tag=response_format.model_dump_json(by_alias=True, From 58dbf2f83adf113b1110dade61d4865d125bee06 Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Mon, 15 Sep 2025 14:15:43 -0700 Subject: [PATCH 07/15] fix --- cpp/tensorrt_llm/executor/guidedDecodingParams.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/tensorrt_llm/executor/guidedDecodingParams.cpp b/cpp/tensorrt_llm/executor/guidedDecodingParams.cpp index 0429572d0898..0910ad5ebc25 100644 --- a/cpp/tensorrt_llm/executor/guidedDecodingParams.cpp +++ b/cpp/tensorrt_llm/executor/guidedDecodingParams.cpp @@ -46,7 +46,7 @@ std::optional GuidedDecodingParams::getGuide() const return mGuide; } -std::optional GuidedDecodingParams::getThinkingEndTokenId() const { return mThinkingEndTokenId; } From 8bf51a89c8f7549ca1a98de44c74568ab1b430a4 Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Mon, 15 Sep 2025 15:38:51 -0700 Subject: [PATCH 08/15] update name --- cpp/include/tensorrt_llm/executor/executor.h | 6 +-- .../executor/guidedDecodingParams.cpp | 10 ++--- cpp/tensorrt_llm/executor/serialization.cpp | 8 ++-- .../nanobind/executor/request.cpp | 4 +- cpp/tensorrt_llm/pybind/executor/request.cpp | 5 ++- .../_torch/pyexecutor/grammar_matcher.py | 44 +++++++++---------- .../_torch/pyexecutor/guided_decoder.py | 12 +++-- tensorrt_llm/sampling_params.py | 14 +++--- tensorrt_llm/serve/openai_protocol.py | 6 +-- 9 files changed, 53 insertions(+), 56 deletions(-) diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 75836b0ed64c..51a25e5a2036 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -509,12 +509,12 @@ class GuidedDecodingParams kSTRUCTURAL_TAG = 4, }; - explicit GuidedDecodingParams(GuideType guideType, std::optional guide = std::nullopt, std::optional thinkingEndTokenId = std::nullopt); + explicit GuidedDecodingParams(GuideType guideType, std::optional guide = std::nullopt, std::optional guidanceStartTokenId = std::nullopt); bool operator==(GuidedDecodingParams const& other) const; [[nodiscard]] GuideType getGuideType() const; [[nodiscard]] std::optional getGuide() const; - [[nodiscard]] std::optional getThinkingEndTokenId() const; + [[nodiscard]] std::optional getGuidanceStartTokenId() const; private: friend class Serialization; @@ -524,7 +524,7 @@ class GuidedDecodingParams /// @brief The detailed guide string. It could be a json schema, a regular expression or a EBNF grammar depending on /// mGuideType. std::optional mGuide; - std::optional mThinkingEndTokenId; + std::optional mGuidanceStartTokenId; }; using RetentionPriority = SizeType32; diff --git a/cpp/tensorrt_llm/executor/guidedDecodingParams.cpp b/cpp/tensorrt_llm/executor/guidedDecodingParams.cpp index 0910ad5ebc25..929a9271c3f4 100644 --- a/cpp/tensorrt_llm/executor/guidedDecodingParams.cpp +++ b/cpp/tensorrt_llm/executor/guidedDecodingParams.cpp @@ -22,10 +22,10 @@ namespace tensorrt_llm::executor { -GuidedDecodingParams::GuidedDecodingParams(GuideType guideType, std::optional guide, std::optional thinkingEndTokenId) +GuidedDecodingParams::GuidedDecodingParams(GuideType guideType, std::optional guide, std::optional guidanceStartTokenId) : mGuideType{guideType} , mGuide{std::move(guide)} - , mThinkingEndTokenId{thinkingEndTokenId} + , mGuidanceStartTokenId{guidanceStartTokenId} { TLLM_CHECK_WITH_INFO(mGuideType == GuideType::kJSON || mGuide.has_value(), "The guide string must be provided unless using GuideType::kJSON."); @@ -33,7 +33,7 @@ GuidedDecodingParams::GuidedDecodingParams(GuideType guideType, std::optional GuidedDecodingParams::getGuide() const return mGuide; } -std::optional GuidedDecodingParams::getThinkingEndTokenId() const +std::optional GuidedDecodingParams::getGuidanceStartTokenId() const { - return mThinkingEndTokenId; + return mGuidanceStartTokenId; } } // namespace tensorrt_llm::executor diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index f713687cd265..9b6d0e301a56 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -1563,15 +1563,15 @@ GuidedDecodingParams Serialization::deserializeGuidedDecodingParams(std::istream { auto guideType = su::deserializeWithGetterType(is); auto guide = su::deserializeWithGetterType(is); - auto thinkingEndTokenId = su::deserializeWithGetterType(is); - return GuidedDecodingParams(guideType, guide, thinkingEndTokenId); + auto guidanceStartTokenId = su::deserializeWithGetterType(is); + return GuidedDecodingParams(guideType, guide, guidanceStartTokenId); } void Serialization::serialize(GuidedDecodingParams const& guidedDecodingParams, std::ostream& os) { su::serialize(guidedDecodingParams.getGuideType(), os); su::serialize(guidedDecodingParams.getGuide(), os); - su::serialize(guidedDecodingParams.getThinkingEndTokenId(), os); + su::serialize(guidedDecodingParams.getGuidanceStartTokenId(), os); } size_t Serialization::serializedSize(GuidedDecodingParams const& guidedDecodingParams) @@ -1579,7 +1579,7 @@ size_t Serialization::serializedSize(GuidedDecodingParams const& guidedDecodingP size_t totalSize = 0; totalSize += su::serializedSize(guidedDecodingParams.getGuideType()); totalSize += su::serializedSize(guidedDecodingParams.getGuide()); - totalSize += su::serializedSize(guidedDecodingParams.getThinkingEndTokenId()); + totalSize += su::serializedSize(guidedDecodingParams.getGuidanceStartTokenId()); return totalSize; } diff --git a/cpp/tensorrt_llm/nanobind/executor/request.cpp b/cpp/tensorrt_llm/nanobind/executor/request.cpp index 016a589816e0..2901cd74a652 100644 --- a/cpp/tensorrt_llm/nanobind/executor/request.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/request.cpp @@ -542,7 +542,7 @@ void initRequestBindings(nb::module_& m) .value("STRUCTURAL_TAG", tle::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG); auto guidedDecodingParamsGetstate - = [](tle::GuidedDecodingParams const& self) { return nb::make_tuple(self.getGuideType(), self.getGuide(), self.getThinkingEndTokenId()); }; + = [](tle::GuidedDecodingParams const& self) { return nb::make_tuple(self.getGuideType(), self.getGuide(), self.getGuidanceStartTokenId()); }; auto guidedDecodingParamsSetstate = [](tle::GuidedDecodingParams& self, nb::tuple const& state) { @@ -559,7 +559,7 @@ void initRequestBindings(nb::module_& m) nb::arg("guide") = nb::none()) .def_prop_ro("guide_type", &tle::GuidedDecodingParams::getGuideType) .def_prop_ro("guide", &tle::GuidedDecodingParams::getGuide) - .def_prop_ro("thinking_end_token_id", &tle::GuidedDecodingParams::getThinkingEndTokenId) + .def_prop_ro("guidance_start_token_id", &tle::GuidedDecodingParams::getGuidanceStartTokenId) .def("__getstate__", guidedDecodingParamsGetstate) .def("__setstate__", guidedDecodingParamsSetstate); diff --git a/cpp/tensorrt_llm/pybind/executor/request.cpp b/cpp/tensorrt_llm/pybind/executor/request.cpp index 4eb61ecde98d..19e20d5510a6 100644 --- a/cpp/tensorrt_llm/pybind/executor/request.cpp +++ b/cpp/tensorrt_llm/pybind/executor/request.cpp @@ -496,11 +496,11 @@ void initRequestBindings(pybind11::module_& m) .value("STRUCTURAL_TAG", tle::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG); auto guidedDecodingParamsGetstate - = [](tle::GuidedDecodingParams const& self) { return py::make_tuple(self.getGuideType(), self.getGuide()); }; + = [](tle::GuidedDecodingParams const& self) { return py::make_tuple(self.getGuideType(), self.getGuide(), self.getGuidanceStartTokenId()); }; auto guidedDecodingParamsSetstate = [](py::tuple state) { - if (state.size() != 2) + if (state.size() != 3) { throw std::runtime_error("Invalid GuidedDecodingParams state!"); } @@ -513,6 +513,7 @@ void initRequestBindings(pybind11::module_& m) py::arg("guide") = py::none()) .def_property_readonly("guide_type", &tle::GuidedDecodingParams::getGuideType) .def_property_readonly("guide", &tle::GuidedDecodingParams::getGuide) + .def_property_readonly("guidance_start_token_Id", &tle::GuidedDecodingParams::getGuidanceStartTokenId) .def(py::pickle(guidedDecodingParamsGetstate, guidedDecodingParamsSetstate)); auto requestGetstate = [](tle::Request const& self) diff --git a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py index 64d282494074..879f23a3cfde 100644 --- a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py +++ b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py @@ -30,7 +30,7 @@ def is_terminated(self) -> bool: pass @abstractmethod - def is_thinking(self) -> bool: + def guidance_started(self) -> bool: pass @@ -61,8 +61,8 @@ def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, def is_terminated(self) -> bool: return self._matcher.is_terminated() - def is_thinking(self) -> bool: - return False + def guidance_started(self) -> bool: + return True class GrammarMatcherWrapper(GrammarMatcher): @@ -70,32 +70,32 @@ def __init__(self, matcher: GrammarMatcher, guided_decoding_params: GuidedDecodi super().__init__() self._matcher = matcher self._guided_decoding_params = guided_decoding_params - self._end_thinking_token_id = guided_decoding_params.end_thinking_token_id - self._is_thinking = self._end_thinking_token_id is not None - self._steps_after_thinking = 0 + self._guidance_start_token_id = guided_decoding_params.guidance_start_token_id + self._guidance_started = self._guidance_start_token_id is None + self._steps_after_guidance_start = 0 def accept_token(self, token_id: int) -> bool: print(token_id) - if self._end_thinking_token_id: - if self._is_thinking: - if token_id == self._end_thinking_token_id: - self._is_thinking = False - self._steps_after_thinking = 0 + if self._guidance_start_token_id: + if not self._guidance_started: + if token_id == self._guidance_start_token_id: + self._guidance_started = True + self._steps_after_guidance_start = 0 return True else: return True - self._steps_after_thinking += 1 + self._steps_after_guidance_start += 1 return self._matcher.accept_token(token_id) def rollback(self, num_tokens: int) -> None: num_tokens_to_rollback = num_tokens - if self._end_thinking_token_id: - if self._is_thinking: + if self._guidance_start_token_id: + if not self._guidance_started: return - # cannot rollback more than steps_after_thinking - num_tokens_to_rollback = min(num_tokens, self._steps_after_thinking) - if num_tokens > self._steps_after_thinking: - self._is_thinking = True + # cannot rollback more than _steps_after_guidance_start + num_tokens_to_rollback = min(num_tokens, self._steps_after_guidance_start) + if num_tokens > self._steps_after_guidance_start: + self._guidance_started = False self._matcher.rollback(num_tokens_to_rollback) def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, @@ -105,8 +105,8 @@ def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, def is_terminated(self) -> bool: return self._matcher.is_terminated() - def is_thinking(self) -> bool: - return self._is_thinking + def guidance_started(self) -> bool: + return self._guidance_started class GrammarMatcherFactoryWrapper(GrammarMatcherFactory): def __init__(self, factory: GrammarMatcherFactory): @@ -227,8 +227,8 @@ def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, def is_terminated(self) -> bool: return self._is_terminated - def is_thinking(self) -> bool: - return False + def guidance_started(self) -> bool: + return True def _check_err(self) -> None: if self._matcher.is_error(): diff --git a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py index 86385ccf5988..42ca1460640c 100644 --- a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py +++ b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py @@ -250,7 +250,7 @@ def _build(self, requests: GuidedRequests) -> None: self.num_advanced_tokens[slot] += 1 if not matcher.is_terminated(): - if not matcher.is_thinking(): + if matcher.guidance_started(): matcher.fill_next_token_bitmask(self.bitmask_host, offset) self.token_mask_host[offset] = 1 self.num_guided_tokens[slot] += 1 @@ -262,13 +262,11 @@ def _build(self, requests: GuidedRequests) -> None: self.num_advanced_tokens[slot] += 1 if matcher.is_terminated(): break - if matcher.is_thinking(): - # don't apply bitmask when the matcher is thinking - continue - matcher.fill_next_token_bitmask(self.bitmask_host, + if matcher.guidance_started(): + matcher.fill_next_token_bitmask(self.bitmask_host, offset + i) - self.token_mask_host[offset + i] = 1 - self.num_guided_tokens[slot] += 1 + self.token_mask_host[offset + i] = 1 + self.num_guided_tokens[slot] += 1 if req.is_draft: assert len(req.draft_tokens) == 0 diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index 3e08d4783ac0..55de359f6ed9 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -27,7 +27,7 @@ class GuidedDecodingParams: grammar: Optional[str] = None json_object: bool = False structural_tag: Optional[str] = None - thinking_end_token_id: Optional[int] = None + guidance_start_token_id: Optional[int] = None def _validate(self): num_guides = 0 @@ -461,9 +461,7 @@ def _get_guided_decoding_params(self) -> tllme.GuidedDecodingParams: if self.guided_decoding.json_object: return tllme.GuidedDecodingParams( - tllme.GuidedDecodingParams.GuideType.JSON, - None, - self.guided_decoding.thinking_end_token_id, + tllme.GuidedDecodingParams.GuideType.JSON, None, self.guided_decoding.guidance_start_token_id, ) elif self.guided_decoding.json is not None: json_schema = self.guided_decoding.json @@ -472,21 +470,21 @@ def _get_guided_decoding_params(self) -> tllme.GuidedDecodingParams: if isinstance(json_schema, dict): json_schema = json.dumps(json_schema) return tllme.GuidedDecodingParams( - tllme.GuidedDecodingParams.GuideType.JSON_SCHEMA, json_schema, self.guided_decoding.thinking_end_token_id + tllme.GuidedDecodingParams.GuideType.JSON_SCHEMA, json_schema, self.guided_decoding.guidance_start_token_id ) elif self.guided_decoding.regex is not None: return tllme.GuidedDecodingParams( - tllme.GuidedDecodingParams.GuideType.REGEX, self.guided_decoding.regex, self.guided_decoding.thinking_end_token_id + tllme.GuidedDecodingParams.GuideType.REGEX, self.guided_decoding.regex, self.guided_decoding.guidance_start_token_id ) elif self.guided_decoding.grammar is not None: return tllme.GuidedDecodingParams( - tllme.GuidedDecodingParams.GuideType.EBNF_GRAMMAR, self.guided_decoding.grammar, self.guided_decoding.thinking_end_token_id + tllme.GuidedDecodingParams.GuideType.EBNF_GRAMMAR, self.guided_decoding.grammar, self.guided_decoding.guidance_start_token_id ) elif self.guided_decoding.structural_tag is not None: return tllme.GuidedDecodingParams( tllme.GuidedDecodingParams.GuideType.STRUCTURAL_TAG, self.guided_decoding.structural_tag, - self.guided_decoding.thinking_end_token_id, + self.guided_decoding.guidance_start_token_id, ) else: return None diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 85800881232d..954e02be47a9 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -100,7 +100,7 @@ class ResponseFormat(OpenAIBaseModel): schema: Optional[dict] = None structures: Optional[List[StructuralTag]] = None triggers: Optional[List[str]] = None - thinking_end_token_id: Optional[int] = None + guidance_start_token_id: Optional[int] = None class DisaggregatedParams(OpenAIBaseModel): request_type: str @@ -189,9 +189,9 @@ def _response_format_to_guided_decoding_params( raise ValueError( "The 'schema' field is required when response_format.type is 'json'." ) - return GuidedDecodingParams(json=response_format.schema, thinking_end_token_id=response_format.thinking_end_token_id) + return GuidedDecodingParams(json=response_format.schema, guidance_start_token_id=response_format.guidance_start_token_id) elif response_format.type == "json_object": - return GuidedDecodingParams(json_object=True, thinking_end_token_id=response_format.thinking_end_token_id) + return GuidedDecodingParams(json_object=True, guidance_start_token_id=response_format.guidance_start_token_id) elif response_format.type == "structural_tag": return GuidedDecodingParams( structural_tag=response_format.model_dump_json(by_alias=True, From 4b75546b21674f14dad8c0fd1d88c11a79c50d8c Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Mon, 15 Sep 2025 15:57:31 -0700 Subject: [PATCH 09/15] fix --- cpp/tensorrt_llm/executor/serialization.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index 9b6d0e301a56..b893f645c002 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -1563,7 +1563,7 @@ GuidedDecodingParams Serialization::deserializeGuidedDecodingParams(std::istream { auto guideType = su::deserializeWithGetterType(is); auto guide = su::deserializeWithGetterType(is); - auto guidanceStartTokenId = su::deserializeWithGetterType(is); + auto guidanceStartTokenId = su::deserializeWithGetterType(is); return GuidedDecodingParams(guideType, guide, guidanceStartTokenId); } From b5b872f3fdcd70f1476e38df91630743ea1a5a8c Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Mon, 15 Sep 2025 16:30:20 -0700 Subject: [PATCH 10/15] fix --- cpp/tensorrt_llm/nanobind/executor/request.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/tensorrt_llm/nanobind/executor/request.cpp b/cpp/tensorrt_llm/nanobind/executor/request.cpp index 2901cd74a652..69b26321474f 100644 --- a/cpp/tensorrt_llm/nanobind/executor/request.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/request.cpp @@ -555,8 +555,8 @@ void initRequestBindings(nb::module_& m) }; pyGuidedDecodingParams - .def(nb::init>(), nb::arg("guide_type"), - nb::arg("guide") = nb::none()) + .def(nb::init, std::optional(), nb::arg("guide_type"), + nb::arg("guide") = nb::none(), nb::arg("guidance_start_token_id") = nb::none()) .def_prop_ro("guide_type", &tle::GuidedDecodingParams::getGuideType) .def_prop_ro("guide", &tle::GuidedDecodingParams::getGuide) .def_prop_ro("guidance_start_token_id", &tle::GuidedDecodingParams::getGuidanceStartTokenId) From 2b56da716fbf9a63811c33e7f51ce21371006d8c Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Tue, 16 Sep 2025 10:38:11 -0700 Subject: [PATCH 11/15] fix --- cpp/tensorrt_llm/nanobind/executor/request.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/tensorrt_llm/nanobind/executor/request.cpp b/cpp/tensorrt_llm/nanobind/executor/request.cpp index 69b26321474f..40994b55d0e1 100644 --- a/cpp/tensorrt_llm/nanobind/executor/request.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/request.cpp @@ -555,7 +555,7 @@ void initRequestBindings(nb::module_& m) }; pyGuidedDecodingParams - .def(nb::init, std::optional(), nb::arg("guide_type"), + .def(nb::init, std::optional>(), nb::arg("guide_type"), nb::arg("guide") = nb::none(), nb::arg("guidance_start_token_id") = nb::none()) .def_prop_ro("guide_type", &tle::GuidedDecodingParams::getGuideType) .def_prop_ro("guide", &tle::GuidedDecodingParams::getGuide) From 9e97fda531a753fec17fc6a23049b41fda92a63a Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Tue, 16 Sep 2025 11:06:03 -0700 Subject: [PATCH 12/15] fix --- tensorrt_llm/sampling_params.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index 55de359f6ed9..d79413967bc1 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -30,8 +30,11 @@ class GuidedDecodingParams: guidance_start_token_id: Optional[int] = None def _validate(self): + exclude_fields = set(["guidance_start_token_id"]) num_guides = 0 for _field in fields(self): + if _field.name in exclude_fields: + continue num_guides += bool(getattr(self, _field.name)) if num_guides > 1: raise ValueError(f"Only one guide can be used for a request, but got {num_guides}.") From 6e1c5cebfda18e501ba38b27151aaeb269be7e42 Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Tue, 16 Sep 2025 11:11:11 -0700 Subject: [PATCH 13/15] remove log --- tensorrt_llm/_torch/pyexecutor/grammar_matcher.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py index 879f23a3cfde..5d72ed6e7c20 100644 --- a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py +++ b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py @@ -75,7 +75,6 @@ def __init__(self, matcher: GrammarMatcher, guided_decoding_params: GuidedDecodi self._steps_after_guidance_start = 0 def accept_token(self, token_id: int) -> bool: - print(token_id) if self._guidance_start_token_id: if not self._guidance_started: if token_id == self._guidance_start_token_id: From dec9d782135c9f1486a394622e1872392293557e Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Tue, 16 Sep 2025 11:15:07 -0700 Subject: [PATCH 14/15] simplify --- .../_torch/pyexecutor/grammar_matcher.py | 41 +++++++++---------- 1 file changed, 20 insertions(+), 21 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py index 5d72ed6e7c20..19b2ac1e7be7 100644 --- a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py +++ b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py @@ -66,35 +66,32 @@ def guidance_started(self) -> bool: class GrammarMatcherWrapper(GrammarMatcher): - def __init__(self, matcher: GrammarMatcher, guided_decoding_params: GuidedDecodingParams): + def __init__(self, matcher: GrammarMatcher, guidance_start_token_id: int): super().__init__() self._matcher = matcher - self._guided_decoding_params = guided_decoding_params - self._guidance_start_token_id = guided_decoding_params.guidance_start_token_id - self._guidance_started = self._guidance_start_token_id is None + self._guidance_start_token_id = guidance_start_token_id + self._guidance_started = False self._steps_after_guidance_start = 0 def accept_token(self, token_id: int) -> bool: - if self._guidance_start_token_id: - if not self._guidance_started: - if token_id == self._guidance_start_token_id: - self._guidance_started = True - self._steps_after_guidance_start = 0 - return True - else: - return True - self._steps_after_guidance_start += 1 + if not self._guidance_started: + if token_id == self._guidance_start_token_id: + self._guidance_started = True + self._steps_after_guidance_start = 0 + return True + else: + return True + self._steps_after_guidance_start += 1 return self._matcher.accept_token(token_id) def rollback(self, num_tokens: int) -> None: num_tokens_to_rollback = num_tokens - if self._guidance_start_token_id: - if not self._guidance_started: - return - # cannot rollback more than _steps_after_guidance_start - num_tokens_to_rollback = min(num_tokens, self._steps_after_guidance_start) - if num_tokens > self._steps_after_guidance_start: - self._guidance_started = False + if not self._guidance_started: + return + # cannot rollback more than _steps_after_guidance_start + num_tokens_to_rollback = min(num_tokens, self._steps_after_guidance_start) + if num_tokens > self._steps_after_guidance_start: + self._guidance_started = False self._matcher.rollback(num_tokens_to_rollback) def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, @@ -115,7 +112,9 @@ def __init__(self, factory: GrammarMatcherFactory): def create(self, guided_decoding_params: GuidedDecodingParams) -> GrammarMatcher: matcher = self._factory.create(guided_decoding_params) - return GrammarMatcherWrapper(matcher, guided_decoding_params) + if guided_decoding_params.guidance_start_token_id: + return GrammarMatcherWrapper(matcher, guided_decoding_params.guidance_start_token_id) + return matcher class XGrammarMatcherFactory(GrammarMatcherFactory): From 69ce2ac1e206b6419df702f3cdffcf3400c80c53 Mon Sep 17 00:00:00 2001 From: Shang-Pin Sheng Date: Tue, 16 Sep 2025 11:47:23 -0700 Subject: [PATCH 15/15] remove extra --- tensorrt_llm/_torch/pyexecutor/grammar_matcher.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py index 19b2ac1e7be7..c8065c15df2b 100644 --- a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py +++ b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py @@ -85,7 +85,6 @@ def accept_token(self, token_id: int) -> bool: return self._matcher.accept_token(token_id) def rollback(self, num_tokens: int) -> None: - num_tokens_to_rollback = num_tokens if not self._guidance_started: return # cannot rollback more than _steps_after_guidance_start