diff --git a/src/Config/DatabasePrimitives/DatabaseObject.cs b/src/Config/DatabasePrimitives/DatabaseObject.cs index be1eff45ba..9548eda1ba 100644 --- a/src/Config/DatabasePrimitives/DatabaseObject.cs +++ b/src/Config/DatabasePrimitives/DatabaseObject.cs @@ -235,6 +235,16 @@ public bool IsAnyColumnNullable(List columnsToCheck) return null; } + + public virtual int? GetLengthForParam(string paramName) + { + if (Columns.TryGetValue(paramName, out ColumnDefinition? columnDefinition)) + { + return columnDefinition.Length; + } + + return null; + } } /// @@ -270,6 +280,7 @@ public class ColumnDefinition public bool IsNullable { get; set; } public bool IsReadOnly { get; set; } public object? DefaultValue { get; set; } + public int? Length { get; set; } public ColumnDefinition() { } diff --git a/src/Core/Models/DbConnectionParam.cs b/src/Core/Models/DbConnectionParam.cs index 9426f8fd49..0c2c54a5e0 100644 --- a/src/Core/Models/DbConnectionParam.cs +++ b/src/Core/Models/DbConnectionParam.cs @@ -10,11 +10,12 @@ namespace Azure.DataApiBuilder.Core.Models; /// public class DbConnectionParam { - public DbConnectionParam(object? value, DbType? dbType = null, SqlDbType? sqlDbType = null) + public DbConnectionParam(object? value, DbType? dbType = null, SqlDbType? sqlDbType = null, int? length = null) { Value = value; DbType = dbType; SqlDbType = sqlDbType; + Length = length; } /// @@ -31,4 +32,7 @@ public DbConnectionParam(object? value, DbType? dbType = null, SqlDbType? sqlDbT // This is being made nullable // because it's not populated for DB's other than MSSQL. public SqlDbType? SqlDbType { get; set; } + + // Nullable integer parameter representing length. nullable for back compatibility and for where its not needed + public int? Length { get; set; } } diff --git a/src/Core/Models/GraphQLFilterParsers.cs b/src/Core/Models/GraphQLFilterParsers.cs index 90deb884b3..9c21a89686 100644 --- a/src/Core/Models/GraphQLFilterParsers.cs +++ b/src/Core/Models/GraphQLFilterParsers.cs @@ -483,7 +483,7 @@ private static Predicate ParseScalarType( string schemaName, string tableName, string tableAlias, - Func processLiterals, + Func processLiterals, bool isListType = false) { Column column = new(schemaName, tableName, columnName: fieldName, tableAlias); @@ -611,7 +611,7 @@ public static Predicate Parse( IInputValueDefinition argumentSchema, Column column, List fields, - Func processLiterals, + Func processLiterals, bool isListType = false) { List predicates = new(); @@ -626,6 +626,7 @@ public static Predicate Parse( variables: ctx.Variables); bool processLiteral = true; + bool lengthOverride = false; if (value is null) { @@ -671,6 +672,7 @@ public static Predicate Parse( { op = PredicateOperation.LIKE; value = $"%{EscapeLikeString((string)value)}%"; + lengthOverride = true; } break; @@ -683,16 +685,19 @@ public static Predicate Parse( { op = PredicateOperation.NOT_LIKE; value = $"%{EscapeLikeString((string)value)}%"; + lengthOverride = true; } break; case "startsWith": op = PredicateOperation.LIKE; value = $"{EscapeLikeString((string)value)}%"; + lengthOverride = true; break; case "endsWith": op = PredicateOperation.LIKE; value = $"%{EscapeLikeString((string)value)}"; + lengthOverride = true; break; case "isNull": processLiteral = false; @@ -707,7 +712,7 @@ public static Predicate Parse( predicates.Push(new PredicateOperand(new Predicate( new PredicateOperand(column), op, - GenerateRightOperand(ctx, argumentObject, name, processLiterals, value, processLiteral) // right operand + GenerateRightOperand(ctx, argumentObject, name, column, processLiterals, value, processLiteral, lengthOverride) ))); } @@ -758,17 +763,21 @@ public static Predicate Parse( /// The GraphQL middleware context, used to resolve variable values. /// The input object type describing the argument schema. /// The name of the filter operation (e.g., "eq", "in"). + /// The target column, used to derive parameter type/size metadata. /// A function to encode or parameterize literal values for database queries. /// The value to be used as the right operand in the predicate. /// Indicates whether to process the value as a literal using processLiterals, or use its string representation directly. + /// When true, indicates the parameter length should not be constrained to the column length (used for LIKE operations). /// A representing the right operand for the predicate. private static PredicateOperand GenerateRightOperand( IMiddlewareContext ctx, InputObjectType argumentObject, string operationName, - Func processLiterals, + Column column, + Func processLiterals, object value, - bool processLiteral) + bool processLiteral, + bool lengthOverride) { if (operationName.Equals("in", StringComparison.OrdinalIgnoreCase)) { @@ -778,13 +787,13 @@ private static PredicateOperand GenerateRightOperand( argumentObject.Fields[operationName], ctx.Variables)) .Where(inValue => inValue is not null) - .Select(inValue => processLiterals(inValue!, null)) + .Select(inValue => processLiterals(inValue!, column.ColumnName, false)) .ToList(); return new PredicateOperand("(" + string.Join(", ", encodedParams) + ")"); } - return new PredicateOperand(processLiteral ? processLiterals(value, null) : value.ToString()); + return new PredicateOperand(processLiteral ? $"{processLiterals(value, column.ColumnName, lengthOverride)}" : value.ToString()); } private static string EscapeLikeString(string input) diff --git a/src/Core/Resolvers/BaseQueryStructure.cs b/src/Core/Resolvers/BaseQueryStructure.cs index 7f5564f831..67bab5258a 100644 --- a/src/Core/Resolvers/BaseQueryStructure.cs +++ b/src/Core/Resolvers/BaseQueryStructure.cs @@ -116,7 +116,7 @@ public BaseQueryStructure( /// /// Value to be assigned to parameter, which can be null for nullable columns. /// The name of the parameter - backing column name for table/views or parameter name for stored procedures. - public virtual string MakeDbConnectionParam(object? value, string? paramName = null) + public virtual string MakeDbConnectionParam(object? value, string? paramName = null, bool lengthOverride = false) { string encodedParamName = GetEncodedParamName(Counter.Next()); if (!string.IsNullOrEmpty(paramName)) @@ -124,7 +124,8 @@ public virtual string MakeDbConnectionParam(object? value, string? paramName = n Parameters.Add(encodedParamName, new(value, dbType: GetUnderlyingSourceDefinition().GetDbTypeForParam(paramName), - sqlDbType: GetUnderlyingSourceDefinition().GetSqlDbTypeForParam(paramName))); + sqlDbType: GetUnderlyingSourceDefinition().GetSqlDbTypeForParam(paramName), + length: lengthOverride ? -1 : GetUnderlyingSourceDefinition().GetLengthForParam(paramName))); } else { diff --git a/src/Core/Resolvers/CosmosQueryStructure.cs b/src/Core/Resolvers/CosmosQueryStructure.cs index 06558297d8..68d83557c0 100644 --- a/src/Core/Resolvers/CosmosQueryStructure.cs +++ b/src/Core/Resolvers/CosmosQueryStructure.cs @@ -68,7 +68,7 @@ public CosmosQueryStructure( } /// - public override string MakeDbConnectionParam(object? value, string? columnName = null) + public override string MakeDbConnectionParam(object? value, string? columnName = null, bool lengthOverride = false) { string encodedParamName = $"{PARAM_NAME_PREFIX}param{Counter.Next()}"; Parameters.Add(encodedParamName, new(value)); diff --git a/src/Core/Resolvers/MsSqlQueryExecutor.cs b/src/Core/Resolvers/MsSqlQueryExecutor.cs index 368e5d6b00..7037f6aa3a 100644 --- a/src/Core/Resolvers/MsSqlQueryExecutor.cs +++ b/src/Core/Resolvers/MsSqlQueryExecutor.cs @@ -624,8 +624,17 @@ public override SqlCommand PrepareDbCommand( { SqlParameter parameter = cmd.CreateParameter(); parameter.ParameterName = parameterEntry.Key; - parameter.Value = parameterEntry.Value.Value ?? DBNull.Value; + parameter.Value = parameterEntry.Value?.Value ?? DBNull.Value; + PopulateDbTypeForParameter(parameterEntry, parameter); + + //if sqldbtype is varchar, nvarchar then set the length when explicitly provided + if (parameter.SqlDbType is SqlDbType.VarChar or SqlDbType.NVarChar or SqlDbType.Char or SqlDbType.NChar + && parameterEntry.Value?.Length is not null) + { + parameter.Size = parameterEntry.Value.Length.Value; + } + cmd.Parameters.Add(parameter); } } diff --git a/src/Service.Tests/UnitTests/SerializationDeserializationTests.cs b/src/Service.Tests/UnitTests/SerializationDeserializationTests.cs index 12ae3ee993..2ecbac42af 100644 --- a/src/Service.Tests/UnitTests/SerializationDeserializationTests.cs +++ b/src/Service.Tests/UnitTests/SerializationDeserializationTests.cs @@ -526,7 +526,7 @@ private static void VerifyColumnDefinitionSerializationDeserialization(ColumnDef { // test number of properties/fields defined in Column Definition int fields = typeof(ColumnDefinition).GetFields(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance).Length; - Assert.AreEqual(fields, 8); + Assert.AreEqual(fields, 9); // test values expectedColumnDefinition.Equals(deserializedColumnDefinition);