Skip to content

Commit 3745e38

Browse files
alexcovingtongfoidldanmoseley
authored
Vectorize SpanHelpers<T>.IndexOf (#60974)
Co-authored-by: Günther Foidl <gue@korporal.at> Co-authored-by: Dan Moseley <danmose@microsoft.com>
1 parent a6e0f25 commit 3745e38

File tree

4 files changed

+275
-8
lines changed

4 files changed

+275
-8
lines changed

src/libraries/System.Memory/tests/Span/Contains.T.cs

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// Licensed to the .NET Foundation under one or more agreements.
22
// The .NET Foundation licenses this file to you under the MIT license.
33

4+
using System.Collections.Generic;
5+
using System.Linq;
46
using Xunit;
57

68
namespace System.SpanTests
@@ -193,5 +195,114 @@ public static void ContainsNull_String(string[] spanInput, bool expected)
193195
Span<string> theStrings = spanInput;
194196
Assert.Equal(expected, theStrings.Contains(null));
195197
}
198+
199+
[Theory]
200+
[InlineData(new int[] { 1, 2, 3, 4 }, 4, true)]
201+
[InlineData(new int[] { 1, 2, 3, 4 }, 5, false)]
202+
public static void Contains_Int32(int[] array, int value, bool expectedResult)
203+
{
204+
// Test with short Span
205+
Span<int> span = new Span<int>(array);
206+
bool result = span.Contains(value);
207+
Assert.Equal(result, expectedResult);
208+
209+
// Test with long Span
210+
for (int i = 0; i < 10; i++)
211+
array = array.Concat(array).ToArray();
212+
span = new Span<int>(array);
213+
result = span.Contains(value);
214+
Assert.Equal(result, expectedResult);
215+
}
216+
217+
[Theory]
218+
[InlineData(new long[] { 1, 2, 3, 4 }, 4, true)]
219+
[InlineData(new long[] { 1, 2, 3, 4 }, 5, false)]
220+
public static void Contains_Int64(long[] array, long value, bool expectedResult)
221+
{
222+
// Test with short Span
223+
Span<long> span = new Span<long>(array);
224+
bool result = span.Contains(value);
225+
Assert.Equal(result, expectedResult);
226+
227+
// Test with long Span
228+
for (int i = 0; i < 10; i++)
229+
array = array.Concat(array).ToArray();
230+
span = new Span<long>(array);
231+
result = span.Contains(value);
232+
Assert.Equal(result, expectedResult);
233+
}
234+
235+
[Theory]
236+
[InlineData(new byte[] { 1, 2, 3, 4 }, 4, true)]
237+
[InlineData(new byte[] { 1, 2, 3, 4 }, 5, false)]
238+
public static void Contains_Byte(byte[] array, byte value, bool expectedResult)
239+
{
240+
// Test with short Span
241+
Span<byte> span = new Span<byte>(array);
242+
bool result = span.Contains(value);
243+
Assert.Equal(result, expectedResult);
244+
245+
// Test with long Span
246+
for (int i = 0; i < 10; i++)
247+
array = array.Concat(array).ToArray();
248+
span = new Span<byte>(array);
249+
result = span.Contains(value);
250+
Assert.Equal(result, expectedResult);
251+
}
252+
253+
[Theory]
254+
[InlineData(new char[] { 'a', 'b', 'c', 'd' }, 'd', true)]
255+
[InlineData(new char[] { 'a', 'b', 'c', 'd' }, 'e', false)]
256+
public static void Contains_Char(char[] array, char value, bool expectedResult)
257+
{
258+
// Test with short Span
259+
Span<char> span = new Span<char>(array);
260+
bool result = span.Contains(value);
261+
Assert.Equal(result, expectedResult);
262+
263+
// Test with long Span
264+
for (int i = 0; i < 10; i++)
265+
array = array.Concat(array).ToArray();
266+
span = new Span<char>(array);
267+
result = span.Contains(value);
268+
Assert.Equal(result, expectedResult);
269+
270+
}
271+
272+
[Theory]
273+
[InlineData(new float[] { 1, 2, 3, 4 }, 4, true)]
274+
[InlineData(new float[] { 1, 2, 3, 4 }, 5, false)]
275+
public static void Contains_Float(float[] array, float value, bool expectedResult)
276+
{
277+
// Test with short Span
278+
Span<float> span = new Span<float>(array);
279+
bool result = span.Contains(value);
280+
Assert.Equal(result, expectedResult);
281+
282+
// Test with long Span
283+
for (int i = 0; i < 10; i++)
284+
array = array.Concat(array).ToArray();
285+
span = new Span<float>(array);
286+
result = span.Contains(value);
287+
Assert.Equal(result, expectedResult);
288+
}
289+
290+
[Theory]
291+
[InlineData(new double[] { 1, 2, 3, 4 }, 4, true)]
292+
[InlineData(new double[] { 1, 2, 3, 4 }, 5, false)]
293+
public static void Contains_Double(double[] array, double value, bool expectedResult)
294+
{
295+
// Test with short Span
296+
Span<double> span = new Span<double>(array);
297+
bool result = span.Contains(value);
298+
Assert.Equal(result, expectedResult);
299+
300+
// Test with long Span
301+
for (int i = 0; i < 10; i++)
302+
array = array.Concat(array).ToArray();
303+
span = new Span<double>(array);
304+
result = span.Contains(value);
305+
Assert.Equal(result, expectedResult);
306+
}
196307
}
197308
}

