Skip to content

Commit c54532b

Browse files
committed
2 parents 3c2549c + 9023e72 commit c54532b

File tree

3 files changed

+227
-0
lines changed

3 files changed

+227
-0
lines changed

src/thirtytwo/NativeMethods.txt

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ CoGetClassObject
2222
CombineRgn
2323
COMBOBOXINFO_BUTTON_STATE
2424
CopyImage
25+
CoTaskMemAlloc
2526
CoTaskMemFree
2627
CountClipboardFormats
2728
CreateActCtx
@@ -224,6 +225,7 @@ IMarshal
224225
IModalWindow
225226
InitCommonControlsEx
226227
InitVariantFromDoubleArray
228+
INoMarshal
227229
INTERFACEDATA
228230
INVALID_HANDLE_VALUE
229231
InvalidateRect
+80
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright (c) Jeremy W. Kuhne. All rights reserved.
2+
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
3+
4+
using System.Runtime.InteropServices;
5+
6+
namespace Windows.Win32.System.Com;
7+
8+
/// <summary>
9+
/// Lifetime management helper for a COM callable wrapper. It holds the created <typeparamref name="TObject"/>
10+
/// wrapper with he given <typeparamref name="TVTable"/>.
11+
/// </summary>
12+
/// <remarks>
13+
/// <para>
14+
/// This should not be created directly. Instead use <see cref="Lifetime{TVTable, TObject}.Allocate"/>.
15+
/// </para>
16+
/// <para>
17+
/// A COM object's memory layout is a virtual function table (vtable) pointer followed by instance data. We're
18+
/// effectively manually creating a COM object here that contains instance data of a GCHandle to the related
19+
/// managed object and a ref count.
20+
/// </para>
21+
/// </remarks>
22+
public unsafe struct Lifetime<TVTable, TObject> where TVTable : unmanaged
23+
{
24+
public TVTable* VTable;
25+
public IUnknown* Handle;
26+
public uint RefCount;
27+
28+
public static unsafe uint AddRef(IUnknown* @this)
29+
=> Interlocked.Increment(ref ((Lifetime<TVTable, TObject>*)@this)->RefCount);
30+
31+
public static unsafe uint Release(IUnknown* @this)
32+
{
33+
var lifetime = (Lifetime<TVTable, TObject>*)@this;
34+
Debug.Assert(lifetime->RefCount > 0);
35+
uint count = Interlocked.Decrement(ref lifetime->RefCount);
36+
if (count == 0)
37+
{
38+
GCHandle.FromIntPtr((nint)lifetime->Handle).Free();
39+
Interop.CoTaskMemFree(lifetime);
40+
}
41+
42+
return count;
43+
}
44+
45+
/// <summary>
46+
/// Allocate a lifetime wrapper for the given <paramref name="object"/> with the given
47+
/// <paramref name="vtable"/>.
48+
/// </summary>
49+
/// <remarks>
50+
/// <para>
51+
/// This creates a <see cref="GCHandle"/> to root the <paramref name="object"/> until ref
52+
/// counting has gone to zero.
53+
/// </para>
54+
/// <para>
55+
/// The <paramref name="vtable"/> should be fixed, typically as a static. Com calls always
56+
/// include the "this" pointer as the first argument.
57+
/// </para>
58+
/// </remarks>
59+
public static unsafe Lifetime<TVTable, TObject>* Allocate(TObject @object, TVTable* vtable)
60+
{
61+
// Manually allocate a native instance of this struct.
62+
var wrapper = (Lifetime<TVTable, TObject>*)Interop.CoTaskMemAlloc((nuint)sizeof(Lifetime<TVTable, TObject>));
63+
64+
// Assign a pointer to the vtable, allocate a GCHandle for the related object, and set the initial ref count.
65+
wrapper->VTable = vtable;
66+
wrapper->Handle = (IUnknown*)GCHandle.ToIntPtr(GCHandle.Alloc(@object));
67+
wrapper->RefCount = 1;
68+
69+
return wrapper;
70+
}
71+
72+
/// <summary>
73+
/// Gets the object wrapped by a lifetime wrapper.
74+
/// </summary>
75+
public static TObject? GetObject(IUnknown* @this)
76+
{
77+
var lifetime = (Lifetime<TVTable, TObject>*)@this;
78+
return (TObject?)GCHandle.FromIntPtr((nint)lifetime->Handle).Target;
79+
}
80+
}

src/thirtytwo_tests/Win32/System/Com/ComTests.cs

+145
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
// Copyright (c) Jeremy W. Kuhne. All rights reserved.
22
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
33

4+
using System.Runtime.CompilerServices;
5+
using System.Runtime.InteropServices;
46
using Windows.Dialogs;
7+
using Windows.Win32.Foundation;
8+
using Windows.Win32.System.Com.Marshal;
59
using Windows.Win32.UI.Shell;
10+
using InteropMarshal = global::System.Runtime.InteropServices.Marshal;
611

712
namespace Windows.Win32.System.Com;
813

@@ -27,4 +32,144 @@ public void Com_GetComPointer_SameInterfaceInstance()
2732

