Skip to content

Commit 886bca3

Browse files
authored
Use MemoryMarshal.Cast in a few places (#99835)
* Use MemoryMarshal.Cast in a few places * Removed special-casing * Fix build
1 parent 07c99ab commit 886bca3

File tree

5 files changed

+39
-31
lines changed

5 files changed

+39
-31
lines changed

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.Helpers.cs

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
using System.Runtime.CompilerServices;
77
using System.Runtime.InteropServices;
88

9+
#pragma warning disable CS8500 // This takes the address of, gets the size of, or declares a pointer to a managed type
10+
911
namespace System.Numerics.Tensors
1012
{
1113
/// <summary>Performs primitive tensor operations over spans of memory.</summary>
@@ -23,7 +25,31 @@ private static void ValidateInputOutputSpanNonOverlapping<T>(ReadOnlySpan<T> inp
2325
}
2426

2527
/// <summary>Throws an <see cref="OverflowException"/> for trying to negate the minimum value of a two-complement value.</summary>
26-
internal static void ThrowNegateTwosCompOverflow() => throw new OverflowException(SR.Overflow_NegateTwosCompNum);
28+
private static void ThrowNegateTwosCompOverflow() => throw new OverflowException(SR.Overflow_NegateTwosCompNum);
29+
30+
/// <summary>Creates a span of <typeparamref name="TTo"/> from a <typeparamref name="TFrom"/> when they're the same type.</summary>
31+
/// <remarks>
32+
/// This is the same as MemoryMarshal.Cast, except only to be used when TFrom and TTo are the same type or effectively
33+
/// the same type (e.g. int and nint in a 32-bit process). MemoryMarshal.Cast can't currently be used as it's
34+
/// TFrom/TTo are constrained to be value types.
35+
/// </remarks>
36+
private static unsafe Span<TTo> Rename<TFrom, TTo>(Span<TFrom> span)
37+
{
38+
Debug.Assert(sizeof(TFrom) == sizeof(TTo));
39+
return *(Span<TTo>*)(&span);
40+
}
41+
42+
/// <summary>Creates a span of <typeparamref name="TTo"/> from a <typeparamref name="TFrom"/> when they're the same type.</summary>
43+
/// <remarks>
44+
/// This is the same as MemoryMarshal.Cast, except only to be used when TFrom and TTo are the same type or effectively
45+
/// the same type (e.g. int and nint in a 32-bit process). MemoryMarshal.Cast can't currently be used as it's
46+
/// TFrom/TTo are constrained to be value types.
47+
/// </remarks>
48+
private static unsafe ReadOnlySpan<TTo> Rename<TFrom, TTo>(ReadOnlySpan<TFrom> span)
49+
{
50+
Debug.Assert(sizeof(TFrom) == sizeof(TTo));
51+
return *(ReadOnlySpan<TTo>*)(&span);
52+
}
2753

2854
/// <summary>Mask used to handle alignment elements before vectorized handling of the input.</summary>
2955
/// <remarks>

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/Common/TensorPrimitives.IUnaryOperator.cs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1247,12 +1247,5 @@ static void VectorizedSmall8(ref TInput xRef, ref TOutput dRef, nuint remainder)
12471247
}
12481248
}
12491249
}
1250-
1251-
/// <summary>Creates a span of <typeparamref name="TTo"/> from a <typeparamref name="TTo"/> when they're the same type.</summary>
1252-
private static unsafe Span<TTo> Rename<TFrom, TTo>(Span<TFrom> span)
1253-
{
1254-
Debug.Assert(sizeof(TFrom) == sizeof(TTo));
1255-
return MemoryMarshal.CreateSpan(ref Unsafe.As<TFrom, TTo>(ref MemoryMarshal.GetReference(span)), span.Length);
1256-
}
12571250
}
12581251
}

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.ConvertHelpers.cs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -643,13 +643,6 @@ static Vector512<uint> SingleToHalfAsWidenedUInt32(Vector512<float> value)
643643
}
644644
}
645645

646-
/// <summary>Creates a span of <typeparamref name="TTo"/> from a <typeparamref name="TTo"/> when they're the same type.</summary>
647-
private static unsafe ReadOnlySpan<TTo> Rename<TFrom, TTo>(ReadOnlySpan<TFrom> span)
648-
{
649-
Debug.Assert(sizeof(TFrom) == sizeof(TTo));
650-
return MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TFrom, TTo>(ref MemoryMarshal.GetReference(span)), span.Length);
651-
}
652-
653646
/// <summary>Gets whether <typeparamref name="T"/> is <see cref="uint"/> or <see cref="nuint"/> if in a 32-bit process.</summary>
654647
private static bool IsUInt32Like<T>() => typeof(T) == typeof(uint) || (IntPtr.Size == 4 && typeof(T) == typeof(nuint));
655648

