diff --git a/agentlightning/types/resources.py b/agentlightning/types/resources.py index c3adf5f84..9fa8f2248 100644 --- a/agentlightning/types/resources.py +++ b/agentlightning/types/resources.py @@ -4,14 +4,8 @@ import inspect import logging -from typing import ( - Annotated, - Any, - Dict, - Literal, - Optional, - Union, -) +from pathlib import Path +from typing import Annotated, Any, Dict, Literal, Optional, Union from pydantic import BaseModel, Field @@ -138,20 +132,61 @@ class PromptTemplate(Resource): engine (Literal['jinja', 'f-string', 'poml']): The templating engine to use for rendering the prompt. I imagine users can use their own customized engines, but algos can only well operate on a subset of them. + + Notes: + ``poml`` templates accept an additional keyword argument ``_poml_format`` + when calling :meth:`format`. This value is forwarded to :func:`poml.poml` + and defaults to ``"openai_chat"``. """ resource_type: Literal["prompt_template"] = "prompt_template" template: str engine: Literal["jinja", "f-string", "poml"] - def format(self, **kwargs: Any) -> str: - """Format the prompt template with the given kwargs.""" + def format(self, **kwargs: Any) -> Any: + """Format the prompt template with the given kwargs. + + Returns: + Any: The rendered prompt. ``f-string`` and ``jinja`` engines return a + string, while ``poml`` returns the object produced by :func:`poml.poml`. + + Raises: + RuntimeError: If the required optional dependency for the configured + engine is not available. + """ if self.engine == "f-string": return self.template.format(**kwargs) - else: - raise NotImplementedError( - "Formatting prompt templates for non-f-string engines with format() helper is not supported yet." - ) + if self.engine == "jinja": + try: + from jinja2 import Template + except ImportError as exc: # pragma: no cover - defensive guard + raise RuntimeError( + "Formatting a PromptTemplate with engine 'jinja' requires the 'jinja2' package to be installed." + ) from exc + + template = Template(self.template) + return template.render(**kwargs) + if self.engine == "poml": + try: + import poml # type: ignore[import-not-found] + except ImportError as exc: # pragma: no cover - defensive guard + raise RuntimeError( + "Formatting a PromptTemplate with engine 'poml' requires the 'poml' package to be installed." + ) from exc + + poml_kwargs = dict(kwargs) + poml_format = poml_kwargs.pop("_poml_format", "openai_chat") + + template_input: Any + template_path = Path(self.template) + if template_path.suffix == ".poml" and template_path.exists(): + template_input = template_path + else: + template_input = self.template + + return poml.poml(template_input, context=poml_kwargs, format=poml_format) + + raise NotImplementedError(f"Unknown prompt template engine: {self.engine}") # Use discriminated union for proper deserialization diff --git a/tests/types/test_prompt_template.py b/tests/types/test_prompt_template.py new file mode 100644 index 000000000..34463c21b --- /dev/null +++ b/tests/types/test_prompt_template.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import sys +from pathlib import Path +from types import SimpleNamespace + +import pytest + +from agentlightning.types import PromptTemplate + + +def test_prompt_template_format_f_string() -> None: + template = PromptTemplate(template="Hello {name}!", engine="f-string") + + assert template.format(name="World") == "Hello World!" + + +def test_prompt_template_format_jinja(monkeypatch: pytest.MonkeyPatch) -> None: + class DummyTemplate: + def __init__(self, source: str) -> None: + self.source = source + + def render(self, **context: str) -> str: + result = self.source + for key, value in context.items(): + result = result.replace(f"{{{{ {key} }}}}", value) + return result + + monkeypatch.setitem(sys.modules, "jinja2", SimpleNamespace(Template=DummyTemplate)) + + template = PromptTemplate(template="Hello {{ name }}!", engine="jinja") + + assert template.format(name="World") == "Hello World!" + + +def test_prompt_template_format_poml_inline(monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[tuple[object, dict[str, object], str]] = [] + + def dummy_poml(template: object, context: dict[str, object], format: str) -> dict[str, object]: + calls.append((template, context, format)) + return {"template": template, "context": context, "format": format} + + monkeypatch.setitem(sys.modules, "poml", SimpleNamespace(poml=dummy_poml)) + + template = PromptTemplate(template="{{ name }}", engine="poml") + + result = template.format(name="World") + + assert calls == [("{{ name }}", {"name": "World"}, "openai_chat")] + assert result == {"template": "{{ name }}", "context": {"name": "World"}, "format": "openai_chat"} + + +def test_prompt_template_format_poml_path(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + calls: list[tuple[object, dict[str, object], str]] = [] + + def dummy_poml(template: object, context: dict[str, object], format: str) -> dict[str, object]: + calls.append((template, context, format)) + return {"template": template, "context": context, "format": format} + + monkeypatch.setitem(sys.modules, "poml", SimpleNamespace(poml=dummy_poml)) + + poml_file = tmp_path / "sample.poml" + poml_file.write_text("{{ name }}") + + template = PromptTemplate(template=str(poml_file), engine="poml") + + result = template.format(name="World", _poml_format="raw") + + assert len(calls) == 1 + call_template, context, output_format = calls[0] + assert isinstance(call_template, Path) + assert call_template == poml_file + assert context == {"name": "World"} + assert output_format == "raw" + assert result == {"template": poml_file, "context": {"name": "World"}, "format": "raw"}