diff --git a/src/System.Private.Windows.Core/src/System/Private/Windows/Ole/Composition.NativeToManagedAdapter.cs b/src/System.Private.Windows.Core/src/System/Private/Windows/Ole/Composition.NativeToManagedAdapter.cs index f8184baa64d..20a48cbaa0d 100644 --- a/src/System.Private.Windows.Core/src/System/Private/Windows/Ole/Composition.NativeToManagedAdapter.cs +++ b/src/System.Private.Windows.Core/src/System/Private/Windows/Ole/Composition.NativeToManagedAdapter.cs @@ -420,8 +420,9 @@ private static unsafe bool TryGetIStreamData( return false; } - using ComScope pStream = new((Com.IStream*)medium.hGlobal); - pStream.Value->Stat(out Com.STATSTG sstg, (uint)Com.STATFLAG.STATFLAG_DEFAULT); + // Don't wrap in ComScope - ReleaseStgMedium will release the stream. + Com.IStream* pStream = (Com.IStream*)medium.hGlobal; + pStream->Stat(out Com.STATSTG sstg, (uint)Com.STATFLAG.STATFLAG_DEFAULT); hglobal = PInvokeCore.GlobalAlloc(GLOBAL_ALLOC_FLAGS.GMEM_MOVEABLE | GLOBAL_ALLOC_FLAGS.GMEM_ZEROINIT, (uint)sstg.cbSize); @@ -433,7 +434,7 @@ private static unsafe bool TryGetIStreamData( } void* ptr = PInvokeCore.GlobalLock(hglobal); - pStream.Value->Read((byte*)ptr, (uint)sstg.cbSize, null); + pStream->Read((byte*)ptr, (uint)sstg.cbSize, null); PInvokeCore.GlobalUnlock(hglobal); return TryGetDataFromHGLOBAL(hglobal, in request, out data); diff --git a/src/System.Private.Windows.Core/tests/System.Private.Windows.Core.Tests/System/Private/Windows/Ole/IStreamNativeDataObject.cs b/src/System.Private.Windows.Core/tests/System.Private.Windows.Core.Tests/System/Private/Windows/Ole/IStreamNativeDataObject.cs new file mode 100644 index 00000000000..7b4141a99de --- /dev/null +++ b/src/System.Private.Windows.Core/tests/System.Private.Windows.Core.Tests/System/Private/Windows/Ole/IStreamNativeDataObject.cs @@ -0,0 +1,83 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Windows.Win32; +using Windows.Win32.Foundation; +using Windows.Win32.System.Com; + +namespace System.Private.Windows.Ole; + +/// +/// A native data object mock that returns data via . +/// +internal unsafe class IStreamNativeDataObject : NativeDataObjectMock +{ + private readonly Stream _stream; + private readonly ushort _format; + + public IStreamNativeDataObject(Stream stream, ushort format) + { + _stream = stream; + _format = format; + } + + public override HRESULT QueryGetData(FORMATETC* pformatetc) + { + if (pformatetc is null) + { + return HRESULT.DV_E_FORMATETC; + } + + if (pformatetc->cfFormat != _format) + { + return HRESULT.DV_E_FORMATETC; + } + + if (pformatetc->dwAspect != (uint)DVASPECT.DVASPECT_CONTENT) + { + return HRESULT.DV_E_DVASPECT; + } + + if (pformatetc->lindex != -1) + { + return HRESULT.DV_E_LINDEX; + } + + if (pformatetc->tymed != (uint)TYMED.TYMED_ISTREAM) + { + return HRESULT.DV_E_TYMED; + } + + return HRESULT.S_OK; + } + + public override HRESULT GetData(FORMATETC* pformatetcIn, STGMEDIUM* pmedium) + { + HRESULT result = QueryGetData(pformatetcIn); + if (result.Failed) + { + return result; + } + + if (pmedium is null) + { + return HRESULT.E_POINTER; + } + + // Reset stream position for each GetData call + _stream.Position = 0; + + // Create a ComManagedStream wrapper + ComManagedStream comStream = new(_stream); + + // Return the IStream pointer in the STGMEDIUM + // Note: hGlobal is a union with pstm in STGMEDIUM + pmedium->hGlobal = (HGLOBAL)(nint)ComHelpers.GetComPointer(comStream); + pmedium->tymed = TYMED.TYMED_ISTREAM; + pmedium->pUnkForRelease = null; + + return HRESULT.S_OK; + } + + protected override void Dispose(bool disposing) => _stream.Dispose(); +} diff --git a/src/System.Private.Windows.Core/tests/System.Private.Windows.Core.Tests/System/Private/Windows/Ole/NativeToManagedAdapterTests.cs b/src/System.Private.Windows.Core/tests/System.Private.Windows.Core.Tests/System/Private/Windows/Ole/NativeToManagedAdapterTests.cs index 6da568dee95..8f54f54ba82 100644 --- a/src/System.Private.Windows.Core/tests/System.Private.Windows.Core.Tests/System/Private/Windows/Ole/NativeToManagedAdapterTests.cs +++ b/src/System.Private.Windows.Core/tests/System.Private.Windows.Core.Tests/System/Private/Windows/Ole/NativeToManagedAdapterTests.cs @@ -181,4 +181,68 @@ public void ReadStringFromHGLOBAL_Terminator_ReturnsString(bool unicode) PInvokeCore.GlobalFree(global); } } + + [Fact] + public void TryGetIStreamData_DoesNotDoubleReleaseStream() + { + // This test verifies the fix for double-releasing the IStream. + // https://github.com/dotnet/wpf/issues/11401 + // + // Previously, the code wrapped the IStream in a ComScope which would Release it, + // and then ReleaseStgMedium would also try to Release it, causing a double-release. + + MemoryStream stream = new([0xBE, 0xAD, 0xCA, 0xFE]); + using IStreamNativeDataObject dataObject = new(stream, (ushort)_format.Id); + + IDataObject* pDataObject = ComHelpers.GetComPointer(dataObject); + + // GetComPointer returns a pointer with ref count of 1. + // Add an extra reference so we can track the ref count throughout the test. + uint initialRefCount = pDataObject->AddRef(); + initialRefCount.Should().Be(2); + + object? data; + uint refCountBeforeGetData; + uint refCountAfterGetData; + + // Scope the Composition so we can observe ref count changes. + { + // Composition.Create calls AddRef twice (once for NativeToManagedAdapter, once for NativeToRuntimeAdapter) + // and takes ownership of the original ref from GetComPointer. + var composition = Composition.Create(pDataObject); + + // After Create: original(1) + our AddRef(1) + Composition's two AddRefs(2) = 4 + refCountBeforeGetData = pDataObject->AddRef(); + pDataObject->Release(); // Undo our test AddRef + + // Try to get data - this should not crash due to CFG violation + // The IStream data object will return data via TYMED_ISTREAM + data = composition.GetData(nameof(NativeToManagedAdapterTests)); + + // Verify data was retrieved successfully + data.Should().BeOfType(); + + // Check ref count after GetData - should be unchanged from before + // (the IStream from GetData should be properly released by ReleaseStgMedium only once) + refCountAfterGetData = pDataObject->AddRef(); + pDataObject->Release(); // Undo our test AddRef + + refCountAfterGetData.Should().Be( + refCountBeforeGetData, + "GetData should not leak or double-release the IDataObject"); + } + + // Note: Composition doesn't implement IDisposable, so refs are released via GC. + // We still hold our extra ref, so the object won't be collected. + + // Release our extra ref from the start of the test. + uint finalRefCount = pDataObject->Release(); + + // We should still have refs from Composition's adapters (they're not disposed yet). + // The important thing is that GetData didn't corrupt the ref count. + finalRefCount.Should().BeGreaterThan(0); + + MemoryStream result = (MemoryStream)data!; + result.ToArray().Should().Equal(0xBE, 0xAD, 0xCA, 0xFE); + } }