diff --git a/src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs b/src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs index a8fd4e2337819a..6409dc8eb0eb9e 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs @@ -1891,23 +1891,24 @@ internal static void SafeHandleRelease(SafeHandle pHandle) private static extern IntPtr GetCOMIPFromRCW(object objSrc, IntPtr pCPCMD, out IntPtr ppTarget); [LibraryImport(RuntimeHelpers.QCall, EntryPoint = "StubHelpers_GetCOMIPFromRCWSlow")] - private static partial IntPtr GetCOMIPFromRCWSlow(ObjectHandleOnStack objSrc, IntPtr pCPCMD, out IntPtr ppTarget); + private static partial IntPtr GetCOMIPFromRCWSlow(ObjectHandleOnStack objSrc, IntPtr pCPCMD, out IntPtr ppTarget, [MarshalAs(UnmanagedType.Bool)] out bool pfNeedsRelease); internal static IntPtr GetCOMIPFromRCW(object objSrc, IntPtr pCPCMD, out IntPtr ppTarget, out bool pfNeedsRelease) { IntPtr rcw = GetCOMIPFromRCW(objSrc, pCPCMD, out ppTarget); - if (rcw == IntPtr.Zero) + if (rcw != IntPtr.Zero) { - // If we didn't find the COM interface pointer in the cache we need to release the pointer. - pfNeedsRelease = true; - return GetCOMIPFromRCWWorker(objSrc, pCPCMD, out ppTarget); + pfNeedsRelease = false; + return rcw; } - pfNeedsRelease = false; - return rcw; + + // The slow path may create OLE TLS and then still resolve the interface via the RCW cache. + // Let the slow path tell us whether it returned an owned pointer that requires cleanup. + return GetCOMIPFromRCWWorker(objSrc, pCPCMD, out ppTarget, out pfNeedsRelease); [MethodImpl(MethodImplOptions.NoInlining)] - static IntPtr GetCOMIPFromRCWWorker(object objSrc, IntPtr pCPCMD, out IntPtr ppTarget) - => GetCOMIPFromRCWSlow(ObjectHandleOnStack.Create(ref objSrc), pCPCMD, out ppTarget); + static IntPtr GetCOMIPFromRCWWorker(object objSrc, IntPtr pCPCMD, out IntPtr ppTarget, out bool pfNeedsRelease) + => GetCOMIPFromRCWSlow(ObjectHandleOnStack.Create(ref objSrc), pCPCMD, out ppTarget, out pfNeedsRelease); } #endif // FEATURE_COMINTEROP diff --git a/src/coreclr/vm/stubhelpers.cpp b/src/coreclr/vm/stubhelpers.cpp index 36dd859b5c7a84..4cca5e9019f3dd 100644 --- a/src/coreclr/vm/stubhelpers.cpp +++ b/src/coreclr/vm/stubhelpers.cpp @@ -312,11 +312,12 @@ FCIMPLEND #include -extern "C" IUnknown* QCALLTYPE StubHelpers_GetCOMIPFromRCWSlow(QCall::ObjectHandleOnStack pSrc, MethodDesc* pMD, void** ppTarget) +extern "C" IUnknown* QCALLTYPE StubHelpers_GetCOMIPFromRCWSlow(QCall::ObjectHandleOnStack pSrc, MethodDesc* pMD, void** ppTarget, BOOL* pfNeedsRelease) { QCALL_CONTRACT; _ASSERTE(pMD != NULL); _ASSERTE(ppTarget != NULL); + _ASSERTE(pfNeedsRelease != NULL); IUnknown *pIntf = NULL; BEGIN_QCALL; @@ -326,6 +327,8 @@ extern "C" IUnknown* QCALLTYPE StubHelpers_GetCOMIPFromRCWSlow(QCall::ObjectHand OBJECTREF objRef = pSrc.Get(); GCPROTECT_BEGIN(objRef); + *pfNeedsRelease = FALSE; + // This snippet exists to enable OLE TLS data creation that isn't possible on the fast path. // It is practically identical to the StubHelpers::GetCOMIPFromRCW FCALL, but in the event the OLE TLS // data on this thread hasn't occurred yet, we will create it. Since this is the slow path, trying the @@ -335,17 +338,19 @@ extern "C" IUnknown* QCALLTYPE StubHelpers_GetCOMIPFromRCWSlow(QCall::ObjectHand RCW* pRCW = objRef->PassiveGetSyncBlock()->GetInteropInfoNoCreate()->GetRawRCW(); if (pRCW != NULL) { - IUnknown* pUnk = GetCOMIPFromRCW_GetTargetFromRCWCache(pOleTlsData, pRCW, pComInfo, ppTarget); - if (pUnk != NULL) - return pUnk; + pIntf = GetCOMIPFromRCW_GetTargetFromRCWCache(pOleTlsData, pRCW, pComInfo, ppTarget); } - // Still not in the cache and we've ensured the OLE TLS data was created. - SafeComHolder pRetUnk = ComObject::GetComIPFromRCWThrowing(&objRef, pComInfo->m_pInterfaceMT); - *ppTarget = GetCOMIPFromRCW_GetTarget(pRetUnk, pComInfo); - _ASSERTE(*ppTarget != NULL); + if (pIntf == NULL) + { + // Still not in the cache and we've ensured the OLE TLS data was created. + SafeComHolder pRetUnk = ComObject::GetComIPFromRCWThrowing(&objRef, pComInfo->m_pInterfaceMT); + *ppTarget = GetCOMIPFromRCW_GetTarget(pRetUnk, pComInfo); + _ASSERTE(*ppTarget != NULL); - pIntf = pRetUnk.Extract(); + pIntf = pRetUnk.Extract(); + *pfNeedsRelease = TRUE; + } GCPROTECT_END(); diff --git a/src/coreclr/vm/stubhelpers.h b/src/coreclr/vm/stubhelpers.h index 38ae6ae5b7e76e..140bac12e80885 100644 --- a/src/coreclr/vm/stubhelpers.h +++ b/src/coreclr/vm/stubhelpers.h @@ -46,7 +46,7 @@ extern "C" void QCALLTYPE StubHelpers_ProfilerEndTransitionCallback(MethodDesc* #endif #ifdef FEATURE_COMINTEROP -extern "C" IUnknown* QCALLTYPE StubHelpers_GetCOMIPFromRCWSlow(QCall::ObjectHandleOnStack pSrc, MethodDesc* pMD, void** ppTarget); +extern "C" IUnknown* QCALLTYPE StubHelpers_GetCOMIPFromRCWSlow(QCall::ObjectHandleOnStack pSrc, MethodDesc* pMD, void** ppTarget, BOOL* pfNeedsRelease); extern "C" void QCALLTYPE ObjectMarshaler_ConvertToNative(QCall::ObjectHandleOnStack pSrcUNSAFE, VARIANT* pDest); extern "C" void QCALLTYPE ObjectMarshaler_ConvertToManaged(VARIANT* pSrc, QCall::ObjectHandleOnStack retObject); diff --git a/src/tests/Interop/COM/NETClients/Lifetime/Program.cs b/src/tests/Interop/COM/NETClients/Lifetime/Program.cs index 9bdc9d8ea6ea2c..a0a6c254d4d996 100644 --- a/src/tests/Interop/COM/NETClients/Lifetime/Program.cs +++ b/src/tests/Interop/COM/NETClients/Lifetime/Program.cs @@ -17,8 +17,14 @@ namespace NetClient public unsafe class Program { static delegate* unmanaged GetAllocationCount; + static ITrackMyLifetimeTesting? s_agileInstance; + static Exception? s_callbackException; + + [DllImport("COMNativeServer", EntryPoint = "InvokeCallbackOnNativeThread")] + private static extern int InvokeCallbackOnNativeThread(delegate* unmanaged callback); // Initialize for all tests + [MethodImpl(MethodImplOptions.NoInlining)] static void Initialize() { var inst = new TrackMyLifetimeTesting(); @@ -45,6 +51,19 @@ static void ForceGC() } } + [UnmanagedCallersOnly] + static void InvokeObjectFromNativeThread() + { + try + { + s_agileInstance!.Method(); + } + catch (Exception e) + { + s_callbackException = e; + } + } + static void Validate_COMServer_CleanUp() { Console.WriteLine($"Calling {nameof(Validate_COMServer_CleanUp)}..."); @@ -85,36 +104,57 @@ static void Validate_COMServer_DisableEagerCleanUp() Assert.False(Marshal.AreComObjectsAvailableForCleanup()); } - [Fact] - public static int TestEntryPoint() + static void Validate_COMServer_CallOnNativeThread() { - // RegFree COM and STA apartments are not supported on Windows Nano - if (Utilities.IsWindowsNanoServer) + Console.WriteLine($"Calling {nameof(Validate_COMServer_CallOnNativeThread)}..."); + + // Need agile instance since the object will be used on a different thread + // than the creating thread and we're on an STA thread. + s_agileInstance = CreateAgileInstance(); + try + { + s_agileInstance.Method(); + + // Create a fresh native thread for each callback so the COM call runs before that thread + // has initialized the CLR's OLE TLS state. + for (int i = 0; i < 10; i++) + { + s_callbackException = null; + + Marshal.ThrowExceptionForHR(InvokeCallbackOnNativeThread(&InvokeObjectFromNativeThread)); + + Assert.True(s_callbackException is null, s_callbackException?.ToString()); + } + } + finally { - return 100; + s_agileInstance = null; } - int result = 101; + [MethodImpl(MethodImplOptions.NoInlining)] + static ITrackMyLifetimeTesting CreateAgileInstance() + => new TrackMyLifetimeTesting().CreateAgileInstance(); + } + + const int TestFailed = 101; + const int TestPassed = 100; + + static int RunOnSTAThread(Action action) + { + int result = TestFailed; - // Run the test on a new STA thread since Nano Server doesn't support the STA - // and as a result, the main application thread can't be made STA with the STAThread attribute Thread staThread = new Thread(() => { try { - // Initialization for all future tests - Initialize(); - Assert.True(GetAllocationCount != null); - - Validate_COMServer_CleanUp(); - Validate_COMServer_DisableEagerCleanUp(); + action(); } catch (Exception e) { Console.WriteLine($"Test Failure: {e}"); - result = 101; + result = TestFailed; } - result = 100; + result = TestPassed; }); staThread.SetApartmentState(ApartmentState.STA); @@ -123,5 +163,45 @@ public static int TestEntryPoint() return result; } + + [Fact] + public static int TestEntryPoint() + { + // RegFree COM and STA apartments are not supported on Windows Nano + if (Utilities.IsWindowsNanoServer) + { + return TestPassed; + } + + // Run the test on a new STA thread since Nano Server doesn't support the STA + // and as a result, the main application thread can't be made STA with the STAThread attribute + int result = RunOnSTAThread(() => + { + // Initialization for all future tests + Initialize(); + ForceGC(); + Assert.True(GetAllocationCount != null); + + Validate_COMServer_CleanUp(); + Validate_COMServer_CallOnNativeThread(); + }); + if (result != TestPassed) + { + return result; + } + + return RunOnSTAThread(() => + { + // Initialization for all future tests + Initialize(); + ForceGC(); + Assert.True(GetAllocationCount != null); + + // Manipulating the eager cleanup state cannot be changed once set, + // so we need to run this test on a separate thread after the first + // test validates that cleanup is working as expected with eager cleanup enabled. + Validate_COMServer_DisableEagerCleanUp(); + }); + } } } diff --git a/src/tests/Interop/COM/NativeServer/Exports.def b/src/tests/Interop/COM/NativeServer/Exports.def index 2d0de26b056e0b..cb6d5520db0bb1 100644 --- a/src/tests/Interop/COM/NativeServer/Exports.def +++ b/src/tests/Interop/COM/NativeServer/Exports.def @@ -1,4 +1,5 @@ EXPORTS DllGetClassObject PRIVATE DllRegisterServer PRIVATE - DllUnregisterServer PRIVATE \ No newline at end of file + DllUnregisterServer PRIVATE + InvokeCallbackOnNativeThread diff --git a/src/tests/Interop/COM/NativeServer/Servers.cpp b/src/tests/Interop/COM/NativeServer/Servers.cpp index ebe4f9df5acde2..ed47fd9591b91c 100644 --- a/src/tests/Interop/COM/NativeServer/Servers.cpp +++ b/src/tests/Interop/COM/NativeServer/Servers.cpp @@ -3,6 +3,7 @@ #include "stdafx.h" #include "Servers.h" +#include namespace { @@ -155,6 +156,20 @@ namespace } } +extern "C" HRESULT STDMETHODCALLTYPE InvokeCallbackOnNativeThread(void (STDMETHODCALLTYPE* callback)()) +{ + if (callback == nullptr) + return E_POINTER; + + std::thread worker([callback]() + { + callback(); + }); + + worker.join(); + return S_OK; +} + STDAPI DllRegisterServer(void) { HRESULT hr; diff --git a/src/tests/Interop/COM/NativeServer/TrackMyLifetimeTesting.h b/src/tests/Interop/COM/NativeServer/TrackMyLifetimeTesting.h index 3d0ac26da0a16a..6e3c3ee3c4034f 100644 --- a/src/tests/Interop/COM/NativeServer/TrackMyLifetimeTesting.h +++ b/src/tests/Interop/COM/NativeServer/TrackMyLifetimeTesting.h @@ -5,7 +5,7 @@ #include "Servers.h" -class TrackMyLifetimeTesting : public UnknownImpl, public ITrackMyLifetimeTesting +class TrackMyLifetimeTesting : public UnknownImpl, public ITrackMyLifetimeTesting, public IAgileObject { static std::atomic _instanceCount; @@ -14,8 +14,15 @@ class TrackMyLifetimeTesting : public UnknownImpl, public ITrackMyLifetimeTestin return _instanceCount; } +private: + const bool _isAgileInstance = false; + public: TrackMyLifetimeTesting() + : TrackMyLifetimeTesting(false) + { } + TrackMyLifetimeTesting(bool isAgileInstance) + : _isAgileInstance(isAgileInstance) { _instanceCount++; } @@ -34,11 +41,34 @@ class TrackMyLifetimeTesting : public UnknownImpl, public ITrackMyLifetimeTestin return S_OK; } + DEF_FUNC(CreateAgileInstance)(ITrackMyLifetimeTesting** agileInstance) + { + if (agileInstance == nullptr) + return E_POINTER; + + *agileInstance = new TrackMyLifetimeTesting(/*isAgileInstance*/ true); + return S_OK; + } + + DEF_FUNC(Method)() + { + return S_OK; + } + public: // IUnknown STDMETHOD(QueryInterface)( /* [in] */ REFIID riid, /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject) { + if (_isAgileInstance) + { + if (riid == __uuidof(IAgileObject)) + { + *ppvObject = static_cast(this); + AddRef(); + return S_OK; + } + } return DoQueryInterface(riid, ppvObject, static_cast(this)); } diff --git a/src/tests/Interop/COM/ServerContracts/Server.Contracts.cs b/src/tests/Interop/COM/ServerContracts/Server.Contracts.cs index 7da4b220fbc3cd..c54fc066cae2be 100644 --- a/src/tests/Interop/COM/ServerContracts/Server.Contracts.cs +++ b/src/tests/Interop/COM/ServerContracts/Server.Contracts.cs @@ -446,6 +446,8 @@ internal interface IInspectableTesting2 internal interface ITrackMyLifetimeTesting { IntPtr GetAllocationCountCallback(); + ITrackMyLifetimeTesting CreateAgileInstance(); + void Method(); } } diff --git a/src/tests/Interop/COM/ServerContracts/Server.Contracts.h b/src/tests/Interop/COM/ServerContracts/Server.Contracts.h index 830966f1427980..ef87bb9c920aab 100644 --- a/src/tests/Interop/COM/ServerContracts/Server.Contracts.h +++ b/src/tests/Interop/COM/ServerContracts/Server.Contracts.h @@ -543,6 +543,8 @@ struct __declspec(uuid("57f396a1-58a0-425f-8807-9f938a534984")) ITrackMyLifetimeTesting : IUnknown { virtual HRESULT STDMETHODCALLTYPE GetAllocationCountCallback(_Outptr_ void** fptr) = 0; + virtual HRESULT STDMETHODCALLTYPE CreateAgileInstance(ITrackMyLifetimeTesting** agileInstance) = 0; + virtual HRESULT STDMETHODCALLTYPE Method() = 0; }; // IIDs for the below types are generated by the runtime.