Skip to content

Commit 5266e7d

Browse files
authored
Implement the "compact dispatch sections" optimization from CoreCLR's ComWrappers implementation in NativeAOT's implementation (#112461)
* Implement the "compact dispatch sections" optimization from CoreCLR's ComWrappers implementation in NativeAOT's implementation * Various fixes to get tests passing
1 parent a087868 commit 5266e7d

File tree

2 files changed

+88
-20
lines changed

2 files changed

+88
-20
lines changed

src/coreclr/nativeaot/Runtime/HandleTableHelpers.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,32 @@ struct ManagedObjectWrapper
9696
}
9797
};
9898

99+
template<typename T>
100+
struct Span
101+
{
102+
T* _pointer;
103+
int _length;
104+
};
105+
99106
// This structure mirrors the managed type System.Runtime.InteropServices.ComWrappers.InternalComInterfaceDispatch.
100107
struct InternalComInterfaceDispatch
101108
{
102-
void* Vtable;
103109
ManagedObjectWrapper* _thisPtr;
110+
Span<void*> Vtables;
104111
};
105112

113+
#ifdef TARGET_64BIT
114+
constexpr uintptr_t DispatchAlignment = 64;
115+
#else
116+
constexpr uintptr_t DispatchAlignment = 16;
117+
#endif
118+
119+
constexpr uintptr_t DispatchAlignmentMask = ~(DispatchAlignment - 1);
120+
106121
static ManagedObjectWrapper* ToManagedObjectWrapper(void* dispatchPtr)
107122
{
108-
return ((InternalComInterfaceDispatch*)dispatchPtr)->_thisPtr;
123+
uintptr_t dispatch = reinterpret_cast<uintptr_t>(dispatchPtr) & DispatchAlignmentMask;
124+
return ((InternalComInterfaceDispatch*)dispatch)->_thisPtr;
109125
}
110126

111127
//

src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,31 @@ public static unsafe T GetInstance<T>(ComInterfaceDispatch* dispatchPtr) where T
110110

111111
internal static unsafe ManagedObjectWrapper* ToManagedObjectWrapper(ComInterfaceDispatch* dispatchPtr)
112112
{
113-
return ((InternalComInterfaceDispatch*)dispatchPtr)->_thisPtr;
113+
InternalComInterfaceDispatch* dispatch = (InternalComInterfaceDispatch*)unchecked((nuint)dispatchPtr & (nuint)InternalComInterfaceDispatch.DispatchAlignmentMask);
114+
return dispatch->_thisPtr;
114115
}
115116
}
116117

117118
internal unsafe struct InternalComInterfaceDispatch
118119
{
119-
public IntPtr Vtable;
120+
#if TARGET_64BIT
121+
internal const int DispatchAlignment = 64;
122+
internal const int NumEntriesInDispatchTable = DispatchAlignment / 8 /* sizeof(void*) */ - 1;
123+
#else
124+
internal const int DispatchAlignment = 16;
125+
internal const int NumEntriesInDispatchTable = DispatchAlignment / 4 /* sizeof(void*) */ - 1;
126+
#endif
127+
internal const ulong DispatchAlignmentMask = unchecked((ulong)~(InternalComInterfaceDispatch.DispatchAlignment - 1));
128+
120129
internal ManagedObjectWrapper* _thisPtr;
130+
131+
public DispatchTable Vtables;
132+
133+
[InlineArray(NumEntriesInDispatchTable)]
134+
internal unsafe struct DispatchTable
135+
{
136+
private IntPtr _element;
137+
}
121138
}
122139

123140
internal enum CreateComInterfaceFlagsEx
@@ -337,16 +354,23 @@ public unsafe bool Destroy()
337354
}
338355
}
339356