src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/netcore/TensorPrimitives.Round.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,13 @@ public static void Round<T>(ReadOnlySpan<T> x, int digits, MidpointRounding mode
106106
if (typeof(T) == typeof(float))
107107
{
108108
ReadOnlySpan<float> roundPower10Single = [1e0f, 1e1f, 1e2f, 1e3f, 1e4f, 1e5f, 1e6f];
109-
roundPower10 = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<float, T>(ref MemoryMarshal.GetReference(roundPower10Single)), roundPower10Single.Length);
109+
roundPower10 = Rename<float, T>(roundPower10Single);
110110
}
111111
else if (typeof(T) == typeof(double))
112112
{
113113
Debug.Assert(typeof(T) == typeof(double));
114114
ReadOnlySpan<double> roundPower10Double = [1e0, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, 1e10, 1e11, 1e12, 1e13, 1e14, 1e15];
115-
roundPower10 = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<double, T>(ref MemoryMarshal.GetReference(roundPower10Double)), roundPower10Double.Length);
115+
roundPower10 = Rename<double, T>(roundPower10Double);
116116
}
117117
else
118118
{

src/libraries/System.Private.CoreLib/src/System/Number.Parsing.cs

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -883,16 +883,16 @@ internal static bool SpanStartsWith<TChar>(ReadOnlySpan<TChar> span, ReadOnlySpa
883883
{
884884
if (typeof(TChar) == typeof(char))
885885
{
886-
ReadOnlySpan<char> typedSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TChar, char>(ref MemoryMarshal.GetReference(span)), span.Length);
887-
ReadOnlySpan<char> typedValue = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TChar, char>(ref MemoryMarshal.GetReference(value)), value.Length);
886+
ReadOnlySpan<char> typedSpan = MemoryMarshal.Cast<TChar, char>(span);
887+
ReadOnlySpan<char> typedValue = MemoryMarshal.Cast<TChar, char>(value);
888888
return typedSpan.StartsWith(typedValue, comparisonType);
889889
}
890890
else
891891
{
892892
Debug.Assert(typeof(TChar) == typeof(byte));
893893

894-
ReadOnlySpan<byte> typedSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TChar, byte>(ref MemoryMarshal.GetReference(span)), span.Length);
895-
ReadOnlySpan<byte> typedValue = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TChar, byte>(ref MemoryMarshal.GetReference(value)), value.Length);
894+
ReadOnlySpan<byte> typedSpan = MemoryMarshal.Cast<TChar, byte>(span);
895+
ReadOnlySpan<byte> typedValue = MemoryMarshal.Cast<TChar, byte>(value);
896896
return typedSpan.StartsWithUtf8(typedValue, comparisonType);
897897
}
898898
}
@@ -903,17 +903,13 @@ internal static ReadOnlySpan<TChar> SpanTrim<TChar>(ReadOnlySpan<TChar> span)
903903
{
904904
if (typeof(TChar) == typeof(char))
905905
{
906-
ReadOnlySpan<char> typedSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TChar, char>(ref MemoryMarshal.GetReference(span)), span.Length);
907-
ReadOnlySpan<char> result = typedSpan.Trim();
908-
return MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<char, TChar>(ref MemoryMarshal.GetReference(result)), result.Length);
906+
return MemoryMarshal.Cast<char, TChar>(MemoryMarshal.Cast<TChar, char>(span).Trim());
909907
}
910908
else
911909
{
912910
Debug.Assert(typeof(TChar) == typeof(byte));
913911

914-
ReadOnlySpan<byte> typedSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TChar, byte>(ref MemoryMarshal.GetReference(span)), span.Length);
915-
ReadOnlySpan<byte> result = typedSpan.TrimUtf8();
916-
return MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<byte, TChar>(ref MemoryMarshal.GetReference(result)), result.Length);
912+
return MemoryMarshal.Cast<byte, TChar>(MemoryMarshal.Cast<TChar, byte>(span).TrimUtf8());
917913
}
918914
}
919915

@@ -923,16 +919,16 @@ internal static bool SpanEqualsOrdinalIgnoreCase<TChar>(ReadOnlySpan<TChar> span
923919
{
924920
if (typeof(TChar) == typeof(char))
925921
{
926-
ReadOnlySpan<char> typedSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TChar, char>(ref MemoryMarshal.GetReference(span)), span.Length);
927-
ReadOnlySpan<char> typedValue = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TChar, char>(ref MemoryMarshal.GetReference(value)), value.Length);
922+
ReadOnlySpan<char> typedSpan = MemoryMarshal.Cast<TChar, char>(span);
923+
ReadOnlySpan<char> typedValue = MemoryMarshal.Cast<TChar, char>(value);
928924
return typedSpan.EqualsOrdinalIgnoreCase(typedValue);
929925
}
930926
else
931927
{
932928
Debug.Assert(typeof(TChar) == typeof(byte));
933929

934-
ReadOnlySpan<byte> typedSpan = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TChar, byte>(ref MemoryMarshal.GetReference(span)), span.Length);
935-
ReadOnlySpan<byte> typedValue = MemoryMarshal.CreateReadOnlySpan(ref Unsafe.As<TChar, byte>(ref MemoryMarshal.GetReference(value)), value.Length);
930+
ReadOnlySpan<byte> typedSpan = MemoryMarshal.Cast<TChar, byte>(span);
931+
ReadOnlySpan<byte> typedValue = MemoryMarshal.Cast<TChar, byte>(value);
936932
return typedSpan.EqualsOrdinalIgnoreCaseUtf8(typedValue);
937933
}
938934
}

0 commit comments

Comments
 (0)