diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/EqualsMethodToFilterTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/EqualsMethodToFilterTranslator.cs index ec3271f443f..44f623b00df 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/EqualsMethodToFilterTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/EqualsMethodToFilterTranslator.cs @@ -13,6 +13,7 @@ * limitations under the License. */ +using System; using System.Linq.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Filters; using MongoDB.Driver.Linq.Linq3Implementation.ExtensionMethods; @@ -72,8 +73,31 @@ private static AstFilter Translate(TranslationContext context, Expression expres var fieldTranslation = ExpressionToFilterFieldTranslator.Translate(context, fieldExpression); var value = valueExpression.GetConstantValue(containingExpression: expression); + + var serializerValueType = fieldTranslation.Serializer.ValueType; + if (value != null && !serializerValueType.IsInstanceOfType(value)) + { + var targetType = Nullable.GetUnderlyingType(serializerValueType) ?? serializerValueType; + var valueType = value.GetType(); + if (IsNumericType(valueType) && IsNumericType(targetType)) + { + value = Convert.ChangeType(value, targetType); + } + } + var serializedValue = SerializationHelper.SerializeValue(fieldTranslation.Serializer, value); return AstFilter.Eq(fieldTranslation.Ast, serializedValue); } + + private static bool IsNumericType(Type type) => + Type.GetTypeCode(type) switch + { + TypeCode.Byte or TypeCode.SByte or + TypeCode.Int16 or TypeCode.UInt16 or + TypeCode.Int32 or TypeCode.UInt32 or + TypeCode.Int64 or TypeCode.UInt64 or + TypeCode.Single or TypeCode.Double or TypeCode.Decimal => true, + _ => false + }; } } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/EqualsMethodToFilterTranslatorTests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/EqualsMethodToFilterTranslatorTests.cs new file mode 100644 index 00000000000..9b0236a13f3 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Translators/ExpressionToFilterTranslators/MethodTranslators/EqualsMethodToFilterTranslatorTests.cs @@ -0,0 +1,95 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using FluentAssertions; +using MongoDB.Bson.Serialization; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Translators.ExpressionToFilterTranslators.MethodTranslators +{ + public class EqualsMethodToFilterTranslatorTests + { + private static readonly RenderArgs __args = new( + BsonSerializer.SerializerRegistry.GetSerializer(), + BsonSerializer.SerializerRegistry); + + [Fact] + public void Equals_with_uint64_and_nullable_int_should_translate() + { + ulong value = 2; + + var filter = Builders.Filter.Where(e => e.NullableIntegerProperty.Equals(value)); + + filter.Render(__args).Should().Be("{ NullableIntegerProperty : 2 }"); + } + + [Fact] + public void Equals_with_int_and_nullable_int_should_translate() + { + int value = 2; + + var filter = Builders.Filter.Where(e => e.NullableIntegerProperty.Equals(value)); + + filter.Render(__args).Should().Be("{ NullableIntegerProperty : 2 }"); + } + + [Fact] + public void Equals_with_null_should_translate() + { + var filter = Builders.Filter.Where(e => e.NullableIntegerProperty.Equals(null)); + + filter.Render(__args).Should().Be("{ NullableIntegerProperty : null }"); + } + + [Fact] + public void Equals_with_uint64_and_int_should_translate() + { + ulong value = 1; + + var filter = Builders.Filter.Where(e => e.IntegerProperty.Equals(value)); + + filter.Render(__args).Should().Be("{ IntegerProperty : 1 }"); + } + + [Fact] + public void Equals_with_string_and_nullable_int_should_throw() + { + var value = "2"; + + var filter = Builders.Filter.Where(e => e.NullableIntegerProperty.Equals(value)); + + var exception = Record.Exception(() => filter.Render(__args)); + exception.Should().NotBeNull(); + } + + [Fact] + public void Equals_with_overflowing_uint64_and_nullable_int_should_throw() + { + ulong value = (ulong)int.MaxValue + 1; + + var filter = Builders.Filter.Where(e => e.NullableIntegerProperty.Equals(value)); + + var exception = Record.Exception(() => filter.Render(__args)); + exception.Should().BeOfType(); + } + + public class TestClass + { + public int IntegerProperty { get; set; } + public int? NullableIntegerProperty { get; set; } + } + } +}