diff --git a/src/databricks/sqlalchemy/_ddl.py b/src/databricks/sqlalchemy/_ddl.py index d5d0bf8..08627a2 100644 --- a/src/databricks/sqlalchemy/_ddl.py +++ b/src/databricks/sqlalchemy/_ddl.py @@ -84,6 +84,49 @@ def get_column_specification(self, column, **kwargs): class DatabricksStatementCompiler(compiler.SQLCompiler): + """Compiler that wraps every bind parameter marker in backticks. + + Databricks named parameter markers only accept bare identifiers + (``[A-Za-z_][A-Za-z0-9_]*``) unless backtick-quoted. DataFrame-origin + column names frequently contain hyphens (``col-with-hyphen``), which + SQLAlchemy would otherwise render as an invalid marker + ``:col-with-hyphen`` — the parser splits on ``-`` and reports + UNBOUND_SQL_PARAMETER. + + Wrapping every marker in backticks (``:`col-with-hyphen```) is valid + for any identifier the Spark SQL grammar accepts, so we wrap + unconditionally. The backticks are SQL-side quoting only — the + parameter's logical name is the text between them, so the params + dict sent to the driver keeps the original unquoted key. + + Implementation: fix ``bindtemplate`` and ``compilation_bindtemplate`` + on the class. Every bind-render path in SQLAlchemy reads one of + these two attributes (``bindparam_string``, + ``_literal_execute_expanding_parameter``, and the insertmanyvalues + path which this dialect doesn't enable), so fixing them at the + attribute level covers all paths with no method overrides. We use + property descriptors with no-op setters because ``SQLCompiler.__init__`` + assigns the default templates from ``BIND_TEMPLATES[paramstyle]`` + during its own init — a plain class attribute would be shadowed by + that instance assignment. The no-op setter silently discards super's + assignment so our class-level value is always what gets read. + """ + + _BIND_TEMPLATE = ":`%(name)s`" + + # The no-op setter makes ``SQLCompiler.__init__``'s assignment of the + # default template a silent no-op so our class-level value is what + # every render path reads. ``# type: ignore[assignment]`` is required + # because super declares these as ``str``, and a ``property`` is a + # different type at the static-analysis level (runtime behavior is + # unchanged — the descriptor returns ``str`` on access). + bindtemplate = property( # type: ignore[assignment] + lambda self: self._BIND_TEMPLATE, lambda self, _: None + ) + compilation_bindtemplate = property( # type: ignore[assignment] + lambda self: self._BIND_TEMPLATE, lambda self, _: None + ) + def limit_clause(self, select, **kw): """Identical to the default implementation of SQLCompiler.limit_clause except it writes LIMIT ALL instead of LIMIT -1, since Databricks SQL doesn't support the latter. diff --git a/tests/test_local/test_ddl.py b/tests/test_local/test_ddl.py index 9b19acf..804e11d 100644 --- a/tests/test_local/test_ddl.py +++ b/tests/test_local/test_ddl.py @@ -1,5 +1,5 @@ import pytest -from sqlalchemy import Column, MetaData, String, Table, Numeric, Integer, create_engine +from sqlalchemy import Column, MetaData, String, Table, Numeric, Integer, create_engine, insert from sqlalchemy.schema import ( CreateTable, DropColumnComment, @@ -114,3 +114,274 @@ def test_create_table_with_complex_type(self, metadata): assert "array_array_string ARRAY>" in output assert "map_string_string MAP" in output assert "variant_col VARIANT" in output + + +class TestBindParamQuoting(DDLTestBase): + """Regression tests for bind-parameter quoting. + + Databricks named parameter markers (``:name``) must be bare identifiers + (``[A-Za-z_][A-Za-z0-9_]*``) unless wrapped in backticks. Because + DataFrame-origin column names frequently contain hyphens (a character + that's legal inside a backtick-quoted column identifier but not in a + bare bind marker), the dialect wraps every bind name in backticks + unconditionally. The backticks are SQL-side quoting only — the params + dict sent to the driver keeps the original unquoted key. + + The behavior is gated by ``DatabricksDialect.quote_bind_params`` which + defaults to True; set ``?quote_bind_params=false`` in the URL to + disable. + """ + + def _compile_insert(self, table, values, engine=None): + stmt = insert(table).values(values) + return stmt.compile(bind=engine or self.engine) + + def test_hyphenated_column_renders_backticked_bind_marker(self): + metadata = MetaData() + table = Table( + "t", + metadata, + Column("col-with-hyphen", String()), + Column("normal_col", String()), + ) + compiled = self._compile_insert( + table, {"col-with-hyphen": "x", "normal_col": "y"} + ) + + sql = str(compiled) + # Both names are backticked at the marker site + assert ":`col-with-hyphen`" in sql + assert ":`normal_col`" in sql + # The params dict sent to the driver keeps the ORIGINAL unquoted key + # — this matches what the Databricks server expects (verified + # empirically: a backticked marker ``:`name``` binds against a plain + # ``name`` key in the params dict). + params = compiled.construct_params() + assert params["col-with-hyphen"] == "x" + assert params["normal_col"] == "y" + assert "`col-with-hyphen`" not in params + assert "`normal_col`" not in params + + def test_hyphen_and_underscore_columns_do_not_collide(self): + """A table containing both ``col-name`` and ``col_name`` must produce + two distinct bind parameters with two distinct dict keys; otherwise + one value would silently clobber the other. + """ + metadata = MetaData() + table = Table( + "t", + metadata, + Column("col-name", String()), + Column("col_name", String()), + ) + compiled = self._compile_insert( + table, {"col-name": "hyphen_value", "col_name": "underscore_value"} + ) + + sql = str(compiled) + assert ":`col-name`" in sql + assert ":`col_name`" in sql + + params = compiled.construct_params() + assert params["col-name"] == "hyphen_value" + assert params["col_name"] == "underscore_value" + + def test_plain_identifier_bind_names_are_also_backticked(self): + """Every bind name is wrapped unconditionally — the Databricks SQL + grammar accepts ``:`id``` identically to ``:id`` for plain names + (verified against a live warehouse). + """ + metadata = MetaData() + table = Table( + "t", + metadata, + Column("id", String()), + Column("name", String()), + ) + compiled = self._compile_insert(table, {"id": "1", "name": "n"}) + sql = str(compiled) + assert ":`id`" in sql + assert ":`name`" in sql + + + def test_leading_digit_column_is_backticked(self): + """Databricks bind names cannot start with a digit bare.""" + metadata = MetaData() + table = Table("t", metadata, Column("1col", String())) + compiled = self._compile_insert(table, {"1col": "x"}) + assert ":`1col`" in str(compiled) + + def test_many_special_characters_in_column_names(self): + """Column names containing characters that Delta allows (hyphens, + slashes, question marks, hash, plus, star, at, dollar, amp, pipe, + lt/gt) should render as valid backtick-quoted bind markers. We + intentionally exclude characters Delta rejects at DDL time + (space, parens, comma, equals) — those never land in a real + Databricks table, so never reach the bind-name path. + """ + # Each of these survives a CREATE TABLE in Delta (verified empirically) + # and appears verbatim inside the backtick-quoted bind name — the + # default SQLAlchemy escape map does not translate any of them. + pass_through = [ + "col-hyphen", + "col/slash", + "col?question", + "col#hash", + "col+plus", + "col*star", + "col@at", + "col$dollar", + "col&", + "col|pipe", + "colgt", + ] + metadata = MetaData() + columns = [Column(n, String()) for n in pass_through] + table = Table("t", metadata, *columns) + values = {n: f"v-{i}" for i, n in enumerate(pass_through)} + compiled = self._compile_insert(table, values) + sql = str(compiled) + params = compiled.construct_params() + for n in pass_through: + assert f":`{n}`" in sql, f"bind marker missing for {n!r}" + assert params[n] == values[n] + + def test_chars_in_sqlalchemy_default_escape_map_still_work(self): + """Characters already in SQLAlchemy's default + ``bindname_escape_characters`` (``.``, ``[``, ``]``, ``:``, ``%``) + are pre-translated by super's ``bindparam_string`` before our + backtick template wraps the resulting name. The rendered bind + name is the translated one (``col_with_dot``), inside backticks. + ``construct_params`` uses ``escaped_bind_names`` to translate + the customer's incoming dict key to match. Verified end-to-end + against a live warehouse. + """ + metadata = MetaData() + table = Table( + "t", + metadata, + Column("col.with.dot", String()), + Column("col[bracket]", String()), + Column("col:colon", String()), + Column("col%percent", String()), + ) + compiled = self._compile_insert( + table, + { + "col.with.dot": "d", + "col[bracket]": "b", + "col:colon": "c", + "col%percent": "p", + }, + ) + sql = str(compiled) + assert ":`col_with_dot`" in sql + assert ":`col_bracket_`" in sql + assert ":`colCcolon`" in sql + assert ":`colPpercent`" in sql + + params = compiled.construct_params() + assert params["col_with_dot"] == "d" + assert params["colCcolon"] == "c" + assert params["col_bracket_"] == "b" + assert params["colPpercent"] == "p" + + def test_unicode_column_names(self): + """Databricks allows arbitrary Unicode inside backtick-quoted + identifiers. Bind parameter quoting must handle Unicode names too. + """ + names = ["prénom", "姓名", "Straße"] + metadata = MetaData() + table = Table("t", metadata, *(Column(n, String()) for n in names)) + values = {n: f"v{i}" for i, n in enumerate(names)} + compiled = self._compile_insert(table, values) + sql = str(compiled) + for n in names: + assert f":`{n}`" in sql + params = compiled.construct_params() + for n in names: + assert params[n] == values[n] + + def test_sql_reserved_word_as_column_name(self): + """Reserved words used as column names must work as bind params too.""" + metadata = MetaData() + table = Table("t", metadata, Column("select", String()), Column("from", String())) + compiled = self._compile_insert(table, {"select": "s", "from": "f"}) + sql = str(compiled) + assert ":`select`" in sql + assert ":`from`" in sql + + def test_where_clause_with_hyphenated_column(self): + """The quoting must also apply when the hyphenated column appears in + a WHERE clause (SELECT / UPDATE / DELETE all share this path). + """ + from sqlalchemy import select + + metadata = MetaData() + table = Table("t", metadata, Column("col-name", String())) + stmt = select(table).where(table.c["col-name"] == "x") + compiled = stmt.compile(bind=self.engine) + # SQLAlchemy anonymizes the bind as ``_`` — the hyphen + # survives into the bind name, so it must still be backtick-quoted. + assert ":`col-name_1`" in str(compiled) + + def test_multivalues_insert_disambiguates_with_backticked_markers(self): + """Multi-row INSERT generates per-row suffixed bind names. Each + suffixed name must still render backtick-quoted correctly. + """ + metadata = MetaData() + table = Table("t", metadata, Column("col-name", String())) + stmt = insert(table).values([{"col-name": "a"}, {"col-name": "b"}]) + compiled = stmt.compile(bind=self.engine) + sql = str(compiled) + # SQLAlchemy emits e.g. `col-name_m0`, `col-name_m1` for row-level params + assert ":`col-name_m0`" in sql + assert ":`col-name_m1`" in sql + + def test_in_clause_with_hyphenated_column_compiles_to_postcompile(self): + """The initial compilation leaves an IN clause as a POSTCOMPILE + placeholder. The placeholder itself isn't a bind marker so no + quoting is needed at this stage — the actual expanded markers + (``:\\`col-name_1_1\\``, …) are rendered at expansion time by our + ``_literal_execute_expanding_parameter`` override (see + ``test_in_clause_expansion_renders_backticked_markers``). + """ + from sqlalchemy import select + + metadata = MetaData() + table = Table("t", metadata, Column("col-name", String())) + stmt = select(table).where(table.c["col-name"].in_(["a", "b"])) + sql = str(stmt.compile(bind=self.engine)) + assert "POSTCOMPILE_col-name_1" in sql + + def test_in_clause_expansion_renders_backticked_markers(self): + """Exercise the three sites that invoke + ``_literal_execute_expanding_parameter``: + + * normal execute-time expansion via ``construct_expanded_state`` + * ``compile_kwargs={'render_postcompile': True}`` — which fires + inside super's ``__init__``, before any post-super subclass + init would take effect + """ + from sqlalchemy import select + + metadata = MetaData() + table = Table("t", metadata, Column("col-name", String())) + stmt = select(table).where(table.c["col-name"].in_(["a", "b", "c"])) + + # (1) render_postcompile=True at compile time — fires inside super __init__ + rendered = str( + stmt.compile(bind=self.engine, compile_kwargs={"render_postcompile": True}) + ) + assert ":`col-name_1_1`" in rendered + assert ":`col-name_1_2`" in rendered + assert ":`col-name_1_3`" in rendered + + # (2) construct_expanded_state at execute time + compiled = stmt.compile(bind=self.engine) + expanded = compiled.construct_expanded_state( + {"col-name_1": ["a", "b", "c"]} + ) + assert ":`col-name_1_1`" in expanded.statement + assert ":`col-name_1_2`" in expanded.statement + assert ":`col-name_1_3`" in expanded.statement