|
60 | 60 | PyArrowFile, |
61 | 61 | PyArrowFileIO, |
62 | 62 | StatsAggregator, |
| 63 | + _check_schema_compatible, |
63 | 64 | _ConvertToArrowSchema, |
64 | 65 | _determine_partitions, |
65 | 66 | _primitive_to_physical, |
@@ -1722,6 +1723,96 @@ def test_bin_pack_arrow_table(arrow_table_with_null: pa.Table) -> None: |
1722 | 1723 | assert len(list(bin_packed)) == 5 |
1723 | 1724 |
|
1724 | 1725 |
|
| 1726 | +def test_schema_mismatch_type(table_schema_simple: Schema) -> None: |
| 1727 | + other_schema = pa.schema(( |
| 1728 | + pa.field("foo", pa.string(), nullable=True), |
| 1729 | + pa.field("bar", pa.decimal128(18, 6), nullable=False), |
| 1730 | + pa.field("baz", pa.bool_(), nullable=True), |
| 1731 | + )) |
| 1732 | + |
| 1733 | + expected = r"""Mismatch in fields: |
| 1734 | +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ |
| 1735 | +┃ ┃ Table field ┃ Dataframe field ┃ |
| 1736 | +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ |
| 1737 | +│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ |
| 1738 | +│ ❌ │ 2: bar: required int │ 2: bar: required decimal\(18, 6\) │ |
| 1739 | +│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ |
| 1740 | +└────┴──────────────────────────┴─────────────────────────────────┘ |
| 1741 | +""" |
| 1742 | + |
| 1743 | + with pytest.raises(ValueError, match=expected): |
| 1744 | + _check_schema_compatible(table_schema_simple, other_schema) |
| 1745 | + |
| 1746 | + |
| 1747 | +def test_schema_mismatch_nullability(table_schema_simple: Schema) -> None: |
| 1748 | + other_schema = pa.schema(( |
| 1749 | + pa.field("foo", pa.string(), nullable=True), |
| 1750 | + pa.field("bar", pa.int32(), nullable=True), |
| 1751 | + pa.field("baz", pa.bool_(), nullable=True), |
| 1752 | + )) |
| 1753 | + |
| 1754 | + expected = """Mismatch in fields: |
| 1755 | +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ |
| 1756 | +┃ ┃ Table field ┃ Dataframe field ┃ |
| 1757 | +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ |
| 1758 | +│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ |
| 1759 | +│ ❌ │ 2: bar: required int │ 2: bar: optional int │ |
| 1760 | +│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ |
| 1761 | +└────┴──────────────────────────┴──────────────────────────┘ |
| 1762 | +""" |
| 1763 | + |
| 1764 | + with pytest.raises(ValueError, match=expected): |
| 1765 | + _check_schema_compatible(table_schema_simple, other_schema) |
| 1766 | + |
| 1767 | + |
| 1768 | +def test_schema_mismatch_missing_field(table_schema_simple: Schema) -> None: |
| 1769 | + other_schema = pa.schema(( |
| 1770 | + pa.field("foo", pa.string(), nullable=True), |
| 1771 | + pa.field("baz", pa.bool_(), nullable=True), |
| 1772 | + )) |
| 1773 | + |
| 1774 | + expected = """Mismatch in fields: |
| 1775 | +┏━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━┓ |
| 1776 | +┃ ┃ Table field ┃ Dataframe field ┃ |
| 1777 | +┡━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━┩ |
| 1778 | +│ ✅ │ 1: foo: optional string │ 1: foo: optional string │ |
| 1779 | +│ ❌ │ 2: bar: required int │ Missing │ |
| 1780 | +│ ✅ │ 3: baz: optional boolean │ 3: baz: optional boolean │ |
| 1781 | +└────┴──────────────────────────┴──────────────────────────┘ |
| 1782 | +""" |
| 1783 | + |
| 1784 | + with pytest.raises(ValueError, match=expected): |
| 1785 | + _check_schema_compatible(table_schema_simple, other_schema) |
| 1786 | + |
| 1787 | + |
| 1788 | +def test_schema_mismatch_additional_field(table_schema_simple: Schema) -> None: |
| 1789 | + other_schema = pa.schema(( |
| 1790 | + pa.field("foo", pa.string(), nullable=True), |
| 1791 | + pa.field("bar", pa.int32(), nullable=True), |
| 1792 | + pa.field("baz", pa.bool_(), nullable=True), |
| 1793 | + pa.field("new_field", pa.date32(), nullable=True), |
| 1794 | + )) |
| 1795 | + |
| 1796 | + expected = r"PyArrow table contains more columns: new_field. Update the schema first \(hint, use union_by_name\)." |
| 1797 | + |
| 1798 | + with pytest.raises(ValueError, match=expected): |
| 1799 | + _check_schema_compatible(table_schema_simple, other_schema) |
| 1800 | + |
| 1801 | + |
| 1802 | +def test_schema_downcast(table_schema_simple: Schema) -> None: |
| 1803 | + # large_string type is compatible with string type |
| 1804 | + other_schema = pa.schema(( |
| 1805 | + pa.field("foo", pa.large_string(), nullable=True), |
| 1806 | + pa.field("bar", pa.int32(), nullable=False), |
| 1807 | + pa.field("baz", pa.bool_(), nullable=True), |
| 1808 | + )) |
| 1809 | + |
| 1810 | + try: |
| 1811 | + _check_schema_compatible(table_schema_simple, other_schema) |
| 1812 | + except Exception: |
| 1813 | + pytest.fail("Unexpected Exception raised when calling `_check_schema`") |
| 1814 | + |
| 1815 | + |
1725 | 1816 | def test_partition_for_demo() -> None: |
1726 | 1817 | test_pa_schema = pa.schema([("year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) |
1727 | 1818 | test_schema = Schema( |
|
0 commit comments