Skip to content

Commit d5338da

Browse files
authored
Fix tensor external data info length parsing issue. (#23526)
Fix tensor external data info length parsing issue. The old implementation was parsing a `size_t` value with `strtol` (via `OrtStrToPtrDiff`) on ARM64 MSVC. https://github.com/microsoft/onnxruntime/blob/bf023ab3d565668c13a5334b505df0eb6acf3625/onnxruntime/core/platform/path_lib.h#L74 If we have `sizeof(size_t) == 8` and `sizeof(long) == 4` (as is the case for x64 and ARM64 MSVC), `strtol` will return a maximum value of `2^31-1` even for a larger, valid `size_t` value. `strtol` will also set `errno` to `ERANGE`, but we weren't checking that. Updated to use `ParseStringWithClassicLocale` which will parse directly to the target type. Added some tests.
1 parent e3e4173 commit d5338da

File tree

3 files changed

+88
-30
lines changed

3 files changed

+88
-30
lines changed

onnxruntime/core/framework/tensor_external_data_info.cc

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "tensor_external_data_info.h"
55
#include "core/common/common.h"
66
#include "core/common/narrow.h"
7+
#include "core/common/parse_string.h"
78
#include "core/common/safeint.h"
89
#include "core/common/string_utils.h"
910
#include "core/platform/path_lib.h"
@@ -18,21 +19,8 @@ using ::ONNX_NAMESPACE::StringStringEntryProto;
1819

1920
namespace onnxruntime {
2021
Status ExternalDataInfo::Create(const RepeatedPtrField<StringStringEntryProto>& input,
21-
std::unique_ptr<ExternalDataInfo>& out) {
22-
auto str_to_int = [](const std::string& s, OFFSET_TYPE& result) -> Status {
23-
char* end;
24-
#ifdef _WIN32
25-
result = _strtoi64(s.c_str(), &end, 10);
26-
#else
27-
result = OrtStrToPtrDiff(s.c_str(), &end);
28-
#endif
29-
if (end != s.c_str() + s.length()) {
30-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "parsing ", s, " failed");
31-
}
32-
return Status::OK();
33-
};
34-
35-
out = std::make_unique<ExternalDataInfo>();
22+
std::unique_ptr<ExternalDataInfo>& external_data_info_result) {
23+
auto external_data_info = std::make_unique<ExternalDataInfo>();
3624
PrepackedInfos prepacked_infos;
3725

3826
const int input_size = input.size();
@@ -43,17 +31,15 @@ Status ExternalDataInfo::Create(const RepeatedPtrField<StringStringEntryProto>&
4331
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model format error! Need a key for the external data info");
4432
if (!stringmap.has_value())
4533
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model format error! Need a value for the external data info");
34+
4635
if (stringmap.key() == "location" && !stringmap.value().empty()) {
47-
out->rel_path_ = ToWideString(stringmap.value());
36+
external_data_info->rel_path_ = ToWideString(stringmap.value());
4837
} else if (stringmap.key() == "offset" && !stringmap.value().empty()) {
49-
ORT_RETURN_IF_ERROR(str_to_int(stringmap.value(), out->offset_));
38+
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(stringmap.value(), external_data_info->offset_));
5039
} else if (stringmap.key() == "length" && !stringmap.value().empty()) {
51-
char* end;
52-
out->length_ = narrow<size_t>(OrtStrToPtrDiff(stringmap.value().c_str(), &end));
53-
if (end != stringmap.value().c_str() + stringmap.value().length())
54-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "parsing ", stringmap.value(), " failed");
40+
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(stringmap.value(), external_data_info->length_));
5541
} else if (stringmap.key() == "checksum" && !stringmap.value().empty()) {
56-
out->checksum_ = stringmap.value();
42+
external_data_info->checksum_ = stringmap.value();
5743
} else if (stringmap.key().find("prepacked", 0) == 0) {
5844
// Starts with 'prepacked', each has its own key.
5945
// Each prepacked entry may have multiple blobs with the same key
@@ -72,10 +58,11 @@ Status ExternalDataInfo::Create(const RepeatedPtrField<StringStringEntryProto>&
7258
const auto& blob = split_fields[f];
7359
auto blob_fields = utils::SplitString(blob, ";", false);
7460
if (blob_fields.size() == 3) {
75-
OFFSET_TYPE offset, len;
76-
ORT_RETURN_IF_ERROR(str_to_int(std::string(blob_fields[0]), offset));
77-
ORT_RETURN_IF_ERROR(str_to_int(std::string(blob_fields[1]), len));
78-
blob_infos.push_back(std::make_tuple(offset, narrow<size_t>(len), std::string(blob_fields[2])));
61+
OFFSET_TYPE offset;
62+
size_t len;
63+
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(blob_fields[0], offset));
64+
ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(blob_fields[1], len));
65+
blob_infos.push_back(std::make_tuple(offset, len, std::string(blob_fields[2])));
7966
}
8067
}
8168
if (blob_infos.empty()) {
@@ -88,14 +75,15 @@ Status ExternalDataInfo::Create(const RepeatedPtrField<StringStringEntryProto>&
8875
}
8976
}
9077

91-
if (out->rel_path_.empty()) {
78+
if (external_data_info->rel_path_.empty()) {
9279
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "model format error! Missing 'location'");
9380
}
9481

9582
if (!prepacked_infos.empty()) {
96-
out->prepacked_infos_ = std::move(prepacked_infos);
83+
external_data_info->prepacked_infos_ = std::move(prepacked_infos);
9784
}
9885

