Skip to content

Commit 63df321

Browse files
committed
Apply dotnet#112461 on 9.0
1 parent c97d3a4 commit 63df321

2 files changed

Lines changed: 88 additions & 20 deletions

File tree

src/coreclr/nativeaot/Runtime/HandleTableHelpers.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,32 @@ struct ManagedObjectWrapper
9696
}
9797
};
9898

99+
template<typename T>
100+
struct Span
101+
{
102+
T* _pointer;
103+
int _length;
104+
};
105+
99106
// This structure mirrors the managed type System.Runtime.InteropServices.ComWrappers.InternalComInterfaceDispatch.
100107
struct InternalComInterfaceDispatch
101108
{
102-
void* Vtable;
103109
ManagedObjectWrapper* _thisPtr;
110+
Span<void*> Vtables;
104111
};
105112

113+
#ifdef TARGET_64BIT
114+
constexpr uintptr_t DispatchAlignment = 64;
115+
#else
116+
constexpr uintptr_t DispatchAlignment = 16;
117+
#endif
118+
119+
constexpr uintptr_t DispatchAlignmentMask = ~(DispatchAlignment - 1);
120+
106121
static ManagedObjectWrapper* ToManagedObjectWrapper(void* dispatchPtr)
107122
{
108-
return ((InternalComInterfaceDispatch*)dispatchPtr)->_thisPtr;
123+
uintptr_t dispatch = reinterpret_cast<uintptr_t>(dispatchPtr) & DispatchAlignmentMask;
124+
return ((InternalComInterfaceDispatch*)dispatch)->_thisPtr;
109125
}
110126

111127
//

src/coreclr/nativeaot/System.Private.CoreLib/src/System/Runtime/InteropServices/ComWrappers.NativeAot.cs

Lines changed: 70 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,31 @@ public static unsafe T GetInstance<T>(ComInterfaceDispatch* dispatchPtr) where T
110110

111111
internal static unsafe ManagedObjectWrapper* ToManagedObjectWrapper(ComInterfaceDispatch* dispatchPtr)
112112
{
113-
return ((InternalComInterfaceDispatch*)dispatchPtr)->_thisPtr;
113+
InternalComInterfaceDispatch* dispatch = (InternalComInterfaceDispatch*)unchecked((nuint)dispatchPtr & (nuint)InternalComInterfaceDispatch.DispatchAlignmentMask);
114+
return dispatch->_thisPtr;
114115
}
115116
}
116117

117118
internal unsafe struct InternalComInterfaceDispatch
118119
{
119-
public IntPtr Vtable;
120+
#if TARGET_64BIT
121+
internal const int DispatchAlignment = 64;
122+
internal const int NumEntriesInDispatchTable = DispatchAlignment / 8 /* sizeof(void*) */ - 1;
123+
#else
124+
internal const int DispatchAlignment = 16;
125+
internal const int NumEntriesInDispatchTable = DispatchAlignment / 4 /* sizeof(void*) */ - 1;
126+
#endif
127+
internal const ulong DispatchAlignmentMask = unchecked((ulong)~(InternalComInterfaceDispatch.DispatchAlignment - 1));
128+
120129
internal ManagedObjectWrapper* _thisPtr;
130+
131+
public DispatchTable Vtables;
132+
133+
[InlineArray(NumEntriesInDispatchTable)]
134+
internal unsafe struct DispatchTable
135+
{
136+
private IntPtr _element;
137+
}
121138
}
122139

123140
internal enum CreateComInterfaceFlagsEx
@@ -337,16 +354,23 @@ public unsafe bool Destroy()
337354
}
338355
}
339356

