Skip to content

Commit 66be1eb

Browse files
danielcweeksFokko
andauthored
Fix literal predicate equality check (#94)
* Fix literal predicate equality check * Fix the tests * Some more fixes --------- Co-authored-by: Fokko Driesprong <fokko@tabular.io>
1 parent 4616d03 commit 66be1eb

3 files changed

Lines changed: 21 additions & 15 deletions

File tree

pyiceberg/expressions/__init__.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def __init__(self, term: Union[str, UnboundTerm[Any]]):
364364

365365
def __eq__(self, other: Any) -> bool:
366366
"""Return the equality of two instances of the UnboundPredicate class."""
367-
return self.term == other.term if isinstance(other, UnboundPredicate) else False
367+
return self.term == other.term if isinstance(other, self.__class__) else False
368368

369369
@abstractmethod
370370
def bind(self, schema: Schema, case_sensitive: bool = True) -> BooleanExpression:
@@ -531,7 +531,7 @@ def __repr__(self) -> str:
531531

532532
def __eq__(self, other: Any) -> bool:
533533
"""Return the equality of two instances of the SetPredicate class."""
534-
return self.term == other.term and self.literals == other.literals if isinstance(other, SetPredicate) else False
534+
return self.term == other.term and self.literals == other.literals if isinstance(other, self.__class__) else False
535535

536536
def __getnewargs__(self) -> Tuple[UnboundTerm[L], Set[Literal[L]]]:
537537
"""Pickle the SetPredicate class."""
@@ -664,12 +664,6 @@ def __invert__(self) -> In[L]:
664664
"""Transform the Expression into its negated version."""
665665
return In[L](self.term, self.literals)
666666

667-
def __eq__(self, other: Any) -> bool:
668-
"""Return the equality of two instances of the NotIn class."""
669-
if isinstance(other, NotIn):
670-
return self.term == other.term and self.literals == other.literals
671-
return False
672-
673667
@property
674668
def as_bound(self) -> Type[BoundNotIn[L]]:
675669
return BoundNotIn[L]
@@ -701,7 +695,7 @@ def bind(self, schema: Schema, case_sensitive: bool = True) -> BoundLiteralPredi
701695

702696
def __eq__(self, other: Any) -> bool:
703697
"""Return the equality of two instances of the LiteralPredicate class."""
704-
if isinstance(other, LiteralPredicate):
698+
if isinstance(other, self.__class__):
705699
return self.term == other.term and self.literal == other.literal
706700
return False
707701

tests/expressions/test_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ def test_greater_than() -> None:
8686

8787

8888
def test_greater_than_or_equal() -> None:
89-
assert GreaterThanOrEqual("foo", 5) == parser.parse("foo <= 5")
90-
assert GreaterThanOrEqual("foo", "a") == parser.parse("'a' >= foo")
89+
assert GreaterThanOrEqual("foo", 5) == parser.parse("foo >= 5")
90+
assert GreaterThanOrEqual("foo", "a") == parser.parse("'a' <= foo")
9191

9292

9393
def test_equal_to() -> None:

tests/test_transforms.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -825,15 +825,27 @@ def test_projection_truncate_string_literal_eq(bound_reference_str: BoundReferen
825825

826826

827827
def test_projection_truncate_string_literal_gt(bound_reference_str: BoundReference[str]) -> None:
828-
assert TruncateTransform(2).project("name", BoundGreaterThan(term=bound_reference_str, literal=literal("data"))) == EqualTo(
829-
term="name", literal=literal("da")
830-
)
828+
assert TruncateTransform(2).project(
829+
"name", BoundGreaterThan(term=bound_reference_str, literal=literal("data"))
830+
) == GreaterThanOrEqual(term="name", literal=literal("da"))
831831

832832

833833
def test_projection_truncate_string_literal_gte(bound_reference_str: BoundReference[str]) -> None:
834834
assert TruncateTransform(2).project(
835835
"name", BoundGreaterThanOrEqual(term=bound_reference_str, literal=literal("data"))
836-
) == EqualTo(term="name", literal=literal("da"))
836+
) == GreaterThanOrEqual(term="name", literal=literal("da"))
837+
838+
839+
def test_projection_truncate_string_literal_lt(bound_reference_str: BoundReference[str]) -> None:
840+
assert TruncateTransform(2).project(
841+
"name", BoundLessThan(term=bound_reference_str, literal=literal("data"))
842+
) == LessThanOrEqual(term="name", literal=literal("da"))
843+
844+
845+
def test_projection_truncate_string_literal_lte(bound_reference_str: BoundReference[str]) -> None:
846+
assert TruncateTransform(2).project(
847+
"name", BoundLessThanOrEqual(term=bound_reference_str, literal=literal("data"))
848+
) == LessThanOrEqual(term="name", literal=literal("da"))
837849

838850

839851
def test_projection_truncate_string_set_same_result(bound_reference_str: BoundReference[str]) -> None:

0 commit comments

Comments
 (0)