src/libraries/System.Private.CoreLib/src/System/Array.cs

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1232,18 +1232,28 @@ ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As<char[]>(array))
12321232
}
12331233
else if (Unsafe.SizeOf<T>() == sizeof(int))
12341234
{
1235-
int result = SpanHelpers.IndexOf(
1236-
ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As<int[]>(array)), startIndex),
1237-
Unsafe.As<T, int>(ref value),
1238-
count);
1235+
int result = typeof(T).IsValueType
1236+
? SpanHelpers.IndexOfValueType(
1237+
ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As<int[]>(array)), startIndex),
1238+
Unsafe.As<T, int>(ref value),
1239+
count)
1240+
: SpanHelpers.IndexOf(
1241+
ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As<int[]>(array)), startIndex),
1242+
Unsafe.As<T, int>(ref value),
1243+
count);
12391244
return (result >= 0 ? startIndex : 0) + result;
12401245
}
12411246
else if (Unsafe.SizeOf<T>() == sizeof(long))
12421247
{
1243-
int result = SpanHelpers.IndexOf(
1244-
ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As<long[]>(array)), startIndex),
1245-
Unsafe.As<T, long>(ref value),
1246-
count);
1248+
int result = typeof(T).IsValueType
1249+
? SpanHelpers.IndexOfValueType(
1250+
ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As<long[]>(array)), startIndex),
1251+
Unsafe.As<T, long>(ref value),
1252+
count)
1253+
: SpanHelpers.IndexOf(
1254+
ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(Unsafe.As<long[]>(array)), startIndex),
1255+
Unsafe.As<T, long>(ref value),
1256+
count);
12471257
return (result >= 0 ? startIndex : 0) + result;
12481258
}
12491259
}

src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,18 @@ ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
279279
ref Unsafe.As<T, char>(ref MemoryMarshal.GetReference(span)),
280280
Unsafe.As<T, char>(ref value),
281281
span.Length);
282+
283+
if (Unsafe.SizeOf<T>() == sizeof(int))
284+
return -1 != SpanHelpers.IndexOfValueType(
285+
ref Unsafe.As<T, int>(ref MemoryMarshal.GetReference(span)),
286+
Unsafe.As<T, int>(ref value),
287+
span.Length);
288+
289+
if (Unsafe.SizeOf<T>() == sizeof(long))
290+
return -1 != SpanHelpers.IndexOfValueType(
291+
ref Unsafe.As<T, long>(ref MemoryMarshal.GetReference(span)),
292+
Unsafe.As<T, long>(ref value),
293+
span.Length);
282294
}
283295

284296
return SpanHelpers.Contains(ref MemoryMarshal.GetReference(span), value, span.Length);
@@ -306,6 +318,18 @@ ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
306318
ref Unsafe.As<T, char>(ref MemoryMarshal.GetReference(span)),
307319
Unsafe.As<T, char>(ref value),
308320
span.Length);
321+
322+
if (Unsafe.SizeOf<T>() == sizeof(int))
323+
return -1 != SpanHelpers.IndexOfValueType(
324+
ref Unsafe.As<T, int>(ref MemoryMarshal.GetReference(span)),
325+
Unsafe.As<T, int>(ref value),
326+
span.Length);
327+
328+
if (Unsafe.SizeOf<T>() == sizeof(long))
329+
return -1 != SpanHelpers.IndexOfValueType(
330+
ref Unsafe.As<T, long>(ref MemoryMarshal.GetReference(span)),
331+
Unsafe.As<T, long>(ref value),
332+
span.Length);
309333
}
310334

311335
return SpanHelpers.Contains(ref MemoryMarshal.GetReference(span), value, span.Length);
@@ -332,6 +356,18 @@ ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
332356
ref Unsafe.As<T, char>(ref MemoryMarshal.GetReference(span)),
333357
Unsafe.As<T, char>(ref value),
334358
span.Length);
359+
360+
if (Unsafe.SizeOf<T>() == sizeof(int))
361+
return SpanHelpers.IndexOfValueType(
362+
ref Unsafe.As<T, int>(ref MemoryMarshal.GetReference(span)),
363+
Unsafe.As<T, int>(ref value),
364+
span.Length);
365+
366+
if (Unsafe.SizeOf<T>() == sizeof(long))
367+
return SpanHelpers.IndexOfValueType(
368+
ref Unsafe.As<T, long>(ref MemoryMarshal.GetReference(span)),
369+
Unsafe.As<T, long>(ref value),
370+
span.Length);
335371
}
336372

337373
return SpanHelpers.IndexOf(ref MemoryMarshal.GetReference(span), value, span.Length);

