Skip to content

Commit 3b2d101

Browse files
committed
GH-49697: [C++][CI] Check IPC file body bounds are in sync with decoder outcome
1 parent c16a8b2 commit 3b2d101

File tree

3 files changed

+106
-104
lines changed

3 files changed

+106
-104
lines changed

cpp/src/arrow/ipc/message.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -423,10 +423,12 @@ static Result<std::unique_ptr<Message>> ReadMessageInternal(
423423
body, file->ReadAt(offset + metadata_length, decoder.next_required_size()));
424424
}
425425

426-
if (body->size() < decoder.next_required_size()) {
427-
return Status::IOError("Expected to be able to read ",
428-
decoder.next_required_size(),
429-
" bytes for message body, got ", body->size());
426+
if (body->size() != decoder.next_required_size()) {
427+
// The streaming decoder got out of sync with the actual advertised
428+
// metadata and body size, which signals an invalid IPC file.
429+
return Status::IOError("Invalid IPC file: advertised body size is ", body->size(),
430+
", but message decoder expects to read ",
431+
decoder.next_required_size(), " bytes instead");
430432
}
431433
RETURN_NOT_OK(decoder.Consume(body));
432434
return result;

cpp/src/arrow/ipc/reader.cc

Lines changed: 99 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,14 @@ Status InvalidMessageType(MessageType expected, MessageType actual) {
122122

123123
/// \brief Structure to keep common arguments to be passed
124124
struct IpcReadContext {
125-
IpcReadContext(DictionaryMemo* memo, const IpcReadOptions& option, bool swap,
125+
IpcReadContext(DictionaryMemo* memo, const IpcReadOptions& option, bool swap_endian,
126126
MetadataVersion version = MetadataVersion::V5,
127127
Compression::type kind = Compression::UNCOMPRESSED)
128128
: dictionary_memo(memo),
129129
options(option),
130130
metadata_version(version),
131131
compression(kind),
132-
swap_endian(swap) {}
132+
swap_endian(swap_endian) {}
133133

134134
DictionaryMemo* dictionary_memo;
135135

@@ -589,6 +589,7 @@ Status DecompressBuffers(Compression::type compression, const IpcReadOptions& op
589589
}
590590
AppendFrom(field->child_data);
591591
}
592+
// Dictionary buffers are decompressed separately (see ReadDictionary).
592593
}
593594

594595
BufferPtrVector Get(const ArrayDataVector& fields) && {
@@ -613,16 +614,90 @@ Status DecompressBuffers(Compression::type compression, const IpcReadOptions& op
613614
});
614615
}
615616

617+
// Helper class to run post-ArrayLoader steps:
618+
// buffer decompression, dictionary resolution, buffer re-alignment.
619+
struct RecordBatchLoader {
620+
Result<std::shared_ptr<RecordBatch>> CreateRecordBatch(ArrayDataVector columns) {
621+
ARROW_ASSIGN_OR_RAISE(auto filtered_columns, CreateColumns(std::move(columns)));
622+
623+
std::shared_ptr<Schema> filtered_schema;
624+
if (!inclusion_mask_.empty()) {
625+
FieldVector filtered_fields;
626+
for (int i = 0; i < schema_->num_fields(); ++i) {
627+
if (inclusion_mask_[i]) {
628+
filtered_fields.push_back(schema_->field(i));
629+
}
630+
}
631+
filtered_schema = schema(std::move(filtered_fields), schema_->metadata());
632+
} else {
633+
filtered_schema = schema_;
634+
}
635+
636+
return RecordBatch::Make(std::move(filtered_schema), batch_length_,
637+
std::move(filtered_columns));
638+
}
639+
640+
Result<ArrayDataVector> CreateColumns(ArrayDataVector columns,
641+
bool resolve_dictionaries = true) {
642+
if (resolve_dictionaries) {
643+
// Dictionary resolution needs to happen on the unfiltered columns,
644+
// because fields are mapped structurally (by path in the original schema).
645+
RETURN_NOT_OK(ResolveDictionaries(columns, *context_.dictionary_memo,
646+
context_.options.memory_pool));
647+
}
648+
649+
ArrayDataVector filtered_columns;
650+
if (!inclusion_mask_.empty()) {
651+
FieldVector filtered_fields;
652+
for (int i = 0; i < schema_->num_fields(); ++i) {
653+
if (inclusion_mask_[i]) {
654+
filtered_columns.push_back(std::move(columns[i]));
655+
}
656+
}
657+
columns.clear();
658+
} else {
659+
filtered_columns = std::move(columns);
660+
}
661+
662+
if (context_.compression != Compression::UNCOMPRESSED) {
663+
RETURN_NOT_OK(
664+
DecompressBuffers(context_.compression, context_.options, &filtered_columns));
665+
}
666+
667+
// Swap endian if necessary
668+
if (context_.swap_endian) {
669+
for (auto& column : filtered_columns) {
670+
ARROW_ASSIGN_OR_RAISE(
671+
column, arrow::internal::SwapEndianArrayData(std::move(column),
672+
context_.options.memory_pool));
673+
}
674+
}
675+
if (context_.options.ensure_alignment != Alignment::kAnyAlignment) {
676+
for (auto& column : filtered_columns) {
677+
ARROW_ASSIGN_OR_RAISE(
678+
column,
679+
util::EnsureAlignment(
680+
std::move(column),
681+
// The numerical value of the enum is taken literally as byte alignment
682+
static_cast<int64_t>(context_.options.ensure_alignment),
683+
context_.options.memory_pool));
684+
}
685+
}
686+
return filtered_columns;
687+
}
688+
689+
IpcReadContext context_;
690+
std::shared_ptr<Schema> schema_;
691+
int64_t batch_length_;
692+
std::vector<bool> inclusion_mask_;
693+
};
694+
616695
Result<std::shared_ptr<RecordBatch>> LoadRecordBatchSubset(
617696
const flatbuf::RecordBatch* metadata, const std::shared_ptr<Schema>& schema,
618697
const std::vector<bool>* inclusion_mask, const IpcReadContext& context,
619698
io::RandomAccessFile* file) {
620699
ArrayLoader loader(metadata, context.metadata_version, context.options, file);
621-
622700
ArrayDataVector columns(schema->num_fields());
623-
ArrayDataVector filtered_columns;
624-
FieldVector filtered_fields;
625-
std::shared_ptr<Schema> filtered_schema;
626701

627702
for (int i = 0; i < schema->num_fields(); ++i) {
628703
const Field& field = *schema->field(i);
@@ -634,52 +709,16 @@ Result<std::shared_ptr<RecordBatch>> LoadRecordBatchSubset(
634709
return Status::IOError("Array length did not match record batch length");
635710
}
636711
columns[i] = std::move(column);
637-
if (inclusion_mask) {
638-
filtered_columns.push_back(columns[i]);
639-
filtered_fields.push_back(schema->field(i));
640-
}
641712
} else {
642713
// Skip field. This logic must be executed to advance the state of the
643714
// loader to the next field
644715
RETURN_NOT_OK(loader.SkipField(&field));
645716
}
646717
}
647718

648-
// Dictionary resolution needs to happen on the unfiltered columns,
649-
// because fields are mapped structurally (by path in the original schema).
650-
RETURN_NOT_OK(ResolveDictionaries(columns, *context.dictionary_memo,
651-
context.options.memory_pool));
652-
653-
if (inclusion_mask) {
654-
filtered_schema = ::arrow::schema(std::move(filtered_fields), schema->metadata());
655-
columns.clear();
656-
} else {
657-
filtered_schema = schema;
658-
filtered_columns = std::move(columns);
659-
}
660-
if (context.compression != Compression::UNCOMPRESSED) {
661-
RETURN_NOT_OK(
662-
DecompressBuffers(context.compression, context.options, &filtered_columns));
663-
}
664-
665-
// swap endian in a set of ArrayData if necessary (swap_endian == true)
666-
if (context.swap_endian) {
667-
for (auto& filtered_column : filtered_columns) {
668-
ARROW_ASSIGN_OR_RAISE(filtered_column,
669-
arrow::internal::SwapEndianArrayData(filtered_column));
670-
}
671-
}
672-
auto batch = RecordBatch::Make(std::move(filtered_schema), metadata->length(),
673-
std::move(filtered_columns));
674-
675-
if (ARROW_PREDICT_FALSE(context.options.ensure_alignment != Alignment::kAnyAlignment)) {
676-
return util::EnsureAlignment(batch,
677-
// the numerical value of ensure_alignment enum is taken
678-
// literally as byte alignment
679-
static_cast<int64_t>(context.options.ensure_alignment),
680-
context.options.memory_pool);
681-
}
682-
return batch;
719+
RecordBatchLoader batch_loader{context, schema, metadata->length(),
720+
inclusion_mask ? *inclusion_mask : std::vector<bool>{}};
721+
return batch_loader.CreateRecordBatch(std::move(columns));
683722
}
684723

