From 7bfa99bda261268be0ceabe00fa8ee6f0fa62859 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Tue, 28 Apr 2026 09:39:58 +0200 Subject: [PATCH 1/2] Migrate to slnx --- .github/workflows/pr-validation.yml | 8 +++---- DacFx.sln | 37 ----------------------------- DacFx.slnx | 5 ++++ 3 files changed, 9 insertions(+), 41 deletions(-) delete mode 100644 DacFx.sln create mode 100644 DacFx.slnx diff --git a/.github/workflows/pr-validation.yml b/.github/workflows/pr-validation.yml index 75d6272..308a821 100644 --- a/.github/workflows/pr-validation.yml +++ b/.github/workflows/pr-validation.yml @@ -35,8 +35,8 @@ jobs: with: dotnet-version: ${{ matrix.dotnetVersion }} dotnet-quality: ga - - run: dotnet build DacFx.sln - - run: dotnet test DacFx.sln --no-build -f ${{ matrix.targetFramework }} + - run: dotnet build DacFx.slnx + - run: dotnet test DacFx.slnx --no-build -f ${{ matrix.targetFramework }} # Test SDK builds with full framework MSBuild on Windows, with SDK itself and against SSDT installation. msbuildTest: @@ -57,5 +57,5 @@ jobs: with: dotnet-version: ${{ env.LATEST_DOTNET_VERSION }} dotnet-quality: preview - - run: dotnet build DacFx.sln - - run: dotnet test DacFx.sln --no-build -f net472 \ No newline at end of file + - run: dotnet build DacFx.slnx + - run: dotnet test DacFx.slnx --no-build -f net472 \ No newline at end of file diff --git a/DacFx.sln b/DacFx.sln deleted file mode 100644 index 90dde50..0000000 --- a/DacFx.sln +++ /dev/null @@ -1,37 +0,0 @@ - -Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio Version 17 -VisualStudioVersion = 17.3.32929.385 -MinimumVisualStudioVersion = 10.0.40219.1 -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Build.Sql", "src\Microsoft.Build.Sql\Microsoft.Build.Sql.csproj", "{7C194D72-97FE-49BB-8D03-B3F7C9A91835}" -EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Build.Sql.Tests", "test\Microsoft.Build.Sql.Tests\Microsoft.Build.Sql.Tests.csproj", "{4F50196D-E946-4A58-992E-C1AE25332CE2}" -EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.Build.Sql.Templates", "src\Microsoft.Build.Sql.Templates\Microsoft.Build.Sql.Templates.csproj", "{1AAE57EF-E614-4C58-8287-C830F098A7DA}" -EndProject -Global - GlobalSection(SolutionConfigurationPlatforms) = preSolution - Debug|Any CPU = Debug|Any CPU - Release|Any CPU = Release|Any CPU - EndGlobalSection - GlobalSection(ProjectConfigurationPlatforms) = postSolution - {7C194D72-97FE-49BB-8D03-B3F7C9A91835}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {7C194D72-97FE-49BB-8D03-B3F7C9A91835}.Debug|Any CPU.Build.0 = Debug|Any CPU - {7C194D72-97FE-49BB-8D03-B3F7C9A91835}.Release|Any CPU.ActiveCfg = Release|Any CPU - {7C194D72-97FE-49BB-8D03-B3F7C9A91835}.Release|Any CPU.Build.0 = Release|Any CPU - {4F50196D-E946-4A58-992E-C1AE25332CE2}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {4F50196D-E946-4A58-992E-C1AE25332CE2}.Debug|Any CPU.Build.0 = Debug|Any CPU - {4F50196D-E946-4A58-992E-C1AE25332CE2}.Release|Any CPU.ActiveCfg = Release|Any CPU - {4F50196D-E946-4A58-992E-C1AE25332CE2}.Release|Any CPU.Build.0 = Release|Any CPU - {1AAE57EF-E614-4C58-8287-C830F098A7DA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU - {1AAE57EF-E614-4C58-8287-C830F098A7DA}.Debug|Any CPU.Build.0 = Debug|Any CPU - {1AAE57EF-E614-4C58-8287-C830F098A7DA}.Release|Any CPU.ActiveCfg = Release|Any CPU - {1AAE57EF-E614-4C58-8287-C830F098A7DA}.Release|Any CPU.Build.0 = Release|Any CPU - EndGlobalSection - GlobalSection(SolutionProperties) = preSolution - HideSolutionNode = FALSE - EndGlobalSection - GlobalSection(ExtensibilityGlobals) = postSolution - SolutionGuid = {D45B5BBF-A8CB-4D39-A2EA-0FB386311429} - EndGlobalSection -EndGlobal diff --git a/DacFx.slnx b/DacFx.slnx new file mode 100644 index 0000000..94de990 --- /dev/null +++ b/DacFx.slnx @@ -0,0 +1,5 @@ + + + + + From c604f5e4586f3294ec533f68a3a4f7de1ffd5d99 Mon Sep 17 00:00:00 2001 From: Shay Rojansky Date: Mon, 27 Apr 2026 18:56:08 +0200 Subject: [PATCH 2/2] Drop in Microsoft.SqlServer.VectorData Moved from Semantic Kernel --- .github/workflows/pr-validation.yml | 14 +- .gitignore | 4 +- DacFx.slnx | 2 + Directory.Packages.props | 10 + .../CallerArgumentExpressionAttribute.cs | 32 + src/LegacySupport/Index.cs | 160 +++ src/LegacySupport/IsExternalInit.cs | 13 + src/LegacySupport/NullableAttributes.cs | 174 +++ .../RequiresDynamicCodeAttribute.cs | 49 + .../RequiresUnreferencedCodeAttribute.cs | 50 + src/LegacySupport/UnreachableException.cs | 48 + src/Microsoft.Build.Sql/VersionCheckTask.cs | 7 +- .../AssemblyInfo.cs | 1 + src/Microsoft.SqlServer.VectorData/README.md | 50 + .../SqlFilterTranslator.cs | 372 ++++++ .../SqlServer.csproj | 42 + .../SqlServerCollection.cs | 889 +++++++++++++ .../SqlServerCollectionOptions.cs | 30 + .../SqlServerCommandBuilder.cs | 1106 +++++++++++++++++ .../SqlServerConstants.cs | 13 + .../SqlServerDynamicCollection.cs | 36 + .../SqlServerFilterTranslator.cs | 174 +++ .../SqlServerJsonSerializerContext.cs | 11 + .../SqlServerMapper.cs | 151 +++ .../SqlServerModelBuilder.cs | 106 ++ .../SqlServerServiceCollectionExtensions.cs | 195 +++ .../SqlServerVectorStore.cs | 149 +++ .../SqlServerVectorStoreOptions.cs | 36 + src/Microsoft.SqlServer.VectorData/Throw.cs | 72 ++ .../VectorStoreErrorHandler.cs | 256 ++++ .../ModelTests/SqlServerBasicModelTests.cs | 77 ++ .../ModelTests/SqlServerDynamicModelTests.cs | 17 + .../SqlServerMultiVectorModelTests.cs | 17 + .../ModelTests/SqlServerNoDataModelTests.cs | 17 + .../ModelTests/SqlServerNoVectorModelTests.cs | 17 + .../Properties/AssemblyAttributes.cs | 1 + .../README.md | 52 + .../SqlServer.ConformanceTests.csproj | 38 + .../SqlServerCollectionManagementTests.cs | 17 + .../SqlServerCommandBuilderTests.cs | 743 +++++++++++ .../SqlServerDependencyInjectionTests.cs | 84 ++ .../SqlServerDistanceFunctionTests.cs | 23 + .../SqlServerEmbeddingGenerationTests.cs | 53 + .../SqlServerFilterTests.cs | 46 + .../SqlServerHybridSearchTests.cs | 26 + .../SqlServerIndexKindTests.cs | 109 ++ .../SqlServerTestSuiteImplementationTests.cs | 14 + .../Support/AzureSqlRequiredAttribute.cs | 38 + .../Support/SqlServerTestEnvironment.cs | 27 + .../Support/SqlServerTestStore.cs | 108 ++ .../TypeTests/SqlServerDataTypeTests.cs | 25 + .../TypeTests/SqlServerEmbeddingTypeTests.cs | 27 + .../TypeTests/SqlServerKeyTypeTests.cs | 26 + .../testsettings.json | 8 + 54 files changed, 5858 insertions(+), 4 deletions(-) create mode 100644 src/LegacySupport/CallerArgumentExpressionAttribute.cs create mode 100644 src/LegacySupport/Index.cs create mode 100644 src/LegacySupport/IsExternalInit.cs create mode 100644 src/LegacySupport/NullableAttributes.cs create mode 100644 src/LegacySupport/RequiresDynamicCodeAttribute.cs create mode 100644 src/LegacySupport/RequiresUnreferencedCodeAttribute.cs create mode 100644 src/LegacySupport/UnreachableException.cs create mode 100644 src/Microsoft.SqlServer.VectorData/AssemblyInfo.cs create mode 100644 src/Microsoft.SqlServer.VectorData/README.md create mode 100644 src/Microsoft.SqlServer.VectorData/SqlFilterTranslator.cs create mode 100644 src/Microsoft.SqlServer.VectorData/SqlServer.csproj create mode 100644 src/Microsoft.SqlServer.VectorData/SqlServerCollection.cs create mode 100644 src/Microsoft.SqlServer.VectorData/SqlServerCollectionOptions.cs create mode 100644 src/Microsoft.SqlServer.VectorData/SqlServerCommandBuilder.cs create mode 100644 src/Microsoft.SqlServer.VectorData/SqlServerConstants.cs create mode 100644 src/Microsoft.SqlServer.VectorData/SqlServerDynamicCollection.cs create mode 100644 src/Microsoft.SqlServer.VectorData/SqlServerFilterTranslator.cs create mode 100644 src/Microsoft.SqlServer.VectorData/SqlServerJsonSerializerContext.cs create mode 100644 src/Microsoft.SqlServer.VectorData/SqlServerMapper.cs create mode 100644 src/Microsoft.SqlServer.VectorData/SqlServerModelBuilder.cs create mode 100644 src/Microsoft.SqlServer.VectorData/SqlServerServiceCollectionExtensions.cs create mode 100644 src/Microsoft.SqlServer.VectorData/SqlServerVectorStore.cs create mode 100644 src/Microsoft.SqlServer.VectorData/SqlServerVectorStoreOptions.cs create mode 100644 src/Microsoft.SqlServer.VectorData/Throw.cs create mode 100644 src/Microsoft.SqlServer.VectorData/VectorStoreErrorHandler.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/ModelTests/SqlServerBasicModelTests.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/ModelTests/SqlServerDynamicModelTests.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/ModelTests/SqlServerMultiVectorModelTests.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/ModelTests/SqlServerNoDataModelTests.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/ModelTests/SqlServerNoVectorModelTests.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/Properties/AssemblyAttributes.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/README.md create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServer.ConformanceTests.csproj create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerCollectionManagementTests.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerCommandBuilderTests.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerDependencyInjectionTests.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerDistanceFunctionTests.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerEmbeddingGenerationTests.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerFilterTests.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerHybridSearchTests.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerIndexKindTests.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerTestSuiteImplementationTests.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/Support/AzureSqlRequiredAttribute.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/Support/SqlServerTestEnvironment.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/Support/SqlServerTestStore.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/TypeTests/SqlServerDataTypeTests.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/TypeTests/SqlServerEmbeddingTypeTests.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/TypeTests/SqlServerKeyTypeTests.cs create mode 100644 test/Microsoft.SqlServer.VectorData.ConformanceTests/testsettings.json diff --git a/.github/workflows/pr-validation.yml b/.github/workflows/pr-validation.yml index 308a821..ea25f76 100644 --- a/.github/workflows/pr-validation.yml +++ b/.github/workflows/pr-validation.yml @@ -36,7 +36,12 @@ jobs: dotnet-version: ${{ matrix.dotnetVersion }} dotnet-quality: ga - run: dotnet build DacFx.slnx - - run: dotnet test DacFx.slnx --no-build -f ${{ matrix.targetFramework }} + + # The tests for Microsoft.SqlServer.VectorData currently require an Azure SQL instance, since on-premise + # SQL Server 2025 doesn't have the latest vector support. Once it does, we can turn on the tests here + # (they use testcontainers). + # - run: dotnet test DacFx.slnx --no-build -f ${{ matrix.targetFramework }} + - run: dotnet test test/Microsoft.Build.Sql.Tests --no-build -f ${{ matrix.targetFramework }} # Test SDK builds with full framework MSBuild on Windows, with SDK itself and against SSDT installation. msbuildTest: @@ -58,4 +63,9 @@ jobs: dotnet-version: ${{ env.LATEST_DOTNET_VERSION }} dotnet-quality: preview - run: dotnet build DacFx.slnx - - run: dotnet test DacFx.slnx --no-build -f net472 \ No newline at end of file + + # The tests for Microsoft.SqlServer.VectorData currently require an Azure SQL instance, since on-premise + # SQL Server 2025 doesn't have the latest vector support. Once it does, we can turn on the tests here + # (they use testcontainers). + # - run: dotnet test DacFx.slnx --no-build -f net472 + - run: dotnet test test/Microsoft.Build.Sql.Tests --no-build -f net472 \ No newline at end of file diff --git a/.gitignore b/.gitignore index af16f14..c549502 100644 --- a/.gitignore +++ b/.gitignore @@ -353,4 +353,6 @@ MigrationBackup/ src/Microsoft.Build.Sql/[Tt]ools/ ## Ignore packages generated for testing -test/Microsoft.Build.Sql.Tests/pkg/ \ No newline at end of file +test/Microsoft.Build.Sql.Tests/pkg/ + +*.lscache diff --git a/DacFx.slnx b/DacFx.slnx index 94de990..c9d3afc 100644 --- a/DacFx.slnx +++ b/DacFx.slnx @@ -1,5 +1,7 @@ + + diff --git a/Directory.Packages.props b/Directory.Packages.props index 3f7f61d..2cc9e77 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -13,6 +13,8 @@ + + @@ -23,11 +25,19 @@ + + + + + + + + \ No newline at end of file diff --git a/src/LegacySupport/CallerArgumentExpressionAttribute.cs b/src/LegacySupport/CallerArgumentExpressionAttribute.cs new file mode 100644 index 0000000..364f9ef --- /dev/null +++ b/src/LegacySupport/CallerArgumentExpressionAttribute.cs @@ -0,0 +1,32 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable IDE0079 +#pragma warning disable SA1101 +#pragma warning disable SA1512 + +using System.Diagnostics.CodeAnalysis; + +namespace System.Runtime.CompilerServices; + +/// +/// Tags parameter that should be filled with specific caller name. +/// +[AttributeUsage(AttributeTargets.Parameter, AllowMultiple = false, Inherited = false)] +[ExcludeFromCodeCoverage] +internal sealed class CallerArgumentExpressionAttribute : Attribute +{ + /// + /// Initializes a new instance of the class. + /// + /// Function parameter to take the name from. + public CallerArgumentExpressionAttribute(string parameterName) + { + ParameterName = parameterName; + } + + /// + /// Gets name of the function parameter that name should be taken from. + /// + public string ParameterName { get; } +} diff --git a/src/LegacySupport/Index.cs b/src/LegacySupport/Index.cs new file mode 100644 index 0000000..fa276aa --- /dev/null +++ b/src/LegacySupport/Index.cs @@ -0,0 +1,160 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + +#pragma warning disable CS0436 // Type conflicts with imported type +#pragma warning disable S3427 // Method overloads with default parameter values should not overlap +#pragma warning disable SA1642 // Constructor summary documentation should begin with standard text +#pragma warning disable IDE0011 // Add braces +#pragma warning disable SA1623 // Property summary documentation should match accessors +#pragma warning disable IDE0023 // Use block body for conversion operator +#pragma warning disable S3928 // Parameter names used into ArgumentException constructors should match an existing one +#pragma warning disable LA0001 // Use the 'Microsoft.Shared.Diagnostics.Throws' class instead of explicitly throwing exception for improved performance +#pragma warning disable CA1305 // Specify IFormatProvider + +namespace System +{ + internal readonly struct Index : IEquatable + { + private readonly int _value; + + /// Construct an Index using a value and indicating if the index is from the start or from the end. + /// The index value. it has to be zero or positive number. + /// Indicating if the index is from the start or from the end. + /// + /// If the Index constructed from the end, index value 1 means pointing at the last element and index value 0 means pointing at beyond last element. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public Index(int value, bool fromEnd = false) + { + if (value < 0) + { + ThrowValueArgumentOutOfRange_NeedNonNegNumException(); + } + + if (fromEnd) + _value = ~value; + else + _value = value; + } + + // The following private constructors mainly created for perf reason to avoid the checks + private Index(int value) + { + _value = value; + } + + /// Create an Index pointing at first element. + public static Index Start => new Index(0); + + /// Create an Index pointing at beyond last element. + public static Index End => new Index(~0); + + /// Create an Index from the start at the position indicated by the value. + /// The index value from the start. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Index FromStart(int value) + { + if (value < 0) + { + ThrowValueArgumentOutOfRange_NeedNonNegNumException(); + } + + return new Index(value); + } + + /// Create an Index from the end at the position indicated by the value. + /// The index value from the end. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Index FromEnd(int value) + { + if (value < 0) + { + ThrowValueArgumentOutOfRange_NeedNonNegNumException(); + } + + return new Index(~value); + } + + /// Returns the index value. + public int Value + { + get + { + if (_value < 0) + return ~_value; + else + return _value; + } + } + + /// Indicates whether the index is from the start or the end. + public bool IsFromEnd => _value < 0; + + /// Calculate the offset from the start using the giving collection length. + /// The length of the collection that the Index will be used with. length has to be a positive value. + /// + /// For performance reason, we don't validate the input length parameter and the returned offset value against negative values. + /// we don't validate either the returned offset is greater than the input length. + /// It is expected Index will be used with collections which always have non negative length/count. If the returned offset is negative and + /// then used to index a collection will get out of range exception which will be same affect as the validation. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public int GetOffset(int length) + { + int offset = _value; + if (IsFromEnd) + { + // offset = length - (~value) + // offset = length + (~(~value) + 1) + // offset = length + value + 1 + + offset += length + 1; + } + + return offset; + } + + /// Indicates whether the current Index object is equal to another object of the same type. + /// An object to compare with this object. + public override bool Equals([NotNullWhen(true)] object? value) => value is Index && _value == ((Index)value)._value; + + /// Indicates whether the current Index object is equal to another Index object. + /// An object to compare with this object. + public bool Equals(Index other) => _value == other._value; + + /// Returns the hash code for this instance. + public override int GetHashCode() => _value; + + /// Converts integer number to an Index. + public static implicit operator Index(int value) => FromStart(value); + + /// Converts the value of the current Index object to its equivalent string representation. + public override string ToString() + { + if (IsFromEnd) + return ToStringFromEnd(); + + return ((uint)Value).ToString(); + } + + private static void ThrowValueArgumentOutOfRange_NeedNonNegNumException() + { + throw new ArgumentOutOfRangeException("value", "value must be non-negative"); + } + + private string ToStringFromEnd() + { +#if (!NETSTANDARD2_0 && !NETFRAMEWORK) + Span span = stackalloc char[11]; // 1 for ^ and 10 for longest possible uint value + bool formatted = ((uint)Value).TryFormat(span.Slice(1), out int charsWritten); + span[0] = '^'; + return new string(span.Slice(0, charsWritten + 1)); +#else + return '^' + Value.ToString(); +#endif + } + } +} diff --git a/src/LegacySupport/IsExternalInit.cs b/src/LegacySupport/IsExternalInit.cs new file mode 100644 index 0000000..4e1b8ba --- /dev/null +++ b/src/LegacySupport/IsExternalInit.cs @@ -0,0 +1,13 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable IDE0079 +#pragma warning disable S3903 + +/* This enables support for C# 9/10 records on older frameworks */ + +namespace System.Runtime.CompilerServices; + +internal static class IsExternalInit +{ +} diff --git a/src/LegacySupport/NullableAttributes.cs b/src/LegacySupport/NullableAttributes.cs new file mode 100644 index 0000000..5fe5f9e --- /dev/null +++ b/src/LegacySupport/NullableAttributes.cs @@ -0,0 +1,174 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable IDE0079 +#pragma warning disable SA1101 +#pragma warning disable SA1116 +#pragma warning disable SA1117 +#pragma warning disable SA1402 +#pragma warning disable SA1512 +#pragma warning disable SA1623 +#pragma warning disable SA1642 +#pragma warning disable SA1649 +#pragma warning disable S3903 +#pragma warning disable IDE0021 // Use block body for constructors +#pragma warning disable CA1019 + +namespace System.Diagnostics.CodeAnalysis; + +#if !NETCOREAPP3_1_OR_GREATER +/// Specifies that null is allowed as an input even if the corresponding type disallows it. +[AttributeUsage(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property, Inherited = false)] +[ExcludeFromCodeCoverage] +internal sealed class AllowNullAttribute : Attribute +{ +} + +/// Specifies that null is disallowed as an input even if the corresponding type allows it. +[AttributeUsage(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property, Inherited = false)] +[ExcludeFromCodeCoverage] +internal sealed class DisallowNullAttribute : Attribute +{ +} + +/// Specifies that an output may be null even if the corresponding type disallows it. +[AttributeUsage(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property | AttributeTargets.ReturnValue, Inherited = false)] +[ExcludeFromCodeCoverage] +internal sealed class MaybeNullAttribute : Attribute +{ +} + +/// Specifies that an output will not be null even if the corresponding type allows it. Specifies that an input argument was not null when the call returns. +[AttributeUsage(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property | AttributeTargets.ReturnValue, Inherited = false)] +[ExcludeFromCodeCoverage] +internal sealed class NotNullAttribute : Attribute +{ +} + +/// Specifies that when a method returns , the parameter may be null even if the corresponding type disallows it. +[AttributeUsage(AttributeTargets.Parameter, Inherited = false)] +[ExcludeFromCodeCoverage] +internal sealed class MaybeNullWhenAttribute : Attribute +{ + /// Initializes the attribute with the specified return value condition. + /// + /// The return value condition. If the method returns this value, the associated parameter may be . + /// + public MaybeNullWhenAttribute(bool returnValue) => ReturnValue = returnValue; + + /// Gets the return value condition. + public bool ReturnValue { get; } +} + +/// Specifies that when a method returns , the parameter will not be null even if the corresponding type allows it. +[AttributeUsage(AttributeTargets.Parameter, Inherited = false)] +[ExcludeFromCodeCoverage] +internal sealed class NotNullWhenAttribute : Attribute +{ + /// Initializes the attribute with the specified return value condition. + /// + /// The return value condition. If the method returns this value, the associated parameter will not be . + /// + public NotNullWhenAttribute(bool returnValue) => ReturnValue = returnValue; + + /// Gets the return value condition. + public bool ReturnValue { get; } +} + +/// Specifies that the method or property will ensure that the listed field and property members have not-null values. +[AttributeUsage(AttributeTargets.Method | AttributeTargets.Property, Inherited = false, AllowMultiple = true)] +[ExcludeFromCodeCoverage] +internal sealed class MemberNotNullAttribute : Attribute +{ + /// Initializes the attribute with a field or property member. + /// + /// The field or property member that is promised to be not-null. + /// + public MemberNotNullAttribute(string member) => Members = [member]; + + /// Initializes the attribute with the list of field and property members. + /// + /// The list of field and property members that are promised to be not-null. + /// + public MemberNotNullAttribute(params string[] members) => Members = members; + + /// Gets field or property member names. + public string[] Members { get; } +} + +/// Specifies that the method or property will ensure that the listed field and property members have not-null values when returning with the specified return value condition. +[AttributeUsage(AttributeTargets.Method | AttributeTargets.Property, Inherited = false, AllowMultiple = true)] +[ExcludeFromCodeCoverage] +internal sealed class MemberNotNullWhenAttribute : Attribute +{ + /// Initializes the attribute with the specified return value condition and a field or property member. + /// + /// The return value condition. If the method returns this value, the associated parameter will not be . + /// + /// + /// The field or property member that is promised to be not-null. + /// + public MemberNotNullWhenAttribute(bool returnValue, string member) + { + ReturnValue = returnValue; + Members = [member]; + } + + /// Initializes the attribute with the specified return value condition and list of field and property members. + /// + /// The return value condition. If the method returns this value, the associated parameter will not be . + /// + /// + /// The list of field and property members that are promised to be not-null. + /// + public MemberNotNullWhenAttribute(bool returnValue, params string[] members) + { + ReturnValue = returnValue; + Members = members; + } + + /// Gets the return value condition. + public bool ReturnValue { get; } + + /// Gets field or property member names. + public string[] Members { get; } +} + +/// Specifies that the output will be non-null if the named parameter is non-null. +[AttributeUsage(AttributeTargets.Parameter | AttributeTargets.Property | AttributeTargets.ReturnValue, AllowMultiple = true, Inherited = false)] +[ExcludeFromCodeCoverage] +internal sealed class NotNullIfNotNullAttribute : Attribute +{ + /// Initializes the attribute with the associated parameter name. + /// + /// The associated parameter name. The output will be non-null if the argument to the parameter specified is non-null. + /// + public NotNullIfNotNullAttribute(string parameterName) => ParameterName = parameterName; + + /// Gets the associated parameter name. + public string ParameterName { get; } +} + +/// Applied to a method that will never return under any circumstance. +[AttributeUsage(AttributeTargets.Method, Inherited = false)] +[ExcludeFromCodeCoverage] +internal sealed class DoesNotReturnAttribute : Attribute +{ +} + +/// Specifies that the method will not return if the associated Boolean parameter is passed the specified value. +[AttributeUsage(AttributeTargets.Parameter, Inherited = false)] +[ExcludeFromCodeCoverage] +internal sealed class DoesNotReturnIfAttribute : Attribute +{ + /// Initializes the attribute with the specified parameter value. + /// + /// The condition parameter value. Code after the method will be considered unreachable by diagnostics if the argument to + /// the associated parameter matches this value. + /// + public DoesNotReturnIfAttribute(bool parameterValue) => ParameterValue = parameterValue; + + /// Gets the condition parameter value. + public bool ParameterValue { get; } +} +#endif diff --git a/src/LegacySupport/RequiresDynamicCodeAttribute.cs b/src/LegacySupport/RequiresDynamicCodeAttribute.cs new file mode 100644 index 0000000..072701f --- /dev/null +++ b/src/LegacySupport/RequiresDynamicCodeAttribute.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable IDE0079 +#pragma warning disable SA1101 +#pragma warning disable SA1116 +#pragma warning disable SA1117 +#pragma warning disable SA1512 +#pragma warning disable SA1623 +#pragma warning disable SA1642 +#pragma warning disable S3903 +#pragma warning disable S3996 + +namespace System.Diagnostics.CodeAnalysis; + +/// +/// Indicates that the specified method requires the ability to generate new code at runtime, +/// for example through . +/// +/// +/// This allows tools to understand which methods are unsafe to call when compiling ahead of time. +/// +[AttributeUsage(AttributeTargets.Method | AttributeTargets.Constructor | AttributeTargets.Class, Inherited = false)] +[ExcludeFromCodeCoverage] +internal sealed class RequiresDynamicCodeAttribute : Attribute +{ + /// + /// Initializes a new instance of the class + /// with the specified message. + /// + /// + /// A message that contains information about the usage of dynamic code. + /// + public RequiresDynamicCodeAttribute(string message) + { + Message = message; + } + + /// + /// Gets a message that contains information about the usage of dynamic code. + /// + public string Message { get; } + + /// + /// Gets or sets an optional URL that contains more information about the method, + /// why it requires dynamic code, and what options a consumer has to deal with it. + /// + public string? Url { get; set; } +} diff --git a/src/LegacySupport/RequiresUnreferencedCodeAttribute.cs b/src/LegacySupport/RequiresUnreferencedCodeAttribute.cs new file mode 100644 index 0000000..6ee4305 --- /dev/null +++ b/src/LegacySupport/RequiresUnreferencedCodeAttribute.cs @@ -0,0 +1,50 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +#pragma warning disable IDE0079 +#pragma warning disable SA1101 +#pragma warning disable SA1116 +#pragma warning disable SA1117 +#pragma warning disable SA1512 +#pragma warning disable SA1623 +#pragma warning disable SA1642 +#pragma warning disable S3903 +#pragma warning disable S3996 + +namespace System.Diagnostics.CodeAnalysis; + +/// +/// /// Indicates that the specified method requires dynamic access to code that is not referenced +/// statically, for example through . +/// +/// +/// This allows tools to understand which methods are unsafe to call when removing unreferenced +/// code from an application. +/// +[AttributeUsage(AttributeTargets.Method | AttributeTargets.Constructor | AttributeTargets.Class, Inherited = false)] +[ExcludeFromCodeCoverage] +internal sealed class RequiresUnreferencedCodeAttribute : Attribute +{ + /// + /// Initializes a new instance of the class + /// with the specified message. + /// + /// + /// A message that contains information about the usage of unreferenced code. + /// + public RequiresUnreferencedCodeAttribute(string message) + { + Message = message; + } + + /// + /// Gets a message that contains information about the usage of unreferenced code. + /// + public string Message { get; } + + /// + /// Gets or sets an optional URL that contains more information about the method, + /// why it requires unreferenced code, and what options a consumer has to deal with it. + /// + public string? Url { get; set; } +} diff --git a/src/LegacySupport/UnreachableException.cs b/src/LegacySupport/UnreachableException.cs new file mode 100644 index 0000000..702dd43 --- /dev/null +++ b/src/LegacySupport/UnreachableException.cs @@ -0,0 +1,48 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; + +#pragma warning disable CA1064 // Exceptions should be public +#pragma warning disable CA1812 // Internal class that is (sometimes) never instantiated. + +namespace System.Diagnostics; + +/// +/// Exception thrown when the program executes an instruction that was thought to be unreachable. +/// +[ExcludeFromCodeCoverage] +internal sealed class UnreachableException : Exception +{ + private const string MessageText = "The program executed an instruction that was thought to be unreachable."; + + /// + /// Initializes a new instance of the class with the default error message. + /// + public UnreachableException() + : base(MessageText) + { + } + + /// + /// Initializes a new instance of the + /// class with a specified error message. + /// + /// The error message that explains the reason for the exception. + public UnreachableException(string? message) + : base(message ?? MessageText) + { + } + + /// + /// Initializes a new instance of the + /// class with a specified error message and a reference to the inner exception that is the cause of + /// this exception. + /// + /// The error message that explains the reason for the exception. + /// The exception that is the cause of the current exception. + public UnreachableException(string? message, Exception? innerException) + : base(message ?? MessageText, innerException) + { + } +} diff --git a/src/Microsoft.Build.Sql/VersionCheckTask.cs b/src/Microsoft.Build.Sql/VersionCheckTask.cs index bc73085..eab1155 100644 --- a/src/Microsoft.Build.Sql/VersionCheckTask.cs +++ b/src/Microsoft.Build.Sql/VersionCheckTask.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System; +using System.Net.Http; using System.Threading; using Microsoft.Build.Framework; using NuGet.Versioning; @@ -47,7 +48,11 @@ public override bool Execute() Log.LogMessage(MessageImportance.Low, $"Already using the latest version of {PackageName}: {currentVersion}."); } } - catch (Exception ex) + catch (OperationCanceledException) + { + // Build was canceled or version check timed out + } + catch (Exception ex) when (ex is HttpRequestException or InvalidOperationException) { Log.LogMessage(MessageImportance.Low, $"Failed to check for the latest version of {PackageName} on NuGet: {ex.Message}"); } diff --git a/src/Microsoft.SqlServer.VectorData/AssemblyInfo.cs b/src/Microsoft.SqlServer.VectorData/AssemblyInfo.cs new file mode 100644 index 0000000..cbb67c1 --- /dev/null +++ b/src/Microsoft.SqlServer.VectorData/AssemblyInfo.cs @@ -0,0 +1 @@ +// Copyright (c) Microsoft. All rights reserved. diff --git a/src/Microsoft.SqlServer.VectorData/README.md b/src/Microsoft.SqlServer.VectorData/README.md new file mode 100644 index 0000000..aa65d48 --- /dev/null +++ b/src/Microsoft.SqlServer.VectorData/README.md @@ -0,0 +1,50 @@ +# Microsoft.SqlServer.VectorData + +SQL Server and Azure SQL provider for [Microsoft.Extensions.VectorData](https://learn.microsoft.com/en-us/dotnet/ai/vector-stores/overview). + +## Usage + +```csharp +using Microsoft.Extensions.VectorData; +using Microsoft.SqlServer.VectorData; + +// Define your record model +public sealed class BlogPost +{ + [VectorStoreKey] + public int Id { get; set; } + + [VectorStoreData] + public string? Title { get; set; } + + [VectorStoreData] + public string? Url { get; set; } + + [VectorStoreData] + public string? Content { get; set; } + + [VectorStoreVector(Dimensions: 1536)] + public ReadOnlyMemory ContentEmbedding { get; set; } +} + +// Create the vector store and get a collection +var vectorStore = new SqlServerVectorStore(connectionString); +var collection = vectorStore.GetCollection("BlogPosts"); +await collection.EnsureCollectionExistsAsync(); + +// Upsert records +await collection.UpsertAsync(new BlogPost +{ + Id = 1, + Title = "Vector search in Azure SQL", + Content = "...", + ContentEmbedding = embedding // ReadOnlyMemory from your embedding provider +}); + +// Search +var results = await collection.SearchAsync(queryEmbedding, top: 5).ToListAsync(); +``` + +## Documentation + +- [Vector stores in .NET](https://learn.microsoft.com/en-us/dotnet/ai/vector-stores/overview) diff --git a/src/Microsoft.SqlServer.VectorData/SqlFilterTranslator.cs b/src/Microsoft.SqlServer.VectorData/SqlFilterTranslator.cs new file mode 100644 index 0000000..1b96676 --- /dev/null +++ b/src/Microsoft.SqlServer.VectorData/SqlFilterTranslator.cs @@ -0,0 +1,372 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics; +using System.Linq; +using System.Linq.Expressions; +using System.Text; +using Microsoft.Extensions.VectorData.ProviderServices; +using Microsoft.Extensions.VectorData.ProviderServices.Filter; + +namespace Microsoft.SqlServer.VectorData; + +#pragma warning disable MEVD9001 // Microsoft.Extensions.VectorData experimental connector-facing APIs + +internal abstract class SqlFilterTranslator : FilterTranslatorBase +{ + protected readonly StringBuilder _sql; + private readonly Expression _preprocessedExpression; + + internal SqlFilterTranslator( + CollectionModel model, + LambdaExpression lambdaExpression, + StringBuilder? sql = null) + { + Debug.Assert(lambdaExpression.Parameters.Count == 1); + this._sql = sql ?? new(); + + this._preprocessedExpression = this.PreprocessFilter(lambdaExpression, model, new FilterPreprocessingOptions { SupportsParameterization = true }); + } + + internal StringBuilder Clause => this._sql; + + internal void Translate(bool appendWhere) + { + if (appendWhere) + { + this._sql.Append("WHERE "); + } + + this.Translate(this._preprocessedExpression, isSearchCondition: true); + } + + protected void Translate(Expression? node, bool isSearchCondition = false) + { + switch (node) + { + case BinaryExpression binary: + this.TranslateBinary(binary); + return; + + case ConstantExpression constant: + this.TranslateConstant(constant.Value, isSearchCondition); + return; + + case QueryParameterExpression { Name: var name, Value: var value }: + this.TranslateQueryParameter(value); + return; + + case MemberExpression member: + this.TranslateMember(member, isSearchCondition); + return; + + case MethodCallExpression methodCall: + this.TranslateMethodCall(methodCall, isSearchCondition); + return; + + case UnaryExpression unary: + this.TranslateUnary(unary, isSearchCondition); + return; + + default: + throw new NotSupportedException("Unsupported NodeType in filter: " + node?.NodeType); + } + } + + protected void TranslateBinary(BinaryExpression binary) + { + // Special handling for null comparisons + switch (binary.NodeType) + { + case ExpressionType.Equal when IsNull(binary.Right): + this._sql.Append('('); + this.Translate(binary.Left); + this._sql.Append(" IS NULL)"); + return; + case ExpressionType.NotEqual when IsNull(binary.Right): + this._sql.Append('('); + this.Translate(binary.Left); + this._sql.Append(" IS NOT NULL)"); + return; + + case ExpressionType.Equal when IsNull(binary.Left): + this._sql.Append('('); + this.Translate(binary.Right); + this._sql.Append(" IS NULL)"); + return; + case ExpressionType.NotEqual when IsNull(binary.Left): + this._sql.Append('('); + this.Translate(binary.Right); + this._sql.Append(" IS NOT NULL)"); + return; + } + + this._sql.Append('('); + this.Translate(binary.Left, isSearchCondition: binary.NodeType is ExpressionType.AndAlso or ExpressionType.OrElse); + + this._sql.Append(binary.NodeType switch + { + ExpressionType.Equal => " = ", + ExpressionType.NotEqual => " <> ", + + ExpressionType.GreaterThan => " > ", + ExpressionType.GreaterThanOrEqual => " >= ", + ExpressionType.LessThan => " < ", + ExpressionType.LessThanOrEqual => " <= ", + + ExpressionType.AndAlso => " AND ", + ExpressionType.OrElse => " OR ", + + _ => throw new NotSupportedException("Unsupported binary expression node type: " + binary.NodeType) + }); + + this.Translate(binary.Right, isSearchCondition: binary.NodeType is ExpressionType.AndAlso or ExpressionType.OrElse); + + this._sql.Append(')'); + + static bool IsNull(Expression expression) + => expression is ConstantExpression { Value: null } or QueryParameterExpression { Value: null }; + } + + protected virtual void TranslateConstant(object? value, bool isSearchCondition) + { + switch (value) + { + case byte b: + this._sql.Append(b); + return; + case short s: + this._sql.Append(s); + return; + case int i: + this._sql.Append(i); + return; + case long l: + this._sql.Append(l); + return; + + case float f: + this._sql.Append(f); + return; + case double d: + this._sql.Append(d); + return; + case decimal d: + this._sql.Append(d); + return; + + case string untrustedInput: + // This is the only place where we allow untrusted input to be passed in, so we need to quote and escape it. + // Luckily for us, values are escaped in the same way for every provider that we support so far. + this._sql.Append('\'').Append(untrustedInput.Replace("'", "''")).Append('\''); + return; + case bool b: + this._sql.Append(b ? "TRUE" : "FALSE"); + return; + case Guid g: + this._sql.Append('\'').Append(g.ToString()).Append('\''); + return; + + case DateTime dateTime: + case DateTimeOffset dateTimeOffset: + case Array: +#if NET + case DateOnly dateOnly: + case TimeOnly timeOnly: +#endif + throw new UnreachableException("Database-specific format, needs to be implemented in the provider's derived translator."); + + case null: + this._sql.Append("NULL"); + return; + + default: + throw new NotSupportedException("Unsupported constant type: " + value.GetType().Name); + } + } + + private void TranslateMember(MemberExpression memberExpression, bool isSearchCondition) + { + if (this.TryBindProperty(memberExpression, out var property)) + { + this.GenerateColumn(property, isSearchCondition); + return; + } + + throw new NotSupportedException($"Member access for '{memberExpression.Member.Name}' is unsupported - only member access over the filter parameter are supported"); + } + + protected virtual void GenerateColumn(PropertyModel property, bool isSearchCondition = false) + // StorageName is considered to be a safe input, we quote and escape it mostly to produce valid SQL. + => this._sql.Append('"').Append(property.StorageName.Replace("\"", "\"\"")).Append('"'); + + protected abstract void TranslateQueryParameter(object? value); + + private void TranslateMethodCall(MethodCallExpression methodCall, bool isSearchCondition = false) + { + // Dictionary access for dynamic mapping (r => r["SomeString"] == "foo") + if (this.TryBindProperty(methodCall, out var property)) + { + this.GenerateColumn(property, isSearchCondition); + return; + } + + switch (methodCall) + { + // Enumerable.Contains(), List.Contains(), MemoryExtensions.Contains() + case var _ when TryMatchContains(methodCall, out var source, out var item): + this.TranslateContains(source, item); + return; + + // Enumerable.Any() with a Contains predicate (r => r.Strings.Any(s => array.Contains(s))) + case { Method.Name: nameof(Enumerable.Any), Arguments: [var anySource, LambdaExpression lambda] } any + when any.Method.DeclaringType == typeof(Enumerable): + this.TranslateAny(anySource, lambda); + return; + + default: + throw new NotSupportedException($"Unsupported method call: {methodCall.Method.DeclaringType?.Name}.{methodCall.Method.Name}"); + } + } + + private void TranslateContains(Expression source, Expression item) + { + switch (source) + { + // Contains over array column (r => r.Strings.Contains("foo")) + case var _ when this.TryBindProperty(source, out _): + this.TranslateContainsOverArrayColumn(source, item); + return; + + // Contains over inline array (r => new[] { "foo", "bar" }.Contains(r.String)) + case NewArrayExpression newArray: + this.Translate(item); + this._sql.Append(" IN ("); + + var isFirst = true; + foreach (var element in newArray.Expressions) + { + if (isFirst) + { + isFirst = false; + } + else + { + this._sql.Append(", "); + } + + this.Translate(element); + } + + this._sql.Append(')'); + return; + + // Contains over captured array (r => arrayLocalVariable.Contains(r.String)) + case QueryParameterExpression { Value: var value }: + this.TranslateContainsOverParameterizedArray(source, item, value); + return; + + default: + throw new NotSupportedException("Unsupported Contains expression"); + } + } + + protected abstract void TranslateContainsOverArrayColumn(Expression source, Expression item); + + protected abstract void TranslateContainsOverParameterizedArray(Expression source, Expression item, object? value); + + /// + /// Translates an Any() call with a Contains predicate, e.g. r.Strings.Any(s => array.Contains(s)). + /// This checks whether any element in the array column is contained in the given values. + /// + private void TranslateAny(Expression source, LambdaExpression lambda) + { + // We only support the pattern: r.ArrayColumn.Any(x => values.Contains(x)) + // where 'values' is an inline array, captured array, or captured list. + if (!this.TryBindProperty(source, out var property) + || lambda.Body is not MethodCallExpression containsCall + || !TryMatchContains(containsCall, out var valuesExpression, out var itemExpression)) + { + throw new NotSupportedException("Unsupported method call: Enumerable.Any"); + } + + // Verify that the item is the lambda parameter + if (itemExpression != lambda.Parameters[0]) + { + throw new NotSupportedException("Unsupported method call: Enumerable.Any"); + } + + // Now extract the values from valuesExpression + switch (valuesExpression) + { + // Inline array: r.Strings.Any(s => new[] { "a", "b" }.Contains(s)) + case NewArrayExpression newArray: + { + var values = new object?[newArray.Expressions.Count]; + for (var i = 0; i < newArray.Expressions.Count; i++) + { + values[i] = newArray.Expressions[i] switch + { + ConstantExpression { Value: var v } => v, + QueryParameterExpression { Value: var v } => v, + _ => throw new NotSupportedException("Unsupported method call: Enumerable.Any") + }; + } + + this.TranslateAnyContainsOverArrayColumn(property, values); + return; + } + + // Captured/parameterized array or list: r.Strings.Any(s => capturedArray.Contains(s)) + case QueryParameterExpression { Value: var value }: + this.TranslateAnyContainsOverArrayColumn(property, value); + return; + + // Constant array: shouldn't normally happen, but handle it + case ConstantExpression { Value: var value }: + this.TranslateAnyContainsOverArrayColumn(property, value); + return; + + default: + throw new NotSupportedException("Unsupported method call: Enumerable.Any"); + } + } + + protected abstract void TranslateAnyContainsOverArrayColumn(PropertyModel property, object? values); + + private void TranslateUnary(UnaryExpression unary, bool isSearchCondition) + { + switch (unary.NodeType) + { + case ExpressionType.Not: + // Special handling for !(a == b) and !(a != b) + if (unary.Operand is BinaryExpression { NodeType: ExpressionType.Equal or ExpressionType.NotEqual } binary) + { + this.TranslateBinary( + Expression.MakeBinary( + binary.NodeType is ExpressionType.Equal ? ExpressionType.NotEqual : ExpressionType.Equal, + binary.Left, + binary.Right)); + return; + } + + this._sql.Append("(NOT "); + this.Translate(unary.Operand, isSearchCondition); + this._sql.Append(')'); + return; + + // Handle converting non-nullable to nullable; such nodes are found in e.g. r => r.Int == nullableInt + case ExpressionType.Convert when Nullable.GetUnderlyingType(unary.Type) == unary.Operand.Type: + this.Translate(unary.Operand, isSearchCondition); + return; + + // Handle convert over member access, for dynamic dictionary access (r => (int)r["SomeInt"] == 8) + case ExpressionType.Convert when this.TryBindProperty(unary.Operand, out var property) && unary.Type == property.Type: + this.GenerateColumn(property, isSearchCondition); + return; + + default: + throw new NotSupportedException("Unsupported unary expression node type: " + unary.NodeType); + } + } +} diff --git a/src/Microsoft.SqlServer.VectorData/SqlServer.csproj b/src/Microsoft.SqlServer.VectorData/SqlServer.csproj new file mode 100644 index 0000000..9362a46 --- /dev/null +++ b/src/Microsoft.SqlServer.VectorData/SqlServer.csproj @@ -0,0 +1,42 @@ + + + + 1.0.0-preview.1 + Microsoft.SqlServer.VectorData + $(AssemblyName) + netstandard2.0;net8.0;net462 + preview + enable + enable + $(NoWarn);MEVD9000,MEVD9001 + + SQL Server provider for Microsoft.Extensions.VectorData + SQL Server provider for Microsoft.Extensions.VectorData + + + true + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/Microsoft.SqlServer.VectorData/SqlServerCollection.cs b/src/Microsoft.SqlServer.VectorData/SqlServerCollection.cs new file mode 100644 index 0000000..f92d55f --- /dev/null +++ b/src/Microsoft.SqlServer.VectorData/SqlServerCollection.cs @@ -0,0 +1,889 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Diagnostics; +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Linq.Expressions; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient; +using Microsoft.Data.SqlTypes; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ProviderServices; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.SqlServer.VectorData; + +/// +/// An implementation of backed by a SQL Server or Azure SQL database. +/// +#pragma warning disable CA1711 // Identifiers should not have incorrect suffix (Collection) +public class SqlServerCollection +#pragma warning restore CA1711 + : VectorStoreCollection, + IKeywordHybridSearchable + where TKey : notnull + where TRecord : class +{ + /// Metadata about vector store record collection. + private readonly VectorStoreCollectionMetadata _collectionMetadata; + + private static readonly VectorSearchOptions s_defaultVectorSearchOptions = new(); + private static readonly HybridSearchOptions s_defaultHybridSearchOptions = new(); + + private readonly string _connectionString; + private readonly CollectionModel _model; + private readonly SqlServerMapper _mapper; + + /// The database schema. + private readonly string? _schema; + + /// Whether the model contains any DiskAnn vector properties, requiring Azure SQL. + private readonly bool _requiresAzureSql; + + /// Cached result of the Azure SQL engine edition check (null = not yet checked). + private bool? _isAzureSql; + + /// + /// Initializes a new instance of the class. + /// + /// Database connection string. + /// The name of the collection. + /// Optional configuration options. + [RequiresUnreferencedCode("The SQL Server provider is currently incompatible with trimming.")] + [RequiresDynamicCode("The SQL Server provider is currently incompatible with NativeAOT.")] + public SqlServerCollection( + string connectionString, + string name, + SqlServerCollectionOptions? options = null) + : this( + connectionString, + name, + static options => typeof(TRecord) == typeof(Dictionary) + ? throw new NotSupportedException(VectorDataStrings.NonDynamicCollectionWithDictionaryNotSupported(typeof(SqlServerDynamicCollection))) + : new SqlServerModelBuilder().Build(typeof(TRecord), typeof(TKey), options.Definition, options.EmbeddingGenerator), + options) + { + } + + internal SqlServerCollection(string connectionString, string name, Func modelFactory, SqlServerCollectionOptions? options) + { + Throw.IfNullOrWhitespace(connectionString); + Throw.IfNull(name); + + options ??= SqlServerCollectionOptions.Default; + this._schema = options.Schema; + + this._connectionString = connectionString; + this.Name = name; + this._model = modelFactory(options); + + this._mapper = new SqlServerMapper(this._model); + + // Check if any vector property uses DiskAnn, which requires Azure SQL. + foreach (var vp in this._model.VectorProperties) + { + if (vp.IndexKind == IndexKind.DiskAnn) + { + this._requiresAzureSql = true; + break; + } + } + + var connectionStringBuilder = new SqlConnectionStringBuilder(connectionString); + + this._collectionMetadata = new() + { + VectorStoreSystemName = SqlServerConstants.VectorStoreSystemName, + VectorStoreName = connectionStringBuilder.InitialCatalog, + CollectionName = name + }; + } + + /// + public override string Name { get; } + + /// + public override async Task CollectionExistsAsync(CancellationToken cancellationToken = default) + { + using SqlConnection connection = new(this._connectionString); + using SqlCommand command = SqlServerCommandBuilder.SelectTableName( + connection, this._schema, this.Name); + + return await connection.ExecuteWithErrorHandlingAsync( + this._collectionMetadata, + "CollectionExists", + async () => + { + using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + return await reader.ReadAsync(cancellationToken).ConfigureAwait(false); + }, + cancellationToken).ConfigureAwait(false); + } + + /// + public override Task EnsureCollectionExistsAsync(CancellationToken cancellationToken = default) + => this.CreateCollectionAsync(ifNotExists: true, cancellationToken); + + private async Task CreateCollectionAsync(bool ifNotExists, CancellationToken cancellationToken) + { + using SqlConnection connection = new(this._connectionString); + + if (this._requiresAzureSql) + { + await this.EnsureAzureSqlForDiskAnnAsync(connection, cancellationToken).ConfigureAwait(false); + } + + List commands = SqlServerCommandBuilder.CreateTable( + connection, + this._schema, + this.Name, + ifNotExists, + this._model); + + foreach (SqlCommand command in commands) + { + using (command) + { + await connection.ExecuteWithErrorHandlingAsync( + this._collectionMetadata, + "CreateCollection", + () => command.ExecuteNonQueryAsync(cancellationToken), + cancellationToken).ConfigureAwait(false); + } + } + } + + /// + public override async Task EnsureCollectionDeletedAsync(CancellationToken cancellationToken = default) + { + using SqlConnection connection = new(this._connectionString); + using SqlCommand command = SqlServerCommandBuilder.DropTableIfExists( + connection, this._schema, this.Name); + + await connection.ExecuteWithErrorHandlingAsync( + this._collectionMetadata, + "DeleteCollection", + () => command.ExecuteNonQueryAsync(cancellationToken), + cancellationToken).ConfigureAwait(false); + } + + /// + public override async Task DeleteAsync(TKey key, CancellationToken cancellationToken = default) + { + Throw.IfNull(key); + + using SqlConnection connection = new(this._connectionString); + using SqlCommand command = SqlServerCommandBuilder.DeleteSingle( + connection, + this._schema, + this.Name, + this._model.KeyProperty, + key); + + await connection.ExecuteWithErrorHandlingAsync( + this._collectionMetadata, + "Delete", + () => command.ExecuteNonQueryAsync(cancellationToken), + cancellationToken).ConfigureAwait(false); + } + + /// + public override async Task DeleteAsync(IEnumerable keys, CancellationToken cancellationToken = default) + { + Throw.IfNull(keys); + + using SqlConnection connection = new(this._connectionString); + await connection.OpenAsync(cancellationToken).ConfigureAwait(false); + + using SqlTransaction transaction = connection.BeginTransaction(); + int taken = 0; + + try + { + while (true) + { +#if NET + SqlCommand command = new("", connection, transaction); + await using (command.ConfigureAwait(false)) +#else + using (SqlCommand command = new("", connection, transaction)) +#endif + { + if (!SqlServerCommandBuilder.DeleteMany( + command, + this._schema, + this.Name, + this._model.KeyProperty, + keys.Skip(taken).Take(SqlServerConstants.MaxParameterCount))) + { + break; // keys is empty, there is nothing to delete + } + + checked + { + taken += command.Parameters.Count; + } + + await command.ExecuteNonQueryAsync(cancellationToken).ConfigureAwait(false); + } + } + + if (taken > 0) + { +#if NET + await transaction.CommitAsync(cancellationToken).ConfigureAwait(false); +#else + transaction.Commit(); +#endif + } + } + catch (DbException ex) + { +#if NET + await transaction.RollbackAsync(cancellationToken).ConfigureAwait(false); +#else + transaction.Rollback(); +#endif + + throw new VectorStoreException(ex.Message, ex) + { + VectorStoreSystemName = SqlServerConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, + CollectionName = this.Name, + OperationName = "DeleteBatch" + }; + } + catch (Exception) + { +#if NET + await transaction.RollbackAsync(cancellationToken).ConfigureAwait(false); +#else + transaction.Rollback(); +#endif + + throw; + } + } + + /// + public override async Task GetAsync(TKey key, RecordRetrievalOptions? options = null, CancellationToken cancellationToken = default) + { + Throw.IfNull(key); + + bool includeVectors = options?.IncludeVectors is true; + if (includeVectors && this._model.EmbeddingGenerationRequired) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + + using SqlConnection connection = new(this._connectionString); + using SqlCommand command = SqlServerCommandBuilder.SelectSingle( + connection, + this._schema, + this.Name, + this._model, + key, + includeVectors); + + return await connection.ExecuteWithErrorHandlingAsync( + this._collectionMetadata, + operationName: "Get", + async () => + { + using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + await reader.ReadAsync(cancellationToken).ConfigureAwait(false); + return reader.HasRows + ? this._mapper.MapFromStorageToDataModel(reader, includeVectors) + : null; + }, + cancellationToken).ConfigureAwait(false); + } + + /// + public override async IAsyncEnumerable GetAsync(IEnumerable keys, RecordRetrievalOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Throw.IfNull(keys); + + bool includeVectors = options?.IncludeVectors is true; + if (includeVectors && this._model.EmbeddingGenerationRequired) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + + using SqlConnection connection = new(this._connectionString); + using SqlCommand command = connection.CreateCommand(); + int taken = 0; + + do + { + if (command.Parameters.Count > 0) + { + command.Parameters.Clear(); // We reuse the same command for the next batch. + } + + if (!SqlServerCommandBuilder.SelectMany( + command, + this._schema, + this.Name, + this._model, + keys.Skip(taken).Take(SqlServerConstants.MaxParameterCount), + includeVectors)) + { + yield break; // keys is empty + } + + checked + { + taken += command.Parameters.Count; + } + + using SqlDataReader reader = await connection.ExecuteWithErrorHandlingAsync( + this._collectionMetadata, + operationName: "GetBatch", + () => command.ExecuteReaderAsync(cancellationToken), + cancellationToken).ConfigureAwait(false); + + while (true) + { + TRecord? record = await VectorStoreErrorHandler.RunOperationAsync( + this._collectionMetadata, + "GetBatch", + async () => await reader.ReadAsync(cancellationToken).ConfigureAwait(false) + ? this._mapper.MapFromStorageToDataModel(reader, includeVectors) + : null) + .ConfigureAwait(false); + + if (record is null) + { + break; + } + + yield return record; + } + } while (command.Parameters.Count == SqlServerConstants.MaxParameterCount); + } + + /// + public override async Task UpsertAsync(TRecord record, CancellationToken cancellationToken = default) + { + Throw.IfNull(record); + + Dictionary>? generatedEmbeddings = null; + + var vectorPropertyCount = this._model.VectorProperties.Count; + for (var i = 0; i < vectorPropertyCount; i++) + { + var vectorProperty = this._model.VectorProperties[i]; + + if (SqlServerModelBuilder.IsVectorPropertyTypeValidCore(vectorProperty.Type, out _)) + { + continue; + } + + // We have a vector property whose type isn't natively supported - we need to generate embeddings. + Debug.Assert(vectorProperty.EmbeddingGenerator is not null); + + // TODO: Ideally we'd group together vector properties using the same generator (and with the same input and output properties), + // and generate embeddings for them in a single batch. That's some more complexity though. + generatedEmbeddings ??= new Dictionary>(vectorPropertyCount); + generatedEmbeddings[vectorProperty] = [await vectorProperty.GenerateEmbeddingAsync(vectorProperty.GetValueAsObject(record), cancellationToken).ConfigureAwait(false)]; + } + + using SqlConnection connection = new(this._connectionString); + using SqlCommand command = connection.CreateCommand(); + SqlServerCommandBuilder.Upsert( + command, + this._schema, + this.Name, + this._model, + [record], + firstRecordIndex: 0, + generatedEmbeddings); + + var keyProperty = this._model.KeyProperty; + + await connection.ExecuteWithErrorHandlingAsync( + this._collectionMetadata, + "Upsert", + async () => + { + using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + await reader.ReadAsync(cancellationToken).ConfigureAwait(false); + + // Inject the generated key into the record if auto-generation was used + if (keyProperty.IsAutoGenerated && Equals(keyProperty.GetValueAsObject(record), default(TKey))) + { + var keyValue = reader.GetFieldValue(0); + keyProperty.SetValue(record, keyValue); + } + + return 0; + }, + cancellationToken).ConfigureAwait(false); + } + + /// + public override async Task UpsertAsync(IEnumerable records, CancellationToken cancellationToken = default) + { + Throw.IfNull(records); + + IReadOnlyList? recordsList = null; + + // If an embedding generator is defined, invoke it once per property for all records. + Dictionary>? generatedEmbeddings = null; + + var vectorPropertyCount = this._model.VectorProperties.Count; + for (var i = 0; i < vectorPropertyCount; i++) + { + var vectorProperty = this._model.VectorProperties[i]; + + if (SqlServerModelBuilder.IsVectorPropertyTypeValidCore(vectorProperty.Type, out _)) + { + continue; + } + + // We have a vector property whose type isn't natively supported - we need to generate embeddings. + Debug.Assert(vectorProperty.EmbeddingGenerator is not null); + + // We have a property with embedding generation; materialize the records' enumerable if needed, to + // prevent multiple enumeration. + if (recordsList is null) + { + recordsList = records is IReadOnlyList r ? r : records.ToList(); + + if (recordsList.Count == 0) + { + return; + } + + records = recordsList; + } + + // TODO: Ideally we'd group together vector properties using the same generator (and with the same input and output properties), + // and generate embeddings for them in a single batch. That's some more complexity though. + generatedEmbeddings ??= new Dictionary>(vectorPropertyCount); + generatedEmbeddings[vectorProperty] = await vectorProperty.GenerateEmbeddingsAsync(records.Select(r => vectorProperty.GetValueAsObject(r)), cancellationToken).ConfigureAwait(false); + } + + // If key auto-generation is enabled, we need to read back generated keys and inject them into records. + // Materialize the records' enumerable if needed, to allow iteration for key injection. + var keyProperty = this._model.KeyProperty; + if (keyProperty.IsAutoGenerated && recordsList is null) + { + recordsList = records is IReadOnlyList r ? r : records.ToList(); + + if (recordsList.Count == 0) + { + return; + } + + records = recordsList; + } + + using SqlConnection connection = new(this._connectionString); + await connection.OpenAsync(cancellationToken).ConfigureAwait(false); + + using SqlTransaction transaction = connection.BeginTransaction(); + int parametersPerRecord = this._model.Properties.Count; + int taken = 0; + int batchSize = SqlServerConstants.MaxParameterCount / parametersPerRecord; + + try + { + while (true) + { + // Materialize the batch to a list so we can iterate multiple times: + // once for building the command, once for reading back results. + var batch = records.Skip(taken).Take(batchSize).ToList(); + if (batch.Count == 0) + { + break; + } + +#if NET + SqlCommand command = new("", connection, transaction); + await using (command.ConfigureAwait(false)) +#else + using (SqlCommand command = new("", connection, transaction)) +#endif + { + if (!SqlServerCommandBuilder.Upsert( + command, + this._schema, + this.Name, + this._model, + batch, + firstRecordIndex: taken, + generatedEmbeddings)) + { + break; // records is empty (shouldn't happen given check above, but defensive) + } + + // Execute and read back the generated keys. + // Each MERGE statement returns a single result set with one row containing the key. + using SqlDataReader reader = await command.ExecuteReaderAsync(cancellationToken).ConfigureAwait(false); + + // Iterate through the records in this batch and inject generated keys where needed. + foreach (var record in batch) + { + await reader.ReadAsync(cancellationToken).ConfigureAwait(false); + + // Only inject key if auto-generation is enabled and record had a default key value + if (keyProperty.IsAutoGenerated && Equals(keyProperty.GetValueAsObject(record), default(TKey))) + { + var keyValue = reader.GetFieldValue(0); + keyProperty.SetValue(record, keyValue); + } + + await reader.NextResultAsync(cancellationToken).ConfigureAwait(false); + } + + checked + { + taken += batch.Count; + } + } + } + + if (taken > 0) + { +#if NET + await transaction.CommitAsync(cancellationToken).ConfigureAwait(false); +#else + transaction.Commit(); +#endif + } + } + catch (DbException ex) + { +#if NET + await transaction.RollbackAsync(cancellationToken).ConfigureAwait(false); +#else + transaction.Rollback(); +#endif + + throw new VectorStoreException(ex.Message, ex) + { + VectorStoreSystemName = SqlServerConstants.VectorStoreSystemName, + VectorStoreName = this._collectionMetadata.VectorStoreName, + CollectionName = this.Name, + OperationName = "UpsertBatch" + }; + } + catch (Exception) + { +#if NET + await transaction.RollbackAsync(cancellationToken).ConfigureAwait(false); +#else + transaction.Rollback(); +#endif + throw; + } + } + + #region Search + + /// + public override async IAsyncEnumerable> SearchAsync( + TInput searchValue, + int top, + VectorSearchOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Throw.IfNull(searchValue); + Throw.IfLessThan(top, 1); + + options ??= s_defaultVectorSearchOptions; + if (options.IncludeVectors && this._model.EmbeddingGenerationRequired) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + + var vectorProperty = this._model.GetVectorPropertyOrSingle(options); + + SqlVector vector = searchValue switch + { + SqlVector v => v, + ReadOnlyMemory r => new(r), + float[] f => new(f), + Embedding e => new(e.Vector), + + _ when vectorProperty.EmbeddingGenerationDispatcher is not null + => new(((Embedding)await vectorProperty.GenerateEmbeddingAsync(searchValue, cancellationToken).ConfigureAwait(false)).Vector), + + _ => vectorProperty.EmbeddingGenerator is null + ? throw new NotSupportedException(VectorDataStrings.InvalidSearchInputAndNoEmbeddingGeneratorWasConfigured(searchValue.GetType(), SqlServerModelBuilder.SupportedVectorTypes)) + : throw new InvalidOperationException(VectorDataStrings.IncompatibleEmbeddingGeneratorWasConfiguredForInputType(typeof(TInput), vectorProperty.EmbeddingGenerator.GetType())) + }; + +#pragma warning disable CA2000 // Dispose objects before losing scope + // Connection and command are going to be disposed by the ReadVectorSearchResultsAsync, + // when the user is done with the results. + SqlConnection connection = new(this._connectionString); + + if (vectorProperty.IndexKind == IndexKind.DiskAnn) + { + await this.EnsureAzureSqlForDiskAnnAsync(connection, cancellationToken).ConfigureAwait(false); + } + + SqlCommand command = SqlServerCommandBuilder.SelectVector( + connection, + this._schema, + this.Name, + vectorProperty, + this._model, + top, + options, + vector); +#pragma warning restore CA2000 // Dispose objects before losing scope + + await foreach (var record in this.ReadVectorSearchResultsAsync(connection, command, options.IncludeVectors, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } + } + + /// + public async IAsyncEnumerable> HybridSearchAsync( + TInput searchValue, + ICollection keywords, + int top, + HybridSearchOptions? options = null, + [EnumeratorCancellation] CancellationToken cancellationToken = default) + where TInput : notnull + { + Throw.IfNull(searchValue); + Throw.IfNull(keywords); + Throw.IfLessThan(top, 1); + + options ??= s_defaultHybridSearchOptions; + if (options.IncludeVectors && this._model.EmbeddingGenerationRequired) + { + throw new NotSupportedException(VectorDataStrings.IncludeVectorsNotSupportedWithEmbeddingGeneration); + } + + var vectorProperty = this._model.GetVectorPropertyOrSingle(new VectorSearchOptions { VectorProperty = options.VectorProperty }); + var textDataProperty = this._model.GetFullTextDataPropertyOrSingle(options.AdditionalProperty); + + SqlVector vector = searchValue switch + { + SqlVector v => v, + ReadOnlyMemory r => new(r), + float[] f => new(f), + Embedding e => new(e.Vector), + + _ when vectorProperty.EmbeddingGenerationDispatcher is not null + => new(((Embedding)await vectorProperty.GenerateEmbeddingAsync(searchValue, cancellationToken).ConfigureAwait(false)).Vector), + + _ => vectorProperty.EmbeddingGenerator is null + ? throw new NotSupportedException(VectorDataStrings.InvalidSearchInputAndNoEmbeddingGeneratorWasConfigured(searchValue.GetType(), SqlServerModelBuilder.SupportedVectorTypes)) + : throw new InvalidOperationException(VectorDataStrings.IncompatibleEmbeddingGeneratorWasConfiguredForInputType(typeof(TInput), vectorProperty.EmbeddingGenerator.GetType())) + }; + + var keywordsCombined = string.Join(" ", keywords); + +#pragma warning disable CA2000 // Dispose objects before losing scope + // Connection and command are going to be disposed by the ReadVectorSearchResultsAsync, + // when the user is done with the results. + SqlConnection connection = new(this._connectionString); + + if (vectorProperty.IndexKind == IndexKind.DiskAnn) + { + await this.EnsureAzureSqlForDiskAnnAsync(connection, cancellationToken).ConfigureAwait(false); + } + + SqlCommand command = SqlServerCommandBuilder.SelectHybrid( + connection, + this._schema, + this.Name, + vectorProperty, + textDataProperty, + this._model, + top, + options, + vector, + keywordsCombined); +#pragma warning restore CA2000 // Dispose objects before losing scope + + await foreach (var record in this.ReadHybridSearchResultsAsync(connection, command, options, cancellationToken).ConfigureAwait(false)) + { + yield return record; + } + } + + #endregion Search + + /// + public override object? GetService(Type serviceType, object? serviceKey = null) + { + Throw.IfNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreCollectionMetadata) ? this._collectionMetadata : + serviceType.IsInstanceOfType(this) ? this : + null; + } + + private async IAsyncEnumerable> ReadVectorSearchResultsAsync( + SqlConnection connection, + SqlCommand command, + bool includeVectors, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + try + { + var vectorProperties = includeVectors ? this._model.VectorProperties : []; + + using SqlDataReader reader = await connection.ExecuteWithErrorHandlingAsync( + this._collectionMetadata, + operationName: "VectorizedSearch", + () => command.ExecuteReaderAsync(cancellationToken), + cancellationToken).ConfigureAwait(false); + + int scoreIndex = -1; + while (await reader.ReadWithErrorHandlingAsync( + this._collectionMetadata, + operationName: "VectorizedSearch", + cancellationToken).ConfigureAwait(false)) + { + if (scoreIndex < 0) + { + scoreIndex = reader.GetOrdinal("score"); + } + + yield return new VectorSearchResult( + this._mapper.MapFromStorageToDataModel(reader, includeVectors), + reader.GetDouble(scoreIndex)); + } + } + finally + { + command.Dispose(); + connection.Dispose(); + } + } + + private async IAsyncEnumerable> ReadHybridSearchResultsAsync( + SqlConnection connection, + SqlCommand command, + HybridSearchOptions options, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + try + { + using SqlDataReader reader = await connection.ExecuteWithErrorHandlingAsync( + this._collectionMetadata, + operationName: "HybridSearch", + () => command.ExecuteReaderAsync(cancellationToken), + cancellationToken).ConfigureAwait(false); + + int scoreIndex = -1; + while (await reader.ReadWithErrorHandlingAsync( + this._collectionMetadata, + operationName: "HybridSearch", + cancellationToken).ConfigureAwait(false)) + { + if (scoreIndex < 0) + { + scoreIndex = reader.GetOrdinal("score"); + } + + yield return new VectorSearchResult( + this._mapper.MapFromStorageToDataModel(reader, options.IncludeVectors), + reader.GetDouble(scoreIndex)); + } + } + finally + { + command.Dispose(); + connection.Dispose(); + } + } + + /// + public override async IAsyncEnumerable GetAsync(Expression> filter, int top, + FilteredRecordRetrievalOptions? options = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + { + Throw.IfNull(filter); + Throw.IfLessThan(top, 1); + + options ??= new(); + + using SqlConnection connection = new(this._connectionString); + using SqlCommand command = SqlServerCommandBuilder.SelectWhere( + filter, + top, + options, + connection, + this._schema, + this.Name, + this._model); + + using SqlDataReader reader = await connection.ExecuteWithErrorHandlingAsync( + this._collectionMetadata, + operationName: "GetAsync", + () => command.ExecuteReaderAsync(cancellationToken), + cancellationToken).ConfigureAwait(false); + + var vectorProperties = options.IncludeVectors ? this._model.VectorProperties : []; + while (await reader.ReadWithErrorHandlingAsync( + this._collectionMetadata, + operationName: "GetAsync", + cancellationToken).ConfigureAwait(false)) + { + yield return this._mapper.MapFromStorageToDataModel(reader, options.IncludeVectors); + } + } + + /// + /// Validates that the connection is to Azure SQL Database or SQL database in Microsoft Fabric, + /// which is required for DiskAnn vector indexes and the VECTOR_SEARCH function. + /// + private async Task EnsureAzureSqlForDiskAnnAsync(SqlConnection connection, CancellationToken cancellationToken) + { + if (this._isAzureSql is true) + { + return; + } + + if (this._isAzureSql is false) + { + connection.Dispose(); + throw new NotSupportedException( + "DiskAnn vector indexes and the VECTOR_SEARCH function require Azure SQL Database or SQL database in Microsoft Fabric. " + + "They are not supported on SQL Server. Use a Flat index kind with VECTOR_DISTANCE instead."); + } + + if (connection.State != System.Data.ConnectionState.Open) + { + await connection.OpenAsync(cancellationToken).ConfigureAwait(false); + } + + using var command = connection.CreateCommand(); + command.CommandText = "SELECT SERVERPROPERTY('EngineEdition')"; + var result = await command.ExecuteScalarAsync(cancellationToken).ConfigureAwait(false); + var engineEdition = Convert.ToInt32(result); + + // 5 = Azure SQL Database, 11 = SQL database in Microsoft Fabric + this._isAzureSql = engineEdition is 5 or 11; + + if (!this._isAzureSql.Value) + { + // Dispose the connection before throwing; in SearchAsync/HybridSearchAsync the connection + // is not in a using block (it's normally disposed by ReadVectorSearchResultsAsync). + connection.Dispose(); + + throw new NotSupportedException( + "DiskAnn vector indexes and the VECTOR_SEARCH function require Azure SQL Database or SQL database in Microsoft Fabric. " + + "They are not supported on SQL Server. Use a Flat index kind with VECTOR_DISTANCE instead."); + } + } +} diff --git a/src/Microsoft.SqlServer.VectorData/SqlServerCollectionOptions.cs b/src/Microsoft.SqlServer.VectorData/SqlServerCollectionOptions.cs new file mode 100644 index 0000000..64becf6 --- /dev/null +++ b/src/Microsoft.SqlServer.VectorData/SqlServerCollectionOptions.cs @@ -0,0 +1,30 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; + +namespace Microsoft.SqlServer.VectorData; + +/// +/// Options when creating a . +/// +public sealed class SqlServerCollectionOptions : VectorStoreCollectionOptions +{ + internal static readonly SqlServerCollectionOptions Default = new(); + + /// + /// Initializes a new instance of the class. + /// + public SqlServerCollectionOptions() + { + } + + internal SqlServerCollectionOptions(SqlServerCollectionOptions? source) : base(source) + { + this.Schema = source?.Schema; + } + + /// + /// Gets or sets the database schema. + /// + public string? Schema { get; set; } +} diff --git a/src/Microsoft.SqlServer.VectorData/SqlServerCommandBuilder.cs b/src/Microsoft.SqlServer.VectorData/SqlServerCommandBuilder.cs new file mode 100644 index 0000000..210d683 --- /dev/null +++ b/src/Microsoft.SqlServer.VectorData/SqlServerCommandBuilder.cs @@ -0,0 +1,1106 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq.Expressions; +using System.Text; +using System.Text.Json; +using Microsoft.Data.SqlClient; +using Microsoft.Data.SqlTypes; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ProviderServices; +using Microsoft.Shared.Diagnostics; + +#pragma warning disable CA2100 // Review SQL queries for security vulnerabilities + +namespace Microsoft.SqlServer.VectorData; + +internal static class SqlServerCommandBuilder +{ + internal static List CreateTable( + SqlConnection connection, + string? schema, + string tableName, + bool ifNotExists, + CollectionModel model) + { + List commands = []; + + StringBuilder sb = new(200); + if (ifNotExists) + { + sb.Append("IF OBJECT_ID(N'"); + sb.AppendTableNameInsideLiteral(schema, tableName); + sb.AppendLine("', N'U') IS NULL"); + } + sb.AppendLine("BEGIN"); + sb.Append("CREATE TABLE "); + sb.AppendTableName(schema, tableName); + sb.AppendLine(" ("); + + var keyStoreType = Map(model.KeyProperty); + sb.AppendIdentifier(model.KeyProperty.StorageName).Append(' ').Append(keyStoreType); + if (model.KeyProperty.IsAutoGenerated) + { + switch (keyStoreType.ToUpperInvariant()) + { + case "SMALLINT": + case "INT": + case "BIGINT": + sb.Append(" IDENTITY"); + break; + case "UNIQUEIDENTIFIER": + sb.Append(" DEFAULT NEWSEQUENTIALID()"); + break; + default: + throw new UnreachableException(); + } + } + + sb.AppendLine(","); + + foreach (var property in model.DataProperties) + { + sb.AppendIdentifier(property.StorageName).Append(' ').Append(Map(property)); + if (!property.IsNullable) + { + sb.Append(" NOT NULL"); + } + sb.AppendLine(","); + } + + foreach (var property in model.VectorProperties) + { + sb.AppendIdentifier(property.StorageName).Append(" VECTOR(").Append(property.Dimensions).Append(')'); + if (!property.IsNullable) + { + sb.Append(" NOT NULL"); + } + sb.AppendLine(","); + } + + sb.Append("PRIMARY KEY (").AppendIdentifier(model.KeyProperty.StorageName).AppendLine(")"); + sb.AppendLine(");"); // end the table definition + + foreach (var dataProperty in model.DataProperties) + { + if (dataProperty.IsIndexed) + { + var sqlType = Map(dataProperty); + if (sqlType == "JSON") + { + sb.Append("CREATE JSON INDEX "); + } + else + { + sb.Append("CREATE INDEX "); + } + sb.AppendIndexName(tableName, dataProperty.StorageName); + sb.Append(" ON ").AppendTableName(schema, tableName); + sb.Append('(').AppendIdentifier(dataProperty.StorageName).AppendLine(");"); + } + } + + // Create full-text catalog and index for properties marked as IsFullTextIndexed + var fullTextProperties = new List(); + foreach (var dataProperty in model.DataProperties) + { + if (dataProperty.IsFullTextIndexed) + { + fullTextProperties.Add(dataProperty); + } + } + + if (fullTextProperties.Count > 0) + { + // Generate a unique catalog name based on the table name + var catalogName = $"ftcat_{tableName}".Replace(" ", "_"); + + // Create full-text catalog if it doesn't exist + sb.Append("IF NOT EXISTS (SELECT 1 FROM sys.fulltext_catalogs WHERE name = '").Append(catalogName.Replace("'", "''")).AppendLine("')"); + sb.Append(" CREATE FULLTEXT CATALOG ").AppendIdentifier(catalogName).AppendLine(";"); + + // Create full-text index on the table using dynamic SQL to look up the PK constraint name + // Full-text indexes require a unique index (we use the primary key) + sb.AppendLine("DECLARE @pkIndexName NVARCHAR(128);"); + sb.Append("SELECT @pkIndexName = name FROM sys.indexes WHERE object_id = OBJECT_ID(N'"); + sb.AppendTableNameInsideLiteral(schema, tableName); + sb.AppendLine("') AND is_primary_key = 1;"); + + sb.AppendLine("DECLARE @ftSql NVARCHAR(MAX);"); + sb.Append("SET @ftSql = N'CREATE FULLTEXT INDEX ON "); + sb.AppendTableNameInsideLiteral(schema, tableName).Append(" ("); + for (int i = 0; i < fullTextProperties.Count; i++) + { + sb.AppendIdentifierInsideLiteral(fullTextProperties[i].StorageName); + if (i < fullTextProperties.Count - 1) + { + sb.Append(','); + } + } + sb.Append(") KEY INDEX ' + QUOTENAME(@pkIndexName) + N' ON "); + sb.AppendIdentifierInsideLiteral(catalogName).AppendLine("';"); + sb.AppendLine("EXEC sp_executesql @ftSql;"); + } + + sb.Append("END;"); + + commands.Add(connection.CreateCommand(sb)); + + // CREATE VECTOR INDEX must be in a separate batch from CREATE TABLE. + // It is also a preview feature in SQL Server 2025, requiring PREVIEW_FEATURES to be enabled. + bool hasVectorIndex = false; + foreach (var vectorProperty in model.VectorProperties) + { + switch (vectorProperty.IndexKind) + { + case IndexKind.Flat or null or "": + continue; + + case IndexKind.DiskAnn: + if (!hasVectorIndex) + { + SqlCommand enablePreview = connection.CreateCommand(); + enablePreview.CommandText = "ALTER DATABASE SCOPED CONFIGURATION SET PREVIEW_FEATURES = ON;"; + commands.Add(enablePreview); + hasVectorIndex = true; + } + + string distanceFunction = vectorProperty.DistanceFunction ?? DistanceFunction.CosineDistance; + (string distanceMetric, _) = MapDistanceFunction(distanceFunction); + + StringBuilder vectorIndexSb = new(200); + vectorIndexSb.Append("CREATE VECTOR INDEX "); + vectorIndexSb.AppendIndexName(tableName, vectorProperty.StorageName); + vectorIndexSb.Append(" ON ").AppendTableName(schema, tableName); + vectorIndexSb.Append('(').AppendIdentifier(vectorProperty.StorageName).Append(')'); + vectorIndexSb.Append(" WITH (METRIC = '").Append(distanceMetric).AppendLine("', TYPE = 'DISKANN');"); + commands.Add(connection.CreateCommand(vectorIndexSb)); + break; + + default: + throw new NotSupportedException($"Index kind '{vectorProperty.IndexKind}' is not supported by the SQL Server connector."); + } + } + + return commands; + } + + internal static SqlCommand DropTableIfExists(SqlConnection connection, string? schema, string tableName) + { + StringBuilder sb = new(50); + sb.Append("DROP TABLE IF EXISTS "); + sb.AppendTableName(schema, tableName); + + return connection.CreateCommand(sb); + } + + internal static SqlCommand SelectTableName(SqlConnection connection, string? schema, string tableName) + { + SqlCommand command = connection.CreateCommand(); + command.CommandText = """ + SELECT TABLE_NAME + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_TYPE = 'BASE TABLE' + AND (@schema is NULL or TABLE_SCHEMA = @schema) + AND TABLE_NAME = @tableName + """; + command.Parameters.AddWithValue("@schema", string.IsNullOrEmpty(schema) ? DBNull.Value : schema); + command.Parameters.AddWithValue("@tableName", tableName); // the name is not escaped by us, just provided as parameter + return command; + } + + internal static SqlCommand SelectTableNames(SqlConnection connection, string? schema) + { + SqlCommand command = connection.CreateCommand(); + command.CommandText = """ + SELECT TABLE_NAME + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_TYPE = 'BASE TABLE' + AND (@schema is NULL or TABLE_SCHEMA = @schema) + """; + command.Parameters.AddWithValue("@schema", string.IsNullOrEmpty(schema) ? DBNull.Value : schema); + return command; + } + + /// + /// Checks if the key property uses SQL Server IDENTITY (for int/bigint) as opposed to DEFAULT (for GUID). + /// IDENTITY columns require SET IDENTITY_INSERT ON to insert explicit values. + /// + private static bool UsesIdentity(KeyPropertyModel keyProperty) + { + if (!keyProperty.IsAutoGenerated) + { + return false; + } + + var keyStoreType = Map(keyProperty).ToUpperInvariant(); + return keyStoreType is "SMALLINT" or "INT" or "BIGINT"; + } + + // Note: since keys may be auto-generated, we can't use a single multi-value MERGE statement, since that would return + // the generated keys in undefined order (OUTPUT order is not guaranteed in MERGE). + // Use a batch of single-row MERGE statements instead - each returns a separate result set. + internal static bool Upsert( + SqlCommand command, + string? schema, + string tableName, + CollectionModel model, + IEnumerable records, + int firstRecordIndex, + Dictionary>? generatedEmbeddings) + { + var keyProperty = model.KeyProperty; + StringBuilder sb = new(500); + + int rowIndex = 0, paramIndex = 0; + + foreach (var record in records) + { + // A record needs auto-generation if the key property is auto-generated AND the record has a default key value. + var needsKeyGeneration = keyProperty.IsAutoGenerated && Equals(keyProperty.GetValueAsObject(record), default(TKey)); + // Skip key in INSERT when auto-generating (IDENTITY will provide the value) + var skipKeyInInsert = needsKeyGeneration; + // For explicit keys with IDENTITY columns, we need to enable IDENTITY_INSERT + // (only for int/bigint, not for GUID which uses DEFAULT NEWSEQUENTIALID()) + var needsIdentityInsert = UsesIdentity(keyProperty) && !needsKeyGeneration; + + // Enable IDENTITY_INSERT if we're inserting an explicit value into an IDENTITY column + if (needsIdentityInsert) + { + sb.Append("SET IDENTITY_INSERT "); + sb.AppendTableName(schema, tableName); + sb.AppendLine(" ON;"); + } + + sb.Append("MERGE INTO "); + sb.AppendTableName(schema, tableName); + sb.AppendLine(" AS t"); + sb.Append("USING (VALUES ("); + + foreach (var property in model.Properties) + { + // Skip key in VALUES when auto-generating + if (property is KeyPropertyModel && skipKeyInInsert) + { + continue; + } + + sb.AppendParameterName(property, ref paramIndex, out var paramName).Append(','); + + var value = property is VectorPropertyModel vectorProperty && generatedEmbeddings?.TryGetValue(vectorProperty, out var ge) == true + ? ge[firstRecordIndex + rowIndex] + : property.GetValueAsObject(record); + + command.AddParameter(property, paramName, value); + } + + sb[sb.Length - 1] = ')'; // replace the last comma with a closing parenthesis + sb.Append(") AS s ("); + sb.AppendIdentifiers(model.Properties, skipKey: skipKeyInInsert); + sb.AppendLine(")"); + + if (needsKeyGeneration) + { + // When auto-generating a key, we always insert (ON condition never matches). + sb.AppendLine("ON (1=0)"); + } + else + { + // For upsert, match on the key from the source + sb.Append("ON (t.").AppendIdentifier(model.KeyProperty.StorageName).Append(" = s.").AppendIdentifier(model.KeyProperty.StorageName).AppendLine(")"); + sb.AppendLine("WHEN MATCHED THEN"); + sb.Append("UPDATE SET "); + foreach (var property in model.Properties) + { + if (property is not KeyPropertyModel) // don't update the key + { + sb.Append("t.").AppendIdentifier(property.StorageName).Append(" = s.").AppendIdentifier(property.StorageName).Append(','); + } + } + --sb.Length; // remove the last comma + sb.AppendLine(); + } + + sb.AppendLine("WHEN NOT MATCHED THEN"); + sb.Append("INSERT ("); + sb.AppendIdentifiers(model.Properties, skipKey: skipKeyInInsert); + sb.AppendLine(")"); + sb.Append("VALUES ("); + sb.AppendIdentifiers(model.Properties, prefix: "s.", skipKey: skipKeyInInsert); + sb.AppendLine(")"); + sb.Append("OUTPUT inserted.").AppendIdentifier(model.KeyProperty.StorageName).AppendLine(";"); + + // Disable IDENTITY_INSERT after the MERGE + if (needsIdentityInsert) + { + sb.Append("SET IDENTITY_INSERT "); + sb.AppendTableName(schema, tableName); + sb.AppendLine(" OFF;"); + } + + sb.AppendLine(); + + rowIndex++; + } + + if (rowIndex == 0) + { + return false; // there is nothing to do! + } + + command.CommandText = sb.ToString(); + return true; + } + + internal static SqlCommand DeleteSingle( + SqlConnection connection, string? schema, string tableName, + KeyPropertyModel keyProperty, object key) + { + SqlCommand command = connection.CreateCommand(); + + int paramIndex = 0; + StringBuilder sb = new(100); + sb.Append("DELETE FROM "); + sb.AppendTableName(schema, tableName); + sb.Append(" WHERE ").AppendIdentifier(keyProperty.StorageName).Append(" = "); + sb.AppendParameterName(keyProperty, ref paramIndex, out string keyParamName); + command.AddParameter(keyProperty, keyParamName, key); + + command.CommandText = sb.ToString(); + return command; + } + + internal static bool DeleteMany( + SqlCommand command, string? schema, string tableName, + KeyPropertyModel keyProperty, IEnumerable keys) + { + StringBuilder sb = new(100); + sb.Append("DELETE FROM "); + sb.AppendTableName(schema, tableName); + sb.Append(" WHERE ").AppendIdentifier(keyProperty.StorageName).Append(" IN ("); + sb.AppendKeyParameterList(keys, command, keyProperty, out bool emptyKeys); + sb.Append(')'); // close the IN clause + + if (emptyKeys) + { + return false; + } + + command.CommandText = sb.ToString(); + return true; + } + + internal static SqlCommand SelectSingle( + SqlConnection sqlConnection, string? schema, string collectionName, + CollectionModel model, + object key, + bool includeVectors) + { + SqlCommand command = sqlConnection.CreateCommand(); + + int paramIndex = 0; + StringBuilder sb = new(200); + sb.Append("SELECT "); + sb.AppendIdentifiers(model.Properties, includeVectors: includeVectors); + sb.AppendLine(); + sb.Append("FROM "); + sb.AppendTableName(schema, collectionName); + sb.AppendLine(); + sb.Append("WHERE ").AppendIdentifier(model.KeyProperty.StorageName).Append(" = "); + sb.AppendParameterName(model.KeyProperty, ref paramIndex, out string keyParamName); + command.AddParameter(model.KeyProperty, keyParamName, key); + + command.CommandText = sb.ToString(); + return command; + } + + internal static bool SelectMany( + SqlCommand command, string? schema, string tableName, + CollectionModel model, + IEnumerable keys, + bool includeVectors) + { + StringBuilder sb = new(200); + sb.Append("SELECT "); + sb.AppendIdentifiers(model.Properties, includeVectors: includeVectors); + sb.AppendLine(); + sb.Append("FROM "); + sb.AppendTableName(schema, tableName); + sb.AppendLine(); + sb.Append("WHERE ").AppendIdentifier(model.KeyProperty.StorageName).Append(" IN ("); + sb.AppendKeyParameterList(keys, command, model.KeyProperty, out bool emptyKeys); + sb.Append(')'); // close the IN clause + + if (emptyKeys) + { + return false; // there is nothing to do! + } + + command.CommandText = sb.ToString(); + return true; + } + + internal static SqlCommand SelectVector( + SqlConnection connection, string? schema, string tableName, + VectorPropertyModel vectorProperty, + CollectionModel model, + int top, + VectorSearchOptions options, + SqlVector vector) + { + string distanceFunction = vectorProperty.DistanceFunction ?? DistanceFunction.CosineDistance; + (string distanceMetric, string sorting) = MapDistanceFunction(distanceFunction); + + return UseVectorSearch(vectorProperty) + ? SelectVectorWithVectorSearch(connection, schema, tableName, vectorProperty, model, top, options, vector, distanceMetric, sorting) + : SelectVectorWithVectorDistance(connection, schema, tableName, vectorProperty, model, top, options, vector, distanceMetric, sorting); + } + + private static SqlCommand SelectVectorWithVectorDistance( + SqlConnection connection, string? schema, string tableName, + VectorPropertyModel vectorProperty, + CollectionModel model, + int top, + VectorSearchOptions options, + SqlVector vector, + string distanceMetric, + string sorting) + { + SqlCommand command = connection.CreateCommand(); + command.Parameters.AddWithValue("@vector", vector); + + StringBuilder sb = new(200); + + sb.Append("SELECT "); + sb.AppendIdentifiers(model.Properties, includeVectors: options.IncludeVectors); + sb.AppendLine(","); + sb.Append("VECTOR_DISTANCE('").Append(distanceMetric).Append("', ").AppendIdentifier(vectorProperty.StorageName) + .Append(", CAST(@vector AS VECTOR(").Append(vector.Length).AppendLine("))) AS [score]"); + sb.Append("FROM "); + sb.AppendTableName(schema, tableName); + sb.AppendLine(); + + if (options.Filter is not null) + { + int startParamIndex = command.Parameters.Count; + + SqlServerFilterTranslator translator = new(model, options.Filter, sb, startParamIndex: startParamIndex); + translator.Translate(appendWhere: true); + List parameters = translator.ParameterValues; + + foreach (object parameter in parameters) + { + command.AddParameter(vectorProperty, $"@_{startParamIndex++}", parameter); + } + + sb.AppendLine(); + } + + // If score threshold is specified, wrap in a subquery to filter on the pre-calculated score + // This avoids calculating VECTOR_DISTANCE() twice. + if (options.ScoreThreshold is not null) + { + // For SQL Server, all distance metrics return a distance (lower = more similar), so we filter with <=. + command.Parameters.AddWithValue("@scoreThreshold", options.ScoreThreshold!.Value); + + var innerQuery = sb.ToString(); + sb.Clear(); + sb.Append("SELECT * FROM (").Append(innerQuery).AppendLine(") AS [inner]"); + sb.AppendLine("WHERE [score] <= @scoreThreshold"); + } + + sb.AppendFormat("ORDER BY [score] {0}", sorting); + sb.AppendLine(); + // Negative Skip and Top values are rejected by the VectorSearchOptions property setters. + // 0 is a legal value for OFFSET. + sb.AppendFormat("OFFSET {0} ROWS FETCH NEXT {1} ROWS ONLY;", options.Skip, top); + + command.CommandText = sb.ToString(); + return command; + } + + /// + /// Generates a SELECT query using the VECTOR_SEARCH() function for approximate nearest neighbor search + /// when the vector property has a vector index (e.g. DiskANN). + /// + private static SqlCommand SelectVectorWithVectorSearch( + SqlConnection connection, string? schema, string tableName, + VectorPropertyModel vectorProperty, + CollectionModel model, + int top, + VectorSearchOptions options, + SqlVector vector, + string distanceMetric, + string sorting) + { + SqlCommand command = connection.CreateCommand(); + command.Parameters.AddWithValue("@vector", vector); + + StringBuilder sb = new(300); + + // When skip > 0, we need a subquery since TOP and OFFSET/FETCH can't coexist in the same SELECT. + bool needsSubquery = options.Skip > 0; + + if (needsSubquery) + { + sb.Append("SELECT * FROM ("); + } + + // VECTOR_SEARCH returns all columns from the table plus a 'distance' column. + // We select the needed columns from the table alias and alias 'distance' as 'score'. + // The latest version vector indexes require SELECT TOP(N) WITH APPROXIMATE instead of the deprecated TOP_N parameter. + sb.Append("SELECT TOP(").Append(top + options.Skip).Append(") WITH APPROXIMATE "); + sb.AppendIdentifiers(model.Properties, prefix: "t.", includeVectors: options.IncludeVectors); + sb.AppendLine(","); + sb.AppendLine("s.[distance] AS [score]"); + sb.Append("FROM VECTOR_SEARCH(TABLE = "); + sb.AppendTableName(schema, tableName); + sb.Append(" AS t, COLUMN = ").AppendIdentifier(vectorProperty.StorageName); + sb.Append(", SIMILAR_TO = @vector, METRIC = '").Append(distanceMetric).AppendLine("') AS s"); + + // With latest version vector indexes, WHERE predicates are applied during the vector search process + // (iterative filtering), not after retrieval. + if (options.Filter is not null) + { + int startParamIndex = command.Parameters.Count; + + SqlServerFilterTranslator translator = new(model, options.Filter, sb, startParamIndex: startParamIndex, tableAlias: "t"); + translator.Translate(appendWhere: true); + List parameters = translator.ParameterValues; + + foreach (object parameter in parameters) + { + command.AddParameter(property: null, $"@_{startParamIndex++}", parameter); + } + + sb.AppendLine(); + } + + if (options.ScoreThreshold is not null) + { + command.Parameters.AddWithValue("@scoreThreshold", options.ScoreThreshold!.Value); + sb.Append(options.Filter is not null ? "AND " : "WHERE "); + sb.AppendLine("s.[distance] <= @scoreThreshold"); + } + + sb.AppendFormat("ORDER BY [score] {0}", sorting); + + if (needsSubquery) + { + sb.AppendLine(); + sb.Append(") AS [inner]"); + sb.AppendLine(); + sb.AppendFormat("ORDER BY [score] {0}", sorting); + sb.AppendLine(); + sb.AppendFormat("OFFSET {0} ROWS FETCH NEXT {1} ROWS ONLY;", options.Skip, top); + } + + command.CommandText = sb.ToString(); + return command; + } + + internal static SqlCommand SelectHybrid( + SqlConnection connection, string? schema, string tableName, + VectorPropertyModel vectorProperty, + DataPropertyModel textProperty, + CollectionModel model, + int top, + HybridSearchOptions options, + SqlVector vector, + string keywords) + { + bool useVectorSearch = UseVectorSearch(vectorProperty); + + string distanceFunction = vectorProperty.DistanceFunction ?? DistanceFunction.CosineDistance; + (string distanceMetric, _) = MapDistanceFunction(distanceFunction); + + SqlCommand command = connection.CreateCommand(); + command.Parameters.AddWithValue("@vector", vector); + command.Parameters.AddWithValue("@keywords", keywords); + + // For RRF, we need to fetch more candidates from each search than the final top count + // to allow proper merging. The number of candidates should be at least top + skip. + // The RRF constant (k) is typically 60 in literature, but we use a smaller value + // that still allows proper ranking while keeping the query efficient. + int candidateCount = Math.Max(top + options.Skip, 20); // Fetch at least 20 candidates + const int RrfK = 60; // Standard RRF constant + + command.Parameters.AddWithValue("@candidateCount", candidateCount); + command.Parameters.AddWithValue("@rrfK", RrfK); + + StringBuilder sb = new(1000); + + // Build the hybrid search query using CTEs with Reciprocal Rank Fusion (RRF) + // Reference: https://github.com/Azure-Samples/azure-sql-db-openai/blob/main/vector-embeddings/07-hybrid-search.sql + + // CTE 1: Keyword search using FREETEXTTABLE + sb.AppendLine("WITH keyword_search AS ("); + sb.AppendLine(" SELECT TOP(@candidateCount)"); + sb.Append(" ").AppendIdentifier(model.KeyProperty.StorageName).AppendLine(","); + sb.AppendLine(" RANK() OVER (ORDER BY ft_rank DESC) AS [rank]"); + sb.AppendLine(" FROM ("); + sb.AppendLine(" SELECT TOP(@candidateCount)"); + sb.Append(" w.").AppendIdentifier(model.KeyProperty.StorageName).AppendLine(","); + sb.AppendLine(" ftt.[RANK] AS ft_rank"); + sb.Append(" FROM ").AppendTableName(schema, tableName).AppendLine(" w"); + sb.Append(" INNER JOIN FREETEXTTABLE(").AppendTableName(schema, tableName).Append(", ") + .AppendIdentifier(textProperty.StorageName).AppendLine(", @keywords) AS ftt"); + sb.Append(" ON w.").AppendIdentifier(model.KeyProperty.StorageName).AppendLine(" = ftt.[KEY]"); + + // Apply filter to keyword search if specified + if (options.Filter is not null) + { + int startParamIndex = command.Parameters.Count; + SqlServerFilterTranslator translator = new(model, options.Filter, sb, startParamIndex: startParamIndex, tableAlias: "w"); + translator.Translate(appendWhere: true); + foreach (object parameter in translator.ParameterValues) + { + command.AddParameter(property: null, $"@_{startParamIndex++}", parameter); + } + sb.AppendLine(); + } + + sb.AppendLine(" ORDER BY ft_rank DESC"); + sb.AppendLine(" ) AS freetext_documents"); + sb.AppendLine("),"); + + // CTE 2: Semantic/vector search + if (useVectorSearch) + { + // Use VECTOR_SEARCH() for approximate nearest neighbor search with a vector index. + // The latest version vector indexes require SELECT TOP(N) WITH APPROXIMATE instead of the deprecated TOP_N parameter. + sb.AppendLine("semantic_search AS ("); + sb.AppendLine(" SELECT TOP(@candidateCount) WITH APPROXIMATE"); + sb.Append(" t.").AppendIdentifier(model.KeyProperty.StorageName).AppendLine(","); + sb.AppendLine(" RANK() OVER (ORDER BY s.[distance]) AS [rank]"); + sb.AppendLine(" FROM VECTOR_SEARCH(TABLE = "); + sb.Append(" ").AppendTableName(schema, tableName); + sb.Append(" AS t, COLUMN = ").AppendIdentifier(vectorProperty.StorageName); + sb.Append(", SIMILAR_TO = @vector, METRIC = '").Append(distanceMetric).AppendLine("') AS s"); + + // With latest version vector indexes, WHERE predicates are applied during the vector search process + // (iterative filtering), not after retrieval. + if (options.Filter is not null) + { + int filterParamStart = command.Parameters.Count; + SqlServerFilterTranslator translator = new(model, options.Filter, sb, startParamIndex: filterParamStart, tableAlias: "t"); + translator.Translate(appendWhere: true); + foreach (object parameter in translator.ParameterValues) + { + command.AddParameter(property: null, $"@_{filterParamStart++}", parameter); + } + sb.AppendLine(); + } + + sb.AppendLine(" ORDER BY s.[distance]"); + sb.AppendLine("),"); + } + else + { + // Use VECTOR_DISTANCE() for exact brute-force search (flat index / no index) + sb.AppendLine("semantic_search AS ("); + sb.AppendLine(" SELECT TOP(@candidateCount)"); + sb.Append(" ").AppendIdentifier(model.KeyProperty.StorageName).AppendLine(","); + sb.AppendLine(" RANK() OVER (ORDER BY cosine_distance) AS [rank]"); + sb.AppendLine(" FROM ("); + sb.AppendLine(" SELECT TOP(@candidateCount)"); + sb.Append(" w.").AppendIdentifier(model.KeyProperty.StorageName).AppendLine(","); + sb.Append(" VECTOR_DISTANCE('").Append(distanceMetric).Append("', ") + .AppendIdentifier(vectorProperty.StorageName) + .Append(", CAST(@vector AS VECTOR(").Append(vector.Length).AppendLine("))) AS cosine_distance"); + sb.Append(" FROM ").AppendTableName(schema, tableName).AppendLine(" w"); + + // Apply filter to semantic search if specified + if (options.Filter is not null) + { + // We need to re-translate the filter for the semantic search CTE + // The parameters are already added from keyword search, so we start fresh for this CTE + int filterParamStart = command.Parameters.Count; + SqlServerFilterTranslator translator = new(model, options.Filter, sb, startParamIndex: filterParamStart, tableAlias: "w"); + translator.Translate(appendWhere: true); + foreach (object parameter in translator.ParameterValues) + { + command.AddParameter(property: null, $"@_{filterParamStart++}", parameter); + } + sb.AppendLine(); + } + + sb.AppendLine(" ORDER BY cosine_distance"); + sb.AppendLine(" ) AS similar_documents"); + sb.AppendLine("),"); + } + + // CTE 3: Combined results with RRF scoring + sb.AppendLine("hybrid_result AS ("); + sb.AppendLine(" SELECT"); + sb.Append(" COALESCE(ss.").AppendIdentifier(model.KeyProperty.StorageName) + .Append(", ks.").AppendIdentifier(model.KeyProperty.StorageName).AppendLine(") AS combined_key,"); + sb.AppendLine(" ss.[rank] AS semantic_rank,"); + sb.AppendLine(" ks.[rank] AS keyword_rank,"); + // Cast to FLOAT to match the expected return type in C# (double) + // Use @rrfK as the RRF constant (typically 60) + sb.AppendLine(" CAST(COALESCE(1.0 / (@rrfK + ss.[rank]), 0.0) + COALESCE(1.0 / (@rrfK + ks.[rank]), 0.0) AS FLOAT) AS [score]"); + sb.AppendLine(" FROM semantic_search ss"); + sb.Append(" FULL OUTER JOIN keyword_search ks ON ss.").AppendIdentifier(model.KeyProperty.StorageName) + .Append(" = ks.").AppendIdentifier(model.KeyProperty.StorageName).AppendLine(); + sb.AppendLine(")"); + + // Final SELECT joining back to the main table + sb.Append("SELECT "); + foreach (var property in model.Properties) + { + if (!options.IncludeVectors && property is VectorPropertyModel) + { + continue; + } + sb.Append("w.").AppendIdentifier(property.StorageName).Append(','); + } + sb.Length--; // remove trailing comma + sb.AppendLine(","); + sb.AppendLine(" hr.[score]"); + sb.AppendLine("FROM hybrid_result hr"); + sb.Append("INNER JOIN ").AppendTableName(schema, tableName).AppendLine(" w"); + sb.Append(" ON hr.combined_key = w.").AppendIdentifier(model.KeyProperty.StorageName).AppendLine(); + if (options.ScoreThreshold.HasValue) + { + command.Parameters.AddWithValue("@scoreThreshold", options.ScoreThreshold.Value); + sb.AppendLine("WHERE hr.[score] >= @scoreThreshold"); + } + sb.AppendLine("ORDER BY hr.[score] DESC"); + sb.AppendFormat("OFFSET {0} ROWS FETCH NEXT {1} ROWS ONLY;", options.Skip, top); + + command.CommandText = sb.ToString(); + return command; + } + + internal static SqlCommand SelectWhere( + Expression> filter, + int top, + FilteredRecordRetrievalOptions options, + SqlConnection connection, string? schema, string tableName, + CollectionModel model) + { + SqlCommand command = connection.CreateCommand(); + + StringBuilder sb = new(200); + sb.Append("SELECT "); + sb.AppendIdentifiers(model.Properties, includeVectors: options.IncludeVectors); + sb.AppendLine(); + sb.Append("FROM "); + sb.AppendTableName(schema, tableName); + sb.AppendLine(); + if (filter is not null) + { + int startParamIndex = command.Parameters.Count; + + SqlServerFilterTranslator translator = new(model, filter, sb, startParamIndex: startParamIndex); + translator.Translate(appendWhere: true); + List parameters = translator.ParameterValues; + + foreach (object parameter in parameters) + { + command.AddParameter(property: null, $"@_{startParamIndex++}", parameter); + } + sb.AppendLine(); + } + + var orderBy = options.OrderBy?.Invoke(new()).Values; + if (orderBy is { Count: > 0 }) + { + sb.Append("ORDER BY "); + + var first = true; + foreach (var sortInfo in orderBy) + { + if (!first) + { + sb.Append(','); + } + first = false; + sb.AppendIdentifier(model.GetDataOrKeyProperty(sortInfo.PropertySelector).StorageName) + .Append(sortInfo.Ascending ? " ASC" : " DESC"); + } + + sb.AppendLine(); + } + else + { + // no order by properties, but we need to add something for OFFSET and NEXT to work + sb.AppendLine("ORDER BY (SELECT 1)"); + } + + // Negative Skip and Top values are rejected by the GetFilteredRecordOptions property setters. + // 0 is a legal value for OFFSET. + sb.AppendFormat("OFFSET {0} ROWS FETCH NEXT {1} ROWS ONLY;", options.Skip, top); + + command.CommandText = sb.ToString(); + return command; + } + + internal static StringBuilder AppendParameterName(this StringBuilder sb, PropertyModel property, ref int paramIndex, out string parameterName) + { + // In SQL Server, parameter names cannot be just a number like "@1". + // Parameter names must start with an alphabetic character or an underscore + // and can be followed by alphanumeric characters or underscores. + // Since we can't guarantee that the value returned by StoragePropertyName and DataModelPropertyName + // is valid parameter name (it can contain whitespaces, or start with a number), + // we just append the ASCII letters, stop on the first non-ASCII letter + // and append the index. + int index = sb.Length; + sb.Append('@'); + foreach (char character in property.StorageName) + { + // We don't call APIs like char.IsWhitespace as they are expensive + // as they need to handle all Unicode characters. + if (character is not (>= 'a' and <= 'z' or >= 'A' and <= 'Z')) + { + break; + } + sb.Append(character); + } + // In case the column name is empty or does not start with ASCII letters, + // we provide the underscore as a prefix (allowed). + sb.Append('_'); + // To ensure the generated parameter id is unique, we append the index. + sb.Append(paramIndex++); + parameterName = sb.ToString(index, sb.Length - index); + + return sb; + } + + internal static StringBuilder AppendTableName(this StringBuilder sb, string? schema, string tableName) + { + // If the identifier contains a ], then escape it by doubling it. + // "Name with [brackets]" becomes [Name with [brackets]]]. + + if (!string.IsNullOrEmpty(schema)) + { + sb.AppendIdentifier(schema!).Append('.'); + } + + return sb.AppendIdentifier(tableName); + } + + /// + /// Appends a properly quoted and escaped SQL Server identifier to the StringBuilder. + /// If the identifier contains a ], it is escaped by doubling it. + /// + internal static StringBuilder AppendIdentifier(this StringBuilder sb, string identifier) + { + sb.Append('['); + int index = sb.Length; + sb.Append(identifier); + sb.Replace("]", "]]", index, identifier.Length); + sb.Append(']'); + return sb; + } + + /// + /// Same as , but for use inside a SQL string literal (N'...'), + /// where single quotes must be escaped by doubling them. + /// + internal static StringBuilder AppendTableNameInsideLiteral(this StringBuilder sb, string? schema, string tableName) + { + int start = sb.Length; + sb.AppendTableName(schema, tableName); + sb.Replace("'", "''", start, sb.Length - start); + return sb; + } + + /// + /// Same as , but for use inside a SQL string literal (N'...'), + /// where single quotes must be escaped by doubling them. + /// + internal static StringBuilder AppendIdentifierInsideLiteral(this StringBuilder sb, string identifier) + { + int start = sb.Length; + sb.AppendIdentifier(identifier); + sb.Replace("'", "''", start, sb.Length - start); + return sb; + } + + private static StringBuilder AppendIdentifiers(this StringBuilder sb, + IEnumerable properties, + string? prefix = null, + bool includeVectors = true, + bool skipKey = false) + { + bool any = false; + foreach (var property in properties) + { + if (!includeVectors && property is VectorPropertyModel) + { + continue; + } + + if (skipKey && property is KeyPropertyModel) + { + continue; + } + + if (prefix is not null) + { + sb.Append(prefix); + } + sb.AppendIdentifier(property.StorageName).Append(','); + any = true; + } + + if (any) + { + --sb.Length; // remove the last comma + } + + return sb; + } + + private static StringBuilder AppendKeyParameterList(this StringBuilder sb, + IEnumerable keys, SqlCommand command, KeyPropertyModel keyProperty, out bool emptyKeys) + { + int keyIndex = 0; + foreach (TKey key in keys) + { + // The caller ensures that keys collection is not null. + // We need to ensure that none of the keys is null. + Throw.IfNull(key); + + sb.AppendParameterName(keyProperty, ref keyIndex, out string keyParamName); + sb.Append(','); + command.AddParameter(keyProperty, keyParamName, key); + } + + emptyKeys = keyIndex == 0; + sb.Length--; // remove the last comma + return sb; + } + + private static StringBuilder AppendIndexName(this StringBuilder sb, string tableName, string columnName) + { + int length = sb.Length; + + // "Index names must start with a letter or an underscore (_)." + sb.Append("index"); + sb.Append('_'); + AppendAllowedOnly(tableName); + sb.Append('_'); + AppendAllowedOnly(columnName); + + if (sb.Length > length + SqlServerConstants.MaxIndexNameLength) + { + sb.Length = length + SqlServerConstants.MaxIndexNameLength; + } + + return sb; + + void AppendAllowedOnly(string value) + { + foreach (char c in value) + { + // Index names can include letters, numbers, and underscores. + if (char.IsLetterOrDigit(c) || c == '_') + { + sb.Append(c); + } + } + } + } + + private static SqlCommand CreateCommand(this SqlConnection connection, StringBuilder sb) + { + SqlCommand command = connection.CreateCommand(); + command.CommandText = sb.ToString(); + return command; + } + + private static void AddParameter(this SqlCommand command, PropertyModel? property, string name, object? value) + { + switch (value) + { + case null when property?.Type == typeof(byte[]): + command.Parameters.Add(name, System.Data.SqlDbType.VarBinary).Value = DBNull.Value; + break; + case null: + command.Parameters.AddWithValue(name, DBNull.Value); + break; + case byte[] buffer: + command.Parameters.Add(name, System.Data.SqlDbType.VarBinary).Value = buffer; + break; + case DateTime dateTime: + command.Parameters.Add(name, System.Data.SqlDbType.DateTime2).Value = dateTime; + break; + + // Note that SqlVector doesn't any transformation and can be passed as-is (default case below) + case ReadOnlyMemory vector: + command.Parameters.AddWithValue(name, new SqlVector(vector)); + break; + case Embedding { Vector: var vector }: + command.Parameters.AddWithValue(name, new SqlVector(vector)); + break; + case float[] vectorArray: + command.Parameters.AddWithValue(name, new SqlVector(vectorArray)); + break; + + case string[] strings: + command.Parameters.AddWithValue(name, JsonSerializer.Serialize(strings, SqlServerJsonSerializerContext.Default.StringArray)); + break; + case List strings: + command.Parameters.AddWithValue(name, JsonSerializer.Serialize(strings, SqlServerJsonSerializerContext.Default.ListString)); + break; + + default: + command.Parameters.AddWithValue(name, value); + break; + } + } + + private static string Map(PropertyModel property) + => (Nullable.GetUnderlyingType(property.Type) ?? property.Type) switch + { + Type t when t == typeof(byte) => "TINYINT", + Type t when t == typeof(short) => "SMALLINT", + Type t when t == typeof(int) => "INT", + Type t when t == typeof(long) => "BIGINT", + Type t when t == typeof(Guid) => "UNIQUEIDENTIFIER", + Type t when t == typeof(string) && property is KeyPropertyModel => "NVARCHAR(4000)", + Type t when t == typeof(string) && property is DataPropertyModel { IsIndexed: true } => "NVARCHAR(4000)", + Type t when t == typeof(string) => "NVARCHAR(MAX)", + Type t when t == typeof(byte[]) => "VARBINARY(MAX)", + Type t when t == typeof(bool) => "BIT", + Type t when t == typeof(DateTime) => "DATETIME2", + Type t when t == typeof(DateTimeOffset) => "DATETIMEOFFSET", +#if NET + Type t when t == typeof(DateOnly) => "DATE", + Type t when t == typeof(TimeOnly) => "TIME", +#endif + Type t when t == typeof(decimal) => "DECIMAL(18,2)", + Type t when t == typeof(double) => "FLOAT", + Type t when t == typeof(float) => "REAL", + + Type t when t == typeof(string[]) || t == typeof(List) => "JSON", + + _ => throw new NotSupportedException($"Type {property.Type} is not supported.") + }; + + // Source: https://learn.microsoft.com/sql/t-sql/functions/vector-distance-transact-sql + private static (string distanceMetric, string sorting) MapDistanceFunction(string name) => name switch + { + // A value of 0 indicates that the vectors are identical in direction (cosine similarity of 1), + // while a value of 1 indicates that the vectors are orthogonal (cosine similarity of 0). + DistanceFunction.CosineDistance => ("COSINE", "ASC"), + // A value of 0 indicates that the vectors are identical, while larger values indicate greater dissimilarity. + DistanceFunction.EuclideanDistance => ("EUCLIDEAN", "ASC"), + // Smaller numbers indicate more similar vectors + DistanceFunction.NegativeDotProductSimilarity => ("DOT", "ASC"), + _ => throw new NotSupportedException($"Distance function {name} is not supported.") + }; + + /// + /// Returns whether VECTOR_SEARCH() (approximate/indexed search) should be used for the given vector property, + /// as opposed to VECTOR_DISTANCE() (exact/brute-force search). + /// + private static bool UseVectorSearch(VectorPropertyModel vectorProperty) + => vectorProperty.IndexKind is not (null or "" or IndexKind.Flat); +} diff --git a/src/Microsoft.SqlServer.VectorData/SqlServerConstants.cs b/src/Microsoft.SqlServer.VectorData/SqlServerConstants.cs new file mode 100644 index 0000000..3a6472d --- /dev/null +++ b/src/Microsoft.SqlServer.VectorData/SqlServerConstants.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft. All rights reserved. + +namespace Microsoft.SqlServer.VectorData; + +internal static class SqlServerConstants +{ + internal const string VectorStoreSystemName = "microsoft.sql_server"; + + // The actual number is actually higher (2_100), but we want to avoid any kind of "off by one" errors. + internal const int MaxParameterCount = 2_000; + + internal const int MaxIndexNameLength = 128; +} diff --git a/src/Microsoft.SqlServer.VectorData/SqlServerDynamicCollection.cs b/src/Microsoft.SqlServer.VectorData/SqlServerDynamicCollection.cs new file mode 100644 index 0000000..7381b72 --- /dev/null +++ b/src/Microsoft.SqlServer.VectorData/SqlServerDynamicCollection.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; + +namespace Microsoft.SqlServer.VectorData; + +/// +/// Represents a collection of vector store records in a SqlServer database, mapped to a dynamic Dictionary<string, object?>. +/// +#pragma warning disable CA1711 // Identifiers should not have incorrect suffix +public sealed class SqlServerDynamicCollection : SqlServerCollection> +#pragma warning restore CA1711 // Identifiers should not have incorrect suffix +{ + /// + /// Initializes a new instance of the class. + /// + /// Database connection string. + /// The name of the collection. + /// Optional configuration options for this class. + // TODO: The provider uses unsafe JSON serialization in many places, #11963 + [RequiresUnreferencedCode("The SQL Server provider is currently incompatible with trimming.")] + [RequiresDynamicCode("The SQL Server provider is currently incompatible with NativeAOT.")] + public SqlServerDynamicCollection(string connectionString, string name, SqlServerCollectionOptions options) + : base( + connectionString, + name, + static options => new SqlServerModelBuilder() + .BuildDynamic( + options.Definition ?? throw new ArgumentException("RecordDefinition is required for dynamic collections"), + options.EmbeddingGenerator), + options) + { + } +} diff --git a/src/Microsoft.SqlServer.VectorData/SqlServerFilterTranslator.cs b/src/Microsoft.SqlServer.VectorData/SqlServerFilterTranslator.cs new file mode 100644 index 0000000..6b42fda --- /dev/null +++ b/src/Microsoft.SqlServer.VectorData/SqlServerFilterTranslator.cs @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections; +using System.Collections.Generic; +#if NET +using System.Globalization; +#endif +using System.Linq.Expressions; +using System.Text; +using Microsoft.Extensions.VectorData.ProviderServices; + +namespace Microsoft.SqlServer.VectorData; + +internal sealed class SqlServerFilterTranslator : SqlFilterTranslator +{ + private readonly List _parameterValues = []; + private readonly string? _tableAlias; + private int _parameterIndex; + + internal SqlServerFilterTranslator( + CollectionModel model, + LambdaExpression lambdaExpression, + StringBuilder sql, + int startParamIndex, + string? tableAlias = null) + : base(model, lambdaExpression, sql) + { + this._parameterIndex = startParamIndex; + this._tableAlias = tableAlias; + } + + internal List ParameterValues => this._parameterValues; + + protected override void TranslateConstant(object? value, bool isSearchCondition) + { + switch (value) + { + case bool boolValue when isSearchCondition: + this._sql.Append(boolValue ? "1 = 1" : "1 = 0"); + return; + case bool boolValue: + this._sql.Append(boolValue ? "CAST(1 AS BIT)" : "CAST(0 AS BIT)"); + return; + case DateTime dateTime: + this._sql.Append('\'').Append(dateTime.ToString("o")).Append('\''); + return; + case DateTimeOffset dateTimeOffset: + this._sql.Append('\'').Append(dateTimeOffset.ToString("o")).Append('\''); + return; +#if NET + case DateOnly dateOnly: + this._sql.Append('\'').Append(dateOnly.ToString("o")).Append('\''); + return; + case TimeOnly timeOnly: + this._sql.AppendFormat(timeOnly.Ticks % 10000000 == 0 + ? string.Format(CultureInfo.InvariantCulture, @"'{0:HH\:mm\:ss}'", value) + : string.Format(CultureInfo.InvariantCulture, @"'{0:HH\:mm\:ss\.FFFFFFF}'", value)); + return; +#endif + + default: + base.TranslateConstant(value, isSearchCondition); + break; + } + } + + protected override void GenerateColumn(PropertyModel property, bool isSearchCondition = false) + { + // StorageName is considered to be a safe input, we quote and escape it mostly to produce valid SQL. + if (this._tableAlias is not null) + { + this._sql.Append(this._tableAlias).Append('.'); + } + this._sql.Append('[').Append(property.StorageName.Replace("]", "]]")).Append(']'); + + // "SELECT * FROM MyTable WHERE BooleanColumn;" is not supported. + // "SELECT * FROM MyTable WHERE BooleanColumn = 1;" is supported. + if (isSearchCondition) + { + this._sql.Append(" = 1"); + } + } + + protected override void TranslateContainsOverArrayColumn(Expression source, Expression item) + { + if (item.Type != typeof(string)) + { + throw new NotSupportedException("Unsupported Contains expression"); + } + + this._sql.Append("JSON_CONTAINS("); + this.Translate(source); + this._sql.Append(", "); + this.Translate(item); + this._sql.Append(") = 1"); + } + + protected override void TranslateContainsOverParameterizedArray(Expression source, Expression item, object? value) + { + if (value is not IEnumerable elements) + { + throw new NotSupportedException("Unsupported Contains expression"); + } + + this.Translate(item); + this._sql.Append(" IN ("); + + var isFirst = true; + foreach (var element in elements) + { + if (isFirst) + { + isFirst = false; + } + else + { + this._sql.Append(", "); + } + + this.TranslateConstant(element, isSearchCondition: false); + } + + this._sql.Append(')'); + } + + protected override void TranslateAnyContainsOverArrayColumn(PropertyModel property, object? values) + { + // Translate r.Strings.Any(s => array.Contains(s)) to: + // EXISTS(SELECT 1 FROM OPENJSON(column) WHERE value IN ('a', 'b', 'c')) + if (values is not IEnumerable elements) + { + throw new NotSupportedException("Unsupported Any expression"); + } + + this._sql.Append("EXISTS(SELECT 1 FROM OPENJSON("); + this.GenerateColumn(property); + this._sql.Append(") WHERE value IN ("); + + var isFirst = true; + foreach (var element in elements) + { + if (isFirst) + { + isFirst = false; + } + else + { + this._sql.Append(", "); + } + + this.TranslateConstant(element, isSearchCondition: false); + } + + this._sql.Append("))"); + } + + protected override void TranslateQueryParameter(object? value) + { + // For null values, simply inline rather than parameterize; parameterized NULLs require setting NpgsqlDbType which is a bit more complicated, + // plus in any case equality with NULL requires different SQL (x IS NULL rather than x = y) + if (value is null) + { + this._sql.Append("NULL"); + } + else + { + this._parameterValues.Add(value); + // The param name is just the index, so there is no need for escaping or quoting. + // SQL Server parameters can't start with a digit (but underscore is OK). + this._sql.Append("@_").Append(this._parameterIndex++); + } + } +} diff --git a/src/Microsoft.SqlServer.VectorData/SqlServerJsonSerializerContext.cs b/src/Microsoft.SqlServer.VectorData/SqlServerJsonSerializerContext.cs new file mode 100644 index 0000000..8bd9de7 --- /dev/null +++ b/src/Microsoft.SqlServer.VectorData/SqlServerJsonSerializerContext.cs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Collections.Generic; +using System.Text.Json.Serialization; + +namespace Microsoft.SqlServer.VectorData; + +// For mapping string[] properties to SQL Server JSON columns +[JsonSerializable(typeof(string[]))] +[JsonSerializable(typeof(List))] +internal partial class SqlServerJsonSerializerContext : JsonSerializerContext; diff --git a/src/Microsoft.SqlServer.VectorData/SqlServerMapper.cs b/src/Microsoft.SqlServer.VectorData/SqlServerMapper.cs new file mode 100644 index 0000000..1439161 --- /dev/null +++ b/src/Microsoft.SqlServer.VectorData/SqlServerMapper.cs @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Text.Json; +using Microsoft.Data.SqlClient; +using Microsoft.Data.SqlTypes; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData.ProviderServices; + +namespace Microsoft.SqlServer.VectorData; + +internal sealed class SqlServerMapper(CollectionModel model) +{ + public TRecord MapFromStorageToDataModel(SqlDataReader reader, bool includeVectors) + { + var record = model.CreateRecord()!; + + PopulateValue(reader, model.KeyProperty, record); + + foreach (var property in model.DataProperties) + { + PopulateValue(reader, property, record); + } + + if (includeVectors) + { + foreach (var property in model.VectorProperties) + { + try + { + var ordinal = reader.GetOrdinal(property.StorageName); + + if (!reader.IsDBNull(ordinal)) + { + var vector = reader.GetFieldValue>(ordinal); + + property.SetValueAsObject(record, property.Type switch + { + var t when t == typeof(SqlVector) => vector, + var t when t == typeof(ReadOnlyMemory) => vector.Memory, + var t when t == typeof(Embedding) => new Embedding(vector.Memory), + var t when t == typeof(float[]) + => MemoryMarshal.TryGetArray(vector.Memory, out ArraySegment segment) + && segment.Count == segment.Array!.Length + ? segment.Array + : vector.Memory.ToArray(), + + _ => throw new UnreachableException() + }); + } + } + catch (Exception e) + { + throw new InvalidOperationException($"Failed to deserialize vector property '{property.ModelName}'.", e); + } + } + } + + return record; + + static void PopulateValue(SqlDataReader reader, PropertyModel property, object record) + { + try + { + var ordinal = reader.GetOrdinal(property.StorageName); + + if (reader.IsDBNull(ordinal)) + { + property.SetValueAsObject(record, null); + return; + } + + switch (Nullable.GetUnderlyingType(property.Type) ?? property.Type) + { + case var t when t == typeof(byte): + property.SetValue(record, reader.GetByte(ordinal)); // TINYINT + break; + case var t when t == typeof(short): + property.SetValue(record, reader.GetInt16(ordinal)); // SMALLINT + break; + case var t when t == typeof(int): + property.SetValue(record, reader.GetInt32(ordinal)); // INT + break; + case var t when t == typeof(long): + property.SetValue(record, reader.GetInt64(ordinal)); // BIGINT + break; + + case var t when t == typeof(float): + property.SetValue(record, reader.GetFloat(ordinal)); // REAL + break; + case var t when t == typeof(double): + property.SetValue(record, reader.GetDouble(ordinal)); // FLOAT + break; + case var t when t == typeof(decimal): + property.SetValue(record, reader.GetDecimal(ordinal)); // DECIMAL + break; + + case var t when t == typeof(string): + property.SetValue(record, reader.GetString(ordinal)); // NVARCHAR + break; + case var t when t == typeof(Guid): + property.SetValue(record, reader.GetGuid(ordinal)); // UNIQUEIDENTIFIER + break; + case var t when t == typeof(byte[]): + property.SetValueAsObject(record, reader.GetValue(ordinal)); // VARBINARY + break; + case var t when t == typeof(bool): + property.SetValue(record, reader.GetBoolean(ordinal)); // BIT + break; + + case var t when t == typeof(DateTime): + property.SetValue(record, reader.GetDateTime(ordinal)); // DATETIME2 + break; + case var t when t == typeof(DateTimeOffset): + property.SetValue(record, reader.GetDateTimeOffset(ordinal)); // DATETIMEOFFSET + break; +#if NET + case var t when t == typeof(DateOnly): + property.SetValue(record, reader.GetFieldValue(ordinal)); // DATE + break; + case var t when t == typeof(TimeOnly): + property.SetValue(record, reader.GetFieldValue(ordinal)); // TIME + break; +#endif + + // We map string[] and List properties to SQL Server JSON columns, so deserialize from JSON here. + case var t when t == typeof(string[]): + property.SetValue(record, JsonSerializer.Deserialize( + reader.GetString(ordinal), + SqlServerJsonSerializerContext.Default.StringArray)); + break; + case var t when t == typeof(List): + property.SetValue(record, JsonSerializer.Deserialize>( + reader.GetString(ordinal), + SqlServerJsonSerializerContext.Default.ListString)); + break; + + default: + throw new NotSupportedException($"Unsupported type '{property.Type.Name}' for property '{property.ModelName}'."); + } + } + catch (Exception ex) + { + throw new InvalidOperationException($"Failed to read property '{property.ModelName}' of type '{property.Type.Name}'.", ex); + } + } + } +} diff --git a/src/Microsoft.SqlServer.VectorData/SqlServerModelBuilder.cs b/src/Microsoft.SqlServer.VectorData/SqlServerModelBuilder.cs new file mode 100644 index 0000000..306f349 --- /dev/null +++ b/src/Microsoft.SqlServer.VectorData/SqlServerModelBuilder.cs @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Data.SqlTypes; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ProviderServices; + +namespace Microsoft.SqlServer.VectorData; + +internal class SqlServerModelBuilder() : CollectionModelBuilder(s_modelBuildingOptions) +{ + internal const string SupportedVectorTypes = "SqlVector, ReadOnlyMemory, Embedding, float[]"; + internal const string SupportedIndexKinds = $"{IndexKind.Flat}, {IndexKind.DiskAnn}"; + + private static readonly CollectionModelBuildingOptions s_modelBuildingOptions = new() + { + RequiresAtLeastOneVector = false, + SupportsMultipleVectors = true, + }; + + protected override bool SupportsKeyAutoGeneration(Type keyPropertyType) + => keyPropertyType == typeof(Guid) || keyPropertyType == typeof(int) || keyPropertyType == typeof(long); + + protected override void ValidateKeyProperty(KeyPropertyModel keyProperty) + { + base.ValidateKeyProperty(keyProperty); + + var type = keyProperty.Type; + + if (type != typeof(int) && type != typeof(long) && type != typeof(string) && type != typeof(Guid)) + { + throw new NotSupportedException( + $"Property '{keyProperty.ModelName}' has unsupported type '{type.Name}'. Key properties must be one of the supported types: int, long, string, Guid."); + } + } + + protected override void ValidateProperty(PropertyModel propertyModel, VectorStoreCollectionDefinition? definition) + { + base.ValidateProperty(propertyModel, definition); + + if (propertyModel is VectorPropertyModel vectorProperty) + { + switch (vectorProperty.IndexKind) + { + case IndexKind.Flat or IndexKind.DiskAnn or null or "": + break; + default: + throw new NotSupportedException( + $"Index kind '{vectorProperty.IndexKind}' is not supported by the SQL Server connector. Supported index kinds: {SupportedIndexKinds}"); + } + } + } + + protected override bool IsDataPropertyTypeValid(Type type, [NotNullWhen(false)] out string? supportedTypes) + { + supportedTypes = "string, short, int, long, double, float, decimal, bool, DateTime, DateTimeOffset, DateOnly, TimeOnly, Guid, byte[], string[], List"; + + if (Nullable.GetUnderlyingType(type) is Type underlyingType) + { + type = underlyingType; + } + + return type == typeof(int) // INT + || type == typeof(short) // SMALLINT + || type == typeof(byte) // TINYINT + || type == typeof(long) // BIGINT. + || type == typeof(Guid) // UNIQUEIDENTIFIER. + || type == typeof(string) // NVARCHAR + || type == typeof(byte[]) // VARBINARY + || type == typeof(bool) // BIT + || type == typeof(DateTime) // DATETIME2 + || type == typeof(DateTimeOffset) // DATETIMEOFFSET +#if NET + || type == typeof(DateOnly) // DATE + // We don't support mapping TimeSpan to TIME on purpose + // See https://github.com/microsoft/semantic-kernel/pull/10623#discussion_r1980350721 + || type == typeof(TimeOnly) // TIME +#endif + || type == typeof(decimal) // DECIMAL + || type == typeof(double) // FLOAT + || type == typeof(float) // REAL + + // We map string[] to the SQL Server 2025 JSON data type (anyone using vector search is already using 2025) + || type == typeof(string[]) // JSON + || type == typeof(List); // JSON + } + + protected override bool IsVectorPropertyTypeValid(Type type, [NotNullWhen(false)] out string? supportedTypes) + => IsVectorPropertyTypeValidCore(type, out supportedTypes); + + internal static bool IsVectorPropertyTypeValidCore(Type type, [NotNullWhen(false)] out string? supportedTypes) + { + supportedTypes = SupportedVectorTypes; + + return type == typeof(ReadOnlyMemory) + || type == typeof(ReadOnlyMemory?) + || type == typeof(Embedding) + || type == typeof(float[]) + // SqlClient-specific type representing a vector + || type == typeof(SqlVector) + || type == typeof(SqlVector?); + } +} diff --git a/src/Microsoft.SqlServer.VectorData/SqlServerServiceCollectionExtensions.cs b/src/Microsoft.SqlServer.VectorData/SqlServerServiceCollectionExtensions.cs new file mode 100644 index 0000000..5734bca --- /dev/null +++ b/src/Microsoft.SqlServer.VectorData/SqlServerServiceCollectionExtensions.cs @@ -0,0 +1,195 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Diagnostics.CodeAnalysis; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData; +using Microsoft.Shared.Diagnostics; +using Microsoft.SqlServer.VectorData; + +namespace Microsoft.Extensions.DependencyInjection; + +/// +/// Extension methods to register instances on an . +/// +public static class SqlServerServiceCollectionExtensions +{ + /// + /// Registers a as , with the specified connection string and service lifetime. + /// + /// + [RequiresUnreferencedCode("The SQL Server provider is currently incompatible with trimming.")] + [RequiresDynamicCode("The SQL Server provider is currently incompatible with NativeAOT.")] + public static IServiceCollection AddSqlServerVectorStore( + this IServiceCollection services, + Func connectionStringProvider, + Func? optionsProvider = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + => AddKeyedSqlServerVectorStore(services, serviceKey: null, connectionStringProvider, optionsProvider, lifetime); + + /// + /// Registers a keyed as , with the specified connection string and service lifetime. + /// + /// The to register the on. + /// The key with which to associate the vector store. + /// The connection string provider. + /// Options provider to further configure the vector store. + /// The service lifetime for the store. Defaults to . + /// The service collection. + [RequiresUnreferencedCode("The SQL Server provider is currently incompatible with trimming.")] + [RequiresDynamicCode("The SQL Server provider is currently incompatible with NativeAOT.")] + public static IServiceCollection AddKeyedSqlServerVectorStore( + this IServiceCollection services, + object? serviceKey, + Func connectionStringProvider, + Func? optionsProvider = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + { + Throw.IfNull(services); + Throw.IfNull(connectionStringProvider); + + services.Add(new ServiceDescriptor(typeof(SqlServerVectorStore), serviceKey, (sp, _) => + { + var connectionString = connectionStringProvider(sp); + var options = GetStoreOptions(sp, optionsProvider); + return new SqlServerVectorStore(connectionString, options); + }, lifetime)); + + services.Add(new ServiceDescriptor(typeof(VectorStore), serviceKey, + static (sp, key) => sp.GetRequiredKeyedService(key), lifetime)); + + return services; + } + + /// + /// Registers a as , with the specified connection string and service lifetime. + /// + /// + [RequiresUnreferencedCode("The SQL Server provider is currently incompatible with trimming.")] + [RequiresDynamicCode("The SQL Server provider is currently incompatible with NativeAOT.")] + public static IServiceCollection AddSqlServerCollection( + this IServiceCollection services, + string name, + Func connectionStringProvider, + Func? optionsProvider = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + where TKey : notnull + where TRecord : class + => AddKeyedSqlServerCollection(services, serviceKey: null, name, connectionStringProvider, optionsProvider, lifetime); + + /// + /// Registers a keyed as , with the specified connection string and service lifetime. + /// + /// The to register the on. + /// The key with which to associate the collection. + /// The name of the collection. + /// The connection string provider. + /// Options provider to further configure the collection. + /// The service lifetime for the store. Defaults to . + /// The service collection. + [RequiresUnreferencedCode("The SQL Server provider is currently incompatible with trimming.")] + [RequiresDynamicCode("The SQL Server provider is currently incompatible with NativeAOT.")] + public static IServiceCollection AddKeyedSqlServerCollection( + this IServiceCollection services, + object? serviceKey, + string name, + Func connectionStringProvider, + Func? optionsProvider = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + where TKey : notnull + where TRecord : class + { + Throw.IfNull(services); + Throw.IfNullOrWhitespace(name); + Throw.IfNull(connectionStringProvider); + + services.Add(new ServiceDescriptor(typeof(SqlServerCollection), serviceKey, (sp, _) => + { + var connectionString = connectionStringProvider(sp); + var options = GetCollectionOptions(sp, optionsProvider); + return new SqlServerCollection(connectionString, name, options); + }, lifetime)); + + services.Add(new ServiceDescriptor(typeof(VectorStoreCollection), serviceKey, + static (sp, key) => sp.GetRequiredKeyedService>(key), lifetime)); + + services.Add(new ServiceDescriptor(typeof(IVectorSearchable), serviceKey, + static (sp, key) => sp.GetRequiredKeyedService>(key), lifetime)); + + services.Add(new ServiceDescriptor(typeof(IKeywordHybridSearchable), serviceKey, + static (sp, key) => sp.GetRequiredKeyedService>(key), lifetime)); + + return services; + } + + /// + /// Registers a as , with the specified connection string and service lifetime. + /// + /// /> + [RequiresUnreferencedCode("The SQL Server provider is currently incompatible with trimming.")] + [RequiresDynamicCode("The SQL Server provider is currently incompatible with NativeAOT.")] + public static IServiceCollection AddSqlServerCollection( + this IServiceCollection services, + string name, + string connectionString, + SqlServerCollectionOptions? options = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + where TKey : notnull + where TRecord : class + => AddKeyedSqlServerCollection(services, serviceKey: null, name, connectionString, options, lifetime); + + /// + /// Registers a keyed as , with the specified connection string and service lifetime. + /// + /// The to register the on. + /// The key with which to associate the collection. + /// The name of the collection. + /// The connection string. + /// Options to further configure the collection. + /// The service lifetime for the store. Defaults to . + /// The service collection. + [RequiresUnreferencedCode("The SQL Server provider is currently incompatible with trimming.")] + [RequiresDynamicCode("The SQL Server provider is currently incompatible with NativeAOT.")] + public static IServiceCollection AddKeyedSqlServerCollection( + this IServiceCollection services, + object? serviceKey, + string name, + string connectionString, + SqlServerCollectionOptions? options = null, + ServiceLifetime lifetime = ServiceLifetime.Singleton) + where TKey : notnull + where TRecord : class + { + Throw.IfNullOrWhitespace(connectionString); + + return AddKeyedSqlServerCollection(services, serviceKey, name, _ => connectionString, _ => options!, lifetime); + } + + private static SqlServerVectorStoreOptions? GetStoreOptions(IServiceProvider sp, Func? optionsProvider) + { + var options = optionsProvider?.Invoke(sp); + if (options?.EmbeddingGenerator is not null) + { + return options; // The user has provided everything, there is nothing to change. + } + + var embeddingGenerator = sp.GetService(); + return embeddingGenerator is null + ? options // There is nothing to change. + : new(options) { EmbeddingGenerator = embeddingGenerator }; // Create a brand new copy in order to avoid modifying the original options. + } + + private static SqlServerCollectionOptions? GetCollectionOptions(IServiceProvider sp, Func? optionsProvider) + { + var options = optionsProvider?.Invoke(sp); + if (options?.EmbeddingGenerator is not null) + { + return options; // The user has provided everything, there is nothing to change. + } + + var embeddingGenerator = sp.GetService(); + return embeddingGenerator is null + ? options // There is nothing to change. + : new(options) { EmbeddingGenerator = embeddingGenerator }; // Create a brand new copy in order to avoid modifying the original options. + } +} diff --git a/src/Microsoft.SqlServer.VectorData/SqlServerVectorStore.cs b/src/Microsoft.SqlServer.VectorData/SqlServerVectorStore.cs new file mode 100644 index 0000000..28f23dd --- /dev/null +++ b/src/Microsoft.SqlServer.VectorData/SqlServerVectorStore.cs @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Data.SqlClient; +using Microsoft.Extensions.AI; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ProviderServices; +using Microsoft.Shared.Diagnostics; + +namespace Microsoft.SqlServer.VectorData; + +/// +/// An implementation of backed by a SQL Server or Azure SQL database. +/// +public sealed class SqlServerVectorStore : VectorStore +{ + private readonly string _connectionString; + + /// Metadata about vector store. + private readonly VectorStoreMetadata _metadata; + + /// A general purpose definition that can be used to construct a collection when needing to proxy schema agnostic operations. + private static readonly VectorStoreCollectionDefinition s_generalPurposeDefinition = new() { Properties = [new VectorStoreKeyProperty("Key", typeof(string))] }; + + /// The database schema. + private readonly string? _schema; + + private readonly IEmbeddingGenerator? _embeddingGenerator; + + /// + /// Initializes a new instance of the class. + /// + /// The connection string. + /// Optional configuration options. + [RequiresUnreferencedCode("The SQL Server provider is currently incompatible with trimming.")] + [RequiresDynamicCode("The SQL Server provider is currently incompatible with NativeAOT.")] + public SqlServerVectorStore(string connectionString, SqlServerVectorStoreOptions? options = null) + { + Throw.IfNullOrWhitespace(connectionString); + + this._connectionString = connectionString; + + options ??= SqlServerVectorStoreOptions.Defaults; + this._schema = options.Schema; + this._embeddingGenerator = options.EmbeddingGenerator; + + var connectionStringBuilder = new SqlConnectionStringBuilder(connectionString); + + this._metadata = new() + { + VectorStoreSystemName = SqlServerConstants.VectorStoreSystemName, + VectorStoreName = connectionStringBuilder.InitialCatalog + }; + } + +#pragma warning disable IDE0090 // Use 'new(...)' + /// + [RequiresUnreferencedCode("The SQL Server provider is currently incompatible with trimming.")] + [RequiresDynamicCode("The SQL Server provider is currently incompatible with NativeAOT.")] +#if NET + public override SqlServerCollection GetCollection(string name, VectorStoreCollectionDefinition? definition = null) +#else + public override VectorStoreCollection GetCollection(string name, VectorStoreCollectionDefinition? definition = null) +#endif + => typeof(TRecord) == typeof(Dictionary) + ? throw new ArgumentException(VectorDataStrings.GetCollectionWithDictionaryNotSupported) + : new SqlServerCollection( + this._connectionString, + name, + new() + { + Schema = this._schema, + Definition = definition, + EmbeddingGenerator = this._embeddingGenerator + }); + + /// + // TODO: The provider uses unsafe JSON serialization in many places, #11963 + [RequiresUnreferencedCode("The SQL Server provider is currently incompatible with trimming.")] + [RequiresDynamicCode("The SQL Server provider is currently incompatible with NativeAOT.")] +#if NET + public override SqlServerDynamicCollection GetDynamicCollection(string name, VectorStoreCollectionDefinition definition) +#else + public override VectorStoreCollection> GetDynamicCollection(string name, VectorStoreCollectionDefinition definition) +#endif + => new SqlServerDynamicCollection( + this._connectionString, + name, + new() + { + Schema = this._schema, + Definition = definition, + EmbeddingGenerator = this._embeddingGenerator, + } + ); +#pragma warning restore IDE0090 + + /// + public override async IAsyncEnumerable ListCollectionNamesAsync([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + using SqlConnection connection = new(this._connectionString); + using SqlCommand command = SqlServerCommandBuilder.SelectTableNames(connection, this._schema); + + using SqlDataReader reader = await connection.ExecuteWithErrorHandlingAsync( + this._metadata, + operationName: "ListCollectionNames", + () => command.ExecuteReaderAsync(cancellationToken), + cancellationToken).ConfigureAwait(false); + + while (await reader.ReadWithErrorHandlingAsync( + this._metadata, + operationName: "ListCollectionNames", + cancellationToken).ConfigureAwait(false)) + { + yield return reader.GetString(reader.GetOrdinal("table_name")); + } + } + + /// + public override Task CollectionExistsAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetDynamicCollection(name, s_generalPurposeDefinition); + return collection.CollectionExistsAsync(cancellationToken); + } + + /// + public override Task EnsureCollectionDeletedAsync(string name, CancellationToken cancellationToken = default) + { + var collection = this.GetDynamicCollection(name, s_generalPurposeDefinition); + return collection.EnsureCollectionDeletedAsync(cancellationToken); + } + + /// + public override object? GetService(Type serviceType, object? serviceKey = null) + { + Throw.IfNull(serviceType); + + return + serviceKey is not null ? null : + serviceType == typeof(VectorStoreMetadata) ? this._metadata : + serviceType.IsInstanceOfType(this) ? this : + null; + } +} diff --git a/src/Microsoft.SqlServer.VectorData/SqlServerVectorStoreOptions.cs b/src/Microsoft.SqlServer.VectorData/SqlServerVectorStoreOptions.cs new file mode 100644 index 0000000..33761a4 --- /dev/null +++ b/src/Microsoft.SqlServer.VectorData/SqlServerVectorStoreOptions.cs @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; + +namespace Microsoft.SqlServer.VectorData; + +/// +/// Options for creating a . +/// +public sealed class SqlServerVectorStoreOptions +{ + internal static readonly SqlServerVectorStoreOptions Defaults = new(); + + /// + /// Initializes a new instance of the class. + /// + public SqlServerVectorStoreOptions() + { + } + + internal SqlServerVectorStoreOptions(SqlServerVectorStoreOptions? source) + { + this.Schema = source?.Schema; + this.EmbeddingGenerator = source?.EmbeddingGenerator; + } + + /// + /// Gets or sets the database schema. + /// + public string? Schema { get; set; } + + /// + /// Gets or sets the default embedding generator to use when generating vectors embeddings with this vector store. + /// + public IEmbeddingGenerator? EmbeddingGenerator { get; set; } +} diff --git a/src/Microsoft.SqlServer.VectorData/Throw.cs b/src/Microsoft.SqlServer.VectorData/Throw.cs new file mode 100644 index 0000000..9f3d6e8 --- /dev/null +++ b/src/Microsoft.SqlServer.VectorData/Throw.cs @@ -0,0 +1,72 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.CompilerServices; + +#pragma warning disable CA1716 +#pragma warning disable CS8777 // Parameter must have a non-null value when exiting +namespace Microsoft.Shared.Diagnostics; +#pragma warning restore CA1716 + +/// +/// Defines static methods used to throw exceptions. +/// +[ExcludeFromCodeCoverage] +internal static partial class Throw +{ + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [return: NotNull] + public static T IfNull([NotNull] T argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument is null) + { + ArgumentNullException(paramName); + } + + return argument; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [return: NotNull] + public static string IfNullOrWhitespace([NotNull] string? argument, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (string.IsNullOrWhiteSpace(argument)) + { + if (argument == null) + { + ArgumentNullException(paramName); + } + else + { + ArgumentException(paramName, "Argument is whitespace"); + } + } + + return argument!; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static int IfLessThan(int argument, int min, [CallerArgumentExpression(nameof(argument))] string paramName = "") + { + if (argument < min) + { + ArgumentOutOfRangeException(paramName, argument, $"Argument less than minimum value {min}"); + } + + return argument; + } + + [DoesNotReturn] + public static void ArgumentNullException(string paramName) + => throw new ArgumentNullException(paramName); + + [DoesNotReturn] + public static void ArgumentOutOfRangeException(string paramName, object? actualValue, string? message) + => throw new ArgumentOutOfRangeException(paramName, actualValue, message); + + [DoesNotReturn] + public static void ArgumentException(string paramName, string? message) + => throw new ArgumentException(message, paramName); +} diff --git a/src/Microsoft.SqlServer.VectorData/VectorStoreErrorHandler.cs b/src/Microsoft.SqlServer.VectorData/VectorStoreErrorHandler.cs new file mode 100644 index 0000000..d31ff62 --- /dev/null +++ b/src/Microsoft.SqlServer.VectorData/VectorStoreErrorHandler.cs @@ -0,0 +1,256 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Collections.Generic; +using System.Data.Common; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Extensions.VectorData; + +#pragma warning disable MEVD9000 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. + +/// +/// Contains helpers for reading vector store model properties and their attributes. +/// +[ExcludeFromCodeCoverage] +internal static class VectorStoreErrorHandler +{ + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static Task RunOperationAsync( + VectorStoreMetadata metadata, + string operationName, + Func> operation) + where TException : Exception + { + return RunOperationAsync( + new VectorStoreCollectionMetadata() + { + CollectionName = null, + VectorStoreName = metadata.VectorStoreName, + VectorStoreSystemName = metadata.VectorStoreSystemName, + }, + operationName, + operation); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static async Task RunOperationAsync( + VectorStoreCollectionMetadata metadata, + string operationName, + Func> operation) + where TException : Exception + { + try + { + return await operation.Invoke().ConfigureAwait(false); + } + catch (AggregateException ex) when (ex.InnerException is TException innerEx) + { + throw new VectorStoreException("Call to vector store failed.", ex) + { + VectorStoreSystemName = metadata.VectorStoreSystemName, + VectorStoreName = metadata.VectorStoreName, + CollectionName = metadata.CollectionName, + OperationName = operationName + }; + } + catch (TException ex) + { + throw new VectorStoreException("Call to vector store failed.", ex) + { + VectorStoreSystemName = metadata.VectorStoreSystemName, + VectorStoreName = metadata.VectorStoreName, + CollectionName = metadata.CollectionName, + OperationName = operationName + }; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static TResult RunOperation( + VectorStoreMetadata metadata, + string operationName, + Func operation) + where TException : Exception + { + return RunOperation( + new VectorStoreCollectionMetadata() + { + CollectionName = null, + VectorStoreName = metadata.VectorStoreName, + VectorStoreSystemName = metadata.VectorStoreSystemName, + }, + operationName, + operation); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static TResult RunOperation( + VectorStoreCollectionMetadata metadata, + string operationName, + Func operation) + where TException : Exception + { + try + { + return operation.Invoke(); + } + catch (AggregateException ex) when (ex.InnerException is TException innerEx) + { + throw new VectorStoreException("Call to vector store failed.", ex) + { + VectorStoreSystemName = metadata.VectorStoreSystemName, + VectorStoreName = metadata.VectorStoreName, + CollectionName = metadata.CollectionName, + OperationName = operationName + }; + } + catch (TException ex) + { + throw new VectorStoreException("Call to vector store failed.", ex) + { + VectorStoreSystemName = metadata.VectorStoreSystemName, + VectorStoreName = metadata.VectorStoreName, + CollectionName = metadata.CollectionName, + OperationName = operationName + }; + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static async Task RunOperationAsync( + VectorStoreCollectionMetadata metadata, + string operationName, + Func operation) + where TException : Exception + { + try + { + await operation.Invoke().ConfigureAwait(false); + } + catch (AggregateException ex) when (ex.InnerException is TException innerEx) + { + throw new VectorStoreException("Call to vector store failed.", ex) + { + VectorStoreSystemName = metadata.VectorStoreSystemName, + VectorStoreName = metadata.VectorStoreName, + CollectionName = metadata.CollectionName, + OperationName = operationName + }; + } + catch (TException ex) + { + throw new VectorStoreException("Call to vector store failed.", ex) + { + VectorStoreSystemName = metadata.VectorStoreSystemName, + VectorStoreName = metadata.VectorStoreName, + CollectionName = metadata.CollectionName, + OperationName = operationName + }; + } + } + + internal static Task ReadWithErrorHandlingAsync( + this DbDataReader reader, + VectorStoreCollectionMetadata metadata, + string operationName, + CancellationToken cancellationToken) + => VectorStoreErrorHandler.RunOperationAsync( + metadata, + operationName, + () => reader.ReadAsync(cancellationToken)); + + internal static Task ReadWithErrorHandlingAsync( + this DbDataReader reader, + VectorStoreMetadata metadata, + string operationName, + CancellationToken cancellationToken) + => VectorStoreErrorHandler.RunOperationAsync( + metadata, + operationName, + () => reader.ReadAsync(cancellationToken)); + + internal static async Task ExecuteWithErrorHandlingAsync( + this DbConnection connection, + VectorStoreMetadata metadata, + string operationName, + Func> operation, + CancellationToken cancellationToken) + { + return await ExecuteWithErrorHandlingAsync( + connection, + new VectorStoreCollectionMetadata + { + VectorStoreSystemName = metadata.VectorStoreSystemName, + VectorStoreName = metadata.VectorStoreName, + CollectionName = null + }, + operationName, + operation, + cancellationToken).ConfigureAwait(false); + } + + internal static async Task ExecuteWithErrorHandlingAsync( + this DbConnection connection, + VectorStoreCollectionMetadata metadata, + string operationName, + Func> operation, + CancellationToken cancellationToken) + { + if (connection.State != System.Data.ConnectionState.Open) + { + await connection.OpenAsync(cancellationToken).ConfigureAwait(false); + } + + try + { + return await operation().ConfigureAwait(false); + } + catch (DbException ex) + { +#if NET + await connection.DisposeAsync().ConfigureAwait(false); +#else + connection.Dispose(); +#endif + + throw new VectorStoreException("Call to vector store failed.", ex) + { + VectorStoreSystemName = metadata.VectorStoreSystemName, + VectorStoreName = metadata.VectorStoreName, + CollectionName = metadata.CollectionName, + OperationName = operationName + }; + } + catch (IOException ex) + { +#if NET + await connection.DisposeAsync().ConfigureAwait(false); +#else + connection.Dispose(); +#endif + + throw new VectorStoreException("Call to vector store failed.", ex) + { + VectorStoreSystemName = metadata.VectorStoreSystemName, + VectorStoreName = metadata.VectorStoreName, + CollectionName = metadata.CollectionName, + OperationName = operationName + }; + } + catch (Exception) + { +#if NET + await connection.DisposeAsync().ConfigureAwait(false); +#else + connection.Dispose(); +#endif + throw; + } + } +} diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/ModelTests/SqlServerBasicModelTests.cs b/test/Microsoft.SqlServer.VectorData.ConformanceTests/ModelTests/SqlServerBasicModelTests.cs new file mode 100644 index 0000000..4118de0 --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/ModelTests/SqlServerBasicModelTests.cs @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.VectorData; +using SqlServer.ConformanceTests.Support; +using VectorData.ConformanceTests.ModelTests; +using VectorData.ConformanceTests.Support; +using Xunit; + +namespace SqlServer.ConformanceTests.ModelTests; + +public class SqlServerBasicModelTests(SqlServerBasicModelTests.Fixture fixture) + : BasicModelTests(fixture), IClassFixture +{ + private const int SqlServerMaxParameters = 2_100; + + [Fact] + private async Task Split_batches_to_account_for_max_parameter_limit() + { + var collection = fixture.Collection; + Record[] inserted = Enumerable.Range(0, SqlServerMaxParameters + 1).Select(i => new Record() + { + Key = fixture.GenerateNextKey(), + Number = 100 + i, + Text = i.ToString(), + Vector = Enumerable.Range(0, 3).Select(j => (float)(i + j)).ToArray() + }).ToArray(); + var keys = inserted.Select(record => record.Key).ToArray(); + + Assert.Empty(await collection.GetAsync(keys).ToArrayAsync()); + await collection.UpsertAsync(inserted); + + var received = await collection.GetAsync(keys).ToArrayAsync(); + foreach (var record in inserted) + { + record.AssertEqual( + received.Single(r => r.Key.Equals(record.Key, StringComparison.Ordinal)), + includeVectors: false, + fixture.TestStore.VectorsComparable); + } + + await collection.DeleteAsync(keys); + Assert.Empty(await collection.GetAsync(keys).ToArrayAsync()); + } + + [Fact] + public async Task Upsert_batch_is_atomic() + { + var collection = fixture.Collection; + Record[] inserted = Enumerable.Range(0, SqlServerMaxParameters + 1).Select(i => new Record() + { + // The last Key is set to NULL, so it must not be inserted and the whole batch should fail + Key = i < SqlServerMaxParameters ? fixture.GenerateNextKey() : null!, + Number = 100 + i, + Text = i.ToString(), + Vector = Enumerable.Range(0, 3).Select(j => (float)(i + j)).ToArray() + }).ToArray(); + + var keys = inserted.Select(record => record.Key).Where(key => key is not null).ToArray(); + Assert.Empty(await collection.GetAsync(keys).ToArrayAsync()); + + VectorStoreException ex = await Assert.ThrowsAsync(() => collection.UpsertAsync(inserted)); + Assert.Equal("UpsertBatch", ex.OperationName); + + var metadata = collection.GetService(typeof(VectorStoreCollectionMetadata)) as VectorStoreCollectionMetadata; + + Assert.NotNull(metadata?.CollectionName); + Assert.Equal(metadata.CollectionName, ex.CollectionName); + + // Make sure that no records were inserted! + Assert.Empty(await collection.GetAsync(keys).ToArrayAsync()); + } + + public new class Fixture : BasicModelTests.Fixture + { + public override TestStore TestStore => SqlServerTestStore.Instance; + } +} diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/ModelTests/SqlServerDynamicModelTests.cs b/test/Microsoft.SqlServer.VectorData.ConformanceTests/ModelTests/SqlServerDynamicModelTests.cs new file mode 100644 index 0000000..038623e --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/ModelTests/SqlServerDynamicModelTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServer.ConformanceTests.Support; +using VectorData.ConformanceTests.ModelTests; +using VectorData.ConformanceTests.Support; +using Xunit; + +namespace SqlServer.ConformanceTests.ModelTests; + +public class SqlServerDynamicModelTests(SqlServerDynamicModelTests.Fixture fixture) + : DynamicModelTests(fixture), IClassFixture +{ + public new class Fixture : DynamicModelTests.Fixture + { + public override TestStore TestStore => SqlServerTestStore.Instance; + } +} diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/ModelTests/SqlServerMultiVectorModelTests.cs b/test/Microsoft.SqlServer.VectorData.ConformanceTests/ModelTests/SqlServerMultiVectorModelTests.cs new file mode 100644 index 0000000..10b5ad8 --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/ModelTests/SqlServerMultiVectorModelTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServer.ConformanceTests.Support; +using VectorData.ConformanceTests.ModelTests; +using VectorData.ConformanceTests.Support; +using Xunit; + +namespace SqlServer.ConformanceTests.ModelTests; + +public class SqlServerMultiVectorModelTests(SqlServerMultiVectorModelTests.Fixture fixture) + : MultiVectorModelTests(fixture), IClassFixture +{ + public new class Fixture : MultiVectorModelTests.Fixture + { + public override TestStore TestStore => SqlServerTestStore.Instance; + } +} diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/ModelTests/SqlServerNoDataModelTests.cs b/test/Microsoft.SqlServer.VectorData.ConformanceTests/ModelTests/SqlServerNoDataModelTests.cs new file mode 100644 index 0000000..6938f26 --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/ModelTests/SqlServerNoDataModelTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServer.ConformanceTests.Support; +using VectorData.ConformanceTests.ModelTests; +using VectorData.ConformanceTests.Support; +using Xunit; + +namespace SqlServer.ConformanceTests.ModelTests; + +public class SqlServerNoDataModelTests(SqlServerNoDataModelTests.Fixture fixture) + : NoDataModelTests(fixture), IClassFixture +{ + public new class Fixture : NoDataModelTests.Fixture + { + public override TestStore TestStore => SqlServerTestStore.Instance; + } +} diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/ModelTests/SqlServerNoVectorModelTests.cs b/test/Microsoft.SqlServer.VectorData.ConformanceTests/ModelTests/SqlServerNoVectorModelTests.cs new file mode 100644 index 0000000..5d8bc1b --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/ModelTests/SqlServerNoVectorModelTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServer.ConformanceTests.Support; +using VectorData.ConformanceTests.ModelTests; +using VectorData.ConformanceTests.Support; +using Xunit; + +namespace SqlServer.ConformanceTests.ModelTests; + +public class SqlServerNoVectorModelTests(SqlServerNoVectorModelTests.Fixture fixture) + : NoVectorModelTests(fixture), IClassFixture +{ + public new class Fixture : NoVectorModelTests.Fixture + { + public override TestStore TestStore => SqlServerTestStore.Instance; + } +} diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/Properties/AssemblyAttributes.cs b/test/Microsoft.SqlServer.VectorData.ConformanceTests/Properties/AssemblyAttributes.cs new file mode 100644 index 0000000..cbb67c1 --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/Properties/AssemblyAttributes.cs @@ -0,0 +1 @@ +// Copyright (c) Microsoft. All rights reserved. diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/README.md b/test/Microsoft.SqlServer.VectorData.ConformanceTests/README.md new file mode 100644 index 0000000..120d2f8 --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/README.md @@ -0,0 +1,52 @@ +# SQL Server Vector Store Conformance Tests + +This project contains conformance tests for the SQL Server Vector Store implementation. + +## Running the Tests + +By default, the tests will automatically use a testcontainer to spin up a SQL Server instance. Docker must be running on your machine for this to work. + +### Using an External SQL Server Instance + +If you want to run the tests against an external SQL Server instance (e.g., Azure SQL, a local SQL Server, or any other instance), you can provide a connection string through one of the following methods: + +#### Option 1: Environment Variable + +Set the `SqlServer__ConnectionString` environment variable: + +```bash +# Bash/Linux/macOS +export SqlServer__ConnectionString="Server=myserver.database.windows.net;Database=mydb;User Id=myuser;Password=mypassword;" + +# PowerShell +$env:SqlServer__ConnectionString = "Server=myserver.database.windows.net;Database=mydb;User Id=myuser;Password=mypassword;" +``` + +#### Option 2: Configuration File + +Create a `testsettings.development.json` file in this directory with the following content: + +```json +{ + "SqlServer": { + "ConnectionString": "Server=myserver.database.windows.net;Database=mydb;User Id=myuser;Password=mypassword;" + } +} +``` + +This file is git-ignored and safe for local development. + +#### Option 3: User Secrets + +```bash +cd test/VectorData/SqlServer.ConformanceTests +dotnet user-secrets set "SqlServer:ConnectionString" "Server=myserver.database.windows.net;Database=mydb;User Id=myuser;Password=mypassword;" +``` + +## Benefits of Using an External Instance + +Using an external SQL Server instance can be beneficial when: +- You want to avoid the overhead of spinning up Docker containers +- You need to test against Azure SQL specifically +- You want faster test execution (no container startup time) +- You're running tests in an environment where Docker is not available diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServer.ConformanceTests.csproj b/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServer.ConformanceTests.csproj new file mode 100644 index 0000000..3af3590 --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServer.ConformanceTests.csproj @@ -0,0 +1,38 @@ + + + + net10.0;net472 + enable + enable + Microsoft.SqlServer.VectorData.ConformanceTests + SqlServer.ConformanceTests + + false + true + + $(NoWarn);MEVD9000,MEVD9001 + true + b7762d10-e29b-4bb1-8b74-b6d69a667dd4 + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + + + + + + + + + diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerCollectionManagementTests.cs b/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerCollectionManagementTests.cs new file mode 100644 index 0000000..b1fa017 --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerCollectionManagementTests.cs @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServer.ConformanceTests.Support; +using VectorData.ConformanceTests; +using VectorData.ConformanceTests.Support; +using Xunit; + +namespace SqlServer.ConformanceTests; + +public class SqlServerCollectionManagementTests(SqlServerCollectionManagementTests.Fixture fixture) + : CollectionManagementTests(fixture), IClassFixture +{ + public class Fixture : VectorStoreFixture + { + public override TestStore TestStore => SqlServerTestStore.Instance; + } +} diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerCommandBuilderTests.cs b/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerCommandBuilderTests.cs new file mode 100644 index 0000000..da7a771 --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerCommandBuilderTests.cs @@ -0,0 +1,743 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Text; +using Microsoft.Data.SqlClient; +using Microsoft.Data.SqlTypes; +using Microsoft.Extensions.VectorData; +using Microsoft.Extensions.VectorData.ProviderServices; +using Microsoft.SqlServer.VectorData; +using Xunit; + +namespace SqlServer.ConformanceTests; + +public class SqlServerCommandBuilderTests +{ + [Theory] + [InlineData("schema", "name", "[schema].[name]")] + [InlineData(null, "name", "[name]")] + [InlineData("schema", "[brackets]", "[schema].[[brackets]]]")] + [InlineData(null, "[needsEscaping]", "[[needsEscaping]]]")] + [InlineData("needs]escaping", "[brackets]", "[needs]]escaping].[[brackets]]]")] + public void AppendTableName(string? schema, string table, string expectedFullName) + { + StringBuilder result = new(); + + SqlServerCommandBuilder.AppendTableName(result, schema, table); + + Assert.Equal(expectedFullName, result.ToString()); + } + + [Theory] + [InlineData("schema", "name", "[schema].[name]")] + [InlineData(null, "name", "[name]")] + [InlineData("schema", "it's", "[schema].[it''s]")] + [InlineData("it's", "name", "[it''s].[name]")] + [InlineData("it's", "it's", "[it''s].[it''s]")] + [InlineData(null, "it's", "[it''s]")] + [InlineData("schema", "[brackets]", "[schema].[[brackets]]]")] + public void AppendTableNameInsideLiteral(string? schema, string table, string expectedFullName) + { + StringBuilder result = new(); + + SqlServerCommandBuilder.AppendTableNameInsideLiteral(result, schema, table); + + Assert.Equal(expectedFullName, result.ToString()); + } + + [Theory] + [InlineData("name", "[name]")] + [InlineData("it's", "[it''s]")] + [InlineData("two''quotes", "[two''''quotes]")] + public void AppendIdentifierInsideLiteral(string identifier, string expected) + { + StringBuilder result = new(); + + SqlServerCommandBuilder.AppendIdentifierInsideLiteral(result, identifier); + + Assert.Equal(expected, result.ToString()); + } + + [Theory] + [InlineData("name", "@name_")] // typical name + [InlineData("na me", "@na_")] // contains a whitespace, an illegal parameter name character + [InlineData("123", "@_")] // starts with a digit, also not allowed + [InlineData("ĄŻŚĆ_doesNotStartWithAscii", "@_")] // starts with a non-ASCII character + public void AppendParameterName(string propertyName, string expectedPrefix) + { + StringBuilder builder = new(); + StringBuilder expectedBuilder = new(); + KeyPropertyModel keyProperty = new(propertyName, typeof(string)); + + int paramIndex = 0; // we need a dedicated variable to ensure that AppendParameterName increments the index + for (int i = 0; i < 10; i++) + { + Assert.Equal(paramIndex, i); + SqlServerCommandBuilder.AppendParameterName(builder, keyProperty, ref paramIndex, out string parameterName); + Assert.Equal($"{expectedPrefix}{i}", parameterName); + expectedBuilder.Append(parameterName); + } + + Assert.Equal(expectedBuilder.ToString(), builder.ToString()); + } + + [Theory] + [InlineData("schema", "simpleName", "[simpleName]")] + [InlineData("schema", "[needsEscaping]", "[[needsEscaping]]]")] + public void DropTable(string schema, string table, string expectedTable) + { + using SqlConnection connection = CreateConnection(); + + using SqlCommand command = SqlServerCommandBuilder.DropTableIfExists(connection, schema, table); + + Assert.Equal($"DROP TABLE IF EXISTS [{schema}].{expectedTable}", command.CommandText); + } + + [Theory] + [InlineData("schema", "simpleName")] + [InlineData("schema", "[needsEscaping]")] + public void SelectTableName(string schema, string table) + { + using SqlConnection connection = CreateConnection(); + + using SqlCommand command = SqlServerCommandBuilder.SelectTableName(connection, schema, table); + + Assert.Equal( + """ + SELECT TABLE_NAME + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_TYPE = 'BASE TABLE' + AND (@schema is NULL or TABLE_SCHEMA = @schema) + AND TABLE_NAME = @tableName + """ + , command.CommandText); + Assert.Equal(schema, command.Parameters[0].Value); + Assert.Equal(table, command.Parameters[1].Value); + } + + [Fact] + public void SelectTableNames() + { + const string SchemaName = "theSchemaName"; + using SqlConnection connection = CreateConnection(); + + using SqlCommand command = SqlServerCommandBuilder.SelectTableNames(connection, SchemaName); + + Assert.Equal( + """ + SELECT TABLE_NAME + FROM INFORMATION_SCHEMA.TABLES + WHERE TABLE_TYPE = 'BASE TABLE' + AND (@schema is NULL or TABLE_SCHEMA = @schema) + """ + , command.CommandText); + Assert.Equal(SchemaName, command.Parameters[0].Value); + Assert.Equal("@schema", command.Parameters[0].ParameterName); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public void CreateTable(bool ifNotExists) + { + var model = BuildModel( + [ + new VectorStoreKeyProperty("id", typeof(long)), + new VectorStoreDataProperty("simpleName", typeof(string)), + new VectorStoreDataProperty("with space", typeof(int)) { IsIndexed = true }, + new VectorStoreDataProperty("nullableInt", typeof(int?)), + new VectorStoreDataProperty("flag", typeof(bool)), + new VectorStoreVectorProperty("embedding", typeof(ReadOnlyMemory), 10), + new VectorStoreVectorProperty("nullableEmbedding", typeof(ReadOnlyMemory?), 10) + ]); + + using SqlConnection connection = CreateConnection(); + + var commands = SqlServerCommandBuilder.CreateTable(connection, "schema", "table", ifNotExists, model); + + var command = Assert.Single(commands); + string expectedCommand = + """ + BEGIN + CREATE TABLE [schema].[table] ( + [id] BIGINT IDENTITY, + [simpleName] NVARCHAR(MAX), + [with space] INT NOT NULL, + [nullableInt] INT, + [flag] BIT NOT NULL, + [embedding] VECTOR(10) NOT NULL, + [nullableEmbedding] VECTOR(10), + PRIMARY KEY ([id]) + ); + CREATE INDEX index_table_withspace ON [schema].[table]([with space]); + END; + """; + if (ifNotExists) + { + expectedCommand = "IF OBJECT_ID(N'[schema].[table]', N'U') IS NULL" + Environment.NewLine + expectedCommand; + } + + Assert.Equal(expectedCommand, command.CommandText, ignoreLineEndingDifferences: true); + } + + [Fact] + public void CreateTable_WithSingleQuoteInName() + { + var model = BuildModel( + [ + new VectorStoreKeyProperty("id", typeof(long)), + new VectorStoreDataProperty("name", typeof(string)), + ]); + + using SqlConnection connection = CreateConnection(); + + var commands = SqlServerCommandBuilder.CreateTable(connection, "it's", "ta'ble", ifNotExists: true, model); + + var command = Assert.Single(commands); + Assert.Equal( + "IF OBJECT_ID(N'[it''s].[ta''ble]', N'U') IS NULL" + Environment.NewLine + + """ + BEGIN + CREATE TABLE [it's].[ta'ble] ( + [id] BIGINT IDENTITY, + [name] NVARCHAR(MAX), + PRIMARY KEY ([id]) + ); + END; + """, + command.CommandText, ignoreLineEndingDifferences: true); + } + + [Fact] + public void CreateTable_WithDiskAnnIndex() + { + var model = BuildModel( + [ + new VectorStoreKeyProperty("id", typeof(long)), + new VectorStoreDataProperty("name", typeof(string)), + new VectorStoreVectorProperty("embedding", typeof(ReadOnlyMemory), 10) + { + IndexKind = IndexKind.DiskAnn, + DistanceFunction = DistanceFunction.CosineDistance + } + ]); + + using SqlConnection connection = CreateConnection(); + + var commands = SqlServerCommandBuilder.CreateTable(connection, "schema", "table", ifNotExists: false, model); + + Assert.Equal(3, commands.Count); + Assert.Equal( + """ + BEGIN + CREATE TABLE [schema].[table] ( + [id] BIGINT IDENTITY, + [name] NVARCHAR(MAX), + [embedding] VECTOR(10) NOT NULL, + PRIMARY KEY ([id]) + ); + END; + """, commands[0].CommandText, ignoreLineEndingDifferences: true); + Assert.Equal("ALTER DATABASE SCOPED CONFIGURATION SET PREVIEW_FEATURES = ON;", commands[1].CommandText); + Assert.Equal( + """ + CREATE VECTOR INDEX index_table_embedding ON [schema].[table]([embedding]) WITH (METRIC = 'COSINE', TYPE = 'DISKANN'); + + """, commands[2].CommandText, ignoreLineEndingDifferences: true); + } + + [Fact] + public void CreateTable_WithDiskAnnIndex_EuclideanDistance() + { + var model = BuildModel( + [ + new VectorStoreKeyProperty("id", typeof(long)), + new VectorStoreVectorProperty("embedding", typeof(ReadOnlyMemory), 10) + { + IndexKind = IndexKind.DiskAnn, + DistanceFunction = DistanceFunction.EuclideanDistance + } + ]); + + using SqlConnection connection = CreateConnection(); + + var commands = SqlServerCommandBuilder.CreateTable(connection, "schema", "table", ifNotExists: false, model); + + Assert.Equal(3, commands.Count); + Assert.Equal( + """ + BEGIN + CREATE TABLE [schema].[table] ( + [id] BIGINT IDENTITY, + [embedding] VECTOR(10) NOT NULL, + PRIMARY KEY ([id]) + ); + END; + """, commands[0].CommandText, ignoreLineEndingDifferences: true); + Assert.Equal("ALTER DATABASE SCOPED CONFIGURATION SET PREVIEW_FEATURES = ON;", commands[1].CommandText); + Assert.Equal( + """ + CREATE VECTOR INDEX index_table_embedding ON [schema].[table]([embedding]) WITH (METRIC = 'EUCLIDEAN', TYPE = 'DISKANN'); + + """, commands[2].CommandText, ignoreLineEndingDifferences: true); + } + + [Fact] + public void CreateTable_WithUnsupportedIndexKind_Throws() + { + Assert.Throws(() => + BuildModel( + [ + new VectorStoreKeyProperty("id", typeof(long)), + new VectorStoreVectorProperty("embedding", typeof(ReadOnlyMemory), 10) + { + IndexKind = IndexKind.Hnsw + } + ])); + } + + [Fact] + public void SelectVector_WithDiskAnnIndex() + { + var model = BuildModel( + [ + new VectorStoreKeyProperty("id", typeof(long)), + new VectorStoreDataProperty("name", typeof(string)), + new VectorStoreVectorProperty("embedding", typeof(ReadOnlyMemory), 3) + { + IndexKind = IndexKind.DiskAnn, + DistanceFunction = DistanceFunction.CosineDistance + } + ]); + + using SqlConnection connection = CreateConnection(); + + var options = new VectorSearchOptions> { IncludeVectors = true }; + using SqlCommand command = SqlServerCommandBuilder.SelectVector( + connection, "schema", "table", + model.VectorProperties[0], model, + top: 5, options, + new SqlVector(new float[] { 1f, 2f, 3f })); + + Assert.Equal( + """ + SELECT TOP(5) WITH APPROXIMATE t.[id],t.[name],t.[embedding], + s.[distance] AS [score] + FROM VECTOR_SEARCH(TABLE = [schema].[table] AS t, COLUMN = [embedding], SIMILAR_TO = @vector, METRIC = 'COSINE') AS s + ORDER BY [score] ASC + """, command.CommandText, ignoreLineEndingDifferences: true); + } + + [Fact] + public void SelectVector_WithDiskAnnIndex_WithSkip() + { + var model = BuildModel( + [ + new VectorStoreKeyProperty("id", typeof(long)), + new VectorStoreDataProperty("name", typeof(string)), + new VectorStoreVectorProperty("embedding", typeof(ReadOnlyMemory), 3) + { + IndexKind = IndexKind.DiskAnn, + DistanceFunction = DistanceFunction.CosineDistance + } + ]); + + using SqlConnection connection = CreateConnection(); + + var options = new VectorSearchOptions> { IncludeVectors = false, Skip = 3 }; + using SqlCommand command = SqlServerCommandBuilder.SelectVector( + connection, "schema", "table", + model.VectorProperties[0], model, + top: 5, options, + new SqlVector(new float[] { 1f, 2f, 3f })); + + Assert.Equal( + """ + SELECT * FROM (SELECT TOP(8) WITH APPROXIMATE t.[id],t.[name], + s.[distance] AS [score] + FROM VECTOR_SEARCH(TABLE = [schema].[table] AS t, COLUMN = [embedding], SIMILAR_TO = @vector, METRIC = 'COSINE') AS s + ORDER BY [score] ASC + ) AS [inner] + ORDER BY [score] ASC + OFFSET 3 ROWS FETCH NEXT 5 ROWS ONLY; + """, command.CommandText, ignoreLineEndingDifferences: true); + } + + [Fact] + public void SelectVector_WithDiskAnnIndex_WithFilter() + { + var model = BuildModel( + [ + new VectorStoreKeyProperty("id", typeof(long)), + new VectorStoreDataProperty("name", typeof(string)), + new VectorStoreVectorProperty("embedding", typeof(ReadOnlyMemory), 3) + { + IndexKind = IndexKind.DiskAnn, + DistanceFunction = DistanceFunction.CosineDistance + } + ]); + + using SqlConnection connection = CreateConnection(); + + var options = new VectorSearchOptions> + { + Filter = d => (string)d["name"]! == "test" + }; + + using SqlCommand command = SqlServerCommandBuilder.SelectVector( + connection, "schema", "table", + model.VectorProperties[0], model, + top: 5, options, + new SqlVector(new float[] { 1f, 2f, 3f })); + + Assert.Equal( + """ + SELECT TOP(5) WITH APPROXIMATE t.[id],t.[name], + s.[distance] AS [score] + FROM VECTOR_SEARCH(TABLE = [schema].[table] AS t, COLUMN = [embedding], SIMILAR_TO = @vector, METRIC = 'COSINE') AS s + WHERE (t.[name] = 'test') + ORDER BY [score] ASC + """, command.CommandText, ignoreLineEndingDifferences: true); + } + + [Fact] + public void SelectVector_WithDiskAnnIndex_WithScoreThreshold() + { + var model = BuildModel( + [ + new VectorStoreKeyProperty("id", typeof(long)), + new VectorStoreDataProperty("name", typeof(string)), + new VectorStoreVectorProperty("embedding", typeof(ReadOnlyMemory), 3) + { + IndexKind = IndexKind.DiskAnn, + DistanceFunction = DistanceFunction.CosineDistance + } + ]); + + using SqlConnection connection = CreateConnection(); + + var options = new VectorSearchOptions> + { + IncludeVectors = true, + ScoreThreshold = 0.5f + }; + using SqlCommand command = SqlServerCommandBuilder.SelectVector( + connection, "schema", "table", + model.VectorProperties[0], model, + top: 5, options, + new SqlVector(new float[] { 1f, 2f, 3f })); + + Assert.Equal( + """ + SELECT TOP(5) WITH APPROXIMATE t.[id],t.[name],t.[embedding], + s.[distance] AS [score] + FROM VECTOR_SEARCH(TABLE = [schema].[table] AS t, COLUMN = [embedding], SIMILAR_TO = @vector, METRIC = 'COSINE') AS s + WHERE s.[distance] <= @scoreThreshold + ORDER BY [score] ASC + """, command.CommandText, ignoreLineEndingDifferences: true); + } + + [Fact] + public void SelectVector_WithDiskAnnIndex_WithFilterAndScoreThreshold() + { + var model = BuildModel( + [ + new VectorStoreKeyProperty("id", typeof(long)), + new VectorStoreDataProperty("name", typeof(string)), + new VectorStoreVectorProperty("embedding", typeof(ReadOnlyMemory), 3) + { + IndexKind = IndexKind.DiskAnn, + DistanceFunction = DistanceFunction.CosineDistance + } + ]); + + using SqlConnection connection = CreateConnection(); + + var options = new VectorSearchOptions> + { + Filter = d => (string)d["name"]! == "test", + ScoreThreshold = 0.5f + }; + + using SqlCommand command = SqlServerCommandBuilder.SelectVector( + connection, "schema", "table", + model.VectorProperties[0], model, + top: 5, options, + new SqlVector(new float[] { 1f, 2f, 3f })); + + Assert.Equal( + """ + SELECT TOP(5) WITH APPROXIMATE t.[id],t.[name], + s.[distance] AS [score] + FROM VECTOR_SEARCH(TABLE = [schema].[table] AS t, COLUMN = [embedding], SIMILAR_TO = @vector, METRIC = 'COSINE') AS s + WHERE (t.[name] = 'test') + AND s.[distance] <= @scoreThreshold + ORDER BY [score] ASC + """, command.CommandText, ignoreLineEndingDifferences: true); + } + + [Fact] + public void SelectHybrid_WithDiskAnnIndex_WithFilter() + { + var model = BuildModel( + [ + new VectorStoreKeyProperty("id", typeof(long)), + new VectorStoreDataProperty("name", typeof(string)) { IsFullTextIndexed = true }, + new VectorStoreVectorProperty("embedding", typeof(ReadOnlyMemory), 3) + { + IndexKind = IndexKind.DiskAnn, + DistanceFunction = DistanceFunction.CosineDistance + } + ]); + + using SqlConnection connection = CreateConnection(); + + var options = new HybridSearchOptions> + { + Filter = d => (string)d["name"]! == "test" + }; + + using SqlCommand command = SqlServerCommandBuilder.SelectHybrid( + connection, "schema", "table", + model.VectorProperties[0], model.DataProperties.First(p => p.IsFullTextIndexed), model, + top: 5, options, + new SqlVector(new float[] { 1f, 2f, 3f }), + "keyword"); + + Assert.Contains("SELECT TOP(@candidateCount) WITH APPROXIMATE", command.CommandText); + Assert.Contains("WHERE (t.[name] = 'test')", command.CommandText); + Assert.Contains("VECTOR_SEARCH(TABLE =", command.CommandText); + } + + [Fact] + public void Upsert() + { + var model = BuildModel( + [ + new VectorStoreKeyProperty("id", typeof(long)) { IsAutoGenerated = false }, + new VectorStoreDataProperty("simpleString", typeof(string)), + new VectorStoreDataProperty("simpleInt", typeof(int)), + new VectorStoreVectorProperty("embedding", typeof(ReadOnlyMemory), 10) + ]); + + Dictionary[] records = + [ + new Dictionary + { + { "id", 0L }, + { "simpleString", "nameValue0" }, + { "simpleInt", 134 }, + { "embedding", new ReadOnlyMemory([10.0f]) } + }, + new Dictionary + { + { "id", 1L }, + { "simpleString", "nameValue1" }, + { "simpleInt", 135 }, + { "embedding", new ReadOnlyMemory([11.0f]) } + } + ]; + + using SqlConnection connection = CreateConnection(); + using SqlCommand command = connection.CreateCommand(); + + Assert.True(SqlServerCommandBuilder.Upsert(command, "schema", "table", model, records, firstRecordIndex: 0, generatedEmbeddings: null)); + + string expectedCommand = + """" + MERGE INTO [schema].[table] AS t + USING (VALUES (@id_0,@simpleString_1,@simpleInt_2,@embedding_3)) AS s ([id],[simpleString],[simpleInt],[embedding]) + ON (t.[id] = s.[id]) + WHEN MATCHED THEN + UPDATE SET t.[simpleString] = s.[simpleString],t.[simpleInt] = s.[simpleInt],t.[embedding] = s.[embedding] + WHEN NOT MATCHED THEN + INSERT ([id],[simpleString],[simpleInt],[embedding]) + VALUES (s.[id],s.[simpleString],s.[simpleInt],s.[embedding]) + OUTPUT inserted.[id]; + + MERGE INTO [schema].[table] AS t + USING (VALUES (@id_4,@simpleString_5,@simpleInt_6,@embedding_7)) AS s ([id],[simpleString],[simpleInt],[embedding]) + ON (t.[id] = s.[id]) + WHEN MATCHED THEN + UPDATE SET t.[simpleString] = s.[simpleString],t.[simpleInt] = s.[simpleInt],t.[embedding] = s.[embedding] + WHEN NOT MATCHED THEN + INSERT ([id],[simpleString],[simpleInt],[embedding]) + VALUES (s.[id],s.[simpleString],s.[simpleInt],s.[embedding]) + OUTPUT inserted.[id]; + + + """"; + + Assert.Equal(expectedCommand, command.CommandText, ignoreLineEndingDifferences: true); + + for (int i = 0; i < records.Length; i++) + { + Assert.Equal($"@id_{4 * i + 0}", command.Parameters[4 * i + 0].ParameterName); + Assert.Equal((long)i, command.Parameters[4 * i + 0].Value); + Assert.Equal($"@simpleString_{4 * i + 1}", command.Parameters[4 * i + 1].ParameterName); + Assert.Equal($"nameValue{i}", command.Parameters[4 * i + 1].Value); + Assert.Equal($"@simpleInt_{4 * i + 2}", command.Parameters[4 * i + 2].ParameterName); + Assert.Equal(134 + i, command.Parameters[4 * i + 2].Value); + Assert.Equal($"@embedding_{4 * i + 3}", command.Parameters[4 * i + 3].ParameterName); + var vector = Assert.IsType>(command.Parameters[4 * i + 3].Value); + Assert.Equal([10 + i], vector.Memory.ToArray()); + } + } + + [Fact] + public void DeleteSingle() + { + KeyPropertyModel keyProperty = new("id", typeof(long)); + using SqlConnection connection = CreateConnection(); + + using SqlCommand command = SqlServerCommandBuilder.DeleteSingle(connection, + "schema", "tableName", keyProperty, 123L); + + Assert.Equal("DELETE FROM [schema].[tableName] WHERE [id] = @id_0", command.CommandText); + Assert.Equal(123L, command.Parameters[0].Value); + Assert.Equal("@id_0", command.Parameters[0].ParameterName); + } + + [Fact] + public void DeleteMany() + { + string[] keys = ["key1", "key2"]; + KeyPropertyModel keyProperty = new("id", typeof(string)); + using SqlConnection connection = CreateConnection(); + using SqlCommand command = connection.CreateCommand(); + + Assert.True(SqlServerCommandBuilder.DeleteMany(command, "schema", "tableName", keyProperty, keys)); + + Assert.Equal("DELETE FROM [schema].[tableName] WHERE [id] IN (@id_0,@id_1)", command.CommandText); + for (int i = 0; i < keys.Length; i++) + { + Assert.Equal(keys[i], command.Parameters[i].Value); + Assert.Equal($"@id_{i}", command.Parameters[i].ParameterName); + } + } + + [Fact] + public void SelectSingle() + { + var model = BuildModel( + [ + new VectorStoreKeyProperty("id", typeof(long)), + new VectorStoreDataProperty("name", typeof(string)), + new VectorStoreDataProperty("age", typeof(int)), + new VectorStoreVectorProperty("embedding", typeof(ReadOnlyMemory), 10) + ]); + + using SqlConnection connection = CreateConnection(); + + using SqlCommand command = SqlServerCommandBuilder.SelectSingle(connection, "schema", "tableName", model, 123L, includeVectors: true); + + Assert.Equal( + """"" + SELECT [id],[name],[age],[embedding] + FROM [schema].[tableName] + WHERE [id] = @id_0 + """"", command.CommandText, ignoreLineEndingDifferences: true); + Assert.Equal(123L, command.Parameters[0].Value); + Assert.Equal("@id_0", command.Parameters[0].ParameterName); + } + + [Fact] + public void SelectMany() + { + var model = BuildModel( + [ + new VectorStoreKeyProperty("id", typeof(long)), + new VectorStoreDataProperty("name", typeof(string)), + new VectorStoreDataProperty("age", typeof(int)), + new VectorStoreVectorProperty("embedding", typeof(ReadOnlyMemory), 10) + ]); + + long[] keys = [123L, 456L, 789L]; + using SqlConnection connection = CreateConnection(); + using SqlCommand command = connection.CreateCommand(); + + Assert.True(SqlServerCommandBuilder.SelectMany(command, + "schema", "tableName", model, keys, includeVectors: true)); + + Assert.Equal( + """"" + SELECT [id],[name],[age],[embedding] + FROM [schema].[tableName] + WHERE [id] IN (@id_0,@id_1,@id_2) + """"", command.CommandText, ignoreLineEndingDifferences: true); + for (int i = 0; i < keys.Length; i++) + { + Assert.Equal(keys[i], command.Parameters[i].Value); + Assert.Equal($"@id_{i}", command.Parameters[i].ParameterName); + } + } + + // We create a connection using a fake connection string just to be able to create the SqlCommand. + private static SqlConnection CreateConnection() + => new("Server=localhost;Database=master;Integrated Security=True;"); + + private static CollectionModel BuildModel(List properties) + => new SqlServerModelBuilder() + .BuildDynamic(new() { Properties = properties }, defaultEmbeddingGenerator: null); + +#if NET // NRT detection via NullabilityInfoContext is only available on .NET 6+ + [Fact] + public void CreateTable_WithNrtAnnotations() + { + var model = new SqlServerModelBuilder().Build( + typeof(NrtTestRecord), + typeof(long), + definition: null, + defaultEmbeddingGenerator: null); + + using SqlConnection connection = CreateConnection(); + + var commands = SqlServerCommandBuilder.CreateTable(connection, "schema", "table", ifNotExists: false, model); + + var command = Assert.Single(commands); + + Assert.Equal( + """ + BEGIN + CREATE TABLE [schema].[table] ( + [Id] BIGINT IDENTITY, + [NonNullableString] NVARCHAR(MAX) NOT NULL, + [NullableString] NVARCHAR(MAX), + [NonNullableInt] INT NOT NULL, + [NullableInt] INT, + [NonNullableBool] BIT NOT NULL, + [Embedding] VECTOR(10) NOT NULL, + PRIMARY KEY ([Id]) + ); + END; + """, command.CommandText, ignoreLineEndingDifferences: true); + } + +#pragma warning disable CS8618 // Non-nullable field must contain a non-null value when exiting constructor +#pragma warning disable CA1812 // Class is used via reflection + private sealed class NrtTestRecord + { + [VectorStoreKey] + public long Id { get; set; } + + [VectorStoreData] + public string NonNullableString { get; set; } + + [VectorStoreData] + public string? NullableString { get; set; } + + [VectorStoreData] + public int NonNullableInt { get; set; } + + [VectorStoreData] + public int? NullableInt { get; set; } + + [VectorStoreData] + public bool NonNullableBool { get; set; } + + [VectorStoreVector(10)] + public ReadOnlyMemory Embedding { get; set; } + } +#pragma warning restore CA1812 +#pragma warning restore CS8618 +#endif +} diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerDependencyInjectionTests.cs b/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerDependencyInjectionTests.cs new file mode 100644 index 0000000..e689589 --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerDependencyInjectionTests.cs @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.SqlServer.VectorData; +using VectorData.ConformanceTests; +using Xunit; + +namespace SqlServer.ConformanceTests; + +public class SqlServerDependencyInjectionTests + : DependencyInjectionTests.Record>, string, DependencyInjectionTests.Record> +{ + protected const string ConnectionString = "Server=localhost;Database=master;Integrated Security=True;"; + + protected override void PopulateConfiguration(ConfigurationManager configuration, object? serviceKey = null) + => configuration.AddInMemoryCollection( + [ + new(CreateConfigKey("SqlServer", serviceKey, "ConnectionString"), ConnectionString), + ]); + + private static string ConnectionStringProvider(IServiceProvider sp) + => sp.GetRequiredService().GetRequiredSection("SqlServer:ConnectionString").Value!; + + private static string ConnectionStringProvider(IServiceProvider sp, object serviceKey) + => sp.GetRequiredService().GetRequiredSection(CreateConfigKey("SqlServer", serviceKey, "ConnectionString")).Value!; + + public override IEnumerable> CollectionDelegates + { + get + { + yield return (services, serviceKey, name, lifetime) => serviceKey is null + ? services.AddSqlServerCollection( + name, connectionString: ConnectionString, lifetime: lifetime) + : services.AddKeyedSqlServerCollection( + serviceKey, name, connectionString: ConnectionString, lifetime: lifetime); + + yield return (services, serviceKey, name, lifetime) => serviceKey is null + ? services.AddSqlServerCollection( + name, ConnectionStringProvider, lifetime: lifetime) + : services.AddKeyedSqlServerCollection( + serviceKey, name, sp => ConnectionStringProvider(sp, serviceKey), lifetime: lifetime); + } + } + + public override IEnumerable> StoreDelegates + { + get + { + yield return (services, serviceKey, lifetime) => serviceKey is null + ? services.AddSqlServerVectorStore( + ConnectionStringProvider, lifetime: lifetime) + : services.AddKeyedSqlServerVectorStore( + serviceKey, sp => ConnectionStringProvider(sp, serviceKey), lifetime: lifetime); + } + } + + [Fact] + public void ConnectionStringProviderCantBeNull() + { + IServiceCollection services = new ServiceCollection(); + + Assert.Throws(() => services.AddSqlServerVectorStore(connectionStringProvider: null!)); + Assert.Throws(() => services.AddKeyedSqlServerVectorStore(serviceKey: "notNull", connectionStringProvider: null!)); + Assert.Throws(() => services.AddSqlServerCollection(name: "notNull", connectionStringProvider: null!)); + Assert.Throws(() => services.AddKeyedSqlServerCollection(serviceKey: "notNull", name: "notNull", connectionStringProvider: null!)); + } + + [Fact] + public void ConnectionStringCantBeNullOrEmpty() + { + IServiceCollection services = new ServiceCollection(); + + Assert.Throws(() => services.AddSqlServerCollection( + name: "notNull", connectionString: null!)); + Assert.Throws(() => services.AddSqlServerCollection( + name: "notNull", connectionString: "")); + Assert.Throws(() => services.AddKeyedSqlServerCollection( + serviceKey: "notNull", name: "notNull", connectionString: null!)); + Assert.Throws(() => services.AddKeyedSqlServerCollection( + serviceKey: "notNull", name: "notNull", connectionString: "")); + } +} diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerDistanceFunctionTests.cs b/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerDistanceFunctionTests.cs new file mode 100644 index 0000000..354906a --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerDistanceFunctionTests.cs @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServer.ConformanceTests.Support; +using VectorData.ConformanceTests; +using VectorData.ConformanceTests.Support; +using Xunit; + +namespace SqlServer.ConformanceTests; + +public class SqlServerDistanceFunctionTests(SqlServerDistanceFunctionTests.Fixture fixture) + : DistanceFunctionTests(fixture), IClassFixture +{ + public override Task CosineSimilarity() => Assert.ThrowsAsync(base.CosineSimilarity); + public override Task DotProductSimilarity() => Assert.ThrowsAsync(base.DotProductSimilarity); + public override Task EuclideanSquaredDistance() => Assert.ThrowsAsync(base.EuclideanSquaredDistance); + public override Task HammingDistance() => Assert.ThrowsAsync(base.HammingDistance); + public override Task ManhattanDistance() => Assert.ThrowsAsync(base.ManhattanDistance); + + public new class Fixture() : DistanceFunctionTests.Fixture + { + public override TestStore TestStore => SqlServerTestStore.Instance; + } +} diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerEmbeddingGenerationTests.cs b/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerEmbeddingGenerationTests.cs new file mode 100644 index 0000000..8cf414b --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerEmbeddingGenerationTests.cs @@ -0,0 +1,53 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.VectorData; +using SqlServer.ConformanceTests.Support; +using VectorData.ConformanceTests; +using VectorData.ConformanceTests.Support; +using Xunit; + +namespace SqlServer.ConformanceTests; + +public class SqlServerEmbeddingGenerationTests(SqlServerEmbeddingGenerationTests.StringVectorFixture stringVectorFixture, SqlServerEmbeddingGenerationTests.RomOfFloatVectorFixture romOfFloatVectorFixture) + : EmbeddingGenerationTests(stringVectorFixture, romOfFloatVectorFixture), IClassFixture, IClassFixture +{ + public new class StringVectorFixture : EmbeddingGenerationTests.StringVectorFixture + { + public override TestStore TestStore => SqlServerTestStore.Instance; + + public override VectorStore CreateVectorStore(IEmbeddingGenerator? embeddingGenerator) + => SqlServerTestStore.Instance.GetVectorStore(new() { EmbeddingGenerator = embeddingGenerator }); + + public override Func[] DependencyInjectionStoreRegistrationDelegates => + [ + services => services.AddSqlServerVectorStore(sp => SqlServerTestStore.Instance.ConnectionString) + ]; + + public override Func[] DependencyInjectionCollectionRegistrationDelegates => + [ + services => services.AddSqlServerCollection(this.CollectionName, sp => SqlServerTestStore.Instance.ConnectionString), + services => services.AddSqlServerCollection(this.CollectionName, SqlServerTestStore.Instance.ConnectionString), + ]; + } + + public new class RomOfFloatVectorFixture : EmbeddingGenerationTests.RomOfFloatVectorFixture + { + public override TestStore TestStore => SqlServerTestStore.Instance; + + public override VectorStore CreateVectorStore(IEmbeddingGenerator? embeddingGenerator) + => SqlServerTestStore.Instance.GetVectorStore(new() { EmbeddingGenerator = embeddingGenerator }); + + public override Func[] DependencyInjectionStoreRegistrationDelegates => + [ + services => services.AddSqlServerVectorStore(sp => SqlServerTestStore.Instance.ConnectionString) + ]; + + public override Func[] DependencyInjectionCollectionRegistrationDelegates => + [ + services => services.AddSqlServerCollection(this.CollectionName, sp => SqlServerTestStore.Instance.ConnectionString), + services => services.AddSqlServerCollection(this.CollectionName, SqlServerTestStore.Instance.ConnectionString), + ]; + } +} diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerFilterTests.cs b/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerFilterTests.cs new file mode 100644 index 0000000..8be860c --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerFilterTests.cs @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServer.ConformanceTests.Support; +using VectorData.ConformanceTests; +using VectorData.ConformanceTests.Support; +using Xunit; +using Xunit.Sdk; + +namespace SqlServer.ConformanceTests; + +#pragma warning disable CS0252 // Possible unintended reference comparison; left hand side needs cast + +public class SqlServerFilterTests(SqlServerFilterTests.Fixture fixture) + : FilterTests(fixture), IClassFixture +{ + public override async Task Not_over_Or() + { + // Test sends: WHERE (NOT (("Int" = 8) OR ("String" = 'foo'))) + // There's a NULL string in the database, and relational null semantics in conjunction with negation makes the default implementation fail. + await Assert.ThrowsAsync(() => base.Not_over_Or()); + + // Compensate by adding a null check: + await this.TestFilterAsync( + r => r.String != null && !(r.Int == 8 || r.String == "foo"), + r => r["String"] != null && !((int)r["Int"]! == 8 || r["String"] == "foo")); + } + + public override async Task NotEqual_with_string() + { + // As above, null semantics + negation + await Assert.ThrowsAsync(() => base.NotEqual_with_string()); + + await this.TestFilterAsync( + r => r.String != null && r.String != "foo", + r => r["String"] != null && r["String"] != "foo"); + } + + public new class Fixture : FilterTests.Fixture + { + private static readonly string s_uniqueName = Guid.NewGuid().ToString(); + + public override TestStore TestStore => SqlServerTestStore.Instance; + + protected override string CollectionNameBase => s_uniqueName; + } +} diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerHybridSearchTests.cs b/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerHybridSearchTests.cs new file mode 100644 index 0000000..d12da12 --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerHybridSearchTests.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServer.ConformanceTests.Support; +using VectorData.ConformanceTests; +using VectorData.ConformanceTests.Support; +using Xunit; + +namespace SqlServer.ConformanceTests; + +public class SqlServerHybridSearchTests( + SqlServerHybridSearchTests.VectorAndStringFixture vectorAndStringFixture, + SqlServerHybridSearchTests.MultiTextFixture multiTextFixture) + : HybridSearchTests(vectorAndStringFixture, multiTextFixture), + IClassFixture, + IClassFixture +{ + public new class VectorAndStringFixture : HybridSearchTests.VectorAndStringFixture + { + public override TestStore TestStore => SqlServerTestStore.Instance; + } + + public new class MultiTextFixture : HybridSearchTests.MultiTextFixture + { + public override TestStore TestStore => SqlServerTestStore.Instance; + } +} diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerIndexKindTests.cs b/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerIndexKindTests.cs new file mode 100644 index 0000000..5127333 --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerIndexKindTests.cs @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Data.SqlClient; +using Microsoft.Extensions.VectorData; +using SqlServer.ConformanceTests.Support; +using VectorData.ConformanceTests; +using VectorData.ConformanceTests.Support; +using Xunit; + +namespace SqlServer.ConformanceTests; + +public class SqlServerIndexKindTests(SqlServerIndexKindTests.Fixture fixture) + : IndexKindTests(fixture), IClassFixture +{ + // Latest version vector indexes are only available in Azure SQL, not in on-prem SQL Server. + // They also require at least 100 rows before the vector index can be created, + // so we override the test to insert data first, then create the index. + [Fact] + public virtual async Task DiskAnn() + { + await AzureSqlHelper.EnsureAzureSqlAsync(); + const string CollectionName = "IndexKindTests_DiskAnn"; + + // Step 1: Create the table using Flat index (no vector index) so we can insert data. + VectorStoreCollectionDefinition flatDefinition = new() + { + Properties = + [ + new VectorStoreKeyProperty(nameof(SearchRecord.Key), typeof(int)), + new VectorStoreDataProperty(nameof(SearchRecord.Int), typeof(int)), + new VectorStoreVectorProperty(nameof(SearchRecord.Vector), typeof(ReadOnlyMemory), dimensions: 3) + { + IndexKind = IndexKind.Flat, + DistanceFunction = DistanceFunction.CosineDistance + } + ] + }; + + using var flatCollection = fixture.TestStore.CreateCollection(CollectionName, flatDefinition); + await flatCollection.EnsureCollectionDeletedAsync(); + await flatCollection.EnsureCollectionExistsAsync(); + + try + { + // Step 2: Insert the 3 test rows + 97 filler rows to meet the 100-row minimum. + SearchRecord[] testRecords = + [ + new() { Key = 1, Int = 1, Vector = new([1, 2, 3]) }, + new() { Key = 2, Int = 2, Vector = new([10, 30, 50]) }, + new() { Key = 3, Int = 3, Vector = new([100, 40, 70]) } + ]; + + await flatCollection.UpsertAsync(testRecords); + + var fillerRecords = Enumerable.Range(100, 97) + .Select(i => new SearchRecord + { + Key = i, + Int = i, + Vector = new([i * 0.1f, i * 0.2f, i * 0.3f]) + }) + .ToArray(); + + await flatCollection.UpsertAsync(fillerRecords); + + // Step 3: Create the DiskANN vector index via raw SQL now that data is in the table. + using var connection = new SqlConnection(SqlServerTestStore.Instance.ConnectionString); + await connection.OpenAsync(); + + using (var createIndex = new SqlCommand( + $"CREATE VECTOR INDEX index_{CollectionName}_Vector ON [{CollectionName}]([Vector]) WITH (METRIC = 'COSINE', TYPE = 'DISKANN');", + connection)) + { + await createIndex.ExecuteNonQueryAsync(); + } + + // Step 4: Create a new collection instance with DiskAnn to route searches through VECTOR_SEARCH(). + VectorStoreCollectionDefinition diskAnnDefinition = new() + { + Properties = + [ + new VectorStoreKeyProperty(nameof(SearchRecord.Key), typeof(int)), + new VectorStoreDataProperty(nameof(SearchRecord.Int), typeof(int)), + new VectorStoreVectorProperty(nameof(SearchRecord.Vector), typeof(ReadOnlyMemory), dimensions: 3) + { + IndexKind = IndexKind.DiskAnn, + DistanceFunction = DistanceFunction.CosineDistance + } + ] + }; + + using var diskAnnCollection = fixture.TestStore.CreateCollection(CollectionName, diskAnnDefinition); + + var result = await diskAnnCollection.SearchAsync(new ReadOnlyMemory([10, 30, 50]), top: 1).SingleAsync(); + + Assert.NotNull(result); + Assert.Equal(2, result.Record.Int); + } + finally + { + await flatCollection.EnsureCollectionDeletedAsync(); + } + } + + public new class Fixture() : IndexKindTests.Fixture + { + public override TestStore TestStore => SqlServerTestStore.Instance; + } +} diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerTestSuiteImplementationTests.cs b/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerTestSuiteImplementationTests.cs new file mode 100644 index 0000000..d51296c --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/SqlServerTestSuiteImplementationTests.cs @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft. All rights reserved. + +using VectorData.ConformanceTests; + +namespace SqlServer.ConformanceTests; + +public class SqlServerTestSuiteImplementationTests : TestSuiteImplementationTests +{ + protected override ICollection IgnoredTestBases { get; } = + [ + // Hybrid search not supported + typeof(HybridSearchTests<>) + ]; +} diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/Support/AzureSqlRequiredAttribute.cs b/test/Microsoft.SqlServer.VectorData.ConformanceTests/Support/AzureSqlRequiredAttribute.cs new file mode 100644 index 0000000..bf81a1c --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/Support/AzureSqlRequiredAttribute.cs @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Data.SqlClient; +using Xunit; + +namespace SqlServer.ConformanceTests.Support; + +/// +/// Helpers for tests that require Azure SQL Database or SQL database in Microsoft Fabric. +/// +internal static class AzureSqlHelper +{ + private static bool? s_isAzureSql; + + public static async Task EnsureAzureSqlAsync() + { + if (s_isAzureSql is not null) + { + Assert.True(s_isAzureSql.Value, "This test requires Azure SQL Database or SQL database in Microsoft Fabric."); + return; + } + + var connectionString = SqlServerTestStore.Instance.ConnectionString; + + using var connection = new SqlConnection(connectionString); + await connection.OpenAsync(); + + using var command = connection.CreateCommand(); + command.CommandText = "SELECT SERVERPROPERTY('EngineEdition')"; + var result = await command.ExecuteScalarAsync(); + var engineEdition = Convert.ToInt32(result); + + // 5 = Azure SQL Database, 11 = SQL database in Microsoft Fabric + s_isAzureSql = engineEdition is 5 or 11; + + Assert.True(s_isAzureSql.Value, "This test requires Azure SQL Database or SQL database in Microsoft Fabric."); + } +} diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/Support/SqlServerTestEnvironment.cs b/test/Microsoft.SqlServer.VectorData.ConformanceTests/Support/SqlServerTestEnvironment.cs new file mode 100644 index 0000000..cece3e9 --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/Support/SqlServerTestEnvironment.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Extensions.Configuration; + +namespace SqlServer.ConformanceTests.Support; + +#pragma warning disable CA1810 // Initialize all static fields when those fields are declared + +internal static class SqlServerTestEnvironment +{ + public static readonly string? ConnectionString; + + public static bool IsConnectionStringDefined => ConnectionString is not null; + + static SqlServerTestEnvironment() + { + var configuration = new ConfigurationBuilder() + .AddJsonFile(path: "testsettings.json", optional: true) + .AddJsonFile(path: "testsettings.development.json", optional: true) + .AddEnvironmentVariables() + .AddUserSecrets() + .Build(); + + var sqlServerSection = configuration.GetSection("SqlServer"); + ConnectionString = sqlServerSection["ConnectionString"]; + } +} diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/Support/SqlServerTestStore.cs b/test/Microsoft.SqlServer.VectorData.ConformanceTests/Support/SqlServerTestStore.cs new file mode 100644 index 0000000..85be664 --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/Support/SqlServerTestStore.cs @@ -0,0 +1,108 @@ +// Copyright (c) Microsoft. All rights reserved. + +using System.Linq.Expressions; +using Microsoft.Data.SqlClient; +using Microsoft.Extensions.VectorData; +using Microsoft.SqlServer.VectorData; +using Testcontainers.MsSql; +using VectorData.ConformanceTests.Support; + +namespace SqlServer.ConformanceTests.Support; + +#pragma warning disable CA1001 // Type owns disposable fields but is not disposable + +internal sealed class SqlServerTestStore : TestStore +{ + public static SqlServerTestStore Instance { get; } = new(); + + private static readonly MsSqlContainer s_container = new MsSqlBuilder("mcr.microsoft.com/mssql/server:2025-latest") + .Build(); + + private string? _connectionString; + private bool _useExternalInstance; + + public string ConnectionString => this._connectionString ?? throw new InvalidOperationException("Not initialized"); + + public SqlServerVectorStore GetVectorStore(SqlServerVectorStoreOptions options) + => new(this.ConnectionString, options); + + public override string DefaultDistanceFunction => DistanceFunction.CosineDistance; + + private SqlServerTestStore() + { + } + + protected override async Task StartAsync() + { + if (SqlServerTestEnvironment.IsConnectionStringDefined) + { + this._connectionString = SqlServerTestEnvironment.ConnectionString!; + this._useExternalInstance = true; + } + else + { + // Use testcontainer if no external connection string is provided + await s_container.StartAsync(); + this._connectionString = s_container.GetConnectionString(); + this._useExternalInstance = false; + } + + this.DefaultVectorStore = new SqlServerVectorStore(this._connectionString); + } + + protected override async Task StopAsync() + { + // Only stop the container if we started it + if (!this._useExternalInstance) + { + await s_container.StopAsync(); + } + } + + public override async Task WaitForDataAsync( + VectorStoreCollection collection, + int recordCount, + Expression>? filter = null, + Expression>? vectorProperty = null, + int? vectorSize = null, + object? dummyVector = null) + { + // First wait for the base data to be visible via vector search + await base.WaitForDataAsync(collection, recordCount, filter, vectorProperty, vectorSize, dummyVector); + + // Then wait for full-text population to complete (if any full-text indexes exist) + await this.WaitForFullTextPopulationAsync(collection.Name); + } + + private async Task WaitForFullTextPopulationAsync(string tableName) + { + using var connection = new SqlConnection(this.ConnectionString); + await connection.OpenAsync(); + + // Query to check if full-text population is complete + var checkSql = @" + SELECT COUNT(*) + FROM sys.fulltext_indexes fi + JOIN sys.tables t ON fi.object_id = t.object_id + WHERE t.name = @tableName + AND fi.has_crawl_completed = 0"; + + using var command = new SqlCommand(checkSql, connection); + command.Parameters.AddWithValue("@tableName", tableName); + + for (int i = 0; i < 100; i++) // Wait up to 10 seconds + { + var result = await command.ExecuteScalarAsync(); + + if (result is int count && count == 0) + { + // Either no full-text indexes exist or all are populated + return; + } + + await Task.Delay(TimeSpan.FromMilliseconds(100)); + } + + // Don't fail the test - some tests might not need full-text + } +} diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/TypeTests/SqlServerDataTypeTests.cs b/test/Microsoft.SqlServer.VectorData.ConformanceTests/TypeTests/SqlServerDataTypeTests.cs new file mode 100644 index 0000000..db4f80a --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/TypeTests/SqlServerDataTypeTests.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServer.ConformanceTests.Support; +using VectorData.ConformanceTests.Support; +using VectorData.ConformanceTests.TypeTests; +using Xunit; + +namespace SqlServer.ConformanceTests.TypeTests; + +public class SqlServerDataTypeTests(SqlServerDataTypeTests.Fixture fixture) + : DataTypeTests.DefaultRecord>(fixture), IClassFixture +{ + public override Task String_array() + => this.Test( + "StringArray", + ["foo", "bar"], + ["foo", "baz"], + // SQL Server doesn't support comparing JSON + isFilterable: false); + + public new class Fixture : DataTypeTests.DefaultRecord>.Fixture + { + public override TestStore TestStore => SqlServerTestStore.Instance; + } +} diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/TypeTests/SqlServerEmbeddingTypeTests.cs b/test/Microsoft.SqlServer.VectorData.ConformanceTests/TypeTests/SqlServerEmbeddingTypeTests.cs new file mode 100644 index 0000000..ca33bf9 --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/TypeTests/SqlServerEmbeddingTypeTests.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft. All rights reserved. + +using Microsoft.Data.SqlTypes; +using SqlServer.ConformanceTests.Support; +using VectorData.ConformanceTests.Support; +using VectorData.ConformanceTests.TypeTests; +using Xunit; + +#pragma warning disable CA2000 // Dispose objects before losing scope + +namespace SqlServer.ConformanceTests.TypeTests; + +public class SqlServerEmbeddingTypeTests(SqlServerEmbeddingTypeTests.Fixture fixture) + : EmbeddingTypeTests(fixture), IClassFixture +{ + [Fact] + public virtual Task SqlVector_of_float() + => this.Test>( + new SqlVector(new float[] { 1, 2, 3 }), + new ReadOnlyMemoryEmbeddingGenerator([1, 2, 3]), + vectorEqualityAsserter: (e, a) => Assert.Equal(e.Memory.ToArray(), a.Memory.ToArray())); + + public new class Fixture : EmbeddingTypeTests.Fixture + { + public override TestStore TestStore => SqlServerTestStore.Instance; + } +} diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/TypeTests/SqlServerKeyTypeTests.cs b/test/Microsoft.SqlServer.VectorData.ConformanceTests/TypeTests/SqlServerKeyTypeTests.cs new file mode 100644 index 0000000..5cc312e --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/TypeTests/SqlServerKeyTypeTests.cs @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft. All rights reserved. + +using SqlServer.ConformanceTests.Support; +using VectorData.ConformanceTests.Support; +using VectorData.ConformanceTests.TypeTests; +using Xunit; + +namespace SqlServer.ConformanceTests.TypeTests; + +public class SqlServerKeyTypeTests(SqlServerKeyTypeTests.Fixture fixture) + : KeyTypeTests(fixture), IClassFixture +{ + [Fact] + public virtual Task Int() => this.Test(8, supportsAutoGeneration: true); + + [Fact] + public virtual Task Long() => this.Test(8L, supportsAutoGeneration: true); + + [Fact] + public virtual Task String() => this.Test("foo", "bar"); + + public new class Fixture : KeyTypeTests.Fixture + { + public override TestStore TestStore => SqlServerTestStore.Instance; + } +} diff --git a/test/Microsoft.SqlServer.VectorData.ConformanceTests/testsettings.json b/test/Microsoft.SqlServer.VectorData.ConformanceTests/testsettings.json new file mode 100644 index 0000000..ac88bad --- /dev/null +++ b/test/Microsoft.SqlServer.VectorData.ConformanceTests/testsettings.json @@ -0,0 +1,8 @@ +{ + // Optional: Provide a connection string to use an external SQL Server instance instead of testcontainers + // This is useful for running tests against Azure SQL or other external SQL Server instances + // If not specified, tests will automatically use a testcontainer + "SqlServer": { + "ConnectionString": "" + } +}