357+
private unsafe IntPtr GetDispatchPointerAtIndex(int index)
358+
{
359+
InternalComInterfaceDispatch* dispatch = &Dispatches[index / InternalComInterfaceDispatch.NumEntriesInDispatchTable];
360+
IntPtr* vtables = (IntPtr*)(void*)&dispatch->Vtables;
361+
return (IntPtr)(&vtables[index % InternalComInterfaceDispatch.NumEntriesInDispatchTable]);
362+
}
363+
340364
private unsafe IntPtr AsRuntimeDefined(in Guid riid)
341365
{
342366
// The order of interface lookup here is important.
343-
// See CreateCCW() for the expected order.
367+
// See CreateManagedObjectWrapper() for the expected order.
344368
int i = UserDefinedCount;
345369
if ((Flags & CreateComInterfaceFlagsEx.CallerDefinedIUnknown) == 0)
346370
{
347371
if (riid == IID_IUnknown)
348372
{
349-
return (IntPtr)(Dispatches + i);
373+
return GetDispatchPointerAtIndex(i);
350374
}
351375

352376
i++;
@@ -356,7 +380,7 @@ private unsafe IntPtr AsRuntimeDefined(in Guid riid)
356380
{
357381
if (riid == IID_IReferenceTrackerTarget)
358382
{
359-
return (IntPtr)(Dispatches + i);
383+
return GetDispatchPointerAtIndex(i);
360384
}
361385

362386
i++;
@@ -365,7 +389,7 @@ private unsafe IntPtr AsRuntimeDefined(in Guid riid)
365389
{
366390
if (riid == IID_TaggedImpl)
367391
{
368-
return (IntPtr)(Dispatches + i);
392+
return GetDispatchPointerAtIndex(i);
369393
}
370394
}
371395

@@ -378,7 +402,7 @@ private unsafe IntPtr AsUserDefined(in Guid riid)
378402
{
379403
if (UserDefined[i].IID == riid)
380404
{
381-
return (IntPtr)(Dispatches + i);
405+
return GetDispatchPointerAtIndex(i);
382406
}
383407
}
384408

@@ -475,7 +499,7 @@ public ManagedObjectWrapperReleaser(ManagedObjectWrapper* wrapper)
475499
// Release GC handle created when MOW was built.
476500
if (_wrapper->Destroy())
477501
{
478-
NativeMemory.Free(_wrapper);
502+
NativeMemory.AlignedFree(_wrapper);
479503
_wrapper = null;
480504
}
481505
else
@@ -747,6 +771,12 @@ public unsafe IntPtr GetOrCreateComInterfaceForObject(object instance, CreateCom
747771
return managedObjectWrapper.ComIp;
748772
}
749773

774+
private static nuint AlignUp(nuint value, nuint alignment)
775+
{
776+
nuint alignMask = alignment - 1;
777+
return (nuint)((value + alignMask) & ~alignMask);
778+
}
779+
750780
private unsafe ManagedObjectWrapper* CreateManagedObjectWrapper(object instance, CreateComInterfaceFlags flags)
751781
{
752782
ComInterfaceEntry* userDefined = ComputeVtables(instance, flags, out int userDefinedCount);
@@ -777,21 +807,43 @@ public unsafe IntPtr GetOrCreateComInterfaceForObject(object instance, CreateCom
777807
// Compute size for ManagedObjectWrapper instance.
778808
int totalDefinedCount = runtimeDefinedCount + userDefinedCount;
779809

780-
// Allocate memory for the ManagedObjectWrapper.
781-
IntPtr wrapperMem = (IntPtr)NativeMemory.Alloc(
782-
(nuint)sizeof(ManagedObjectWrapper) + (nuint)totalDefinedCount * (nuint)sizeof(InternalComInterfaceDispatch));
810+
int numSections = totalDefinedCount / InternalComInterfaceDispatch.NumEntriesInDispatchTable;
811+
if (totalDefinedCount % InternalComInterfaceDispatch.NumEntriesInDispatchTable != 0)
812+
{
813+
// Account for a trailing partial section to fit all of the defined interfaces.
814+
numSections++;
815+
}
816+
817+
nuint headerSize = AlignUp((nuint)sizeof(ManagedObjectWrapper), InternalComInterfaceDispatch.DispatchAlignment);
783818

784-
// Compute the dispatch section offset and ensure it is aligned.
785-
ManagedObjectWrapper* mow = (ManagedObjectWrapper*)wrapperMem;
819+
// Instead of allocating a full section even when we have a trailing one, we'll allocate only
820+
// as much space as we need to store all of our dispatch tables.
821+
nuint dispatchSectionSize = (nuint)totalDefinedCount * (nuint)sizeof(void*) + (nuint)numSections * (nuint)sizeof(void*);
822+
823+
// Allocate memory for the ManagedObjectWrapper with the correct alignment for our dispatch tables.
824+
IntPtr wrapperMem = (IntPtr)NativeMemory.AlignedAlloc(
825+
headerSize + dispatchSectionSize,
826+
InternalComInterfaceDispatch.DispatchAlignment);
786827

787-
// Dispatches follow immediately after ManagedObjectWrapper
788-
InternalComInterfaceDispatch* pDispatches = (InternalComInterfaceDispatch*)(wrapperMem + sizeof(ManagedObjectWrapper));
789-
for (int i = 0; i < totalDefinedCount; i++)
828+
// Dispatches follow the ManagedObjectWrapper.
829+
InternalComInterfaceDispatch* pDispatches = (InternalComInterfaceDispatch*)((nuint)wrapperMem + headerSize);
830+
Span<InternalComInterfaceDispatch> dispatches = new Span<InternalComInterfaceDispatch>(pDispatches, numSections);
831+
for (int i = 0; i < dispatches.Length; i++)
790832
{
791-
pDispatches[i].Vtable = (i < userDefinedCount) ? userDefined[i].Vtable : runtimeDefinedVtable[i - userDefinedCount];
792-
pDispatches[i]._thisPtr = mow;
833+
dispatches[i]._thisPtr = (ManagedObjectWrapper*)wrapperMem;
834+
Span<IntPtr> dispatchVtables = dispatches[i].Vtables;
835+
for (int j = 0; j < dispatchVtables.Length; j++)
836+
{
837+
int index = i * dispatchVtables.Length + j;
838+
if (index >= totalDefinedCount)
839+
{
840+
break;
841+
}
842+
dispatchVtables[j] = (index < userDefinedCount) ? userDefined[index].Vtable : runtimeDefinedVtable[index - userDefinedCount];
843+
}
793844
}
794845

846+
ManagedObjectWrapper* mow = (ManagedObjectWrapper*)wrapperMem;
795847
mow->HolderHandle = IntPtr.Zero;
796848
mow->RefCount = 0;
797849
mow->UserDefinedCount = userDefinedCount;

0 commit comments

Comments
 (0)