diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index d9b115bf8fe9..51a25e5a2036 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -509,11 +509,12 @@ class GuidedDecodingParams kSTRUCTURAL_TAG = 4, }; - explicit GuidedDecodingParams(GuideType guideType, std::optional guide = std::nullopt); + explicit GuidedDecodingParams(GuideType guideType, std::optional guide = std::nullopt, std::optional guidanceStartTokenId = std::nullopt); bool operator==(GuidedDecodingParams const& other) const; [[nodiscard]] GuideType getGuideType() const; [[nodiscard]] std::optional getGuide() const; + [[nodiscard]] std::optional getGuidanceStartTokenId() const; private: friend class Serialization; @@ -523,6 +524,7 @@ class GuidedDecodingParams /// @brief The detailed guide string. It could be a json schema, a regular expression or a EBNF grammar depending on /// mGuideType. std::optional mGuide; + std::optional mGuidanceStartTokenId; }; using RetentionPriority = SizeType32; diff --git a/cpp/tensorrt_llm/executor/guidedDecodingParams.cpp b/cpp/tensorrt_llm/executor/guidedDecodingParams.cpp index 83b99e0b4125..929a9271c3f4 100644 --- a/cpp/tensorrt_llm/executor/guidedDecodingParams.cpp +++ b/cpp/tensorrt_llm/executor/guidedDecodingParams.cpp @@ -22,9 +22,10 @@ namespace tensorrt_llm::executor { -GuidedDecodingParams::GuidedDecodingParams(GuideType guideType, std::optional guide) +GuidedDecodingParams::GuidedDecodingParams(GuideType guideType, std::optional guide, std::optional guidanceStartTokenId) : mGuideType{guideType} , mGuide{std::move(guide)} + , mGuidanceStartTokenId{guidanceStartTokenId} { TLLM_CHECK_WITH_INFO(mGuideType == GuideType::kJSON || mGuide.has_value(), "The guide string must be provided unless using GuideType::kJSON."); @@ -32,7 +33,7 @@ GuidedDecodingParams::GuidedDecodingParams(GuideType guideType, std::optional GuidedDecodingParams::getGuide() const return mGuide; } +std::optional GuidedDecodingParams::getGuidanceStartTokenId() const +{ + return mGuidanceStartTokenId; +} + } // namespace tensorrt_llm::executor diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index 0dd28641d1dc..b893f645c002 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -1563,13 +1563,15 @@ GuidedDecodingParams Serialization::deserializeGuidedDecodingParams(std::istream { auto guideType = su::deserializeWithGetterType(is); auto guide = su::deserializeWithGetterType(is); - return GuidedDecodingParams(guideType, guide); + auto guidanceStartTokenId = su::deserializeWithGetterType(is); + return GuidedDecodingParams(guideType, guide, guidanceStartTokenId); } void Serialization::serialize(GuidedDecodingParams const& guidedDecodingParams, std::ostream& os) { su::serialize(guidedDecodingParams.getGuideType(), os); su::serialize(guidedDecodingParams.getGuide(), os); + su::serialize(guidedDecodingParams.getGuidanceStartTokenId(), os); } size_t Serialization::serializedSize(GuidedDecodingParams const& guidedDecodingParams) @@ -1577,6 +1579,7 @@ size_t Serialization::serializedSize(GuidedDecodingParams const& guidedDecodingP size_t totalSize = 0; totalSize += su::serializedSize(guidedDecodingParams.getGuideType()); totalSize += su::serializedSize(guidedDecodingParams.getGuide()); + totalSize += su::serializedSize(guidedDecodingParams.getGuidanceStartTokenId()); return totalSize; } diff --git a/cpp/tensorrt_llm/nanobind/executor/request.cpp b/cpp/tensorrt_llm/nanobind/executor/request.cpp index d26c8dd70e0a..40994b55d0e1 100644 --- a/cpp/tensorrt_llm/nanobind/executor/request.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/request.cpp @@ -542,23 +542,24 @@ void initRequestBindings(nb::module_& m) .value("STRUCTURAL_TAG", tle::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG); auto guidedDecodingParamsGetstate - = [](tle::GuidedDecodingParams const& self) { return nb::make_tuple(self.getGuideType(), self.getGuide()); }; + = [](tle::GuidedDecodingParams const& self) { return nb::make_tuple(self.getGuideType(), self.getGuide(), self.getGuidanceStartTokenId()); }; auto guidedDecodingParamsSetstate = [](tle::GuidedDecodingParams& self, nb::tuple const& state) { - if (state.size() != 2) + if (state.size() != 3) { throw std::runtime_error("Invalid GuidedDecodingParams state!"); } new (&self) tle::GuidedDecodingParams( - nb::cast(state[0]), nb::cast>(state[1])); + nb::cast(state[0]), nb::cast>(state[1]), nb::cast>(state[2])); }; pyGuidedDecodingParams - .def(nb::init>(), nb::arg("guide_type"), - nb::arg("guide") = nb::none()) + .def(nb::init, std::optional>(), nb::arg("guide_type"), + nb::arg("guide") = nb::none(), nb::arg("guidance_start_token_id") = nb::none()) .def_prop_ro("guide_type", &tle::GuidedDecodingParams::getGuideType) .def_prop_ro("guide", &tle::GuidedDecodingParams::getGuide) + .def_prop_ro("guidance_start_token_id", &tle::GuidedDecodingParams::getGuidanceStartTokenId) .def("__getstate__", guidedDecodingParamsGetstate) .def("__setstate__", guidedDecodingParamsSetstate); diff --git a/cpp/tensorrt_llm/pybind/executor/request.cpp b/cpp/tensorrt_llm/pybind/executor/request.cpp index 4eb61ecde98d..19e20d5510a6 100644 --- a/cpp/tensorrt_llm/pybind/executor/request.cpp +++ b/cpp/tensorrt_llm/pybind/executor/request.cpp @@ -496,11 +496,11 @@ void initRequestBindings(pybind11::module_& m) .value("STRUCTURAL_TAG", tle::GuidedDecodingParams::GuideType::kSTRUCTURAL_TAG); auto guidedDecodingParamsGetstate - = [](tle::GuidedDecodingParams const& self) { return py::make_tuple(self.getGuideType(), self.getGuide()); }; + = [](tle::GuidedDecodingParams const& self) { return py::make_tuple(self.getGuideType(), self.getGuide(), self.getGuidanceStartTokenId()); }; auto guidedDecodingParamsSetstate = [](py::tuple state) { - if (state.size() != 2) + if (state.size() != 3) { throw std::runtime_error("Invalid GuidedDecodingParams state!"); } @@ -513,6 +513,7 @@ void initRequestBindings(pybind11::module_& m) py::arg("guide") = py::none()) .def_property_readonly("guide_type", &tle::GuidedDecodingParams::getGuideType) .def_property_readonly("guide", &tle::GuidedDecodingParams::getGuide) + .def_property_readonly("guidance_start_token_Id", &tle::GuidedDecodingParams::getGuidanceStartTokenId) .def(py::pickle(guidedDecodingParamsGetstate, guidedDecodingParamsSetstate)); auto requestGetstate = [](tle::Request const& self) diff --git a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py index 68fb6bf13d83..c8065c15df2b 100644 --- a/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py +++ b/tensorrt_llm/_torch/pyexecutor/grammar_matcher.py @@ -29,6 +29,10 @@ def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, def is_terminated(self) -> bool: pass + @abstractmethod + def guidance_started(self) -> bool: + pass + class GrammarMatcherFactory(ABC): @@ -56,7 +60,60 @@ def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, def is_terminated(self) -> bool: return self._matcher.is_terminated() + + def guidance_started(self) -> bool: + return True + + +class GrammarMatcherWrapper(GrammarMatcher): + def __init__(self, matcher: GrammarMatcher, guidance_start_token_id: int): + super().__init__() + self._matcher = matcher + self._guidance_start_token_id = guidance_start_token_id + self._guidance_started = False + self._steps_after_guidance_start = 0 + + def accept_token(self, token_id: int) -> bool: + if not self._guidance_started: + if token_id == self._guidance_start_token_id: + self._guidance_started = True + self._steps_after_guidance_start = 0 + return True + else: + return True + self._steps_after_guidance_start += 1 + return self._matcher.accept_token(token_id) + + def rollback(self, num_tokens: int) -> None: + if not self._guidance_started: + return + # cannot rollback more than _steps_after_guidance_start + num_tokens_to_rollback = min(num_tokens, self._steps_after_guidance_start) + if num_tokens > self._steps_after_guidance_start: + self._guidance_started = False + self._matcher.rollback(num_tokens_to_rollback) + + def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, + index: int) -> None: + self._matcher.fill_next_token_bitmask(next_token_bitmask, index) + + def is_terminated(self) -> bool: + return self._matcher.is_terminated() + + def guidance_started(self) -> bool: + return self._guidance_started +class GrammarMatcherFactoryWrapper(GrammarMatcherFactory): + def __init__(self, factory: GrammarMatcherFactory): + super().__init__() + self._factory = factory + + def create(self, + guided_decoding_params: GuidedDecodingParams) -> GrammarMatcher: + matcher = self._factory.create(guided_decoding_params) + if guided_decoding_params.guidance_start_token_id: + return GrammarMatcherWrapper(matcher, guided_decoding_params.guidance_start_token_id) + return matcher class XGrammarMatcherFactory(GrammarMatcherFactory): @@ -167,6 +224,9 @@ def fill_next_token_bitmask(self, next_token_bitmask: torch.Tensor, def is_terminated(self) -> bool: return self._is_terminated + def guidance_started(self) -> bool: + return True + def _check_err(self) -> None: if self._matcher.is_error(): raise ValueError( diff --git a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py index c27c81f4b08b..42ca1460640c 100644 --- a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py +++ b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py @@ -10,7 +10,7 @@ from ...bindings.internal.batch_manager import LlmRequestType from ...logger import logger from ..hostfunc import hostfunc -from .grammar_matcher import (GrammarMatcher, LLGuidanceMatcherFactory, +from .grammar_matcher import (GrammarMatcher, GrammarMatcherFactoryWrapper, LLGuidanceMatcherFactory, XGrammarMatcherFactory) from .llm_request import LlmRequest from .scheduler import ScheduledRequests @@ -161,6 +161,7 @@ def __init__(self, raise ValueError( f"Invalid guided decoding backend: {self.guided_decoding_backend}" ) + self.grammar_matcher_factory = GrammarMatcherFactoryWrapper(self.grammar_matcher_factory) logger.info( f"Guided decoder initialized with backend: {self.guided_decoding_backend}" ) @@ -249,9 +250,10 @@ def _build(self, requests: GuidedRequests) -> None: self.num_advanced_tokens[slot] += 1 if not matcher.is_terminated(): - matcher.fill_next_token_bitmask(self.bitmask_host, offset) - self.token_mask_host[offset] = 1 - self.num_guided_tokens[slot] += 1 + if matcher.guidance_started(): + matcher.fill_next_token_bitmask(self.bitmask_host, offset) + self.token_mask_host[offset] = 1 + self.num_guided_tokens[slot] += 1 # Process draft tokens for i, tid in enumerate(req.draft_tokens, 1): accepted = matcher.accept_token(tid) @@ -260,10 +262,11 @@ def _build(self, requests: GuidedRequests) -> None: self.num_advanced_tokens[slot] += 1 if matcher.is_terminated(): break - matcher.fill_next_token_bitmask(self.bitmask_host, + if matcher.guidance_started(): + matcher.fill_next_token_bitmask(self.bitmask_host, offset + i) - self.token_mask_host[offset + i] = 1 - self.num_guided_tokens[slot] += 1 + self.token_mask_host[offset + i] = 1 + self.num_guided_tokens[slot] += 1 if req.is_draft: assert len(req.draft_tokens) == 0 diff --git a/tensorrt_llm/sampling_params.py b/tensorrt_llm/sampling_params.py index 38f0a07bbf0b..d79413967bc1 100644 --- a/tensorrt_llm/sampling_params.py +++ b/tensorrt_llm/sampling_params.py @@ -27,10 +27,14 @@ class GuidedDecodingParams: grammar: Optional[str] = None json_object: bool = False structural_tag: Optional[str] = None + guidance_start_token_id: Optional[int] = None def _validate(self): + exclude_fields = set(["guidance_start_token_id"]) num_guides = 0 for _field in fields(self): + if _field.name in exclude_fields: + continue num_guides += bool(getattr(self, _field.name)) if num_guides > 1: raise ValueError(f"Only one guide can be used for a request, but got {num_guides}.") @@ -459,7 +463,9 @@ def _get_guided_decoding_params(self) -> tllme.GuidedDecodingParams: return None if self.guided_decoding.json_object: - return tllme.GuidedDecodingParams(tllme.GuidedDecodingParams.GuideType.JSON) + return tllme.GuidedDecodingParams( + tllme.GuidedDecodingParams.GuideType.JSON, None, self.guided_decoding.guidance_start_token_id, + ) elif self.guided_decoding.json is not None: json_schema = self.guided_decoding.json if isinstance(json_schema, BaseModel): @@ -467,20 +473,21 @@ def _get_guided_decoding_params(self) -> tllme.GuidedDecodingParams: if isinstance(json_schema, dict): json_schema = json.dumps(json_schema) return tllme.GuidedDecodingParams( - tllme.GuidedDecodingParams.GuideType.JSON_SCHEMA, json_schema + tllme.GuidedDecodingParams.GuideType.JSON_SCHEMA, json_schema, self.guided_decoding.guidance_start_token_id ) elif self.guided_decoding.regex is not None: return tllme.GuidedDecodingParams( - tllme.GuidedDecodingParams.GuideType.REGEX, self.guided_decoding.regex + tllme.GuidedDecodingParams.GuideType.REGEX, self.guided_decoding.regex, self.guided_decoding.guidance_start_token_id ) elif self.guided_decoding.grammar is not None: return tllme.GuidedDecodingParams( - tllme.GuidedDecodingParams.GuideType.EBNF_GRAMMAR, self.guided_decoding.grammar + tllme.GuidedDecodingParams.GuideType.EBNF_GRAMMAR, self.guided_decoding.grammar, self.guided_decoding.guidance_start_token_id ) elif self.guided_decoding.structural_tag is not None: return tllme.GuidedDecodingParams( tllme.GuidedDecodingParams.GuideType.STRUCTURAL_TAG, self.guided_decoding.structural_tag, + self.guided_decoding.guidance_start_token_id, ) else: return None diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index c391ff0e360c..954e02be47a9 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -100,7 +100,7 @@ class ResponseFormat(OpenAIBaseModel): schema: Optional[dict] = None structures: Optional[List[StructuralTag]] = None triggers: Optional[List[str]] = None - + guidance_start_token_id: Optional[int] = None class DisaggregatedParams(OpenAIBaseModel): request_type: str @@ -189,9 +189,9 @@ def _response_format_to_guided_decoding_params( raise ValueError( "The 'schema' field is required when response_format.type is 'json'." ) - return GuidedDecodingParams(json=response_format.schema) + return GuidedDecodingParams(json=response_format.schema, guidance_start_token_id=response_format.guidance_start_token_id) elif response_format.type == "json_object": - return GuidedDecodingParams(json_object=True) + return GuidedDecodingParams(json_object=True, guidance_start_token_id=response_format.guidance_start_token_id) elif response_format.type == "structural_tag": return GuidedDecodingParams( structural_tag=response_format.model_dump_json(by_alias=True,