Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 55 additions & 15 deletions src/Microsoft.Windows.CsWin32/Generator.FriendlyOverloads.cs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(
int iidPpvRiidOrigIndex = -1;
int iidPpvPpvOrigIndex = -1;
bool iidPpvMarshalingMode = false;
bool iidPpvUseNativeOutMarshaling = false;

if (this.options.FriendlyOverloads.ComOutPtrGenericOverloads)
{
Expand Down Expand Up @@ -186,10 +187,12 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(
// Skip if ppv is typed as IntPtr (UseIntPtrForComOutPointers mode).
if (ppvExtern.Type is not IdentifierNameSyntax { Identifier.ValueText: nameof(IntPtr) })
{
bool ppvExternIsObjectOut = ppvExtern.Modifiers.Any(SyntaxKind.OutKeyword)
&& ppvExtern.Type is PredefinedTypeSyntax { Keyword.RawKind: (int)SyntaxKind.ObjectKeyword };
iidPpvRiidOrigIndex = riidOrig;
iidPpvPpvOrigIndex = ppvOrig;
iidPpvMarshalingMode = ppvExtern.Modifiers.Any(SyntaxKind.OutKeyword)
&& ppvExtern.Type is PredefinedTypeSyntax { Keyword.RawKind: (int)SyntaxKind.ObjectKeyword };
iidPpvMarshalingMode = this.options.AllowMarshaling;
iidPpvUseNativeOutMarshaling = iidPpvMarshalingMode && !ppvExternIsObjectOut;
}
}
}
Expand Down Expand Up @@ -248,28 +251,66 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(
{
signatureChanged = true;
IdentifierNameSyntax tName = IdentifierName("T");
ParameterSyntax ppvExternParam = externMethodDeclaration.ParameterList.Parameters[paramIndex];
IdentifierNameSyntax ppvName = IdentifierName(ppvExternParam.Identifier.ValueText);

if (iidPpvMarshalingMode)
{
parameters[paramIndex] = StripAttributes(externMethodDeclaration.ParameterList.Parameters[paramIndex])
parameters[paramIndex] = StripAttributes(ppvExternParam)
.WithType(tName.WithTrailingTrivia(TriviaList(Space)))
.WithModifiers([TokenWithSpace(SyntaxKind.OutKeyword)]);

arguments[paramIndex] = Argument(DeclarationExpression(
PredefinedType(TokenWithSpace(SyntaxKind.ObjectKeyword)),
SingleVariableDesignation(Identifier("__ppv"))))
.WithRefKindKeyword(TokenWithSpace(SyntaxKind.OutKeyword));
if (iidPpvUseNativeOutMarshaling)
{
IdentifierNameSyntax nativeLocal = IdentifierName("__ppv");
leadingOutsideTryStatements.Add(
LocalDeclarationStatement(VariableDeclaration(
PointerType(PredefinedType(Token(SyntaxKind.VoidKeyword))),
[VariableDeclarator(nativeLocal.Identifier, EqualsValueClause(LiteralExpression(SyntaxKind.NullLiteralExpression)))])));

IdentifierNameSyntax ppvName = IdentifierName(externMethodDeclaration.ParameterList.Parameters[paramIndex].Identifier.ValueText);
trailingStatements.Add(ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
ppvName,
CastExpression(tName, IdentifierName("__ppv")))));
arguments[paramIndex] = Argument(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, nativeLocal));

ExpressionSyntax toManagedExpression = this.useSourceGenerators ?
InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
GenericName($"global::System.Runtime.InteropServices.Marshalling.ComInterfaceMarshaller", [tName]),
IdentifierName("ConvertToManaged")),
[Argument(nativeLocal)]) :
ParenthesizedExpression(ConditionalExpression(
BinaryExpression(SyntaxKind.NotEqualsExpression, nativeLocal, LiteralExpression(SyntaxKind.NullLiteralExpression)),
CastExpression(tName, InvocationExpression(
MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
ParseTypeName($"global::System.Runtime.InteropServices.Marshal"),
IdentifierName("GetObjectForIUnknown")),
[Argument(CastExpression(ParseName("nint"), nativeLocal))])),
LiteralExpression(SyntaxKind.NullLiteralExpression)));

