Skip to content

Commit 8be215c

Browse files
committed
[TMVA] Fix failing rbdt_xgboost test
* avoid warnings with opened file that is not closed * don't assume the number of features is in the `_features_count` attribute (that one doesn't exist with xgboost 2.0) * support the `"reg:squarederror"` target, which is the default regression target in xgboost 2.0
1 parent 30f27a5 commit 8be215c

File tree

2 files changed

+6
-14
lines changed

2 files changed

+6
-14
lines changed

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_tmva/_tree_inference.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
import cppyy
1313

1414

15-
def SaveXGBoost(self, xgb_model, key_name, output_path, num_inputs=None, tmp_path="/tmp", threshold_dtype="float"):
15+
def SaveXGBoost(self, xgb_model, key_name, output_path, num_inputs, tmp_path="/tmp", threshold_dtype="float"):
1616
# Extract objective
1717
objective_map = {
1818
"multi:softprob": "softmax", # Naming the objective softmax is more common today
1919
"binary:logistic": "logistic",
2020
"reg:linear": "identity",
21+
"reg:squarederror": "identity",
2122
}
2223
model_objective = xgb_model.objective
2324
if not model_objective in objective_map:
@@ -48,7 +49,8 @@ def SaveXGBoost(self, xgb_model, key_name, output_path, num_inputs=None, tmp_pat
4849

4950
import json
5051

51-
forest = json.load(open(tmp_path, "r"))
52+
with open(tmp_path, "r") as json_file:
53+
forest = json.load(json_file)
5254

5355
# Determine whether the model has a bias paramter and write bias trees
5456
if hasattr(xgb_model, "base_score") and "reg:" in model_objective:
@@ -96,16 +98,6 @@ def fill_arrays(node, index, inputs_base, thresholds_base):
9698
for i in range(num_trees):
9799
outputs[i] = int(i % num_outputs)
98100

99-
# Determine number of input variables
100-
if not num_inputs is None:
101-
pass
102-
elif hasattr(xgb_model, "_features_count"):
103-
num_inputs = xgb_model._features_count
104-
else:
105-
raise Exception(
106-
"Failed to get number of input variables from XGBoost model. Please provide the additional keyword argument 'num_inputs' to this function."
107-
)
108-
109101
# Store arrays in a ROOT file in a folder with the given key name
110102
# TODO: Write single values as simple integers and not vectors.
111103
f = cppyy.gbl.TFile(output_path, "RECREATE")

tmva/tmva/test/rbdt_xgboost.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def _test_XGBBinary(backend, label):
2121
x, y = create_dataset(1000, 10, 2)
2222
xgb = xgboost.XGBClassifier(n_estimators=100, max_depth=3)
2323
xgb.fit(x, y)
24-
ROOT.TMVA.Experimental.SaveXGBoost(xgb, "myModel", "testXGBBinary{}.root".format(label))
24+
ROOT.TMVA.Experimental.SaveXGBoost(xgb, "myModel", "testXGBBinary{}.root".format(label), num_inputs=10)
2525
bdt = ROOT.TMVA.Experimental.RBDT[backend]("myModel", "testXGBBinary{}.root".format(label))
2626

2727
y_xgb = xgb.predict_proba(x)[:, 1].squeeze()
@@ -51,7 +51,7 @@ def _test_XGBMulticlass(backend, label):
5151
x, y = create_dataset(1000, 10, 3)
5252
xgb = xgboost.XGBClassifier(n_estimators=100, max_depth=3)
5353
xgb.fit(x, y)
54-
ROOT.TMVA.Experimental.SaveXGBoost(xgb, "myModel", "testXGBMulticlass{}.root".format(label))
54+
ROOT.TMVA.Experimental.SaveXGBoost(xgb, "myModel", "testXGBMulticlass{}.root".format(label), num_inputs=10)
5555
bdt = ROOT.TMVA.Experimental.RBDT[backend]("myModel", "testXGBMulticlass{}.root".format(label))
5656

5757
y_xgb = xgb.predict_proba(x)

0 commit comments

Comments
 (0)