diff --git a/openml/datasets/functions.py b/openml/datasets/functions.py index c2eb8ee75..26c705eca 100644 --- a/openml/datasets/functions.py +++ b/openml/datasets/functions.py @@ -333,6 +333,29 @@ def _load_features_from_file(features_file: str) -> Dict: return xml_dict["oml:data_features"] +def _expand_parameter(parameter: Union[str, List[str]]) -> List[str]: + expanded_parameter = [] + if isinstance(parameter, str): + expanded_parameter = [x.strip() for x in parameter.split(",")] + elif isinstance(parameter, list): + expanded_parameter = parameter + return expanded_parameter + + +def _validated_data_attributes( + attributes: List[str], data_attributes: List[str], parameter_name: str +) -> None: + for attribute_ in attributes: + is_attribute_a_data_attribute = any([attr[0] == attribute_ for attr in data_attributes]) + if not is_attribute_a_data_attribute: + raise ValueError( + "all attribute of '{}' should be one of the data attribute. " + " Got '{}' while candidates are {}.".format( + parameter_name, attribute_, [attr[0] for attr in data_attributes] + ) + ) + + def check_datasets_active( dataset_ids: List[int], raise_error_if_not_exist: bool = True, @@ -646,6 +669,7 @@ def create_dataset( ignore_attribute : str | list Attributes that should be excluded in modelling, such as identifiers and indexes. + Can have multiple values, comma separated. citation : str Reference(s) that should be cited when building on this data. version_label : str, optional @@ -697,6 +721,11 @@ def create_dataset( attributes_[attr_idx] = (attr_name, attributes[attr_name]) else: attributes_ = attributes + ignore_attributes = _expand_parameter(ignore_attribute) + _validated_data_attributes(ignore_attributes, attributes_, "ignore_attribute") + + default_target_attributes = _expand_parameter(default_target_attribute) + _validated_data_attributes(default_target_attributes, attributes_, "default_target_attribute") if row_id_attribute is not None: is_row_id_an_attribute = any([attr[0] == row_id_attribute for attr in attributes_]) diff --git a/tests/test_datasets/test_dataset_functions.py b/tests/test_datasets/test_dataset_functions.py index 707b6f9c5..38b035fcf 100644 --- a/tests/test_datasets/test_dataset_functions.py +++ b/tests/test_datasets/test_dataset_functions.py @@ -901,7 +901,6 @@ def test_create_dataset_pandas(self): collection_date = "01-01-2018" language = "English" licence = "MIT" - default_target_attribute = "play" citation = "None" original_data_url = "http://openml.github.io/openml-python" paper_url = "http://openml.github.io/openml-python" @@ -913,7 +912,7 @@ def test_create_dataset_pandas(self): collection_date=collection_date, language=language, licence=licence, - default_target_attribute=default_target_attribute, + default_target_attribute="play", row_id_attribute=None, ignore_attribute=None, citation=citation, @@ -948,7 +947,7 @@ def test_create_dataset_pandas(self): collection_date=collection_date, language=language, licence=licence, - default_target_attribute=default_target_attribute, + default_target_attribute="y", row_id_attribute=None, ignore_attribute=None, citation=citation, @@ -984,7 +983,7 @@ def test_create_dataset_pandas(self): collection_date=collection_date, language=language, licence=licence, - default_target_attribute=default_target_attribute, + default_target_attribute="rnd_str", row_id_attribute=None, ignore_attribute=None, citation=citation, @@ -1420,3 +1419,118 @@ def test_data_fork(self): self.assertRaisesRegex( OpenMLServerException, "Unknown dataset", fork_dataset, data_id=999999, ) + + +@pytest.mark.parametrize( + "default_target_attribute,row_id_attribute,ignore_attribute", + [ + ("wrong", None, None), + (None, "wrong", None), + (None, None, "wrong"), + ("wrong,sunny", None, None), + (None, None, "wrong,sunny"), + (["wrong", "sunny"], None, None), + (None, None, ["wrong", "sunny"]), + ], +) +def test_invalid_attribute_validations( + default_target_attribute, row_id_attribute, ignore_attribute +): + data = [ + ["a", "sunny", 85.0, 85.0, "FALSE", "no"], + ["b", "sunny", 80.0, 90.0, "TRUE", "no"], + ["c", "overcast", 83.0, 86.0, "FALSE", "yes"], + ["d", "rainy", 70.0, 96.0, "FALSE", "yes"], + ["e", "rainy", 68.0, 80.0, "FALSE", "yes"], + ] + column_names = ["rnd_str", "outlook", "temperature", "humidity", "windy", "play"] + df = pd.DataFrame(data, columns=column_names) + # enforce the type of each column + df["outlook"] = df["outlook"].astype("category") + df["windy"] = df["windy"].astype("bool") + df["play"] = df["play"].astype("category") + # meta-information + name = "pandas_testing_dataset" + description = "Synthetic dataset created from a Pandas DataFrame" + creator = "OpenML tester" + collection_date = "01-01-2018" + language = "English" + licence = "MIT" + citation = "None" + original_data_url = "http://openml.github.io/openml-python" + paper_url = "http://openml.github.io/openml-python" + with pytest.raises(ValueError, match="should be one of the data attribute"): + _ = openml.datasets.functions.create_dataset( + name=name, + description=description, + creator=creator, + contributor=None, + collection_date=collection_date, + language=language, + licence=licence, + default_target_attribute=default_target_attribute, + row_id_attribute=row_id_attribute, + ignore_attribute=ignore_attribute, + citation=citation, + attributes="auto", + data=df, + version_label="test", + original_data_url=original_data_url, + paper_url=paper_url, + ) + + +@pytest.mark.parametrize( + "default_target_attribute,row_id_attribute,ignore_attribute", + [ + ("outlook", None, None), + (None, "outlook", None), + (None, None, "outlook"), + ("outlook,windy", None, None), + (None, None, "outlook,windy"), + (["outlook", "windy"], None, None), + (None, None, ["outlook", "windy"]), + ], +) +def test_valid_attribute_validations(default_target_attribute, row_id_attribute, ignore_attribute): + data = [ + ["a", "sunny", 85.0, 85.0, "FALSE", "no"], + ["b", "sunny", 80.0, 90.0, "TRUE", "no"], + ["c", "overcast", 83.0, 86.0, "FALSE", "yes"], + ["d", "rainy", 70.0, 96.0, "FALSE", "yes"], + ["e", "rainy", 68.0, 80.0, "FALSE", "yes"], + ] + column_names = ["rnd_str", "outlook", "temperature", "humidity", "windy", "play"] + df = pd.DataFrame(data, columns=column_names) + # enforce the type of each column + df["outlook"] = df["outlook"].astype("category") + df["windy"] = df["windy"].astype("bool") + df["play"] = df["play"].astype("category") + # meta-information + name = "pandas_testing_dataset" + description = "Synthetic dataset created from a Pandas DataFrame" + creator = "OpenML tester" + collection_date = "01-01-2018" + language = "English" + licence = "MIT" + citation = "None" + original_data_url = "http://openml.github.io/openml-python" + paper_url = "http://openml.github.io/openml-python" + _ = openml.datasets.functions.create_dataset( + name=name, + description=description, + creator=creator, + contributor=None, + collection_date=collection_date, + language=language, + licence=licence, + default_target_attribute=default_target_attribute, + row_id_attribute=row_id_attribute, + ignore_attribute=ignore_attribute, + citation=citation, + attributes="auto", + data=df, + version_label="test", + original_data_url=original_data_url, + paper_url=paper_url, + )