2833
Assert.True(iEvents1.Pointer == iEvents2.Pointer);
2934
}
35+
36+
[Fact]
37+
public void Com_BuiltInCom_RCW_Behavior()
38+
{
39+
UnknownTest unknown = new();
40+
using ComScope<IUnknown> iUnknown = new(UnknownCCW.CreateInstance(unknown));
41+
42+
object rcw = InteropMarshal.GetObjectForIUnknown((IntPtr)iUnknown.Pointer);
43+
44+
unknown.AddRefCount.Should().Be(1);
45+
unknown.ReleaseCount.Should().Be(1);
46+
unknown.LastRefCount.Should().Be(2);
47+
unknown.QueryInterfaceGuids.Should().BeEquivalentTo([
48+
IUnknown.IID_Guid,
49+
INoMarshal.IID_Guid,
50+
IAgileObject.IID_Guid,
51+
IMarshal.IID_Guid]);
52+
53+
// Release and FinalRelease look the same from our IUnknown's perspective
54+
InteropMarshal.FinalReleaseComObject(rcw);
55+
56+
unknown.AddRefCount.Should().Be(1);
57+
unknown.ReleaseCount.Should().Be(2);
58+
unknown.LastRefCount.Should().Be(1);
59+
unknown.QueryInterfaceGuids.Should().BeEquivalentTo([
60+
IUnknown.IID_Guid,
61+
INoMarshal.IID_Guid,
62+
IAgileObject.IID_Guid,
63+
IMarshal.IID_Guid]);
64+
}
65+
66+
public interface IUnkownTest
67+
{
68+
public void QueryInterface(Guid riid);
69+
public void AddRef(uint current);
70+
public void Release(uint current);
71+
}
72+
73+
public class UnknownTest : IUnkownTest
74+
{
75+
public int AddRefCount { get; private set; }
76+
public int ReleaseCount { get; private set; }
77+
public List<Guid> QueryInterfaceGuids { get; } = [];
78+
public int LastRefCount { get; private set; }
79+
80+
void IUnkownTest.AddRef(uint current)
81+
{
82+
AddRefCount++;
83+
LastRefCount = (int)current;
84+
}
85+
86+
void IUnkownTest.QueryInterface(Guid riid)
87+
{
88+
QueryInterfaceGuids.Add(riid);
89+
}
90+
91+
void IUnkownTest.Release(uint current)
92+
{
93+
ReleaseCount++;
94+
LastRefCount = (int)current;
95+
}
96+
}
97+
98+
public static class UnknownCCW
99+
{
100+
public static unsafe IUnknown* CreateInstance(IUnkownTest @object)
101+
=> (IUnknown*)Lifetime<IUnknown.Vtbl, IUnkownTest>.Allocate(@object, CCWVTable);
102+
103+
private static readonly IUnknown.Vtbl* CCWVTable = AllocateVTable();
104+
105+
private static unsafe IUnknown.Vtbl* AllocateVTable()
106+
{
107+
// Allocate and create a static VTable for this type projection.
108+
var vtable = (IUnknown.Vtbl*)RuntimeHelpers.AllocateTypeAssociatedMemory(typeof(UnknownCCW), sizeof(IUnknown.Vtbl));
109+
110+
// IUnknown
111+
vtable->QueryInterface_1 = &QueryInterface;
112+
vtable->AddRef_2 = &AddRef;
113+
vtable->Release_3 = &Release;
114+
return vtable;
115+
}
116+
117+
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])]
118+
private static HRESULT QueryInterface(IUnknown* @this, Guid* riid, void** ppvObject)
119+
{
120+
if (ppvObject is null)
121+
{
122+
return HRESULT.E_POINTER;
123+
}
124+
125+
var unknown = Lifetime<IUnknown.Vtbl, IUnkownTest>.GetObject(@this);
126+
if (unknown is null)
127+
{
128+
return HRESULT.COR_E_OBJECTDISPOSED;
129+
}
130+
131+
unknown.QueryInterface(*riid);
132+
133+
if (*riid == typeof(IUnknown).GUID)
134+
{
135+
*ppvObject = @this;
136+
}
137+
else
138+
{
139+
*ppvObject = null;
140+
return HRESULT.E_NOINTERFACE;
141+
}
142+
143+
Lifetime<IUnknown.Vtbl, IUnkownTest>.AddRef(@this);
144+
return HRESULT.S_OK;
145+
}
146+
147+
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])]
148+
private static uint AddRef(IUnknown* @this)
149+
{
150+
var unknown = Lifetime<IUnknown.Vtbl, IUnkownTest>.GetObject(@this);
151+
if (unknown is null)
152+
{
153+
return HRESULT.COR_E_OBJECTDISPOSED;
154+
}
155+
156+
uint current = Lifetime<IUnknown.Vtbl, IUnkownTest>.AddRef(@this);
157+
unknown.AddRef(current);
158+
return current;
159+
}
160+
161+
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvStdcall)])]
162+
private static uint Release(IUnknown* @this)
163+
{
164+
var unknown = Lifetime<IUnknown.Vtbl, IUnkownTest>.GetObject(@this);
165+
if (unknown is null)
166+
{
167+
return HRESULT.COR_E_OBJECTDISPOSED;
168+
}
169+
170+
uint current = Lifetime<IUnknown.Vtbl, IUnkownTest>.Release(@this);
171+
unknown.Release(current);
172+
return current;
173+
}
174+
}
30175
}

0 commit comments

Comments
 (0)