From c876ea37f162a45873451b036c77e7c2a60870d8 Mon Sep 17 00:00:00 2001 From: openhands Date: Thu, 19 Mar 2026 05:57:00 +0000 Subject: [PATCH] fix: suppress SMOTE for multiclass targets (issue #36) The SMOTE guard in template_based_adaptation.py only blocked multi-target cases (len(target_columns) > 1) but allowed SMOTE through for multiclass tasks (task.is_multiclass=True). SMOTE is only valid for binary classification, so extend the guard to also skip it when the task is multiclass. Also add test_smote_not_recommended_for_multiclass to cover the edge case where a binary-imbalanced training split would trigger SMOTE but the full dataset has more than 2 classes (is_multiclass=True). Co-authored-by: openhands Signed-off-by: openhands --- .../generation/template_based_adaptation.py | 7 +++- .../test_generatedcode_additional_patterns.py | 41 +++++++++++++++++++ 2 files changed, 46 insertions(+), 2 deletions(-) diff --git a/sapientml_core/adaptation/generation/template_based_adaptation.py b/sapientml_core/adaptation/generation/template_based_adaptation.py index cf5f909..c72d6a2 100644 --- a/sapientml_core/adaptation/generation/template_based_adaptation.py +++ b/sapientml_core/adaptation/generation/template_based_adaptation.py @@ -303,8 +303,11 @@ def _populate_preprocessing_components_in_pipeline(self, preprocessing_component columns = self.dataset_summary.columns for component in preprocessing_components: - # handle special case for SMOTE, don't apply if target columns are > 1. SMOTE fails in such cases - if "PREPROCESS:Balancing:SMOTE:imblearn" == component.label_name and len(self.task.target_columns) > 1: + # handle special case for SMOTE: skip if multiple target columns (SMOTE fails), + # or if the task is multiclass (SMOTE is only valid for binary classification) + if "PREPROCESS:Balancing:SMOTE:imblearn" == component.label_name and ( + len(self.task.target_columns) > 1 or self.task.is_multiclass + ): continue rel_cols = component.get_relevant_columns( diff --git a/tests/sapientml/test_generatedcode_additional_patterns.py b/tests/sapientml/test_generatedcode_additional_patterns.py index eeb3a5c..e619efe 100644 --- a/tests/sapientml/test_generatedcode_additional_patterns.py +++ b/tests/sapientml/test_generatedcode_additional_patterns.py @@ -779,6 +779,47 @@ def test_additional_classifier_works_with_preprocess( assert "StandardScaler" in code_for_test +@pytest.mark.parametrize("adaptation_metric", ["f1"]) +@pytest.mark.parametrize("target_col", ["target_category_binary_imbalance"]) +def test_smote_not_recommended_for_multiclass( + adaptation_metric, + target_col, + setup_request_parameters, + make_tempdir, + execute_pipeline, + execute_code_for_test_ipynb, + test_data, +): + """SMOTE must be suppressed when task.is_multiclass=True. + + Tests the edge case where the full dataset has more than 2 target classes but the + training split happens to contain only 2, causing _get_target_imbalance_score to + return a high score that would normally trigger SMOTE. + """ + task, config, dataset = setup_request_parameters() + + df = test_data + config.n_models = 1 + + task.task_type = "classification" + task.adaptation_metric = adaptation_metric + task.target_columns = [target_col] + # Force multiclass flag even though the column is binary, simulating the scenario where + # the full dataset has a 3rd rare class absent from the training split. + task.is_multiclass = True + + dataset.training_dataframe = df + dataset.training_data_path = (fxdir / "datasets" / "testdata_df.csv").as_posix() + dataset.ignore_columns.extend( + ["explanatory_multi_category_num", "target_category_multi_num", "target_category_binary_num"] + ) + temp_dir = make_tempdir + pipeline_results = execute_pipeline(dataset, task, config, temp_dir, initial_timeout=60) + test_result_df = execute_code_for_test_ipynb(pipeline_results, temp_dir) + code_for_test = test_result_df.loc[0, "code_for_test"] + assert "SMOTE" not in code_for_test + + @pytest.mark.parametrize("adaptation_metric", ["MCC", "QWK"]) @pytest.mark.parametrize("target_col", ["target_category_binary_num"]) def test_additional_classifier_category_binary_num_use_proba_with_metric_default_noproba(