diff --git a/CHANGELOG.md b/CHANGELOG.md index 2d6a2e9b4..73568a9c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,9 +2,19 @@ ### v1.10.1 +#### Features + +- Add `dbt_sqlserver_enable_safe_type_expansion` behaviour flag to allow safe column type widening during schema expansion: `varchar` → `nvarchar`, integer family promotions (`bit` → `tinyint` → `smallint` → `int` → `bigint`), and `numeric`/`decimal` precision/scale upgrades. Gated by the per-model `column_type_expansion_max_rows` config (default 1,000,000 rows). See [#699](https://github.com/dbt-msft/dbt-sqlserver/issues/699). +- Add `prefer_single_alter_column` model config to use a single `ALTER COLUMN` statement instead of the add+update+drop+rename pattern when altering column types on tables. +- Add `string_type_instance()` to preserve the NVARCHAR/NCHAR type family during column expansion, fixing incorrect promotion of NVARCHAR/NCHAR to VARCHAR. +- Add `tinyint` and `bit` to the `is_integer()` type list for correct type detection. + #### Bugfixes - Fix unit tests with empty fixtures (`rows: []`) generating invalid `limit 0` syntax; emit `top 0` instead. Also fix `get_columns_in_query()` for queries starting with a CTE, which broke unit tests with an empty `expect` block; such queries are now described via `sp_describe_first_result_set` instead of being executed. [#698](https://github.com/dbt-msft/dbt-sqlserver/issues/698) +- Fix catalog generation for NVARCHAR/NCHAR columns: use `user_type_id` instead of `system_type_id` in catalog.sql, preventing them from appearing as `SYSNAME` in `dbt docs`. [#637](https://github.com/dbt-msft/dbt-sqlserver/issues/637) +- Fix `varchar(max)` / `nvarchar(max)` columns being incorrectly treated as size `-1` during type expansion, preventing `varchar(max)` → `varchar(100)` narrowing and properly allowing `varchar(100)` → `varchar(max)` expansion. +- Fix seed table ingestion of empty numeric cells by inlining `null` literals instead of binding parameters. [#425](https://github.com/dbt-msft/dbt-sqlserver/issues/425) #### Features diff --git a/README.md b/README.md index 34f25d944..757687572 100644 --- a/README.md +++ b/README.md @@ -129,6 +129,18 @@ The same setting is also honoured via `vars:` for backwards compatibility; the b *(default: `pyodbc`)* Set to `mssql-python` in a profile target to use the `mssql-python` backend instead of `pyodbc`. The adapter fails if the required backend package (Python dependency), such as `pyodbc` or `mssql-python`, is not installed. +### `dbt_sqlserver_enable_safe_type_expansion` + +*(default: `false`)* When enabled, allows the adapter to widen column types during incremental model schema expansion beyond same-family string resizes. Supported safe expansions include: + +- **Cross-family string**: `varchar`/`char` → `nvarchar`/`nchar` (same or larger size) +- **Integer family**: `bit` → `tinyint` → `smallint` → `int` → `bigint` +- **Integer → numeric**: `int` → `numeric` (with sufficient precision to hold the integer range) +- **Numeric precision/scale**: `numeric(p,s)` → `numeric(p2,s2)` where precision and scale both increase +- **Fixed-money**: `smallmoney` → `money`, `money` → `numeric` (with sufficient precision) + +Safe expansions are further gated by `column_type_expansion_max_rows` (default 1,000,000 rows) to avoid long-running operations on large tables. + ### `dbt_sqlserver_use_dbt_transactions` _(default: `false`)_ When enabled, makes dbt's transaction hooks real at the SQL Server level by emitting `BEGIN TRANSACTION` / `COMMIT TRANSACTION` through the adapter's `add_begin_query` and `add_commit_query` methods. @@ -142,7 +154,28 @@ This mode is opt-in and should be tested carefully with project-specific materia ```yaml # dbt_project.yml flags: - dbt_sqlserver_use_dbt_transactions: true # <-- opt-in; default is false + dbt_sqlserver_enable_safe_type_expansion: true + dbt_sqlserver_use_dbt_transactions: true # <-- opt-in; default is false +``` + +### `column_type_expansion_max_rows` + +*(default: `1000000`)* Per-model config that limits when safe type expansion runs. When the target table exceeds this row count, safe type expansion is skipped (basic same-family string resizes still proceed). Set to `-1` to disable the check entirely. + +```sql +-- In an incremental model +{{ config(materialized='incremental', unique_key='id', + column_type_expansion_max_rows=500000) }} +``` + +### `prefer_single_alter_column` + +*(default: `false`)* Model-level config that controls how `alter_column_type` changes column types on tables. When `false` (default), the adapter uses the safer approach: add a temporary column, copy data, drop the original, and rename. When `true`, the adapter uses a single `ALTER COLUMN` statement, which is faster on small, medium tables and instant on safe type expansions but may fail for types that cannot be implicitly converted. + +```sql +-- In an incremental model +{{ config(materialized='incremental', unique_key='id', + prefer_single_alter_column=true) }} ``` **Compatibility notes:** Enabling `dbt_sqlserver_use_dbt_transactions: true` may expose transaction-state assumptions hidden by autocommit-only mode. Explicit transaction macros may interact with dbt-managed transactions, and cleanup after failed DDL/DML may differ. Review pre/post hooks for in-transaction vs out-of-transaction semantics. diff --git a/dbt/adapters/sqlserver/sqlserver_adapter.py b/dbt/adapters/sqlserver/sqlserver_adapter.py index e922b7af1..50629c191 100644 --- a/dbt/adapters/sqlserver/sqlserver_adapter.py +++ b/dbt/adapters/sqlserver/sqlserver_adapter.py @@ -15,7 +15,8 @@ from dbt.adapters.base.meta import available from dbt.adapters.base.relation import BaseRelation from dbt.adapters.capability import Capability, CapabilityDict, CapabilitySupport, Support -from dbt.adapters.events.types import SchemaCreation +from dbt.adapters.events.logging import AdapterLogger +from dbt.adapters.events.types import ColTypeChange, SchemaCreation from dbt.adapters.reference_keys import _make_ref_key_dict from dbt.adapters.relation_configs import RelationConfigChangeAction from dbt.adapters.sql.impl import CREATE_SCHEMA_MACRO_NAME, SQLAdapter @@ -30,6 +31,8 @@ from dbt.adapters.sqlserver.sqlserver_connections import SQLServerConnectionManager from dbt.adapters.sqlserver.sqlserver_relation import SQLServerRelation +logger = AdapterLogger("SQLServer") + class SQLServerAdapter(SQLAdapter): """ @@ -111,6 +114,16 @@ def _behavior_flags(self) -> List[BehaviorFlag]: "The new behaviour is intended to become the default in a future release." ), }, + { + "name": "dbt_sqlserver_enable_safe_type_expansion", + "default": False, + "description": ( + "Allow the SQL Server adapter to widen column types during schema expansion. " + "This enables promotions like varchar -> nvarchar, " + "bit -> tinyint -> smallint -> int -> bigint, " + "and numeric(p,s) -> numeric(p2,s2) using alter column." + ), + }, { "name": "dbt_sqlserver_use_dbt_transactions", "default": False, @@ -312,6 +325,101 @@ def render_model_constraint(cls, constraint: ModelLevelConstraint) -> Optional[s else: return None + def _get_row_count(self, relation) -> int: + """Return the number of rows in the given relation.""" + sql = f"SELECT COUNT_BIG(*) FROM {relation}" + _, cursor = self.connections.add_select_query(sql) + row = cursor.fetchone() + return int(row[0]) if row else 0 + + def expand_column_types(self, goal, current, max_rows: int = 1000000): + """Override to ensure we preserve nvarchar/nchar type family during + column expansion. Necessary same-family resizes (e.g. varchar size) + always proceed. Safe type expansions (cross-family promotions like + varchar -> nvarchar) are guarded by column_type_expansion_max_rows. + enable_safe_type_expansion is the future approach for widening.""" + + reference_columns = {c.name: c for c in self.get_columns_in_relation(goal)} + target_columns = {c.name: c for c in self.get_columns_in_relation(current)} + + enable_safe = self.behavior.dbt_sqlserver_enable_safe_type_expansion + + row_count_exceeds = False + if enable_safe and max_rows != -1: + if max_rows == 0: + row_count_exceeds = True + logger.info( + "Safe type expansion skipped for %s: column_type_expansion_max_rows is 0.", + current, + ) + else: + row_count = self._get_row_count(current) + if row_count > max_rows: + row_count_exceeds = True + logger.warning( + "Safe type expansion skipped for %s: " + "%s rows exceeds column_type_expansion_max_rows (%s). " + "Set column_type_expansion_max_rows=-1 to disable " + "this check, or increase the limit.", + current, + row_count, + max_rows, + ) + + for column_name, reference_column in reference_columns.items(): + target_column = target_columns.get(column_name) + if target_column is None: + continue + + if target_column.can_expand_to(reference_column): + pass + elif ( + enable_safe + and not row_count_exceeds + and target_column.can_expand_safe(reference_column) + ): + pass + else: + continue + + if reference_column.is_string(): + col_string_size = reference_column.string_size() + new_type = reference_column.string_type_instance(col_string_size) + else: + new_type = reference_column.data_type + fire_event( + ColTypeChange( + orig_type=target_column.data_type, + new_type=new_type, + table=_make_ref_key_dict(current), + ) + ) + self.alter_column_type(current, column_name, new_type) + + @available.parse_none + def expand_target_column_types( + self, from_relation: BaseRelation, to_relation: BaseRelation, max_rows: int = 1000000 + ) -> None: + if not isinstance(from_relation, self.Relation): + from dbt.adapters.base.impl import MacroArgTypeError + + raise MacroArgTypeError( + method_name="expand_target_column_types", + arg_name="from_relation", + got_value=from_relation, + expected_type=self.Relation, + ) + if not isinstance(to_relation, self.Relation): + from dbt.adapters.base.impl import MacroArgTypeError + + raise MacroArgTypeError( + method_name="expand_target_column_types", + arg_name="to_relation", + got_value=to_relation, + expected_type=self.Relation, + ) + self.expand_column_types(from_relation, to_relation, max_rows) + @available def parse_index(self, raw_index: Any) -> Optional[SQLServerIndexConfig]: return SQLServerIndexConfig.parse(raw_index) diff --git a/dbt/adapters/sqlserver/sqlserver_column.py b/dbt/adapters/sqlserver/sqlserver_column.py index d93281b5f..418220f13 100644 --- a/dbt/adapters/sqlserver/sqlserver_column.py +++ b/dbt/adapters/sqlserver/sqlserver_column.py @@ -37,6 +37,34 @@ class SQLServerColumn(Column): @classmethod def string_type(cls, size: int) -> str: + """Class-level string_type used by SQLAdapter.expand_column_types. + + Return a VARCHAR default for the SQLAdapter path; this keeps behaviour + consistent with the rest of dbt where class-level string_type is + generic and not instance-aware. + """ + return f"varchar({size if size > 0 else '8000'})" + + def string_type_instance(self, size: int) -> str: + """Instance-level string type selection that respects NVARCHAR/NCHAR. + + Handles MAX strings (size == -1) by emitting the appropriate + varchar(max) or nvarchar(max) DDL. Fixed-length char/nchar do not + support MAX and raise if queried with size == -1. + """ + dtype = (self.dtype or "").lower() + if size == -1: + if dtype == "varchar": + return "varchar(max)" + if dtype == "nvarchar": + return "nvarchar(max)" + raise DbtRuntimeError(f"{dtype}(max) is not a valid SQL Server type") + if dtype == "nvarchar": + return f"nvarchar({size if size > 0 else '4000'})" + if dtype == "nchar": + return f"nchar({size if size > 0 else '1'})" + if dtype == "char": + return f"char({size if size > 0 else '1'})" return f"varchar({size if size > 0 else '8000'})" def literal(self, value: Any) -> str: @@ -48,14 +76,24 @@ def data_type(self) -> str: if self.dtype.lower() == "datetime2": return "datetime2(6)" if self.is_string(): - return self.string_type(self.string_size()) - elif self.is_numeric(): + return self.string_type_instance(self.string_size()) + elif self.is_decimal_type(): return self.numeric_type(self.dtype, self.numeric_precision, self.numeric_scale) else: return self.dtype def is_string(self) -> bool: - return self.dtype.lower() in ["varchar", "char"] + return self.dtype.lower() in ["varchar", "char", "nvarchar", "nchar"] + + def is_max_string(self) -> bool: + """Return True if this is a MAX string column (char_size == -1). + + In SQL Server, MAX is represented as -1 in the catalog views. + This applies to varchar(max) and nvarchar(max). char/nchar do not + support MAX. + """ + dtype = (self.dtype or "").lower() + return dtype in ("varchar", "nvarchar") and int(self.char_size or 0) == -1 def is_number(self): return any([self.is_integer(), self.is_numeric(), self.is_float()]) @@ -64,27 +102,33 @@ def is_float(self): return self.dtype.lower() in ["float", "real"] def is_integer(self) -> bool: + # SQL Server exact numeric integer types per MS docs (all versions back to 2017). + # bit is classified as "an integer data type" by Microsoft in the Transact-SQL docs + # (https://learn.microsoft.com/en-us/sql/t-sql/data-types/bit-transact-sql). + # integer is a standard SQL synonym for int kept for ODBC compatibility. return self.dtype.lower() in [ - # real types + "bit", + "tinyint", "smallint", + "int", "integer", "bigint", - "smallserial", - "serial", - "bigserial", - # aliases - "int2", - "int4", - "int8", - "serial2", - "serial4", - "serial8", - "int", ] def is_numeric(self) -> bool: return self.dtype.lower() in ["numeric", "decimal", "money", "smallmoney"] + def is_decimal_type(self) -> bool: + """Return True for true arbitrary-precision numeric/decimal types only. + + This excludes fixed-scale money/smallmoney which are still classified + as numeric by is_numeric() for backward compatibility. + """ + return self.dtype.lower() in ["numeric", "decimal"] + + def is_fixed_numeric(self) -> bool: + return self.dtype.lower() in ["money", "smallmoney"] + def string_size(self) -> int: if not self.is_string(): raise DbtRuntimeError("Called string_size() on non-string field!") @@ -94,9 +138,126 @@ def string_size(self) -> int: return int(self.char_size) def can_expand_to(self, other_column: "SQLServerColumn") -> bool: - if not self.is_string() or not other_column.is_string(): + self_dtype = self.dtype.lower() + other_dtype = other_column.dtype.lower() + if self.is_string() and other_column.is_string(): + if self_dtype != other_dtype: + return False + self_max = self.is_max_string() + other_max = other_column.is_max_string() + # MAX -> MAX: not an expansion + if self_max and other_max: + return False + # MAX -> bounded: rejected (would be a shrink) + if self_max and not other_max: + return False + # bounded -> MAX: always an expansion + if not self_max and other_max: + return True + # bounded -> bounded: normal numeric size comparison + return other_column.string_size() > self.string_size() + return False + + @staticmethod + def _integer_digits(col: "SQLServerColumn") -> int: + """Return the number of integer digits for a numeric/integer column. + + For numeric/decimal columns: precision - scale. + For integer types: the maximum decimal precision required. + For fixed-money types: precision - scale of their effective representation. + """ + dtype = col.dtype.lower() + if col.is_decimal_type(): + prec = int(col.numeric_precision or 0) + scale = int(col.numeric_scale or 0) + return prec - scale + if col.is_fixed_numeric(): + # Treat money/smallmoney as fixed-scale numerics + if dtype == "smallmoney": + return 10 - 4 # effectively numeric(10,4) + elif dtype == "money": + return 19 - 4 # effectively numeric(19,4) + if col.is_integer(): + if dtype in ("bit",): + return 1 + if dtype in ("tinyint",): + return 3 + if dtype in ("smallint",): + return 5 + if dtype in ("bigint",): + return 19 + # int, integer + return 10 + return 0 + + @staticmethod + def _scale(col: "SQLServerColumn") -> int: + """Return the scale for numeric / fixed-money columns.""" + if col.is_decimal_type(): + return int(col.numeric_scale or 0) + if col.is_fixed_numeric(): + # smallmoney and money both have scale 4 + return 4 + return 0 + + def can_expand_safe(self, other_column: "SQLServerColumn") -> bool: + self_dtype = self.dtype.lower() + other_dtype = other_column.dtype.lower() + + if self.is_string() and other_column.is_string(): + # Cross-family varchar/char -> nvarchar/nchar guarded expansion + # Also nchar -> nvarchar (fixed-width unicode to variable-width unicode) + if (self_dtype in ("varchar", "char") and other_dtype in ("nvarchar", "nchar")) or ( + self_dtype == "nchar" and other_dtype == "nvarchar" + ): + self_max = self.is_max_string() + other_max = other_column.is_max_string() + + # varchar(max) -> nvarchar(max): allowed behind safe flag + if self_max and other_max: + return True + # varchar(max) -> nvarchar(n): rejected for every bounded n + if self_max and not other_max: + return False + # varchar(n) -> nvarchar(max): allowed + if not self_max and other_max: + return True + # varchar(n) -> nvarchar(m): normal bounded comparison + return other_column.string_size() >= self.string_size() + + # Same-family string handled by can_expand_to + return False + + if not self.is_number() or not other_column.is_number(): return False - return other_column.string_size() > self.string_size() + + int_family = ("bit", "tinyint", "smallint", "int", "bigint") + if self_dtype in int_family and other_dtype in int_family: + return int_family.index(other_dtype) > int_family.index(self_dtype) + + # Integer -> decimal/numeric expansion + if self.is_integer() and other_column.is_decimal_type(): + source_int_digits = self._integer_digits(self) + target_scale = self._scale(other_column) + target_int_digits = self._integer_digits(other_column) + return target_scale >= 0 and target_int_digits >= source_int_digits + + # Numeric/decimal <-> fixed-money type expansion + if (self.is_decimal_type() or self.is_fixed_numeric()) and ( + other_column.is_decimal_type() or other_column.is_fixed_numeric() + ): + source_scale = self._scale(self) + target_scale = self._scale(other_column) + source_int_digits = self._integer_digits(self) + target_int_digits = self._integer_digits(other_column) + + if target_scale >= source_scale and target_int_digits >= source_int_digits: + # Must be a real widening — a pure type rename without + # increasing integer digits or scale is not an expansion. + if target_int_digits > source_int_digits or target_scale > source_scale: + return True + + return False class SQLServerColumnNative(SQLServerColumn): diff --git a/dbt/adapters/sqlserver/sqlserver_configs.py b/dbt/adapters/sqlserver/sqlserver_configs.py index b12c189af..5ff178ecb 100644 --- a/dbt/adapters/sqlserver/sqlserver_configs.py +++ b/dbt/adapters/sqlserver/sqlserver_configs.py @@ -8,6 +8,8 @@ @dataclass class SQLServerConfigs(AdapterConfig): auto_provision_aad_principals: Optional[bool] = False + prefer_single_alter_column: Optional[bool] = False + column_type_expansion_max_rows: Optional[int] = None indexes: Optional[Tuple[SQLServerIndexConfig, ...]] = None # false (default) | warn | true - how index reconciliation treats # droppable indexes dbt didn't create (YAML may supply bool or str) diff --git a/dbt/include/sqlserver/macros/adapters/catalog.sql b/dbt/include/sqlserver/macros/adapters/catalog.sql index 996f76206..55d07ab76 100644 --- a/dbt/include/sqlserver/macros/adapters/catalog.sql +++ b/dbt/include/sqlserver/macros/adapters/catalog.sql @@ -113,8 +113,7 @@ cast(ep.value as nvarchar(max)) as column_comment from sys.columns as c {{ information_schema_hints() }} left join sys.types as t {{ information_schema_hints() }} - on c.system_type_id = t.system_type_id - and c.user_type_id = t.user_type_id + on c.user_type_id = t.user_type_id left join sys.extended_properties as ep {{ information_schema_hints() }} on ep.class = 1 and ep.major_id = c.object_id diff --git a/dbt/include/sqlserver/macros/adapters/columns.sql b/dbt/include/sqlserver/macros/adapters/columns.sql index 313cc0f12..685ec94f6 100644 --- a/dbt/include/sqlserver/macros/adapters/columns.sql +++ b/dbt/include/sqlserver/macros/adapters/columns.sql @@ -38,27 +38,38 @@ {% macro sqlserver__alter_column_type(relation, column_name, new_column_type) %} - {%- set tmp_column = column_name + "__dbt_alter" -%} - {% set alter_column_type %} - alter {{ relation.type }} {{ relation }} add "{{ tmp_column }}" {{ new_column_type }}; - {%- endset %} + {% set prefer_single = config.get('prefer_single_alter_column', false) %} - {% set update_column %} - update {{ relation }} set "{{ tmp_column }}" = "{{ column_name }}"; - {%- endset %} + {% if prefer_single and relation.type == 'table' %} + {% set alter_sql %} + alter {{ relation.type }} {{ relation }} + alter column "{{ column_name }}" {{ new_column_type }}; + {%- endset %} + {% do run_query(alter_sql) %} - {% set drop_column %} - alter {{ relation.type }} {{ relation }} drop column "{{ column_name }}"; - {%- endset %} + {% else %} + {%- set tmp_column = column_name + "__dbt_alter" -%} - {% set rename_column %} - exec sp_rename '{{ relation | replace('"', '') }}.{{ tmp_column }}', '{{ column_name }}', 'column' - {%- endset %} + {% set add_column %} + alter {{ relation.type }} {{ relation }} + add "{{ tmp_column }}" {{ new_column_type }}; + {%- endset %} + {% set update_column %} + update {{ relation }} set "{{ tmp_column }}" = "{{ column_name }}"; + {%- endset %} + {% set drop_column %} + alter {{ relation.type }} {{ relation }} + drop column "{{ column_name }}"; + {%- endset %} + {% set rename_column %} + exec sp_rename '{{ relation | replace('"', '') }}.{{ tmp_column }}', '{{ column_name }}', 'column' + {%- endset %} - {% do run_query(alter_column_type) %} - {% do run_query(update_column) %} - {% do run_query(drop_column) %} - {% do run_query(rename_column) %} + {% do run_query(add_column) %} + {% do run_query(update_column) %} + {% do run_query(drop_column) %} + {% do run_query(rename_column) %} + {% endif %} {% endmacro %} diff --git a/dbt/include/sqlserver/macros/materializations/models/incremental/incremental.sql b/dbt/include/sqlserver/macros/materializations/models/incremental/incremental.sql index 251638c25..8c60ba257 100644 --- a/dbt/include/sqlserver/macros/materializations/models/incremental/incremental.sql +++ b/dbt/include/sqlserver/macros/materializations/models/incremental/incremental.sql @@ -42,9 +42,11 @@ {% set contract_config = config.get('contract') %} {% if not contract_config or not contract_config.enforced %} + {% set expansion_max_rows = config.get('column_type_expansion_max_rows', 1000000) %} {% do adapter.expand_target_column_types( from_relation=temp_relation, - to_relation=target_relation) %} + to_relation=target_relation, + max_rows=expansion_max_rows) %} {% endif %} {#-- Process schema changes. Returns dict of changes if successful. Use source columns for upserting/merging --#} {% set dest_columns = process_schema_changes(on_schema_change, temp_relation, existing_relation) %} diff --git a/dbt/include/sqlserver/macros/materializations/snapshots/snapshot.sql b/dbt/include/sqlserver/macros/materializations/snapshots/snapshot.sql index c2093bdf5..ac601481d 100644 --- a/dbt/include/sqlserver/macros/materializations/snapshots/snapshot.sql +++ b/dbt/include/sqlserver/macros/materializations/snapshots/snapshot.sql @@ -65,8 +65,10 @@ {% set build_or_select_sql = snapshot_staging_table(strategy, temp_snapshot_relation, target_relation) %} {% set staging_table = build_snapshot_staging_table(strategy, temp_snapshot_relation, target_relation) %} -- this may no-op if the database does not require column expansion + {% set expansion_max_rows = config.get('column_type_expansion_max_rows', 1000000) %} {% do adapter.expand_target_column_types(from_relation=staging_table, - to_relation=target_relation) %} + to_relation=target_relation, + max_rows=expansion_max_rows) %} {% set remove_columns = ['dbt_change_type', 'DBT_CHANGE_TYPE', 'dbt_unique_key', 'DBT_UNIQUE_KEY'] %} {% if unique_key | is_list %} diff --git a/dbt/include/sqlserver/macros/relations/seeds/helpers.sql b/dbt/include/sqlserver/macros/relations/seeds/helpers.sql index 34c8e726d..46b59f0a1 100644 --- a/dbt/include/sqlserver/macros/relations/seeds/helpers.sql +++ b/dbt/include/sqlserver/macros/relations/seeds/helpers.sql @@ -22,27 +22,30 @@ {% macro sqlserver__load_csv_rows(model, agate_table) %} {% set cols_sql = get_seed_column_quoted_csv(model, agate_table.column_names) %} {% set batch_size = calc_batch_size(agate_table.column_names|length) %} - {% set bindings = [] %} {% set statements = [] %} {{ log("Inserting batches of " ~ batch_size ~ " records") }} {% for chunk in agate_table.rows | batch(batch_size) %} {% set bindings = [] %} + {% set values_clause = [] %} {% for row in chunk %} - {% do bindings.extend(row) %} + {% set row_values = [] %} + {% for column in agate_table.column_names %} + {%- set val = row[loop.index0] -%} + {%- if val is none -%} + {%- do row_values.append("null") -%} + {%- else -%} + {%- do row_values.append(get_binding_char()) -%} + {%- do bindings.append(val) -%} + {%- endif -%} + {% endfor %} + {% do values_clause.append("(" ~ row_values | join(", ") ~ ")") %} {% endfor %} {% set sql %} - insert into {{ this.render() }} ({{ cols_sql }}) values - {% for row in chunk -%} - ({%- for column in agate_table.column_names -%} - {{ get_binding_char() }} - {%- if not loop.last%},{%- endif %} - {%- endfor -%}) - {%- if not loop.last%},{%- endif %} - {%- endfor %} + insert into {{ this.render() }} ({{ cols_sql }}) values {{ values_clause | join(", ") }} {% endset %} {% do adapter.add_query(sql, bindings=bindings, abridge_sql_log=True) %} diff --git a/pyproject.toml b/pyproject.toml index 753d0270f..445f2cbce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,3 +108,6 @@ known-first-party = ["dbt"] [tool.ruff.lint.flake8-tidy-imports] ban-relative-imports = "all" + +[tool.ty.environment] +python = ".venv" diff --git a/tests/functional/adapter/dbt/test_catalog.py b/tests/functional/adapter/dbt/test_catalog.py index da0924334..3ff431d8b 100644 --- a/tests/functional/adapter/dbt/test_catalog.py +++ b/tests/functional/adapter/dbt/test_catalog.py @@ -154,3 +154,49 @@ def test_docs_generate_includes_non_default_database(self, project): assert "id" in other_node["columns"] finally: self.cleanup_secondary_database(project) + + +CATALOG_COLUMN_TYPES_SQL = """ +{{ config(materialized='table') }} +select + cast('hello' as nvarchar(50)) as nv_col, + cast('h' as nchar(1)) as nc_col, + cast(1 as int) as int_col +""" + + +class TestCatalogColumnTypes: + """ + This test addresses: https://github.com/dbt-msft/dbt-sqlserver/issues/637 + catalog.sql used system_type_id instead of user_type_id causing + NVARCHAR/NCHAR columns to appear as SYSNAME in dbt docs. + """ + + @pytest.fixture(scope="class") + def project_config_update(self): + return {"name": "catalog_column_types_test"} + + @pytest.fixture(scope="class") + def models(self): + return {"catalog_model.sql": CATALOG_COLUMN_TYPES_SQL} + + @pytest.fixture(scope="class") + def docs(self, project): + run_dbt(["run"]) + yield run_dbt(["docs", "generate"]) + + def test_catalog_does_not_return_sysname(self, project, docs): + catalog_path = os.path.join(project.project_root, "target", "catalog.json") + with open(catalog_path) as f: + catalog = json.load(f) + + nodes = catalog.get("nodes", {}) + for node_name, node in nodes.items(): + if "catalog_model" not in node_name: + continue + for col_name, col in node.get("columns", {}).items(): + col_type = col.get("type", "").lower() + assert "sysname" not in col_type, ( + f"Column '{col_name}' has type '{col_type}' " + f"which contains SYSNAME instead of NVARCHAR/NCHAR" + ) diff --git a/tests/functional/adapter/mssql/test_column_type_expansion.py b/tests/functional/adapter/mssql/test_column_type_expansion.py new file mode 100644 index 000000000..745acc89c --- /dev/null +++ b/tests/functional/adapter/mssql/test_column_type_expansion.py @@ -0,0 +1,274 @@ +"""Functional tests for column type expansion and addition +via the incremental materialization. + +Two scenarios tested with default and native string type flags: + 1. Column type expansion via expand_target_column_types + 2. Adding a new nvarchar column via on_schema_change (append / sync_all) +""" + +import os + +import pytest + +from dbt.tests.util import run_dbt + + +def _column_type(project, schema, table, column): + rows = project.run_sql( + f""" + select t.name, c.max_length + from [{project.database}].sys.columns c + inner join [{project.database}].sys.types t + on c.user_type_id = t.user_type_id + where c.object_id = object_id('[{project.database}].[{schema}].[{table}]') + and c.name = '{column}' + """, + fetch="all", + ) + if not rows: + return None + dtype, max_length = rows[0] + if dtype in ("nchar", "nvarchar", "sysname") and max_length != -1: + return (dtype, max_length // 2) + return (dtype, max_length) + + +def write_model(project, filename, contents): + path = os.path.join(project.project_root, "models", filename) + with open(path, "w") as f: + f.write(contents) + + +# --- Model SQL for expansion test --- + +EXPAND_V1 = """ +{{ config(materialized='incremental', unique_key='id') }} +select 1 as id, cast('hello' as varchar(10)) as str_col +""" + +EXPAND_V2 = """ +{{ config(materialized='incremental', unique_key='id') }} +select 1 as id, cast('hello world' as varchar(25)) as str_col +""" + +# --- Model SQL for add-column test --- + +ADD_COL_V1 = """ +{{ + config(materialized='incremental', unique_key='id', + on_schema_change='append_new_columns') +}} +select 1 as id, cast('hello' as varchar(10)) as str_col +""" + +ADD_COL_V2 = """ +{{ + config(materialized='incremental', unique_key='id', + on_schema_change='append_new_columns') +}} +select 1 as id, + cast('hello' as varchar(10)) as str_col, + cast('hello' as nvarchar(20)) as new_col +""" + +# --- Model SQL for sync-all-columns test --- + +SYNC_V1 = """ +{{ + config(materialized='incremental', unique_key='id', + on_schema_change='sync_all_columns') +}} +select 1 as id, cast('hello' as varchar(10)) as str_col +""" + +SYNC_V2 = """ +{{ + config(materialized='incremental', unique_key='id', + on_schema_change='sync_all_columns') +}} +select 1 as id, + cast('hello world' as varchar(25)) as str_col, + cast('hello' as nvarchar(20)) as new_col +""" + + +# ============================================================================ +# Default string types (dbt_sqlserver_use_native_string_types = false) +# ============================================================================ + + +class TestExpansionDefault: + @pytest.fixture(scope="class") + def models(self): + return {"expand_test.sql": EXPAND_V1} + + def test_varchar_size_expansion(self, project): + run_dbt(["run", "--full-refresh"]) + write_model(project, "expand_test.sql", EXPAND_V2) + results = run_dbt(["run"]) + assert len(results) == 1 + assert results[0].status == "success" + + typ = _column_type(project, project.test_schema, "expand_test", "str_col") + assert typ == ("varchar", 25), f"Expected varchar(25), got {typ}" + + +class TestAddColumnDefault: + """ + This test addresses: https://github.com/dbt-msft/dbt-sqlserver/issues/446 + """ + + @pytest.fixture(scope="class") + def models(self): + return {"add_col_test.sql": ADD_COL_V1} + + def test_add_nvarchar_column(self, project): + run_dbt(["run", "--full-refresh"]) + write_model(project, "add_col_test.sql", ADD_COL_V2) + results = run_dbt(["run"]) + assert len(results) == 1 + assert results[0].status == "success" + + typ = _column_type(project, project.test_schema, "add_col_test", "new_col") + assert typ == ("nvarchar", 20), f"Expected nvarchar(20), got {typ}" + + +class TestSyncColumnsDefault: + @pytest.fixture(scope="class") + def models(self): + return {"sync_test.sql": SYNC_V1} + + def test_sync_all_columns(self, project): + run_dbt(["run", "--full-refresh"]) + write_model(project, "sync_test.sql", SYNC_V2) + results = run_dbt(["run"]) + assert len(results) == 1 + assert results[0].status == "success" + + typ = _column_type(project, project.test_schema, "sync_test", "str_col") + assert typ == ("varchar", 25), f"Expected varchar(25), got {typ}" + + typ = _column_type(project, project.test_schema, "sync_test", "new_col") + assert typ == ("nvarchar", 20), f"Expected nvarchar(20), got {typ}" + + +# ============================================================================ +# Safe type expansion: varchar -> nvarchar (requires enable_safe flag) +# ============================================================================ + + +NVARCHAR_V1 = """ +{{ config(materialized='incremental', unique_key='id', + column_type_expansion_max_rows=10) }} +select 1 as id, cast('hello' as varchar(10)) as str_col +""" + +NVARCHAR_V2 = """ +{{ config(materialized='incremental', unique_key='id', + column_type_expansion_max_rows=10) }} +select 1 as id, cast('hi' as nvarchar(25)) as str_col +""" + + +class TestVarcharToNvarcharWithoutFlag: + @pytest.fixture(scope="class") + def models(self): + return {"nvarchar_test.sql": NVARCHAR_V1} + + def test_varchar_to_nvarchar_blocked_without_flag(self, project): + run_dbt(["run", "--full-refresh"]) + write_model(project, "nvarchar_test.sql", NVARCHAR_V2) + results = run_dbt(["run"]) + assert len(results) == 1 + assert results[0].status == "success" + + typ = _column_type(project, project.test_schema, "nvarchar_test", "str_col") + assert typ == ("varchar", 10), f"Expected varchar(10), got {typ}" + + +class TestVarcharToNvarcharWithFlag: + @pytest.fixture(scope="class") + def project_config_update(self): + return {"flags": {"dbt_sqlserver_enable_safe_type_expansion": True}} + + @pytest.fixture(scope="class") + def models(self): + return {"nvarchar_safe_test.sql": NVARCHAR_V1} + + def test_varchar_to_nvarchar_works_with_flag(self, project): + run_dbt(["run", "--full-refresh"]) + write_model(project, "nvarchar_safe_test.sql", NVARCHAR_V2) + results = run_dbt(["run"]) + assert len(results) == 1 + assert results[0].status == "success" + + typ = _column_type(project, project.test_schema, "nvarchar_safe_test", "str_col") + assert typ == ("nvarchar", 25), f"Expected nvarchar(25), got {typ}" + + +# ============================================================================ +# Native string types (dbt_sqlserver_use_native_string_types = true) +# ============================================================================ + + +class TestExpansionNative: + @pytest.fixture(scope="class") + def project_config_update(self): + return {"flags": {"dbt_sqlserver_use_native_string_types": True}} + + @pytest.fixture(scope="class") + def models(self): + return {"expand_test.sql": EXPAND_V1} + + def test_varchar_size_expansion_native(self, project): + run_dbt(["run", "--full-refresh"]) + write_model(project, "expand_test.sql", EXPAND_V2) + results = run_dbt(["run"]) + assert len(results) == 1 + assert results[0].status == "success" + + typ = _column_type(project, project.test_schema, "expand_test", "str_col") + assert typ == ("varchar", 25), f"Expected varchar(25), got {typ}" + + +class TestAddColumnNative: + @pytest.fixture(scope="class") + def project_config_update(self): + return {"flags": {"dbt_sqlserver_use_native_string_types": True}} + + @pytest.fixture(scope="class") + def models(self): + return {"add_col_test.sql": ADD_COL_V1} + + def test_add_nvarchar_column_native(self, project): + run_dbt(["run", "--full-refresh"]) + write_model(project, "add_col_test.sql", ADD_COL_V2) + results = run_dbt(["run"]) + assert len(results) == 1 + assert results[0].status == "success" + + typ = _column_type(project, project.test_schema, "add_col_test", "new_col") + assert typ == ("nvarchar", 20), f"Expected nvarchar(20), got {typ}" + + +class TestSyncColumnsNative: + @pytest.fixture(scope="class") + def project_config_update(self): + return {"flags": {"dbt_sqlserver_use_native_string_types": True}} + + @pytest.fixture(scope="class") + def models(self): + return {"sync_test.sql": SYNC_V1} + + def test_sync_all_columns_native(self, project): + run_dbt(["run", "--full-refresh"]) + write_model(project, "sync_test.sql", SYNC_V2) + results = run_dbt(["run"]) + assert len(results) == 1 + assert results[0].status == "success" + + typ = _column_type(project, project.test_schema, "sync_test", "str_col") + assert typ == ("varchar", 25), f"Expected varchar(25), got {typ}" + + typ = _column_type(project, project.test_schema, "sync_test", "new_col") + assert typ == ("nvarchar", 20), f"Expected nvarchar(20), got {typ}" diff --git a/tests/functional/adapter/mssql/test_max_string_expansion.py b/tests/functional/adapter/mssql/test_max_string_expansion.py new file mode 100644 index 000000000..00647c0b6 --- /dev/null +++ b/tests/functional/adapter/mssql/test_max_string_expansion.py @@ -0,0 +1,109 @@ +"""Functional tests for varchar(max) / nvarchar(max) column type expansion +via expand_target_column_types(). + +These tests verify physical adapter behaviour against real SQL Server catalog +metadata — complementing the unit-level comparison/rendering tests already in +test_can_expand_to.py and test_sqlserver_column.py. +""" + +import pytest + +from dbt.adapters.sqlserver.sqlserver_relation import ( + SQLServerRelation, + SQLServerRelationType, +) +from dbt.tests.util import run_dbt + +CURRENT_BOUNDED = """ + {{ config(materialized='table') }} + select + cast('abc' as varchar(100)) as varchar_col, + cast(N'abc' as nvarchar(100)) as nvarchar_col +""" + +GOAL_MAX = """ + {{ config(materialized='table') }} + select + cast('abc' as varchar(max)) as varchar_col, + cast(N'abc' as nvarchar(max)) as nvarchar_col +""" + +CURRENT_MAX = """ + {{ config(materialized='table') }} + select + cast('abc' as varchar(max)) as varchar_col, + cast(N'abc' as nvarchar(max)) as nvarchar_col +""" + +GOAL_BOUNDED = """ + {{ config(materialized='table') }} + select + cast('abc' as varchar(100)) as varchar_col, + cast(N'abc' as nvarchar(100)) as nvarchar_col +""" + + +def _table_relation(adapter, name: str): + """Build a table-typed relation from a model name.""" + credentials = adapter.config.credentials + return SQLServerRelation.create( + database=credentials.database, + schema=adapter.config.credentials.schema, + identifier=name, + type=SQLServerRelationType.Table, + ) + + +class TestSQLServerMaxStringTypeExpansion: + @pytest.fixture(scope="class") + def models(self): + return { + "current_bounded.sql": CURRENT_BOUNDED, + "goal_max.sql": GOAL_MAX, + "current_max.sql": CURRENT_MAX, + "goal_bounded.sql": GOAL_BOUNDED, + } + + @staticmethod + def _columns_by_name(adapter, relation): + return {col.name.lower(): col for col in adapter.get_columns_in_relation(relation)} + + @staticmethod + def _assert_max_string_columns(columns): + varchar_col = columns["varchar_col"] + assert varchar_col.dtype.lower() == "varchar" + assert int(varchar_col.char_size) == -1 + + nvarchar_col = columns["nvarchar_col"] + assert nvarchar_col.dtype.lower() == "nvarchar" + assert int(nvarchar_col.char_size) == -1 + + def test_bounded_strings_expand_to_max(self, project): + """bounded varchar(100)/nvarchar(100) expand to varchar(max)/nvarchar(max).""" + run_dbt(["run", "--select", "current_bounded"]) + run_dbt(["run", "--select", "goal_max"]) + + adapter = project.adapter + current_relation = _table_relation(adapter, "current_bounded") + goal_relation = _table_relation(adapter, "goal_max") + + with adapter.connection_named("__test"): + adapter.expand_target_column_types(goal_relation, current_relation) + columns = self._columns_by_name(adapter, current_relation) + + self._assert_max_string_columns(columns) + + def test_max_strings_do_not_shrink_to_bounded(self, project): + """varchar(max)/nvarchar(max) columns are not narrowed to bounded.""" + run_dbt(["run", "--select", "current_max"]) + run_dbt(["run", "--select", "goal_bounded"]) + + adapter = project.adapter + current_relation = _table_relation(adapter, "current_max") + goal_relation = _table_relation(adapter, "goal_bounded") + + with adapter.connection_named("__test"): + adapter.expand_target_column_types(goal_relation, current_relation) + columns = self._columns_by_name(adapter, current_relation) + + self._assert_max_string_columns(columns) diff --git a/tests/functional/adapter/mssql/test_mssql_seed.py b/tests/functional/adapter/mssql/test_mssql_seed.py index 8a8cdfe22..a2b21222a 100644 --- a/tests/functional/adapter/mssql/test_mssql_seed.py +++ b/tests/functional/adapter/mssql/test_mssql_seed.py @@ -40,3 +40,41 @@ def seeds(self): def test_large_seed(self, project): run_dbt(["seed"]) + + +seed_empty_numeric_csv = """x +123 + +456 +""" + +seed_empty_numeric_yml = """ +version: 2 +seeds: + - name: seed_empty_numeric + config: + column_types: + x: numeric(18, 0) +""" + + +class TestSeedNumericColumnWithEmptyRows: + """ + This test addresses: https://github.com/dbt-msft/dbt-sqlserver/issues/425 + """ + + @pytest.fixture(scope="class") + def project_config_update(self): + return {"name": "seed_empty_numeric_test"} + + @pytest.fixture(scope="class") + def seeds(self): + return { + "seed_empty_numeric.csv": seed_empty_numeric_csv, + "schema.yml": seed_empty_numeric_yml, + } + + def test_seed_numeric_column_with_empty_rows(self, project): + results = run_dbt(["seed"]) + assert len(results) == 1 + assert results[0].status == "success" diff --git a/tests/unit/adapters/mssql/test_can_expand_to.py b/tests/unit/adapters/mssql/test_can_expand_to.py new file mode 100644 index 000000000..b76a65d19 --- /dev/null +++ b/tests/unit/adapters/mssql/test_can_expand_to.py @@ -0,0 +1,160 @@ +import pytest + +from dbt.adapters.sqlserver.sqlserver_column import SQLServerColumn + + +def col_kwargs(dtype, char_size=None, numeric_precision=0, numeric_scale=0): + return { + "column": "c", + "dtype": dtype, + "char_size": char_size, + "numeric_precision": numeric_precision, + "numeric_scale": numeric_scale, + } + + +@pytest.mark.parametrize( + "src_kwargs,tgt_kwargs,expect_with_flag,expect_without_flag", + [ + # String same-family expansions always work + (col_kwargs("varchar", char_size=10), col_kwargs("varchar", char_size=100), True, True), + (col_kwargs("char", char_size=5), col_kwargs("char", char_size=20), True, True), + (col_kwargs("nvarchar", char_size=50), col_kwargs("nvarchar", char_size=200), True, True), + (col_kwargs("nchar", char_size=10), col_kwargs("nchar", char_size=30), True, True), + # String same-size does not expand + (col_kwargs("varchar", char_size=100), col_kwargs("varchar", char_size=100), False, False), + # String smaller target does not expand + (col_kwargs("varchar", char_size=100), col_kwargs("varchar", char_size=50), False, False), + # MAX string: MAX -> bounded is rejected (same-family) + (col_kwargs("varchar", char_size=-1), col_kwargs("varchar", char_size=8000), False, False), + # String cross-family (VARCHAR -> NVARCHAR) requires flag + (col_kwargs("varchar", char_size=10), col_kwargs("nvarchar", char_size=10), True, False), + (col_kwargs("char", char_size=5), col_kwargs("nchar", char_size=5), True, False), + # String cross-family reverse (NVARCHAR -> VARCHAR) never works + (col_kwargs("nvarchar", char_size=10), col_kwargs("varchar", char_size=10), False, False), + # MAX string: bounded -> MAX is allowed (same-family) + (col_kwargs("varchar", char_size=100), col_kwargs("varchar", char_size=-1), True, True), + (col_kwargs("nvarchar", char_size=100), col_kwargs("nvarchar", char_size=-1), True, True), + # MAX string: MAX -> bounded is rejected (same-family) + (col_kwargs("varchar", char_size=-1), col_kwargs("varchar", char_size=100), False, False), + ( + col_kwargs("nvarchar", char_size=-1), + col_kwargs("nvarchar", char_size=100), + False, + False, + ), + # MAX string: MAX -> MAX is not an expansion + (col_kwargs("varchar", char_size=-1), col_kwargs("varchar", char_size=-1), False, False), + # MAX string: varchar(max) -> nvarchar(max) allowed via safe expansion + (col_kwargs("varchar", char_size=-1), col_kwargs("nvarchar", char_size=-1), True, False), + # MAX string: varchar(max) -> nvarchar(n) rejected for all bounded n + ( + col_kwargs("varchar", char_size=-1), + col_kwargs("nvarchar", char_size=4000), + False, + False, + ), + # MAX string: nvarchar(max) -> nvarchar(4000) rejected (same-family MAX -> bounded) + ( + col_kwargs("nvarchar", char_size=-1), + col_kwargs("nvarchar", char_size=4000), + False, + False, + ), + # MAX string: varchar(n) -> nvarchar(max) allowed via safe expansion + (col_kwargs("varchar", char_size=8000), col_kwargs("nvarchar", char_size=-1), True, False), + # Integer family promotions require the feature flag + (col_kwargs("int"), col_kwargs("bigint"), True, False), + (col_kwargs("bit"), col_kwargs("tinyint"), True, False), + # Integer -> numeric widening requires the feature flag + (col_kwargs("int"), col_kwargs("numeric", numeric_precision=10), True, False), + # Integer -> numeric with insufficient integer digits must be rejected + ( + col_kwargs("int"), + col_kwargs("numeric", numeric_precision=10, numeric_scale=5), + False, + False, + ), + # Integer -> numeric with sufficient integer digits may be allowed + ( + col_kwargs("int"), + col_kwargs("numeric", numeric_precision=15, numeric_scale=5), + True, + False, + ), + # Numeric/decimal promotions: precision/scale must increase; flag required + ( + col_kwargs("numeric", numeric_precision=10, numeric_scale=2), + col_kwargs("numeric", numeric_precision=12, numeric_scale=4), + True, + False, + ), + ( + col_kwargs("numeric", numeric_precision=10, numeric_scale=2), + col_kwargs("numeric", numeric_precision=12, numeric_scale=1), + False, + False, + ), + # Numeric/decimal -> numeric/decimal with shrinking integer digits must be rejected + ( + col_kwargs("numeric", numeric_precision=10, numeric_scale=2), + col_kwargs("numeric", numeric_precision=12, numeric_scale=5), + False, + False, + ), + # Numeric/decimal -> numeric/decimal with sufficient integer digits may be allowed + ( + col_kwargs("numeric", numeric_precision=10, numeric_scale=2), + col_kwargs("numeric", numeric_precision=13, numeric_scale=5), + True, + False, + ), + # Fixed-money types (MONEY/SMALLMONEY) + ( + col_kwargs("smallmoney", numeric_precision=10, numeric_scale=4), + col_kwargs("money", numeric_precision=19, numeric_scale=4), + True, + False, + ), + ( + col_kwargs("money", numeric_precision=19, numeric_scale=4), + col_kwargs("numeric", numeric_precision=20, numeric_scale=4), + True, + False, + ), + # MONEY -> NUMERIC with same capacity (int_digits=15, scale=4) is not an expansion + ( + col_kwargs("money", numeric_precision=19, numeric_scale=4), + col_kwargs("numeric", numeric_precision=19, numeric_scale=4), + False, + False, + ), + # NUMERIC -> MONEY that would shrink precision should not be allowed + ( + col_kwargs("numeric", numeric_precision=20, numeric_scale=4), + col_kwargs("money", numeric_precision=19, numeric_scale=4), + False, + False, + ), + # MONEY -> NUMERIC with shrinking integer digits must be rejected + ( + col_kwargs("money", numeric_precision=19, numeric_scale=4), + col_kwargs("numeric", numeric_precision=19, numeric_scale=5), + False, + False, + ), + # MONEY -> NUMERIC with sufficient integer digits may be allowed + ( + col_kwargs("money", numeric_precision=19, numeric_scale=4), + col_kwargs("numeric", numeric_precision=20, numeric_scale=5), + True, + False, + ), + ], +) +def test_can_expand_parametrized(src_kwargs, tgt_kwargs, expect_with_flag, expect_without_flag): + src = SQLServerColumn(**src_kwargs) + tgt = SQLServerColumn(**tgt_kwargs) + + assert src.can_expand_to(tgt) is expect_without_flag + assert (src.can_expand_to(tgt) or src.can_expand_safe(tgt)) is expect_with_flag diff --git a/tests/unit/adapters/mssql/test_expand_column_types.py b/tests/unit/adapters/mssql/test_expand_column_types.py new file mode 100644 index 000000000..3aafc53b9 --- /dev/null +++ b/tests/unit/adapters/mssql/test_expand_column_types.py @@ -0,0 +1,168 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from dbt.adapters.sqlserver.sqlserver_adapter import SQLServerAdapter +from dbt.adapters.sqlserver.sqlserver_relation import SQLServerRelation + + +@pytest.fixture +def adapter(): + config = MagicMock() + config.flags = {} + config.project_name = "test" + config.credentials.type = "sqlserver" + mp_context = MagicMock() + adapter = SQLServerAdapter(config, mp_context) + adapter._get_row_count = MagicMock(return_value=0) # type: ignore + adapter.get_columns_in_relation = MagicMock(return_value=[]) + adapter.alter_column_type = MagicMock() + # behavior is an available_property — override via the underlying Behavior object + adapter.behavior.dbt_sqlserver_enable_safe_type_expansion = True # type: ignore[attr-defined] + return adapter + + +def make_rel(name="t"): + rel = MagicMock(spec=SQLServerRelation) + rel.__str__ = lambda s: f"test_schema.{name}" + return rel + + +class TestExpandColumnTypes: + def test_skips_row_count_when_max_rows_is_negative_one(self, adapter): + adapter.expand_column_types(make_rel("goal"), make_rel("current"), max_rows=-1) + adapter._get_row_count.assert_not_called() + + def test_blocks_safe_expansion_when_max_rows_is_zero(self, adapter): + adapter.get_columns_in_relation = MagicMock(return_value=[]) + adapter.alter_column_type = MagicMock() + + goal = make_rel("goal") + goal_col = MagicMock() + goal_col.name = "c" + goal_col.dtype = "nvarchar" + goal_col.is_string = MagicMock(return_value=True) + goal_col.is_number = MagicMock(return_value=True) + goal_col.string_size = MagicMock(return_value=20) + goal_col.string_type_instance = MagicMock(return_value="nvarchar(20)") + goal_col.data_type = "nvarchar(20)" + + current = make_rel("current") + current_col = MagicMock() + current_col.name = "c" + current_col.dtype = "varchar" + current_col.is_string = MagicMock(return_value=True) + current_col.is_number = MagicMock(return_value=True) + current_col.can_expand_to = MagicMock(return_value=False) + current_col.can_expand_safe = MagicMock(return_value=True) + + adapter.get_columns_in_relation.side_effect = lambda r: ( + [goal_col] if r is goal else [current_col] + ) + + with patch("dbt.adapters.sqlserver.sqlserver_adapter.logger"): + adapter.expand_column_types(goal, current, max_rows=0) + + adapter._get_row_count.assert_not_called() + adapter.alter_column_type.assert_not_called() + + def test_reads_row_count_when_within_limit(self, adapter): + adapter._get_row_count.return_value = 50 + adapter.expand_column_types(make_rel("goal"), make_rel("current"), max_rows=100) + adapter._get_row_count.assert_called_once() + + def test_emits_warning_when_row_count_exceeds_max(self, adapter): + adapter._get_row_count.return_value = 200 + with patch("dbt.adapters.sqlserver.sqlserver_adapter.logger") as logger: + adapter.expand_column_types(make_rel("goal"), make_rel("current"), max_rows=100) + adapter._get_row_count.assert_called_once() + logger.warning.assert_called_once() + + def test_expand_target_column_types_forwards_max_rows(self, adapter): + adapter.get_columns_in_relation = MagicMock(return_value=[]) + adapter.alter_column_type = MagicMock() + + goal = make_rel("goal") + current = make_rel("current") + max_rows = 500 + + with patch.object(adapter, "expand_column_types") as mock_expand: + adapter.expand_target_column_types(goal, current, max_rows=max_rows) + + mock_expand.assert_called_once_with(goal, current, max_rows) + + @pytest.mark.parametrize( + "dtype,expected_type", + [ + ("varchar", "varchar(max)"), + ("nvarchar", "nvarchar(max)"), + ], + ) + def test_bounded_to_max_emits_max(self, adapter, dtype, expected_type): + """bounded {dtype}(100) -> {dtype}(max) should emit {dtype}(max).""" + adapter.get_columns_in_relation = MagicMock(return_value=[]) + adapter.alter_column_type = MagicMock() + + goal = make_rel("goal") + goal_col = MagicMock() + goal_col.name = "c" + goal_col.dtype = dtype + goal_col.data_type = expected_type + goal_col.is_string = MagicMock(return_value=True) + goal_col.is_number = MagicMock(return_value=False) + goal_col.string_size = MagicMock(return_value=-1) + goal_col.string_type_instance = MagicMock(return_value=expected_type) + + current = make_rel("current") + current.database = "test_db" + current.schema = "test_schema" + current.identifier = "current" + current_col = MagicMock() + current_col.name = "c" + current_col.dtype = dtype + current_col.data_type = f"{dtype}(100)" + current_col.is_string = MagicMock(return_value=True) + current_col.is_number = MagicMock(return_value=False) + current_col.can_expand_to = MagicMock(return_value=True) + current_col.can_expand_safe = MagicMock(return_value=False) + + adapter.get_columns_in_relation.side_effect = lambda r: ( + [goal_col] if r is goal else [current_col] + ) + + adapter.expand_column_types(goal, current, max_rows=-1) + + goal_col.string_type_instance.assert_called_once_with(-1) + adapter.alter_column_type.assert_called_once_with(current, "c", expected_type) + + def test_varchar_max_to_bounded_does_not_expand(self, adapter): + """varchar(max) current, varchar(100) goal should not call alter_column_type().""" + adapter.get_columns_in_relation = MagicMock(return_value=[]) + adapter.alter_column_type = MagicMock() + + goal = make_rel("goal") + goal_col = MagicMock() + goal_col.name = "c" + goal_col.dtype = "varchar" + goal_col.data_type = "varchar(100)" + goal_col.is_string = MagicMock(return_value=True) + goal_col.is_number = MagicMock(return_value=False) + + current = make_rel("current") + current_col = MagicMock() + current_col.name = "c" + current_col.dtype = "varchar" + current_col.data_type = "varchar(max)" + current_col.is_string = MagicMock(return_value=True) + current_col.is_number = MagicMock(return_value=False) + current_col.string_size = MagicMock(return_value=-1) + current_col.can_expand_to = MagicMock(return_value=False) + current_col.can_expand_safe = MagicMock(return_value=False) + + adapter.get_columns_in_relation.side_effect = lambda r: ( + [goal_col] if r is goal else [current_col] + ) + + adapter.expand_column_types(goal, current, max_rows=-1) + + adapter.alter_column_type.assert_not_called() diff --git a/tests/unit/adapters/mssql/test_sqlserver_column.py b/tests/unit/adapters/mssql/test_sqlserver_column.py new file mode 100644 index 000000000..c1d5bd112 --- /dev/null +++ b/tests/unit/adapters/mssql/test_sqlserver_column.py @@ -0,0 +1,181 @@ +import pytest +from dbt_common.exceptions import DbtRuntimeError + +from dbt.adapters.sqlserver.sqlserver_column import SQLServerColumn + + +class TestSQLServerColumnIsString: + def test_varchar_is_string(self): + col = SQLServerColumn("c", "varchar", char_size=50) + assert col.is_string() is True + + def test_char_is_string(self): + col = SQLServerColumn("c", "char", char_size=10) + assert col.is_string() is True + + def test_nvarchar_is_string(self): + col = SQLServerColumn("c", "nvarchar", char_size=100) + assert col.is_string() is True + + def test_nchar_is_string(self): + col = SQLServerColumn("c", "nchar", char_size=20) + assert col.is_string() is True + + def test_int_is_not_string(self): + col = SQLServerColumn("c", "int") + assert col.is_string() is False + + def test_numeric_is_not_string(self): + col = SQLServerColumn("c", "numeric") + assert col.is_string() is False + + def test_varchar_max_is_max_string(self): + col = SQLServerColumn("c", "varchar", char_size=-1) + assert col.is_max_string() is True + + def test_nvarchar_max_is_max_string(self): + col = SQLServerColumn("c", "nvarchar", char_size=-1) + assert col.is_max_string() is True + + def test_char_is_not_max_string(self): + col = SQLServerColumn("c", "char", char_size=-1) + assert col.is_max_string() is False + + def test_nchar_is_not_max_string(self): + col = SQLServerColumn("c", "nchar", char_size=-1) + assert col.is_max_string() is False + + +class TestSQLServerColumnStringTypeInstance: + def test_varchar_default(self): + col = SQLServerColumn("c", "varchar") + result = col.string_type_instance(100) + assert result == "varchar(100)" + + def test_varchar_max_bounded(self): + col = SQLServerColumn("c", "varchar") + result = col.string_type_instance(0) + assert result == "varchar(8000)" + + def test_nvarchar(self): + col = SQLServerColumn("c", "nvarchar") + result = col.string_type_instance(200) + assert result == "nvarchar(200)" + + def test_nvarchar_max_bounded(self): + col = SQLServerColumn("c", "nvarchar") + result = col.string_type_instance(0) + assert result == "nvarchar(4000)" + + def test_nchar(self): + col = SQLServerColumn("c", "nchar") + result = col.string_type_instance(50) + assert result == "nchar(50)" + + def test_nchar_max_bounded(self): + col = SQLServerColumn("c", "nchar") + result = col.string_type_instance(0) + assert result == "nchar(1)" + + def test_char_default(self): + col = SQLServerColumn("c", "char") + result = col.string_type_instance(5) + assert result == "char(5)" + + result = col.string_type_instance(0) + assert result == "char(1)" + + def test_varchar_max_emits_varchar_max(self): + col = SQLServerColumn("c", "varchar") + result = col.string_type_instance(-1) + assert result == "varchar(max)" + + def test_nvarchar_max_emits_nvarchar_max(self): + col = SQLServerColumn("c", "nvarchar") + result = col.string_type_instance(-1) + assert result == "nvarchar(max)" + + def test_char_max_raises(self): + col = SQLServerColumn("c", "char") + with pytest.raises(DbtRuntimeError, match=r"char\(max\) is not a valid SQL Server type"): + col.string_type_instance(-1) + + def test_nchar_max_raises(self): + col = SQLServerColumn("c", "nchar") + with pytest.raises(DbtRuntimeError, match=r"nchar\(max\) is not a valid SQL Server type"): + col.string_type_instance(-1) + + +class TestSQLServerColumnDataType: + def test_varchar_data_type(self): + col = SQLServerColumn("c", "varchar", char_size=100) + assert col.data_type == "varchar(100)" + + def test_nvarchar_data_type(self): + col = SQLServerColumn("c", "nvarchar", char_size=200) + assert col.data_type == "nvarchar(200)" + + +class TestSQLServerColumnIsFixedNumeric: + def test_money(self): + col = SQLServerColumn("c", "money") + assert col.is_fixed_numeric() is True + + def test_smallmoney(self): + col = SQLServerColumn("c", "smallmoney") + assert col.is_fixed_numeric() is True + + def test_numeric_is_not_fixed(self): + col = SQLServerColumn("c", "numeric") + assert col.is_fixed_numeric() is False + + +class TestSQLServerColumnIsNumeric: + def test_numeric(self): + col = SQLServerColumn("c", "numeric") + assert col.is_numeric() is True + + def test_decimal(self): + col = SQLServerColumn("c", "decimal") + assert col.is_numeric() is True + + def test_money_is_numeric(self): + col = SQLServerColumn("c", "money") + assert col.is_numeric() is True + + def test_smallmoney_is_numeric(self): + col = SQLServerColumn("c", "smallmoney") + assert col.is_numeric() is True + + +class TestSQLServerColumnIsDecimalType: + def test_numeric_is_decimal(self): + col = SQLServerColumn("c", "numeric") + assert col.is_decimal_type() is True + + def test_decimal_is_decimal(self): + col = SQLServerColumn("c", "decimal") + assert col.is_decimal_type() is True + + def test_money_is_not_decimal(self): + col = SQLServerColumn("c", "money") + assert col.is_decimal_type() is False + + def test_smallmoney_is_not_decimal(self): + col = SQLServerColumn("c", "smallmoney") + assert col.is_decimal_type() is False + + +class TestSQLServerColumnStringSize: + def test_string_size_with_char_size(self): + col = SQLServerColumn("c", "varchar", char_size=100) + assert col.string_size() == 100 + + def test_string_size_none_char_size(self): + col = SQLServerColumn("c", "varchar") + assert col.string_size() == 8000 + + def test_string_size_raises_on_non_string(self): + col = SQLServerColumn("c", "int") + with pytest.raises(DbtRuntimeError, match="Called string_size"): + col.string_size()