Skip to content

Commit c2daf38

Browse files
committed
GH-49697: [C++][CI] Check IPC file body bounds are in sync with decoder outcome
1 parent c16a8b2 commit c2daf38

File tree

3 files changed

+107
-104
lines changed

3 files changed

+107
-104
lines changed

cpp/src/arrow/ipc/message.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -423,10 +423,12 @@ static Result<std::unique_ptr<Message>> ReadMessageInternal(
423423
body, file->ReadAt(offset + metadata_length, decoder.next_required_size()));
424424
}
425425

426-
if (body->size() < decoder.next_required_size()) {
427-
return Status::IOError("Expected to be able to read ",
428-
decoder.next_required_size(),
429-
" bytes for message body, got ", body->size());
426+
if (body->size() != decoder.next_required_size()) {
427+
// The streaming decoder got out of sync with the actual advertised
428+
// metadata and body size, which signals an invalid IPC file.
429+
return Status::IOError("Invalid IPC file: advertised body size is ", body->size(),
430+
", but message decoder expects to read ",
431+
decoder.next_required_size(), " bytes instead");
430432
}
431433
RETURN_NOT_OK(decoder.Consume(body));
432434
return result;

cpp/src/arrow/ipc/reader.cc

Lines changed: 100 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -122,14 +122,14 @@ Status InvalidMessageType(MessageType expected, MessageType actual) {
122122

123123
/// \brief Structure to keep common arguments to be passed
124124
struct IpcReadContext {
125-
IpcReadContext(DictionaryMemo* memo, const IpcReadOptions& option, bool swap,
125+
IpcReadContext(DictionaryMemo* memo, const IpcReadOptions& option, bool swap_endian,
126126
MetadataVersion version = MetadataVersion::V5,
127127
Compression::type kind = Compression::UNCOMPRESSED)
128128
: dictionary_memo(memo),
129129
options(option),
130130
metadata_version(version),
131131
compression(kind),
132-
swap_endian(swap) {}
132+
swap_endian(swap_endian) {}
133133

134134
DictionaryMemo* dictionary_memo;
135135

@@ -589,6 +589,7 @@ Status DecompressBuffers(Compression::type compression, const IpcReadOptions& op
589589
}
590590
AppendFrom(field->child_data);
591591
}
592+
// Dictionary buffers are decompressed separately (see ReadDictionary).
592593
}
593594

594595
BufferPtrVector Get(const ArrayDataVector& fields) && {
@@ -613,16 +614,91 @@ Status DecompressBuffers(Compression::type compression, const IpcReadOptions& op
613614
});
614615
}
615616

617+
// Helper class to run post-ArrayLoader steps:
618+
// buffer decompression, dictionary resolution, buffer re-alignment.
619+
struct RecordBatchLoader {
620+
Result<std::shared_ptr<RecordBatch>> CreateRecordBatch(ArrayDataVector columns) {
621+
ARROW_ASSIGN_OR_RAISE(auto filtered_columns, CreateColumns(std::move(columns)));
622+
623+
std::shared_ptr<Schema> filtered_schema;
624+
if (!inclusion_mask_.empty()) {
625+
FieldVector filtered_fields;
626+
for (int i = 0; i < schema_->num_fields(); ++i) {
627+
if (inclusion_mask_[i]) {
628+
filtered_fields.push_back(schema_->field(i));
629+
}
630+
}
631+
filtered_schema = schema(std::move(filtered_fields), schema_->metadata());
632+
} else {
633+
filtered_schema = schema_;
634+
}
635+
636+
return RecordBatch::Make(std::move(filtered_schema), batch_length_,
637+
std::move(filtered_columns));
638+
}
639+
640+
Result<ArrayDataVector> CreateColumns(ArrayDataVector columns,
641+
bool resolve_dictionaries = true) {
642+
if (resolve_dictionaries) {
643+
// Dictionary resolution needs to happen on the unfiltered columns,
644+
// because fields are mapped structurally (by path in the original schema).
645+
RETURN_NOT_OK(ResolveDictionaries(columns, *context_.dictionary_memo,
646+
context_.options.memory_pool));
647+
}
648+
649+
ArrayDataVector filtered_columns;
650+
if (!inclusion_mask_.empty()) {
651+
FieldVector filtered_fields;
652+
for (int i = 0; i < schema_->num_fields(); ++i) {
653+
if (inclusion_mask_[i]) {
654+
DCHECK_NE(columns[i], nullptr);
655+
filtered_columns.push_back(std::move(columns[i]));
656+
}
657+
}
658+
columns.clear();
659+
} else {
660+
filtered_columns = std::move(columns);
661+
}
662+
663+
if (context_.compression != Compression::UNCOMPRESSED) {
664+
RETURN_NOT_OK(
665+
DecompressBuffers(context_.compression, context_.options, &filtered_columns));
666+
}
667+
668+
// Swap endian if necessary
669+
if (context_.swap_endian) {
670+
for (auto& column : filtered_columns) {
671+
ARROW_ASSIGN_OR_RAISE(
672+
column, arrow::internal::SwapEndianArrayData(std::move(column),
673+
context_.options.memory_pool));
674+
}
675+
}
676+
if (context_.options.ensure_alignment != Alignment::kAnyAlignment) {
677+
for (auto& column : filtered_columns) {
678+
ARROW_ASSIGN_OR_RAISE(
679+
column,
680+
util::EnsureAlignment(
681+
std::move(column),
682+
// The numerical value of the enum is taken literally as byte alignment
683+
static_cast<int64_t>(context_.options.ensure_alignment),
684+
context_.options.memory_pool));
685+
}
686+
}
687+
return filtered_columns;
688+
}
689+
690+
IpcReadContext context_;
691+
std::shared_ptr<Schema> schema_;
692+
int64_t batch_length_;
693+
std::vector<bool> inclusion_mask_;
694+
};
695+
616696
Result<std::shared_ptr<RecordBatch>> LoadRecordBatchSubset(
617697
const flatbuf::RecordBatch* metadata, const std::shared_ptr<Schema>& schema,
618698
const std::vector<bool>* inclusion_mask, const IpcReadContext& context,
619699
io::RandomAccessFile* file) {
620700
ArrayLoader loader(metadata, context.metadata_version, context.options, file);
621-
622701
ArrayDataVector columns(schema->num_fields());
623-
ArrayDataVector filtered_columns;
624-
FieldVector filtered_fields;
625-
std::shared_ptr<Schema> filtered_schema;
626702

627703
for (int i = 0; i < schema->num_fields(); ++i) {
628704
const Field& field = *schema->field(i);
@@ -634,52 +710,16 @@ Result<std::shared_ptr<RecordBatch>> LoadRecordBatchSubset(
634710
return Status::IOError("Array length did not match record batch length");
635711
}
636712
columns[i] = std::move(column);
637-
if (inclusion_mask) {
638-
filtered_columns.push_back(columns[i]);
639-
filtered_fields.push_back(schema->field(i));
640-
}
641713
} else {
642714
// Skip field. This logic must be executed to advance the state of the
643715
// loader to the next field
644716
RETURN_NOT_OK(loader.SkipField(&field));
645717
}
646718
}
647719

648-
// Dictionary resolution needs to happen on the unfiltered columns,
649-
// because fields are mapped structurally (by path in the original schema).
650-
RETURN_NOT_OK(ResolveDictionaries(columns, *context.dictionary_memo,
651-
context.options.memory_pool));
652-
653-
if (inclusion_mask) {
654-
filtered_schema = ::arrow::schema(std::move(filtered_fields), schema->metadata());
655-
columns.clear();
656-
} else {
657-
filtered_schema = schema;
658-
filtered_columns = std::move(columns);
659-
}
660-
if (context.compression != Compression::UNCOMPRESSED) {
661-
RETURN_NOT_OK(
662-
DecompressBuffers(context.compression, context.options, &filtered_columns));
663-
}
664-
665-
// swap endian in a set of ArrayData if necessary (swap_endian == true)
666-
if (context.swap_endian) {
667-
for (auto& filtered_column : filtered_columns) {
668-
ARROW_ASSIGN_OR_RAISE(filtered_column,
669-
arrow::internal::SwapEndianArrayData(filtered_column));
670-
}
671-
}
672-
auto batch = RecordBatch::Make(std::move(filtered_schema), metadata->length(),
673-
std::move(filtered_columns));
674-
675-
if (ARROW_PREDICT_FALSE(context.options.ensure_alignment != Alignment::kAnyAlignment)) {
676-
return util::EnsureAlignment(batch,
677-
// the numerical value of ensure_alignment enum is taken
678-
// literally as byte alignment
679-
static_cast<int64_t>(context.options.ensure_alignment),
680-
context.options.memory_pool);
681-
}
682-
return batch;
720+
RecordBatchLoader batch_loader{context, schema, metadata->length(),
721+
inclusion_mask ? *inclusion_mask : std::vector<bool>{}};
722+
return batch_loader.CreateRecordBatch(std::move(columns));
683723
}
684724

