From 9bd8ac5dee40ea616e7c6813b8e69d275f0e5563 Mon Sep 17 00:00:00 2001 From: matthew-peacock Date: Wed, 2 Aug 2023 13:58:46 +1000 Subject: [PATCH 1/4] Fixes for GOSS data sample strategy. 1) Value of data_sample_strategy was not written out in Config::SaveMembersToString() 2) GOSSStrategy->Bagging may modify value of bag_data_cnt_ during training, which may mean tmp_grad_ and tmp_hess_ need resizing in RF::TrainOneIter --- src/boosting/rf.hpp | 6 ++++++ src/io/config_auto.cpp | 1 + 2 files changed, 7 insertions(+) diff --git a/src/boosting/rf.hpp b/src/boosting/rf.hpp index 9a87e982483e..88ece154e432 100644 --- a/src/boosting/rf.hpp +++ b/src/boosting/rf.hpp @@ -115,6 +115,12 @@ class RF : public GBDT { const data_size_t bag_data_cnt = data_sample_strategy_->bag_data_cnt(); const std::vector>& bag_data_indices = data_sample_strategy_->bag_data_indices(); + // GOSSStrategy->Bagging may modify value of bag_data_cnt_ + if (is_use_subset && bag_data_cnt < num_data_) { + tmp_grad_.resize(num_data_); + tmp_hess_.resize(num_data_); + } + CHECK_EQ(gradients, nullptr); CHECK_EQ(hessians, nullptr); diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp index 0906ba4b6439..7e9fae5630bd 100644 --- a/src/io/config_auto.cpp +++ b/src/io/config_auto.cpp @@ -673,6 +673,7 @@ std::string Config::SaveMembersToString() const { str_buf << "[max_depth: " << max_depth << "]\n"; str_buf << "[min_data_in_leaf: " << min_data_in_leaf << "]\n"; str_buf << "[min_sum_hessian_in_leaf: " << min_sum_hessian_in_leaf << "]\n"; + str_buf << "[data_sample_strategy: " << data_sample_strategy << "]\n"; str_buf << "[bagging_fraction: " << bagging_fraction << "]\n"; str_buf << "[pos_bagging_fraction: " << pos_bagging_fraction << "]\n"; str_buf << "[neg_bagging_fraction: " << neg_bagging_fraction << "]\n"; From d9a4ddf4684125fdb1745c1d99f17d7a36d6fc9c Mon Sep 17 00:00:00 2001 From: mjmckp Date: Tue, 15 Aug 2023 12:48:39 +1000 Subject: [PATCH 2/4] Remove doc_only from data_sample_strategy This is required by LightGBMNet --- include/LightGBM/config.h | 1 - 1 file changed, 1 deletion(-) diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index e01578396259..f470a22d6f8a 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -160,7 +160,6 @@ struct Config { // descl2 = **Note**: internally, LightGBM uses ``gbdt`` mode for the first ``1 / learning_rate`` iterations std::string boosting = "gbdt"; - // [doc-only] // type = enum // options = bagging, goss // desc = ``bagging``, Randomly Bagging Sampling From 3d3188b2ffc7c3b40a6f56ef3e0c94dff62f047f Mon Sep 17 00:00:00 2001 From: mjmckp Date: Thu, 24 Aug 2023 16:12:33 +1000 Subject: [PATCH 3/4] Revert change (will be handled in separate PR) --- include/LightGBM/config.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/LightGBM/config.h b/include/LightGBM/config.h index f470a22d6f8a..e01578396259 100644 --- a/include/LightGBM/config.h +++ b/include/LightGBM/config.h @@ -160,6 +160,7 @@ struct Config { // descl2 = **Note**: internally, LightGBM uses ``gbdt`` mode for the first ``1 / learning_rate`` iterations std::string boosting = "gbdt"; + // [doc-only] // type = enum // options = bagging, goss // desc = ``bagging``, Randomly Bagging Sampling From 6e1ca6f727a6d89271f6cc43d5bc571ff625c2c3 Mon Sep 17 00:00:00 2001 From: mjmckp Date: Thu, 24 Aug 2023 16:13:06 +1000 Subject: [PATCH 4/4] Revert change (will be handled in separate PR) --- src/io/config_auto.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/io/config_auto.cpp b/src/io/config_auto.cpp index 7e9fae5630bd..0906ba4b6439 100644 --- a/src/io/config_auto.cpp +++ b/src/io/config_auto.cpp @@ -673,7 +673,6 @@ std::string Config::SaveMembersToString() const { str_buf << "[max_depth: " << max_depth << "]\n"; str_buf << "[min_data_in_leaf: " << min_data_in_leaf << "]\n"; str_buf << "[min_sum_hessian_in_leaf: " << min_sum_hessian_in_leaf << "]\n"; - str_buf << "[data_sample_strategy: " << data_sample_strategy << "]\n"; str_buf << "[bagging_fraction: " << bagging_fraction << "]\n"; str_buf << "[pos_bagging_fraction: " << pos_bagging_fraction << "]\n"; str_buf << "[neg_bagging_fraction: " << neg_bagging_fraction << "]\n";