357+
private unsafe IntPtr GetDispatchPointerAtIndex(int index)
358+
{
359+
InternalComInterfaceDispatch* dispatch = &Dispatches[index / InternalComInterfaceDispatch.NumEntriesInDispatchTable];
360+
IntPtr* vtables = (IntPtr*)(void*)&dispatch->Vtables;
361+
return (IntPtr)(&vtables[index % InternalComInterfaceDispatch.NumEntriesInDispatchTable]);
362+
}
363+
340364
private unsafe IntPtr AsRuntimeDefined(in Guid riid)
341365
{
342366
// The order of interface lookup here is important.
343-
// See CreateCCW() for the expected order.
367+
// See CreateManagedObjectWrapper() for the expected order.
344368
int i = UserDefinedCount;
345369
if ((Flags & CreateComInterfaceFlagsEx.CallerDefinedIUnknown) == 0)
346370
{
347371
if (riid == IID_IUnknown)
348372
{
349-
return (IntPtr)(Dispatches + i);
373+
return GetDispatchPointerAtIndex(i);
350374
}
351375

352376
i++;
@@ -356,7 +380,7 @@ private unsafe IntPtr AsRuntimeDefined(in Guid riid)
356380
{
357381
if (riid == IID_IReferenceTrackerTarget)
358382
{
359-
return (IntPtr)(Dispatches + i);
383+
return GetDispatchPointerAtIndex(i);
360384
}
361385

362386
i++;
@@ -365,7 +389,7 @@ private unsafe IntPtr AsRuntimeDefined(in Guid riid)
365389
{
366390
if (riid == IID_TaggedImpl)
367391
{
368-
return (IntPtr)(Dispatches + i);
392+
return GetDispatchPointerAtIndex(i);
369393
}
370394
}
371395

@@ -378,7 +402,7 @@ private unsafe IntPtr AsUserDefined(in Guid riid)
378402
{
379403
if (UserDefined[i].IID == riid)
380404
{
381-
return (IntPtr)(Dispatches + i);
405+
return GetDispatchPointerAtIndex(i);
382406
}
383407
}
384408

@@ -475,7 +499,7 @@ public ManagedObjectWrapperReleaser(ManagedObjectWrapper* wrapper)
475499
// Release GC handle created when MOW was built.
476500
if (_wrapper->Destroy())
477501
{
478-
NativeMemory.Free(_wrapper);
502+
NativeMemory.AlignedFree(_wrapper);
479503
_wrapper = null;
480504
}
481505
else
@@ -732,6 +756,12 @@ public unsafe IntPtr GetOrCreateComInterfaceForObject(object instance, CreateCom
732756
return managedObjectWrapper.ComIp;
733757
}
734758

759+
private static nuint AlignUp(nuint value, nuint alignment)
760+
{
761+
nuint alignMask = alignment - 1;
762+
return (nuint)((value + alignMask) & ~alignMask);
763+
}
764+
735765
private unsafe ManagedObjectWrapper* CreateManagedObjectWrapper(object instance, CreateComInterfaceFlags flags)
736766
{
737767
ComInterfaceEntry* userDefined = ComputeVtables(instance, flags, out int userDefinedCount);
@@ -762,21 +792,43 @@ public unsafe IntPtr GetOrCreateComInterfaceForObject(object instance, CreateCom
762792
// Compute size for ManagedObjectWrapper instance.
763793
int totalDefinedCount = runtimeDefinedCount + userDefinedCount;
764794

765-
// Allocate memory for the ManagedObjectWrapper.
766-
IntPtr wrapperMem = (IntPtr)NativeMemory.Alloc(
767-
(nuint)sizeof(ManagedObjectWrapper) + (nuint)totalDefinedCount * (nuint)sizeof(InternalComInterfaceDispatch));
795+
int numSections = totalDefinedCount / InternalComInterfaceDispatch.NumEntriesInDispatchTable;
796+
if (totalDefinedCount % InternalComInterfaceDispatch.NumEntriesInDispatchTable != 0)
797+
{
798+
// Account for a trailing partial section to fit all of the defined interfaces.
799+
numSections++;
800+
}
801+
802+
nuint headerSize = AlignUp((nuint)sizeof(ManagedObjectWrapper), InternalComInterfaceDispatch.DispatchAlignment);
768803

769-
// Compute the dispatch section offset and ensure it is aligned.
770-
ManagedObjectWrapper* mow = (ManagedObjectWrapper*)wrapperMem;
804+
// Instead of allocating a full section even when we have a trailing one, we'll allocate only
805+
// as much space as we need to store all of our dispatch tables.
806+
nuint dispatchSectionSize = (nuint)totalDefinedCount * (nuint)sizeof(void*) + (nuint)numSections * (nuint)sizeof(void*);
807+
808+
// Allocate memory for the ManagedObjectWrapper with the correct alignment for our dispatch tables.
809+
IntPtr wrapperMem = (IntPtr)NativeMemory.AlignedAlloc(
810+
headerSize + dispatchSectionSize,
811+
InternalComInterfaceDispatch.DispatchAlignment);
771812

772-
// Dispatches follow immediately after ManagedObjectWrapper
773-
InternalComInterfaceDispatch* pDispatches = (InternalComInterfaceDispatch*)(wrapperMem + sizeof(ManagedObjectWrapper));
774-
for (int i = 0; i < totalDefinedCount; i++)
813+
// Dispatches follow the ManagedObjectWrapper.
814+
InternalComInterfaceDispatch* pDispatches = (InternalComInterfaceDispatch*)((nuint)wrapperMem + headerSize);
815+
Span<InternalComInterfaceDispatch> dispatches = new Span<InternalComInterfaceDispatch>(pDispatches, numSections);
816+
for (int i = 0; i < dispatches.Length; i++)
775817
{
776-
pDispatches[i].Vtable = (i < userDefinedCount) ? userDefined[i].Vtable : runtimeDefinedVtable[i - userDefinedCount];
777-
pDispatches[i]._thisPtr = mow;
818+
dispatches[i]._thisPtr = (ManagedObjectWrapper*)wrapperMem;
819+
Span<IntPtr> dispatchVtables = dispatches[i].Vtables;
820+
for (int j = 0; j < dispatchVtables.Length; j++)
821+
{
822+
int index = i * dispatchVtables.Length + j;
823+
if (index >= totalDefinedCount)
824+
{
825+
break;
826+
}
827+
dispatchVtables[j] = (index < userDefinedCount) ? userDefined[index].Vtable : runtimeDefinedVtable[index - userDefinedCount];
828+
}
778829
}
779830

831+
ManagedObjectWrapper* mow = (ManagedObjectWrapper*)wrapperMem;
780832
mow->HolderHandle = IntPtr.Zero;
781833
mow->RefCount = 0;
782834
mow->UserDefinedCount = userDefinedCount;

0 commit comments

Comments
 (0)