diff --git a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs index dc7100c7..5aa8b160 100644 --- a/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs +++ b/src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs @@ -154,6 +154,7 @@ private IEnumerable DeclareFriendlyOverload( int iidPpvRiidOrigIndex = -1; int iidPpvPpvOrigIndex = -1; bool iidPpvMarshalingMode = false; + bool iidPpvUseNativeOutMarshaling = false; if (this.options.FriendlyOverloads.ComOutPtrGenericOverloads) { @@ -186,10 +187,12 @@ private IEnumerable DeclareFriendlyOverload( // Skip if ppv is typed as IntPtr (UseIntPtrForComOutPointers mode). if (ppvExtern.Type is not IdentifierNameSyntax { Identifier.ValueText: nameof(IntPtr) }) { + bool ppvExternIsObjectOut = ppvExtern.Modifiers.Any(SyntaxKind.OutKeyword) + && ppvExtern.Type is PredefinedTypeSyntax { Keyword.RawKind: (int)SyntaxKind.ObjectKeyword }; iidPpvRiidOrigIndex = riidOrig; iidPpvPpvOrigIndex = ppvOrig; - iidPpvMarshalingMode = ppvExtern.Modifiers.Any(SyntaxKind.OutKeyword) - && ppvExtern.Type is PredefinedTypeSyntax { Keyword.RawKind: (int)SyntaxKind.ObjectKeyword }; + iidPpvMarshalingMode = this.options.AllowMarshaling; + iidPpvUseNativeOutMarshaling = iidPpvMarshalingMode && !ppvExternIsObjectOut; } } } @@ -248,28 +251,66 @@ private IEnumerable DeclareFriendlyOverload( { signatureChanged = true; IdentifierNameSyntax tName = IdentifierName("T"); + ParameterSyntax ppvExternParam = externMethodDeclaration.ParameterList.Parameters[paramIndex]; + IdentifierNameSyntax ppvName = IdentifierName(ppvExternParam.Identifier.ValueText); if (iidPpvMarshalingMode) { - parameters[paramIndex] = StripAttributes(externMethodDeclaration.ParameterList.Parameters[paramIndex]) + parameters[paramIndex] = StripAttributes(ppvExternParam) .WithType(tName.WithTrailingTrivia(TriviaList(Space))) .WithModifiers([TokenWithSpace(SyntaxKind.OutKeyword)]); - arguments[paramIndex] = Argument(DeclarationExpression( - PredefinedType(TokenWithSpace(SyntaxKind.ObjectKeyword)), - SingleVariableDesignation(Identifier("__ppv")))) - .WithRefKindKeyword(TokenWithSpace(SyntaxKind.OutKeyword)); + if (iidPpvUseNativeOutMarshaling) + { + IdentifierNameSyntax nativeLocal = IdentifierName("__ppv"); + leadingOutsideTryStatements.Add( + LocalDeclarationStatement(VariableDeclaration( + PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword))), + [VariableDeclarator(nativeLocal.Identifier, EqualsValueClause(LiteralExpression(SyntaxKind.NullLiteralExpression)))]))); - IdentifierNameSyntax ppvName = IdentifierName(externMethodDeclaration.ParameterList.Parameters[paramIndex].Identifier.ValueText); - trailingStatements.Add(ExpressionStatement( - AssignmentExpression( - SyntaxKind.SimpleAssignmentExpression, - ppvName, - CastExpression(tName, IdentifierName("__ppv"))))); + arguments[paramIndex] = Argument(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, nativeLocal)); + + ExpressionSyntax toManagedExpression = this.useSourceGenerators ? + InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + GenericName($"global::System.Runtime.InteropServices.Marshalling.ComInterfaceMarshaller", [tName]), + IdentifierName("ConvertToManaged")), + [Argument(nativeLocal)]) : + ParenthesizedExpression(ConditionalExpression( + BinaryExpression(SyntaxKind.NotEqualsExpression, nativeLocal, LiteralExpression(SyntaxKind.NullLiteralExpression)), + CastExpression(tName, InvocationExpression( + MemberAccessExpression( + SyntaxKind.SimpleMemberAccessExpression, + ParseTypeName($"global::System.Runtime.InteropServices.Marshal"), + IdentifierName("GetObjectForIUnknown")), + [Argument(CastExpression(ParseName("nint"), nativeLocal))])), + LiteralExpression(SyntaxKind.NullLiteralExpression))); + + trailingStatements.Add(ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + ppvName, + toManagedExpression))); + finallyStatements.Add(this.COMFreeNativePointerStatement(nativeLocal, tName)); + } + else + { + arguments[paramIndex] = Argument(DeclarationExpression( + PredefinedType(TokenWithSpace(SyntaxKind.ObjectKeyword)), + SingleVariableDesignation(Identifier("__ppv")))) + .WithRefKindKeyword(TokenWithSpace(SyntaxKind.OutKeyword)); + + trailingStatements.Add(ExpressionStatement( + AssignmentExpression( + SyntaxKind.SimpleAssignmentExpression, + ppvName, + CastExpression(tName, IdentifierName("__ppv"))))); + } } else { - parameters[paramIndex] = StripAttributes(externMethodDeclaration.ParameterList.Parameters[paramIndex]) + parameters[paramIndex] = StripAttributes(ppvExternParam) .WithType(PointerType(tName).WithTrailingTrivia(TriviaList(Space))) .WithModifiers([TokenWithSpace(SyntaxKind.OutKeyword)]); @@ -280,7 +321,6 @@ private IEnumerable DeclareFriendlyOverload( arguments[paramIndex] = Argument(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, IdentifierName("__ppv"))); - IdentifierNameSyntax ppvName = IdentifierName(externMethodDeclaration.ParameterList.Parameters[paramIndex].Identifier.ValueText); trailingStatements.Add(ExpressionStatement( AssignmentExpression( SyntaxKind.SimpleAssignmentExpression, diff --git a/test/GenerationSandbox.BuildTask.Tests/COMTests.cs b/test/GenerationSandbox.BuildTask.Tests/COMTests.cs index e6732350..257ee393 100644 --- a/test/GenerationSandbox.BuildTask.Tests/COMTests.cs +++ b/test/GenerationSandbox.BuildTask.Tests/COMTests.cs @@ -18,6 +18,7 @@ using Windows.Win32.Graphics.Direct2D.Common; using Windows.Win32.Graphics.Direct3D; using Windows.Win32.Graphics.Direct3D11; +using Windows.Win32.Graphics.Direct3D12; using Windows.Win32.Graphics.Dxgi.Common; using Windows.Win32.NetworkManagement.WindowsFirewall; using Windows.Win32.Storage.FileSystem; @@ -33,6 +34,8 @@ namespace GenerationSandbox.BuildTask.Tests; [Trait("WindowsOnly", "true")] public partial class COMTests(ITestOutputHelper outputHelper) { + private delegate void CreateCommittedResourceGenericOverloadCompileDelegate(ID3D12Device device, in D3D12_HEAP_PROPERTIES heapProperties, in D3D12_RESOURCE_DESC resourceDesc); + private ITestOutputHelper outputHelper = outputHelper; [Fact] @@ -108,6 +111,25 @@ public void IsSHGetFileInfoEasilyCalled() SHGFI_FLAGS.SHGFI_DISPLAYNAME); } + [Fact] + public void ID3D12Device_CreateCommittedResource_GenericOverloadsCompile() + { + CreateCommittedResourceGenericOverloadCompileDelegate compileOnly = CompileOnlyCreateCommittedResourceGenericOverload; + Assert.NotNull(compileOnly); + } + + private static void CompileOnlyCreateCommittedResourceGenericOverload(ID3D12Device device, in D3D12_HEAP_PROPERTIES heapProperties, in D3D12_RESOURCE_DESC resourceDesc) + { + device.CreateCommittedResource( + in heapProperties, + D3D12_HEAP_FLAGS.D3D12_HEAP_FLAG_NONE, + in resourceDesc, + D3D12_RESOURCE_STATES.D3D12_RESOURCE_STATE_COMMON, + null, + out ID3D12Resource resource); + GC.KeepAlive(resource); + } + [Fact] [Trait("TestCategory", "RequiresHardware")] // D3D APIs fail in cloud VMs diff --git a/test/GenerationSandbox.BuildTask.Tests/NativeMethods.txt b/test/GenerationSandbox.BuildTask.Tests/NativeMethods.txt index de01d40a..c7e378dc 100644 --- a/test/GenerationSandbox.BuildTask.Tests/NativeMethods.txt +++ b/test/GenerationSandbox.BuildTask.Tests/NativeMethods.txt @@ -25,6 +25,8 @@ HMENU ICompositorInterop ID2D1HwndRenderTarget ID2D1RenderTarget +ID3D12Device +ID3D12Resource IDebugProperty IDispatch IEventSubscription @@ -85,4 +87,4 @@ IEnumVARIANT INetFwAuthorizedApplication SetupDiGetClassDevs SetupDiEnumDeviceInfo -SetupDiGetDeviceInstanceId \ No newline at end of file +SetupDiGetDeviceInstanceId diff --git a/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs b/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs index 8722c005..4a4a1a2f 100644 --- a/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs +++ b/test/Microsoft.Windows.CsWin32.Tests/COMTests.cs @@ -483,6 +483,46 @@ public void ComOutPtrGenericOverload_D3D12_CreateCommandQueue(bool allowMarshali m => m.TypeParameterList?.Parameters.Count == 1); } + /// + /// Regression test for #1608. + /// Optional IID_PPV_ARGS out parameters should still get managed generic overloads. + /// + [Theory, PairwiseData] + public void ComOutPtrGenericOverload_D3D12_OptionalOutCreateCommittedResource(bool useComSourceGenerators) + { + const string methodName = "CreateCommittedResource"; + if (useComSourceGenerators) + { + this.compilation = this.starterCompilations["net10.0"]; + this.parseOptions = this.parseOptions.WithLanguageVersion(GetLanguageVersionForTfm("net10.0") ?? LanguageVersion.Latest); + } + + this.generator = this.CreateGenerator(new GeneratorOptions + { + AllowMarshaling = true, + ComInterop = new GeneratorOptions.ComInteropOptions { UseComSourceGenerators = useComSourceGenerators }, + }); + Assert.True(this.generator.TryGenerate("ID3D12Device", CancellationToken.None)); + this.CollectGeneratedCode(this.generator); + this.AssertNoDiagnostics( + this.compilation, + logAllGeneratedCode: false, + acceptable: d => useComSourceGenerators && (d.Id is "CS8795" or "CS1574")); + + List genericOverloads = this.FindGeneratedMethod(methodName) + .Where(m => m.TypeParameterList?.Parameters.Count == 1 + && m.ParameterList.Parameters.FirstOrDefault() is { } first + && first.Modifiers.Any(SyntaxKind.ThisKeyword) + && first.Type?.ToString().Contains("ID3D12Device") == true) + .ToList(); + + Assert.Contains(genericOverloads, m => + IsClassConstrainedGeneric(m) + && m.ParameterList.Parameters.All(p => p.Identifier.ValueText != "riidResource") + && m.ParameterList.Parameters.Last() is { Identifier.ValueText: "ppvResource", Type: IdentifierNameSyntax { Identifier.ValueText: "T" } } ppv + && ppv.Modifiers.Any(SyntaxKind.OutKeyword)); + } + /// /// Regression test for #374. /// IMoniker.BindToObject should get a generic overload. @@ -1096,6 +1136,9 @@ public void COMStructWrappers_ClsPragmaSuppressionsAreLoadBearing_Issue1703() private static string StripWarningDisablePragmas(string code) => string.Join("\n", code.Split('\n').Where(line => !line.TrimStart().StartsWith("#pragma warning disable", StringComparison.Ordinal))); + private static bool IsClassConstrainedGeneric(MethodDeclarationSyntax method) => + method.ConstraintClauses.Any(cc => cc.Constraints.Any(c => c is ClassOrStructConstraintSyntax { ClassOrStructKeyword.RawKind: (int)SyntaxKind.ClassKeyword })); + private static string? GetDocumentationComment(MemberDeclarationSyntax member) => member.GetLeadingTrivia() .Select(t => t.GetStructure())