Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -208,23 +208,27 @@ def __init__(self) -> None:
)


class StrippedConstraint(StringConstraint):
"""Allows only strings that have no leading/trailing whitespace."""
class StrippedConstraint(PatternConstraint):
r"""Allows only strings that have no leading/trailing whitespace.

def validate(self, value: str, info: ValidationInfo) -> None:
if value != value.strip():
self._raise_validation_error(
value,
info,
f"String cannot have leading or trailing whitespace: {repr(value)}",
)
Uses ``\Z`` (absolute end-of-string) instead of ``$`` because
Python's ``$`` matches before a trailing ``\n``. ECMA regex (used by
JSON Schema) treats ``$`` as absolute end-of-string, so the JSON
schema output swaps ``\Z`` back to ``$``.
"""

def __init__(self) -> None:
super().__init__(
pattern=r"^(\S(.*\S)?)?\Z",
error_message="String cannot have leading or trailing whitespace: {value}",
description="String with no leading/trailing whitespace",
)

def __get_pydantic_json_schema__(
self, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> dict[str, Any]:
json_schema = handler(core_schema)
json_schema["pattern"] = r"^(\S(.*\S)?)?$"
json_schema["description"] = "String with no leading/trailing whitespace"
json_schema = super().__get_pydantic_json_schema__(core_schema, handler)
json_schema["pattern"] = self.pattern.pattern.replace(r"\Z", "$")
return json_schema


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@
["Restaurant", "gas-station", "shopping mall", "category!"],
"Invalid category format",
),
(
StrippedConstraint,
["hello", "hello world", "text with internal spaces", ""],
[" hello", "hello ", "\thello", "hello\n", " hello world "],
"leading or trailing whitespace",
),
]


Expand Down Expand Up @@ -174,22 +180,9 @@ class TestModel(BaseModel):
TestModel(pointer=ptr)
assert "JSON Pointer must start" in str(exc_info.value)

def test_whitespace_constraint_valid(self) -> None:
class TestModel(BaseModel):
text: Annotated[str, StrippedConstraint()]

for text in ["hello", "hello world", "text with internal spaces", ""]:
model = TestModel(text=text)
assert model.text == text

def test_whitespace_constraint_invalid(self) -> None:
class TestModel(BaseModel):
text: Annotated[str, StrippedConstraint()]

for text in [" hello", "hello ", "\thello", "hello\n", " hello world "]:
with pytest.raises(ValidationError) as exc_info:
TestModel(text=text)
assert "cannot have leading or trailing whitespace" in str(exc_info.value)
def test_stripped_constraint_pattern_string(self) -> None:
"""Codegen extracts the regex via constraint.pattern.pattern."""
assert StrippedConstraint().pattern.pattern == r"^(\S(.*\S)?)?\Z"


class TestJsonSchemaGeneration:
Expand Down Expand Up @@ -294,6 +287,7 @@ class TestPatternConstraintHierarchy:
SnakeCaseConstraint,
PhoneNumberConstraint,
RegionCodeConstraint,
StrippedConstraint,
WikidataIdConstraint,
],
)
Expand Down
Loading