trailingStatements.Add(ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
ppvName,
toManagedExpression)));
finallyStatements.Add(this.COMFreeNativePointerStatement(nativeLocal, tName));
}
else
{
arguments[paramIndex] = Argument(DeclarationExpression(
PredefinedType(TokenWithSpace(SyntaxKind.ObjectKeyword)),
SingleVariableDesignation(Identifier("__ppv"))))
.WithRefKindKeyword(TokenWithSpace(SyntaxKind.OutKeyword));

trailingStatements.Add(ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
ppvName,
CastExpression(tName, IdentifierName("__ppv")))));
}
}
else
{
parameters[paramIndex] = StripAttributes(externMethodDeclaration.ParameterList.Parameters[paramIndex])
parameters[paramIndex] = StripAttributes(ppvExternParam)
.WithType(PointerType(tName).WithTrailingTrivia(TriviaList(Space)))
.WithModifiers([TokenWithSpace(SyntaxKind.OutKeyword)]);

Expand All @@ -280,7 +321,6 @@ private IEnumerable<MethodDeclarationSyntax> DeclareFriendlyOverload(

arguments[paramIndex] = Argument(PrefixUnaryExpression(SyntaxKind.AddressOfExpression, IdentifierName("__ppv")));

IdentifierNameSyntax ppvName = IdentifierName(externMethodDeclaration.ParameterList.Parameters[paramIndex].Identifier.ValueText);
trailingStatements.Add(ExpressionStatement(
AssignmentExpression(
SyntaxKind.SimpleAssignmentExpression,
Expand Down
22 changes: 22 additions & 0 deletions test/GenerationSandbox.BuildTask.Tests/COMTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
using Windows.Win32.Graphics.Direct2D.Common;
using Windows.Win32.Graphics.Direct3D;
using Windows.Win32.Graphics.Direct3D11;
using Windows.Win32.Graphics.Direct3D12;
using Windows.Win32.Graphics.Dxgi.Common;
using Windows.Win32.NetworkManagement.WindowsFirewall;
using Windows.Win32.Storage.FileSystem;
Expand All @@ -33,6 +34,8 @@ namespace GenerationSandbox.BuildTask.Tests;
[Trait("WindowsOnly", "true")]
public partial class COMTests(ITestOutputHelper outputHelper)
{
private delegate void CreateCommittedResourceGenericOverloadCompileDelegate(ID3D12Device device, in D3D12_HEAP_PROPERTIES heapProperties, in D3D12_RESOURCE_DESC resourceDesc);

private ITestOutputHelper outputHelper = outputHelper;

[Fact]
Expand Down Expand Up @@ -108,6 +111,25 @@ public void IsSHGetFileInfoEasilyCalled()
SHGFI_FLAGS.SHGFI_DISPLAYNAME);
}

[Fact]
public void ID3D12Device_CreateCommittedResource_GenericOverloadsCompile()
{
CreateCommittedResourceGenericOverloadCompileDelegate compileOnly = CompileOnlyCreateCommittedResourceGenericOverload;
Assert.NotNull(compileOnly);
}

private static void CompileOnlyCreateCommittedResourceGenericOverload(ID3D12Device device, in D3D12_HEAP_PROPERTIES heapProperties, in D3D12_RESOURCE_DESC resourceDesc)
{
device.CreateCommittedResource<ID3D12Resource>(
in heapProperties,
D3D12_HEAP_FLAGS.D3D12_HEAP_FLAG_NONE,
in resourceDesc,
D3D12_RESOURCE_STATES.D3D12_RESOURCE_STATE_COMMON,
null,
out ID3D12Resource resource);
GC.KeepAlive(resource);
}


[Fact]
[Trait("TestCategory", "RequiresHardware")] // D3D APIs fail in cloud VMs
Expand Down
4 changes: 3 additions & 1 deletion test/GenerationSandbox.BuildTask.Tests/NativeMethods.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ HMENU
ICompositorInterop
ID2D1HwndRenderTarget
ID2D1RenderTarget
ID3D12Device
ID3D12Resource
IDebugProperty
IDispatch
IEventSubscription
Expand Down Expand Up @@ -85,4 +87,4 @@ IEnumVARIANT
INetFwAuthorizedApplication
SetupDiGetClassDevs
SetupDiEnumDeviceInfo
SetupDiGetDeviceInstanceId
SetupDiGetDeviceInstanceId
43 changes: 43 additions & 0 deletions test/Microsoft.Windows.CsWin32.Tests/COMTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,46 @@ public void ComOutPtrGenericOverload_D3D12_CreateCommandQueue(bool allowMarshali
m => m.TypeParameterList?.Parameters.Count == 1);
}

/// <summary>
/// Regression test for <see href="https://github.com/microsoft/CsWin32/issues/1608">#1608</see>.
/// Optional IID_PPV_ARGS out parameters should still get managed generic overloads.
/// </summary>
[Theory, PairwiseData]
public void ComOutPtrGenericOverload_D3D12_OptionalOutCreateCommittedResource(bool useComSourceGenerators)
{
const string methodName = "CreateCommittedResource";
if (useComSourceGenerators)
{
this.compilation = this.starterCompilations["net10.0"];
this.parseOptions = this.parseOptions.WithLanguageVersion(GetLanguageVersionForTfm("net10.0") ?? LanguageVersion.Latest);
}

this.generator = this.CreateGenerator(new GeneratorOptions
{
AllowMarshaling = true,
ComInterop = new GeneratorOptions.ComInteropOptions { UseComSourceGenerators = useComSourceGenerators },
});
Assert.True(this.generator.TryGenerate("ID3D12Device", CancellationToken.None));
this.CollectGeneratedCode(this.generator);
this.AssertNoDiagnostics(
this.compilation,
logAllGeneratedCode: false,
acceptable: d => useComSourceGenerators && (d.Id is "CS8795" or "CS1574"));

List<MethodDeclarationSyntax> genericOverloads = this.FindGeneratedMethod(methodName)
.Where(m => m.TypeParameterList?.Parameters.Count == 1
&& m.ParameterList.Parameters.FirstOrDefault() is { } first
&& first.Modifiers.Any(SyntaxKind.ThisKeyword)
&& first.Type?.ToString().Contains("ID3D12Device") == true)
.ToList();

Assert.Contains(genericOverloads, m =>
IsClassConstrainedGeneric(m)
&& m.ParameterList.Parameters.All(p => p.Identifier.ValueText != "riidResource")
&& m.ParameterList.Parameters.Last() is { Identifier.ValueText: "ppvResource", Type: IdentifierNameSyntax { Identifier.ValueText: "T" } } ppv
&& ppv.Modifiers.Any(SyntaxKind.OutKeyword));
}

/// <summary>
/// Regression test for <see href="https://github.com/microsoft/CsWin32/issues/374">#374</see>.
/// IMoniker.BindToObject should get a generic overload.
Expand Down Expand Up @@ -1096,6 +1136,9 @@ public void COMStructWrappers_ClsPragmaSuppressionsAreLoadBearing_Issue1703()
private static string StripWarningDisablePragmas(string code) =>
string.Join("\n", code.Split('\n').Where(line => !line.TrimStart().StartsWith("#pragma warning disable", StringComparison.Ordinal)));

private static bool IsClassConstrainedGeneric(MethodDeclarationSyntax method) =>
method.ConstraintClauses.Any(cc => cc.Constraints.Any(c => c is ClassOrStructConstraintSyntax { ClassOrStructKeyword.RawKind: (int)SyntaxKind.ClassKeyword }));

private static string? GetDocumentationComment(MemberDeclarationSyntax member) =>
member.GetLeadingTrivia()
.Select(t => t.GetStructure())
Expand Down
Loading