Skip to content

Commit d8946ca

Browse files
authored
Fix ArrowArrayStreamReader for 0-columns record batch streams (#9405)
# Which issue does this PR close? - Closes #9394 # Rationale for this change PR #8944 introduced a regression that 0-column record batch streams could not longer be decoded. # What changes are included in this PR? - Construct `RecordBatch` with `try_new_with_options` using the `len` of the `ArrayData`, instead of letting it try to implicitly determine `len` by looking at the first column (this is what `try_new` does). - Slight refactor and reduction of code duplication of the existing `test_stream_round_trip_[import/export]` tests - Introduction of a new `test_stream_round_trip_no_columns` test # Are these changes tested? Yes, both export and import are tested in `test_stream_round_trip_no_columns`. # Are there any user-facing changes? 0-column record batch streams should be decodable now.
1 parent 70089ac commit d8946ca

File tree

1 file changed

+39
-36
lines changed

1 file changed

+39
-36
lines changed

arrow-array/src/ffi_stream.rs

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ use std::{
6666
use arrow_data::ffi::FFI_ArrowArray;
6767
use arrow_schema::{ArrowError, Schema, SchemaRef, ffi::FFI_ArrowSchema};
6868

69+
use crate::RecordBatchOptions;
6970
use crate::array::Array;
7071
use crate::array::StructArray;
7172
use crate::ffi::from_ffi_and_data_type;
@@ -365,7 +366,12 @@ impl Iterator for ArrowArrayStreamReader {
365366
from_ffi_and_data_type(array, DataType::Struct(self.schema().fields().clone()))
366367
};
367368
Some(result.and_then(|data| {
368-
RecordBatch::try_new(self.schema.clone(), StructArray::from(data).into_parts().1)
369+
let len = data.len();
370+
RecordBatch::try_new_with_options(
371+
self.schema.clone(),
372+
StructArray::from(data).into_parts().1,
373+
&RecordBatchOptions::new().with_row_count(Some(len)),
374+
)
369375
}))
370376
} else {
371377
let last_error = self.get_stream_last_error();
@@ -419,20 +425,7 @@ mod tests {
419425
}
420426
}
421427

422-
fn _test_round_trip_export(arrays: Vec<Arc<dyn Array>>) -> Result<()> {
423-
let metadata = HashMap::from([("foo".to_owned(), "bar".to_owned())]);
424-
let schema = Arc::new(Schema::new_with_metadata(
425-
vec![
426-
Field::new("a", arrays[0].data_type().clone(), true)
427-
.with_metadata(metadata.clone()),
428-
Field::new("b", arrays[1].data_type().clone(), true)
429-
.with_metadata(metadata.clone()),
430-
Field::new("c", arrays[2].data_type().clone(), true)
431-
.with_metadata(metadata.clone()),
432-
],
433-
metadata,
434-
));
435-
let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
428+
fn _test_round_trip_export(batch: RecordBatch, schema: Arc<Schema>) -> Result<()> {
436429
let iter = Box::new(vec![batch.clone(), batch.clone()].into_iter().map(Ok)) as _;
437430

438431
let reader = TestRecordBatchReader::new(schema.clone(), iter);
@@ -461,10 +454,12 @@ mod tests {
461454
}
462455

463456
let array = unsafe { from_ffi(ffi_array, &ffi_schema) }.unwrap();
457+
let len = array.len();
464458

465-
let record_batch = RecordBatch::try_new(
459+
let record_batch = RecordBatch::try_new_with_options(
466460
SchemaRef::from(exported_schema.clone()),
467461
StructArray::from(array).into_parts().1,
462+
&RecordBatchOptions::new().with_row_count(Some(len)),
468463
)
469464
.unwrap();
470465
produced_batches.push(record_batch);
@@ -475,20 +470,7 @@ mod tests {
475470
Ok(())
476471
}
477472

478-
fn _test_round_trip_import(arrays: Vec<Arc<dyn Array>>) -> Result<()> {
479-
let metadata = HashMap::from([("foo".to_owned(), "bar".to_owned())]);
480-
let schema = Arc::new(Schema::new_with_metadata(
481-
vec![
482-
Field::new("a", arrays[0].data_type().clone(), true)
483-
.with_metadata(metadata.clone()),
484-
Field::new("b", arrays[1].data_type().clone(), true)
485-
.with_metadata(metadata.clone()),
486-
Field::new("c", arrays[2].data_type().clone(), true)
487-
.with_metadata(metadata.clone()),
488-
],
489-
metadata,
490-
));
491-
let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap();
473+
fn _test_round_trip_import(batch: RecordBatch, schema: Arc<Schema>) -> Result<()> {
492474
let iter = Box::new(vec![batch.clone(), batch.clone()].into_iter().map(Ok)) as _;
493475

494476
let reader = TestRecordBatchReader::new(schema.clone(), iter);
@@ -511,19 +493,40 @@ mod tests {
511493
}
512494

513495
#[test]
514-
fn test_stream_round_trip_export() -> Result<()> {
496+
fn test_stream_round_trip() {
515497
let array = Int32Array::from(vec![Some(2), None, Some(1), None]);
516498
let array: Arc<dyn Array> = Arc::new(array);
499+
let metadata = HashMap::from([("foo".to_owned(), "bar".to_owned())]);
500+
501+
let schema = Arc::new(Schema::new_with_metadata(
502+
vec![
503+
Field::new("a", array.data_type().clone(), true).with_metadata(metadata.clone()),
504+
Field::new("b", array.data_type().clone(), true).with_metadata(metadata.clone()),
505+
Field::new("c", array.data_type().clone(), true).with_metadata(metadata.clone()),
506+
],
507+
metadata,
508+
));
509+
let batch = RecordBatch::try_new(schema.clone(), vec![array.clone(), array.clone(), array])
510+
.unwrap();
517511

518-
_test_round_trip_export(vec![array.clone(), array.clone(), array])
512+
_test_round_trip_export(batch.clone(), schema.clone()).unwrap();
513+
_test_round_trip_import(batch, schema).unwrap();
519514
}
520515

521516
#[test]
522-
fn test_stream_round_trip_import() -> Result<()> {
523-
let array = Int32Array::from(vec![Some(2), None, Some(1), None]);
524-
let array: Arc<dyn Array> = Arc::new(array);
517+
fn test_stream_round_trip_no_columns() {
518+
let metadata = HashMap::from([("foo".to_owned(), "bar".to_owned())]);
519+
520+
let schema = Arc::new(Schema::new_with_metadata(Vec::<Field>::new(), metadata));
521+
let batch = RecordBatch::try_new_with_options(
522+
schema.clone(),
523+
Vec::<Arc<dyn Array>>::new(),
524+
&RecordBatchOptions::new().with_row_count(Some(10)),
525+
)
526+
.unwrap();
525527

526-
_test_round_trip_import(vec![array.clone(), array.clone(), array])
528+
_test_round_trip_export(batch.clone(), schema.clone()).unwrap();
529+
_test_round_trip_import(batch, schema).unwrap();
527530
}
528531

529532
#[test]

0 commit comments

Comments
 (0)