From 8431f09335d94b831179987e53401e67df9c51a9 Mon Sep 17 00:00:00 2001 From: Aaron R Robinson Date: Thu, 9 Apr 2026 13:46:05 -0700 Subject: [PATCH 1/2] Fix CLR->COM RCW ownership on slow path Have the slow GetCOMIPFromRCW helper report whether the returned interface pointer actually requires cleanup, so RCW cache hits after OLE TLS initialization are not released as if they were freshly acquired.\n\nFixes #126619\n\nCo-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../src/System/StubHelpers.cs | 19 +++++++-------- src/coreclr/vm/stubhelpers.cpp | 23 +++++++++++-------- src/coreclr/vm/stubhelpers.h | 2 +- 3 files changed, 25 insertions(+), 19 deletions(-) diff --git a/src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs b/src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs index a8fd4e2337819a..6409dc8eb0eb9e 100644 --- a/src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs +++ b/src/coreclr/System.Private.CoreLib/src/System/StubHelpers.cs @@ -1891,23 +1891,24 @@ internal static void SafeHandleRelease(SafeHandle pHandle) private static extern IntPtr GetCOMIPFromRCW(object objSrc, IntPtr pCPCMD, out IntPtr ppTarget); [LibraryImport(RuntimeHelpers.QCall, EntryPoint = "StubHelpers_GetCOMIPFromRCWSlow")] - private static partial IntPtr GetCOMIPFromRCWSlow(ObjectHandleOnStack objSrc, IntPtr pCPCMD, out IntPtr ppTarget); + private static partial IntPtr GetCOMIPFromRCWSlow(ObjectHandleOnStack objSrc, IntPtr pCPCMD, out IntPtr ppTarget, [MarshalAs(UnmanagedType.Bool)] out bool pfNeedsRelease); internal static IntPtr GetCOMIPFromRCW(object objSrc, IntPtr pCPCMD, out IntPtr ppTarget, out bool pfNeedsRelease) { IntPtr rcw = GetCOMIPFromRCW(objSrc, pCPCMD, out ppTarget); - if (rcw == IntPtr.Zero) + if (rcw != IntPtr.Zero) { - // If we didn't find the COM interface pointer in the cache we need to release the pointer. - pfNeedsRelease = true; - return GetCOMIPFromRCWWorker(objSrc, pCPCMD, out ppTarget); + pfNeedsRelease = false; + return rcw; } - pfNeedsRelease = false; - return rcw; + + // The slow path may create OLE TLS and then still resolve the interface via the RCW cache. + // Let the slow path tell us whether it returned an owned pointer that requires cleanup. + return GetCOMIPFromRCWWorker(objSrc, pCPCMD, out ppTarget, out pfNeedsRelease); [MethodImpl(MethodImplOptions.NoInlining)] - static IntPtr GetCOMIPFromRCWWorker(object objSrc, IntPtr pCPCMD, out IntPtr ppTarget) - => GetCOMIPFromRCWSlow(ObjectHandleOnStack.Create(ref objSrc), pCPCMD, out ppTarget); + static IntPtr GetCOMIPFromRCWWorker(object objSrc, IntPtr pCPCMD, out IntPtr ppTarget, out bool pfNeedsRelease) + => GetCOMIPFromRCWSlow(ObjectHandleOnStack.Create(ref objSrc), pCPCMD, out ppTarget, out pfNeedsRelease); } #endif // FEATURE_COMINTEROP diff --git a/src/coreclr/vm/stubhelpers.cpp b/src/coreclr/vm/stubhelpers.cpp index 36dd859b5c7a84..4cca5e9019f3dd 100644 --- a/src/coreclr/vm/stubhelpers.cpp +++ b/src/coreclr/vm/stubhelpers.cpp @@ -312,11 +312,12 @@ FCIMPLEND #include -extern "C" IUnknown* QCALLTYPE StubHelpers_GetCOMIPFromRCWSlow(QCall::ObjectHandleOnStack pSrc, MethodDesc* pMD, void** ppTarget) +extern "C" IUnknown* QCALLTYPE StubHelpers_GetCOMIPFromRCWSlow(QCall::ObjectHandleOnStack pSrc, MethodDesc* pMD, void** ppTarget, BOOL* pfNeedsRelease) { QCALL_CONTRACT; _ASSERTE(pMD != NULL); _ASSERTE(ppTarget != NULL); + _ASSERTE(pfNeedsRelease != NULL); IUnknown *pIntf = NULL; BEGIN_QCALL; @@ -326,6 +327,8 @@ extern "C" IUnknown* QCALLTYPE StubHelpers_GetCOMIPFromRCWSlow(QCall::ObjectHand OBJECTREF objRef = pSrc.Get(); GCPROTECT_BEGIN(objRef); + *pfNeedsRelease = FALSE; + // This snippet exists to enable OLE TLS data creation that isn't possible on the fast path. // It is practically identical to the StubHelpers::GetCOMIPFromRCW FCALL, but in the event the OLE TLS // data on this thread hasn't occurred yet, we will create it. Since this is the slow path, trying the @@ -335,17 +338,19 @@ extern "C" IUnknown* QCALLTYPE StubHelpers_GetCOMIPFromRCWSlow(QCall::ObjectHand RCW* pRCW = objRef->PassiveGetSyncBlock()->GetInteropInfoNoCreate()->GetRawRCW(); if (pRCW != NULL) { - IUnknown* pUnk = GetCOMIPFromRCW_GetTargetFromRCWCache(pOleTlsData, pRCW, pComInfo, ppTarget); - if (pUnk != NULL) - return pUnk; + pIntf = GetCOMIPFromRCW_GetTargetFromRCWCache(pOleTlsData, pRCW, pComInfo, ppTarget); } - // Still not in the cache and we've ensured the OLE TLS data was created. - SafeComHolder pRetUnk = ComObject::GetComIPFromRCWThrowing(&objRef, pComInfo->m_pInterfaceMT); - *ppTarget = GetCOMIPFromRCW_GetTarget(pRetUnk, pComInfo); - _ASSERTE(*ppTarget != NULL); + if (pIntf == NULL) + { + // Still not in the cache and we've ensured the OLE TLS data was created. + SafeComHolder pRetUnk = ComObject::GetComIPFromRCWThrowing(&objRef, pComInfo->m_pInterfaceMT); + *ppTarget = GetCOMIPFromRCW_GetTarget(pRetUnk, pComInfo); + _ASSERTE(*ppTarget != NULL); - pIntf = pRetUnk.Extract(); + pIntf = pRetUnk.Extract(); + *pfNeedsRelease = TRUE; + } GCPROTECT_END(); diff --git a/src/coreclr/vm/stubhelpers.h b/src/coreclr/vm/stubhelpers.h index 38ae6ae5b7e76e..140bac12e80885 100644 --- a/src/coreclr/vm/stubhelpers.h +++ b/src/coreclr/vm/stubhelpers.h @@ -46,7 +46,7 @@ extern "C" void QCALLTYPE StubHelpers_ProfilerEndTransitionCallback(MethodDesc* #endif #ifdef FEATURE_COMINTEROP -extern "C" IUnknown* QCALLTYPE StubHelpers_GetCOMIPFromRCWSlow(QCall::ObjectHandleOnStack pSrc, MethodDesc* pMD, void** ppTarget); +extern "C" IUnknown* QCALLTYPE StubHelpers_GetCOMIPFromRCWSlow(QCall::ObjectHandleOnStack pSrc, MethodDesc* pMD, void** ppTarget, BOOL* pfNeedsRelease); extern "C" void QCALLTYPE ObjectMarshaler_ConvertToNative(QCall::ObjectHandleOnStack pSrcUNSAFE, VARIANT* pDest); extern "C" void QCALLTYPE ObjectMarshaler_ConvertToManaged(VARIANT* pSrc, QCall::ObjectHandleOnStack retObject); From f45112945f2c68099d6374fe65d200b99942cc96 Mon Sep 17 00:00:00 2001 From: Aaron R Robinson Date: Thu, 9 Apr 2026 15:59:04 -0700 Subject: [PATCH 2/2] Add COM regression coverage for RCW slow-path cleanup Extend TrackMyLifetimeTesting with a simple Method() call and add a native-thread callback test that exercises the agile RCW path covered by #126619. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../COM/NETClients/Lifetime/Program.cs | 112 +++++++++++++++--- .../Interop/COM/NativeServer/Exports.def | 3 +- .../Interop/COM/NativeServer/Servers.cpp | 15 +++ .../COM/NativeServer/TrackMyLifetimeTesting.h | 32 ++++- .../COM/ServerContracts/Server.Contracts.cs | 2 + .../COM/ServerContracts/Server.Contracts.h | 2 + 6 files changed, 148 insertions(+), 18 deletions(-) diff --git a/src/tests/Interop/COM/NETClients/Lifetime/Program.cs b/src/tests/Interop/COM/NETClients/Lifetime/Program.cs index 9bdc9d8ea6ea2c..a0a6c254d4d996 100644 --- a/src/tests/Interop/COM/NETClients/Lifetime/Program.cs +++ b/src/tests/Interop/COM/NETClients/Lifetime/Program.cs @@ -17,8 +17,14 @@ namespace NetClient public unsafe class Program { static delegate* unmanaged GetAllocationCount; + static ITrackMyLifetimeTesting? s_agileInstance; + static Exception? s_callbackException; + + [DllImport("COMNativeServer", EntryPoint = "InvokeCallbackOnNativeThread")] + private static extern int InvokeCallbackOnNativeThread(delegate* unmanaged callback); // Initialize for all tests + [MethodImpl(MethodImplOptions.NoInlining)] static void Initialize() { var inst = new TrackMyLifetimeTesting(); @@ -45,6 +51,19 @@ static void ForceGC() } } + [UnmanagedCallersOnly] + static void InvokeObjectFromNativeThread() + { + try + { + s_agileInstance!.Method(); + } + catch (Exception e) + { + s_callbackException = e; + } + } + static void Validate_COMServer_CleanUp() { Console.WriteLine($"Calling {nameof(Validate_COMServer_CleanUp)}..."); @@ -85,36 +104,57 @@ static void Validate_COMServer_DisableEagerCleanUp() Assert.False(Marshal.AreComObjectsAvailableForCleanup()); } - [Fact] - public static int TestEntryPoint() + static void Validate_COMServer_CallOnNativeThread() { - // RegFree COM and STA apartments are not supported on Windows Nano - if (Utilities.IsWindowsNanoServer) + Console.WriteLine($"Calling {nameof(Validate_COMServer_CallOnNativeThread)}..."); + + // Need agile instance since the object will be used on a different thread + // than the creating thread and we're on an STA thread. + s_agileInstance = CreateAgileInstance(); + try + { + s_agileInstance.Method(); + + // Create a fresh native thread for each callback so the COM call runs before that thread + // has initialized the CLR's OLE TLS state. + for (int i = 0; i < 10; i++) + { + s_callbackException = null; + + Marshal.ThrowExceptionForHR(InvokeCallbackOnNativeThread(&InvokeObjectFromNativeThread)); + + Assert.True(s_callbackException is null, s_callbackException?.ToString()); + } + } + finally { - return 100; + s_agileInstance = null; } - int result = 101; + [MethodImpl(MethodImplOptions.NoInlining)] + static ITrackMyLifetimeTesting CreateAgileInstance() + => new TrackMyLifetimeTesting().CreateAgileInstance(); + } + + const int TestFailed = 101; + const int TestPassed = 100; + + static int RunOnSTAThread(Action action) + { + int result = TestFailed; - // Run the test on a new STA thread since Nano Server doesn't support the STA - // and as a result, the main application thread can't be made STA with the STAThread attribute Thread staThread = new Thread(() => { try { - // Initialization for all future tests - Initialize(); - Assert.True(GetAllocationCount != null); - - Validate_COMServer_CleanUp(); - Validate_COMServer_DisableEagerCleanUp(); + action(); } catch (Exception e) { Console.WriteLine($"Test Failure: {e}"); - result = 101; + result = TestFailed; } - result = 100; + result = TestPassed; }); staThread.SetApartmentState(ApartmentState.STA); @@ -123,5 +163,45 @@ public static int TestEntryPoint() return result; } + + [Fact] + public static int TestEntryPoint() + { + // RegFree COM and STA apartments are not supported on Windows Nano + if (Utilities.IsWindowsNanoServer) + { + return TestPassed; + } + + // Run the test on a new STA thread since Nano Server doesn't support the STA + // and as a result, the main application thread can't be made STA with the STAThread attribute + int result = RunOnSTAThread(() => + { + // Initialization for all future tests + Initialize(); + ForceGC(); + Assert.True(GetAllocationCount != null); + + Validate_COMServer_CleanUp(); + Validate_COMServer_CallOnNativeThread(); + }); + if (result != TestPassed) + { + return result; + } + + return RunOnSTAThread(() => + { + // Initialization for all future tests + Initialize(); + ForceGC(); + Assert.True(GetAllocationCount != null); + + // Manipulating the eager cleanup state cannot be changed once set, + // so we need to run this test on a separate thread after the first + // test validates that cleanup is working as expected with eager cleanup enabled. + Validate_COMServer_DisableEagerCleanUp(); + }); + } } } diff --git a/src/tests/Interop/COM/NativeServer/Exports.def b/src/tests/Interop/COM/NativeServer/Exports.def index 2d0de26b056e0b..cb6d5520db0bb1 100644 --- a/src/tests/Interop/COM/NativeServer/Exports.def +++ b/src/tests/Interop/COM/NativeServer/Exports.def @@ -1,4 +1,5 @@ EXPORTS DllGetClassObject PRIVATE DllRegisterServer PRIVATE - DllUnregisterServer PRIVATE \ No newline at end of file + DllUnregisterServer PRIVATE + InvokeCallbackOnNativeThread diff --git a/src/tests/Interop/COM/NativeServer/Servers.cpp b/src/tests/Interop/COM/NativeServer/Servers.cpp index ebe4f9df5acde2..ed47fd9591b91c 100644 --- a/src/tests/Interop/COM/NativeServer/Servers.cpp +++ b/src/tests/Interop/COM/NativeServer/Servers.cpp @@ -3,6 +3,7 @@ #include "stdafx.h" #include "Servers.h" +#include namespace { @@ -155,6 +156,20 @@ namespace } } +extern "C" HRESULT STDMETHODCALLTYPE InvokeCallbackOnNativeThread(void (STDMETHODCALLTYPE* callback)()) +{ + if (callback == nullptr) + return E_POINTER; + + std::thread worker([callback]() + { + callback(); + }); + + worker.join(); + return S_OK; +} + STDAPI DllRegisterServer(void) { HRESULT hr; diff --git a/src/tests/Interop/COM/NativeServer/TrackMyLifetimeTesting.h b/src/tests/Interop/COM/NativeServer/TrackMyLifetimeTesting.h index 3d0ac26da0a16a..6e3c3ee3c4034f 100644 --- a/src/tests/Interop/COM/NativeServer/TrackMyLifetimeTesting.h +++ b/src/tests/Interop/COM/NativeServer/TrackMyLifetimeTesting.h @@ -5,7 +5,7 @@ #include "Servers.h" -class TrackMyLifetimeTesting : public UnknownImpl, public ITrackMyLifetimeTesting +class TrackMyLifetimeTesting : public UnknownImpl, public ITrackMyLifetimeTesting, public IAgileObject { static std::atomic _instanceCount; @@ -14,8 +14,15 @@ class TrackMyLifetimeTesting : public UnknownImpl, public ITrackMyLifetimeTestin return _instanceCount; } +private: + const bool _isAgileInstance = false; + public: TrackMyLifetimeTesting() + : TrackMyLifetimeTesting(false) + { } + TrackMyLifetimeTesting(bool isAgileInstance) + : _isAgileInstance(isAgileInstance) { _instanceCount++; } @@ -34,11 +41,34 @@ class TrackMyLifetimeTesting : public UnknownImpl, public ITrackMyLifetimeTestin return S_OK; } + DEF_FUNC(CreateAgileInstance)(ITrackMyLifetimeTesting** agileInstance) + { + if (agileInstance == nullptr) + return E_POINTER; + + *agileInstance = new TrackMyLifetimeTesting(/*isAgileInstance*/ true); + return S_OK; + } + + DEF_FUNC(Method)() + { + return S_OK; + } + public: // IUnknown STDMETHOD(QueryInterface)( /* [in] */ REFIID riid, /* [iid_is][out] */ _COM_Outptr_ void __RPC_FAR *__RPC_FAR *ppvObject) { + if (_isAgileInstance) + { + if (riid == __uuidof(IAgileObject)) + { + *ppvObject = static_cast(this); + AddRef(); + return S_OK; + } + } return DoQueryInterface(riid, ppvObject, static_cast(this)); } diff --git a/src/tests/Interop/COM/ServerContracts/Server.Contracts.cs b/src/tests/Interop/COM/ServerContracts/Server.Contracts.cs index 7da4b220fbc3cd..c54fc066cae2be 100644 --- a/src/tests/Interop/COM/ServerContracts/Server.Contracts.cs +++ b/src/tests/Interop/COM/ServerContracts/Server.Contracts.cs @@ -446,6 +446,8 @@ internal interface IInspectableTesting2 internal interface ITrackMyLifetimeTesting { IntPtr GetAllocationCountCallback(); + ITrackMyLifetimeTesting CreateAgileInstance(); + void Method(); } } diff --git a/src/tests/Interop/COM/ServerContracts/Server.Contracts.h b/src/tests/Interop/COM/ServerContracts/Server.Contracts.h index 830966f1427980..ef87bb9c920aab 100644 --- a/src/tests/Interop/COM/ServerContracts/Server.Contracts.h +++ b/src/tests/Interop/COM/ServerContracts/Server.Contracts.h @@ -543,6 +543,8 @@ struct __declspec(uuid("57f396a1-58a0-425f-8807-9f938a534984")) ITrackMyLifetimeTesting : IUnknown { virtual HRESULT STDMETHODCALLTYPE GetAllocationCountCallback(_Outptr_ void** fptr) = 0; + virtual HRESULT STDMETHODCALLTYPE CreateAgileInstance(ITrackMyLifetimeTesting** agileInstance) = 0; + virtual HRESULT STDMETHODCALLTYPE Method() = 0; }; // IIDs for the below types are generated by the runtime.