86+
external_data_info_result = std::move(external_data_info);
9987
return Status::OK();
10088
}
10189
void ExternalDataInfo::SetExternalLocationToProto(const std::filesystem::path& external_file_path,

onnxruntime/core/framework/tensor_external_data_info.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@ class ExternalDataInfo {
3232

3333
const std::string& GetChecksum() const { return checksum_; }
3434

35-
// If the value of 'offset' or 'length' field is larger the max value of ssize_t, this function will treat it as a
36-
// wrong value and return FAIL.
3735
static common::Status Create(
3836
const ::google::protobuf::RepeatedPtrField<::ONNX_NAMESPACE::StringStringEntryProto>& input,
3937
std::unique_ptr<ExternalDataInfo>& out);

onnxruntime/test/framework/tensorutils_test.cc

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@
22
// Licensed under the MIT License.
33

44
#include "core/common/inlined_containers.h"
5+
#include "core/common/parse_string.h"
56
#include "core/framework/prepacked_weights.h"
67
#include "core/framework/prepacked_weights_container.h"
78
#include "core/framework/tensorprotoutils.h"
89
#include "core/graph/onnx_protobuf.h"
910
#include "test/util/include/asserts.h"
1011
#include "file_util.h"
1112

13+
#include <cstdint>
14+
#include <limits>
15+
1216
#include "gtest/gtest.h"
1317
#include "gmock/gmock.h"
1418

@@ -22,6 +26,74 @@ using namespace ONNX_NAMESPACE;
2226
namespace onnxruntime {
2327
namespace test {
2428

29+
// if `expected_error_message_substring` is nullptr, parsing is expected to be successful
30+
static void TestExternalDataInfoParsingOffsetAndLengthWithStrings(
31+
std::string_view offset_str,
32+
std::string_view length_str,
33+
const char* expected_error_message_substring = nullptr) {
34+
SCOPED_TRACE(MakeString("offset: \"", offset_str, "\", length: \"", length_str, "\""));
35+
36+
ONNX_NAMESPACE::TensorProto tensor_proto;
37+
const std::filesystem::path kExternalDataPath("test.bin");
38+
39+
tensor_proto.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL);
40+
41+
auto* location_entry = tensor_proto.add_external_data();
42+
location_entry->set_key("location");
43+
location_entry->set_value(ToUTF8String(kExternalDataPath.native()));
44+
45+
auto* offset_entry = tensor_proto.add_external_data();
46+
offset_entry->set_key("offset");
47+
offset_entry->set_value(offset_str.data(), offset_str.size());
48+
49+
auto* length_entry = tensor_proto.add_external_data();
50+
length_entry->set_key("length");
51+
length_entry->set_value(length_str.data(), length_str.size());
52+
53+
std::unique_ptr<ExternalDataInfo> external_data_info{};
54+
const auto create_status = ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info);
55+
if (expected_error_message_substring) {
56+
ASSERT_STATUS_NOT_OK_AND_HAS_SUBSTR(create_status, expected_error_message_substring);
57+
return;
58+
}
59+
ASSERT_STATUS_OK(create_status);
60+
61+
// if we got this far, assume that offset_str and length_str are able to be parsed.
62+
const auto expected_offset = ParseStringWithClassicLocale<ExternalDataInfo::OFFSET_TYPE>(offset_str);
63+
const auto expected_length = ParseStringWithClassicLocale<size_t>(length_str);
64+
65+
ASSERT_EQ(external_data_info->GetOffset(), expected_offset);
66+
ASSERT_EQ(external_data_info->GetLength(), expected_length);
67+
}
68+
69+
// if `expected_error_message_substring` is nullptr, parsing is expected to be successful
70+
static void TestExternalDataInfoParsingOffsetAndLength(intmax_t offset,
71+
uintmax_t length,
72+
const char* expected_error_message_substring = nullptr) {
73+
TestExternalDataInfoParsingOffsetAndLengthWithStrings(std::to_string(offset), std::to_string(length),
74+
expected_error_message_substring);
75+
}
76+
77+
TEST(TensorProtoUtilsTest, ParseExternalDataInfoOffsetAndLength) {
78+
TestExternalDataInfoParsingOffsetAndLength(0, 0);
79+
80+
TestExternalDataInfoParsingOffsetAndLength(0, 1024);
81+
TestExternalDataInfoParsingOffsetAndLength(0, std::numeric_limits<size_t>::max());
82+
83+
TestExternalDataInfoParsingOffsetAndLength(1024, 1024);
84+
TestExternalDataInfoParsingOffsetAndLength(std::numeric_limits<ExternalDataInfo::OFFSET_TYPE>::max(), 1024);
85+
86+
{
87+
// assuming that this value is too large to fit in either size_t or ExternalDataInfo::OFFSET_TYPE
88+
const std::string_view two_to_the_65th_power = "36893488147419103232";
89+
const std::string_view zero = "0";
90+
TestExternalDataInfoParsingOffsetAndLengthWithStrings(two_to_the_65th_power, zero, "Failed to parse value");
91+
TestExternalDataInfoParsingOffsetAndLengthWithStrings(zero, two_to_the_65th_power, "Failed to parse value");
92+
}
93+
94+
// TODO should ExternalDataInfo::Create() also reject negative offset values?
95+
}
96+
2597
// Test ExternalData functionality
2698
TEST(TensorProtoUtilsTest, SetExternalDataInformation) {
2799
ONNX_NAMESPACE::TensorProto tensor_proto;

0 commit comments

Comments
 (0)