@@ -110,14 +110,31 @@ public static unsafe T GetInstance<T>(ComInterfaceDispatch* dispatchPtr) where T
110110
111111 internal static unsafe ManagedObjectWrapper * ToManagedObjectWrapper ( ComInterfaceDispatch * dispatchPtr )
112112 {
113- return ( ( InternalComInterfaceDispatch * ) dispatchPtr ) ->_thisPtr ;
113+ InternalComInterfaceDispatch * dispatch = ( InternalComInterfaceDispatch * ) unchecked ( ( nuint ) dispatchPtr & ( nuint ) InternalComInterfaceDispatch . DispatchAlignmentMask ) ;
114+ return dispatch ->_thisPtr ;
114115 }
115116 }
116117
117118 internal unsafe struct InternalComInterfaceDispatch
118119 {
119- public IntPtr Vtable ;
120+ #if TARGET_64BIT
121+ internal const int DispatchAlignment = 64 ;
122+ internal const int NumEntriesInDispatchTable = DispatchAlignment / 8 /* sizeof(void*) */ - 1 ;
123+ #else
124+ internal const int DispatchAlignment = 16 ;
125+ internal const int NumEntriesInDispatchTable = DispatchAlignment / 4 /* sizeof(void*) */ - 1 ;
126+ #endif
127+ internal const ulong DispatchAlignmentMask = unchecked ( ( ulong ) ~ ( InternalComInterfaceDispatch . DispatchAlignment - 1 ) ) ;
128+
120129 internal ManagedObjectWrapper * _thisPtr ;
130+
131+ public DispatchTable Vtables ;
132+
133+ [ InlineArray ( NumEntriesInDispatchTable ) ]
134+ internal unsafe struct DispatchTable
135+ {
136+ private IntPtr _element ;
137+ }
121138 }
122139
123140 internal enum CreateComInterfaceFlagsEx
@@ -337,16 +354,23 @@ public unsafe bool Destroy()
337354 }
338355 }
339356
357+ private unsafe IntPtr GetDispatchPointerAtIndex ( int index )
358+ {
359+ InternalComInterfaceDispatch * dispatch = & Dispatches [ index / InternalComInterfaceDispatch . NumEntriesInDispatchTable ] ;
360+ IntPtr * vtables = ( IntPtr * ) ( void * ) & dispatch ->Vtables ;
361+ return ( IntPtr ) ( & vtables [ index % InternalComInterfaceDispatch . NumEntriesInDispatchTable ] ) ;
362+ }
363+
340364 private unsafe IntPtr AsRuntimeDefined ( in Guid riid )
341365 {
342366 // The order of interface lookup here is important.
343- // See CreateCCW () for the expected order.
367+ // See CreateManagedObjectWrapper () for the expected order.
344368 int i = UserDefinedCount ;
345369 if ( ( Flags & CreateComInterfaceFlagsEx . CallerDefinedIUnknown ) == 0 )
346370 {
347371 if ( riid == IID_IUnknown )
348372 {
349- return ( IntPtr ) ( Dispatches + i ) ;
373+ return GetDispatchPointerAtIndex ( i ) ;
350374 }
351375
352376 i ++ ;
@@ -356,7 +380,7 @@ private unsafe IntPtr AsRuntimeDefined(in Guid riid)
356380 {
357381 if ( riid == IID_IReferenceTrackerTarget )
358382 {
359- return ( IntPtr ) ( Dispatches + i ) ;
383+ return GetDispatchPointerAtIndex ( i ) ;
360384 }
361385
362386 i ++ ;
@@ -365,7 +389,7 @@ private unsafe IntPtr AsRuntimeDefined(in Guid riid)
365389 {
366390 if ( riid == IID_TaggedImpl )
367391 {
368- return ( IntPtr ) ( Dispatches + i ) ;
392+ return GetDispatchPointerAtIndex ( i ) ;
369393 }
370394 }
371395
@@ -378,7 +402,7 @@ private unsafe IntPtr AsUserDefined(in Guid riid)
378402 {
379403 if ( UserDefined [ i ] . IID == riid )
380404 {
381- return ( IntPtr ) ( Dispatches + i ) ;
405+ return GetDispatchPointerAtIndex ( i ) ;
382406 }
383407 }
384408
@@ -475,7 +499,7 @@ public ManagedObjectWrapperReleaser(ManagedObjectWrapper* wrapper)
475499 // Release GC handle created when MOW was built.
476500 if ( _wrapper ->Destroy ( ) )
477501 {
478- NativeMemory . Free ( _wrapper ) ;
502+ NativeMemory . AlignedFree ( _wrapper ) ;
479503 _wrapper = null ;
480504 }
481505 else
@@ -747,6 +771,12 @@ public unsafe IntPtr GetOrCreateComInterfaceForObject(object instance, CreateCom
747771 return managedObjectWrapper . ComIp ;
748772 }
749773
774+ private static nuint AlignUp ( nuint value , nuint alignment )
775+ {
776+ nuint alignMask = alignment - 1 ;
777+ return ( nuint ) ( ( value + alignMask ) & ~ alignMask ) ;
778+ }
779+
750780 private unsafe ManagedObjectWrapper * CreateManagedObjectWrapper ( object instance , CreateComInterfaceFlags flags )
751781 {
752782 ComInterfaceEntry * userDefined = ComputeVtables ( instance , flags , out int userDefinedCount ) ;
@@ -777,21 +807,43 @@ public unsafe IntPtr GetOrCreateComInterfaceForObject(object instance, CreateCom
777807 // Compute size for ManagedObjectWrapper instance.
778808 int totalDefinedCount = runtimeDefinedCount + userDefinedCount ;
779809
780- // Allocate memory for the ManagedObjectWrapper.
781- IntPtr wrapperMem = ( IntPtr ) NativeMemory . Alloc (
782- ( nuint ) sizeof ( ManagedObjectWrapper ) + ( nuint ) totalDefinedCount * ( nuint ) sizeof ( InternalComInterfaceDispatch ) ) ;
810+ int numSections = totalDefinedCount / InternalComInterfaceDispatch . NumEntriesInDispatchTable ;
811+ if ( totalDefinedCount % InternalComInterfaceDispatch . NumEntriesInDispatchTable != 0 )
812+ {
813+ // Account for a trailing partial section to fit all of the defined interfaces.
814+ numSections ++ ;
815+ }
816+
817+ nuint headerSize = AlignUp ( ( nuint ) sizeof ( ManagedObjectWrapper ) , InternalComInterfaceDispatch . DispatchAlignment ) ;
783818
784- // Compute the dispatch section offset and ensure it is aligned.
785- ManagedObjectWrapper * mow = ( ManagedObjectWrapper * ) wrapperMem ;
819+ // Instead of allocating a full section even when we have a trailing one, we'll allocate only
820+ // as much space as we need to store all of our dispatch tables.
821+ nuint dispatchSectionSize = ( nuint ) totalDefinedCount * ( nuint ) sizeof ( void * ) + ( nuint ) numSections * ( nuint ) sizeof ( void * ) ;
822+
823+ // Allocate memory for the ManagedObjectWrapper with the correct alignment for our dispatch tables.
824+ IntPtr wrapperMem = ( IntPtr ) NativeMemory . AlignedAlloc (
825+ headerSize + dispatchSectionSize ,
826+ InternalComInterfaceDispatch . DispatchAlignment ) ;
786827
787- // Dispatches follow immediately after ManagedObjectWrapper
788- InternalComInterfaceDispatch * pDispatches = ( InternalComInterfaceDispatch * ) ( wrapperMem + sizeof ( ManagedObjectWrapper ) ) ;
789- for ( int i = 0 ; i < totalDefinedCount ; i ++ )
828+ // Dispatches follow the ManagedObjectWrapper.
829+ InternalComInterfaceDispatch * pDispatches = ( InternalComInterfaceDispatch * ) ( ( nuint ) wrapperMem + headerSize ) ;
830+ Span < InternalComInterfaceDispatch > dispatches = new Span < InternalComInterfaceDispatch > ( pDispatches , numSections ) ;
831+ for ( int i = 0 ; i < dispatches . Length ; i ++ )
790832 {
791- pDispatches [ i ] . Vtable = ( i < userDefinedCount ) ? userDefined [ i ] . Vtable : runtimeDefinedVtable [ i - userDefinedCount ] ;
792- pDispatches [ i ] . _thisPtr = mow ;
833+ dispatches [ i ] . _thisPtr = ( ManagedObjectWrapper * ) wrapperMem ;
834+ Span < IntPtr > dispatchVtables = dispatches [ i ] . Vtables ;
835+ for ( int j = 0 ; j < dispatchVtables . Length ; j ++ )
836+ {
837+ int index = i * dispatchVtables . Length + j ;
838+ if ( index >= totalDefinedCount )
839+ {
840+ break ;
841+ }
842+ dispatchVtables [ j ] = ( index < userDefinedCount ) ? userDefined [ index ] . Vtable : runtimeDefinedVtable [ index - userDefinedCount ] ;
843+ }
793844 }
794845
846+ ManagedObjectWrapper * mow = ( ManagedObjectWrapper * ) wrapperMem ;
795847 mow ->HolderHandle = IntPtr . Zero ;
796848 mow ->RefCount = 0 ;
797849 mow ->UserDefinedCount = userDefinedCount ;
0 commit comments