src/libraries/System.Private.CoreLib/src/System/SpanHelpers.T.cs

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using System.Diagnostics;
55
using System.Numerics;
66
using System.Runtime.CompilerServices;
7+
using System.Runtime.InteropServices;
78
using System.Runtime.Intrinsics;
89
using Internal.Runtime.CompilerServices;
910

@@ -291,6 +292,115 @@ public static unsafe bool Contains<T>(ref T searchSpace, T value, int length) wh
291292
return true;
292293
}
293294

295+
internal static unsafe int IndexOfValueType<T>(ref T searchSpace, T value, int length) where T : struct, IEquatable<T>
296+
{
297+
Debug.Assert(length >= 0);
298+
299+
nint index = 0; // Use nint for arithmetic to avoid unnecessary 64->32->64 truncations
300+
if (Vector.IsHardwareAccelerated && Vector<T>.IsTypeSupported && (Vector<T>.Count * 2) <= length)
301+
{
302+
Vector<T> valueVector = new Vector<T>(value);
303+
Vector<T> compareVector = default;
304+
Vector<T> matchVector = default;
305+
if ((uint)length % (uint)Vector<T>.Count != 0)
306+
{
307+
// Number of elements is not a multiple of Vector<T>.Count, so do one
308+
// check and shift only enough for the remaining set to be a multiple
309+
// of Vector<T>.Count.
310+
compareVector = Unsafe.As<T, Vector<T>>(ref Unsafe.Add(ref searchSpace, index));
311+
matchVector = Vector.Equals(valueVector, compareVector);
312+
if (matchVector != Vector<T>.Zero)
313+
{
314+
goto VectorMatch;
315+
}
316+
index += length % Vector<T>.Count;
317+
length -= length % Vector<T>.Count;
318+
}
319+
while (length > 0)
320+
{
321+
compareVector = Unsafe.As<T, Vector<T>>(ref Unsafe.Add(ref searchSpace, index));
322+
matchVector = Vector.Equals(valueVector, compareVector);
323+
if (matchVector != Vector<T>.Zero)
324+
{
325+
goto VectorMatch;
326+
}
327+
index += Vector<T>.Count;
328+
length -= Vector<T>.Count;
329+
}
330+
goto NotFound;
331+
VectorMatch:
332+
for (int i = 0; i < Vector<T>.Count; i++)
333+
if (compareVector[i].Equals(value))
334+
return (int)(index + i);
335+
}
336+
337+
while (length >= 8)
338+
{
339+
if (value.Equals(Unsafe.Add(ref searchSpace, index)))
340+
goto Found;
341+
if (value.Equals(Unsafe.Add(ref searchSpace, index + 1)))
342+
goto Found1;
343+
if (value.Equals(Unsafe.Add(ref searchSpace, index + 2)))
344+
goto Found2;
345+
if (value.Equals(Unsafe.Add(ref searchSpace, index + 3)))
346+
goto Found3;
347+
if (value.Equals(Unsafe.Add(ref searchSpace, index + 4)))
348+
goto Found4;
349+
if (value.Equals(Unsafe.Add(ref searchSpace, index + 5)))
350+
goto Found5;
351+
if (value.Equals(Unsafe.Add(ref searchSpace, index + 6)))
352+
goto Found6;
353+
if (value.Equals(Unsafe.Add(ref searchSpace, index + 7)))
354+
goto Found7;
355+
356+
length -= 8;
357+
index += 8;
358+
}
359+
360+
while (length >= 4)
361+
{
362+
if (value.Equals(Unsafe.Add(ref searchSpace, index)))
363+
goto Found;
364+
if (value.Equals(Unsafe.Add(ref searchSpace, index + 1)))
365+
goto Found1;
366+
if (value.Equals(Unsafe.Add(ref searchSpace, index + 2)))
367+
goto Found2;
368+
if (value.Equals(Unsafe.Add(ref searchSpace, index + 3)))
369+
goto Found3;
370+
371+
length -= 4;
372+
index += 4;
373+
}
374+
375+
while (length > 0)
376+
{
377+
if (value.Equals(Unsafe.Add(ref searchSpace, index)))
378+
goto Found;
379+
380+
index += 1;
381+
length--;
382+
}
383+
NotFound:
384+
return -1;
385+
386+
Found: // Workaround for https://github.com/dotnet/runtime/issues/8795
387+
return (int)index;
388+
Found1:
389+
return (int)(index + 1);
390+
Found2:
391+
return (int)(index + 2);
392+
Found3:
393+
return (int)(index + 3);
394+
Found4:
395+
return (int)(index + 4);
396+
Found5:
397+
return (int)(index + 5);
398+
Found6:
399+
return (int)(index + 6);
400+
Found7:
401+
return (int)(index + 7);
402+
}
403+
294404
public static unsafe int IndexOf<T>(ref T searchSpace, T value, int length) where T : IEquatable<T>
295405
{
296406
Debug.Assert(length >= 0);

0 commit comments

Comments
 (0)