685725
Result<std::shared_ptr<RecordBatch>> LoadRecordBatch(
@@ -845,7 +885,7 @@ Status UnpackSchemaMessage(const Message& message, const IpcReadOptions& options
845885
out_schema, field_inclusion_mask, swap_endian);
846886
}
847887

848-
Status ReadDictionary(const Buffer& metadata, const IpcReadContext& context,
888+
Status ReadDictionary(const Buffer& metadata, IpcReadContext context,
849889
DictionaryKind* kind, io::RandomAccessFile* file) {
850890
const flatbuf::Message* message = nullptr;
851891
RETURN_NOT_OK(internal::VerifyMessage(metadata.data(), metadata.size(), &message));
@@ -860,13 +900,12 @@ Status ReadDictionary(const Buffer& metadata, const IpcReadContext& context,
860900

861901
CHECK_FLATBUFFERS_NOT_NULL(batch_meta, "DictionaryBatch.data");
862902

863-
Compression::type compression;
864-
RETURN_NOT_OK(GetCompression(batch_meta, &compression));
865-
if (compression == Compression::UNCOMPRESSED &&
903+
RETURN_NOT_OK(GetCompression(batch_meta, &context.compression));
904+
if (context.compression == Compression::UNCOMPRESSED &&
866905
message->version() == flatbuf::MetadataVersion::MetadataVersion_V4) {
867906
// Possibly obtain codec information from experimental serialization format
868907
// in 0.17.x
869-
RETURN_NOT_OK(GetCompressionExperimental(message, &compression));
908+
RETURN_NOT_OK(GetCompressionExperimental(message, &context.compression));
870909
}
871910

872911
const int64_t id = dictionary_batch->id();
@@ -882,16 +921,14 @@ Status ReadDictionary(const Buffer& metadata, const IpcReadContext& context,
882921
const Field dummy_field("", value_type);
883922
RETURN_NOT_OK(loader.Load(&dummy_field, dict_data.get()));
884923

885-
if (compression != Compression::UNCOMPRESSED) {
886-
ArrayDataVector dict_fields{dict_data};
887-
RETURN_NOT_OK(DecompressBuffers(compression, context.options, &dict_fields));
888-
}
889-
890-
// swap endian in dict_data if necessary (swap_endian == true)
891-
if (context.swap_endian) {
892-
ARROW_ASSIGN_OR_RAISE(dict_data, ::arrow::internal::SwapEndianArrayData(
893-
dict_data, context.options.memory_pool));
894-
}
924+
// Run post-load steps: buffer decompression, etc.
925+
RecordBatchLoader batch_loader{context, /*schema=*/nullptr, batch_meta->length(),
926+
/*inclusion_mask=*/std::vector<bool>{}};
927+
ARROW_ASSIGN_OR_RAISE(
928+
auto dict_columns,
929+
batch_loader.CreateColumns({dict_data}, /*resolve_dictionaries=*/false));
930+
DCHECK_EQ(dict_columns.size(), 1);
931+
dict_data = dict_columns[0];
895932

896933
if (dictionary_batch->isDelta()) {
897934
if (kind != nullptr) {
@@ -1756,32 +1793,22 @@ class RecordBatchFileReaderImpl : public RecordBatchFileReader {
17561793
std::shared_ptr<Schema> out_schema;
17571794
RETURN_NOT_OK(GetInclusionMaskAndOutSchema(schema, context.options.included_fields,
17581795
&inclusion_mask, &out_schema));
1759-
17601796
for (int i = 0; i < schema->num_fields(); ++i) {
17611797
const Field& field = *schema->field(i);
1762-
if (inclusion_mask.size() == 0 || inclusion_mask[i]) {
1798+
if (inclusion_mask.empty() || inclusion_mask[i]) {
17631799
// Read field
17641800
auto column = std::make_shared<ArrayData>();
17651801
RETURN_NOT_OK(loader.Load(&field, column.get()));
17661802
if (length != column->length) {
17671803
return Status::IOError("Array length did not match record batch length");
17681804
}
17691805
columns[i] = std::move(column);
1770-
if (inclusion_mask.size() > 0) {
1771-
filtered_columns.push_back(columns[i]);
1772-
filtered_fields.push_back(schema->field(i));
1773-
}
17741806
} else {
17751807
// Skip field. This logic must be executed to advance the state of the
17761808
// loader to the next field
17771809
RETURN_NOT_OK(loader.SkipField(&field));
17781810
}
17791811
}
1780-
if (inclusion_mask.size() > 0) {
1781-
filtered_schema = ::arrow::schema(std::move(filtered_fields), schema->metadata());
1782-
} else {
1783-
filtered_schema = schema;
1784-
}
17851812
return Status::OK();
17861813
}
17871814

@@ -1798,31 +1825,8 @@ class RecordBatchFileReaderImpl : public RecordBatchFileReader {
17981825
}
17991826
loader.read_request().FulfillRequest(buffers);
18001827

1801-
// Dictionary resolution needs to happen on the unfiltered columns,
1802-
// because fields are mapped structurally (by path in the original schema).
1803-
RETURN_NOT_OK(ResolveDictionaries(columns, *context.dictionary_memo,
1804-
context.options.memory_pool));
1805-
if (inclusion_mask.size() > 0) {
1806-
columns.clear();
1807-
} else {
1808-
filtered_columns = std::move(columns);
1809-
}
1810-
1811-
if (context.compression != Compression::UNCOMPRESSED) {
1812-
RETURN_NOT_OK(
1813-
DecompressBuffers(context.compression, context.options, &filtered_columns));
1814-
}
1815-
1816-
// swap endian in a set of ArrayData if necessary (swap_endian == true)
1817-
if (context.swap_endian) {
1818-
for (int i = 0; i < static_cast<int>(filtered_columns.size()); ++i) {
1819-
ARROW_ASSIGN_OR_RAISE(filtered_columns[i],
1820-
arrow::internal::SwapEndianArrayData(
1821-
filtered_columns[i], context.options.memory_pool));
1822-
}
1823-
}
1824-
return RecordBatch::Make(std::move(filtered_schema), length,
1825-
std::move(filtered_columns));
1828+
RecordBatchLoader batch_loader{context, schema, length, std::move(inclusion_mask)};
1829+
return batch_loader.CreateRecordBatch(std::move(columns));
18261830
}
18271831

18281832
std::shared_ptr<Schema> schema;
@@ -1834,9 +1838,6 @@ class RecordBatchFileReaderImpl : public RecordBatchFileReader {
18341838
ArrayDataVector columns;
18351839
io::internal::ReadRangeCache cache;
18361840
int64_t length;
1837-
ArrayDataVector filtered_columns;
1838-
FieldVector filtered_fields;
1839-
std::shared_ptr<Schema> filtered_schema;
18401841
std::vector<bool> inclusion_mask;
18411842
};
18421843

0 commit comments

Comments
 (0)