-
Notifications
You must be signed in to change notification settings - Fork 5.4k
Tensor.Slice No Longer Copies #113166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tensor.Slice No Longer Copies #113166
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,6 +34,8 @@ public sealed class Tensor<T> | |
| internal readonly nint[] _strides; | ||
| /// <summary>If the backing memory is permanently pinned (so not just using a fixed statement).</summary> | ||
| internal readonly bool _isPinned; | ||
| /// <summary>The offset of the first element in the backing memory.</summary> | ||
| internal readonly int _memoryOffset; | ||
|
|
||
| /// <summary> | ||
| /// Creates a new empty Tensor. | ||
|
|
@@ -44,13 +46,14 @@ internal Tensor() | |
| _values = []; | ||
| _lengths = []; | ||
| _strides = []; | ||
| _memoryOffset = 0; | ||
| } | ||
|
|
||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | ||
| internal Tensor(T[]? values, ReadOnlySpan<nint> lengths, bool isPinned = false) : this(values, lengths, Array.Empty<nint>(), isPinned) { } | ||
| internal Tensor(T[]? values, ReadOnlySpan<nint> lengths, bool isPinned = false, int memoryOffset = 0) : this(values, lengths, Array.Empty<nint>(), isPinned, memoryOffset) { } | ||
|
|
||
| [MethodImpl(MethodImplOptions.AggressiveInlining)] | ||
| internal Tensor(T[]? values, ReadOnlySpan<nint> lengths, ReadOnlySpan<nint> strides, bool isPinned = false) | ||
| internal Tensor(T[]? values, ReadOnlySpan<nint> lengths, ReadOnlySpan<nint> strides, bool isPinned = false, int memoryOffset = 0) | ||
| { | ||
| if (values == null) | ||
| { | ||
|
|
@@ -60,10 +63,12 @@ internal Tensor(T[]? values, ReadOnlySpan<nint> lengths, ReadOnlySpan<nint> stri | |
| _values = []; | ||
| _lengths = []; | ||
| _strides = []; | ||
| _memoryOffset = memoryOffset; | ||
| return; // returns default | ||
| } | ||
|
|
||
| _lengths = lengths.IsEmpty ? [values.Length] : lengths.ToArray(); | ||
| _memoryOffset = memoryOffset; | ||
|
|
||
| _flattenedLength = TensorSpanHelpers.CalculateTotalLength(_lengths); | ||
| _strides = strides.IsEmpty ? TensorSpanHelpers.CalculateStrides(_lengths, _flattenedLength) : strides.ToArray(); | ||
|
|
@@ -386,7 +391,7 @@ public Tensor<T> this[Tensor<bool> filter] | |
| /// Converts this <see cref="Tensor{T}"/> to a <see cref="TensorSpan{T}"/> pointing to the same backing memory."/> | ||
| /// </summary> | ||
| /// <returns><see cref="TensorSpan{T}"/></returns> | ||
| public TensorSpan<T> AsTensorSpan() => new TensorSpan<T>(ref MemoryMarshal.GetArrayDataReference(_values), _lengths, _strides, _flattenedLength); | ||
| public TensorSpan<T> AsTensorSpan() => new TensorSpan<T>(ref Unsafe.Add(ref MemoryMarshal.GetArrayDataReference(_values), _memoryOffset), _lengths, _strides, _values.Length - _memoryOffset); | ||
|
|
||
| /// <summary> | ||
| /// Converts this <see cref="Tensor{T}"/> to a <see cref="TensorSpan{T}"/> pointing to the same backing memory based on the provided ranges."/> | ||
|
|
@@ -454,26 +459,71 @@ public Tensor<T> this[Tensor<bool> filter] | |
| /// Forms a slice out of the given tensor | ||
| /// </summary> | ||
| /// <param name="start">The ranges for the slice</param> | ||
| /// <returns><see cref="Tensor{T}"/> as a copy of the provided ranges.</returns> | ||
| // REVIEW: CURRENTLY DOES A COPY. | ||
| /// <returns><see cref="Tensor{T}"/> without copying the provided ranges.</returns> | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we consider this a breaking change? Also, does tensors use triple slash comments as source of truth? Asking in case it doesn't, in which case you will need to update dotnet-api-docs manually.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I believe we are using triple slash comments as source of truth. I'm not sure if its a breaking change, but since the API here is still experimental we are good to make this change. |
||
| public Tensor<T> Slice(params ReadOnlySpan<NRange> start) | ||
| { | ||
| if (start.Length != Lengths.Length) | ||
| throw new ArgumentOutOfRangeException(nameof(start), "Number of dimensions to slice does not equal the number of dimensions in the span"); | ||
|
|
||
| TensorSpan<T> s = AsTensorSpan(start); | ||
| T[] values = _isPinned ? GC.AllocateArray<T>(checked((int)s.FlattenedLength), _isPinned) : (new T[s.FlattenedLength]); | ||
| var outTensor = new Tensor<T>(values, s.Lengths.ToArray(), _isPinned); | ||
| s.CopyTo(outTensor); | ||
| return outTensor; | ||
| scoped Span<nint> lengths; | ||
| scoped Span<nint> offsets; | ||
| nint[]? lengthsArray; | ||
| nint[]? offsetsArray; | ||
| if (Rank > TensorShape.MaxInlineRank) | ||
| { | ||
| lengthsArray = ArrayPool<nint>.Shared.Rent(Rank); | ||
| lengths = lengthsArray.AsSpan(0, Rank); | ||
|
|
||
| offsetsArray = ArrayPool<nint>.Shared.Rent(Rank); | ||
| offsets = offsetsArray.AsSpan(0, Rank); | ||
| } | ||
| else | ||
| { | ||
| lengths = stackalloc nint[Rank]; | ||
| offsets = stackalloc nint[Rank]; | ||
|
|
||
| lengthsArray = null; | ||
| offsetsArray = null; | ||
| } | ||
| lengths.Clear(); | ||
| offsets.Clear(); | ||
|
|
||
| for (int i = 0; i < start.Length; i++) | ||
| { | ||
| (offsets[i], lengths[i]) = start[i].GetOffsetAndLength(Lengths[i]); | ||
| } | ||
|
|
||
| // When we have an empty Tensor and someone wants to slice all of it, we should return an empty Tensor. | ||
| // FlattenedLength is computed everytime so using a local to cache the value. | ||
| nint flattenedLength = FlattenedLength; | ||
| int memoryOffset = 0; | ||
|
|
||
| if (flattenedLength != 0) | ||
| { | ||
| for (int i = 0; i < offsets.Length; i++) | ||
| { | ||
| memoryOffset += (int)(Strides[i] * offsets[i]); | ||
| } | ||
| } | ||
|
|
||
| if ((memoryOffset >= _values.Length || memoryOffset < 0) && flattenedLength != 0) | ||
michaelgsharp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ThrowHelper.ThrowIndexOutOfRangeException(); | ||
|
|
||
| Tensor<T> toReturn = new Tensor<T>(_values, lengths, Strides, _isPinned, memoryOffset); | ||
|
|
||
| if (offsetsArray != null) | ||
| ArrayPool<nint>.Shared.Return(offsetsArray); | ||
| if (lengthsArray != null) | ||
| ArrayPool<nint>.Shared.Return(lengthsArray); | ||
|
|
||
| return toReturn; | ||
| } | ||
|
|
||
| /// <summary> | ||
| /// Forms a slice out of the given tensor | ||
| /// </summary> | ||
| /// <param name="start">The start indexes for the slice</param> | ||
| /// <returns><see cref="Tensor{T}"/> as a copy of the provided ranges.</returns> | ||
| // REVIEW: CURRENTLY DOES A COPY. | ||
michaelgsharp marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| /// <returns><see cref="Tensor{T}"/> without copying the provided ranges.</returns> | ||
| public Tensor<T> Slice(params ReadOnlySpan<nint> start) | ||
| { | ||
| NRange[] ranges = new NRange[start.Length]; | ||
|
|
@@ -488,8 +538,7 @@ public Tensor<T> Slice(params ReadOnlySpan<nint> start) | |
| /// Forms a slice out of the given tensor | ||
| /// </summary> | ||
| /// <param name="startIndex">The start indexes for the slice</param> | ||
| /// <returns><see cref="Tensor{T}"/> as a copy of the provided ranges.</returns> | ||
| // REVIEW: CURRENTLY DOES A COPY. | ||
| /// <returns><see cref="Tensor{T}"/> without copying the provided ranges.</returns> | ||
| public Tensor<T> Slice(params ReadOnlySpan<NIndex> startIndex) | ||
| { | ||
| NRange[] ranges = new NRange[startIndex.Length]; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using '_values.Length - _memoryOffset' to determine the flattened length of the slice may not correctly reflect the actual number of elements in the sliced tensor. Consider calculating this value from the sliced tensor's dimensions (e.g., using TensorSpanHelpers.CalculateTotalLength(_lengths)) to ensure it accurately represents the slice's data range.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think of this suggestion, @michaelgsharp?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This suggestion is incorrect sadly.