diff --git a/mooncake-integration/store/store_py.cpp b/mooncake-integration/store/store_py.cpp index f38e785bb..eb8801cf1 100644 --- a/mooncake-integration/store/store_py.cpp +++ b/mooncake-integration/store/store_py.cpp @@ -1392,6 +1392,20 @@ class MooncakeHostMemAllocatorPyWrapper { }; PYBIND11_MODULE(store, m) { + // Object data type classification + py::enum_(m, "ObjectDataType") + .value("UNKNOWN", ObjectDataType::UNKNOWN) + .value("KVCACHE", ObjectDataType::KVCACHE) + .value("TENSOR", ObjectDataType::TENSOR) + .value("WEIGHT", ObjectDataType::WEIGHT) + .value("SAMPLE", ObjectDataType::SAMPLE) + .value("ACTIVATION", ObjectDataType::ACTIVATION) + .value("GRADIENT", ObjectDataType::GRADIENT) + .value("OPTIMIZER_STATE", ObjectDataType::OPTIMIZER_STATE) + .value("METADATA", ObjectDataType::METADATA) + .value("GENERAL", ObjectDataType::GENERAL) + .export_values(); + // Define the ReplicateConfig class py::class_(m, "ReplicateConfig") .def(py::init<>()) @@ -1402,6 +1416,7 @@ PYBIND11_MODULE(store, m) { .def_readwrite("preferred_segment", &ReplicateConfig::preferred_segment) .def_readwrite("prefer_alloc_in_same_node", &ReplicateConfig::prefer_alloc_in_same_node) + .def_readwrite("data_type", &ReplicateConfig::data_type) .def("__str__", [](const ReplicateConfig &config) { std::ostringstream oss; oss << config; diff --git a/mooncake-store/include/master_service.h b/mooncake-store/include/master_service.h index 9acc0fe1c..94d8b4810 100644 --- a/mooncake-store/include/master_service.h +++ b/mooncake-store/include/master_service.h @@ -546,10 +546,12 @@ class MasterService { const UUID& client_id_, const std::chrono::system_clock::time_point put_start_time_, size_t value_length, std::vector&& reps, - bool enable_soft_pin, bool enable_hard_pin = false) + bool enable_soft_pin, bool enable_hard_pin = false, + ObjectDataType data_type_ = ObjectDataType::UNKNOWN) : client_id(client_id_), put_start_time(put_start_time_), size(value_length), + data_type(data_type_), lease_timeout(), soft_pin_timeout(std::nullopt), hard_pinned(enable_hard_pin), @@ -572,6 +574,7 @@ class MasterService { // Updated by UpsertStart (Case B) to reset the discard timeout. std::chrono::system_clock::time_point put_start_time; const size_t size; + const ObjectDataType data_type{ObjectDataType::UNKNOWN}; mutable SpinLock lock; // Default constructor, creates a time_point representing @@ -975,7 +978,8 @@ class MasterService { void Create(const UUID& client_id, uint64_t total_length, std::vector replicas, bool enable_soft_pin, - bool enable_hard_pin = false) { + bool enable_hard_pin = false, + ObjectDataType data_type = ObjectDataType::UNKNOWN) { if (Exists()) { throw std::logic_error("Already exists"); } @@ -984,7 +988,7 @@ class MasterService { std::piecewise_construct, std::forward_as_tuple(key_), std::forward_as_tuple(client_id, now, total_length, std::move(replicas), enable_soft_pin, - enable_hard_pin)); + enable_hard_pin, data_type)); it_ = result.first; } diff --git a/mooncake-store/include/replica.h b/mooncake-store/include/replica.h index dba29c3bb..f3d383c74 100644 --- a/mooncake-store/include/replica.h +++ b/mooncake-store/include/replica.h @@ -90,6 +90,7 @@ struct ReplicateConfig { std::string preferred_segment{}; // Deprecated: Single preferred segment // for backward compatibility bool prefer_alloc_in_same_node{false}; + ObjectDataType data_type{ObjectDataType::UNKNOWN}; friend std::ostream& operator<<(std::ostream& os, const ReplicateConfig& config) noexcept { @@ -107,7 +108,8 @@ struct ReplicateConfig { << config.preferred_segment; } os << ", prefer_alloc_in_same_node: " - << config.prefer_alloc_in_same_node << " }"; + << config.prefer_alloc_in_same_node + << ", data_type: " << config.data_type << " }"; return os; } }; diff --git a/mooncake-store/include/types.h b/mooncake-store/include/types.h index cd6bd53dd..898d99ca5 100644 --- a/mooncake-store/include/types.h +++ b/mooncake-store/include/types.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -123,6 +124,46 @@ static constexpr uint64_t DEFAULT_PROCESSING_TASK_TIMEOUT_SEC = 300; // 0 to be no timeout static constexpr uint32_t DEFAULT_MAX_RETRY_ATTEMPTS = 10; +/** + * @brief Data type classification for objects stored in Mooncake Store. + * + * This allows the store to track what kind of data each object holds, + * enabling future type-aware policies (eviction priority, replication + * strategies, etc.). Defaults to UNKNOWN for backward compatibility. + */ +enum class ObjectDataType : uint8_t { + UNKNOWN = 0, + KVCACHE = 1, + TENSOR = 2, + WEIGHT = 3, + SAMPLE = 4, + ACTIVATION = 5, + GRADIENT = 6, + OPTIMIZER_STATE = 7, + METADATA = 8, + GENERAL = 9, + // 10-255 reserved for future types +}; + +inline std::ostream& operator<<(std::ostream& os, + const ObjectDataType& type) noexcept { + static const std::unordered_map + type_strings{{ObjectDataType::UNKNOWN, "UNKNOWN"}, + {ObjectDataType::KVCACHE, "KVCACHE"}, + {ObjectDataType::TENSOR, "TENSOR"}, + {ObjectDataType::WEIGHT, "WEIGHT"}, + {ObjectDataType::SAMPLE, "SAMPLE"}, + {ObjectDataType::ACTIVATION, "ACTIVATION"}, + {ObjectDataType::GRADIENT, "GRADIENT"}, + {ObjectDataType::OPTIMIZER_STATE, "OPTIMIZER_STATE"}, + {ObjectDataType::METADATA, "METADATA"}, + {ObjectDataType::GENERAL, "GENERAL"}}; + + auto it = type_strings.find(type); + os << (it != type_strings.end() ? it->second : "UNKNOWN"); + return os; +} + // Forward declarations class BufferAllocatorBase; class CachelibBufferAllocator; diff --git a/mooncake-store/src/master_service.cpp b/mooncake-store/src/master_service.cpp index 280006091..1fea47cb2 100644 --- a/mooncake-store/src/master_service.cpp +++ b/mooncake-store/src/master_service.cpp @@ -781,7 +781,8 @@ auto MasterService::AllocateAndInsertMetadata( shard->metadata.emplace( std::piecewise_construct, std::forward_as_tuple(key), std::forward_as_tuple(client_id, now, value_length, std::move(replicas), - config.with_soft_pin, config.with_hard_pin)); + config.with_soft_pin, config.with_hard_pin, + config.data_type)); shard->processing_keys.insert(key); return replica_list; @@ -3867,7 +3868,7 @@ MasterService::MetadataSerializer::DeserializeShard(const msgpack::object& obj, metadata_ptr->client_id, metadata_ptr->put_start_time, metadata_ptr->size, metadata_ptr->PopReplicas(), metadata_ptr->soft_pin_timeout.has_value(), - metadata_ptr->IsHardPinned())); + metadata_ptr->IsHardPinned(), metadata_ptr->data_type)); it->second.lease_timeout = metadata_ptr->lease_timeout; it->second.soft_pin_timeout = metadata_ptr->soft_pin_timeout; @@ -3882,12 +3883,12 @@ MasterService::MetadataSerializer::SerializeMetadata( MsgpackPacker& packer) const { // Pack ObjectMetadata using array structure for efficiency // Format: [client_id, put_start_time, size, lease_timeout, - // has_soft_pin_timeout, soft_pin_timeout, replicas_count, replicas..., - // hard_pinned] + // has_soft_pin_timeout, soft_pin_timeout, replicas_count, data_type, + // replicas..., hard_pinned] - size_t array_size = 8; // client_id, put_start_time, size, lease_timeout, + size_t array_size = 9; // client_id, put_start_time, size, lease_timeout, // has_soft_pin_timeout, soft_pin_timeout, - // replicas_count + hard_pinned + // replicas_count, data_type, hard_pinned array_size += metadata.CountReplicas(); // One element per replica packer.pack_array(array_size); @@ -3927,6 +3928,9 @@ MasterService::MetadataSerializer::SerializeMetadata( // Serialize replicas count packer.pack(static_cast(metadata.CountReplicas())); + // Serialize data_type + packer.pack(static_cast(metadata.data_type)); + // Serialize replicas for (const auto& replica : metadata.GetAllReplicas()) { auto result = Serializer::serialize( @@ -3951,9 +3955,8 @@ MasterService::MetadataSerializer::DeserializeMetadata( "deserialize ObjectMetadata state is not an array")); } - // Need at least 7 elements: client_id, put_start_time, size, lease_timeout, - // has_soft_pin_timeout, soft_pin_timeout, replicas_count - // (8th element = hard_pinned is optional for backward compat) + // Need at least 7 elements for old format, 8 for data_type-only or + // hard_pinned-only, 9 for newest format with both if (obj.via.array.size < 7) { return tl::unexpected(SerializationError( ErrorCode::DESERIALIZE_FAIL, @@ -3986,15 +3989,27 @@ MasterService::MetadataSerializer::DeserializeMetadata( // Deserialize replicas count uint32_t replicas_count = array[index++].as(); - // Array size: 7 + replicas_count (old format) or 8 + replicas_count (new - // format with hard_pinned) - if (obj.via.array.size != 7 + replicas_count && - obj.via.array.size != 8 + replicas_count) { + // Format detection: + // v1 (old): 7 + replicas_count — no data_type, no hard_pinned + // v2 (main): 8 + replicas_count — hard_pinned after replicas + // v3 (newest): 9 + replicas_count — adds data_type & hard_pinned + const uint32_t total_elements = obj.via.array.size; + const bool is_v1 = (total_elements == 7 + replicas_count); + const bool is_v2 = (total_elements == 8 + replicas_count); + const bool is_v3 = (total_elements == 9 + replicas_count); + + if (!is_v1 && !is_v2 && !is_v3) { return tl::unexpected(SerializationError( ErrorCode::DESERIALIZE_FAIL, "deserialize ObjectMetadata array size mismatch")); } + // v3 has data_type right after replicas_count + ObjectDataType data_type = ObjectDataType::UNKNOWN; + if (is_v3) { + data_type = static_cast(array[index++].as()); + } + // Deserialize replicas std::vector replicas; replicas.reserve(replicas_count); @@ -4020,7 +4035,7 @@ MasterService::MetadataSerializer::DeserializeMetadata( client_id, std::chrono::system_clock::time_point( std::chrono::milliseconds(put_start_time_timestamp)), - size, std::move(replicas), enable_soft_pin, is_hard_pinned); + size, std::move(replicas), enable_soft_pin, is_hard_pinned, data_type); metadata->lease_timeout = std::chrono::system_clock::time_point( std::chrono::milliseconds(lease_timestamp)); diff --git a/mooncake-store/tests/CMakeLists.txt b/mooncake-store/tests/CMakeLists.txt index 2f2d535b5..7325f48b0 100644 --- a/mooncake-store/tests/CMakeLists.txt +++ b/mooncake-store/tests/CMakeLists.txt @@ -77,6 +77,7 @@ add_store_test(task_executor_test task_executor_test.cpp) add_store_test(task_integration_test task_integration_test.cpp) add_store_test(dummy_client_get_buffer_test dummy_client_get_buffer_test.cpp) add_store_test(health_check_test health_check_test.cpp) +add_store_test(object_data_type_test object_data_type_test.cpp) add_subdirectory(e2e) add_executable(high_availability_test ha/leadership/high_availability_test.cpp) diff --git a/mooncake-store/tests/object_data_type_test.cpp b/mooncake-store/tests/object_data_type_test.cpp new file mode 100644 index 000000000..fd0de441a --- /dev/null +++ b/mooncake-store/tests/object_data_type_test.cpp @@ -0,0 +1,157 @@ +#include "types.h" +#include "replica.h" +#include "master_service.h" + +#include +#include + +#include +#include + +namespace mooncake::test { + +class ObjectDataTypeTest : public ::testing::Test { + protected: + void SetUp() override { + google::InitGoogleLogging("ObjectDataTypeTest"); + FLAGS_logtostderr = true; + } + + void TearDown() override { google::ShutdownGoogleLogging(); } + + static constexpr size_t kDefaultSegmentBase = 0x300000000; + static constexpr size_t kDefaultSegmentSize = 1024 * 1024 * 16; + + Segment MakeSegment(std::string name = "test_segment", + size_t base = kDefaultSegmentBase, + size_t size = kDefaultSegmentSize) const { + Segment segment; + segment.id = generate_uuid(); + segment.name = std::move(name); + segment.base = base; + segment.size = size; + segment.te_endpoint = segment.name; + return segment; + } +}; + +// Verify enum values match the RFC spec +TEST_F(ObjectDataTypeTest, EnumValues) { + EXPECT_EQ(static_cast(ObjectDataType::UNKNOWN), 0); + EXPECT_EQ(static_cast(ObjectDataType::KVCACHE), 1); + EXPECT_EQ(static_cast(ObjectDataType::TENSOR), 2); + EXPECT_EQ(static_cast(ObjectDataType::WEIGHT), 3); + EXPECT_EQ(static_cast(ObjectDataType::SAMPLE), 4); + EXPECT_EQ(static_cast(ObjectDataType::ACTIVATION), 5); + EXPECT_EQ(static_cast(ObjectDataType::GRADIENT), 6); + EXPECT_EQ(static_cast(ObjectDataType::OPTIMIZER_STATE), 7); + EXPECT_EQ(static_cast(ObjectDataType::METADATA), 8); + EXPECT_EQ(static_cast(ObjectDataType::GENERAL), 9); +} + +// Verify stream operator produces readable output +TEST_F(ObjectDataTypeTest, StreamOperator) { + std::ostringstream oss; + oss << ObjectDataType::KVCACHE; + EXPECT_EQ(oss.str(), "KVCACHE"); + + oss.str(""); + oss << ObjectDataType::UNKNOWN; + EXPECT_EQ(oss.str(), "UNKNOWN"); + + oss.str(""); + oss << ObjectDataType::OPTIMIZER_STATE; + EXPECT_EQ(oss.str(), "OPTIMIZER_STATE"); + + oss.str(""); + oss << ObjectDataType::GENERAL; + EXPECT_EQ(oss.str(), "GENERAL"); + + // Out-of-range value should print "UNKNOWN" + oss.str(""); + oss << static_cast(200); + EXPECT_EQ(oss.str(), "UNKNOWN"); +} + +// ReplicateConfig defaults to UNKNOWN +TEST_F(ObjectDataTypeTest, ReplicateConfigDefaultDataType) { + ReplicateConfig config; + EXPECT_EQ(config.data_type, ObjectDataType::UNKNOWN); +} + +// ReplicateConfig can be set to other types +TEST_F(ObjectDataTypeTest, ReplicateConfigSetDataType) { + ReplicateConfig config; + config.data_type = ObjectDataType::WEIGHT; + EXPECT_EQ(config.data_type, ObjectDataType::WEIGHT); +} + +// ReplicateConfig stream output includes data_type +TEST_F(ObjectDataTypeTest, ReplicateConfigStreamIncludesDataType) { + ReplicateConfig config; + config.data_type = ObjectDataType::TENSOR; + std::ostringstream oss; + oss << config; + EXPECT_NE(oss.str().find("data_type: TENSOR"), std::string::npos); +} + +// PutStart with data_type propagates to ObjectMetadata +TEST_F(ObjectDataTypeTest, PutStartWithDataType) { + std::unique_ptr service(new MasterService()); + Segment segment = MakeSegment(); + UUID client_id = generate_uuid(); + auto mount_result = service->MountSegment(segment, client_id); + ASSERT_TRUE(mount_result.has_value()); + + UUID put_client = generate_uuid(); + + // Put with WEIGHT type + ReplicateConfig config; + config.replica_num = 1; + config.data_type = ObjectDataType::WEIGHT; + + auto result = service->PutStart(put_client, "key_weight", 1024, config); + ASSERT_TRUE(result.has_value()); + EXPECT_FALSE(result.value().empty()); + + auto end_result = + service->PutEnd(put_client, "key_weight", ReplicaType::MEMORY); + EXPECT_TRUE(end_result.has_value()); +} + +// PutStart with default UNKNOWN data_type still works (backward compat) +TEST_F(ObjectDataTypeTest, PutStartDefaultDataType) { + std::unique_ptr service(new MasterService()); + Segment segment = MakeSegment(); + UUID client_id = generate_uuid(); + auto mount_result = service->MountSegment(segment, client_id); + ASSERT_TRUE(mount_result.has_value()); + + UUID put_client = generate_uuid(); + ReplicateConfig config; + config.replica_num = 1; + // data_type left as default (UNKNOWN) + + auto result = service->PutStart(put_client, "key_default", 1024, config); + ASSERT_TRUE(result.has_value()); + EXPECT_FALSE(result.value().empty()); +} + +// Verify all enum values can roundtrip through uint8_t cast +TEST_F(ObjectDataTypeTest, EnumRoundtrip) { + std::vector all_types = { + ObjectDataType::UNKNOWN, ObjectDataType::KVCACHE, + ObjectDataType::TENSOR, ObjectDataType::WEIGHT, + ObjectDataType::SAMPLE, ObjectDataType::ACTIVATION, + ObjectDataType::GRADIENT, ObjectDataType::OPTIMIZER_STATE, + ObjectDataType::METADATA, ObjectDataType::GENERAL, + }; + + for (auto type : all_types) { + uint8_t raw = static_cast(type); + auto recovered = static_cast(raw); + EXPECT_EQ(type, recovered); + } +} + +} // namespace mooncake::test