685724
Result<std::shared_ptr<RecordBatch>> LoadRecordBatch(
@@ -845,7 +884,7 @@ Status UnpackSchemaMessage(const Message& message, const IpcReadOptions& options
845884
out_schema, field_inclusion_mask, swap_endian);
846885
}
847886

848-
Status ReadDictionary(const Buffer& metadata, const IpcReadContext& context,
887+
Status ReadDictionary(const Buffer& metadata, IpcReadContext context,
849888
DictionaryKind* kind, io::RandomAccessFile* file) {
850889
const flatbuf::Message* message = nullptr;
851890
RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message));
@@ -860,13 +899,12 @@ Status ReadDictionary(const Buffer& metadata, const IpcReadContext& context,
860899

861900
CHECK_FLATBUFFERS_NOT_NULL(batch_meta, "DictionaryBatch.data");
862901

863-
Compression::type compression;
864-
RETURN_NOT_OK(GetCompression(batch_meta, &compression));
865-
if (compression == Compression::UNCOMPRESSED &&
902+
RETURN_NOT_OK(GetCompression(batch_meta, &context.compression));
903+
if (context.compression == Compression::UNCOMPRESSED &&
866904
message->version() == flatbuf::MetadataVersion::MetadataVersion_V4) {
867905
// Possibly obtain codec information from experimental serialization format
868906
// in 0.17.x
869-
RETURN_NOT_OK(GetCompressionExperimental(message, &compression));
907+
RETURN_NOT_OK(GetCompressionExperimental(message, &context.compression));
870908
}
871909

872910
const int64_t id = dictionary_batch->id();
@@ -882,16 +920,14 @@ Status ReadDictionary(const Buffer& metadata, const IpcReadContext& context,
882920
const Field dummy_field("", value_type);
883921
RETURN_NOT_OK(loader.Load(&dummy_field, dict_data.get()));
884922

885-
if (compression != Compression::UNCOMPRESSED) {
886-
ArrayDataVector dict_fields{dict_data};
887-
RETURN_NOT_OK(DecompressBuffers(compression, context.options, &dict_fields));
888-
}
889-
890-
// swap endian in dict_data if necessary (swap_endian == true)
891-
if (context.swap_endian) {
892-
ARROW_ASSIGN_OR_RAISE(dict_data, ::arrow::internal::SwapEndianArrayData(
893-
dict_data, context.options.memory_pool));
894-
}
923+
// Run post-load steps: buffer decompression, etc.
924+
RecordBatchLoader batch_loader{context, /*schema=*/nullptr, batch_meta->length(),
925+
/*inclusion_mask=*/std::vector<bool>{}};
926+
ARROW_ASSIGN_OR_RAISE(
927+
auto dict_columns,
928+
batch_loader.CreateColumns({dict_data}, /*resolve_dictionaries=*/false));
929+
DCHECK_EQ(dict_columns.size(), 1);
930+
dict_data = dict_columns[0];
895931

896932
if (dictionary_batch->isDelta()) {
897933
if (kind != nullptr) {
@@ -1756,32 +1792,22 @@ class RecordBatchFileReaderImpl : public RecordBatchFileReader {
17561792
std::shared_ptr<Schema> out_schema;
17571793
RETURN_NOT_OK(GetInclusionMaskAndOutSchema(schema, context.options.included_fields,
17581794
&inclusion_mask, &out_schema));
1759-
17601795
for (int i = 0; i < schema->num_fields(); ++i) {
17611796
const Field& field = *schema->field(i);
1762-
if (inclusion_mask.size() == 0 || inclusion_mask[i]) {
1797+
if (inclusion_mask.empty() || inclusion_mask[i]) {
17631798
// Read field
17641799
auto column = std::make_shared<ArrayData>();
17651800
RETURN_NOT_OK(loader.Load(&field, column.get()));
17661801
if (length != column->length) {
17671802
return Status::IOError("Array length did not match record batch length");
17681803
}
17691804
columns[i] = std::move(column);
1770-
if (inclusion_mask.size() > 0) {
1771-
filtered_columns.push_back(columns[i]);
1772-
filtered_fields.push_back(schema->field(i));
1773-
}
17741805
} else {
17751806
// Skip field. This logic must be executed to advance the state of the
17761807
// loader to the next field
17771808
RETURN_NOT_OK(loader.SkipField(&field));
17781809
}
17791810
}
1780-
if (inclusion_mask.size() > 0) {
1781-
filtered_schema = ::arrow::schema(std::move(filtered_fields), schema->metadata());
1782-
} else {
1783-
filtered_schema = schema;
1784-
}
17851811
return Status::OK();
17861812
}
17871813

@@ -1798,31 +1824,8 @@ class RecordBatchFileReaderImpl : public RecordBatchFileReader {
17981824
}
17991825
loader.read_request().FulfillRequest(buffers);
18001826

1801-
// Dictionary resolution needs to happen on the unfiltered columns,
1802-
// because fields are mapped structurally (by path in the original schema).
1803-
RETURN_NOT_OK(ResolveDictionaries(columns, *context.dictionary_memo,
1804-
context.options.memory_pool));
1805-
if (inclusion_mask.size() > 0) {
1806-
columns.clear();
1807-
} else {
1808-
filtered_columns = std::move(columns);
1809-
}
1810-
1811-
if (context.compression != Compression::UNCOMPRESSED) {
1812-
RETURN_NOT_OK(
1813-
DecompressBuffers(context.compression, context.options, &filtered_columns));
1814-
}
1815-
1816-
// swap endian in a set of ArrayData if necessary (swap_endian == true)
1817-
if (context.swap_endian) {
1818-
for (int i = 0; i < static_cast<int>(filtered_columns.size()); ++i) {
1819-
ARROW_ASSIGN_OR_RAISE(filtered_columns[i],
1820-
arrow::internal::SwapEndianArrayData(
1821-
filtered_columns[i], context.options.memory_pool));
1822-
}
1823-
}
1824-
return RecordBatch::Make(std::move(filtered_schema), length,
1825-
std::move(filtered_columns));
1827+
RecordBatchLoader batch_loader{context, schema, length, std::move(inclusion_mask)};
1828+
return batch_loader.CreateRecordBatch(std::move(columns));
18261829
}
18271830

18281831
std::shared_ptr<Schema> schema;
@@ -1834,9 +1837,6 @@ class RecordBatchFileReaderImpl : public RecordBatchFileReader {
18341837
ArrayDataVector columns;
18351838
io::internal::ReadRangeCache cache;
18361839
int64_t length;
1837-
ArrayDataVector filtered_columns;
1838-
FieldVector filtered_fields;
1839-
std::shared_ptr<Schema> filtered_schema;
18401840
std::vector<bool> inclusion_mask;
18411841
};
18421842

0 commit comments

Comments
 (0)