Skip to content

Commit ab045a4

Browse files
authored
[Data] (De)serialization of PyArrow Extension Arrays (#51972)
<!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? This feature adds the ability to (de)serialize arbitrary PyArrow extension arrays. This is needed to use Ray in code bases that use extension arrays. ~The serialization already seemed sufficiently general, but as far as I can tell, the deserialization can not be done in generality. Hence, this setup allows registration of custom deserializers for extension types.~ ~For serialization, the selector has been changed from `ExtensionType` to `BaseExtensionType` to accommodate for non-Python ExtensionArrays, like `pyarrow.FixedShapeTensorArray`.~ ~This is at the moment a proof-of-concept. If you like the idea, I suppose the registration function may need to move to a better place, and docs need adding.~ The implementation now works without registration on any extension type. ## Related issue number Closes #51959 ## Checks - [X] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [X] I've run `scripts/format.sh` to lint the changes in this PR. - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [X] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [X] Unit tests - [ ] Release tests - [ ] This PR is not tested :( <!-- CURSOR_SUMMARY --> --- > [!NOTE] > Generalizes Arrow array (de)serialization to any `pyarrow.BaseExtensionType`, removing tensor-specific handling and adding tests for fixed/variable-shape tensors. > > - **Arrow (De)serialization**: > - Switch from tensor-specific checks to generic `pyarrow.BaseExtensionType` handling. > - Reconstruct extension arrays via `type.wrap_array(storage)`; serialize via storage payload wrapped with extension metadata. > - Remove `ray.air.util.tensor_extensions.arrow` dependencies and special-casing. > - **Tests**: > - Add roundtrip tests for `pa.FixedShapeTensorArray` and a custom variable-shape `ExtensionType`. > - Import `PicklableArrayPayload` in tests for constructing payloads. > > <sup>Written by [Cursor Bugbot](https://cursor.com/dashboard?tab=bugbot) for commit 4bbcdbe. This will update automatically on new commits. Configure [here](https://cursor.com/dashboard?tab=bugbot).</sup> <!-- /CURSOR_SUMMARY --> --------- Signed-off-by: Pim de Haan <pim@cusp.ai>
1 parent d09174e commit ab045a4

File tree

2 files changed

+64
-36
lines changed

2 files changed

+64
-36
lines changed

python/ray/_private/arrow_serialization.py

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
if TYPE_CHECKING:
1414
import pyarrow
1515

16-
from ray.data.extensions import ArrowTensorArray
17-
1816
RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION = (
1917
"RAY_DISABLE_CUSTOM_ARROW_JSON_OPTIONS_SERIALIZATION"
2018
)
@@ -240,12 +238,8 @@ def _array_payload_to_array(payload: "PicklableArrayPayload") -> "pyarrow.Array"
240238
"""Reconstruct an Arrow Array from a possibly nested PicklableArrayPayload."""
241239
import pyarrow as pa
242240

243-
from ray.air.util.tensor_extensions.arrow import get_arrow_extension_tensor_types
244-
245241
children = [child_payload.to_array() for child_payload in payload.children]
246242

247-
tensor_extension_types = get_arrow_extension_tensor_types()
248-
249243
if pa.types.is_dictionary(payload.type):
250244
# Dedicated path for reconstructing a DictionaryArray, since
251245
# Array.from_buffers() doesn't work for DictionaryArrays.
@@ -258,16 +252,10 @@ def _array_payload_to_array(payload: "PicklableArrayPayload") -> "pyarrow.Array"
258252
assert len(children) == 3, len(children)
259253
offsets, keys, items = children
260254
return pa.MapArray.from_arrays(offsets, keys, items)
261-
elif isinstance(
262-
payload.type,
263-
tensor_extension_types,
264-
):
265-
# Dedicated path for reconstructing an ArrowTensorArray or
266-
# ArrowVariableShapedTensorArray, both of which can't be reconstructed by the
267-
# Array.from_buffers() API.
255+
elif isinstance(payload.type, pa.BaseExtensionType):
268256
assert len(children) == 1, len(children)
269257
storage = children[0]
270-
return pa.ExtensionArray.from_storage(payload.type, storage)
258+
return payload.type.wrap_array(storage)
271259
else:
272260
# Common case: use Array.from_buffers() to construct an array of a certain type.
273261
return pa.Array.from_buffers(
@@ -288,10 +276,6 @@ def _array_to_array_payload(a: "pyarrow.Array") -> "PicklableArrayPayload":
288276
"""
289277
import pyarrow as pa
290278

291-
from ray.air.util.tensor_extensions.arrow import get_arrow_extension_tensor_types
292-
293-
tensor_extension_types = get_arrow_extension_tensor_types()
294-
295279
if _is_dense_union(a.type):
296280
# Dense unions are not supported.
297281
# TODO(Clark): Support dense unions.
@@ -319,9 +303,7 @@ def _array_to_array_payload(a: "pyarrow.Array") -> "PicklableArrayPayload":
319303
return _dictionary_array_to_array_payload(a)
320304
elif pa.types.is_map(a.type):
321305
return _map_array_to_array_payload(a)
322-
elif isinstance(a.type, tensor_extension_types):
323-
return _tensor_array_to_array_payload(a)
324-
elif isinstance(a.type, pa.ExtensionType):
306+
elif isinstance(a.type, pa.BaseExtensionType):
325307
return _extension_array_to_array_payload(a)
326308
else:
327309
raise ValueError("Unhandled Arrow array type:", a.type)
@@ -630,11 +612,9 @@ def _map_array_to_array_payload(a: "pyarrow.MapArray") -> "PicklableArrayPayload
630612
)
631613

632614

633-
def _tensor_array_to_array_payload(a: "ArrowTensorArray") -> "PicklableArrayPayload":
634-
"""Serialize tensor arrays to PicklableArrayPayload."""
635-
# Offset is propagated to storage array, and the storage array items align with the
636-
# tensor elements, so we only need to do the straightforward creation of the storage
637-
# array payload.
615+
def _extension_array_to_array_payload(
616+
a: "pyarrow.ExtensionArray",
617+
) -> "PicklableArrayPayload":
638618
storage_payload = _array_to_array_payload(a.storage)
639619
return PicklableArrayPayload(
640620
type=a.type,
@@ -646,16 +626,6 @@ def _tensor_array_to_array_payload(a: "ArrowTensorArray") -> "PicklableArrayPayl
646626
)
647627

648628

649-
def _extension_array_to_array_payload(
650-
a: "pyarrow.ExtensionArray",
651-
) -> "PicklableArrayPayload":
652-
payload = _array_to_array_payload(a.storage)
653-
payload.type = a.type
654-
payload.length = len(a)
655-
payload.null_count = a.null_count
656-
return payload
657-
658-
659629
def _copy_buffer_if_needed(
660630
buf: "pyarrow.Buffer",
661631
type_: Optional["pyarrow.DataType"],

python/ray/data/tests/test_arrow_serialization.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import ray.data
1717
import ray.train
1818
from ray._private.arrow_serialization import (
19+
PicklableArrayPayload,
1920
_align_bit_offset,
2021
_bytes_for_bits,
2122
_copy_bitpacked_buffer_if_needed,
@@ -595,3 +596,60 @@ def test_custom_arrow_data_serializer_disable(shutdown_only):
595596
assert d_view["a"].chunk(0).buffers()[1].size == t["a"].chunk(0).buffers()[1].size
596597
# Check that the serialized slice view is large
597598
assert len(s_view) > 0.8 * len(s_t)
599+
600+
601+
def test_fixed_shape_tensor_array_serialization():
602+
a = pa.FixedShapeTensorArray.from_numpy_ndarray(
603+
np.arange(4 * 2 * 3).reshape(4, 2, 3)
604+
)
605+
payload = PicklableArrayPayload.from_array(a)
606+
a2 = payload.to_array()
607+
assert a == a2
608+
609+
610+
class _VariableShapeTensorType(pa.ExtensionType):
611+
def __init__(
612+
self,
613+
value_type: pa.DataType,
614+
ndim: int,
615+
) -> None:
616+
self.value_type = value_type
617+
self.ndim = ndim
618+
super().__init__(
619+
pa.struct(
620+
[
621+
pa.field("data", pa.list_(value_type)),
622+
pa.field("shape", pa.list_(pa.int32(), ndim)),
623+
]
624+
),
625+
"variable_shape_tensor",
626+
)
627+
628+
def __arrow_ext_serialize__(self) -> bytes:
629+
return b""
630+
631+
@classmethod
632+
def __arrow_ext_deserialize__(cls, storage_type: pa.DataType, serialized: bytes):
633+
ndim = storage_type[1].type.list_size
634+
value_type = storage_type[0].type.value_type
635+
return cls(value_type, ndim)
636+
637+
638+
def test_variable_shape_tensor_serialization():
639+
t = _VariableShapeTensorType(pa.float32(), 2)
640+
ar = pa.array(
641+
[
642+
{
643+
"data": np.arange(2 * 3),
644+
"shape": [2, 3],
645+
},
646+
{
647+
"data": np.arange(4 * 5),
648+
"shape": [4, 5],
649+
},
650+
],
651+
type=t,
652+
)
653+
payload = PicklableArrayPayload.from_array(ar)
654+
ar2 = payload.to_array()
655+
assert ar == ar2

0 commit comments

Comments
 (0)